mesh_channel/
cell.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! An implementation of a cell type that can be updated from a remote mesh
5//! node.
6
7use super::bidir::Channel;
8use mesh_node::local_node::HandleMessageError;
9use mesh_node::local_node::HandlePortEvent;
10use mesh_node::local_node::NodeError;
11use mesh_node::local_node::Port;
12use mesh_node::local_node::PortControl;
13use mesh_node::local_node::PortWithHandler;
14use mesh_node::message::MeshField;
15use mesh_node::message::Message;
16use mesh_node::message::OwnedMessage;
17use mesh_node::resource::Resource;
18use mesh_protobuf::EncodeAs;
19use mesh_protobuf::Protobuf;
20use mesh_protobuf::SerializedMessage;
21use std::future::Future;
22use std::future::poll_fn;
23use std::task::Context;
24use std::task::Poll;
25use std::task::Waker;
26
27/// A cell containing a value that can be updated from a remote node.
28///
29/// Created by [`cell()`].
30#[derive(Debug, Clone, Protobuf)]
31#[mesh(bound = "T: MeshField + Send", resource = "Resource")]
32pub struct Cell<T: 'static + MeshField + Send + Sync + Clone>(EncodeAs<Inner<T>, EncodedCell<T>>);
33
34#[derive(Debug)]
35struct Inner<T> {
36    port: PortWithHandler<State<T>>,
37    last_id: u64,
38}
39
40#[derive(Debug)]
41struct State<T> {
42    id: u64,
43    value: T,
44    waker: Option<Waker>,
45}
46
47#[derive(Protobuf)]
48#[mesh(resource = "Resource")]
49struct EncodedCell<T> {
50    id: u64,
51    value: T,
52    port: Port,
53}
54
55/// A type used to update the value in one or more [`Cell`]s.
56#[derive(Debug, Protobuf)]
57#[mesh(resource = "Resource")]
58pub struct CellUpdater<T> {
59    value: T,
60    current_id: u64,
61    ports: Vec<(u64, Channel)>,
62}
63
64impl<T: 'static + Clone + MeshField + Sync + Send> CellUpdater<T> {
65    /// Creates a new cell updater with no associated cells.
66    pub fn new(value: T) -> Self {
67        Self {
68            value,
69            current_id: 0,
70            ports: Vec::new(),
71        }
72    }
73
74    /// Creates a new associated cell.
75    pub fn cell(&mut self) -> Cell<T> {
76        let (recv, send) = Port::new_pair();
77        send.send(Message::new(UpdateMessage {
78            id: self.current_id,
79            value: self.value.clone(),
80        }));
81        self.ports.push((self.current_id, send.into()));
82        Cell(EncodeAs::new(Inner::from_parts(
83            self.current_id,
84            self.value.clone(),
85            recv,
86        )))
87    }
88
89    /// Gets the current value.
90    pub fn get(&self) -> &T {
91        &self.value
92    }
93
94    /// Asynchronously updates the value in the associated cells.
95    pub fn set(&mut self, value: T) -> impl '_ + Future<Output = ()> + Unpin {
96        self.send_value(value);
97        self.process_incoming()
98    }
99
100    fn send_value(&mut self, value: T) {
101        self.value = value;
102        self.current_id += 1;
103        for (_, port) in self.ports.iter_mut() {
104            port.send(SerializedMessage::from_message(UpdateMessage {
105                id: self.current_id,
106                value: self.value.clone(),
107            }));
108        }
109    }
110
111    fn poll_one(&mut self, cx: &mut Context<'_>, i: usize) -> Poll<bool> {
112        loop {
113            let (id, port) = &mut self.ports[i];
114            if *id >= self.current_id {
115                break Poll::Ready(true);
116            }
117            let message = std::task::ready!(port.poll_recv(cx));
118            let message = message.ok().and_then(|m| m.into_message().ok());
119            match message {
120                Some(message) => match message {
121                    UpdateResponse::NewPort(new_id, new_port) => {
122                        if new_id < self.current_id {
123                            // This port has a stale value. Send it the new
124                            // value. We'll wait for its response in a
125                            // subsequent call.
126                            new_port.send(Message::new(UpdateMessage {
127                                id: self.current_id,
128                                value: self.value.clone(),
129                            }));
130                        }
131                        self.ports.push((new_id, new_port.into()));
132                    }
133                    UpdateResponse::Updated(new_id) => {
134                        if new_id > *id {
135                            *id = new_id;
136                        }
137                    }
138                },
139
140                None => {
141                    break Poll::Ready(false);
142                }
143            }
144        }
145    }
146
147    fn process_incoming(&mut self) -> impl '_ + Future<Output = ()> + Unpin {
148        poll_fn(|cx| {
149            let mut wait = false;
150            let mut i = 0;
151            while i < self.ports.len() {
152                match self.poll_one(cx, i) {
153                    Poll::Ready(true) => i += 1,
154                    Poll::Ready(false) => {
155                        self.ports.swap_remove(i);
156                    }
157                    Poll::Pending => {
158                        i += 1;
159                        wait = true;
160                    }
161                }
162            }
163            if wait { Poll::Pending } else { Poll::Ready(()) }
164        })
165    }
166}
167
168/// Creates a new cell and its associated updater.
169///
170/// Both the cell and the updater can be sent to remote processes via mesh channels.
171///
172/// ```rust
173/// # use mesh_channel::cell::cell;
174/// # use futures::executor::block_on;
175/// let (mut updater, cell) = cell::<u32>(5);
176/// assert_eq!(cell.get(), 5);
177/// block_on(updater.set(6));
178/// assert_eq!(cell.get(), 6);
179/// ```
180pub fn cell<T: 'static + MeshField + Send + Sync + Clone>(value: T) -> (CellUpdater<T>, Cell<T>) {
181    let mut updater = CellUpdater::new(value);
182    let cell = updater.cell();
183    (updater, cell)
184}
185
186impl<T: 'static + MeshField + Send + Sync + Clone> Clone for Inner<T> {
187    fn clone(&self) -> Self {
188        let (left, right) = Port::new_pair();
189        // Hold the lock for the whole operation to ensure the new port message
190        // is sent before the update callback sends the update response;
191        // otherwise, the updater will fail to see that there is a new port with
192        // a stale value.
193        let (id, value) = self.port.with_port_and_handler(|control, state| {
194            let id = state.id;
195            let value = state.value.clone();
196            control.respond(Message::new(UpdateResponse::NewPort(id, left)));
197            (id, value)
198        });
199        Self::from_parts(id, value, right)
200    }
201}
202
203impl<T: 'static + MeshField + Send + Sync + Clone> Cell<T> {
204    /// Gets a clone of the cell's current value.
205    pub fn get(&self) -> T
206    where
207        T: Clone,
208    {
209        self.0.port.with_handler(|state| state.value.clone())
210    }
211
212    /// Runs `f` with a reference to the cell's current value.
213    ///
214    /// While `f` is running, updates to the cell's value will not be
215    /// acknowledged (and the remote updater's `set` method will block).
216    pub fn with<F, R>(&self, f: F) -> R
217    where
218        F: FnOnce(&T) -> R,
219    {
220        self.0.port.with_handler(|state| f(&state.value))
221    }
222
223    /// Runs `f` with a mutable reference to the cell's current value.
224    ///
225    /// While `f` is running, updates to the cell's value will not be
226    /// acknowledged (and the remote updater's `set` method will block).
227    pub fn with_mut<F, R>(&self, f: F) -> R
228    where
229        F: FnOnce(&mut T) -> R,
230    {
231        self.0.port.with_handler(|state| f(&mut state.value))
232    }
233
234    /// Waits for a new value to be set.
235    pub async fn wait_next(&mut self) {
236        poll_fn(|cx| {
237            let mut old_waker = None;
238            let inner = &mut *self.0;
239            inner.port.with_handler(|state| {
240                if inner.last_id == state.id {
241                    old_waker = state.waker.replace(cx.waker().clone());
242                    return Poll::Pending;
243                }
244                inner.last_id = state.id;
245                Poll::Ready(())
246            })
247        })
248        .await
249    }
250}
251
252#[derive(Protobuf)]
253#[mesh(resource = "Resource")]
254struct UpdateMessage<T> {
255    value: T,
256    id: u64,
257}
258
259#[derive(Protobuf)]
260#[mesh(resource = "Resource")]
261enum UpdateResponse {
262    Updated(u64),
263    NewPort(u64, Port),
264}
265
266impl<T: 'static + MeshField + Send + Sync> HandlePortEvent for State<T> {
267    fn message(
268        &mut self,
269        control: &mut PortControl<'_, '_>,
270        message: Message<'_>,
271    ) -> Result<(), HandleMessageError> {
272        let UpdateMessage::<T> { id, value } = message.parse().map_err(HandleMessageError::new)?;
273        if self.id < id {
274            self.id = id;
275            self.value = value;
276            if let Some(waker) = self.waker.take() {
277                control.wake(waker);
278            }
279            control.respond(Message::new(UpdateResponse::Updated(id)));
280        }
281        Ok(())
282    }
283
284    fn close(&mut self, _control: &mut PortControl<'_, '_>) {}
285
286    fn fail(&mut self, _control: &mut PortControl<'_, '_>, _err: NodeError) {}
287
288    fn drain(&mut self) -> Vec<OwnedMessage> {
289        Vec::new()
290    }
291}
292
293impl<T: 'static + MeshField + Send + Sync> Inner<T> {
294    fn from_parts(id: u64, value: T, port: Port) -> Self {
295        let state = State {
296            id,
297            value,
298            waker: None,
299        };
300        Self {
301            port: port.set_handler(state),
302            last_id: id,
303        }
304    }
305
306    fn into_parts(self) -> (u64, T, Port) {
307        let (port, state) = self.port.remove_handler();
308        (state.id, state.value, port)
309    }
310}
311
312impl<T: 'static + MeshField + Send + Sync + Clone> From<Inner<T>> for EncodedCell<T> {
313    fn from(cell: Inner<T>) -> Self {
314        let (id, value, port) = cell.into_parts();
315        Self { id, value, port }
316    }
317}
318
319impl<T: 'static + MeshField + Send + Sync + Clone> From<EncodedCell<T>> for Inner<T> {
320    fn from(encoded: EncodedCell<T>) -> Self {
321        Inner::from_parts(encoded.id, encoded.value, encoded.port)
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::CellUpdater;
328    use pal_async::DefaultDriver;
329    use pal_async::async_test;
330    use pal_async::task::Spawn;
331    use std::future::poll_fn;
332    use std::task::Poll;
333
334    #[async_test]
335    async fn cell() {
336        let (mut updater, cell) = super::cell("hey".to_string());
337        updater.set("hello".to_string()).await;
338        cell.with(|val| assert_eq!(&val, &"hello"));
339    }
340
341    #[async_test]
342    async fn multi_cell() {
343        let mut updater = CellUpdater::new(0);
344        let c1 = updater.cell();
345        let c2 = updater.cell();
346        let c3 = updater.cell();
347        let c4 = c3.clone();
348        updater.set(5).await;
349        let c5 = updater.cell();
350        let c6 = c4.clone();
351        assert_eq!(c1.get(), 5);
352        assert_eq!(c2.get(), 5);
353        assert_eq!(c3.get(), 5);
354        assert_eq!(c4.get(), 5);
355        assert_eq!(c5.get(), 5);
356        assert_eq!(c6.get(), 5);
357    }
358
359    #[async_test]
360    async fn wait_next(driver: DefaultDriver) {
361        let mut updater = CellUpdater::new(0);
362        let mut c = updater.cell();
363        for i in 1..100 {
364            let t = driver.spawn("test", async {
365                c.wait_next().await;
366                c
367            });
368
369            // Yield so that `t` runs until it blocks.
370            let mut yielded = false;
371            poll_fn(|cx| {
372                if yielded {
373                    Poll::Ready(())
374                } else {
375                    cx.waker().wake_by_ref();
376                    yielded = true;
377                    Poll::Pending
378                }
379            })
380            .await;
381
382            drop(updater.set(i));
383            c = t.await;
384            assert_eq!(c.get(), i);
385        }
386    }
387}