1use crate::bus::ChannelRequest;
7use crate::bus::ChannelServerRequest;
8use crate::bus::ModifyRequest;
9use crate::bus::OfferInput;
10use crate::bus::OfferParams;
11use crate::bus::OfferResources;
12use crate::bus::OpenRequest;
13use crate::bus::ParentBus;
14use crate::gpadl::GpadlMap;
15use crate::gpadl::GpadlMapView;
16use anyhow::Context;
17use async_trait::async_trait;
18use futures::StreamExt;
19use futures::stream::SelectAll;
20use futures::stream::select;
21use inspect::Inspect;
22use inspect::InspectMut;
23use mesh::RecvError;
24use mesh::rpc::FailableRpc;
25use mesh::rpc::Rpc;
26use mesh::rpc::RpcSend;
27use pal_async::task::Spawn;
28use pal_async::task::Task;
29use pal_event::Event;
30use std::any::Any;
31use std::collections::BTreeSet;
32use std::marker::PhantomData;
33use std::pin::pin;
34use std::sync::Arc;
35use thiserror::Error;
36use tracing::instrument;
37use vmbus_core::TaggedStream;
38use vmbus_core::protocol::GpadlId;
39use vmbus_ring::gparange::MultiPagedRangeBuf;
40use vmcore::notify::Notify;
41use vmcore::save_restore::RestoreError;
42use vmcore::save_restore::SaveError;
43use vmcore::save_restore::SavedStateBlob;
44use vmcore::slim_event::SlimEvent;
45
46pub type ChannelOpenError = anyhow::Error;
48
49#[async_trait]
51pub trait VmbusDevice: Send + Any + InspectMut {
52 fn offer(&self) -> OfferParams;
54
55 fn max_subchannels(&self) -> u16 {
57 0
58 }
59
60 fn install(&mut self, resources: DeviceResources);
62
63 async fn open(
65 &mut self,
66 channel_idx: u16,
67 open_request: &OpenRequest,
68 ) -> Result<(), ChannelOpenError>;
69
70 async fn close(&mut self, channel_idx: u16);
72
73 async fn retarget_vp(&mut self, channel_idx: u16, target_vp: u32);
75
76 fn start(&mut self);
78
79 async fn stop(&mut self);
81
82 fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice>;
87}
88
89#[async_trait]
91pub trait SaveRestoreVmbusDevice: VmbusDevice {
92 async fn save(&mut self) -> Result<SavedStateBlob, SaveError>;
94
95 async fn restore(
100 &mut self,
101 control: RestoreControl<'_>,
102 state: SavedStateBlob,
103 ) -> Result<(), RestoreError>;
104}
105
106#[derive(Debug, Default)]
108pub struct DeviceResources {
109 pub offer_resources: OfferResources,
111 pub gpadl_map: GpadlMapView,
113 pub channel_control: ChannelControl,
115 pub channels: Vec<ChannelResources>,
117}
118
119#[derive(Debug)]
121pub struct ChannelResources {
122 pub event: Notify,
124}
125
126#[derive(Debug, Default, Clone)]
128pub struct ChannelControl {
129 send: Option<mesh::Sender<u16>>,
130 max: u16,
131}
132
133#[derive(Debug, Error)]
135#[error("too many subchannels requested")]
136pub struct TooManySubchannels;
137
138impl ChannelControl {
139 pub fn enable_subchannels(&self, count: u16) -> Result<(), TooManySubchannels> {
146 if count > self.max {
147 return Err(TooManySubchannels);
148 }
149 if let Some(send) = &self.send {
150 send.send(count);
151 }
152 Ok(())
153 }
154
155 pub fn max_subchannels(&self) -> u16 {
157 self.max
158 }
159}
160
161#[must_use]
165#[derive(Inspect)]
166pub(crate) struct GenericChannelHandle {
167 #[inspect(flatten, send = "StateRequest::Inspect")]
168 state_req: mesh::Sender<StateRequest>,
169 #[inspect(skip)]
170 task: Task<Box<dyn VmbusDevice>>,
171}
172
173#[derive(Debug)]
174enum StateRequest {
175 Start,
177 Stop(Rpc<(), ()>),
179
180 Reset(Rpc<(), ()>),
184
185 Save(FailableRpc<(), Option<SavedStateBlob>>),
189
190 Restore(FailableRpc<SavedStateBlob, ()>),
194
195 Inspect(inspect::Deferred),
197}
198
199impl std::fmt::Debug for GenericChannelHandle {
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 f.pad("ChannelHandle")
202 }
203}
204
205impl GenericChannelHandle {
206 pub async fn revoke(self) -> Option<Box<dyn VmbusDevice>> {
208 drop(self.state_req);
209 Some(self.task.await)
210 }
211
212 pub fn start(&self) {
213 self.state_req.send(StateRequest::Start);
214 }
215
216 pub async fn stop(&self) {
217 self.state_req
218 .call(StateRequest::Stop, ())
219 .await
220 .expect("critical channel failure")
221 }
222
223 pub async fn reset(&self) {
224 self.state_req
225 .call(StateRequest::Reset, ())
226 .await
227 .expect("critical channel failure")
228 }
229
230 pub async fn save(&self) -> anyhow::Result<Option<SavedStateBlob>> {
231 self.state_req
232 .call(StateRequest::Save, ())
233 .await
234 .expect("critical channel failure")
235 .map_err(|err| err.into())
236 }
237
238 pub async fn restore(&self, buffer: SavedStateBlob) -> anyhow::Result<()> {
239 self.state_req
240 .call(StateRequest::Restore, buffer)
241 .await
242 .expect("critical channel failure")
243 .map_err(|err| err.into())
244 }
245}
246
247#[must_use]
251#[derive(Inspect)]
252#[inspect(transparent)]
253pub struct ChannelHandle<T: ?Sized>(GenericChannelHandle, PhantomData<fn() -> Box<T>>);
254
255impl<T: ?Sized> std::fmt::Debug for ChannelHandle<T> {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 f.pad("ChannelHandle")
258 }
259}
260
261impl<T: 'static + VmbusDevice> ChannelHandle<T> {
262 pub async fn revoke(self) -> Option<T> {
264 let device = self.0.revoke().await? as Box<dyn Any>;
265 Some(
266 *device
267 .downcast()
268 .expect("type must match the one used to create it"),
269 )
270 }
271}
272
273impl ChannelHandle<dyn VmbusDevice> {
274 pub async fn revoke(self) -> Option<Box<dyn VmbusDevice>> {
276 self.0.revoke().await
277 }
278}
279
280impl<T: 'static + VmbusDevice + ?Sized> ChannelHandle<T> {
281 pub fn start(&self) {
283 self.0.start()
284 }
285
286 pub async fn stop(&self) {
288 self.0.stop().await
289 }
290
291 pub async fn reset(&self) {
293 self.0.reset().await
294 }
295
296 pub async fn save(&self) -> anyhow::Result<Option<SavedStateBlob>> {
298 self.0.save().await
299 }
300
301 pub async fn restore(&self, buffer: SavedStateBlob) -> anyhow::Result<()> {
303 self.0.restore(buffer).await
304 }
305}
306
307async fn offer_generic(
308 driver: &impl Spawn,
309 bus: &(impl ParentBus + ?Sized),
310 mut channel: Box<dyn VmbusDevice>,
311) -> anyhow::Result<GenericChannelHandle> {
312 let offer = channel.offer();
313 let max_subchannels = channel.max_subchannels();
314 let instance_id = offer.instance_id;
315 let (request_send, request_recv) = mesh::channel();
316 let (server_request_send, server_request_recv) = mesh::channel();
317 let (state_req_send, state_req_recv) = mesh::channel();
318
319 let use_event = bus.use_event();
320
321 let events: Vec<_> = (0..max_subchannels + 1)
322 .map(|_| {
323 if use_event {
324 Notify::from_event(Event::new())
325 } else {
326 Notify::from_slim_event(Arc::new(SlimEvent::new()))
327 }
328 })
329 .collect();
330
331 let request = OfferInput {
332 params: offer,
333 event: events[0].clone().interrupt(),
334 request_send,
335 server_request_recv,
336 };
337
338 let gpadl_map = GpadlMap::new();
339
340 let offer_result = bus.add_child(request).await?;
341
342 let resources = events
343 .iter()
344 .map(|event| ChannelResources {
345 event: event.clone(),
346 })
347 .collect();
348
349 let (subchannel_enable_send, subchannel_enable_recv) = mesh::channel();
350 channel.install(DeviceResources {
351 offer_resources: offer_result,
352 gpadl_map: gpadl_map.clone().view(),
353 channels: resources,
354 channel_control: ChannelControl {
355 send: Some(subchannel_enable_send),
356 max: max_subchannels,
357 },
358 });
359
360 let bus = bus.clone_bus();
361 let task = driver.spawn(format!("vmbus offer {}", instance_id), async move {
362 let device = Device::new(
363 request_recv,
364 server_request_send,
365 events,
366 gpadl_map,
367 subchannel_enable_recv,
368 );
369 device
370 .run_channel(bus.as_ref(), channel.as_mut(), state_req_recv)
371 .await;
372 channel
373 });
374
375 Ok(GenericChannelHandle {
376 state_req: state_req_send,
377 task,
378 })
379}
380
381pub struct RestoreControl<'a> {
384 device: &'a mut Device,
385 bus: &'a dyn ParentBus,
386 offer: OfferParams,
387}
388
389impl RestoreControl<'_> {
390 pub async fn restore(
399 &mut self,
400 states: &[bool],
401 ) -> Result<Vec<Option<OpenRequest>>, ChannelRestoreError> {
402 self.device.restore(self.bus, &self.offer, states).await
403 }
404}
405
406#[derive(Debug, Error)]
408pub enum ChannelRestoreError {
409 #[error("failed to enable subchannels")]
411 EnablingSubchannels(#[source] anyhow::Error),
412 #[error("failed to restore vmbus channel")]
414 RestoreError(#[source] anyhow::Error),
415 #[error("failed to restore gpadl")]
417 GpadlError(#[source] vmbus_ring::gparange::Error),
418}
419
420impl From<ChannelRestoreError> for RestoreError {
421 fn from(err: ChannelRestoreError) -> Self {
422 RestoreError::Other(err.into())
423 }
424}
425
426enum DeviceState {
427 Running,
428 Stopped(Vec<(usize, ChannelRequest)>),
432}
433
434struct Device {
435 state: DeviceState,
436 server_requests: Vec<mesh::Sender<ChannelServerRequest>>,
437 open: Vec<bool>,
438 subchannel_gpadls: Vec<BTreeSet<GpadlId>>,
439 requests: SelectAll<TaggedStream<usize, mesh::Receiver<ChannelRequest>>>,
440 events: Vec<Notify>,
441 gpadl_map: Arc<GpadlMap>,
442 subchannel_enable_recv: mesh::Receiver<u16>,
443}
444
445impl Device {
446 fn new(
447 request_recv: mesh::Receiver<ChannelRequest>,
448 server_request_send: mesh::Sender<ChannelServerRequest>,
449 events: Vec<Notify>,
450 gpadl_map: Arc<GpadlMap>,
451 subchannel_enable_recv: mesh::Receiver<u16>,
452 ) -> Self {
453 let open: Vec<bool> = vec![false];
454 let subchannel_gpadls: Vec<BTreeSet<GpadlId>> = vec![];
455 let mut requests: SelectAll<TaggedStream<usize, mesh::Receiver<ChannelRequest>>> =
456 SelectAll::new();
457 requests.push(TaggedStream::new(0, request_recv));
458 Self {
459 state: DeviceState::Running,
460 server_requests: vec![server_request_send],
461 open,
462 subchannel_gpadls,
463 requests,
464 events,
465 gpadl_map,
466 subchannel_enable_recv,
467 }
468 }
469
470 async fn run_channel(
472 mut self,
473 bus: &dyn ParentBus,
474 channel: &mut dyn VmbusDevice,
475 state_req_recv: mesh::Receiver<StateRequest>,
476 ) {
477 enum Event {
478 Request(usize, Option<ChannelRequest>),
479 EnableSubchannels(u16),
480 StateRequest(Result<StateRequest, RecvError>),
481 }
482
483 let mut state_req_recv = pin!(futures::stream::unfold(state_req_recv, async |mut recv| {
484 Some((recv.recv().await, recv))
485 }));
486
487 let map_request = |(idx, req)| Event::Request(idx, req);
488 loop {
489 let mut s = select(
490 (&mut self.requests).map(map_request),
491 select(
492 (&mut self.subchannel_enable_recv).map(Event::EnableSubchannels),
493 (&mut state_req_recv).map(Event::StateRequest),
494 ),
495 );
496 if let Some(event) = s.next().await {
497 match event {
498 Event::Request(idx, Some(request)) => {
499 self.handle_channel_request(idx, request, channel).await;
500 }
501 Event::Request(_idx, None) => continue,
502 Event::EnableSubchannels(count) => {
503 let offer = channel.offer();
504 let _ = self.enable_channels(bus, &offer, count as usize + 1).await;
505 }
506 Event::StateRequest(Ok(request)) => {
507 self.handle_state_request(request, channel, bus).await;
508 }
509 Event::StateRequest(Err(_)) => {
510 break;
512 }
513 }
514 }
515 }
516 drop(self.server_requests);
518 for recv in self.requests.iter_mut() {
526 if recv.value().is_some() {
527 while recv.next().await.is_some() {}
528 }
529 }
530
531 for subchannel_idx in (0..self.open.len()).rev() {
532 if self.open[subchannel_idx] {
533 channel.close(subchannel_idx as u16).await;
534 }
535 }
536 }
537
538 #[instrument(level = "debug", skip_all, fields(channel_idx, ?request))]
539 async fn handle_channel_request(
540 &mut self,
541 channel_idx: usize,
542 request: ChannelRequest,
543 channel: &mut dyn VmbusDevice,
544 ) {
545 if matches!(request, ChannelRequest::Open(_) | ChannelRequest::Modify(_)) {
551 if let DeviceState::Stopped(pending_messages) = &mut self.state {
552 pending_messages.push((channel_idx, request));
553 return;
554 }
555 }
556
557 match request {
558 ChannelRequest::Open(rpc) => {
559 rpc.handle(async |open_request| {
560 self.handle_open(channel, channel_idx, open_request).await
561 })
562 .await
563 }
564 ChannelRequest::Close(rpc) => {
565 rpc.handle(async |()| {
566 self.handle_close(channel_idx, channel).await;
567 })
568 .await
569 }
570 ChannelRequest::Gpadl(rpc) => rpc.handle_sync(|gpadl| {
571 self.handle_gpadl(gpadl.id, gpadl.count, gpadl.buf, channel_idx);
572 true
573 }),
574 ChannelRequest::TeardownGpadl(rpc) => {
575 self.handle_teardown_gpadl(rpc, channel_idx);
576 }
577 ChannelRequest::Modify(rpc) => {
578 rpc.handle(async |req| {
579 self.handle_modify(channel, channel_idx, req).await;
580 0
581 })
582 .await
583 }
584 }
585 }
586
587 async fn handle_open(
588 &mut self,
589 channel: &mut dyn VmbusDevice,
590 channel_idx: usize,
591 open_request: OpenRequest,
592 ) -> bool {
593 assert!(!self.open[channel_idx]);
594 let opened = channel
597 .open(channel_idx as u16, &open_request)
598 .await
599 .inspect_err(|error| {
600 tracelimit::error_ratelimited!(
601 error = error.as_ref() as &dyn std::error::Error,
602 "failed to open channel"
603 );
604 })
605 .is_ok();
606 self.open[channel_idx] = opened;
607 opened
608 }
609
610 async fn handle_close(&mut self, channel_idx: usize, channel: &mut dyn VmbusDevice) {
611 assert!(self.open[channel_idx]);
612 if channel_idx == 0 {
613 self.server_requests.truncate(1);
615 for recv in self.requests.iter_mut() {
616 if let Some(&idx) = recv.value() {
617 if idx > 0 {
618 while recv.next().await.is_some() {}
619 }
620 }
621 }
622 for subchannel_idx in 1..self.open.len() {
623 if self.open[subchannel_idx] {
624 channel.close(subchannel_idx as u16).await;
625 }
626 for &gpadl_id in &self.subchannel_gpadls[subchannel_idx - 1] {
627 self.gpadl_map.remove(gpadl_id, Box::new(|| ()));
628 }
629 }
630 self.open.truncate(1);
631 self.subchannel_gpadls.clear();
632 }
633 channel.close(channel_idx as u16).await;
634 self.open[channel_idx] = false;
635 if channel_idx == 0 {
636 while self.subchannel_enable_recv.try_recv().is_ok() {}
638 }
639 }
640
641 fn handle_gpadl(&mut self, id: GpadlId, count: u16, buf: Vec<u64>, channel_idx: usize) {
642 self.gpadl_map
643 .add(id, MultiPagedRangeBuf::new(count.into(), buf).unwrap());
644 if channel_idx > 0 {
645 self.subchannel_gpadls[channel_idx - 1].insert(id);
646 }
647 }
648
649 fn handle_teardown_gpadl(&mut self, rpc: Rpc<GpadlId, ()>, channel_idx: usize) {
650 let id = *rpc.input();
651 if let Some(f) = self.gpadl_map.remove(
652 id,
653 Box::new(move || {
654 rpc.complete(());
655 }),
656 ) {
657 f()
658 }
659 if channel_idx > 0 {
660 assert!(self.subchannel_gpadls[channel_idx - 1].remove(&id));
661 }
662 }
663
664 async fn handle_modify(
665 &mut self,
666 channel: &mut dyn VmbusDevice,
667 channel_idx: usize,
668 req: ModifyRequest,
669 ) {
670 match req {
671 ModifyRequest::TargetVp { target_vp } => {
672 channel.retarget_vp(channel_idx as u16, target_vp).await
673 }
674 }
675 }
676
677 #[instrument(level = "debug", skip_all, fields(?request))]
678 async fn handle_state_request(
679 &mut self,
680 request: StateRequest,
681 channel: &mut dyn VmbusDevice,
682 bus: &dyn ParentBus,
683 ) {
684 match request {
685 StateRequest::Start => {
686 channel.start();
687 if let DeviceState::Stopped(pending_messages) =
688 std::mem::replace(&mut self.state, DeviceState::Running)
689 {
690 for (channel_idx, request) in pending_messages.into_iter() {
691 self.handle_channel_request(channel_idx, request, channel)
692 .await;
693 }
694 }
695 }
696 StateRequest::Stop(rpc) => {
697 if matches!(self.state, DeviceState::Running) {
698 self.state = DeviceState::Stopped(Vec::new());
699 rpc.handle(async |()| {
700 channel.stop().await;
701 })
702 .await;
703 } else {
704 rpc.complete(());
705 }
706 }
707 StateRequest::Reset(rpc) => {
708 if let DeviceState::Stopped(pending_messages) = &mut self.state {
709 pending_messages.clear();
710 }
711 rpc.complete(());
712 }
713 StateRequest::Save(rpc) => {
714 rpc.handle_failable(async |()| {
715 if let Some(channel) = channel.supports_save_restore() {
716 channel.save().await.map(Some)
717 } else {
718 Ok(None)
719 }
720 })
721 .await;
722 }
723 StateRequest::Restore(rpc) => {
724 rpc.handle_failable(async |buffer| {
725 let channel = channel
726 .supports_save_restore()
727 .context("saved state not supported")?;
728 let control = RestoreControl {
729 device: &mut *self,
730 offer: channel.offer(),
731 bus,
732 };
733 channel
734 .restore(control, buffer)
735 .await
736 .map_err(anyhow::Error::from)?;
737 anyhow::Ok(())
738 })
739 .await;
740 }
741 StateRequest::Inspect(deferred) => {
742 deferred.inspect(&mut *channel);
743 }
744 }
745 }
746
747 async fn enable_channels(
748 &mut self,
749 bus: &dyn ParentBus,
750 offer: &OfferParams,
751 count: usize,
752 ) -> anyhow::Result<()> {
753 let mut r = Ok(());
755 for subchannel_idx in self.server_requests.len()..count {
756 let (request_send, request_recv) = mesh::channel();
757 let (server_request_send, server_request_recv) = mesh::channel();
758 let request = OfferInput {
759 params: OfferParams {
760 subchannel_index: subchannel_idx as u16,
761 ..offer.clone()
762 },
763 event: self.events[subchannel_idx].clone().interrupt(),
764 request_send,
765 server_request_recv,
766 };
767 match bus.add_child(request).await {
768 Ok(_) => {
769 self.requests
770 .push(TaggedStream::new(subchannel_idx, request_recv));
771 self.server_requests.push(server_request_send);
772 self.subchannel_gpadls.push(BTreeSet::new());
773 self.open.push(false);
774 }
775 Err(err) => {
776 tracing::error!(
777 error = err.as_ref() as &dyn std::error::Error,
778 "could not offer subchannel"
779 );
780 if r.is_ok() {
781 r = Err(err);
782 }
783 }
784 }
785 }
786 r
787 }
788
789 pub async fn restore(
790 &mut self,
791 bus: &dyn ParentBus,
792 offer: &OfferParams,
793 states: &[bool],
794 ) -> Result<Vec<Option<OpenRequest>>, ChannelRestoreError> {
795 self.enable_channels(bus, offer, states.len())
796 .await
797 .map_err(ChannelRestoreError::EnablingSubchannels)?;
798
799 let mut results = Vec::with_capacity(states.len());
800 for (channel_idx, open) in states.iter().copied().enumerate() {
801 let result = self.server_requests[channel_idx]
802 .call_failable(ChannelServerRequest::Restore, open)
803 .await
804 .map_err(|err| ChannelRestoreError::RestoreError(err.into()))?;
805
806 assert!(open == result.open_request.is_some());
807
808 for gpadl in result.gpadls {
809 let buf =
810 match MultiPagedRangeBuf::new(gpadl.request.count.into(), gpadl.request.buf) {
811 Ok(buf) => buf,
812 Err(err) => {
813 if gpadl.accepted {
814 return Err(ChannelRestoreError::GpadlError(err));
815 } else {
816 continue;
819 }
820 }
821 };
822 self.gpadl_map.add(gpadl.request.id, buf);
823 if channel_idx > 0 {
824 self.subchannel_gpadls[channel_idx - 1].insert(gpadl.request.id);
825 }
826 }
827
828 results.push(result.open_request);
829 }
830 self.open.copy_from_slice(states);
831 Ok(results)
832 }
833}
834
835pub async fn offer_channel<T: 'static + VmbusDevice>(
838 driver: &impl Spawn,
839 bus: &(impl ParentBus + ?Sized),
840 channel: T,
841) -> anyhow::Result<ChannelHandle<T>> {
842 let handle = offer_generic(driver, bus, Box::new(channel)).await?;
843 Ok(ChannelHandle(handle, PhantomData))
844}
845
846pub async fn offer_generic_channel(
848 driver: &impl Spawn,
849 bus: &(impl ParentBus + ?Sized),
850 channel: Box<dyn VmbusDevice>,
851) -> anyhow::Result<ChannelHandle<dyn VmbusDevice>> {
852 let handle = offer_generic(driver, bus, channel).await?;
853 Ok(ChannelHandle(handle, PhantomData))
854}