mesh_channel/
rpc.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Remote Procedure Call functionality.
5
6use super::error::RemoteResult;
7use crate::OneshotReceiver;
8use crate::OneshotSender;
9use crate::RecvError;
10use crate::error::RemoteError;
11use crate::oneshot;
12use mesh_node::message::MeshField;
13use mesh_protobuf::Protobuf;
14use std::convert::Infallible;
15use std::fmt::Debug;
16use std::future::Future;
17use std::pin::Pin;
18use std::task::Poll;
19use std::task::ready;
20use thiserror::Error;
21
22/// An RPC message for a request with input of type `I` and output of type `R`.
23/// The receiver of the message should process the request and return results
24/// via the `Sender<R>`.
25#[derive(Protobuf)]
26#[mesh(
27    bound = "I: 'static + MeshField + Send, R: 'static + MeshField + Send",
28    resource = "mesh_node::resource::Resource"
29)]
30pub struct Rpc<I, R>(I, OneshotSender<R>);
31
32impl<I: Debug, R> Debug for Rpc<I, R> {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_tuple("Rpc").field(&self.0).finish()
35    }
36}
37
38/// An RPC message with a failable result.
39pub type FailableRpc<I, R> = Rpc<I, RemoteResult<R>>;
40
41impl<I, R: 'static + Send> Rpc<I, R> {
42    /// Returns a new RPC message with `input` and no one listening for the
43    /// result.
44    pub fn detached(input: I) -> Self {
45        let (result_send, _) = oneshot();
46        Rpc(input, result_send)
47    }
48
49    /// Returns the input to the RPC.
50    pub fn input(&self) -> &I {
51        &self.0
52    }
53
54    /// Splits the RPC into its input and an input-less RPC. This is useful when
55    /// the input is needed in one place but the RPC will be completed in
56    /// another.
57    pub fn split(self) -> (I, Rpc<(), R>) {
58        (self.0, Rpc((), self.1))
59    }
60
61    /// Handles an RPC request by calling `f` and sending the result to the
62    /// initiator.
63    pub fn handle_sync<F>(self, f: F)
64    where
65        F: FnOnce(I) -> R,
66    {
67        let r = f(self.0);
68        self.1.send(r);
69    }
70
71    /// Handles an RPC request by calling `f`, awaiting its result, and sending
72    /// the result to the initiator.
73    pub async fn handle<F>(self, f: F)
74    where
75        F: AsyncFnOnce(I) -> R,
76    {
77        let r = f(self.0).await;
78        self.1.send(r);
79    }
80
81    /// Handles an RPC request by calling `f`, awaiting its result, and sending
82    /// Ok results back to the initiator.
83    ///
84    /// If `f` fails, the error is propagated back to the caller, and the RPC
85    /// channel is dropped (resulting in a `RecvError::Closed` on the
86    /// initiator).
87    pub async fn handle_must_succeed<F, E>(self, f: F) -> Result<(), E>
88    where
89        F: AsyncFnOnce(I) -> Result<R, E>,
90    {
91        let r = f(self.0).await?;
92        self.1.send(r);
93        Ok(())
94    }
95
96    /// Completes the RPC with the specified result value.
97    pub fn complete(self, result: R) {
98        self.1.send(result);
99    }
100}
101
102impl<I, R: 'static + Send> Rpc<I, Result<R, RemoteError>> {
103    /// Handles an RPC request by calling `f` and sending the result to the
104    /// initiator, after converting any error to a [`RemoteError`].
105    pub fn handle_failable_sync<F, E>(self, f: F)
106    where
107        F: FnOnce(I) -> Result<R, E>,
108        E: Into<Box<dyn std::error::Error + Send + Sync>>,
109    {
110        let r = f(self.0);
111        self.1.send(r.map_err(RemoteError::new));
112    }
113
114    /// Handles an RPC request by calling `f`, awaiting its result, and sending
115    /// the result to the initiator, after converting any error to a
116    /// [`RemoteError`].
117    pub async fn handle_failable<F, E>(self, f: F)
118    where
119        F: AsyncFnOnce(I) -> Result<R, E>,
120        E: Into<Box<dyn std::error::Error + Send + Sync>>,
121    {
122        let r = f(self.0).await;
123        self.1.send(r.map_err(RemoteError::new));
124    }
125
126    /// Fails the RPC with the specified error.
127    pub fn fail<E>(self, error: E)
128    where
129        E: Into<Box<dyn std::error::Error + Send + Sync>>,
130    {
131        self.1.send(Err(RemoteError::new(error)));
132    }
133}
134
135/// A trait implemented by objects that can send RPC requests.
136pub trait RpcSend: Sized {
137    /// The message type for this sender.
138    type Message;
139
140    /// Send an RPC request.
141    fn send_rpc(self, message: Self::Message);
142
143    /// Issues a request and returns a channel to receive the result.
144    ///
145    /// `f` maps an [`Rpc`] object to the message type and is often an enum
146    /// variant name.
147    ///
148    /// `input` is the input to the call.
149    ///
150    /// # Example
151    ///
152    /// ```rust
153    /// # use mesh_channel::rpc::{Rpc, RpcSend};
154    /// # use mesh_channel::Sender;
155    /// enum Request {
156    ///     Add(Rpc<(u32, u32), u32>),
157    /// }
158    /// async fn add(send: &Sender<Request>) {
159    ///     assert_eq!(send.call(Request::Add, (3, 4)).await.unwrap(), 7);
160    /// }
161    /// ```
162    fn call<F, I, R>(self, f: F, input: I) -> PendingRpc<R>
163    where
164        F: FnOnce(Rpc<I, R>) -> Self::Message,
165        R: 'static + Send,
166    {
167        let (result_send, result_recv) = oneshot();
168        self.send_rpc(f(Rpc(input, result_send)));
169        PendingRpc(result_recv)
170    }
171
172    /// Issues a request and returns an object to receive the result.
173    ///
174    /// This is like [`RpcSend::call`], but for RPCs that return a [`Result`].
175    /// The returned object combines the channel error and the call's error into
176    /// a single [`RpcError`] type, which makes it easier to handle errors.
177    fn call_failable<F, I, T, E>(self, f: F, input: I) -> PendingFailableRpc<T, E>
178    where
179        F: FnOnce(Rpc<I, Result<T, E>>) -> Self::Message,
180        T: 'static + Send,
181        E: 'static + Send,
182    {
183        PendingFailableRpc(self.call(f, input))
184    }
185}
186
187/// A trait implemented by objects that can try to send RPC requests but may
188/// fail.
189pub trait TryRpcSend: Sized {
190    /// The message type for this sender.
191    type Message;
192    /// The error type returned when sending an RPC request fails.
193    type Error;
194
195    /// Tries to send an RPC request.
196    fn try_send_rpc(self, message: Self::Message) -> Result<(), Self::Error>;
197
198    /// Issues a request and returns a channel to receive the result.
199    ///
200    /// `f` maps an [`Rpc`] object to the message type and is often an enum
201    /// variant name.
202    ///
203    /// `input` is the input to the call.
204    ///
205    /// # Example
206    ///
207    /// ```rust
208    /// # use mesh_channel::rpc::{Rpc, RpcSend};
209    /// # use mesh_channel::Sender;
210    /// enum Request {
211    ///     Add(Rpc<(u32, u32), u32>),
212    /// }
213    /// async fn add(send: &Sender<Request>) {
214    ///     assert_eq!(send.call(Request::Add, (3, 4)).await.unwrap(), 7);
215    /// }
216    /// ```
217    fn try_call<F, I, R>(self, f: F, input: I) -> Result<PendingRpc<R>, Self::Error>
218    where
219        F: FnOnce(Rpc<I, R>) -> Self::Message,
220        R: 'static + Send,
221    {
222        let (result_send, result_recv) = oneshot();
223        self.try_send_rpc(f(Rpc(input, result_send)))?;
224        Ok(PendingRpc(result_recv))
225    }
226
227    /// Issues a request and returns an object to receive the result.
228    ///
229    /// This is like [`TryRpcSend::try_call`], but for RPCs that return a
230    /// [`Result`]. The returned object combines the channel error and the
231    /// call's error into a single [`RpcError`] type, which makes it easier to
232    /// handle errors.
233    fn try_call_failable<F, I, T, E>(
234        self,
235        f: F,
236        input: I,
237    ) -> Result<PendingFailableRpc<T, E>, Self::Error>
238    where
239        F: FnOnce(Rpc<I, Result<T, E>>) -> Self::Message,
240        T: 'static + Send,
241        E: 'static + Send,
242    {
243        Ok(PendingFailableRpc(self.try_call(f, input)?))
244    }
245}
246
247/// An error from an RPC call, via
248/// [`RpcSend::call_failable`] or [`RpcSend::call`].
249#[derive(Debug, Error)]
250pub enum RpcError<E = Infallible> {
251    #[error(transparent)]
252    Call(E),
253    #[error(transparent)]
254    Channel(RecvError),
255}
256
257/// The result future of an [`RpcSend::call`] call.
258#[must_use]
259#[derive(Debug)]
260pub struct PendingRpc<T>(OneshotReceiver<T>);
261
262impl<T: 'static + Send> Future for PendingRpc<T> {
263    type Output = Result<T, RpcError<Infallible>>;
264
265    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
266        Poll::Ready(ready!(Pin::new(&mut self.get_mut().0).poll(cx)).map_err(RpcError::Channel))
267    }
268}
269
270/// The result future of an [`RpcSend::call_failable`] call.
271#[must_use]
272#[derive(Debug)]
273pub struct PendingFailableRpc<T, E = RemoteError>(PendingRpc<Result<T, E>>);
274
275impl<T: 'static + Send, E: 'static + Send> Future for PendingFailableRpc<T, E> {
276    type Output = Result<T, RpcError<E>>;
277
278    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
279        let r = ready!(Pin::new(&mut self.get_mut().0).poll(cx));
280        match r {
281            Ok(Ok(t)) => Ok(t),
282            Ok(Err(e)) => Err(RpcError::Call(e)),
283            Err(RpcError::Channel(e)) => Err(RpcError::Channel(e)),
284        }
285        .into()
286    }
287}
288
289impl<T: 'static + Send> RpcSend for OneshotSender<T> {
290    type Message = T;
291    fn send_rpc(self, message: T) {
292        self.send(message);
293    }
294}
295
296impl<T: 'static + Send> RpcSend for &mesh_channel_core::Sender<T> {
297    type Message = T;
298    fn send_rpc(self, message: T) {
299        self.send(message);
300    }
301}