1use super::Guid;
12use crate::HvsockRelayChannelHalf;
13use crate::ring::RingMem;
14use anyhow::Context;
15use futures::AsyncReadExt;
16use futures::AsyncWriteExt;
17use futures::StreamExt;
18use futures_concurrency::stream::Merge;
19use mesh::CancelContext;
20use pal_async::driver::SpawnDriver;
21use pal_async::socket::PolledSocket;
22use pal_async::task::Spawn;
23use pal_async::task::Task;
24use std::io::ErrorKind;
25use std::path::Path;
26use std::path::PathBuf;
27use std::sync::Arc;
28use std::time::Duration;
29use unicycle::FuturesUnordered;
30use unix_socket::UnixListener;
31use unix_socket::UnixStream;
32use vmbus_async::pipe::BytePipe;
33use vmbus_channel::bus::ChannelType;
34use vmbus_channel::bus::OfferParams;
35use vmbus_channel::bus::ParentBus;
36use vmbus_channel::offer::Offer;
37use vmbus_core::HvsockConnectRequest;
38use vmbus_core::HvsockConnectResult;
39
40pub struct HvsockRelay {
41 inner: Arc<RelayInner>,
42 host_send: mesh::Sender<RelayRequest>,
43 _relay_task: Task<()>,
44 _listener_task: Option<Task<()>>,
45}
46
47enum RelayRequest {
48 AddTask(Task<()>),
49}
50
51struct RelayInner {
52 vmbus: Arc<dyn ParentBus>,
53 driver: Box<dyn SpawnDriver>,
54}
55
56impl HvsockRelay {
57 pub fn new(
60 driver: impl SpawnDriver,
61 vmbus: Arc<dyn ParentBus>,
62 guest: HvsockRelayChannelHalf,
63 hybrid_vsock_path: Option<PathBuf>,
64 hybrid_vsock_listener: Option<UnixListener>,
65 ) -> anyhow::Result<Self> {
66 let inner = Arc::new(RelayInner {
67 vmbus,
68 driver: Box::new(driver),
69 });
70
71 let worker = HvsockRelayWorker {
72 guest_send: guest.response_send,
73 inner: inner.clone(),
74 tasks: Default::default(),
75 hybrid_vsock_path,
76 };
77
78 let (host_send, host_recv) = mesh::channel();
79
80 let _listener_task = if let Some(listener) = hybrid_vsock_listener {
81 let listener = PolledSocket::new(inner.driver.as_ref(), listener)?;
82 Some(
83 inner.driver.spawn(
84 "hvsock-listener",
85 ListenerWorker {
86 inner: inner.clone(),
87 host_send: host_send.clone(),
88 }
89 .run(listener),
90 ),
91 )
92 } else {
93 None
94 };
95
96 let task = inner
97 .driver
98 .spawn("hvsock relay", worker.run(guest.request_receive, host_recv));
99
100 Ok(Self {
101 host_send,
102 inner,
103 _relay_task: task,
104 _listener_task,
105 })
106 }
107
108 pub fn connect(
113 &self,
114 ctx: &mut CancelContext,
115 service_id: Guid,
116 ) -> impl Future<Output = anyhow::Result<UnixStream>> + Send + use<> {
117 let inner = self.inner.clone();
118 let host_send = self.host_send.clone();
119 let (send, recv) = mesh::oneshot();
120
121 let (mut ctx, cancel) = ctx.with_cancel();
123
124 let task = self.inner.driver.spawn("hvsock-connect", async move {
126 let r = async {
127 let (stream, task) = ctx
128 .until_cancelled(inner.connect_to_guest(service_id))
129 .await??;
130 host_send.send(RelayRequest::AddTask(task));
131 Ok(stream)
132 }
133 .await;
134
135 send.send(r);
136 });
137 self.host_send.send(RelayRequest::AddTask(task));
138 async move {
139 let _cancel = cancel;
140 recv.await?
141 }
142 }
143}
144
145struct ListenerWorker {
146 inner: Arc<RelayInner>,
147 host_send: mesh::Sender<RelayRequest>,
148}
149
150impl ListenerWorker {
151 async fn run(self, mut listener: PolledSocket<UnixListener>) {
152 loop {
153 let connection = match listener.accept().await {
154 Ok((connection, _address)) => connection,
155 Err(err) => {
156 tracing::error!(
157 error = &err as &dyn std::error::Error,
158 "failed to accept hybrid vsock connection, shutting down listener"
159 );
160 break;
161 }
162 };
163 match self.spawn_relay(connection).await {
164 Ok(task) => {
165 self.host_send.send(RelayRequest::AddTask(task));
166 }
167 Err(err) => {
168 tracing::warn!(
169 error = err.as_ref() as &dyn std::error::Error,
170 "relayed connection failed"
171 );
172 }
173 }
174 }
175 }
176
177 async fn spawn_relay(&self, connection: UnixStream) -> anyhow::Result<Task<()>> {
178 let mut socket = PolledSocket::new(self.inner.driver.as_ref(), connection)?;
179 let (service_id, format) = read_hybrid_vsock_connect(&mut socket).await?;
180
181 let instance_id = Guid::new_random();
182 let mut offer = Offer::new(
183 self.inner.driver.as_ref(),
184 self.inner.vmbus.as_ref(),
185 OfferParams {
186 interface_name: "hvsocket_connect".into(),
187 interface_id: service_id,
188 instance_id,
189 channel_type: ChannelType::HvSocket {
190 is_connect: true,
191 is_for_container: false,
192 silo_id: Guid::ZERO,
193 },
194 ..Default::default()
195 },
196 )
197 .await
198 .context("failed to offer channel")?;
199
200 let channel = CancelContext::new()
201 .with_timeout(Duration::from_secs(2))
202 .until_cancelled(offer.accept(self.inner.driver.as_ref()))
203 .await?
204 .context("failed to accept channel")?
205 .channel;
206
207 let pipe = BytePipe::new(channel).context("failed to create vmbus pipe")?;
208
209 tracing::debug!(%service_id, endpoint_id = %instance_id, "connected host to guest");
210
211 let task = self
212 .inner
213 .driver
214 .spawn("hvsock connection relay", async move {
215 let _offer = offer;
217
218 let s = match format {
220 ServiceIdFormat::Vsock => format!("OK {}\n", instance_id.data1),
221 ServiceIdFormat::HyperV => format!("OK {}\n", instance_id),
222 };
223 if let Err(err) = socket.write_all(s.as_bytes()).await {
224 tracing::error!(
225 %service_id,
226 error = &err as &dyn std::error::Error,
227 "failed to write OK response"
228 );
229 }
230
231 if let Err(err) = relay_connected(pipe, socket).await {
232 tracing::error!(
233 %service_id,
234 error = &err as &dyn std::error::Error,
235 "connection relay failed"
236 );
237 } else {
238 tracing::debug!(%service_id, "connection relay finished");
239 }
240 });
241
242 Ok(task)
243 }
244}
245
246#[derive(Debug)]
247enum ServiceIdFormat {
248 Vsock,
249 HyperV,
250}
251
252async fn read_hybrid_vsock_connect(
253 socket: &mut PolledSocket<UnixStream>,
254) -> anyhow::Result<(Guid, ServiceIdFormat)> {
255 let mut buf = [0; "CONNECT 00000000-facb-11e6-bd58-64006a7986d3\n".len()];
256 let mut i = 0;
257 while i == 0 || buf[i - 1] != b'\n' {
258 if i == buf.len() {
259 anyhow::bail!("connect request did not fit");
260 }
261 let n = socket
262 .read(&mut buf[i..])
263 .await
264 .context("failed to read connect request")?;
265 if n == 0 {
266 anyhow::bail!("no connect request");
267 }
268 i += n;
269 }
270
271 let rest = buf[..i - 1]
272 .strip_prefix(b"CONNECT ")
273 .context("invalid connect request")?;
274
275 let rest = std::str::from_utf8(rest).context("invalid connect request")?;
276 let (service_id, format) = if let Ok(port) = rest.parse::<u32>() {
277 (
278 Guid {
279 data1: port,
280 ..VSOCK_TEMPLATE
281 },
282 ServiceIdFormat::Vsock,
283 )
284 } else if let Ok(service_id) = rest.parse::<Guid>() {
285 (service_id, ServiceIdFormat::HyperV)
286 } else {
287 anyhow::bail!("invalid port or service ID: {}", rest);
288 };
289
290 tracing::debug!(%service_id, ?format, "got hybrid connect request");
291 Ok((service_id, format))
292}
293
294struct PendingConnection {
295 send: mesh::Sender<HvsockConnectResult>,
296 request: HvsockConnectRequest,
297}
298
299impl PendingConnection {
300 fn done(self, success: bool) {
301 self.send
302 .send(HvsockConnectResult::from_request(&self.request, success));
303 std::mem::forget(self);
304 }
305}
306
307impl Drop for PendingConnection {
308 fn drop(&mut self) {
309 self.send
310 .send(HvsockConnectResult::from_request(&self.request, false));
311 }
312}
313
314static VSOCK_TEMPLATE: Guid = guid::guid!("00000000-facb-11e6-bd58-64006a7986d3");
317
318fn vsock_port(service_id: &Guid) -> Option<u32> {
319 let stripped_id = Guid {
320 data1: 0,
321 ..*service_id
322 };
323 (VSOCK_TEMPLATE == stripped_id).then_some(service_id.data1)
324}
325
326struct HvsockRelayWorker {
327 guest_send: mesh::Sender<HvsockConnectResult>,
328 tasks: FuturesUnordered<Task<()>>,
329 inner: Arc<RelayInner>,
330 hybrid_vsock_path: Option<PathBuf>,
331}
332
333impl HvsockRelayWorker {
334 async fn run(
335 mut self,
336 guest_recv: mesh::Receiver<HvsockConnectRequest>,
337 host_recv: mesh::Receiver<RelayRequest>,
338 ) {
339 enum Event {
340 Guest(HvsockConnectRequest),
341 Host(RelayRequest),
342 TaskDone(()),
343 }
344
345 let mut recv = (guest_recv.map(Event::Guest), host_recv.map(Event::Host)).merge();
346
347 while let Some(event) = (&mut recv, (&mut self.tasks).map(Event::TaskDone))
348 .merge()
349 .next()
350 .await
351 {
352 match event {
353 Event::Guest(request) => {
354 self.handle_connect_from_guest(request);
355 }
356 Event::Host(request) => match request {
357 RelayRequest::AddTask(task) => {
358 self.tasks.push(task);
359 }
360 },
361 Event::TaskDone(()) => {}
362 }
363 }
364 }
365
366 fn handle_connect_from_guest(&mut self, request: HvsockConnectRequest) {
367 if request.silo_id != Guid::ZERO {
368 tracelimit::warn_ratelimited!(?request, "Non-zero silo ID is currently ignored.")
369 }
370
371 let pending = PendingConnection {
373 send: self.guest_send.clone(),
374 request,
375 };
376 let (path, is_specific_path) = {
377 if let Some(hybrid_vsock_path) = &self.hybrid_vsock_path {
378 (hybrid_vsock_path.to_owned(), false)
379 } else {
380 tracing::debug!(request = ?&request, "ignoring hvsock connect request");
381 return;
382 }
383 };
384
385 let task = self.inner.driver.spawn(
386 format!(
387 "hvsock accept {}:{}",
388 request.service_id, request.endpoint_id
389 ),
390 {
391 let inner = self.inner.clone();
392 async move {
393 match inner
394 .relay_guest_connect_to_host(pending, path.as_ref(), is_specific_path)
395 .await
396 {
397 Ok(()) => {
398 tracing::debug!(request = ?&request, "relay done");
399 }
400 Err(err) => {
401 tracing::error!(
402 request = ?&request,
403 err = err.as_ref() as &dyn std::error::Error,
404 "relay error"
405 );
406 }
407 }
408 }
409 },
410 );
411 self.tasks.push(task);
412 }
413}
414
415impl RelayInner {
416 async fn relay_guest_connect_to_host(
417 &self,
418 pending: PendingConnection,
419 path: &Path,
420 is_specific_path: bool,
421 ) -> anyhow::Result<()> {
422 let request = &pending.request;
423 let socket = self
424 .connect_to_host_uds(request, path, is_specific_path)
425 .await?;
426
427 let mut offer = Offer::new(
428 self.driver.as_ref(),
429 self.vmbus.as_ref(),
430 OfferParams {
431 interface_name: "hvsocket".to_owned(),
432 instance_id: request.endpoint_id,
433 interface_id: request.service_id,
434 channel_type: ChannelType::HvSocket {
435 is_connect: false,
436 is_for_container: false,
437 silo_id: Guid::ZERO,
438 },
439 ..Default::default()
440 },
441 )
442 .await
443 .context("failed to offer channel")?;
444
445 tracing::debug!(?request, "connected guest to host");
446 let service_id = request.service_id;
447
448 pending.done(true);
451
452 let channel = offer.accept(self.driver.as_ref()).await?.channel;
453 let channel = BytePipe::new(channel)?;
454 if let Err(err) = relay_connected(channel, socket).await {
455 tracing::error!(
456 %service_id,
457 error = &err as &dyn std::error::Error,
458 "guest to host connection relay failed"
459 );
460 } else {
461 tracing::debug!(%service_id, "guest to host connection relay finished");
462 }
463
464 drop(offer);
467 Ok(())
468 }
469
470 async fn connect_to_host_uds(
471 &self,
472 request: &HvsockConnectRequest,
473 path: &Path,
474 is_specific_path: bool,
475 ) -> anyhow::Result<PolledSocket<UnixStream>> {
476 if is_specific_path {
477 let socket = PolledSocket::connect_unix(self.driver.as_ref(), path)
479 .await
480 .with_context(|| {
481 format!(
482 "failed to connect to registered listener {} for {}",
483 path.display(),
484 request.service_id
485 )
486 })?;
487 return Ok(socket);
488 }
489
490 if let Some(port) = vsock_port(&request.service_id) {
491 let mut path = path.as_os_str().to_owned();
494 path.push(format!("_{port}"));
495 if let Ok(socket) = PolledSocket::connect_unix(self.driver.as_ref(), path).await {
496 return Ok(socket);
497 }
498 }
499
500 let mut path = path.as_os_str().to_owned();
503 path.push(format!("_{}", request.service_id));
504 let path = Path::new(&path);
505 let socket = PolledSocket::connect_unix(self.driver.as_ref(), path)
506 .await
507 .with_context(|| {
508 format!(
509 "failed to connect to hybrid vsock listener {} for {}",
510 path.display(),
511 request.service_id
512 )
513 })?;
514
515 Ok(socket)
516 }
517
518 async fn connect_to_guest(&self, service_id: Guid) -> anyhow::Result<(UnixStream, Task<()>)> {
519 let instance_id = Guid::new_random();
520 let mut offer = Offer::new(
521 &self.driver,
522 self.vmbus.as_ref(),
523 OfferParams {
524 interface_name: "hvsocket_connect".into(),
525 interface_id: service_id,
526 instance_id,
527 channel_type: ChannelType::HvSocket {
528 is_connect: true,
529 is_for_container: false,
530 silo_id: Guid::ZERO,
531 },
532 ..Default::default()
533 },
534 )
535 .await
536 .context("failed to offer channel")?;
537
538 let channel = offer
539 .accept(self.driver.as_ref())
540 .await
541 .context("failed to accept channel")?
542 .channel;
543 let pipe = BytePipe::new(channel).context("failed to create vmbus pipe")?;
544
545 tracing::debug!(%service_id, endpoint_id = %instance_id, "connected host to guest");
546
547 let (left, right) = UnixStream::pair().context("failed to create socket pair")?;
548 let right = PolledSocket::new(self.driver.as_ref(), right)
549 .context("failed to create polled socket")?;
550
551 let task = self.driver.spawn(
552 format!("hvsock {}:{}", service_id, instance_id),
553 async move {
554 let _offer = offer;
556 if let Err(err) = relay_connected(pipe, right).await {
557 tracing::error!(
558 %service_id,
559 error = &err as &dyn std::error::Error,
560 "connection relay failed"
561 );
562 }
563 },
564 );
565
566 Ok((left, task))
567 }
568}
569
570async fn relay_connected<T: RingMem + Unpin>(
571 channel: BytePipe<T>,
572 socket: PolledSocket<UnixStream>,
573) -> std::io::Result<()> {
574 let (channel_read, mut channel_write) = channel.split();
575 let (socket_read, mut socket_write) = socket.split();
576
577 let channel_to_socket = async {
578 futures::io::copy(channel_read, &mut socket_write).await?;
579 socket_write.close().await
580 };
581
582 let socket_to_channel = async {
583 futures::io::copy(socket_read, &mut channel_write).await?;
584 channel_write.close().await
585 };
586
587 match futures::future::try_join(channel_to_socket, socket_to_channel).await {
588 Ok(((), ())) => {}
589 Err(err) if err.kind() == ErrorKind::ConnectionReset => {}
590 Err(err) => return Err(err),
591 }
592 Ok(())
593}
594
595#[cfg(test)]
596mod tests {
597 use super::relay_connected;
598 use crate::ring::FlatRingMem;
599 use futures::AsyncReadExt;
600 use futures::AsyncWriteExt;
601 use pal_async::DefaultDriver;
602 use pal_async::async_test;
603 use pal_async::driver::Driver;
604 use pal_async::socket::PolledSocket;
605 use pal_async::task::Spawn;
606 use pal_async::task::Task;
607 use unix_socket::UnixStream;
608 use vmbus_async::pipe::BytePipe;
609 use vmbus_async::pipe::connected_byte_pipes;
610
611 fn setup_relay<T: Driver + Spawn>(
612 driver: &T,
613 ) -> (
614 BytePipe<FlatRingMem>,
615 PolledSocket<UnixStream>,
616 Task<std::io::Result<()>>,
617 ) {
618 let (hc, c) = connected_byte_pipes(4096);
619 let (s, s2) = UnixStream::pair().unwrap();
620 let s = PolledSocket::new(driver, s).unwrap();
621 let s2 = PolledSocket::new(driver, s2).unwrap();
622 let task = driver.spawn("test", async move { relay_connected(hc, s2).await });
623
624 (c, s, task)
625 }
626
627 #[async_test]
628 async fn test_relay(driver: DefaultDriver) {
629 let (mut c, mut s, task) = setup_relay(&driver);
630
631 let d = b"abcd";
632 let mut v = [0; 4];
633
634 c.write_all(d).await.unwrap();
636 s.read_exact(&mut v).await.unwrap();
637 assert_eq!(&v, d);
638
639 s.write_all(d).await.unwrap();
641 c.read_exact(&mut v).await.unwrap();
642 assert_eq!(&v, d);
643
644 s.write_all(d).await.unwrap();
646 s.close().await.unwrap();
647 c.read_exact(&mut v).await.unwrap();
648 assert_eq!(&v, d);
649
650 c.write_all(d).await.unwrap();
652 s.read_exact(&mut v).await.unwrap();
653 assert_eq!(&v, d);
654
655 c.close().await.unwrap();
656 task.await.unwrap();
657 }
658
659 #[cfg(unix)] #[async_test]
661 async fn test_relay_host_close(driver: DefaultDriver) {
662 let (mut c, _, task) = setup_relay(&driver);
663
664 let mut b = [0];
665 assert_eq!(c.read(&mut b).await.unwrap(), 0);
666 drop(c);
667 task.await.unwrap();
668 }
669
670 #[async_test]
671 async fn test_relay_guest_close(driver: DefaultDriver) {
672 let (_, mut s, task) = setup_relay(&driver);
673
674 let mut b = [0];
675 assert_eq!(s.read(&mut b).await.unwrap(), 0);
676 drop(s);
677 task.await.unwrap();
678 }
679
680 #[async_test]
681 async fn test_relay_forward_socket_shutdown(driver: DefaultDriver) {
682 let (mut c, mut s, task) = setup_relay(&driver);
683 s.close().await.unwrap();
684 let mut v = [0; 1];
685 assert_eq!(c.read(&mut v).await.unwrap(), 0);
686 drop(c);
687 task.await.unwrap();
688 }
689
690 #[async_test]
691 async fn test_relay_forward_channel_shutdown(driver: DefaultDriver) {
692 let (mut c, mut s, task) = setup_relay(&driver);
693
694 c.close().await.unwrap();
695 let mut v = [0; 1];
696 assert_eq!(s.read(&mut v).await.unwrap(), 0);
697 drop(s);
698 task.await.unwrap();
699 }
700}