diag_server/
diag_service.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! RPC service for diagnostics.
5
6use crate::grpc_result;
7use crate::new_pty;
8use anyhow::Context;
9use azure_profiler_proto::AzureProfiler;
10use azure_profiler_proto::ProfileRequest;
11use diag_proto::ExecRequest;
12use diag_proto::ExecResponse;
13use diag_proto::FILE_LINE_MAX;
14use diag_proto::FileRequest;
15use diag_proto::KmsgRequest;
16use diag_proto::NetworkPacketCaptureRequest;
17use diag_proto::NetworkPacketCaptureResponse;
18use diag_proto::OpenhclDiag;
19use diag_proto::StartRequest;
20use diag_proto::UnderhillDiag;
21use diag_proto::WaitRequest;
22use diag_proto::WaitResponse;
23use diag_proto::network_packet_capture_request::Operation;
24use futures::AsyncRead;
25use futures::AsyncReadExt;
26use futures::AsyncWrite;
27use futures::AsyncWriteExt;
28use futures::FutureExt;
29use futures::StreamExt;
30use futures::future::join_all;
31use futures::io::AllowStdIo;
32use futures_concurrency::stream::Merge;
33use inspect::InspectionBuilder;
34use inspect_proto::InspectRequest;
35use inspect_proto::InspectResponse2;
36use inspect_proto::InspectService;
37use inspect_proto::UpdateRequest;
38use inspect_proto::UpdateResponse2;
39use mesh::CancelContext;
40use mesh::rpc::FailableRpc;
41use mesh::rpc::RpcSend;
42use mesh_rpc::server::RpcReceiver;
43use net_packet_capture::OperationData;
44use net_packet_capture::PacketCaptureOperation;
45use net_packet_capture::PacketCaptureParams;
46use net_packet_capture::StartData;
47use pal::unix::process::Stdio;
48use pal_async::driver::Driver;
49use pal_async::interest::InterestSlot;
50use pal_async::interest::PollEvents;
51use pal_async::pipe::PolledPipe;
52use pal_async::socket::AsSockRef;
53use pal_async::socket::PollReady;
54use pal_async::socket::PollReadyExt;
55use pal_async::socket::PolledSocket;
56use pal_async::task::Spawn;
57use pal_async::task::Task;
58use parking_lot::Mutex;
59use socket2::Socket;
60use std::collections::HashMap;
61use std::fs::File;
62use std::future::poll_fn;
63use std::io;
64use std::io::Read;
65use std::os::unix::fs::FileTypeExt;
66use std::os::unix::prelude::*;
67use std::process::ExitStatus;
68use std::sync::Arc;
69
70/// A diagnostics request.
71#[derive(Debug, mesh::MeshPayload)]
72pub enum DiagRequest {
73    /// Start the VM, if it has not already been started.
74    Start(FailableRpc<StartParams, ()>),
75    /// Inspect the VM.
76    Inspect(inspect::Deferred),
77    /// Crash the VM
78    Crash(i32),
79    /// Restart the worker.
80    Restart(FailableRpc<(), ()>),
81    /// Pause VTL0
82    Pause(FailableRpc<(), ()>),
83    /// Resume VTL0
84    Resume(FailableRpc<(), ()>),
85    /// Save VTL2 state
86    Save(FailableRpc<(), Vec<u8>>),
87    /// Setup network trace
88    PacketCapture(FailableRpc<PacketCaptureParams<Socket>, PacketCaptureParams<Socket>>),
89    /// Profile VTL2
90    #[cfg(feature = "profiler")]
91    Profile(FailableRpc<profiler_worker::ProfilerRequest, ()>),
92}
93
94/// Additional parameters provided as part of a delayed start request.
95#[derive(Debug, mesh::MeshPayload)]
96pub struct StartParams {
97    /// Environment variables to set or remove.
98    pub env: Vec<(String, Option<String>)>,
99    /// Command line arguments to append.
100    pub args: Vec<String>,
101}
102
103pub(crate) struct DiagServiceHandler {
104    request_send: mesh::Sender<DiagRequest>,
105    children: Mutex<HashMap<i32, Task<ExitStatus>>>,
106    inspect_sensitivity_level: Option<inspect::SensitivityLevel>,
107    inner: Arc<crate::Inner>,
108}
109
110impl DiagServiceHandler {
111    pub fn new(request_send: mesh::Sender<DiagRequest>, inner: Arc<crate::Inner>) -> Self {
112        Self {
113            children: Default::default(),
114            request_send,
115            // On CVMs only allow inspecting nodes defined as safe.
116            inspect_sensitivity_level: if underhill_confidentiality::confidential_filtering_enabled(
117            ) {
118                Some(inspect::SensitivityLevel::Safe)
119            } else {
120                None
121            },
122            // TODO: use a remotable type for `Inner`, which is just used to get
123            // data connection sockets.
124            inner,
125        }
126    }
127
128    pub async fn process_requests(
129        self: &Arc<Self>,
130        driver: &(impl Driver + Spawn + Clone),
131        diag_recv: RpcReceiver<UnderhillDiag>,
132        diag2_recv: RpcReceiver<OpenhclDiag>,
133        inspect_recv: RpcReceiver<InspectService>,
134        profile_recv: RpcReceiver<AzureProfiler>,
135    ) -> anyhow::Result<()> {
136        enum Event {
137            Diag(UnderhillDiag),
138            Diag2(OpenhclDiag),
139            Inspect(InspectService),
140            Profile(AzureProfiler),
141        }
142        let mut s = (
143            diag_recv.map(|(ctx, req)| (ctx, Event::Diag(req))),
144            diag2_recv.map(|(ctx, req)| (ctx, Event::Diag2(req))),
145            inspect_recv.map(|(ctx, req)| (ctx, Event::Inspect(req))),
146            profile_recv.map(|(ctx, req)| (ctx, Event::Profile(req))),
147        )
148            .merge();
149
150        while let Some((ctx, req)) = s.next().await {
151            driver
152                .spawn("diag request", {
153                    let driver = driver.clone();
154                    let this = self.clone();
155                    async move {
156                        match req {
157                            Event::Diag(req) => this.handle_diag_request(&driver, req, ctx).await,
158                            Event::Diag2(req) => this.handle_diag2_request(&driver, req, ctx).await,
159                            Event::Inspect(req) => this.handle_inspect_request(req, ctx).await,
160                            Event::Profile(req) => this.handle_profile_request(req, ctx).await,
161                        }
162                    }
163                })
164                .detach();
165        }
166        Ok(())
167    }
168
169    async fn take_connection(&self, id: u64) -> anyhow::Result<PolledSocket<Socket>> {
170        self.inner.take_connection(id).await
171    }
172
173    async fn handle_inspect_request(&self, req: InspectService, mut ctx: CancelContext) {
174        match req {
175            InspectService::Inspect(request, response) => {
176                let inspect_response = self.handle_inspect(&request, ctx).await;
177                response.send(grpc_result(Ok(Ok(inspect_response))));
178            }
179            InspectService::Update(request, response) => {
180                response.send(grpc_result(
181                    ctx.until_cancelled(self.handle_update(&request)).await,
182                ));
183            }
184        }
185    }
186
187    async fn handle_profile_request(&self, req: AzureProfiler, mut ctx: CancelContext) {
188        match req {
189            AzureProfiler::Profile(request, response) => response.send(grpc_result(
190                ctx.until_cancelled(self.handle_profile(request)).await,
191            )),
192        }
193    }
194
195    async fn handle_diag_request(
196        &self,
197        driver: &(impl Driver + Spawn + Clone),
198        req: UnderhillDiag,
199        mut ctx: CancelContext,
200    ) {
201        match req {
202            UnderhillDiag::Exec(request, response) => response.send(grpc_result(
203                ctx.until_cancelled(self.handle_exec(driver, &request))
204                    .await,
205            )),
206            UnderhillDiag::Wait(request, response) => response.send(grpc_result(
207                ctx.until_cancelled(self.handle_wait(&request)).await,
208            )),
209            UnderhillDiag::Start(request, response) => {
210                response.send(grpc_result(
211                    ctx.until_cancelled(self.handle_start(request)).await,
212                ));
213            }
214            UnderhillDiag::Kmsg(request, response) => {
215                response.send(grpc_result(Ok(self.handle_kmsg(driver, &request).await)))
216            }
217            UnderhillDiag::Crash(request, response) => {
218                response.send(grpc_result(
219                    ctx.until_cancelled(self.handle_crash(request)).await,
220                ));
221            }
222            UnderhillDiag::Restart(_, response) => {
223                response.send(grpc_result(
224                    ctx.until_cancelled(self.handle_restart()).await,
225                ));
226            }
227            UnderhillDiag::ReadFile(request, response) => response.send(grpc_result(Ok(self
228                .handle_read_file(driver, &request)
229                .await))),
230            UnderhillDiag::Pause(_, response) => {
231                response.send(grpc_result(ctx.until_cancelled(self.handle_pause()).await))
232            }
233            UnderhillDiag::PacketCapture(request, response) => response.send(grpc_result(
234                ctx.until_cancelled(self.handle_packet_capture(&request))
235                    .await,
236            )),
237            UnderhillDiag::Resume(_, response) => {
238                response.send(grpc_result(ctx.until_cancelled(self.handle_resume()).await))
239            }
240            UnderhillDiag::DumpSavedState((), response) => response.send(grpc_result(
241                ctx.until_cancelled(self.handle_dump_saved_state()).await,
242            )),
243        }
244    }
245
246    async fn handle_diag2_request(
247        &self,
248        _driver: &(impl Driver + Spawn + Clone),
249        req: OpenhclDiag,
250        _ctx: CancelContext,
251    ) {
252        match req {
253            OpenhclDiag::Ping((), response) => {
254                response.send(Ok(()));
255            }
256        }
257    }
258
259    async fn handle_start(&self, request: StartRequest) -> anyhow::Result<()> {
260        let params = StartParams {
261            env: request
262                .env
263                .into_iter()
264                .map(|pair| (pair.name, pair.value))
265                .collect(),
266            args: request.args,
267        };
268        self.request_send
269            .call_failable(DiagRequest::Start, params)
270            .await?;
271        Ok(())
272    }
273
274    async fn handle_crash(&self, request: diag_proto::CrashRequest) -> anyhow::Result<()> {
275        self.request_send.send(DiagRequest::Crash(request.pid));
276
277        Ok(())
278    }
279
280    async fn handle_exec(
281        &self,
282        driver: &(impl Driver + Spawn + Clone),
283        request: &ExecRequest,
284    ) -> anyhow::Result<ExecResponse> {
285        tracing::info!(
286            command = %request.command,
287            stdin = request.stdin,
288            stdout = request.stdout,
289            stderr = request.stderr,
290            tty = request.tty,
291            "exec request"
292        );
293
294        let stdin = if request.stdin != 0 {
295            Some(
296                self.take_connection(request.stdin)
297                    .await
298                    .context("failed to get stdin conn")?,
299            )
300        } else {
301            None
302        };
303        let stdout = if request.stdout != 0 {
304            Some(
305                self.take_connection(request.stdout)
306                    .await
307                    .context("failed to get stdout conn")?,
308            )
309        } else {
310            None
311        };
312        let stderr = if request.stderr != 0 {
313            Some(
314                self.take_connection(request.stderr)
315                    .await
316                    .context("failed to get stderr conn")?,
317            )
318        } else {
319            None
320        };
321
322        let mut builder = pal::unix::process::Builder::new(&request.command);
323        builder.args(&request.args);
324        if request.clear_env {
325            builder.env_clear();
326        }
327        for diag_proto::EnvPair { name, value } in &request.env {
328            if let Some(value) = value {
329                builder.env(name, value);
330            } else {
331                builder.env_remove(name);
332            }
333        }
334
335        // HACK: A hack to fix segfault caused by glibc bug in L1 TDX VMM.
336        // Should be removed after glibc update or a clean CPUID virtualization solution.
337        // Please refer to https://github.com/microsoft/openvmm-deps/issues/21 for more information.
338        // xtask-fmt allow-target-arch cpu-intrinsic
339        #[cfg(target_arch = "x86_64")]
340        {
341            let result =
342                safe_intrinsics::cpuid(hvdef::HV_CPUID_FUNCTION_MS_HV_ISOLATION_CONFIGURATION, 0);
343            // Value 3 means TDX.
344            let tdx_isolated = (result.ebx & 0xF) == 3;
345            if tdx_isolated {
346                builder.env("GLIBC_TUNABLES", "glibc.cpu.x86_non_temporal_threshold=0x11a000:glibc.cpu.x86_rep_movsb_threshold=0x4000");
347            }
348        };
349
350        let mut stdin_relay = None;
351        let mut stdout_relay = None;
352        let mut stderr_relay = None;
353        let mut raw_stdout = None;
354        let mut raw_stderr = None;
355        let mut child = {
356            let (stdin_pipes, stdout_pipes, stderr_pipes);
357            let stdin_socket;
358            let stdout_socket;
359            let stderr_socket;
360            let pty;
361            if request.tty {
362                pty = new_pty::new_pty().context("failed to create pty")?;
363
364                let primary = PolledPipe::new(driver, pty.0)
365                    .context("failed to create polled pty primary")?;
366
367                let secondary = &pty.1;
368
369                let (primary_read, primary_write) = primary.split();
370                if let Some(stdin) = stdin {
371                    stdin_relay = Some(driver.spawn("pty stdin relay", async move {
372                        relay(stdin, primary_write).await;
373                    }));
374                }
375                if let Some(stdout) = stdout {
376                    stdout_relay =
377                        Some(driver.spawn("pty stdout relay", relay(primary_read, stdout)));
378                }
379
380                builder
381                    .setsid(true)
382                    .controlling_terminal(secondary.as_fd())
383                    .stdin(Stdio::Fd(secondary.as_fd()))
384                    .stdout(Stdio::Fd(secondary.as_fd()))
385                    .stderr(Stdio::Fd(secondary.as_fd()));
386            } else if request.raw_socket_io {
387                if let Some(stdin) = stdin {
388                    stdin_socket = stdin.into_inner();
389                    builder.stdin(Stdio::Fd(stdin_socket.as_fd()));
390                }
391                if let Some(stdout) = stdout {
392                    stdout_socket = raw_stdout.insert(stdout.into_inner());
393                    builder.stdout(Stdio::Fd(stdout_socket.as_fd()));
394                    if request.combine_stderr {
395                        builder.stderr(Stdio::Fd(stdout_socket.as_fd()));
396                    }
397                }
398                if let Some(stderr) = stderr {
399                    stderr_socket = raw_stderr.insert(stderr.into_inner());
400                    builder.stderr(Stdio::Fd(stderr_socket.as_fd()));
401                }
402            } else {
403                if let Some(stdin) = stdin {
404                    stdin_pipes = pal::unix::pipe::pair().context("failed to create pipe pair")?;
405                    let pipe = PolledPipe::new(driver, stdin_pipes.1)
406                        .context("failed to create polled pipe")?;
407                    stdin_relay = Some(driver.spawn("stdin relay", async move {
408                        relay(stdin, pipe).await;
409                    }));
410                    builder.stdin(Stdio::Fd(stdin_pipes.0.as_fd()));
411                }
412                if let Some(stdout) = stdout {
413                    stdout_pipes = pal::unix::pipe::pair().context("failed to create pipe pair")?;
414                    let pipe = PolledPipe::new(driver, stdout_pipes.0)
415                        .context("failed to create polled pipe")?;
416                    stdout_relay = Some(driver.spawn("stdout relay", relay(pipe, stdout)));
417                    builder.stdout(Stdio::Fd(stdout_pipes.1.as_fd()));
418                    if request.combine_stderr {
419                        builder.stderr(Stdio::Fd(stdout_pipes.1.as_fd()));
420                    }
421                }
422                if let Some(stderr) = stderr {
423                    stderr_pipes = pal::unix::pipe::pair().context("failed to create pipe pair")?;
424                    let pipe = PolledPipe::new(driver, stderr_pipes.0)
425                        .context("failed to create polled pipe")?;
426                    stderr_relay = Some(driver.spawn("stderr relay", relay(pipe, stderr)));
427                    builder.stderr(Stdio::Fd(stderr_pipes.1.as_fd()));
428                }
429            }
430
431            builder
432                .spawn()
433                .with_context(|| format!("failed to launch {}", &request.command))?
434        };
435
436        let pid = child.id();
437
438        tracing::info!(pid, "spawned child");
439
440        let mut child_ready = driver
441            .new_dyn_fd_ready(child.as_fd().as_raw_fd())
442            .expect("failed creating child poll");
443
444        let status = driver.spawn("diag child wait", {
445            let driver = driver.clone();
446            async move {
447                poll_fn(|cx| child_ready.poll_fd_ready(cx, InterestSlot::Read, PollEvents::IN))
448                    .await;
449                let status = child.try_wait().unwrap().unwrap();
450                tracing::info!(pid, ?status, "child exited");
451
452                // The process is gone, so the stdin relay's job is done.
453                drop(stdin_relay);
454
455                // Shut down raw stdout and stderr to notify the host that there
456                // is no more data.
457                let finish_raw = |raw: Option<Socket>| {
458                    raw.and_then(|raw| {
459                        let _ = raw.as_sock_ref().shutdown(std::net::Shutdown::Write);
460                        PolledSocket::new(&driver, raw).ok()
461                    })
462                };
463                let raw_stdout = finish_raw(raw_stdout);
464                let raw_stderr = finish_raw(raw_stderr);
465
466                // Wait for the host to finish with the stdout and stderr
467                // sockets, but don't block the process exit notification.
468                driver
469                    .spawn("socket-wait", async move {
470                        let await_output_relay = async |task, raw| {
471                            let socket = if let Some(task) = task {
472                                Some(task.await)
473                            } else {
474                                raw
475                            };
476                            if let Some(socket) = socket {
477                                // Wait for the host to close the socket to ensure that all
478                                // the data is written.
479                                let _ = futures::io::copy(socket, &mut futures::io::sink()).await;
480                            }
481                        };
482
483                        await_output_relay(stdout_relay, raw_stdout).await;
484                        await_output_relay(stderr_relay, raw_stderr).await;
485                    })
486                    .detach();
487
488                status
489            }
490        });
491
492        self.children.lock().insert(pid, status);
493        Ok(ExecResponse { pid })
494    }
495
496    async fn handle_wait(&self, request: &WaitRequest) -> anyhow::Result<WaitResponse> {
497        tracing::debug!(pid = request.pid, "wait request");
498        let channel = self
499            .children
500            .lock()
501            .remove(&request.pid)
502            .context("pid not found")?;
503
504        let status = channel.await;
505        let exit_code = status.code().unwrap_or(255);
506
507        tracing::debug!(pid = request.pid, exit_code, "wait complete");
508
509        Ok(WaitResponse { exit_code })
510    }
511
512    async fn handle_inspect(
513        &self,
514        request: &InspectRequest,
515        mut ctx: CancelContext,
516    ) -> InspectResponse2 {
517        tracing::debug!(
518            path = request.path.as_str(),
519            depth = request.depth,
520            "inspect request"
521        );
522        let mut inspection = InspectionBuilder::new(&request.path)
523            .depth(Some(request.depth as usize))
524            .sensitivity(self.inspect_sensitivity_level)
525            .inspect(inspect::send(&self.request_send, DiagRequest::Inspect));
526
527        // Don't return early on cancel, as we want to return the partial
528        // inspection results.
529        let _ = ctx.until_cancelled(inspection.resolve()).await;
530
531        let result = inspection.results();
532        InspectResponse2 { result }
533    }
534
535    async fn handle_update(&self, request: &UpdateRequest) -> anyhow::Result<UpdateResponse2> {
536        tracing::debug!(
537            path = request.path.as_str(),
538            value = request.value.as_str(),
539            "update request"
540        );
541        let new_value = InspectionBuilder::new(&request.path)
542            .sensitivity(self.inspect_sensitivity_level)
543            .update(
544                &request.value,
545                inspect::send(&self.request_send, DiagRequest::Inspect),
546            )
547            .await?;
548        Ok(UpdateResponse2 { new_value })
549    }
550
551    async fn handle_kmsg(
552        &self,
553        driver: &(impl Driver + Spawn + Clone),
554        request: &KmsgRequest,
555    ) -> anyhow::Result<()> {
556        self.handle_read_file_request(driver, request.conn, request.follow, "/dev/kmsg")
557            .await
558    }
559
560    async fn handle_read_file(
561        &self,
562        driver: &(impl Driver + Spawn + Clone),
563        request: &FileRequest,
564    ) -> anyhow::Result<()> {
565        self.handle_read_file_request(driver, request.conn, request.follow, &request.file_path)
566            .await
567    }
568
569    async fn handle_packet_capture(
570        &self,
571        request: &NetworkPacketCaptureRequest,
572    ) -> anyhow::Result<NetworkPacketCaptureResponse> {
573        let operation = if request.operation == Operation::Query as i32 {
574            PacketCaptureOperation::Query
575        } else if request.operation == Operation::Start as i32 {
576            PacketCaptureOperation::Start
577        } else if request.operation == Operation::Stop as i32 {
578            PacketCaptureOperation::Stop
579        } else {
580            anyhow::bail!("unsupported request type {}", request.operation);
581        };
582
583        let op_data = match operation {
584            // Query the number of streams needed, starting with a value of 0.
585            PacketCaptureOperation::Query => Some(OperationData::OpQueryData(0)),
586            PacketCaptureOperation::Start => {
587                let Some(op_data) = &request.op_data else {
588                    anyhow::bail!("missing start operation parameters");
589                };
590
591                match op_data {
592                    diag_proto::network_packet_capture_request::OpData::StartData(start_data) => {
593                        let writers = join_all(start_data.conns.iter().map(async |c| {
594                            let conn = self.take_connection(*c).await?;
595                            Ok(conn.into_inner())
596                        }))
597                        .await
598                        .into_iter()
599                        .collect::<anyhow::Result<Vec<Socket>>>()?;
600                        Some(OperationData::OpStartData(StartData {
601                            writers,
602                            snaplen: start_data.snaplen,
603                        }))
604                    }
605                }
606            }
607            _ => None,
608        };
609
610        let params = PacketCaptureParams { operation, op_data };
611        let params = self
612            .request_send
613            .call_failable(DiagRequest::PacketCapture, params)
614            .await?;
615        let num_streams = match params.op_data {
616            Some(OperationData::OpQueryData(num_streams)) => num_streams,
617            _ => 0,
618        };
619        Ok(NetworkPacketCaptureResponse { num_streams })
620    }
621
622    async fn handle_profile(&self, request: ProfileRequest) -> anyhow::Result<()> {
623        let conn = self.take_connection(request.conn).await?;
624        #[cfg(feature = "profiler")]
625        {
626            let profiler_request = profiler_worker::ProfilerRequest {
627                profiler_args: request.profiler_args,
628                duration: request.duration,
629                conn: conn.into_inner(),
630            };
631
632            self.request_send
633                .call_failable(DiagRequest::Profile, profiler_request)
634                .await?;
635        }
636        #[cfg(not(feature = "profiler"))]
637        {
638            // Profiler feature disabled, drop the connection.
639            drop(conn);
640            tracing::error!("Profiler feature disabled");
641        }
642        Ok(())
643    }
644
645    async fn handle_read_file_request(
646        &self,
647        driver: &(impl Driver + Spawn + Clone),
648        conn: u64,
649        follow: bool,
650        file_path: &str,
651    ) -> anyhow::Result<()> {
652        let mut conn = self.take_connection(conn).await?;
653        let file = fs_err::File::open(file_path).context("failed to open file")?;
654
655        let file_meta = file.metadata()?;
656
657        if file_meta.file_type().is_char_device() {
658            let file =
659                PolledPipe::new(driver, file.into()).context("failed to create polled pipe")?;
660
661            driver
662                .spawn("read file relay", async move {
663                    if let Err(err) = relay_read_file(file, conn, follow).await {
664                        tracing::warn!(
665                            error = &*err as &dyn std::error::Error,
666                            "read file relay failed"
667                        );
668                    }
669                })
670                .detach();
671        } else if file_meta.file_type().is_file() {
672            driver
673                .spawn("read file relay", async move {
674                    // Since this is a file, and in Underhill files are backed
675                    // by RAM, allow blocking reads directly on this thread,
676                    // since the reads should be satisfied instantly.
677                    //
678                    // (If this becomes a problem, we can spawn a thread to do
679                    // this, or use io-uring.)
680                    if let Err(err) =
681                        futures::io::copy(AllowStdIo::new(File::from(file)), &mut conn).await
682                    {
683                        tracing::warn!(
684                            error = &err as &dyn std::error::Error,
685                            "read file relay failed"
686                        );
687                    }
688                })
689                .detach();
690        } else {
691            anyhow::bail!("cannot read directory");
692        }
693
694        Ok(())
695    }
696
697    async fn handle_restart(&self) -> anyhow::Result<()> {
698        self.request_send
699            .call_failable(DiagRequest::Restart, ())
700            .await?;
701        Ok(())
702    }
703
704    async fn handle_pause(&self) -> anyhow::Result<()> {
705        self.request_send
706            .call_failable(DiagRequest::Pause, ())
707            .await?;
708        Ok(())
709    }
710
711    async fn handle_resume(&self) -> anyhow::Result<()> {
712        self.request_send
713            .call_failable(DiagRequest::Resume, ())
714            .await?;
715        Ok(())
716    }
717
718    async fn handle_dump_saved_state(&self) -> anyhow::Result<diag_proto::DumpSavedStateResponse> {
719        let data = self
720            .request_send
721            .call_failable(DiagRequest::Save, ())
722            .await?;
723
724        Ok(diag_proto::DumpSavedStateResponse { data })
725    }
726}
727
728async fn relay<
729    R: 'static + AsyncRead + Unpin + Send,
730    W: 'static + AsyncWrite + PollReady + Unpin + Send,
731>(
732    mut read: R,
733    mut write: W,
734) -> W {
735    let mut buffer = [0; 1024];
736    let result: anyhow::Result<_> = async {
737        loop {
738            let n = futures::select! { // merge semantics
739                n = read.read(&mut buffer).fuse() => n.context("read failed")?,
740                _ = write.wait_ready(PollEvents::RDHUP).fuse() => {
741                    // RDHUP indicates the connection is closed or shut down.
742                    // Although generically this does not indicate that the
743                    // connection does not want to _read_ any more data, for our
744                    // use cases it does (either we are using a unidirectional
745                    // pipe/socket, or we are using a pty, which never returns
746                    // RDHUP but does return HUP, which is just as good).
747                    //
748                    // Stop this relay to propagate the close notification to
749                    // the other endpoint.
750                    break;
751                }
752            };
753            if n == 0 {
754                break;
755            }
756            write
757                .write_all(&buffer[..n])
758                .await
759                .context("write failed")?;
760        }
761        Ok(())
762    }
763    .await;
764    let _ = write.close().await;
765    if let Err(err) = result {
766        tracing::warn!(error = &*err as &dyn std::error::Error, "relay error");
767    }
768    write
769}
770
771async fn relay_read_file(
772    mut file: PolledPipe,
773    mut conn: PolledSocket<Socket>,
774    follow: bool,
775) -> anyhow::Result<()> {
776    let mut buffer = [0; FILE_LINE_MAX];
777    loop {
778        let n = if follow {
779            futures::select! { // race semantics
780                _ = conn.wait_ready(PollEvents::RDHUP).fuse() => break,
781                n = file.read(&mut buffer[..FILE_LINE_MAX - 1]).fuse() => n
782            }
783        } else {
784            // The caller just wants the current contents of file, so issue a
785            // nonblocking, non-async read, and handle EAGAIN below.
786            file.get().read(&mut buffer[..FILE_LINE_MAX - 1])
787        };
788        let n = match n {
789            Ok(0) => break,
790            Ok(count) => count,
791            Err(e) => {
792                match e.kind() {
793                    io::ErrorKind::BrokenPipe => {
794                        // The kmsg interface returns EPIPE if an entry has overwritten another in the ring.
795                        // Retry the read which has the kernel move the seek position to the next available record.
796                        continue;
797                    }
798                    io::ErrorKind::WouldBlock => {
799                        // There are no more messages.
800                        assert!(!follow);
801                        break;
802                    }
803                    _ => return Err(e).context("file read failed"),
804                }
805            }
806        };
807        assert!(
808            n < buffer.len(),
809            "the file returned a line bigger than its maximum"
810        );
811        // Add a null terminator.
812        buffer[n] = 0;
813        // Write the message followed by a null terminator.
814        conn.write_all(&buffer[..n + 1])
815            .await
816            .context("socket write failed")?;
817    }
818    Ok(())
819}