1#![forbid(unsafe_code)]
32
33use futures::FutureExt;
34use futures::StreamExt;
35use futures::future::join_all;
36use futures_concurrency::stream::Merge;
37use inspect::Inspect;
38use inspect::InspectMut;
39use mesh::MeshPayload;
40use mesh::Receiver;
41use mesh::Sender;
42use mesh::payload::Protobuf;
43use mesh::rpc::FailableRpc;
44use mesh::rpc::Rpc;
45use mesh::rpc::RpcError;
46use mesh::rpc::RpcSend;
47use pal_async::task::Spawn;
48use pal_async::task::Task;
49use parking_lot::Mutex;
50use std::collections::BTreeMap;
51use std::collections::HashMap;
52use std::collections::hash_map;
53use std::fmt::Debug;
54use std::fmt::Display;
55use std::future::Future;
56use std::pin::pin;
57use std::sync::Arc;
58use std::sync::Weak;
59use std::sync::atomic::AtomicU32;
60use std::sync::atomic::Ordering;
61use std::time::Instant;
62use thiserror::Error;
63use tracing::Instrument;
64use vmcore::save_restore::RestoreError;
65use vmcore::save_restore::SaveError;
66use vmcore::save_restore::SavedStateBlob;
67
68#[derive(Debug, MeshPayload)]
70pub enum StateRequest {
71 Start(Rpc<(), ()>),
73
74 Stop(Rpc<(), ()>),
76
77 Reset(FailableRpc<(), ()>),
79
80 Save(FailableRpc<(), Option<SavedStateBlob>>),
82
83 Restore(FailableRpc<SavedStateBlob, ()>),
85
86 Inspect(inspect::Deferred),
88}
89
90#[expect(async_fn_in_trait)] pub trait StateUnit: InspectMut {
97 async fn start(&mut self);
99
100 async fn stop(&mut self);
102
103 async fn reset(&mut self) -> anyhow::Result<()>;
107
108 async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError>;
112
113 async fn restore(&mut self, buffer: SavedStateBlob) -> Result<(), RestoreError>;
117}
118
119pub async fn run_unit<T: StateUnit>(mut unit: T, mut recv: Receiver<StateRequest>) -> T {
121 while let Some(req) = recv.next().await {
122 req.apply(&mut unit).await;
123 }
124 unit
125}
126
127pub async fn run_async_unit<T>(unit: T, mut recv: Receiver<StateRequest>) -> T
130where
131 for<'a> &'a T: StateUnit,
132{
133 while let Some(req) = recv.next().await {
134 req.apply_with_concurrent_inspects(&mut &unit, &mut recv)
135 .await;
136 }
137 unit
138}
139
140impl StateRequest {
141 pub async fn apply_with_concurrent_inspects<'a, T>(
151 self,
152 unit: &mut &'a T,
153 recv: &mut Receiver<StateRequest>,
154 ) where
155 &'a T: StateUnit,
156 {
157 match self {
158 StateRequest::Inspect(_) => {
159 self.apply(unit).await;
162 }
163
164 StateRequest::Start(_)
165 | StateRequest::Stop(_)
166 | StateRequest::Reset(_)
167 | StateRequest::Save(_)
168 | StateRequest::Restore(_) => {
169 enum Event {
171 OpDone,
172 Req(StateRequest),
173 }
174 let mut op_unit = *unit;
175 let op = pin!(self.apply(&mut op_unit).into_stream());
176 let mut stream = (op.map(|()| Event::OpDone), recv.map(Event::Req)).merge();
177 while let Some(Event::Req(next_req)) = stream.next().await {
178 match next_req {
179 StateRequest::Inspect(req) => req.inspect(&mut *unit),
180 _ => panic!(
181 "unexpected state transition {next_req:?} during state transition"
182 ),
183 }
184 }
185 }
186 }
187 }
188
189 pub async fn apply(self, unit: &mut impl StateUnit) {
191 match self {
192 StateRequest::Start(rpc) => rpc.handle(async |()| unit.start().await).await,
193 StateRequest::Stop(rpc) => rpc.handle(async |()| unit.stop().await).await,
194 StateRequest::Reset(rpc) => rpc.handle_failable(async |()| unit.reset().await).await,
195 StateRequest::Save(rpc) => rpc.handle_failable(async |()| unit.save().await).await,
196 StateRequest::Restore(rpc) => {
197 rpc.handle_failable(async |buffer| unit.restore(buffer).await)
198 .await
199 }
200 StateRequest::Inspect(req) => req.inspect(unit),
201 }
202 }
203}
204
205#[derive(Debug)]
207pub struct StateUnits {
208 inner: Arc<Mutex<Inner>>,
209 running: bool,
210}
211
212#[derive(Copy, Clone, PartialEq, Eq, Debug, Inspect)]
213enum State {
214 Stopped,
215 Starting,
216 Running,
217 Stopping,
218 Resetting,
219 Saving,
220 Restoring,
221}
222
223#[derive(Debug)]
224struct Inner {
225 next_id: u64,
226 units: BTreeMap<u64, Unit>,
227 names: HashMap<Arc<str>, u64>,
228}
229
230#[derive(Debug)]
231struct Unit {
232 name: Arc<str>,
233 send: Sender<StateRequest>,
234 dependencies: Vec<u64>,
235 dependents: Vec<u64>,
236 state: State,
237}
238
239#[derive(Debug, Error)]
241#[error("state unit name {0} is in use")]
242pub struct NameInUse(Arc<str>);
243
244#[derive(Debug, Error)]
245#[error("critical unit communication failure: {name}")]
246struct UnitRecvError {
247 name: Arc<str>,
248 #[source]
249 source: RpcError,
250}
251
252#[derive(Debug, Clone)]
253struct UnitId {
254 name: Arc<str>,
255 id: u64,
256}
257
258#[must_use]
260#[derive(Debug)]
261pub struct UnitHandle {
262 id: UnitId,
263 inner: Option<Weak<Mutex<Inner>>>,
264}
265
266impl Drop for UnitHandle {
267 fn drop(&mut self) {
268 self.remove_if();
269 }
270}
271
272impl UnitHandle {
273 pub fn remove(mut self) {
275 self.remove_if();
276 }
277
278 pub fn detach(mut self) {
280 self.inner = None;
281 }
282
283 fn remove_if(&mut self) {
284 if let Some(inner) = self.inner.take().and_then(|inner| inner.upgrade()) {
285 let mut inner = inner.lock();
286 inner.units.remove(&self.id.id).expect("unit exists");
287 inner.names.remove(&self.id.name).expect("unit exists");
288 }
289 }
290}
291
292pub struct StateUnitsInspector {
295 inner: Weak<Mutex<Inner>>,
296}
297
298impl Inspect for StateUnits {
299 fn inspect(&self, req: inspect::Request<'_>) {
300 self.inner.lock().inspect(req);
301 }
302}
303
304impl Inspect for StateUnitsInspector {
305 fn inspect(&self, req: inspect::Request<'_>) {
306 if let Some(inner) = self.inner.upgrade() {
307 inner.lock().inspect(req);
308 }
309 }
310}
311
312impl Inspect for Inner {
313 fn inspect(&self, req: inspect::Request<'_>) {
314 let mut resp = req.respond();
315 for unit in self.units.values() {
316 resp.child(unit.name.as_ref(), |req| {
317 let mut resp = req.respond();
318 if !unit.dependencies.is_empty() {
319 resp.field_with("dependencies", || {
320 unit.dependencies
321 .iter()
322 .map(|id| self.units[id].name.as_ref())
323 .collect::<Vec<_>>()
324 .join(",")
325 });
326 }
327 if !unit.dependents.is_empty() {
328 resp.field_with("dependents", || {
329 unit.dependents
330 .iter()
331 .map(|id| self.units[id].name.as_ref())
332 .collect::<Vec<_>>()
333 .join(",")
334 });
335 }
336 resp.field("unit_state", unit.state)
337 .merge(inspect::adhoc(|req| {
338 unit.send.send(StateRequest::Inspect(req.defer()))
339 }));
340 });
341 }
342 }
343}
344
345#[derive(Protobuf)]
347#[mesh(package = "state_unit")]
348pub struct SavedStateUnit {
349 #[mesh(1)]
351 pub name: String,
352 #[mesh(2)]
354 pub state: SavedStateBlob,
355}
356
357#[derive(Debug, Error)]
359#[error("{op} failed")]
360pub struct StateTransitionError {
361 op: &'static str,
362 #[source]
363 errors: UnitErrorSet,
364}
365
366fn extract<T, E: Into<anyhow::Error>, U>(
367 op: &'static str,
368 iter: impl IntoIterator<Item = (Arc<str>, Result<T, E>)>,
369 mut f: impl FnMut(Arc<str>, T) -> Option<U>,
370) -> Result<Vec<U>, StateTransitionError> {
371 let mut result = Vec::new();
372 let mut errors = Vec::new();
373 for (name, item) in iter {
374 match item {
375 Ok(t) => {
376 if let Some(u) = f(name, t) {
377 result.push(u);
378 }
379 }
380 Err(err) => errors.push((name, err.into())),
381 }
382 }
383 if errors.is_empty() {
384 Ok(result)
385 } else {
386 Err(StateTransitionError {
387 op,
388 errors: UnitErrorSet(errors),
389 })
390 }
391}
392
393fn check<E: Into<anyhow::Error>>(
394 op: &'static str,
395 iter: impl IntoIterator<Item = (Arc<str>, Result<(), E>)>,
396) -> Result<(), StateTransitionError> {
397 extract(op, iter, |_, _| Some(()))?;
398 Ok(())
399}
400
401#[derive(Debug)]
402struct UnitErrorSet(Vec<(Arc<str>, anyhow::Error)>);
403
404impl Display for UnitErrorSet {
405 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406 let mut map = f.debug_map();
407 for (name, err) in &self.0 {
408 map.entry(&format_args!("{}", name), &format_args!("{:#}", err));
409 }
410 map.finish()
411 }
412}
413
414impl std::error::Error for UnitErrorSet {}
415
416impl StateUnits {
417 pub fn new() -> Self {
419 Self {
420 inner: Arc::new(Mutex::new(Inner {
421 next_id: 0,
422 units: BTreeMap::new(),
423 names: HashMap::new(),
424 })),
425 running: false,
426 }
427 }
428
429 pub fn inspector(&self) -> StateUnitsInspector {
432 StateUnitsInspector {
433 inner: Arc::downgrade(&self.inner),
434 }
435 }
436
437 pub fn add(&self, name: impl Into<Arc<str>>) -> UnitBuilder<'_> {
444 UnitBuilder {
445 units: self,
446 name: name.into(),
447 dependencies: Vec::new(),
448 dependents: Vec::new(),
449 }
450 }
451
452 pub fn is_running(&self) -> bool {
454 self.running
455 }
456
457 pub async fn start_stopped_units(&mut self) {
462 if self.is_running() {
463 self.start().await;
464 }
465 }
466
467 pub async fn start(&mut self) {
469 self.run_op(
470 "start",
471 None,
472 State::Stopped,
473 State::Starting,
474 State::Running,
475 StateRequest::Start,
476 |_, _| Some(()),
477 |unit| &unit.dependencies,
478 )
479 .await;
480 self.running = true;
481 }
482
483 pub async fn stop(&mut self) {
485 assert!(self.running);
486 self.run_op(
489 "stop",
490 None,
491 State::Running,
492 State::Stopping,
493 State::Stopped,
494 StateRequest::Stop,
495 |_, _| Some(()),
496 |unit| &unit.dependents,
497 )
498 .await;
499 self.running = false;
500 }
501
502 pub async fn reset(&mut self) -> Result<(), StateTransitionError> {
506 assert!(!self.running);
507 let r = self
510 .run_op(
511 "reset",
512 None,
513 State::Stopped,
514 State::Resetting,
515 State::Stopped,
516 StateRequest::Reset,
517 |_, _| Some(()),
518 |unit| &unit.dependencies,
519 )
520 .await;
521
522 check("reset", r)?;
523 Ok(())
524 }
525
526 pub async fn save(&mut self) -> Result<Vec<SavedStateUnit>, StateTransitionError> {
530 assert!(!self.running);
531 let r = self
534 .run_op(
535 "save",
536 None,
537 State::Stopped,
538 State::Saving,
539 State::Stopped,
540 StateRequest::Save,
541 |_, _| Some(()),
542 |_| &[],
543 )
544 .await;
545
546 let states = extract("save", r, |name, state| {
547 state.map(|state| SavedStateUnit {
548 name: name.to_string(),
549 state,
550 })
551 })?;
552
553 Ok(states)
554 }
555
556 pub async fn restore(
560 &mut self,
561 states: Vec<SavedStateUnit>,
562 ) -> Result<(), StateTransitionError> {
563 assert!(!self.running);
564
565 #[derive(Debug, Error)]
566 enum RestoreUnitError {
567 #[error("unknown unit name")]
568 Unknown,
569 #[error("duplicate unit name")]
570 Duplicate,
571 }
572
573 let mut states_by_id = HashMap::new();
574 let mut r = Vec::new();
575 {
576 let inner = self.inner.lock();
577 for state in states {
578 match inner.names.get_key_value(state.name.as_str()) {
579 Some((name, &id)) => {
580 if states_by_id
581 .insert(id, (name.clone(), state.state))
582 .is_some()
583 {
584 r.push((name.clone(), Err(RestoreUnitError::Duplicate)));
585 }
586 }
587 None => {
588 r.push((state.name.into(), Err(RestoreUnitError::Unknown)));
589 }
590 }
591 }
592 }
593
594 check("restore", r)?;
595
596 let r = self
597 .run_op(
598 "restore",
599 None,
600 State::Stopped,
601 State::Restoring,
602 State::Stopped,
603 StateRequest::Restore,
604 |id, _| states_by_id.remove(&id).map(|(_, blob)| blob),
605 |unit| &unit.dependencies,
606 )
607 .await;
608
609 check(
612 "restore",
613 states_by_id
614 .into_iter()
615 .map(|(_, (name, _))| (name, Err(RestoreUnitError::Unknown))),
616 )?;
617
618 check("restore", r)?;
619
620 Ok(())
621 }
622
623 async fn run_op<I: 'static, R: 'static + Send>(
645 &self,
646 op: &str,
647 unit_ids: Option<&[u64]>,
648 old_state: State,
649 interim_state: State,
650 new_state: State,
651 request: impl Copy + FnOnce(Rpc<I, R>) -> StateRequest,
652 mut input: impl FnMut(u64, &Unit) -> Option<I>,
653 mut deps: impl FnMut(&Unit) -> &[u64],
654 ) -> Vec<(Arc<str>, R)> {
655 let mut done = Vec::new();
656 let ready_set;
657 {
658 let mut inner = self.inner.lock();
659 ready_set = inner.ready_set(unit_ids);
660 for (&id, unit) in inner
661 .units
662 .iter_mut()
663 .filter(|(id, _)| ready_set.0.contains_key(id))
664 {
665 if unit.state != old_state {
666 assert_eq!(
667 unit.state, new_state,
668 "unit {} in {:?} state, should be {:?} or {:?}",
669 unit.name, unit.state, old_state, new_state
670 );
671 ready_set.done(id, true);
672 } else {
673 let name = unit.name.clone();
674 let input = (input)(id, unit);
675 let ready_set = ready_set.clone();
676 let deps = deps(unit).to_vec();
677 let fut = state_change(name.clone(), unit, request, input);
678 let recv = async move {
679 ready_set.wait(op, id, &deps).await;
680 let r = fut.await;
681 ready_set.done(id, true);
682 (name, id, r)
683 };
684 done.push(recv);
685 unit.state = interim_state;
686 }
687 }
688 }
689
690 let results = async {
691 let start = Instant::now();
692 let results = join_all(done).await;
693 tracing::info!(duration = ?Instant::now() - start, "state change complete");
694 results
695 }
696 .instrument(tracing::info_span!("state_change", operation = op))
697 .await;
698
699 let mut inner = self.inner.lock();
700 let r = results
701 .into_iter()
702 .filter_map(|(name, id, r)| {
703 match r {
704 Ok(Some(r)) => Some((name, r)),
705 Ok(None) => None,
706 Err(err) => {
707 if inner.units.contains_key(&id) {
712 panic!("{:?}", err);
713 }
714 None
715 }
716 }
717 })
718 .collect();
719 for (_, unit) in inner
720 .units
721 .iter_mut()
722 .filter(|(id, _)| ready_set.0.contains_key(id))
723 {
724 if unit.state == interim_state {
725 unit.state = new_state;
726 } else {
727 assert_eq!(
728 unit.state, new_state,
729 "unit {} in {:?} state, should be {:?} or {:?}",
730 unit.name, unit.state, interim_state, new_state
731 );
732 }
733 }
734 r
735 }
736}
737
738impl Inner {
739 fn ready_set(&self, unit_ids: Option<&[u64]>) -> ReadySet {
740 let map = |id, unit: &Unit| {
741 (
742 id,
743 ReadyState {
744 name: unit.name.clone(),
745 ready: Arc::new(Ready::default()),
746 },
747 )
748 };
749 let units = if let Some(unit_ids) = unit_ids {
750 unit_ids
751 .iter()
752 .map(|id| map(*id, &self.units[id]))
753 .collect()
754 } else {
755 self.units.iter().map(|(id, unit)| map(*id, unit)).collect()
756 };
757 ReadySet(Arc::new(units))
758 }
759}
760
761#[derive(Clone)]
762struct ReadySet(Arc<BTreeMap<u64, ReadyState>>);
763
764#[derive(Clone)]
765struct ReadyState {
766 name: Arc<str>,
767 ready: Arc<Ready>,
768}
769
770impl ReadySet {
771 async fn wait(&self, op: &str, id: u64, deps: &[u64]) -> bool {
772 for dep in deps {
773 if let Some(dep) = self.0.get(dep) {
774 if !dep.ready.is_ready() {
775 tracing::debug!(
776 device = self.0[&id].name.as_ref(),
777 dependency = dep.name.as_ref(),
778 operation = op,
779 "waiting on dependency"
780 );
781 }
782 if !dep.ready.wait().await {
783 return false;
784 }
785 }
786 }
787 true
788 }
789
790 fn done(&self, id: u64, success: bool) {
791 self.0[&id].ready.signal(success);
792 }
793}
794
795fn state_change<I: 'static, R: 'static + Send, Req: FnOnce(Rpc<I, R>) -> StateRequest>(
800 name: Arc<str>,
801 unit: &Unit,
802 request: Req,
803 input: Option<I>,
804) -> impl Future<Output = Result<Option<R>, UnitRecvError>> + use<I, R, Req> {
805 let send = unit.send.clone();
806
807 async move {
808 let Some(input) = input else { return Ok(None) };
809 let span = tracing::info_span!("device_state_change", device = name.as_ref());
810 async move {
811 let start = Instant::now();
812 let r = send
813 .call(request, input)
814 .await
815 .map_err(|err| UnitRecvError { name, source: err });
816 tracing::debug!(duration = ?Instant::now() - start, "device state change complete");
817 r.map(Some)
818 }
819 .instrument(span)
820 .await
821 }
822}
823
824#[derive(Debug)]
826#[must_use]
827pub struct UnitBuilder<'a> {
828 units: &'a StateUnits,
829 name: Arc<str>,
830 dependencies: Vec<u64>,
831 dependents: Vec<u64>,
832}
833
834impl UnitBuilder<'_> {
835 pub fn depends_on(mut self, handle: &UnitHandle) -> Self {
840 self.dependencies.push(self.handle_id(handle));
841 self
842 }
843
844 pub fn dependency_of(mut self, handle: &UnitHandle) -> Self {
849 self.dependents.push(self.handle_id(handle));
850 self
851 }
852
853 fn handle_id(&self, handle: &UnitHandle) -> u64 {
854 assert_eq!(
856 Weak::as_ptr(handle.inner.as_ref().unwrap()),
857 Arc::as_ptr(&self.units.inner)
858 );
859 handle.id.id
860 }
861
862 pub fn build(mut self, send: Sender<StateRequest>) -> Result<UnitHandle, NameInUse> {
864 let id = {
865 let mut inner = self.units.inner.lock();
866 let id = inner.next_id;
867 let entry = match inner.names.entry(self.name.clone()) {
868 hash_map::Entry::Occupied(_) => return Err(NameInUse(self.name)),
869 hash_map::Entry::Vacant(e) => e,
870 };
871 entry.insert(id);
872
873 self.dependencies.sort();
876 self.dependencies.dedup();
877 for &dep in &self.dependencies {
878 inner.units.get_mut(&dep).unwrap().dependents.push(id);
879 }
880
881 self.dependents.sort();
884 self.dependents.dedup();
885 for &dep in &self.dependents {
886 inner.units.get_mut(&dep).unwrap().dependencies.push(id);
887 }
888 inner.units.insert(
889 id,
890 Unit {
891 name: self.name.clone(),
892 send,
893 dependencies: self.dependencies,
894 dependents: self.dependents,
895 state: State::Stopped,
896 },
897 );
898 let unit_id = UnitId {
899 name: self.name,
900 id,
901 };
902 inner.next_id += 1;
903 unit_id
904 };
905 Ok(UnitHandle {
906 id,
907 inner: Some(Arc::downgrade(&self.units.inner)),
908 })
909 }
910
911 #[track_caller]
917 pub fn spawn<F, Fut>(
918 self,
919 spawner: impl Spawn,
920 f: F,
921 ) -> Result<SpawnedUnit<Fut::Output>, NameInUse>
922 where
923 F: FnOnce(Receiver<StateRequest>) -> Fut,
924 Fut: 'static + Send + Future,
925 Fut::Output: 'static + Send,
926 {
927 let (send, recv) = mesh::channel();
928 let task_name = format!("unit-{}", self.name);
929 let handle = self.build(send)?;
930 let fut = (f)(recv);
931 let task = spawner.spawn(task_name, fut);
932 Ok(SpawnedUnit { task, handle })
933 }
934}
935
936#[must_use]
938pub struct SpawnedUnit<T> {
939 handle: UnitHandle,
940 task: Task<T>,
941}
942
943impl<T> Debug for SpawnedUnit<T> {
944 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
945 f.debug_struct("SpawnedUnit")
946 .field("handle", &self.handle)
947 .field("task", &self.task)
948 .finish()
949 }
950}
951
952impl<T> SpawnedUnit<T> {
953 pub async fn remove(self) -> T {
955 self.handle.remove();
956 self.task.await
957 }
958
959 pub fn handle(&self) -> &UnitHandle {
962 &self.handle
963 }
964}
965
966#[derive(Default)]
967struct Ready {
968 state: AtomicU32,
969 event: event_listener::Event,
970}
971
972impl Ready {
973 fn signal(&self, success: bool) {
975 self.state.store(success as u32 + 1, Ordering::Release);
976 self.event.notify(usize::MAX);
977 }
978
979 fn is_ready(&self) -> bool {
980 self.state.load(Ordering::Acquire) != 0
981 }
982
983 async fn wait(&self) -> bool {
985 loop {
986 let listener = self.event.listen();
987 let state = self.state.load(Ordering::Acquire);
988 if state != 0 {
989 return state - 1 != 0;
990 }
991 listener.await;
992 }
993 }
994}
995
996#[cfg(test)]
997mod tests {
998 use super::StateUnit;
999 use super::StateUnits;
1000 use crate::run_unit;
1001 use inspect::InspectMut;
1002 use mesh::payload::Protobuf;
1003 use pal_async::DefaultDriver;
1004 use pal_async::async_test;
1005 use std::sync::Arc;
1006 use std::sync::atomic::AtomicBool;
1007 use std::sync::atomic::Ordering;
1008 use std::time::Duration;
1009 use test_with_tracing::test;
1010 use vmcore::save_restore::RestoreError;
1011 use vmcore::save_restore::SaveError;
1012 use vmcore::save_restore::SavedStateBlob;
1013 use vmcore::save_restore::SavedStateRoot;
1014
1015 #[derive(Default)]
1016 struct TestUnit {
1017 value: Arc<AtomicBool>,
1018 dep: Option<Arc<AtomicBool>>,
1019 support_saved_state: bool,
1021 }
1022
1023 #[derive(Protobuf, SavedStateRoot)]
1024 #[mesh(package = "test")]
1025 struct SavedState(bool);
1026
1027 impl StateUnit for TestUnit {
1028 async fn start(&mut self) {}
1029
1030 async fn stop(&mut self) {}
1031
1032 async fn reset(&mut self) -> anyhow::Result<()> {
1033 Ok(())
1034 }
1035
1036 async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError> {
1037 if self.support_saved_state {
1038 let state = SavedState(self.value.load(Ordering::Relaxed));
1039 Ok(Some(SavedStateBlob::new(state)))
1040 } else {
1041 Ok(None)
1042 }
1043 }
1044
1045 async fn restore(&mut self, state: SavedStateBlob) -> Result<(), RestoreError> {
1046 assert!(self.dep.as_ref().is_none_or(|v| v.load(Ordering::Relaxed)));
1047
1048 if self.support_saved_state {
1049 let state: SavedState = state.parse()?;
1050 self.value.store(state.0, Ordering::Relaxed);
1051 Ok(())
1052 } else {
1053 Err(RestoreError::SavedStateNotSupported)
1054 }
1055 }
1056 }
1057
1058 impl InspectMut for TestUnit {
1059 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1060 req.respond();
1061 }
1062 }
1063
1064 struct TestUnitSetDep {
1065 dep: Arc<AtomicBool>,
1066 driver: DefaultDriver,
1067 }
1068
1069 impl StateUnit for TestUnitSetDep {
1070 async fn start(&mut self) {}
1071
1072 async fn stop(&mut self) {}
1073
1074 async fn reset(&mut self) -> anyhow::Result<()> {
1075 Ok(())
1076 }
1077
1078 async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError> {
1079 Ok(Some(SavedStateBlob::new(SavedState(true))))
1080 }
1081
1082 async fn restore(&mut self, _state: SavedStateBlob) -> Result<(), RestoreError> {
1083 pal_async::timer::PolledTimer::new(&self.driver)
1084 .sleep(Duration::from_millis(100))
1085 .await;
1086
1087 self.dep.store(true, Ordering::Relaxed);
1088
1089 Ok(())
1090 }
1091 }
1092
1093 impl InspectMut for TestUnitSetDep {
1094 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1095 req.respond();
1096 }
1097 }
1098
1099 #[async_test]
1100 async fn test_state_change(driver: DefaultDriver) {
1101 let mut units = StateUnits::new();
1102
1103 let a_val = Arc::new(AtomicBool::new(true));
1104
1105 let _a = units
1106 .add("a")
1107 .spawn(&driver, |recv| {
1108 run_unit(
1109 TestUnit {
1110 value: a_val.clone(),
1111 dep: None,
1112 support_saved_state: true,
1113 },
1114 recv,
1115 )
1116 })
1117 .unwrap();
1118 let _b = units
1119 .add("b")
1120 .spawn(&driver, |recv| run_unit(TestUnit::default(), recv))
1121 .unwrap();
1122 units.start().await;
1123
1124 let _c = units
1125 .add("c")
1126 .spawn(&driver, |recv| run_unit(TestUnit::default(), recv));
1127 units.stop().await;
1128 units.start().await;
1129
1130 units.stop().await;
1131
1132 let state = units.save().await.unwrap();
1133
1134 a_val.store(false, Ordering::Relaxed);
1135
1136 units.restore(state).await.unwrap();
1137
1138 assert!(a_val.load(Ordering::Relaxed));
1139 }
1140
1141 #[async_test]
1142 async fn test_dependencies(driver: DefaultDriver) {
1143 let mut units = StateUnits::new();
1144
1145 let a_val = Arc::new(AtomicBool::new(true));
1146
1147 let a = units
1148 .add("zzz")
1149 .spawn(&driver, |recv| {
1150 run_unit(
1151 TestUnit {
1152 value: a_val.clone(),
1153 dep: None,
1154 support_saved_state: true,
1155 },
1156 recv,
1157 )
1158 })
1159 .unwrap();
1160
1161 let _b = units
1162 .add("aaa")
1163 .depends_on(a.handle())
1164 .spawn(&driver, |recv| {
1165 run_unit(
1166 TestUnit {
1167 dep: Some(a_val.clone()),
1168 value: Default::default(),
1169 support_saved_state: true,
1170 },
1171 recv,
1172 )
1173 })
1174 .unwrap();
1175 units.start().await;
1176 units.stop().await;
1177
1178 let state = units.save().await.unwrap();
1179
1180 a_val.store(false, Ordering::Relaxed);
1181
1182 units.restore(state).await.unwrap();
1183 }
1184
1185 #[async_test]
1186 async fn test_dep_no_saved_state(driver: DefaultDriver) {
1187 let mut units = StateUnits::new();
1188
1189 let true_val = Arc::new(AtomicBool::new(true));
1190 let shared_val = Arc::new(AtomicBool::new(false));
1191
1192 let a = units
1193 .add("a")
1194 .spawn(&driver, |recv| {
1195 run_unit(
1196 TestUnit {
1197 value: true_val.clone(),
1198 dep: Some(shared_val.clone()),
1199 support_saved_state: true,
1200 },
1201 recv,
1202 )
1203 })
1204 .unwrap();
1205
1206 let b = units
1209 .add("b_no_saved_state")
1210 .dependency_of(a.handle())
1211 .spawn(&driver, |recv| {
1212 run_unit(
1213 TestUnit {
1214 value: true_val.clone(),
1215 dep: Some(shared_val.clone()),
1216 support_saved_state: false,
1217 },
1218 recv,
1219 )
1220 })
1221 .unwrap();
1222
1223 let _c = units
1225 .add("c")
1226 .dependency_of(b.handle())
1227 .spawn(&driver, |recv| {
1228 run_unit(
1229 TestUnitSetDep {
1230 dep: shared_val,
1231 driver: driver.clone(),
1232 },
1233 recv,
1234 )
1235 })
1236 .unwrap();
1237
1238 let state = units.save().await.unwrap();
1239
1240 units.restore(state).await.unwrap();
1241 }
1242}