1use futures::AsyncBufReadExt;
7use futures::AsyncRead;
8use futures::AsyncReadExt;
9use futures::AsyncWrite;
10use futures::AsyncWriteExt;
11use futures::StreamExt;
12use futures::TryFutureExt;
13use futures::future::try_join;
14use futures::io::BufReader;
15use futures_concurrency::future::Race;
16use mesh_channel::cancel::Cancel;
17use mesh_channel::cancel::CancelContext;
18use mesh_channel::cancel::CancelReason;
19use mesh_node::common::Address;
20use mesh_node::common::NodeId;
21use mesh_node::common::PortId;
22use mesh_node::common::Uuid;
23use mesh_node::local_node::Connect;
24use mesh_node::local_node::LocalNode;
25use mesh_node::local_node::OutgoingEvent;
26use mesh_node::local_node::Port;
27use mesh_node::local_node::SendEvent;
28use pal_async::task::Spawn;
29use pal_async::task::Task;
30use std::io;
31use std::pin::pin;
32use thiserror::Error;
33use tracing::Instrument;
34use zerocopy::FromBytes;
35use zerocopy::FromZeros;
36use zerocopy::Immutable;
37use zerocopy::IntoBytes;
38use zerocopy::KnownLayout;
39
40#[must_use]
50pub struct PointToPointMesh {
51 task: Task<()>,
52 cancel: Cancel,
53}
54
55impl PointToPointMesh {
56 pub fn new(
74 spawn: impl Spawn,
75 conn: impl 'static + AsyncRead + AsyncWrite + Send + Unpin,
76 port: Port,
77 ) -> Self {
78 let local_address = Address {
79 node: NodeId::new(),
80 port: PortId::new(),
81 };
82 let (mut ctx, cancel) = CancelContext::new().with_cancel();
83 let task = spawn.spawn(
84 format!("mesh-point-to-point-{:?}", local_address.node),
85 async move {
86 if let Err(err) = handle_comms(&mut ctx, Box::new(conn), local_address, port).await
87 {
88 tracing::error!(error = &err as &dyn std::error::Error, "io failure");
89 }
90 }
91 .instrument(tracing::info_span!("mesh-point-to-point", node = ?local_address.node)),
92 );
93
94 Self { task, cancel }
95 }
96
97 pub async fn shutdown(mut self) {
99 self.cancel.cancel();
100 self.task.await;
101 }
102}
103
104trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin + Send {}
105impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncReadWrite for T {}
106
107#[derive(Debug, Error)]
108enum TaskError {
109 #[error("cancelled")]
110 Cancelled(#[from] CancelReason),
111 #[error("failed to change addresses")]
112 Exchange(#[source] io::Error),
113 #[error("failed to send data")]
114 Send(#[source] io::Error),
115 #[error("failed to receive data")]
116 Recv(#[source] io::Error),
117}
118
119async fn handle_comms(
120 ctx: &mut CancelContext,
121 conn: Box<dyn AsyncReadWrite>,
122 local_address: Address,
123 port: Port,
124) -> Result<(), TaskError> {
125 let (mut read, mut write) = conn.split();
126 let node = LocalNode::with_id(local_address.node, Box::new(NullConnector));
127
128 tracing::debug!("exchanging addresses");
129 let remote_address = ctx
130 .until_cancelled(exchange_addresses(local_address, &mut write, &mut read))
131 .await?
132 .map_err(TaskError::Exchange)?;
133
134 tracing::debug!(?local_address, ?remote_address, "connected to remote node");
135
136 let remote = node.add_remote(remote_address.node);
137 let (send_event, recv_event) = mesh_channel::channel();
138 remote.connect(PointToPointConnection(send_event));
139 let init_port = node.add_port(local_address.port, remote_address);
140 init_port.bridge(port);
141
142 let recv_loop = recv_loop(&remote_address.node, read, &node).map_err(TaskError::Recv);
143 let send_loop = send_loop(recv_event, write).map_err(TaskError::Send);
144
145 let mut fut = pin!((recv_loop, send_loop).race());
150
151 let r = match ctx.until_cancelled(fut.as_mut()).await {
152 Ok(r) => r,
153 Err(_) => {
154 let shutdown = async {
155 node.wait_for_ports(false).await;
156 node.fail_all_nodes();
157 Ok(())
158 };
159 try_join(shutdown, fut).await.map(|((), ())| ())
160 }
161 };
162 match r {
163 Ok(()) => remote.disconnect(),
164 Err(err) => remote.fail(err),
165 }
166 Ok(())
167}
168
169async fn exchange_addresses(
170 local_address: Address,
171 write: &mut (impl AsyncWrite + Unpin),
172 read: &mut (impl AsyncRead + Unpin),
173) -> io::Result<Address> {
174 #[repr(C)]
175 #[derive(IntoBytes, Immutable, KnownLayout, FromBytes)]
176 struct Message {
177 magic: [u8; 4],
178 node: [u8; 16],
179 port: [u8; 16],
180 }
181
182 const MAGIC: [u8; 4] = *b"mesh";
183 let local_msg = Message {
184 magic: MAGIC,
185 node: (local_address.node.0).0,
186 port: (local_address.port.0).0,
187 };
188
189 let mut remote_msg = Message::new_zeroed();
190 try_join(
191 write.write_all(local_msg.as_bytes()),
192 read.read_exact(remote_msg.as_mut_bytes()),
193 )
194 .await?;
195
196 if remote_msg.magic != MAGIC {
197 return Err(io::Error::new(
198 io::ErrorKind::InvalidData,
199 "invalid address header",
200 ));
201 }
202
203 Ok(Address::new(
204 NodeId(Uuid(remote_msg.node)),
205 PortId(Uuid(remote_msg.port)),
206 ))
207}
208
209async fn recv_loop(
210 remote_id: &NodeId,
211 read: impl AsyncRead + Unpin,
212 node: &LocalNode,
213) -> io::Result<()> {
214 let mut read = BufReader::new(read);
215 loop {
216 let mut b = [0; 8];
217 if read.fill_buf().await?.is_empty() {
218 break;
219 }
220 read.read_exact(&mut b).await?;
221 let len = u64::from_le_bytes(b) as usize;
222 let buf = read.buffer();
223 if buf.len() >= len {
224 node.event(remote_id, buf, &mut Vec::new());
226 read.consume_unpin(len);
227 } else {
228 let mut b = vec![0; len];
230 read.read_exact(&mut b).await?;
231 node.event(remote_id, &b, &mut Vec::new());
232 }
233 }
234 tracing::debug!("recv loop done");
235 Ok(())
236}
237
238async fn send_loop(
239 mut recv_event: mesh_channel::Receiver<Vec<u8>>,
240 mut write: impl AsyncWrite + Unpin,
241) -> io::Result<()> {
242 while let Some(event) = recv_event.next().await {
243 write.write_all(&(event.len() as u64).to_le_bytes()).await?;
244 write.write_all(&event).await?;
245 }
246 tracing::debug!("send loop done");
247 Ok(())
248}
249
250#[derive(Debug)]
251struct PointToPointConnection(mesh_channel::Sender<Vec<u8>>);
252
253impl SendEvent for PointToPointConnection {
254 fn event(&self, event: OutgoingEvent<'_>) {
255 let len = event.len();
256 let mut v = Vec::with_capacity(len);
257 let mut resources = Vec::new();
258 event.write_to(&mut v, &mut resources);
259 if !resources.is_empty() {
260 tracing::warn!("cannot send OS resources across a point-to-point connection");
265 }
266 self.0.send(v);
267 }
268}
269
270#[derive(Debug)]
271struct NullConnector;
272
273impl Connect for NullConnector {
274 fn connect(&self, _node_id: NodeId, handle: mesh_node::local_node::RemoteNodeHandle) {
275 handle.fail(NoMesh);
276 }
277}
278
279#[derive(Debug, Error)]
280#[error("no extra connections allowed in point-to-point mesh")]
281struct NoMesh;
282
283#[cfg(test)]
284mod tests {
285 use super::PointToPointMesh;
286 use mesh_channel::channel;
287 use pal_async::DefaultDriver;
288 use pal_async::async_test;
289 use pal_async::socket::PolledSocket;
290 use test_with_tracing::test;
291 use unix_socket::UnixStream;
292
293 #[async_test]
294 async fn test_point_to_point(driver: DefaultDriver) {
295 let (left, right) = UnixStream::pair().unwrap();
296 let left = PolledSocket::new(&driver, left).unwrap();
297 let right = PolledSocket::new(&driver, right).unwrap();
298 let (a, ax) = channel::<u32>();
299 let (bx, mut b) = channel::<u32>();
300 let left = PointToPointMesh::new(&driver, left, ax.into());
301 let right = PointToPointMesh::new(&driver, right, bx.into());
302 a.send(5);
303 assert_eq!(b.recv().await.unwrap(), 5);
304 left.shutdown().await;
305 right.shutdown().await;
306 }
307}