vmbus_server/
hvsock.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! This module implements an hvsocket-to-Unix socket relay.
5//!
6//! This supports the [hybrid vsock connection model][1] established by
7//! Firecracker, extended to support Hyper-V sockets as well.
8//!
9//! [1]: <https://github.com/firecracker-microvm/firecracker/blob/7b2e87dc65fc45162303e5708b83c379cf1b0426/docs/vsock.md>
10
11use super::Guid;
12use crate::HvsockRelayChannelHalf;
13use crate::ring::RingMem;
14use anyhow::Context;
15use fs_err::PathExt;
16use futures::AsyncReadExt;
17use futures::AsyncWriteExt;
18use futures::StreamExt;
19use futures_concurrency::stream::Merge;
20use mesh::CancelContext;
21use pal_async::driver::SpawnDriver;
22use pal_async::socket::PolledSocket;
23use pal_async::task::Spawn;
24use pal_async::task::Task;
25use std::io::ErrorKind;
26use std::path::Path;
27use std::path::PathBuf;
28use std::sync::Arc;
29use std::time::Duration;
30use unicycle::FuturesUnordered;
31use unix_socket::UnixListener;
32use unix_socket::UnixStream;
33use vmbus_async::pipe::BytePipe;
34use vmbus_channel::bus::ChannelType;
35use vmbus_channel::bus::OfferParams;
36use vmbus_channel::bus::ParentBus;
37use vmbus_channel::offer::Offer;
38use vmbus_core::HvsockConnectRequest;
39use vmbus_core::HvsockConnectResult;
40
41pub struct HvsockRelay {
42    inner: Arc<RelayInner>,
43    host_send: mesh::Sender<RelayRequest>,
44    _relay_task: Task<()>,
45    _listener_task: Option<Task<()>>,
46}
47
48enum RelayRequest {
49    AddTask(Task<()>),
50}
51
52struct RelayInner {
53    vmbus: Arc<dyn ParentBus>,
54    driver: Box<dyn SpawnDriver>,
55}
56
57impl HvsockRelay {
58    /// Creates and starts a relay thread, waiting for hvsocket connect requests
59    /// on `recv`.
60    pub fn new(
61        driver: impl SpawnDriver,
62        vmbus: Arc<dyn ParentBus>,
63        guest: HvsockRelayChannelHalf,
64        hybrid_vsock_path: Option<PathBuf>,
65        hybrid_vsock_listener: Option<UnixListener>,
66    ) -> anyhow::Result<Self> {
67        let inner = Arc::new(RelayInner {
68            vmbus,
69            driver: Box::new(driver),
70        });
71
72        let worker = HvsockRelayWorker {
73            guest_send: guest.response_send,
74            inner: inner.clone(),
75            tasks: Default::default(),
76            hybrid_vsock_path,
77        };
78
79        let (host_send, host_recv) = mesh::channel();
80
81        let _listener_task = if let Some(listener) = hybrid_vsock_listener {
82            let listener = PolledSocket::new(inner.driver.as_ref(), listener)?;
83            Some(
84                inner.driver.spawn(
85                    "hvsock-listener",
86                    ListenerWorker {
87                        inner: inner.clone(),
88                        host_send: host_send.clone(),
89                    }
90                    .run(listener),
91                ),
92            )
93        } else {
94            None
95        };
96
97        let task = inner
98            .driver
99            .spawn("hvsock relay", worker.run(guest.request_receive, host_recv));
100
101        Ok(Self {
102            host_send,
103            inner,
104            _relay_task: task,
105            _listener_task,
106        })
107    }
108
109    /// Connects to an hvsocket in the guest and returns a Unix socket that is
110    /// relayed to the hvsocket.
111    ///
112    /// Blocks until complete or cancelled.
113    pub fn connect(
114        &self,
115        ctx: &mut CancelContext,
116        service_id: Guid,
117    ) -> impl Future<Output = anyhow::Result<UnixStream>> + Send + use<> {
118        let inner = self.inner.clone();
119        let host_send = self.host_send.clone();
120        let (send, recv) = mesh::oneshot();
121
122        // Ensure the task gets dropped if the future is dropped.
123        let (mut ctx, cancel) = ctx.with_cancel();
124
125        // Spawn a task to initiate the connect to avoid keeping a reference on `RelayInner`.
126        let task = self.inner.driver.spawn("hvsock-connect", async move {
127            let r = async {
128                let (stream, task) = ctx
129                    .until_cancelled(inner.connect_to_guest(service_id))
130                    .await??;
131                host_send.send(RelayRequest::AddTask(task));
132                Ok(stream)
133            }
134            .await;
135
136            send.send(r);
137        });
138        self.host_send.send(RelayRequest::AddTask(task));
139        async move {
140            let _cancel = cancel;
141            recv.await?
142        }
143    }
144}
145
146struct ListenerWorker {
147    inner: Arc<RelayInner>,
148    host_send: mesh::Sender<RelayRequest>,
149}
150
151impl ListenerWorker {
152    async fn run(self, mut listener: PolledSocket<UnixListener>) {
153        loop {
154            let connection = match listener.accept().await {
155                Ok((connection, _address)) => connection,
156                Err(err) => {
157                    tracing::error!(
158                        error = &err as &dyn std::error::Error,
159                        "failed to accept hybrid vsock connection, shutting down listener"
160                    );
161                    break;
162                }
163            };
164            match self.spawn_relay(connection).await {
165                Ok(task) => {
166                    self.host_send.send(RelayRequest::AddTask(task));
167                }
168                Err(err) => {
169                    tracing::warn!(
170                        error = err.as_ref() as &dyn std::error::Error,
171                        "relayed connection failed"
172                    );
173                }
174            }
175        }
176    }
177
178    async fn spawn_relay(&self, connection: UnixStream) -> anyhow::Result<Task<()>> {
179        let mut socket = PolledSocket::new(self.inner.driver.as_ref(), connection)?;
180        let (service_id, format) = read_hybrid_vsock_connect(&mut socket).await?;
181
182        let instance_id = Guid::new_random();
183        let mut offer = Offer::new(
184            self.inner.driver.as_ref(),
185            self.inner.vmbus.as_ref(),
186            OfferParams {
187                interface_name: "hvsocket_connect".into(),
188                interface_id: service_id,
189                instance_id,
190                channel_type: ChannelType::HvSocket {
191                    is_connect: true,
192                    is_for_container: false,
193                    silo_id: Guid::ZERO,
194                },
195                ..Default::default()
196            },
197        )
198        .await
199        .context("failed to offer channel")?;
200
201        let channel = CancelContext::new()
202            .with_timeout(Duration::from_secs(2))
203            .until_cancelled(offer.wait_for_open(self.inner.driver.as_ref()))
204            .await?
205            .context("failed to accept channel")?
206            .accept()
207            .channel;
208
209        let pipe = BytePipe::new(channel).context("failed to create vmbus pipe")?;
210
211        tracing::debug!(%service_id, endpoint_id = %instance_id, "connected host to guest");
212
213        let task = self
214            .inner
215            .driver
216            .spawn("hvsock connection relay", async move {
217                // Keep the offer alive until the relay completes.
218                let _offer = offer;
219
220                // Notify the client that connection was successful.
221                let s = match format {
222                    ServiceIdFormat::Vsock => format!("OK {}\n", instance_id.data1),
223                    ServiceIdFormat::HyperV => format!("OK {}\n", instance_id),
224                };
225                if let Err(err) = socket.write_all(s.as_bytes()).await {
226                    tracing::error!(
227                        %service_id,
228                        error = &err as &dyn std::error::Error,
229                        "failed to write OK response"
230                    );
231                }
232
233                if let Err(err) = relay_connected(pipe, socket).await {
234                    tracing::error!(
235                        %service_id,
236                        error = &err as &dyn std::error::Error,
237                        "connection relay failed"
238                    );
239                } else {
240                    tracing::debug!(%service_id, "connection relay finished");
241                }
242            });
243
244        Ok(task)
245    }
246}
247
248#[derive(Debug)]
249enum ServiceIdFormat {
250    Vsock,
251    HyperV,
252}
253
254async fn read_hybrid_vsock_connect(
255    socket: &mut PolledSocket<UnixStream>,
256) -> anyhow::Result<(Guid, ServiceIdFormat)> {
257    let mut buf = [0; "CONNECT 00000000-facb-11e6-bd58-64006a7986d3\n".len()];
258    let mut i = 0;
259    while i == 0 || buf[i - 1] != b'\n' {
260        if i == buf.len() {
261            anyhow::bail!("connect request did not fit");
262        }
263        let n = socket
264            .read(&mut buf[i..])
265            .await
266            .context("failed to read connect request")?;
267        if n == 0 {
268            anyhow::bail!("no connect request");
269        }
270        i += n;
271    }
272
273    let rest = buf[..i - 1]
274        .strip_prefix(b"CONNECT ")
275        .context("invalid connect request")?;
276
277    let rest = std::str::from_utf8(rest).context("invalid connect request")?;
278    let (service_id, format) = if let Ok(port) = rest.parse::<u32>() {
279        (
280            Guid {
281                data1: port,
282                ..VSOCK_TEMPLATE
283            },
284            ServiceIdFormat::Vsock,
285        )
286    } else if let Ok(service_id) = rest.parse::<Guid>() {
287        (service_id, ServiceIdFormat::HyperV)
288    } else {
289        anyhow::bail!("invalid port or service ID: {}", rest);
290    };
291
292    tracing::debug!(%service_id, ?format, "got hybrid connect request");
293    Ok((service_id, format))
294}
295
296struct PendingConnection {
297    send: mesh::Sender<HvsockConnectResult>,
298    request: HvsockConnectRequest,
299}
300
301impl PendingConnection {
302    fn done(self, success: bool) {
303        self.send
304            .send(HvsockConnectResult::from_request(&self.request, success));
305        std::mem::forget(self);
306    }
307}
308
309impl Drop for PendingConnection {
310    fn drop(&mut self) {
311        self.send
312            .send(HvsockConnectResult::from_request(&self.request, false));
313    }
314}
315
316// This GUID is an embedding of the AF_VSOCK port into an
317// AF_HYPERV service ID.
318static VSOCK_TEMPLATE: Guid = guid::guid!("00000000-facb-11e6-bd58-64006a7986d3");
319
320fn vsock_port(service_id: &Guid) -> Option<u32> {
321    let stripped_id = Guid {
322        data1: 0,
323        ..*service_id
324    };
325    (VSOCK_TEMPLATE == stripped_id).then_some(service_id.data1)
326}
327
328struct HvsockRelayWorker {
329    guest_send: mesh::Sender<HvsockConnectResult>,
330    tasks: FuturesUnordered<Task<()>>,
331    inner: Arc<RelayInner>,
332    hybrid_vsock_path: Option<PathBuf>,
333}
334
335impl HvsockRelayWorker {
336    async fn run(
337        mut self,
338        guest_recv: mesh::Receiver<HvsockConnectRequest>,
339        host_recv: mesh::Receiver<RelayRequest>,
340    ) {
341        enum Event {
342            Guest(HvsockConnectRequest),
343            Host(RelayRequest),
344            TaskDone(()),
345        }
346
347        let mut recv = (guest_recv.map(Event::Guest), host_recv.map(Event::Host)).merge();
348
349        while let Some(event) = (&mut recv, (&mut self.tasks).map(Event::TaskDone))
350            .merge()
351            .next()
352            .await
353        {
354            match event {
355                Event::Guest(request) => {
356                    self.handle_connect_from_guest(request);
357                }
358                Event::Host(request) => match request {
359                    RelayRequest::AddTask(task) => {
360                        self.tasks.push(task);
361                    }
362                },
363                Event::TaskDone(()) => {}
364            }
365        }
366    }
367
368    fn handle_connect_from_guest(&mut self, request: HvsockConnectRequest) {
369        if request.silo_id != Guid::ZERO {
370            tracelimit::warn_ratelimited!(?request, "Non-zero silo ID is currently ignored.")
371        }
372
373        // Wrap the connect request so that we are assured to send a response.
374        let pending = PendingConnection {
375            send: self.guest_send.clone(),
376            request,
377        };
378        let (path, is_specific_path) = {
379            if let Some(hybrid_vsock_path) = &self.hybrid_vsock_path {
380                (hybrid_vsock_path.to_owned(), false)
381            } else {
382                tracing::debug!(request = ?&request, "ignoring hvsock connect request");
383                return;
384            }
385        };
386
387        let task = self.inner.driver.spawn(
388            format!(
389                "hvsock accept {}:{}",
390                request.service_id, request.endpoint_id
391            ),
392            {
393                let inner = self.inner.clone();
394                async move {
395                    match inner
396                        .relay_guest_connect_to_host(pending, path.as_ref(), is_specific_path)
397                        .await
398                    {
399                        Ok(()) => {
400                            tracing::debug!(request = ?&request, "relay done");
401                        }
402                        Err(err) => {
403                            tracelimit::error_ratelimited!(
404                                request = ?&request,
405                                err = err.as_ref() as &dyn std::error::Error,
406                                "relay error"
407                            );
408                        }
409                    }
410                }
411            },
412        );
413        self.tasks.push(task);
414    }
415}
416
417impl RelayInner {
418    async fn relay_guest_connect_to_host(
419        &self,
420        pending: PendingConnection,
421        base_path: &Path,
422        is_specific_path: bool,
423    ) -> anyhow::Result<()> {
424        let request = &pending.request;
425
426        // Find the appropriate path to connect to. Don't connect until the
427        // channel with the guest has been established, since that's a
428        // failure-prone operation and we don't want the host to see a broken
429        // connection.
430        let path = self.host_uds_path(request, base_path, is_specific_path)?;
431
432        let mut offer = Offer::new(
433            self.driver.as_ref(),
434            self.vmbus.as_ref(),
435            OfferParams {
436                interface_name: "hvsocket".to_owned(),
437                instance_id: request.endpoint_id,
438                interface_id: request.service_id,
439                channel_type: ChannelType::HvSocket {
440                    is_connect: false,
441                    is_for_container: false,
442                    silo_id: Guid::ZERO,
443                },
444                ..Default::default()
445            },
446        )
447        .await
448        .context("failed to offer channel")?;
449
450        tracing::debug!(?request, "offered hvsocket channel to guest");
451        let service_id = request.service_id;
452
453        // Now that the channel is offered, report that the connection operation is
454        // done.
455        pending.done(true);
456
457        // Give the guest a few seconds to open the channel.
458        let channel = CancelContext::new()
459            .with_timeout(Duration::from_secs(5))
460            .until_cancelled(offer.wait_for_open(self.driver.as_ref()))
461            .await
462            .context("guest did not open hvsocket channel")??;
463
464        tracing::debug!(%service_id, "guest opened hvsocket channel");
465
466        // Connect to the host Unix socket.
467        let socket = PolledSocket::connect_unix(self.driver.as_ref(), &path)
468            .await
469            .with_context(|| {
470                format!(
471                    "failed to connect to registered listener {} for {}",
472                    path.display(),
473                    service_id
474                )
475            })?;
476
477        tracing::debug!(%service_id, path = %path.display(), "connected to host uds socket");
478
479        // Accept the channel now that the host connection is established.
480        let channel = channel.accept().channel;
481
482        let channel = BytePipe::new(channel)?;
483        if let Err(err) = relay_connected(channel, socket).await {
484            tracelimit::error_ratelimited!(
485                %service_id,
486                error = &err as &dyn std::error::Error,
487                "guest to host connection relay failed"
488            );
489        } else {
490            tracing::debug!(%service_id, "guest to host connection relay finished");
491        }
492
493        // N.B. offer needs to stay alive until here to avoid revoking the channel
494        // before the relay is done.
495        drop(offer);
496        Ok(())
497    }
498
499    fn host_uds_path(
500        &self,
501        request: &HvsockConnectRequest,
502        base_path: &Path,
503        is_specific_path: bool,
504    ) -> anyhow::Result<PathBuf> {
505        let mut path = base_path.as_os_str().to_owned();
506        if !is_specific_path {
507            if let Some(port) = vsock_port(&request.service_id) {
508                // This is a vsock connection, so try connecting after appending the
509                // port to the path.
510                path.push(format!("_{port}"));
511                if Path::new(&path).fs_err_try_exists()? {
512                    return Ok(path.into());
513                }
514                path.clear();
515                path.push(base_path);
516            }
517            path.push(format!("_{}", request.service_id));
518        }
519        if !Path::new(&path).fs_err_try_exists()? {
520            anyhow::bail!(
521                "no hybrid vsock listener based at {} for {}",
522                base_path.display(),
523                request.service_id
524            );
525        }
526        Ok(path.into())
527    }
528
529    async fn connect_to_guest(&self, service_id: Guid) -> anyhow::Result<(UnixStream, Task<()>)> {
530        let instance_id = Guid::new_random();
531        let mut offer = Offer::new(
532            &self.driver,
533            self.vmbus.as_ref(),
534            OfferParams {
535                interface_name: "hvsocket_connect".into(),
536                interface_id: service_id,
537                instance_id,
538                channel_type: ChannelType::HvSocket {
539                    is_connect: true,
540                    is_for_container: false,
541                    silo_id: Guid::ZERO,
542                },
543                ..Default::default()
544            },
545        )
546        .await
547        .context("failed to offer channel")?;
548
549        let channel = offer
550            .wait_for_open(self.driver.as_ref())
551            .await
552            .context("failed to accept channel")?
553            .accept()
554            .channel;
555        let pipe = BytePipe::new(channel).context("failed to create vmbus pipe")?;
556
557        tracing::debug!(%service_id, endpoint_id = %instance_id, "connected host to guest");
558
559        let (left, right) = UnixStream::pair().context("failed to create socket pair")?;
560        let right = PolledSocket::new(self.driver.as_ref(), right)
561            .context("failed to create polled socket")?;
562
563        let task = self.driver.spawn(
564            format!("hvsock {}:{}", service_id, instance_id),
565            async move {
566                // Keep the offer alive until the relay completes.
567                let _offer = offer;
568                if let Err(err) = relay_connected(pipe, right).await {
569                    tracing::error!(
570                        %service_id,
571                        error = &err as &dyn std::error::Error,
572                        "connection relay failed"
573                    );
574                }
575            },
576        );
577
578        Ok((left, task))
579    }
580}
581
582async fn relay_connected<T: RingMem + Unpin>(
583    channel: BytePipe<T>,
584    socket: PolledSocket<UnixStream>,
585) -> std::io::Result<()> {
586    let (channel_read, mut channel_write) = channel.split();
587    let (socket_read, mut socket_write) = socket.split();
588
589    let channel_to_socket = async {
590        futures::io::copy(channel_read, &mut socket_write).await?;
591        socket_write.close().await
592    };
593
594    let socket_to_channel = async {
595        futures::io::copy(socket_read, &mut channel_write).await?;
596        channel_write.close().await
597    };
598
599    match futures::future::try_join(channel_to_socket, socket_to_channel).await {
600        Ok(((), ())) => {}
601        Err(err) if err.kind() == ErrorKind::ConnectionReset => {}
602        Err(err) => return Err(err),
603    }
604    Ok(())
605}
606
607#[cfg(test)]
608mod tests {
609    use super::relay_connected;
610    use crate::ring::FlatRingMem;
611    use futures::AsyncReadExt;
612    use futures::AsyncWriteExt;
613    use pal_async::DefaultDriver;
614    use pal_async::async_test;
615    use pal_async::driver::Driver;
616    use pal_async::socket::PolledSocket;
617    use pal_async::task::Spawn;
618    use pal_async::task::Task;
619    use unix_socket::UnixStream;
620    use vmbus_async::pipe::BytePipe;
621    use vmbus_async::pipe::connected_byte_pipes;
622
623    fn setup_relay<T: Driver + Spawn>(
624        driver: &T,
625    ) -> (
626        BytePipe<FlatRingMem>,
627        PolledSocket<UnixStream>,
628        Task<std::io::Result<()>>,
629    ) {
630        let (hc, c) = connected_byte_pipes(4096);
631        let (s, s2) = UnixStream::pair().unwrap();
632        let s = PolledSocket::new(driver, s).unwrap();
633        let s2 = PolledSocket::new(driver, s2).unwrap();
634        let task = driver.spawn("test", async move { relay_connected(hc, s2).await });
635
636        (c, s, task)
637    }
638
639    #[async_test]
640    async fn test_relay(driver: DefaultDriver) {
641        let (mut c, mut s, task) = setup_relay(&driver);
642
643        let d = b"abcd";
644        let mut v = [0; 4];
645
646        // c to s
647        c.write_all(d).await.unwrap();
648        s.read_exact(&mut v).await.unwrap();
649        assert_eq!(&v, d);
650
651        // s to c
652        s.write_all(d).await.unwrap();
653        c.read_exact(&mut v).await.unwrap();
654        assert_eq!(&v, d);
655
656        // s to c
657        s.write_all(d).await.unwrap();
658        s.close().await.unwrap();
659        c.read_exact(&mut v).await.unwrap();
660        assert_eq!(&v, d);
661
662        // c to s
663        c.write_all(d).await.unwrap();
664        s.read_exact(&mut v).await.unwrap();
665        assert_eq!(&v, d);
666
667        c.close().await.unwrap();
668        task.await.unwrap();
669    }
670
671    #[cfg(unix)] // Windows does not deliver POLLHUP on Unix socket close.
672    #[async_test]
673    async fn test_relay_host_close(driver: DefaultDriver) {
674        let (mut c, _, task) = setup_relay(&driver);
675
676        let mut b = [0];
677        assert_eq!(c.read(&mut b).await.unwrap(), 0);
678        drop(c);
679        task.await.unwrap();
680    }
681
682    #[async_test]
683    async fn test_relay_guest_close(driver: DefaultDriver) {
684        let (_, mut s, task) = setup_relay(&driver);
685
686        let mut b = [0];
687        assert_eq!(s.read(&mut b).await.unwrap(), 0);
688        drop(s);
689        task.await.unwrap();
690    }
691
692    #[async_test]
693    async fn test_relay_forward_socket_shutdown(driver: DefaultDriver) {
694        let (mut c, mut s, task) = setup_relay(&driver);
695        s.close().await.unwrap();
696        let mut v = [0; 1];
697        assert_eq!(c.read(&mut v).await.unwrap(), 0);
698        drop(c);
699        task.await.unwrap();
700    }
701
702    #[async_test]
703    async fn test_relay_forward_channel_shutdown(driver: DefaultDriver) {
704        let (mut c, mut s, task) = setup_relay(&driver);
705
706        c.close().await.unwrap();
707        let mut v = [0; 1];
708        assert_eq!(s.read(&mut v).await.unwrap(), 0);
709        drop(s);
710        task.await.unwrap();
711    }
712}