mesh_remote/
point_to_point.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Point-to-point mesh implementation.
5
6use futures::AsyncBufReadExt;
7use futures::AsyncRead;
8use futures::AsyncReadExt;
9use futures::AsyncWrite;
10use futures::AsyncWriteExt;
11use futures::StreamExt;
12use futures::TryFutureExt;
13use futures::future::try_join;
14use futures::io::BufReader;
15use futures_concurrency::future::Race;
16use mesh_channel::cancel::Cancel;
17use mesh_channel::cancel::CancelContext;
18use mesh_channel::cancel::CancelReason;
19use mesh_node::common::Address;
20use mesh_node::common::NodeId;
21use mesh_node::common::PortId;
22use mesh_node::common::Uuid;
23use mesh_node::local_node::Connect;
24use mesh_node::local_node::LocalNode;
25use mesh_node::local_node::OutgoingEvent;
26use mesh_node::local_node::Port;
27use mesh_node::local_node::SendEvent;
28use pal_async::task::Spawn;
29use pal_async::task::Task;
30use std::io;
31use std::pin::pin;
32use thiserror::Error;
33use tracing::Instrument;
34use zerocopy::FromBytes;
35use zerocopy::FromZeros;
36use zerocopy::Immutable;
37use zerocopy::IntoBytes;
38use zerocopy::KnownLayout;
39
40/// A mesh that consists of exactly two nodes, communicating over an arbitrary
41/// bidirectional byte stream.
42///
43/// This byte stream could be a stream socket, a Windows named pipe, or a serial
44/// port, for example.
45///
46/// There is no support for OS resources (handles or file descriptors) in this
47/// mesh implementation. Any attempt to send OS resources will fail the
48/// underlying channel.
49#[must_use]
50pub struct PointToPointMesh {
51    task: Task<()>,
52    cancel: Cancel,
53}
54
55impl PointToPointMesh {
56    /// Makes a new mesh over the connection `conn`, with initial port `port`.
57    ///
58    /// ```rust
59    /// # use mesh_remote::PointToPointMesh;
60    /// # use mesh_channel::channel;
61    /// # use unix_socket::UnixStream;
62    /// # use pal_async::socket::PolledSocket;
63    /// # pal_async::DefaultPool::run_with(async |driver| {
64    /// let (left, right) = UnixStream::pair().unwrap();
65    /// let (a, ax) = channel::<u32>();
66    /// let (bx, mut b) = channel::<u32>();
67    /// let left = PointToPointMesh::new(&driver, PolledSocket::new(&driver, left).unwrap(), ax.into());
68    /// let right = PointToPointMesh::new(&driver, PolledSocket::new(&driver, right).unwrap(), bx.into());
69    /// a.send(5);
70    /// assert_eq!(b.recv().await.unwrap(), 5);
71    /// # })
72    /// ```
73    pub fn new(
74        spawn: impl Spawn,
75        conn: impl 'static + AsyncRead + AsyncWrite + Send + Unpin,
76        port: Port,
77    ) -> Self {
78        let local_address = Address {
79            node: NodeId::new(),
80            port: PortId::new(),
81        };
82        let (mut ctx, cancel) = CancelContext::new().with_cancel();
83        let task = spawn.spawn(
84            format!("mesh-point-to-point-{:?}", local_address.node),
85            async move {
86                if let Err(err) = handle_comms(&mut ctx, Box::new(conn), local_address, port).await
87                {
88                    tracing::error!(error = &err as &dyn std::error::Error, "io failure");
89                }
90            }
91            .instrument(tracing::info_span!("mesh-point-to-point", node = ?local_address.node)),
92        );
93
94        Self { task, cancel }
95    }
96
97    /// Shuts down the mesh. Any pending messages are dropped.
98    pub async fn shutdown(mut self) {
99        self.cancel.cancel();
100        self.task.await;
101    }
102}
103
104trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin + Send {}
105impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncReadWrite for T {}
106
107#[derive(Debug, Error)]
108enum TaskError {
109    #[error("cancelled")]
110    Cancelled(#[from] CancelReason),
111    #[error("failed to change addresses")]
112    Exchange(#[source] io::Error),
113    #[error("failed to send data")]
114    Send(#[source] io::Error),
115    #[error("failed to receive data")]
116    Recv(#[source] io::Error),
117}
118
119async fn handle_comms(
120    ctx: &mut CancelContext,
121    conn: Box<dyn AsyncReadWrite>,
122    local_address: Address,
123    port: Port,
124) -> Result<(), TaskError> {
125    let (mut read, mut write) = conn.split();
126    let node = LocalNode::with_id(local_address.node, Box::new(NullConnector));
127
128    tracing::debug!("exchanging addresses");
129    let remote_address = ctx
130        .until_cancelled(exchange_addresses(local_address, &mut write, &mut read))
131        .await?
132        .map_err(TaskError::Exchange)?;
133
134    tracing::debug!(?local_address, ?remote_address, "connected to remote node");
135
136    let remote = node.add_remote(remote_address.node);
137    let (send_event, recv_event) = mesh_channel::channel();
138    remote.connect(PointToPointConnection(send_event));
139    let init_port = node.add_port(local_address.port, remote_address);
140    init_port.bridge(port);
141
142    let recv_loop = recv_loop(&remote_address.node, read, &node).map_err(TaskError::Recv);
143    let send_loop = send_loop(recv_event, write).map_err(TaskError::Send);
144
145    // Run until either send or receive finishes. If sending is done, then the
146    // remote node has been disconnected from `LocalNode`, so no more events
147    // need to be received. If receiving is done, then the remote node has
148    // disconnected its pipe, so it will not be accepting any more events.
149    let mut fut = pin!((recv_loop, send_loop).race());
150
151    let r = match ctx.until_cancelled(fut.as_mut()).await {
152        Ok(r) => r,
153        Err(_) => {
154            let shutdown = async {
155                node.wait_for_ports(false).await;
156                node.fail_all_nodes();
157                Ok(())
158            };
159            try_join(shutdown, fut).await.map(|((), ())| ())
160        }
161    };
162    match r {
163        Ok(()) => remote.disconnect(),
164        Err(err) => remote.fail(err),
165    }
166    Ok(())
167}
168
169async fn exchange_addresses(
170    local_address: Address,
171    write: &mut (impl AsyncWrite + Unpin),
172    read: &mut (impl AsyncRead + Unpin),
173) -> io::Result<Address> {
174    #[repr(C)]
175    #[derive(IntoBytes, Immutable, KnownLayout, FromBytes)]
176    struct Message {
177        magic: [u8; 4],
178        node: [u8; 16],
179        port: [u8; 16],
180    }
181
182    const MAGIC: [u8; 4] = *b"mesh";
183    let local_msg = Message {
184        magic: MAGIC,
185        node: (local_address.node.0).0,
186        port: (local_address.port.0).0,
187    };
188
189    let mut remote_msg = Message::new_zeroed();
190    try_join(
191        write.write_all(local_msg.as_bytes()),
192        read.read_exact(remote_msg.as_mut_bytes()),
193    )
194    .await?;
195
196    if remote_msg.magic != MAGIC {
197        return Err(io::Error::new(
198            io::ErrorKind::InvalidData,
199            "invalid address header",
200        ));
201    }
202
203    Ok(Address::new(
204        NodeId(Uuid(remote_msg.node)),
205        PortId(Uuid(remote_msg.port)),
206    ))
207}
208
209async fn recv_loop(
210    remote_id: &NodeId,
211    read: impl AsyncRead + Unpin,
212    node: &LocalNode,
213) -> io::Result<()> {
214    let mut read = BufReader::new(read);
215    loop {
216        let mut b = [0; 8];
217        if read.fill_buf().await?.is_empty() {
218            break;
219        }
220        read.read_exact(&mut b).await?;
221        let len = u64::from_le_bytes(b) as usize;
222        let buf = read.buffer();
223        if buf.len() >= len {
224            // Parse the event directly from the buffer.
225            node.event(remote_id, buf, &mut Vec::new());
226            read.consume_unpin(len);
227        } else {
228            // Read the whole event into a new buffer.
229            let mut b = vec![0; len];
230            read.read_exact(&mut b).await?;
231            node.event(remote_id, &b, &mut Vec::new());
232        }
233    }
234    tracing::debug!("recv loop done");
235    Ok(())
236}
237
238async fn send_loop(
239    mut recv_event: mesh_channel::Receiver<Vec<u8>>,
240    mut write: impl AsyncWrite + Unpin,
241) -> io::Result<()> {
242    while let Some(event) = recv_event.next().await {
243        write.write_all(&(event.len() as u64).to_le_bytes()).await?;
244        write.write_all(&event).await?;
245    }
246    tracing::debug!("send loop done");
247    Ok(())
248}
249
250#[derive(Debug)]
251struct PointToPointConnection(mesh_channel::Sender<Vec<u8>>);
252
253impl SendEvent for PointToPointConnection {
254    fn event(&self, event: OutgoingEvent<'_>) {
255        let len = event.len();
256        let mut v = Vec::with_capacity(len);
257        let mut resources = Vec::new();
258        event.write_to(&mut v, &mut resources);
259        if !resources.is_empty() {
260            // Still send the message so that the receiving side gets an error
261            // when decoding. Otherwise, the only other option at this point is
262            // to fail the whole connection, which is probably not what you
263            // want.
264            tracing::warn!("cannot send OS resources across a point-to-point connection");
265        }
266        self.0.send(v);
267    }
268}
269
270#[derive(Debug)]
271struct NullConnector;
272
273impl Connect for NullConnector {
274    fn connect(&self, _node_id: NodeId, handle: mesh_node::local_node::RemoteNodeHandle) {
275        handle.fail(NoMesh);
276    }
277}
278
279#[derive(Debug, Error)]
280#[error("no extra connections allowed in point-to-point mesh")]
281struct NoMesh;
282
283#[cfg(test)]
284mod tests {
285    use super::PointToPointMesh;
286    use mesh_channel::channel;
287    use pal_async::DefaultDriver;
288    use pal_async::async_test;
289    use pal_async::socket::PolledSocket;
290    use test_with_tracing::test;
291    use unix_socket::UnixStream;
292
293    #[async_test]
294    async fn test_point_to_point(driver: DefaultDriver) {
295        let (left, right) = UnixStream::pair().unwrap();
296        let left = PolledSocket::new(&driver, left).unwrap();
297        let right = PolledSocket::new(&driver, right).unwrap();
298        let (a, ax) = channel::<u32>();
299        let (bx, mut b) = channel::<u32>();
300        let left = PointToPointMesh::new(&driver, left, ax.into());
301        let right = PointToPointMesh::new(&driver, right, bx.into());
302        a.send(5);
303        assert_eq!(b.recv().await.unwrap(), 5);
304        left.shutdown().await;
305        right.shutdown().await;
306    }
307}