mesh_rpc/
service.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Protobuf service support for mesh types.
5
6pub use grpc::Code;
7pub use grpc::Status;
8use mesh::local_node::Port;
9use mesh::payload::DefaultEncoding;
10use mesh::payload::MessageDecode;
11use mesh::payload::MessageEncode;
12use mesh::payload::Result;
13use mesh::payload::encoding::MessageEncoding;
14use mesh::payload::protobuf::FieldSizer;
15use mesh::payload::protobuf::FieldWriter;
16use mesh::payload::protobuf::MessageReader;
17use mesh::payload::protobuf::MessageSizer;
18use mesh::payload::protobuf::MessageWriter;
19use mesh::resource::Resource;
20
21mod grpc {
22    // Generated types use these crates, reference them here to ensure they are
23    // not removed by automated tooling.
24    use prost as _;
25    use prost_types as _;
26
27    include!(concat!(env!("OUT_DIR"), "/google.rpc.rs"));
28
29    impl std::fmt::Display for Code {
30        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31            write!(f, "{:?}", self)
32        }
33    }
34
35    impl std::error::Error for Code {}
36}
37
38/// A generic RPC value.
39///
40/// This type is designed to have the same encoding as [`DecodedRpc`].
41#[derive(mesh::MeshPayload)]
42pub(crate) struct GenericRpc {
43    #[mesh(1)]
44    pub method: String,
45    #[mesh(2)]
46    pub data: Vec<u8>,
47    #[mesh(3)]
48    pub port: Port, // TODO: transparent mesh::OneshotSender<std::result::Result<Vec<u8>, Status>>,
49}
50
51impl GenericRpc {
52    pub(crate) fn respond_status(self, status: Status) {
53        let sender =
54            mesh::OneshotSender::<std::result::Result<std::convert::Infallible, Status>>::from(
55                self.port,
56            );
57
58        sender.send(Err(status));
59    }
60}
61
62/// A generic RPC value, using borrows instead of owning types.
63#[derive(mesh::MeshPayload)]
64struct GenericRpcView<'a> {
65    #[mesh(1)]
66    method: &'a str,
67    #[mesh(2)]
68    data: &'a [u8],
69    #[mesh(3)]
70    port: Port,
71}
72
73/// Trait for service-specific RPC requests.
74pub trait ServiceRpc: 'static + Send + Sized {
75    /// The service name.
76    const NAME: &'static str;
77
78    /// The method name.
79    fn method(&self) -> &'static str;
80
81    /// Encode the request into a field.
82    fn encode(self, writer: FieldWriter<'_, '_, Resource>) -> Port;
83
84    /// Compute the field size of the request.
85    fn compute_size(&mut self, sizer: FieldSizer<'_>);
86
87    /// Decode the request from a field.
88    fn decode(
89        method: &str,
90        port: Port,
91        data: &[u8],
92    ) -> std::result::Result<Self, (ServiceRpcError, Port)>;
93}
94
95/// An error returned while decoding a method call.
96pub enum ServiceRpcError {
97    /// The method is unknown.
98    UnknownMethod,
99    /// The input could not be decoded.
100    InvalidInput(mesh::payload::Error),
101}
102
103#[doc(hidden)]
104pub(crate) enum DecodedRpc<T> {
105    Rpc(T),
106    Err {
107        rpc: GenericRpc,
108        err: ServiceRpcError,
109    },
110}
111
112pub(crate) struct DecodedRpcEncoder;
113
114impl<T: ServiceRpc> DefaultEncoding for DecodedRpc<T> {
115    type Encoding = MessageEncoding<DecodedRpcEncoder>;
116}
117
118impl<T: ServiceRpc> MessageEncode<DecodedRpc<T>, Resource> for DecodedRpcEncoder {
119    fn write_message(item: DecodedRpc<T>, mut writer: MessageWriter<'_, '_, Resource>) {
120        match item {
121            DecodedRpc::Rpc(rpc) => {
122                writer.field(1).bytes(rpc.method().as_bytes());
123                let port = rpc.encode(writer.field(2));
124                writer.field(3).resource(Resource::Port(port));
125            }
126            DecodedRpc::Err { rpc, err: _ } => {
127                <GenericRpc as DefaultEncoding>::Encoding::write_message(rpc, writer)
128            }
129        }
130    }
131
132    fn compute_message_size(item: &mut DecodedRpc<T>, mut sizer: MessageSizer<'_>) {
133        match item {
134            DecodedRpc::Rpc(rpc) => {
135                sizer.field(1).bytes(rpc.method().len());
136                rpc.compute_size(sizer.field(2));
137                sizer.field(3).resource();
138            }
139            DecodedRpc::Err { rpc, err: _ } => {
140                <GenericRpc as DefaultEncoding>::Encoding::compute_message_size(rpc, sizer)
141            }
142        }
143    }
144}
145
146impl<'a, T: ServiceRpc> MessageDecode<'a, DecodedRpc<T>, Resource> for DecodedRpcEncoder {
147    fn read_message(
148        item: &mut mesh::payload::inplace::InplaceOption<'_, DecodedRpc<T>>,
149        reader: MessageReader<'a, '_, Resource>,
150    ) -> Result<()> {
151        mesh::payload::inplace_none!(v: GenericRpcView<'_>);
152        <GenericRpcView<'_> as DefaultEncoding>::Encoding::read_message(&mut v, reader)?;
153        let v = v.take().expect("should be constructed");
154        let rpc = match T::decode(v.method, v.port, v.data) {
155            Ok(rpc) => DecodedRpc::Rpc(rpc),
156            Err((err, port)) => {
157                let rpc = GenericRpc {
158                    method: v.method.to_string(),
159                    data: v.data.to_vec(),
160                    port,
161                };
162                DecodedRpc::Err { rpc, err }
163            }
164        };
165        item.set(rpc);
166        Ok(())
167    }
168}