1#![cfg(target_os = "linux")]
7
8mod diag_service;
9mod new_pty;
10
11pub use diag_service::DiagRequest;
12pub use diag_service::StartParams;
13
14use anyhow::Context;
15use cvm_tracing::CVM_ALLOWED;
16use futures::AsyncWriteExt;
17use futures::FutureExt;
18use mesh::CancelReason;
19use mesh_rpc::server::RpcReceiver;
20use mesh_rpc::service::Code;
21use mesh_rpc::service::Status;
22use pal_async::driver::Driver;
23use pal_async::interest::PollEvents;
24use pal_async::socket::PollReadyExt;
25use pal_async::socket::PolledSocket;
26use pal_async::task::Spawn;
27use pal_async::task::Task;
28use parking_lot::Mutex;
29use socket2::Socket;
30use std::collections::HashMap;
31use std::path::Path;
32use std::pin::pin;
33use std::sync::Arc;
34use unix_socket::UnixListener;
35use vmsocket::VmAddress;
36use vmsocket::VmListener;
37
38pub struct DiagServer {
41 control_listener: Socket,
43 data_listener: Socket,
45 inner: Arc<Inner>,
46 server: mesh_rpc::Server,
47}
48
49impl DiagServer {
50 pub fn new_vsock(control_address: VmAddress, data_address: VmAddress) -> anyhow::Result<Self> {
52 tracing::info!(CVM_ALLOWED, ?control_address, "control starting");
53 let control_listener =
54 VmListener::bind(control_address).context("failed to bind socket")?;
55
56 tracing::info!(CVM_ALLOWED, ?data_address, "data starting");
57 let data_listener = VmListener::bind(data_address).context("failed to bind socket")?;
58
59 Ok(Self::new_generic(
60 control_listener.into(),
61 data_listener.into(),
62 ))
63 }
64
65 pub fn new_unix(control_address: &Path, data_address: &Path) -> anyhow::Result<Self> {
67 tracing::info!(CVM_ALLOWED, ?control_address, "control starting");
68 let control_listener =
69 UnixListener::bind(control_address).context("failed to bind socket")?;
70
71 tracing::info!(CVM_ALLOWED, ?data_address, "data starting");
72 let data_listener = UnixListener::bind(data_address).context("failed to bind socket")?;
73
74 Ok(Self::new_generic(
75 control_listener.into(),
76 data_listener.into(),
77 ))
78 }
79
80 fn new_generic(control_listener: Socket, data_listener: Socket) -> Self {
81 Self {
82 control_listener,
83 data_listener,
84 server: mesh_rpc::Server::new(),
85 inner: Arc::new(Inner {
86 connections: Mutex::new(DataConnections {
87 next_id: 1, active: Default::default(),
89 }),
90 }),
91 }
92 }
93
94 pub async fn serve(
96 mut self,
97 driver: &(impl Driver + Spawn + Clone),
98 cancel: mesh::OneshotReceiver<()>,
99 request_send: mesh::Sender<DiagRequest>,
100 ) -> anyhow::Result<()> {
101 let (diag_recv, diag2_recv) = if underhill_confidentiality::confidential_filtering_enabled()
104 {
105 (RpcReceiver::disconnected(), RpcReceiver::disconnected())
106 } else {
107 (
108 self.server.add_service::<diag_proto::UnderhillDiag>(),
109 self.server.add_service::<diag_proto::OpenhclDiag>(),
110 )
111 };
112
113 let inspect_recv = self.server.add_service::<inspect_proto::InspectService>();
114
115 let profile_recv = self
117 .server
118 .add_service::<azure_profiler_proto::AzureProfiler>();
119
120 let diag_service = Arc::new(diag_service::DiagServiceHandler::new(
121 request_send,
122 self.inner.clone(),
123 ));
124 let process = diag_service.process_requests(
125 driver,
126 diag_recv,
127 diag2_recv,
128 inspect_recv,
129 profile_recv,
130 );
131
132 let serve = self.server.run(driver, self.control_listener, cancel);
133 let data_connections = self
134 .inner
135 .process_data_connections(driver, self.data_listener);
136
137 futures::future::try_join3(serve, process, data_connections).await?;
138 Ok(())
139 }
140}
141
142#[derive(Debug)]
143struct DataConnectionEntry {
144 sender: mesh::OneshotSender<()>,
146 task: Task<Option<PolledSocket<Socket>>>,
148}
149
150#[derive(Debug, Default)]
151struct DataConnections {
152 next_id: u64,
153 active: HashMap<u64, DataConnectionEntry>,
154}
155
156impl DataConnections {
157 fn take_connection(&mut self, id: u64) -> anyhow::Result<DataConnectionEntry> {
158 self.active
159 .remove(&id)
160 .ok_or_else(|| anyhow::anyhow!("invalid connection id"))
161 }
162}
163
164struct Inner {
165 connections: Mutex<DataConnections>,
166}
167
168impl Inner {
169 async fn take_connection(&self, id: u64) -> anyhow::Result<PolledSocket<Socket>> {
170 let DataConnectionEntry { sender, task } = self.connections.lock().take_connection(id)?;
171
172 sender.send(());
173 task.await
174 .ok_or_else(|| anyhow::anyhow!("connection disconnected"))
175 }
176
177 async fn process_data_connections(
179 self: &Arc<Self>,
180 driver: &(impl Driver + Spawn + Clone),
181 listener: Socket,
182 ) -> anyhow::Result<()> {
183 let mut listener = PolledSocket::new(driver, listener)?;
184
185 loop {
186 let (connection, _addr) = listener.accept().await?;
187 let mut socket = PolledSocket::new(driver, connection)?;
188 let inner = Arc::downgrade(self);
189
190 let id;
192 {
193 let mut state = self.connections.lock();
194 id = state.next_id;
195 state.next_id += 1;
196
197 tracing::debug!(id, "new data connection");
198 }
199
200 let (sender, recv) = mesh::oneshot();
201
202 let task = driver.spawn(format!("data connection {} waiting", id), async move {
204 match socket.write_all(&id.to_ne_bytes()).await {
205 Ok(_) => {}
206 Err(error) => {
207 tracing::trace!(?error, "error writing connection id, removing.");
208 if let Some(state) = inner.upgrade() {
209 state.connections.lock().active.remove(&id);
210 }
211
212 return None;
213 }
214 }
215
216 let mut return_future = pin!(async { recv.await.is_ok() }.fuse());
217 let hangup = futures::select! { _ = socket.wait_ready(PollEvents::RDHUP).fuse() => true,
219 _ = return_future => false,
220 };
221
222 if hangup {
223 tracing::trace!(id, "data connection disconnected");
225 if let Some(state) = inner.upgrade() {
226 state.connections.lock().active.remove(&id);
227 }
228
229 None
230 } else {
231 Some(socket)
232 }
233 });
234
235 let mut state = self.connections.lock();
236 let result = state
237 .active
238 .insert(id, DataConnectionEntry { sender, task });
239
240 if result.is_some() {
241 anyhow::bail!("connection id reused");
242 }
243 }
244 }
245}
246
247fn grpc_result<T>(result: Result<anyhow::Result<T>, CancelReason>) -> Result<T, Status> {
248 match result {
249 Ok(result) => match result {
250 Ok(value) => Ok(value),
251 Err(err) => Err(Status {
252 code: Code::Unknown as i32,
253 message: format!("{:#}", err),
254 details: vec![],
255 }),
256 },
257 Err(err) => Err(Status {
258 code: match &err {
259 CancelReason::Cancelled => Code::Cancelled,
260 CancelReason::DeadlineExceeded => Code::DeadlineExceeded,
261 } as i32,
262 message: format!("{:#}", err),
263 details: vec![],
264 }),
265 }
266}