vmbus_client/
filter.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Support for filtering vmbus offers. This is useful for redirecting offers to
5//! separate client drivers.
6
7use crate::ConnectResult;
8use crate::OfferInfo;
9use futures::StreamExt;
10use futures_concurrency::stream::Merge;
11use guid::Guid;
12use inspect::Inspect;
13use inspect::InspectMut;
14use pal_async::task::Spawn;
15use pal_async::task::Task;
16use std::pin::pin;
17use vmbus_core::protocol::OfferChannel;
18
19/// A filter.
20///
21/// Create using [`ClientFilterBuilder`].
22pub struct ClientFilter {
23    req: mesh::Sender<FilterRequest>,
24    task: Task<()>,
25}
26
27impl Inspect for ClientFilter {
28    fn inspect(&self, req: inspect::Request<'_>) {
29        self.req.send(FilterRequest::Inspect(req.defer()));
30    }
31}
32
33enum FilterRequest {
34    Inspect(inspect::Deferred),
35}
36
37impl ClientFilter {
38    /// Shuts down the filter.
39    pub async fn shutdown(self) {
40        drop(self.req);
41        self.task.await;
42    }
43}
44
45/// A builder for creating a [`ClientFilter`].
46pub struct ClientFilterBuilder<'a> {
47    clients: Vec<&'a mut FilterDefinition>,
48}
49
50#[derive(InspectMut)]
51struct FilterWorker {
52    #[inspect(flatten)]
53    filters: Filters,
54    #[inspect(skip)]
55    clients: Vec<mesh::Sender<OfferInfo>>,
56}
57
58struct Filters {
59    interfaces: Vec<(Guid, usize)>,
60    instances: Vec<(Guid, Guid, usize)>,
61    rest: Option<usize>,
62    names: Vec<String>,
63}
64
65/// A single filter definition.
66pub struct FilterDefinition {
67    name: String,
68    interfaces: Vec<Guid>,
69    instances: Vec<(Guid, Guid)>,
70    rest: bool,
71    result: Option<ConnectResult>,
72}
73
74impl FilterDefinition {
75    /// Returns a new filter instance with the given name (for diagnostics).
76    pub fn new(name: impl Into<String>) -> Self {
77        Self {
78            name: name.into(),
79            interfaces: Vec::new(),
80            instances: Vec::new(),
81            rest: false,
82            result: None,
83        }
84    }
85
86    /// Adds the specified interface ID to the filter, to include offers for that interface.
87    pub fn by_interface(mut self, interface_id: Guid) -> Self {
88        self.interfaces.push(interface_id);
89        self
90    }
91
92    /// Adds the specified interface ID and instance ID to the filter, to
93    /// include offers for a specific offer instance.
94    pub fn by_instance(mut self, interface_id: Guid, instance_id: Guid) -> Self {
95        self.instances.push((interface_id, instance_id));
96        self
97    }
98
99    /// Filter all remaining offers that do not match any other filter.
100    pub fn rest(mut self) -> Self {
101        self.rest = true;
102        self
103    }
104
105    /// Takes a filtered connection result.
106    ///
107    /// This should be called only after the filter has been built and offers
108    /// have been processed, via [`ClientFilterBuilder::build`]. Panics
109    /// otherwise.
110    pub fn take(mut self) -> ConnectResult {
111        self.result
112            .take()
113            .expect("failed to call ClientFilterBuilder::build")
114    }
115}
116
117impl<'a> ClientFilterBuilder<'a> {
118    /// Creates a new filter builder.
119    pub fn new() -> Self {
120        Self {
121            clients: Vec::new(),
122        }
123    }
124
125    /// Adds a filter definition.
126    pub fn add(&mut self, client: &'a mut FilterDefinition) -> &mut Self {
127        self.clients.push(client);
128        self
129    }
130
131    /// Builds a filter instance, which applies the assigned filters to the
132    /// initial and dynamic offers in `connection`.
133    ///
134    /// Uses `driver` to spawn the filter worker task.
135    pub fn build(mut self, driver: impl Spawn, connection: ConnectResult) -> ClientFilter {
136        let mut filters = Filters {
137            interfaces: Vec::new(),
138            instances: Vec::new(),
139            rest: None,
140            names: self.clients.iter().map(|c| c.name.clone()).collect(),
141        };
142        let mut offer_send = Vec::with_capacity(self.clients.len());
143        for (i, client) in self.clients.iter_mut().enumerate() {
144            let (send, recv) = mesh::channel();
145            client.result = Some(ConnectResult {
146                version: connection.version,
147                offers: Vec::new(),
148                offer_recv: recv,
149            });
150            offer_send.push(send);
151            for &interface in &client.interfaces {
152                filters.interfaces.push((interface, i));
153            }
154            for &(interface, instance) in &client.instances {
155                filters.instances.push((interface, instance, i));
156            }
157            if client.rest {
158                assert!(filters.rest.is_none(), "multiple rest filters set");
159                filters.rest = Some(i);
160            }
161        }
162
163        for offer in connection.offers {
164            if let Some(i) = filters.find(&offer.offer) {
165                self.clients[i].result.as_mut().unwrap().offers.push(offer);
166            }
167        }
168
169        let (req_send, req_recv) = mesh::channel();
170        let mut worker = FilterWorker {
171            filters,
172            clients: offer_send,
173        };
174
175        let offer_recv = connection.offer_recv;
176        let task = driver.spawn("client_filter", async move {
177            worker.run(req_recv, offer_recv).await;
178        });
179        ClientFilter {
180            task,
181            req: req_send,
182        }
183    }
184}
185
186impl Filters {
187    fn find(&self, offer: &OfferChannel) -> Option<usize> {
188        let interface = &offer.interface_id;
189        let instance = &offer.instance_id;
190        let (&v, ty) = if let Some(v) = self.instances.iter().find_map(|(iface, inst, send)| {
191            ((iface, inst) == (interface, instance)).then_some(send)
192        }) {
193            (v, "instance")
194        } else if let Some(v) = self
195            .interfaces
196            .iter()
197            .find_map(|(iface, send)| (iface == interface).then_some(send))
198        {
199            (v, "interface")
200        } else if let Some(v) = self.rest.as_ref() {
201            (v, "rest")
202        } else {
203            tracing::warn!(%interface, %instance, "rejecting offer");
204            return None;
205        };
206        tracing::debug!(
207            %interface,
208            %instance,
209            filter_type = ty,
210            client = self.names[v],
211            "accepting offer"
212        );
213        Some(v)
214    }
215}
216
217impl Inspect for Filters {
218    fn inspect(&self, req: inspect::Request<'_>) {
219        let mut resp = req.respond();
220        let Self {
221            interfaces,
222            instances,
223            rest,
224            names,
225        } = self;
226        for &(interface, i) in interfaces {
227            resp.field(&format!("by_interface/{}", interface), &names[i]);
228        }
229        for &(interface, instance, i) in instances {
230            resp.field(
231                &format!("by_instance/{}_{}", interface, instance),
232                &names[i],
233            );
234        }
235        if let Some(i) = rest {
236            resp.field("rest", &names[*i]);
237        }
238    }
239}
240
241impl FilterWorker {
242    async fn run(&mut self, req: mesh::Receiver<FilterRequest>, offers: mesh::Receiver<OfferInfo>) {
243        enum Event {
244            Request(FilterRequest),
245            Done,
246            Offer(OfferInfo),
247        }
248        let req = req
249            .map(Event::Request)
250            .chain(futures::stream::once(async { Event::Done }));
251        let offers = offers.map(Event::Offer);
252        let mut events = pin!((req, offers).merge());
253
254        while let Some(event) = events.next().await {
255            match event {
256                Event::Request(FilterRequest::Inspect(deferred)) => deferred.inspect(&mut *self),
257                Event::Done => break,
258                Event::Offer(offer_info) => {
259                    if let Some(i) = self.filters.find(&offer_info.offer) {
260                        self.clients[i].send(offer_info);
261                    }
262                }
263            }
264        }
265    }
266}