1use crate::ConnectResult;
8use crate::OfferInfo;
9use futures::StreamExt;
10use futures_concurrency::stream::Merge;
11use guid::Guid;
12use inspect::Inspect;
13use inspect::InspectMut;
14use pal_async::task::Spawn;
15use pal_async::task::Task;
16use std::pin::pin;
17use vmbus_core::protocol::OfferChannel;
18
19pub struct ClientFilter {
23 req: mesh::Sender<FilterRequest>,
24 task: Task<()>,
25}
26
27impl Inspect for ClientFilter {
28 fn inspect(&self, req: inspect::Request<'_>) {
29 self.req.send(FilterRequest::Inspect(req.defer()));
30 }
31}
32
33enum FilterRequest {
34 Inspect(inspect::Deferred),
35}
36
37impl ClientFilter {
38 pub async fn shutdown(self) {
40 drop(self.req);
41 self.task.await;
42 }
43}
44
45pub struct ClientFilterBuilder<'a> {
47 clients: Vec<&'a mut FilterDefinition>,
48}
49
50#[derive(InspectMut)]
51struct FilterWorker {
52 #[inspect(flatten)]
53 filters: Filters,
54 #[inspect(skip)]
55 clients: Vec<mesh::Sender<OfferInfo>>,
56}
57
58struct Filters {
59 interfaces: Vec<(Guid, usize)>,
60 instances: Vec<(Guid, Guid, usize)>,
61 rest: Option<usize>,
62 names: Vec<String>,
63}
64
65pub struct FilterDefinition {
67 name: String,
68 interfaces: Vec<Guid>,
69 instances: Vec<(Guid, Guid)>,
70 rest: bool,
71 result: Option<ConnectResult>,
72}
73
74impl FilterDefinition {
75 pub fn new(name: impl Into<String>) -> Self {
77 Self {
78 name: name.into(),
79 interfaces: Vec::new(),
80 instances: Vec::new(),
81 rest: false,
82 result: None,
83 }
84 }
85
86 pub fn by_interface(mut self, interface_id: Guid) -> Self {
88 self.interfaces.push(interface_id);
89 self
90 }
91
92 pub fn by_instance(mut self, interface_id: Guid, instance_id: Guid) -> Self {
95 self.instances.push((interface_id, instance_id));
96 self
97 }
98
99 pub fn rest(mut self) -> Self {
101 self.rest = true;
102 self
103 }
104
105 pub fn take(mut self) -> ConnectResult {
111 self.result
112 .take()
113 .expect("failed to call ClientFilterBuilder::build")
114 }
115}
116
117impl<'a> ClientFilterBuilder<'a> {
118 pub fn new() -> Self {
120 Self {
121 clients: Vec::new(),
122 }
123 }
124
125 pub fn add(&mut self, client: &'a mut FilterDefinition) -> &mut Self {
127 self.clients.push(client);
128 self
129 }
130
131 pub fn build(mut self, driver: impl Spawn, connection: ConnectResult) -> ClientFilter {
136 let mut filters = Filters {
137 interfaces: Vec::new(),
138 instances: Vec::new(),
139 rest: None,
140 names: self.clients.iter().map(|c| c.name.clone()).collect(),
141 };
142 let mut offer_send = Vec::with_capacity(self.clients.len());
143 for (i, client) in self.clients.iter_mut().enumerate() {
144 let (send, recv) = mesh::channel();
145 client.result = Some(ConnectResult {
146 version: connection.version,
147 offers: Vec::new(),
148 offer_recv: recv,
149 });
150 offer_send.push(send);
151 for &interface in &client.interfaces {
152 filters.interfaces.push((interface, i));
153 }
154 for &(interface, instance) in &client.instances {
155 filters.instances.push((interface, instance, i));
156 }
157 if client.rest {
158 assert!(filters.rest.is_none(), "multiple rest filters set");
159 filters.rest = Some(i);
160 }
161 }
162
163 for offer in connection.offers {
164 if let Some(i) = filters.find(&offer.offer) {
165 self.clients[i].result.as_mut().unwrap().offers.push(offer);
166 }
167 }
168
169 let (req_send, req_recv) = mesh::channel();
170 let mut worker = FilterWorker {
171 filters,
172 clients: offer_send,
173 };
174
175 let offer_recv = connection.offer_recv;
176 let task = driver.spawn("client_filter", async move {
177 worker.run(req_recv, offer_recv).await;
178 });
179 ClientFilter {
180 task,
181 req: req_send,
182 }
183 }
184}
185
186impl Filters {
187 fn find(&self, offer: &OfferChannel) -> Option<usize> {
188 let interface = &offer.interface_id;
189 let instance = &offer.instance_id;
190 let (&v, ty) = if let Some(v) = self.instances.iter().find_map(|(iface, inst, send)| {
191 ((iface, inst) == (interface, instance)).then_some(send)
192 }) {
193 (v, "instance")
194 } else if let Some(v) = self
195 .interfaces
196 .iter()
197 .find_map(|(iface, send)| (iface == interface).then_some(send))
198 {
199 (v, "interface")
200 } else if let Some(v) = self.rest.as_ref() {
201 (v, "rest")
202 } else {
203 tracing::warn!(%interface, %instance, "rejecting offer");
204 return None;
205 };
206 tracing::debug!(
207 %interface,
208 %instance,
209 filter_type = ty,
210 client = self.names[v],
211 "accepting offer"
212 );
213 Some(v)
214 }
215}
216
217impl Inspect for Filters {
218 fn inspect(&self, req: inspect::Request<'_>) {
219 let mut resp = req.respond();
220 let Self {
221 interfaces,
222 instances,
223 rest,
224 names,
225 } = self;
226 for &(interface, i) in interfaces {
227 resp.field(&format!("by_interface/{}", interface), &names[i]);
228 }
229 for &(interface, instance, i) in instances {
230 resp.field(
231 &format!("by_instance/{}_{}", interface, instance),
232 &names[i],
233 );
234 }
235 if let Some(i) = rest {
236 resp.field("rest", &names[*i]);
237 }
238 }
239}
240
241impl FilterWorker {
242 async fn run(&mut self, req: mesh::Receiver<FilterRequest>, offers: mesh::Receiver<OfferInfo>) {
243 enum Event {
244 Request(FilterRequest),
245 Done,
246 Offer(OfferInfo),
247 }
248 let req = req
249 .map(Event::Request)
250 .chain(futures::stream::once(async { Event::Done }));
251 let offers = offers.map(Event::Offer);
252 let mut events = pin!((req, offers).merge());
253
254 while let Some(event) = events.next().await {
255 match event {
256 Event::Request(FilterRequest::Inspect(deferred)) => deferred.inspect(&mut *self),
257 Event::Done => break,
258 Event::Offer(offer_info) => {
259 if let Some(i) = self.filters.find(&offer_info.offer) {
260 self.clients[i].send(offer_info);
261 }
262 }
263 }
264 }
265 }
266}