vmbus_server/
hvsock.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! This module implements an hvsocket-to-Unix socket relay.
5//!
6//! This supports the [hybrid vsock connection model][1] established by
7//! Firecracker, extended to support Hyper-V sockets as well.
8//!
9//! [1]: <https://github.com/firecracker-microvm/firecracker/blob/7b2e87dc65fc45162303e5708b83c379cf1b0426/docs/vsock.md>
10
11use super::Guid;
12use crate::HvsockRelayChannelHalf;
13use crate::ring::RingMem;
14use anyhow::Context;
15use futures::AsyncReadExt;
16use futures::AsyncWriteExt;
17use futures::StreamExt;
18use futures_concurrency::stream::Merge;
19use mesh::CancelContext;
20use pal_async::driver::SpawnDriver;
21use pal_async::socket::PolledSocket;
22use pal_async::task::Spawn;
23use pal_async::task::Task;
24use std::io::ErrorKind;
25use std::path::Path;
26use std::path::PathBuf;
27use std::sync::Arc;
28use std::time::Duration;
29use unicycle::FuturesUnordered;
30use unix_socket::UnixListener;
31use unix_socket::UnixStream;
32use vmbus_async::pipe::BytePipe;
33use vmbus_channel::bus::ChannelType;
34use vmbus_channel::bus::OfferParams;
35use vmbus_channel::bus::ParentBus;
36use vmbus_channel::offer::Offer;
37use vmbus_core::HvsockConnectRequest;
38use vmbus_core::HvsockConnectResult;
39
40pub struct HvsockRelay {
41    inner: Arc<RelayInner>,
42    host_send: mesh::Sender<RelayRequest>,
43    _relay_task: Task<()>,
44    _listener_task: Option<Task<()>>,
45}
46
47enum RelayRequest {
48    AddTask(Task<()>),
49}
50
51struct RelayInner {
52    vmbus: Arc<dyn ParentBus>,
53    driver: Box<dyn SpawnDriver>,
54}
55
56impl HvsockRelay {
57    /// Creates and starts a relay thread, waiting for hvsocket connect requests
58    /// on `recv`.
59    pub fn new(
60        driver: impl SpawnDriver,
61        vmbus: Arc<dyn ParentBus>,
62        guest: HvsockRelayChannelHalf,
63        hybrid_vsock_path: Option<PathBuf>,
64        hybrid_vsock_listener: Option<UnixListener>,
65    ) -> anyhow::Result<Self> {
66        let inner = Arc::new(RelayInner {
67            vmbus,
68            driver: Box::new(driver),
69        });
70
71        let worker = HvsockRelayWorker {
72            guest_send: guest.response_send,
73            inner: inner.clone(),
74            tasks: Default::default(),
75            hybrid_vsock_path,
76        };
77
78        let (host_send, host_recv) = mesh::channel();
79
80        let _listener_task = if let Some(listener) = hybrid_vsock_listener {
81            let listener = PolledSocket::new(inner.driver.as_ref(), listener)?;
82            Some(
83                inner.driver.spawn(
84                    "hvsock-listener",
85                    ListenerWorker {
86                        inner: inner.clone(),
87                        host_send: host_send.clone(),
88                    }
89                    .run(listener),
90                ),
91            )
92        } else {
93            None
94        };
95
96        let task = inner
97            .driver
98            .spawn("hvsock relay", worker.run(guest.request_receive, host_recv));
99
100        Ok(Self {
101            host_send,
102            inner,
103            _relay_task: task,
104            _listener_task,
105        })
106    }
107
108    /// Connects to an hvsocket in the guest and returns a Unix socket that is
109    /// relayed to the hvsocket.
110    ///
111    /// Blocks until complete or cancelled.
112    pub fn connect(
113        &self,
114        ctx: &mut CancelContext,
115        service_id: Guid,
116    ) -> impl Future<Output = anyhow::Result<UnixStream>> + Send + use<> {
117        let inner = self.inner.clone();
118        let host_send = self.host_send.clone();
119        let (send, recv) = mesh::oneshot();
120
121        // Ensure the task gets dropped if the future is dropped.
122        let (mut ctx, cancel) = ctx.with_cancel();
123
124        // Spawn a task to initiate the connect to avoid keeping a reference on `RelayInner`.
125        let task = self.inner.driver.spawn("hvsock-connect", async move {
126            let r = async {
127                let (stream, task) = ctx
128                    .until_cancelled(inner.connect_to_guest(service_id))
129                    .await??;
130                host_send.send(RelayRequest::AddTask(task));
131                Ok(stream)
132            }
133            .await;
134
135            send.send(r);
136        });
137        self.host_send.send(RelayRequest::AddTask(task));
138        async move {
139            let _cancel = cancel;
140            recv.await?
141        }
142    }
143}
144
145struct ListenerWorker {
146    inner: Arc<RelayInner>,
147    host_send: mesh::Sender<RelayRequest>,
148}
149
150impl ListenerWorker {
151    async fn run(self, mut listener: PolledSocket<UnixListener>) {
152        loop {
153            let connection = match listener.accept().await {
154                Ok((connection, _address)) => connection,
155                Err(err) => {
156                    tracing::error!(
157                        error = &err as &dyn std::error::Error,
158                        "failed to accept hybrid vsock connection, shutting down listener"
159                    );
160                    break;
161                }
162            };
163            match self.spawn_relay(connection).await {
164                Ok(task) => {
165                    self.host_send.send(RelayRequest::AddTask(task));
166                }
167                Err(err) => {
168                    tracing::warn!(
169                        error = err.as_ref() as &dyn std::error::Error,
170                        "relayed connection failed"
171                    );
172                }
173            }
174        }
175    }
176
177    async fn spawn_relay(&self, connection: UnixStream) -> anyhow::Result<Task<()>> {
178        let mut socket = PolledSocket::new(self.inner.driver.as_ref(), connection)?;
179        let (service_id, format) = read_hybrid_vsock_connect(&mut socket).await?;
180
181        let instance_id = Guid::new_random();
182        let mut offer = Offer::new(
183            self.inner.driver.as_ref(),
184            self.inner.vmbus.as_ref(),
185            OfferParams {
186                interface_name: "hvsocket_connect".into(),
187                interface_id: service_id,
188                instance_id,
189                channel_type: ChannelType::HvSocket {
190                    is_connect: true,
191                    is_for_container: false,
192                    silo_id: Guid::ZERO,
193                },
194                ..Default::default()
195            },
196        )
197        .await
198        .context("failed to offer channel")?;
199
200        let channel = CancelContext::new()
201            .with_timeout(Duration::from_secs(2))
202            .until_cancelled(offer.accept(self.inner.driver.as_ref()))
203            .await?
204            .context("failed to accept channel")?
205            .channel;
206
207        let pipe = BytePipe::new(channel).context("failed to create vmbus pipe")?;
208
209        tracing::debug!(%service_id, endpoint_id = %instance_id, "connected host to guest");
210
211        let task = self
212            .inner
213            .driver
214            .spawn("hvsock connection relay", async move {
215                // Keep the offer alive until the relay completes.
216                let _offer = offer;
217
218                // Notify the client that connection was successful.
219                let s = match format {
220                    ServiceIdFormat::Vsock => format!("OK {}\n", instance_id.data1),
221                    ServiceIdFormat::HyperV => format!("OK {}\n", instance_id),
222                };
223                if let Err(err) = socket.write_all(s.as_bytes()).await {
224                    tracing::error!(
225                        %service_id,
226                        error = &err as &dyn std::error::Error,
227                        "failed to write OK response"
228                    );
229                }
230
231                if let Err(err) = relay_connected(pipe, socket).await {
232                    tracing::error!(
233                        %service_id,
234                        error = &err as &dyn std::error::Error,
235                        "connection relay failed"
236                    );
237                } else {
238                    tracing::debug!(%service_id, "connection relay finished");
239                }
240            });
241
242        Ok(task)
243    }
244}
245
246#[derive(Debug)]
247enum ServiceIdFormat {
248    Vsock,
249    HyperV,
250}
251
252async fn read_hybrid_vsock_connect(
253    socket: &mut PolledSocket<UnixStream>,
254) -> anyhow::Result<(Guid, ServiceIdFormat)> {
255    let mut buf = [0; "CONNECT 00000000-facb-11e6-bd58-64006a7986d3\n".len()];
256    let mut i = 0;
257    while i == 0 || buf[i - 1] != b'\n' {
258        if i == buf.len() {
259            anyhow::bail!("connect request did not fit");
260        }
261        let n = socket
262            .read(&mut buf[i..])
263            .await
264            .context("failed to read connect request")?;
265        if n == 0 {
266            anyhow::bail!("no connect request");
267        }
268        i += n;
269    }
270
271    let rest = buf[..i - 1]
272        .strip_prefix(b"CONNECT ")
273        .context("invalid connect request")?;
274
275    let rest = std::str::from_utf8(rest).context("invalid connect request")?;
276    let (service_id, format) = if let Ok(port) = rest.parse::<u32>() {
277        (
278            Guid {
279                data1: port,
280                ..VSOCK_TEMPLATE
281            },
282            ServiceIdFormat::Vsock,
283        )
284    } else if let Ok(service_id) = rest.parse::<Guid>() {
285        (service_id, ServiceIdFormat::HyperV)
286    } else {
287        anyhow::bail!("invalid port or service ID: {}", rest);
288    };
289
290    tracing::debug!(%service_id, ?format, "got hybrid connect request");
291    Ok((service_id, format))
292}
293
294struct PendingConnection {
295    send: mesh::Sender<HvsockConnectResult>,
296    request: HvsockConnectRequest,
297}
298
299impl PendingConnection {
300    fn done(self, success: bool) {
301        self.send
302            .send(HvsockConnectResult::from_request(&self.request, success));
303        std::mem::forget(self);
304    }
305}
306
307impl Drop for PendingConnection {
308    fn drop(&mut self) {
309        self.send
310            .send(HvsockConnectResult::from_request(&self.request, false));
311    }
312}
313
314// This GUID is an embedding of the AF_VSOCK port into an
315// AF_HYPERV service ID.
316static VSOCK_TEMPLATE: Guid = guid::guid!("00000000-facb-11e6-bd58-64006a7986d3");
317
318fn vsock_port(service_id: &Guid) -> Option<u32> {
319    let stripped_id = Guid {
320        data1: 0,
321        ..*service_id
322    };
323    (VSOCK_TEMPLATE == stripped_id).then_some(service_id.data1)
324}
325
326struct HvsockRelayWorker {
327    guest_send: mesh::Sender<HvsockConnectResult>,
328    tasks: FuturesUnordered<Task<()>>,
329    inner: Arc<RelayInner>,
330    hybrid_vsock_path: Option<PathBuf>,
331}
332
333impl HvsockRelayWorker {
334    async fn run(
335        mut self,
336        guest_recv: mesh::Receiver<HvsockConnectRequest>,
337        host_recv: mesh::Receiver<RelayRequest>,
338    ) {
339        enum Event {
340            Guest(HvsockConnectRequest),
341            Host(RelayRequest),
342            TaskDone(()),
343        }
344
345        let mut recv = (guest_recv.map(Event::Guest), host_recv.map(Event::Host)).merge();
346
347        while let Some(event) = (&mut recv, (&mut self.tasks).map(Event::TaskDone))
348            .merge()
349            .next()
350            .await
351        {
352            match event {
353                Event::Guest(request) => {
354                    self.handle_connect_from_guest(request);
355                }
356                Event::Host(request) => match request {
357                    RelayRequest::AddTask(task) => {
358                        self.tasks.push(task);
359                    }
360                },
361                Event::TaskDone(()) => {}
362            }
363        }
364    }
365
366    fn handle_connect_from_guest(&mut self, request: HvsockConnectRequest) {
367        if request.silo_id != Guid::ZERO {
368            tracelimit::warn_ratelimited!(?request, "Non-zero silo ID is currently ignored.")
369        }
370
371        // Wrap the connect request so that we are assured to send a response.
372        let pending = PendingConnection {
373            send: self.guest_send.clone(),
374            request,
375        };
376        let (path, is_specific_path) = {
377            if let Some(hybrid_vsock_path) = &self.hybrid_vsock_path {
378                (hybrid_vsock_path.to_owned(), false)
379            } else {
380                tracing::debug!(request = ?&request, "ignoring hvsock connect request");
381                return;
382            }
383        };
384
385        let task = self.inner.driver.spawn(
386            format!(
387                "hvsock accept {}:{}",
388                request.service_id, request.endpoint_id
389            ),
390            {
391                let inner = self.inner.clone();
392                async move {
393                    match inner
394                        .relay_guest_connect_to_host(pending, path.as_ref(), is_specific_path)
395                        .await
396                    {
397                        Ok(()) => {
398                            tracing::debug!(request = ?&request, "relay done");
399                        }
400                        Err(err) => {
401                            tracing::error!(
402                                request = ?&request,
403                                err = err.as_ref() as &dyn std::error::Error,
404                                "relay error"
405                            );
406                        }
407                    }
408                }
409            },
410        );
411        self.tasks.push(task);
412    }
413}
414
415impl RelayInner {
416    async fn relay_guest_connect_to_host(
417        &self,
418        pending: PendingConnection,
419        path: &Path,
420        is_specific_path: bool,
421    ) -> anyhow::Result<()> {
422        let request = &pending.request;
423        let socket = self
424            .connect_to_host_uds(request, path, is_specific_path)
425            .await?;
426
427        let mut offer = Offer::new(
428            self.driver.as_ref(),
429            self.vmbus.as_ref(),
430            OfferParams {
431                interface_name: "hvsocket".to_owned(),
432                instance_id: request.endpoint_id,
433                interface_id: request.service_id,
434                channel_type: ChannelType::HvSocket {
435                    is_connect: false,
436                    is_for_container: false,
437                    silo_id: Guid::ZERO,
438                },
439                ..Default::default()
440            },
441        )
442        .await
443        .context("failed to offer channel")?;
444
445        tracing::debug!(?request, "connected guest to host");
446        let service_id = request.service_id;
447
448        // Now that the channel is offered, report that the connection operation is
449        // done.
450        pending.done(true);
451
452        let channel = offer.accept(self.driver.as_ref()).await?.channel;
453        let channel = BytePipe::new(channel)?;
454        if let Err(err) = relay_connected(channel, socket).await {
455            tracing::error!(
456                %service_id,
457                error = &err as &dyn std::error::Error,
458                "guest to host connection relay failed"
459            );
460        } else {
461            tracing::debug!(%service_id, "guest to host connection relay finished");
462        }
463
464        // N.B. offer needs to stay alive until here to avoid revoking the channel
465        // before the relay is done.
466        drop(offer);
467        Ok(())
468    }
469
470    async fn connect_to_host_uds(
471        &self,
472        request: &HvsockConnectRequest,
473        path: &Path,
474        is_specific_path: bool,
475    ) -> anyhow::Result<PolledSocket<UnixStream>> {
476        if is_specific_path {
477            // `path` is the specific path we should connect to.
478            let socket = PolledSocket::connect_unix(self.driver.as_ref(), path)
479                .await
480                .with_context(|| {
481                    format!(
482                        "failed to connect to registered listener {} for {}",
483                        path.display(),
484                        request.service_id
485                    )
486                })?;
487            return Ok(socket);
488        }
489
490        if let Some(port) = vsock_port(&request.service_id) {
491            // This is a vsock connection, so try connecting after appending the
492            // port to the path.
493            let mut path = path.as_os_str().to_owned();
494            path.push(format!("_{port}"));
495            if let Ok(socket) = PolledSocket::connect_unix(self.driver.as_ref(), path).await {
496                return Ok(socket);
497            }
498        }
499
500        // This is not a vsock connection, or the vsock connection failed. Try
501        // connecting after appending the service ID to the path.
502        let mut path = path.as_os_str().to_owned();
503        path.push(format!("_{}", request.service_id));
504        let path = Path::new(&path);
505        let socket = PolledSocket::connect_unix(self.driver.as_ref(), path)
506            .await
507            .with_context(|| {
508                format!(
509                    "failed to connect to hybrid vsock listener {} for {}",
510                    path.display(),
511                    request.service_id
512                )
513            })?;
514
515        Ok(socket)
516    }
517
518    async fn connect_to_guest(&self, service_id: Guid) -> anyhow::Result<(UnixStream, Task<()>)> {
519        let instance_id = Guid::new_random();
520        let mut offer = Offer::new(
521            &self.driver,
522            self.vmbus.as_ref(),
523            OfferParams {
524                interface_name: "hvsocket_connect".into(),
525                interface_id: service_id,
526                instance_id,
527                channel_type: ChannelType::HvSocket {
528                    is_connect: true,
529                    is_for_container: false,
530                    silo_id: Guid::ZERO,
531                },
532                ..Default::default()
533            },
534        )
535        .await
536        .context("failed to offer channel")?;
537
538        let channel = offer
539            .accept(self.driver.as_ref())
540            .await
541            .context("failed to accept channel")?
542            .channel;
543        let pipe = BytePipe::new(channel).context("failed to create vmbus pipe")?;
544
545        tracing::debug!(%service_id, endpoint_id = %instance_id, "connected host to guest");
546
547        let (left, right) = UnixStream::pair().context("failed to create socket pair")?;
548        let right = PolledSocket::new(self.driver.as_ref(), right)
549            .context("failed to create polled socket")?;
550
551        let task = self.driver.spawn(
552            format!("hvsock {}:{}", service_id, instance_id),
553            async move {
554                // Keep the offer alive until the relay completes.
555                let _offer = offer;
556                if let Err(err) = relay_connected(pipe, right).await {
557                    tracing::error!(
558                        %service_id,
559                        error = &err as &dyn std::error::Error,
560                        "connection relay failed"
561                    );
562                }
563            },
564        );
565
566        Ok((left, task))
567    }
568}
569
570async fn relay_connected<T: RingMem + Unpin>(
571    channel: BytePipe<T>,
572    socket: PolledSocket<UnixStream>,
573) -> std::io::Result<()> {
574    let (channel_read, mut channel_write) = channel.split();
575    let (socket_read, mut socket_write) = socket.split();
576
577    let channel_to_socket = async {
578        futures::io::copy(channel_read, &mut socket_write).await?;
579        socket_write.close().await
580    };
581
582    let socket_to_channel = async {
583        futures::io::copy(socket_read, &mut channel_write).await?;
584        channel_write.close().await
585    };
586
587    match futures::future::try_join(channel_to_socket, socket_to_channel).await {
588        Ok(((), ())) => {}
589        Err(err) if err.kind() == ErrorKind::ConnectionReset => {}
590        Err(err) => return Err(err),
591    }
592    Ok(())
593}
594
595#[cfg(test)]
596mod tests {
597    use super::relay_connected;
598    use crate::ring::FlatRingMem;
599    use futures::AsyncReadExt;
600    use futures::AsyncWriteExt;
601    use pal_async::DefaultDriver;
602    use pal_async::async_test;
603    use pal_async::driver::Driver;
604    use pal_async::socket::PolledSocket;
605    use pal_async::task::Spawn;
606    use pal_async::task::Task;
607    use unix_socket::UnixStream;
608    use vmbus_async::pipe::BytePipe;
609    use vmbus_async::pipe::connected_byte_pipes;
610
611    fn setup_relay<T: Driver + Spawn>(
612        driver: &T,
613    ) -> (
614        BytePipe<FlatRingMem>,
615        PolledSocket<UnixStream>,
616        Task<std::io::Result<()>>,
617    ) {
618        let (hc, c) = connected_byte_pipes(4096);
619        let (s, s2) = UnixStream::pair().unwrap();
620        let s = PolledSocket::new(driver, s).unwrap();
621        let s2 = PolledSocket::new(driver, s2).unwrap();
622        let task = driver.spawn("test", async move { relay_connected(hc, s2).await });
623
624        (c, s, task)
625    }
626
627    #[async_test]
628    async fn test_relay(driver: DefaultDriver) {
629        let (mut c, mut s, task) = setup_relay(&driver);
630
631        let d = b"abcd";
632        let mut v = [0; 4];
633
634        // c to s
635        c.write_all(d).await.unwrap();
636        s.read_exact(&mut v).await.unwrap();
637        assert_eq!(&v, d);
638
639        // s to c
640        s.write_all(d).await.unwrap();
641        c.read_exact(&mut v).await.unwrap();
642        assert_eq!(&v, d);
643
644        // s to c
645        s.write_all(d).await.unwrap();
646        s.close().await.unwrap();
647        c.read_exact(&mut v).await.unwrap();
648        assert_eq!(&v, d);
649
650        // c to s
651        c.write_all(d).await.unwrap();
652        s.read_exact(&mut v).await.unwrap();
653        assert_eq!(&v, d);
654
655        c.close().await.unwrap();
656        task.await.unwrap();
657    }
658
659    #[cfg(unix)] // Windows does not deliver POLLHUP on Unix socket close.
660    #[async_test]
661    async fn test_relay_host_close(driver: DefaultDriver) {
662        let (mut c, _, task) = setup_relay(&driver);
663
664        let mut b = [0];
665        assert_eq!(c.read(&mut b).await.unwrap(), 0);
666        drop(c);
667        task.await.unwrap();
668    }
669
670    #[async_test]
671    async fn test_relay_guest_close(driver: DefaultDriver) {
672        let (_, mut s, task) = setup_relay(&driver);
673
674        let mut b = [0];
675        assert_eq!(s.read(&mut b).await.unwrap(), 0);
676        drop(s);
677        task.await.unwrap();
678    }
679
680    #[async_test]
681    async fn test_relay_forward_socket_shutdown(driver: DefaultDriver) {
682        let (mut c, mut s, task) = setup_relay(&driver);
683        s.close().await.unwrap();
684        let mut v = [0; 1];
685        assert_eq!(c.read(&mut v).await.unwrap(), 0);
686        drop(c);
687        task.await.unwrap();
688    }
689
690    #[async_test]
691    async fn test_relay_forward_channel_shutdown(driver: DefaultDriver) {
692        let (mut c, mut s, task) = setup_relay(&driver);
693
694        c.close().await.unwrap();
695        let mut v = [0; 1];
696        assert_eq!(s.read(&mut v).await.unwrap(), 0);
697        drop(s);
698        task.await.unwrap();
699    }
700}