1#![cfg(unix)]
15
16#[cfg(target_os = "linux")]
17mod memfd;
18
19use crate::common::InvitationAddress;
20use crate::protocol;
21use crate::unix_common::try_recv;
22use crate::unix_common::try_send;
23use futures::FutureExt;
24use futures::StreamExt;
25use futures::channel::mpsc;
26use futures::future;
27use futures::future::BoxFuture;
28use io::ErrorKind;
29use mesh_channel::OneshotReceiver;
30use mesh_channel::OneshotSender;
31use mesh_channel::RecvError;
32use mesh_channel::channel;
33use mesh_channel::oneshot;
34use mesh_node::common::Address;
35use mesh_node::common::NodeId;
36use mesh_node::common::PortId;
37use mesh_node::local_node::Connect;
38use mesh_node::local_node::LocalNode;
39use mesh_node::local_node::OutgoingEvent;
40use mesh_node::local_node::Port;
41use mesh_node::local_node::RemoteNodeHandle;
42use mesh_node::local_node::SendEvent;
43use mesh_node::resource::OsResource;
44use mesh_node::resource::Resource;
45use mesh_protobuf::Protobuf;
46use pal_async::driver::SpawnDriver;
47use pal_async::interest::InterestSlot;
48use pal_async::interest::PollEvents;
49use pal_async::socket::PolledSocket;
50use pal_async::task::Spawn;
51use pal_async::task::Task;
52use parking_lot::Mutex;
53use socket2::Socket;
54use std::collections::HashMap;
55use std::collections::VecDeque;
56use std::fmt::Debug;
57use std::future::Future;
58use std::future::poll_fn;
59use std::io;
60use std::io::IoSlice;
61use std::os::unix::prelude::*;
62use std::pin::pin;
63use std::sync::Arc;
64use thiserror::Error;
65use tracing::instrument;
66use unicycle::FuturesUnordered;
67use zerocopy::FromBytes;
68use zerocopy::FromZeros;
69use zerocopy::IntoBytes;
70
71const USE_SEQPACKET: bool = cfg!(target_os = "linux");
78
79const MAX_PACKET_SIZE: usize = if cfg!(target_os = "linux") {
85 0x4000
86} else {
87 0x40000
88};
89
90const MAX_SMALL_EVENT_SIZE: usize = MAX_PACKET_SIZE - size_of::<protocol::PacketHeader>();
91
92pub struct UnixNode {
102 driver: Arc<dyn SpawnDriver>,
103 local_node: Arc<LocalNode>,
104 to_leader: Arc<mesh_channel::Sender<LeaderRequest>>,
105 tasks: Arc<mesh_channel::Sender<SmallTask>>,
106 io_task: Task<()>,
107 leader_resign_send:
109 Mutex<Option<Arc<mesh_channel::Sender<(NodeId, mesh_channel::Sender<Followers>)>>>>,
110
111 _drop_send: OneshotSender<()>,
113}
114
115#[derive(Debug, Error)]
117#[error("mesh node shut down before invitation could be created")]
118pub struct InviteError;
119
120#[derive(Clone)]
129pub(crate) struct UnixMeshInviter {
130 to_leader: mesh_channel::Sender<LeaderRequest>,
131}
132
133impl UnixMeshInviter {
134 pub(crate) async fn invite(&self, port: Port) -> Result<Invitation, InviteError> {
139 let (invitation_send, mut invitation_recv) = channel();
140 self.to_leader
141 .send(LeaderRequest::Invite(port, invitation_send));
142 invitation_recv.recv().await.map_err(|_| InviteError)
143 }
144}
145
146#[derive(Debug, Protobuf)]
147#[mesh(resource = "Resource")]
148enum LeaderRequest {
149 Connect(NodeId),
150 Invite(Port, mesh_channel::Sender<Invitation>),
151}
152
153#[derive(Debug, Protobuf)]
154#[mesh(resource = "Resource")]
155enum FollowerRequest {
156 Connect(
157 NodeId,
158 #[mesh(
159 encoding = "mesh_protobuf::encoding::OptionField<mesh_protobuf::encoding::ResourceField<OwnedFd>>"
160 )]
161 Option<Socket>,
162 ),
163}
164
165#[derive(Debug, Protobuf)]
166#[mesh(resource = "Resource")]
167pub struct Followers {
168 list: Vec<(
169 NodeId,
170 mesh_channel::Receiver<LeaderRequest>,
171 mesh_channel::Sender<FollowerRequest>,
172 )>,
173}
174
175#[derive(Debug, Protobuf)]
176#[mesh(resource = "Resource")]
177struct InitialMessage {
178 leader_send: mesh_channel::Sender<LeaderRequest>,
179 follower_recv: mesh_channel::Receiver<FollowerRequest>,
180 user_port: Port,
181}
182
183#[instrument(skip_all, fields(local_id = ?local_node.id()))]
186async fn run_follower(
187 driver: &dyn SpawnDriver,
188 local_node: &Arc<LocalNode>,
189 mut recv: mesh_channel::Receiver<FollowerRequest>,
190 pending_connections: Arc<Mutex<HashMap<NodeId, RemoteNodeHandle>>>,
191 tasks: &mesh_channel::Sender<SmallTask>,
192) {
193 while let Ok(req) = recv.recv().await {
194 match req {
195 FollowerRequest::Connect(target_id, fd) => {
196 tracing::debug!(?target_id, "got connection request from leader");
197 let handle = pending_connections.lock().remove(&target_id);
198 let handle = handle.unwrap_or_else(|| local_node.get_remote_handle(target_id));
199
200 if let Some(fd) = fd {
201 start_connection(
202 tasks,
203 local_node,
204 target_id,
205 handle,
206 UnixSocket::new(driver, fd),
207 );
208 } else {
209 tracing::warn!(?target_id, "leader provided failed connection");
210 }
211 }
212 }
213 }
214}
215
216#[instrument(skip_all, fields(local_id = ?local_node.id()))]
220async fn run_leader(
221 driver: &dyn SpawnDriver,
222 local_node: &Arc<LocalNode>,
223 mut resign_recv: mesh_channel::Receiver<(NodeId, mesh_channel::Sender<Followers>)>,
224 followers: Followers,
225 tasks: &mesh_channel::Sender<SmallTask>,
226) {
227 let mut senders = HashMap::new();
228 let mut receivers = Vec::new();
229 for (remote_id, recv, send) in followers.list {
230 receivers.push((remote_id, recv));
231 senders.insert(remote_id, send);
232 }
233
234 let new_leader_info = loop {
235 if receivers.is_empty() {
236 return;
237 }
238 let recvs = receivers
239 .iter_mut()
240 .map(|(_, recv)| poll_fn(|cx| recv.poll_recv(cx)));
241 let (req, index, _) = futures::select! { r = resign_recv.next() => break r,
243 r = future::select_all(recvs).fuse() => r,
244 };
245 let remote_id = receivers[index].0;
246 match req {
247 Ok(req) => match req {
248 LeaderRequest::Connect(target_id) => {
249 tracing::debug!(?target_id, ?remote_id, "connection request");
250 let remote = senders
251 .get(&remote_id)
252 .expect("sender must exist to receive from it");
253 let mut fd = None;
254 if let Some(target) = senders.get(&target_id) {
255 match new_socket_pair() {
256 Ok((left, right)) => {
257 tracing::trace!(?target, "send to");
258 target.send(FollowerRequest::Connect(remote_id, Some(left)));
259 fd = Some(right);
260 }
261 Err(err) => {
262 tracing::warn!(
263 ?target_id,
264 ?remote_id,
265 error = &err as &dyn std::error::Error,
266 "failed to create socket pair for connection request"
267 );
268 }
269 }
270 } else {
271 tracing::warn!(?target_id, ?remote_id, "could not find target for remote");
272 }
273 remote.send(FollowerRequest::Connect(target_id, fd));
274 }
275 LeaderRequest::Invite(port, send) => {
276 tracing::debug!(?remote_id, "invitation request");
277 match new_socket_pair() {
278 Ok((left, right)) => {
279 let (leader_send, leader_recv) = channel();
280 let (follower_send, follower_recv) = channel();
281 let remote_addr = Address {
282 node: NodeId::new(),
283 port: PortId::new(),
284 };
285 let local_port_id = PortId::new();
286 let handle = local_node.add_remote(remote_addr.node);
287 start_connection(
288 tasks,
289 local_node,
290 remote_addr.node,
291 handle,
292 UnixSocket::new(driver, left),
293 );
294 let init_send = OneshotSender::<InitialMessage>::from(
295 local_node.add_port(local_port_id, remote_addr),
296 );
297 init_send.send(InitialMessage {
298 leader_send,
299 follower_recv,
300 user_port: port,
301 });
302 let invitation = Invitation {
303 address: InvitationAddress {
304 local_addr: remote_addr,
305 remote_addr: Address {
306 node: local_node.id(),
307 port: local_port_id,
308 },
309 },
310 fd: right.into(),
311 };
312 tracing::debug!(
313 invite_id = ?invitation.address.local_addr.node,
314 ?remote_id,
315 "inviting",
316 );
317 send.send(invitation);
318 senders.insert(remote_addr.node, follower_send);
319 receivers.push((remote_addr.node, leader_recv));
320 }
321 Err(err) => {
322 tracing::error!(
323 error = &err as &dyn std::error::Error,
324 "failed to create socket pair",
325 );
326 }
327 }
328 }
329 },
330 Err(err) => {
331 if let RecvError::Error(err) = err {
332 tracing::debug!(
333 ?remote_id,
334 error = &err as &dyn std::error::Error,
335 "leader connection to remote failed"
336 );
337 }
338 senders.remove(&remote_id);
339 receivers.swap_remove(index);
340 }
341 }
342 };
343
344 if let Some((new_leader_id, new_leader_followers_sink)) = new_leader_info {
345 if let Some(new_leader_send) = senders.get(&new_leader_id) {
346 tracing::debug!(?new_leader_id, "resigning leadership");
347 for (remote_id, send) in senders.iter() {
350 if new_leader_id != *remote_id {
351 match new_socket_pair() {
352 Ok((left, right)) => {
353 send.send(FollowerRequest::Connect(new_leader_id, Some(left)));
354 new_leader_send.send(FollowerRequest::Connect(*remote_id, Some(right)));
355 }
356 Err(err) => {
357 tracing::error!(
358 ?new_leader_id,
359 error = &err as &dyn std::error::Error,
360 "failed to connect node to new leader, mesh is leaderless",
361 );
362 return;
363 }
364 }
365 }
366 }
367
368 let mut followers = Vec::new();
370 for (remote_id, recv) in receivers {
371 let send = senders
372 .remove(&remote_id)
373 .expect("should be in sync with receivers");
374 followers.push((remote_id, recv, send));
375 }
376 new_leader_followers_sink.send(Followers { list: followers });
377 } else {
378 tracing::error!(?new_leader_id, "new leader is unknown, mesh is leaderless");
379 }
380 }
381}
382
383struct SmallTask {
386 name: &'static str,
387 future: BoxFuture<'static, ()>,
388}
389
390impl Debug for SmallTask {
391 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392 f.pad("SmallTask")
393 }
394}
395
396impl SmallTask {
397 fn new(name: &'static str, f: impl 'static + Send + Future<Output = ()>) -> Self {
398 Self {
399 name,
400 future: Box::pin(f),
401 }
402 }
403}
404
405#[derive(Debug, Protobuf)]
409#[mesh(resource = "Resource")]
410pub struct Invitation {
411 pub address: InvitationAddress,
413 pub fd: OwnedFd,
415}
416
417#[derive(Debug)]
418enum SenderCommand {
419 Send {
420 packet: Vec<u8>,
421 fds: Vec<OsResource>,
422 },
423 ReleaseFds {
424 count: usize,
425 },
426}
427
428#[derive(Clone)]
429struct PacketSender {
430 send: mpsc::UnboundedSender<SenderCommand>,
431 socket: Arc<UnixSocket>,
432}
433
434impl SendEvent for PacketSender {
435 fn event(&self, event: OutgoingEvent<'_>) {
436 let (packet, fds) = match serialize_event(event) {
437 Ok(r) => r,
438 Err(err) => {
439 tracing::error!(
441 error = &err as &dyn std::error::Error,
442 "failed to serialize event"
443 );
444 return;
445 }
446 };
447
448 if !USE_SEQPACKET
460 || try_send(
461 self.socket.socket.lock().get().as_fd(),
462 &[IoSlice::new(&packet)],
463 &fds,
464 )
465 .is_err()
466 {
467 let _ = self
468 .send
469 .unbounded_send(SenderCommand::Send { packet, fds });
470 }
471 }
472}
473
474fn serialize_event(event: OutgoingEvent<'_>) -> io::Result<(Vec<u8>, Vec<OsResource>)> {
475 if event.len() > MAX_SMALL_EVENT_SIZE {
477 return serialize_large_event(event);
478 }
479
480 let cap = size_of::<protocol::PacketHeader>() + event.len();
482 let mut packet = Vec::with_capacity(cap);
483 packet.extend_from_slice(
484 protocol::PacketHeader {
485 packet_type: protocol::PacketType::EVENT,
486 reserved: [0; 7],
487 }
488 .as_bytes(),
489 );
490 let mut fds = Vec::new();
491 event.write_to(&mut packet, &mut fds);
492 assert_eq!(packet.len(), cap);
493 Ok((packet, fds))
494}
495
496#[cfg(target_os = "linux")]
497fn serialize_large_event(event: OutgoingEvent<'_>) -> io::Result<(Vec<u8>, Vec<OsResource>)> {
498 let packet = protocol::PacketHeader {
499 packet_type: protocol::PacketType::LARGE_EVENT,
500 ..FromZeros::new_zeroed()
501 }
502 .as_bytes()
503 .to_vec();
504
505 let mut fds = Vec::new();
506
507 let mut memfd = memfd::MemfdBuilder::new(event.len())?;
508 event.write_to(&mut io::Cursor::new(&mut *memfd), &mut fds);
509 fds.insert(0, OsResource::Fd(memfd.seal()?.into()));
510
511 Ok((packet, fds))
512}
513
514#[cfg(not(target_os = "linux"))]
515fn serialize_large_event(_event: OutgoingEvent<'_>) -> io::Result<(Vec<u8>, Vec<OsResource>)> {
516 Err(io::Error::new(
517 ErrorKind::Unsupported,
518 "event too large for this OS",
519 ))
520}
521
522impl Drop for PacketSender {
523 fn drop(&mut self) {
524 self.send.close_channel();
527 }
528}
529
530fn start_connection(
532 tasks: &mesh_channel::Sender<SmallTask>,
533 local_node: &Arc<LocalNode>,
534 remote_id: NodeId,
535 handle: RemoteNodeHandle,
536 socket: UnixSocket,
537) {
538 #[expect(clippy::disallowed_methods)] let (send, recv) = mpsc::unbounded();
540 let socket = Arc::new(socket);
541 let sender = PacketSender {
542 send: send.clone(),
543 socket: socket.clone(),
544 };
545 if handle.connect(sender) {
546 let task = SmallTask::new("run_connection", {
547 let local_node = local_node.clone();
548 run_connection(local_node, remote_id, send, recv, socket, handle)
549 });
550 tasks.send(task);
551 tracing::debug!(?remote_id, "connected");
552 } else {
553 tracing::debug!(?remote_id, "duplicate connection");
557 }
558}
559
560#[instrument(skip_all, fields(local_id = ?local_node.id(), remote_id = ?remote_id))]
562async fn run_connection(
563 local_node: Arc<LocalNode>,
564 remote_id: NodeId,
565 send_send: mpsc::UnboundedSender<SenderCommand>,
566 send_recv: mpsc::UnboundedReceiver<SenderCommand>,
567 socket: Arc<UnixSocket>,
568 handle: RemoteNodeHandle,
569) {
570 let mut retained_fds = VecDeque::new();
571 let mut recv = pin!(
572 async {
573 let r = run_receive(&local_node, &remote_id, &socket, &send_send).await;
574 match &r {
575 Ok(_) => {
576 tracing::debug!("incoming socket disconnected");
577 }
578 Err(err) => {
579 tracing::error!(error = err as &dyn std::error::Error, "error receiving");
580 }
581 }
582 r
583 }
584 .fuse()
585 );
586 let mut send = pin!(
587 async {
588 match run_send(send_recv, &socket, &mut retained_fds).await {
589 Ok(_) => {
590 tracing::debug!("sending is done");
591 }
592 Err(err) => {
593 tracing::error!(error = &err as &dyn std::error::Error, "failed send");
594 }
595 }
596 }
597 .fuse()
598 );
599 let r = futures::select! { r = recv => {
601 tracing::trace!("read complete, shutting down writes");
603 let _ = socket.close_write().await;
604 r
605 }
606 _ = send => {
607 match socket.close_write().await {
608 Ok(()) => {
609 tracing::trace!("shutdown writes, waiting for reads");
610 recv.await
611 }
612 Err(err) => {
613 tracing::error!(
614 error = &err as &dyn std::error::Error,
615 "failed to shutdown writes, aborting connection",
616 );
617 Err(ReceiveError::Io(err))
618 }
619 }
620 }
621 };
622 tracing::trace!("connection done");
623 match r {
624 Ok(()) => handle.disconnect(),
625 Err(err) => handle.fail(err),
626 }
627}
628
629#[derive(Debug, Error)]
630enum ReceiveError {
631 #[error("i/o error")]
632 Io(#[from] io::Error),
633 #[error("missing packet header")]
634 NoHeader,
635 #[error("release fds packet too small")]
636 BadReleaseFds,
637 #[error("unknown packet type {0:?}")]
638 UnknownPacketType(protocol::PacketType),
639 #[cfg(target_os = "linux")]
640 #[error("memfd file descriptor not sent for large event")]
641 MissingMemfd,
642 #[cfg(target_os = "linux")]
643 #[error("failed to map memfd")]
644 Memfd(#[source] io::Error),
645}
646
647async fn run_receive(
649 local_node: &LocalNode,
650 remote_id: &NodeId,
651 socket: &UnixSocket,
652 send: &mpsc::UnboundedSender<SenderCommand>,
653) -> Result<(), ReceiveError> {
654 let mut buf = vec![0; MAX_PACKET_SIZE];
655 let mut fds = Vec::new();
656 loop {
657 let len = socket.recv(&mut buf, &mut fds).await?;
658 if len == 0 {
659 break;
660 }
661 if cfg!(target_os = "macos") && !fds.is_empty() {
662 let _ = send.unbounded_send(SenderCommand::Send {
664 packet: protocol::ReleaseFds {
665 header: protocol::PacketHeader {
666 packet_type: protocol::PacketType::RELEASE_FDS,
667 ..FromZeros::new_zeroed()
668 },
669 count: fds.len() as u64,
670 }
671 .as_bytes()
672 .to_vec(),
673 fds: Vec::new(),
674 });
675 }
676
677 let buf = &buf[..len];
678 let header = protocol::PacketHeader::read_from_prefix(buf)
679 .map_err(|_| ReceiveError::NoHeader)?
680 .0; match header.packet_type {
682 protocol::PacketType::EVENT => {
683 local_node.event(remote_id, &buf[size_of_val(&header)..], &mut fds);
684 fds.clear();
685 }
686 protocol::PacketType::RELEASE_FDS => {
687 let release_fds = protocol::ReleaseFds::read_from_prefix(buf)
688 .map_err(|_| ReceiveError::BadReleaseFds)?
689 .0; let _ = send.unbounded_send(SenderCommand::ReleaseFds {
691 count: release_fds.count as usize,
692 });
693 }
694 #[cfg(target_os = "linux")]
695 protocol::PacketType::LARGE_EVENT => {
696 if fds.is_empty() {
697 return Err(ReceiveError::MissingMemfd);
698 }
699 let OsResource::Fd(fd) = fds.remove(0);
700 let memfd = memfd::SealedMemfd::new(fd.into()).map_err(ReceiveError::Memfd)?;
701 local_node.event(remote_id, &memfd, &mut fds);
702 fds.clear();
703 }
704 ty => {
705 return Err(ReceiveError::UnknownPacketType(ty));
706 }
707 }
708 }
709 Ok(())
710}
711
712#[derive(Debug, Error)]
713enum ProtocolError {
714 #[error("request to release too many fds")]
715 ReleasingTooManyFds,
716}
717
718async fn run_send(
720 mut recv: mpsc::UnboundedReceiver<SenderCommand>,
721 socket: &UnixSocket,
722 retained_fds: &mut VecDeque<OsResource>,
723) -> io::Result<()> {
724 while let Some(command) = recv.next().await {
725 match command {
726 SenderCommand::Send { packet, fds } => {
727 match socket.send(&packet, &fds).await {
728 Ok(_) => (),
729 Err(err) => {
730 tracing::error!(
731 fd_count = fds.len(),
732 packet_len = packet.len(),
733 "failed to send packet"
734 );
735 return Err(err);
736 }
737 }
738 if cfg!(target_os = "macos") {
739 if !fds.is_empty() {
745 retained_fds.extend(fds);
746 }
747 }
748 }
749 SenderCommand::ReleaseFds { count } => {
750 if retained_fds.len() < count {
751 return Err(io::Error::other(ProtocolError::ReleasingTooManyFds));
752 }
753 retained_fds.drain(..count);
754 }
755 }
756 }
757 Ok(())
758}
759
760#[derive(Debug, Protobuf)]
763#[mesh(resource = "Resource")]
764pub struct LeadershipOffer {
765 send: mesh_channel::Sender<(NodeId, mesh_channel::Sender<Followers>)>,
766}
767
768impl UnixNode {
769 pub fn new(driver: impl SpawnDriver) -> Self {
771 let (to_leader_send, to_leader_recv) = channel();
772 let (from_leader_send, from_leader_recv) = channel();
773 let this = Self::with_id(
774 Arc::new(driver),
775 NodeId::new(),
776 to_leader_send,
777 from_leader_recv,
778 );
779
780 let (resign_send, resign_recv) = channel();
782 let resign_send = Arc::new(resign_send);
783 let followers = Followers {
784 list: vec![(this.local_node.id(), to_leader_recv, from_leader_send)],
785 };
786 let task = SmallTask::new("run_leader", {
787 let local_node = this.local_node.clone();
788 let tasks = this.tasks.clone();
789 let driver = this.driver.clone();
790 async move { run_leader(driver.as_ref(), &local_node, resign_recv, followers, &tasks).await }
791 });
792 this.tasks.send(task);
793 *this.leader_resign_send.lock() = Some(resign_send);
794
795 this
796 }
797
798 pub fn id(&self) -> NodeId {
800 self.local_node.id()
801 }
802
803 fn with_id(
806 driver: Arc<dyn SpawnDriver>,
807 id: NodeId,
808 to_leader: mesh_channel::Sender<LeaderRequest>,
809 from_leader: mesh_channel::Receiver<FollowerRequest>,
810 ) -> Self {
811 let to_leader = Arc::new(to_leader);
812 let pending_connections: Arc<Mutex<HashMap<NodeId, RemoteNodeHandle>>> = Default::default();
813 let local_node = Arc::new(LocalNode::with_id(
814 id,
815 Box::new(Connector {
816 local_id: id,
817 conn_req_send: to_leader.clone(),
818 pending_connections: pending_connections.clone(),
819 }),
820 ));
821 let (task_send, mut task_recv) = channel::<SmallTask>();
822 let task_send = Arc::new(task_send);
823 let (drop_send, drop_recv) = oneshot();
824
825 let io_task = driver.spawn("unix-mesh-io", async move {
827 let process = async {
828 let mut futs = FuturesUnordered::new();
829 loop {
830 futures::select! { _ = futs.next() => {},
832 task = task_recv.select_next_some() => {
833 futs.push(async move {
834 tracing::trace!(?id, name = task.name, "task start");
835 task.future.await;
836 tracing::trace!(?id, name = task.name, "task end");
837 });
838 },
839 complete => break,
840 };
841 }
842 };
843 future::select(pin!(process), drop_recv).await;
844 });
845
846 task_send.send(SmallTask::new("run_follower", {
847 let local_node = local_node.clone();
848 let tasks = task_send.clone();
849 let driver = driver.clone();
850 async move {
851 run_follower(
852 driver.as_ref(),
853 &local_node,
854 from_leader,
855 pending_connections,
856 &tasks,
857 )
858 .await
859 }
860 }));
861
862 Self {
863 driver,
864 local_node,
865 tasks: task_send,
866 io_task,
867 to_leader,
868 leader_resign_send: Mutex::new(None),
869
870 _drop_send: drop_send,
871 }
872 }
873
874 pub fn offer_leadership(&self) -> LeadershipOffer {
879 let (send, mut recv) = channel();
880 if let Some(leader_send) = self.leader_resign_send.lock().clone() {
881 let task = SmallTask::new("offer_leadership", async move {
884 if let Ok(r) = recv.recv().await {
885 leader_send.send(r);
886 }
887 });
888 self.tasks.send(task);
889 }
890 LeadershipOffer { send }
891 }
892
893 pub fn accept_leadership(&self, offer: LeadershipOffer) {
895 let (send, mut recv) = channel();
896 offer.send.send((self.local_node.id(), send));
897
898 let (resign_send, resign_recv) = channel();
899 let resign_send = Arc::new(resign_send);
900 let task = SmallTask::new("accept_and_run_leader", {
901 let local_node = self.local_node.clone();
902 let tasks = self.tasks.clone();
903 let driver = self.driver.clone();
904 async move {
905 if let Ok(followers) = recv.recv().await {
906 drop(recv);
907 run_leader(driver.as_ref(), &local_node, resign_recv, followers, &tasks).await
908 }
909 }
910 });
911 self.tasks.send(task);
912 *self.leader_resign_send.lock() = Some(resign_send);
913 }
914
915 #[instrument(skip_all, fields(local_id = ?self.local_node.id()))]
918 pub async fn invite(&self, port: Port) -> Result<Invitation, InviteError> {
919 let (invitation_send, mut invitation_recv) = channel();
920 self.to_leader
921 .send(LeaderRequest::Invite(port, invitation_send));
922 let invitation = invitation_recv.recv().await.map_err(|_| InviteError)?;
923 tracing::debug!(
924 invite_id = ?invitation.address.local_addr.node,
925 "received invitation",
926 );
927 Ok(invitation)
928 }
929
930 pub(crate) fn inviter(&self) -> UnixMeshInviter {
935 UnixMeshInviter {
936 to_leader: (*self.to_leader).clone(),
937 }
938 }
939
940 pub async fn join(
943 driver: impl SpawnDriver,
944 invitation: Invitation,
945 port: Port,
946 ) -> Result<Self, JoinError> {
947 Self::join_generic(Arc::new(driver), invitation, port).await
948 }
949
950 #[instrument(skip_all, fields(local_id = ?invitation.address.local_addr.node, remote_id = ?invitation.address.remote_addr.node))]
951 async fn join_generic(
952 driver: Arc<dyn SpawnDriver>,
953 invitation: Invitation,
954 port: Port,
955 ) -> Result<Self, JoinError> {
956 let (to_leader_send, to_leader_recv) = channel();
957 let (from_leader_send, from_leader_recv) = channel();
958 let this = Self::with_id(
959 driver,
960 invitation.address.local_addr.node,
961 to_leader_send,
962 from_leader_recv,
963 );
964
965 let handle = this
966 .local_node
967 .add_remote(invitation.address.remote_addr.node);
968 let init_recv = OneshotReceiver::<InitialMessage>::from(this.local_node.add_port(
969 invitation.address.local_addr.port,
970 invitation.address.remote_addr,
971 ));
972
973 start_connection(
974 &this.tasks,
975 &this.local_node,
976 invitation.address.remote_addr.node,
977 handle,
978 UnixSocket::new(this.driver.as_ref(), invitation.fd.into()),
979 );
980
981 let init_message = init_recv.await.map_err(JoinError)?;
982 to_leader_recv.bridge(init_message.leader_send);
983 from_leader_send.bridge(init_message.follower_recv);
984 port.bridge(init_message.user_port);
985
986 Ok(this)
987 }
988
989 pub async fn shutdown(mut self) {
998 self.local_node.wait_for_ports(false).await;
1000 drop(self.to_leader);
1002 self.local_node.drop_connector();
1003 self.leader_resign_send.get_mut().take();
1005 self.local_node.fail_all_nodes();
1007 drop(self.tasks);
1009 self.io_task.await;
1011 }
1012}
1013
1014#[derive(Debug, Error)]
1016#[error("failed to accept invitation")]
1017pub struct JoinError(#[source] RecvError);
1018
1019#[derive(Debug)]
1023struct Connector {
1024 local_id: NodeId,
1025 conn_req_send: Arc<mesh_channel::Sender<LeaderRequest>>,
1026 pending_connections: Arc<Mutex<HashMap<NodeId, RemoteNodeHandle>>>,
1027}
1028
1029impl Connect for Connector {
1030 fn connect(&self, node_id: NodeId, handle: RemoteNodeHandle) {
1031 tracing::trace!(local_id = ?self.local_id, remote_id = ?node_id, "connecting");
1032 let old_request = self.pending_connections.lock().insert(node_id, handle);
1033 if old_request.is_some() {
1034 panic!("duplicate connection request for {:?}", node_id);
1035 }
1036 self.conn_req_send.send(LeaderRequest::Connect(node_id))
1037 }
1038}
1039
1040fn new_socket_pair() -> Result<(Socket, Socket), io::Error> {
1042 let ty = if USE_SEQPACKET {
1043 socket2::Type::SEQPACKET
1044 } else {
1045 socket2::Type::STREAM
1046 };
1047 Socket::pair(socket2::Domain::UNIX, ty, None)
1048}
1049
1050struct UnixSocket {
1052 socket: Mutex<PolledSocket<Socket>>,
1053}
1054
1055impl UnixSocket {
1056 fn new(driver: &dyn SpawnDriver, fd: Socket) -> Self {
1057 let socket = PolledSocket::new(driver, fd).unwrap();
1058 UnixSocket {
1059 socket: Mutex::new(socket),
1060 }
1061 }
1062
1063 async fn send(&self, msg: &[u8], fds: &[OsResource]) -> io::Result<()> {
1064 if USE_SEQPACKET {
1065 self.send_raw(&mut [IoSlice::new(msg)], fds).await?;
1066 } else {
1067 let len = (msg.len() as u32).to_le_bytes();
1068 let mut iov = [IoSlice::new(&len), IoSlice::new(msg)];
1069 self.send_all_raw(&mut iov, fds).await?;
1070 }
1071 Ok(())
1072 }
1073
1074 async fn send_raw(
1075 &self,
1076 iov: &mut [IoSlice<'_>],
1077 fds: &[OsResource],
1078 ) -> Result<usize, io::Error> {
1079 let n = poll_fn(|cx| {
1080 self.socket
1081 .lock()
1082 .poll_io(cx, InterestSlot::Write, PollEvents::OUT, |socket| {
1083 try_send(socket.get().as_fd(), iov, fds)
1084 })
1085 })
1086 .await?;
1087 Ok(n)
1088 }
1089
1090 async fn send_all_raw(
1091 &self,
1092 mut iov: &mut [IoSlice<'_>],
1093 mut fds: &[OsResource],
1094 ) -> Result<(), io::Error> {
1095 while !iov.is_empty() || !fds.is_empty() {
1096 let n = self.send_raw(iov, fds).await?;
1097 IoSlice::advance_slices(&mut iov, n);
1098 fds = &[];
1099 }
1100 Ok(())
1101 }
1102
1103 async fn recv(&self, buf: &mut [u8], fds: &mut Vec<OsResource>) -> io::Result<usize> {
1104 if USE_SEQPACKET {
1105 self.recv_raw(buf, fds).await
1106 } else {
1107 let mut len = [0; 4];
1108 if !self.recv_all_raw(&mut len, fds).await? {
1109 return Ok(0);
1110 }
1111 let len = u32::from_le_bytes(len) as usize;
1112 let buf = buf
1113 .get_mut(..len)
1114 .ok_or_else(|| io::Error::from_raw_os_error(libc::EMSGSIZE))?;
1115 if !self.recv_all_raw(buf, fds).await? {
1116 return Err(ErrorKind::UnexpectedEof.into());
1117 }
1118 Ok(len)
1119 }
1120 }
1121
1122 async fn recv_all_raw(
1123 &self,
1124 buf: &mut [u8],
1125 fds: &mut Vec<OsResource>,
1126 ) -> Result<bool, io::Error> {
1127 let mut read = 0;
1128 while read < buf.len() {
1129 let n = self.recv_raw(&mut buf[read..], fds).await?;
1130 if n == 0 {
1131 if read != 0 {
1132 return Err(ErrorKind::UnexpectedEof.into());
1133 } else {
1134 return Ok(false);
1135 }
1136 }
1137 read += n;
1138 }
1139 Ok(true)
1140 }
1141
1142 async fn recv_raw(
1143 &self,
1144 buf: &mut [u8],
1145 fds: &mut Vec<OsResource>,
1146 ) -> Result<usize, io::Error> {
1147 let n = poll_fn(|cx| {
1148 self.socket
1149 .lock()
1150 .poll_io(cx, InterestSlot::Read, PollEvents::IN, |socket| {
1151 try_recv(socket.get().as_fd(), buf, fds)
1152 })
1153 })
1154 .await?;
1155 Ok(n)
1156 }
1157
1158 async fn close_write(&self) -> io::Result<()> {
1159 self.socket.lock().get().shutdown(std::net::Shutdown::Write)
1160 }
1161}
1162
1163#[cfg(test)]
1164mod tests {
1165 use crate::unix::UnixNode;
1166 use mesh_channel::RecvError;
1167 use mesh_channel::channel;
1168 use pal_async::DefaultDriver;
1169 use pal_async::async_test;
1170 use test_with_tracing::test;
1171
1172 #[async_test]
1173 async fn test_basic(driver: DefaultDriver) {
1174 let leader = UnixNode::new(driver.clone());
1175 let (send, recv) = channel::<u32>();
1176 let invitation = leader.invite(recv.into()).await.unwrap();
1177 let (send2, mut recv2) = channel::<u32>();
1178 let follower = UnixNode::join(driver, invitation, send2.into())
1179 .await
1180 .unwrap();
1181 send.send(5);
1182 assert_eq!(recv2.recv().await.unwrap(), 5);
1183 drop(send);
1184 drop(recv2);
1185 follower.shutdown().await;
1186 leader.shutdown().await;
1187 }
1188
1189 #[cfg(target_os = "linux")]
1190 #[async_test]
1191 async fn test_huge_message(driver: DefaultDriver) {
1192 let leader = UnixNode::new(driver.clone());
1193 let (send, recv) = channel::<Vec<u8>>();
1194 let invitation = leader.invite(recv.into()).await.unwrap();
1195 let (send2, mut recv2) = channel::<Vec<u8>>();
1196 let follower = UnixNode::join(driver, invitation, send2.into())
1197 .await
1198 .unwrap();
1199
1200 let v = vec![0xcc; 16 << 20];
1201 send.send(v.clone());
1202 let v2 = recv2.recv().await.unwrap();
1203 assert_eq!(v, v2);
1204 follower.shutdown().await;
1205 leader.shutdown().await;
1206 }
1207
1208 #[cfg(target_os = "linux")]
1209 #[async_test]
1210 async fn test_message_sizes(driver: DefaultDriver) {
1211 let (p1, p2) = mesh_node::local_node::Port::new_pair();
1212 let (p3, p4) = mesh_node::local_node::Port::new_pair();
1213 let node1 = UnixNode::new(driver.clone());
1214 let invitation = node1.invite(p2).await.unwrap();
1215 let _node2 = UnixNode::join(driver.clone(), invitation, p3)
1216 .await
1217 .unwrap();
1218
1219 crate::test_common::test_message_sizes(p1, p4, 0..=super::MAX_SMALL_EVENT_SIZE + 0x1000)
1220 .await;
1221 }
1222
1223 #[async_test]
1224 async fn test_dropped_shutdown(driver: DefaultDriver) {
1225 let leader = UnixNode::new(driver.clone());
1226 {
1227 let (_send, recv) = channel::<u32>();
1228 let invitation = leader.invite(recv.into()).await.unwrap();
1229 let (send2, _recv2) = channel::<u32>();
1230 let _follower = UnixNode::join(driver, invitation, send2.into())
1231 .await
1232 .unwrap();
1233 }
1234 leader.shutdown().await;
1235 }
1236
1237 #[async_test]
1238 async fn test_send_shutdown(driver: DefaultDriver) {
1239 let leader = UnixNode::new(driver.clone());
1240 let (send, mut recv) = channel::<u32>();
1241 let invitation = leader.invite(send.into()).await.unwrap();
1242 let (send2, recv2) = channel::<u32>();
1243 let follower = UnixNode::join(driver, invitation, recv2.into())
1244 .await
1245 .unwrap();
1246 send2.send(5);
1247 drop(send2);
1248 follower.shutdown().await;
1249 assert_eq!(recv.recv().await.unwrap(), 5);
1250 }
1251
1252 #[async_test]
1253 async fn test_failed_invitation(driver: DefaultDriver) {
1254 let leader = UnixNode::new(driver);
1255 let (send, mut recv) = channel::<()>();
1256 leader.invite(send.into()).await.unwrap();
1257 assert!(matches!(
1258 recv.recv().await.unwrap_err(),
1259 RecvError::Error(_)
1260 ));
1261 drop(recv);
1262 leader.shutdown().await;
1263 }
1264
1265 #[async_test]
1266 async fn test_three(driver: DefaultDriver) {
1267 let (p1, p2) = channel::<u32>();
1268 let (p3, mut p4) = channel::<u32>();
1269 let (p5, p6) = channel::<u32>();
1270 let (p7, p8) = channel::<u32>();
1271
1272 let node1 = UnixNode::new(driver.clone());
1273
1274 let invitation = node1.invite(p2.into()).await.unwrap();
1275 let node2 = UnixNode::join(driver.clone(), invitation, p3.into())
1276 .await
1277 .unwrap();
1278
1279 let invitation = node1.invite(p5.into()).await.unwrap();
1280 let node3 = UnixNode::join(driver, invitation, p8.into()).await.unwrap();
1281
1282 p1.bridge(p6);
1283
1284 p7.send(5);
1285
1286 assert_eq!(p4.recv().await.unwrap(), 5);
1287 drop(p4);
1288 drop(p7);
1289 futures::join!(node2.shutdown(), node3.shutdown());
1290 node1.shutdown().await;
1291 }
1292
1293 #[async_test]
1294 async fn test_handoff_leader(driver: DefaultDriver) {
1295 let (p1, p2) = channel::<u32>();
1296 let (p3, p4) = channel::<u32>();
1297 let (p5, p6) = channel::<u32>();
1298 let (p7, p8) = channel::<u32>();
1299 let (p9, p10) = channel();
1300 let (p11, mut p12) = channel();
1301
1302 let node1 = UnixNode::new(driver.clone());
1303
1304 let invitation = node1.invite(p2.into()).await.unwrap();
1305 let node2 = UnixNode::join(driver.clone(), invitation, p3.into())
1306 .await
1307 .unwrap();
1308
1309 let invitation = node1.invite(p5.into()).await.unwrap();
1310 let node3 = UnixNode::join(driver.clone(), invitation, p8.into())
1311 .await
1312 .unwrap();
1313
1314 let invitation = node1.invite(p10.into()).await.unwrap();
1315 let node4 = UnixNode::join(driver, invitation, p11.into())
1316 .await
1317 .unwrap();
1318
1319 p9.send(node1.offer_leadership());
1320 node4.accept_leadership(p12.recv().await.unwrap());
1321 drop(p9);
1322 drop(p12);
1323 p1.bridge(p6);
1324
1325 std::thread::sleep(std::time::Duration::from_millis(200));
1326
1327 node1.shutdown().await;
1328 drop(p4);
1329 drop(p7);
1330 node2.shutdown().await;
1331 node3.shutdown().await;
1332
1333 std::thread::sleep(std::time::Duration::from_millis(200));
1334
1335 node4.shutdown().await;
1336 }
1337}