1use futures::AsyncBufReadExt;
7use futures::AsyncRead;
8use futures::AsyncReadExt;
9use futures::AsyncWrite;
10use futures::AsyncWriteExt;
11use futures::StreamExt;
12use futures::TryFutureExt;
13use futures::future::try_join;
14use futures::io::BufReader;
15use futures_concurrency::future::Race;
16use mesh_channel::cancel::Cancel;
17use mesh_channel::cancel::CancelContext;
18use mesh_channel::cancel::CancelReason;
19use mesh_node::common::Address;
20use mesh_node::common::NodeId;
21use mesh_node::common::PortId;
22use mesh_node::common::Uuid;
23use mesh_node::local_node::Connect;
24use mesh_node::local_node::LocalNode;
25use mesh_node::local_node::OutgoingEvent;
26use mesh_node::local_node::Port;
27use mesh_node::local_node::SendEvent;
28use pal_async::task::Spawn;
29use pal_async::task::Task;
30use std::io;
31use std::pin::pin;
32use thiserror::Error;
33use tracing::Instrument;
34use zerocopy::FromBytes;
35use zerocopy::FromZeros;
36use zerocopy::Immutable;
37use zerocopy::IntoBytes;
38use zerocopy::KnownLayout;
39
40#[must_use]
52pub struct PointToPointMesh {
53 task: Task<()>,
54 cancel: Cancel,
55}
56
57impl PointToPointMesh {
58 pub fn new(
76 spawn: impl Spawn,
77 conn: impl 'static + AsyncRead + AsyncWrite + Send + Unpin,
78 port: Port,
79 ) -> Self {
80 let local_address = Address {
81 node: NodeId::new(),
82 port: PortId::new(),
83 };
84 let (mut ctx, cancel) = CancelContext::new().with_cancel();
85 let task = spawn.spawn(
86 format!("mesh-point-to-point-{:?}", local_address.node),
87 async move {
88 if let Err(err) = handle_comms(&mut ctx, Box::new(conn), local_address, port).await
89 {
90 tracing::error!(error = &err as &dyn std::error::Error, "io failure");
91 }
92 }
93 .instrument(tracing::info_span!("mesh-point-to-point", node = ?local_address.node)),
94 );
95
96 Self { task, cancel }
97 }
98
99 pub async fn shutdown(mut self) {
101 self.cancel.cancel();
102 self.task.await;
103 }
104}
105
106trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin + Send {}
107impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncReadWrite for T {}
108
109#[derive(Debug, Error)]
110enum TaskError {
111 #[error("cancelled")]
112 Cancelled(#[from] CancelReason),
113 #[error("failed to change addresses")]
114 Exchange(#[source] io::Error),
115 #[error("failed to send data")]
116 Send(#[source] io::Error),
117 #[error("failed to receive data")]
118 Recv(#[source] io::Error),
119}
120
121async fn handle_comms(
122 ctx: &mut CancelContext,
123 conn: Box<dyn AsyncReadWrite>,
124 local_address: Address,
125 port: Port,
126) -> Result<(), TaskError> {
127 let (mut read, mut write) = conn.split();
128 let node = LocalNode::with_id(local_address.node, Box::new(NullConnector));
129
130 tracing::debug!("exchanging addresses");
131 let remote_address = ctx
132 .until_cancelled(exchange_addresses(local_address, &mut write, &mut read))
133 .await?
134 .map_err(TaskError::Exchange)?;
135
136 tracing::debug!(?local_address, ?remote_address, "connected to remote node");
137
138 let remote = node.add_remote(remote_address.node);
139 let (send_event, recv_event) = mesh_channel::channel();
140 remote.connect(PointToPointConnection(send_event));
141 let init_port = node.add_port(local_address.port, remote_address);
142 init_port.bridge(port);
143
144 let recv_loop = recv_loop(&remote_address.node, read, &node).map_err(TaskError::Recv);
145 let send_loop = send_loop(recv_event, write).map_err(TaskError::Send);
146
147 let mut fut = pin!((recv_loop, send_loop).race());
152
153 let r = match ctx.until_cancelled(fut.as_mut()).await {
154 Ok(r) => r,
155 Err(_) => {
156 let shutdown = async {
157 node.wait_for_ports(false).await;
158 node.fail_all_nodes();
159 Ok(())
160 };
161 try_join(shutdown, fut).await.map(|((), ())| ())
162 }
163 };
164 match r {
165 Ok(()) => remote.disconnect(),
166 Err(err) => remote.fail(err),
167 }
168 Ok(())
169}
170
171async fn exchange_addresses(
172 local_address: Address,
173 write: &mut (impl AsyncWrite + Unpin),
174 read: &mut (impl AsyncRead + Unpin),
175) -> io::Result<Address> {
176 #[repr(C)]
177 #[derive(IntoBytes, Immutable, KnownLayout, FromBytes)]
178 struct Message {
179 magic: [u8; 4],
180 node: [u8; 16],
181 port: [u8; 16],
182 }
183
184 const MAGIC: [u8; 4] = *b"mesh";
185 let local_msg = Message {
186 magic: MAGIC,
187 node: (local_address.node.0).0,
188 port: (local_address.port.0).0,
189 };
190
191 let mut remote_msg = Message::new_zeroed();
192 try_join(
193 write.write_all(local_msg.as_bytes()),
194 read.read_exact(remote_msg.as_mut_bytes()),
195 )
196 .await?;
197
198 if remote_msg.magic != MAGIC {
199 return Err(io::Error::new(
200 io::ErrorKind::InvalidData,
201 "invalid address header",
202 ));
203 }
204
205 Ok(Address::new(
206 NodeId(Uuid(remote_msg.node)),
207 PortId(Uuid(remote_msg.port)),
208 ))
209}
210
211async fn recv_loop(
212 remote_id: &NodeId,
213 read: impl AsyncRead + Unpin,
214 node: &LocalNode,
215) -> io::Result<()> {
216 let mut read = BufReader::new(read);
217 loop {
218 let mut b = [0; 8];
219 if read.fill_buf().await?.is_empty() {
220 break;
221 }
222 read.read_exact(&mut b).await?;
223 let len = u64::from_le_bytes(b) as usize;
224 let buf = read.buffer();
225 if buf.len() >= len {
226 node.event(remote_id, buf, &mut Vec::new());
228 read.consume_unpin(len);
229 } else {
230 let mut b = vec![0; len];
232 read.read_exact(&mut b).await?;
233 node.event(remote_id, &b, &mut Vec::new());
234 }
235 }
236 tracing::debug!("recv loop done");
237 Ok(())
238}
239
240async fn send_loop(
241 mut recv_event: mesh_channel::Receiver<Vec<u8>>,
242 mut write: impl AsyncWrite + Unpin,
243) -> io::Result<()> {
244 while let Some(event) = recv_event.next().await {
245 write.write_all(&(event.len() as u64).to_le_bytes()).await?;
246 write.write_all(&event).await?;
247 }
248 tracing::debug!("send loop done");
249 Ok(())
250}
251
252#[derive(Debug)]
253struct PointToPointConnection(mesh_channel::Sender<Vec<u8>>);
254
255impl SendEvent for PointToPointConnection {
256 fn event(&self, event: OutgoingEvent<'_>) {
257 let len = event.len();
258 let mut v = Vec::with_capacity(len);
259 let mut resources = Vec::new();
260 event.write_to(&mut v, &mut resources);
261 if !resources.is_empty() {
262 tracing::warn!("cannot send OS resources across a point-to-point connection");
267 }
268 self.0.send(v);
269 }
270}
271
272#[derive(Debug)]
273struct NullConnector;
274
275impl Connect for NullConnector {
276 fn connect(&self, _node_id: NodeId, handle: mesh_node::local_node::RemoteNodeHandle) {
277 handle.fail(NoMesh);
278 }
279}
280
281#[derive(Debug, Error)]
282#[error("no extra connections allowed in point-to-point mesh")]
283struct NoMesh;
284
285#[cfg(test)]
286mod tests {
287 use super::PointToPointMesh;
288 use mesh_channel::channel;
289 use pal_async::DefaultDriver;
290 use pal_async::async_test;
291 use pal_async::socket::PolledSocket;
292 use test_with_tracing::test;
293 use unix_socket::UnixStream;
294
295 #[async_test]
296 async fn test_point_to_point(driver: DefaultDriver) {
297 let (left, right) = UnixStream::pair().unwrap();
298 let left = PolledSocket::new(&driver, left).unwrap();
299 let right = PolledSocket::new(&driver, right).unwrap();
300 let (a, ax) = channel::<u32>();
301 let (bx, mut b) = channel::<u32>();
302 let left = PointToPointMesh::new(&driver, left, ax.into());
303 let right = PointToPointMesh::new(&driver, right, bx.into());
304 a.send(5);
305 assert_eq!(b.recv().await.unwrap(), 5);
306 left.shutdown().await;
307 right.shutdown().await;
308 }
309
310 #[async_test]
315 async fn test_point_to_point_os_resource_dropped(driver: DefaultDriver) {
316 let (left_sock, right_sock) = UnixStream::pair().unwrap();
317 let left_sock = PolledSocket::new(&driver, left_sock).unwrap();
318 let right_sock = PolledSocket::new(&driver, right_sock).unwrap();
319
320 let (sender, sender_port) = channel::<UnixStream>();
323 let (receiver_port, mut receiver) = channel::<UnixStream>();
324 let left = PointToPointMesh::new(&driver, left_sock, sender_port.into());
325 let right = PointToPointMesh::new(&driver, right_sock, receiver_port.into());
326
327 let (fd, _other) = UnixStream::pair().unwrap();
329 sender.send(fd);
330
331 drop(sender);
335
336 left.shutdown().await;
337 right.shutdown().await;
338
339 let result = receiver.recv().await;
342 assert!(result.is_err(), "expected error, got {:?}", result);
343 }
344}