1use super::bidir::Channel;
8use mesh_node::local_node::HandleMessageError;
9use mesh_node::local_node::HandlePortEvent;
10use mesh_node::local_node::NodeError;
11use mesh_node::local_node::Port;
12use mesh_node::local_node::PortControl;
13use mesh_node::local_node::PortWithHandler;
14use mesh_node::message::MeshField;
15use mesh_node::message::Message;
16use mesh_node::message::OwnedMessage;
17use mesh_node::resource::Resource;
18use mesh_protobuf::EncodeAs;
19use mesh_protobuf::Protobuf;
20use mesh_protobuf::SerializedMessage;
21use std::future::Future;
22use std::future::poll_fn;
23use std::task::Context;
24use std::task::Poll;
25use std::task::Waker;
26
27#[derive(Debug, Clone, Protobuf)]
31#[mesh(bound = "T: MeshField + Send", resource = "Resource")]
32pub struct Cell<T: 'static + MeshField + Send + Sync + Clone>(EncodeAs<Inner<T>, EncodedCell<T>>);
33
34#[derive(Debug)]
35struct Inner<T> {
36 port: PortWithHandler<State<T>>,
37 last_id: u64,
38}
39
40#[derive(Debug)]
41struct State<T> {
42 id: u64,
43 value: T,
44 waker: Option<Waker>,
45}
46
47#[derive(Protobuf)]
48#[mesh(resource = "Resource")]
49struct EncodedCell<T> {
50 id: u64,
51 value: T,
52 port: Port,
53}
54
55#[derive(Debug, Protobuf)]
57#[mesh(resource = "Resource")]
58pub struct CellUpdater<T> {
59 value: T,
60 current_id: u64,
61 ports: Vec<(u64, Channel)>,
62}
63
64impl<T: 'static + Clone + MeshField + Sync + Send> CellUpdater<T> {
65 pub fn new(value: T) -> Self {
67 Self {
68 value,
69 current_id: 0,
70 ports: Vec::new(),
71 }
72 }
73
74 pub fn cell(&mut self) -> Cell<T> {
76 let (recv, send) = Port::new_pair();
77 send.send(Message::new(UpdateMessage {
78 id: self.current_id,
79 value: self.value.clone(),
80 }));
81 self.ports.push((self.current_id, send.into()));
82 Cell(EncodeAs::new(Inner::from_parts(
83 self.current_id,
84 self.value.clone(),
85 recv,
86 )))
87 }
88
89 pub fn get(&self) -> &T {
91 &self.value
92 }
93
94 pub fn set(&mut self, value: T) -> impl '_ + Future<Output = ()> + Unpin {
96 self.send_value(value);
97 self.process_incoming()
98 }
99
100 fn send_value(&mut self, value: T) {
101 self.value = value;
102 self.current_id += 1;
103 for (_, port) in self.ports.iter_mut() {
104 port.send(SerializedMessage::from_message(UpdateMessage {
105 id: self.current_id,
106 value: self.value.clone(),
107 }));
108 }
109 }
110
111 fn poll_one(&mut self, cx: &mut Context<'_>, i: usize) -> Poll<bool> {
112 loop {
113 let (id, port) = &mut self.ports[i];
114 if *id >= self.current_id {
115 break Poll::Ready(true);
116 }
117 let message = std::task::ready!(port.poll_recv(cx));
118 let message = message.ok().and_then(|m| m.into_message().ok());
119 match message {
120 Some(message) => match message {
121 UpdateResponse::NewPort(new_id, new_port) => {
122 if new_id < self.current_id {
123 new_port.send(Message::new(UpdateMessage {
127 id: self.current_id,
128 value: self.value.clone(),
129 }));
130 }
131 self.ports.push((new_id, new_port.into()));
132 }
133 UpdateResponse::Updated(new_id) => {
134 if new_id > *id {
135 *id = new_id;
136 }
137 }
138 },
139
140 None => {
141 break Poll::Ready(false);
142 }
143 }
144 }
145 }
146
147 fn process_incoming(&mut self) -> impl '_ + Future<Output = ()> + Unpin {
148 poll_fn(|cx| {
149 let mut wait = false;
150 let mut i = 0;
151 while i < self.ports.len() {
152 match self.poll_one(cx, i) {
153 Poll::Ready(true) => i += 1,
154 Poll::Ready(false) => {
155 self.ports.swap_remove(i);
156 }
157 Poll::Pending => {
158 i += 1;
159 wait = true;
160 }
161 }
162 }
163 if wait { Poll::Pending } else { Poll::Ready(()) }
164 })
165 }
166}
167
168pub fn cell<T: 'static + MeshField + Send + Sync + Clone>(value: T) -> (CellUpdater<T>, Cell<T>) {
181 let mut updater = CellUpdater::new(value);
182 let cell = updater.cell();
183 (updater, cell)
184}
185
186impl<T: 'static + MeshField + Send + Sync + Clone> Clone for Inner<T> {
187 fn clone(&self) -> Self {
188 let (left, right) = Port::new_pair();
189 let (id, value) = self.port.with_port_and_handler(|control, state| {
194 let id = state.id;
195 let value = state.value.clone();
196 control.respond(Message::new(UpdateResponse::NewPort(id, left)));
197 (id, value)
198 });
199 Self::from_parts(id, value, right)
200 }
201}
202
203impl<T: 'static + MeshField + Send + Sync + Clone> Cell<T> {
204 pub fn get(&self) -> T
206 where
207 T: Clone,
208 {
209 self.0.port.with_handler(|state| state.value.clone())
210 }
211
212 pub fn with<F, R>(&self, f: F) -> R
217 where
218 F: FnOnce(&T) -> R,
219 {
220 self.0.port.with_handler(|state| f(&state.value))
221 }
222
223 pub fn with_mut<F, R>(&self, f: F) -> R
228 where
229 F: FnOnce(&mut T) -> R,
230 {
231 self.0.port.with_handler(|state| f(&mut state.value))
232 }
233
234 pub async fn wait_next(&mut self) {
236 poll_fn(|cx| {
237 let mut old_waker = None;
238 let inner = &mut *self.0;
239 inner.port.with_handler(|state| {
240 if inner.last_id == state.id {
241 old_waker = state.waker.replace(cx.waker().clone());
242 return Poll::Pending;
243 }
244 inner.last_id = state.id;
245 Poll::Ready(())
246 })
247 })
248 .await
249 }
250}
251
252#[derive(Protobuf)]
253#[mesh(resource = "Resource")]
254struct UpdateMessage<T> {
255 value: T,
256 id: u64,
257}
258
259#[derive(Protobuf)]
260#[mesh(resource = "Resource")]
261enum UpdateResponse {
262 Updated(u64),
263 NewPort(u64, Port),
264}
265
266impl<T: 'static + MeshField + Send + Sync> HandlePortEvent for State<T> {
267 fn message(
268 &mut self,
269 control: &mut PortControl<'_, '_>,
270 message: Message<'_>,
271 ) -> Result<(), HandleMessageError> {
272 let UpdateMessage::<T> { id, value } = message.parse().map_err(HandleMessageError::new)?;
273 if self.id < id {
274 self.id = id;
275 self.value = value;
276 if let Some(waker) = self.waker.take() {
277 control.wake(waker);
278 }
279 control.respond(Message::new(UpdateResponse::Updated(id)));
280 }
281 Ok(())
282 }
283
284 fn close(&mut self, _control: &mut PortControl<'_, '_>) {}
285
286 fn fail(&mut self, _control: &mut PortControl<'_, '_>, _err: NodeError) {}
287
288 fn drain(&mut self) -> Vec<OwnedMessage> {
289 Vec::new()
290 }
291}
292
293impl<T: 'static + MeshField + Send + Sync> Inner<T> {
294 fn from_parts(id: u64, value: T, port: Port) -> Self {
295 let state = State {
296 id,
297 value,
298 waker: None,
299 };
300 Self {
301 port: port.set_handler(state),
302 last_id: id,
303 }
304 }
305
306 fn into_parts(self) -> (u64, T, Port) {
307 let (port, state) = self.port.remove_handler();
308 (state.id, state.value, port)
309 }
310}
311
312impl<T: 'static + MeshField + Send + Sync + Clone> From<Inner<T>> for EncodedCell<T> {
313 fn from(cell: Inner<T>) -> Self {
314 let (id, value, port) = cell.into_parts();
315 Self { id, value, port }
316 }
317}
318
319impl<T: 'static + MeshField + Send + Sync + Clone> From<EncodedCell<T>> for Inner<T> {
320 fn from(encoded: EncodedCell<T>) -> Self {
321 Inner::from_parts(encoded.id, encoded.value, encoded.port)
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::CellUpdater;
328 use pal_async::DefaultDriver;
329 use pal_async::async_test;
330 use pal_async::task::Spawn;
331 use std::future::poll_fn;
332 use std::task::Poll;
333
334 #[async_test]
335 async fn cell() {
336 let (mut updater, cell) = super::cell("hey".to_string());
337 updater.set("hello".to_string()).await;
338 cell.with(|val| assert_eq!(&val, &"hello"));
339 }
340
341 #[async_test]
342 async fn multi_cell() {
343 let mut updater = CellUpdater::new(0);
344 let c1 = updater.cell();
345 let c2 = updater.cell();
346 let c3 = updater.cell();
347 let c4 = c3.clone();
348 updater.set(5).await;
349 let c5 = updater.cell();
350 let c6 = c4.clone();
351 assert_eq!(c1.get(), 5);
352 assert_eq!(c2.get(), 5);
353 assert_eq!(c3.get(), 5);
354 assert_eq!(c4.get(), 5);
355 assert_eq!(c5.get(), 5);
356 assert_eq!(c6.get(), 5);
357 }
358
359 #[async_test]
360 async fn wait_next(driver: DefaultDriver) {
361 let mut updater = CellUpdater::new(0);
362 let mut c = updater.cell();
363 for i in 1..100 {
364 let t = driver.spawn("test", async {
365 c.wait_next().await;
366 c
367 });
368
369 let mut yielded = false;
371 poll_fn(|cx| {
372 if yielded {
373 Poll::Ready(())
374 } else {
375 cx.waker().wake_by_ref();
376 yielded = true;
377 Poll::Pending
378 }
379 })
380 .await;
381
382 drop(updater.set(i));
383 c = t.await;
384 assert_eq!(c.get(), i);
385 }
386 }
387}