mesh_remote/
point_to_point.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Point-to-point mesh implementation.
5
6use futures::AsyncBufReadExt;
7use futures::AsyncRead;
8use futures::AsyncReadExt;
9use futures::AsyncWrite;
10use futures::AsyncWriteExt;
11use futures::StreamExt;
12use futures::TryFutureExt;
13use futures::future::try_join;
14use futures::io::BufReader;
15use futures_concurrency::future::Race;
16use mesh_channel::cancel::Cancel;
17use mesh_channel::cancel::CancelContext;
18use mesh_channel::cancel::CancelReason;
19use mesh_node::common::Address;
20use mesh_node::common::NodeId;
21use mesh_node::common::PortId;
22use mesh_node::common::Uuid;
23use mesh_node::local_node::Connect;
24use mesh_node::local_node::LocalNode;
25use mesh_node::local_node::OutgoingEvent;
26use mesh_node::local_node::Port;
27use mesh_node::local_node::SendEvent;
28use pal_async::task::Spawn;
29use pal_async::task::Task;
30use std::io;
31use std::pin::pin;
32use thiserror::Error;
33use tracing::Instrument;
34use zerocopy::FromBytes;
35use zerocopy::FromZeros;
36use zerocopy::Immutable;
37use zerocopy::IntoBytes;
38use zerocopy::KnownLayout;
39
40/// A mesh that consists of exactly two nodes, communicating over an arbitrary
41/// bidirectional byte stream.
42///
43/// This byte stream could be a stream socket, a Windows named pipe, or a serial
44/// port, for example.
45///
46/// There is no support for OS resources (handles or file descriptors) in this
47/// mesh implementation. If a message containing OS resources is sent, the
48/// resources are dropped and the message is lost. Because this breaks the
49/// port's event sequence, all subsequent messages on the same port are also
50/// lost. Avoid sending OS resources over a point-to-point mesh.
51#[must_use]
52pub struct PointToPointMesh {
53    task: Task<()>,
54    cancel: Cancel,
55}
56
57impl PointToPointMesh {
58    /// Makes a new mesh over the connection `conn`, with initial port `port`.
59    ///
60    /// ```rust
61    /// # use mesh_remote::PointToPointMesh;
62    /// # use mesh_channel::channel;
63    /// # use unix_socket::UnixStream;
64    /// # use pal_async::socket::PolledSocket;
65    /// # pal_async::DefaultPool::run_with(async |driver| {
66    /// let (left, right) = UnixStream::pair().unwrap();
67    /// let (a, ax) = channel::<u32>();
68    /// let (bx, mut b) = channel::<u32>();
69    /// let left = PointToPointMesh::new(&driver, PolledSocket::new(&driver, left).unwrap(), ax.into());
70    /// let right = PointToPointMesh::new(&driver, PolledSocket::new(&driver, right).unwrap(), bx.into());
71    /// a.send(5);
72    /// assert_eq!(b.recv().await.unwrap(), 5);
73    /// # })
74    /// ```
75    pub fn new(
76        spawn: impl Spawn,
77        conn: impl 'static + AsyncRead + AsyncWrite + Send + Unpin,
78        port: Port,
79    ) -> Self {
80        let local_address = Address {
81            node: NodeId::new(),
82            port: PortId::new(),
83        };
84        let (mut ctx, cancel) = CancelContext::new().with_cancel();
85        let task = spawn.spawn(
86            format!("mesh-point-to-point-{:?}", local_address.node),
87            async move {
88                if let Err(err) = handle_comms(&mut ctx, Box::new(conn), local_address, port).await
89                {
90                    tracing::error!(error = &err as &dyn std::error::Error, "io failure");
91                }
92            }
93            .instrument(tracing::info_span!("mesh-point-to-point", node = ?local_address.node)),
94        );
95
96        Self { task, cancel }
97    }
98
99    /// Shuts down the mesh. Any pending messages are dropped.
100    pub async fn shutdown(mut self) {
101        self.cancel.cancel();
102        self.task.await;
103    }
104}
105
106trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin + Send {}
107impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncReadWrite for T {}
108
109#[derive(Debug, Error)]
110enum TaskError {
111    #[error("cancelled")]
112    Cancelled(#[from] CancelReason),
113    #[error("failed to change addresses")]
114    Exchange(#[source] io::Error),
115    #[error("failed to send data")]
116    Send(#[source] io::Error),
117    #[error("failed to receive data")]
118    Recv(#[source] io::Error),
119}
120
121async fn handle_comms(
122    ctx: &mut CancelContext,
123    conn: Box<dyn AsyncReadWrite>,
124    local_address: Address,
125    port: Port,
126) -> Result<(), TaskError> {
127    let (mut read, mut write) = conn.split();
128    let node = LocalNode::with_id(local_address.node, Box::new(NullConnector));
129
130    tracing::debug!("exchanging addresses");
131    let remote_address = ctx
132        .until_cancelled(exchange_addresses(local_address, &mut write, &mut read))
133        .await?
134        .map_err(TaskError::Exchange)?;
135
136    tracing::debug!(?local_address, ?remote_address, "connected to remote node");
137
138    let remote = node.add_remote(remote_address.node);
139    let (send_event, recv_event) = mesh_channel::channel();
140    remote.connect(PointToPointConnection(send_event));
141    let init_port = node.add_port(local_address.port, remote_address);
142    init_port.bridge(port);
143
144    let recv_loop = recv_loop(&remote_address.node, read, &node).map_err(TaskError::Recv);
145    let send_loop = send_loop(recv_event, write).map_err(TaskError::Send);
146
147    // Run until either send or receive finishes. If sending is done, then the
148    // remote node has been disconnected from `LocalNode`, so no more events
149    // need to be received. If receiving is done, then the remote node has
150    // disconnected its pipe, so it will not be accepting any more events.
151    let mut fut = pin!((recv_loop, send_loop).race());
152
153    let r = match ctx.until_cancelled(fut.as_mut()).await {
154        Ok(r) => r,
155        Err(_) => {
156            let shutdown = async {
157                node.wait_for_ports(false).await;
158                node.fail_all_nodes();
159                Ok(())
160            };
161            try_join(shutdown, fut).await.map(|((), ())| ())
162        }
163    };
164    match r {
165        Ok(()) => remote.disconnect(),
166        Err(err) => remote.fail(err),
167    }
168    Ok(())
169}
170
171async fn exchange_addresses(
172    local_address: Address,
173    write: &mut (impl AsyncWrite + Unpin),
174    read: &mut (impl AsyncRead + Unpin),
175) -> io::Result<Address> {
176    #[repr(C)]
177    #[derive(IntoBytes, Immutable, KnownLayout, FromBytes)]
178    struct Message {
179        magic: [u8; 4],
180        node: [u8; 16],
181        port: [u8; 16],
182    }
183
184    const MAGIC: [u8; 4] = *b"mesh";
185    let local_msg = Message {
186        magic: MAGIC,
187        node: (local_address.node.0).0,
188        port: (local_address.port.0).0,
189    };
190
191    let mut remote_msg = Message::new_zeroed();
192    try_join(
193        write.write_all(local_msg.as_bytes()),
194        read.read_exact(remote_msg.as_mut_bytes()),
195    )
196    .await?;
197
198    if remote_msg.magic != MAGIC {
199        return Err(io::Error::new(
200            io::ErrorKind::InvalidData,
201            "invalid address header",
202        ));
203    }
204
205    Ok(Address::new(
206        NodeId(Uuid(remote_msg.node)),
207        PortId(Uuid(remote_msg.port)),
208    ))
209}
210
211async fn recv_loop(
212    remote_id: &NodeId,
213    read: impl AsyncRead + Unpin,
214    node: &LocalNode,
215) -> io::Result<()> {
216    let mut read = BufReader::new(read);
217    loop {
218        let mut b = [0; 8];
219        if read.fill_buf().await?.is_empty() {
220            break;
221        }
222        read.read_exact(&mut b).await?;
223        let len = u64::from_le_bytes(b) as usize;
224        let buf = read.buffer();
225        if buf.len() >= len {
226            // Parse the event directly from the buffer.
227            node.event(remote_id, buf, &mut Vec::new());
228            read.consume_unpin(len);
229        } else {
230            // Read the whole event into a new buffer.
231            let mut b = vec![0; len];
232            read.read_exact(&mut b).await?;
233            node.event(remote_id, &b, &mut Vec::new());
234        }
235    }
236    tracing::debug!("recv loop done");
237    Ok(())
238}
239
240async fn send_loop(
241    mut recv_event: mesh_channel::Receiver<Vec<u8>>,
242    mut write: impl AsyncWrite + Unpin,
243) -> io::Result<()> {
244    while let Some(event) = recv_event.next().await {
245        write.write_all(&(event.len() as u64).to_le_bytes()).await?;
246        write.write_all(&event).await?;
247    }
248    tracing::debug!("send loop done");
249    Ok(())
250}
251
252#[derive(Debug)]
253struct PointToPointConnection(mesh_channel::Sender<Vec<u8>>);
254
255impl SendEvent for PointToPointConnection {
256    fn event(&self, event: OutgoingEvent<'_>) {
257        let len = event.len();
258        let mut v = Vec::with_capacity(len);
259        let mut resources = Vec::new();
260        event.write_to(&mut v, &mut resources);
261        if !resources.is_empty() {
262            // Still send the message so that the receiving side gets an error
263            // when decoding. Otherwise, the only other option at this point is
264            // to fail the whole connection, which is probably not what you
265            // want.
266            tracing::warn!("cannot send OS resources across a point-to-point connection");
267        }
268        self.0.send(v);
269    }
270}
271
272#[derive(Debug)]
273struct NullConnector;
274
275impl Connect for NullConnector {
276    fn connect(&self, _node_id: NodeId, handle: mesh_node::local_node::RemoteNodeHandle) {
277        handle.fail(NoMesh);
278    }
279}
280
281#[derive(Debug, Error)]
282#[error("no extra connections allowed in point-to-point mesh")]
283struct NoMesh;
284
285#[cfg(test)]
286mod tests {
287    use super::PointToPointMesh;
288    use mesh_channel::channel;
289    use pal_async::DefaultDriver;
290    use pal_async::async_test;
291    use pal_async::socket::PolledSocket;
292    use test_with_tracing::test;
293    use unix_socket::UnixStream;
294
295    #[async_test]
296    async fn test_point_to_point(driver: DefaultDriver) {
297        let (left, right) = UnixStream::pair().unwrap();
298        let left = PolledSocket::new(&driver, left).unwrap();
299        let right = PolledSocket::new(&driver, right).unwrap();
300        let (a, ax) = channel::<u32>();
301        let (bx, mut b) = channel::<u32>();
302        let left = PointToPointMesh::new(&driver, left, ax.into());
303        let right = PointToPointMesh::new(&driver, right, bx.into());
304        a.send(5);
305        assert_eq!(b.recv().await.unwrap(), 5);
306        left.shutdown().await;
307        right.shutdown().await;
308    }
309
310    /// Sending OS resources over a point-to-point mesh silently breaks the
311    /// affected port: the message containing the resource is lost, and all
312    /// subsequent messages on the same port are also lost because the
313    /// port's sequence counter gets out of sync.
314    #[async_test]
315    async fn test_point_to_point_os_resource_dropped(driver: DefaultDriver) {
316        let (left_sock, right_sock) = UnixStream::pair().unwrap();
317        let left_sock = PolledSocket::new(&driver, left_sock).unwrap();
318        let right_sock = PolledSocket::new(&driver, right_sock).unwrap();
319
320        // UnixStream is an OS resource (OwnedFd on Unix, OwnedSocket on
321        // Windows), so sending one exercises the OS resource path.
322        let (sender, sender_port) = channel::<UnixStream>();
323        let (receiver_port, mut receiver) = channel::<UnixStream>();
324        let left = PointToPointMesh::new(&driver, left_sock, sender_port.into());
325        let right = PointToPointMesh::new(&driver, right_sock, receiver_port.into());
326
327        // Send a message with an OS resource — it will be silently dropped.
328        let (fd, _other) = UnixStream::pair().unwrap();
329        sender.send(fd);
330
331        // Drop the sender to close the channel. The close event won't be
332        // delivered because the port is stuck on the missing sequence number,
333        // so recv() will only see the error when the mesh shuts down.
334        drop(sender);
335
336        left.shutdown().await;
337        right.shutdown().await;
338
339        // After shutdown, the receiver should see an error (not a successful
340        // message).
341        let result = receiver.recv().await;
342        assert!(result.is_err(), "expected error, got {:?}", result);
343    }
344}