pipette/
agent.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! The main pipette agent, which is run when the process starts.
5
6#![cfg(any(target_os = "linux", target_os = "windows"))]
7
8use anyhow::Context;
9use futures::future::FutureExt;
10use futures_concurrency::future::RaceOk;
11use mesh_remote::PointToPointMesh;
12use pal_async::DefaultDriver;
13use pal_async::socket::PolledSocket;
14use pal_async::task::Spawn;
15use pal_async::timer::PolledTimer;
16use pipette_protocol::DiagnosticFile;
17use pipette_protocol::PipetteBootstrap;
18use pipette_protocol::PipetteRequest;
19use socket2::Socket;
20use std::time::Duration;
21use std::time::SystemTime;
22use unicycle::FuturesUnordered;
23use vmsocket::VmAddress;
24use vmsocket::VmSocket;
25
26pub struct Agent {
27    driver: DefaultDriver,
28    mesh: PointToPointMesh,
29    request_recv: mesh::Receiver<PipetteRequest>,
30    diag_file_send: DiagnosticSender,
31    watch_send: mesh::OneshotSender<()>,
32}
33
34#[derive(Clone)]
35pub struct DiagnosticSender(mesh::Sender<DiagnosticFile>);
36
37impl Agent {
38    pub async fn new(driver: DefaultDriver) -> anyhow::Result<Self> {
39        let socket = (connect_client(&driver), connect_server(&driver))
40            .race_ok()
41            .await
42            .map_err(|e| {
43                let [e0, e1] = &*e;
44                anyhow::anyhow!(
45                    "failed to connect. client error: {:#} server error: {:#}",
46                    e0,
47                    e1
48                )
49            })?;
50
51        let (bootstrap_send, bootstrap_recv) = mesh::oneshot::<PipetteBootstrap>();
52        let mesh = PointToPointMesh::new(&driver, socket, bootstrap_recv.into());
53
54        let (request_send, request_recv) = mesh::channel();
55        let (diag_file_send, diag_file_recv) = mesh::channel();
56        let (watch_send, watch_recv) = mesh::oneshot();
57        let log = crate::trace::init_tracing();
58
59        bootstrap_send.send(PipetteBootstrap {
60            requests: request_send,
61            diag_file_recv,
62            watch: watch_recv,
63            log,
64        });
65
66        Ok(Self {
67            driver,
68            mesh,
69            request_recv,
70            diag_file_send: DiagnosticSender(diag_file_send),
71            watch_send,
72        })
73    }
74
75    pub async fn run(mut self) -> anyhow::Result<()> {
76        let mut tasks = FuturesUnordered::new();
77        loop {
78            futures::select! {
79                req = self.request_recv.recv().fuse() => {
80                    match req {
81                        Ok(req) => {
82                            tasks.push(handle_request(&self.driver, req, self.diag_file_send.clone()));
83                        },
84                        Err(e) => {
85                            tracing::info!(?e, "request channel closed, shutting down");
86                            break;
87                        }
88                    }
89                }
90                _ = tasks.next() => {}
91            }
92        }
93        self.watch_send.send(());
94        self.mesh.shutdown().await;
95        Ok(())
96    }
97}
98
99async fn connect_server(driver: &DefaultDriver) -> anyhow::Result<PolledSocket<Socket>> {
100    let mut socket = VmSocket::new()?;
101    socket.bind(VmAddress::vsock_any(pipette_protocol::PIPETTE_VSOCK_PORT))?;
102    let mut socket =
103        PolledSocket::new(driver, socket.into()).context("failed to create polled socket")?;
104    socket.listen(1)?;
105    let socket = socket
106        .accept()
107        .await
108        .context("failed to accept connection")?
109        .0;
110    PolledSocket::new(driver, socket).context("failed to create polled socket")
111}
112
113async fn connect_client(driver: &DefaultDriver) -> anyhow::Result<PolledSocket<Socket>> {
114    let socket = VmSocket::new()?;
115    // Extend the default timeout of 2 seconds, as tests are often run in
116    // parallel on a host, causing very heavy load on the overall system.
117    socket
118        .set_connect_timeout(Duration::from_secs(5))
119        .context("failed to set socket timeout")?;
120    let mut socket = PolledSocket::new(driver, socket)
121        .context("failed to create polled socket")?
122        .convert();
123    socket
124        .connect(&VmAddress::vsock_host(pipette_protocol::PIPETTE_VSOCK_PORT).into())
125        .await?;
126    Ok(socket)
127}
128
129async fn handle_request(
130    driver: &DefaultDriver,
131    req: PipetteRequest,
132    _diag_file_send: DiagnosticSender,
133) {
134    match req {
135        PipetteRequest::Ping(rpc) => rpc.handle_sync(|()| {
136            tracing::info!("ping");
137        }),
138        PipetteRequest::Execute(rpc) => rpc.handle_failable_sync(crate::execute::handle_execute),
139        PipetteRequest::Shutdown(rpc) => {
140            rpc.handle_sync(|request| {
141                tracing::info!(shutdown_type = ?request.shutdown_type, "shutdown request");
142                // TODO: handle this inline without waiting. Currently we spawn
143                // a task so that the response is sent before the shutdown
144                // starts, since hvlite fails to notice that the connection is
145                // closed if we power off while a response is pending.
146                let mut timer = PolledTimer::new(driver);
147                driver
148                    .spawn("shutdown", async move {
149                        // Because pipette runs as a system service on Windows
150                        // it is able to issue a shutdown command before Windows
151                        // has finished starting up and logging in the user. This
152                        // can put the system into a stuck state, where it is
153                        // completely unable to shut down. To avoid this, we
154                        // wait for a longer period before attempting to shut down.
155                        #[cfg(windows)]
156                        timer.sleep(Duration::from_secs(5)).await;
157                        #[cfg(not(windows))]
158                        timer.sleep(Duration::from_millis(250)).await;
159                        loop {
160                            if let Err(err) = crate::shutdown::handle_shutdown(request) {
161                                tracing::error!(
162                                    error = err.as_ref() as &dyn std::error::Error,
163                                    "failed to shut down"
164                                );
165                            }
166                            timer.sleep(Duration::from_secs(5)).await;
167                            tracing::warn!("still waiting to shut down, trying again");
168                        }
169                    })
170                    .detach();
171                Ok(())
172            })
173        }
174        PipetteRequest::ReadFile(rpc) => rpc.handle_failable(read_file).await,
175        PipetteRequest::WriteFile(rpc) => rpc.handle_failable(write_file).await,
176        PipetteRequest::GetTime(rpc) => rpc.handle_sync(|()| SystemTime::now().into()),
177    }
178}
179
180async fn read_file(mut request: pipette_protocol::ReadFileRequest) -> anyhow::Result<u64> {
181    tracing::debug!(path = request.path, "Beginning file read request");
182    let file = fs_err::File::open(request.path)?;
183    let n = futures::io::copy(&mut futures::io::AllowStdIo::new(file), &mut request.sender).await?;
184    tracing::debug!("file read request complete");
185    Ok(n)
186}
187
188async fn write_file(mut request: pipette_protocol::WriteFileRequest) -> anyhow::Result<u64> {
189    tracing::debug!(path = request.path, "Beginning file write request");
190    let file = fs_err::File::create(request.path)?;
191    let n = futures::io::copy(
192        &mut request.receiver,
193        &mut futures::io::AllowStdIo::new(file),
194    )
195    .await?;
196    tracing::debug!("file write request complete");
197    Ok(n)
198}
199
200impl DiagnosticSender {
201    #[cfg_attr(not(windows), expect(dead_code))]
202    pub async fn send(&self, filename: &str) -> anyhow::Result<()> {
203        tracing::debug!(filename, "Beginning diagnostic file request");
204        let file = fs_err::File::open(filename)?;
205        let (recv_pipe, mut send_pipe) = mesh::pipe::pipe();
206        self.0.send(DiagnosticFile {
207            name: filename.to_owned(),
208            receiver: recv_pipe,
209        });
210        futures::io::copy(&mut futures::io::AllowStdIo::new(file), &mut send_pipe).await?;
211        drop(send_pipe);
212        tracing::debug!("diagnostic request complete");
213        Ok(())
214    }
215}