storvsp/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! VMBus SCSI controller emulator (StorVSP).
5//!
6//! StorVSP implements the Hyper-V synthetic SCSI protocol — a VMBus-based
7//! transport that carries SCSI CDBs between the guest's `storvsc` driver and
8//! the VMM. This is not a standard SCSI transport (like iSCSI or SAS); it's a
9//! Hyper-V-specific wire format defined in [`storvsp_protocol`].
10//!
11//! # Architecture
12//!
13//! The crate uses a multi-worker model. The primary VMBus channel handles
14//! protocol version negotiation (Win6 through Blue); sub-channels process I/O
15//! in parallel. Each worker owns a VMBus ring and processes packets
16//! concurrently via `FuturesUnordered`.
17//!
18//! StorVSP handles the transport (ring buffer management, GPADL setup, packet
19//! framing, sub-channel lifecycle) and a few SCSI control commands directly
20//! (`REPORT_LUNS`, `INQUIRY` for absent targets). All actual I/O is delegated
21//! to [`AsyncScsiDisk`] implementations — StorVSP
22//! never interprets SCSI data CDBs itself.
23//!
24//! # Key types
25//!
26//! - [`StorageDevice`] — the VMBus device. Implements `VmbusDevice` and
27//!   `SaveRestoreVmbusDevice`.
28//! - [`ScsiController`] — manages attached disks by [`ScsiPath`]. Supports
29//!   runtime attach/remove.
30//! - [`ScsiControllerDisk`] — wraps `Arc<dyn AsyncScsiDisk>`.
31//!
32//! # Performance
33//!
34//! Poll-mode optimization: when pending I/O count exceeds
35//! `poll_mode_queue_depth`, the worker switches from interrupt-driven to
36//! busy-poll for new requests, reducing guest exit frequency. Future storage
37//! for SCSI request processing is pooled to avoid allocation on the hot path.
38
39#![expect(missing_docs)]
40#![forbid(unsafe_code)]
41
42#[cfg(feature = "ioperf")]
43pub mod ioperf;
44
45#[cfg(feature = "test")]
46pub mod test_helpers;
47
48#[cfg(not(feature = "test"))]
49mod test_helpers;
50
51pub mod resolver;
52mod save_restore;
53
54use crate::ring::gparange::MultiPagedRangeBuf;
55use anyhow::Context as _;
56use async_trait::async_trait;
57use fast_select::FastSelect;
58use futures::FutureExt;
59use futures::StreamExt;
60use futures::select_biased;
61use guestmem::AccessError;
62use guestmem::GuestMemory;
63use guestmem::MemoryRead;
64use guestmem::MemoryWrite;
65use guestmem::ranges::PagedRange;
66use guid::Guid;
67use inspect::Inspect;
68use inspect::InspectMut;
69use inspect_counters::Counter;
70use inspect_counters::Histogram;
71use oversized_box::OversizedBox;
72use parking_lot::Mutex;
73use parking_lot::RwLock;
74use ring::OutgoingPacketType;
75use scsi::AdditionalSenseCode;
76use scsi::ScsiOp;
77use scsi::ScsiStatus;
78use scsi::srb::SrbStatus;
79use scsi::srb::SrbStatusAndFlags;
80use scsi_buffers::RequestBuffers;
81use scsi_core::AsyncScsiDisk;
82use scsi_core::Request;
83use scsi_core::ScsiResult;
84use scsi_defs as scsi;
85use scsidisk::illegal_request_sense;
86use slab::Slab;
87use std::collections::hash_map::Entry;
88use std::collections::hash_map::HashMap;
89use std::fmt::Debug;
90use std::future::Future;
91use std::future::poll_fn;
92use std::pin::Pin;
93use std::sync::Arc;
94use std::sync::atomic::AtomicU32;
95use std::sync::atomic::Ordering::Relaxed;
96use std::task::Context;
97use std::task::Poll;
98use storvsp_resources::ScsiPath;
99use task_control::AsyncRun;
100use task_control::InspectTask;
101use task_control::StopTask;
102use task_control::TaskControl;
103use thiserror::Error;
104use tracing_helpers::ErrorValueExt;
105use unicycle::FuturesUnordered;
106use vmbus_async::queue;
107use vmbus_async::queue::ExternalDataError;
108use vmbus_async::queue::IncomingPacket;
109use vmbus_async::queue::OutgoingPacket;
110use vmbus_async::queue::Queue;
111use vmbus_channel::RawAsyncChannel;
112use vmbus_channel::bus::ChannelType;
113use vmbus_channel::bus::OfferParams;
114use vmbus_channel::bus::OpenRequest;
115use vmbus_channel::channel::ChannelControl;
116use vmbus_channel::channel::ChannelOpenError;
117use vmbus_channel::channel::DeviceResources;
118use vmbus_channel::channel::RestoreControl;
119use vmbus_channel::channel::SaveRestoreVmbusDevice;
120use vmbus_channel::channel::VmbusDevice;
121use vmbus_channel::gpadl_ring::GpadlRingMem;
122use vmbus_channel::gpadl_ring::gpadl_channel;
123use vmbus_core::protocol::UserDefinedData;
124use vmbus_ring as ring;
125use vmbus_ring::RingMem;
126use vmcore::save_restore::RestoreError;
127use vmcore::save_restore::SaveError;
128use vmcore::save_restore::SavedStateBlob;
129use vmcore::vm_task::VmTaskDriver;
130use vmcore::vm_task::VmTaskDriverSource;
131use zerocopy::FromBytes;
132use zerocopy::FromZeros;
133use zerocopy::Immutable;
134use zerocopy::IntoBytes;
135use zerocopy::KnownLayout;
136
137/// The IO queue depth at which the controller switches from guest-signal-driven
138/// to poll-mode operation. This optimization reduces the guest exit rate by
139/// relying on (typically-interrupt-driven) IO completions to drive polling for
140/// new IO requests.
141const DEFAULT_POLL_MODE_QUEUE_DEPTH: u32 = 1;
142
143pub struct StorageDevice {
144    instance_id: Guid,
145    ide_path: Option<ScsiPath>,
146    workers: Vec<WorkerAndDriver>,
147    controller: Arc<ScsiControllerState>,
148    resources: DeviceResources,
149    driver_source: VmTaskDriverSource,
150    max_sub_channel_count: u16,
151    protocol: Arc<Protocol>,
152    io_queue_depth: u32,
153}
154
155#[derive(Inspect)]
156struct WorkerAndDriver {
157    #[inspect(flatten)]
158    worker: TaskControl<WorkerState, Worker>,
159    driver: VmTaskDriver,
160}
161
162struct WorkerState;
163
164impl InspectMut for StorageDevice {
165    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
166        let mut resp = req.respond();
167
168        let disks = self.controller.disks.read();
169        for (path, controller_disk) in disks.iter() {
170            resp.child(&format!("disks/{}", path), |req| {
171                controller_disk.disk.inspect(req);
172            });
173        }
174
175        resp.fields(
176            "channels",
177            self.workers
178                .iter()
179                .filter(|task| task.worker.has_state())
180                .enumerate(),
181        )
182        .field(
183            "poll_mode_queue_depth",
184            inspect::AtomicMut(&self.controller.poll_mode_queue_depth),
185        );
186    }
187}
188
189struct Worker<T: RingMem = GpadlRingMem> {
190    inner: WorkerInner,
191    rescan_notification: futures::channel::mpsc::Receiver<()>,
192    fast_select: FastSelect,
193    queue: Queue<T>,
194}
195
196struct Protocol {
197    state: RwLock<ProtocolState>,
198    /// Signaled when `state` transitions to `ProtocolState::Ready`.
199    ready: event_listener::Event,
200}
201
202struct WorkerInner {
203    protocol: Arc<Protocol>,
204    request_size: usize,
205    controller: Arc<ScsiControllerState>,
206    channel_index: u16,
207    scsi_queue: Arc<ScsiCommandQueue>,
208    scsi_requests: FuturesUnordered<ScsiRequest>,
209    scsi_requests_states: Slab<ScsiRequestState>,
210    full_request_pool: Vec<Arc<ScsiRequestAndRange>>,
211    future_pool: Vec<OversizedBox<(), ScsiOpStorage>>,
212    channel_control: ChannelControl,
213    max_io_queue_depth: usize,
214    stats: WorkerStats,
215}
216
217#[derive(Debug, Default, Inspect)]
218struct WorkerStats {
219    ios_submitted: Counter,
220    ios_completed: Counter,
221    wakes: Counter,
222    wakes_spurious: Counter,
223    per_wake_submissions: Histogram<10>,
224    per_wake_completions: Histogram<10>,
225}
226
227#[repr(u16)]
228#[derive(Copy, Clone, Debug, Inspect, PartialEq, Eq, PartialOrd, Ord)]
229enum Version {
230    Win6 = storvsp_protocol::VERSION_WIN6,
231    Win7 = storvsp_protocol::VERSION_WIN7,
232    Win8 = storvsp_protocol::VERSION_WIN8,
233    Blue = storvsp_protocol::VERSION_BLUE,
234}
235
236#[derive(Debug, Error)]
237#[error("protocol version {0:#x} not supported")]
238struct UnsupportedVersion(u16);
239
240impl Version {
241    fn parse(major_minor: u16) -> Result<Self, UnsupportedVersion> {
242        let version = match major_minor {
243            storvsp_protocol::VERSION_WIN6 => Self::Win6,
244            storvsp_protocol::VERSION_WIN7 => Self::Win7,
245            storvsp_protocol::VERSION_WIN8 => Self::Win8,
246            storvsp_protocol::VERSION_BLUE => Self::Blue,
247            version => return Err(UnsupportedVersion(version)),
248        };
249        assert_eq!(version as u16, major_minor);
250        Ok(version)
251    }
252
253    fn max_request_size(&self) -> usize {
254        match self {
255            Version::Win8 | Version::Blue => storvsp_protocol::SCSI_REQUEST_LEN_V2,
256            Version::Win6 | Version::Win7 => storvsp_protocol::SCSI_REQUEST_LEN_V1,
257        }
258    }
259}
260
261#[derive(Copy, Clone)]
262enum ProtocolState {
263    Init(InitState),
264    Ready {
265        version: Version,
266        subchannel_count: u16,
267    },
268}
269
270#[derive(Copy, Clone, Debug)]
271enum InitState {
272    Begin,
273    QueryVersion,
274    QueryProperties {
275        version: Version,
276    },
277    EndInitialization {
278        version: Version,
279        subchannel_count: Option<u16>,
280    },
281}
282
283/// The internal SCSI operation future type.
284///
285/// This is a boxed future of a large pre-determined size. The box is reused
286/// after a SCSI request completes to avoid allocations in the hot path.
287///
288/// An Option type is used so that the future can be efficiently dropped (via
289/// `Pin::set(x, None)`) before it is stashed away for reuse.
290type ScsiOpStorage = [u64; SCSI_REQUEST_STACK_SIZE / 8];
291type ScsiOpFuture = Pin<OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>>;
292
293/// The amount of space reserved for a ScsiOpFuture.
294///
295/// This was chosen by running `cargo test -p storvsp -- --no-capture` and looking at the required
296/// size that was given in the failure message
297const SCSI_REQUEST_STACK_SIZE: usize = scsi_core::ASYNC_SCSI_DISK_STACK_SIZE + 272;
298
299struct ScsiRequest {
300    request_id: usize,
301    future: Option<ScsiOpFuture>,
302}
303
304impl ScsiRequest {
305    fn new(
306        request_id: usize,
307        future: OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>,
308    ) -> Self {
309        Self {
310            request_id,
311            future: Some(future.into()),
312        }
313    }
314}
315
316impl Future for ScsiRequest {
317    type Output = (usize, ScsiResult, OversizedBox<(), ScsiOpStorage>);
318
319    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
320        let this = self.get_mut();
321        let future = this.future.as_mut().unwrap().as_mut();
322        let result = std::task::ready!(future.poll(cx));
323        // Return the future so that its storage can be reused.
324        let future = this.future.take().unwrap();
325        Poll::Ready((this.request_id, result, OversizedBox::empty_pinned(future)))
326    }
327}
328
329#[derive(Debug, Error)]
330enum WorkerError {
331    #[error("packet error")]
332    PacketError(#[source] PacketError),
333    #[error("queue error")]
334    Queue(#[source] queue::Error),
335    #[error("queue should have enough space but no longer does")]
336    NotEnoughSpace,
337}
338
339#[derive(Debug, Error)]
340enum PacketError {
341    #[error("Not transactional")]
342    NotTransactional,
343    #[error("Unrecognized operation {0:?}")]
344    UnrecognizedOperation(storvsp_protocol::Operation),
345    #[error("Invalid packet type")]
346    InvalidPacketType,
347    #[error("Invalid data transfer length")]
348    InvalidDataTransferLength,
349    #[error("Access error: {0}")]
350    Access(#[source] AccessError),
351    #[error("Range error")]
352    Range(#[source] ExternalDataError),
353}
354
355#[derive(Debug, Default, Clone)]
356struct Range {
357    len: usize,
358    is_write: bool,
359}
360
361impl Range {
362    fn new(buf: &MultiPagedRangeBuf, request: &storvsp_protocol::ScsiRequest) -> Option<Self> {
363        let len = request.data_transfer_length as usize;
364        let is_write = request.data_in != 0;
365        // Ensure there is exactly one range and it's large enough, or there are
366        // zero ranges and there is no associated SCSI buffer.
367        if buf.range_count() > 1 || (len > 0 && buf.first()?.len() < len) {
368            return None;
369        }
370        Some(Self { len, is_write })
371    }
372
373    fn buffer<'a>(
374        &'a self,
375        buf: &'a MultiPagedRangeBuf,
376        guest_memory: &'a GuestMemory,
377    ) -> RequestBuffers<'a> {
378        let mut range = buf.first().unwrap_or_else(PagedRange::empty);
379        range.truncate(self.len);
380        RequestBuffers::new(guest_memory, range, self.is_write)
381    }
382}
383
384#[derive(Debug)]
385struct Packet {
386    data: PacketData,
387    transaction_id: u64,
388    request_size: usize,
389}
390
391#[derive(Debug)]
392enum PacketData {
393    BeginInitialization,
394    EndInitialization,
395    QueryProtocolVersion(u16),
396    QueryProperties,
397    CreateSubChannels(u16),
398    ExecuteScsi(Arc<ScsiRequestAndRange>),
399    ResetBus,
400    ResetAdapter,
401    ResetLun,
402}
403
404#[derive(Debug)]
405pub struct RangeError;
406
407fn parse_packet<T: RingMem>(
408    packet: &IncomingPacket<'_, T>,
409    pool: &mut Vec<Arc<ScsiRequestAndRange>>,
410) -> Result<Packet, PacketError> {
411    let packet = match packet {
412        IncomingPacket::Completion(_) => return Err(PacketError::InvalidPacketType),
413        IncomingPacket::Data(packet) => packet,
414    };
415    let transaction_id = packet
416        .transaction_id()
417        .ok_or(PacketError::NotTransactional)?;
418
419    let mut reader = packet.reader();
420    let header: storvsp_protocol::Packet = reader.read_plain().map_err(PacketError::Access)?;
421    // You would expect that this should be limited to the current protocol
422    // version's maximum packet size, but this is not what Hyper-V does, and
423    // Linux 6.1 relies on this behavior during protocol initialization.
424    let request_size = reader.len().min(storvsp_protocol::SCSI_REQUEST_LEN_MAX);
425    let data = match header.operation {
426        storvsp_protocol::Operation::BEGIN_INITIALIZATION => PacketData::BeginInitialization,
427        storvsp_protocol::Operation::END_INITIALIZATION => PacketData::EndInitialization,
428        storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION => {
429            let mut version = storvsp_protocol::ProtocolVersion::new_zeroed();
430            reader
431                .read(version.as_mut_bytes())
432                .map_err(PacketError::Access)?;
433            PacketData::QueryProtocolVersion(version.major_minor)
434        }
435        storvsp_protocol::Operation::QUERY_PROPERTIES => PacketData::QueryProperties,
436        storvsp_protocol::Operation::EXECUTE_SRB => {
437            let mut full_request = pool.pop().unwrap_or_else(|| {
438                Arc::new(ScsiRequestAndRange {
439                    external_data: Range::default(),
440                    external_data_buf: MultiPagedRangeBuf::new(),
441                    request: storvsp_protocol::ScsiRequest::new_zeroed(),
442                    request_size,
443                })
444            });
445
446            {
447                let full_request = Arc::get_mut(&mut full_request).unwrap();
448                let request_buf = &mut full_request.request.as_mut_bytes()[..request_size];
449                reader.read(request_buf).map_err(PacketError::Access)?;
450
451                full_request.external_data_buf.clear();
452                packet
453                    .read_external_ranges(&mut full_request.external_data_buf)
454                    .map_err(PacketError::Range)?;
455
456                full_request.external_data =
457                    Range::new(&full_request.external_data_buf, &full_request.request)
458                        .ok_or(PacketError::InvalidDataTransferLength)?;
459            }
460
461            PacketData::ExecuteScsi(full_request)
462        }
463        storvsp_protocol::Operation::RESET_LUN => PacketData::ResetLun,
464        storvsp_protocol::Operation::RESET_ADAPTER => PacketData::ResetAdapter,
465        storvsp_protocol::Operation::RESET_BUS => PacketData::ResetBus,
466        storvsp_protocol::Operation::CREATE_SUB_CHANNELS => {
467            let mut sub_channel_count: u16 = 0;
468            reader
469                .read(sub_channel_count.as_mut_bytes())
470                .map_err(PacketError::Access)?;
471            PacketData::CreateSubChannels(sub_channel_count)
472        }
473        _ => return Err(PacketError::UnrecognizedOperation(header.operation)),
474    };
475
476    if let PacketData::ExecuteScsi(_) = data {
477        tracing::trace!(transaction_id, ?data, "parse_packet");
478    } else {
479        tracing::debug!(transaction_id, ?data, "parse_packet");
480    }
481
482    Ok(Packet {
483        data,
484        request_size,
485        transaction_id,
486    })
487}
488
489impl WorkerInner {
490    fn send_vmbus_packet<M: RingMem>(
491        &mut self,
492        writer: &mut queue::WriteBatch<'_, M>,
493        packet_type: OutgoingPacketType<'_>,
494        request_size: usize,
495        transaction_id: u64,
496        operation: storvsp_protocol::Operation,
497        status: storvsp_protocol::NtStatus,
498        payload: &[u8],
499    ) -> Result<(), WorkerError> {
500        let header = storvsp_protocol::Packet {
501            operation,
502            flags: 0,
503            status,
504        };
505
506        let packet_size = size_of_val(&header) + request_size;
507
508        // Zero pad or truncate the payload to the queue's packet size. This is
509        // necessary because Windows guests check that each packet's size is
510        // exactly the largest possible packet size for the negotiated protocol
511        // version.
512        let len = size_of_val(&header) + size_of_val(payload);
513        let padding = [0; storvsp_protocol::SCSI_REQUEST_LEN_MAX];
514        let (payload_bytes, padding_bytes) = if len > packet_size {
515            (&payload[..packet_size - size_of_val(&header)], &[][..])
516        } else {
517            (payload, &padding[..packet_size - len])
518        };
519        assert_eq!(
520            size_of_val(&header) + payload_bytes.len() + padding_bytes.len(),
521            packet_size
522        );
523        writer
524            .try_write(&OutgoingPacket {
525                transaction_id,
526                packet_type,
527                payload: &[header.as_bytes(), payload_bytes, padding_bytes],
528            })
529            .map_err(|err| match err {
530                queue::TryWriteError::Full(_) => WorkerError::NotEnoughSpace,
531                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
532            })
533    }
534
535    fn send_packet<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
536        &mut self,
537        writer: &mut queue::WriteHalf<'_, M>,
538        operation: storvsp_protocol::Operation,
539        status: storvsp_protocol::NtStatus,
540        payload: &P,
541    ) -> Result<(), WorkerError> {
542        self.send_vmbus_packet(
543            &mut writer.batched(),
544            OutgoingPacketType::InBandNoCompletion,
545            self.request_size,
546            0,
547            operation,
548            status,
549            payload.as_bytes(),
550        )
551    }
552
553    fn send_completion<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
554        &mut self,
555        writer: &mut queue::WriteHalf<'_, M>,
556        packet: &Packet,
557        status: storvsp_protocol::NtStatus,
558        payload: &P,
559    ) -> Result<(), WorkerError> {
560        self.send_vmbus_packet(
561            &mut writer.batched(),
562            OutgoingPacketType::Completion,
563            packet.request_size,
564            packet.transaction_id,
565            storvsp_protocol::Operation::COMPLETE_IO,
566            status,
567            payload.as_bytes(),
568        )
569    }
570}
571
572struct ScsiCommandQueue {
573    controller: Arc<ScsiControllerState>,
574    mem: GuestMemory,
575    force_path_id: Option<u8>,
576}
577
578impl ScsiCommandQueue {
579    async fn execute_scsi(&self, full_request: &ScsiRequestAndRange) -> ScsiResult {
580        let request = &full_request.request;
581        let op = ScsiOp(request.payload[0]);
582        let external_data = full_request
583            .external_data
584            .buffer(&full_request.external_data_buf, &self.mem);
585
586        tracing::trace!(
587            path_id = request.path_id,
588            target_id = request.target_id,
589            lun = request.lun,
590            op = ?op,
591            "execute_scsi start...",
592        );
593
594        let path_id = self.force_path_id.unwrap_or(request.path_id);
595
596        let controller_disk = self
597            .controller
598            .disks
599            .read()
600            .get(&ScsiPath {
601                path: path_id,
602                target: request.target_id,
603                lun: request.lun,
604            })
605            .cloned();
606
607        let result = match op {
608            ScsiOp::REPORT_LUNS => {
609                const HEADER_SIZE: usize = size_of::<scsi::LunList>();
610                let mut luns: Vec<u8> = self
611                    .controller
612                    .disks
613                    .read()
614                    .iter()
615                    .flat_map(|(path, _)| {
616                        // Use the original path ID and not the forced one to
617                        // match Hyper-V storvsp behavior.
618                        if request.path_id == path.path && request.target_id == path.target {
619                            Some(path.lun)
620                        } else {
621                            None
622                        }
623                    })
624                    .collect();
625                luns.sort_unstable();
626                let mut data: Vec<u64> = vec![0; luns.len() + 1];
627                let header = scsi::LunList {
628                    length: (luns.len() as u32 * 8).into(),
629                    reserved: [0; 4],
630                };
631                data.as_mut_bytes()[..HEADER_SIZE].copy_from_slice(header.as_bytes());
632                for (i, lun) in luns.iter().enumerate() {
633                    data[i + 1].as_mut_bytes()[..2].copy_from_slice(&(*lun as u16).to_be_bytes());
634                }
635                if external_data.len() >= HEADER_SIZE {
636                    let tx = std::cmp::min(external_data.len(), data.as_bytes().len());
637                    external_data.writer().write(&data.as_bytes()[..tx]).map_or(
638                        ScsiResult {
639                            scsi_status: ScsiStatus::CHECK_CONDITION,
640                            srb_status: SrbStatus::INVALID_REQUEST,
641                            tx: 0,
642                            sense_data: Some(illegal_request_sense(
643                                AdditionalSenseCode::INVALID_CDB,
644                            )),
645                        },
646                        |_| ScsiResult {
647                            scsi_status: ScsiStatus::GOOD,
648                            srb_status: SrbStatus::SUCCESS,
649                            tx,
650                            sense_data: None,
651                        },
652                    )
653                } else {
654                    ScsiResult {
655                        scsi_status: ScsiStatus::GOOD,
656                        srb_status: SrbStatus::SUCCESS,
657                        tx: 0,
658                        sense_data: None,
659                    }
660                }
661            }
662            _ if controller_disk.is_some() => {
663                let mut cdb = [0; 16];
664                cdb.copy_from_slice(&request.payload[0..storvsp_protocol::CDB16GENERIC_LENGTH]);
665                controller_disk
666                    .unwrap()
667                    .disk
668                    .execute_scsi(
669                        &external_data,
670                        &Request {
671                            cdb,
672                            srb_flags: request.srb_flags,
673                        },
674                    )
675                    .await
676            }
677            ScsiOp::INQUIRY => {
678                let cdb = scsi::CdbInquiry::ref_from_prefix(&request.payload)
679                    .unwrap()
680                    .0; // TODO: zerocopy: ref-from-prefix: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
681                if external_data.len() < cdb.allocation_length.get() as usize
682                    || request.data_in != storvsp_protocol::SCSI_IOCTL_DATA_IN
683                    || (cdb.allocation_length.get() as usize) < size_of::<scsi::InquiryDataHeader>()
684                {
685                    ScsiResult {
686                        scsi_status: ScsiStatus::CHECK_CONDITION,
687                        srb_status: SrbStatus::INVALID_REQUEST,
688                        tx: 0,
689                        sense_data: Some(illegal_request_sense(AdditionalSenseCode::INVALID_CDB)),
690                    }
691                } else {
692                    let enable_vpd = cdb.flags.vpd();
693                    if enable_vpd || cdb.page_code != 0 {
694                        // cannot support VPD inquiry for non-existing device (lun).
695                        ScsiResult {
696                            scsi_status: ScsiStatus::CHECK_CONDITION,
697                            srb_status: SrbStatus::INVALID_REQUEST,
698                            tx: 0,
699                            sense_data: Some(illegal_request_sense(
700                                AdditionalSenseCode::INVALID_CDB,
701                            )),
702                        }
703                    } else {
704                        const LOGICAL_UNIT_NOT_PRESENT_DEVICE: u8 = 0x7F;
705                        let mut data = scsidisk::INQUIRY_DATA_TEMPLATE;
706                        data.header.device_type = LOGICAL_UNIT_NOT_PRESENT_DEVICE;
707
708                        if request.lun != 0 {
709                            // Below fields are only set for lun0 inquiry so zero out here.
710                            data.vendor_id = [0; 8];
711                            data.product_id = [0; 16];
712                            data.product_revision_level = [0; 4];
713                        }
714
715                        let datab = data.as_bytes();
716                        let tx = std::cmp::min(
717                            cdb.allocation_length.get() as usize,
718                            size_of::<scsi::InquiryData>(),
719                        );
720                        external_data.writer().write(&datab[..tx]).map_or(
721                            ScsiResult {
722                                scsi_status: ScsiStatus::CHECK_CONDITION,
723                                srb_status: SrbStatus::INVALID_REQUEST,
724                                tx: 0,
725                                sense_data: Some(illegal_request_sense(
726                                    AdditionalSenseCode::INVALID_CDB,
727                                )),
728                            },
729                            |_| ScsiResult {
730                                scsi_status: ScsiStatus::GOOD,
731                                srb_status: SrbStatus::SUCCESS,
732                                tx,
733                                sense_data: None,
734                            },
735                        )
736                    }
737                }
738            }
739            _ => ScsiResult {
740                scsi_status: ScsiStatus::CHECK_CONDITION,
741                srb_status: SrbStatus::INVALID_LUN,
742                tx: 0,
743                sense_data: None,
744            },
745        };
746
747        tracing::trace!(
748            path_id = request.path_id,
749            target_id = request.target_id,
750            lun = request.lun,
751            op = ?op,
752            result = ?result,
753            "execute_scsi completed.",
754        );
755        result
756    }
757}
758
759impl<T: RingMem + 'static> Worker<T> {
760    fn new(
761        controller: Arc<ScsiControllerState>,
762        channel: RawAsyncChannel<T>,
763        channel_index: u16,
764        mem: GuestMemory,
765        channel_control: ChannelControl,
766        io_queue_depth: u32,
767        protocol: Arc<Protocol>,
768        force_path_id: Option<u8>,
769    ) -> anyhow::Result<Self> {
770        let queue = Queue::new(channel)?;
771        #[expect(clippy::disallowed_methods)] // TODO
772        let (source, target) = futures::channel::mpsc::channel(1);
773        controller.add_rescan_notification_source(source);
774
775        let max_io_queue_depth = io_queue_depth.max(1) as usize;
776        Ok(Self {
777            inner: WorkerInner {
778                protocol,
779                request_size: storvsp_protocol::SCSI_REQUEST_LEN_V1,
780                controller: controller.clone(),
781                channel_index,
782                scsi_queue: Arc::new(ScsiCommandQueue {
783                    controller,
784                    mem,
785                    force_path_id,
786                }),
787                scsi_requests: FuturesUnordered::new(),
788                scsi_requests_states: Slab::with_capacity(max_io_queue_depth),
789                channel_control,
790                max_io_queue_depth,
791                future_pool: Vec::new(),
792                full_request_pool: Vec::new(),
793                stats: Default::default(),
794            },
795            queue,
796            rescan_notification: target,
797            fast_select: FastSelect::new(),
798        })
799    }
800
801    async fn wait_for_scsi_requests_complete(&mut self) {
802        tracing::debug!(
803            channel_index = self.inner.channel_index,
804            "wait for IOs completed..."
805        );
806        while let Some((id, _, _)) = self.inner.scsi_requests.next().await {
807            self.inner.scsi_requests_states.remove(id);
808        }
809    }
810}
811
812impl InspectTask<Worker> for WorkerState {
813    fn inspect(&self, req: inspect::Request<'_>, worker: Option<&Worker>) {
814        if let Some(worker) = worker {
815            let mut resp = req.respond();
816            if worker.inner.channel_index == 0 {
817                let (state, version, subchannel_count) = match *worker.inner.protocol.state.read() {
818                    ProtocolState::Init(state) => match state {
819                        InitState::Begin => ("begin_init", None, None),
820                        InitState::QueryVersion => ("query_version", None, None),
821                        InitState::QueryProperties { version } => {
822                            ("query_properties", Some(version), None)
823                        }
824                        InitState::EndInitialization {
825                            version,
826                            subchannel_count,
827                        } => ("end_init", Some(version), subchannel_count),
828                    },
829                    ProtocolState::Ready {
830                        version,
831                        subchannel_count,
832                    } => ("ready", Some(version), Some(subchannel_count)),
833                };
834                resp.field("state", state)
835                    .field("version", version)
836                    .field("subchannel_count", subchannel_count);
837            }
838            resp.field("pending_packets", worker.inner.scsi_requests_states.len())
839                .fields("io", worker.inner.scsi_requests_states.iter())
840                .field("stats", &worker.inner.stats)
841                .field("ring", &worker.queue)
842                .field("max_io_queue_depth", worker.inner.max_io_queue_depth);
843        }
844    }
845}
846
847impl<T: 'static + Send + Sync + RingMem> AsyncRun<Worker<T>> for WorkerState {
848    async fn run(
849        &mut self,
850        stop: &mut StopTask<'_>,
851        worker: &mut Worker<T>,
852    ) -> Result<(), task_control::Cancelled> {
853        let fut = async {
854            if worker.inner.channel_index == 0 {
855                worker.process_primary().await
856            } else {
857                // Wait for initialization to end before processing any
858                // subchannel packets.
859                let protocol_version = loop {
860                    let listener = worker.inner.protocol.ready.listen();
861                    if let ProtocolState::Ready { version, .. } =
862                        *worker.inner.protocol.state.read()
863                    {
864                        break version;
865                    }
866                    tracing::debug!("subchannel waiting for initialization to end");
867                    listener.await
868                };
869                worker
870                    .inner
871                    .process_ready(&mut worker.queue, protocol_version)
872                    .await
873            }
874        };
875
876        match stop.until_stopped(fut).await? {
877            Ok(_) => {}
878            Err(e) => tracing::error!(error = e.as_error(), "process_packets error"),
879        }
880        Ok(())
881    }
882}
883
884impl WorkerInner {
885    /// Awaits the next incoming packet, without checking for any other events (device add/remove notifications or available completions).
886    /// Increments the count of outstanding packets when returning `Ok(Packet)`.
887    async fn next_packet<'a, M: RingMem>(
888        &mut self,
889        reader: &'a mut queue::ReadHalf<'a, M>,
890    ) -> Result<Packet, WorkerError> {
891        let packet = reader.read().await.map_err(WorkerError::Queue)?;
892        let stor_packet =
893            parse_packet(&packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
894        Ok(stor_packet)
895    }
896
897    /// Polls for enough ring space in the outgoing ring to send a packet.
898    ///
899    /// This is used to ensure there is enough space in the ring before
900    /// committing to sending a packet. This avoids the need to save pending
901    /// packets on the side if queue processing is interrupted while the ring is
902    /// full.
903    fn poll_for_ring_space<M: RingMem>(
904        &mut self,
905        cx: &mut Context<'_>,
906        writer: &mut queue::WriteHalf<'_, M>,
907    ) -> Poll<Result<(), WorkerError>> {
908        writer
909            .poll_ready(cx, MAX_VMBUS_PACKET_SIZE)
910            .map_err(WorkerError::Queue)
911    }
912}
913
914const MAX_VMBUS_PACKET_SIZE: usize = ring::PacketSize::in_band(
915    size_of::<storvsp_protocol::Packet>() + storvsp_protocol::SCSI_REQUEST_LEN_MAX,
916);
917
918impl<T: RingMem> Worker<T> {
919    /// Processes the protocol state machine.
920    async fn process_primary(&mut self) -> Result<(), WorkerError> {
921        loop {
922            let current_state = *self.inner.protocol.state.read();
923            match current_state {
924                ProtocolState::Ready { version, .. } => {
925                    break loop {
926                        select_biased! {
927                            r = self.inner.process_ready(&mut self.queue, version).fuse() => break r,
928                            _ = self.fast_select.select((self.rescan_notification.select_next_some(),)).fuse() => {
929                                if version >= Version::Win7
930                                {
931                                    tracing::debug!("rescan notification received, sending ENUMERATE_BUS");
932                                    self.inner.send_packet(&mut self.queue.split().1, storvsp_protocol::Operation::ENUMERATE_BUS, storvsp_protocol::NtStatus::SUCCESS, &())?;
933                                }
934                            }
935                        }
936                    };
937                }
938                ProtocolState::Init(state) => {
939                    let (mut reader, mut writer) = self.queue.split();
940
941                    // Ensure that subsequent calls to `send_completion` won't
942                    // fail due to lack of ring space, to avoid keeping (and saving/restoring) interim states.
943                    poll_fn(|cx| self.inner.poll_for_ring_space(cx, &mut writer)).await?;
944
945                    tracing::debug!(?state, "process_primary");
946                    match state {
947                        InitState::Begin => {
948                            let packet = self.inner.next_packet(&mut reader).await?;
949                            if let PacketData::BeginInitialization = packet.data {
950                                self.inner.send_completion(
951                                    &mut writer,
952                                    &packet,
953                                    storvsp_protocol::NtStatus::SUCCESS,
954                                    &(),
955                                )?;
956                                *self.inner.protocol.state.write() =
957                                    ProtocolState::Init(InitState::QueryVersion);
958                            } else {
959                                tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
960                                self.inner.send_completion(
961                                    &mut writer,
962                                    &packet,
963                                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
964                                    &(),
965                                )?;
966                            }
967                        }
968                        InitState::QueryVersion => {
969                            let packet = self.inner.next_packet(&mut reader).await?;
970                            if let PacketData::QueryProtocolVersion(major_minor) = packet.data {
971                                if let Ok(version) = Version::parse(major_minor) {
972                                    self.inner.send_completion(
973                                        &mut writer,
974                                        &packet,
975                                        storvsp_protocol::NtStatus::SUCCESS,
976                                        &storvsp_protocol::ProtocolVersion {
977                                            major_minor,
978                                            reserved: 0,
979                                        },
980                                    )?;
981                                    self.inner.request_size = version.max_request_size();
982                                    *self.inner.protocol.state.write() =
983                                        ProtocolState::Init(InitState::QueryProperties { version });
984
985                                    tracelimit::info_ratelimited!(
986                                        ?version,
987                                        "scsi version negotiated"
988                                    );
989                                } else {
990                                    self.inner.send_completion(
991                                        &mut writer,
992                                        &packet,
993                                        storvsp_protocol::NtStatus::REVISION_MISMATCH,
994                                        &storvsp_protocol::ProtocolVersion {
995                                            major_minor,
996                                            reserved: 0,
997                                        },
998                                    )?;
999                                    *self.inner.protocol.state.write() =
1000                                        ProtocolState::Init(InitState::QueryVersion);
1001                                }
1002                            } else {
1003                                tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1004                                self.inner.send_completion(
1005                                    &mut writer,
1006                                    &packet,
1007                                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1008                                    &(),
1009                                )?;
1010                            }
1011                        }
1012                        InitState::QueryProperties { version } => {
1013                            let packet = self.inner.next_packet(&mut reader).await?;
1014                            if let PacketData::QueryProperties = packet.data {
1015                                let multi_channel_supported = version >= Version::Win8;
1016
1017                                self.inner.send_completion(
1018                                    &mut writer,
1019                                    &packet,
1020                                    storvsp_protocol::NtStatus::SUCCESS,
1021                                    &storvsp_protocol::ChannelProperties {
1022                                        max_transfer_bytes: 0x40000, // 256KB
1023                                        flags: {
1024                                            if multi_channel_supported {
1025                                                storvsp_protocol::STORAGE_CHANNEL_SUPPORTS_MULTI_CHANNEL
1026                                            } else {
1027                                                0
1028                                            }
1029                                        },
1030                                        maximum_sub_channel_count: if multi_channel_supported {
1031                                            self.inner.channel_control.max_subchannels()
1032                                        } else {
1033                                            0
1034                                        },
1035                                        reserved: 0,
1036                                        reserved2: 0,
1037                                        reserved3: [0, 0],
1038                                    },
1039                                )?;
1040                                *self.inner.protocol.state.write() =
1041                                    ProtocolState::Init(InitState::EndInitialization {
1042                                        version,
1043                                        subchannel_count: if multi_channel_supported {
1044                                            None
1045                                        } else {
1046                                            Some(0)
1047                                        },
1048                                    });
1049                            } else {
1050                                tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1051                                self.inner.send_completion(
1052                                    &mut writer,
1053                                    &packet,
1054                                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1055                                    &(),
1056                                )?;
1057                            }
1058                        }
1059                        InitState::EndInitialization {
1060                            version,
1061                            subchannel_count,
1062                        } => {
1063                            let packet = self.inner.next_packet(&mut reader).await?;
1064                            match packet.data {
1065                                PacketData::CreateSubChannels(sub_channel_count)
1066                                    if subchannel_count.is_none() =>
1067                                {
1068                                    if let Err(err) = self
1069                                        .inner
1070                                        .channel_control
1071                                        .enable_subchannels(sub_channel_count)
1072                                    {
1073                                        tracelimit::warn_ratelimited!(
1074                                            ?err,
1075                                            "cannot enable subchannels"
1076                                        );
1077                                        self.inner.send_completion(
1078                                            &mut writer,
1079                                            &packet,
1080                                            storvsp_protocol::NtStatus::INVALID_PARAMETER,
1081                                            &(),
1082                                        )?;
1083                                    } else {
1084                                        self.inner.send_completion(
1085                                            &mut writer,
1086                                            &packet,
1087                                            storvsp_protocol::NtStatus::SUCCESS,
1088                                            &(),
1089                                        )?;
1090                                        *self.inner.protocol.state.write() =
1091                                            ProtocolState::Init(InitState::EndInitialization {
1092                                                version,
1093                                                subchannel_count: Some(sub_channel_count),
1094                                            });
1095                                    }
1096                                }
1097                                PacketData::EndInitialization => {
1098                                    self.inner.send_completion(
1099                                        &mut writer,
1100                                        &packet,
1101                                        storvsp_protocol::NtStatus::SUCCESS,
1102                                        &(),
1103                                    )?;
1104                                    // Reset the rescan notification event now, before the guest has a
1105                                    // chance to send any SCSI requests to scan the bus.
1106                                    self.rescan_notification.try_next().ok();
1107                                    *self.inner.protocol.state.write() = ProtocolState::Ready {
1108                                        version,
1109                                        subchannel_count: subchannel_count.unwrap_or(0),
1110                                    };
1111                                    // Wake up subchannels waiting for the
1112                                    // protocol state to become ready.
1113                                    self.inner.protocol.ready.notify(usize::MAX);
1114                                }
1115                                _ => {
1116                                    tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1117                                    self.inner.send_completion(
1118                                        &mut writer,
1119                                        &packet,
1120                                        storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1121                                        &(),
1122                                    )?;
1123                                }
1124                            }
1125                        }
1126                    }
1127                }
1128            }
1129        }
1130    }
1131}
1132
1133fn convert_srb_status_to_nt_status(srb_status: SrbStatus) -> storvsp_protocol::NtStatus {
1134    match srb_status {
1135        SrbStatus::BUSY => storvsp_protocol::NtStatus::DEVICE_BUSY,
1136        SrbStatus::SUCCESS => storvsp_protocol::NtStatus::SUCCESS,
1137        SrbStatus::INVALID_LUN
1138        | SrbStatus::INVALID_TARGET_ID
1139        | SrbStatus::NO_DEVICE
1140        | SrbStatus::NO_HBA => storvsp_protocol::NtStatus::DEVICE_DOES_NOT_EXIST,
1141        SrbStatus::COMMAND_TIMEOUT | SrbStatus::TIMEOUT => storvsp_protocol::NtStatus::IO_TIMEOUT,
1142        SrbStatus::SELECTION_TIMEOUT => storvsp_protocol::NtStatus::DEVICE_NOT_CONNECTED,
1143        SrbStatus::BAD_FUNCTION | SrbStatus::BAD_SRB_BLOCK_LENGTH => {
1144            storvsp_protocol::NtStatus::INVALID_DEVICE_REQUEST
1145        }
1146        SrbStatus::DATA_OVERRUN => storvsp_protocol::NtStatus::BUFFER_OVERFLOW,
1147        SrbStatus::REQUEST_FLUSHED => storvsp_protocol::NtStatus::UNSUCCESSFUL,
1148        SrbStatus::ABORTED => storvsp_protocol::NtStatus::CANCELLED,
1149        _ => storvsp_protocol::NtStatus::IO_DEVICE_ERROR,
1150    }
1151}
1152
1153impl WorkerInner {
1154    /// Processes packets and SCSI completions after protocol negotiation has finished.
1155    async fn process_ready<M: RingMem>(
1156        &mut self,
1157        queue: &mut Queue<M>,
1158        protocol_version: Version,
1159    ) -> Result<(), WorkerError> {
1160        self.request_size = protocol_version.max_request_size();
1161        poll_fn(|cx| self.poll_process_ready(cx, queue)).await
1162    }
1163
1164    /// Processes packets and SCSI completions after protocol negotiation has finished.
1165    fn poll_process_ready<M: RingMem>(
1166        &mut self,
1167        cx: &mut Context<'_>,
1168        queue: &mut Queue<M>,
1169    ) -> Poll<Result<(), WorkerError>> {
1170        self.stats.wakes.increment();
1171
1172        let (mut reader, mut writer) = queue.split();
1173        let mut total_completions = 0;
1174        let mut total_submissions = 0;
1175        let poll_mode_queue_depth = self.controller.poll_mode_queue_depth.load(Relaxed) as usize;
1176
1177        loop {
1178            // Drive IOs forward and collect completions.
1179            'outer: while !self.scsi_requests_states.is_empty() {
1180                {
1181                    let mut batch = writer.batched();
1182                    loop {
1183                        // Ensure there is room for the completion before consuming
1184                        // the IO so that we don't have to track completed IOs whose
1185                        // completions haven't been sent.
1186                        if !batch
1187                            .can_write(MAX_VMBUS_PACKET_SIZE)
1188                            .map_err(WorkerError::Queue)?
1189                        {
1190                            // This batch is full but there may still be more completions.
1191                            break;
1192                        }
1193                        if let Poll::Ready(Some((request_id, result, future))) =
1194                            self.scsi_requests.poll_next_unpin(cx)
1195                        {
1196                            self.future_pool.push(future);
1197                            self.handle_completion(&mut batch, request_id, result)?;
1198                            total_completions += 1;
1199                        } else {
1200                            tracing::trace!("out of completions");
1201                            break 'outer;
1202                        }
1203                    }
1204                }
1205
1206                // Wait for enough space for any completion packets.
1207                if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1208                    tracing::trace!("out of ring space");
1209                    break;
1210                }
1211            }
1212
1213            let mut submissions = 0;
1214            // Process new requests.
1215            'outer: loop {
1216                if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1217                    break;
1218                }
1219                let mut batch = if self.scsi_requests_states.len() < poll_mode_queue_depth {
1220                    if let Poll::Ready(batch) = reader.poll_read_batch(cx) {
1221                        batch.map_err(WorkerError::Queue)?
1222                    } else {
1223                        tracing::trace!("out of incoming packets");
1224                        break;
1225                    }
1226                } else {
1227                    match reader.try_read_batch() {
1228                        Ok(batch) => batch,
1229                        Err(queue::TryReadError::Empty) => {
1230                            tracing::trace!(
1231                                pending_io_count = self.scsi_requests_states.len(),
1232                                "out of incoming packets, keeping interrupts masked"
1233                            );
1234                            break;
1235                        }
1236                        Err(queue::TryReadError::Queue(err)) => Err(WorkerError::Queue(err))?,
1237                    }
1238                };
1239
1240                let mut packets = batch.packets();
1241                loop {
1242                    if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1243                        break 'outer;
1244                    }
1245                    // Wait for enough space for any completion packets that
1246                    // `handle_packet` may need to send, so that it isn't necessary
1247                    // to track pending completions.
1248                    if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1249                        tracing::trace!("out of ring space");
1250                        break 'outer;
1251                    }
1252
1253                    let packet = if let Some(packet) = packets.next() {
1254                        packet.map_err(WorkerError::Queue)?
1255                    } else {
1256                        break;
1257                    };
1258
1259                    if self.handle_packet(&mut writer, &packet)? {
1260                        submissions += 1;
1261                    }
1262                }
1263            }
1264
1265            // Loop around to poll the IOs if any new IOs were submitted.
1266            if submissions == 0 {
1267                // No need to poll again.
1268                break;
1269            }
1270            total_submissions += submissions;
1271        }
1272
1273        if total_submissions != 0 || total_completions != 0 {
1274            self.stats.ios_submitted.add(total_submissions);
1275            self.stats
1276                .per_wake_submissions
1277                .add_sample(total_submissions);
1278            self.stats
1279                .per_wake_completions
1280                .add_sample(total_completions);
1281            self.stats.ios_completed.add(total_completions);
1282        } else {
1283            self.stats.wakes_spurious.increment();
1284        }
1285
1286        Poll::Pending
1287    }
1288
1289    fn handle_completion<M: RingMem>(
1290        &mut self,
1291        writer: &mut queue::WriteBatch<'_, M>,
1292        request_id: usize,
1293        result: ScsiResult,
1294    ) -> Result<(), WorkerError> {
1295        let state = self.scsi_requests_states.remove(request_id);
1296        let request_size = state.request.request_size;
1297
1298        // Push the request into the pool to avoid reallocating later.
1299        assert_eq!(
1300            Arc::strong_count(&state.request) + Arc::weak_count(&state.request),
1301            1
1302        );
1303        self.full_request_pool.push(state.request);
1304
1305        let status = convert_srb_status_to_nt_status(result.srb_status);
1306        let mut payload = [0; 0x14];
1307        if let Some(sense) = result.sense_data {
1308            payload[..size_of_val(&sense)].copy_from_slice(sense.as_bytes());
1309            tracing::trace!(sense_info = ?payload, sense_key = payload[2], asc = payload[12], "execute_scsi");
1310        };
1311        let response = storvsp_protocol::ScsiRequest {
1312            length: size_of::<storvsp_protocol::ScsiRequest>() as u16,
1313            scsi_status: result.scsi_status,
1314            srb_status: SrbStatusAndFlags::new()
1315                .with_status(result.srb_status)
1316                .with_autosense_valid(result.sense_data.is_some()),
1317            data_transfer_length: result.tx as u32,
1318            cdb_length: storvsp_protocol::CDB16GENERIC_LENGTH as u8,
1319            sense_info_ex_length: storvsp_protocol::VMSCSI_SENSE_BUFFER_SIZE as u8,
1320            payload,
1321            ..storvsp_protocol::ScsiRequest::new_zeroed()
1322        };
1323        self.send_vmbus_packet(
1324            writer,
1325            OutgoingPacketType::Completion,
1326            request_size,
1327            state.transaction_id,
1328            storvsp_protocol::Operation::COMPLETE_IO,
1329            status,
1330            response.as_bytes(),
1331        )?;
1332        Ok(())
1333    }
1334
1335    fn handle_packet<M: RingMem>(
1336        &mut self,
1337        writer: &mut queue::WriteHalf<'_, M>,
1338        packet: &IncomingPacket<'_, M>,
1339    ) -> Result<bool, WorkerError> {
1340        let packet =
1341            parse_packet(packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
1342        let submitted_io = match packet.data {
1343            PacketData::ExecuteScsi(request) => {
1344                self.push_scsi_request(packet.transaction_id, request);
1345                true
1346            }
1347            PacketData::ResetAdapter | PacketData::ResetBus | PacketData::ResetLun => {
1348                // These operations have always been no-ops.
1349                self.send_completion(writer, &packet, storvsp_protocol::NtStatus::SUCCESS, &())?;
1350                false
1351            }
1352            PacketData::CreateSubChannels(new_subchannel_count) if self.channel_index == 0 => {
1353                if let Err(err) = self
1354                    .channel_control
1355                    .enable_subchannels(new_subchannel_count)
1356                {
1357                    tracelimit::warn_ratelimited!(?err, "cannot create subchannels");
1358                    self.send_completion(
1359                        writer,
1360                        &packet,
1361                        storvsp_protocol::NtStatus::INVALID_PARAMETER,
1362                        &(),
1363                    )?;
1364                    false
1365                } else {
1366                    // Update the subchannel count in the protocol state for save.
1367                    if let ProtocolState::Ready {
1368                        subchannel_count, ..
1369                    } = &mut *self.protocol.state.write()
1370                    {
1371                        *subchannel_count = new_subchannel_count;
1372                    } else {
1373                        unreachable!()
1374                    }
1375
1376                    self.send_completion(
1377                        writer,
1378                        &packet,
1379                        storvsp_protocol::NtStatus::SUCCESS,
1380                        &(),
1381                    )?;
1382                    false
1383                }
1384            }
1385            _ => {
1386                tracelimit::warn_ratelimited!(data = ?packet.data, "unexpected packet on ready");
1387                self.send_completion(
1388                    writer,
1389                    &packet,
1390                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1391                    &(),
1392                )?;
1393                false
1394            }
1395        };
1396        Ok(submitted_io)
1397    }
1398
1399    fn push_scsi_request(&mut self, transaction_id: u64, full_request: Arc<ScsiRequestAndRange>) {
1400        let scsi_queue = self.scsi_queue.clone();
1401        let scsi_request_state = ScsiRequestState {
1402            transaction_id,
1403            request: full_request.clone(),
1404        };
1405        let request_id = self.scsi_requests_states.insert(scsi_request_state);
1406        let future = self
1407            .future_pool
1408            .pop()
1409            .unwrap_or_else(|| OversizedBox::new(()));
1410        let future = OversizedBox::refill(future, async move {
1411            scsi_queue.execute_scsi(full_request.as_ref()).await
1412        });
1413        let request = ScsiRequest::new(request_id, oversized_box::coerce!(future));
1414        self.scsi_requests.push(request);
1415    }
1416}
1417
1418impl<T: RingMem> Drop for Worker<T> {
1419    fn drop(&mut self) {
1420        self.inner
1421            .controller
1422            .remove_rescan_notification_source(&self.rescan_notification);
1423    }
1424}
1425
1426#[derive(Debug, Error)]
1427#[error("SCSI path {}:{}:{} is already in use", self.0.path, self.0.target, self.0.lun)]
1428pub struct ScsiPathInUse(pub ScsiPath);
1429
1430#[derive(Debug, Error)]
1431#[error("SCSI path {}:{}:{} is not in use", self.0.path, self.0.target, self.0.lun)]
1432pub struct ScsiPathNotInUse(ScsiPath);
1433
1434#[derive(Clone)]
1435struct ScsiRequestState {
1436    transaction_id: u64,
1437    request: Arc<ScsiRequestAndRange>,
1438}
1439
1440#[derive(Debug)]
1441struct ScsiRequestAndRange {
1442    external_data: Range,
1443    external_data_buf: MultiPagedRangeBuf,
1444    request: storvsp_protocol::ScsiRequest,
1445    request_size: usize,
1446}
1447
1448impl Inspect for ScsiRequestState {
1449    fn inspect(&self, req: inspect::Request<'_>) {
1450        req.respond()
1451            .field("transaction_id", self.transaction_id)
1452            .display(
1453                "address",
1454                &ScsiPath {
1455                    path: self.request.request.path_id,
1456                    target: self.request.request.target_id,
1457                    lun: self.request.request.lun,
1458                },
1459            )
1460            .display_debug("operation", &ScsiOp(self.request.request.payload[0]));
1461    }
1462}
1463
1464impl StorageDevice {
1465    /// Returns a new SCSI device.
1466    pub fn build_scsi(
1467        driver_source: &VmTaskDriverSource,
1468        controller: &ScsiController,
1469        instance_id: Guid,
1470        max_sub_channel_count: u16,
1471        io_queue_depth: u32,
1472    ) -> Self {
1473        Self::build_inner(
1474            driver_source,
1475            controller,
1476            instance_id,
1477            None,
1478            max_sub_channel_count,
1479            io_queue_depth,
1480        )
1481    }
1482
1483    /// Returns a new SCSI device for implementing an IDE accelerator channel
1484    /// for IDE device `device_id` on channel `channel_id`.
1485    pub fn build_ide(
1486        driver_source: &VmTaskDriverSource,
1487        channel_id: u8,
1488        device_id: u8,
1489        disk: ScsiControllerDisk,
1490        io_queue_depth: u32,
1491    ) -> Self {
1492        let path = ScsiPath {
1493            path: channel_id,
1494            target: device_id,
1495            lun: 0,
1496        };
1497
1498        let controller = ScsiController::new();
1499        controller.attach(path, disk).unwrap();
1500
1501        // Construct the specific GUID that drivers in the guest expect for this
1502        // IDE device.
1503        let instance_id = Guid {
1504            data1: channel_id.into(),
1505            data2: device_id.into(),
1506            data3: 0x8899,
1507            data4: [0; 8],
1508        };
1509        Self::build_inner(
1510            driver_source,
1511            &controller,
1512            instance_id,
1513            Some(path),
1514            0,
1515            io_queue_depth,
1516        )
1517    }
1518
1519    fn build_inner(
1520        driver_source: &VmTaskDriverSource,
1521        controller: &ScsiController,
1522        instance_id: Guid,
1523        ide_path: Option<ScsiPath>,
1524        max_sub_channel_count: u16,
1525        io_queue_depth: u32,
1526    ) -> Self {
1527        let workers = (0..max_sub_channel_count + 1)
1528            .map(|channel_index| WorkerAndDriver {
1529                worker: TaskControl::new(WorkerState),
1530                driver: driver_source
1531                    .builder()
1532                    .target_vp(0)
1533                    .run_on_target(true)
1534                    .build(format!("storvsp-{}-{}", instance_id, channel_index)),
1535            })
1536            .collect();
1537
1538        Self {
1539            instance_id,
1540            ide_path,
1541            workers,
1542            controller: controller.state.clone(),
1543            resources: Default::default(),
1544            max_sub_channel_count,
1545            driver_source: driver_source.clone(),
1546            protocol: Arc::new(Protocol {
1547                state: RwLock::new(ProtocolState::Init(InitState::Begin)),
1548                ready: Default::default(),
1549            }),
1550            io_queue_depth,
1551        }
1552    }
1553
1554    fn new_worker(
1555        &mut self,
1556        open_request: &OpenRequest,
1557        channel_index: u16,
1558    ) -> anyhow::Result<&mut Worker> {
1559        let controller = self.controller.clone();
1560
1561        // VMBus doesn't provide a target VP if the channel is not using interrupts. Run on VP 0 in
1562        // that case.
1563        let target_vp = open_request.open_data.target_vp.unwrap_or_default();
1564        let driver = self
1565            .driver_source
1566            .builder()
1567            .target_vp(target_vp)
1568            .run_on_target(true)
1569            .build(format!("storvsp-{}-{}", self.instance_id, channel_index));
1570
1571        let channel = gpadl_channel(&driver, &self.resources, open_request, channel_index)
1572            .context("failed to create vmbus channel")?;
1573
1574        let channel_control = self.resources.channel_control.clone();
1575
1576        tracing::debug!(
1577            target_vp = open_request.open_data.target_vp,
1578            channel_index,
1579            "packet processing starting...",
1580        );
1581
1582        // Force the path ID on incoming SCSI requests to match the IDE
1583        // channel ID, since guests do not reliably set the path ID
1584        // correctly.
1585        let force_path_id = self.ide_path.map(|p| p.path);
1586
1587        let worker = Worker::new(
1588            controller,
1589            channel,
1590            channel_index,
1591            self.resources
1592                .offer_resources
1593                .guest_memory(open_request)
1594                .clone(),
1595            channel_control,
1596            self.io_queue_depth,
1597            self.protocol.clone(),
1598            force_path_id,
1599        )
1600        .map_err(RestoreError::Other)?;
1601
1602        self.workers[channel_index as usize]
1603            .driver
1604            .retarget_vp(target_vp);
1605
1606        Ok(self.workers[channel_index as usize].worker.insert(
1607            &driver,
1608            format!("storvsp worker {}-{}", self.instance_id, channel_index),
1609            worker,
1610        ))
1611    }
1612}
1613
1614/// A disk that can be added to a SCSI controller.
1615#[derive(Clone)]
1616pub struct ScsiControllerDisk {
1617    disk: Arc<dyn AsyncScsiDisk>,
1618}
1619
1620impl ScsiControllerDisk {
1621    /// Creates a new controller disk from an async SCSI disk.
1622    pub fn new(disk: Arc<dyn AsyncScsiDisk>) -> Self {
1623        Self { disk }
1624    }
1625}
1626
1627struct ScsiControllerState {
1628    disks: RwLock<HashMap<ScsiPath, ScsiControllerDisk>>,
1629    rescan_notification_source: Mutex<Vec<futures::channel::mpsc::Sender<()>>>,
1630    poll_mode_queue_depth: AtomicU32,
1631}
1632
1633pub struct ScsiController {
1634    state: Arc<ScsiControllerState>,
1635}
1636
1637impl ScsiController {
1638    pub fn new() -> Self {
1639        Self::new_with_poll_mode_queue_depth(None)
1640    }
1641
1642    pub fn new_with_poll_mode_queue_depth(poll_mode_queue_depth: Option<u32>) -> Self {
1643        Self {
1644            state: Arc::new(ScsiControllerState {
1645                disks: Default::default(),
1646                rescan_notification_source: Mutex::new(Vec::new()),
1647                poll_mode_queue_depth: AtomicU32::new(
1648                    poll_mode_queue_depth.unwrap_or(DEFAULT_POLL_MODE_QUEUE_DEPTH),
1649                ),
1650            }),
1651        }
1652    }
1653
1654    pub fn attach(&self, path: ScsiPath, disk: ScsiControllerDisk) -> Result<(), ScsiPathInUse> {
1655        match self.state.disks.write().entry(path) {
1656            Entry::Occupied(_) => return Err(ScsiPathInUse(path)),
1657            Entry::Vacant(entry) => entry.insert(disk),
1658        };
1659        for source in self.state.rescan_notification_source.lock().iter_mut() {
1660            // Ok to ignore errors here. If the channel is full a previous notification has not yet
1661            // been processed by the primary channel worker.
1662            source.try_send(()).ok();
1663        }
1664        Ok(())
1665    }
1666
1667    pub fn remove(&self, path: ScsiPath) -> Result<(), ScsiPathNotInUse> {
1668        match self.state.disks.write().entry(path) {
1669            Entry::Vacant(_) => return Err(ScsiPathNotInUse(path)),
1670            Entry::Occupied(entry) => {
1671                entry.remove();
1672            }
1673        }
1674        for source in self.state.rescan_notification_source.lock().iter_mut() {
1675            // Ok to ignore errors here. If the channel is full a previous notification has not yet
1676            // been processed by the primary channel worker.
1677            source.try_send(()).ok();
1678        }
1679        Ok(())
1680    }
1681}
1682
1683impl ScsiControllerState {
1684    fn add_rescan_notification_source(&self, source: futures::channel::mpsc::Sender<()>) {
1685        self.rescan_notification_source.lock().push(source);
1686    }
1687
1688    fn remove_rescan_notification_source(&self, target: &futures::channel::mpsc::Receiver<()>) {
1689        let mut sources = self.rescan_notification_source.lock();
1690        if let Some(index) = sources
1691            .iter()
1692            .position(|source| source.is_connected_to(target))
1693        {
1694            sources.remove(index);
1695        }
1696    }
1697}
1698
1699#[async_trait]
1700impl VmbusDevice for StorageDevice {
1701    fn offer(&self) -> OfferParams {
1702        if let Some(path) = self.ide_path {
1703            let offer_properties = storvsp_protocol::OfferProperties {
1704                path_id: path.path,
1705                target_id: path.target,
1706                flags: storvsp_protocol::OFFER_PROPERTIES_FLAG_IDE_DEVICE,
1707                ..FromZeros::new_zeroed()
1708            };
1709            let mut user_defined = UserDefinedData::new_zeroed();
1710            offer_properties
1711                .write_to_prefix(&mut user_defined[..])
1712                .unwrap();
1713            OfferParams {
1714                interface_name: "ide-accel".to_owned(),
1715                instance_id: self.instance_id,
1716                interface_id: storvsp_protocol::IDE_ACCELERATOR_INTERFACE_ID,
1717                channel_type: ChannelType::Interface { user_defined },
1718                ..Default::default()
1719            }
1720        } else {
1721            OfferParams {
1722                interface_name: "scsi".to_owned(),
1723                instance_id: self.instance_id,
1724                interface_id: storvsp_protocol::SCSI_INTERFACE_ID,
1725                ..Default::default()
1726            }
1727        }
1728    }
1729
1730    fn max_subchannels(&self) -> u16 {
1731        self.max_sub_channel_count
1732    }
1733
1734    fn install(&mut self, resources: DeviceResources) {
1735        self.resources = resources;
1736    }
1737
1738    async fn open(
1739        &mut self,
1740        channel_index: u16,
1741        open_request: &OpenRequest,
1742    ) -> Result<(), ChannelOpenError> {
1743        tracing::debug!(channel_index, "scsi open channel");
1744        self.new_worker(open_request, channel_index)?;
1745        self.workers[channel_index as usize].worker.start();
1746        Ok(())
1747    }
1748
1749    async fn close(&mut self, channel_index: u16) {
1750        tracing::debug!(channel_index, "scsi close channel");
1751        let worker = &mut self.workers[channel_index as usize].worker;
1752        worker.stop().await;
1753        if worker.state_mut().is_some() {
1754            worker
1755                .state_mut()
1756                .unwrap()
1757                .wait_for_scsi_requests_complete()
1758                .await;
1759            worker.remove();
1760        }
1761        if channel_index == 0 {
1762            *self.protocol.state.write() = ProtocolState::Init(InitState::Begin);
1763        }
1764    }
1765
1766    async fn retarget_vp(&mut self, channel_index: u16, target_vp: u32) {
1767        self.workers[channel_index as usize]
1768            .driver
1769            .retarget_vp(target_vp);
1770    }
1771
1772    fn start(&mut self) {
1773        for task in self
1774            .workers
1775            .iter_mut()
1776            .filter(|task| task.worker.has_state() && !task.worker.is_running())
1777        {
1778            task.worker.start();
1779        }
1780    }
1781
1782    async fn stop(&mut self) {
1783        tracing::debug!(instance_id = ?self.instance_id, "StorageDevice stopping...");
1784        for task in self
1785            .workers
1786            .iter_mut()
1787            .filter(|task| task.worker.has_state() && task.worker.is_running())
1788        {
1789            task.worker.stop().await;
1790        }
1791    }
1792
1793    fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1794        Some(self)
1795    }
1796}
1797
1798#[async_trait]
1799impl SaveRestoreVmbusDevice for StorageDevice {
1800    async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1801        Ok(SavedStateBlob::new(self.save()?))
1802    }
1803
1804    async fn restore(
1805        &mut self,
1806        control: RestoreControl<'_>,
1807        state: SavedStateBlob,
1808    ) -> Result<(), RestoreError> {
1809        self.restore(control, state.parse()?).await
1810    }
1811}
1812
1813#[cfg(test)]
1814mod tests {
1815    use super::*;
1816    use crate::test_helpers::TestWorker;
1817    use crate::test_helpers::parse_guest_completion;
1818    use crate::test_helpers::parse_guest_completion_check_flags_status;
1819    use pal_async::DefaultDriver;
1820    use pal_async::async_test;
1821    use scsi::srb::SrbStatus;
1822    use test_with_tracing::test;
1823    use vmbus_channel::connected_async_channels;
1824
1825    // Discourage `Clone` for `ScsiController` outside the crate, but it is
1826    // necessary for testing. The fuzzer also uses `TestWorker`, which needs
1827    // a `clone` of the inner state, but is not in this crate.
1828    impl Clone for ScsiController {
1829        fn clone(&self) -> Self {
1830            ScsiController {
1831                state: self.state.clone(),
1832            }
1833        }
1834    }
1835
1836    #[async_test]
1837    async fn test_channel_working(driver: DefaultDriver) {
1838        // set up the channels and worker
1839        let (host, guest) = connected_async_channels(16 * 1024);
1840        let guest_queue = Queue::new(guest).unwrap();
1841
1842        let test_guest_mem = GuestMemory::allocate(16384);
1843        let controller = ScsiController::new();
1844        let disk = scsidisk::SimpleScsiDisk::new(
1845            disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
1846            Default::default(),
1847        );
1848        controller
1849            .attach(
1850                ScsiPath {
1851                    path: 0,
1852                    target: 0,
1853                    lun: 0,
1854                },
1855                ScsiControllerDisk::new(Arc::new(disk)),
1856            )
1857            .unwrap();
1858
1859        let test_worker = TestWorker::start(
1860            controller.clone(),
1861            driver.clone(),
1862            test_guest_mem.clone(),
1863            host,
1864            None,
1865        );
1866
1867        let mut guest = test_helpers::TestGuest {
1868            queue: guest_queue,
1869            transaction_id: 0,
1870        };
1871
1872        guest.perform_protocol_negotiation().await;
1873
1874        // Set up the buffer for a write request
1875        const IO_LEN: usize = 4 * 1024;
1876        let write_buf = [7u8; IO_LEN];
1877        let write_gpa = 4 * 1024u64;
1878        test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
1879        guest
1880            .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
1881            .await;
1882
1883        guest
1884            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1885            .await;
1886
1887        let read_gpa = 8 * 1024u64;
1888        guest
1889            .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
1890            .await;
1891
1892        guest
1893            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1894            .await;
1895        let mut read_buf = [0u8; IO_LEN];
1896        test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
1897        for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
1898            assert_eq!(b1, b2);
1899        }
1900
1901        // stop everything
1902        guest.verify_graceful_close(test_worker).await;
1903    }
1904
1905    #[async_test]
1906    async fn test_packet_sizes(driver: DefaultDriver) {
1907        // set up the channels and worker
1908        let (host, guest) = connected_async_channels(16384);
1909        let guest_queue = Queue::new(guest).unwrap();
1910
1911        let test_guest_mem = GuestMemory::allocate(1024);
1912        let controller = ScsiController::new();
1913
1914        let _worker = TestWorker::start(
1915            controller.clone(),
1916            driver.clone(),
1917            test_guest_mem,
1918            host,
1919            None,
1920        );
1921
1922        let mut guest = test_helpers::TestGuest {
1923            queue: guest_queue,
1924            transaction_id: 0,
1925        };
1926
1927        let negotiate_packet = storvsp_protocol::Packet {
1928            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
1929            flags: 0,
1930            status: storvsp_protocol::NtStatus::SUCCESS,
1931        };
1932        guest
1933            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1934            .await;
1935
1936        guest.verify_completion(parse_guest_completion).await;
1937
1938        let header = storvsp_protocol::Packet {
1939            operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
1940            flags: 0,
1941            status: storvsp_protocol::NtStatus::SUCCESS,
1942        };
1943
1944        let mut buf = [0u8; 128];
1945        storvsp_protocol::ProtocolVersion {
1946            major_minor: !0,
1947            reserved: 0,
1948        }
1949        .write_to_prefix(&mut buf[..])
1950        .unwrap(); // PANIC: Infallable since `ProtcolVersion` is less than 128 bytes
1951
1952        for &(len, resp_len) in &[(48, 48), (50, 56), (56, 56), (64, 64), (72, 64)] {
1953            guest
1954                .send_data_packet_sync(&[header.as_bytes(), &buf[..len - size_of_val(&header)]])
1955                .await;
1956
1957            guest
1958                .verify_completion(|packet| {
1959                    let IncomingPacket::Completion(packet) = packet else {
1960                        unreachable!()
1961                    };
1962                    assert_eq!(packet.reader().len(), resp_len);
1963                    assert_eq!(
1964                        packet
1965                            .reader()
1966                            .read_plain::<storvsp_protocol::Packet>()
1967                            .unwrap()
1968                            .status,
1969                        storvsp_protocol::NtStatus::REVISION_MISMATCH
1970                    );
1971                    Ok(())
1972                })
1973                .await;
1974        }
1975    }
1976
1977    #[async_test]
1978    async fn test_wrong_first_packet(driver: DefaultDriver) {
1979        // set up the channels and worker
1980        let (host, guest) = connected_async_channels(16384);
1981        let guest_queue = Queue::new(guest).unwrap();
1982
1983        let test_guest_mem = GuestMemory::allocate(1024);
1984        let controller = ScsiController::new();
1985
1986        let _worker = TestWorker::start(
1987            controller.clone(),
1988            driver.clone(),
1989            test_guest_mem,
1990            host,
1991            None,
1992        );
1993
1994        let mut guest = test_helpers::TestGuest {
1995            queue: guest_queue,
1996            transaction_id: 0,
1997        };
1998
1999        // Protocol negotiation done out of order
2000        let negotiate_packet = storvsp_protocol::Packet {
2001            operation: storvsp_protocol::Operation::END_INITIALIZATION,
2002            flags: 0,
2003            status: storvsp_protocol::NtStatus::SUCCESS,
2004        };
2005        guest
2006            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2007            .await;
2008
2009        guest
2010            .verify_completion(|packet| {
2011                let IncomingPacket::Completion(packet) = packet else {
2012                    unreachable!()
2013                };
2014                assert_eq!(
2015                    packet
2016                        .reader()
2017                        .read_plain::<storvsp_protocol::Packet>()
2018                        .unwrap()
2019                        .status,
2020                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE
2021                );
2022                Ok(())
2023            })
2024            .await;
2025    }
2026
2027    #[async_test]
2028    async fn test_unrecognized_operation(driver: DefaultDriver) {
2029        // set up the channels and worker
2030        let (host, guest) = connected_async_channels(16384);
2031        let guest_queue = Queue::new(guest).unwrap();
2032
2033        let test_guest_mem = GuestMemory::allocate(1024);
2034        let controller = ScsiController::new();
2035
2036        let worker = TestWorker::start(
2037            controller.clone(),
2038            driver.clone(),
2039            test_guest_mem,
2040            host,
2041            None,
2042        );
2043
2044        let mut guest = test_helpers::TestGuest {
2045            queue: guest_queue,
2046            transaction_id: 0,
2047        };
2048
2049        // Send packet with unrecognized operation
2050        let negotiate_packet = storvsp_protocol::Packet {
2051            operation: storvsp_protocol::Operation::REMOVE_DEVICE,
2052            flags: 0,
2053            status: storvsp_protocol::NtStatus::SUCCESS,
2054        };
2055        guest
2056            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2057            .await;
2058
2059        match worker.teardown().await {
2060            Err(WorkerError::PacketError(PacketError::UnrecognizedOperation(
2061                storvsp_protocol::Operation::REMOVE_DEVICE,
2062            ))) => {}
2063            result => panic!("Worker failed with unexpected result {:?}!", result),
2064        }
2065    }
2066
2067    #[async_test]
2068    async fn test_too_many_subchannels(driver: DefaultDriver) {
2069        // set up the channels and worker
2070        let (host, guest) = connected_async_channels(16384);
2071        let guest_queue = Queue::new(guest).unwrap();
2072
2073        let test_guest_mem = GuestMemory::allocate(1024);
2074        let controller = ScsiController::new();
2075
2076        let _worker = TestWorker::start(
2077            controller.clone(),
2078            driver.clone(),
2079            test_guest_mem,
2080            host,
2081            None,
2082        );
2083
2084        let mut guest = test_helpers::TestGuest {
2085            queue: guest_queue,
2086            transaction_id: 0,
2087        };
2088
2089        let negotiate_packet = storvsp_protocol::Packet {
2090            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2091            flags: 0,
2092            status: storvsp_protocol::NtStatus::SUCCESS,
2093        };
2094        guest
2095            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2096            .await;
2097        guest.verify_completion(parse_guest_completion).await;
2098
2099        let version_packet = storvsp_protocol::Packet {
2100            operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2101            flags: 0,
2102            status: storvsp_protocol::NtStatus::SUCCESS,
2103        };
2104        let version = storvsp_protocol::ProtocolVersion {
2105            major_minor: storvsp_protocol::VERSION_BLUE,
2106            reserved: 0,
2107        };
2108        guest
2109            .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2110            .await;
2111        guest.verify_completion(parse_guest_completion).await;
2112
2113        let properties_packet = storvsp_protocol::Packet {
2114            operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2115            flags: 0,
2116            status: storvsp_protocol::NtStatus::SUCCESS,
2117        };
2118        guest
2119            .send_data_packet_sync(&[properties_packet.as_bytes()])
2120            .await;
2121
2122        guest.verify_completion(parse_guest_completion).await;
2123
2124        let negotiate_packet = storvsp_protocol::Packet {
2125            operation: storvsp_protocol::Operation::CREATE_SUB_CHANNELS,
2126            flags: 0,
2127            status: storvsp_protocol::NtStatus::SUCCESS,
2128        };
2129        // Create sub channels more than maximum_sub_channel_count
2130        guest
2131            .send_data_packet_sync(&[negotiate_packet.as_bytes(), 1_u16.as_bytes()])
2132            .await;
2133
2134        guest
2135            .verify_completion(|packet| {
2136                let IncomingPacket::Completion(packet) = packet else {
2137                    unreachable!()
2138                };
2139                assert_eq!(
2140                    packet
2141                        .reader()
2142                        .read_plain::<storvsp_protocol::Packet>()
2143                        .unwrap()
2144                        .status,
2145                    storvsp_protocol::NtStatus::INVALID_PARAMETER
2146                );
2147                Ok(())
2148            })
2149            .await;
2150    }
2151
2152    #[async_test]
2153    async fn test_begin_init_on_ready(driver: DefaultDriver) {
2154        // set up the channels and worker
2155        let (host, guest) = connected_async_channels(16384);
2156        let guest_queue = Queue::new(guest).unwrap();
2157
2158        let test_guest_mem = GuestMemory::allocate(1024);
2159        let controller = ScsiController::new();
2160
2161        let _worker = TestWorker::start(
2162            controller.clone(),
2163            driver.clone(),
2164            test_guest_mem,
2165            host,
2166            None,
2167        );
2168
2169        let mut guest = test_helpers::TestGuest {
2170            queue: guest_queue,
2171            transaction_id: 0,
2172        };
2173
2174        guest.perform_protocol_negotiation().await;
2175
2176        // Protocol negotiation done out of order
2177        let negotiate_packet = storvsp_protocol::Packet {
2178            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2179            flags: 0,
2180            status: storvsp_protocol::NtStatus::SUCCESS,
2181        };
2182        guest
2183            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2184            .await;
2185
2186        guest
2187            .verify_completion(|p| {
2188                parse_guest_completion_check_flags_status(
2189                    p,
2190                    0,
2191                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
2192                )
2193            })
2194            .await;
2195    }
2196
2197    #[async_test]
2198    async fn test_hot_add_remove(driver: DefaultDriver) {
2199        // set up channels and worker.
2200        let (host, guest) = connected_async_channels(16 * 1024);
2201        let guest_queue = Queue::new(guest).unwrap();
2202
2203        let test_guest_mem = GuestMemory::allocate(16384);
2204        // create a controller with no disk yet.
2205        let controller = ScsiController::new();
2206
2207        let test_worker = TestWorker::start(
2208            controller.clone(),
2209            driver.clone(),
2210            test_guest_mem.clone(),
2211            host,
2212            None,
2213        );
2214
2215        let mut guest = test_helpers::TestGuest {
2216            queue: guest_queue,
2217            transaction_id: 0,
2218        };
2219
2220        guest.perform_protocol_negotiation().await;
2221
2222        // Verify no LUNs are reported initially.
2223        let mut lun_list_buffer: [u8; 256] = [0; 256];
2224        let mut disk_count = 0;
2225        guest
2226            .send_report_luns_packet(ScsiPath::default(), 0, lun_list_buffer.len())
2227            .await;
2228        guest
2229            .verify_completion(|p| {
2230                test_helpers::parse_guest_completed_io_check_tx_len(p, SrbStatus::SUCCESS, Some(8))
2231            })
2232            .await;
2233        test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2234        let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2235        assert_eq!(lun_list_size, disk_count as u32 * 8);
2236
2237        // Set up a buffer for writes.
2238        const IO_LEN: usize = 4 * 1024;
2239        let write_buf = [7u8; IO_LEN];
2240        let write_gpa = 4 * 1024u64;
2241        test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2242
2243        guest
2244            .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2245            .await;
2246        guest
2247            .verify_completion(|p| {
2248                test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2249            })
2250            .await;
2251
2252        // Add some disks while the guest is running.
2253        for lun in 0..4 {
2254            let disk = scsidisk::SimpleScsiDisk::new(
2255                disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
2256                Default::default(),
2257            );
2258            controller
2259                .attach(
2260                    ScsiPath {
2261                        path: 0,
2262                        target: 0,
2263                        lun,
2264                    },
2265                    ScsiControllerDisk::new(Arc::new(disk)),
2266                )
2267                .unwrap();
2268            guest
2269                .verify_completion(test_helpers::parse_guest_enumerate_bus)
2270                .await;
2271
2272            disk_count += 1;
2273            guest
2274                .send_report_luns_packet(ScsiPath::default(), 0, 256)
2275                .await;
2276            guest
2277                .verify_completion(|p| {
2278                    test_helpers::parse_guest_completed_io_check_tx_len(
2279                        p,
2280                        SrbStatus::SUCCESS,
2281                        Some((disk_count + 1) * 8),
2282                    )
2283                })
2284                .await;
2285            test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2286            let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2287            assert_eq!(lun_list_size, disk_count as u32 * 8);
2288
2289            guest
2290                .send_write_packet(
2291                    ScsiPath {
2292                        path: 0,
2293                        target: 0,
2294                        lun,
2295                    },
2296                    write_gpa,
2297                    1,
2298                    IO_LEN,
2299                )
2300                .await;
2301            guest
2302                .verify_completion(|p| {
2303                    test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS)
2304                })
2305                .await;
2306        }
2307
2308        // Remove all disks while the guest is running.
2309        for lun in 0..4 {
2310            controller
2311                .remove(ScsiPath {
2312                    path: 0,
2313                    target: 0,
2314                    lun,
2315                })
2316                .unwrap();
2317            guest
2318                .verify_completion(test_helpers::parse_guest_enumerate_bus)
2319                .await;
2320
2321            disk_count -= 1;
2322            guest
2323                .send_report_luns_packet(ScsiPath::default(), 0, 4096)
2324                .await;
2325            guest
2326                .verify_completion(|p| {
2327                    test_helpers::parse_guest_completed_io_check_tx_len(
2328                        p,
2329                        SrbStatus::SUCCESS,
2330                        Some((disk_count + 1) * 8),
2331                    )
2332                })
2333                .await;
2334            test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2335            let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2336            assert_eq!(lun_list_size, disk_count as u32 * 8);
2337
2338            guest
2339                .send_write_packet(
2340                    ScsiPath {
2341                        path: 0,
2342                        target: 0,
2343                        lun,
2344                    },
2345                    write_gpa,
2346                    1,
2347                    IO_LEN,
2348                )
2349                .await;
2350            guest
2351                .verify_completion(|p| {
2352                    test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2353                })
2354                .await;
2355        }
2356
2357        guest.verify_graceful_close(test_worker).await;
2358    }
2359
2360    #[async_test]
2361    pub async fn test_async_disk(driver: DefaultDriver) {
2362        let device = disklayer_ram::ram_disk(64 * 1024, false).unwrap();
2363        let controller = ScsiController::new();
2364        let disk = ScsiControllerDisk::new(Arc::new(scsidisk::SimpleScsiDisk::new(
2365            device,
2366            Default::default(),
2367        )));
2368        controller
2369            .attach(
2370                ScsiPath {
2371                    path: 0,
2372                    target: 0,
2373                    lun: 0,
2374                },
2375                disk,
2376            )
2377            .unwrap();
2378
2379        let (host, guest) = connected_async_channels(16 * 1024);
2380        let guest_queue = Queue::new(guest).unwrap();
2381
2382        let mut guest = test_helpers::TestGuest {
2383            queue: guest_queue,
2384            transaction_id: 0,
2385        };
2386
2387        let test_guest_mem = GuestMemory::allocate(16384);
2388        let worker = TestWorker::start(
2389            controller.clone(),
2390            &driver,
2391            test_guest_mem.clone(),
2392            host,
2393            None,
2394        );
2395
2396        let negotiate_packet = storvsp_protocol::Packet {
2397            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2398            flags: 0,
2399            status: storvsp_protocol::NtStatus::SUCCESS,
2400        };
2401        guest
2402            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2403            .await;
2404        guest.verify_completion(parse_guest_completion).await;
2405
2406        let version_packet = storvsp_protocol::Packet {
2407            operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2408            flags: 0,
2409            status: storvsp_protocol::NtStatus::SUCCESS,
2410        };
2411        let version = storvsp_protocol::ProtocolVersion {
2412            major_minor: storvsp_protocol::VERSION_BLUE,
2413            reserved: 0,
2414        };
2415        guest
2416            .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2417            .await;
2418        guest.verify_completion(parse_guest_completion).await;
2419
2420        let properties_packet = storvsp_protocol::Packet {
2421            operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2422            flags: 0,
2423            status: storvsp_protocol::NtStatus::SUCCESS,
2424        };
2425        guest
2426            .send_data_packet_sync(&[properties_packet.as_bytes()])
2427            .await;
2428        guest.verify_completion(parse_guest_completion).await;
2429
2430        let negotiate_packet = storvsp_protocol::Packet {
2431            operation: storvsp_protocol::Operation::END_INITIALIZATION,
2432            flags: 0,
2433            status: storvsp_protocol::NtStatus::SUCCESS,
2434        };
2435        guest
2436            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2437            .await;
2438        guest.verify_completion(parse_guest_completion).await;
2439
2440        const IO_LEN: usize = 4 * 1024;
2441        let write_buf = [7u8; IO_LEN];
2442        let write_gpa = 4 * 1024u64;
2443        test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2444        guest
2445            .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2446            .await;
2447        guest
2448            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2449            .await;
2450
2451        let read_gpa = 8 * 1024u64;
2452        guest
2453            .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
2454            .await;
2455        guest
2456            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2457            .await;
2458        let mut read_buf = [0u8; IO_LEN];
2459        test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
2460        for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
2461            assert_eq!(b1, b2);
2462        }
2463
2464        guest.verify_graceful_close(worker).await;
2465    }
2466}