diag_server/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Underhill diagnostics server.
5
6#![cfg(target_os = "linux")]
7
8mod diag_service;
9mod new_pty;
10
11pub use diag_service::DiagRequest;
12pub use diag_service::StartParams;
13
14use anyhow::Context;
15use cvm_tracing::CVM_ALLOWED;
16use futures::AsyncWriteExt;
17use futures::FutureExt;
18use mesh::CancelReason;
19use mesh_rpc::server::RpcReceiver;
20use mesh_rpc::service::Code;
21use mesh_rpc::service::Status;
22use pal_async::driver::Driver;
23use pal_async::interest::PollEvents;
24use pal_async::socket::PollReadyExt;
25use pal_async::socket::PolledSocket;
26use pal_async::task::Spawn;
27use pal_async::task::Task;
28use parking_lot::Mutex;
29use socket2::Socket;
30use std::collections::HashMap;
31use std::path::Path;
32use std::pin::pin;
33use std::sync::Arc;
34use unix_socket::UnixListener;
35use vmsocket::VmAddress;
36use vmsocket::VmListener;
37
38/// The diagnostics server, which is a ttrpc server listening on `AF_VSOCK` at
39/// for control and data.
40pub struct DiagServer {
41    // control listener
42    control_listener: Socket,
43    // data listener
44    data_listener: Socket,
45    inner: Arc<Inner>,
46    server: mesh_rpc::Server,
47}
48
49impl DiagServer {
50    /// Creates a server over VmSockets and starts listening.
51    pub fn new_vsock(control_address: VmAddress, data_address: VmAddress) -> anyhow::Result<Self> {
52        tracing::info!(CVM_ALLOWED, ?control_address, "control starting");
53        let control_listener =
54            VmListener::bind(control_address).context("failed to bind socket")?;
55
56        tracing::info!(CVM_ALLOWED, ?data_address, "data starting");
57        let data_listener = VmListener::bind(data_address).context("failed to bind socket")?;
58
59        Ok(Self::new_generic(
60            control_listener.into(),
61            data_listener.into(),
62        ))
63    }
64
65    /// Creates a server over Unix sockets and starts listening.
66    pub fn new_unix(control_address: &Path, data_address: &Path) -> anyhow::Result<Self> {
67        tracing::info!(CVM_ALLOWED, ?control_address, "control starting");
68        let control_listener =
69            UnixListener::bind(control_address).context("failed to bind socket")?;
70
71        tracing::info!(CVM_ALLOWED, ?data_address, "data starting");
72        let data_listener = UnixListener::bind(data_address).context("failed to bind socket")?;
73
74        Ok(Self::new_generic(
75            control_listener.into(),
76            data_listener.into(),
77        ))
78    }
79
80    fn new_generic(control_listener: Socket, data_listener: Socket) -> Self {
81        Self {
82            control_listener,
83            data_listener,
84            server: mesh_rpc::Server::new(),
85            inner: Arc::new(Inner {
86                connections: Mutex::new(DataConnections {
87                    next_id: 1, // connection IDs start at 1, as 0 is an invalid ID.
88                    active: Default::default(),
89                }),
90            }),
91        }
92    }
93
94    /// Serves requests until `cancel` is dropped.
95    pub async fn serve(
96        mut self,
97        driver: &(impl Driver + Spawn + Clone),
98        cancel: mesh::OneshotReceiver<()>,
99        request_send: mesh::Sender<DiagRequest>,
100    ) -> anyhow::Result<()> {
101        // Disable all diag requests for CVMs. Inspect filtering will be handled
102        // internally more granularly.
103        let (diag_recv, diag2_recv) = if underhill_confidentiality::confidential_filtering_enabled()
104        {
105            (RpcReceiver::disconnected(), RpcReceiver::disconnected())
106        } else {
107            (
108                self.server.add_service::<diag_proto::UnderhillDiag>(),
109                self.server.add_service::<diag_proto::OpenhclDiag>(),
110            )
111        };
112
113        let inspect_recv = self.server.add_service::<inspect_proto::InspectService>();
114
115        // TODO: split the profiler to a separate service provider.
116        let profile_recv = self
117            .server
118            .add_service::<azure_profiler_proto::AzureProfiler>();
119
120        let diag_service = Arc::new(diag_service::DiagServiceHandler::new(
121            request_send,
122            self.inner.clone(),
123        ));
124        let process = diag_service.process_requests(
125            driver,
126            diag_recv,
127            diag2_recv,
128            inspect_recv,
129            profile_recv,
130        );
131
132        let serve = self.server.run(driver, self.control_listener, cancel);
133        let data_connections = self
134            .inner
135            .process_data_connections(driver, self.data_listener);
136
137        futures::future::try_join3(serve, process, data_connections).await?;
138        Ok(())
139    }
140}
141
142#[derive(Debug)]
143struct DataConnectionEntry {
144    /// Sender used to notify the hangup task to return the socket.
145    sender: mesh::OneshotSender<()>,
146    /// Task used to wait for hangup notifications or a request to return the socket.
147    task: Task<Option<PolledSocket<Socket>>>,
148}
149
150#[derive(Debug, Default)]
151struct DataConnections {
152    next_id: u64,
153    active: HashMap<u64, DataConnectionEntry>,
154}
155
156impl DataConnections {
157    fn take_connection(&mut self, id: u64) -> anyhow::Result<DataConnectionEntry> {
158        self.active
159            .remove(&id)
160            .ok_or_else(|| anyhow::anyhow!("invalid connection id"))
161    }
162}
163
164struct Inner {
165    connections: Mutex<DataConnections>,
166}
167
168impl Inner {
169    async fn take_connection(&self, id: u64) -> anyhow::Result<PolledSocket<Socket>> {
170        let DataConnectionEntry { sender, task } = self.connections.lock().take_connection(id)?;
171
172        sender.send(());
173        task.await
174            .ok_or_else(|| anyhow::anyhow!("connection disconnected"))
175    }
176
177    /// Listen for data connections and add them to the internal connections lookup table as they arrive.
178    async fn process_data_connections(
179        self: &Arc<Self>,
180        driver: &(impl Driver + Spawn + Clone),
181        listener: Socket,
182    ) -> anyhow::Result<()> {
183        let mut listener = PolledSocket::new(driver, listener)?;
184
185        loop {
186            let (connection, _addr) = listener.accept().await?;
187            let mut socket = PolledSocket::new(driver, connection)?;
188            let inner = Arc::downgrade(self);
189
190            // Send the 8 byte connection id, then stash the connection in the lookup table to be used later.
191            let id;
192            {
193                let mut state = self.connections.lock();
194                id = state.next_id;
195                state.next_id += 1;
196
197                tracing::debug!(id, "new data connection");
198            }
199
200            let (sender, recv) = mesh::oneshot();
201
202            // Spawn a task that returns the socket when asked to, or removes itself from the map if disconnected.
203            let task = driver.spawn(format!("data connection {} waiting", id), async move {
204                match socket.write_all(&id.to_ne_bytes()).await {
205                    Ok(_) => {}
206                    Err(error) => {
207                        tracing::trace!(?error, "error writing connection id, removing.");
208                        if let Some(state) = inner.upgrade() {
209                            state.connections.lock().active.remove(&id);
210                        }
211
212                        return None;
213                    }
214                }
215
216                let mut return_future = pin!(async { recv.await.is_ok() }.fuse());
217                let hangup = futures::select! { // race semantics
218                    _ = socket.wait_ready(PollEvents::RDHUP).fuse() => true,
219                    _ = return_future => false,
220                };
221
222                if hangup {
223                    // Other side has disconnected, remove from the table if not already done.
224                    tracing::trace!(id, "data connection disconnected");
225                    if let Some(state) = inner.upgrade() {
226                        state.connections.lock().active.remove(&id);
227                    }
228
229                    None
230                } else {
231                    Some(socket)
232                }
233            });
234
235            let mut state = self.connections.lock();
236            let result = state
237                .active
238                .insert(id, DataConnectionEntry { sender, task });
239
240            if result.is_some() {
241                anyhow::bail!("connection id reused");
242            }
243        }
244    }
245}
246
247fn grpc_result<T>(result: Result<anyhow::Result<T>, CancelReason>) -> Result<T, Status> {
248    match result {
249        Ok(result) => match result {
250            Ok(value) => Ok(value),
251            Err(err) => Err(Status {
252                code: Code::Unknown as i32,
253                message: format!("{:#}", err),
254                details: vec![],
255            }),
256        },
257        Err(err) => Err(Status {
258            code: match &err {
259                CancelReason::Cancelled => Code::Cancelled,
260                CancelReason::DeadlineExceeded => Code::DeadlineExceeded,
261            } as i32,
262            message: format!("{:#}", err),
263            details: vec![],
264        }),
265    }
266}