netvsp/
rx_bufs.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Data structure for tracking receive buffer state.
5
6use thiserror::Error;
7
8/// State of networking receive buffers.
9pub struct RxBuffers {
10    /// Chains together rx receive buffers that are used as part of the same
11    /// VMBus request. `state[i]` specifies the index of the next receive buffer
12    /// in the request, or `END` if `i` is the last buffer. The beginning of
13    /// each chain has `state[id] & START_MASK == START_MASK`. `INVALID`
14    /// indicates the buffer is not in use.
15    state: Vec<u32>,
16}
17
18const START_MASK: u32 = 0x80000000;
19const INVALID: u32 = !START_MASK;
20const END: u32 = !1 & !START_MASK;
21
22#[derive(Debug, Error)]
23#[error("suballocation is already in use")]
24pub struct SubAllocationInUse;
25
26impl RxBuffers {
27    pub fn new(count: u32) -> Self {
28        Self {
29            state: (0..count).map(|_| INVALID).collect(),
30        }
31    }
32
33    pub fn is_free(&self, id: u32) -> bool {
34        self.state[id as usize] == INVALID
35    }
36
37    pub fn allocate<I: Iterator<Item = u32> + Clone>(
38        &mut self,
39        ids: impl IntoIterator<Item = u32, IntoIter = I>,
40    ) -> Result<(), SubAllocationInUse> {
41        let ids = ids.into_iter();
42        let first = ids.clone().next().unwrap();
43        let next_ids = ids.clone().skip(1).chain(std::iter::once(END));
44        for (n, (id, next_id)) in ids.clone().zip(next_ids).enumerate() {
45            if self.state[id as usize] != INVALID {
46                for id in ids.take(n) {
47                    self.state[id as usize] = INVALID;
48                }
49                return Err(SubAllocationInUse);
50            }
51            self.state[id as usize] = next_id;
52        }
53        self.state[first as usize] |= START_MASK;
54        Ok(())
55    }
56
57    pub fn free(&mut self, id: u32) -> Option<FreeIterator<'_>> {
58        let next = self.state.get(id as usize)?;
59        if next & START_MASK == 0 {
60            return None;
61        }
62        Some(FreeIterator {
63            id,
64            state: &mut self.state,
65        })
66    }
67
68    pub fn allocated(&self) -> RxIterator<'_> {
69        RxIterator {
70            id: 0,
71            chained_rx_id: &self.state,
72        }
73    }
74}
75
76pub struct RxIterator<'a> {
77    id: usize,
78    chained_rx_id: &'a Vec<u32>,
79}
80
81impl<'a> Iterator for RxIterator<'a> {
82    type Item = ReadIterator<'a>;
83
84    fn next(&mut self) -> Option<Self::Item> {
85        while self.id < self.chained_rx_id.len() {
86            let id = self.id;
87            self.id += 1;
88            if self.chained_rx_id[id] & START_MASK != 0 {
89                return Some(ReadIterator {
90                    id: id as u32,
91                    state: self.chained_rx_id,
92                });
93            }
94        }
95        None
96    }
97}
98
99pub struct ReadIterator<'a> {
100    id: u32,
101    state: &'a Vec<u32>,
102}
103
104impl Iterator for ReadIterator<'_> {
105    type Item = u32;
106
107    fn next(&mut self) -> Option<Self::Item> {
108        let id = self.id;
109        if id == END {
110            return None;
111        }
112        self.id = self.state[id as usize] & !START_MASK;
113        Some(id)
114    }
115}
116
117pub struct FreeIterator<'a> {
118    id: u32,
119    state: &'a mut Vec<u32>,
120}
121
122impl Iterator for FreeIterator<'_> {
123    type Item = u32;
124
125    fn next(&mut self) -> Option<Self::Item> {
126        let id = self.id;
127        if id == END {
128            return None;
129        }
130        self.id = self.state[id as usize] & !START_MASK;
131        self.state[id as usize] = INVALID;
132        Some(id)
133    }
134}
135
136impl Drop for FreeIterator<'_> {
137    fn drop(&mut self) {
138        while self.next().is_some() {}
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::RxBuffers;
145
146    #[test]
147    fn test_rx_bufs() {
148        let mut bufs = RxBuffers::new(20);
149        bufs.allocate([0, 1, 2]).unwrap();
150        bufs.allocate([6, 9, 5]).unwrap();
151        bufs.allocate([3, 10, 15, 0, 4]).unwrap_err();
152        bufs.allocate([3, 10, 12]).unwrap();
153        assert!(!bufs.is_free(1));
154        assert!(!bufs.is_free(3));
155        assert!(bufs.is_free(4));
156        assert!(bufs.free(9).is_none());
157        assert!(bufs.free(12).is_none());
158        assert!(bufs.free(6).unwrap().eq([6, 9, 5]));
159        assert!(
160            bufs.allocated()
161                .map(Vec::from_iter)
162                .eq([[0, 1, 2], [3, 10, 12]])
163        );
164    }
165}