pipette/
agent.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! The main pipette agent, which is run when the process starts.
5
6use anyhow::Context;
7use futures::future::FutureExt;
8use futures_concurrency::future::Race;
9use mesh_remote::PointToPointMesh;
10use pal_async::DefaultDriver;
11use pal_async::socket::PolledSocket;
12use pal_async::task::Spawn;
13use pal_async::timer::PolledTimer;
14use pipette_protocol::DiagnosticFile;
15use pipette_protocol::PipetteBootstrap;
16use pipette_protocol::PipetteRequest;
17use socket2::Socket;
18use std::time::Duration;
19use std::time::SystemTime;
20use unicycle::FuturesUnordered;
21use vmsocket::VmAddress;
22use vmsocket::VmSocket;
23
24pub struct Agent {
25    driver: DefaultDriver,
26    mesh: PointToPointMesh,
27    request_recv: mesh::Receiver<PipetteRequest>,
28    diag_file_send: DiagnosticSender,
29    watch_send: mesh::OneshotSender<()>,
30}
31
32#[derive(Clone)]
33pub struct DiagnosticSender(mesh::Sender<DiagnosticFile>);
34
35impl Agent {
36    pub async fn new(driver: DefaultDriver) -> anyhow::Result<Self> {
37        let socket = (connect_client(&driver), connect_server(&driver))
38            .race()
39            .await;
40
41        eprintln!("Pipette handshaking with host");
42        let (bootstrap_send, bootstrap_recv) = mesh::oneshot::<PipetteBootstrap>();
43        let mesh = PointToPointMesh::new(&driver, socket, bootstrap_recv.into());
44
45        let (request_send, request_recv) = mesh::channel();
46        let (diag_file_send, diag_file_recv) = mesh::channel();
47        let (watch_send, watch_recv) = mesh::oneshot();
48        eprintln!("Pipette initializing tracing");
49        let log = crate::trace::init_tracing();
50
51        bootstrap_send.send(PipetteBootstrap {
52            requests: request_send,
53            diag_file_recv,
54            watch: watch_recv,
55            log,
56        });
57        eprintln!("Pipette bootstrap sent to host");
58
59        Ok(Self {
60            driver,
61            mesh,
62            request_recv,
63            diag_file_send: DiagnosticSender(diag_file_send),
64            watch_send,
65        })
66    }
67
68    pub async fn run(mut self) -> anyhow::Result<()> {
69        let mut tasks = FuturesUnordered::new();
70        loop {
71            futures::select! {
72                req = self.request_recv.recv().fuse() => {
73                    match req {
74                        Ok(req) => {
75                            tasks.push(handle_request(&self.driver, req, self.diag_file_send.clone()));
76                        },
77                        Err(e) => {
78                            tracing::info!(?e, "request channel closed, shutting down");
79                            break;
80                        }
81                    }
82                }
83                _ = tasks.next() => {}
84            }
85        }
86        self.watch_send.send(());
87        self.mesh.shutdown().await;
88        Ok(())
89    }
90}
91
92async fn connect_server(driver: &DefaultDriver) -> PolledSocket<Socket> {
93    let server_core = async || {
94        let mut socket = VmSocket::new()?;
95        socket.bind(VmAddress::vsock_any(pipette_protocol::PIPETTE_VSOCK_PORT))?;
96        let mut socket =
97            PolledSocket::new(driver, socket.into()).context("failed to create polled socket")?;
98        socket.listen(1)?;
99        let socket = socket
100            .accept()
101            .await
102            .context("failed to accept connection")?
103            .0;
104        PolledSocket::new(driver, socket).context("failed to create polled server socket")
105    };
106
107    match server_core().await {
108        Ok(socket) => socket,
109        Err(err) => {
110            eprintln!("failed to stand up server: {:?}", err);
111            std::future::pending().await
112        }
113    }
114}
115
116async fn connect_client(driver: &DefaultDriver) -> PolledSocket<Socket> {
117    let client_core = async || {
118        let socket = VmSocket::new()?;
119        // Extend the default timeout of 2 seconds, as tests are often run in
120        // parallel on a host, causing very heavy load on the overall system.
121        socket
122            .set_connect_timeout(Duration::from_secs(5))
123            .context("failed to set socket timeout")?;
124        let mut socket = PolledSocket::new(driver, socket)
125            .context("failed to create polled client socket")?
126            .convert();
127        socket
128            .connect(&VmAddress::vsock_host(pipette_protocol::PIPETTE_VSOCK_PORT).into())
129            .await
130            .context("failed to connect")
131            .map(|()| socket)
132    };
133    loop {
134        let mut timer = PolledTimer::new(driver);
135        match client_core().await {
136            Ok(socket) => return socket,
137            Err(err) => {
138                eprintln!("failed to connect to server, retrying: {:?}", err);
139                timer.sleep(Duration::from_secs(1)).await;
140            }
141        }
142    }
143}
144
145async fn handle_request(
146    driver: &DefaultDriver,
147    req: PipetteRequest,
148    _diag_file_send: DiagnosticSender,
149) {
150    match req {
151        PipetteRequest::Ping(rpc) => rpc.handle_sync(|()| {
152            tracing::info!("ping");
153        }),
154        PipetteRequest::Execute(rpc) => rpc.handle_failable_sync(crate::execute::handle_execute),
155        PipetteRequest::Shutdown(rpc) => {
156            rpc.handle_sync(|request| {
157                tracing::info!(shutdown_type = ?request.shutdown_type, "shutdown request");
158                // TODO: handle this inline without waiting. Currently we spawn
159                // a task so that the response is sent before the shutdown
160                // starts, since OpenVMM fails to notice that the connection is
161                // closed if we power off while a response is pending.
162                let mut timer = PolledTimer::new(driver);
163                driver
164                    .spawn("shutdown", async move {
165                        // Because pipette runs as a system service on Windows
166                        // it is able to issue a shutdown command before Windows
167                        // has finished starting up and logging in the user. This
168                        // can put the system into a stuck state, where it is
169                        // completely unable to shut down. To avoid this, we
170                        // wait for a longer period before attempting to shut down.
171                        #[cfg(windows)]
172                        timer.sleep(Duration::from_secs(5)).await;
173                        #[cfg(not(windows))]
174                        timer.sleep(Duration::from_millis(250)).await;
175                        loop {
176                            if let Err(err) = crate::shutdown::handle_shutdown(request) {
177                                tracing::error!(
178                                    error = err.as_ref() as &dyn std::error::Error,
179                                    "failed to shut down"
180                                );
181                            }
182                            timer.sleep(Duration::from_secs(5)).await;
183                            tracing::warn!("still waiting to shut down, trying again");
184                        }
185                    })
186                    .detach();
187                Ok(())
188            })
189        }
190        PipetteRequest::ReadFile(rpc) => rpc.handle_failable(read_file).await,
191        PipetteRequest::WriteFile(rpc) => rpc.handle_failable(write_file).await,
192        PipetteRequest::GetTime(rpc) => rpc.handle_sync(|()| SystemTime::now().into()),
193        PipetteRequest::Crash(rpc) => rpc.handle_sync(|()| panic!("crash requested")),
194        PipetteRequest::KernelCrash(rpc) => {
195            rpc.handle_failable(async |()| {
196                crate::crash::trigger_kernel_crash()?;
197                std::future::pending::<()>().await;
198                anyhow::Ok(())
199            })
200            .await
201        }
202        #[cfg(target_os = "linux")]
203        PipetteRequest::Mount(rpc) => rpc.handle_failable_sync(crate::mount::handle_mount),
204        #[cfg(not(target_os = "linux"))]
205        PipetteRequest::Mount(rpc) => {
206            rpc.handle_failable_sync(|_| anyhow::bail!("mount not supported on this platform"))
207        }
208    }
209}
210
211async fn read_file(mut request: pipette_protocol::ReadFileRequest) -> anyhow::Result<u64> {
212    tracing::debug!(path = request.path, "Beginning file read request");
213    let file = fs_err::File::open(request.path)?;
214    let n = futures::io::copy(&mut futures::io::AllowStdIo::new(file), &mut request.sender).await?;
215    tracing::debug!("file read request complete");
216    Ok(n)
217}
218
219async fn write_file(mut request: pipette_protocol::WriteFileRequest) -> anyhow::Result<u64> {
220    tracing::debug!(path = request.path, "Beginning file write request");
221    let file = fs_err::File::create(request.path)?;
222    let n = futures::io::copy(
223        &mut request.receiver,
224        &mut futures::io::AllowStdIo::new(file),
225    )
226    .await?;
227    tracing::debug!("file write request complete");
228    Ok(n)
229}
230
231impl DiagnosticSender {
232    #[cfg_attr(not(windows), expect(dead_code))]
233    pub async fn send(&self, filename: &str) -> anyhow::Result<()> {
234        tracing::debug!(filename, "Beginning diagnostic file request");
235        let file = fs_err::File::open(filename)?;
236        let (recv_pipe, mut send_pipe) = mesh::pipe::pipe();
237        self.0.send(DiagnosticFile {
238            name: filename.to_owned(),
239            receiver: recv_pipe,
240        });
241        futures::io::copy(&mut futures::io::AllowStdIo::new(file), &mut send_pipe).await?;
242        drop(send_pipe);
243        tracing::debug!("diagnostic request complete");
244        Ok(())
245    }
246}