1use anyhow::Context;
7use futures::future::FutureExt;
8use futures_concurrency::future::Race;
9use mesh_remote::PointToPointMesh;
10use pal_async::DefaultDriver;
11use pal_async::socket::PolledSocket;
12use pal_async::task::Spawn;
13use pal_async::timer::PolledTimer;
14use pipette_protocol::DiagnosticFile;
15use pipette_protocol::PipetteBootstrap;
16use pipette_protocol::PipetteRequest;
17use socket2::Socket;
18use std::time::Duration;
19use std::time::SystemTime;
20use unicycle::FuturesUnordered;
21use vmsocket::VmAddress;
22use vmsocket::VmSocket;
23
24pub struct Agent {
25 driver: DefaultDriver,
26 mesh: PointToPointMesh,
27 request_recv: mesh::Receiver<PipetteRequest>,
28 diag_file_send: DiagnosticSender,
29 watch_send: mesh::OneshotSender<()>,
30}
31
32#[derive(Clone)]
33pub struct DiagnosticSender(mesh::Sender<DiagnosticFile>);
34
35impl Agent {
36 pub async fn new(driver: DefaultDriver) -> anyhow::Result<Self> {
37 let socket = (connect_client(&driver), connect_server(&driver))
38 .race()
39 .await;
40
41 eprintln!("Pipette handshaking with host");
42 let (bootstrap_send, bootstrap_recv) = mesh::oneshot::<PipetteBootstrap>();
43 let mesh = PointToPointMesh::new(&driver, socket, bootstrap_recv.into());
44
45 let (request_send, request_recv) = mesh::channel();
46 let (diag_file_send, diag_file_recv) = mesh::channel();
47 let (watch_send, watch_recv) = mesh::oneshot();
48 eprintln!("Pipette initializing tracing");
49 let log = crate::trace::init_tracing();
50
51 bootstrap_send.send(PipetteBootstrap {
52 requests: request_send,
53 diag_file_recv,
54 watch: watch_recv,
55 log,
56 });
57 eprintln!("Pipette bootstrap sent to host");
58
59 Ok(Self {
60 driver,
61 mesh,
62 request_recv,
63 diag_file_send: DiagnosticSender(diag_file_send),
64 watch_send,
65 })
66 }
67
68 pub async fn run(mut self) -> anyhow::Result<()> {
69 let mut tasks = FuturesUnordered::new();
70 loop {
71 futures::select! {
72 req = self.request_recv.recv().fuse() => {
73 match req {
74 Ok(req) => {
75 tasks.push(handle_request(&self.driver, req, self.diag_file_send.clone()));
76 },
77 Err(e) => {
78 tracing::info!(?e, "request channel closed, shutting down");
79 break;
80 }
81 }
82 }
83 _ = tasks.next() => {}
84 }
85 }
86 self.watch_send.send(());
87 self.mesh.shutdown().await;
88 Ok(())
89 }
90}
91
92async fn connect_server(driver: &DefaultDriver) -> PolledSocket<Socket> {
93 let server_core = async || {
94 let mut socket = VmSocket::new()?;
95 socket.bind(VmAddress::vsock_any(pipette_protocol::PIPETTE_VSOCK_PORT))?;
96 let mut socket =
97 PolledSocket::new(driver, socket.into()).context("failed to create polled socket")?;
98 socket.listen(1)?;
99 let socket = socket
100 .accept()
101 .await
102 .context("failed to accept connection")?
103 .0;
104 PolledSocket::new(driver, socket).context("failed to create polled server socket")
105 };
106
107 match server_core().await {
108 Ok(socket) => socket,
109 Err(err) => {
110 eprintln!("failed to stand up server: {:?}", err);
111 std::future::pending().await
112 }
113 }
114}
115
116async fn connect_client(driver: &DefaultDriver) -> PolledSocket<Socket> {
117 let client_core = async || {
118 let socket = VmSocket::new()?;
119 socket
122 .set_connect_timeout(Duration::from_secs(5))
123 .context("failed to set socket timeout")?;
124 let mut socket = PolledSocket::new(driver, socket)
125 .context("failed to create polled client socket")?
126 .convert();
127 socket
128 .connect(&VmAddress::vsock_host(pipette_protocol::PIPETTE_VSOCK_PORT).into())
129 .await
130 .context("failed to connect")
131 .map(|()| socket)
132 };
133 loop {
134 let mut timer = PolledTimer::new(driver);
135 match client_core().await {
136 Ok(socket) => return socket,
137 Err(err) => {
138 eprintln!("failed to connect to server, retrying: {:?}", err);
139 timer.sleep(Duration::from_secs(1)).await;
140 }
141 }
142 }
143}
144
145async fn handle_request(
146 driver: &DefaultDriver,
147 req: PipetteRequest,
148 _diag_file_send: DiagnosticSender,
149) {
150 match req {
151 PipetteRequest::Ping(rpc) => rpc.handle_sync(|()| {
152 tracing::info!("ping");
153 }),
154 PipetteRequest::Execute(rpc) => rpc.handle_failable_sync(crate::execute::handle_execute),
155 PipetteRequest::Shutdown(rpc) => {
156 rpc.handle_sync(|request| {
157 tracing::info!(shutdown_type = ?request.shutdown_type, "shutdown request");
158 let mut timer = PolledTimer::new(driver);
163 driver
164 .spawn("shutdown", async move {
165 #[cfg(windows)]
172 timer.sleep(Duration::from_secs(5)).await;
173 #[cfg(not(windows))]
174 timer.sleep(Duration::from_millis(250)).await;
175 loop {
176 if let Err(err) = crate::shutdown::handle_shutdown(request) {
177 tracing::error!(
178 error = err.as_ref() as &dyn std::error::Error,
179 "failed to shut down"
180 );
181 }
182 timer.sleep(Duration::from_secs(5)).await;
183 tracing::warn!("still waiting to shut down, trying again");
184 }
185 })
186 .detach();
187 Ok(())
188 })
189 }
190 PipetteRequest::ReadFile(rpc) => rpc.handle_failable(read_file).await,
191 PipetteRequest::WriteFile(rpc) => rpc.handle_failable(write_file).await,
192 PipetteRequest::GetTime(rpc) => rpc.handle_sync(|()| SystemTime::now().into()),
193 PipetteRequest::Crash(rpc) => rpc.handle_sync(|()| panic!("crash requested")),
194 PipetteRequest::KernelCrash(rpc) => {
195 rpc.handle_failable(async |()| {
196 crate::crash::trigger_kernel_crash()?;
197 std::future::pending::<()>().await;
198 anyhow::Ok(())
199 })
200 .await
201 }
202 #[cfg(target_os = "linux")]
203 PipetteRequest::Mount(rpc) => rpc.handle_failable_sync(crate::mount::handle_mount),
204 #[cfg(not(target_os = "linux"))]
205 PipetteRequest::Mount(rpc) => {
206 rpc.handle_failable_sync(|_| anyhow::bail!("mount not supported on this platform"))
207 }
208 }
209}
210
211async fn read_file(mut request: pipette_protocol::ReadFileRequest) -> anyhow::Result<u64> {
212 tracing::debug!(path = request.path, "Beginning file read request");
213 let file = fs_err::File::open(request.path)?;
214 let n = futures::io::copy(&mut futures::io::AllowStdIo::new(file), &mut request.sender).await?;
215 tracing::debug!("file read request complete");
216 Ok(n)
217}
218
219async fn write_file(mut request: pipette_protocol::WriteFileRequest) -> anyhow::Result<u64> {
220 tracing::debug!(path = request.path, "Beginning file write request");
221 let file = fs_err::File::create(request.path)?;
222 let n = futures::io::copy(
223 &mut request.receiver,
224 &mut futures::io::AllowStdIo::new(file),
225 )
226 .await?;
227 tracing::debug!("file write request complete");
228 Ok(n)
229}
230
231impl DiagnosticSender {
232 #[cfg_attr(not(windows), expect(dead_code))]
233 pub async fn send(&self, filename: &str) -> anyhow::Result<()> {
234 tracing::debug!(filename, "Beginning diagnostic file request");
235 let file = fs_err::File::open(filename)?;
236 let (recv_pipe, mut send_pipe) = mesh::pipe::pipe();
237 self.0.send(DiagnosticFile {
238 name: filename.to_owned(),
239 receiver: recv_pipe,
240 });
241 futures::io::copy(&mut futures::io::AllowStdIo::new(file), &mut send_pipe).await?;
242 drop(send_pipe);
243 tracing::debug!("diagnostic request complete");
244 Ok(())
245 }
246}