Skip to main content

mesh_rpc/
service.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Protobuf service support for mesh types.
5
6pub use grpc::Code;
7pub use grpc::Status;
8use mesh::local_node::Port;
9use mesh::payload::DefaultEncoding;
10use mesh::payload::MessageDecode;
11use mesh::payload::MessageEncode;
12use mesh::payload::Result;
13use mesh::payload::encoding::MessageEncoding;
14use mesh::payload::protobuf::FieldSizer;
15use mesh::payload::protobuf::FieldWriter;
16use mesh::payload::protobuf::MessageReader;
17use mesh::payload::protobuf::MessageSizer;
18use mesh::payload::protobuf::MessageWriter;
19use mesh::resource::Resource;
20
21#[expect(clippy::allow_attributes)]
22mod grpc {
23    // Generated types use these crates, reference them here to ensure they are
24    // not removed by automated tooling.
25    use prost as _;
26    use prost_types as _;
27
28    include!(concat!(env!("OUT_DIR"), "/google.rpc.rs"));
29
30    impl std::fmt::Display for Code {
31        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32            write!(f, "{:?}", self)
33        }
34    }
35
36    impl std::error::Error for Code {}
37}
38
39/// A generic RPC value.
40///
41/// This type is designed to have the same encoding as [`DecodedRpc`].
42#[derive(mesh::MeshPayload)]
43pub(crate) struct GenericRpc {
44    #[mesh(1)]
45    pub method: String,
46    #[mesh(2)]
47    pub data: Vec<u8>,
48    #[mesh(3)]
49    pub port: Port, // TODO: transparent mesh::OneshotSender<std::result::Result<Vec<u8>, Status>>,
50}
51
52impl GenericRpc {
53    pub(crate) fn respond_status(self, status: Status) {
54        let sender =
55            mesh::OneshotSender::<std::result::Result<std::convert::Infallible, Status>>::from(
56                self.port,
57            );
58
59        sender.send(Err(status));
60    }
61}
62
63/// A generic RPC value, using borrows instead of owning types.
64#[derive(mesh::MeshPayload)]
65struct GenericRpcView<'a> {
66    #[mesh(1)]
67    method: &'a str,
68    #[mesh(2)]
69    data: &'a [u8],
70    #[mesh(3)]
71    port: Port,
72}
73
74/// Trait for service-specific RPC requests.
75pub trait ServiceRpc: 'static + Send + Sized {
76    /// The service name.
77    const NAME: &'static str;
78
79    /// The method name.
80    fn method(&self) -> &'static str;
81
82    /// Encode the request into a field.
83    fn encode(self, writer: FieldWriter<'_, '_, Resource>) -> Port;
84
85    /// Compute the field size of the request.
86    fn compute_size(&mut self, sizer: FieldSizer<'_>);
87
88    /// Decode the request from a field.
89    fn decode(
90        method: &str,
91        port: Port,
92        data: &[u8],
93    ) -> std::result::Result<Self, (ServiceRpcError, Port)>;
94}
95
96/// An error returned while decoding a method call.
97pub enum ServiceRpcError {
98    /// The method is unknown.
99    UnknownMethod,
100    /// The input could not be decoded.
101    InvalidInput(mesh::payload::Error),
102}
103
104#[doc(hidden)]
105pub(crate) enum DecodedRpc<T> {
106    Rpc(T),
107    Err {
108        rpc: GenericRpc,
109        err: ServiceRpcError,
110    },
111}
112
113pub(crate) struct DecodedRpcEncoder;
114
115impl<T: ServiceRpc> DefaultEncoding for DecodedRpc<T> {
116    type Encoding = MessageEncoding<DecodedRpcEncoder>;
117}
118
119impl<T: ServiceRpc> MessageEncode<DecodedRpc<T>, Resource> for DecodedRpcEncoder {
120    fn write_message(item: DecodedRpc<T>, mut writer: MessageWriter<'_, '_, Resource>) {
121        match item {
122            DecodedRpc::Rpc(rpc) => {
123                writer.field(1).bytes(rpc.method().as_bytes());
124                let port = rpc.encode(writer.field(2));
125                writer.field(3).resource(Resource::Port(port));
126            }
127            DecodedRpc::Err { rpc, err: _ } => {
128                <GenericRpc as DefaultEncoding>::Encoding::write_message(rpc, writer)
129            }
130        }
131    }
132
133    fn compute_message_size(item: &mut DecodedRpc<T>, mut sizer: MessageSizer<'_>) {
134        match item {
135            DecodedRpc::Rpc(rpc) => {
136                sizer.field(1).bytes(rpc.method().len());
137                rpc.compute_size(sizer.field(2));
138                sizer.field(3).resource();
139            }
140            DecodedRpc::Err { rpc, err: _ } => {
141                <GenericRpc as DefaultEncoding>::Encoding::compute_message_size(rpc, sizer)
142            }
143        }
144    }
145}
146
147impl<'a, T: ServiceRpc> MessageDecode<'a, DecodedRpc<T>, Resource> for DecodedRpcEncoder {
148    fn read_message(
149        item: &mut mesh::payload::inplace::InplaceOption<'_, DecodedRpc<T>>,
150        reader: MessageReader<'a, '_, Resource>,
151    ) -> Result<()> {
152        mesh::payload::inplace_none!(v: GenericRpcView<'_>);
153        <GenericRpcView<'_> as DefaultEncoding>::Encoding::read_message(&mut v, reader)?;
154        let v = v.take().expect("should be constructed");
155        let rpc = match T::decode(v.method, v.port, v.data) {
156            Ok(rpc) => DecodedRpc::Rpc(rpc),
157            Err((err, port)) => {
158                let rpc = GenericRpc {
159                    method: v.method.to_string(),
160                    data: v.data.to_vec(),
161                    port,
162                };
163                DecodedRpc::Err { rpc, err }
164            }
165        };
166        item.set(rpc);
167        Ok(())
168    }
169}