storvsp/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7#[cfg(feature = "ioperf")]
8pub mod ioperf;
9
10#[cfg(feature = "test")]
11pub mod test_helpers;
12
13#[cfg(not(feature = "test"))]
14mod test_helpers;
15
16pub mod resolver;
17mod save_restore;
18
19use crate::ring::gparange::GpnList;
20use crate::ring::gparange::MultiPagedRangeBuf;
21use anyhow::Context as _;
22use async_trait::async_trait;
23use fast_select::FastSelect;
24use futures::FutureExt;
25use futures::StreamExt;
26use futures::select_biased;
27use guestmem::AccessError;
28use guestmem::GuestMemory;
29use guestmem::MemoryRead;
30use guestmem::MemoryWrite;
31use guestmem::ranges::PagedRange;
32use guid::Guid;
33use inspect::Inspect;
34use inspect::InspectMut;
35use inspect_counters::Counter;
36use inspect_counters::Histogram;
37use oversized_box::OversizedBox;
38use parking_lot::Mutex;
39use parking_lot::RwLock;
40use ring::OutgoingPacketType;
41use scsi::AdditionalSenseCode;
42use scsi::ScsiOp;
43use scsi::ScsiStatus;
44use scsi::srb::SrbStatus;
45use scsi::srb::SrbStatusAndFlags;
46use scsi_buffers::RequestBuffers;
47use scsi_core::AsyncScsiDisk;
48use scsi_core::Request;
49use scsi_core::ScsiResult;
50use scsi_defs as scsi;
51use scsidisk::illegal_request_sense;
52use slab::Slab;
53use std::collections::hash_map::Entry;
54use std::collections::hash_map::HashMap;
55use std::fmt::Debug;
56use std::future::Future;
57use std::future::poll_fn;
58use std::pin::Pin;
59use std::sync::Arc;
60use std::sync::atomic::AtomicU32;
61use std::sync::atomic::Ordering::Relaxed;
62use std::task::Context;
63use std::task::Poll;
64use storvsp_resources::ScsiPath;
65use task_control::AsyncRun;
66use task_control::InspectTask;
67use task_control::StopTask;
68use task_control::TaskControl;
69use thiserror::Error;
70use tracing_helpers::ErrorValueExt;
71use unicycle::FuturesUnordered;
72use vmbus_async::queue;
73use vmbus_async::queue::ExternalDataError;
74use vmbus_async::queue::IncomingPacket;
75use vmbus_async::queue::OutgoingPacket;
76use vmbus_async::queue::Queue;
77use vmbus_channel::RawAsyncChannel;
78use vmbus_channel::bus::ChannelType;
79use vmbus_channel::bus::OfferParams;
80use vmbus_channel::bus::OpenRequest;
81use vmbus_channel::channel::ChannelControl;
82use vmbus_channel::channel::ChannelOpenError;
83use vmbus_channel::channel::DeviceResources;
84use vmbus_channel::channel::RestoreControl;
85use vmbus_channel::channel::SaveRestoreVmbusDevice;
86use vmbus_channel::channel::VmbusDevice;
87use vmbus_channel::gpadl_ring::GpadlRingMem;
88use vmbus_channel::gpadl_ring::gpadl_channel;
89use vmbus_core::protocol::UserDefinedData;
90use vmbus_ring as ring;
91use vmbus_ring::RingMem;
92use vmcore::save_restore::RestoreError;
93use vmcore::save_restore::SaveError;
94use vmcore::save_restore::SavedStateBlob;
95use vmcore::vm_task::VmTaskDriver;
96use vmcore::vm_task::VmTaskDriverSource;
97use zerocopy::FromBytes;
98use zerocopy::FromZeros;
99use zerocopy::Immutable;
100use zerocopy::IntoBytes;
101use zerocopy::KnownLayout;
102
103/// The IO queue depth at which the controller switches from guest-signal-driven
104/// to poll-mode operation. This optimization reduces the guest exit rate by
105/// relying on (typically-interrupt-driven) IO completions to drive polling for
106/// new IO requests.
107const DEFAULT_POLL_MODE_QUEUE_DEPTH: u32 = 1;
108
109pub struct StorageDevice {
110    instance_id: Guid,
111    ide_path: Option<ScsiPath>,
112    workers: Vec<WorkerAndDriver>,
113    controller: Arc<ScsiControllerState>,
114    resources: DeviceResources,
115    driver_source: VmTaskDriverSource,
116    max_sub_channel_count: u16,
117    protocol: Arc<Protocol>,
118    io_queue_depth: u32,
119}
120
121#[derive(Inspect)]
122struct WorkerAndDriver {
123    #[inspect(flatten)]
124    worker: TaskControl<WorkerState, Worker>,
125    driver: VmTaskDriver,
126}
127
128struct WorkerState;
129
130impl InspectMut for StorageDevice {
131    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
132        let mut resp = req.respond();
133
134        let disks = self.controller.disks.read();
135        for (path, controller_disk) in disks.iter() {
136            resp.child(&format!("disks/{}", path), |req| {
137                controller_disk.disk.inspect(req);
138            });
139        }
140
141        resp.fields(
142            "channels",
143            self.workers
144                .iter()
145                .filter(|task| task.worker.has_state())
146                .enumerate(),
147        )
148        .field(
149            "poll_mode_queue_depth",
150            inspect::AtomicMut(&self.controller.poll_mode_queue_depth),
151        );
152    }
153}
154
155struct Worker<T: RingMem = GpadlRingMem> {
156    inner: WorkerInner,
157    rescan_notification: futures::channel::mpsc::Receiver<()>,
158    fast_select: FastSelect,
159    queue: Queue<T>,
160}
161
162struct Protocol {
163    state: RwLock<ProtocolState>,
164    /// Signaled when `state` transitions to `ProtocolState::Ready`.
165    ready: event_listener::Event,
166}
167
168struct WorkerInner {
169    protocol: Arc<Protocol>,
170    request_size: usize,
171    controller: Arc<ScsiControllerState>,
172    channel_index: u16,
173    scsi_queue: Arc<ScsiCommandQueue>,
174    scsi_requests: FuturesUnordered<ScsiRequest>,
175    scsi_requests_states: Slab<ScsiRequestState>,
176    full_request_pool: Vec<Arc<ScsiRequestAndRange>>,
177    future_pool: Vec<OversizedBox<(), ScsiOpStorage>>,
178    channel_control: ChannelControl,
179    max_io_queue_depth: usize,
180    stats: WorkerStats,
181}
182
183#[derive(Debug, Default, Inspect)]
184struct WorkerStats {
185    ios_submitted: Counter,
186    ios_completed: Counter,
187    wakes: Counter,
188    wakes_spurious: Counter,
189    per_wake_submissions: Histogram<10>,
190    per_wake_completions: Histogram<10>,
191}
192
193#[repr(u16)]
194#[derive(Copy, Clone, Debug, Inspect, PartialEq, Eq, PartialOrd, Ord)]
195enum Version {
196    Win6 = storvsp_protocol::VERSION_WIN6,
197    Win7 = storvsp_protocol::VERSION_WIN7,
198    Win8 = storvsp_protocol::VERSION_WIN8,
199    Blue = storvsp_protocol::VERSION_BLUE,
200}
201
202#[derive(Debug, Error)]
203#[error("protocol version {0:#x} not supported")]
204struct UnsupportedVersion(u16);
205
206impl Version {
207    fn parse(major_minor: u16) -> Result<Self, UnsupportedVersion> {
208        let version = match major_minor {
209            storvsp_protocol::VERSION_WIN6 => Self::Win6,
210            storvsp_protocol::VERSION_WIN7 => Self::Win7,
211            storvsp_protocol::VERSION_WIN8 => Self::Win8,
212            storvsp_protocol::VERSION_BLUE => Self::Blue,
213            version => return Err(UnsupportedVersion(version)),
214        };
215        assert_eq!(version as u16, major_minor);
216        Ok(version)
217    }
218
219    fn max_request_size(&self) -> usize {
220        match self {
221            Version::Win8 | Version::Blue => storvsp_protocol::SCSI_REQUEST_LEN_V2,
222            Version::Win6 | Version::Win7 => storvsp_protocol::SCSI_REQUEST_LEN_V1,
223        }
224    }
225}
226
227#[derive(Copy, Clone)]
228enum ProtocolState {
229    Init(InitState),
230    Ready {
231        version: Version,
232        subchannel_count: u16,
233    },
234}
235
236#[derive(Copy, Clone, Debug)]
237enum InitState {
238    Begin,
239    QueryVersion,
240    QueryProperties {
241        version: Version,
242    },
243    EndInitialization {
244        version: Version,
245        subchannel_count: Option<u16>,
246    },
247}
248
249/// The internal SCSI operation future type.
250///
251/// This is a boxed future of a large pre-determined size. The box is reused
252/// after a SCSI request completes to avoid allocations in the hot path.
253///
254/// An Option type is used so that the future can be efficiently dropped (via
255/// `Pin::set(x, None)`) before it is stashed away for reuse.
256type ScsiOpStorage = [u64; SCSI_REQUEST_STACK_SIZE / 8];
257type ScsiOpFuture = Pin<OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>>;
258
259/// The amount of space reserved for a ScsiOpFuture.
260///
261/// This was chosen by running `cargo test -p storvsp -- --no-capture` and looking at the required
262/// size that was given in the failure message
263const SCSI_REQUEST_STACK_SIZE: usize = scsi_core::ASYNC_SCSI_DISK_STACK_SIZE + 272;
264
265struct ScsiRequest {
266    request_id: usize,
267    future: Option<ScsiOpFuture>,
268}
269
270impl ScsiRequest {
271    fn new(
272        request_id: usize,
273        future: OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>,
274    ) -> Self {
275        Self {
276            request_id,
277            future: Some(future.into()),
278        }
279    }
280}
281
282impl Future for ScsiRequest {
283    type Output = (usize, ScsiResult, OversizedBox<(), ScsiOpStorage>);
284
285    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
286        let this = self.get_mut();
287        let future = this.future.as_mut().unwrap().as_mut();
288        let result = std::task::ready!(future.poll(cx));
289        // Return the future so that its storage can be reused.
290        let future = this.future.take().unwrap();
291        Poll::Ready((this.request_id, result, OversizedBox::empty_pinned(future)))
292    }
293}
294
295#[derive(Debug, Error)]
296enum WorkerError {
297    #[error("packet error")]
298    PacketError(#[source] PacketError),
299    #[error("queue error")]
300    Queue(#[source] queue::Error),
301    #[error("queue should have enough space but no longer does")]
302    NotEnoughSpace,
303}
304
305#[derive(Debug, Error)]
306enum PacketError {
307    #[error("Not transactional")]
308    NotTransactional,
309    #[error("Unrecognized operation {0:?}")]
310    UnrecognizedOperation(storvsp_protocol::Operation),
311    #[error("Invalid packet type")]
312    InvalidPacketType,
313    #[error("Invalid data transfer length")]
314    InvalidDataTransferLength,
315    #[error("Access error: {0}")]
316    Access(#[source] AccessError),
317    #[error("Range error")]
318    Range(#[source] ExternalDataError),
319}
320
321#[derive(Debug, Default, Clone)]
322struct Range {
323    buf: MultiPagedRangeBuf<GpnList>,
324    len: usize,
325    is_write: bool,
326}
327
328impl Range {
329    fn new(
330        buf: MultiPagedRangeBuf<GpnList>,
331        request: &storvsp_protocol::ScsiRequest,
332    ) -> Option<Self> {
333        let len = request.data_transfer_length as usize;
334        let is_write = request.data_in != 0;
335        // Ensure there is exactly one range and it's large enough, or there are
336        // zero ranges and there is no associated SCSI buffer.
337        if buf.range_count() > 1 || (len > 0 && buf.first()?.len() < len) {
338            return None;
339        }
340        Some(Self { buf, len, is_write })
341    }
342
343    fn buffer<'a>(&'a self, guest_memory: &'a GuestMemory) -> RequestBuffers<'a> {
344        let mut range = self.buf.first().unwrap_or_else(PagedRange::empty);
345        range.truncate(self.len);
346        RequestBuffers::new(guest_memory, range, self.is_write)
347    }
348}
349
350#[derive(Debug)]
351struct Packet {
352    data: PacketData,
353    transaction_id: u64,
354    request_size: usize,
355}
356
357#[derive(Debug)]
358enum PacketData {
359    BeginInitialization,
360    EndInitialization,
361    QueryProtocolVersion(u16),
362    QueryProperties,
363    CreateSubChannels(u16),
364    ExecuteScsi(Arc<ScsiRequestAndRange>),
365    ResetBus,
366    ResetAdapter,
367    ResetLun,
368}
369
370#[derive(Debug)]
371pub struct RangeError;
372
373fn parse_packet<T: RingMem>(
374    packet: &IncomingPacket<'_, T>,
375    pool: &mut Vec<Arc<ScsiRequestAndRange>>,
376) -> Result<Packet, PacketError> {
377    let packet = match packet {
378        IncomingPacket::Completion(_) => return Err(PacketError::InvalidPacketType),
379        IncomingPacket::Data(packet) => packet,
380    };
381    let transaction_id = packet
382        .transaction_id()
383        .ok_or(PacketError::NotTransactional)?;
384
385    let mut reader = packet.reader();
386    let header: storvsp_protocol::Packet = reader.read_plain().map_err(PacketError::Access)?;
387    // You would expect that this should be limited to the current protocol
388    // version's maximum packet size, but this is not what Hyper-V does, and
389    // Linux 6.1 relies on this behavior during protocol initialization.
390    let request_size = reader.len().min(storvsp_protocol::SCSI_REQUEST_LEN_MAX);
391    let data = match header.operation {
392        storvsp_protocol::Operation::BEGIN_INITIALIZATION => PacketData::BeginInitialization,
393        storvsp_protocol::Operation::END_INITIALIZATION => PacketData::EndInitialization,
394        storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION => {
395            let mut version = storvsp_protocol::ProtocolVersion::new_zeroed();
396            reader
397                .read(version.as_mut_bytes())
398                .map_err(PacketError::Access)?;
399            PacketData::QueryProtocolVersion(version.major_minor)
400        }
401        storvsp_protocol::Operation::QUERY_PROPERTIES => PacketData::QueryProperties,
402        storvsp_protocol::Operation::EXECUTE_SRB => {
403            let mut full_request = pool.pop().unwrap_or_else(|| {
404                Arc::new(ScsiRequestAndRange {
405                    external_data: Range::default(),
406                    request: storvsp_protocol::ScsiRequest::new_zeroed(),
407                    request_size,
408                })
409            });
410
411            {
412                let full_request = Arc::get_mut(&mut full_request).unwrap();
413                let request_buf = &mut full_request.request.as_mut_bytes()[..request_size];
414                reader.read(request_buf).map_err(PacketError::Access)?;
415
416                let buf = packet.read_external_ranges().map_err(PacketError::Range)?;
417
418                full_request.external_data = Range::new(buf, &full_request.request)
419                    .ok_or(PacketError::InvalidDataTransferLength)?;
420            }
421
422            PacketData::ExecuteScsi(full_request)
423        }
424        storvsp_protocol::Operation::RESET_LUN => PacketData::ResetLun,
425        storvsp_protocol::Operation::RESET_ADAPTER => PacketData::ResetAdapter,
426        storvsp_protocol::Operation::RESET_BUS => PacketData::ResetBus,
427        storvsp_protocol::Operation::CREATE_SUB_CHANNELS => {
428            let mut sub_channel_count: u16 = 0;
429            reader
430                .read(sub_channel_count.as_mut_bytes())
431                .map_err(PacketError::Access)?;
432            PacketData::CreateSubChannels(sub_channel_count)
433        }
434        _ => return Err(PacketError::UnrecognizedOperation(header.operation)),
435    };
436
437    if let PacketData::ExecuteScsi(_) = data {
438        tracing::trace!(transaction_id, ?data, "parse_packet");
439    } else {
440        tracing::debug!(transaction_id, ?data, "parse_packet");
441    }
442
443    Ok(Packet {
444        data,
445        request_size,
446        transaction_id,
447    })
448}
449
450impl WorkerInner {
451    fn send_vmbus_packet<M: RingMem>(
452        &mut self,
453        writer: &mut queue::WriteBatch<'_, M>,
454        packet_type: OutgoingPacketType<'_>,
455        request_size: usize,
456        transaction_id: u64,
457        operation: storvsp_protocol::Operation,
458        status: storvsp_protocol::NtStatus,
459        payload: &[u8],
460    ) -> Result<(), WorkerError> {
461        let header = storvsp_protocol::Packet {
462            operation,
463            flags: 0,
464            status,
465        };
466
467        let packet_size = size_of_val(&header) + request_size;
468
469        // Zero pad or truncate the payload to the queue's packet size. This is
470        // necessary because Windows guests check that each packet's size is
471        // exactly the largest possible packet size for the negotiated protocol
472        // version.
473        let len = size_of_val(&header) + size_of_val(payload);
474        let padding = [0; storvsp_protocol::SCSI_REQUEST_LEN_MAX];
475        let (payload_bytes, padding_bytes) = if len > packet_size {
476            (&payload[..packet_size - size_of_val(&header)], &[][..])
477        } else {
478            (payload, &padding[..packet_size - len])
479        };
480        assert_eq!(
481            size_of_val(&header) + payload_bytes.len() + padding_bytes.len(),
482            packet_size
483        );
484        writer
485            .try_write(&OutgoingPacket {
486                transaction_id,
487                packet_type,
488                payload: &[header.as_bytes(), payload_bytes, padding_bytes],
489            })
490            .map_err(|err| match err {
491                queue::TryWriteError::Full(_) => WorkerError::NotEnoughSpace,
492                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
493            })
494    }
495
496    fn send_packet<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
497        &mut self,
498        writer: &mut queue::WriteHalf<'_, M>,
499        operation: storvsp_protocol::Operation,
500        status: storvsp_protocol::NtStatus,
501        payload: &P,
502    ) -> Result<(), WorkerError> {
503        self.send_vmbus_packet(
504            &mut writer.batched(),
505            OutgoingPacketType::InBandNoCompletion,
506            self.request_size,
507            0,
508            operation,
509            status,
510            payload.as_bytes(),
511        )
512    }
513
514    fn send_completion<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
515        &mut self,
516        writer: &mut queue::WriteHalf<'_, M>,
517        packet: &Packet,
518        status: storvsp_protocol::NtStatus,
519        payload: &P,
520    ) -> Result<(), WorkerError> {
521        self.send_vmbus_packet(
522            &mut writer.batched(),
523            OutgoingPacketType::Completion,
524            packet.request_size,
525            packet.transaction_id,
526            storvsp_protocol::Operation::COMPLETE_IO,
527            status,
528            payload.as_bytes(),
529        )
530    }
531}
532
533struct ScsiCommandQueue {
534    controller: Arc<ScsiControllerState>,
535    mem: GuestMemory,
536    force_path_id: Option<u8>,
537}
538
539impl ScsiCommandQueue {
540    async fn execute_scsi(
541        &self,
542        external_data: &Range,
543        request: &storvsp_protocol::ScsiRequest,
544    ) -> ScsiResult {
545        let op = ScsiOp(request.payload[0]);
546        let external_data = external_data.buffer(&self.mem);
547
548        tracing::trace!(
549            path_id = request.path_id,
550            target_id = request.target_id,
551            lun = request.lun,
552            op = ?op,
553            "execute_scsi start...",
554        );
555
556        let path_id = self.force_path_id.unwrap_or(request.path_id);
557
558        let controller_disk = self
559            .controller
560            .disks
561            .read()
562            .get(&ScsiPath {
563                path: path_id,
564                target: request.target_id,
565                lun: request.lun,
566            })
567            .cloned();
568
569        let result = match op {
570            ScsiOp::REPORT_LUNS => {
571                const HEADER_SIZE: usize = size_of::<scsi::LunList>();
572                let mut luns: Vec<u8> = self
573                    .controller
574                    .disks
575                    .read()
576                    .iter()
577                    .flat_map(|(path, _)| {
578                        // Use the original path ID and not the forced one to
579                        // match Hyper-V storvsp behavior.
580                        if request.path_id == path.path && request.target_id == path.target {
581                            Some(path.lun)
582                        } else {
583                            None
584                        }
585                    })
586                    .collect();
587                luns.sort_unstable();
588                let mut data: Vec<u64> = vec![0; luns.len() + 1];
589                let header = scsi::LunList {
590                    length: (luns.len() as u32 * 8).into(),
591                    reserved: [0; 4],
592                };
593                data.as_mut_bytes()[..HEADER_SIZE].copy_from_slice(header.as_bytes());
594                for (i, lun) in luns.iter().enumerate() {
595                    data[i + 1].as_mut_bytes()[..2].copy_from_slice(&(*lun as u16).to_be_bytes());
596                }
597                if external_data.len() >= HEADER_SIZE {
598                    let tx = std::cmp::min(external_data.len(), data.as_bytes().len());
599                    external_data.writer().write(&data.as_bytes()[..tx]).map_or(
600                        ScsiResult {
601                            scsi_status: ScsiStatus::CHECK_CONDITION,
602                            srb_status: SrbStatus::INVALID_REQUEST,
603                            tx: 0,
604                            sense_data: Some(illegal_request_sense(
605                                AdditionalSenseCode::INVALID_CDB,
606                            )),
607                        },
608                        |_| ScsiResult {
609                            scsi_status: ScsiStatus::GOOD,
610                            srb_status: SrbStatus::SUCCESS,
611                            tx,
612                            sense_data: None,
613                        },
614                    )
615                } else {
616                    ScsiResult {
617                        scsi_status: ScsiStatus::GOOD,
618                        srb_status: SrbStatus::SUCCESS,
619                        tx: 0,
620                        sense_data: None,
621                    }
622                }
623            }
624            _ if controller_disk.is_some() => {
625                let mut cdb = [0; 16];
626                cdb.copy_from_slice(&request.payload[0..storvsp_protocol::CDB16GENERIC_LENGTH]);
627                controller_disk
628                    .unwrap()
629                    .disk
630                    .execute_scsi(
631                        &external_data,
632                        &Request {
633                            cdb,
634                            srb_flags: request.srb_flags,
635                        },
636                    )
637                    .await
638            }
639            ScsiOp::INQUIRY => {
640                let cdb = scsi::CdbInquiry::ref_from_prefix(&request.payload)
641                    .unwrap()
642                    .0; // TODO: zerocopy: ref-from-prefix: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
643                if external_data.len() < cdb.allocation_length.get() as usize
644                    || request.data_in != storvsp_protocol::SCSI_IOCTL_DATA_IN
645                    || (cdb.allocation_length.get() as usize) < size_of::<scsi::InquiryDataHeader>()
646                {
647                    ScsiResult {
648                        scsi_status: ScsiStatus::CHECK_CONDITION,
649                        srb_status: SrbStatus::INVALID_REQUEST,
650                        tx: 0,
651                        sense_data: Some(illegal_request_sense(AdditionalSenseCode::INVALID_CDB)),
652                    }
653                } else {
654                    let enable_vpd = cdb.flags.vpd();
655                    if enable_vpd || cdb.page_code != 0 {
656                        // cannot support VPD inquiry for non-existing device (lun).
657                        ScsiResult {
658                            scsi_status: ScsiStatus::CHECK_CONDITION,
659                            srb_status: SrbStatus::INVALID_REQUEST,
660                            tx: 0,
661                            sense_data: Some(illegal_request_sense(
662                                AdditionalSenseCode::INVALID_CDB,
663                            )),
664                        }
665                    } else {
666                        const LOGICAL_UNIT_NOT_PRESENT_DEVICE: u8 = 0x7F;
667                        let mut data = scsidisk::INQUIRY_DATA_TEMPLATE;
668                        data.header.device_type = LOGICAL_UNIT_NOT_PRESENT_DEVICE;
669
670                        if request.lun != 0 {
671                            // Below fields are only set for lun0 inquiry so zero out here.
672                            data.vendor_id = [0; 8];
673                            data.product_id = [0; 16];
674                            data.product_revision_level = [0; 4];
675                        }
676
677                        let datab = data.as_bytes();
678                        let tx = std::cmp::min(
679                            cdb.allocation_length.get() as usize,
680                            size_of::<scsi::InquiryData>(),
681                        );
682                        external_data.writer().write(&datab[..tx]).map_or(
683                            ScsiResult {
684                                scsi_status: ScsiStatus::CHECK_CONDITION,
685                                srb_status: SrbStatus::INVALID_REQUEST,
686                                tx: 0,
687                                sense_data: Some(illegal_request_sense(
688                                    AdditionalSenseCode::INVALID_CDB,
689                                )),
690                            },
691                            |_| ScsiResult {
692                                scsi_status: ScsiStatus::GOOD,
693                                srb_status: SrbStatus::SUCCESS,
694                                tx,
695                                sense_data: None,
696                            },
697                        )
698                    }
699                }
700            }
701            _ => ScsiResult {
702                scsi_status: ScsiStatus::CHECK_CONDITION,
703                srb_status: SrbStatus::INVALID_LUN,
704                tx: 0,
705                sense_data: None,
706            },
707        };
708
709        tracing::trace!(
710            path_id = request.path_id,
711            target_id = request.target_id,
712            lun = request.lun,
713            op = ?op,
714            result = ?result,
715            "execute_scsi completed.",
716        );
717        result
718    }
719}
720
721impl<T: RingMem + 'static> Worker<T> {
722    fn new(
723        controller: Arc<ScsiControllerState>,
724        channel: RawAsyncChannel<T>,
725        channel_index: u16,
726        mem: GuestMemory,
727        channel_control: ChannelControl,
728        io_queue_depth: u32,
729        protocol: Arc<Protocol>,
730        force_path_id: Option<u8>,
731    ) -> anyhow::Result<Self> {
732        let queue = Queue::new(channel)?;
733        #[expect(clippy::disallowed_methods)] // TODO
734        let (source, target) = futures::channel::mpsc::channel(1);
735        controller.add_rescan_notification_source(source);
736
737        let max_io_queue_depth = io_queue_depth.max(1) as usize;
738        Ok(Self {
739            inner: WorkerInner {
740                protocol,
741                request_size: storvsp_protocol::SCSI_REQUEST_LEN_V1,
742                controller: controller.clone(),
743                channel_index,
744                scsi_queue: Arc::new(ScsiCommandQueue {
745                    controller,
746                    mem,
747                    force_path_id,
748                }),
749                scsi_requests: FuturesUnordered::new(),
750                scsi_requests_states: Slab::with_capacity(max_io_queue_depth),
751                channel_control,
752                max_io_queue_depth,
753                future_pool: Vec::new(),
754                full_request_pool: Vec::new(),
755                stats: Default::default(),
756            },
757            queue,
758            rescan_notification: target,
759            fast_select: FastSelect::new(),
760        })
761    }
762
763    async fn wait_for_scsi_requests_complete(&mut self) {
764        tracing::debug!(
765            channel_index = self.inner.channel_index,
766            "wait for IOs completed..."
767        );
768        while let Some((id, _, _)) = self.inner.scsi_requests.next().await {
769            self.inner.scsi_requests_states.remove(id);
770        }
771    }
772}
773
774impl InspectTask<Worker> for WorkerState {
775    fn inspect(&self, req: inspect::Request<'_>, worker: Option<&Worker>) {
776        if let Some(worker) = worker {
777            let mut resp = req.respond();
778            if worker.inner.channel_index == 0 {
779                let (state, version, subchannel_count) = match *worker.inner.protocol.state.read() {
780                    ProtocolState::Init(state) => match state {
781                        InitState::Begin => ("begin_init", None, None),
782                        InitState::QueryVersion => ("query_version", None, None),
783                        InitState::QueryProperties { version } => {
784                            ("query_properties", Some(version), None)
785                        }
786                        InitState::EndInitialization {
787                            version,
788                            subchannel_count,
789                        } => ("end_init", Some(version), subchannel_count),
790                    },
791                    ProtocolState::Ready {
792                        version,
793                        subchannel_count,
794                    } => ("ready", Some(version), Some(subchannel_count)),
795                };
796                resp.field("state", state)
797                    .field("version", version)
798                    .field("subchannel_count", subchannel_count);
799            }
800            resp.field("pending_packets", worker.inner.scsi_requests_states.len())
801                .fields("io", worker.inner.scsi_requests_states.iter())
802                .field("stats", &worker.inner.stats)
803                .field("ring", &worker.queue)
804                .field("max_io_queue_depth", worker.inner.max_io_queue_depth);
805        }
806    }
807}
808
809impl<T: 'static + Send + Sync + RingMem> AsyncRun<Worker<T>> for WorkerState {
810    async fn run(
811        &mut self,
812        stop: &mut StopTask<'_>,
813        worker: &mut Worker<T>,
814    ) -> Result<(), task_control::Cancelled> {
815        let fut = async {
816            if worker.inner.channel_index == 0 {
817                worker.process_primary().await
818            } else {
819                // Wait for initialization to end before processing any
820                // subchannel packets.
821                let protocol_version = loop {
822                    let listener = worker.inner.protocol.ready.listen();
823                    if let ProtocolState::Ready { version, .. } =
824                        *worker.inner.protocol.state.read()
825                    {
826                        break version;
827                    }
828                    tracing::debug!("subchannel waiting for initialization to end");
829                    listener.await
830                };
831                worker
832                    .inner
833                    .process_ready(&mut worker.queue, protocol_version)
834                    .await
835            }
836        };
837
838        match stop.until_stopped(fut).await? {
839            Ok(_) => {}
840            Err(e) => tracing::error!(error = e.as_error(), "process_packets error"),
841        }
842        Ok(())
843    }
844}
845
846impl WorkerInner {
847    /// Awaits the next incoming packet, without checking for any other events (device add/remove notifications or available completions).
848    /// Increments the count of outstanding packets when returning `Ok(Packet)`.
849    async fn next_packet<'a, M: RingMem>(
850        &mut self,
851        reader: &'a mut queue::ReadHalf<'a, M>,
852    ) -> Result<Packet, WorkerError> {
853        let packet = reader.read().await.map_err(WorkerError::Queue)?;
854        let stor_packet =
855            parse_packet(&packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
856        Ok(stor_packet)
857    }
858
859    /// Polls for enough ring space in the outgoing ring to send a packet.
860    ///
861    /// This is used to ensure there is enough space in the ring before
862    /// committing to sending a packet. This avoids the need to save pending
863    /// packets on the side if queue processing is interrupted while the ring is
864    /// full.
865    fn poll_for_ring_space<M: RingMem>(
866        &mut self,
867        cx: &mut Context<'_>,
868        writer: &mut queue::WriteHalf<'_, M>,
869    ) -> Poll<Result<(), WorkerError>> {
870        writer
871            .poll_ready(cx, MAX_VMBUS_PACKET_SIZE)
872            .map_err(WorkerError::Queue)
873    }
874}
875
876const MAX_VMBUS_PACKET_SIZE: usize = ring::PacketSize::in_band(
877    size_of::<storvsp_protocol::Packet>() + storvsp_protocol::SCSI_REQUEST_LEN_MAX,
878);
879
880impl<T: RingMem> Worker<T> {
881    /// Processes the protocol state machine.
882    async fn process_primary(&mut self) -> Result<(), WorkerError> {
883        loop {
884            let current_state = *self.inner.protocol.state.read();
885            match current_state {
886                ProtocolState::Ready { version, .. } => {
887                    break loop {
888                        select_biased! {
889                            r = self.inner.process_ready(&mut self.queue, version).fuse() => break r,
890                            _ = self.fast_select.select((self.rescan_notification.select_next_some(),)).fuse() => {
891                                if version >= Version::Win7
892                                {
893                                    self.inner.send_packet(&mut self.queue.split().1, storvsp_protocol::Operation::ENUMERATE_BUS, storvsp_protocol::NtStatus::SUCCESS, &())?;
894                                }
895                            }
896                        }
897                    };
898                }
899                ProtocolState::Init(state) => {
900                    let (mut reader, mut writer) = self.queue.split();
901
902                    // Ensure that subsequent calls to `send_completion` won't
903                    // fail due to lack of ring space, to avoid keeping (and saving/restoring) interim states.
904                    poll_fn(|cx| self.inner.poll_for_ring_space(cx, &mut writer)).await?;
905
906                    tracing::debug!(?state, "process_primary");
907                    match state {
908                        InitState::Begin => {
909                            let packet = self.inner.next_packet(&mut reader).await?;
910                            if let PacketData::BeginInitialization = packet.data {
911                                self.inner.send_completion(
912                                    &mut writer,
913                                    &packet,
914                                    storvsp_protocol::NtStatus::SUCCESS,
915                                    &(),
916                                )?;
917                                *self.inner.protocol.state.write() =
918                                    ProtocolState::Init(InitState::QueryVersion);
919                            } else {
920                                tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
921                                self.inner.send_completion(
922                                    &mut writer,
923                                    &packet,
924                                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
925                                    &(),
926                                )?;
927                            }
928                        }
929                        InitState::QueryVersion => {
930                            let packet = self.inner.next_packet(&mut reader).await?;
931                            if let PacketData::QueryProtocolVersion(major_minor) = packet.data {
932                                if let Ok(version) = Version::parse(major_minor) {
933                                    self.inner.send_completion(
934                                        &mut writer,
935                                        &packet,
936                                        storvsp_protocol::NtStatus::SUCCESS,
937                                        &storvsp_protocol::ProtocolVersion {
938                                            major_minor,
939                                            reserved: 0,
940                                        },
941                                    )?;
942                                    self.inner.request_size = version.max_request_size();
943                                    *self.inner.protocol.state.write() =
944                                        ProtocolState::Init(InitState::QueryProperties { version });
945
946                                    tracelimit::info_ratelimited!(
947                                        ?version,
948                                        "scsi version negotiated"
949                                    );
950                                } else {
951                                    self.inner.send_completion(
952                                        &mut writer,
953                                        &packet,
954                                        storvsp_protocol::NtStatus::REVISION_MISMATCH,
955                                        &storvsp_protocol::ProtocolVersion {
956                                            major_minor,
957                                            reserved: 0,
958                                        },
959                                    )?;
960                                    *self.inner.protocol.state.write() =
961                                        ProtocolState::Init(InitState::QueryVersion);
962                                }
963                            } else {
964                                tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
965                                self.inner.send_completion(
966                                    &mut writer,
967                                    &packet,
968                                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
969                                    &(),
970                                )?;
971                            }
972                        }
973                        InitState::QueryProperties { version } => {
974                            let packet = self.inner.next_packet(&mut reader).await?;
975                            if let PacketData::QueryProperties = packet.data {
976                                let multi_channel_supported = version >= Version::Win8;
977
978                                self.inner.send_completion(
979                                    &mut writer,
980                                    &packet,
981                                    storvsp_protocol::NtStatus::SUCCESS,
982                                    &storvsp_protocol::ChannelProperties {
983                                        max_transfer_bytes: 0x40000, // 256KB
984                                        flags: {
985                                            if multi_channel_supported {
986                                                storvsp_protocol::STORAGE_CHANNEL_SUPPORTS_MULTI_CHANNEL
987                                            } else {
988                                                0
989                                            }
990                                        },
991                                        maximum_sub_channel_count: if multi_channel_supported {
992                                            self.inner.channel_control.max_subchannels()
993                                        } else {
994                                            0
995                                        },
996                                        reserved: 0,
997                                        reserved2: 0,
998                                        reserved3: [0, 0],
999                                    },
1000                                )?;
1001                                *self.inner.protocol.state.write() =
1002                                    ProtocolState::Init(InitState::EndInitialization {
1003                                        version,
1004                                        subchannel_count: if multi_channel_supported {
1005                                            None
1006                                        } else {
1007                                            Some(0)
1008                                        },
1009                                    });
1010                            } else {
1011                                tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1012                                self.inner.send_completion(
1013                                    &mut writer,
1014                                    &packet,
1015                                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1016                                    &(),
1017                                )?;
1018                            }
1019                        }
1020                        InitState::EndInitialization {
1021                            version,
1022                            subchannel_count,
1023                        } => {
1024                            let packet = self.inner.next_packet(&mut reader).await?;
1025                            match packet.data {
1026                                PacketData::CreateSubChannels(sub_channel_count)
1027                                    if subchannel_count.is_none() =>
1028                                {
1029                                    if let Err(err) = self
1030                                        .inner
1031                                        .channel_control
1032                                        .enable_subchannels(sub_channel_count)
1033                                    {
1034                                        tracelimit::warn_ratelimited!(
1035                                            ?err,
1036                                            "cannot enable subchannels"
1037                                        );
1038                                        self.inner.send_completion(
1039                                            &mut writer,
1040                                            &packet,
1041                                            storvsp_protocol::NtStatus::INVALID_PARAMETER,
1042                                            &(),
1043                                        )?;
1044                                    } else {
1045                                        self.inner.send_completion(
1046                                            &mut writer,
1047                                            &packet,
1048                                            storvsp_protocol::NtStatus::SUCCESS,
1049                                            &(),
1050                                        )?;
1051                                        *self.inner.protocol.state.write() =
1052                                            ProtocolState::Init(InitState::EndInitialization {
1053                                                version,
1054                                                subchannel_count: Some(sub_channel_count),
1055                                            });
1056                                    }
1057                                }
1058                                PacketData::EndInitialization => {
1059                                    self.inner.send_completion(
1060                                        &mut writer,
1061                                        &packet,
1062                                        storvsp_protocol::NtStatus::SUCCESS,
1063                                        &(),
1064                                    )?;
1065                                    // Reset the rescan notification event now, before the guest has a
1066                                    // chance to send any SCSI requests to scan the bus.
1067                                    self.rescan_notification.try_next().ok();
1068                                    *self.inner.protocol.state.write() = ProtocolState::Ready {
1069                                        version,
1070                                        subchannel_count: subchannel_count.unwrap_or(0),
1071                                    };
1072                                    // Wake up subchannels waiting for the
1073                                    // protocol state to become ready.
1074                                    self.inner.protocol.ready.notify(usize::MAX);
1075                                }
1076                                _ => {
1077                                    tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1078                                    self.inner.send_completion(
1079                                        &mut writer,
1080                                        &packet,
1081                                        storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1082                                        &(),
1083                                    )?;
1084                                }
1085                            }
1086                        }
1087                    }
1088                }
1089            }
1090        }
1091    }
1092}
1093
1094fn convert_srb_status_to_nt_status(srb_status: SrbStatus) -> storvsp_protocol::NtStatus {
1095    match srb_status {
1096        SrbStatus::BUSY => storvsp_protocol::NtStatus::DEVICE_BUSY,
1097        SrbStatus::SUCCESS => storvsp_protocol::NtStatus::SUCCESS,
1098        SrbStatus::INVALID_LUN
1099        | SrbStatus::INVALID_TARGET_ID
1100        | SrbStatus::NO_DEVICE
1101        | SrbStatus::NO_HBA => storvsp_protocol::NtStatus::DEVICE_DOES_NOT_EXIST,
1102        SrbStatus::COMMAND_TIMEOUT | SrbStatus::TIMEOUT => storvsp_protocol::NtStatus::IO_TIMEOUT,
1103        SrbStatus::SELECTION_TIMEOUT => storvsp_protocol::NtStatus::DEVICE_NOT_CONNECTED,
1104        SrbStatus::BAD_FUNCTION | SrbStatus::BAD_SRB_BLOCK_LENGTH => {
1105            storvsp_protocol::NtStatus::INVALID_DEVICE_REQUEST
1106        }
1107        SrbStatus::DATA_OVERRUN => storvsp_protocol::NtStatus::BUFFER_OVERFLOW,
1108        SrbStatus::REQUEST_FLUSHED => storvsp_protocol::NtStatus::UNSUCCESSFUL,
1109        SrbStatus::ABORTED => storvsp_protocol::NtStatus::CANCELLED,
1110        _ => storvsp_protocol::NtStatus::IO_DEVICE_ERROR,
1111    }
1112}
1113
1114impl WorkerInner {
1115    /// Processes packets and SCSI completions after protocol negotiation has finished.
1116    async fn process_ready<M: RingMem>(
1117        &mut self,
1118        queue: &mut Queue<M>,
1119        protocol_version: Version,
1120    ) -> Result<(), WorkerError> {
1121        self.request_size = protocol_version.max_request_size();
1122        poll_fn(|cx| self.poll_process_ready(cx, queue)).await
1123    }
1124
1125    /// Processes packets and SCSI completions after protocol negotiation has finished.
1126    fn poll_process_ready<M: RingMem>(
1127        &mut self,
1128        cx: &mut Context<'_>,
1129        queue: &mut Queue<M>,
1130    ) -> Poll<Result<(), WorkerError>> {
1131        self.stats.wakes.increment();
1132
1133        let (mut reader, mut writer) = queue.split();
1134        let mut total_completions = 0;
1135        let mut total_submissions = 0;
1136        let poll_mode_queue_depth = self.controller.poll_mode_queue_depth.load(Relaxed) as usize;
1137
1138        loop {
1139            // Drive IOs forward and collect completions.
1140            'outer: while !self.scsi_requests_states.is_empty() {
1141                {
1142                    let mut batch = writer.batched();
1143                    loop {
1144                        // Ensure there is room for the completion before consuming
1145                        // the IO so that we don't have to track completed IOs whose
1146                        // completions haven't been sent.
1147                        if !batch
1148                            .can_write(MAX_VMBUS_PACKET_SIZE)
1149                            .map_err(WorkerError::Queue)?
1150                        {
1151                            // This batch is full but there may still be more completions.
1152                            break;
1153                        }
1154                        if let Poll::Ready(Some((request_id, result, future))) =
1155                            self.scsi_requests.poll_next_unpin(cx)
1156                        {
1157                            self.future_pool.push(future);
1158                            self.handle_completion(&mut batch, request_id, result)?;
1159                            total_completions += 1;
1160                        } else {
1161                            tracing::trace!("out of completions");
1162                            break 'outer;
1163                        }
1164                    }
1165                }
1166
1167                // Wait for enough space for any completion packets.
1168                if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1169                    tracing::trace!("out of ring space");
1170                    break;
1171                }
1172            }
1173
1174            let mut submissions = 0;
1175            // Process new requests.
1176            'outer: loop {
1177                if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1178                    break;
1179                }
1180                let mut batch = if self.scsi_requests_states.len() < poll_mode_queue_depth {
1181                    if let Poll::Ready(batch) = reader.poll_read_batch(cx) {
1182                        batch.map_err(WorkerError::Queue)?
1183                    } else {
1184                        tracing::trace!("out of incoming packets");
1185                        break;
1186                    }
1187                } else {
1188                    match reader.try_read_batch() {
1189                        Ok(batch) => batch,
1190                        Err(queue::TryReadError::Empty) => {
1191                            tracing::trace!(
1192                                pending_io_count = self.scsi_requests_states.len(),
1193                                "out of incoming packets, keeping interrupts masked"
1194                            );
1195                            break;
1196                        }
1197                        Err(queue::TryReadError::Queue(err)) => Err(WorkerError::Queue(err))?,
1198                    }
1199                };
1200
1201                let mut packets = batch.packets();
1202                loop {
1203                    if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1204                        break 'outer;
1205                    }
1206                    // Wait for enough space for any completion packets that
1207                    // `handle_packet` may need to send, so that it isn't necessary
1208                    // to track pending completions.
1209                    if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1210                        tracing::trace!("out of ring space");
1211                        break 'outer;
1212                    }
1213
1214                    let packet = if let Some(packet) = packets.next() {
1215                        packet.map_err(WorkerError::Queue)?
1216                    } else {
1217                        break;
1218                    };
1219
1220                    if self.handle_packet(&mut writer, &packet)? {
1221                        submissions += 1;
1222                    }
1223                }
1224            }
1225
1226            // Loop around to poll the IOs if any new IOs were submitted.
1227            if submissions == 0 {
1228                // No need to poll again.
1229                break;
1230            }
1231            total_submissions += submissions;
1232        }
1233
1234        if total_submissions != 0 || total_completions != 0 {
1235            self.stats.ios_submitted.add(total_submissions);
1236            self.stats
1237                .per_wake_submissions
1238                .add_sample(total_submissions);
1239            self.stats
1240                .per_wake_completions
1241                .add_sample(total_completions);
1242            self.stats.ios_completed.add(total_completions);
1243        } else {
1244            self.stats.wakes_spurious.increment();
1245        }
1246
1247        Poll::Pending
1248    }
1249
1250    fn handle_completion<M: RingMem>(
1251        &mut self,
1252        writer: &mut queue::WriteBatch<'_, M>,
1253        request_id: usize,
1254        result: ScsiResult,
1255    ) -> Result<(), WorkerError> {
1256        let state = self.scsi_requests_states.remove(request_id);
1257        let request_size = state.request.request_size;
1258
1259        // Push the request into the pool to avoid reallocating later.
1260        assert_eq!(
1261            Arc::strong_count(&state.request) + Arc::weak_count(&state.request),
1262            1
1263        );
1264        self.full_request_pool.push(state.request);
1265
1266        let status = convert_srb_status_to_nt_status(result.srb_status);
1267        let mut payload = [0; 0x14];
1268        if let Some(sense) = result.sense_data {
1269            payload[..size_of_val(&sense)].copy_from_slice(sense.as_bytes());
1270            tracing::trace!(sense_info = ?payload, sense_key = payload[2], asc = payload[12], "execute_scsi");
1271        };
1272        let response = storvsp_protocol::ScsiRequest {
1273            length: size_of::<storvsp_protocol::ScsiRequest>() as u16,
1274            scsi_status: result.scsi_status,
1275            srb_status: SrbStatusAndFlags::new()
1276                .with_status(result.srb_status)
1277                .with_autosense_valid(result.sense_data.is_some()),
1278            data_transfer_length: result.tx as u32,
1279            cdb_length: storvsp_protocol::CDB16GENERIC_LENGTH as u8,
1280            sense_info_ex_length: storvsp_protocol::VMSCSI_SENSE_BUFFER_SIZE as u8,
1281            payload,
1282            ..storvsp_protocol::ScsiRequest::new_zeroed()
1283        };
1284        self.send_vmbus_packet(
1285            writer,
1286            OutgoingPacketType::Completion,
1287            request_size,
1288            state.transaction_id,
1289            storvsp_protocol::Operation::COMPLETE_IO,
1290            status,
1291            response.as_bytes(),
1292        )?;
1293        Ok(())
1294    }
1295
1296    fn handle_packet<M: RingMem>(
1297        &mut self,
1298        writer: &mut queue::WriteHalf<'_, M>,
1299        packet: &IncomingPacket<'_, M>,
1300    ) -> Result<bool, WorkerError> {
1301        let packet =
1302            parse_packet(packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
1303        let submitted_io = match packet.data {
1304            PacketData::ExecuteScsi(request) => {
1305                self.push_scsi_request(packet.transaction_id, request);
1306                true
1307            }
1308            PacketData::ResetAdapter | PacketData::ResetBus | PacketData::ResetLun => {
1309                // These operations have always been no-ops.
1310                self.send_completion(writer, &packet, storvsp_protocol::NtStatus::SUCCESS, &())?;
1311                false
1312            }
1313            PacketData::CreateSubChannels(new_subchannel_count) if self.channel_index == 0 => {
1314                if let Err(err) = self
1315                    .channel_control
1316                    .enable_subchannels(new_subchannel_count)
1317                {
1318                    tracelimit::warn_ratelimited!(?err, "cannot create subchannels");
1319                    self.send_completion(
1320                        writer,
1321                        &packet,
1322                        storvsp_protocol::NtStatus::INVALID_PARAMETER,
1323                        &(),
1324                    )?;
1325                    false
1326                } else {
1327                    // Update the subchannel count in the protocol state for save.
1328                    if let ProtocolState::Ready {
1329                        subchannel_count, ..
1330                    } = &mut *self.protocol.state.write()
1331                    {
1332                        *subchannel_count = new_subchannel_count;
1333                    } else {
1334                        unreachable!()
1335                    }
1336
1337                    self.send_completion(
1338                        writer,
1339                        &packet,
1340                        storvsp_protocol::NtStatus::SUCCESS,
1341                        &(),
1342                    )?;
1343                    false
1344                }
1345            }
1346            _ => {
1347                tracelimit::warn_ratelimited!(data = ?packet.data, "unexpected packet on ready");
1348                self.send_completion(
1349                    writer,
1350                    &packet,
1351                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1352                    &(),
1353                )?;
1354                false
1355            }
1356        };
1357        Ok(submitted_io)
1358    }
1359
1360    fn push_scsi_request(&mut self, transaction_id: u64, full_request: Arc<ScsiRequestAndRange>) {
1361        let scsi_queue = self.scsi_queue.clone();
1362        let scsi_request_state = ScsiRequestState {
1363            transaction_id,
1364            request: full_request.clone(),
1365        };
1366        let request_id = self.scsi_requests_states.insert(scsi_request_state);
1367        let future = self
1368            .future_pool
1369            .pop()
1370            .unwrap_or_else(|| OversizedBox::new(()));
1371        let future = OversizedBox::refill(future, async move {
1372            scsi_queue
1373                .execute_scsi(&full_request.external_data, &full_request.request)
1374                .await
1375        });
1376        let request = ScsiRequest::new(request_id, oversized_box::coerce!(future));
1377        self.scsi_requests.push(request);
1378    }
1379}
1380
1381impl<T: RingMem> Drop for Worker<T> {
1382    fn drop(&mut self) {
1383        self.inner
1384            .controller
1385            .remove_rescan_notification_source(&self.rescan_notification);
1386    }
1387}
1388
1389#[derive(Debug, Error)]
1390#[error("SCSI path {}:{}:{} is already in use", self.0.path, self.0.target, self.0.lun)]
1391pub struct ScsiPathInUse(pub ScsiPath);
1392
1393#[derive(Debug, Error)]
1394#[error("SCSI path {}:{}:{} is not in use", self.0.path, self.0.target, self.0.lun)]
1395pub struct ScsiPathNotInUse(ScsiPath);
1396
1397#[derive(Clone)]
1398struct ScsiRequestState {
1399    transaction_id: u64,
1400    request: Arc<ScsiRequestAndRange>,
1401}
1402
1403#[derive(Debug)]
1404struct ScsiRequestAndRange {
1405    external_data: Range,
1406    request: storvsp_protocol::ScsiRequest,
1407    request_size: usize,
1408}
1409
1410impl Inspect for ScsiRequestState {
1411    fn inspect(&self, req: inspect::Request<'_>) {
1412        req.respond()
1413            .field("transaction_id", self.transaction_id)
1414            .display(
1415                "address",
1416                &ScsiPath {
1417                    path: self.request.request.path_id,
1418                    target: self.request.request.target_id,
1419                    lun: self.request.request.lun,
1420                },
1421            )
1422            .display_debug("operation", &ScsiOp(self.request.request.payload[0]));
1423    }
1424}
1425
1426impl StorageDevice {
1427    /// Returns a new SCSI device.
1428    pub fn build_scsi(
1429        driver_source: &VmTaskDriverSource,
1430        controller: &ScsiController,
1431        instance_id: Guid,
1432        max_sub_channel_count: u16,
1433        io_queue_depth: u32,
1434    ) -> Self {
1435        Self::build_inner(
1436            driver_source,
1437            controller,
1438            instance_id,
1439            None,
1440            max_sub_channel_count,
1441            io_queue_depth,
1442        )
1443    }
1444
1445    /// Returns a new SCSI device for implementing an IDE accelerator channel
1446    /// for IDE device `device_id` on channel `channel_id`.
1447    pub fn build_ide(
1448        driver_source: &VmTaskDriverSource,
1449        channel_id: u8,
1450        device_id: u8,
1451        disk: ScsiControllerDisk,
1452        io_queue_depth: u32,
1453    ) -> Self {
1454        let path = ScsiPath {
1455            path: channel_id,
1456            target: device_id,
1457            lun: 0,
1458        };
1459
1460        let controller = ScsiController::new();
1461        controller.attach(path, disk).unwrap();
1462
1463        // Construct the specific GUID that drivers in the guest expect for this
1464        // IDE device.
1465        let instance_id = Guid {
1466            data1: channel_id.into(),
1467            data2: device_id.into(),
1468            data3: 0x8899,
1469            data4: [0; 8],
1470        };
1471        Self::build_inner(
1472            driver_source,
1473            &controller,
1474            instance_id,
1475            Some(path),
1476            0,
1477            io_queue_depth,
1478        )
1479    }
1480
1481    fn build_inner(
1482        driver_source: &VmTaskDriverSource,
1483        controller: &ScsiController,
1484        instance_id: Guid,
1485        ide_path: Option<ScsiPath>,
1486        max_sub_channel_count: u16,
1487        io_queue_depth: u32,
1488    ) -> Self {
1489        let workers = (0..max_sub_channel_count + 1)
1490            .map(|channel_index| WorkerAndDriver {
1491                worker: TaskControl::new(WorkerState),
1492                driver: driver_source
1493                    .builder()
1494                    .target_vp(0)
1495                    .run_on_target(true)
1496                    .build(format!("storvsp-{}-{}", instance_id, channel_index)),
1497            })
1498            .collect();
1499
1500        Self {
1501            instance_id,
1502            ide_path,
1503            workers,
1504            controller: controller.state.clone(),
1505            resources: Default::default(),
1506            max_sub_channel_count,
1507            driver_source: driver_source.clone(),
1508            protocol: Arc::new(Protocol {
1509                state: RwLock::new(ProtocolState::Init(InitState::Begin)),
1510                ready: Default::default(),
1511            }),
1512            io_queue_depth,
1513        }
1514    }
1515
1516    fn new_worker(
1517        &mut self,
1518        open_request: &OpenRequest,
1519        channel_index: u16,
1520    ) -> anyhow::Result<&mut Worker> {
1521        let controller = self.controller.clone();
1522
1523        let driver = self
1524            .driver_source
1525            .builder()
1526            .target_vp(open_request.open_data.target_vp)
1527            .run_on_target(true)
1528            .build(format!("storvsp-{}-{}", self.instance_id, channel_index));
1529
1530        let channel = gpadl_channel(&driver, &self.resources, open_request, channel_index)
1531            .context("failed to create vmbus channel")?;
1532
1533        let channel_control = self.resources.channel_control.clone();
1534
1535        tracing::debug!(
1536            target_vp = open_request.open_data.target_vp,
1537            channel_index,
1538            "packet processing starting...",
1539        );
1540
1541        // Force the path ID on incoming SCSI requests to match the IDE
1542        // channel ID, since guests do not reliably set the path ID
1543        // correctly.
1544        let force_path_id = self.ide_path.map(|p| p.path);
1545
1546        let worker = Worker::new(
1547            controller,
1548            channel,
1549            channel_index,
1550            self.resources
1551                .offer_resources
1552                .guest_memory(open_request)
1553                .clone(),
1554            channel_control,
1555            self.io_queue_depth,
1556            self.protocol.clone(),
1557            force_path_id,
1558        )
1559        .map_err(RestoreError::Other)?;
1560
1561        self.workers[channel_index as usize]
1562            .driver
1563            .retarget_vp(open_request.open_data.target_vp);
1564
1565        Ok(self.workers[channel_index as usize].worker.insert(
1566            &driver,
1567            format!("storvsp worker {}-{}", self.instance_id, channel_index),
1568            worker,
1569        ))
1570    }
1571}
1572
1573/// A disk that can be added to a SCSI controller.
1574#[derive(Clone)]
1575pub struct ScsiControllerDisk {
1576    disk: Arc<dyn AsyncScsiDisk>,
1577}
1578
1579impl ScsiControllerDisk {
1580    /// Creates a new controller disk from an async SCSI disk.
1581    pub fn new(disk: Arc<dyn AsyncScsiDisk>) -> Self {
1582        Self { disk }
1583    }
1584}
1585
1586struct ScsiControllerState {
1587    disks: RwLock<HashMap<ScsiPath, ScsiControllerDisk>>,
1588    rescan_notification_source: Mutex<Vec<futures::channel::mpsc::Sender<()>>>,
1589    poll_mode_queue_depth: AtomicU32,
1590}
1591
1592pub struct ScsiController {
1593    state: Arc<ScsiControllerState>,
1594}
1595
1596impl ScsiController {
1597    pub fn new() -> Self {
1598        Self::new_with_poll_mode_queue_depth(None)
1599    }
1600
1601    pub fn new_with_poll_mode_queue_depth(poll_mode_queue_depth: Option<u32>) -> Self {
1602        Self {
1603            state: Arc::new(ScsiControllerState {
1604                disks: Default::default(),
1605                rescan_notification_source: Mutex::new(Vec::new()),
1606                poll_mode_queue_depth: AtomicU32::new(
1607                    poll_mode_queue_depth.unwrap_or(DEFAULT_POLL_MODE_QUEUE_DEPTH),
1608                ),
1609            }),
1610        }
1611    }
1612
1613    pub fn attach(&self, path: ScsiPath, disk: ScsiControllerDisk) -> Result<(), ScsiPathInUse> {
1614        match self.state.disks.write().entry(path) {
1615            Entry::Occupied(_) => return Err(ScsiPathInUse(path)),
1616            Entry::Vacant(entry) => entry.insert(disk),
1617        };
1618        for source in self.state.rescan_notification_source.lock().iter_mut() {
1619            // Ok to ignore errors here. If the channel is full a previous notification has not yet
1620            // been processed by the primary channel worker.
1621            source.try_send(()).ok();
1622        }
1623        Ok(())
1624    }
1625
1626    pub fn remove(&self, path: ScsiPath) -> Result<(), ScsiPathNotInUse> {
1627        match self.state.disks.write().entry(path) {
1628            Entry::Vacant(_) => return Err(ScsiPathNotInUse(path)),
1629            Entry::Occupied(entry) => {
1630                entry.remove();
1631            }
1632        }
1633        for source in self.state.rescan_notification_source.lock().iter_mut() {
1634            // Ok to ignore errors here. If the channel is full a previous notification has not yet
1635            // been processed by the primary channel worker.
1636            source.try_send(()).ok();
1637        }
1638        Ok(())
1639    }
1640}
1641
1642impl ScsiControllerState {
1643    fn add_rescan_notification_source(&self, source: futures::channel::mpsc::Sender<()>) {
1644        self.rescan_notification_source.lock().push(source);
1645    }
1646
1647    fn remove_rescan_notification_source(&self, target: &futures::channel::mpsc::Receiver<()>) {
1648        let mut sources = self.rescan_notification_source.lock();
1649        if let Some(index) = sources
1650            .iter()
1651            .position(|source| source.is_connected_to(target))
1652        {
1653            sources.remove(index);
1654        }
1655    }
1656}
1657
1658#[async_trait]
1659impl VmbusDevice for StorageDevice {
1660    fn offer(&self) -> OfferParams {
1661        if let Some(path) = self.ide_path {
1662            let offer_properties = storvsp_protocol::OfferProperties {
1663                path_id: path.path,
1664                target_id: path.target,
1665                flags: storvsp_protocol::OFFER_PROPERTIES_FLAG_IDE_DEVICE,
1666                ..FromZeros::new_zeroed()
1667            };
1668            let mut user_defined = UserDefinedData::new_zeroed();
1669            offer_properties
1670                .write_to_prefix(&mut user_defined[..])
1671                .unwrap();
1672            OfferParams {
1673                interface_name: "ide-accel".to_owned(),
1674                instance_id: self.instance_id,
1675                interface_id: storvsp_protocol::IDE_ACCELERATOR_INTERFACE_ID,
1676                channel_type: ChannelType::Interface { user_defined },
1677                ..Default::default()
1678            }
1679        } else {
1680            OfferParams {
1681                interface_name: "scsi".to_owned(),
1682                instance_id: self.instance_id,
1683                interface_id: storvsp_protocol::SCSI_INTERFACE_ID,
1684                ..Default::default()
1685            }
1686        }
1687    }
1688
1689    fn max_subchannels(&self) -> u16 {
1690        self.max_sub_channel_count
1691    }
1692
1693    fn install(&mut self, resources: DeviceResources) {
1694        self.resources = resources;
1695    }
1696
1697    async fn open(
1698        &mut self,
1699        channel_index: u16,
1700        open_request: &OpenRequest,
1701    ) -> Result<(), ChannelOpenError> {
1702        tracing::debug!(channel_index, "scsi open channel");
1703        self.new_worker(open_request, channel_index)?;
1704        self.workers[channel_index as usize].worker.start();
1705        Ok(())
1706    }
1707
1708    async fn close(&mut self, channel_index: u16) {
1709        tracing::debug!(channel_index, "scsi close channel");
1710        let worker = &mut self.workers[channel_index as usize].worker;
1711        worker.stop().await;
1712        if worker.state_mut().is_some() {
1713            worker
1714                .state_mut()
1715                .unwrap()
1716                .wait_for_scsi_requests_complete()
1717                .await;
1718            worker.remove();
1719        }
1720        if channel_index == 0 {
1721            *self.protocol.state.write() = ProtocolState::Init(InitState::Begin);
1722        }
1723    }
1724
1725    async fn retarget_vp(&mut self, channel_index: u16, target_vp: u32) {
1726        self.workers[channel_index as usize]
1727            .driver
1728            .retarget_vp(target_vp);
1729    }
1730
1731    fn start(&mut self) {
1732        for task in self
1733            .workers
1734            .iter_mut()
1735            .filter(|task| task.worker.has_state() && !task.worker.is_running())
1736        {
1737            task.worker.start();
1738        }
1739    }
1740
1741    async fn stop(&mut self) {
1742        tracing::debug!(instance_id = ?self.instance_id, "StorageDevice stopping...");
1743        for task in self
1744            .workers
1745            .iter_mut()
1746            .filter(|task| task.worker.has_state() && task.worker.is_running())
1747        {
1748            task.worker.stop().await;
1749        }
1750    }
1751
1752    fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1753        Some(self)
1754    }
1755}
1756
1757#[async_trait]
1758impl SaveRestoreVmbusDevice for StorageDevice {
1759    async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1760        Ok(SavedStateBlob::new(self.save()?))
1761    }
1762
1763    async fn restore(
1764        &mut self,
1765        control: RestoreControl<'_>,
1766        state: SavedStateBlob,
1767    ) -> Result<(), RestoreError> {
1768        self.restore(control, state.parse()?).await
1769    }
1770}
1771
1772#[cfg(test)]
1773mod tests {
1774    use super::*;
1775    use crate::test_helpers::TestWorker;
1776    use crate::test_helpers::parse_guest_completion;
1777    use crate::test_helpers::parse_guest_completion_check_flags_status;
1778    use pal_async::DefaultDriver;
1779    use pal_async::async_test;
1780    use scsi::srb::SrbStatus;
1781    use test_with_tracing::test;
1782    use vmbus_channel::connected_async_channels;
1783
1784    // Discourage `Clone` for `ScsiController` outside the crate, but it is
1785    // necessary for testing. The fuzzer also uses `TestWorker`, which needs
1786    // a `clone` of the inner state, but is not in this crate.
1787    impl Clone for ScsiController {
1788        fn clone(&self) -> Self {
1789            ScsiController {
1790                state: self.state.clone(),
1791            }
1792        }
1793    }
1794
1795    #[async_test]
1796    async fn test_channel_working(driver: DefaultDriver) {
1797        // set up the channels and worker
1798        let (host, guest) = connected_async_channels(16 * 1024);
1799        let guest_queue = Queue::new(guest).unwrap();
1800
1801        let test_guest_mem = GuestMemory::allocate(16384);
1802        let controller = ScsiController::new();
1803        let disk = scsidisk::SimpleScsiDisk::new(
1804            disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
1805            Default::default(),
1806        );
1807        controller
1808            .attach(
1809                ScsiPath {
1810                    path: 0,
1811                    target: 0,
1812                    lun: 0,
1813                },
1814                ScsiControllerDisk::new(Arc::new(disk)),
1815            )
1816            .unwrap();
1817
1818        let test_worker = TestWorker::start(
1819            controller.clone(),
1820            driver.clone(),
1821            test_guest_mem.clone(),
1822            host,
1823            None,
1824        );
1825
1826        let mut guest = test_helpers::TestGuest {
1827            queue: guest_queue,
1828            transaction_id: 0,
1829        };
1830
1831        guest.perform_protocol_negotiation().await;
1832
1833        // Set up the buffer for a write request
1834        const IO_LEN: usize = 4 * 1024;
1835        let write_buf = [7u8; IO_LEN];
1836        let write_gpa = 4 * 1024u64;
1837        test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
1838        guest
1839            .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
1840            .await;
1841
1842        guest
1843            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1844            .await;
1845
1846        let read_gpa = 8 * 1024u64;
1847        guest
1848            .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
1849            .await;
1850
1851        guest
1852            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1853            .await;
1854        let mut read_buf = [0u8; IO_LEN];
1855        test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
1856        for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
1857            assert_eq!(b1, b2);
1858        }
1859
1860        // stop everything
1861        guest.verify_graceful_close(test_worker).await;
1862    }
1863
1864    #[async_test]
1865    async fn test_packet_sizes(driver: DefaultDriver) {
1866        // set up the channels and worker
1867        let (host, guest) = connected_async_channels(16384);
1868        let guest_queue = Queue::new(guest).unwrap();
1869
1870        let test_guest_mem = GuestMemory::allocate(1024);
1871        let controller = ScsiController::new();
1872
1873        let _worker = TestWorker::start(
1874            controller.clone(),
1875            driver.clone(),
1876            test_guest_mem,
1877            host,
1878            None,
1879        );
1880
1881        let mut guest = test_helpers::TestGuest {
1882            queue: guest_queue,
1883            transaction_id: 0,
1884        };
1885
1886        let negotiate_packet = storvsp_protocol::Packet {
1887            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
1888            flags: 0,
1889            status: storvsp_protocol::NtStatus::SUCCESS,
1890        };
1891        guest
1892            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1893            .await;
1894
1895        guest.verify_completion(parse_guest_completion).await;
1896
1897        let header = storvsp_protocol::Packet {
1898            operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
1899            flags: 0,
1900            status: storvsp_protocol::NtStatus::SUCCESS,
1901        };
1902
1903        let mut buf = [0u8; 128];
1904        storvsp_protocol::ProtocolVersion {
1905            major_minor: !0,
1906            reserved: 0,
1907        }
1908        .write_to_prefix(&mut buf[..])
1909        .unwrap(); // PANIC: Infallable since `ProtcolVersion` is less than 128 bytes
1910
1911        for &(len, resp_len) in &[(48, 48), (50, 56), (56, 56), (64, 64), (72, 64)] {
1912            guest
1913                .send_data_packet_sync(&[header.as_bytes(), &buf[..len - size_of_val(&header)]])
1914                .await;
1915
1916            guest
1917                .verify_completion(|packet| {
1918                    let IncomingPacket::Completion(packet) = packet else {
1919                        unreachable!()
1920                    };
1921                    assert_eq!(packet.reader().len(), resp_len);
1922                    assert_eq!(
1923                        packet
1924                            .reader()
1925                            .read_plain::<storvsp_protocol::Packet>()
1926                            .unwrap()
1927                            .status,
1928                        storvsp_protocol::NtStatus::REVISION_MISMATCH
1929                    );
1930                    Ok(())
1931                })
1932                .await;
1933        }
1934    }
1935
1936    #[async_test]
1937    async fn test_wrong_first_packet(driver: DefaultDriver) {
1938        // set up the channels and worker
1939        let (host, guest) = connected_async_channels(16384);
1940        let guest_queue = Queue::new(guest).unwrap();
1941
1942        let test_guest_mem = GuestMemory::allocate(1024);
1943        let controller = ScsiController::new();
1944
1945        let _worker = TestWorker::start(
1946            controller.clone(),
1947            driver.clone(),
1948            test_guest_mem,
1949            host,
1950            None,
1951        );
1952
1953        let mut guest = test_helpers::TestGuest {
1954            queue: guest_queue,
1955            transaction_id: 0,
1956        };
1957
1958        // Protocol negotiation done out of order
1959        let negotiate_packet = storvsp_protocol::Packet {
1960            operation: storvsp_protocol::Operation::END_INITIALIZATION,
1961            flags: 0,
1962            status: storvsp_protocol::NtStatus::SUCCESS,
1963        };
1964        guest
1965            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1966            .await;
1967
1968        guest
1969            .verify_completion(|packet| {
1970                let IncomingPacket::Completion(packet) = packet else {
1971                    unreachable!()
1972                };
1973                assert_eq!(
1974                    packet
1975                        .reader()
1976                        .read_plain::<storvsp_protocol::Packet>()
1977                        .unwrap()
1978                        .status,
1979                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE
1980                );
1981                Ok(())
1982            })
1983            .await;
1984    }
1985
1986    #[async_test]
1987    async fn test_unrecognized_operation(driver: DefaultDriver) {
1988        // set up the channels and worker
1989        let (host, guest) = connected_async_channels(16384);
1990        let guest_queue = Queue::new(guest).unwrap();
1991
1992        let test_guest_mem = GuestMemory::allocate(1024);
1993        let controller = ScsiController::new();
1994
1995        let worker = TestWorker::start(
1996            controller.clone(),
1997            driver.clone(),
1998            test_guest_mem,
1999            host,
2000            None,
2001        );
2002
2003        let mut guest = test_helpers::TestGuest {
2004            queue: guest_queue,
2005            transaction_id: 0,
2006        };
2007
2008        // Send packet with unrecognized operation
2009        let negotiate_packet = storvsp_protocol::Packet {
2010            operation: storvsp_protocol::Operation::REMOVE_DEVICE,
2011            flags: 0,
2012            status: storvsp_protocol::NtStatus::SUCCESS,
2013        };
2014        guest
2015            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2016            .await;
2017
2018        match worker.teardown().await {
2019            Err(WorkerError::PacketError(PacketError::UnrecognizedOperation(
2020                storvsp_protocol::Operation::REMOVE_DEVICE,
2021            ))) => {}
2022            result => panic!("Worker failed with unexpected result {:?}!", result),
2023        }
2024    }
2025
2026    #[async_test]
2027    async fn test_too_many_subchannels(driver: DefaultDriver) {
2028        // set up the channels and worker
2029        let (host, guest) = connected_async_channels(16384);
2030        let guest_queue = Queue::new(guest).unwrap();
2031
2032        let test_guest_mem = GuestMemory::allocate(1024);
2033        let controller = ScsiController::new();
2034
2035        let _worker = TestWorker::start(
2036            controller.clone(),
2037            driver.clone(),
2038            test_guest_mem,
2039            host,
2040            None,
2041        );
2042
2043        let mut guest = test_helpers::TestGuest {
2044            queue: guest_queue,
2045            transaction_id: 0,
2046        };
2047
2048        let negotiate_packet = storvsp_protocol::Packet {
2049            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2050            flags: 0,
2051            status: storvsp_protocol::NtStatus::SUCCESS,
2052        };
2053        guest
2054            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2055            .await;
2056        guest.verify_completion(parse_guest_completion).await;
2057
2058        let version_packet = storvsp_protocol::Packet {
2059            operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2060            flags: 0,
2061            status: storvsp_protocol::NtStatus::SUCCESS,
2062        };
2063        let version = storvsp_protocol::ProtocolVersion {
2064            major_minor: storvsp_protocol::VERSION_BLUE,
2065            reserved: 0,
2066        };
2067        guest
2068            .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2069            .await;
2070        guest.verify_completion(parse_guest_completion).await;
2071
2072        let properties_packet = storvsp_protocol::Packet {
2073            operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2074            flags: 0,
2075            status: storvsp_protocol::NtStatus::SUCCESS,
2076        };
2077        guest
2078            .send_data_packet_sync(&[properties_packet.as_bytes()])
2079            .await;
2080
2081        guest.verify_completion(parse_guest_completion).await;
2082
2083        let negotiate_packet = storvsp_protocol::Packet {
2084            operation: storvsp_protocol::Operation::CREATE_SUB_CHANNELS,
2085            flags: 0,
2086            status: storvsp_protocol::NtStatus::SUCCESS,
2087        };
2088        // Create sub channels more than maximum_sub_channel_count
2089        guest
2090            .send_data_packet_sync(&[negotiate_packet.as_bytes(), 1_u16.as_bytes()])
2091            .await;
2092
2093        guest
2094            .verify_completion(|packet| {
2095                let IncomingPacket::Completion(packet) = packet else {
2096                    unreachable!()
2097                };
2098                assert_eq!(
2099                    packet
2100                        .reader()
2101                        .read_plain::<storvsp_protocol::Packet>()
2102                        .unwrap()
2103                        .status,
2104                    storvsp_protocol::NtStatus::INVALID_PARAMETER
2105                );
2106                Ok(())
2107            })
2108            .await;
2109    }
2110
2111    #[async_test]
2112    async fn test_begin_init_on_ready(driver: DefaultDriver) {
2113        // set up the channels and worker
2114        let (host, guest) = connected_async_channels(16384);
2115        let guest_queue = Queue::new(guest).unwrap();
2116
2117        let test_guest_mem = GuestMemory::allocate(1024);
2118        let controller = ScsiController::new();
2119
2120        let _worker = TestWorker::start(
2121            controller.clone(),
2122            driver.clone(),
2123            test_guest_mem,
2124            host,
2125            None,
2126        );
2127
2128        let mut guest = test_helpers::TestGuest {
2129            queue: guest_queue,
2130            transaction_id: 0,
2131        };
2132
2133        guest.perform_protocol_negotiation().await;
2134
2135        // Protocol negotiation done out of order
2136        let negotiate_packet = storvsp_protocol::Packet {
2137            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2138            flags: 0,
2139            status: storvsp_protocol::NtStatus::SUCCESS,
2140        };
2141        guest
2142            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2143            .await;
2144
2145        guest
2146            .verify_completion(|p| {
2147                parse_guest_completion_check_flags_status(
2148                    p,
2149                    0,
2150                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
2151                )
2152            })
2153            .await;
2154    }
2155
2156    #[async_test]
2157    async fn test_hot_add_remove(driver: DefaultDriver) {
2158        // set up channels and worker.
2159        let (host, guest) = connected_async_channels(16 * 1024);
2160        let guest_queue = Queue::new(guest).unwrap();
2161
2162        let test_guest_mem = GuestMemory::allocate(16384);
2163        // create a controller with no disk yet.
2164        let controller = ScsiController::new();
2165
2166        let test_worker = TestWorker::start(
2167            controller.clone(),
2168            driver.clone(),
2169            test_guest_mem.clone(),
2170            host,
2171            None,
2172        );
2173
2174        let mut guest = test_helpers::TestGuest {
2175            queue: guest_queue,
2176            transaction_id: 0,
2177        };
2178
2179        guest.perform_protocol_negotiation().await;
2180
2181        // Verify no LUNs are reported initially.
2182        let mut lun_list_buffer: [u8; 256] = [0; 256];
2183        let mut disk_count = 0;
2184        guest
2185            .send_report_luns_packet(ScsiPath::default(), 0, lun_list_buffer.len())
2186            .await;
2187        guest
2188            .verify_completion(|p| {
2189                test_helpers::parse_guest_completed_io_check_tx_len(p, SrbStatus::SUCCESS, Some(8))
2190            })
2191            .await;
2192        test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2193        let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2194        assert_eq!(lun_list_size, disk_count as u32 * 8);
2195
2196        // Set up a buffer for writes.
2197        const IO_LEN: usize = 4 * 1024;
2198        let write_buf = [7u8; IO_LEN];
2199        let write_gpa = 4 * 1024u64;
2200        test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2201
2202        guest
2203            .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2204            .await;
2205        guest
2206            .verify_completion(|p| {
2207                test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2208            })
2209            .await;
2210
2211        // Add some disks while the guest is running.
2212        for lun in 0..4 {
2213            let disk = scsidisk::SimpleScsiDisk::new(
2214                disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
2215                Default::default(),
2216            );
2217            controller
2218                .attach(
2219                    ScsiPath {
2220                        path: 0,
2221                        target: 0,
2222                        lun,
2223                    },
2224                    ScsiControllerDisk::new(Arc::new(disk)),
2225                )
2226                .unwrap();
2227            guest
2228                .verify_completion(test_helpers::parse_guest_enumerate_bus)
2229                .await;
2230
2231            disk_count += 1;
2232            guest
2233                .send_report_luns_packet(ScsiPath::default(), 0, 256)
2234                .await;
2235            guest
2236                .verify_completion(|p| {
2237                    test_helpers::parse_guest_completed_io_check_tx_len(
2238                        p,
2239                        SrbStatus::SUCCESS,
2240                        Some((disk_count + 1) * 8),
2241                    )
2242                })
2243                .await;
2244            test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2245            let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2246            assert_eq!(lun_list_size, disk_count as u32 * 8);
2247
2248            guest
2249                .send_write_packet(
2250                    ScsiPath {
2251                        path: 0,
2252                        target: 0,
2253                        lun,
2254                    },
2255                    write_gpa,
2256                    1,
2257                    IO_LEN,
2258                )
2259                .await;
2260            guest
2261                .verify_completion(|p| {
2262                    test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS)
2263                })
2264                .await;
2265        }
2266
2267        // Remove all disks while the guest is running.
2268        for lun in 0..4 {
2269            controller
2270                .remove(ScsiPath {
2271                    path: 0,
2272                    target: 0,
2273                    lun,
2274                })
2275                .unwrap();
2276            guest
2277                .verify_completion(test_helpers::parse_guest_enumerate_bus)
2278                .await;
2279
2280            disk_count -= 1;
2281            guest
2282                .send_report_luns_packet(ScsiPath::default(), 0, 4096)
2283                .await;
2284            guest
2285                .verify_completion(|p| {
2286                    test_helpers::parse_guest_completed_io_check_tx_len(
2287                        p,
2288                        SrbStatus::SUCCESS,
2289                        Some((disk_count + 1) * 8),
2290                    )
2291                })
2292                .await;
2293            test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2294            let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2295            assert_eq!(lun_list_size, disk_count as u32 * 8);
2296
2297            guest
2298                .send_write_packet(
2299                    ScsiPath {
2300                        path: 0,
2301                        target: 0,
2302                        lun,
2303                    },
2304                    write_gpa,
2305                    1,
2306                    IO_LEN,
2307                )
2308                .await;
2309            guest
2310                .verify_completion(|p| {
2311                    test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2312                })
2313                .await;
2314        }
2315
2316        guest.verify_graceful_close(test_worker).await;
2317    }
2318
2319    #[async_test]
2320    pub async fn test_async_disk(driver: DefaultDriver) {
2321        let device = disklayer_ram::ram_disk(64 * 1024, false).unwrap();
2322        let controller = ScsiController::new();
2323        let disk = ScsiControllerDisk::new(Arc::new(scsidisk::SimpleScsiDisk::new(
2324            device,
2325            Default::default(),
2326        )));
2327        controller
2328            .attach(
2329                ScsiPath {
2330                    path: 0,
2331                    target: 0,
2332                    lun: 0,
2333                },
2334                disk,
2335            )
2336            .unwrap();
2337
2338        let (host, guest) = connected_async_channels(16 * 1024);
2339        let guest_queue = Queue::new(guest).unwrap();
2340
2341        let mut guest = test_helpers::TestGuest {
2342            queue: guest_queue,
2343            transaction_id: 0,
2344        };
2345
2346        let test_guest_mem = GuestMemory::allocate(16384);
2347        let worker = TestWorker::start(
2348            controller.clone(),
2349            &driver,
2350            test_guest_mem.clone(),
2351            host,
2352            None,
2353        );
2354
2355        let negotiate_packet = storvsp_protocol::Packet {
2356            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2357            flags: 0,
2358            status: storvsp_protocol::NtStatus::SUCCESS,
2359        };
2360        guest
2361            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2362            .await;
2363        guest.verify_completion(parse_guest_completion).await;
2364
2365        let version_packet = storvsp_protocol::Packet {
2366            operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2367            flags: 0,
2368            status: storvsp_protocol::NtStatus::SUCCESS,
2369        };
2370        let version = storvsp_protocol::ProtocolVersion {
2371            major_minor: storvsp_protocol::VERSION_BLUE,
2372            reserved: 0,
2373        };
2374        guest
2375            .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2376            .await;
2377        guest.verify_completion(parse_guest_completion).await;
2378
2379        let properties_packet = storvsp_protocol::Packet {
2380            operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2381            flags: 0,
2382            status: storvsp_protocol::NtStatus::SUCCESS,
2383        };
2384        guest
2385            .send_data_packet_sync(&[properties_packet.as_bytes()])
2386            .await;
2387        guest.verify_completion(parse_guest_completion).await;
2388
2389        let negotiate_packet = storvsp_protocol::Packet {
2390            operation: storvsp_protocol::Operation::END_INITIALIZATION,
2391            flags: 0,
2392            status: storvsp_protocol::NtStatus::SUCCESS,
2393        };
2394        guest
2395            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2396            .await;
2397        guest.verify_completion(parse_guest_completion).await;
2398
2399        const IO_LEN: usize = 4 * 1024;
2400        let write_buf = [7u8; IO_LEN];
2401        let write_gpa = 4 * 1024u64;
2402        test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2403        guest
2404            .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2405            .await;
2406        guest
2407            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2408            .await;
2409
2410        let read_gpa = 8 * 1024u64;
2411        guest
2412            .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
2413            .await;
2414        guest
2415            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2416            .await;
2417        let mut read_buf = [0u8; IO_LEN];
2418        test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
2419        for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
2420            assert_eq!(b1, b2);
2421        }
2422
2423        guest.verify_graceful_close(worker).await;
2424    }
2425}