hyperv_ic_guest/
shutdown.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! The shutdown IC client.
5
6#![forbid(unsafe_code)]
7
8pub use hyperv_ic_protocol::shutdown::INSTANCE_ID;
9pub use hyperv_ic_protocol::shutdown::INTERFACE_ID;
10
11use guid::Guid;
12use hyperv_ic_protocol::FRAMEWORK_VERSION_1;
13use hyperv_ic_protocol::FRAMEWORK_VERSION_3;
14use hyperv_ic_protocol::Status;
15use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_1;
16use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_3;
17use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_3_1;
18use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_3_2;
19use hyperv_ic_resources::shutdown::ShutdownParams;
20use hyperv_ic_resources::shutdown::ShutdownResult;
21use hyperv_ic_resources::shutdown::ShutdownType;
22use inspect::Inspect;
23use inspect::InspectMut;
24use mesh::rpc::Rpc;
25use mesh::rpc::RpcSend;
26use std::io::IoSlice;
27use std::mem::size_of_val;
28use task_control::Cancelled;
29use task_control::StopTask;
30use thiserror::Error;
31use vmbus_async::async_dgram::AsyncRecvExt;
32use vmbus_async::async_dgram::AsyncSendExt;
33use vmbus_async::pipe::MessagePipe;
34use vmbus_channel::RawAsyncChannel;
35use vmbus_channel::channel::ChannelOpenError;
36use vmbus_relay_intercept_device::OfferResponse;
37use vmbus_relay_intercept_device::SaveRestoreSimpleVmbusClientDevice;
38use vmbus_relay_intercept_device::SimpleVmbusClientDevice;
39use vmbus_relay_intercept_device::SimpleVmbusClientDeviceAsync;
40use vmbus_relay_intercept_device::ring_buffer::MemoryBlockRingBuffer;
41use vmbus_ring::RingMem;
42use vmcore::save_restore::NoSavedState;
43use zerocopy::FromBytes;
44use zerocopy::FromZeros;
45use zerocopy::IntoBytes;
46
47/// A shutdown IC client device.
48#[derive(InspectMut)]
49pub struct ShutdownGuestIc {
50    #[inspect(skip)]
51    send_shutdown_notification: mesh::Sender<Rpc<ShutdownParams, ShutdownResult>>,
52    #[inspect(skip)]
53    recv_shutdown_notification: Option<mesh::Receiver<Rpc<ShutdownParams, ShutdownResult>>>,
54}
55
56#[derive(Debug, Inspect)]
57#[inspect(tag = "channel_state")]
58enum ShutdownGuestChannelState {
59    NegotiateVersion,
60    Running {
61        #[inspect(display)]
62        framework_version: hyperv_ic_protocol::Version,
63        #[inspect(display)]
64        message_version: hyperv_ic_protocol::Version,
65    },
66}
67
68/// Established channel between guest and host.
69#[derive(InspectMut)]
70pub struct ShutdownGuestChannel {
71    /// Current state.
72    state: ShutdownGuestChannelState,
73    /// Vmbus pipe to the host.
74    #[inspect(mut)]
75    pipe: MessagePipe<MemoryBlockRingBuffer>,
76}
77
78#[derive(Debug, Error)]
79enum Error {
80    #[error("ring buffer error")]
81    Ring(#[source] std::io::Error),
82    #[error("truncated message")]
83    TruncatedMessage,
84}
85
86impl ShutdownGuestIc {
87    /// Returns a new shutdown IC client device.
88    pub fn new() -> Self {
89        let (send_shutdown_notification, recv_shutdown_notification) = mesh::channel();
90        Self {
91            send_shutdown_notification,
92            recv_shutdown_notification: Some(recv_shutdown_notification),
93        }
94    }
95
96    /// Returns the notifier that will receive any shutdown requests from the host.
97    pub fn get_shutdown_notifier(&mut self) -> mesh::Receiver<Rpc<ShutdownParams, ShutdownResult>> {
98        self.recv_shutdown_notification
99            .take()
100            .expect("can only be called once")
101    }
102}
103
104impl ShutdownGuestChannel {
105    fn new(pipe: MessagePipe<MemoryBlockRingBuffer>) -> Self {
106        Self {
107            state: ShutdownGuestChannelState::NegotiateVersion,
108            pipe,
109        }
110    }
111
112    async fn process(&mut self, ic: &mut ShutdownGuestIc) -> Result<(), Error> {
113        loop {
114            match read_from_pipe(&mut self.pipe).await {
115                Ok(buf) => {
116                    self.handle_host_message(&buf, ic).await;
117                }
118                Err(err) => {
119                    tracelimit::error_ratelimited!(
120                        err = &err as &dyn std::error::Error,
121                        "reading shutdown packet from host",
122                    );
123                }
124            }
125        }
126    }
127
128    async fn handle_host_message(&mut self, buf: &[u8], ic: &ShutdownGuestIc) {
129        // TODO: zerocopy: err (https://github.com/microsoft/openvmm/issues/759)
130        let (header, rest) = match hyperv_ic_protocol::Header::read_from_prefix(buf).ok() {
131            Some((h, r)) => (h, r),
132            None => {
133                tracelimit::error_ratelimited!("invalid shutdown packet from host",);
134                return;
135            }
136        };
137        match header.message_type {
138            hyperv_ic_protocol::MessageType::VERSION_NEGOTIATION => {
139                // Version negotiation can happen multiple times due to various
140                // state changes on the host. This message triggers a reset
141                // of the current state.
142                self.state = ShutdownGuestChannelState::NegotiateVersion;
143                if let Err(err) = self.handle_version_negotiation(&header, rest).await {
144                    tracelimit::error_ratelimited!(
145                        err = &err as &dyn std::error::Error,
146                        "Failed version negotiation"
147                    );
148                }
149            }
150            hyperv_ic_protocol::MessageType::SHUTDOWN
151                if matches!(self.state, ShutdownGuestChannelState::Running { .. }) =>
152            {
153                if let Err(err) = self.handle_shutdown_notification(&header, rest, ic).await {
154                    tracelimit::error_ratelimited!(
155                        err = &err as &dyn std::error::Error,
156                        "Failed processing shutdown message"
157                    );
158                }
159            }
160            _ => {
161                tracelimit::error_ratelimited!(r#type = ?header.message_type, state = ?self.state, "Unrecognized packet");
162            }
163        }
164    }
165
166    fn find_latest_supported_version<'a>(
167        buf: &'a [u8],
168        count: usize,
169        supported: &[hyperv_ic_protocol::Version],
170    ) -> (Option<hyperv_ic_protocol::Version>, &'a [u8]) {
171        let mut rest = buf;
172        let mut next_version;
173        let mut latest_version = None;
174        for _ in 0..count {
175            // TODO: zerocopy: err (https://github.com/microsoft/openvmm/issues/759)
176            (next_version, rest) = match hyperv_ic_protocol::Version::read_from_prefix(rest).ok() {
177                Some((n, r)) => (n, r),
178                None => {
179                    tracelimit::error_ratelimited!("truncated message version list");
180                    return (latest_version, rest);
181                }
182            };
183            for known in supported {
184                if known.major == next_version.major && known.minor == next_version.minor {
185                    if latest_version.is_some() {
186                        if next_version.major >= latest_version.unwrap().major {
187                            if next_version.major > latest_version.unwrap().major
188                                || next_version.minor > latest_version.unwrap().minor
189                            {
190                                latest_version = Some(next_version);
191                            }
192                        }
193                    } else {
194                        latest_version = Some(next_version);
195                    }
196                }
197            }
198        }
199        (latest_version, rest)
200    }
201
202    async fn handle_version_negotiation(
203        &mut self,
204        header: &hyperv_ic_protocol::Header,
205        msg: &[u8],
206    ) -> Result<(), Error> {
207        const FRAMEWORK_VERSIONS: &[hyperv_ic_protocol::Version] =
208            &[FRAMEWORK_VERSION_1, FRAMEWORK_VERSION_3];
209
210        const SHUTDOWN_VERSIONS: &[hyperv_ic_protocol::Version] = &[
211            SHUTDOWN_VERSION_1,
212            SHUTDOWN_VERSION_3,
213            SHUTDOWN_VERSION_3_1,
214            SHUTDOWN_VERSION_3_2,
215        ];
216
217        let (prefix, rest) = hyperv_ic_protocol::NegotiateMessage::read_from_prefix(msg)
218            .map_err(|_| Error::TruncatedMessage)?; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
219        let (latest_framework_version, rest) = Self::find_latest_supported_version(
220            rest,
221            prefix.framework_version_count as usize,
222            FRAMEWORK_VERSIONS,
223        );
224        let framework_version = if let Some(version) = latest_framework_version {
225            version
226        } else {
227            tracelimit::error_ratelimited!("Unsupported framework version");
228            FRAMEWORK_VERSIONS[FRAMEWORK_VERSIONS.len() - 1]
229        };
230        let (latest_message_version, _) = Self::find_latest_supported_version(
231            rest,
232            prefix.message_version_count as usize,
233            SHUTDOWN_VERSIONS,
234        );
235        let message_version = if let Some(version) = latest_message_version {
236            version
237        } else {
238            tracelimit::error_ratelimited!("Unsupported message version");
239            SHUTDOWN_VERSIONS[SHUTDOWN_VERSIONS.len() - 1]
240        };
241
242        let message = hyperv_ic_protocol::NegotiateMessage {
243            framework_version_count: 1,
244            message_version_count: 1,
245            ..FromZeros::new_zeroed()
246        };
247        let response = hyperv_ic_protocol::Header {
248            message_type: hyperv_ic_protocol::MessageType::VERSION_NEGOTIATION,
249            message_size: (size_of_val(&message)
250                + size_of_val(&framework_version)
251                + size_of_val(&message_version)) as u16,
252            status: Status::SUCCESS,
253            transaction_id: header.transaction_id,
254            flags: hyperv_ic_protocol::HeaderFlags::new()
255                .with_transaction(header.flags.transaction())
256                .with_response(true),
257            ..FromZeros::new_zeroed()
258        };
259        self.pipe
260            .send_vectored(&[
261                IoSlice::new(response.as_bytes()),
262                IoSlice::new(message.as_bytes()),
263                IoSlice::new(framework_version.as_bytes()),
264                IoSlice::new(message_version.as_bytes()),
265            ])
266            .await
267            .map_err(Error::Ring)?;
268
269        tracing::info!(%framework_version, %message_version, "version negotiated");
270        self.state = ShutdownGuestChannelState::Running {
271            framework_version,
272            message_version,
273        };
274        Ok(())
275    }
276
277    async fn handle_shutdown_notification(
278        &mut self,
279        header: &hyperv_ic_protocol::Header,
280        buf: &[u8],
281        ic: &ShutdownGuestIc,
282    ) -> Result<(), Error> {
283        let ShutdownGuestChannelState::Running {
284            framework_version,
285            message_version,
286        } = &self.state
287        else {
288            panic!("Shutdown message processing while in invalid state");
289        };
290
291        let message = hyperv_ic_protocol::shutdown::ShutdownMessage::read_from_prefix(buf)
292            .map_err(|_| Error::TruncatedMessage)?
293            .0; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
294        let shutdown_type = if message.flags.restart() {
295            ShutdownType::Reboot
296        } else if message.flags.hibernate() {
297            ShutdownType::Hibernate
298        } else {
299            ShutdownType::PowerOff
300        };
301        let params = ShutdownParams {
302            shutdown_type,
303            force: message.flags.force(),
304        };
305
306        // Notify the internal listener and wait for a response.
307        let status = match ic.send_shutdown_notification.call(|x| x, params).await {
308            Ok(ShutdownResult::Ok) => Status::SUCCESS,
309            Ok(ShutdownResult::Failed(x)) => Status(x),
310            Ok(ShutdownResult::NotReady) | Ok(ShutdownResult::AlreadyInProgress) | Err(_) => {
311                Status::FAIL
312            }
313        };
314
315        // Respond to the request.
316        let response = hyperv_ic_protocol::Header {
317            framework_version: *framework_version,
318            message_version: *message_version,
319            message_type: hyperv_ic_protocol::MessageType::SHUTDOWN,
320            message_size: 0,
321            status,
322            transaction_id: header.transaction_id,
323            flags: hyperv_ic_protocol::HeaderFlags::new()
324                .with_transaction(header.flags.transaction())
325                .with_response(true),
326            ..FromZeros::new_zeroed()
327        };
328        self.pipe
329            .send(response.as_bytes())
330            .await
331            .map_err(Error::Ring)
332    }
333}
334
335async fn read_from_pipe<T: RingMem>(pipe: &mut MessagePipe<T>) -> Result<Vec<u8>, Error> {
336    let mut buf = vec![0; hyperv_ic_protocol::MAX_MESSAGE_SIZE];
337    let n = pipe.recv(&mut buf).await.map_err(Error::Ring)?;
338    let buf = &buf[..n];
339    Ok(buf.to_vec())
340}
341
342impl SimpleVmbusClientDevice for ShutdownGuestIc {
343    type SavedState = NoSavedState;
344    type Runner = ShutdownGuestChannel;
345
346    fn instance_id(&self) -> Guid {
347        INSTANCE_ID
348    }
349
350    fn offer(&self, _offer: &vmbus_core::protocol::OfferChannel) -> OfferResponse {
351        OfferResponse::Open
352    }
353
354    fn inspect(&mut self, req: inspect::Request<'_>, runner: Option<&mut Self::Runner>) {
355        req.respond().merge(self).merge(runner);
356    }
357
358    fn open(
359        &mut self,
360        _channel_idx: u16,
361        channel: RawAsyncChannel<MemoryBlockRingBuffer>,
362    ) -> Result<Self::Runner, ChannelOpenError> {
363        let pipe = MessagePipe::new(channel)?;
364        Ok(ShutdownGuestChannel::new(pipe))
365    }
366
367    fn close(&mut self, _channel_idx: u16) {}
368
369    fn supports_save_restore(
370        &mut self,
371    ) -> Option<
372        &mut dyn SaveRestoreSimpleVmbusClientDevice<
373            SavedState = Self::SavedState,
374            Runner = Self::Runner,
375        >,
376    > {
377        None
378    }
379}
380
381impl SimpleVmbusClientDeviceAsync for ShutdownGuestIc {
382    async fn run(
383        &mut self,
384        stop: &mut StopTask<'_>,
385        runner: &mut Self::Runner,
386    ) -> Result<(), Cancelled> {
387        stop.until_stopped(async {
388            match runner.process(self).await {
389                Ok(()) => {}
390                Err(err) => {
391                    tracing::error!(
392                        error = &err as &dyn std::error::Error,
393                        "shutdown ic relay error"
394                    )
395                }
396            }
397        })
398        .await
399    }
400}