mesh_remote/
unix_node.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Unix socket-based mesh node implementation.
5//!
6//! Each pair of nodes communicate using a single, bidirectional Unix socket. On
7//! platforms that support it (Linux), a `SOCK_SEQPACKET` socket is used, which
8//! provides message framing and atomic message sends. Otherwise, a
9//! `SOCK_STREAM` packet is used, and the protocol includes a message size.
10//!
11//! File descriptors are sent between nodes using the `SCM_RIGHTS` functionality
12//! of Unix sockets.
13
14#![cfg(unix)]
15// UNSAFETY: Calls to libc send/recvmsg fns and the work to prepare their inputs
16// and handle their outputs (mem::zeroed, transmutes, from_raw_fds).
17#![expect(unsafe_code)]
18
19#[cfg(target_os = "linux")]
20mod memfd;
21
22use crate::common::InvitationAddress;
23use crate::protocol;
24use futures::FutureExt;
25use futures::StreamExt;
26use futures::channel::mpsc;
27use futures::future;
28use futures::future::BoxFuture;
29use io::ErrorKind;
30use mesh_channel::OneshotReceiver;
31use mesh_channel::OneshotSender;
32use mesh_channel::RecvError;
33use mesh_channel::channel;
34use mesh_channel::oneshot;
35use mesh_node::common::Address;
36use mesh_node::common::NodeId;
37use mesh_node::common::PortId;
38use mesh_node::local_node::Connect;
39use mesh_node::local_node::LocalNode;
40use mesh_node::local_node::OutgoingEvent;
41use mesh_node::local_node::Port;
42use mesh_node::local_node::RemoteNodeHandle;
43use mesh_node::local_node::SendEvent;
44use mesh_node::resource::OsResource;
45use mesh_node::resource::Resource;
46use mesh_protobuf::Protobuf;
47use pal_async::driver::SpawnDriver;
48use pal_async::interest::InterestSlot;
49use pal_async::interest::PollEvents;
50use pal_async::socket::PolledSocket;
51use pal_async::task::Spawn;
52use pal_async::task::Task;
53use parking_lot::Mutex;
54use socket2::Socket;
55use std::collections::HashMap;
56use std::collections::VecDeque;
57use std::fmt::Debug;
58use std::future::Future;
59use std::future::poll_fn;
60use std::io;
61use std::io::IoSlice;
62use std::io::IoSliceMut;
63use std::os::unix::prelude::*;
64use std::pin::pin;
65use std::sync::Arc;
66use thiserror::Error;
67use tracing::instrument;
68use unicycle::FuturesUnordered;
69use zerocopy::FromBytes;
70use zerocopy::FromZeros;
71use zerocopy::IntoBytes;
72
73/// If true, use a SOCK_SEQPACKET socket. Otherwise, use a SOCK_STREAM socket.
74///
75/// SOCK_SEQPACKET is preferred where available because it allows us to avoid
76/// separately tracking message boundaries. Most importantly, this makes it
77/// straightforward to support sending messages from multiple threads
78/// simultaneously.
79const USE_SEQPACKET: bool = cfg!(target_os = "linux");
80
81/// The maximum packet size. Linux uses memfd to send data larger than this.
82/// Other OSes just fail, so choose a larger size.
83///
84/// These values were chosen arbitrarily and have not been tested for
85/// performance.
86const MAX_PACKET_SIZE: usize = if cfg!(target_os = "linux") {
87    0x4000
88} else {
89    0x40000
90};
91
92const MAX_SMALL_EVENT_SIZE: usize = MAX_PACKET_SIZE - size_of::<protocol::PacketHeader>();
93
94/// A node within a mesh that uses Unix sockets to communicate.
95///
96/// Each pairwise connection between two nodes in the mesh communicates via a
97/// pair of bidirectional sockets.
98///
99/// If one node needs to send data to another but does not have a connection, it
100/// sends a request to the leader node to establish a connection. The leader
101/// creates a new socket pair and sends one end to each of the nodes, which the
102/// two nodes can use to communicate.
103pub struct UnixNode {
104    driver: Arc<dyn SpawnDriver>,
105    local_node: Arc<LocalNode>,
106    to_leader: Arc<mesh_channel::Sender<LeaderRequest>>,
107    tasks: Arc<mesh_channel::Sender<SmallTask>>,
108    io_task: Task<()>,
109    // TODO: consider reducing type complexity?
110    leader_resign_send:
111        Mutex<Option<Arc<mesh_channel::Sender<(NodeId, mesh_channel::Sender<Followers>)>>>>,
112
113    // meaningful drop
114    _drop_send: OneshotSender<()>,
115}
116
117#[derive(Debug, Protobuf)]
118#[mesh(resource = "Resource")]
119enum LeaderRequest {
120    Connect(NodeId),
121    Invite(Port, mesh_channel::Sender<Invitation>),
122}
123
124#[derive(Debug, Protobuf)]
125#[mesh(resource = "Resource")]
126enum FollowerRequest {
127    Connect(
128        NodeId,
129        #[mesh(
130            encoding = "mesh_protobuf::encoding::OptionField<mesh_protobuf::encoding::ResourceField<OwnedFd>>"
131        )]
132        Option<Socket>,
133    ),
134}
135
136#[derive(Debug, Protobuf)]
137#[mesh(resource = "Resource")]
138pub struct Followers {
139    list: Vec<(
140        NodeId,
141        mesh_channel::Receiver<LeaderRequest>,
142        mesh_channel::Sender<FollowerRequest>,
143    )>,
144}
145
146#[derive(Debug, Protobuf)]
147#[mesh(resource = "Resource")]
148struct InitialMessage {
149    leader_send: mesh_channel::Sender<LeaderRequest>,
150    follower_recv: mesh_channel::Receiver<FollowerRequest>,
151    user_port: Port,
152}
153
154/// Processes incoming requests from the leader to a follower. Currently the only
155/// such request is to add a connection to another node.
156#[instrument(skip_all, fields(local_id = ?local_node.id()))]
157async fn run_follower(
158    driver: &dyn SpawnDriver,
159    local_node: &Arc<LocalNode>,
160    mut recv: mesh_channel::Receiver<FollowerRequest>,
161    pending_connections: Arc<Mutex<HashMap<NodeId, RemoteNodeHandle>>>,
162    tasks: &mesh_channel::Sender<SmallTask>,
163) {
164    while let Ok(req) = recv.recv().await {
165        match req {
166            FollowerRequest::Connect(target_id, fd) => {
167                tracing::debug!(?target_id, "got connection request from leader");
168                let handle = pending_connections.lock().remove(&target_id);
169                let handle = handle.unwrap_or_else(|| local_node.get_remote_handle(target_id));
170
171                if let Some(fd) = fd {
172                    start_connection(
173                        tasks,
174                        local_node,
175                        target_id,
176                        handle,
177                        UnixSocket::new(driver, fd),
178                    );
179                } else {
180                    tracing::warn!(?target_id, "leader provided failed connection");
181                }
182            }
183        }
184    }
185}
186
187/// Processes incoming requests from a follower to the leader. Runs until there
188/// are no more followers or until the leader is asked to transfer power to
189/// another node via `resign_recv`.
190#[instrument(skip_all, fields(local_id = ?local_node.id()))]
191async fn run_leader(
192    driver: &dyn SpawnDriver,
193    local_node: &Arc<LocalNode>,
194    mut resign_recv: mesh_channel::Receiver<(NodeId, mesh_channel::Sender<Followers>)>,
195    followers: Followers,
196    tasks: &mesh_channel::Sender<SmallTask>,
197) {
198    let mut senders = HashMap::new();
199    let mut receivers = Vec::new();
200    for (remote_id, recv, send) in followers.list {
201        receivers.push((remote_id, recv));
202        senders.insert(remote_id, send);
203    }
204
205    let new_leader_info = loop {
206        if receivers.is_empty() {
207            return;
208        }
209        let recvs = receivers
210            .iter_mut()
211            .map(|(_, recv)| poll_fn(|cx| recv.poll_recv(cx)));
212        let (req, index, _) = futures::select! { // merge semantics
213            r = resign_recv.next() => break r,
214            r = future::select_all(recvs).fuse() => r,
215        };
216        let remote_id = receivers[index].0;
217        match req {
218            Ok(req) => match req {
219                LeaderRequest::Connect(target_id) => {
220                    tracing::debug!(?target_id, ?remote_id, "connection request");
221                    let remote = senders
222                        .get(&remote_id)
223                        .expect("sender must exist to receive from it");
224                    let mut fd = None;
225                    if let Some(target) = senders.get(&target_id) {
226                        match new_socket_pair() {
227                            Ok((left, right)) => {
228                                tracing::trace!(?target, "send to");
229                                target.send(FollowerRequest::Connect(remote_id, Some(left)));
230                                fd = Some(right);
231                            }
232                            Err(err) => {
233                                tracing::warn!(
234                                    ?target_id,
235                                    ?remote_id,
236                                    error = &err as &dyn std::error::Error,
237                                    "failed to create socket pair for connection request"
238                                );
239                            }
240                        }
241                    } else {
242                        tracing::warn!(?target_id, ?remote_id, "could not find target for remote");
243                    }
244                    remote.send(FollowerRequest::Connect(target_id, fd));
245                }
246                LeaderRequest::Invite(port, send) => {
247                    tracing::debug!(?remote_id, "invitation request");
248                    match new_socket_pair() {
249                        Ok((left, right)) => {
250                            let (leader_send, leader_recv) = channel();
251                            let (follower_send, follower_recv) = channel();
252                            let remote_addr = Address {
253                                node: NodeId::new(),
254                                port: PortId::new(),
255                            };
256                            let local_port_id = PortId::new();
257                            let handle = local_node.add_remote(remote_addr.node);
258                            start_connection(
259                                tasks,
260                                local_node,
261                                remote_addr.node,
262                                handle,
263                                UnixSocket::new(driver, left),
264                            );
265                            let init_send = OneshotSender::<InitialMessage>::from(
266                                local_node.add_port(local_port_id, remote_addr),
267                            );
268                            init_send.send(InitialMessage {
269                                leader_send,
270                                follower_recv,
271                                user_port: port,
272                            });
273                            let invitation = Invitation {
274                                address: InvitationAddress {
275                                    local_addr: remote_addr,
276                                    remote_addr: Address {
277                                        node: local_node.id(),
278                                        port: local_port_id,
279                                    },
280                                },
281                                fd: right.into(),
282                            };
283                            tracing::debug!(
284                                invite_id = ?invitation.address.local_addr.node,
285                                ?remote_id,
286                                "inviting",
287                            );
288                            send.send(invitation);
289                            senders.insert(remote_addr.node, follower_send);
290                            receivers.push((remote_addr.node, leader_recv));
291                        }
292                        Err(err) => {
293                            tracing::error!(
294                                error = &err as &dyn std::error::Error,
295                                "failed to create socket pair",
296                            );
297                        }
298                    }
299                }
300            },
301            Err(err) => {
302                if let RecvError::Error(err) = err {
303                    tracing::debug!(
304                        ?remote_id,
305                        error = &err as &dyn std::error::Error,
306                        "leader connection to remote failed"
307                    );
308                }
309                senders.remove(&remote_id);
310                receivers.swap_remove(index);
311            }
312        }
313    };
314
315    if let Some((new_leader_id, new_leader_followers_sink)) = new_leader_info {
316        if let Some(new_leader_send) = senders.get(&new_leader_id) {
317            tracing::debug!(?new_leader_id, "resigning leadership");
318            // Ensure there is a connection between every follower and the new
319            // leader.
320            for (remote_id, send) in senders.iter() {
321                if new_leader_id != *remote_id {
322                    match new_socket_pair() {
323                        Ok((left, right)) => {
324                            send.send(FollowerRequest::Connect(new_leader_id, Some(left)));
325                            new_leader_send.send(FollowerRequest::Connect(*remote_id, Some(right)));
326                        }
327                        Err(err) => {
328                            tracing::error!(
329                                ?new_leader_id,
330                                error = &err as &dyn std::error::Error,
331                                "failed to connect node to new leader, mesh is leaderless",
332                            );
333                            return;
334                        }
335                    }
336                }
337            }
338
339            // Send all the followers to the new leader.
340            let mut followers = Vec::new();
341            for (remote_id, recv) in receivers {
342                let send = senders
343                    .remove(&remote_id)
344                    .expect("should be in sync with receivers");
345                followers.push((remote_id, recv, send));
346            }
347            new_leader_followers_sink.send(Followers { list: followers });
348        } else {
349            tracing::error!(?new_leader_id, "new leader is unknown, mesh is leaderless");
350        }
351    }
352}
353
354/// A task initiator, implementing by a function returning a future. This is
355/// used to send work to the node's IO thread.
356struct SmallTask {
357    name: &'static str,
358    future: BoxFuture<'static, ()>,
359}
360
361impl Debug for SmallTask {
362    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363        f.pad("SmallTask")
364    }
365}
366
367impl SmallTask {
368    fn new(name: &'static str, f: impl 'static + Send + Future<Output = ()>) -> Self {
369        Self {
370            name,
371            future: Box::pin(f),
372        }
373    }
374}
375
376/// An invitation allowing another process to join the mesh.
377///
378/// Created by [`UnixNode::invite`].
379#[derive(Debug, Protobuf)]
380#[mesh(resource = "Resource")]
381pub struct Invitation {
382    /// The common invitation addresses.
383    pub address: InvitationAddress,
384    /// The Unix socket used to initiate communications with the mesh.
385    pub fd: OwnedFd,
386}
387
388#[derive(Debug)]
389enum SenderCommand {
390    Send {
391        packet: Vec<u8>,
392        fds: Vec<OsResource>,
393    },
394    ReleaseFds {
395        count: usize,
396    },
397}
398
399#[derive(Clone)]
400struct PacketSender {
401    send: mpsc::UnboundedSender<SenderCommand>,
402    socket: Arc<UnixSocket>,
403}
404
405impl SendEvent for PacketSender {
406    fn event(&self, event: OutgoingEvent<'_>) {
407        let (packet, fds) = match serialize_event(event) {
408            Ok(r) => r,
409            Err(err) => {
410                // FUTURE: fail the port or connection instead?
411                tracing::error!(
412                    error = &err as &dyn std::error::Error,
413                    "failed to serialize event"
414                );
415                return;
416            }
417        };
418
419        // If using SOCK_SEQPACKET, try to send the packet immediately. If this
420        // fails (likely due to EAGAIN), send the packet to the asynchronous
421        // task for deferred processing.
422        //
423        // When SOCK_STREAM is in use, this optimization cannot be tried,
424        // because for stream sockets we may need to issue multiple writes to
425        // send the whole message, and those writes cannot be interleaved
426        // correctly.
427        //
428        // N.B. This can lead to out of order messages. The event protocol is
429        //      responsible for handling this condition.
430        if !USE_SEQPACKET
431            || try_send(
432                self.socket.socket.lock().get(),
433                &[IoSlice::new(&packet)],
434                &fds,
435            )
436            .is_err()
437        {
438            let _ = self
439                .send
440                .unbounded_send(SenderCommand::Send { packet, fds });
441        }
442    }
443}
444
445fn serialize_event(event: OutgoingEvent<'_>) -> io::Result<(Vec<u8>, Vec<OsResource>)> {
446    // Serialize the event to a memfd if it's too large to send inline.
447    if event.len() > MAX_SMALL_EVENT_SIZE {
448        return serialize_large_event(event);
449    }
450
451    // Serialize the event to a byte vector.
452    let cap = size_of::<protocol::PacketHeader>() + event.len();
453    let mut packet = Vec::with_capacity(cap);
454    packet.extend_from_slice(
455        protocol::PacketHeader {
456            packet_type: protocol::PacketType::EVENT,
457            reserved: [0; 7],
458        }
459        .as_bytes(),
460    );
461    let mut fds = Vec::new();
462    event.write_to(&mut packet, &mut fds);
463    assert_eq!(packet.len(), cap);
464    Ok((packet, fds))
465}
466
467#[cfg(target_os = "linux")]
468fn serialize_large_event(event: OutgoingEvent<'_>) -> io::Result<(Vec<u8>, Vec<OsResource>)> {
469    let packet = protocol::PacketHeader {
470        packet_type: protocol::PacketType::LARGE_EVENT,
471        ..FromZeros::new_zeroed()
472    }
473    .as_bytes()
474    .to_vec();
475
476    let mut fds = Vec::new();
477
478    let mut memfd = memfd::MemfdBuilder::new(event.len())?;
479    event.write_to(&mut io::Cursor::new(&mut *memfd), &mut fds);
480    fds.insert(0, OsResource::Fd(memfd.seal()?.into()));
481
482    Ok((packet, fds))
483}
484
485#[cfg(not(target_os = "linux"))]
486fn serialize_large_event(_event: OutgoingEvent<'_>) -> io::Result<(Vec<u8>, Vec<OsResource>)> {
487    Err(io::Error::new(
488        ErrorKind::Unsupported,
489        "event too large for this OS",
490    ))
491}
492
493impl Drop for PacketSender {
494    fn drop(&mut self) {
495        // Explicitly close the send channel so that the send task returns, even
496        // though the send channel is also in use by the receive task.
497        self.send.close_channel();
498    }
499}
500
501/// Starts a connection processing task.
502fn start_connection(
503    tasks: &mesh_channel::Sender<SmallTask>,
504    local_node: &Arc<LocalNode>,
505    remote_id: NodeId,
506    handle: RemoteNodeHandle,
507    socket: UnixSocket,
508) {
509    #[expect(clippy::disallowed_methods)] // TODO
510    let (send, recv) = mpsc::unbounded();
511    let socket = Arc::new(socket);
512    let sender = PacketSender {
513        send: send.clone(),
514        socket: socket.clone(),
515    };
516    if handle.connect(sender) {
517        let task = SmallTask::new("run_connection", {
518            let local_node = local_node.clone();
519            run_connection(local_node, remote_id, send, recv, socket, handle)
520        });
521        tasks.send(task);
522        tracing::debug!(?remote_id, "connected");
523    } else {
524        // N.B. This is an expected condition in many scenarios, since the
525        //      leader does not track which connections have already been
526        //      made and so will often send duplicate connection requests.
527        tracing::debug!(?remote_id, "duplicate connection");
528    }
529}
530
531/// Runs the packet processing loop.
532#[instrument(skip_all, fields(local_id = ?local_node.id(), remote_id = ?remote_id))]
533async fn run_connection(
534    local_node: Arc<LocalNode>,
535    remote_id: NodeId,
536    send_send: mpsc::UnboundedSender<SenderCommand>,
537    send_recv: mpsc::UnboundedReceiver<SenderCommand>,
538    socket: Arc<UnixSocket>,
539    handle: RemoteNodeHandle,
540) {
541    let mut retained_fds = VecDeque::new();
542    let mut recv = pin!(
543        async {
544            let r = run_receive(&local_node, &remote_id, &socket, &send_send).await;
545            match &r {
546                Ok(_) => {
547                    tracing::debug!("incoming socket disconnected");
548                }
549                Err(err) => {
550                    tracing::error!(error = err as &dyn std::error::Error, "error receiving");
551                }
552            }
553            r
554        }
555        .fuse()
556    );
557    let mut send = pin!(
558        async {
559            match run_send(send_recv, &socket, &mut retained_fds).await {
560                Ok(_) => {
561                    tracing::debug!("sending is done");
562                }
563                Err(err) => {
564                    tracing::error!(error = &err as &dyn std::error::Error, "failed send");
565                }
566            }
567        }
568        .fuse()
569    );
570    let r = futures::select! { // race semantics
571        r = recv => {
572            // Notify the remote node that no more data will be sent.
573            tracing::trace!("read complete, shutting down writes");
574            let _ = socket.close_write().await;
575            r
576        }
577        _ = send => {
578            match socket.close_write().await {
579                Ok(()) => {
580                    tracing::trace!("shutdown writes, waiting for reads");
581                    recv.await
582                }
583                Err(err) => {
584                    tracing::error!(
585                        error = &err as &dyn std::error::Error,
586                        "failed to shutdown writes, aborting connection",
587                    );
588                    Err(ReceiveError::Io(err))
589                }
590            }
591        }
592    };
593    tracing::trace!("connection done");
594    match r {
595        Ok(()) => handle.disconnect(),
596        Err(err) => handle.fail(err),
597    }
598}
599
600#[derive(Debug, Error)]
601enum ReceiveError {
602    #[error("i/o error")]
603    Io(#[from] io::Error),
604    #[error("missing packet header")]
605    NoHeader,
606    #[error("release fds packet too small")]
607    BadReleaseFds,
608    #[error("unknown packet type {0:?}")]
609    UnknownPacketType(protocol::PacketType),
610    #[cfg(target_os = "linux")]
611    #[error("memfd file descriptor not sent for large event")]
612    MissingMemfd,
613    #[cfg(target_os = "linux")]
614    #[error("failed to map memfd")]
615    Memfd(#[source] io::Error),
616}
617
618/// Handles receive processing for the socket.
619async fn run_receive(
620    local_node: &LocalNode,
621    remote_id: &NodeId,
622    socket: &UnixSocket,
623    send: &mpsc::UnboundedSender<SenderCommand>,
624) -> Result<(), ReceiveError> {
625    let mut buf = vec![0; MAX_PACKET_SIZE];
626    let mut fds = Vec::new();
627    loop {
628        let len = socket.recv(&mut buf, &mut fds).await?;
629        if len == 0 {
630            break;
631        }
632        if cfg!(target_os = "macos") && !fds.is_empty() {
633            // Tell the opposite endpoint to release the fds it sent.
634            let _ = send.unbounded_send(SenderCommand::Send {
635                packet: protocol::ReleaseFds {
636                    header: protocol::PacketHeader {
637                        packet_type: protocol::PacketType::RELEASE_FDS,
638                        ..FromZeros::new_zeroed()
639                    },
640                    count: fds.len() as u64,
641                }
642                .as_bytes()
643                .to_vec(),
644                fds: Vec::new(),
645            });
646        }
647
648        let buf = &buf[..len];
649        let header = protocol::PacketHeader::read_from_prefix(buf)
650            .map_err(|_| ReceiveError::NoHeader)?
651            .0; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
652        match header.packet_type {
653            protocol::PacketType::EVENT => {
654                local_node.event(remote_id, &buf[size_of_val(&header)..], &mut fds);
655                fds.clear();
656            }
657            protocol::PacketType::RELEASE_FDS => {
658                let release_fds = protocol::ReleaseFds::read_from_prefix(buf)
659                    .map_err(|_| ReceiveError::BadReleaseFds)?
660                    .0; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
661                let _ = send.unbounded_send(SenderCommand::ReleaseFds {
662                    count: release_fds.count as usize,
663                });
664            }
665            #[cfg(target_os = "linux")]
666            protocol::PacketType::LARGE_EVENT => {
667                if fds.is_empty() {
668                    return Err(ReceiveError::MissingMemfd);
669                }
670                let OsResource::Fd(fd) = fds.remove(0);
671                let memfd = memfd::SealedMemfd::new(fd.into()).map_err(ReceiveError::Memfd)?;
672                local_node.event(remote_id, &memfd, &mut fds);
673                fds.clear();
674            }
675            ty => {
676                return Err(ReceiveError::UnknownPacketType(ty));
677            }
678        }
679    }
680    Ok(())
681}
682
683#[derive(Debug, Error)]
684enum ProtocolError {
685    #[error("request to release too many fds")]
686    ReleasingTooManyFds,
687}
688
689/// Handles send processing for the socket.
690async fn run_send(
691    mut recv: mpsc::UnboundedReceiver<SenderCommand>,
692    socket: &UnixSocket,
693    retained_fds: &mut VecDeque<OsResource>,
694) -> io::Result<()> {
695    while let Some(command) = recv.next().await {
696        match command {
697            SenderCommand::Send { packet, fds } => {
698                match socket.send(&packet, &fds).await {
699                    Ok(_) => (),
700                    Err(err) => {
701                        tracing::error!(
702                            fd_count = fds.len(),
703                            packet_len = packet.len(),
704                            "failed to send packet"
705                        );
706                        return Err(err);
707                    }
708                }
709                if cfg!(target_os = "macos") {
710                    // MacOS has a bug where it prematurely closes Unix sockets
711                    // if a file descriptor to one is closed while it is also in
712                    // the process of being sent across another Unix socket.
713                    // Retain the fds until the opposite endpoint sends a reply
714                    // message.
715                    if !fds.is_empty() {
716                        retained_fds.extend(fds);
717                    }
718                }
719            }
720            SenderCommand::ReleaseFds { count } => {
721                if retained_fds.len() < count {
722                    return Err(io::Error::other(ProtocolError::ReleasingTooManyFds));
723                }
724                retained_fds.drain(..count);
725            }
726        }
727    }
728    Ok(())
729}
730
731/// An offer to take over as leader of the mesh. One trusted process must be the
732/// leader at all times or the mesh will fail.
733#[derive(Debug, Protobuf)]
734#[mesh(resource = "Resource")]
735pub struct LeadershipOffer {
736    send: mesh_channel::Sender<(NodeId, mesh_channel::Sender<Followers>)>,
737}
738
739impl UnixNode {
740    /// Creates a new Unix node mesh with this node as the leader.
741    pub fn new(driver: impl SpawnDriver) -> Self {
742        let (to_leader_send, to_leader_recv) = channel();
743        let (from_leader_send, from_leader_recv) = channel();
744        let this = Self::with_id(
745            Arc::new(driver),
746            NodeId::new(),
747            to_leader_send,
748            from_leader_recv,
749        );
750
751        // Start a leader task. At this point the leader has only one follower, itself.
752        let (resign_send, resign_recv) = channel();
753        let resign_send = Arc::new(resign_send);
754        let followers = Followers {
755            list: vec![(this.local_node.id(), to_leader_recv, from_leader_send)],
756        };
757        let task = SmallTask::new("run_leader", {
758            let local_node = this.local_node.clone();
759            let tasks = this.tasks.clone();
760            let driver = this.driver.clone();
761            async move { run_leader(driver.as_ref(), &local_node, resign_recv, followers, &tasks).await }
762        });
763        this.tasks.send(task);
764        *this.leader_resign_send.lock() = Some(resign_send);
765
766        this
767    }
768
769    /// Gets the node ID. This is mostly useful for diagnostics.
770    pub fn id(&self) -> NodeId {
771        self.local_node.id()
772    }
773
774    /// Creates a node with `id` using `to_leader` and `from_leader` to
775    /// communicate with the leader node.
776    fn with_id(
777        driver: Arc<dyn SpawnDriver>,
778        id: NodeId,
779        to_leader: mesh_channel::Sender<LeaderRequest>,
780        from_leader: mesh_channel::Receiver<FollowerRequest>,
781    ) -> Self {
782        let to_leader = Arc::new(to_leader);
783        let pending_connections: Arc<Mutex<HashMap<NodeId, RemoteNodeHandle>>> = Default::default();
784        let local_node = Arc::new(LocalNode::with_id(
785            id,
786            Box::new(Connector {
787                local_id: id,
788                conn_req_send: to_leader.clone(),
789                pending_connections: pending_connections.clone(),
790            }),
791        ));
792        let (task_send, mut task_recv) = channel::<SmallTask>();
793        let task_send = Arc::new(task_send);
794        let (drop_send, drop_recv) = oneshot();
795
796        // Start a thread to run IO tasks.
797        let io_task = driver.spawn("unix-mesh-io", async move {
798            let process = async {
799                let mut futs = FuturesUnordered::new();
800                loop {
801                    futures::select! { // merge semantics
802                        _ = futs.next() => {},
803                        task = task_recv.select_next_some() => {
804                            futs.push(async move {
805                                tracing::trace!(?id, name = task.name, "task start");
806                                task.future.await;
807                                tracing::trace!(?id, name = task.name, "task end");
808                            });
809                        },
810                        complete => break,
811                    };
812                }
813            };
814            future::select(pin!(process), drop_recv).await;
815        });
816
817        task_send.send(SmallTask::new("run_follower", {
818            let local_node = local_node.clone();
819            let tasks = task_send.clone();
820            let driver = driver.clone();
821            async move {
822                run_follower(
823                    driver.as_ref(),
824                    &local_node,
825                    from_leader,
826                    pending_connections,
827                    &tasks,
828                )
829                .await
830            }
831        }));
832
833        Self {
834            driver,
835            local_node,
836            tasks: task_send,
837            io_task,
838            to_leader,
839            leader_resign_send: Mutex::new(None),
840
841            _drop_send: drop_send,
842        }
843    }
844
845    /// Returns an offer to hand the leadership to another node in the mesh.
846    ///
847    /// The offer should be sent over a channel and passed to
848    /// `accept_leadership` in the receiving node.
849    pub fn offer_leadership(&self) -> LeadershipOffer {
850        let (send, mut recv) = channel();
851        if let Some(leader_send) = self.leader_resign_send.lock().clone() {
852            // Start a task to wait for the offer to be acknowledged, then send
853            // the offer details to the leader thread.
854            let task = SmallTask::new("offer_leadership", async move {
855                if let Ok(r) = recv.recv().await {
856                    leader_send.send(r);
857                }
858            });
859            self.tasks.send(task);
860        }
861        LeadershipOffer { send }
862    }
863
864    /// Accepts a leadership offer, making this node the current leader.
865    pub fn accept_leadership(&self, offer: LeadershipOffer) {
866        let (send, mut recv) = channel();
867        offer.send.send((self.local_node.id(), send));
868
869        let (resign_send, resign_recv) = channel();
870        let resign_send = Arc::new(resign_send);
871        let task = SmallTask::new("accept_and_run_leader", {
872            let local_node = self.local_node.clone();
873            let tasks = self.tasks.clone();
874            let driver = self.driver.clone();
875            async move {
876                if let Ok(followers) = recv.recv().await {
877                    drop(recv);
878                    run_leader(driver.as_ref(), &local_node, resign_recv, followers, &tasks).await
879                }
880            }
881        });
882        self.tasks.send(task);
883        *self.leader_resign_send.lock() = Some(resign_send);
884    }
885
886    /// Invites another process to join the mesh, with `port` bridged with the
887    /// original port.
888    #[instrument(skip_all, fields(local_id = ?self.local_node.id()))]
889    pub async fn invite(&self, port: Port) -> io::Result<Invitation> {
890        let (invitation_send, mut invitation_recv) = channel();
891        self.to_leader
892            .send(LeaderRequest::Invite(port, invitation_send));
893        let invitation = invitation_recv
894            .recv()
895            .await
896            .map_err(|_| ErrorKind::ConnectionReset)?;
897        tracing::debug!(
898            invite_id = ?invitation.address.local_addr.node,
899            "received invitation",
900        );
901        Ok(invitation)
902    }
903
904    /// Joins an existing mesh via an invitation, briding `port` with the
905    /// initial port.
906    pub async fn join(
907        driver: impl SpawnDriver,
908        invitation: Invitation,
909        port: Port,
910    ) -> Result<Self, JoinError> {
911        Self::join_generic(Arc::new(driver), invitation, port).await
912    }
913
914    #[instrument(skip_all, fields(local_id = ?invitation.address.local_addr.node, remote_id = ?invitation.address.remote_addr.node))]
915    async fn join_generic(
916        driver: Arc<dyn SpawnDriver>,
917        invitation: Invitation,
918        port: Port,
919    ) -> Result<Self, JoinError> {
920        let (to_leader_send, to_leader_recv) = channel();
921        let (from_leader_send, from_leader_recv) = channel();
922        let this = Self::with_id(
923            driver,
924            invitation.address.local_addr.node,
925            to_leader_send,
926            from_leader_recv,
927        );
928
929        let handle = this
930            .local_node
931            .add_remote(invitation.address.remote_addr.node);
932        let init_recv = OneshotReceiver::<InitialMessage>::from(this.local_node.add_port(
933            invitation.address.local_addr.port,
934            invitation.address.remote_addr,
935        ));
936
937        start_connection(
938            &this.tasks,
939            &this.local_node,
940            invitation.address.remote_addr.node,
941            handle,
942            UnixSocket::new(this.driver.as_ref(), invitation.fd.into()),
943        );
944
945        let init_message = init_recv.await.map_err(JoinError)?;
946        to_leader_recv.bridge(init_message.leader_send);
947        from_leader_send.bridge(init_message.follower_recv);
948        port.bridge(init_message.user_port);
949
950        Ok(this)
951    }
952
953    /// Shuts down the node, waiting for any sent messages to be sent to their
954    /// destination.
955    ///
956    /// After this call, any active ports will no longer be able to receive
957    /// messages.
958    ///
959    /// It is essential to call this before exiting a mesh process; until this
960    /// returns, data loss could occur for other mesh nodes.
961    pub async fn shutdown(mut self) {
962        // Wait for any proxy ports to disassociate.
963        self.local_node.wait_for_ports(false).await;
964        // Drop all connections to the leader.
965        drop(self.to_leader);
966        self.local_node.drop_connector();
967        // Terminate the leader task.
968        self.leader_resign_send.get_mut().take();
969        // Fail all nodes so that the send threads are dropped.
970        self.local_node.fail_all_nodes();
971        // Signal the IO task to tear down.
972        drop(self.tasks);
973        // Wait for the IO task.
974        self.io_task.await;
975    }
976}
977
978/// An error returned by [`UnixNode::join`].
979#[derive(Debug, Error)]
980#[error("failed to accept invitation")]
981pub struct JoinError(#[source] RecvError);
982
983/// The connector used when the mesh needs to connect to a previously-recognized
984/// node. Sends a message to the leader node to get a new socket to communicate
985/// over.
986#[derive(Debug)]
987struct Connector {
988    local_id: NodeId,
989    conn_req_send: Arc<mesh_channel::Sender<LeaderRequest>>,
990    pending_connections: Arc<Mutex<HashMap<NodeId, RemoteNodeHandle>>>,
991}
992
993impl Connect for Connector {
994    fn connect(&self, node_id: NodeId, handle: RemoteNodeHandle) {
995        tracing::trace!(local_id = ?self.local_id, remote_id = ?node_id, "connecting");
996        let old_request = self.pending_connections.lock().insert(node_id, handle);
997        if old_request.is_some() {
998            panic!("duplicate connection request for {:?}", node_id);
999        }
1000        self.conn_req_send.send(LeaderRequest::Connect(node_id))
1001    }
1002}
1003
1004/// Creates an AF_UNIX socket pair of the appropriate type.
1005fn new_socket_pair() -> Result<(Socket, Socket), io::Error> {
1006    let ty = if USE_SEQPACKET {
1007        socket2::Type::SEQPACKET
1008    } else {
1009        socket2::Type::STREAM
1010    };
1011    Socket::pair(socket2::Domain::UNIX, ty, None)
1012}
1013
1014/// An AF_UNIX SOCK_SEQPACKET connection.
1015struct UnixSocket {
1016    socket: Mutex<PolledSocket<Socket>>,
1017}
1018
1019#[repr(C)]
1020struct CmsgScmRights {
1021    hdr: libc::cmsghdr,
1022    fds: [RawFd; 64],
1023}
1024
1025// TODO: replace this copy+paste of IoSlice::advance_slices with std's
1026// implementation once stabilized.
1027fn advance_slices(bufs: &mut &mut [IoSlice<'_>], n: usize) {
1028    // Number of buffers to remove.
1029    let mut remove = 0;
1030    // Total length of all the to be removed buffers.
1031    let mut accumulated_len = 0;
1032    for buf in bufs.iter() {
1033        if accumulated_len + buf.len() > n {
1034            break;
1035        } else {
1036            accumulated_len += buf.len();
1037            remove += 1;
1038        }
1039    }
1040
1041    *bufs = &mut std::mem::take(bufs)[remove..];
1042    if !bufs.is_empty() {
1043        let buf = bufs[0];
1044        // SAFETY: this transmute extends the lifetime, which is necessary
1045        // because IoSlice<'a> does not have a method to get the inner slice
1046        // with lifetime 'a, even though this is perfectly safe and is necessary
1047        // to implement this function.
1048        bufs[0] = unsafe {
1049            std::mem::transmute::<IoSlice<'_>, IoSlice<'_>>(IoSlice::new(
1050                &buf[n - accumulated_len..],
1051            ))
1052        };
1053    }
1054}
1055
1056impl UnixSocket {
1057    fn new(driver: &dyn SpawnDriver, fd: Socket) -> Self {
1058        let socket = PolledSocket::new(driver, fd).unwrap();
1059        UnixSocket {
1060            socket: Mutex::new(socket),
1061        }
1062    }
1063
1064    async fn send(&self, msg: &[u8], fds: &[OsResource]) -> io::Result<()> {
1065        if USE_SEQPACKET {
1066            self.send_raw(&mut [IoSlice::new(msg)], fds).await?;
1067        } else {
1068            let len = (msg.len() as u32).to_le_bytes();
1069            let mut iov = [IoSlice::new(&len), IoSlice::new(msg)];
1070            self.send_all_raw(&mut iov, fds).await?;
1071        }
1072        Ok(())
1073    }
1074
1075    async fn send_raw(
1076        &self,
1077        iov: &mut [IoSlice<'_>],
1078        fds: &[OsResource],
1079    ) -> Result<usize, io::Error> {
1080        let n = poll_fn(|cx| {
1081            self.socket
1082                .lock()
1083                .poll_io(cx, InterestSlot::Write, PollEvents::OUT, |socket| {
1084                    try_send(socket.get(), iov, fds)
1085                })
1086        })
1087        .await?;
1088        Ok(n)
1089    }
1090
1091    async fn send_all_raw(
1092        &self,
1093        mut iov: &mut [IoSlice<'_>],
1094        mut fds: &[OsResource],
1095    ) -> Result<(), io::Error> {
1096        while !iov.is_empty() || !fds.is_empty() {
1097            let n = self.send_raw(iov, fds).await?;
1098            advance_slices(&mut iov, n);
1099            fds = &[];
1100        }
1101        Ok(())
1102    }
1103
1104    async fn recv(&self, buf: &mut [u8], fds: &mut Vec<OsResource>) -> io::Result<usize> {
1105        if USE_SEQPACKET {
1106            self.recv_raw(buf, fds).await
1107        } else {
1108            let mut len = [0; 4];
1109            if !self.recv_all_raw(&mut len, fds).await? {
1110                return Ok(0);
1111            }
1112            let len = u32::from_le_bytes(len) as usize;
1113            let buf = buf
1114                .get_mut(..len)
1115                .ok_or_else(|| io::Error::from_raw_os_error(libc::EMSGSIZE))?;
1116            if !self.recv_all_raw(buf, fds).await? {
1117                return Err(ErrorKind::UnexpectedEof.into());
1118            }
1119            Ok(len)
1120        }
1121    }
1122
1123    async fn recv_all_raw(
1124        &self,
1125        buf: &mut [u8],
1126        fds: &mut Vec<OsResource>,
1127    ) -> Result<bool, io::Error> {
1128        let mut read = 0;
1129        while read < buf.len() {
1130            let n = self.recv_raw(&mut buf[read..], fds).await?;
1131            if n == 0 {
1132                if read != 0 {
1133                    return Err(ErrorKind::UnexpectedEof.into());
1134                } else {
1135                    return Ok(false);
1136                }
1137            }
1138            read += n;
1139        }
1140        Ok(true)
1141    }
1142
1143    async fn recv_raw(
1144        &self,
1145        buf: &mut [u8],
1146        fds: &mut Vec<OsResource>,
1147    ) -> Result<usize, io::Error> {
1148        let n = poll_fn(|cx| {
1149            self.socket
1150                .lock()
1151                .poll_io(cx, InterestSlot::Read, PollEvents::IN, |socket| {
1152                    try_recv(socket.get(), buf, fds)
1153                })
1154        })
1155        .await?;
1156        Ok(n)
1157    }
1158
1159    async fn close_write(&self) -> io::Result<()> {
1160        self.socket.lock().get().shutdown(std::net::Shutdown::Write)
1161    }
1162}
1163
1164/// Sends a packet, including the specified file descriptors. May fail with
1165/// ErrorKind::WouldBlock.
1166// x86_64-unknown-linux-musl targets have a different type defn for
1167// `libc::cmsghdr`, hence why these lints are being suppressed.
1168#[allow(clippy::needless_update, clippy::useless_conversion)]
1169fn try_send(socket: &Socket, msg: &[IoSlice<'_>], fds: &[OsResource]) -> io::Result<usize> {
1170    let mut cmsg = CmsgScmRights {
1171        hdr: libc::cmsghdr {
1172            cmsg_level: libc::SOL_SOCKET,
1173            cmsg_type: libc::SCM_RIGHTS,
1174            cmsg_len: (size_of::<libc::cmsghdr>() + size_of_val(fds))
1175                .try_into()
1176                .unwrap(),
1177
1178            ..{
1179                // SAFETY: type has no invariants
1180                unsafe { std::mem::zeroed() }
1181            }
1182        },
1183        fds: [0; 64],
1184    };
1185    for (fdi, fdo) in fds.iter().zip(cmsg.fds.iter_mut()) {
1186        *fdo = match fdi {
1187            OsResource::Fd(fd) => fd.as_raw_fd(),
1188        }
1189    }
1190
1191    // SAFETY: type has no invariants
1192    let mut hdr: libc::msghdr = unsafe { std::mem::zeroed() };
1193    hdr.msg_iov = msg.as_ptr() as *mut libc::iovec;
1194    hdr.msg_iovlen = msg.len().try_into().unwrap();
1195    hdr.msg_control = if fds.is_empty() {
1196        std::ptr::null_mut()
1197    } else {
1198        std::ptr::from_mut(&mut cmsg).cast::<libc::c_void>()
1199    };
1200    hdr.msg_controllen = if fds.is_empty() { 0 } else { cmsg.hdr.cmsg_len };
1201    // SAFETY: calling with appropriately initialized buffers.
1202    let n = unsafe { libc::sendmsg(socket.as_raw_fd(), &hdr, 0) };
1203    if n < 0 {
1204        return Err(io::Error::last_os_error());
1205    }
1206    Ok(n as usize)
1207}
1208
1209/// Receives the next packet. Returns the number of bytes read and any file
1210/// descriptors that were associated with the packet. May fail with
1211/// ErrorKind::WouldBlock.
1212fn try_recv(socket: &Socket, buf: &mut [u8], fds: &mut Vec<OsResource>) -> io::Result<usize> {
1213    assert!(!buf.is_empty());
1214    let mut iov = IoSliceMut::new(buf);
1215    // SAFETY: type has no invariants
1216    let mut cmsg: CmsgScmRights = unsafe { std::mem::zeroed() };
1217    // SAFETY: type has no invariants
1218    let mut hdr: libc::msghdr = unsafe { std::mem::zeroed() };
1219    hdr.msg_iov = std::ptr::from_mut(&mut iov).cast::<libc::iovec>();
1220    hdr.msg_iovlen = 1;
1221    hdr.msg_control = std::ptr::from_mut(&mut cmsg).cast::<libc::c_void>();
1222    hdr.msg_controllen = size_of_val(&cmsg) as _;
1223
1224    // On Linux, automatically set O_CLOEXEC on incoming fds.
1225    #[cfg(target_os = "linux")]
1226    let flags = libc::MSG_CMSG_CLOEXEC;
1227    #[cfg(not(target_os = "linux"))]
1228    let flags = 0;
1229
1230    // SAFETY: calling with properly initialized buffers.
1231    let n = unsafe { libc::recvmsg(socket.as_raw_fd(), &mut hdr, flags) };
1232    if n < 0 {
1233        return Err(io::Error::last_os_error());
1234    }
1235    if n == 0 {
1236        assert_eq!(hdr.msg_controllen, 0);
1237        return Ok(0);
1238    }
1239
1240    let fd_count = if hdr.msg_controllen > 0 {
1241        if cmsg.hdr.cmsg_level != libc::SOL_SOCKET || cmsg.hdr.cmsg_type != libc::SCM_RIGHTS {
1242            // BUGBUG: need to loop: possible to leak fds
1243            return Err(ErrorKind::InvalidData.into());
1244        }
1245        #[allow(clippy::unnecessary_cast)] // cmsg_len is u32 on musl and usize on gnu.
1246        {
1247            (cmsg.hdr.cmsg_len as usize - size_of_val(&cmsg.hdr)) / size_of::<RawFd>()
1248        }
1249    } else {
1250        0
1251    };
1252
1253    let start = fds.len();
1254    fds.extend(cmsg.fds[..fd_count].iter().map(|x| {
1255        // SAFETY: according to the contract with the kernel, this
1256        // fd is now owned by the process.
1257        OsResource::Fd(unsafe { OwnedFd::from_raw_fd(*x) })
1258    }));
1259
1260    // Set O_CLOEXEC on all received fds on platforms that don't support
1261    // MSG_CMSG_CLOEXEC (set above).
1262    if !cfg!(target_os = "linux") {
1263        for OsResource::Fd(fd) in &fds[start..] {
1264            set_cloexec(fd);
1265        }
1266    }
1267
1268    // Check for truncation only after taking ownership of the fds.
1269    if hdr.msg_flags & (libc::MSG_TRUNC | libc::MSG_CTRUNC) != 0 {
1270        return Err(io::Error::from_raw_os_error(libc::EMSGSIZE));
1271    }
1272    Ok(n as usize)
1273}
1274
1275fn set_cloexec(fd: impl AsFd) {
1276    // SAFETY: using fcntl as documented.
1277    unsafe {
1278        let flags = libc::fcntl(fd.as_fd().as_raw_fd(), libc::F_GETFD);
1279        assert!(flags >= 0);
1280        let r = libc::fcntl(
1281            fd.as_fd().as_raw_fd(),
1282            libc::F_SETFD,
1283            flags | libc::FD_CLOEXEC,
1284        );
1285        assert!(r >= 0);
1286    }
1287}
1288
1289#[cfg(test)]
1290mod tests {
1291    use crate::unix::UnixNode;
1292    use mesh_channel::RecvError;
1293    use mesh_channel::channel;
1294    use pal_async::DefaultDriver;
1295    use pal_async::async_test;
1296    use test_with_tracing::test;
1297
1298    #[async_test]
1299    async fn test_basic(driver: DefaultDriver) {
1300        let leader = UnixNode::new(driver.clone());
1301        let (send, recv) = channel::<u32>();
1302        let invitation = leader.invite(recv.into()).await.unwrap();
1303        let (send2, mut recv2) = channel::<u32>();
1304        let follower = UnixNode::join(driver, invitation, send2.into())
1305            .await
1306            .unwrap();
1307        send.send(5);
1308        assert_eq!(recv2.recv().await.unwrap(), 5);
1309        drop(send);
1310        drop(recv2);
1311        follower.shutdown().await;
1312        leader.shutdown().await;
1313    }
1314
1315    #[cfg(target_os = "linux")]
1316    #[async_test]
1317    async fn test_huge_message(driver: DefaultDriver) {
1318        let leader = UnixNode::new(driver.clone());
1319        let (send, recv) = channel::<Vec<u8>>();
1320        let invitation = leader.invite(recv.into()).await.unwrap();
1321        let (send2, mut recv2) = channel::<Vec<u8>>();
1322        let follower = UnixNode::join(driver, invitation, send2.into())
1323            .await
1324            .unwrap();
1325
1326        let v = vec![0xcc; 16 << 20];
1327        send.send(v.clone());
1328        let v2 = recv2.recv().await.unwrap();
1329        assert_eq!(v, v2);
1330        follower.shutdown().await;
1331        leader.shutdown().await;
1332    }
1333
1334    #[cfg(target_os = "linux")]
1335    #[async_test]
1336    async fn test_message_sizes(driver: DefaultDriver) {
1337        let (p1, p2) = mesh_node::local_node::Port::new_pair();
1338        let (p3, p4) = mesh_node::local_node::Port::new_pair();
1339        let node1 = UnixNode::new(driver.clone());
1340        let invitation = node1.invite(p2).await.unwrap();
1341        let _node2 = UnixNode::join(driver.clone(), invitation, p3)
1342            .await
1343            .unwrap();
1344
1345        crate::test_common::test_message_sizes(p1, p4, 0..=super::MAX_SMALL_EVENT_SIZE + 0x1000)
1346            .await;
1347    }
1348
1349    #[async_test]
1350    async fn test_dropped_shutdown(driver: DefaultDriver) {
1351        let leader = UnixNode::new(driver.clone());
1352        {
1353            let (_send, recv) = channel::<u32>();
1354            let invitation = leader.invite(recv.into()).await.unwrap();
1355            let (send2, _recv2) = channel::<u32>();
1356            let _follower = UnixNode::join(driver, invitation, send2.into())
1357                .await
1358                .unwrap();
1359        }
1360        leader.shutdown().await;
1361    }
1362
1363    #[async_test]
1364    async fn test_send_shutdown(driver: DefaultDriver) {
1365        let leader = UnixNode::new(driver.clone());
1366        let (send, mut recv) = channel::<u32>();
1367        let invitation = leader.invite(send.into()).await.unwrap();
1368        let (send2, recv2) = channel::<u32>();
1369        let follower = UnixNode::join(driver, invitation, recv2.into())
1370            .await
1371            .unwrap();
1372        send2.send(5);
1373        drop(send2);
1374        follower.shutdown().await;
1375        assert_eq!(recv.recv().await.unwrap(), 5);
1376    }
1377
1378    #[async_test]
1379    async fn test_failed_invitation(driver: DefaultDriver) {
1380        let leader = UnixNode::new(driver);
1381        let (send, mut recv) = channel::<()>();
1382        leader.invite(send.into()).await.unwrap();
1383        assert!(matches!(
1384            recv.recv().await.unwrap_err(),
1385            RecvError::Error(_)
1386        ));
1387        drop(recv);
1388        leader.shutdown().await;
1389    }
1390
1391    #[async_test]
1392    async fn test_three(driver: DefaultDriver) {
1393        let (p1, p2) = channel::<u32>();
1394        let (p3, mut p4) = channel::<u32>();
1395        let (p5, p6) = channel::<u32>();
1396        let (p7, p8) = channel::<u32>();
1397
1398        let node1 = UnixNode::new(driver.clone());
1399
1400        let invitation = node1.invite(p2.into()).await.unwrap();
1401        let node2 = UnixNode::join(driver.clone(), invitation, p3.into())
1402            .await
1403            .unwrap();
1404
1405        let invitation = node1.invite(p5.into()).await.unwrap();
1406        let node3 = UnixNode::join(driver, invitation, p8.into()).await.unwrap();
1407
1408        p1.bridge(p6);
1409
1410        p7.send(5);
1411
1412        assert_eq!(p4.recv().await.unwrap(), 5);
1413        drop(p4);
1414        drop(p7);
1415        futures::join!(node2.shutdown(), node3.shutdown());
1416        node1.shutdown().await;
1417    }
1418
1419    #[async_test]
1420    async fn test_handoff_leader(driver: DefaultDriver) {
1421        let (p1, p2) = channel::<u32>();
1422        let (p3, p4) = channel::<u32>();
1423        let (p5, p6) = channel::<u32>();
1424        let (p7, p8) = channel::<u32>();
1425        let (p9, p10) = channel();
1426        let (p11, mut p12) = channel();
1427
1428        let node1 = UnixNode::new(driver.clone());
1429
1430        let invitation = node1.invite(p2.into()).await.unwrap();
1431        let node2 = UnixNode::join(driver.clone(), invitation, p3.into())
1432            .await
1433            .unwrap();
1434
1435        let invitation = node1.invite(p5.into()).await.unwrap();
1436        let node3 = UnixNode::join(driver.clone(), invitation, p8.into())
1437            .await
1438            .unwrap();
1439
1440        let invitation = node1.invite(p10.into()).await.unwrap();
1441        let node4 = UnixNode::join(driver, invitation, p11.into())
1442            .await
1443            .unwrap();
1444
1445        p9.send(node1.offer_leadership());
1446        node4.accept_leadership(p12.recv().await.unwrap());
1447        drop(p9);
1448        drop(p12);
1449        p1.bridge(p6);
1450
1451        std::thread::sleep(std::time::Duration::from_millis(200));
1452
1453        node1.shutdown().await;
1454        drop(p4);
1455        drop(p7);
1456        node2.shutdown().await;
1457        node3.shutdown().await;
1458
1459        std::thread::sleep(std::time::Duration::from_millis(200));
1460
1461        node4.shutdown().await;
1462    }
1463}