1#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7#[cfg(feature = "ioperf")]
8pub mod ioperf;
9
10#[cfg(feature = "test")]
11pub mod test_helpers;
12
13#[cfg(not(feature = "test"))]
14mod test_helpers;
15
16pub mod resolver;
17mod save_restore;
18
19use crate::ring::gparange::GpnList;
20use crate::ring::gparange::MultiPagedRangeBuf;
21use anyhow::Context as _;
22use async_trait::async_trait;
23use fast_select::FastSelect;
24use futures::FutureExt;
25use futures::StreamExt;
26use futures::select_biased;
27use guestmem::AccessError;
28use guestmem::GuestMemory;
29use guestmem::MemoryRead;
30use guestmem::MemoryWrite;
31use guestmem::ranges::PagedRange;
32use guid::Guid;
33use inspect::Inspect;
34use inspect::InspectMut;
35use inspect_counters::Counter;
36use inspect_counters::Histogram;
37use oversized_box::OversizedBox;
38use parking_lot::Mutex;
39use parking_lot::RwLock;
40use ring::OutgoingPacketType;
41use scsi::AdditionalSenseCode;
42use scsi::ScsiOp;
43use scsi::ScsiStatus;
44use scsi::srb::SrbStatus;
45use scsi::srb::SrbStatusAndFlags;
46use scsi_buffers::RequestBuffers;
47use scsi_core::AsyncScsiDisk;
48use scsi_core::Request;
49use scsi_core::ScsiResult;
50use scsi_defs as scsi;
51use scsidisk::illegal_request_sense;
52use slab::Slab;
53use std::collections::hash_map::Entry;
54use std::collections::hash_map::HashMap;
55use std::fmt::Debug;
56use std::future::Future;
57use std::future::poll_fn;
58use std::pin::Pin;
59use std::sync::Arc;
60use std::sync::atomic::AtomicU32;
61use std::sync::atomic::Ordering::Relaxed;
62use std::task::Context;
63use std::task::Poll;
64use storvsp_resources::ScsiPath;
65use task_control::AsyncRun;
66use task_control::InspectTask;
67use task_control::StopTask;
68use task_control::TaskControl;
69use thiserror::Error;
70use tracing_helpers::ErrorValueExt;
71use unicycle::FuturesUnordered;
72use vmbus_async::queue;
73use vmbus_async::queue::ExternalDataError;
74use vmbus_async::queue::IncomingPacket;
75use vmbus_async::queue::OutgoingPacket;
76use vmbus_async::queue::Queue;
77use vmbus_channel::RawAsyncChannel;
78use vmbus_channel::bus::ChannelType;
79use vmbus_channel::bus::OfferParams;
80use vmbus_channel::bus::OpenRequest;
81use vmbus_channel::channel::ChannelControl;
82use vmbus_channel::channel::ChannelOpenError;
83use vmbus_channel::channel::DeviceResources;
84use vmbus_channel::channel::RestoreControl;
85use vmbus_channel::channel::SaveRestoreVmbusDevice;
86use vmbus_channel::channel::VmbusDevice;
87use vmbus_channel::gpadl_ring::GpadlRingMem;
88use vmbus_channel::gpadl_ring::gpadl_channel;
89use vmbus_core::protocol::UserDefinedData;
90use vmbus_ring as ring;
91use vmbus_ring::RingMem;
92use vmcore::save_restore::RestoreError;
93use vmcore::save_restore::SaveError;
94use vmcore::save_restore::SavedStateBlob;
95use vmcore::vm_task::VmTaskDriver;
96use vmcore::vm_task::VmTaskDriverSource;
97use zerocopy::FromBytes;
98use zerocopy::FromZeros;
99use zerocopy::Immutable;
100use zerocopy::IntoBytes;
101use zerocopy::KnownLayout;
102
103const DEFAULT_POLL_MODE_QUEUE_DEPTH: u32 = 1;
108
109pub struct StorageDevice {
110 instance_id: Guid,
111 ide_path: Option<ScsiPath>,
112 workers: Vec<WorkerAndDriver>,
113 controller: Arc<ScsiControllerState>,
114 resources: DeviceResources,
115 driver_source: VmTaskDriverSource,
116 max_sub_channel_count: u16,
117 protocol: Arc<Protocol>,
118 io_queue_depth: u32,
119}
120
121#[derive(Inspect)]
122struct WorkerAndDriver {
123 #[inspect(flatten)]
124 worker: TaskControl<WorkerState, Worker>,
125 driver: VmTaskDriver,
126}
127
128struct WorkerState;
129
130impl InspectMut for StorageDevice {
131 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
132 let mut resp = req.respond();
133
134 let disks = self.controller.disks.read();
135 for (path, controller_disk) in disks.iter() {
136 resp.child(&format!("disks/{}", path), |req| {
137 controller_disk.disk.inspect(req);
138 });
139 }
140
141 resp.fields(
142 "channels",
143 self.workers
144 .iter()
145 .filter(|task| task.worker.has_state())
146 .enumerate(),
147 )
148 .field(
149 "poll_mode_queue_depth",
150 inspect::AtomicMut(&self.controller.poll_mode_queue_depth),
151 );
152 }
153}
154
155struct Worker<T: RingMem = GpadlRingMem> {
156 inner: WorkerInner,
157 rescan_notification: futures::channel::mpsc::Receiver<()>,
158 fast_select: FastSelect,
159 queue: Queue<T>,
160}
161
162struct Protocol {
163 state: RwLock<ProtocolState>,
164 ready: event_listener::Event,
166}
167
168struct WorkerInner {
169 protocol: Arc<Protocol>,
170 request_size: usize,
171 controller: Arc<ScsiControllerState>,
172 channel_index: u16,
173 scsi_queue: Arc<ScsiCommandQueue>,
174 scsi_requests: FuturesUnordered<ScsiRequest>,
175 scsi_requests_states: Slab<ScsiRequestState>,
176 full_request_pool: Vec<Arc<ScsiRequestAndRange>>,
177 future_pool: Vec<OversizedBox<(), ScsiOpStorage>>,
178 channel_control: ChannelControl,
179 max_io_queue_depth: usize,
180 stats: WorkerStats,
181}
182
183#[derive(Debug, Default, Inspect)]
184struct WorkerStats {
185 ios_submitted: Counter,
186 ios_completed: Counter,
187 wakes: Counter,
188 wakes_spurious: Counter,
189 per_wake_submissions: Histogram<10>,
190 per_wake_completions: Histogram<10>,
191}
192
193#[repr(u16)]
194#[derive(Copy, Clone, Debug, Inspect, PartialEq, Eq, PartialOrd, Ord)]
195enum Version {
196 Win6 = storvsp_protocol::VERSION_WIN6,
197 Win7 = storvsp_protocol::VERSION_WIN7,
198 Win8 = storvsp_protocol::VERSION_WIN8,
199 Blue = storvsp_protocol::VERSION_BLUE,
200}
201
202#[derive(Debug, Error)]
203#[error("protocol version {0:#x} not supported")]
204struct UnsupportedVersion(u16);
205
206impl Version {
207 fn parse(major_minor: u16) -> Result<Self, UnsupportedVersion> {
208 let version = match major_minor {
209 storvsp_protocol::VERSION_WIN6 => Self::Win6,
210 storvsp_protocol::VERSION_WIN7 => Self::Win7,
211 storvsp_protocol::VERSION_WIN8 => Self::Win8,
212 storvsp_protocol::VERSION_BLUE => Self::Blue,
213 version => return Err(UnsupportedVersion(version)),
214 };
215 assert_eq!(version as u16, major_minor);
216 Ok(version)
217 }
218
219 fn max_request_size(&self) -> usize {
220 match self {
221 Version::Win8 | Version::Blue => storvsp_protocol::SCSI_REQUEST_LEN_V2,
222 Version::Win6 | Version::Win7 => storvsp_protocol::SCSI_REQUEST_LEN_V1,
223 }
224 }
225}
226
227#[derive(Copy, Clone)]
228enum ProtocolState {
229 Init(InitState),
230 Ready {
231 version: Version,
232 subchannel_count: u16,
233 },
234}
235
236#[derive(Copy, Clone, Debug)]
237enum InitState {
238 Begin,
239 QueryVersion,
240 QueryProperties {
241 version: Version,
242 },
243 EndInitialization {
244 version: Version,
245 subchannel_count: Option<u16>,
246 },
247}
248
249type ScsiOpStorage = [u64; SCSI_REQUEST_STACK_SIZE / 8];
257type ScsiOpFuture = Pin<OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>>;
258
259const SCSI_REQUEST_STACK_SIZE: usize = scsi_core::ASYNC_SCSI_DISK_STACK_SIZE + 272;
264
265struct ScsiRequest {
266 request_id: usize,
267 future: Option<ScsiOpFuture>,
268}
269
270impl ScsiRequest {
271 fn new(
272 request_id: usize,
273 future: OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>,
274 ) -> Self {
275 Self {
276 request_id,
277 future: Some(future.into()),
278 }
279 }
280}
281
282impl Future for ScsiRequest {
283 type Output = (usize, ScsiResult, OversizedBox<(), ScsiOpStorage>);
284
285 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
286 let this = self.get_mut();
287 let future = this.future.as_mut().unwrap().as_mut();
288 let result = std::task::ready!(future.poll(cx));
289 let future = this.future.take().unwrap();
291 Poll::Ready((this.request_id, result, OversizedBox::empty_pinned(future)))
292 }
293}
294
295#[derive(Debug, Error)]
296enum WorkerError {
297 #[error("packet error")]
298 PacketError(#[source] PacketError),
299 #[error("queue error")]
300 Queue(#[source] queue::Error),
301 #[error("queue should have enough space but no longer does")]
302 NotEnoughSpace,
303}
304
305#[derive(Debug, Error)]
306enum PacketError {
307 #[error("Not transactional")]
308 NotTransactional,
309 #[error("Unrecognized operation {0:?}")]
310 UnrecognizedOperation(storvsp_protocol::Operation),
311 #[error("Invalid packet type")]
312 InvalidPacketType,
313 #[error("Invalid data transfer length")]
314 InvalidDataTransferLength,
315 #[error("Access error: {0}")]
316 Access(#[source] AccessError),
317 #[error("Range error")]
318 Range(#[source] ExternalDataError),
319}
320
321#[derive(Debug, Default, Clone)]
322struct Range {
323 buf: MultiPagedRangeBuf<GpnList>,
324 len: usize,
325 is_write: bool,
326}
327
328impl Range {
329 fn new(
330 buf: MultiPagedRangeBuf<GpnList>,
331 request: &storvsp_protocol::ScsiRequest,
332 ) -> Option<Self> {
333 let len = request.data_transfer_length as usize;
334 let is_write = request.data_in != 0;
335 if buf.range_count() > 1 || (len > 0 && buf.first()?.len() < len) {
338 return None;
339 }
340 Some(Self { buf, len, is_write })
341 }
342
343 fn buffer<'a>(&'a self, guest_memory: &'a GuestMemory) -> RequestBuffers<'a> {
344 let mut range = self.buf.first().unwrap_or_else(PagedRange::empty);
345 range.truncate(self.len);
346 RequestBuffers::new(guest_memory, range, self.is_write)
347 }
348}
349
350#[derive(Debug)]
351struct Packet {
352 data: PacketData,
353 transaction_id: u64,
354 request_size: usize,
355}
356
357#[derive(Debug)]
358enum PacketData {
359 BeginInitialization,
360 EndInitialization,
361 QueryProtocolVersion(u16),
362 QueryProperties,
363 CreateSubChannels(u16),
364 ExecuteScsi(Arc<ScsiRequestAndRange>),
365 ResetBus,
366 ResetAdapter,
367 ResetLun,
368}
369
370#[derive(Debug)]
371pub struct RangeError;
372
373fn parse_packet<T: RingMem>(
374 packet: &IncomingPacket<'_, T>,
375 pool: &mut Vec<Arc<ScsiRequestAndRange>>,
376) -> Result<Packet, PacketError> {
377 let packet = match packet {
378 IncomingPacket::Completion(_) => return Err(PacketError::InvalidPacketType),
379 IncomingPacket::Data(packet) => packet,
380 };
381 let transaction_id = packet
382 .transaction_id()
383 .ok_or(PacketError::NotTransactional)?;
384
385 let mut reader = packet.reader();
386 let header: storvsp_protocol::Packet = reader.read_plain().map_err(PacketError::Access)?;
387 let request_size = reader.len().min(storvsp_protocol::SCSI_REQUEST_LEN_MAX);
391 let data = match header.operation {
392 storvsp_protocol::Operation::BEGIN_INITIALIZATION => PacketData::BeginInitialization,
393 storvsp_protocol::Operation::END_INITIALIZATION => PacketData::EndInitialization,
394 storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION => {
395 let mut version = storvsp_protocol::ProtocolVersion::new_zeroed();
396 reader
397 .read(version.as_mut_bytes())
398 .map_err(PacketError::Access)?;
399 PacketData::QueryProtocolVersion(version.major_minor)
400 }
401 storvsp_protocol::Operation::QUERY_PROPERTIES => PacketData::QueryProperties,
402 storvsp_protocol::Operation::EXECUTE_SRB => {
403 let mut full_request = pool.pop().unwrap_or_else(|| {
404 Arc::new(ScsiRequestAndRange {
405 external_data: Range::default(),
406 request: storvsp_protocol::ScsiRequest::new_zeroed(),
407 request_size,
408 })
409 });
410
411 {
412 let full_request = Arc::get_mut(&mut full_request).unwrap();
413 let request_buf = &mut full_request.request.as_mut_bytes()[..request_size];
414 reader.read(request_buf).map_err(PacketError::Access)?;
415
416 let buf = packet.read_external_ranges().map_err(PacketError::Range)?;
417
418 full_request.external_data = Range::new(buf, &full_request.request)
419 .ok_or(PacketError::InvalidDataTransferLength)?;
420 }
421
422 PacketData::ExecuteScsi(full_request)
423 }
424 storvsp_protocol::Operation::RESET_LUN => PacketData::ResetLun,
425 storvsp_protocol::Operation::RESET_ADAPTER => PacketData::ResetAdapter,
426 storvsp_protocol::Operation::RESET_BUS => PacketData::ResetBus,
427 storvsp_protocol::Operation::CREATE_SUB_CHANNELS => {
428 let mut sub_channel_count: u16 = 0;
429 reader
430 .read(sub_channel_count.as_mut_bytes())
431 .map_err(PacketError::Access)?;
432 PacketData::CreateSubChannels(sub_channel_count)
433 }
434 _ => return Err(PacketError::UnrecognizedOperation(header.operation)),
435 };
436
437 if let PacketData::ExecuteScsi(_) = data {
438 tracing::trace!(transaction_id, ?data, "parse_packet");
439 } else {
440 tracing::debug!(transaction_id, ?data, "parse_packet");
441 }
442
443 Ok(Packet {
444 data,
445 request_size,
446 transaction_id,
447 })
448}
449
450impl WorkerInner {
451 fn send_vmbus_packet<M: RingMem>(
452 &mut self,
453 writer: &mut queue::WriteBatch<'_, M>,
454 packet_type: OutgoingPacketType<'_>,
455 request_size: usize,
456 transaction_id: u64,
457 operation: storvsp_protocol::Operation,
458 status: storvsp_protocol::NtStatus,
459 payload: &[u8],
460 ) -> Result<(), WorkerError> {
461 let header = storvsp_protocol::Packet {
462 operation,
463 flags: 0,
464 status,
465 };
466
467 let packet_size = size_of_val(&header) + request_size;
468
469 let len = size_of_val(&header) + size_of_val(payload);
474 let padding = [0; storvsp_protocol::SCSI_REQUEST_LEN_MAX];
475 let (payload_bytes, padding_bytes) = if len > packet_size {
476 (&payload[..packet_size - size_of_val(&header)], &[][..])
477 } else {
478 (payload, &padding[..packet_size - len])
479 };
480 assert_eq!(
481 size_of_val(&header) + payload_bytes.len() + padding_bytes.len(),
482 packet_size
483 );
484 writer
485 .try_write(&OutgoingPacket {
486 transaction_id,
487 packet_type,
488 payload: &[header.as_bytes(), payload_bytes, padding_bytes],
489 })
490 .map_err(|err| match err {
491 queue::TryWriteError::Full(_) => WorkerError::NotEnoughSpace,
492 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
493 })
494 }
495
496 fn send_packet<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
497 &mut self,
498 writer: &mut queue::WriteHalf<'_, M>,
499 operation: storvsp_protocol::Operation,
500 status: storvsp_protocol::NtStatus,
501 payload: &P,
502 ) -> Result<(), WorkerError> {
503 self.send_vmbus_packet(
504 &mut writer.batched(),
505 OutgoingPacketType::InBandNoCompletion,
506 self.request_size,
507 0,
508 operation,
509 status,
510 payload.as_bytes(),
511 )
512 }
513
514 fn send_completion<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
515 &mut self,
516 writer: &mut queue::WriteHalf<'_, M>,
517 packet: &Packet,
518 status: storvsp_protocol::NtStatus,
519 payload: &P,
520 ) -> Result<(), WorkerError> {
521 self.send_vmbus_packet(
522 &mut writer.batched(),
523 OutgoingPacketType::Completion,
524 packet.request_size,
525 packet.transaction_id,
526 storvsp_protocol::Operation::COMPLETE_IO,
527 status,
528 payload.as_bytes(),
529 )
530 }
531}
532
533struct ScsiCommandQueue {
534 controller: Arc<ScsiControllerState>,
535 mem: GuestMemory,
536 force_path_id: Option<u8>,
537}
538
539impl ScsiCommandQueue {
540 async fn execute_scsi(
541 &self,
542 external_data: &Range,
543 request: &storvsp_protocol::ScsiRequest,
544 ) -> ScsiResult {
545 let op = ScsiOp(request.payload[0]);
546 let external_data = external_data.buffer(&self.mem);
547
548 tracing::trace!(
549 path_id = request.path_id,
550 target_id = request.target_id,
551 lun = request.lun,
552 op = ?op,
553 "execute_scsi start...",
554 );
555
556 let path_id = self.force_path_id.unwrap_or(request.path_id);
557
558 let controller_disk = self
559 .controller
560 .disks
561 .read()
562 .get(&ScsiPath {
563 path: path_id,
564 target: request.target_id,
565 lun: request.lun,
566 })
567 .cloned();
568
569 let result = match op {
570 ScsiOp::REPORT_LUNS => {
571 const HEADER_SIZE: usize = size_of::<scsi::LunList>();
572 let mut luns: Vec<u8> = self
573 .controller
574 .disks
575 .read()
576 .iter()
577 .flat_map(|(path, _)| {
578 if request.path_id == path.path && request.target_id == path.target {
581 Some(path.lun)
582 } else {
583 None
584 }
585 })
586 .collect();
587 luns.sort_unstable();
588 let mut data: Vec<u64> = vec![0; luns.len() + 1];
589 let header = scsi::LunList {
590 length: (luns.len() as u32 * 8).into(),
591 reserved: [0; 4],
592 };
593 data.as_mut_bytes()[..HEADER_SIZE].copy_from_slice(header.as_bytes());
594 for (i, lun) in luns.iter().enumerate() {
595 data[i + 1].as_mut_bytes()[..2].copy_from_slice(&(*lun as u16).to_be_bytes());
596 }
597 if external_data.len() >= HEADER_SIZE {
598 let tx = std::cmp::min(external_data.len(), data.as_bytes().len());
599 external_data.writer().write(&data.as_bytes()[..tx]).map_or(
600 ScsiResult {
601 scsi_status: ScsiStatus::CHECK_CONDITION,
602 srb_status: SrbStatus::INVALID_REQUEST,
603 tx: 0,
604 sense_data: Some(illegal_request_sense(
605 AdditionalSenseCode::INVALID_CDB,
606 )),
607 },
608 |_| ScsiResult {
609 scsi_status: ScsiStatus::GOOD,
610 srb_status: SrbStatus::SUCCESS,
611 tx,
612 sense_data: None,
613 },
614 )
615 } else {
616 ScsiResult {
617 scsi_status: ScsiStatus::GOOD,
618 srb_status: SrbStatus::SUCCESS,
619 tx: 0,
620 sense_data: None,
621 }
622 }
623 }
624 _ if controller_disk.is_some() => {
625 let mut cdb = [0; 16];
626 cdb.copy_from_slice(&request.payload[0..storvsp_protocol::CDB16GENERIC_LENGTH]);
627 controller_disk
628 .unwrap()
629 .disk
630 .execute_scsi(
631 &external_data,
632 &Request {
633 cdb,
634 srb_flags: request.srb_flags,
635 },
636 )
637 .await
638 }
639 ScsiOp::INQUIRY => {
640 let cdb = scsi::CdbInquiry::ref_from_prefix(&request.payload)
641 .unwrap()
642 .0; if external_data.len() < cdb.allocation_length.get() as usize
644 || request.data_in != storvsp_protocol::SCSI_IOCTL_DATA_IN
645 || (cdb.allocation_length.get() as usize) < size_of::<scsi::InquiryDataHeader>()
646 {
647 ScsiResult {
648 scsi_status: ScsiStatus::CHECK_CONDITION,
649 srb_status: SrbStatus::INVALID_REQUEST,
650 tx: 0,
651 sense_data: Some(illegal_request_sense(AdditionalSenseCode::INVALID_CDB)),
652 }
653 } else {
654 let enable_vpd = cdb.flags.vpd();
655 if enable_vpd || cdb.page_code != 0 {
656 ScsiResult {
658 scsi_status: ScsiStatus::CHECK_CONDITION,
659 srb_status: SrbStatus::INVALID_REQUEST,
660 tx: 0,
661 sense_data: Some(illegal_request_sense(
662 AdditionalSenseCode::INVALID_CDB,
663 )),
664 }
665 } else {
666 const LOGICAL_UNIT_NOT_PRESENT_DEVICE: u8 = 0x7F;
667 let mut data = scsidisk::INQUIRY_DATA_TEMPLATE;
668 data.header.device_type = LOGICAL_UNIT_NOT_PRESENT_DEVICE;
669
670 if request.lun != 0 {
671 data.vendor_id = [0; 8];
673 data.product_id = [0; 16];
674 data.product_revision_level = [0; 4];
675 }
676
677 let datab = data.as_bytes();
678 let tx = std::cmp::min(
679 cdb.allocation_length.get() as usize,
680 size_of::<scsi::InquiryData>(),
681 );
682 external_data.writer().write(&datab[..tx]).map_or(
683 ScsiResult {
684 scsi_status: ScsiStatus::CHECK_CONDITION,
685 srb_status: SrbStatus::INVALID_REQUEST,
686 tx: 0,
687 sense_data: Some(illegal_request_sense(
688 AdditionalSenseCode::INVALID_CDB,
689 )),
690 },
691 |_| ScsiResult {
692 scsi_status: ScsiStatus::GOOD,
693 srb_status: SrbStatus::SUCCESS,
694 tx,
695 sense_data: None,
696 },
697 )
698 }
699 }
700 }
701 _ => ScsiResult {
702 scsi_status: ScsiStatus::CHECK_CONDITION,
703 srb_status: SrbStatus::INVALID_LUN,
704 tx: 0,
705 sense_data: None,
706 },
707 };
708
709 tracing::trace!(
710 path_id = request.path_id,
711 target_id = request.target_id,
712 lun = request.lun,
713 op = ?op,
714 result = ?result,
715 "execute_scsi completed.",
716 );
717 result
718 }
719}
720
721impl<T: RingMem + 'static> Worker<T> {
722 fn new(
723 controller: Arc<ScsiControllerState>,
724 channel: RawAsyncChannel<T>,
725 channel_index: u16,
726 mem: GuestMemory,
727 channel_control: ChannelControl,
728 io_queue_depth: u32,
729 protocol: Arc<Protocol>,
730 force_path_id: Option<u8>,
731 ) -> anyhow::Result<Self> {
732 let queue = Queue::new(channel)?;
733 #[expect(clippy::disallowed_methods)] let (source, target) = futures::channel::mpsc::channel(1);
735 controller.add_rescan_notification_source(source);
736
737 let max_io_queue_depth = io_queue_depth.max(1) as usize;
738 Ok(Self {
739 inner: WorkerInner {
740 protocol,
741 request_size: storvsp_protocol::SCSI_REQUEST_LEN_V1,
742 controller: controller.clone(),
743 channel_index,
744 scsi_queue: Arc::new(ScsiCommandQueue {
745 controller,
746 mem,
747 force_path_id,
748 }),
749 scsi_requests: FuturesUnordered::new(),
750 scsi_requests_states: Slab::with_capacity(max_io_queue_depth),
751 channel_control,
752 max_io_queue_depth,
753 future_pool: Vec::new(),
754 full_request_pool: Vec::new(),
755 stats: Default::default(),
756 },
757 queue,
758 rescan_notification: target,
759 fast_select: FastSelect::new(),
760 })
761 }
762
763 async fn wait_for_scsi_requests_complete(&mut self) {
764 tracing::debug!(
765 channel_index = self.inner.channel_index,
766 "wait for IOs completed..."
767 );
768 while let Some((id, _, _)) = self.inner.scsi_requests.next().await {
769 self.inner.scsi_requests_states.remove(id);
770 }
771 }
772}
773
774impl InspectTask<Worker> for WorkerState {
775 fn inspect(&self, req: inspect::Request<'_>, worker: Option<&Worker>) {
776 if let Some(worker) = worker {
777 let mut resp = req.respond();
778 if worker.inner.channel_index == 0 {
779 let (state, version, subchannel_count) = match *worker.inner.protocol.state.read() {
780 ProtocolState::Init(state) => match state {
781 InitState::Begin => ("begin_init", None, None),
782 InitState::QueryVersion => ("query_version", None, None),
783 InitState::QueryProperties { version } => {
784 ("query_properties", Some(version), None)
785 }
786 InitState::EndInitialization {
787 version,
788 subchannel_count,
789 } => ("end_init", Some(version), subchannel_count),
790 },
791 ProtocolState::Ready {
792 version,
793 subchannel_count,
794 } => ("ready", Some(version), Some(subchannel_count)),
795 };
796 resp.field("state", state)
797 .field("version", version)
798 .field("subchannel_count", subchannel_count);
799 }
800 resp.field("pending_packets", worker.inner.scsi_requests_states.len())
801 .fields("io", worker.inner.scsi_requests_states.iter())
802 .field("stats", &worker.inner.stats)
803 .field("ring", &worker.queue)
804 .field("max_io_queue_depth", worker.inner.max_io_queue_depth);
805 }
806 }
807}
808
809impl<T: 'static + Send + Sync + RingMem> AsyncRun<Worker<T>> for WorkerState {
810 async fn run(
811 &mut self,
812 stop: &mut StopTask<'_>,
813 worker: &mut Worker<T>,
814 ) -> Result<(), task_control::Cancelled> {
815 let fut = async {
816 if worker.inner.channel_index == 0 {
817 worker.process_primary().await
818 } else {
819 let protocol_version = loop {
822 let listener = worker.inner.protocol.ready.listen();
823 if let ProtocolState::Ready { version, .. } =
824 *worker.inner.protocol.state.read()
825 {
826 break version;
827 }
828 tracing::debug!("subchannel waiting for initialization to end");
829 listener.await
830 };
831 worker
832 .inner
833 .process_ready(&mut worker.queue, protocol_version)
834 .await
835 }
836 };
837
838 match stop.until_stopped(fut).await? {
839 Ok(_) => {}
840 Err(e) => tracing::error!(error = e.as_error(), "process_packets error"),
841 }
842 Ok(())
843 }
844}
845
846impl WorkerInner {
847 async fn next_packet<'a, M: RingMem>(
850 &mut self,
851 reader: &'a mut queue::ReadHalf<'a, M>,
852 ) -> Result<Packet, WorkerError> {
853 let packet = reader.read().await.map_err(WorkerError::Queue)?;
854 let stor_packet =
855 parse_packet(&packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
856 Ok(stor_packet)
857 }
858
859 fn poll_for_ring_space<M: RingMem>(
866 &mut self,
867 cx: &mut Context<'_>,
868 writer: &mut queue::WriteHalf<'_, M>,
869 ) -> Poll<Result<(), WorkerError>> {
870 writer
871 .poll_ready(cx, MAX_VMBUS_PACKET_SIZE)
872 .map_err(WorkerError::Queue)
873 }
874}
875
876const MAX_VMBUS_PACKET_SIZE: usize = ring::PacketSize::in_band(
877 size_of::<storvsp_protocol::Packet>() + storvsp_protocol::SCSI_REQUEST_LEN_MAX,
878);
879
880impl<T: RingMem> Worker<T> {
881 async fn process_primary(&mut self) -> Result<(), WorkerError> {
883 loop {
884 let current_state = *self.inner.protocol.state.read();
885 match current_state {
886 ProtocolState::Ready { version, .. } => {
887 break loop {
888 select_biased! {
889 r = self.inner.process_ready(&mut self.queue, version).fuse() => break r,
890 _ = self.fast_select.select((self.rescan_notification.select_next_some(),)).fuse() => {
891 if version >= Version::Win7
892 {
893 self.inner.send_packet(&mut self.queue.split().1, storvsp_protocol::Operation::ENUMERATE_BUS, storvsp_protocol::NtStatus::SUCCESS, &())?;
894 }
895 }
896 }
897 };
898 }
899 ProtocolState::Init(state) => {
900 let (mut reader, mut writer) = self.queue.split();
901
902 poll_fn(|cx| self.inner.poll_for_ring_space(cx, &mut writer)).await?;
905
906 tracing::debug!(?state, "process_primary");
907 match state {
908 InitState::Begin => {
909 let packet = self.inner.next_packet(&mut reader).await?;
910 if let PacketData::BeginInitialization = packet.data {
911 self.inner.send_completion(
912 &mut writer,
913 &packet,
914 storvsp_protocol::NtStatus::SUCCESS,
915 &(),
916 )?;
917 *self.inner.protocol.state.write() =
918 ProtocolState::Init(InitState::QueryVersion);
919 } else {
920 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
921 self.inner.send_completion(
922 &mut writer,
923 &packet,
924 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
925 &(),
926 )?;
927 }
928 }
929 InitState::QueryVersion => {
930 let packet = self.inner.next_packet(&mut reader).await?;
931 if let PacketData::QueryProtocolVersion(major_minor) = packet.data {
932 if let Ok(version) = Version::parse(major_minor) {
933 self.inner.send_completion(
934 &mut writer,
935 &packet,
936 storvsp_protocol::NtStatus::SUCCESS,
937 &storvsp_protocol::ProtocolVersion {
938 major_minor,
939 reserved: 0,
940 },
941 )?;
942 self.inner.request_size = version.max_request_size();
943 *self.inner.protocol.state.write() =
944 ProtocolState::Init(InitState::QueryProperties { version });
945
946 tracelimit::info_ratelimited!(
947 ?version,
948 "scsi version negotiated"
949 );
950 } else {
951 self.inner.send_completion(
952 &mut writer,
953 &packet,
954 storvsp_protocol::NtStatus::REVISION_MISMATCH,
955 &storvsp_protocol::ProtocolVersion {
956 major_minor,
957 reserved: 0,
958 },
959 )?;
960 *self.inner.protocol.state.write() =
961 ProtocolState::Init(InitState::QueryVersion);
962 }
963 } else {
964 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
965 self.inner.send_completion(
966 &mut writer,
967 &packet,
968 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
969 &(),
970 )?;
971 }
972 }
973 InitState::QueryProperties { version } => {
974 let packet = self.inner.next_packet(&mut reader).await?;
975 if let PacketData::QueryProperties = packet.data {
976 let multi_channel_supported = version >= Version::Win8;
977
978 self.inner.send_completion(
979 &mut writer,
980 &packet,
981 storvsp_protocol::NtStatus::SUCCESS,
982 &storvsp_protocol::ChannelProperties {
983 max_transfer_bytes: 0x40000, flags: {
985 if multi_channel_supported {
986 storvsp_protocol::STORAGE_CHANNEL_SUPPORTS_MULTI_CHANNEL
987 } else {
988 0
989 }
990 },
991 maximum_sub_channel_count: if multi_channel_supported {
992 self.inner.channel_control.max_subchannels()
993 } else {
994 0
995 },
996 reserved: 0,
997 reserved2: 0,
998 reserved3: [0, 0],
999 },
1000 )?;
1001 *self.inner.protocol.state.write() =
1002 ProtocolState::Init(InitState::EndInitialization {
1003 version,
1004 subchannel_count: if multi_channel_supported {
1005 None
1006 } else {
1007 Some(0)
1008 },
1009 });
1010 } else {
1011 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1012 self.inner.send_completion(
1013 &mut writer,
1014 &packet,
1015 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1016 &(),
1017 )?;
1018 }
1019 }
1020 InitState::EndInitialization {
1021 version,
1022 subchannel_count,
1023 } => {
1024 let packet = self.inner.next_packet(&mut reader).await?;
1025 match packet.data {
1026 PacketData::CreateSubChannels(sub_channel_count)
1027 if subchannel_count.is_none() =>
1028 {
1029 if let Err(err) = self
1030 .inner
1031 .channel_control
1032 .enable_subchannels(sub_channel_count)
1033 {
1034 tracelimit::warn_ratelimited!(
1035 ?err,
1036 "cannot enable subchannels"
1037 );
1038 self.inner.send_completion(
1039 &mut writer,
1040 &packet,
1041 storvsp_protocol::NtStatus::INVALID_PARAMETER,
1042 &(),
1043 )?;
1044 } else {
1045 self.inner.send_completion(
1046 &mut writer,
1047 &packet,
1048 storvsp_protocol::NtStatus::SUCCESS,
1049 &(),
1050 )?;
1051 *self.inner.protocol.state.write() =
1052 ProtocolState::Init(InitState::EndInitialization {
1053 version,
1054 subchannel_count: Some(sub_channel_count),
1055 });
1056 }
1057 }
1058 PacketData::EndInitialization => {
1059 self.inner.send_completion(
1060 &mut writer,
1061 &packet,
1062 storvsp_protocol::NtStatus::SUCCESS,
1063 &(),
1064 )?;
1065 self.rescan_notification.try_next().ok();
1068 *self.inner.protocol.state.write() = ProtocolState::Ready {
1069 version,
1070 subchannel_count: subchannel_count.unwrap_or(0),
1071 };
1072 self.inner.protocol.ready.notify(usize::MAX);
1075 }
1076 _ => {
1077 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1078 self.inner.send_completion(
1079 &mut writer,
1080 &packet,
1081 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1082 &(),
1083 )?;
1084 }
1085 }
1086 }
1087 }
1088 }
1089 }
1090 }
1091 }
1092}
1093
1094fn convert_srb_status_to_nt_status(srb_status: SrbStatus) -> storvsp_protocol::NtStatus {
1095 match srb_status {
1096 SrbStatus::BUSY => storvsp_protocol::NtStatus::DEVICE_BUSY,
1097 SrbStatus::SUCCESS => storvsp_protocol::NtStatus::SUCCESS,
1098 SrbStatus::INVALID_LUN
1099 | SrbStatus::INVALID_TARGET_ID
1100 | SrbStatus::NO_DEVICE
1101 | SrbStatus::NO_HBA => storvsp_protocol::NtStatus::DEVICE_DOES_NOT_EXIST,
1102 SrbStatus::COMMAND_TIMEOUT | SrbStatus::TIMEOUT => storvsp_protocol::NtStatus::IO_TIMEOUT,
1103 SrbStatus::SELECTION_TIMEOUT => storvsp_protocol::NtStatus::DEVICE_NOT_CONNECTED,
1104 SrbStatus::BAD_FUNCTION | SrbStatus::BAD_SRB_BLOCK_LENGTH => {
1105 storvsp_protocol::NtStatus::INVALID_DEVICE_REQUEST
1106 }
1107 SrbStatus::DATA_OVERRUN => storvsp_protocol::NtStatus::BUFFER_OVERFLOW,
1108 SrbStatus::REQUEST_FLUSHED => storvsp_protocol::NtStatus::UNSUCCESSFUL,
1109 SrbStatus::ABORTED => storvsp_protocol::NtStatus::CANCELLED,
1110 _ => storvsp_protocol::NtStatus::IO_DEVICE_ERROR,
1111 }
1112}
1113
1114impl WorkerInner {
1115 async fn process_ready<M: RingMem>(
1117 &mut self,
1118 queue: &mut Queue<M>,
1119 protocol_version: Version,
1120 ) -> Result<(), WorkerError> {
1121 self.request_size = protocol_version.max_request_size();
1122 poll_fn(|cx| self.poll_process_ready(cx, queue)).await
1123 }
1124
1125 fn poll_process_ready<M: RingMem>(
1127 &mut self,
1128 cx: &mut Context<'_>,
1129 queue: &mut Queue<M>,
1130 ) -> Poll<Result<(), WorkerError>> {
1131 self.stats.wakes.increment();
1132
1133 let (mut reader, mut writer) = queue.split();
1134 let mut total_completions = 0;
1135 let mut total_submissions = 0;
1136 let poll_mode_queue_depth = self.controller.poll_mode_queue_depth.load(Relaxed) as usize;
1137
1138 loop {
1139 'outer: while !self.scsi_requests_states.is_empty() {
1141 {
1142 let mut batch = writer.batched();
1143 loop {
1144 if !batch
1148 .can_write(MAX_VMBUS_PACKET_SIZE)
1149 .map_err(WorkerError::Queue)?
1150 {
1151 break;
1153 }
1154 if let Poll::Ready(Some((request_id, result, future))) =
1155 self.scsi_requests.poll_next_unpin(cx)
1156 {
1157 self.future_pool.push(future);
1158 self.handle_completion(&mut batch, request_id, result)?;
1159 total_completions += 1;
1160 } else {
1161 tracing::trace!("out of completions");
1162 break 'outer;
1163 }
1164 }
1165 }
1166
1167 if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1169 tracing::trace!("out of ring space");
1170 break;
1171 }
1172 }
1173
1174 let mut submissions = 0;
1175 'outer: loop {
1177 if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1178 break;
1179 }
1180 let mut batch = if self.scsi_requests_states.len() < poll_mode_queue_depth {
1181 if let Poll::Ready(batch) = reader.poll_read_batch(cx) {
1182 batch.map_err(WorkerError::Queue)?
1183 } else {
1184 tracing::trace!("out of incoming packets");
1185 break;
1186 }
1187 } else {
1188 match reader.try_read_batch() {
1189 Ok(batch) => batch,
1190 Err(queue::TryReadError::Empty) => {
1191 tracing::trace!(
1192 pending_io_count = self.scsi_requests_states.len(),
1193 "out of incoming packets, keeping interrupts masked"
1194 );
1195 break;
1196 }
1197 Err(queue::TryReadError::Queue(err)) => Err(WorkerError::Queue(err))?,
1198 }
1199 };
1200
1201 let mut packets = batch.packets();
1202 loop {
1203 if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1204 break 'outer;
1205 }
1206 if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1210 tracing::trace!("out of ring space");
1211 break 'outer;
1212 }
1213
1214 let packet = if let Some(packet) = packets.next() {
1215 packet.map_err(WorkerError::Queue)?
1216 } else {
1217 break;
1218 };
1219
1220 if self.handle_packet(&mut writer, &packet)? {
1221 submissions += 1;
1222 }
1223 }
1224 }
1225
1226 if submissions == 0 {
1228 break;
1230 }
1231 total_submissions += submissions;
1232 }
1233
1234 if total_submissions != 0 || total_completions != 0 {
1235 self.stats.ios_submitted.add(total_submissions);
1236 self.stats
1237 .per_wake_submissions
1238 .add_sample(total_submissions);
1239 self.stats
1240 .per_wake_completions
1241 .add_sample(total_completions);
1242 self.stats.ios_completed.add(total_completions);
1243 } else {
1244 self.stats.wakes_spurious.increment();
1245 }
1246
1247 Poll::Pending
1248 }
1249
1250 fn handle_completion<M: RingMem>(
1251 &mut self,
1252 writer: &mut queue::WriteBatch<'_, M>,
1253 request_id: usize,
1254 result: ScsiResult,
1255 ) -> Result<(), WorkerError> {
1256 let state = self.scsi_requests_states.remove(request_id);
1257 let request_size = state.request.request_size;
1258
1259 assert_eq!(
1261 Arc::strong_count(&state.request) + Arc::weak_count(&state.request),
1262 1
1263 );
1264 self.full_request_pool.push(state.request);
1265
1266 let status = convert_srb_status_to_nt_status(result.srb_status);
1267 let mut payload = [0; 0x14];
1268 if let Some(sense) = result.sense_data {
1269 payload[..size_of_val(&sense)].copy_from_slice(sense.as_bytes());
1270 tracing::trace!(sense_info = ?payload, sense_key = payload[2], asc = payload[12], "execute_scsi");
1271 };
1272 let response = storvsp_protocol::ScsiRequest {
1273 length: size_of::<storvsp_protocol::ScsiRequest>() as u16,
1274 scsi_status: result.scsi_status,
1275 srb_status: SrbStatusAndFlags::new()
1276 .with_status(result.srb_status)
1277 .with_autosense_valid(result.sense_data.is_some()),
1278 data_transfer_length: result.tx as u32,
1279 cdb_length: storvsp_protocol::CDB16GENERIC_LENGTH as u8,
1280 sense_info_ex_length: storvsp_protocol::VMSCSI_SENSE_BUFFER_SIZE as u8,
1281 payload,
1282 ..storvsp_protocol::ScsiRequest::new_zeroed()
1283 };
1284 self.send_vmbus_packet(
1285 writer,
1286 OutgoingPacketType::Completion,
1287 request_size,
1288 state.transaction_id,
1289 storvsp_protocol::Operation::COMPLETE_IO,
1290 status,
1291 response.as_bytes(),
1292 )?;
1293 Ok(())
1294 }
1295
1296 fn handle_packet<M: RingMem>(
1297 &mut self,
1298 writer: &mut queue::WriteHalf<'_, M>,
1299 packet: &IncomingPacket<'_, M>,
1300 ) -> Result<bool, WorkerError> {
1301 let packet =
1302 parse_packet(packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
1303 let submitted_io = match packet.data {
1304 PacketData::ExecuteScsi(request) => {
1305 self.push_scsi_request(packet.transaction_id, request);
1306 true
1307 }
1308 PacketData::ResetAdapter | PacketData::ResetBus | PacketData::ResetLun => {
1309 self.send_completion(writer, &packet, storvsp_protocol::NtStatus::SUCCESS, &())?;
1311 false
1312 }
1313 PacketData::CreateSubChannels(new_subchannel_count) if self.channel_index == 0 => {
1314 if let Err(err) = self
1315 .channel_control
1316 .enable_subchannels(new_subchannel_count)
1317 {
1318 tracelimit::warn_ratelimited!(?err, "cannot create subchannels");
1319 self.send_completion(
1320 writer,
1321 &packet,
1322 storvsp_protocol::NtStatus::INVALID_PARAMETER,
1323 &(),
1324 )?;
1325 false
1326 } else {
1327 if let ProtocolState::Ready {
1329 subchannel_count, ..
1330 } = &mut *self.protocol.state.write()
1331 {
1332 *subchannel_count = new_subchannel_count;
1333 } else {
1334 unreachable!()
1335 }
1336
1337 self.send_completion(
1338 writer,
1339 &packet,
1340 storvsp_protocol::NtStatus::SUCCESS,
1341 &(),
1342 )?;
1343 false
1344 }
1345 }
1346 _ => {
1347 tracelimit::warn_ratelimited!(data = ?packet.data, "unexpected packet on ready");
1348 self.send_completion(
1349 writer,
1350 &packet,
1351 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1352 &(),
1353 )?;
1354 false
1355 }
1356 };
1357 Ok(submitted_io)
1358 }
1359
1360 fn push_scsi_request(&mut self, transaction_id: u64, full_request: Arc<ScsiRequestAndRange>) {
1361 let scsi_queue = self.scsi_queue.clone();
1362 let scsi_request_state = ScsiRequestState {
1363 transaction_id,
1364 request: full_request.clone(),
1365 };
1366 let request_id = self.scsi_requests_states.insert(scsi_request_state);
1367 let future = self
1368 .future_pool
1369 .pop()
1370 .unwrap_or_else(|| OversizedBox::new(()));
1371 let future = OversizedBox::refill(future, async move {
1372 scsi_queue
1373 .execute_scsi(&full_request.external_data, &full_request.request)
1374 .await
1375 });
1376 let request = ScsiRequest::new(request_id, oversized_box::coerce!(future));
1377 self.scsi_requests.push(request);
1378 }
1379}
1380
1381impl<T: RingMem> Drop for Worker<T> {
1382 fn drop(&mut self) {
1383 self.inner
1384 .controller
1385 .remove_rescan_notification_source(&self.rescan_notification);
1386 }
1387}
1388
1389#[derive(Debug, Error)]
1390#[error("SCSI path {}:{}:{} is already in use", self.0.path, self.0.target, self.0.lun)]
1391pub struct ScsiPathInUse(pub ScsiPath);
1392
1393#[derive(Debug, Error)]
1394#[error("SCSI path {}:{}:{} is not in use", self.0.path, self.0.target, self.0.lun)]
1395pub struct ScsiPathNotInUse(ScsiPath);
1396
1397#[derive(Clone)]
1398struct ScsiRequestState {
1399 transaction_id: u64,
1400 request: Arc<ScsiRequestAndRange>,
1401}
1402
1403#[derive(Debug)]
1404struct ScsiRequestAndRange {
1405 external_data: Range,
1406 request: storvsp_protocol::ScsiRequest,
1407 request_size: usize,
1408}
1409
1410impl Inspect for ScsiRequestState {
1411 fn inspect(&self, req: inspect::Request<'_>) {
1412 req.respond()
1413 .field("transaction_id", self.transaction_id)
1414 .display(
1415 "address",
1416 &ScsiPath {
1417 path: self.request.request.path_id,
1418 target: self.request.request.target_id,
1419 lun: self.request.request.lun,
1420 },
1421 )
1422 .display_debug("operation", &ScsiOp(self.request.request.payload[0]));
1423 }
1424}
1425
1426impl StorageDevice {
1427 pub fn build_scsi(
1429 driver_source: &VmTaskDriverSource,
1430 controller: &ScsiController,
1431 instance_id: Guid,
1432 max_sub_channel_count: u16,
1433 io_queue_depth: u32,
1434 ) -> Self {
1435 Self::build_inner(
1436 driver_source,
1437 controller,
1438 instance_id,
1439 None,
1440 max_sub_channel_count,
1441 io_queue_depth,
1442 )
1443 }
1444
1445 pub fn build_ide(
1448 driver_source: &VmTaskDriverSource,
1449 channel_id: u8,
1450 device_id: u8,
1451 disk: ScsiControllerDisk,
1452 io_queue_depth: u32,
1453 ) -> Self {
1454 let path = ScsiPath {
1455 path: channel_id,
1456 target: device_id,
1457 lun: 0,
1458 };
1459
1460 let controller = ScsiController::new();
1461 controller.attach(path, disk).unwrap();
1462
1463 let instance_id = Guid {
1466 data1: channel_id.into(),
1467 data2: device_id.into(),
1468 data3: 0x8899,
1469 data4: [0; 8],
1470 };
1471 Self::build_inner(
1472 driver_source,
1473 &controller,
1474 instance_id,
1475 Some(path),
1476 0,
1477 io_queue_depth,
1478 )
1479 }
1480
1481 fn build_inner(
1482 driver_source: &VmTaskDriverSource,
1483 controller: &ScsiController,
1484 instance_id: Guid,
1485 ide_path: Option<ScsiPath>,
1486 max_sub_channel_count: u16,
1487 io_queue_depth: u32,
1488 ) -> Self {
1489 let workers = (0..max_sub_channel_count + 1)
1490 .map(|channel_index| WorkerAndDriver {
1491 worker: TaskControl::new(WorkerState),
1492 driver: driver_source
1493 .builder()
1494 .target_vp(0)
1495 .run_on_target(true)
1496 .build(format!("storvsp-{}-{}", instance_id, channel_index)),
1497 })
1498 .collect();
1499
1500 Self {
1501 instance_id,
1502 ide_path,
1503 workers,
1504 controller: controller.state.clone(),
1505 resources: Default::default(),
1506 max_sub_channel_count,
1507 driver_source: driver_source.clone(),
1508 protocol: Arc::new(Protocol {
1509 state: RwLock::new(ProtocolState::Init(InitState::Begin)),
1510 ready: Default::default(),
1511 }),
1512 io_queue_depth,
1513 }
1514 }
1515
1516 fn new_worker(
1517 &mut self,
1518 open_request: &OpenRequest,
1519 channel_index: u16,
1520 ) -> anyhow::Result<&mut Worker> {
1521 let controller = self.controller.clone();
1522
1523 let driver = self
1524 .driver_source
1525 .builder()
1526 .target_vp(open_request.open_data.target_vp)
1527 .run_on_target(true)
1528 .build(format!("storvsp-{}-{}", self.instance_id, channel_index));
1529
1530 let channel = gpadl_channel(&driver, &self.resources, open_request, channel_index)
1531 .context("failed to create vmbus channel")?;
1532
1533 let channel_control = self.resources.channel_control.clone();
1534
1535 tracing::debug!(
1536 target_vp = open_request.open_data.target_vp,
1537 channel_index,
1538 "packet processing starting...",
1539 );
1540
1541 let force_path_id = self.ide_path.map(|p| p.path);
1545
1546 let worker = Worker::new(
1547 controller,
1548 channel,
1549 channel_index,
1550 self.resources
1551 .offer_resources
1552 .guest_memory(open_request)
1553 .clone(),
1554 channel_control,
1555 self.io_queue_depth,
1556 self.protocol.clone(),
1557 force_path_id,
1558 )
1559 .map_err(RestoreError::Other)?;
1560
1561 self.workers[channel_index as usize]
1562 .driver
1563 .retarget_vp(open_request.open_data.target_vp);
1564
1565 Ok(self.workers[channel_index as usize].worker.insert(
1566 &driver,
1567 format!("storvsp worker {}-{}", self.instance_id, channel_index),
1568 worker,
1569 ))
1570 }
1571}
1572
1573#[derive(Clone)]
1575pub struct ScsiControllerDisk {
1576 disk: Arc<dyn AsyncScsiDisk>,
1577}
1578
1579impl ScsiControllerDisk {
1580 pub fn new(disk: Arc<dyn AsyncScsiDisk>) -> Self {
1582 Self { disk }
1583 }
1584}
1585
1586struct ScsiControllerState {
1587 disks: RwLock<HashMap<ScsiPath, ScsiControllerDisk>>,
1588 rescan_notification_source: Mutex<Vec<futures::channel::mpsc::Sender<()>>>,
1589 poll_mode_queue_depth: AtomicU32,
1590}
1591
1592pub struct ScsiController {
1593 state: Arc<ScsiControllerState>,
1594}
1595
1596impl ScsiController {
1597 pub fn new() -> Self {
1598 Self {
1599 state: Arc::new(ScsiControllerState {
1600 disks: Default::default(),
1601 rescan_notification_source: Mutex::new(Vec::new()),
1602 poll_mode_queue_depth: AtomicU32::new(DEFAULT_POLL_MODE_QUEUE_DEPTH),
1603 }),
1604 }
1605 }
1606
1607 pub fn attach(&self, path: ScsiPath, disk: ScsiControllerDisk) -> Result<(), ScsiPathInUse> {
1608 match self.state.disks.write().entry(path) {
1609 Entry::Occupied(_) => return Err(ScsiPathInUse(path)),
1610 Entry::Vacant(entry) => entry.insert(disk),
1611 };
1612 for source in self.state.rescan_notification_source.lock().iter_mut() {
1613 source.try_send(()).ok();
1616 }
1617 Ok(())
1618 }
1619
1620 pub fn remove(&self, path: ScsiPath) -> Result<(), ScsiPathNotInUse> {
1621 match self.state.disks.write().entry(path) {
1622 Entry::Vacant(_) => return Err(ScsiPathNotInUse(path)),
1623 Entry::Occupied(entry) => {
1624 entry.remove();
1625 }
1626 }
1627 for source in self.state.rescan_notification_source.lock().iter_mut() {
1628 source.try_send(()).ok();
1631 }
1632 Ok(())
1633 }
1634}
1635
1636impl ScsiControllerState {
1637 fn add_rescan_notification_source(&self, source: futures::channel::mpsc::Sender<()>) {
1638 self.rescan_notification_source.lock().push(source);
1639 }
1640
1641 fn remove_rescan_notification_source(&self, target: &futures::channel::mpsc::Receiver<()>) {
1642 let mut sources = self.rescan_notification_source.lock();
1643 if let Some(index) = sources
1644 .iter()
1645 .position(|source| source.is_connected_to(target))
1646 {
1647 sources.remove(index);
1648 }
1649 }
1650}
1651
1652#[async_trait]
1653impl VmbusDevice for StorageDevice {
1654 fn offer(&self) -> OfferParams {
1655 if let Some(path) = self.ide_path {
1656 let offer_properties = storvsp_protocol::OfferProperties {
1657 path_id: path.path,
1658 target_id: path.target,
1659 flags: storvsp_protocol::OFFER_PROPERTIES_FLAG_IDE_DEVICE,
1660 ..FromZeros::new_zeroed()
1661 };
1662 let mut user_defined = UserDefinedData::new_zeroed();
1663 offer_properties
1664 .write_to_prefix(&mut user_defined[..])
1665 .unwrap();
1666 OfferParams {
1667 interface_name: "ide-accel".to_owned(),
1668 instance_id: self.instance_id,
1669 interface_id: storvsp_protocol::IDE_ACCELERATOR_INTERFACE_ID,
1670 channel_type: ChannelType::Interface { user_defined },
1671 ..Default::default()
1672 }
1673 } else {
1674 OfferParams {
1675 interface_name: "scsi".to_owned(),
1676 instance_id: self.instance_id,
1677 interface_id: storvsp_protocol::SCSI_INTERFACE_ID,
1678 ..Default::default()
1679 }
1680 }
1681 }
1682
1683 fn max_subchannels(&self) -> u16 {
1684 self.max_sub_channel_count
1685 }
1686
1687 fn install(&mut self, resources: DeviceResources) {
1688 self.resources = resources;
1689 }
1690
1691 async fn open(
1692 &mut self,
1693 channel_index: u16,
1694 open_request: &OpenRequest,
1695 ) -> Result<(), ChannelOpenError> {
1696 tracing::debug!(channel_index, "scsi open channel");
1697 self.new_worker(open_request, channel_index)?;
1698 self.workers[channel_index as usize].worker.start();
1699 Ok(())
1700 }
1701
1702 async fn close(&mut self, channel_index: u16) {
1703 tracing::debug!(channel_index, "scsi close channel");
1704 let worker = &mut self.workers[channel_index as usize].worker;
1705 worker.stop().await;
1706 if worker.state_mut().is_some() {
1707 worker
1708 .state_mut()
1709 .unwrap()
1710 .wait_for_scsi_requests_complete()
1711 .await;
1712 worker.remove();
1713 }
1714 if channel_index == 0 {
1715 *self.protocol.state.write() = ProtocolState::Init(InitState::Begin);
1716 }
1717 }
1718
1719 async fn retarget_vp(&mut self, channel_index: u16, target_vp: u32) {
1720 self.workers[channel_index as usize]
1721 .driver
1722 .retarget_vp(target_vp);
1723 }
1724
1725 fn start(&mut self) {
1726 for task in self
1727 .workers
1728 .iter_mut()
1729 .filter(|task| task.worker.has_state() && !task.worker.is_running())
1730 {
1731 task.worker.start();
1732 }
1733 }
1734
1735 async fn stop(&mut self) {
1736 tracing::debug!(instance_id = ?self.instance_id, "StorageDevice stopping...");
1737 for task in self
1738 .workers
1739 .iter_mut()
1740 .filter(|task| task.worker.has_state() && task.worker.is_running())
1741 {
1742 task.worker.stop().await;
1743 }
1744 }
1745
1746 fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1747 Some(self)
1748 }
1749}
1750
1751#[async_trait]
1752impl SaveRestoreVmbusDevice for StorageDevice {
1753 async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1754 Ok(SavedStateBlob::new(self.save()?))
1755 }
1756
1757 async fn restore(
1758 &mut self,
1759 control: RestoreControl<'_>,
1760 state: SavedStateBlob,
1761 ) -> Result<(), RestoreError> {
1762 self.restore(control, state.parse()?).await
1763 }
1764}
1765
1766#[cfg(test)]
1767mod tests {
1768 use super::*;
1769 use crate::test_helpers::TestWorker;
1770 use crate::test_helpers::parse_guest_completion;
1771 use crate::test_helpers::parse_guest_completion_check_flags_status;
1772 use pal_async::DefaultDriver;
1773 use pal_async::async_test;
1774 use scsi::srb::SrbStatus;
1775 use test_with_tracing::test;
1776 use vmbus_channel::connected_async_channels;
1777
1778 impl Clone for ScsiController {
1782 fn clone(&self) -> Self {
1783 ScsiController {
1784 state: self.state.clone(),
1785 }
1786 }
1787 }
1788
1789 #[async_test]
1790 async fn test_channel_working(driver: DefaultDriver) {
1791 let (host, guest) = connected_async_channels(16 * 1024);
1793 let guest_queue = Queue::new(guest).unwrap();
1794
1795 let test_guest_mem = GuestMemory::allocate(16384);
1796 let controller = ScsiController::new();
1797 let disk = scsidisk::SimpleScsiDisk::new(
1798 disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
1799 Default::default(),
1800 );
1801 controller
1802 .attach(
1803 ScsiPath {
1804 path: 0,
1805 target: 0,
1806 lun: 0,
1807 },
1808 ScsiControllerDisk::new(Arc::new(disk)),
1809 )
1810 .unwrap();
1811
1812 let test_worker = TestWorker::start(
1813 controller.clone(),
1814 driver.clone(),
1815 test_guest_mem.clone(),
1816 host,
1817 None,
1818 );
1819
1820 let mut guest = test_helpers::TestGuest {
1821 queue: guest_queue,
1822 transaction_id: 0,
1823 };
1824
1825 guest.perform_protocol_negotiation().await;
1826
1827 const IO_LEN: usize = 4 * 1024;
1829 let write_buf = [7u8; IO_LEN];
1830 let write_gpa = 4 * 1024u64;
1831 test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
1832 guest
1833 .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
1834 .await;
1835
1836 guest
1837 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1838 .await;
1839
1840 let read_gpa = 8 * 1024u64;
1841 guest
1842 .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
1843 .await;
1844
1845 guest
1846 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1847 .await;
1848 let mut read_buf = [0u8; IO_LEN];
1849 test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
1850 for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
1851 assert_eq!(b1, b2);
1852 }
1853
1854 guest.verify_graceful_close(test_worker).await;
1856 }
1857
1858 #[async_test]
1859 async fn test_packet_sizes(driver: DefaultDriver) {
1860 let (host, guest) = connected_async_channels(16384);
1862 let guest_queue = Queue::new(guest).unwrap();
1863
1864 let test_guest_mem = GuestMemory::allocate(1024);
1865 let controller = ScsiController::new();
1866
1867 let _worker = TestWorker::start(
1868 controller.clone(),
1869 driver.clone(),
1870 test_guest_mem,
1871 host,
1872 None,
1873 );
1874
1875 let mut guest = test_helpers::TestGuest {
1876 queue: guest_queue,
1877 transaction_id: 0,
1878 };
1879
1880 let negotiate_packet = storvsp_protocol::Packet {
1881 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
1882 flags: 0,
1883 status: storvsp_protocol::NtStatus::SUCCESS,
1884 };
1885 guest
1886 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1887 .await;
1888
1889 guest.verify_completion(parse_guest_completion).await;
1890
1891 let header = storvsp_protocol::Packet {
1892 operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
1893 flags: 0,
1894 status: storvsp_protocol::NtStatus::SUCCESS,
1895 };
1896
1897 let mut buf = [0u8; 128];
1898 storvsp_protocol::ProtocolVersion {
1899 major_minor: !0,
1900 reserved: 0,
1901 }
1902 .write_to_prefix(&mut buf[..])
1903 .unwrap(); for &(len, resp_len) in &[(48, 48), (50, 56), (56, 56), (64, 64), (72, 64)] {
1906 guest
1907 .send_data_packet_sync(&[header.as_bytes(), &buf[..len - size_of_val(&header)]])
1908 .await;
1909
1910 guest
1911 .verify_completion(|packet| {
1912 let IncomingPacket::Completion(packet) = packet else {
1913 unreachable!()
1914 };
1915 assert_eq!(packet.reader().len(), resp_len);
1916 assert_eq!(
1917 packet
1918 .reader()
1919 .read_plain::<storvsp_protocol::Packet>()
1920 .unwrap()
1921 .status,
1922 storvsp_protocol::NtStatus::REVISION_MISMATCH
1923 );
1924 Ok(())
1925 })
1926 .await;
1927 }
1928 }
1929
1930 #[async_test]
1931 async fn test_wrong_first_packet(driver: DefaultDriver) {
1932 let (host, guest) = connected_async_channels(16384);
1934 let guest_queue = Queue::new(guest).unwrap();
1935
1936 let test_guest_mem = GuestMemory::allocate(1024);
1937 let controller = ScsiController::new();
1938
1939 let _worker = TestWorker::start(
1940 controller.clone(),
1941 driver.clone(),
1942 test_guest_mem,
1943 host,
1944 None,
1945 );
1946
1947 let mut guest = test_helpers::TestGuest {
1948 queue: guest_queue,
1949 transaction_id: 0,
1950 };
1951
1952 let negotiate_packet = storvsp_protocol::Packet {
1954 operation: storvsp_protocol::Operation::END_INITIALIZATION,
1955 flags: 0,
1956 status: storvsp_protocol::NtStatus::SUCCESS,
1957 };
1958 guest
1959 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1960 .await;
1961
1962 guest
1963 .verify_completion(|packet| {
1964 let IncomingPacket::Completion(packet) = packet else {
1965 unreachable!()
1966 };
1967 assert_eq!(
1968 packet
1969 .reader()
1970 .read_plain::<storvsp_protocol::Packet>()
1971 .unwrap()
1972 .status,
1973 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE
1974 );
1975 Ok(())
1976 })
1977 .await;
1978 }
1979
1980 #[async_test]
1981 async fn test_unrecognized_operation(driver: DefaultDriver) {
1982 let (host, guest) = connected_async_channels(16384);
1984 let guest_queue = Queue::new(guest).unwrap();
1985
1986 let test_guest_mem = GuestMemory::allocate(1024);
1987 let controller = ScsiController::new();
1988
1989 let worker = TestWorker::start(
1990 controller.clone(),
1991 driver.clone(),
1992 test_guest_mem,
1993 host,
1994 None,
1995 );
1996
1997 let mut guest = test_helpers::TestGuest {
1998 queue: guest_queue,
1999 transaction_id: 0,
2000 };
2001
2002 let negotiate_packet = storvsp_protocol::Packet {
2004 operation: storvsp_protocol::Operation::REMOVE_DEVICE,
2005 flags: 0,
2006 status: storvsp_protocol::NtStatus::SUCCESS,
2007 };
2008 guest
2009 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2010 .await;
2011
2012 match worker.teardown().await {
2013 Err(WorkerError::PacketError(PacketError::UnrecognizedOperation(
2014 storvsp_protocol::Operation::REMOVE_DEVICE,
2015 ))) => {}
2016 result => panic!("Worker failed with unexpected result {:?}!", result),
2017 }
2018 }
2019
2020 #[async_test]
2021 async fn test_too_many_subchannels(driver: DefaultDriver) {
2022 let (host, guest) = connected_async_channels(16384);
2024 let guest_queue = Queue::new(guest).unwrap();
2025
2026 let test_guest_mem = GuestMemory::allocate(1024);
2027 let controller = ScsiController::new();
2028
2029 let _worker = TestWorker::start(
2030 controller.clone(),
2031 driver.clone(),
2032 test_guest_mem,
2033 host,
2034 None,
2035 );
2036
2037 let mut guest = test_helpers::TestGuest {
2038 queue: guest_queue,
2039 transaction_id: 0,
2040 };
2041
2042 let negotiate_packet = storvsp_protocol::Packet {
2043 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2044 flags: 0,
2045 status: storvsp_protocol::NtStatus::SUCCESS,
2046 };
2047 guest
2048 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2049 .await;
2050 guest.verify_completion(parse_guest_completion).await;
2051
2052 let version_packet = storvsp_protocol::Packet {
2053 operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2054 flags: 0,
2055 status: storvsp_protocol::NtStatus::SUCCESS,
2056 };
2057 let version = storvsp_protocol::ProtocolVersion {
2058 major_minor: storvsp_protocol::VERSION_BLUE,
2059 reserved: 0,
2060 };
2061 guest
2062 .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2063 .await;
2064 guest.verify_completion(parse_guest_completion).await;
2065
2066 let properties_packet = storvsp_protocol::Packet {
2067 operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2068 flags: 0,
2069 status: storvsp_protocol::NtStatus::SUCCESS,
2070 };
2071 guest
2072 .send_data_packet_sync(&[properties_packet.as_bytes()])
2073 .await;
2074
2075 guest.verify_completion(parse_guest_completion).await;
2076
2077 let negotiate_packet = storvsp_protocol::Packet {
2078 operation: storvsp_protocol::Operation::CREATE_SUB_CHANNELS,
2079 flags: 0,
2080 status: storvsp_protocol::NtStatus::SUCCESS,
2081 };
2082 guest
2084 .send_data_packet_sync(&[negotiate_packet.as_bytes(), 1_u16.as_bytes()])
2085 .await;
2086
2087 guest
2088 .verify_completion(|packet| {
2089 let IncomingPacket::Completion(packet) = packet else {
2090 unreachable!()
2091 };
2092 assert_eq!(
2093 packet
2094 .reader()
2095 .read_plain::<storvsp_protocol::Packet>()
2096 .unwrap()
2097 .status,
2098 storvsp_protocol::NtStatus::INVALID_PARAMETER
2099 );
2100 Ok(())
2101 })
2102 .await;
2103 }
2104
2105 #[async_test]
2106 async fn test_begin_init_on_ready(driver: DefaultDriver) {
2107 let (host, guest) = connected_async_channels(16384);
2109 let guest_queue = Queue::new(guest).unwrap();
2110
2111 let test_guest_mem = GuestMemory::allocate(1024);
2112 let controller = ScsiController::new();
2113
2114 let _worker = TestWorker::start(
2115 controller.clone(),
2116 driver.clone(),
2117 test_guest_mem,
2118 host,
2119 None,
2120 );
2121
2122 let mut guest = test_helpers::TestGuest {
2123 queue: guest_queue,
2124 transaction_id: 0,
2125 };
2126
2127 guest.perform_protocol_negotiation().await;
2128
2129 let negotiate_packet = storvsp_protocol::Packet {
2131 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2132 flags: 0,
2133 status: storvsp_protocol::NtStatus::SUCCESS,
2134 };
2135 guest
2136 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2137 .await;
2138
2139 guest
2140 .verify_completion(|p| {
2141 parse_guest_completion_check_flags_status(
2142 p,
2143 0,
2144 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
2145 )
2146 })
2147 .await;
2148 }
2149
2150 #[async_test]
2151 async fn test_hot_add_remove(driver: DefaultDriver) {
2152 let (host, guest) = connected_async_channels(16 * 1024);
2154 let guest_queue = Queue::new(guest).unwrap();
2155
2156 let test_guest_mem = GuestMemory::allocate(16384);
2157 let controller = ScsiController::new();
2159
2160 let test_worker = TestWorker::start(
2161 controller.clone(),
2162 driver.clone(),
2163 test_guest_mem.clone(),
2164 host,
2165 None,
2166 );
2167
2168 let mut guest = test_helpers::TestGuest {
2169 queue: guest_queue,
2170 transaction_id: 0,
2171 };
2172
2173 guest.perform_protocol_negotiation().await;
2174
2175 let mut lun_list_buffer: [u8; 256] = [0; 256];
2177 let mut disk_count = 0;
2178 guest
2179 .send_report_luns_packet(ScsiPath::default(), 0, lun_list_buffer.len())
2180 .await;
2181 guest
2182 .verify_completion(|p| {
2183 test_helpers::parse_guest_completed_io_check_tx_len(p, SrbStatus::SUCCESS, Some(8))
2184 })
2185 .await;
2186 test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2187 let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2188 assert_eq!(lun_list_size, disk_count as u32 * 8);
2189
2190 const IO_LEN: usize = 4 * 1024;
2192 let write_buf = [7u8; IO_LEN];
2193 let write_gpa = 4 * 1024u64;
2194 test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2195
2196 guest
2197 .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2198 .await;
2199 guest
2200 .verify_completion(|p| {
2201 test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2202 })
2203 .await;
2204
2205 for lun in 0..4 {
2207 let disk = scsidisk::SimpleScsiDisk::new(
2208 disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
2209 Default::default(),
2210 );
2211 controller
2212 .attach(
2213 ScsiPath {
2214 path: 0,
2215 target: 0,
2216 lun,
2217 },
2218 ScsiControllerDisk::new(Arc::new(disk)),
2219 )
2220 .unwrap();
2221 guest
2222 .verify_completion(test_helpers::parse_guest_enumerate_bus)
2223 .await;
2224
2225 disk_count += 1;
2226 guest
2227 .send_report_luns_packet(ScsiPath::default(), 0, 256)
2228 .await;
2229 guest
2230 .verify_completion(|p| {
2231 test_helpers::parse_guest_completed_io_check_tx_len(
2232 p,
2233 SrbStatus::SUCCESS,
2234 Some((disk_count + 1) * 8),
2235 )
2236 })
2237 .await;
2238 test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2239 let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2240 assert_eq!(lun_list_size, disk_count as u32 * 8);
2241
2242 guest
2243 .send_write_packet(
2244 ScsiPath {
2245 path: 0,
2246 target: 0,
2247 lun,
2248 },
2249 write_gpa,
2250 1,
2251 IO_LEN,
2252 )
2253 .await;
2254 guest
2255 .verify_completion(|p| {
2256 test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS)
2257 })
2258 .await;
2259 }
2260
2261 for lun in 0..4 {
2263 controller
2264 .remove(ScsiPath {
2265 path: 0,
2266 target: 0,
2267 lun,
2268 })
2269 .unwrap();
2270 guest
2271 .verify_completion(test_helpers::parse_guest_enumerate_bus)
2272 .await;
2273
2274 disk_count -= 1;
2275 guest
2276 .send_report_luns_packet(ScsiPath::default(), 0, 4096)
2277 .await;
2278 guest
2279 .verify_completion(|p| {
2280 test_helpers::parse_guest_completed_io_check_tx_len(
2281 p,
2282 SrbStatus::SUCCESS,
2283 Some((disk_count + 1) * 8),
2284 )
2285 })
2286 .await;
2287 test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2288 let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2289 assert_eq!(lun_list_size, disk_count as u32 * 8);
2290
2291 guest
2292 .send_write_packet(
2293 ScsiPath {
2294 path: 0,
2295 target: 0,
2296 lun,
2297 },
2298 write_gpa,
2299 1,
2300 IO_LEN,
2301 )
2302 .await;
2303 guest
2304 .verify_completion(|p| {
2305 test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2306 })
2307 .await;
2308 }
2309
2310 guest.verify_graceful_close(test_worker).await;
2311 }
2312
2313 #[async_test]
2314 pub async fn test_async_disk(driver: DefaultDriver) {
2315 let device = disklayer_ram::ram_disk(64 * 1024, false).unwrap();
2316 let controller = ScsiController::new();
2317 let disk = ScsiControllerDisk::new(Arc::new(scsidisk::SimpleScsiDisk::new(
2318 device,
2319 Default::default(),
2320 )));
2321 controller
2322 .attach(
2323 ScsiPath {
2324 path: 0,
2325 target: 0,
2326 lun: 0,
2327 },
2328 disk,
2329 )
2330 .unwrap();
2331
2332 let (host, guest) = connected_async_channels(16 * 1024);
2333 let guest_queue = Queue::new(guest).unwrap();
2334
2335 let mut guest = test_helpers::TestGuest {
2336 queue: guest_queue,
2337 transaction_id: 0,
2338 };
2339
2340 let test_guest_mem = GuestMemory::allocate(16384);
2341 let worker = TestWorker::start(
2342 controller.clone(),
2343 &driver,
2344 test_guest_mem.clone(),
2345 host,
2346 None,
2347 );
2348
2349 let negotiate_packet = storvsp_protocol::Packet {
2350 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2351 flags: 0,
2352 status: storvsp_protocol::NtStatus::SUCCESS,
2353 };
2354 guest
2355 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2356 .await;
2357 guest.verify_completion(parse_guest_completion).await;
2358
2359 let version_packet = storvsp_protocol::Packet {
2360 operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2361 flags: 0,
2362 status: storvsp_protocol::NtStatus::SUCCESS,
2363 };
2364 let version = storvsp_protocol::ProtocolVersion {
2365 major_minor: storvsp_protocol::VERSION_BLUE,
2366 reserved: 0,
2367 };
2368 guest
2369 .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2370 .await;
2371 guest.verify_completion(parse_guest_completion).await;
2372
2373 let properties_packet = storvsp_protocol::Packet {
2374 operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2375 flags: 0,
2376 status: storvsp_protocol::NtStatus::SUCCESS,
2377 };
2378 guest
2379 .send_data_packet_sync(&[properties_packet.as_bytes()])
2380 .await;
2381 guest.verify_completion(parse_guest_completion).await;
2382
2383 let negotiate_packet = storvsp_protocol::Packet {
2384 operation: storvsp_protocol::Operation::END_INITIALIZATION,
2385 flags: 0,
2386 status: storvsp_protocol::NtStatus::SUCCESS,
2387 };
2388 guest
2389 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2390 .await;
2391 guest.verify_completion(parse_guest_completion).await;
2392
2393 const IO_LEN: usize = 4 * 1024;
2394 let write_buf = [7u8; IO_LEN];
2395 let write_gpa = 4 * 1024u64;
2396 test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2397 guest
2398 .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2399 .await;
2400 guest
2401 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2402 .await;
2403
2404 let read_gpa = 8 * 1024u64;
2405 guest
2406 .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
2407 .await;
2408 guest
2409 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2410 .await;
2411 let mut read_buf = [0u8; IO_LEN];
2412 test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
2413 for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
2414 assert_eq!(b1, b2);
2415 }
2416
2417 guest.verify_graceful_close(worker).await;
2418 }
2419}