Skip to main content

virtio_vsock/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Virtio vsock device implementation, per section 5.10 of the virtio specification.
5
6// UNSAFETY: Pointer casts between AtomicU8 and u8 to allow direct read/write into guest memory.
7#![expect(unsafe_code)]
8
9mod connections;
10pub mod resolver;
11mod ring;
12mod spec;
13mod unix_relay;
14
15#[cfg(test)]
16mod integration_tests;
17
18use crate::connections::ConnectionInstanceId;
19use crate::connections::ConnectionKey;
20use crate::connections::TX_BUF_SIZE;
21use crate::spec::Operation;
22use crate::spec::VSOCK_HEADER_SIZE;
23use crate::spec::VsockFeaturesBank0;
24use crate::spec::VsockPacket;
25use crate::spec::VsockPacketBuf;
26use anyhow::Context;
27use connections::ConnectionManager;
28use futures::FutureExt;
29use futures::StreamExt;
30use futures::future::OptionFuture;
31use futures::future::poll_fn;
32use futures::stream::Fuse;
33use guestmem::GuestMemory;
34use guestmem::LockedRange;
35use guestmem::LockedRangeImpl;
36use guestmem::ranges::PagedRange;
37use inspect::InspectMut;
38use pal_async::socket::PolledSocket;
39use pal_async::wait::PolledWait;
40use smallvec::SmallVec;
41use spec::VsockConfig;
42use spec::VsockHeader;
43use std::io::IoSlice;
44use std::io::IoSliceMut;
45use std::path::PathBuf;
46use std::pin::Pin;
47use task_control::AsyncRun;
48use task_control::StopTask;
49use task_control::TaskControl;
50use unicycle::FuturesUnordered;
51use unix_socket::UnixListener;
52use virtio::DeviceTraits;
53use virtio::VirtioDevice;
54use virtio::VirtioQueue;
55use virtio::VirtioQueueCallbackWork;
56use virtio::queue::VirtioQueuePayload;
57use virtio::regions::data_regions;
58use virtio::regions::try_build_gpn_list;
59use virtio::spec::VirtioDeviceFeatures;
60use virtio::spec::VirtioDeviceType;
61use vmcore::vm_task::VmTaskDriver;
62use vmcore::vm_task::VmTaskDriverSource;
63use zerocopy::FromZeros;
64use zerocopy::IntoBytes;
65
66const QUEUE_COUNT: usize = 3;
67const RX_QUEUE_INDEX: usize = 0;
68const TX_QUEUE_INDEX: usize = 1;
69const EVENT_QUEUE_INDEX: usize = 2;
70
71/// Virtio vsock device.
72#[derive(InspectMut)]
73pub struct VirtioVsockDevice {
74    guest_cid: u64,
75    driver: VmTaskDriver,
76    #[inspect(skip)]
77    worker: TaskControl<VsockWorker, VsockWorkerState>,
78    #[inspect(skip)]
79    started_queues: [Option<VirtioQueue>; QUEUE_COUNT],
80    #[inspect(skip)]
81    base_path: PathBuf,
82}
83
84impl VirtioVsockDevice {
85    /// Create a new virtio-vsock device.
86    ///
87    /// `guest_cid` is the context ID assigned to the guest.
88    ///
89    /// `base_path` is the path prefix for Unix socket relay. For a vsock port P, the relay will
90    /// attempt to connect to `<base_path>_P`.
91    ///
92    /// `listener` is an pre-bound Unix listener for accepting host-initiated connections using the
93    /// hybrid vsock connect protocol.
94    pub fn new(
95        driver_source: &VmTaskDriverSource,
96        guest_cid: u64,
97        base_path: PathBuf,
98        listener: UnixListener,
99    ) -> anyhow::Result<Self> {
100        let driver = driver_source.simple();
101        let listener = PolledSocket::new(&driver, listener)
102            .context("failed to create polled socket for vsock relay listener")?;
103        Ok(Self {
104            guest_cid,
105            driver: driver.clone(),
106            worker: TaskControl::new(VsockWorker { driver, listener }),
107            started_queues: [const { None }; QUEUE_COUNT],
108            base_path,
109        })
110    }
111}
112
113impl VirtioDevice for VirtioVsockDevice {
114    fn traits(&self) -> DeviceTraits {
115        // Spec 5.10.3.2: The device SHOULD offer the VIRTIO_VSOCK_F_NO_IMPLIED_STREAM feature.
116        let features_bank0 = VsockFeaturesBank0::new()
117            .with_stream(true)
118            .with_no_implied_stream(true);
119        DeviceTraits {
120            device_id: VirtioDeviceType::VSOCK,
121            device_features: VirtioDeviceFeatures::new()
122                .with_device_specific_low(features_bank0.into_bits()),
123            max_queues: QUEUE_COUNT.try_into().unwrap(),
124            device_register_length: size_of::<VsockConfig>() as u32,
125            ..Default::default()
126        }
127    }
128
129    async fn read_registers_u32(&mut self, offset: u16) -> u32 {
130        // Device config: guest_cid is a 64-bit LE value.
131        let config = VsockConfig {
132            guest_cid: self.guest_cid.to_le(),
133        };
134        let bytes = config.as_bytes();
135        let offset = offset as usize;
136        if offset + 4 <= bytes.len() {
137            u32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap())
138        } else {
139            0
140        }
141    }
142
143    async fn write_registers_u32(&mut self, offset: u16, val: u32) {
144        tracelimit::warn_ratelimited!(offset, val, "vsock: unexpected config write");
145    }
146
147    async fn start_queue(
148        &mut self,
149        idx: u16,
150        resources: virtio::QueueResources,
151        features: &VirtioDeviceFeatures,
152        initial_state: Option<virtio::queue::QueueState>,
153    ) -> anyhow::Result<()> {
154        if self
155            .started_queues
156            .get(idx as usize)
157            .ok_or_else(|| anyhow::anyhow!("invalid queue index {idx}"))?
158            .is_some()
159        {
160            anyhow::bail!("virtio queue already started");
161        }
162
163        // Spec 5.10.3.2: If no feature bit has been negotiated, the device SHOULD act as if
164        // VIRTIO_VSOCK_F_STREAM has been negotiated.
165        //
166        // If VIRTIO_VSOCK_F_SEQPACKET has been negotiated, but not
167        // VIRTIO_VSOCK_F_NO_IMPLIED_STREAM, the device MAY act as if VIRTIO_VSOCK_F_STREAM has also
168        // been negotiated.
169        let negotiated_features = VsockFeaturesBank0::from_bits(features.bank(0));
170        if negotiated_features.no_implied_stream() && !negotiated_features.stream() {
171            anyhow::bail!("guest does not support stream sockets");
172        }
173
174        let queue_event = PolledWait::new(&self.driver, resources.event)
175            .context("failed to create queue event")?;
176
177        let queue = VirtioQueue::new(
178            *features,
179            resources.params,
180            resources.guest_memory.clone(),
181            resources.notify,
182            queue_event,
183            initial_state,
184        )
185        .context("failed to create virtio queue")?;
186
187        self.started_queues[idx as usize] = Some(queue);
188
189        // Start the worker if all queues are started.
190        if self.started_queues.iter().all(|q| q.is_some()) {
191            let state = VsockWorkerState {
192                rx_queue: self.started_queues[RX_QUEUE_INDEX].take().unwrap(),
193                tx_queue: self.started_queues[TX_QUEUE_INDEX].take().unwrap().fuse(),
194                _event_queue: self.started_queues[EVENT_QUEUE_INDEX].take().unwrap(),
195                memory: resources.guest_memory.clone(),
196                connections: ConnectionManager::new(self.guest_cid, self.base_path.clone()),
197                rx_ready: FuturesUnordered::new(),
198                write_ready: FuturesUnordered::new(),
199            };
200
201            self.worker
202                .insert(self.driver.clone(), "virtio-vsock-worker", state);
203            self.worker.start();
204        }
205
206        Ok(())
207    }
208
209    async fn stop_queue(&mut self, idx: u16) -> Option<virtio::queue::QueueState> {
210        // Stop the worker task (cancels the run loop via until_stopped).
211        if self.worker.stop().await {
212            let state = self.worker.remove();
213
214            // Transfer the queues back, so we can return the state as each one is stopped
215            // individually.
216            self.started_queues[RX_QUEUE_INDEX] = Some(state.rx_queue);
217            self.started_queues[TX_QUEUE_INDEX] = Some(state.tx_queue.into_inner());
218            self.started_queues[EVENT_QUEUE_INDEX] = Some(state._event_queue);
219        }
220
221        // Remove the queue state (drops VirtioQueue).
222        self.started_queues[idx as usize]
223            .take()
224            .map(|queue| queue.queue_state())
225    }
226}
227
228/// Indicates a connection that is read to put data on the rx queue to send to the guest.
229/// N.B. When a future returns an item, it's not always guaranteed that a packet is ready to be
230///      sent. It could be a spurious wake from a poll, or a pending connection that's still reading
231///      its connect request, etc.
232pub enum RxReady {
233    /// A connection has data or a control packet to send.
234    Connection(ConnectionInstanceId),
235    /// A pending connection has data.
236    PendingConnection(u64),
237    /// A RST packet should be sent for a connection that was removed or invalid.
238    SendReset(ConnectionKey),
239}
240
241/// A pinned future that resolves to an `RxReady` item.
242type RxReadyItem = Pin<Box<dyn Future<Output = RxReady> + Send>>;
243
244/// A pinned future that resolves to a `ConnectionInstanceId` for a connection that's ready to write
245/// buffered data to the unix socket.
246type WriteReadyItem = Pin<Box<dyn Future<Output = ConnectionInstanceId> + Send>>;
247
248/// Represents futures returned from a function that the worker should wait on.
249struct PendingFutures {
250    rx_ready: Option<RxReadyItem>,
251    write_ready: Option<WriteReadyItem>,
252}
253
254impl PendingFutures {
255    /// A value holding no pending futures.
256    const NONE: Self = Self {
257        rx_ready: None,
258        write_ready: None,
259    };
260
261    /// Create a new `PendingFutures` with the given RxReady future, and no WriteReady future.
262    fn rx(future: Option<RxReadyItem>) -> Self {
263        Self {
264            rx_ready: future,
265            write_ready: None,
266        }
267    }
268
269    /// Create a new `PendingFutures` with a future that is immediately ready with the given RxReady
270    /// item.
271    fn simple_rx(work: RxReady) -> Self {
272        Self {
273            rx_ready: Some(Box::pin(async move { work })),
274            write_ready: None,
275        }
276    }
277
278    /// Create a new `PendingFutures` with the given WriteReady future and RxReady futures.
279    fn new(work: Option<WriteReadyItem>, rx_work: Option<RxReady>) -> Self {
280        Self {
281            rx_ready: rx_work.map(|w| -> RxReadyItem { Box::pin(async move { w }) }),
282            write_ready: work,
283        }
284    }
285}
286
287/// Transient worker state for all three queues.
288struct VsockWorkerState {
289    connections: ConnectionManager,
290    rx_queue: VirtioQueue,
291    tx_queue: Fuse<VirtioQueue>,
292    // The event queue is not used by this implementation.
293    _event_queue: VirtioQueue,
294    memory: GuestMemory,
295    rx_ready: FuturesUnordered<RxReadyItem>,
296    write_ready: FuturesUnordered<WriteReadyItem>,
297}
298
299impl VsockWorkerState {
300    /// Queue pending futures returned from the connection manager to be processed by the worker run
301    /// loop.
302    fn queue_pending(&mut self, pending: PendingFutures) {
303        if let Some(work) = pending.rx_ready {
304            self.rx_ready.push(work);
305        }
306        if let Some(work) = pending.write_ready {
307            self.write_ready.push(work);
308        }
309    }
310}
311
312/// The main worker for the virtio-vsock device.
313struct VsockWorker {
314    driver: VmTaskDriver,
315    listener: PolledSocket<UnixListener>,
316}
317
318impl VsockWorker {
319    /// Handle a work item from the tx virtqueue (guest -> host).
320    fn handle_guest_tx(&mut self, state: &mut VsockWorkerState, work: VirtioQueueCallbackWork) {
321        if let Err(err) = self.handle_guest_tx_inner(state, &work) {
322            tracelimit::error_ratelimited!(
323                error = err.as_ref() as &dyn std::error::Error,
324                "error handling vsock tx work"
325            );
326        }
327
328        state.tx_queue.get_mut().complete(work, 0);
329    }
330
331    /// Handle a work item from the TX virtqueue (guest -> host).
332    fn handle_guest_tx_inner(
333        &mut self,
334        state: &mut VsockWorkerState,
335        work: &VirtioQueueCallbackWork,
336    ) -> anyhow::Result<()> {
337        let mut header = VsockHeader::new_zeroed();
338        work.read(&state.memory, header.as_mut_bytes())?;
339
340        let rw_len = if header.operation() == Operation::RW {
341            // Unaligned field read.
342            let len = header.len;
343
344            // The guest should never exceed our available credit, which cannot be larger than the
345            // max buffer size. This check prevents a guest from consuming too much host memory if
346            // we need to bounce the data through a temporary buffer.
347            if len > TX_BUF_SIZE {
348                anyhow::bail!("guest attempted to send packet with data length {len}");
349            }
350
351            len
352        } else {
353            // Ignore the length field for other packets (it should always be zero).
354            0
355        };
356
357        tracing::trace!(?header, "got tx packet from guest");
358        let pending = {
359            if rw_len == 0 {
360                // No payload, so handle the packet immediately.
361                state
362                    .connections
363                    .handle_guest_tx(&self.driver, VsockPacket::new(header, &[]))
364            } else if let Some(locked) = lock_payload_data(
365                &state.memory,
366                &work.payload,
367                rw_len as u64,
368                true,
369                false,
370                LockedIoSlice::new(),
371            )? {
372                // We can read the payload directly from guest memory.
373                state
374                    .connections
375                    .handle_guest_tx(&self.driver, VsockPacket::new(header, &locked.get().0))
376            } else {
377                // Use a temp bounce buffer if the payload couldn't be locked.
378                let mut temp_buf = vec![0u8; rw_len as usize];
379                let read_bytes =
380                    work.read_at_offset(VSOCK_HEADER_SIZE as u64, &state.memory, &mut temp_buf)?;
381                if read_bytes != temp_buf.len() {
382                    anyhow::bail!(
383                        "expected to read {} bytes of payload, but only read {}",
384                        temp_buf.len(),
385                        read_bytes
386                    );
387                }
388                state.connections.handle_guest_tx(
389                    &self.driver,
390                    VsockPacket::new(header, &[IoSlice::new(&temp_buf)]),
391                )
392            }
393        };
394
395        state.queue_pending(pending);
396        Ok(())
397    }
398
399    /// Helper to write a packet to the RX queue. Returns the number of bytes
400    /// written. The caller is responsible for completing the work item.
401    fn write_packet(
402        state: &VsockWorkerState,
403        queue_work: &VirtioQueueCallbackWork,
404        packet: &VsockPacketBuf,
405    ) -> anyhow::Result<u32> {
406        tracing::trace!(?packet.header, "sending reply");
407        let header_bytes = packet.header.as_bytes();
408        queue_work
409            .write(&state.memory, header_bytes)
410            .context("failed to write vsock header to guest rx")?;
411
412        // The data buffer is present if only this is an RW packet and the data could not be read
413        // directly into the guest buffer.
414        if !packet.data.is_empty() {
415            queue_work
416                .write_at_offset(header_bytes.len() as u64, &state.memory, &packet.data)
417                .context("failed to write vsock data to guest rx")?;
418        }
419
420        Ok(header_bytes.len() as u32 + packet.header.len)
421    }
422
423    /// Try to deliver pending rx packets to the guest via the rx virtqueue.
424    fn handle_host_rx(&mut self, state: &mut VsockWorkerState, rx_ready: RxReady) {
425        // Due to lifetime issues the PeekedWork cannot be passed into this function so get it
426        // back here.
427        let peeked_work = state
428            .rx_queue
429            .try_peek()
430            .expect("peek already succeeded before")
431            .expect("queue was already checked to have items");
432
433        let (packet, pending) = state.connections.get_rx_packet(
434            &state.memory,
435            &self.driver,
436            peeked_work.payload(),
437            rx_ready,
438        );
439
440        // If there's a packet to send, write it to the guest.
441        if let Some(packet) = packet {
442            let queue_work = peeked_work.consume();
443            let bytes = match Self::write_packet(state, &queue_work, &packet) {
444                Ok(bytes) => bytes,
445                Err(err) => {
446                    tracelimit::error_ratelimited!(
447                        error = err.as_ref() as &dyn std::error::Error,
448                        "failed to write vsock packet"
449                    );
450
451                    // We can't recover from this. Remove the connection so any future attempts to use
452                    // it will fail.
453                    state
454                        .connections
455                        .remove(&ConnectionKey::from_rx_packet(&packet.header));
456                    0
457                }
458            };
459            state.rx_queue.complete(queue_work, bytes);
460        }
461
462        state.queue_pending(pending);
463    }
464}
465
466impl AsyncRun<VsockWorkerState> for VsockWorker {
467    /// The main worker loop.
468    async fn run(
469        &mut self,
470        stop: &mut StopTask<'_>,
471        state: &mut VsockWorkerState,
472    ) -> Result<(), task_control::Cancelled> {
473        stop.until_stopped(async {
474            loop {
475                let peeked = match state.rx_queue.try_peek() {
476                    Ok(p) => p,
477                    Err(err) => {
478                        tracing::error!(
479                            error = &err as &dyn std::error::Error,
480                            "error peeking virtio rx queue"
481                        );
482                        return false;
483                    }
484                };
485
486                let has_rx_work = peeked.is_some();
487                let mut rx_ready =
488                    OptionFuture::from(has_rx_work.then(|| state.rx_ready.select_next_some()));
489
490                // This future unfortunately borrows state.rx_queue, which means peeked cannot be
491                // used below.
492                let mut rx_queue_kick = OptionFuture::from(
493                    (!has_rx_work).then(|| poll_fn(|cx| state.rx_queue.poll_kick(cx)).fuse()),
494                );
495
496                // Wait for work to do from either host or guest.
497                futures::select! {
498                    id = state.write_ready.select_next_some() => {
499                        let pending = state.connections.handle_write_ready(id);
500                        state.queue_pending(pending);
501                    }
502                    r = state.tx_queue.select_next_some() => {
503                        match r {
504                            Ok(work) => self.handle_guest_tx(state, work),
505                            Err(err) => tracing::error!(
506                                error = &err as &dyn std::error::Error,
507                                "error reading from virtio tx queue"
508                            ),
509                        }
510                    }
511                    r = rx_ready => {
512                        let work = r.unwrap();
513                        self.handle_host_rx(state, work);
514                    }
515                    _ = rx_queue_kick => {
516                        // New buffers are available in the rx queue; repeat the loop to peek again.
517                    }
518                    r = self.listener.accept().fuse() => {
519                        match r {
520                            Ok((stream, _)) => {
521                                tracing::trace!("host unix socket accepted");
522                                match state.connections.handle_host_connect(&self.driver, stream) {
523                                    Err(err) => {
524                                        tracing::error!(
525                                            error = err.as_ref() as &dyn std::error::Error,
526                                            "error handling Unix socket connect"
527                                        );
528                                    }
529                                    Ok((read_work, timeout_work)) => {
530                                        state.queue_pending(read_work);
531                                        state.queue_pending(timeout_work);
532                                    }
533                                }
534                            }
535                            Err(err) => tracing::error!(
536                                error = &err as &dyn std::error::Error,
537                                "error accepting host connections"
538                            ),
539                        }
540                    }
541                };
542            }
543        })
544        .await?;
545        Ok(())
546    }
547}
548
549// Implementation of LockedRange that collects IoSlice items for use with socket vectored IO.
550// Uses SmallVec since this will nearly always have one item.
551struct LockedIoSlice<'a>(SmallVec<[IoSlice<'a>; 4]>);
552
553impl LockedIoSlice<'_> {
554    fn new() -> Self {
555        Self(SmallVec::new())
556    }
557}
558
559impl<'a> LockedRange<'a> for LockedIoSlice<'a> {
560    fn push_sub_range(&mut self, sub_range: &'a [std::sync::atomic::AtomicU8]) {
561        // SAFETY: Treating AtomicU8 as u8 for vectored IO. The lifetime annotations ensure the
562        // sub_range lives long enough for the IoSlice.
563        let slice =
564            unsafe { std::slice::from_raw_parts(sub_range.as_ptr().cast::<u8>(), sub_range.len()) };
565        self.0.push(IoSlice::new(slice));
566    }
567}
568
569// Same as LockedIoSlice but for mutable buffers.
570struct LockedIoSliceMut<'a>(SmallVec<[IoSliceMut<'a>; 4]>);
571
572impl LockedIoSliceMut<'_> {
573    fn new() -> Self {
574        Self(SmallVec::new())
575    }
576}
577
578impl<'a> LockedRange<'a> for LockedIoSliceMut<'a> {
579    fn push_sub_range(&mut self, sub_range: &'a [std::sync::atomic::AtomicU8]) {
580        // SAFETY: Treating AtomicU8 as mut u8 for vectored IO. The lifetime annotations ensure the
581        // sub_range lives long enough for the IoSliceMut. Treating the memory as mutable should be
582        // safe because AtomicU8 also provides interior mutability.
583        let slice = unsafe {
584            std::slice::from_raw_parts_mut(sub_range.as_ptr() as *mut u8, sub_range.len())
585        };
586        self.0.push(IoSliceMut::new(slice));
587    }
588}
589
590/// Attempts to lock the payload buffers for a virtio request.
591///
592/// Returns `Ok(Some(...))` if every region boundary falls on a page boundary (or regions are
593/// GPA-contiguous), so the whole chain can be expressed as one [`PagedRange`]. Returns `Ok(None)`
594/// if any interior boundary violates the constraint.
595fn lock_payload_data<'a, T: LockedRange<'a>>(
596    mem: &'a GuestMemory,
597    payload: &[VirtioQueuePayload],
598    data_len: u64,
599    require_exact_len: bool,
600    writable: bool,
601    locked_range: T,
602) -> anyhow::Result<Option<LockedRangeImpl<'a, T>>> {
603    let regions = data_regions(payload, writable, VSOCK_HEADER_SIZE as u64, data_len);
604    let gpn_list = try_build_gpn_list(regions);
605    let locked = if let Some((gpns, offset, len)) = &gpn_list {
606        if require_exact_len && *len != data_len as usize {
607            anyhow::bail!("data length mismatch in vsock tx packet");
608        }
609        let paged_range =
610            PagedRange::new(*offset, *len, gpns).expect("offset and len should be valid");
611        Some(mem.lock_range(paged_range, locked_range)?)
612    } else {
613        tracing::trace!("payload data is not representable in a single PagedRange");
614        None
615    };
616
617    Ok(locked)
618}