Skip to main content

mesh_remote/
unix_node.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Unix socket-based mesh node implementation.
5//!
6//! Each pair of nodes communicate using a single, bidirectional Unix socket. On
7//! platforms that support it (Linux), a `SOCK_SEQPACKET` socket is used, which
8//! provides message framing and atomic message sends. Otherwise, a
9//! `SOCK_STREAM` packet is used, and the protocol includes a message size.
10//!
11//! File descriptors are sent between nodes using the `SCM_RIGHTS` functionality
12//! of Unix sockets.
13
14#![cfg(unix)]
15
16#[cfg(target_os = "linux")]
17mod memfd;
18
19use crate::common::InvitationAddress;
20use crate::protocol;
21use crate::unix_common::try_recv;
22use crate::unix_common::try_send;
23use futures::FutureExt;
24use futures::StreamExt;
25use futures::channel::mpsc;
26use futures::future;
27use futures::future::BoxFuture;
28use io::ErrorKind;
29use mesh_channel::OneshotReceiver;
30use mesh_channel::OneshotSender;
31use mesh_channel::RecvError;
32use mesh_channel::channel;
33use mesh_channel::oneshot;
34use mesh_node::common::Address;
35use mesh_node::common::NodeId;
36use mesh_node::common::PortId;
37use mesh_node::local_node::Connect;
38use mesh_node::local_node::LocalNode;
39use mesh_node::local_node::OutgoingEvent;
40use mesh_node::local_node::Port;
41use mesh_node::local_node::RemoteNodeHandle;
42use mesh_node::local_node::SendEvent;
43use mesh_node::resource::OsResource;
44use mesh_node::resource::Resource;
45use mesh_protobuf::Protobuf;
46use pal_async::driver::SpawnDriver;
47use pal_async::interest::InterestSlot;
48use pal_async::interest::PollEvents;
49use pal_async::socket::PolledSocket;
50use pal_async::task::Spawn;
51use pal_async::task::Task;
52use parking_lot::Mutex;
53use socket2::Socket;
54use std::collections::HashMap;
55use std::collections::VecDeque;
56use std::fmt::Debug;
57use std::future::Future;
58use std::future::poll_fn;
59use std::io;
60use std::io::IoSlice;
61use std::os::unix::prelude::*;
62use std::pin::pin;
63use std::sync::Arc;
64use thiserror::Error;
65use tracing::instrument;
66use unicycle::FuturesUnordered;
67use zerocopy::FromBytes;
68use zerocopy::FromZeros;
69use zerocopy::IntoBytes;
70
71/// If true, use a SOCK_SEQPACKET socket. Otherwise, use a SOCK_STREAM socket.
72///
73/// SOCK_SEQPACKET is preferred where available because it allows us to avoid
74/// separately tracking message boundaries. Most importantly, this makes it
75/// straightforward to support sending messages from multiple threads
76/// simultaneously.
77const USE_SEQPACKET: bool = cfg!(target_os = "linux");
78
79/// The maximum packet size. Linux uses memfd to send data larger than this.
80/// Other OSes just fail, so choose a larger size.
81///
82/// These values were chosen arbitrarily and have not been tested for
83/// performance.
84const MAX_PACKET_SIZE: usize = if cfg!(target_os = "linux") {
85    0x4000
86} else {
87    0x40000
88};
89
90const MAX_SMALL_EVENT_SIZE: usize = MAX_PACKET_SIZE - size_of::<protocol::PacketHeader>();
91
92/// A node within a mesh that uses Unix sockets to communicate.
93///
94/// Each pairwise connection between two nodes in the mesh communicates via a
95/// pair of bidirectional sockets.
96///
97/// If one node needs to send data to another but does not have a connection, it
98/// sends a request to the leader node to establish a connection. The leader
99/// creates a new socket pair and sends one end to each of the nodes, which the
100/// two nodes can use to communicate.
101pub struct UnixNode {
102    driver: Arc<dyn SpawnDriver>,
103    local_node: Arc<LocalNode>,
104    to_leader: Arc<mesh_channel::Sender<LeaderRequest>>,
105    tasks: Arc<mesh_channel::Sender<SmallTask>>,
106    io_task: Task<()>,
107    // TODO: consider reducing type complexity?
108    leader_resign_send:
109        Mutex<Option<Arc<mesh_channel::Sender<(NodeId, mesh_channel::Sender<Followers>)>>>>,
110
111    // meaningful drop
112    _drop_send: OneshotSender<()>,
113}
114
115/// Error returned when creating a mesh invitation fails.
116#[derive(Debug, Error)]
117#[error("mesh node shut down before invitation could be created")]
118pub struct InviteError;
119
120/// A Clone + Send handle for creating mesh invitations.
121///
122/// Extracted from a [`UnixNode`] via [`UnixNode::inviter()`]. This delegates
123/// invitation requests to the node's leader task via the existing
124/// `LeaderRequest::Invite` channel.
125///
126/// This allows creating invitations from any async task without holding a
127/// reference to the node.
128#[derive(Clone)]
129pub(crate) struct UnixMeshInviter {
130    to_leader: mesh_channel::Sender<LeaderRequest>,
131}
132
133impl UnixMeshInviter {
134    /// Create a mesh invitation for a new node to join.
135    ///
136    /// The returned [`Invitation`] can be sent to another process to allow
137    /// it to join the mesh.
138    pub(crate) async fn invite(&self, port: Port) -> Result<Invitation, InviteError> {
139        let (invitation_send, mut invitation_recv) = channel();
140        self.to_leader
141            .send(LeaderRequest::Invite(port, invitation_send));
142        invitation_recv.recv().await.map_err(|_| InviteError)
143    }
144}
145
146#[derive(Debug, Protobuf)]
147#[mesh(resource = "Resource")]
148enum LeaderRequest {
149    Connect(NodeId),
150    Invite(Port, mesh_channel::Sender<Invitation>),
151}
152
153#[derive(Debug, Protobuf)]
154#[mesh(resource = "Resource")]
155enum FollowerRequest {
156    Connect(
157        NodeId,
158        #[mesh(
159            encoding = "mesh_protobuf::encoding::OptionField<mesh_protobuf::encoding::ResourceField<OwnedFd>>"
160        )]
161        Option<Socket>,
162    ),
163}
164
165#[derive(Debug, Protobuf)]
166#[mesh(resource = "Resource")]
167pub struct Followers {
168    list: Vec<(
169        NodeId,
170        mesh_channel::Receiver<LeaderRequest>,
171        mesh_channel::Sender<FollowerRequest>,
172    )>,
173}
174
175#[derive(Debug, Protobuf)]
176#[mesh(resource = "Resource")]
177struct InitialMessage {
178    leader_send: mesh_channel::Sender<LeaderRequest>,
179    follower_recv: mesh_channel::Receiver<FollowerRequest>,
180    user_port: Port,
181}
182
183/// Processes incoming requests from the leader to a follower. Currently the only
184/// such request is to add a connection to another node.
185#[instrument(skip_all, fields(local_id = ?local_node.id()))]
186async fn run_follower(
187    driver: &dyn SpawnDriver,
188    local_node: &Arc<LocalNode>,
189    mut recv: mesh_channel::Receiver<FollowerRequest>,
190    pending_connections: Arc<Mutex<HashMap<NodeId, RemoteNodeHandle>>>,
191    tasks: &mesh_channel::Sender<SmallTask>,
192) {
193    while let Ok(req) = recv.recv().await {
194        match req {
195            FollowerRequest::Connect(target_id, fd) => {
196                tracing::debug!(?target_id, "got connection request from leader");
197                let handle = pending_connections.lock().remove(&target_id);
198                let handle = handle.unwrap_or_else(|| local_node.get_remote_handle(target_id));
199
200                if let Some(fd) = fd {
201                    start_connection(
202                        tasks,
203                        local_node,
204                        target_id,
205                        handle,
206                        UnixSocket::new(driver, fd),
207                    );
208                } else {
209                    tracing::warn!(?target_id, "leader provided failed connection");
210                }
211            }
212        }
213    }
214}
215
216/// Processes incoming requests from a follower to the leader. Runs until there
217/// are no more followers or until the leader is asked to transfer power to
218/// another node via `resign_recv`.
219#[instrument(skip_all, fields(local_id = ?local_node.id()))]
220async fn run_leader(
221    driver: &dyn SpawnDriver,
222    local_node: &Arc<LocalNode>,
223    mut resign_recv: mesh_channel::Receiver<(NodeId, mesh_channel::Sender<Followers>)>,
224    followers: Followers,
225    tasks: &mesh_channel::Sender<SmallTask>,
226) {
227    let mut senders = HashMap::new();
228    let mut receivers = Vec::new();
229    for (remote_id, recv, send) in followers.list {
230        receivers.push((remote_id, recv));
231        senders.insert(remote_id, send);
232    }
233
234    let new_leader_info = loop {
235        if receivers.is_empty() {
236            return;
237        }
238        let recvs = receivers
239            .iter_mut()
240            .map(|(_, recv)| poll_fn(|cx| recv.poll_recv(cx)));
241        let (req, index, _) = futures::select! { // merge semantics
242            r = resign_recv.next() => break r,
243            r = future::select_all(recvs).fuse() => r,
244        };
245        let remote_id = receivers[index].0;
246        match req {
247            Ok(req) => match req {
248                LeaderRequest::Connect(target_id) => {
249                    tracing::debug!(?target_id, ?remote_id, "connection request");
250                    let remote = senders
251                        .get(&remote_id)
252                        .expect("sender must exist to receive from it");
253                    let mut fd = None;
254                    if let Some(target) = senders.get(&target_id) {
255                        match new_socket_pair() {
256                            Ok((left, right)) => {
257                                tracing::trace!(?target, "send to");
258                                target.send(FollowerRequest::Connect(remote_id, Some(left)));
259                                fd = Some(right);
260                            }
261                            Err(err) => {
262                                tracing::warn!(
263                                    ?target_id,
264                                    ?remote_id,
265                                    error = &err as &dyn std::error::Error,
266                                    "failed to create socket pair for connection request"
267                                );
268                            }
269                        }
270                    } else {
271                        tracing::warn!(?target_id, ?remote_id, "could not find target for remote");
272                    }
273                    remote.send(FollowerRequest::Connect(target_id, fd));
274                }
275                LeaderRequest::Invite(port, send) => {
276                    tracing::debug!(?remote_id, "invitation request");
277                    match new_socket_pair() {
278                        Ok((left, right)) => {
279                            let (leader_send, leader_recv) = channel();
280                            let (follower_send, follower_recv) = channel();
281                            let remote_addr = Address {
282                                node: NodeId::new(),
283                                port: PortId::new(),
284                            };
285                            let local_port_id = PortId::new();
286                            let handle = local_node.add_remote(remote_addr.node);
287                            start_connection(
288                                tasks,
289                                local_node,
290                                remote_addr.node,
291                                handle,
292                                UnixSocket::new(driver, left),
293                            );
294                            let init_send = OneshotSender::<InitialMessage>::from(
295                                local_node.add_port(local_port_id, remote_addr),
296                            );
297                            init_send.send(InitialMessage {
298                                leader_send,
299                                follower_recv,
300                                user_port: port,
301                            });
302                            let invitation = Invitation {
303                                address: InvitationAddress {
304                                    local_addr: remote_addr,
305                                    remote_addr: Address {
306                                        node: local_node.id(),
307                                        port: local_port_id,
308                                    },
309                                },
310                                fd: right.into(),
311                            };
312                            tracing::debug!(
313                                invite_id = ?invitation.address.local_addr.node,
314                                ?remote_id,
315                                "inviting",
316                            );
317                            send.send(invitation);
318                            senders.insert(remote_addr.node, follower_send);
319                            receivers.push((remote_addr.node, leader_recv));
320                        }
321                        Err(err) => {
322                            tracing::error!(
323                                error = &err as &dyn std::error::Error,
324                                "failed to create socket pair",
325                            );
326                        }
327                    }
328                }
329            },
330            Err(err) => {
331                if let RecvError::Error(err) = err {
332                    tracing::debug!(
333                        ?remote_id,
334                        error = &err as &dyn std::error::Error,
335                        "leader connection to remote failed"
336                    );
337                }
338                senders.remove(&remote_id);
339                receivers.swap_remove(index);
340            }
341        }
342    };
343
344    if let Some((new_leader_id, new_leader_followers_sink)) = new_leader_info {
345        if let Some(new_leader_send) = senders.get(&new_leader_id) {
346            tracing::debug!(?new_leader_id, "resigning leadership");
347            // Ensure there is a connection between every follower and the new
348            // leader.
349            for (remote_id, send) in senders.iter() {
350                if new_leader_id != *remote_id {
351                    match new_socket_pair() {
352                        Ok((left, right)) => {
353                            send.send(FollowerRequest::Connect(new_leader_id, Some(left)));
354                            new_leader_send.send(FollowerRequest::Connect(*remote_id, Some(right)));
355                        }
356                        Err(err) => {
357                            tracing::error!(
358                                ?new_leader_id,
359                                error = &err as &dyn std::error::Error,
360                                "failed to connect node to new leader, mesh is leaderless",
361                            );
362                            return;
363                        }
364                    }
365                }
366            }
367
368            // Send all the followers to the new leader.
369            let mut followers = Vec::new();
370            for (remote_id, recv) in receivers {
371                let send = senders
372                    .remove(&remote_id)
373                    .expect("should be in sync with receivers");
374                followers.push((remote_id, recv, send));
375            }
376            new_leader_followers_sink.send(Followers { list: followers });
377        } else {
378            tracing::error!(?new_leader_id, "new leader is unknown, mesh is leaderless");
379        }
380    }
381}
382
383/// A task initiator, implementing by a function returning a future. This is
384/// used to send work to the node's IO thread.
385struct SmallTask {
386    name: &'static str,
387    future: BoxFuture<'static, ()>,
388}
389
390impl Debug for SmallTask {
391    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392        f.pad("SmallTask")
393    }
394}
395
396impl SmallTask {
397    fn new(name: &'static str, f: impl 'static + Send + Future<Output = ()>) -> Self {
398        Self {
399            name,
400            future: Box::pin(f),
401        }
402    }
403}
404
405/// An invitation allowing another process to join the mesh.
406///
407/// Created by [`UnixNode::invite`].
408#[derive(Debug, Protobuf)]
409#[mesh(resource = "Resource")]
410pub struct Invitation {
411    /// The common invitation addresses.
412    pub address: InvitationAddress,
413    /// The Unix socket used to initiate communications with the mesh.
414    pub fd: OwnedFd,
415}
416
417#[derive(Debug)]
418enum SenderCommand {
419    Send {
420        packet: Vec<u8>,
421        fds: Vec<OsResource>,
422    },
423    ReleaseFds {
424        count: usize,
425    },
426}
427
428#[derive(Clone)]
429struct PacketSender {
430    send: mpsc::UnboundedSender<SenderCommand>,
431    socket: Arc<UnixSocket>,
432}
433
434impl SendEvent for PacketSender {
435    fn event(&self, event: OutgoingEvent<'_>) {
436        let (packet, fds) = match serialize_event(event) {
437            Ok(r) => r,
438            Err(err) => {
439                // FUTURE: fail the port or connection instead?
440                tracing::error!(
441                    error = &err as &dyn std::error::Error,
442                    "failed to serialize event"
443                );
444                return;
445            }
446        };
447
448        // If using SOCK_SEQPACKET, try to send the packet immediately. If this
449        // fails (likely due to EAGAIN), send the packet to the asynchronous
450        // task for deferred processing.
451        //
452        // When SOCK_STREAM is in use, this optimization cannot be tried,
453        // because for stream sockets we may need to issue multiple writes to
454        // send the whole message, and those writes cannot be interleaved
455        // correctly.
456        //
457        // N.B. This can lead to out of order messages. The event protocol is
458        //      responsible for handling this condition.
459        if !USE_SEQPACKET
460            || try_send(
461                self.socket.socket.lock().get().as_fd(),
462                &[IoSlice::new(&packet)],
463                &fds,
464            )
465            .is_err()
466        {
467            let _ = self
468                .send
469                .unbounded_send(SenderCommand::Send { packet, fds });
470        }
471    }
472}
473
474fn serialize_event(event: OutgoingEvent<'_>) -> io::Result<(Vec<u8>, Vec<OsResource>)> {
475    // Serialize the event to a memfd if it's too large to send inline.
476    if event.len() > MAX_SMALL_EVENT_SIZE {
477        return serialize_large_event(event);
478    }
479
480    // Serialize the event to a byte vector.
481    let cap = size_of::<protocol::PacketHeader>() + event.len();
482    let mut packet = Vec::with_capacity(cap);
483    packet.extend_from_slice(
484        protocol::PacketHeader {
485            packet_type: protocol::PacketType::EVENT,
486            reserved: [0; 7],
487        }
488        .as_bytes(),
489    );
490    let mut fds = Vec::new();
491    event.write_to(&mut packet, &mut fds);
492    assert_eq!(packet.len(), cap);
493    Ok((packet, fds))
494}
495
496#[cfg(target_os = "linux")]
497fn serialize_large_event(event: OutgoingEvent<'_>) -> io::Result<(Vec<u8>, Vec<OsResource>)> {
498    let packet = protocol::PacketHeader {
499        packet_type: protocol::PacketType::LARGE_EVENT,
500        ..FromZeros::new_zeroed()
501    }
502    .as_bytes()
503    .to_vec();
504
505    let mut fds = Vec::new();
506
507    let mut memfd = memfd::MemfdBuilder::new(event.len())?;
508    event.write_to(&mut io::Cursor::new(&mut *memfd), &mut fds);
509    fds.insert(0, OsResource::Fd(memfd.seal()?.into()));
510
511    Ok((packet, fds))
512}
513
514#[cfg(not(target_os = "linux"))]
515fn serialize_large_event(_event: OutgoingEvent<'_>) -> io::Result<(Vec<u8>, Vec<OsResource>)> {
516    Err(io::Error::new(
517        ErrorKind::Unsupported,
518        "event too large for this OS",
519    ))
520}
521
522impl Drop for PacketSender {
523    fn drop(&mut self) {
524        // Explicitly close the send channel so that the send task returns, even
525        // though the send channel is also in use by the receive task.
526        self.send.close_channel();
527    }
528}
529
530/// Starts a connection processing task.
531fn start_connection(
532    tasks: &mesh_channel::Sender<SmallTask>,
533    local_node: &Arc<LocalNode>,
534    remote_id: NodeId,
535    handle: RemoteNodeHandle,
536    socket: UnixSocket,
537) {
538    #[expect(clippy::disallowed_methods)] // TODO
539    let (send, recv) = mpsc::unbounded();
540    let socket = Arc::new(socket);
541    let sender = PacketSender {
542        send: send.clone(),
543        socket: socket.clone(),
544    };
545    if handle.connect(sender) {
546        let task = SmallTask::new("run_connection", {
547            let local_node = local_node.clone();
548            run_connection(local_node, remote_id, send, recv, socket, handle)
549        });
550        tasks.send(task);
551        tracing::debug!(?remote_id, "connected");
552    } else {
553        // N.B. This is an expected condition in many scenarios, since the
554        //      leader does not track which connections have already been
555        //      made and so will often send duplicate connection requests.
556        tracing::debug!(?remote_id, "duplicate connection");
557    }
558}
559
560/// Runs the packet processing loop.
561#[instrument(skip_all, fields(local_id = ?local_node.id(), remote_id = ?remote_id))]
562async fn run_connection(
563    local_node: Arc<LocalNode>,
564    remote_id: NodeId,
565    send_send: mpsc::UnboundedSender<SenderCommand>,
566    send_recv: mpsc::UnboundedReceiver<SenderCommand>,
567    socket: Arc<UnixSocket>,
568    handle: RemoteNodeHandle,
569) {
570    let mut retained_fds = VecDeque::new();
571    let mut recv = pin!(
572        async {
573            let r = run_receive(&local_node, &remote_id, &socket, &send_send).await;
574            match &r {
575                Ok(_) => {
576                    tracing::debug!("incoming socket disconnected");
577                }
578                Err(err) => {
579                    tracing::error!(error = err as &dyn std::error::Error, "error receiving");
580                }
581            }
582            r
583        }
584        .fuse()
585    );
586    let mut send = pin!(
587        async {
588            match run_send(send_recv, &socket, &mut retained_fds).await {
589                Ok(_) => {
590                    tracing::debug!("sending is done");
591                }
592                Err(err) => {
593                    tracing::error!(error = &err as &dyn std::error::Error, "failed send");
594                }
595            }
596        }
597        .fuse()
598    );
599    let r = futures::select! { // race semantics
600        r = recv => {
601            // Notify the remote node that no more data will be sent.
602            tracing::trace!("read complete, shutting down writes");
603            let _ = socket.close_write().await;
604            r
605        }
606        _ = send => {
607            match socket.close_write().await {
608                Ok(()) => {
609                    tracing::trace!("shutdown writes, waiting for reads");
610                    recv.await
611                }
612                Err(err) => {
613                    tracing::error!(
614                        error = &err as &dyn std::error::Error,
615                        "failed to shutdown writes, aborting connection",
616                    );
617                    Err(ReceiveError::Io(err))
618                }
619            }
620        }
621    };
622    tracing::trace!("connection done");
623    match r {
624        Ok(()) => handle.disconnect(),
625        Err(err) => handle.fail(err),
626    }
627}
628
629#[derive(Debug, Error)]
630enum ReceiveError {
631    #[error("i/o error")]
632    Io(#[from] io::Error),
633    #[error("missing packet header")]
634    NoHeader,
635    #[error("release fds packet too small")]
636    BadReleaseFds,
637    #[error("unknown packet type {0:?}")]
638    UnknownPacketType(protocol::PacketType),
639    #[cfg(target_os = "linux")]
640    #[error("memfd file descriptor not sent for large event")]
641    MissingMemfd,
642    #[cfg(target_os = "linux")]
643    #[error("failed to map memfd")]
644    Memfd(#[source] io::Error),
645}
646
647/// Handles receive processing for the socket.
648async fn run_receive(
649    local_node: &LocalNode,
650    remote_id: &NodeId,
651    socket: &UnixSocket,
652    send: &mpsc::UnboundedSender<SenderCommand>,
653) -> Result<(), ReceiveError> {
654    let mut buf = vec![0; MAX_PACKET_SIZE];
655    let mut fds = Vec::new();
656    loop {
657        let len = socket.recv(&mut buf, &mut fds).await?;
658        if len == 0 {
659            break;
660        }
661        if cfg!(target_os = "macos") && !fds.is_empty() {
662            // Tell the opposite endpoint to release the fds it sent.
663            let _ = send.unbounded_send(SenderCommand::Send {
664                packet: protocol::ReleaseFds {
665                    header: protocol::PacketHeader {
666                        packet_type: protocol::PacketType::RELEASE_FDS,
667                        ..FromZeros::new_zeroed()
668                    },
669                    count: fds.len() as u64,
670                }
671                .as_bytes()
672                .to_vec(),
673                fds: Vec::new(),
674            });
675        }
676
677        let buf = &buf[..len];
678        let header = protocol::PacketHeader::read_from_prefix(buf)
679            .map_err(|_| ReceiveError::NoHeader)?
680            .0; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
681        match header.packet_type {
682            protocol::PacketType::EVENT => {
683                local_node.event(remote_id, &buf[size_of_val(&header)..], &mut fds);
684                fds.clear();
685            }
686            protocol::PacketType::RELEASE_FDS => {
687                let release_fds = protocol::ReleaseFds::read_from_prefix(buf)
688                    .map_err(|_| ReceiveError::BadReleaseFds)?
689                    .0; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
690                let _ = send.unbounded_send(SenderCommand::ReleaseFds {
691                    count: release_fds.count as usize,
692                });
693            }
694            #[cfg(target_os = "linux")]
695            protocol::PacketType::LARGE_EVENT => {
696                if fds.is_empty() {
697                    return Err(ReceiveError::MissingMemfd);
698                }
699                let OsResource::Fd(fd) = fds.remove(0);
700                let memfd = memfd::SealedMemfd::new(fd.into()).map_err(ReceiveError::Memfd)?;
701                local_node.event(remote_id, &memfd, &mut fds);
702                fds.clear();
703            }
704            ty => {
705                return Err(ReceiveError::UnknownPacketType(ty));
706            }
707        }
708    }
709    Ok(())
710}
711
712#[derive(Debug, Error)]
713enum ProtocolError {
714    #[error("request to release too many fds")]
715    ReleasingTooManyFds,
716}
717
718/// Handles send processing for the socket.
719async fn run_send(
720    mut recv: mpsc::UnboundedReceiver<SenderCommand>,
721    socket: &UnixSocket,
722    retained_fds: &mut VecDeque<OsResource>,
723) -> io::Result<()> {
724    while let Some(command) = recv.next().await {
725        match command {
726            SenderCommand::Send { packet, fds } => {
727                match socket.send(&packet, &fds).await {
728                    Ok(_) => (),
729                    Err(err) => {
730                        tracing::error!(
731                            fd_count = fds.len(),
732                            packet_len = packet.len(),
733                            "failed to send packet"
734                        );
735                        return Err(err);
736                    }
737                }
738                if cfg!(target_os = "macos") {
739                    // MacOS has a bug where it prematurely closes Unix sockets
740                    // if a file descriptor to one is closed while it is also in
741                    // the process of being sent across another Unix socket.
742                    // Retain the fds until the opposite endpoint sends a reply
743                    // message.
744                    if !fds.is_empty() {
745                        retained_fds.extend(fds);
746                    }
747                }
748            }
749            SenderCommand::ReleaseFds { count } => {
750                if retained_fds.len() < count {
751                    return Err(io::Error::other(ProtocolError::ReleasingTooManyFds));
752                }
753                retained_fds.drain(..count);
754            }
755        }
756    }
757    Ok(())
758}
759
760/// An offer to take over as leader of the mesh. One trusted process must be the
761/// leader at all times or the mesh will fail.
762#[derive(Debug, Protobuf)]
763#[mesh(resource = "Resource")]
764pub struct LeadershipOffer {
765    send: mesh_channel::Sender<(NodeId, mesh_channel::Sender<Followers>)>,
766}
767
768impl UnixNode {
769    /// Creates a new Unix node mesh with this node as the leader.
770    pub fn new(driver: impl SpawnDriver) -> Self {
771        let (to_leader_send, to_leader_recv) = channel();
772        let (from_leader_send, from_leader_recv) = channel();
773        let this = Self::with_id(
774            Arc::new(driver),
775            NodeId::new(),
776            to_leader_send,
777            from_leader_recv,
778        );
779
780        // Start a leader task. At this point the leader has only one follower, itself.
781        let (resign_send, resign_recv) = channel();
782        let resign_send = Arc::new(resign_send);
783        let followers = Followers {
784            list: vec![(this.local_node.id(), to_leader_recv, from_leader_send)],
785        };
786        let task = SmallTask::new("run_leader", {
787            let local_node = this.local_node.clone();
788            let tasks = this.tasks.clone();
789            let driver = this.driver.clone();
790            async move { run_leader(driver.as_ref(), &local_node, resign_recv, followers, &tasks).await }
791        });
792        this.tasks.send(task);
793        *this.leader_resign_send.lock() = Some(resign_send);
794
795        this
796    }
797
798    /// Gets the node ID. This is mostly useful for diagnostics.
799    pub fn id(&self) -> NodeId {
800        self.local_node.id()
801    }
802
803    /// Creates a node with `id` using `to_leader` and `from_leader` to
804    /// communicate with the leader node.
805    fn with_id(
806        driver: Arc<dyn SpawnDriver>,
807        id: NodeId,
808        to_leader: mesh_channel::Sender<LeaderRequest>,
809        from_leader: mesh_channel::Receiver<FollowerRequest>,
810    ) -> Self {
811        let to_leader = Arc::new(to_leader);
812        let pending_connections: Arc<Mutex<HashMap<NodeId, RemoteNodeHandle>>> = Default::default();
813        let local_node = Arc::new(LocalNode::with_id(
814            id,
815            Box::new(Connector {
816                local_id: id,
817                conn_req_send: to_leader.clone(),
818                pending_connections: pending_connections.clone(),
819            }),
820        ));
821        let (task_send, mut task_recv) = channel::<SmallTask>();
822        let task_send = Arc::new(task_send);
823        let (drop_send, drop_recv) = oneshot();
824
825        // Start a thread to run IO tasks.
826        let io_task = driver.spawn("unix-mesh-io", async move {
827            let process = async {
828                let mut futs = FuturesUnordered::new();
829                loop {
830                    futures::select! { // merge semantics
831                        _ = futs.next() => {},
832                        task = task_recv.select_next_some() => {
833                            futs.push(async move {
834                                tracing::trace!(?id, name = task.name, "task start");
835                                task.future.await;
836                                tracing::trace!(?id, name = task.name, "task end");
837                            });
838                        },
839                        complete => break,
840                    };
841                }
842            };
843            future::select(pin!(process), drop_recv).await;
844        });
845
846        task_send.send(SmallTask::new("run_follower", {
847            let local_node = local_node.clone();
848            let tasks = task_send.clone();
849            let driver = driver.clone();
850            async move {
851                run_follower(
852                    driver.as_ref(),
853                    &local_node,
854                    from_leader,
855                    pending_connections,
856                    &tasks,
857                )
858                .await
859            }
860        }));
861
862        Self {
863            driver,
864            local_node,
865            tasks: task_send,
866            io_task,
867            to_leader,
868            leader_resign_send: Mutex::new(None),
869
870            _drop_send: drop_send,
871        }
872    }
873
874    /// Returns an offer to hand the leadership to another node in the mesh.
875    ///
876    /// The offer should be sent over a channel and passed to
877    /// `accept_leadership` in the receiving node.
878    pub fn offer_leadership(&self) -> LeadershipOffer {
879        let (send, mut recv) = channel();
880        if let Some(leader_send) = self.leader_resign_send.lock().clone() {
881            // Start a task to wait for the offer to be acknowledged, then send
882            // the offer details to the leader thread.
883            let task = SmallTask::new("offer_leadership", async move {
884                if let Ok(r) = recv.recv().await {
885                    leader_send.send(r);
886                }
887            });
888            self.tasks.send(task);
889        }
890        LeadershipOffer { send }
891    }
892
893    /// Accepts a leadership offer, making this node the current leader.
894    pub fn accept_leadership(&self, offer: LeadershipOffer) {
895        let (send, mut recv) = channel();
896        offer.send.send((self.local_node.id(), send));
897
898        let (resign_send, resign_recv) = channel();
899        let resign_send = Arc::new(resign_send);
900        let task = SmallTask::new("accept_and_run_leader", {
901            let local_node = self.local_node.clone();
902            let tasks = self.tasks.clone();
903            let driver = self.driver.clone();
904            async move {
905                if let Ok(followers) = recv.recv().await {
906                    drop(recv);
907                    run_leader(driver.as_ref(), &local_node, resign_recv, followers, &tasks).await
908                }
909            }
910        });
911        self.tasks.send(task);
912        *self.leader_resign_send.lock() = Some(resign_send);
913    }
914
915    /// Invites another process to join the mesh, with `port` bridged with the
916    /// original port.
917    #[instrument(skip_all, fields(local_id = ?self.local_node.id()))]
918    pub async fn invite(&self, port: Port) -> Result<Invitation, InviteError> {
919        let (invitation_send, mut invitation_recv) = channel();
920        self.to_leader
921            .send(LeaderRequest::Invite(port, invitation_send));
922        let invitation = invitation_recv.recv().await.map_err(|_| InviteError)?;
923        tracing::debug!(
924            invite_id = ?invitation.address.local_addr.node,
925            "received invitation",
926        );
927        Ok(invitation)
928    }
929
930    /// Extract a [`UnixMeshInviter`] handle. Cheap — clones an Arc.
931    ///
932    /// The inviter can be used from any task to create mesh invitations
933    /// without holding a reference to this node.
934    pub(crate) fn inviter(&self) -> UnixMeshInviter {
935        UnixMeshInviter {
936            to_leader: (*self.to_leader).clone(),
937        }
938    }
939
940    /// Joins an existing mesh via an invitation, briding `port` with the
941    /// initial port.
942    pub async fn join(
943        driver: impl SpawnDriver,
944        invitation: Invitation,
945        port: Port,
946    ) -> Result<Self, JoinError> {
947        Self::join_generic(Arc::new(driver), invitation, port).await
948    }
949
950    #[instrument(skip_all, fields(local_id = ?invitation.address.local_addr.node, remote_id = ?invitation.address.remote_addr.node))]
951    async fn join_generic(
952        driver: Arc<dyn SpawnDriver>,
953        invitation: Invitation,
954        port: Port,
955    ) -> Result<Self, JoinError> {
956        let (to_leader_send, to_leader_recv) = channel();
957        let (from_leader_send, from_leader_recv) = channel();
958        let this = Self::with_id(
959            driver,
960            invitation.address.local_addr.node,
961            to_leader_send,
962            from_leader_recv,
963        );
964
965        let handle = this
966            .local_node
967            .add_remote(invitation.address.remote_addr.node);
968        let init_recv = OneshotReceiver::<InitialMessage>::from(this.local_node.add_port(
969            invitation.address.local_addr.port,
970            invitation.address.remote_addr,
971        ));
972
973        start_connection(
974            &this.tasks,
975            &this.local_node,
976            invitation.address.remote_addr.node,
977            handle,
978            UnixSocket::new(this.driver.as_ref(), invitation.fd.into()),
979        );
980
981        let init_message = init_recv.await.map_err(JoinError)?;
982        to_leader_recv.bridge(init_message.leader_send);
983        from_leader_send.bridge(init_message.follower_recv);
984        port.bridge(init_message.user_port);
985
986        Ok(this)
987    }
988
989    /// Shuts down the node, waiting for any sent messages to be sent to their
990    /// destination.
991    ///
992    /// After this call, any active ports will no longer be able to receive
993    /// messages.
994    ///
995    /// It is essential to call this before exiting a mesh process; until this
996    /// returns, data loss could occur for other mesh nodes.
997    pub async fn shutdown(mut self) {
998        // Wait for any proxy ports to disassociate.
999        self.local_node.wait_for_ports(false).await;
1000        // Drop all connections to the leader.
1001        drop(self.to_leader);
1002        self.local_node.drop_connector();
1003        // Terminate the leader task.
1004        self.leader_resign_send.get_mut().take();
1005        // Fail all nodes so that the send threads are dropped.
1006        self.local_node.fail_all_nodes();
1007        // Signal the IO task to tear down.
1008        drop(self.tasks);
1009        // Wait for the IO task.
1010        self.io_task.await;
1011    }
1012}
1013
1014/// An error returned by [`UnixNode::join`].
1015#[derive(Debug, Error)]
1016#[error("failed to accept invitation")]
1017pub struct JoinError(#[source] RecvError);
1018
1019/// The connector used when the mesh needs to connect to a previously-recognized
1020/// node. Sends a message to the leader node to get a new socket to communicate
1021/// over.
1022#[derive(Debug)]
1023struct Connector {
1024    local_id: NodeId,
1025    conn_req_send: Arc<mesh_channel::Sender<LeaderRequest>>,
1026    pending_connections: Arc<Mutex<HashMap<NodeId, RemoteNodeHandle>>>,
1027}
1028
1029impl Connect for Connector {
1030    fn connect(&self, node_id: NodeId, handle: RemoteNodeHandle) {
1031        tracing::trace!(local_id = ?self.local_id, remote_id = ?node_id, "connecting");
1032        let old_request = self.pending_connections.lock().insert(node_id, handle);
1033        if old_request.is_some() {
1034            panic!("duplicate connection request for {:?}", node_id);
1035        }
1036        self.conn_req_send.send(LeaderRequest::Connect(node_id))
1037    }
1038}
1039
1040/// Creates an AF_UNIX socket pair of the appropriate type.
1041fn new_socket_pair() -> Result<(Socket, Socket), io::Error> {
1042    let ty = if USE_SEQPACKET {
1043        socket2::Type::SEQPACKET
1044    } else {
1045        socket2::Type::STREAM
1046    };
1047    Socket::pair(socket2::Domain::UNIX, ty, None)
1048}
1049
1050/// An AF_UNIX SOCK_SEQPACKET connection.
1051struct UnixSocket {
1052    socket: Mutex<PolledSocket<Socket>>,
1053}
1054
1055impl UnixSocket {
1056    fn new(driver: &dyn SpawnDriver, fd: Socket) -> Self {
1057        let socket = PolledSocket::new(driver, fd).unwrap();
1058        UnixSocket {
1059            socket: Mutex::new(socket),
1060        }
1061    }
1062
1063    async fn send(&self, msg: &[u8], fds: &[OsResource]) -> io::Result<()> {
1064        if USE_SEQPACKET {
1065            self.send_raw(&mut [IoSlice::new(msg)], fds).await?;
1066        } else {
1067            let len = (msg.len() as u32).to_le_bytes();
1068            let mut iov = [IoSlice::new(&len), IoSlice::new(msg)];
1069            self.send_all_raw(&mut iov, fds).await?;
1070        }
1071        Ok(())
1072    }
1073
1074    async fn send_raw(
1075        &self,
1076        iov: &mut [IoSlice<'_>],
1077        fds: &[OsResource],
1078    ) -> Result<usize, io::Error> {
1079        let n = poll_fn(|cx| {
1080            self.socket
1081                .lock()
1082                .poll_io(cx, InterestSlot::Write, PollEvents::OUT, |socket| {
1083                    try_send(socket.get().as_fd(), iov, fds)
1084                })
1085        })
1086        .await?;
1087        Ok(n)
1088    }
1089
1090    async fn send_all_raw(
1091        &self,
1092        mut iov: &mut [IoSlice<'_>],
1093        mut fds: &[OsResource],
1094    ) -> Result<(), io::Error> {
1095        while !iov.is_empty() || !fds.is_empty() {
1096            let n = self.send_raw(iov, fds).await?;
1097            IoSlice::advance_slices(&mut iov, n);
1098            fds = &[];
1099        }
1100        Ok(())
1101    }
1102
1103    async fn recv(&self, buf: &mut [u8], fds: &mut Vec<OsResource>) -> io::Result<usize> {
1104        if USE_SEQPACKET {
1105            self.recv_raw(buf, fds).await
1106        } else {
1107            let mut len = [0; 4];
1108            if !self.recv_all_raw(&mut len, fds).await? {
1109                return Ok(0);
1110            }
1111            let len = u32::from_le_bytes(len) as usize;
1112            let buf = buf
1113                .get_mut(..len)
1114                .ok_or_else(|| io::Error::from_raw_os_error(libc::EMSGSIZE))?;
1115            if !self.recv_all_raw(buf, fds).await? {
1116                return Err(ErrorKind::UnexpectedEof.into());
1117            }
1118            Ok(len)
1119        }
1120    }
1121
1122    async fn recv_all_raw(
1123        &self,
1124        buf: &mut [u8],
1125        fds: &mut Vec<OsResource>,
1126    ) -> Result<bool, io::Error> {
1127        let mut read = 0;
1128        while read < buf.len() {
1129            let n = self.recv_raw(&mut buf[read..], fds).await?;
1130            if n == 0 {
1131                if read != 0 {
1132                    return Err(ErrorKind::UnexpectedEof.into());
1133                } else {
1134                    return Ok(false);
1135                }
1136            }
1137            read += n;
1138        }
1139        Ok(true)
1140    }
1141
1142    async fn recv_raw(
1143        &self,
1144        buf: &mut [u8],
1145        fds: &mut Vec<OsResource>,
1146    ) -> Result<usize, io::Error> {
1147        let n = poll_fn(|cx| {
1148            self.socket
1149                .lock()
1150                .poll_io(cx, InterestSlot::Read, PollEvents::IN, |socket| {
1151                    try_recv(socket.get().as_fd(), buf, fds)
1152                })
1153        })
1154        .await?;
1155        Ok(n)
1156    }
1157
1158    async fn close_write(&self) -> io::Result<()> {
1159        self.socket.lock().get().shutdown(std::net::Shutdown::Write)
1160    }
1161}
1162
1163#[cfg(test)]
1164mod tests {
1165    use crate::unix::UnixNode;
1166    use mesh_channel::RecvError;
1167    use mesh_channel::channel;
1168    use pal_async::DefaultDriver;
1169    use pal_async::async_test;
1170    use test_with_tracing::test;
1171
1172    #[async_test]
1173    async fn test_basic(driver: DefaultDriver) {
1174        let leader = UnixNode::new(driver.clone());
1175        let (send, recv) = channel::<u32>();
1176        let invitation = leader.invite(recv.into()).await.unwrap();
1177        let (send2, mut recv2) = channel::<u32>();
1178        let follower = UnixNode::join(driver, invitation, send2.into())
1179            .await
1180            .unwrap();
1181        send.send(5);
1182        assert_eq!(recv2.recv().await.unwrap(), 5);
1183        drop(send);
1184        drop(recv2);
1185        follower.shutdown().await;
1186        leader.shutdown().await;
1187    }
1188
1189    #[cfg(target_os = "linux")]
1190    #[async_test]
1191    async fn test_huge_message(driver: DefaultDriver) {
1192        let leader = UnixNode::new(driver.clone());
1193        let (send, recv) = channel::<Vec<u8>>();
1194        let invitation = leader.invite(recv.into()).await.unwrap();
1195        let (send2, mut recv2) = channel::<Vec<u8>>();
1196        let follower = UnixNode::join(driver, invitation, send2.into())
1197            .await
1198            .unwrap();
1199
1200        let v = vec![0xcc; 16 << 20];
1201        send.send(v.clone());
1202        let v2 = recv2.recv().await.unwrap();
1203        assert_eq!(v, v2);
1204        follower.shutdown().await;
1205        leader.shutdown().await;
1206    }
1207
1208    #[cfg(target_os = "linux")]
1209    #[async_test]
1210    async fn test_message_sizes(driver: DefaultDriver) {
1211        let (p1, p2) = mesh_node::local_node::Port::new_pair();
1212        let (p3, p4) = mesh_node::local_node::Port::new_pair();
1213        let node1 = UnixNode::new(driver.clone());
1214        let invitation = node1.invite(p2).await.unwrap();
1215        let _node2 = UnixNode::join(driver.clone(), invitation, p3)
1216            .await
1217            .unwrap();
1218
1219        crate::test_common::test_message_sizes(p1, p4, 0..=super::MAX_SMALL_EVENT_SIZE + 0x1000)
1220            .await;
1221    }
1222
1223    #[async_test]
1224    async fn test_dropped_shutdown(driver: DefaultDriver) {
1225        let leader = UnixNode::new(driver.clone());
1226        {
1227            let (_send, recv) = channel::<u32>();
1228            let invitation = leader.invite(recv.into()).await.unwrap();
1229            let (send2, _recv2) = channel::<u32>();
1230            let _follower = UnixNode::join(driver, invitation, send2.into())
1231                .await
1232                .unwrap();
1233        }
1234        leader.shutdown().await;
1235    }
1236
1237    #[async_test]
1238    async fn test_send_shutdown(driver: DefaultDriver) {
1239        let leader = UnixNode::new(driver.clone());
1240        let (send, mut recv) = channel::<u32>();
1241        let invitation = leader.invite(send.into()).await.unwrap();
1242        let (send2, recv2) = channel::<u32>();
1243        let follower = UnixNode::join(driver, invitation, recv2.into())
1244            .await
1245            .unwrap();
1246        send2.send(5);
1247        drop(send2);
1248        follower.shutdown().await;
1249        assert_eq!(recv.recv().await.unwrap(), 5);
1250    }
1251
1252    #[async_test]
1253    async fn test_failed_invitation(driver: DefaultDriver) {
1254        let leader = UnixNode::new(driver);
1255        let (send, mut recv) = channel::<()>();
1256        leader.invite(send.into()).await.unwrap();
1257        assert!(matches!(
1258            recv.recv().await.unwrap_err(),
1259            RecvError::Error(_)
1260        ));
1261        drop(recv);
1262        leader.shutdown().await;
1263    }
1264
1265    #[async_test]
1266    async fn test_three(driver: DefaultDriver) {
1267        let (p1, p2) = channel::<u32>();
1268        let (p3, mut p4) = channel::<u32>();
1269        let (p5, p6) = channel::<u32>();
1270        let (p7, p8) = channel::<u32>();
1271
1272        let node1 = UnixNode::new(driver.clone());
1273
1274        let invitation = node1.invite(p2.into()).await.unwrap();
1275        let node2 = UnixNode::join(driver.clone(), invitation, p3.into())
1276            .await
1277            .unwrap();
1278
1279        let invitation = node1.invite(p5.into()).await.unwrap();
1280        let node3 = UnixNode::join(driver, invitation, p8.into()).await.unwrap();
1281
1282        p1.bridge(p6);
1283
1284        p7.send(5);
1285
1286        assert_eq!(p4.recv().await.unwrap(), 5);
1287        drop(p4);
1288        drop(p7);
1289        futures::join!(node2.shutdown(), node3.shutdown());
1290        node1.shutdown().await;
1291    }
1292
1293    #[async_test]
1294    async fn test_handoff_leader(driver: DefaultDriver) {
1295        let (p1, p2) = channel::<u32>();
1296        let (p3, p4) = channel::<u32>();
1297        let (p5, p6) = channel::<u32>();
1298        let (p7, p8) = channel::<u32>();
1299        let (p9, p10) = channel();
1300        let (p11, mut p12) = channel();
1301
1302        let node1 = UnixNode::new(driver.clone());
1303
1304        let invitation = node1.invite(p2.into()).await.unwrap();
1305        let node2 = UnixNode::join(driver.clone(), invitation, p3.into())
1306            .await
1307            .unwrap();
1308
1309        let invitation = node1.invite(p5.into()).await.unwrap();
1310        let node3 = UnixNode::join(driver.clone(), invitation, p8.into())
1311            .await
1312            .unwrap();
1313
1314        let invitation = node1.invite(p10.into()).await.unwrap();
1315        let node4 = UnixNode::join(driver, invitation, p11.into())
1316            .await
1317            .unwrap();
1318
1319        p9.send(node1.offer_leadership());
1320        node4.accept_leadership(p12.recv().await.unwrap());
1321        drop(p9);
1322        drop(p12);
1323        p1.bridge(p6);
1324
1325        std::thread::sleep(std::time::Duration::from_millis(200));
1326
1327        node1.shutdown().await;
1328        drop(p4);
1329        drop(p7);
1330        node2.shutdown().await;
1331        node3.shutdown().await;
1332
1333        std::thread::sleep(std::time::Duration::from_millis(200));
1334
1335        node4.shutdown().await;
1336    }
1337}