vmbus_relay/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Implements a vmbus channel relay, which consumes channels from the host
5//! vmbus control plane (via [`vmbus_client`]) and relays them as channels to
6//! the guest OS (via [`vmbus_server`]).
7//!
8//! This is used to allow the paravisor to implement the vmbus control plane
9//! while still passing through channels from the host, without any paravisor
10//! presence in the data plane.
11
12#![expect(missing_docs)]
13#![forbid(unsafe_code)]
14
15pub mod legacy_saved_state;
16mod saved_state;
17
18pub use saved_state::SavedState;
19
20use anyhow::Context;
21use anyhow::Result;
22use client::ModifyConnectionRequest;
23use futures::FutureExt;
24use futures::StreamExt;
25use futures::future::BoxFuture;
26use futures::future::OptionFuture;
27use futures::future::join_all;
28use guid::Guid;
29use inspect::Inspect;
30use inspect::InspectMut;
31use mesh::rpc::FailableRpc;
32use mesh::rpc::Rpc;
33use mesh::rpc::RpcSend;
34use pal_async::driver::SpawnDriver;
35use pal_async::task::Spawn;
36use pal_async::task::Task;
37use pal_event::Event;
38use std::collections::HashMap;
39use std::fmt::Debug;
40use std::future::Future;
41use std::pin::Pin;
42use std::sync::Arc;
43use std::sync::atomic::AtomicBool;
44use std::sync::atomic::Ordering;
45use unicycle::FuturesUnordered;
46use vmbus_channel::bus::ChannelRequest;
47use vmbus_channel::bus::ChannelServerRequest;
48use vmbus_channel::bus::GpadlRequest;
49use vmbus_channel::bus::ModifyRequest;
50use vmbus_channel::bus::OfferKey;
51use vmbus_channel::bus::OpenRequest;
52use vmbus_client as client;
53use vmbus_core::HvsockConnectRequest;
54use vmbus_core::HvsockConnectResult;
55use vmbus_core::VersionInfo;
56use vmbus_core::protocol;
57use vmbus_core::protocol::ChannelId;
58use vmbus_core::protocol::FeatureFlags;
59use vmbus_core::protocol::GpadlId;
60use vmbus_server::HvsockRelayChannelHalf;
61use vmbus_server::MnfUsage;
62use vmbus_server::ModifyRelayResponse;
63use vmbus_server::OfferInfo;
64use vmbus_server::OfferParamsInternal;
65use vmbus_server::Update;
66use vmbus_server::VmbusRelayChannelHalf;
67use vmbus_server::VmbusServerControl;
68use vmcore::interrupt::Interrupt;
69use vmcore::notify::Notify;
70use vmcore::notify::PolledNotify;
71
72pub enum InterceptChannelRequest {
73    Start,
74    Stop(Rpc<(), ()>),
75    Save(Rpc<(), vmcore::save_restore::SavedStateBlob>),
76    Restore(vmcore::save_restore::SavedStateBlob),
77    Offer(client::OfferInfo),
78}
79
80const REQUIRED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
81    .with_channel_interrupt_redirection(true)
82    .with_guest_specified_signal_parameters(true)
83    .with_modify_connection(true);
84
85/// Represents a relay between a vmbus server on the host, and the vmbus server running in
86/// Underhill, allowing offers from the host and offers from Underhill to be mixed.
87///
88/// The relay will connect to the host when it first receives a start request through its state
89/// unit, and will remain connected until it is destroyed.
90#[derive(Inspect, Debug)]
91pub struct HostVmbusTransport {
92    #[inspect(skip)]
93    _relay_task: Task<()>,
94    #[inspect(flatten, send = "TaskRequest::Inspect")]
95    task_send: mesh::Sender<TaskRequest>,
96}
97
98impl HostVmbusTransport {
99    /// Create a new instance of the host vmbus relay.
100    pub async fn new(
101        driver: impl SpawnDriver + Clone,
102        control: Arc<VmbusServerControl>,
103        channel: VmbusRelayChannelHalf,
104        hvsock_relay: HvsockRelayChannelHalf,
105        vmbus_client: client::VmbusClientAccess,
106        connection: client::ConnectResult,
107        intercept_list: Vec<(Guid, mesh::Sender<InterceptChannelRequest>)>,
108    ) -> Result<Self> {
109        if connection.version.feature_flags & REQUIRED_FEATURE_FLAGS != REQUIRED_FEATURE_FLAGS {
110            anyhow::bail!(
111                "host must support required feature flags. \
112                 Required: {REQUIRED_FEATURE_FLAGS:?}, actual: {:?}.",
113                connection.version.feature_flags
114            );
115        }
116
117        let mut relay_task = RelayTask::new(
118            Arc::new(driver.clone()),
119            control,
120            channel.response_send,
121            hvsock_relay,
122            vmbus_client,
123            connection.version,
124        );
125
126        relay_task.intercept_channels.extend(intercept_list);
127
128        for offer in connection.offers {
129            relay_task.handle_offer(offer).await?;
130        }
131
132        let (task_send, task_recv) = mesh::channel();
133
134        let relay_task = driver.spawn("vmbus hcl relay", async move {
135            relay_task
136                .run(channel.request_receive, connection.offer_recv, task_recv)
137                .await
138                .unwrap()
139        });
140
141        Ok(Self {
142            _relay_task: relay_task,
143            task_send,
144        })
145    }
146
147    pub fn start(&self) {
148        self.task_send.send(TaskRequest::Start);
149    }
150
151    pub async fn stop(&self) {
152        self.task_send.call(TaskRequest::Stop, ()).await.unwrap()
153    }
154
155    pub async fn save(&self) -> SavedState {
156        self.task_send.call(TaskRequest::Save, ()).await.unwrap()
157    }
158
159    pub async fn restore(&self, state: SavedState) -> Result<()> {
160        self.task_send
161            .call(TaskRequest::Restore, state)
162            .await
163            .unwrap()
164    }
165}
166
167/// State needed to relay host-to-guest interrupts.
168struct InterruptRelay {
169    /// Event signaled when the host sends an interrupt.
170    notify: PolledNotify,
171    /// Interrupt used to signal the guest.
172    interrupt: Interrupt,
173    /// Event flag used to signal the guest.
174    /// FUTURE: remove once this moves into `vmbus_client` saved state.
175    event_flag: u16,
176}
177
178enum RelayChannelRequest {
179    Start,
180    Stop(Rpc<(), ()>),
181    Save(Rpc<(), saved_state::Channel>),
182    Restore(FailableRpc<saved_state::Channel, ()>),
183    Inspect(inspect::Deferred),
184}
185
186impl Debug for RelayChannelRequest {
187    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188        match self {
189            RelayChannelRequest::Start => f.pad("Start"),
190            RelayChannelRequest::Stop(..) => f.pad("Stop"),
191            RelayChannelRequest::Save(..) => f.pad("Save"),
192            RelayChannelRequest::Restore(..) => f.pad("Restore"),
193            RelayChannelRequest::Inspect(..) => f.pad("Inspect"),
194        }
195    }
196}
197
198#[derive(Inspect)]
199struct RelayChannelInfo {
200    #[inspect(flatten, send = "RelayChannelRequest::Inspect")]
201    relay_request_send: mesh::Sender<RelayChannelRequest>,
202}
203
204#[derive(Inspect)]
205#[inspect(external_tag)]
206enum ChannelInfo {
207    #[inspect(transparent)]
208    Relay(RelayChannelInfo),
209    #[inspect(transparent)]
210    Intercept(Guid),
211}
212
213impl RelayChannelInfo {
214    async fn stop(&self) {
215        if let Err(err) = self
216            .relay_request_send
217            .call(RelayChannelRequest::Stop, ())
218            .await
219        {
220            tracing::warn!(?err, "failed to request channel stop");
221        }
222    }
223
224    fn start(&self) {
225        self.relay_request_send.send(RelayChannelRequest::Start);
226    }
227}
228
229/// Connects a Client channel to a Server Channel.
230#[derive(Inspect)]
231struct RelayChannel {
232    /// The Channel Id given to us by the client
233    channel_id: ChannelId,
234    /// The identifying key for this channel.
235    key: OfferKey,
236    /// Receives requests from the relay.
237    #[inspect(skip)]
238    relay_request_recv: mesh::Receiver<RelayChannelRequest>,
239    /// Receives requests from the server.
240    #[inspect(skip)]
241    server_request_recv: mesh::Receiver<ChannelRequest>,
242    #[inspect(skip)]
243    server_request_send: mesh::Sender<ChannelServerRequest>,
244    /// Closed when the channel has been revoked.
245    #[inspect(skip)]
246    revoke_recv: mesh::OneshotReceiver<()>,
247    /// Sends requests to the client
248    #[inspect(skip)]
249    request_send: mesh::Sender<client::ChannelRequest>,
250    /// Indicates whether or not interrupts should be relayed. This is shared with the relay server
251    /// connection, which sets this to true only if the guest uses the channel bitmap.
252    use_interrupt_relay: Arc<AtomicBool>,
253    /// State used to relay host-to-guest interrupts.
254    #[inspect(with = "Option::is_some")]
255    interrupt_relay: Option<InterruptRelay>,
256    /// Futures waiting for GPADL teardown to complete before responding to
257    /// `vmbus_server`.
258    #[inspect(skip)]
259    gpadls_tearing_down: FuturesUnordered<BoxFuture<'static, ()>>,
260    is_open: bool,
261}
262
263#[derive(InspectMut)]
264struct RelayChannelTask {
265    #[inspect(skip)]
266    driver: Arc<dyn SpawnDriver>,
267    channel: RelayChannel,
268    running: bool,
269}
270
271impl RelayChannelTask {
272    /// Relay open channel request from VTL0 to Host, responding with Open Result
273    async fn handle_open_channel(&mut self, open_request: &OpenRequest) -> Result<()> {
274        // If the guest uses the channel bitmap, the host can't send interrupts
275        // directly and they must be relayed.
276        let redirect_interrupt = self.channel.use_interrupt_relay.load(Ordering::SeqCst);
277        let (incoming_event, notify) = if redirect_interrupt {
278            let event = Event::new();
279            let notify = Notify::from_event(event.clone())
280                .pollable(self.driver.as_ref())
281                .context("failed to create polled notify")?;
282            Some((event, notify))
283        } else {
284            None
285        }
286        .unzip();
287
288        let opened = self
289            .channel
290            .request_send
291            .call_failable(
292                client::ChannelRequest::Open,
293                client::OpenRequest {
294                    open_data: open_request.open_data,
295                    incoming_event,
296                    use_vtl2_connection_id: false,
297                },
298            )
299            .await?;
300
301        if let Some(notify) = notify {
302            self.channel.interrupt_relay = Some(InterruptRelay {
303                notify,
304                interrupt: open_request.interrupt.clone(),
305                event_flag: opened.redirected_event_flag.unwrap(),
306            });
307        }
308
309        self.channel.is_open = true;
310
311        Ok(())
312    }
313
314    async fn handle_close_channel(&mut self) {
315        self.channel
316            .request_send
317            .call(client::ChannelRequest::Close, ())
318            .await
319            .ok();
320
321        self.channel.interrupt_relay = None;
322        self.channel.is_open = false;
323    }
324
325    /// Relay gpadl request from VTL0 to the Host and respond with gpadl created.
326    async fn handle_gpadl(&mut self, request: GpadlRequest) -> Result<()> {
327        self.channel
328            .request_send
329            .call_failable(client::ChannelRequest::Gpadl, request)
330            .await?;
331
332        Ok(())
333    }
334
335    fn handle_gpadl_teardown(&mut self, rpc: Rpc<GpadlId, ()>) {
336        let (gpadl_id, rpc) = rpc.split();
337        tracing::trace!(gpadl_id = gpadl_id.0, key = %self.channel.key, "Tearing down GPADL");
338
339        let call = self
340            .channel
341            .request_send
342            .call(client::ChannelRequest::TeardownGpadl, gpadl_id);
343
344        // We cannot wait for GpadlTorndown here, because the host may not send the GpadlTorndown
345        // message immediately, for example if the channel is still open and the host device still
346        // has the gpadl mapped. We should not block further requests while waiting for the
347        // response.
348        let key = self.channel.key;
349        self.channel.gpadls_tearing_down.push(Box::pin(async move {
350            if let Err(err) = call.await {
351                tracing::warn!(
352                    %key,
353                    error = &err as &dyn std::error::Error,
354                    "failed to send gpadl teardown"
355                );
356            }
357            rpc.complete(());
358        }));
359    }
360
361    async fn handle_modify_channel(&mut self, modify_request: ModifyRequest) -> Result<i32> {
362        let status = self
363            .channel
364            .request_send
365            .call(client::ChannelRequest::Modify, modify_request)
366            .await?;
367
368        Ok(status)
369    }
370
371    /// Dispatch requests sent by VTL0
372    async fn handle_server_request(&mut self, request: ChannelRequest) -> Result<()> {
373        tracing::trace!(key = %self.channel.key, request = ?request, "received channel request");
374        match request {
375            ChannelRequest::Open(rpc) => {
376                rpc.handle(async |open_request| {
377                    self.handle_open_channel(&open_request)
378                        .await
379                        .inspect_err(|err| {
380                            tracelimit::error_ratelimited!(
381                                err = err.as_ref() as &dyn std::error::Error,
382                                key = %self.channel.key,
383                                channel_id = self.channel.channel_id.0,
384                                "failed to open channel"
385                            );
386                        })
387                        .is_ok()
388                })
389                .await;
390            }
391            ChannelRequest::Gpadl(rpc) => {
392                rpc.handle(async |gpadl| {
393                    let id = gpadl.id;
394                    self.handle_gpadl(gpadl)
395                        .await
396                        .inspect_err(|err| {
397                            tracelimit::error_ratelimited!(
398                                err = err.as_ref() as &dyn std::error::Error,
399                                key = %self.channel.key,
400                                channel_id = self.channel.channel_id.0,
401                                gpadl_id = id.0,
402                                "failed to create gpadl"
403                            );
404                        })
405                        .is_ok()
406                })
407                .await;
408            }
409            ChannelRequest::Close(rpc) => {
410                rpc.handle(async |()| self.handle_close_channel().await)
411                    .await;
412            }
413            ChannelRequest::TeardownGpadl(rpc) => {
414                self.handle_gpadl_teardown(rpc);
415            }
416            ChannelRequest::Modify(rpc) => {
417                rpc.handle(async |request| self.handle_modify_channel(request).await.unwrap_or(-1))
418                    .await;
419            }
420        }
421
422        Ok(())
423    }
424
425    async fn handle_relay_request(&mut self, request: RelayChannelRequest) {
426        tracing::trace!(
427            channel_id = self.channel.channel_id.0,
428            key = %self.channel.key,
429            ?request,
430            "received relay request"
431        );
432
433        match request {
434            RelayChannelRequest::Start => self.running = true,
435            RelayChannelRequest::Stop(rpc) => rpc.handle_sync(|()| self.running = false),
436            RelayChannelRequest::Save(rpc) => rpc.handle_sync(|_| self.handle_save()),
437            RelayChannelRequest::Restore(rpc) => {
438                rpc.handle_failable(async |state| self.handle_restore(state).await)
439                    .await
440            }
441            RelayChannelRequest::Inspect(deferred) => deferred.inspect(self),
442        }
443    }
444
445    /// Request dispatch loop
446    async fn run(mut self) {
447        loop {
448            let mut relay_event = OptionFuture::from(
449                self.channel
450                    .interrupt_relay
451                    .as_mut()
452                    .map(|e| e.notify.wait().fuse()),
453            );
454
455            let mut server_request = OptionFuture::from(
456                self.running
457                    .then(|| self.channel.server_request_recv.next()),
458            );
459
460            futures::select! { // merge semantics
461                r = self.channel.relay_request_recv.next() => {
462                    match r {
463                        Some(request) => {
464                            // Needed to avoid conflicting &mut self borrow.
465                            drop(relay_event);
466                            self.handle_relay_request(request).await;
467                        }
468                        None => {
469                            break;
470                        }
471                    }
472                }
473                r = server_request => {
474                    match r.unwrap() {
475                        Some(request) => {
476                            // Needed to avoid conflicting &mut self borrow.
477                            drop(relay_event);
478                            self
479                                .handle_server_request(request)
480                                .await
481                                .expect("failed to get server request");
482                        }
483                        None => {
484                            break;
485                        }
486                    }
487                }
488                _r = (&mut self.channel.revoke_recv).fuse() => {
489                    break;
490                }
491                () = self.channel.gpadls_tearing_down.select_next_some() => {}
492                _r = relay_event => {
493                    // Needed to avoid conflicting interrupt_relay borrow.
494                    drop(relay_event);
495                    self.channel.interrupt_relay.as_ref().unwrap().interrupt.deliver();
496                }
497            }
498        }
499
500        // Drain GPADL teardown requests cleanly; these will all complete now
501        // that the channel has been revoked.
502        while let Some(()) = self.channel.gpadls_tearing_down.next().await {}
503
504        tracing::debug!(
505            channel_id = %self.channel.channel_id.0,
506            key = %self.channel.key,
507            "dropped channel"
508        );
509
510        // Dropping the channel would revoke it, but since that's not synchronized there's a chance
511        // we reoffer the channel before the server receives the revoke. Using the request ensures
512        // that won't happen.
513        if let Err(err) = self
514            .channel
515            .server_request_send
516            .call(ChannelServerRequest::Revoke, ())
517            .await
518        {
519            tracing::warn!(
520                channel_id = self.channel.channel_id.0,
521                key = %self.channel.key,
522                err = &err as &dyn std::error::Error,
523                "failed to send revoke request"
524            );
525        }
526    }
527}
528
529enum TaskRequest {
530    Inspect(inspect::Deferred),
531    Save(Rpc<(), SavedState>),
532    Restore(Rpc<SavedState, Result<()>>),
533    Start,
534    Stop(Rpc<(), ()>),
535}
536
537/// Dispatches requests between Server/Client.
538#[derive(InspectMut)]
539struct RelayTask {
540    #[inspect(skip)]
541    spawner: Arc<dyn SpawnDriver>,
542    #[inspect(skip)]
543    vmbus_client: client::VmbusClientAccess,
544    version: VersionInfo,
545    #[inspect(skip)]
546    vmbus_control: Arc<VmbusServerControl>,
547    #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|x| x.0)")]
548    channels: HashMap<ChannelId, ChannelInfo>,
549    #[inspect(skip)]
550    channel_workers: FuturesUnordered<Task<ChannelId>>,
551    #[inspect(with = "|x| inspect::iter_by_key(x).map_value(|_| ())")]
552    intercept_channels: HashMap<Guid, mesh::Sender<InterceptChannelRequest>>,
553    use_interrupt_relay: Arc<AtomicBool>,
554    #[inspect(skip)]
555    server_response_send: mesh::Sender<ModifyRelayResponse>,
556    #[inspect(skip)]
557    hvsock_relay: HvsockRelayChannelHalf,
558    #[inspect(skip)]
559    hvsock_requests: FuturesUnordered<HvsockRequestFuture>,
560    running: bool,
561}
562
563type HvsockRequestFuture =
564    Pin<Box<dyn Future<Output = (HvsockConnectRequest, Option<client::OfferInfo>)> + Sync + Send>>;
565
566impl RelayTask {
567    fn new(
568        spawner: Arc<dyn SpawnDriver>,
569        vmbus_control: Arc<VmbusServerControl>,
570        server_response_send: mesh::Sender<ModifyRelayResponse>,
571        hvsock_relay: HvsockRelayChannelHalf,
572        vmbus_client: client::VmbusClientAccess,
573        version: VersionInfo,
574    ) -> Self {
575        Self {
576            spawner,
577            vmbus_client,
578            version,
579            vmbus_control,
580            channels: HashMap::new(),
581            channel_workers: FuturesUnordered::new(),
582            intercept_channels: HashMap::new(),
583            use_interrupt_relay: Arc::new(AtomicBool::new(false)),
584            server_response_send,
585            hvsock_relay,
586            running: false,
587            hvsock_requests: FuturesUnordered::new(),
588        }
589    }
590
591    async fn handle_start(&mut self) {
592        if !self.running {
593            // Resume all channels.
594            for c in self.channels.values() {
595                match c {
596                    ChannelInfo::Relay(relay) => relay.start(),
597                    ChannelInfo::Intercept(id) => {
598                        let Some(intercept_channel) = self.intercept_channels.get(id) else {
599                            tracing::error!(%id, "Intercept device missing from list");
600                            continue;
601                        };
602                        intercept_channel.send(InterceptChannelRequest::Start);
603                    }
604                }
605            }
606
607            self.running = true;
608        }
609    }
610
611    async fn handle_stop(&mut self) {
612        if self.running {
613            // Stop all the channels before the relay itself can stop.
614            join_all(self.channels.values().map(|c| match c {
615                ChannelInfo::Relay(relay) => futures::future::Either::Left(relay.stop()),
616                ChannelInfo::Intercept(id) => futures::future::Either::Right(async {
617                    let id = *id;
618                    if let Some(intercept_channel) = self.intercept_channels.get(&id) {
619                        if let Err(err) = intercept_channel
620                            .call(InterceptChannelRequest::Stop, ())
621                            .await
622                        {
623                            tracing::error!(
624                                err = &err as &dyn std::error::Error,
625                                %id,
626                                "Failed to stop intercepted device"
627                            );
628                        }
629                    }
630                }),
631            }))
632            .await;
633
634            // Because requests are handled "synchronously" (async is used but everything is awaited
635            // before another request is handled), there is no need for rundown and the relay can
636            // stop immediately.
637            self.running = false;
638        }
639    }
640
641    /// Translates an offer received from the client to a server offer.
642    /// Additionally, sets up all the appropriate channels.
643    async fn handle_offer(&mut self, offer: client::OfferInfo) -> Result<()> {
644        let channel_id = offer.offer.channel_id.0;
645
646        if let Some(intercept) = self.intercept_channels.get(&offer.offer.instance_id) {
647            self.channels.insert(
648                ChannelId(channel_id),
649                ChannelInfo::Intercept(offer.offer.instance_id),
650            );
651            intercept.send(InterceptChannelRequest::Offer(offer));
652            return Ok(());
653        }
654
655        if self.channels.contains_key(&ChannelId(channel_id)) {
656            anyhow::bail!("channel {channel_id} already exists");
657        }
658
659        // Used to Recv requests from the server.
660        let (request_send, request_recv) = mesh::channel();
661        // Used to Send responses from the server
662        let (server_request_send, server_request_recv) = mesh::channel();
663
664        if offer.offer.is_dedicated != 1 {
665            tracing::warn!(offer = ?offer.offer, "All offers should be dedicated with Win8+ host")
666        }
667
668        // If the vmbus server is handling MnF, instead of relaying it, it will ignore this monitor
669        // ID and allocate its own.
670        let use_mnf = if offer.offer.monitor_allocated != 0 {
671            MnfUsage::Relayed {
672                monitor_id: offer.offer.monitor_id,
673            }
674        } else {
675            MnfUsage::Disabled
676        };
677
678        let params = OfferParamsInternal {
679            interface_name: "host relay".to_owned(),
680            instance_id: offer.offer.instance_id,
681            interface_id: offer.offer.interface_id,
682            mmio_megabytes: offer.offer.mmio_megabytes,
683            mmio_megabytes_optional: offer.offer.mmio_megabytes_optional,
684            subchannel_index: offer.offer.subchannel_index,
685            use_mnf,
686            // Preserve channel enumeration order from the host within the same
687            // interface type.
688            offer_order: Some(channel_id.into()),
689            // Strip the confidential flags for relay channels if the host set them.
690            flags: offer
691                .offer
692                .flags
693                .with_confidential_ring_buffer(false)
694                .with_confidential_external_memory(false),
695            user_defined: offer.offer.user_defined,
696        };
697
698        let key = params.key();
699        let new_offer = OfferInfo {
700            params,
701            event: offer.guest_to_host_interrupt,
702            request_send,
703            server_request_recv,
704        };
705
706        // Don't send the client's channel and connection ID to the server. Instead, the server will
707        // decide its own IDs which are communicated back to the host as part of the open message
708        // using guest-specified signal parameters, which the host must support.
709        //
710        // The vmbus server will ignore the monitor ID and allocate its own if MNF is handled by it
711        // and not the host.
712        self.vmbus_control
713            .offer_core(new_offer)
714            .await
715            .with_context(|| format!("failed to offer relay channel {key}"))?;
716
717        let (relay_request_send, relay_request_recv) = mesh::channel();
718        let channel_task = RelayChannelTask {
719            driver: Arc::clone(&self.spawner),
720            channel: RelayChannel {
721                channel_id: ChannelId(channel_id),
722                key,
723                relay_request_recv,
724                request_send: offer.request_send,
725                revoke_recv: offer.revoke_recv,
726                server_request_send,
727                server_request_recv: request_recv,
728                use_interrupt_relay: Arc::clone(&self.use_interrupt_relay),
729                interrupt_relay: None,
730                gpadls_tearing_down: FuturesUnordered::new(),
731                is_open: false,
732            },
733            running: self.running,
734        };
735
736        let task = self.spawner.spawn("vmbus hcl channel worker", async move {
737            channel_task.run().await;
738            ChannelId(channel_id)
739        });
740
741        self.channels.insert(
742            ChannelId(channel_id),
743            ChannelInfo::Relay(RelayChannelInfo { relay_request_send }),
744        );
745        self.channel_workers.push(task);
746
747        Ok(())
748    }
749
750    async fn handle_revoked(&mut self, channel_id: ChannelId) {
751        // The task has already completed, so just remove the channel from the list.
752        self.channels
753            .remove(&channel_id)
754            .expect("channel should exist");
755    }
756
757    async fn handle_modify(
758        &mut self,
759        request: vmbus_server::ModifyRelayRequest,
760    ) -> ModifyRelayResponse {
761        // If the guest is requesting a version change, check whether that version is not newer
762        // than what the host supports.
763        if let Some(version) = request.version {
764            if (self.version.version as u32) < version {
765                return ModifyRelayResponse::Unsupported;
766            }
767        }
768
769        if let Some(use_interrupt_page) = request.use_interrupt_page {
770            // If the guest is using the channel bitmap, the host can't send interrupts directly and
771            // must relay them through Underhill.
772            self.use_interrupt_relay
773                .store(use_interrupt_page, Ordering::SeqCst);
774        }
775
776        // If the monitor page is not changing, there is no need to send any request to the host.
777        let state = match request.monitor_page {
778            Update::Unchanged => protocol::ConnectionState::SUCCESSFUL,
779            Update::Reset => {
780                self.vmbus_client
781                    .modify(ModifyConnectionRequest { monitor_page: None })
782                    .await
783            }
784            Update::Set(value) => {
785                self.vmbus_client
786                    .modify(ModifyConnectionRequest {
787                        monitor_page: Some(value),
788                    })
789                    .await
790            }
791        };
792
793        // Use Supported only for new connections (which have a version).
794        if request.version.is_some() {
795            ModifyRelayResponse::Supported(state, self.version.feature_flags)
796        } else {
797            ModifyRelayResponse::Modified(state)
798        }
799    }
800
801    async fn handle_server_request(&mut self, request: vmbus_server::ModifyRelayRequest) {
802        tracing::trace!(request = ?request, "received server request");
803        let result = self.handle_modify(request).await;
804        self.server_response_send.send(result);
805    }
806
807    fn handle_hvsock_request(&mut self, request: HvsockConnectRequest) {
808        tracing::debug!(request = ?request, "received hvsock connect request");
809        let fut = self.vmbus_client.connect_hvsock(request);
810        self.hvsock_requests
811            .push(Box::pin(fut.map(move |offer| (request, offer))));
812    }
813
814    async fn handle_hvsock_response(
815        &mut self,
816        request: HvsockConnectRequest,
817        offer: Option<client::OfferInfo>,
818    ) {
819        let success = if let Some(offer) = offer {
820            match self.handle_offer(offer).await {
821                Ok(()) => true,
822                Err(err) => {
823                    tracing::error!(
824                        error = err.as_ref() as &dyn std::error::Error,
825                        ?request,
826                        "failed add hvsock offer"
827                    );
828                    false
829                }
830            }
831        } else {
832            false
833        };
834        self.hvsock_relay
835            .response_send
836            .send(HvsockConnectResult::from_request(&request, success));
837    }
838
839    async fn handle_offer_request(&mut self, request: client::OfferInfo) -> Result<()> {
840        let offer = request.offer;
841        if let Err(err) = self.handle_offer(request).await {
842            tracing::error!(
843                error = err.as_ref() as &dyn std::error::Error,
844                ?offer,
845                "failed to hot add offer"
846            );
847        }
848
849        Ok(())
850    }
851
852    async fn run(
853        &mut self,
854        server_recv: mesh::Receiver<vmbus_server::ModifyRelayRequest>,
855        mut offer_recv: mesh::Receiver<client::OfferInfo>,
856        mut task_recv: mesh::Receiver<TaskRequest>,
857    ) -> Result<()> {
858        let mut server_recv = server_recv.fuse();
859        loop {
860            let mut offer_recv =
861                OptionFuture::from(self.running.then(|| offer_recv.select_next_some()));
862
863            futures::select! { // merge semantics
864                r = server_recv.select_next_some() => {
865                    self.handle_server_request(r).await;
866                }
867                r = self.hvsock_relay.request_receive.select_next_some() => {
868                    self.handle_hvsock_request(r);
869                }
870                r = self.hvsock_requests.select_next_some() => {
871                    self.handle_hvsock_response(r.0, r.1).await;
872                }
873                r = offer_recv => {
874                    self.handle_offer_request(r.unwrap()).await?;
875                }
876                r = task_recv.recv().fuse() => {
877                    match r.unwrap() {
878                        TaskRequest::Inspect(req) => req.inspect(&mut *self),
879                        TaskRequest::Save(rpc) => rpc.handle(async |()| {
880                             self.handle_save().await
881                        }).await,
882                        TaskRequest::Restore(rpc) => rpc.handle(async |state|  {
883                            self.handle_restore(state).await
884                        }).await,
885                        TaskRequest::Start => self.handle_start().await,
886                        TaskRequest::Stop(rpc) => rpc.handle(async |()| self.handle_stop().await).await,
887                    }
888                }
889                r = self.channel_workers.select_next_some() => {
890                    self.handle_revoked(r).await;
891                }
892            }
893        }
894    }
895}