storvsp/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7#[cfg(feature = "ioperf")]
8pub mod ioperf;
9
10#[cfg(feature = "test")]
11pub mod test_helpers;
12
13#[cfg(not(feature = "test"))]
14mod test_helpers;
15
16pub mod resolver;
17mod save_restore;
18
19use crate::ring::gparange::GpnList;
20use crate::ring::gparange::MultiPagedRangeBuf;
21use anyhow::Context as _;
22use async_trait::async_trait;
23use fast_select::FastSelect;
24use futures::FutureExt;
25use futures::StreamExt;
26use futures::select_biased;
27use guestmem::AccessError;
28use guestmem::GuestMemory;
29use guestmem::MemoryRead;
30use guestmem::MemoryWrite;
31use guestmem::ranges::PagedRange;
32use guid::Guid;
33use inspect::Inspect;
34use inspect::InspectMut;
35use inspect_counters::Counter;
36use inspect_counters::Histogram;
37use oversized_box::OversizedBox;
38use parking_lot::Mutex;
39use parking_lot::RwLock;
40use ring::OutgoingPacketType;
41use scsi::AdditionalSenseCode;
42use scsi::ScsiOp;
43use scsi::ScsiStatus;
44use scsi::srb::SrbStatus;
45use scsi::srb::SrbStatusAndFlags;
46use scsi_buffers::RequestBuffers;
47use scsi_core::AsyncScsiDisk;
48use scsi_core::Request;
49use scsi_core::ScsiResult;
50use scsi_defs as scsi;
51use scsidisk::illegal_request_sense;
52use slab::Slab;
53use std::collections::hash_map::Entry;
54use std::collections::hash_map::HashMap;
55use std::fmt::Debug;
56use std::future::Future;
57use std::future::poll_fn;
58use std::pin::Pin;
59use std::sync::Arc;
60use std::task::Context;
61use std::task::Poll;
62use storvsp_resources::ScsiPath;
63use task_control::AsyncRun;
64use task_control::InspectTask;
65use task_control::StopTask;
66use task_control::TaskControl;
67use thiserror::Error;
68use tracing_helpers::ErrorValueExt;
69use unicycle::FuturesUnordered;
70use vmbus_async::queue;
71use vmbus_async::queue::ExternalDataError;
72use vmbus_async::queue::IncomingPacket;
73use vmbus_async::queue::OutgoingPacket;
74use vmbus_async::queue::Queue;
75use vmbus_channel::RawAsyncChannel;
76use vmbus_channel::bus::ChannelType;
77use vmbus_channel::bus::OfferParams;
78use vmbus_channel::bus::OpenRequest;
79use vmbus_channel::channel::ChannelControl;
80use vmbus_channel::channel::ChannelOpenError;
81use vmbus_channel::channel::DeviceResources;
82use vmbus_channel::channel::RestoreControl;
83use vmbus_channel::channel::SaveRestoreVmbusDevice;
84use vmbus_channel::channel::VmbusDevice;
85use vmbus_channel::gpadl_ring::GpadlRingMem;
86use vmbus_channel::gpadl_ring::gpadl_channel;
87use vmbus_core::protocol::UserDefinedData;
88use vmbus_ring as ring;
89use vmbus_ring::RingMem;
90use vmcore::save_restore::RestoreError;
91use vmcore::save_restore::SaveError;
92use vmcore::save_restore::SavedStateBlob;
93use vmcore::vm_task::VmTaskDriver;
94use vmcore::vm_task::VmTaskDriverSource;
95use zerocopy::FromBytes;
96use zerocopy::FromZeros;
97use zerocopy::Immutable;
98use zerocopy::IntoBytes;
99use zerocopy::KnownLayout;
100
101pub struct StorageDevice {
102    instance_id: Guid,
103    ide_path: Option<ScsiPath>,
104    workers: Vec<WorkerAndDriver>,
105    controller: Arc<ScsiControllerState>,
106    resources: DeviceResources,
107    driver_source: VmTaskDriverSource,
108    max_sub_channel_count: u16,
109    protocol: Arc<Protocol>,
110    io_queue_depth: u32,
111}
112
113#[derive(Inspect)]
114struct WorkerAndDriver {
115    #[inspect(flatten)]
116    worker: TaskControl<WorkerState, Worker>,
117    driver: VmTaskDriver,
118}
119
120struct WorkerState;
121
122impl InspectMut for StorageDevice {
123    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
124        let mut resp = req.respond();
125
126        let disks = self.controller.disks.read();
127        for (path, controller_disk) in disks.iter() {
128            resp.child(&format!("disks/{}", path), |req| {
129                controller_disk.disk.inspect(req);
130            });
131        }
132
133        resp.fields(
134            "channels",
135            self.workers
136                .iter()
137                .filter(|task| task.worker.has_state())
138                .enumerate(),
139        );
140    }
141}
142
143struct Worker<T: RingMem = GpadlRingMem> {
144    inner: WorkerInner,
145    rescan_notification: futures::channel::mpsc::Receiver<()>,
146    fast_select: FastSelect,
147    queue: Queue<T>,
148}
149
150struct Protocol {
151    state: RwLock<ProtocolState>,
152    /// Signaled when `state` transitions to `ProtocolState::Ready`.
153    ready: event_listener::Event,
154}
155
156struct WorkerInner {
157    protocol: Arc<Protocol>,
158    request_size: usize,
159    controller: Arc<ScsiControllerState>,
160    channel_index: u16,
161    scsi_queue: Arc<ScsiCommandQueue>,
162    scsi_requests: FuturesUnordered<ScsiRequest>,
163    scsi_requests_states: Slab<ScsiRequestState>,
164    full_request_pool: Vec<Arc<ScsiRequestAndRange>>,
165    future_pool: Vec<OversizedBox<(), ScsiOpStorage>>,
166    channel_control: ChannelControl,
167    max_io_queue_depth: usize,
168    stats: WorkerStats,
169}
170
171#[derive(Debug, Default, Inspect)]
172struct WorkerStats {
173    ios_submitted: Counter,
174    ios_completed: Counter,
175    wakes: Counter,
176    wakes_spurious: Counter,
177    per_wake_submissions: Histogram<10>,
178    per_wake_completions: Histogram<10>,
179}
180
181#[repr(u16)]
182#[derive(Copy, Clone, Debug, Inspect, PartialEq, Eq, PartialOrd, Ord)]
183enum Version {
184    Win6 = storvsp_protocol::VERSION_WIN6,
185    Win7 = storvsp_protocol::VERSION_WIN7,
186    Win8 = storvsp_protocol::VERSION_WIN8,
187    Blue = storvsp_protocol::VERSION_BLUE,
188}
189
190#[derive(Debug, Error)]
191#[error("protocol version {0:#x} not supported")]
192struct UnsupportedVersion(u16);
193
194impl Version {
195    fn parse(major_minor: u16) -> Result<Self, UnsupportedVersion> {
196        let version = match major_minor {
197            storvsp_protocol::VERSION_WIN6 => Self::Win6,
198            storvsp_protocol::VERSION_WIN7 => Self::Win7,
199            storvsp_protocol::VERSION_WIN8 => Self::Win8,
200            storvsp_protocol::VERSION_BLUE => Self::Blue,
201            version => return Err(UnsupportedVersion(version)),
202        };
203        assert_eq!(version as u16, major_minor);
204        Ok(version)
205    }
206
207    fn max_request_size(&self) -> usize {
208        match self {
209            Version::Win8 | Version::Blue => storvsp_protocol::SCSI_REQUEST_LEN_V2,
210            Version::Win6 | Version::Win7 => storvsp_protocol::SCSI_REQUEST_LEN_V1,
211        }
212    }
213}
214
215#[derive(Copy, Clone)]
216enum ProtocolState {
217    Init(InitState),
218    Ready {
219        version: Version,
220        subchannel_count: u16,
221    },
222}
223
224#[derive(Copy, Clone, Debug)]
225enum InitState {
226    Begin,
227    QueryVersion,
228    QueryProperties {
229        version: Version,
230    },
231    EndInitialization {
232        version: Version,
233        subchannel_count: Option<u16>,
234    },
235}
236
237/// The internal SCSI operation future type.
238///
239/// This is a boxed future of a large pre-determined size. The box is reused
240/// after a SCSI request completes to avoid allocations in the hot path.
241///
242/// An Option type is used so that the future can be efficiently dropped (via
243/// `Pin::set(x, None)`) before it is stashed away for reuse.
244type ScsiOpStorage = [u64; SCSI_REQUEST_STACK_SIZE / 8];
245type ScsiOpFuture = Pin<OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>>;
246
247/// The amount of space reserved for a ScsiOpFuture.
248///
249/// This was chosen by running `cargo test -p storvsp -- --no-capture` and looking at the required
250/// size that was given in the failure message
251const SCSI_REQUEST_STACK_SIZE: usize = scsi_core::ASYNC_SCSI_DISK_STACK_SIZE + 272;
252
253struct ScsiRequest {
254    request_id: usize,
255    future: Option<ScsiOpFuture>,
256}
257
258impl ScsiRequest {
259    fn new(
260        request_id: usize,
261        future: OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>,
262    ) -> Self {
263        Self {
264            request_id,
265            future: Some(future.into()),
266        }
267    }
268}
269
270impl Future for ScsiRequest {
271    type Output = (usize, ScsiResult, OversizedBox<(), ScsiOpStorage>);
272
273    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
274        let this = self.get_mut();
275        let future = this.future.as_mut().unwrap().as_mut();
276        let result = std::task::ready!(future.poll(cx));
277        // Return the future so that its storage can be reused.
278        let future = this.future.take().unwrap();
279        Poll::Ready((this.request_id, result, OversizedBox::empty_pinned(future)))
280    }
281}
282
283#[derive(Debug, Error)]
284enum WorkerError {
285    #[error("packet error")]
286    PacketError(#[source] PacketError),
287    #[error("queue error")]
288    Queue(#[source] queue::Error),
289    #[error("queue should have enough space but no longer does")]
290    NotEnoughSpace,
291}
292
293#[derive(Debug, Error)]
294enum PacketError {
295    #[error("Not transactional")]
296    NotTransactional,
297    #[error("Unrecognized operation {0:?}")]
298    UnrecognizedOperation(storvsp_protocol::Operation),
299    #[error("Invalid packet type")]
300    InvalidPacketType,
301    #[error("Invalid data transfer length")]
302    InvalidDataTransferLength,
303    #[error("Access error: {0}")]
304    Access(#[source] AccessError),
305    #[error("Range error")]
306    Range(#[source] ExternalDataError),
307}
308
309#[derive(Debug, Default, Clone)]
310struct Range {
311    buf: MultiPagedRangeBuf<GpnList>,
312    len: usize,
313    is_write: bool,
314}
315
316impl Range {
317    fn new(
318        buf: MultiPagedRangeBuf<GpnList>,
319        request: &storvsp_protocol::ScsiRequest,
320    ) -> Option<Self> {
321        let len = request.data_transfer_length as usize;
322        let is_write = request.data_in != 0;
323        // Ensure there is exactly one range and it's large enough, or there are
324        // zero ranges and there is no associated SCSI buffer.
325        if buf.range_count() > 1 || (len > 0 && buf.first()?.len() < len) {
326            return None;
327        }
328        Some(Self { buf, len, is_write })
329    }
330
331    fn buffer<'a>(&'a self, guest_memory: &'a GuestMemory) -> RequestBuffers<'a> {
332        let mut range = self.buf.first().unwrap_or_else(PagedRange::empty);
333        range.truncate(self.len);
334        RequestBuffers::new(guest_memory, range, self.is_write)
335    }
336}
337
338#[derive(Debug)]
339struct Packet {
340    data: PacketData,
341    transaction_id: u64,
342    request_size: usize,
343}
344
345#[derive(Debug)]
346enum PacketData {
347    BeginInitialization,
348    EndInitialization,
349    QueryProtocolVersion(u16),
350    QueryProperties,
351    CreateSubChannels(u16),
352    ExecuteScsi(Arc<ScsiRequestAndRange>),
353    ResetBus,
354    ResetAdapter,
355    ResetLun,
356}
357
358#[derive(Debug)]
359pub struct RangeError;
360
361fn parse_packet<T: RingMem>(
362    packet: &IncomingPacket<'_, T>,
363    pool: &mut Vec<Arc<ScsiRequestAndRange>>,
364) -> Result<Packet, PacketError> {
365    let packet = match packet {
366        IncomingPacket::Completion(_) => return Err(PacketError::InvalidPacketType),
367        IncomingPacket::Data(packet) => packet,
368    };
369    let transaction_id = packet
370        .transaction_id()
371        .ok_or(PacketError::NotTransactional)?;
372
373    let mut reader = packet.reader();
374    let header: storvsp_protocol::Packet = reader.read_plain().map_err(PacketError::Access)?;
375    // You would expect that this should be limited to the current protocol
376    // version's maximum packet size, but this is not what Hyper-V does, and
377    // Linux 6.1 relies on this behavior during protocol initialization.
378    let request_size = reader.len().min(storvsp_protocol::SCSI_REQUEST_LEN_MAX);
379    let data = match header.operation {
380        storvsp_protocol::Operation::BEGIN_INITIALIZATION => PacketData::BeginInitialization,
381        storvsp_protocol::Operation::END_INITIALIZATION => PacketData::EndInitialization,
382        storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION => {
383            let mut version = storvsp_protocol::ProtocolVersion::new_zeroed();
384            reader
385                .read(version.as_mut_bytes())
386                .map_err(PacketError::Access)?;
387            PacketData::QueryProtocolVersion(version.major_minor)
388        }
389        storvsp_protocol::Operation::QUERY_PROPERTIES => PacketData::QueryProperties,
390        storvsp_protocol::Operation::EXECUTE_SRB => {
391            let mut full_request = pool.pop().unwrap_or_else(|| {
392                Arc::new(ScsiRequestAndRange {
393                    external_data: Range::default(),
394                    request: storvsp_protocol::ScsiRequest::new_zeroed(),
395                    request_size,
396                })
397            });
398
399            {
400                let full_request = Arc::get_mut(&mut full_request).unwrap();
401                let request_buf = &mut full_request.request.as_mut_bytes()[..request_size];
402                reader.read(request_buf).map_err(PacketError::Access)?;
403
404                let buf = packet.read_external_ranges().map_err(PacketError::Range)?;
405
406                full_request.external_data = Range::new(buf, &full_request.request)
407                    .ok_or(PacketError::InvalidDataTransferLength)?;
408            }
409
410            PacketData::ExecuteScsi(full_request)
411        }
412        storvsp_protocol::Operation::RESET_LUN => PacketData::ResetLun,
413        storvsp_protocol::Operation::RESET_ADAPTER => PacketData::ResetAdapter,
414        storvsp_protocol::Operation::RESET_BUS => PacketData::ResetBus,
415        storvsp_protocol::Operation::CREATE_SUB_CHANNELS => {
416            let mut sub_channel_count: u16 = 0;
417            reader
418                .read(sub_channel_count.as_mut_bytes())
419                .map_err(PacketError::Access)?;
420            PacketData::CreateSubChannels(sub_channel_count)
421        }
422        _ => return Err(PacketError::UnrecognizedOperation(header.operation)),
423    };
424
425    if let PacketData::ExecuteScsi(_) = data {
426        tracing::trace!(transaction_id, ?data, "parse_packet");
427    } else {
428        tracing::debug!(transaction_id, ?data, "parse_packet");
429    }
430
431    Ok(Packet {
432        data,
433        request_size,
434        transaction_id,
435    })
436}
437
438impl WorkerInner {
439    fn send_vmbus_packet<M: RingMem>(
440        &mut self,
441        writer: &mut queue::WriteBatch<'_, M>,
442        packet_type: OutgoingPacketType<'_>,
443        request_size: usize,
444        transaction_id: u64,
445        operation: storvsp_protocol::Operation,
446        status: storvsp_protocol::NtStatus,
447        payload: &[u8],
448    ) -> Result<(), WorkerError> {
449        let header = storvsp_protocol::Packet {
450            operation,
451            flags: 0,
452            status,
453        };
454
455        let packet_size = size_of_val(&header) + request_size;
456
457        // Zero pad or truncate the payload to the queue's packet size. This is
458        // necessary because Windows guests check that each packet's size is
459        // exactly the largest possible packet size for the negotiated protocol
460        // version.
461        let len = size_of_val(&header) + size_of_val(payload);
462        let padding = [0; storvsp_protocol::SCSI_REQUEST_LEN_MAX];
463        let (payload_bytes, padding_bytes) = if len > packet_size {
464            (&payload[..packet_size - size_of_val(&header)], &[][..])
465        } else {
466            (payload, &padding[..packet_size - len])
467        };
468        assert_eq!(
469            size_of_val(&header) + payload_bytes.len() + padding_bytes.len(),
470            packet_size
471        );
472        writer
473            .try_write(&OutgoingPacket {
474                transaction_id,
475                packet_type,
476                payload: &[header.as_bytes(), payload_bytes, padding_bytes],
477            })
478            .map_err(|err| match err {
479                queue::TryWriteError::Full(_) => WorkerError::NotEnoughSpace,
480                queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
481            })
482    }
483
484    fn send_packet<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
485        &mut self,
486        writer: &mut queue::WriteHalf<'_, M>,
487        operation: storvsp_protocol::Operation,
488        status: storvsp_protocol::NtStatus,
489        payload: &P,
490    ) -> Result<(), WorkerError> {
491        self.send_vmbus_packet(
492            &mut writer.batched(),
493            OutgoingPacketType::InBandNoCompletion,
494            self.request_size,
495            0,
496            operation,
497            status,
498            payload.as_bytes(),
499        )
500    }
501
502    fn send_completion<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
503        &mut self,
504        writer: &mut queue::WriteHalf<'_, M>,
505        packet: &Packet,
506        status: storvsp_protocol::NtStatus,
507        payload: &P,
508    ) -> Result<(), WorkerError> {
509        self.send_vmbus_packet(
510            &mut writer.batched(),
511            OutgoingPacketType::Completion,
512            packet.request_size,
513            packet.transaction_id,
514            storvsp_protocol::Operation::COMPLETE_IO,
515            status,
516            payload.as_bytes(),
517        )
518    }
519}
520
521struct ScsiCommandQueue {
522    controller: Arc<ScsiControllerState>,
523    mem: GuestMemory,
524    force_path_id: Option<u8>,
525}
526
527impl ScsiCommandQueue {
528    async fn execute_scsi(
529        &self,
530        external_data: &Range,
531        request: &storvsp_protocol::ScsiRequest,
532    ) -> ScsiResult {
533        let op = ScsiOp(request.payload[0]);
534        let external_data = external_data.buffer(&self.mem);
535
536        tracing::trace!(
537            path_id = request.path_id,
538            target_id = request.target_id,
539            lun = request.lun,
540            op = ?op,
541            "execute_scsi start...",
542        );
543
544        let path_id = self.force_path_id.unwrap_or(request.path_id);
545
546        let controller_disk = self
547            .controller
548            .disks
549            .read()
550            .get(&ScsiPath {
551                path: path_id,
552                target: request.target_id,
553                lun: request.lun,
554            })
555            .cloned();
556
557        let result = match op {
558            ScsiOp::REPORT_LUNS => {
559                const HEADER_SIZE: usize = size_of::<scsi::LunList>();
560                let mut luns: Vec<u8> = self
561                    .controller
562                    .disks
563                    .read()
564                    .iter()
565                    .flat_map(|(path, _)| {
566                        // Use the original path ID and not the forced one to
567                        // match Hyper-V storvsp behavior.
568                        if request.path_id == path.path && request.target_id == path.target {
569                            Some(path.lun)
570                        } else {
571                            None
572                        }
573                    })
574                    .collect();
575                luns.sort_unstable();
576                let mut data: Vec<u64> = vec![0; luns.len() + 1];
577                let header = scsi::LunList {
578                    length: (luns.len() as u32 * 8).into(),
579                    reserved: [0; 4],
580                };
581                data.as_mut_bytes()[..HEADER_SIZE].copy_from_slice(header.as_bytes());
582                for (i, lun) in luns.iter().enumerate() {
583                    data[i + 1].as_mut_bytes()[..2].copy_from_slice(&(*lun as u16).to_be_bytes());
584                }
585                if external_data.len() >= HEADER_SIZE {
586                    let tx = std::cmp::min(external_data.len(), data.as_bytes().len());
587                    external_data.writer().write(&data.as_bytes()[..tx]).map_or(
588                        ScsiResult {
589                            scsi_status: ScsiStatus::CHECK_CONDITION,
590                            srb_status: SrbStatus::INVALID_REQUEST,
591                            tx: 0,
592                            sense_data: Some(illegal_request_sense(
593                                AdditionalSenseCode::INVALID_CDB,
594                            )),
595                        },
596                        |_| ScsiResult {
597                            scsi_status: ScsiStatus::GOOD,
598                            srb_status: SrbStatus::SUCCESS,
599                            tx,
600                            sense_data: None,
601                        },
602                    )
603                } else {
604                    ScsiResult {
605                        scsi_status: ScsiStatus::GOOD,
606                        srb_status: SrbStatus::SUCCESS,
607                        tx: 0,
608                        sense_data: None,
609                    }
610                }
611            }
612            _ if controller_disk.is_some() => {
613                let mut cdb = [0; 16];
614                cdb.copy_from_slice(&request.payload[0..storvsp_protocol::CDB16GENERIC_LENGTH]);
615                controller_disk
616                    .unwrap()
617                    .disk
618                    .execute_scsi(
619                        &external_data,
620                        &Request {
621                            cdb,
622                            srb_flags: request.srb_flags,
623                        },
624                    )
625                    .await
626            }
627            ScsiOp::INQUIRY => {
628                let cdb = scsi::CdbInquiry::ref_from_prefix(&request.payload)
629                    .unwrap()
630                    .0; // TODO: zerocopy: ref-from-prefix: use-rest-of-range (https://github.com/microsoft/openvmm/issues/759)
631                if external_data.len() < cdb.allocation_length.get() as usize
632                    || request.data_in != storvsp_protocol::SCSI_IOCTL_DATA_IN
633                    || (cdb.allocation_length.get() as usize) < size_of::<scsi::InquiryDataHeader>()
634                {
635                    ScsiResult {
636                        scsi_status: ScsiStatus::CHECK_CONDITION,
637                        srb_status: SrbStatus::INVALID_REQUEST,
638                        tx: 0,
639                        sense_data: Some(illegal_request_sense(AdditionalSenseCode::INVALID_CDB)),
640                    }
641                } else {
642                    let enable_vpd = cdb.flags.vpd();
643                    if enable_vpd || cdb.page_code != 0 {
644                        // cannot support VPD inquiry for non-existing device (lun).
645                        ScsiResult {
646                            scsi_status: ScsiStatus::CHECK_CONDITION,
647                            srb_status: SrbStatus::INVALID_REQUEST,
648                            tx: 0,
649                            sense_data: Some(illegal_request_sense(
650                                AdditionalSenseCode::INVALID_CDB,
651                            )),
652                        }
653                    } else {
654                        const LOGICAL_UNIT_NOT_PRESENT_DEVICE: u8 = 0x7F;
655                        let mut data = scsidisk::INQUIRY_DATA_TEMPLATE;
656                        data.header.device_type = LOGICAL_UNIT_NOT_PRESENT_DEVICE;
657
658                        if request.lun != 0 {
659                            // Below fields are only set for lun0 inquiry so zero out here.
660                            data.vendor_id = [0; 8];
661                            data.product_id = [0; 16];
662                            data.product_revision_level = [0; 4];
663                        }
664
665                        let datab = data.as_bytes();
666                        let tx = std::cmp::min(
667                            cdb.allocation_length.get() as usize,
668                            size_of::<scsi::InquiryData>(),
669                        );
670                        external_data.writer().write(&datab[..tx]).map_or(
671                            ScsiResult {
672                                scsi_status: ScsiStatus::CHECK_CONDITION,
673                                srb_status: SrbStatus::INVALID_REQUEST,
674                                tx: 0,
675                                sense_data: Some(illegal_request_sense(
676                                    AdditionalSenseCode::INVALID_CDB,
677                                )),
678                            },
679                            |_| ScsiResult {
680                                scsi_status: ScsiStatus::GOOD,
681                                srb_status: SrbStatus::SUCCESS,
682                                tx,
683                                sense_data: None,
684                            },
685                        )
686                    }
687                }
688            }
689            _ => ScsiResult {
690                scsi_status: ScsiStatus::CHECK_CONDITION,
691                srb_status: SrbStatus::INVALID_LUN,
692                tx: 0,
693                sense_data: None,
694            },
695        };
696
697        tracing::trace!(
698            path_id = request.path_id,
699            target_id = request.target_id,
700            lun = request.lun,
701            op = ?op,
702            result = ?result,
703            "execute_scsi completed.",
704        );
705        result
706    }
707}
708
709impl<T: RingMem + 'static> Worker<T> {
710    fn new(
711        controller: Arc<ScsiControllerState>,
712        channel: RawAsyncChannel<T>,
713        channel_index: u16,
714        mem: GuestMemory,
715        channel_control: ChannelControl,
716        io_queue_depth: u32,
717        protocol: Arc<Protocol>,
718        force_path_id: Option<u8>,
719    ) -> anyhow::Result<Self> {
720        let queue = Queue::new(channel)?;
721        #[expect(clippy::disallowed_methods)] // TODO
722        let (source, target) = futures::channel::mpsc::channel(1);
723        controller.add_rescan_notification_source(source);
724
725        let max_io_queue_depth = io_queue_depth.max(1) as usize;
726        Ok(Self {
727            inner: WorkerInner {
728                protocol,
729                request_size: storvsp_protocol::SCSI_REQUEST_LEN_V1,
730                controller: controller.clone(),
731                channel_index,
732                scsi_queue: Arc::new(ScsiCommandQueue {
733                    controller,
734                    mem,
735                    force_path_id,
736                }),
737                scsi_requests: FuturesUnordered::new(),
738                scsi_requests_states: Slab::with_capacity(max_io_queue_depth),
739                channel_control,
740                max_io_queue_depth,
741                future_pool: Vec::new(),
742                full_request_pool: Vec::new(),
743                stats: Default::default(),
744            },
745            queue,
746            rescan_notification: target,
747            fast_select: FastSelect::new(),
748        })
749    }
750
751    async fn wait_for_scsi_requests_complete(&mut self) {
752        tracing::debug!(
753            channel_index = self.inner.channel_index,
754            "wait for IOs completed..."
755        );
756        while let Some((id, _, _)) = self.inner.scsi_requests.next().await {
757            self.inner.scsi_requests_states.remove(id);
758        }
759    }
760}
761
762impl InspectTask<Worker> for WorkerState {
763    fn inspect(&self, req: inspect::Request<'_>, worker: Option<&Worker>) {
764        if let Some(worker) = worker {
765            let mut resp = req.respond();
766            if worker.inner.channel_index == 0 {
767                let (state, version, subchannel_count) = match *worker.inner.protocol.state.read() {
768                    ProtocolState::Init(state) => match state {
769                        InitState::Begin => ("begin_init", None, None),
770                        InitState::QueryVersion => ("query_version", None, None),
771                        InitState::QueryProperties { version } => {
772                            ("query_properties", Some(version), None)
773                        }
774                        InitState::EndInitialization {
775                            version,
776                            subchannel_count,
777                        } => ("end_init", Some(version), subchannel_count),
778                    },
779                    ProtocolState::Ready {
780                        version,
781                        subchannel_count,
782                    } => ("ready", Some(version), Some(subchannel_count)),
783                };
784                resp.field("state", state)
785                    .field("version", version)
786                    .field("subchannel_count", subchannel_count);
787            }
788            resp.field("pending_packets", worker.inner.scsi_requests_states.len())
789                .fields("io", worker.inner.scsi_requests_states.iter())
790                .field("stats", &worker.inner.stats)
791                .field("ring", &worker.queue)
792                .field("max_io_queue_depth", worker.inner.max_io_queue_depth);
793        }
794    }
795}
796
797impl<T: 'static + Send + Sync + RingMem> AsyncRun<Worker<T>> for WorkerState {
798    async fn run(
799        &mut self,
800        stop: &mut StopTask<'_>,
801        worker: &mut Worker<T>,
802    ) -> Result<(), task_control::Cancelled> {
803        let fut = async {
804            if worker.inner.channel_index == 0 {
805                worker.process_primary().await
806            } else {
807                // Wait for initialization to end before processing any
808                // subchannel packets.
809                let protocol_version = loop {
810                    let listener = worker.inner.protocol.ready.listen();
811                    if let ProtocolState::Ready { version, .. } =
812                        *worker.inner.protocol.state.read()
813                    {
814                        break version;
815                    }
816                    tracing::debug!("subchannel waiting for initialization to end");
817                    listener.await
818                };
819                worker
820                    .inner
821                    .process_ready(&mut worker.queue, protocol_version)
822                    .await
823            }
824        };
825
826        match stop.until_stopped(fut).await? {
827            Ok(_) => {}
828            Err(e) => tracing::error!(error = e.as_error(), "process_packets error"),
829        }
830        Ok(())
831    }
832}
833
834impl WorkerInner {
835    /// Awaits the next incoming packet, without checking for any other events (device add/remove notifications or available completions).
836    /// Increments the count of outstanding packets when returning `Ok(Packet)`.
837    async fn next_packet<'a, M: RingMem>(
838        &mut self,
839        reader: &'a mut queue::ReadHalf<'a, M>,
840    ) -> Result<Packet, WorkerError> {
841        let packet = reader.read().await.map_err(WorkerError::Queue)?;
842        let stor_packet =
843            parse_packet(&packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
844        Ok(stor_packet)
845    }
846
847    /// Polls for enough ring space in the outgoing ring to send a packet.
848    ///
849    /// This is used to ensure there is enough space in the ring before
850    /// committing to sending a packet. This avoids the need to save pending
851    /// packets on the side if queue processing is interrupted while the ring is
852    /// full.
853    fn poll_for_ring_space<M: RingMem>(
854        &mut self,
855        cx: &mut Context<'_>,
856        writer: &mut queue::WriteHalf<'_, M>,
857    ) -> Poll<Result<(), WorkerError>> {
858        writer
859            .poll_ready(cx, MAX_VMBUS_PACKET_SIZE)
860            .map_err(WorkerError::Queue)
861    }
862}
863
864const MAX_VMBUS_PACKET_SIZE: usize = ring::PacketSize::in_band(
865    size_of::<storvsp_protocol::Packet>() + storvsp_protocol::SCSI_REQUEST_LEN_MAX,
866);
867
868impl<T: RingMem> Worker<T> {
869    /// Processes the protocol state machine.
870    async fn process_primary(&mut self) -> Result<(), WorkerError> {
871        loop {
872            let current_state = *self.inner.protocol.state.read();
873            match current_state {
874                ProtocolState::Ready { version, .. } => {
875                    break loop {
876                        select_biased! {
877                            r = self.inner.process_ready(&mut self.queue, version).fuse() => break r,
878                            _ = self.fast_select.select((self.rescan_notification.select_next_some(),)).fuse() => {
879                                if version >= Version::Win7
880                                {
881                                    self.inner.send_packet(&mut self.queue.split().1, storvsp_protocol::Operation::ENUMERATE_BUS, storvsp_protocol::NtStatus::SUCCESS, &())?;
882                                }
883                            }
884                        }
885                    };
886                }
887                ProtocolState::Init(state) => {
888                    let (mut reader, mut writer) = self.queue.split();
889
890                    // Ensure that subsequent calls to `send_completion` won't
891                    // fail due to lack of ring space, to avoid keeping (and saving/restoring) interim states.
892                    poll_fn(|cx| self.inner.poll_for_ring_space(cx, &mut writer)).await?;
893
894                    tracing::debug!(?state, "process_primary");
895                    match state {
896                        InitState::Begin => {
897                            let packet = self.inner.next_packet(&mut reader).await?;
898                            if let PacketData::BeginInitialization = packet.data {
899                                self.inner.send_completion(
900                                    &mut writer,
901                                    &packet,
902                                    storvsp_protocol::NtStatus::SUCCESS,
903                                    &(),
904                                )?;
905                                *self.inner.protocol.state.write() =
906                                    ProtocolState::Init(InitState::QueryVersion);
907                            } else {
908                                tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
909                                self.inner.send_completion(
910                                    &mut writer,
911                                    &packet,
912                                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
913                                    &(),
914                                )?;
915                            }
916                        }
917                        InitState::QueryVersion => {
918                            let packet = self.inner.next_packet(&mut reader).await?;
919                            if let PacketData::QueryProtocolVersion(major_minor) = packet.data {
920                                if let Ok(version) = Version::parse(major_minor) {
921                                    self.inner.send_completion(
922                                        &mut writer,
923                                        &packet,
924                                        storvsp_protocol::NtStatus::SUCCESS,
925                                        &storvsp_protocol::ProtocolVersion {
926                                            major_minor,
927                                            reserved: 0,
928                                        },
929                                    )?;
930                                    self.inner.request_size = version.max_request_size();
931                                    *self.inner.protocol.state.write() =
932                                        ProtocolState::Init(InitState::QueryProperties { version });
933
934                                    tracelimit::info_ratelimited!(
935                                        ?version,
936                                        "scsi version negotiated"
937                                    );
938                                } else {
939                                    self.inner.send_completion(
940                                        &mut writer,
941                                        &packet,
942                                        storvsp_protocol::NtStatus::REVISION_MISMATCH,
943                                        &storvsp_protocol::ProtocolVersion {
944                                            major_minor,
945                                            reserved: 0,
946                                        },
947                                    )?;
948                                    *self.inner.protocol.state.write() =
949                                        ProtocolState::Init(InitState::QueryVersion);
950                                }
951                            } else {
952                                tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
953                                self.inner.send_completion(
954                                    &mut writer,
955                                    &packet,
956                                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
957                                    &(),
958                                )?;
959                            }
960                        }
961                        InitState::QueryProperties { version } => {
962                            let packet = self.inner.next_packet(&mut reader).await?;
963                            if let PacketData::QueryProperties = packet.data {
964                                let multi_channel_supported = version >= Version::Win8;
965
966                                self.inner.send_completion(
967                                    &mut writer,
968                                    &packet,
969                                    storvsp_protocol::NtStatus::SUCCESS,
970                                    &storvsp_protocol::ChannelProperties {
971                                        max_transfer_bytes: 0x40000, // 256KB
972                                        flags: {
973                                            if multi_channel_supported {
974                                                storvsp_protocol::STORAGE_CHANNEL_SUPPORTS_MULTI_CHANNEL
975                                            } else {
976                                                0
977                                            }
978                                        },
979                                        maximum_sub_channel_count: if multi_channel_supported {
980                                            self.inner.channel_control.max_subchannels()
981                                        } else {
982                                            0
983                                        },
984                                        reserved: 0,
985                                        reserved2: 0,
986                                        reserved3: [0, 0],
987                                    },
988                                )?;
989                                *self.inner.protocol.state.write() =
990                                    ProtocolState::Init(InitState::EndInitialization {
991                                        version,
992                                        subchannel_count: if multi_channel_supported {
993                                            None
994                                        } else {
995                                            Some(0)
996                                        },
997                                    });
998                            } else {
999                                tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1000                                self.inner.send_completion(
1001                                    &mut writer,
1002                                    &packet,
1003                                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1004                                    &(),
1005                                )?;
1006                            }
1007                        }
1008                        InitState::EndInitialization {
1009                            version,
1010                            subchannel_count,
1011                        } => {
1012                            let packet = self.inner.next_packet(&mut reader).await?;
1013                            match packet.data {
1014                                PacketData::CreateSubChannels(sub_channel_count)
1015                                    if subchannel_count.is_none() =>
1016                                {
1017                                    if let Err(err) = self
1018                                        .inner
1019                                        .channel_control
1020                                        .enable_subchannels(sub_channel_count)
1021                                    {
1022                                        tracelimit::warn_ratelimited!(
1023                                            ?err,
1024                                            "cannot enable subchannels"
1025                                        );
1026                                        self.inner.send_completion(
1027                                            &mut writer,
1028                                            &packet,
1029                                            storvsp_protocol::NtStatus::INVALID_PARAMETER,
1030                                            &(),
1031                                        )?;
1032                                    } else {
1033                                        self.inner.send_completion(
1034                                            &mut writer,
1035                                            &packet,
1036                                            storvsp_protocol::NtStatus::SUCCESS,
1037                                            &(),
1038                                        )?;
1039                                        *self.inner.protocol.state.write() =
1040                                            ProtocolState::Init(InitState::EndInitialization {
1041                                                version,
1042                                                subchannel_count: Some(sub_channel_count),
1043                                            });
1044                                    }
1045                                }
1046                                PacketData::EndInitialization => {
1047                                    self.inner.send_completion(
1048                                        &mut writer,
1049                                        &packet,
1050                                        storvsp_protocol::NtStatus::SUCCESS,
1051                                        &(),
1052                                    )?;
1053                                    // Reset the rescan notification event now, before the guest has a
1054                                    // chance to send any SCSI requests to scan the bus.
1055                                    self.rescan_notification.try_next().ok();
1056                                    *self.inner.protocol.state.write() = ProtocolState::Ready {
1057                                        version,
1058                                        subchannel_count: subchannel_count.unwrap_or(0),
1059                                    };
1060                                    // Wake up subchannels waiting for the
1061                                    // protocol state to become ready.
1062                                    self.inner.protocol.ready.notify(usize::MAX);
1063                                }
1064                                _ => {
1065                                    tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1066                                    self.inner.send_completion(
1067                                        &mut writer,
1068                                        &packet,
1069                                        storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1070                                        &(),
1071                                    )?;
1072                                }
1073                            }
1074                        }
1075                    }
1076                }
1077            }
1078        }
1079    }
1080}
1081
1082fn convert_srb_status_to_nt_status(srb_status: SrbStatus) -> storvsp_protocol::NtStatus {
1083    match srb_status {
1084        SrbStatus::BUSY => storvsp_protocol::NtStatus::DEVICE_BUSY,
1085        SrbStatus::SUCCESS => storvsp_protocol::NtStatus::SUCCESS,
1086        SrbStatus::INVALID_LUN
1087        | SrbStatus::INVALID_TARGET_ID
1088        | SrbStatus::NO_DEVICE
1089        | SrbStatus::NO_HBA => storvsp_protocol::NtStatus::DEVICE_DOES_NOT_EXIST,
1090        SrbStatus::COMMAND_TIMEOUT | SrbStatus::TIMEOUT => storvsp_protocol::NtStatus::IO_TIMEOUT,
1091        SrbStatus::SELECTION_TIMEOUT => storvsp_protocol::NtStatus::DEVICE_NOT_CONNECTED,
1092        SrbStatus::BAD_FUNCTION | SrbStatus::BAD_SRB_BLOCK_LENGTH => {
1093            storvsp_protocol::NtStatus::INVALID_DEVICE_REQUEST
1094        }
1095        SrbStatus::DATA_OVERRUN => storvsp_protocol::NtStatus::BUFFER_OVERFLOW,
1096        SrbStatus::REQUEST_FLUSHED => storvsp_protocol::NtStatus::UNSUCCESSFUL,
1097        SrbStatus::ABORTED => storvsp_protocol::NtStatus::CANCELLED,
1098        _ => storvsp_protocol::NtStatus::IO_DEVICE_ERROR,
1099    }
1100}
1101
1102impl WorkerInner {
1103    /// Processes packets and SCSI completions after protocol negotiation has finished.
1104    async fn process_ready<M: RingMem>(
1105        &mut self,
1106        queue: &mut Queue<M>,
1107        protocol_version: Version,
1108    ) -> Result<(), WorkerError> {
1109        self.request_size = protocol_version.max_request_size();
1110        poll_fn(|cx| self.poll_process_ready(cx, queue)).await
1111    }
1112
1113    /// Processes packets and SCSI completions after protocol negotiation has finished.
1114    fn poll_process_ready<M: RingMem>(
1115        &mut self,
1116        cx: &mut Context<'_>,
1117        queue: &mut Queue<M>,
1118    ) -> Poll<Result<(), WorkerError>> {
1119        self.stats.wakes.increment();
1120
1121        let (mut reader, mut writer) = queue.split();
1122        let mut total_completions = 0;
1123        let mut total_submissions = 0;
1124
1125        loop {
1126            // Drive IOs forward and collect completions.
1127            'outer: while !self.scsi_requests_states.is_empty() {
1128                {
1129                    let mut batch = writer.batched();
1130                    loop {
1131                        // Ensure there is room for the completion before consuming
1132                        // the IO so that we don't have to track completed IOs whose
1133                        // completions haven't been sent.
1134                        if !batch
1135                            .can_write(MAX_VMBUS_PACKET_SIZE)
1136                            .map_err(WorkerError::Queue)?
1137                        {
1138                            // This batch is full but there may still be more completions.
1139                            break;
1140                        }
1141                        if let Poll::Ready(Some((request_id, result, future))) =
1142                            self.scsi_requests.poll_next_unpin(cx)
1143                        {
1144                            self.future_pool.push(future);
1145                            self.handle_completion(&mut batch, request_id, result)?;
1146                            total_completions += 1;
1147                        } else {
1148                            tracing::trace!("out of completions");
1149                            break 'outer;
1150                        }
1151                    }
1152                }
1153
1154                // Wait for enough space for any completion packets.
1155                if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1156                    tracing::trace!("out of ring space");
1157                    break;
1158                }
1159            }
1160
1161            let mut submissions = 0;
1162            // Process new requests.
1163            'outer: loop {
1164                if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1165                    break;
1166                }
1167                let mut batch = if self.scsi_requests_states.is_empty() {
1168                    if let Poll::Ready(batch) = reader.poll_read_batch(cx) {
1169                        batch.map_err(WorkerError::Queue)?
1170                    } else {
1171                        tracing::trace!("out of incoming packets");
1172                        break;
1173                    }
1174                } else {
1175                    match reader.try_read_batch() {
1176                        Ok(batch) => batch,
1177                        Err(queue::TryReadError::Empty) => {
1178                            tracing::trace!(
1179                                pending_io_count = self.scsi_requests_states.len(),
1180                                "out of incoming packets, keeping interrupts masked"
1181                            );
1182                            break;
1183                        }
1184                        Err(queue::TryReadError::Queue(err)) => Err(WorkerError::Queue(err))?,
1185                    }
1186                };
1187
1188                let mut packets = batch.packets();
1189                loop {
1190                    if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1191                        break 'outer;
1192                    }
1193                    // Wait for enough space for any completion packets that
1194                    // `handle_packet` may need to send, so that it isn't necessary
1195                    // to track pending completions.
1196                    if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1197                        tracing::trace!("out of ring space");
1198                        break 'outer;
1199                    }
1200
1201                    let packet = if let Some(packet) = packets.next() {
1202                        packet.map_err(WorkerError::Queue)?
1203                    } else {
1204                        break;
1205                    };
1206
1207                    if self.handle_packet(&mut writer, &packet)? {
1208                        submissions += 1;
1209                    }
1210                }
1211            }
1212
1213            // Loop around to poll the IOs if any new IOs were submitted.
1214            if submissions == 0 {
1215                // No need to poll again.
1216                break;
1217            }
1218            total_submissions += submissions;
1219        }
1220
1221        if total_submissions != 0 || total_completions != 0 {
1222            self.stats.ios_submitted.add(total_submissions);
1223            self.stats
1224                .per_wake_submissions
1225                .add_sample(total_submissions);
1226            self.stats
1227                .per_wake_completions
1228                .add_sample(total_completions);
1229            self.stats.ios_completed.add(total_completions);
1230        } else {
1231            self.stats.wakes_spurious.increment();
1232        }
1233
1234        Poll::Pending
1235    }
1236
1237    fn handle_completion<M: RingMem>(
1238        &mut self,
1239        writer: &mut queue::WriteBatch<'_, M>,
1240        request_id: usize,
1241        result: ScsiResult,
1242    ) -> Result<(), WorkerError> {
1243        let state = self.scsi_requests_states.remove(request_id);
1244        let request_size = state.request.request_size;
1245
1246        // Push the request into the pool to avoid reallocating later.
1247        assert_eq!(
1248            Arc::strong_count(&state.request) + Arc::weak_count(&state.request),
1249            1
1250        );
1251        self.full_request_pool.push(state.request);
1252
1253        let status = convert_srb_status_to_nt_status(result.srb_status);
1254        let mut payload = [0; 0x14];
1255        if let Some(sense) = result.sense_data {
1256            payload[..size_of_val(&sense)].copy_from_slice(sense.as_bytes());
1257            tracing::trace!(sense_info = ?payload, sense_key = payload[2], asc = payload[12], "execute_scsi");
1258        };
1259        let response = storvsp_protocol::ScsiRequest {
1260            length: size_of::<storvsp_protocol::ScsiRequest>() as u16,
1261            scsi_status: result.scsi_status,
1262            srb_status: SrbStatusAndFlags::new()
1263                .with_status(result.srb_status)
1264                .with_autosense_valid(result.sense_data.is_some()),
1265            data_transfer_length: result.tx as u32,
1266            cdb_length: storvsp_protocol::CDB16GENERIC_LENGTH as u8,
1267            sense_info_ex_length: storvsp_protocol::VMSCSI_SENSE_BUFFER_SIZE as u8,
1268            payload,
1269            ..storvsp_protocol::ScsiRequest::new_zeroed()
1270        };
1271        self.send_vmbus_packet(
1272            writer,
1273            OutgoingPacketType::Completion,
1274            request_size,
1275            state.transaction_id,
1276            storvsp_protocol::Operation::COMPLETE_IO,
1277            status,
1278            response.as_bytes(),
1279        )?;
1280        Ok(())
1281    }
1282
1283    fn handle_packet<M: RingMem>(
1284        &mut self,
1285        writer: &mut queue::WriteHalf<'_, M>,
1286        packet: &IncomingPacket<'_, M>,
1287    ) -> Result<bool, WorkerError> {
1288        let packet =
1289            parse_packet(packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
1290        let submitted_io = match packet.data {
1291            PacketData::ExecuteScsi(request) => {
1292                self.push_scsi_request(packet.transaction_id, request);
1293                true
1294            }
1295            PacketData::ResetAdapter | PacketData::ResetBus | PacketData::ResetLun => {
1296                // These operations have always been no-ops.
1297                self.send_completion(writer, &packet, storvsp_protocol::NtStatus::SUCCESS, &())?;
1298                false
1299            }
1300            PacketData::CreateSubChannels(new_subchannel_count) if self.channel_index == 0 => {
1301                if let Err(err) = self
1302                    .channel_control
1303                    .enable_subchannels(new_subchannel_count)
1304                {
1305                    tracelimit::warn_ratelimited!(?err, "cannot create subchannels");
1306                    self.send_completion(
1307                        writer,
1308                        &packet,
1309                        storvsp_protocol::NtStatus::INVALID_PARAMETER,
1310                        &(),
1311                    )?;
1312                    false
1313                } else {
1314                    // Update the subchannel count in the protocol state for save.
1315                    if let ProtocolState::Ready {
1316                        subchannel_count, ..
1317                    } = &mut *self.protocol.state.write()
1318                    {
1319                        *subchannel_count = new_subchannel_count;
1320                    } else {
1321                        unreachable!()
1322                    }
1323
1324                    self.send_completion(
1325                        writer,
1326                        &packet,
1327                        storvsp_protocol::NtStatus::SUCCESS,
1328                        &(),
1329                    )?;
1330                    false
1331                }
1332            }
1333            _ => {
1334                tracelimit::warn_ratelimited!(data = ?packet.data, "unexpected packet on ready");
1335                self.send_completion(
1336                    writer,
1337                    &packet,
1338                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1339                    &(),
1340                )?;
1341                false
1342            }
1343        };
1344        Ok(submitted_io)
1345    }
1346
1347    fn push_scsi_request(&mut self, transaction_id: u64, full_request: Arc<ScsiRequestAndRange>) {
1348        let scsi_queue = self.scsi_queue.clone();
1349        let scsi_request_state = ScsiRequestState {
1350            transaction_id,
1351            request: full_request.clone(),
1352        };
1353        let request_id = self.scsi_requests_states.insert(scsi_request_state);
1354        let future = self
1355            .future_pool
1356            .pop()
1357            .unwrap_or_else(|| OversizedBox::new(()));
1358        let future = OversizedBox::refill(future, async move {
1359            scsi_queue
1360                .execute_scsi(&full_request.external_data, &full_request.request)
1361                .await
1362        });
1363        let request = ScsiRequest::new(request_id, oversized_box::coerce!(future));
1364        self.scsi_requests.push(request);
1365    }
1366}
1367
1368impl<T: RingMem> Drop for Worker<T> {
1369    fn drop(&mut self) {
1370        self.inner
1371            .controller
1372            .remove_rescan_notification_source(&self.rescan_notification);
1373    }
1374}
1375
1376#[derive(Debug, Error)]
1377#[error("SCSI path {}:{}:{} is already in use", self.0.path, self.0.target, self.0.lun)]
1378pub struct ScsiPathInUse(pub ScsiPath);
1379
1380#[derive(Debug, Error)]
1381#[error("SCSI path {}:{}:{} is not in use", self.0.path, self.0.target, self.0.lun)]
1382pub struct ScsiPathNotInUse(ScsiPath);
1383
1384#[derive(Clone)]
1385struct ScsiRequestState {
1386    transaction_id: u64,
1387    request: Arc<ScsiRequestAndRange>,
1388}
1389
1390#[derive(Debug)]
1391struct ScsiRequestAndRange {
1392    external_data: Range,
1393    request: storvsp_protocol::ScsiRequest,
1394    request_size: usize,
1395}
1396
1397impl Inspect for ScsiRequestState {
1398    fn inspect(&self, req: inspect::Request<'_>) {
1399        req.respond()
1400            .field("transaction_id", self.transaction_id)
1401            .display(
1402                "address",
1403                &ScsiPath {
1404                    path: self.request.request.path_id,
1405                    target: self.request.request.target_id,
1406                    lun: self.request.request.lun,
1407                },
1408            )
1409            .display_debug("operation", &ScsiOp(self.request.request.payload[0]));
1410    }
1411}
1412
1413impl StorageDevice {
1414    /// Returns a new SCSI device.
1415    pub fn build_scsi(
1416        driver_source: &VmTaskDriverSource,
1417        controller: &ScsiController,
1418        instance_id: Guid,
1419        max_sub_channel_count: u16,
1420        io_queue_depth: u32,
1421    ) -> Self {
1422        Self::build_inner(
1423            driver_source,
1424            controller,
1425            instance_id,
1426            None,
1427            max_sub_channel_count,
1428            io_queue_depth,
1429        )
1430    }
1431
1432    /// Returns a new SCSI device for implementing an IDE accelerator channel
1433    /// for IDE device `device_id` on channel `channel_id`.
1434    pub fn build_ide(
1435        driver_source: &VmTaskDriverSource,
1436        channel_id: u8,
1437        device_id: u8,
1438        disk: ScsiControllerDisk,
1439        io_queue_depth: u32,
1440    ) -> Self {
1441        let path = ScsiPath {
1442            path: channel_id,
1443            target: device_id,
1444            lun: 0,
1445        };
1446
1447        let controller = ScsiController::new();
1448        controller.attach(path, disk).unwrap();
1449
1450        // Construct the specific GUID that drivers in the guest expect for this
1451        // IDE device.
1452        let instance_id = Guid {
1453            data1: channel_id.into(),
1454            data2: device_id.into(),
1455            data3: 0x8899,
1456            data4: [0; 8],
1457        };
1458        Self::build_inner(
1459            driver_source,
1460            &controller,
1461            instance_id,
1462            Some(path),
1463            0,
1464            io_queue_depth,
1465        )
1466    }
1467
1468    fn build_inner(
1469        driver_source: &VmTaskDriverSource,
1470        controller: &ScsiController,
1471        instance_id: Guid,
1472        ide_path: Option<ScsiPath>,
1473        max_sub_channel_count: u16,
1474        io_queue_depth: u32,
1475    ) -> Self {
1476        let workers = (0..max_sub_channel_count + 1)
1477            .map(|channel_index| WorkerAndDriver {
1478                worker: TaskControl::new(WorkerState),
1479                driver: driver_source
1480                    .builder()
1481                    .target_vp(0)
1482                    .run_on_target(true)
1483                    .build(format!("storvsp-{}-{}", instance_id, channel_index)),
1484            })
1485            .collect();
1486
1487        Self {
1488            instance_id,
1489            ide_path,
1490            workers,
1491            controller: controller.state.clone(),
1492            resources: Default::default(),
1493            max_sub_channel_count,
1494            driver_source: driver_source.clone(),
1495            protocol: Arc::new(Protocol {
1496                state: RwLock::new(ProtocolState::Init(InitState::Begin)),
1497                ready: Default::default(),
1498            }),
1499            io_queue_depth,
1500        }
1501    }
1502
1503    fn new_worker(
1504        &mut self,
1505        open_request: &OpenRequest,
1506        channel_index: u16,
1507    ) -> anyhow::Result<&mut Worker> {
1508        let controller = self.controller.clone();
1509
1510        let driver = self
1511            .driver_source
1512            .builder()
1513            .target_vp(open_request.open_data.target_vp)
1514            .run_on_target(true)
1515            .build(format!("storvsp-{}-{}", self.instance_id, channel_index));
1516
1517        let channel = gpadl_channel(&driver, &self.resources, open_request, channel_index)
1518            .context("failed to create vmbus channel")?;
1519
1520        let channel_control = self.resources.channel_control.clone();
1521
1522        tracing::debug!(
1523            target_vp = open_request.open_data.target_vp,
1524            channel_index,
1525            "packet processing starting...",
1526        );
1527
1528        // Force the path ID on incoming SCSI requests to match the IDE
1529        // channel ID, since guests do not reliably set the path ID
1530        // correctly.
1531        let force_path_id = self.ide_path.map(|p| p.path);
1532
1533        let worker = Worker::new(
1534            controller,
1535            channel,
1536            channel_index,
1537            self.resources
1538                .offer_resources
1539                .guest_memory(open_request)
1540                .clone(),
1541            channel_control,
1542            self.io_queue_depth,
1543            self.protocol.clone(),
1544            force_path_id,
1545        )
1546        .map_err(RestoreError::Other)?;
1547
1548        self.workers[channel_index as usize]
1549            .driver
1550            .retarget_vp(open_request.open_data.target_vp);
1551
1552        Ok(self.workers[channel_index as usize].worker.insert(
1553            &driver,
1554            format!("storvsp worker {}-{}", self.instance_id, channel_index),
1555            worker,
1556        ))
1557    }
1558}
1559
1560/// A disk that can be added to a SCSI controller.
1561#[derive(Clone)]
1562pub struct ScsiControllerDisk {
1563    disk: Arc<dyn AsyncScsiDisk>,
1564}
1565
1566impl ScsiControllerDisk {
1567    /// Creates a new controller disk from an async SCSI disk.
1568    pub fn new(disk: Arc<dyn AsyncScsiDisk>) -> Self {
1569        Self { disk }
1570    }
1571}
1572
1573struct ScsiControllerState {
1574    disks: RwLock<HashMap<ScsiPath, ScsiControllerDisk>>,
1575    rescan_notification_source: Mutex<Vec<futures::channel::mpsc::Sender<()>>>,
1576}
1577
1578pub struct ScsiController {
1579    state: Arc<ScsiControllerState>,
1580}
1581
1582impl ScsiController {
1583    pub fn new() -> Self {
1584        Self {
1585            state: Arc::new(ScsiControllerState {
1586                disks: Default::default(),
1587                rescan_notification_source: Mutex::new(Vec::new()),
1588            }),
1589        }
1590    }
1591
1592    pub fn attach(&self, path: ScsiPath, disk: ScsiControllerDisk) -> Result<(), ScsiPathInUse> {
1593        match self.state.disks.write().entry(path) {
1594            Entry::Occupied(_) => return Err(ScsiPathInUse(path)),
1595            Entry::Vacant(entry) => entry.insert(disk),
1596        };
1597        for source in self.state.rescan_notification_source.lock().iter_mut() {
1598            // Ok to ignore errors here. If the channel is full a previous notification has not yet
1599            // been processed by the primary channel worker.
1600            source.try_send(()).ok();
1601        }
1602        Ok(())
1603    }
1604
1605    pub fn remove(&self, path: ScsiPath) -> Result<(), ScsiPathNotInUse> {
1606        match self.state.disks.write().entry(path) {
1607            Entry::Vacant(_) => return Err(ScsiPathNotInUse(path)),
1608            Entry::Occupied(entry) => {
1609                entry.remove();
1610            }
1611        }
1612        for source in self.state.rescan_notification_source.lock().iter_mut() {
1613            // Ok to ignore errors here. If the channel is full a previous notification has not yet
1614            // been processed by the primary channel worker.
1615            source.try_send(()).ok();
1616        }
1617        Ok(())
1618    }
1619}
1620
1621impl ScsiControllerState {
1622    fn add_rescan_notification_source(&self, source: futures::channel::mpsc::Sender<()>) {
1623        self.rescan_notification_source.lock().push(source);
1624    }
1625
1626    fn remove_rescan_notification_source(&self, target: &futures::channel::mpsc::Receiver<()>) {
1627        let mut sources = self.rescan_notification_source.lock();
1628        if let Some(index) = sources
1629            .iter()
1630            .position(|source| source.is_connected_to(target))
1631        {
1632            sources.remove(index);
1633        }
1634    }
1635}
1636
1637#[async_trait]
1638impl VmbusDevice for StorageDevice {
1639    fn offer(&self) -> OfferParams {
1640        if let Some(path) = self.ide_path {
1641            let offer_properties = storvsp_protocol::OfferProperties {
1642                path_id: path.path,
1643                target_id: path.target,
1644                flags: storvsp_protocol::OFFER_PROPERTIES_FLAG_IDE_DEVICE,
1645                ..FromZeros::new_zeroed()
1646            };
1647            let mut user_defined = UserDefinedData::new_zeroed();
1648            offer_properties
1649                .write_to_prefix(&mut user_defined[..])
1650                .unwrap();
1651            OfferParams {
1652                interface_name: "ide-accel".to_owned(),
1653                instance_id: self.instance_id,
1654                interface_id: storvsp_protocol::IDE_ACCELERATOR_INTERFACE_ID,
1655                channel_type: ChannelType::Interface { user_defined },
1656                ..Default::default()
1657            }
1658        } else {
1659            OfferParams {
1660                interface_name: "scsi".to_owned(),
1661                instance_id: self.instance_id,
1662                interface_id: storvsp_protocol::SCSI_INTERFACE_ID,
1663                ..Default::default()
1664            }
1665        }
1666    }
1667
1668    fn max_subchannels(&self) -> u16 {
1669        self.max_sub_channel_count
1670    }
1671
1672    fn install(&mut self, resources: DeviceResources) {
1673        self.resources = resources;
1674    }
1675
1676    async fn open(
1677        &mut self,
1678        channel_index: u16,
1679        open_request: &OpenRequest,
1680    ) -> Result<(), ChannelOpenError> {
1681        tracing::debug!(channel_index, "scsi open channel");
1682        self.new_worker(open_request, channel_index)?;
1683        self.workers[channel_index as usize].worker.start();
1684        Ok(())
1685    }
1686
1687    async fn close(&mut self, channel_index: u16) {
1688        tracing::debug!(channel_index, "scsi close channel");
1689        let worker = &mut self.workers[channel_index as usize].worker;
1690        worker.stop().await;
1691        if worker.state_mut().is_some() {
1692            worker
1693                .state_mut()
1694                .unwrap()
1695                .wait_for_scsi_requests_complete()
1696                .await;
1697            worker.remove();
1698        }
1699        if channel_index == 0 {
1700            *self.protocol.state.write() = ProtocolState::Init(InitState::Begin);
1701        }
1702    }
1703
1704    async fn retarget_vp(&mut self, channel_index: u16, target_vp: u32) {
1705        self.workers[channel_index as usize]
1706            .driver
1707            .retarget_vp(target_vp);
1708    }
1709
1710    fn start(&mut self) {
1711        for task in self
1712            .workers
1713            .iter_mut()
1714            .filter(|task| task.worker.has_state() && !task.worker.is_running())
1715        {
1716            task.worker.start();
1717        }
1718    }
1719
1720    async fn stop(&mut self) {
1721        tracing::debug!(instance_id = ?self.instance_id, "StorageDevice stopping...");
1722        for task in self
1723            .workers
1724            .iter_mut()
1725            .filter(|task| task.worker.has_state() && task.worker.is_running())
1726        {
1727            task.worker.stop().await;
1728        }
1729    }
1730
1731    fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1732        Some(self)
1733    }
1734}
1735
1736#[async_trait]
1737impl SaveRestoreVmbusDevice for StorageDevice {
1738    async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1739        Ok(SavedStateBlob::new(self.save()?))
1740    }
1741
1742    async fn restore(
1743        &mut self,
1744        control: RestoreControl<'_>,
1745        state: SavedStateBlob,
1746    ) -> Result<(), RestoreError> {
1747        self.restore(control, state.parse()?).await
1748    }
1749}
1750
1751#[cfg(test)]
1752mod tests {
1753    use super::*;
1754    use crate::test_helpers::TestWorker;
1755    use crate::test_helpers::parse_guest_completion;
1756    use crate::test_helpers::parse_guest_completion_check_flags_status;
1757    use pal_async::DefaultDriver;
1758    use pal_async::async_test;
1759    use scsi::srb::SrbStatus;
1760    use test_with_tracing::test;
1761    use vmbus_channel::connected_async_channels;
1762
1763    // Discourage `Clone` for `ScsiController` outside the crate, but it is
1764    // necessary for testing. The fuzzer also uses `TestWorker`, which needs
1765    // a `clone` of the inner state, but is not in this crate.
1766    impl Clone for ScsiController {
1767        fn clone(&self) -> Self {
1768            ScsiController {
1769                state: self.state.clone(),
1770            }
1771        }
1772    }
1773
1774    #[async_test]
1775    async fn test_channel_working(driver: DefaultDriver) {
1776        // set up the channels and worker
1777        let (host, guest) = connected_async_channels(16 * 1024);
1778        let guest_queue = Queue::new(guest).unwrap();
1779
1780        let test_guest_mem = GuestMemory::allocate(16384);
1781        let controller = ScsiController::new();
1782        let disk = scsidisk::SimpleScsiDisk::new(
1783            disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
1784            Default::default(),
1785        );
1786        controller
1787            .attach(
1788                ScsiPath {
1789                    path: 0,
1790                    target: 0,
1791                    lun: 0,
1792                },
1793                ScsiControllerDisk::new(Arc::new(disk)),
1794            )
1795            .unwrap();
1796
1797        let test_worker = TestWorker::start(
1798            controller.clone(),
1799            driver.clone(),
1800            test_guest_mem.clone(),
1801            host,
1802            None,
1803        );
1804
1805        let mut guest = test_helpers::TestGuest {
1806            queue: guest_queue,
1807            transaction_id: 0,
1808        };
1809
1810        guest.perform_protocol_negotiation().await;
1811
1812        // Set up the buffer for a write request
1813        const IO_LEN: usize = 4 * 1024;
1814        let write_buf = [7u8; IO_LEN];
1815        let write_gpa = 4 * 1024u64;
1816        test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
1817        guest
1818            .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
1819            .await;
1820
1821        guest
1822            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1823            .await;
1824
1825        let read_gpa = 8 * 1024u64;
1826        guest
1827            .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
1828            .await;
1829
1830        guest
1831            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1832            .await;
1833        let mut read_buf = [0u8; IO_LEN];
1834        test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
1835        for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
1836            assert_eq!(b1, b2);
1837        }
1838
1839        // stop everything
1840        guest.verify_graceful_close(test_worker).await;
1841    }
1842
1843    #[async_test]
1844    async fn test_packet_sizes(driver: DefaultDriver) {
1845        // set up the channels and worker
1846        let (host, guest) = connected_async_channels(16384);
1847        let guest_queue = Queue::new(guest).unwrap();
1848
1849        let test_guest_mem = GuestMemory::allocate(1024);
1850        let controller = ScsiController::new();
1851
1852        let _worker = TestWorker::start(
1853            controller.clone(),
1854            driver.clone(),
1855            test_guest_mem,
1856            host,
1857            None,
1858        );
1859
1860        let mut guest = test_helpers::TestGuest {
1861            queue: guest_queue,
1862            transaction_id: 0,
1863        };
1864
1865        let negotiate_packet = storvsp_protocol::Packet {
1866            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
1867            flags: 0,
1868            status: storvsp_protocol::NtStatus::SUCCESS,
1869        };
1870        guest
1871            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1872            .await;
1873
1874        guest.verify_completion(parse_guest_completion).await;
1875
1876        let header = storvsp_protocol::Packet {
1877            operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
1878            flags: 0,
1879            status: storvsp_protocol::NtStatus::SUCCESS,
1880        };
1881
1882        let mut buf = [0u8; 128];
1883        storvsp_protocol::ProtocolVersion {
1884            major_minor: !0,
1885            reserved: 0,
1886        }
1887        .write_to_prefix(&mut buf[..])
1888        .unwrap(); // PANIC: Infallable since `ProtcolVersion` is less than 128 bytes
1889
1890        for &(len, resp_len) in &[(48, 48), (50, 56), (56, 56), (64, 64), (72, 64)] {
1891            guest
1892                .send_data_packet_sync(&[header.as_bytes(), &buf[..len - size_of_val(&header)]])
1893                .await;
1894
1895            guest
1896                .verify_completion(|packet| {
1897                    let IncomingPacket::Completion(packet) = packet else {
1898                        unreachable!()
1899                    };
1900                    assert_eq!(packet.reader().len(), resp_len);
1901                    assert_eq!(
1902                        packet
1903                            .reader()
1904                            .read_plain::<storvsp_protocol::Packet>()
1905                            .unwrap()
1906                            .status,
1907                        storvsp_protocol::NtStatus::REVISION_MISMATCH
1908                    );
1909                    Ok(())
1910                })
1911                .await;
1912        }
1913    }
1914
1915    #[async_test]
1916    async fn test_wrong_first_packet(driver: DefaultDriver) {
1917        // set up the channels and worker
1918        let (host, guest) = connected_async_channels(16384);
1919        let guest_queue = Queue::new(guest).unwrap();
1920
1921        let test_guest_mem = GuestMemory::allocate(1024);
1922        let controller = ScsiController::new();
1923
1924        let _worker = TestWorker::start(
1925            controller.clone(),
1926            driver.clone(),
1927            test_guest_mem,
1928            host,
1929            None,
1930        );
1931
1932        let mut guest = test_helpers::TestGuest {
1933            queue: guest_queue,
1934            transaction_id: 0,
1935        };
1936
1937        // Protocol negotiation done out of order
1938        let negotiate_packet = storvsp_protocol::Packet {
1939            operation: storvsp_protocol::Operation::END_INITIALIZATION,
1940            flags: 0,
1941            status: storvsp_protocol::NtStatus::SUCCESS,
1942        };
1943        guest
1944            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1945            .await;
1946
1947        guest
1948            .verify_completion(|packet| {
1949                let IncomingPacket::Completion(packet) = packet else {
1950                    unreachable!()
1951                };
1952                assert_eq!(
1953                    packet
1954                        .reader()
1955                        .read_plain::<storvsp_protocol::Packet>()
1956                        .unwrap()
1957                        .status,
1958                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE
1959                );
1960                Ok(())
1961            })
1962            .await;
1963    }
1964
1965    #[async_test]
1966    async fn test_unrecognized_operation(driver: DefaultDriver) {
1967        // set up the channels and worker
1968        let (host, guest) = connected_async_channels(16384);
1969        let guest_queue = Queue::new(guest).unwrap();
1970
1971        let test_guest_mem = GuestMemory::allocate(1024);
1972        let controller = ScsiController::new();
1973
1974        let worker = TestWorker::start(
1975            controller.clone(),
1976            driver.clone(),
1977            test_guest_mem,
1978            host,
1979            None,
1980        );
1981
1982        let mut guest = test_helpers::TestGuest {
1983            queue: guest_queue,
1984            transaction_id: 0,
1985        };
1986
1987        // Send packet with unrecognized operation
1988        let negotiate_packet = storvsp_protocol::Packet {
1989            operation: storvsp_protocol::Operation::REMOVE_DEVICE,
1990            flags: 0,
1991            status: storvsp_protocol::NtStatus::SUCCESS,
1992        };
1993        guest
1994            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1995            .await;
1996
1997        match worker.teardown().await {
1998            Err(WorkerError::PacketError(PacketError::UnrecognizedOperation(
1999                storvsp_protocol::Operation::REMOVE_DEVICE,
2000            ))) => {}
2001            result => panic!("Worker failed with unexpected result {:?}!", result),
2002        }
2003    }
2004
2005    #[async_test]
2006    async fn test_too_many_subchannels(driver: DefaultDriver) {
2007        // set up the channels and worker
2008        let (host, guest) = connected_async_channels(16384);
2009        let guest_queue = Queue::new(guest).unwrap();
2010
2011        let test_guest_mem = GuestMemory::allocate(1024);
2012        let controller = ScsiController::new();
2013
2014        let _worker = TestWorker::start(
2015            controller.clone(),
2016            driver.clone(),
2017            test_guest_mem,
2018            host,
2019            None,
2020        );
2021
2022        let mut guest = test_helpers::TestGuest {
2023            queue: guest_queue,
2024            transaction_id: 0,
2025        };
2026
2027        let negotiate_packet = storvsp_protocol::Packet {
2028            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2029            flags: 0,
2030            status: storvsp_protocol::NtStatus::SUCCESS,
2031        };
2032        guest
2033            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2034            .await;
2035        guest.verify_completion(parse_guest_completion).await;
2036
2037        let version_packet = storvsp_protocol::Packet {
2038            operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2039            flags: 0,
2040            status: storvsp_protocol::NtStatus::SUCCESS,
2041        };
2042        let version = storvsp_protocol::ProtocolVersion {
2043            major_minor: storvsp_protocol::VERSION_BLUE,
2044            reserved: 0,
2045        };
2046        guest
2047            .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2048            .await;
2049        guest.verify_completion(parse_guest_completion).await;
2050
2051        let properties_packet = storvsp_protocol::Packet {
2052            operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2053            flags: 0,
2054            status: storvsp_protocol::NtStatus::SUCCESS,
2055        };
2056        guest
2057            .send_data_packet_sync(&[properties_packet.as_bytes()])
2058            .await;
2059
2060        guest.verify_completion(parse_guest_completion).await;
2061
2062        let negotiate_packet = storvsp_protocol::Packet {
2063            operation: storvsp_protocol::Operation::CREATE_SUB_CHANNELS,
2064            flags: 0,
2065            status: storvsp_protocol::NtStatus::SUCCESS,
2066        };
2067        // Create sub channels more than maximum_sub_channel_count
2068        guest
2069            .send_data_packet_sync(&[negotiate_packet.as_bytes(), 1_u16.as_bytes()])
2070            .await;
2071
2072        guest
2073            .verify_completion(|packet| {
2074                let IncomingPacket::Completion(packet) = packet else {
2075                    unreachable!()
2076                };
2077                assert_eq!(
2078                    packet
2079                        .reader()
2080                        .read_plain::<storvsp_protocol::Packet>()
2081                        .unwrap()
2082                        .status,
2083                    storvsp_protocol::NtStatus::INVALID_PARAMETER
2084                );
2085                Ok(())
2086            })
2087            .await;
2088    }
2089
2090    #[async_test]
2091    async fn test_begin_init_on_ready(driver: DefaultDriver) {
2092        // set up the channels and worker
2093        let (host, guest) = connected_async_channels(16384);
2094        let guest_queue = Queue::new(guest).unwrap();
2095
2096        let test_guest_mem = GuestMemory::allocate(1024);
2097        let controller = ScsiController::new();
2098
2099        let _worker = TestWorker::start(
2100            controller.clone(),
2101            driver.clone(),
2102            test_guest_mem,
2103            host,
2104            None,
2105        );
2106
2107        let mut guest = test_helpers::TestGuest {
2108            queue: guest_queue,
2109            transaction_id: 0,
2110        };
2111
2112        guest.perform_protocol_negotiation().await;
2113
2114        // Protocol negotiation done out of order
2115        let negotiate_packet = storvsp_protocol::Packet {
2116            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2117            flags: 0,
2118            status: storvsp_protocol::NtStatus::SUCCESS,
2119        };
2120        guest
2121            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2122            .await;
2123
2124        guest
2125            .verify_completion(|p| {
2126                parse_guest_completion_check_flags_status(
2127                    p,
2128                    0,
2129                    storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
2130                )
2131            })
2132            .await;
2133    }
2134
2135    #[async_test]
2136    async fn test_hot_add_remove(driver: DefaultDriver) {
2137        // set up channels and worker.
2138        let (host, guest) = connected_async_channels(16 * 1024);
2139        let guest_queue = Queue::new(guest).unwrap();
2140
2141        let test_guest_mem = GuestMemory::allocate(16384);
2142        // create a controller with no disk yet.
2143        let controller = ScsiController::new();
2144
2145        let test_worker = TestWorker::start(
2146            controller.clone(),
2147            driver.clone(),
2148            test_guest_mem.clone(),
2149            host,
2150            None,
2151        );
2152
2153        let mut guest = test_helpers::TestGuest {
2154            queue: guest_queue,
2155            transaction_id: 0,
2156        };
2157
2158        guest.perform_protocol_negotiation().await;
2159
2160        // Verify no LUNs are reported initially.
2161        let mut lun_list_buffer: [u8; 256] = [0; 256];
2162        let mut disk_count = 0;
2163        guest
2164            .send_report_luns_packet(ScsiPath::default(), 0, lun_list_buffer.len())
2165            .await;
2166        guest
2167            .verify_completion(|p| {
2168                test_helpers::parse_guest_completed_io_check_tx_len(p, SrbStatus::SUCCESS, Some(8))
2169            })
2170            .await;
2171        test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2172        let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2173        assert_eq!(lun_list_size, disk_count as u32 * 8);
2174
2175        // Set up a buffer for writes.
2176        const IO_LEN: usize = 4 * 1024;
2177        let write_buf = [7u8; IO_LEN];
2178        let write_gpa = 4 * 1024u64;
2179        test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2180
2181        guest
2182            .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2183            .await;
2184        guest
2185            .verify_completion(|p| {
2186                test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2187            })
2188            .await;
2189
2190        // Add some disks while the guest is running.
2191        for lun in 0..4 {
2192            let disk = scsidisk::SimpleScsiDisk::new(
2193                disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
2194                Default::default(),
2195            );
2196            controller
2197                .attach(
2198                    ScsiPath {
2199                        path: 0,
2200                        target: 0,
2201                        lun,
2202                    },
2203                    ScsiControllerDisk::new(Arc::new(disk)),
2204                )
2205                .unwrap();
2206            guest
2207                .verify_completion(test_helpers::parse_guest_enumerate_bus)
2208                .await;
2209
2210            disk_count += 1;
2211            guest
2212                .send_report_luns_packet(ScsiPath::default(), 0, 256)
2213                .await;
2214            guest
2215                .verify_completion(|p| {
2216                    test_helpers::parse_guest_completed_io_check_tx_len(
2217                        p,
2218                        SrbStatus::SUCCESS,
2219                        Some((disk_count + 1) * 8),
2220                    )
2221                })
2222                .await;
2223            test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2224            let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2225            assert_eq!(lun_list_size, disk_count as u32 * 8);
2226
2227            guest
2228                .send_write_packet(
2229                    ScsiPath {
2230                        path: 0,
2231                        target: 0,
2232                        lun,
2233                    },
2234                    write_gpa,
2235                    1,
2236                    IO_LEN,
2237                )
2238                .await;
2239            guest
2240                .verify_completion(|p| {
2241                    test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS)
2242                })
2243                .await;
2244        }
2245
2246        // Remove all disks while the guest is running.
2247        for lun in 0..4 {
2248            controller
2249                .remove(ScsiPath {
2250                    path: 0,
2251                    target: 0,
2252                    lun,
2253                })
2254                .unwrap();
2255            guest
2256                .verify_completion(test_helpers::parse_guest_enumerate_bus)
2257                .await;
2258
2259            disk_count -= 1;
2260            guest
2261                .send_report_luns_packet(ScsiPath::default(), 0, 4096)
2262                .await;
2263            guest
2264                .verify_completion(|p| {
2265                    test_helpers::parse_guest_completed_io_check_tx_len(
2266                        p,
2267                        SrbStatus::SUCCESS,
2268                        Some((disk_count + 1) * 8),
2269                    )
2270                })
2271                .await;
2272            test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2273            let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2274            assert_eq!(lun_list_size, disk_count as u32 * 8);
2275
2276            guest
2277                .send_write_packet(
2278                    ScsiPath {
2279                        path: 0,
2280                        target: 0,
2281                        lun,
2282                    },
2283                    write_gpa,
2284                    1,
2285                    IO_LEN,
2286                )
2287                .await;
2288            guest
2289                .verify_completion(|p| {
2290                    test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2291                })
2292                .await;
2293        }
2294
2295        guest.verify_graceful_close(test_worker).await;
2296    }
2297
2298    #[async_test]
2299    pub async fn test_async_disk(driver: DefaultDriver) {
2300        let device = disklayer_ram::ram_disk(64 * 1024, false).unwrap();
2301        let controller = ScsiController::new();
2302        let disk = ScsiControllerDisk::new(Arc::new(scsidisk::SimpleScsiDisk::new(
2303            device,
2304            Default::default(),
2305        )));
2306        controller
2307            .attach(
2308                ScsiPath {
2309                    path: 0,
2310                    target: 0,
2311                    lun: 0,
2312                },
2313                disk,
2314            )
2315            .unwrap();
2316
2317        let (host, guest) = connected_async_channels(16 * 1024);
2318        let guest_queue = Queue::new(guest).unwrap();
2319
2320        let mut guest = test_helpers::TestGuest {
2321            queue: guest_queue,
2322            transaction_id: 0,
2323        };
2324
2325        let test_guest_mem = GuestMemory::allocate(16384);
2326        let worker = TestWorker::start(
2327            controller.clone(),
2328            &driver,
2329            test_guest_mem.clone(),
2330            host,
2331            None,
2332        );
2333
2334        let negotiate_packet = storvsp_protocol::Packet {
2335            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2336            flags: 0,
2337            status: storvsp_protocol::NtStatus::SUCCESS,
2338        };
2339        guest
2340            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2341            .await;
2342        guest.verify_completion(parse_guest_completion).await;
2343
2344        let version_packet = storvsp_protocol::Packet {
2345            operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2346            flags: 0,
2347            status: storvsp_protocol::NtStatus::SUCCESS,
2348        };
2349        let version = storvsp_protocol::ProtocolVersion {
2350            major_minor: storvsp_protocol::VERSION_BLUE,
2351            reserved: 0,
2352        };
2353        guest
2354            .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2355            .await;
2356        guest.verify_completion(parse_guest_completion).await;
2357
2358        let properties_packet = storvsp_protocol::Packet {
2359            operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2360            flags: 0,
2361            status: storvsp_protocol::NtStatus::SUCCESS,
2362        };
2363        guest
2364            .send_data_packet_sync(&[properties_packet.as_bytes()])
2365            .await;
2366        guest.verify_completion(parse_guest_completion).await;
2367
2368        let negotiate_packet = storvsp_protocol::Packet {
2369            operation: storvsp_protocol::Operation::END_INITIALIZATION,
2370            flags: 0,
2371            status: storvsp_protocol::NtStatus::SUCCESS,
2372        };
2373        guest
2374            .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2375            .await;
2376        guest.verify_completion(parse_guest_completion).await;
2377
2378        const IO_LEN: usize = 4 * 1024;
2379        let write_buf = [7u8; IO_LEN];
2380        let write_gpa = 4 * 1024u64;
2381        test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2382        guest
2383            .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2384            .await;
2385        guest
2386            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2387            .await;
2388
2389        let read_gpa = 8 * 1024u64;
2390        guest
2391            .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
2392            .await;
2393        guest
2394            .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2395            .await;
2396        let mut read_buf = [0u8; IO_LEN];
2397        test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
2398        for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
2399            assert_eq!(b1, b2);
2400        }
2401
2402        guest.verify_graceful_close(worker).await;
2403    }
2404}