1use futures::FutureExt;
32use futures::StreamExt;
33use futures::future::join_all;
34use futures_concurrency::stream::Merge;
35use inspect::Inspect;
36use inspect::InspectMut;
37use mesh::MeshPayload;
38use mesh::Receiver;
39use mesh::Sender;
40use mesh::payload::Protobuf;
41use mesh::rpc::FailableRpc;
42use mesh::rpc::Rpc;
43use mesh::rpc::RpcError;
44use mesh::rpc::RpcSend;
45use pal_async::task::Spawn;
46use pal_async::task::Task;
47use parking_lot::Mutex;
48use std::collections::BTreeMap;
49use std::collections::HashMap;
50use std::collections::hash_map;
51use std::fmt::Debug;
52use std::fmt::Display;
53use std::future::Future;
54use std::pin::pin;
55use std::sync::Arc;
56use std::sync::Weak;
57use std::sync::atomic::AtomicU32;
58use std::sync::atomic::Ordering;
59use std::time::Instant;
60use thiserror::Error;
61use tracing::Instrument;
62use vmcore::save_restore::RestoreError;
63use vmcore::save_restore::SaveError;
64use vmcore::save_restore::SavedStateBlob;
65
66#[derive(Debug, MeshPayload)]
68pub enum StateRequest {
69 Start(Rpc<(), ()>),
71
72 Stop(Rpc<(), ()>),
74
75 Reset(FailableRpc<(), ()>),
77
78 Save(FailableRpc<(), Option<SavedStateBlob>>),
80
81 Restore(FailableRpc<SavedStateBlob, ()>),
83
84 PostRestore(FailableRpc<(), ()>),
88
89 Inspect(inspect::Deferred),
91}
92
93#[expect(async_fn_in_trait)] pub trait StateUnit: InspectMut {
100 async fn start(&mut self);
102
103 async fn stop(&mut self);
105
106 async fn reset(&mut self) -> anyhow::Result<()>;
110
111 async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError>;
115
116 async fn restore(&mut self, buffer: SavedStateBlob) -> Result<(), RestoreError>;
120
121 async fn post_restore(&mut self) -> anyhow::Result<()> {
125 Ok(())
126 }
127}
128
129pub async fn run_unit<T: StateUnit>(mut unit: T, mut recv: Receiver<StateRequest>) -> T {
131 while let Some(req) = recv.next().await {
132 req.apply(&mut unit).await;
133 }
134 unit
135}
136
137pub async fn run_async_unit<T>(unit: T, mut recv: Receiver<StateRequest>) -> T
140where
141 for<'a> &'a T: StateUnit,
142{
143 while let Some(req) = recv.next().await {
144 req.apply_with_concurrent_inspects(&mut &unit, &mut recv)
145 .await;
146 }
147 unit
148}
149
150impl StateRequest {
151 pub async fn apply_with_concurrent_inspects<'a, T>(
161 self,
162 unit: &mut &'a T,
163 recv: &mut Receiver<StateRequest>,
164 ) where
165 &'a T: StateUnit,
166 {
167 match self {
168 StateRequest::Inspect(_) => {
169 self.apply(unit).await;
172 }
173
174 StateRequest::Start(_)
175 | StateRequest::Stop(_)
176 | StateRequest::Reset(_)
177 | StateRequest::Save(_)
178 | StateRequest::Restore(_)
179 | StateRequest::PostRestore(_) => {
180 enum Event {
182 OpDone,
183 Req(StateRequest),
184 }
185 let mut op_unit = *unit;
186 let op = pin!(self.apply(&mut op_unit).into_stream());
187 let mut stream = (op.map(|()| Event::OpDone), recv.map(Event::Req)).merge();
188 while let Some(Event::Req(next_req)) = stream.next().await {
189 match next_req {
190 StateRequest::Inspect(req) => req.inspect(&mut *unit),
191 _ => panic!(
192 "unexpected state transition {next_req:?} during state transition"
193 ),
194 }
195 }
196 }
197 }
198 }
199
200 pub async fn apply(self, unit: &mut impl StateUnit) {
202 match self {
203 StateRequest::Start(rpc) => rpc.handle(async |()| unit.start().await).await,
204 StateRequest::Stop(rpc) => rpc.handle(async |()| unit.stop().await).await,
205 StateRequest::Reset(rpc) => rpc.handle_failable(async |()| unit.reset().await).await,
206 StateRequest::Save(rpc) => rpc.handle_failable(async |()| unit.save().await).await,
207 StateRequest::Restore(rpc) => {
208 rpc.handle_failable(async |buffer| unit.restore(buffer).await)
209 .await
210 }
211 StateRequest::PostRestore(rpc) => {
212 rpc.handle_failable(async |()| unit.post_restore().await)
213 .await
214 }
215 StateRequest::Inspect(req) => req.inspect(unit),
216 }
217 }
218}
219
220#[derive(Debug)]
222pub struct StateUnits {
223 inner: Arc<Mutex<Inner>>,
224 running: bool,
225}
226
227#[derive(Copy, Clone, PartialEq, Eq, Debug, Inspect)]
228enum State {
229 Stopped,
230 Starting,
231 Running,
232 Stopping,
233 Resetting,
234 Saving,
235 Restoring,
236 PostRestoring,
237}
238
239#[derive(Debug)]
240struct Inner {
241 next_id: u64,
242 units: BTreeMap<u64, Unit>,
243 names: HashMap<Arc<str>, u64>,
244}
245
246#[derive(Debug)]
247struct Unit {
248 name: Arc<str>,
249 send: Sender<StateRequest>,
250 dependencies: Vec<u64>,
251 dependents: Vec<u64>,
252 state: State,
253}
254
255#[derive(Debug, Error)]
257#[error("state unit name {0} is in use")]
258pub struct NameInUse(Arc<str>);
259
260#[derive(Debug, Error)]
261#[error("critical unit communication failure: {name}")]
262struct UnitRecvError {
263 name: Arc<str>,
264 #[source]
265 source: RpcError,
266}
267
268#[derive(Debug, Clone)]
269struct UnitId {
270 name: Arc<str>,
271 id: u64,
272}
273
274#[must_use]
276#[derive(Debug)]
277pub struct UnitHandle {
278 id: UnitId,
279 inner: Option<Weak<Mutex<Inner>>>,
280}
281
282impl Drop for UnitHandle {
283 fn drop(&mut self) {
284 self.remove_if();
285 }
286}
287
288impl UnitHandle {
289 pub fn remove(mut self) {
291 self.remove_if();
292 }
293
294 pub fn detach(mut self) {
296 self.inner = None;
297 }
298
299 fn remove_if(&mut self) {
300 if let Some(inner) = self.inner.take().and_then(|inner| inner.upgrade()) {
301 let mut inner = inner.lock();
302 inner.units.remove(&self.id.id).expect("unit exists");
303 inner.names.remove(&self.id.name).expect("unit exists");
304 }
305 }
306}
307
308pub struct StateUnitsInspector {
311 inner: Weak<Mutex<Inner>>,
312}
313
314impl Inspect for StateUnits {
315 fn inspect(&self, req: inspect::Request<'_>) {
316 self.inner.lock().inspect(req);
317 }
318}
319
320impl Inspect for StateUnitsInspector {
321 fn inspect(&self, req: inspect::Request<'_>) {
322 if let Some(inner) = self.inner.upgrade() {
323 inner.lock().inspect(req);
324 }
325 }
326}
327
328impl Inspect for Inner {
329 fn inspect(&self, req: inspect::Request<'_>) {
330 let mut resp = req.respond();
331 for unit in self.units.values() {
332 resp.child(unit.name.as_ref(), |req| {
333 let mut resp = req.respond();
334 if !unit.dependencies.is_empty() {
335 resp.field_with("dependencies", || {
336 unit.dependencies
337 .iter()
338 .map(|id| self.units[id].name.as_ref())
339 .collect::<Vec<_>>()
340 .join(",")
341 });
342 }
343 if !unit.dependents.is_empty() {
344 resp.field_with("dependents", || {
345 unit.dependents
346 .iter()
347 .map(|id| self.units[id].name.as_ref())
348 .collect::<Vec<_>>()
349 .join(",")
350 });
351 }
352 resp.field("unit_state", unit.state);
353 unit.send
354 .send(StateRequest::Inspect(resp.request().defer()))
355 });
356 }
357 }
358}
359
360#[derive(Protobuf)]
362#[mesh(package = "state_unit")]
363pub struct SavedStateUnit {
364 #[mesh(1)]
366 pub name: String,
367 #[mesh(2)]
369 pub state: SavedStateBlob,
370}
371
372#[derive(Debug, Error)]
374#[error("{op} failed")]
375pub struct StateTransitionError {
376 op: &'static str,
377 #[source]
378 errors: UnitErrorSet,
379}
380
381fn extract<T, E: Into<anyhow::Error>, U>(
382 op: &'static str,
383 iter: impl IntoIterator<Item = (Arc<str>, Result<T, E>)>,
384 mut f: impl FnMut(Arc<str>, T) -> Option<U>,
385) -> Result<Vec<U>, StateTransitionError> {
386 let mut result = Vec::new();
387 let mut errors = Vec::new();
388 for (name, item) in iter {
389 match item {
390 Ok(t) => {
391 if let Some(u) = f(name, t) {
392 result.push(u);
393 }
394 }
395 Err(err) => errors.push((name, err.into())),
396 }
397 }
398 if errors.is_empty() {
399 Ok(result)
400 } else {
401 Err(StateTransitionError {
402 op,
403 errors: UnitErrorSet(errors),
404 })
405 }
406}
407
408fn check<E: Into<anyhow::Error>>(
409 op: &'static str,
410 iter: impl IntoIterator<Item = (Arc<str>, Result<(), E>)>,
411) -> Result<(), StateTransitionError> {
412 extract(op, iter, |_, _| Some(()))?;
413 Ok(())
414}
415
416#[derive(Debug)]
417struct UnitErrorSet(Vec<(Arc<str>, anyhow::Error)>);
418
419impl Display for UnitErrorSet {
420 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
421 let mut map = f.debug_map();
422 for (name, err) in &self.0 {
423 map.entry(&format_args!("{}", name), &format_args!("{:#}", err));
424 }
425 map.finish()
426 }
427}
428
429impl std::error::Error for UnitErrorSet {}
430
431impl StateUnits {
432 pub fn new() -> Self {
434 Self {
435 inner: Arc::new(Mutex::new(Inner {
436 next_id: 0,
437 units: BTreeMap::new(),
438 names: HashMap::new(),
439 })),
440 running: false,
441 }
442 }
443
444 pub fn inspector(&self) -> StateUnitsInspector {
447 StateUnitsInspector {
448 inner: Arc::downgrade(&self.inner),
449 }
450 }
451
452 pub fn add(&self, name: impl Into<Arc<str>>) -> UnitBuilder<'_> {
459 UnitBuilder {
460 units: self,
461 name: name.into(),
462 dependencies: Vec::new(),
463 dependents: Vec::new(),
464 }
465 }
466
467 pub fn is_running(&self) -> bool {
469 self.running
470 }
471
472 pub async fn start_stopped_units(&mut self) {
477 if self.is_running() {
478 self.start().await;
479 }
480 }
481
482 pub async fn start(&mut self) {
484 self.run_op(
485 "start",
486 None,
487 State::Stopped,
488 State::Starting,
489 State::Running,
490 StateRequest::Start,
491 |_, _| Some(()),
492 |unit| &unit.dependencies,
493 )
494 .await;
495 self.running = true;
496 }
497
498 pub async fn stop(&mut self) {
500 assert!(self.running);
501 self.run_op(
504 "stop",
505 None,
506 State::Running,
507 State::Stopping,
508 State::Stopped,
509 StateRequest::Stop,
510 |_, _| Some(()),
511 |unit| &unit.dependents,
512 )
513 .await;
514 self.running = false;
515 }
516
517 pub async fn reset(&mut self) -> Result<(), StateTransitionError> {
521 assert!(!self.running);
522 let r = self
525 .run_op(
526 "reset",
527 None,
528 State::Stopped,
529 State::Resetting,
530 State::Stopped,
531 StateRequest::Reset,
532 |_, _| Some(()),
533 |unit| &unit.dependencies,
534 )
535 .await;
536
537 check("reset", r)?;
538 Ok(())
539 }
540
541 pub async fn save(&mut self) -> Result<Vec<SavedStateUnit>, StateTransitionError> {
545 assert!(!self.running);
546 let r = self
549 .run_op(
550 "save",
551 None,
552 State::Stopped,
553 State::Saving,
554 State::Stopped,
555 StateRequest::Save,
556 |_, _| Some(()),
557 |_| &[],
558 )
559 .await;
560
561 let states = extract("save", r, |name, state| {
562 state.map(|state| SavedStateUnit {
563 name: name.to_string(),
564 state,
565 })
566 })?;
567
568 Ok(states)
569 }
570
571 pub async fn restore(
575 &mut self,
576 states: Vec<SavedStateUnit>,
577 ) -> Result<(), StateTransitionError> {
578 assert!(!self.running);
579
580 #[derive(Debug, Error)]
581 enum RestoreUnitError {
582 #[error("unknown unit name")]
583 Unknown,
584 #[error("duplicate unit name")]
585 Duplicate,
586 }
587
588 let mut states_by_id = HashMap::new();
589 let mut r = Vec::new();
590 {
591 let inner = self.inner.lock();
592 for state in states {
593 match inner.names.get_key_value(state.name.as_str()) {
594 Some((name, &id)) => {
595 if states_by_id
596 .insert(id, (name.clone(), state.state))
597 .is_some()
598 {
599 r.push((name.clone(), Err(RestoreUnitError::Duplicate)));
600 }
601 }
602 None => {
603 r.push((state.name.into(), Err(RestoreUnitError::Unknown)));
604 }
605 }
606 }
607 }
608
609 check("restore", r)?;
610
611 let r = self
612 .run_op(
613 "restore",
614 None,
615 State::Stopped,
616 State::Restoring,
617 State::Stopped,
618 StateRequest::Restore,
619 |id, _| states_by_id.remove(&id).map(|(_, blob)| blob),
620 |unit| &unit.dependencies,
621 )
622 .await;
623
624 check(
627 "restore",
628 states_by_id
629 .into_iter()
630 .map(|(_, (name, _))| (name, Err(RestoreUnitError::Unknown))),
631 )?;
632
633 check("restore", r)?;
634
635 self.post_restore().await?;
636 Ok(())
637 }
638
639 async fn post_restore(&mut self) -> Result<(), StateTransitionError> {
643 let r = self
646 .run_op(
647 "post_restore",
648 None,
649 State::Stopped,
650 State::PostRestoring,
651 State::Stopped,
652 StateRequest::PostRestore,
653 |_, _| Some(()),
654 |_| &[],
655 )
656 .await;
657
658 check("post_restore", r)?;
659 Ok(())
660 }
661
662 async fn run_op<I: 'static, R: 'static + Send>(
684 &self,
685 op: &str,
686 unit_ids: Option<&[u64]>,
687 old_state: State,
688 interim_state: State,
689 new_state: State,
690 request: impl Copy + FnOnce(Rpc<I, R>) -> StateRequest,
691 mut input: impl FnMut(u64, &Unit) -> Option<I>,
692 mut deps: impl FnMut(&Unit) -> &[u64],
693 ) -> Vec<(Arc<str>, R)> {
694 let mut done = Vec::new();
695 let ready_set;
696 {
697 let mut inner = self.inner.lock();
698 ready_set = inner.ready_set(unit_ids);
699 for (&id, unit) in inner
700 .units
701 .iter_mut()
702 .filter(|(id, _)| ready_set.0.contains_key(id))
703 {
704 if unit.state != old_state {
705 assert_eq!(
706 unit.state, new_state,
707 "unit {} in {:?} state, should be {:?} or {:?}",
708 unit.name, unit.state, old_state, new_state
709 );
710 ready_set.done(id, true);
711 } else {
712 let name = unit.name.clone();
713 let input = (input)(id, unit);
714 let ready_set = ready_set.clone();
715 let deps = deps(unit).to_vec();
716 let fut = state_change(name.clone(), unit, request, input);
717 let recv = async move {
718 ready_set.wait(op, id, &deps).await;
719 let r = fut.await;
720 ready_set.done(id, true);
721 (name, id, r)
722 };
723 done.push(recv);
724 unit.state = interim_state;
725 }
726 }
727 }
728
729 let results = async {
730 let start = Instant::now();
731 let results = join_all(done).await;
732 tracing::info!(duration = ?Instant::now() - start, "state change complete");
733 results
734 }
735 .instrument(tracing::info_span!("state_change", operation = op))
736 .await;
737
738 let mut inner = self.inner.lock();
739 let r = results
740 .into_iter()
741 .filter_map(|(name, id, r)| {
742 match r {
743 Ok(Some(r)) => Some((name, r)),
744 Ok(None) => None,
745 Err(err) => {
746 if inner.units.contains_key(&id) {
751 panic!("{:?}", err);
752 }
753 None
754 }
755 }
756 })
757 .collect();
758 for (_, unit) in inner
759 .units
760 .iter_mut()
761 .filter(|(id, _)| ready_set.0.contains_key(id))
762 {
763 if unit.state == interim_state {
764 unit.state = new_state;
765 } else {
766 assert_eq!(
767 unit.state, new_state,
768 "unit {} in {:?} state, should be {:?} or {:?}",
769 unit.name, unit.state, interim_state, new_state
770 );
771 }
772 }
773 r
774 }
775}
776
777impl Inner {
778 fn ready_set(&self, unit_ids: Option<&[u64]>) -> ReadySet {
779 let map = |id, unit: &Unit| {
780 (
781 id,
782 ReadyState {
783 name: unit.name.clone(),
784 ready: Arc::new(Ready::default()),
785 },
786 )
787 };
788 let units = if let Some(unit_ids) = unit_ids {
789 unit_ids
790 .iter()
791 .map(|id| map(*id, &self.units[id]))
792 .collect()
793 } else {
794 self.units.iter().map(|(id, unit)| map(*id, unit)).collect()
795 };
796 ReadySet(Arc::new(units))
797 }
798}
799
800#[derive(Clone)]
801struct ReadySet(Arc<BTreeMap<u64, ReadyState>>);
802
803#[derive(Clone)]
804struct ReadyState {
805 name: Arc<str>,
806 ready: Arc<Ready>,
807}
808
809impl ReadySet {
810 async fn wait(&self, op: &str, id: u64, deps: &[u64]) -> bool {
811 for dep in deps {
812 if let Some(dep) = self.0.get(dep) {
813 if !dep.ready.is_ready() {
814 tracing::debug!(
815 device = self.0[&id].name.as_ref(),
816 dependency = dep.name.as_ref(),
817 operation = op,
818 "waiting on dependency"
819 );
820 }
821 if !dep.ready.wait().await {
822 return false;
823 }
824 }
825 }
826 true
827 }
828
829 fn done(&self, id: u64, success: bool) {
830 self.0[&id].ready.signal(success);
831 }
832}
833
834fn state_change<I: 'static, R: 'static + Send, Req: FnOnce(Rpc<I, R>) -> StateRequest>(
839 name: Arc<str>,
840 unit: &Unit,
841 request: Req,
842 input: Option<I>,
843) -> impl Future<Output = Result<Option<R>, UnitRecvError>> + use<I, R, Req> {
844 let send = unit.send.clone();
845
846 async move {
847 let Some(input) = input else { return Ok(None) };
848 let span = tracing::info_span!("device_state_change", device = name.as_ref());
849 async move {
850 let start = Instant::now();
851 let r = send
852 .call(request, input)
853 .await
854 .map_err(|err| UnitRecvError { name, source: err });
855 tracing::debug!(duration = ?Instant::now() - start, "device state change complete");
856 r.map(Some)
857 }
858 .instrument(span)
859 .await
860 }
861}
862
863#[derive(Debug)]
865#[must_use]
866pub struct UnitBuilder<'a> {
867 units: &'a StateUnits,
868 name: Arc<str>,
869 dependencies: Vec<u64>,
870 dependents: Vec<u64>,
871}
872
873impl UnitBuilder<'_> {
874 pub fn depends_on(mut self, handle: &UnitHandle) -> Self {
879 self.dependencies.push(self.handle_id(handle));
880 self
881 }
882
883 pub fn dependency_of(mut self, handle: &UnitHandle) -> Self {
888 self.dependents.push(self.handle_id(handle));
889 self
890 }
891
892 fn handle_id(&self, handle: &UnitHandle) -> u64 {
893 assert_eq!(
895 Weak::as_ptr(handle.inner.as_ref().unwrap()),
896 Arc::as_ptr(&self.units.inner)
897 );
898 handle.id.id
899 }
900
901 pub fn build(mut self, send: Sender<StateRequest>) -> Result<UnitHandle, NameInUse> {
903 let id = {
904 let mut inner = self.units.inner.lock();
905 let id = inner.next_id;
906 let entry = match inner.names.entry(self.name.clone()) {
907 hash_map::Entry::Occupied(_) => return Err(NameInUse(self.name)),
908 hash_map::Entry::Vacant(e) => e,
909 };
910 entry.insert(id);
911
912 self.dependencies.sort();
915 self.dependencies.dedup();
916 for &dep in &self.dependencies {
917 inner.units.get_mut(&dep).unwrap().dependents.push(id);
918 }
919
920 self.dependents.sort();
923 self.dependents.dedup();
924 for &dep in &self.dependents {
925 inner.units.get_mut(&dep).unwrap().dependencies.push(id);
926 }
927 inner.units.insert(
928 id,
929 Unit {
930 name: self.name.clone(),
931 send,
932 dependencies: self.dependencies,
933 dependents: self.dependents,
934 state: State::Stopped,
935 },
936 );
937 let unit_id = UnitId {
938 name: self.name,
939 id,
940 };
941 inner.next_id += 1;
942 unit_id
943 };
944 Ok(UnitHandle {
945 id,
946 inner: Some(Arc::downgrade(&self.units.inner)),
947 })
948 }
949
950 #[track_caller]
956 pub fn spawn<F, Fut>(
957 self,
958 spawner: impl Spawn,
959 f: F,
960 ) -> Result<SpawnedUnit<Fut::Output>, NameInUse>
961 where
962 F: FnOnce(Receiver<StateRequest>) -> Fut,
963 Fut: 'static + Send + Future,
964 Fut::Output: 'static + Send,
965 {
966 let (send, recv) = mesh::channel();
967 let task_name = format!("unit-{}", self.name);
968 let handle = self.build(send)?;
969 let fut = (f)(recv);
970 let task = spawner.spawn(task_name, fut);
971 Ok(SpawnedUnit { task, handle })
972 }
973}
974
975#[must_use]
977pub struct SpawnedUnit<T> {
978 handle: UnitHandle,
979 task: Task<T>,
980}
981
982impl<T> Debug for SpawnedUnit<T> {
983 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
984 f.debug_struct("SpawnedUnit")
985 .field("handle", &self.handle)
986 .field("task", &self.task)
987 .finish()
988 }
989}
990
991impl<T> SpawnedUnit<T> {
992 pub async fn remove(self) -> T {
994 self.handle.remove();
995 self.task.await
996 }
997
998 pub fn handle(&self) -> &UnitHandle {
1001 &self.handle
1002 }
1003}
1004
1005#[derive(Default)]
1006struct Ready {
1007 state: AtomicU32,
1008 event: event_listener::Event,
1009}
1010
1011impl Ready {
1012 fn signal(&self, success: bool) {
1014 self.state.store(success as u32 + 1, Ordering::Release);
1015 self.event.notify(usize::MAX);
1016 }
1017
1018 fn is_ready(&self) -> bool {
1019 self.state.load(Ordering::Acquire) != 0
1020 }
1021
1022 async fn wait(&self) -> bool {
1024 loop {
1025 let listener = self.event.listen();
1026 let state = self.state.load(Ordering::Acquire);
1027 if state != 0 {
1028 return state - 1 != 0;
1029 }
1030 listener.await;
1031 }
1032 }
1033}
1034
1035#[cfg(test)]
1036mod tests {
1037 use super::StateUnit;
1038 use super::StateUnits;
1039 use crate::run_unit;
1040 use inspect::InspectMut;
1041 use mesh::payload::Protobuf;
1042 use pal_async::DefaultDriver;
1043 use pal_async::async_test;
1044 use std::sync::Arc;
1045 use std::sync::atomic::AtomicBool;
1046 use std::sync::atomic::Ordering;
1047 use std::time::Duration;
1048 use test_with_tracing::test;
1049 use vmcore::save_restore::RestoreError;
1050 use vmcore::save_restore::SaveError;
1051 use vmcore::save_restore::SavedStateBlob;
1052 use vmcore::save_restore::SavedStateRoot;
1053
1054 #[derive(Default)]
1055 struct TestUnit {
1056 value: Arc<AtomicBool>,
1057 dep: Option<Arc<AtomicBool>>,
1058 support_saved_state: bool,
1060 }
1061
1062 #[derive(Protobuf, SavedStateRoot)]
1063 #[mesh(package = "test")]
1064 struct SavedState(bool);
1065
1066 impl StateUnit for TestUnit {
1067 async fn start(&mut self) {}
1068
1069 async fn stop(&mut self) {}
1070
1071 async fn reset(&mut self) -> anyhow::Result<()> {
1072 Ok(())
1073 }
1074
1075 async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError> {
1076 if self.support_saved_state {
1077 let state = SavedState(self.value.load(Ordering::Relaxed));
1078 Ok(Some(SavedStateBlob::new(state)))
1079 } else {
1080 Ok(None)
1081 }
1082 }
1083
1084 async fn restore(&mut self, state: SavedStateBlob) -> Result<(), RestoreError> {
1085 assert!(self.dep.as_ref().is_none_or(|v| v.load(Ordering::Relaxed)));
1086
1087 if self.support_saved_state {
1088 let state: SavedState = state.parse()?;
1089 self.value.store(state.0, Ordering::Relaxed);
1090 Ok(())
1091 } else {
1092 Err(RestoreError::SavedStateNotSupported)
1093 }
1094 }
1095
1096 async fn post_restore(&mut self) -> anyhow::Result<()> {
1097 Ok(())
1098 }
1099 }
1100
1101 impl InspectMut for TestUnit {
1102 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1103 req.respond();
1104 }
1105 }
1106
1107 struct TestUnitSetDep {
1108 dep: Arc<AtomicBool>,
1109 driver: DefaultDriver,
1110 }
1111
1112 impl StateUnit for TestUnitSetDep {
1113 async fn start(&mut self) {}
1114
1115 async fn stop(&mut self) {}
1116
1117 async fn reset(&mut self) -> anyhow::Result<()> {
1118 Ok(())
1119 }
1120
1121 async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError> {
1122 Ok(Some(SavedStateBlob::new(SavedState(true))))
1123 }
1124
1125 async fn restore(&mut self, _state: SavedStateBlob) -> Result<(), RestoreError> {
1126 pal_async::timer::PolledTimer::new(&self.driver)
1127 .sleep(Duration::from_millis(100))
1128 .await;
1129
1130 self.dep.store(true, Ordering::Relaxed);
1131
1132 Ok(())
1133 }
1134
1135 async fn post_restore(&mut self) -> anyhow::Result<()> {
1136 Ok(())
1137 }
1138 }
1139
1140 impl InspectMut for TestUnitSetDep {
1141 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
1142 req.respond();
1143 }
1144 }
1145
1146 #[async_test]
1147 async fn test_state_change(driver: DefaultDriver) {
1148 let mut units = StateUnits::new();
1149
1150 let a_val = Arc::new(AtomicBool::new(true));
1151
1152 let _a = units
1153 .add("a")
1154 .spawn(&driver, |recv| {
1155 run_unit(
1156 TestUnit {
1157 value: a_val.clone(),
1158 dep: None,
1159 support_saved_state: true,
1160 },
1161 recv,
1162 )
1163 })
1164 .unwrap();
1165 let _b = units
1166 .add("b")
1167 .spawn(&driver, |recv| run_unit(TestUnit::default(), recv))
1168 .unwrap();
1169 units.start().await;
1170
1171 let _c = units
1172 .add("c")
1173 .spawn(&driver, |recv| run_unit(TestUnit::default(), recv));
1174 units.stop().await;
1175 units.start().await;
1176
1177 units.stop().await;
1178
1179 let state = units.save().await.unwrap();
1180
1181 a_val.store(false, Ordering::Relaxed);
1182
1183 units.restore(state).await.unwrap();
1184
1185 assert!(a_val.load(Ordering::Relaxed));
1186 }
1187
1188 #[async_test]
1189 async fn test_dependencies(driver: DefaultDriver) {
1190 let mut units = StateUnits::new();
1191
1192 let a_val = Arc::new(AtomicBool::new(true));
1193
1194 let a = units
1195 .add("zzz")
1196 .spawn(&driver, |recv| {
1197 run_unit(
1198 TestUnit {
1199 value: a_val.clone(),
1200 dep: None,
1201 support_saved_state: true,
1202 },
1203 recv,
1204 )
1205 })
1206 .unwrap();
1207
1208 let _b = units
1209 .add("aaa")
1210 .depends_on(a.handle())
1211 .spawn(&driver, |recv| {
1212 run_unit(
1213 TestUnit {
1214 dep: Some(a_val.clone()),
1215 value: Default::default(),
1216 support_saved_state: true,
1217 },
1218 recv,
1219 )
1220 })
1221 .unwrap();
1222 units.start().await;
1223 units.stop().await;
1224
1225 let state = units.save().await.unwrap();
1226
1227 a_val.store(false, Ordering::Relaxed);
1228
1229 units.restore(state).await.unwrap();
1230 }
1231
1232 #[async_test]
1233 async fn test_dep_no_saved_state(driver: DefaultDriver) {
1234 let mut units = StateUnits::new();
1235
1236 let true_val = Arc::new(AtomicBool::new(true));
1237 let shared_val = Arc::new(AtomicBool::new(false));
1238
1239 let a = units
1240 .add("a")
1241 .spawn(&driver, |recv| {
1242 run_unit(
1243 TestUnit {
1244 value: true_val.clone(),
1245 dep: Some(shared_val.clone()),
1246 support_saved_state: true,
1247 },
1248 recv,
1249 )
1250 })
1251 .unwrap();
1252
1253 let b = units
1256 .add("b_no_saved_state")
1257 .dependency_of(a.handle())
1258 .spawn(&driver, |recv| {
1259 run_unit(
1260 TestUnit {
1261 value: true_val.clone(),
1262 dep: Some(shared_val.clone()),
1263 support_saved_state: false,
1264 },
1265 recv,
1266 )
1267 })
1268 .unwrap();
1269
1270 let _c = units
1272 .add("c")
1273 .dependency_of(b.handle())
1274 .spawn(&driver, |recv| {
1275 run_unit(
1276 TestUnitSetDep {
1277 dep: shared_val,
1278 driver: driver.clone(),
1279 },
1280 recv,
1281 )
1282 })
1283 .unwrap();
1284
1285 let state = units.save().await.unwrap();
1286
1287 units.restore(state).await.unwrap();
1288 }
1289}