1use anyhow::Context;
7use futures::Stream;
8use futures::StreamExt;
9use futures::executor::block_on;
10use futures::stream::FusedStream;
11use futures_concurrency::stream::Merge;
12use inspect::Inspect;
13use mesh::MeshPayload;
14use mesh::error::RemoteError;
15use mesh::rpc::FailableRpc;
16use mesh::rpc::RpcSend;
17use std::fmt;
18use std::marker::PhantomData;
19use std::pin::Pin;
20use std::task::Poll;
21use std::thread;
22use unicycle::FuturesUnordered;
23
24#[derive(Copy, Clone, Debug)]
26pub struct WorkerId<T>(&'static str, PhantomData<T>);
27
28impl<T> WorkerId<T> {
29 pub const fn new(id: &'static str) -> Self {
31 Self(id, PhantomData)
32 }
33
34 pub const fn id(&self) -> &'static str {
36 self.0
37 }
38}
39
40impl<T> fmt::Display for WorkerId<T> {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 f.pad(self.0)
43 }
44}
45
46pub trait Worker: 'static + Sized {
48 type Parameters: 'static + Send;
53
54 type State: 'static + MeshPayload + Send;
56
57 const ID: WorkerId<Self::Parameters>;
61
62 fn new(parameters: Self::Parameters) -> anyhow::Result<Self>;
67
68 fn restart(state: Self::State) -> anyhow::Result<Self>;
70
71 fn run(self, recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()>;
78}
79
80#[derive(Debug, MeshPayload)]
82#[mesh(bound = "T: 'static + MeshPayload + Send")]
83pub enum WorkerRpc<T> {
84 Stop,
86 Restart(FailableRpc<(), T>),
89 Inspect(inspect::Deferred),
91}
92
93#[derive(Debug, MeshPayload)]
94enum LaunchType {
95 New {
96 parameters: mesh::OwnedMessage,
97 },
98 Restart {
99 send: mesh::Sender<WorkerRpc<mesh::OwnedMessage>>,
100 events: mesh::Receiver<WorkerEvent>,
101 },
102}
103
104#[derive(Debug, MeshPayload)]
109pub struct WorkerHostRunner(mesh::Receiver<WorkerHostLaunchRequest>);
110
111impl WorkerHostRunner {
112 pub async fn run(mut self, factory: impl WorkerFactory) {
118 let mut rundown = FuturesUnordered::new();
119 loop {
120 let mut stream = ((&mut self.0).map(Some), (&mut rundown).map(|_| None)).merge();
121 let launch_params = match stream.next().await {
122 Some(Some(launch_params)) => launch_params,
123 Some(None) => continue,
124 None => break,
125 };
126
127 let _requestspan = tracing::info_span!("worker_host_launch_request").entered();
128
129 let result = factory.builder(&launch_params.name);
130 match result {
131 Ok(runner) => {
132 let (rundown_send, rundown_recv) = mesh::oneshot::<()>();
134 thread::Builder::new()
135 .name(format!("worker-{}", &launch_params.name))
136 .spawn(move || {
137 launch_params.request.launch(runner);
138 drop(rundown_send);
139 })
140 .expect("thread launch failed");
141
142 rundown.push(rundown_recv);
143 }
144 Err(err) => {
145 launch_params.request.fail(err);
147 }
148 }
149 }
150 }
151}
152
153#[derive(Debug, MeshPayload, Inspect)]
157pub struct WorkerHandle {
158 #[inspect(skip)]
159 name: String,
160 #[inspect(flatten, send = "WorkerRpc::Inspect")]
161 send: mesh::Sender<WorkerRpc<mesh::OwnedMessage>>,
162 #[inspect(skip)]
163 events: mesh::Receiver<WorkerEvent>,
164}
165
166impl Stream for WorkerHandle {
167 type Item = WorkerEvent;
168
169 fn poll_next(
170 mut self: Pin<&mut Self>,
171 cx: &mut std::task::Context<'_>,
172 ) -> Poll<Option<Self::Item>> {
173 Poll::Ready(match std::task::ready!(self.events.poll_recv(cx)) {
174 Ok(event) => Some(event),
175 Err(mesh::RecvError::Error(err)) => Some(WorkerEvent::Failed(RemoteError::new(err))),
176 Err(mesh::RecvError::Closed) => None,
177 })
178 }
179}
180
181impl FusedStream for WorkerHandle {
182 fn is_terminated(&self) -> bool {
183 self.events.is_terminated()
184 }
185}
186
187#[derive(Debug, MeshPayload)]
189pub enum WorkerEvent {
190 Stopped,
192 Failed(RemoteError),
194 Started,
196 RestartFailed(RemoteError),
198}
199
200impl WorkerHandle {
201 pub fn stop(&mut self) {
203 self.send.send(WorkerRpc::Stop);
204 }
205
206 pub async fn join(&mut self) -> anyhow::Result<()> {
208 while let Some(event) = self.next().await {
209 if let WorkerEvent::Failed(err) = event {
210 return Err(err.into());
211 }
212 }
213 Ok(())
214 }
215
216 pub fn restart(&mut self, host: &WorkerHost) {
221 let (send, recv) = mesh::channel();
222 let (events_send, events) = mesh::channel();
223
224 let send = std::mem::replace(&mut self.send, send);
225 let events = std::mem::replace(&mut self.events, events);
226
227 host.launch_worker_internal(
229 &self.name,
230 recv,
231 events_send,
232 LaunchType::Restart { send, events },
233 );
234 }
235}
236
237#[derive(Debug, MeshPayload, Clone)]
244pub struct WorkerHost(mesh::Sender<WorkerHostLaunchRequest>);
245
246pub fn worker_host() -> (WorkerHost, WorkerHostRunner) {
280 let (send, recv) = mesh::mpsc_channel();
281 (WorkerHost(send), WorkerHostRunner(recv))
282}
283
284impl WorkerHost {
285 pub fn start_worker<T>(&self, id: WorkerId<T>, params: T) -> anyhow::Result<WorkerHandle>
291 where
292 T: 'static + MeshPayload + Send,
293 {
294 self.start_worker_inner(id.id(), mesh::OwnedMessage::new(params))
295 }
296
297 fn start_worker_inner(
298 &self,
299 id: &str,
300 parameters: mesh::OwnedMessage,
301 ) -> anyhow::Result<WorkerHandle> {
302 let (events_send, events_recv) = mesh::channel();
303 let (rpc_send, rpc_recv) = mesh::channel();
304 self.launch_worker_internal(id, rpc_recv, events_send, LaunchType::New { parameters });
305 Ok(WorkerHandle {
306 name: id.to_string(),
307 send: rpc_send,
308 events: events_recv,
309 })
310 }
311
312 pub async fn launch_worker<T>(&self, id: WorkerId<T>, params: T) -> anyhow::Result<WorkerHandle>
315 where
316 T: 'static + MeshPayload + Send,
317 {
318 let mut handle = self.start_worker_inner(id.id(), mesh::OwnedMessage::new(params))?;
319 match handle.next().await.context("failed to launch worker")? {
320 WorkerEvent::Started => Ok(handle),
321 WorkerEvent::Failed(err) => Err(err).context("failed to launch worker")?,
322 WorkerEvent::Stopped | WorkerEvent::RestartFailed(_) => {
323 anyhow::bail!("received invalid worker event")
324 }
325 }
326 }
327
328 fn launch_worker_internal(
329 &self,
330 id: &str,
331 rpc_recv: mesh::Receiver<WorkerRpc<mesh::OwnedMessage>>,
332 events_send: mesh::Sender<WorkerEvent>,
333 launch_type: LaunchType,
334 ) {
335 let request = WorkerHostLaunchRequest {
336 name: id.to_string(),
337 request: WorkerLaunchRequest {
338 rpc: rpc_recv,
339 events: events_send,
340 launch_type,
341 },
342 };
343
344 self.0.send(request);
345 }
346}
347
348pub async fn launch_local_worker<T: Worker>(
364 parameters: T::Parameters,
365) -> anyhow::Result<WorkerHandle> {
366 let (rpc_send, rpc_recv) = mesh::channel();
367 let (events_send, events_recv) = mesh::channel();
368 let (result_send, result_recv) = mesh::oneshot();
369
370 thread::Builder::new()
371 .name(format!("worker-{}", &T::ID.id()))
372 .spawn(move || match T::new(parameters) {
373 Ok(worker) => {
374 result_send.send(Ok(()));
375 match worker.run(rpc_recv) {
376 Ok(()) => {
377 events_send.send(WorkerEvent::Stopped);
378 }
379 Err(err) => {
380 events_send.send(WorkerEvent::Failed(RemoteError::new(err)));
381 }
382 }
383 }
384 Err(err) => {
385 result_send.send(Err(err));
386 }
387 })
388 .expect("thread launch failed");
389
390 result_recv.await.unwrap()?;
391 Ok(WorkerHandle {
392 name: T::ID.id().to_owned(),
393 send: mesh::local_node::Port::from(rpc_send).into(),
395 events: events_recv,
396 })
397}
398
399pub trait WorkerFactory: 'static + Send + Sync {
405 fn builder(&self, name: &str) -> anyhow::Result<WorkerBuilder>;
407}
408
409#[derive(Debug, MeshPayload)]
410struct WorkerLaunchRequest {
411 rpc: mesh::Receiver<WorkerRpc<mesh::OwnedMessage>>,
412 events: mesh::Sender<WorkerEvent>,
413 launch_type: LaunchType,
414}
415
416impl WorkerLaunchRequest {
417 fn fail(self, err: anyhow::Error) {
418 match self.launch_type {
419 LaunchType::New { .. } => {
420 self.events.send(WorkerEvent::Failed(RemoteError::new(err)));
421 }
422 LaunchType::Restart { send, events } => {
423 self.events
425 .send(WorkerEvent::RestartFailed(RemoteError::new(err)));
426 self.rpc.bridge(send);
427 self.events.bridge(events);
428 }
429 }
430 }
431
432 fn launch(self, builder: WorkerBuilder) {
433 let worker = match self.launch_type {
434 LaunchType::New { parameters } => {
435 let _span =
436 tracing::info_span!("worker_new", name = builder.id, action = "new").entered();
437 match builder.build_and_run(BuildRequest::New(parameters)) {
438 Ok(worker) => worker,
439 Err(err) => {
440 self.events.send(WorkerEvent::Failed(RemoteError::new(err)));
441 return;
442 }
443 }
444 }
445 LaunchType::Restart { send, events } => {
446 let state_recv = send.call_failable(WorkerRpc::Restart, ());
447 let state = match block_on(state_recv) {
448 Ok(state) => state,
449 Err(err) => {
450 self.events
451 .send(WorkerEvent::RestartFailed(RemoteError::new(err)));
452 self.events.bridge(events);
454 self.rpc.bridge(send);
455 return;
456 }
457 };
458 let _span =
459 tracing::info_span!("worker_new", name = builder.id, action = "restart")
460 .entered();
461 match builder.build_and_run(BuildRequest::Restart(state)) {
462 Ok(worker) => worker,
463 Err(err) => {
464 self.events.send(WorkerEvent::Failed(RemoteError::new(err)));
465 return;
466 }
467 }
468 }
469 };
470
471 self.events.send(WorkerEvent::Started);
472 match worker.run(self.rpc) {
473 Ok(()) => {
474 self.events.send(WorkerEvent::Stopped);
475 }
476 Err(err) => {
477 self.events.send(WorkerEvent::Failed(RemoteError::new(err)));
478 }
479 }
480 }
481}
482
483pub struct WorkerBuilder {
485 inner: Box<dyn WorkerBuildAndRun>,
486 id: &'static str,
487}
488
489impl WorkerBuilder {
490 pub fn new<T: Worker>() -> Self
492 where
493 T::Parameters: MeshPayload,
494 {
495 Self {
496 inner: Box::new(BuilderInner::<T>(PhantomData)),
497 id: T::ID.id(),
498 }
499 }
500
501 fn build_and_run(self, request: BuildRequest) -> anyhow::Result<Box<dyn Run>> {
502 self.inner.build_and_run(request)
503 }
504}
505
506#[doc(hidden)]
507pub enum BuildRequest {
508 New(mesh::OwnedMessage),
509 Restart(mesh::OwnedMessage),
510}
511
512struct BuilderInner<T: Worker>(PhantomData<fn() -> T>);
513
514trait WorkerBuildAndRun: Send {
515 fn build_and_run(self: Box<Self>, request: BuildRequest) -> anyhow::Result<Box<dyn Run>>;
516}
517
518trait Run {
519 fn run(
520 self: Box<Self>,
521 recv: mesh::Receiver<WorkerRpc<mesh::OwnedMessage>>,
522 ) -> anyhow::Result<()>;
523}
524
525impl<T: Worker> Run for T {
526 fn run(
527 self: Box<Self>,
528 recv: mesh::Receiver<WorkerRpc<mesh::OwnedMessage>>,
529 ) -> anyhow::Result<()> {
530 let recv = mesh::local_node::Port::from(recv).into();
532 Worker::run(*self, recv)
533 }
534}
535
536impl<T: Worker> WorkerBuildAndRun for BuilderInner<T>
537where
538 T::Parameters: MeshPayload,
539{
540 fn build_and_run(self: Box<Self>, request: BuildRequest) -> anyhow::Result<Box<dyn Run>> {
541 let worker = match request {
542 BuildRequest::New(parameters) => {
543 T::new(parameters.parse().context("failed to receive parameters")?)
544 }
545 BuildRequest::Restart(state) => T::restart(
546 state
548 .serialize()
549 .into_message()
550 .context("failed to parse restart state")?,
551 ),
552 }?;
553
554 Ok(Box::new(worker))
555 }
556}
557
558#[macro_export]
583macro_rules! runnable_workers {
584 (
585 $name:ident {
586 $($(#[$vattr:meta])* $worker:ty),*$(,)?
587 }
588 ) => {
589
590 #[derive(Debug, Clone)]
591 struct $name;
592
593 impl $crate::WorkerFactory for $name {
594 fn builder(&self, name: &str) -> anyhow::Result<$crate::WorkerBuilder> {
595 $(
596 $(#[$vattr])*
597 {
598 if name == <$worker as $crate::Worker>::ID.id() {
599 return Ok($crate::WorkerBuilder::new::<$worker>());
600 }
601 }
602 )*
603
604 anyhow::bail!("unsupported worker {name}")
605 }
606 }
607 };
608}
609
610#[doc(hidden)]
611pub mod private {
612 #![expect(unsafe_code)]
614
615 use super::RegisteredWorkers;
616 use super::WorkerFactory;
617 use crate::Worker;
618 use crate::WorkerBuilder;
619 pub use linkme;
620 use mesh::MeshPayload;
621
622 #[linkme::distributed_slice]
627 pub static WORKERS: [Option<&'static RegisteredWorker>] = [..];
628
629 #[linkme::distributed_slice(WORKERS)]
633 static WORKAROUND: Option<&'static RegisteredWorker> = None;
634
635 pub struct RegisteredWorker {
636 id: &'static str,
637 build: fn() -> WorkerBuilder,
638 }
639
640 impl RegisteredWorker {
641 pub const fn new<T: Worker>() -> Self
642 where
643 T::Parameters: MeshPayload,
644 {
645 Self {
646 id: T::ID.id(),
647 build: WorkerBuilder::new::<T>,
648 }
649 }
650 }
651
652 impl WorkerFactory for RegisteredWorkers {
653 fn builder(&self, name: &str) -> anyhow::Result<WorkerBuilder> {
654 for worker in WORKERS.iter().flatten() {
655 if worker.id == name {
656 return Ok((worker.build)());
657 }
658 }
659 anyhow::bail!("unsupported worker {name}")
660 }
661 }
662
663 #[macro_export]
670 macro_rules! register_workers {
671 {} => {};
672 { $($(#[$attr:meta])* $worker:ty),+ $(,)? } => {
673 $(
674 $(#[$attr])*
675 const _: () = {
676 use $crate::private;
677 use private::linkme;
678
679 #[linkme::distributed_slice(private::WORKERS)]
680 #[linkme(crate = linkme)]
681 static WORKER: Option<&'static private::RegisteredWorker> = Some(&private::RegisteredWorker::new::<$worker>());
682 };
683 )*
684 };
685 }
686}
687
688#[derive(Debug, Clone)]
707pub struct RegisteredWorkers;
708
709#[derive(Debug, MeshPayload)]
711struct WorkerHostLaunchRequest {
712 name: String,
714 request: WorkerLaunchRequest,
716}
717
718#[cfg(test)]
719mod tests {
720 use super::Worker;
721 use super::WorkerFactory;
722 use super::WorkerId;
723 use super::WorkerRpc;
724 use crate::launch_local_worker;
725 use crate::worker::WorkerEvent;
726 use futures::StreamExt;
727 use futures::executor::block_on;
728 use mesh::MeshPayload;
729 use pal_async::DefaultDriver;
730 use pal_async::async_test;
731 use pal_async::task::Spawn;
732 use test_with_tracing::test;
733
734 struct TestWorker {
735 value: u64,
736 }
737
738 #[derive(MeshPayload, Default)]
739 struct TestWorkerConfig {
740 pub value: u64,
741 }
742
743 #[derive(MeshPayload)]
744 struct TestWorkerState {
745 pub value: u64,
746 }
747
748 impl Worker for TestWorker {
749 type Parameters = TestWorkerConfig;
750 type State = TestWorkerState;
751 const ID: WorkerId<Self::Parameters> = WorkerId::new("TestWorker");
752
753 fn new(parameters: Self::Parameters) -> anyhow::Result<Self> {
754 Ok(Self {
755 value: parameters.value,
756 })
757 }
758
759 fn restart(state: Self::State) -> anyhow::Result<Self> {
760 Ok(Self { value: state.value })
761 }
762
763 fn run(self, mut recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()> {
764 block_on(async {
765 while let Ok(req) = recv.recv().await {
766 match req {
767 WorkerRpc::Stop => break,
768 WorkerRpc::Restart(rpc) => {
769 rpc.complete(Ok(TestWorkerState { value: self.value }));
770 break;
771 }
772 WorkerRpc::Inspect(_deferred) => (),
773 }
774 }
775 Ok(())
776 })
777 }
778 }
779
780 struct TestWorker2;
781
782 impl Worker for TestWorker2 {
783 type Parameters = ();
784 type State = ();
785 const ID: WorkerId<Self::Parameters> = WorkerId::new("TestWorker2");
786
787 fn new(_parameters: Self::Parameters) -> anyhow::Result<Self> {
788 Ok(Self)
789 }
790
791 fn restart(_state: ()) -> anyhow::Result<Self> {
792 Ok(Self)
793 }
794
795 fn run(self, mut recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()> {
796 block_on(async {
797 while let Ok(req) = recv.recv().await {
798 match req {
799 WorkerRpc::Stop => break,
800 WorkerRpc::Restart(rpc) => {
801 rpc.complete(Ok(()));
802 break;
803 }
804 WorkerRpc::Inspect(_deferred) => (),
805 }
806 }
807 Ok(())
808 })
809 }
810 }
811
812 runnable_workers! {
813 RunnableWorkers1 {
814 TestWorker,
815 }
816 }
817
818 #[test]
819 fn test_runnable_workers_unsupported() {
820 let result = RunnableWorkers1.builder("foobar");
821
822 assert!(result.is_err());
823 }
824
825 #[async_test]
826 async fn test_launch_worker_remote_supported_worker(driver: DefaultDriver) {
827 let (host, runner) = super::worker_host();
828
829 let task = driver.spawn("runner", runner.run(RunnableWorkers1));
832
833 let result = host.launch_worker(TestWorker::ID, Default::default()).await;
834
835 assert!(result.is_ok());
836 drop(result.unwrap());
838 drop(host);
840 task.await;
841 }
842
843 #[async_test]
844 async fn test_launch_worker_remote_unsupported_worker(driver: DefaultDriver) {
845 let (host, runner) = super::worker_host();
846
847 let task = driver.spawn("runner", runner.run(RunnableWorkers1));
850
851 let result = host.launch_worker(TestWorker2::ID, ()).await;
852
853 assert!(result.is_err());
854 drop(host);
856 task.await;
857 }
858
859 #[async_test]
860 async fn test_launch_worker_remote_restart_worker(driver: DefaultDriver) {
861 let (host, runner) = super::worker_host();
862
863 let task = driver.spawn("runner", runner.run(RunnableWorkers1));
866
867 let result = host.launch_worker(TestWorker::ID, Default::default()).await;
868
869 let mut handle = result.expect("worker launch failed");
870 handle.restart(&host);
871 assert!(matches!(handle.next().await.unwrap(), WorkerEvent::Started));
872 handle.stop();
873 assert!(matches!(handle.next().await.unwrap(), WorkerEvent::Stopped));
874
875 assert!(handle.next().await.is_none());
876
877 drop(host);
879 task.await;
880 }
881
882 struct LocalWorker;
883
884 impl Worker for LocalWorker {
885 type Parameters = fn() -> anyhow::Result<()>;
886 type State = ();
887 const ID: WorkerId<Self::Parameters> = WorkerId::new("local");
888
889 fn new(parameters: Self::Parameters) -> anyhow::Result<Self> {
890 parameters()?;
891 Ok(Self)
892 }
893
894 fn restart(_state: Self::State) -> anyhow::Result<Self> {
895 unreachable!()
896 }
897
898 fn run(self, _recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()> {
899 Ok(())
900 }
901 }
902
903 #[async_test]
904 async fn test_launch_local_no_mesh() {
905 let mut worker = launch_local_worker::<LocalWorker>(|| Ok(())).await.unwrap();
906 worker.join().await.unwrap();
907 }
908}
909
910#[doc(hidden)]
912pub mod test_support {
913 use std::marker::PhantomData;
914
915 use crate::Worker;
916 use crate::WorkerId;
917 use crate::WorkerRpc;
918
919 pub struct DummyWorker<T>(PhantomData<T>);
921
922 pub const DUMMY_WORKER: WorkerId<()> = WorkerId::new("DummyWorker");
923
924 impl<T: 'static + Send> Worker for DummyWorker<T> {
925 type Parameters = ();
926 type State = ();
927 const ID: WorkerId<Self::Parameters> = DUMMY_WORKER;
928
929 fn new(_parameters: Self::Parameters) -> anyhow::Result<Self> {
930 Ok(Self(PhantomData))
931 }
932
933 fn restart(_state: Self::State) -> anyhow::Result<Self> {
934 Ok(Self(PhantomData))
935 }
936
937 fn run(self, _recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()> {
938 todo!()
939 }
940 }
941}