hyperv_ic_guest/
shutdown.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! The shutdown IC client.
5
6#![cfg(target_os = "linux")]
7#![forbid(unsafe_code)]
8
9pub use hyperv_ic_protocol::shutdown::INSTANCE_ID;
10pub use hyperv_ic_protocol::shutdown::INTERFACE_ID;
11
12use guid::Guid;
13use hyperv_ic_protocol::FRAMEWORK_VERSION_1;
14use hyperv_ic_protocol::FRAMEWORK_VERSION_3;
15use hyperv_ic_protocol::Status;
16use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_1;
17use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_3;
18use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_3_1;
19use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_3_2;
20use hyperv_ic_resources::shutdown::ShutdownParams;
21use hyperv_ic_resources::shutdown::ShutdownResult;
22use hyperv_ic_resources::shutdown::ShutdownType;
23use inspect::Inspect;
24use inspect::InspectMut;
25use mesh::rpc::Rpc;
26use mesh::rpc::RpcSend;
27use std::io::IoSlice;
28use std::mem::size_of_val;
29use task_control::Cancelled;
30use task_control::StopTask;
31use thiserror::Error;
32use vmbus_async::async_dgram::AsyncRecvExt;
33use vmbus_async::async_dgram::AsyncSendExt;
34use vmbus_async::pipe::MessagePipe;
35use vmbus_channel::RawAsyncChannel;
36use vmbus_channel::channel::ChannelOpenError;
37use vmbus_relay_intercept_device::OfferResponse;
38use vmbus_relay_intercept_device::SaveRestoreSimpleVmbusClientDevice;
39use vmbus_relay_intercept_device::SimpleVmbusClientDevice;
40use vmbus_relay_intercept_device::SimpleVmbusClientDeviceAsync;
41use vmbus_relay_intercept_device::ring_buffer::MemoryBlockRingBuffer;
42use vmbus_ring::RingMem;
43use vmcore::save_restore::NoSavedState;
44use zerocopy::FromBytes;
45use zerocopy::FromZeros;
46use zerocopy::IntoBytes;
47
48/// A shutdown IC client device.
49#[derive(InspectMut)]
50pub struct ShutdownGuestIc {
51    #[inspect(skip)]
52    send_shutdown_notification: mesh::Sender<Rpc<ShutdownParams, ShutdownResult>>,
53    #[inspect(skip)]
54    recv_shutdown_notification: Option<mesh::Receiver<Rpc<ShutdownParams, ShutdownResult>>>,
55}
56
57#[derive(Debug, Inspect)]
58#[inspect(tag = "channel_state")]
59enum ShutdownGuestChannelState {
60    NegotiateVersion,
61    Running {
62        #[inspect(display)]
63        framework_version: hyperv_ic_protocol::Version,
64        #[inspect(display)]
65        message_version: hyperv_ic_protocol::Version,
66    },
67}
68
69/// Established channel between guest and host.
70#[derive(InspectMut)]
71pub struct ShutdownGuestChannel {
72    /// Current state.
73    state: ShutdownGuestChannelState,
74    /// Vmbus pipe to the host.
75    #[inspect(mut)]
76    pipe: MessagePipe<MemoryBlockRingBuffer>,
77}
78
79#[derive(Debug, Error)]
80enum Error {
81    #[error("ring buffer error")]
82    Ring(#[source] std::io::Error),
83    #[error("truncated message")]
84    TruncatedMessage,
85}
86
87impl ShutdownGuestIc {
88    /// Returns a new shutdown IC client device.
89    pub fn new() -> Self {
90        let (send_shutdown_notification, recv_shutdown_notification) = mesh::channel();
91        Self {
92            send_shutdown_notification,
93            recv_shutdown_notification: Some(recv_shutdown_notification),
94        }
95    }
96
97    /// Returns the notifier that will receive any shutdown requests from the host.
98    pub fn get_shutdown_notifier(&mut self) -> mesh::Receiver<Rpc<ShutdownParams, ShutdownResult>> {
99        self.recv_shutdown_notification
100            .take()
101            .expect("can only be called once")
102    }
103}
104
105impl ShutdownGuestChannel {
106    fn new(pipe: MessagePipe<MemoryBlockRingBuffer>) -> Self {
107        Self {
108            state: ShutdownGuestChannelState::NegotiateVersion,
109            pipe,
110        }
111    }
112
113    async fn process(&mut self, ic: &mut ShutdownGuestIc) -> Result<(), Error> {
114        loop {
115            match read_from_pipe(&mut self.pipe).await {
116                Ok(buf) => {
117                    self.handle_host_message(&buf, ic).await;
118                }
119                Err(err) => {
120                    tracelimit::error_ratelimited!(
121                        err = &err as &dyn std::error::Error,
122                        "reading shutdown packet from host",
123                    );
124                }
125            }
126        }
127    }
128
129    async fn handle_host_message(&mut self, buf: &[u8], ic: &ShutdownGuestIc) {
130        // TODO: zerocopy: err (https://github.com/microsoft/openvmm/issues/759)
131        let (header, rest) = match hyperv_ic_protocol::Header::read_from_prefix(buf).ok() {
132            Some((h, r)) => (h, r),
133            None => {
134                tracelimit::error_ratelimited!("invalid shutdown packet from host",);
135                return;
136            }
137        };
138        match header.message_type {
139            hyperv_ic_protocol::MessageType::VERSION_NEGOTIATION => {
140                // Version negotiation can happen multiple times due to various
141                // state changes on the host. This message triggers a reset
142                // of the current state.
143                self.state = ShutdownGuestChannelState::NegotiateVersion;
144                if let Err(err) = self.handle_version_negotiation(&header, rest).await {
145                    tracelimit::error_ratelimited!(
146                        err = &err as &dyn std::error::Error,
147                        "Failed version negotiation"
148                    );
149                }
150            }
151            hyperv_ic_protocol::MessageType::SHUTDOWN
152                if matches!(self.state, ShutdownGuestChannelState::Running { .. }) =>
153            {
154                if let Err(err) = self.handle_shutdown_notification(&header, rest, ic).await {
155                    tracelimit::error_ratelimited!(
156                        err = &err as &dyn std::error::Error,
157                        "Failed processing shutdown message"
158                    );
159                }
160            }
161            _ => {
162                tracelimit::error_ratelimited!(r#type = ?header.message_type, state = ?self.state, "Unrecognized packet");
163            }
164        }
165    }
166
167    fn find_latest_supported_version<'a>(
168        buf: &'a [u8],
169        count: usize,
170        supported: &[hyperv_ic_protocol::Version],
171    ) -> (Option<hyperv_ic_protocol::Version>, &'a [u8]) {
172        let mut rest = buf;
173        let mut next_version;
174        let mut latest_version = None;
175        for _ in 0..count {
176            // TODO: zerocopy: err (https://github.com/microsoft/openvmm/issues/759)
177            (next_version, rest) = match hyperv_ic_protocol::Version::read_from_prefix(rest).ok() {
178                Some((n, r)) => (n, r),
179                None => {
180                    tracelimit::error_ratelimited!("truncated message version list");
181                    return (latest_version, rest);
182                }
183            };
184            for known in supported {
185                if known.major == next_version.major && known.minor == next_version.minor {
186                    if latest_version.is_some() {
187                        if next_version.major >= latest_version.unwrap().major {
188                            if next_version.major > latest_version.unwrap().major
189                                || next_version.minor > latest_version.unwrap().minor
190                            {
191                                latest_version = Some(next_version);
192                            }
193                        }
194                    } else {
195                        latest_version = Some(next_version);
196                    }
197                }
198            }
199        }
200        (latest_version, rest)
201    }
202
203    async fn handle_version_negotiation(
204        &mut self,
205        header: &hyperv_ic_protocol::Header,
206        msg: &[u8],
207    ) -> Result<(), Error> {
208        const FRAMEWORK_VERSIONS: &[hyperv_ic_protocol::Version] =
209            &[FRAMEWORK_VERSION_1, FRAMEWORK_VERSION_3];
210
211        const SHUTDOWN_VERSIONS: &[hyperv_ic_protocol::Version] = &[
212            SHUTDOWN_VERSION_1,
213            SHUTDOWN_VERSION_3,
214            SHUTDOWN_VERSION_3_1,
215            SHUTDOWN_VERSION_3_2,
216        ];
217
218        let (prefix, rest) = hyperv_ic_protocol::NegotiateMessage::read_from_prefix(msg)
219            .map_err(|_| Error::TruncatedMessage)?; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
220        let (latest_framework_version, rest) = Self::find_latest_supported_version(
221            rest,
222            prefix.framework_version_count as usize,
223            FRAMEWORK_VERSIONS,
224        );
225        let framework_version = if let Some(version) = latest_framework_version {
226            version
227        } else {
228            tracelimit::error_ratelimited!("Unsupported framework version");
229            FRAMEWORK_VERSIONS[FRAMEWORK_VERSIONS.len() - 1]
230        };
231        let (latest_message_version, _) = Self::find_latest_supported_version(
232            rest,
233            prefix.message_version_count as usize,
234            SHUTDOWN_VERSIONS,
235        );
236        let message_version = if let Some(version) = latest_message_version {
237            version
238        } else {
239            tracelimit::error_ratelimited!("Unsupported message version");
240            SHUTDOWN_VERSIONS[SHUTDOWN_VERSIONS.len() - 1]
241        };
242
243        let message = hyperv_ic_protocol::NegotiateMessage {
244            framework_version_count: 1,
245            message_version_count: 1,
246            ..FromZeros::new_zeroed()
247        };
248        let response = hyperv_ic_protocol::Header {
249            message_type: hyperv_ic_protocol::MessageType::VERSION_NEGOTIATION,
250            message_size: (size_of_val(&message)
251                + size_of_val(&framework_version)
252                + size_of_val(&message_version)) as u16,
253            status: Status::SUCCESS,
254            transaction_id: header.transaction_id,
255            flags: hyperv_ic_protocol::HeaderFlags::new()
256                .with_transaction(header.flags.transaction())
257                .with_response(true),
258            ..FromZeros::new_zeroed()
259        };
260        self.pipe
261            .send_vectored(&[
262                IoSlice::new(response.as_bytes()),
263                IoSlice::new(message.as_bytes()),
264                IoSlice::new(framework_version.as_bytes()),
265                IoSlice::new(message_version.as_bytes()),
266            ])
267            .await
268            .map_err(Error::Ring)?;
269
270        tracing::info!(%framework_version, %message_version, "version negotiated");
271        self.state = ShutdownGuestChannelState::Running {
272            framework_version,
273            message_version,
274        };
275        Ok(())
276    }
277
278    async fn handle_shutdown_notification(
279        &mut self,
280        header: &hyperv_ic_protocol::Header,
281        buf: &[u8],
282        ic: &ShutdownGuestIc,
283    ) -> Result<(), Error> {
284        let ShutdownGuestChannelState::Running {
285            framework_version,
286            message_version,
287        } = &self.state
288        else {
289            panic!("Shutdown message processing while in invalid state");
290        };
291
292        let message = hyperv_ic_protocol::shutdown::ShutdownMessage::read_from_prefix(buf)
293            .map_err(|_| Error::TruncatedMessage)?
294            .0; // TODO: zerocopy: map_err (https://github.com/microsoft/openvmm/issues/759)
295        let shutdown_type = if message.flags.restart() {
296            ShutdownType::Reboot
297        } else if message.flags.hibernate() {
298            ShutdownType::Hibernate
299        } else {
300            ShutdownType::PowerOff
301        };
302        let params = ShutdownParams {
303            shutdown_type,
304            force: message.flags.force(),
305        };
306
307        // Notify the internal listener and wait for a response.
308        let result = ic.send_shutdown_notification.call(|x| x, params).await;
309
310        // Respond to the request.
311        let response = hyperv_ic_protocol::Header {
312            framework_version: *framework_version,
313            message_version: *message_version,
314            message_type: hyperv_ic_protocol::MessageType::SHUTDOWN,
315            message_size: 0,
316            status: if result.is_ok() {
317                Status::SUCCESS
318            } else {
319                Status::FAIL
320            },
321            transaction_id: header.transaction_id,
322            flags: hyperv_ic_protocol::HeaderFlags::new()
323                .with_transaction(header.flags.transaction())
324                .with_response(true),
325            ..FromZeros::new_zeroed()
326        };
327        self.pipe
328            .send(response.as_bytes())
329            .await
330            .map_err(Error::Ring)
331    }
332}
333
334async fn read_from_pipe<T: RingMem>(pipe: &mut MessagePipe<T>) -> Result<Vec<u8>, Error> {
335    let mut buf = vec![0; hyperv_ic_protocol::MAX_MESSAGE_SIZE];
336    let n = pipe.recv(&mut buf).await.map_err(Error::Ring)?;
337    let buf = &buf[..n];
338    Ok(buf.to_vec())
339}
340
341impl SimpleVmbusClientDevice for ShutdownGuestIc {
342    type SavedState = NoSavedState;
343    type Runner = ShutdownGuestChannel;
344
345    fn instance_id(&self) -> Guid {
346        INSTANCE_ID
347    }
348
349    fn offer(&self, _offer: &vmbus_core::protocol::OfferChannel) -> OfferResponse {
350        OfferResponse::Open
351    }
352
353    fn inspect(&mut self, req: inspect::Request<'_>, runner: Option<&mut Self::Runner>) {
354        req.respond().merge(self).merge(runner);
355    }
356
357    fn open(
358        &mut self,
359        _channel_idx: u16,
360        channel: RawAsyncChannel<MemoryBlockRingBuffer>,
361    ) -> Result<Self::Runner, ChannelOpenError> {
362        let pipe = MessagePipe::new(channel)?;
363        Ok(ShutdownGuestChannel::new(pipe))
364    }
365
366    fn close(&mut self, _channel_idx: u16) {}
367
368    fn supports_save_restore(
369        &mut self,
370    ) -> Option<
371        &mut dyn SaveRestoreSimpleVmbusClientDevice<
372            SavedState = Self::SavedState,
373            Runner = Self::Runner,
374        >,
375    > {
376        None
377    }
378}
379
380impl SimpleVmbusClientDeviceAsync for ShutdownGuestIc {
381    async fn run(
382        &mut self,
383        stop: &mut StopTask<'_>,
384        runner: &mut Self::Runner,
385    ) -> Result<(), Cancelled> {
386        stop.until_stopped(async {
387            match runner.process(self).await {
388                Ok(()) => {}
389                Err(err) => {
390                    tracing::error!(
391                        error = &err as &dyn std::error::Error,
392                        "shutdown ic relay error"
393                    )
394                }
395            }
396        })
397        .await
398    }
399}