state_unit/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! VM state machine unit handling.
5//!
6//! A state unit is a VM component (such as a device) that needs to react to
7//! changes in the VM state machine. It needs to start, stop, reset, and
8//! save/restore with the VM. (Save/restore is not really a state change but is
9//! modeled as such since it must be synchronized with actual state changes.)
10//!
11//! This module contains types and functions for defining and manipulating state
12//! units. It does this in three parts:
13//!
14//! 1. It defines an RPC enum [`StateRequest`] which is used to request that a
15//!    state unit change state (start, stop, etc.). Each state unit must handle
16//!    incoming state requests on a mesh receiver. This is the foundational
17//!    type of this model.
18//!
19//! 2. It defines a type [`StateUnits`], which is a collection of mesh senders
20//!    for sending `StateRequest`s. This is used to initiate and wait for state
21//!    changes across all the units in the VMM, handling any required dependency
22//!    ordering.
23//!
24//! 3. It defines a trait [`StateUnit`] that can be used to define handlers for
25//!    each state request. This is an optional convenience; not all state units
26//!    will have a type that implements this trait.
27//!
28//! This model allows for asynchronous, highly concurrent state changes, and it
29//! works across process boundaries thanks to `mesh`.
30
31#![forbid(unsafe_code)]
32
33use futures::FutureExt;
34use futures::StreamExt;
35use futures::future::join_all;
36use futures_concurrency::stream::Merge;
37use inspect::Inspect;
38use inspect::InspectMut;
39use mesh::MeshPayload;
40use mesh::Receiver;
41use mesh::Sender;
42use mesh::payload::Protobuf;
43use mesh::rpc::FailableRpc;
44use mesh::rpc::Rpc;
45use mesh::rpc::RpcError;
46use mesh::rpc::RpcSend;
47use pal_async::task::Spawn;
48use pal_async::task::Task;
49use parking_lot::Mutex;
50use std::collections::BTreeMap;
51use std::collections::HashMap;
52use std::collections::hash_map;
53use std::fmt::Debug;
54use std::fmt::Display;
55use std::future::Future;
56use std::pin::pin;
57use std::sync::Arc;
58use std::sync::Weak;
59use std::sync::atomic::AtomicU32;
60use std::sync::atomic::Ordering;
61use std::time::Instant;
62use thiserror::Error;
63use tracing::Instrument;
64use vmcore::save_restore::RestoreError;
65use vmcore::save_restore::SaveError;
66use vmcore::save_restore::SavedStateBlob;
67
68/// A state change request.
69#[derive(Debug, MeshPayload)]
70pub enum StateRequest {
71    /// Start asynchronous operations.
72    Start(Rpc<(), ()>),
73
74    /// Stop asynchronous operations.
75    Stop(Rpc<(), ()>),
76
77    /// Reset a stopped unit to initial state.
78    Reset(FailableRpc<(), ()>),
79
80    /// Save state of a stopped unit.
81    Save(FailableRpc<(), Option<SavedStateBlob>>),
82
83    /// Restore state of a stopped unit.
84    Restore(FailableRpc<SavedStateBlob, ()>),
85
86    /// Inspect state.
87    Inspect(inspect::Deferred),
88}
89
90/// Trait implemented by an object that can act as a state unit.
91///
92/// Implementing this is optional, to be used with [`UnitBuilder::spawn`] or
93/// [`StateRequest::apply`]; state units can also directly process incoming
94/// [`StateRequest`]s.
95#[expect(async_fn_in_trait)] // Don't need Send bounds
96pub trait StateUnit: InspectMut {
97    /// Start asynchronous processing.
98    async fn start(&mut self);
99
100    /// Stop asynchronous processing.
101    async fn stop(&mut self);
102
103    /// Reset to initial state.
104    ///
105    /// Must only be called while stopped.
106    async fn reset(&mut self) -> anyhow::Result<()>;
107
108    /// Save state.
109    ///
110    /// Must only be called while stopped.
111    async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError>;
112
113    /// Restore state.
114    ///
115    /// Must only be called while stopped.
116    async fn restore(&mut self, buffer: SavedStateBlob) -> Result<(), RestoreError>;
117}
118
119/// Runs a simple unit that only needs to respond to state requests.
120pub async fn run_unit<T: StateUnit>(mut unit: T, mut recv: Receiver<StateRequest>) -> T {
121    while let Some(req) = recv.next().await {
122        req.apply(&mut unit).await;
123    }
124    unit
125}
126
127/// Runs a state unit that can handle inspect requests while there is an active
128/// state transition.
129pub async fn run_async_unit<T>(unit: T, mut recv: Receiver<StateRequest>) -> T
130where
131    for<'a> &'a T: StateUnit,
132{
133    while let Some(req) = recv.next().await {
134        req.apply_with_concurrent_inspects(&mut &unit, &mut recv)
135            .await;
136    }
137    unit
138}
139
140impl StateRequest {
141    /// Runs this state request against `unit`, polling `recv` for incoming
142    /// inspect requests and applying them while any state transition is in
143    /// flight.
144    ///
145    /// For this to work, your state unit `T` should implement [`StateUnit`] for
146    /// `&'_ T`.
147    ///
148    /// Panics if a state transition arrives on `recv` while this one is being
149    /// processed. This would indicate a contract violation with [`StateUnits`].
150    pub async fn apply_with_concurrent_inspects<'a, T>(
151        self,
152        unit: &mut &'a T,
153        recv: &mut Receiver<StateRequest>,
154    ) where
155        &'a T: StateUnit,
156    {
157        match self {
158            StateRequest::Inspect(_) => {
159                // This request has no response and completes synchronously,
160                // so don't wait for concurrent requests.
161                self.apply(unit).await;
162            }
163
164            StateRequest::Start(_)
165            | StateRequest::Stop(_)
166            | StateRequest::Reset(_)
167            | StateRequest::Save(_)
168            | StateRequest::Restore(_) => {
169                // Handle for concurrent inspect requests.
170                enum Event {
171                    OpDone,
172                    Req(StateRequest),
173                }
174                let mut op_unit = *unit;
175                let op = pin!(self.apply(&mut op_unit).into_stream());
176                let mut stream = (op.map(|()| Event::OpDone), recv.map(Event::Req)).merge();
177                while let Some(Event::Req(next_req)) = stream.next().await {
178                    match next_req {
179                        StateRequest::Inspect(req) => req.inspect(&mut *unit),
180                        _ => panic!(
181                            "unexpected state transition {next_req:?} during state transition"
182                        ),
183                    }
184                }
185            }
186        }
187    }
188
189    /// Runs this state request against `unit`.
190    pub async fn apply(self, unit: &mut impl StateUnit) {
191        match self {
192            StateRequest::Start(rpc) => rpc.handle(async |()| unit.start().await).await,
193            StateRequest::Stop(rpc) => rpc.handle(async |()| unit.stop().await).await,
194            StateRequest::Reset(rpc) => rpc.handle_failable(async |()| unit.reset().await).await,
195            StateRequest::Save(rpc) => rpc.handle_failable(async |()| unit.save().await).await,
196            StateRequest::Restore(rpc) => {
197                rpc.handle_failable(async |buffer| unit.restore(buffer).await)
198                    .await
199            }
200            StateRequest::Inspect(req) => req.inspect(unit),
201        }
202    }
203}
204
205/// A set of state units.
206#[derive(Debug)]
207pub struct StateUnits {
208    inner: Arc<Mutex<Inner>>,
209    running: bool,
210}
211
212#[derive(Copy, Clone, PartialEq, Eq, Debug, Inspect)]
213enum State {
214    Stopped,
215    Starting,
216    Running,
217    Stopping,
218    Resetting,
219    Saving,
220    Restoring,
221}
222
223#[derive(Debug)]
224struct Inner {
225    next_id: u64,
226    units: BTreeMap<u64, Unit>,
227    names: HashMap<Arc<str>, u64>,
228}
229
230#[derive(Debug)]
231struct Unit {
232    name: Arc<str>,
233    send: Sender<StateRequest>,
234    dependencies: Vec<u64>,
235    dependents: Vec<u64>,
236    state: State,
237}
238
239/// An error returned when a state unit name is already in use.
240#[derive(Debug, Error)]
241#[error("state unit name {0} is in use")]
242pub struct NameInUse(Arc<str>);
243
244#[derive(Debug, Error)]
245#[error("critical unit communication failure: {name}")]
246struct UnitRecvError {
247    name: Arc<str>,
248    #[source]
249    source: RpcError,
250}
251
252#[derive(Debug, Clone)]
253struct UnitId {
254    name: Arc<str>,
255    id: u64,
256}
257
258/// A handle returned by [`StateUnits::add`], used to remove the state unit.
259#[must_use]
260#[derive(Debug)]
261pub struct UnitHandle {
262    id: UnitId,
263    inner: Option<Weak<Mutex<Inner>>>,
264}
265
266impl Drop for UnitHandle {
267    fn drop(&mut self) {
268        self.remove_if();
269    }
270}
271
272impl UnitHandle {
273    /// Remove the state unit.
274    pub fn remove(mut self) {
275        self.remove_if();
276    }
277
278    /// Detach this handle, leaving the unit in place indefinitely.
279    pub fn detach(mut self) {
280        self.inner = None;
281    }
282
283    fn remove_if(&mut self) {
284        if let Some(inner) = self.inner.take().and_then(|inner| inner.upgrade()) {
285            let mut inner = inner.lock();
286            inner.units.remove(&self.id.id).expect("unit exists");
287            inner.names.remove(&self.id.name).expect("unit exists");
288        }
289    }
290}
291
292/// An object returned by [`StateUnits::inspector`] to inspect state units while
293/// state transitions may be in flight.
294pub struct StateUnitsInspector {
295    inner: Weak<Mutex<Inner>>,
296}
297
298impl Inspect for StateUnits {
299    fn inspect(&self, req: inspect::Request<'_>) {
300        self.inner.lock().inspect(req);
301    }
302}
303
304impl Inspect for StateUnitsInspector {
305    fn inspect(&self, req: inspect::Request<'_>) {
306        if let Some(inner) = self.inner.upgrade() {
307            inner.lock().inspect(req);
308        }
309    }
310}
311
312impl Inspect for Inner {
313    fn inspect(&self, req: inspect::Request<'_>) {
314        let mut resp = req.respond();
315        for unit in self.units.values() {
316            resp.child(unit.name.as_ref(), |req| {
317                let mut resp = req.respond();
318                if !unit.dependencies.is_empty() {
319                    resp.field_with("dependencies", || {
320                        unit.dependencies
321                            .iter()
322                            .map(|id| self.units[id].name.as_ref())
323                            .collect::<Vec<_>>()
324                            .join(",")
325                    });
326                }
327                if !unit.dependents.is_empty() {
328                    resp.field_with("dependents", || {
329                        unit.dependents
330                            .iter()
331                            .map(|id| self.units[id].name.as_ref())
332                            .collect::<Vec<_>>()
333                            .join(",")
334                    });
335                }
336                resp.field("unit_state", unit.state)
337                    .merge(inspect::adhoc(|req| {
338                        unit.send.send(StateRequest::Inspect(req.defer()))
339                    }));
340            });
341        }
342    }
343}
344
345/// The saved state for an individual unit.
346#[derive(Protobuf)]
347#[mesh(package = "state_unit")]
348pub struct SavedStateUnit {
349    /// The name of the state unit.
350    #[mesh(1)]
351    pub name: String,
352    /// The opaque saved state blob.
353    #[mesh(2)]
354    pub state: SavedStateBlob,
355}
356
357/// An error from a state transition.
358#[derive(Debug, Error)]
359#[error("{op} failed")]
360pub struct StateTransitionError {
361    op: &'static str,
362    #[source]
363    errors: UnitErrorSet,
364}
365
366fn extract<T, E: Into<anyhow::Error>, U>(
367    op: &'static str,
368    iter: impl IntoIterator<Item = (Arc<str>, Result<T, E>)>,
369    mut f: impl FnMut(Arc<str>, T) -> Option<U>,
370) -> Result<Vec<U>, StateTransitionError> {
371    let mut result = Vec::new();
372    let mut errors = Vec::new();
373    for (name, item) in iter {
374        match item {
375            Ok(t) => {
376                if let Some(u) = f(name, t) {
377                    result.push(u);
378                }
379            }
380            Err(err) => errors.push((name, err.into())),
381        }
382    }
383    if errors.is_empty() {
384        Ok(result)
385    } else {
386        Err(StateTransitionError {
387            op,
388            errors: UnitErrorSet(errors),
389        })
390    }
391}
392
393fn check<E: Into<anyhow::Error>>(
394    op: &'static str,
395    iter: impl IntoIterator<Item = (Arc<str>, Result<(), E>)>,
396) -> Result<(), StateTransitionError> {
397    extract(op, iter, |_, _| Some(()))?;
398    Ok(())
399}
400
401#[derive(Debug)]
402struct UnitErrorSet(Vec<(Arc<str>, anyhow::Error)>);
403
404impl Display for UnitErrorSet {
405    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406        let mut map = f.debug_map();
407        for (name, err) in &self.0 {
408            map.entry(&format_args!("{}", name), &format_args!("{:#}", err));
409        }
410        map.finish()
411    }
412}
413
414impl std::error::Error for UnitErrorSet {}
415
416impl StateUnits {
417    /// Creates a new instance with no initial units.
418    pub fn new() -> Self {
419        Self {
420            inner: Arc::new(Mutex::new(Inner {
421                next_id: 0,
422                units: BTreeMap::new(),
423                names: HashMap::new(),
424            })),
425            running: false,
426        }
427    }
428
429    /// Returns an inspector that can be used to inspect the state units while
430    /// state transitions are in process.
431    pub fn inspector(&self) -> StateUnitsInspector {
432        StateUnitsInspector {
433            inner: Arc::downgrade(&self.inner),
434        }
435    }
436
437    /// Save and restore will use `name` as the save ID, so this forms part of
438    /// the saved state.
439    ///
440    /// Note that the added unit will not be running after it is built/spawned,
441    /// even if the other units are running. Call
442    /// [`StateUnits::start_stopped_units`] when finished adding units.
443    pub fn add(&self, name: impl Into<Arc<str>>) -> UnitBuilder<'_> {
444        UnitBuilder {
445            units: self,
446            name: name.into(),
447            dependencies: Vec::new(),
448            dependents: Vec::new(),
449        }
450    }
451
452    /// Check if state units are currently running.
453    pub fn is_running(&self) -> bool {
454        self.running
455    }
456
457    /// Starts any units that are individually stopped, because they were added
458    /// via [`StateUnits::add`] while the VM was running.
459    ///
460    /// Does nothing if all units are stopped, via [`StateUnits::stop`].
461    pub async fn start_stopped_units(&mut self) {
462        if self.is_running() {
463            self.start().await;
464        }
465    }
466
467    /// Starts all the state units.
468    pub async fn start(&mut self) {
469        self.run_op(
470            "start",
471            None,
472            State::Stopped,
473            State::Starting,
474            State::Running,
475            StateRequest::Start,
476            |_, _| Some(()),
477            |unit| &unit.dependencies,
478        )
479        .await;
480        self.running = true;
481    }
482
483    /// Stops all the state units.
484    pub async fn stop(&mut self) {
485        assert!(self.running);
486        // Stop units in reverse dependency order so that a dependency is not
487        // stopped before its dependant.
488        self.run_op(
489            "stop",
490            None,
491            State::Running,
492            State::Stopping,
493            State::Stopped,
494            StateRequest::Stop,
495            |_, _| Some(()),
496            |unit| &unit.dependents,
497        )
498        .await;
499        self.running = false;
500    }
501
502    /// Resets all the state units.
503    ///
504    /// Panics if running.
505    pub async fn reset(&mut self) -> Result<(), StateTransitionError> {
506        assert!(!self.running);
507        // Reset in dependency order so that dependants observe their
508        // dependencies' reset state.
509        let r = self
510            .run_op(
511                "reset",
512                None,
513                State::Stopped,
514                State::Resetting,
515                State::Stopped,
516                StateRequest::Reset,
517                |_, _| Some(()),
518                |unit| &unit.dependencies,
519            )
520            .await;
521
522        check("reset", r)?;
523        Ok(())
524    }
525
526    /// Saves all the state units.
527    ///
528    /// Panics if running.
529    pub async fn save(&mut self) -> Result<Vec<SavedStateUnit>, StateTransitionError> {
530        assert!(!self.running);
531        // Save can occur in any order since it will not observably mutate
532        // state.
533        let r = self
534            .run_op(
535                "save",
536                None,
537                State::Stopped,
538                State::Saving,
539                State::Stopped,
540                StateRequest::Save,
541                |_, _| Some(()),
542                |_| &[],
543            )
544            .await;
545
546        let states = extract("save", r, |name, state| {
547            state.map(|state| SavedStateUnit {
548                name: name.to_string(),
549                state,
550            })
551        })?;
552
553        Ok(states)
554    }
555
556    /// Restores all the state units.
557    ///
558    /// Panics if running.
559    pub async fn restore(
560        &mut self,
561        states: Vec<SavedStateUnit>,
562    ) -> Result<(), StateTransitionError> {
563        assert!(!self.running);
564
565        #[derive(Debug, Error)]
566        enum RestoreUnitError {
567            #[error("unknown unit name")]
568            Unknown,
569            #[error("duplicate unit name")]
570            Duplicate,
571        }
572
573        let mut states_by_id = HashMap::new();
574        let mut r = Vec::new();
575        {
576            let inner = self.inner.lock();
577            for state in states {
578                match inner.names.get_key_value(state.name.as_str()) {
579                    Some((name, &id)) => {
580                        if states_by_id
581                            .insert(id, (name.clone(), state.state))
582                            .is_some()
583                        {
584                            r.push((name.clone(), Err(RestoreUnitError::Duplicate)));
585                        }
586                    }
587                    None => {
588                        r.push((state.name.into(), Err(RestoreUnitError::Unknown)));
589                    }
590                }
591            }
592        }
593
594        check("restore", r)?;
595
596        let r = self
597            .run_op(
598                "restore",
599                None,
600                State::Stopped,
601                State::Restoring,
602                State::Stopped,
603                StateRequest::Restore,
604                |id, _| states_by_id.remove(&id).map(|(_, blob)| blob),
605                |unit| &unit.dependencies,
606            )
607            .await;
608
609        // Make sure all the saved state was consumed. This could hit if a unit
610        // was removed concurrently with the restore.
611        check(
612            "restore",
613            states_by_id
614                .into_iter()
615                .map(|(_, (name, _))| (name, Err(RestoreUnitError::Unknown))),
616        )?;
617
618        check("restore", r)?;
619
620        Ok(())
621    }
622
623    /// Runs a state change operation on a set of units.
624    ///
625    /// `op` gives the name of the operation for tracing and error reporting
626    /// purposes.
627    ///
628    /// `unit_ids` is the set of the units whose state should be changed. If
629    /// `unit_ids` is `None`, all units change states.
630    ///
631    /// The old state for each unit must be `old_state`. During the operation,
632    /// the unit is temporarily places in `interim_state`. When complete, the
633    /// unit is placed in `new_state`.
634    ///
635    /// Each unit waits for its dependencies to complete their state change
636    /// operation before proceeding with their own state change. The
637    /// dependencies list is computed for a unit by calling `deps`.
638    ///
639    /// To perform the state change, the unit is sent a request generated using
640    /// `request`, with input generated by `input`. If `input` returns `None`,
641    /// then communication with the unit is skipped, but the unit still
642    /// transitions through the interim and into the new state, and its
643    /// dependencies are still waited on by its dependents.
644    async fn run_op<I: 'static, R: 'static + Send>(
645        &self,
646        op: &str,
647        unit_ids: Option<&[u64]>,
648        old_state: State,
649        interim_state: State,
650        new_state: State,
651        request: impl Copy + FnOnce(Rpc<I, R>) -> StateRequest,
652        mut input: impl FnMut(u64, &Unit) -> Option<I>,
653        mut deps: impl FnMut(&Unit) -> &[u64],
654    ) -> Vec<(Arc<str>, R)> {
655        let mut done = Vec::new();
656        let ready_set;
657        {
658            let mut inner = self.inner.lock();
659            ready_set = inner.ready_set(unit_ids);
660            for (&id, unit) in inner
661                .units
662                .iter_mut()
663                .filter(|(id, _)| ready_set.0.contains_key(id))
664            {
665                if unit.state != old_state {
666                    assert_eq!(
667                        unit.state, new_state,
668                        "unit {} in {:?} state, should be {:?} or {:?}",
669                        unit.name, unit.state, old_state, new_state
670                    );
671                    ready_set.done(id, true);
672                } else {
673                    let name = unit.name.clone();
674                    let input = (input)(id, unit);
675                    let ready_set = ready_set.clone();
676                    let deps = deps(unit).to_vec();
677                    let fut = state_change(name.clone(), unit, request, input);
678                    let recv = async move {
679                        ready_set.wait(op, id, &deps).await;
680                        let r = fut.await;
681                        ready_set.done(id, true);
682                        (name, id, r)
683                    };
684                    done.push(recv);
685                    unit.state = interim_state;
686                }
687            }
688        }
689
690        let results = async {
691            let start = Instant::now();
692            let results = join_all(done).await;
693            tracing::info!(duration = ?Instant::now() - start, "state change complete");
694            results
695        }
696        .instrument(tracing::info_span!("state_change", operation = op))
697        .await;
698
699        let mut inner = self.inner.lock();
700        let r = results
701            .into_iter()
702            .filter_map(|(name, id, r)| {
703                match r {
704                    Ok(Some(r)) => Some((name, r)),
705                    Ok(None) => None,
706                    Err(err) => {
707                        // If the unit was removed during the operation, then
708                        // ignore the failure. Otherwise, panic because unit
709                        // failure is not recoverable. FUTURE: reconsider this
710                        // position.
711                        if inner.units.contains_key(&id) {
712                            panic!("{:?}", err);
713                        }
714                        None
715                    }
716                }
717            })
718            .collect();
719        for (_, unit) in inner
720            .units
721            .iter_mut()
722            .filter(|(id, _)| ready_set.0.contains_key(id))
723        {
724            if unit.state == interim_state {
725                unit.state = new_state;
726            } else {
727                assert_eq!(
728                    unit.state, new_state,
729                    "unit {} in {:?} state, should be {:?} or {:?}",
730                    unit.name, unit.state, interim_state, new_state
731                );
732            }
733        }
734        r
735    }
736}
737
738impl Inner {
739    fn ready_set(&self, unit_ids: Option<&[u64]>) -> ReadySet {
740        let map = |id, unit: &Unit| {
741            (
742                id,
743                ReadyState {
744                    name: unit.name.clone(),
745                    ready: Arc::new(Ready::default()),
746                },
747            )
748        };
749        let units = if let Some(unit_ids) = unit_ids {
750            unit_ids
751                .iter()
752                .map(|id| map(*id, &self.units[id]))
753                .collect()
754        } else {
755            self.units.iter().map(|(id, unit)| map(*id, unit)).collect()
756        };
757        ReadySet(Arc::new(units))
758    }
759}
760
761#[derive(Clone)]
762struct ReadySet(Arc<BTreeMap<u64, ReadyState>>);
763
764#[derive(Clone)]
765struct ReadyState {
766    name: Arc<str>,
767    ready: Arc<Ready>,
768}
769
770impl ReadySet {
771    async fn wait(&self, op: &str, id: u64, deps: &[u64]) -> bool {
772        for dep in deps {
773            if let Some(dep) = self.0.get(dep) {
774                if !dep.ready.is_ready() {
775                    tracing::debug!(
776                        device = self.0[&id].name.as_ref(),
777                        dependency = dep.name.as_ref(),
778                        operation = op,
779                        "waiting on dependency"
780                    );
781                }
782                if !dep.ready.wait().await {
783                    return false;
784                }
785            }
786        }
787        true
788    }
789
790    fn done(&self, id: u64, success: bool) {
791        self.0[&id].ready.signal(success);
792    }
793}
794
795/// Sends state change `request` to `unit` with `input`, wrapping the result
796/// future with a span, and wrapping its error with something more informative.
797///
798/// `operation` and `name` are used in tracing and error construction.
799fn state_change<I: 'static, R: 'static + Send, Req: FnOnce(Rpc<I, R>) -> StateRequest>(
800    name: Arc<str>,
801    unit: &Unit,
802    request: Req,
803    input: Option<I>,
804) -> impl Future<Output = Result<Option<R>, UnitRecvError>> + use<I, R, Req> {
805    let send = unit.send.clone();
806
807    async move {
808        let Some(input) = input else { return Ok(None) };
809        let span = tracing::info_span!("device_state_change", device = name.as_ref());
810        async move {
811            let start = Instant::now();
812            let r = send
813                .call(request, input)
814                .await
815                .map_err(|err| UnitRecvError { name, source: err });
816            tracing::debug!(duration = ?Instant::now() - start, "device state change complete");
817            r.map(Some)
818        }
819        .instrument(span)
820        .await
821    }
822}
823
824/// A builder returned by [`StateUnits::add`].
825#[derive(Debug)]
826#[must_use]
827pub struct UnitBuilder<'a> {
828    units: &'a StateUnits,
829    name: Arc<str>,
830    dependencies: Vec<u64>,
831    dependents: Vec<u64>,
832}
833
834impl UnitBuilder<'_> {
835    /// Adds `handle` as a dependency of this new unit.
836    ///
837    /// Operations will be ordered to ensure that a dependency will stop after
838    /// its dependants, and that it will reset or restore before its dependants.
839    pub fn depends_on(mut self, handle: &UnitHandle) -> Self {
840        self.dependencies.push(self.handle_id(handle));
841        self
842    }
843
844    /// Adds this new unit as a dependency of `handle`.
845    ///
846    /// Operations will be ordered to ensure that a dependency will stop after
847    /// its dependants, and that it will reset or restore before its dependants.
848    pub fn dependency_of(mut self, handle: &UnitHandle) -> Self {
849        self.dependents.push(self.handle_id(handle));
850        self
851    }
852
853    fn handle_id(&self, handle: &UnitHandle) -> u64 {
854        // Ensure this handle is associated with this set of state units.
855        assert_eq!(
856            Weak::as_ptr(handle.inner.as_ref().unwrap()),
857            Arc::as_ptr(&self.units.inner)
858        );
859        handle.id.id
860    }
861
862    /// Adds a new state unit sending requests to `send`.
863    pub fn build(mut self, send: Sender<StateRequest>) -> Result<UnitHandle, NameInUse> {
864        let id = {
865            let mut inner = self.units.inner.lock();
866            let id = inner.next_id;
867            let entry = match inner.names.entry(self.name.clone()) {
868                hash_map::Entry::Occupied(_) => return Err(NameInUse(self.name)),
869                hash_map::Entry::Vacant(e) => e,
870            };
871            entry.insert(id);
872
873            // Dedup the dependencies and update the dependencies' lists of
874            // dependents.
875            self.dependencies.sort();
876            self.dependencies.dedup();
877            for &dep in &self.dependencies {
878                inner.units.get_mut(&dep).unwrap().dependents.push(id);
879            }
880
881            // Dedup the depenedents and update the dependents' lists of
882            // dependencies.
883            self.dependents.sort();
884            self.dependents.dedup();
885            for &dep in &self.dependents {
886                inner.units.get_mut(&dep).unwrap().dependencies.push(id);
887            }
888            inner.units.insert(
889                id,
890                Unit {
891                    name: self.name.clone(),
892                    send,
893                    dependencies: self.dependencies,
894                    dependents: self.dependents,
895                    state: State::Stopped,
896                },
897            );
898            let unit_id = UnitId {
899                name: self.name,
900                id,
901            };
902            inner.next_id += 1;
903            unit_id
904        };
905        Ok(UnitHandle {
906            id,
907            inner: Some(Arc::downgrade(&self.units.inner)),
908        })
909    }
910
911    /// Adds a unit as in [`Self::build`], then spawns a task for running the
912    /// unit.
913    ///
914    /// The channel to receive state change requests is passed to `f`, which
915    /// should return the future to evaluate to run the unit.
916    #[track_caller]
917    pub fn spawn<F, Fut>(
918        self,
919        spawner: impl Spawn,
920        f: F,
921    ) -> Result<SpawnedUnit<Fut::Output>, NameInUse>
922    where
923        F: FnOnce(Receiver<StateRequest>) -> Fut,
924        Fut: 'static + Send + Future,
925        Fut::Output: 'static + Send,
926    {
927        let (send, recv) = mesh::channel();
928        let task_name = format!("unit-{}", self.name);
929        let handle = self.build(send)?;
930        let fut = (f)(recv);
931        let task = spawner.spawn(task_name, fut);
932        Ok(SpawnedUnit { task, handle })
933    }
934}
935
936/// A handle to a spawned unit.
937#[must_use]
938pub struct SpawnedUnit<T> {
939    handle: UnitHandle,
940    task: Task<T>,
941}
942
943impl<T> Debug for SpawnedUnit<T> {
944    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
945        f.debug_struct("SpawnedUnit")
946            .field("handle", &self.handle)
947            .field("task", &self.task)
948            .finish()
949    }
950}
951
952impl<T> SpawnedUnit<T> {
953    /// Removes the unit and returns it.
954    pub async fn remove(self) -> T {
955        self.handle.remove();
956        self.task.await
957    }
958
959    /// Gets the unit handle for use with methods like
960    /// [`UnitBuilder::depends_on`].
961    pub fn handle(&self) -> &UnitHandle {
962        &self.handle
963    }
964}
965
966#[derive(Default)]
967struct Ready {
968    state: AtomicU32,
969    event: event_listener::Event,
970}
971
972impl Ready {
973    /// Wakes everyone with `success`.
974    fn signal(&self, success: bool) {
975        self.state.store(success as u32 + 1, Ordering::Release);
976        self.event.notify(usize::MAX);
977    }
978
979    fn is_ready(&self) -> bool {
980        self.state.load(Ordering::Acquire) != 0
981    }
982
983    /// Waits for `signal` to be called and returns its `success` parameter.
984    async fn wait(&self) -> bool {
985        loop {
986            let listener = self.event.listen();
987            let state = self.state.load(Ordering::Acquire);
988            if state != 0 {
989                return state - 1 != 0;
990            }
991            listener.await;
992        }
993    }
994}
995
996#[cfg(test)]
997mod tests {
998    use super::StateUnit;
999    use super::StateUnits;
1000    use crate::run_unit;
1001    use inspect::InspectMut;
1002    use mesh::payload::Protobuf;
1003    use pal_async::DefaultDriver;
1004    use pal_async::async_test;
1005    use std::sync::Arc;
1006    use std::sync::atomic::AtomicBool;
1007    use std::sync::atomic::Ordering;
1008    use std::time::Duration;
1009    use test_with_tracing::test;
1010    use vmcore::save_restore::RestoreError;
1011    use vmcore::save_restore::SaveError;
1012    use vmcore::save_restore::SavedStateBlob;
1013    use vmcore::save_restore::SavedStateRoot;
1014
1015    #[derive(Default)]
1016    struct TestUnit {
1017        value: Arc<AtomicBool>,
1018        dep: Option<Arc<AtomicBool>>,
1019        /// If we should support saved state or not.
1020        support_saved_state: bool,
1021    }
1022
1023    #[derive(Protobuf, SavedStateRoot)]
1024    #[mesh(package = "test")]
1025    struct SavedState(bool);
1026
1027    impl StateUnit for TestUnit {
1028        async fn start(&mut self) {}
1029
1030        async fn stop(&mut self) {}
1031
1032        async fn reset(&mut self) -> anyhow::Result<()> {
1033            Ok(())
1034        }
1035
1036        async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError> {
1037            if self.support_saved_state {
1038                let state = SavedState(self.value.load(Ordering::Relaxed));
1039                Ok(Some(SavedStateBlob::new(state)))
1040            } else {
1041                Ok(None)
1042            }
1043        }
1044
1045        async fn restore(&mut self, state: SavedStateBlob) -> Result<(), RestoreError> {
1046            assert!(self.dep.as_ref().is_none_or(|v| v.load(Ordering::Relaxed)));
1047
1048            if self.support_saved_state {
1049                let state: SavedState = state.parse()?;
1050                self.value.store(state.0, Ordering::Relaxed);
1051                Ok(())
1052            } else {
1053                Err(RestoreError::SavedStateNotSupported)
1054            }
1055        }
1056    }
1057
1058    impl InspectMut for TestUnit {
1059        fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1060            req.respond();
1061        }
1062    }
1063
1064    struct TestUnitSetDep {
1065        dep: Arc<AtomicBool>,
1066        driver: DefaultDriver,
1067    }
1068
1069    impl StateUnit for TestUnitSetDep {
1070        async fn start(&mut self) {}
1071
1072        async fn stop(&mut self) {}
1073
1074        async fn reset(&mut self) -> anyhow::Result<()> {
1075            Ok(())
1076        }
1077
1078        async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError> {
1079            Ok(Some(SavedStateBlob::new(SavedState(true))))
1080        }
1081
1082        async fn restore(&mut self, _state: SavedStateBlob) -> Result<(), RestoreError> {
1083            pal_async::timer::PolledTimer::new(&self.driver)
1084                .sleep(Duration::from_millis(100))
1085                .await;
1086
1087            self.dep.store(true, Ordering::Relaxed);
1088
1089            Ok(())
1090        }
1091    }
1092
1093    impl InspectMut for TestUnitSetDep {
1094        fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1095            req.respond();
1096        }
1097    }
1098
1099    #[async_test]
1100    async fn test_state_change(driver: DefaultDriver) {
1101        let mut units = StateUnits::new();
1102
1103        let a_val = Arc::new(AtomicBool::new(true));
1104
1105        let _a = units
1106            .add("a")
1107            .spawn(&driver, |recv| {
1108                run_unit(
1109                    TestUnit {
1110                        value: a_val.clone(),
1111                        dep: None,
1112                        support_saved_state: true,
1113                    },
1114                    recv,
1115                )
1116            })
1117            .unwrap();
1118        let _b = units
1119            .add("b")
1120            .spawn(&driver, |recv| run_unit(TestUnit::default(), recv))
1121            .unwrap();
1122        units.start().await;
1123
1124        let _c = units
1125            .add("c")
1126            .spawn(&driver, |recv| run_unit(TestUnit::default(), recv));
1127        units.stop().await;
1128        units.start().await;
1129
1130        units.stop().await;
1131
1132        let state = units.save().await.unwrap();
1133
1134        a_val.store(false, Ordering::Relaxed);
1135
1136        units.restore(state).await.unwrap();
1137
1138        assert!(a_val.load(Ordering::Relaxed));
1139    }
1140
1141    #[async_test]
1142    async fn test_dependencies(driver: DefaultDriver) {
1143        let mut units = StateUnits::new();
1144
1145        let a_val = Arc::new(AtomicBool::new(true));
1146
1147        let a = units
1148            .add("zzz")
1149            .spawn(&driver, |recv| {
1150                run_unit(
1151                    TestUnit {
1152                        value: a_val.clone(),
1153                        dep: None,
1154                        support_saved_state: true,
1155                    },
1156                    recv,
1157                )
1158            })
1159            .unwrap();
1160
1161        let _b = units
1162            .add("aaa")
1163            .depends_on(a.handle())
1164            .spawn(&driver, |recv| {
1165                run_unit(
1166                    TestUnit {
1167                        dep: Some(a_val.clone()),
1168                        value: Default::default(),
1169                        support_saved_state: true,
1170                    },
1171                    recv,
1172                )
1173            })
1174            .unwrap();
1175        units.start().await;
1176        units.stop().await;
1177
1178        let state = units.save().await.unwrap();
1179
1180        a_val.store(false, Ordering::Relaxed);
1181
1182        units.restore(state).await.unwrap();
1183    }
1184
1185    #[async_test]
1186    async fn test_dep_no_saved_state(driver: DefaultDriver) {
1187        let mut units = StateUnits::new();
1188
1189        let true_val = Arc::new(AtomicBool::new(true));
1190        let shared_val = Arc::new(AtomicBool::new(false));
1191
1192        let a = units
1193            .add("a")
1194            .spawn(&driver, |recv| {
1195                run_unit(
1196                    TestUnit {
1197                        value: true_val.clone(),
1198                        dep: Some(shared_val.clone()),
1199                        support_saved_state: true,
1200                    },
1201                    recv,
1202                )
1203            })
1204            .unwrap();
1205
1206        // no saved state
1207        // Note that restore is never called for this unit.
1208        let b = units
1209            .add("b_no_saved_state")
1210            .dependency_of(a.handle())
1211            .spawn(&driver, |recv| {
1212                run_unit(
1213                    TestUnit {
1214                        value: true_val.clone(),
1215                        dep: Some(shared_val.clone()),
1216                        support_saved_state: false,
1217                    },
1218                    recv,
1219                )
1220            })
1221            .unwrap();
1222
1223        // A has a transitive dependency on C via B.
1224        let _c = units
1225            .add("c")
1226            .dependency_of(b.handle())
1227            .spawn(&driver, |recv| {
1228                run_unit(
1229                    TestUnitSetDep {
1230                        dep: shared_val,
1231                        driver: driver.clone(),
1232                    },
1233                    recv,
1234                )
1235            })
1236            .unwrap();
1237
1238        let state = units.save().await.unwrap();
1239
1240        units.restore(state).await.unwrap();
1241    }
1242}