1#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7#[cfg(feature = "ioperf")]
8pub mod ioperf;
9
10#[cfg(feature = "test")]
11pub mod test_helpers;
12
13#[cfg(not(feature = "test"))]
14mod test_helpers;
15
16pub mod resolver;
17mod save_restore;
18
19use crate::ring::gparange::GpnList;
20use crate::ring::gparange::MultiPagedRangeBuf;
21use anyhow::Context as _;
22use async_trait::async_trait;
23use fast_select::FastSelect;
24use futures::FutureExt;
25use futures::StreamExt;
26use futures::select_biased;
27use guestmem::AccessError;
28use guestmem::GuestMemory;
29use guestmem::MemoryRead;
30use guestmem::MemoryWrite;
31use guestmem::ranges::PagedRange;
32use guid::Guid;
33use inspect::Inspect;
34use inspect::InspectMut;
35use inspect_counters::Counter;
36use inspect_counters::Histogram;
37use oversized_box::OversizedBox;
38use parking_lot::Mutex;
39use parking_lot::RwLock;
40use ring::OutgoingPacketType;
41use scsi::AdditionalSenseCode;
42use scsi::ScsiOp;
43use scsi::ScsiStatus;
44use scsi::srb::SrbStatus;
45use scsi::srb::SrbStatusAndFlags;
46use scsi_buffers::RequestBuffers;
47use scsi_core::AsyncScsiDisk;
48use scsi_core::Request;
49use scsi_core::ScsiResult;
50use scsi_defs as scsi;
51use scsidisk::illegal_request_sense;
52use slab::Slab;
53use std::collections::hash_map::Entry;
54use std::collections::hash_map::HashMap;
55use std::fmt::Debug;
56use std::future::Future;
57use std::future::poll_fn;
58use std::pin::Pin;
59use std::sync::Arc;
60use std::task::Context;
61use std::task::Poll;
62use storvsp_resources::ScsiPath;
63use task_control::AsyncRun;
64use task_control::InspectTask;
65use task_control::StopTask;
66use task_control::TaskControl;
67use thiserror::Error;
68use tracing_helpers::ErrorValueExt;
69use unicycle::FuturesUnordered;
70use vmbus_async::queue;
71use vmbus_async::queue::ExternalDataError;
72use vmbus_async::queue::IncomingPacket;
73use vmbus_async::queue::OutgoingPacket;
74use vmbus_async::queue::Queue;
75use vmbus_channel::RawAsyncChannel;
76use vmbus_channel::bus::ChannelType;
77use vmbus_channel::bus::OfferParams;
78use vmbus_channel::bus::OpenRequest;
79use vmbus_channel::channel::ChannelControl;
80use vmbus_channel::channel::ChannelOpenError;
81use vmbus_channel::channel::DeviceResources;
82use vmbus_channel::channel::RestoreControl;
83use vmbus_channel::channel::SaveRestoreVmbusDevice;
84use vmbus_channel::channel::VmbusDevice;
85use vmbus_channel::gpadl_ring::GpadlRingMem;
86use vmbus_channel::gpadl_ring::gpadl_channel;
87use vmbus_core::protocol::UserDefinedData;
88use vmbus_ring as ring;
89use vmbus_ring::RingMem;
90use vmcore::save_restore::RestoreError;
91use vmcore::save_restore::SaveError;
92use vmcore::save_restore::SavedStateBlob;
93use vmcore::vm_task::VmTaskDriver;
94use vmcore::vm_task::VmTaskDriverSource;
95use zerocopy::FromBytes;
96use zerocopy::FromZeros;
97use zerocopy::Immutable;
98use zerocopy::IntoBytes;
99use zerocopy::KnownLayout;
100
101pub struct StorageDevice {
102 instance_id: Guid,
103 ide_path: Option<ScsiPath>,
104 workers: Vec<WorkerAndDriver>,
105 controller: Arc<ScsiControllerState>,
106 resources: DeviceResources,
107 driver_source: VmTaskDriverSource,
108 max_sub_channel_count: u16,
109 protocol: Arc<Protocol>,
110 io_queue_depth: u32,
111}
112
113#[derive(Inspect)]
114struct WorkerAndDriver {
115 #[inspect(flatten)]
116 worker: TaskControl<WorkerState, Worker>,
117 driver: VmTaskDriver,
118}
119
120struct WorkerState;
121
122impl InspectMut for StorageDevice {
123 fn inspect_mut(&mut self, req: inspect::Request<'_>) {
124 let mut resp = req.respond();
125
126 let disks = self.controller.disks.read();
127 for (path, controller_disk) in disks.iter() {
128 resp.child(&format!("disks/{}", path), |req| {
129 controller_disk.disk.inspect(req);
130 });
131 }
132
133 resp.fields(
134 "channels",
135 self.workers
136 .iter()
137 .filter(|task| task.worker.has_state())
138 .enumerate(),
139 );
140 }
141}
142
143struct Worker<T: RingMem = GpadlRingMem> {
144 inner: WorkerInner,
145 rescan_notification: futures::channel::mpsc::Receiver<()>,
146 fast_select: FastSelect,
147 queue: Queue<T>,
148}
149
150struct Protocol {
151 state: RwLock<ProtocolState>,
152 ready: event_listener::Event,
154}
155
156struct WorkerInner {
157 protocol: Arc<Protocol>,
158 request_size: usize,
159 controller: Arc<ScsiControllerState>,
160 channel_index: u16,
161 scsi_queue: Arc<ScsiCommandQueue>,
162 scsi_requests: FuturesUnordered<ScsiRequest>,
163 scsi_requests_states: Slab<ScsiRequestState>,
164 full_request_pool: Vec<Arc<ScsiRequestAndRange>>,
165 future_pool: Vec<OversizedBox<(), ScsiOpStorage>>,
166 channel_control: ChannelControl,
167 max_io_queue_depth: usize,
168 stats: WorkerStats,
169}
170
171#[derive(Debug, Default, Inspect)]
172struct WorkerStats {
173 ios_submitted: Counter,
174 ios_completed: Counter,
175 wakes: Counter,
176 wakes_spurious: Counter,
177 per_wake_submissions: Histogram<10>,
178 per_wake_completions: Histogram<10>,
179}
180
181#[repr(u16)]
182#[derive(Copy, Clone, Debug, Inspect, PartialEq, Eq, PartialOrd, Ord)]
183enum Version {
184 Win6 = storvsp_protocol::VERSION_WIN6,
185 Win7 = storvsp_protocol::VERSION_WIN7,
186 Win8 = storvsp_protocol::VERSION_WIN8,
187 Blue = storvsp_protocol::VERSION_BLUE,
188}
189
190#[derive(Debug, Error)]
191#[error("protocol version {0:#x} not supported")]
192struct UnsupportedVersion(u16);
193
194impl Version {
195 fn parse(major_minor: u16) -> Result<Self, UnsupportedVersion> {
196 let version = match major_minor {
197 storvsp_protocol::VERSION_WIN6 => Self::Win6,
198 storvsp_protocol::VERSION_WIN7 => Self::Win7,
199 storvsp_protocol::VERSION_WIN8 => Self::Win8,
200 storvsp_protocol::VERSION_BLUE => Self::Blue,
201 version => return Err(UnsupportedVersion(version)),
202 };
203 assert_eq!(version as u16, major_minor);
204 Ok(version)
205 }
206
207 fn max_request_size(&self) -> usize {
208 match self {
209 Version::Win8 | Version::Blue => storvsp_protocol::SCSI_REQUEST_LEN_V2,
210 Version::Win6 | Version::Win7 => storvsp_protocol::SCSI_REQUEST_LEN_V1,
211 }
212 }
213}
214
215#[derive(Copy, Clone)]
216enum ProtocolState {
217 Init(InitState),
218 Ready {
219 version: Version,
220 subchannel_count: u16,
221 },
222}
223
224#[derive(Copy, Clone, Debug)]
225enum InitState {
226 Begin,
227 QueryVersion,
228 QueryProperties {
229 version: Version,
230 },
231 EndInitialization {
232 version: Version,
233 subchannel_count: Option<u16>,
234 },
235}
236
237type ScsiOpStorage = [u64; SCSI_REQUEST_STACK_SIZE / 8];
245type ScsiOpFuture = Pin<OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>>;
246
247const SCSI_REQUEST_STACK_SIZE: usize = scsi_core::ASYNC_SCSI_DISK_STACK_SIZE + 272;
252
253struct ScsiRequest {
254 request_id: usize,
255 future: Option<ScsiOpFuture>,
256}
257
258impl ScsiRequest {
259 fn new(
260 request_id: usize,
261 future: OversizedBox<dyn Future<Output = ScsiResult> + Send, ScsiOpStorage>,
262 ) -> Self {
263 Self {
264 request_id,
265 future: Some(future.into()),
266 }
267 }
268}
269
270impl Future for ScsiRequest {
271 type Output = (usize, ScsiResult, OversizedBox<(), ScsiOpStorage>);
272
273 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
274 let this = self.get_mut();
275 let future = this.future.as_mut().unwrap().as_mut();
276 let result = std::task::ready!(future.poll(cx));
277 let future = this.future.take().unwrap();
279 Poll::Ready((this.request_id, result, OversizedBox::empty_pinned(future)))
280 }
281}
282
283#[derive(Debug, Error)]
284enum WorkerError {
285 #[error("packet error")]
286 PacketError(#[source] PacketError),
287 #[error("queue error")]
288 Queue(#[source] queue::Error),
289 #[error("queue should have enough space but no longer does")]
290 NotEnoughSpace,
291}
292
293#[derive(Debug, Error)]
294enum PacketError {
295 #[error("Not transactional")]
296 NotTransactional,
297 #[error("Unrecognized operation {0:?}")]
298 UnrecognizedOperation(storvsp_protocol::Operation),
299 #[error("Invalid packet type")]
300 InvalidPacketType,
301 #[error("Invalid data transfer length")]
302 InvalidDataTransferLength,
303 #[error("Access error: {0}")]
304 Access(#[source] AccessError),
305 #[error("Range error")]
306 Range(#[source] ExternalDataError),
307}
308
309#[derive(Debug, Default, Clone)]
310struct Range {
311 buf: MultiPagedRangeBuf<GpnList>,
312 len: usize,
313 is_write: bool,
314}
315
316impl Range {
317 fn new(
318 buf: MultiPagedRangeBuf<GpnList>,
319 request: &storvsp_protocol::ScsiRequest,
320 ) -> Option<Self> {
321 let len = request.data_transfer_length as usize;
322 let is_write = request.data_in != 0;
323 if buf.range_count() > 1 || (len > 0 && buf.first()?.len() < len) {
326 return None;
327 }
328 Some(Self { buf, len, is_write })
329 }
330
331 fn buffer<'a>(&'a self, guest_memory: &'a GuestMemory) -> RequestBuffers<'a> {
332 let mut range = self.buf.first().unwrap_or_else(PagedRange::empty);
333 range.truncate(self.len);
334 RequestBuffers::new(guest_memory, range, self.is_write)
335 }
336}
337
338#[derive(Debug)]
339struct Packet {
340 data: PacketData,
341 transaction_id: u64,
342 request_size: usize,
343}
344
345#[derive(Debug)]
346enum PacketData {
347 BeginInitialization,
348 EndInitialization,
349 QueryProtocolVersion(u16),
350 QueryProperties,
351 CreateSubChannels(u16),
352 ExecuteScsi(Arc<ScsiRequestAndRange>),
353 ResetBus,
354 ResetAdapter,
355 ResetLun,
356}
357
358#[derive(Debug)]
359pub struct RangeError;
360
361fn parse_packet<T: RingMem>(
362 packet: &IncomingPacket<'_, T>,
363 pool: &mut Vec<Arc<ScsiRequestAndRange>>,
364) -> Result<Packet, PacketError> {
365 let packet = match packet {
366 IncomingPacket::Completion(_) => return Err(PacketError::InvalidPacketType),
367 IncomingPacket::Data(packet) => packet,
368 };
369 let transaction_id = packet
370 .transaction_id()
371 .ok_or(PacketError::NotTransactional)?;
372
373 let mut reader = packet.reader();
374 let header: storvsp_protocol::Packet = reader.read_plain().map_err(PacketError::Access)?;
375 let request_size = reader.len().min(storvsp_protocol::SCSI_REQUEST_LEN_MAX);
379 let data = match header.operation {
380 storvsp_protocol::Operation::BEGIN_INITIALIZATION => PacketData::BeginInitialization,
381 storvsp_protocol::Operation::END_INITIALIZATION => PacketData::EndInitialization,
382 storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION => {
383 let mut version = storvsp_protocol::ProtocolVersion::new_zeroed();
384 reader
385 .read(version.as_mut_bytes())
386 .map_err(PacketError::Access)?;
387 PacketData::QueryProtocolVersion(version.major_minor)
388 }
389 storvsp_protocol::Operation::QUERY_PROPERTIES => PacketData::QueryProperties,
390 storvsp_protocol::Operation::EXECUTE_SRB => {
391 let mut full_request = pool.pop().unwrap_or_else(|| {
392 Arc::new(ScsiRequestAndRange {
393 external_data: Range::default(),
394 request: storvsp_protocol::ScsiRequest::new_zeroed(),
395 request_size,
396 })
397 });
398
399 {
400 let full_request = Arc::get_mut(&mut full_request).unwrap();
401 let request_buf = &mut full_request.request.as_mut_bytes()[..request_size];
402 reader.read(request_buf).map_err(PacketError::Access)?;
403
404 let buf = packet.read_external_ranges().map_err(PacketError::Range)?;
405
406 full_request.external_data = Range::new(buf, &full_request.request)
407 .ok_or(PacketError::InvalidDataTransferLength)?;
408 }
409
410 PacketData::ExecuteScsi(full_request)
411 }
412 storvsp_protocol::Operation::RESET_LUN => PacketData::ResetLun,
413 storvsp_protocol::Operation::RESET_ADAPTER => PacketData::ResetAdapter,
414 storvsp_protocol::Operation::RESET_BUS => PacketData::ResetBus,
415 storvsp_protocol::Operation::CREATE_SUB_CHANNELS => {
416 let mut sub_channel_count: u16 = 0;
417 reader
418 .read(sub_channel_count.as_mut_bytes())
419 .map_err(PacketError::Access)?;
420 PacketData::CreateSubChannels(sub_channel_count)
421 }
422 _ => return Err(PacketError::UnrecognizedOperation(header.operation)),
423 };
424
425 if let PacketData::ExecuteScsi(_) = data {
426 tracing::trace!(transaction_id, ?data, "parse_packet");
427 } else {
428 tracing::debug!(transaction_id, ?data, "parse_packet");
429 }
430
431 Ok(Packet {
432 data,
433 request_size,
434 transaction_id,
435 })
436}
437
438impl WorkerInner {
439 fn send_vmbus_packet<M: RingMem>(
440 &mut self,
441 writer: &mut queue::WriteBatch<'_, M>,
442 packet_type: OutgoingPacketType<'_>,
443 request_size: usize,
444 transaction_id: u64,
445 operation: storvsp_protocol::Operation,
446 status: storvsp_protocol::NtStatus,
447 payload: &[u8],
448 ) -> Result<(), WorkerError> {
449 let header = storvsp_protocol::Packet {
450 operation,
451 flags: 0,
452 status,
453 };
454
455 let packet_size = size_of_val(&header) + request_size;
456
457 let len = size_of_val(&header) + size_of_val(payload);
462 let padding = [0; storvsp_protocol::SCSI_REQUEST_LEN_MAX];
463 let (payload_bytes, padding_bytes) = if len > packet_size {
464 (&payload[..packet_size - size_of_val(&header)], &[][..])
465 } else {
466 (payload, &padding[..packet_size - len])
467 };
468 assert_eq!(
469 size_of_val(&header) + payload_bytes.len() + padding_bytes.len(),
470 packet_size
471 );
472 writer
473 .try_write(&OutgoingPacket {
474 transaction_id,
475 packet_type,
476 payload: &[header.as_bytes(), payload_bytes, padding_bytes],
477 })
478 .map_err(|err| match err {
479 queue::TryWriteError::Full(_) => WorkerError::NotEnoughSpace,
480 queue::TryWriteError::Queue(err) => WorkerError::Queue(err),
481 })
482 }
483
484 fn send_packet<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
485 &mut self,
486 writer: &mut queue::WriteHalf<'_, M>,
487 operation: storvsp_protocol::Operation,
488 status: storvsp_protocol::NtStatus,
489 payload: &P,
490 ) -> Result<(), WorkerError> {
491 self.send_vmbus_packet(
492 &mut writer.batched(),
493 OutgoingPacketType::InBandNoCompletion,
494 self.request_size,
495 0,
496 operation,
497 status,
498 payload.as_bytes(),
499 )
500 }
501
502 fn send_completion<M: RingMem, P: IntoBytes + Immutable + KnownLayout>(
503 &mut self,
504 writer: &mut queue::WriteHalf<'_, M>,
505 packet: &Packet,
506 status: storvsp_protocol::NtStatus,
507 payload: &P,
508 ) -> Result<(), WorkerError> {
509 self.send_vmbus_packet(
510 &mut writer.batched(),
511 OutgoingPacketType::Completion,
512 packet.request_size,
513 packet.transaction_id,
514 storvsp_protocol::Operation::COMPLETE_IO,
515 status,
516 payload.as_bytes(),
517 )
518 }
519}
520
521struct ScsiCommandQueue {
522 controller: Arc<ScsiControllerState>,
523 mem: GuestMemory,
524 force_path_id: Option<u8>,
525}
526
527impl ScsiCommandQueue {
528 async fn execute_scsi(
529 &self,
530 external_data: &Range,
531 request: &storvsp_protocol::ScsiRequest,
532 ) -> ScsiResult {
533 let op = ScsiOp(request.payload[0]);
534 let external_data = external_data.buffer(&self.mem);
535
536 tracing::trace!(
537 path_id = request.path_id,
538 target_id = request.target_id,
539 lun = request.lun,
540 op = ?op,
541 "execute_scsi start...",
542 );
543
544 let path_id = self.force_path_id.unwrap_or(request.path_id);
545
546 let controller_disk = self
547 .controller
548 .disks
549 .read()
550 .get(&ScsiPath {
551 path: path_id,
552 target: request.target_id,
553 lun: request.lun,
554 })
555 .cloned();
556
557 let result = match op {
558 ScsiOp::REPORT_LUNS => {
559 const HEADER_SIZE: usize = size_of::<scsi::LunList>();
560 let mut luns: Vec<u8> = self
561 .controller
562 .disks
563 .read()
564 .iter()
565 .flat_map(|(path, _)| {
566 if request.path_id == path.path && request.target_id == path.target {
569 Some(path.lun)
570 } else {
571 None
572 }
573 })
574 .collect();
575 luns.sort_unstable();
576 let mut data: Vec<u64> = vec![0; luns.len() + 1];
577 let header = scsi::LunList {
578 length: (luns.len() as u32 * 8).into(),
579 reserved: [0; 4],
580 };
581 data.as_mut_bytes()[..HEADER_SIZE].copy_from_slice(header.as_bytes());
582 for (i, lun) in luns.iter().enumerate() {
583 data[i + 1].as_mut_bytes()[..2].copy_from_slice(&(*lun as u16).to_be_bytes());
584 }
585 if external_data.len() >= HEADER_SIZE {
586 let tx = std::cmp::min(external_data.len(), data.as_bytes().len());
587 external_data.writer().write(&data.as_bytes()[..tx]).map_or(
588 ScsiResult {
589 scsi_status: ScsiStatus::CHECK_CONDITION,
590 srb_status: SrbStatus::INVALID_REQUEST,
591 tx: 0,
592 sense_data: Some(illegal_request_sense(
593 AdditionalSenseCode::INVALID_CDB,
594 )),
595 },
596 |_| ScsiResult {
597 scsi_status: ScsiStatus::GOOD,
598 srb_status: SrbStatus::SUCCESS,
599 tx,
600 sense_data: None,
601 },
602 )
603 } else {
604 ScsiResult {
605 scsi_status: ScsiStatus::GOOD,
606 srb_status: SrbStatus::SUCCESS,
607 tx: 0,
608 sense_data: None,
609 }
610 }
611 }
612 _ if controller_disk.is_some() => {
613 let mut cdb = [0; 16];
614 cdb.copy_from_slice(&request.payload[0..storvsp_protocol::CDB16GENERIC_LENGTH]);
615 controller_disk
616 .unwrap()
617 .disk
618 .execute_scsi(
619 &external_data,
620 &Request {
621 cdb,
622 srb_flags: request.srb_flags,
623 },
624 )
625 .await
626 }
627 ScsiOp::INQUIRY => {
628 let cdb = scsi::CdbInquiry::ref_from_prefix(&request.payload)
629 .unwrap()
630 .0; if external_data.len() < cdb.allocation_length.get() as usize
632 || request.data_in != storvsp_protocol::SCSI_IOCTL_DATA_IN
633 || (cdb.allocation_length.get() as usize) < size_of::<scsi::InquiryDataHeader>()
634 {
635 ScsiResult {
636 scsi_status: ScsiStatus::CHECK_CONDITION,
637 srb_status: SrbStatus::INVALID_REQUEST,
638 tx: 0,
639 sense_data: Some(illegal_request_sense(AdditionalSenseCode::INVALID_CDB)),
640 }
641 } else {
642 let enable_vpd = cdb.flags.vpd();
643 if enable_vpd || cdb.page_code != 0 {
644 ScsiResult {
646 scsi_status: ScsiStatus::CHECK_CONDITION,
647 srb_status: SrbStatus::INVALID_REQUEST,
648 tx: 0,
649 sense_data: Some(illegal_request_sense(
650 AdditionalSenseCode::INVALID_CDB,
651 )),
652 }
653 } else {
654 const LOGICAL_UNIT_NOT_PRESENT_DEVICE: u8 = 0x7F;
655 let mut data = scsidisk::INQUIRY_DATA_TEMPLATE;
656 data.header.device_type = LOGICAL_UNIT_NOT_PRESENT_DEVICE;
657
658 if request.lun != 0 {
659 data.vendor_id = [0; 8];
661 data.product_id = [0; 16];
662 data.product_revision_level = [0; 4];
663 }
664
665 let datab = data.as_bytes();
666 let tx = std::cmp::min(
667 cdb.allocation_length.get() as usize,
668 size_of::<scsi::InquiryData>(),
669 );
670 external_data.writer().write(&datab[..tx]).map_or(
671 ScsiResult {
672 scsi_status: ScsiStatus::CHECK_CONDITION,
673 srb_status: SrbStatus::INVALID_REQUEST,
674 tx: 0,
675 sense_data: Some(illegal_request_sense(
676 AdditionalSenseCode::INVALID_CDB,
677 )),
678 },
679 |_| ScsiResult {
680 scsi_status: ScsiStatus::GOOD,
681 srb_status: SrbStatus::SUCCESS,
682 tx,
683 sense_data: None,
684 },
685 )
686 }
687 }
688 }
689 _ => ScsiResult {
690 scsi_status: ScsiStatus::CHECK_CONDITION,
691 srb_status: SrbStatus::INVALID_LUN,
692 tx: 0,
693 sense_data: None,
694 },
695 };
696
697 tracing::trace!(
698 path_id = request.path_id,
699 target_id = request.target_id,
700 lun = request.lun,
701 op = ?op,
702 result = ?result,
703 "execute_scsi completed.",
704 );
705 result
706 }
707}
708
709impl<T: RingMem + 'static> Worker<T> {
710 fn new(
711 controller: Arc<ScsiControllerState>,
712 channel: RawAsyncChannel<T>,
713 channel_index: u16,
714 mem: GuestMemory,
715 channel_control: ChannelControl,
716 io_queue_depth: u32,
717 protocol: Arc<Protocol>,
718 force_path_id: Option<u8>,
719 ) -> anyhow::Result<Self> {
720 let queue = Queue::new(channel)?;
721 #[expect(clippy::disallowed_methods)] let (source, target) = futures::channel::mpsc::channel(1);
723 controller.add_rescan_notification_source(source);
724
725 let max_io_queue_depth = io_queue_depth.max(1) as usize;
726 Ok(Self {
727 inner: WorkerInner {
728 protocol,
729 request_size: storvsp_protocol::SCSI_REQUEST_LEN_V1,
730 controller: controller.clone(),
731 channel_index,
732 scsi_queue: Arc::new(ScsiCommandQueue {
733 controller,
734 mem,
735 force_path_id,
736 }),
737 scsi_requests: FuturesUnordered::new(),
738 scsi_requests_states: Slab::with_capacity(max_io_queue_depth),
739 channel_control,
740 max_io_queue_depth,
741 future_pool: Vec::new(),
742 full_request_pool: Vec::new(),
743 stats: Default::default(),
744 },
745 queue,
746 rescan_notification: target,
747 fast_select: FastSelect::new(),
748 })
749 }
750
751 async fn wait_for_scsi_requests_complete(&mut self) {
752 tracing::debug!(
753 channel_index = self.inner.channel_index,
754 "wait for IOs completed..."
755 );
756 while let Some((id, _, _)) = self.inner.scsi_requests.next().await {
757 self.inner.scsi_requests_states.remove(id);
758 }
759 }
760}
761
762impl InspectTask<Worker> for WorkerState {
763 fn inspect(&self, req: inspect::Request<'_>, worker: Option<&Worker>) {
764 if let Some(worker) = worker {
765 let mut resp = req.respond();
766 if worker.inner.channel_index == 0 {
767 let (state, version, subchannel_count) = match *worker.inner.protocol.state.read() {
768 ProtocolState::Init(state) => match state {
769 InitState::Begin => ("begin_init", None, None),
770 InitState::QueryVersion => ("query_version", None, None),
771 InitState::QueryProperties { version } => {
772 ("query_properties", Some(version), None)
773 }
774 InitState::EndInitialization {
775 version,
776 subchannel_count,
777 } => ("end_init", Some(version), subchannel_count),
778 },
779 ProtocolState::Ready {
780 version,
781 subchannel_count,
782 } => ("ready", Some(version), Some(subchannel_count)),
783 };
784 resp.field("state", state)
785 .field("version", version)
786 .field("subchannel_count", subchannel_count);
787 }
788 resp.field("pending_packets", worker.inner.scsi_requests_states.len())
789 .fields("io", worker.inner.scsi_requests_states.iter())
790 .field("stats", &worker.inner.stats)
791 .field("ring", &worker.queue)
792 .field("max_io_queue_depth", worker.inner.max_io_queue_depth);
793 }
794 }
795}
796
797impl<T: 'static + Send + Sync + RingMem> AsyncRun<Worker<T>> for WorkerState {
798 async fn run(
799 &mut self,
800 stop: &mut StopTask<'_>,
801 worker: &mut Worker<T>,
802 ) -> Result<(), task_control::Cancelled> {
803 let fut = async {
804 if worker.inner.channel_index == 0 {
805 worker.process_primary().await
806 } else {
807 let protocol_version = loop {
810 let listener = worker.inner.protocol.ready.listen();
811 if let ProtocolState::Ready { version, .. } =
812 *worker.inner.protocol.state.read()
813 {
814 break version;
815 }
816 tracing::debug!("subchannel waiting for initialization to end");
817 listener.await
818 };
819 worker
820 .inner
821 .process_ready(&mut worker.queue, protocol_version)
822 .await
823 }
824 };
825
826 match stop.until_stopped(fut).await? {
827 Ok(_) => {}
828 Err(e) => tracing::error!(error = e.as_error(), "process_packets error"),
829 }
830 Ok(())
831 }
832}
833
834impl WorkerInner {
835 async fn next_packet<'a, M: RingMem>(
838 &mut self,
839 reader: &'a mut queue::ReadHalf<'a, M>,
840 ) -> Result<Packet, WorkerError> {
841 let packet = reader.read().await.map_err(WorkerError::Queue)?;
842 let stor_packet =
843 parse_packet(&packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
844 Ok(stor_packet)
845 }
846
847 fn poll_for_ring_space<M: RingMem>(
854 &mut self,
855 cx: &mut Context<'_>,
856 writer: &mut queue::WriteHalf<'_, M>,
857 ) -> Poll<Result<(), WorkerError>> {
858 writer
859 .poll_ready(cx, MAX_VMBUS_PACKET_SIZE)
860 .map_err(WorkerError::Queue)
861 }
862}
863
864const MAX_VMBUS_PACKET_SIZE: usize = ring::PacketSize::in_band(
865 size_of::<storvsp_protocol::Packet>() + storvsp_protocol::SCSI_REQUEST_LEN_MAX,
866);
867
868impl<T: RingMem> Worker<T> {
869 async fn process_primary(&mut self) -> Result<(), WorkerError> {
871 loop {
872 let current_state = *self.inner.protocol.state.read();
873 match current_state {
874 ProtocolState::Ready { version, .. } => {
875 break loop {
876 select_biased! {
877 r = self.inner.process_ready(&mut self.queue, version).fuse() => break r,
878 _ = self.fast_select.select((self.rescan_notification.select_next_some(),)).fuse() => {
879 if version >= Version::Win7
880 {
881 self.inner.send_packet(&mut self.queue.split().1, storvsp_protocol::Operation::ENUMERATE_BUS, storvsp_protocol::NtStatus::SUCCESS, &())?;
882 }
883 }
884 }
885 };
886 }
887 ProtocolState::Init(state) => {
888 let (mut reader, mut writer) = self.queue.split();
889
890 poll_fn(|cx| self.inner.poll_for_ring_space(cx, &mut writer)).await?;
893
894 tracing::debug!(?state, "process_primary");
895 match state {
896 InitState::Begin => {
897 let packet = self.inner.next_packet(&mut reader).await?;
898 if let PacketData::BeginInitialization = packet.data {
899 self.inner.send_completion(
900 &mut writer,
901 &packet,
902 storvsp_protocol::NtStatus::SUCCESS,
903 &(),
904 )?;
905 *self.inner.protocol.state.write() =
906 ProtocolState::Init(InitState::QueryVersion);
907 } else {
908 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
909 self.inner.send_completion(
910 &mut writer,
911 &packet,
912 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
913 &(),
914 )?;
915 }
916 }
917 InitState::QueryVersion => {
918 let packet = self.inner.next_packet(&mut reader).await?;
919 if let PacketData::QueryProtocolVersion(major_minor) = packet.data {
920 if let Ok(version) = Version::parse(major_minor) {
921 self.inner.send_completion(
922 &mut writer,
923 &packet,
924 storvsp_protocol::NtStatus::SUCCESS,
925 &storvsp_protocol::ProtocolVersion {
926 major_minor,
927 reserved: 0,
928 },
929 )?;
930 self.inner.request_size = version.max_request_size();
931 *self.inner.protocol.state.write() =
932 ProtocolState::Init(InitState::QueryProperties { version });
933
934 tracelimit::info_ratelimited!(
935 ?version,
936 "scsi version negotiated"
937 );
938 } else {
939 self.inner.send_completion(
940 &mut writer,
941 &packet,
942 storvsp_protocol::NtStatus::REVISION_MISMATCH,
943 &storvsp_protocol::ProtocolVersion {
944 major_minor,
945 reserved: 0,
946 },
947 )?;
948 *self.inner.protocol.state.write() =
949 ProtocolState::Init(InitState::QueryVersion);
950 }
951 } else {
952 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
953 self.inner.send_completion(
954 &mut writer,
955 &packet,
956 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
957 &(),
958 )?;
959 }
960 }
961 InitState::QueryProperties { version } => {
962 let packet = self.inner.next_packet(&mut reader).await?;
963 if let PacketData::QueryProperties = packet.data {
964 let multi_channel_supported = version >= Version::Win8;
965
966 self.inner.send_completion(
967 &mut writer,
968 &packet,
969 storvsp_protocol::NtStatus::SUCCESS,
970 &storvsp_protocol::ChannelProperties {
971 max_transfer_bytes: 0x40000, flags: {
973 if multi_channel_supported {
974 storvsp_protocol::STORAGE_CHANNEL_SUPPORTS_MULTI_CHANNEL
975 } else {
976 0
977 }
978 },
979 maximum_sub_channel_count: if multi_channel_supported {
980 self.inner.channel_control.max_subchannels()
981 } else {
982 0
983 },
984 reserved: 0,
985 reserved2: 0,
986 reserved3: [0, 0],
987 },
988 )?;
989 *self.inner.protocol.state.write() =
990 ProtocolState::Init(InitState::EndInitialization {
991 version,
992 subchannel_count: if multi_channel_supported {
993 None
994 } else {
995 Some(0)
996 },
997 });
998 } else {
999 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1000 self.inner.send_completion(
1001 &mut writer,
1002 &packet,
1003 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1004 &(),
1005 )?;
1006 }
1007 }
1008 InitState::EndInitialization {
1009 version,
1010 subchannel_count,
1011 } => {
1012 let packet = self.inner.next_packet(&mut reader).await?;
1013 match packet.data {
1014 PacketData::CreateSubChannels(sub_channel_count)
1015 if subchannel_count.is_none() =>
1016 {
1017 if let Err(err) = self
1018 .inner
1019 .channel_control
1020 .enable_subchannels(sub_channel_count)
1021 {
1022 tracelimit::warn_ratelimited!(
1023 ?err,
1024 "cannot enable subchannels"
1025 );
1026 self.inner.send_completion(
1027 &mut writer,
1028 &packet,
1029 storvsp_protocol::NtStatus::INVALID_PARAMETER,
1030 &(),
1031 )?;
1032 } else {
1033 self.inner.send_completion(
1034 &mut writer,
1035 &packet,
1036 storvsp_protocol::NtStatus::SUCCESS,
1037 &(),
1038 )?;
1039 *self.inner.protocol.state.write() =
1040 ProtocolState::Init(InitState::EndInitialization {
1041 version,
1042 subchannel_count: Some(sub_channel_count),
1043 });
1044 }
1045 }
1046 PacketData::EndInitialization => {
1047 self.inner.send_completion(
1048 &mut writer,
1049 &packet,
1050 storvsp_protocol::NtStatus::SUCCESS,
1051 &(),
1052 )?;
1053 self.rescan_notification.try_next().ok();
1056 *self.inner.protocol.state.write() = ProtocolState::Ready {
1057 version,
1058 subchannel_count: subchannel_count.unwrap_or(0),
1059 };
1060 self.inner.protocol.ready.notify(usize::MAX);
1063 }
1064 _ => {
1065 tracelimit::warn_ratelimited!(?state, data = ?packet.data, "unexpected packet order");
1066 self.inner.send_completion(
1067 &mut writer,
1068 &packet,
1069 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1070 &(),
1071 )?;
1072 }
1073 }
1074 }
1075 }
1076 }
1077 }
1078 }
1079 }
1080}
1081
1082fn convert_srb_status_to_nt_status(srb_status: SrbStatus) -> storvsp_protocol::NtStatus {
1083 match srb_status {
1084 SrbStatus::BUSY => storvsp_protocol::NtStatus::DEVICE_BUSY,
1085 SrbStatus::SUCCESS => storvsp_protocol::NtStatus::SUCCESS,
1086 SrbStatus::INVALID_LUN
1087 | SrbStatus::INVALID_TARGET_ID
1088 | SrbStatus::NO_DEVICE
1089 | SrbStatus::NO_HBA => storvsp_protocol::NtStatus::DEVICE_DOES_NOT_EXIST,
1090 SrbStatus::COMMAND_TIMEOUT | SrbStatus::TIMEOUT => storvsp_protocol::NtStatus::IO_TIMEOUT,
1091 SrbStatus::SELECTION_TIMEOUT => storvsp_protocol::NtStatus::DEVICE_NOT_CONNECTED,
1092 SrbStatus::BAD_FUNCTION | SrbStatus::BAD_SRB_BLOCK_LENGTH => {
1093 storvsp_protocol::NtStatus::INVALID_DEVICE_REQUEST
1094 }
1095 SrbStatus::DATA_OVERRUN => storvsp_protocol::NtStatus::BUFFER_OVERFLOW,
1096 SrbStatus::REQUEST_FLUSHED => storvsp_protocol::NtStatus::UNSUCCESSFUL,
1097 SrbStatus::ABORTED => storvsp_protocol::NtStatus::CANCELLED,
1098 _ => storvsp_protocol::NtStatus::IO_DEVICE_ERROR,
1099 }
1100}
1101
1102impl WorkerInner {
1103 async fn process_ready<M: RingMem>(
1105 &mut self,
1106 queue: &mut Queue<M>,
1107 protocol_version: Version,
1108 ) -> Result<(), WorkerError> {
1109 self.request_size = protocol_version.max_request_size();
1110 poll_fn(|cx| self.poll_process_ready(cx, queue)).await
1111 }
1112
1113 fn poll_process_ready<M: RingMem>(
1115 &mut self,
1116 cx: &mut Context<'_>,
1117 queue: &mut Queue<M>,
1118 ) -> Poll<Result<(), WorkerError>> {
1119 self.stats.wakes.increment();
1120
1121 let (mut reader, mut writer) = queue.split();
1122 let mut total_completions = 0;
1123 let mut total_submissions = 0;
1124
1125 loop {
1126 'outer: while !self.scsi_requests_states.is_empty() {
1128 {
1129 let mut batch = writer.batched();
1130 loop {
1131 if !batch
1135 .can_write(MAX_VMBUS_PACKET_SIZE)
1136 .map_err(WorkerError::Queue)?
1137 {
1138 break;
1140 }
1141 if let Poll::Ready(Some((request_id, result, future))) =
1142 self.scsi_requests.poll_next_unpin(cx)
1143 {
1144 self.future_pool.push(future);
1145 self.handle_completion(&mut batch, request_id, result)?;
1146 total_completions += 1;
1147 } else {
1148 tracing::trace!("out of completions");
1149 break 'outer;
1150 }
1151 }
1152 }
1153
1154 if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1156 tracing::trace!("out of ring space");
1157 break;
1158 }
1159 }
1160
1161 let mut submissions = 0;
1162 'outer: loop {
1164 if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1165 break;
1166 }
1167 let mut batch = if self.scsi_requests_states.is_empty() {
1168 if let Poll::Ready(batch) = reader.poll_read_batch(cx) {
1169 batch.map_err(WorkerError::Queue)?
1170 } else {
1171 tracing::trace!("out of incoming packets");
1172 break;
1173 }
1174 } else {
1175 match reader.try_read_batch() {
1176 Ok(batch) => batch,
1177 Err(queue::TryReadError::Empty) => {
1178 tracing::trace!(
1179 pending_io_count = self.scsi_requests_states.len(),
1180 "out of incoming packets, keeping interrupts masked"
1181 );
1182 break;
1183 }
1184 Err(queue::TryReadError::Queue(err)) => Err(WorkerError::Queue(err))?,
1185 }
1186 };
1187
1188 let mut packets = batch.packets();
1189 loop {
1190 if self.scsi_requests_states.len() >= self.max_io_queue_depth {
1191 break 'outer;
1192 }
1193 if self.poll_for_ring_space(cx, &mut writer).is_pending() {
1197 tracing::trace!("out of ring space");
1198 break 'outer;
1199 }
1200
1201 let packet = if let Some(packet) = packets.next() {
1202 packet.map_err(WorkerError::Queue)?
1203 } else {
1204 break;
1205 };
1206
1207 if self.handle_packet(&mut writer, &packet)? {
1208 submissions += 1;
1209 }
1210 }
1211 }
1212
1213 if submissions == 0 {
1215 break;
1217 }
1218 total_submissions += submissions;
1219 }
1220
1221 if total_submissions != 0 || total_completions != 0 {
1222 self.stats.ios_submitted.add(total_submissions);
1223 self.stats
1224 .per_wake_submissions
1225 .add_sample(total_submissions);
1226 self.stats
1227 .per_wake_completions
1228 .add_sample(total_completions);
1229 self.stats.ios_completed.add(total_completions);
1230 } else {
1231 self.stats.wakes_spurious.increment();
1232 }
1233
1234 Poll::Pending
1235 }
1236
1237 fn handle_completion<M: RingMem>(
1238 &mut self,
1239 writer: &mut queue::WriteBatch<'_, M>,
1240 request_id: usize,
1241 result: ScsiResult,
1242 ) -> Result<(), WorkerError> {
1243 let state = self.scsi_requests_states.remove(request_id);
1244 let request_size = state.request.request_size;
1245
1246 assert_eq!(
1248 Arc::strong_count(&state.request) + Arc::weak_count(&state.request),
1249 1
1250 );
1251 self.full_request_pool.push(state.request);
1252
1253 let status = convert_srb_status_to_nt_status(result.srb_status);
1254 let mut payload = [0; 0x14];
1255 if let Some(sense) = result.sense_data {
1256 payload[..size_of_val(&sense)].copy_from_slice(sense.as_bytes());
1257 tracing::trace!(sense_info = ?payload, sense_key = payload[2], asc = payload[12], "execute_scsi");
1258 };
1259 let response = storvsp_protocol::ScsiRequest {
1260 length: size_of::<storvsp_protocol::ScsiRequest>() as u16,
1261 scsi_status: result.scsi_status,
1262 srb_status: SrbStatusAndFlags::new()
1263 .with_status(result.srb_status)
1264 .with_autosense_valid(result.sense_data.is_some()),
1265 data_transfer_length: result.tx as u32,
1266 cdb_length: storvsp_protocol::CDB16GENERIC_LENGTH as u8,
1267 sense_info_ex_length: storvsp_protocol::VMSCSI_SENSE_BUFFER_SIZE as u8,
1268 payload,
1269 ..storvsp_protocol::ScsiRequest::new_zeroed()
1270 };
1271 self.send_vmbus_packet(
1272 writer,
1273 OutgoingPacketType::Completion,
1274 request_size,
1275 state.transaction_id,
1276 storvsp_protocol::Operation::COMPLETE_IO,
1277 status,
1278 response.as_bytes(),
1279 )?;
1280 Ok(())
1281 }
1282
1283 fn handle_packet<M: RingMem>(
1284 &mut self,
1285 writer: &mut queue::WriteHalf<'_, M>,
1286 packet: &IncomingPacket<'_, M>,
1287 ) -> Result<bool, WorkerError> {
1288 let packet =
1289 parse_packet(packet, &mut self.full_request_pool).map_err(WorkerError::PacketError)?;
1290 let submitted_io = match packet.data {
1291 PacketData::ExecuteScsi(request) => {
1292 self.push_scsi_request(packet.transaction_id, request);
1293 true
1294 }
1295 PacketData::ResetAdapter | PacketData::ResetBus | PacketData::ResetLun => {
1296 self.send_completion(writer, &packet, storvsp_protocol::NtStatus::SUCCESS, &())?;
1298 false
1299 }
1300 PacketData::CreateSubChannels(new_subchannel_count) if self.channel_index == 0 => {
1301 if let Err(err) = self
1302 .channel_control
1303 .enable_subchannels(new_subchannel_count)
1304 {
1305 tracelimit::warn_ratelimited!(?err, "cannot create subchannels");
1306 self.send_completion(
1307 writer,
1308 &packet,
1309 storvsp_protocol::NtStatus::INVALID_PARAMETER,
1310 &(),
1311 )?;
1312 false
1313 } else {
1314 if let ProtocolState::Ready {
1316 subchannel_count, ..
1317 } = &mut *self.protocol.state.write()
1318 {
1319 *subchannel_count = new_subchannel_count;
1320 } else {
1321 unreachable!()
1322 }
1323
1324 self.send_completion(
1325 writer,
1326 &packet,
1327 storvsp_protocol::NtStatus::SUCCESS,
1328 &(),
1329 )?;
1330 false
1331 }
1332 }
1333 _ => {
1334 tracelimit::warn_ratelimited!(data = ?packet.data, "unexpected packet on ready");
1335 self.send_completion(
1336 writer,
1337 &packet,
1338 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
1339 &(),
1340 )?;
1341 false
1342 }
1343 };
1344 Ok(submitted_io)
1345 }
1346
1347 fn push_scsi_request(&mut self, transaction_id: u64, full_request: Arc<ScsiRequestAndRange>) {
1348 let scsi_queue = self.scsi_queue.clone();
1349 let scsi_request_state = ScsiRequestState {
1350 transaction_id,
1351 request: full_request.clone(),
1352 };
1353 let request_id = self.scsi_requests_states.insert(scsi_request_state);
1354 let future = self
1355 .future_pool
1356 .pop()
1357 .unwrap_or_else(|| OversizedBox::new(()));
1358 let future = OversizedBox::refill(future, async move {
1359 scsi_queue
1360 .execute_scsi(&full_request.external_data, &full_request.request)
1361 .await
1362 });
1363 let request = ScsiRequest::new(request_id, oversized_box::coerce!(future));
1364 self.scsi_requests.push(request);
1365 }
1366}
1367
1368impl<T: RingMem> Drop for Worker<T> {
1369 fn drop(&mut self) {
1370 self.inner
1371 .controller
1372 .remove_rescan_notification_source(&self.rescan_notification);
1373 }
1374}
1375
1376#[derive(Debug, Error)]
1377#[error("SCSI path {}:{}:{} is already in use", self.0.path, self.0.target, self.0.lun)]
1378pub struct ScsiPathInUse(pub ScsiPath);
1379
1380#[derive(Debug, Error)]
1381#[error("SCSI path {}:{}:{} is not in use", self.0.path, self.0.target, self.0.lun)]
1382pub struct ScsiPathNotInUse(ScsiPath);
1383
1384#[derive(Clone)]
1385struct ScsiRequestState {
1386 transaction_id: u64,
1387 request: Arc<ScsiRequestAndRange>,
1388}
1389
1390#[derive(Debug)]
1391struct ScsiRequestAndRange {
1392 external_data: Range,
1393 request: storvsp_protocol::ScsiRequest,
1394 request_size: usize,
1395}
1396
1397impl Inspect for ScsiRequestState {
1398 fn inspect(&self, req: inspect::Request<'_>) {
1399 req.respond()
1400 .field("transaction_id", self.transaction_id)
1401 .display(
1402 "address",
1403 &ScsiPath {
1404 path: self.request.request.path_id,
1405 target: self.request.request.target_id,
1406 lun: self.request.request.lun,
1407 },
1408 )
1409 .display_debug("operation", &ScsiOp(self.request.request.payload[0]));
1410 }
1411}
1412
1413impl StorageDevice {
1414 pub fn build_scsi(
1416 driver_source: &VmTaskDriverSource,
1417 controller: &ScsiController,
1418 instance_id: Guid,
1419 max_sub_channel_count: u16,
1420 io_queue_depth: u32,
1421 ) -> Self {
1422 Self::build_inner(
1423 driver_source,
1424 controller,
1425 instance_id,
1426 None,
1427 max_sub_channel_count,
1428 io_queue_depth,
1429 )
1430 }
1431
1432 pub fn build_ide(
1435 driver_source: &VmTaskDriverSource,
1436 channel_id: u8,
1437 device_id: u8,
1438 disk: ScsiControllerDisk,
1439 io_queue_depth: u32,
1440 ) -> Self {
1441 let path = ScsiPath {
1442 path: channel_id,
1443 target: device_id,
1444 lun: 0,
1445 };
1446
1447 let controller = ScsiController::new();
1448 controller.attach(path, disk).unwrap();
1449
1450 let instance_id = Guid {
1453 data1: channel_id.into(),
1454 data2: device_id.into(),
1455 data3: 0x8899,
1456 data4: [0; 8],
1457 };
1458 Self::build_inner(
1459 driver_source,
1460 &controller,
1461 instance_id,
1462 Some(path),
1463 0,
1464 io_queue_depth,
1465 )
1466 }
1467
1468 fn build_inner(
1469 driver_source: &VmTaskDriverSource,
1470 controller: &ScsiController,
1471 instance_id: Guid,
1472 ide_path: Option<ScsiPath>,
1473 max_sub_channel_count: u16,
1474 io_queue_depth: u32,
1475 ) -> Self {
1476 let workers = (0..max_sub_channel_count + 1)
1477 .map(|channel_index| WorkerAndDriver {
1478 worker: TaskControl::new(WorkerState),
1479 driver: driver_source
1480 .builder()
1481 .target_vp(0)
1482 .run_on_target(true)
1483 .build(format!("storvsp-{}-{}", instance_id, channel_index)),
1484 })
1485 .collect();
1486
1487 Self {
1488 instance_id,
1489 ide_path,
1490 workers,
1491 controller: controller.state.clone(),
1492 resources: Default::default(),
1493 max_sub_channel_count,
1494 driver_source: driver_source.clone(),
1495 protocol: Arc::new(Protocol {
1496 state: RwLock::new(ProtocolState::Init(InitState::Begin)),
1497 ready: Default::default(),
1498 }),
1499 io_queue_depth,
1500 }
1501 }
1502
1503 fn new_worker(
1504 &mut self,
1505 open_request: &OpenRequest,
1506 channel_index: u16,
1507 ) -> anyhow::Result<&mut Worker> {
1508 let controller = self.controller.clone();
1509
1510 let driver = self
1511 .driver_source
1512 .builder()
1513 .target_vp(open_request.open_data.target_vp)
1514 .run_on_target(true)
1515 .build(format!("storvsp-{}-{}", self.instance_id, channel_index));
1516
1517 let channel = gpadl_channel(&driver, &self.resources, open_request, channel_index)
1518 .context("failed to create vmbus channel")?;
1519
1520 let channel_control = self.resources.channel_control.clone();
1521
1522 tracing::debug!(
1523 target_vp = open_request.open_data.target_vp,
1524 channel_index,
1525 "packet processing starting...",
1526 );
1527
1528 let force_path_id = self.ide_path.map(|p| p.path);
1532
1533 let worker = Worker::new(
1534 controller,
1535 channel,
1536 channel_index,
1537 self.resources
1538 .offer_resources
1539 .guest_memory(open_request)
1540 .clone(),
1541 channel_control,
1542 self.io_queue_depth,
1543 self.protocol.clone(),
1544 force_path_id,
1545 )
1546 .map_err(RestoreError::Other)?;
1547
1548 self.workers[channel_index as usize]
1549 .driver
1550 .retarget_vp(open_request.open_data.target_vp);
1551
1552 Ok(self.workers[channel_index as usize].worker.insert(
1553 &driver,
1554 format!("storvsp worker {}-{}", self.instance_id, channel_index),
1555 worker,
1556 ))
1557 }
1558}
1559
1560#[derive(Clone)]
1562pub struct ScsiControllerDisk {
1563 disk: Arc<dyn AsyncScsiDisk>,
1564}
1565
1566impl ScsiControllerDisk {
1567 pub fn new(disk: Arc<dyn AsyncScsiDisk>) -> Self {
1569 Self { disk }
1570 }
1571}
1572
1573struct ScsiControllerState {
1574 disks: RwLock<HashMap<ScsiPath, ScsiControllerDisk>>,
1575 rescan_notification_source: Mutex<Vec<futures::channel::mpsc::Sender<()>>>,
1576}
1577
1578pub struct ScsiController {
1579 state: Arc<ScsiControllerState>,
1580}
1581
1582impl ScsiController {
1583 pub fn new() -> Self {
1584 Self {
1585 state: Arc::new(ScsiControllerState {
1586 disks: Default::default(),
1587 rescan_notification_source: Mutex::new(Vec::new()),
1588 }),
1589 }
1590 }
1591
1592 pub fn attach(&self, path: ScsiPath, disk: ScsiControllerDisk) -> Result<(), ScsiPathInUse> {
1593 match self.state.disks.write().entry(path) {
1594 Entry::Occupied(_) => return Err(ScsiPathInUse(path)),
1595 Entry::Vacant(entry) => entry.insert(disk),
1596 };
1597 for source in self.state.rescan_notification_source.lock().iter_mut() {
1598 source.try_send(()).ok();
1601 }
1602 Ok(())
1603 }
1604
1605 pub fn remove(&self, path: ScsiPath) -> Result<(), ScsiPathNotInUse> {
1606 match self.state.disks.write().entry(path) {
1607 Entry::Vacant(_) => return Err(ScsiPathNotInUse(path)),
1608 Entry::Occupied(entry) => {
1609 entry.remove();
1610 }
1611 }
1612 for source in self.state.rescan_notification_source.lock().iter_mut() {
1613 source.try_send(()).ok();
1616 }
1617 Ok(())
1618 }
1619}
1620
1621impl ScsiControllerState {
1622 fn add_rescan_notification_source(&self, source: futures::channel::mpsc::Sender<()>) {
1623 self.rescan_notification_source.lock().push(source);
1624 }
1625
1626 fn remove_rescan_notification_source(&self, target: &futures::channel::mpsc::Receiver<()>) {
1627 let mut sources = self.rescan_notification_source.lock();
1628 if let Some(index) = sources
1629 .iter()
1630 .position(|source| source.is_connected_to(target))
1631 {
1632 sources.remove(index);
1633 }
1634 }
1635}
1636
1637#[async_trait]
1638impl VmbusDevice for StorageDevice {
1639 fn offer(&self) -> OfferParams {
1640 if let Some(path) = self.ide_path {
1641 let offer_properties = storvsp_protocol::OfferProperties {
1642 path_id: path.path,
1643 target_id: path.target,
1644 flags: storvsp_protocol::OFFER_PROPERTIES_FLAG_IDE_DEVICE,
1645 ..FromZeros::new_zeroed()
1646 };
1647 let mut user_defined = UserDefinedData::new_zeroed();
1648 offer_properties
1649 .write_to_prefix(&mut user_defined[..])
1650 .unwrap();
1651 OfferParams {
1652 interface_name: "ide-accel".to_owned(),
1653 instance_id: self.instance_id,
1654 interface_id: storvsp_protocol::IDE_ACCELERATOR_INTERFACE_ID,
1655 channel_type: ChannelType::Interface { user_defined },
1656 ..Default::default()
1657 }
1658 } else {
1659 OfferParams {
1660 interface_name: "scsi".to_owned(),
1661 instance_id: self.instance_id,
1662 interface_id: storvsp_protocol::SCSI_INTERFACE_ID,
1663 ..Default::default()
1664 }
1665 }
1666 }
1667
1668 fn max_subchannels(&self) -> u16 {
1669 self.max_sub_channel_count
1670 }
1671
1672 fn install(&mut self, resources: DeviceResources) {
1673 self.resources = resources;
1674 }
1675
1676 async fn open(
1677 &mut self,
1678 channel_index: u16,
1679 open_request: &OpenRequest,
1680 ) -> Result<(), ChannelOpenError> {
1681 tracing::debug!(channel_index, "scsi open channel");
1682 self.new_worker(open_request, channel_index)?;
1683 self.workers[channel_index as usize].worker.start();
1684 Ok(())
1685 }
1686
1687 async fn close(&mut self, channel_index: u16) {
1688 tracing::debug!(channel_index, "scsi close channel");
1689 let worker = &mut self.workers[channel_index as usize].worker;
1690 worker.stop().await;
1691 if worker.state_mut().is_some() {
1692 worker
1693 .state_mut()
1694 .unwrap()
1695 .wait_for_scsi_requests_complete()
1696 .await;
1697 worker.remove();
1698 }
1699 if channel_index == 0 {
1700 *self.protocol.state.write() = ProtocolState::Init(InitState::Begin);
1701 }
1702 }
1703
1704 async fn retarget_vp(&mut self, channel_index: u16, target_vp: u32) {
1705 self.workers[channel_index as usize]
1706 .driver
1707 .retarget_vp(target_vp);
1708 }
1709
1710 fn start(&mut self) {
1711 for task in self
1712 .workers
1713 .iter_mut()
1714 .filter(|task| task.worker.has_state() && !task.worker.is_running())
1715 {
1716 task.worker.start();
1717 }
1718 }
1719
1720 async fn stop(&mut self) {
1721 tracing::debug!(instance_id = ?self.instance_id, "StorageDevice stopping...");
1722 for task in self
1723 .workers
1724 .iter_mut()
1725 .filter(|task| task.worker.has_state() && task.worker.is_running())
1726 {
1727 task.worker.stop().await;
1728 }
1729 }
1730
1731 fn supports_save_restore(&mut self) -> Option<&mut dyn SaveRestoreVmbusDevice> {
1732 Some(self)
1733 }
1734}
1735
1736#[async_trait]
1737impl SaveRestoreVmbusDevice for StorageDevice {
1738 async fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
1739 Ok(SavedStateBlob::new(self.save()?))
1740 }
1741
1742 async fn restore(
1743 &mut self,
1744 control: RestoreControl<'_>,
1745 state: SavedStateBlob,
1746 ) -> Result<(), RestoreError> {
1747 self.restore(control, state.parse()?).await
1748 }
1749}
1750
1751#[cfg(test)]
1752mod tests {
1753 use super::*;
1754 use crate::test_helpers::TestWorker;
1755 use crate::test_helpers::parse_guest_completion;
1756 use crate::test_helpers::parse_guest_completion_check_flags_status;
1757 use pal_async::DefaultDriver;
1758 use pal_async::async_test;
1759 use scsi::srb::SrbStatus;
1760 use test_with_tracing::test;
1761 use vmbus_channel::connected_async_channels;
1762
1763 impl Clone for ScsiController {
1767 fn clone(&self) -> Self {
1768 ScsiController {
1769 state: self.state.clone(),
1770 }
1771 }
1772 }
1773
1774 #[async_test]
1775 async fn test_channel_working(driver: DefaultDriver) {
1776 let (host, guest) = connected_async_channels(16 * 1024);
1778 let guest_queue = Queue::new(guest).unwrap();
1779
1780 let test_guest_mem = GuestMemory::allocate(16384);
1781 let controller = ScsiController::new();
1782 let disk = scsidisk::SimpleScsiDisk::new(
1783 disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
1784 Default::default(),
1785 );
1786 controller
1787 .attach(
1788 ScsiPath {
1789 path: 0,
1790 target: 0,
1791 lun: 0,
1792 },
1793 ScsiControllerDisk::new(Arc::new(disk)),
1794 )
1795 .unwrap();
1796
1797 let test_worker = TestWorker::start(
1798 controller.clone(),
1799 driver.clone(),
1800 test_guest_mem.clone(),
1801 host,
1802 None,
1803 );
1804
1805 let mut guest = test_helpers::TestGuest {
1806 queue: guest_queue,
1807 transaction_id: 0,
1808 };
1809
1810 guest.perform_protocol_negotiation().await;
1811
1812 const IO_LEN: usize = 4 * 1024;
1814 let write_buf = [7u8; IO_LEN];
1815 let write_gpa = 4 * 1024u64;
1816 test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
1817 guest
1818 .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
1819 .await;
1820
1821 guest
1822 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1823 .await;
1824
1825 let read_gpa = 8 * 1024u64;
1826 guest
1827 .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
1828 .await;
1829
1830 guest
1831 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
1832 .await;
1833 let mut read_buf = [0u8; IO_LEN];
1834 test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
1835 for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
1836 assert_eq!(b1, b2);
1837 }
1838
1839 guest.verify_graceful_close(test_worker).await;
1841 }
1842
1843 #[async_test]
1844 async fn test_packet_sizes(driver: DefaultDriver) {
1845 let (host, guest) = connected_async_channels(16384);
1847 let guest_queue = Queue::new(guest).unwrap();
1848
1849 let test_guest_mem = GuestMemory::allocate(1024);
1850 let controller = ScsiController::new();
1851
1852 let _worker = TestWorker::start(
1853 controller.clone(),
1854 driver.clone(),
1855 test_guest_mem,
1856 host,
1857 None,
1858 );
1859
1860 let mut guest = test_helpers::TestGuest {
1861 queue: guest_queue,
1862 transaction_id: 0,
1863 };
1864
1865 let negotiate_packet = storvsp_protocol::Packet {
1866 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
1867 flags: 0,
1868 status: storvsp_protocol::NtStatus::SUCCESS,
1869 };
1870 guest
1871 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1872 .await;
1873
1874 guest.verify_completion(parse_guest_completion).await;
1875
1876 let header = storvsp_protocol::Packet {
1877 operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
1878 flags: 0,
1879 status: storvsp_protocol::NtStatus::SUCCESS,
1880 };
1881
1882 let mut buf = [0u8; 128];
1883 storvsp_protocol::ProtocolVersion {
1884 major_minor: !0,
1885 reserved: 0,
1886 }
1887 .write_to_prefix(&mut buf[..])
1888 .unwrap(); for &(len, resp_len) in &[(48, 48), (50, 56), (56, 56), (64, 64), (72, 64)] {
1891 guest
1892 .send_data_packet_sync(&[header.as_bytes(), &buf[..len - size_of_val(&header)]])
1893 .await;
1894
1895 guest
1896 .verify_completion(|packet| {
1897 let IncomingPacket::Completion(packet) = packet else {
1898 unreachable!()
1899 };
1900 assert_eq!(packet.reader().len(), resp_len);
1901 assert_eq!(
1902 packet
1903 .reader()
1904 .read_plain::<storvsp_protocol::Packet>()
1905 .unwrap()
1906 .status,
1907 storvsp_protocol::NtStatus::REVISION_MISMATCH
1908 );
1909 Ok(())
1910 })
1911 .await;
1912 }
1913 }
1914
1915 #[async_test]
1916 async fn test_wrong_first_packet(driver: DefaultDriver) {
1917 let (host, guest) = connected_async_channels(16384);
1919 let guest_queue = Queue::new(guest).unwrap();
1920
1921 let test_guest_mem = GuestMemory::allocate(1024);
1922 let controller = ScsiController::new();
1923
1924 let _worker = TestWorker::start(
1925 controller.clone(),
1926 driver.clone(),
1927 test_guest_mem,
1928 host,
1929 None,
1930 );
1931
1932 let mut guest = test_helpers::TestGuest {
1933 queue: guest_queue,
1934 transaction_id: 0,
1935 };
1936
1937 let negotiate_packet = storvsp_protocol::Packet {
1939 operation: storvsp_protocol::Operation::END_INITIALIZATION,
1940 flags: 0,
1941 status: storvsp_protocol::NtStatus::SUCCESS,
1942 };
1943 guest
1944 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1945 .await;
1946
1947 guest
1948 .verify_completion(|packet| {
1949 let IncomingPacket::Completion(packet) = packet else {
1950 unreachable!()
1951 };
1952 assert_eq!(
1953 packet
1954 .reader()
1955 .read_plain::<storvsp_protocol::Packet>()
1956 .unwrap()
1957 .status,
1958 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE
1959 );
1960 Ok(())
1961 })
1962 .await;
1963 }
1964
1965 #[async_test]
1966 async fn test_unrecognized_operation(driver: DefaultDriver) {
1967 let (host, guest) = connected_async_channels(16384);
1969 let guest_queue = Queue::new(guest).unwrap();
1970
1971 let test_guest_mem = GuestMemory::allocate(1024);
1972 let controller = ScsiController::new();
1973
1974 let worker = TestWorker::start(
1975 controller.clone(),
1976 driver.clone(),
1977 test_guest_mem,
1978 host,
1979 None,
1980 );
1981
1982 let mut guest = test_helpers::TestGuest {
1983 queue: guest_queue,
1984 transaction_id: 0,
1985 };
1986
1987 let negotiate_packet = storvsp_protocol::Packet {
1989 operation: storvsp_protocol::Operation::REMOVE_DEVICE,
1990 flags: 0,
1991 status: storvsp_protocol::NtStatus::SUCCESS,
1992 };
1993 guest
1994 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
1995 .await;
1996
1997 match worker.teardown().await {
1998 Err(WorkerError::PacketError(PacketError::UnrecognizedOperation(
1999 storvsp_protocol::Operation::REMOVE_DEVICE,
2000 ))) => {}
2001 result => panic!("Worker failed with unexpected result {:?}!", result),
2002 }
2003 }
2004
2005 #[async_test]
2006 async fn test_too_many_subchannels(driver: DefaultDriver) {
2007 let (host, guest) = connected_async_channels(16384);
2009 let guest_queue = Queue::new(guest).unwrap();
2010
2011 let test_guest_mem = GuestMemory::allocate(1024);
2012 let controller = ScsiController::new();
2013
2014 let _worker = TestWorker::start(
2015 controller.clone(),
2016 driver.clone(),
2017 test_guest_mem,
2018 host,
2019 None,
2020 );
2021
2022 let mut guest = test_helpers::TestGuest {
2023 queue: guest_queue,
2024 transaction_id: 0,
2025 };
2026
2027 let negotiate_packet = storvsp_protocol::Packet {
2028 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2029 flags: 0,
2030 status: storvsp_protocol::NtStatus::SUCCESS,
2031 };
2032 guest
2033 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2034 .await;
2035 guest.verify_completion(parse_guest_completion).await;
2036
2037 let version_packet = storvsp_protocol::Packet {
2038 operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2039 flags: 0,
2040 status: storvsp_protocol::NtStatus::SUCCESS,
2041 };
2042 let version = storvsp_protocol::ProtocolVersion {
2043 major_minor: storvsp_protocol::VERSION_BLUE,
2044 reserved: 0,
2045 };
2046 guest
2047 .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2048 .await;
2049 guest.verify_completion(parse_guest_completion).await;
2050
2051 let properties_packet = storvsp_protocol::Packet {
2052 operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2053 flags: 0,
2054 status: storvsp_protocol::NtStatus::SUCCESS,
2055 };
2056 guest
2057 .send_data_packet_sync(&[properties_packet.as_bytes()])
2058 .await;
2059
2060 guest.verify_completion(parse_guest_completion).await;
2061
2062 let negotiate_packet = storvsp_protocol::Packet {
2063 operation: storvsp_protocol::Operation::CREATE_SUB_CHANNELS,
2064 flags: 0,
2065 status: storvsp_protocol::NtStatus::SUCCESS,
2066 };
2067 guest
2069 .send_data_packet_sync(&[negotiate_packet.as_bytes(), 1_u16.as_bytes()])
2070 .await;
2071
2072 guest
2073 .verify_completion(|packet| {
2074 let IncomingPacket::Completion(packet) = packet else {
2075 unreachable!()
2076 };
2077 assert_eq!(
2078 packet
2079 .reader()
2080 .read_plain::<storvsp_protocol::Packet>()
2081 .unwrap()
2082 .status,
2083 storvsp_protocol::NtStatus::INVALID_PARAMETER
2084 );
2085 Ok(())
2086 })
2087 .await;
2088 }
2089
2090 #[async_test]
2091 async fn test_begin_init_on_ready(driver: DefaultDriver) {
2092 let (host, guest) = connected_async_channels(16384);
2094 let guest_queue = Queue::new(guest).unwrap();
2095
2096 let test_guest_mem = GuestMemory::allocate(1024);
2097 let controller = ScsiController::new();
2098
2099 let _worker = TestWorker::start(
2100 controller.clone(),
2101 driver.clone(),
2102 test_guest_mem,
2103 host,
2104 None,
2105 );
2106
2107 let mut guest = test_helpers::TestGuest {
2108 queue: guest_queue,
2109 transaction_id: 0,
2110 };
2111
2112 guest.perform_protocol_negotiation().await;
2113
2114 let negotiate_packet = storvsp_protocol::Packet {
2116 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2117 flags: 0,
2118 status: storvsp_protocol::NtStatus::SUCCESS,
2119 };
2120 guest
2121 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2122 .await;
2123
2124 guest
2125 .verify_completion(|p| {
2126 parse_guest_completion_check_flags_status(
2127 p,
2128 0,
2129 storvsp_protocol::NtStatus::INVALID_DEVICE_STATE,
2130 )
2131 })
2132 .await;
2133 }
2134
2135 #[async_test]
2136 async fn test_hot_add_remove(driver: DefaultDriver) {
2137 let (host, guest) = connected_async_channels(16 * 1024);
2139 let guest_queue = Queue::new(guest).unwrap();
2140
2141 let test_guest_mem = GuestMemory::allocate(16384);
2142 let controller = ScsiController::new();
2144
2145 let test_worker = TestWorker::start(
2146 controller.clone(),
2147 driver.clone(),
2148 test_guest_mem.clone(),
2149 host,
2150 None,
2151 );
2152
2153 let mut guest = test_helpers::TestGuest {
2154 queue: guest_queue,
2155 transaction_id: 0,
2156 };
2157
2158 guest.perform_protocol_negotiation().await;
2159
2160 let mut lun_list_buffer: [u8; 256] = [0; 256];
2162 let mut disk_count = 0;
2163 guest
2164 .send_report_luns_packet(ScsiPath::default(), 0, lun_list_buffer.len())
2165 .await;
2166 guest
2167 .verify_completion(|p| {
2168 test_helpers::parse_guest_completed_io_check_tx_len(p, SrbStatus::SUCCESS, Some(8))
2169 })
2170 .await;
2171 test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2172 let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2173 assert_eq!(lun_list_size, disk_count as u32 * 8);
2174
2175 const IO_LEN: usize = 4 * 1024;
2177 let write_buf = [7u8; IO_LEN];
2178 let write_gpa = 4 * 1024u64;
2179 test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2180
2181 guest
2182 .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2183 .await;
2184 guest
2185 .verify_completion(|p| {
2186 test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2187 })
2188 .await;
2189
2190 for lun in 0..4 {
2192 let disk = scsidisk::SimpleScsiDisk::new(
2193 disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(),
2194 Default::default(),
2195 );
2196 controller
2197 .attach(
2198 ScsiPath {
2199 path: 0,
2200 target: 0,
2201 lun,
2202 },
2203 ScsiControllerDisk::new(Arc::new(disk)),
2204 )
2205 .unwrap();
2206 guest
2207 .verify_completion(test_helpers::parse_guest_enumerate_bus)
2208 .await;
2209
2210 disk_count += 1;
2211 guest
2212 .send_report_luns_packet(ScsiPath::default(), 0, 256)
2213 .await;
2214 guest
2215 .verify_completion(|p| {
2216 test_helpers::parse_guest_completed_io_check_tx_len(
2217 p,
2218 SrbStatus::SUCCESS,
2219 Some((disk_count + 1) * 8),
2220 )
2221 })
2222 .await;
2223 test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2224 let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2225 assert_eq!(lun_list_size, disk_count as u32 * 8);
2226
2227 guest
2228 .send_write_packet(
2229 ScsiPath {
2230 path: 0,
2231 target: 0,
2232 lun,
2233 },
2234 write_gpa,
2235 1,
2236 IO_LEN,
2237 )
2238 .await;
2239 guest
2240 .verify_completion(|p| {
2241 test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS)
2242 })
2243 .await;
2244 }
2245
2246 for lun in 0..4 {
2248 controller
2249 .remove(ScsiPath {
2250 path: 0,
2251 target: 0,
2252 lun,
2253 })
2254 .unwrap();
2255 guest
2256 .verify_completion(test_helpers::parse_guest_enumerate_bus)
2257 .await;
2258
2259 disk_count -= 1;
2260 guest
2261 .send_report_luns_packet(ScsiPath::default(), 0, 4096)
2262 .await;
2263 guest
2264 .verify_completion(|p| {
2265 test_helpers::parse_guest_completed_io_check_tx_len(
2266 p,
2267 SrbStatus::SUCCESS,
2268 Some((disk_count + 1) * 8),
2269 )
2270 })
2271 .await;
2272 test_guest_mem.read_at(0, &mut lun_list_buffer).unwrap();
2273 let lun_list_size = u32::from_be_bytes(lun_list_buffer[0..4].try_into().unwrap());
2274 assert_eq!(lun_list_size, disk_count as u32 * 8);
2275
2276 guest
2277 .send_write_packet(
2278 ScsiPath {
2279 path: 0,
2280 target: 0,
2281 lun,
2282 },
2283 write_gpa,
2284 1,
2285 IO_LEN,
2286 )
2287 .await;
2288 guest
2289 .verify_completion(|p| {
2290 test_helpers::parse_guest_completed_io(p, SrbStatus::INVALID_LUN)
2291 })
2292 .await;
2293 }
2294
2295 guest.verify_graceful_close(test_worker).await;
2296 }
2297
2298 #[async_test]
2299 pub async fn test_async_disk(driver: DefaultDriver) {
2300 let device = disklayer_ram::ram_disk(64 * 1024, false).unwrap();
2301 let controller = ScsiController::new();
2302 let disk = ScsiControllerDisk::new(Arc::new(scsidisk::SimpleScsiDisk::new(
2303 device,
2304 Default::default(),
2305 )));
2306 controller
2307 .attach(
2308 ScsiPath {
2309 path: 0,
2310 target: 0,
2311 lun: 0,
2312 },
2313 disk,
2314 )
2315 .unwrap();
2316
2317 let (host, guest) = connected_async_channels(16 * 1024);
2318 let guest_queue = Queue::new(guest).unwrap();
2319
2320 let mut guest = test_helpers::TestGuest {
2321 queue: guest_queue,
2322 transaction_id: 0,
2323 };
2324
2325 let test_guest_mem = GuestMemory::allocate(16384);
2326 let worker = TestWorker::start(
2327 controller.clone(),
2328 &driver,
2329 test_guest_mem.clone(),
2330 host,
2331 None,
2332 );
2333
2334 let negotiate_packet = storvsp_protocol::Packet {
2335 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
2336 flags: 0,
2337 status: storvsp_protocol::NtStatus::SUCCESS,
2338 };
2339 guest
2340 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2341 .await;
2342 guest.verify_completion(parse_guest_completion).await;
2343
2344 let version_packet = storvsp_protocol::Packet {
2345 operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
2346 flags: 0,
2347 status: storvsp_protocol::NtStatus::SUCCESS,
2348 };
2349 let version = storvsp_protocol::ProtocolVersion {
2350 major_minor: storvsp_protocol::VERSION_BLUE,
2351 reserved: 0,
2352 };
2353 guest
2354 .send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
2355 .await;
2356 guest.verify_completion(parse_guest_completion).await;
2357
2358 let properties_packet = storvsp_protocol::Packet {
2359 operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
2360 flags: 0,
2361 status: storvsp_protocol::NtStatus::SUCCESS,
2362 };
2363 guest
2364 .send_data_packet_sync(&[properties_packet.as_bytes()])
2365 .await;
2366 guest.verify_completion(parse_guest_completion).await;
2367
2368 let negotiate_packet = storvsp_protocol::Packet {
2369 operation: storvsp_protocol::Operation::END_INITIALIZATION,
2370 flags: 0,
2371 status: storvsp_protocol::NtStatus::SUCCESS,
2372 };
2373 guest
2374 .send_data_packet_sync(&[negotiate_packet.as_bytes()])
2375 .await;
2376 guest.verify_completion(parse_guest_completion).await;
2377
2378 const IO_LEN: usize = 4 * 1024;
2379 let write_buf = [7u8; IO_LEN];
2380 let write_gpa = 4 * 1024u64;
2381 test_guest_mem.write_at(write_gpa, &write_buf).unwrap();
2382 guest
2383 .send_write_packet(ScsiPath::default(), write_gpa, 1, IO_LEN)
2384 .await;
2385 guest
2386 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2387 .await;
2388
2389 let read_gpa = 8 * 1024u64;
2390 guest
2391 .send_read_packet(ScsiPath::default(), read_gpa, 1, IO_LEN)
2392 .await;
2393 guest
2394 .verify_completion(|p| test_helpers::parse_guest_completed_io(p, SrbStatus::SUCCESS))
2395 .await;
2396 let mut read_buf = [0u8; IO_LEN];
2397 test_guest_mem.read_at(read_gpa, &mut read_buf).unwrap();
2398 for (b1, b2) in read_buf.iter().zip(write_buf.iter()) {
2399 assert_eq!(b1, b2);
2400 }
2401
2402 guest.verify_graceful_close(worker).await;
2403 }
2404}