1use crate::common::IcPipe;
7use crate::common::NegotiateState;
8use crate::common::Versions;
9use async_trait::async_trait;
10use futures::FutureExt;
11use futures::StreamExt;
12use futures::stream::once;
13use futures_concurrency::stream::Merge;
14use hyperv_ic_protocol::Status;
15use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_1;
16use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_3;
17use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_3_1;
18use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_3_2;
19use hyperv_ic_resources::shutdown::ShutdownParams;
20use hyperv_ic_resources::shutdown::ShutdownResult;
21use hyperv_ic_resources::shutdown::ShutdownRpc;
22use hyperv_ic_resources::shutdown::ShutdownType;
23use inspect::Inspect;
24use inspect::InspectMut;
25use mesh::rpc::Rpc;
26use std::pin::pin;
27use task_control::Cancelled;
28use task_control::StopTask;
29use vmbus_channel::RawAsyncChannel;
30use vmbus_channel::bus::ChannelType;
31use vmbus_channel::bus::OfferParams;
32use vmbus_channel::channel::ChannelOpenError;
33use vmbus_channel::gpadl_ring::GpadlRingMem;
34use vmbus_channel::simple::SaveRestoreSimpleVmbusDevice;
35use vmbus_channel::simple::SimpleVmbusDevice;
36use zerocopy::IntoBytes;
37
38const SHUTDOWN_VERSIONS: &[hyperv_ic_protocol::Version] = &[
39 SHUTDOWN_VERSION_1,
40 SHUTDOWN_VERSION_3,
41 SHUTDOWN_VERSION_3_1,
42 SHUTDOWN_VERSION_3_2,
43];
44
45#[derive(InspectMut)]
47pub struct ShutdownIc {
48 #[inspect(skip)]
49 recv: mesh::Receiver<ShutdownRpc>,
50 #[inspect(skip)]
51 wait_ready: Vec<Rpc<(), mesh::OneshotReceiver<()>>>,
52}
53
54#[doc(hidden)]
55#[derive(InspectMut)]
56pub struct ShutdownChannel {
57 #[inspect(mut)]
58 pipe: IcPipe,
59 state: ChannelState,
60 #[inspect(with = "Option::is_some")]
61 pending_shutdown: Option<Rpc<(), ShutdownResult>>,
62}
63
64#[derive(Inspect)]
65#[inspect(external_tag)]
66enum ChannelState {
67 Negotiate(#[inspect(flatten)] NegotiateState),
68 Ready {
69 versions: Versions,
70 state: ReadyState,
71 #[inspect(with = "|x| x.len()")]
72 clients: Vec<mesh::OneshotSender<()>>,
73 },
74}
75
76#[derive(Inspect)]
77#[inspect(external_tag)]
78enum ReadyState {
79 Ready,
80 SendShutdown(#[inspect(skip)] ShutdownParams),
81 WaitShutdown,
82}
83
84impl ShutdownIc {
85 pub fn new(recv: mesh::Receiver<ShutdownRpc>) -> Self {
87 Self {
88 recv,
89 wait_ready: Vec::new(),
90 }
91 }
92}
93
94impl ShutdownChannel {
95 fn new(
96 channel: RawAsyncChannel<GpadlRingMem>,
97 restore_state: Option<ChannelState>,
98 ) -> Result<ShutdownChannel, ChannelOpenError> {
99 let pipe = IcPipe::new(channel)?;
100 Ok(Self {
101 pipe,
102 state: restore_state.unwrap_or(ChannelState::Negotiate(NegotiateState::default())),
103 pending_shutdown: None,
104 })
105 }
106
107 async fn process(&mut self, ic: &mut ShutdownIc) -> anyhow::Result<()> {
108 enum Event {
109 StateMachine(anyhow::Result<()>),
110 Request(ShutdownRpc),
111 }
112
113 loop {
114 let event = pin!(
115 (
116 once(
117 self.process_state_machine(&mut ic.wait_ready)
118 .map(Event::StateMachine)
119 ),
120 (&mut ic.recv).map(Event::Request),
121 )
122 .merge()
123 )
124 .next()
125 .await
126 .unwrap();
127 match event {
128 Event::StateMachine(r) => {
129 r?;
130 }
131 Event::Request(req) => match req {
132 ShutdownRpc::WaitReady(rpc) => match &mut self.state {
133 ChannelState::Negotiate(_) => ic.wait_ready.push(rpc),
134 ChannelState::Ready { clients, .. } => {
135 let (send, recv) = mesh::oneshot();
136 clients.retain(|c| !c.is_closed());
137 clients.push(send);
138 rpc.complete(recv);
139 }
140 },
141 ShutdownRpc::Shutdown(rpc) => match self.state {
142 ChannelState::Negotiate(_) => rpc.complete(ShutdownResult::NotReady),
143 ChannelState::Ready { ref mut state, .. } => match state {
144 ReadyState::Ready => {
145 let (input, rpc) = rpc.split();
146 self.pending_shutdown = Some(rpc);
147 *state = ReadyState::SendShutdown(input);
148 }
149 ReadyState::SendShutdown { .. } | ReadyState::WaitShutdown => {
150 rpc.complete(ShutdownResult::AlreadyInProgress)
151 }
152 },
153 },
154 },
155 }
156 }
157 }
158
159 async fn process_state_machine(
160 &mut self,
161 wait_ready: &mut Vec<Rpc<(), mesh::OneshotReceiver<()>>>,
162 ) -> anyhow::Result<()> {
163 match self.state {
164 ChannelState::Negotiate(ref mut state) => {
165 if let Some(versions) = self.pipe.negotiate(state, SHUTDOWN_VERSIONS).await? {
166 let clients = wait_ready
167 .drain(..)
168 .map(|rpc| {
169 let (send, recv) = mesh::oneshot();
170 rpc.complete(recv);
171 send
172 })
173 .collect();
174
175 self.state = ChannelState::Ready {
176 versions,
177 state: ReadyState::Ready,
178 clients,
179 };
180 }
181 }
182 ChannelState::Ready {
183 ref mut state,
184 ref versions,
185 clients: _,
186 } => match state {
187 ReadyState::Ready => std::future::pending().await,
188 ReadyState::SendShutdown(params) => {
189 let mut flags =
190 hyperv_ic_protocol::shutdown::ShutdownFlags::new().with_force(params.force);
191 match params.shutdown_type {
192 ShutdownType::PowerOff => {}
193 ShutdownType::Reboot => flags.set_restart(true),
194 ShutdownType::Hibernate => flags.set_hibernate(true),
195 }
196
197 let message = Box::new(hyperv_ic_protocol::shutdown::ShutdownMessage {
198 reason_code: hyperv_ic_protocol::shutdown::SHTDN_REASON_FLAG_PLANNED,
199 timeout_secs: 0,
200 flags,
201 message: [0; 2048],
202 });
203
204 self.pipe
205 .write_message(
206 versions,
207 hyperv_ic_protocol::MessageType::SHUTDOWN,
208 hyperv_ic_protocol::HeaderFlags::new()
209 .with_transaction(true)
210 .with_request(true),
211 message.as_bytes(),
212 )
213 .await?;
214
215 *state = ReadyState::WaitShutdown;
216 }
217 ReadyState::WaitShutdown => {
218 let (status, _) = self.pipe.read_response().await?;
219 let result = if status == Status::SUCCESS {
220 ShutdownResult::Ok
221 } else {
222 ShutdownResult::Failed(status.0)
223 };
224 if let Some(send) = self.pending_shutdown.take() {
225 send.complete(result);
226 }
227 *state = ReadyState::Ready;
228 }
229 },
230 }
231 Ok(())
232 }
233}
234
235#[async_trait]
236impl SimpleVmbusDevice for ShutdownIc {
237 type SavedState = save_restore::state::SavedState;
238 type Runner = ShutdownChannel;
239
240 fn offer(&self) -> OfferParams {
241 OfferParams {
242 interface_name: "shutdown_ic".to_owned(),
243 instance_id: hyperv_ic_protocol::shutdown::INSTANCE_ID,
244 interface_id: hyperv_ic_protocol::shutdown::INTERFACE_ID,
245 channel_type: ChannelType::Pipe { message_mode: true },
246 ..Default::default()
247 }
248 }
249
250 fn inspect(&mut self, req: inspect::Request<'_>, runner: Option<&mut Self::Runner>) {
251 req.respond().merge(self).merge(runner);
252 }
253
254 fn open(
255 &mut self,
256 channel: RawAsyncChannel<GpadlRingMem>,
257 _guest_memory: guestmem::GuestMemory,
258 ) -> Result<Self::Runner, ChannelOpenError> {
259 ShutdownChannel::new(channel, None)
260 }
261
262 async fn run(
263 &mut self,
264 stop: &mut StopTask<'_>,
265 runner: &mut Self::Runner,
266 ) -> Result<(), Cancelled> {
267 stop.until_stopped(async {
268 match runner.process(self).await {
269 Ok(()) => {}
270 Err(err) => {
271 tracing::error!(
272 error = err.as_ref() as &dyn std::error::Error,
273 "shutdown ic error"
274 )
275 }
276 }
277 })
278 .await
279 }
280
281 fn supports_save_restore(
282 &mut self,
283 ) -> Option<
284 &mut dyn SaveRestoreSimpleVmbusDevice<SavedState = Self::SavedState, Runner = Self::Runner>,
285 > {
286 Some(self)
287 }
288}
289
290mod save_restore {
291 use super::*;
292
293 pub mod state {
294 use hyperv_ic_protocol;
295 use mesh::payload::Protobuf;
296 use vmcore::save_restore::SavedStateRoot;
297
298 #[derive(Copy, Clone, Eq, PartialEq, Protobuf)]
299 #[mesh(package = "shutdown_ic")]
300 pub struct Version {
301 #[mesh(1)]
302 pub major: u16,
303 #[mesh(2)]
304 pub minor: u16,
305 }
306
307 impl From<hyperv_ic_protocol::Version> for Version {
308 fn from(version: hyperv_ic_protocol::Version) -> Self {
309 Self {
310 major: version.major,
311 minor: version.minor,
312 }
313 }
314 }
315
316 impl From<Version> for hyperv_ic_protocol::Version {
317 fn from(version: Version) -> Self {
318 Self {
319 major: version.major,
320 minor: version.minor,
321 }
322 }
323 }
324
325 #[derive(Copy, Clone, Eq, PartialEq, Protobuf)]
326 #[mesh(package = "shutdown_ic")]
327 pub struct ShutdownParams {
328 #[mesh(1)]
329 pub shutdown_type: ShutdownType,
330 #[mesh(2)]
331 pub force: bool,
332 }
333
334 impl From<&hyperv_ic_resources::shutdown::ShutdownParams> for ShutdownParams {
335 fn from(params: &hyperv_ic_resources::shutdown::ShutdownParams) -> Self {
336 let shutdown_type = match params.shutdown_type {
337 hyperv_ic_resources::shutdown::ShutdownType::PowerOff => ShutdownType::PowerOff,
338 hyperv_ic_resources::shutdown::ShutdownType::Reboot => ShutdownType::Reboot,
339 hyperv_ic_resources::shutdown::ShutdownType::Hibernate => {
340 ShutdownType::Hibernate
341 }
342 };
343 Self {
344 shutdown_type,
345 force: params.force,
346 }
347 }
348 }
349
350 impl From<&ShutdownParams> for hyperv_ic_resources::shutdown::ShutdownParams {
351 fn from(params: &ShutdownParams) -> Self {
352 let shutdown_type = match params.shutdown_type {
353 ShutdownType::PowerOff => hyperv_ic_resources::shutdown::ShutdownType::PowerOff,
354 ShutdownType::Reboot => hyperv_ic_resources::shutdown::ShutdownType::Reboot,
355 ShutdownType::Hibernate => {
356 hyperv_ic_resources::shutdown::ShutdownType::Hibernate
357 }
358 };
359 Self {
360 shutdown_type,
361 force: params.force,
362 }
363 }
364 }
365
366 impl From<ShutdownParams> for hyperv_ic_resources::shutdown::ShutdownParams {
367 fn from(params: ShutdownParams) -> Self {
368 (¶ms).into()
369 }
370 }
371
372 #[derive(Copy, Clone, Eq, PartialEq, Protobuf)]
373 #[mesh(package = "shutdown_ic")]
374 pub enum ShutdownType {
375 #[mesh(1)]
376 PowerOff,
377 #[mesh(2)]
378 Reboot,
379 #[mesh(3)]
380 Hibernate,
381 }
382
383 #[derive(Protobuf, SavedStateRoot)]
384 #[mesh(package = "shutdown_ic")]
385 pub struct SavedState {
386 #[mesh(1)]
387 pub version: Option<(Version, Version)>,
388 #[mesh(2)]
389 pub shutdown_request: Option<ShutdownParams>,
390 #[mesh(3)]
391 pub waiting_on_version: bool,
392 #[mesh(4)]
393 pub waiting_on_shutdown_response: bool,
394 }
395 }
396
397 impl SaveRestoreSimpleVmbusDevice for ShutdownIc {
398 fn save_open(&mut self, runner: &Self::Runner) -> state::SavedState {
399 let (versions, shutdown_request, waiting_on_shutdown_response) =
400 if let ChannelState::Ready {
401 versions,
402 ref state,
403 clients: _,
404 } = runner.state
405 {
406 let request = if let ReadyState::SendShutdown(request) = state {
407 Some(request.into())
408 } else {
409 None
410 };
411 let waiting = matches!(state, ReadyState::WaitShutdown);
412 (Some(versions), request, waiting)
413 } else {
414 (None, None, false)
415 };
416 let waiting_on_version = matches!(
417 runner.state,
418 ChannelState::Negotiate(NegotiateState::WaitVersion)
419 );
420 state::SavedState {
421 version: versions.map(|v| (v.framework_version.into(), v.message_version.into())),
422 shutdown_request,
423 waiting_on_version,
424 waiting_on_shutdown_response,
425 }
426 }
427
428 fn restore_open(
429 &mut self,
430 saved_state: Self::SavedState,
431 channel: RawAsyncChannel<GpadlRingMem>,
432 ) -> Result<Self::Runner, ChannelOpenError> {
433 let state = if let Some((framework, message)) = saved_state.version {
434 let state = if let Some(request) = saved_state.shutdown_request {
435 ReadyState::SendShutdown(request.into())
436 } else if saved_state.waiting_on_shutdown_response {
437 ReadyState::WaitShutdown
438 } else {
439 ReadyState::Ready
440 };
441 ChannelState::Ready {
442 versions: Versions {
443 framework_version: framework.into(),
444 message_version: message.into(),
445 },
446 state,
447 clients: Vec::new(),
448 }
449 } else {
450 ChannelState::Negotiate(if saved_state.waiting_on_version {
451 NegotiateState::WaitVersion
452 } else {
453 NegotiateState::SendVersion
454 })
455 };
456 ShutdownChannel::new(channel, Some(state))
457 }
458 }
459}