mesh_worker/
worker.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Infrastructure for workers that can run on mesh nodes.
5
6use anyhow::Context;
7use futures::Stream;
8use futures::StreamExt;
9use futures::executor::block_on;
10use futures::stream::FusedStream;
11use futures_concurrency::stream::Merge;
12use inspect::Inspect;
13use mesh::MeshPayload;
14use mesh::error::RemoteError;
15use mesh::rpc::FailableRpc;
16use mesh::rpc::RpcSend;
17use std::fmt;
18use std::marker::PhantomData;
19use std::pin::Pin;
20use std::task::Poll;
21use std::thread;
22use unicycle::FuturesUnordered;
23
24/// A unique identifier for a worker, used to specify which worker to launch.
25#[derive(Copy, Clone, Debug)]
26pub struct WorkerId<T>(&'static str, PhantomData<T>);
27
28impl<T> WorkerId<T> {
29    /// Makes a new worker ID with the name `id`.
30    pub const fn new(id: &'static str) -> Self {
31        Self(id, PhantomData)
32    }
33
34    /// Gets the ID string.
35    pub const fn id(&self) -> &'static str {
36        self.0
37    }
38}
39
40impl<T> fmt::Display for WorkerId<T> {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        f.pad(self.0)
43    }
44}
45
46/// Trait implemented by workers.
47pub trait Worker: 'static + Sized {
48    /// Parameters passed to launch the worker. Used with [`Worker::new`].
49    ///
50    /// For this worker to be spawned on a remote node, `Parameters` must
51    /// implement [`MeshPayload`].
52    type Parameters: 'static + Send;
53
54    /// State used to implement hot restart. Used with [`Worker::restart`].
55    type State: 'static + MeshPayload + Send;
56
57    /// String identifying the Worker. Used when launching workers in separate processes
58    /// to specify which workers are supported and which worker to launch.
59    /// IDs must be unique within a given worker host.
60    const ID: WorkerId<Self::Parameters>;
61
62    /// Instantiates the worker.
63    ///
64    /// The worker should not start running yet, but it can allocate any resources
65    /// necessary to run.
66    fn new(parameters: Self::Parameters) -> anyhow::Result<Self>;
67
68    /// Restarts a worker from a previous worker's execution state.
69    fn restart(state: Self::State) -> anyhow::Result<Self>;
70
71    /// Synchronously runs the worker on the current thread.
72    ///
73    /// The worker should respond to commands sent in `recv`. If `recv` is closed,
74    /// the worker should exit.
75    ///
76    /// The worker ends when it returns from this function.
77    fn run(self, recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()>;
78}
79
80/// Common requests for workers.
81#[derive(Debug, MeshPayload)]
82#[mesh(bound = "T: 'static + MeshPayload + Send")]
83pub enum WorkerRpc<T> {
84    /// Tear down.
85    Stop,
86    /// Tear down and send the state necessary to restart on the provided
87    /// channel.
88    Restart(FailableRpc<(), T>),
89    /// Inspect the worker.
90    Inspect(inspect::Deferred),
91}
92
93#[derive(Debug, MeshPayload)]
94enum LaunchType {
95    New {
96        parameters: mesh::OwnedMessage,
97    },
98    Restart {
99        send: mesh::Sender<WorkerRpc<mesh::OwnedMessage>>,
100        events: mesh::Receiver<WorkerEvent>,
101    },
102}
103
104/// A runner returned by [`worker_host()`]. Used to handle worker launch
105/// requests.
106///
107/// This may be sent across processes via mesh.
108#[derive(Debug, MeshPayload)]
109pub struct WorkerHostRunner(mesh::Receiver<WorkerHostLaunchRequest>);
110
111impl WorkerHostRunner {
112    /// Runs the worker host until all corresponding [`WorkerHost`] instances
113    /// have been dropped and all workers have exited.
114    ///
115    /// `factory` provides the set of possible workers to launch. Typically,
116    /// this will be [`RegisteredWorkers`].
117    pub async fn run(mut self, factory: impl WorkerFactory) {
118        let mut rundown = FuturesUnordered::new();
119        loop {
120            let mut stream = ((&mut self.0).map(Some), (&mut rundown).map(|_| None)).merge();
121            let launch_params = match stream.next().await {
122                Some(Some(launch_params)) => launch_params,
123                Some(None) => continue,
124                None => break,
125            };
126
127            let _requestspan = tracing::info_span!("worker_host_launch_request").entered();
128
129            let result = factory.builder(&launch_params.name);
130            match result {
131                Ok(runner) => {
132                    // start a new thread and run the runner.
133                    let (rundown_send, rundown_recv) = mesh::oneshot::<()>();
134                    thread::Builder::new()
135                        .name(format!("worker-{}", &launch_params.name))
136                        .spawn(move || {
137                            launch_params.request.launch(runner);
138                            drop(rundown_send);
139                        })
140                        .expect("thread launch failed");
141
142                    rundown.push(rundown_recv);
143                }
144                Err(err) => {
145                    // TODO: kharp 2021-05-26 Better tracing of errors, maybe tracing_error?
146                    launch_params.request.fail(err);
147                }
148            }
149        }
150    }
151}
152
153/// Represents a running [`Worker`] instance providing the ability to restart,
154/// stop or wait for exit. To launch a worker and get a handle, use
155/// [`WorkerHost::launch_worker`]
156#[derive(Debug, MeshPayload)]
157pub struct WorkerHandle {
158    name: String,
159    send: mesh::Sender<WorkerRpc<mesh::OwnedMessage>>,
160    events: mesh::Receiver<WorkerEvent>,
161}
162
163impl Inspect for WorkerHandle {
164    fn inspect(&self, req: inspect::Request<'_>) {
165        self.send.send(WorkerRpc::Inspect(req.defer()))
166    }
167}
168
169impl Stream for WorkerHandle {
170    type Item = WorkerEvent;
171
172    fn poll_next(
173        mut self: Pin<&mut Self>,
174        cx: &mut std::task::Context<'_>,
175    ) -> Poll<Option<Self::Item>> {
176        Poll::Ready(match std::task::ready!(self.events.poll_recv(cx)) {
177            Ok(event) => Some(event),
178            Err(mesh::RecvError::Error(err)) => Some(WorkerEvent::Failed(RemoteError::new(err))),
179            Err(mesh::RecvError::Closed) => None,
180        })
181    }
182}
183
184impl FusedStream for WorkerHandle {
185    fn is_terminated(&self) -> bool {
186        self.events.is_terminated()
187    }
188}
189
190/// A lifetime event for a worker.
191#[derive(Debug, MeshPayload)]
192pub enum WorkerEvent {
193    /// The worker has stopped without error.
194    Stopped,
195    /// The worker has failed.
196    Failed(RemoteError),
197    /// The worker has started or restarted successfully.
198    Started,
199    /// The requested restart operation failed, but the worker is still running.
200    RestartFailed(RemoteError),
201}
202
203impl WorkerHandle {
204    /// Requests that the worker stop.
205    pub fn stop(&mut self) {
206        self.send.send(WorkerRpc::Stop);
207    }
208
209    /// Waits until the worker has stopped.
210    pub async fn join(&mut self) -> anyhow::Result<()> {
211        while let Some(event) = self.next().await {
212            if let WorkerEvent::Failed(err) = event {
213                return Err(err.into());
214            }
215        }
216        Ok(())
217    }
218
219    /// Stops the worker, then restarts it, using the same state, on `host`.
220    ///
221    /// This can be used to upgrade a worker at runtime if `host` is a
222    /// worker host in a new process.
223    pub fn restart(&mut self, host: &WorkerHost) {
224        let (send, recv) = mesh::channel();
225        let (events_send, events) = mesh::channel();
226
227        let send = std::mem::replace(&mut self.send, send);
228        let events = std::mem::replace(&mut self.events, events);
229
230        // Launch the new worker.
231        host.launch_worker_internal(
232            &self.name,
233            recv,
234            events_send,
235            LaunchType::Restart { send, events },
236        );
237    }
238}
239
240/// A handle used to launch workers on a host.
241///
242/// You can get an instance of this by spawning a new host with
243/// [`worker_host()`].
244///
245/// This may be sent across processes via mesh.
246#[derive(Debug, MeshPayload, Clone)]
247pub struct WorkerHost(mesh::Sender<WorkerHostLaunchRequest>);
248
249/// Returns a new [`WorkerHost`], [`WorkerHostRunner`] pair.
250///
251/// The [`WorkerHost`] is used to launch workers, while the [`WorkerHostRunner`]
252/// is used to handle worker launch requests. The caller must start
253/// [`WorkerHostRunner::run()`] on an appropriate task before `WorkerHost` will
254/// be able to launch workers.
255///
256/// This is useful over just using [`launch_local_worker`] because it provides
257/// an indirection between the identity of the workers being launched
258/// (identified via [`WorkerId`]) and the concrete worker implementation. This
259/// can be used to swap worker implementations, improve build times, and to
260/// support launching workers across process boundaries.
261///
262/// To achieve this latter feat, note that either half of the returned tuple may
263/// be sent to over a mesh channel to another process, allowing a worker to be
264/// spawned in a separate process from the caller. This can be useful for fault
265/// or resource isolation and for security sandboxing.
266///
267/// # Example
268/// ```
269/// # use mesh_worker::{worker_host, WorkerHost, WorkerHostRunner, RegisteredWorkers, register_workers};
270/// # use mesh_worker::test_support::DUMMY_WORKER as MY_WORKER;
271/// # use futures::executor::block_on;
272/// # register_workers!(mesh_worker::test_support::DummyWorker<u32>);
273/// # block_on(async {
274/// let (host, runner) = worker_host();
275/// // Run the worker host on a separate thread. (Typically this would just be
276/// // a separate task in your async framework.)
277/// std::thread::spawn(|| block_on(runner.run(RegisteredWorkers)));
278/// // Launch a worker by ID. This will call to the worker host runner.
279/// host.launch_worker(MY_WORKER, ()).await.unwrap();
280/// # })
281/// ```
282pub fn worker_host() -> (WorkerHost, WorkerHostRunner) {
283    let (send, recv) = mesh::mpsc_channel();
284    (WorkerHost(send), WorkerHostRunner(recv))
285}
286
287impl WorkerHost {
288    /// Launches a [`Worker`] instance on this host.
289    ///
290    /// Returns before the worker has finished launching. Look for the
291    /// [`WorkerEvent::Started`] event to ensure the worker did not fail to
292    /// start.
293    pub fn start_worker<T>(&self, id: WorkerId<T>, params: T) -> anyhow::Result<WorkerHandle>
294    where
295        T: 'static + MeshPayload + Send,
296    {
297        self.start_worker_inner(id.id(), mesh::OwnedMessage::new(params))
298    }
299
300    fn start_worker_inner(
301        &self,
302        id: &str,
303        parameters: mesh::OwnedMessage,
304    ) -> anyhow::Result<WorkerHandle> {
305        let (events_send, events_recv) = mesh::channel();
306        let (rpc_send, rpc_recv) = mesh::channel();
307        self.launch_worker_internal(id, rpc_recv, events_send, LaunchType::New { parameters });
308        Ok(WorkerHandle {
309            name: id.to_string(),
310            send: rpc_send,
311            events: events_recv,
312        })
313    }
314
315    /// Launches a [`Worker`] instance on this host, waiting for the worker to
316    /// start running.
317    pub async fn launch_worker<T>(&self, id: WorkerId<T>, params: T) -> anyhow::Result<WorkerHandle>
318    where
319        T: 'static + MeshPayload + Send,
320    {
321        let mut handle = self.start_worker_inner(id.id(), mesh::OwnedMessage::new(params))?;
322        match handle.next().await.context("failed to launch worker")? {
323            WorkerEvent::Started => Ok(handle),
324            WorkerEvent::Failed(err) => Err(err).context("failed to launch worker")?,
325            WorkerEvent::Stopped | WorkerEvent::RestartFailed(_) => {
326                anyhow::bail!("received invalid worker event")
327            }
328        }
329    }
330
331    fn launch_worker_internal(
332        &self,
333        id: &str,
334        rpc_recv: mesh::Receiver<WorkerRpc<mesh::OwnedMessage>>,
335        events_send: mesh::Sender<WorkerEvent>,
336        launch_type: LaunchType,
337    ) {
338        let request = WorkerHostLaunchRequest {
339            name: id.to_string(),
340            request: WorkerLaunchRequest {
341                rpc: rpc_recv,
342                events: events_send,
343                launch_type,
344            },
345        };
346
347        self.0.send(request);
348    }
349}
350
351/// Launches a worker locally.
352///
353/// When launched via this API, a worker's parameters do not have to derive
354/// `MeshPayload`.
355///
356/// # Example
357/// ```
358/// # use mesh_worker::test_support::DUMMY_WORKER;
359/// # use mesh_worker::WorkerHost;
360/// # use mesh_worker::launch_local_worker;
361/// # type MyWorker = mesh_worker::test_support::DummyWorker<u32>;
362/// # futures::executor::block_on(async {
363/// let worker = launch_local_worker::<MyWorker>(()).await.unwrap();
364/// # })
365/// ```
366pub async fn launch_local_worker<T: Worker>(
367    parameters: T::Parameters,
368) -> anyhow::Result<WorkerHandle> {
369    let (rpc_send, rpc_recv) = mesh::channel();
370    let (events_send, events_recv) = mesh::channel();
371    let (result_send, result_recv) = mesh::oneshot();
372
373    thread::Builder::new()
374        .name(format!("worker-{}", &T::ID.id()))
375        .spawn(move || match T::new(parameters) {
376            Ok(worker) => {
377                result_send.send(Ok(()));
378                match worker.run(rpc_recv) {
379                    Ok(()) => {
380                        events_send.send(WorkerEvent::Stopped);
381                    }
382                    Err(err) => {
383                        events_send.send(WorkerEvent::Failed(RemoteError::new(err)));
384                    }
385                }
386            }
387            Err(err) => {
388                result_send.send(Err(err));
389            }
390        })
391        .expect("thread launch failed");
392
393    result_recv.await.unwrap()?;
394    Ok(WorkerHandle {
395        name: T::ID.id().to_owned(),
396        // Erase the type of the worker state.
397        send: mesh::local_node::Port::from(rpc_send).into(),
398        events: events_recv,
399    })
400}
401
402/// Trait implemented by a type that can dispatch requests to a worker.
403///
404/// This trait is generally not used directly. Instead, use either
405/// [`RegisteredWorkers`], or generate a factory type with the
406/// [`crate::runnable_workers!`] macro.
407pub trait WorkerFactory: 'static + Send + Sync {
408    /// Returns a builder for the worker with the given name.
409    fn builder(&self, name: &str) -> anyhow::Result<WorkerBuilder>;
410}
411
412#[derive(Debug, MeshPayload)]
413struct WorkerLaunchRequest {
414    rpc: mesh::Receiver<WorkerRpc<mesh::OwnedMessage>>,
415    events: mesh::Sender<WorkerEvent>,
416    launch_type: LaunchType,
417}
418
419impl WorkerLaunchRequest {
420    fn fail(self, err: anyhow::Error) {
421        match self.launch_type {
422            LaunchType::New { .. } => {
423                self.events.send(WorkerEvent::Failed(RemoteError::new(err)));
424            }
425            LaunchType::Restart { send, events } => {
426                // Report the error and revert communications to the old worker.
427                self.events
428                    .send(WorkerEvent::RestartFailed(RemoteError::new(err)));
429                self.rpc.bridge(send);
430                self.events.bridge(events);
431            }
432        }
433    }
434
435    fn launch(self, builder: WorkerBuilder) {
436        let worker = match self.launch_type {
437            LaunchType::New { parameters } => {
438                let _span =
439                    tracing::info_span!("worker_new", name = builder.id, action = "new").entered();
440                match builder.build_and_run(BuildRequest::New(parameters)) {
441                    Ok(worker) => worker,
442                    Err(err) => {
443                        self.events.send(WorkerEvent::Failed(RemoteError::new(err)));
444                        return;
445                    }
446                }
447            }
448            LaunchType::Restart { send, events } => {
449                let state_recv = send.call_failable(WorkerRpc::Restart, ());
450                let state = match block_on(state_recv) {
451                    Ok(state) => state,
452                    Err(err) => {
453                        self.events
454                            .send(WorkerEvent::RestartFailed(RemoteError::new(err)));
455                        // Revert communications to the old worker.
456                        self.events.bridge(events);
457                        self.rpc.bridge(send);
458                        return;
459                    }
460                };
461                let _span =
462                    tracing::info_span!("worker_new", name = builder.id, action = "restart")
463                        .entered();
464                match builder.build_and_run(BuildRequest::Restart(state)) {
465                    Ok(worker) => worker,
466                    Err(err) => {
467                        self.events.send(WorkerEvent::Failed(RemoteError::new(err)));
468                        return;
469                    }
470                }
471            }
472        };
473
474        self.events.send(WorkerEvent::Started);
475        match worker.run(self.rpc) {
476            Ok(()) => {
477                self.events.send(WorkerEvent::Stopped);
478            }
479            Err(err) => {
480                self.events.send(WorkerEvent::Failed(RemoteError::new(err)));
481            }
482        }
483    }
484}
485
486/// A builder for a worker.
487pub struct WorkerBuilder {
488    inner: Box<dyn WorkerBuildAndRun>,
489    id: &'static str,
490}
491
492impl WorkerBuilder {
493    /// Returns a builder for `T`.
494    pub fn new<T: Worker>() -> Self
495    where
496        T::Parameters: MeshPayload,
497    {
498        Self {
499            inner: Box::new(BuilderInner::<T>(PhantomData)),
500            id: T::ID.id(),
501        }
502    }
503
504    fn build_and_run(self, request: BuildRequest) -> anyhow::Result<Box<dyn Run>> {
505        self.inner.build_and_run(request)
506    }
507}
508
509#[doc(hidden)]
510pub enum BuildRequest {
511    New(mesh::OwnedMessage),
512    Restart(mesh::OwnedMessage),
513}
514
515struct BuilderInner<T: Worker>(PhantomData<fn() -> T>);
516
517trait WorkerBuildAndRun: Send {
518    fn build_and_run(self: Box<Self>, request: BuildRequest) -> anyhow::Result<Box<dyn Run>>;
519}
520
521trait Run {
522    fn run(
523        self: Box<Self>,
524        recv: mesh::Receiver<WorkerRpc<mesh::OwnedMessage>>,
525    ) -> anyhow::Result<()>;
526}
527
528impl<T: Worker> Run for T {
529    fn run(
530        self: Box<Self>,
531        recv: mesh::Receiver<WorkerRpc<mesh::OwnedMessage>>,
532    ) -> anyhow::Result<()> {
533        // Unerase the type of the worker state.
534        let recv = mesh::local_node::Port::from(recv).into();
535        Worker::run(*self, recv)
536    }
537}
538
539impl<T: Worker> WorkerBuildAndRun for BuilderInner<T>
540where
541    T::Parameters: MeshPayload,
542{
543    fn build_and_run(self: Box<Self>, request: BuildRequest) -> anyhow::Result<Box<dyn Run>> {
544        let worker = match request {
545            BuildRequest::New(parameters) => {
546                T::new(parameters.parse().context("failed to receive parameters")?)
547            }
548            BuildRequest::Restart(state) => T::restart(
549                // Unerase the type of the worker state.
550                state
551                    .serialize()
552                    .into_message()
553                    .context("failed to parse restart state")?,
554            ),
555        }?;
556
557        Ok(Box::new(worker))
558    }
559}
560
561/// Generates a type that defines the set of workers that can be run by a worker host.
562/// Generate a type to that can be used to match a requested worker name and run it.
563///
564/// The resulting type is an empty struct implementing the [`WorkerFactory`] trait.
565///
566/// This is used to enumerate the list of worker types a host can instantiate.
567///
568/// Workers can be conditionally enabled by tagging them with a corresponding `#[cfg]` attr.
569///
570/// # Example
571///
572/// ```no_run
573/// # use mesh_worker::test_support;
574/// # use mesh_worker::runnable_workers;
575/// # type MyWorker1 = test_support::DummyWorker<u32>;
576/// # type MyWorker2 = test_support::DummyWorker<i32>;
577/// runnable_workers! {
578///     RunnableWorkers {
579///         MyWorker1,
580///         #[cfg(unix)]
581///         MyWorker2,
582///     }
583/// }
584/// ```
585#[macro_export]
586macro_rules! runnable_workers {
587    (
588        $name:ident {
589            $($(#[$vattr:meta])* $worker:ty),*$(,)?
590        }
591    ) => {
592
593        #[derive(Debug, Clone)]
594        struct $name;
595
596        impl $crate::WorkerFactory for $name {
597            fn builder(&self, name: &str) -> anyhow::Result<$crate::WorkerBuilder> {
598                $(
599                    $(#[$vattr])*
600                    {
601                        if name == <$worker as $crate::Worker>::ID.id() {
602                            return Ok($crate::WorkerBuilder::new::<$worker>());
603                        }
604                    }
605                )*
606
607                anyhow::bail!("unsupported worker {name}")
608            }
609        }
610    };
611}
612
613#[doc(hidden)]
614pub mod private {
615    // UNSAFETY: Needed for linkme.
616    #![expect(unsafe_code)]
617
618    use super::RegisteredWorkers;
619    use super::WorkerFactory;
620    use crate::Worker;
621    use crate::WorkerBuilder;
622    pub use linkme;
623    use mesh::MeshPayload;
624
625    // Use Option<&X> in case the linker inserts some stray nulls, as we think
626    // it might on Windows.
627    //
628    // See <https://devblogs.microsoft.com/oldnewthing/20181108-00/?p=100165>.
629    #[linkme::distributed_slice]
630    pub static WORKERS: [Option<&'static RegisteredWorker>] = [..];
631
632    // Always have at least one entry to work around linker bugs.
633    //
634    // See <https://github.com/llvm/llvm-project/issues/65855>.
635    #[linkme::distributed_slice(WORKERS)]
636    static WORKAROUND: Option<&'static RegisteredWorker> = None;
637
638    pub struct RegisteredWorker {
639        id: &'static str,
640        build: fn() -> WorkerBuilder,
641    }
642
643    impl RegisteredWorker {
644        pub const fn new<T: Worker>() -> Self
645        where
646            T::Parameters: MeshPayload,
647        {
648            Self {
649                id: T::ID.id(),
650                build: WorkerBuilder::new::<T>,
651            }
652        }
653    }
654
655    impl WorkerFactory for RegisteredWorkers {
656        fn builder(&self, name: &str) -> anyhow::Result<WorkerBuilder> {
657            for worker in WORKERS.iter().flatten() {
658                if worker.id == name {
659                    return Ok((worker.build)());
660                }
661            }
662            anyhow::bail!("unsupported worker {name}")
663        }
664    }
665
666    /// Registers workers for use with
667    /// [`RegisteredWorkers`](super::RegisteredWorkers).
668    ///
669    /// You can invoke this macro multiple times, even from different crates.
670    /// All registered workers will be available from any user of
671    /// `RegisteredWorkers`.
672    #[macro_export]
673    macro_rules! register_workers {
674        {} => {};
675        { $($(#[$attr:meta])* $worker:ty),+ $(,)? } => {
676            $(
677            $(#[$attr])*
678            const _: () = {
679                use $crate::private;
680                use private::linkme;
681
682                #[linkme::distributed_slice(private::WORKERS)]
683                #[linkme(crate = linkme)]
684                static WORKER: Option<&'static private::RegisteredWorker> = Some(&private::RegisteredWorker::new::<$worker>());
685            };
686            )*
687        };
688    }
689}
690
691/// A worker factory that can build any worker built with
692/// [`register_workers`](crate::register_workers).
693///
694/// ```
695/// # use mesh_worker::register_workers;
696/// # use mesh_worker::RegisteredWorkers;
697/// # use futures::executor::block_on;
698/// # type MyWorker1 = mesh_worker::test_support::DummyWorker<u32>;
699/// # type MyWorker2 = mesh_worker::test_support::DummyWorker<i32>;
700/// register_workers! {
701///     MyWorker1,
702///     MyWorker2,
703/// }
704///
705/// // Construct a worker host for these workers.
706/// let (host, runner) = mesh_worker::worker_host();
707/// std::thread::spawn(|| block_on(runner.run(RegisteredWorkers)));
708/// ```
709#[derive(Debug, Clone)]
710pub struct RegisteredWorkers;
711
712/// A request to launch a worker on a host.
713#[derive(Debug, MeshPayload)]
714struct WorkerHostLaunchRequest {
715    /// Name of the worker to launch
716    name: String,
717    /// Request parameters.
718    request: WorkerLaunchRequest,
719}
720
721#[cfg(test)]
722mod tests {
723    use super::Worker;
724    use super::WorkerFactory;
725    use super::WorkerId;
726    use super::WorkerRpc;
727    use crate::launch_local_worker;
728    use crate::worker::WorkerEvent;
729    use futures::StreamExt;
730    use futures::executor::block_on;
731    use mesh::MeshPayload;
732    use pal_async::DefaultDriver;
733    use pal_async::async_test;
734    use pal_async::task::Spawn;
735    use test_with_tracing::test;
736
737    struct TestWorker {
738        value: u64,
739    }
740
741    #[derive(MeshPayload, Default)]
742    struct TestWorkerConfig {
743        pub value: u64,
744    }
745
746    #[derive(MeshPayload)]
747    struct TestWorkerState {
748        pub value: u64,
749    }
750
751    impl Worker for TestWorker {
752        type Parameters = TestWorkerConfig;
753        type State = TestWorkerState;
754        const ID: WorkerId<Self::Parameters> = WorkerId::new("TestWorker");
755
756        fn new(parameters: Self::Parameters) -> anyhow::Result<Self> {
757            Ok(Self {
758                value: parameters.value,
759            })
760        }
761
762        fn restart(state: Self::State) -> anyhow::Result<Self> {
763            Ok(Self { value: state.value })
764        }
765
766        fn run(self, mut recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()> {
767            block_on(async {
768                while let Ok(req) = recv.recv().await {
769                    match req {
770                        WorkerRpc::Stop => break,
771                        WorkerRpc::Restart(rpc) => {
772                            rpc.complete(Ok(TestWorkerState { value: self.value }));
773                            break;
774                        }
775                        WorkerRpc::Inspect(_deferred) => (),
776                    }
777                }
778                Ok(())
779            })
780        }
781    }
782
783    struct TestWorker2;
784
785    impl Worker for TestWorker2 {
786        type Parameters = ();
787        type State = ();
788        const ID: WorkerId<Self::Parameters> = WorkerId::new("TestWorker2");
789
790        fn new(_parameters: Self::Parameters) -> anyhow::Result<Self> {
791            Ok(Self)
792        }
793
794        fn restart(_state: ()) -> anyhow::Result<Self> {
795            Ok(Self)
796        }
797
798        fn run(self, mut recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()> {
799            block_on(async {
800                while let Ok(req) = recv.recv().await {
801                    match req {
802                        WorkerRpc::Stop => break,
803                        WorkerRpc::Restart(rpc) => {
804                            rpc.complete(Ok(()));
805                            break;
806                        }
807                        WorkerRpc::Inspect(_deferred) => (),
808                    }
809                }
810                Ok(())
811            })
812        }
813    }
814
815    runnable_workers! {
816        RunnableWorkers1 {
817            TestWorker,
818        }
819    }
820
821    #[test]
822    fn test_runnable_workers_unsupported() {
823        let result = RunnableWorkers1.builder("foobar");
824
825        assert!(result.is_err());
826    }
827
828    #[async_test]
829    async fn test_launch_worker_remote_supported_worker(driver: DefaultDriver) {
830        let (host, runner) = super::worker_host();
831
832        // N.B. remote host needs to start first to recv and respond as
833        // launch_worker blocks waiting for the response.
834        let task = driver.spawn("runner", runner.run(RunnableWorkers1));
835
836        let result = host.launch_worker(TestWorker::ID, Default::default()).await;
837
838        assert!(result.is_ok());
839        // drop the handle to get the worker to exit.
840        drop(result.unwrap());
841        // drop the host (owns the send port), to get the host to exit.
842        drop(host);
843        task.await;
844    }
845
846    #[async_test]
847    async fn test_launch_worker_remote_unsupported_worker(driver: DefaultDriver) {
848        let (host, runner) = super::worker_host();
849
850        // N.B. remote host needs to start first to recv and respond as
851        // launch_worker blocks waiting for the response.
852        let task = driver.spawn("runner", runner.run(RunnableWorkers1));
853
854        let result = host.launch_worker(TestWorker2::ID, ()).await;
855
856        assert!(result.is_err());
857        // drop the target (owns the send port), to get the host to exit.
858        drop(host);
859        task.await;
860    }
861
862    #[async_test]
863    async fn test_launch_worker_remote_restart_worker(driver: DefaultDriver) {
864        let (host, runner) = super::worker_host();
865
866        // N.B. remote host needs to start first to recv and respond as
867        // launch_worker blocks waiting for the response.
868        let task = driver.spawn("runner", runner.run(RunnableWorkers1));
869
870        let result = host.launch_worker(TestWorker::ID, Default::default()).await;
871
872        let mut handle = result.expect("worker launch failed");
873        handle.restart(&host);
874        assert!(matches!(handle.next().await.unwrap(), WorkerEvent::Started));
875        handle.stop();
876        assert!(matches!(handle.next().await.unwrap(), WorkerEvent::Stopped));
877
878        assert!(handle.next().await.is_none());
879
880        // drop the target (owns the send port), to get the host to exit.
881        drop(host);
882        task.await;
883    }
884
885    struct LocalWorker;
886
887    impl Worker for LocalWorker {
888        type Parameters = fn() -> anyhow::Result<()>;
889        type State = ();
890        const ID: WorkerId<Self::Parameters> = WorkerId::new("local");
891
892        fn new(parameters: Self::Parameters) -> anyhow::Result<Self> {
893            parameters()?;
894            Ok(Self)
895        }
896
897        fn restart(_state: Self::State) -> anyhow::Result<Self> {
898            unreachable!()
899        }
900
901        fn run(self, _recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()> {
902            Ok(())
903        }
904    }
905
906    #[async_test]
907    async fn test_launch_local_no_mesh() {
908        let mut worker = launch_local_worker::<LocalWorker>(|| Ok(())).await.unwrap();
909        worker.join().await.unwrap();
910    }
911}
912
913/// Internal test support
914#[doc(hidden)]
915pub mod test_support {
916    use std::marker::PhantomData;
917
918    use crate::Worker;
919    use crate::WorkerId;
920    use crate::WorkerRpc;
921
922    // Worker that always fails. Used for doc tests.
923    pub struct DummyWorker<T>(PhantomData<T>);
924
925    pub const DUMMY_WORKER: WorkerId<()> = WorkerId::new("DummyWorker");
926
927    impl<T: 'static + Send> Worker for DummyWorker<T> {
928        type Parameters = ();
929        type State = ();
930        const ID: WorkerId<Self::Parameters> = DUMMY_WORKER;
931
932        fn new(_parameters: Self::Parameters) -> anyhow::Result<Self> {
933            Ok(Self(PhantomData))
934        }
935
936        fn restart(_state: Self::State) -> anyhow::Result<Self> {
937            Ok(Self(PhantomData))
938        }
939
940        fn run(self, _recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()> {
941            todo!()
942        }
943    }
944}