1use crate::bus::ChannelRequest;
7use crate::bus::ChannelServerRequest;
8use crate::bus::ModifyRequest;
9use crate::bus::OfferInput;
10use crate::bus::OfferParams;
11use crate::bus::OfferResources;
12use crate::bus::OpenRequest;
13use crate::bus::ParentBus;
14use crate::gpadl::GpadlMap;
15use crate::gpadl::GpadlMapView;
16use anyhow::Context;
17use async_trait::async_trait;
18use futures::StreamExt;
19use futures::stream::SelectAll;
20use futures::stream::select;
21use inspect::Inspect;
22use inspect::InspectMut;
23use mesh::RecvError;
24use mesh::rpc::FailableRpc;
25use mesh::rpc::Rpc;
26use mesh::rpc::RpcSend;
27use pal_async::task::Spawn;
28use pal_async::task::Task;
29use pal_event::Event;
30use std::any::Any;
31use std::collections::BTreeSet;
32use std::marker::PhantomData;
33use std::pin::pin;
34use std::sync::Arc;
35use thiserror::Error;
36use tracing::instrument;
37use vmbus_core::TaggedStream;
38use vmbus_core::protocol::GpadlId;
39use vmbus_ring::gparange::MultiPagedRangeBuf;
40use vmcore::notify::Notify;
41use vmcore::save_restore::RestoreError;
42use vmcore::save_restore::SaveError;
43use vmcore::save_restore::SavedStateBlob;
44use vmcore::slim_event::SlimEvent;
45
46pub type ChannelOpenError = anyhow::Error;
48
49#[async_trait]
51pub trait VmbusDevice: Send + Any + InspectMut {
52 fn offer(&self) -> OfferParams;
54
55 fn max_subchannels(&self) -> u16 {
57 0
58 }
59
60 fn install(&mut self, resources: DeviceResources);
62
63 async fn open(
65 &mut self,
66 channel_idx: u16,
67 open_request: &OpenRequest,
68 ) -> Result<(), ChannelOpenError>;
69
70 async fn close(&mut self, channel_idx: u16);
72
73 async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32);
75
76 fn start(&mut self);
78
79 async fn stop(&mut self);
81
82 fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice>;
87}
88
89#[async_trait]
91pub trait SaveRestoreVmbusDevice: VmbusDevice {
92 async fn save(&mut self) -> Result<SavedStateBlob, SaveError>;
94
95 async fn restore(
100 &mut self,
101 control: RestoreControl<'_>,
102 state: SavedStateBlob,
103 ) -> Result<(), RestoreError>;
104}
105
106#[derive(Debug, Default)]
108pub struct DeviceResources {
109 pub offer_resources: OfferResources,
111 pub gpadl_map: GpadlMapView,
113 pub channel_control: ChannelControl,
115 pub channels: Vec<ChannelResources>,
117}
118
119#[derive(Debug)]
121pub struct ChannelResources {
122 pub event: Notify,
124}
125
126#[derive(Debug, Default, Clone)]
128pub struct ChannelControl {
129 send: Option<mesh::Sender<u16>>,
130 max: u16,
131}
132
133#[derive(Debug, Error)]
135#[error("too many subchannels requested")]
136pub struct TooManySubchannels;
137
138impl ChannelControl {
139 pub fn enable_subchannels(&self, count: u16) -> Result<(), TooManySubchannels> {
146 if count > self.max {
147 return Err(TooManySubchannels);
148 }
149 if let Some(send) = &self.send {
150 send.send(count);
151 }
152 Ok(())
153 }
154
155 pub fn max_subchannels(&self) -> u16 {
157 self.max
158 }
159}
160
161#[must_use]
165pub(crate) struct GenericChannelHandle {
166 state_req: mesh::Sender<StateRequest>,
167 task: Task<Box<dyn VmbusDevice>>,
168}
169
170#[derive(Debug)]
171enum StateRequest {
172 Start,
174 Stop(Rpc<(), ()>),
176
177 Reset(Rpc<(), ()>),
181
182 Save(FailableRpc<(), Option<SavedStateBlob>>),
186
187 Restore(FailableRpc<SavedStateBlob, ()>),
191
192 Inspect(inspect::Deferred),
194}
195
196impl std::fmt::Debug for GenericChannelHandle {
197 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198 f.pad("ChannelHandle")
199 }
200}
201
202impl GenericChannelHandle {
203 pub async fn revoke(self) -> Option<Box<dyn VmbusDevice>> {
205 drop(self.state_req);
206 Some(self.task.await)
207 }
208
209 pub fn start(&self) {
210 self.state_req.send(StateRequest::Start);
211 }
212
213 pub async fn stop(&self) {
214 self.state_req
215 .call(StateRequest::Stop, ())
216 .await
217 .expect("critical channel failure")
218 }
219
220 pub async fn reset(&self) {
221 self.state_req
222 .call(StateRequest::Reset, ())
223 .await
224 .expect("critical channel failure")
225 }
226
227 pub async fn save(&self) -> anyhow::Result<Option<SavedStateBlob>> {
228 self.state_req
229 .call(StateRequest::Save, ())
230 .await
231 .expect("critical channel failure")
232 .map_err(|err| err.into())
233 }
234
235 pub async fn restore(&self, buffer: SavedStateBlob) -> anyhow::Result<()> {
236 self.state_req
237 .call(StateRequest::Restore, buffer)
238 .await
239 .expect("critical channel failure")
240 .map_err(|err| err.into())
241 }
242}
243
244impl Inspect for GenericChannelHandle {
245 fn inspect(&self, req: inspect::Request<'_>) {
246 self.state_req.send(StateRequest::Inspect(req.defer()));
247 }
248}
249
250#[must_use]
254#[derive(Inspect)]
255#[inspect(transparent)]
256pub struct ChannelHandle<T: ?Sized>(GenericChannelHandle, PhantomData<fn() -> Box<T>>);
257
258impl<T: ?Sized> std::fmt::Debug for ChannelHandle<T> {
259 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 f.pad("ChannelHandle")
261 }
262}
263
264impl<T: 'static + VmbusDevice> ChannelHandle<T> {
265 pub async fn revoke(self) -> Option<T> {
267 let device = self.0.revoke().await? as Box<dyn Any>;
268 Some(
269 *device
270 .downcast()
271 .expect("type must match the one used to create it"),
272 )
273 }
274}
275
276impl ChannelHandle<dyn VmbusDevice> {
277 pub async fn revoke(self) -> Option<Box<dyn VmbusDevice>> {
279 self.0.revoke().await
280 }
281}
282
283impl<T: 'static + VmbusDevice + ?Sized> ChannelHandle<T> {
284 pub fn start(&self) {
286 self.0.start()
287 }
288
289 pub async fn stop(&self) {
291 self.0.stop().await
292 }
293
294 pub async fn reset(&self) {
296 self.0.reset().await
297 }
298
299 pub async fn save(&self) -> anyhow::Result<Option<SavedStateBlob>> {
301 self.0.save().await
302 }
303
304 pub async fn restore(&self, buffer: SavedStateBlob) -> anyhow::Result<()> {
306 self.0.restore(buffer).await
307 }
308}
309
310async fn offer_generic(
311 driver: &impl Spawn,
312 bus: &(impl ParentBus + ?Sized),
313 mut channel: Box<dyn VmbusDevice>,
314) -> anyhow::Result<GenericChannelHandle> {
315 let offer = channel.offer();
316 let max_subchannels = channel.max_subchannels();
317 let instance_id = offer.instance_id;
318 let (request_send, request_recv) = mesh::channel();
319 let (server_request_send, server_request_recv) = mesh::channel();
320 let (state_req_send, state_req_recv) = mesh::channel();
321
322 let use_event = bus.use_event();
323
324 let events: Vec<_> = (0..max_subchannels + 1)
325 .map(|_| {
326 if use_event {
327 Notify::from_event(Event::new())
328 } else {
329 Notify::from_slim_event(Arc::new(SlimEvent::new()))
330 }
331 })
332 .collect();
333
334 let request = OfferInput {
335 params: offer,
336 event: events[0].clone().interrupt(),
337 request_send,
338 server_request_recv,
339 };
340
341 let gpadl_map = GpadlMap::new();
342
343 let offer_result = bus.add_child(request).await?;
344
345 let resources = events
346 .iter()
347 .map(|event| ChannelResources {
348 event: event.clone(),
349 })
350 .collect();
351
352 let (subchannel_enable_send, subchannel_enable_recv) = mesh::channel();
353 channel.install(DeviceResources {
354 offer_resources: offer_result,
355 gpadl_map: gpadl_map.clone().view(),
356 channels: resources,
357 channel_control: ChannelControl {
358 send: Some(subchannel_enable_send),
359 max: max_subchannels,
360 },
361 });
362
363 let bus = bus.clone_bus();
364 let task = driver.spawn(format!("vmbus offer {}", instance_id), async move {
365 let device = Device::new(
366 request_recv,
367 server_request_send,
368 events,
369 gpadl_map,
370 subchannel_enable_recv,
371 );
372 device
373 .run_channel(bus.as_ref(), channel.as_mut(), state_req_recv)
374 .await;
375 channel
376 });
377
378 Ok(GenericChannelHandle {
379 state_req: state_req_send,
380 task,
381 })
382}
383
384pub struct RestoreControl<'a> {
387 device: &'a mut Device,
388 bus: &'a dyn ParentBus,
389 offer: OfferParams,
390}
391
392impl RestoreControl<'_> {
393 pub async fn restore(
402 &mut self,
403 states: &[bool],
404 ) -> Result<Vec<Option<OpenRequest>>, ChannelRestoreError> {
405 self.device.restore(self.bus, &self.offer, states).await
406 }
407}
408
409#[derive(Debug, Error)]
411pub enum ChannelRestoreError {
412 #[error("failed to enable subchannels")]
414 EnablingSubchannels(#[source] anyhow::Error),
415 #[error("failed to restore vmbus channel")]
417 RestoreError(#[source] anyhow::Error),
418 #[error("failed to restore gpadl")]
420 GpadlError(#[source] vmbus_ring::gparange::Error),
421}
422
423impl From<ChannelRestoreError> for RestoreError {
424 fn from(err: ChannelRestoreError) -> Self {
425 RestoreError::Other(err.into())
426 }
427}
428
429enum DeviceState {
430 Running,
431 Stopped(Vec<(usize, ChannelRequest)>),
435}
436
437struct Device {
438 state: DeviceState,
439 server_requests: Vec<mesh::Sender<ChannelServerRequest>>,
440 open: Vec<bool>,
441 subchannel_gpadls: Vec<BTreeSet<GpadlId>>,
442 requests: SelectAll<TaggedStream<usize, mesh::Receiver<ChannelRequest>>>,
443 events: Vec<Notify>,
444 gpadl_map: Arc<GpadlMap>,
445 subchannel_enable_recv: mesh::Receiver<u16>,
446}
447
448impl Device {
449 fn new(
450 request_recv: mesh::Receiver<ChannelRequest>,
451 server_request_send: mesh::Sender<ChannelServerRequest>,
452 events: Vec<Notify>,
453 gpadl_map: Arc<GpadlMap>,
454 subchannel_enable_recv: mesh::Receiver<u16>,
455 ) -> Self {
456 let open: Vec<bool> = vec![false];
457 let subchannel_gpadls: Vec<BTreeSet<GpadlId>> = vec![];
458 let mut requests: SelectAll<TaggedStream<usize, mesh::Receiver<ChannelRequest>>> =
459 SelectAll::new();
460 requests.push(TaggedStream::new(0, request_recv));
461 Self {
462 state: DeviceState::Running,
463 server_requests: vec![server_request_send],
464 open,
465 subchannel_gpadls,
466 requests,
467 events,
468 gpadl_map,
469 subchannel_enable_recv,
470 }
471 }
472
473 async fn run_channel(
475 mut self,
476 bus: &dyn ParentBus,
477 channel: &mut dyn VmbusDevice,
478 state_req_recv: mesh::Receiver<StateRequest>,
479 ) {
480 enum Event {
481 Request(usize, Option<ChannelRequest>),
482 EnableSubchannels(u16),
483 StateRequest(Result<StateRequest, RecvError>),
484 }
485
486 let mut state_req_recv = pin!(futures::stream::unfold(state_req_recv, async |mut recv| {
487 Some((recv.recv().await, recv))
488 }));
489
490 let map_request = |(idx, req)| Event::Request(idx, req);
491 loop {
492 let mut s = select(
493 (&mut self.requests).map(map_request),
494 select(
495 (&mut self.subchannel_enable_recv).map(Event::EnableSubchannels),
496 (&mut state_req_recv).map(Event::StateRequest),
497 ),
498 );
499 if let Some(event) = s.next().await {
500 match event {
501 Event::Request(idx, Some(request)) => {
502 self.handle_channel_request(idx, request, channel).await;
503 }
504 Event::Request(_idx, None) => continue,
505 Event::EnableSubchannels(count) => {
506 let offer = channel.offer();
507 let _ = self.enable_channels(bus, &offer, count as usize + 1).await;
508 }
509 Event::StateRequest(Ok(request)) => {
510 self.handle_state_request(request, channel, bus).await;
511 }
512 Event::StateRequest(Err(_)) => {
513 break;
515 }
516 }
517 }
518 }
519 drop(self.server_requests);
521 for recv in self.requests.iter_mut() {
529 if recv.value().is_some() {
530 while recv.next().await.is_some() {}
531 }
532 }
533
534 for subchannel_idx in (0..self.open.len()).rev() {
535 if self.open[subchannel_idx] {
536 channel.close(subchannel_idx as u16).await;
537 }
538 }
539 }
540
541 #[instrument(level = "debug", skip_all, fields(channel_idx, ?request))]
542 async fn handle_channel_request(
543 &mut self,
544 channel_idx: usize,
545 request: ChannelRequest,
546 channel: &mut dyn VmbusDevice,
547 ) {
548 if matches!(request, ChannelRequest::Open(_) | ChannelRequest::Modify(_)) {
554 if let DeviceState::Stopped(pending_messages) = &mut self.state {
555 pending_messages.push((channel_idx, request));
556 return;
557 }
558 }
559
560 match request {
561 ChannelRequest::Open(rpc) => {
562 rpc.handle(async |open_request| {
563 self.handle_open(channel, channel_idx, open_request).await
564 })
565 .await
566 }
567 ChannelRequest::Close(rpc) => {
568 rpc.handle(async |()| {
569 self.handle_close(channel_idx, channel).await;
570 })
571 .await
572 }
573 ChannelRequest::Gpadl(rpc) => rpc.handle_sync(|gpadl| {
574 self.handle_gpadl(gpadl.id, gpadl.count, gpadl.buf, channel_idx);
575 true
576 }),
577 ChannelRequest::TeardownGpadl(rpc) => {
578 self.handle_teardown_gpadl(rpc, channel_idx);
579 }
580 ChannelRequest::Modify(rpc) => {
581 rpc.handle(async |req| {
582 self.handle_modify(channel, channel_idx, req).await;
583 0
584 })
585 .await
586 }
587 }
588 }
589
590 async fn handle_open(
591 &mut self,
592 channel: &mut dyn VmbusDevice,
593 channel_idx: usize,
594 open_request: OpenRequest,
595 ) -> bool {
596 assert!(!self.open[channel_idx]);
597 let opened = channel
600 .open(channel_idx as u16, &open_request)
601 .await
602 .inspect_err(|error| {
603 tracelimit::error_ratelimited!(
604 error = error.as_ref() as &dyn std::error::Error,
605 "failed to open channel"
606 );
607 })
608 .is_ok();
609 self.open[channel_idx] = opened;
610 opened
611 }
612
613 async fn handle_close(&mut self, channel_idx: usize, channel: &mut dyn VmbusDevice) {
614 assert!(self.open[channel_idx]);
615 if channel_idx == 0 {
616 self.server_requests.truncate(1);
618 for recv in self.requests.iter_mut() {
619 if let Some(&idx) = recv.value() {
620 if idx > 0 {
621 while recv.next().await.is_some() {}
622 }
623 }
624 }
625 for subchannel_idx in 1..self.open.len() {
626 if self.open[subchannel_idx] {
627 channel.close(subchannel_idx as u16).await;
628 }
629 for &gpadl_id in &self.subchannel_gpadls[subchannel_idx - 1] {
630 self.gpadl_map.remove(gpadl_id, Box::new(|| ()));
631 }
632 }
633 self.open.truncate(1);
634 self.subchannel_gpadls.clear();
635 }
636 channel.close(channel_idx as u16).await;
637 self.open[channel_idx] = false;
638 if channel_idx == 0 {
639 while self.subchannel_enable_recv.try_recv().is_ok() {}
641 }
642 }
643
644 fn handle_gpadl(&mut self, id: GpadlId, count: u16, buf: Vec<u64>, channel_idx: usize) {
645 self.gpadl_map
646 .add(id, MultiPagedRangeBuf::new(count.into(), buf).unwrap());
647 if channel_idx > 0 {
648 self.subchannel_gpadls[channel_idx - 1].insert(id);
649 }
650 }
651
652 fn handle_teardown_gpadl(&mut self, rpc: Rpc<GpadlId, ()>, channel_idx: usize) {
653 let id = *rpc.input();
654 if let Some(f) = self.gpadl_map.remove(
655 id,
656 Box::new(move || {
657 rpc.complete(());
658 }),
659 ) {
660 f()
661 }
662 if channel_idx > 0 {
663 assert!(self.subchannel_gpadls[channel_idx - 1].remove(&id));
664 }
665 }
666
667 async fn handle_modify(
668 &mut self,
669 channel: &mut dyn VmbusDevice,
670 channel_idx: usize,
671 req: ModifyRequest,
672 ) {
673 match req {
674 ModifyRequest::TargetVp { target_vp } => {
675 channel.retarget_vp(channel_idx as u16, target_vp).await
676 }
677 }
678 }
679
680 #[instrument(level = "debug", skip_all, fields(?request))]
681 async fn handle_state_request(
682 &mut self,
683 request: StateRequest,
684 channel: &mut dyn VmbusDevice,
685 bus: &dyn ParentBus,
686 ) {
687 match request {
688 StateRequest::Start => {
689 channel.start();
690 if let DeviceState::Stopped(pending_messages) =
691 std::mem::replace(&mut self.state, DeviceState::Running)
692 {
693 for (channel_idx, request) in pending_messages.into_iter() {
694 self.handle_channel_request(channel_idx, request, channel)
695 .await;
696 }
697 }
698 }
699 StateRequest::Stop(rpc) => {
700 if matches!(self.state, DeviceState::Running) {
701 self.state = DeviceState::Stopped(Vec::new());
702 rpc.handle(async |()| {
703 channel.stop().await;
704 })
705 .await;
706 } else {
707 rpc.complete(());
708 }
709 }
710 StateRequest::Reset(rpc) => {
711 if let DeviceState::Stopped(pending_messages) = &mut self.state {
712 pending_messages.clear();
713 }
714 rpc.complete(());
715 }
716 StateRequest::Save(rpc) => {
717 rpc.handle_failable(async |()| {
718 if let Some(channel) = channel.supports_save_restore() {
719 channel.save().await.map(Some)
720 } else {
721 Ok(None)
722 }
723 })
724 .await;
725 }
726 StateRequest::Restore(rpc) => {
727 rpc.handle_failable(async |buffer| {
728 let channel = channel
729 .supports_save_restore()
730 .context("saved state not supported")?;
731 let control = RestoreControl {
732 device: &mut *self,
733 offer: channel.offer(),
734 bus,
735 };
736 channel
737 .restore(control, buffer)
738 .await
739 .map_err(anyhow::Error::from)?;
740 anyhow::Ok(())
741 })
742 .await;
743 }
744 StateRequest::Inspect(deferred) => {
745 deferred.inspect(&mut *channel);
746 }
747 }
748 }
749
750 async fn enable_channels(
751 &mut self,
752 bus: &dyn ParentBus,
753 offer: &OfferParams,
754 count: usize,
755 ) -> anyhow::Result<()> {
756 let mut r = Ok(());
758 for subchannel_idx in self.server_requests.len()..count {
759 let (request_send, request_recv) = mesh::channel();
760 let (server_request_send, server_request_recv) = mesh::channel();
761 let request = OfferInput {
762 params: OfferParams {
763 subchannel_index: subchannel_idx as u16,
764 ..offer.clone()
765 },
766 event: self.events[subchannel_idx].clone().interrupt(),
767 request_send,
768 server_request_recv,
769 };
770 match bus.add_child(request).await {
771 Ok(_) => {
772 self.requests
773 .push(TaggedStream::new(subchannel_idx, request_recv));
774 self.server_requests.push(server_request_send);
775 self.subchannel_gpadls.push(BTreeSet::new());
776 self.open.push(false);
777 }
778 Err(err) => {
779 tracing::error!(
780 error = err.as_ref() as &dyn std::error::Error,
781 "could not offer subchannel"
782 );
783 if r.is_ok() {
784 r = Err(err);
785 }
786 }
787 }
788 }
789 r
790 }
791
792 pub async fn restore(
793 &mut self,
794 bus: &dyn ParentBus,
795 offer: &OfferParams,
796 states: &[bool],
797 ) -> Result<Vec<Option<OpenRequest>>, ChannelRestoreError> {
798 self.enable_channels(bus, offer, states.len())
799 .await
800 .map_err(ChannelRestoreError::EnablingSubchannels)?;
801
802 let mut results = Vec::with_capacity(states.len());
803 for (channel_idx, open) in states.iter().copied().enumerate() {
804 let result = self.server_requests[channel_idx]
805 .call_failable(ChannelServerRequest::Restore, open)
806 .await
807 .map_err(|err| ChannelRestoreError::RestoreError(err.into()))?;
808
809 assert!(open == result.open_request.is_some());
810
811 for gpadl in result.gpadls {
812 let buf =
813 match MultiPagedRangeBuf::new(gpadl.request.count.into(), gpadl.request.buf) {
814 Ok(buf) => buf,
815 Err(err) => {
816 if gpadl.accepted {
817 return Err(ChannelRestoreError::GpadlError(err));
818 } else {
819 continue;
822 }
823 }
824 };
825 self.gpadl_map.add(gpadl.request.id, buf);
826 if channel_idx > 0 {
827 self.subchannel_gpadls[channel_idx - 1].insert(gpadl.request.id);
828 }
829 }
830
831 results.push(result.open_request);
832 }
833 self.open.copy_from_slice(states);
834 Ok(results)
835 }
836}
837
838pub async fn offer_channel<T: 'static + VmbusDevice>(
841 driver: &impl Spawn,
842 bus: &(impl ParentBus + ?Sized),
843 channel: T,
844) -> anyhow::Result<ChannelHandle<T>> {
845 let handle = offer_generic(driver, bus, Box::new(channel)).await?;
846 Ok(ChannelHandle(handle, PhantomData))
847}
848
849pub async fn offer_generic_channel(
851 driver: &impl Spawn,
852 bus: &(impl ParentBus + ?Sized),
853 channel: Box<dyn VmbusDevice>,
854) -> anyhow::Result<ChannelHandle<dyn VmbusDevice>> {
855 let handle = offer_generic(driver, bus, channel).await?;
856 Ok(ChannelHandle(handle, PhantomData))
857}