vmbus_relay/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Implements a vmbus channel relay, which consumes channels from the host
5//! vmbus control plane (via [`vmbus_client`]) and relays them as channels to
6//! the guest OS (via [`vmbus_server`]).
7//!
8//! This is used to allow the paravisor to implement the vmbus control plane
9//! while still passing through channels from the host, without any paravisor
10//! presence in the data plane.
11
12#![expect(missing_docs)]
13#![forbid(unsafe_code)]
14
15pub mod legacy_saved_state;
16mod saved_state;
17
18pub use saved_state::SavedState;
19
20use anyhow::Context;
21use anyhow::Result;
22use client::ModifyConnectionRequest;
23use futures::FutureExt;
24use futures::StreamExt;
25use futures::future::BoxFuture;
26use futures::future::OptionFuture;
27use futures::future::join_all;
28use guid::Guid;
29use inspect::Inspect;
30use inspect::InspectMut;
31use mesh::rpc::FailableRpc;
32use mesh::rpc::Rpc;
33use mesh::rpc::RpcSend;
34use pal_async::driver::SpawnDriver;
35use pal_async::task::Spawn;
36use pal_async::task::Task;
37use pal_event::Event;
38use std::collections::HashMap;
39use std::fmt::Debug;
40use std::future::Future;
41use std::pin::Pin;
42use std::sync::Arc;
43use std::sync::atomic::AtomicBool;
44use std::sync::atomic::Ordering;
45use unicycle::FuturesUnordered;
46use vmbus_channel::bus::ChannelRequest;
47use vmbus_channel::bus::ChannelServerRequest;
48use vmbus_channel::bus::GpadlRequest;
49use vmbus_channel::bus::ModifyRequest;
50use vmbus_channel::bus::OpenRequest;
51use vmbus_client as client;
52use vmbus_core::HvsockConnectRequest;
53use vmbus_core::HvsockConnectResult;
54use vmbus_core::VersionInfo;
55use vmbus_core::protocol;
56use vmbus_core::protocol::ChannelId;
57use vmbus_core::protocol::FeatureFlags;
58use vmbus_core::protocol::GpadlId;
59use vmbus_server::HvsockRelayChannelHalf;
60use vmbus_server::MnfUsage;
61use vmbus_server::ModifyRelayResponse;
62use vmbus_server::OfferInfo;
63use vmbus_server::OfferParamsInternal;
64use vmbus_server::Update;
65use vmbus_server::VmbusRelayChannelHalf;
66use vmbus_server::VmbusServerControl;
67use vmcore::interrupt::Interrupt;
68use vmcore::notify::Notify;
69use vmcore::notify::PolledNotify;
70
71pub enum InterceptChannelRequest {
72    Start,
73    Stop(Rpc<(), ()>),
74    Save(Rpc<(), vmcore::save_restore::SavedStateBlob>),
75    Restore(vmcore::save_restore::SavedStateBlob),
76    Offer(client::OfferInfo),
77}
78
79const REQUIRED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
80    .with_channel_interrupt_redirection(true)
81    .with_guest_specified_signal_parameters(true)
82    .with_modify_connection(true);
83
84/// Represents a relay between a vmbus server on the host, and the vmbus server running in
85/// Underhill, allowing offers from the host and offers from Underhill to be mixed.
86///
87/// The relay will connect to the host when it first receives a start request through its state
88/// unit, and will remain connected until it is destroyed.
89pub struct HostVmbusTransport {
90    _relay_task: Task<()>,
91    task_send: mesh::Sender<TaskRequest>,
92}
93
94impl HostVmbusTransport {
95    /// Create a new instance of the host vmbus relay.
96    pub async fn new(
97        driver: impl SpawnDriver + Clone,
98        control: Arc<VmbusServerControl>,
99        channel: VmbusRelayChannelHalf,
100        hvsock_relay: HvsockRelayChannelHalf,
101        vmbus_client: client::VmbusClientAccess,
102        connection: client::ConnectResult,
103        intercept_list: Vec<(Guid, mesh::Sender<InterceptChannelRequest>)>,
104    ) -> Result<Self> {
105        if connection.version.feature_flags & REQUIRED_FEATURE_FLAGS != REQUIRED_FEATURE_FLAGS {
106            anyhow::bail!(
107                "host must support required feature flags. \
108                 Required: {REQUIRED_FEATURE_FLAGS:?}, actual: {:?}.",
109                connection.version.feature_flags
110            );
111        }
112
113        let mut relay_task = RelayTask::new(
114            Arc::new(driver.clone()),
115            control,
116            channel.response_send,
117            hvsock_relay,
118            vmbus_client,
119            connection.version,
120        );
121
122        relay_task.intercept_channels.extend(intercept_list);
123
124        for offer in connection.offers {
125            relay_task.handle_offer(offer).await?;
126        }
127
128        let (task_send, task_recv) = mesh::channel();
129
130        let relay_task = driver.spawn("vmbus hcl relay", async move {
131            relay_task
132                .run(channel.request_receive, connection.offer_recv, task_recv)
133                .await
134                .unwrap()
135        });
136
137        Ok(Self {
138            _relay_task: relay_task,
139            task_send,
140        })
141    }
142
143    pub fn start(&self) {
144        self.task_send.send(TaskRequest::Start);
145    }
146
147    pub async fn stop(&self) {
148        self.task_send.call(TaskRequest::Stop, ()).await.unwrap()
149    }
150
151    pub async fn save(&self) -> SavedState {
152        self.task_send.call(TaskRequest::Save, ()).await.unwrap()
153    }
154
155    pub async fn restore(&self, state: SavedState) -> Result<()> {
156        self.task_send
157            .call(TaskRequest::Restore, state)
158            .await
159            .unwrap()
160    }
161}
162
163impl Inspect for HostVmbusTransport {
164    fn inspect(&self, req: inspect::Request<'_>) {
165        self.task_send.send(TaskRequest::Inspect(req.defer()));
166    }
167}
168
169impl Debug for HostVmbusTransport {
170    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        writeln!(fmt, "HostVmbusTransport")
172    }
173}
174
175/// State needed to relay host-to-guest interrupts.
176struct InterruptRelay {
177    /// Event signaled when the host sends an interrupt.
178    notify: PolledNotify,
179    /// Interrupt used to signal the guest.
180    interrupt: Interrupt,
181    /// Event flag used to signal the guest.
182    /// FUTURE: remove once this moves into `vmbus_client` saved state.
183    event_flag: u16,
184}
185
186enum RelayChannelRequest {
187    Start,
188    Stop(Rpc<(), ()>),
189    Save(Rpc<(), saved_state::Channel>),
190    Restore(FailableRpc<saved_state::Channel, ()>),
191    Inspect(inspect::Deferred),
192}
193
194impl Debug for RelayChannelRequest {
195    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196        match self {
197            RelayChannelRequest::Start => f.pad("Start"),
198            RelayChannelRequest::Stop(..) => f.pad("Stop"),
199            RelayChannelRequest::Save(..) => f.pad("Save"),
200            RelayChannelRequest::Restore(..) => f.pad("Restore"),
201            RelayChannelRequest::Inspect(..) => f.pad("Inspect"),
202        }
203    }
204}
205
206struct RelayChannelInfo {
207    relay_request_send: mesh::Sender<RelayChannelRequest>,
208}
209
210impl Inspect for RelayChannelInfo {
211    fn inspect(&self, req: inspect::Request<'_>) {
212        self.relay_request_send
213            .send(RelayChannelRequest::Inspect(req.defer()));
214    }
215}
216
217#[derive(Inspect)]
218#[inspect(external_tag)]
219enum ChannelInfo {
220    #[inspect(transparent)]
221    Relay(RelayChannelInfo),
222    #[inspect(transparent)]
223    Intercept(Guid),
224}
225
226impl RelayChannelInfo {
227    async fn stop(&self) {
228        if let Err(err) = self
229            .relay_request_send
230            .call(RelayChannelRequest::Stop, ())
231            .await
232        {
233            tracing::warn!(?err, "failed to request channel stop");
234        }
235    }
236
237    fn start(&self) {
238        self.relay_request_send.send(RelayChannelRequest::Start);
239    }
240}
241
242/// Connects a Client channel to a Server Channel.
243#[derive(Inspect)]
244struct RelayChannel {
245    /// The Channel Id given to us by the client
246    channel_id: ChannelId,
247    /// Receives requests from the relay.
248    #[inspect(skip)]
249    relay_request_recv: mesh::Receiver<RelayChannelRequest>,
250    /// Receives requests from the server.
251    #[inspect(skip)]
252    server_request_recv: mesh::Receiver<ChannelRequest>,
253    #[inspect(skip)]
254    server_request_send: mesh::Sender<ChannelServerRequest>,
255    /// Closed when the channel has been revoked.
256    #[inspect(skip)]
257    revoke_recv: mesh::OneshotReceiver<()>,
258    /// Sends requests to the client
259    #[inspect(skip)]
260    request_send: mesh::Sender<client::ChannelRequest>,
261    /// Indicates whether or not interrupts should be relayed. This is shared with the relay server
262    /// connection, which sets this to true only if the guest uses the channel bitmap.
263    use_interrupt_relay: Arc<AtomicBool>,
264    /// State used to relay host-to-guest interrupts.
265    #[inspect(with = "Option::is_some")]
266    interrupt_relay: Option<InterruptRelay>,
267    /// Futures waiting for GPADL teardown to complete before responding to
268    /// `vmbus_server`.
269    #[inspect(skip)]
270    gpadls_tearing_down: FuturesUnordered<BoxFuture<'static, ()>>,
271    is_open: bool,
272}
273
274#[derive(InspectMut)]
275struct RelayChannelTask {
276    #[inspect(skip)]
277    driver: Arc<dyn SpawnDriver>,
278    channel: RelayChannel,
279    running: bool,
280}
281
282impl RelayChannelTask {
283    /// Relay open channel request from VTL0 to Host, responding with Open Result
284    async fn handle_open_channel(&mut self, open_request: &OpenRequest) -> Result<()> {
285        // If the guest uses the channel bitmap, the host can't send interrupts
286        // directly and they must be relayed.
287        let redirect_interrupt = self.channel.use_interrupt_relay.load(Ordering::SeqCst);
288        let (incoming_event, notify) = if redirect_interrupt {
289            let event = Event::new();
290            let notify = Notify::from_event(event.clone())
291                .pollable(self.driver.as_ref())
292                .context("failed to create polled notify")?;
293            Some((event, notify))
294        } else {
295            None
296        }
297        .unzip();
298
299        let opened = self
300            .channel
301            .request_send
302            .call_failable(
303                client::ChannelRequest::Open,
304                client::OpenRequest {
305                    open_data: open_request.open_data,
306                    incoming_event,
307                    use_vtl2_connection_id: false,
308                },
309            )
310            .await?;
311
312        if let Some(notify) = notify {
313            self.channel.interrupt_relay = Some(InterruptRelay {
314                notify,
315                interrupt: open_request.interrupt.clone(),
316                event_flag: opened.redirected_event_flag.unwrap(),
317            });
318        }
319
320        self.channel.is_open = true;
321
322        Ok(())
323    }
324
325    async fn handle_close_channel(&mut self) {
326        self.channel
327            .request_send
328            .call(client::ChannelRequest::Close, ())
329            .await
330            .ok();
331
332        self.channel.interrupt_relay = None;
333        self.channel.is_open = false;
334    }
335
336    /// Relay gpadl request from VTL0 to the Host and respond with gpadl created.
337    async fn handle_gpadl(&mut self, request: GpadlRequest) -> Result<()> {
338        self.channel
339            .request_send
340            .call_failable(client::ChannelRequest::Gpadl, request)
341            .await?;
342
343        Ok(())
344    }
345
346    fn handle_gpadl_teardown(&mut self, rpc: Rpc<GpadlId, ()>) {
347        let (gpadl_id, rpc) = rpc.split();
348        tracing::trace!(gpadl_id = gpadl_id.0, "Tearing down GPADL");
349
350        let call = self
351            .channel
352            .request_send
353            .call(client::ChannelRequest::TeardownGpadl, gpadl_id);
354
355        // We cannot wait for GpadlTorndown here, because the host may not send the GpadlTorndown
356        // message immediately, for example if the channel is still open and the host device still
357        // has the gpadl mapped. We should not block further requests while waiting for the
358        // response.
359        self.channel.gpadls_tearing_down.push(Box::pin(async move {
360            if let Err(err) = call.await {
361                tracing::warn!(
362                    error = &err as &dyn std::error::Error,
363                    "failed to send gpadl teardown"
364                );
365            }
366            rpc.complete(());
367        }));
368    }
369
370    async fn handle_modify_channel(&mut self, modify_request: ModifyRequest) -> Result<i32> {
371        let status = self
372            .channel
373            .request_send
374            .call(client::ChannelRequest::Modify, modify_request)
375            .await?;
376
377        Ok(status)
378    }
379
380    /// Dispatch requests sent by VTL0
381    async fn handle_server_request(&mut self, request: ChannelRequest) -> Result<()> {
382        tracing::trace!(request = ?request, "received channel request");
383        match request {
384            ChannelRequest::Open(rpc) => {
385                rpc.handle(async |open_request| {
386                    self.handle_open_channel(&open_request)
387                        .await
388                        .inspect_err(|err| {
389                            tracelimit::error_ratelimited!(
390                                err = err.as_ref() as &dyn std::error::Error,
391                                channel_id = self.channel.channel_id.0,
392                                "failed to open channel"
393                            );
394                        })
395                        .is_ok()
396                })
397                .await;
398            }
399            ChannelRequest::Gpadl(rpc) => {
400                rpc.handle(async |gpadl| {
401                    let id = gpadl.id;
402                    self.handle_gpadl(gpadl)
403                        .await
404                        .inspect_err(|err| {
405                            tracelimit::error_ratelimited!(
406                                err = err.as_ref() as &dyn std::error::Error,
407                                channel_id = self.channel.channel_id.0,
408                                gpadl_id = id.0,
409                                "failed to create gpadl"
410                            );
411                        })
412                        .is_ok()
413                })
414                .await;
415            }
416            ChannelRequest::Close(rpc) => {
417                rpc.handle(async |()| self.handle_close_channel().await)
418                    .await;
419            }
420            ChannelRequest::TeardownGpadl(rpc) => {
421                self.handle_gpadl_teardown(rpc);
422            }
423            ChannelRequest::Modify(rpc) => {
424                rpc.handle(async |request| self.handle_modify_channel(request).await.unwrap_or(-1))
425                    .await;
426            }
427        }
428
429        Ok(())
430    }
431
432    async fn handle_relay_request(&mut self, request: RelayChannelRequest) {
433        tracing::trace!(
434            channel_id = self.channel.channel_id.0,
435            ?request,
436            "received relay request"
437        );
438
439        match request {
440            RelayChannelRequest::Start => self.running = true,
441            RelayChannelRequest::Stop(rpc) => rpc.handle_sync(|()| self.running = false),
442            RelayChannelRequest::Save(rpc) => rpc.handle_sync(|_| self.handle_save()),
443            RelayChannelRequest::Restore(rpc) => {
444                rpc.handle_failable(async |state| self.handle_restore(state).await)
445                    .await
446            }
447            RelayChannelRequest::Inspect(deferred) => deferred.inspect(self),
448        }
449    }
450
451    /// Request dispatch loop
452    async fn run(mut self) {
453        loop {
454            let mut relay_event = OptionFuture::from(
455                self.channel
456                    .interrupt_relay
457                    .as_mut()
458                    .map(|e| e.notify.wait().fuse()),
459            );
460
461            let mut server_request = OptionFuture::from(
462                self.running
463                    .then(|| self.channel.server_request_recv.next()),
464            );
465
466            futures::select! { // merge semantics
467                r = self.channel.relay_request_recv.next() => {
468                    match r {
469                        Some(request) => {
470                            // Needed to avoid conflicting &mut self borrow.
471                            drop(relay_event);
472                            self.handle_relay_request(request).await;
473                        }
474                        None => {
475                            break;
476                        }
477                    }
478                }
479                r = server_request => {
480                    match r.unwrap() {
481                        Some(request) => {
482                            // Needed to avoid conflicting &mut self borrow.
483                            drop(relay_event);
484                            self
485                                .handle_server_request(request)
486                                .await
487                                .expect("failed to get server request");
488                        }
489                        None => {
490                            break;
491                        }
492                    }
493                }
494                _r = (&mut self.channel.revoke_recv).fuse() => {
495                    break;
496                }
497                () = self.channel.gpadls_tearing_down.select_next_some() => {}
498                _r = relay_event => {
499                    // Needed to avoid conflicting interrupt_relay borrow.
500                    drop(relay_event);
501                    self.channel.interrupt_relay.as_ref().unwrap().interrupt.deliver();
502                }
503            }
504        }
505
506        // Drain GPADL teardown requests cleanly; these will all complete now
507        // that the channel has been revoked.
508        while let Some(()) = self.channel.gpadls_tearing_down.next().await {}
509
510        tracing::debug!(channel_id = %self.channel.channel_id.0, "dropped channel");
511
512        // Dropping the channel would revoke it, but since that's not synchronized there's a chance
513        // we reoffer the channel before the server receives the revoke. Using the request ensures
514        // that won't happen.
515        if let Err(err) = self
516            .channel
517            .server_request_send
518            .call(ChannelServerRequest::Revoke, ())
519            .await
520        {
521            tracing::warn!(
522                channel_id = self.channel.channel_id.0,
523                err = &err as &dyn std::error::Error,
524                "failed to send revoke request"
525            );
526        }
527    }
528}
529
530enum TaskRequest {
531    Inspect(inspect::Deferred),
532    Save(Rpc<(), SavedState>),
533    Restore(Rpc<SavedState, Result<()>>),
534    Start,
535    Stop(Rpc<(), ()>),
536}
537
538/// Dispatches requests between Server/Client.
539#[derive(InspectMut)]
540struct RelayTask {
541    #[inspect(skip)]
542    spawner: Arc<dyn SpawnDriver>,
543    #[inspect(skip)]
544    vmbus_client: client::VmbusClientAccess,
545    version: VersionInfo,
546    #[inspect(skip)]
547    vmbus_control: Arc<VmbusServerControl>,
548    #[inspect(with = "|x| inspect::iter_by_key(x).map_key(|x| x.0)")]
549    channels: HashMap<ChannelId, ChannelInfo>,
550    #[inspect(skip)]
551    channel_workers: FuturesUnordered<Task<ChannelId>>,
552    #[inspect(with = "|x| inspect::iter_by_key(x).map_value(|_| ())")]
553    intercept_channels: HashMap<Guid, mesh::Sender<InterceptChannelRequest>>,
554    use_interrupt_relay: Arc<AtomicBool>,
555    #[inspect(skip)]
556    server_response_send: mesh::Sender<ModifyRelayResponse>,
557    #[inspect(skip)]
558    hvsock_relay: HvsockRelayChannelHalf,
559    #[inspect(skip)]
560    hvsock_requests: FuturesUnordered<HvsockRequestFuture>,
561    running: bool,
562}
563
564type HvsockRequestFuture =
565    Pin<Box<dyn Future<Output = (HvsockConnectRequest, Option<client::OfferInfo>)> + Sync + Send>>;
566
567impl RelayTask {
568    fn new(
569        spawner: Arc<dyn SpawnDriver>,
570        vmbus_control: Arc<VmbusServerControl>,
571        server_response_send: mesh::Sender<ModifyRelayResponse>,
572        hvsock_relay: HvsockRelayChannelHalf,
573        vmbus_client: client::VmbusClientAccess,
574        version: VersionInfo,
575    ) -> Self {
576        Self {
577            spawner,
578            vmbus_client,
579            version,
580            vmbus_control,
581            channels: HashMap::new(),
582            channel_workers: FuturesUnordered::new(),
583            intercept_channels: HashMap::new(),
584            use_interrupt_relay: Arc::new(AtomicBool::new(false)),
585            server_response_send,
586            hvsock_relay,
587            running: false,
588            hvsock_requests: FuturesUnordered::new(),
589        }
590    }
591
592    async fn handle_start(&mut self) {
593        if !self.running {
594            // Resume all channels.
595            for c in self.channels.values() {
596                match c {
597                    ChannelInfo::Relay(relay) => relay.start(),
598                    ChannelInfo::Intercept(id) => {
599                        let Some(intercept_channel) = self.intercept_channels.get(id) else {
600                            tracing::error!(%id, "Intercept device missing from list");
601                            continue;
602                        };
603                        intercept_channel.send(InterceptChannelRequest::Start);
604                    }
605                }
606            }
607
608            self.running = true;
609        }
610    }
611
612    async fn handle_stop(&mut self) {
613        if self.running {
614            // Stop all the channels before the relay itself can stop.
615            join_all(self.channels.values().map(|c| match c {
616                ChannelInfo::Relay(relay) => futures::future::Either::Left(relay.stop()),
617                ChannelInfo::Intercept(id) => futures::future::Either::Right(async {
618                    let id = *id;
619                    if let Some(intercept_channel) = self.intercept_channels.get(&id) {
620                        if let Err(err) = intercept_channel
621                            .call(InterceptChannelRequest::Stop, ())
622                            .await
623                        {
624                            tracing::error!(
625                                err = &err as &dyn std::error::Error,
626                                %id,
627                                "Failed to stop intercepted device"
628                            );
629                        }
630                    }
631                }),
632            }))
633            .await;
634
635            // Because requests are handled "synchronously" (async is used but everything is awaited
636            // before another request is handled), there is no need for rundown and the relay can
637            // stop immediately.
638            self.running = false;
639        }
640    }
641
642    /// Translates an offer received from the client to a server offer.
643    /// Additionally, sets up all the appropriate channels.
644    async fn handle_offer(&mut self, offer: client::OfferInfo) -> Result<()> {
645        let channel_id = offer.offer.channel_id.0;
646
647        if self.channels.contains_key(&ChannelId(channel_id)) {
648            anyhow::bail!("channel {channel_id} already exists");
649        }
650
651        if let Some(intercept) = self.intercept_channels.get(&offer.offer.instance_id) {
652            self.channels.insert(
653                ChannelId(channel_id),
654                ChannelInfo::Intercept(offer.offer.instance_id),
655            );
656            intercept.send(InterceptChannelRequest::Offer(offer));
657            return Ok(());
658        }
659
660        // Used to Recv requests from the server.
661        let (request_send, request_recv) = mesh::channel();
662        // Used to Send responses from the server
663        let (server_request_send, server_request_recv) = mesh::channel();
664
665        if offer.offer.is_dedicated != 1 {
666            tracing::warn!(offer = ?offer.offer, "All offers should be dedicated with Win8+ host")
667        }
668
669        // If the vmbus server is handling MnF, instead of relaying it, it will ignore this monitor
670        // ID and allocate its own.
671        let use_mnf = if offer.offer.monitor_allocated != 0 {
672            MnfUsage::Relayed {
673                monitor_id: offer.offer.monitor_id,
674            }
675        } else {
676            MnfUsage::Disabled
677        };
678
679        let params = OfferParamsInternal {
680            interface_name: "host relay".to_owned(),
681            instance_id: offer.offer.instance_id,
682            interface_id: offer.offer.interface_id,
683            mmio_megabytes: offer.offer.mmio_megabytes,
684            mmio_megabytes_optional: offer.offer.mmio_megabytes_optional,
685            subchannel_index: offer.offer.subchannel_index,
686            use_mnf,
687            // Preserve channel enumeration order from the host within the same
688            // interface type.
689            offer_order: Some(channel_id),
690            // Strip the confidential flags for relay channels if the host set them.
691            flags: offer
692                .offer
693                .flags
694                .with_confidential_ring_buffer(false)
695                .with_confidential_external_memory(false),
696            user_defined: offer.offer.user_defined,
697        };
698
699        let key = params.key();
700        let new_offer = OfferInfo {
701            params,
702            event: offer.guest_to_host_interrupt,
703            request_send,
704            server_request_recv,
705        };
706
707        // Don't send the client's channel and connection ID to the server. Instead, the server will
708        // decide its own IDs which are communicated back to the host as part of the open message
709        // using guest-specified signal parameters, which the host must support.
710        //
711        // The vmbus server will ignore the monitor ID and allocate its own if MNF is handled by it
712        // and not the host.
713        self.vmbus_control
714            .offer_core(new_offer)
715            .await
716            .with_context(|| format!("failed to offer relay channel {key}"))?;
717
718        let (relay_request_send, relay_request_recv) = mesh::channel();
719        let channel_task = RelayChannelTask {
720            driver: Arc::clone(&self.spawner),
721            channel: RelayChannel {
722                channel_id: ChannelId(channel_id),
723                relay_request_recv,
724                request_send: offer.request_send,
725                revoke_recv: offer.revoke_recv,
726                server_request_send,
727                server_request_recv: request_recv,
728                use_interrupt_relay: Arc::clone(&self.use_interrupt_relay),
729                interrupt_relay: None,
730                gpadls_tearing_down: FuturesUnordered::new(),
731                is_open: false,
732            },
733            running: self.running,
734        };
735
736        let task = self.spawner.spawn("vmbus hcl channel worker", async move {
737            channel_task.run().await;
738            ChannelId(channel_id)
739        });
740
741        self.channels.insert(
742            ChannelId(channel_id),
743            ChannelInfo::Relay(RelayChannelInfo { relay_request_send }),
744        );
745        self.channel_workers.push(task);
746
747        Ok(())
748    }
749
750    async fn handle_revoked(&mut self, channel_id: ChannelId) {
751        // The task has already completed, so just remove the channel from the list.
752        self.channels
753            .remove(&channel_id)
754            .expect("channel should exist");
755    }
756
757    async fn handle_modify(
758        &mut self,
759        request: vmbus_server::ModifyRelayRequest,
760    ) -> ModifyRelayResponse {
761        // If the guest is requesting a version change, check whether that version is not newer
762        // than what the host supports.
763        if let Some(version) = request.version {
764            if (self.version.version as u32) < version {
765                return ModifyRelayResponse::Unsupported;
766            }
767        }
768
769        if let Some(use_interrupt_page) = request.use_interrupt_page {
770            // If the guest is using the channel bitmap, the host can't send interrupts directly and
771            // must relay them through Underhill.
772            self.use_interrupt_relay
773                .store(use_interrupt_page, Ordering::SeqCst);
774        }
775
776        // If the monitor page is not changing, there is no need to send any request to the host.
777        let state = match request.monitor_page {
778            Update::Unchanged => protocol::ConnectionState::SUCCESSFUL,
779            Update::Reset => {
780                self.vmbus_client
781                    .modify(ModifyConnectionRequest { monitor_page: None })
782                    .await
783            }
784            Update::Set(value) => {
785                self.vmbus_client
786                    .modify(ModifyConnectionRequest {
787                        monitor_page: Some(value),
788                    })
789                    .await
790            }
791        };
792
793        // Use Supported only for new connections (which have a version).
794        if request.version.is_some() {
795            ModifyRelayResponse::Supported(state, self.version.feature_flags)
796        } else {
797            ModifyRelayResponse::Modified(state)
798        }
799    }
800
801    async fn handle_server_request(&mut self, request: vmbus_server::ModifyRelayRequest) {
802        tracing::trace!(request = ?request, "received server request");
803        let result = self.handle_modify(request).await;
804        self.server_response_send.send(result);
805    }
806
807    fn handle_hvsock_request(&mut self, request: HvsockConnectRequest) {
808        tracing::debug!(request = ?request, "received hvsock connect request");
809        let fut = self.vmbus_client.connect_hvsock(request);
810        self.hvsock_requests
811            .push(Box::pin(fut.map(move |offer| (request, offer))));
812    }
813
814    async fn handle_hvsock_response(
815        &mut self,
816        request: HvsockConnectRequest,
817        offer: Option<client::OfferInfo>,
818    ) {
819        let success = if let Some(offer) = offer {
820            match self.handle_offer(offer).await {
821                Ok(()) => true,
822                Err(err) => {
823                    tracing::error!(
824                        error = err.as_ref() as &dyn std::error::Error,
825                        "failed add hvsock offer"
826                    );
827                    false
828                }
829            }
830        } else {
831            false
832        };
833        self.hvsock_relay
834            .response_send
835            .send(HvsockConnectResult::from_request(&request, success));
836    }
837
838    async fn handle_offer_request(&mut self, request: client::OfferInfo) -> Result<()> {
839        if let Err(err) = self.handle_offer(request).await {
840            tracing::error!(
841                error = err.as_ref() as &dyn std::error::Error,
842                "failed to hot add offer"
843            );
844        }
845
846        Ok(())
847    }
848
849    async fn run(
850        &mut self,
851        server_recv: mesh::Receiver<vmbus_server::ModifyRelayRequest>,
852        mut offer_recv: mesh::Receiver<client::OfferInfo>,
853        mut task_recv: mesh::Receiver<TaskRequest>,
854    ) -> Result<()> {
855        let mut server_recv = server_recv.fuse();
856        loop {
857            let mut offer_recv =
858                OptionFuture::from(self.running.then(|| offer_recv.select_next_some()));
859
860            futures::select! { // merge semantics
861                r = server_recv.select_next_some() => {
862                    self.handle_server_request(r).await;
863                }
864                r = self.hvsock_relay.request_receive.select_next_some() => {
865                    self.handle_hvsock_request(r);
866                }
867                r = self.hvsock_requests.select_next_some() => {
868                    self.handle_hvsock_response(r.0, r.1).await;
869                }
870                r = offer_recv => {
871                    self.handle_offer_request(r.unwrap()).await?;
872                }
873                r = task_recv.recv().fuse() => {
874                    match r.unwrap() {
875                        TaskRequest::Inspect(req) => req.inspect(&mut *self),
876                        TaskRequest::Save(rpc) => rpc.handle(async |()| {
877                             self.handle_save().await
878                        }).await,
879                        TaskRequest::Restore(rpc) => rpc.handle(async |state|  {
880                            self.handle_restore(state).await
881                        }).await,
882                        TaskRequest::Start => self.handle_start().await,
883                        TaskRequest::Stop(rpc) => rpc.handle(async |()| self.handle_stop().await).await,
884                    }
885                }
886                r = self.channel_workers.select_next_some() => {
887                    self.handle_revoked(r).await;
888                }
889            }
890        }
891    }
892}