1#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7#[cfg(feature = "ioperf")]
8pub mod ioperf;
9
10#[cfg(feature = "test")]
11pub mod test_helpers;
12
13#[cfg(not(feature = "test"))]
14mod test_helpers;
15
16pub mod resolver;
17mod save_restore;
18
19use crate::ring::gparange::MultiPagedRangeBuf;
20use anyhow::Context as _;
21use async_trait::async_trait;
22use fast_select::FastSelect;
23use futures::FutureExt;
24use futures::StreamExt;
25use futures::select_biased;
26use guestmem::AccessError;
27use guestmem::GuestMemory;
28use guestmem::MemoryRead;
29use guestmem::MemoryWrite;
30use guestmem::ranges::PagedRange;
31use guid::Guid;
32use inspect::Inspect;
33use inspect::InspectMut;
34use inspect_counters::Counter;
35use inspect_counters::Histogram;
36use oversized_box::OversizedBox;
37use parking_lot::Mutex;
38use parking_lot::RwLock;
39use ring::OutgoingPacketType;
40use scsi::AdditionalSenseCode;
41use scsi::ScsiOp;
42use scsi::ScsiStatus;
43use scsi::srb::SrbStatus;
44use scsi::srb::SrbStatusAndFlags;
45use scsi_buffers::RequestBuffers;
46use scsi_core::AsyncScsiDisk;
47use scsi_core::Request;
48use scsi_core::ScsiResult;
49use scsi_defs as scsi;
50use scsidisk::illegal_request_sense;
51use slab::Slab;
52use std::collections::hash_map::Entry;
53use std::collections::hash_map::HashMap;
54use std::fmt::Debug;
55use std::future::Future;
56use std::future::poll_fn;
57use std::pin::Pin;
58use std::sync::Arc;
59use std::sync::atomic::AtomicU32;
60use std::sync::atomic::Ordering::Relaxed;
61use std::task::Context;
62use std::task::Poll;
63use storvsp_resources::ScsiPath;
64use task_control::AsyncRun;
65use task_control::InspectTask;
66use task_control::StopTask;
67use task_control::TaskControl;
68use thiserror::Error;
69use tracing_helpers::ErrorValueExt;
70use unicycle::FuturesUnordered;
71use vmbus_async::queue;
72use vmbus_async::queue::ExternalDataError;
73use vmbus_async::queue::IncomingPacket;
74use vmbus_async::queue::OutgoingPacket;
75use vmbus_async::queue::Queue;
76use vmbus_channel::RawAsyncChannel;
77use vmbus_channel::bus::ChannelType;
78use vmbus_channel::bus::OfferParams;
79use vmbus_channel::bus::OpenRequest;
80use vmbus_channel::channel::ChannelControl;
81use vmbus_channel::channel::ChannelOpenError;
82use vmbus_channel::channel::DeviceResources;
83use vmbus_channel::channel::RestoreControl;
84use vmbus_channel::channel::SaveRestoreVmbusDevice;
85use vmbus_channel::channel::VmbusDevice;
86use vmbus_channel::gpadl_ring::GpadlRingMem;
87use vmbus_channel::gpadl_ring::gpadl_channel;
88use vmbus_core::protocol::UserDefinedData;
89use vmbus_ring as ring;
90use vmbus_ring::RingMem;
91use vmcore::save_restore::RestoreError;
92use vmcore::save_restore::SaveError;
93use vmcore::save_restore::SavedStateBlob;
94use vmcore::vm_task::VmTaskDriver;
95use vmcore::vm_task::VmTaskDriverSource;
96use zerocopy::FromBytes;
97use zerocopy::FromZeros;
98use zerocopy::Immutable;
99use zerocopy::IntoBytes;
100use zerocopy::KnownLayout;
101
102const DEFAULT_POLL_MODE_QUEUE_DEPTH: u32 = 1;
107
108pub struct StorageDevice {
109 instance_id: Guid,
110 ide_path: Option<ScsiPath>,
111 workers: Vec<WorkerAndDriver>,
112 controller: Arc<ScsiControllerState>,
113 resources: DeviceResources,
114 driver_source: VmTaskDriverSource,
115 max_sub_channel_count: u16,
116 protocol: Arc<Protocol>,
117 io_queue_depth: u32,
118}
119
120#[derive(Inspect)]
121struct WorkerAndDriver {
122 #[inspect(flatten)]
123 worker: TaskControl<WorkerState, Worker>,
124 driver: VmTaskDriver,
125}
126
127struct WorkerState;
128
129impl InspectMut for StorageDevice {
130 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
131 let mut resp = req.respond();
132
133 let disks = self.controller.disks.read();
134 for (path, controller_disk) in disks.iter() {
135 resp.child(&format!("disks/{}", path), |req| {
136 controller_disk.disk.inspect(req);
137 });
138 }
139
140 resp.fields(
141 "channels",
142 self.workers
143 .iter()
144 .filter(|task| task.worker.has_state())
145 .enumerate(),
146 )
147 .field(
148 "poll_mode_queue_depth",
149 inspect::AtomicMut(&self.controller.poll_mode_queue_depth),
150 );
151 }
152}
153
154struct Worker<T: RingMem = GpadlRingMem> {
155 inner: WorkerInner,
156 rescan_notification: futures::channel::mpsc::Receiver<()>,
157 fast_select: FastSelect,
158 queue: Queue<T>,
159}
160
161struct Protocol {
162 state: RwLock<ProtocolState>,
163 ready: event_listener::Event,
165}
166
167struct WorkerInner {
168 protocol: Arc<Protocol>,
169 request_size: usize,
170 controller: Arc<ScsiControllerState>,
171 channel_index: u16,
172 scsi_queue: Arc<ScsiCommandQueue>,
173 scsi_requests: FuturesUnordered<ScsiRequest>,
174 scsi_requests_states: Slab<ScsiRequestState>,
175 full_request_pool: Vec<Arc<ScsiRequestAndRange>>,
176 future_pool: Vec<OversizedBox<(), ScsiOpStorage>>,
177 channel_control: ChannelControl,
178 max_io_queue_depth: usize,
179 stats: WorkerStats,
180}
181
182#[derive(Debug, Default, Inspect)]
183struct WorkerStats {
184 ios_submitted: Counter,
185 ios_completed: Counter,
186 wakes: Counter,
187 wakes_spurious: Counter,
188 per_wake_submissions: Histogram<10>,
189 per_wake_completions: Histogram<10>,
190}
191
192#[repr(u16)]
193#[derive(Copy, Clone, Debug, Inspect, PartialEq, Eq, PartialOrd, Ord)]
194enum Version {
195 Win6 = storvsp_protocol::VERSION_WIN6,
196 Win7 = storvsp_protocol::VERSION_WIN7,
197 Win8 = storvsp_protocol::VERSION_WIN8,
198 Blue = storvsp_protocol::VERSION_BLUE,
199}
200
201#[derive(Debug, Error)]
202#[error("protocol version {0:#x} not supported")]
203struct UnsupportedVersion(u16);
204
205impl Version {
206 fn parse(major_minor: u16) -> Result<Self, UnsupportedVersion> {
207 let version = match major_minor {
208 storvsp_protocol::VERSION_WIN6 => Self::Win6,
209 storvsp_protocol::VERSION_WIN7 => Self::Win7,
210 storvsp_protocol::VERSION_WIN8 => Self::Win8,
211 storvsp_protocol::VERSION_BLUE => Self::Blue,
212 version => return Err(UnsupportedVersion(version)),
213 };
214 assert_eq!(version as u16, major_minor);
215 Ok(version)
216 }
217
218 fn max_request_size(&self) -> usize {
219 match self {
220 Version::Win8 | Version::Blue => storvsp_protocol::SCSI_REQUEST_LEN_V2,
221 Version::Win6 | Version::Win7 => storvsp_protocol::SCSI_REQUEST_LEN_V1,
222 }
223 }
224}
225
226#[derive(Copy, Clone)]
227enum ProtocolState {
228 Init(InitState),
229 Ready {
230 version: Version,
231 subchannel_count: u16,
232 },
233}
234
235#[derive(Copy, Clone, Debug)]
236enum InitState {
237 Begin,
238 QueryVersion,
239 QueryProperties {
240 version: Version,
241 },
242 EndInitialization {
243 version: Version,
244 subchannel_count: Option<u16>,
245 },
246}
247
248type ScsiOpStorage = [u64; SCSI_REQUEST_STACK_SIZE / 8];
256type ScsiOpFuture = Pin<OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>>;
257
258const SCSI_REQUEST_STACK_SIZE: usize = scsi_core::ASYNC_SCSI_DISK_STACK_SIZE + 272;
263
264struct ScsiRequest {
265 request_id: usize,
266 future: Option<ScsiOpFuture>,
267}
268
269impl ScsiRequest {
270 fn new(
271 request_id: usize,
272 future: OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>,
273 ) -> Self {
274 Self {
275 request_id,
276 future: Some(future.into()),
277 }
278 }
279}
280
281impl Future for ScsiRequest {
282 type Output = (usize, ScsiResult, OversizedBox<(), ScsiOpStorage>);
283
284 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
285 let this = self.get_mut();
286 let future = this.future.as_mut().unwrap().as_mut();
287 let result = std::task::ready!(future.poll(cx));
288 let future = this.future.take().unwrap();
290 Poll::Ready((this.request_id, result, OversizedBox::empty_pinned(future)))
291 }
292}
293
294#[derive(Debug, Error)]
295enum WorkerError {
296 #[error("packet error")]
297 PacketError(#[source] PacketError),
298 #[error("queue error")]
299 Queue(#[source] queue::Error),
300 #[error("queue should have enough space but no longer does")]
301 NotEnoughSpace,
302}
303
304#[derive(Debug, Error)]
305enum PacketError {
306 #[error("Not transactional")]
307 NotTransactional,
308 #[error("Unrecognized operation {0:?}")]
309 UnrecognizedOperation(storvsp_protocol::Operation),
310 #[error("Invalid packet type")]
311 InvalidPacketType,
312 #[error("Invalid data transfer length")]
313 InvalidDataTransferLength,
314 #[error("Access error: {0}")]
315 Access(#[source] AccessError),
316 #[error("Range error")]
317 Range(#[source] ExternalDataError),
318}
319
320#[derive(Debug, Default, Clone)]
321struct Range {
322 len: usize,
323 is_write: bool,
324}
325
326impl Range {
327 fn new(buf: &MultiPagedRangeBuf, request: &storvsp_protocol::ScsiRequest) -> Option<Self> {
328 let len = request.data_transfer_length as usize;
329 let is_write = request.data_in != 0;
330 if buf.range_count() > 1 || (len > 0 && buf.first()?.len() < len) {
333 return None;
334 }
335 Some(Self { len, is_write })
336 }
337
338 fn buffer<'a>(
339 &'a self,
340 buf: &'a MultiPagedRangeBuf,
341 guest_memory: &'a GuestMemory,
342 ) -> RequestBuffers<'a> {
343 let mut range = buf.first().unwrap_or_else(PagedRange::empty);
344 range.truncate(self.len);
345 RequestBuffers::new(guest_memory, range, self.is_write)
346 }
347}
348
349#[derive(Debug)]
350struct Packet {
351 data: PacketData,
352 transaction_id: u64,
353 request_size: usize,
354}
355
356#[derive(Debug)]
357enum PacketData {
358 BeginInitialization,
359 EndInitialization,
360 QueryProtocolVersion(u16),
361 QueryProperties,
362 CreateSubChannels(u16),
363 ExecuteScsi(Arc<ScsiRequestAndRange>),
364 ResetBus,
365 ResetAdapter,
366 ResetLun,
367}
368
369#[derive(Debug)]
370pub struct RangeError;
371
372fn parse_packet<T: RingMem>(
373 packet: &IncomingPacket<'_, T>,
374 pool: &mut Vec<Arc<ScsiRequestAndRange>>,
375) -> Result<Packet, PacketError> {
376 let packet = match packet {
377 IncomingPacket::Completion(_) => return Err(PacketError::InvalidPacketType),
378 IncomingPacket::Data(packet) => packet,
379 };
380 let transaction_id = packet
381 .transaction_id()
382 .ok_or(PacketError::NotTransactional)?;
383
384 let mut reader = packet.reader();
385 let header: storvsp_protocol::Packet = reader.read_plain().map_err(PacketError::Access)?;
386 let request_size = reader.len().min(storvsp_protocol::SCSI_REQUEST_LEN_MAX);
390 let data = match header.operation {
391 storvsp_protocol::Operation::BEGIN_INITIALIZATION => PacketData::BeginInitialization,
392 storvsp_protocol::Operation::END_INITIALIZATION => PacketData::EndInitialization,
393 storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION => {
394 let mut version = storvsp_protocol::ProtocolVersion::new_zeroed();
395 reader
396 .read(version.as_mut_bytes())
397 .map_err(PacketError::Access)?;
398 PacketData::QueryProtocolVersion(version.major_minor)
399 }
400 storvsp_protocol::Operation::QUERY_PROPERTIES => PacketData::QueryProperties,
401 storvsp_protocol::Operation::EXECUTE_SRB => {
402 let mut full_request = pool.pop().unwrap_or_else(|| {
403 Arc::new(ScsiRequestAndRange {
404 external_data: Range::default(),
405 external_data_buf: MultiPagedRangeBuf::new(),
406 request: storvsp_protocol::ScsiRequest::new_zeroed(),
407 request_size,
408 })
409 });
410
411 {
412 let full_request = Arc::get_mut(&mut full_request).unwrap();
413 let request_buf = &mut full_request.request.as_mut_bytes()[..request_size];
414 reader.read(request_buf).map_err(PacketError::Access)?;
415
416 full_request.external_data_buf.clear();
417 packet
418 .read_external_ranges(&mut full_request.external_data_buf)
419 .map_err(PacketError::Range)?;
420
421 full_request.external_data =
422 Range::new(&full_request.external_data_buf, &full_request.request)
423 .ok_or(PacketError::InvalidDataTransferLength)?;
424 }
425
426 PacketData::ExecuteScsi(full_request)
427 }
428 storvsp_protocol::Operation::RESET_LUN => PacketData::ResetLun,
429 storvsp_protocol::Operation::RESET_ADAPTER => PacketData::ResetAdapter,
430 storvsp_protocol::Operation::RESET_BUS => PacketData::ResetBus,
431 storvsp_protocol::Operation::CREATE_SUB_CHANNELS => {
432 let mut sub_channel_count: u16 = 0;
433 reader
434 .read(sub_channel_count.as_mut_bytes())
435 .map_err(PacketError::Access)?;
436 PacketData::CreateSubChannels(sub_channel_count)
437 }
438 _ => return Err(PacketError::UnrecognizedOperation(header.operation)),
439 };
440
441 if let PacketData::ExecuteScsi(_) = data {
442 tracing::trace!(transaction_id, ?data, "parse_packet");
443 } else {
444 tracing::debug!(transaction_id, ?data, "parse_packet");
445 }
446
447 Ok(Packet {
448 data,
449 request_size,
450 transaction_id,
451 })
452}
453
454impl WorkerInner {
455 fn send_vmbus_packet<M: RingMem>(
456 &mut self,
457 writer: &mut queue::WriteBatch<'_, M>,
458 packet_type: OutgoingPacketType<'_>,
459 request_size: usize,
460 transaction_id: u64,
461 operation: storvsp_protocol::Operation,
462 status: storvsp_protocol::NtStatus,
463 payload: &[u8],
464 ) -> Result<(), WorkerError> {
465 let header = storvsp_protocol::Packet {
466 operation,
467 flags: 0,
468 status,
469 };
470
471 let packet_size = size_of_val(&header) + request_size;
472
473 let len = size_of_val(&header) + size_of_val(payload);
478 let padding = [0; storvsp_protocol::SCSI_REQUEST_LEN_MAX];
479 let (payload_bytes, padding_bytes) = if len > packet_size {
480 (&payload[..packet_size - size_of_val(&header)], &[][..])
481 } else {
482 (payload, &padding[..packet_size - len])
483 };
484 assert_eq!(
485 size_of_val(&header) + payload_bytes.len() + padding_bytes.len(),
486 packet_size
487 );
488 writer
489 .try_write(&OutgoingPacket {
490 transaction_id,
491 packet_type,
492 payload: &[header.as_bytes(), payload_bytes, padding_bytes],
493 })
494 .map_err(|err| match err {
495 queue::TryWriteError::Full(_) => WorkerError::NotEnoughSpace,
496 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
497 })
498 }
499
500 fn send_packet<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
501 &mut self,
502 writer: &mut queue::WriteHalf<'_, M>,
503 operation: storvsp_protocol::Operation,
504 status: storvsp_protocol::NtStatus,
505 payload: &P,
506 ) -> Result<(), WorkerError> {
507 self.send_vmbus_packet(
508 &mut writer.batched(),
509 OutgoingPacketType::InBandNoCompletion,
510 self.request_size,
511 0,
512 operation,
513 status,
514 payload.as_bytes(),
515 )
516 }
517
518 fn send_completion<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
519 &mut self,
520 writer: &mut queue::WriteHalf<'_, M>,
521 packet: &Packet,
522 status: storvsp_protocol::NtStatus,
523 payload: &P,
524 ) -> Result<(), WorkerError> {
525 self.send_vmbus_packet(
526 &mut writer.batched(),
527 OutgoingPacketType::Completion,
528 packet.request_size,
529 packet.transaction_id,
530 storvsp_protocol::Operation::COMPLETE_IO,
531 status,
532 payload.as_bytes(),
533 )
534 }
535}
536
537struct ScsiCommandQueue {
538 controller: Arc<ScsiControllerState>,
539 mem: GuestMemory,
540 force_path_id: Option<u8>,
541}
542
543impl ScsiCommandQueue {
544 async fn execute_scsi(&self, full_request: &ScsiRequestAndRange) -> ScsiResult {
545 let request = &full_request.request;
546 let op = ScsiOp(request.payload[0]);
547 let external_data = full_request
548 .external_data
549 .buffer(&full_request.external_data_buf, &self.mem);
550
551 tracing::trace!(
552 path_id = request.path_id,
553 target_id = request.target_id,
554 lun = request.lun,
555 op = ?op,
556 "execute_scsi start...",
557 );
558
559 let path_id = self.force_path_id.unwrap_or(request.path_id);
560
561 let controller_disk = self
562 .controller
563 .disks
564 .read()
565 .get(&ScsiPath {
566 path: path_id,
567 target: request.target_id,
568 lun: request.lun,
569 })
570 .cloned();
571
572 let result = match op {
573 ScsiOp::REPORT_LUNS => {
574 const HEADER_SIZE: usize = size_of::<scsi::LunList>();
575 let mut luns: Vec<u8> = self
576 .controller
577 .disks
578 .read()
579 .iter()
580 .flat_map(|(path, _)| {
581 if request.path_id == path.path && request.target_id == path.target {
584 Some(path.lun)
585 } else {
586 None
587 }
588 })
589 .collect();
590 luns.sort_unstable();
591 let mut data: Vec<u64> = vec![0; luns.len() + 1];
592 let header = scsi::LunList {
593 length: (luns.len() as u32 * 8).into(),
594 reserved: [0; 4],
595 };
596 data.as_mut_bytes()[..HEADER_SIZE].copy_from_slice(header.as_bytes());
597 for (i, lun) in luns.iter().enumerate() {
598 data[i + 1].as_mut_bytes()[..2].copy_from_slice(&(*lun as u16).to_be_bytes());
599 }
600 if external_data.len() >= HEADER_SIZE {
601 let tx = std::cmp::min(external_data.len(), data.as_bytes().len());
602 external_data.writer().write(&data.as_bytes()[..tx]).map_or(
603 ScsiResult {
604 scsi_status: ScsiStatus::CHECK_CONDITION,
605 srb_status: SrbStatus::INVALID_REQUEST,
606 tx: 0,
607 sense_data: Some(illegal_request_sense(
608 AdditionalSenseCode::INVALID_CDB,
609 )),
610 },
611 |_| ScsiResult {
612 scsi_status: ScsiStatus::GOOD,
613 srb_status: SrbStatus::SUCCESS,
614 tx,
615 sense_data: None,
616 },
617 )
618 } else {
619 ScsiResult {
620 scsi_status: ScsiStatus::GOOD,
621 srb_status: SrbStatus::SUCCESS,
622 tx: 0,
623 sense_data: None,
624 }
625 }
626 }
627 _ if controller_disk.is_some() => {
628 let mut cdb = [0; 16];
629 cdb.copy_from_slice(&request.payload[0..storvsp_protocol::CDB16GENERIC_LENGTH]);
630 controller_disk
631 .unwrap()
632 .disk
633 .execute_scsi(
634 &external_data,
635 &Request {
636 cdb,
637 srb_flags: request.srb_flags,
638 },
639 )
640 .await
641 }
642 ScsiOp::INQUIRY => {
643 let cdb = scsi::CdbInquiry::ref_from_prefix(&request.payload)
644 .unwrap()
645 .0; if external_data.len() < cdb.allocation_length.get() as usize
647 || request.data_in != storvsp_protocol::SCSI_IOCTL_DATA_IN
648 || (cdb.allocation_length.get() as usize) < size_of::<scsi::InquiryDataHeader>()
649 {
650 ScsiResult {
651 scsi_status: ScsiStatus::CHECK_CONDITION,
652 srb_status: SrbStatus::INVALID_REQUEST,
653 tx: 0,
654 sense_data: Some(illegal_request_sense(AdditionalSenseCode::INVALID_CDB)),
655 }
656 } else {
657 let enable_vpd = cdb.flags.vpd();
658 if enable_vpd || cdb.page_code != 0 {
659 ScsiResult {
661 scsi_status: ScsiStatus::CHECK_CONDITION,
662 srb_status: SrbStatus::INVALID_REQUEST,
663 tx: 0,
664 sense_data: Some(illegal_request_sense(
665 AdditionalSenseCode::INVALID_CDB,
666 )),
667 }
668 } else {
669 const LOGICAL_UNIT_NOT_PRESENT_DEVICE: u8 = 0x7F;
670 let mut data = scsidisk::INQUIRY_DATA_TEMPLATE;
671 data.header.device_type = LOGICAL_UNIT_NOT_PRESENT_DEVICE;
672
673 if request.lun != 0 {
674 data.vendor_id = [0; 8];
676 data.product_id = [0; 16];
677 data.product_revision_level = [0; 4];
678 }
679
680 let datab = data.as_bytes();
681 let tx = std::cmp::min(
682 cdb.allocation_length.get() as usize,
683 size_of::<scsi::InquiryData>(),
684 );
685 external_data.writer().write(&datab[..tx]).map_or(
686 ScsiResult {
687 scsi_status: ScsiStatus::CHECK_CONDITION,
688 srb_status: SrbStatus::INVALID_REQUEST,
689 tx: 0,
690 sense_data: Some(illegal_request_sense(
691 AdditionalSenseCode::INVALID_CDB,
692 )),
693 },
694 |_| ScsiResult {
695 scsi_status: ScsiStatus::GOOD,
696 srb_status: SrbStatus::SUCCESS,
697 tx,
698 sense_data: None,
699 },
700 )
701 }
702 }
703 }
704 _ => ScsiResult {
705 scsi_status: ScsiStatus::CHECK_CONDITION,
706 srb_status: SrbStatus::INVALID_LUN,
707 tx: 0,
708 sense_data: None,
709 },
710 };
711
712 tracing::trace!(
713 path_id = request.path_id,
714 target_id = request.target_id,
715 lun = request.lun,
716 op = ?op,
717 result = ?result,
718 "execute_scsi completed.",
719 );
720 result
721 }
722}
723
724impl<T: RingMem + 'static> Worker<T> {
725 fn new(
726 controller: Arc<ScsiControllerState>,
727 channel: RawAsyncChannel<T>,
728 channel_index: u16,
729 mem: GuestMemory,
730 channel_control: ChannelControl,
731 io_queue_depth: u32,
732 protocol: Arc<Protocol>,
733 force_path_id: Option<u8>,
734 ) -> anyhow::Result<Self> {
735 let queue = Queue::new(channel)?;
736 #[expect(clippy::disallowed_methods)] let (source, target) = futures::channel::mpsc::channel(1);
738 controller.add_rescan_notification_source(source);
739
740 let max_io_queue_depth = io_queue_depth.max(1) as usize;
741 Ok(Self {
742 inner: WorkerInner {
743 protocol,
744 request_size: storvsp_protocol::SCSI_REQUEST_LEN_V1,
745 controller: controller.clone(),
746 channel_index,
747 scsi_queue: Arc::new(ScsiCommandQueue {
748 controller,
749 mem,
750 force_path_id,
751 }),
752 scsi_requests: FuturesUnordered::new(),
753 scsi_requests_states: Slab::with_capacity(max_io_queue_depth),
754 channel_control,
755 max_io_queue_depth,
756 future_pool: Vec::new(),
757 full_request_pool: Vec::new(),
758 stats: Default::default(),
759 },
760 queue,
761 rescan_notification: target,
762 fast_select: FastSelect::new(),
763 })
764 }
765
766 async fn wait_for_scsi_requests_complete(&mut self) {
767 tracing::debug!(
768 channel_index = self.inner.channel_index,
769 "wait for IOs completed..."
770 );
771 while let Some((id, _, _)) = self.inner.scsi_requests.next().await {
772 self.inner.scsi_requests_states.remove(id);
773 }
774 }
775}
776
777impl InspectTask<Worker> for WorkerState {
778 fn inspect(&self, req: inspect::Request<'_>, worker: Option<&Worker>) {
779 if let Some(worker) = worker {
780 let mut resp = req.respond();
781 if worker.inner.channel_index == 0 {
782 let (state, version, subchannel_count) = match *worker.inner.protocol.state.read() {
783 ProtocolState::Init(state) => match state {
784 InitState::Begin => ("begin_init", None, None),
785 InitState::QueryVersion => ("query_version", None, None),
786 InitState::QueryProperties { version } => {
787 ("query_properties", Some(version), None)
788 }
789 InitState::EndInitialization {
790 version,
791 subchannel_count,
792 } => ("end_init", Some(version), subchannel_count),
793 },
794 ProtocolState::Ready {
795 version,
796 subchannel_count,
797 } => ("ready", Some(version), Some(subchannel_count)),
798 };
799 resp.field("state", state)
800 .field("version", version)
801 .field("subchannel_count", subchannel_count);
802 }
803 resp.field("pending_packets", worker.inner.scsi_requests_states.len())
804 .fields("io", worker.inner.scsi_requests_states.iter())
805 .field("stats", &worker.inner.stats)
806 .field("ring", &worker.queue)
807 .field("max_io_queue_depth", worker.inner.max_io_queue_depth);
808 }
809 }
810}
811
812impl<T: 'static + Send + Sync + RingMem> AsyncRun<Worker<T>> for WorkerState {
813 async fn run(
814 &mut self,
815 stop: &mut StopTask<'_>,
816 worker: &mut Worker<T>,
817 ) -> Result<(), task_control::Cancelled> {
818 let fut = async {
819 if worker.inner.channel_index == 0 {
820 worker.process_primary().await
821 } else {
822 let protocol_version = loop {
825 let listener = worker.inner.protocol.ready.listen();
826 if let ProtocolState::Ready { version, .. } =
827 *worker.inner.protocol.state.read()
828 {
829 break version;
830 }
831 tracing::debug!("subchannel waiting for initialization to end");
832 listener.await
833 };
834 worker
835 .inner
836 .process_ready(&mut worker.queue, protocol_version)
837 .await
838 }
839 };
840
841 match stop.until_stopped(fut).await? {
842 Ok(_) => {}
843 Err(e) => tracing::error!(error = e.as_error(), "process_packets error"),
844 }
845 Ok(())
846 }
847}
848
849impl WorkerInner {
850 async fn next_packet<'a, M: RingMem>(
853 &mut self,
854 reader: &'a mut queue::ReadHalf<'a, M>,
855 ) -> Result<Packet, WorkerError> {
856 let packet = reader.read().await.map_err(WorkerError::Queue)?;
857 let stor_packet =
858 parse_packet(&packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
859 Ok(stor_packet)
860 }
861
862 fn poll_for_ring_space<M: RingMem>(
869 &mut self,
870 cx: &mut Context<'_>,
871 writer: &mut queue::WriteHalf<'_, M>,
872 ) -> Poll<Result<(), WorkerError>> {
873 writer
874 .poll_ready(cx, MAX_VMBUS_PACKET_SIZE)
875 .map_err(WorkerError::Queue)
876 }
877}
878
879const MAX_VMBUS_PACKET_SIZE: usize = ring::PacketSize::in_band(
880 size_of::<storvsp_protocol::Packet>() + storvsp_protocol::SCSI_REQUEST_LEN_MAX,
881);
882
883impl<T: RingMem> Worker<T> {
884 async fn process_primary(&mut self) -> Result<(), WorkerError> {
886 loop {
887 let current_state = *self.inner.protocol.state.read();
888 match current_state {
889 ProtocolState::Ready { version, .. } => {
890 break loop {
891 select_biased! {
892 r = self.inner.process_ready(&mut self.queue, version).fuse() => break r,
893 _ = self.fast_select.select((self.rescan_notification.select_next_some(),)).fuse() => {
894 if version >= Version::Win7
895 {
896 self.inner.send_packet(&mut self.queue.split().1, storvsp_protocol::Operation::ENUMERATE_BUS, storvsp_protocol::NtStatus::SUCCESS, &())?;
897 }
898 }
899 }
900 };
901 }
902 ProtocolState::Init(state) => {
903 let (mut reader, mut writer) = self.queue.split();
904
905 poll_fn(|cx| self.inner.poll_for_ring_space(cx, &mut writer)).await?;
908
909 tracing::debug!(?state, "process_primary");
910 match state {
911 InitState::Begin => {
912 let packet = self.inner.next_packet(&mut reader).await?;
913 if let PacketData::BeginInitialization = packet.data {
914 self.inner.send_completion(
915 &mut writer,
916 &packet,
917 storvsp_protocol::NtStatus::SUCCESS,
918 &(),
919 )?;
920 *self.inner.protocol.state.write() =
921 ProtocolState::Init(InitState::QueryVersion);
922 } else {
923 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
924 self.inner.send_completion(
925 &mut writer,
926 &packet,
927 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
928 &(),
929 )?;
930 }
931 }
932 InitState::QueryVersion => {
933 let packet = self.inner.next_packet(&mut reader).await?;
934 if let PacketData::QueryProtocolVersion(major_minor) = packet.data {
935 if let Ok(version) = Version::parse(major_minor) {
936 self.inner.send_completion(
937 &mut writer,
938 &packet,
939 storvsp_protocol::NtStatus::SUCCESS,
940 &storvsp_protocol::ProtocolVersion {
941 major_minor,
942 reserved: 0,
943 },
944 )?;
945 self.inner.request_size = version.max_request_size();
946 *self.inner.protocol.state.write() =
947 ProtocolState::Init(InitState::QueryProperties { version });
948
949 tracelimit::info_ratelimited!(
950 ?version,
951 "scsi version negotiated"
952 );
953 } else {
954 self.inner.send_completion(
955 &mut writer,
956 &packet,
957 storvsp_protocol::NtStatus::REVISION_MISMATCH,
958 &storvsp_protocol::ProtocolVersion {
959 major_minor,
960 reserved: 0,
961 },
962 )?;
963 *self.inner.protocol.state.write() =
964 ProtocolState::Init(InitState::QueryVersion);
965 }
966 } else {
967 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
968 self.inner.send_completion(
969 &mut writer,
970 &packet,
971 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
972 &(),
973 )?;
974 }
975 }
976 InitState::QueryProperties { version } => {
977 let packet = self.inner.next_packet(&mut reader).await?;
978 if let PacketData::QueryProperties = packet.data {
979 let multi_channel_supported = version >= Version::Win8;
980
981 self.inner.send_completion(
982 &mut writer,
983 &packet,
984 storvsp_protocol::NtStatus::SUCCESS,
985 &storvsp_protocol::ChannelProperties {
986 max_transfer_bytes: 0x40000, flags: {
988 if multi_channel_supported {
989 storvsp_protocol::STORAGE_CHANNEL_SUPPORTS_MULTI_CHANNEL
990 } else {
991 0
992 }
993 },
994 maximum_sub_channel_count: if multi_channel_supported {
995 self.inner.channel_control.max_subchannels()
996 } else {
997 0
998 },
999 reserved: 0,
1000 reserved2: 0,
1001 reserved3: [0, 0],
1002 },
1003 )?;
1004 *self.inner.protocol.state.write() =
1005 ProtocolState::Init(InitState::EndInitialization {
1006 version,
1007 subchannel_count: if multi_channel_supported {
1008 None
1009 } else {
1010 Some(0)
1011 },
1012 });
1013 } else {
1014 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1015 self.inner.send_completion(
1016 &mut writer,
1017 &packet,
1018 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1019 &(),
1020 )?;
1021 }
1022 }
1023 InitState::EndInitialization {
1024 version,
1025 subchannel_count,
1026 } => {
1027 let packet = self.inner.next_packet(&mut reader).await?;
1028 match packet.data {
1029 PacketData::CreateSubChannels(sub_channel_count)
1030 if subchannel_count.is_none() =>
1031 {
1032 if let Err(err) = self
1033 .inner
1034 .channel_control
1035 .enable_subchannels(sub_channel_count)
1036 {
1037 tracelimit::warn_ratelimited!(
1038 ?err,
1039 "cannot enable subchannels"
1040 );
1041 self.inner.send_completion(
1042 &mut writer,
1043 &packet,
1044 storvsp_protocol::NtStatus::INVALID_PARAMETER,
1045 &(),
1046 )?;
1047 } else {
1048 self.inner.send_completion(
1049 &mut writer,
1050 &packet,
1051 storvsp_protocol::NtStatus::SUCCESS,
1052 &(),
1053 )?;
1054 *self.inner.protocol.state.write() =
1055 ProtocolState::Init(InitState::EndInitialization {
1056 version,
1057 subchannel_count: Some(sub_channel_count),
1058 });
1059 }
1060 }
1061 PacketData::EndInitialization => {
1062 self.inner.send_completion(
1063 &mut writer,
1064 &packet,
1065 storvsp_protocol::NtStatus::SUCCESS,
1066 &(),
1067 )?;
1068 self.rescan_notification.try_next().ok();
1071 *self.inner.protocol.state.write() = ProtocolState::Ready {
1072 version,
1073 subchannel_count: subchannel_count.unwrap_or(0),
1074 };
1075 self.inner.protocol.ready.notify(usize::MAX);
1078 }
1079 _ => {
1080 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1081 self.inner.send_completion(
1082 &mut writer,
1083 &packet,
1084 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1085 &(),
1086 )?;
1087 }
1088 }
1089 }
1090 }
1091 }
1092 }
1093 }
1094 }
1095}
1096
1097fn convert_srb_status_to_nt_status(srb_status: SrbStatus) -> storvsp_protocol::NtStatus {
1098 match srb_status {
1099 SrbStatus::BUSY => storvsp_protocol::NtStatus::DEVICE_BUSY,
1100 SrbStatus::SUCCESS => storvsp_protocol::NtStatus::SUCCESS,
1101 SrbStatus::INVALID_LUN
1102 | SrbStatus::INVALID_TARGET_ID
1103 | SrbStatus::NO_DEVICE
1104 | SrbStatus::NO_HBA => storvsp_protocol::NtStatus::DEVICE_DOES_NOT_EXIST,
1105 SrbStatus::COMMAND_TIMEOUT | SrbStatus::TIMEOUT => storvsp_protocol::NtStatus::IO_TIMEOUT,
1106 SrbStatus::SELECTION_TIMEOUT => storvsp_protocol::NtStatus::DEVICE_NOT_CONNECTED,
1107 SrbStatus::BAD_FUNCTION | SrbStatus::BAD_SRB_BLOCK_LENGTH => {
1108 storvsp_protocol::NtStatus::INVALID_DEVICE_REQUEST
1109 }
1110 SrbStatus::DATA_OVERRUN => storvsp_protocol::NtStatus::BUFFER_OVERFLOW,
1111 SrbStatus::REQUEST_FLUSHED => storvsp_protocol::NtStatus::UNSUCCESSFUL,
1112 SrbStatus::ABORTED => storvsp_protocol::NtStatus::CANCELLED,
1113 _ => storvsp_protocol::NtStatus::IO_DEVICE_ERROR,
1114 }
1115}
1116
1117impl WorkerInner {
1118 async fn process_ready<M: RingMem>(
1120 &mut self,
1121 queue: &mut Queue<M>,
1122 protocol_version: Version,
1123 ) -> Result<(), WorkerError> {
1124 self.request_size = protocol_version.max_request_size();
1125 poll_fn(|cx| self.poll_process_ready(cx, queue)).await
1126 }
1127
1128 fn poll_process_ready<M: RingMem>(
1130 &mut self,
1131 cx: &mut Context<'_>,
1132 queue: &mut Queue<M>,
1133 ) -> Poll<Result<(), WorkerError>> {
1134 self.stats.wakes.increment();
1135
1136 let (mut reader, mut writer) = queue.split();
1137 let mut total_completions = 0;
1138 let mut total_submissions = 0;
1139 let poll_mode_queue_depth = self.controller.poll_mode_queue_depth.load(Relaxed) as usize;
1140
1141 loop {
1142 'outer: while !self.scsi_requests_states.is_empty() {
1144 {
1145 let mut batch = writer.batched();
1146 loop {
1147 if !batch
1151 .can_write(MAX_VMBUS_PACKET_SIZE)
1152 .map_err(WorkerError::Queue)?
1153 {
1154 break;
1156 }
1157 if let Poll::Ready(Some((request_id, result, future))) =
1158 self.scsi_requests.poll_next_unpin(cx)
1159 {
1160 self.future_pool.push(future);
1161 self.handle_completion(&mut batch, request_id, result)?;
1162 total_completions += 1;
1163 } else {
1164 tracing::trace!("out of completions");
1165 break 'outer;
1166 }
1167 }
1168 }
1169
1170 if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1172 tracing::trace!("out of ring space");
1173 break;
1174 }
1175 }
1176
1177 let mut submissions = 0;
1178 'outer: loop {
1180 if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1181 break;
1182 }
1183 let mut batch = if self.scsi_requests_states.len() < poll_mode_queue_depth {
1184 if let Poll::Ready(batch) = reader.poll_read_batch(cx) {
1185 batch.map_err(WorkerError::Queue)?
1186 } else {
1187 tracing::trace!("out of incoming packets");
1188 break;
1189 }
1190 } else {
1191 match reader.try_read_batch() {
1192 Ok(batch) => batch,
1193 Err(queue::TryReadError::Empty) => {
1194 tracing::trace!(
1195 pending_io_count = self.scsi_requests_states.len(),
1196 "out of incoming packets, keeping interrupts masked"
1197 );
1198 break;
1199 }
1200 Err(queue::TryReadError::Queue(err)) => Err(WorkerError::Queue(err))?,
1201 }
1202 };
1203
1204 let mut packets = batch.packets();
1205 loop {
1206 if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1207 break 'outer;
1208 }
1209 if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1213 tracing::trace!("out of ring space");
1214 break 'outer;
1215 }
1216
1217 let packet = if let Some(packet) = packets.next() {
1218 packet.map_err(WorkerError::Queue)?
1219 } else {
1220 break;
1221 };
1222
1223 if self.handle_packet(&mut writer, &packet)? {
1224 submissions += 1;
1225 }
1226 }
1227 }
1228
1229 if submissions == 0 {
1231 break;
1233 }
1234 total_submissions += submissions;
1235 }
1236
1237 if total_submissions != 0 || total_completions != 0 {
1238 self.stats.ios_submitted.add(total_submissions);
1239 self.stats
1240 .per_wake_submissions
1241 .add_sample(total_submissions);
1242 self.stats
1243 .per_wake_completions
1244 .add_sample(total_completions);
1245 self.stats.ios_completed.add(total_completions);
1246 } else {
1247 self.stats.wakes_spurious.increment();
1248 }
1249
1250 Poll::Pending
1251 }
1252
1253 fn handle_completion<M: RingMem>(
1254 &mut self,
1255 writer: &mut queue::WriteBatch<'_, M>,
1256 request_id: usize,
1257 result: ScsiResult,
1258 ) -> Result<(), WorkerError> {
1259 let state = self.scsi_requests_states.remove(request_id);
1260 let request_size = state.request.request_size;
1261
1262 assert_eq!(
1264 Arc::strong_count(&state.request) + Arc::weak_count(&state.request),
1265 1
1266 );
1267 self.full_request_pool.push(state.request);
1268
1269 let status = convert_srb_status_to_nt_status(result.srb_status);
1270 let mut payload = [0; 0x14];
1271 if let Some(sense) = result.sense_data {
1272 payload[..size_of_val(&sense)].copy_from_slice(sense.as_bytes());
1273 tracing::trace!(sense_info = ?payload, sense_key = payload[2], asc = payload[12], "execute_scsi");
1274 };
1275 let response = storvsp_protocol::ScsiRequest {
1276 length: size_of::<storvsp_protocol::ScsiRequest>() as u16,
1277 scsi_status: result.scsi_status,
1278 srb_status: SrbStatusAndFlags::new()
1279 .with_status(result.srb_status)
1280 .with_autosense_valid(result.sense_data.is_some()),
1281 data_transfer_length: result.tx as u32,
1282 cdb_length: storvsp_protocol::CDB16GENERIC_LENGTH as u8,
1283 sense_info_ex_length: storvsp_protocol::VMSCSI_SENSE_BUFFER_SIZE as u8,
1284 payload,
1285 ..storvsp_protocol::ScsiRequest::new_zeroed()
1286 };
1287 self.send_vmbus_packet(
1288 writer,
1289 OutgoingPacketType::Completion,
1290 request_size,
1291 state.transaction_id,
1292 storvsp_protocol::Operation::COMPLETE_IO,
1293 status,
1294 response.as_bytes(),
1295 )?;
1296 Ok(())
1297 }
1298
1299 fn handle_packet<M: RingMem>(
1300 &mut self,
1301 writer: &mut queue::WriteHalf<'_, M>,
1302 packet: &IncomingPacket<'_, M>,
1303 ) -> Result<bool, WorkerError> {
1304 let packet =
1305 parse_packet(packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
1306 let submitted_io = match packet.data {
1307 PacketData::ExecuteScsi(request) => {
1308 self.push_scsi_request(packet.transaction_id, request);
1309 true
1310 }
1311 PacketData::ResetAdapter | PacketData::ResetBus | PacketData::ResetLun => {
1312 self.send_completion(writer, &packet, storvsp_protocol::NtStatus::SUCCESS, &())?;
1314 false
1315 }
1316 PacketData::CreateSubChannels(new_subchannel_count) if self.channel_index == 0 => {
1317 if let Err(err) = self
1318 .channel_control
1319 .enable_subchannels(new_subchannel_count)
1320 {
1321 tracelimit::warn_ratelimited!(?err, "cannot create subchannels");
1322 self.send_completion(
1323 writer,
1324 &packet,
1325 storvsp_protocol::NtStatus::INVALID_PARAMETER,
1326 &(),
1327 )?;
1328 false
1329 } else {
1330 if let ProtocolState::Ready {
1332 subchannel_count, ..
1333 } = &mut *self.protocol.state.write()
1334 {
1335 *subchannel_count = new_subchannel_count;
1336 } else {
1337 unreachable!()
1338 }
1339
1340 self.send_completion(
1341 writer,
1342 &packet,
1343 storvsp_protocol::NtStatus::SUCCESS,
1344 &(),
1345 )?;
1346 false
1347 }
1348 }
1349 _ => {
1350 tracelimit::warn_ratelimited!(data = ?packet.data, "unexpected packet on ready");
1351 self.send_completion(
1352 writer,
1353 &packet,
1354 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1355 &(),
1356 )?;
1357 false
1358 }
1359 };
1360 Ok(submitted_io)
1361 }
1362
1363 fn push_scsi_request(&mut self, transaction_id: u64, full_request: Arc<ScsiRequestAndRange>) {
1364 let scsi_queue = self.scsi_queue.clone();
1365 let scsi_request_state = ScsiRequestState {
1366 transaction_id,
1367 request: full_request.clone(),
1368 };
1369 let request_id = self.scsi_requests_states.insert(scsi_request_state);
1370 let future = self
1371 .future_pool
1372 .pop()
1373 .unwrap_or_else(|| OversizedBox::new(()));
1374 let future = OversizedBox::refill(future, async move {
1375 scsi_queue.execute_scsi(full_request.as_ref()).await
1376 });
1377 let request = ScsiRequest::new(request_id, oversized_box::coerce!(future));
1378 self.scsi_requests.push(request);
1379 }
1380}
1381
1382impl<T: RingMem> Drop for Worker<T> {
1383 fn drop(&mut self) {
1384 self.inner
1385 .controller
1386 .remove_rescan_notification_source(&self.rescan_notification);
1387 }
1388}
1389
1390#[derive(Debug, Error)]
1391#[error("SCSI path {}:{}:{} is already in use", self.0.path, self.0.target, self.0.lun)]
1392pub struct ScsiPathInUse(pub ScsiPath);
1393
1394#[derive(Debug, Error)]
1395#[error("SCSI path {}:{}:{} is not in use", self.0.path, self.0.target, self.0.lun)]
1396pub struct ScsiPathNotInUse(ScsiPath);
1397
1398#[derive(Clone)]
1399struct ScsiRequestState {
1400 transaction_id: u64,
1401 request: Arc<ScsiRequestAndRange>,
1402}
1403
1404#[derive(Debug)]
1405struct ScsiRequestAndRange {
1406 external_data: Range,
1407 external_data_buf: MultiPagedRangeBuf,
1408 request: storvsp_protocol::ScsiRequest,
1409 request_size: usize,
1410}
1411
1412impl Inspect for ScsiRequestState {
1413 fn inspect(&self, req: inspect::Request<'_>) {
1414 req.respond()
1415 .field("transaction_id", self.transaction_id)
1416 .display(
1417 "address",
1418 &ScsiPath {
1419 path: self.request.request.path_id,
1420 target: self.request.request.target_id,
1421 lun: self.request.request.lun,
1422 },
1423 )
1424 .display_debug("operation", &ScsiOp(self.request.request.payload[0]));
1425 }
1426}
1427
1428impl StorageDevice {
1429 pub fn build_scsi(
1431 driver_source: &VmTaskDriverSource,
1432 controller: &ScsiController,
1433 instance_id: Guid,
1434 max_sub_channel_count: u16,
1435 io_queue_depth: u32,
1436 ) -> Self {
1437 Self::build_inner(
1438 driver_source,
1439 controller,
1440 instance_id,
1441 None,
1442 max_sub_channel_count,
1443 io_queue_depth,
1444 )
1445 }
1446
1447 pub fn build_ide(
1450 driver_source: &VmTaskDriverSource,
1451 channel_id: u8,
1452 device_id: u8,
1453 disk: ScsiControllerDisk,
1454 io_queue_depth: u32,
1455 ) -> Self {
1456 let path = ScsiPath {
1457 path: channel_id,
1458 target: device_id,
1459 lun: 0,
1460 };
1461
1462 let controller = ScsiController::new();
1463 controller.attach(path, disk).unwrap();
1464
1465 let instance_id = Guid {
1468 data1: channel_id.into(),
1469 data2: device_id.into(),
1470 data3: 0x8899,
1471 data4: [0; 8],
1472 };
1473 Self::build_inner(
1474 driver_source,
1475 &controller,
1476 instance_id,
1477 Some(path),
1478 0,
1479 io_queue_depth,
1480 )
1481 }
1482
1483 fn build_inner(
1484 driver_source: &VmTaskDriverSource,
1485 controller: &ScsiController,
1486 instance_id: Guid,
1487 ide_path: Option<ScsiPath>,
1488 max_sub_channel_count: u16,
1489 io_queue_depth: u32,
1490 ) -> Self {
1491 let workers = (0..max_sub_channel_count + 1)
1492 .map(|channel_index| WorkerAndDriver {
1493 worker: TaskControl::new(WorkerState),
1494 driver: driver_source
1495 .builder()
1496 .target_vp(0)
1497 .run_on_target(true)
1498 .build(format!("storvsp-{}-{}", instance_id, channel_index)),
1499 })
1500 .collect();
1501
1502 Self {
1503 instance_id,
1504 ide_path,
1505 workers,
1506 controller: controller.state.clone(),
1507 resources: Default::default(),
1508 max_sub_channel_count,
1509 driver_source: driver_source.clone(),
1510 protocol: Arc::new(Protocol {
1511 state: RwLock::new(ProtocolState::Init(InitState::Begin)),
1512 ready: Default::default(),
1513 }),
1514 io_queue_depth,
1515 }
1516 }
1517
1518 fn new_worker(
1519 &mut self,
1520 open_request: &OpenRequest,
1521 channel_index: u16,
1522 ) -> anyhow::Result<&mut Worker> {
1523 let controller = self.controller.clone();
1524
1525 let driver = self
1526 .driver_source
1527 .builder()
1528 .target_vp(open_request.open_data.target_vp)
1529 .run_on_target(true)
1530 .build(format!("storvsp-{}-{}", self.instance_id, channel_index));
1531
1532 let channel = gpadl_channel(&driver, &self.resources, open_request, channel_index)
1533 .context("failed to create vmbus channel")?;
1534
1535 let channel_control = self.resources.channel_control.clone();
1536
1537 tracing::debug!(
1538 target_vp = open_request.open_data.target_vp,
1539 channel_index,
1540 "packet processing starting...",
1541 );
1542
1543 let force_path_id = self.ide_path.map(|p| p.path);
1547
1548 let worker = Worker::new(
1549 controller,
1550 channel,
1551 channel_index,
1552 self.resources
1553 .offer_resources
1554 .guest_memory(open_request)
1555 .clone(),
1556 channel_control,
1557 self.io_queue_depth,
1558 self.protocol.clone(),
1559 force_path_id,
1560 )
1561 .map_err(RestoreError::Other)?;
1562
1563 self.workers[channel_index as usize]
1564 .driver
1565 .retarget_vp(open_request.open_data.target_vp);
1566
1567 Ok(self.workers[channel_index as usize].worker.insert(
1568 &driver,
1569 format!("storvsp worker {}-{}", self.instance_id, channel_index),
1570 worker,
1571 ))
1572 }
1573}
1574
1575#[derive(Clone)]
1577pub struct ScsiControllerDisk {
1578 disk: Arc<dyn AsyncScsiDisk>,
1579}
1580
1581impl ScsiControllerDisk {
1582 pub fn new(disk: Arc<dyn AsyncScsiDisk>) -> Self {
1584 Self { disk }
1585 }
1586}
1587
1588struct ScsiControllerState {
1589 disks: RwLock<HashMap<ScsiPath, ScsiControllerDisk>>,
1590 rescan_notification_source: Mutex<Vec<futures::channel::mpsc::Sender<()>>>,
1591 poll_mode_queue_depth: AtomicU32,
1592}
1593
1594pub struct ScsiController {
1595 state: Arc<ScsiControllerState>,
1596}
1597
1598impl ScsiController {
1599 pub fn new() -> Self {
1600 Self::new_with_poll_mode_queue_depth(None)
1601 }
1602
1603 pub fn new_with_poll_mode_queue_depth(poll_mode_queue_depth: Option<u32>) -> Self {
1604 Self {
1605 state: Arc::new(ScsiControllerState {
1606 disks: Default::default(),
1607 rescan_notification_source: Mutex::new(Vec::new()),
1608 poll_mode_queue_depth: AtomicU32::new(
1609 poll_mode_queue_depth.unwrap_or(DEFAULT_POLL_MODE_QUEUE_DEPTH),
1610 ),
1611 }),
1612 }
1613 }
1614
1615 pub fn attach(&self, path: ScsiPath, disk: ScsiControllerDisk) -> Result<(), ScsiPathInUse> {
1616 match self.state.disks.write().entry(path) {
1617 Entry::Occupied(_) => return Err(ScsiPathInUse(path)),
1618 Entry::Vacant(entry) => entry.insert(disk),
1619 };
1620 for source in self.state.rescan_notification_source.lock().iter_mut() {
1621 source.try_send(()).ok();
1624 }
1625 Ok(())
1626 }
1627
1628 pub fn remove(&self, path: ScsiPath) -> Result<(), ScsiPathNotInUse> {
1629 match self.state.disks.write().entry(path) {
1630 Entry::Vacant(_) => return Err(ScsiPathNotInUse(path)),
1631 Entry::Occupied(entry) => {
1632 entry.remove();
1633 }
1634 }
1635 for source in self.state.rescan_notification_source.lock().iter_mut() {
1636 source.try_send(()).ok();
1639 }
1640 Ok(())
1641 }
1642}
1643
1644impl ScsiControllerState {
1645 fn add_rescan_notification_source(&self, source: futures::channel::mpsc::Sender<()>) {
1646 self.rescan_notification_source.lock().push(source);
1647 }
1648
1649 fn remove_rescan_notification_source(&self, target: &futures::channel::mpsc::Receiver<()>) {
1650 let mut sources = self.rescan_notification_source.lock();
1651 if let Some(index) = sources
1652 .iter()
1653 .position(|source| source.is_connected_to(target))
1654 {
1655 sources.remove(index);
1656 }
1657 }
1658}
1659
1660#[async_trait]
1661impl VmbusDevice for StorageDevice {
1662 fn offer(&self) -> OfferParams {
1663 if let Some(path) = self.ide_path {
1664 let offer_properties = storvsp_protocol::OfferProperties {
1665 path_id: path.path,
1666 target_id: path.target,
1667 flags: storvsp_protocol::OFFER_PROPERTIES_FLAG_IDE_DEVICE,
1668 ..FromZeros::new_zeroed()
1669 };
1670 let mut user_defined = UserDefinedData::new_zeroed();
1671 offer_properties
1672 .write_to_prefix(&mut user_defined[..])
1673 .unwrap();
1674 OfferParams {
1675 interface_name: "ide-accel".to_owned(),
1676 instance_id: self.instance_id,
1677 interface_id: storvsp_protocol::IDE_ACCELERATOR_INTERFACE_ID,
1678 channel_type: ChannelType::Interface { user_defined },
1679 ..Default::default()
1680 }
1681 } else {
1682 OfferParams {
1683 interface_name: "scsi".to_owned(),
1684 instance_id: self.instance_id,
1685 interface_id: storvsp_protocol::SCSI_INTERFACE_ID,
1686 ..Default::default()
1687 }
1688 }
1689 }
1690
1691 fn max_subchannels(&self) -> u16 {
1692 self.max_sub_channel_count
1693 }
1694
1695 fn install(&mut self, resources: DeviceResources) {
1696 self.resources = resources;
1697 }
1698
1699 async fn open(
1700 &mut self,
1701 channel_index: u16,
1702 open_request: &OpenRequest,
1703 ) -> Result<(), ChannelOpenError> {
1704 tracing::debug!(channel_index, "scsi open channel");
1705 self.new_worker(open_request, channel_index)?;
1706 self.workers[channel_index as usize].worker.start();
1707 Ok(())
1708 }
1709
1710 async fn close(&mut self, channel_index: u16) {
1711 tracing::debug!(channel_index, "scsi close channel");
1712 let worker = &mut self.workers[channel_index as usize].worker;
1713 worker.stop().await;
1714 if worker.state_mut().is_some() {
1715 worker
1716 .state_mut()
1717 .unwrap()
1718 .wait_for_scsi_requests_complete()
1719 .await;
1720 worker.remove();
1721 }
1722 if channel_index == 0 {
1723 *self.protocol.state.write() = ProtocolState::Init(InitState::Begin);
1724 }
1725 }
1726
1727 async fn retarget_vp(&mut self, channel_index: u16, target_vp: u32) {
1728 self.workers[channel_index as usize]
1729 .driver
1730 .retarget_vp(target_vp);
1731 }
1732
1733 fn start(&mut self) {
1734 for task in self
1735 .workers
1736 .iter_mut()
1737 .filter(|task| task.worker.has_state() && !task.worker.is_running())
1738 {
1739 task.worker.start();
1740 }
1741 }
1742
1743 async fn stop(&mut self) {
1744 tracing::debug!(instance_id = ?self.instance_id, "StorageDevice stopping...");
1745 for task in self
1746 .workers
1747 .iter_mut()
1748 .filter(|task| task.worker.has_state() && task.worker.is_running())
1749 {
1750 task.worker.stop().await;
1751 }
1752 }
1753
1754 fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1755 Some(self)
1756 }
1757}
1758
1759#[async_trait]
1760impl SaveRestoreVmbusDevice for StorageDevice {
1761 async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1762 Ok(SavedStateBlob::new(self.save()?))
1763 }
1764
1765 async fn restore(
1766 &mut self,
1767 control: RestoreControl<'_>,
1768 state: SavedStateBlob,
1769 ) -> Result<(), RestoreError> {
1770 self.restore(control, state.parse()?).await
1771 }
1772}
1773
1774#[cfg(test)]
1775mod tests {
1776 use super::*;
1777 use crate::test_helpers::TestWorker;
1778 use crate::test_helpers::parse_guest_completion;
1779 use crate::test_helpers::parse_guest_completion_check_flags_status;
1780 use pal_async::DefaultDriver;
1781 use pal_async::async_test;
1782 use scsi::srb::SrbStatus;
1783 use test_with_tracing::test;
1784 use vmbus_channel::connected_async_channels;
1785
1786 impl Clone for ScsiController {
1790 fn clone(&self) -> Self {
1791 ScsiController {
1792 state: self.state.clone(),
1793 }
1794 }
1795 }
1796
1797 #[async_test]
1798 async fn test_channel_working(driver: DefaultDriver) {
1799 let (host, guest) = connected_async_channels(16 * 1024);
1801 let guest_queue = Queue::new(guest).unwrap();
1802
1803 let test_guest_mem = GuestMemory::allocate(16384);
1804 let controller = ScsiController::new();
1805 let disk = scsidisk::SimpleScsiDisk::new(
1806 disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
1807 Default::default(),
1808 );
1809 controller
1810 .attach(
1811 ScsiPath {
1812 path: 0,
1813 target: 0,
1814 lun: 0,
1815 },
1816 ScsiControllerDisk::new(Arc::new(disk)),
1817 )
1818 .unwrap();
1819
1820 let test_worker = TestWorker::start(
1821 controller.clone(),
1822 driver.clone(),
1823 test_guest_mem.clone(),
1824 host,
1825 None,
1826 );
1827
1828 let mut guest = test_helpers::TestGuest {
1829 queue: guest_queue,
1830 transaction_id: 0,
1831 };
1832
1833 guest.perform_protocol_negotiation().await;
1834
1835 const IO_LEN: usize = 4 * 1024;
1837 let write_buf = [7u8; IO_LEN];
1838 let write_gpa = 4 * 1024u64;
1839 test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
1840 guest
1841 .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
1842 .await;
1843
1844 guest
1845 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1846 .await;
1847
1848 let read_gpa = 8 * 1024u64;
1849 guest
1850 .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
1851 .await;
1852
1853 guest
1854 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1855 .await;
1856 let mut read_buf = [0u8; IO_LEN];
1857 test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
1858 for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
1859 assert_eq!(b1, b2);
1860 }
1861
1862 guest.verify_graceful_close(test_worker).await;
1864 }
1865
1866 #[async_test]
1867 async fn test_packet_sizes(driver: DefaultDriver) {
1868 let (host, guest) = connected_async_channels(16384);
1870 let guest_queue = Queue::new(guest).unwrap();
1871
1872 let test_guest_mem = GuestMemory::allocate(1024);
1873 let controller = ScsiController::new();
1874
1875 let _worker = TestWorker::start(
1876 controller.clone(),
1877 driver.clone(),
1878 test_guest_mem,
1879 host,
1880 None,
1881 );
1882
1883 let mut guest = test_helpers::TestGuest {
1884 queue: guest_queue,
1885 transaction_id: 0,
1886 };
1887
1888 let negotiate_packet = storvsp_protocol::Packet {
1889 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
1890 flags: 0,
1891 status: storvsp_protocol::NtStatus::SUCCESS,
1892 };
1893 guest
1894 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1895 .await;
1896
1897 guest.verify_completion(parse_guest_completion).await;
1898
1899 let header = storvsp_protocol::Packet {
1900 operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
1901 flags: 0,
1902 status: storvsp_protocol::NtStatus::SUCCESS,
1903 };
1904
1905 let mut buf = [0u8; 128];
1906 storvsp_protocol::ProtocolVersion {
1907 major_minor: !0,
1908 reserved: 0,
1909 }
1910 .write_to_prefix(&mut buf[..])
1911 .unwrap(); for &(len, resp_len) in &[(48, 48), (50, 56), (56, 56), (64, 64), (72, 64)] {
1914 guest
1915 .send_data_packet_sync(&[header.as_bytes(), &buf[..len - size_of_val(&header)]])
1916 .await;
1917
1918 guest
1919 .verify_completion(|packet| {
1920 let IncomingPacket::Completion(packet) = packet else {
1921 unreachable!()
1922 };
1923 assert_eq!(packet.reader().len(), resp_len);
1924 assert_eq!(
1925 packet
1926 .reader()
1927 .read_plain::<storvsp_protocol::Packet>()
1928 .unwrap()
1929 .status,
1930 storvsp_protocol::NtStatus::REVISION_MISMATCH
1931 );
1932 Ok(())
1933 })
1934 .await;
1935 }
1936 }
1937
1938 #[async_test]
1939 async fn test_wrong_first_packet(driver: DefaultDriver) {
1940 let (host, guest) = connected_async_channels(16384);
1942 let guest_queue = Queue::new(guest).unwrap();
1943
1944 let test_guest_mem = GuestMemory::allocate(1024);
1945 let controller = ScsiController::new();
1946
1947 let _worker = TestWorker::start(
1948 controller.clone(),
1949 driver.clone(),
1950 test_guest_mem,
1951 host,
1952 None,
1953 );
1954
1955 let mut guest = test_helpers::TestGuest {
1956 queue: guest_queue,
1957 transaction_id: 0,
1958 };
1959
1960 let negotiate_packet = storvsp_protocol::Packet {
1962 operation: storvsp_protocol::Operation::END_INITIALIZATION,
1963 flags: 0,
1964 status: storvsp_protocol::NtStatus::SUCCESS,
1965 };
1966 guest
1967 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1968 .await;
1969
1970 guest
1971 .verify_completion(|packet| {
1972 let IncomingPacket::Completion(packet) = packet else {
1973 unreachable!()
1974 };
1975 assert_eq!(
1976 packet
1977 .reader()
1978 .read_plain::<storvsp_protocol::Packet>()
1979 .unwrap()
1980 .status,
1981 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE
1982 );
1983 Ok(())
1984 })
1985 .await;
1986 }
1987
1988 #[async_test]
1989 async fn test_unrecognized_operation(driver: DefaultDriver) {
1990 let (host, guest) = connected_async_channels(16384);
1992 let guest_queue = Queue::new(guest).unwrap();
1993
1994 let test_guest_mem = GuestMemory::allocate(1024);
1995 let controller = ScsiController::new();
1996
1997 let worker = TestWorker::start(
1998 controller.clone(),
1999 driver.clone(),
2000 test_guest_mem,
2001 host,
2002 None,
2003 );
2004
2005 let mut guest = test_helpers::TestGuest {
2006 queue: guest_queue,
2007 transaction_id: 0,
2008 };
2009
2010 let negotiate_packet = storvsp_protocol::Packet {
2012 operation: storvsp_protocol::Operation::REMOVE_DEVICE,
2013 flags: 0,
2014 status: storvsp_protocol::NtStatus::SUCCESS,
2015 };
2016 guest
2017 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2018 .await;
2019
2020 match worker.teardown().await {
2021 Err(WorkerError::PacketError(PacketError::UnrecognizedOperation(
2022 storvsp_protocol::Operation::REMOVE_DEVICE,
2023 ))) => {}
2024 result => panic!("Worker failed with unexpected result {:?}!", result),
2025 }
2026 }
2027
2028 #[async_test]
2029 async fn test_too_many_subchannels(driver: DefaultDriver) {
2030 let (host, guest) = connected_async_channels(16384);
2032 let guest_queue = Queue::new(guest).unwrap();
2033
2034 let test_guest_mem = GuestMemory::allocate(1024);
2035 let controller = ScsiController::new();
2036
2037 let _worker = TestWorker::start(
2038 controller.clone(),
2039 driver.clone(),
2040 test_guest_mem,
2041 host,
2042 None,
2043 );
2044
2045 let mut guest = test_helpers::TestGuest {
2046 queue: guest_queue,
2047 transaction_id: 0,
2048 };
2049
2050 let negotiate_packet = storvsp_protocol::Packet {
2051 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2052 flags: 0,
2053 status: storvsp_protocol::NtStatus::SUCCESS,
2054 };
2055 guest
2056 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2057 .await;
2058 guest.verify_completion(parse_guest_completion).await;
2059
2060 let version_packet = storvsp_protocol::Packet {
2061 operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2062 flags: 0,
2063 status: storvsp_protocol::NtStatus::SUCCESS,
2064 };
2065 let version = storvsp_protocol::ProtocolVersion {
2066 major_minor: storvsp_protocol::VERSION_BLUE,
2067 reserved: 0,
2068 };
2069 guest
2070 .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2071 .await;
2072 guest.verify_completion(parse_guest_completion).await;
2073
2074 let properties_packet = storvsp_protocol::Packet {
2075 operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2076 flags: 0,
2077 status: storvsp_protocol::NtStatus::SUCCESS,
2078 };
2079 guest
2080 .send_data_packet_sync(&[properties_packet.as_bytes()])
2081 .await;
2082
2083 guest.verify_completion(parse_guest_completion).await;
2084
2085 let negotiate_packet = storvsp_protocol::Packet {
2086 operation: storvsp_protocol::Operation::CREATE_SUB_CHANNELS,
2087 flags: 0,
2088 status: storvsp_protocol::NtStatus::SUCCESS,
2089 };
2090 guest
2092 .send_data_packet_sync(&[negotiate_packet.as_bytes(), 1_u16.as_bytes()])
2093 .await;
2094
2095 guest
2096 .verify_completion(|packet| {
2097 let IncomingPacket::Completion(packet) = packet else {
2098 unreachable!()
2099 };
2100 assert_eq!(
2101 packet
2102 .reader()
2103 .read_plain::<storvsp_protocol::Packet>()
2104 .unwrap()
2105 .status,
2106 storvsp_protocol::NtStatus::INVALID_PARAMETER
2107 );
2108 Ok(())
2109 })
2110 .await;
2111 }
2112
2113 #[async_test]
2114 async fn test_begin_init_on_ready(driver: DefaultDriver) {
2115 let (host, guest) = connected_async_channels(16384);
2117 let guest_queue = Queue::new(guest).unwrap();
2118
2119 let test_guest_mem = GuestMemory::allocate(1024);
2120 let controller = ScsiController::new();
2121
2122 let _worker = TestWorker::start(
2123 controller.clone(),
2124 driver.clone(),
2125 test_guest_mem,
2126 host,
2127 None,
2128 );
2129
2130 let mut guest = test_helpers::TestGuest {
2131 queue: guest_queue,
2132 transaction_id: 0,
2133 };
2134
2135 guest.perform_protocol_negotiation().await;
2136
2137 let negotiate_packet = storvsp_protocol::Packet {
2139 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2140 flags: 0,
2141 status: storvsp_protocol::NtStatus::SUCCESS,
2142 };
2143 guest
2144 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2145 .await;
2146
2147 guest
2148 .verify_completion(|p| {
2149 parse_guest_completion_check_flags_status(
2150 p,
2151 0,
2152 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
2153 )
2154 })
2155 .await;
2156 }
2157
2158 #[async_test]
2159 async fn test_hot_add_remove(driver: DefaultDriver) {
2160 let (host, guest) = connected_async_channels(16 * 1024);
2162 let guest_queue = Queue::new(guest).unwrap();
2163
2164 let test_guest_mem = GuestMemory::allocate(16384);
2165 let controller = ScsiController::new();
2167
2168 let test_worker = TestWorker::start(
2169 controller.clone(),
2170 driver.clone(),
2171 test_guest_mem.clone(),
2172 host,
2173 None,
2174 );
2175
2176 let mut guest = test_helpers::TestGuest {
2177 queue: guest_queue,
2178 transaction_id: 0,
2179 };
2180
2181 guest.perform_protocol_negotiation().await;
2182
2183 let mut lun_list_buffer: [u8; 256] = [0; 256];
2185 let mut disk_count = 0;
2186 guest
2187 .send_report_luns_packet(ScsiPath::default(), 0, lun_list_buffer.len())
2188 .await;
2189 guest
2190 .verify_completion(|p| {
2191 test_helpers::parse_guest_completed_io_check_tx_len(p, SrbStatus::SUCCESS, Some(8))
2192 })
2193 .await;
2194 test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2195 let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2196 assert_eq!(lun_list_size, disk_count as u32 * 8);
2197
2198 const IO_LEN: usize = 4 * 1024;
2200 let write_buf = [7u8; IO_LEN];
2201 let write_gpa = 4 * 1024u64;
2202 test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2203
2204 guest
2205 .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2206 .await;
2207 guest
2208 .verify_completion(|p| {
2209 test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2210 })
2211 .await;
2212
2213 for lun in 0..4 {
2215 let disk = scsidisk::SimpleScsiDisk::new(
2216 disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
2217 Default::default(),
2218 );
2219 controller
2220 .attach(
2221 ScsiPath {
2222 path: 0,
2223 target: 0,
2224 lun,
2225 },
2226 ScsiControllerDisk::new(Arc::new(disk)),
2227 )
2228 .unwrap();
2229 guest
2230 .verify_completion(test_helpers::parse_guest_enumerate_bus)
2231 .await;
2232
2233 disk_count += 1;
2234 guest
2235 .send_report_luns_packet(ScsiPath::default(), 0, 256)
2236 .await;
2237 guest
2238 .verify_completion(|p| {
2239 test_helpers::parse_guest_completed_io_check_tx_len(
2240 p,
2241 SrbStatus::SUCCESS,
2242 Some((disk_count + 1) * 8),
2243 )
2244 })
2245 .await;
2246 test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2247 let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2248 assert_eq!(lun_list_size, disk_count as u32 * 8);
2249
2250 guest
2251 .send_write_packet(
2252 ScsiPath {
2253 path: 0,
2254 target: 0,
2255 lun,
2256 },
2257 write_gpa,
2258 1,
2259 IO_LEN,
2260 )
2261 .await;
2262 guest
2263 .verify_completion(|p| {
2264 test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS)
2265 })
2266 .await;
2267 }
2268
2269 for lun in 0..4 {
2271 controller
2272 .remove(ScsiPath {
2273 path: 0,
2274 target: 0,
2275 lun,
2276 })
2277 .unwrap();
2278 guest
2279 .verify_completion(test_helpers::parse_guest_enumerate_bus)
2280 .await;
2281
2282 disk_count -= 1;
2283 guest
2284 .send_report_luns_packet(ScsiPath::default(), 0, 4096)
2285 .await;
2286 guest
2287 .verify_completion(|p| {
2288 test_helpers::parse_guest_completed_io_check_tx_len(
2289 p,
2290 SrbStatus::SUCCESS,
2291 Some((disk_count + 1) * 8),
2292 )
2293 })
2294 .await;
2295 test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2296 let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2297 assert_eq!(lun_list_size, disk_count as u32 * 8);
2298
2299 guest
2300 .send_write_packet(
2301 ScsiPath {
2302 path: 0,
2303 target: 0,
2304 lun,
2305 },
2306 write_gpa,
2307 1,
2308 IO_LEN,
2309 )
2310 .await;
2311 guest
2312 .verify_completion(|p| {
2313 test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2314 })
2315 .await;
2316 }
2317
2318 guest.verify_graceful_close(test_worker).await;
2319 }
2320
2321 #[async_test]
2322 pub async fn test_async_disk(driver: DefaultDriver) {
2323 let device = disklayer_ram::ram_disk(64 * 1024, false).unwrap();
2324 let controller = ScsiController::new();
2325 let disk = ScsiControllerDisk::new(Arc::new(scsidisk::SimpleScsiDisk::new(
2326 device,
2327 Default::default(),
2328 )));
2329 controller
2330 .attach(
2331 ScsiPath {
2332 path: 0,
2333 target: 0,
2334 lun: 0,
2335 },
2336 disk,
2337 )
2338 .unwrap();
2339
2340 let (host, guest) = connected_async_channels(16 * 1024);
2341 let guest_queue = Queue::new(guest).unwrap();
2342
2343 let mut guest = test_helpers::TestGuest {
2344 queue: guest_queue,
2345 transaction_id: 0,
2346 };
2347
2348 let test_guest_mem = GuestMemory::allocate(16384);
2349 let worker = TestWorker::start(
2350 controller.clone(),
2351 &driver,
2352 test_guest_mem.clone(),
2353 host,
2354 None,
2355 );
2356
2357 let negotiate_packet = storvsp_protocol::Packet {
2358 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2359 flags: 0,
2360 status: storvsp_protocol::NtStatus::SUCCESS,
2361 };
2362 guest
2363 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2364 .await;
2365 guest.verify_completion(parse_guest_completion).await;
2366
2367 let version_packet = storvsp_protocol::Packet {
2368 operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2369 flags: 0,
2370 status: storvsp_protocol::NtStatus::SUCCESS,
2371 };
2372 let version = storvsp_protocol::ProtocolVersion {
2373 major_minor: storvsp_protocol::VERSION_BLUE,
2374 reserved: 0,
2375 };
2376 guest
2377 .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2378 .await;
2379 guest.verify_completion(parse_guest_completion).await;
2380
2381 let properties_packet = storvsp_protocol::Packet {
2382 operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2383 flags: 0,
2384 status: storvsp_protocol::NtStatus::SUCCESS,
2385 };
2386 guest
2387 .send_data_packet_sync(&[properties_packet.as_bytes()])
2388 .await;
2389 guest.verify_completion(parse_guest_completion).await;
2390
2391 let negotiate_packet = storvsp_protocol::Packet {
2392 operation: storvsp_protocol::Operation::END_INITIALIZATION,
2393 flags: 0,
2394 status: storvsp_protocol::NtStatus::SUCCESS,
2395 };
2396 guest
2397 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2398 .await;
2399 guest.verify_completion(parse_guest_completion).await;
2400
2401 const IO_LEN: usize = 4 * 1024;
2402 let write_buf = [7u8; IO_LEN];
2403 let write_gpa = 4 * 1024u64;
2404 test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2405 guest
2406 .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2407 .await;
2408 guest
2409 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2410 .await;
2411
2412 let read_gpa = 8 * 1024u64;
2413 guest
2414 .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
2415 .await;
2416 guest
2417 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2418 .await;
2419 let mut read_buf = [0u8; IO_LEN];
2420 test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
2421 for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
2422 assert_eq!(b1, b2);
2423 }
2424
2425 guest.verify_graceful_close(worker).await;
2426 }
2427}