storvsp/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! VMBus SCSI controller emulator (StorVSP).
5//!
6//! StorVSP implements the Hyper-V synthetic SCSI protocol — a VMBus-based
7//! transport that carries SCSI CDBs between the guest's `storvsc` driver and
8//! the VMM. This is not a standard SCSI transport (like iSCSI or SAS); it's a
9//! Hyper-V-specific wire format defined in [`storvsp_protocol`].
10//!
11//! # Architecture
12//!
13//! The crate uses a multi-worker model. The primary VMBus channel handles
14//! protocol version negotiation (Win6 through Blue); sub-channels process I/O
15//! in parallel. Each worker owns a VMBus ring and processes packets
16//! concurrently via `FuturesUnordered`.
17//!
18//! StorVSP handles the transport (ring buffer management, GPADL setup, packet
19//! framing, sub-channel lifecycle) and a few SCSI control commands directly
20//! (`REPORT_LUNS`, `INQUIRY` for absent targets). All actual I/O is delegated
21//! to [`AsyncScsiDisk`] implementations — StorVSP
22//! never interprets SCSI data CDBs itself.
23//!
24//! For the channel/sub-channel model, CPU affinity, and performance
25//! characteristics, see the
26//! [StorVSP Channels & Subchannels](https://openvmm.dev/reference/devices/vmbus/storvsp_channels.html)
27//! page in the OpenVMM Guide.
28//!
29//! # Key types
30//!
31//! - [`StorageDevice`] — the VMBus device. Implements `VmbusDevice` and
32//!   `SaveRestoreVmbusDevice`.
33//! - [`ScsiController`] — manages attached disks by [`ScsiPath`]. Supports
34//!   runtime attach/remove.
35//! - [`ScsiControllerDisk`] — wraps `Arc<dyn AsyncScsiDisk>`.
36//!
37//! # Performance
38//!
39//! Poll-mode optimization: when pending I/O count exceeds
40//! `poll_mode_queue_depth`, the worker switches from interrupt-driven to
41//! busy-poll for new requests, reducing guest exit frequency. Future storage
42//! for SCSI request processing is pooled to avoid allocation on the hot path.
43
44#![expect(missing_docs)]
45#![forbid(unsafe_code)]
46
47#[cfg(feature = "ioperf")]
48pub mod ioperf;
49
50#[cfg(feature = "test")]
51pub mod test_helpers;
52
53#[cfg(not(feature = "test"))]
54mod test_helpers;
55
56pub mod resolver;
57mod save_restore;
58
59use crate::ring::gparange::MultiPagedRangeBuf;
60use anyhow::Context as _;
61use async_trait::async_trait;
62use fast_select::FastSelect;
63use futures::FutureExt;
64use futures::StreamExt;
65use futures::select_biased;
66use guestmem::AccessError;
67use guestmem::GuestMemory;
68use guestmem::MemoryRead;
69use guestmem::MemoryWrite;
70use guestmem::ranges::PagedRange;
71use guid::Guid;
72use inspect::Inspect;
73use inspect::InspectMut;
74use inspect_counters::Counter;
75use inspect_counters::Histogram;
76use oversized_box::OversizedBox;
77use parking_lot::Mutex;
78use parking_lot::RwLock;
79use ring::OutgoingPacketType;
80use scsi::AdditionalSenseCode;
81use scsi::ScsiOp;
82use scsi::ScsiStatus;
83use scsi::srb::SrbStatus;
84use scsi::srb::SrbStatusAndFlags;
85use scsi_buffers::RequestBuffers;
86use scsi_core::AsyncScsiDisk;
87use scsi_core::Request;
88use scsi_core::ScsiResult;
89use scsi_defs as scsi;
90use scsidisk::illegal_request_sense;
91use slab::Slab;
92use std::collections::hash_map::Entry;
93use std::collections::hash_map::HashMap;
94use std::fmt::Debug;
95use std::future::Future;
96use std::future::poll_fn;
97use std::pin::Pin;
98use std::sync::Arc;
99use std::sync::atomic::AtomicU32;
100use std::sync::atomic::Ordering::Relaxed;
101use std::task::Context;
102use std::task::Poll;
103use storvsp_resources::ScsiPath;
104use task_control::AsyncRun;
105use task_control::InspectTask;
106use task_control::StopTask;
107use task_control::TaskControl;
108use thiserror::Error;
109use tracing_helpers::ErrorValueExt;
110use unicycle::FuturesUnordered;
111use vmbus_async::queue;
112use vmbus_async::queue::ExternalDataError;
113use vmbus_async::queue::IncomingPacket;
114use vmbus_async::queue::OutgoingPacket;
115use vmbus_async::queue::Queue;
116use vmbus_channel::RawAsyncChannel;
117use vmbus_channel::bus::ChannelType;
118use vmbus_channel::bus::OfferParams;
119use vmbus_channel::bus::OpenRequest;
120use vmbus_channel::channel::ChannelControl;
121use vmbus_channel::channel::ChannelOpenError;
122use vmbus_channel::channel::DeviceResources;
123use vmbus_channel::channel::RestoreControl;
124use vmbus_channel::channel::SaveRestoreVmbusDevice;
125use vmbus_channel::channel::VmbusDevice;
126use vmbus_channel::gpadl_ring::GpadlRingMem;
127use vmbus_channel::gpadl_ring::gpadl_channel;
128use vmbus_core::protocol::UserDefinedData;
129use vmbus_ring as ring;
130use vmbus_ring::RingMem;
131use vmcore::save_restore::RestoreError;
132use vmcore::save_restore::SaveError;
133use vmcore::save_restore::SavedStateBlob;
134use vmcore::vm_task::VmTaskDriver;
135use vmcore::vm_task::VmTaskDriverSource;
136use zerocopy::FromBytes;
137use zerocopy::FromZeros;
138use zerocopy::Immutable;
139use zerocopy::IntoBytes;
140use zerocopy::KnownLayout;
141
142/// The IO queue depth at which the controller switches from guest-signal-driven
143/// to poll-mode operation. This optimization reduces the guest exit rate by
144/// relying on (typically-interrupt-driven) IO completions to drive polling for
145/// new IO requests.
146const DEFAULT_POLL_MODE_QUEUE_DEPTH: u32 = 1;
147
148pub struct StorageDevice {
149    instance_id: Guid,
150    ide_path: Option<ScsiPath>,
151    workers: Vec<WorkerAndDriver>,
152    controller: Arc<ScsiControllerState>,
153    resources: DeviceResources,
154    driver_source: VmTaskDriverSource,
155    max_sub_channel_count: u16,
156    protocol: Arc<Protocol>,
157    io_queue_depth: u32,
158}
159
160#[derive(Inspect)]
161struct WorkerAndDriver {
162    #[inspect(flatten)]
163    worker: TaskControl<WorkerState, Worker>,
164    driver: VmTaskDriver,
165}
166
167struct WorkerState;
168
169impl InspectMut for StorageDevice {
170    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
171        let mut resp = req.respond();
172
173        let disks = self.controller.disks.read();
174        for (path, controller_disk) in disks.iter() {
175            resp.child(&format!("disks/{}", path), |req| {
176                controller_disk.disk.inspect(req);
177            });
178        }
179
180        resp.fields(
181            "channels",
182            self.workers
183                .iter()
184                .filter(|task| task.worker.has_state())
185                .enumerate(),
186        )
187        .field(
188            "poll_mode_queue_depth",
189            inspect::AtomicMut(&self.controller.poll_mode_queue_depth),
190        );
191    }
192}
193
194struct Worker<T: RingMem = GpadlRingMem> {
195    inner: WorkerInner,
196    rescan_notification: futures::channel::mpsc::Receiver<()>,
197    fast_select: FastSelect,
198    queue: Queue<T>,
199}
200
201struct Protocol {
202    state: RwLock<ProtocolState>,
203    /// Signaled when `state` transitions to `ProtocolState::Ready`.
204    ready: event_listener::Event,
205}
206
207struct WorkerInner {
208    protocol: Arc<Protocol>,
209    request_size: usize,
210    controller: Arc<ScsiControllerState>,
211    channel_index: u16,
212    scsi_queue: Arc<ScsiCommandQueue>,
213    scsi_requests: FuturesUnordered<ScsiRequest>,
214    scsi_requests_states: Slab<ScsiRequestState>,
215    full_request_pool: Vec<Arc<ScsiRequestAndRange>>,
216    future_pool: Vec<OversizedBox<(), ScsiOpStorage>>,
217    channel_control: ChannelControl,
218    max_io_queue_depth: usize,
219    stats: WorkerStats,
220}
221
222#[derive(Debug, Default, Inspect)]
223struct WorkerStats {
224    ios_submitted: Counter,
225    ios_completed: Counter,
226    wakes: Counter,
227    wakes_spurious: Counter,
228    per_wake_submissions: Histogram<10>,
229    per_wake_completions: Histogram<10>,
230}
231
232#[repr(u16)]
233#[derive(Copy, Clone, Debug, Inspect, PartialEq, Eq, PartialOrd, Ord)]
234enum Version {
235    Win6 = storvsp_protocol::VERSION_WIN6,
236    Win7 = storvsp_protocol::VERSION_WIN7,
237    Win8 = storvsp_protocol::VERSION_WIN8,
238    Blue = storvsp_protocol::VERSION_BLUE,
239}
240
241#[derive(Debug, Error)]
242#[error("protocol version {0:#x} not supported")]
243struct UnsupportedVersion(u16);
244
245impl Version {
246    fn parse(major_minor: u16) -> Result<Self, UnsupportedVersion> {
247        let version = match major_minor {
248            storvsp_protocol::VERSION_WIN6 => Self::Win6,
249            storvsp_protocol::VERSION_WIN7 => Self::Win7,
250            storvsp_protocol::VERSION_WIN8 => Self::Win8,
251            storvsp_protocol::VERSION_BLUE => Self::Blue,
252            version => return Err(UnsupportedVersion(version)),
253        };
254        assert_eq!(version as u16, major_minor);
255        Ok(version)
256    }
257
258    fn max_request_size(&self) -> usize {
259        match self {
260            Version::Win8 | Version::Blue => storvsp_protocol::SCSI_REQUEST_LEN_V2,
261            Version::Win6 | Version::Win7 => storvsp_protocol::SCSI_REQUEST_LEN_V1,
262        }
263    }
264}
265
266#[derive(Copy, Clone)]
267enum ProtocolState {
268    Init(InitState),
269    Ready {
270        version: Version,
271        subchannel_count: u16,
272    },
273}
274
275#[derive(Copy, Clone, Debug)]
276enum InitState {
277    Begin,
278    QueryVersion,
279    QueryProperties {
280        version: Version,
281    },
282    EndInitialization {
283        version: Version,
284        subchannel_count: Option<u16>,
285    },
286}
287
288/// The internal SCSI operation future type.
289///
290/// This is a boxed future of a large pre-determined size. The box is reused
291/// after a SCSI request completes to avoid allocations in the hot path.
292///
293/// An Option type is used so that the future can be efficiently dropped (via
294/// `Pin::set(x, None)`) before it is stashed away for reuse.
295type ScsiOpStorage = [u64; SCSI_REQUEST_STACK_SIZE / 8];
296type ScsiOpFuture = Pin<OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>>;
297
298/// The amount of space reserved for a ScsiOpFuture.
299///
300/// This was chosen by running `cargo test -p storvsp -- --no-capture` and looking at the required
301/// size that was given in the failure message
302const SCSI_REQUEST_STACK_SIZE: usize = scsi_core::ASYNC_SCSI_DISK_STACK_SIZE + 272;
303
304struct ScsiRequest {
305    request_id: usize,
306    future: Option<ScsiOpFuture>,
307}
308
309impl ScsiRequest {
310    fn new(
311        request_id: usize,
312        future: OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>,
313    ) -> Self {
314        Self {
315            request_id,
316            future: Some(future.into()),
317        }
318    }
319}
320
321impl Future for ScsiRequest {
322    type Output = (usize, ScsiResult, OversizedBox<(), ScsiOpStorage>);
323
324    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
325        let this = self.get_mut();
326        let future = this.future.as_mut().unwrap().as_mut();
327        let result = std::task::ready!(future.poll(cx));
328        // Return the future so that its storage can be reused.
329        let future = this.future.take().unwrap();
330        Poll::Ready((this.request_id, result, OversizedBox::empty_pinned(future)))
331    }
332}
333
334#[derive(Debug, Error)]
335enum WorkerError {
336    #[error("packet error")]
337    PacketError(#[source] PacketError),
338    #[error("queue error")]
339    Queue(#[source] queue::Error),
340    #[error("queue should have enough space but no longer does")]
341    NotEnoughSpace,
342}
343
344#[derive(Debug, Error)]
345enum PacketError {
346    #[error("Not transactional")]
347    NotTransactional,
348    #[error("Unrecognized operation {0:?}")]
349    UnrecognizedOperation(storvsp_protocol::Operation),
350    #[error("Invalid packet type")]
351    InvalidPacketType,
352    #[error("Invalid data transfer length")]
353    InvalidDataTransferLength,
354    #[error("Access error: {0}")]
355    Access(#[source] AccessError),
356    #[error("Range error")]
357    Range(#[source] ExternalDataError),
358}
359
360#[derive(Debug, Default, Clone)]
361struct Range {
362    len: usize,
363    is_write: bool,
364}
365
366impl Range {
367    fn new(buf: &MultiPagedRangeBuf, request: &storvsp_protocol::ScsiRequest) -> Option<Self> {
368        let len = request.data_transfer_length as usize;
369        let is_write = request.data_in != 0;
370        // Ensure there is exactly one range and it's large enough, or there are
371        // zero ranges and there is no associated SCSI buffer.
372        if buf.range_count() > 1 || (len > 0 && buf.first()?.len() < len) {
373            return None;
374        }
375        Some(Self { len, is_write })
376    }
377
378    fn buffer<'a>(
379        &'a self,
380        buf: &'a MultiPagedRangeBuf,
381        guest_memory: &'a GuestMemory,
382    ) -> RequestBuffers<'a> {
383        let mut range = buf.first().unwrap_or_else(PagedRange::empty);
384        range.truncate(self.len);
385        RequestBuffers::new(guest_memory, range, self.is_write)
386    }
387}
388
389#[derive(Debug)]
390struct Packet {
391    data: PacketData,
392    transaction_id: u64,
393    request_size: usize,
394}
395
396#[derive(Debug)]
397enum PacketData {
398    BeginInitialization,
399    EndInitialization,
400    QueryProtocolVersion(u16),
401    QueryProperties,
402    CreateSubChannels(u16),
403    ExecuteScsi(Arc<ScsiRequestAndRange>),
404    ResetBus,
405    ResetAdapter,
406    ResetLun,
407}
408
409#[derive(Debug)]
410pub struct RangeError;
411
412fn parse_packet<T: RingMem>(
413    packet: &IncomingPacket<'_, T>,
414    pool: &mut Vec<Arc<ScsiRequestAndRange>>,
415) -> Result<Packet, PacketError> {
416    let packet = match packet {
417        IncomingPacket::Completion(_) => return Err(PacketError::InvalidPacketType),
418        IncomingPacket::Data(packet) => packet,
419    };
420    let transaction_id = packet
421        .transaction_id()
422        .ok_or(PacketError::NotTransactional)?;
423
424    let mut reader = packet.reader();
425    let header: storvsp_protocol::Packet = reader.read_plain().map_err(PacketError::Access)?;
426    // You would expect that this should be limited to the current protocol
427    // version's maximum packet size, but this is not what Hyper-V does, and
428    // Linux 6.1 relies on this behavior during protocol initialization.
429    let request_size = reader.len().min(storvsp_protocol::SCSI_REQUEST_LEN_MAX);
430    let data = match header.operation {
431        storvsp_protocol::Operation::BEGIN_INITIALIZATION => PacketData::BeginInitialization,
432        storvsp_protocol::Operation::END_INITIALIZATION => PacketData::EndInitialization,
433        storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION => {
434            let mut version = storvsp_protocol::ProtocolVersion::new_zeroed();
435            reader
436                .read(version.as_mut_bytes())
437                .map_err(PacketError::Access)?;
438            PacketData::QueryProtocolVersion(version.major_minor)
439        }
440        storvsp_protocol::Operation::QUERY_PROPERTIES => PacketData::QueryProperties,
441        storvsp_protocol::Operation::EXECUTE_SRB => {
442            let mut full_request = pool.pop().unwrap_or_else(|| {
443                Arc::new(ScsiRequestAndRange {
444                    external_data: Range::default(),
445                    external_data_buf: MultiPagedRangeBuf::new(),
446                    request: storvsp_protocol::ScsiRequest::new_zeroed(),
447                    request_size,
448                })
449            });
450
451            {
452                let full_request = Arc::get_mut(&mut full_request).unwrap();
453                let request_buf = &mut full_request.request.as_mut_bytes()[..request_size];
454                reader.read(request_buf).map_err(PacketError::Access)?;
455
456                full_request.external_data_buf.clear();
457                packet
458                    .read_external_ranges(&mut full_request.external_data_buf)
459                    .map_err(PacketError::Range)?;
460
461                full_request.external_data =
462                    Range::new(&full_request.external_data_buf, &full_request.request)
463                        .ok_or(PacketError::InvalidDataTransferLength)?;
464            }
465
466            PacketData::ExecuteScsi(full_request)
467        }
468        storvsp_protocol::Operation::RESET_LUN => PacketData::ResetLun,
469        storvsp_protocol::Operation::RESET_ADAPTER => PacketData::ResetAdapter,
470        storvsp_protocol::Operation::RESET_BUS => PacketData::ResetBus,
471        storvsp_protocol::Operation::CREATE_SUB_CHANNELS => {
472            let mut sub_channel_count: u16 = 0;
473            reader
474                .read(sub_channel_count.as_mut_bytes())
475                .map_err(PacketError::Access)?;
476            PacketData::CreateSubChannels(sub_channel_count)
477        }
478        _ => return Err(PacketError::UnrecognizedOperation(header.operation)),
479    };
480
481    if let PacketData::ExecuteScsi(_) = data {
482        tracing::trace!(transaction_id, ?data, "parse_packet");
483    } else {
484        tracing::debug!(transaction_id, ?data, "parse_packet");
485    }
486
487    Ok(Packet {
488        data,
489        request_size,
490        transaction_id,
491    })
492}
493
494impl WorkerInner {
495    fn send_vmbus_packet<M: RingMem>(
496        &mut self,
497        writer: &mut queue::WriteBatch<'_, M>,
498        packet_type: OutgoingPacketType<'_>,
499        request_size: usize,
500        transaction_id: u64,
501        operation: storvsp_protocol::Operation,
502        status: storvsp_protocol::NtStatus,
503        payload: &[u8],
504    ) -> Result<(), WorkerError> {
505        let header = storvsp_protocol::Packet {
506            operation,
507            flags: 0,
508            status,
509        };
510
511        let packet_size = size_of_val(&header) + request_size;
512
513        // Zero pad or truncate the payload to the queue's packet size. This is
514        // necessary because Windows guests check that each packet's size is
515        // exactly the largest possible packet size for the negotiated protocol
516        // version.
517        let len = size_of_val(&header) + size_of_val(payload);
518        let padding = [0; storvsp_protocol::SCSI_REQUEST_LEN_MAX];
519        let (payload_bytes, padding_bytes) = if len > packet_size {
520            (&payload[..packet_size - size_of_val(&header)], &[][..])
521        } else {
522            (payload, &padding[..packet_size - len])
523        };
524        assert_eq!(
525            size_of_val(&header) + payload_bytes.len() + padding_bytes.len(),
526            packet_size
527        );
528        writer
529            .try_write(&OutgoingPacket {
530                transaction_id,
531                packet_type,
532                payload: &[header.as_bytes(), payload_bytes, padding_bytes],
533            })
534            .map_err(|err| match err {
535                queue::TryWriteError::Full(_) => WorkerError::NotEnoughSpace,
536                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
537            })
538    }
539
540    fn send_packet<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
541        &mut self,
542        writer: &mut queue::WriteHalf<'_, M>,
543        operation: storvsp_protocol::Operation,
544        status: storvsp_protocol::NtStatus,
545        payload: &P,
546    ) -> Result<(), WorkerError> {
547        self.send_vmbus_packet(
548            &mut writer.batched(),
549            OutgoingPacketType::InBandNoCompletion,
550            self.request_size,
551            0,
552            operation,
553            status,
554            payload.as_bytes(),
555        )
556    }
557
558    fn send_completion<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
559        &mut self,
560        writer: &mut queue::WriteHalf<'_, M>,
561        packet: &Packet,
562        status: storvsp_protocol::NtStatus,
563        payload: &P,
564    ) -> Result<(), WorkerError> {
565        self.send_vmbus_packet(
566            &mut writer.batched(),
567            OutgoingPacketType::Completion,
568            packet.request_size,
569            packet.transaction_id,
570            storvsp_protocol::Operation::COMPLETE_IO,
571            status,
572            payload.as_bytes(),
573        )
574    }
575}
576
577struct ScsiCommandQueue {
578    controller: Arc<ScsiControllerState>,
579    mem: GuestMemory,
580    force_path_id: Option<u8>,
581}
582
583impl ScsiCommandQueue {
584    async fn execute_scsi(&self, full_request: &ScsiRequestAndRange) -> ScsiResult {
585        let request = &full_request.request;
586        let op = ScsiOp(request.payload[0]);
587        let external_data = full_request
588            .external_data
589            .buffer(&full_request.external_data_buf, &self.mem);
590
591        tracing::trace!(
592            path_id = request.path_id,
593            target_id = request.target_id,
594            lun = request.lun,
595            op = ?op,
596            "execute_scsi start...",
597        );
598
599        let path_id = self.force_path_id.unwrap_or(request.path_id);
600
601        let controller_disk = self
602            .controller
603            .disks
604            .read()
605            .get(&ScsiPath {
606                path: path_id,
607                target: request.target_id,
608                lun: request.lun,
609            })
610            .cloned();
611
612        let result = match op {
613            ScsiOp::REPORT_LUNS => {
614                const HEADER_SIZE: usize = size_of::<scsi::LunList>();
615                let mut luns: Vec<u8> = self
616                    .controller
617                    .disks
618                    .read()
619                    .iter()
620                    .flat_map(|(path, _)| {
621                        // Use the original path ID and not the forced one to
622                        // match Hyper-V storvsp behavior.
623                        if request.path_id == path.path && request.target_id == path.target {
624                            Some(path.lun)
625                        } else {
626                            None
627                        }
628                    })
629                    .collect();
630                luns.sort_unstable();
631                let mut data: Vec<u64> = vec![0; luns.len() + 1];
632                let header = scsi::LunList {
633                    length: (luns.len() as u32 * 8).into(),
634                    reserved: [0; 4],
635                };
636                data.as_mut_bytes()[..HEADER_SIZE].copy_from_slice(header.as_bytes());
637                for (i, lun) in luns.iter().enumerate() {
638                    data[i + 1].as_mut_bytes()[..2].copy_from_slice(&(*lun as u16).to_be_bytes());
639                }
640                if external_data.len() >= HEADER_SIZE {
641                    let tx = std::cmp::min(external_data.len(), data.as_bytes().len());
642                    external_data.writer().write(&data.as_bytes()[..tx]).map_or(
643                        ScsiResult {
644                            scsi_status: ScsiStatus::CHECK_CONDITION,
645                            srb_status: SrbStatus::INVALID_REQUEST,
646                            tx: 0,
647                            sense_data: Some(illegal_request_sense(
648                                AdditionalSenseCode::INVALID_CDB,
649                            )),
650                        },
651                        |_| ScsiResult {
652                            scsi_status: ScsiStatus::GOOD,
653                            srb_status: SrbStatus::SUCCESS,
654                            tx,
655                            sense_data: None,
656                        },
657                    )
658                } else {
659                    ScsiResult {
660                        scsi_status: ScsiStatus::GOOD,
661                        srb_status: SrbStatus::SUCCESS,
662                        tx: 0,
663                        sense_data: None,
664                    }
665                }
666            }
667            _ if controller_disk.is_some() => {
668                let mut cdb = [0; 16];
669                cdb.copy_from_slice(&request.payload[0..storvsp_protocol::CDB16GENERIC_LENGTH]);
670                controller_disk
671                    .unwrap()
672                    .disk
673                    .execute_scsi(
674                        &external_data,
675                        &Request {
676                            cdb,
677                            srb_flags: request.srb_flags,
678                        },
679                    )
680                    .await
681            }
682            ScsiOp::INQUIRY => {
683                let cdb = scsi::CdbInquiry::ref_from_prefix(&request.payload)
684                    .unwrap()
685                    .0; // TODO: zerocopy: ref-from-prefix: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
686                if external_data.len() < cdb.allocation_length.get() as usize
687                    || request.data_in != storvsp_protocol::SCSI_IOCTL_DATA_IN
688                    || (cdb.allocation_length.get() as usize) < size_of::<scsi::InquiryDataHeader>()
689                {
690                    ScsiResult {
691                        scsi_status: ScsiStatus::CHECK_CONDITION,
692                        srb_status: SrbStatus::INVALID_REQUEST,
693                        tx: 0,
694                        sense_data: Some(illegal_request_sense(AdditionalSenseCode::INVALID_CDB)),
695                    }
696                } else {
697                    let enable_vpd = cdb.flags.vpd();
698                    if enable_vpd || cdb.page_code != 0 {
699                        // cannot support VPD inquiry for non-existing device (lun).
700                        ScsiResult {
701                            scsi_status: ScsiStatus::CHECK_CONDITION,
702                            srb_status: SrbStatus::INVALID_REQUEST,
703                            tx: 0,
704                            sense_data: Some(illegal_request_sense(
705                                AdditionalSenseCode::INVALID_CDB,
706                            )),
707                        }
708                    } else {
709                        const LOGICAL_UNIT_NOT_PRESENT_DEVICE: u8 = 0x7F;
710                        let mut data = scsidisk::INQUIRY_DATA_TEMPLATE;
711                        data.header.device_type = LOGICAL_UNIT_NOT_PRESENT_DEVICE;
712
713                        if request.lun != 0 {
714                            // Below fields are only set for lun0 inquiry so zero out here.
715                            data.vendor_id = [0; 8];
716                            data.product_id = [0; 16];
717                            data.product_revision_level = [0; 4];
718                        }
719
720                        let datab = data.as_bytes();
721                        let tx = std::cmp::min(
722                            cdb.allocation_length.get() as usize,
723                            size_of::<scsi::InquiryData>(),
724                        );
725                        external_data.writer().write(&datab[..tx]).map_or(
726                            ScsiResult {
727                                scsi_status: ScsiStatus::CHECK_CONDITION,
728                                srb_status: SrbStatus::INVALID_REQUEST,
729                                tx: 0,
730                                sense_data: Some(illegal_request_sense(
731                                    AdditionalSenseCode::INVALID_CDB,
732                                )),
733                            },
734                            |_| ScsiResult {
735                                scsi_status: ScsiStatus::GOOD,
736                                srb_status: SrbStatus::SUCCESS,
737                                tx,
738                                sense_data: None,
739                            },
740                        )
741                    }
742                }
743            }
744            _ => ScsiResult {
745                scsi_status: ScsiStatus::CHECK_CONDITION,
746                srb_status: SrbStatus::INVALID_LUN,
747                tx: 0,
748                sense_data: None,
749            },
750        };
751
752        tracing::trace!(
753            path_id = request.path_id,
754            target_id = request.target_id,
755            lun = request.lun,
756            op = ?op,
757            result = ?result,
758            "execute_scsi completed.",
759        );
760        result
761    }
762}
763
764impl<T: RingMem + 'static> Worker<T> {
765    fn new(
766        controller: Arc<ScsiControllerState>,
767        channel: RawAsyncChannel<T>,
768        channel_index: u16,
769        mem: GuestMemory,
770        channel_control: ChannelControl,
771        io_queue_depth: u32,
772        protocol: Arc<Protocol>,
773        force_path_id: Option<u8>,
774    ) -> anyhow::Result<Self> {
775        let queue = Queue::new(channel)?;
776        #[expect(clippy::disallowed_methods)] // TODO
777        let (source, target) = futures::channel::mpsc::channel(1);
778        controller.add_rescan_notification_source(source);
779
780        let max_io_queue_depth = io_queue_depth.max(1) as usize;
781        Ok(Self {
782            inner: WorkerInner {
783                protocol,
784                request_size: storvsp_protocol::SCSI_REQUEST_LEN_V1,
785                controller: controller.clone(),
786                channel_index,
787                scsi_queue: Arc::new(ScsiCommandQueue {
788                    controller,
789                    mem,
790                    force_path_id,
791                }),
792                scsi_requests: FuturesUnordered::new(),
793                scsi_requests_states: Slab::with_capacity(max_io_queue_depth),
794                channel_control,
795                max_io_queue_depth,
796                future_pool: Vec::new(),
797                full_request_pool: Vec::new(),
798                stats: Default::default(),
799            },
800            queue,
801            rescan_notification: target,
802            fast_select: FastSelect::new(),
803        })
804    }
805
806    async fn wait_for_scsi_requests_complete(&mut self) {
807        tracing::debug!(
808            channel_index = self.inner.channel_index,
809            "wait for IOs completed..."
810        );
811        while let Some((id, _, _)) = self.inner.scsi_requests.next().await {
812            self.inner.scsi_requests_states.remove(id);
813        }
814    }
815}
816
817impl InspectTask<Worker> for WorkerState {
818    fn inspect(&self, req: inspect::Request<'_>, worker: Option<&Worker>) {
819        if let Some(worker) = worker {
820            let mut resp = req.respond();
821            if worker.inner.channel_index == 0 {
822                let (state, version, subchannel_count) = match *worker.inner.protocol.state.read() {
823                    ProtocolState::Init(state) => match state {
824                        InitState::Begin => ("begin_init", None, None),
825                        InitState::QueryVersion => ("query_version", None, None),
826                        InitState::QueryProperties { version } => {
827                            ("query_properties", Some(version), None)
828                        }
829                        InitState::EndInitialization {
830                            version,
831                            subchannel_count,
832                        } => ("end_init", Some(version), subchannel_count),
833                    },
834                    ProtocolState::Ready {
835                        version,
836                        subchannel_count,
837                    } => ("ready", Some(version), Some(subchannel_count)),
838                };
839                resp.field("state", state)
840                    .field("version", version)
841                    .field("subchannel_count", subchannel_count);
842            }
843            resp.field("pending_packets", worker.inner.scsi_requests_states.len())
844                .fields("io", worker.inner.scsi_requests_states.iter())
845                .field("stats", &worker.inner.stats)
846                .field("ring", &worker.queue)
847                .field("max_io_queue_depth", worker.inner.max_io_queue_depth);
848        }
849    }
850}
851
852impl<T: 'static + Send + Sync + RingMem> AsyncRun<Worker<T>> for WorkerState {
853    async fn run(
854        &mut self,
855        stop: &mut StopTask<'_>,
856        worker: &mut Worker<T>,
857    ) -> Result<(), task_control::Cancelled> {
858        let fut = async {
859            if worker.inner.channel_index == 0 {
860                worker.process_primary().await
861            } else {
862                // Wait for initialization to end before processing any
863                // subchannel packets.
864                let protocol_version = loop {
865                    let listener = worker.inner.protocol.ready.listen();
866                    if let ProtocolState::Ready { version, .. } =
867                        *worker.inner.protocol.state.read()
868                    {
869                        break version;
870                    }
871                    tracing::debug!("subchannel waiting for initialization to end");
872                    listener.await
873                };
874                worker
875                    .inner
876                    .process_ready(&mut worker.queue, protocol_version)
877                    .await
878            }
879        };
880
881        match stop.until_stopped(fut).await? {
882            Ok(_) => {}
883            Err(e) => tracing::error!(error = e.as_error(), "process_packets error"),
884        }
885        Ok(())
886    }
887}
888
889impl WorkerInner {
890    /// Awaits the next incoming packet, without checking for any other events (device add/remove notifications or available completions).
891    /// Increments the count of outstanding packets when returning `Ok(Packet)`.
892    async fn next_packet<'a, M: RingMem>(
893        &mut self,
894        reader: &'a mut queue::ReadHalf<'a, M>,
895    ) -> Result<Packet, WorkerError> {
896        let packet = reader.read().await.map_err(WorkerError::Queue)?;
897        let stor_packet =
898            parse_packet(&packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
899        Ok(stor_packet)
900    }
901
902    /// Polls for enough ring space in the outgoing ring to send a packet.
903    ///
904    /// This is used to ensure there is enough space in the ring before
905    /// committing to sending a packet. This avoids the need to save pending
906    /// packets on the side if queue processing is interrupted while the ring is
907    /// full.
908    fn poll_for_ring_space<M: RingMem>(
909        &mut self,
910        cx: &mut Context<'_>,
911        writer: &mut queue::WriteHalf<'_, M>,
912    ) -> Poll<Result<(), WorkerError>> {
913        writer
914            .poll_ready(cx, MAX_VMBUS_PACKET_SIZE)
915            .map_err(WorkerError::Queue)
916    }
917}
918
919const MAX_VMBUS_PACKET_SIZE: usize = ring::PacketSize::in_band(
920    size_of::<storvsp_protocol::Packet>() + storvsp_protocol::SCSI_REQUEST_LEN_MAX,
921);
922
923impl<T: RingMem> Worker<T> {
924    /// Processes the protocol state machine.
925    async fn process_primary(&mut self) -> Result<(), WorkerError> {
926        loop {
927            let current_state = *self.inner.protocol.state.read();
928            match current_state {
929                ProtocolState::Ready { version, .. } => {
930                    break loop {
931                        select_biased! {
932                            r = self.inner.process_ready(&mut self.queue, version).fuse() => break r,
933                            _ = self.fast_select.select((self.rescan_notification.select_next_some(),)).fuse() => {
934                                if version >= Version::Win7
935                                {
936                                    tracing::debug!("rescan notification received, sending ENUMERATE_BUS");
937                                    self.inner.send_packet(&mut self.queue.split().1, storvsp_protocol::Operation::ENUMERATE_BUS, storvsp_protocol::NtStatus::SUCCESS, &())?;
938                                }
939                            }
940                        }
941                    };
942                }
943                ProtocolState::Init(state) => {
944                    let (mut reader, mut writer) = self.queue.split();
945
946                    // Ensure that subsequent calls to `send_completion` won't
947                    // fail due to lack of ring space, to avoid keeping (and saving/restoring) interim states.
948                    poll_fn(|cx| self.inner.poll_for_ring_space(cx, &mut writer)).await?;
949
950                    tracing::debug!(?state, "process_primary");
951                    match state {
952                        InitState::Begin => {
953                            let packet = self.inner.next_packet(&mut reader).await?;
954                            if let PacketData::BeginInitialization = packet.data {
955                                self.inner.send_completion(
956                                    &mut writer,
957                                    &packet,
958                                    storvsp_protocol::NtStatus::SUCCESS,
959                                    &(),
960                                )?;
961                                *self.inner.protocol.state.write() =
962                                    ProtocolState::Init(InitState::QueryVersion);
963                            } else {
964                                tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
965                                self.inner.send_completion(
966                                    &mut writer,
967                                    &packet,
968                                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
969                                    &(),
970                                )?;
971                            }
972                        }
973                        InitState::QueryVersion => {
974                            let packet = self.inner.next_packet(&mut reader).await?;
975                            if let PacketData::QueryProtocolVersion(major_minor) = packet.data {
976                                if let Ok(version) = Version::parse(major_minor) {
977                                    self.inner.send_completion(
978                                        &mut writer,
979                                        &packet,
980                                        storvsp_protocol::NtStatus::SUCCESS,
981                                        &storvsp_protocol::ProtocolVersion {
982                                            major_minor,
983                                            reserved: 0,
984                                        },
985                                    )?;
986                                    self.inner.request_size = version.max_request_size();
987                                    *self.inner.protocol.state.write() =
988                                        ProtocolState::Init(InitState::QueryProperties { version });
989
990                                    tracelimit::info_ratelimited!(
991                                        ?version,
992                                        "scsi version negotiated"
993                                    );
994                                } else {
995                                    self.inner.send_completion(
996                                        &mut writer,
997                                        &packet,
998                                        storvsp_protocol::NtStatus::REVISION_MISMATCH,
999                                        &storvsp_protocol::ProtocolVersion {
1000                                            major_minor,
1001                                            reserved: 0,
1002                                        },
1003                                    )?;
1004                                    *self.inner.protocol.state.write() =
1005                                        ProtocolState::Init(InitState::QueryVersion);
1006                                }
1007                            } else {
1008                                tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1009                                self.inner.send_completion(
1010                                    &mut writer,
1011                                    &packet,
1012                                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1013                                    &(),
1014                                )?;
1015                            }
1016                        }
1017                        InitState::QueryProperties { version } => {
1018                            let packet = self.inner.next_packet(&mut reader).await?;
1019                            if let PacketData::QueryProperties = packet.data {
1020                                let multi_channel_supported = version >= Version::Win8;
1021
1022                                self.inner.send_completion(
1023                                    &mut writer,
1024                                    &packet,
1025                                    storvsp_protocol::NtStatus::SUCCESS,
1026                                    &storvsp_protocol::ChannelProperties {
1027                                        max_transfer_bytes: 0x40000, // 256KB
1028                                        flags: {
1029                                            if multi_channel_supported {
1030                                                storvsp_protocol::STORAGE_CHANNEL_SUPPORTS_MULTI_CHANNEL
1031                                            } else {
1032                                                0
1033                                            }
1034                                        },
1035                                        maximum_sub_channel_count: if multi_channel_supported {
1036                                            self.inner.channel_control.max_subchannels()
1037                                        } else {
1038                                            0
1039                                        },
1040                                        reserved: 0,
1041                                        reserved2: 0,
1042                                        reserved3: [0, 0],
1043                                    },
1044                                )?;
1045                                *self.inner.protocol.state.write() =
1046                                    ProtocolState::Init(InitState::EndInitialization {
1047                                        version,
1048                                        subchannel_count: if multi_channel_supported {
1049                                            None
1050                                        } else {
1051                                            Some(0)
1052                                        },
1053                                    });
1054                            } else {
1055                                tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1056                                self.inner.send_completion(
1057                                    &mut writer,
1058                                    &packet,
1059                                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1060                                    &(),
1061                                )?;
1062                            }
1063                        }
1064                        InitState::EndInitialization {
1065                            version,
1066                            subchannel_count,
1067                        } => {
1068                            let packet = self.inner.next_packet(&mut reader).await?;
1069                            match packet.data {
1070                                PacketData::CreateSubChannels(sub_channel_count)
1071                                    if subchannel_count.is_none() =>
1072                                {
1073                                    if let Err(err) = self
1074                                        .inner
1075                                        .channel_control
1076                                        .enable_subchannels(sub_channel_count)
1077                                    {
1078                                        tracelimit::warn_ratelimited!(
1079                                            ?err,
1080                                            "cannot enable subchannels"
1081                                        );
1082                                        self.inner.send_completion(
1083                                            &mut writer,
1084                                            &packet,
1085                                            storvsp_protocol::NtStatus::INVALID_PARAMETER,
1086                                            &(),
1087                                        )?;
1088                                    } else {
1089                                        self.inner.send_completion(
1090                                            &mut writer,
1091                                            &packet,
1092                                            storvsp_protocol::NtStatus::SUCCESS,
1093                                            &(),
1094                                        )?;
1095                                        *self.inner.protocol.state.write() =
1096                                            ProtocolState::Init(InitState::EndInitialization {
1097                                                version,
1098                                                subchannel_count: Some(sub_channel_count),
1099                                            });
1100                                    }
1101                                }
1102                                PacketData::EndInitialization => {
1103                                    self.inner.send_completion(
1104                                        &mut writer,
1105                                        &packet,
1106                                        storvsp_protocol::NtStatus::SUCCESS,
1107                                        &(),
1108                                    )?;
1109                                    // Reset the rescan notification event now, before the guest has a
1110                                    // chance to send any SCSI requests to scan the bus.
1111                                    self.rescan_notification.try_next().ok();
1112                                    *self.inner.protocol.state.write() = ProtocolState::Ready {
1113                                        version,
1114                                        subchannel_count: subchannel_count.unwrap_or(0),
1115                                    };
1116                                    // Wake up subchannels waiting for the
1117                                    // protocol state to become ready.
1118                                    self.inner.protocol.ready.notify(usize::MAX);
1119                                }
1120                                _ => {
1121                                    tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1122                                    self.inner.send_completion(
1123                                        &mut writer,
1124                                        &packet,
1125                                        storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1126                                        &(),
1127                                    )?;
1128                                }
1129                            }
1130                        }
1131                    }
1132                }
1133            }
1134        }
1135    }
1136}
1137
1138fn convert_srb_status_to_nt_status(srb_status: SrbStatus) -> storvsp_protocol::NtStatus {
1139    match srb_status {
1140        SrbStatus::BUSY => storvsp_protocol::NtStatus::DEVICE_BUSY,
1141        SrbStatus::SUCCESS => storvsp_protocol::NtStatus::SUCCESS,
1142        SrbStatus::INVALID_LUN
1143        | SrbStatus::INVALID_TARGET_ID
1144        | SrbStatus::NO_DEVICE
1145        | SrbStatus::NO_HBA => storvsp_protocol::NtStatus::DEVICE_DOES_NOT_EXIST,
1146        SrbStatus::COMMAND_TIMEOUT | SrbStatus::TIMEOUT => storvsp_protocol::NtStatus::IO_TIMEOUT,
1147        SrbStatus::SELECTION_TIMEOUT => storvsp_protocol::NtStatus::DEVICE_NOT_CONNECTED,
1148        SrbStatus::BAD_FUNCTION | SrbStatus::BAD_SRB_BLOCK_LENGTH => {
1149            storvsp_protocol::NtStatus::INVALID_DEVICE_REQUEST
1150        }
1151        SrbStatus::DATA_OVERRUN => storvsp_protocol::NtStatus::BUFFER_OVERFLOW,
1152        SrbStatus::REQUEST_FLUSHED => storvsp_protocol::NtStatus::UNSUCCESSFUL,
1153        SrbStatus::ABORTED => storvsp_protocol::NtStatus::CANCELLED,
1154        _ => storvsp_protocol::NtStatus::IO_DEVICE_ERROR,
1155    }
1156}
1157
1158impl WorkerInner {
1159    /// Processes packets and SCSI completions after protocol negotiation has finished.
1160    async fn process_ready<M: RingMem>(
1161        &mut self,
1162        queue: &mut Queue<M>,
1163        protocol_version: Version,
1164    ) -> Result<(), WorkerError> {
1165        self.request_size = protocol_version.max_request_size();
1166        poll_fn(|cx| self.poll_process_ready(cx, queue)).await
1167    }
1168
1169    /// Processes packets and SCSI completions after protocol negotiation has finished.
1170    fn poll_process_ready<M: RingMem>(
1171        &mut self,
1172        cx: &mut Context<'_>,
1173        queue: &mut Queue<M>,
1174    ) -> Poll<Result<(), WorkerError>> {
1175        self.stats.wakes.increment();
1176
1177        let (mut reader, mut writer) = queue.split();
1178        let mut total_completions = 0;
1179        let mut total_submissions = 0;
1180        let poll_mode_queue_depth = self.controller.poll_mode_queue_depth.load(Relaxed) as usize;
1181
1182        loop {
1183            // Drive IOs forward and collect completions.
1184            'outer: while !self.scsi_requests_states.is_empty() {
1185                {
1186                    let mut batch = writer.batched();
1187                    loop {
1188                        // Ensure there is room for the completion before consuming
1189                        // the IO so that we don't have to track completed IOs whose
1190                        // completions haven't been sent.
1191                        if !batch
1192                            .can_write(MAX_VMBUS_PACKET_SIZE)
1193                            .map_err(WorkerError::Queue)?
1194                        {
1195                            // This batch is full but there may still be more completions.
1196                            break;
1197                        }
1198                        if let Poll::Ready(Some((request_id, result, future))) =
1199                            self.scsi_requests.poll_next_unpin(cx)
1200                        {
1201                            self.future_pool.push(future);
1202                            self.handle_completion(&mut batch, request_id, result)?;
1203                            total_completions += 1;
1204                        } else {
1205                            tracing::trace!("out of completions");
1206                            break 'outer;
1207                        }
1208                    }
1209                }
1210
1211                // Wait for enough space for any completion packets.
1212                if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1213                    tracing::trace!("out of ring space");
1214                    break;
1215                }
1216            }
1217
1218            let mut submissions = 0;
1219            // Process new requests.
1220            'outer: loop {
1221                if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1222                    break;
1223                }
1224                let mut batch = if self.scsi_requests_states.len() < poll_mode_queue_depth {
1225                    if let Poll::Ready(batch) = reader.poll_read_batch(cx) {
1226                        batch.map_err(WorkerError::Queue)?
1227                    } else {
1228                        tracing::trace!("out of incoming packets");
1229                        break;
1230                    }
1231                } else {
1232                    match reader.try_read_batch() {
1233                        Ok(batch) => batch,
1234                        Err(queue::TryReadError::Empty) => {
1235                            tracing::trace!(
1236                                pending_io_count = self.scsi_requests_states.len(),
1237                                "out of incoming packets, keeping interrupts masked"
1238                            );
1239                            break;
1240                        }
1241                        Err(queue::TryReadError::Queue(err)) => Err(WorkerError::Queue(err))?,
1242                    }
1243                };
1244
1245                let mut packets = batch.packets();
1246                loop {
1247                    if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1248                        break 'outer;
1249                    }
1250                    // Wait for enough space for any completion packets that
1251                    // `handle_packet` may need to send, so that it isn't necessary
1252                    // to track pending completions.
1253                    if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1254                        tracing::trace!("out of ring space");
1255                        break 'outer;
1256                    }
1257
1258                    let packet = if let Some(packet) = packets.next() {
1259                        packet.map_err(WorkerError::Queue)?
1260                    } else {
1261                        break;
1262                    };
1263
1264                    if self.handle_packet(&mut writer, &packet)? {
1265                        submissions += 1;
1266                    }
1267                }
1268            }
1269
1270            // Loop around to poll the IOs if any new IOs were submitted.
1271            if submissions == 0 {
1272                // No need to poll again.
1273                break;
1274            }
1275            total_submissions += submissions;
1276        }
1277
1278        if total_submissions != 0 || total_completions != 0 {
1279            self.stats.ios_submitted.add(total_submissions);
1280            self.stats
1281                .per_wake_submissions
1282                .add_sample(total_submissions);
1283            self.stats
1284                .per_wake_completions
1285                .add_sample(total_completions);
1286            self.stats.ios_completed.add(total_completions);
1287        } else {
1288            self.stats.wakes_spurious.increment();
1289        }
1290
1291        Poll::Pending
1292    }
1293
1294    fn handle_completion<M: RingMem>(
1295        &mut self,
1296        writer: &mut queue::WriteBatch<'_, M>,
1297        request_id: usize,
1298        result: ScsiResult,
1299    ) -> Result<(), WorkerError> {
1300        let state = self.scsi_requests_states.remove(request_id);
1301        let request_size = state.request.request_size;
1302
1303        // Push the request into the pool to avoid reallocating later.
1304        assert_eq!(
1305            Arc::strong_count(&state.request) + Arc::weak_count(&state.request),
1306            1
1307        );
1308        self.full_request_pool.push(state.request);
1309
1310        let status = convert_srb_status_to_nt_status(result.srb_status);
1311        let mut payload = [0; 0x14];
1312        if let Some(sense) = result.sense_data {
1313            payload[..size_of_val(&sense)].copy_from_slice(sense.as_bytes());
1314            tracing::trace!(sense_info = ?payload, sense_key = payload[2], asc = payload[12], "execute_scsi");
1315        };
1316        let response = storvsp_protocol::ScsiRequest {
1317            length: size_of::<storvsp_protocol::ScsiRequest>() as u16,
1318            scsi_status: result.scsi_status,
1319            srb_status: SrbStatusAndFlags::new()
1320                .with_status(result.srb_status)
1321                .with_autosense_valid(result.sense_data.is_some()),
1322            data_transfer_length: result.tx as u32,
1323            cdb_length: storvsp_protocol::CDB16GENERIC_LENGTH as u8,
1324            sense_info_ex_length: storvsp_protocol::VMSCSI_SENSE_BUFFER_SIZE as u8,
1325            payload,
1326            ..storvsp_protocol::ScsiRequest::new_zeroed()
1327        };
1328        self.send_vmbus_packet(
1329            writer,
1330            OutgoingPacketType::Completion,
1331            request_size,
1332            state.transaction_id,
1333            storvsp_protocol::Operation::COMPLETE_IO,
1334            status,
1335            response.as_bytes(),
1336        )?;
1337        Ok(())
1338    }
1339
1340    fn handle_packet<M: RingMem>(
1341        &mut self,
1342        writer: &mut queue::WriteHalf<'_, M>,
1343        packet: &IncomingPacket<'_, M>,
1344    ) -> Result<bool, WorkerError> {
1345        let packet =
1346            parse_packet(packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
1347        let submitted_io = match packet.data {
1348            PacketData::ExecuteScsi(request) => {
1349                self.push_scsi_request(packet.transaction_id, request);
1350                true
1351            }
1352            PacketData::ResetAdapter | PacketData::ResetBus | PacketData::ResetLun => {
1353                // These operations have always been no-ops.
1354                self.send_completion(writer, &packet, storvsp_protocol::NtStatus::SUCCESS, &())?;
1355                false
1356            }
1357            PacketData::CreateSubChannels(new_subchannel_count) if self.channel_index == 0 => {
1358                if let Err(err) = self
1359                    .channel_control
1360                    .enable_subchannels(new_subchannel_count)
1361                {
1362                    tracelimit::warn_ratelimited!(?err, "cannot create subchannels");
1363                    self.send_completion(
1364                        writer,
1365                        &packet,
1366                        storvsp_protocol::NtStatus::INVALID_PARAMETER,
1367                        &(),
1368                    )?;
1369                    false
1370                } else {
1371                    // Update the subchannel count in the protocol state for save.
1372                    if let ProtocolState::Ready {
1373                        subchannel_count, ..
1374                    } = &mut *self.protocol.state.write()
1375                    {
1376                        *subchannel_count = new_subchannel_count;
1377                    } else {
1378                        unreachable!()
1379                    }
1380
1381                    self.send_completion(
1382                        writer,
1383                        &packet,
1384                        storvsp_protocol::NtStatus::SUCCESS,
1385                        &(),
1386                    )?;
1387                    false
1388                }
1389            }
1390            _ => {
1391                tracelimit::warn_ratelimited!(data = ?packet.data, "unexpected packet on ready");
1392                self.send_completion(
1393                    writer,
1394                    &packet,
1395                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1396                    &(),
1397                )?;
1398                false
1399            }
1400        };
1401        Ok(submitted_io)
1402    }
1403
1404    fn push_scsi_request(&mut self, transaction_id: u64, full_request: Arc<ScsiRequestAndRange>) {
1405        let scsi_queue = self.scsi_queue.clone();
1406        let scsi_request_state = ScsiRequestState {
1407            transaction_id,
1408            request: full_request.clone(),
1409        };
1410        let request_id = self.scsi_requests_states.insert(scsi_request_state);
1411        let future = self
1412            .future_pool
1413            .pop()
1414            .unwrap_or_else(|| OversizedBox::new(()));
1415        let future = OversizedBox::refill(future, async move {
1416            scsi_queue.execute_scsi(full_request.as_ref()).await
1417        });
1418        let request = ScsiRequest::new(request_id, oversized_box::coerce!(future));
1419        self.scsi_requests.push(request);
1420    }
1421}
1422
1423impl<T: RingMem> Drop for Worker<T> {
1424    fn drop(&mut self) {
1425        self.inner
1426            .controller
1427            .remove_rescan_notification_source(&self.rescan_notification);
1428    }
1429}
1430
1431#[derive(Debug, Error)]
1432#[error("SCSI path {}:{}:{} is already in use", self.0.path, self.0.target, self.0.lun)]
1433pub struct ScsiPathInUse(pub ScsiPath);
1434
1435#[derive(Debug, Error)]
1436#[error("SCSI path {}:{}:{} is not in use", self.0.path, self.0.target, self.0.lun)]
1437pub struct ScsiPathNotInUse(ScsiPath);
1438
1439#[derive(Clone)]
1440struct ScsiRequestState {
1441    transaction_id: u64,
1442    request: Arc<ScsiRequestAndRange>,
1443}
1444
1445#[derive(Debug)]
1446struct ScsiRequestAndRange {
1447    external_data: Range,
1448    external_data_buf: MultiPagedRangeBuf,
1449    request: storvsp_protocol::ScsiRequest,
1450    request_size: usize,
1451}
1452
1453impl Inspect for ScsiRequestState {
1454    fn inspect(&self, req: inspect::Request<'_>) {
1455        req.respond()
1456            .field("transaction_id", self.transaction_id)
1457            .display(
1458                "address",
1459                &ScsiPath {
1460                    path: self.request.request.path_id,
1461                    target: self.request.request.target_id,
1462                    lun: self.request.request.lun,
1463                },
1464            )
1465            .display_debug("operation", &ScsiOp(self.request.request.payload[0]));
1466    }
1467}
1468
1469impl StorageDevice {
1470    /// Returns a new SCSI device.
1471    pub fn build_scsi(
1472        driver_source: &VmTaskDriverSource,
1473        controller: &ScsiController,
1474        instance_id: Guid,
1475        max_sub_channel_count: u16,
1476        io_queue_depth: u32,
1477    ) -> Self {
1478        Self::build_inner(
1479            driver_source,
1480            controller,
1481            instance_id,
1482            None,
1483            max_sub_channel_count,
1484            io_queue_depth,
1485        )
1486    }
1487
1488    /// Returns a new SCSI device for implementing an IDE accelerator channel
1489    /// for IDE device `device_id` on channel `channel_id`.
1490    pub fn build_ide(
1491        driver_source: &VmTaskDriverSource,
1492        channel_id: u8,
1493        device_id: u8,
1494        disk: ScsiControllerDisk,
1495        io_queue_depth: u32,
1496    ) -> Self {
1497        let path = ScsiPath {
1498            path: channel_id,
1499            target: device_id,
1500            lun: 0,
1501        };
1502
1503        let controller = ScsiController::new();
1504        controller.attach(path, disk).unwrap();
1505
1506        // Construct the specific GUID that drivers in the guest expect for this
1507        // IDE device.
1508        let instance_id = Guid {
1509            data1: channel_id.into(),
1510            data2: device_id.into(),
1511            data3: 0x8899,
1512            data4: [0; 8],
1513        };
1514        Self::build_inner(
1515            driver_source,
1516            &controller,
1517            instance_id,
1518            Some(path),
1519            0,
1520            io_queue_depth,
1521        )
1522    }
1523
1524    fn build_inner(
1525        driver_source: &VmTaskDriverSource,
1526        controller: &ScsiController,
1527        instance_id: Guid,
1528        ide_path: Option<ScsiPath>,
1529        max_sub_channel_count: u16,
1530        io_queue_depth: u32,
1531    ) -> Self {
1532        let workers = (0..max_sub_channel_count + 1)
1533            .map(|channel_index| WorkerAndDriver {
1534                worker: TaskControl::new(WorkerState),
1535                driver: driver_source
1536                    .builder()
1537                    .target_vp(0)
1538                    .run_on_target(true)
1539                    .build(format!("storvsp-{}-{}", instance_id, channel_index)),
1540            })
1541            .collect();
1542
1543        Self {
1544            instance_id,
1545            ide_path,
1546            workers,
1547            controller: controller.state.clone(),
1548            resources: Default::default(),
1549            max_sub_channel_count,
1550            driver_source: driver_source.clone(),
1551            protocol: Arc::new(Protocol {
1552                state: RwLock::new(ProtocolState::Init(InitState::Begin)),
1553                ready: Default::default(),
1554            }),
1555            io_queue_depth,
1556        }
1557    }
1558
1559    fn new_worker(
1560        &mut self,
1561        open_request: &OpenRequest,
1562        channel_index: u16,
1563    ) -> anyhow::Result<&mut Worker> {
1564        let controller = self.controller.clone();
1565
1566        // VMBus doesn't provide a target VP if the channel is not using interrupts. Run on VP 0 in
1567        // that case.
1568        let target_vp = open_request.open_data.target_vp.unwrap_or_default();
1569        let driver = self
1570            .driver_source
1571            .builder()
1572            .target_vp(target_vp)
1573            .run_on_target(true)
1574            .build(format!("storvsp-{}-{}", self.instance_id, channel_index));
1575
1576        let channel = gpadl_channel(&driver, &self.resources, open_request, channel_index)
1577            .context("failed to create vmbus channel")?;
1578
1579        let channel_control = self.resources.channel_control.clone();
1580
1581        tracing::debug!(
1582            target_vp = open_request.open_data.target_vp,
1583            channel_index,
1584            "packet processing starting...",
1585        );
1586
1587        // Force the path ID on incoming SCSI requests to match the IDE
1588        // channel ID, since guests do not reliably set the path ID
1589        // correctly.
1590        let force_path_id = self.ide_path.map(|p| p.path);
1591
1592        let worker = Worker::new(
1593            controller,
1594            channel,
1595            channel_index,
1596            self.resources
1597                .offer_resources
1598                .guest_memory(open_request)
1599                .clone(),
1600            channel_control,
1601            self.io_queue_depth,
1602            self.protocol.clone(),
1603            force_path_id,
1604        )
1605        .map_err(RestoreError::Other)?;
1606
1607        self.workers[channel_index as usize]
1608            .driver
1609            .retarget_vp(target_vp);
1610
1611        Ok(self.workers[channel_index as usize].worker.insert(
1612            &driver,
1613            format!("storvsp worker {}-{}", self.instance_id, channel_index),
1614            worker,
1615        ))
1616    }
1617}
1618
1619/// A disk that can be added to a SCSI controller.
1620#[derive(Clone)]
1621pub struct ScsiControllerDisk {
1622    disk: Arc<dyn AsyncScsiDisk>,
1623}
1624
1625impl ScsiControllerDisk {
1626    /// Creates a new controller disk from an async SCSI disk.
1627    pub fn new(disk: Arc<dyn AsyncScsiDisk>) -> Self {
1628        Self { disk }
1629    }
1630}
1631
1632struct ScsiControllerState {
1633    disks: RwLock<HashMap<ScsiPath, ScsiControllerDisk>>,
1634    rescan_notification_source: Mutex<Vec<futures::channel::mpsc::Sender<()>>>,
1635    poll_mode_queue_depth: AtomicU32,
1636}
1637
1638pub struct ScsiController {
1639    state: Arc<ScsiControllerState>,
1640}
1641
1642impl ScsiController {
1643    pub fn new() -> Self {
1644        Self::new_with_poll_mode_queue_depth(None)
1645    }
1646
1647    pub fn new_with_poll_mode_queue_depth(poll_mode_queue_depth: Option<u32>) -> Self {
1648        Self {
1649            state: Arc::new(ScsiControllerState {
1650                disks: Default::default(),
1651                rescan_notification_source: Mutex::new(Vec::new()),
1652                poll_mode_queue_depth: AtomicU32::new(
1653                    poll_mode_queue_depth.unwrap_or(DEFAULT_POLL_MODE_QUEUE_DEPTH),
1654                ),
1655            }),
1656        }
1657    }
1658
1659    pub fn attach(&self, path: ScsiPath, disk: ScsiControllerDisk) -> Result<(), ScsiPathInUse> {
1660        match self.state.disks.write().entry(path) {
1661            Entry::Occupied(_) => return Err(ScsiPathInUse(path)),
1662            Entry::Vacant(entry) => entry.insert(disk),
1663        };
1664        for source in self.state.rescan_notification_source.lock().iter_mut() {
1665            // Ok to ignore errors here. If the channel is full a previous notification has not yet
1666            // been processed by the primary channel worker.
1667            source.try_send(()).ok();
1668        }
1669        Ok(())
1670    }
1671
1672    pub fn remove(&self, path: ScsiPath) -> Result<(), ScsiPathNotInUse> {
1673        match self.state.disks.write().entry(path) {
1674            Entry::Vacant(_) => return Err(ScsiPathNotInUse(path)),
1675            Entry::Occupied(entry) => {
1676                entry.remove();
1677            }
1678        }
1679        for source in self.state.rescan_notification_source.lock().iter_mut() {
1680            // Ok to ignore errors here. If the channel is full a previous notification has not yet
1681            // been processed by the primary channel worker.
1682            source.try_send(()).ok();
1683        }
1684        Ok(())
1685    }
1686}
1687
1688impl ScsiControllerState {
1689    fn add_rescan_notification_source(&self, source: futures::channel::mpsc::Sender<()>) {
1690        self.rescan_notification_source.lock().push(source);
1691    }
1692
1693    fn remove_rescan_notification_source(&self, target: &futures::channel::mpsc::Receiver<()>) {
1694        let mut sources = self.rescan_notification_source.lock();
1695        if let Some(index) = sources
1696            .iter()
1697            .position(|source| source.is_connected_to(target))
1698        {
1699            sources.remove(index);
1700        }
1701    }
1702}
1703
1704#[async_trait]
1705impl VmbusDevice for StorageDevice {
1706    fn offer(&self) -> OfferParams {
1707        if let Some(path) = self.ide_path {
1708            let offer_properties = storvsp_protocol::OfferProperties {
1709                path_id: path.path,
1710                target_id: path.target,
1711                flags: storvsp_protocol::OFFER_PROPERTIES_FLAG_IDE_DEVICE,
1712                ..FromZeros::new_zeroed()
1713            };
1714            let mut user_defined = UserDefinedData::new_zeroed();
1715            offer_properties
1716                .write_to_prefix(&mut user_defined[..])
1717                .unwrap();
1718            OfferParams {
1719                interface_name: "ide-accel".to_owned(),
1720                instance_id: self.instance_id,
1721                interface_id: storvsp_protocol::IDE_ACCELERATOR_INTERFACE_ID,
1722                channel_type: ChannelType::Interface { user_defined },
1723                ..Default::default()
1724            }
1725        } else {
1726            OfferParams {
1727                interface_name: "scsi".to_owned(),
1728                instance_id: self.instance_id,
1729                interface_id: storvsp_protocol::SCSI_INTERFACE_ID,
1730                ..Default::default()
1731            }
1732        }
1733    }
1734
1735    fn max_subchannels(&self) -> u16 {
1736        self.max_sub_channel_count
1737    }
1738
1739    fn install(&mut self, resources: DeviceResources) {
1740        self.resources = resources;
1741    }
1742
1743    async fn open(
1744        &mut self,
1745        channel_index: u16,
1746        open_request: &OpenRequest,
1747    ) -> Result<(), ChannelOpenError> {
1748        tracing::debug!(channel_index, "scsi open channel");
1749        self.new_worker(open_request, channel_index)?;
1750        self.workers[channel_index as usize].worker.start();
1751        Ok(())
1752    }
1753
1754    async fn close(&mut self, channel_index: u16) {
1755        tracing::debug!(channel_index, "scsi close channel");
1756        let worker = &mut self.workers[channel_index as usize].worker;
1757        worker.stop().await;
1758        if worker.state_mut().is_some() {
1759            worker
1760                .state_mut()
1761                .unwrap()
1762                .wait_for_scsi_requests_complete()
1763                .await;
1764            worker.remove();
1765        }
1766        if channel_index == 0 {
1767            *self.protocol.state.write() = ProtocolState::Init(InitState::Begin);
1768        }
1769    }
1770
1771    async fn retarget_vp(&mut self, channel_index: u16, target_vp: u32) {
1772        self.workers[channel_index as usize]
1773            .driver
1774            .retarget_vp(target_vp);
1775    }
1776
1777    fn start(&mut self) {
1778        for task in self
1779            .workers
1780            .iter_mut()
1781            .filter(|task| task.worker.has_state() && !task.worker.is_running())
1782        {
1783            task.worker.start();
1784        }
1785    }
1786
1787    async fn stop(&mut self) {
1788        tracing::debug!(instance_id = ?self.instance_id, "StorageDevice stopping...");
1789        for task in self
1790            .workers
1791            .iter_mut()
1792            .filter(|task| task.worker.has_state() && task.worker.is_running())
1793        {
1794            task.worker.stop().await;
1795        }
1796    }
1797
1798    fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1799        Some(self)
1800    }
1801}
1802
1803#[async_trait]
1804impl SaveRestoreVmbusDevice for StorageDevice {
1805    async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1806        Ok(SavedStateBlob::new(self.save()?))
1807    }
1808
1809    async fn restore(
1810        &mut self,
1811        control: RestoreControl<'_>,
1812        state: SavedStateBlob,
1813    ) -> Result<(), RestoreError> {
1814        self.restore(control, state.parse()?).await
1815    }
1816}
1817
1818#[cfg(test)]
1819mod tests {
1820    use super::*;
1821    use crate::test_helpers::TestWorker;
1822    use crate::test_helpers::parse_guest_completion;
1823    use crate::test_helpers::parse_guest_completion_check_flags_status;
1824    use pal_async::DefaultDriver;
1825    use pal_async::async_test;
1826    use scsi::srb::SrbStatus;
1827    use test_with_tracing::test;
1828    use vmbus_channel::connected_async_channels;
1829
1830    // Discourage `Clone` for `ScsiController` outside the crate, but it is
1831    // necessary for testing. The fuzzer also uses `TestWorker`, which needs
1832    // a `clone` of the inner state, but is not in this crate.
1833    impl Clone for ScsiController {
1834        fn clone(&self) -> Self {
1835            ScsiController {
1836                state: self.state.clone(),
1837            }
1838        }
1839    }
1840
1841    #[async_test]
1842    async fn test_channel_working(driver: DefaultDriver) {
1843        // set up the channels and worker
1844        let (host, guest) = connected_async_channels(16 * 1024);
1845        let guest_queue = Queue::new(guest).unwrap();
1846
1847        let test_guest_mem = GuestMemory::allocate(16384);
1848        let controller = ScsiController::new();
1849        let disk = scsidisk::SimpleScsiDisk::new(
1850            disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
1851            Default::default(),
1852        );
1853        controller
1854            .attach(
1855                ScsiPath {
1856                    path: 0,
1857                    target: 0,
1858                    lun: 0,
1859                },
1860                ScsiControllerDisk::new(Arc::new(disk)),
1861            )
1862            .unwrap();
1863
1864        let test_worker = TestWorker::start(
1865            controller.clone(),
1866            driver.clone(),
1867            test_guest_mem.clone(),
1868            host,
1869            None,
1870        );
1871
1872        let mut guest = test_helpers::TestGuest {
1873            queue: guest_queue,
1874            transaction_id: 0,
1875        };
1876
1877        guest.perform_protocol_negotiation().await;
1878
1879        // Set up the buffer for a write request
1880        const IO_LEN: usize = 4 * 1024;
1881        let write_buf = [7u8; IO_LEN];
1882        let write_gpa = 4 * 1024u64;
1883        test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
1884        guest
1885            .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
1886            .await;
1887
1888        guest
1889            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1890            .await;
1891
1892        let read_gpa = 8 * 1024u64;
1893        guest
1894            .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
1895            .await;
1896
1897        guest
1898            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1899            .await;
1900        let mut read_buf = [0u8; IO_LEN];
1901        test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
1902        for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
1903            assert_eq!(b1, b2);
1904        }
1905
1906        // stop everything
1907        guest.verify_graceful_close(test_worker).await;
1908    }
1909
1910    #[async_test]
1911    async fn test_packet_sizes(driver: DefaultDriver) {
1912        // set up the channels and worker
1913        let (host, guest) = connected_async_channels(16384);
1914        let guest_queue = Queue::new(guest).unwrap();
1915
1916        let test_guest_mem = GuestMemory::allocate(1024);
1917        let controller = ScsiController::new();
1918
1919        let _worker = TestWorker::start(
1920            controller.clone(),
1921            driver.clone(),
1922            test_guest_mem,
1923            host,
1924            None,
1925        );
1926
1927        let mut guest = test_helpers::TestGuest {
1928            queue: guest_queue,
1929            transaction_id: 0,
1930        };
1931
1932        let negotiate_packet = storvsp_protocol::Packet {
1933            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
1934            flags: 0,
1935            status: storvsp_protocol::NtStatus::SUCCESS,
1936        };
1937        guest
1938            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1939            .await;
1940
1941        guest.verify_completion(parse_guest_completion).await;
1942
1943        let header = storvsp_protocol::Packet {
1944            operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
1945            flags: 0,
1946            status: storvsp_protocol::NtStatus::SUCCESS,
1947        };
1948
1949        let mut buf = [0u8; 128];
1950        storvsp_protocol::ProtocolVersion {
1951            major_minor: !0,
1952            reserved: 0,
1953        }
1954        .write_to_prefix(&mut buf[..])
1955        .unwrap(); // PANIC: Infallable since `ProtcolVersion` is less than 128 bytes
1956
1957        for &(len, resp_len) in &[(48, 48), (50, 56), (56, 56), (64, 64), (72, 64)] {
1958            guest
1959                .send_data_packet_sync(&[header.as_bytes(), &buf[..len - size_of_val(&header)]])
1960                .await;
1961
1962            guest
1963                .verify_completion(|packet| {
1964                    let IncomingPacket::Completion(packet) = packet else {
1965                        unreachable!()
1966                    };
1967                    assert_eq!(packet.reader().len(), resp_len);
1968                    assert_eq!(
1969                        packet
1970                            .reader()
1971                            .read_plain::<storvsp_protocol::Packet>()
1972                            .unwrap()
1973                            .status,
1974                        storvsp_protocol::NtStatus::REVISION_MISMATCH
1975                    );
1976                    Ok(())
1977                })
1978                .await;
1979        }
1980    }
1981
1982    #[async_test]
1983    async fn test_wrong_first_packet(driver: DefaultDriver) {
1984        // set up the channels and worker
1985        let (host, guest) = connected_async_channels(16384);
1986        let guest_queue = Queue::new(guest).unwrap();
1987
1988        let test_guest_mem = GuestMemory::allocate(1024);
1989        let controller = ScsiController::new();
1990
1991        let _worker = TestWorker::start(
1992            controller.clone(),
1993            driver.clone(),
1994            test_guest_mem,
1995            host,
1996            None,
1997        );
1998
1999        let mut guest = test_helpers::TestGuest {
2000            queue: guest_queue,
2001            transaction_id: 0,
2002        };
2003
2004        // Protocol negotiation done out of order
2005        let negotiate_packet = storvsp_protocol::Packet {
2006            operation: storvsp_protocol::Operation::END_INITIALIZATION,
2007            flags: 0,
2008            status: storvsp_protocol::NtStatus::SUCCESS,
2009        };
2010        guest
2011            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2012            .await;
2013
2014        guest
2015            .verify_completion(|packet| {
2016                let IncomingPacket::Completion(packet) = packet else {
2017                    unreachable!()
2018                };
2019                assert_eq!(
2020                    packet
2021                        .reader()
2022                        .read_plain::<storvsp_protocol::Packet>()
2023                        .unwrap()
2024                        .status,
2025                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE
2026                );
2027                Ok(())
2028            })
2029            .await;
2030    }
2031
2032    #[async_test]
2033    async fn test_unrecognized_operation(driver: DefaultDriver) {
2034        // set up the channels and worker
2035        let (host, guest) = connected_async_channels(16384);
2036        let guest_queue = Queue::new(guest).unwrap();
2037
2038        let test_guest_mem = GuestMemory::allocate(1024);
2039        let controller = ScsiController::new();
2040
2041        let worker = TestWorker::start(
2042            controller.clone(),
2043            driver.clone(),
2044            test_guest_mem,
2045            host,
2046            None,
2047        );
2048
2049        let mut guest = test_helpers::TestGuest {
2050            queue: guest_queue,
2051            transaction_id: 0,
2052        };
2053
2054        // Send packet with unrecognized operation
2055        let negotiate_packet = storvsp_protocol::Packet {
2056            operation: storvsp_protocol::Operation::REMOVE_DEVICE,
2057            flags: 0,
2058            status: storvsp_protocol::NtStatus::SUCCESS,
2059        };
2060        guest
2061            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2062            .await;
2063
2064        match worker.teardown().await {
2065            Err(WorkerError::PacketError(PacketError::UnrecognizedOperation(
2066                storvsp_protocol::Operation::REMOVE_DEVICE,
2067            ))) => {}
2068            result => panic!("Worker failed with unexpected result {:?}!", result),
2069        }
2070    }
2071
2072    #[async_test]
2073    async fn test_too_many_subchannels(driver: DefaultDriver) {
2074        // set up the channels and worker
2075        let (host, guest) = connected_async_channels(16384);
2076        let guest_queue = Queue::new(guest).unwrap();
2077
2078        let test_guest_mem = GuestMemory::allocate(1024);
2079        let controller = ScsiController::new();
2080
2081        let _worker = TestWorker::start(
2082            controller.clone(),
2083            driver.clone(),
2084            test_guest_mem,
2085            host,
2086            None,
2087        );
2088
2089        let mut guest = test_helpers::TestGuest {
2090            queue: guest_queue,
2091            transaction_id: 0,
2092        };
2093
2094        let negotiate_packet = storvsp_protocol::Packet {
2095            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2096            flags: 0,
2097            status: storvsp_protocol::NtStatus::SUCCESS,
2098        };
2099        guest
2100            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2101            .await;
2102        guest.verify_completion(parse_guest_completion).await;
2103
2104        let version_packet = storvsp_protocol::Packet {
2105            operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2106            flags: 0,
2107            status: storvsp_protocol::NtStatus::SUCCESS,
2108        };
2109        let version = storvsp_protocol::ProtocolVersion {
2110            major_minor: storvsp_protocol::VERSION_BLUE,
2111            reserved: 0,
2112        };
2113        guest
2114            .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2115            .await;
2116        guest.verify_completion(parse_guest_completion).await;
2117
2118        let properties_packet = storvsp_protocol::Packet {
2119            operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2120            flags: 0,
2121            status: storvsp_protocol::NtStatus::SUCCESS,
2122        };
2123        guest
2124            .send_data_packet_sync(&[properties_packet.as_bytes()])
2125            .await;
2126
2127        guest.verify_completion(parse_guest_completion).await;
2128
2129        let negotiate_packet = storvsp_protocol::Packet {
2130            operation: storvsp_protocol::Operation::CREATE_SUB_CHANNELS,
2131            flags: 0,
2132            status: storvsp_protocol::NtStatus::SUCCESS,
2133        };
2134        // Create sub channels more than maximum_sub_channel_count
2135        guest
2136            .send_data_packet_sync(&[negotiate_packet.as_bytes(), 1_u16.as_bytes()])
2137            .await;
2138
2139        guest
2140            .verify_completion(|packet| {
2141                let IncomingPacket::Completion(packet) = packet else {
2142                    unreachable!()
2143                };
2144                assert_eq!(
2145                    packet
2146                        .reader()
2147                        .read_plain::<storvsp_protocol::Packet>()
2148                        .unwrap()
2149                        .status,
2150                    storvsp_protocol::NtStatus::INVALID_PARAMETER
2151                );
2152                Ok(())
2153            })
2154            .await;
2155    }
2156
2157    #[async_test]
2158    async fn test_begin_init_on_ready(driver: DefaultDriver) {
2159        // set up the channels and worker
2160        let (host, guest) = connected_async_channels(16384);
2161        let guest_queue = Queue::new(guest).unwrap();
2162
2163        let test_guest_mem = GuestMemory::allocate(1024);
2164        let controller = ScsiController::new();
2165
2166        let _worker = TestWorker::start(
2167            controller.clone(),
2168            driver.clone(),
2169            test_guest_mem,
2170            host,
2171            None,
2172        );
2173
2174        let mut guest = test_helpers::TestGuest {
2175            queue: guest_queue,
2176            transaction_id: 0,
2177        };
2178
2179        guest.perform_protocol_negotiation().await;
2180
2181        // Protocol negotiation done out of order
2182        let negotiate_packet = storvsp_protocol::Packet {
2183            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2184            flags: 0,
2185            status: storvsp_protocol::NtStatus::SUCCESS,
2186        };
2187        guest
2188            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2189            .await;
2190
2191        guest
2192            .verify_completion(|p| {
2193                parse_guest_completion_check_flags_status(
2194                    p,
2195                    0,
2196                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
2197                )
2198            })
2199            .await;
2200    }
2201
2202    #[async_test]
2203    async fn test_hot_add_remove(driver: DefaultDriver) {
2204        // set up channels and worker.
2205        let (host, guest) = connected_async_channels(16 * 1024);
2206        let guest_queue = Queue::new(guest).unwrap();
2207
2208        let test_guest_mem = GuestMemory::allocate(16384);
2209        // create a controller with no disk yet.
2210        let controller = ScsiController::new();
2211
2212        let test_worker = TestWorker::start(
2213            controller.clone(),
2214            driver.clone(),
2215            test_guest_mem.clone(),
2216            host,
2217            None,
2218        );
2219
2220        let mut guest = test_helpers::TestGuest {
2221            queue: guest_queue,
2222            transaction_id: 0,
2223        };
2224
2225        guest.perform_protocol_negotiation().await;
2226
2227        // Verify no LUNs are reported initially.
2228        let mut lun_list_buffer: [u8; 256] = [0; 256];
2229        let mut disk_count = 0;
2230        guest
2231            .send_report_luns_packet(ScsiPath::default(), 0, lun_list_buffer.len())
2232            .await;
2233        guest
2234            .verify_completion(|p| {
2235                test_helpers::parse_guest_completed_io_check_tx_len(p, SrbStatus::SUCCESS, Some(8))
2236            })
2237            .await;
2238        test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2239        let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2240        assert_eq!(lun_list_size, disk_count as u32 * 8);
2241
2242        // Set up a buffer for writes.
2243        const IO_LEN: usize = 4 * 1024;
2244        let write_buf = [7u8; IO_LEN];
2245        let write_gpa = 4 * 1024u64;
2246        test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2247
2248        guest
2249            .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2250            .await;
2251        guest
2252            .verify_completion(|p| {
2253                test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2254            })
2255            .await;
2256
2257        // Add some disks while the guest is running.
2258        for lun in 0..4 {
2259            let disk = scsidisk::SimpleScsiDisk::new(
2260                disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
2261                Default::default(),
2262            );
2263            controller
2264                .attach(
2265                    ScsiPath {
2266                        path: 0,
2267                        target: 0,
2268                        lun,
2269                    },
2270                    ScsiControllerDisk::new(Arc::new(disk)),
2271                )
2272                .unwrap();
2273            guest
2274                .verify_completion(test_helpers::parse_guest_enumerate_bus)
2275                .await;
2276
2277            disk_count += 1;
2278            guest
2279                .send_report_luns_packet(ScsiPath::default(), 0, 256)
2280                .await;
2281            guest
2282                .verify_completion(|p| {
2283                    test_helpers::parse_guest_completed_io_check_tx_len(
2284                        p,
2285                        SrbStatus::SUCCESS,
2286                        Some((disk_count + 1) * 8),
2287                    )
2288                })
2289                .await;
2290            test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2291            let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2292            assert_eq!(lun_list_size, disk_count as u32 * 8);
2293
2294            guest
2295                .send_write_packet(
2296                    ScsiPath {
2297                        path: 0,
2298                        target: 0,
2299                        lun,
2300                    },
2301                    write_gpa,
2302                    1,
2303                    IO_LEN,
2304                )
2305                .await;
2306            guest
2307                .verify_completion(|p| {
2308                    test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS)
2309                })
2310                .await;
2311        }
2312
2313        // Remove all disks while the guest is running.
2314        for lun in 0..4 {
2315            controller
2316                .remove(ScsiPath {
2317                    path: 0,
2318                    target: 0,
2319                    lun,
2320                })
2321                .unwrap();
2322            guest
2323                .verify_completion(test_helpers::parse_guest_enumerate_bus)
2324                .await;
2325
2326            disk_count -= 1;
2327            guest
2328                .send_report_luns_packet(ScsiPath::default(), 0, 4096)
2329                .await;
2330            guest
2331                .verify_completion(|p| {
2332                    test_helpers::parse_guest_completed_io_check_tx_len(
2333                        p,
2334                        SrbStatus::SUCCESS,
2335                        Some((disk_count + 1) * 8),
2336                    )
2337                })
2338                .await;
2339            test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2340            let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2341            assert_eq!(lun_list_size, disk_count as u32 * 8);
2342
2343            guest
2344                .send_write_packet(
2345                    ScsiPath {
2346                        path: 0,
2347                        target: 0,
2348                        lun,
2349                    },
2350                    write_gpa,
2351                    1,
2352                    IO_LEN,
2353                )
2354                .await;
2355            guest
2356                .verify_completion(|p| {
2357                    test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2358                })
2359                .await;
2360        }
2361
2362        guest.verify_graceful_close(test_worker).await;
2363    }
2364
2365    #[async_test]
2366    pub async fn test_async_disk(driver: DefaultDriver) {
2367        let device = disklayer_ram::ram_disk(64 * 1024, false).unwrap();
2368        let controller = ScsiController::new();
2369        let disk = ScsiControllerDisk::new(Arc::new(scsidisk::SimpleScsiDisk::new(
2370            device,
2371            Default::default(),
2372        )));
2373        controller
2374            .attach(
2375                ScsiPath {
2376                    path: 0,
2377                    target: 0,
2378                    lun: 0,
2379                },
2380                disk,
2381            )
2382            .unwrap();
2383
2384        let (host, guest) = connected_async_channels(16 * 1024);
2385        let guest_queue = Queue::new(guest).unwrap();
2386
2387        let mut guest = test_helpers::TestGuest {
2388            queue: guest_queue,
2389            transaction_id: 0,
2390        };
2391
2392        let test_guest_mem = GuestMemory::allocate(16384);
2393        let worker = TestWorker::start(
2394            controller.clone(),
2395            &driver,
2396            test_guest_mem.clone(),
2397            host,
2398            None,
2399        );
2400
2401        let negotiate_packet = storvsp_protocol::Packet {
2402            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2403            flags: 0,
2404            status: storvsp_protocol::NtStatus::SUCCESS,
2405        };
2406        guest
2407            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2408            .await;
2409        guest.verify_completion(parse_guest_completion).await;
2410
2411        let version_packet = storvsp_protocol::Packet {
2412            operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2413            flags: 0,
2414            status: storvsp_protocol::NtStatus::SUCCESS,
2415        };
2416        let version = storvsp_protocol::ProtocolVersion {
2417            major_minor: storvsp_protocol::VERSION_BLUE,
2418            reserved: 0,
2419        };
2420        guest
2421            .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2422            .await;
2423        guest.verify_completion(parse_guest_completion).await;
2424
2425        let properties_packet = storvsp_protocol::Packet {
2426            operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2427            flags: 0,
2428            status: storvsp_protocol::NtStatus::SUCCESS,
2429        };
2430        guest
2431            .send_data_packet_sync(&[properties_packet.as_bytes()])
2432            .await;
2433        guest.verify_completion(parse_guest_completion).await;
2434
2435        let negotiate_packet = storvsp_protocol::Packet {
2436            operation: storvsp_protocol::Operation::END_INITIALIZATION,
2437            flags: 0,
2438            status: storvsp_protocol::NtStatus::SUCCESS,
2439        };
2440        guest
2441            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2442            .await;
2443        guest.verify_completion(parse_guest_completion).await;
2444
2445        const IO_LEN: usize = 4 * 1024;
2446        let write_buf = [7u8; IO_LEN];
2447        let write_gpa = 4 * 1024u64;
2448        test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2449        guest
2450            .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2451            .await;
2452        guest
2453            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2454            .await;
2455
2456        let read_gpa = 8 * 1024u64;
2457        guest
2458            .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
2459            .await;
2460        guest
2461            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2462            .await;
2463        let mut read_buf = [0u8; IO_LEN];
2464        test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
2465        for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
2466            assert_eq!(b1, b2);
2467        }
2468
2469        guest.verify_graceful_close(worker).await;
2470    }
2471}