1use anyhow::Context;
7use futures::Stream;
8use futures::StreamExt;
9use futures::executor::block_on;
10use futures::stream::FusedStream;
11use futures_concurrency::stream::Merge;
12use inspect::Inspect;
13use mesh::MeshPayload;
14use mesh::error::RemoteError;
15use mesh::rpc::FailableRpc;
16use mesh::rpc::RpcSend;
17use std::fmt;
18use std::marker::PhantomData;
19use std::pin::Pin;
20use std::task::Poll;
21use std::thread;
22use unicycle::FuturesUnordered;
23
24#[derive(Copy, Clone, Debug)]
26pub struct WorkerId<T>(&'static str, PhantomData<T>);
27
28impl<T> WorkerId<T> {
29 pub const fn new(id: &'static str) -> Self {
31 Self(id, PhantomData)
32 }
33
34 pub const fn id(&self) -> &'static str {
36 self.0
37 }
38}
39
40impl<T> fmt::Display for WorkerId<T> {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 f.pad(self.0)
43 }
44}
45
46pub trait Worker: 'static + Sized {
48 type Parameters: 'static + Send;
53
54 type State: 'static + MeshPayload + Send;
56
57 const ID: WorkerId<Self::Parameters>;
61
62 fn new(parameters: Self::Parameters) -> anyhow::Result<Self>;
67
68 fn restart(state: Self::State) -> anyhow::Result<Self>;
70
71 fn run(self, recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()>;
78}
79
80#[derive(Debug, MeshPayload)]
82#[mesh(bound = "T: 'static + MeshPayload + Send")]
83pub enum WorkerRpc<T> {
84 Stop,
86 Restart(FailableRpc<(), T>),
89 Inspect(inspect::Deferred),
91}
92
93#[derive(Debug, MeshPayload)]
94enum LaunchType {
95 New {
96 parameters: mesh::OwnedMessage,
97 },
98 Restart {
99 send: mesh::Sender<WorkerRpc<mesh::OwnedMessage>>,
100 events: mesh::Receiver<WorkerEvent>,
101 },
102}
103
104#[derive(Debug, MeshPayload)]
109pub struct WorkerHostRunner(mesh::Receiver<WorkerHostLaunchRequest>);
110
111impl WorkerHostRunner {
112 pub async fn run(mut self, factory: impl WorkerFactory) {
118 let mut rundown = FuturesUnordered::new();
119 loop {
120 let mut stream = ((&mut self.0).map(Some), (&mut rundown).map(|_| None)).merge();
121 let launch_params = match stream.next().await {
122 Some(Some(launch_params)) => launch_params,
123 Some(None) => continue,
124 None => break,
125 };
126
127 let _requestspan = tracing::info_span!("worker_host_launch_request").entered();
128
129 let result = factory.builder(&launch_params.name);
130 match result {
131 Ok(runner) => {
132 let (rundown_send, rundown_recv) = mesh::oneshot::<()>();
134 thread::Builder::new()
135 .name(format!("worker-{}", &launch_params.name))
136 .spawn(move || {
137 launch_params.request.launch(runner);
138 drop(rundown_send);
139 })
140 .expect("thread launch failed");
141
142 rundown.push(rundown_recv);
143 }
144 Err(err) => {
145 launch_params.request.fail(err);
147 }
148 }
149 }
150 }
151}
152
153#[derive(Debug, MeshPayload)]
157pub struct WorkerHandle {
158 name: String,
159 send: mesh::Sender<WorkerRpc<mesh::OwnedMessage>>,
160 events: mesh::Receiver<WorkerEvent>,
161}
162
163impl Inspect for WorkerHandle {
164 fn inspect(&self, req: inspect::Request<'_>) {
165 self.send.send(WorkerRpc::Inspect(req.defer()))
166 }
167}
168
169impl Stream for WorkerHandle {
170 type Item = WorkerEvent;
171
172 fn poll_next(
173 mut self: Pin<&mut Self>,
174 cx: &mut std::task::Context<'_>,
175 ) -> Poll<Option<Self::Item>> {
176 Poll::Ready(match std::task::ready!(self.events.poll_recv(cx)) {
177 Ok(event) => Some(event),
178 Err(mesh::RecvError::Error(err)) => Some(WorkerEvent::Failed(RemoteError::new(err))),
179 Err(mesh::RecvError::Closed) => None,
180 })
181 }
182}
183
184impl FusedStream for WorkerHandle {
185 fn is_terminated(&self) -> bool {
186 self.events.is_terminated()
187 }
188}
189
190#[derive(Debug, MeshPayload)]
192pub enum WorkerEvent {
193 Stopped,
195 Failed(RemoteError),
197 Started,
199 RestartFailed(RemoteError),
201}
202
203impl WorkerHandle {
204 pub fn stop(&mut self) {
206 self.send.send(WorkerRpc::Stop);
207 }
208
209 pub async fn join(&mut self) -> anyhow::Result<()> {
211 while let Some(event) = self.next().await {
212 if let WorkerEvent::Failed(err) = event {
213 return Err(err.into());
214 }
215 }
216 Ok(())
217 }
218
219 pub fn restart(&mut self, host: &WorkerHost) {
224 let (send, recv) = mesh::channel();
225 let (events_send, events) = mesh::channel();
226
227 let send = std::mem::replace(&mut self.send, send);
228 let events = std::mem::replace(&mut self.events, events);
229
230 host.launch_worker_internal(
232 &self.name,
233 recv,
234 events_send,
235 LaunchType::Restart { send, events },
236 );
237 }
238}
239
240#[derive(Debug, MeshPayload, Clone)]
247pub struct WorkerHost(mesh::Sender<WorkerHostLaunchRequest>);
248
249pub fn worker_host() -> (WorkerHost, WorkerHostRunner) {
283 let (send, recv) = mesh::mpsc_channel();
284 (WorkerHost(send), WorkerHostRunner(recv))
285}
286
287impl WorkerHost {
288 pub fn start_worker<T>(&self, id: WorkerId<T>, params: T) -> anyhow::Result<WorkerHandle>
294 where
295 T: 'static + MeshPayload + Send,
296 {
297 self.start_worker_inner(id.id(), mesh::OwnedMessage::new(params))
298 }
299
300 fn start_worker_inner(
301 &self,
302 id: &str,
303 parameters: mesh::OwnedMessage,
304 ) -> anyhow::Result<WorkerHandle> {
305 let (events_send, events_recv) = mesh::channel();
306 let (rpc_send, rpc_recv) = mesh::channel();
307 self.launch_worker_internal(id, rpc_recv, events_send, LaunchType::New { parameters });
308 Ok(WorkerHandle {
309 name: id.to_string(),
310 send: rpc_send,
311 events: events_recv,
312 })
313 }
314
315 pub async fn launch_worker<T>(&self, id: WorkerId<T>, params: T) -> anyhow::Result<WorkerHandle>
318 where
319 T: 'static + MeshPayload + Send,
320 {
321 let mut handle = self.start_worker_inner(id.id(), mesh::OwnedMessage::new(params))?;
322 match handle.next().await.context("failed to launch worker")? {
323 WorkerEvent::Started => Ok(handle),
324 WorkerEvent::Failed(err) => Err(err).context("failed to launch worker")?,
325 WorkerEvent::Stopped | WorkerEvent::RestartFailed(_) => {
326 anyhow::bail!("received invalid worker event")
327 }
328 }
329 }
330
331 fn launch_worker_internal(
332 &self,
333 id: &str,
334 rpc_recv: mesh::Receiver<WorkerRpc<mesh::OwnedMessage>>,
335 events_send: mesh::Sender<WorkerEvent>,
336 launch_type: LaunchType,
337 ) {
338 let request = WorkerHostLaunchRequest {
339 name: id.to_string(),
340 request: WorkerLaunchRequest {
341 rpc: rpc_recv,
342 events: events_send,
343 launch_type,
344 },
345 };
346
347 self.0.send(request);
348 }
349}
350
351pub async fn launch_local_worker<T: Worker>(
367 parameters: T::Parameters,
368) -> anyhow::Result<WorkerHandle> {
369 let (rpc_send, rpc_recv) = mesh::channel();
370 let (events_send, events_recv) = mesh::channel();
371 let (result_send, result_recv) = mesh::oneshot();
372
373 thread::Builder::new()
374 .name(format!("worker-{}", &T::ID.id()))
375 .spawn(move || match T::new(parameters) {
376 Ok(worker) => {
377 result_send.send(Ok(()));
378 match worker.run(rpc_recv) {
379 Ok(()) => {
380 events_send.send(WorkerEvent::Stopped);
381 }
382 Err(err) => {
383 events_send.send(WorkerEvent::Failed(RemoteError::new(err)));
384 }
385 }
386 }
387 Err(err) => {
388 result_send.send(Err(err));
389 }
390 })
391 .expect("thread launch failed");
392
393 result_recv.await.unwrap()?;
394 Ok(WorkerHandle {
395 name: T::ID.id().to_owned(),
396 send: mesh::local_node::Port::from(rpc_send).into(),
398 events: events_recv,
399 })
400}
401
402pub trait WorkerFactory: 'static + Send + Sync {
408 fn builder(&self, name: &str) -> anyhow::Result<WorkerBuilder>;
410}
411
412#[derive(Debug, MeshPayload)]
413struct WorkerLaunchRequest {
414 rpc: mesh::Receiver<WorkerRpc<mesh::OwnedMessage>>,
415 events: mesh::Sender<WorkerEvent>,
416 launch_type: LaunchType,
417}
418
419impl WorkerLaunchRequest {
420 fn fail(self, err: anyhow::Error) {
421 match self.launch_type {
422 LaunchType::New { .. } => {
423 self.events.send(WorkerEvent::Failed(RemoteError::new(err)));
424 }
425 LaunchType::Restart { send, events } => {
426 self.events
428 .send(WorkerEvent::RestartFailed(RemoteError::new(err)));
429 self.rpc.bridge(send);
430 self.events.bridge(events);
431 }
432 }
433 }
434
435 fn launch(self, builder: WorkerBuilder) {
436 let worker = match self.launch_type {
437 LaunchType::New { parameters } => {
438 let _span =
439 tracing::info_span!("worker_new", name = builder.id, action = "new").entered();
440 match builder.build_and_run(BuildRequest::New(parameters)) {
441 Ok(worker) => worker,
442 Err(err) => {
443 self.events.send(WorkerEvent::Failed(RemoteError::new(err)));
444 return;
445 }
446 }
447 }
448 LaunchType::Restart { send, events } => {
449 let state_recv = send.call_failable(WorkerRpc::Restart, ());
450 let state = match block_on(state_recv) {
451 Ok(state) => state,
452 Err(err) => {
453 self.events
454 .send(WorkerEvent::RestartFailed(RemoteError::new(err)));
455 self.events.bridge(events);
457 self.rpc.bridge(send);
458 return;
459 }
460 };
461 let _span =
462 tracing::info_span!("worker_new", name = builder.id, action = "restart")
463 .entered();
464 match builder.build_and_run(BuildRequest::Restart(state)) {
465 Ok(worker) => worker,
466 Err(err) => {
467 self.events.send(WorkerEvent::Failed(RemoteError::new(err)));
468 return;
469 }
470 }
471 }
472 };
473
474 self.events.send(WorkerEvent::Started);
475 match worker.run(self.rpc) {
476 Ok(()) => {
477 self.events.send(WorkerEvent::Stopped);
478 }
479 Err(err) => {
480 self.events.send(WorkerEvent::Failed(RemoteError::new(err)));
481 }
482 }
483 }
484}
485
486pub struct WorkerBuilder {
488 inner: Box<dyn WorkerBuildAndRun>,
489 id: &'static str,
490}
491
492impl WorkerBuilder {
493 pub fn new<T: Worker>() -> Self
495 where
496 T::Parameters: MeshPayload,
497 {
498 Self {
499 inner: Box::new(BuilderInner::<T>(PhantomData)),
500 id: T::ID.id(),
501 }
502 }
503
504 fn build_and_run(self, request: BuildRequest) -> anyhow::Result<Box<dyn Run>> {
505 self.inner.build_and_run(request)
506 }
507}
508
509#[doc(hidden)]
510pub enum BuildRequest {
511 New(mesh::OwnedMessage),
512 Restart(mesh::OwnedMessage),
513}
514
515struct BuilderInner<T: Worker>(PhantomData<fn() -> T>);
516
517trait WorkerBuildAndRun: Send {
518 fn build_and_run(self: Box<Self>, request: BuildRequest) -> anyhow::Result<Box<dyn Run>>;
519}
520
521trait Run {
522 fn run(
523 self: Box<Self>,
524 recv: mesh::Receiver<WorkerRpc<mesh::OwnedMessage>>,
525 ) -> anyhow::Result<()>;
526}
527
528impl<T: Worker> Run for T {
529 fn run(
530 self: Box<Self>,
531 recv: mesh::Receiver<WorkerRpc<mesh::OwnedMessage>>,
532 ) -> anyhow::Result<()> {
533 let recv = mesh::local_node::Port::from(recv).into();
535 Worker::run(*self, recv)
536 }
537}
538
539impl<T: Worker> WorkerBuildAndRun for BuilderInner<T>
540where
541 T::Parameters: MeshPayload,
542{
543 fn build_and_run(self: Box<Self>, request: BuildRequest) -> anyhow::Result<Box<dyn Run>> {
544 let worker = match request {
545 BuildRequest::New(parameters) => {
546 T::new(parameters.parse().context("failed to receive parameters")?)
547 }
548 BuildRequest::Restart(state) => T::restart(
549 state
551 .serialize()
552 .into_message()
553 .context("failed to parse restart state")?,
554 ),
555 }?;
556
557 Ok(Box::new(worker))
558 }
559}
560
561#[macro_export]
586macro_rules! runnable_workers {
587 (
588 $name:ident {
589 $($(#[$vattr:meta])* $worker:ty),*$(,)?
590 }
591 ) => {
592
593 #[derive(Debug, Clone)]
594 struct $name;
595
596 impl $crate::WorkerFactory for $name {
597 fn builder(&self, name: &str) -> anyhow::Result<$crate::WorkerBuilder> {
598 $(
599 $(#[$vattr])*
600 {
601 if name == <$worker as $crate::Worker>::ID.id() {
602 return Ok($crate::WorkerBuilder::new::<$worker>());
603 }
604 }
605 )*
606
607 anyhow::bail!("unsupported worker {name}")
608 }
609 }
610 };
611}
612
613#[doc(hidden)]
614pub mod private {
615 #![expect(unsafe_code)]
617
618 use super::RegisteredWorkers;
619 use super::WorkerFactory;
620 use crate::Worker;
621 use crate::WorkerBuilder;
622 pub use linkme;
623 use mesh::MeshPayload;
624
625 #[linkme::distributed_slice]
630 pub static WORKERS: [Option<&'static RegisteredWorker>] = [..];
631
632 #[linkme::distributed_slice(WORKERS)]
636 static WORKAROUND: Option<&'static RegisteredWorker> = None;
637
638 pub struct RegisteredWorker {
639 id: &'static str,
640 build: fn() -> WorkerBuilder,
641 }
642
643 impl RegisteredWorker {
644 pub const fn new<T: Worker>() -> Self
645 where
646 T::Parameters: MeshPayload,
647 {
648 Self {
649 id: T::ID.id(),
650 build: WorkerBuilder::new::<T>,
651 }
652 }
653 }
654
655 impl WorkerFactory for RegisteredWorkers {
656 fn builder(&self, name: &str) -> anyhow::Result<WorkerBuilder> {
657 for worker in WORKERS.iter().flatten() {
658 if worker.id == name {
659 return Ok((worker.build)());
660 }
661 }
662 anyhow::bail!("unsupported worker {name}")
663 }
664 }
665
666 #[macro_export]
673 macro_rules! register_workers {
674 {} => {};
675 { $($(#[$attr:meta])* $worker:ty),+ $(,)? } => {
676 $(
677 $(#[$attr])*
678 const _: () = {
679 use $crate::private;
680 use private::linkme;
681
682 #[linkme::distributed_slice(private::WORKERS)]
683 #[linkme(crate = linkme)]
684 static WORKER: Option<&'static private::RegisteredWorker> = Some(&private::RegisteredWorker::new::<$worker>());
685 };
686 )*
687 };
688 }
689}
690
691#[derive(Debug, Clone)]
710pub struct RegisteredWorkers;
711
712#[derive(Debug, MeshPayload)]
714struct WorkerHostLaunchRequest {
715 name: String,
717 request: WorkerLaunchRequest,
719}
720
721#[cfg(test)]
722mod tests {
723 use super::Worker;
724 use super::WorkerFactory;
725 use super::WorkerId;
726 use super::WorkerRpc;
727 use crate::launch_local_worker;
728 use crate::worker::WorkerEvent;
729 use futures::StreamExt;
730 use futures::executor::block_on;
731 use mesh::MeshPayload;
732 use pal_async::DefaultDriver;
733 use pal_async::async_test;
734 use pal_async::task::Spawn;
735 use test_with_tracing::test;
736
737 struct TestWorker {
738 value: u64,
739 }
740
741 #[derive(MeshPayload, Default)]
742 struct TestWorkerConfig {
743 pub value: u64,
744 }
745
746 #[derive(MeshPayload)]
747 struct TestWorkerState {
748 pub value: u64,
749 }
750
751 impl Worker for TestWorker {
752 type Parameters = TestWorkerConfig;
753 type State = TestWorkerState;
754 const ID: WorkerId<Self::Parameters> = WorkerId::new("TestWorker");
755
756 fn new(parameters: Self::Parameters) -> anyhow::Result<Self> {
757 Ok(Self {
758 value: parameters.value,
759 })
760 }
761
762 fn restart(state: Self::State) -> anyhow::Result<Self> {
763 Ok(Self { value: state.value })
764 }
765
766 fn run(self, mut recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()> {
767 block_on(async {
768 while let Ok(req) = recv.recv().await {
769 match req {
770 WorkerRpc::Stop => break,
771 WorkerRpc::Restart(rpc) => {
772 rpc.complete(Ok(TestWorkerState { value: self.value }));
773 break;
774 }
775 WorkerRpc::Inspect(_deferred) => (),
776 }
777 }
778 Ok(())
779 })
780 }
781 }
782
783 struct TestWorker2;
784
785 impl Worker for TestWorker2 {
786 type Parameters = ();
787 type State = ();
788 const ID: WorkerId<Self::Parameters> = WorkerId::new("TestWorker2");
789
790 fn new(_parameters: Self::Parameters) -> anyhow::Result<Self> {
791 Ok(Self)
792 }
793
794 fn restart(_state: ()) -> anyhow::Result<Self> {
795 Ok(Self)
796 }
797
798 fn run(self, mut recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()> {
799 block_on(async {
800 while let Ok(req) = recv.recv().await {
801 match req {
802 WorkerRpc::Stop => break,
803 WorkerRpc::Restart(rpc) => {
804 rpc.complete(Ok(()));
805 break;
806 }
807 WorkerRpc::Inspect(_deferred) => (),
808 }
809 }
810 Ok(())
811 })
812 }
813 }
814
815 runnable_workers! {
816 RunnableWorkers1 {
817 TestWorker,
818 }
819 }
820
821 #[test]
822 fn test_runnable_workers_unsupported() {
823 let result = RunnableWorkers1.builder("foobar");
824
825 assert!(result.is_err());
826 }
827
828 #[async_test]
829 async fn test_launch_worker_remote_supported_worker(driver: DefaultDriver) {
830 let (host, runner) = super::worker_host();
831
832 let task = driver.spawn("runner", runner.run(RunnableWorkers1));
835
836 let result = host.launch_worker(TestWorker::ID, Default::default()).await;
837
838 assert!(result.is_ok());
839 drop(result.unwrap());
841 drop(host);
843 task.await;
844 }
845
846 #[async_test]
847 async fn test_launch_worker_remote_unsupported_worker(driver: DefaultDriver) {
848 let (host, runner) = super::worker_host();
849
850 let task = driver.spawn("runner", runner.run(RunnableWorkers1));
853
854 let result = host.launch_worker(TestWorker2::ID, ()).await;
855
856 assert!(result.is_err());
857 drop(host);
859 task.await;
860 }
861
862 #[async_test]
863 async fn test_launch_worker_remote_restart_worker(driver: DefaultDriver) {
864 let (host, runner) = super::worker_host();
865
866 let task = driver.spawn("runner", runner.run(RunnableWorkers1));
869
870 let result = host.launch_worker(TestWorker::ID, Default::default()).await;
871
872 let mut handle = result.expect("worker launch failed");
873 handle.restart(&host);
874 assert!(matches!(handle.next().await.unwrap(), WorkerEvent::Started));
875 handle.stop();
876 assert!(matches!(handle.next().await.unwrap(), WorkerEvent::Stopped));
877
878 assert!(handle.next().await.is_none());
879
880 drop(host);
882 task.await;
883 }
884
885 struct LocalWorker;
886
887 impl Worker for LocalWorker {
888 type Parameters = fn() -> anyhow::Result<()>;
889 type State = ();
890 const ID: WorkerId<Self::Parameters> = WorkerId::new("local");
891
892 fn new(parameters: Self::Parameters) -> anyhow::Result<Self> {
893 parameters()?;
894 Ok(Self)
895 }
896
897 fn restart(_state: Self::State) -> anyhow::Result<Self> {
898 unreachable!()
899 }
900
901 fn run(self, _recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()> {
902 Ok(())
903 }
904 }
905
906 #[async_test]
907 async fn test_launch_local_no_mesh() {
908 let mut worker = launch_local_worker::<LocalWorker>(|| Ok(())).await.unwrap();
909 worker.join().await.unwrap();
910 }
911}
912
913#[doc(hidden)]
915pub mod test_support {
916 use std::marker::PhantomData;
917
918 use crate::Worker;
919 use crate::WorkerId;
920 use crate::WorkerRpc;
921
922 pub struct DummyWorker<T>(PhantomData<T>);
924
925 pub const DUMMY_WORKER: WorkerId<()> = WorkerId::new("DummyWorker");
926
927 impl<T: 'static + Send> Worker for DummyWorker<T> {
928 type Parameters = ();
929 type State = ();
930 const ID: WorkerId<Self::Parameters> = DUMMY_WORKER;
931
932 fn new(_parameters: Self::Parameters) -> anyhow::Result<Self> {
933 Ok(Self(PhantomData))
934 }
935
936 fn restart(_state: Self::State) -> anyhow::Result<Self> {
937 Ok(Self(PhantomData))
938 }
939
940 fn run(self, _recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()> {
941 todo!()
942 }
943 }
944}