1#![cfg(unix)]
15#![expect(unsafe_code)]
18
19#[cfg(target_os = "linux")]
20mod memfd;
21
22use crate::common::InvitationAddress;
23use crate::protocol;
24use futures::FutureExt;
25use futures::StreamExt;
26use futures::channel::mpsc;
27use futures::future;
28use futures::future::BoxFuture;
29use io::ErrorKind;
30use mesh_channel::OneshotReceiver;
31use mesh_channel::OneshotSender;
32use mesh_channel::RecvError;
33use mesh_channel::channel;
34use mesh_channel::oneshot;
35use mesh_node::common::Address;
36use mesh_node::common::NodeId;
37use mesh_node::common::PortId;
38use mesh_node::local_node::Connect;
39use mesh_node::local_node::LocalNode;
40use mesh_node::local_node::OutgoingEvent;
41use mesh_node::local_node::Port;
42use mesh_node::local_node::RemoteNodeHandle;
43use mesh_node::local_node::SendEvent;
44use mesh_node::resource::OsResource;
45use mesh_node::resource::Resource;
46use mesh_protobuf::Protobuf;
47use pal_async::driver::SpawnDriver;
48use pal_async::interest::InterestSlot;
49use pal_async::interest::PollEvents;
50use pal_async::socket::PolledSocket;
51use pal_async::task::Spawn;
52use pal_async::task::Task;
53use parking_lot::Mutex;
54use socket2::Socket;
55use std::collections::HashMap;
56use std::collections::VecDeque;
57use std::fmt::Debug;
58use std::future::Future;
59use std::future::poll_fn;
60use std::io;
61use std::io::IoSlice;
62use std::io::IoSliceMut;
63use std::os::unix::prelude::*;
64use std::pin::pin;
65use std::sync::Arc;
66use thiserror::Error;
67use tracing::instrument;
68use unicycle::FuturesUnordered;
69use zerocopy::FromBytes;
70use zerocopy::FromZeros;
71use zerocopy::IntoBytes;
72
73const USE_SEQPACKET: bool = cfg!(target_os = "linux");
80
81const MAX_PACKET_SIZE: usize = if cfg!(target_os = "linux") {
87 0x4000
88} else {
89 0x40000
90};
91
92const MAX_SMALL_EVENT_SIZE: usize = MAX_PACKET_SIZE - size_of::<protocol::PacketHeader>();
93
94pub struct UnixNode {
104 driver: Arc<dyn SpawnDriver>,
105 local_node: Arc<LocalNode>,
106 to_leader: Arc<mesh_channel::Sender<LeaderRequest>>,
107 tasks: Arc<mesh_channel::Sender<SmallTask>>,
108 io_task: Task<()>,
109 leader_resign_send:
111 Mutex<Option<Arc<mesh_channel::Sender<(NodeId, mesh_channel::Sender<Followers>)>>>>,
112
113 _drop_send: OneshotSender<()>,
115}
116
117#[derive(Debug, Protobuf)]
118#[mesh(resource = "Resource")]
119enum LeaderRequest {
120 Connect(NodeId),
121 Invite(Port, mesh_channel::Sender<Invitation>),
122}
123
124#[derive(Debug, Protobuf)]
125#[mesh(resource = "Resource")]
126enum FollowerRequest {
127 Connect(
128 NodeId,
129 #[mesh(
130 encoding = "mesh_protobuf::encoding::OptionField<mesh_protobuf::encoding::ResourceField<OwnedFd>>"
131 )]
132 Option<Socket>,
133 ),
134}
135
136#[derive(Debug, Protobuf)]
137#[mesh(resource = "Resource")]
138pub struct Followers {
139 list: Vec<(
140 NodeId,
141 mesh_channel::Receiver<LeaderRequest>,
142 mesh_channel::Sender<FollowerRequest>,
143 )>,
144}
145
146#[derive(Debug, Protobuf)]
147#[mesh(resource = "Resource")]
148struct InitialMessage {
149 leader_send: mesh_channel::Sender<LeaderRequest>,
150 follower_recv: mesh_channel::Receiver<FollowerRequest>,
151 user_port: Port,
152}
153
154#[instrument(skip_all, fields(local_id = ?local_node.id()))]
157async fn run_follower(
158 driver: &dyn SpawnDriver,
159 local_node: &Arc<LocalNode>,
160 mut recv: mesh_channel::Receiver<FollowerRequest>,
161 pending_connections: Arc<Mutex<HashMap<NodeId, RemoteNodeHandle>>>,
162 tasks: &mesh_channel::Sender<SmallTask>,
163) {
164 while let Ok(req) = recv.recv().await {
165 match req {
166 FollowerRequest::Connect(target_id, fd) => {
167 tracing::debug!(?target_id, "got connection request from leader");
168 let handle = pending_connections.lock().remove(&target_id);
169 let handle = handle.unwrap_or_else(|| local_node.get_remote_handle(target_id));
170
171 if let Some(fd) = fd {
172 start_connection(
173 tasks,
174 local_node,
175 target_id,
176 handle,
177 UnixSocket::new(driver, fd),
178 );
179 } else {
180 tracing::warn!(?target_id, "leader provided failed connection");
181 }
182 }
183 }
184 }
185}
186
187#[instrument(skip_all, fields(local_id = ?local_node.id()))]
191async fn run_leader(
192 driver: &dyn SpawnDriver,
193 local_node: &Arc<LocalNode>,
194 mut resign_recv: mesh_channel::Receiver<(NodeId, mesh_channel::Sender<Followers>)>,
195 followers: Followers,
196 tasks: &mesh_channel::Sender<SmallTask>,
197) {
198 let mut senders = HashMap::new();
199 let mut receivers = Vec::new();
200 for (remote_id, recv, send) in followers.list {
201 receivers.push((remote_id, recv));
202 senders.insert(remote_id, send);
203 }
204
205 let new_leader_info = loop {
206 if receivers.is_empty() {
207 return;
208 }
209 let recvs = receivers
210 .iter_mut()
211 .map(|(_, recv)| poll_fn(|cx| recv.poll_recv(cx)));
212 let (req, index, _) = futures::select! { r = resign_recv.next() => break r,
214 r = future::select_all(recvs).fuse() => r,
215 };
216 let remote_id = receivers[index].0;
217 match req {
218 Ok(req) => match req {
219 LeaderRequest::Connect(target_id) => {
220 tracing::debug!(?target_id, ?remote_id, "connection request");
221 let remote = senders
222 .get(&remote_id)
223 .expect("sender must exist to receive from it");
224 let mut fd = None;
225 if let Some(target) = senders.get(&target_id) {
226 match new_socket_pair() {
227 Ok((left, right)) => {
228 tracing::trace!(?target, "send to");
229 target.send(FollowerRequest::Connect(remote_id, Some(left)));
230 fd = Some(right);
231 }
232 Err(err) => {
233 tracing::warn!(
234 ?target_id,
235 ?remote_id,
236 error = &err as &dyn std::error::Error,
237 "failed to create socket pair for connection request"
238 );
239 }
240 }
241 } else {
242 tracing::warn!(?target_id, ?remote_id, "could not find target for remote");
243 }
244 remote.send(FollowerRequest::Connect(target_id, fd));
245 }
246 LeaderRequest::Invite(port, send) => {
247 tracing::debug!(?remote_id, "invitation request");
248 match new_socket_pair() {
249 Ok((left, right)) => {
250 let (leader_send, leader_recv) = channel();
251 let (follower_send, follower_recv) = channel();
252 let remote_addr = Address {
253 node: NodeId::new(),
254 port: PortId::new(),
255 };
256 let local_port_id = PortId::new();
257 let handle = local_node.add_remote(remote_addr.node);
258 start_connection(
259 tasks,
260 local_node,
261 remote_addr.node,
262 handle,
263 UnixSocket::new(driver, left),
264 );
265 let init_send = OneshotSender::<InitialMessage>::from(
266 local_node.add_port(local_port_id, remote_addr),
267 );
268 init_send.send(InitialMessage {
269 leader_send,
270 follower_recv,
271 user_port: port,
272 });
273 let invitation = Invitation {
274 address: InvitationAddress {
275 local_addr: remote_addr,
276 remote_addr: Address {
277 node: local_node.id(),
278 port: local_port_id,
279 },
280 },
281 fd: right.into(),
282 };
283 tracing::debug!(
284 invite_id = ?invitation.address.local_addr.node,
285 ?remote_id,
286 "inviting",
287 );
288 send.send(invitation);
289 senders.insert(remote_addr.node, follower_send);
290 receivers.push((remote_addr.node, leader_recv));
291 }
292 Err(err) => {
293 tracing::error!(
294 error = &err as &dyn std::error::Error,
295 "failed to create socket pair",
296 );
297 }
298 }
299 }
300 },
301 Err(err) => {
302 if let RecvError::Error(err) = err {
303 tracing::debug!(
304 ?remote_id,
305 error = &err as &dyn std::error::Error,
306 "leader connection to remote failed"
307 );
308 }
309 senders.remove(&remote_id);
310 receivers.swap_remove(index);
311 }
312 }
313 };
314
315 if let Some((new_leader_id, new_leader_followers_sink)) = new_leader_info {
316 if let Some(new_leader_send) = senders.get(&new_leader_id) {
317 tracing::debug!(?new_leader_id, "resigning leadership");
318 for (remote_id, send) in senders.iter() {
321 if new_leader_id != *remote_id {
322 match new_socket_pair() {
323 Ok((left, right)) => {
324 send.send(FollowerRequest::Connect(new_leader_id, Some(left)));
325 new_leader_send.send(FollowerRequest::Connect(*remote_id, Some(right)));
326 }
327 Err(err) => {
328 tracing::error!(
329 ?new_leader_id,
330 error = &err as &dyn std::error::Error,
331 "failed to connect node to new leader, mesh is leaderless",
332 );
333 return;
334 }
335 }
336 }
337 }
338
339 let mut followers = Vec::new();
341 for (remote_id, recv) in receivers {
342 let send = senders
343 .remove(&remote_id)
344 .expect("should be in sync with receivers");
345 followers.push((remote_id, recv, send));
346 }
347 new_leader_followers_sink.send(Followers { list: followers });
348 } else {
349 tracing::error!(?new_leader_id, "new leader is unknown, mesh is leaderless");
350 }
351 }
352}
353
354struct SmallTask {
357 name: &'static str,
358 future: BoxFuture<'static, ()>,
359}
360
361impl Debug for SmallTask {
362 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363 f.pad("SmallTask")
364 }
365}
366
367impl SmallTask {
368 fn new(name: &'static str, f: impl 'static + Send + Future<Output = ()>) -> Self {
369 Self {
370 name,
371 future: Box::pin(f),
372 }
373 }
374}
375
376#[derive(Debug, Protobuf)]
380#[mesh(resource = "Resource")]
381pub struct Invitation {
382 pub address: InvitationAddress,
384 pub fd: OwnedFd,
386}
387
388#[derive(Debug)]
389enum SenderCommand {
390 Send {
391 packet: Vec<u8>,
392 fds: Vec<OsResource>,
393 },
394 ReleaseFds {
395 count: usize,
396 },
397}
398
399#[derive(Clone)]
400struct PacketSender {
401 send: mpsc::UnboundedSender<SenderCommand>,
402 socket: Arc<UnixSocket>,
403}
404
405impl SendEvent for PacketSender {
406 fn event(&self, event: OutgoingEvent<'_>) {
407 let (packet, fds) = match serialize_event(event) {
408 Ok(r) => r,
409 Err(err) => {
410 tracing::error!(
412 error = &err as &dyn std::error::Error,
413 "failed to serialize event"
414 );
415 return;
416 }
417 };
418
419 if !USE_SEQPACKET
431 || try_send(
432 self.socket.socket.lock().get(),
433 &[IoSlice::new(&packet)],
434 &fds,
435 )
436 .is_err()
437 {
438 let _ = self
439 .send
440 .unbounded_send(SenderCommand::Send { packet, fds });
441 }
442 }
443}
444
445fn serialize_event(event: OutgoingEvent<'_>) -> io::Result<(Vec<u8>, Vec<OsResource>)> {
446 if event.len() > MAX_SMALL_EVENT_SIZE {
448 return serialize_large_event(event);
449 }
450
451 let cap = size_of::<protocol::PacketHeader>() + event.len();
453 let mut packet = Vec::with_capacity(cap);
454 packet.extend_from_slice(
455 protocol::PacketHeader {
456 packet_type: protocol::PacketType::EVENT,
457 reserved: [0; 7],
458 }
459 .as_bytes(),
460 );
461 let mut fds = Vec::new();
462 event.write_to(&mut packet, &mut fds);
463 assert_eq!(packet.len(), cap);
464 Ok((packet, fds))
465}
466
467#[cfg(target_os = "linux")]
468fn serialize_large_event(event: OutgoingEvent<'_>) -> io::Result<(Vec<u8>, Vec<OsResource>)> {
469 let packet = protocol::PacketHeader {
470 packet_type: protocol::PacketType::LARGE_EVENT,
471 ..FromZeros::new_zeroed()
472 }
473 .as_bytes()
474 .to_vec();
475
476 let mut fds = Vec::new();
477
478 let mut memfd = memfd::MemfdBuilder::new(event.len())?;
479 event.write_to(&mut io::Cursor::new(&mut *memfd), &mut fds);
480 fds.insert(0, OsResource::Fd(memfd.seal()?.into()));
481
482 Ok((packet, fds))
483}
484
485#[cfg(not(target_os = "linux"))]
486fn serialize_large_event(_event: OutgoingEvent<'_>) -> io::Result<(Vec<u8>, Vec<OsResource>)> {
487 Err(io::Error::new(
488 ErrorKind::Unsupported,
489 "event too large for this OS",
490 ))
491}
492
493impl Drop for PacketSender {
494 fn drop(&mut self) {
495 self.send.close_channel();
498 }
499}
500
501fn start_connection(
503 tasks: &mesh_channel::Sender<SmallTask>,
504 local_node: &Arc<LocalNode>,
505 remote_id: NodeId,
506 handle: RemoteNodeHandle,
507 socket: UnixSocket,
508) {
509 #[expect(clippy::disallowed_methods)] let (send, recv) = mpsc::unbounded();
511 let socket = Arc::new(socket);
512 let sender = PacketSender {
513 send: send.clone(),
514 socket: socket.clone(),
515 };
516 if handle.connect(sender) {
517 let task = SmallTask::new("run_connection", {
518 let local_node = local_node.clone();
519 run_connection(local_node, remote_id, send, recv, socket, handle)
520 });
521 tasks.send(task);
522 tracing::debug!(?remote_id, "connected");
523 } else {
524 tracing::debug!(?remote_id, "duplicate connection");
528 }
529}
530
531#[instrument(skip_all, fields(local_id = ?local_node.id(), remote_id = ?remote_id))]
533async fn run_connection(
534 local_node: Arc<LocalNode>,
535 remote_id: NodeId,
536 send_send: mpsc::UnboundedSender<SenderCommand>,
537 send_recv: mpsc::UnboundedReceiver<SenderCommand>,
538 socket: Arc<UnixSocket>,
539 handle: RemoteNodeHandle,
540) {
541 let mut retained_fds = VecDeque::new();
542 let mut recv = pin!(
543 async {
544 let r = run_receive(&local_node, &remote_id, &socket, &send_send).await;
545 match &r {
546 Ok(_) => {
547 tracing::debug!("incoming socket disconnected");
548 }
549 Err(err) => {
550 tracing::error!(error = err as &dyn std::error::Error, "error receiving");
551 }
552 }
553 r
554 }
555 .fuse()
556 );
557 let mut send = pin!(
558 async {
559 match run_send(send_recv, &socket, &mut retained_fds).await {
560 Ok(_) => {
561 tracing::debug!("sending is done");
562 }
563 Err(err) => {
564 tracing::error!(error = &err as &dyn std::error::Error, "failed send");
565 }
566 }
567 }
568 .fuse()
569 );
570 let r = futures::select! { r = recv => {
572 tracing::trace!("read complete, shutting down writes");
574 let _ = socket.close_write().await;
575 r
576 }
577 _ = send => {
578 match socket.close_write().await {
579 Ok(()) => {
580 tracing::trace!("shutdown writes, waiting for reads");
581 recv.await
582 }
583 Err(err) => {
584 tracing::error!(
585 error = &err as &dyn std::error::Error,
586 "failed to shutdown writes, aborting connection",
587 );
588 Err(ReceiveError::Io(err))
589 }
590 }
591 }
592 };
593 tracing::trace!("connection done");
594 match r {
595 Ok(()) => handle.disconnect(),
596 Err(err) => handle.fail(err),
597 }
598}
599
600#[derive(Debug, Error)]
601enum ReceiveError {
602 #[error("i/o error")]
603 Io(#[from] io::Error),
604 #[error("missing packet header")]
605 NoHeader,
606 #[error("release fds packet too small")]
607 BadReleaseFds,
608 #[error("unknown packet type {0:?}")]
609 UnknownPacketType(protocol::PacketType),
610 #[cfg(target_os = "linux")]
611 #[error("memfd file descriptor not sent for large event")]
612 MissingMemfd,
613 #[cfg(target_os = "linux")]
614 #[error("failed to map memfd")]
615 Memfd(#[source] io::Error),
616}
617
618async fn run_receive(
620 local_node: &LocalNode,
621 remote_id: &NodeId,
622 socket: &UnixSocket,
623 send: &mpsc::UnboundedSender<SenderCommand>,
624) -> Result<(), ReceiveError> {
625 let mut buf = vec![0; MAX_PACKET_SIZE];
626 let mut fds = Vec::new();
627 loop {
628 let len = socket.recv(&mut buf, &mut fds).await?;
629 if len == 0 {
630 break;
631 }
632 if cfg!(target_os = "macos") && !fds.is_empty() {
633 let _ = send.unbounded_send(SenderCommand::Send {
635 packet: protocol::ReleaseFds {
636 header: protocol::PacketHeader {
637 packet_type: protocol::PacketType::RELEASE_FDS,
638 ..FromZeros::new_zeroed()
639 },
640 count: fds.len() as u64,
641 }
642 .as_bytes()
643 .to_vec(),
644 fds: Vec::new(),
645 });
646 }
647
648 let buf = &buf[..len];
649 let header = protocol::PacketHeader::read_from_prefix(buf)
650 .map_err(|_| ReceiveError::NoHeader)?
651 .0; match header.packet_type {
653 protocol::PacketType::EVENT => {
654 local_node.event(remote_id, &buf[size_of_val(&header)..], &mut fds);
655 fds.clear();
656 }
657 protocol::PacketType::RELEASE_FDS => {
658 let release_fds = protocol::ReleaseFds::read_from_prefix(buf)
659 .map_err(|_| ReceiveError::BadReleaseFds)?
660 .0; let _ = send.unbounded_send(SenderCommand::ReleaseFds {
662 count: release_fds.count as usize,
663 });
664 }
665 #[cfg(target_os = "linux")]
666 protocol::PacketType::LARGE_EVENT => {
667 if fds.is_empty() {
668 return Err(ReceiveError::MissingMemfd);
669 }
670 let OsResource::Fd(fd) = fds.remove(0);
671 let memfd = memfd::SealedMemfd::new(fd.into()).map_err(ReceiveError::Memfd)?;
672 local_node.event(remote_id, &memfd, &mut fds);
673 fds.clear();
674 }
675 ty => {
676 return Err(ReceiveError::UnknownPacketType(ty));
677 }
678 }
679 }
680 Ok(())
681}
682
683#[derive(Debug, Error)]
684enum ProtocolError {
685 #[error("request to release too many fds")]
686 ReleasingTooManyFds,
687}
688
689async fn run_send(
691 mut recv: mpsc::UnboundedReceiver<SenderCommand>,
692 socket: &UnixSocket,
693 retained_fds: &mut VecDeque<OsResource>,
694) -> io::Result<()> {
695 while let Some(command) = recv.next().await {
696 match command {
697 SenderCommand::Send { packet, fds } => {
698 match socket.send(&packet, &fds).await {
699 Ok(_) => (),
700 Err(err) => {
701 tracing::error!(
702 fd_count = fds.len(),
703 packet_len = packet.len(),
704 "failed to send packet"
705 );
706 return Err(err);
707 }
708 }
709 if cfg!(target_os = "macos") {
710 if !fds.is_empty() {
716 retained_fds.extend(fds);
717 }
718 }
719 }
720 SenderCommand::ReleaseFds { count } => {
721 if retained_fds.len() < count {
722 return Err(io::Error::other(ProtocolError::ReleasingTooManyFds));
723 }
724 retained_fds.drain(..count);
725 }
726 }
727 }
728 Ok(())
729}
730
731#[derive(Debug, Protobuf)]
734#[mesh(resource = "Resource")]
735pub struct LeadershipOffer {
736 send: mesh_channel::Sender<(NodeId, mesh_channel::Sender<Followers>)>,
737}
738
739impl UnixNode {
740 pub fn new(driver: impl SpawnDriver) -> Self {
742 let (to_leader_send, to_leader_recv) = channel();
743 let (from_leader_send, from_leader_recv) = channel();
744 let this = Self::with_id(
745 Arc::new(driver),
746 NodeId::new(),
747 to_leader_send,
748 from_leader_recv,
749 );
750
751 let (resign_send, resign_recv) = channel();
753 let resign_send = Arc::new(resign_send);
754 let followers = Followers {
755 list: vec![(this.local_node.id(), to_leader_recv, from_leader_send)],
756 };
757 let task = SmallTask::new("run_leader", {
758 let local_node = this.local_node.clone();
759 let tasks = this.tasks.clone();
760 let driver = this.driver.clone();
761 async move { run_leader(driver.as_ref(), &local_node, resign_recv, followers, &tasks).await }
762 });
763 this.tasks.send(task);
764 *this.leader_resign_send.lock() = Some(resign_send);
765
766 this
767 }
768
769 pub fn id(&self) -> NodeId {
771 self.local_node.id()
772 }
773
774 fn with_id(
777 driver: Arc<dyn SpawnDriver>,
778 id: NodeId,
779 to_leader: mesh_channel::Sender<LeaderRequest>,
780 from_leader: mesh_channel::Receiver<FollowerRequest>,
781 ) -> Self {
782 let to_leader = Arc::new(to_leader);
783 let pending_connections: Arc<Mutex<HashMap<NodeId, RemoteNodeHandle>>> = Default::default();
784 let local_node = Arc::new(LocalNode::with_id(
785 id,
786 Box::new(Connector {
787 local_id: id,
788 conn_req_send: to_leader.clone(),
789 pending_connections: pending_connections.clone(),
790 }),
791 ));
792 let (task_send, mut task_recv) = channel::<SmallTask>();
793 let task_send = Arc::new(task_send);
794 let (drop_send, drop_recv) = oneshot();
795
796 let io_task = driver.spawn("unix-mesh-io", async move {
798 let process = async {
799 let mut futs = FuturesUnordered::new();
800 loop {
801 futures::select! { _ = futs.next() => {},
803 task = task_recv.select_next_some() => {
804 futs.push(async move {
805 tracing::trace!(?id, name = task.name, "task start");
806 task.future.await;
807 tracing::trace!(?id, name = task.name, "task end");
808 });
809 },
810 complete => break,
811 };
812 }
813 };
814 future::select(pin!(process), drop_recv).await;
815 });
816
817 task_send.send(SmallTask::new("run_follower", {
818 let local_node = local_node.clone();
819 let tasks = task_send.clone();
820 let driver = driver.clone();
821 async move {
822 run_follower(
823 driver.as_ref(),
824 &local_node,
825 from_leader,
826 pending_connections,
827 &tasks,
828 )
829 .await
830 }
831 }));
832
833 Self {
834 driver,
835 local_node,
836 tasks: task_send,
837 io_task,
838 to_leader,
839 leader_resign_send: Mutex::new(None),
840
841 _drop_send: drop_send,
842 }
843 }
844
845 pub fn offer_leadership(&self) -> LeadershipOffer {
850 let (send, mut recv) = channel();
851 if let Some(leader_send) = self.leader_resign_send.lock().clone() {
852 let task = SmallTask::new("offer_leadership", async move {
855 if let Ok(r) = recv.recv().await {
856 leader_send.send(r);
857 }
858 });
859 self.tasks.send(task);
860 }
861 LeadershipOffer { send }
862 }
863
864 pub fn accept_leadership(&self, offer: LeadershipOffer) {
866 let (send, mut recv) = channel();
867 offer.send.send((self.local_node.id(), send));
868
869 let (resign_send, resign_recv) = channel();
870 let resign_send = Arc::new(resign_send);
871 let task = SmallTask::new("accept_and_run_leader", {
872 let local_node = self.local_node.clone();
873 let tasks = self.tasks.clone();
874 let driver = self.driver.clone();
875 async move {
876 if let Ok(followers) = recv.recv().await {
877 drop(recv);
878 run_leader(driver.as_ref(), &local_node, resign_recv, followers, &tasks).await
879 }
880 }
881 });
882 self.tasks.send(task);
883 *self.leader_resign_send.lock() = Some(resign_send);
884 }
885
886 #[instrument(skip_all, fields(local_id = ?self.local_node.id()))]
889 pub async fn invite(&self, port: Port) -> io::Result<Invitation> {
890 let (invitation_send, mut invitation_recv) = channel();
891 self.to_leader
892 .send(LeaderRequest::Invite(port, invitation_send));
893 let invitation = invitation_recv
894 .recv()
895 .await
896 .map_err(|_| ErrorKind::ConnectionReset)?;
897 tracing::debug!(
898 invite_id = ?invitation.address.local_addr.node,
899 "received invitation",
900 );
901 Ok(invitation)
902 }
903
904 pub async fn join(
907 driver: impl SpawnDriver,
908 invitation: Invitation,
909 port: Port,
910 ) -> Result<Self, JoinError> {
911 Self::join_generic(Arc::new(driver), invitation, port).await
912 }
913
914 #[instrument(skip_all, fields(local_id = ?invitation.address.local_addr.node, remote_id = ?invitation.address.remote_addr.node))]
915 async fn join_generic(
916 driver: Arc<dyn SpawnDriver>,
917 invitation: Invitation,
918 port: Port,
919 ) -> Result<Self, JoinError> {
920 let (to_leader_send, to_leader_recv) = channel();
921 let (from_leader_send, from_leader_recv) = channel();
922 let this = Self::with_id(
923 driver,
924 invitation.address.local_addr.node,
925 to_leader_send,
926 from_leader_recv,
927 );
928
929 let handle = this
930 .local_node
931 .add_remote(invitation.address.remote_addr.node);
932 let init_recv = OneshotReceiver::<InitialMessage>::from(this.local_node.add_port(
933 invitation.address.local_addr.port,
934 invitation.address.remote_addr,
935 ));
936
937 start_connection(
938 &this.tasks,
939 &this.local_node,
940 invitation.address.remote_addr.node,
941 handle,
942 UnixSocket::new(this.driver.as_ref(), invitation.fd.into()),
943 );
944
945 let init_message = init_recv.await.map_err(JoinError)?;
946 to_leader_recv.bridge(init_message.leader_send);
947 from_leader_send.bridge(init_message.follower_recv);
948 port.bridge(init_message.user_port);
949
950 Ok(this)
951 }
952
953 pub async fn shutdown(mut self) {
962 self.local_node.wait_for_ports(false).await;
964 drop(self.to_leader);
966 self.local_node.drop_connector();
967 self.leader_resign_send.get_mut().take();
969 self.local_node.fail_all_nodes();
971 drop(self.tasks);
973 self.io_task.await;
975 }
976}
977
978#[derive(Debug, Error)]
980#[error("failed to accept invitation")]
981pub struct JoinError(#[source] RecvError);
982
983#[derive(Debug)]
987struct Connector {
988 local_id: NodeId,
989 conn_req_send: Arc<mesh_channel::Sender<LeaderRequest>>,
990 pending_connections: Arc<Mutex<HashMap<NodeId, RemoteNodeHandle>>>,
991}
992
993impl Connect for Connector {
994 fn connect(&self, node_id: NodeId, handle: RemoteNodeHandle) {
995 tracing::trace!(local_id = ?self.local_id, remote_id = ?node_id, "connecting");
996 let old_request = self.pending_connections.lock().insert(node_id, handle);
997 if old_request.is_some() {
998 panic!("duplicate connection request for {:?}", node_id);
999 }
1000 self.conn_req_send.send(LeaderRequest::Connect(node_id))
1001 }
1002}
1003
1004fn new_socket_pair() -> Result<(Socket, Socket), io::Error> {
1006 let ty = if USE_SEQPACKET {
1007 socket2::Type::SEQPACKET
1008 } else {
1009 socket2::Type::STREAM
1010 };
1011 Socket::pair(socket2::Domain::UNIX, ty, None)
1012}
1013
1014struct UnixSocket {
1016 socket: Mutex<PolledSocket<Socket>>,
1017}
1018
1019#[repr(C)]
1020struct CmsgScmRights {
1021 hdr: libc::cmsghdr,
1022 fds: [RawFd; 64],
1023}
1024
1025fn advance_slices(bufs: &mut &mut [IoSlice<'_>], n: usize) {
1028 let mut remove = 0;
1030 let mut accumulated_len = 0;
1032 for buf in bufs.iter() {
1033 if accumulated_len + buf.len() > n {
1034 break;
1035 } else {
1036 accumulated_len += buf.len();
1037 remove += 1;
1038 }
1039 }
1040
1041 *bufs = &mut std::mem::take(bufs)[remove..];
1042 if !bufs.is_empty() {
1043 let buf = bufs[0];
1044 bufs[0] = unsafe {
1049 std::mem::transmute::<IoSlice<'_>, IoSlice<'_>>(IoSlice::new(
1050 &buf[n - accumulated_len..],
1051 ))
1052 };
1053 }
1054}
1055
1056impl UnixSocket {
1057 fn new(driver: &dyn SpawnDriver, fd: Socket) -> Self {
1058 let socket = PolledSocket::new(driver, fd).unwrap();
1059 UnixSocket {
1060 socket: Mutex::new(socket),
1061 }
1062 }
1063
1064 async fn send(&self, msg: &[u8], fds: &[OsResource]) -> io::Result<()> {
1065 if USE_SEQPACKET {
1066 self.send_raw(&mut [IoSlice::new(msg)], fds).await?;
1067 } else {
1068 let len = (msg.len() as u32).to_le_bytes();
1069 let mut iov = [IoSlice::new(&len), IoSlice::new(msg)];
1070 self.send_all_raw(&mut iov, fds).await?;
1071 }
1072 Ok(())
1073 }
1074
1075 async fn send_raw(
1076 &self,
1077 iov: &mut [IoSlice<'_>],
1078 fds: &[OsResource],
1079 ) -> Result<usize, io::Error> {
1080 let n = poll_fn(|cx| {
1081 self.socket
1082 .lock()
1083 .poll_io(cx, InterestSlot::Write, PollEvents::OUT, |socket| {
1084 try_send(socket.get(), iov, fds)
1085 })
1086 })
1087 .await?;
1088 Ok(n)
1089 }
1090
1091 async fn send_all_raw(
1092 &self,
1093 mut iov: &mut [IoSlice<'_>],
1094 mut fds: &[OsResource],
1095 ) -> Result<(), io::Error> {
1096 while !iov.is_empty() || !fds.is_empty() {
1097 let n = self.send_raw(iov, fds).await?;
1098 advance_slices(&mut iov, n);
1099 fds = &[];
1100 }
1101 Ok(())
1102 }
1103
1104 async fn recv(&self, buf: &mut [u8], fds: &mut Vec<OsResource>) -> io::Result<usize> {
1105 if USE_SEQPACKET {
1106 self.recv_raw(buf, fds).await
1107 } else {
1108 let mut len = [0; 4];
1109 if !self.recv_all_raw(&mut len, fds).await? {
1110 return Ok(0);
1111 }
1112 let len = u32::from_le_bytes(len) as usize;
1113 let buf = buf
1114 .get_mut(..len)
1115 .ok_or_else(|| io::Error::from_raw_os_error(libc::EMSGSIZE))?;
1116 if !self.recv_all_raw(buf, fds).await? {
1117 return Err(ErrorKind::UnexpectedEof.into());
1118 }
1119 Ok(len)
1120 }
1121 }
1122
1123 async fn recv_all_raw(
1124 &self,
1125 buf: &mut [u8],
1126 fds: &mut Vec<OsResource>,
1127 ) -> Result<bool, io::Error> {
1128 let mut read = 0;
1129 while read < buf.len() {
1130 let n = self.recv_raw(&mut buf[read..], fds).await?;
1131 if n == 0 {
1132 if read != 0 {
1133 return Err(ErrorKind::UnexpectedEof.into());
1134 } else {
1135 return Ok(false);
1136 }
1137 }
1138 read += n;
1139 }
1140 Ok(true)
1141 }
1142
1143 async fn recv_raw(
1144 &self,
1145 buf: &mut [u8],
1146 fds: &mut Vec<OsResource>,
1147 ) -> Result<usize, io::Error> {
1148 let n = poll_fn(|cx| {
1149 self.socket
1150 .lock()
1151 .poll_io(cx, InterestSlot::Read, PollEvents::IN, |socket| {
1152 try_recv(socket.get(), buf, fds)
1153 })
1154 })
1155 .await?;
1156 Ok(n)
1157 }
1158
1159 async fn close_write(&self) -> io::Result<()> {
1160 self.socket.lock().get().shutdown(std::net::Shutdown::Write)
1161 }
1162}
1163
1164#[allow(clippy::needless_update, clippy::useless_conversion)]
1169fn try_send(socket: &Socket, msg: &[IoSlice<'_>], fds: &[OsResource]) -> io::Result<usize> {
1170 let mut cmsg = CmsgScmRights {
1171 hdr: libc::cmsghdr {
1172 cmsg_level: libc::SOL_SOCKET,
1173 cmsg_type: libc::SCM_RIGHTS,
1174 cmsg_len: (size_of::<libc::cmsghdr>() + size_of_val(fds))
1175 .try_into()
1176 .unwrap(),
1177
1178 ..{
1179 unsafe { std::mem::zeroed() }
1181 }
1182 },
1183 fds: [0; 64],
1184 };
1185 for (fdi, fdo) in fds.iter().zip(cmsg.fds.iter_mut()) {
1186 *fdo = match fdi {
1187 OsResource::Fd(fd) => fd.as_raw_fd(),
1188 }
1189 }
1190
1191 let mut hdr: libc::msghdr = unsafe { std::mem::zeroed() };
1193 hdr.msg_iov = msg.as_ptr() as *mut libc::iovec;
1194 hdr.msg_iovlen = msg.len().try_into().unwrap();
1195 hdr.msg_control = if fds.is_empty() {
1196 std::ptr::null_mut()
1197 } else {
1198 std::ptr::from_mut(&mut cmsg).cast::<libc::c_void>()
1199 };
1200 hdr.msg_controllen = if fds.is_empty() { 0 } else { cmsg.hdr.cmsg_len };
1201 let n = unsafe { libc::sendmsg(socket.as_raw_fd(), &hdr, 0) };
1203 if n < 0 {
1204 return Err(io::Error::last_os_error());
1205 }
1206 Ok(n as usize)
1207}
1208
1209fn try_recv(socket: &Socket, buf: &mut [u8], fds: &mut Vec<OsResource>) -> io::Result<usize> {
1213 assert!(!buf.is_empty());
1214 let mut iov = IoSliceMut::new(buf);
1215 let mut cmsg: CmsgScmRights = unsafe { std::mem::zeroed() };
1217 let mut hdr: libc::msghdr = unsafe { std::mem::zeroed() };
1219 hdr.msg_iov = std::ptr::from_mut(&mut iov).cast::<libc::iovec>();
1220 hdr.msg_iovlen = 1;
1221 hdr.msg_control = std::ptr::from_mut(&mut cmsg).cast::<libc::c_void>();
1222 hdr.msg_controllen = size_of_val(&cmsg) as _;
1223
1224 #[cfg(target_os = "linux")]
1226 let flags = libc::MSG_CMSG_CLOEXEC;
1227 #[cfg(not(target_os = "linux"))]
1228 let flags = 0;
1229
1230 let n = unsafe { libc::recvmsg(socket.as_raw_fd(), &mut hdr, flags) };
1232 if n < 0 {
1233 return Err(io::Error::last_os_error());
1234 }
1235 if n == 0 {
1236 assert_eq!(hdr.msg_controllen, 0);
1237 return Ok(0);
1238 }
1239
1240 let fd_count = if hdr.msg_controllen > 0 {
1241 if cmsg.hdr.cmsg_level != libc::SOL_SOCKET || cmsg.hdr.cmsg_type != libc::SCM_RIGHTS {
1242 return Err(ErrorKind::InvalidData.into());
1244 }
1245 #[allow(clippy::unnecessary_cast)] {
1247 (cmsg.hdr.cmsg_len as usize - size_of_val(&cmsg.hdr)) / size_of::<RawFd>()
1248 }
1249 } else {
1250 0
1251 };
1252
1253 let start = fds.len();
1254 fds.extend(cmsg.fds[..fd_count].iter().map(|x| {
1255 OsResource::Fd(unsafe { OwnedFd::from_raw_fd(*x) })
1258 }));
1259
1260 if !cfg!(target_os = "linux") {
1263 for OsResource::Fd(fd) in &fds[start..] {
1264 set_cloexec(fd);
1265 }
1266 }
1267
1268 if hdr.msg_flags & (libc::MSG_TRUNC | libc::MSG_CTRUNC) != 0 {
1270 return Err(io::Error::from_raw_os_error(libc::EMSGSIZE));
1271 }
1272 Ok(n as usize)
1273}
1274
1275fn set_cloexec(fd: impl AsFd) {
1276 unsafe {
1278 let flags = libc::fcntl(fd.as_fd().as_raw_fd(), libc::F_GETFD);
1279 assert!(flags >= 0);
1280 let r = libc::fcntl(
1281 fd.as_fd().as_raw_fd(),
1282 libc::F_SETFD,
1283 flags | libc::FD_CLOEXEC,
1284 );
1285 assert!(r >= 0);
1286 }
1287}
1288
1289#[cfg(test)]
1290mod tests {
1291 use crate::unix::UnixNode;
1292 use mesh_channel::RecvError;
1293 use mesh_channel::channel;
1294 use pal_async::DefaultDriver;
1295 use pal_async::async_test;
1296 use test_with_tracing::test;
1297
1298 #[async_test]
1299 async fn test_basic(driver: DefaultDriver) {
1300 let leader = UnixNode::new(driver.clone());
1301 let (send, recv) = channel::<u32>();
1302 let invitation = leader.invite(recv.into()).await.unwrap();
1303 let (send2, mut recv2) = channel::<u32>();
1304 let follower = UnixNode::join(driver, invitation, send2.into())
1305 .await
1306 .unwrap();
1307 send.send(5);
1308 assert_eq!(recv2.recv().await.unwrap(), 5);
1309 drop(send);
1310 drop(recv2);
1311 follower.shutdown().await;
1312 leader.shutdown().await;
1313 }
1314
1315 #[cfg(target_os = "linux")]
1316 #[async_test]
1317 async fn test_huge_message(driver: DefaultDriver) {
1318 let leader = UnixNode::new(driver.clone());
1319 let (send, recv) = channel::<Vec<u8>>();
1320 let invitation = leader.invite(recv.into()).await.unwrap();
1321 let (send2, mut recv2) = channel::<Vec<u8>>();
1322 let follower = UnixNode::join(driver, invitation, send2.into())
1323 .await
1324 .unwrap();
1325
1326 let v = vec![0xcc; 16 << 20];
1327 send.send(v.clone());
1328 let v2 = recv2.recv().await.unwrap();
1329 assert_eq!(v, v2);
1330 follower.shutdown().await;
1331 leader.shutdown().await;
1332 }
1333
1334 #[cfg(target_os = "linux")]
1335 #[async_test]
1336 async fn test_message_sizes(driver: DefaultDriver) {
1337 let (p1, p2) = mesh_node::local_node::Port::new_pair();
1338 let (p3, p4) = mesh_node::local_node::Port::new_pair();
1339 let node1 = UnixNode::new(driver.clone());
1340 let invitation = node1.invite(p2).await.unwrap();
1341 let _node2 = UnixNode::join(driver.clone(), invitation, p3)
1342 .await
1343 .unwrap();
1344
1345 crate::test_common::test_message_sizes(p1, p4, 0..=super::MAX_SMALL_EVENT_SIZE + 0x1000)
1346 .await;
1347 }
1348
1349 #[async_test]
1350 async fn test_dropped_shutdown(driver: DefaultDriver) {
1351 let leader = UnixNode::new(driver.clone());
1352 {
1353 let (_send, recv) = channel::<u32>();
1354 let invitation = leader.invite(recv.into()).await.unwrap();
1355 let (send2, _recv2) = channel::<u32>();
1356 let _follower = UnixNode::join(driver, invitation, send2.into())
1357 .await
1358 .unwrap();
1359 }
1360 leader.shutdown().await;
1361 }
1362
1363 #[async_test]
1364 async fn test_send_shutdown(driver: DefaultDriver) {
1365 let leader = UnixNode::new(driver.clone());
1366 let (send, mut recv) = channel::<u32>();
1367 let invitation = leader.invite(send.into()).await.unwrap();
1368 let (send2, recv2) = channel::<u32>();
1369 let follower = UnixNode::join(driver, invitation, recv2.into())
1370 .await
1371 .unwrap();
1372 send2.send(5);
1373 drop(send2);
1374 follower.shutdown().await;
1375 assert_eq!(recv.recv().await.unwrap(), 5);
1376 }
1377
1378 #[async_test]
1379 async fn test_failed_invitation(driver: DefaultDriver) {
1380 let leader = UnixNode::new(driver);
1381 let (send, mut recv) = channel::<()>();
1382 leader.invite(send.into()).await.unwrap();
1383 assert!(matches!(
1384 recv.recv().await.unwrap_err(),
1385 RecvError::Error(_)
1386 ));
1387 drop(recv);
1388 leader.shutdown().await;
1389 }
1390
1391 #[async_test]
1392 async fn test_three(driver: DefaultDriver) {
1393 let (p1, p2) = channel::<u32>();
1394 let (p3, mut p4) = channel::<u32>();
1395 let (p5, p6) = channel::<u32>();
1396 let (p7, p8) = channel::<u32>();
1397
1398 let node1 = UnixNode::new(driver.clone());
1399
1400 let invitation = node1.invite(p2.into()).await.unwrap();
1401 let node2 = UnixNode::join(driver.clone(), invitation, p3.into())
1402 .await
1403 .unwrap();
1404
1405 let invitation = node1.invite(p5.into()).await.unwrap();
1406 let node3 = UnixNode::join(driver, invitation, p8.into()).await.unwrap();
1407
1408 p1.bridge(p6);
1409
1410 p7.send(5);
1411
1412 assert_eq!(p4.recv().await.unwrap(), 5);
1413 drop(p4);
1414 drop(p7);
1415 futures::join!(node2.shutdown(), node3.shutdown());
1416 node1.shutdown().await;
1417 }
1418
1419 #[async_test]
1420 async fn test_handoff_leader(driver: DefaultDriver) {
1421 let (p1, p2) = channel::<u32>();
1422 let (p3, p4) = channel::<u32>();
1423 let (p5, p6) = channel::<u32>();
1424 let (p7, p8) = channel::<u32>();
1425 let (p9, p10) = channel();
1426 let (p11, mut p12) = channel();
1427
1428 let node1 = UnixNode::new(driver.clone());
1429
1430 let invitation = node1.invite(p2.into()).await.unwrap();
1431 let node2 = UnixNode::join(driver.clone(), invitation, p3.into())
1432 .await
1433 .unwrap();
1434
1435 let invitation = node1.invite(p5.into()).await.unwrap();
1436 let node3 = UnixNode::join(driver.clone(), invitation, p8.into())
1437 .await
1438 .unwrap();
1439
1440 let invitation = node1.invite(p10.into()).await.unwrap();
1441 let node4 = UnixNode::join(driver, invitation, p11.into())
1442 .await
1443 .unwrap();
1444
1445 p9.send(node1.offer_leadership());
1446 node4.accept_leadership(p12.recv().await.unwrap());
1447 drop(p9);
1448 drop(p12);
1449 p1.bridge(p6);
1450
1451 std::thread::sleep(std::time::Duration::from_millis(200));
1452
1453 node1.shutdown().await;
1454 drop(p4);
1455 drop(p7);
1456 node2.shutdown().await;
1457 node3.shutdown().await;
1458
1459 std::thread::sleep(std::time::Duration::from_millis(200));
1460
1461 node4.shutdown().await;
1462 }
1463}