1#![expect(missing_docs)]
45#![forbid(unsafe_code)]
46
47#[cfg(feature = "ioperf")]
48pub mod ioperf;
49
50#[cfg(feature = "test")]
51pub mod test_helpers;
52
53#[cfg(not(feature = "test"))]
54mod test_helpers;
55
56pub mod resolver;
57mod save_restore;
58
59use crate::ring::gparange::MultiPagedRangeBuf;
60use anyhow::Context as _;
61use async_trait::async_trait;
62use fast_select::FastSelect;
63use futures::FutureExt;
64use futures::StreamExt;
65use futures::select_biased;
66use guestmem::AccessError;
67use guestmem::GuestMemory;
68use guestmem::MemoryRead;
69use guestmem::MemoryWrite;
70use guestmem::ranges::PagedRange;
71use guid::Guid;
72use inspect::Inspect;
73use inspect::InspectMut;
74use inspect_counters::Counter;
75use inspect_counters::Histogram;
76use oversized_box::OversizedBox;
77use parking_lot::Mutex;
78use parking_lot::RwLock;
79use ring::OutgoingPacketType;
80use scsi::AdditionalSenseCode;
81use scsi::ScsiOp;
82use scsi::ScsiStatus;
83use scsi::srb::SrbStatus;
84use scsi::srb::SrbStatusAndFlags;
85use scsi_buffers::RequestBuffers;
86use scsi_core::AsyncScsiDisk;
87use scsi_core::Request;
88use scsi_core::ScsiResult;
89use scsi_defs as scsi;
90use scsidisk::illegal_request_sense;
91use slab::Slab;
92use std::collections::hash_map::Entry;
93use std::collections::hash_map::HashMap;
94use std::fmt::Debug;
95use std::future::Future;
96use std::future::poll_fn;
97use std::pin::Pin;
98use std::sync::Arc;
99use std::sync::atomic::AtomicU32;
100use std::sync::atomic::Ordering::Relaxed;
101use std::task::Context;
102use std::task::Poll;
103use storvsp_resources::ScsiPath;
104use task_control::AsyncRun;
105use task_control::InspectTask;
106use task_control::StopTask;
107use task_control::TaskControl;
108use thiserror::Error;
109use tracing_helpers::ErrorValueExt;
110use unicycle::FuturesUnordered;
111use vmbus_async::queue;
112use vmbus_async::queue::ExternalDataError;
113use vmbus_async::queue::IncomingPacket;
114use vmbus_async::queue::OutgoingPacket;
115use vmbus_async::queue::Queue;
116use vmbus_channel::RawAsyncChannel;
117use vmbus_channel::bus::ChannelType;
118use vmbus_channel::bus::OfferParams;
119use vmbus_channel::bus::OpenRequest;
120use vmbus_channel::channel::ChannelControl;
121use vmbus_channel::channel::ChannelOpenError;
122use vmbus_channel::channel::DeviceResources;
123use vmbus_channel::channel::RestoreControl;
124use vmbus_channel::channel::SaveRestoreVmbusDevice;
125use vmbus_channel::channel::VmbusDevice;
126use vmbus_channel::gpadl_ring::GpadlRingMem;
127use vmbus_channel::gpadl_ring::gpadl_channel;
128use vmbus_core::protocol::UserDefinedData;
129use vmbus_ring as ring;
130use vmbus_ring::RingMem;
131use vmcore::save_restore::RestoreError;
132use vmcore::save_restore::SaveError;
133use vmcore::save_restore::SavedStateBlob;
134use vmcore::vm_task::VmTaskDriver;
135use vmcore::vm_task::VmTaskDriverSource;
136use zerocopy::FromBytes;
137use zerocopy::FromZeros;
138use zerocopy::Immutable;
139use zerocopy::IntoBytes;
140use zerocopy::KnownLayout;
141
142const DEFAULT_POLL_MODE_QUEUE_DEPTH: u32 = 1;
147
148pub struct StorageDevice {
149 instance_id: Guid,
150 ide_path: Option<ScsiPath>,
151 workers: Vec<WorkerAndDriver>,
152 controller: Arc<ScsiControllerState>,
153 resources: DeviceResources,
154 driver_source: VmTaskDriverSource,
155 max_sub_channel_count: u16,
156 protocol: Arc<Protocol>,
157 io_queue_depth: u32,
158}
159
160#[derive(Inspect)]
161struct WorkerAndDriver {
162 #[inspect(flatten)]
163 worker: TaskControl<WorkerState, Worker>,
164 driver: VmTaskDriver,
165}
166
167struct WorkerState;
168
169impl InspectMut for StorageDevice {
170 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
171 let mut resp = req.respond();
172
173 let disks = self.controller.disks.read();
174 for (path, controller_disk) in disks.iter() {
175 resp.child(&format!("disks/{}", path), |req| {
176 controller_disk.disk.inspect(req);
177 });
178 }
179
180 resp.fields(
181 "channels",
182 self.workers
183 .iter()
184 .filter(|task| task.worker.has_state())
185 .enumerate(),
186 )
187 .field(
188 "poll_mode_queue_depth",
189 inspect::AtomicMut(&self.controller.poll_mode_queue_depth),
190 );
191 }
192}
193
194struct Worker<T: RingMem = GpadlRingMem> {
195 inner: WorkerInner,
196 rescan_notification: futures::channel::mpsc::Receiver<()>,
197 fast_select: FastSelect,
198 queue: Queue<T>,
199}
200
201struct Protocol {
202 state: RwLock<ProtocolState>,
203 ready: event_listener::Event,
205}
206
207struct WorkerInner {
208 protocol: Arc<Protocol>,
209 request_size: usize,
210 controller: Arc<ScsiControllerState>,
211 channel_index: u16,
212 scsi_queue: Arc<ScsiCommandQueue>,
213 scsi_requests: FuturesUnordered<ScsiRequest>,
214 scsi_requests_states: Slab<ScsiRequestState>,
215 full_request_pool: Vec<Arc<ScsiRequestAndRange>>,
216 future_pool: Vec<OversizedBox<(), ScsiOpStorage>>,
217 channel_control: ChannelControl,
218 max_io_queue_depth: usize,
219 stats: WorkerStats,
220}
221
222#[derive(Debug, Default, Inspect)]
223struct WorkerStats {
224 ios_submitted: Counter,
225 ios_completed: Counter,
226 wakes: Counter,
227 wakes_spurious: Counter,
228 per_wake_submissions: Histogram<10>,
229 per_wake_completions: Histogram<10>,
230}
231
232#[repr(u16)]
233#[derive(Copy, Clone, Debug, Inspect, PartialEq, Eq, PartialOrd, Ord)]
234enum Version {
235 Win6 = storvsp_protocol::VERSION_WIN6,
236 Win7 = storvsp_protocol::VERSION_WIN7,
237 Win8 = storvsp_protocol::VERSION_WIN8,
238 Blue = storvsp_protocol::VERSION_BLUE,
239}
240
241#[derive(Debug, Error)]
242#[error("protocol version {0:#x} not supported")]
243struct UnsupportedVersion(u16);
244
245impl Version {
246 fn parse(major_minor: u16) -> Result<Self, UnsupportedVersion> {
247 let version = match major_minor {
248 storvsp_protocol::VERSION_WIN6 => Self::Win6,
249 storvsp_protocol::VERSION_WIN7 => Self::Win7,
250 storvsp_protocol::VERSION_WIN8 => Self::Win8,
251 storvsp_protocol::VERSION_BLUE => Self::Blue,
252 version => return Err(UnsupportedVersion(version)),
253 };
254 assert_eq!(version as u16, major_minor);
255 Ok(version)
256 }
257
258 fn max_request_size(&self) -> usize {
259 match self {
260 Version::Win8 | Version::Blue => storvsp_protocol::SCSI_REQUEST_LEN_V2,
261 Version::Win6 | Version::Win7 => storvsp_protocol::SCSI_REQUEST_LEN_V1,
262 }
263 }
264}
265
266#[derive(Copy, Clone)]
267enum ProtocolState {
268 Init(InitState),
269 Ready {
270 version: Version,
271 subchannel_count: u16,
272 },
273}
274
275#[derive(Copy, Clone, Debug)]
276enum InitState {
277 Begin,
278 QueryVersion,
279 QueryProperties {
280 version: Version,
281 },
282 EndInitialization {
283 version: Version,
284 subchannel_count: Option<u16>,
285 },
286}
287
288type ScsiOpStorage = [u64; SCSI_REQUEST_STACK_SIZE / 8];
296type ScsiOpFuture = Pin<OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>>;
297
298const SCSI_REQUEST_STACK_SIZE: usize = scsi_core::ASYNC_SCSI_DISK_STACK_SIZE + 272;
303
304struct ScsiRequest {
305 request_id: usize,
306 future: Option<ScsiOpFuture>,
307}
308
309impl ScsiRequest {
310 fn new(
311 request_id: usize,
312 future: OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>,
313 ) -> Self {
314 Self {
315 request_id,
316 future: Some(future.into()),
317 }
318 }
319}
320
321impl Future for ScsiRequest {
322 type Output = (usize, ScsiResult, OversizedBox<(), ScsiOpStorage>);
323
324 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
325 let this = self.get_mut();
326 let future = this.future.as_mut().unwrap().as_mut();
327 let result = std::task::ready!(future.poll(cx));
328 let future = this.future.take().unwrap();
330 Poll::Ready((this.request_id, result, OversizedBox::empty_pinned(future)))
331 }
332}
333
334#[derive(Debug, Error)]
335enum WorkerError {
336 #[error("packet error")]
337 PacketError(#[source] PacketError),
338 #[error("queue error")]
339 Queue(#[source] queue::Error),
340 #[error("queue should have enough space but no longer does")]
341 NotEnoughSpace,
342}
343
344#[derive(Debug, Error)]
345enum PacketError {
346 #[error("Not transactional")]
347 NotTransactional,
348 #[error("Unrecognized operation {0:?}")]
349 UnrecognizedOperation(storvsp_protocol::Operation),
350 #[error("Invalid packet type")]
351 InvalidPacketType,
352 #[error("Invalid data transfer length")]
353 InvalidDataTransferLength,
354 #[error("Access error: {0}")]
355 Access(#[source] AccessError),
356 #[error("Range error")]
357 Range(#[source] ExternalDataError),
358}
359
360#[derive(Debug, Default, Clone)]
361struct Range {
362 len: usize,
363 is_write: bool,
364}
365
366impl Range {
367 fn new(buf: &MultiPagedRangeBuf, request: &storvsp_protocol::ScsiRequest) -> Option<Self> {
368 let len = request.data_transfer_length as usize;
369 let is_write = request.data_in != 0;
370 if buf.range_count() > 1 || (len > 0 && buf.first()?.len() < len) {
373 return None;
374 }
375 Some(Self { len, is_write })
376 }
377
378 fn buffer<'a>(
379 &'a self,
380 buf: &'a MultiPagedRangeBuf,
381 guest_memory: &'a GuestMemory,
382 ) -> RequestBuffers<'a> {
383 let mut range = buf.first().unwrap_or_else(PagedRange::empty);
384 range.truncate(self.len);
385 RequestBuffers::new(guest_memory, range, self.is_write)
386 }
387}
388
389#[derive(Debug)]
390struct Packet {
391 data: PacketData,
392 transaction_id: u64,
393 request_size: usize,
394}
395
396#[derive(Debug)]
397enum PacketData {
398 BeginInitialization,
399 EndInitialization,
400 QueryProtocolVersion(u16),
401 QueryProperties,
402 CreateSubChannels(u16),
403 ExecuteScsi(Arc<ScsiRequestAndRange>),
404 ResetBus,
405 ResetAdapter,
406 ResetLun,
407}
408
409#[derive(Debug)]
410pub struct RangeError;
411
412fn parse_packet<T: RingMem>(
413 packet: &IncomingPacket<'_, T>,
414 pool: &mut Vec<Arc<ScsiRequestAndRange>>,
415) -> Result<Packet, PacketError> {
416 let packet = match packet {
417 IncomingPacket::Completion(_) => return Err(PacketError::InvalidPacketType),
418 IncomingPacket::Data(packet) => packet,
419 };
420 let transaction_id = packet
421 .transaction_id()
422 .ok_or(PacketError::NotTransactional)?;
423
424 let mut reader = packet.reader();
425 let header: storvsp_protocol::Packet = reader.read_plain().map_err(PacketError::Access)?;
426 let request_size = reader.len().min(storvsp_protocol::SCSI_REQUEST_LEN_MAX);
430 let data = match header.operation {
431 storvsp_protocol::Operation::BEGIN_INITIALIZATION => PacketData::BeginInitialization,
432 storvsp_protocol::Operation::END_INITIALIZATION => PacketData::EndInitialization,
433 storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION => {
434 let mut version = storvsp_protocol::ProtocolVersion::new_zeroed();
435 reader
436 .read(version.as_mut_bytes())
437 .map_err(PacketError::Access)?;
438 PacketData::QueryProtocolVersion(version.major_minor)
439 }
440 storvsp_protocol::Operation::QUERY_PROPERTIES => PacketData::QueryProperties,
441 storvsp_protocol::Operation::EXECUTE_SRB => {
442 let mut full_request = pool.pop().unwrap_or_else(|| {
443 Arc::new(ScsiRequestAndRange {
444 external_data: Range::default(),
445 external_data_buf: MultiPagedRangeBuf::new(),
446 request: storvsp_protocol::ScsiRequest::new_zeroed(),
447 request_size,
448 })
449 });
450
451 {
452 let full_request = Arc::get_mut(&mut full_request).unwrap();
453 let request_buf = &mut full_request.request.as_mut_bytes()[..request_size];
454 reader.read(request_buf).map_err(PacketError::Access)?;
455
456 full_request.external_data_buf.clear();
457 packet
458 .read_external_ranges(&mut full_request.external_data_buf)
459 .map_err(PacketError::Range)?;
460
461 full_request.external_data =
462 Range::new(&full_request.external_data_buf, &full_request.request)
463 .ok_or(PacketError::InvalidDataTransferLength)?;
464 }
465
466 PacketData::ExecuteScsi(full_request)
467 }
468 storvsp_protocol::Operation::RESET_LUN => PacketData::ResetLun,
469 storvsp_protocol::Operation::RESET_ADAPTER => PacketData::ResetAdapter,
470 storvsp_protocol::Operation::RESET_BUS => PacketData::ResetBus,
471 storvsp_protocol::Operation::CREATE_SUB_CHANNELS => {
472 let mut sub_channel_count: u16 = 0;
473 reader
474 .read(sub_channel_count.as_mut_bytes())
475 .map_err(PacketError::Access)?;
476 PacketData::CreateSubChannels(sub_channel_count)
477 }
478 _ => return Err(PacketError::UnrecognizedOperation(header.operation)),
479 };
480
481 if let PacketData::ExecuteScsi(_) = data {
482 tracing::trace!(transaction_id, ?data, "parse_packet");
483 } else {
484 tracing::debug!(transaction_id, ?data, "parse_packet");
485 }
486
487 Ok(Packet {
488 data,
489 request_size,
490 transaction_id,
491 })
492}
493
494impl WorkerInner {
495 fn send_vmbus_packet<M: RingMem>(
496 &mut self,
497 writer: &mut queue::WriteBatch<'_, M>,
498 packet_type: OutgoingPacketType<'_>,
499 request_size: usize,
500 transaction_id: u64,
501 operation: storvsp_protocol::Operation,
502 status: storvsp_protocol::NtStatus,
503 payload: &[u8],
504 ) -> Result<(), WorkerError> {
505 let header = storvsp_protocol::Packet {
506 operation,
507 flags: 0,
508 status,
509 };
510
511 let packet_size = size_of_val(&header) + request_size;
512
513 let len = size_of_val(&header) + size_of_val(payload);
518 let padding = [0; storvsp_protocol::SCSI_REQUEST_LEN_MAX];
519 let (payload_bytes, padding_bytes) = if len > packet_size {
520 (&payload[..packet_size - size_of_val(&header)], &[][..])
521 } else {
522 (payload, &padding[..packet_size - len])
523 };
524 assert_eq!(
525 size_of_val(&header) + payload_bytes.len() + padding_bytes.len(),
526 packet_size
527 );
528 writer
529 .try_write(&OutgoingPacket {
530 transaction_id,
531 packet_type,
532 payload: &[header.as_bytes(), payload_bytes, padding_bytes],
533 })
534 .map_err(|err| match err {
535 queue::TryWriteError::Full(_) => WorkerError::NotEnoughSpace,
536 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
537 })
538 }
539
540 fn send_packet<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
541 &mut self,
542 writer: &mut queue::WriteHalf<'_, M>,
543 operation: storvsp_protocol::Operation,
544 status: storvsp_protocol::NtStatus,
545 payload: &P,
546 ) -> Result<(), WorkerError> {
547 self.send_vmbus_packet(
548 &mut writer.batched(),
549 OutgoingPacketType::InBandNoCompletion,
550 self.request_size,
551 0,
552 operation,
553 status,
554 payload.as_bytes(),
555 )
556 }
557
558 fn send_completion<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
559 &mut self,
560 writer: &mut queue::WriteHalf<'_, M>,
561 packet: &Packet,
562 status: storvsp_protocol::NtStatus,
563 payload: &P,
564 ) -> Result<(), WorkerError> {
565 self.send_vmbus_packet(
566 &mut writer.batched(),
567 OutgoingPacketType::Completion,
568 packet.request_size,
569 packet.transaction_id,
570 storvsp_protocol::Operation::COMPLETE_IO,
571 status,
572 payload.as_bytes(),
573 )
574 }
575}
576
577struct ScsiCommandQueue {
578 controller: Arc<ScsiControllerState>,
579 mem: GuestMemory,
580 force_path_id: Option<u8>,
581}
582
583impl ScsiCommandQueue {
584 async fn execute_scsi(&self, full_request: &ScsiRequestAndRange) -> ScsiResult {
585 let request = &full_request.request;
586 let op = ScsiOp(request.payload[0]);
587 let external_data = full_request
588 .external_data
589 .buffer(&full_request.external_data_buf, &self.mem);
590
591 tracing::trace!(
592 path_id = request.path_id,
593 target_id = request.target_id,
594 lun = request.lun,
595 op = ?op,
596 "execute_scsi start...",
597 );
598
599 let path_id = self.force_path_id.unwrap_or(request.path_id);
600
601 let controller_disk = self
602 .controller
603 .disks
604 .read()
605 .get(&ScsiPath {
606 path: path_id,
607 target: request.target_id,
608 lun: request.lun,
609 })
610 .cloned();
611
612 let result = match op {
613 ScsiOp::REPORT_LUNS => {
614 const HEADER_SIZE: usize = size_of::<scsi::LunList>();
615 let mut luns: Vec<u8> = self
616 .controller
617 .disks
618 .read()
619 .iter()
620 .flat_map(|(path, _)| {
621 if request.path_id == path.path && request.target_id == path.target {
624 Some(path.lun)
625 } else {
626 None
627 }
628 })
629 .collect();
630 luns.sort_unstable();
631 let mut data: Vec<u64> = vec![0; luns.len() + 1];
632 let header = scsi::LunList {
633 length: (luns.len() as u32 * 8).into(),
634 reserved: [0; 4],
635 };
636 data.as_mut_bytes()[..HEADER_SIZE].copy_from_slice(header.as_bytes());
637 for (i, lun) in luns.iter().enumerate() {
638 data[i + 1].as_mut_bytes()[..2].copy_from_slice(&(*lun as u16).to_be_bytes());
639 }
640 if external_data.len() >= HEADER_SIZE {
641 let tx = std::cmp::min(external_data.len(), data.as_bytes().len());
642 external_data.writer().write(&data.as_bytes()[..tx]).map_or(
643 ScsiResult {
644 scsi_status: ScsiStatus::CHECK_CONDITION,
645 srb_status: SrbStatus::INVALID_REQUEST,
646 tx: 0,
647 sense_data: Some(illegal_request_sense(
648 AdditionalSenseCode::INVALID_CDB,
649 )),
650 },
651 |_| ScsiResult {
652 scsi_status: ScsiStatus::GOOD,
653 srb_status: SrbStatus::SUCCESS,
654 tx,
655 sense_data: None,
656 },
657 )
658 } else {
659 ScsiResult {
660 scsi_status: ScsiStatus::GOOD,
661 srb_status: SrbStatus::SUCCESS,
662 tx: 0,
663 sense_data: None,
664 }
665 }
666 }
667 _ if controller_disk.is_some() => {
668 let mut cdb = [0; 16];
669 cdb.copy_from_slice(&request.payload[0..storvsp_protocol::CDB16GENERIC_LENGTH]);
670 controller_disk
671 .unwrap()
672 .disk
673 .execute_scsi(
674 &external_data,
675 &Request {
676 cdb,
677 srb_flags: request.srb_flags,
678 },
679 )
680 .await
681 }
682 ScsiOp::INQUIRY => {
683 let cdb = scsi::CdbInquiry::ref_from_prefix(&request.payload)
684 .unwrap()
685 .0; if external_data.len() < cdb.allocation_length.get() as usize
687 || request.data_in != storvsp_protocol::SCSI_IOCTL_DATA_IN
688 || (cdb.allocation_length.get() as usize) < size_of::<scsi::InquiryDataHeader>()
689 {
690 ScsiResult {
691 scsi_status: ScsiStatus::CHECK_CONDITION,
692 srb_status: SrbStatus::INVALID_REQUEST,
693 tx: 0,
694 sense_data: Some(illegal_request_sense(AdditionalSenseCode::INVALID_CDB)),
695 }
696 } else {
697 let enable_vpd = cdb.flags.vpd();
698 if enable_vpd || cdb.page_code != 0 {
699 ScsiResult {
701 scsi_status: ScsiStatus::CHECK_CONDITION,
702 srb_status: SrbStatus::INVALID_REQUEST,
703 tx: 0,
704 sense_data: Some(illegal_request_sense(
705 AdditionalSenseCode::INVALID_CDB,
706 )),
707 }
708 } else {
709 const LOGICAL_UNIT_NOT_PRESENT_DEVICE: u8 = 0x7F;
710 let mut data = scsidisk::INQUIRY_DATA_TEMPLATE;
711 data.header.device_type = LOGICAL_UNIT_NOT_PRESENT_DEVICE;
712
713 if request.lun != 0 {
714 data.vendor_id = [0; 8];
716 data.product_id = [0; 16];
717 data.product_revision_level = [0; 4];
718 }
719
720 let datab = data.as_bytes();
721 let tx = std::cmp::min(
722 cdb.allocation_length.get() as usize,
723 size_of::<scsi::InquiryData>(),
724 );
725 external_data.writer().write(&datab[..tx]).map_or(
726 ScsiResult {
727 scsi_status: ScsiStatus::CHECK_CONDITION,
728 srb_status: SrbStatus::INVALID_REQUEST,
729 tx: 0,
730 sense_data: Some(illegal_request_sense(
731 AdditionalSenseCode::INVALID_CDB,
732 )),
733 },
734 |_| ScsiResult {
735 scsi_status: ScsiStatus::GOOD,
736 srb_status: SrbStatus::SUCCESS,
737 tx,
738 sense_data: None,
739 },
740 )
741 }
742 }
743 }
744 _ => ScsiResult {
745 scsi_status: ScsiStatus::CHECK_CONDITION,
746 srb_status: SrbStatus::INVALID_LUN,
747 tx: 0,
748 sense_data: None,
749 },
750 };
751
752 tracing::trace!(
753 path_id = request.path_id,
754 target_id = request.target_id,
755 lun = request.lun,
756 op = ?op,
757 result = ?result,
758 "execute_scsi completed.",
759 );
760 result
761 }
762}
763
764impl<T: RingMem + 'static> Worker<T> {
765 fn new(
766 controller: Arc<ScsiControllerState>,
767 channel: RawAsyncChannel<T>,
768 channel_index: u16,
769 mem: GuestMemory,
770 channel_control: ChannelControl,
771 io_queue_depth: u32,
772 protocol: Arc<Protocol>,
773 force_path_id: Option<u8>,
774 ) -> anyhow::Result<Self> {
775 let queue = Queue::new(channel)?;
776 #[expect(clippy::disallowed_methods)] let (source, target) = futures::channel::mpsc::channel(1);
778 controller.add_rescan_notification_source(source);
779
780 let max_io_queue_depth = io_queue_depth.max(1) as usize;
781 Ok(Self {
782 inner: WorkerInner {
783 protocol,
784 request_size: storvsp_protocol::SCSI_REQUEST_LEN_V1,
785 controller: controller.clone(),
786 channel_index,
787 scsi_queue: Arc::new(ScsiCommandQueue {
788 controller,
789 mem,
790 force_path_id,
791 }),
792 scsi_requests: FuturesUnordered::new(),
793 scsi_requests_states: Slab::with_capacity(max_io_queue_depth),
794 channel_control,
795 max_io_queue_depth,
796 future_pool: Vec::new(),
797 full_request_pool: Vec::new(),
798 stats: Default::default(),
799 },
800 queue,
801 rescan_notification: target,
802 fast_select: FastSelect::new(),
803 })
804 }
805
806 async fn wait_for_scsi_requests_complete(&mut self) {
807 tracing::debug!(
808 channel_index = self.inner.channel_index,
809 "wait for IOs completed..."
810 );
811 while let Some((id, _, _)) = self.inner.scsi_requests.next().await {
812 self.inner.scsi_requests_states.remove(id);
813 }
814 }
815}
816
817impl InspectTask<Worker> for WorkerState {
818 fn inspect(&self, req: inspect::Request<'_>, worker: Option<&Worker>) {
819 if let Some(worker) = worker {
820 let mut resp = req.respond();
821 if worker.inner.channel_index == 0 {
822 let (state, version, subchannel_count) = match *worker.inner.protocol.state.read() {
823 ProtocolState::Init(state) => match state {
824 InitState::Begin => ("begin_init", None, None),
825 InitState::QueryVersion => ("query_version", None, None),
826 InitState::QueryProperties { version } => {
827 ("query_properties", Some(version), None)
828 }
829 InitState::EndInitialization {
830 version,
831 subchannel_count,
832 } => ("end_init", Some(version), subchannel_count),
833 },
834 ProtocolState::Ready {
835 version,
836 subchannel_count,
837 } => ("ready", Some(version), Some(subchannel_count)),
838 };
839 resp.field("state", state)
840 .field("version", version)
841 .field("subchannel_count", subchannel_count);
842 }
843 resp.field("pending_packets", worker.inner.scsi_requests_states.len())
844 .fields("io", worker.inner.scsi_requests_states.iter())
845 .field("stats", &worker.inner.stats)
846 .field("ring", &worker.queue)
847 .field("max_io_queue_depth", worker.inner.max_io_queue_depth);
848 }
849 }
850}
851
852impl<T: 'static + Send + Sync + RingMem> AsyncRun<Worker<T>> for WorkerState {
853 async fn run(
854 &mut self,
855 stop: &mut StopTask<'_>,
856 worker: &mut Worker<T>,
857 ) -> Result<(), task_control::Cancelled> {
858 let fut = async {
859 if worker.inner.channel_index == 0 {
860 worker.process_primary().await
861 } else {
862 let protocol_version = loop {
865 let listener = worker.inner.protocol.ready.listen();
866 if let ProtocolState::Ready { version, .. } =
867 *worker.inner.protocol.state.read()
868 {
869 break version;
870 }
871 tracing::debug!("subchannel waiting for initialization to end");
872 listener.await
873 };
874 worker
875 .inner
876 .process_ready(&mut worker.queue, protocol_version)
877 .await
878 }
879 };
880
881 match stop.until_stopped(fut).await? {
882 Ok(_) => {}
883 Err(e) => tracing::error!(error = e.as_error(), "process_packets error"),
884 }
885 Ok(())
886 }
887}
888
889impl WorkerInner {
890 async fn next_packet<'a, M: RingMem>(
893 &mut self,
894 reader: &'a mut queue::ReadHalf<'a, M>,
895 ) -> Result<Packet, WorkerError> {
896 let packet = reader.read().await.map_err(WorkerError::Queue)?;
897 let stor_packet =
898 parse_packet(&packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
899 Ok(stor_packet)
900 }
901
902 fn poll_for_ring_space<M: RingMem>(
909 &mut self,
910 cx: &mut Context<'_>,
911 writer: &mut queue::WriteHalf<'_, M>,
912 ) -> Poll<Result<(), WorkerError>> {
913 writer
914 .poll_ready(cx, MAX_VMBUS_PACKET_SIZE)
915 .map_err(WorkerError::Queue)
916 }
917}
918
919const MAX_VMBUS_PACKET_SIZE: usize = ring::PacketSize::in_band(
920 size_of::<storvsp_protocol::Packet>() + storvsp_protocol::SCSI_REQUEST_LEN_MAX,
921);
922
923impl<T: RingMem> Worker<T> {
924 async fn process_primary(&mut self) -> Result<(), WorkerError> {
926 loop {
927 let current_state = *self.inner.protocol.state.read();
928 match current_state {
929 ProtocolState::Ready { version, .. } => {
930 break loop {
931 select_biased! {
932 r = self.inner.process_ready(&mut self.queue, version).fuse() => break r,
933 _ = self.fast_select.select((self.rescan_notification.select_next_some(),)).fuse() => {
934 if version >= Version::Win7
935 {
936 tracing::debug!("rescan notification received, sending ENUMERATE_BUS");
937 self.inner.send_packet(&mut self.queue.split().1, storvsp_protocol::Operation::ENUMERATE_BUS, storvsp_protocol::NtStatus::SUCCESS, &())?;
938 }
939 }
940 }
941 };
942 }
943 ProtocolState::Init(state) => {
944 let (mut reader, mut writer) = self.queue.split();
945
946 poll_fn(|cx| self.inner.poll_for_ring_space(cx, &mut writer)).await?;
949
950 tracing::debug!(?state, "process_primary");
951 match state {
952 InitState::Begin => {
953 let packet = self.inner.next_packet(&mut reader).await?;
954 if let PacketData::BeginInitialization = packet.data {
955 self.inner.send_completion(
956 &mut writer,
957 &packet,
958 storvsp_protocol::NtStatus::SUCCESS,
959 &(),
960 )?;
961 *self.inner.protocol.state.write() =
962 ProtocolState::Init(InitState::QueryVersion);
963 } else {
964 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
965 self.inner.send_completion(
966 &mut writer,
967 &packet,
968 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
969 &(),
970 )?;
971 }
972 }
973 InitState::QueryVersion => {
974 let packet = self.inner.next_packet(&mut reader).await?;
975 if let PacketData::QueryProtocolVersion(major_minor) = packet.data {
976 if let Ok(version) = Version::parse(major_minor) {
977 self.inner.send_completion(
978 &mut writer,
979 &packet,
980 storvsp_protocol::NtStatus::SUCCESS,
981 &storvsp_protocol::ProtocolVersion {
982 major_minor,
983 reserved: 0,
984 },
985 )?;
986 self.inner.request_size = version.max_request_size();
987 *self.inner.protocol.state.write() =
988 ProtocolState::Init(InitState::QueryProperties { version });
989
990 tracelimit::info_ratelimited!(
991 ?version,
992 "scsi version negotiated"
993 );
994 } else {
995 self.inner.send_completion(
996 &mut writer,
997 &packet,
998 storvsp_protocol::NtStatus::REVISION_MISMATCH,
999 &storvsp_protocol::ProtocolVersion {
1000 major_minor,
1001 reserved: 0,
1002 },
1003 )?;
1004 *self.inner.protocol.state.write() =
1005 ProtocolState::Init(InitState::QueryVersion);
1006 }
1007 } else {
1008 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1009 self.inner.send_completion(
1010 &mut writer,
1011 &packet,
1012 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1013 &(),
1014 )?;
1015 }
1016 }
1017 InitState::QueryProperties { version } => {
1018 let packet = self.inner.next_packet(&mut reader).await?;
1019 if let PacketData::QueryProperties = packet.data {
1020 let multi_channel_supported = version >= Version::Win8;
1021
1022 self.inner.send_completion(
1023 &mut writer,
1024 &packet,
1025 storvsp_protocol::NtStatus::SUCCESS,
1026 &storvsp_protocol::ChannelProperties {
1027 max_transfer_bytes: 0x40000, flags: {
1029 if multi_channel_supported {
1030 storvsp_protocol::STORAGE_CHANNEL_SUPPORTS_MULTI_CHANNEL
1031 } else {
1032 0
1033 }
1034 },
1035 maximum_sub_channel_count: if multi_channel_supported {
1036 self.inner.channel_control.max_subchannels()
1037 } else {
1038 0
1039 },
1040 reserved: 0,
1041 reserved2: 0,
1042 reserved3: [0, 0],
1043 },
1044 )?;
1045 *self.inner.protocol.state.write() =
1046 ProtocolState::Init(InitState::EndInitialization {
1047 version,
1048 subchannel_count: if multi_channel_supported {
1049 None
1050 } else {
1051 Some(0)
1052 },
1053 });
1054 } else {
1055 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1056 self.inner.send_completion(
1057 &mut writer,
1058 &packet,
1059 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1060 &(),
1061 )?;
1062 }
1063 }
1064 InitState::EndInitialization {
1065 version,
1066 subchannel_count,
1067 } => {
1068 let packet = self.inner.next_packet(&mut reader).await?;
1069 match packet.data {
1070 PacketData::CreateSubChannels(sub_channel_count)
1071 if subchannel_count.is_none() =>
1072 {
1073 if let Err(err) = self
1074 .inner
1075 .channel_control
1076 .enable_subchannels(sub_channel_count)
1077 {
1078 tracelimit::warn_ratelimited!(
1079 ?err,
1080 "cannot enable subchannels"
1081 );
1082 self.inner.send_completion(
1083 &mut writer,
1084 &packet,
1085 storvsp_protocol::NtStatus::INVALID_PARAMETER,
1086 &(),
1087 )?;
1088 } else {
1089 self.inner.send_completion(
1090 &mut writer,
1091 &packet,
1092 storvsp_protocol::NtStatus::SUCCESS,
1093 &(),
1094 )?;
1095 *self.inner.protocol.state.write() =
1096 ProtocolState::Init(InitState::EndInitialization {
1097 version,
1098 subchannel_count: Some(sub_channel_count),
1099 });
1100 }
1101 }
1102 PacketData::EndInitialization => {
1103 self.inner.send_completion(
1104 &mut writer,
1105 &packet,
1106 storvsp_protocol::NtStatus::SUCCESS,
1107 &(),
1108 )?;
1109 self.rescan_notification.try_next().ok();
1112 *self.inner.protocol.state.write() = ProtocolState::Ready {
1113 version,
1114 subchannel_count: subchannel_count.unwrap_or(0),
1115 };
1116 self.inner.protocol.ready.notify(usize::MAX);
1119 }
1120 _ => {
1121 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1122 self.inner.send_completion(
1123 &mut writer,
1124 &packet,
1125 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1126 &(),
1127 )?;
1128 }
1129 }
1130 }
1131 }
1132 }
1133 }
1134 }
1135 }
1136}
1137
1138fn convert_srb_status_to_nt_status(srb_status: SrbStatus) -> storvsp_protocol::NtStatus {
1139 match srb_status {
1140 SrbStatus::BUSY => storvsp_protocol::NtStatus::DEVICE_BUSY,
1141 SrbStatus::SUCCESS => storvsp_protocol::NtStatus::SUCCESS,
1142 SrbStatus::INVALID_LUN
1143 | SrbStatus::INVALID_TARGET_ID
1144 | SrbStatus::NO_DEVICE
1145 | SrbStatus::NO_HBA => storvsp_protocol::NtStatus::DEVICE_DOES_NOT_EXIST,
1146 SrbStatus::COMMAND_TIMEOUT | SrbStatus::TIMEOUT => storvsp_protocol::NtStatus::IO_TIMEOUT,
1147 SrbStatus::SELECTION_TIMEOUT => storvsp_protocol::NtStatus::DEVICE_NOT_CONNECTED,
1148 SrbStatus::BAD_FUNCTION | SrbStatus::BAD_SRB_BLOCK_LENGTH => {
1149 storvsp_protocol::NtStatus::INVALID_DEVICE_REQUEST
1150 }
1151 SrbStatus::DATA_OVERRUN => storvsp_protocol::NtStatus::BUFFER_OVERFLOW,
1152 SrbStatus::REQUEST_FLUSHED => storvsp_protocol::NtStatus::UNSUCCESSFUL,
1153 SrbStatus::ABORTED => storvsp_protocol::NtStatus::CANCELLED,
1154 _ => storvsp_protocol::NtStatus::IO_DEVICE_ERROR,
1155 }
1156}
1157
1158impl WorkerInner {
1159 async fn process_ready<M: RingMem>(
1161 &mut self,
1162 queue: &mut Queue<M>,
1163 protocol_version: Version,
1164 ) -> Result<(), WorkerError> {
1165 self.request_size = protocol_version.max_request_size();
1166 poll_fn(|cx| self.poll_process_ready(cx, queue)).await
1167 }
1168
1169 fn poll_process_ready<M: RingMem>(
1171 &mut self,
1172 cx: &mut Context<'_>,
1173 queue: &mut Queue<M>,
1174 ) -> Poll<Result<(), WorkerError>> {
1175 self.stats.wakes.increment();
1176
1177 let (mut reader, mut writer) = queue.split();
1178 let mut total_completions = 0;
1179 let mut total_submissions = 0;
1180 let poll_mode_queue_depth = self.controller.poll_mode_queue_depth.load(Relaxed) as usize;
1181
1182 loop {
1183 'outer: while !self.scsi_requests_states.is_empty() {
1185 {
1186 let mut batch = writer.batched();
1187 loop {
1188 if !batch
1192 .can_write(MAX_VMBUS_PACKET_SIZE)
1193 .map_err(WorkerError::Queue)?
1194 {
1195 break;
1197 }
1198 if let Poll::Ready(Some((request_id, result, future))) =
1199 self.scsi_requests.poll_next_unpin(cx)
1200 {
1201 self.future_pool.push(future);
1202 self.handle_completion(&mut batch, request_id, result)?;
1203 total_completions += 1;
1204 } else {
1205 tracing::trace!("out of completions");
1206 break 'outer;
1207 }
1208 }
1209 }
1210
1211 if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1213 tracing::trace!("out of ring space");
1214 break;
1215 }
1216 }
1217
1218 let mut submissions = 0;
1219 'outer: loop {
1221 if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1222 break;
1223 }
1224 let mut batch = if self.scsi_requests_states.len() < poll_mode_queue_depth {
1225 if let Poll::Ready(batch) = reader.poll_read_batch(cx) {
1226 batch.map_err(WorkerError::Queue)?
1227 } else {
1228 tracing::trace!("out of incoming packets");
1229 break;
1230 }
1231 } else {
1232 match reader.try_read_batch() {
1233 Ok(batch) => batch,
1234 Err(queue::TryReadError::Empty) => {
1235 tracing::trace!(
1236 pending_io_count = self.scsi_requests_states.len(),
1237 "out of incoming packets, keeping interrupts masked"
1238 );
1239 break;
1240 }
1241 Err(queue::TryReadError::Queue(err)) => Err(WorkerError::Queue(err))?,
1242 }
1243 };
1244
1245 let mut packets = batch.packets();
1246 loop {
1247 if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1248 break 'outer;
1249 }
1250 if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1254 tracing::trace!("out of ring space");
1255 break 'outer;
1256 }
1257
1258 let packet = if let Some(packet) = packets.next() {
1259 packet.map_err(WorkerError::Queue)?
1260 } else {
1261 break;
1262 };
1263
1264 if self.handle_packet(&mut writer, &packet)? {
1265 submissions += 1;
1266 }
1267 }
1268 }
1269
1270 if submissions == 0 {
1272 break;
1274 }
1275 total_submissions += submissions;
1276 }
1277
1278 if total_submissions != 0 || total_completions != 0 {
1279 self.stats.ios_submitted.add(total_submissions);
1280 self.stats
1281 .per_wake_submissions
1282 .add_sample(total_submissions);
1283 self.stats
1284 .per_wake_completions
1285 .add_sample(total_completions);
1286 self.stats.ios_completed.add(total_completions);
1287 } else {
1288 self.stats.wakes_spurious.increment();
1289 }
1290
1291 Poll::Pending
1292 }
1293
1294 fn handle_completion<M: RingMem>(
1295 &mut self,
1296 writer: &mut queue::WriteBatch<'_, M>,
1297 request_id: usize,
1298 result: ScsiResult,
1299 ) -> Result<(), WorkerError> {
1300 let state = self.scsi_requests_states.remove(request_id);
1301 let request_size = state.request.request_size;
1302
1303 assert_eq!(
1305 Arc::strong_count(&state.request) + Arc::weak_count(&state.request),
1306 1
1307 );
1308 self.full_request_pool.push(state.request);
1309
1310 let status = convert_srb_status_to_nt_status(result.srb_status);
1311 let mut payload = [0; 0x14];
1312 if let Some(sense) = result.sense_data {
1313 payload[..size_of_val(&sense)].copy_from_slice(sense.as_bytes());
1314 tracing::trace!(sense_info = ?payload, sense_key = payload[2], asc = payload[12], "execute_scsi");
1315 };
1316 let response = storvsp_protocol::ScsiRequest {
1317 length: size_of::<storvsp_protocol::ScsiRequest>() as u16,
1318 scsi_status: result.scsi_status,
1319 srb_status: SrbStatusAndFlags::new()
1320 .with_status(result.srb_status)
1321 .with_autosense_valid(result.sense_data.is_some()),
1322 data_transfer_length: result.tx as u32,
1323 cdb_length: storvsp_protocol::CDB16GENERIC_LENGTH as u8,
1324 sense_info_ex_length: storvsp_protocol::VMSCSI_SENSE_BUFFER_SIZE as u8,
1325 payload,
1326 ..storvsp_protocol::ScsiRequest::new_zeroed()
1327 };
1328 self.send_vmbus_packet(
1329 writer,
1330 OutgoingPacketType::Completion,
1331 request_size,
1332 state.transaction_id,
1333 storvsp_protocol::Operation::COMPLETE_IO,
1334 status,
1335 response.as_bytes(),
1336 )?;
1337 Ok(())
1338 }
1339
1340 fn handle_packet<M: RingMem>(
1341 &mut self,
1342 writer: &mut queue::WriteHalf<'_, M>,
1343 packet: &IncomingPacket<'_, M>,
1344 ) -> Result<bool, WorkerError> {
1345 let packet =
1346 parse_packet(packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
1347 let submitted_io = match packet.data {
1348 PacketData::ExecuteScsi(request) => {
1349 self.push_scsi_request(packet.transaction_id, request);
1350 true
1351 }
1352 PacketData::ResetAdapter | PacketData::ResetBus | PacketData::ResetLun => {
1353 self.send_completion(writer, &packet, storvsp_protocol::NtStatus::SUCCESS, &())?;
1355 false
1356 }
1357 PacketData::CreateSubChannels(new_subchannel_count) if self.channel_index == 0 => {
1358 if let Err(err) = self
1359 .channel_control
1360 .enable_subchannels(new_subchannel_count)
1361 {
1362 tracelimit::warn_ratelimited!(?err, "cannot create subchannels");
1363 self.send_completion(
1364 writer,
1365 &packet,
1366 storvsp_protocol::NtStatus::INVALID_PARAMETER,
1367 &(),
1368 )?;
1369 false
1370 } else {
1371 if let ProtocolState::Ready {
1373 subchannel_count, ..
1374 } = &mut *self.protocol.state.write()
1375 {
1376 *subchannel_count = new_subchannel_count;
1377 } else {
1378 unreachable!()
1379 }
1380
1381 self.send_completion(
1382 writer,
1383 &packet,
1384 storvsp_protocol::NtStatus::SUCCESS,
1385 &(),
1386 )?;
1387 false
1388 }
1389 }
1390 _ => {
1391 tracelimit::warn_ratelimited!(data = ?packet.data, "unexpected packet on ready");
1392 self.send_completion(
1393 writer,
1394 &packet,
1395 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1396 &(),
1397 )?;
1398 false
1399 }
1400 };
1401 Ok(submitted_io)
1402 }
1403
1404 fn push_scsi_request(&mut self, transaction_id: u64, full_request: Arc<ScsiRequestAndRange>) {
1405 let scsi_queue = self.scsi_queue.clone();
1406 let scsi_request_state = ScsiRequestState {
1407 transaction_id,
1408 request: full_request.clone(),
1409 };
1410 let request_id = self.scsi_requests_states.insert(scsi_request_state);
1411 let future = self
1412 .future_pool
1413 .pop()
1414 .unwrap_or_else(|| OversizedBox::new(()));
1415 let future = OversizedBox::refill(future, async move {
1416 scsi_queue.execute_scsi(full_request.as_ref()).await
1417 });
1418 let request = ScsiRequest::new(request_id, oversized_box::coerce!(future));
1419 self.scsi_requests.push(request);
1420 }
1421}
1422
1423impl<T: RingMem> Drop for Worker<T> {
1424 fn drop(&mut self) {
1425 self.inner
1426 .controller
1427 .remove_rescan_notification_source(&self.rescan_notification);
1428 }
1429}
1430
1431#[derive(Debug, Error)]
1432#[error("SCSI path {}:{}:{} is already in use", self.0.path, self.0.target, self.0.lun)]
1433pub struct ScsiPathInUse(pub ScsiPath);
1434
1435#[derive(Debug, Error)]
1436#[error("SCSI path {}:{}:{} is not in use", self.0.path, self.0.target, self.0.lun)]
1437pub struct ScsiPathNotInUse(ScsiPath);
1438
1439#[derive(Clone)]
1440struct ScsiRequestState {
1441 transaction_id: u64,
1442 request: Arc<ScsiRequestAndRange>,
1443}
1444
1445#[derive(Debug)]
1446struct ScsiRequestAndRange {
1447 external_data: Range,
1448 external_data_buf: MultiPagedRangeBuf,
1449 request: storvsp_protocol::ScsiRequest,
1450 request_size: usize,
1451}
1452
1453impl Inspect for ScsiRequestState {
1454 fn inspect(&self, req: inspect::Request<'_>) {
1455 req.respond()
1456 .field("transaction_id", self.transaction_id)
1457 .display(
1458 "address",
1459 &ScsiPath {
1460 path: self.request.request.path_id,
1461 target: self.request.request.target_id,
1462 lun: self.request.request.lun,
1463 },
1464 )
1465 .display_debug("operation", &ScsiOp(self.request.request.payload[0]));
1466 }
1467}
1468
1469impl StorageDevice {
1470 pub fn build_scsi(
1472 driver_source: &VmTaskDriverSource,
1473 controller: &ScsiController,
1474 instance_id: Guid,
1475 max_sub_channel_count: u16,
1476 io_queue_depth: u32,
1477 ) -> Self {
1478 Self::build_inner(
1479 driver_source,
1480 controller,
1481 instance_id,
1482 None,
1483 max_sub_channel_count,
1484 io_queue_depth,
1485 )
1486 }
1487
1488 pub fn build_ide(
1491 driver_source: &VmTaskDriverSource,
1492 channel_id: u8,
1493 device_id: u8,
1494 disk: ScsiControllerDisk,
1495 io_queue_depth: u32,
1496 ) -> Self {
1497 let path = ScsiPath {
1498 path: channel_id,
1499 target: device_id,
1500 lun: 0,
1501 };
1502
1503 let controller = ScsiController::new();
1504 controller.attach(path, disk).unwrap();
1505
1506 let instance_id = Guid {
1509 data1: channel_id.into(),
1510 data2: device_id.into(),
1511 data3: 0x8899,
1512 data4: [0; 8],
1513 };
1514 Self::build_inner(
1515 driver_source,
1516 &controller,
1517 instance_id,
1518 Some(path),
1519 0,
1520 io_queue_depth,
1521 )
1522 }
1523
1524 fn build_inner(
1525 driver_source: &VmTaskDriverSource,
1526 controller: &ScsiController,
1527 instance_id: Guid,
1528 ide_path: Option<ScsiPath>,
1529 max_sub_channel_count: u16,
1530 io_queue_depth: u32,
1531 ) -> Self {
1532 let workers = (0..max_sub_channel_count + 1)
1533 .map(|channel_index| WorkerAndDriver {
1534 worker: TaskControl::new(WorkerState),
1535 driver: driver_source
1536 .builder()
1537 .target_vp(0)
1538 .run_on_target(true)
1539 .build(format!("storvsp-{}-{}", instance_id, channel_index)),
1540 })
1541 .collect();
1542
1543 Self {
1544 instance_id,
1545 ide_path,
1546 workers,
1547 controller: controller.state.clone(),
1548 resources: Default::default(),
1549 max_sub_channel_count,
1550 driver_source: driver_source.clone(),
1551 protocol: Arc::new(Protocol {
1552 state: RwLock::new(ProtocolState::Init(InitState::Begin)),
1553 ready: Default::default(),
1554 }),
1555 io_queue_depth,
1556 }
1557 }
1558
1559 fn new_worker(
1560 &mut self,
1561 open_request: &OpenRequest,
1562 channel_index: u16,
1563 ) -> anyhow::Result<&mut Worker> {
1564 let controller = self.controller.clone();
1565
1566 let target_vp = open_request.open_data.target_vp.unwrap_or_default();
1569 let driver = self
1570 .driver_source
1571 .builder()
1572 .target_vp(target_vp)
1573 .run_on_target(true)
1574 .build(format!("storvsp-{}-{}", self.instance_id, channel_index));
1575
1576 let channel = gpadl_channel(&driver, &self.resources, open_request, channel_index)
1577 .context("failed to create vmbus channel")?;
1578
1579 let channel_control = self.resources.channel_control.clone();
1580
1581 tracing::debug!(
1582 target_vp = open_request.open_data.target_vp,
1583 channel_index,
1584 "packet processing starting...",
1585 );
1586
1587 let force_path_id = self.ide_path.map(|p| p.path);
1591
1592 let worker = Worker::new(
1593 controller,
1594 channel,
1595 channel_index,
1596 self.resources
1597 .offer_resources
1598 .guest_memory(open_request)
1599 .clone(),
1600 channel_control,
1601 self.io_queue_depth,
1602 self.protocol.clone(),
1603 force_path_id,
1604 )
1605 .map_err(RestoreError::Other)?;
1606
1607 self.workers[channel_index as usize]
1608 .driver
1609 .retarget_vp(target_vp);
1610
1611 Ok(self.workers[channel_index as usize].worker.insert(
1612 &driver,
1613 format!("storvsp worker {}-{}", self.instance_id, channel_index),
1614 worker,
1615 ))
1616 }
1617}
1618
1619#[derive(Clone)]
1621pub struct ScsiControllerDisk {
1622 disk: Arc<dyn AsyncScsiDisk>,
1623}
1624
1625impl ScsiControllerDisk {
1626 pub fn new(disk: Arc<dyn AsyncScsiDisk>) -> Self {
1628 Self { disk }
1629 }
1630}
1631
1632struct ScsiControllerState {
1633 disks: RwLock<HashMap<ScsiPath, ScsiControllerDisk>>,
1634 rescan_notification_source: Mutex<Vec<futures::channel::mpsc::Sender<()>>>,
1635 poll_mode_queue_depth: AtomicU32,
1636}
1637
1638pub struct ScsiController {
1639 state: Arc<ScsiControllerState>,
1640}
1641
1642impl ScsiController {
1643 pub fn new() -> Self {
1644 Self::new_with_poll_mode_queue_depth(None)
1645 }
1646
1647 pub fn new_with_poll_mode_queue_depth(poll_mode_queue_depth: Option<u32>) -> Self {
1648 Self {
1649 state: Arc::new(ScsiControllerState {
1650 disks: Default::default(),
1651 rescan_notification_source: Mutex::new(Vec::new()),
1652 poll_mode_queue_depth: AtomicU32::new(
1653 poll_mode_queue_depth.unwrap_or(DEFAULT_POLL_MODE_QUEUE_DEPTH),
1654 ),
1655 }),
1656 }
1657 }
1658
1659 pub fn attach(&self, path: ScsiPath, disk: ScsiControllerDisk) -> Result<(), ScsiPathInUse> {
1660 match self.state.disks.write().entry(path) {
1661 Entry::Occupied(_) => return Err(ScsiPathInUse(path)),
1662 Entry::Vacant(entry) => entry.insert(disk),
1663 };
1664 for source in self.state.rescan_notification_source.lock().iter_mut() {
1665 source.try_send(()).ok();
1668 }
1669 Ok(())
1670 }
1671
1672 pub fn remove(&self, path: ScsiPath) -> Result<(), ScsiPathNotInUse> {
1673 match self.state.disks.write().entry(path) {
1674 Entry::Vacant(_) => return Err(ScsiPathNotInUse(path)),
1675 Entry::Occupied(entry) => {
1676 entry.remove();
1677 }
1678 }
1679 for source in self.state.rescan_notification_source.lock().iter_mut() {
1680 source.try_send(()).ok();
1683 }
1684 Ok(())
1685 }
1686}
1687
1688impl ScsiControllerState {
1689 fn add_rescan_notification_source(&self, source: futures::channel::mpsc::Sender<()>) {
1690 self.rescan_notification_source.lock().push(source);
1691 }
1692
1693 fn remove_rescan_notification_source(&self, target: &futures::channel::mpsc::Receiver<()>) {
1694 let mut sources = self.rescan_notification_source.lock();
1695 if let Some(index) = sources
1696 .iter()
1697 .position(|source| source.is_connected_to(target))
1698 {
1699 sources.remove(index);
1700 }
1701 }
1702}
1703
1704#[async_trait]
1705impl VmbusDevice for StorageDevice {
1706 fn offer(&self) -> OfferParams {
1707 if let Some(path) = self.ide_path {
1708 let offer_properties = storvsp_protocol::OfferProperties {
1709 path_id: path.path,
1710 target_id: path.target,
1711 flags: storvsp_protocol::OFFER_PROPERTIES_FLAG_IDE_DEVICE,
1712 ..FromZeros::new_zeroed()
1713 };
1714 let mut user_defined = UserDefinedData::new_zeroed();
1715 offer_properties
1716 .write_to_prefix(&mut user_defined[..])
1717 .unwrap();
1718 OfferParams {
1719 interface_name: "ide-accel".to_owned(),
1720 instance_id: self.instance_id,
1721 interface_id: storvsp_protocol::IDE_ACCELERATOR_INTERFACE_ID,
1722 channel_type: ChannelType::Interface { user_defined },
1723 ..Default::default()
1724 }
1725 } else {
1726 OfferParams {
1727 interface_name: "scsi".to_owned(),
1728 instance_id: self.instance_id,
1729 interface_id: storvsp_protocol::SCSI_INTERFACE_ID,
1730 ..Default::default()
1731 }
1732 }
1733 }
1734
1735 fn max_subchannels(&self) -> u16 {
1736 self.max_sub_channel_count
1737 }
1738
1739 fn install(&mut self, resources: DeviceResources) {
1740 self.resources = resources;
1741 }
1742
1743 async fn open(
1744 &mut self,
1745 channel_index: u16,
1746 open_request: &OpenRequest,
1747 ) -> Result<(), ChannelOpenError> {
1748 tracing::debug!(channel_index, "scsi open channel");
1749 self.new_worker(open_request, channel_index)?;
1750 self.workers[channel_index as usize].worker.start();
1751 Ok(())
1752 }
1753
1754 async fn close(&mut self, channel_index: u16) {
1755 tracing::debug!(channel_index, "scsi close channel");
1756 let worker = &mut self.workers[channel_index as usize].worker;
1757 worker.stop().await;
1758 if worker.state_mut().is_some() {
1759 worker
1760 .state_mut()
1761 .unwrap()
1762 .wait_for_scsi_requests_complete()
1763 .await;
1764 worker.remove();
1765 }
1766 if channel_index == 0 {
1767 *self.protocol.state.write() = ProtocolState::Init(InitState::Begin);
1768 }
1769 }
1770
1771 async fn retarget_vp(&mut self, channel_index: u16, target_vp: u32) {
1772 self.workers[channel_index as usize]
1773 .driver
1774 .retarget_vp(target_vp);
1775 }
1776
1777 fn start(&mut self) {
1778 for task in self
1779 .workers
1780 .iter_mut()
1781 .filter(|task| task.worker.has_state() && !task.worker.is_running())
1782 {
1783 task.worker.start();
1784 }
1785 }
1786
1787 async fn stop(&mut self) {
1788 tracing::debug!(instance_id = ?self.instance_id, "StorageDevice stopping...");
1789 for task in self
1790 .workers
1791 .iter_mut()
1792 .filter(|task| task.worker.has_state() && task.worker.is_running())
1793 {
1794 task.worker.stop().await;
1795 }
1796 }
1797
1798 fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1799 Some(self)
1800 }
1801}
1802
1803#[async_trait]
1804impl SaveRestoreVmbusDevice for StorageDevice {
1805 async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1806 Ok(SavedStateBlob::new(self.save()?))
1807 }
1808
1809 async fn restore(
1810 &mut self,
1811 control: RestoreControl<'_>,
1812 state: SavedStateBlob,
1813 ) -> Result<(), RestoreError> {
1814 self.restore(control, state.parse()?).await
1815 }
1816}
1817
1818#[cfg(test)]
1819mod tests {
1820 use super::*;
1821 use crate::test_helpers::TestWorker;
1822 use crate::test_helpers::parse_guest_completion;
1823 use crate::test_helpers::parse_guest_completion_check_flags_status;
1824 use pal_async::DefaultDriver;
1825 use pal_async::async_test;
1826 use scsi::srb::SrbStatus;
1827 use test_with_tracing::test;
1828 use vmbus_channel::connected_async_channels;
1829
1830 impl Clone for ScsiController {
1834 fn clone(&self) -> Self {
1835 ScsiController {
1836 state: self.state.clone(),
1837 }
1838 }
1839 }
1840
1841 #[async_test]
1842 async fn test_channel_working(driver: DefaultDriver) {
1843 let (host, guest) = connected_async_channels(16 * 1024);
1845 let guest_queue = Queue::new(guest).unwrap();
1846
1847 let test_guest_mem = GuestMemory::allocate(16384);
1848 let controller = ScsiController::new();
1849 let disk = scsidisk::SimpleScsiDisk::new(
1850 disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
1851 Default::default(),
1852 );
1853 controller
1854 .attach(
1855 ScsiPath {
1856 path: 0,
1857 target: 0,
1858 lun: 0,
1859 },
1860 ScsiControllerDisk::new(Arc::new(disk)),
1861 )
1862 .unwrap();
1863
1864 let test_worker = TestWorker::start(
1865 controller.clone(),
1866 driver.clone(),
1867 test_guest_mem.clone(),
1868 host,
1869 None,
1870 );
1871
1872 let mut guest = test_helpers::TestGuest {
1873 queue: guest_queue,
1874 transaction_id: 0,
1875 };
1876
1877 guest.perform_protocol_negotiation().await;
1878
1879 const IO_LEN: usize = 4 * 1024;
1881 let write_buf = [7u8; IO_LEN];
1882 let write_gpa = 4 * 1024u64;
1883 test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
1884 guest
1885 .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
1886 .await;
1887
1888 guest
1889 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1890 .await;
1891
1892 let read_gpa = 8 * 1024u64;
1893 guest
1894 .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
1895 .await;
1896
1897 guest
1898 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1899 .await;
1900 let mut read_buf = [0u8; IO_LEN];
1901 test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
1902 for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
1903 assert_eq!(b1, b2);
1904 }
1905
1906 guest.verify_graceful_close(test_worker).await;
1908 }
1909
1910 #[async_test]
1911 async fn test_packet_sizes(driver: DefaultDriver) {
1912 let (host, guest) = connected_async_channels(16384);
1914 let guest_queue = Queue::new(guest).unwrap();
1915
1916 let test_guest_mem = GuestMemory::allocate(1024);
1917 let controller = ScsiController::new();
1918
1919 let _worker = TestWorker::start(
1920 controller.clone(),
1921 driver.clone(),
1922 test_guest_mem,
1923 host,
1924 None,
1925 );
1926
1927 let mut guest = test_helpers::TestGuest {
1928 queue: guest_queue,
1929 transaction_id: 0,
1930 };
1931
1932 let negotiate_packet = storvsp_protocol::Packet {
1933 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
1934 flags: 0,
1935 status: storvsp_protocol::NtStatus::SUCCESS,
1936 };
1937 guest
1938 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1939 .await;
1940
1941 guest.verify_completion(parse_guest_completion).await;
1942
1943 let header = storvsp_protocol::Packet {
1944 operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
1945 flags: 0,
1946 status: storvsp_protocol::NtStatus::SUCCESS,
1947 };
1948
1949 let mut buf = [0u8; 128];
1950 storvsp_protocol::ProtocolVersion {
1951 major_minor: !0,
1952 reserved: 0,
1953 }
1954 .write_to_prefix(&mut buf[..])
1955 .unwrap(); for &(len, resp_len) in &[(48, 48), (50, 56), (56, 56), (64, 64), (72, 64)] {
1958 guest
1959 .send_data_packet_sync(&[header.as_bytes(), &buf[..len - size_of_val(&header)]])
1960 .await;
1961
1962 guest
1963 .verify_completion(|packet| {
1964 let IncomingPacket::Completion(packet) = packet else {
1965 unreachable!()
1966 };
1967 assert_eq!(packet.reader().len(), resp_len);
1968 assert_eq!(
1969 packet
1970 .reader()
1971 .read_plain::<storvsp_protocol::Packet>()
1972 .unwrap()
1973 .status,
1974 storvsp_protocol::NtStatus::REVISION_MISMATCH
1975 );
1976 Ok(())
1977 })
1978 .await;
1979 }
1980 }
1981
1982 #[async_test]
1983 async fn test_wrong_first_packet(driver: DefaultDriver) {
1984 let (host, guest) = connected_async_channels(16384);
1986 let guest_queue = Queue::new(guest).unwrap();
1987
1988 let test_guest_mem = GuestMemory::allocate(1024);
1989 let controller = ScsiController::new();
1990
1991 let _worker = TestWorker::start(
1992 controller.clone(),
1993 driver.clone(),
1994 test_guest_mem,
1995 host,
1996 None,
1997 );
1998
1999 let mut guest = test_helpers::TestGuest {
2000 queue: guest_queue,
2001 transaction_id: 0,
2002 };
2003
2004 let negotiate_packet = storvsp_protocol::Packet {
2006 operation: storvsp_protocol::Operation::END_INITIALIZATION,
2007 flags: 0,
2008 status: storvsp_protocol::NtStatus::SUCCESS,
2009 };
2010 guest
2011 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2012 .await;
2013
2014 guest
2015 .verify_completion(|packet| {
2016 let IncomingPacket::Completion(packet) = packet else {
2017 unreachable!()
2018 };
2019 assert_eq!(
2020 packet
2021 .reader()
2022 .read_plain::<storvsp_protocol::Packet>()
2023 .unwrap()
2024 .status,
2025 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE
2026 );
2027 Ok(())
2028 })
2029 .await;
2030 }
2031
2032 #[async_test]
2033 async fn test_unrecognized_operation(driver: DefaultDriver) {
2034 let (host, guest) = connected_async_channels(16384);
2036 let guest_queue = Queue::new(guest).unwrap();
2037
2038 let test_guest_mem = GuestMemory::allocate(1024);
2039 let controller = ScsiController::new();
2040
2041 let worker = TestWorker::start(
2042 controller.clone(),
2043 driver.clone(),
2044 test_guest_mem,
2045 host,
2046 None,
2047 );
2048
2049 let mut guest = test_helpers::TestGuest {
2050 queue: guest_queue,
2051 transaction_id: 0,
2052 };
2053
2054 let negotiate_packet = storvsp_protocol::Packet {
2056 operation: storvsp_protocol::Operation::REMOVE_DEVICE,
2057 flags: 0,
2058 status: storvsp_protocol::NtStatus::SUCCESS,
2059 };
2060 guest
2061 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2062 .await;
2063
2064 match worker.teardown().await {
2065 Err(WorkerError::PacketError(PacketError::UnrecognizedOperation(
2066 storvsp_protocol::Operation::REMOVE_DEVICE,
2067 ))) => {}
2068 result => panic!("Worker failed with unexpected result {:?}!", result),
2069 }
2070 }
2071
2072 #[async_test]
2073 async fn test_too_many_subchannels(driver: DefaultDriver) {
2074 let (host, guest) = connected_async_channels(16384);
2076 let guest_queue = Queue::new(guest).unwrap();
2077
2078 let test_guest_mem = GuestMemory::allocate(1024);
2079 let controller = ScsiController::new();
2080
2081 let _worker = TestWorker::start(
2082 controller.clone(),
2083 driver.clone(),
2084 test_guest_mem,
2085 host,
2086 None,
2087 );
2088
2089 let mut guest = test_helpers::TestGuest {
2090 queue: guest_queue,
2091 transaction_id: 0,
2092 };
2093
2094 let negotiate_packet = storvsp_protocol::Packet {
2095 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2096 flags: 0,
2097 status: storvsp_protocol::NtStatus::SUCCESS,
2098 };
2099 guest
2100 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2101 .await;
2102 guest.verify_completion(parse_guest_completion).await;
2103
2104 let version_packet = storvsp_protocol::Packet {
2105 operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2106 flags: 0,
2107 status: storvsp_protocol::NtStatus::SUCCESS,
2108 };
2109 let version = storvsp_protocol::ProtocolVersion {
2110 major_minor: storvsp_protocol::VERSION_BLUE,
2111 reserved: 0,
2112 };
2113 guest
2114 .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2115 .await;
2116 guest.verify_completion(parse_guest_completion).await;
2117
2118 let properties_packet = storvsp_protocol::Packet {
2119 operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2120 flags: 0,
2121 status: storvsp_protocol::NtStatus::SUCCESS,
2122 };
2123 guest
2124 .send_data_packet_sync(&[properties_packet.as_bytes()])
2125 .await;
2126
2127 guest.verify_completion(parse_guest_completion).await;
2128
2129 let negotiate_packet = storvsp_protocol::Packet {
2130 operation: storvsp_protocol::Operation::CREATE_SUB_CHANNELS,
2131 flags: 0,
2132 status: storvsp_protocol::NtStatus::SUCCESS,
2133 };
2134 guest
2136 .send_data_packet_sync(&[negotiate_packet.as_bytes(), 1_u16.as_bytes()])
2137 .await;
2138
2139 guest
2140 .verify_completion(|packet| {
2141 let IncomingPacket::Completion(packet) = packet else {
2142 unreachable!()
2143 };
2144 assert_eq!(
2145 packet
2146 .reader()
2147 .read_plain::<storvsp_protocol::Packet>()
2148 .unwrap()
2149 .status,
2150 storvsp_protocol::NtStatus::INVALID_PARAMETER
2151 );
2152 Ok(())
2153 })
2154 .await;
2155 }
2156
2157 #[async_test]
2158 async fn test_begin_init_on_ready(driver: DefaultDriver) {
2159 let (host, guest) = connected_async_channels(16384);
2161 let guest_queue = Queue::new(guest).unwrap();
2162
2163 let test_guest_mem = GuestMemory::allocate(1024);
2164 let controller = ScsiController::new();
2165
2166 let _worker = TestWorker::start(
2167 controller.clone(),
2168 driver.clone(),
2169 test_guest_mem,
2170 host,
2171 None,
2172 );
2173
2174 let mut guest = test_helpers::TestGuest {
2175 queue: guest_queue,
2176 transaction_id: 0,
2177 };
2178
2179 guest.perform_protocol_negotiation().await;
2180
2181 let negotiate_packet = storvsp_protocol::Packet {
2183 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2184 flags: 0,
2185 status: storvsp_protocol::NtStatus::SUCCESS,
2186 };
2187 guest
2188 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2189 .await;
2190
2191 guest
2192 .verify_completion(|p| {
2193 parse_guest_completion_check_flags_status(
2194 p,
2195 0,
2196 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
2197 )
2198 })
2199 .await;
2200 }
2201
2202 #[async_test]
2203 async fn test_hot_add_remove(driver: DefaultDriver) {
2204 let (host, guest) = connected_async_channels(16 * 1024);
2206 let guest_queue = Queue::new(guest).unwrap();
2207
2208 let test_guest_mem = GuestMemory::allocate(16384);
2209 let controller = ScsiController::new();
2211
2212 let test_worker = TestWorker::start(
2213 controller.clone(),
2214 driver.clone(),
2215 test_guest_mem.clone(),
2216 host,
2217 None,
2218 );
2219
2220 let mut guest = test_helpers::TestGuest {
2221 queue: guest_queue,
2222 transaction_id: 0,
2223 };
2224
2225 guest.perform_protocol_negotiation().await;
2226
2227 let mut lun_list_buffer: [u8; 256] = [0; 256];
2229 let mut disk_count = 0;
2230 guest
2231 .send_report_luns_packet(ScsiPath::default(), 0, lun_list_buffer.len())
2232 .await;
2233 guest
2234 .verify_completion(|p| {
2235 test_helpers::parse_guest_completed_io_check_tx_len(p, SrbStatus::SUCCESS, Some(8))
2236 })
2237 .await;
2238 test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2239 let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2240 assert_eq!(lun_list_size, disk_count as u32 * 8);
2241
2242 const IO_LEN: usize = 4 * 1024;
2244 let write_buf = [7u8; IO_LEN];
2245 let write_gpa = 4 * 1024u64;
2246 test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2247
2248 guest
2249 .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2250 .await;
2251 guest
2252 .verify_completion(|p| {
2253 test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2254 })
2255 .await;
2256
2257 for lun in 0..4 {
2259 let disk = scsidisk::SimpleScsiDisk::new(
2260 disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
2261 Default::default(),
2262 );
2263 controller
2264 .attach(
2265 ScsiPath {
2266 path: 0,
2267 target: 0,
2268 lun,
2269 },
2270 ScsiControllerDisk::new(Arc::new(disk)),
2271 )
2272 .unwrap();
2273 guest
2274 .verify_completion(test_helpers::parse_guest_enumerate_bus)
2275 .await;
2276
2277 disk_count += 1;
2278 guest
2279 .send_report_luns_packet(ScsiPath::default(), 0, 256)
2280 .await;
2281 guest
2282 .verify_completion(|p| {
2283 test_helpers::parse_guest_completed_io_check_tx_len(
2284 p,
2285 SrbStatus::SUCCESS,
2286 Some((disk_count + 1) * 8),
2287 )
2288 })
2289 .await;
2290 test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2291 let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2292 assert_eq!(lun_list_size, disk_count as u32 * 8);
2293
2294 guest
2295 .send_write_packet(
2296 ScsiPath {
2297 path: 0,
2298 target: 0,
2299 lun,
2300 },
2301 write_gpa,
2302 1,
2303 IO_LEN,
2304 )
2305 .await;
2306 guest
2307 .verify_completion(|p| {
2308 test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS)
2309 })
2310 .await;
2311 }
2312
2313 for lun in 0..4 {
2315 controller
2316 .remove(ScsiPath {
2317 path: 0,
2318 target: 0,
2319 lun,
2320 })
2321 .unwrap();
2322 guest
2323 .verify_completion(test_helpers::parse_guest_enumerate_bus)
2324 .await;
2325
2326 disk_count -= 1;
2327 guest
2328 .send_report_luns_packet(ScsiPath::default(), 0, 4096)
2329 .await;
2330 guest
2331 .verify_completion(|p| {
2332 test_helpers::parse_guest_completed_io_check_tx_len(
2333 p,
2334 SrbStatus::SUCCESS,
2335 Some((disk_count + 1) * 8),
2336 )
2337 })
2338 .await;
2339 test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2340 let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2341 assert_eq!(lun_list_size, disk_count as u32 * 8);
2342
2343 guest
2344 .send_write_packet(
2345 ScsiPath {
2346 path: 0,
2347 target: 0,
2348 lun,
2349 },
2350 write_gpa,
2351 1,
2352 IO_LEN,
2353 )
2354 .await;
2355 guest
2356 .verify_completion(|p| {
2357 test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2358 })
2359 .await;
2360 }
2361
2362 guest.verify_graceful_close(test_worker).await;
2363 }
2364
2365 #[async_test]
2366 pub async fn test_async_disk(driver: DefaultDriver) {
2367 let device = disklayer_ram::ram_disk(64 * 1024, false).unwrap();
2368 let controller = ScsiController::new();
2369 let disk = ScsiControllerDisk::new(Arc::new(scsidisk::SimpleScsiDisk::new(
2370 device,
2371 Default::default(),
2372 )));
2373 controller
2374 .attach(
2375 ScsiPath {
2376 path: 0,
2377 target: 0,
2378 lun: 0,
2379 },
2380 disk,
2381 )
2382 .unwrap();
2383
2384 let (host, guest) = connected_async_channels(16 * 1024);
2385 let guest_queue = Queue::new(guest).unwrap();
2386
2387 let mut guest = test_helpers::TestGuest {
2388 queue: guest_queue,
2389 transaction_id: 0,
2390 };
2391
2392 let test_guest_mem = GuestMemory::allocate(16384);
2393 let worker = TestWorker::start(
2394 controller.clone(),
2395 &driver,
2396 test_guest_mem.clone(),
2397 host,
2398 None,
2399 );
2400
2401 let negotiate_packet = storvsp_protocol::Packet {
2402 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2403 flags: 0,
2404 status: storvsp_protocol::NtStatus::SUCCESS,
2405 };
2406 guest
2407 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2408 .await;
2409 guest.verify_completion(parse_guest_completion).await;
2410
2411 let version_packet = storvsp_protocol::Packet {
2412 operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2413 flags: 0,
2414 status: storvsp_protocol::NtStatus::SUCCESS,
2415 };
2416 let version = storvsp_protocol::ProtocolVersion {
2417 major_minor: storvsp_protocol::VERSION_BLUE,
2418 reserved: 0,
2419 };
2420 guest
2421 .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2422 .await;
2423 guest.verify_completion(parse_guest_completion).await;
2424
2425 let properties_packet = storvsp_protocol::Packet {
2426 operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2427 flags: 0,
2428 status: storvsp_protocol::NtStatus::SUCCESS,
2429 };
2430 guest
2431 .send_data_packet_sync(&[properties_packet.as_bytes()])
2432 .await;
2433 guest.verify_completion(parse_guest_completion).await;
2434
2435 let negotiate_packet = storvsp_protocol::Packet {
2436 operation: storvsp_protocol::Operation::END_INITIALIZATION,
2437 flags: 0,
2438 status: storvsp_protocol::NtStatus::SUCCESS,
2439 };
2440 guest
2441 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2442 .await;
2443 guest.verify_completion(parse_guest_completion).await;
2444
2445 const IO_LEN: usize = 4 * 1024;
2446 let write_buf = [7u8; IO_LEN];
2447 let write_gpa = 4 * 1024u64;
2448 test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2449 guest
2450 .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2451 .await;
2452 guest
2453 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2454 .await;
2455
2456 let read_gpa = 8 * 1024u64;
2457 guest
2458 .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
2459 .await;
2460 guest
2461 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2462 .await;
2463 let mut read_buf = [0u8; IO_LEN];
2464 test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
2465 for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
2466 assert_eq!(b1, b2);
2467 }
2468
2469 guest.verify_graceful_close(worker).await;
2470 }
2471}