vmbus_server/
hvsock.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! This module implements an hvsocket-to-Unix socket relay.
5//!
6//! This supports the [hybrid vsock connection model][1] established by
7//! Firecracker, extended to support Hyper-V sockets as well.
8//!
9//! [1]: <https://github.com/firecracker-microvm/firecracker/blob/7b2e87dc65fc45162303e5708b83c379cf1b0426/docs/vsock.md>
10
11use super::Guid;
12use crate::HvsockRelayChannelHalf;
13use crate::ring::RingMem;
14use anyhow::Context;
15use futures::AsyncReadExt;
16use futures::AsyncWriteExt;
17use futures::StreamExt;
18use futures_concurrency::stream::Merge;
19use hybrid_vsock::HYBRID_CONNECT_REQUEST_LEN;
20use hybrid_vsock::VsockPortOrId;
21use mesh::CancelContext;
22use pal_async::driver::SpawnDriver;
23use pal_async::socket::PolledSocket;
24use pal_async::task::Spawn;
25use pal_async::task::Task;
26use std::io::ErrorKind;
27use std::path::Path;
28use std::path::PathBuf;
29use std::sync::Arc;
30use std::time::Duration;
31use unicycle::FuturesUnordered;
32use unix_socket::UnixListener;
33use unix_socket::UnixStream;
34use vmbus_async::pipe::BytePipe;
35use vmbus_channel::bus::ChannelType;
36use vmbus_channel::bus::OfferParams;
37use vmbus_channel::bus::ParentBus;
38use vmbus_channel::offer::Offer;
39use vmbus_core::HvsockConnectRequest;
40use vmbus_core::HvsockConnectResult;
41
42pub struct HvsockRelay {
43    inner: Arc<RelayInner>,
44    host_send: mesh::Sender<RelayRequest>,
45    _relay_task: Task<()>,
46    _listener_task: Option<Task<()>>,
47}
48
49enum RelayRequest {
50    AddTask(Task<()>),
51}
52
53struct RelayInner {
54    vmbus: Arc<dyn ParentBus>,
55    driver: Box<dyn SpawnDriver>,
56}
57
58impl HvsockRelay {
59    /// Creates and starts a relay thread, waiting for hvsocket connect requests
60    /// on `recv`.
61    pub fn new(
62        driver: impl SpawnDriver,
63        vmbus: Arc<dyn ParentBus>,
64        guest: HvsockRelayChannelHalf,
65        hybrid_vsock_path: Option<PathBuf>,
66        hybrid_vsock_listener: Option<UnixListener>,
67    ) -> anyhow::Result<Self> {
68        let inner = Arc::new(RelayInner {
69            vmbus,
70            driver: Box::new(driver),
71        });
72
73        let worker = HvsockRelayWorker {
74            guest_send: guest.response_send,
75            inner: inner.clone(),
76            tasks: Default::default(),
77            hybrid_vsock_path,
78        };
79
80        let (host_send, host_recv) = mesh::channel();
81
82        let _listener_task = if let Some(listener) = hybrid_vsock_listener {
83            let listener = PolledSocket::new(inner.driver.as_ref(), listener)?;
84            Some(
85                inner.driver.spawn(
86                    "hvsock-listener",
87                    ListenerWorker {
88                        inner: inner.clone(),
89                        host_send: host_send.clone(),
90                    }
91                    .run(listener),
92                ),
93            )
94        } else {
95            None
96        };
97
98        let task = inner
99            .driver
100            .spawn("hvsock relay", worker.run(guest.request_receive, host_recv));
101
102        Ok(Self {
103            host_send,
104            inner,
105            _relay_task: task,
106            _listener_task,
107        })
108    }
109
110    /// Connects to an hvsocket in the guest and returns a Unix socket that is
111    /// relayed to the hvsocket.
112    ///
113    /// Blocks until complete or cancelled.
114    pub fn connect(
115        &self,
116        ctx: &mut CancelContext,
117        service_id: Guid,
118    ) -> impl Future<Output = anyhow::Result<UnixStream>> + Send + use<> {
119        let inner = self.inner.clone();
120        let host_send = self.host_send.clone();
121        let (send, recv) = mesh::oneshot();
122
123        // Ensure the task gets dropped if the future is dropped.
124        let (mut ctx, cancel) = ctx.with_cancel();
125
126        // Spawn a task to initiate the connect to avoid keeping a reference on `RelayInner`.
127        let task = self.inner.driver.spawn("hvsock-connect", async move {
128            let r = async {
129                let (stream, task) = ctx
130                    .until_cancelled(inner.connect_to_guest(service_id))
131                    .await??;
132                host_send.send(RelayRequest::AddTask(task));
133                Ok(stream)
134            }
135            .await;
136
137            send.send(r);
138        });
139        self.host_send.send(RelayRequest::AddTask(task));
140        async move {
141            let _cancel = cancel;
142            recv.await?
143        }
144    }
145}
146
147struct ListenerWorker {
148    inner: Arc<RelayInner>,
149    host_send: mesh::Sender<RelayRequest>,
150}
151
152impl ListenerWorker {
153    async fn run(self, mut listener: PolledSocket<UnixListener>) {
154        loop {
155            let connection = match listener.accept().await {
156                Ok((connection, _address)) => connection,
157                Err(err) => {
158                    tracing::error!(
159                        error = &err as &dyn std::error::Error,
160                        "failed to accept hybrid vsock connection, shutting down listener"
161                    );
162                    break;
163                }
164            };
165            match self.spawn_relay(connection).await {
166                Ok(task) => {
167                    self.host_send.send(RelayRequest::AddTask(task));
168                }
169                Err(err) => {
170                    tracing::warn!(
171                        error = err.as_ref() as &dyn std::error::Error,
172                        "relayed connection failed"
173                    );
174                }
175            }
176        }
177    }
178
179    async fn spawn_relay(&self, connection: UnixStream) -> anyhow::Result<Task<()>> {
180        let mut socket = PolledSocket::new(self.inner.driver.as_ref(), connection)?;
181        let request = read_hybrid_vsock_connect(&mut socket).await?;
182
183        let instance_id = Guid::new_random();
184        let mut offer = Offer::new(
185            self.inner.driver.as_ref(),
186            self.inner.vmbus.as_ref(),
187            OfferParams {
188                interface_name: "hvsocket_connect".into(),
189                interface_id: request.id(),
190                instance_id,
191                channel_type: ChannelType::HvSocket {
192                    is_connect: true,
193                    is_for_container: false,
194                    silo_id: Guid::ZERO,
195                },
196                ..Default::default()
197            },
198        )
199        .await
200        .context("failed to offer channel")?;
201
202        let channel = CancelContext::new()
203            .with_timeout(Duration::from_secs(2))
204            .until_cancelled(offer.wait_for_open(self.inner.driver.as_ref()))
205            .await?
206            .context("failed to accept channel")?
207            .accept()
208            .channel;
209
210        let pipe = BytePipe::new(channel).context("failed to create vmbus pipe")?;
211
212        tracing::debug!(service_id = %request.id(), endpoint_id = %instance_id, "connected host to guest");
213
214        let task = self
215            .inner
216            .driver
217            .spawn("hvsock connection relay", async move {
218                // Keep the offer alive until the relay completes.
219                let _offer = offer;
220
221                // Notify the client that connection was successful, using a format that matches the
222                // request.
223                let response = match request {
224                    VsockPortOrId::Port(_) => VsockPortOrId::Port(instance_id.data1),
225                    VsockPortOrId::Id(_) => VsockPortOrId::Id(instance_id),
226                };
227                let s = response.get_ok_response();
228                if let Err(err) = socket.write_all(s.as_bytes()).await {
229                    tracing::error!(
230                        service_id = %request.id(),
231                        error = &err as &dyn std::error::Error,
232                        "failed to write OK response"
233                    );
234                }
235
236                if let Err(err) = relay_connected(pipe, socket).await {
237                    tracing::error!(
238                        service_id = %request.id(),
239                        error = &err as &dyn std::error::Error,
240                        "connection relay failed"
241                    );
242                } else {
243                    tracing::debug!(service_id = %request.id(), "connection relay finished");
244                }
245            });
246
247        Ok(task)
248    }
249}
250
251async fn read_hybrid_vsock_connect(
252    socket: &mut PolledSocket<UnixStream>,
253) -> anyhow::Result<VsockPortOrId> {
254    let mut buf = [0; HYBRID_CONNECT_REQUEST_LEN];
255    let mut i = 0;
256    while i == 0 || buf[i - 1] != b'\n' {
257        if i == buf.len() {
258            anyhow::bail!("connect request did not fit");
259        }
260        let n = socket
261            .read(&mut buf[i..])
262            .await
263            .context("failed to read connect request")?;
264        if n == 0 {
265            anyhow::bail!("no connect request");
266        }
267        i += n;
268    }
269
270    let request = VsockPortOrId::parse_connect_request(&buf[..i - 1])?;
271    tracing::debug!(?request, "got hybrid connect request");
272    Ok(request)
273}
274
275struct PendingConnection {
276    send: mesh::Sender<HvsockConnectResult>,
277    request: HvsockConnectRequest,
278}
279
280impl PendingConnection {
281    fn done(self, success: bool) {
282        self.send
283            .send(HvsockConnectResult::from_request(&self.request, success));
284        std::mem::forget(self);
285    }
286}
287
288impl Drop for PendingConnection {
289    fn drop(&mut self) {
290        self.send
291            .send(HvsockConnectResult::from_request(&self.request, false));
292    }
293}
294
295struct HvsockRelayWorker {
296    guest_send: mesh::Sender<HvsockConnectResult>,
297    tasks: FuturesUnordered<Task<()>>,
298    inner: Arc<RelayInner>,
299    hybrid_vsock_path: Option<PathBuf>,
300}
301
302impl HvsockRelayWorker {
303    async fn run(
304        mut self,
305        guest_recv: mesh::Receiver<HvsockConnectRequest>,
306        host_recv: mesh::Receiver<RelayRequest>,
307    ) {
308        enum Event {
309            Guest(HvsockConnectRequest),
310            Host(RelayRequest),
311            TaskDone(()),
312        }
313
314        let mut recv = (guest_recv.map(Event::Guest), host_recv.map(Event::Host)).merge();
315
316        while let Some(event) = (&mut recv, (&mut self.tasks).map(Event::TaskDone))
317            .merge()
318            .next()
319            .await
320        {
321            match event {
322                Event::Guest(request) => {
323                    self.handle_connect_from_guest(request);
324                }
325                Event::Host(request) => match request {
326                    RelayRequest::AddTask(task) => {
327                        self.tasks.push(task);
328                    }
329                },
330                Event::TaskDone(()) => {}
331            }
332        }
333    }
334
335    fn handle_connect_from_guest(&mut self, request: HvsockConnectRequest) {
336        if request.silo_id != Guid::ZERO {
337            tracelimit::warn_ratelimited!(?request, "Non-zero silo ID is currently ignored.")
338        }
339
340        // Wrap the connect request so that we are assured to send a response.
341        let pending = PendingConnection {
342            send: self.guest_send.clone(),
343            request,
344        };
345        let path = {
346            if let Some(hybrid_vsock_path) = &self.hybrid_vsock_path {
347                hybrid_vsock_path.to_owned()
348            } else {
349                tracing::debug!(request = ?&request, "ignoring hvsock connect request");
350                return;
351            }
352        };
353
354        let task = self.inner.driver.spawn(
355            format!(
356                "hvsock accept {}:{}",
357                request.service_id, request.endpoint_id
358            ),
359            {
360                let inner = self.inner.clone();
361                async move {
362                    match inner
363                        .relay_guest_connect_to_host(pending, path.as_ref())
364                        .await
365                    {
366                        Ok(()) => {
367                            tracing::debug!(request = ?&request, "relay done");
368                        }
369                        Err(err) => {
370                            tracelimit::error_ratelimited!(
371                                request = ?&request,
372                                err = err.as_ref() as &dyn std::error::Error,
373                                "relay error"
374                            );
375                        }
376                    }
377                }
378            },
379        );
380        self.tasks.push(task);
381    }
382}
383
384impl RelayInner {
385    async fn relay_guest_connect_to_host(
386        &self,
387        pending: PendingConnection,
388        base_path: &Path,
389    ) -> anyhow::Result<()> {
390        let request = &pending.request;
391
392        // Find the appropriate path to connect to. Don't connect until the
393        // channel with the guest has been established, since that's a
394        // failure-prone operation and we don't want the host to see a broken
395        // connection.
396        let vsock_request = VsockPortOrId::Id(request.service_id);
397        let path = vsock_request.host_uds_path(base_path)?;
398
399        let mut offer = Offer::new(
400            self.driver.as_ref(),
401            self.vmbus.as_ref(),
402            OfferParams {
403                interface_name: "hvsocket".to_owned(),
404                instance_id: request.endpoint_id,
405                interface_id: request.service_id,
406                channel_type: ChannelType::HvSocket {
407                    is_connect: false,
408                    is_for_container: false,
409                    silo_id: Guid::ZERO,
410                },
411                ..Default::default()
412            },
413        )
414        .await
415        .context("failed to offer channel")?;
416
417        tracing::debug!(?request, "offered hvsocket channel to guest");
418        let service_id = request.service_id;
419
420        // Now that the channel is offered, report that the connection operation is
421        // done.
422        pending.done(true);
423
424        // Give the guest a few seconds to open the channel.
425        let channel = CancelContext::new()
426            .with_timeout(Duration::from_secs(5))
427            .until_cancelled(offer.wait_for_open(self.driver.as_ref()))
428            .await
429            .context("guest did not open hvsocket channel")??;
430
431        tracing::debug!(%service_id, "guest opened hvsocket channel");
432
433        // Connect to the host Unix socket.
434        let socket = PolledSocket::connect_unix(self.driver.as_ref(), &path)
435            .await
436            .with_context(|| {
437                format!(
438                    "failed to connect to registered listener {} for {}",
439                    path.display(),
440                    service_id
441                )
442            })?;
443
444        tracing::debug!(%service_id, path = %path.display(), "connected to host uds socket");
445
446        // Accept the channel now that the host connection is established.
447        let channel = channel.accept().channel;
448
449        let channel = BytePipe::new(channel)?;
450        if let Err(err) = relay_connected(channel, socket).await {
451            tracelimit::error_ratelimited!(
452                %service_id,
453                error = &err as &dyn std::error::Error,
454                "guest to host connection relay failed"
455            );
456        } else {
457            tracing::debug!(%service_id, "guest to host connection relay finished");
458        }
459
460        // N.B. offer needs to stay alive until here to avoid revoking the channel
461        // before the relay is done.
462        drop(offer);
463        Ok(())
464    }
465
466    async fn connect_to_guest(&self, service_id: Guid) -> anyhow::Result<(UnixStream, Task<()>)> {
467        let instance_id = Guid::new_random();
468        let mut offer = Offer::new(
469            &self.driver,
470            self.vmbus.as_ref(),
471            OfferParams {
472                interface_name: "hvsocket_connect".into(),
473                interface_id: service_id,
474                instance_id,
475                channel_type: ChannelType::HvSocket {
476                    is_connect: true,
477                    is_for_container: false,
478                    silo_id: Guid::ZERO,
479                },
480                ..Default::default()
481            },
482        )
483        .await
484        .context("failed to offer channel")?;
485
486        let channel = offer
487            .wait_for_open(self.driver.as_ref())
488            .await
489            .context("failed to accept channel")?
490            .accept()
491            .channel;
492        let pipe = BytePipe::new(channel).context("failed to create vmbus pipe")?;
493
494        tracing::debug!(%service_id, endpoint_id = %instance_id, "connected host to guest");
495
496        let (left, right) = UnixStream::pair().context("failed to create socket pair")?;
497        let right = PolledSocket::new(self.driver.as_ref(), right)
498            .context("failed to create polled socket")?;
499
500        let task = self.driver.spawn(
501            format!("hvsock {}:{}", service_id, instance_id),
502            async move {
503                // Keep the offer alive until the relay completes.
504                let _offer = offer;
505                if let Err(err) = relay_connected(pipe, right).await {
506                    tracing::error!(
507                        %service_id,
508                        error = &err as &dyn std::error::Error,
509                        "connection relay failed"
510                    );
511                }
512            },
513        );
514
515        Ok((left, task))
516    }
517}
518
519async fn relay_connected<T: RingMem + Unpin>(
520    channel: BytePipe<T>,
521    socket: PolledSocket<UnixStream>,
522) -> std::io::Result<()> {
523    let (channel_read, mut channel_write) = channel.split();
524    let (socket_read, mut socket_write) = socket.split();
525
526    let channel_to_socket = async {
527        futures::io::copy(channel_read, &mut socket_write).await?;
528        socket_write.close().await
529    };
530
531    let socket_to_channel = async {
532        futures::io::copy(socket_read, &mut channel_write).await?;
533        channel_write.close().await
534    };
535
536    match futures::future::try_join(channel_to_socket, socket_to_channel).await {
537        Ok(((), ())) => {}
538        Err(err) if err.kind() == ErrorKind::ConnectionReset => {}
539        Err(err) => return Err(err),
540    }
541    Ok(())
542}
543
544#[cfg(test)]
545mod tests {
546    use super::relay_connected;
547    use crate::ring::FlatRingMem;
548    use futures::AsyncReadExt;
549    use futures::AsyncWriteExt;
550    use pal_async::DefaultDriver;
551    use pal_async::async_test;
552    use pal_async::driver::Driver;
553    use pal_async::socket::PolledSocket;
554    use pal_async::task::Spawn;
555    use pal_async::task::Task;
556    use unix_socket::UnixStream;
557    use vmbus_async::pipe::BytePipe;
558    use vmbus_async::pipe::connected_byte_pipes;
559
560    fn setup_relay<T: Driver + Spawn>(
561        driver: &T,
562    ) -> (
563        BytePipe<FlatRingMem>,
564        PolledSocket<UnixStream>,
565        Task<std::io::Result<()>>,
566    ) {
567        let (hc, c) = connected_byte_pipes(4096);
568        let (s, s2) = UnixStream::pair().unwrap();
569        let s = PolledSocket::new(driver, s).unwrap();
570        let s2 = PolledSocket::new(driver, s2).unwrap();
571        let task = driver.spawn("test", async move { relay_connected(hc, s2).await });
572
573        (c, s, task)
574    }
575
576    #[async_test]
577    async fn test_relay(driver: DefaultDriver) {
578        let (mut c, mut s, task) = setup_relay(&driver);
579
580        let d = b"abcd";
581        let mut v = [0; 4];
582
583        // c to s
584        c.write_all(d).await.unwrap();
585        s.read_exact(&mut v).await.unwrap();
586        assert_eq!(&v, d);
587
588        // s to c
589        s.write_all(d).await.unwrap();
590        c.read_exact(&mut v).await.unwrap();
591        assert_eq!(&v, d);
592
593        // s to c
594        s.write_all(d).await.unwrap();
595        s.close().await.unwrap();
596        c.read_exact(&mut v).await.unwrap();
597        assert_eq!(&v, d);
598
599        // c to s
600        c.write_all(d).await.unwrap();
601        s.read_exact(&mut v).await.unwrap();
602        assert_eq!(&v, d);
603
604        c.close().await.unwrap();
605        task.await.unwrap();
606    }
607
608    #[cfg(unix)] // Windows does not deliver POLLHUP on Unix socket close.
609    #[async_test]
610    async fn test_relay_host_close(driver: DefaultDriver) {
611        let (mut c, _, task) = setup_relay(&driver);
612
613        let mut b = [0];
614        assert_eq!(c.read(&mut b).await.unwrap(), 0);
615        drop(c);
616        task.await.unwrap();
617    }
618
619    #[async_test]
620    async fn test_relay_guest_close(driver: DefaultDriver) {
621        let (_, mut s, task) = setup_relay(&driver);
622
623        let mut b = [0];
624        assert_eq!(s.read(&mut b).await.unwrap(), 0);
625        drop(s);
626        task.await.unwrap();
627    }
628
629    #[async_test]
630    async fn test_relay_forward_socket_shutdown(driver: DefaultDriver) {
631        let (mut c, mut s, task) = setup_relay(&driver);
632        s.close().await.unwrap();
633        let mut v = [0; 1];
634        assert_eq!(c.read(&mut v).await.unwrap(), 0);
635        drop(c);
636        task.await.unwrap();
637    }
638
639    #[async_test]
640    async fn test_relay_forward_channel_shutdown(driver: DefaultDriver) {
641        let (mut c, mut s, task) = setup_relay(&driver);
642
643        c.close().await.unwrap();
644        let mut v = [0; 1];
645        assert_eq!(s.read(&mut v).await.unwrap(), 0);
646        drop(s);
647        task.await.unwrap();
648    }
649}