vmbus_server/
hvsock.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! This module implements an hvsocket-to-Unix socket relay.
5//!
6//! This supports the [hybrid vsock connection model][1] established by
7//! Firecracker, extended to support Hyper-V sockets as well.
8//!
9//! [1]: <https://github.com/firecracker-microvm/firecracker/blob/7b2e87dc65fc45162303e5708b83c379cf1b0426/docs/vsock.md>
10
11use super::Guid;
12use crate::HvsockRelayChannelHalf;
13use crate::ring::RingMem;
14use anyhow::Context;
15use futures::AsyncReadExt;
16use futures::AsyncWriteExt;
17use futures::StreamExt;
18use futures_concurrency::stream::Merge;
19use mesh::CancelContext;
20use pal_async::driver::SpawnDriver;
21use pal_async::socket::PolledSocket;
22use pal_async::task::Spawn;
23use pal_async::task::Task;
24use std::io::ErrorKind;
25use std::path::Path;
26use std::path::PathBuf;
27use std::sync::Arc;
28use std::time::Duration;
29use unicycle::FuturesUnordered;
30use unix_socket::UnixListener;
31use unix_socket::UnixStream;
32use vmbus_async::pipe::BytePipe;
33use vmbus_channel::bus::ChannelType;
34use vmbus_channel::bus::OfferParams;
35use vmbus_channel::bus::ParentBus;
36use vmbus_channel::offer::Offer;
37use vmbus_core::HvsockConnectRequest;
38use vmbus_core::HvsockConnectResult;
39
40pub struct HvsockRelay {
41    inner: Arc<RelayInner>,
42    host_send: mesh::Sender<RelayRequest>,
43    _relay_task: Task<()>,
44    _listener_task: Option<Task<()>>,
45}
46
47enum RelayRequest {
48    AddTask(Task<()>),
49}
50
51struct RelayInner {
52    vmbus: Arc<dyn ParentBus>,
53    driver: Box<dyn SpawnDriver>,
54}
55
56impl HvsockRelay {
57    /// Creates and starts a relay thread, waiting for hvsocket connect requests
58    /// on `recv`.
59    pub fn new(
60        driver: impl SpawnDriver,
61        vmbus: Arc<dyn ParentBus>,
62        guest: HvsockRelayChannelHalf,
63        hybrid_vsock_path: Option<PathBuf>,
64        hybrid_vsock_listener: Option<UnixListener>,
65    ) -> anyhow::Result<Self> {
66        let inner = Arc::new(RelayInner {
67            vmbus,
68            driver: Box::new(driver),
69        });
70
71        let worker = HvsockRelayWorker {
72            guest_send: guest.response_send,
73            inner: inner.clone(),
74            tasks: Default::default(),
75            hybrid_vsock_path,
76        };
77
78        let (host_send, host_recv) = mesh::channel();
79
80        let _listener_task = if let Some(listener) = hybrid_vsock_listener {
81            let listener = PolledSocket::new(inner.driver.as_ref(), listener)?;
82            Some(
83                inner.driver.spawn(
84                    "hvsock-listener",
85                    ListenerWorker {
86                        inner: inner.clone(),
87                        host_send: host_send.clone(),
88                    }
89                    .run(listener),
90                ),
91            )
92        } else {
93            None
94        };
95
96        let task = inner
97            .driver
98            .spawn("hvsock relay", worker.run(guest.request_receive, host_recv));
99
100        Ok(Self {
101            host_send,
102            inner,
103            _relay_task: task,
104            _listener_task,
105        })
106    }
107
108    /// Connects to an hvsocket in the guest and returns a Unix socket that is
109    /// relayed to the hvsocket.
110    ///
111    /// Blocks until complete or cancelled.
112    pub fn connect(
113        &self,
114        ctx: &mut CancelContext,
115        service_id: Guid,
116    ) -> impl std::future::Future<Output = anyhow::Result<UnixStream>> + Send + use<> {
117        let inner = self.inner.clone();
118        let host_send = self.host_send.clone();
119        let (send, recv) = mesh::oneshot();
120
121        // Ensure the task gets dropped if the future is dropped.
122        let (mut ctx, cancel) = ctx.with_cancel();
123
124        // Spawn a task to initiate the connect to avoid keeping a reference on `RelayInner`.
125        let task = self.inner.driver.spawn("hvsock-connect", async move {
126            let r = async {
127                let (stream, task) = ctx
128                    .until_cancelled(inner.connect_to_guest(service_id))
129                    .await??;
130                host_send.send(RelayRequest::AddTask(task));
131                Ok(stream)
132            }
133            .await;
134
135            send.send(r);
136        });
137        self.host_send.send(RelayRequest::AddTask(task));
138        async move {
139            let _cancel = cancel;
140            recv.await?
141        }
142    }
143}
144
145struct ListenerWorker {
146    inner: Arc<RelayInner>,
147    host_send: mesh::Sender<RelayRequest>,
148}
149
150impl ListenerWorker {
151    async fn run(self, mut listener: PolledSocket<UnixListener>) {
152        loop {
153            let connection = match listener.accept().await {
154                Ok((connection, _address)) => connection,
155                Err(err) => {
156                    tracing::error!(
157                        error = &err as &dyn std::error::Error,
158                        "failed to accept hybrid vsock connection, shutting down listener"
159                    );
160                    break;
161                }
162            };
163            match self.spawn_relay(connection).await {
164                Ok(task) => {
165                    self.host_send.send(RelayRequest::AddTask(task));
166                }
167                Err(err) => {
168                    tracing::warn!(
169                        error = err.as_ref() as &dyn std::error::Error,
170                        "relayed connection failed"
171                    );
172                }
173            }
174        }
175    }
176
177    async fn spawn_relay(&self, connection: UnixStream) -> anyhow::Result<Task<()>> {
178        let mut socket = PolledSocket::new(self.inner.driver.as_ref(), connection)?;
179        let (service_id, format) = read_hybrid_vsock_connect(&mut socket).await?;
180
181        let instance_id = Guid::new_random();
182        let mut offer = Offer::new(
183            self.inner.driver.as_ref(),
184            self.inner.vmbus.as_ref(),
185            OfferParams {
186                interface_name: "hvsocket_connect".into(),
187                interface_id: service_id,
188                instance_id,
189                channel_type: ChannelType::HvSocket {
190                    is_connect: true,
191                    is_for_container: false,
192                    silo_id: Guid::ZERO,
193                },
194                ..Default::default()
195            },
196        )
197        .await
198        .context("failed to offer channel")?;
199
200        let channel = CancelContext::new()
201            .with_timeout(Duration::from_secs(2))
202            .until_cancelled(offer.accept(self.inner.driver.as_ref()))
203            .await?
204            .context("failed to accept channel")?
205            .channel;
206
207        let pipe = BytePipe::new(channel).context("failed to create vmbus pipe")?;
208
209        tracing::debug!(%service_id, endpoint_id = %instance_id, "connected host to guest");
210
211        let task = self
212            .inner
213            .driver
214            .spawn("hvsock connection relay", async move {
215                // Keep the offer alive until the relay completes.
216                let _offer = offer;
217
218                // Notify the client that connection was successful.
219                let s = match format {
220                    ServiceIdFormat::Vsock => format!("OK {}\n", instance_id.data1),
221                    ServiceIdFormat::HyperV => format!("OK {}\n", instance_id),
222                };
223                if let Err(err) = socket.write_all(s.as_bytes()).await {
224                    tracing::error!(
225                        %service_id,
226                        error = &err as &dyn std::error::Error,
227                        "failed to write OK response"
228                    );
229                }
230
231                if let Err(err) = relay_connected(pipe, socket).await {
232                    tracing::error!(
233                        %service_id,
234                        error = &err as &dyn std::error::Error,
235                        "connection relay failed"
236                    );
237                }
238            });
239
240        Ok(task)
241    }
242}
243
244#[derive(Debug)]
245enum ServiceIdFormat {
246    Vsock,
247    HyperV,
248}
249
250async fn read_hybrid_vsock_connect(
251    socket: &mut PolledSocket<UnixStream>,
252) -> anyhow::Result<(Guid, ServiceIdFormat)> {
253    let mut buf = [0; "CONNECT 00000000-facb-11e6-bd58-64006a7986d3\n".len()];
254    let mut i = 0;
255    while i == 0 || buf[i - 1] != b'\n' {
256        if i == buf.len() {
257            anyhow::bail!("connect request did not fit");
258        }
259        let n = socket
260            .read(&mut buf[i..])
261            .await
262            .context("failed to read connect request")?;
263        if n == 0 {
264            anyhow::bail!("no connect request");
265        }
266        i += n;
267    }
268
269    let rest = buf[..i - 1]
270        .strip_prefix(b"CONNECT ")
271        .context("invalid connect request")?;
272
273    let rest = std::str::from_utf8(rest).context("invalid connect request")?;
274    let (service_id, format) = if let Ok(port) = rest.parse::<u32>() {
275        (
276            Guid {
277                data1: port,
278                ..VSOCK_TEMPLATE
279            },
280            ServiceIdFormat::Vsock,
281        )
282    } else if let Ok(service_id) = rest.parse::<Guid>() {
283        (service_id, ServiceIdFormat::HyperV)
284    } else {
285        anyhow::bail!("invalid port or service ID: {}", rest);
286    };
287
288    tracing::debug!(%service_id, ?format, "got hybrid connect request");
289    Ok((service_id, format))
290}
291
292struct PendingConnection {
293    send: mesh::Sender<HvsockConnectResult>,
294    request: HvsockConnectRequest,
295}
296
297impl PendingConnection {
298    fn done(self, success: bool) {
299        self.send
300            .send(HvsockConnectResult::from_request(&self.request, success));
301        std::mem::forget(self);
302    }
303}
304
305impl Drop for PendingConnection {
306    fn drop(&mut self) {
307        self.send
308            .send(HvsockConnectResult::from_request(&self.request, false));
309    }
310}
311
312// This GUID is an embedding of the AF_VSOCK port into an
313// AF_HYPERV service ID.
314static VSOCK_TEMPLATE: Guid = guid::guid!("00000000-facb-11e6-bd58-64006a7986d3");
315
316fn vsock_port(service_id: &Guid) -> Option<u32> {
317    let stripped_id = Guid {
318        data1: 0,
319        ..*service_id
320    };
321    (VSOCK_TEMPLATE == stripped_id).then_some(service_id.data1)
322}
323
324struct HvsockRelayWorker {
325    guest_send: mesh::Sender<HvsockConnectResult>,
326    tasks: FuturesUnordered<Task<()>>,
327    inner: Arc<RelayInner>,
328    hybrid_vsock_path: Option<PathBuf>,
329}
330
331impl HvsockRelayWorker {
332    async fn run(
333        mut self,
334        guest_recv: mesh::Receiver<HvsockConnectRequest>,
335        host_recv: mesh::Receiver<RelayRequest>,
336    ) {
337        enum Event {
338            Guest(HvsockConnectRequest),
339            Host(RelayRequest),
340            TaskDone(()),
341        }
342
343        let mut recv = (guest_recv.map(Event::Guest), host_recv.map(Event::Host)).merge();
344
345        while let Some(event) = (&mut recv, (&mut self.tasks).map(Event::TaskDone))
346            .merge()
347            .next()
348            .await
349        {
350            match event {
351                Event::Guest(request) => {
352                    self.handle_connect_from_guest(request);
353                }
354                Event::Host(request) => match request {
355                    RelayRequest::AddTask(task) => {
356                        self.tasks.push(task);
357                    }
358                },
359                Event::TaskDone(()) => {}
360            }
361        }
362    }
363
364    fn handle_connect_from_guest(&mut self, request: HvsockConnectRequest) {
365        if request.silo_id != Guid::ZERO {
366            tracelimit::warn_ratelimited!(?request, "Non-zero silo ID is currently ignored.")
367        }
368
369        // Wrap the connect request so that we are assured to send a response.
370        let pending = PendingConnection {
371            send: self.guest_send.clone(),
372            request,
373        };
374        let (path, is_specific_path) = {
375            if let Some(hybrid_vsock_path) = &self.hybrid_vsock_path {
376                (hybrid_vsock_path.to_owned(), false)
377            } else {
378                tracing::debug!(request = ?&request, "ignoring hvsock connect request");
379                return;
380            }
381        };
382
383        let task = self.inner.driver.spawn(
384            format!(
385                "hvsock accept {}:{}",
386                request.service_id, request.endpoint_id
387            ),
388            {
389                let inner = self.inner.clone();
390                async move {
391                    match inner
392                        .relay_guest_connect_to_host(pending, path.as_ref(), is_specific_path)
393                        .await
394                    {
395                        Ok(()) => {
396                            tracing::debug!(request = ?&request, "relay done");
397                        }
398                        Err(err) => {
399                            tracing::error!(
400                                request = ?&request,
401                                err = err.as_ref() as &dyn std::error::Error,
402                                "relay error"
403                            );
404                        }
405                    }
406                }
407            },
408        );
409        self.tasks.push(task);
410    }
411}
412
413impl RelayInner {
414    async fn relay_guest_connect_to_host(
415        &self,
416        pending: PendingConnection,
417        path: &Path,
418        is_specific_path: bool,
419    ) -> anyhow::Result<()> {
420        let request = &pending.request;
421        let socket = self
422            .connect_to_host_uds(request, path, is_specific_path)
423            .await?;
424
425        let mut offer = Offer::new(
426            self.driver.as_ref(),
427            self.vmbus.as_ref(),
428            OfferParams {
429                interface_name: "hvsocket".to_owned(),
430                instance_id: request.endpoint_id,
431                interface_id: request.service_id,
432                channel_type: ChannelType::HvSocket {
433                    is_connect: false,
434                    is_for_container: false,
435                    silo_id: Guid::ZERO,
436                },
437                ..Default::default()
438            },
439        )
440        .await
441        .context("failed to offer channel")?;
442
443        // Now that the channel is offered, report that the connection operation is
444        // done.
445        pending.done(true);
446
447        let channel = offer.accept(self.driver.as_ref()).await?.channel;
448        let channel = BytePipe::new(channel)?;
449        relay_connected(channel, socket).await?;
450        // N.B. offer needs to stay alive until here to avoid revoking the channel
451        // before the relay is done.
452        drop(offer);
453        Ok(())
454    }
455
456    async fn connect_to_host_uds(
457        &self,
458        request: &HvsockConnectRequest,
459        path: &Path,
460        is_specific_path: bool,
461    ) -> anyhow::Result<PolledSocket<UnixStream>> {
462        if is_specific_path {
463            // `path` is the specific path we should connect to.
464            let socket = PolledSocket::connect_unix(self.driver.as_ref(), path)
465                .await
466                .with_context(|| {
467                    format!(
468                        "failed to connect to registered listener {} for {}",
469                        path.display(),
470                        request.service_id
471                    )
472                })?;
473            return Ok(socket);
474        }
475
476        if let Some(port) = vsock_port(&request.service_id) {
477            // This is a vsock connection, so try connecting after appending the
478            // port to the path.
479            let mut path = path.as_os_str().to_owned();
480            path.push(format!("_{port}"));
481            if let Ok(socket) = PolledSocket::connect_unix(self.driver.as_ref(), path).await {
482                return Ok(socket);
483            }
484        }
485
486        // This is not a vsock connection, or the vsock connection failed. Try
487        // connecting after appending the service ID to the path.
488        let mut path = path.as_os_str().to_owned();
489        path.push(format!("_{}", request.service_id));
490        let path = Path::new(&path);
491        let socket = PolledSocket::connect_unix(self.driver.as_ref(), path)
492            .await
493            .with_context(|| {
494                format!(
495                    "failed to connect to hybrid vsock listener {} for {}",
496                    path.display(),
497                    request.service_id
498                )
499            })?;
500
501        Ok(socket)
502    }
503
504    async fn connect_to_guest(&self, service_id: Guid) -> anyhow::Result<(UnixStream, Task<()>)> {
505        let instance_id = Guid::new_random();
506        let mut offer = Offer::new(
507            &self.driver,
508            self.vmbus.as_ref(),
509            OfferParams {
510                interface_name: "hvsocket_connect".into(),
511                interface_id: service_id,
512                instance_id,
513                channel_type: ChannelType::HvSocket {
514                    is_connect: true,
515                    is_for_container: false,
516                    silo_id: Guid::ZERO,
517                },
518                ..Default::default()
519            },
520        )
521        .await
522        .context("failed to offer channel")?;
523
524        let channel = offer
525            .accept(self.driver.as_ref())
526            .await
527            .context("failed to accept channel")?
528            .channel;
529        let pipe = BytePipe::new(channel).context("failed to create vmbus pipe")?;
530
531        tracing::debug!(%service_id, endpoint_id = %instance_id, "connected host to guest");
532
533        let (left, right) = UnixStream::pair().context("failed to create socket pair")?;
534        let right = PolledSocket::new(self.driver.as_ref(), right)
535            .context("failed to create polled socket")?;
536
537        let task = self.driver.spawn(
538            format!("hvsock {}:{}", service_id, instance_id),
539            async move {
540                // Keep the offer alive until the relay completes.
541                let _offer = offer;
542                if let Err(err) = relay_connected(pipe, right).await {
543                    tracing::error!(
544                        %service_id,
545                        error = &err as &dyn std::error::Error,
546                        "connection relay failed"
547                    );
548                }
549            },
550        );
551
552        Ok((left, task))
553    }
554}
555
556async fn relay_connected<T: RingMem + Unpin>(
557    channel: BytePipe<T>,
558    socket: PolledSocket<UnixStream>,
559) -> std::io::Result<()> {
560    let (channel_read, mut channel_write) = channel.split();
561    let (socket_read, mut socket_write) = socket.split();
562
563    let channel_to_socket = async {
564        futures::io::copy(channel_read, &mut socket_write).await?;
565        socket_write.close().await
566    };
567
568    let socket_to_channel = async {
569        futures::io::copy(socket_read, &mut channel_write).await?;
570        channel_write.close().await
571    };
572
573    match futures::future::try_join(channel_to_socket, socket_to_channel).await {
574        Ok(((), ())) => {}
575        Err(err) if err.kind() == ErrorKind::ConnectionReset => {}
576        Err(err) => return Err(err),
577    }
578    Ok(())
579}
580
581#[cfg(test)]
582mod tests {
583    use super::relay_connected;
584    use crate::ring::FlatRingMem;
585    use futures::AsyncReadExt;
586    use futures::AsyncWriteExt;
587    use pal_async::DefaultDriver;
588    use pal_async::async_test;
589    use pal_async::driver::Driver;
590    use pal_async::socket::PolledSocket;
591    use pal_async::task::Spawn;
592    use pal_async::task::Task;
593    use unix_socket::UnixStream;
594    use vmbus_async::pipe::BytePipe;
595    use vmbus_async::pipe::connected_byte_pipes;
596
597    fn setup_relay<T: Driver + Spawn>(
598        driver: &T,
599    ) -> (
600        BytePipe<FlatRingMem>,
601        PolledSocket<UnixStream>,
602        Task<std::io::Result<()>>,
603    ) {
604        let (hc, c) = connected_byte_pipes(4096);
605        let (s, s2) = UnixStream::pair().unwrap();
606        let s = PolledSocket::new(driver, s).unwrap();
607        let s2 = PolledSocket::new(driver, s2).unwrap();
608        let task = driver.spawn("test", async move { relay_connected(hc, s2).await });
609
610        (c, s, task)
611    }
612
613    #[async_test]
614    async fn test_relay(driver: DefaultDriver) {
615        let (mut c, mut s, task) = setup_relay(&driver);
616
617        let d = b"abcd";
618        let mut v = [0; 4];
619
620        // c to s
621        c.write_all(d).await.unwrap();
622        s.read_exact(&mut v).await.unwrap();
623        assert_eq!(&v, d);
624
625        // s to c
626        s.write_all(d).await.unwrap();
627        c.read_exact(&mut v).await.unwrap();
628        assert_eq!(&v, d);
629
630        // s to c
631        s.write_all(d).await.unwrap();
632        s.close().await.unwrap();
633        c.read_exact(&mut v).await.unwrap();
634        assert_eq!(&v, d);
635
636        // c to s
637        c.write_all(d).await.unwrap();
638        s.read_exact(&mut v).await.unwrap();
639        assert_eq!(&v, d);
640
641        c.close().await.unwrap();
642        task.await.unwrap();
643    }
644
645    #[cfg(unix)] // Windows does not deliver POLLHUP on Unix socket close.
646    #[async_test]
647    async fn test_relay_host_close(driver: DefaultDriver) {
648        let (mut c, _, task) = setup_relay(&driver);
649
650        let mut b = [0];
651        assert_eq!(c.read(&mut b).await.unwrap(), 0);
652        drop(c);
653        task.await.unwrap();
654    }
655
656    #[async_test]
657    async fn test_relay_guest_close(driver: DefaultDriver) {
658        let (_, mut s, task) = setup_relay(&driver);
659
660        let mut b = [0];
661        assert_eq!(s.read(&mut b).await.unwrap(), 0);
662        drop(s);
663        task.await.unwrap();
664    }
665
666    #[async_test]
667    async fn test_relay_forward_socket_shutdown(driver: DefaultDriver) {
668        let (mut c, mut s, task) = setup_relay(&driver);
669        s.close().await.unwrap();
670        let mut v = [0; 1];
671        assert_eq!(c.read(&mut v).await.unwrap(), 0);
672        drop(c);
673        task.await.unwrap();
674    }
675
676    #[async_test]
677    async fn test_relay_forward_channel_shutdown(driver: DefaultDriver) {
678        let (mut c, mut s, task) = setup_relay(&driver);
679
680        c.close().await.unwrap();
681        let mut v = [0; 1];
682        assert_eq!(s.read(&mut v).await.unwrap(), 0);
683        drop(s);
684        task.await.unwrap();
685    }
686}