1use super::error::RemoteResult;
7use crate::OneshotReceiver;
8use crate::OneshotSender;
9use crate::RecvError;
10use crate::error::RemoteError;
11use crate::oneshot;
12use mesh_node::message::MeshField;
13use mesh_protobuf::Protobuf;
14use std::convert::Infallible;
15use std::fmt::Debug;
16use std::future::Future;
17use std::pin::Pin;
18use std::task::Poll;
19use std::task::ready;
20use thiserror::Error;
21
22#[derive(Protobuf)]
26#[mesh(
27 bound = "I: 'static + MeshField + Send, R: 'static + MeshField + Send",
28 resource = "mesh_node::resource::Resource"
29)]
30pub struct Rpc<I, R>(I, OneshotSender<R>);
31
32impl<I: Debug, R> Debug for Rpc<I, R> {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 f.debug_tuple("Rpc").field(&self.0).finish()
35 }
36}
37
38pub type FailableRpc<I, R> = Rpc<I, RemoteResult<R>>;
40
41impl<I, R: 'static + Send> Rpc<I, R> {
42 pub fn detached(input: I) -> Self {
45 let (result_send, _) = oneshot();
46 Rpc(input, result_send)
47 }
48
49 pub fn input(&self) -> &I {
51 &self.0
52 }
53
54 pub fn split(self) -> (I, Rpc<(), R>) {
58 (self.0, Rpc((), self.1))
59 }
60
61 pub fn handle_sync<F>(self, f: F)
64 where
65 F: FnOnce(I) -> R,
66 {
67 let r = f(self.0);
68 self.1.send(r);
69 }
70
71 pub async fn handle<F>(self, f: F)
74 where
75 F: AsyncFnOnce(I) -> R,
76 {
77 let r = f(self.0).await;
78 self.1.send(r);
79 }
80
81 pub async fn handle_must_succeed<F, E>(self, f: F) -> Result<(), E>
88 where
89 F: AsyncFnOnce(I) -> Result<R, E>,
90 {
91 let r = f(self.0).await?;
92 self.1.send(r);
93 Ok(())
94 }
95
96 pub fn complete(self, result: R) {
98 self.1.send(result);
99 }
100}
101
102impl<I, R: 'static + Send> Rpc<I, Result<R, RemoteError>> {
103 pub fn handle_failable_sync<F, E>(self, f: F)
106 where
107 F: FnOnce(I) -> Result<R, E>,
108 E: Into<Box<dyn std::error::Error + Send + Sync>>,
109 {
110 let r = f(self.0);
111 self.1.send(r.map_err(RemoteError::new));
112 }
113
114 pub async fn handle_failable<F, E>(self, f: F)
118 where
119 F: AsyncFnOnce(I) -> Result<R, E>,
120 E: Into<Box<dyn std::error::Error + Send + Sync>>,
121 {
122 let r = f(self.0).await;
123 self.1.send(r.map_err(RemoteError::new));
124 }
125
126 pub fn fail<E>(self, error: E)
128 where
129 E: Into<Box<dyn std::error::Error + Send + Sync>>,
130 {
131 self.1.send(Err(RemoteError::new(error)));
132 }
133}
134
135pub trait RpcSend: Sized {
137 type Message;
139
140 fn send_rpc(self, message: Self::Message);
142
143 fn call<F, I, R>(self, f: F, input: I) -> PendingRpc<R>
163 where
164 F: FnOnce(Rpc<I, R>) -> Self::Message,
165 R: 'static + Send,
166 {
167 let (result_send, result_recv) = oneshot();
168 self.send_rpc(f(Rpc(input, result_send)));
169 PendingRpc(result_recv)
170 }
171
172 fn call_failable<F, I, T, E>(self, f: F, input: I) -> PendingFailableRpc<T, E>
178 where
179 F: FnOnce(Rpc<I, Result<T, E>>) -> Self::Message,
180 T: 'static + Send,
181 E: 'static + Send,
182 {
183 PendingFailableRpc(self.call(f, input))
184 }
185}
186
187pub trait TryRpcSend: Sized {
190 type Message;
192 type Error;
194
195 fn try_send_rpc(self, message: Self::Message) -> Result<(), Self::Error>;
197
198 fn try_call<F, I, R>(self, f: F, input: I) -> Result<PendingRpc<R>, Self::Error>
218 where
219 F: FnOnce(Rpc<I, R>) -> Self::Message,
220 R: 'static + Send,
221 {
222 let (result_send, result_recv) = oneshot();
223 self.try_send_rpc(f(Rpc(input, result_send)))?;
224 Ok(PendingRpc(result_recv))
225 }
226
227 fn try_call_failable<F, I, T, E>(
234 self,
235 f: F,
236 input: I,
237 ) -> Result<PendingFailableRpc<T, E>, Self::Error>
238 where
239 F: FnOnce(Rpc<I, Result<T, E>>) -> Self::Message,
240 T: 'static + Send,
241 E: 'static + Send,
242 {
243 Ok(PendingFailableRpc(self.try_call(f, input)?))
244 }
245}
246
247#[derive(Debug, Error)]
250pub enum RpcError<E = Infallible> {
251 #[error(transparent)]
252 Call(E),
253 #[error(transparent)]
254 Channel(RecvError),
255}
256
257#[must_use]
259#[derive(Debug)]
260pub struct PendingRpc<T>(OneshotReceiver<T>);
261
262impl<T: 'static + Send> Future for PendingRpc<T> {
263 type Output = Result<T, RpcError<Infallible>>;
264
265 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
266 Poll::Ready(ready!(Pin::new(&mut self.get_mut().0).poll(cx)).map_err(RpcError::Channel))
267 }
268}
269
270#[must_use]
272#[derive(Debug)]
273pub struct PendingFailableRpc<T, E = RemoteError>(PendingRpc<Result<T, E>>);
274
275impl<T: 'static + Send, E: 'static + Send> Future for PendingFailableRpc<T, E> {
276 type Output = Result<T, RpcError<E>>;
277
278 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
279 let r = ready!(Pin::new(&mut self.get_mut().0).poll(cx));
280 match r {
281 Ok(Ok(t)) => Ok(t),
282 Ok(Err(e)) => Err(RpcError::Call(e)),
283 Err(RpcError::Channel(e)) => Err(RpcError::Channel(e)),
284 }
285 .into()
286 }
287}
288
289impl<T: 'static + Send> RpcSend for OneshotSender<T> {
290 type Message = T;
291 fn send_rpc(self, message: T) {
292 self.send(message);
293 }
294}
295
296impl<T: 'static + Send> RpcSend for &mesh_channel_core::Sender<T> {
297 type Message = T;
298 fn send_rpc(self, message: T) {
299 self.send(message);
300 }
301}