state_unit/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! VM state machine unit handling.
5//!
6//! A state unit is a VM component (such as a device) that needs to react to
7//! changes in the VM state machine. It needs to start, stop, reset, and
8//! save/restore with the VM. (Save/restore is not really a state change but is
9//! modeled as such since it must be synchronized with actual state changes.)
10//!
11//! This module contains types and functions for defining and manipulating state
12//! units. It does this in three parts:
13//!
14//! 1. It defines an RPC enum [`StateRequest`] which is used to request that a
15//!    state unit change state (start, stop, etc.). Each state unit must handle
16//!    incoming state requests on a mesh receiver. This is the foundational
17//!    type of this model.
18//!
19//! 2. It defines a type [`StateUnits`], which is a collection of mesh senders
20//!    for sending `StateRequest`s. This is used to initiate and wait for state
21//!    changes across all the units in the VMM, handling any required dependency
22//!    ordering.
23//!
24//! 3. It defines a trait [`StateUnit`] that can be used to define handlers for
25//!    each state request. This is an optional convenience; not all state units
26//!    will have a type that implements this trait.
27//!
28//! This model allows for asynchronous, highly concurrent state changes, and it
29//! works across process boundaries thanks to `mesh`.
30
31use futures::FutureExt;
32use futures::StreamExt;
33use futures::future::join_all;
34use futures_concurrency::stream::Merge;
35use inspect::Inspect;
36use inspect::InspectMut;
37use mesh::MeshPayload;
38use mesh::Receiver;
39use mesh::Sender;
40use mesh::payload::Protobuf;
41use mesh::rpc::FailableRpc;
42use mesh::rpc::Rpc;
43use mesh::rpc::RpcError;
44use mesh::rpc::RpcSend;
45use pal_async::task::Spawn;
46use pal_async::task::Task;
47use parking_lot::Mutex;
48use std::collections::BTreeMap;
49use std::collections::HashMap;
50use std::collections::hash_map;
51use std::fmt::Debug;
52use std::fmt::Display;
53use std::future::Future;
54use std::pin::pin;
55use std::sync::Arc;
56use std::sync::Weak;
57use std::sync::atomic::AtomicU32;
58use std::sync::atomic::Ordering;
59use std::time::Instant;
60use thiserror::Error;
61use tracing::Instrument;
62use vmcore::save_restore::RestoreError;
63use vmcore::save_restore::SaveError;
64use vmcore::save_restore::SavedStateBlob;
65
66/// A state change request.
67#[derive(Debug, MeshPayload)]
68pub enum StateRequest {
69    /// Start asynchronous operations.
70    Start(Rpc<(), ()>),
71
72    /// Stop asynchronous operations.
73    Stop(Rpc<(), ()>),
74
75    /// Reset a stopped unit to initial state.
76    Reset(FailableRpc<(), ()>),
77
78    /// Save state of a stopped unit.
79    Save(FailableRpc<(), Option<SavedStateBlob>>),
80
81    /// Restore state of a stopped unit.
82    Restore(FailableRpc<SavedStateBlob, ()>),
83
84    /// Perform any post-restore actions.
85    ///
86    /// Sent after all dependencies have been restored but before starting.
87    PostRestore(FailableRpc<(), ()>),
88
89    /// Inspect state.
90    Inspect(inspect::Deferred),
91}
92
93/// Trait implemented by an object that can act as a state unit.
94///
95/// Implementing this is optional, to be used with [`UnitBuilder::spawn`] or
96/// [`StateRequest::apply`]; state units can also directly process incoming
97/// [`StateRequest`]s.
98#[expect(async_fn_in_trait)] // Don't need Send bounds
99pub trait StateUnit: InspectMut {
100    /// Start asynchronous processing.
101    async fn start(&mut self);
102
103    /// Stop asynchronous processing.
104    async fn stop(&mut self);
105
106    /// Reset to initial state.
107    ///
108    /// Must only be called while stopped.
109    async fn reset(&mut self) -> anyhow::Result<()>;
110
111    /// Save state.
112    ///
113    /// Must only be called while stopped.
114    async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError>;
115
116    /// Restore state.
117    ///
118    /// Must only be called while stopped.
119    async fn restore(&mut self, buffer: SavedStateBlob) -> Result<(), RestoreError>;
120
121    /// Complete the restore process, after all dependencies have been restored.
122    ///
123    /// Must only be called while stopped.
124    async fn post_restore(&mut self) -> anyhow::Result<()> {
125        Ok(())
126    }
127}
128
129/// Runs a simple unit that only needs to respond to state requests.
130pub async fn run_unit<T: StateUnit>(mut unit: T, mut recv: Receiver<StateRequest>) -> T {
131    while let Some(req) = recv.next().await {
132        req.apply(&mut unit).await;
133    }
134    unit
135}
136
137/// Runs a state unit that can handle inspect requests while there is an active
138/// state transition.
139pub async fn run_async_unit<T>(unit: T, mut recv: Receiver<StateRequest>) -> T
140where
141    for<'a> &'a T: StateUnit,
142{
143    while let Some(req) = recv.next().await {
144        req.apply_with_concurrent_inspects(&mut &unit, &mut recv)
145            .await;
146    }
147    unit
148}
149
150impl StateRequest {
151    /// Runs this state request against `unit`, polling `recv` for incoming
152    /// inspect requests and applying them while any state transition is in
153    /// flight.
154    ///
155    /// For this to work, your state unit `T` should implement [`StateUnit`] for
156    /// `&'_ T`.
157    ///
158    /// Panics if a state transition arrives on `recv` while this one is being
159    /// processed. This would indicate a contract violation with [`StateUnits`].
160    pub async fn apply_with_concurrent_inspects<'a, T>(
161        self,
162        unit: &mut &'a T,
163        recv: &mut Receiver<StateRequest>,
164    ) where
165        &'a T: StateUnit,
166    {
167        match self {
168            StateRequest::Inspect(_) => {
169                // This request has no response and completes synchronously,
170                // so don't wait for concurrent requests.
171                self.apply(unit).await;
172            }
173
174            StateRequest::Start(_)
175            | StateRequest::Stop(_)
176            | StateRequest::Reset(_)
177            | StateRequest::Save(_)
178            | StateRequest::Restore(_)
179            | StateRequest::PostRestore(_) => {
180                // Handle for concurrent inspect requests.
181                enum Event {
182                    OpDone,
183                    Req(StateRequest),
184                }
185                let mut op_unit = *unit;
186                let op = pin!(self.apply(&mut op_unit).into_stream());
187                let mut stream = (op.map(|()| Event::OpDone), recv.map(Event::Req)).merge();
188                while let Some(Event::Req(next_req)) = stream.next().await {
189                    match next_req {
190                        StateRequest::Inspect(req) => req.inspect(&mut *unit),
191                        _ => panic!(
192                            "unexpected state transition {next_req:?} during state transition"
193                        ),
194                    }
195                }
196            }
197        }
198    }
199
200    /// Runs this state request against `unit`.
201    pub async fn apply(self, unit: &mut impl StateUnit) {
202        match self {
203            StateRequest::Start(rpc) => rpc.handle(async |()| unit.start().await).await,
204            StateRequest::Stop(rpc) => rpc.handle(async |()| unit.stop().await).await,
205            StateRequest::Reset(rpc) => rpc.handle_failable(async |()| unit.reset().await).await,
206            StateRequest::Save(rpc) => rpc.handle_failable(async |()| unit.save().await).await,
207            StateRequest::Restore(rpc) => {
208                rpc.handle_failable(async |buffer| unit.restore(buffer).await)
209                    .await
210            }
211            StateRequest::PostRestore(rpc) => {
212                rpc.handle_failable(async |()| unit.post_restore().await)
213                    .await
214            }
215            StateRequest::Inspect(req) => req.inspect(unit),
216        }
217    }
218}
219
220/// A set of state units.
221#[derive(Debug)]
222pub struct StateUnits {
223    inner: Arc<Mutex<Inner>>,
224    running: bool,
225}
226
227#[derive(Copy, Clone, PartialEq, Eq, Debug, Inspect)]
228enum State {
229    Stopped,
230    Starting,
231    Running,
232    Stopping,
233    Resetting,
234    Saving,
235    Restoring,
236    PostRestoring,
237}
238
239#[derive(Debug)]
240struct Inner {
241    next_id: u64,
242    units: BTreeMap<u64, Unit>,
243    names: HashMap<Arc<str>, u64>,
244}
245
246#[derive(Debug)]
247struct Unit {
248    name: Arc<str>,
249    send: Sender<StateRequest>,
250    dependencies: Vec<u64>,
251    dependents: Vec<u64>,
252    state: State,
253}
254
255/// An error returned when a state unit name is already in use.
256#[derive(Debug, Error)]
257#[error("state unit name {0} is in use")]
258pub struct NameInUse(Arc<str>);
259
260#[derive(Debug, Error)]
261#[error("critical unit communication failure: {name}")]
262struct UnitRecvError {
263    name: Arc<str>,
264    #[source]
265    source: RpcError,
266}
267
268#[derive(Debug, Clone)]
269struct UnitId {
270    name: Arc<str>,
271    id: u64,
272}
273
274/// A handle returned by [`StateUnits::add`], used to remove the state unit.
275#[must_use]
276#[derive(Debug)]
277pub struct UnitHandle {
278    id: UnitId,
279    inner: Option<Weak<Mutex<Inner>>>,
280}
281
282impl Drop for UnitHandle {
283    fn drop(&mut self) {
284        self.remove_if();
285    }
286}
287
288impl UnitHandle {
289    /// Remove the state unit.
290    pub fn remove(mut self) {
291        self.remove_if();
292    }
293
294    /// Detach this handle, leaving the unit in place indefinitely.
295    pub fn detach(mut self) {
296        self.inner = None;
297    }
298
299    fn remove_if(&mut self) {
300        if let Some(inner) = self.inner.take().and_then(|inner| inner.upgrade()) {
301            let mut inner = inner.lock();
302            inner.units.remove(&self.id.id).expect("unit exists");
303            inner.names.remove(&self.id.name).expect("unit exists");
304        }
305    }
306}
307
308/// An object returned by [`StateUnits::inspector`] to inspect state units while
309/// state transitions may be in flight.
310pub struct StateUnitsInspector {
311    inner: Weak<Mutex<Inner>>,
312}
313
314impl Inspect for StateUnits {
315    fn inspect(&self, req: inspect::Request<'_>) {
316        self.inner.lock().inspect(req);
317    }
318}
319
320impl Inspect for StateUnitsInspector {
321    fn inspect(&self, req: inspect::Request<'_>) {
322        if let Some(inner) = self.inner.upgrade() {
323            inner.lock().inspect(req);
324        }
325    }
326}
327
328impl Inspect for Inner {
329    fn inspect(&self, req: inspect::Request<'_>) {
330        let mut resp = req.respond();
331        for unit in self.units.values() {
332            resp.child(unit.name.as_ref(), |req| {
333                let mut resp = req.respond();
334                if !unit.dependencies.is_empty() {
335                    resp.field_with("dependencies", || {
336                        unit.dependencies
337                            .iter()
338                            .map(|id| self.units[id].name.as_ref())
339                            .collect::<Vec<_>>()
340                            .join(",")
341                    });
342                }
343                if !unit.dependents.is_empty() {
344                    resp.field_with("dependents", || {
345                        unit.dependents
346                            .iter()
347                            .map(|id| self.units[id].name.as_ref())
348                            .collect::<Vec<_>>()
349                            .join(",")
350                    });
351                }
352                resp.field("unit_state", unit.state);
353                unit.send
354                    .send(StateRequest::Inspect(resp.request().defer()))
355            });
356        }
357    }
358}
359
360/// The saved state for an individual unit.
361#[derive(Protobuf)]
362#[mesh(package = "state_unit")]
363pub struct SavedStateUnit {
364    /// The name of the state unit.
365    #[mesh(1)]
366    pub name: String,
367    /// The opaque saved state blob.
368    #[mesh(2)]
369    pub state: SavedStateBlob,
370}
371
372/// An error from a state transition.
373#[derive(Debug, Error)]
374#[error("{op} failed")]
375pub struct StateTransitionError {
376    op: &'static str,
377    #[source]
378    errors: UnitErrorSet,
379}
380
381fn extract<T, E: Into<anyhow::Error>, U>(
382    op: &'static str,
383    iter: impl IntoIterator<Item = (Arc<str>, Result<T, E>)>,
384    mut f: impl FnMut(Arc<str>, T) -> Option<U>,
385) -> Result<Vec<U>, StateTransitionError> {
386    let mut result = Vec::new();
387    let mut errors = Vec::new();
388    for (name, item) in iter {
389        match item {
390            Ok(t) => {
391                if let Some(u) = f(name, t) {
392                    result.push(u);
393                }
394            }
395            Err(err) => errors.push((name, err.into())),
396        }
397    }
398    if errors.is_empty() {
399        Ok(result)
400    } else {
401        Err(StateTransitionError {
402            op,
403            errors: UnitErrorSet(errors),
404        })
405    }
406}
407
408fn check<E: Into<anyhow::Error>>(
409    op: &'static str,
410    iter: impl IntoIterator<Item = (Arc<str>, Result<(), E>)>,
411) -> Result<(), StateTransitionError> {
412    extract(op, iter, |_, _| Some(()))?;
413    Ok(())
414}
415
416#[derive(Debug)]
417struct UnitErrorSet(Vec<(Arc<str>, anyhow::Error)>);
418
419impl Display for UnitErrorSet {
420    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
421        let mut map = f.debug_map();
422        for (name, err) in &self.0 {
423            map.entry(&format_args!("{}", name), &format_args!("{:#}", err));
424        }
425        map.finish()
426    }
427}
428
429impl std::error::Error for UnitErrorSet {}
430
431impl StateUnits {
432    /// Creates a new instance with no initial units.
433    pub fn new() -> Self {
434        Self {
435            inner: Arc::new(Mutex::new(Inner {
436                next_id: 0,
437                units: BTreeMap::new(),
438                names: HashMap::new(),
439            })),
440            running: false,
441        }
442    }
443
444    /// Returns an inspector that can be used to inspect the state units while
445    /// state transitions are in process.
446    pub fn inspector(&self) -> StateUnitsInspector {
447        StateUnitsInspector {
448            inner: Arc::downgrade(&self.inner),
449        }
450    }
451
452    /// Save and restore will use `name` as the save ID, so this forms part of
453    /// the saved state.
454    ///
455    /// Note that the added unit will not be running after it is built/spawned,
456    /// even if the other units are running. Call
457    /// [`StateUnits::start_stopped_units`] when finished adding units.
458    pub fn add(&self, name: impl Into<Arc<str>>) -> UnitBuilder<'_> {
459        UnitBuilder {
460            units: self,
461            name: name.into(),
462            dependencies: Vec::new(),
463            dependents: Vec::new(),
464        }
465    }
466
467    /// Check if state units are currently running.
468    pub fn is_running(&self) -> bool {
469        self.running
470    }
471
472    /// Starts any units that are individually stopped, because they were added
473    /// via [`StateUnits::add`] while the VM was running.
474    ///
475    /// Does nothing if all units are stopped, via [`StateUnits::stop`].
476    pub async fn start_stopped_units(&mut self) {
477        if self.is_running() {
478            self.start().await;
479        }
480    }
481
482    /// Starts all the state units.
483    pub async fn start(&mut self) {
484        self.run_op(
485            "start",
486            None,
487            State::Stopped,
488            State::Starting,
489            State::Running,
490            StateRequest::Start,
491            |_, _| Some(()),
492            |unit| &unit.dependencies,
493        )
494        .await;
495        self.running = true;
496    }
497
498    /// Stops all the state units.
499    pub async fn stop(&mut self) {
500        assert!(self.running);
501        // Stop units in reverse dependency order so that a dependency is not
502        // stopped before its dependant.
503        self.run_op(
504            "stop",
505            None,
506            State::Running,
507            State::Stopping,
508            State::Stopped,
509            StateRequest::Stop,
510            |_, _| Some(()),
511            |unit| &unit.dependents,
512        )
513        .await;
514        self.running = false;
515    }
516
517    /// Resets all the state units.
518    ///
519    /// Panics if running.
520    pub async fn reset(&mut self) -> Result<(), StateTransitionError> {
521        assert!(!self.running);
522        // Reset in dependency order so that dependants observe their
523        // dependencies' reset state.
524        let r = self
525            .run_op(
526                "reset",
527                None,
528                State::Stopped,
529                State::Resetting,
530                State::Stopped,
531                StateRequest::Reset,
532                |_, _| Some(()),
533                |unit| &unit.dependencies,
534            )
535            .await;
536
537        check("reset", r)?;
538        Ok(())
539    }
540
541    /// Saves all the state units.
542    ///
543    /// Panics if running.
544    pub async fn save(&mut self) -> Result<Vec<SavedStateUnit>, StateTransitionError> {
545        assert!(!self.running);
546        // Save can occur in any order since it will not observably mutate
547        // state.
548        let r = self
549            .run_op(
550                "save",
551                None,
552                State::Stopped,
553                State::Saving,
554                State::Stopped,
555                StateRequest::Save,
556                |_, _| Some(()),
557                |_| &[],
558            )
559            .await;
560
561        let states = extract("save", r, |name, state| {
562            state.map(|state| SavedStateUnit {
563                name: name.to_string(),
564                state,
565            })
566        })?;
567
568        Ok(states)
569    }
570
571    /// Restores all the state units.
572    ///
573    /// Panics if running.
574    pub async fn restore(
575        &mut self,
576        states: Vec<SavedStateUnit>,
577    ) -> Result<(), StateTransitionError> {
578        assert!(!self.running);
579
580        #[derive(Debug, Error)]
581        enum RestoreUnitError {
582            #[error("unknown unit name")]
583            Unknown,
584            #[error("duplicate unit name")]
585            Duplicate,
586        }
587
588        let mut states_by_id = HashMap::new();
589        let mut r = Vec::new();
590        {
591            let inner = self.inner.lock();
592            for state in states {
593                match inner.names.get_key_value(state.name.as_str()) {
594                    Some((name, &id)) => {
595                        if states_by_id
596                            .insert(id, (name.clone(), state.state))
597                            .is_some()
598                        {
599                            r.push((name.clone(), Err(RestoreUnitError::Duplicate)));
600                        }
601                    }
602                    None => {
603                        r.push((state.name.into(), Err(RestoreUnitError::Unknown)));
604                    }
605                }
606            }
607        }
608
609        check("restore", r)?;
610
611        let r = self
612            .run_op(
613                "restore",
614                None,
615                State::Stopped,
616                State::Restoring,
617                State::Stopped,
618                StateRequest::Restore,
619                |id, _| states_by_id.remove(&id).map(|(_, blob)| blob),
620                |unit| &unit.dependencies,
621            )
622            .await;
623
624        // Make sure all the saved state was consumed. This could hit if a unit
625        // was removed concurrently with the restore.
626        check(
627            "restore",
628            states_by_id
629                .into_iter()
630                .map(|(_, (name, _))| (name, Err(RestoreUnitError::Unknown))),
631        )?;
632
633        check("restore", r)?;
634
635        self.post_restore().await?;
636        Ok(())
637    }
638
639    /// Completes the restore operation on all state units.
640    ///
641    /// Panics if running.
642    async fn post_restore(&mut self) -> Result<(), StateTransitionError> {
643        // Post-restore in any order because all state should be up-to-date
644        // after restore.
645        let r = self
646            .run_op(
647                "post_restore",
648                None,
649                State::Stopped,
650                State::PostRestoring,
651                State::Stopped,
652                StateRequest::PostRestore,
653                |_, _| Some(()),
654                |_| &[],
655            )
656            .await;
657
658        check("post_restore", r)?;
659        Ok(())
660    }
661
662    /// Runs a state change operation on a set of units.
663    ///
664    /// `op` gives the name of the operation for tracing and error reporting
665    /// purposes.
666    ///
667    /// `unit_ids` is the set of the units whose state should be changed. If
668    /// `unit_ids` is `None`, all units change states.
669    ///
670    /// The old state for each unit must be `old_state`. During the operation,
671    /// the unit is temporarily places in `interim_state`. When complete, the
672    /// unit is placed in `new_state`.
673    ///
674    /// Each unit waits for its dependencies to complete their state change
675    /// operation before proceeding with their own state change. The
676    /// dependencies list is computed for a unit by calling `deps`.
677    ///
678    /// To perform the state change, the unit is sent a request generated using
679    /// `request`, with input generated by `input`. If `input` returns `None`,
680    /// then communication with the unit is skipped, but the unit still
681    /// transitions through the interim and into the new state, and its
682    /// dependencies are still waited on by its dependents.
683    async fn run_op<I: 'static, R: 'static + Send>(
684        &self,
685        op: &str,
686        unit_ids: Option<&[u64]>,
687        old_state: State,
688        interim_state: State,
689        new_state: State,
690        request: impl Copy + FnOnce(Rpc<I, R>) -> StateRequest,
691        mut input: impl FnMut(u64, &Unit) -> Option<I>,
692        mut deps: impl FnMut(&Unit) -> &[u64],
693    ) -> Vec<(Arc<str>, R)> {
694        let mut done = Vec::new();
695        let ready_set;
696        {
697            let mut inner = self.inner.lock();
698            ready_set = inner.ready_set(unit_ids);
699            for (&id, unit) in inner
700                .units
701                .iter_mut()
702                .filter(|(id, _)| ready_set.0.contains_key(id))
703            {
704                if unit.state != old_state {
705                    assert_eq!(
706                        unit.state, new_state,
707                        "unit {} in {:?} state, should be {:?} or {:?}",
708                        unit.name, unit.state, old_state, new_state
709                    );
710                    ready_set.done(id, true);
711                } else {
712                    let name = unit.name.clone();
713                    let input = (input)(id, unit);
714                    let ready_set = ready_set.clone();
715                    let deps = deps(unit).to_vec();
716                    let fut = state_change(name.clone(), unit, request, input);
717                    let recv = async move {
718                        ready_set.wait(op, id, &deps).await;
719                        let r = fut.await;
720                        ready_set.done(id, true);
721                        (name, id, r)
722                    };
723                    done.push(recv);
724                    unit.state = interim_state;
725                }
726            }
727        }
728
729        let results = async {
730            let start = Instant::now();
731            let results = join_all(done).await;
732            tracing::info!(duration = ?Instant::now() - start, "state change complete");
733            results
734        }
735        .instrument(tracing::info_span!("state_change", operation = op))
736        .await;
737
738        let mut inner = self.inner.lock();
739        let r = results
740            .into_iter()
741            .filter_map(|(name, id, r)| {
742                match r {
743                    Ok(Some(r)) => Some((name, r)),
744                    Ok(None) => None,
745                    Err(err) => {
746                        // If the unit was removed during the operation, then
747                        // ignore the failure. Otherwise, panic because unit
748                        // failure is not recoverable. FUTURE: reconsider this
749                        // position.
750                        if inner.units.contains_key(&id) {
751                            panic!("{:?}", err);
752                        }
753                        None
754                    }
755                }
756            })
757            .collect();
758        for (_, unit) in inner
759            .units
760            .iter_mut()
761            .filter(|(id, _)| ready_set.0.contains_key(id))
762        {
763            if unit.state == interim_state {
764                unit.state = new_state;
765            } else {
766                assert_eq!(
767                    unit.state, new_state,
768                    "unit {} in {:?} state, should be {:?} or {:?}",
769                    unit.name, unit.state, interim_state, new_state
770                );
771            }
772        }
773        r
774    }
775}
776
777impl Inner {
778    fn ready_set(&self, unit_ids: Option<&[u64]>) -> ReadySet {
779        let map = |id, unit: &Unit| {
780            (
781                id,
782                ReadyState {
783                    name: unit.name.clone(),
784                    ready: Arc::new(Ready::default()),
785                },
786            )
787        };
788        let units = if let Some(unit_ids) = unit_ids {
789            unit_ids
790                .iter()
791                .map(|id| map(*id, &self.units[id]))
792                .collect()
793        } else {
794            self.units.iter().map(|(id, unit)| map(*id, unit)).collect()
795        };
796        ReadySet(Arc::new(units))
797    }
798}
799
800#[derive(Clone)]
801struct ReadySet(Arc<BTreeMap<u64, ReadyState>>);
802
803#[derive(Clone)]
804struct ReadyState {
805    name: Arc<str>,
806    ready: Arc<Ready>,
807}
808
809impl ReadySet {
810    async fn wait(&self, op: &str, id: u64, deps: &[u64]) -> bool {
811        for dep in deps {
812            if let Some(dep) = self.0.get(dep) {
813                if !dep.ready.is_ready() {
814                    tracing::debug!(
815                        device = self.0[&id].name.as_ref(),
816                        dependency = dep.name.as_ref(),
817                        operation = op,
818                        "waiting on dependency"
819                    );
820                }
821                if !dep.ready.wait().await {
822                    return false;
823                }
824            }
825        }
826        true
827    }
828
829    fn done(&self, id: u64, success: bool) {
830        self.0[&id].ready.signal(success);
831    }
832}
833
834/// Sends state change `request` to `unit` with `input`, wrapping the result
835/// future with a span, and wrapping its error with something more informative.
836///
837/// `operation` and `name` are used in tracing and error construction.
838fn state_change<I: 'static, R: 'static + Send, Req: FnOnce(Rpc<I, R>) -> StateRequest>(
839    name: Arc<str>,
840    unit: &Unit,
841    request: Req,
842    input: Option<I>,
843) -> impl Future<Output = Result<Option<R>, UnitRecvError>> + use<I, R, Req> {
844    let send = unit.send.clone();
845
846    async move {
847        let Some(input) = input else { return Ok(None) };
848        let span = tracing::info_span!("device_state_change", device = name.as_ref());
849        async move {
850            let start = Instant::now();
851            let r = send
852                .call(request, input)
853                .await
854                .map_err(|err| UnitRecvError { name, source: err });
855            tracing::debug!(duration = ?Instant::now() - start, "device state change complete");
856            r.map(Some)
857        }
858        .instrument(span)
859        .await
860    }
861}
862
863/// A builder returned by [`StateUnits::add`].
864#[derive(Debug)]
865#[must_use]
866pub struct UnitBuilder<'a> {
867    units: &'a StateUnits,
868    name: Arc<str>,
869    dependencies: Vec<u64>,
870    dependents: Vec<u64>,
871}
872
873impl UnitBuilder<'_> {
874    /// Adds `handle` as a dependency of this new unit.
875    ///
876    /// Operations will be ordered to ensure that a dependency will stop after
877    /// its dependants, and that it will reset or restore before its dependants.
878    pub fn depends_on(mut self, handle: &UnitHandle) -> Self {
879        self.dependencies.push(self.handle_id(handle));
880        self
881    }
882
883    /// Adds this new unit as a dependency of `handle`.
884    ///
885    /// Operations will be ordered to ensure that a dependency will stop after
886    /// its dependants, and that it will reset or restore before its dependants.
887    pub fn dependency_of(mut self, handle: &UnitHandle) -> Self {
888        self.dependents.push(self.handle_id(handle));
889        self
890    }
891
892    fn handle_id(&self, handle: &UnitHandle) -> u64 {
893        // Ensure this handle is associated with this set of state units.
894        assert_eq!(
895            Weak::as_ptr(handle.inner.as_ref().unwrap()),
896            Arc::as_ptr(&self.units.inner)
897        );
898        handle.id.id
899    }
900
901    /// Adds a new state unit sending requests to `send`.
902    pub fn build(mut self, send: Sender<StateRequest>) -> Result<UnitHandle, NameInUse> {
903        let id = {
904            let mut inner = self.units.inner.lock();
905            let id = inner.next_id;
906            let entry = match inner.names.entry(self.name.clone()) {
907                hash_map::Entry::Occupied(_) => return Err(NameInUse(self.name)),
908                hash_map::Entry::Vacant(e) => e,
909            };
910            entry.insert(id);
911
912            // Dedup the dependencies and update the dependencies' lists of
913            // dependents.
914            self.dependencies.sort();
915            self.dependencies.dedup();
916            for &dep in &self.dependencies {
917                inner.units.get_mut(&dep).unwrap().dependents.push(id);
918            }
919
920            // Dedup the depenedents and update the dependents' lists of
921            // dependencies.
922            self.dependents.sort();
923            self.dependents.dedup();
924            for &dep in &self.dependents {
925                inner.units.get_mut(&dep).unwrap().dependencies.push(id);
926            }
927            inner.units.insert(
928                id,
929                Unit {
930                    name: self.name.clone(),
931                    send,
932                    dependencies: self.dependencies,
933                    dependents: self.dependents,
934                    state: State::Stopped,
935                },
936            );
937            let unit_id = UnitId {
938                name: self.name,
939                id,
940            };
941            inner.next_id += 1;
942            unit_id
943        };
944        Ok(UnitHandle {
945            id,
946            inner: Some(Arc::downgrade(&self.units.inner)),
947        })
948    }
949
950    /// Adds a unit as in [`Self::build`], then spawns a task for running the
951    /// unit.
952    ///
953    /// The channel to receive state change requests is passed to `f`, which
954    /// should return the future to evaluate to run the unit.
955    #[track_caller]
956    pub fn spawn<F, Fut>(
957        self,
958        spawner: impl Spawn,
959        f: F,
960    ) -> Result<SpawnedUnit<Fut::Output>, NameInUse>
961    where
962        F: FnOnce(Receiver<StateRequest>) -> Fut,
963        Fut: 'static + Send + Future,
964        Fut::Output: 'static + Send,
965    {
966        let (send, recv) = mesh::channel();
967        let task_name = format!("unit-{}", self.name);
968        let handle = self.build(send)?;
969        let fut = (f)(recv);
970        let task = spawner.spawn(task_name, fut);
971        Ok(SpawnedUnit { task, handle })
972    }
973}
974
975/// A handle to a spawned unit.
976#[must_use]
977pub struct SpawnedUnit<T> {
978    handle: UnitHandle,
979    task: Task<T>,
980}
981
982impl<T> Debug for SpawnedUnit<T> {
983    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
984        f.debug_struct("SpawnedUnit")
985            .field("handle", &self.handle)
986            .field("task", &self.task)
987            .finish()
988    }
989}
990
991impl<T> SpawnedUnit<T> {
992    /// Removes the unit and returns it.
993    pub async fn remove(self) -> T {
994        self.handle.remove();
995        self.task.await
996    }
997
998    /// Gets the unit handle for use with methods like
999    /// [`UnitBuilder::depends_on`].
1000    pub fn handle(&self) -> &UnitHandle {
1001        &self.handle
1002    }
1003}
1004
1005#[derive(Default)]
1006struct Ready {
1007    state: AtomicU32,
1008    event: event_listener::Event,
1009}
1010
1011impl Ready {
1012    /// Wakes everyone with `success`.
1013    fn signal(&self, success: bool) {
1014        self.state.store(success as u32 + 1, Ordering::Release);
1015        self.event.notify(usize::MAX);
1016    }
1017
1018    fn is_ready(&self) -> bool {
1019        self.state.load(Ordering::Acquire) != 0
1020    }
1021
1022    /// Waits for `signal` to be called and returns its `success` parameter.
1023    async fn wait(&self) -> bool {
1024        loop {
1025            let listener = self.event.listen();
1026            let state = self.state.load(Ordering::Acquire);
1027            if state != 0 {
1028                return state - 1 != 0;
1029            }
1030            listener.await;
1031        }
1032    }
1033}
1034
1035#[cfg(test)]
1036mod tests {
1037    use super::StateUnit;
1038    use super::StateUnits;
1039    use crate::run_unit;
1040    use inspect::InspectMut;
1041    use mesh::payload::Protobuf;
1042    use pal_async::DefaultDriver;
1043    use pal_async::async_test;
1044    use std::sync::Arc;
1045    use std::sync::atomic::AtomicBool;
1046    use std::sync::atomic::Ordering;
1047    use std::time::Duration;
1048    use test_with_tracing::test;
1049    use vmcore::save_restore::RestoreError;
1050    use vmcore::save_restore::SaveError;
1051    use vmcore::save_restore::SavedStateBlob;
1052    use vmcore::save_restore::SavedStateRoot;
1053
1054    #[derive(Default)]
1055    struct TestUnit {
1056        value: Arc<AtomicBool>,
1057        dep: Option<Arc<AtomicBool>>,
1058        /// If we should support saved state or not.
1059        support_saved_state: bool,
1060    }
1061
1062    #[derive(Protobuf, SavedStateRoot)]
1063    #[mesh(package = "test")]
1064    struct SavedState(bool);
1065
1066    impl StateUnit for TestUnit {
1067        async fn start(&mut self) {}
1068
1069        async fn stop(&mut self) {}
1070
1071        async fn reset(&mut self) -> anyhow::Result<()> {
1072            Ok(())
1073        }
1074
1075        async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError> {
1076            if self.support_saved_state {
1077                let state = SavedState(self.value.load(Ordering::Relaxed));
1078                Ok(Some(SavedStateBlob::new(state)))
1079            } else {
1080                Ok(None)
1081            }
1082        }
1083
1084        async fn restore(&mut self, state: SavedStateBlob) -> Result<(), RestoreError> {
1085            assert!(self.dep.as_ref().is_none_or(|v| v.load(Ordering::Relaxed)));
1086
1087            if self.support_saved_state {
1088                let state: SavedState = state.parse()?;
1089                self.value.store(state.0, Ordering::Relaxed);
1090                Ok(())
1091            } else {
1092                Err(RestoreError::SavedStateNotSupported)
1093            }
1094        }
1095
1096        async fn post_restore(&mut self) -> anyhow::Result<()> {
1097            Ok(())
1098        }
1099    }
1100
1101    impl InspectMut for TestUnit {
1102        fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1103            req.respond();
1104        }
1105    }
1106
1107    struct TestUnitSetDep {
1108        dep: Arc<AtomicBool>,
1109        driver: DefaultDriver,
1110    }
1111
1112    impl StateUnit for TestUnitSetDep {
1113        async fn start(&mut self) {}
1114
1115        async fn stop(&mut self) {}
1116
1117        async fn reset(&mut self) -> anyhow::Result<()> {
1118            Ok(())
1119        }
1120
1121        async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError> {
1122            Ok(Some(SavedStateBlob::new(SavedState(true))))
1123        }
1124
1125        async fn restore(&mut self, _state: SavedStateBlob) -> Result<(), RestoreError> {
1126            pal_async::timer::PolledTimer::new(&self.driver)
1127                .sleep(Duration::from_millis(100))
1128                .await;
1129
1130            self.dep.store(true, Ordering::Relaxed);
1131
1132            Ok(())
1133        }
1134
1135        async fn post_restore(&mut self) -> anyhow::Result<()> {
1136            Ok(())
1137        }
1138    }
1139
1140    impl InspectMut for TestUnitSetDep {
1141        fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1142            req.respond();
1143        }
1144    }
1145
1146    #[async_test]
1147    async fn test_state_change(driver: DefaultDriver) {
1148        let mut units = StateUnits::new();
1149
1150        let a_val = Arc::new(AtomicBool::new(true));
1151
1152        let _a = units
1153            .add("a")
1154            .spawn(&driver, |recv| {
1155                run_unit(
1156                    TestUnit {
1157                        value: a_val.clone(),
1158                        dep: None,
1159                        support_saved_state: true,
1160                    },
1161                    recv,
1162                )
1163            })
1164            .unwrap();
1165        let _b = units
1166            .add("b")
1167            .spawn(&driver, |recv| run_unit(TestUnit::default(), recv))
1168            .unwrap();
1169        units.start().await;
1170
1171        let _c = units
1172            .add("c")
1173            .spawn(&driver, |recv| run_unit(TestUnit::default(), recv));
1174        units.stop().await;
1175        units.start().await;
1176
1177        units.stop().await;
1178
1179        let state = units.save().await.unwrap();
1180
1181        a_val.store(false, Ordering::Relaxed);
1182
1183        units.restore(state).await.unwrap();
1184
1185        assert!(a_val.load(Ordering::Relaxed));
1186    }
1187
1188    #[async_test]
1189    async fn test_dependencies(driver: DefaultDriver) {
1190        let mut units = StateUnits::new();
1191
1192        let a_val = Arc::new(AtomicBool::new(true));
1193
1194        let a = units
1195            .add("zzz")
1196            .spawn(&driver, |recv| {
1197                run_unit(
1198                    TestUnit {
1199                        value: a_val.clone(),
1200                        dep: None,
1201                        support_saved_state: true,
1202                    },
1203                    recv,
1204                )
1205            })
1206            .unwrap();
1207
1208        let _b = units
1209            .add("aaa")
1210            .depends_on(a.handle())
1211            .spawn(&driver, |recv| {
1212                run_unit(
1213                    TestUnit {
1214                        dep: Some(a_val.clone()),
1215                        value: Default::default(),
1216                        support_saved_state: true,
1217                    },
1218                    recv,
1219                )
1220            })
1221            .unwrap();
1222        units.start().await;
1223        units.stop().await;
1224
1225        let state = units.save().await.unwrap();
1226
1227        a_val.store(false, Ordering::Relaxed);
1228
1229        units.restore(state).await.unwrap();
1230    }
1231
1232    #[async_test]
1233    async fn test_dep_no_saved_state(driver: DefaultDriver) {
1234        let mut units = StateUnits::new();
1235
1236        let true_val = Arc::new(AtomicBool::new(true));
1237        let shared_val = Arc::new(AtomicBool::new(false));
1238
1239        let a = units
1240            .add("a")
1241            .spawn(&driver, |recv| {
1242                run_unit(
1243                    TestUnit {
1244                        value: true_val.clone(),
1245                        dep: Some(shared_val.clone()),
1246                        support_saved_state: true,
1247                    },
1248                    recv,
1249                )
1250            })
1251            .unwrap();
1252
1253        // no saved state
1254        // Note that restore is never called for this unit.
1255        let b = units
1256            .add("b_no_saved_state")
1257            .dependency_of(a.handle())
1258            .spawn(&driver, |recv| {
1259                run_unit(
1260                    TestUnit {
1261                        value: true_val.clone(),
1262                        dep: Some(shared_val.clone()),
1263                        support_saved_state: false,
1264                    },
1265                    recv,
1266                )
1267            })
1268            .unwrap();
1269
1270        // A has a transitive dependency on C via B.
1271        let _c = units
1272            .add("c")
1273            .dependency_of(b.handle())
1274            .spawn(&driver, |recv| {
1275                run_unit(
1276                    TestUnitSetDep {
1277                        dep: shared_val,
1278                        driver: driver.clone(),
1279                    },
1280                    recv,
1281                )
1282            })
1283            .unwrap();
1284
1285        let state = units.save().await.unwrap();
1286
1287        units.restore(state).await.unwrap();
1288    }
1289}