1#![expect(missing_docs)]
40#![forbid(unsafe_code)]
41
42#[cfg(feature = "ioperf")]
43pub mod ioperf;
44
45#[cfg(feature = "test")]
46pub mod test_helpers;
47
48#[cfg(not(feature = "test"))]
49mod test_helpers;
50
51pub mod resolver;
52mod save_restore;
53
54use crate::ring::gparange::MultiPagedRangeBuf;
55use anyhow::Context as _;
56use async_trait::async_trait;
57use fast_select::FastSelect;
58use futures::FutureExt;
59use futures::StreamExt;
60use futures::select_biased;
61use guestmem::AccessError;
62use guestmem::GuestMemory;
63use guestmem::MemoryRead;
64use guestmem::MemoryWrite;
65use guestmem::ranges::PagedRange;
66use guid::Guid;
67use inspect::Inspect;
68use inspect::InspectMut;
69use inspect_counters::Counter;
70use inspect_counters::Histogram;
71use oversized_box::OversizedBox;
72use parking_lot::Mutex;
73use parking_lot::RwLock;
74use ring::OutgoingPacketType;
75use scsi::AdditionalSenseCode;
76use scsi::ScsiOp;
77use scsi::ScsiStatus;
78use scsi::srb::SrbStatus;
79use scsi::srb::SrbStatusAndFlags;
80use scsi_buffers::RequestBuffers;
81use scsi_core::AsyncScsiDisk;
82use scsi_core::Request;
83use scsi_core::ScsiResult;
84use scsi_defs as scsi;
85use scsidisk::illegal_request_sense;
86use slab::Slab;
87use std::collections::hash_map::Entry;
88use std::collections::hash_map::HashMap;
89use std::fmt::Debug;
90use std::future::Future;
91use std::future::poll_fn;
92use std::pin::Pin;
93use std::sync::Arc;
94use std::sync::atomic::AtomicU32;
95use std::sync::atomic::Ordering::Relaxed;
96use std::task::Context;
97use std::task::Poll;
98use storvsp_resources::ScsiPath;
99use task_control::AsyncRun;
100use task_control::InspectTask;
101use task_control::StopTask;
102use task_control::TaskControl;
103use thiserror::Error;
104use tracing_helpers::ErrorValueExt;
105use unicycle::FuturesUnordered;
106use vmbus_async::queue;
107use vmbus_async::queue::ExternalDataError;
108use vmbus_async::queue::IncomingPacket;
109use vmbus_async::queue::OutgoingPacket;
110use vmbus_async::queue::Queue;
111use vmbus_channel::RawAsyncChannel;
112use vmbus_channel::bus::ChannelType;
113use vmbus_channel::bus::OfferParams;
114use vmbus_channel::bus::OpenRequest;
115use vmbus_channel::channel::ChannelControl;
116use vmbus_channel::channel::ChannelOpenError;
117use vmbus_channel::channel::DeviceResources;
118use vmbus_channel::channel::RestoreControl;
119use vmbus_channel::channel::SaveRestoreVmbusDevice;
120use vmbus_channel::channel::VmbusDevice;
121use vmbus_channel::gpadl_ring::GpadlRingMem;
122use vmbus_channel::gpadl_ring::gpadl_channel;
123use vmbus_core::protocol::UserDefinedData;
124use vmbus_ring as ring;
125use vmbus_ring::RingMem;
126use vmcore::save_restore::RestoreError;
127use vmcore::save_restore::SaveError;
128use vmcore::save_restore::SavedStateBlob;
129use vmcore::vm_task::VmTaskDriver;
130use vmcore::vm_task::VmTaskDriverSource;
131use zerocopy::FromBytes;
132use zerocopy::FromZeros;
133use zerocopy::Immutable;
134use zerocopy::IntoBytes;
135use zerocopy::KnownLayout;
136
137const DEFAULT_POLL_MODE_QUEUE_DEPTH: u32 = 1;
142
143pub struct StorageDevice {
144 instance_id: Guid,
145 ide_path: Option<ScsiPath>,
146 workers: Vec<WorkerAndDriver>,
147 controller: Arc<ScsiControllerState>,
148 resources: DeviceResources,
149 driver_source: VmTaskDriverSource,
150 max_sub_channel_count: u16,
151 protocol: Arc<Protocol>,
152 io_queue_depth: u32,
153}
154
155#[derive(Inspect)]
156struct WorkerAndDriver {
157 #[inspect(flatten)]
158 worker: TaskControl<WorkerState, Worker>,
159 driver: VmTaskDriver,
160}
161
162struct WorkerState;
163
164impl InspectMut for StorageDevice {
165 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
166 let mut resp = req.respond();
167
168 let disks = self.controller.disks.read();
169 for (path, controller_disk) in disks.iter() {
170 resp.child(&format!("disks/{}", path), |req| {
171 controller_disk.disk.inspect(req);
172 });
173 }
174
175 resp.fields(
176 "channels",
177 self.workers
178 .iter()
179 .filter(|task| task.worker.has_state())
180 .enumerate(),
181 )
182 .field(
183 "poll_mode_queue_depth",
184 inspect::AtomicMut(&self.controller.poll_mode_queue_depth),
185 );
186 }
187}
188
189struct Worker<T: RingMem = GpadlRingMem> {
190 inner: WorkerInner,
191 rescan_notification: futures::channel::mpsc::Receiver<()>,
192 fast_select: FastSelect,
193 queue: Queue<T>,
194}
195
196struct Protocol {
197 state: RwLock<ProtocolState>,
198 ready: event_listener::Event,
200}
201
202struct WorkerInner {
203 protocol: Arc<Protocol>,
204 request_size: usize,
205 controller: Arc<ScsiControllerState>,
206 channel_index: u16,
207 scsi_queue: Arc<ScsiCommandQueue>,
208 scsi_requests: FuturesUnordered<ScsiRequest>,
209 scsi_requests_states: Slab<ScsiRequestState>,
210 full_request_pool: Vec<Arc<ScsiRequestAndRange>>,
211 future_pool: Vec<OversizedBox<(), ScsiOpStorage>>,
212 channel_control: ChannelControl,
213 max_io_queue_depth: usize,
214 stats: WorkerStats,
215}
216
217#[derive(Debug, Default, Inspect)]
218struct WorkerStats {
219 ios_submitted: Counter,
220 ios_completed: Counter,
221 wakes: Counter,
222 wakes_spurious: Counter,
223 per_wake_submissions: Histogram<10>,
224 per_wake_completions: Histogram<10>,
225}
226
227#[repr(u16)]
228#[derive(Copy, Clone, Debug, Inspect, PartialEq, Eq, PartialOrd, Ord)]
229enum Version {
230 Win6 = storvsp_protocol::VERSION_WIN6,
231 Win7 = storvsp_protocol::VERSION_WIN7,
232 Win8 = storvsp_protocol::VERSION_WIN8,
233 Blue = storvsp_protocol::VERSION_BLUE,
234}
235
236#[derive(Debug, Error)]
237#[error("protocol version {0:#x} not supported")]
238struct UnsupportedVersion(u16);
239
240impl Version {
241 fn parse(major_minor: u16) -> Result<Self, UnsupportedVersion> {
242 let version = match major_minor {
243 storvsp_protocol::VERSION_WIN6 => Self::Win6,
244 storvsp_protocol::VERSION_WIN7 => Self::Win7,
245 storvsp_protocol::VERSION_WIN8 => Self::Win8,
246 storvsp_protocol::VERSION_BLUE => Self::Blue,
247 version => return Err(UnsupportedVersion(version)),
248 };
249 assert_eq!(version as u16, major_minor);
250 Ok(version)
251 }
252
253 fn max_request_size(&self) -> usize {
254 match self {
255 Version::Win8 | Version::Blue => storvsp_protocol::SCSI_REQUEST_LEN_V2,
256 Version::Win6 | Version::Win7 => storvsp_protocol::SCSI_REQUEST_LEN_V1,
257 }
258 }
259}
260
261#[derive(Copy, Clone)]
262enum ProtocolState {
263 Init(InitState),
264 Ready {
265 version: Version,
266 subchannel_count: u16,
267 },
268}
269
270#[derive(Copy, Clone, Debug)]
271enum InitState {
272 Begin,
273 QueryVersion,
274 QueryProperties {
275 version: Version,
276 },
277 EndInitialization {
278 version: Version,
279 subchannel_count: Option<u16>,
280 },
281}
282
283type ScsiOpStorage = [u64; SCSI_REQUEST_STACK_SIZE / 8];
291type ScsiOpFuture = Pin<OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>>;
292
293const SCSI_REQUEST_STACK_SIZE: usize = scsi_core::ASYNC_SCSI_DISK_STACK_SIZE + 272;
298
299struct ScsiRequest {
300 request_id: usize,
301 future: Option<ScsiOpFuture>,
302}
303
304impl ScsiRequest {
305 fn new(
306 request_id: usize,
307 future: OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>,
308 ) -> Self {
309 Self {
310 request_id,
311 future: Some(future.into()),
312 }
313 }
314}
315
316impl Future for ScsiRequest {
317 type Output = (usize, ScsiResult, OversizedBox<(), ScsiOpStorage>);
318
319 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
320 let this = self.get_mut();
321 let future = this.future.as_mut().unwrap().as_mut();
322 let result = std::task::ready!(future.poll(cx));
323 let future = this.future.take().unwrap();
325 Poll::Ready((this.request_id, result, OversizedBox::empty_pinned(future)))
326 }
327}
328
329#[derive(Debug, Error)]
330enum WorkerError {
331 #[error("packet error")]
332 PacketError(#[source] PacketError),
333 #[error("queue error")]
334 Queue(#[source] queue::Error),
335 #[error("queue should have enough space but no longer does")]
336 NotEnoughSpace,
337}
338
339#[derive(Debug, Error)]
340enum PacketError {
341 #[error("Not transactional")]
342 NotTransactional,
343 #[error("Unrecognized operation {0:?}")]
344 UnrecognizedOperation(storvsp_protocol::Operation),
345 #[error("Invalid packet type")]
346 InvalidPacketType,
347 #[error("Invalid data transfer length")]
348 InvalidDataTransferLength,
349 #[error("Access error: {0}")]
350 Access(#[source] AccessError),
351 #[error("Range error")]
352 Range(#[source] ExternalDataError),
353}
354
355#[derive(Debug, Default, Clone)]
356struct Range {
357 len: usize,
358 is_write: bool,
359}
360
361impl Range {
362 fn new(buf: &MultiPagedRangeBuf, request: &storvsp_protocol::ScsiRequest) -> Option<Self> {
363 let len = request.data_transfer_length as usize;
364 let is_write = request.data_in != 0;
365 if buf.range_count() > 1 || (len > 0 && buf.first()?.len() < len) {
368 return None;
369 }
370 Some(Self { len, is_write })
371 }
372
373 fn buffer<'a>(
374 &'a self,
375 buf: &'a MultiPagedRangeBuf,
376 guest_memory: &'a GuestMemory,
377 ) -> RequestBuffers<'a> {
378 let mut range = buf.first().unwrap_or_else(PagedRange::empty);
379 range.truncate(self.len);
380 RequestBuffers::new(guest_memory, range, self.is_write)
381 }
382}
383
384#[derive(Debug)]
385struct Packet {
386 data: PacketData,
387 transaction_id: u64,
388 request_size: usize,
389}
390
391#[derive(Debug)]
392enum PacketData {
393 BeginInitialization,
394 EndInitialization,
395 QueryProtocolVersion(u16),
396 QueryProperties,
397 CreateSubChannels(u16),
398 ExecuteScsi(Arc<ScsiRequestAndRange>),
399 ResetBus,
400 ResetAdapter,
401 ResetLun,
402}
403
404#[derive(Debug)]
405pub struct RangeError;
406
407fn parse_packet<T: RingMem>(
408 packet: &IncomingPacket<'_, T>,
409 pool: &mut Vec<Arc<ScsiRequestAndRange>>,
410) -> Result<Packet, PacketError> {
411 let packet = match packet {
412 IncomingPacket::Completion(_) => return Err(PacketError::InvalidPacketType),
413 IncomingPacket::Data(packet) => packet,
414 };
415 let transaction_id = packet
416 .transaction_id()
417 .ok_or(PacketError::NotTransactional)?;
418
419 let mut reader = packet.reader();
420 let header: storvsp_protocol::Packet = reader.read_plain().map_err(PacketError::Access)?;
421 let request_size = reader.len().min(storvsp_protocol::SCSI_REQUEST_LEN_MAX);
425 let data = match header.operation {
426 storvsp_protocol::Operation::BEGIN_INITIALIZATION => PacketData::BeginInitialization,
427 storvsp_protocol::Operation::END_INITIALIZATION => PacketData::EndInitialization,
428 storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION => {
429 let mut version = storvsp_protocol::ProtocolVersion::new_zeroed();
430 reader
431 .read(version.as_mut_bytes())
432 .map_err(PacketError::Access)?;
433 PacketData::QueryProtocolVersion(version.major_minor)
434 }
435 storvsp_protocol::Operation::QUERY_PROPERTIES => PacketData::QueryProperties,
436 storvsp_protocol::Operation::EXECUTE_SRB => {
437 let mut full_request = pool.pop().unwrap_or_else(|| {
438 Arc::new(ScsiRequestAndRange {
439 external_data: Range::default(),
440 external_data_buf: MultiPagedRangeBuf::new(),
441 request: storvsp_protocol::ScsiRequest::new_zeroed(),
442 request_size,
443 })
444 });
445
446 {
447 let full_request = Arc::get_mut(&mut full_request).unwrap();
448 let request_buf = &mut full_request.request.as_mut_bytes()[..request_size];
449 reader.read(request_buf).map_err(PacketError::Access)?;
450
451 full_request.external_data_buf.clear();
452 packet
453 .read_external_ranges(&mut full_request.external_data_buf)
454 .map_err(PacketError::Range)?;
455
456 full_request.external_data =
457 Range::new(&full_request.external_data_buf, &full_request.request)
458 .ok_or(PacketError::InvalidDataTransferLength)?;
459 }
460
461 PacketData::ExecuteScsi(full_request)
462 }
463 storvsp_protocol::Operation::RESET_LUN => PacketData::ResetLun,
464 storvsp_protocol::Operation::RESET_ADAPTER => PacketData::ResetAdapter,
465 storvsp_protocol::Operation::RESET_BUS => PacketData::ResetBus,
466 storvsp_protocol::Operation::CREATE_SUB_CHANNELS => {
467 let mut sub_channel_count: u16 = 0;
468 reader
469 .read(sub_channel_count.as_mut_bytes())
470 .map_err(PacketError::Access)?;
471 PacketData::CreateSubChannels(sub_channel_count)
472 }
473 _ => return Err(PacketError::UnrecognizedOperation(header.operation)),
474 };
475
476 if let PacketData::ExecuteScsi(_) = data {
477 tracing::trace!(transaction_id, ?data, "parse_packet");
478 } else {
479 tracing::debug!(transaction_id, ?data, "parse_packet");
480 }
481
482 Ok(Packet {
483 data,
484 request_size,
485 transaction_id,
486 })
487}
488
489impl WorkerInner {
490 fn send_vmbus_packet<M: RingMem>(
491 &mut self,
492 writer: &mut queue::WriteBatch<'_, M>,
493 packet_type: OutgoingPacketType<'_>,
494 request_size: usize,
495 transaction_id: u64,
496 operation: storvsp_protocol::Operation,
497 status: storvsp_protocol::NtStatus,
498 payload: &[u8],
499 ) -> Result<(), WorkerError> {
500 let header = storvsp_protocol::Packet {
501 operation,
502 flags: 0,
503 status,
504 };
505
506 let packet_size = size_of_val(&header) + request_size;
507
508 let len = size_of_val(&header) + size_of_val(payload);
513 let padding = [0; storvsp_protocol::SCSI_REQUEST_LEN_MAX];
514 let (payload_bytes, padding_bytes) = if len > packet_size {
515 (&payload[..packet_size - size_of_val(&header)], &[][..])
516 } else {
517 (payload, &padding[..packet_size - len])
518 };
519 assert_eq!(
520 size_of_val(&header) + payload_bytes.len() + padding_bytes.len(),
521 packet_size
522 );
523 writer
524 .try_write(&OutgoingPacket {
525 transaction_id,
526 packet_type,
527 payload: &[header.as_bytes(), payload_bytes, padding_bytes],
528 })
529 .map_err(|err| match err {
530 queue::TryWriteError::Full(_) => WorkerError::NotEnoughSpace,
531 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
532 })
533 }
534
535 fn send_packet<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
536 &mut self,
537 writer: &mut queue::WriteHalf<'_, M>,
538 operation: storvsp_protocol::Operation,
539 status: storvsp_protocol::NtStatus,
540 payload: &P,
541 ) -> Result<(), WorkerError> {
542 self.send_vmbus_packet(
543 &mut writer.batched(),
544 OutgoingPacketType::InBandNoCompletion,
545 self.request_size,
546 0,
547 operation,
548 status,
549 payload.as_bytes(),
550 )
551 }
552
553 fn send_completion<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
554 &mut self,
555 writer: &mut queue::WriteHalf<'_, M>,
556 packet: &Packet,
557 status: storvsp_protocol::NtStatus,
558 payload: &P,
559 ) -> Result<(), WorkerError> {
560 self.send_vmbus_packet(
561 &mut writer.batched(),
562 OutgoingPacketType::Completion,
563 packet.request_size,
564 packet.transaction_id,
565 storvsp_protocol::Operation::COMPLETE_IO,
566 status,
567 payload.as_bytes(),
568 )
569 }
570}
571
572struct ScsiCommandQueue {
573 controller: Arc<ScsiControllerState>,
574 mem: GuestMemory,
575 force_path_id: Option<u8>,
576}
577
578impl ScsiCommandQueue {
579 async fn execute_scsi(&self, full_request: &ScsiRequestAndRange) -> ScsiResult {
580 let request = &full_request.request;
581 let op = ScsiOp(request.payload[0]);
582 let external_data = full_request
583 .external_data
584 .buffer(&full_request.external_data_buf, &self.mem);
585
586 tracing::trace!(
587 path_id = request.path_id,
588 target_id = request.target_id,
589 lun = request.lun,
590 op = ?op,
591 "execute_scsi start...",
592 );
593
594 let path_id = self.force_path_id.unwrap_or(request.path_id);
595
596 let controller_disk = self
597 .controller
598 .disks
599 .read()
600 .get(&ScsiPath {
601 path: path_id,
602 target: request.target_id,
603 lun: request.lun,
604 })
605 .cloned();
606
607 let result = match op {
608 ScsiOp::REPORT_LUNS => {
609 const HEADER_SIZE: usize = size_of::<scsi::LunList>();
610 let mut luns: Vec<u8> = self
611 .controller
612 .disks
613 .read()
614 .iter()
615 .flat_map(|(path, _)| {
616 if request.path_id == path.path && request.target_id == path.target {
619 Some(path.lun)
620 } else {
621 None
622 }
623 })
624 .collect();
625 luns.sort_unstable();
626 let mut data: Vec<u64> = vec![0; luns.len() + 1];
627 let header = scsi::LunList {
628 length: (luns.len() as u32 * 8).into(),
629 reserved: [0; 4],
630 };
631 data.as_mut_bytes()[..HEADER_SIZE].copy_from_slice(header.as_bytes());
632 for (i, lun) in luns.iter().enumerate() {
633 data[i + 1].as_mut_bytes()[..2].copy_from_slice(&(*lun as u16).to_be_bytes());
634 }
635 if external_data.len() >= HEADER_SIZE {
636 let tx = std::cmp::min(external_data.len(), data.as_bytes().len());
637 external_data.writer().write(&data.as_bytes()[..tx]).map_or(
638 ScsiResult {
639 scsi_status: ScsiStatus::CHECK_CONDITION,
640 srb_status: SrbStatus::INVALID_REQUEST,
641 tx: 0,
642 sense_data: Some(illegal_request_sense(
643 AdditionalSenseCode::INVALID_CDB,
644 )),
645 },
646 |_| ScsiResult {
647 scsi_status: ScsiStatus::GOOD,
648 srb_status: SrbStatus::SUCCESS,
649 tx,
650 sense_data: None,
651 },
652 )
653 } else {
654 ScsiResult {
655 scsi_status: ScsiStatus::GOOD,
656 srb_status: SrbStatus::SUCCESS,
657 tx: 0,
658 sense_data: None,
659 }
660 }
661 }
662 _ if controller_disk.is_some() => {
663 let mut cdb = [0; 16];
664 cdb.copy_from_slice(&request.payload[0..storvsp_protocol::CDB16GENERIC_LENGTH]);
665 controller_disk
666 .unwrap()
667 .disk
668 .execute_scsi(
669 &external_data,
670 &Request {
671 cdb,
672 srb_flags: request.srb_flags,
673 },
674 )
675 .await
676 }
677 ScsiOp::INQUIRY => {
678 let cdb = scsi::CdbInquiry::ref_from_prefix(&request.payload)
679 .unwrap()
680 .0; if external_data.len() < cdb.allocation_length.get() as usize
682 || request.data_in != storvsp_protocol::SCSI_IOCTL_DATA_IN
683 || (cdb.allocation_length.get() as usize) < size_of::<scsi::InquiryDataHeader>()
684 {
685 ScsiResult {
686 scsi_status: ScsiStatus::CHECK_CONDITION,
687 srb_status: SrbStatus::INVALID_REQUEST,
688 tx: 0,
689 sense_data: Some(illegal_request_sense(AdditionalSenseCode::INVALID_CDB)),
690 }
691 } else {
692 let enable_vpd = cdb.flags.vpd();
693 if enable_vpd || cdb.page_code != 0 {
694 ScsiResult {
696 scsi_status: ScsiStatus::CHECK_CONDITION,
697 srb_status: SrbStatus::INVALID_REQUEST,
698 tx: 0,
699 sense_data: Some(illegal_request_sense(
700 AdditionalSenseCode::INVALID_CDB,
701 )),
702 }
703 } else {
704 const LOGICAL_UNIT_NOT_PRESENT_DEVICE: u8 = 0x7F;
705 let mut data = scsidisk::INQUIRY_DATA_TEMPLATE;
706 data.header.device_type = LOGICAL_UNIT_NOT_PRESENT_DEVICE;
707
708 if request.lun != 0 {
709 data.vendor_id = [0; 8];
711 data.product_id = [0; 16];
712 data.product_revision_level = [0; 4];
713 }
714
715 let datab = data.as_bytes();
716 let tx = std::cmp::min(
717 cdb.allocation_length.get() as usize,
718 size_of::<scsi::InquiryData>(),
719 );
720 external_data.writer().write(&datab[..tx]).map_or(
721 ScsiResult {
722 scsi_status: ScsiStatus::CHECK_CONDITION,
723 srb_status: SrbStatus::INVALID_REQUEST,
724 tx: 0,
725 sense_data: Some(illegal_request_sense(
726 AdditionalSenseCode::INVALID_CDB,
727 )),
728 },
729 |_| ScsiResult {
730 scsi_status: ScsiStatus::GOOD,
731 srb_status: SrbStatus::SUCCESS,
732 tx,
733 sense_data: None,
734 },
735 )
736 }
737 }
738 }
739 _ => ScsiResult {
740 scsi_status: ScsiStatus::CHECK_CONDITION,
741 srb_status: SrbStatus::INVALID_LUN,
742 tx: 0,
743 sense_data: None,
744 },
745 };
746
747 tracing::trace!(
748 path_id = request.path_id,
749 target_id = request.target_id,
750 lun = request.lun,
751 op = ?op,
752 result = ?result,
753 "execute_scsi completed.",
754 );
755 result
756 }
757}
758
759impl<T: RingMem + 'static> Worker<T> {
760 fn new(
761 controller: Arc<ScsiControllerState>,
762 channel: RawAsyncChannel<T>,
763 channel_index: u16,
764 mem: GuestMemory,
765 channel_control: ChannelControl,
766 io_queue_depth: u32,
767 protocol: Arc<Protocol>,
768 force_path_id: Option<u8>,
769 ) -> anyhow::Result<Self> {
770 let queue = Queue::new(channel)?;
771 #[expect(clippy::disallowed_methods)] let (source, target) = futures::channel::mpsc::channel(1);
773 controller.add_rescan_notification_source(source);
774
775 let max_io_queue_depth = io_queue_depth.max(1) as usize;
776 Ok(Self {
777 inner: WorkerInner {
778 protocol,
779 request_size: storvsp_protocol::SCSI_REQUEST_LEN_V1,
780 controller: controller.clone(),
781 channel_index,
782 scsi_queue: Arc::new(ScsiCommandQueue {
783 controller,
784 mem,
785 force_path_id,
786 }),
787 scsi_requests: FuturesUnordered::new(),
788 scsi_requests_states: Slab::with_capacity(max_io_queue_depth),
789 channel_control,
790 max_io_queue_depth,
791 future_pool: Vec::new(),
792 full_request_pool: Vec::new(),
793 stats: Default::default(),
794 },
795 queue,
796 rescan_notification: target,
797 fast_select: FastSelect::new(),
798 })
799 }
800
801 async fn wait_for_scsi_requests_complete(&mut self) {
802 tracing::debug!(
803 channel_index = self.inner.channel_index,
804 "wait for IOs completed..."
805 );
806 while let Some((id, _, _)) = self.inner.scsi_requests.next().await {
807 self.inner.scsi_requests_states.remove(id);
808 }
809 }
810}
811
812impl InspectTask<Worker> for WorkerState {
813 fn inspect(&self, req: inspect::Request<'_>, worker: Option<&Worker>) {
814 if let Some(worker) = worker {
815 let mut resp = req.respond();
816 if worker.inner.channel_index == 0 {
817 let (state, version, subchannel_count) = match *worker.inner.protocol.state.read() {
818 ProtocolState::Init(state) => match state {
819 InitState::Begin => ("begin_init", None, None),
820 InitState::QueryVersion => ("query_version", None, None),
821 InitState::QueryProperties { version } => {
822 ("query_properties", Some(version), None)
823 }
824 InitState::EndInitialization {
825 version,
826 subchannel_count,
827 } => ("end_init", Some(version), subchannel_count),
828 },
829 ProtocolState::Ready {
830 version,
831 subchannel_count,
832 } => ("ready", Some(version), Some(subchannel_count)),
833 };
834 resp.field("state", state)
835 .field("version", version)
836 .field("subchannel_count", subchannel_count);
837 }
838 resp.field("pending_packets", worker.inner.scsi_requests_states.len())
839 .fields("io", worker.inner.scsi_requests_states.iter())
840 .field("stats", &worker.inner.stats)
841 .field("ring", &worker.queue)
842 .field("max_io_queue_depth", worker.inner.max_io_queue_depth);
843 }
844 }
845}
846
847impl<T: 'static + Send + Sync + RingMem> AsyncRun<Worker<T>> for WorkerState {
848 async fn run(
849 &mut self,
850 stop: &mut StopTask<'_>,
851 worker: &mut Worker<T>,
852 ) -> Result<(), task_control::Cancelled> {
853 let fut = async {
854 if worker.inner.channel_index == 0 {
855 worker.process_primary().await
856 } else {
857 let protocol_version = loop {
860 let listener = worker.inner.protocol.ready.listen();
861 if let ProtocolState::Ready { version, .. } =
862 *worker.inner.protocol.state.read()
863 {
864 break version;
865 }
866 tracing::debug!("subchannel waiting for initialization to end");
867 listener.await
868 };
869 worker
870 .inner
871 .process_ready(&mut worker.queue, protocol_version)
872 .await
873 }
874 };
875
876 match stop.until_stopped(fut).await? {
877 Ok(_) => {}
878 Err(e) => tracing::error!(error = e.as_error(), "process_packets error"),
879 }
880 Ok(())
881 }
882}
883
884impl WorkerInner {
885 async fn next_packet<'a, M: RingMem>(
888 &mut self,
889 reader: &'a mut queue::ReadHalf<'a, M>,
890 ) -> Result<Packet, WorkerError> {
891 let packet = reader.read().await.map_err(WorkerError::Queue)?;
892 let stor_packet =
893 parse_packet(&packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
894 Ok(stor_packet)
895 }
896
897 fn poll_for_ring_space<M: RingMem>(
904 &mut self,
905 cx: &mut Context<'_>,
906 writer: &mut queue::WriteHalf<'_, M>,
907 ) -> Poll<Result<(), WorkerError>> {
908 writer
909 .poll_ready(cx, MAX_VMBUS_PACKET_SIZE)
910 .map_err(WorkerError::Queue)
911 }
912}
913
914const MAX_VMBUS_PACKET_SIZE: usize = ring::PacketSize::in_band(
915 size_of::<storvsp_protocol::Packet>() + storvsp_protocol::SCSI_REQUEST_LEN_MAX,
916);
917
918impl<T: RingMem> Worker<T> {
919 async fn process_primary(&mut self) -> Result<(), WorkerError> {
921 loop {
922 let current_state = *self.inner.protocol.state.read();
923 match current_state {
924 ProtocolState::Ready { version, .. } => {
925 break loop {
926 select_biased! {
927 r = self.inner.process_ready(&mut self.queue, version).fuse() => break r,
928 _ = self.fast_select.select((self.rescan_notification.select_next_some(),)).fuse() => {
929 if version >= Version::Win7
930 {
931 tracing::debug!("rescan notification received, sending ENUMERATE_BUS");
932 self.inner.send_packet(&mut self.queue.split().1, storvsp_protocol::Operation::ENUMERATE_BUS, storvsp_protocol::NtStatus::SUCCESS, &())?;
933 }
934 }
935 }
936 };
937 }
938 ProtocolState::Init(state) => {
939 let (mut reader, mut writer) = self.queue.split();
940
941 poll_fn(|cx| self.inner.poll_for_ring_space(cx, &mut writer)).await?;
944
945 tracing::debug!(?state, "process_primary");
946 match state {
947 InitState::Begin => {
948 let packet = self.inner.next_packet(&mut reader).await?;
949 if let PacketData::BeginInitialization = packet.data {
950 self.inner.send_completion(
951 &mut writer,
952 &packet,
953 storvsp_protocol::NtStatus::SUCCESS,
954 &(),
955 )?;
956 *self.inner.protocol.state.write() =
957 ProtocolState::Init(InitState::QueryVersion);
958 } else {
959 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
960 self.inner.send_completion(
961 &mut writer,
962 &packet,
963 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
964 &(),
965 )?;
966 }
967 }
968 InitState::QueryVersion => {
969 let packet = self.inner.next_packet(&mut reader).await?;
970 if let PacketData::QueryProtocolVersion(major_minor) = packet.data {
971 if let Ok(version) = Version::parse(major_minor) {
972 self.inner.send_completion(
973 &mut writer,
974 &packet,
975 storvsp_protocol::NtStatus::SUCCESS,
976 &storvsp_protocol::ProtocolVersion {
977 major_minor,
978 reserved: 0,
979 },
980 )?;
981 self.inner.request_size = version.max_request_size();
982 *self.inner.protocol.state.write() =
983 ProtocolState::Init(InitState::QueryProperties { version });
984
985 tracelimit::info_ratelimited!(
986 ?version,
987 "scsi version negotiated"
988 );
989 } else {
990 self.inner.send_completion(
991 &mut writer,
992 &packet,
993 storvsp_protocol::NtStatus::REVISION_MISMATCH,
994 &storvsp_protocol::ProtocolVersion {
995 major_minor,
996 reserved: 0,
997 },
998 )?;
999 *self.inner.protocol.state.write() =
1000 ProtocolState::Init(InitState::QueryVersion);
1001 }
1002 } else {
1003 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1004 self.inner.send_completion(
1005 &mut writer,
1006 &packet,
1007 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1008 &(),
1009 )?;
1010 }
1011 }
1012 InitState::QueryProperties { version } => {
1013 let packet = self.inner.next_packet(&mut reader).await?;
1014 if let PacketData::QueryProperties = packet.data {
1015 let multi_channel_supported = version >= Version::Win8;
1016
1017 self.inner.send_completion(
1018 &mut writer,
1019 &packet,
1020 storvsp_protocol::NtStatus::SUCCESS,
1021 &storvsp_protocol::ChannelProperties {
1022 max_transfer_bytes: 0x40000, flags: {
1024 if multi_channel_supported {
1025 storvsp_protocol::STORAGE_CHANNEL_SUPPORTS_MULTI_CHANNEL
1026 } else {
1027 0
1028 }
1029 },
1030 maximum_sub_channel_count: if multi_channel_supported {
1031 self.inner.channel_control.max_subchannels()
1032 } else {
1033 0
1034 },
1035 reserved: 0,
1036 reserved2: 0,
1037 reserved3: [0, 0],
1038 },
1039 )?;
1040 *self.inner.protocol.state.write() =
1041 ProtocolState::Init(InitState::EndInitialization {
1042 version,
1043 subchannel_count: if multi_channel_supported {
1044 None
1045 } else {
1046 Some(0)
1047 },
1048 });
1049 } else {
1050 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1051 self.inner.send_completion(
1052 &mut writer,
1053 &packet,
1054 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1055 &(),
1056 )?;
1057 }
1058 }
1059 InitState::EndInitialization {
1060 version,
1061 subchannel_count,
1062 } => {
1063 let packet = self.inner.next_packet(&mut reader).await?;
1064 match packet.data {
1065 PacketData::CreateSubChannels(sub_channel_count)
1066 if subchannel_count.is_none() =>
1067 {
1068 if let Err(err) = self
1069 .inner
1070 .channel_control
1071 .enable_subchannels(sub_channel_count)
1072 {
1073 tracelimit::warn_ratelimited!(
1074 ?err,
1075 "cannot enable subchannels"
1076 );
1077 self.inner.send_completion(
1078 &mut writer,
1079 &packet,
1080 storvsp_protocol::NtStatus::INVALID_PARAMETER,
1081 &(),
1082 )?;
1083 } else {
1084 self.inner.send_completion(
1085 &mut writer,
1086 &packet,
1087 storvsp_protocol::NtStatus::SUCCESS,
1088 &(),
1089 )?;
1090 *self.inner.protocol.state.write() =
1091 ProtocolState::Init(InitState::EndInitialization {
1092 version,
1093 subchannel_count: Some(sub_channel_count),
1094 });
1095 }
1096 }
1097 PacketData::EndInitialization => {
1098 self.inner.send_completion(
1099 &mut writer,
1100 &packet,
1101 storvsp_protocol::NtStatus::SUCCESS,
1102 &(),
1103 )?;
1104 self.rescan_notification.try_next().ok();
1107 *self.inner.protocol.state.write() = ProtocolState::Ready {
1108 version,
1109 subchannel_count: subchannel_count.unwrap_or(0),
1110 };
1111 self.inner.protocol.ready.notify(usize::MAX);
1114 }
1115 _ => {
1116 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1117 self.inner.send_completion(
1118 &mut writer,
1119 &packet,
1120 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1121 &(),
1122 )?;
1123 }
1124 }
1125 }
1126 }
1127 }
1128 }
1129 }
1130 }
1131}
1132
1133fn convert_srb_status_to_nt_status(srb_status: SrbStatus) -> storvsp_protocol::NtStatus {
1134 match srb_status {
1135 SrbStatus::BUSY => storvsp_protocol::NtStatus::DEVICE_BUSY,
1136 SrbStatus::SUCCESS => storvsp_protocol::NtStatus::SUCCESS,
1137 SrbStatus::INVALID_LUN
1138 | SrbStatus::INVALID_TARGET_ID
1139 | SrbStatus::NO_DEVICE
1140 | SrbStatus::NO_HBA => storvsp_protocol::NtStatus::DEVICE_DOES_NOT_EXIST,
1141 SrbStatus::COMMAND_TIMEOUT | SrbStatus::TIMEOUT => storvsp_protocol::NtStatus::IO_TIMEOUT,
1142 SrbStatus::SELECTION_TIMEOUT => storvsp_protocol::NtStatus::DEVICE_NOT_CONNECTED,
1143 SrbStatus::BAD_FUNCTION | SrbStatus::BAD_SRB_BLOCK_LENGTH => {
1144 storvsp_protocol::NtStatus::INVALID_DEVICE_REQUEST
1145 }
1146 SrbStatus::DATA_OVERRUN => storvsp_protocol::NtStatus::BUFFER_OVERFLOW,
1147 SrbStatus::REQUEST_FLUSHED => storvsp_protocol::NtStatus::UNSUCCESSFUL,
1148 SrbStatus::ABORTED => storvsp_protocol::NtStatus::CANCELLED,
1149 _ => storvsp_protocol::NtStatus::IO_DEVICE_ERROR,
1150 }
1151}
1152
1153impl WorkerInner {
1154 async fn process_ready<M: RingMem>(
1156 &mut self,
1157 queue: &mut Queue<M>,
1158 protocol_version: Version,
1159 ) -> Result<(), WorkerError> {
1160 self.request_size = protocol_version.max_request_size();
1161 poll_fn(|cx| self.poll_process_ready(cx, queue)).await
1162 }
1163
1164 fn poll_process_ready<M: RingMem>(
1166 &mut self,
1167 cx: &mut Context<'_>,
1168 queue: &mut Queue<M>,
1169 ) -> Poll<Result<(), WorkerError>> {
1170 self.stats.wakes.increment();
1171
1172 let (mut reader, mut writer) = queue.split();
1173 let mut total_completions = 0;
1174 let mut total_submissions = 0;
1175 let poll_mode_queue_depth = self.controller.poll_mode_queue_depth.load(Relaxed) as usize;
1176
1177 loop {
1178 'outer: while !self.scsi_requests_states.is_empty() {
1180 {
1181 let mut batch = writer.batched();
1182 loop {
1183 if !batch
1187 .can_write(MAX_VMBUS_PACKET_SIZE)
1188 .map_err(WorkerError::Queue)?
1189 {
1190 break;
1192 }
1193 if let Poll::Ready(Some((request_id, result, future))) =
1194 self.scsi_requests.poll_next_unpin(cx)
1195 {
1196 self.future_pool.push(future);
1197 self.handle_completion(&mut batch, request_id, result)?;
1198 total_completions += 1;
1199 } else {
1200 tracing::trace!("out of completions");
1201 break 'outer;
1202 }
1203 }
1204 }
1205
1206 if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1208 tracing::trace!("out of ring space");
1209 break;
1210 }
1211 }
1212
1213 let mut submissions = 0;
1214 'outer: loop {
1216 if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1217 break;
1218 }
1219 let mut batch = if self.scsi_requests_states.len() < poll_mode_queue_depth {
1220 if let Poll::Ready(batch) = reader.poll_read_batch(cx) {
1221 batch.map_err(WorkerError::Queue)?
1222 } else {
1223 tracing::trace!("out of incoming packets");
1224 break;
1225 }
1226 } else {
1227 match reader.try_read_batch() {
1228 Ok(batch) => batch,
1229 Err(queue::TryReadError::Empty) => {
1230 tracing::trace!(
1231 pending_io_count = self.scsi_requests_states.len(),
1232 "out of incoming packets, keeping interrupts masked"
1233 );
1234 break;
1235 }
1236 Err(queue::TryReadError::Queue(err)) => Err(WorkerError::Queue(err))?,
1237 }
1238 };
1239
1240 let mut packets = batch.packets();
1241 loop {
1242 if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1243 break 'outer;
1244 }
1245 if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1249 tracing::trace!("out of ring space");
1250 break 'outer;
1251 }
1252
1253 let packet = if let Some(packet) = packets.next() {
1254 packet.map_err(WorkerError::Queue)?
1255 } else {
1256 break;
1257 };
1258
1259 if self.handle_packet(&mut writer, &packet)? {
1260 submissions += 1;
1261 }
1262 }
1263 }
1264
1265 if submissions == 0 {
1267 break;
1269 }
1270 total_submissions += submissions;
1271 }
1272
1273 if total_submissions != 0 || total_completions != 0 {
1274 self.stats.ios_submitted.add(total_submissions);
1275 self.stats
1276 .per_wake_submissions
1277 .add_sample(total_submissions);
1278 self.stats
1279 .per_wake_completions
1280 .add_sample(total_completions);
1281 self.stats.ios_completed.add(total_completions);
1282 } else {
1283 self.stats.wakes_spurious.increment();
1284 }
1285
1286 Poll::Pending
1287 }
1288
1289 fn handle_completion<M: RingMem>(
1290 &mut self,
1291 writer: &mut queue::WriteBatch<'_, M>,
1292 request_id: usize,
1293 result: ScsiResult,
1294 ) -> Result<(), WorkerError> {
1295 let state = self.scsi_requests_states.remove(request_id);
1296 let request_size = state.request.request_size;
1297
1298 assert_eq!(
1300 Arc::strong_count(&state.request) + Arc::weak_count(&state.request),
1301 1
1302 );
1303 self.full_request_pool.push(state.request);
1304
1305 let status = convert_srb_status_to_nt_status(result.srb_status);
1306 let mut payload = [0; 0x14];
1307 if let Some(sense) = result.sense_data {
1308 payload[..size_of_val(&sense)].copy_from_slice(sense.as_bytes());
1309 tracing::trace!(sense_info = ?payload, sense_key = payload[2], asc = payload[12], "execute_scsi");
1310 };
1311 let response = storvsp_protocol::ScsiRequest {
1312 length: size_of::<storvsp_protocol::ScsiRequest>() as u16,
1313 scsi_status: result.scsi_status,
1314 srb_status: SrbStatusAndFlags::new()
1315 .with_status(result.srb_status)
1316 .with_autosense_valid(result.sense_data.is_some()),
1317 data_transfer_length: result.tx as u32,
1318 cdb_length: storvsp_protocol::CDB16GENERIC_LENGTH as u8,
1319 sense_info_ex_length: storvsp_protocol::VMSCSI_SENSE_BUFFER_SIZE as u8,
1320 payload,
1321 ..storvsp_protocol::ScsiRequest::new_zeroed()
1322 };
1323 self.send_vmbus_packet(
1324 writer,
1325 OutgoingPacketType::Completion,
1326 request_size,
1327 state.transaction_id,
1328 storvsp_protocol::Operation::COMPLETE_IO,
1329 status,
1330 response.as_bytes(),
1331 )?;
1332 Ok(())
1333 }
1334
1335 fn handle_packet<M: RingMem>(
1336 &mut self,
1337 writer: &mut queue::WriteHalf<'_, M>,
1338 packet: &IncomingPacket<'_, M>,
1339 ) -> Result<bool, WorkerError> {
1340 let packet =
1341 parse_packet(packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
1342 let submitted_io = match packet.data {
1343 PacketData::ExecuteScsi(request) => {
1344 self.push_scsi_request(packet.transaction_id, request);
1345 true
1346 }
1347 PacketData::ResetAdapter | PacketData::ResetBus | PacketData::ResetLun => {
1348 self.send_completion(writer, &packet, storvsp_protocol::NtStatus::SUCCESS, &())?;
1350 false
1351 }
1352 PacketData::CreateSubChannels(new_subchannel_count) if self.channel_index == 0 => {
1353 if let Err(err) = self
1354 .channel_control
1355 .enable_subchannels(new_subchannel_count)
1356 {
1357 tracelimit::warn_ratelimited!(?err, "cannot create subchannels");
1358 self.send_completion(
1359 writer,
1360 &packet,
1361 storvsp_protocol::NtStatus::INVALID_PARAMETER,
1362 &(),
1363 )?;
1364 false
1365 } else {
1366 if let ProtocolState::Ready {
1368 subchannel_count, ..
1369 } = &mut *self.protocol.state.write()
1370 {
1371 *subchannel_count = new_subchannel_count;
1372 } else {
1373 unreachable!()
1374 }
1375
1376 self.send_completion(
1377 writer,
1378 &packet,
1379 storvsp_protocol::NtStatus::SUCCESS,
1380 &(),
1381 )?;
1382 false
1383 }
1384 }
1385 _ => {
1386 tracelimit::warn_ratelimited!(data = ?packet.data, "unexpected packet on ready");
1387 self.send_completion(
1388 writer,
1389 &packet,
1390 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1391 &(),
1392 )?;
1393 false
1394 }
1395 };
1396 Ok(submitted_io)
1397 }
1398
1399 fn push_scsi_request(&mut self, transaction_id: u64, full_request: Arc<ScsiRequestAndRange>) {
1400 let scsi_queue = self.scsi_queue.clone();
1401 let scsi_request_state = ScsiRequestState {
1402 transaction_id,
1403 request: full_request.clone(),
1404 };
1405 let request_id = self.scsi_requests_states.insert(scsi_request_state);
1406 let future = self
1407 .future_pool
1408 .pop()
1409 .unwrap_or_else(|| OversizedBox::new(()));
1410 let future = OversizedBox::refill(future, async move {
1411 scsi_queue.execute_scsi(full_request.as_ref()).await
1412 });
1413 let request = ScsiRequest::new(request_id, oversized_box::coerce!(future));
1414 self.scsi_requests.push(request);
1415 }
1416}
1417
1418impl<T: RingMem> Drop for Worker<T> {
1419 fn drop(&mut self) {
1420 self.inner
1421 .controller
1422 .remove_rescan_notification_source(&self.rescan_notification);
1423 }
1424}
1425
1426#[derive(Debug, Error)]
1427#[error("SCSI path {}:{}:{} is already in use", self.0.path, self.0.target, self.0.lun)]
1428pub struct ScsiPathInUse(pub ScsiPath);
1429
1430#[derive(Debug, Error)]
1431#[error("SCSI path {}:{}:{} is not in use", self.0.path, self.0.target, self.0.lun)]
1432pub struct ScsiPathNotInUse(ScsiPath);
1433
1434#[derive(Clone)]
1435struct ScsiRequestState {
1436 transaction_id: u64,
1437 request: Arc<ScsiRequestAndRange>,
1438}
1439
1440#[derive(Debug)]
1441struct ScsiRequestAndRange {
1442 external_data: Range,
1443 external_data_buf: MultiPagedRangeBuf,
1444 request: storvsp_protocol::ScsiRequest,
1445 request_size: usize,
1446}
1447
1448impl Inspect for ScsiRequestState {
1449 fn inspect(&self, req: inspect::Request<'_>) {
1450 req.respond()
1451 .field("transaction_id", self.transaction_id)
1452 .display(
1453 "address",
1454 &ScsiPath {
1455 path: self.request.request.path_id,
1456 target: self.request.request.target_id,
1457 lun: self.request.request.lun,
1458 },
1459 )
1460 .display_debug("operation", &ScsiOp(self.request.request.payload[0]));
1461 }
1462}
1463
1464impl StorageDevice {
1465 pub fn build_scsi(
1467 driver_source: &VmTaskDriverSource,
1468 controller: &ScsiController,
1469 instance_id: Guid,
1470 max_sub_channel_count: u16,
1471 io_queue_depth: u32,
1472 ) -> Self {
1473 Self::build_inner(
1474 driver_source,
1475 controller,
1476 instance_id,
1477 None,
1478 max_sub_channel_count,
1479 io_queue_depth,
1480 )
1481 }
1482
1483 pub fn build_ide(
1486 driver_source: &VmTaskDriverSource,
1487 channel_id: u8,
1488 device_id: u8,
1489 disk: ScsiControllerDisk,
1490 io_queue_depth: u32,
1491 ) -> Self {
1492 let path = ScsiPath {
1493 path: channel_id,
1494 target: device_id,
1495 lun: 0,
1496 };
1497
1498 let controller = ScsiController::new();
1499 controller.attach(path, disk).unwrap();
1500
1501 let instance_id = Guid {
1504 data1: channel_id.into(),
1505 data2: device_id.into(),
1506 data3: 0x8899,
1507 data4: [0; 8],
1508 };
1509 Self::build_inner(
1510 driver_source,
1511 &controller,
1512 instance_id,
1513 Some(path),
1514 0,
1515 io_queue_depth,
1516 )
1517 }
1518
1519 fn build_inner(
1520 driver_source: &VmTaskDriverSource,
1521 controller: &ScsiController,
1522 instance_id: Guid,
1523 ide_path: Option<ScsiPath>,
1524 max_sub_channel_count: u16,
1525 io_queue_depth: u32,
1526 ) -> Self {
1527 let workers = (0..max_sub_channel_count + 1)
1528 .map(|channel_index| WorkerAndDriver {
1529 worker: TaskControl::new(WorkerState),
1530 driver: driver_source
1531 .builder()
1532 .target_vp(0)
1533 .run_on_target(true)
1534 .build(format!("storvsp-{}-{}", instance_id, channel_index)),
1535 })
1536 .collect();
1537
1538 Self {
1539 instance_id,
1540 ide_path,
1541 workers,
1542 controller: controller.state.clone(),
1543 resources: Default::default(),
1544 max_sub_channel_count,
1545 driver_source: driver_source.clone(),
1546 protocol: Arc::new(Protocol {
1547 state: RwLock::new(ProtocolState::Init(InitState::Begin)),
1548 ready: Default::default(),
1549 }),
1550 io_queue_depth,
1551 }
1552 }
1553
1554 fn new_worker(
1555 &mut self,
1556 open_request: &OpenRequest,
1557 channel_index: u16,
1558 ) -> anyhow::Result<&mut Worker> {
1559 let controller = self.controller.clone();
1560
1561 let target_vp = open_request.open_data.target_vp.unwrap_or_default();
1564 let driver = self
1565 .driver_source
1566 .builder()
1567 .target_vp(target_vp)
1568 .run_on_target(true)
1569 .build(format!("storvsp-{}-{}", self.instance_id, channel_index));
1570
1571 let channel = gpadl_channel(&driver, &self.resources, open_request, channel_index)
1572 .context("failed to create vmbus channel")?;
1573
1574 let channel_control = self.resources.channel_control.clone();
1575
1576 tracing::debug!(
1577 target_vp = open_request.open_data.target_vp,
1578 channel_index,
1579 "packet processing starting...",
1580 );
1581
1582 let force_path_id = self.ide_path.map(|p| p.path);
1586
1587 let worker = Worker::new(
1588 controller,
1589 channel,
1590 channel_index,
1591 self.resources
1592 .offer_resources
1593 .guest_memory(open_request)
1594 .clone(),
1595 channel_control,
1596 self.io_queue_depth,
1597 self.protocol.clone(),
1598 force_path_id,
1599 )
1600 .map_err(RestoreError::Other)?;
1601
1602 self.workers[channel_index as usize]
1603 .driver
1604 .retarget_vp(target_vp);
1605
1606 Ok(self.workers[channel_index as usize].worker.insert(
1607 &driver,
1608 format!("storvsp worker {}-{}", self.instance_id, channel_index),
1609 worker,
1610 ))
1611 }
1612}
1613
1614#[derive(Clone)]
1616pub struct ScsiControllerDisk {
1617 disk: Arc<dyn AsyncScsiDisk>,
1618}
1619
1620impl ScsiControllerDisk {
1621 pub fn new(disk: Arc<dyn AsyncScsiDisk>) -> Self {
1623 Self { disk }
1624 }
1625}
1626
1627struct ScsiControllerState {
1628 disks: RwLock<HashMap<ScsiPath, ScsiControllerDisk>>,
1629 rescan_notification_source: Mutex<Vec<futures::channel::mpsc::Sender<()>>>,
1630 poll_mode_queue_depth: AtomicU32,
1631}
1632
1633pub struct ScsiController {
1634 state: Arc<ScsiControllerState>,
1635}
1636
1637impl ScsiController {
1638 pub fn new() -> Self {
1639 Self::new_with_poll_mode_queue_depth(None)
1640 }
1641
1642 pub fn new_with_poll_mode_queue_depth(poll_mode_queue_depth: Option<u32>) -> Self {
1643 Self {
1644 state: Arc::new(ScsiControllerState {
1645 disks: Default::default(),
1646 rescan_notification_source: Mutex::new(Vec::new()),
1647 poll_mode_queue_depth: AtomicU32::new(
1648 poll_mode_queue_depth.unwrap_or(DEFAULT_POLL_MODE_QUEUE_DEPTH),
1649 ),
1650 }),
1651 }
1652 }
1653
1654 pub fn attach(&self, path: ScsiPath, disk: ScsiControllerDisk) -> Result<(), ScsiPathInUse> {
1655 match self.state.disks.write().entry(path) {
1656 Entry::Occupied(_) => return Err(ScsiPathInUse(path)),
1657 Entry::Vacant(entry) => entry.insert(disk),
1658 };
1659 for source in self.state.rescan_notification_source.lock().iter_mut() {
1660 source.try_send(()).ok();
1663 }
1664 Ok(())
1665 }
1666
1667 pub fn remove(&self, path: ScsiPath) -> Result<(), ScsiPathNotInUse> {
1668 match self.state.disks.write().entry(path) {
1669 Entry::Vacant(_) => return Err(ScsiPathNotInUse(path)),
1670 Entry::Occupied(entry) => {
1671 entry.remove();
1672 }
1673 }
1674 for source in self.state.rescan_notification_source.lock().iter_mut() {
1675 source.try_send(()).ok();
1678 }
1679 Ok(())
1680 }
1681}
1682
1683impl ScsiControllerState {
1684 fn add_rescan_notification_source(&self, source: futures::channel::mpsc::Sender<()>) {
1685 self.rescan_notification_source.lock().push(source);
1686 }
1687
1688 fn remove_rescan_notification_source(&self, target: &futures::channel::mpsc::Receiver<()>) {
1689 let mut sources = self.rescan_notification_source.lock();
1690 if let Some(index) = sources
1691 .iter()
1692 .position(|source| source.is_connected_to(target))
1693 {
1694 sources.remove(index);
1695 }
1696 }
1697}
1698
1699#[async_trait]
1700impl VmbusDevice for StorageDevice {
1701 fn offer(&self) -> OfferParams {
1702 if let Some(path) = self.ide_path {
1703 let offer_properties = storvsp_protocol::OfferProperties {
1704 path_id: path.path,
1705 target_id: path.target,
1706 flags: storvsp_protocol::OFFER_PROPERTIES_FLAG_IDE_DEVICE,
1707 ..FromZeros::new_zeroed()
1708 };
1709 let mut user_defined = UserDefinedData::new_zeroed();
1710 offer_properties
1711 .write_to_prefix(&mut user_defined[..])
1712 .unwrap();
1713 OfferParams {
1714 interface_name: "ide-accel".to_owned(),
1715 instance_id: self.instance_id,
1716 interface_id: storvsp_protocol::IDE_ACCELERATOR_INTERFACE_ID,
1717 channel_type: ChannelType::Interface { user_defined },
1718 ..Default::default()
1719 }
1720 } else {
1721 OfferParams {
1722 interface_name: "scsi".to_owned(),
1723 instance_id: self.instance_id,
1724 interface_id: storvsp_protocol::SCSI_INTERFACE_ID,
1725 ..Default::default()
1726 }
1727 }
1728 }
1729
1730 fn max_subchannels(&self) -> u16 {
1731 self.max_sub_channel_count
1732 }
1733
1734 fn install(&mut self, resources: DeviceResources) {
1735 self.resources = resources;
1736 }
1737
1738 async fn open(
1739 &mut self,
1740 channel_index: u16,
1741 open_request: &OpenRequest,
1742 ) -> Result<(), ChannelOpenError> {
1743 tracing::debug!(channel_index, "scsi open channel");
1744 self.new_worker(open_request, channel_index)?;
1745 self.workers[channel_index as usize].worker.start();
1746 Ok(())
1747 }
1748
1749 async fn close(&mut self, channel_index: u16) {
1750 tracing::debug!(channel_index, "scsi close channel");
1751 let worker = &mut self.workers[channel_index as usize].worker;
1752 worker.stop().await;
1753 if worker.state_mut().is_some() {
1754 worker
1755 .state_mut()
1756 .unwrap()
1757 .wait_for_scsi_requests_complete()
1758 .await;
1759 worker.remove();
1760 }
1761 if channel_index == 0 {
1762 *self.protocol.state.write() = ProtocolState::Init(InitState::Begin);
1763 }
1764 }
1765
1766 async fn retarget_vp(&mut self, channel_index: u16, target_vp: u32) {
1767 self.workers[channel_index as usize]
1768 .driver
1769 .retarget_vp(target_vp);
1770 }
1771
1772 fn start(&mut self) {
1773 for task in self
1774 .workers
1775 .iter_mut()
1776 .filter(|task| task.worker.has_state() && !task.worker.is_running())
1777 {
1778 task.worker.start();
1779 }
1780 }
1781
1782 async fn stop(&mut self) {
1783 tracing::debug!(instance_id = ?self.instance_id, "StorageDevice stopping...");
1784 for task in self
1785 .workers
1786 .iter_mut()
1787 .filter(|task| task.worker.has_state() && task.worker.is_running())
1788 {
1789 task.worker.stop().await;
1790 }
1791 }
1792
1793 fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1794 Some(self)
1795 }
1796}
1797
1798#[async_trait]
1799impl SaveRestoreVmbusDevice for StorageDevice {
1800 async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1801 Ok(SavedStateBlob::new(self.save()?))
1802 }
1803
1804 async fn restore(
1805 &mut self,
1806 control: RestoreControl<'_>,
1807 state: SavedStateBlob,
1808 ) -> Result<(), RestoreError> {
1809 self.restore(control, state.parse()?).await
1810 }
1811}
1812
1813#[cfg(test)]
1814mod tests {
1815 use super::*;
1816 use crate::test_helpers::TestWorker;
1817 use crate::test_helpers::parse_guest_completion;
1818 use crate::test_helpers::parse_guest_completion_check_flags_status;
1819 use pal_async::DefaultDriver;
1820 use pal_async::async_test;
1821 use scsi::srb::SrbStatus;
1822 use test_with_tracing::test;
1823 use vmbus_channel::connected_async_channels;
1824
1825 impl Clone for ScsiController {
1829 fn clone(&self) -> Self {
1830 ScsiController {
1831 state: self.state.clone(),
1832 }
1833 }
1834 }
1835
1836 #[async_test]
1837 async fn test_channel_working(driver: DefaultDriver) {
1838 let (host, guest) = connected_async_channels(16 * 1024);
1840 let guest_queue = Queue::new(guest).unwrap();
1841
1842 let test_guest_mem = GuestMemory::allocate(16384);
1843 let controller = ScsiController::new();
1844 let disk = scsidisk::SimpleScsiDisk::new(
1845 disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
1846 Default::default(),
1847 );
1848 controller
1849 .attach(
1850 ScsiPath {
1851 path: 0,
1852 target: 0,
1853 lun: 0,
1854 },
1855 ScsiControllerDisk::new(Arc::new(disk)),
1856 )
1857 .unwrap();
1858
1859 let test_worker = TestWorker::start(
1860 controller.clone(),
1861 driver.clone(),
1862 test_guest_mem.clone(),
1863 host,
1864 None,
1865 );
1866
1867 let mut guest = test_helpers::TestGuest {
1868 queue: guest_queue,
1869 transaction_id: 0,
1870 };
1871
1872 guest.perform_protocol_negotiation().await;
1873
1874 const IO_LEN: usize = 4 * 1024;
1876 let write_buf = [7u8; IO_LEN];
1877 let write_gpa = 4 * 1024u64;
1878 test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
1879 guest
1880 .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
1881 .await;
1882
1883 guest
1884 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1885 .await;
1886
1887 let read_gpa = 8 * 1024u64;
1888 guest
1889 .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
1890 .await;
1891
1892 guest
1893 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1894 .await;
1895 let mut read_buf = [0u8; IO_LEN];
1896 test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
1897 for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
1898 assert_eq!(b1, b2);
1899 }
1900
1901 guest.verify_graceful_close(test_worker).await;
1903 }
1904
1905 #[async_test]
1906 async fn test_packet_sizes(driver: DefaultDriver) {
1907 let (host, guest) = connected_async_channels(16384);
1909 let guest_queue = Queue::new(guest).unwrap();
1910
1911 let test_guest_mem = GuestMemory::allocate(1024);
1912 let controller = ScsiController::new();
1913
1914 let _worker = TestWorker::start(
1915 controller.clone(),
1916 driver.clone(),
1917 test_guest_mem,
1918 host,
1919 None,
1920 );
1921
1922 let mut guest = test_helpers::TestGuest {
1923 queue: guest_queue,
1924 transaction_id: 0,
1925 };
1926
1927 let negotiate_packet = storvsp_protocol::Packet {
1928 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
1929 flags: 0,
1930 status: storvsp_protocol::NtStatus::SUCCESS,
1931 };
1932 guest
1933 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1934 .await;
1935
1936 guest.verify_completion(parse_guest_completion).await;
1937
1938 let header = storvsp_protocol::Packet {
1939 operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
1940 flags: 0,
1941 status: storvsp_protocol::NtStatus::SUCCESS,
1942 };
1943
1944 let mut buf = [0u8; 128];
1945 storvsp_protocol::ProtocolVersion {
1946 major_minor: !0,
1947 reserved: 0,
1948 }
1949 .write_to_prefix(&mut buf[..])
1950 .unwrap(); for &(len, resp_len) in &[(48, 48), (50, 56), (56, 56), (64, 64), (72, 64)] {
1953 guest
1954 .send_data_packet_sync(&[header.as_bytes(), &buf[..len - size_of_val(&header)]])
1955 .await;
1956
1957 guest
1958 .verify_completion(|packet| {
1959 let IncomingPacket::Completion(packet) = packet else {
1960 unreachable!()
1961 };
1962 assert_eq!(packet.reader().len(), resp_len);
1963 assert_eq!(
1964 packet
1965 .reader()
1966 .read_plain::<storvsp_protocol::Packet>()
1967 .unwrap()
1968 .status,
1969 storvsp_protocol::NtStatus::REVISION_MISMATCH
1970 );
1971 Ok(())
1972 })
1973 .await;
1974 }
1975 }
1976
1977 #[async_test]
1978 async fn test_wrong_first_packet(driver: DefaultDriver) {
1979 let (host, guest) = connected_async_channels(16384);
1981 let guest_queue = Queue::new(guest).unwrap();
1982
1983 let test_guest_mem = GuestMemory::allocate(1024);
1984 let controller = ScsiController::new();
1985
1986 let _worker = TestWorker::start(
1987 controller.clone(),
1988 driver.clone(),
1989 test_guest_mem,
1990 host,
1991 None,
1992 );
1993
1994 let mut guest = test_helpers::TestGuest {
1995 queue: guest_queue,
1996 transaction_id: 0,
1997 };
1998
1999 let negotiate_packet = storvsp_protocol::Packet {
2001 operation: storvsp_protocol::Operation::END_INITIALIZATION,
2002 flags: 0,
2003 status: storvsp_protocol::NtStatus::SUCCESS,
2004 };
2005 guest
2006 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2007 .await;
2008
2009 guest
2010 .verify_completion(|packet| {
2011 let IncomingPacket::Completion(packet) = packet else {
2012 unreachable!()
2013 };
2014 assert_eq!(
2015 packet
2016 .reader()
2017 .read_plain::<storvsp_protocol::Packet>()
2018 .unwrap()
2019 .status,
2020 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE
2021 );
2022 Ok(())
2023 })
2024 .await;
2025 }
2026
2027 #[async_test]
2028 async fn test_unrecognized_operation(driver: DefaultDriver) {
2029 let (host, guest) = connected_async_channels(16384);
2031 let guest_queue = Queue::new(guest).unwrap();
2032
2033 let test_guest_mem = GuestMemory::allocate(1024);
2034 let controller = ScsiController::new();
2035
2036 let worker = TestWorker::start(
2037 controller.clone(),
2038 driver.clone(),
2039 test_guest_mem,
2040 host,
2041 None,
2042 );
2043
2044 let mut guest = test_helpers::TestGuest {
2045 queue: guest_queue,
2046 transaction_id: 0,
2047 };
2048
2049 let negotiate_packet = storvsp_protocol::Packet {
2051 operation: storvsp_protocol::Operation::REMOVE_DEVICE,
2052 flags: 0,
2053 status: storvsp_protocol::NtStatus::SUCCESS,
2054 };
2055 guest
2056 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2057 .await;
2058
2059 match worker.teardown().await {
2060 Err(WorkerError::PacketError(PacketError::UnrecognizedOperation(
2061 storvsp_protocol::Operation::REMOVE_DEVICE,
2062 ))) => {}
2063 result => panic!("Worker failed with unexpected result {:?}!", result),
2064 }
2065 }
2066
2067 #[async_test]
2068 async fn test_too_many_subchannels(driver: DefaultDriver) {
2069 let (host, guest) = connected_async_channels(16384);
2071 let guest_queue = Queue::new(guest).unwrap();
2072
2073 let test_guest_mem = GuestMemory::allocate(1024);
2074 let controller = ScsiController::new();
2075
2076 let _worker = TestWorker::start(
2077 controller.clone(),
2078 driver.clone(),
2079 test_guest_mem,
2080 host,
2081 None,
2082 );
2083
2084 let mut guest = test_helpers::TestGuest {
2085 queue: guest_queue,
2086 transaction_id: 0,
2087 };
2088
2089 let negotiate_packet = storvsp_protocol::Packet {
2090 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2091 flags: 0,
2092 status: storvsp_protocol::NtStatus::SUCCESS,
2093 };
2094 guest
2095 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2096 .await;
2097 guest.verify_completion(parse_guest_completion).await;
2098
2099 let version_packet = storvsp_protocol::Packet {
2100 operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2101 flags: 0,
2102 status: storvsp_protocol::NtStatus::SUCCESS,
2103 };
2104 let version = storvsp_protocol::ProtocolVersion {
2105 major_minor: storvsp_protocol::VERSION_BLUE,
2106 reserved: 0,
2107 };
2108 guest
2109 .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2110 .await;
2111 guest.verify_completion(parse_guest_completion).await;
2112
2113 let properties_packet = storvsp_protocol::Packet {
2114 operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2115 flags: 0,
2116 status: storvsp_protocol::NtStatus::SUCCESS,
2117 };
2118 guest
2119 .send_data_packet_sync(&[properties_packet.as_bytes()])
2120 .await;
2121
2122 guest.verify_completion(parse_guest_completion).await;
2123
2124 let negotiate_packet = storvsp_protocol::Packet {
2125 operation: storvsp_protocol::Operation::CREATE_SUB_CHANNELS,
2126 flags: 0,
2127 status: storvsp_protocol::NtStatus::SUCCESS,
2128 };
2129 guest
2131 .send_data_packet_sync(&[negotiate_packet.as_bytes(), 1_u16.as_bytes()])
2132 .await;
2133
2134 guest
2135 .verify_completion(|packet| {
2136 let IncomingPacket::Completion(packet) = packet else {
2137 unreachable!()
2138 };
2139 assert_eq!(
2140 packet
2141 .reader()
2142 .read_plain::<storvsp_protocol::Packet>()
2143 .unwrap()
2144 .status,
2145 storvsp_protocol::NtStatus::INVALID_PARAMETER
2146 );
2147 Ok(())
2148 })
2149 .await;
2150 }
2151
2152 #[async_test]
2153 async fn test_begin_init_on_ready(driver: DefaultDriver) {
2154 let (host, guest) = connected_async_channels(16384);
2156 let guest_queue = Queue::new(guest).unwrap();
2157
2158 let test_guest_mem = GuestMemory::allocate(1024);
2159 let controller = ScsiController::new();
2160
2161 let _worker = TestWorker::start(
2162 controller.clone(),
2163 driver.clone(),
2164 test_guest_mem,
2165 host,
2166 None,
2167 );
2168
2169 let mut guest = test_helpers::TestGuest {
2170 queue: guest_queue,
2171 transaction_id: 0,
2172 };
2173
2174 guest.perform_protocol_negotiation().await;
2175
2176 let negotiate_packet = storvsp_protocol::Packet {
2178 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2179 flags: 0,
2180 status: storvsp_protocol::NtStatus::SUCCESS,
2181 };
2182 guest
2183 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2184 .await;
2185
2186 guest
2187 .verify_completion(|p| {
2188 parse_guest_completion_check_flags_status(
2189 p,
2190 0,
2191 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
2192 )
2193 })
2194 .await;
2195 }
2196
2197 #[async_test]
2198 async fn test_hot_add_remove(driver: DefaultDriver) {
2199 let (host, guest) = connected_async_channels(16 * 1024);
2201 let guest_queue = Queue::new(guest).unwrap();
2202
2203 let test_guest_mem = GuestMemory::allocate(16384);
2204 let controller = ScsiController::new();
2206
2207 let test_worker = TestWorker::start(
2208 controller.clone(),
2209 driver.clone(),
2210 test_guest_mem.clone(),
2211 host,
2212 None,
2213 );
2214
2215 let mut guest = test_helpers::TestGuest {
2216 queue: guest_queue,
2217 transaction_id: 0,
2218 };
2219
2220 guest.perform_protocol_negotiation().await;
2221
2222 let mut lun_list_buffer: [u8; 256] = [0; 256];
2224 let mut disk_count = 0;
2225 guest
2226 .send_report_luns_packet(ScsiPath::default(), 0, lun_list_buffer.len())
2227 .await;
2228 guest
2229 .verify_completion(|p| {
2230 test_helpers::parse_guest_completed_io_check_tx_len(p, SrbStatus::SUCCESS, Some(8))
2231 })
2232 .await;
2233 test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2234 let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2235 assert_eq!(lun_list_size, disk_count as u32 * 8);
2236
2237 const IO_LEN: usize = 4 * 1024;
2239 let write_buf = [7u8; IO_LEN];
2240 let write_gpa = 4 * 1024u64;
2241 test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2242
2243 guest
2244 .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2245 .await;
2246 guest
2247 .verify_completion(|p| {
2248 test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2249 })
2250 .await;
2251
2252 for lun in 0..4 {
2254 let disk = scsidisk::SimpleScsiDisk::new(
2255 disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
2256 Default::default(),
2257 );
2258 controller
2259 .attach(
2260 ScsiPath {
2261 path: 0,
2262 target: 0,
2263 lun,
2264 },
2265 ScsiControllerDisk::new(Arc::new(disk)),
2266 )
2267 .unwrap();
2268 guest
2269 .verify_completion(test_helpers::parse_guest_enumerate_bus)
2270 .await;
2271
2272 disk_count += 1;
2273 guest
2274 .send_report_luns_packet(ScsiPath::default(), 0, 256)
2275 .await;
2276 guest
2277 .verify_completion(|p| {
2278 test_helpers::parse_guest_completed_io_check_tx_len(
2279 p,
2280 SrbStatus::SUCCESS,
2281 Some((disk_count + 1) * 8),
2282 )
2283 })
2284 .await;
2285 test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2286 let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2287 assert_eq!(lun_list_size, disk_count as u32 * 8);
2288
2289 guest
2290 .send_write_packet(
2291 ScsiPath {
2292 path: 0,
2293 target: 0,
2294 lun,
2295 },
2296 write_gpa,
2297 1,
2298 IO_LEN,
2299 )
2300 .await;
2301 guest
2302 .verify_completion(|p| {
2303 test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS)
2304 })
2305 .await;
2306 }
2307
2308 for lun in 0..4 {
2310 controller
2311 .remove(ScsiPath {
2312 path: 0,
2313 target: 0,
2314 lun,
2315 })
2316 .unwrap();
2317 guest
2318 .verify_completion(test_helpers::parse_guest_enumerate_bus)
2319 .await;
2320
2321 disk_count -= 1;
2322 guest
2323 .send_report_luns_packet(ScsiPath::default(), 0, 4096)
2324 .await;
2325 guest
2326 .verify_completion(|p| {
2327 test_helpers::parse_guest_completed_io_check_tx_len(
2328 p,
2329 SrbStatus::SUCCESS,
2330 Some((disk_count + 1) * 8),
2331 )
2332 })
2333 .await;
2334 test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2335 let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2336 assert_eq!(lun_list_size, disk_count as u32 * 8);
2337
2338 guest
2339 .send_write_packet(
2340 ScsiPath {
2341 path: 0,
2342 target: 0,
2343 lun,
2344 },
2345 write_gpa,
2346 1,
2347 IO_LEN,
2348 )
2349 .await;
2350 guest
2351 .verify_completion(|p| {
2352 test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2353 })
2354 .await;
2355 }
2356
2357 guest.verify_graceful_close(test_worker).await;
2358 }
2359
2360 #[async_test]
2361 pub async fn test_async_disk(driver: DefaultDriver) {
2362 let device = disklayer_ram::ram_disk(64 * 1024, false).unwrap();
2363 let controller = ScsiController::new();
2364 let disk = ScsiControllerDisk::new(Arc::new(scsidisk::SimpleScsiDisk::new(
2365 device,
2366 Default::default(),
2367 )));
2368 controller
2369 .attach(
2370 ScsiPath {
2371 path: 0,
2372 target: 0,
2373 lun: 0,
2374 },
2375 disk,
2376 )
2377 .unwrap();
2378
2379 let (host, guest) = connected_async_channels(16 * 1024);
2380 let guest_queue = Queue::new(guest).unwrap();
2381
2382 let mut guest = test_helpers::TestGuest {
2383 queue: guest_queue,
2384 transaction_id: 0,
2385 };
2386
2387 let test_guest_mem = GuestMemory::allocate(16384);
2388 let worker = TestWorker::start(
2389 controller.clone(),
2390 &driver,
2391 test_guest_mem.clone(),
2392 host,
2393 None,
2394 );
2395
2396 let negotiate_packet = storvsp_protocol::Packet {
2397 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2398 flags: 0,
2399 status: storvsp_protocol::NtStatus::SUCCESS,
2400 };
2401 guest
2402 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2403 .await;
2404 guest.verify_completion(parse_guest_completion).await;
2405
2406 let version_packet = storvsp_protocol::Packet {
2407 operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2408 flags: 0,
2409 status: storvsp_protocol::NtStatus::SUCCESS,
2410 };
2411 let version = storvsp_protocol::ProtocolVersion {
2412 major_minor: storvsp_protocol::VERSION_BLUE,
2413 reserved: 0,
2414 };
2415 guest
2416 .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2417 .await;
2418 guest.verify_completion(parse_guest_completion).await;
2419
2420 let properties_packet = storvsp_protocol::Packet {
2421 operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2422 flags: 0,
2423 status: storvsp_protocol::NtStatus::SUCCESS,
2424 };
2425 guest
2426 .send_data_packet_sync(&[properties_packet.as_bytes()])
2427 .await;
2428 guest.verify_completion(parse_guest_completion).await;
2429
2430 let negotiate_packet = storvsp_protocol::Packet {
2431 operation: storvsp_protocol::Operation::END_INITIALIZATION,
2432 flags: 0,
2433 status: storvsp_protocol::NtStatus::SUCCESS,
2434 };
2435 guest
2436 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2437 .await;
2438 guest.verify_completion(parse_guest_completion).await;
2439
2440 const IO_LEN: usize = 4 * 1024;
2441 let write_buf = [7u8; IO_LEN];
2442 let write_gpa = 4 * 1024u64;
2443 test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2444 guest
2445 .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2446 .await;
2447 guest
2448 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2449 .await;
2450
2451 let read_gpa = 8 * 1024u64;
2452 guest
2453 .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
2454 .await;
2455 guest
2456 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2457 .await;
2458 let mut read_buf = [0u8; IO_LEN];
2459 test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
2460 for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
2461 assert_eq!(b1, b2);
2462 }
2463
2464 guest.verify_graceful_close(worker).await;
2465 }
2466}