1use crate::bus::ChannelRequest;
7use crate::bus::ChannelServerRequest;
8use crate::bus::ModifyRequest;
9use crate::bus::OfferInput;
10use crate::bus::OfferParams;
11use crate::bus::OfferResources;
12use crate::bus::OpenRequest;
13use crate::bus::OpenResult;
14use crate::bus::ParentBus;
15use crate::gpadl::GpadlMap;
16use crate::gpadl::GpadlMapView;
17use anyhow::Context;
18use async_trait::async_trait;
19use futures::StreamExt;
20use futures::stream::SelectAll;
21use futures::stream::select;
22use inspect::Inspect;
23use inspect::InspectMut;
24use mesh::RecvError;
25use mesh::rpc::FailableRpc;
26use mesh::rpc::Rpc;
27use mesh::rpc::RpcSend;
28use pal_async::task::Spawn;
29use pal_async::task::Task;
30use pal_event::Event;
31use std::any::Any;
32use std::collections::BTreeSet;
33use std::marker::PhantomData;
34use std::pin::pin;
35use std::sync::Arc;
36use thiserror::Error;
37use tracing::instrument;
38use vmbus_core::TaggedStream;
39use vmbus_core::protocol::GpadlId;
40use vmbus_ring::gparange::MultiPagedRangeBuf;
41use vmcore::notify::Notify;
42use vmcore::save_restore::RestoreError;
43use vmcore::save_restore::SaveError;
44use vmcore::save_restore::SavedStateBlob;
45use vmcore::slim_event::SlimEvent;
46
47pub type ChannelOpenError = anyhow::Error;
49
50#[async_trait]
52pub trait VmbusDevice: Send + Any + InspectMut {
53 fn offer(&self) -> OfferParams;
55
56 fn max_subchannels(&self) -> u16 {
58 0
59 }
60
61 fn install(&mut self, resources: DeviceResources);
63
64 async fn open(
66 &mut self,
67 channel_idx: u16,
68 open_request: &OpenRequest,
69 ) -> Result<(), ChannelOpenError>;
70
71 async fn close(&mut self, channel_idx: u16);
73
74 async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32);
76
77 fn start(&mut self);
79
80 async fn stop(&mut self);
82
83 fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice>;
88}
89
90#[async_trait]
92pub trait SaveRestoreVmbusDevice: VmbusDevice {
93 async fn save(&mut self) -> Result<SavedStateBlob, SaveError>;
95
96 async fn restore(
101 &mut self,
102 control: RestoreControl<'_>,
103 state: SavedStateBlob,
104 ) -> Result<(), RestoreError>;
105}
106
107#[derive(Debug, Default)]
109pub struct DeviceResources {
110 pub offer_resources: OfferResources,
112 pub gpadl_map: GpadlMapView,
114 pub channel_control: ChannelControl,
116 pub channels: Vec<ChannelResources>,
118}
119
120#[derive(Debug)]
122pub struct ChannelResources {
123 pub event: Notify,
125}
126
127#[derive(Debug, Default, Clone)]
129pub struct ChannelControl {
130 send: Option<mesh::Sender<u16>>,
131 max: u16,
132}
133
134#[derive(Debug, Error)]
136#[error("too many subchannels requested")]
137pub struct TooManySubchannels;
138
139impl ChannelControl {
140 pub fn enable_subchannels(&self, count: u16) -> Result<(), TooManySubchannels> {
147 if count > self.max {
148 return Err(TooManySubchannels);
149 }
150 if let Some(send) = &self.send {
151 send.send(count);
152 }
153 Ok(())
154 }
155
156 pub fn max_subchannels(&self) -> u16 {
158 self.max
159 }
160}
161
162#[must_use]
166pub(crate) struct GenericChannelHandle {
167 state_req: mesh::Sender<StateRequest>,
168 task: Task<Box<dyn VmbusDevice>>,
169}
170
171#[derive(Debug)]
172enum StateRequest {
173 Start,
175 Stop(Rpc<(), ()>),
177
178 Reset(Rpc<(), ()>),
182
183 Save(FailableRpc<(), Option<SavedStateBlob>>),
187
188 Restore(FailableRpc<SavedStateBlob, ()>),
192
193 Inspect(inspect::Deferred),
195}
196
197impl std::fmt::Debug for GenericChannelHandle {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 f.pad("ChannelHandle")
200 }
201}
202
203impl GenericChannelHandle {
204 pub async fn revoke(self) -> Option<Box<dyn VmbusDevice>> {
206 drop(self.state_req);
207 Some(self.task.await)
208 }
209
210 pub fn start(&self) {
211 self.state_req.send(StateRequest::Start);
212 }
213
214 pub async fn stop(&self) {
215 self.state_req
216 .call(StateRequest::Stop, ())
217 .await
218 .expect("critical channel failure")
219 }
220
221 pub async fn reset(&self) {
222 self.state_req
223 .call(StateRequest::Reset, ())
224 .await
225 .expect("critical channel failure")
226 }
227
228 pub async fn save(&self) -> anyhow::Result<Option<SavedStateBlob>> {
229 self.state_req
230 .call(StateRequest::Save, ())
231 .await
232 .expect("critical channel failure")
233 .map_err(|err| err.into())
234 }
235
236 pub async fn restore(&self, buffer: SavedStateBlob) -> anyhow::Result<()> {
237 self.state_req
238 .call(StateRequest::Restore, buffer)
239 .await
240 .expect("critical channel failure")
241 .map_err(|err| err.into())
242 }
243}
244
245impl Inspect for GenericChannelHandle {
246 fn inspect(&self, req: inspect::Request<'_>) {
247 self.state_req.send(StateRequest::Inspect(req.defer()));
248 }
249}
250
251#[must_use]
255#[derive(Inspect)]
256#[inspect(transparent)]
257pub struct ChannelHandle<T: ?Sized>(GenericChannelHandle, PhantomData<fn() -> Box<T>>);
258
259impl<T: ?Sized> std::fmt::Debug for ChannelHandle<T> {
260 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261 f.pad("ChannelHandle")
262 }
263}
264
265impl<T: 'static + VmbusDevice> ChannelHandle<T> {
266 pub async fn revoke(self) -> Option<T> {
268 let device = self.0.revoke().await? as Box<dyn Any>;
269 Some(
270 *device
271 .downcast()
272 .expect("type must match the one used to create it"),
273 )
274 }
275}
276
277impl ChannelHandle<dyn VmbusDevice> {
278 pub async fn revoke(self) -> Option<Box<dyn VmbusDevice>> {
280 self.0.revoke().await
281 }
282}
283
284impl<T: 'static + VmbusDevice + ?Sized> ChannelHandle<T> {
285 pub fn start(&self) {
287 self.0.start()
288 }
289
290 pub async fn stop(&self) {
292 self.0.stop().await
293 }
294
295 pub async fn reset(&self) {
297 self.0.reset().await
298 }
299
300 pub async fn save(&self) -> anyhow::Result<Option<SavedStateBlob>> {
302 self.0.save().await
303 }
304
305 pub async fn restore(&self, buffer: SavedStateBlob) -> anyhow::Result<()> {
307 self.0.restore(buffer).await
308 }
309}
310
311async fn offer_generic(
312 driver: &impl Spawn,
313 bus: &(impl ParentBus + ?Sized),
314 mut channel: Box<dyn VmbusDevice>,
315) -> anyhow::Result<GenericChannelHandle> {
316 let offer = channel.offer();
317 let max_subchannels = channel.max_subchannels();
318 let instance_id = offer.instance_id;
319 let (request_send, request_recv) = mesh::channel();
320 let (server_request_send, server_request_recv) = mesh::channel();
321 let (state_req_send, state_req_recv) = mesh::channel();
322
323 let use_event = bus.use_event();
324
325 let events: Vec<_> = (0..max_subchannels + 1)
326 .map(|_| {
327 if use_event {
328 Notify::from_event(Event::new())
329 } else {
330 Notify::from_slim_event(Arc::new(SlimEvent::new()))
331 }
332 })
333 .collect();
334
335 let request = OfferInput {
336 params: offer,
337 request_send,
338 server_request_recv,
339 };
340
341 let gpadl_map = GpadlMap::new();
342
343 let offer_result = bus.add_child(request).await?;
344
345 let resources = events
346 .iter()
347 .map(|event| ChannelResources {
348 event: event.clone(),
349 })
350 .collect();
351
352 let (subchannel_enable_send, subchannel_enable_recv) = mesh::channel();
353 channel.install(DeviceResources {
354 offer_resources: offer_result,
355 gpadl_map: gpadl_map.clone().view(),
356 channels: resources,
357 channel_control: ChannelControl {
358 send: Some(subchannel_enable_send),
359 max: max_subchannels,
360 },
361 });
362
363 let bus = bus.clone_bus();
364 let task = driver.spawn(format!("vmbus offer {}", instance_id), async move {
365 let device = Device::new(
366 request_recv,
367 server_request_send,
368 events,
369 gpadl_map,
370 subchannel_enable_recv,
371 );
372 device
373 .run_channel(bus.as_ref(), channel.as_mut(), state_req_recv)
374 .await;
375 channel
376 });
377
378 Ok(GenericChannelHandle {
379 state_req: state_req_send,
380 task,
381 })
382}
383
384pub struct RestoreControl<'a> {
387 device: &'a mut Device,
388 bus: &'a dyn ParentBus,
389 offer: OfferParams,
390}
391
392impl RestoreControl<'_> {
393 pub async fn restore(
402 &mut self,
403 states: &[bool],
404 ) -> Result<Vec<Option<OpenRequest>>, ChannelRestoreError> {
405 self.device.restore(self.bus, &self.offer, states).await
406 }
407}
408
409#[derive(Debug, Error)]
411pub enum ChannelRestoreError {
412 #[error("failed to enable subchannels")]
414 EnablingSubchannels(#[source] anyhow::Error),
415 #[error("failed to restore vmbus channel")]
417 RestoreError(#[source] anyhow::Error),
418 #[error("failed to restore gpadl")]
420 GpadlError(#[source] vmbus_ring::gparange::Error),
421}
422
423impl From<ChannelRestoreError> for RestoreError {
424 fn from(err: ChannelRestoreError) -> Self {
425 RestoreError::Other(err.into())
426 }
427}
428
429enum DeviceState {
430 Running,
431 Stopped(Vec<(usize, ChannelRequest)>),
435}
436
437struct Device {
438 state: DeviceState,
439 server_requests: Vec<mesh::Sender<ChannelServerRequest>>,
440 open: Vec<bool>,
441 subchannel_gpadls: Vec<BTreeSet<GpadlId>>,
442 requests: SelectAll<TaggedStream<usize, mesh::Receiver<ChannelRequest>>>,
443 events: Vec<Notify>,
444 gpadl_map: Arc<GpadlMap>,
445 subchannel_enable_recv: mesh::Receiver<u16>,
446}
447
448impl Device {
449 fn new(
450 request_recv: mesh::Receiver<ChannelRequest>,
451 server_request_send: mesh::Sender<ChannelServerRequest>,
452 events: Vec<Notify>,
453 gpadl_map: Arc<GpadlMap>,
454 subchannel_enable_recv: mesh::Receiver<u16>,
455 ) -> Self {
456 let open: Vec<bool> = vec![false];
457 let subchannel_gpadls: Vec<BTreeSet<GpadlId>> = vec![];
458 let mut requests: SelectAll<TaggedStream<usize, mesh::Receiver<ChannelRequest>>> =
459 SelectAll::new();
460 requests.push(TaggedStream::new(0, request_recv));
461 Self {
462 state: DeviceState::Running,
463 server_requests: vec![server_request_send],
464 open,
465 subchannel_gpadls,
466 requests,
467 events,
468 gpadl_map,
469 subchannel_enable_recv,
470 }
471 }
472
473 async fn run_channel(
475 mut self,
476 bus: &dyn ParentBus,
477 channel: &mut dyn VmbusDevice,
478 state_req_recv: mesh::Receiver<StateRequest>,
479 ) {
480 enum Event {
481 Request(usize, Option<ChannelRequest>),
482 EnableSubchannels(u16),
483 StateRequest(Result<StateRequest, RecvError>),
484 }
485
486 let mut state_req_recv = pin!(futures::stream::unfold(state_req_recv, async |mut recv| {
487 Some((recv.recv().await, recv))
488 }));
489
490 let map_request = |(idx, req)| Event::Request(idx, req);
491 loop {
492 let mut s = select(
493 (&mut self.requests).map(map_request),
494 select(
495 (&mut self.subchannel_enable_recv).map(Event::EnableSubchannels),
496 (&mut state_req_recv).map(Event::StateRequest),
497 ),
498 );
499 if let Some(event) = s.next().await {
500 match event {
501 Event::Request(idx, Some(request)) => {
502 self.handle_channel_request(idx, request, channel).await;
503 }
504 Event::Request(_idx, None) => continue,
505 Event::EnableSubchannels(count) => {
506 let offer = channel.offer();
507 let _ = self.enable_channels(bus, &offer, count as usize + 1).await;
508 }
509 Event::StateRequest(Ok(request)) => {
510 self.handle_state_request(request, channel, bus).await;
511 }
512 Event::StateRequest(Err(_)) => {
513 break;
515 }
516 }
517 }
518 }
519 drop(self.server_requests);
521 for recv in self.requests.iter_mut() {
529 if recv.value().is_some() {
530 while recv.next().await.is_some() {}
531 }
532 }
533
534 for subchannel_idx in (0..self.open.len()).rev() {
535 if self.open[subchannel_idx] {
536 channel.close(subchannel_idx as u16).await;
537 }
538 }
539 }
540
541 #[instrument(level = "debug", skip_all, fields(channel_idx, ?request))]
542 async fn handle_channel_request(
543 &mut self,
544 channel_idx: usize,
545 request: ChannelRequest,
546 channel: &mut dyn VmbusDevice,
547 ) {
548 if matches!(request, ChannelRequest::Open(_) | ChannelRequest::Modify(_)) {
554 if let DeviceState::Stopped(pending_messages) = &mut self.state {
555 pending_messages.push((channel_idx, request));
556 return;
557 }
558 }
559
560 match request {
561 ChannelRequest::Open(rpc) => {
562 rpc.handle(async |open_request| {
563 self.handle_open(channel, channel_idx, open_request).await
564 })
565 .await
566 }
567 ChannelRequest::Close(rpc) => {
568 rpc.handle(async |()| {
569 self.handle_close(channel_idx, channel).await;
570 })
571 .await
572 }
573 ChannelRequest::Gpadl(rpc) => rpc.handle_sync(|gpadl| {
574 self.handle_gpadl(gpadl.id, gpadl.count, gpadl.buf, channel_idx);
575 true
576 }),
577 ChannelRequest::TeardownGpadl(rpc) => {
578 self.handle_teardown_gpadl(rpc, channel_idx);
579 }
580 ChannelRequest::Modify(rpc) => {
581 rpc.handle(async |req| {
582 self.handle_modify(channel, channel_idx, req).await;
583 0
584 })
585 .await
586 }
587 }
588 }
589
590 async fn handle_open(
591 &mut self,
592 channel: &mut dyn VmbusDevice,
593 channel_idx: usize,
594 open_request: OpenRequest,
595 ) -> Option<OpenResult> {
596 assert!(!self.open[channel_idx]);
597 let opened = if let Err(error) = channel.open(channel_idx as u16, &open_request).await {
600 tracelimit::error_ratelimited!(
601 error = error.as_ref() as &dyn std::error::Error,
602 "failed to open channel"
603 );
604 None
605 } else {
606 Some(OpenResult {
607 guest_to_host_interrupt: self.events[channel_idx].clone().interrupt(),
608 })
609 };
610 self.open[channel_idx] = opened.is_some();
611 opened
612 }
613
614 async fn handle_close(&mut self, channel_idx: usize, channel: &mut dyn VmbusDevice) {
615 assert!(self.open[channel_idx]);
616 if channel_idx == 0 {
617 self.server_requests.truncate(1);
619 for recv in self.requests.iter_mut() {
620 if let Some(&idx) = recv.value() {
621 if idx > 0 {
622 while recv.next().await.is_some() {}
623 }
624 }
625 }
626 for subchannel_idx in 1..self.open.len() {
627 if self.open[subchannel_idx] {
628 channel.close(subchannel_idx as u16).await;
629 }
630 for &gpadl_id in &self.subchannel_gpadls[subchannel_idx - 1] {
631 self.gpadl_map.remove(gpadl_id, Box::new(|| ()));
632 }
633 }
634 self.open.truncate(1);
635 self.subchannel_gpadls.clear();
636 }
637 channel.close(channel_idx as u16).await;
638 self.open[channel_idx] = false;
639 if channel_idx == 0 {
640 while self.subchannel_enable_recv.try_recv().is_ok() {}
642 }
643 }
644
645 fn handle_gpadl(&mut self, id: GpadlId, count: u16, buf: Vec<u64>, channel_idx: usize) {
646 self.gpadl_map
647 .add(id, MultiPagedRangeBuf::new(count.into(), buf).unwrap());
648 if channel_idx > 0 {
649 self.subchannel_gpadls[channel_idx - 1].insert(id);
650 }
651 }
652
653 fn handle_teardown_gpadl(&mut self, rpc: Rpc<GpadlId, ()>, channel_idx: usize) {
654 let id = *rpc.input();
655 if let Some(f) = self.gpadl_map.remove(
656 id,
657 Box::new(move || {
658 rpc.complete(());
659 }),
660 ) {
661 f()
662 }
663 if channel_idx > 0 {
664 assert!(self.subchannel_gpadls[channel_idx - 1].remove(&id));
665 }
666 }
667
668 async fn handle_modify(
669 &mut self,
670 channel: &mut dyn VmbusDevice,
671 channel_idx: usize,
672 req: ModifyRequest,
673 ) {
674 match req {
675 ModifyRequest::TargetVp { target_vp } => {
676 channel.retarget_vp(channel_idx as u16, target_vp).await
677 }
678 }
679 }
680
681 #[instrument(level = "debug", skip_all, fields(?request))]
682 async fn handle_state_request(
683 &mut self,
684 request: StateRequest,
685 channel: &mut dyn VmbusDevice,
686 bus: &dyn ParentBus,
687 ) {
688 match request {
689 StateRequest::Start => {
690 channel.start();
691 if let DeviceState::Stopped(pending_messages) =
692 std::mem::replace(&mut self.state, DeviceState::Running)
693 {
694 for (channel_idx, request) in pending_messages.into_iter() {
695 self.handle_channel_request(channel_idx, request, channel)
696 .await;
697 }
698 }
699 }
700 StateRequest::Stop(rpc) => {
701 if matches!(self.state, DeviceState::Running) {
702 self.state = DeviceState::Stopped(Vec::new());
703 rpc.handle(async |()| {
704 channel.stop().await;
705 })
706 .await;
707 } else {
708 rpc.complete(());
709 }
710 }
711 StateRequest::Reset(rpc) => {
712 if let DeviceState::Stopped(pending_messages) = &mut self.state {
713 pending_messages.clear();
714 }
715 rpc.complete(());
716 }
717 StateRequest::Save(rpc) => {
718 rpc.handle_failable(async |()| {
719 if let Some(channel) = channel.supports_save_restore() {
720 channel.save().await.map(Some)
721 } else {
722 Ok(None)
723 }
724 })
725 .await;
726 }
727 StateRequest::Restore(rpc) => {
728 rpc.handle_failable(async |buffer| {
729 let channel = channel
730 .supports_save_restore()
731 .context("saved state not supported")?;
732 let control = RestoreControl {
733 device: &mut *self,
734 offer: channel.offer(),
735 bus,
736 };
737 channel
738 .restore(control, buffer)
739 .await
740 .map_err(anyhow::Error::from)?;
741 anyhow::Ok(())
742 })
743 .await;
744 }
745 StateRequest::Inspect(deferred) => {
746 deferred.inspect(&mut *channel);
747 }
748 }
749 }
750
751 async fn enable_channels(
752 &mut self,
753 bus: &dyn ParentBus,
754 offer: &OfferParams,
755 count: usize,
756 ) -> anyhow::Result<()> {
757 let mut r = Ok(());
759 for subchannel_idx in self.server_requests.len()..count {
760 let (request_send, request_recv) = mesh::channel();
761 let (server_request_send, server_request_recv) = mesh::channel();
762 let request = OfferInput {
763 params: OfferParams {
764 subchannel_index: subchannel_idx as u16,
765 ..offer.clone()
766 },
767 request_send,
768 server_request_recv,
769 };
770 match bus.add_child(request).await {
771 Ok(_) => {
772 self.requests
773 .push(TaggedStream::new(subchannel_idx, request_recv));
774 self.server_requests.push(server_request_send);
775 self.subchannel_gpadls.push(BTreeSet::new());
776 self.open.push(false);
777 }
778 Err(err) => {
779 tracing::error!(
780 error = err.as_ref() as &dyn std::error::Error,
781 "could not offer subchannel"
782 );
783 if r.is_ok() {
784 r = Err(err);
785 }
786 }
787 }
788 }
789 r
790 }
791
792 pub async fn restore(
793 &mut self,
794 bus: &dyn ParentBus,
795 offer: &OfferParams,
796 states: &[bool],
797 ) -> Result<Vec<Option<OpenRequest>>, ChannelRestoreError> {
798 self.enable_channels(bus, offer, states.len())
799 .await
800 .map_err(ChannelRestoreError::EnablingSubchannels)?;
801
802 let mut results = Vec::with_capacity(states.len());
803 for (channel_idx, (open, event)) in states.iter().copied().zip(&self.events).enumerate() {
804 let open_result = open.then(|| OpenResult {
805 guest_to_host_interrupt: event.clone().interrupt(),
806 });
807 let result = self.server_requests[channel_idx]
808 .call_failable(ChannelServerRequest::Restore, open_result)
809 .await
810 .map_err(|err| ChannelRestoreError::RestoreError(err.into()))?;
811
812 assert!(open == result.open_request.is_some());
813
814 for gpadl in result.gpadls {
815 let buf =
816 match MultiPagedRangeBuf::new(gpadl.request.count.into(), gpadl.request.buf) {
817 Ok(buf) => buf,
818 Err(err) => {
819 if gpadl.accepted {
820 return Err(ChannelRestoreError::GpadlError(err));
821 } else {
822 continue;
825 }
826 }
827 };
828 self.gpadl_map.add(gpadl.request.id, buf);
829 if channel_idx > 0 {
830 self.subchannel_gpadls[channel_idx - 1].insert(gpadl.request.id);
831 }
832 }
833
834 results.push(result.open_request);
835 }
836 self.open.copy_from_slice(states);
837 Ok(results)
838 }
839}
840
841pub async fn offer_channel<T: 'static + VmbusDevice>(
844 driver: &impl Spawn,
845 bus: &(impl ParentBus + ?Sized),
846 channel: T,
847) -> anyhow::Result<ChannelHandle<T>> {
848 let handle = offer_generic(driver, bus, Box::new(channel)).await?;
849 Ok(ChannelHandle(handle, PhantomData))
850}
851
852pub async fn offer_generic_channel(
854 driver: &impl Spawn,
855 bus: &(impl ParentBus + ?Sized),
856 channel: Box<dyn VmbusDevice>,
857) -> anyhow::Result<ChannelHandle<dyn VmbusDevice>> {
858 let handle = offer_generic(driver, bus, channel).await?;
859 Ok(ChannelHandle(handle, PhantomData))
860}