pipette/
agent.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

//! The main pipette agent, which is run when the process starts.

#![cfg(any(target_os = "linux", target_os = "windows"))]

use anyhow::Context;
use futures::future::FutureExt;
use futures_concurrency::future::RaceOk;
use mesh_remote::PointToPointMesh;
use pal_async::DefaultDriver;
use pal_async::socket::PolledSocket;
use pal_async::task::Spawn;
use pal_async::timer::PolledTimer;
use pipette_protocol::DiagnosticFile;
use pipette_protocol::PipetteBootstrap;
use pipette_protocol::PipetteRequest;
use socket2::Socket;
use std::time::Duration;
use unicycle::FuturesUnordered;
use vmsocket::VmAddress;
use vmsocket::VmSocket;

pub struct Agent {
    driver: DefaultDriver,
    mesh: PointToPointMesh,
    request_recv: mesh::Receiver<PipetteRequest>,
    diag_file_send: DiagnosticSender,
    watch_send: mesh::OneshotSender<()>,
}

#[derive(Clone)]
pub struct DiagnosticSender(mesh::Sender<DiagnosticFile>);

impl Agent {
    pub async fn new(driver: DefaultDriver) -> anyhow::Result<Self> {
        let socket = (connect_client(&driver), connect_server(&driver))
            .race_ok()
            .await
            .map_err(|e| {
                let [e0, e1] = &*e;
                anyhow::anyhow!(
                    "failed to connect. client error: {:#} server error: {:#}",
                    e0,
                    e1
                )
            })?;

        let (bootstrap_send, bootstrap_recv) = mesh::oneshot::<PipetteBootstrap>();
        let mesh = PointToPointMesh::new(&driver, socket, bootstrap_recv.into());

        let (request_send, request_recv) = mesh::channel();
        let (diag_file_send, diag_file_recv) = mesh::channel();
        let (watch_send, watch_recv) = mesh::oneshot();
        let log = crate::trace::init_tracing();

        bootstrap_send.send(PipetteBootstrap {
            requests: request_send,
            diag_file_recv,
            watch: watch_recv,
            log,
        });

        Ok(Self {
            driver,
            mesh,
            request_recv,
            diag_file_send: DiagnosticSender(diag_file_send),
            watch_send,
        })
    }

    pub async fn run(mut self) -> anyhow::Result<()> {
        let mut tasks = FuturesUnordered::new();
        loop {
            futures::select! {
                req = self.request_recv.recv().fuse() => {
                    match req {
                        Ok(req) => {
                            tasks.push(handle_request(&self.driver, req, self.diag_file_send.clone()));
                        },
                        Err(e) => {
                            tracing::info!(?e, "request channel closed, shutting down");
                            break;
                        }
                    }
                }
                _ = tasks.next() => {}
            }
        }
        self.watch_send.send(());
        self.mesh.shutdown().await;
        Ok(())
    }
}

async fn connect_server(driver: &DefaultDriver) -> anyhow::Result<PolledSocket<Socket>> {
    let mut socket = VmSocket::new()?;
    socket.bind(VmAddress::vsock_any(pipette_protocol::PIPETTE_VSOCK_PORT))?;
    let mut socket =
        PolledSocket::new(driver, socket.into()).context("failed to create polled socket")?;
    socket.listen(1)?;
    let socket = socket
        .accept()
        .await
        .context("failed to accept connection")?
        .0;
    PolledSocket::new(driver, socket).context("failed to create polled socket")
}

async fn connect_client(driver: &DefaultDriver) -> anyhow::Result<PolledSocket<Socket>> {
    let socket = VmSocket::new()?;
    // Extend the default timeout of 2 seconds, as tests are often run in
    // parallel on a host, causing very heavy load on the overall system.
    socket
        .set_connect_timeout(Duration::from_secs(5))
        .context("failed to set socket timeout")?;
    let mut socket = PolledSocket::new(driver, socket)
        .context("failed to create polled socket")?
        .convert();
    socket
        .connect(&VmAddress::vsock_host(pipette_protocol::PIPETTE_VSOCK_PORT).into())
        .await?;
    Ok(socket)
}

async fn handle_request(
    driver: &DefaultDriver,
    req: PipetteRequest,
    _diag_file_send: DiagnosticSender,
) {
    match req {
        PipetteRequest::Ping(rpc) => rpc.handle_sync(|()| {
            tracing::info!("ping");
        }),
        PipetteRequest::Execute(rpc) => rpc.handle_failable_sync(crate::execute::handle_execute),
        PipetteRequest::Shutdown(rpc) => {
            rpc.handle_sync(|request| {
                tracing::info!(shutdown_type = ?request.shutdown_type, "shutdown request");
                // TODO: handle this inline without waiting. Currently we spawn
                // a task so that the response is sent before the shutdown
                // starts, since hvlite fails to notice that the connection is
                // closed if we power off while a response is pending.
                let mut timer = PolledTimer::new(driver);
                driver
                    .spawn("shutdown", async move {
                        // Because pipette runs as a system service on Windows
                        // it is able to issue a shutdown command before Windows
                        // has finished starting up and logging in the user. This
                        // can put the system into a stuck state, where it is
                        // completely unable to shut down. To avoid this, we
                        // wait for a longer period before attempting to shut down.
                        #[cfg(windows)]
                        timer.sleep(Duration::from_secs(5)).await;
                        #[cfg(not(windows))]
                        timer.sleep(Duration::from_millis(250)).await;
                        loop {
                            if let Err(err) = crate::shutdown::handle_shutdown(request) {
                                tracing::error!(
                                    error = err.as_ref() as &dyn std::error::Error,
                                    "failed to shut down"
                                );
                            }
                            timer.sleep(Duration::from_secs(5)).await;
                            tracing::warn!("still waiting to shut down, trying again");
                        }
                    })
                    .detach();
                Ok(())
            })
        }
        PipetteRequest::ReadFile(rpc) => rpc.handle_failable(read_file).await,
        PipetteRequest::WriteFile(rpc) => rpc.handle_failable(write_file).await,
    }
}

async fn read_file(mut request: pipette_protocol::ReadFileRequest) -> anyhow::Result<()> {
    tracing::debug!(path = request.path, "Beginning file read request");
    let file = fs_err::File::open(request.path)?;
    futures::io::copy(&mut futures::io::AllowStdIo::new(file), &mut request.sender).await?;
    tracing::debug!("file read request complete");
    Ok(())
}

async fn write_file(mut request: pipette_protocol::WriteFileRequest) -> anyhow::Result<()> {
    tracing::debug!(path = request.path, "Beginning file write request");
    let file = fs_err::File::create(request.path)?;
    futures::io::copy(
        &mut request.receiver,
        &mut futures::io::AllowStdIo::new(file),
    )
    .await?;
    tracing::debug!("file write request complete");
    Ok(())
}

impl DiagnosticSender {
    #[cfg_attr(not(windows), expect(dead_code))]
    pub async fn send(&self, filename: &str) -> anyhow::Result<()> {
        tracing::debug!(filename, "Beginning diagnostic file request");
        let file = fs_err::File::open(filename)?;
        let (recv_pipe, mut send_pipe) = mesh::pipe::pipe();
        self.0.send(DiagnosticFile {
            name: filename.to_owned(),
            receiver: recv_pipe,
        });
        futures::io::copy(&mut futures::io::AllowStdIo::new(file), &mut send_pipe).await?;
        drop(send_pipe);
        tracing::debug!("diagnostic request complete");
        Ok(())
    }
}