1use crate::message::MESSAGE_TYPE_REQUEST;
7use crate::message::MESSAGE_TYPE_RESPONSE;
8use crate::message::ReadResult;
9use crate::message::Request;
10use crate::message::Response;
11use crate::message::TooLongError;
12use crate::message::read_message;
13use crate::message::write_message;
14use crate::rpc::ProtocolError;
15use crate::rpc::status_from_err;
16use crate::service::Code;
17use crate::service::DecodedRpc;
18use crate::service::GenericRpc;
19use crate::service::ServiceRpc;
20use crate::service::ServiceRpcError;
21use crate::service::Status;
22use futures::FutureExt;
23use futures::Stream;
24use futures::StreamExt;
25use futures::stream::FusedStream;
26use futures_concurrency::future::TryJoin;
27use futures_concurrency::stream::Merge;
28use mesh::CancelContext;
29use mesh::MeshPayload;
30use mesh::local_node::Port;
31use pal_async::driver::Driver;
32use pal_async::socket::AsSockRef;
33use pal_async::socket::Listener;
34use pal_async::socket::PolledSocket;
35use std::collections::HashMap;
36use std::io::Read;
37use std::io::Write;
38use std::pin::Pin;
39use std::task::ready;
40use unicycle::FuturesUnordered;
41
42#[derive(Debug, Default)]
44pub struct Server {
45 services: HashMap<&'static str, mesh::Sender<(CancelContext, GenericRpc)>>,
46}
47
48#[derive(MeshPayload)]
52#[mesh(bound = "T: ServiceRpc")]
53pub struct RpcReceiver<T>(mesh::Receiver<(CancelContext, DecodedRpc<T>)>);
54
55impl<T: ServiceRpc> RpcReceiver<T> {
56 pub fn disconnected() -> Self {
59 let (_send, recv) = mesh::channel();
60 Self(recv)
61 }
62}
63
64impl<T: ServiceRpc> Stream for RpcReceiver<T> {
65 type Item = (CancelContext, T);
66
67 fn poll_next(
68 self: Pin<&mut Self>,
69 cx: &mut std::task::Context<'_>,
70 ) -> std::task::Poll<Option<Self::Item>> {
71 let this = self.get_mut();
72 while let Some((ctx, rpc)) = ready!(Pin::new(&mut this.0).poll_next(cx)) {
73 match rpc {
74 DecodedRpc::Rpc(rpc) => return Some((ctx, rpc)).into(),
75 DecodedRpc::Err { rpc, err } => {
76 rpc.fail(err);
77 }
78 }
79 }
80 None.into()
81 }
82}
83
84impl<T: ServiceRpc> FusedStream for RpcReceiver<T> {
85 fn is_terminated(&self) -> bool {
86 self.0.is_terminated()
87 }
88}
89
90impl GenericRpc {
91 fn fail(self, err: ServiceRpcError) {
92 let status = match err {
93 ServiceRpcError::UnknownMethod => Status {
94 code: Code::Unimplemented.into(),
95 message: format!("unknown method {}", self.method),
96 details: Vec::new(),
97 },
98 ServiceRpcError::InvalidInput(error) => status_from_err(Code::InvalidArgument, error),
99 };
100 self.respond_status(status);
101 }
102}
103
104impl Server {
105 pub fn new() -> Self {
107 Self {
108 services: Default::default(),
109 }
110 }
111
112 pub fn add_service<T: ServiceRpc>(&mut self) -> RpcReceiver<T> {
114 let (send, recv) = mesh::channel();
115 self.services.insert(T::NAME, Port::from(send).into());
116 RpcReceiver(recv)
117 }
118
119 pub async fn run(
122 &mut self,
123 driver: &(impl Driver + ?Sized),
124 listener: impl Listener,
125 cancel: mesh::OneshotReceiver<()>,
126 ) -> anyhow::Result<()> {
127 let mut listener = PolledSocket::new(driver, listener)?;
128 let mut tasks = FuturesUnordered::new();
129 let mut cancel = cancel.fuse();
130 loop {
131 let conn = futures::select! { r = listener.accept().fuse() => r,
133 _ = tasks.next() => continue,
134 _ = cancel => break,
135 };
136 if let Ok(conn) = conn.and_then(|(conn, _)| PolledSocket::new(driver, conn)) {
137 tasks.push(async {
138 let _ = self.serve(conn).await.map_err(|err| {
139 tracing::error!(
140 error = err.as_ref() as &dyn std::error::Error,
141 "connection error"
142 )
143 });
144 });
145 }
146 }
147 Ok(())
148 }
149
150 pub async fn run_single(
152 &mut self,
153 driver: &(impl Driver + ?Sized),
154 conn: impl AsSockRef + Read + Write,
155 ) -> anyhow::Result<()> {
156 self.serve(PolledSocket::new(driver, conn)?).await
157 }
158
159 async fn serve(
160 &self,
161 stream: PolledSocket<impl AsSockRef + Read + Write>,
162 ) -> anyhow::Result<()> {
163 let (mut reader, mut writer) = stream.split();
164 let (stream_send, mut stream_recv) = mesh::channel();
165 let ctx = CancelContext::new();
166 let recv_task = async {
167 let stream_send = stream_send; while let Some(message) = read_message(&mut reader).await? {
169 let (send, recv) = mesh::oneshot::<Result<Vec<u8>, Status>>();
170 stream_send.send((message.stream_id, recv));
171
172 let handle = handle_message(message).and_then(|request| {
173 let service = self.services.get(request.service.as_str()).ok_or_else(|| {
174 status_from_err(
175 Code::Unimplemented,
176 anyhow::anyhow!("unknown service {}", request.service),
177 )
178 })?;
179
180 let ctx = if request.timeout_nano == 0 {
181 ctx.clone()
182 } else {
183 ctx.with_timeout(std::time::Duration::from_nanos(request.timeout_nano))
184 };
185
186 Ok(move |port| {
187 service.send((
188 ctx,
189 GenericRpc {
190 method: request.method,
191 data: request.payload,
192 port,
193 },
194 ));
195 })
196 });
197
198 match handle {
199 Ok(handle) => handle(send.into()),
200 Err(err) => send.send(Err(err)),
201 }
202 }
203 Ok(())
204 };
205 let send_task = async {
206 let mut responses = FuturesUnordered::new();
207 enum Event<T> {
208 Request((u32, mesh::OneshotReceiver<Result<Vec<u8>, Status>>)),
209 Response(T),
210 }
211 while let Some(event) = (
212 (&mut stream_recv).map(Event::Request),
213 (&mut responses).map(Event::Response),
214 )
215 .merge()
216 .next()
217 .await
218 {
219 match event {
220 Event::Request((stream_id, recv)) => {
221 let recv = recv.map(move |r| {
222 (
223 stream_id,
224 match r {
225 Ok(Ok(payload)) => Response::Payload(payload),
226 Ok(Err(status)) => Response::Status(status),
227 Err(err) => {
228 Response::Status(status_from_err(Code::Internal, err))
229 }
230 },
231 )
232 });
233 responses.push(recv);
234 }
235 Event::Response((stream_id, payload)) => {
236 write_message(
237 &mut writer,
238 stream_id,
239 MESSAGE_TYPE_RESPONSE,
240 &mesh::payload::encode(payload),
241 )
242 .await?;
243 }
244 }
245 }
246 anyhow::Result::<_>::Ok(())
247 };
248 (recv_task, send_task).try_join().await?;
249 Ok(())
250 }
251}
252
253fn handle_message(message: ReadResult) -> Result<Request, Status> {
254 if message.stream_id % 2 != 1 {
255 return Err(status_from_err(
256 Code::InvalidArgument,
257 ProtocolError::EvenStreamId,
258 ));
259 }
260
261 match message.message_type {
262 MESSAGE_TYPE_REQUEST => {
263 let payload = message.payload.map_err(|err @ TooLongError { .. }| {
264 status_from_err(Code::ResourceExhausted, err)
265 })?;
266 let request = mesh::payload::decode::<Request>(&payload)
267 .map_err(|err| status_from_err(Code::InvalidArgument, err))?;
268
269 tracing::debug!(
270 stream_id = message.stream_id,
271 service = %request.service,
272 method = %request.method,
273 timeout = request.timeout_nano / 1000 / 1000,
274 "message",
275 );
276
277 Ok(request)
278 }
279 ty => Err(status_from_err(
280 Code::InvalidArgument,
281 ProtocolError::InvalidMessageType(ty),
282 )),
283 }
284}
285
286#[cfg(feature = "grpc")]
287mod grpc {
288 use super::Server;
289 use crate::rpc::status_from_err;
290 use crate::service::Code;
291 use crate::service::GenericRpc;
292 use crate::service::Status;
293 use anyhow::Context as _;
294 use futures::AsyncRead as _;
295 use futures::AsyncWrite;
296 use futures::FutureExt;
297 use futures::StreamExt;
298 use futures_concurrency::stream::Merge;
299 use h2::RecvStream;
300 use h2::server::SendResponse;
301 use http::HeaderMap;
302 use http::HeaderValue;
303 use mesh::CancelContext;
304 use pal_async::driver::Driver;
305 use pal_async::socket::AsSockRef;
306 use pal_async::socket::Listener;
307 use pal_async::socket::PolledSocket;
308 use prost::bytes::Bytes;
309 use std::io::Read;
310 use std::io::Write;
311 use std::pin::Pin;
312 use std::task::ready;
313 use thiserror::Error;
314 use unicycle::FuturesUnordered;
315
316 #[derive(Debug, Error)]
317 enum RequestError {
318 #[error("http error")]
319 Http(#[from] http::Error),
320 #[error("http2 error")]
321 H2(#[from] h2::Error),
322 #[error("unreachable")]
323 Status(http::StatusCode),
324 #[error("invalid message header")]
325 InvalidHeader,
326 }
327
328 impl From<http::StatusCode> for RequestError {
329 fn from(status: http::StatusCode) -> Self {
330 RequestError::Status(status)
331 }
332 }
333
334 impl Server {
335 pub async fn run_grpc(
338 &mut self,
339 driver: &(impl Driver + ?Sized),
340 listener: impl Listener,
341 cancel: mesh::OneshotReceiver<()>,
342 ) -> anyhow::Result<()> {
343 let mut listener = PolledSocket::new(driver, listener)?;
344 let mut tasks = FuturesUnordered::new();
345 let mut cancel = cancel.fuse();
346 loop {
347 let conn = futures::select! { r = listener.accept().fuse() => r,
349 _ = tasks.next() => continue,
350 _ = cancel => break,
351 };
352 if let Ok(conn) = conn.and_then(|(conn, _)| PolledSocket::new(driver, conn)) {
353 tasks.push(async {
354 let _ = self.serve_grpc(conn).await.map_err(|err| {
355 tracing::error!(
356 error = err.as_ref() as &dyn std::error::Error,
357 "connection error"
358 )
359 });
360 });
361 }
362 }
363 Ok(())
364 }
365
366 async fn serve_grpc(
367 &self,
368 stream: PolledSocket<impl AsSockRef + Read + Write>,
369 ) -> anyhow::Result<()> {
370 struct Wrap<T>(T);
371
372 impl<T: AsSockRef + Read> tokio::io::AsyncRead for Wrap<PolledSocket<T>> {
373 fn poll_read(
374 self: Pin<&mut Self>,
375 cx: &mut std::task::Context<'_>,
376 buf: &mut tokio::io::ReadBuf<'_>,
377 ) -> std::task::Poll<std::io::Result<()>> {
378 let n = ready!(
379 Pin::new(&mut self.get_mut().0).poll_read(cx, buf.initialize_unfilled())
380 )?;
381 buf.advance(n);
382 std::task::Poll::Ready(Ok(()))
383 }
384 }
385
386 impl<T: AsSockRef + Write> tokio::io::AsyncWrite for Wrap<PolledSocket<T>> {
387 fn poll_write(
388 self: Pin<&mut Self>,
389 cx: &mut std::task::Context<'_>,
390 buf: &[u8],
391 ) -> std::task::Poll<Result<usize, std::io::Error>> {
392 Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
393 }
394
395 fn poll_flush(
396 self: Pin<&mut Self>,
397 cx: &mut std::task::Context<'_>,
398 ) -> std::task::Poll<Result<(), std::io::Error>> {
399 Pin::new(&mut self.get_mut().0).poll_flush(cx)
400 }
401
402 fn poll_shutdown(
403 self: Pin<&mut Self>,
404 cx: &mut std::task::Context<'_>,
405 ) -> std::task::Poll<Result<(), std::io::Error>> {
406 Pin::new(&mut self.get_mut().0).poll_close(cx)
407 }
408 }
409
410 let mut conn = h2::server::handshake(Wrap(stream))
411 .await
412 .context("failed http2 handshake")?;
413
414 let mut tasks = FuturesUnordered::new();
415
416 loop {
417 enum Event<A, B> {
418 Accept(A),
419 Task(Result<(), B>),
420 }
421
422 let r = (
423 (&mut conn).map(Event::Accept),
424 (&mut tasks).map(Event::Task),
425 )
426 .merge()
427 .next()
428 .await;
429
430 let (req, mut resp) = match r {
431 None => break,
432 Some(Event::Task(r)) => {
433 r?;
434 continue;
435 }
436 Some(Event::Accept(r)) => r.context("failed http2 stream accept")?,
437 };
438
439 let task = async move {
440 match self.handle_request(req, &mut resp).await {
441 Err(RequestError::Status(status)) => {
442 tracing::debug!(status = status.as_u16(), "request error");
443 resp.send_response(
444 http::Response::builder().status(status).body(())?,
445 true,
446 )?;
447 Ok(())
448 }
449 r => r,
450 }
451 };
452 tasks.push(task);
453 }
454
455 std::future::poll_fn(|cx| conn.poll_closed(cx)).await?;
456 Ok(())
457 }
458
459 async fn handle_request(
460 &self,
461 req: http::Request<RecvStream>,
462 resp: &mut SendResponse<Bytes>,
463 ) -> Result<(), RequestError> {
464 tracing::debug!(url = %req.uri(), "rpc request");
465
466 if req.method() != http::Method::POST {
467 Err(http::StatusCode::METHOD_NOT_ALLOWED)?
468 }
469 let content_type = req.headers().get("content-type");
470 match content_type.map(|v| v.as_bytes()) {
471 Some(b"application/grpc" | b"application/grpc+proto") => {}
472 _ => Err(http::StatusCode::UNSUPPORTED_MEDIA_TYPE)?,
473 }
474
475 let response =
476 http::Response::builder().header("content-type", "application/grpc+proto");
477
478 let ctx = if let Some(timeout) = req.headers().get("grpc-timeout") {
479 let timeout = timeout
480 .to_str()
481 .map_err(|_| http::StatusCode::BAD_REQUEST)?;
482 let mul = match timeout
483 .bytes()
484 .last()
485 .ok_or(http::StatusCode::BAD_REQUEST)?
486 {
487 b'H' => std::time::Duration::from_secs(60 * 60),
488 b'M' => std::time::Duration::from_secs(60),
489 b'S' => std::time::Duration::from_secs(1),
490 b'm' => std::time::Duration::from_millis(1),
491 b'u' => std::time::Duration::from_micros(1),
492 b'n' => std::time::Duration::from_nanos(1),
493 _ => Err(http::StatusCode::BAD_REQUEST)?,
494 };
495 let timeout = timeout[..timeout.len() - 1]
496 .parse::<u32>()
497 .map_err(|_| http::StatusCode::BAD_REQUEST)?;
498 CancelContext::new().with_timeout(mul * timeout)
499 } else {
500 CancelContext::new()
501 };
502
503 let (head, body) = req.into_parts();
504 let path = head.uri.path();
505 let path = path.strip_prefix('/').ok_or(http::StatusCode::NOT_FOUND)?;
506 let (service, method) = path.split_once('/').ok_or(http::StatusCode::NOT_FOUND)?;
507
508 let mut resp = resp.send_response(response.body(())?, false)?;
510
511 let result = self.invoke_rpc(service, method, body, ctx).await?;
512
513 let mut trailers = HeaderMap::new();
514 match result {
515 Ok(data) => {
516 tracing::debug!(service, method, "rpc success");
517
518 let mut buf = Vec::with_capacity(5 + data.len());
519 buf.push(0);
520 buf.extend(&(data.len() as u32).to_be_bytes());
521 buf.extend(data);
522 resp.send_data(buf.into(), false)?;
523 trailers.insert("grpc-status", const { HeaderValue::from_static("0") });
524 }
525 Err(status) => {
526 tracing::debug!(service, method, ?status, "rpc error");
527
528 trailers.insert("grpc-status", status.code.into());
529 trailers.insert(
530 "grpc-message",
531 urlencoding::encode(&status.message)
532 .into_owned()
533 .try_into()
534 .unwrap(),
535 );
536 trailers.insert(
537 "grpc-status-details-bin",
538 base64::Engine::encode(
539 &base64::engine::general_purpose::STANDARD,
540 prost::Message::encode_to_vec(&status),
541 )
542 .try_into()
543 .unwrap(),
544 );
545 }
546 }
547 resp.send_trailers(trailers)?;
548 Ok(())
549 }
550
551 async fn invoke_rpc(
552 &self,
553 service: &str,
554 method: &str,
555 mut body: RecvStream,
556 ctx: CancelContext,
557 ) -> Result<Result<Vec<u8>, Status>, RequestError> {
558 let Some(service) = self.services.get(service) else {
559 return Ok(Err(Status {
560 code: Code::Unimplemented.into(),
561 message: format!("unknown service {}", service),
562 details: Vec::new(),
563 }));
564 };
565
566 let mut buf = Vec::new();
573
574 while buf.len() < 5 {
576 let data = body.data().await.ok_or(RequestError::InvalidHeader)??;
577 buf.extend(&data);
578 body.flow_control().release_capacity(data.len()).unwrap();
579 }
580 let hdr = buf.get(0..5).ok_or(RequestError::InvalidHeader)?;
581 if hdr[0] != 0 {
582 return Err(RequestError::InvalidHeader);
585 }
586 let len = u32::from_be_bytes(hdr[1..5].try_into().unwrap()) as usize;
587
588 buf.drain(..5);
589 while buf.len() < len {
590 let data = body.data().await.ok_or(RequestError::InvalidHeader)??;
591 buf.extend(&data);
592 body.flow_control().release_capacity(data.len()).unwrap();
593 }
594
595 let (send, recv) = mesh::oneshot();
596
597 let rpc = GenericRpc {
598 method: method.to_owned(),
599 data: buf,
600 port: send.into(),
601 };
602
603 service.send((ctx, rpc));
604
605 Ok(recv
606 .await
607 .unwrap_or_else(|err| Err(status_from_err(Code::Internal, err))))
608 }
609 }
610}
611
612#[cfg(test)]
613mod tests {
614 use crate::Client;
615 use crate::Server;
616 use crate::client::ExistingConnection;
617 use crate::service::Code;
618 use crate::service::ServiceRpc;
619 use futures::StreamExt;
620 use futures::executor::block_on;
621 use pal_async::DefaultPool;
622 use pal_async::local::block_with_io;
623 use pal_async::socket::PolledSocket;
624 use test_with_tracing::test;
625
626 mod items {
627 include!(concat!(env!("OUT_DIR"), "/ttrpc.example.v1.rs"));
628 }
629
630 #[test]
631 fn client_server() {
632 let (c, s) = unix_socket::UnixStream::pair().unwrap();
633 let mut server = Server::new();
634 let mut recv = server.add_service::<items::Example>();
635 let server_thread = std::thread::spawn(move || {
636 block_with_io(async |driver| server.run_single(&driver, s).await)
637 });
638
639 let client_thread = std::thread::spawn(move || {
640 DefaultPool::run_with(async |driver| {
641 let client = Client::new(
642 &driver,
643 ExistingConnection::new(PolledSocket::new(&driver, c).unwrap()),
644 );
645 let response = client
646 .call()
647 .start(
648 items::Example::Method1,
649 items::Method1Request {
650 foo: "abc".to_string(),
651 bar: "def".to_string(),
652 },
653 )
654 .await
655 .unwrap();
656
657 assert_eq!(&response.foo, "abc123");
658 assert_eq!(&response.bar, "def456");
659
660 let status = client
661 .call()
662 .start_raw(items::Example::NAME, "unknown", Vec::new())
663 .await
664 .unwrap_err();
665
666 assert_eq!(status.code, Code::Unimplemented as i32);
667
668 client.shutdown().await;
669 })
670 });
671
672 block_on(async {
673 let (_, req) = recv.next().await.unwrap();
674 match req {
675 items::Example::Method1(input, resp) => {
676 assert_eq!(&input.foo, "abc");
677 assert_eq!(&input.bar, "def");
678 resp.send(Ok(items::Method1Response {
679 foo: input.foo + "123",
680 bar: input.bar + "456",
681 }));
682 }
683 _ => panic!("{:?}", &req),
684 }
685
686 assert!(recv.next().await.is_none());
687 });
688
689 client_thread.join().unwrap();
690 server_thread.join().unwrap().unwrap();
691 }
692}