state_unit/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! VM state machine unit handling.
5//!
6//! A state unit is a VM component (such as a device) that needs to react to
7//! changes in the VM state machine. It needs to start, stop, reset, and
8//! save/restore with the VM. (Save/restore is not really a state change but is
9//! modeled as such since it must be synchronized with actual state changes.)
10//!
11//! This module contains types and functions for defining and manipulating state
12//! units. It does this in three parts:
13//!
14//! 1. It defines an RPC enum [`StateRequest`] which is used to request that a
15//!    state unit change state (start, stop, etc.). Each state unit must handle
16//!    incoming state requests on a mesh receiver. This is the foundational
17//!    type of this model.
18//!
19//! 2. It defines a type [`StateUnits`], which is a collection of mesh senders
20//!    for sending `StateRequest`s. This is used to initiate and wait for state
21//!    changes across all the units in the VMM, handling any required dependency
22//!    ordering.
23//!
24//! 3. It defines a trait [`StateUnit`] that can be used to define handlers for
25//!    each state request. This is an optional convenience; not all state units
26//!    will have a type that implements this trait.
27//!
28//! This model allows for asynchronous, highly concurrent state changes, and it
29//! works across process boundaries thanks to `mesh`.
30
31#![forbid(unsafe_code)]
32
33use futures::FutureExt;
34use futures::StreamExt;
35use futures::future::join_all;
36use futures_concurrency::stream::Merge;
37use inspect::Inspect;
38use inspect::InspectMut;
39use mesh::MeshPayload;
40use mesh::Receiver;
41use mesh::Sender;
42use mesh::payload::Protobuf;
43use mesh::rpc::FailableRpc;
44use mesh::rpc::Rpc;
45use mesh::rpc::RpcError;
46use mesh::rpc::RpcSend;
47use pal_async::task::Spawn;
48use pal_async::task::Task;
49use parking_lot::Mutex;
50use std::collections::BTreeMap;
51use std::collections::HashMap;
52use std::collections::hash_map;
53use std::fmt::Debug;
54use std::fmt::Display;
55use std::future::Future;
56use std::pin::pin;
57use std::sync::Arc;
58use std::sync::Weak;
59use std::sync::atomic::AtomicU32;
60use std::sync::atomic::Ordering;
61use std::time::Instant;
62use thiserror::Error;
63use tracing::Instrument;
64use vmcore::save_restore::RestoreError;
65use vmcore::save_restore::SaveError;
66use vmcore::save_restore::SavedStateBlob;
67
68/// A state change request.
69#[derive(Debug, MeshPayload)]
70pub enum StateRequest {
71    /// Start asynchronous operations.
72    Start(Rpc<(), ()>),
73
74    /// Stop asynchronous operations.
75    Stop(Rpc<(), ()>),
76
77    /// Reset a stopped unit to initial state.
78    Reset(FailableRpc<(), ()>),
79
80    /// Save state of a stopped unit.
81    Save(FailableRpc<(), Option<SavedStateBlob>>),
82
83    /// Restore state of a stopped unit.
84    Restore(FailableRpc<SavedStateBlob, ()>),
85
86    /// Inspect state.
87    Inspect(inspect::Deferred),
88}
89
90/// Trait implemented by an object that can act as a state unit.
91///
92/// Implementing this is optional, to be used with [`UnitBuilder::spawn`] or
93/// [`StateRequest::apply`]; state units can also directly process incoming
94/// [`StateRequest`]s.
95#[expect(async_fn_in_trait)] // Don't need Send bounds
96pub trait StateUnit: InspectMut {
97    /// Start asynchronous processing.
98    async fn start(&mut self);
99
100    /// Stop asynchronous processing.
101    async fn stop(&mut self);
102
103    /// Reset to initial state.
104    ///
105    /// Must only be called while stopped.
106    async fn reset(&mut self) -> anyhow::Result<()>;
107
108    /// Save state.
109    ///
110    /// Must only be called while stopped.
111    async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError>;
112
113    /// Restore state.
114    ///
115    /// Must only be called while stopped.
116    async fn restore(&mut self, buffer: SavedStateBlob) -> Result<(), RestoreError>;
117}
118
119/// Runs a simple unit that only needs to respond to state requests.
120pub async fn run_unit<T: StateUnit>(mut unit: T, mut recv: Receiver<StateRequest>) -> T {
121    while let Some(req) = recv.next().await {
122        req.apply(&mut unit).await;
123    }
124    unit
125}
126
127/// Runs a state unit that can handle inspect requests while there is an active
128/// state transition.
129pub async fn run_async_unit<T>(unit: T, mut recv: Receiver<StateRequest>) -> T
130where
131    for<'a> &'a T: StateUnit,
132{
133    while let Some(req) = recv.next().await {
134        req.apply_with_concurrent_inspects(&mut &unit, &mut recv)
135            .await;
136    }
137    unit
138}
139
140impl StateRequest {
141    /// Runs this state request against `unit`, polling `recv` for incoming
142    /// inspect requests and applying them while any state transition is in
143    /// flight.
144    ///
145    /// For this to work, your state unit `T` should implement [`StateUnit`] for
146    /// `&'_ T`.
147    ///
148    /// Panics if a state transition arrives on `recv` while this one is being
149    /// processed. This would indicate a contract violation with [`StateUnits`].
150    pub async fn apply_with_concurrent_inspects<'a, T>(
151        self,
152        unit: &mut &'a T,
153        recv: &mut Receiver<StateRequest>,
154    ) where
155        &'a T: StateUnit,
156    {
157        match self {
158            StateRequest::Inspect(_) => {
159                // This request has no response and completes synchronously,
160                // so don't wait for concurrent requests.
161                self.apply(unit).await;
162            }
163
164            StateRequest::Start(_)
165            | StateRequest::Stop(_)
166            | StateRequest::Reset(_)
167            | StateRequest::Save(_)
168            | StateRequest::Restore(_) => {
169                // Handle for concurrent inspect requests.
170                enum Event {
171                    OpDone,
172                    Req(StateRequest),
173                }
174                let mut op_unit = *unit;
175                let op = pin!(self.apply(&mut op_unit).into_stream());
176                let mut stream = (op.map(|()| Event::OpDone), recv.map(Event::Req)).merge();
177                while let Some(Event::Req(next_req)) = stream.next().await {
178                    match next_req {
179                        StateRequest::Inspect(req) => req.inspect(&mut *unit),
180                        _ => panic!(
181                            "unexpected state transition {next_req:?} during state transition"
182                        ),
183                    }
184                }
185            }
186        }
187    }
188
189    /// Runs this state request against `unit`.
190    pub async fn apply(self, unit: &mut impl StateUnit) {
191        match self {
192            StateRequest::Start(rpc) => rpc.handle(async |()| unit.start().await).await,
193            StateRequest::Stop(rpc) => rpc.handle(async |()| unit.stop().await).await,
194            StateRequest::Reset(rpc) => rpc.handle_failable(async |()| unit.reset().await).await,
195            StateRequest::Save(rpc) => rpc.handle_failable(async |()| unit.save().await).await,
196            StateRequest::Restore(rpc) => {
197                rpc.handle_failable(async |buffer| unit.restore(buffer).await)
198                    .await
199            }
200            StateRequest::Inspect(req) => req.inspect(unit),
201        }
202    }
203}
204
205/// A set of state units.
206#[derive(Debug)]
207pub struct StateUnits {
208    inner: Arc<Mutex<Inner>>,
209    running: bool,
210}
211
212#[derive(Copy, Clone, PartialEq, Eq, Debug, Inspect)]
213enum State {
214    Stopped,
215    Starting,
216    Running,
217    Stopping,
218    Resetting,
219    Saving,
220    Restoring,
221}
222
223#[derive(Debug)]
224struct Inner {
225    next_id: u64,
226    units: BTreeMap<u64, Unit>,
227    names: HashMap<Arc<str>, u64>,
228}
229
230#[derive(Debug)]
231struct Unit {
232    name: Arc<str>,
233    send: Sender<StateRequest>,
234    dependencies: Vec<u64>,
235    dependents: Vec<u64>,
236    state: State,
237}
238
239/// An error returned when a state unit name is already in use.
240#[derive(Debug, Error)]
241#[error("state unit name {0} is in use")]
242pub struct NameInUse(Arc<str>);
243
244#[derive(Debug, Error)]
245#[error("critical unit communication failure: {name}")]
246struct UnitRecvError {
247    name: Arc<str>,
248    #[source]
249    source: RpcError,
250}
251
252#[derive(Debug, Clone)]
253struct UnitId {
254    name: Arc<str>,
255    id: u64,
256}
257
258/// A handle returned by [`StateUnits::add`], used to remove the state unit.
259#[must_use]
260#[derive(Debug)]
261pub struct UnitHandle {
262    id: UnitId,
263    inner: Option<Weak<Mutex<Inner>>>,
264}
265
266impl Drop for UnitHandle {
267    fn drop(&mut self) {
268        self.remove_if();
269    }
270}
271
272impl UnitHandle {
273    /// Remove the state unit.
274    pub fn remove(mut self) {
275        self.remove_if();
276    }
277
278    /// Detach this handle, leaving the unit in place indefinitely.
279    pub fn detach(mut self) {
280        self.inner = None;
281    }
282
283    fn remove_if(&mut self) {
284        if let Some(inner) = self.inner.take().and_then(|inner| inner.upgrade()) {
285            let mut inner = inner.lock();
286            inner.units.remove(&self.id.id).expect("unit exists");
287            inner.names.remove(&self.id.name).expect("unit exists");
288        }
289    }
290}
291
292/// An object returned by [`StateUnits::inspector`] to inspect state units while
293/// state transitions may be in flight.
294pub struct StateUnitsInspector {
295    inner: Weak<Mutex<Inner>>,
296}
297
298impl Inspect for StateUnits {
299    fn inspect(&self, req: inspect::Request<'_>) {
300        self.inner.lock().inspect(req);
301    }
302}
303
304impl Inspect for StateUnitsInspector {
305    fn inspect(&self, req: inspect::Request<'_>) {
306        if let Some(inner) = self.inner.upgrade() {
307            inner.lock().inspect(req);
308        }
309    }
310}
311
312impl Inspect for Inner {
313    fn inspect(&self, req: inspect::Request<'_>) {
314        let mut resp = req.respond();
315        for unit in self.units.values() {
316            resp.child(unit.name.as_ref(), |req| {
317                let mut resp = req.respond();
318                if !unit.dependencies.is_empty() {
319                    resp.field_with("dependencies", || {
320                        unit.dependencies
321                            .iter()
322                            .map(|id| self.units[id].name.as_ref())
323                            .collect::<Vec<_>>()
324                            .join(",")
325                    });
326                }
327                if !unit.dependents.is_empty() {
328                    resp.field_with("dependents", || {
329                        unit.dependents
330                            .iter()
331                            .map(|id| self.units[id].name.as_ref())
332                            .collect::<Vec<_>>()
333                            .join(",")
334                    });
335                }
336                resp.field("unit_state", unit.state)
337                    .merge(&inspect::send(&unit.send, StateRequest::Inspect));
338            });
339        }
340    }
341}
342
343/// The saved state for an individual unit.
344#[derive(Protobuf)]
345#[mesh(package = "state_unit")]
346pub struct SavedStateUnit {
347    /// The name of the state unit.
348    #[mesh(1)]
349    pub name: String,
350    /// The opaque saved state blob.
351    #[mesh(2)]
352    pub state: SavedStateBlob,
353}
354
355/// An error from a state transition.
356#[derive(Debug, Error)]
357#[error("{op} failed")]
358pub struct StateTransitionError {
359    op: &'static str,
360    #[source]
361    errors: UnitErrorSet,
362}
363
364fn extract<T, E: Into<anyhow::Error>, U>(
365    op: &'static str,
366    iter: impl IntoIterator<Item = (Arc<str>, Result<T, E>)>,
367    mut f: impl FnMut(Arc<str>, T) -> Option<U>,
368) -> Result<Vec<U>, StateTransitionError> {
369    let mut result = Vec::new();
370    let mut errors = Vec::new();
371    for (name, item) in iter {
372        match item {
373            Ok(t) => {
374                if let Some(u) = f(name, t) {
375                    result.push(u);
376                }
377            }
378            Err(err) => errors.push((name, err.into())),
379        }
380    }
381    if errors.is_empty() {
382        Ok(result)
383    } else {
384        Err(StateTransitionError {
385            op,
386            errors: UnitErrorSet(errors),
387        })
388    }
389}
390
391fn check<E: Into<anyhow::Error>>(
392    op: &'static str,
393    iter: impl IntoIterator<Item = (Arc<str>, Result<(), E>)>,
394) -> Result<(), StateTransitionError> {
395    extract(op, iter, |_, _| Some(()))?;
396    Ok(())
397}
398
399#[derive(Debug)]
400struct UnitErrorSet(Vec<(Arc<str>, anyhow::Error)>);
401
402impl Display for UnitErrorSet {
403    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
404        let mut map = f.debug_map();
405        for (name, err) in &self.0 {
406            map.entry(&format_args!("{}", name), &format_args!("{:#}", err));
407        }
408        map.finish()
409    }
410}
411
412impl std::error::Error for UnitErrorSet {}
413
414impl StateUnits {
415    /// Creates a new instance with no initial units.
416    pub fn new() -> Self {
417        Self {
418            inner: Arc::new(Mutex::new(Inner {
419                next_id: 0,
420                units: BTreeMap::new(),
421                names: HashMap::new(),
422            })),
423            running: false,
424        }
425    }
426
427    /// Returns an inspector that can be used to inspect the state units while
428    /// state transitions are in process.
429    pub fn inspector(&self) -> StateUnitsInspector {
430        StateUnitsInspector {
431            inner: Arc::downgrade(&self.inner),
432        }
433    }
434
435    /// Save and restore will use `name` as the save ID, so this forms part of
436    /// the saved state.
437    ///
438    /// Note that the added unit will not be running after it is built/spawned,
439    /// even if the other units are running. Call
440    /// [`StateUnits::start_stopped_units`] when finished adding units.
441    pub fn add(&self, name: impl Into<Arc<str>>) -> UnitBuilder<'_> {
442        UnitBuilder {
443            units: self,
444            name: name.into(),
445            dependencies: Vec::new(),
446            dependents: Vec::new(),
447        }
448    }
449
450    /// Check if state units are currently running.
451    pub fn is_running(&self) -> bool {
452        self.running
453    }
454
455    /// Starts any units that are individually stopped, because they were added
456    /// via [`StateUnits::add`] while the VM was running.
457    ///
458    /// Does nothing if all units are stopped, via [`StateUnits::stop`].
459    pub async fn start_stopped_units(&mut self) {
460        if self.is_running() {
461            self.start().await;
462        }
463    }
464
465    /// Starts all the state units.
466    pub async fn start(&mut self) {
467        self.run_op(
468            "start",
469            None,
470            State::Stopped,
471            State::Starting,
472            State::Running,
473            StateRequest::Start,
474            |_, _| Some(()),
475            |unit| &unit.dependencies,
476        )
477        .await;
478        self.running = true;
479    }
480
481    /// Stops all the state units.
482    pub async fn stop(&mut self) {
483        assert!(self.running);
484        // Stop units in reverse dependency order so that a dependency is not
485        // stopped before its dependant.
486        self.run_op(
487            "stop",
488            None,
489            State::Running,
490            State::Stopping,
491            State::Stopped,
492            StateRequest::Stop,
493            |_, _| Some(()),
494            |unit| &unit.dependents,
495        )
496        .await;
497        self.running = false;
498    }
499
500    /// Resets all the state units.
501    ///
502    /// Panics if running.
503    pub async fn reset(&mut self) -> Result<(), StateTransitionError> {
504        assert!(!self.running);
505        // Reset in dependency order so that dependants observe their
506        // dependencies' reset state.
507        let r = self
508            .run_op(
509                "reset",
510                None,
511                State::Stopped,
512                State::Resetting,
513                State::Stopped,
514                StateRequest::Reset,
515                |_, _| Some(()),
516                |unit| &unit.dependencies,
517            )
518            .await;
519
520        check("reset", r)?;
521        Ok(())
522    }
523
524    /// Saves all the state units.
525    ///
526    /// Panics if running.
527    pub async fn save(&mut self) -> Result<Vec<SavedStateUnit>, StateTransitionError> {
528        assert!(!self.running);
529        // Save can occur in any order since it will not observably mutate
530        // state.
531        let r = self
532            .run_op(
533                "save",
534                None,
535                State::Stopped,
536                State::Saving,
537                State::Stopped,
538                StateRequest::Save,
539                |_, _| Some(()),
540                |_| &[],
541            )
542            .await;
543
544        let states = extract("save", r, |name, state| {
545            state.map(|state| SavedStateUnit {
546                name: name.to_string(),
547                state,
548            })
549        })?;
550
551        Ok(states)
552    }
553
554    /// Restores all the state units.
555    ///
556    /// Panics if running.
557    pub async fn restore(
558        &mut self,
559        states: Vec<SavedStateUnit>,
560    ) -> Result<(), StateTransitionError> {
561        assert!(!self.running);
562
563        #[derive(Debug, Error)]
564        enum RestoreUnitError {
565            #[error("unknown unit name")]
566            Unknown,
567            #[error("duplicate unit name")]
568            Duplicate,
569        }
570
571        let mut states_by_id = HashMap::new();
572        let mut r = Vec::new();
573        {
574            let inner = self.inner.lock();
575            for state in states {
576                match inner.names.get_key_value(state.name.as_str()) {
577                    Some((name, &id)) => {
578                        if states_by_id
579                            .insert(id, (name.clone(), state.state))
580                            .is_some()
581                        {
582                            r.push((name.clone(), Err(RestoreUnitError::Duplicate)));
583                        }
584                    }
585                    None => {
586                        r.push((state.name.into(), Err(RestoreUnitError::Unknown)));
587                    }
588                }
589            }
590        }
591
592        check("restore", r)?;
593
594        let r = self
595            .run_op(
596                "restore",
597                None,
598                State::Stopped,
599                State::Restoring,
600                State::Stopped,
601                StateRequest::Restore,
602                |id, _| states_by_id.remove(&id).map(|(_, blob)| blob),
603                |unit| &unit.dependencies,
604            )
605            .await;
606
607        // Make sure all the saved state was consumed. This could hit if a unit
608        // was removed concurrently with the restore.
609        check(
610            "restore",
611            states_by_id
612                .into_iter()
613                .map(|(_, (name, _))| (name, Err(RestoreUnitError::Unknown))),
614        )?;
615
616        check("restore", r)?;
617
618        Ok(())
619    }
620
621    /// Runs a state change operation on a set of units.
622    ///
623    /// `op` gives the name of the operation for tracing and error reporting
624    /// purposes.
625    ///
626    /// `unit_ids` is the set of the units whose state should be changed. If
627    /// `unit_ids` is `None`, all units change states.
628    ///
629    /// The old state for each unit must be `old_state`. During the operation,
630    /// the unit is temporarily places in `interim_state`. When complete, the
631    /// unit is placed in `new_state`.
632    ///
633    /// Each unit waits for its dependencies to complete their state change
634    /// operation before proceeding with their own state change. The
635    /// dependencies list is computed for a unit by calling `deps`.
636    ///
637    /// To perform the state change, the unit is sent a request generated using
638    /// `request`, with input generated by `input`. If `input` returns `None`,
639    /// then communication with the unit is skipped, but the unit still
640    /// transitions through the interim and into the new state, and its
641    /// dependencies are still waited on by its dependents.
642    async fn run_op<I: 'static, R: 'static + Send>(
643        &self,
644        op: &str,
645        unit_ids: Option<&[u64]>,
646        old_state: State,
647        interim_state: State,
648        new_state: State,
649        request: impl Copy + FnOnce(Rpc<I, R>) -> StateRequest,
650        mut input: impl FnMut(u64, &Unit) -> Option<I>,
651        mut deps: impl FnMut(&Unit) -> &[u64],
652    ) -> Vec<(Arc<str>, R)> {
653        let mut done = Vec::new();
654        let ready_set;
655        {
656            let mut inner = self.inner.lock();
657            ready_set = inner.ready_set(unit_ids);
658            for (&id, unit) in inner
659                .units
660                .iter_mut()
661                .filter(|(id, _)| ready_set.0.contains_key(id))
662            {
663                if unit.state != old_state {
664                    assert_eq!(
665                        unit.state, new_state,
666                        "unit {} in {:?} state, should be {:?} or {:?}",
667                        unit.name, unit.state, old_state, new_state
668                    );
669                    ready_set.done(id, true);
670                } else {
671                    let name = unit.name.clone();
672                    let input = (input)(id, unit);
673                    let ready_set = ready_set.clone();
674                    let deps = deps(unit).to_vec();
675                    let fut = state_change(name.clone(), unit, request, input);
676                    let recv = async move {
677                        ready_set.wait(op, id, &deps).await;
678                        let r = fut.await;
679                        ready_set.done(id, true);
680                        (name, id, r)
681                    };
682                    done.push(recv);
683                    unit.state = interim_state;
684                }
685            }
686        }
687
688        let results = async {
689            let start = Instant::now();
690            let results = join_all(done).await;
691            tracing::info!(duration = ?Instant::now() - start, "state change complete");
692            results
693        }
694        .instrument(tracing::info_span!("state_change", operation = op))
695        .await;
696
697        let mut inner = self.inner.lock();
698        let r = results
699            .into_iter()
700            .filter_map(|(name, id, r)| {
701                match r {
702                    Ok(Some(r)) => Some((name, r)),
703                    Ok(None) => None,
704                    Err(err) => {
705                        // If the unit was removed during the operation, then
706                        // ignore the failure. Otherwise, panic because unit
707                        // failure is not recoverable. FUTURE: reconsider this
708                        // position.
709                        if inner.units.contains_key(&id) {
710                            panic!("{:?}", err);
711                        }
712                        None
713                    }
714                }
715            })
716            .collect();
717        for (_, unit) in inner
718            .units
719            .iter_mut()
720            .filter(|(id, _)| ready_set.0.contains_key(id))
721        {
722            if unit.state == interim_state {
723                unit.state = new_state;
724            } else {
725                assert_eq!(
726                    unit.state, new_state,
727                    "unit {} in {:?} state, should be {:?} or {:?}",
728                    unit.name, unit.state, interim_state, new_state
729                );
730            }
731        }
732        r
733    }
734}
735
736impl Inner {
737    fn ready_set(&self, unit_ids: Option<&[u64]>) -> ReadySet {
738        let map = |id, unit: &Unit| {
739            (
740                id,
741                ReadyState {
742                    name: unit.name.clone(),
743                    ready: Arc::new(Ready::default()),
744                },
745            )
746        };
747        let units = if let Some(unit_ids) = unit_ids {
748            unit_ids
749                .iter()
750                .map(|id| map(*id, &self.units[id]))
751                .collect()
752        } else {
753            self.units.iter().map(|(id, unit)| map(*id, unit)).collect()
754        };
755        ReadySet(Arc::new(units))
756    }
757}
758
759#[derive(Clone)]
760struct ReadySet(Arc<BTreeMap<u64, ReadyState>>);
761
762#[derive(Clone)]
763struct ReadyState {
764    name: Arc<str>,
765    ready: Arc<Ready>,
766}
767
768impl ReadySet {
769    async fn wait(&self, op: &str, id: u64, deps: &[u64]) -> bool {
770        for dep in deps {
771            if let Some(dep) = self.0.get(dep) {
772                if !dep.ready.is_ready() {
773                    tracing::debug!(
774                        device = self.0[&id].name.as_ref(),
775                        dependency = dep.name.as_ref(),
776                        operation = op,
777                        "waiting on dependency"
778                    );
779                }
780                if !dep.ready.wait().await {
781                    return false;
782                }
783            }
784        }
785        true
786    }
787
788    fn done(&self, id: u64, success: bool) {
789        self.0[&id].ready.signal(success);
790    }
791}
792
793/// Sends state change `request` to `unit` with `input`, wrapping the result
794/// future with a span, and wrapping its error with something more informative.
795///
796/// `operation` and `name` are used in tracing and error construction.
797fn state_change<I: 'static, R: 'static + Send, Req: FnOnce(Rpc<I, R>) -> StateRequest>(
798    name: Arc<str>,
799    unit: &Unit,
800    request: Req,
801    input: Option<I>,
802) -> impl Future<Output = Result<Option<R>, UnitRecvError>> + use<I, R, Req> {
803    let send = unit.send.clone();
804
805    async move {
806        let Some(input) = input else { return Ok(None) };
807        let span = tracing::info_span!("device_state_change", device = name.as_ref());
808        async move {
809            let start = Instant::now();
810            let r = send
811                .call(request, input)
812                .await
813                .map_err(|err| UnitRecvError { name, source: err });
814            tracing::debug!(duration = ?Instant::now() - start, "device state change complete");
815            r.map(Some)
816        }
817        .instrument(span)
818        .await
819    }
820}
821
822/// A builder returned by [`StateUnits::add`].
823#[derive(Debug)]
824#[must_use]
825pub struct UnitBuilder<'a> {
826    units: &'a StateUnits,
827    name: Arc<str>,
828    dependencies: Vec<u64>,
829    dependents: Vec<u64>,
830}
831
832impl UnitBuilder<'_> {
833    /// Adds `handle` as a dependency of this new unit.
834    ///
835    /// Operations will be ordered to ensure that a dependency will stop after
836    /// its dependants, and that it will reset or restore before its dependants.
837    pub fn depends_on(mut self, handle: &UnitHandle) -> Self {
838        self.dependencies.push(self.handle_id(handle));
839        self
840    }
841
842    /// Adds this new unit as a dependency of `handle`.
843    ///
844    /// Operations will be ordered to ensure that a dependency will stop after
845    /// its dependants, and that it will reset or restore before its dependants.
846    pub fn dependency_of(mut self, handle: &UnitHandle) -> Self {
847        self.dependents.push(self.handle_id(handle));
848        self
849    }
850
851    fn handle_id(&self, handle: &UnitHandle) -> u64 {
852        // Ensure this handle is associated with this set of state units.
853        assert_eq!(
854            Weak::as_ptr(handle.inner.as_ref().unwrap()),
855            Arc::as_ptr(&self.units.inner)
856        );
857        handle.id.id
858    }
859
860    /// Adds a new state unit sending requests to `send`.
861    pub fn build(mut self, send: Sender<StateRequest>) -> Result<UnitHandle, NameInUse> {
862        let id = {
863            let mut inner = self.units.inner.lock();
864            let id = inner.next_id;
865            let entry = match inner.names.entry(self.name.clone()) {
866                hash_map::Entry::Occupied(_) => return Err(NameInUse(self.name)),
867                hash_map::Entry::Vacant(e) => e,
868            };
869            entry.insert(id);
870
871            // Dedup the dependencies and update the dependencies' lists of
872            // dependents.
873            self.dependencies.sort();
874            self.dependencies.dedup();
875            for &dep in &self.dependencies {
876                inner.units.get_mut(&dep).unwrap().dependents.push(id);
877            }
878
879            // Dedup the depenedents and update the dependents' lists of
880            // dependencies.
881            self.dependents.sort();
882            self.dependents.dedup();
883            for &dep in &self.dependents {
884                inner.units.get_mut(&dep).unwrap().dependencies.push(id);
885            }
886            inner.units.insert(
887                id,
888                Unit {
889                    name: self.name.clone(),
890                    send,
891                    dependencies: self.dependencies,
892                    dependents: self.dependents,
893                    state: State::Stopped,
894                },
895            );
896            let unit_id = UnitId {
897                name: self.name,
898                id,
899            };
900            inner.next_id += 1;
901            unit_id
902        };
903        Ok(UnitHandle {
904            id,
905            inner: Some(Arc::downgrade(&self.units.inner)),
906        })
907    }
908
909    /// Adds a unit as in [`Self::build`], then spawns a task for running the
910    /// unit.
911    ///
912    /// The channel to receive state change requests is passed to `f`, which
913    /// should return the future to evaluate to run the unit.
914    #[track_caller]
915    pub fn spawn<F, Fut>(
916        self,
917        spawner: impl Spawn,
918        f: F,
919    ) -> Result<SpawnedUnit<Fut::Output>, NameInUse>
920    where
921        F: FnOnce(Receiver<StateRequest>) -> Fut,
922        Fut: 'static + Send + Future,
923        Fut::Output: 'static + Send,
924    {
925        let (send, recv) = mesh::channel();
926        let task_name = format!("unit-{}", self.name);
927        let handle = self.build(send)?;
928        let fut = (f)(recv);
929        let task = spawner.spawn(task_name, fut);
930        Ok(SpawnedUnit { task, handle })
931    }
932}
933
934/// A handle to a spawned unit.
935#[must_use]
936pub struct SpawnedUnit<T> {
937    handle: UnitHandle,
938    task: Task<T>,
939}
940
941impl<T> Debug for SpawnedUnit<T> {
942    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
943        f.debug_struct("SpawnedUnit")
944            .field("handle", &self.handle)
945            .field("task", &self.task)
946            .finish()
947    }
948}
949
950impl<T> SpawnedUnit<T> {
951    /// Removes the unit and returns it.
952    pub async fn remove(self) -> T {
953        self.handle.remove();
954        self.task.await
955    }
956
957    /// Gets the unit handle for use with methods like
958    /// [`UnitBuilder::depends_on`].
959    pub fn handle(&self) -> &UnitHandle {
960        &self.handle
961    }
962}
963
964#[derive(Default)]
965struct Ready {
966    state: AtomicU32,
967    event: event_listener::Event,
968}
969
970impl Ready {
971    /// Wakes everyone with `success`.
972    fn signal(&self, success: bool) {
973        self.state.store(success as u32 + 1, Ordering::Release);
974        self.event.notify(usize::MAX);
975    }
976
977    fn is_ready(&self) -> bool {
978        self.state.load(Ordering::Acquire) != 0
979    }
980
981    /// Waits for `signal` to be called and returns its `success` parameter.
982    async fn wait(&self) -> bool {
983        loop {
984            let listener = self.event.listen();
985            let state = self.state.load(Ordering::Acquire);
986            if state != 0 {
987                return state - 1 != 0;
988            }
989            listener.await;
990        }
991    }
992}
993
994#[cfg(test)]
995mod tests {
996    use super::StateUnit;
997    use super::StateUnits;
998    use crate::run_unit;
999    use inspect::InspectMut;
1000    use mesh::payload::Protobuf;
1001    use pal_async::DefaultDriver;
1002    use pal_async::async_test;
1003    use std::sync::Arc;
1004    use std::sync::atomic::AtomicBool;
1005    use std::sync::atomic::Ordering;
1006    use std::time::Duration;
1007    use test_with_tracing::test;
1008    use vmcore::save_restore::RestoreError;
1009    use vmcore::save_restore::SaveError;
1010    use vmcore::save_restore::SavedStateBlob;
1011    use vmcore::save_restore::SavedStateRoot;
1012
1013    #[derive(Default)]
1014    struct TestUnit {
1015        value: Arc<AtomicBool>,
1016        dep: Option<Arc<AtomicBool>>,
1017        /// If we should support saved state or not.
1018        support_saved_state: bool,
1019    }
1020
1021    #[derive(Protobuf, SavedStateRoot)]
1022    #[mesh(package = "test")]
1023    struct SavedState(bool);
1024
1025    impl StateUnit for TestUnit {
1026        async fn start(&mut self) {}
1027
1028        async fn stop(&mut self) {}
1029
1030        async fn reset(&mut self) -> anyhow::Result<()> {
1031            Ok(())
1032        }
1033
1034        async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError> {
1035            if self.support_saved_state {
1036                let state = SavedState(self.value.load(Ordering::Relaxed));
1037                Ok(Some(SavedStateBlob::new(state)))
1038            } else {
1039                Ok(None)
1040            }
1041        }
1042
1043        async fn restore(&mut self, state: SavedStateBlob) -> Result<(), RestoreError> {
1044            assert!(self.dep.as_ref().is_none_or(|v| v.load(Ordering::Relaxed)));
1045
1046            if self.support_saved_state {
1047                let state: SavedState = state.parse()?;
1048                self.value.store(state.0, Ordering::Relaxed);
1049                Ok(())
1050            } else {
1051                Err(RestoreError::SavedStateNotSupported)
1052            }
1053        }
1054    }
1055
1056    impl InspectMut for TestUnit {
1057        fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1058            req.respond();
1059        }
1060    }
1061
1062    struct TestUnitSetDep {
1063        dep: Arc<AtomicBool>,
1064        driver: DefaultDriver,
1065    }
1066
1067    impl StateUnit for TestUnitSetDep {
1068        async fn start(&mut self) {}
1069
1070        async fn stop(&mut self) {}
1071
1072        async fn reset(&mut self) -> anyhow::Result<()> {
1073            Ok(())
1074        }
1075
1076        async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError> {
1077            Ok(Some(SavedStateBlob::new(SavedState(true))))
1078        }
1079
1080        async fn restore(&mut self, _state: SavedStateBlob) -> Result<(), RestoreError> {
1081            pal_async::timer::PolledTimer::new(&self.driver)
1082                .sleep(Duration::from_millis(100))
1083                .await;
1084
1085            self.dep.store(true, Ordering::Relaxed);
1086
1087            Ok(())
1088        }
1089    }
1090
1091    impl InspectMut for TestUnitSetDep {
1092        fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1093            req.respond();
1094        }
1095    }
1096
1097    #[async_test]
1098    async fn test_state_change(driver: DefaultDriver) {
1099        let mut units = StateUnits::new();
1100
1101        let a_val = Arc::new(AtomicBool::new(true));
1102
1103        let _a = units
1104            .add("a")
1105            .spawn(&driver, |recv| {
1106                run_unit(
1107                    TestUnit {
1108                        value: a_val.clone(),
1109                        dep: None,
1110                        support_saved_state: true,
1111                    },
1112                    recv,
1113                )
1114            })
1115            .unwrap();
1116        let _b = units
1117            .add("b")
1118            .spawn(&driver, |recv| run_unit(TestUnit::default(), recv))
1119            .unwrap();
1120        units.start().await;
1121
1122        let _c = units
1123            .add("c")
1124            .spawn(&driver, |recv| run_unit(TestUnit::default(), recv));
1125        units.stop().await;
1126        units.start().await;
1127
1128        units.stop().await;
1129
1130        let state = units.save().await.unwrap();
1131
1132        a_val.store(false, Ordering::Relaxed);
1133
1134        units.restore(state).await.unwrap();
1135
1136        assert!(a_val.load(Ordering::Relaxed));
1137    }
1138
1139    #[async_test]
1140    async fn test_dependencies(driver: DefaultDriver) {
1141        let mut units = StateUnits::new();
1142
1143        let a_val = Arc::new(AtomicBool::new(true));
1144
1145        let a = units
1146            .add("zzz")
1147            .spawn(&driver, |recv| {
1148                run_unit(
1149                    TestUnit {
1150                        value: a_val.clone(),
1151                        dep: None,
1152                        support_saved_state: true,
1153                    },
1154                    recv,
1155                )
1156            })
1157            .unwrap();
1158
1159        let _b = units
1160            .add("aaa")
1161            .depends_on(a.handle())
1162            .spawn(&driver, |recv| {
1163                run_unit(
1164                    TestUnit {
1165                        dep: Some(a_val.clone()),
1166                        value: Default::default(),
1167                        support_saved_state: true,
1168                    },
1169                    recv,
1170                )
1171            })
1172            .unwrap();
1173        units.start().await;
1174        units.stop().await;
1175
1176        let state = units.save().await.unwrap();
1177
1178        a_val.store(false, Ordering::Relaxed);
1179
1180        units.restore(state).await.unwrap();
1181    }
1182
1183    #[async_test]
1184    async fn test_dep_no_saved_state(driver: DefaultDriver) {
1185        let mut units = StateUnits::new();
1186
1187        let true_val = Arc::new(AtomicBool::new(true));
1188        let shared_val = Arc::new(AtomicBool::new(false));
1189
1190        let a = units
1191            .add("a")
1192            .spawn(&driver, |recv| {
1193                run_unit(
1194                    TestUnit {
1195                        value: true_val.clone(),
1196                        dep: Some(shared_val.clone()),
1197                        support_saved_state: true,
1198                    },
1199                    recv,
1200                )
1201            })
1202            .unwrap();
1203
1204        // no saved state
1205        // Note that restore is never called for this unit.
1206        let b = units
1207            .add("b_no_saved_state")
1208            .dependency_of(a.handle())
1209            .spawn(&driver, |recv| {
1210                run_unit(
1211                    TestUnit {
1212                        value: true_val.clone(),
1213                        dep: Some(shared_val.clone()),
1214                        support_saved_state: false,
1215                    },
1216                    recv,
1217                )
1218            })
1219            .unwrap();
1220
1221        // A has a transitive dependency on C via B.
1222        let _c = units
1223            .add("c")
1224            .dependency_of(b.handle())
1225            .spawn(&driver, |recv| {
1226                run_unit(
1227                    TestUnitSetDep {
1228                        dep: shared_val,
1229                        driver: driver.clone(),
1230                    },
1231                    recv,
1232                )
1233            })
1234            .unwrap();
1235
1236        let state = units.save().await.unwrap();
1237
1238        units.restore(state).await.unwrap();
1239    }
1240}