1#![forbid(unsafe_code)]
9
10use heck::ToUpperCamelCase;
11use proc_macro2::Span;
12use syn::Ident;
13
14pub struct MeshServiceGenerator {
16 replacements: Vec<(syn::TypePath, syn::Type)>,
17}
18
19impl MeshServiceGenerator {
20 pub fn new() -> Self {
22 Self {
23 replacements: Vec::new(),
24 }
25 }
26
27 pub fn replace_type(mut self, ty: &str, replacement: &str) -> Self {
34 let ty = syn::parse_str(ty).unwrap();
35 let replacement = syn::parse_str(replacement).unwrap();
36 self.replacements.push((ty, replacement));
37 self
38 }
39
40 fn lookup_type(&self, ty: &str) -> syn::Type {
41 let ty: syn::Type = syn::parse_str(ty).unwrap_or_else(|err| {
42 panic!("failed to parse type {}: {}", ty, err);
43 });
44 if let syn::Type::Path(ty) = &ty {
45 for (from, to) in &self.replacements {
46 if from == ty {
47 return to.clone();
48 }
49 }
50 }
51 ty
52 }
53}
54
55impl prost_build::ServiceGenerator for MeshServiceGenerator {
56 fn generate(&mut self, service: prost_build::Service, buf: &mut String) {
57 let name = format!("{}.{}", service.package, service.proto_name);
58 let ident = Ident::new(&service.name, Span::call_site());
59 let method_names: Vec<_> = service.methods.iter().map(|m| &m.proto_name).collect();
60 let method_idents: Vec<_> = service
61 .methods
62 .iter()
63 .map(|m| Ident::new(&m.name.to_upper_camel_case(), Span::call_site()))
64 .collect();
65 let request_types: Vec<_> = service
66 .methods
67 .iter()
68 .map(|m| self.lookup_type(&m.input_type))
69 .collect();
70 let response_types: Vec<_> = service
71 .methods
72 .iter()
73 .map(|m| self.lookup_type(&m.output_type))
74 .collect();
75
76 *buf += "e::quote! {
77 #[derive(Debug)]
78 pub enum #ident {
79 #(
80 #method_idents(
81 #request_types,
82 ::mesh::OneshotSender<::core::result::Result<#response_types, ::mesh_rpc::service::Status>>,
83 ),
84 )*
85 }
86
87 impl #ident {
88 #[allow(dead_code)]
89 pub fn fail(self, status: ::mesh_rpc::service::Status) {
90 match self {
91 #(
92 #ident::#method_idents(_, response) => response.send(Err(status)),
93 )*
94 }
95 }
96 }
97
98 impl ::mesh_rpc::service::ServiceRpc for #ident {
99 const NAME: &'static str = #name;
100
101 fn method(&self) -> &'static str {
102 match self {
103 #(
104 #ident::#method_idents(_, _) => #method_names,
105 )*
106 }
107 }
108
109 fn encode(
110 self,
111 writer: ::mesh::payload::protobuf::FieldWriter<'_, '_, ::mesh::resource::Resource>,
112 ) -> ::mesh::local_node::Port {
113 match self {
114 #(
115 #ident::#method_idents(req, port) => {
116 <<#request_types as ::mesh::payload::DefaultEncoding>::Encoding as ::mesh::payload::FieldEncode<_, _>>::write_field(req, writer);
117 port.into()
118 }
119 )*
120 }
121 }
122
123 fn compute_size(&mut self, sizer: ::mesh::payload::protobuf::FieldSizer<'_>) {
124 match self {
125 #(
126 #ident::#method_idents(req, _) => {
127 <<#request_types as ::mesh::payload::DefaultEncoding>::Encoding as ::mesh::payload::FieldEncode::<_, ::mesh::resource::Resource>>::compute_field_size(
128 req,
129 sizer);
130 }
131 )*
132 }
133 }
134
135 fn decode(
136 method: &str,
137 port: ::mesh::local_node::Port,
138 data: &[u8],
139 ) -> Result<Self, (::mesh_rpc::service::ServiceRpcError, ::mesh::local_node::Port)> {
140 match method {
141 #(
142 #method_names => {
143 match mesh::payload::decode(data) {
144 Ok(req) => Ok(#ident::#method_idents(req, port.into())),
145 Err(e) => Err((::mesh_rpc::service::ServiceRpcError::InvalidInput(e), port)),
146 }
147 }
148 )*
149 _ => Err((::mesh_rpc::service::ServiceRpcError::UnknownMethod, port)),
150 }
151 }
152 }
153 }
154 .to_string();
155 }
156}