1use super::Guid;
12use crate::HvsockRelayChannelHalf;
13use crate::ring::RingMem;
14use anyhow::Context;
15use fs_err::PathExt;
16use futures::AsyncReadExt;
17use futures::AsyncWriteExt;
18use futures::StreamExt;
19use futures_concurrency::stream::Merge;
20use mesh::CancelContext;
21use pal_async::driver::SpawnDriver;
22use pal_async::socket::PolledSocket;
23use pal_async::task::Spawn;
24use pal_async::task::Task;
25use std::io::ErrorKind;
26use std::path::Path;
27use std::path::PathBuf;
28use std::sync::Arc;
29use std::time::Duration;
30use unicycle::FuturesUnordered;
31use unix_socket::UnixListener;
32use unix_socket::UnixStream;
33use vmbus_async::pipe::BytePipe;
34use vmbus_channel::bus::ChannelType;
35use vmbus_channel::bus::OfferParams;
36use vmbus_channel::bus::ParentBus;
37use vmbus_channel::offer::Offer;
38use vmbus_core::HvsockConnectRequest;
39use vmbus_core::HvsockConnectResult;
40
41pub struct HvsockRelay {
42 inner: Arc<RelayInner>,
43 host_send: mesh::Sender<RelayRequest>,
44 _relay_task: Task<()>,
45 _listener_task: Option<Task<()>>,
46}
47
48enum RelayRequest {
49 AddTask(Task<()>),
50}
51
52struct RelayInner {
53 vmbus: Arc<dyn ParentBus>,
54 driver: Box<dyn SpawnDriver>,
55}
56
57impl HvsockRelay {
58 pub fn new(
61 driver: impl SpawnDriver,
62 vmbus: Arc<dyn ParentBus>,
63 guest: HvsockRelayChannelHalf,
64 hybrid_vsock_path: Option<PathBuf>,
65 hybrid_vsock_listener: Option<UnixListener>,
66 ) -> anyhow::Result<Self> {
67 let inner = Arc::new(RelayInner {
68 vmbus,
69 driver: Box::new(driver),
70 });
71
72 let worker = HvsockRelayWorker {
73 guest_send: guest.response_send,
74 inner: inner.clone(),
75 tasks: Default::default(),
76 hybrid_vsock_path,
77 };
78
79 let (host_send, host_recv) = mesh::channel();
80
81 let _listener_task = if let Some(listener) = hybrid_vsock_listener {
82 let listener = PolledSocket::new(inner.driver.as_ref(), listener)?;
83 Some(
84 inner.driver.spawn(
85 "hvsock-listener",
86 ListenerWorker {
87 inner: inner.clone(),
88 host_send: host_send.clone(),
89 }
90 .run(listener),
91 ),
92 )
93 } else {
94 None
95 };
96
97 let task = inner
98 .driver
99 .spawn("hvsock relay", worker.run(guest.request_receive, host_recv));
100
101 Ok(Self {
102 host_send,
103 inner,
104 _relay_task: task,
105 _listener_task,
106 })
107 }
108
109 pub fn connect(
114 &self,
115 ctx: &mut CancelContext,
116 service_id: Guid,
117 ) -> impl Future<Output = anyhow::Result<UnixStream>> + Send + use<> {
118 let inner = self.inner.clone();
119 let host_send = self.host_send.clone();
120 let (send, recv) = mesh::oneshot();
121
122 let (mut ctx, cancel) = ctx.with_cancel();
124
125 let task = self.inner.driver.spawn("hvsock-connect", async move {
127 let r = async {
128 let (stream, task) = ctx
129 .until_cancelled(inner.connect_to_guest(service_id))
130 .await??;
131 host_send.send(RelayRequest::AddTask(task));
132 Ok(stream)
133 }
134 .await;
135
136 send.send(r);
137 });
138 self.host_send.send(RelayRequest::AddTask(task));
139 async move {
140 let _cancel = cancel;
141 recv.await?
142 }
143 }
144}
145
146struct ListenerWorker {
147 inner: Arc<RelayInner>,
148 host_send: mesh::Sender<RelayRequest>,
149}
150
151impl ListenerWorker {
152 async fn run(self, mut listener: PolledSocket<UnixListener>) {
153 loop {
154 let connection = match listener.accept().await {
155 Ok((connection, _address)) => connection,
156 Err(err) => {
157 tracing::error!(
158 error = &err as &dyn std::error::Error,
159 "failed to accept hybrid vsock connection, shutting down listener"
160 );
161 break;
162 }
163 };
164 match self.spawn_relay(connection).await {
165 Ok(task) => {
166 self.host_send.send(RelayRequest::AddTask(task));
167 }
168 Err(err) => {
169 tracing::warn!(
170 error = err.as_ref() as &dyn std::error::Error,
171 "relayed connection failed"
172 );
173 }
174 }
175 }
176 }
177
178 async fn spawn_relay(&self, connection: UnixStream) -> anyhow::Result<Task<()>> {
179 let mut socket = PolledSocket::new(self.inner.driver.as_ref(), connection)?;
180 let (service_id, format) = read_hybrid_vsock_connect(&mut socket).await?;
181
182 let instance_id = Guid::new_random();
183 let mut offer = Offer::new(
184 self.inner.driver.as_ref(),
185 self.inner.vmbus.as_ref(),
186 OfferParams {
187 interface_name: "hvsocket_connect".into(),
188 interface_id: service_id,
189 instance_id,
190 channel_type: ChannelType::HvSocket {
191 is_connect: true,
192 is_for_container: false,
193 silo_id: Guid::ZERO,
194 },
195 ..Default::default()
196 },
197 )
198 .await
199 .context("failed to offer channel")?;
200
201 let channel = CancelContext::new()
202 .with_timeout(Duration::from_secs(2))
203 .until_cancelled(offer.wait_for_open(self.inner.driver.as_ref()))
204 .await?
205 .context("failed to accept channel")?
206 .accept()
207 .channel;
208
209 let pipe = BytePipe::new(channel).context("failed to create vmbus pipe")?;
210
211 tracing::debug!(%service_id, endpoint_id = %instance_id, "connected host to guest");
212
213 let task = self
214 .inner
215 .driver
216 .spawn("hvsock connection relay", async move {
217 let _offer = offer;
219
220 let s = match format {
222 ServiceIdFormat::Vsock => format!("OK {}\n", instance_id.data1),
223 ServiceIdFormat::HyperV => format!("OK {}\n", instance_id),
224 };
225 if let Err(err) = socket.write_all(s.as_bytes()).await {
226 tracing::error!(
227 %service_id,
228 error = &err as &dyn std::error::Error,
229 "failed to write OK response"
230 );
231 }
232
233 if let Err(err) = relay_connected(pipe, socket).await {
234 tracing::error!(
235 %service_id,
236 error = &err as &dyn std::error::Error,
237 "connection relay failed"
238 );
239 } else {
240 tracing::debug!(%service_id, "connection relay finished");
241 }
242 });
243
244 Ok(task)
245 }
246}
247
248#[derive(Debug)]
249enum ServiceIdFormat {
250 Vsock,
251 HyperV,
252}
253
254async fn read_hybrid_vsock_connect(
255 socket: &mut PolledSocket<UnixStream>,
256) -> anyhow::Result<(Guid, ServiceIdFormat)> {
257 let mut buf = [0; "CONNECT 00000000-facb-11e6-bd58-64006a7986d3\n".len()];
258 let mut i = 0;
259 while i == 0 || buf[i - 1] != b'\n' {
260 if i == buf.len() {
261 anyhow::bail!("connect request did not fit");
262 }
263 let n = socket
264 .read(&mut buf[i..])
265 .await
266 .context("failed to read connect request")?;
267 if n == 0 {
268 anyhow::bail!("no connect request");
269 }
270 i += n;
271 }
272
273 let rest = buf[..i - 1]
274 .strip_prefix(b"CONNECT ")
275 .context("invalid connect request")?;
276
277 let rest = std::str::from_utf8(rest).context("invalid connect request")?;
278 let (service_id, format) = if let Ok(port) = rest.parse::<u32>() {
279 (
280 Guid {
281 data1: port,
282 ..VSOCK_TEMPLATE
283 },
284 ServiceIdFormat::Vsock,
285 )
286 } else if let Ok(service_id) = rest.parse::<Guid>() {
287 (service_id, ServiceIdFormat::HyperV)
288 } else {
289 anyhow::bail!("invalid port or service ID: {}", rest);
290 };
291
292 tracing::debug!(%service_id, ?format, "got hybrid connect request");
293 Ok((service_id, format))
294}
295
296struct PendingConnection {
297 send: mesh::Sender<HvsockConnectResult>,
298 request: HvsockConnectRequest,
299}
300
301impl PendingConnection {
302 fn done(self, success: bool) {
303 self.send
304 .send(HvsockConnectResult::from_request(&self.request, success));
305 std::mem::forget(self);
306 }
307}
308
309impl Drop for PendingConnection {
310 fn drop(&mut self) {
311 self.send
312 .send(HvsockConnectResult::from_request(&self.request, false));
313 }
314}
315
316static VSOCK_TEMPLATE: Guid = guid::guid!("00000000-facb-11e6-bd58-64006a7986d3");
319
320fn vsock_port(service_id: &Guid) -> Option<u32> {
321 let stripped_id = Guid {
322 data1: 0,
323 ..*service_id
324 };
325 (VSOCK_TEMPLATE == stripped_id).then_some(service_id.data1)
326}
327
328struct HvsockRelayWorker {
329 guest_send: mesh::Sender<HvsockConnectResult>,
330 tasks: FuturesUnordered<Task<()>>,
331 inner: Arc<RelayInner>,
332 hybrid_vsock_path: Option<PathBuf>,
333}
334
335impl HvsockRelayWorker {
336 async fn run(
337 mut self,
338 guest_recv: mesh::Receiver<HvsockConnectRequest>,
339 host_recv: mesh::Receiver<RelayRequest>,
340 ) {
341 enum Event {
342 Guest(HvsockConnectRequest),
343 Host(RelayRequest),
344 TaskDone(()),
345 }
346
347 let mut recv = (guest_recv.map(Event::Guest), host_recv.map(Event::Host)).merge();
348
349 while let Some(event) = (&mut recv, (&mut self.tasks).map(Event::TaskDone))
350 .merge()
351 .next()
352 .await
353 {
354 match event {
355 Event::Guest(request) => {
356 self.handle_connect_from_guest(request);
357 }
358 Event::Host(request) => match request {
359 RelayRequest::AddTask(task) => {
360 self.tasks.push(task);
361 }
362 },
363 Event::TaskDone(()) => {}
364 }
365 }
366 }
367
368 fn handle_connect_from_guest(&mut self, request: HvsockConnectRequest) {
369 if request.silo_id != Guid::ZERO {
370 tracelimit::warn_ratelimited!(?request, "Non-zero silo ID is currently ignored.")
371 }
372
373 let pending = PendingConnection {
375 send: self.guest_send.clone(),
376 request,
377 };
378 let (path, is_specific_path) = {
379 if let Some(hybrid_vsock_path) = &self.hybrid_vsock_path {
380 (hybrid_vsock_path.to_owned(), false)
381 } else {
382 tracing::debug!(request = ?&request, "ignoring hvsock connect request");
383 return;
384 }
385 };
386
387 let task = self.inner.driver.spawn(
388 format!(
389 "hvsock accept {}:{}",
390 request.service_id, request.endpoint_id
391 ),
392 {
393 let inner = self.inner.clone();
394 async move {
395 match inner
396 .relay_guest_connect_to_host(pending, path.as_ref(), is_specific_path)
397 .await
398 {
399 Ok(()) => {
400 tracing::debug!(request = ?&request, "relay done");
401 }
402 Err(err) => {
403 tracelimit::error_ratelimited!(
404 request = ?&request,
405 err = err.as_ref() as &dyn std::error::Error,
406 "relay error"
407 );
408 }
409 }
410 }
411 },
412 );
413 self.tasks.push(task);
414 }
415}
416
417impl RelayInner {
418 async fn relay_guest_connect_to_host(
419 &self,
420 pending: PendingConnection,
421 base_path: &Path,
422 is_specific_path: bool,
423 ) -> anyhow::Result<()> {
424 let request = &pending.request;
425
426 let path = self.host_uds_path(request, base_path, is_specific_path)?;
431
432 let mut offer = Offer::new(
433 self.driver.as_ref(),
434 self.vmbus.as_ref(),
435 OfferParams {
436 interface_name: "hvsocket".to_owned(),
437 instance_id: request.endpoint_id,
438 interface_id: request.service_id,
439 channel_type: ChannelType::HvSocket {
440 is_connect: false,
441 is_for_container: false,
442 silo_id: Guid::ZERO,
443 },
444 ..Default::default()
445 },
446 )
447 .await
448 .context("failed to offer channel")?;
449
450 tracing::debug!(?request, "offered hvsocket channel to guest");
451 let service_id = request.service_id;
452
453 pending.done(true);
456
457 let channel = CancelContext::new()
459 .with_timeout(Duration::from_secs(5))
460 .until_cancelled(offer.wait_for_open(self.driver.as_ref()))
461 .await
462 .context("guest did not open hvsocket channel")??;
463
464 tracing::debug!(%service_id, "guest opened hvsocket channel");
465
466 let socket = PolledSocket::connect_unix(self.driver.as_ref(), &path)
468 .await
469 .with_context(|| {
470 format!(
471 "failed to connect to registered listener {} for {}",
472 path.display(),
473 service_id
474 )
475 })?;
476
477 tracing::debug!(%service_id, path = %path.display(), "connected to host uds socket");
478
479 let channel = channel.accept().channel;
481
482 let channel = BytePipe::new(channel)?;
483 if let Err(err) = relay_connected(channel, socket).await {
484 tracelimit::error_ratelimited!(
485 %service_id,
486 error = &err as &dyn std::error::Error,
487 "guest to host connection relay failed"
488 );
489 } else {
490 tracing::debug!(%service_id, "guest to host connection relay finished");
491 }
492
493 drop(offer);
496 Ok(())
497 }
498
499 fn host_uds_path(
500 &self,
501 request: &HvsockConnectRequest,
502 base_path: &Path,
503 is_specific_path: bool,
504 ) -> anyhow::Result<PathBuf> {
505 let mut path = base_path.as_os_str().to_owned();
506 if !is_specific_path {
507 if let Some(port) = vsock_port(&request.service_id) {
508 path.push(format!("_{port}"));
511 if Path::new(&path).fs_err_try_exists()? {
512 return Ok(path.into());
513 }
514 path.clear();
515 path.push(base_path);
516 }
517 path.push(format!("_{}", request.service_id));
518 }
519 if !Path::new(&path).fs_err_try_exists()? {
520 anyhow::bail!(
521 "no hybrid vsock listener based at {} for {}",
522 base_path.display(),
523 request.service_id
524 );
525 }
526 Ok(path.into())
527 }
528
529 async fn connect_to_guest(&self, service_id: Guid) -> anyhow::Result<(UnixStream, Task<()>)> {
530 let instance_id = Guid::new_random();
531 let mut offer = Offer::new(
532 &self.driver,
533 self.vmbus.as_ref(),
534 OfferParams {
535 interface_name: "hvsocket_connect".into(),
536 interface_id: service_id,
537 instance_id,
538 channel_type: ChannelType::HvSocket {
539 is_connect: true,
540 is_for_container: false,
541 silo_id: Guid::ZERO,
542 },
543 ..Default::default()
544 },
545 )
546 .await
547 .context("failed to offer channel")?;
548
549 let channel = offer
550 .wait_for_open(self.driver.as_ref())
551 .await
552 .context("failed to accept channel")?
553 .accept()
554 .channel;
555 let pipe = BytePipe::new(channel).context("failed to create vmbus pipe")?;
556
557 tracing::debug!(%service_id, endpoint_id = %instance_id, "connected host to guest");
558
559 let (left, right) = UnixStream::pair().context("failed to create socket pair")?;
560 let right = PolledSocket::new(self.driver.as_ref(), right)
561 .context("failed to create polled socket")?;
562
563 let task = self.driver.spawn(
564 format!("hvsock {}:{}", service_id, instance_id),
565 async move {
566 let _offer = offer;
568 if let Err(err) = relay_connected(pipe, right).await {
569 tracing::error!(
570 %service_id,
571 error = &err as &dyn std::error::Error,
572 "connection relay failed"
573 );
574 }
575 },
576 );
577
578 Ok((left, task))
579 }
580}
581
582async fn relay_connected<T: RingMem + Unpin>(
583 channel: BytePipe<T>,
584 socket: PolledSocket<UnixStream>,
585) -> std::io::Result<()> {
586 let (channel_read, mut channel_write) = channel.split();
587 let (socket_read, mut socket_write) = socket.split();
588
589 let channel_to_socket = async {
590 futures::io::copy(channel_read, &mut socket_write).await?;
591 socket_write.close().await
592 };
593
594 let socket_to_channel = async {
595 futures::io::copy(socket_read, &mut channel_write).await?;
596 channel_write.close().await
597 };
598
599 match futures::future::try_join(channel_to_socket, socket_to_channel).await {
600 Ok(((), ())) => {}
601 Err(err) if err.kind() == ErrorKind::ConnectionReset => {}
602 Err(err) => return Err(err),
603 }
604 Ok(())
605}
606
607#[cfg(test)]
608mod tests {
609 use super::relay_connected;
610 use crate::ring::FlatRingMem;
611 use futures::AsyncReadExt;
612 use futures::AsyncWriteExt;
613 use pal_async::DefaultDriver;
614 use pal_async::async_test;
615 use pal_async::driver::Driver;
616 use pal_async::socket::PolledSocket;
617 use pal_async::task::Spawn;
618 use pal_async::task::Task;
619 use unix_socket::UnixStream;
620 use vmbus_async::pipe::BytePipe;
621 use vmbus_async::pipe::connected_byte_pipes;
622
623 fn setup_relay<T: Driver + Spawn>(
624 driver: &T,
625 ) -> (
626 BytePipe<FlatRingMem>,
627 PolledSocket<UnixStream>,
628 Task<std::io::Result<()>>,
629 ) {
630 let (hc, c) = connected_byte_pipes(4096);
631 let (s, s2) = UnixStream::pair().unwrap();
632 let s = PolledSocket::new(driver, s).unwrap();
633 let s2 = PolledSocket::new(driver, s2).unwrap();
634 let task = driver.spawn("test", async move { relay_connected(hc, s2).await });
635
636 (c, s, task)
637 }
638
639 #[async_test]
640 async fn test_relay(driver: DefaultDriver) {
641 let (mut c, mut s, task) = setup_relay(&driver);
642
643 let d = b"abcd";
644 let mut v = [0; 4];
645
646 c.write_all(d).await.unwrap();
648 s.read_exact(&mut v).await.unwrap();
649 assert_eq!(&v, d);
650
651 s.write_all(d).await.unwrap();
653 c.read_exact(&mut v).await.unwrap();
654 assert_eq!(&v, d);
655
656 s.write_all(d).await.unwrap();
658 s.close().await.unwrap();
659 c.read_exact(&mut v).await.unwrap();
660 assert_eq!(&v, d);
661
662 c.write_all(d).await.unwrap();
664 s.read_exact(&mut v).await.unwrap();
665 assert_eq!(&v, d);
666
667 c.close().await.unwrap();
668 task.await.unwrap();
669 }
670
671 #[cfg(unix)] #[async_test]
673 async fn test_relay_host_close(driver: DefaultDriver) {
674 let (mut c, _, task) = setup_relay(&driver);
675
676 let mut b = [0];
677 assert_eq!(c.read(&mut b).await.unwrap(), 0);
678 drop(c);
679 task.await.unwrap();
680 }
681
682 #[async_test]
683 async fn test_relay_guest_close(driver: DefaultDriver) {
684 let (_, mut s, task) = setup_relay(&driver);
685
686 let mut b = [0];
687 assert_eq!(s.read(&mut b).await.unwrap(), 0);
688 drop(s);
689 task.await.unwrap();
690 }
691
692 #[async_test]
693 async fn test_relay_forward_socket_shutdown(driver: DefaultDriver) {
694 let (mut c, mut s, task) = setup_relay(&driver);
695 s.close().await.unwrap();
696 let mut v = [0; 1];
697 assert_eq!(c.read(&mut v).await.unwrap(), 0);
698 drop(c);
699 task.await.unwrap();
700 }
701
702 #[async_test]
703 async fn test_relay_forward_channel_shutdown(driver: DefaultDriver) {
704 let (mut c, mut s, task) = setup_relay(&driver);
705
706 c.close().await.unwrap();
707 let mut v = [0; 1];
708 assert_eq!(s.read(&mut v).await.unwrap(), 0);
709 drop(s);
710 task.await.unwrap();
711 }
712}