vmbus_relay/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Implements a vmbus channel relay, which consumes channels from the host
5//! vmbus control plane (via [`vmbus_client`]) and relays them as channels to
6//! the guest OS (via [`vmbus_server`]).
7//!
8//! This is used to allow the paravisor to implement the vmbus control plane
9//! while still passing through channels from the host, without any paravisor
10//! presence in the data plane.
11
12#![expect(missing_docs)]
13#![forbid(unsafe_code)]
14
15pub mod legacy_saved_state;
16mod saved_state;
17
18pub use saved_state::SavedState;
19
20use anyhow::Context;
21use anyhow::Result;
22use client::ModifyConnectionRequest;
23use futures::FutureExt;
24use futures::StreamExt;
25use futures::future::BoxFuture;
26use futures::future::OptionFuture;
27use futures::future::join_all;
28use guid::Guid;
29use inspect::Inspect;
30use inspect::InspectMut;
31use mesh::rpc::FailableRpc;
32use mesh::rpc::Rpc;
33use mesh::rpc::RpcSend;
34use pal_async::driver::SpawnDriver;
35use pal_async::task::Spawn;
36use pal_async::task::Task;
37use pal_event::Event;
38use std::collections::HashMap;
39use std::fmt::Debug;
40use std::future::Future;
41use std::pin::Pin;
42use std::sync::Arc;
43use std::sync::atomic::AtomicBool;
44use std::sync::atomic::Ordering;
45use unicycle::FuturesUnordered;
46use vmbus_channel::bus::ChannelRequest;
47use vmbus_channel::bus::ChannelServerRequest;
48use vmbus_channel::bus::GpadlRequest;
49use vmbus_channel::bus::ModifyRequest;
50use vmbus_channel::bus::OpenRequest;
51use vmbus_channel::bus::OpenResult;
52use vmbus_client as client;
53use vmbus_core::HvsockConnectRequest;
54use vmbus_core::HvsockConnectResult;
55use vmbus_core::VersionInfo;
56use vmbus_core::protocol;
57use vmbus_core::protocol::ChannelId;
58use vmbus_core::protocol::FeatureFlags;
59use vmbus_core::protocol::GpadlId;
60use vmbus_server::HvsockRelayChannelHalf;
61use vmbus_server::MnfUsage;
62use vmbus_server::ModifyConnectionResponse;
63use vmbus_server::OfferInfo;
64use vmbus_server::OfferParamsInternal;
65use vmbus_server::Update;
66use vmbus_server::VmbusRelayChannelHalf;
67use vmbus_server::VmbusServerControl;
68use vmcore::interrupt::Interrupt;
69use vmcore::notify::Notify;
70use vmcore::notify::PolledNotify;
71
72pub enum InterceptChannelRequest {
73    Start,
74    Stop(Rpc<(), ()>),
75    Save(Rpc<(), vmcore::save_restore::SavedStateBlob>),
76    Restore(vmcore::save_restore::SavedStateBlob),
77    Offer(client::OfferInfo),
78}
79
80const REQUIRED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
81    .with_channel_interrupt_redirection(true)
82    .with_guest_specified_signal_parameters(true)
83    .with_modify_connection(true);
84
85/// Represents a relay between a vmbus server on the host, and the vmbus server running in
86/// Underhill, allowing offers from the host and offers from Underhill to be mixed.
87///
88/// The relay will connect to the host when it first receives a start request through its state
89/// unit, and will remain connected until it is destroyed.
90pub struct HostVmbusTransport {
91    _relay_task: Task<()>,
92    task_send: mesh::Sender<TaskRequest>,
93}
94
95impl HostVmbusTransport {
96    /// Create a new instance of the host vmbus relay.
97    pub async fn new(
98        driver: impl SpawnDriver + Clone,
99        control: Arc<VmbusServerControl>,
100        channel: VmbusRelayChannelHalf,
101        hvsock_relay: HvsockRelayChannelHalf,
102        vmbus_client: client::VmbusClientAccess,
103        connection: client::ConnectResult,
104        intercept_list: Vec<(Guid, mesh::Sender<InterceptChannelRequest>)>,
105    ) -> Result<Self> {
106        if connection.version.feature_flags & REQUIRED_FEATURE_FLAGS != REQUIRED_FEATURE_FLAGS {
107            anyhow::bail!(
108                "host must support required feature flags. \
109                 Required: {REQUIRED_FEATURE_FLAGS:?}, actual: {:?}.",
110                connection.version.feature_flags
111            );
112        }
113
114        let mut relay_task = RelayTask::new(
115            Arc::new(driver.clone()),
116            control,
117            channel.response_send,
118            hvsock_relay,
119            vmbus_client,
120            connection.version,
121        );
122
123        relay_task.intercept_channels.extend(intercept_list);
124
125        for offer in connection.offers {
126            relay_task.handle_offer(offer).await?;
127        }
128
129        let (task_send, task_recv) = mesh::channel();
130
131        let relay_task = driver.spawn("vmbus hcl relay", async move {
132            relay_task
133                .run(channel.request_receive, connection.offer_recv, task_recv)
134                .await
135                .unwrap()
136        });
137
138        Ok(Self {
139            _relay_task: relay_task,
140            task_send,
141        })
142    }
143
144    pub fn start(&self) {
145        self.task_send.send(TaskRequest::Start);
146    }
147
148    pub async fn stop(&self) {
149        self.task_send.call(TaskRequest::Stop, ()).await.unwrap()
150    }
151
152    pub async fn save(&self) -> SavedState {
153        self.task_send.call(TaskRequest::Save, ()).await.unwrap()
154    }
155
156    pub async fn restore(&self, state: SavedState) -> Result<()> {
157        self.task_send
158            .call(TaskRequest::Restore, state)
159            .await
160            .unwrap()
161    }
162}
163
164impl Inspect for HostVmbusTransport {
165    fn inspect(&self, req: inspect::Request<'_>) {
166        self.task_send.send(TaskRequest::Inspect(req.defer()));
167    }
168}
169
170impl Debug for HostVmbusTransport {
171    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172        writeln!(fmt, "HostVmbusTransport")
173    }
174}
175
176/// State needed to relay host-to-guest interrupts.
177struct InterruptRelay {
178    /// Event signaled when the host sends an interrupt.
179    notify: PolledNotify,
180    /// Interrupt used to signal the guest.
181    interrupt: Interrupt,
182    /// Event flag used to signal the guest.
183    /// FUTURE: remove once this moves into `vmbus_client` saved state.
184    event_flag: u16,
185}
186
187enum RelayChannelRequest {
188    Start,
189    Stop(Rpc<(), ()>),
190    Save(Rpc<(), saved_state::Channel>),
191    Restore(FailableRpc<saved_state::Channel, ()>),
192    Inspect(inspect::Deferred),
193}
194
195impl Debug for RelayChannelRequest {
196    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197        match self {
198            RelayChannelRequest::Start => f.pad("Start"),
199            RelayChannelRequest::Stop(..) => f.pad("Stop"),
200            RelayChannelRequest::Save(..) => f.pad("Save"),
201            RelayChannelRequest::Restore(..) => f.pad("Restore"),
202            RelayChannelRequest::Inspect(..) => f.pad("Inspect"),
203        }
204    }
205}
206
207struct RelayChannelInfo {
208    relay_request_send: mesh::Sender<RelayChannelRequest>,
209}
210
211impl Inspect for RelayChannelInfo {
212    fn inspect(&self, req: inspect::Request<'_>) {
213        self.relay_request_send
214            .send(RelayChannelRequest::Inspect(req.defer()));
215    }
216}
217
218#[derive(Inspect)]
219#[inspect(external_tag)]
220enum ChannelInfo {
221    #[inspect(transparent)]
222    Relay(RelayChannelInfo),
223    #[inspect(transparent)]
224    Intercept(Guid),
225}
226
227impl RelayChannelInfo {
228    async fn stop(&self) {
229        if let Err(err) = self
230            .relay_request_send
231            .call(RelayChannelRequest::Stop, ())
232            .await
233        {
234            tracing::warn!(?err, "failed to request channel stop");
235        }
236    }
237
238    fn start(&self) {
239        self.relay_request_send.send(RelayChannelRequest::Start);
240    }
241}
242
243/// Connects a Client channel to a Server Channel.
244#[derive(Inspect)]
245struct RelayChannel {
246    /// The Channel Id given to us by the client
247    channel_id: ChannelId,
248    /// Receives requests from the relay.
249    #[inspect(skip)]
250    relay_request_recv: mesh::Receiver<RelayChannelRequest>,
251    /// Receives requests from the server.
252    #[inspect(skip)]
253    server_request_recv: mesh::Receiver<ChannelRequest>,
254    #[inspect(skip)]
255    server_request_send: mesh::Sender<ChannelServerRequest>,
256    /// Closed when the channel has been revoked.
257    #[inspect(skip)]
258    revoke_recv: mesh::OneshotReceiver<()>,
259    /// Sends requests to the client
260    #[inspect(skip)]
261    request_send: mesh::Sender<client::ChannelRequest>,
262    /// Indicates whether or not interrupts should be relayed. This is shared with the relay server
263    /// connection, which sets this to true only if the guest uses the channel bitmap.
264    use_interrupt_relay: Arc<AtomicBool>,
265    /// State used to relay host-to-guest interrupts.
266    #[inspect(with = "Option::is_some")]
267    interrupt_relay: Option<InterruptRelay>,
268    /// Futures waiting for GPADL teardown to complete before responding to
269    /// `vmbus_server`.
270    #[inspect(skip)]
271    gpadls_tearing_down: FuturesUnordered<BoxFuture<'static, ()>>,
272    is_open: bool,
273}
274
275#[derive(InspectMut)]
276struct RelayChannelTask {
277    #[inspect(skip)]
278    driver: Arc<dyn SpawnDriver>,
279    channel: RelayChannel,
280    running: bool,
281}
282
283impl RelayChannelTask {
284    /// Relay open channel request from VTL0 to Host, responding with Open Result
285    async fn handle_open_channel(&mut self, open_request: &OpenRequest) -> Result<OpenResult> {
286        // If the guest uses the channel bitmap, the host can't send interrupts
287        // directly and they must be relayed.
288        let redirect_interrupt = self.channel.use_interrupt_relay.load(Ordering::SeqCst);
289        let (incoming_event, notify) = if redirect_interrupt {
290            let event = Event::new();
291            let notify = Notify::from_event(event.clone())
292                .pollable(self.driver.as_ref())
293                .context("failed to create polled notify")?;
294            Some((event, notify))
295        } else {
296            None
297        }
298        .unzip();
299
300        let opened = self
301            .channel
302            .request_send
303            .call_failable(
304                client::ChannelRequest::Open,
305                client::OpenRequest {
306                    open_data: open_request.open_data,
307                    incoming_event,
308                    use_vtl2_connection_id: false,
309                },
310            )
311            .await?;
312
313        if let Some(notify) = notify {
314            self.channel.interrupt_relay = Some(InterruptRelay {
315                notify,
316                interrupt: open_request.interrupt.clone(),
317                event_flag: opened.redirected_event_flag.unwrap(),
318            });
319        }
320
321        self.channel.is_open = true;
322
323        Ok(OpenResult {
324            guest_to_host_interrupt: opened.guest_to_host_signal,
325        })
326    }
327
328    async fn handle_close_channel(&mut self) {
329        self.channel
330            .request_send
331            .call(client::ChannelRequest::Close, ())
332            .await
333            .ok();
334
335        self.channel.interrupt_relay = None;
336        self.channel.is_open = false;
337    }
338
339    /// Relay gpadl request from VTL0 to the Host and respond with gpadl created.
340    async fn handle_gpadl(&mut self, request: GpadlRequest) -> Result<()> {
341        self.channel
342            .request_send
343            .call_failable(client::ChannelRequest::Gpadl, request)
344            .await?;
345
346        Ok(())
347    }
348
349    fn handle_gpadl_teardown(&mut self, rpc: Rpc<GpadlId, ()>) {
350        let (gpadl_id, rpc) = rpc.split();
351        tracing::trace!(gpadl_id = gpadl_id.0, "Tearing down GPADL");
352
353        let call = self
354            .channel
355            .request_send
356            .call(client::ChannelRequest::TeardownGpadl, gpadl_id);
357
358        // We cannot wait for GpadlTorndown here, because the host may not send the GpadlTorndown
359        // message immediately, for example if the channel is still open and the host device still
360        // has the gpadl mapped. We should not block further requests while waiting for the
361        // response.
362        self.channel.gpadls_tearing_down.push(Box::pin(async move {
363            if let Err(err) = call.await {
364                tracing::warn!(
365                    error = &err as &dyn std::error::Error,
366                    "failed to send gpadl teardown"
367                );
368            }
369            rpc.complete(());
370        }));
371    }
372
373    async fn handle_modify_channel(&mut self, modify_request: ModifyRequest) -> Result<i32> {
374        let status = self
375            .channel
376            .request_send
377            .call(client::ChannelRequest::Modify, modify_request)
378            .await?;
379
380        Ok(status)
381    }
382
383    /// Dispatch requests sent by VTL0
384    async fn handle_server_request(&mut self, request: ChannelRequest) -> Result<()> {
385        tracing::trace!(request = ?request, "received channel request");
386        match request {
387            ChannelRequest::Open(rpc) => {
388                rpc.handle(async |open_request| {
389                    self.handle_open_channel(&open_request)
390                        .await
391                        .inspect_err(|err| {
392                            tracelimit::error_ratelimited!(
393                                err = err.as_ref() as &dyn std::error::Error,
394                                channel_id = self.channel.channel_id.0,
395                                "failed to open channel"
396                            );
397                        })
398                        .ok()
399                })
400                .await;
401            }
402            ChannelRequest::Gpadl(rpc) => {
403                rpc.handle(async |gpadl| {
404                    let id = gpadl.id;
405                    self.handle_gpadl(gpadl)
406                        .await
407                        .inspect_err(|err| {
408                            tracelimit::error_ratelimited!(
409                                err = err.as_ref() as &dyn std::error::Error,
410                                channel_id = self.channel.channel_id.0,
411                                gpadl_id = id.0,
412                                "failed to create gpadl"
413                            );
414                        })
415                        .is_ok()
416                })
417                .await;
418            }
419            ChannelRequest::Close(rpc) => {
420                rpc.handle(async |()| self.handle_close_channel().await)
421                    .await;
422            }
423            ChannelRequest::TeardownGpadl(rpc) => {
424                self.handle_gpadl_teardown(rpc);
425            }
426            ChannelRequest::Modify(rpc) => {
427                rpc.handle(async |request| self.handle_modify_channel(request).await.unwrap_or(-1))
428                    .await;
429            }
430        }
431
432        Ok(())
433    }
434
435    async fn handle_relay_request(&mut self, request: RelayChannelRequest) {
436        tracing::trace!(
437            channel_id = self.channel.channel_id.0,
438            ?request,
439            "received relay request"
440        );
441
442        match request {
443            RelayChannelRequest::Start => self.running = true,
444            RelayChannelRequest::Stop(rpc) => rpc.handle_sync(|()| self.running = false),
445            RelayChannelRequest::Save(rpc) => rpc.handle_sync(|_| self.handle_save()),
446            RelayChannelRequest::Restore(rpc) => {
447                rpc.handle_failable(async |state| self.handle_restore(state).await)
448                    .await
449            }
450            RelayChannelRequest::Inspect(deferred) => deferred.inspect(self),
451        }
452    }
453
454    /// Request dispatch loop
455    async fn run(mut self) {
456        loop {
457            let mut relay_event = OptionFuture::from(
458                self.channel
459                    .interrupt_relay
460                    .as_mut()
461                    .map(|e| e.notify.wait().fuse()),
462            );
463
464            let mut server_request = OptionFuture::from(
465                self.running
466                    .then(|| self.channel.server_request_recv.next()),
467            );
468
469            futures::select! { // merge semantics
470                r = self.channel.relay_request_recv.next() => {
471                    match r {
472                        Some(request) => {
473                            // Needed to avoid conflicting &mut self borrow.
474                            drop(relay_event);
475                            self.handle_relay_request(request).await;
476                        }
477                        None => {
478                            break;
479                        }
480                    }
481                }
482                r = server_request => {
483                    match r.unwrap() {
484                        Some(request) => {
485                            // Needed to avoid conflicting &mut self borrow.
486                            drop(relay_event);
487                            self
488                                .handle_server_request(request)
489                                .await
490                                .expect("failed to get server request");
491                        }
492                        None => {
493                            break;
494                        }
495                    }
496                }
497                _r = (&mut self.channel.revoke_recv).fuse() => {
498                    break;
499                }
500                () = self.channel.gpadls_tearing_down.select_next_some() => {}
501                _r = relay_event => {
502                    // Needed to avoid conflicting interrupt_relay borrow.
503                    drop(relay_event);
504                    self.channel.interrupt_relay.as_ref().unwrap().interrupt.deliver();
505                }
506            }
507        }
508
509        // Drain GPADL teardown requests cleanly; these will all complete now
510        // that the channel has been revoked.
511        while let Some(()) = self.channel.gpadls_tearing_down.next().await {}
512
513        tracing::debug!(channel_id = %self.channel.channel_id.0, "dropped channel");
514
515        // Dropping the channel would revoke it, but since that's not synchronized there's a chance
516        // we reoffer the channel before the server receives the revoke. Using the request ensures
517        // that won't happen.
518        if let Err(err) = self
519            .channel
520            .server_request_send
521            .call(ChannelServerRequest::Revoke, ())
522            .await
523        {
524            tracing::warn!(
525                channel_id = self.channel.channel_id.0,
526                err = &err as &dyn std::error::Error,
527                "failed to send revoke request"
528            );
529        }
530    }
531}
532
533enum TaskRequest {
534    Inspect(inspect::Deferred),
535    Save(Rpc<(), SavedState>),
536    Restore(Rpc<SavedState, Result<()>>),
537    Start,
538    Stop(Rpc<(), ()>),
539}
540
541/// Dispatches requests between Server/Client.
542#[derive(InspectMut)]
543struct RelayTask {
544    #[inspect(skip)]
545    spawner: Arc<dyn SpawnDriver>,
546    #[inspect(skip)]
547    vmbus_client: client::VmbusClientAccess,
548    version: VersionInfo,
549    #[inspect(skip)]
550    vmbus_control: Arc<VmbusServerControl>,
551    #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|x| x.0)")]
552    channels: HashMap<ChannelId, ChannelInfo>,
553    #[inspect(skip)]
554    channel_workers: FuturesUnordered<Task<ChannelId>>,
555    #[inspect(with = "|x| inspect::iter_by_key(x).map_value(|_| ())")]
556    intercept_channels: HashMap<Guid, mesh::Sender<InterceptChannelRequest>>,
557    use_interrupt_relay: Arc<AtomicBool>,
558    #[inspect(skip)]
559    server_response_send: mesh::Sender<ModifyConnectionResponse>,
560    #[inspect(skip)]
561    hvsock_relay: HvsockRelayChannelHalf,
562    #[inspect(skip)]
563    hvsock_requests: FuturesUnordered<HvsockRequestFuture>,
564    running: bool,
565}
566
567type HvsockRequestFuture =
568    Pin<Box<dyn Future<Output = (HvsockConnectRequest, Option<client::OfferInfo>)> + Sync + Send>>;
569
570impl RelayTask {
571    fn new(
572        spawner: Arc<dyn SpawnDriver>,
573        vmbus_control: Arc<VmbusServerControl>,
574        server_response_send: mesh::Sender<ModifyConnectionResponse>,
575        hvsock_relay: HvsockRelayChannelHalf,
576        vmbus_client: client::VmbusClientAccess,
577        version: VersionInfo,
578    ) -> Self {
579        Self {
580            spawner,
581            vmbus_client,
582            version,
583            vmbus_control,
584            channels: HashMap::new(),
585            channel_workers: FuturesUnordered::new(),
586            intercept_channels: HashMap::new(),
587            use_interrupt_relay: Arc::new(AtomicBool::new(false)),
588            server_response_send,
589            hvsock_relay,
590            running: false,
591            hvsock_requests: FuturesUnordered::new(),
592        }
593    }
594
595    async fn handle_start(&mut self) {
596        if !self.running {
597            // Resume all channels.
598            for c in self.channels.values() {
599                match c {
600                    ChannelInfo::Relay(relay) => relay.start(),
601                    ChannelInfo::Intercept(id) => {
602                        let Some(intercept_channel) = self.intercept_channels.get(id) else {
603                            tracing::error!(%id, "Intercept device missing from list");
604                            continue;
605                        };
606                        intercept_channel.send(InterceptChannelRequest::Start);
607                    }
608                }
609            }
610
611            self.running = true;
612        }
613    }
614
615    async fn handle_stop(&mut self) {
616        if self.running {
617            // Stop all the channels before the relay itself can stop.
618            join_all(self.channels.values().map(|c| match c {
619                ChannelInfo::Relay(relay) => futures::future::Either::Left(relay.stop()),
620                ChannelInfo::Intercept(id) => futures::future::Either::Right(async {
621                    let id = *id;
622                    if let Some(intercept_channel) = self.intercept_channels.get(&id) {
623                        if let Err(err) = intercept_channel
624                            .call(InterceptChannelRequest::Stop, ())
625                            .await
626                        {
627                            tracing::error!(
628                                err = &err as &dyn std::error::Error,
629                                %id,
630                                "Failed to stop intercepted device"
631                            );
632                        }
633                    }
634                }),
635            }))
636            .await;
637
638            // Because requests are handled "synchronously" (async is used but everything is awaited
639            // before another request is handled), there is no need for rundown and the relay can
640            // stop immediately.
641            self.running = false;
642        }
643    }
644
645    /// Translates an offer received from the client to a server offer.
646    /// Additionally, sets up all the appropriate channels.
647    async fn handle_offer(&mut self, offer: client::OfferInfo) -> Result<()> {
648        let channel_id = offer.offer.channel_id.0;
649
650        if self.channels.contains_key(&ChannelId(channel_id)) {
651            anyhow::bail!("channel {channel_id} already exists");
652        }
653
654        if let Some(intercept) = self.intercept_channels.get(&offer.offer.instance_id) {
655            self.channels.insert(
656                ChannelId(channel_id),
657                ChannelInfo::Intercept(offer.offer.instance_id),
658            );
659            intercept.send(InterceptChannelRequest::Offer(offer));
660            return Ok(());
661        }
662
663        // Used to Recv requests from the server.
664        let (request_send, request_recv) = mesh::channel();
665        // Used to Send responses from the server
666        let (server_request_send, server_request_recv) = mesh::channel();
667
668        if offer.offer.is_dedicated != 1 {
669            tracing::warn!(offer = ?offer.offer, "All offers should be dedicated with Win8+ host")
670        }
671
672        // If the vmbus server is handling MnF, instead of relaying it, it will ignore this monitor
673        // ID and allocate its own.
674        let use_mnf = if offer.offer.monitor_allocated != 0 {
675            MnfUsage::Relayed {
676                monitor_id: offer.offer.monitor_id,
677            }
678        } else {
679            MnfUsage::Disabled
680        };
681
682        let params = OfferParamsInternal {
683            interface_name: "host relay".to_owned(),
684            instance_id: offer.offer.instance_id,
685            interface_id: offer.offer.interface_id,
686            mmio_megabytes: offer.offer.mmio_megabytes,
687            mmio_megabytes_optional: offer.offer.mmio_megabytes_optional,
688            subchannel_index: offer.offer.subchannel_index,
689            use_mnf,
690            // Preserve channel enumeration order from the host within the same
691            // interface type.
692            offer_order: Some(channel_id),
693            // Strip the confidential flags for relay channels if the host set them.
694            flags: offer
695                .offer
696                .flags
697                .with_confidential_ring_buffer(false)
698                .with_confidential_external_memory(false),
699            user_defined: offer.offer.user_defined,
700        };
701
702        let key = params.key();
703        let new_offer = OfferInfo {
704            params,
705            request_send,
706            server_request_recv,
707        };
708
709        // Don't send the client's channel and connection ID to the server. Instead, the server will
710        // decide its own IDs which are communicated back to the host as part of the open message
711        // using guest-specified signal parameters, which the host must support.
712        //
713        // The vmbus server will ignore the monitor ID and allocate its own if MNF is handled by it
714        // and not the host.
715        self.vmbus_control
716            .offer_core(new_offer)
717            .await
718            .with_context(|| format!("failed to offer relay channel {key}"))?;
719
720        let (relay_request_send, relay_request_recv) = mesh::channel();
721        let channel_task = RelayChannelTask {
722            driver: Arc::clone(&self.spawner),
723            channel: RelayChannel {
724                channel_id: ChannelId(channel_id),
725                relay_request_recv,
726                request_send: offer.request_send,
727                revoke_recv: offer.revoke_recv,
728                server_request_send,
729                server_request_recv: request_recv,
730                use_interrupt_relay: Arc::clone(&self.use_interrupt_relay),
731                interrupt_relay: None,
732                gpadls_tearing_down: FuturesUnordered::new(),
733                is_open: false,
734            },
735            running: self.running,
736        };
737
738        let task = self.spawner.spawn("vmbus hcl channel worker", async move {
739            channel_task.run().await;
740            ChannelId(channel_id)
741        });
742
743        self.channels.insert(
744            ChannelId(channel_id),
745            ChannelInfo::Relay(RelayChannelInfo { relay_request_send }),
746        );
747        self.channel_workers.push(task);
748
749        Ok(())
750    }
751
752    async fn handle_revoked(&mut self, channel_id: ChannelId) {
753        // The task has already completed, so just remove the channel from the list.
754        self.channels
755            .remove(&channel_id)
756            .expect("channel should exist");
757    }
758
759    async fn handle_modify(
760        &mut self,
761        request: vmbus_server::ModifyRelayRequest,
762    ) -> ModifyConnectionResponse {
763        // If the guest is requesting a version change, check whether that version is not newer
764        // than what the host supports.
765        if let Some(version) = request.version {
766            if (self.version.version as u32) < version {
767                return ModifyConnectionResponse::Unsupported;
768            }
769        }
770
771        if let Some(use_interrupt_page) = request.use_interrupt_page {
772            // If the guest is using the channel bitmap, the host can't send interrupts directly and
773            // must relay them through Underhill.
774            self.use_interrupt_relay
775                .store(use_interrupt_page, Ordering::SeqCst);
776        }
777
778        // If the monitor page is not changing, there is no need to send any request to the host.
779        let state = match request.monitor_page {
780            Update::Unchanged => protocol::ConnectionState::SUCCESSFUL,
781            Update::Reset => {
782                self.vmbus_client
783                    .modify(ModifyConnectionRequest { monitor_page: None })
784                    .await
785            }
786            Update::Set(value) => {
787                self.vmbus_client
788                    .modify(ModifyConnectionRequest {
789                        monitor_page: Some(value),
790                    })
791                    .await
792            }
793        };
794
795        ModifyConnectionResponse::Supported(state, self.version.feature_flags)
796    }
797
798    async fn handle_server_request(&mut self, request: vmbus_server::ModifyRelayRequest) {
799        tracing::trace!(request = ?request, "received server request");
800        let result = self.handle_modify(request).await;
801        self.server_response_send.send(result);
802    }
803
804    fn handle_hvsock_request(&mut self, request: HvsockConnectRequest) {
805        tracing::debug!(request = ?request, "received hvsock connect request");
806        let fut = self.vmbus_client.connect_hvsock(request);
807        self.hvsock_requests
808            .push(Box::pin(fut.map(move |offer| (request, offer))));
809    }
810
811    async fn handle_hvsock_response(
812        &mut self,
813        request: HvsockConnectRequest,
814        offer: Option<client::OfferInfo>,
815    ) {
816        let success = if let Some(offer) = offer {
817            match self.handle_offer(offer).await {
818                Ok(()) => true,
819                Err(err) => {
820                    tracing::error!(
821                        error = err.as_ref() as &dyn std::error::Error,
822                        "failed add hvsock offer"
823                    );
824                    false
825                }
826            }
827        } else {
828            false
829        };
830        self.hvsock_relay
831            .response_send
832            .send(HvsockConnectResult::from_request(&request, success));
833    }
834
835    async fn handle_offer_request(&mut self, request: client::OfferInfo) -> Result<()> {
836        if let Err(err) = self.handle_offer(request).await {
837            tracing::error!(
838                error = err.as_ref() as &dyn std::error::Error,
839                "failed to hot add offer"
840            );
841        }
842
843        Ok(())
844    }
845
846    async fn run(
847        &mut self,
848        server_recv: mesh::Receiver<vmbus_server::ModifyRelayRequest>,
849        mut offer_recv: mesh::Receiver<client::OfferInfo>,
850        mut task_recv: mesh::Receiver<TaskRequest>,
851    ) -> Result<()> {
852        let mut server_recv = server_recv.fuse();
853        loop {
854            let mut offer_recv =
855                OptionFuture::from(self.running.then(|| offer_recv.select_next_some()));
856
857            futures::select! { // merge semantics
858                r = server_recv.select_next_some() => {
859                    self.handle_server_request(r).await;
860                }
861                r = self.hvsock_relay.request_receive.select_next_some() => {
862                    self.handle_hvsock_request(r);
863                }
864                r = self.hvsock_requests.select_next_some() => {
865                    self.handle_hvsock_response(r.0, r.1).await;
866                }
867                r = offer_recv => {
868                    self.handle_offer_request(r.unwrap()).await?;
869                }
870                r = task_recv.recv().fuse() => {
871                    match r.unwrap() {
872                        TaskRequest::Inspect(req) => req.inspect(&mut *self),
873                        TaskRequest::Save(rpc) => rpc.handle(async |()| {
874                             self.handle_save().await
875                        }).await,
876                        TaskRequest::Restore(rpc) => rpc.handle(async |state|  {
877                            self.handle_restore(state).await
878                        }).await,
879                        TaskRequest::Start => self.handle_start().await,
880                        TaskRequest::Stop(rpc) => rpc.handle(async |()| self.handle_stop().await).await,
881                    }
882                }
883                r = self.channel_workers.select_next_some() => {
884                    self.handle_revoked(r).await;
885                }
886            }
887        }
888    }
889}