storvsp/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7#[cfg(feature = "ioperf")]
8pub mod ioperf;
9
10#[cfg(feature = "test")]
11pub mod test_helpers;
12
13#[cfg(not(feature = "test"))]
14mod test_helpers;
15
16pub mod resolver;
17mod save_restore;
18
19use crate::ring::gparange::MultiPagedRangeBuf;
20use anyhow::Context as _;
21use async_trait::async_trait;
22use fast_select::FastSelect;
23use futures::FutureExt;
24use futures::StreamExt;
25use futures::select_biased;
26use guestmem::AccessError;
27use guestmem::GuestMemory;
28use guestmem::MemoryRead;
29use guestmem::MemoryWrite;
30use guestmem::ranges::PagedRange;
31use guid::Guid;
32use inspect::Inspect;
33use inspect::InspectMut;
34use inspect_counters::Counter;
35use inspect_counters::Histogram;
36use oversized_box::OversizedBox;
37use parking_lot::Mutex;
38use parking_lot::RwLock;
39use ring::OutgoingPacketType;
40use scsi::AdditionalSenseCode;
41use scsi::ScsiOp;
42use scsi::ScsiStatus;
43use scsi::srb::SrbStatus;
44use scsi::srb::SrbStatusAndFlags;
45use scsi_buffers::RequestBuffers;
46use scsi_core::AsyncScsiDisk;
47use scsi_core::Request;
48use scsi_core::ScsiResult;
49use scsi_defs as scsi;
50use scsidisk::illegal_request_sense;
51use slab::Slab;
52use std::collections::hash_map::Entry;
53use std::collections::hash_map::HashMap;
54use std::fmt::Debug;
55use std::future::Future;
56use std::future::poll_fn;
57use std::pin::Pin;
58use std::sync::Arc;
59use std::sync::atomic::AtomicU32;
60use std::sync::atomic::Ordering::Relaxed;
61use std::task::Context;
62use std::task::Poll;
63use storvsp_resources::ScsiPath;
64use task_control::AsyncRun;
65use task_control::InspectTask;
66use task_control::StopTask;
67use task_control::TaskControl;
68use thiserror::Error;
69use tracing_helpers::ErrorValueExt;
70use unicycle::FuturesUnordered;
71use vmbus_async::queue;
72use vmbus_async::queue::ExternalDataError;
73use vmbus_async::queue::IncomingPacket;
74use vmbus_async::queue::OutgoingPacket;
75use vmbus_async::queue::Queue;
76use vmbus_channel::RawAsyncChannel;
77use vmbus_channel::bus::ChannelType;
78use vmbus_channel::bus::OfferParams;
79use vmbus_channel::bus::OpenRequest;
80use vmbus_channel::channel::ChannelControl;
81use vmbus_channel::channel::ChannelOpenError;
82use vmbus_channel::channel::DeviceResources;
83use vmbus_channel::channel::RestoreControl;
84use vmbus_channel::channel::SaveRestoreVmbusDevice;
85use vmbus_channel::channel::VmbusDevice;
86use vmbus_channel::gpadl_ring::GpadlRingMem;
87use vmbus_channel::gpadl_ring::gpadl_channel;
88use vmbus_core::protocol::UserDefinedData;
89use vmbus_ring as ring;
90use vmbus_ring::RingMem;
91use vmcore::save_restore::RestoreError;
92use vmcore::save_restore::SaveError;
93use vmcore::save_restore::SavedStateBlob;
94use vmcore::vm_task::VmTaskDriver;
95use vmcore::vm_task::VmTaskDriverSource;
96use zerocopy::FromBytes;
97use zerocopy::FromZeros;
98use zerocopy::Immutable;
99use zerocopy::IntoBytes;
100use zerocopy::KnownLayout;
101
102/// The IO queue depth at which the controller switches from guest-signal-driven
103/// to poll-mode operation. This optimization reduces the guest exit rate by
104/// relying on (typically-interrupt-driven) IO completions to drive polling for
105/// new IO requests.
106const DEFAULT_POLL_MODE_QUEUE_DEPTH: u32 = 1;
107
108pub struct StorageDevice {
109    instance_id: Guid,
110    ide_path: Option<ScsiPath>,
111    workers: Vec<WorkerAndDriver>,
112    controller: Arc<ScsiControllerState>,
113    resources: DeviceResources,
114    driver_source: VmTaskDriverSource,
115    max_sub_channel_count: u16,
116    protocol: Arc<Protocol>,
117    io_queue_depth: u32,
118}
119
120#[derive(Inspect)]
121struct WorkerAndDriver {
122    #[inspect(flatten)]
123    worker: TaskControl<WorkerState, Worker>,
124    driver: VmTaskDriver,
125}
126
127struct WorkerState;
128
129impl InspectMut for StorageDevice {
130    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
131        let mut resp = req.respond();
132
133        let disks = self.controller.disks.read();
134        for (path, controller_disk) in disks.iter() {
135            resp.child(&format!("disks/{}", path), |req| {
136                controller_disk.disk.inspect(req);
137            });
138        }
139
140        resp.fields(
141            "channels",
142            self.workers
143                .iter()
144                .filter(|task| task.worker.has_state())
145                .enumerate(),
146        )
147        .field(
148            "poll_mode_queue_depth",
149            inspect::AtomicMut(&self.controller.poll_mode_queue_depth),
150        );
151    }
152}
153
154struct Worker<T: RingMem = GpadlRingMem> {
155    inner: WorkerInner,
156    rescan_notification: futures::channel::mpsc::Receiver<()>,
157    fast_select: FastSelect,
158    queue: Queue<T>,
159}
160
161struct Protocol {
162    state: RwLock<ProtocolState>,
163    /// Signaled when `state` transitions to `ProtocolState::Ready`.
164    ready: event_listener::Event,
165}
166
167struct WorkerInner {
168    protocol: Arc<Protocol>,
169    request_size: usize,
170    controller: Arc<ScsiControllerState>,
171    channel_index: u16,
172    scsi_queue: Arc<ScsiCommandQueue>,
173    scsi_requests: FuturesUnordered<ScsiRequest>,
174    scsi_requests_states: Slab<ScsiRequestState>,
175    full_request_pool: Vec<Arc<ScsiRequestAndRange>>,
176    future_pool: Vec<OversizedBox<(), ScsiOpStorage>>,
177    channel_control: ChannelControl,
178    max_io_queue_depth: usize,
179    stats: WorkerStats,
180}
181
182#[derive(Debug, Default, Inspect)]
183struct WorkerStats {
184    ios_submitted: Counter,
185    ios_completed: Counter,
186    wakes: Counter,
187    wakes_spurious: Counter,
188    per_wake_submissions: Histogram<10>,
189    per_wake_completions: Histogram<10>,
190}
191
192#[repr(u16)]
193#[derive(Copy, Clone, Debug, Inspect, PartialEq, Eq, PartialOrd, Ord)]
194enum Version {
195    Win6 = storvsp_protocol::VERSION_WIN6,
196    Win7 = storvsp_protocol::VERSION_WIN7,
197    Win8 = storvsp_protocol::VERSION_WIN8,
198    Blue = storvsp_protocol::VERSION_BLUE,
199}
200
201#[derive(Debug, Error)]
202#[error("protocol version {0:#x} not supported")]
203struct UnsupportedVersion(u16);
204
205impl Version {
206    fn parse(major_minor: u16) -> Result<Self, UnsupportedVersion> {
207        let version = match major_minor {
208            storvsp_protocol::VERSION_WIN6 => Self::Win6,
209            storvsp_protocol::VERSION_WIN7 => Self::Win7,
210            storvsp_protocol::VERSION_WIN8 => Self::Win8,
211            storvsp_protocol::VERSION_BLUE => Self::Blue,
212            version => return Err(UnsupportedVersion(version)),
213        };
214        assert_eq!(version as u16, major_minor);
215        Ok(version)
216    }
217
218    fn max_request_size(&self) -> usize {
219        match self {
220            Version::Win8 | Version::Blue => storvsp_protocol::SCSI_REQUEST_LEN_V2,
221            Version::Win6 | Version::Win7 => storvsp_protocol::SCSI_REQUEST_LEN_V1,
222        }
223    }
224}
225
226#[derive(Copy, Clone)]
227enum ProtocolState {
228    Init(InitState),
229    Ready {
230        version: Version,
231        subchannel_count: u16,
232    },
233}
234
235#[derive(Copy, Clone, Debug)]
236enum InitState {
237    Begin,
238    QueryVersion,
239    QueryProperties {
240        version: Version,
241    },
242    EndInitialization {
243        version: Version,
244        subchannel_count: Option<u16>,
245    },
246}
247
248/// The internal SCSI operation future type.
249///
250/// This is a boxed future of a large pre-determined size. The box is reused
251/// after a SCSI request completes to avoid allocations in the hot path.
252///
253/// An Option type is used so that the future can be efficiently dropped (via
254/// `Pin::set(x, None)`) before it is stashed away for reuse.
255type ScsiOpStorage = [u64; SCSI_REQUEST_STACK_SIZE / 8];
256type ScsiOpFuture = Pin<OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>>;
257
258/// The amount of space reserved for a ScsiOpFuture.
259///
260/// This was chosen by running `cargo test -p storvsp -- --no-capture` and looking at the required
261/// size that was given in the failure message
262const SCSI_REQUEST_STACK_SIZE: usize = scsi_core::ASYNC_SCSI_DISK_STACK_SIZE + 272;
263
264struct ScsiRequest {
265    request_id: usize,
266    future: Option<ScsiOpFuture>,
267}
268
269impl ScsiRequest {
270    fn new(
271        request_id: usize,
272        future: OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>,
273    ) -> Self {
274        Self {
275            request_id,
276            future: Some(future.into()),
277        }
278    }
279}
280
281impl Future for ScsiRequest {
282    type Output = (usize, ScsiResult, OversizedBox<(), ScsiOpStorage>);
283
284    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
285        let this = self.get_mut();
286        let future = this.future.as_mut().unwrap().as_mut();
287        let result = std::task::ready!(future.poll(cx));
288        // Return the future so that its storage can be reused.
289        let future = this.future.take().unwrap();
290        Poll::Ready((this.request_id, result, OversizedBox::empty_pinned(future)))
291    }
292}
293
294#[derive(Debug, Error)]
295enum WorkerError {
296    #[error("packet error")]
297    PacketError(#[source] PacketError),
298    #[error("queue error")]
299    Queue(#[source] queue::Error),
300    #[error("queue should have enough space but no longer does")]
301    NotEnoughSpace,
302}
303
304#[derive(Debug, Error)]
305enum PacketError {
306    #[error("Not transactional")]
307    NotTransactional,
308    #[error("Unrecognized operation {0:?}")]
309    UnrecognizedOperation(storvsp_protocol::Operation),
310    #[error("Invalid packet type")]
311    InvalidPacketType,
312    #[error("Invalid data transfer length")]
313    InvalidDataTransferLength,
314    #[error("Access error: {0}")]
315    Access(#[source] AccessError),
316    #[error("Range error")]
317    Range(#[source] ExternalDataError),
318}
319
320#[derive(Debug, Default, Clone)]
321struct Range {
322    len: usize,
323    is_write: bool,
324}
325
326impl Range {
327    fn new(buf: &MultiPagedRangeBuf, request: &storvsp_protocol::ScsiRequest) -> Option<Self> {
328        let len = request.data_transfer_length as usize;
329        let is_write = request.data_in != 0;
330        // Ensure there is exactly one range and it's large enough, or there are
331        // zero ranges and there is no associated SCSI buffer.
332        if buf.range_count() > 1 || (len > 0 && buf.first()?.len() < len) {
333            return None;
334        }
335        Some(Self { len, is_write })
336    }
337
338    fn buffer<'a>(
339        &'a self,
340        buf: &'a MultiPagedRangeBuf,
341        guest_memory: &'a GuestMemory,
342    ) -> RequestBuffers<'a> {
343        let mut range = buf.first().unwrap_or_else(PagedRange::empty);
344        range.truncate(self.len);
345        RequestBuffers::new(guest_memory, range, self.is_write)
346    }
347}
348
349#[derive(Debug)]
350struct Packet {
351    data: PacketData,
352    transaction_id: u64,
353    request_size: usize,
354}
355
356#[derive(Debug)]
357enum PacketData {
358    BeginInitialization,
359    EndInitialization,
360    QueryProtocolVersion(u16),
361    QueryProperties,
362    CreateSubChannels(u16),
363    ExecuteScsi(Arc<ScsiRequestAndRange>),
364    ResetBus,
365    ResetAdapter,
366    ResetLun,
367}
368
369#[derive(Debug)]
370pub struct RangeError;
371
372fn parse_packet<T: RingMem>(
373    packet: &IncomingPacket<'_, T>,
374    pool: &mut Vec<Arc<ScsiRequestAndRange>>,
375) -> Result<Packet, PacketError> {
376    let packet = match packet {
377        IncomingPacket::Completion(_) => return Err(PacketError::InvalidPacketType),
378        IncomingPacket::Data(packet) => packet,
379    };
380    let transaction_id = packet
381        .transaction_id()
382        .ok_or(PacketError::NotTransactional)?;
383
384    let mut reader = packet.reader();
385    let header: storvsp_protocol::Packet = reader.read_plain().map_err(PacketError::Access)?;
386    // You would expect that this should be limited to the current protocol
387    // version's maximum packet size, but this is not what Hyper-V does, and
388    // Linux 6.1 relies on this behavior during protocol initialization.
389    let request_size = reader.len().min(storvsp_protocol::SCSI_REQUEST_LEN_MAX);
390    let data = match header.operation {
391        storvsp_protocol::Operation::BEGIN_INITIALIZATION => PacketData::BeginInitialization,
392        storvsp_protocol::Operation::END_INITIALIZATION => PacketData::EndInitialization,
393        storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION => {
394            let mut version = storvsp_protocol::ProtocolVersion::new_zeroed();
395            reader
396                .read(version.as_mut_bytes())
397                .map_err(PacketError::Access)?;
398            PacketData::QueryProtocolVersion(version.major_minor)
399        }
400        storvsp_protocol::Operation::QUERY_PROPERTIES => PacketData::QueryProperties,
401        storvsp_protocol::Operation::EXECUTE_SRB => {
402            let mut full_request = pool.pop().unwrap_or_else(|| {
403                Arc::new(ScsiRequestAndRange {
404                    external_data: Range::default(),
405                    external_data_buf: MultiPagedRangeBuf::new(),
406                    request: storvsp_protocol::ScsiRequest::new_zeroed(),
407                    request_size,
408                })
409            });
410
411            {
412                let full_request = Arc::get_mut(&mut full_request).unwrap();
413                let request_buf = &mut full_request.request.as_mut_bytes()[..request_size];
414                reader.read(request_buf).map_err(PacketError::Access)?;
415
416                full_request.external_data_buf.clear();
417                packet
418                    .read_external_ranges(&mut full_request.external_data_buf)
419                    .map_err(PacketError::Range)?;
420
421                full_request.external_data =
422                    Range::new(&full_request.external_data_buf, &full_request.request)
423                        .ok_or(PacketError::InvalidDataTransferLength)?;
424            }
425
426            PacketData::ExecuteScsi(full_request)
427        }
428        storvsp_protocol::Operation::RESET_LUN => PacketData::ResetLun,
429        storvsp_protocol::Operation::RESET_ADAPTER => PacketData::ResetAdapter,
430        storvsp_protocol::Operation::RESET_BUS => PacketData::ResetBus,
431        storvsp_protocol::Operation::CREATE_SUB_CHANNELS => {
432            let mut sub_channel_count: u16 = 0;
433            reader
434                .read(sub_channel_count.as_mut_bytes())
435                .map_err(PacketError::Access)?;
436            PacketData::CreateSubChannels(sub_channel_count)
437        }
438        _ => return Err(PacketError::UnrecognizedOperation(header.operation)),
439    };
440
441    if let PacketData::ExecuteScsi(_) = data {
442        tracing::trace!(transaction_id, ?data, "parse_packet");
443    } else {
444        tracing::debug!(transaction_id, ?data, "parse_packet");
445    }
446
447    Ok(Packet {
448        data,
449        request_size,
450        transaction_id,
451    })
452}
453
454impl WorkerInner {
455    fn send_vmbus_packet<M: RingMem>(
456        &mut self,
457        writer: &mut queue::WriteBatch<'_, M>,
458        packet_type: OutgoingPacketType<'_>,
459        request_size: usize,
460        transaction_id: u64,
461        operation: storvsp_protocol::Operation,
462        status: storvsp_protocol::NtStatus,
463        payload: &[u8],
464    ) -> Result<(), WorkerError> {
465        let header = storvsp_protocol::Packet {
466            operation,
467            flags: 0,
468            status,
469        };
470
471        let packet_size = size_of_val(&header) + request_size;
472
473        // Zero pad or truncate the payload to the queue's packet size. This is
474        // necessary because Windows guests check that each packet's size is
475        // exactly the largest possible packet size for the negotiated protocol
476        // version.
477        let len = size_of_val(&header) + size_of_val(payload);
478        let padding = [0; storvsp_protocol::SCSI_REQUEST_LEN_MAX];
479        let (payload_bytes, padding_bytes) = if len > packet_size {
480            (&payload[..packet_size - size_of_val(&header)], &[][..])
481        } else {
482            (payload, &padding[..packet_size - len])
483        };
484        assert_eq!(
485            size_of_val(&header) + payload_bytes.len() + padding_bytes.len(),
486            packet_size
487        );
488        writer
489            .try_write(&OutgoingPacket {
490                transaction_id,
491                packet_type,
492                payload: &[header.as_bytes(), payload_bytes, padding_bytes],
493            })
494            .map_err(|err| match err {
495                queue::TryWriteError::Full(_) => WorkerError::NotEnoughSpace,
496                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
497            })
498    }
499
500    fn send_packet<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
501        &mut self,
502        writer: &mut queue::WriteHalf<'_, M>,
503        operation: storvsp_protocol::Operation,
504        status: storvsp_protocol::NtStatus,
505        payload: &P,
506    ) -> Result<(), WorkerError> {
507        self.send_vmbus_packet(
508            &mut writer.batched(),
509            OutgoingPacketType::InBandNoCompletion,
510            self.request_size,
511            0,
512            operation,
513            status,
514            payload.as_bytes(),
515        )
516    }
517
518    fn send_completion<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
519        &mut self,
520        writer: &mut queue::WriteHalf<'_, M>,
521        packet: &Packet,
522        status: storvsp_protocol::NtStatus,
523        payload: &P,
524    ) -> Result<(), WorkerError> {
525        self.send_vmbus_packet(
526            &mut writer.batched(),
527            OutgoingPacketType::Completion,
528            packet.request_size,
529            packet.transaction_id,
530            storvsp_protocol::Operation::COMPLETE_IO,
531            status,
532            payload.as_bytes(),
533        )
534    }
535}
536
537struct ScsiCommandQueue {
538    controller: Arc<ScsiControllerState>,
539    mem: GuestMemory,
540    force_path_id: Option<u8>,
541}
542
543impl ScsiCommandQueue {
544    async fn execute_scsi(&self, full_request: &ScsiRequestAndRange) -> ScsiResult {
545        let request = &full_request.request;
546        let op = ScsiOp(request.payload[0]);
547        let external_data = full_request
548            .external_data
549            .buffer(&full_request.external_data_buf, &self.mem);
550
551        tracing::trace!(
552            path_id = request.path_id,
553            target_id = request.target_id,
554            lun = request.lun,
555            op = ?op,
556            "execute_scsi start...",
557        );
558
559        let path_id = self.force_path_id.unwrap_or(request.path_id);
560
561        let controller_disk = self
562            .controller
563            .disks
564            .read()
565            .get(&ScsiPath {
566                path: path_id,
567                target: request.target_id,
568                lun: request.lun,
569            })
570            .cloned();
571
572        let result = match op {
573            ScsiOp::REPORT_LUNS => {
574                const HEADER_SIZE: usize = size_of::<scsi::LunList>();
575                let mut luns: Vec<u8> = self
576                    .controller
577                    .disks
578                    .read()
579                    .iter()
580                    .flat_map(|(path, _)| {
581                        // Use the original path ID and not the forced one to
582                        // match Hyper-V storvsp behavior.
583                        if request.path_id == path.path && request.target_id == path.target {
584                            Some(path.lun)
585                        } else {
586                            None
587                        }
588                    })
589                    .collect();
590                luns.sort_unstable();
591                let mut data: Vec<u64> = vec![0; luns.len() + 1];
592                let header = scsi::LunList {
593                    length: (luns.len() as u32 * 8).into(),
594                    reserved: [0; 4],
595                };
596                data.as_mut_bytes()[..HEADER_SIZE].copy_from_slice(header.as_bytes());
597                for (i, lun) in luns.iter().enumerate() {
598                    data[i + 1].as_mut_bytes()[..2].copy_from_slice(&(*lun as u16).to_be_bytes());
599                }
600                if external_data.len() >= HEADER_SIZE {
601                    let tx = std::cmp::min(external_data.len(), data.as_bytes().len());
602                    external_data.writer().write(&data.as_bytes()[..tx]).map_or(
603                        ScsiResult {
604                            scsi_status: ScsiStatus::CHECK_CONDITION,
605                            srb_status: SrbStatus::INVALID_REQUEST,
606                            tx: 0,
607                            sense_data: Some(illegal_request_sense(
608                                AdditionalSenseCode::INVALID_CDB,
609                            )),
610                        },
611                        |_| ScsiResult {
612                            scsi_status: ScsiStatus::GOOD,
613                            srb_status: SrbStatus::SUCCESS,
614                            tx,
615                            sense_data: None,
616                        },
617                    )
618                } else {
619                    ScsiResult {
620                        scsi_status: ScsiStatus::GOOD,
621                        srb_status: SrbStatus::SUCCESS,
622                        tx: 0,
623                        sense_data: None,
624                    }
625                }
626            }
627            _ if controller_disk.is_some() => {
628                let mut cdb = [0; 16];
629                cdb.copy_from_slice(&request.payload[0..storvsp_protocol::CDB16GENERIC_LENGTH]);
630                controller_disk
631                    .unwrap()
632                    .disk
633                    .execute_scsi(
634                        &external_data,
635                        &Request {
636                            cdb,
637                            srb_flags: request.srb_flags,
638                        },
639                    )
640                    .await
641            }
642            ScsiOp::INQUIRY => {
643                let cdb = scsi::CdbInquiry::ref_from_prefix(&request.payload)
644                    .unwrap()
645                    .0; // TODO: zerocopy: ref-from-prefix: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
646                if external_data.len() < cdb.allocation_length.get() as usize
647                    || request.data_in != storvsp_protocol::SCSI_IOCTL_DATA_IN
648                    || (cdb.allocation_length.get() as usize) < size_of::<scsi::InquiryDataHeader>()
649                {
650                    ScsiResult {
651                        scsi_status: ScsiStatus::CHECK_CONDITION,
652                        srb_status: SrbStatus::INVALID_REQUEST,
653                        tx: 0,
654                        sense_data: Some(illegal_request_sense(AdditionalSenseCode::INVALID_CDB)),
655                    }
656                } else {
657                    let enable_vpd = cdb.flags.vpd();
658                    if enable_vpd || cdb.page_code != 0 {
659                        // cannot support VPD inquiry for non-existing device (lun).
660                        ScsiResult {
661                            scsi_status: ScsiStatus::CHECK_CONDITION,
662                            srb_status: SrbStatus::INVALID_REQUEST,
663                            tx: 0,
664                            sense_data: Some(illegal_request_sense(
665                                AdditionalSenseCode::INVALID_CDB,
666                            )),
667                        }
668                    } else {
669                        const LOGICAL_UNIT_NOT_PRESENT_DEVICE: u8 = 0x7F;
670                        let mut data = scsidisk::INQUIRY_DATA_TEMPLATE;
671                        data.header.device_type = LOGICAL_UNIT_NOT_PRESENT_DEVICE;
672
673                        if request.lun != 0 {
674                            // Below fields are only set for lun0 inquiry so zero out here.
675                            data.vendor_id = [0; 8];
676                            data.product_id = [0; 16];
677                            data.product_revision_level = [0; 4];
678                        }
679
680                        let datab = data.as_bytes();
681                        let tx = std::cmp::min(
682                            cdb.allocation_length.get() as usize,
683                            size_of::<scsi::InquiryData>(),
684                        );
685                        external_data.writer().write(&datab[..tx]).map_or(
686                            ScsiResult {
687                                scsi_status: ScsiStatus::CHECK_CONDITION,
688                                srb_status: SrbStatus::INVALID_REQUEST,
689                                tx: 0,
690                                sense_data: Some(illegal_request_sense(
691                                    AdditionalSenseCode::INVALID_CDB,
692                                )),
693                            },
694                            |_| ScsiResult {
695                                scsi_status: ScsiStatus::GOOD,
696                                srb_status: SrbStatus::SUCCESS,
697                                tx,
698                                sense_data: None,
699                            },
700                        )
701                    }
702                }
703            }
704            _ => ScsiResult {
705                scsi_status: ScsiStatus::CHECK_CONDITION,
706                srb_status: SrbStatus::INVALID_LUN,
707                tx: 0,
708                sense_data: None,
709            },
710        };
711
712        tracing::trace!(
713            path_id = request.path_id,
714            target_id = request.target_id,
715            lun = request.lun,
716            op = ?op,
717            result = ?result,
718            "execute_scsi completed.",
719        );
720        result
721    }
722}
723
724impl<T: RingMem + 'static> Worker<T> {
725    fn new(
726        controller: Arc<ScsiControllerState>,
727        channel: RawAsyncChannel<T>,
728        channel_index: u16,
729        mem: GuestMemory,
730        channel_control: ChannelControl,
731        io_queue_depth: u32,
732        protocol: Arc<Protocol>,
733        force_path_id: Option<u8>,
734    ) -> anyhow::Result<Self> {
735        let queue = Queue::new(channel)?;
736        #[expect(clippy::disallowed_methods)] // TODO
737        let (source, target) = futures::channel::mpsc::channel(1);
738        controller.add_rescan_notification_source(source);
739
740        let max_io_queue_depth = io_queue_depth.max(1) as usize;
741        Ok(Self {
742            inner: WorkerInner {
743                protocol,
744                request_size: storvsp_protocol::SCSI_REQUEST_LEN_V1,
745                controller: controller.clone(),
746                channel_index,
747                scsi_queue: Arc::new(ScsiCommandQueue {
748                    controller,
749                    mem,
750                    force_path_id,
751                }),
752                scsi_requests: FuturesUnordered::new(),
753                scsi_requests_states: Slab::with_capacity(max_io_queue_depth),
754                channel_control,
755                max_io_queue_depth,
756                future_pool: Vec::new(),
757                full_request_pool: Vec::new(),
758                stats: Default::default(),
759            },
760            queue,
761            rescan_notification: target,
762            fast_select: FastSelect::new(),
763        })
764    }
765
766    async fn wait_for_scsi_requests_complete(&mut self) {
767        tracing::debug!(
768            channel_index = self.inner.channel_index,
769            "wait for IOs completed..."
770        );
771        while let Some((id, _, _)) = self.inner.scsi_requests.next().await {
772            self.inner.scsi_requests_states.remove(id);
773        }
774    }
775}
776
777impl InspectTask<Worker> for WorkerState {
778    fn inspect(&self, req: inspect::Request<'_>, worker: Option<&Worker>) {
779        if let Some(worker) = worker {
780            let mut resp = req.respond();
781            if worker.inner.channel_index == 0 {
782                let (state, version, subchannel_count) = match *worker.inner.protocol.state.read() {
783                    ProtocolState::Init(state) => match state {
784                        InitState::Begin => ("begin_init", None, None),
785                        InitState::QueryVersion => ("query_version", None, None),
786                        InitState::QueryProperties { version } => {
787                            ("query_properties", Some(version), None)
788                        }
789                        InitState::EndInitialization {
790                            version,
791                            subchannel_count,
792                        } => ("end_init", Some(version), subchannel_count),
793                    },
794                    ProtocolState::Ready {
795                        version,
796                        subchannel_count,
797                    } => ("ready", Some(version), Some(subchannel_count)),
798                };
799                resp.field("state", state)
800                    .field("version", version)
801                    .field("subchannel_count", subchannel_count);
802            }
803            resp.field("pending_packets", worker.inner.scsi_requests_states.len())
804                .fields("io", worker.inner.scsi_requests_states.iter())
805                .field("stats", &worker.inner.stats)
806                .field("ring", &worker.queue)
807                .field("max_io_queue_depth", worker.inner.max_io_queue_depth);
808        }
809    }
810}
811
812impl<T: 'static + Send + Sync + RingMem> AsyncRun<Worker<T>> for WorkerState {
813    async fn run(
814        &mut self,
815        stop: &mut StopTask<'_>,
816        worker: &mut Worker<T>,
817    ) -> Result<(), task_control::Cancelled> {
818        let fut = async {
819            if worker.inner.channel_index == 0 {
820                worker.process_primary().await
821            } else {
822                // Wait for initialization to end before processing any
823                // subchannel packets.
824                let protocol_version = loop {
825                    let listener = worker.inner.protocol.ready.listen();
826                    if let ProtocolState::Ready { version, .. } =
827                        *worker.inner.protocol.state.read()
828                    {
829                        break version;
830                    }
831                    tracing::debug!("subchannel waiting for initialization to end");
832                    listener.await
833                };
834                worker
835                    .inner
836                    .process_ready(&mut worker.queue, protocol_version)
837                    .await
838            }
839        };
840
841        match stop.until_stopped(fut).await? {
842            Ok(_) => {}
843            Err(e) => tracing::error!(error = e.as_error(), "process_packets error"),
844        }
845        Ok(())
846    }
847}
848
849impl WorkerInner {
850    /// Awaits the next incoming packet, without checking for any other events (device add/remove notifications or available completions).
851    /// Increments the count of outstanding packets when returning `Ok(Packet)`.
852    async fn next_packet<'a, M: RingMem>(
853        &mut self,
854        reader: &'a mut queue::ReadHalf<'a, M>,
855    ) -> Result<Packet, WorkerError> {
856        let packet = reader.read().await.map_err(WorkerError::Queue)?;
857        let stor_packet =
858            parse_packet(&packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
859        Ok(stor_packet)
860    }
861
862    /// Polls for enough ring space in the outgoing ring to send a packet.
863    ///
864    /// This is used to ensure there is enough space in the ring before
865    /// committing to sending a packet. This avoids the need to save pending
866    /// packets on the side if queue processing is interrupted while the ring is
867    /// full.
868    fn poll_for_ring_space<M: RingMem>(
869        &mut self,
870        cx: &mut Context<'_>,
871        writer: &mut queue::WriteHalf<'_, M>,
872    ) -> Poll<Result<(), WorkerError>> {
873        writer
874            .poll_ready(cx, MAX_VMBUS_PACKET_SIZE)
875            .map_err(WorkerError::Queue)
876    }
877}
878
879const MAX_VMBUS_PACKET_SIZE: usize = ring::PacketSize::in_band(
880    size_of::<storvsp_protocol::Packet>() + storvsp_protocol::SCSI_REQUEST_LEN_MAX,
881);
882
883impl<T: RingMem> Worker<T> {
884    /// Processes the protocol state machine.
885    async fn process_primary(&mut self) -> Result<(), WorkerError> {
886        loop {
887            let current_state = *self.inner.protocol.state.read();
888            match current_state {
889                ProtocolState::Ready { version, .. } => {
890                    break loop {
891                        select_biased! {
892                            r = self.inner.process_ready(&mut self.queue, version).fuse() => break r,
893                            _ = self.fast_select.select((self.rescan_notification.select_next_some(),)).fuse() => {
894                                if version >= Version::Win7
895                                {
896                                    self.inner.send_packet(&mut self.queue.split().1, storvsp_protocol::Operation::ENUMERATE_BUS, storvsp_protocol::NtStatus::SUCCESS, &())?;
897                                }
898                            }
899                        }
900                    };
901                }
902                ProtocolState::Init(state) => {
903                    let (mut reader, mut writer) = self.queue.split();
904
905                    // Ensure that subsequent calls to `send_completion` won't
906                    // fail due to lack of ring space, to avoid keeping (and saving/restoring) interim states.
907                    poll_fn(|cx| self.inner.poll_for_ring_space(cx, &mut writer)).await?;
908
909                    tracing::debug!(?state, "process_primary");
910                    match state {
911                        InitState::Begin => {
912                            let packet = self.inner.next_packet(&mut reader).await?;
913                            if let PacketData::BeginInitialization = packet.data {
914                                self.inner.send_completion(
915                                    &mut writer,
916                                    &packet,
917                                    storvsp_protocol::NtStatus::SUCCESS,
918                                    &(),
919                                )?;
920                                *self.inner.protocol.state.write() =
921                                    ProtocolState::Init(InitState::QueryVersion);
922                            } else {
923                                tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
924                                self.inner.send_completion(
925                                    &mut writer,
926                                    &packet,
927                                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
928                                    &(),
929                                )?;
930                            }
931                        }
932                        InitState::QueryVersion => {
933                            let packet = self.inner.next_packet(&mut reader).await?;
934                            if let PacketData::QueryProtocolVersion(major_minor) = packet.data {
935                                if let Ok(version) = Version::parse(major_minor) {
936                                    self.inner.send_completion(
937                                        &mut writer,
938                                        &packet,
939                                        storvsp_protocol::NtStatus::SUCCESS,
940                                        &storvsp_protocol::ProtocolVersion {
941                                            major_minor,
942                                            reserved: 0,
943                                        },
944                                    )?;
945                                    self.inner.request_size = version.max_request_size();
946                                    *self.inner.protocol.state.write() =
947                                        ProtocolState::Init(InitState::QueryProperties { version });
948
949                                    tracelimit::info_ratelimited!(
950                                        ?version,
951                                        "scsi version negotiated"
952                                    );
953                                } else {
954                                    self.inner.send_completion(
955                                        &mut writer,
956                                        &packet,
957                                        storvsp_protocol::NtStatus::REVISION_MISMATCH,
958                                        &storvsp_protocol::ProtocolVersion {
959                                            major_minor,
960                                            reserved: 0,
961                                        },
962                                    )?;
963                                    *self.inner.protocol.state.write() =
964                                        ProtocolState::Init(InitState::QueryVersion);
965                                }
966                            } else {
967                                tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
968                                self.inner.send_completion(
969                                    &mut writer,
970                                    &packet,
971                                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
972                                    &(),
973                                )?;
974                            }
975                        }
976                        InitState::QueryProperties { version } => {
977                            let packet = self.inner.next_packet(&mut reader).await?;
978                            if let PacketData::QueryProperties = packet.data {
979                                let multi_channel_supported = version >= Version::Win8;
980
981                                self.inner.send_completion(
982                                    &mut writer,
983                                    &packet,
984                                    storvsp_protocol::NtStatus::SUCCESS,
985                                    &storvsp_protocol::ChannelProperties {
986                                        max_transfer_bytes: 0x40000, // 256KB
987                                        flags: {
988                                            if multi_channel_supported {
989                                                storvsp_protocol::STORAGE_CHANNEL_SUPPORTS_MULTI_CHANNEL
990                                            } else {
991                                                0
992                                            }
993                                        },
994                                        maximum_sub_channel_count: if multi_channel_supported {
995                                            self.inner.channel_control.max_subchannels()
996                                        } else {
997                                            0
998                                        },
999                                        reserved: 0,
1000                                        reserved2: 0,
1001                                        reserved3: [0, 0],
1002                                    },
1003                                )?;
1004                                *self.inner.protocol.state.write() =
1005                                    ProtocolState::Init(InitState::EndInitialization {
1006                                        version,
1007                                        subchannel_count: if multi_channel_supported {
1008                                            None
1009                                        } else {
1010                                            Some(0)
1011                                        },
1012                                    });
1013                            } else {
1014                                tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1015                                self.inner.send_completion(
1016                                    &mut writer,
1017                                    &packet,
1018                                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1019                                    &(),
1020                                )?;
1021                            }
1022                        }
1023                        InitState::EndInitialization {
1024                            version,
1025                            subchannel_count,
1026                        } => {
1027                            let packet = self.inner.next_packet(&mut reader).await?;
1028                            match packet.data {
1029                                PacketData::CreateSubChannels(sub_channel_count)
1030                                    if subchannel_count.is_none() =>
1031                                {
1032                                    if let Err(err) = self
1033                                        .inner
1034                                        .channel_control
1035                                        .enable_subchannels(sub_channel_count)
1036                                    {
1037                                        tracelimit::warn_ratelimited!(
1038                                            ?err,
1039                                            "cannot enable subchannels"
1040                                        );
1041                                        self.inner.send_completion(
1042                                            &mut writer,
1043                                            &packet,
1044                                            storvsp_protocol::NtStatus::INVALID_PARAMETER,
1045                                            &(),
1046                                        )?;
1047                                    } else {
1048                                        self.inner.send_completion(
1049                                            &mut writer,
1050                                            &packet,
1051                                            storvsp_protocol::NtStatus::SUCCESS,
1052                                            &(),
1053                                        )?;
1054                                        *self.inner.protocol.state.write() =
1055                                            ProtocolState::Init(InitState::EndInitialization {
1056                                                version,
1057                                                subchannel_count: Some(sub_channel_count),
1058                                            });
1059                                    }
1060                                }
1061                                PacketData::EndInitialization => {
1062                                    self.inner.send_completion(
1063                                        &mut writer,
1064                                        &packet,
1065                                        storvsp_protocol::NtStatus::SUCCESS,
1066                                        &(),
1067                                    )?;
1068                                    // Reset the rescan notification event now, before the guest has a
1069                                    // chance to send any SCSI requests to scan the bus.
1070                                    self.rescan_notification.try_next().ok();
1071                                    *self.inner.protocol.state.write() = ProtocolState::Ready {
1072                                        version,
1073                                        subchannel_count: subchannel_count.unwrap_or(0),
1074                                    };
1075                                    // Wake up subchannels waiting for the
1076                                    // protocol state to become ready.
1077                                    self.inner.protocol.ready.notify(usize::MAX);
1078                                }
1079                                _ => {
1080                                    tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1081                                    self.inner.send_completion(
1082                                        &mut writer,
1083                                        &packet,
1084                                        storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1085                                        &(),
1086                                    )?;
1087                                }
1088                            }
1089                        }
1090                    }
1091                }
1092            }
1093        }
1094    }
1095}
1096
1097fn convert_srb_status_to_nt_status(srb_status: SrbStatus) -> storvsp_protocol::NtStatus {
1098    match srb_status {
1099        SrbStatus::BUSY => storvsp_protocol::NtStatus::DEVICE_BUSY,
1100        SrbStatus::SUCCESS => storvsp_protocol::NtStatus::SUCCESS,
1101        SrbStatus::INVALID_LUN
1102        | SrbStatus::INVALID_TARGET_ID
1103        | SrbStatus::NO_DEVICE
1104        | SrbStatus::NO_HBA => storvsp_protocol::NtStatus::DEVICE_DOES_NOT_EXIST,
1105        SrbStatus::COMMAND_TIMEOUT | SrbStatus::TIMEOUT => storvsp_protocol::NtStatus::IO_TIMEOUT,
1106        SrbStatus::SELECTION_TIMEOUT => storvsp_protocol::NtStatus::DEVICE_NOT_CONNECTED,
1107        SrbStatus::BAD_FUNCTION | SrbStatus::BAD_SRB_BLOCK_LENGTH => {
1108            storvsp_protocol::NtStatus::INVALID_DEVICE_REQUEST
1109        }
1110        SrbStatus::DATA_OVERRUN => storvsp_protocol::NtStatus::BUFFER_OVERFLOW,
1111        SrbStatus::REQUEST_FLUSHED => storvsp_protocol::NtStatus::UNSUCCESSFUL,
1112        SrbStatus::ABORTED => storvsp_protocol::NtStatus::CANCELLED,
1113        _ => storvsp_protocol::NtStatus::IO_DEVICE_ERROR,
1114    }
1115}
1116
1117impl WorkerInner {
1118    /// Processes packets and SCSI completions after protocol negotiation has finished.
1119    async fn process_ready<M: RingMem>(
1120        &mut self,
1121        queue: &mut Queue<M>,
1122        protocol_version: Version,
1123    ) -> Result<(), WorkerError> {
1124        self.request_size = protocol_version.max_request_size();
1125        poll_fn(|cx| self.poll_process_ready(cx, queue)).await
1126    }
1127
1128    /// Processes packets and SCSI completions after protocol negotiation has finished.
1129    fn poll_process_ready<M: RingMem>(
1130        &mut self,
1131        cx: &mut Context<'_>,
1132        queue: &mut Queue<M>,
1133    ) -> Poll<Result<(), WorkerError>> {
1134        self.stats.wakes.increment();
1135
1136        let (mut reader, mut writer) = queue.split();
1137        let mut total_completions = 0;
1138        let mut total_submissions = 0;
1139        let poll_mode_queue_depth = self.controller.poll_mode_queue_depth.load(Relaxed) as usize;
1140
1141        loop {
1142            // Drive IOs forward and collect completions.
1143            'outer: while !self.scsi_requests_states.is_empty() {
1144                {
1145                    let mut batch = writer.batched();
1146                    loop {
1147                        // Ensure there is room for the completion before consuming
1148                        // the IO so that we don't have to track completed IOs whose
1149                        // completions haven't been sent.
1150                        if !batch
1151                            .can_write(MAX_VMBUS_PACKET_SIZE)
1152                            .map_err(WorkerError::Queue)?
1153                        {
1154                            // This batch is full but there may still be more completions.
1155                            break;
1156                        }
1157                        if let Poll::Ready(Some((request_id, result, future))) =
1158                            self.scsi_requests.poll_next_unpin(cx)
1159                        {
1160                            self.future_pool.push(future);
1161                            self.handle_completion(&mut batch, request_id, result)?;
1162                            total_completions += 1;
1163                        } else {
1164                            tracing::trace!("out of completions");
1165                            break 'outer;
1166                        }
1167                    }
1168                }
1169
1170                // Wait for enough space for any completion packets.
1171                if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1172                    tracing::trace!("out of ring space");
1173                    break;
1174                }
1175            }
1176
1177            let mut submissions = 0;
1178            // Process new requests.
1179            'outer: loop {
1180                if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1181                    break;
1182                }
1183                let mut batch = if self.scsi_requests_states.len() < poll_mode_queue_depth {
1184                    if let Poll::Ready(batch) = reader.poll_read_batch(cx) {
1185                        batch.map_err(WorkerError::Queue)?
1186                    } else {
1187                        tracing::trace!("out of incoming packets");
1188                        break;
1189                    }
1190                } else {
1191                    match reader.try_read_batch() {
1192                        Ok(batch) => batch,
1193                        Err(queue::TryReadError::Empty) => {
1194                            tracing::trace!(
1195                                pending_io_count = self.scsi_requests_states.len(),
1196                                "out of incoming packets, keeping interrupts masked"
1197                            );
1198                            break;
1199                        }
1200                        Err(queue::TryReadError::Queue(err)) => Err(WorkerError::Queue(err))?,
1201                    }
1202                };
1203
1204                let mut packets = batch.packets();
1205                loop {
1206                    if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1207                        break 'outer;
1208                    }
1209                    // Wait for enough space for any completion packets that
1210                    // `handle_packet` may need to send, so that it isn't necessary
1211                    // to track pending completions.
1212                    if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1213                        tracing::trace!("out of ring space");
1214                        break 'outer;
1215                    }
1216
1217                    let packet = if let Some(packet) = packets.next() {
1218                        packet.map_err(WorkerError::Queue)?
1219                    } else {
1220                        break;
1221                    };
1222
1223                    if self.handle_packet(&mut writer, &packet)? {
1224                        submissions += 1;
1225                    }
1226                }
1227            }
1228
1229            // Loop around to poll the IOs if any new IOs were submitted.
1230            if submissions == 0 {
1231                // No need to poll again.
1232                break;
1233            }
1234            total_submissions += submissions;
1235        }
1236
1237        if total_submissions != 0 || total_completions != 0 {
1238            self.stats.ios_submitted.add(total_submissions);
1239            self.stats
1240                .per_wake_submissions
1241                .add_sample(total_submissions);
1242            self.stats
1243                .per_wake_completions
1244                .add_sample(total_completions);
1245            self.stats.ios_completed.add(total_completions);
1246        } else {
1247            self.stats.wakes_spurious.increment();
1248        }
1249
1250        Poll::Pending
1251    }
1252
1253    fn handle_completion<M: RingMem>(
1254        &mut self,
1255        writer: &mut queue::WriteBatch<'_, M>,
1256        request_id: usize,
1257        result: ScsiResult,
1258    ) -> Result<(), WorkerError> {
1259        let state = self.scsi_requests_states.remove(request_id);
1260        let request_size = state.request.request_size;
1261
1262        // Push the request into the pool to avoid reallocating later.
1263        assert_eq!(
1264            Arc::strong_count(&state.request) + Arc::weak_count(&state.request),
1265            1
1266        );
1267        self.full_request_pool.push(state.request);
1268
1269        let status = convert_srb_status_to_nt_status(result.srb_status);
1270        let mut payload = [0; 0x14];
1271        if let Some(sense) = result.sense_data {
1272            payload[..size_of_val(&sense)].copy_from_slice(sense.as_bytes());
1273            tracing::trace!(sense_info = ?payload, sense_key = payload[2], asc = payload[12], "execute_scsi");
1274        };
1275        let response = storvsp_protocol::ScsiRequest {
1276            length: size_of::<storvsp_protocol::ScsiRequest>() as u16,
1277            scsi_status: result.scsi_status,
1278            srb_status: SrbStatusAndFlags::new()
1279                .with_status(result.srb_status)
1280                .with_autosense_valid(result.sense_data.is_some()),
1281            data_transfer_length: result.tx as u32,
1282            cdb_length: storvsp_protocol::CDB16GENERIC_LENGTH as u8,
1283            sense_info_ex_length: storvsp_protocol::VMSCSI_SENSE_BUFFER_SIZE as u8,
1284            payload,
1285            ..storvsp_protocol::ScsiRequest::new_zeroed()
1286        };
1287        self.send_vmbus_packet(
1288            writer,
1289            OutgoingPacketType::Completion,
1290            request_size,
1291            state.transaction_id,
1292            storvsp_protocol::Operation::COMPLETE_IO,
1293            status,
1294            response.as_bytes(),
1295        )?;
1296        Ok(())
1297    }
1298
1299    fn handle_packet<M: RingMem>(
1300        &mut self,
1301        writer: &mut queue::WriteHalf<'_, M>,
1302        packet: &IncomingPacket<'_, M>,
1303    ) -> Result<bool, WorkerError> {
1304        let packet =
1305            parse_packet(packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
1306        let submitted_io = match packet.data {
1307            PacketData::ExecuteScsi(request) => {
1308                self.push_scsi_request(packet.transaction_id, request);
1309                true
1310            }
1311            PacketData::ResetAdapter | PacketData::ResetBus | PacketData::ResetLun => {
1312                // These operations have always been no-ops.
1313                self.send_completion(writer, &packet, storvsp_protocol::NtStatus::SUCCESS, &())?;
1314                false
1315            }
1316            PacketData::CreateSubChannels(new_subchannel_count) if self.channel_index == 0 => {
1317                if let Err(err) = self
1318                    .channel_control
1319                    .enable_subchannels(new_subchannel_count)
1320                {
1321                    tracelimit::warn_ratelimited!(?err, "cannot create subchannels");
1322                    self.send_completion(
1323                        writer,
1324                        &packet,
1325                        storvsp_protocol::NtStatus::INVALID_PARAMETER,
1326                        &(),
1327                    )?;
1328                    false
1329                } else {
1330                    // Update the subchannel count in the protocol state for save.
1331                    if let ProtocolState::Ready {
1332                        subchannel_count, ..
1333                    } = &mut *self.protocol.state.write()
1334                    {
1335                        *subchannel_count = new_subchannel_count;
1336                    } else {
1337                        unreachable!()
1338                    }
1339
1340                    self.send_completion(
1341                        writer,
1342                        &packet,
1343                        storvsp_protocol::NtStatus::SUCCESS,
1344                        &(),
1345                    )?;
1346                    false
1347                }
1348            }
1349            _ => {
1350                tracelimit::warn_ratelimited!(data = ?packet.data, "unexpected packet on ready");
1351                self.send_completion(
1352                    writer,
1353                    &packet,
1354                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1355                    &(),
1356                )?;
1357                false
1358            }
1359        };
1360        Ok(submitted_io)
1361    }
1362
1363    fn push_scsi_request(&mut self, transaction_id: u64, full_request: Arc<ScsiRequestAndRange>) {
1364        let scsi_queue = self.scsi_queue.clone();
1365        let scsi_request_state = ScsiRequestState {
1366            transaction_id,
1367            request: full_request.clone(),
1368        };
1369        let request_id = self.scsi_requests_states.insert(scsi_request_state);
1370        let future = self
1371            .future_pool
1372            .pop()
1373            .unwrap_or_else(|| OversizedBox::new(()));
1374        let future = OversizedBox::refill(future, async move {
1375            scsi_queue.execute_scsi(full_request.as_ref()).await
1376        });
1377        let request = ScsiRequest::new(request_id, oversized_box::coerce!(future));
1378        self.scsi_requests.push(request);
1379    }
1380}
1381
1382impl<T: RingMem> Drop for Worker<T> {
1383    fn drop(&mut self) {
1384        self.inner
1385            .controller
1386            .remove_rescan_notification_source(&self.rescan_notification);
1387    }
1388}
1389
1390#[derive(Debug, Error)]
1391#[error("SCSI path {}:{}:{} is already in use", self.0.path, self.0.target, self.0.lun)]
1392pub struct ScsiPathInUse(pub ScsiPath);
1393
1394#[derive(Debug, Error)]
1395#[error("SCSI path {}:{}:{} is not in use", self.0.path, self.0.target, self.0.lun)]
1396pub struct ScsiPathNotInUse(ScsiPath);
1397
1398#[derive(Clone)]
1399struct ScsiRequestState {
1400    transaction_id: u64,
1401    request: Arc<ScsiRequestAndRange>,
1402}
1403
1404#[derive(Debug)]
1405struct ScsiRequestAndRange {
1406    external_data: Range,
1407    external_data_buf: MultiPagedRangeBuf,
1408    request: storvsp_protocol::ScsiRequest,
1409    request_size: usize,
1410}
1411
1412impl Inspect for ScsiRequestState {
1413    fn inspect(&self, req: inspect::Request<'_>) {
1414        req.respond()
1415            .field("transaction_id", self.transaction_id)
1416            .display(
1417                "address",
1418                &ScsiPath {
1419                    path: self.request.request.path_id,
1420                    target: self.request.request.target_id,
1421                    lun: self.request.request.lun,
1422                },
1423            )
1424            .display_debug("operation", &ScsiOp(self.request.request.payload[0]));
1425    }
1426}
1427
1428impl StorageDevice {
1429    /// Returns a new SCSI device.
1430    pub fn build_scsi(
1431        driver_source: &VmTaskDriverSource,
1432        controller: &ScsiController,
1433        instance_id: Guid,
1434        max_sub_channel_count: u16,
1435        io_queue_depth: u32,
1436    ) -> Self {
1437        Self::build_inner(
1438            driver_source,
1439            controller,
1440            instance_id,
1441            None,
1442            max_sub_channel_count,
1443            io_queue_depth,
1444        )
1445    }
1446
1447    /// Returns a new SCSI device for implementing an IDE accelerator channel
1448    /// for IDE device `device_id` on channel `channel_id`.
1449    pub fn build_ide(
1450        driver_source: &VmTaskDriverSource,
1451        channel_id: u8,
1452        device_id: u8,
1453        disk: ScsiControllerDisk,
1454        io_queue_depth: u32,
1455    ) -> Self {
1456        let path = ScsiPath {
1457            path: channel_id,
1458            target: device_id,
1459            lun: 0,
1460        };
1461
1462        let controller = ScsiController::new();
1463        controller.attach(path, disk).unwrap();
1464
1465        // Construct the specific GUID that drivers in the guest expect for this
1466        // IDE device.
1467        let instance_id = Guid {
1468            data1: channel_id.into(),
1469            data2: device_id.into(),
1470            data3: 0x8899,
1471            data4: [0; 8],
1472        };
1473        Self::build_inner(
1474            driver_source,
1475            &controller,
1476            instance_id,
1477            Some(path),
1478            0,
1479            io_queue_depth,
1480        )
1481    }
1482
1483    fn build_inner(
1484        driver_source: &VmTaskDriverSource,
1485        controller: &ScsiController,
1486        instance_id: Guid,
1487        ide_path: Option<ScsiPath>,
1488        max_sub_channel_count: u16,
1489        io_queue_depth: u32,
1490    ) -> Self {
1491        let workers = (0..max_sub_channel_count + 1)
1492            .map(|channel_index| WorkerAndDriver {
1493                worker: TaskControl::new(WorkerState),
1494                driver: driver_source
1495                    .builder()
1496                    .target_vp(0)
1497                    .run_on_target(true)
1498                    .build(format!("storvsp-{}-{}", instance_id, channel_index)),
1499            })
1500            .collect();
1501
1502        Self {
1503            instance_id,
1504            ide_path,
1505            workers,
1506            controller: controller.state.clone(),
1507            resources: Default::default(),
1508            max_sub_channel_count,
1509            driver_source: driver_source.clone(),
1510            protocol: Arc::new(Protocol {
1511                state: RwLock::new(ProtocolState::Init(InitState::Begin)),
1512                ready: Default::default(),
1513            }),
1514            io_queue_depth,
1515        }
1516    }
1517
1518    fn new_worker(
1519        &mut self,
1520        open_request: &OpenRequest,
1521        channel_index: u16,
1522    ) -> anyhow::Result<&mut Worker> {
1523        let controller = self.controller.clone();
1524
1525        let driver = self
1526            .driver_source
1527            .builder()
1528            .target_vp(open_request.open_data.target_vp)
1529            .run_on_target(true)
1530            .build(format!("storvsp-{}-{}", self.instance_id, channel_index));
1531
1532        let channel = gpadl_channel(&driver, &self.resources, open_request, channel_index)
1533            .context("failed to create vmbus channel")?;
1534
1535        let channel_control = self.resources.channel_control.clone();
1536
1537        tracing::debug!(
1538            target_vp = open_request.open_data.target_vp,
1539            channel_index,
1540            "packet processing starting...",
1541        );
1542
1543        // Force the path ID on incoming SCSI requests to match the IDE
1544        // channel ID, since guests do not reliably set the path ID
1545        // correctly.
1546        let force_path_id = self.ide_path.map(|p| p.path);
1547
1548        let worker = Worker::new(
1549            controller,
1550            channel,
1551            channel_index,
1552            self.resources
1553                .offer_resources
1554                .guest_memory(open_request)
1555                .clone(),
1556            channel_control,
1557            self.io_queue_depth,
1558            self.protocol.clone(),
1559            force_path_id,
1560        )
1561        .map_err(RestoreError::Other)?;
1562
1563        self.workers[channel_index as usize]
1564            .driver
1565            .retarget_vp(open_request.open_data.target_vp);
1566
1567        Ok(self.workers[channel_index as usize].worker.insert(
1568            &driver,
1569            format!("storvsp worker {}-{}", self.instance_id, channel_index),
1570            worker,
1571        ))
1572    }
1573}
1574
1575/// A disk that can be added to a SCSI controller.
1576#[derive(Clone)]
1577pub struct ScsiControllerDisk {
1578    disk: Arc<dyn AsyncScsiDisk>,
1579}
1580
1581impl ScsiControllerDisk {
1582    /// Creates a new controller disk from an async SCSI disk.
1583    pub fn new(disk: Arc<dyn AsyncScsiDisk>) -> Self {
1584        Self { disk }
1585    }
1586}
1587
1588struct ScsiControllerState {
1589    disks: RwLock<HashMap<ScsiPath, ScsiControllerDisk>>,
1590    rescan_notification_source: Mutex<Vec<futures::channel::mpsc::Sender<()>>>,
1591    poll_mode_queue_depth: AtomicU32,
1592}
1593
1594pub struct ScsiController {
1595    state: Arc<ScsiControllerState>,
1596}
1597
1598impl ScsiController {
1599    pub fn new() -> Self {
1600        Self::new_with_poll_mode_queue_depth(None)
1601    }
1602
1603    pub fn new_with_poll_mode_queue_depth(poll_mode_queue_depth: Option<u32>) -> Self {
1604        Self {
1605            state: Arc::new(ScsiControllerState {
1606                disks: Default::default(),
1607                rescan_notification_source: Mutex::new(Vec::new()),
1608                poll_mode_queue_depth: AtomicU32::new(
1609                    poll_mode_queue_depth.unwrap_or(DEFAULT_POLL_MODE_QUEUE_DEPTH),
1610                ),
1611            }),
1612        }
1613    }
1614
1615    pub fn attach(&self, path: ScsiPath, disk: ScsiControllerDisk) -> Result<(), ScsiPathInUse> {
1616        match self.state.disks.write().entry(path) {
1617            Entry::Occupied(_) => return Err(ScsiPathInUse(path)),
1618            Entry::Vacant(entry) => entry.insert(disk),
1619        };
1620        for source in self.state.rescan_notification_source.lock().iter_mut() {
1621            // Ok to ignore errors here. If the channel is full a previous notification has not yet
1622            // been processed by the primary channel worker.
1623            source.try_send(()).ok();
1624        }
1625        Ok(())
1626    }
1627
1628    pub fn remove(&self, path: ScsiPath) -> Result<(), ScsiPathNotInUse> {
1629        match self.state.disks.write().entry(path) {
1630            Entry::Vacant(_) => return Err(ScsiPathNotInUse(path)),
1631            Entry::Occupied(entry) => {
1632                entry.remove();
1633            }
1634        }
1635        for source in self.state.rescan_notification_source.lock().iter_mut() {
1636            // Ok to ignore errors here. If the channel is full a previous notification has not yet
1637            // been processed by the primary channel worker.
1638            source.try_send(()).ok();
1639        }
1640        Ok(())
1641    }
1642}
1643
1644impl ScsiControllerState {
1645    fn add_rescan_notification_source(&self, source: futures::channel::mpsc::Sender<()>) {
1646        self.rescan_notification_source.lock().push(source);
1647    }
1648
1649    fn remove_rescan_notification_source(&self, target: &futures::channel::mpsc::Receiver<()>) {
1650        let mut sources = self.rescan_notification_source.lock();
1651        if let Some(index) = sources
1652            .iter()
1653            .position(|source| source.is_connected_to(target))
1654        {
1655            sources.remove(index);
1656        }
1657    }
1658}
1659
1660#[async_trait]
1661impl VmbusDevice for StorageDevice {
1662    fn offer(&self) -> OfferParams {
1663        if let Some(path) = self.ide_path {
1664            let offer_properties = storvsp_protocol::OfferProperties {
1665                path_id: path.path,
1666                target_id: path.target,
1667                flags: storvsp_protocol::OFFER_PROPERTIES_FLAG_IDE_DEVICE,
1668                ..FromZeros::new_zeroed()
1669            };
1670            let mut user_defined = UserDefinedData::new_zeroed();
1671            offer_properties
1672                .write_to_prefix(&mut user_defined[..])
1673                .unwrap();
1674            OfferParams {
1675                interface_name: "ide-accel".to_owned(),
1676                instance_id: self.instance_id,
1677                interface_id: storvsp_protocol::IDE_ACCELERATOR_INTERFACE_ID,
1678                channel_type: ChannelType::Interface { user_defined },
1679                ..Default::default()
1680            }
1681        } else {
1682            OfferParams {
1683                interface_name: "scsi".to_owned(),
1684                instance_id: self.instance_id,
1685                interface_id: storvsp_protocol::SCSI_INTERFACE_ID,
1686                ..Default::default()
1687            }
1688        }
1689    }
1690
1691    fn max_subchannels(&self) -> u16 {
1692        self.max_sub_channel_count
1693    }
1694
1695    fn install(&mut self, resources: DeviceResources) {
1696        self.resources = resources;
1697    }
1698
1699    async fn open(
1700        &mut self,
1701        channel_index: u16,
1702        open_request: &OpenRequest,
1703    ) -> Result<(), ChannelOpenError> {
1704        tracing::debug!(channel_index, "scsi open channel");
1705        self.new_worker(open_request, channel_index)?;
1706        self.workers[channel_index as usize].worker.start();
1707        Ok(())
1708    }
1709
1710    async fn close(&mut self, channel_index: u16) {
1711        tracing::debug!(channel_index, "scsi close channel");
1712        let worker = &mut self.workers[channel_index as usize].worker;
1713        worker.stop().await;
1714        if worker.state_mut().is_some() {
1715            worker
1716                .state_mut()
1717                .unwrap()
1718                .wait_for_scsi_requests_complete()
1719                .await;
1720            worker.remove();
1721        }
1722        if channel_index == 0 {
1723            *self.protocol.state.write() = ProtocolState::Init(InitState::Begin);
1724        }
1725    }
1726
1727    async fn retarget_vp(&mut self, channel_index: u16, target_vp: u32) {
1728        self.workers[channel_index as usize]
1729            .driver
1730            .retarget_vp(target_vp);
1731    }
1732
1733    fn start(&mut self) {
1734        for task in self
1735            .workers
1736            .iter_mut()
1737            .filter(|task| task.worker.has_state() && !task.worker.is_running())
1738        {
1739            task.worker.start();
1740        }
1741    }
1742
1743    async fn stop(&mut self) {
1744        tracing::debug!(instance_id = ?self.instance_id, "StorageDevice stopping...");
1745        for task in self
1746            .workers
1747            .iter_mut()
1748            .filter(|task| task.worker.has_state() && task.worker.is_running())
1749        {
1750            task.worker.stop().await;
1751        }
1752    }
1753
1754    fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1755        Some(self)
1756    }
1757}
1758
1759#[async_trait]
1760impl SaveRestoreVmbusDevice for StorageDevice {
1761    async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1762        Ok(SavedStateBlob::new(self.save()?))
1763    }
1764
1765    async fn restore(
1766        &mut self,
1767        control: RestoreControl<'_>,
1768        state: SavedStateBlob,
1769    ) -> Result<(), RestoreError> {
1770        self.restore(control, state.parse()?).await
1771    }
1772}
1773
1774#[cfg(test)]
1775mod tests {
1776    use super::*;
1777    use crate::test_helpers::TestWorker;
1778    use crate::test_helpers::parse_guest_completion;
1779    use crate::test_helpers::parse_guest_completion_check_flags_status;
1780    use pal_async::DefaultDriver;
1781    use pal_async::async_test;
1782    use scsi::srb::SrbStatus;
1783    use test_with_tracing::test;
1784    use vmbus_channel::connected_async_channels;
1785
1786    // Discourage `Clone` for `ScsiController` outside the crate, but it is
1787    // necessary for testing. The fuzzer also uses `TestWorker`, which needs
1788    // a `clone` of the inner state, but is not in this crate.
1789    impl Clone for ScsiController {
1790        fn clone(&self) -> Self {
1791            ScsiController {
1792                state: self.state.clone(),
1793            }
1794        }
1795    }
1796
1797    #[async_test]
1798    async fn test_channel_working(driver: DefaultDriver) {
1799        // set up the channels and worker
1800        let (host, guest) = connected_async_channels(16 * 1024);
1801        let guest_queue = Queue::new(guest).unwrap();
1802
1803        let test_guest_mem = GuestMemory::allocate(16384);
1804        let controller = ScsiController::new();
1805        let disk = scsidisk::SimpleScsiDisk::new(
1806            disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
1807            Default::default(),
1808        );
1809        controller
1810            .attach(
1811                ScsiPath {
1812                    path: 0,
1813                    target: 0,
1814                    lun: 0,
1815                },
1816                ScsiControllerDisk::new(Arc::new(disk)),
1817            )
1818            .unwrap();
1819
1820        let test_worker = TestWorker::start(
1821            controller.clone(),
1822            driver.clone(),
1823            test_guest_mem.clone(),
1824            host,
1825            None,
1826        );
1827
1828        let mut guest = test_helpers::TestGuest {
1829            queue: guest_queue,
1830            transaction_id: 0,
1831        };
1832
1833        guest.perform_protocol_negotiation().await;
1834
1835        // Set up the buffer for a write request
1836        const IO_LEN: usize = 4 * 1024;
1837        let write_buf = [7u8; IO_LEN];
1838        let write_gpa = 4 * 1024u64;
1839        test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
1840        guest
1841            .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
1842            .await;
1843
1844        guest
1845            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1846            .await;
1847
1848        let read_gpa = 8 * 1024u64;
1849        guest
1850            .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
1851            .await;
1852
1853        guest
1854            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1855            .await;
1856        let mut read_buf = [0u8; IO_LEN];
1857        test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
1858        for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
1859            assert_eq!(b1, b2);
1860        }
1861
1862        // stop everything
1863        guest.verify_graceful_close(test_worker).await;
1864    }
1865
1866    #[async_test]
1867    async fn test_packet_sizes(driver: DefaultDriver) {
1868        // set up the channels and worker
1869        let (host, guest) = connected_async_channels(16384);
1870        let guest_queue = Queue::new(guest).unwrap();
1871
1872        let test_guest_mem = GuestMemory::allocate(1024);
1873        let controller = ScsiController::new();
1874
1875        let _worker = TestWorker::start(
1876            controller.clone(),
1877            driver.clone(),
1878            test_guest_mem,
1879            host,
1880            None,
1881        );
1882
1883        let mut guest = test_helpers::TestGuest {
1884            queue: guest_queue,
1885            transaction_id: 0,
1886        };
1887
1888        let negotiate_packet = storvsp_protocol::Packet {
1889            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
1890            flags: 0,
1891            status: storvsp_protocol::NtStatus::SUCCESS,
1892        };
1893        guest
1894            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1895            .await;
1896
1897        guest.verify_completion(parse_guest_completion).await;
1898
1899        let header = storvsp_protocol::Packet {
1900            operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
1901            flags: 0,
1902            status: storvsp_protocol::NtStatus::SUCCESS,
1903        };
1904
1905        let mut buf = [0u8; 128];
1906        storvsp_protocol::ProtocolVersion {
1907            major_minor: !0,
1908            reserved: 0,
1909        }
1910        .write_to_prefix(&mut buf[..])
1911        .unwrap(); // PANIC: Infallable since `ProtcolVersion` is less than 128 bytes
1912
1913        for &(len, resp_len) in &[(48, 48), (50, 56), (56, 56), (64, 64), (72, 64)] {
1914            guest
1915                .send_data_packet_sync(&[header.as_bytes(), &buf[..len - size_of_val(&header)]])
1916                .await;
1917
1918            guest
1919                .verify_completion(|packet| {
1920                    let IncomingPacket::Completion(packet) = packet else {
1921                        unreachable!()
1922                    };
1923                    assert_eq!(packet.reader().len(), resp_len);
1924                    assert_eq!(
1925                        packet
1926                            .reader()
1927                            .read_plain::<storvsp_protocol::Packet>()
1928                            .unwrap()
1929                            .status,
1930                        storvsp_protocol::NtStatus::REVISION_MISMATCH
1931                    );
1932                    Ok(())
1933                })
1934                .await;
1935        }
1936    }
1937
1938    #[async_test]
1939    async fn test_wrong_first_packet(driver: DefaultDriver) {
1940        // set up the channels and worker
1941        let (host, guest) = connected_async_channels(16384);
1942        let guest_queue = Queue::new(guest).unwrap();
1943
1944        let test_guest_mem = GuestMemory::allocate(1024);
1945        let controller = ScsiController::new();
1946
1947        let _worker = TestWorker::start(
1948            controller.clone(),
1949            driver.clone(),
1950            test_guest_mem,
1951            host,
1952            None,
1953        );
1954
1955        let mut guest = test_helpers::TestGuest {
1956            queue: guest_queue,
1957            transaction_id: 0,
1958        };
1959
1960        // Protocol negotiation done out of order
1961        let negotiate_packet = storvsp_protocol::Packet {
1962            operation: storvsp_protocol::Operation::END_INITIALIZATION,
1963            flags: 0,
1964            status: storvsp_protocol::NtStatus::SUCCESS,
1965        };
1966        guest
1967            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1968            .await;
1969
1970        guest
1971            .verify_completion(|packet| {
1972                let IncomingPacket::Completion(packet) = packet else {
1973                    unreachable!()
1974                };
1975                assert_eq!(
1976                    packet
1977                        .reader()
1978                        .read_plain::<storvsp_protocol::Packet>()
1979                        .unwrap()
1980                        .status,
1981                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE
1982                );
1983                Ok(())
1984            })
1985            .await;
1986    }
1987
1988    #[async_test]
1989    async fn test_unrecognized_operation(driver: DefaultDriver) {
1990        // set up the channels and worker
1991        let (host, guest) = connected_async_channels(16384);
1992        let guest_queue = Queue::new(guest).unwrap();
1993
1994        let test_guest_mem = GuestMemory::allocate(1024);
1995        let controller = ScsiController::new();
1996
1997        let worker = TestWorker::start(
1998            controller.clone(),
1999            driver.clone(),
2000            test_guest_mem,
2001            host,
2002            None,
2003        );
2004
2005        let mut guest = test_helpers::TestGuest {
2006            queue: guest_queue,
2007            transaction_id: 0,
2008        };
2009
2010        // Send packet with unrecognized operation
2011        let negotiate_packet = storvsp_protocol::Packet {
2012            operation: storvsp_protocol::Operation::REMOVE_DEVICE,
2013            flags: 0,
2014            status: storvsp_protocol::NtStatus::SUCCESS,
2015        };
2016        guest
2017            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2018            .await;
2019
2020        match worker.teardown().await {
2021            Err(WorkerError::PacketError(PacketError::UnrecognizedOperation(
2022                storvsp_protocol::Operation::REMOVE_DEVICE,
2023            ))) => {}
2024            result => panic!("Worker failed with unexpected result {:?}!", result),
2025        }
2026    }
2027
2028    #[async_test]
2029    async fn test_too_many_subchannels(driver: DefaultDriver) {
2030        // set up the channels and worker
2031        let (host, guest) = connected_async_channels(16384);
2032        let guest_queue = Queue::new(guest).unwrap();
2033
2034        let test_guest_mem = GuestMemory::allocate(1024);
2035        let controller = ScsiController::new();
2036
2037        let _worker = TestWorker::start(
2038            controller.clone(),
2039            driver.clone(),
2040            test_guest_mem,
2041            host,
2042            None,
2043        );
2044
2045        let mut guest = test_helpers::TestGuest {
2046            queue: guest_queue,
2047            transaction_id: 0,
2048        };
2049
2050        let negotiate_packet = storvsp_protocol::Packet {
2051            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2052            flags: 0,
2053            status: storvsp_protocol::NtStatus::SUCCESS,
2054        };
2055        guest
2056            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2057            .await;
2058        guest.verify_completion(parse_guest_completion).await;
2059
2060        let version_packet = storvsp_protocol::Packet {
2061            operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2062            flags: 0,
2063            status: storvsp_protocol::NtStatus::SUCCESS,
2064        };
2065        let version = storvsp_protocol::ProtocolVersion {
2066            major_minor: storvsp_protocol::VERSION_BLUE,
2067            reserved: 0,
2068        };
2069        guest
2070            .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2071            .await;
2072        guest.verify_completion(parse_guest_completion).await;
2073
2074        let properties_packet = storvsp_protocol::Packet {
2075            operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2076            flags: 0,
2077            status: storvsp_protocol::NtStatus::SUCCESS,
2078        };
2079        guest
2080            .send_data_packet_sync(&[properties_packet.as_bytes()])
2081            .await;
2082
2083        guest.verify_completion(parse_guest_completion).await;
2084
2085        let negotiate_packet = storvsp_protocol::Packet {
2086            operation: storvsp_protocol::Operation::CREATE_SUB_CHANNELS,
2087            flags: 0,
2088            status: storvsp_protocol::NtStatus::SUCCESS,
2089        };
2090        // Create sub channels more than maximum_sub_channel_count
2091        guest
2092            .send_data_packet_sync(&[negotiate_packet.as_bytes(), 1_u16.as_bytes()])
2093            .await;
2094
2095        guest
2096            .verify_completion(|packet| {
2097                let IncomingPacket::Completion(packet) = packet else {
2098                    unreachable!()
2099                };
2100                assert_eq!(
2101                    packet
2102                        .reader()
2103                        .read_plain::<storvsp_protocol::Packet>()
2104                        .unwrap()
2105                        .status,
2106                    storvsp_protocol::NtStatus::INVALID_PARAMETER
2107                );
2108                Ok(())
2109            })
2110            .await;
2111    }
2112
2113    #[async_test]
2114    async fn test_begin_init_on_ready(driver: DefaultDriver) {
2115        // set up the channels and worker
2116        let (host, guest) = connected_async_channels(16384);
2117        let guest_queue = Queue::new(guest).unwrap();
2118
2119        let test_guest_mem = GuestMemory::allocate(1024);
2120        let controller = ScsiController::new();
2121
2122        let _worker = TestWorker::start(
2123            controller.clone(),
2124            driver.clone(),
2125            test_guest_mem,
2126            host,
2127            None,
2128        );
2129
2130        let mut guest = test_helpers::TestGuest {
2131            queue: guest_queue,
2132            transaction_id: 0,
2133        };
2134
2135        guest.perform_protocol_negotiation().await;
2136
2137        // Protocol negotiation done out of order
2138        let negotiate_packet = storvsp_protocol::Packet {
2139            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2140            flags: 0,
2141            status: storvsp_protocol::NtStatus::SUCCESS,
2142        };
2143        guest
2144            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2145            .await;
2146
2147        guest
2148            .verify_completion(|p| {
2149                parse_guest_completion_check_flags_status(
2150                    p,
2151                    0,
2152                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
2153                )
2154            })
2155            .await;
2156    }
2157
2158    #[async_test]
2159    async fn test_hot_add_remove(driver: DefaultDriver) {
2160        // set up channels and worker.
2161        let (host, guest) = connected_async_channels(16 * 1024);
2162        let guest_queue = Queue::new(guest).unwrap();
2163
2164        let test_guest_mem = GuestMemory::allocate(16384);
2165        // create a controller with no disk yet.
2166        let controller = ScsiController::new();
2167
2168        let test_worker = TestWorker::start(
2169            controller.clone(),
2170            driver.clone(),
2171            test_guest_mem.clone(),
2172            host,
2173            None,
2174        );
2175
2176        let mut guest = test_helpers::TestGuest {
2177            queue: guest_queue,
2178            transaction_id: 0,
2179        };
2180
2181        guest.perform_protocol_negotiation().await;
2182
2183        // Verify no LUNs are reported initially.
2184        let mut lun_list_buffer: [u8; 256] = [0; 256];
2185        let mut disk_count = 0;
2186        guest
2187            .send_report_luns_packet(ScsiPath::default(), 0, lun_list_buffer.len())
2188            .await;
2189        guest
2190            .verify_completion(|p| {
2191                test_helpers::parse_guest_completed_io_check_tx_len(p, SrbStatus::SUCCESS, Some(8))
2192            })
2193            .await;
2194        test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2195        let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2196        assert_eq!(lun_list_size, disk_count as u32 * 8);
2197
2198        // Set up a buffer for writes.
2199        const IO_LEN: usize = 4 * 1024;
2200        let write_buf = [7u8; IO_LEN];
2201        let write_gpa = 4 * 1024u64;
2202        test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2203
2204        guest
2205            .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2206            .await;
2207        guest
2208            .verify_completion(|p| {
2209                test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2210            })
2211            .await;
2212
2213        // Add some disks while the guest is running.
2214        for lun in 0..4 {
2215            let disk = scsidisk::SimpleScsiDisk::new(
2216                disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
2217                Default::default(),
2218            );
2219            controller
2220                .attach(
2221                    ScsiPath {
2222                        path: 0,
2223                        target: 0,
2224                        lun,
2225                    },
2226                    ScsiControllerDisk::new(Arc::new(disk)),
2227                )
2228                .unwrap();
2229            guest
2230                .verify_completion(test_helpers::parse_guest_enumerate_bus)
2231                .await;
2232
2233            disk_count += 1;
2234            guest
2235                .send_report_luns_packet(ScsiPath::default(), 0, 256)
2236                .await;
2237            guest
2238                .verify_completion(|p| {
2239                    test_helpers::parse_guest_completed_io_check_tx_len(
2240                        p,
2241                        SrbStatus::SUCCESS,
2242                        Some((disk_count + 1) * 8),
2243                    )
2244                })
2245                .await;
2246            test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2247            let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2248            assert_eq!(lun_list_size, disk_count as u32 * 8);
2249
2250            guest
2251                .send_write_packet(
2252                    ScsiPath {
2253                        path: 0,
2254                        target: 0,
2255                        lun,
2256                    },
2257                    write_gpa,
2258                    1,
2259                    IO_LEN,
2260                )
2261                .await;
2262            guest
2263                .verify_completion(|p| {
2264                    test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS)
2265                })
2266                .await;
2267        }
2268
2269        // Remove all disks while the guest is running.
2270        for lun in 0..4 {
2271            controller
2272                .remove(ScsiPath {
2273                    path: 0,
2274                    target: 0,
2275                    lun,
2276                })
2277                .unwrap();
2278            guest
2279                .verify_completion(test_helpers::parse_guest_enumerate_bus)
2280                .await;
2281
2282            disk_count -= 1;
2283            guest
2284                .send_report_luns_packet(ScsiPath::default(), 0, 4096)
2285                .await;
2286            guest
2287                .verify_completion(|p| {
2288                    test_helpers::parse_guest_completed_io_check_tx_len(
2289                        p,
2290                        SrbStatus::SUCCESS,
2291                        Some((disk_count + 1) * 8),
2292                    )
2293                })
2294                .await;
2295            test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2296            let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2297            assert_eq!(lun_list_size, disk_count as u32 * 8);
2298
2299            guest
2300                .send_write_packet(
2301                    ScsiPath {
2302                        path: 0,
2303                        target: 0,
2304                        lun,
2305                    },
2306                    write_gpa,
2307                    1,
2308                    IO_LEN,
2309                )
2310                .await;
2311            guest
2312                .verify_completion(|p| {
2313                    test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2314                })
2315                .await;
2316        }
2317
2318        guest.verify_graceful_close(test_worker).await;
2319    }
2320
2321    #[async_test]
2322    pub async fn test_async_disk(driver: DefaultDriver) {
2323        let device = disklayer_ram::ram_disk(64 * 1024, false).unwrap();
2324        let controller = ScsiController::new();
2325        let disk = ScsiControllerDisk::new(Arc::new(scsidisk::SimpleScsiDisk::new(
2326            device,
2327            Default::default(),
2328        )));
2329        controller
2330            .attach(
2331                ScsiPath {
2332                    path: 0,
2333                    target: 0,
2334                    lun: 0,
2335                },
2336                disk,
2337            )
2338            .unwrap();
2339
2340        let (host, guest) = connected_async_channels(16 * 1024);
2341        let guest_queue = Queue::new(guest).unwrap();
2342
2343        let mut guest = test_helpers::TestGuest {
2344            queue: guest_queue,
2345            transaction_id: 0,
2346        };
2347
2348        let test_guest_mem = GuestMemory::allocate(16384);
2349        let worker = TestWorker::start(
2350            controller.clone(),
2351            &driver,
2352            test_guest_mem.clone(),
2353            host,
2354            None,
2355        );
2356
2357        let negotiate_packet = storvsp_protocol::Packet {
2358            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2359            flags: 0,
2360            status: storvsp_protocol::NtStatus::SUCCESS,
2361        };
2362        guest
2363            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2364            .await;
2365        guest.verify_completion(parse_guest_completion).await;
2366
2367        let version_packet = storvsp_protocol::Packet {
2368            operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2369            flags: 0,
2370            status: storvsp_protocol::NtStatus::SUCCESS,
2371        };
2372        let version = storvsp_protocol::ProtocolVersion {
2373            major_minor: storvsp_protocol::VERSION_BLUE,
2374            reserved: 0,
2375        };
2376        guest
2377            .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2378            .await;
2379        guest.verify_completion(parse_guest_completion).await;
2380
2381        let properties_packet = storvsp_protocol::Packet {
2382            operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2383            flags: 0,
2384            status: storvsp_protocol::NtStatus::SUCCESS,
2385        };
2386        guest
2387            .send_data_packet_sync(&[properties_packet.as_bytes()])
2388            .await;
2389        guest.verify_completion(parse_guest_completion).await;
2390
2391        let negotiate_packet = storvsp_protocol::Packet {
2392            operation: storvsp_protocol::Operation::END_INITIALIZATION,
2393            flags: 0,
2394            status: storvsp_protocol::NtStatus::SUCCESS,
2395        };
2396        guest
2397            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2398            .await;
2399        guest.verify_completion(parse_guest_completion).await;
2400
2401        const IO_LEN: usize = 4 * 1024;
2402        let write_buf = [7u8; IO_LEN];
2403        let write_gpa = 4 * 1024u64;
2404        test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2405        guest
2406            .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2407            .await;
2408        guest
2409            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2410            .await;
2411
2412        let read_gpa = 8 * 1024u64;
2413        guest
2414            .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
2415            .await;
2416        guest
2417            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2418            .await;
2419        let mut read_buf = [0u8; IO_LEN];
2420        test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
2421        for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
2422            assert_eq!(b1, b2);
2423        }
2424
2425        guest.verify_graceful_close(worker).await;
2426    }
2427}