1use super::Guid;
12use crate::HvsockRelayChannelHalf;
13use crate::ring::RingMem;
14use anyhow::Context;
15use futures::AsyncReadExt;
16use futures::AsyncWriteExt;
17use futures::StreamExt;
18use futures_concurrency::stream::Merge;
19use hybrid_vsock::HYBRID_CONNECT_REQUEST_LEN;
20use hybrid_vsock::VsockPortOrId;
21use mesh::CancelContext;
22use pal_async::driver::SpawnDriver;
23use pal_async::socket::PolledSocket;
24use pal_async::task::Spawn;
25use pal_async::task::Task;
26use std::io::ErrorKind;
27use std::path::Path;
28use std::path::PathBuf;
29use std::sync::Arc;
30use std::time::Duration;
31use unicycle::FuturesUnordered;
32use unix_socket::UnixListener;
33use unix_socket::UnixStream;
34use vmbus_async::pipe::BytePipe;
35use vmbus_channel::bus::ChannelType;
36use vmbus_channel::bus::OfferParams;
37use vmbus_channel::bus::ParentBus;
38use vmbus_channel::offer::Offer;
39use vmbus_core::HvsockConnectRequest;
40use vmbus_core::HvsockConnectResult;
41
42pub struct HvsockRelay {
43 inner: Arc<RelayInner>,
44 host_send: mesh::Sender<RelayRequest>,
45 _relay_task: Task<()>,
46 _listener_task: Option<Task<()>>,
47}
48
49enum RelayRequest {
50 AddTask(Task<()>),
51}
52
53struct RelayInner {
54 vmbus: Arc<dyn ParentBus>,
55 driver: Box<dyn SpawnDriver>,
56}
57
58impl HvsockRelay {
59 pub fn new(
62 driver: impl SpawnDriver,
63 vmbus: Arc<dyn ParentBus>,
64 guest: HvsockRelayChannelHalf,
65 hybrid_vsock_path: Option<PathBuf>,
66 hybrid_vsock_listener: Option<UnixListener>,
67 ) -> anyhow::Result<Self> {
68 let inner = Arc::new(RelayInner {
69 vmbus,
70 driver: Box::new(driver),
71 });
72
73 let worker = HvsockRelayWorker {
74 guest_send: guest.response_send,
75 inner: inner.clone(),
76 tasks: Default::default(),
77 hybrid_vsock_path,
78 };
79
80 let (host_send, host_recv) = mesh::channel();
81
82 let _listener_task = if let Some(listener) = hybrid_vsock_listener {
83 let listener = PolledSocket::new(inner.driver.as_ref(), listener)?;
84 Some(
85 inner.driver.spawn(
86 "hvsock-listener",
87 ListenerWorker {
88 inner: inner.clone(),
89 host_send: host_send.clone(),
90 }
91 .run(listener),
92 ),
93 )
94 } else {
95 None
96 };
97
98 let task = inner
99 .driver
100 .spawn("hvsock relay", worker.run(guest.request_receive, host_recv));
101
102 Ok(Self {
103 host_send,
104 inner,
105 _relay_task: task,
106 _listener_task,
107 })
108 }
109
110 pub fn connect(
115 &self,
116 ctx: &mut CancelContext,
117 service_id: Guid,
118 ) -> impl Future<Output = anyhow::Result<UnixStream>> + Send + use<> {
119 let inner = self.inner.clone();
120 let host_send = self.host_send.clone();
121 let (send, recv) = mesh::oneshot();
122
123 let (mut ctx, cancel) = ctx.with_cancel();
125
126 let task = self.inner.driver.spawn("hvsock-connect", async move {
128 let r = async {
129 let (stream, task) = ctx
130 .until_cancelled(inner.connect_to_guest(service_id))
131 .await??;
132 host_send.send(RelayRequest::AddTask(task));
133 Ok(stream)
134 }
135 .await;
136
137 send.send(r);
138 });
139 self.host_send.send(RelayRequest::AddTask(task));
140 async move {
141 let _cancel = cancel;
142 recv.await?
143 }
144 }
145}
146
147struct ListenerWorker {
148 inner: Arc<RelayInner>,
149 host_send: mesh::Sender<RelayRequest>,
150}
151
152impl ListenerWorker {
153 async fn run(self, mut listener: PolledSocket<UnixListener>) {
154 loop {
155 let connection = match listener.accept().await {
156 Ok((connection, _address)) => connection,
157 Err(err) => {
158 tracing::error!(
159 error = &err as &dyn std::error::Error,
160 "failed to accept hybrid vsock connection, shutting down listener"
161 );
162 break;
163 }
164 };
165 match self.spawn_relay(connection).await {
166 Ok(task) => {
167 self.host_send.send(RelayRequest::AddTask(task));
168 }
169 Err(err) => {
170 tracing::warn!(
171 error = err.as_ref() as &dyn std::error::Error,
172 "relayed connection failed"
173 );
174 }
175 }
176 }
177 }
178
179 async fn spawn_relay(&self, connection: UnixStream) -> anyhow::Result<Task<()>> {
180 let mut socket = PolledSocket::new(self.inner.driver.as_ref(), connection)?;
181 let request = read_hybrid_vsock_connect(&mut socket).await?;
182
183 let instance_id = Guid::new_random();
184 let mut offer = Offer::new(
185 self.inner.driver.as_ref(),
186 self.inner.vmbus.as_ref(),
187 OfferParams {
188 interface_name: "hvsocket_connect".into(),
189 interface_id: request.id(),
190 instance_id,
191 channel_type: ChannelType::HvSocket {
192 is_connect: true,
193 is_for_container: false,
194 silo_id: Guid::ZERO,
195 },
196 ..Default::default()
197 },
198 )
199 .await
200 .context("failed to offer channel")?;
201
202 let channel = CancelContext::new()
203 .with_timeout(Duration::from_secs(2))
204 .until_cancelled(offer.wait_for_open(self.inner.driver.as_ref()))
205 .await?
206 .context("failed to accept channel")?
207 .accept()
208 .channel;
209
210 let pipe = BytePipe::new(channel).context("failed to create vmbus pipe")?;
211
212 tracing::debug!(service_id = %request.id(), endpoint_id = %instance_id, "connected host to guest");
213
214 let task = self
215 .inner
216 .driver
217 .spawn("hvsock connection relay", async move {
218 let _offer = offer;
220
221 let response = match request {
224 VsockPortOrId::Port(_) => VsockPortOrId::Port(instance_id.data1),
225 VsockPortOrId::Id(_) => VsockPortOrId::Id(instance_id),
226 };
227 let s = response.get_ok_response();
228 if let Err(err) = socket.write_all(s.as_bytes()).await {
229 tracing::error!(
230 service_id = %request.id(),
231 error = &err as &dyn std::error::Error,
232 "failed to write OK response"
233 );
234 }
235
236 if let Err(err) = relay_connected(pipe, socket).await {
237 tracing::error!(
238 service_id = %request.id(),
239 error = &err as &dyn std::error::Error,
240 "connection relay failed"
241 );
242 } else {
243 tracing::debug!(service_id = %request.id(), "connection relay finished");
244 }
245 });
246
247 Ok(task)
248 }
249}
250
251async fn read_hybrid_vsock_connect(
252 socket: &mut PolledSocket<UnixStream>,
253) -> anyhow::Result<VsockPortOrId> {
254 let mut buf = [0; HYBRID_CONNECT_REQUEST_LEN];
255 let mut i = 0;
256 while i == 0 || buf[i - 1] != b'\n' {
257 if i == buf.len() {
258 anyhow::bail!("connect request did not fit");
259 }
260 let n = socket
261 .read(&mut buf[i..])
262 .await
263 .context("failed to read connect request")?;
264 if n == 0 {
265 anyhow::bail!("no connect request");
266 }
267 i += n;
268 }
269
270 let request = VsockPortOrId::parse_connect_request(&buf[..i - 1])?;
271 tracing::debug!(?request, "got hybrid connect request");
272 Ok(request)
273}
274
275struct PendingConnection {
276 send: mesh::Sender<HvsockConnectResult>,
277 request: HvsockConnectRequest,
278}
279
280impl PendingConnection {
281 fn done(self, success: bool) {
282 self.send
283 .send(HvsockConnectResult::from_request(&self.request, success));
284 std::mem::forget(self);
285 }
286}
287
288impl Drop for PendingConnection {
289 fn drop(&mut self) {
290 self.send
291 .send(HvsockConnectResult::from_request(&self.request, false));
292 }
293}
294
295struct HvsockRelayWorker {
296 guest_send: mesh::Sender<HvsockConnectResult>,
297 tasks: FuturesUnordered<Task<()>>,
298 inner: Arc<RelayInner>,
299 hybrid_vsock_path: Option<PathBuf>,
300}
301
302impl HvsockRelayWorker {
303 async fn run(
304 mut self,
305 guest_recv: mesh::Receiver<HvsockConnectRequest>,
306 host_recv: mesh::Receiver<RelayRequest>,
307 ) {
308 enum Event {
309 Guest(HvsockConnectRequest),
310 Host(RelayRequest),
311 TaskDone(()),
312 }
313
314 let mut recv = (guest_recv.map(Event::Guest), host_recv.map(Event::Host)).merge();
315
316 while let Some(event) = (&mut recv, (&mut self.tasks).map(Event::TaskDone))
317 .merge()
318 .next()
319 .await
320 {
321 match event {
322 Event::Guest(request) => {
323 self.handle_connect_from_guest(request);
324 }
325 Event::Host(request) => match request {
326 RelayRequest::AddTask(task) => {
327 self.tasks.push(task);
328 }
329 },
330 Event::TaskDone(()) => {}
331 }
332 }
333 }
334
335 fn handle_connect_from_guest(&mut self, request: HvsockConnectRequest) {
336 if request.silo_id != Guid::ZERO {
337 tracelimit::warn_ratelimited!(?request, "Non-zero silo ID is currently ignored.")
338 }
339
340 let pending = PendingConnection {
342 send: self.guest_send.clone(),
343 request,
344 };
345 let path = {
346 if let Some(hybrid_vsock_path) = &self.hybrid_vsock_path {
347 hybrid_vsock_path.to_owned()
348 } else {
349 tracing::debug!(request = ?&request, "ignoring hvsock connect request");
350 return;
351 }
352 };
353
354 let task = self.inner.driver.spawn(
355 format!(
356 "hvsock accept {}:{}",
357 request.service_id, request.endpoint_id
358 ),
359 {
360 let inner = self.inner.clone();
361 async move {
362 match inner
363 .relay_guest_connect_to_host(pending, path.as_ref())
364 .await
365 {
366 Ok(()) => {
367 tracing::debug!(request = ?&request, "relay done");
368 }
369 Err(err) => {
370 tracelimit::error_ratelimited!(
371 request = ?&request,
372 err = err.as_ref() as &dyn std::error::Error,
373 "relay error"
374 );
375 }
376 }
377 }
378 },
379 );
380 self.tasks.push(task);
381 }
382}
383
384impl RelayInner {
385 async fn relay_guest_connect_to_host(
386 &self,
387 pending: PendingConnection,
388 base_path: &Path,
389 ) -> anyhow::Result<()> {
390 let request = &pending.request;
391
392 let vsock_request = VsockPortOrId::Id(request.service_id);
397 let path = vsock_request.host_uds_path(base_path)?;
398
399 let mut offer = Offer::new(
400 self.driver.as_ref(),
401 self.vmbus.as_ref(),
402 OfferParams {
403 interface_name: "hvsocket".to_owned(),
404 instance_id: request.endpoint_id,
405 interface_id: request.service_id,
406 channel_type: ChannelType::HvSocket {
407 is_connect: false,
408 is_for_container: false,
409 silo_id: Guid::ZERO,
410 },
411 ..Default::default()
412 },
413 )
414 .await
415 .context("failed to offer channel")?;
416
417 tracing::debug!(?request, "offered hvsocket channel to guest");
418 let service_id = request.service_id;
419
420 pending.done(true);
423
424 let channel = CancelContext::new()
426 .with_timeout(Duration::from_secs(5))
427 .until_cancelled(offer.wait_for_open(self.driver.as_ref()))
428 .await
429 .context("guest did not open hvsocket channel")??;
430
431 tracing::debug!(%service_id, "guest opened hvsocket channel");
432
433 let socket = PolledSocket::connect_unix(self.driver.as_ref(), &path)
435 .await
436 .with_context(|| {
437 format!(
438 "failed to connect to registered listener {} for {}",
439 path.display(),
440 service_id
441 )
442 })?;
443
444 tracing::debug!(%service_id, path = %path.display(), "connected to host uds socket");
445
446 let channel = channel.accept().channel;
448
449 let channel = BytePipe::new(channel)?;
450 if let Err(err) = relay_connected(channel, socket).await {
451 tracelimit::error_ratelimited!(
452 %service_id,
453 error = &err as &dyn std::error::Error,
454 "guest to host connection relay failed"
455 );
456 } else {
457 tracing::debug!(%service_id, "guest to host connection relay finished");
458 }
459
460 drop(offer);
463 Ok(())
464 }
465
466 async fn connect_to_guest(&self, service_id: Guid) -> anyhow::Result<(UnixStream, Task<()>)> {
467 let instance_id = Guid::new_random();
468 let mut offer = Offer::new(
469 &self.driver,
470 self.vmbus.as_ref(),
471 OfferParams {
472 interface_name: "hvsocket_connect".into(),
473 interface_id: service_id,
474 instance_id,
475 channel_type: ChannelType::HvSocket {
476 is_connect: true,
477 is_for_container: false,
478 silo_id: Guid::ZERO,
479 },
480 ..Default::default()
481 },
482 )
483 .await
484 .context("failed to offer channel")?;
485
486 let channel = offer
487 .wait_for_open(self.driver.as_ref())
488 .await
489 .context("failed to accept channel")?
490 .accept()
491 .channel;
492 let pipe = BytePipe::new(channel).context("failed to create vmbus pipe")?;
493
494 tracing::debug!(%service_id, endpoint_id = %instance_id, "connected host to guest");
495
496 let (left, right) = UnixStream::pair().context("failed to create socket pair")?;
497 let right = PolledSocket::new(self.driver.as_ref(), right)
498 .context("failed to create polled socket")?;
499
500 let task = self.driver.spawn(
501 format!("hvsock {}:{}", service_id, instance_id),
502 async move {
503 let _offer = offer;
505 if let Err(err) = relay_connected(pipe, right).await {
506 tracing::error!(
507 %service_id,
508 error = &err as &dyn std::error::Error,
509 "connection relay failed"
510 );
511 }
512 },
513 );
514
515 Ok((left, task))
516 }
517}
518
519async fn relay_connected<T: RingMem + Unpin>(
520 channel: BytePipe<T>,
521 socket: PolledSocket<UnixStream>,
522) -> std::io::Result<()> {
523 let (channel_read, mut channel_write) = channel.split();
524 let (socket_read, mut socket_write) = socket.split();
525
526 let channel_to_socket = async {
527 futures::io::copy(channel_read, &mut socket_write).await?;
528 socket_write.close().await
529 };
530
531 let socket_to_channel = async {
532 futures::io::copy(socket_read, &mut channel_write).await?;
533 channel_write.close().await
534 };
535
536 match futures::future::try_join(channel_to_socket, socket_to_channel).await {
537 Ok(((), ())) => {}
538 Err(err) if err.kind() == ErrorKind::ConnectionReset => {}
539 Err(err) => return Err(err),
540 }
541 Ok(())
542}
543
544#[cfg(test)]
545mod tests {
546 use super::relay_connected;
547 use crate::ring::FlatRingMem;
548 use futures::AsyncReadExt;
549 use futures::AsyncWriteExt;
550 use pal_async::DefaultDriver;
551 use pal_async::async_test;
552 use pal_async::driver::Driver;
553 use pal_async::socket::PolledSocket;
554 use pal_async::task::Spawn;
555 use pal_async::task::Task;
556 use unix_socket::UnixStream;
557 use vmbus_async::pipe::BytePipe;
558 use vmbus_async::pipe::connected_byte_pipes;
559
560 fn setup_relay<T: Driver + Spawn>(
561 driver: &T,
562 ) -> (
563 BytePipe<FlatRingMem>,
564 PolledSocket<UnixStream>,
565 Task<std::io::Result<()>>,
566 ) {
567 let (hc, c) = connected_byte_pipes(4096);
568 let (s, s2) = UnixStream::pair().unwrap();
569 let s = PolledSocket::new(driver, s).unwrap();
570 let s2 = PolledSocket::new(driver, s2).unwrap();
571 let task = driver.spawn("test", async move { relay_connected(hc, s2).await });
572
573 (c, s, task)
574 }
575
576 #[async_test]
577 async fn test_relay(driver: DefaultDriver) {
578 let (mut c, mut s, task) = setup_relay(&driver);
579
580 let d = b"abcd";
581 let mut v = [0; 4];
582
583 c.write_all(d).await.unwrap();
585 s.read_exact(&mut v).await.unwrap();
586 assert_eq!(&v, d);
587
588 s.write_all(d).await.unwrap();
590 c.read_exact(&mut v).await.unwrap();
591 assert_eq!(&v, d);
592
593 s.write_all(d).await.unwrap();
595 s.close().await.unwrap();
596 c.read_exact(&mut v).await.unwrap();
597 assert_eq!(&v, d);
598
599 c.write_all(d).await.unwrap();
601 s.read_exact(&mut v).await.unwrap();
602 assert_eq!(&v, d);
603
604 c.close().await.unwrap();
605 task.await.unwrap();
606 }
607
608 #[cfg(unix)] #[async_test]
610 async fn test_relay_host_close(driver: DefaultDriver) {
611 let (mut c, _, task) = setup_relay(&driver);
612
613 let mut b = [0];
614 assert_eq!(c.read(&mut b).await.unwrap(), 0);
615 drop(c);
616 task.await.unwrap();
617 }
618
619 #[async_test]
620 async fn test_relay_guest_close(driver: DefaultDriver) {
621 let (_, mut s, task) = setup_relay(&driver);
622
623 let mut b = [0];
624 assert_eq!(s.read(&mut b).await.unwrap(), 0);
625 drop(s);
626 task.await.unwrap();
627 }
628
629 #[async_test]
630 async fn test_relay_forward_socket_shutdown(driver: DefaultDriver) {
631 let (mut c, mut s, task) = setup_relay(&driver);
632 s.close().await.unwrap();
633 let mut v = [0; 1];
634 assert_eq!(c.read(&mut v).await.unwrap(), 0);
635 drop(c);
636 task.await.unwrap();
637 }
638
639 #[async_test]
640 async fn test_relay_forward_channel_shutdown(driver: DefaultDriver) {
641 let (mut c, mut s, task) = setup_relay(&driver);
642
643 c.close().await.unwrap();
644 let mut v = [0; 1];
645 assert_eq!(s.read(&mut v).await.unwrap(), 0);
646 drop(s);
647 task.await.unwrap();
648 }
649}