1use super::Guid;
12use crate::HvsockRelayChannelHalf;
13use crate::ring::RingMem;
14use anyhow::Context;
15use futures::AsyncReadExt;
16use futures::AsyncWriteExt;
17use futures::StreamExt;
18use futures_concurrency::stream::Merge;
19use mesh::CancelContext;
20use pal_async::driver::SpawnDriver;
21use pal_async::socket::PolledSocket;
22use pal_async::task::Spawn;
23use pal_async::task::Task;
24use std::io::ErrorKind;
25use std::path::Path;
26use std::path::PathBuf;
27use std::sync::Arc;
28use std::time::Duration;
29use unicycle::FuturesUnordered;
30use unix_socket::UnixListener;
31use unix_socket::UnixStream;
32use vmbus_async::pipe::BytePipe;
33use vmbus_channel::bus::ChannelType;
34use vmbus_channel::bus::OfferParams;
35use vmbus_channel::bus::ParentBus;
36use vmbus_channel::offer::Offer;
37use vmbus_core::HvsockConnectRequest;
38use vmbus_core::HvsockConnectResult;
39
40pub struct HvsockRelay {
41 inner: Arc<RelayInner>,
42 host_send: mesh::Sender<RelayRequest>,
43 _relay_task: Task<()>,
44 _listener_task: Option<Task<()>>,
45}
46
47enum RelayRequest {
48 AddTask(Task<()>),
49}
50
51struct RelayInner {
52 vmbus: Arc<dyn ParentBus>,
53 driver: Box<dyn SpawnDriver>,
54}
55
56impl HvsockRelay {
57 pub fn new(
60 driver: impl SpawnDriver,
61 vmbus: Arc<dyn ParentBus>,
62 guest: HvsockRelayChannelHalf,
63 hybrid_vsock_path: Option<PathBuf>,
64 hybrid_vsock_listener: Option<UnixListener>,
65 ) -> anyhow::Result<Self> {
66 let inner = Arc::new(RelayInner {
67 vmbus,
68 driver: Box::new(driver),
69 });
70
71 let worker = HvsockRelayWorker {
72 guest_send: guest.response_send,
73 inner: inner.clone(),
74 tasks: Default::default(),
75 hybrid_vsock_path,
76 };
77
78 let (host_send, host_recv) = mesh::channel();
79
80 let _listener_task = if let Some(listener) = hybrid_vsock_listener {
81 let listener = PolledSocket::new(inner.driver.as_ref(), listener)?;
82 Some(
83 inner.driver.spawn(
84 "hvsock-listener",
85 ListenerWorker {
86 inner: inner.clone(),
87 host_send: host_send.clone(),
88 }
89 .run(listener),
90 ),
91 )
92 } else {
93 None
94 };
95
96 let task = inner
97 .driver
98 .spawn("hvsock relay", worker.run(guest.request_receive, host_recv));
99
100 Ok(Self {
101 host_send,
102 inner,
103 _relay_task: task,
104 _listener_task,
105 })
106 }
107
108 pub fn connect(
113 &self,
114 ctx: &mut CancelContext,
115 service_id: Guid,
116 ) -> impl std::future::Future<Output = anyhow::Result<UnixStream>> + Send + use<> {
117 let inner = self.inner.clone();
118 let host_send = self.host_send.clone();
119 let (send, recv) = mesh::oneshot();
120
121 let (mut ctx, cancel) = ctx.with_cancel();
123
124 let task = self.inner.driver.spawn("hvsock-connect", async move {
126 let r = async {
127 let (stream, task) = ctx
128 .until_cancelled(inner.connect_to_guest(service_id))
129 .await??;
130 host_send.send(RelayRequest::AddTask(task));
131 Ok(stream)
132 }
133 .await;
134
135 send.send(r);
136 });
137 self.host_send.send(RelayRequest::AddTask(task));
138 async move {
139 let _cancel = cancel;
140 recv.await?
141 }
142 }
143}
144
145struct ListenerWorker {
146 inner: Arc<RelayInner>,
147 host_send: mesh::Sender<RelayRequest>,
148}
149
150impl ListenerWorker {
151 async fn run(self, mut listener: PolledSocket<UnixListener>) {
152 loop {
153 let connection = match listener.accept().await {
154 Ok((connection, _address)) => connection,
155 Err(err) => {
156 tracing::error!(
157 error = &err as &dyn std::error::Error,
158 "failed to accept hybrid vsock connection, shutting down listener"
159 );
160 break;
161 }
162 };
163 match self.spawn_relay(connection).await {
164 Ok(task) => {
165 self.host_send.send(RelayRequest::AddTask(task));
166 }
167 Err(err) => {
168 tracing::warn!(
169 error = err.as_ref() as &dyn std::error::Error,
170 "relayed connection failed"
171 );
172 }
173 }
174 }
175 }
176
177 async fn spawn_relay(&self, connection: UnixStream) -> anyhow::Result<Task<()>> {
178 let mut socket = PolledSocket::new(self.inner.driver.as_ref(), connection)?;
179 let (service_id, format) = read_hybrid_vsock_connect(&mut socket).await?;
180
181 let instance_id = Guid::new_random();
182 let mut offer = Offer::new(
183 self.inner.driver.as_ref(),
184 self.inner.vmbus.as_ref(),
185 OfferParams {
186 interface_name: "hvsocket_connect".into(),
187 interface_id: service_id,
188 instance_id,
189 channel_type: ChannelType::HvSocket {
190 is_connect: true,
191 is_for_container: false,
192 silo_id: Guid::ZERO,
193 },
194 ..Default::default()
195 },
196 )
197 .await
198 .context("failed to offer channel")?;
199
200 let channel = CancelContext::new()
201 .with_timeout(Duration::from_secs(2))
202 .until_cancelled(offer.accept(self.inner.driver.as_ref()))
203 .await?
204 .context("failed to accept channel")?
205 .channel;
206
207 let pipe = BytePipe::new(channel).context("failed to create vmbus pipe")?;
208
209 tracing::debug!(%service_id, endpoint_id = %instance_id, "connected host to guest");
210
211 let task = self
212 .inner
213 .driver
214 .spawn("hvsock connection relay", async move {
215 let _offer = offer;
217
218 let s = match format {
220 ServiceIdFormat::Vsock => format!("OK {}\n", instance_id.data1),
221 ServiceIdFormat::HyperV => format!("OK {}\n", instance_id),
222 };
223 if let Err(err) = socket.write_all(s.as_bytes()).await {
224 tracing::error!(
225 %service_id,
226 error = &err as &dyn std::error::Error,
227 "failed to write OK response"
228 );
229 }
230
231 if let Err(err) = relay_connected(pipe, socket).await {
232 tracing::error!(
233 %service_id,
234 error = &err as &dyn std::error::Error,
235 "connection relay failed"
236 );
237 }
238 });
239
240 Ok(task)
241 }
242}
243
244#[derive(Debug)]
245enum ServiceIdFormat {
246 Vsock,
247 HyperV,
248}
249
250async fn read_hybrid_vsock_connect(
251 socket: &mut PolledSocket<UnixStream>,
252) -> anyhow::Result<(Guid, ServiceIdFormat)> {
253 let mut buf = [0; "CONNECT 00000000-facb-11e6-bd58-64006a7986d3\n".len()];
254 let mut i = 0;
255 while i == 0 || buf[i - 1] != b'\n' {
256 if i == buf.len() {
257 anyhow::bail!("connect request did not fit");
258 }
259 let n = socket
260 .read(&mut buf[i..])
261 .await
262 .context("failed to read connect request")?;
263 if n == 0 {
264 anyhow::bail!("no connect request");
265 }
266 i += n;
267 }
268
269 let rest = buf[..i - 1]
270 .strip_prefix(b"CONNECT ")
271 .context("invalid connect request")?;
272
273 let rest = std::str::from_utf8(rest).context("invalid connect request")?;
274 let (service_id, format) = if let Ok(port) = rest.parse::<u32>() {
275 (
276 Guid {
277 data1: port,
278 ..VSOCK_TEMPLATE
279 },
280 ServiceIdFormat::Vsock,
281 )
282 } else if let Ok(service_id) = rest.parse::<Guid>() {
283 (service_id, ServiceIdFormat::HyperV)
284 } else {
285 anyhow::bail!("invalid port or service ID: {}", rest);
286 };
287
288 tracing::debug!(%service_id, ?format, "got hybrid connect request");
289 Ok((service_id, format))
290}
291
292struct PendingConnection {
293 send: mesh::Sender<HvsockConnectResult>,
294 request: HvsockConnectRequest,
295}
296
297impl PendingConnection {
298 fn done(self, success: bool) {
299 self.send
300 .send(HvsockConnectResult::from_request(&self.request, success));
301 std::mem::forget(self);
302 }
303}
304
305impl Drop for PendingConnection {
306 fn drop(&mut self) {
307 self.send
308 .send(HvsockConnectResult::from_request(&self.request, false));
309 }
310}
311
312static VSOCK_TEMPLATE: Guid = guid::guid!("00000000-facb-11e6-bd58-64006a7986d3");
315
316fn vsock_port(service_id: &Guid) -> Option<u32> {
317 let stripped_id = Guid {
318 data1: 0,
319 ..*service_id
320 };
321 (VSOCK_TEMPLATE == stripped_id).then_some(service_id.data1)
322}
323
324struct HvsockRelayWorker {
325 guest_send: mesh::Sender<HvsockConnectResult>,
326 tasks: FuturesUnordered<Task<()>>,
327 inner: Arc<RelayInner>,
328 hybrid_vsock_path: Option<PathBuf>,
329}
330
331impl HvsockRelayWorker {
332 async fn run(
333 mut self,
334 guest_recv: mesh::Receiver<HvsockConnectRequest>,
335 host_recv: mesh::Receiver<RelayRequest>,
336 ) {
337 enum Event {
338 Guest(HvsockConnectRequest),
339 Host(RelayRequest),
340 TaskDone(()),
341 }
342
343 let mut recv = (guest_recv.map(Event::Guest), host_recv.map(Event::Host)).merge();
344
345 while let Some(event) = (&mut recv, (&mut self.tasks).map(Event::TaskDone))
346 .merge()
347 .next()
348 .await
349 {
350 match event {
351 Event::Guest(request) => {
352 self.handle_connect_from_guest(request);
353 }
354 Event::Host(request) => match request {
355 RelayRequest::AddTask(task) => {
356 self.tasks.push(task);
357 }
358 },
359 Event::TaskDone(()) => {}
360 }
361 }
362 }
363
364 fn handle_connect_from_guest(&mut self, request: HvsockConnectRequest) {
365 if request.silo_id != Guid::ZERO {
366 tracelimit::warn_ratelimited!(?request, "Non-zero silo ID is currently ignored.")
367 }
368
369 let pending = PendingConnection {
371 send: self.guest_send.clone(),
372 request,
373 };
374 let (path, is_specific_path) = {
375 if let Some(hybrid_vsock_path) = &self.hybrid_vsock_path {
376 (hybrid_vsock_path.to_owned(), false)
377 } else {
378 tracing::debug!(request = ?&request, "ignoring hvsock connect request");
379 return;
380 }
381 };
382
383 let task = self.inner.driver.spawn(
384 format!(
385 "hvsock accept {}:{}",
386 request.service_id, request.endpoint_id
387 ),
388 {
389 let inner = self.inner.clone();
390 async move {
391 match inner
392 .relay_guest_connect_to_host(pending, path.as_ref(), is_specific_path)
393 .await
394 {
395 Ok(()) => {
396 tracing::debug!(request = ?&request, "relay done");
397 }
398 Err(err) => {
399 tracing::error!(
400 request = ?&request,
401 err = err.as_ref() as &dyn std::error::Error,
402 "relay error"
403 );
404 }
405 }
406 }
407 },
408 );
409 self.tasks.push(task);
410 }
411}
412
413impl RelayInner {
414 async fn relay_guest_connect_to_host(
415 &self,
416 pending: PendingConnection,
417 path: &Path,
418 is_specific_path: bool,
419 ) -> anyhow::Result<()> {
420 let request = &pending.request;
421 let socket = self
422 .connect_to_host_uds(request, path, is_specific_path)
423 .await?;
424
425 let mut offer = Offer::new(
426 self.driver.as_ref(),
427 self.vmbus.as_ref(),
428 OfferParams {
429 interface_name: "hvsocket".to_owned(),
430 instance_id: request.endpoint_id,
431 interface_id: request.service_id,
432 channel_type: ChannelType::HvSocket {
433 is_connect: false,
434 is_for_container: false,
435 silo_id: Guid::ZERO,
436 },
437 ..Default::default()
438 },
439 )
440 .await
441 .context("failed to offer channel")?;
442
443 pending.done(true);
446
447 let channel = offer.accept(self.driver.as_ref()).await?.channel;
448 let channel = BytePipe::new(channel)?;
449 relay_connected(channel, socket).await?;
450 drop(offer);
453 Ok(())
454 }
455
456 async fn connect_to_host_uds(
457 &self,
458 request: &HvsockConnectRequest,
459 path: &Path,
460 is_specific_path: bool,
461 ) -> anyhow::Result<PolledSocket<UnixStream>> {
462 if is_specific_path {
463 let socket = PolledSocket::connect_unix(self.driver.as_ref(), path)
465 .await
466 .with_context(|| {
467 format!(
468 "failed to connect to registered listener {} for {}",
469 path.display(),
470 request.service_id
471 )
472 })?;
473 return Ok(socket);
474 }
475
476 if let Some(port) = vsock_port(&request.service_id) {
477 let mut path = path.as_os_str().to_owned();
480 path.push(format!("_{port}"));
481 if let Ok(socket) = PolledSocket::connect_unix(self.driver.as_ref(), path).await {
482 return Ok(socket);
483 }
484 }
485
486 let mut path = path.as_os_str().to_owned();
489 path.push(format!("_{}", request.service_id));
490 let path = Path::new(&path);
491 let socket = PolledSocket::connect_unix(self.driver.as_ref(), path)
492 .await
493 .with_context(|| {
494 format!(
495 "failed to connect to hybrid vsock listener {} for {}",
496 path.display(),
497 request.service_id
498 )
499 })?;
500
501 Ok(socket)
502 }
503
504 async fn connect_to_guest(&self, service_id: Guid) -> anyhow::Result<(UnixStream, Task<()>)> {
505 let instance_id = Guid::new_random();
506 let mut offer = Offer::new(
507 &self.driver,
508 self.vmbus.as_ref(),
509 OfferParams {
510 interface_name: "hvsocket_connect".into(),
511 interface_id: service_id,
512 instance_id,
513 channel_type: ChannelType::HvSocket {
514 is_connect: true,
515 is_for_container: false,
516 silo_id: Guid::ZERO,
517 },
518 ..Default::default()
519 },
520 )
521 .await
522 .context("failed to offer channel")?;
523
524 let channel = offer
525 .accept(self.driver.as_ref())
526 .await
527 .context("failed to accept channel")?
528 .channel;
529 let pipe = BytePipe::new(channel).context("failed to create vmbus pipe")?;
530
531 tracing::debug!(%service_id, endpoint_id = %instance_id, "connected host to guest");
532
533 let (left, right) = UnixStream::pair().context("failed to create socket pair")?;
534 let right = PolledSocket::new(self.driver.as_ref(), right)
535 .context("failed to create polled socket")?;
536
537 let task = self.driver.spawn(
538 format!("hvsock {}:{}", service_id, instance_id),
539 async move {
540 let _offer = offer;
542 if let Err(err) = relay_connected(pipe, right).await {
543 tracing::error!(
544 %service_id,
545 error = &err as &dyn std::error::Error,
546 "connection relay failed"
547 );
548 }
549 },
550 );
551
552 Ok((left, task))
553 }
554}
555
556async fn relay_connected<T: RingMem + Unpin>(
557 channel: BytePipe<T>,
558 socket: PolledSocket<UnixStream>,
559) -> std::io::Result<()> {
560 let (channel_read, mut channel_write) = channel.split();
561 let (socket_read, mut socket_write) = socket.split();
562
563 let channel_to_socket = async {
564 futures::io::copy(channel_read, &mut socket_write).await?;
565 socket_write.close().await
566 };
567
568 let socket_to_channel = async {
569 futures::io::copy(socket_read, &mut channel_write).await?;
570 channel_write.close().await
571 };
572
573 match futures::future::try_join(channel_to_socket, socket_to_channel).await {
574 Ok(((), ())) => {}
575 Err(err) if err.kind() == ErrorKind::ConnectionReset => {}
576 Err(err) => return Err(err),
577 }
578 Ok(())
579}
580
581#[cfg(test)]
582mod tests {
583 use super::relay_connected;
584 use crate::ring::FlatRingMem;
585 use futures::AsyncReadExt;
586 use futures::AsyncWriteExt;
587 use pal_async::DefaultDriver;
588 use pal_async::async_test;
589 use pal_async::driver::Driver;
590 use pal_async::socket::PolledSocket;
591 use pal_async::task::Spawn;
592 use pal_async::task::Task;
593 use unix_socket::UnixStream;
594 use vmbus_async::pipe::BytePipe;
595 use vmbus_async::pipe::connected_byte_pipes;
596
597 fn setup_relay<T: Driver + Spawn>(
598 driver: &T,
599 ) -> (
600 BytePipe<FlatRingMem>,
601 PolledSocket<UnixStream>,
602 Task<std::io::Result<()>>,
603 ) {
604 let (hc, c) = connected_byte_pipes(4096);
605 let (s, s2) = UnixStream::pair().unwrap();
606 let s = PolledSocket::new(driver, s).unwrap();
607 let s2 = PolledSocket::new(driver, s2).unwrap();
608 let task = driver.spawn("test", async move { relay_connected(hc, s2).await });
609
610 (c, s, task)
611 }
612
613 #[async_test]
614 async fn test_relay(driver: DefaultDriver) {
615 let (mut c, mut s, task) = setup_relay(&driver);
616
617 let d = b"abcd";
618 let mut v = [0; 4];
619
620 c.write_all(d).await.unwrap();
622 s.read_exact(&mut v).await.unwrap();
623 assert_eq!(&v, d);
624
625 s.write_all(d).await.unwrap();
627 c.read_exact(&mut v).await.unwrap();
628 assert_eq!(&v, d);
629
630 s.write_all(d).await.unwrap();
632 s.close().await.unwrap();
633 c.read_exact(&mut v).await.unwrap();
634 assert_eq!(&v, d);
635
636 c.write_all(d).await.unwrap();
638 s.read_exact(&mut v).await.unwrap();
639 assert_eq!(&v, d);
640
641 c.close().await.unwrap();
642 task.await.unwrap();
643 }
644
645 #[cfg(unix)] #[async_test]
647 async fn test_relay_host_close(driver: DefaultDriver) {
648 let (mut c, _, task) = setup_relay(&driver);
649
650 let mut b = [0];
651 assert_eq!(c.read(&mut b).await.unwrap(), 0);
652 drop(c);
653 task.await.unwrap();
654 }
655
656 #[async_test]
657 async fn test_relay_guest_close(driver: DefaultDriver) {
658 let (_, mut s, task) = setup_relay(&driver);
659
660 let mut b = [0];
661 assert_eq!(s.read(&mut b).await.unwrap(), 0);
662 drop(s);
663 task.await.unwrap();
664 }
665
666 #[async_test]
667 async fn test_relay_forward_socket_shutdown(driver: DefaultDriver) {
668 let (mut c, mut s, task) = setup_relay(&driver);
669 s.close().await.unwrap();
670 let mut v = [0; 1];
671 assert_eq!(c.read(&mut v).await.unwrap(), 0);
672 drop(c);
673 task.await.unwrap();
674 }
675
676 #[async_test]
677 async fn test_relay_forward_channel_shutdown(driver: DefaultDriver) {
678 let (mut c, mut s, task) = setup_relay(&driver);
679
680 c.close().await.unwrap();
681 let mut v = [0; 1];
682 assert_eq!(s.read(&mut v).await.unwrap(), 0);
683 drop(s);
684 task.await.unwrap();
685 }
686}