mesh_rpc/
server.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! TTRPC server.
5
6use crate::message::MESSAGE_TYPE_REQUEST;
7use crate::message::MESSAGE_TYPE_RESPONSE;
8use crate::message::ReadResult;
9use crate::message::Request;
10use crate::message::Response;
11use crate::message::TooLongError;
12use crate::message::read_message;
13use crate::message::write_message;
14use crate::rpc::ProtocolError;
15use crate::rpc::status_from_err;
16use crate::service::Code;
17use crate::service::DecodedRpc;
18use crate::service::GenericRpc;
19use crate::service::ServiceRpc;
20use crate::service::ServiceRpcError;
21use crate::service::Status;
22use futures::FutureExt;
23use futures::Stream;
24use futures::StreamExt;
25use futures::stream::FusedStream;
26use futures_concurrency::future::TryJoin;
27use futures_concurrency::stream::Merge;
28use mesh::CancelContext;
29use mesh::MeshPayload;
30use mesh::local_node::Port;
31use pal_async::driver::Driver;
32use pal_async::socket::AsSockRef;
33use pal_async::socket::Listener;
34use pal_async::socket::PolledSocket;
35use std::collections::HashMap;
36use std::io::Read;
37use std::io::Write;
38use std::pin::Pin;
39use std::task::ready;
40use unicycle::FuturesUnordered;
41
42/// A ttrpc server.
43#[derive(Debug, Default)]
44pub struct Server {
45    services: HashMap<&'static str, mesh::Sender<(CancelContext, GenericRpc)>>,
46}
47
48/// A receiver for RPC requests for a given service.
49///
50/// Returned by [`Server::add_service`].
51#[derive(MeshPayload)]
52#[mesh(bound = "T: ServiceRpc")]
53pub struct RpcReceiver<T>(mesh::Receiver<(CancelContext, DecodedRpc<T>)>);
54
55impl<T: ServiceRpc> RpcReceiver<T> {
56    /// Returns a disconnected stream, useful for when a service is dynamically
57    /// not registered.
58    pub fn disconnected() -> Self {
59        let (_send, recv) = mesh::channel();
60        Self(recv)
61    }
62}
63
64impl<T: ServiceRpc> Stream for RpcReceiver<T> {
65    type Item = (CancelContext, T);
66
67    fn poll_next(
68        self: Pin<&mut Self>,
69        cx: &mut std::task::Context<'_>,
70    ) -> std::task::Poll<Option<Self::Item>> {
71        let this = self.get_mut();
72        while let Some((ctx, rpc)) = ready!(Pin::new(&mut this.0).poll_next(cx)) {
73            match rpc {
74                DecodedRpc::Rpc(rpc) => return Some((ctx, rpc)).into(),
75                DecodedRpc::Err { rpc, err } => {
76                    rpc.fail(err);
77                }
78            }
79        }
80        None.into()
81    }
82}
83
84impl<T: ServiceRpc> FusedStream for RpcReceiver<T> {
85    fn is_terminated(&self) -> bool {
86        self.0.is_terminated()
87    }
88}
89
90impl GenericRpc {
91    fn fail(self, err: ServiceRpcError) {
92        let status = match err {
93            ServiceRpcError::UnknownMethod => Status {
94                code: Code::Unimplemented.into(),
95                message: format!("unknown method {}", self.method),
96                details: Vec::new(),
97            },
98            ServiceRpcError::InvalidInput(error) => status_from_err(Code::InvalidArgument, error),
99        };
100        self.respond_status(status);
101    }
102}
103
104impl Server {
105    /// Creates a new ttrpc server.
106    pub fn new() -> Self {
107        Self {
108            services: Default::default(),
109        }
110    }
111
112    /// Adds or updates a channel for receiving service requests.
113    pub fn add_service<T: ServiceRpc>(&mut self) -> RpcReceiver<T> {
114        let (send, recv) = mesh::channel();
115        self.services.insert(T::NAME, Port::from(send).into());
116        RpcReceiver(recv)
117    }
118
119    /// Runs the server using the ttrpc transport, listening on `listener` and
120    /// servicing connections until `cancel`.
121    pub async fn run(
122        &mut self,
123        driver: &(impl Driver + ?Sized),
124        listener: impl Listener,
125        cancel: mesh::OneshotReceiver<()>,
126    ) -> anyhow::Result<()> {
127        let mut listener = PolledSocket::new(driver, listener)?;
128        let mut tasks = FuturesUnordered::new();
129        let mut cancel = cancel.fuse();
130        loop {
131            let conn = futures::select! { // merge semantics
132                r = listener.accept().fuse() => r,
133                _ = tasks.next() => continue,
134                _ = cancel => break,
135            };
136            if let Ok(conn) = conn.and_then(|(conn, _)| PolledSocket::new(driver, conn)) {
137                tasks.push(async {
138                    let _ = self.serve(conn).await.map_err(|err| {
139                        tracing::error!(
140                            error = err.as_ref() as &dyn std::error::Error,
141                            "connection error"
142                        )
143                    });
144                });
145            }
146        }
147        Ok(())
148    }
149
150    /// Runs the server, servicing a single connection `conn`.
151    pub async fn run_single(
152        &mut self,
153        driver: &(impl Driver + ?Sized),
154        conn: impl AsSockRef + Read + Write,
155    ) -> anyhow::Result<()> {
156        self.serve(PolledSocket::new(driver, conn)?).await
157    }
158
159    async fn serve(
160        &self,
161        stream: PolledSocket<impl AsSockRef + Read + Write>,
162    ) -> anyhow::Result<()> {
163        let (mut reader, mut writer) = stream.split();
164        let (stream_send, mut stream_recv) = mesh::channel();
165        let ctx = CancelContext::new();
166        let recv_task = async {
167            let stream_send = stream_send; // move into this task
168            while let Some(message) = read_message(&mut reader).await? {
169                let (send, recv) = mesh::oneshot::<Result<Vec<u8>, Status>>();
170                stream_send.send((message.stream_id, recv));
171
172                let handle = handle_message(message).and_then(|request| {
173                    let service = self.services.get(request.service.as_str()).ok_or_else(|| {
174                        status_from_err(
175                            Code::Unimplemented,
176                            anyhow::anyhow!("unknown service {}", request.service),
177                        )
178                    })?;
179
180                    let ctx = if request.timeout_nano == 0 {
181                        ctx.clone()
182                    } else {
183                        ctx.with_timeout(std::time::Duration::from_nanos(request.timeout_nano))
184                    };
185
186                    Ok(move |port| {
187                        service.send((
188                            ctx,
189                            GenericRpc {
190                                method: request.method,
191                                data: request.payload,
192                                port,
193                            },
194                        ));
195                    })
196                });
197
198                match handle {
199                    Ok(handle) => handle(send.into()),
200                    Err(err) => send.send(Err(err)),
201                }
202            }
203            Ok(())
204        };
205        let send_task = async {
206            let mut responses = FuturesUnordered::new();
207            enum Event<T> {
208                Request((u32, mesh::OneshotReceiver<Result<Vec<u8>, Status>>)),
209                Response(T),
210            }
211            while let Some(event) = (
212                (&mut stream_recv).map(Event::Request),
213                (&mut responses).map(Event::Response),
214            )
215                .merge()
216                .next()
217                .await
218            {
219                match event {
220                    Event::Request((stream_id, recv)) => {
221                        let recv = recv.map(move |r| {
222                            (
223                                stream_id,
224                                match r {
225                                    Ok(Ok(payload)) => Response::Payload(payload),
226                                    Ok(Err(status)) => Response::Status(status),
227                                    Err(err) => {
228                                        Response::Status(status_from_err(Code::Internal, err))
229                                    }
230                                },
231                            )
232                        });
233                        responses.push(recv);
234                    }
235                    Event::Response((stream_id, payload)) => {
236                        write_message(
237                            &mut writer,
238                            stream_id,
239                            MESSAGE_TYPE_RESPONSE,
240                            &mesh::payload::encode(payload),
241                        )
242                        .await?;
243                    }
244                }
245            }
246            anyhow::Result::<_>::Ok(())
247        };
248        (recv_task, send_task).try_join().await?;
249        Ok(())
250    }
251}
252
253fn handle_message(message: ReadResult) -> Result<Request, Status> {
254    if message.stream_id % 2 != 1 {
255        return Err(status_from_err(
256            Code::InvalidArgument,
257            ProtocolError::EvenStreamId,
258        ));
259    }
260
261    match message.message_type {
262        MESSAGE_TYPE_REQUEST => {
263            let payload = message.payload.map_err(|err @ TooLongError { .. }| {
264                status_from_err(Code::ResourceExhausted, err)
265            })?;
266            let request = mesh::payload::decode::<Request>(&payload)
267                .map_err(|err| status_from_err(Code::InvalidArgument, err))?;
268
269            tracing::debug!(
270                stream_id = message.stream_id,
271                service = %request.service,
272                method = %request.method,
273                timeout = request.timeout_nano / 1000 / 1000,
274                "message",
275            );
276
277            Ok(request)
278        }
279        ty => Err(status_from_err(
280            Code::InvalidArgument,
281            ProtocolError::InvalidMessageType(ty),
282        )),
283    }
284}
285
286#[cfg(feature = "grpc")]
287mod grpc {
288    use super::Server;
289    use crate::rpc::status_from_err;
290    use crate::service::Code;
291    use crate::service::GenericRpc;
292    use crate::service::Status;
293    use anyhow::Context as _;
294    use futures::AsyncRead as _;
295    use futures::AsyncWrite;
296    use futures::FutureExt;
297    use futures::StreamExt;
298    use futures_concurrency::stream::Merge;
299    use h2::RecvStream;
300    use h2::server::SendResponse;
301    use http::HeaderMap;
302    use http::HeaderValue;
303    use mesh::CancelContext;
304    use pal_async::driver::Driver;
305    use pal_async::socket::AsSockRef;
306    use pal_async::socket::Listener;
307    use pal_async::socket::PolledSocket;
308    use prost::bytes::Bytes;
309    use std::io::Read;
310    use std::io::Write;
311    use std::pin::Pin;
312    use std::task::ready;
313    use thiserror::Error;
314    use unicycle::FuturesUnordered;
315
316    #[derive(Debug, Error)]
317    enum RequestError {
318        #[error("http error")]
319        Http(#[from] http::Error),
320        #[error("http2 error")]
321        H2(#[from] h2::Error),
322        #[error("unreachable")]
323        Status(http::StatusCode),
324        #[error("invalid message header")]
325        InvalidHeader,
326    }
327
328    impl From<http::StatusCode> for RequestError {
329        fn from(status: http::StatusCode) -> Self {
330            RequestError::Status(status)
331        }
332    }
333
334    impl Server {
335        /// Runs the server using the gRPC transport, listening on `listener` and servicing connections until
336        /// `cancel`.
337        pub async fn run_grpc(
338            &mut self,
339            driver: &(impl Driver + ?Sized),
340            listener: impl Listener,
341            cancel: mesh::OneshotReceiver<()>,
342        ) -> anyhow::Result<()> {
343            let mut listener = PolledSocket::new(driver, listener)?;
344            let mut tasks = FuturesUnordered::new();
345            let mut cancel = cancel.fuse();
346            loop {
347                let conn = futures::select! { // merge semantics
348                    r = listener.accept().fuse() => r,
349                    _ = tasks.next() => continue,
350                    _ = cancel => break,
351                };
352                if let Ok(conn) = conn.and_then(|(conn, _)| PolledSocket::new(driver, conn)) {
353                    tasks.push(async {
354                        let _ = self.serve_grpc(conn).await.map_err(|err| {
355                            tracing::error!(
356                                error = err.as_ref() as &dyn std::error::Error,
357                                "connection error"
358                            )
359                        });
360                    });
361                }
362            }
363            Ok(())
364        }
365
366        async fn serve_grpc(
367            &self,
368            stream: PolledSocket<impl AsSockRef + Read + Write>,
369        ) -> anyhow::Result<()> {
370            struct Wrap<T>(T);
371
372            impl<T: AsSockRef + Read> tokio::io::AsyncRead for Wrap<PolledSocket<T>> {
373                fn poll_read(
374                    self: Pin<&mut Self>,
375                    cx: &mut std::task::Context<'_>,
376                    buf: &mut tokio::io::ReadBuf<'_>,
377                ) -> std::task::Poll<std::io::Result<()>> {
378                    let n = ready!(
379                        Pin::new(&mut self.get_mut().0).poll_read(cx, buf.initialize_unfilled())
380                    )?;
381                    buf.advance(n);
382                    std::task::Poll::Ready(Ok(()))
383                }
384            }
385
386            impl<T: AsSockRef + Write> tokio::io::AsyncWrite for Wrap<PolledSocket<T>> {
387                fn poll_write(
388                    self: Pin<&mut Self>,
389                    cx: &mut std::task::Context<'_>,
390                    buf: &[u8],
391                ) -> std::task::Poll<Result<usize, std::io::Error>> {
392                    Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
393                }
394
395                fn poll_flush(
396                    self: Pin<&mut Self>,
397                    cx: &mut std::task::Context<'_>,
398                ) -> std::task::Poll<Result<(), std::io::Error>> {
399                    Pin::new(&mut self.get_mut().0).poll_flush(cx)
400                }
401
402                fn poll_shutdown(
403                    self: Pin<&mut Self>,
404                    cx: &mut std::task::Context<'_>,
405                ) -> std::task::Poll<Result<(), std::io::Error>> {
406                    Pin::new(&mut self.get_mut().0).poll_close(cx)
407                }
408            }
409
410            let mut conn = h2::server::handshake(Wrap(stream))
411                .await
412                .context("failed http2 handshake")?;
413
414            let mut tasks = FuturesUnordered::new();
415
416            loop {
417                enum Event<A, B> {
418                    Accept(A),
419                    Task(Result<(), B>),
420                }
421
422                let r = (
423                    (&mut conn).map(Event::Accept),
424                    (&mut tasks).map(Event::Task),
425                )
426                    .merge()
427                    .next()
428                    .await;
429
430                let (req, mut resp) = match r {
431                    None => break,
432                    Some(Event::Task(r)) => {
433                        r?;
434                        continue;
435                    }
436                    Some(Event::Accept(r)) => r.context("failed http2 stream accept")?,
437                };
438
439                let task = async move {
440                    match self.handle_request(req, &mut resp).await {
441                        Err(RequestError::Status(status)) => {
442                            tracing::debug!(status = status.as_u16(), "request error");
443                            resp.send_response(
444                                http::Response::builder().status(status).body(())?,
445                                true,
446                            )?;
447                            Ok(())
448                        }
449                        r => r,
450                    }
451                };
452                tasks.push(task);
453            }
454
455            std::future::poll_fn(|cx| conn.poll_closed(cx)).await?;
456            Ok(())
457        }
458
459        async fn handle_request(
460            &self,
461            req: http::Request<RecvStream>,
462            resp: &mut SendResponse<Bytes>,
463        ) -> Result<(), RequestError> {
464            tracing::debug!(url = %req.uri(), "rpc request");
465
466            if req.method() != http::Method::POST {
467                Err(http::StatusCode::METHOD_NOT_ALLOWED)?
468            }
469            let content_type = req.headers().get("content-type");
470            match content_type.map(|v| v.as_bytes()) {
471                Some(b"application/grpc" | b"application/grpc+proto") => {}
472                _ => Err(http::StatusCode::UNSUPPORTED_MEDIA_TYPE)?,
473            }
474
475            let response =
476                http::Response::builder().header("content-type", "application/grpc+proto");
477
478            let ctx = if let Some(timeout) = req.headers().get("grpc-timeout") {
479                let timeout = timeout
480                    .to_str()
481                    .map_err(|_| http::StatusCode::BAD_REQUEST)?;
482                let mul = match timeout
483                    .bytes()
484                    .last()
485                    .ok_or(http::StatusCode::BAD_REQUEST)?
486                {
487                    b'H' => std::time::Duration::from_secs(60 * 60),
488                    b'M' => std::time::Duration::from_secs(60),
489                    b'S' => std::time::Duration::from_secs(1),
490                    b'm' => std::time::Duration::from_millis(1),
491                    b'u' => std::time::Duration::from_micros(1),
492                    b'n' => std::time::Duration::from_nanos(1),
493                    _ => Err(http::StatusCode::BAD_REQUEST)?,
494                };
495                let timeout = timeout[..timeout.len() - 1]
496                    .parse::<u32>()
497                    .map_err(|_| http::StatusCode::BAD_REQUEST)?;
498                CancelContext::new().with_timeout(mul * timeout)
499            } else {
500                CancelContext::new()
501            };
502
503            let (head, body) = req.into_parts();
504            let path = head.uri.path();
505            let path = path.strip_prefix('/').ok_or(http::StatusCode::NOT_FOUND)?;
506            let (service, method) = path.split_once('/').ok_or(http::StatusCode::NOT_FOUND)?;
507
508            // No returning HTTP status code errors after this point.
509            let mut resp = resp.send_response(response.body(())?, false)?;
510
511            let result = self.invoke_rpc(service, method, body, ctx).await?;
512
513            let mut trailers = HeaderMap::new();
514            match result {
515                Ok(data) => {
516                    tracing::debug!(service, method, "rpc success");
517
518                    let mut buf = Vec::with_capacity(5 + data.len());
519                    buf.push(0);
520                    buf.extend(&(data.len() as u32).to_be_bytes());
521                    buf.extend(data);
522                    resp.send_data(buf.into(), false)?;
523                    trailers.insert("grpc-status", const { HeaderValue::from_static("0") });
524                }
525                Err(status) => {
526                    tracing::debug!(service, method, ?status, "rpc error");
527
528                    trailers.insert("grpc-status", status.code.into());
529                    trailers.insert(
530                        "grpc-message",
531                        urlencoding::encode(&status.message)
532                            .into_owned()
533                            .try_into()
534                            .unwrap(),
535                    );
536                    trailers.insert(
537                        "grpc-status-details-bin",
538                        base64::Engine::encode(
539                            &base64::engine::general_purpose::STANDARD,
540                            prost::Message::encode_to_vec(&status),
541                        )
542                        .try_into()
543                        .unwrap(),
544                    );
545                }
546            }
547            resp.send_trailers(trailers)?;
548            Ok(())
549        }
550
551        async fn invoke_rpc(
552            &self,
553            service: &str,
554            method: &str,
555            mut body: RecvStream,
556            ctx: CancelContext,
557        ) -> Result<Result<Vec<u8>, Status>, RequestError> {
558            let Some(service) = self.services.get(service) else {
559                return Ok(Err(Status {
560                    code: Code::Unimplemented.into(),
561                    message: format!("unknown service {}", service),
562                    details: Vec::new(),
563                }));
564            };
565
566            // For now, only non-stream RPCs are supported, so read the first
567            // message and ignore the rest.
568            //
569            // FUTURE: change the `GenericRpc` type to include channels for
570            // streams.
571
572            let mut buf = Vec::new();
573
574            // Read data frames until the header is complete.
575            while buf.len() < 5 {
576                let data = body.data().await.ok_or(RequestError::InvalidHeader)??;
577                buf.extend(&data);
578                body.flow_control().release_capacity(data.len()).unwrap();
579            }
580            let hdr = buf.get(0..5).ok_or(RequestError::InvalidHeader)?;
581            if hdr[0] != 0 {
582                // Compression was not advertised as supported, so the client
583                // should not send compressed messages.
584                return Err(RequestError::InvalidHeader);
585            }
586            let len = u32::from_be_bytes(hdr[1..5].try_into().unwrap()) as usize;
587
588            buf.drain(..5);
589            while buf.len() < len {
590                let data = body.data().await.ok_or(RequestError::InvalidHeader)??;
591                buf.extend(&data);
592                body.flow_control().release_capacity(data.len()).unwrap();
593            }
594
595            let (send, recv) = mesh::oneshot();
596
597            let rpc = GenericRpc {
598                method: method.to_owned(),
599                data: buf,
600                port: send.into(),
601            };
602
603            service.send((ctx, rpc));
604
605            Ok(recv
606                .await
607                .unwrap_or_else(|err| Err(status_from_err(Code::Internal, err))))
608        }
609    }
610}
611
612#[cfg(test)]
613mod tests {
614    use crate::Client;
615    use crate::Server;
616    use crate::client::ExistingConnection;
617    use crate::service::Code;
618    use crate::service::ServiceRpc;
619    use futures::StreamExt;
620    use futures::executor::block_on;
621    use pal_async::DefaultPool;
622    use pal_async::local::block_with_io;
623    use pal_async::socket::PolledSocket;
624    use test_with_tracing::test;
625
626    mod items {
627        include!(concat!(env!("OUT_DIR"), "/ttrpc.example.v1.rs"));
628    }
629
630    #[test]
631    fn client_server() {
632        let (c, s) = unix_socket::UnixStream::pair().unwrap();
633        let mut server = Server::new();
634        let mut recv = server.add_service::<items::Example>();
635        let server_thread = std::thread::spawn(move || {
636            block_with_io(async |driver| server.run_single(&driver, s).await)
637        });
638
639        let client_thread = std::thread::spawn(move || {
640            DefaultPool::run_with(async |driver| {
641                let client = Client::new(
642                    &driver,
643                    ExistingConnection::new(PolledSocket::new(&driver, c).unwrap()),
644                );
645                let response = client
646                    .call()
647                    .start(
648                        items::Example::Method1,
649                        items::Method1Request {
650                            foo: "abc".to_string(),
651                            bar: "def".to_string(),
652                        },
653                    )
654                    .await
655                    .unwrap();
656
657                assert_eq!(&response.foo, "abc123");
658                assert_eq!(&response.bar, "def456");
659
660                let status = client
661                    .call()
662                    .start_raw(items::Example::NAME, "unknown", Vec::new())
663                    .await
664                    .unwrap_err();
665
666                assert_eq!(status.code, Code::Unimplemented as i32);
667
668                client.shutdown().await;
669            })
670        });
671
672        block_on(async {
673            let (_, req) = recv.next().await.unwrap();
674            match req {
675                items::Example::Method1(input, resp) => {
676                    assert_eq!(&input.foo, "abc");
677                    assert_eq!(&input.bar, "def");
678                    resp.send(Ok(items::Method1Response {
679                        foo: input.foo + "123",
680                        bar: input.bar + "456",
681                    }));
682                }
683                _ => panic!("{:?}", &req),
684            }
685
686            assert!(recv.next().await.is_none());
687        });
688
689        client_thread.join().unwrap();
690        server_thread.join().unwrap().unwrap();
691    }
692}