1#![cfg(any(target_os = "linux", target_os = "windows"))]
7
8use anyhow::Context;
9use futures::future::FutureExt;
10use futures_concurrency::future::RaceOk;
11use mesh_remote::PointToPointMesh;
12use pal_async::DefaultDriver;
13use pal_async::socket::PolledSocket;
14use pal_async::task::Spawn;
15use pal_async::timer::PolledTimer;
16use pipette_protocol::DiagnosticFile;
17use pipette_protocol::PipetteBootstrap;
18use pipette_protocol::PipetteRequest;
19use socket2::Socket;
20use std::time::Duration;
21use std::time::SystemTime;
22use unicycle::FuturesUnordered;
23use vmsocket::VmAddress;
24use vmsocket::VmSocket;
25
26pub struct Agent {
27 driver: DefaultDriver,
28 mesh: PointToPointMesh,
29 request_recv: mesh::Receiver<PipetteRequest>,
30 diag_file_send: DiagnosticSender,
31 watch_send: mesh::OneshotSender<()>,
32}
33
34#[derive(Clone)]
35pub struct DiagnosticSender(mesh::Sender<DiagnosticFile>);
36
37impl Agent {
38 pub async fn new(driver: DefaultDriver) -> anyhow::Result<Self> {
39 let socket = (connect_client(&driver), connect_server(&driver))
40 .race_ok()
41 .await
42 .map_err(|e| {
43 let [e0, e1] = &*e;
44 anyhow::anyhow!(
45 "failed to connect. client error: {:#} server error: {:#}",
46 e0,
47 e1
48 )
49 })?;
50
51 let (bootstrap_send, bootstrap_recv) = mesh::oneshot::<PipetteBootstrap>();
52 let mesh = PointToPointMesh::new(&driver, socket, bootstrap_recv.into());
53
54 let (request_send, request_recv) = mesh::channel();
55 let (diag_file_send, diag_file_recv) = mesh::channel();
56 let (watch_send, watch_recv) = mesh::oneshot();
57 let log = crate::trace::init_tracing();
58
59 bootstrap_send.send(PipetteBootstrap {
60 requests: request_send,
61 diag_file_recv,
62 watch: watch_recv,
63 log,
64 });
65
66 Ok(Self {
67 driver,
68 mesh,
69 request_recv,
70 diag_file_send: DiagnosticSender(diag_file_send),
71 watch_send,
72 })
73 }
74
75 pub async fn run(mut self) -> anyhow::Result<()> {
76 let mut tasks = FuturesUnordered::new();
77 loop {
78 futures::select! {
79 req = self.request_recv.recv().fuse() => {
80 match req {
81 Ok(req) => {
82 tasks.push(handle_request(&self.driver, req, self.diag_file_send.clone()));
83 },
84 Err(e) => {
85 tracing::info!(?e, "request channel closed, shutting down");
86 break;
87 }
88 }
89 }
90 _ = tasks.next() => {}
91 }
92 }
93 self.watch_send.send(());
94 self.mesh.shutdown().await;
95 Ok(())
96 }
97}
98
99async fn connect_server(driver: &DefaultDriver) -> anyhow::Result<PolledSocket<Socket>> {
100 let mut socket = VmSocket::new()?;
101 socket.bind(VmAddress::vsock_any(pipette_protocol::PIPETTE_VSOCK_PORT))?;
102 let mut socket =
103 PolledSocket::new(driver, socket.into()).context("failed to create polled socket")?;
104 socket.listen(1)?;
105 let socket = socket
106 .accept()
107 .await
108 .context("failed to accept connection")?
109 .0;
110 PolledSocket::new(driver, socket).context("failed to create polled socket")
111}
112
113async fn connect_client(driver: &DefaultDriver) -> anyhow::Result<PolledSocket<Socket>> {
114 let socket = VmSocket::new()?;
115 socket
118 .set_connect_timeout(Duration::from_secs(5))
119 .context("failed to set socket timeout")?;
120 let mut socket = PolledSocket::new(driver, socket)
121 .context("failed to create polled socket")?
122 .convert();
123 socket
124 .connect(&VmAddress::vsock_host(pipette_protocol::PIPETTE_VSOCK_PORT).into())
125 .await?;
126 Ok(socket)
127}
128
129async fn handle_request(
130 driver: &DefaultDriver,
131 req: PipetteRequest,
132 _diag_file_send: DiagnosticSender,
133) {
134 match req {
135 PipetteRequest::Ping(rpc) => rpc.handle_sync(|()| {
136 tracing::info!("ping");
137 }),
138 PipetteRequest::Execute(rpc) => rpc.handle_failable_sync(crate::execute::handle_execute),
139 PipetteRequest::Shutdown(rpc) => {
140 rpc.handle_sync(|request| {
141 tracing::info!(shutdown_type = ?request.shutdown_type, "shutdown request");
142 let mut timer = PolledTimer::new(driver);
147 driver
148 .spawn("shutdown", async move {
149 #[cfg(windows)]
156 timer.sleep(Duration::from_secs(5)).await;
157 #[cfg(not(windows))]
158 timer.sleep(Duration::from_millis(250)).await;
159 loop {
160 if let Err(err) = crate::shutdown::handle_shutdown(request) {
161 tracing::error!(
162 error = err.as_ref() as &dyn std::error::Error,
163 "failed to shut down"
164 );
165 }
166 timer.sleep(Duration::from_secs(5)).await;
167 tracing::warn!("still waiting to shut down, trying again");
168 }
169 })
170 .detach();
171 Ok(())
172 })
173 }
174 PipetteRequest::ReadFile(rpc) => rpc.handle_failable(read_file).await,
175 PipetteRequest::WriteFile(rpc) => rpc.handle_failable(write_file).await,
176 PipetteRequest::GetTime(rpc) => rpc.handle_sync(|()| SystemTime::now().into()),
177 }
178}
179
180async fn read_file(mut request: pipette_protocol::ReadFileRequest) -> anyhow::Result<u64> {
181 tracing::debug!(path = request.path, "Beginning file read request");
182 let file = fs_err::File::open(request.path)?;
183 let n = futures::io::copy(&mut futures::io::AllowStdIo::new(file), &mut request.sender).await?;
184 tracing::debug!("file read request complete");
185 Ok(n)
186}
187
188async fn write_file(mut request: pipette_protocol::WriteFileRequest) -> anyhow::Result<u64> {
189 tracing::debug!(path = request.path, "Beginning file write request");
190 let file = fs_err::File::create(request.path)?;
191 let n = futures::io::copy(
192 &mut request.receiver,
193 &mut futures::io::AllowStdIo::new(file),
194 )
195 .await?;
196 tracing::debug!("file write request complete");
197 Ok(n)
198}
199
200impl DiagnosticSender {
201 #[cfg_attr(not(windows), expect(dead_code))]
202 pub async fn send(&self, filename: &str) -> anyhow::Result<()> {
203 tracing::debug!(filename, "Beginning diagnostic file request");
204 let file = fs_err::File::open(filename)?;
205 let (recv_pipe, mut send_pipe) = mesh::pipe::pipe();
206 self.0.send(DiagnosticFile {
207 name: filename.to_owned(),
208 receiver: recv_pipe,
209 });
210 futures::io::copy(&mut futures::io::AllowStdIo::new(file), &mut send_pipe).await?;
211 drop(send_pipe);
212 tracing::debug!("diagnostic request complete");
213 Ok(())
214 }
215}