1#![expect(unsafe_code)]
8
9mod connections;
10pub mod resolver;
11mod ring;
12mod spec;
13mod unix_relay;
14
15#[cfg(test)]
16mod integration_tests;
17
18use crate::connections::ConnectionInstanceId;
19use crate::connections::ConnectionKey;
20use crate::connections::TX_BUF_SIZE;
21use crate::spec::Operation;
22use crate::spec::VSOCK_HEADER_SIZE;
23use crate::spec::VsockFeaturesBank0;
24use crate::spec::VsockPacket;
25use crate::spec::VsockPacketBuf;
26use anyhow::Context;
27use connections::ConnectionManager;
28use futures::FutureExt;
29use futures::StreamExt;
30use futures::future::OptionFuture;
31use futures::future::poll_fn;
32use futures::stream::Fuse;
33use guestmem::GuestMemory;
34use guestmem::LockedRange;
35use guestmem::LockedRangeImpl;
36use guestmem::ranges::PagedRange;
37use inspect::InspectMut;
38use pal_async::socket::PolledSocket;
39use pal_async::wait::PolledWait;
40use smallvec::SmallVec;
41use spec::VsockConfig;
42use spec::VsockHeader;
43use std::io::IoSlice;
44use std::io::IoSliceMut;
45use std::path::PathBuf;
46use std::pin::Pin;
47use task_control::AsyncRun;
48use task_control::StopTask;
49use task_control::TaskControl;
50use unicycle::FuturesUnordered;
51use unix_socket::UnixListener;
52use virtio::DeviceTraits;
53use virtio::VirtioDevice;
54use virtio::VirtioQueue;
55use virtio::VirtioQueueCallbackWork;
56use virtio::queue::VirtioQueuePayload;
57use virtio::regions::data_regions;
58use virtio::regions::try_build_gpn_list;
59use virtio::spec::VirtioDeviceFeatures;
60use virtio::spec::VirtioDeviceType;
61use vmcore::vm_task::VmTaskDriver;
62use vmcore::vm_task::VmTaskDriverSource;
63use zerocopy::FromZeros;
64use zerocopy::IntoBytes;
65
66const QUEUE_COUNT: usize = 3;
67const RX_QUEUE_INDEX: usize = 0;
68const TX_QUEUE_INDEX: usize = 1;
69const EVENT_QUEUE_INDEX: usize = 2;
70
71#[derive(InspectMut)]
73pub struct VirtioVsockDevice {
74 guest_cid: u64,
75 driver: VmTaskDriver,
76 #[inspect(skip)]
77 worker: TaskControl<VsockWorker, VsockWorkerState>,
78 #[inspect(skip)]
79 started_queues: [Option<VirtioQueue>; QUEUE_COUNT],
80 #[inspect(skip)]
81 base_path: PathBuf,
82}
83
84impl VirtioVsockDevice {
85 pub fn new(
95 driver_source: &VmTaskDriverSource,
96 guest_cid: u64,
97 base_path: PathBuf,
98 listener: UnixListener,
99 ) -> anyhow::Result<Self> {
100 let driver = driver_source.simple();
101 let listener = PolledSocket::new(&driver, listener)
102 .context("failed to create polled socket for vsock relay listener")?;
103 Ok(Self {
104 guest_cid,
105 driver: driver.clone(),
106 worker: TaskControl::new(VsockWorker { driver, listener }),
107 started_queues: [const { None }; QUEUE_COUNT],
108 base_path,
109 })
110 }
111}
112
113impl VirtioDevice for VirtioVsockDevice {
114 fn traits(&self) -> DeviceTraits {
115 let features_bank0 = VsockFeaturesBank0::new()
117 .with_stream(true)
118 .with_no_implied_stream(true);
119 DeviceTraits {
120 device_id: VirtioDeviceType::VSOCK,
121 device_features: VirtioDeviceFeatures::new()
122 .with_device_specific_low(features_bank0.into_bits()),
123 max_queues: QUEUE_COUNT.try_into().unwrap(),
124 device_register_length: size_of::<VsockConfig>() as u32,
125 ..Default::default()
126 }
127 }
128
129 async fn read_registers_u32(&mut self, offset: u16) -> u32 {
130 let config = VsockConfig {
132 guest_cid: self.guest_cid.to_le(),
133 };
134 let bytes = config.as_bytes();
135 let offset = offset as usize;
136 if offset + 4 <= bytes.len() {
137 u32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap())
138 } else {
139 0
140 }
141 }
142
143 async fn write_registers_u32(&mut self, offset: u16, val: u32) {
144 tracelimit::warn_ratelimited!(offset, val, "vsock: unexpected config write");
145 }
146
147 async fn start_queue(
148 &mut self,
149 idx: u16,
150 resources: virtio::QueueResources,
151 features: &VirtioDeviceFeatures,
152 initial_state: Option<virtio::queue::QueueState>,
153 ) -> anyhow::Result<()> {
154 if self
155 .started_queues
156 .get(idx as usize)
157 .ok_or_else(|| anyhow::anyhow!("invalid queue index {idx}"))?
158 .is_some()
159 {
160 anyhow::bail!("virtio queue already started");
161 }
162
163 let negotiated_features = VsockFeaturesBank0::from_bits(features.bank(0));
170 if negotiated_features.no_implied_stream() && !negotiated_features.stream() {
171 anyhow::bail!("guest does not support stream sockets");
172 }
173
174 let queue_event = PolledWait::new(&self.driver, resources.event)
175 .context("failed to create queue event")?;
176
177 let queue = VirtioQueue::new(
178 *features,
179 resources.params,
180 resources.guest_memory.clone(),
181 resources.notify,
182 queue_event,
183 initial_state,
184 )
185 .context("failed to create virtio queue")?;
186
187 self.started_queues[idx as usize] = Some(queue);
188
189 if self.started_queues.iter().all(|q| q.is_some()) {
191 let state = VsockWorkerState {
192 rx_queue: self.started_queues[RX_QUEUE_INDEX].take().unwrap(),
193 tx_queue: self.started_queues[TX_QUEUE_INDEX].take().unwrap().fuse(),
194 _event_queue: self.started_queues[EVENT_QUEUE_INDEX].take().unwrap(),
195 memory: resources.guest_memory.clone(),
196 connections: ConnectionManager::new(self.guest_cid, self.base_path.clone()),
197 rx_ready: FuturesUnordered::new(),
198 write_ready: FuturesUnordered::new(),
199 };
200
201 self.worker
202 .insert(self.driver.clone(), "virtio-vsock-worker", state);
203 self.worker.start();
204 }
205
206 Ok(())
207 }
208
209 async fn stop_queue(&mut self, idx: u16) -> Option<virtio::queue::QueueState> {
210 if self.worker.stop().await {
212 let state = self.worker.remove();
213
214 self.started_queues[RX_QUEUE_INDEX] = Some(state.rx_queue);
217 self.started_queues[TX_QUEUE_INDEX] = Some(state.tx_queue.into_inner());
218 self.started_queues[EVENT_QUEUE_INDEX] = Some(state._event_queue);
219 }
220
221 self.started_queues[idx as usize]
223 .take()
224 .map(|queue| queue.queue_state())
225 }
226}
227
228pub enum RxReady {
233 Connection(ConnectionInstanceId),
235 PendingConnection(u64),
237 SendReset(ConnectionKey),
239}
240
241type RxReadyItem = Pin<Box<dyn Future<Output = RxReady> + Send>>;
243
244type WriteReadyItem = Pin<Box<dyn Future<Output = ConnectionInstanceId> + Send>>;
247
248struct PendingFutures {
250 rx_ready: Option<RxReadyItem>,
251 write_ready: Option<WriteReadyItem>,
252}
253
254impl PendingFutures {
255 const NONE: Self = Self {
257 rx_ready: None,
258 write_ready: None,
259 };
260
261 fn rx(future: Option<RxReadyItem>) -> Self {
263 Self {
264 rx_ready: future,
265 write_ready: None,
266 }
267 }
268
269 fn simple_rx(work: RxReady) -> Self {
272 Self {
273 rx_ready: Some(Box::pin(async move { work })),
274 write_ready: None,
275 }
276 }
277
278 fn new(work: Option<WriteReadyItem>, rx_work: Option<RxReady>) -> Self {
280 Self {
281 rx_ready: rx_work.map(|w| -> RxReadyItem { Box::pin(async move { w }) }),
282 write_ready: work,
283 }
284 }
285}
286
287struct VsockWorkerState {
289 connections: ConnectionManager,
290 rx_queue: VirtioQueue,
291 tx_queue: Fuse<VirtioQueue>,
292 _event_queue: VirtioQueue,
294 memory: GuestMemory,
295 rx_ready: FuturesUnordered<RxReadyItem>,
296 write_ready: FuturesUnordered<WriteReadyItem>,
297}
298
299impl VsockWorkerState {
300 fn queue_pending(&mut self, pending: PendingFutures) {
303 if let Some(work) = pending.rx_ready {
304 self.rx_ready.push(work);
305 }
306 if let Some(work) = pending.write_ready {
307 self.write_ready.push(work);
308 }
309 }
310}
311
312struct VsockWorker {
314 driver: VmTaskDriver,
315 listener: PolledSocket<UnixListener>,
316}
317
318impl VsockWorker {
319 fn handle_guest_tx(&mut self, state: &mut VsockWorkerState, work: VirtioQueueCallbackWork) {
321 if let Err(err) = self.handle_guest_tx_inner(state, &work) {
322 tracelimit::error_ratelimited!(
323 error = err.as_ref() as &dyn std::error::Error,
324 "error handling vsock tx work"
325 );
326 }
327
328 state.tx_queue.get_mut().complete(work, 0);
329 }
330
331 fn handle_guest_tx_inner(
333 &mut self,
334 state: &mut VsockWorkerState,
335 work: &VirtioQueueCallbackWork,
336 ) -> anyhow::Result<()> {
337 let mut header = VsockHeader::new_zeroed();
338 work.read(&state.memory, header.as_mut_bytes())?;
339
340 let rw_len = if header.operation() == Operation::RW {
341 let len = header.len;
343
344 if len > TX_BUF_SIZE {
348 anyhow::bail!("guest attempted to send packet with data length {len}");
349 }
350
351 len
352 } else {
353 0
355 };
356
357 tracing::trace!(?header, "got tx packet from guest");
358 let pending = {
359 if rw_len == 0 {
360 state
362 .connections
363 .handle_guest_tx(&self.driver, VsockPacket::new(header, &[]))
364 } else if let Some(locked) = lock_payload_data(
365 &state.memory,
366 &work.payload,
367 rw_len as u64,
368 true,
369 false,
370 LockedIoSlice::new(),
371 )? {
372 state
374 .connections
375 .handle_guest_tx(&self.driver, VsockPacket::new(header, &locked.get().0))
376 } else {
377 let mut temp_buf = vec![0u8; rw_len as usize];
379 let read_bytes =
380 work.read_at_offset(VSOCK_HEADER_SIZE as u64, &state.memory, &mut temp_buf)?;
381 if read_bytes != temp_buf.len() {
382 anyhow::bail!(
383 "expected to read {} bytes of payload, but only read {}",
384 temp_buf.len(),
385 read_bytes
386 );
387 }
388 state.connections.handle_guest_tx(
389 &self.driver,
390 VsockPacket::new(header, &[IoSlice::new(&temp_buf)]),
391 )
392 }
393 };
394
395 state.queue_pending(pending);
396 Ok(())
397 }
398
399 fn write_packet(
402 state: &VsockWorkerState,
403 queue_work: &VirtioQueueCallbackWork,
404 packet: &VsockPacketBuf,
405 ) -> anyhow::Result<u32> {
406 tracing::trace!(?packet.header, "sending reply");
407 let header_bytes = packet.header.as_bytes();
408 queue_work
409 .write(&state.memory, header_bytes)
410 .context("failed to write vsock header to guest rx")?;
411
412 if !packet.data.is_empty() {
415 queue_work
416 .write_at_offset(header_bytes.len() as u64, &state.memory, &packet.data)
417 .context("failed to write vsock data to guest rx")?;
418 }
419
420 Ok(header_bytes.len() as u32 + packet.header.len)
421 }
422
423 fn handle_host_rx(&mut self, state: &mut VsockWorkerState, rx_ready: RxReady) {
425 let peeked_work = state
428 .rx_queue
429 .try_peek()
430 .expect("peek already succeeded before")
431 .expect("queue was already checked to have items");
432
433 let (packet, pending) = state.connections.get_rx_packet(
434 &state.memory,
435 &self.driver,
436 peeked_work.payload(),
437 rx_ready,
438 );
439
440 if let Some(packet) = packet {
442 let queue_work = peeked_work.consume();
443 let bytes = match Self::write_packet(state, &queue_work, &packet) {
444 Ok(bytes) => bytes,
445 Err(err) => {
446 tracelimit::error_ratelimited!(
447 error = err.as_ref() as &dyn std::error::Error,
448 "failed to write vsock packet"
449 );
450
451 state
454 .connections
455 .remove(&ConnectionKey::from_rx_packet(&packet.header));
456 0
457 }
458 };
459 state.rx_queue.complete(queue_work, bytes);
460 }
461
462 state.queue_pending(pending);
463 }
464}
465
466impl AsyncRun<VsockWorkerState> for VsockWorker {
467 async fn run(
469 &mut self,
470 stop: &mut StopTask<'_>,
471 state: &mut VsockWorkerState,
472 ) -> Result<(), task_control::Cancelled> {
473 stop.until_stopped(async {
474 loop {
475 let peeked = match state.rx_queue.try_peek() {
476 Ok(p) => p,
477 Err(err) => {
478 tracing::error!(
479 error = &err as &dyn std::error::Error,
480 "error peeking virtio rx queue"
481 );
482 return false;
483 }
484 };
485
486 let has_rx_work = peeked.is_some();
487 let mut rx_ready =
488 OptionFuture::from(has_rx_work.then(|| state.rx_ready.select_next_some()));
489
490 let mut rx_queue_kick = OptionFuture::from(
493 (!has_rx_work).then(|| poll_fn(|cx| state.rx_queue.poll_kick(cx)).fuse()),
494 );
495
496 futures::select! {
498 id = state.write_ready.select_next_some() => {
499 let pending = state.connections.handle_write_ready(id);
500 state.queue_pending(pending);
501 }
502 r = state.tx_queue.select_next_some() => {
503 match r {
504 Ok(work) => self.handle_guest_tx(state, work),
505 Err(err) => tracing::error!(
506 error = &err as &dyn std::error::Error,
507 "error reading from virtio tx queue"
508 ),
509 }
510 }
511 r = rx_ready => {
512 let work = r.unwrap();
513 self.handle_host_rx(state, work);
514 }
515 _ = rx_queue_kick => {
516 }
518 r = self.listener.accept().fuse() => {
519 match r {
520 Ok((stream, _)) => {
521 tracing::trace!("host unix socket accepted");
522 match state.connections.handle_host_connect(&self.driver, stream) {
523 Err(err) => {
524 tracing::error!(
525 error = err.as_ref() as &dyn std::error::Error,
526 "error handling Unix socket connect"
527 );
528 }
529 Ok((read_work, timeout_work)) => {
530 state.queue_pending(read_work);
531 state.queue_pending(timeout_work);
532 }
533 }
534 }
535 Err(err) => tracing::error!(
536 error = &err as &dyn std::error::Error,
537 "error accepting host connections"
538 ),
539 }
540 }
541 };
542 }
543 })
544 .await?;
545 Ok(())
546 }
547}
548
549struct LockedIoSlice<'a>(SmallVec<[IoSlice<'a>; 4]>);
552
553impl LockedIoSlice<'_> {
554 fn new() -> Self {
555 Self(SmallVec::new())
556 }
557}
558
559impl<'a> LockedRange<'a> for LockedIoSlice<'a> {
560 fn push_sub_range(&mut self, sub_range: &'a [std::sync::atomic::AtomicU8]) {
561 let slice =
564 unsafe { std::slice::from_raw_parts(sub_range.as_ptr().cast::<u8>(), sub_range.len()) };
565 self.0.push(IoSlice::new(slice));
566 }
567}
568
569struct LockedIoSliceMut<'a>(SmallVec<[IoSliceMut<'a>; 4]>);
571
572impl LockedIoSliceMut<'_> {
573 fn new() -> Self {
574 Self(SmallVec::new())
575 }
576}
577
578impl<'a> LockedRange<'a> for LockedIoSliceMut<'a> {
579 fn push_sub_range(&mut self, sub_range: &'a [std::sync::atomic::AtomicU8]) {
580 let slice = unsafe {
584 std::slice::from_raw_parts_mut(sub_range.as_ptr() as *mut u8, sub_range.len())
585 };
586 self.0.push(IoSliceMut::new(slice));
587 }
588}
589
590fn lock_payload_data<'a, T: LockedRange<'a>>(
596 mem: &'a GuestMemory,
597 payload: &[VirtioQueuePayload],
598 data_len: u64,
599 require_exact_len: bool,
600 writable: bool,
601 locked_range: T,
602) -> anyhow::Result<Option<LockedRangeImpl<'a, T>>> {
603 let regions = data_regions(payload, writable, VSOCK_HEADER_SIZE as u64, data_len);
604 let gpn_list = try_build_gpn_list(regions);
605 let locked = if let Some((gpns, offset, len)) = &gpn_list {
606 if require_exact_len && *len != data_len as usize {
607 anyhow::bail!("data length mismatch in vsock tx packet");
608 }
609 let paged_range =
610 PagedRange::new(*offset, *len, gpns).expect("offset and len should be valid");
611 Some(mem.lock_range(paged_range, locked_range)?)
612 } else {
613 tracing::trace!("payload data is not representable in a single PagedRange");
614 None
615 };
616
617 Ok(locked)
618}