1use crate::message::MESSAGE_TYPE_REQUEST;
7use crate::message::MESSAGE_TYPE_RESPONSE;
8use crate::message::ReadResult;
9use crate::message::Request;
10use crate::message::Response;
11use crate::message::TooLongError;
12use crate::message::read_message;
13use crate::message::write_message;
14use crate::rpc::ProtocolError;
15use crate::rpc::status_from_err;
16use crate::service::Code;
17use crate::service::DecodedRpc;
18use crate::service::GenericRpc;
19use crate::service::ServiceRpc;
20use crate::service::Status;
21use anyhow::Context as _;
22use futures::AsyncRead;
23use futures::AsyncReadExt;
24use futures::AsyncWrite;
25use futures::FutureExt;
26use futures::StreamExt;
27use futures_concurrency::future::Race;
28use mesh::Deadline;
29use mesh::MeshPayload;
30use mesh::payload::EncodeAs;
31use mesh::payload::Timestamp;
32use pal_async::driver::Driver;
33use pal_async::socket::PolledSocket;
34use pal_async::task::Spawn;
35use pal_async::task::Task;
36use pal_async::timer::Instant;
37use pal_async::timer::PolledTimer;
38use parking_lot::Mutex;
39use std::collections::HashMap;
40use std::collections::VecDeque;
41use std::future::Future;
42use std::future::pending;
43use std::pin::pin;
44use std::task::ready;
45use std::time::Duration;
46use unix_socket::UnixStream;
47
48pub struct Client {
50 send: mesh::Sender<mesh::OwnedMessage>,
51 task: Task<()>,
52}
53
54#[derive(MeshPayload)]
55struct ClientRequest<T> {
56 service: String,
57 deadline: Option<EncodeAs<Deadline, Timestamp>>,
58 wait_ready: bool,
59 rpc: T,
60}
61
62pub trait Dial: 'static + Send {
64 type Stream: 'static + Send + AsyncRead + AsyncWrite;
66
67 fn dial(&mut self) -> impl Future<Output = std::io::Result<Self::Stream>> + Send;
69}
70
71pub struct UnixDialier<T>(T, std::path::PathBuf);
73
74impl<T: Driver> UnixDialier<T> {
75 pub fn new(driver: T, path: impl Into<std::path::PathBuf>) -> Self {
77 Self(driver, path.into())
78 }
79}
80
81impl<T: Driver> Dial for UnixDialier<T> {
82 type Stream = PolledSocket<UnixStream>;
83
84 fn dial(&mut self) -> impl Future<Output = std::io::Result<Self::Stream>> + Send {
85 PolledSocket::connect_unix(&self.0, &self.1)
86 }
87}
88
89pub struct ExistingConnection<T>(Option<T>);
93
94impl<T: 'static + Send + AsyncRead + AsyncWrite> ExistingConnection<T> {
95 pub fn new(socket: T) -> Self {
97 Self(Some(socket))
98 }
99}
100
101impl<T: 'static + Send + AsyncRead + AsyncWrite> Dial for ExistingConnection<T> {
102 type Stream = T;
103
104 async fn dial(&mut self) -> std::io::Result<Self::Stream> {
105 self.0.take().ok_or_else(|| {
106 std::io::Error::new(
107 std::io::ErrorKind::AddrNotAvailable,
108 "connection already used",
109 )
110 })
111 }
112}
113
114pub struct ClientBuilder {
116 retry_timeout: Duration,
117}
118
119impl ClientBuilder {
120 pub fn new() -> Self {
122 Self {
123 retry_timeout: Duration::from_secs(20),
125 }
126 }
127
128 pub fn retry_timeout(&mut self, timeout: Duration) -> &mut Self {
130 self.retry_timeout = timeout;
131 self
132 }
133
134 pub fn build(&self, driver: &(impl Driver + Spawn), dialer: impl Dial) -> Client {
136 let (send, recv) = mesh::channel();
137 let worker = ClientWorker {
138 timer: PolledTimer::new(driver),
139 failure_timer: PolledTimer::new(driver),
140 dialer,
141 waiting: VecDeque::new(),
142 rpc_recv: Some(recv),
143 last_failure: None,
144 failure_timeout: self.retry_timeout,
145 };
146 let task = driver.spawn("ttrpc client", worker.run());
147 Client {
148 send: mesh::local_node::Port::from(send).into(),
150 task,
151 }
152 }
153}
154
155impl Client {
156 pub fn new(driver: &(impl Driver + Spawn), dialer: impl Dial) -> Self {
158 ClientBuilder::new().build(driver, dialer)
159 }
160
161 pub fn call(&self) -> CallBuilder<'_> {
163 CallBuilder {
164 client: self,
165 deadline: None,
166 wait_ready: false,
167 }
168 }
169
170 pub async fn shutdown(self) {
172 drop(self.send);
173 self.task.await;
174 }
175}
176
177pub struct CallBuilder<'a> {
179 client: &'a Client,
180 deadline: Option<Deadline>,
181 wait_ready: bool,
182}
183
184pub struct Call<T>(mesh::OneshotReceiver<Result<T, Status>>);
186
187impl<T> std::fmt::Debug for Call<T> {
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 f.debug_tuple("CallFuture").field(&self.0).finish()
190 }
191}
192
193impl CallBuilder<'_> {
194 pub fn timeout(&mut self, timeout: Option<Duration>) -> &mut Self {
198 self.deadline = timeout.and_then(|timeout| Deadline::now().checked_add(timeout));
199 self
200 }
201
202 pub fn deadline(&mut self, deadline: Option<Deadline>) -> &mut Self {
204 self.deadline = deadline;
205 self
206 }
207
208 pub fn wait_ready(&mut self, wait_ready: bool) -> &mut Self {
215 self.wait_ready = wait_ready;
216 self
217 }
218
219 #[must_use]
223 pub fn start<F, R, T, U>(&self, rpc: F, input: T) -> Call<U>
224 where
225 F: FnOnce(T, mesh::OneshotSender<Result<U, Status>>) -> R,
226 R: ServiceRpc,
227 U: 'static + MeshPayload + Send,
228 {
229 let (send, recv) = mesh::oneshot();
230
231 self.client
232 .send
233 .send(mesh::OwnedMessage::new(ClientRequest {
234 service: R::NAME.to_string(),
235 deadline: self.deadline.map(Into::into),
236 rpc: DecodedRpc::Rpc(rpc(input, send)),
237 wait_ready: self.wait_ready,
238 }));
239
240 Call(recv)
241 }
242
243 #[cfg(test)]
245 pub(crate) fn start_raw(&self, service: &str, method: &str, data: Vec<u8>) -> Call<Vec<u8>> {
246 let (send, recv) = mesh::oneshot();
247
248 self.client
249 .send
250 .send(mesh::OwnedMessage::new(ClientRequest {
251 service: service.to_string(),
252 deadline: self.deadline.map(Into::into),
253 rpc: GenericRpc {
254 method: method.to_string(),
255 data,
256 port: send.into(),
257 },
258 wait_ready: self.wait_ready,
259 }));
260
261 Call(recv)
262 }
263}
264
265impl<T: 'static + Send> Future for Call<T> {
266 type Output = Result<T, Status>;
267
268 fn poll(
269 self: std::pin::Pin<&mut Self>,
270 cx: &mut std::task::Context<'_>,
271 ) -> std::task::Poll<Self::Output> {
272 match ready!(self.get_mut().0.poll_unpin(cx)) {
273 Ok(r) => r,
274 Err(err) => Err(status_from_err(Code::Unavailable, err)),
275 }
276 .into()
277 }
278}
279
280struct ClientWorker<T> {
281 dialer: T,
282 timer: PolledTimer,
283 failure_timer: PolledTimer,
284 waiting: VecDeque<ClientRequest<GenericRpc>>,
285 rpc_recv: Option<mesh::Receiver<ClientRequest<GenericRpc>>>,
286 last_failure: Option<Instant>,
287 failure_timeout: Duration,
288}
289
290impl<T: Dial> ClientWorker<T> {
291 async fn run(mut self) {
292 loop {
293 let r = match self.wait_connect().await {
294 None => break,
295 Some(Ok(stream)) => {
296 tracing::debug!("connection established");
297 self.run_connection(stream).await.inspect_err(|err| {
298 tracing::debug!(
299 error = err.as_ref() as &dyn std::error::Error,
300 "connection failed"
301 );
302 })
303 }
304 Some(Err(err)) => {
305 tracing::debug!(error = &err as &dyn std::error::Error, "failed to connect");
306 Err(err.into())
307 }
308 };
309 if let Err(err) = r {
310 let status = status_from_err(Code::Unavailable, err);
311 self.waiting = self
312 .waiting
313 .drain(..)
314 .filter_map(|req| {
315 if req.wait_ready {
316 return Some(req);
317 }
318 req.rpc.respond_status(status.clone());
319 None
320 })
321 .collect();
322 self.last_failure = Some(Instant::now());
323 }
324 }
325 tracing::debug!("shutting down");
326 }
327
328 async fn wait_connect(&mut self) -> Option<std::io::Result<T::Stream>> {
329 let mut dial = pin!(self.dialer.dial());
330 while self.rpc_recv.is_some() || !self.waiting.is_empty() {
331 let oldest_deadline = self
332 .waiting
333 .iter()
334 .filter_map(|v| v.deadline.map(|d| *d))
335 .min();
336 let sleep = async {
337 if let Some(deadline) = oldest_deadline {
338 self.timer.sleep(deadline - Deadline::now()).await;
339 } else {
340 pending().await
341 }
342 };
343 let next = async {
344 if let Some(recv) = &mut self.rpc_recv {
345 recv.next().await
346 } else {
347 pending().await
348 }
349 };
350 let connect = async {
351 if !self.waiting.is_empty() {
352 if let Some(last_failure) = self.last_failure {
353 self.failure_timer
354 .sleep_until(last_failure + self.failure_timeout)
355 .await;
356 }
357 (&mut dial).await
358 } else {
359 pending().await
360 }
361 };
362
363 enum Event<T> {
364 Request(Option<ClientRequest<GenericRpc>>),
365 Timeout(()),
366 Connect(std::io::Result<T>),
367 }
368
369 match (
370 connect.map(Event::Connect),
371 next.map(Event::Request),
372 sleep.map(Event::Timeout),
373 )
374 .race()
375 .await
376 {
377 Event::Request(req) => {
378 if let Some(req) = req {
379 self.waiting.push_back(req);
380 } else {
381 self.rpc_recv = None;
382 }
383 }
384 Event::Timeout(()) => {
385 let now = Deadline::now();
386 self.waiting = self
387 .waiting
388 .drain(..)
389 .filter_map(|req| {
390 if let Some(deadline) = req.deadline {
391 if *deadline <= now {
392 req.rpc.respond_status(Status {
393 code: Code::DeadlineExceeded as i32,
394 message: "deadline exceeded".to_string(),
395 details: Vec::new(),
396 });
397 return None;
398 }
399 }
400 Some(req)
401 })
402 .collect();
403 }
404 Event::Connect(r) => {
405 return Some(r);
406 }
407 }
408 }
409 None
410 }
411
412 async fn run_connection(&mut self, stream: T::Stream) -> anyhow::Result<()> {
413 let (mut reader, mut writer) = AsyncReadExt::split(stream);
414 let responses = Mutex::new(HashMap::<u32, mesh::OneshotSender<mesh::OwnedMessage>>::new());
415 let recv_task = async {
416 while let Some(message) = read_message(&mut reader)
417 .await
418 .context("fatal connection error")?
419 {
420 let stream_id = message.stream_id;
421 tracing::debug!(stream_id, "response");
422
423 let response_send = responses.lock().remove(&stream_id);
424
425 let Some(response_send) = response_send else {
426 tracing::error!(stream_id, "response for unknown stream");
427 continue;
428 };
429
430 let result = handle_message(message);
431
432 response_send.send(mesh::OwnedMessage::new(result));
433 }
434 Ok(())
435 };
436
437 let send_task = async {
438 let mut next_stream_id = 1;
439 loop {
440 let request = if let Some(req) = self.waiting.pop_front() {
441 Some(req)
442 } else if let Some(recv) = &mut self.rpc_recv {
443 recv.next().await
444 } else {
445 None
446 };
447 let Some(request) = request else {
448 break;
449 };
450 responses
451 .lock()
452 .insert(next_stream_id, request.rpc.port.into());
453
454 let payload = mesh::payload::encode(Request {
455 service: request.service,
456 method: request.rpc.method,
457 payload: request.rpc.data,
458 timeout_nano: request.deadline.map_or(0, |deadline| {
459 (*deadline - Deadline::now()).as_nanos() as u64
460 }),
461 metadata: vec![],
462 });
463
464 write_message(&mut writer, next_stream_id, MESSAGE_TYPE_REQUEST, &payload)
465 .await
466 .context("failed to write to connection")?;
467
468 next_stream_id = next_stream_id.wrapping_add(2);
469 }
470 Ok(())
471 };
472
473 (send_task, recv_task).race().await
474 }
475}
476
477fn handle_message(message: ReadResult) -> Result<Vec<u8>, Status> {
478 match message.message_type {
479 MESSAGE_TYPE_RESPONSE => {
480 let payload = message.payload.map_err(|err @ TooLongError { .. }| {
481 status_from_err(Code::ResourceExhausted, err)
482 })?;
483
484 let response = mesh::payload::decode(&payload)
485 .map_err(|err| status_from_err(Code::Unknown, err))?;
486
487 match response {
488 Response::Payload(payload) => Ok(payload),
489 Response::Status(status) => Err(status),
490 }
491 }
492 ty => Err(status_from_err(
493 Code::Internal,
494 ProtocolError::InvalidMessageType(ty),
495 )),
496 }
497}
498
499#[cfg(test)]
500mod tests {
501 use super::Client;
502 use super::Dial;
503 use crate::service::Code;
504 use mesh::CancelContext;
505 use mesh::Deadline;
506 use pal_async::DefaultDriver;
507 use pal_async::async_test;
508 use pal_async::socket::PolledSocket;
509 use std::future::pending;
510 use std::net::TcpStream;
511 use std::time::Duration;
512 use test_with_tracing::test;
513
514 struct NeverDial;
515
516 impl Dial for NeverDial {
517 type Stream = PolledSocket<TcpStream>;
518
519 async fn dial(&mut self) -> std::io::Result<Self::Stream> {
520 pending().await
521 }
522 }
523
524 struct FailDial;
525
526 impl Dial for FailDial {
527 type Stream = PolledSocket<TcpStream>;
528
529 async fn dial(&mut self) -> std::io::Result<Self::Stream> {
530 Err(std::io::ErrorKind::NotConnected.into())
531 }
532 }
533
534 #[async_test]
535 async fn test_failed_connect(driver: DefaultDriver) {
536 let client = Client::new(&driver, FailDial);
537 let err = client
538 .call()
539 .start_raw("service", "method", vec![])
540 .await
541 .unwrap_err();
542
543 assert_eq!(err.code, Code::Unavailable as i32);
544 assert!(err.message.contains("not connected"));
545 }
546
547 #[async_test]
548 async fn test_delayed_connect_never(driver: DefaultDriver) {
549 let client = Client::new(&driver, NeverDial);
550
551 CancelContext::new()
553 .with_timeout(Duration::from_millis(250))
554 .until_cancelled(
555 client
556 .call()
557 .deadline(Some(Deadline::now() + Duration::from_secs(60)))
558 .start_raw("service", "method", vec![]),
559 )
560 .await
561 .unwrap_err();
562 }
563
564 #[async_test]
565 async fn test_delayed_connect(driver: DefaultDriver) {
566 let client = Client::new(&driver, NeverDial);
567 let err = client
568 .call()
569 .deadline(Some(Deadline::now() + Duration::from_millis(200)))
570 .start_raw("service", "method", vec![])
571 .await
572 .unwrap_err();
573
574 assert_eq!(err.code, Code::DeadlineExceeded as i32);
575 }
576}