mesh_worker/
worker.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Infrastructure for workers that can run on mesh nodes.
5
6use anyhow::Context;
7use futures::Stream;
8use futures::StreamExt;
9use futures::executor::block_on;
10use futures::stream::FusedStream;
11use futures_concurrency::stream::Merge;
12use inspect::Inspect;
13use mesh::MeshPayload;
14use mesh::error::RemoteError;
15use mesh::rpc::FailableRpc;
16use mesh::rpc::RpcSend;
17use std::fmt;
18use std::marker::PhantomData;
19use std::pin::Pin;
20use std::task::Poll;
21use std::thread;
22use unicycle::FuturesUnordered;
23
24/// A unique identifier for a worker, used to specify which worker to launch.
25#[derive(Copy, Clone, Debug)]
26pub struct WorkerId<T>(&'static str, PhantomData<T>);
27
28impl<T> WorkerId<T> {
29    /// Makes a new worker ID with the name `id`.
30    pub const fn new(id: &'static str) -> Self {
31        Self(id, PhantomData)
32    }
33
34    /// Gets the ID string.
35    pub const fn id(&self) -> &'static str {
36        self.0
37    }
38}
39
40impl<T> fmt::Display for WorkerId<T> {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        f.pad(self.0)
43    }
44}
45
46/// Trait implemented by workers.
47pub trait Worker: 'static + Sized {
48    /// Parameters passed to launch the worker. Used with [`Worker::new`].
49    ///
50    /// For this worker to be spawned on a remote node, `Parameters` must
51    /// implement [`MeshPayload`].
52    type Parameters: 'static + Send;
53
54    /// State used to implement hot restart. Used with [`Worker::restart`].
55    type State: 'static + MeshPayload + Send;
56
57    /// String identifying the Worker. Used when launching workers in separate processes
58    /// to specify which workers are supported and which worker to launch.
59    /// IDs must be unique within a given worker host.
60    const ID: WorkerId<Self::Parameters>;
61
62    /// Instantiates the worker.
63    ///
64    /// The worker should not start running yet, but it can allocate any resources
65    /// necessary to run.
66    fn new(parameters: Self::Parameters) -> anyhow::Result<Self>;
67
68    /// Restarts a worker from a previous worker's execution state.
69    fn restart(state: Self::State) -> anyhow::Result<Self>;
70
71    /// Synchronously runs the worker on the current thread.
72    ///
73    /// The worker should respond to commands sent in `recv`. If `recv` is closed,
74    /// the worker should exit.
75    ///
76    /// The worker ends when it returns from this function.
77    fn run(self, recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()>;
78}
79
80/// Common requests for workers.
81#[derive(Debug, MeshPayload)]
82#[mesh(bound = "T: 'static + MeshPayload + Send")]
83pub enum WorkerRpc<T> {
84    /// Tear down.
85    Stop,
86    /// Tear down and send the state necessary to restart on the provided
87    /// channel.
88    Restart(FailableRpc<(), T>),
89    /// Inspect the worker.
90    Inspect(inspect::Deferred),
91}
92
93#[derive(Debug, MeshPayload)]
94enum LaunchType {
95    New {
96        parameters: mesh::OwnedMessage,
97    },
98    Restart {
99        send: mesh::Sender<WorkerRpc<mesh::OwnedMessage>>,
100        events: mesh::Receiver<WorkerEvent>,
101    },
102}
103
104/// A runner returned by [`worker_host()`]. Used to handle worker launch
105/// requests.
106///
107/// This may be sent across processes via mesh.
108#[derive(Debug, MeshPayload)]
109pub struct WorkerHostRunner(mesh::Receiver<WorkerHostLaunchRequest>);
110
111impl WorkerHostRunner {
112    /// Runs the worker host until all corresponding [`WorkerHost`] instances
113    /// have been dropped and all workers have exited.
114    ///
115    /// `factory` provides the set of possible workers to launch. Typically,
116    /// this will be [`RegisteredWorkers`].
117    pub async fn run(mut self, factory: impl WorkerFactory) {
118        let mut rundown = FuturesUnordered::new();
119        loop {
120            let mut stream = ((&mut self.0).map(Some), (&mut rundown).map(|_| None)).merge();
121            let launch_params = match stream.next().await {
122                Some(Some(launch_params)) => launch_params,
123                Some(None) => continue,
124                None => break,
125            };
126
127            let _requestspan = tracing::info_span!("worker_host_launch_request").entered();
128
129            let result = factory.builder(&launch_params.name);
130            match result {
131                Ok(runner) => {
132                    // start a new thread and run the runner.
133                    let (rundown_send, rundown_recv) = mesh::oneshot::<()>();
134                    thread::Builder::new()
135                        .name(format!("worker-{}", &launch_params.name))
136                        .spawn(move || {
137                            launch_params.request.launch(runner);
138                            drop(rundown_send);
139                        })
140                        .expect("thread launch failed");
141
142                    rundown.push(rundown_recv);
143                }
144                Err(err) => {
145                    // TODO: kharp 2021-05-26 Better tracing of errors, maybe tracing_error?
146                    launch_params.request.fail(err);
147                }
148            }
149        }
150    }
151}
152
153/// Represents a running [`Worker`] instance providing the ability to restart,
154/// stop or wait for exit. To launch a worker and get a handle, use
155/// [`WorkerHost::launch_worker`]
156#[derive(Debug, MeshPayload, Inspect)]
157pub struct WorkerHandle {
158    #[inspect(skip)]
159    name: String,
160    #[inspect(flatten, send = "WorkerRpc::Inspect")]
161    send: mesh::Sender<WorkerRpc<mesh::OwnedMessage>>,
162    #[inspect(skip)]
163    events: mesh::Receiver<WorkerEvent>,
164}
165
166impl Stream for WorkerHandle {
167    type Item = WorkerEvent;
168
169    fn poll_next(
170        mut self: Pin<&mut Self>,
171        cx: &mut std::task::Context<'_>,
172    ) -> Poll<Option<Self::Item>> {
173        Poll::Ready(match std::task::ready!(self.events.poll_recv(cx)) {
174            Ok(event) => Some(event),
175            Err(mesh::RecvError::Error(err)) => Some(WorkerEvent::Failed(RemoteError::new(err))),
176            Err(mesh::RecvError::Closed) => None,
177        })
178    }
179}
180
181impl FusedStream for WorkerHandle {
182    fn is_terminated(&self) -> bool {
183        self.events.is_terminated()
184    }
185}
186
187/// A lifetime event for a worker.
188#[derive(Debug, MeshPayload)]
189pub enum WorkerEvent {
190    /// The worker has stopped without error.
191    Stopped,
192    /// The worker has failed.
193    Failed(RemoteError),
194    /// The worker has started or restarted successfully.
195    Started,
196    /// The requested restart operation failed, but the worker is still running.
197    RestartFailed(RemoteError),
198}
199
200impl WorkerHandle {
201    /// Requests that the worker stop.
202    pub fn stop(&mut self) {
203        self.send.send(WorkerRpc::Stop);
204    }
205
206    /// Waits until the worker has stopped.
207    pub async fn join(&mut self) -> anyhow::Result<()> {
208        while let Some(event) = self.next().await {
209            if let WorkerEvent::Failed(err) = event {
210                return Err(err.into());
211            }
212        }
213        Ok(())
214    }
215
216    /// Stops the worker, then restarts it, using the same state, on `host`.
217    ///
218    /// This can be used to upgrade a worker at runtime if `host` is a
219    /// worker host in a new process.
220    pub fn restart(&mut self, host: &WorkerHost) {
221        let (send, recv) = mesh::channel();
222        let (events_send, events) = mesh::channel();
223
224        let send = std::mem::replace(&mut self.send, send);
225        let events = std::mem::replace(&mut self.events, events);
226
227        // Launch the new worker.
228        host.launch_worker_internal(
229            &self.name,
230            recv,
231            events_send,
232            LaunchType::Restart { send, events },
233        );
234    }
235}
236
237/// A handle used to launch workers on a host.
238///
239/// You can get an instance of this by spawning a new host with
240/// [`worker_host()`].
241///
242/// This may be sent across processes via mesh.
243#[derive(Debug, MeshPayload, Clone)]
244pub struct WorkerHost(mesh::Sender<WorkerHostLaunchRequest>);
245
246/// Returns a new [`WorkerHost`], [`WorkerHostRunner`] pair.
247///
248/// The [`WorkerHost`] is used to launch workers, while the [`WorkerHostRunner`]
249/// is used to handle worker launch requests. The caller must start
250/// [`WorkerHostRunner::run()`] on an appropriate task before `WorkerHost` will
251/// be able to launch workers.
252///
253/// This is useful over just using [`launch_local_worker`] because it provides
254/// an indirection between the identity of the workers being launched
255/// (identified via [`WorkerId`]) and the concrete worker implementation. This
256/// can be used to swap worker implementations, improve build times, and to
257/// support launching workers across process boundaries.
258///
259/// To achieve this latter feat, note that either half of the returned tuple may
260/// be sent to over a mesh channel to another process, allowing a worker to be
261/// spawned in a separate process from the caller. This can be useful for fault
262/// or resource isolation and for security sandboxing.
263///
264/// # Example
265/// ```
266/// # use mesh_worker::{worker_host, WorkerHost, WorkerHostRunner, RegisteredWorkers, register_workers};
267/// # use mesh_worker::test_support::DUMMY_WORKER as MY_WORKER;
268/// # use futures::executor::block_on;
269/// # register_workers!(mesh_worker::test_support::DummyWorker<u32>);
270/// # block_on(async {
271/// let (host, runner) = worker_host();
272/// // Run the worker host on a separate thread. (Typically this would just be
273/// // a separate task in your async framework.)
274/// std::thread::spawn(|| block_on(runner.run(RegisteredWorkers)));
275/// // Launch a worker by ID. This will call to the worker host runner.
276/// host.launch_worker(MY_WORKER, ()).await.unwrap();
277/// # })
278/// ```
279pub fn worker_host() -> (WorkerHost, WorkerHostRunner) {
280    let (send, recv) = mesh::mpsc_channel();
281    (WorkerHost(send), WorkerHostRunner(recv))
282}
283
284impl WorkerHost {
285    /// Launches a [`Worker`] instance on this host.
286    ///
287    /// Returns before the worker has finished launching. Look for the
288    /// [`WorkerEvent::Started`] event to ensure the worker did not fail to
289    /// start.
290    pub fn start_worker<T>(&self, id: WorkerId<T>, params: T) -> anyhow::Result<WorkerHandle>
291    where
292        T: 'static + MeshPayload + Send,
293    {
294        self.start_worker_inner(id.id(), mesh::OwnedMessage::new(params))
295    }
296
297    fn start_worker_inner(
298        &self,
299        id: &str,
300        parameters: mesh::OwnedMessage,
301    ) -> anyhow::Result<WorkerHandle> {
302        let (events_send, events_recv) = mesh::channel();
303        let (rpc_send, rpc_recv) = mesh::channel();
304        self.launch_worker_internal(id, rpc_recv, events_send, LaunchType::New { parameters });
305        Ok(WorkerHandle {
306            name: id.to_string(),
307            send: rpc_send,
308            events: events_recv,
309        })
310    }
311
312    /// Launches a [`Worker`] instance on this host, waiting for the worker to
313    /// start running.
314    pub async fn launch_worker<T>(&self, id: WorkerId<T>, params: T) -> anyhow::Result<WorkerHandle>
315    where
316        T: 'static + MeshPayload + Send,
317    {
318        let mut handle = self.start_worker_inner(id.id(), mesh::OwnedMessage::new(params))?;
319        match handle.next().await.context("failed to launch worker")? {
320            WorkerEvent::Started => Ok(handle),
321            WorkerEvent::Failed(err) => Err(err).context("failed to launch worker")?,
322            WorkerEvent::Stopped | WorkerEvent::RestartFailed(_) => {
323                anyhow::bail!("received invalid worker event")
324            }
325        }
326    }
327
328    fn launch_worker_internal(
329        &self,
330        id: &str,
331        rpc_recv: mesh::Receiver<WorkerRpc<mesh::OwnedMessage>>,
332        events_send: mesh::Sender<WorkerEvent>,
333        launch_type: LaunchType,
334    ) {
335        let request = WorkerHostLaunchRequest {
336            name: id.to_string(),
337            request: WorkerLaunchRequest {
338                rpc: rpc_recv,
339                events: events_send,
340                launch_type,
341            },
342        };
343
344        self.0.send(request);
345    }
346}
347
348/// Launches a worker locally.
349///
350/// When launched via this API, a worker's parameters do not have to derive
351/// `MeshPayload`.
352///
353/// # Example
354/// ```
355/// # use mesh_worker::test_support::DUMMY_WORKER;
356/// # use mesh_worker::WorkerHost;
357/// # use mesh_worker::launch_local_worker;
358/// # type MyWorker = mesh_worker::test_support::DummyWorker<u32>;
359/// # futures::executor::block_on(async {
360/// let worker = launch_local_worker::<MyWorker>(()).await.unwrap();
361/// # })
362/// ```
363pub async fn launch_local_worker<T: Worker>(
364    parameters: T::Parameters,
365) -> anyhow::Result<WorkerHandle> {
366    let (rpc_send, rpc_recv) = mesh::channel();
367    let (events_send, events_recv) = mesh::channel();
368    let (result_send, result_recv) = mesh::oneshot();
369
370    thread::Builder::new()
371        .name(format!("worker-{}", &T::ID.id()))
372        .spawn(move || match T::new(parameters) {
373            Ok(worker) => {
374                result_send.send(Ok(()));
375                match worker.run(rpc_recv) {
376                    Ok(()) => {
377                        events_send.send(WorkerEvent::Stopped);
378                    }
379                    Err(err) => {
380                        events_send.send(WorkerEvent::Failed(RemoteError::new(err)));
381                    }
382                }
383            }
384            Err(err) => {
385                result_send.send(Err(err));
386            }
387        })
388        .expect("thread launch failed");
389
390    result_recv.await.unwrap()?;
391    Ok(WorkerHandle {
392        name: T::ID.id().to_owned(),
393        // Erase the type of the worker state.
394        send: mesh::local_node::Port::from(rpc_send).into(),
395        events: events_recv,
396    })
397}
398
399/// Trait implemented by a type that can dispatch requests to a worker.
400///
401/// This trait is generally not used directly. Instead, use either
402/// [`RegisteredWorkers`], or generate a factory type with the
403/// [`crate::runnable_workers!`] macro.
404pub trait WorkerFactory: 'static + Send + Sync {
405    /// Returns a builder for the worker with the given name.
406    fn builder(&self, name: &str) -> anyhow::Result<WorkerBuilder>;
407}
408
409#[derive(Debug, MeshPayload)]
410struct WorkerLaunchRequest {
411    rpc: mesh::Receiver<WorkerRpc<mesh::OwnedMessage>>,
412    events: mesh::Sender<WorkerEvent>,
413    launch_type: LaunchType,
414}
415
416impl WorkerLaunchRequest {
417    fn fail(self, err: anyhow::Error) {
418        match self.launch_type {
419            LaunchType::New { .. } => {
420                self.events.send(WorkerEvent::Failed(RemoteError::new(err)));
421            }
422            LaunchType::Restart { send, events } => {
423                // Report the error and revert communications to the old worker.
424                self.events
425                    .send(WorkerEvent::RestartFailed(RemoteError::new(err)));
426                self.rpc.bridge(send);
427                self.events.bridge(events);
428            }
429        }
430    }
431
432    fn launch(self, builder: WorkerBuilder) {
433        let worker = match self.launch_type {
434            LaunchType::New { parameters } => {
435                let _span =
436                    tracing::info_span!("worker_new", name = builder.id, action = "new").entered();
437                match builder.build_and_run(BuildRequest::New(parameters)) {
438                    Ok(worker) => worker,
439                    Err(err) => {
440                        self.events.send(WorkerEvent::Failed(RemoteError::new(err)));
441                        return;
442                    }
443                }
444            }
445            LaunchType::Restart { send, events } => {
446                let state_recv = send.call_failable(WorkerRpc::Restart, ());
447                let state = match block_on(state_recv) {
448                    Ok(state) => state,
449                    Err(err) => {
450                        self.events
451                            .send(WorkerEvent::RestartFailed(RemoteError::new(err)));
452                        // Revert communications to the old worker.
453                        self.events.bridge(events);
454                        self.rpc.bridge(send);
455                        return;
456                    }
457                };
458                let _span =
459                    tracing::info_span!("worker_new", name = builder.id, action = "restart")
460                        .entered();
461                match builder.build_and_run(BuildRequest::Restart(state)) {
462                    Ok(worker) => worker,
463                    Err(err) => {
464                        self.events.send(WorkerEvent::Failed(RemoteError::new(err)));
465                        return;
466                    }
467                }
468            }
469        };
470
471        self.events.send(WorkerEvent::Started);
472        match worker.run(self.rpc) {
473            Ok(()) => {
474                self.events.send(WorkerEvent::Stopped);
475            }
476            Err(err) => {
477                self.events.send(WorkerEvent::Failed(RemoteError::new(err)));
478            }
479        }
480    }
481}
482
483/// A builder for a worker.
484pub struct WorkerBuilder {
485    inner: Box<dyn WorkerBuildAndRun>,
486    id: &'static str,
487}
488
489impl WorkerBuilder {
490    /// Returns a builder for `T`.
491    pub fn new<T: Worker>() -> Self
492    where
493        T::Parameters: MeshPayload,
494    {
495        Self {
496            inner: Box::new(BuilderInner::<T>(PhantomData)),
497            id: T::ID.id(),
498        }
499    }
500
501    fn build_and_run(self, request: BuildRequest) -> anyhow::Result<Box<dyn Run>> {
502        self.inner.build_and_run(request)
503    }
504}
505
506#[doc(hidden)]
507pub enum BuildRequest {
508    New(mesh::OwnedMessage),
509    Restart(mesh::OwnedMessage),
510}
511
512struct BuilderInner<T: Worker>(PhantomData<fn() -> T>);
513
514trait WorkerBuildAndRun: Send {
515    fn build_and_run(self: Box<Self>, request: BuildRequest) -> anyhow::Result<Box<dyn Run>>;
516}
517
518trait Run {
519    fn run(
520        self: Box<Self>,
521        recv: mesh::Receiver<WorkerRpc<mesh::OwnedMessage>>,
522    ) -> anyhow::Result<()>;
523}
524
525impl<T: Worker> Run for T {
526    fn run(
527        self: Box<Self>,
528        recv: mesh::Receiver<WorkerRpc<mesh::OwnedMessage>>,
529    ) -> anyhow::Result<()> {
530        // Unerase the type of the worker state.
531        let recv = mesh::local_node::Port::from(recv).into();
532        Worker::run(*self, recv)
533    }
534}
535
536impl<T: Worker> WorkerBuildAndRun for BuilderInner<T>
537where
538    T::Parameters: MeshPayload,
539{
540    fn build_and_run(self: Box<Self>, request: BuildRequest) -> anyhow::Result<Box<dyn Run>> {
541        let worker = match request {
542            BuildRequest::New(parameters) => {
543                T::new(parameters.parse().context("failed to receive parameters")?)
544            }
545            BuildRequest::Restart(state) => T::restart(
546                // Unerase the type of the worker state.
547                state
548                    .serialize()
549                    .into_message()
550                    .context("failed to parse restart state")?,
551            ),
552        }?;
553
554        Ok(Box::new(worker))
555    }
556}
557
558/// Generates a type that defines the set of workers that can be run by a worker host.
559/// Generate a type to that can be used to match a requested worker name and run it.
560///
561/// The resulting type is an empty struct implementing the [`WorkerFactory`] trait.
562///
563/// This is used to enumerate the list of worker types a host can instantiate.
564///
565/// Workers can be conditionally enabled by tagging them with a corresponding `#[cfg]` attr.
566///
567/// # Example
568///
569/// ```no_run
570/// # use mesh_worker::test_support;
571/// # use mesh_worker::runnable_workers;
572/// # type MyWorker1 = test_support::DummyWorker<u32>;
573/// # type MyWorker2 = test_support::DummyWorker<i32>;
574/// runnable_workers! {
575///     RunnableWorkers {
576///         MyWorker1,
577///         #[cfg(unix)]
578///         MyWorker2,
579///     }
580/// }
581/// ```
582#[macro_export]
583macro_rules! runnable_workers {
584    (
585        $name:ident {
586            $($(#[$vattr:meta])* $worker:ty),*$(,)?
587        }
588    ) => {
589
590        #[derive(Debug, Clone)]
591        struct $name;
592
593        impl $crate::WorkerFactory for $name {
594            fn builder(&self, name: &str) -> anyhow::Result<$crate::WorkerBuilder> {
595                $(
596                    $(#[$vattr])*
597                    {
598                        if name == <$worker as $crate::Worker>::ID.id() {
599                            return Ok($crate::WorkerBuilder::new::<$worker>());
600                        }
601                    }
602                )*
603
604                anyhow::bail!("unsupported worker {name}")
605            }
606        }
607    };
608}
609
610#[doc(hidden)]
611pub mod private {
612    // UNSAFETY: Needed for linkme.
613    #![expect(unsafe_code)]
614
615    use super::RegisteredWorkers;
616    use super::WorkerFactory;
617    use crate::Worker;
618    use crate::WorkerBuilder;
619    pub use linkme;
620    use mesh::MeshPayload;
621
622    // Use Option<&X> in case the linker inserts some stray nulls, as we think
623    // it might on Windows.
624    //
625    // See <https://devblogs.microsoft.com/oldnewthing/20181108-00/?p=100165>.
626    #[linkme::distributed_slice]
627    pub static WORKERS: [Option<&'static RegisteredWorker>] = [..];
628
629    // Always have at least one entry to work around linker bugs.
630    //
631    // See <https://github.com/llvm/llvm-project/issues/65855>.
632    #[linkme::distributed_slice(WORKERS)]
633    static WORKAROUND: Option<&'static RegisteredWorker> = None;
634
635    pub struct RegisteredWorker {
636        id: &'static str,
637        build: fn() -> WorkerBuilder,
638    }
639
640    impl RegisteredWorker {
641        pub const fn new<T: Worker>() -> Self
642        where
643            T::Parameters: MeshPayload,
644        {
645            Self {
646                id: T::ID.id(),
647                build: WorkerBuilder::new::<T>,
648            }
649        }
650    }
651
652    impl WorkerFactory for RegisteredWorkers {
653        fn builder(&self, name: &str) -> anyhow::Result<WorkerBuilder> {
654            for worker in WORKERS.iter().flatten() {
655                if worker.id == name {
656                    return Ok((worker.build)());
657                }
658            }
659            anyhow::bail!("unsupported worker {name}")
660        }
661    }
662
663    /// Registers workers for use with
664    /// [`RegisteredWorkers`](super::RegisteredWorkers).
665    ///
666    /// You can invoke this macro multiple times, even from different crates.
667    /// All registered workers will be available from any user of
668    /// `RegisteredWorkers`.
669    #[macro_export]
670    macro_rules! register_workers {
671        {} => {};
672        { $($(#[$attr:meta])* $worker:ty),+ $(,)? } => {
673            $(
674            $(#[$attr])*
675            const _: () = {
676                use $crate::private;
677                use private::linkme;
678
679                #[linkme::distributed_slice(private::WORKERS)]
680                #[linkme(crate = linkme)]
681                static WORKER: Option<&'static private::RegisteredWorker> = Some(&private::RegisteredWorker::new::<$worker>());
682            };
683            )*
684        };
685    }
686}
687
688/// A worker factory that can build any worker built with
689/// [`register_workers`](crate::register_workers).
690///
691/// ```
692/// # use mesh_worker::register_workers;
693/// # use mesh_worker::RegisteredWorkers;
694/// # use futures::executor::block_on;
695/// # type MyWorker1 = mesh_worker::test_support::DummyWorker<u32>;
696/// # type MyWorker2 = mesh_worker::test_support::DummyWorker<i32>;
697/// register_workers! {
698///     MyWorker1,
699///     MyWorker2,
700/// }
701///
702/// // Construct a worker host for these workers.
703/// let (host, runner) = mesh_worker::worker_host();
704/// std::thread::spawn(|| block_on(runner.run(RegisteredWorkers)));
705/// ```
706#[derive(Debug, Clone)]
707pub struct RegisteredWorkers;
708
709/// A request to launch a worker on a host.
710#[derive(Debug, MeshPayload)]
711struct WorkerHostLaunchRequest {
712    /// Name of the worker to launch
713    name: String,
714    /// Request parameters.
715    request: WorkerLaunchRequest,
716}
717
718#[cfg(test)]
719mod tests {
720    use super::Worker;
721    use super::WorkerFactory;
722    use super::WorkerId;
723    use super::WorkerRpc;
724    use crate::launch_local_worker;
725    use crate::worker::WorkerEvent;
726    use futures::StreamExt;
727    use futures::executor::block_on;
728    use mesh::MeshPayload;
729    use pal_async::DefaultDriver;
730    use pal_async::async_test;
731    use pal_async::task::Spawn;
732    use test_with_tracing::test;
733
734    struct TestWorker {
735        value: u64,
736    }
737
738    #[derive(MeshPayload, Default)]
739    struct TestWorkerConfig {
740        pub value: u64,
741    }
742
743    #[derive(MeshPayload)]
744    struct TestWorkerState {
745        pub value: u64,
746    }
747
748    impl Worker for TestWorker {
749        type Parameters = TestWorkerConfig;
750        type State = TestWorkerState;
751        const ID: WorkerId<Self::Parameters> = WorkerId::new("TestWorker");
752
753        fn new(parameters: Self::Parameters) -> anyhow::Result<Self> {
754            Ok(Self {
755                value: parameters.value,
756            })
757        }
758
759        fn restart(state: Self::State) -> anyhow::Result<Self> {
760            Ok(Self { value: state.value })
761        }
762
763        fn run(self, mut recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()> {
764            block_on(async {
765                while let Ok(req) = recv.recv().await {
766                    match req {
767                        WorkerRpc::Stop => break,
768                        WorkerRpc::Restart(rpc) => {
769                            rpc.complete(Ok(TestWorkerState { value: self.value }));
770                            break;
771                        }
772                        WorkerRpc::Inspect(_deferred) => (),
773                    }
774                }
775                Ok(())
776            })
777        }
778    }
779
780    struct TestWorker2;
781
782    impl Worker for TestWorker2 {
783        type Parameters = ();
784        type State = ();
785        const ID: WorkerId<Self::Parameters> = WorkerId::new("TestWorker2");
786
787        fn new(_parameters: Self::Parameters) -> anyhow::Result<Self> {
788            Ok(Self)
789        }
790
791        fn restart(_state: ()) -> anyhow::Result<Self> {
792            Ok(Self)
793        }
794
795        fn run(self, mut recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()> {
796            block_on(async {
797                while let Ok(req) = recv.recv().await {
798                    match req {
799                        WorkerRpc::Stop => break,
800                        WorkerRpc::Restart(rpc) => {
801                            rpc.complete(Ok(()));
802                            break;
803                        }
804                        WorkerRpc::Inspect(_deferred) => (),
805                    }
806                }
807                Ok(())
808            })
809        }
810    }
811
812    runnable_workers! {
813        RunnableWorkers1 {
814            TestWorker,
815        }
816    }
817
818    #[test]
819    fn test_runnable_workers_unsupported() {
820        let result = RunnableWorkers1.builder("foobar");
821
822        assert!(result.is_err());
823    }
824
825    #[async_test]
826    async fn test_launch_worker_remote_supported_worker(driver: DefaultDriver) {
827        let (host, runner) = super::worker_host();
828
829        // N.B. remote host needs to start first to recv and respond as
830        // launch_worker blocks waiting for the response.
831        let task = driver.spawn("runner", runner.run(RunnableWorkers1));
832
833        let result = host.launch_worker(TestWorker::ID, Default::default()).await;
834
835        assert!(result.is_ok());
836        // drop the handle to get the worker to exit.
837        drop(result.unwrap());
838        // drop the host (owns the send port), to get the host to exit.
839        drop(host);
840        task.await;
841    }
842
843    #[async_test]
844    async fn test_launch_worker_remote_unsupported_worker(driver: DefaultDriver) {
845        let (host, runner) = super::worker_host();
846
847        // N.B. remote host needs to start first to recv and respond as
848        // launch_worker blocks waiting for the response.
849        let task = driver.spawn("runner", runner.run(RunnableWorkers1));
850
851        let result = host.launch_worker(TestWorker2::ID, ()).await;
852
853        assert!(result.is_err());
854        // drop the target (owns the send port), to get the host to exit.
855        drop(host);
856        task.await;
857    }
858
859    #[async_test]
860    async fn test_launch_worker_remote_restart_worker(driver: DefaultDriver) {
861        let (host, runner) = super::worker_host();
862
863        // N.B. remote host needs to start first to recv and respond as
864        // launch_worker blocks waiting for the response.
865        let task = driver.spawn("runner", runner.run(RunnableWorkers1));
866
867        let result = host.launch_worker(TestWorker::ID, Default::default()).await;
868
869        let mut handle = result.expect("worker launch failed");
870        handle.restart(&host);
871        assert!(matches!(handle.next().await.unwrap(), WorkerEvent::Started));
872        handle.stop();
873        assert!(matches!(handle.next().await.unwrap(), WorkerEvent::Stopped));
874
875        assert!(handle.next().await.is_none());
876
877        // drop the target (owns the send port), to get the host to exit.
878        drop(host);
879        task.await;
880    }
881
882    struct LocalWorker;
883
884    impl Worker for LocalWorker {
885        type Parameters = fn() -> anyhow::Result<()>;
886        type State = ();
887        const ID: WorkerId<Self::Parameters> = WorkerId::new("local");
888
889        fn new(parameters: Self::Parameters) -> anyhow::Result<Self> {
890            parameters()?;
891            Ok(Self)
892        }
893
894        fn restart(_state: Self::State) -> anyhow::Result<Self> {
895            unreachable!()
896        }
897
898        fn run(self, _recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()> {
899            Ok(())
900        }
901    }
902
903    #[async_test]
904    async fn test_launch_local_no_mesh() {
905        let mut worker = launch_local_worker::<LocalWorker>(|| Ok(())).await.unwrap();
906        worker.join().await.unwrap();
907    }
908}
909
910/// Internal test support
911#[doc(hidden)]
912pub mod test_support {
913    use std::marker::PhantomData;
914
915    use crate::Worker;
916    use crate::WorkerId;
917    use crate::WorkerRpc;
918
919    // Worker that always fails. Used for doc tests.
920    pub struct DummyWorker<T>(PhantomData<T>);
921
922    pub const DUMMY_WORKER: WorkerId<()> = WorkerId::new("DummyWorker");
923
924    impl<T: 'static + Send> Worker for DummyWorker<T> {
925        type Parameters = ();
926        type State = ();
927        const ID: WorkerId<Self::Parameters> = DUMMY_WORKER;
928
929        fn new(_parameters: Self::Parameters) -> anyhow::Result<Self> {
930            Ok(Self(PhantomData))
931        }
932
933        fn restart(_state: Self::State) -> anyhow::Result<Self> {
934            Ok(Self(PhantomData))
935        }
936
937        fn run(self, _recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()> {
938            todo!()
939        }
940    }
941}