hyperv_ic/
shutdown.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! The shutdown IC.
5
6use crate::common::IcPipe;
7use crate::common::NegotiateState;
8use crate::common::Versions;
9use async_trait::async_trait;
10use futures::FutureExt;
11use futures::StreamExt;
12use futures::stream::once;
13use futures_concurrency::stream::Merge;
14use hyperv_ic_protocol::Status;
15use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_1;
16use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_3;
17use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_3_1;
18use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_3_2;
19use hyperv_ic_resources::shutdown::ShutdownParams;
20use hyperv_ic_resources::shutdown::ShutdownResult;
21use hyperv_ic_resources::shutdown::ShutdownRpc;
22use hyperv_ic_resources::shutdown::ShutdownType;
23use inspect::Inspect;
24use inspect::InspectMut;
25use mesh::rpc::Rpc;
26use std::pin::pin;
27use task_control::Cancelled;
28use task_control::StopTask;
29use vmbus_channel::RawAsyncChannel;
30use vmbus_channel::bus::ChannelType;
31use vmbus_channel::bus::OfferParams;
32use vmbus_channel::channel::ChannelOpenError;
33use vmbus_channel::gpadl_ring::GpadlRingMem;
34use vmbus_channel::simple::SaveRestoreSimpleVmbusDevice;
35use vmbus_channel::simple::SimpleVmbusDevice;
36use zerocopy::IntoBytes;
37
38const SHUTDOWN_VERSIONS: &[hyperv_ic_protocol::Version] = &[
39    SHUTDOWN_VERSION_1,
40    SHUTDOWN_VERSION_3,
41    SHUTDOWN_VERSION_3_1,
42    SHUTDOWN_VERSION_3_2,
43];
44
45/// A shutdown IC device.
46#[derive(InspectMut)]
47pub struct ShutdownIc {
48    #[inspect(skip)]
49    recv: mesh::Receiver<ShutdownRpc>,
50    #[inspect(skip)]
51    wait_ready: Vec<Rpc<(), mesh::OneshotReceiver<()>>>,
52}
53
54#[doc(hidden)]
55#[derive(InspectMut)]
56pub struct ShutdownChannel {
57    #[inspect(mut)]
58    pipe: IcPipe,
59    state: ChannelState,
60    #[inspect(with = "Option::is_some")]
61    pending_shutdown: Option<Rpc<(), ShutdownResult>>,
62}
63
64#[derive(Inspect)]
65#[inspect(external_tag)]
66enum ChannelState {
67    Negotiate(#[inspect(flatten)] NegotiateState),
68    Ready {
69        versions: Versions,
70        state: ReadyState,
71        #[inspect(with = "|x| x.len()")]
72        clients: Vec<mesh::OneshotSender<()>>,
73    },
74}
75
76#[derive(Inspect)]
77#[inspect(external_tag)]
78enum ReadyState {
79    Ready,
80    SendShutdown(#[inspect(skip)] ShutdownParams),
81    WaitShutdown,
82}
83
84impl ShutdownIc {
85    /// Returns a new shutdown IC, using `recv` to receive shutdown requests.
86    pub fn new(recv: mesh::Receiver<ShutdownRpc>) -> Self {
87        Self {
88            recv,
89            wait_ready: Vec::new(),
90        }
91    }
92}
93
94impl ShutdownChannel {
95    fn new(
96        channel: RawAsyncChannel<GpadlRingMem>,
97        restore_state: Option<ChannelState>,
98    ) -> Result<ShutdownChannel, ChannelOpenError> {
99        let pipe = IcPipe::new(channel)?;
100        Ok(Self {
101            pipe,
102            state: restore_state.unwrap_or(ChannelState::Negotiate(NegotiateState::default())),
103            pending_shutdown: None,
104        })
105    }
106
107    async fn process(&mut self, ic: &mut ShutdownIc) -> anyhow::Result<()> {
108        enum Event {
109            StateMachine(anyhow::Result<()>),
110            Request(ShutdownRpc),
111        }
112
113        loop {
114            let event = pin!(
115                (
116                    once(
117                        self.process_state_machine(&mut ic.wait_ready)
118                            .map(Event::StateMachine)
119                    ),
120                    (&mut ic.recv).map(Event::Request),
121                )
122                    .merge()
123            )
124            .next()
125            .await
126            .unwrap();
127            match event {
128                Event::StateMachine(r) => {
129                    r?;
130                }
131                Event::Request(req) => match req {
132                    ShutdownRpc::WaitReady(rpc) => match &mut self.state {
133                        ChannelState::Negotiate(_) => ic.wait_ready.push(rpc),
134                        ChannelState::Ready { clients, .. } => {
135                            let (send, recv) = mesh::oneshot();
136                            clients.retain(|c| !c.is_closed());
137                            clients.push(send);
138                            rpc.complete(recv);
139                        }
140                    },
141                    ShutdownRpc::Shutdown(rpc) => match self.state {
142                        ChannelState::Negotiate(_) => rpc.complete(ShutdownResult::NotReady),
143                        ChannelState::Ready { ref mut state, .. } => match state {
144                            ReadyState::Ready => {
145                                let (input, rpc) = rpc.split();
146                                self.pending_shutdown = Some(rpc);
147                                *state = ReadyState::SendShutdown(input);
148                            }
149                            ReadyState::SendShutdown { .. } | ReadyState::WaitShutdown => {
150                                rpc.complete(ShutdownResult::AlreadyInProgress)
151                            }
152                        },
153                    },
154                },
155            }
156        }
157    }
158
159    async fn process_state_machine(
160        &mut self,
161        wait_ready: &mut Vec<Rpc<(), mesh::OneshotReceiver<()>>>,
162    ) -> anyhow::Result<()> {
163        match self.state {
164            ChannelState::Negotiate(ref mut state) => {
165                if let Some(versions) = self.pipe.negotiate(state, SHUTDOWN_VERSIONS).await? {
166                    let clients = wait_ready
167                        .drain(..)
168                        .map(|rpc| {
169                            let (send, recv) = mesh::oneshot();
170                            rpc.complete(recv);
171                            send
172                        })
173                        .collect();
174
175                    self.state = ChannelState::Ready {
176                        versions,
177                        state: ReadyState::Ready,
178                        clients,
179                    };
180                }
181            }
182            ChannelState::Ready {
183                ref mut state,
184                ref versions,
185                clients: _,
186            } => match state {
187                ReadyState::Ready => std::future::pending().await,
188                ReadyState::SendShutdown(params) => {
189                    let mut flags =
190                        hyperv_ic_protocol::shutdown::ShutdownFlags::new().with_force(params.force);
191                    match params.shutdown_type {
192                        ShutdownType::PowerOff => {}
193                        ShutdownType::Reboot => flags.set_restart(true),
194                        ShutdownType::Hibernate => flags.set_hibernate(true),
195                    }
196
197                    let message = Box::new(hyperv_ic_protocol::shutdown::ShutdownMessage {
198                        reason_code: hyperv_ic_protocol::shutdown::SHTDN_REASON_FLAG_PLANNED,
199                        timeout_secs: 0,
200                        flags,
201                        message: [0; 2048],
202                    });
203
204                    self.pipe
205                        .write_message(
206                            versions,
207                            hyperv_ic_protocol::MessageType::SHUTDOWN,
208                            hyperv_ic_protocol::HeaderFlags::new()
209                                .with_transaction(true)
210                                .with_request(true),
211                            message.as_bytes(),
212                        )
213                        .await?;
214
215                    *state = ReadyState::WaitShutdown;
216                }
217                ReadyState::WaitShutdown => {
218                    let (status, _) = self.pipe.read_response().await?;
219                    let result = if status == Status::SUCCESS {
220                        ShutdownResult::Ok
221                    } else {
222                        ShutdownResult::Failed(status.0)
223                    };
224                    if let Some(send) = self.pending_shutdown.take() {
225                        send.complete(result);
226                    }
227                    *state = ReadyState::Ready;
228                }
229            },
230        }
231        Ok(())
232    }
233}
234
235#[async_trait]
236impl SimpleVmbusDevice for ShutdownIc {
237    type SavedState = save_restore::state::SavedState;
238    type Runner = ShutdownChannel;
239
240    fn offer(&self) -> OfferParams {
241        OfferParams {
242            interface_name: "shutdown_ic".to_owned(),
243            instance_id: hyperv_ic_protocol::shutdown::INSTANCE_ID,
244            interface_id: hyperv_ic_protocol::shutdown::INTERFACE_ID,
245            channel_type: ChannelType::Pipe { message_mode: true },
246            ..Default::default()
247        }
248    }
249
250    fn inspect(&mut self, req: inspect::Request<'_>, runner: Option<&mut Self::Runner>) {
251        req.respond().merge(self).merge(runner);
252    }
253
254    fn open(
255        &mut self,
256        channel: RawAsyncChannel<GpadlRingMem>,
257        _guest_memory: guestmem::GuestMemory,
258    ) -> Result<Self::Runner, ChannelOpenError> {
259        ShutdownChannel::new(channel, None)
260    }
261
262    async fn run(
263        &mut self,
264        stop: &mut StopTask<'_>,
265        runner: &mut Self::Runner,
266    ) -> Result<(), Cancelled> {
267        stop.until_stopped(async {
268            match runner.process(self).await {
269                Ok(()) => {}
270                Err(err) => {
271                    tracing::error!(
272                        error = err.as_ref() as &dyn std::error::Error,
273                        "shutdown ic error"
274                    )
275                }
276            }
277        })
278        .await
279    }
280
281    fn supports_save_restore(
282        &mut self,
283    ) -> Option<
284        &mut dyn SaveRestoreSimpleVmbusDevice<SavedState = Self::SavedState, Runner = Self::Runner>,
285    > {
286        Some(self)
287    }
288}
289
290mod save_restore {
291    use super::*;
292
293    pub mod state {
294        use hyperv_ic_protocol;
295        use mesh::payload::Protobuf;
296        use vmcore::save_restore::SavedStateRoot;
297
298        #[derive(Copy, Clone, Eq, PartialEq, Protobuf)]
299        #[mesh(package = "shutdown_ic")]
300        pub struct Version {
301            #[mesh(1)]
302            pub major: u16,
303            #[mesh(2)]
304            pub minor: u16,
305        }
306
307        impl From<hyperv_ic_protocol::Version> for Version {
308            fn from(version: hyperv_ic_protocol::Version) -> Self {
309                Self {
310                    major: version.major,
311                    minor: version.minor,
312                }
313            }
314        }
315
316        impl From<Version> for hyperv_ic_protocol::Version {
317            fn from(version: Version) -> Self {
318                Self {
319                    major: version.major,
320                    minor: version.minor,
321                }
322            }
323        }
324
325        #[derive(Copy, Clone, Eq, PartialEq, Protobuf)]
326        #[mesh(package = "shutdown_ic")]
327        pub struct ShutdownParams {
328            #[mesh(1)]
329            pub shutdown_type: ShutdownType,
330            #[mesh(2)]
331            pub force: bool,
332        }
333
334        impl From<&hyperv_ic_resources::shutdown::ShutdownParams> for ShutdownParams {
335            fn from(params: &hyperv_ic_resources::shutdown::ShutdownParams) -> Self {
336                let shutdown_type = match params.shutdown_type {
337                    hyperv_ic_resources::shutdown::ShutdownType::PowerOff => ShutdownType::PowerOff,
338                    hyperv_ic_resources::shutdown::ShutdownType::Reboot => ShutdownType::Reboot,
339                    hyperv_ic_resources::shutdown::ShutdownType::Hibernate => {
340                        ShutdownType::Hibernate
341                    }
342                };
343                Self {
344                    shutdown_type,
345                    force: params.force,
346                }
347            }
348        }
349
350        impl From<&ShutdownParams> for hyperv_ic_resources::shutdown::ShutdownParams {
351            fn from(params: &ShutdownParams) -> Self {
352                let shutdown_type = match params.shutdown_type {
353                    ShutdownType::PowerOff => hyperv_ic_resources::shutdown::ShutdownType::PowerOff,
354                    ShutdownType::Reboot => hyperv_ic_resources::shutdown::ShutdownType::Reboot,
355                    ShutdownType::Hibernate => {
356                        hyperv_ic_resources::shutdown::ShutdownType::Hibernate
357                    }
358                };
359                Self {
360                    shutdown_type,
361                    force: params.force,
362                }
363            }
364        }
365
366        impl From<ShutdownParams> for hyperv_ic_resources::shutdown::ShutdownParams {
367            fn from(params: ShutdownParams) -> Self {
368                (&params).into()
369            }
370        }
371
372        #[derive(Copy, Clone, Eq, PartialEq, Protobuf)]
373        #[mesh(package = "shutdown_ic")]
374        pub enum ShutdownType {
375            #[mesh(1)]
376            PowerOff,
377            #[mesh(2)]
378            Reboot,
379            #[mesh(3)]
380            Hibernate,
381        }
382
383        #[derive(Protobuf, SavedStateRoot)]
384        #[mesh(package = "shutdown_ic")]
385        pub struct SavedState {
386            #[mesh(1)]
387            pub version: Option<(Version, Version)>,
388            #[mesh(2)]
389            pub shutdown_request: Option<ShutdownParams>,
390            #[mesh(3)]
391            pub waiting_on_version: bool,
392            #[mesh(4)]
393            pub waiting_on_shutdown_response: bool,
394        }
395    }
396
397    impl SaveRestoreSimpleVmbusDevice for ShutdownIc {
398        fn save_open(&mut self, runner: &Self::Runner) -> state::SavedState {
399            let (versions, shutdown_request, waiting_on_shutdown_response) =
400                if let ChannelState::Ready {
401                    versions,
402                    ref state,
403                    clients: _,
404                } = runner.state
405                {
406                    let request = if let ReadyState::SendShutdown(request) = state {
407                        Some(request.into())
408                    } else {
409                        None
410                    };
411                    let waiting = matches!(state, ReadyState::WaitShutdown);
412                    (Some(versions), request, waiting)
413                } else {
414                    (None, None, false)
415                };
416            let waiting_on_version = matches!(
417                runner.state,
418                ChannelState::Negotiate(NegotiateState::WaitVersion)
419            );
420            state::SavedState {
421                version: versions.map(|v| (v.framework_version.into(), v.message_version.into())),
422                shutdown_request,
423                waiting_on_version,
424                waiting_on_shutdown_response,
425            }
426        }
427
428        fn restore_open(
429            &mut self,
430            saved_state: Self::SavedState,
431            channel: RawAsyncChannel<GpadlRingMem>,
432        ) -> Result<Self::Runner, ChannelOpenError> {
433            let state = if let Some((framework, message)) = saved_state.version {
434                let state = if let Some(request) = saved_state.shutdown_request {
435                    ReadyState::SendShutdown(request.into())
436                } else if saved_state.waiting_on_shutdown_response {
437                    ReadyState::WaitShutdown
438                } else {
439                    ReadyState::Ready
440                };
441                ChannelState::Ready {
442                    versions: Versions {
443                        framework_version: framework.into(),
444                        message_version: message.into(),
445                    },
446                    state,
447                    clients: Vec::new(),
448                }
449            } else {
450                ChannelState::Negotiate(if saved_state.waiting_on_version {
451                    NegotiateState::WaitVersion
452                } else {
453                    NegotiateState::SendVersion
454                })
455            };
456            ShutdownChannel::new(channel, Some(state))
457        }
458    }
459}