mesh_build/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! A code generator for protobuf service definitions.
5//!
6//! Used with the prost protobuf code generator.
7
8#![forbid(unsafe_code)]
9
10use heck::ToUpperCamelCase;
11use proc_macro2::Span;
12use syn::Ident;
13
14/// A service generator for mesh services.
15pub struct MeshServiceGenerator {
16    replacements: Vec<(syn::TypePath, syn::Type)>,
17}
18
19impl MeshServiceGenerator {
20    /// Creates a new service generator.
21    pub fn new() -> Self {
22        Self {
23            replacements: Vec::new(),
24        }
25    }
26
27    /// Configures the generator to replace any instance of Rust `ty` with
28    /// `replacement`.
29    ///
30    /// This can be useful when some input or output messages already have mesh
31    /// types defined, and you want to use them instead of the generated prost
32    /// types.
33    pub fn replace_type(mut self, ty: &str, replacement: &str) -> Self {
34        let ty = syn::parse_str(ty).unwrap();
35        let replacement = syn::parse_str(replacement).unwrap();
36        self.replacements.push((ty, replacement));
37        self
38    }
39
40    fn lookup_type(&self, ty: &str) -> syn::Type {
41        let ty: syn::Type = syn::parse_str(ty).unwrap_or_else(|err| {
42            panic!("failed to parse type {}: {}", ty, err);
43        });
44        if let syn::Type::Path(ty) = &ty {
45            for (from, to) in &self.replacements {
46                if from == ty {
47                    return to.clone();
48                }
49            }
50        }
51        ty
52    }
53}
54
55impl prost_build::ServiceGenerator for MeshServiceGenerator {
56    fn generate(&mut self, service: prost_build::Service, buf: &mut String) {
57        let name = format!("{}.{}", service.package, service.proto_name);
58        let ident = Ident::new(&service.name, Span::call_site());
59        let method_names: Vec<_> = service.methods.iter().map(|m| &m.proto_name).collect();
60        let method_idents: Vec<_> = service
61            .methods
62            .iter()
63            .map(|m| Ident::new(&m.name.to_upper_camel_case(), Span::call_site()))
64            .collect();
65        let request_types: Vec<_> = service
66            .methods
67            .iter()
68            .map(|m| self.lookup_type(&m.input_type))
69            .collect();
70        let response_types: Vec<_> = service
71            .methods
72            .iter()
73            .map(|m| self.lookup_type(&m.output_type))
74            .collect();
75
76        *buf += &quote::quote! {
77            #[derive(Debug)]
78            pub enum #ident {
79                #(
80                    #method_idents(
81                        #request_types,
82                        ::mesh::OneshotSender<::core::result::Result<#response_types, ::mesh_rpc::service::Status>>,
83                    ),
84                )*
85            }
86
87            impl #ident {
88                #[allow(dead_code)]
89                pub fn fail(self, status: ::mesh_rpc::service::Status) {
90                    match self {
91                        #(
92                            #ident::#method_idents(_, response) => response.send(Err(status)),
93                        )*
94                    }
95                }
96            }
97
98            impl ::mesh_rpc::service::ServiceRpc for #ident {
99                const NAME: &'static str = #name;
100
101                fn method(&self) -> &'static str {
102                    match self {
103                        #(
104                            #ident::#method_idents(_, _) => #method_names,
105                        )*
106                    }
107                }
108
109                fn encode(
110                    self,
111                    writer: ::mesh::payload::protobuf::FieldWriter<'_, '_, ::mesh::resource::Resource>,
112                ) -> ::mesh::local_node::Port {
113                    match self {
114                        #(
115                            #ident::#method_idents(req, port) => {
116                                <<#request_types as ::mesh::payload::DefaultEncoding>::Encoding as ::mesh::payload::FieldEncode<_, _>>::write_field(req, writer);
117                                port.into()
118                            }
119                        )*
120                    }
121                }
122
123                fn compute_size(&mut self, sizer: ::mesh::payload::protobuf::FieldSizer<'_>) {
124                    match self {
125                        #(
126                            #ident::#method_idents(req, _) => {
127                                <<#request_types as ::mesh::payload::DefaultEncoding>::Encoding as ::mesh::payload::FieldEncode::<_, ::mesh::resource::Resource>>::compute_field_size(
128                                    req,
129                                    sizer);
130                            }
131                        )*
132                    }
133                }
134
135                fn decode(
136                    method: &str,
137                    port: ::mesh::local_node::Port,
138                    data: &[u8],
139                ) -> Result<Self, (::mesh_rpc::service::ServiceRpcError, ::mesh::local_node::Port)> {
140                    match method {
141                        #(
142                            #method_names => {
143                                match mesh::payload::decode(data) {
144                                    Ok(req) => Ok(#ident::#method_idents(req, port.into())),
145                                    Err(e) => Err((::mesh_rpc::service::ServiceRpcError::InvalidInput(e), port)),
146                                }
147                            }
148                        )*
149                        _ => Err((::mesh_rpc::service::ServiceRpcError::UnknownMethod, port)),
150                    }
151                }
152            }
153        }
154        .to_string();
155    }
156}