mesh_rpc/
client.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! TTRPC client.
5
6use crate::message::MESSAGE_TYPE_REQUEST;
7use crate::message::MESSAGE_TYPE_RESPONSE;
8use crate::message::ReadResult;
9use crate::message::Request;
10use crate::message::Response;
11use crate::message::TooLongError;
12use crate::message::read_message;
13use crate::message::write_message;
14use crate::rpc::ProtocolError;
15use crate::rpc::status_from_err;
16use crate::service::Code;
17use crate::service::DecodedRpc;
18use crate::service::GenericRpc;
19use crate::service::ServiceRpc;
20use crate::service::Status;
21use anyhow::Context as _;
22use futures::AsyncRead;
23use futures::AsyncReadExt;
24use futures::AsyncWrite;
25use futures::FutureExt;
26use futures::StreamExt;
27use futures_concurrency::future::Race;
28use mesh::Deadline;
29use mesh::MeshPayload;
30use mesh::payload::EncodeAs;
31use mesh::payload::Timestamp;
32use pal_async::driver::Driver;
33use pal_async::socket::PolledSocket;
34use pal_async::task::Spawn;
35use pal_async::task::Task;
36use pal_async::timer::Instant;
37use pal_async::timer::PolledTimer;
38use parking_lot::Mutex;
39use std::collections::HashMap;
40use std::collections::VecDeque;
41use std::future::Future;
42use std::future::pending;
43use std::pin::pin;
44use std::task::ready;
45use std::time::Duration;
46use unix_socket::UnixStream;
47
48/// A TTRPC client connection.
49pub struct Client {
50    send: mesh::Sender<mesh::OwnedMessage>,
51    task: Task<()>,
52}
53
54#[derive(MeshPayload)]
55struct ClientRequest<T> {
56    service: String,
57    deadline: Option<EncodeAs<Deadline, Timestamp>>,
58    wait_ready: bool,
59    rpc: T,
60}
61
62/// Dials a connection to a server.
63pub trait Dial: 'static + Send {
64    /// A bidirectional byte stream connection to the server.
65    type Stream: 'static + Send + AsyncRead + AsyncWrite;
66
67    /// Connects to the server.
68    fn dial(&mut self) -> impl Future<Output = std::io::Result<Self::Stream>> + Send;
69}
70
71/// A [`Dial`] implementation that connects to a Unix domain socket.
72pub struct UnixDialier<T>(T, std::path::PathBuf);
73
74impl<T: Driver> UnixDialier<T> {
75    /// Returns a new dialier that connects to `path`.
76    pub fn new(driver: T, path: impl Into<std::path::PathBuf>) -> Self {
77        Self(driver, path.into())
78    }
79}
80
81impl<T: Driver> Dial for UnixDialier<T> {
82    type Stream = PolledSocket<UnixStream>;
83
84    fn dial(&mut self) -> impl Future<Output = std::io::Result<Self::Stream>> + Send {
85        PolledSocket::connect_unix(&self.0, &self.1)
86    }
87}
88
89/// A [`Dial`] implementation that uses an existing connection.
90///
91/// Once the connection terminates, subsequent connections will fail.
92pub struct ExistingConnection<T>(Option<T>);
93
94impl<T: 'static + Send + AsyncRead + AsyncWrite> ExistingConnection<T> {
95    /// Returns a new dialier that uses `socket`, once.
96    pub fn new(socket: T) -> Self {
97        Self(Some(socket))
98    }
99}
100
101impl<T: 'static + Send + AsyncRead + AsyncWrite> Dial for ExistingConnection<T> {
102    type Stream = T;
103
104    async fn dial(&mut self) -> std::io::Result<Self::Stream> {
105        self.0.take().ok_or_else(|| {
106            std::io::Error::new(
107                std::io::ErrorKind::AddrNotAvailable,
108                "connection already used",
109            )
110        })
111    }
112}
113
114/// A builder for [`Client`].
115pub struct ClientBuilder {
116    retry_timeout: Duration,
117}
118
119impl ClientBuilder {
120    /// Returns a new client builder.
121    pub fn new() -> Self {
122        Self {
123            // Use the gRPC default.
124            retry_timeout: Duration::from_secs(20),
125        }
126    }
127
128    /// Sets the timeout for a failed connection before attempting to reconnect.
129    pub fn retry_timeout(&mut self, timeout: Duration) -> &mut Self {
130        self.retry_timeout = timeout;
131        self
132    }
133
134    /// Builds a new client from a dialier.
135    pub fn build(&self, driver: &(impl Driver + Spawn), dialer: impl Dial) -> Client {
136        let (send, recv) = mesh::channel();
137        let worker = ClientWorker {
138            timer: PolledTimer::new(driver),
139            failure_timer: PolledTimer::new(driver),
140            dialer,
141            waiting: VecDeque::new(),
142            rpc_recv: Some(recv),
143            last_failure: None,
144            failure_timeout: self.retry_timeout,
145        };
146        let task = driver.spawn("ttrpc client", worker.run());
147        Client {
148            // Erase the type of the sender.
149            send: mesh::local_node::Port::from(send).into(),
150            task,
151        }
152    }
153}
154
155impl Client {
156    /// Creates a new client from a dialer.
157    pub fn new(driver: &(impl Driver + Spawn), dialer: impl Dial) -> Self {
158        ClientBuilder::new().build(driver, dialer)
159    }
160
161    /// Returns a [`CallBuilder`] to build RPCs.
162    pub fn call(&self) -> CallBuilder<'_> {
163        CallBuilder {
164            client: self,
165            deadline: None,
166            wait_ready: false,
167        }
168    }
169
170    /// Shuts down the client, waiting for the associated task to complete.
171    pub async fn shutdown(self) {
172        drop(self.send);
173        self.task.await;
174    }
175}
176
177/// A builder for RPCs returned by [`Client::call`].
178pub struct CallBuilder<'a> {
179    client: &'a Client,
180    deadline: Option<Deadline>,
181    wait_ready: bool,
182}
183
184/// A future representing an RPC call.
185pub struct Call<T>(mesh::OneshotReceiver<Result<T, Status>>);
186
187impl<T> std::fmt::Debug for Call<T> {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        f.debug_tuple("CallFuture").field(&self.0).finish()
190    }
191}
192
193impl CallBuilder<'_> {
194    /// Sets the timeout for the RPC.
195    ///
196    /// Internally, this will immediately compute a deadline that is `timeout` from now.
197    pub fn timeout(&mut self, timeout: Option<Duration>) -> &mut Self {
198        self.deadline = timeout.and_then(|timeout| Deadline::now().checked_add(timeout));
199        self
200    }
201
202    /// Sets the deadline for the RPC.
203    pub fn deadline(&mut self, deadline: Option<Deadline>) -> &mut Self {
204        self.deadline = deadline;
205        self
206    }
207
208    /// Sets whether the client should wait for the server to be ready before
209    /// sending the RPC.
210    ///
211    /// If this is not set and a connection to the server cannot be established,
212    /// the RPC will fail. Otherwise, the RPC will keep waiting for a connection
213    /// until its deadline.
214    pub fn wait_ready(&mut self, wait_ready: bool) -> &mut Self {
215        self.wait_ready = wait_ready;
216        self
217    }
218
219    /// Starts the RPC.
220    ///
221    /// To get the RPC result, `await` the returned future.
222    #[must_use]
223    pub fn start<F, R, T, U>(&self, rpc: F, input: T) -> Call<U>
224    where
225        F: FnOnce(T, mesh::OneshotSender<Result<U, Status>>) -> R,
226        R: ServiceRpc,
227        U: 'static + MeshPayload + Send,
228    {
229        let (send, recv) = mesh::oneshot();
230
231        self.client
232            .send
233            .send(mesh::OwnedMessage::new(ClientRequest {
234                service: R::NAME.to_string(),
235                deadline: self.deadline.map(Into::into),
236                rpc: DecodedRpc::Rpc(rpc(input, send)),
237                wait_ready: self.wait_ready,
238            }));
239
240        Call(recv)
241    }
242
243    /// Used to send unknown requests for testing.
244    #[cfg(test)]
245    pub(crate) fn start_raw(&self, service: &str, method: &str, data: Vec<u8>) -> Call<Vec<u8>> {
246        let (send, recv) = mesh::oneshot();
247
248        self.client
249            .send
250            .send(mesh::OwnedMessage::new(ClientRequest {
251                service: service.to_string(),
252                deadline: self.deadline.map(Into::into),
253                rpc: GenericRpc {
254                    method: method.to_string(),
255                    data,
256                    port: send.into(),
257                },
258                wait_ready: self.wait_ready,
259            }));
260
261        Call(recv)
262    }
263}
264
265impl<T: 'static + Send> Future for Call<T> {
266    type Output = Result<T, Status>;
267
268    fn poll(
269        self: std::pin::Pin<&mut Self>,
270        cx: &mut std::task::Context<'_>,
271    ) -> std::task::Poll<Self::Output> {
272        match ready!(self.get_mut().0.poll_unpin(cx)) {
273            Ok(r) => r,
274            Err(err) => Err(status_from_err(Code::Unavailable, err)),
275        }
276        .into()
277    }
278}
279
280struct ClientWorker<T> {
281    dialer: T,
282    timer: PolledTimer,
283    failure_timer: PolledTimer,
284    waiting: VecDeque<ClientRequest<GenericRpc>>,
285    rpc_recv: Option<mesh::Receiver<ClientRequest<GenericRpc>>>,
286    last_failure: Option<Instant>,
287    failure_timeout: Duration,
288}
289
290impl<T: Dial> ClientWorker<T> {
291    async fn run(mut self) {
292        loop {
293            let r = match self.wait_connect().await {
294                None => break,
295                Some(Ok(stream)) => {
296                    tracing::debug!("connection established");
297                    self.run_connection(stream).await.inspect_err(|err| {
298                        tracing::debug!(
299                            error = err.as_ref() as &dyn std::error::Error,
300                            "connection failed"
301                        );
302                    })
303                }
304                Some(Err(err)) => {
305                    tracing::debug!(error = &err as &dyn std::error::Error, "failed to connect");
306                    Err(err.into())
307                }
308            };
309            if let Err(err) = r {
310                let status = status_from_err(Code::Unavailable, err);
311                self.waiting = self
312                    .waiting
313                    .drain(..)
314                    .filter_map(|req| {
315                        if req.wait_ready {
316                            return Some(req);
317                        }
318                        req.rpc.respond_status(status.clone());
319                        None
320                    })
321                    .collect();
322                self.last_failure = Some(Instant::now());
323            }
324        }
325        tracing::debug!("shutting down");
326    }
327
328    async fn wait_connect(&mut self) -> Option<std::io::Result<T::Stream>> {
329        let mut dial = pin!(self.dialer.dial());
330        while self.rpc_recv.is_some() || !self.waiting.is_empty() {
331            let oldest_deadline = self
332                .waiting
333                .iter()
334                .filter_map(|v| v.deadline.map(|d| *d))
335                .min();
336            let sleep = async {
337                if let Some(deadline) = oldest_deadline {
338                    self.timer.sleep(deadline - Deadline::now()).await;
339                } else {
340                    pending().await
341                }
342            };
343            let next = async {
344                if let Some(recv) = &mut self.rpc_recv {
345                    recv.next().await
346                } else {
347                    pending().await
348                }
349            };
350            let connect = async {
351                if !self.waiting.is_empty() {
352                    if let Some(last_failure) = self.last_failure {
353                        self.failure_timer
354                            .sleep_until(last_failure + self.failure_timeout)
355                            .await;
356                    }
357                    (&mut dial).await
358                } else {
359                    pending().await
360                }
361            };
362
363            enum Event<T> {
364                Request(Option<ClientRequest<GenericRpc>>),
365                Timeout(()),
366                Connect(std::io::Result<T>),
367            }
368
369            match (
370                connect.map(Event::Connect),
371                next.map(Event::Request),
372                sleep.map(Event::Timeout),
373            )
374                .race()
375                .await
376            {
377                Event::Request(req) => {
378                    if let Some(req) = req {
379                        self.waiting.push_back(req);
380                    } else {
381                        self.rpc_recv = None;
382                    }
383                }
384                Event::Timeout(()) => {
385                    let now = Deadline::now();
386                    self.waiting = self
387                        .waiting
388                        .drain(..)
389                        .filter_map(|req| {
390                            if let Some(deadline) = req.deadline {
391                                if *deadline <= now {
392                                    req.rpc.respond_status(Status {
393                                        code: Code::DeadlineExceeded as i32,
394                                        message: "deadline exceeded".to_string(),
395                                        details: Vec::new(),
396                                    });
397                                    return None;
398                                }
399                            }
400                            Some(req)
401                        })
402                        .collect();
403                }
404                Event::Connect(r) => {
405                    return Some(r);
406                }
407            }
408        }
409        None
410    }
411
412    async fn run_connection(&mut self, stream: T::Stream) -> anyhow::Result<()> {
413        let (mut reader, mut writer) = AsyncReadExt::split(stream);
414        let responses = Mutex::new(HashMap::<u32, mesh::OneshotSender<mesh::OwnedMessage>>::new());
415        let recv_task = async {
416            while let Some(message) = read_message(&mut reader)
417                .await
418                .context("fatal connection error")?
419            {
420                let stream_id = message.stream_id;
421                tracing::debug!(stream_id, "response");
422
423                let response_send = responses.lock().remove(&stream_id);
424
425                let Some(response_send) = response_send else {
426                    tracing::error!(stream_id, "response for unknown stream");
427                    continue;
428                };
429
430                let result = handle_message(message);
431
432                response_send.send(mesh::OwnedMessage::new(result));
433            }
434            Ok(())
435        };
436
437        let send_task = async {
438            let mut next_stream_id = 1;
439            loop {
440                let request = if let Some(req) = self.waiting.pop_front() {
441                    Some(req)
442                } else if let Some(recv) = &mut self.rpc_recv {
443                    recv.next().await
444                } else {
445                    None
446                };
447                let Some(request) = request else {
448                    break;
449                };
450                responses
451                    .lock()
452                    .insert(next_stream_id, request.rpc.port.into());
453
454                let payload = mesh::payload::encode(Request {
455                    service: request.service,
456                    method: request.rpc.method,
457                    payload: request.rpc.data,
458                    timeout_nano: request.deadline.map_or(0, |deadline| {
459                        (*deadline - Deadline::now()).as_nanos() as u64
460                    }),
461                    metadata: vec![],
462                });
463
464                write_message(&mut writer, next_stream_id, MESSAGE_TYPE_REQUEST, &payload)
465                    .await
466                    .context("failed to write to connection")?;
467
468                next_stream_id = next_stream_id.wrapping_add(2);
469            }
470            Ok(())
471        };
472
473        (send_task, recv_task).race().await
474    }
475}
476
477fn handle_message(message: ReadResult) -> Result<Vec<u8>, Status> {
478    match message.message_type {
479        MESSAGE_TYPE_RESPONSE => {
480            let payload = message.payload.map_err(|err @ TooLongError { .. }| {
481                status_from_err(Code::ResourceExhausted, err)
482            })?;
483
484            let response = mesh::payload::decode(&payload)
485                .map_err(|err| status_from_err(Code::Unknown, err))?;
486
487            match response {
488                Response::Payload(payload) => Ok(payload),
489                Response::Status(status) => Err(status),
490            }
491        }
492        ty => Err(status_from_err(
493            Code::Internal,
494            ProtocolError::InvalidMessageType(ty),
495        )),
496    }
497}
498
499#[cfg(test)]
500mod tests {
501    use super::Client;
502    use super::Dial;
503    use crate::service::Code;
504    use mesh::CancelContext;
505    use mesh::Deadline;
506    use pal_async::DefaultDriver;
507    use pal_async::async_test;
508    use pal_async::socket::PolledSocket;
509    use std::future::pending;
510    use std::net::TcpStream;
511    use std::time::Duration;
512    use test_with_tracing::test;
513
514    struct NeverDial;
515
516    impl Dial for NeverDial {
517        type Stream = PolledSocket<TcpStream>;
518
519        async fn dial(&mut self) -> std::io::Result<Self::Stream> {
520            pending().await
521        }
522    }
523
524    struct FailDial;
525
526    impl Dial for FailDial {
527        type Stream = PolledSocket<TcpStream>;
528
529        async fn dial(&mut self) -> std::io::Result<Self::Stream> {
530            Err(std::io::ErrorKind::NotConnected.into())
531        }
532    }
533
534    #[async_test]
535    async fn test_failed_connect(driver: DefaultDriver) {
536        let client = Client::new(&driver, FailDial);
537        let err = client
538            .call()
539            .start_raw("service", "method", vec![])
540            .await
541            .unwrap_err();
542
543        assert_eq!(err.code, Code::Unavailable as i32);
544        assert!(err.message.contains("not connected"));
545    }
546
547    #[async_test]
548    async fn test_delayed_connect_never(driver: DefaultDriver) {
549        let client = Client::new(&driver, NeverDial);
550
551        // The request should not fail within the cancel context timeout.
552        CancelContext::new()
553            .with_timeout(Duration::from_millis(250))
554            .until_cancelled(
555                client
556                    .call()
557                    .deadline(Some(Deadline::now() + Duration::from_secs(60)))
558                    .start_raw("service", "method", vec![]),
559            )
560            .await
561            .unwrap_err();
562    }
563
564    #[async_test]
565    async fn test_delayed_connect(driver: DefaultDriver) {
566        let client = Client::new(&driver, NeverDial);
567        let err = client
568            .call()
569            .deadline(Some(Deadline::now() + Duration::from_millis(200)))
570            .start_raw("service", "method", vec![])
571            .await
572            .unwrap_err();
573
574        assert_eq!(err.code, Code::DeadlineExceeded as i32);
575    }
576}