1#![forbid(unsafe_code)]
32
33use futures::FutureExt;
34use futures::StreamExt;
35use futures::future::join_all;
36use futures_concurrency::stream::Merge;
37use inspect::Inspect;
38use inspect::InspectMut;
39use mesh::MeshPayload;
40use mesh::Receiver;
41use mesh::Sender;
42use mesh::payload::Protobuf;
43use mesh::rpc::FailableRpc;
44use mesh::rpc::Rpc;
45use mesh::rpc::RpcError;
46use mesh::rpc::RpcSend;
47use pal_async::task::Spawn;
48use pal_async::task::Task;
49use parking_lot::Mutex;
50use std::collections::BTreeMap;
51use std::collections::HashMap;
52use std::collections::hash_map;
53use std::fmt::Debug;
54use std::fmt::Display;
55use std::future::Future;
56use std::pin::pin;
57use std::sync::Arc;
58use std::sync::Weak;
59use std::sync::atomic::AtomicU32;
60use std::sync::atomic::Ordering;
61use std::time::Instant;
62use thiserror::Error;
63use tracing::Instrument;
64use vmcore::save_restore::RestoreError;
65use vmcore::save_restore::SaveError;
66use vmcore::save_restore::SavedStateBlob;
67
68#[derive(Debug, MeshPayload)]
70pub enum StateRequest {
71 Start(Rpc<(), ()>),
73
74 Stop(Rpc<(), ()>),
76
77 Reset(FailableRpc<(), ()>),
79
80 Save(FailableRpc<(), Option<SavedStateBlob>>),
82
83 Restore(FailableRpc<SavedStateBlob, ()>),
85
86 Inspect(inspect::Deferred),
88}
89
90#[expect(async_fn_in_trait)] pub trait StateUnit: InspectMut {
97 async fn start(&mut self);
99
100 async fn stop(&mut self);
102
103 async fn reset(&mut self) -> anyhow::Result<()>;
107
108 async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError>;
112
113 async fn restore(&mut self, buffer: SavedStateBlob) -> Result<(), RestoreError>;
117}
118
119pub async fn run_unit<T: StateUnit>(mut unit: T, mut recv: Receiver<StateRequest>) -> T {
121 while let Some(req) = recv.next().await {
122 req.apply(&mut unit).await;
123 }
124 unit
125}
126
127pub async fn run_async_unit<T>(unit: T, mut recv: Receiver<StateRequest>) -> T
130where
131 for<'a> &'a T: StateUnit,
132{
133 while let Some(req) = recv.next().await {
134 req.apply_with_concurrent_inspects(&mut &unit, &mut recv)
135 .await;
136 }
137 unit
138}
139
140impl StateRequest {
141 pub async fn apply_with_concurrent_inspects<'a, T>(
151 self,
152 unit: &mut &'a T,
153 recv: &mut Receiver<StateRequest>,
154 ) where
155 &'a T: StateUnit,
156 {
157 match self {
158 StateRequest::Inspect(_) => {
159 self.apply(unit).await;
162 }
163
164 StateRequest::Start(_)
165 | StateRequest::Stop(_)
166 | StateRequest::Reset(_)
167 | StateRequest::Save(_)
168 | StateRequest::Restore(_) => {
169 enum Event {
171 OpDone,
172 Req(StateRequest),
173 }
174 let mut op_unit = *unit;
175 let op = pin!(self.apply(&mut op_unit).into_stream());
176 let mut stream = (op.map(|()| Event::OpDone), recv.map(Event::Req)).merge();
177 while let Some(Event::Req(next_req)) = stream.next().await {
178 match next_req {
179 StateRequest::Inspect(req) => req.inspect(&mut *unit),
180 _ => panic!(
181 "unexpected state transition {next_req:?} during state transition"
182 ),
183 }
184 }
185 }
186 }
187 }
188
189 pub async fn apply(self, unit: &mut impl StateUnit) {
191 match self {
192 StateRequest::Start(rpc) => rpc.handle(async |()| unit.start().await).await,
193 StateRequest::Stop(rpc) => rpc.handle(async |()| unit.stop().await).await,
194 StateRequest::Reset(rpc) => rpc.handle_failable(async |()| unit.reset().await).await,
195 StateRequest::Save(rpc) => rpc.handle_failable(async |()| unit.save().await).await,
196 StateRequest::Restore(rpc) => {
197 rpc.handle_failable(async |buffer| unit.restore(buffer).await)
198 .await
199 }
200 StateRequest::Inspect(req) => req.inspect(unit),
201 }
202 }
203}
204
205#[derive(Debug)]
207pub struct StateUnits {
208 inner: Arc<Mutex<Inner>>,
209 running: bool,
210}
211
212#[derive(Copy, Clone, PartialEq, Eq, Debug, Inspect)]
213enum State {
214 Stopped,
215 Starting,
216 Running,
217 Stopping,
218 Resetting,
219 Saving,
220 Restoring,
221}
222
223#[derive(Debug)]
224struct Inner {
225 next_id: u64,
226 units: BTreeMap<u64, Unit>,
227 names: HashMap<Arc<str>, u64>,
228}
229
230#[derive(Debug)]
231struct Unit {
232 name: Arc<str>,
233 send: Sender<StateRequest>,
234 dependencies: Vec<u64>,
235 dependents: Vec<u64>,
236 state: State,
237}
238
239#[derive(Debug, Error)]
241#[error("state unit name {0} is in use")]
242pub struct NameInUse(Arc<str>);
243
244#[derive(Debug, Error)]
245#[error("critical unit communication failure: {name}")]
246struct UnitRecvError {
247 name: Arc<str>,
248 #[source]
249 source: RpcError,
250}
251
252#[derive(Debug, Clone)]
253struct UnitId {
254 name: Arc<str>,
255 id: u64,
256}
257
258#[must_use]
260#[derive(Debug)]
261pub struct UnitHandle {
262 id: UnitId,
263 inner: Option<Weak<Mutex<Inner>>>,
264}
265
266impl Drop for UnitHandle {
267 fn drop(&mut self) {
268 self.remove_if();
269 }
270}
271
272impl UnitHandle {
273 pub fn remove(mut self) {
275 self.remove_if();
276 }
277
278 pub fn detach(mut self) {
280 self.inner = None;
281 }
282
283 fn remove_if(&mut self) {
284 if let Some(inner) = self.inner.take().and_then(|inner| inner.upgrade()) {
285 let mut inner = inner.lock();
286 inner.units.remove(&self.id.id).expect("unit exists");
287 inner.names.remove(&self.id.name).expect("unit exists");
288 }
289 }
290}
291
292pub struct StateUnitsInspector {
295 inner: Weak<Mutex<Inner>>,
296}
297
298impl Inspect for StateUnits {
299 fn inspect(&self, req: inspect::Request<'_>) {
300 self.inner.lock().inspect(req);
301 }
302}
303
304impl Inspect for StateUnitsInspector {
305 fn inspect(&self, req: inspect::Request<'_>) {
306 if let Some(inner) = self.inner.upgrade() {
307 inner.lock().inspect(req);
308 }
309 }
310}
311
312impl Inspect for Inner {
313 fn inspect(&self, req: inspect::Request<'_>) {
314 let mut resp = req.respond();
315 for unit in self.units.values() {
316 resp.child(unit.name.as_ref(), |req| {
317 let mut resp = req.respond();
318 if !unit.dependencies.is_empty() {
319 resp.field_with("dependencies", || {
320 unit.dependencies
321 .iter()
322 .map(|id| self.units[id].name.as_ref())
323 .collect::<Vec<_>>()
324 .join(",")
325 });
326 }
327 if !unit.dependents.is_empty() {
328 resp.field_with("dependents", || {
329 unit.dependents
330 .iter()
331 .map(|id| self.units[id].name.as_ref())
332 .collect::<Vec<_>>()
333 .join(",")
334 });
335 }
336 resp.field("unit_state", unit.state)
337 .merge(&inspect::send(&unit.send, StateRequest::Inspect));
338 });
339 }
340 }
341}
342
343#[derive(Protobuf)]
345#[mesh(package = "state_unit")]
346pub struct SavedStateUnit {
347 #[mesh(1)]
349 pub name: String,
350 #[mesh(2)]
352 pub state: SavedStateBlob,
353}
354
355#[derive(Debug, Error)]
357#[error("{op} failed")]
358pub struct StateTransitionError {
359 op: &'static str,
360 #[source]
361 errors: UnitErrorSet,
362}
363
364fn extract<T, E: Into<anyhow::Error>, U>(
365 op: &'static str,
366 iter: impl IntoIterator<Item = (Arc<str>, Result<T, E>)>,
367 mut f: impl FnMut(Arc<str>, T) -> Option<U>,
368) -> Result<Vec<U>, StateTransitionError> {
369 let mut result = Vec::new();
370 let mut errors = Vec::new();
371 for (name, item) in iter {
372 match item {
373 Ok(t) => {
374 if let Some(u) = f(name, t) {
375 result.push(u);
376 }
377 }
378 Err(err) => errors.push((name, err.into())),
379 }
380 }
381 if errors.is_empty() {
382 Ok(result)
383 } else {
384 Err(StateTransitionError {
385 op,
386 errors: UnitErrorSet(errors),
387 })
388 }
389}
390
391fn check<E: Into<anyhow::Error>>(
392 op: &'static str,
393 iter: impl IntoIterator<Item = (Arc<str>, Result<(), E>)>,
394) -> Result<(), StateTransitionError> {
395 extract(op, iter, |_, _| Some(()))?;
396 Ok(())
397}
398
399#[derive(Debug)]
400struct UnitErrorSet(Vec<(Arc<str>, anyhow::Error)>);
401
402impl Display for UnitErrorSet {
403 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
404 let mut map = f.debug_map();
405 for (name, err) in &self.0 {
406 map.entry(&format_args!("{}", name), &format_args!("{:#}", err));
407 }
408 map.finish()
409 }
410}
411
412impl std::error::Error for UnitErrorSet {}
413
414impl StateUnits {
415 pub fn new() -> Self {
417 Self {
418 inner: Arc::new(Mutex::new(Inner {
419 next_id: 0,
420 units: BTreeMap::new(),
421 names: HashMap::new(),
422 })),
423 running: false,
424 }
425 }
426
427 pub fn inspector(&self) -> StateUnitsInspector {
430 StateUnitsInspector {
431 inner: Arc::downgrade(&self.inner),
432 }
433 }
434
435 pub fn add(&self, name: impl Into<Arc<str>>) -> UnitBuilder<'_> {
442 UnitBuilder {
443 units: self,
444 name: name.into(),
445 dependencies: Vec::new(),
446 dependents: Vec::new(),
447 }
448 }
449
450 pub fn is_running(&self) -> bool {
452 self.running
453 }
454
455 pub async fn start_stopped_units(&mut self) {
460 if self.is_running() {
461 self.start().await;
462 }
463 }
464
465 pub async fn start(&mut self) {
467 self.run_op(
468 "start",
469 None,
470 State::Stopped,
471 State::Starting,
472 State::Running,
473 StateRequest::Start,
474 |_, _| Some(()),
475 |unit| &unit.dependencies,
476 )
477 .await;
478 self.running = true;
479 }
480
481 pub async fn stop(&mut self) {
483 assert!(self.running);
484 self.run_op(
487 "stop",
488 None,
489 State::Running,
490 State::Stopping,
491 State::Stopped,
492 StateRequest::Stop,
493 |_, _| Some(()),
494 |unit| &unit.dependents,
495 )
496 .await;
497 self.running = false;
498 }
499
500 pub async fn reset(&mut self) -> Result<(), StateTransitionError> {
504 assert!(!self.running);
505 let r = self
508 .run_op(
509 "reset",
510 None,
511 State::Stopped,
512 State::Resetting,
513 State::Stopped,
514 StateRequest::Reset,
515 |_, _| Some(()),
516 |unit| &unit.dependencies,
517 )
518 .await;
519
520 check("reset", r)?;
521 Ok(())
522 }
523
524 pub async fn save(&mut self) -> Result<Vec<SavedStateUnit>, StateTransitionError> {
528 assert!(!self.running);
529 let r = self
532 .run_op(
533 "save",
534 None,
535 State::Stopped,
536 State::Saving,
537 State::Stopped,
538 StateRequest::Save,
539 |_, _| Some(()),
540 |_| &[],
541 )
542 .await;
543
544 let states = extract("save", r, |name, state| {
545 state.map(|state| SavedStateUnit {
546 name: name.to_string(),
547 state,
548 })
549 })?;
550
551 Ok(states)
552 }
553
554 pub async fn restore(
558 &mut self,
559 states: Vec<SavedStateUnit>,
560 ) -> Result<(), StateTransitionError> {
561 assert!(!self.running);
562
563 #[derive(Debug, Error)]
564 enum RestoreUnitError {
565 #[error("unknown unit name")]
566 Unknown,
567 #[error("duplicate unit name")]
568 Duplicate,
569 }
570
571 let mut states_by_id = HashMap::new();
572 let mut r = Vec::new();
573 {
574 let inner = self.inner.lock();
575 for state in states {
576 match inner.names.get_key_value(state.name.as_str()) {
577 Some((name, &id)) => {
578 if states_by_id
579 .insert(id, (name.clone(), state.state))
580 .is_some()
581 {
582 r.push((name.clone(), Err(RestoreUnitError::Duplicate)));
583 }
584 }
585 None => {
586 r.push((state.name.into(), Err(RestoreUnitError::Unknown)));
587 }
588 }
589 }
590 }
591
592 check("restore", r)?;
593
594 let r = self
595 .run_op(
596 "restore",
597 None,
598 State::Stopped,
599 State::Restoring,
600 State::Stopped,
601 StateRequest::Restore,
602 |id, _| states_by_id.remove(&id).map(|(_, blob)| blob),
603 |unit| &unit.dependencies,
604 )
605 .await;
606
607 check(
610 "restore",
611 states_by_id
612 .into_iter()
613 .map(|(_, (name, _))| (name, Err(RestoreUnitError::Unknown))),
614 )?;
615
616 check("restore", r)?;
617
618 Ok(())
619 }
620
621 async fn run_op<I: 'static, R: 'static + Send>(
643 &self,
644 op: &str,
645 unit_ids: Option<&[u64]>,
646 old_state: State,
647 interim_state: State,
648 new_state: State,
649 request: impl Copy + FnOnce(Rpc<I, R>) -> StateRequest,
650 mut input: impl FnMut(u64, &Unit) -> Option<I>,
651 mut deps: impl FnMut(&Unit) -> &[u64],
652 ) -> Vec<(Arc<str>, R)> {
653 let mut done = Vec::new();
654 let ready_set;
655 {
656 let mut inner = self.inner.lock();
657 ready_set = inner.ready_set(unit_ids);
658 for (&id, unit) in inner
659 .units
660 .iter_mut()
661 .filter(|(id, _)| ready_set.0.contains_key(id))
662 {
663 if unit.state != old_state {
664 assert_eq!(
665 unit.state, new_state,
666 "unit {} in {:?} state, should be {:?} or {:?}",
667 unit.name, unit.state, old_state, new_state
668 );
669 ready_set.done(id, true);
670 } else {
671 let name = unit.name.clone();
672 let input = (input)(id, unit);
673 let ready_set = ready_set.clone();
674 let deps = deps(unit).to_vec();
675 let fut = state_change(name.clone(), unit, request, input);
676 let recv = async move {
677 ready_set.wait(op, id, &deps).await;
678 let r = fut.await;
679 ready_set.done(id, true);
680 (name, id, r)
681 };
682 done.push(recv);
683 unit.state = interim_state;
684 }
685 }
686 }
687
688 let results = async {
689 let start = Instant::now();
690 let results = join_all(done).await;
691 tracing::info!(duration = ?Instant::now() - start, "state change complete");
692 results
693 }
694 .instrument(tracing::info_span!("state_change", operation = op))
695 .await;
696
697 let mut inner = self.inner.lock();
698 let r = results
699 .into_iter()
700 .filter_map(|(name, id, r)| {
701 match r {
702 Ok(Some(r)) => Some((name, r)),
703 Ok(None) => None,
704 Err(err) => {
705 if inner.units.contains_key(&id) {
710 panic!("{:?}", err);
711 }
712 None
713 }
714 }
715 })
716 .collect();
717 for (_, unit) in inner
718 .units
719 .iter_mut()
720 .filter(|(id, _)| ready_set.0.contains_key(id))
721 {
722 if unit.state == interim_state {
723 unit.state = new_state;
724 } else {
725 assert_eq!(
726 unit.state, new_state,
727 "unit {} in {:?} state, should be {:?} or {:?}",
728 unit.name, unit.state, interim_state, new_state
729 );
730 }
731 }
732 r
733 }
734}
735
736impl Inner {
737 fn ready_set(&self, unit_ids: Option<&[u64]>) -> ReadySet {
738 let map = |id, unit: &Unit| {
739 (
740 id,
741 ReadyState {
742 name: unit.name.clone(),
743 ready: Arc::new(Ready::default()),
744 },
745 )
746 };
747 let units = if let Some(unit_ids) = unit_ids {
748 unit_ids
749 .iter()
750 .map(|id| map(*id, &self.units[id]))
751 .collect()
752 } else {
753 self.units.iter().map(|(id, unit)| map(*id, unit)).collect()
754 };
755 ReadySet(Arc::new(units))
756 }
757}
758
759#[derive(Clone)]
760struct ReadySet(Arc<BTreeMap<u64, ReadyState>>);
761
762#[derive(Clone)]
763struct ReadyState {
764 name: Arc<str>,
765 ready: Arc<Ready>,
766}
767
768impl ReadySet {
769 async fn wait(&self, op: &str, id: u64, deps: &[u64]) -> bool {
770 for dep in deps {
771 if let Some(dep) = self.0.get(dep) {
772 if !dep.ready.is_ready() {
773 tracing::debug!(
774 device = self.0[&id].name.as_ref(),
775 dependency = dep.name.as_ref(),
776 operation = op,
777 "waiting on dependency"
778 );
779 }
780 if !dep.ready.wait().await {
781 return false;
782 }
783 }
784 }
785 true
786 }
787
788 fn done(&self, id: u64, success: bool) {
789 self.0[&id].ready.signal(success);
790 }
791}
792
793fn state_change<I: 'static, R: 'static + Send, Req: FnOnce(Rpc<I, R>) -> StateRequest>(
798 name: Arc<str>,
799 unit: &Unit,
800 request: Req,
801 input: Option<I>,
802) -> impl Future<Output = Result<Option<R>, UnitRecvError>> + use<I, R, Req> {
803 let send = unit.send.clone();
804
805 async move {
806 let Some(input) = input else { return Ok(None) };
807 let span = tracing::info_span!("device_state_change", device = name.as_ref());
808 async move {
809 let start = Instant::now();
810 let r = send
811 .call(request, input)
812 .await
813 .map_err(|err| UnitRecvError { name, source: err });
814 tracing::debug!(duration = ?Instant::now() - start, "device state change complete");
815 r.map(Some)
816 }
817 .instrument(span)
818 .await
819 }
820}
821
822#[derive(Debug)]
824#[must_use]
825pub struct UnitBuilder<'a> {
826 units: &'a StateUnits,
827 name: Arc<str>,
828 dependencies: Vec<u64>,
829 dependents: Vec<u64>,
830}
831
832impl UnitBuilder<'_> {
833 pub fn depends_on(mut self, handle: &UnitHandle) -> Self {
838 self.dependencies.push(self.handle_id(handle));
839 self
840 }
841
842 pub fn dependency_of(mut self, handle: &UnitHandle) -> Self {
847 self.dependents.push(self.handle_id(handle));
848 self
849 }
850
851 fn handle_id(&self, handle: &UnitHandle) -> u64 {
852 assert_eq!(
854 Weak::as_ptr(handle.inner.as_ref().unwrap()),
855 Arc::as_ptr(&self.units.inner)
856 );
857 handle.id.id
858 }
859
860 pub fn build(mut self, send: Sender<StateRequest>) -> Result<UnitHandle, NameInUse> {
862 let id = {
863 let mut inner = self.units.inner.lock();
864 let id = inner.next_id;
865 let entry = match inner.names.entry(self.name.clone()) {
866 hash_map::Entry::Occupied(_) => return Err(NameInUse(self.name)),
867 hash_map::Entry::Vacant(e) => e,
868 };
869 entry.insert(id);
870
871 self.dependencies.sort();
874 self.dependencies.dedup();
875 for &dep in &self.dependencies {
876 inner.units.get_mut(&dep).unwrap().dependents.push(id);
877 }
878
879 self.dependents.sort();
882 self.dependents.dedup();
883 for &dep in &self.dependents {
884 inner.units.get_mut(&dep).unwrap().dependencies.push(id);
885 }
886 inner.units.insert(
887 id,
888 Unit {
889 name: self.name.clone(),
890 send,
891 dependencies: self.dependencies,
892 dependents: self.dependents,
893 state: State::Stopped,
894 },
895 );
896 let unit_id = UnitId {
897 name: self.name,
898 id,
899 };
900 inner.next_id += 1;
901 unit_id
902 };
903 Ok(UnitHandle {
904 id,
905 inner: Some(Arc::downgrade(&self.units.inner)),
906 })
907 }
908
909 #[track_caller]
915 pub fn spawn<F, Fut>(
916 self,
917 spawner: impl Spawn,
918 f: F,
919 ) -> Result<SpawnedUnit<Fut::Output>, NameInUse>
920 where
921 F: FnOnce(Receiver<StateRequest>) -> Fut,
922 Fut: 'static + Send + Future,
923 Fut::Output: 'static + Send,
924 {
925 let (send, recv) = mesh::channel();
926 let task_name = format!("unit-{}", self.name);
927 let handle = self.build(send)?;
928 let fut = (f)(recv);
929 let task = spawner.spawn(task_name, fut);
930 Ok(SpawnedUnit { task, handle })
931 }
932}
933
934#[must_use]
936pub struct SpawnedUnit<T> {
937 handle: UnitHandle,
938 task: Task<T>,
939}
940
941impl<T> Debug for SpawnedUnit<T> {
942 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
943 f.debug_struct("SpawnedUnit")
944 .field("handle", &self.handle)
945 .field("task", &self.task)
946 .finish()
947 }
948}
949
950impl<T> SpawnedUnit<T> {
951 pub async fn remove(self) -> T {
953 self.handle.remove();
954 self.task.await
955 }
956
957 pub fn handle(&self) -> &UnitHandle {
960 &self.handle
961 }
962}
963
964#[derive(Default)]
965struct Ready {
966 state: AtomicU32,
967 event: event_listener::Event,
968}
969
970impl Ready {
971 fn signal(&self, success: bool) {
973 self.state.store(success as u32 + 1, Ordering::Release);
974 self.event.notify(usize::MAX);
975 }
976
977 fn is_ready(&self) -> bool {
978 self.state.load(Ordering::Acquire) != 0
979 }
980
981 async fn wait(&self) -> bool {
983 loop {
984 let listener = self.event.listen();
985 let state = self.state.load(Ordering::Acquire);
986 if state != 0 {
987 return state - 1 != 0;
988 }
989 listener.await;
990 }
991 }
992}
993
994#[cfg(test)]
995mod tests {
996 use super::StateUnit;
997 use super::StateUnits;
998 use crate::run_unit;
999 use inspect::InspectMut;
1000 use mesh::payload::Protobuf;
1001 use pal_async::DefaultDriver;
1002 use pal_async::async_test;
1003 use std::sync::Arc;
1004 use std::sync::atomic::AtomicBool;
1005 use std::sync::atomic::Ordering;
1006 use std::time::Duration;
1007 use test_with_tracing::test;
1008 use vmcore::save_restore::RestoreError;
1009 use vmcore::save_restore::SaveError;
1010 use vmcore::save_restore::SavedStateBlob;
1011 use vmcore::save_restore::SavedStateRoot;
1012
1013 #[derive(Default)]
1014 struct TestUnit {
1015 value: Arc<AtomicBool>,
1016 dep: Option<Arc<AtomicBool>>,
1017 support_saved_state: bool,
1019 }
1020
1021 #[derive(Protobuf, SavedStateRoot)]
1022 #[mesh(package = "test")]
1023 struct SavedState(bool);
1024
1025 impl StateUnit for TestUnit {
1026 async fn start(&mut self) {}
1027
1028 async fn stop(&mut self) {}
1029
1030 async fn reset(&mut self) -> anyhow::Result<()> {
1031 Ok(())
1032 }
1033
1034 async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError> {
1035 if self.support_saved_state {
1036 let state = SavedState(self.value.load(Ordering::Relaxed));
1037 Ok(Some(SavedStateBlob::new(state)))
1038 } else {
1039 Ok(None)
1040 }
1041 }
1042
1043 async fn restore(&mut self, state: SavedStateBlob) -> Result<(), RestoreError> {
1044 assert!(self.dep.as_ref().is_none_or(|v| v.load(Ordering::Relaxed)));
1045
1046 if self.support_saved_state {
1047 let state: SavedState = state.parse()?;
1048 self.value.store(state.0, Ordering::Relaxed);
1049 Ok(())
1050 } else {
1051 Err(RestoreError::SavedStateNotSupported)
1052 }
1053 }
1054 }
1055
1056 impl InspectMut for TestUnit {
1057 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1058 req.respond();
1059 }
1060 }
1061
1062 struct TestUnitSetDep {
1063 dep: Arc<AtomicBool>,
1064 driver: DefaultDriver,
1065 }
1066
1067 impl StateUnit for TestUnitSetDep {
1068 async fn start(&mut self) {}
1069
1070 async fn stop(&mut self) {}
1071
1072 async fn reset(&mut self) -> anyhow::Result<()> {
1073 Ok(())
1074 }
1075
1076 async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError> {
1077 Ok(Some(SavedStateBlob::new(SavedState(true))))
1078 }
1079
1080 async fn restore(&mut self, _state: SavedStateBlob) -> Result<(), RestoreError> {
1081 pal_async::timer::PolledTimer::new(&self.driver)
1082 .sleep(Duration::from_millis(100))
1083 .await;
1084
1085 self.dep.store(true, Ordering::Relaxed);
1086
1087 Ok(())
1088 }
1089 }
1090
1091 impl InspectMut for TestUnitSetDep {
1092 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1093 req.respond();
1094 }
1095 }
1096
1097 #[async_test]
1098 async fn test_state_change(driver: DefaultDriver) {
1099 let mut units = StateUnits::new();
1100
1101 let a_val = Arc::new(AtomicBool::new(true));
1102
1103 let _a = units
1104 .add("a")
1105 .spawn(&driver, |recv| {
1106 run_unit(
1107 TestUnit {
1108 value: a_val.clone(),
1109 dep: None,
1110 support_saved_state: true,
1111 },
1112 recv,
1113 )
1114 })
1115 .unwrap();
1116 let _b = units
1117 .add("b")
1118 .spawn(&driver, |recv| run_unit(TestUnit::default(), recv))
1119 .unwrap();
1120 units.start().await;
1121
1122 let _c = units
1123 .add("c")
1124 .spawn(&driver, |recv| run_unit(TestUnit::default(), recv));
1125 units.stop().await;
1126 units.start().await;
1127
1128 units.stop().await;
1129
1130 let state = units.save().await.unwrap();
1131
1132 a_val.store(false, Ordering::Relaxed);
1133
1134 units.restore(state).await.unwrap();
1135
1136 assert!(a_val.load(Ordering::Relaxed));
1137 }
1138
1139 #[async_test]
1140 async fn test_dependencies(driver: DefaultDriver) {
1141 let mut units = StateUnits::new();
1142
1143 let a_val = Arc::new(AtomicBool::new(true));
1144
1145 let a = units
1146 .add("zzz")
1147 .spawn(&driver, |recv| {
1148 run_unit(
1149 TestUnit {
1150 value: a_val.clone(),
1151 dep: None,
1152 support_saved_state: true,
1153 },
1154 recv,
1155 )
1156 })
1157 .unwrap();
1158
1159 let _b = units
1160 .add("aaa")
1161 .depends_on(a.handle())
1162 .spawn(&driver, |recv| {
1163 run_unit(
1164 TestUnit {
1165 dep: Some(a_val.clone()),
1166 value: Default::default(),
1167 support_saved_state: true,
1168 },
1169 recv,
1170 )
1171 })
1172 .unwrap();
1173 units.start().await;
1174 units.stop().await;
1175
1176 let state = units.save().await.unwrap();
1177
1178 a_val.store(false, Ordering::Relaxed);
1179
1180 units.restore(state).await.unwrap();
1181 }
1182
1183 #[async_test]
1184 async fn test_dep_no_saved_state(driver: DefaultDriver) {
1185 let mut units = StateUnits::new();
1186
1187 let true_val = Arc::new(AtomicBool::new(true));
1188 let shared_val = Arc::new(AtomicBool::new(false));
1189
1190 let a = units
1191 .add("a")
1192 .spawn(&driver, |recv| {
1193 run_unit(
1194 TestUnit {
1195 value: true_val.clone(),
1196 dep: Some(shared_val.clone()),
1197 support_saved_state: true,
1198 },
1199 recv,
1200 )
1201 })
1202 .unwrap();
1203
1204 let b = units
1207 .add("b_no_saved_state")
1208 .dependency_of(a.handle())
1209 .spawn(&driver, |recv| {
1210 run_unit(
1211 TestUnit {
1212 value: true_val.clone(),
1213 dep: Some(shared_val.clone()),
1214 support_saved_state: false,
1215 },
1216 recv,
1217 )
1218 })
1219 .unwrap();
1220
1221 let _c = units
1223 .add("c")
1224 .dependency_of(b.handle())
1225 .spawn(&driver, |recv| {
1226 run_unit(
1227 TestUnitSetDep {
1228 dep: shared_val,
1229 driver: driver.clone(),
1230 },
1231 recv,
1232 )
1233 })
1234 .unwrap();
1235
1236 let state = units.save().await.unwrap();
1237
1238 units.restore(state).await.unwrap();
1239 }
1240}