1mod protocol;
5
6use crate::common::Address;
7use crate::common::NodeId;
8use crate::common::PortId;
9use crate::message::Message;
10use crate::message::OwnedMessage;
11use crate::resource::OsResource;
12use crate::resource::Resource;
13use futures_channel::oneshot;
14use mesh_protobuf::DefaultEncoding;
15use mesh_protobuf::buffer::Buf;
16use mesh_protobuf::buffer::Buffer;
17use mesh_protobuf::buffer::write_with;
18use mesh_protobuf::protobuf::Encoder;
19use parking_lot::Mutex;
20use parking_lot::MutexGuard;
21use parking_lot::RwLock;
22use std::any::Any;
23use std::cmp::Reverse;
24use std::collections::BinaryHeap;
25use std::collections::HashMap;
26use std::collections::VecDeque;
27use std::collections::hash_map;
28use std::fmt;
29use std::fmt::Debug;
30use std::fmt::Display;
31use std::marker::PhantomData;
32use std::num::Wrapping;
33use std::sync::Arc;
34use std::sync::Weak;
35use std::sync::atomic::AtomicBool;
36use std::sync::atomic::AtomicIsize;
37use std::sync::atomic::Ordering;
38use std::task::Waker;
39use thiserror::Error;
40use zerocopy::FromBytes;
41use zerocopy::FromZeros;
42use zerocopy::IntoBytes;
43use zerocopy::Ref;
44use zerocopy::Unalign;
45
46pub struct Port {
52 inner: Arc<PortInner>,
53}
54
55impl Debug for Port {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 Debug::fmt(&self.inner.id, f)
58 }
59}
60
61impl Drop for Port {
62 fn drop(&mut self) {
63 self.inner.close(None);
64 }
65}
66
67impl Port {
68 pub fn new_pair() -> (Self, Self) {
70 let left_addr = Address {
71 node: NodeId::ZERO,
72 port: PortId::new(),
73 };
74 let right_addr = Address {
75 node: NodeId::ZERO,
76 port: PortId::new(),
77 };
78 let left = Self::new(
79 left_addr.port,
80 PortInnerState::new(PortActivity::Unreachable),
81 );
82 let right = Self::new(
83 right_addr.port,
84 PortInnerState::new(PortActivity::Peered(PortRef::LocalPort(left.inner.clone()))),
85 );
86 left.inner.state.lock().activity =
87 PortActivity::Peered(PortRef::LocalPort(right.inner.clone()));
88 tracing::trace!(left = ?left_addr.port, right = ?right_addr.port, "new port pair");
89 (left, right)
90 }
91
92 fn new(id: PortId, state: PortInnerState) -> Self {
94 Self {
95 inner: Arc::new(PortInner {
96 id,
97 state: Mutex::new(state),
98 }),
99 }
100 }
101
102 pub fn set_handler<T: HandlePortEvent>(self, handler: T) -> PortWithHandler<T> {
108 self.inner.set_handler(Box::new(handler));
109 PortWithHandler {
110 raw: self,
111 _phantom: PhantomData,
112 }
113 }
114
115 fn forget(self) {
117 self.into_inner();
118 }
119
120 fn repeer_if_done(&self, state: &mut PortInnerState) -> Option<Self> {
124 if matches!(state.activity, PortActivity::Done) {
125 let new_id = PortId::new();
126 let mut peer_state =
127 PortInnerState::new(PortActivity::Peered(PortRef::LocalPort(self.inner.clone())));
128 peer_state.next_local_seq = state.event_queue.next_peer_seq;
130 let peer_port = Self::new(new_id, peer_state);
131 state.set_activity(PortActivity::Peered(PortRef::LocalPort(
132 peer_port.inner.clone(),
133 )));
134 Some(peer_port)
135 } else {
136 None
137 }
138 }
139
140 fn prepare_to_send(self, remote_node: &Arc<RemoteNode>) -> protocol::ResourceData {
143 let old_address = Address {
144 node: remote_node.local_node.id,
145 port: self.inner.id,
146 };
147
148 let port_id = PortId::new();
149 let target = PortRef::RemotePort(remote_node.clone(), port_id);
150
151 let mut state = PortInner::associate(&self.inner, &remote_node.local_node);
153
154 let next_local_seq = state.next_local_seq + Wrapping(1);
156
157 let mut _port_to_close = self.repeer_if_done(&mut state);
161
162 let mut port_to_associate = None;
165 let (peer_node, peer_port) =
166 match std::mem::replace(&mut state.activity, PortActivity::Unreachable) {
167 PortActivity::Peered(peer) => {
168 let peer_addr = match &peer {
169 PortRef::LocalPort(peer_port) => {
170 port_to_associate = Some(peer_port.clone());
171 (remote_node.local_node.id, Some(peer_port.id))
172 }
173 PortRef::RemotePort(peer_node, peer_port_id) => {
174 (peer_node.id, Some(*peer_port_id))
175 }
176 };
177 state.set_activity(PortActivity::Sending { peer, target });
178 peer_addr
179 }
180 PortActivity::Failed(err) => {
181 let node_id = *err.node_id().unwrap_or(&remote_node.local_node.id);
182 state.activity = PortActivity::Failed(err);
183 (node_id, None)
184 }
185 state => panic!("invalid state: {:?}", state),
186 };
187
188 drop(state);
189 if let Some(port_to_associate) = &port_to_associate {
190 drop(PortInner::associate(
191 port_to_associate,
192 &remote_node.local_node,
193 ))
194 }
195
196 self.forget();
197
198 protocol::ResourceData {
199 id: port_id.0.into(),
200 next_local_seq: next_local_seq.0,
201 reserved: 0,
202 old_node: old_address.node.0.into(),
203 old_port: old_address.port.0.into(),
204 peer_node: peer_node.0.into(),
205 peer_port: peer_port.map_or(protocol::Uuid::ZERO, |p| p.0.into()),
206 }
207 }
208
209 pub fn bridge(self, other: Self) {
212 tracing::trace!(left = ?self.inner.id, right = ?other.inner.id, "bridging ports");
213
214 let get_peer_info = |state: &PortInnerState| {
215 match &state.activity {
216 PortActivity::Peered(peer) => {
217 let peer = peer.clone();
218 let initial_seq = state.next_local_seq + Wrapping(1);
220 Ok((peer, initial_seq))
221 }
222 PortActivity::Failed(err) => Err(err.clone()),
223 s => unreachable!("{:?}", s),
224 }
225 };
226
227 let start_proxy = |inner: &PortInner,
228 state: &mut PortInnerState,
229 target_info: Result<(PortRef, Seq), NodeError>,
230 pending_events: &mut PendingEvents<'_>| {
231 let result = match target_info {
232 Ok((PortRef::LocalPort(ref target), _)) if target.id == inner.id => {
233 Err(NodeError::local(PortError::CircularBridge))
235 }
236 Ok((target, initial_seq)) => {
237 match std::mem::replace(&mut state.activity, PortActivity::Unreachable) {
238 PortActivity::Peered(peer) => {
239 state.start_proxy(peer, target, initial_seq, pending_events);
240 Ok(())
241 }
242 activity @ PortActivity::Failed(_) => {
243 state.activity = activity;
244 Ok(())
245 }
246 s => unreachable!("{s:?}"),
247 }
248 }
249 Err(err) => Err(err),
250 };
251 if let Err(err) = result {
252 state.fail(pending_events, err);
253 inner.disassociate(&mut *state);
254 }
255 };
256
257 let (_this_repeer, _other_repeer);
258 let mut pending_events = PendingEvents::new();
259 {
260 let (mut this_state, mut other_state) = PortInner::lock_two(&self.inner, &other.inner);
261 _this_repeer = self.repeer_if_done(&mut this_state);
265 _other_repeer = other.repeer_if_done(&mut other_state);
266 let this_peer_info = get_peer_info(&this_state);
267 let other_peer_info = get_peer_info(&other_state);
268 start_proxy(
269 &self.inner,
270 &mut this_state,
271 other_peer_info,
272 &mut pending_events,
273 );
274 start_proxy(
275 &other.inner,
276 &mut other_state,
277 this_peer_info,
278 &mut pending_events,
279 );
280 }
281
282 pending_events.process();
283 self.forget();
284 other.forget();
285 }
286
287 pub fn send(&self, message: Message<'_>) {
289 self.inner.send(message);
290 }
291
292 pub fn send_and_close(self, message: Message<'_>) {
294 self.into_inner().close(Some(message));
296 }
297
298 pub fn send_protobuf<T: DefaultEncoding>(&self, value: T)
307 where
308 T::Encoding: mesh_protobuf::MessageEncode<T, Resource>,
309 {
310 self.send(crate::message::stack_message!(value));
311 }
312
313 pub fn send_protobuf_and_close<T: DefaultEncoding>(self, value: T)
323 where
324 T::Encoding: mesh_protobuf::MessageEncode<T, Resource>,
325 {
326 self.send_and_close(crate::message::stack_message!(value));
327 }
328
329 pub fn is_closed(&self) -> Result<bool, NodeError> {
330 match &self.inner.state.lock().activity {
331 PortActivity::Done => Ok(true),
332 PortActivity::Failed(err) => Err(err.clone()),
333 _ => Ok(false),
334 }
335 }
336
337 #[cfg(test)]
338 fn fail(self, err: NodeError) {
339 let mut pending_events = PendingEvents::new();
340 {
341 let mut state = self.inner.state.lock();
342 state.fail(&mut pending_events, err);
343 }
344 pending_events.process();
345 }
346}
347
348pub struct PortWithHandler<T> {
352 raw: Port,
353 _phantom: PhantomData<Arc<Mutex<T>>>,
354}
355
356impl<T> Debug for PortWithHandler<T> {
357 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358 f.debug_struct("PortWithHandler")
359 .field("raw", &self.raw)
360 .finish()
361 }
362}
363
364impl<T: HandlePortEvent> From<PortWithHandler<T>> for Port {
365 fn from(port: PortWithHandler<T>) -> Self {
366 port.remove_handler().0
367 }
368}
369
370impl<T: Default + HandlePortEvent> From<Port> for PortWithHandler<T> {
371 fn from(port: Port) -> Self {
372 port.set_handler(Default::default())
373 }
374}
375
376mod unsafe_code {
378 #![expect(unsafe_code)]
380
381 use super::Port;
382 use super::PortInner;
383 use super::PortWithHandler;
384 use std::mem::ManuallyDrop;
385 use std::sync::Arc;
386
387 impl Port {
388 pub(super) fn into_inner(self) -> Arc<PortInner> {
389 let Self { ref inner } = *ManuallyDrop::new(self);
390 unsafe { <*const _>::read(inner) }
392 }
393 }
394
395 impl<T> PortWithHandler<T> {
396 pub(super) fn into_port_preserve_handler(self) -> Port {
397 let Self {
398 ref raw,
399 _phantom: _,
400 } = *ManuallyDrop::new(self);
401 unsafe { <*const _>::read(raw) }
403 }
404 }
405}
406
407impl<T: HandlePortEvent> PortWithHandler<T> {
408 pub fn send(&self, message: Message<'_>) {
410 self.raw.send(message)
411 }
412
413 pub fn is_closed(&self) -> Result<bool, NodeError> {
414 self.raw.is_closed()
415 }
416
417 pub fn remove_handler(self) -> (Port, T) {
418 let port = self.into_port_preserve_handler();
419 let handler = port.inner.drain_queue().unwrap() as Box<dyn Any>;
420 (port, *handler.downcast().unwrap())
421 }
422
423 pub fn with_handler<R>(&self, f: impl FnOnce(&mut T) -> R) -> R {
424 let mut state = self.raw.inner.state.lock();
425 let handler = state.handler.as_mut().unwrap().as_mut() as &mut dyn Any;
426 f(handler.downcast_mut().unwrap())
427 }
428
429 pub fn with_port_and_handler<'a, R>(
430 &self,
431 f: impl FnOnce(&mut PortControl<'_, 'a>, &mut T) -> R,
432 ) -> R {
433 let mut pending_events = PendingEvents::new();
434 let mut state = self.raw.inner.state.lock();
435 let state = &mut *state;
436 let peer_and_seq = match &state.activity {
437 PortActivity::Peered(peer) => Some((peer, &mut state.next_local_seq)),
438 _ => None,
439 };
440 let mut control = PortControl {
441 peer_and_seq,
442 events: &mut pending_events,
443 };
444 let handler = state.handler.as_mut().unwrap().as_mut() as &mut dyn Any;
445 let r = f(&mut control, handler.downcast_mut().unwrap());
446 pending_events.process();
447 r
448 }
449}
450
451pub struct PortField;
453
454impl<T: Into<Port>, R: From<Port>> mesh_protobuf::FieldEncode<T, R> for PortField {
455 fn write_field(item: T, writer: mesh_protobuf::protobuf::FieldWriter<'_, '_, R>) {
456 writer.resource(item.into().into());
457 }
458
459 fn compute_field_size(_item: &mut T, sizer: mesh_protobuf::protobuf::FieldSizer<'_>) {
460 sizer.resource();
461 }
462}
463
464#[derive(Debug, Error)]
465#[error("missing port")]
466struct MissingPort;
467
468impl<T: From<Port>, R> mesh_protobuf::FieldDecode<'_, T, R> for PortField
469where
470 Port: TryFrom<R>,
471 <Port as TryFrom<R>>::Error: 'static + std::error::Error + Send + Sync,
472{
473 fn read_field(
474 item: &mut mesh_protobuf::inplace::InplaceOption<'_, T>,
475 reader: mesh_protobuf::protobuf::FieldReader<'_, '_, R>,
476 ) -> mesh_protobuf::Result<()> {
477 item.set(
478 Port::try_from(reader.resource()?)
479 .map_err(mesh_protobuf::Error::new)?
480 .into(),
481 );
482 Ok(())
483 }
484
485 fn default_field(
486 _item: &mut mesh_protobuf::inplace::InplaceOption<'_, T>,
487 ) -> mesh_protobuf::Result<()> {
488 Err(mesh_protobuf::Error::new(MissingPort))
489 }
490}
491
492impl DefaultEncoding for Port {
493 type Encoding = PortField;
494}
495
496pub struct LocalNode {
498 inner: Arc<LocalNodeInner>,
499 connector: Mutex<Option<Box<dyn Connect>>>,
500}
501
502impl Drop for LocalNode {
503 fn drop(&mut self) {
504 let err = NodeError::shutting_down();
505 self.inner.fail_all_ports(err.clone());
508 self.inner.fail_all_nodes(err);
511 }
512}
513
514#[derive(Debug)]
516struct LocalNodeInner {
517 id: NodeId,
518 state: Mutex<LocalNodeState>,
519}
520
521type Seq = Wrapping<u64>;
523
524#[derive(Debug, Copy, Clone)]
527struct SeqValue<T>(Seq, T);
528
529impl<T> PartialEq for SeqValue<T> {
530 fn eq(&self, other: &Self) -> bool {
531 self.0 == other.0
532 }
533}
534
535impl<T> Eq for SeqValue<T> {}
536
537impl<T> PartialOrd for SeqValue<T> {
538 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
539 Some(self.cmp(other))
540 }
541}
542
543impl<T> Ord for SeqValue<T> {
544 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
545 self.0.cmp(&other.0)
546 }
547}
548
549struct RemoteNode {
551 id: NodeId,
552 local_node: Arc<LocalNodeInner>,
558 state: RwLock<RemoteNodeState>,
559 failed: AtomicBool,
560 node_error: Mutex<Result<(), NodeError>>,
561 handle_count: AtomicIsize,
562}
563
564impl Debug for RemoteNode {
565 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
566 f.debug_struct("RemoteNode")
567 .field("local_node", &self.local_node.id)
568 .field("id", &self.id)
569 .field("failed", &self.failed)
570 .finish()
571 }
572}
573
574enum RemoteNodeState {
576 Queuing(Mutex<Vec<DeferredEvent>>),
577 Failed,
578 Active(Box<dyn SendEvent>),
579}
580
581#[derive(Debug)]
584struct DeferredEvent {
585 port_id: PortId,
586 seq: Seq,
587 event: OwnedPortEvent,
588}
589
590impl RemoteNode {
591 fn new(local_node: Arc<LocalNodeInner>, id: NodeId) -> (Arc<Self>, RemoteNodeHandle) {
592 let this = Arc::new(Self {
593 local_node,
594 id,
595 state: RwLock::new(RemoteNodeState::Queuing(Default::default())),
596 failed: AtomicBool::new(false),
597 node_error: Mutex::new(Ok(())),
598 handle_count: AtomicIsize::new(1),
599 });
600 let handle = RemoteNodeHandle {
601 id,
602 remote_node: Arc::downgrade(&this),
603 };
604 (this, handle)
605 }
606
607 fn connect(self: &Arc<Self>, conn: Box<dyn SendEvent>) -> bool {
609 let events = {
610 let mut state = self.state.write();
611 match &mut *state {
612 RemoteNodeState::Queuing(v) => {
613 let v = std::mem::take(v.get_mut());
614 *state = RemoteNodeState::Active(conn);
615 v
616 }
617 _ => return false,
618 }
619 };
620 self.check_failed();
621 for event in events {
622 self.event(event.port_id, event.seq, event.event.into());
623 }
624 true
625 }
626
627 fn check_failed(&self) {
628 if self.failed.load(Ordering::SeqCst) {
629 let _old = std::mem::replace(&mut *self.state.write(), RemoteNodeState::Failed);
630 }
631 }
632
633 fn fail(&self, err: NodeError) {
635 *self.node_error.lock() = Err(err);
636 self.failed.store(true, Ordering::SeqCst);
637 if let Some(mut state) = self.state.try_write() {
641 let _old = std::mem::replace(&mut *state, RemoteNodeState::Failed);
642 }
643 }
644
645 fn event(self: &Arc<Self>, port_id: PortId, seq: Seq, event: PortEvent<'_>) {
647 match &*self.state.read() {
648 RemoteNodeState::Queuing(v) => {
649 v.lock().push(DeferredEvent {
650 port_id,
651 seq,
652 event: event.into_owned(),
653 });
654 }
655 RemoteNodeState::Failed => (),
656 RemoteNodeState::Active(conn) => {
657 conn.event(OutgoingEvent::new(port_id, seq, event, self))
658 }
659 }
660 self.check_failed();
661 }
662
663 fn node_status(&self) -> Result<(), NodeError> {
665 if !self.failed.load(Ordering::SeqCst) {
666 return Ok(());
667 }
668 self.node_error.lock().clone()
669 }
670}
671
672#[derive(Debug)]
674struct PortInner {
675 id: PortId,
676 state: Mutex<PortInnerState>,
677}
678
679pub struct PortControl<'a, 'm> {
681 peer_and_seq: Option<(&'a PortRef, &'a mut Seq)>,
682 events: &'a mut PendingEvents<'m>,
683}
684
685impl<'a, 'm> PortControl<'a, 'm> {
686 fn peered(peer: &'a PortRef, seq: &'a mut Seq, events: &'a mut PendingEvents<'m>) -> Self {
687 Self {
688 peer_and_seq: Some((peer, seq)),
689 events,
690 }
691 }
692
693 fn unpeered(events: &'a mut PendingEvents<'m>) -> Self {
694 Self {
695 peer_and_seq: None,
696 events,
697 }
698 }
699
700 pub fn respond(&mut self, message: Message<'m>) {
702 if let Some((port_ref, seq)) = &mut self.peer_and_seq {
703 let this = **seq;
704 **seq += Wrapping(1);
705 self.events.push(
706 port_ref.clone(),
707 this,
708 PortEvent::Message {
709 message: Some(message),
710 close: false,
711 },
712 )
713 }
714 }
715
716 pub fn wake(&mut self, waker: Waker) {
718 self.events.wake(waker);
719 }
720}
721
722pub trait HandlePortEvent: 'static + Send {
727 fn message<'a>(
732 &mut self,
733 control: &mut PortControl<'_, 'a>,
734 message: Message<'a>,
735 ) -> Result<(), HandleMessageError>;
736
737 fn close(&mut self, control: &mut PortControl<'_, '_>);
739
740 fn fail(&mut self, control: &mut PortControl<'_, '_>, err: NodeError);
742
743 fn drain(&mut self) -> Vec<OwnedMessage>;
748}
749
750pub struct HandleMessageError(Box<dyn std::error::Error + Send + Sync>);
753
754impl HandleMessageError {
755 pub fn new<E: Into<Box<dyn std::error::Error + Send + Sync>>>(err: E) -> Self {
757 Self(err.into())
758 }
759}
760
761#[derive(Clone, Debug, Error)]
763#[error(transparent)]
764pub struct NodeError(Arc<NodeErrorInner>);
765
766impl NodeError {
767 fn new(node: &NodeId, source: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> Self {
768 Self(Arc::new(NodeErrorInner {
769 node_id: Some(*node),
770 source: source.into(),
771 }))
772 }
773
774 fn local(source: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> Self {
775 Self(Arc::new(NodeErrorInner {
776 node_id: None,
777 source: source.into(),
778 }))
779 }
780
781 fn shutting_down() -> Self {
782 Self::local(ShuttingDownError)
783 }
784
785 fn remote_node_id(&self) -> Option<&NodeId> {
786 if let Some(err) = self.0.source.downcast_ref::<RemotePortError>() {
787 Some(&err.0)
788 } else {
789 self.0.node_id.as_ref()
790 }
791 }
792
793 fn node_id(&self) -> Option<&NodeId> {
794 self.0.node_id.as_ref()
795 }
796}
797
798#[derive(Debug, Error)]
799struct NodeErrorInner {
800 node_id: Option<NodeId>,
801 source: Box<dyn std::error::Error + Send + Sync>,
802}
803
804impl Display for NodeErrorInner {
805 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
806 if let Some(node_id) = self.node_id {
807 write!(f, "communication with node {node_id:?} failed")
808 } else {
809 write!(f, "local mesh failure")
810 }
811 }
812}
813
814#[derive(Debug, Error)]
815#[error("mesh is shutting down")]
816struct ShuttingDownError;
817
818#[derive(Debug, Error)]
819#[error("received unknown local port")]
820struct UnknownLocalPort;
821
822#[derive(Debug, Error)]
823#[error("port failed on remote node due to node {0:?}")]
824struct RemotePortError(NodeId);
825
826#[derive(Debug, Error)]
827#[error("remote node disconnected")]
828struct RemoteNodeDisconnected;
829
830#[derive(Debug, Error)]
831#[error("remote node dropped")]
832struct RemoteNodeDropped;
833
834trait HandlePortEventAndAny: HandlePortEvent + Any {}
835
836impl Debug for dyn HandlePortEventAndAny {
837 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
838 f.pad("HandlePortEvent")
839 }
840}
841
842impl<T: HandlePortEvent> HandlePortEventAndAny for T {}
843
844#[derive(Debug)]
846struct PortInnerState {
847 activity: PortActivity,
848 local_node: Option<Weak<LocalNodeInner>>,
849
850 event_queue: EventQueue,
851 handler: Option<Box<dyn HandlePortEventAndAny>>,
852
853 next_local_seq: Seq,
854 is_local_closed: bool,
855}
856
857#[derive(Default)]
861struct QueuingHandler {
862 messages: Vec<OwnedMessage>,
863}
864
865impl HandlePortEvent for QueuingHandler {
866 fn message(
867 &mut self,
868 _control: &mut PortControl<'_, '_>,
869 message: Message<'_>,
870 ) -> Result<(), HandleMessageError> {
871 self.messages.push(message.into_owned());
872 Ok(())
873 }
874
875 fn close(&mut self, _control: &mut PortControl<'_, '_>) {}
876
877 fn fail(&mut self, _control: &mut PortControl<'_, '_>, _err: NodeError) {}
878
879 fn drain(&mut self) -> Vec<OwnedMessage> {
880 std::mem::take(&mut self.messages)
881 }
882}
883
884#[derive(Debug)]
885struct EventQueue {
886 next_peer_seq: Seq,
887 heap: BinaryHeap<Reverse<SeqValue<OwnedPortEvent>>>,
888}
889
890impl EventQueue {
891 fn new() -> Self {
892 Self {
893 next_peer_seq: Wrapping(1),
894 heap: BinaryHeap::new(),
895 }
896 }
897
898 fn pop<'a>(&mut self, v: Option<(Seq, PortEvent<'a>)>) -> Option<PortEvent<'a>> {
903 if let Some((seq, event)) = v {
904 if seq == self.next_peer_seq {
905 self.next_peer_seq += Wrapping(1);
906 return Some(event);
907 }
908 self.add(seq, event);
909 }
910 if let Some(Reverse(SeqValue(seq, _))) = self.heap.peek() {
911 if *seq > self.next_peer_seq {
912 return None;
913 }
914 let Reverse(SeqValue(_, port_event)) = self.heap.pop().unwrap();
915 self.next_peer_seq += Wrapping(1);
916 return Some(port_event.into());
917 }
918 None
919 }
920
921 fn add(&mut self, seq: Seq, event: PortEvent<'_>) {
922 assert!(seq >= self.next_peer_seq);
923 self.heap.push(Reverse(SeqValue(seq, event.into_owned())));
924 }
925
926 fn is_empty(&self) -> bool {
927 self.heap.is_empty()
928 }
929}
930
931#[derive(Clone, Debug)]
933enum PortActivity {
934 Peered(PortRef),
935 Sending { peer: PortRef, target: PortRef },
936 Proxying { peer: PortRef, target: PortRef },
937 Failed(NodeError),
938 Done,
939 Unreachable,
940}
941
942#[derive(Clone)]
944enum PortRef {
945 LocalPort(Arc<PortInner>),
946 RemotePort(Arc<RemoteNode>, PortId),
947}
948
949impl Debug for PortRef {
950 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
951 match self {
952 PortRef::LocalPort(port) => f.debug_tuple("LocalPort").field(&port.id).finish(),
953 PortRef::RemotePort(remote_node, port_id) => f
954 .debug_tuple("RemotePort")
955 .field(&remote_node.id)
956 .field(&port_id)
957 .finish(),
958 }
959 }
960}
961
962impl PortRef {
963 fn node_status(&self) -> Result<(), NodeError> {
964 match self {
965 PortRef::LocalPort(_) => Ok(()),
966 PortRef::RemotePort(node, _) => node.node_status(),
967 }
968 }
969
970 fn is_compatible_node(&self, local_node: &Option<Weak<LocalNodeInner>>) -> bool {
972 match local_node {
973 None => true,
974 Some(local_node) => match self {
975 PortRef::LocalPort(_) => true,
976 PortRef::RemotePort(node, _) => {
977 Weak::as_ptr(local_node) == Arc::as_ptr(&node.local_node)
978 }
979 },
980 }
981 }
982}
983
984impl PortInnerState {
985 fn new(activity: PortActivity) -> Self {
986 Self {
987 local_node: None,
988 activity,
989 next_local_seq: Wrapping(1),
990 event_queue: EventQueue::new(),
991 handler: None,
992 is_local_closed: false,
993 }
994 }
995
996 fn fail(&mut self, pending_events: &mut PendingEvents<'_>, err: NodeError) {
998 match std::mem::replace(&mut self.activity, PortActivity::Failed(err.clone())) {
999 PortActivity::Peered(peer) => {
1000 pending_events.push(peer, Wrapping(0), NonMessageEvent::FailPort(err));
1001 }
1002 PortActivity::Sending { peer, target } | PortActivity::Proxying { peer, target } => {
1003 pending_events.push(peer, Wrapping(0), NonMessageEvent::FailPort(err.clone()));
1004 pending_events.push(target, Wrapping(0), NonMessageEvent::FailPort(err.clone()));
1005 }
1006 activity @ PortActivity::Failed(_) => {
1007 self.activity = activity;
1009 }
1010 PortActivity::Done => {}
1011 PortActivity::Unreachable => unreachable!(),
1012 }
1013 }
1014
1015 fn set_activity(&mut self, activity: PortActivity) {
1016 self.activity = activity;
1017 }
1018
1019 fn next_peer_and_seq(&mut self) -> Option<(PortRef, Seq)> {
1022 match &self.activity {
1023 PortActivity::Peered(peer) => {
1024 let peer = peer.clone();
1025 let seq = self.next_local_seq;
1026 self.next_local_seq += Wrapping(1);
1027 Some((peer, seq))
1028 }
1029 PortActivity::Done | PortActivity::Failed(_) => None, s => unreachable!("{:?}", s),
1031 }
1032 }
1033}
1034
1035#[derive(Debug)]
1037enum EventError {
1038 UnknownPort,
1039 Truncated,
1040 UnknownEventType(#[expect(dead_code)] protocol::EventType),
1042 MissingOsResource,
1043}
1044
1045struct PendingEvents<'a> {
1048 local_events: VecDeque<(Arc<PortInner>, Seq, PortEvent<'a>)>,
1049 remote_events: Vec<(Arc<RemoteNode>, PortId, Seq, PortEvent<'a>)>,
1050 wakers: Vec<Waker>,
1051}
1052
1053impl<'a> PendingEvents<'a> {
1054 fn new() -> Self {
1055 Self {
1056 local_events: VecDeque::new(),
1057 remote_events: Vec::new(),
1058 wakers: Vec::new(),
1059 }
1060 }
1061
1062 fn send_local(
1065 port: &Arc<PortInner>,
1066 remote_node_id: Option<&NodeId>,
1067 seq: Seq,
1068 event: PortEvent<'a>,
1069 ) {
1070 let mut this = Self::new();
1071 port.on_event(remote_node_id, seq, event, &mut this);
1072 this.process();
1073 }
1074
1075 fn send(port: &PortRef, seq: Seq, event: impl Into<PortEvent<'a>>) {
1078 let event = event.into();
1079 match port {
1080 PortRef::LocalPort(port) => Self::send_local(port, None, seq, event),
1081 PortRef::RemotePort(remote_node, port_id) => {
1082 remote_node.event(*port_id, seq, event);
1083 }
1084 }
1085 }
1086
1087 fn process(mut self) {
1089 while let Some((port, seq, event)) = self.local_events.pop_front() {
1090 port.on_event(None, seq, event, &mut self);
1091 }
1092 for (remote_node, port_id, seq, event) in self.remote_events.drain(..) {
1093 remote_node.event(port_id, seq, event);
1094 }
1095 for waker in self.wakers {
1096 waker.wake();
1097 }
1098 }
1099
1100 fn push_local(&mut self, port: Arc<PortInner>, seq: Seq, event: PortEvent<'a>) {
1102 self.local_events.push_back((port, seq, event));
1103 }
1104
1105 fn push(&mut self, port: PortRef, seq: Seq, event: impl Into<PortEvent<'a>>) {
1107 let event = event.into();
1108 match port {
1109 PortRef::LocalPort(port) => self.push_local(port, seq, event),
1110 PortRef::RemotePort(remote_node, port_id) => {
1111 self.remote_events.push((remote_node, port_id, seq, event));
1112 }
1113 }
1114 }
1115
1116 fn wake(&mut self, waker: Waker) {
1117 self.wakers.push(waker);
1118 }
1119}
1120
1121#[derive(Debug, Error)]
1123enum PortError {
1124 #[error("duplicate sequence number")]
1125 DuplicateSeq { next: Seq },
1126 #[error("received event after port closed")]
1127 EventAfterClose,
1128 #[error("unexpected acknowledgement of peer change")]
1129 AckChangePeerInvalidState,
1130 #[error("received event after proxy end")]
1131 EventAfterProxyEnd,
1132 #[error("circular bridge")]
1133 CircularBridge,
1134 #[error("invalid state for proxy")]
1135 InvalidStateForProxy,
1136 #[error("failed to parse message")]
1137 BadMessage(#[source] Box<dyn std::error::Error + Send + Sync>),
1138}
1139
1140enum PortEventResult {
1142 None,
1144 Done,
1147}
1148
1149impl PortInnerState {
1150 fn on_event<'a>(
1152 &mut self,
1153 remote_node_id: Option<&NodeId>,
1154 seq: Seq,
1155 event: PortEvent<'a>,
1156 pending_events: &mut PendingEvents<'a>,
1157 ) -> Result<PortEventResult, NodeError> {
1158 if let PortEvent::Event(NonMessageEvent::FailPort(err)) = event {
1159 return Err(err);
1160 }
1161
1162 let err = 'error: {
1163 if seq < self.event_queue.next_peer_seq {
1164 break 'error PortError::DuplicateSeq {
1165 next: self.event_queue.next_peer_seq,
1166 };
1167 }
1168
1169 match &mut self.activity {
1170 PortActivity::Peered(peer) => {
1171 let mut v = Some((seq, event));
1172 while let Some(port_event) = self.event_queue.pop(v.take()) {
1173 match port_event {
1174 PortEvent::Message { message, close } => {
1175 if let Some(message) = message {
1176 let handler = self
1177 .handler
1178 .get_or_insert_with(|| Box::new(QueuingHandler::default()));
1179 if let Err(err) = handler.message(
1180 &mut PortControl::peered(
1181 peer,
1182 &mut self.next_local_seq,
1183 pending_events,
1184 ),
1185 message,
1186 ) {
1187 break 'error PortError::BadMessage(err.0);
1188 }
1189 }
1190 if close {
1191 if !self.event_queue.is_empty() {
1192 break 'error PortError::EventAfterClose;
1193 }
1194 if !self.is_local_closed {
1195 pending_events.push(
1196 peer.clone(),
1197 self.next_local_seq,
1198 PortEvent::Message {
1199 message: None,
1200 close: true,
1201 },
1202 );
1203 }
1204 return Ok(PortEventResult::Done);
1205 }
1206 }
1207 PortEvent::Event(e) => match e {
1208 NonMessageEvent::ChangePeer(new_peer, seq_delta) => {
1209 assert!(new_peer.is_compatible_node(&self.local_node));
1210 new_peer.node_status()?;
1211 let old_peer = std::mem::replace(peer, new_peer);
1212 pending_events.push(
1213 old_peer,
1214 self.next_local_seq,
1215 NonMessageEvent::AcknowledgeChangePeer,
1216 );
1217 self.next_local_seq -= seq_delta;
1218 }
1219 NonMessageEvent::AcknowledgeChangePeer => {
1220 break 'error PortError::AckChangePeerInvalidState;
1221 }
1222 NonMessageEvent::AcknowledgePort | NonMessageEvent::FailPort(_) => {
1223 unreachable!()
1224 }
1225 },
1226 }
1227 }
1228 return Ok(PortEventResult::None);
1229 }
1230 PortActivity::Sending { .. } => {
1231 self.event_queue.add(seq, event);
1234 return Ok(PortEventResult::None);
1235 }
1236 PortActivity::Proxying { peer: _, target } => {
1237 let target = target.clone();
1238
1239 let mut v = Some((seq, event));
1240 let mut next_seq = self.next_local_seq;
1241 while let Some(port_event) = self.event_queue.pop(v.take()) {
1242 match port_event {
1243 PortEvent::Event(NonMessageEvent::AcknowledgeChangePeer) => {
1244 if !self.event_queue.is_empty() {
1245 break 'error PortError::EventAfterProxyEnd;
1246 }
1247 return Ok(PortEventResult::Done);
1248 }
1249 event => {
1250 if let PortEvent::Event(NonMessageEvent::ChangePeer(new_peer, _)) =
1251 &event
1252 {
1253 assert!(new_peer.is_compatible_node(&self.local_node));
1254 new_peer.node_status()?;
1255 self.set_activity(PortActivity::Proxying {
1256 peer: new_peer.clone(),
1257 target: target.clone(),
1258 });
1259 }
1260 pending_events.push(target.clone(), next_seq, event);
1261 next_seq += Wrapping(1);
1262 }
1263 }
1264 }
1265
1266 self.next_local_seq = next_seq;
1267 return Ok(PortEventResult::None);
1268 }
1269 PortActivity::Done => PortError::EventAfterClose,
1270 PortActivity::Failed(err) => return Err(err.clone()),
1271 PortActivity::Unreachable => unreachable!(),
1272 }
1273 };
1274 if let Some(remote_node_id) = remote_node_id {
1275 Err(NodeError::new(remote_node_id, err))
1276 } else {
1277 Err(NodeError::local(err))
1278 }
1279 }
1280
1281 fn start_proxy(
1283 &mut self,
1284 peer: PortRef,
1285 target: PortRef,
1286 initial_seq: Seq,
1287 pending_events: &mut PendingEvents<'_>,
1288 ) {
1289 let mut seq = initial_seq;
1290
1291 if let Some(handler) = &mut self.handler {
1293 for message in handler.drain() {
1294 pending_events.push(
1295 target.clone(),
1296 seq,
1297 OwnedPortEvent::Message {
1298 message: Some(message),
1299 close: false,
1300 },
1301 );
1302 seq += Wrapping(1);
1303 }
1304 }
1305
1306 while let Some(port_event) = self.event_queue.pop(None) {
1308 pending_events.push(target.clone(), seq, port_event);
1309 seq += Wrapping(1);
1310 }
1311
1312 let change_seq = self.next_local_seq;
1313
1314 self.next_local_seq = seq;
1315 let delta = self.event_queue.next_peer_seq - self.next_local_seq;
1316 self.set_activity(PortActivity::Proxying {
1317 peer: peer.clone(),
1318 target: target.clone(),
1319 });
1320
1321 pending_events.push(peer, change_seq, NonMessageEvent::ChangePeer(target, delta));
1322 }
1323}
1324
1325impl PortInner {
1326 fn send(&self, message: Message<'_>) {
1328 let peer_seq = {
1329 let mut state = self.state.lock();
1330 assert!(!state.is_local_closed);
1331 state.next_peer_and_seq()
1332 };
1333
1334 if let Some((peer, seq)) = peer_seq {
1335 PendingEvents::send(
1336 &peer,
1337 seq,
1338 PortEvent::Message {
1339 message: Some(message),
1340 close: false,
1341 },
1342 );
1343 }
1344 }
1345
1346 fn close(&self, message: Option<Message<'_>>) {
1349 let _old_handler;
1350 let peer_seq = {
1351 let mut state = self.state.lock();
1352 assert!(!state.is_local_closed);
1353
1354 _old_handler = std::mem::take(&mut state.handler);
1356
1357 state.is_local_closed = true;
1358 state.next_peer_and_seq()
1359 };
1360
1361 if let Some((peer, seq)) = peer_seq {
1362 PendingEvents::send(
1363 &peer,
1364 seq,
1365 PortEvent::Message {
1366 message,
1367 close: true,
1368 },
1369 );
1370 }
1371 }
1372
1373 fn on_event<'a>(
1375 &self,
1376 remote_node_id: Option<&NodeId>,
1377 seq: Seq,
1378 event: PortEvent<'a>,
1379 pending_events: &mut PendingEvents<'a>,
1380 ) {
1381 let mut state = self.state.lock();
1382 let mut disassociate = false;
1383 match state.on_event(remote_node_id, seq, event, pending_events) {
1384 Ok(PortEventResult::None) => {}
1385 Ok(PortEventResult::Done) => {
1386 state.set_activity(PortActivity::Done);
1387 if let Some(handler) = &mut state.handler {
1388 handler.close(&mut PortControl::unpeered(pending_events));
1389 }
1390 disassociate = true;
1391 }
1392 Err(err) => {
1393 state.fail(pending_events, err.clone());
1394 if let Some(handler) = &mut state.handler {
1395 handler.fail(&mut PortControl::unpeered(pending_events), err);
1396 }
1397 disassociate = true;
1398 }
1399 }
1400
1401 if disassociate {
1402 self.disassociate(&mut state);
1403 }
1404 drop(state);
1405 }
1406
1407 fn start_proxy(
1409 &self,
1410 remote_node_id: &NodeId,
1411 initial_seq: Seq,
1412 pending_events: &mut PendingEvents<'_>,
1413 ) {
1414 tracing::trace!(port = ?self.id, initial_seq, "proxy starting");
1415 let mut state = self.state.lock();
1416
1417 let mut err = None;
1418 match std::mem::replace(&mut state.activity, PortActivity::Unreachable) {
1419 PortActivity::Sending { peer, target } => {
1420 state.start_proxy(peer, target, initial_seq, pending_events);
1421 }
1422 activity => {
1423 state.activity = activity;
1424 err = Some(NodeError::new(
1425 remote_node_id,
1426 PortError::InvalidStateForProxy,
1427 ));
1428 }
1429 };
1430
1431 if let Some(err) = err {
1432 self.disassociate(&mut state);
1433 if let Some(handler) = &mut state.handler {
1434 handler.fail(
1435 &mut PortControl::unpeered(pending_events),
1436 NodeError::new(remote_node_id, err),
1437 );
1438 }
1439
1440 drop(state);
1441 tracing::error!(port = ?self.id, "proxy from wrong state");
1443 }
1444 }
1445
1446 fn associate<'a>(
1450 inner: &'a Arc<Self>,
1451 local_node: &Arc<LocalNodeInner>,
1452 ) -> MutexGuard<'a, PortInnerState> {
1453 let mut state = inner.state.lock();
1454 match &state.local_node {
1455 Some(node) => assert_eq!(Arc::as_ptr(local_node), node.as_ptr()),
1456 None => {
1457 local_node
1458 .state
1459 .lock()
1460 .ports
1461 .insert(inner.id, inner.clone());
1462 state.local_node = Some(Arc::downgrade(local_node));
1463 }
1464 }
1465 state
1466 }
1467
1468 fn disassociate(&self, port_state: &mut PortInnerState) {
1470 if let Some(local_node) = port_state
1471 .local_node
1472 .take()
1473 .as_ref()
1474 .and_then(Weak::upgrade)
1475 {
1476 tracing::trace!(node = ?local_node.id, port = ?self.id, "disassociate port");
1477 let mut state = local_node.state.lock();
1478 state.ports.remove(&self.id);
1479 let shutdown = state.shutdown.take();
1480 drop(state);
1481 if shutdown.is_some() {
1483 tracing::trace!(node = ?local_node.id, "waking shutdown waiter");
1484 }
1485 }
1486 }
1487
1488 fn lock_two<'a>(
1491 left: &'a Self,
1492 right: &'a Self,
1493 ) -> (
1494 MutexGuard<'a, PortInnerState>,
1495 MutexGuard<'a, PortInnerState>,
1496 ) {
1497 let (lm, rm);
1501 if std::ptr::from_ref(left) < std::ptr::from_ref(right) {
1502 lm = left.state.lock();
1503 rm = right.state.lock();
1504 } else {
1505 rm = right.state.lock();
1506 lm = left.state.lock();
1507 }
1508 (lm, rm)
1509 }
1510
1511 fn set_handler(&self, mut handler: Box<dyn HandlePortEventAndAny>) {
1512 let mut pending_events = PendingEvents::new();
1513 {
1514 let mut state = self.state.lock();
1515 let state = &mut *state;
1516 let peer_and_seq = match &state.activity {
1517 PortActivity::Peered(peer) => Some((peer, &mut state.next_local_seq)),
1518 _ => None,
1519 };
1520 let mut control = PortControl {
1521 peer_and_seq,
1522 events: &mut pending_events,
1523 };
1524 if let Some(mut old_handler) = state.handler.take() {
1525 for message in old_handler.drain() {
1526 if let Err(err) = handler.message(&mut control, message.into()) {
1527 state.fail(
1528 &mut pending_events,
1529 NodeError::local(PortError::BadMessage(err.0)),
1530 );
1531 break;
1532 }
1533 }
1534 }
1535 match &state.activity {
1536 PortActivity::Peered(_) => {}
1537 PortActivity::Failed(err) => {
1538 handler.fail(&mut PortControl::unpeered(&mut pending_events), err.clone())
1539 }
1540 PortActivity::Done => {
1541 handler.close(&mut PortControl::unpeered(&mut pending_events))
1542 }
1543 _ => unreachable!(),
1544 }
1545 state.handler = Some(handler);
1546 }
1547 pending_events.process();
1548 }
1549
1550 fn drain_queue(&self) -> Option<Box<dyn HandlePortEventAndAny>> {
1551 let mut state = self.state.lock();
1552 let mut handler = state.handler.take();
1553 let messages = handler
1554 .as_mut()
1555 .map_or_else(Vec::new, |handler| handler.drain());
1556 if !messages.is_empty() {
1557 state.handler = Some(Box::new(QueuingHandler { messages }));
1558 }
1559 handler
1560 }
1561}
1562
1563pub struct RemoteNodeHandle {
1566 id: NodeId,
1567 remote_node: Weak<RemoteNode>,
1568}
1569
1570impl Debug for RemoteNodeHandle {
1571 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1572 f.debug_struct("RemoteNodeHandle")
1573 .field("id", &self.id)
1574 .finish()
1575 }
1576}
1577
1578impl Drop for RemoteNodeHandle {
1579 fn drop(&mut self) {
1580 if let Some(remote_node) = self.remote_node.upgrade() {
1581 remote_node.local_node.drop_remote_handle(&remote_node);
1582 }
1583 }
1584}
1585
1586impl RemoteNodeHandle {
1587 pub fn id(&self) -> &NodeId {
1588 &self.id
1589 }
1590
1591 pub fn connect<T: 'static + SendEvent>(&self, conn: T) -> bool {
1593 if let Some(remote_node) = self.remote_node.upgrade() {
1594 remote_node.connect(Box::new(conn))
1595 } else {
1596 false
1597 }
1598 }
1599
1600 pub fn disconnect(&self) {
1601 self.fail(RemoteNodeDisconnected)
1602 }
1603
1604 pub fn fail(&self, err: impl Into<Box<dyn std::error::Error + Send + Sync>>) {
1605 if let Some(remote_node) = self.remote_node.upgrade() {
1606 remote_node
1607 .local_node
1608 .disconnect_remote(&remote_node, NodeError::new(&remote_node.id, err));
1609 }
1610 }
1611}
1612
1613impl Clone for RemoteNodeHandle {
1614 fn clone(&self) -> Self {
1615 if let Some(remote_node) = self.remote_node.upgrade() {
1616 assert!(remote_node.handle_count.fetch_add(1, Ordering::SeqCst) > 0);
1617 }
1618 Self {
1619 id: self.id,
1620 remote_node: self.remote_node.clone(),
1621 }
1622 }
1623}
1624
1625#[derive(Debug)]
1627struct LocalNodeState {
1628 ports: HashMap<PortId, Arc<PortInner>>,
1629 nodes: HashMap<NodeId, Arc<RemoteNode>>,
1630 shutdown: Option<oneshot::Sender<()>>,
1631}
1632
1633#[derive(Debug)]
1635enum PortEvent<'a> {
1636 Message {
1637 message: Option<Message<'a>>,
1638 close: bool,
1639 },
1640 Event(NonMessageEvent),
1641}
1642
1643impl From<NonMessageEvent> for PortEvent<'_> {
1644 fn from(value: NonMessageEvent) -> Self {
1645 PortEvent::Event(value)
1646 }
1647}
1648
1649impl From<OwnedPortEvent> for PortEvent<'_> {
1650 fn from(value: OwnedPortEvent) -> Self {
1651 match value {
1652 OwnedPortEvent::Message { message, close } => PortEvent::Message {
1653 message: message.map(Into::into),
1654 close,
1655 },
1656 OwnedPortEvent::Event(e) => PortEvent::Event(e),
1657 }
1658 }
1659}
1660
1661impl PortEvent<'_> {
1662 fn into_owned(self) -> OwnedPortEvent {
1663 match self {
1664 PortEvent::Message { message, close } => OwnedPortEvent::Message {
1665 message: message.map(|m| m.into_owned()),
1666 close,
1667 },
1668 PortEvent::Event(event) => OwnedPortEvent::Event(event),
1669 }
1670 }
1671}
1672
1673#[derive(Debug)]
1675enum OwnedPortEvent {
1676 Message {
1677 message: Option<OwnedMessage>,
1678 close: bool,
1679 },
1680 Event(NonMessageEvent),
1681}
1682
1683#[derive(Debug)]
1685enum NonMessageEvent {
1686 ChangePeer(PortRef, Seq),
1687 AcknowledgeChangePeer,
1688 AcknowledgePort,
1689 FailPort(NodeError),
1690}
1691
1692pub struct OutgoingEvent<'a> {
1694 port_id: PortId,
1695 seq: Seq,
1696 event: EventAndEncoder<'a>,
1697 len: usize,
1698 remote_node: &'a Arc<RemoteNode>,
1699}
1700
1701enum EventAndEncoder<'a> {
1702 Message {
1703 message: Option<Encoder<Message<'a>, <Message<'a> as DefaultEncoding>::Encoding, Resource>>,
1704 close: bool,
1705 },
1706 Other(NonMessageEvent),
1707}
1708
1709impl<'a> OutgoingEvent<'a> {
1710 fn new(
1711 port_id: PortId,
1712 seq: Seq,
1713 event: PortEvent<'a>,
1714 remote_node: &'a Arc<RemoteNode>,
1715 ) -> Self {
1716 let mut len = size_of::<protocol::Event>();
1717 let event = match event {
1718 PortEvent::Message { message, close } => {
1719 let message = message.map(|m| {
1720 let message = Encoder::new(m);
1721 len += message.resource_count() * size_of::<protocol::ResourceData>();
1722 len += message.len();
1723 message
1724 });
1725 EventAndEncoder::Message { message, close }
1726 }
1727 PortEvent::Event(event) => match event {
1728 NonMessageEvent::ChangePeer(_, _) => {
1729 len += size_of::<protocol::ChangePeerData>();
1730 EventAndEncoder::Other(event)
1731 }
1732 NonMessageEvent::FailPort(_) => {
1733 len += size_of::<protocol::FailPortData>();
1734 EventAndEncoder::Other(event)
1735 }
1736 event @ (NonMessageEvent::AcknowledgeChangePeer
1737 | NonMessageEvent::AcknowledgePort) => EventAndEncoder::Other(event),
1738 },
1739 };
1740 Self {
1741 port_id,
1742 seq,
1743 event,
1744 len,
1745 remote_node,
1746 }
1747 }
1748
1749 pub fn len(&self) -> usize {
1751 self.len
1752 }
1753
1754 pub fn write_to(self, buf: &mut dyn Buffer, os_resources: &mut impl Extend<OsResource>) {
1756 write_with(buf, |mut buf| {
1757 buf.write_split(size_of::<protocol::Event>(), |header_buf, buf| {
1758 self.write_split(header_buf, buf, os_resources);
1759 })
1760 })
1761 }
1762
1763 fn write_split(
1764 self,
1765 mut header_buf: Buf<'_>,
1766 mut buf: Buf<'_>,
1767 os_resources: &mut impl Extend<OsResource>,
1768 ) {
1769 let mut header = protocol::Event {
1770 port_id: self.port_id.0.into(),
1771 seq: self.seq.0,
1772 ..protocol::Event::new_zeroed()
1773 };
1774 match self.event {
1775 EventAndEncoder::Other(event) => match event {
1776 NonMessageEvent::ChangePeer(port, seq_delta) => {
1777 let (node_id, port_id) = match port {
1778 PortRef::LocalPort(port) => {
1779 drop(PortInner::associate(&port, &self.remote_node.local_node));
1780 (self.remote_node.local_node.id, port.id)
1781 }
1782 PortRef::RemotePort(remote_node, port_id) => (remote_node.id, port_id),
1783 };
1784 header.event_type = protocol::EventType::CHANGE_PEER;
1785 header.message_size = size_of::<protocol::ChangePeerData>() as u32;
1786 buf.append(
1787 protocol::ChangePeerData {
1788 node: node_id.0.into(),
1789 port: port_id.0.into(),
1790 seq_delta: seq_delta.0,
1791 reserved: 0,
1792 }
1793 .as_bytes(),
1794 );
1795 }
1796 NonMessageEvent::AcknowledgeChangePeer => {
1797 header.event_type = protocol::EventType::ACKNOWLEDGE_CHANGE_PEER
1798 }
1799 NonMessageEvent::AcknowledgePort => {
1800 header.event_type = protocol::EventType::ACKNOWLEDGE_PORT
1801 }
1802 NonMessageEvent::FailPort(err) => {
1803 header.event_type = protocol::EventType::FAIL_PORT;
1804 header.message_size = size_of::<protocol::FailPortData>() as u32;
1805 buf.append(
1806 protocol::FailPortData {
1807 node: err
1808 .remote_node_id()
1809 .unwrap_or(&self.remote_node.local_node.id)
1810 .0
1811 .into(),
1812 }
1813 .as_bytes(),
1814 );
1815 }
1816 },
1817 EventAndEncoder::Message { message, close } => {
1818 let mut resources = Vec::new();
1819 header.event_type = protocol::EventType::MESSAGE;
1820 header.flags.set_close(close);
1821 if let Some(message) = message {
1822 header.flags.set_message(true);
1823 header.message_size = message.len() as u32;
1824 header.resource_count = message.resource_count() as u32;
1825 buf.write_split(
1826 message.resource_count() * size_of::<protocol::ResourceData>(),
1827 |mut resource_buf, mut message_buf| {
1828 message.encode_into(&mut message_buf, &mut resources);
1829 for resource in resources {
1830 let data = match resource {
1831 Resource::Port(port) => port.prepare_to_send(self.remote_node),
1832 Resource::Os(r) => {
1833 os_resources.extend([r]);
1834 protocol::ResourceData::new_zeroed()
1835 }
1836 };
1837 resource_buf.append(data.as_bytes());
1838 }
1839 },
1840 );
1841 }
1842 }
1843 }
1844
1845 header_buf.append(header.as_bytes());
1847 }
1848}
1849
1850pub trait SendEvent: Send + Sync {
1852 fn event(&self, event: OutgoingEvent<'_>);
1853}
1854
1855pub trait Connect: Send + Sync {
1857 fn connect(&self, node_id: NodeId, handle: RemoteNodeHandle);
1858}
1859
1860impl LocalNode {
1861 pub fn with_id(node_id: NodeId, connector: Box<dyn Connect>) -> Self {
1864 let node = Arc::new(LocalNodeInner {
1865 id: node_id,
1866 state: Mutex::new(LocalNodeState {
1867 ports: HashMap::new(),
1868 nodes: HashMap::new(),
1869 shutdown: None,
1870 }),
1871 });
1872 Self {
1873 inner: node,
1874 connector: Mutex::new(Some(connector)),
1875 }
1876 }
1877
1878 pub fn id(&self) -> NodeId {
1880 self.inner.id
1881 }
1882
1883 #[cfg(test)]
1884 fn is_empty(&self) -> bool {
1885 self.inner.state.lock().ports.is_empty()
1886 }
1887
1888 pub async fn wait_for_ports(&self, all_ports: bool) {
1893 loop {
1894 #[allow(clippy::disallowed_methods)] let (send, recv) = oneshot::channel::<()>();
1896 let ports: Vec<_> = {
1897 let mut state = self.inner.state.lock();
1898 state.shutdown = Some(send);
1899 state.ports.values().cloned().collect()
1900 };
1901 let left = ports
1902 .into_iter()
1903 .filter(|port| {
1904 let wait = all_ports
1905 || match &port.state.lock().activity {
1906 PortActivity::Peered(_) => false,
1907 PortActivity::Sending { .. } => true,
1908 PortActivity::Proxying { .. } => true,
1909 PortActivity::Failed(_) => false,
1910 PortActivity::Done => false,
1911 PortActivity::Unreachable => unreachable!(),
1912 };
1913 if wait {
1914 tracing::trace!(node = ?self.id(), ?port, "waiting for port");
1915 }
1916 wait
1917 })
1918 .count();
1919 if left == 0 {
1920 tracing::debug!(node = ?self.id(), "no ports remain");
1921 return;
1922 }
1923 tracing::debug!(node = ?self.id(), count = left, "waiting for ports");
1924 let _ = recv.await;
1925 }
1926 }
1927
1928 pub fn drop_connector(&self) {
1929 self.connector.lock().take();
1930 }
1931
1932 pub fn fail_all_nodes(&self) {
1933 self.drop_connector();
1935 self.inner.fail_all_nodes(NodeError::shutting_down());
1936 }
1937
1938 pub fn add_port(&self, id: PortId, peer: Address) -> Port {
1939 tracing::trace!(node = ?self.inner.id, port = ?id, peer = ?peer, "importing port");
1940 let peer_node = self.get_remote(peer.node);
1941 let activity = PortActivity::Peered(PortRef::RemotePort(peer_node.clone(), peer.port));
1942
1943 let port = Port::new(id, PortInnerState::new(activity));
1944 {
1945 let mut state = PortInner::associate(&port.inner, &self.inner);
1946 if let Err(err) = peer_node.node_status() {
1947 state.set_activity(PortActivity::Failed(err));
1948 port.inner.disassociate(&mut state);
1949 }
1950 }
1951 port
1952 }
1953
1954 pub fn add_remote(&self, id: NodeId) -> RemoteNodeHandle {
1956 let (deferred_conn, handle) = RemoteNode::new(self.inner.clone(), id);
1957 self.inner.state.lock().nodes.insert(id, deferred_conn);
1958
1959 handle
1960 }
1961
1962 pub fn get_remote_handle(&self, id: NodeId) -> RemoteNodeHandle {
1965 let remote = self.get_remote(id);
1966 let handle = remote.handle_count.fetch_add(1, Ordering::SeqCst);
1967 assert!(handle >= 0);
1968 RemoteNodeHandle {
1969 id,
1970 remote_node: Arc::downgrade(&remote),
1971 }
1972 }
1973
1974 pub fn event(&self, remote_node_id: &NodeId, event: &[u8], os_resources: &mut Vec<OsResource>) {
1976 let parse = || {
1977 let header = protocol::Event::read_from_prefix(event).ok()?.0; let (resources, message) = Ref::from_prefix_with_elems(
1979 &event[size_of_val(&header)..],
1980 header.resource_count as usize,
1981 )
1982 .ok()?; let message = message.get(..header.message_size as usize)?;
1984 Some((header, resources, message))
1985 };
1986
1987 match parse() {
1988 Some((header, resources, message)) => {
1989 if let Err(error) =
1990 self.on_parsed_event(remote_node_id, &header, &resources, message, os_resources)
1991 {
1992 tracing::error!(
1993 node = ?self.inner.id,
1994 port = ?PortId(header.port_id.into()),
1995 seq = header.seq,
1996 ?error,
1997 "node event failure"
1998 );
1999 }
2000 }
2001 None => {
2002 tracing::error!(
2003 node = ?self.inner.id,
2004 "node event parse failure"
2005 );
2006 }
2007 }
2008 }
2009
2010 fn on_parsed_event(
2011 &self,
2012 remote_node_id: &NodeId,
2013 header: &protocol::Event,
2014 resource_data: &[Unalign<protocol::ResourceData>],
2015 message: &[u8],
2016 os_resources: &mut Vec<OsResource>,
2017 ) -> Result<(), EventError> {
2018 let port_id = PortId(header.port_id.into());
2019 let seq = Wrapping(header.seq);
2020
2021 tracing::trace!(
2022 node = ?self.inner.id,
2023 port = ?port_id,
2024 seq,
2025 event_type = ?header.event_type,
2026 "port event"
2027 );
2028 let port = self
2029 .get_local_port(port_id)
2030 .ok_or(EventError::UnknownPort)?;
2031 let mut os_resources = os_resources.drain(..);
2032
2033 let port_event = match header.event_type {
2034 protocol::EventType::MESSAGE => {
2035 let message = if header.flags.message() {
2036 let mut resources = Vec::with_capacity(resource_data.len());
2038 for data in resource_data {
2039 let data = data.get();
2040 let r = if data.id.is_zero() {
2041 Resource::Os(os_resources.next().ok_or(EventError::MissingOsResource)?)
2042 } else {
2043 Resource::Port(self.receive_port(remote_node_id, data))
2044 };
2045 resources.push(r);
2046 }
2047 Some(Message::serialized(message, resources))
2048 } else {
2049 None
2050 };
2051 PortEvent::Message {
2052 message,
2053 close: header.flags.close(),
2054 }
2055 }
2056 protocol::EventType::CHANGE_PEER => {
2057 let data = protocol::ChangePeerData::read_from_prefix(message)
2058 .map_err(|_| EventError::Truncated)?
2059 .0; let port = self
2061 .get_port(Address {
2062 node: NodeId(data.node.into()),
2063 port: PortId(data.port.into()),
2064 })
2065 .ok_or(EventError::UnknownPort)?;
2066 NonMessageEvent::ChangePeer(port, Wrapping(data.seq_delta)).into()
2067 }
2068 protocol::EventType::ACKNOWLEDGE_CHANGE_PEER => {
2069 NonMessageEvent::AcknowledgeChangePeer.into()
2070 }
2071 protocol::EventType::ACKNOWLEDGE_PORT => {
2072 let mut events = PendingEvents::new();
2073 port.start_proxy(remote_node_id, Wrapping(1), &mut events);
2074 events.process();
2075 return Ok(());
2076 }
2077 protocol::EventType::FAIL_PORT => {
2078 let data = protocol::FailPortData::read_from_prefix(message)
2079 .map_err(|_| EventError::Truncated)?
2080 .0; NonMessageEvent::FailPort(NodeError::new(
2082 remote_node_id,
2083 RemotePortError(NodeId(data.node.into())),
2084 ))
2085 .into()
2086 }
2087 ty => return Err(EventError::UnknownEventType(ty)),
2088 };
2089 PendingEvents::send_local(&port, Some(remote_node_id), seq, port_event);
2090 Ok(())
2091 }
2092
2093 fn receive_port(&self, remote_node_id: &NodeId, data: protocol::ResourceData) -> Port {
2095 let old_address = Address {
2096 node: NodeId(data.old_node.into()),
2097 port: PortId(data.old_port.into()),
2098 };
2099
2100 let peer_address = if !data.peer_port.is_zero() {
2101 Ok(Address {
2102 node: NodeId(data.peer_node.into()),
2103 port: PortId(data.peer_port.into()),
2104 })
2105 } else {
2106 Err(NodeError::new(
2107 remote_node_id,
2108 RemotePortError(NodeId(data.peer_node.into())),
2109 ))
2110 };
2111
2112 tracing::trace!(
2113 node = ?self.inner.id,
2114 port = ?PortId(data.id.into()),
2115 old_address = ?old_address,
2116 peer = ?peer_address,
2117 "received port"
2118 );
2119
2120 let peer;
2121 let activity = match peer_address.and_then(|addr| {
2122 self.get_port(addr)
2123 .ok_or_else(|| NodeError::new(remote_node_id, UnknownLocalPort))
2124 }) {
2125 Ok(peer_port) => {
2126 peer = Some(peer_port.clone());
2127 PortActivity::Peered(peer_port)
2128 }
2129 Err(err) => {
2130 tracing::warn!(
2131 node = ?self.inner.id,
2132 port = ?PortId(data.id.into()),
2133 error = &err as &dyn std::error::Error,
2134 old_address = ?old_address,
2135 "received failed port",
2136 );
2137 peer = None;
2138 PortActivity::Failed(err)
2139 }
2140 };
2141
2142 let port = Port::new(
2143 PortId(data.id.into()),
2144 PortInnerState {
2145 next_local_seq: Wrapping(data.next_local_seq),
2146 ..PortInnerState::new(activity)
2147 },
2148 );
2149 if let Some(peer) = peer {
2150 let mut state = PortInner::associate(&port.inner, &self.inner);
2151 let source = self.get_remote(old_address.node);
2152 if let Err(err) = peer.node_status().and_then(|()| source.node_status()) {
2153 state.set_activity(PortActivity::Failed(err));
2154 port.inner.disassociate(&mut state);
2155 } else {
2156 drop(state);
2157 source.event(
2158 old_address.port,
2159 Wrapping(0),
2160 NonMessageEvent::AcknowledgePort.into(),
2161 );
2162 }
2163 }
2164 port
2165 }
2166
2167 fn get_remote(&self, id: NodeId) -> Arc<RemoteNode> {
2169 assert!(id != self.id());
2170 let mut state = self.inner.state.lock();
2171 let remote_node = match state.nodes.entry(id) {
2172 hash_map::Entry::Occupied(entry) => entry.get().clone(),
2173 hash_map::Entry::Vacant(entry) => {
2174 let (remote_node, handle) = RemoteNode::new(self.inner.clone(), id);
2175 entry.insert(remote_node.clone());
2176 drop(state);
2177 let connector = self.connector.lock();
2178 if let Some(connector) = &*connector {
2179 connector.connect(id, handle);
2180 }
2181 remote_node
2182 }
2183 };
2184 remote_node
2185 }
2186
2187 fn get_local_port(&self, port_id: PortId) -> Option<Arc<PortInner>> {
2189 self.inner.state.lock().ports.get(&port_id).cloned()
2190 }
2191
2192 fn get_port(&self, address: Address) -> Option<PortRef> {
2194 let peer = if address.node == self.inner.id {
2195 PortRef::LocalPort(self.get_local_port(address.port)?)
2196 } else {
2197 PortRef::RemotePort(self.get_remote(address.node), address.port)
2198 };
2199 Some(peer)
2200 }
2201}
2202
2203impl LocalNodeInner {
2204 fn fail_all_nodes(&self, err: NodeError) {
2206 let nodes = std::mem::take(&mut self.state.lock().nodes);
2207 for (_, node) in nodes {
2208 node.fail(err.clone());
2209 }
2210 }
2211
2212 fn fail_all_ports(&self, err: NodeError) {
2214 let ports = std::mem::take(&mut self.state.lock().ports);
2215 let mut pending_events = PendingEvents::new();
2216 let mut control = PortControl::unpeered(&mut pending_events);
2217 for (_, port) in ports {
2218 let mut state = port.state.lock();
2219 if let Some(handler) = &mut state.handler {
2220 handler.fail(&mut control, err.clone());
2221 }
2222 state.local_node = None;
2223 state.set_activity(PortActivity::Failed(err.clone()));
2224 }
2225 pending_events.process();
2226 }
2227
2228 fn drop_remote_handle(&self, remote_node: &Arc<RemoteNode>) {
2229 let count = remote_node.handle_count.fetch_sub(1, Ordering::SeqCst);
2230 assert!(count > 0);
2231 if count == 1 {
2232 self.disconnect_remote(
2233 remote_node,
2234 NodeError::new(&remote_node.id, RemoteNodeDropped),
2235 );
2236 }
2237 }
2238
2239 fn disconnect_remote(&self, remote_node: &Arc<RemoteNode>, err: NodeError) {
2241 tracing::trace!(node = ?self.id, remote_node = ?remote_node.id, "disconnecting node");
2242
2243 remote_node.fail(err.clone());
2245
2246 let ports: Vec<_> = self.state.lock().ports.values().cloned().collect();
2249
2250 let mut pending_events = PendingEvents::new();
2251 for port in ports {
2252 let mut state = port.state.lock();
2253 let fail = match &state.activity {
2254 PortActivity::Failed(_) => continue,
2255 PortActivity::Proxying {
2256 target: PortRef::RemotePort(node, _),
2257 ..
2258 }
2259 | PortActivity::Proxying {
2260 peer: PortRef::RemotePort(node, _),
2261 ..
2262 }
2263 | PortActivity::Peered(PortRef::RemotePort(node, _))
2264 | PortActivity::Sending {
2265 peer: PortRef::RemotePort(node, _),
2266 ..
2267 }
2268 | PortActivity::Sending {
2269 target: PortRef::RemotePort(node, _),
2270 ..
2271 } if node.id == remote_node.id => true,
2272 _ => false,
2273 };
2274 if fail {
2275 state.fail(&mut pending_events, err.clone());
2276 if let Some(handler) = &mut state.handler {
2277 handler.fail(&mut PortControl::unpeered(&mut pending_events), err.clone());
2278 }
2279 port.disassociate(&mut state);
2280 drop(state);
2281
2282 tracing::debug!(
2284 local_id = ?self.id,
2285 port = ?port.id,
2286 remote_id = ?remote_node.id,
2287 error = &err as &dyn std::error::Error,
2288 "port failed due to failed node"
2289 );
2290 }
2291 }
2292 pending_events.process();
2293
2294 self.state.lock().nodes.remove(&remote_node.id);
2296 }
2297}
2298
2299#[cfg(test)]
2300pub mod tests {
2301 use super::*;
2302 use crate::message::MeshField;
2303 use crate::resource::SerializedMessage;
2304 use futures::stream::Stream;
2305 use pal_async::DefaultDriver;
2306 use pal_async::async_test;
2307 use pal_async::task::Spawn;
2308 use pal_async::task::Task;
2309 use std::future::Future;
2310 use std::future::poll_fn;
2311 use std::marker::PhantomData;
2312 use std::pin::Pin;
2313 use std::pin::pin;
2314 use std::task::Context;
2315 use std::task::Poll;
2316 use test_with_tracing::test;
2317
2318 fn yield_once() -> YieldOnce {
2319 YieldOnce { yielded: false }
2320 }
2321
2322 struct YieldOnce {
2323 yielded: bool,
2324 }
2325
2326 impl Future for YieldOnce {
2327 type Output = ();
2328
2329 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
2330 if !self.yielded {
2331 self.yielded = true;
2332 cx.waker().wake_by_ref();
2333 return Poll::Pending;
2334 }
2335 ().into()
2336 }
2337 }
2338
2339 #[derive(Debug)]
2340 pub enum TryRecvError {
2341 Empty,
2342 Closed,
2343 Failed,
2344 }
2345
2346 #[derive(Debug)]
2347 pub enum RecvError {
2348 Closed,
2349 Failed,
2350 }
2351
2352 struct Channel<T = SerializedMessage, U = SerializedMessage> {
2353 port: PortWithHandler<Queue>,
2354 _phantom: PhantomData<(fn(T), fn() -> U)>,
2355 }
2356
2357 #[derive(Default)]
2358 struct Queue {
2359 queue: VecDeque<OwnedMessage>,
2360 closed: bool,
2361 failed: bool,
2362 waker: Option<Waker>,
2363 }
2364
2365 impl Queue {
2366 fn try_recv(&mut self) -> Result<OwnedMessage, TryRecvError> {
2367 if let Some(x) = self.queue.pop_front() {
2368 Ok(x)
2369 } else if self.closed {
2370 Err(TryRecvError::Closed)
2371 } else if self.failed {
2372 Err(TryRecvError::Failed)
2373 } else {
2374 Err(TryRecvError::Empty)
2375 }
2376 }
2377
2378 fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<OwnedMessage, RecvError>> {
2379 let r = if let Some(x) = self.queue.pop_front() {
2380 Ok(x)
2381 } else if self.closed {
2382 Err(RecvError::Closed)
2383 } else if self.failed {
2384 Err(RecvError::Failed)
2385 } else {
2386 self.waker = Some(cx.waker().clone());
2387 return Poll::Pending;
2388 };
2389 Poll::Ready(r)
2390 }
2391 }
2392
2393 impl HandlePortEvent for Queue {
2394 fn message(
2395 &mut self,
2396 control: &mut PortControl<'_, '_>,
2397 message: Message<'_>,
2398 ) -> Result<(), HandleMessageError> {
2399 self.queue.push_back(message.into_owned());
2400 if let Some(waker) = self.waker.take() {
2401 control.wake(waker);
2402 }
2403 Ok(())
2404 }
2405
2406 fn close(&mut self, control: &mut PortControl<'_, '_>) {
2407 self.closed = true;
2408 if let Some(waker) = self.waker.take() {
2409 control.wake(waker);
2410 }
2411 }
2412
2413 fn fail(&mut self, control: &mut PortControl<'_, '_>, _err: NodeError) {
2414 self.failed = true;
2415 if let Some(waker) = self.waker.take() {
2416 control.wake(waker);
2417 }
2418 }
2419
2420 fn drain(&mut self) -> Vec<OwnedMessage> {
2421 self.queue.drain(..).collect()
2422 }
2423 }
2424
2425 impl<T: MeshField, U: MeshField> From<Port> for Channel<T, U> {
2426 fn from(port: Port) -> Self {
2427 Self {
2428 port: port.set_handler(Queue::default()),
2429 _phantom: PhantomData,
2430 }
2431 }
2432 }
2433
2434 impl<T, U> From<Channel<T, U>> for Port {
2435 fn from(channel: Channel<T, U>) -> Self {
2436 channel.port.remove_handler().0
2437 }
2438 }
2439
2440 impl<T: 'static + MeshField + Send, U: 'static + MeshField + Send> Channel<T, U> {
2441 fn new_pair() -> (Self, Channel<U, T>) {
2442 let (left, right) = Port::new_pair();
2443 (left.into(), right.into())
2444 }
2445
2446 fn bridge(self, other: Channel<U, T>) {
2447 Port::from(self).bridge(other.into())
2448 }
2449
2450 fn change_types<T2: MeshField, U2: MeshField>(self) -> Channel<T2, U2> {
2451 let Self { port, _phantom: _ } = self;
2452 Channel {
2453 port,
2454 _phantom: PhantomData,
2455 }
2456 }
2457
2458 fn send(&self, t: T) {
2459 self.port.send(Message::new((t,)));
2460 }
2461
2462 fn try_recv(&mut self) -> Result<U, TryRecvError> {
2463 self.port
2464 .with_handler(|queue| queue.try_recv())
2465 .map(|m| m.parse::<(U,)>().unwrap().0)
2466 }
2467
2468 async fn recv(&mut self) -> Result<U, RecvError> {
2469 poll_fn(|cx| self.port.with_handler(|queue| queue.poll_recv(cx)))
2470 .await
2471 .map(|m| m.parse::<(U,)>().unwrap().0)
2472 }
2473 }
2474
2475 struct RemoteLocalNode {
2476 _task: Task<()>,
2477 node: Arc<LocalNode>,
2478 send: futures_channel::mpsc::UnboundedSender<RemoteEvent>,
2479 }
2480
2481 struct RemoteEvent {
2482 node_id: NodeId,
2483 data: Vec<u8>,
2484 resources: Vec<OsResource>,
2485 }
2486
2487 #[derive(Debug)]
2488 struct NullConnect;
2489
2490 impl Connect for NullConnect {
2491 fn connect(&self, _node_id: NodeId, _handle: RemoteNodeHandle) {}
2492 }
2493
2494 impl RemoteLocalNode {
2495 fn new(driver: &impl Spawn) -> Self {
2496 #[expect(
2497 clippy::disallowed_methods,
2498 reason = "can't use mesh channels from mesh_node"
2499 )]
2500 let (send, recv) = futures_channel::mpsc::unbounded::<RemoteEvent>();
2501 let node = Arc::new(LocalNode::with_id(NodeId::new(), Box::new(NullConnect)));
2502 let task = driver.spawn("test", {
2503 let node = node.clone();
2504 async move {
2505 let mut recv = pin!(recv);
2506 while let Some(mut event) = poll_fn(|cx| recv.as_mut().poll_next(cx)).await {
2507 node.event(&event.node_id, &event.data, &mut event.resources);
2508 }
2509 }
2510 });
2511 Self {
2512 _task: task,
2513 node,
2514 send,
2515 }
2516 }
2517
2518 fn connect(self: &Arc<Self>, other: &Arc<Self>) -> RemoteNodeHandle {
2519 let handle = self.node.add_remote(other.node.id());
2520 handle.connect(EventsFrom {
2521 node_id: self.node.id(),
2522 send: other.send.clone(),
2523 });
2524 handle
2525 }
2526 }
2527
2528 struct EventsFrom {
2529 node_id: NodeId,
2530 send: futures_channel::mpsc::UnboundedSender<RemoteEvent>,
2531 }
2532
2533 impl SendEvent for EventsFrom {
2534 fn event(&self, event: OutgoingEvent<'_>) {
2535 let mut buffer = Vec::with_capacity(event.len());
2536 let mut os_resources = Vec::new();
2537 event.write_to(&mut buffer, &mut os_resources);
2538 self.send
2539 .unbounded_send(RemoteEvent {
2540 node_id: self.node_id,
2541 data: buffer,
2542 resources: os_resources,
2543 })
2544 .ok();
2545 }
2546 }
2547
2548 #[test]
2549 fn test_local() {
2550 let (left, mut right) = Channel::<_, ()>::new_pair();
2551 left.send(SerializedMessage {
2552 data: b"abc".to_vec(),
2553 ..Default::default()
2554 });
2555 assert_eq!(right.try_recv().unwrap().data, b"abc");
2556 assert!(matches!(right.try_recv().unwrap_err(), TryRecvError::Empty));
2557 }
2558
2559 fn new_two_node_mesh(
2560 driver: &DefaultDriver,
2561 ) -> (
2562 Arc<RemoteLocalNode>,
2563 Arc<RemoteLocalNode>,
2564 Vec<RemoteNodeHandle>,
2565 ) {
2566 let node = Arc::new(RemoteLocalNode::new(driver));
2567 let node2 = Arc::new(RemoteLocalNode::new(driver));
2568 let mut v = Vec::new();
2569 let handle = node.connect(&node2);
2570 v.push(handle);
2571 let handle = node2.connect(&node);
2572 v.push(handle);
2573 (node, node2, v)
2574 }
2575
2576 fn new_three_node_mesh(
2577 driver: &DefaultDriver,
2578 ) -> (
2579 Arc<RemoteLocalNode>,
2580 Arc<RemoteLocalNode>,
2581 Arc<RemoteLocalNode>,
2582 Vec<RemoteNodeHandle>,
2583 ) {
2584 let node = Arc::new(RemoteLocalNode::new(driver));
2585 let node2 = Arc::new(RemoteLocalNode::new(driver));
2586 let node3 = Arc::new(RemoteLocalNode::new(driver));
2587 let mut v = Vec::new();
2588 for i in [&node, &node2, &node3][..].iter().copied() {
2589 for j in [&node, &node2, &node3][..].iter().copied() {
2590 if Arc::as_ptr(i) != Arc::as_ptr(j) {
2591 let handle = i.connect(j);
2592 v.push(handle);
2593 }
2594 }
2595 }
2596 (node, node2, node3, v)
2597 }
2598
2599 fn new_remote_port_pair(node1: &LocalNode, node2: &LocalNode) -> (Channel, Channel) {
2600 let left_id = PortId::new();
2601 let right_id = PortId::new();
2602 let left = node1.add_port(
2603 left_id,
2604 Address {
2605 node: node2.id(),
2606 port: right_id,
2607 },
2608 );
2609 let right = node2.add_port(
2610 right_id,
2611 Address {
2612 node: node1.id(),
2613 port: left_id,
2614 },
2615 );
2616 (left.into(), right.into())
2617 }
2618
2619 fn bmsg(data: &[u8]) -> SerializedMessage {
2620 SerializedMessage {
2621 data: data.into(),
2622 ..Default::default()
2623 }
2624 }
2625
2626 #[async_test]
2627 async fn test_remote(driver: DefaultDriver) {
2628 let (node, node2, _h) = new_two_node_mesh(&driver);
2629 {
2630 let (left, mut right) = new_remote_port_pair(&node.node, &node2.node);
2631 left.send(SerializedMessage {
2632 data: b"abc".to_vec(),
2633 ..Default::default()
2634 });
2635 assert_eq!(right.recv().await.unwrap().data, b"abc");
2636 }
2637 yield_once().await;
2638 assert!(node.node.is_empty());
2639 assert!(node2.node.is_empty());
2640 }
2641
2642 #[async_test]
2643 async fn test_send_port(driver: DefaultDriver) {
2644 let (node, node2, _h) = new_two_node_mesh(&driver);
2645 {
2646 let (left, mut right) = new_remote_port_pair(&node.node, &node2.node);
2647 let (left2, right2) = <Channel>::new_pair();
2648 left2.send(SerializedMessage {
2649 data: b"abc".to_vec(),
2650 ..Default::default()
2651 });
2652 left.send(SerializedMessage {
2653 resources: vec![Resource::Port(right2.into())],
2654 ..Default::default()
2655 });
2656 let r = right.recv().await.unwrap();
2657 let mut right2 =
2658 <Channel>::from(Port::try_from(r.resources.into_iter().next().unwrap()).unwrap());
2659 left2.send(SerializedMessage {
2660 data: b"def".to_vec(),
2661 ..Default::default()
2662 });
2663 assert_eq!(right2.recv().await.unwrap().data, b"abc");
2664 assert_eq!(right2.recv().await.unwrap().data, b"def");
2665 }
2666 yield_once().await;
2667 assert!(node.node.is_empty());
2668 assert!(node2.node.is_empty());
2669 }
2670
2671 #[async_test]
2672 async fn test_send_port_with_three_nodes(driver: DefaultDriver) {
2673 let (node, node2, node3, _h) = new_three_node_mesh(&driver);
2674 {
2675 let (left, mut right) = new_remote_port_pair(&node.node, &node2.node);
2676 let (left2, right2) = new_remote_port_pair(&node3.node, &node.node);
2677 left2.send(SerializedMessage {
2678 data: b"abc".to_vec(),
2679 ..Default::default()
2680 });
2681 left.send(SerializedMessage {
2682 resources: vec![Resource::Port(right2.into())],
2683 ..Default::default()
2684 });
2685 let r = right.recv().await.unwrap();
2686 let mut right2 =
2687 <Channel>::from(Port::try_from(r.resources.into_iter().next().unwrap()).unwrap());
2688 left2.send(SerializedMessage {
2689 data: b"def".to_vec(),
2690 ..Default::default()
2691 });
2692 assert_eq!(right2.recv().await.unwrap().data, b"abc");
2693 assert_eq!(right2.recv().await.unwrap().data, b"def");
2694 }
2695 yield_once().await;
2696 assert!(node.node.is_empty());
2697 assert!(node2.node.is_empty());
2698 assert!(node3.node.is_empty());
2699 }
2700
2701 #[async_test]
2702 async fn test_send_closed_port(driver: DefaultDriver) {
2703 let (node, node2, _h) = new_two_node_mesh(&driver);
2704 {
2705 let (left, mut right) = new_remote_port_pair(&node.node, &node2.node);
2706 let (left2, right2) = <Channel>::new_pair();
2707 drop(left2);
2708 left.send(SerializedMessage {
2709 resources: vec![Resource::Port(right2.into())],
2710 ..Default::default()
2711 });
2712 let r = right.recv().await.unwrap();
2713 let mut right2 =
2714 <Channel>::from(Port::try_from(r.resources.into_iter().next().unwrap()).unwrap());
2715 assert!(matches!(
2716 right2.try_recv().unwrap_err(),
2717 TryRecvError::Closed
2718 ));
2719 }
2720 yield_once().await;
2721 assert!(node.node.is_empty());
2722 assert!(node2.node.is_empty());
2723 }
2724
2725 #[test]
2726 fn test_local_close() {
2727 let (left, mut right) = Channel::<_, ()>::new_pair();
2728 left.send(SerializedMessage {
2729 data: b"abc".to_vec(),
2730 ..Default::default()
2731 });
2732 drop(left);
2733 assert_eq!(right.try_recv().unwrap().data, b"abc");
2734 assert!(matches!(
2735 right.try_recv().unwrap_err(),
2736 TryRecvError::Closed
2737 ));
2738 }
2739
2740 #[async_test]
2741 async fn test_remote_close(driver: DefaultDriver) {
2742 let (node, node2, _h) = new_two_node_mesh(&driver);
2743 {
2744 let (left, mut right) = new_remote_port_pair(&node.node, &node2.node);
2745 left.send(SerializedMessage {
2746 data: b"abc".to_vec(),
2747 ..Default::default()
2748 });
2749 drop(left);
2750 assert_eq!(right.recv().await.unwrap().data, b"abc");
2751 assert!(matches!(
2752 right.try_recv().unwrap_err(),
2753 TryRecvError::Closed
2754 ));
2755 }
2756 yield_once().await;
2757 assert!(node.node.is_empty());
2758 assert!(node2.node.is_empty());
2759 }
2760
2761 #[async_test]
2762 async fn test_node_fail(driver: DefaultDriver) {
2763 let (node, node2, mut handles) = new_two_node_mesh(&driver);
2764 let (_left, mut right) = new_remote_port_pair(&node.node, &node2.node);
2765 handles.remove(1);
2766 assert!(matches!(
2767 right.try_recv().unwrap_err(),
2768 TryRecvError::Failed
2769 ));
2770 }
2771
2772 #[async_test]
2773 async fn test_send_failed_port(driver: DefaultDriver) {
2774 let (node, node2, node3, mut handles) = new_three_node_mesh(&driver);
2775 let (_left, right) = new_remote_port_pair(&node.node, &node2.node);
2776 let (left2, mut right2) = new_remote_port_pair(&node2.node, &node3.node);
2777 handles.remove(2);
2778 left2.send(SerializedMessage {
2779 resources: vec![Resource::Port(right.into())],
2780 ..Default::default()
2781 });
2782 let r = right2.recv().await.unwrap();
2783 let mut right =
2784 <Channel>::from(Port::try_from(r.resources.into_iter().next().unwrap()).unwrap());
2785 assert!(matches!(
2786 right.try_recv().unwrap_err(),
2787 TryRecvError::Failed
2788 ));
2789 }
2790
2791 #[async_test]
2792 async fn test_async(driver: DefaultDriver) {
2793 let (node, node2, _h) = new_two_node_mesh(&driver);
2794 let (left, mut right) = new_remote_port_pair(&node.node, &node2.node);
2795 let left = Arc::new(left);
2796 driver
2797 .spawn("test", {
2798 let left = left.clone();
2799 async move {
2800 left.send(SerializedMessage {
2801 data: b"abc".to_vec(),
2802 ..Default::default()
2803 });
2804 }
2805 })
2806 .detach();
2807 assert_eq!(right.recv().await.unwrap().data, b"abc");
2808 drop(left);
2809 }
2810
2811 #[async_test]
2812 async fn test_async_close(driver: DefaultDriver) {
2813 let (node, node2, _h) = new_two_node_mesh(&driver);
2814 let (left, mut right) = new_remote_port_pair(&node.node, &node2.node);
2815 driver
2816 .spawn("test", async move {
2817 drop(left);
2818 })
2819 .detach();
2820 assert!(matches!(right.recv().await.unwrap_err(), RecvError::Closed));
2821 }
2822
2823 #[async_test]
2824 async fn test_bridge_local() {
2825 let (p1, p2) = Channel::new_pair();
2826 let (p3, p4) = Channel::new_pair();
2827 test_bridge(p1, p2, p3, p4).await;
2828 }
2829
2830 #[async_test]
2831 async fn test_bridge_remote(driver: DefaultDriver) {
2832 let (node, node2, node3, _h) = new_three_node_mesh(&driver);
2833 let (p1, p2) = new_remote_port_pair(&node.node, &node2.node);
2834 let (p3, p4) = new_remote_port_pair(&node2.node, &node3.node);
2835 test_bridge(p1, p2, p3, p4).await;
2836 node.node.wait_for_ports(true).await;
2837 node2.node.wait_for_ports(true).await;
2838 node3.node.wait_for_ports(true).await;
2839 }
2840
2841 async fn test_bridge(p1: Channel, p2: Channel, mut p3: Channel, p4: Channel) {
2842 p1.send(bmsg(b"5"));
2843 p1.send(bmsg(b"6"));
2844 p1.send(bmsg(b"7"));
2845
2846 p2.send(bmsg(b"a"));
2847 p2.send(bmsg(b"b"));
2848
2849 p3.send(bmsg(b"1"));
2850 p3.send(bmsg(b"2"));
2851 p3.send(bmsg(b"3"));
2852 p3.send(bmsg(b"4"));
2853
2854 p4.send(bmsg(b"x"));
2855 p4.send(bmsg(b"y"));
2856 p4.send(bmsg(b"c"));
2857 p4.send(bmsg(b"d"));
2858 p4.send(bmsg(b"e"));
2859 p4.send(bmsg(b"f"));
2860 p4.send(bmsg(b"g"));
2861 p4.send(bmsg(b"h"));
2862
2863 p3.recv().await.unwrap();
2864 p3.recv().await.unwrap();
2865
2866 p2.bridge(p3);
2867
2868 p4.send(bmsg(b"i"));
2869 drop(p4);
2870
2871 let recv_all = async |mut p: Channel| {
2872 let mut v = Vec::new();
2873 loop {
2874 match p.recv().await {
2875 Ok(m) => v.push(m.data[0]),
2876 Err(RecvError::Closed) => break,
2877 Err(e) => return Err(e),
2878 }
2879 }
2880 Ok(v)
2881 };
2882
2883 assert_eq!(recv_all(p1).await.unwrap(), b"abcdefghi");
2884 }
2885
2886 #[test]
2887 fn test_bridge_self() {
2888 let (p1, p2) = Channel::<(), ()>::new_pair();
2889 p1.bridge(p2);
2891 }
2892
2893 #[async_test]
2894 async fn test_fail_sent_port_to_failed_node(driver: DefaultDriver) {
2895 let (n1, n2, mut h) = new_two_node_mesh(&driver);
2896 let (p1, _p2) = new_remote_port_pair(&n1.node, &n2.node);
2897 let (mut p3, p4) = <Channel>::new_pair();
2898 p1.send(SerializedMessage {
2899 resources: vec![Resource::Port(p4.into())],
2900 ..Default::default()
2901 });
2902 h.remove(0);
2903 assert!(matches!(p3.recv().await.unwrap_err(), RecvError::Failed));
2904 }
2905
2906 #[async_test]
2907 async fn test_close_drop_port_with_queued_ports() {
2908 let (p1, p2) = Channel::<_, ()>::new_pair();
2909 let (mut p3, p4) = <Channel>::new_pair();
2910 p1.send(SerializedMessage {
2911 resources: vec![Resource::Port(p4.into())],
2912 ..Default::default()
2913 });
2914 drop(p2);
2915 assert!(matches!(p3.recv().await.unwrap_err(), RecvError::Closed));
2916 }
2917
2918 #[async_test]
2919 async fn test_close_send_port_to_dropped_port() {
2920 let (p1, p2) = Channel::<_, ()>::new_pair();
2921 let (mut p3, p4) = <Channel>::new_pair();
2922 drop(p2);
2923 p1.send(SerializedMessage {
2924 resources: vec![Resource::Port(p4.into())],
2925 ..Default::default()
2926 });
2927 assert!(matches!(p3.recv().await.unwrap_err(), RecvError::Closed));
2928 }
2929
2930 #[async_test]
2931 async fn test_change_sender_types() {
2932 let (p1, mut p2) = Channel::<u32, ()>::new_pair();
2933 let p1 = p1.change_types::<u64, ()>();
2934 p1.send(1);
2935 assert_eq!(p2.recv().await.unwrap(), 1);
2936 }
2937
2938 #[async_test]
2939 async fn test_change_receiver_types() {
2940 let (p1, p2) = Channel::<u32, ()>::new_pair();
2941 let mut p2 = p2.change_types::<(), u64>();
2942 p1.send(1);
2943 assert_eq!(p2.recv().await.unwrap(), 1);
2944 }
2945
2946 #[async_test]
2947 async fn test_change_both_types() {
2948 let (p1, p2) = Channel::<u32, ()>::new_pair();
2949 let p1 = p1.change_types::<u64, ()>();
2950 let mut p2 = p2.change_types::<(), u64>();
2951 p1.send(1);
2952 assert_eq!(p2.recv().await.unwrap(), 1);
2953 }
2954
2955 #[async_test]
2956 async fn test_change_from_generic() {
2957 let (p1, p2) = Channel::<SerializedMessage, SerializedMessage>::new_pair();
2958 let p1 = p1.change_types::<u64, ()>();
2959 let mut p2 = p2.change_types::<(), u32>();
2960 p1.send(1);
2961 assert_eq!(p2.recv().await.unwrap(), 1);
2962 }
2963
2964 #[async_test]
2965 async fn test_fail_port(driver: DefaultDriver) {
2966 #[derive(Debug, Error)]
2967 #[error("test failure")]
2968 struct ExplicitFailure;
2969
2970 let (node, node2, _h) = new_two_node_mesh(&driver);
2971 let (p1, mut p2) = new_remote_port_pair(&node.node, &node2.node);
2972 let p1 = Port::from(p1);
2973 p1.fail(NodeError::local(ExplicitFailure));
2974 let err = p2.recv().await.unwrap_err();
2975 assert!(matches!(err, RecvError::Failed));
2976 }
2977}