storvsp/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7#[cfg(feature = "ioperf")]
8pub mod ioperf;
9
10#[cfg(feature = "test")]
11pub mod test_helpers;
12
13#[cfg(not(feature = "test"))]
14mod test_helpers;
15
16pub mod resolver;
17mod save_restore;
18
19use crate::ring::gparange::GpnList;
20use crate::ring::gparange::MultiPagedRangeBuf;
21use anyhow::Context as _;
22use async_trait::async_trait;
23use fast_select::FastSelect;
24use futures::FutureExt;
25use futures::StreamExt;
26use futures::select_biased;
27use guestmem::AccessError;
28use guestmem::GuestMemory;
29use guestmem::MemoryRead;
30use guestmem::MemoryWrite;
31use guestmem::ranges::PagedRange;
32use guid::Guid;
33use inspect::Inspect;
34use inspect::InspectMut;
35use inspect_counters::Counter;
36use inspect_counters::Histogram;
37use oversized_box::OversizedBox;
38use parking_lot::Mutex;
39use parking_lot::RwLock;
40use ring::OutgoingPacketType;
41use scsi::AdditionalSenseCode;
42use scsi::ScsiOp;
43use scsi::ScsiStatus;
44use scsi::srb::SrbStatus;
45use scsi::srb::SrbStatusAndFlags;
46use scsi_buffers::RequestBuffers;
47use scsi_core::AsyncScsiDisk;
48use scsi_core::Request;
49use scsi_core::ScsiResult;
50use scsi_defs as scsi;
51use scsidisk::illegal_request_sense;
52use slab::Slab;
53use std::collections::hash_map::Entry;
54use std::collections::hash_map::HashMap;
55use std::fmt::Debug;
56use std::future::Future;
57use std::future::poll_fn;
58use std::pin::Pin;
59use std::sync::Arc;
60use std::sync::atomic::AtomicU32;
61use std::sync::atomic::Ordering::Relaxed;
62use std::task::Context;
63use std::task::Poll;
64use storvsp_resources::ScsiPath;
65use task_control::AsyncRun;
66use task_control::InspectTask;
67use task_control::StopTask;
68use task_control::TaskControl;
69use thiserror::Error;
70use tracing_helpers::ErrorValueExt;
71use unicycle::FuturesUnordered;
72use vmbus_async::queue;
73use vmbus_async::queue::ExternalDataError;
74use vmbus_async::queue::IncomingPacket;
75use vmbus_async::queue::OutgoingPacket;
76use vmbus_async::queue::Queue;
77use vmbus_channel::RawAsyncChannel;
78use vmbus_channel::bus::ChannelType;
79use vmbus_channel::bus::OfferParams;
80use vmbus_channel::bus::OpenRequest;
81use vmbus_channel::channel::ChannelControl;
82use vmbus_channel::channel::ChannelOpenError;
83use vmbus_channel::channel::DeviceResources;
84use vmbus_channel::channel::RestoreControl;
85use vmbus_channel::channel::SaveRestoreVmbusDevice;
86use vmbus_channel::channel::VmbusDevice;
87use vmbus_channel::gpadl_ring::GpadlRingMem;
88use vmbus_channel::gpadl_ring::gpadl_channel;
89use vmbus_core::protocol::UserDefinedData;
90use vmbus_ring as ring;
91use vmbus_ring::RingMem;
92use vmcore::save_restore::RestoreError;
93use vmcore::save_restore::SaveError;
94use vmcore::save_restore::SavedStateBlob;
95use vmcore::vm_task::VmTaskDriver;
96use vmcore::vm_task::VmTaskDriverSource;
97use zerocopy::FromBytes;
98use zerocopy::FromZeros;
99use zerocopy::Immutable;
100use zerocopy::IntoBytes;
101use zerocopy::KnownLayout;
102
103/// The IO queue depth at which the controller switches from guest-signal-driven
104/// to poll-mode operation. This optimization reduces the guest exit rate by
105/// relying on (typically-interrupt-driven) IO completions to drive polling for
106/// new IO requests.
107const DEFAULT_POLL_MODE_QUEUE_DEPTH: u32 = 1;
108
109pub struct StorageDevice {
110    instance_id: Guid,
111    ide_path: Option<ScsiPath>,
112    workers: Vec<WorkerAndDriver>,
113    controller: Arc<ScsiControllerState>,
114    resources: DeviceResources,
115    driver_source: VmTaskDriverSource,
116    max_sub_channel_count: u16,
117    protocol: Arc<Protocol>,
118    io_queue_depth: u32,
119}
120
121#[derive(Inspect)]
122struct WorkerAndDriver {
123    #[inspect(flatten)]
124    worker: TaskControl<WorkerState, Worker>,
125    driver: VmTaskDriver,
126}
127
128struct WorkerState;
129
130impl InspectMut for StorageDevice {
131    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
132        let mut resp = req.respond();
133
134        let disks = self.controller.disks.read();
135        for (path, controller_disk) in disks.iter() {
136            resp.child(&format!("disks/{}", path), |req| {
137                controller_disk.disk.inspect(req);
138            });
139        }
140
141        resp.fields(
142            "channels",
143            self.workers
144                .iter()
145                .filter(|task| task.worker.has_state())
146                .enumerate(),
147        )
148        .field(
149            "poll_mode_queue_depth",
150            inspect::AtomicMut(&self.controller.poll_mode_queue_depth),
151        );
152    }
153}
154
155struct Worker<T: RingMem = GpadlRingMem> {
156    inner: WorkerInner,
157    rescan_notification: futures::channel::mpsc::Receiver<()>,
158    fast_select: FastSelect,
159    queue: Queue<T>,
160}
161
162struct Protocol {
163    state: RwLock<ProtocolState>,
164    /// Signaled when `state` transitions to `ProtocolState::Ready`.
165    ready: event_listener::Event,
166}
167
168struct WorkerInner {
169    protocol: Arc<Protocol>,
170    request_size: usize,
171    controller: Arc<ScsiControllerState>,
172    channel_index: u16,
173    scsi_queue: Arc<ScsiCommandQueue>,
174    scsi_requests: FuturesUnordered<ScsiRequest>,
175    scsi_requests_states: Slab<ScsiRequestState>,
176    full_request_pool: Vec<Arc<ScsiRequestAndRange>>,
177    future_pool: Vec<OversizedBox<(), ScsiOpStorage>>,
178    channel_control: ChannelControl,
179    max_io_queue_depth: usize,
180    stats: WorkerStats,
181}
182
183#[derive(Debug, Default, Inspect)]
184struct WorkerStats {
185    ios_submitted: Counter,
186    ios_completed: Counter,
187    wakes: Counter,
188    wakes_spurious: Counter,
189    per_wake_submissions: Histogram<10>,
190    per_wake_completions: Histogram<10>,
191}
192
193#[repr(u16)]
194#[derive(Copy, Clone, Debug, Inspect, PartialEq, Eq, PartialOrd, Ord)]
195enum Version {
196    Win6 = storvsp_protocol::VERSION_WIN6,
197    Win7 = storvsp_protocol::VERSION_WIN7,
198    Win8 = storvsp_protocol::VERSION_WIN8,
199    Blue = storvsp_protocol::VERSION_BLUE,
200}
201
202#[derive(Debug, Error)]
203#[error("protocol version {0:#x} not supported")]
204struct UnsupportedVersion(u16);
205
206impl Version {
207    fn parse(major_minor: u16) -> Result<Self, UnsupportedVersion> {
208        let version = match major_minor {
209            storvsp_protocol::VERSION_WIN6 => Self::Win6,
210            storvsp_protocol::VERSION_WIN7 => Self::Win7,
211            storvsp_protocol::VERSION_WIN8 => Self::Win8,
212            storvsp_protocol::VERSION_BLUE => Self::Blue,
213            version => return Err(UnsupportedVersion(version)),
214        };
215        assert_eq!(version as u16, major_minor);
216        Ok(version)
217    }
218
219    fn max_request_size(&self) -> usize {
220        match self {
221            Version::Win8 | Version::Blue => storvsp_protocol::SCSI_REQUEST_LEN_V2,
222            Version::Win6 | Version::Win7 => storvsp_protocol::SCSI_REQUEST_LEN_V1,
223        }
224    }
225}
226
227#[derive(Copy, Clone)]
228enum ProtocolState {
229    Init(InitState),
230    Ready {
231        version: Version,
232        subchannel_count: u16,
233    },
234}
235
236#[derive(Copy, Clone, Debug)]
237enum InitState {
238    Begin,
239    QueryVersion,
240    QueryProperties {
241        version: Version,
242    },
243    EndInitialization {
244        version: Version,
245        subchannel_count: Option<u16>,
246    },
247}
248
249/// The internal SCSI operation future type.
250///
251/// This is a boxed future of a large pre-determined size. The box is reused
252/// after a SCSI request completes to avoid allocations in the hot path.
253///
254/// An Option type is used so that the future can be efficiently dropped (via
255/// `Pin::set(x, None)`) before it is stashed away for reuse.
256type ScsiOpStorage = [u64; SCSI_REQUEST_STACK_SIZE / 8];
257type ScsiOpFuture = Pin<OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>>;
258
259/// The amount of space reserved for a ScsiOpFuture.
260///
261/// This was chosen by running `cargo test -p storvsp -- --no-capture` and looking at the required
262/// size that was given in the failure message
263const SCSI_REQUEST_STACK_SIZE: usize = scsi_core::ASYNC_SCSI_DISK_STACK_SIZE + 272;
264
265struct ScsiRequest {
266    request_id: usize,
267    future: Option<ScsiOpFuture>,
268}
269
270impl ScsiRequest {
271    fn new(
272        request_id: usize,
273        future: OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>,
274    ) -> Self {
275        Self {
276            request_id,
277            future: Some(future.into()),
278        }
279    }
280}
281
282impl Future for ScsiRequest {
283    type Output = (usize, ScsiResult, OversizedBox<(), ScsiOpStorage>);
284
285    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
286        let this = self.get_mut();
287        let future = this.future.as_mut().unwrap().as_mut();
288        let result = std::task::ready!(future.poll(cx));
289        // Return the future so that its storage can be reused.
290        let future = this.future.take().unwrap();
291        Poll::Ready((this.request_id, result, OversizedBox::empty_pinned(future)))
292    }
293}
294
295#[derive(Debug, Error)]
296enum WorkerError {
297    #[error("packet error")]
298    PacketError(#[source] PacketError),
299    #[error("queue error")]
300    Queue(#[source] queue::Error),
301    #[error("queue should have enough space but no longer does")]
302    NotEnoughSpace,
303}
304
305#[derive(Debug, Error)]
306enum PacketError {
307    #[error("Not transactional")]
308    NotTransactional,
309    #[error("Unrecognized operation {0:?}")]
310    UnrecognizedOperation(storvsp_protocol::Operation),
311    #[error("Invalid packet type")]
312    InvalidPacketType,
313    #[error("Invalid data transfer length")]
314    InvalidDataTransferLength,
315    #[error("Access error: {0}")]
316    Access(#[source] AccessError),
317    #[error("Range error")]
318    Range(#[source] ExternalDataError),
319}
320
321#[derive(Debug, Default, Clone)]
322struct Range {
323    buf: MultiPagedRangeBuf<GpnList>,
324    len: usize,
325    is_write: bool,
326}
327
328impl Range {
329    fn new(
330        buf: MultiPagedRangeBuf<GpnList>,
331        request: &storvsp_protocol::ScsiRequest,
332    ) -> Option<Self> {
333        let len = request.data_transfer_length as usize;
334        let is_write = request.data_in != 0;
335        // Ensure there is exactly one range and it's large enough, or there are
336        // zero ranges and there is no associated SCSI buffer.
337        if buf.range_count() > 1 || (len > 0 && buf.first()?.len() < len) {
338            return None;
339        }
340        Some(Self { buf, len, is_write })
341    }
342
343    fn buffer<'a>(&'a self, guest_memory: &'a GuestMemory) -> RequestBuffers<'a> {
344        let mut range = self.buf.first().unwrap_or_else(PagedRange::empty);
345        range.truncate(self.len);
346        RequestBuffers::new(guest_memory, range, self.is_write)
347    }
348}
349
350#[derive(Debug)]
351struct Packet {
352    data: PacketData,
353    transaction_id: u64,
354    request_size: usize,
355}
356
357#[derive(Debug)]
358enum PacketData {
359    BeginInitialization,
360    EndInitialization,
361    QueryProtocolVersion(u16),
362    QueryProperties,
363    CreateSubChannels(u16),
364    ExecuteScsi(Arc<ScsiRequestAndRange>),
365    ResetBus,
366    ResetAdapter,
367    ResetLun,
368}
369
370#[derive(Debug)]
371pub struct RangeError;
372
373fn parse_packet<T: RingMem>(
374    packet: &IncomingPacket<'_, T>,
375    pool: &mut Vec<Arc<ScsiRequestAndRange>>,
376) -> Result<Packet, PacketError> {
377    let packet = match packet {
378        IncomingPacket::Completion(_) => return Err(PacketError::InvalidPacketType),
379        IncomingPacket::Data(packet) => packet,
380    };
381    let transaction_id = packet
382        .transaction_id()
383        .ok_or(PacketError::NotTransactional)?;
384
385    let mut reader = packet.reader();
386    let header: storvsp_protocol::Packet = reader.read_plain().map_err(PacketError::Access)?;
387    // You would expect that this should be limited to the current protocol
388    // version's maximum packet size, but this is not what Hyper-V does, and
389    // Linux 6.1 relies on this behavior during protocol initialization.
390    let request_size = reader.len().min(storvsp_protocol::SCSI_REQUEST_LEN_MAX);
391    let data = match header.operation {
392        storvsp_protocol::Operation::BEGIN_INITIALIZATION => PacketData::BeginInitialization,
393        storvsp_protocol::Operation::END_INITIALIZATION => PacketData::EndInitialization,
394        storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION => {
395            let mut version = storvsp_protocol::ProtocolVersion::new_zeroed();
396            reader
397                .read(version.as_mut_bytes())
398                .map_err(PacketError::Access)?;
399            PacketData::QueryProtocolVersion(version.major_minor)
400        }
401        storvsp_protocol::Operation::QUERY_PROPERTIES => PacketData::QueryProperties,
402        storvsp_protocol::Operation::EXECUTE_SRB => {
403            let mut full_request = pool.pop().unwrap_or_else(|| {
404                Arc::new(ScsiRequestAndRange {
405                    external_data: Range::default(),
406                    request: storvsp_protocol::ScsiRequest::new_zeroed(),
407                    request_size,
408                })
409            });
410
411            {
412                let full_request = Arc::get_mut(&mut full_request).unwrap();
413                let request_buf = &mut full_request.request.as_mut_bytes()[..request_size];
414                reader.read(request_buf).map_err(PacketError::Access)?;
415
416                let buf = packet.read_external_ranges().map_err(PacketError::Range)?;
417
418                full_request.external_data = Range::new(buf, &full_request.request)
419                    .ok_or(PacketError::InvalidDataTransferLength)?;
420            }
421
422            PacketData::ExecuteScsi(full_request)
423        }
424        storvsp_protocol::Operation::RESET_LUN => PacketData::ResetLun,
425        storvsp_protocol::Operation::RESET_ADAPTER => PacketData::ResetAdapter,
426        storvsp_protocol::Operation::RESET_BUS => PacketData::ResetBus,
427        storvsp_protocol::Operation::CREATE_SUB_CHANNELS => {
428            let mut sub_channel_count: u16 = 0;
429            reader
430                .read(sub_channel_count.as_mut_bytes())
431                .map_err(PacketError::Access)?;
432            PacketData::CreateSubChannels(sub_channel_count)
433        }
434        _ => return Err(PacketError::UnrecognizedOperation(header.operation)),
435    };
436
437    if let PacketData::ExecuteScsi(_) = data {
438        tracing::trace!(transaction_id, ?data, "parse_packet");
439    } else {
440        tracing::debug!(transaction_id, ?data, "parse_packet");
441    }
442
443    Ok(Packet {
444        data,
445        request_size,
446        transaction_id,
447    })
448}
449
450impl WorkerInner {
451    fn send_vmbus_packet<M: RingMem>(
452        &mut self,
453        writer: &mut queue::WriteBatch<'_, M>,
454        packet_type: OutgoingPacketType<'_>,
455        request_size: usize,
456        transaction_id: u64,
457        operation: storvsp_protocol::Operation,
458        status: storvsp_protocol::NtStatus,
459        payload: &[u8],
460    ) -> Result<(), WorkerError> {
461        let header = storvsp_protocol::Packet {
462            operation,
463            flags: 0,
464            status,
465        };
466
467        let packet_size = size_of_val(&header) + request_size;
468
469        // Zero pad or truncate the payload to the queue's packet size. This is
470        // necessary because Windows guests check that each packet's size is
471        // exactly the largest possible packet size for the negotiated protocol
472        // version.
473        let len = size_of_val(&header) + size_of_val(payload);
474        let padding = [0; storvsp_protocol::SCSI_REQUEST_LEN_MAX];
475        let (payload_bytes, padding_bytes) = if len > packet_size {
476            (&payload[..packet_size - size_of_val(&header)], &[][..])
477        } else {
478            (payload, &padding[..packet_size - len])
479        };
480        assert_eq!(
481            size_of_val(&header) + payload_bytes.len() + padding_bytes.len(),
482            packet_size
483        );
484        writer
485            .try_write(&OutgoingPacket {
486                transaction_id,
487                packet_type,
488                payload: &[header.as_bytes(), payload_bytes, padding_bytes],
489            })
490            .map_err(|err| match err {
491                queue::TryWriteError::Full(_) => WorkerError::NotEnoughSpace,
492                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
493            })
494    }
495
496    fn send_packet<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
497        &mut self,
498        writer: &mut queue::WriteHalf<'_, M>,
499        operation: storvsp_protocol::Operation,
500        status: storvsp_protocol::NtStatus,
501        payload: &P,
502    ) -> Result<(), WorkerError> {
503        self.send_vmbus_packet(
504            &mut writer.batched(),
505            OutgoingPacketType::InBandNoCompletion,
506            self.request_size,
507            0,
508            operation,
509            status,
510            payload.as_bytes(),
511        )
512    }
513
514    fn send_completion<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
515        &mut self,
516        writer: &mut queue::WriteHalf<'_, M>,
517        packet: &Packet,
518        status: storvsp_protocol::NtStatus,
519        payload: &P,
520    ) -> Result<(), WorkerError> {
521        self.send_vmbus_packet(
522            &mut writer.batched(),
523            OutgoingPacketType::Completion,
524            packet.request_size,
525            packet.transaction_id,
526            storvsp_protocol::Operation::COMPLETE_IO,
527            status,
528            payload.as_bytes(),
529        )
530    }
531}
532
533struct ScsiCommandQueue {
534    controller: Arc<ScsiControllerState>,
535    mem: GuestMemory,
536    force_path_id: Option<u8>,
537}
538
539impl ScsiCommandQueue {
540    async fn execute_scsi(
541        &self,
542        external_data: &Range,
543        request: &storvsp_protocol::ScsiRequest,
544    ) -> ScsiResult {
545        let op = ScsiOp(request.payload[0]);
546        let external_data = external_data.buffer(&self.mem);
547
548        tracing::trace!(
549            path_id = request.path_id,
550            target_id = request.target_id,
551            lun = request.lun,
552            op = ?op,
553            "execute_scsi start...",
554        );
555
556        let path_id = self.force_path_id.unwrap_or(request.path_id);
557
558        let controller_disk = self
559            .controller
560            .disks
561            .read()
562            .get(&ScsiPath {
563                path: path_id,
564                target: request.target_id,
565                lun: request.lun,
566            })
567            .cloned();
568
569        let result = match op {
570            ScsiOp::REPORT_LUNS => {
571                const HEADER_SIZE: usize = size_of::<scsi::LunList>();
572                let mut luns: Vec<u8> = self
573                    .controller
574                    .disks
575                    .read()
576                    .iter()
577                    .flat_map(|(path, _)| {
578                        // Use the original path ID and not the forced one to
579                        // match Hyper-V storvsp behavior.
580                        if request.path_id == path.path && request.target_id == path.target {
581                            Some(path.lun)
582                        } else {
583                            None
584                        }
585                    })
586                    .collect();
587                luns.sort_unstable();
588                let mut data: Vec<u64> = vec![0; luns.len() + 1];
589                let header = scsi::LunList {
590                    length: (luns.len() as u32 * 8).into(),
591                    reserved: [0; 4],
592                };
593                data.as_mut_bytes()[..HEADER_SIZE].copy_from_slice(header.as_bytes());
594                for (i, lun) in luns.iter().enumerate() {
595                    data[i + 1].as_mut_bytes()[..2].copy_from_slice(&(*lun as u16).to_be_bytes());
596                }
597                if external_data.len() >= HEADER_SIZE {
598                    let tx = std::cmp::min(external_data.len(), data.as_bytes().len());
599                    external_data.writer().write(&data.as_bytes()[..tx]).map_or(
600                        ScsiResult {
601                            scsi_status: ScsiStatus::CHECK_CONDITION,
602                            srb_status: SrbStatus::INVALID_REQUEST,
603                            tx: 0,
604                            sense_data: Some(illegal_request_sense(
605                                AdditionalSenseCode::INVALID_CDB,
606                            )),
607                        },
608                        |_| ScsiResult {
609                            scsi_status: ScsiStatus::GOOD,
610                            srb_status: SrbStatus::SUCCESS,
611                            tx,
612                            sense_data: None,
613                        },
614                    )
615                } else {
616                    ScsiResult {
617                        scsi_status: ScsiStatus::GOOD,
618                        srb_status: SrbStatus::SUCCESS,
619                        tx: 0,
620                        sense_data: None,
621                    }
622                }
623            }
624            _ if controller_disk.is_some() => {
625                let mut cdb = [0; 16];
626                cdb.copy_from_slice(&request.payload[0..storvsp_protocol::CDB16GENERIC_LENGTH]);
627                controller_disk
628                    .unwrap()
629                    .disk
630                    .execute_scsi(
631                        &external_data,
632                        &Request {
633                            cdb,
634                            srb_flags: request.srb_flags,
635                        },
636                    )
637                    .await
638            }
639            ScsiOp::INQUIRY => {
640                let cdb = scsi::CdbInquiry::ref_from_prefix(&request.payload)
641                    .unwrap()
642                    .0; // TODO: zerocopy: ref-from-prefix: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
643                if external_data.len() < cdb.allocation_length.get() as usize
644                    || request.data_in != storvsp_protocol::SCSI_IOCTL_DATA_IN
645                    || (cdb.allocation_length.get() as usize) < size_of::<scsi::InquiryDataHeader>()
646                {
647                    ScsiResult {
648                        scsi_status: ScsiStatus::CHECK_CONDITION,
649                        srb_status: SrbStatus::INVALID_REQUEST,
650                        tx: 0,
651                        sense_data: Some(illegal_request_sense(AdditionalSenseCode::INVALID_CDB)),
652                    }
653                } else {
654                    let enable_vpd = cdb.flags.vpd();
655                    if enable_vpd || cdb.page_code != 0 {
656                        // cannot support VPD inquiry for non-existing device (lun).
657                        ScsiResult {
658                            scsi_status: ScsiStatus::CHECK_CONDITION,
659                            srb_status: SrbStatus::INVALID_REQUEST,
660                            tx: 0,
661                            sense_data: Some(illegal_request_sense(
662                                AdditionalSenseCode::INVALID_CDB,
663                            )),
664                        }
665                    } else {
666                        const LOGICAL_UNIT_NOT_PRESENT_DEVICE: u8 = 0x7F;
667                        let mut data = scsidisk::INQUIRY_DATA_TEMPLATE;
668                        data.header.device_type = LOGICAL_UNIT_NOT_PRESENT_DEVICE;
669
670                        if request.lun != 0 {
671                            // Below fields are only set for lun0 inquiry so zero out here.
672                            data.vendor_id = [0; 8];
673                            data.product_id = [0; 16];
674                            data.product_revision_level = [0; 4];
675                        }
676
677                        let datab = data.as_bytes();
678                        let tx = std::cmp::min(
679                            cdb.allocation_length.get() as usize,
680                            size_of::<scsi::InquiryData>(),
681                        );
682                        external_data.writer().write(&datab[..tx]).map_or(
683                            ScsiResult {
684                                scsi_status: ScsiStatus::CHECK_CONDITION,
685                                srb_status: SrbStatus::INVALID_REQUEST,
686                                tx: 0,
687                                sense_data: Some(illegal_request_sense(
688                                    AdditionalSenseCode::INVALID_CDB,
689                                )),
690                            },
691                            |_| ScsiResult {
692                                scsi_status: ScsiStatus::GOOD,
693                                srb_status: SrbStatus::SUCCESS,
694                                tx,
695                                sense_data: None,
696                            },
697                        )
698                    }
699                }
700            }
701            _ => ScsiResult {
702                scsi_status: ScsiStatus::CHECK_CONDITION,
703                srb_status: SrbStatus::INVALID_LUN,
704                tx: 0,
705                sense_data: None,
706            },
707        };
708
709        tracing::trace!(
710            path_id = request.path_id,
711            target_id = request.target_id,
712            lun = request.lun,
713            op = ?op,
714            result = ?result,
715            "execute_scsi completed.",
716        );
717        result
718    }
719}
720
721impl<T: RingMem + 'static> Worker<T> {
722    fn new(
723        controller: Arc<ScsiControllerState>,
724        channel: RawAsyncChannel<T>,
725        channel_index: u16,
726        mem: GuestMemory,
727        channel_control: ChannelControl,
728        io_queue_depth: u32,
729        protocol: Arc<Protocol>,
730        force_path_id: Option<u8>,
731    ) -> anyhow::Result<Self> {
732        let queue = Queue::new(channel)?;
733        #[expect(clippy::disallowed_methods)] // TODO
734        let (source, target) = futures::channel::mpsc::channel(1);
735        controller.add_rescan_notification_source(source);
736
737        let max_io_queue_depth = io_queue_depth.max(1) as usize;
738        Ok(Self {
739            inner: WorkerInner {
740                protocol,
741                request_size: storvsp_protocol::SCSI_REQUEST_LEN_V1,
742                controller: controller.clone(),
743                channel_index,
744                scsi_queue: Arc::new(ScsiCommandQueue {
745                    controller,
746                    mem,
747                    force_path_id,
748                }),
749                scsi_requests: FuturesUnordered::new(),
750                scsi_requests_states: Slab::with_capacity(max_io_queue_depth),
751                channel_control,
752                max_io_queue_depth,
753                future_pool: Vec::new(),
754                full_request_pool: Vec::new(),
755                stats: Default::default(),
756            },
757            queue,
758            rescan_notification: target,
759            fast_select: FastSelect::new(),
760        })
761    }
762
763    async fn wait_for_scsi_requests_complete(&mut self) {
764        tracing::debug!(
765            channel_index = self.inner.channel_index,
766            "wait for IOs completed..."
767        );
768        while let Some((id, _, _)) = self.inner.scsi_requests.next().await {
769            self.inner.scsi_requests_states.remove(id);
770        }
771    }
772}
773
774impl InspectTask<Worker> for WorkerState {
775    fn inspect(&self, req: inspect::Request<'_>, worker: Option<&Worker>) {
776        if let Some(worker) = worker {
777            let mut resp = req.respond();
778            if worker.inner.channel_index == 0 {
779                let (state, version, subchannel_count) = match *worker.inner.protocol.state.read() {
780                    ProtocolState::Init(state) => match state {
781                        InitState::Begin => ("begin_init", None, None),
782                        InitState::QueryVersion => ("query_version", None, None),
783                        InitState::QueryProperties { version } => {
784                            ("query_properties", Some(version), None)
785                        }
786                        InitState::EndInitialization {
787                            version,
788                            subchannel_count,
789                        } => ("end_init", Some(version), subchannel_count),
790                    },
791                    ProtocolState::Ready {
792                        version,
793                        subchannel_count,
794                    } => ("ready", Some(version), Some(subchannel_count)),
795                };
796                resp.field("state", state)
797                    .field("version", version)
798                    .field("subchannel_count", subchannel_count);
799            }
800            resp.field("pending_packets", worker.inner.scsi_requests_states.len())
801                .fields("io", worker.inner.scsi_requests_states.iter())
802                .field("stats", &worker.inner.stats)
803                .field("ring", &worker.queue)
804                .field("max_io_queue_depth", worker.inner.max_io_queue_depth);
805        }
806    }
807}
808
809impl<T: 'static + Send + Sync + RingMem> AsyncRun<Worker<T>> for WorkerState {
810    async fn run(
811        &mut self,
812        stop: &mut StopTask<'_>,
813        worker: &mut Worker<T>,
814    ) -> Result<(), task_control::Cancelled> {
815        let fut = async {
816            if worker.inner.channel_index == 0 {
817                worker.process_primary().await
818            } else {
819                // Wait for initialization to end before processing any
820                // subchannel packets.
821                let protocol_version = loop {
822                    let listener = worker.inner.protocol.ready.listen();
823                    if let ProtocolState::Ready { version, .. } =
824                        *worker.inner.protocol.state.read()
825                    {
826                        break version;
827                    }
828                    tracing::debug!("subchannel waiting for initialization to end");
829                    listener.await
830                };
831                worker
832                    .inner
833                    .process_ready(&mut worker.queue, protocol_version)
834                    .await
835            }
836        };
837
838        match stop.until_stopped(fut).await? {
839            Ok(_) => {}
840            Err(e) => tracing::error!(error = e.as_error(), "process_packets error"),
841        }
842        Ok(())
843    }
844}
845
846impl WorkerInner {
847    /// Awaits the next incoming packet, without checking for any other events (device add/remove notifications or available completions).
848    /// Increments the count of outstanding packets when returning `Ok(Packet)`.
849    async fn next_packet<'a, M: RingMem>(
850        &mut self,
851        reader: &'a mut queue::ReadHalf<'a, M>,
852    ) -> Result<Packet, WorkerError> {
853        let packet = reader.read().await.map_err(WorkerError::Queue)?;
854        let stor_packet =
855            parse_packet(&packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
856        Ok(stor_packet)
857    }
858
859    /// Polls for enough ring space in the outgoing ring to send a packet.
860    ///
861    /// This is used to ensure there is enough space in the ring before
862    /// committing to sending a packet. This avoids the need to save pending
863    /// packets on the side if queue processing is interrupted while the ring is
864    /// full.
865    fn poll_for_ring_space<M: RingMem>(
866        &mut self,
867        cx: &mut Context<'_>,
868        writer: &mut queue::WriteHalf<'_, M>,
869    ) -> Poll<Result<(), WorkerError>> {
870        writer
871            .poll_ready(cx, MAX_VMBUS_PACKET_SIZE)
872            .map_err(WorkerError::Queue)
873    }
874}
875
876const MAX_VMBUS_PACKET_SIZE: usize = ring::PacketSize::in_band(
877    size_of::<storvsp_protocol::Packet>() + storvsp_protocol::SCSI_REQUEST_LEN_MAX,
878);
879
880impl<T: RingMem> Worker<T> {
881    /// Processes the protocol state machine.
882    async fn process_primary(&mut self) -> Result<(), WorkerError> {
883        loop {
884            let current_state = *self.inner.protocol.state.read();
885            match current_state {
886                ProtocolState::Ready { version, .. } => {
887                    break loop {
888                        select_biased! {
889                            r = self.inner.process_ready(&mut self.queue, version).fuse() => break r,
890                            _ = self.fast_select.select((self.rescan_notification.select_next_some(),)).fuse() => {
891                                if version >= Version::Win7
892                                {
893                                    self.inner.send_packet(&mut self.queue.split().1, storvsp_protocol::Operation::ENUMERATE_BUS, storvsp_protocol::NtStatus::SUCCESS, &())?;
894                                }
895                            }
896                        }
897                    };
898                }
899                ProtocolState::Init(state) => {
900                    let (mut reader, mut writer) = self.queue.split();
901
902                    // Ensure that subsequent calls to `send_completion` won't
903                    // fail due to lack of ring space, to avoid keeping (and saving/restoring) interim states.
904                    poll_fn(|cx| self.inner.poll_for_ring_space(cx, &mut writer)).await?;
905
906                    tracing::debug!(?state, "process_primary");
907                    match state {
908                        InitState::Begin => {
909                            let packet = self.inner.next_packet(&mut reader).await?;
910                            if let PacketData::BeginInitialization = packet.data {
911                                self.inner.send_completion(
912                                    &mut writer,
913                                    &packet,
914                                    storvsp_protocol::NtStatus::SUCCESS,
915                                    &(),
916                                )?;
917                                *self.inner.protocol.state.write() =
918                                    ProtocolState::Init(InitState::QueryVersion);
919                            } else {
920                                tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
921                                self.inner.send_completion(
922                                    &mut writer,
923                                    &packet,
924                                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
925                                    &(),
926                                )?;
927                            }
928                        }
929                        InitState::QueryVersion => {
930                            let packet = self.inner.next_packet(&mut reader).await?;
931                            if let PacketData::QueryProtocolVersion(major_minor) = packet.data {
932                                if let Ok(version) = Version::parse(major_minor) {
933                                    self.inner.send_completion(
934                                        &mut writer,
935                                        &packet,
936                                        storvsp_protocol::NtStatus::SUCCESS,
937                                        &storvsp_protocol::ProtocolVersion {
938                                            major_minor,
939                                            reserved: 0,
940                                        },
941                                    )?;
942                                    self.inner.request_size = version.max_request_size();
943                                    *self.inner.protocol.state.write() =
944                                        ProtocolState::Init(InitState::QueryProperties { version });
945
946                                    tracelimit::info_ratelimited!(
947                                        ?version,
948                                        "scsi version negotiated"
949                                    );
950                                } else {
951                                    self.inner.send_completion(
952                                        &mut writer,
953                                        &packet,
954                                        storvsp_protocol::NtStatus::REVISION_MISMATCH,
955                                        &storvsp_protocol::ProtocolVersion {
956                                            major_minor,
957                                            reserved: 0,
958                                        },
959                                    )?;
960                                    *self.inner.protocol.state.write() =
961                                        ProtocolState::Init(InitState::QueryVersion);
962                                }
963                            } else {
964                                tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
965                                self.inner.send_completion(
966                                    &mut writer,
967                                    &packet,
968                                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
969                                    &(),
970                                )?;
971                            }
972                        }
973                        InitState::QueryProperties { version } => {
974                            let packet = self.inner.next_packet(&mut reader).await?;
975                            if let PacketData::QueryProperties = packet.data {
976                                let multi_channel_supported = version >= Version::Win8;
977
978                                self.inner.send_completion(
979                                    &mut writer,
980                                    &packet,
981                                    storvsp_protocol::NtStatus::SUCCESS,
982                                    &storvsp_protocol::ChannelProperties {
983                                        max_transfer_bytes: 0x40000, // 256KB
984                                        flags: {
985                                            if multi_channel_supported {
986                                                storvsp_protocol::STORAGE_CHANNEL_SUPPORTS_MULTI_CHANNEL
987                                            } else {
988                                                0
989                                            }
990                                        },
991                                        maximum_sub_channel_count: if multi_channel_supported {
992                                            self.inner.channel_control.max_subchannels()
993                                        } else {
994                                            0
995                                        },
996                                        reserved: 0,
997                                        reserved2: 0,
998                                        reserved3: [0, 0],
999                                    },
1000                                )?;
1001                                *self.inner.protocol.state.write() =
1002                                    ProtocolState::Init(InitState::EndInitialization {
1003                                        version,
1004                                        subchannel_count: if multi_channel_supported {
1005                                            None
1006                                        } else {
1007                                            Some(0)
1008                                        },
1009                                    });
1010                            } else {
1011                                tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1012                                self.inner.send_completion(
1013                                    &mut writer,
1014                                    &packet,
1015                                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1016                                    &(),
1017                                )?;
1018                            }
1019                        }
1020                        InitState::EndInitialization {
1021                            version,
1022                            subchannel_count,
1023                        } => {
1024                            let packet = self.inner.next_packet(&mut reader).await?;
1025                            match packet.data {
1026                                PacketData::CreateSubChannels(sub_channel_count)
1027                                    if subchannel_count.is_none() =>
1028                                {
1029                                    if let Err(err) = self
1030                                        .inner
1031                                        .channel_control
1032                                        .enable_subchannels(sub_channel_count)
1033                                    {
1034                                        tracelimit::warn_ratelimited!(
1035                                            ?err,
1036                                            "cannot enable subchannels"
1037                                        );
1038                                        self.inner.send_completion(
1039                                            &mut writer,
1040                                            &packet,
1041                                            storvsp_protocol::NtStatus::INVALID_PARAMETER,
1042                                            &(),
1043                                        )?;
1044                                    } else {
1045                                        self.inner.send_completion(
1046                                            &mut writer,
1047                                            &packet,
1048                                            storvsp_protocol::NtStatus::SUCCESS,
1049                                            &(),
1050                                        )?;
1051                                        *self.inner.protocol.state.write() =
1052                                            ProtocolState::Init(InitState::EndInitialization {
1053                                                version,
1054                                                subchannel_count: Some(sub_channel_count),
1055                                            });
1056                                    }
1057                                }
1058                                PacketData::EndInitialization => {
1059                                    self.inner.send_completion(
1060                                        &mut writer,
1061                                        &packet,
1062                                        storvsp_protocol::NtStatus::SUCCESS,
1063                                        &(),
1064                                    )?;
1065                                    // Reset the rescan notification event now, before the guest has a
1066                                    // chance to send any SCSI requests to scan the bus.
1067                                    self.rescan_notification.try_next().ok();
1068                                    *self.inner.protocol.state.write() = ProtocolState::Ready {
1069                                        version,
1070                                        subchannel_count: subchannel_count.unwrap_or(0),
1071                                    };
1072                                    // Wake up subchannels waiting for the
1073                                    // protocol state to become ready.
1074                                    self.inner.protocol.ready.notify(usize::MAX);
1075                                }
1076                                _ => {
1077                                    tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1078                                    self.inner.send_completion(
1079                                        &mut writer,
1080                                        &packet,
1081                                        storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1082                                        &(),
1083                                    )?;
1084                                }
1085                            }
1086                        }
1087                    }
1088                }
1089            }
1090        }
1091    }
1092}
1093
1094fn convert_srb_status_to_nt_status(srb_status: SrbStatus) -> storvsp_protocol::NtStatus {
1095    match srb_status {
1096        SrbStatus::BUSY => storvsp_protocol::NtStatus::DEVICE_BUSY,
1097        SrbStatus::SUCCESS => storvsp_protocol::NtStatus::SUCCESS,
1098        SrbStatus::INVALID_LUN
1099        | SrbStatus::INVALID_TARGET_ID
1100        | SrbStatus::NO_DEVICE
1101        | SrbStatus::NO_HBA => storvsp_protocol::NtStatus::DEVICE_DOES_NOT_EXIST,
1102        SrbStatus::COMMAND_TIMEOUT | SrbStatus::TIMEOUT => storvsp_protocol::NtStatus::IO_TIMEOUT,
1103        SrbStatus::SELECTION_TIMEOUT => storvsp_protocol::NtStatus::DEVICE_NOT_CONNECTED,
1104        SrbStatus::BAD_FUNCTION | SrbStatus::BAD_SRB_BLOCK_LENGTH => {
1105            storvsp_protocol::NtStatus::INVALID_DEVICE_REQUEST
1106        }
1107        SrbStatus::DATA_OVERRUN => storvsp_protocol::NtStatus::BUFFER_OVERFLOW,
1108        SrbStatus::REQUEST_FLUSHED => storvsp_protocol::NtStatus::UNSUCCESSFUL,
1109        SrbStatus::ABORTED => storvsp_protocol::NtStatus::CANCELLED,
1110        _ => storvsp_protocol::NtStatus::IO_DEVICE_ERROR,
1111    }
1112}
1113
1114impl WorkerInner {
1115    /// Processes packets and SCSI completions after protocol negotiation has finished.
1116    async fn process_ready<M: RingMem>(
1117        &mut self,
1118        queue: &mut Queue<M>,
1119        protocol_version: Version,
1120    ) -> Result<(), WorkerError> {
1121        self.request_size = protocol_version.max_request_size();
1122        poll_fn(|cx| self.poll_process_ready(cx, queue)).await
1123    }
1124
1125    /// Processes packets and SCSI completions after protocol negotiation has finished.
1126    fn poll_process_ready<M: RingMem>(
1127        &mut self,
1128        cx: &mut Context<'_>,
1129        queue: &mut Queue<M>,
1130    ) -> Poll<Result<(), WorkerError>> {
1131        self.stats.wakes.increment();
1132
1133        let (mut reader, mut writer) = queue.split();
1134        let mut total_completions = 0;
1135        let mut total_submissions = 0;
1136        let poll_mode_queue_depth = self.controller.poll_mode_queue_depth.load(Relaxed) as usize;
1137
1138        loop {
1139            // Drive IOs forward and collect completions.
1140            'outer: while !self.scsi_requests_states.is_empty() {
1141                {
1142                    let mut batch = writer.batched();
1143                    loop {
1144                        // Ensure there is room for the completion before consuming
1145                        // the IO so that we don't have to track completed IOs whose
1146                        // completions haven't been sent.
1147                        if !batch
1148                            .can_write(MAX_VMBUS_PACKET_SIZE)
1149                            .map_err(WorkerError::Queue)?
1150                        {
1151                            // This batch is full but there may still be more completions.
1152                            break;
1153                        }
1154                        if let Poll::Ready(Some((request_id, result, future))) =
1155                            self.scsi_requests.poll_next_unpin(cx)
1156                        {
1157                            self.future_pool.push(future);
1158                            self.handle_completion(&mut batch, request_id, result)?;
1159                            total_completions += 1;
1160                        } else {
1161                            tracing::trace!("out of completions");
1162                            break 'outer;
1163                        }
1164                    }
1165                }
1166
1167                // Wait for enough space for any completion packets.
1168                if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1169                    tracing::trace!("out of ring space");
1170                    break;
1171                }
1172            }
1173
1174            let mut submissions = 0;
1175            // Process new requests.
1176            'outer: loop {
1177                if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1178                    break;
1179                }
1180                let mut batch = if self.scsi_requests_states.len() < poll_mode_queue_depth {
1181                    if let Poll::Ready(batch) = reader.poll_read_batch(cx) {
1182                        batch.map_err(WorkerError::Queue)?
1183                    } else {
1184                        tracing::trace!("out of incoming packets");
1185                        break;
1186                    }
1187                } else {
1188                    match reader.try_read_batch() {
1189                        Ok(batch) => batch,
1190                        Err(queue::TryReadError::Empty) => {
1191                            tracing::trace!(
1192                                pending_io_count = self.scsi_requests_states.len(),
1193                                "out of incoming packets, keeping interrupts masked"
1194                            );
1195                            break;
1196                        }
1197                        Err(queue::TryReadError::Queue(err)) => Err(WorkerError::Queue(err))?,
1198                    }
1199                };
1200
1201                let mut packets = batch.packets();
1202                loop {
1203                    if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1204                        break 'outer;
1205                    }
1206                    // Wait for enough space for any completion packets that
1207                    // `handle_packet` may need to send, so that it isn't necessary
1208                    // to track pending completions.
1209                    if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1210                        tracing::trace!("out of ring space");
1211                        break 'outer;
1212                    }
1213
1214                    let packet = if let Some(packet) = packets.next() {
1215                        packet.map_err(WorkerError::Queue)?
1216                    } else {
1217                        break;
1218                    };
1219
1220                    if self.handle_packet(&mut writer, &packet)? {
1221                        submissions += 1;
1222                    }
1223                }
1224            }
1225
1226            // Loop around to poll the IOs if any new IOs were submitted.
1227            if submissions == 0 {
1228                // No need to poll again.
1229                break;
1230            }
1231            total_submissions += submissions;
1232        }
1233
1234        if total_submissions != 0 || total_completions != 0 {
1235            self.stats.ios_submitted.add(total_submissions);
1236            self.stats
1237                .per_wake_submissions
1238                .add_sample(total_submissions);
1239            self.stats
1240                .per_wake_completions
1241                .add_sample(total_completions);
1242            self.stats.ios_completed.add(total_completions);
1243        } else {
1244            self.stats.wakes_spurious.increment();
1245        }
1246
1247        Poll::Pending
1248    }
1249
1250    fn handle_completion<M: RingMem>(
1251        &mut self,
1252        writer: &mut queue::WriteBatch<'_, M>,
1253        request_id: usize,
1254        result: ScsiResult,
1255    ) -> Result<(), WorkerError> {
1256        let state = self.scsi_requests_states.remove(request_id);
1257        let request_size = state.request.request_size;
1258
1259        // Push the request into the pool to avoid reallocating later.
1260        assert_eq!(
1261            Arc::strong_count(&state.request) + Arc::weak_count(&state.request),
1262            1
1263        );
1264        self.full_request_pool.push(state.request);
1265
1266        let status = convert_srb_status_to_nt_status(result.srb_status);
1267        let mut payload = [0; 0x14];
1268        if let Some(sense) = result.sense_data {
1269            payload[..size_of_val(&sense)].copy_from_slice(sense.as_bytes());
1270            tracing::trace!(sense_info = ?payload, sense_key = payload[2], asc = payload[12], "execute_scsi");
1271        };
1272        let response = storvsp_protocol::ScsiRequest {
1273            length: size_of::<storvsp_protocol::ScsiRequest>() as u16,
1274            scsi_status: result.scsi_status,
1275            srb_status: SrbStatusAndFlags::new()
1276                .with_status(result.srb_status)
1277                .with_autosense_valid(result.sense_data.is_some()),
1278            data_transfer_length: result.tx as u32,
1279            cdb_length: storvsp_protocol::CDB16GENERIC_LENGTH as u8,
1280            sense_info_ex_length: storvsp_protocol::VMSCSI_SENSE_BUFFER_SIZE as u8,
1281            payload,
1282            ..storvsp_protocol::ScsiRequest::new_zeroed()
1283        };
1284        self.send_vmbus_packet(
1285            writer,
1286            OutgoingPacketType::Completion,
1287            request_size,
1288            state.transaction_id,
1289            storvsp_protocol::Operation::COMPLETE_IO,
1290            status,
1291            response.as_bytes(),
1292        )?;
1293        Ok(())
1294    }
1295
1296    fn handle_packet<M: RingMem>(
1297        &mut self,
1298        writer: &mut queue::WriteHalf<'_, M>,
1299        packet: &IncomingPacket<'_, M>,
1300    ) -> Result<bool, WorkerError> {
1301        let packet =
1302            parse_packet(packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
1303        let submitted_io = match packet.data {
1304            PacketData::ExecuteScsi(request) => {
1305                self.push_scsi_request(packet.transaction_id, request);
1306                true
1307            }
1308            PacketData::ResetAdapter | PacketData::ResetBus | PacketData::ResetLun => {
1309                // These operations have always been no-ops.
1310                self.send_completion(writer, &packet, storvsp_protocol::NtStatus::SUCCESS, &())?;
1311                false
1312            }
1313            PacketData::CreateSubChannels(new_subchannel_count) if self.channel_index == 0 => {
1314                if let Err(err) = self
1315                    .channel_control
1316                    .enable_subchannels(new_subchannel_count)
1317                {
1318                    tracelimit::warn_ratelimited!(?err, "cannot create subchannels");
1319                    self.send_completion(
1320                        writer,
1321                        &packet,
1322                        storvsp_protocol::NtStatus::INVALID_PARAMETER,
1323                        &(),
1324                    )?;
1325                    false
1326                } else {
1327                    // Update the subchannel count in the protocol state for save.
1328                    if let ProtocolState::Ready {
1329                        subchannel_count, ..
1330                    } = &mut *self.protocol.state.write()
1331                    {
1332                        *subchannel_count = new_subchannel_count;
1333                    } else {
1334                        unreachable!()
1335                    }
1336
1337                    self.send_completion(
1338                        writer,
1339                        &packet,
1340                        storvsp_protocol::NtStatus::SUCCESS,
1341                        &(),
1342                    )?;
1343                    false
1344                }
1345            }
1346            _ => {
1347                tracelimit::warn_ratelimited!(data = ?packet.data, "unexpected packet on ready");
1348                self.send_completion(
1349                    writer,
1350                    &packet,
1351                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1352                    &(),
1353                )?;
1354                false
1355            }
1356        };
1357        Ok(submitted_io)
1358    }
1359
1360    fn push_scsi_request(&mut self, transaction_id: u64, full_request: Arc<ScsiRequestAndRange>) {
1361        let scsi_queue = self.scsi_queue.clone();
1362        let scsi_request_state = ScsiRequestState {
1363            transaction_id,
1364            request: full_request.clone(),
1365        };
1366        let request_id = self.scsi_requests_states.insert(scsi_request_state);
1367        let future = self
1368            .future_pool
1369            .pop()
1370            .unwrap_or_else(|| OversizedBox::new(()));
1371        let future = OversizedBox::refill(future, async move {
1372            scsi_queue
1373                .execute_scsi(&full_request.external_data, &full_request.request)
1374                .await
1375        });
1376        let request = ScsiRequest::new(request_id, oversized_box::coerce!(future));
1377        self.scsi_requests.push(request);
1378    }
1379}
1380
1381impl<T: RingMem> Drop for Worker<T> {
1382    fn drop(&mut self) {
1383        self.inner
1384            .controller
1385            .remove_rescan_notification_source(&self.rescan_notification);
1386    }
1387}
1388
1389#[derive(Debug, Error)]
1390#[error("SCSI path {}:{}:{} is already in use", self.0.path, self.0.target, self.0.lun)]
1391pub struct ScsiPathInUse(pub ScsiPath);
1392
1393#[derive(Debug, Error)]
1394#[error("SCSI path {}:{}:{} is not in use", self.0.path, self.0.target, self.0.lun)]
1395pub struct ScsiPathNotInUse(ScsiPath);
1396
1397#[derive(Clone)]
1398struct ScsiRequestState {
1399    transaction_id: u64,
1400    request: Arc<ScsiRequestAndRange>,
1401}
1402
1403#[derive(Debug)]
1404struct ScsiRequestAndRange {
1405    external_data: Range,
1406    request: storvsp_protocol::ScsiRequest,
1407    request_size: usize,
1408}
1409
1410impl Inspect for ScsiRequestState {
1411    fn inspect(&self, req: inspect::Request<'_>) {
1412        req.respond()
1413            .field("transaction_id", self.transaction_id)
1414            .display(
1415                "address",
1416                &ScsiPath {
1417                    path: self.request.request.path_id,
1418                    target: self.request.request.target_id,
1419                    lun: self.request.request.lun,
1420                },
1421            )
1422            .display_debug("operation", &ScsiOp(self.request.request.payload[0]));
1423    }
1424}
1425
1426impl StorageDevice {
1427    /// Returns a new SCSI device.
1428    pub fn build_scsi(
1429        driver_source: &VmTaskDriverSource,
1430        controller: &ScsiController,
1431        instance_id: Guid,
1432        max_sub_channel_count: u16,
1433        io_queue_depth: u32,
1434    ) -> Self {
1435        Self::build_inner(
1436            driver_source,
1437            controller,
1438            instance_id,
1439            None,
1440            max_sub_channel_count,
1441            io_queue_depth,
1442        )
1443    }
1444
1445    /// Returns a new SCSI device for implementing an IDE accelerator channel
1446    /// for IDE device `device_id` on channel `channel_id`.
1447    pub fn build_ide(
1448        driver_source: &VmTaskDriverSource,
1449        channel_id: u8,
1450        device_id: u8,
1451        disk: ScsiControllerDisk,
1452        io_queue_depth: u32,
1453    ) -> Self {
1454        let path = ScsiPath {
1455            path: channel_id,
1456            target: device_id,
1457            lun: 0,
1458        };
1459
1460        let controller = ScsiController::new();
1461        controller.attach(path, disk).unwrap();
1462
1463        // Construct the specific GUID that drivers in the guest expect for this
1464        // IDE device.
1465        let instance_id = Guid {
1466            data1: channel_id.into(),
1467            data2: device_id.into(),
1468            data3: 0x8899,
1469            data4: [0; 8],
1470        };
1471        Self::build_inner(
1472            driver_source,
1473            &controller,
1474            instance_id,
1475            Some(path),
1476            0,
1477            io_queue_depth,
1478        )
1479    }
1480
1481    fn build_inner(
1482        driver_source: &VmTaskDriverSource,
1483        controller: &ScsiController,
1484        instance_id: Guid,
1485        ide_path: Option<ScsiPath>,
1486        max_sub_channel_count: u16,
1487        io_queue_depth: u32,
1488    ) -> Self {
1489        let workers = (0..max_sub_channel_count + 1)
1490            .map(|channel_index| WorkerAndDriver {
1491                worker: TaskControl::new(WorkerState),
1492                driver: driver_source
1493                    .builder()
1494                    .target_vp(0)
1495                    .run_on_target(true)
1496                    .build(format!("storvsp-{}-{}", instance_id, channel_index)),
1497            })
1498            .collect();
1499
1500        Self {
1501            instance_id,
1502            ide_path,
1503            workers,
1504            controller: controller.state.clone(),
1505            resources: Default::default(),
1506            max_sub_channel_count,
1507            driver_source: driver_source.clone(),
1508            protocol: Arc::new(Protocol {
1509                state: RwLock::new(ProtocolState::Init(InitState::Begin)),
1510                ready: Default::default(),
1511            }),
1512            io_queue_depth,
1513        }
1514    }
1515
1516    fn new_worker(
1517        &mut self,
1518        open_request: &OpenRequest,
1519        channel_index: u16,
1520    ) -> anyhow::Result<&mut Worker> {
1521        let controller = self.controller.clone();
1522
1523        let driver = self
1524            .driver_source
1525            .builder()
1526            .target_vp(open_request.open_data.target_vp)
1527            .run_on_target(true)
1528            .build(format!("storvsp-{}-{}", self.instance_id, channel_index));
1529
1530        let channel = gpadl_channel(&driver, &self.resources, open_request, channel_index)
1531            .context("failed to create vmbus channel")?;
1532
1533        let channel_control = self.resources.channel_control.clone();
1534
1535        tracing::debug!(
1536            target_vp = open_request.open_data.target_vp,
1537            channel_index,
1538            "packet processing starting...",
1539        );
1540
1541        // Force the path ID on incoming SCSI requests to match the IDE
1542        // channel ID, since guests do not reliably set the path ID
1543        // correctly.
1544        let force_path_id = self.ide_path.map(|p| p.path);
1545
1546        let worker = Worker::new(
1547            controller,
1548            channel,
1549            channel_index,
1550            self.resources
1551                .offer_resources
1552                .guest_memory(open_request)
1553                .clone(),
1554            channel_control,
1555            self.io_queue_depth,
1556            self.protocol.clone(),
1557            force_path_id,
1558        )
1559        .map_err(RestoreError::Other)?;
1560
1561        self.workers[channel_index as usize]
1562            .driver
1563            .retarget_vp(open_request.open_data.target_vp);
1564
1565        Ok(self.workers[channel_index as usize].worker.insert(
1566            &driver,
1567            format!("storvsp worker {}-{}", self.instance_id, channel_index),
1568            worker,
1569        ))
1570    }
1571}
1572
1573/// A disk that can be added to a SCSI controller.
1574#[derive(Clone)]
1575pub struct ScsiControllerDisk {
1576    disk: Arc<dyn AsyncScsiDisk>,
1577}
1578
1579impl ScsiControllerDisk {
1580    /// Creates a new controller disk from an async SCSI disk.
1581    pub fn new(disk: Arc<dyn AsyncScsiDisk>) -> Self {
1582        Self { disk }
1583    }
1584}
1585
1586struct ScsiControllerState {
1587    disks: RwLock<HashMap<ScsiPath, ScsiControllerDisk>>,
1588    rescan_notification_source: Mutex<Vec<futures::channel::mpsc::Sender<()>>>,
1589    poll_mode_queue_depth: AtomicU32,
1590}
1591
1592pub struct ScsiController {
1593    state: Arc<ScsiControllerState>,
1594}
1595
1596impl ScsiController {
1597    pub fn new() -> Self {
1598        Self {
1599            state: Arc::new(ScsiControllerState {
1600                disks: Default::default(),
1601                rescan_notification_source: Mutex::new(Vec::new()),
1602                poll_mode_queue_depth: AtomicU32::new(DEFAULT_POLL_MODE_QUEUE_DEPTH),
1603            }),
1604        }
1605    }
1606
1607    pub fn attach(&self, path: ScsiPath, disk: ScsiControllerDisk) -> Result<(), ScsiPathInUse> {
1608        match self.state.disks.write().entry(path) {
1609            Entry::Occupied(_) => return Err(ScsiPathInUse(path)),
1610            Entry::Vacant(entry) => entry.insert(disk),
1611        };
1612        for source in self.state.rescan_notification_source.lock().iter_mut() {
1613            // Ok to ignore errors here. If the channel is full a previous notification has not yet
1614            // been processed by the primary channel worker.
1615            source.try_send(()).ok();
1616        }
1617        Ok(())
1618    }
1619
1620    pub fn remove(&self, path: ScsiPath) -> Result<(), ScsiPathNotInUse> {
1621        match self.state.disks.write().entry(path) {
1622            Entry::Vacant(_) => return Err(ScsiPathNotInUse(path)),
1623            Entry::Occupied(entry) => {
1624                entry.remove();
1625            }
1626        }
1627        for source in self.state.rescan_notification_source.lock().iter_mut() {
1628            // Ok to ignore errors here. If the channel is full a previous notification has not yet
1629            // been processed by the primary channel worker.
1630            source.try_send(()).ok();
1631        }
1632        Ok(())
1633    }
1634}
1635
1636impl ScsiControllerState {
1637    fn add_rescan_notification_source(&self, source: futures::channel::mpsc::Sender<()>) {
1638        self.rescan_notification_source.lock().push(source);
1639    }
1640
1641    fn remove_rescan_notification_source(&self, target: &futures::channel::mpsc::Receiver<()>) {
1642        let mut sources = self.rescan_notification_source.lock();
1643        if let Some(index) = sources
1644            .iter()
1645            .position(|source| source.is_connected_to(target))
1646        {
1647            sources.remove(index);
1648        }
1649    }
1650}
1651
1652#[async_trait]
1653impl VmbusDevice for StorageDevice {
1654    fn offer(&self) -> OfferParams {
1655        if let Some(path) = self.ide_path {
1656            let offer_properties = storvsp_protocol::OfferProperties {
1657                path_id: path.path,
1658                target_id: path.target,
1659                flags: storvsp_protocol::OFFER_PROPERTIES_FLAG_IDE_DEVICE,
1660                ..FromZeros::new_zeroed()
1661            };
1662            let mut user_defined = UserDefinedData::new_zeroed();
1663            offer_properties
1664                .write_to_prefix(&mut user_defined[..])
1665                .unwrap();
1666            OfferParams {
1667                interface_name: "ide-accel".to_owned(),
1668                instance_id: self.instance_id,
1669                interface_id: storvsp_protocol::IDE_ACCELERATOR_INTERFACE_ID,
1670                channel_type: ChannelType::Interface { user_defined },
1671                ..Default::default()
1672            }
1673        } else {
1674            OfferParams {
1675                interface_name: "scsi".to_owned(),
1676                instance_id: self.instance_id,
1677                interface_id: storvsp_protocol::SCSI_INTERFACE_ID,
1678                ..Default::default()
1679            }
1680        }
1681    }
1682
1683    fn max_subchannels(&self) -> u16 {
1684        self.max_sub_channel_count
1685    }
1686
1687    fn install(&mut self, resources: DeviceResources) {
1688        self.resources = resources;
1689    }
1690
1691    async fn open(
1692        &mut self,
1693        channel_index: u16,
1694        open_request: &OpenRequest,
1695    ) -> Result<(), ChannelOpenError> {
1696        tracing::debug!(channel_index, "scsi open channel");
1697        self.new_worker(open_request, channel_index)?;
1698        self.workers[channel_index as usize].worker.start();
1699        Ok(())
1700    }
1701
1702    async fn close(&mut self, channel_index: u16) {
1703        tracing::debug!(channel_index, "scsi close channel");
1704        let worker = &mut self.workers[channel_index as usize].worker;
1705        worker.stop().await;
1706        if worker.state_mut().is_some() {
1707            worker
1708                .state_mut()
1709                .unwrap()
1710                .wait_for_scsi_requests_complete()
1711                .await;
1712            worker.remove();
1713        }
1714        if channel_index == 0 {
1715            *self.protocol.state.write() = ProtocolState::Init(InitState::Begin);
1716        }
1717    }
1718
1719    async fn retarget_vp(&mut self, channel_index: u16, target_vp: u32) {
1720        self.workers[channel_index as usize]
1721            .driver
1722            .retarget_vp(target_vp);
1723    }
1724
1725    fn start(&mut self) {
1726        for task in self
1727            .workers
1728            .iter_mut()
1729            .filter(|task| task.worker.has_state() && !task.worker.is_running())
1730        {
1731            task.worker.start();
1732        }
1733    }
1734
1735    async fn stop(&mut self) {
1736        tracing::debug!(instance_id = ?self.instance_id, "StorageDevice stopping...");
1737        for task in self
1738            .workers
1739            .iter_mut()
1740            .filter(|task| task.worker.has_state() && task.worker.is_running())
1741        {
1742            task.worker.stop().await;
1743        }
1744    }
1745
1746    fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1747        Some(self)
1748    }
1749}
1750
1751#[async_trait]
1752impl SaveRestoreVmbusDevice for StorageDevice {
1753    async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1754        Ok(SavedStateBlob::new(self.save()?))
1755    }
1756
1757    async fn restore(
1758        &mut self,
1759        control: RestoreControl<'_>,
1760        state: SavedStateBlob,
1761    ) -> Result<(), RestoreError> {
1762        self.restore(control, state.parse()?).await
1763    }
1764}
1765
1766#[cfg(test)]
1767mod tests {
1768    use super::*;
1769    use crate::test_helpers::TestWorker;
1770    use crate::test_helpers::parse_guest_completion;
1771    use crate::test_helpers::parse_guest_completion_check_flags_status;
1772    use pal_async::DefaultDriver;
1773    use pal_async::async_test;
1774    use scsi::srb::SrbStatus;
1775    use test_with_tracing::test;
1776    use vmbus_channel::connected_async_channels;
1777
1778    // Discourage `Clone` for `ScsiController` outside the crate, but it is
1779    // necessary for testing. The fuzzer also uses `TestWorker`, which needs
1780    // a `clone` of the inner state, but is not in this crate.
1781    impl Clone for ScsiController {
1782        fn clone(&self) -> Self {
1783            ScsiController {
1784                state: self.state.clone(),
1785            }
1786        }
1787    }
1788
1789    #[async_test]
1790    async fn test_channel_working(driver: DefaultDriver) {
1791        // set up the channels and worker
1792        let (host, guest) = connected_async_channels(16 * 1024);
1793        let guest_queue = Queue::new(guest).unwrap();
1794
1795        let test_guest_mem = GuestMemory::allocate(16384);
1796        let controller = ScsiController::new();
1797        let disk = scsidisk::SimpleScsiDisk::new(
1798            disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
1799            Default::default(),
1800        );
1801        controller
1802            .attach(
1803                ScsiPath {
1804                    path: 0,
1805                    target: 0,
1806                    lun: 0,
1807                },
1808                ScsiControllerDisk::new(Arc::new(disk)),
1809            )
1810            .unwrap();
1811
1812        let test_worker = TestWorker::start(
1813            controller.clone(),
1814            driver.clone(),
1815            test_guest_mem.clone(),
1816            host,
1817            None,
1818        );
1819
1820        let mut guest = test_helpers::TestGuest {
1821            queue: guest_queue,
1822            transaction_id: 0,
1823        };
1824
1825        guest.perform_protocol_negotiation().await;
1826
1827        // Set up the buffer for a write request
1828        const IO_LEN: usize = 4 * 1024;
1829        let write_buf = [7u8; IO_LEN];
1830        let write_gpa = 4 * 1024u64;
1831        test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
1832        guest
1833            .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
1834            .await;
1835
1836        guest
1837            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1838            .await;
1839
1840        let read_gpa = 8 * 1024u64;
1841        guest
1842            .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
1843            .await;
1844
1845        guest
1846            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1847            .await;
1848        let mut read_buf = [0u8; IO_LEN];
1849        test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
1850        for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
1851            assert_eq!(b1, b2);
1852        }
1853
1854        // stop everything
1855        guest.verify_graceful_close(test_worker).await;
1856    }
1857
1858    #[async_test]
1859    async fn test_packet_sizes(driver: DefaultDriver) {
1860        // set up the channels and worker
1861        let (host, guest) = connected_async_channels(16384);
1862        let guest_queue = Queue::new(guest).unwrap();
1863
1864        let test_guest_mem = GuestMemory::allocate(1024);
1865        let controller = ScsiController::new();
1866
1867        let _worker = TestWorker::start(
1868            controller.clone(),
1869            driver.clone(),
1870            test_guest_mem,
1871            host,
1872            None,
1873        );
1874
1875        let mut guest = test_helpers::TestGuest {
1876            queue: guest_queue,
1877            transaction_id: 0,
1878        };
1879
1880        let negotiate_packet = storvsp_protocol::Packet {
1881            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
1882            flags: 0,
1883            status: storvsp_protocol::NtStatus::SUCCESS,
1884        };
1885        guest
1886            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1887            .await;
1888
1889        guest.verify_completion(parse_guest_completion).await;
1890
1891        let header = storvsp_protocol::Packet {
1892            operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
1893            flags: 0,
1894            status: storvsp_protocol::NtStatus::SUCCESS,
1895        };
1896
1897        let mut buf = [0u8; 128];
1898        storvsp_protocol::ProtocolVersion {
1899            major_minor: !0,
1900            reserved: 0,
1901        }
1902        .write_to_prefix(&mut buf[..])
1903        .unwrap(); // PANIC: Infallable since `ProtcolVersion` is less than 128 bytes
1904
1905        for &(len, resp_len) in &[(48, 48), (50, 56), (56, 56), (64, 64), (72, 64)] {
1906            guest
1907                .send_data_packet_sync(&[header.as_bytes(), &buf[..len - size_of_val(&header)]])
1908                .await;
1909
1910            guest
1911                .verify_completion(|packet| {
1912                    let IncomingPacket::Completion(packet) = packet else {
1913                        unreachable!()
1914                    };
1915                    assert_eq!(packet.reader().len(), resp_len);
1916                    assert_eq!(
1917                        packet
1918                            .reader()
1919                            .read_plain::<storvsp_protocol::Packet>()
1920                            .unwrap()
1921                            .status,
1922                        storvsp_protocol::NtStatus::REVISION_MISMATCH
1923                    );
1924                    Ok(())
1925                })
1926                .await;
1927        }
1928    }
1929
1930    #[async_test]
1931    async fn test_wrong_first_packet(driver: DefaultDriver) {
1932        // set up the channels and worker
1933        let (host, guest) = connected_async_channels(16384);
1934        let guest_queue = Queue::new(guest).unwrap();
1935
1936        let test_guest_mem = GuestMemory::allocate(1024);
1937        let controller = ScsiController::new();
1938
1939        let _worker = TestWorker::start(
1940            controller.clone(),
1941            driver.clone(),
1942            test_guest_mem,
1943            host,
1944            None,
1945        );
1946
1947        let mut guest = test_helpers::TestGuest {
1948            queue: guest_queue,
1949            transaction_id: 0,
1950        };
1951
1952        // Protocol negotiation done out of order
1953        let negotiate_packet = storvsp_protocol::Packet {
1954            operation: storvsp_protocol::Operation::END_INITIALIZATION,
1955            flags: 0,
1956            status: storvsp_protocol::NtStatus::SUCCESS,
1957        };
1958        guest
1959            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1960            .await;
1961
1962        guest
1963            .verify_completion(|packet| {
1964                let IncomingPacket::Completion(packet) = packet else {
1965                    unreachable!()
1966                };
1967                assert_eq!(
1968                    packet
1969                        .reader()
1970                        .read_plain::<storvsp_protocol::Packet>()
1971                        .unwrap()
1972                        .status,
1973                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE
1974                );
1975                Ok(())
1976            })
1977            .await;
1978    }
1979
1980    #[async_test]
1981    async fn test_unrecognized_operation(driver: DefaultDriver) {
1982        // set up the channels and worker
1983        let (host, guest) = connected_async_channels(16384);
1984        let guest_queue = Queue::new(guest).unwrap();
1985
1986        let test_guest_mem = GuestMemory::allocate(1024);
1987        let controller = ScsiController::new();
1988
1989        let worker = TestWorker::start(
1990            controller.clone(),
1991            driver.clone(),
1992            test_guest_mem,
1993            host,
1994            None,
1995        );
1996
1997        let mut guest = test_helpers::TestGuest {
1998            queue: guest_queue,
1999            transaction_id: 0,
2000        };
2001
2002        // Send packet with unrecognized operation
2003        let negotiate_packet = storvsp_protocol::Packet {
2004            operation: storvsp_protocol::Operation::REMOVE_DEVICE,
2005            flags: 0,
2006            status: storvsp_protocol::NtStatus::SUCCESS,
2007        };
2008        guest
2009            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2010            .await;
2011
2012        match worker.teardown().await {
2013            Err(WorkerError::PacketError(PacketError::UnrecognizedOperation(
2014                storvsp_protocol::Operation::REMOVE_DEVICE,
2015            ))) => {}
2016            result => panic!("Worker failed with unexpected result {:?}!", result),
2017        }
2018    }
2019
2020    #[async_test]
2021    async fn test_too_many_subchannels(driver: DefaultDriver) {
2022        // set up the channels and worker
2023        let (host, guest) = connected_async_channels(16384);
2024        let guest_queue = Queue::new(guest).unwrap();
2025
2026        let test_guest_mem = GuestMemory::allocate(1024);
2027        let controller = ScsiController::new();
2028
2029        let _worker = TestWorker::start(
2030            controller.clone(),
2031            driver.clone(),
2032            test_guest_mem,
2033            host,
2034            None,
2035        );
2036
2037        let mut guest = test_helpers::TestGuest {
2038            queue: guest_queue,
2039            transaction_id: 0,
2040        };
2041
2042        let negotiate_packet = storvsp_protocol::Packet {
2043            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2044            flags: 0,
2045            status: storvsp_protocol::NtStatus::SUCCESS,
2046        };
2047        guest
2048            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2049            .await;
2050        guest.verify_completion(parse_guest_completion).await;
2051
2052        let version_packet = storvsp_protocol::Packet {
2053            operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2054            flags: 0,
2055            status: storvsp_protocol::NtStatus::SUCCESS,
2056        };
2057        let version = storvsp_protocol::ProtocolVersion {
2058            major_minor: storvsp_protocol::VERSION_BLUE,
2059            reserved: 0,
2060        };
2061        guest
2062            .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2063            .await;
2064        guest.verify_completion(parse_guest_completion).await;
2065
2066        let properties_packet = storvsp_protocol::Packet {
2067            operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2068            flags: 0,
2069            status: storvsp_protocol::NtStatus::SUCCESS,
2070        };
2071        guest
2072            .send_data_packet_sync(&[properties_packet.as_bytes()])
2073            .await;
2074
2075        guest.verify_completion(parse_guest_completion).await;
2076
2077        let negotiate_packet = storvsp_protocol::Packet {
2078            operation: storvsp_protocol::Operation::CREATE_SUB_CHANNELS,
2079            flags: 0,
2080            status: storvsp_protocol::NtStatus::SUCCESS,
2081        };
2082        // Create sub channels more than maximum_sub_channel_count
2083        guest
2084            .send_data_packet_sync(&[negotiate_packet.as_bytes(), 1_u16.as_bytes()])
2085            .await;
2086
2087        guest
2088            .verify_completion(|packet| {
2089                let IncomingPacket::Completion(packet) = packet else {
2090                    unreachable!()
2091                };
2092                assert_eq!(
2093                    packet
2094                        .reader()
2095                        .read_plain::<storvsp_protocol::Packet>()
2096                        .unwrap()
2097                        .status,
2098                    storvsp_protocol::NtStatus::INVALID_PARAMETER
2099                );
2100                Ok(())
2101            })
2102            .await;
2103    }
2104
2105    #[async_test]
2106    async fn test_begin_init_on_ready(driver: DefaultDriver) {
2107        // set up the channels and worker
2108        let (host, guest) = connected_async_channels(16384);
2109        let guest_queue = Queue::new(guest).unwrap();
2110
2111        let test_guest_mem = GuestMemory::allocate(1024);
2112        let controller = ScsiController::new();
2113
2114        let _worker = TestWorker::start(
2115            controller.clone(),
2116            driver.clone(),
2117            test_guest_mem,
2118            host,
2119            None,
2120        );
2121
2122        let mut guest = test_helpers::TestGuest {
2123            queue: guest_queue,
2124            transaction_id: 0,
2125        };
2126
2127        guest.perform_protocol_negotiation().await;
2128
2129        // Protocol negotiation done out of order
2130        let negotiate_packet = storvsp_protocol::Packet {
2131            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2132            flags: 0,
2133            status: storvsp_protocol::NtStatus::SUCCESS,
2134        };
2135        guest
2136            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2137            .await;
2138
2139        guest
2140            .verify_completion(|p| {
2141                parse_guest_completion_check_flags_status(
2142                    p,
2143                    0,
2144                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
2145                )
2146            })
2147            .await;
2148    }
2149
2150    #[async_test]
2151    async fn test_hot_add_remove(driver: DefaultDriver) {
2152        // set up channels and worker.
2153        let (host, guest) = connected_async_channels(16 * 1024);
2154        let guest_queue = Queue::new(guest).unwrap();
2155
2156        let test_guest_mem = GuestMemory::allocate(16384);
2157        // create a controller with no disk yet.
2158        let controller = ScsiController::new();
2159
2160        let test_worker = TestWorker::start(
2161            controller.clone(),
2162            driver.clone(),
2163            test_guest_mem.clone(),
2164            host,
2165            None,
2166        );
2167
2168        let mut guest = test_helpers::TestGuest {
2169            queue: guest_queue,
2170            transaction_id: 0,
2171        };
2172
2173        guest.perform_protocol_negotiation().await;
2174
2175        // Verify no LUNs are reported initially.
2176        let mut lun_list_buffer: [u8; 256] = [0; 256];
2177        let mut disk_count = 0;
2178        guest
2179            .send_report_luns_packet(ScsiPath::default(), 0, lun_list_buffer.len())
2180            .await;
2181        guest
2182            .verify_completion(|p| {
2183                test_helpers::parse_guest_completed_io_check_tx_len(p, SrbStatus::SUCCESS, Some(8))
2184            })
2185            .await;
2186        test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2187        let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2188        assert_eq!(lun_list_size, disk_count as u32 * 8);
2189
2190        // Set up a buffer for writes.
2191        const IO_LEN: usize = 4 * 1024;
2192        let write_buf = [7u8; IO_LEN];
2193        let write_gpa = 4 * 1024u64;
2194        test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2195
2196        guest
2197            .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2198            .await;
2199        guest
2200            .verify_completion(|p| {
2201                test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2202            })
2203            .await;
2204
2205        // Add some disks while the guest is running.
2206        for lun in 0..4 {
2207            let disk = scsidisk::SimpleScsiDisk::new(
2208                disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
2209                Default::default(),
2210            );
2211            controller
2212                .attach(
2213                    ScsiPath {
2214                        path: 0,
2215                        target: 0,
2216                        lun,
2217                    },
2218                    ScsiControllerDisk::new(Arc::new(disk)),
2219                )
2220                .unwrap();
2221            guest
2222                .verify_completion(test_helpers::parse_guest_enumerate_bus)
2223                .await;
2224
2225            disk_count += 1;
2226            guest
2227                .send_report_luns_packet(ScsiPath::default(), 0, 256)
2228                .await;
2229            guest
2230                .verify_completion(|p| {
2231                    test_helpers::parse_guest_completed_io_check_tx_len(
2232                        p,
2233                        SrbStatus::SUCCESS,
2234                        Some((disk_count + 1) * 8),
2235                    )
2236                })
2237                .await;
2238            test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2239            let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2240            assert_eq!(lun_list_size, disk_count as u32 * 8);
2241
2242            guest
2243                .send_write_packet(
2244                    ScsiPath {
2245                        path: 0,
2246                        target: 0,
2247                        lun,
2248                    },
2249                    write_gpa,
2250                    1,
2251                    IO_LEN,
2252                )
2253                .await;
2254            guest
2255                .verify_completion(|p| {
2256                    test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS)
2257                })
2258                .await;
2259        }
2260
2261        // Remove all disks while the guest is running.
2262        for lun in 0..4 {
2263            controller
2264                .remove(ScsiPath {
2265                    path: 0,
2266                    target: 0,
2267                    lun,
2268                })
2269                .unwrap();
2270            guest
2271                .verify_completion(test_helpers::parse_guest_enumerate_bus)
2272                .await;
2273
2274            disk_count -= 1;
2275            guest
2276                .send_report_luns_packet(ScsiPath::default(), 0, 4096)
2277                .await;
2278            guest
2279                .verify_completion(|p| {
2280                    test_helpers::parse_guest_completed_io_check_tx_len(
2281                        p,
2282                        SrbStatus::SUCCESS,
2283                        Some((disk_count + 1) * 8),
2284                    )
2285                })
2286                .await;
2287            test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2288            let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2289            assert_eq!(lun_list_size, disk_count as u32 * 8);
2290
2291            guest
2292                .send_write_packet(
2293                    ScsiPath {
2294                        path: 0,
2295                        target: 0,
2296                        lun,
2297                    },
2298                    write_gpa,
2299                    1,
2300                    IO_LEN,
2301                )
2302                .await;
2303            guest
2304                .verify_completion(|p| {
2305                    test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2306                })
2307                .await;
2308        }
2309
2310        guest.verify_graceful_close(test_worker).await;
2311    }
2312
2313    #[async_test]
2314    pub async fn test_async_disk(driver: DefaultDriver) {
2315        let device = disklayer_ram::ram_disk(64 * 1024, false).unwrap();
2316        let controller = ScsiController::new();
2317        let disk = ScsiControllerDisk::new(Arc::new(scsidisk::SimpleScsiDisk::new(
2318            device,
2319            Default::default(),
2320        )));
2321        controller
2322            .attach(
2323                ScsiPath {
2324                    path: 0,
2325                    target: 0,
2326                    lun: 0,
2327                },
2328                disk,
2329            )
2330            .unwrap();
2331
2332        let (host, guest) = connected_async_channels(16 * 1024);
2333        let guest_queue = Queue::new(guest).unwrap();
2334
2335        let mut guest = test_helpers::TestGuest {
2336            queue: guest_queue,
2337            transaction_id: 0,
2338        };
2339
2340        let test_guest_mem = GuestMemory::allocate(16384);
2341        let worker = TestWorker::start(
2342            controller.clone(),
2343            &driver,
2344            test_guest_mem.clone(),
2345            host,
2346            None,
2347        );
2348
2349        let negotiate_packet = storvsp_protocol::Packet {
2350            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2351            flags: 0,
2352            status: storvsp_protocol::NtStatus::SUCCESS,
2353        };
2354        guest
2355            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2356            .await;
2357        guest.verify_completion(parse_guest_completion).await;
2358
2359        let version_packet = storvsp_protocol::Packet {
2360            operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2361            flags: 0,
2362            status: storvsp_protocol::NtStatus::SUCCESS,
2363        };
2364        let version = storvsp_protocol::ProtocolVersion {
2365            major_minor: storvsp_protocol::VERSION_BLUE,
2366            reserved: 0,
2367        };
2368        guest
2369            .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2370            .await;
2371        guest.verify_completion(parse_guest_completion).await;
2372
2373        let properties_packet = storvsp_protocol::Packet {
2374            operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2375            flags: 0,
2376            status: storvsp_protocol::NtStatus::SUCCESS,
2377        };
2378        guest
2379            .send_data_packet_sync(&[properties_packet.as_bytes()])
2380            .await;
2381        guest.verify_completion(parse_guest_completion).await;
2382
2383        let negotiate_packet = storvsp_protocol::Packet {
2384            operation: storvsp_protocol::Operation::END_INITIALIZATION,
2385            flags: 0,
2386            status: storvsp_protocol::NtStatus::SUCCESS,
2387        };
2388        guest
2389            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2390            .await;
2391        guest.verify_completion(parse_guest_completion).await;
2392
2393        const IO_LEN: usize = 4 * 1024;
2394        let write_buf = [7u8; IO_LEN];
2395        let write_gpa = 4 * 1024u64;
2396        test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2397        guest
2398            .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2399            .await;
2400        guest
2401            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2402            .await;
2403
2404        let read_gpa = 8 * 1024u64;
2405        guest
2406            .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
2407            .await;
2408        guest
2409            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2410            .await;
2411        let mut read_buf = [0u8; IO_LEN];
2412        test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
2413        for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
2414            assert_eq!(b1, b2);
2415        }
2416
2417        guest.verify_graceful_close(worker).await;
2418    }
2419}