1#![cfg(target_os = "linux")]
7#![forbid(unsafe_code)]
8
9pub use hyperv_ic_protocol::shutdown::INSTANCE_ID;
10pub use hyperv_ic_protocol::shutdown::INTERFACE_ID;
11
12use guid::Guid;
13use hyperv_ic_protocol::FRAMEWORK_VERSION_1;
14use hyperv_ic_protocol::FRAMEWORK_VERSION_3;
15use hyperv_ic_protocol::Status;
16use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_1;
17use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_3;
18use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_3_1;
19use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_3_2;
20use hyperv_ic_resources::shutdown::ShutdownParams;
21use hyperv_ic_resources::shutdown::ShutdownResult;
22use hyperv_ic_resources::shutdown::ShutdownType;
23use inspect::Inspect;
24use inspect::InspectMut;
25use mesh::rpc::Rpc;
26use mesh::rpc::RpcSend;
27use std::io::IoSlice;
28use std::mem::size_of_val;
29use task_control::Cancelled;
30use task_control::StopTask;
31use thiserror::Error;
32use vmbus_async::async_dgram::AsyncRecvExt;
33use vmbus_async::async_dgram::AsyncSendExt;
34use vmbus_async::pipe::MessagePipe;
35use vmbus_channel::RawAsyncChannel;
36use vmbus_channel::channel::ChannelOpenError;
37use vmbus_relay_intercept_device::OfferResponse;
38use vmbus_relay_intercept_device::SaveRestoreSimpleVmbusClientDevice;
39use vmbus_relay_intercept_device::SimpleVmbusClientDevice;
40use vmbus_relay_intercept_device::SimpleVmbusClientDeviceAsync;
41use vmbus_relay_intercept_device::ring_buffer::MemoryBlockRingBuffer;
42use vmbus_ring::RingMem;
43use vmcore::save_restore::NoSavedState;
44use zerocopy::FromBytes;
45use zerocopy::FromZeros;
46use zerocopy::IntoBytes;
47
48#[derive(InspectMut)]
50pub struct ShutdownGuestIc {
51 #[inspect(skip)]
52 send_shutdown_notification: mesh::Sender<Rpc<ShutdownParams, ShutdownResult>>,
53 #[inspect(skip)]
54 recv_shutdown_notification: Option<mesh::Receiver<Rpc<ShutdownParams, ShutdownResult>>>,
55}
56
57#[derive(Debug, Inspect)]
58#[inspect(tag = "channel_state")]
59enum ShutdownGuestChannelState {
60 NegotiateVersion,
61 Running {
62 #[inspect(display)]
63 framework_version: hyperv_ic_protocol::Version,
64 #[inspect(display)]
65 message_version: hyperv_ic_protocol::Version,
66 },
67}
68
69#[derive(InspectMut)]
71pub struct ShutdownGuestChannel {
72 state: ShutdownGuestChannelState,
74 #[inspect(mut)]
76 pipe: MessagePipe<MemoryBlockRingBuffer>,
77}
78
79#[derive(Debug, Error)]
80enum Error {
81 #[error("ring buffer error")]
82 Ring(#[source] std::io::Error),
83 #[error("truncated message")]
84 TruncatedMessage,
85}
86
87impl ShutdownGuestIc {
88 pub fn new() -> Self {
90 let (send_shutdown_notification, recv_shutdown_notification) = mesh::channel();
91 Self {
92 send_shutdown_notification,
93 recv_shutdown_notification: Some(recv_shutdown_notification),
94 }
95 }
96
97 pub fn get_shutdown_notifier(&mut self) -> mesh::Receiver<Rpc<ShutdownParams, ShutdownResult>> {
99 self.recv_shutdown_notification
100 .take()
101 .expect("can only be called once")
102 }
103}
104
105impl ShutdownGuestChannel {
106 fn new(pipe: MessagePipe<MemoryBlockRingBuffer>) -> Self {
107 Self {
108 state: ShutdownGuestChannelState::NegotiateVersion,
109 pipe,
110 }
111 }
112
113 async fn process(&mut self, ic: &mut ShutdownGuestIc) -> Result<(), Error> {
114 loop {
115 match read_from_pipe(&mut self.pipe).await {
116 Ok(buf) => {
117 self.handle_host_message(&buf, ic).await;
118 }
119 Err(err) => {
120 tracelimit::error_ratelimited!(
121 err = &err as &dyn std::error::Error,
122 "reading shutdown packet from host",
123 );
124 }
125 }
126 }
127 }
128
129 async fn handle_host_message(&mut self, buf: &[u8], ic: &ShutdownGuestIc) {
130 let (header, rest) = match hyperv_ic_protocol::Header::read_from_prefix(buf).ok() {
132 Some((h, r)) => (h, r),
133 None => {
134 tracelimit::error_ratelimited!("invalid shutdown packet from host",);
135 return;
136 }
137 };
138 match header.message_type {
139 hyperv_ic_protocol::MessageType::VERSION_NEGOTIATION => {
140 self.state = ShutdownGuestChannelState::NegotiateVersion;
144 if let Err(err) = self.handle_version_negotiation(&header, rest).await {
145 tracelimit::error_ratelimited!(
146 err = &err as &dyn std::error::Error,
147 "Failed version negotiation"
148 );
149 }
150 }
151 hyperv_ic_protocol::MessageType::SHUTDOWN
152 if matches!(self.state, ShutdownGuestChannelState::Running { .. }) =>
153 {
154 if let Err(err) = self.handle_shutdown_notification(&header, rest, ic).await {
155 tracelimit::error_ratelimited!(
156 err = &err as &dyn std::error::Error,
157 "Failed processing shutdown message"
158 );
159 }
160 }
161 _ => {
162 tracelimit::error_ratelimited!(r#type = ?header.message_type, state = ?self.state, "Unrecognized packet");
163 }
164 }
165 }
166
167 fn find_latest_supported_version<'a>(
168 buf: &'a [u8],
169 count: usize,
170 supported: &[hyperv_ic_protocol::Version],
171 ) -> (Option<hyperv_ic_protocol::Version>, &'a [u8]) {
172 let mut rest = buf;
173 let mut next_version;
174 let mut latest_version = None;
175 for _ in 0..count {
176 (next_version, rest) = match hyperv_ic_protocol::Version::read_from_prefix(rest).ok() {
178 Some((n, r)) => (n, r),
179 None => {
180 tracelimit::error_ratelimited!("truncated message version list");
181 return (latest_version, rest);
182 }
183 };
184 for known in supported {
185 if known.major == next_version.major && known.minor == next_version.minor {
186 if latest_version.is_some() {
187 if next_version.major >= latest_version.unwrap().major {
188 if next_version.major > latest_version.unwrap().major
189 || next_version.minor > latest_version.unwrap().minor
190 {
191 latest_version = Some(next_version);
192 }
193 }
194 } else {
195 latest_version = Some(next_version);
196 }
197 }
198 }
199 }
200 (latest_version, rest)
201 }
202
203 async fn handle_version_negotiation(
204 &mut self,
205 header: &hyperv_ic_protocol::Header,
206 msg: &[u8],
207 ) -> Result<(), Error> {
208 const FRAMEWORK_VERSIONS: &[hyperv_ic_protocol::Version] =
209 &[FRAMEWORK_VERSION_1, FRAMEWORK_VERSION_3];
210
211 const SHUTDOWN_VERSIONS: &[hyperv_ic_protocol::Version] = &[
212 SHUTDOWN_VERSION_1,
213 SHUTDOWN_VERSION_3,
214 SHUTDOWN_VERSION_3_1,
215 SHUTDOWN_VERSION_3_2,
216 ];
217
218 let (prefix, rest) = hyperv_ic_protocol::NegotiateMessage::read_from_prefix(msg)
219 .map_err(|_| Error::TruncatedMessage)?; let (latest_framework_version, rest) = Self::find_latest_supported_version(
221 rest,
222 prefix.framework_version_count as usize,
223 FRAMEWORK_VERSIONS,
224 );
225 let framework_version = if let Some(version) = latest_framework_version {
226 version
227 } else {
228 tracelimit::error_ratelimited!("Unsupported framework version");
229 FRAMEWORK_VERSIONS[FRAMEWORK_VERSIONS.len() - 1]
230 };
231 let (latest_message_version, _) = Self::find_latest_supported_version(
232 rest,
233 prefix.message_version_count as usize,
234 SHUTDOWN_VERSIONS,
235 );
236 let message_version = if let Some(version) = latest_message_version {
237 version
238 } else {
239 tracelimit::error_ratelimited!("Unsupported message version");
240 SHUTDOWN_VERSIONS[SHUTDOWN_VERSIONS.len() - 1]
241 };
242
243 let message = hyperv_ic_protocol::NegotiateMessage {
244 framework_version_count: 1,
245 message_version_count: 1,
246 ..FromZeros::new_zeroed()
247 };
248 let response = hyperv_ic_protocol::Header {
249 message_type: hyperv_ic_protocol::MessageType::VERSION_NEGOTIATION,
250 message_size: (size_of_val(&message)
251 + size_of_val(&framework_version)
252 + size_of_val(&message_version)) as u16,
253 status: Status::SUCCESS,
254 transaction_id: header.transaction_id,
255 flags: hyperv_ic_protocol::HeaderFlags::new()
256 .with_transaction(header.flags.transaction())
257 .with_response(true),
258 ..FromZeros::new_zeroed()
259 };
260 self.pipe
261 .send_vectored(&[
262 IoSlice::new(response.as_bytes()),
263 IoSlice::new(message.as_bytes()),
264 IoSlice::new(framework_version.as_bytes()),
265 IoSlice::new(message_version.as_bytes()),
266 ])
267 .await
268 .map_err(Error::Ring)?;
269
270 tracing::info!(%framework_version, %message_version, "version negotiated");
271 self.state = ShutdownGuestChannelState::Running {
272 framework_version,
273 message_version,
274 };
275 Ok(())
276 }
277
278 async fn handle_shutdown_notification(
279 &mut self,
280 header: &hyperv_ic_protocol::Header,
281 buf: &[u8],
282 ic: &ShutdownGuestIc,
283 ) -> Result<(), Error> {
284 let ShutdownGuestChannelState::Running {
285 framework_version,
286 message_version,
287 } = &self.state
288 else {
289 panic!("Shutdown message processing while in invalid state");
290 };
291
292 let message = hyperv_ic_protocol::shutdown::ShutdownMessage::read_from_prefix(buf)
293 .map_err(|_| Error::TruncatedMessage)?
294 .0; let shutdown_type = if message.flags.restart() {
296 ShutdownType::Reboot
297 } else if message.flags.hibernate() {
298 ShutdownType::Hibernate
299 } else {
300 ShutdownType::PowerOff
301 };
302 let params = ShutdownParams {
303 shutdown_type,
304 force: message.flags.force(),
305 };
306
307 let result = ic.send_shutdown_notification.call(|x| x, params).await;
309
310 let response = hyperv_ic_protocol::Header {
312 framework_version: *framework_version,
313 message_version: *message_version,
314 message_type: hyperv_ic_protocol::MessageType::SHUTDOWN,
315 message_size: 0,
316 status: if result.is_ok() {
317 Status::SUCCESS
318 } else {
319 Status::FAIL
320 },
321 transaction_id: header.transaction_id,
322 flags: hyperv_ic_protocol::HeaderFlags::new()
323 .with_transaction(header.flags.transaction())
324 .with_response(true),
325 ..FromZeros::new_zeroed()
326 };
327 self.pipe
328 .send(response.as_bytes())
329 .await
330 .map_err(Error::Ring)
331 }
332}
333
334async fn read_from_pipe<T: RingMem>(pipe: &mut MessagePipe<T>) -> Result<Vec<u8>, Error> {
335 let mut buf = vec![0; hyperv_ic_protocol::MAX_MESSAGE_SIZE];
336 let n = pipe.recv(&mut buf).await.map_err(Error::Ring)?;
337 let buf = &buf[..n];
338 Ok(buf.to_vec())
339}
340
341impl SimpleVmbusClientDevice for ShutdownGuestIc {
342 type SavedState = NoSavedState;
343 type Runner = ShutdownGuestChannel;
344
345 fn instance_id(&self) -> Guid {
346 INSTANCE_ID
347 }
348
349 fn offer(&self, _offer: &vmbus_core::protocol::OfferChannel) -> OfferResponse {
350 OfferResponse::Open
351 }
352
353 fn inspect(&mut self, req: inspect::Request<'_>, runner: Option<&mut Self::Runner>) {
354 req.respond().merge(self).merge(runner);
355 }
356
357 fn open(
358 &mut self,
359 _channel_idx: u16,
360 channel: RawAsyncChannel<MemoryBlockRingBuffer>,
361 ) -> Result<Self::Runner, ChannelOpenError> {
362 let pipe = MessagePipe::new(channel)?;
363 Ok(ShutdownGuestChannel::new(pipe))
364 }
365
366 fn close(&mut self, _channel_idx: u16) {}
367
368 fn supports_save_restore(
369 &mut self,
370 ) -> Option<
371 &mut dyn SaveRestoreSimpleVmbusClientDevice<
372 SavedState = Self::SavedState,
373 Runner = Self::Runner,
374 >,
375 > {
376 None
377 }
378}
379
380impl SimpleVmbusClientDeviceAsync for ShutdownGuestIc {
381 async fn run(
382 &mut self,
383 stop: &mut StopTask<'_>,
384 runner: &mut Self::Runner,
385 ) -> Result<(), Cancelled> {
386 stop.until_stopped(async {
387 match runner.process(self).await {
388 Ok(()) => {}
389 Err(err) => {
390 tracing::error!(
391 error = &err as &dyn std::error::Error,
392 "shutdown ic relay error"
393 )
394 }
395 }
396 })
397 .await
398 }
399}