1#![cfg_attr(not(test), expect(dead_code))]
10
11use crate::InitState;
12use crate::PacketError;
13use crate::Protocol;
14use crate::ProtocolState;
15use crate::ScsiController;
16use crate::ScsiPath;
17use crate::Worker;
18use crate::WorkerError;
19use guestmem::GuestMemory;
20use guestmem::MemoryRead;
21use guestmem::ranges::PagedRange;
22use pal_async::task::Spawn;
23use pal_async::task::Task;
24use parking_lot::RwLock;
25use scsi::ScsiOp;
26use scsi::srb::SrbStatus;
27use scsi_defs as scsi;
28use std::sync::Arc;
29use vmbus_async::queue::IncomingPacket;
30use vmbus_async::queue::OutgoingPacket;
31use vmbus_async::queue::Queue;
32use vmbus_channel::RawAsyncChannel;
33use vmbus_ring as ring;
34use vmbus_ring::FlatRingMem;
35use vmbus_ring::OutgoingPacketType;
36use vmbus_ring::PAGE_SIZE;
37use zerocopy::FromZeros;
38use zerocopy::IntoBytes;
39
40pub struct TestWorker {
41 task: Task<Result<(), WorkerError>>,
42}
43
44impl TestWorker {
45 pub(crate) async fn teardown(self) -> Result<(), WorkerError> {
46 self.task.await
47 }
48
49 #[cfg(feature = "test")]
53 pub async fn teardown_ignore(self) {
54 let _ = self.task.await;
55 }
56
57 #[cfg(feature = "test")]
60 pub async fn teardown_or_panic(self) {
61 match self.task.await {
62 Ok(()) => {}
63 Err(WorkerError::Queue(err)) if err.is_closed_error() => {}
64 Err(err) => {
65 panic!("Worker did not teardown gracefully: {:?}", err);
66 }
67 }
68 }
69
70 pub fn start<T: ring::RingMem + 'static + Sync>(
71 controller: ScsiController,
72 spawner: impl Spawn,
73 mem: GuestMemory,
74 channel: RawAsyncChannel<T>,
75 io_queue_depth: Option<u32>,
76 ) -> Self {
77 let task = spawner.spawn("test", async move {
78 let mut worker = Worker::new(
79 controller.state.clone(),
80 channel,
81 0,
82 mem,
83 Default::default(),
84 io_queue_depth.unwrap_or(256),
85 Arc::new(Protocol {
86 state: RwLock::new(ProtocolState::Init(InitState::Begin)),
87 ready: Default::default(),
88 }),
89 None,
90 )
91 .unwrap();
92 worker.process_primary().await
93 });
94
95 Self { task }
96 }
97}
98
99pub(crate) fn parse_guest_completion_check_flags_status<T: ring::RingMem>(
100 packet: &IncomingPacket<'_, T>,
101 flags: u32,
102 status: storvsp_protocol::NtStatus,
103) -> Result<(), PacketError> {
104 match packet {
105 IncomingPacket::Completion(compl) => {
106 let mut reader = compl.reader();
107 let header: storvsp_protocol::Packet =
108 reader.read_plain().map_err(PacketError::Access)?;
109 assert_eq!(header.flags, flags, "mismatched flags");
110 assert_eq!(header.status, status, "mismatched status");
111 assert_eq!(
112 header.operation,
113 storvsp_protocol::Operation::COMPLETE_IO,
114 "mismatched operation"
115 );
116 Ok(())
117 }
118 IncomingPacket::Data(_) => Err(PacketError::InvalidPacketType),
119 }
120}
121
122pub(crate) fn parse_guest_completion<T: ring::RingMem>(
123 packet: &IncomingPacket<'_, T>,
124) -> Result<(), PacketError> {
125 parse_guest_completion_check_flags_status(packet, 0, storvsp_protocol::NtStatus::SUCCESS)
126}
127
128pub(crate) fn parse_guest_completed_io<T: ring::RingMem>(
129 packet: &IncomingPacket<'_, T>,
130 expected_srb_status: SrbStatus,
131) -> Result<(), PacketError> {
132 parse_guest_completed_io_check_tx_len(packet, expected_srb_status, None)
133}
134
135pub(crate) fn parse_guest_completed_io_check_tx_len<T: ring::RingMem>(
136 packet: &IncomingPacket<'_, T>,
137 expected_srb_status: SrbStatus,
138 expected_data_tx_length: Option<usize>,
139) -> Result<(), PacketError> {
140 match packet {
141 IncomingPacket::Completion(compl) => {
142 let mut reader = compl.reader();
143 let header: storvsp_protocol::Packet =
144 reader.read_plain().map_err(PacketError::Access)?;
145 if header.operation != storvsp_protocol::Operation::COMPLETE_IO {
146 Err(PacketError::UnrecognizedOperation(header.operation))
147 } else {
148 if expected_srb_status == SrbStatus::SUCCESS {
149 assert_eq!(header.status, storvsp_protocol::NtStatus::SUCCESS);
150 if let Some(expected_data_tx_length) = expected_data_tx_length {
151 let payload: storvsp_protocol::ScsiRequest =
152 reader.read_plain().map_err(PacketError::Access)?;
153 assert_eq!(
154 payload.data_transfer_length as usize,
155 expected_data_tx_length
156 );
157 }
158 } else {
159 assert_ne!(header.status, storvsp_protocol::NtStatus::SUCCESS);
160 let payload: storvsp_protocol::ScsiRequest =
161 reader.read_plain().map_err(PacketError::Access)?;
162 assert_eq!(payload.srb_status.status(), expected_srb_status);
163 }
164 Ok(())
165 }
166 }
167 _ => Err(PacketError::InvalidPacketType),
168 }
169}
170
171pub(crate) fn parse_guest_enumerate_bus<T: ring::RingMem>(
172 packet: &IncomingPacket<'_, T>,
173) -> Result<(), PacketError> {
174 match packet {
175 IncomingPacket::Data(p) => {
176 let mut reader = p.reader();
177 let header: storvsp_protocol::Packet =
178 reader.read_plain().map_err(PacketError::Access)?;
179 if header.operation != storvsp_protocol::Operation::ENUMERATE_BUS {
180 Err(PacketError::UnrecognizedOperation(header.operation))
181 } else {
182 assert_eq!(header.status, storvsp_protocol::NtStatus::SUCCESS);
183 Ok(())
184 }
185 }
186 _ => Err(PacketError::InvalidPacketType),
187 }
188}
189
190pub struct TestGuest {
191 pub queue: Queue<FlatRingMem>,
192 pub transaction_id: u64,
193}
194
195impl TestGuest {
196 pub async fn send_data_packet_sync(&mut self, payload: &[&[u8]]) {
197 self.queue
198 .split()
199 .1
200 .write(OutgoingPacket {
201 packet_type: OutgoingPacketType::InBandWithCompletion,
202 transaction_id: self.transaction_id,
203 payload,
204 })
205 .await
206 .unwrap();
207
208 self.transaction_id += 1;
209 }
210
211 pub async fn send_gpa_direct_packet_sync(
212 &mut self,
213 payload: &[&[u8]],
214 gpa_start: u64,
215 byte_len: usize,
216 ) {
217 let start_page: u64 = gpa_start / PAGE_SIZE as u64;
218 let end_page: u64 = (gpa_start + (byte_len + PAGE_SIZE - 1) as u64) / PAGE_SIZE as u64;
219 let gpas: Vec<u64> = (start_page..end_page).collect();
220 let pages =
221 PagedRange::new(gpa_start as usize % PAGE_SIZE, byte_len, gpas.as_slice()).unwrap();
222 self.queue
223 .split()
224 .1
225 .write(OutgoingPacket {
226 packet_type: OutgoingPacketType::GpaDirect(&[pages]),
227 transaction_id: self.transaction_id,
228 payload,
229 })
230 .await
231 .unwrap();
232
233 self.transaction_id += 1;
234 }
235
236 pub async fn send_write_packet(
238 &mut self,
239 path: ScsiPath,
240 buf_gpa: u64,
241 block: u32,
242 byte_len: usize,
243 ) {
244 let write_packet = storvsp_protocol::Packet {
245 operation: storvsp_protocol::Operation::EXECUTE_SRB,
246 flags: 0,
247 status: storvsp_protocol::NtStatus::SUCCESS,
248 };
249
250 let cdb = scsi::Cdb10 {
251 operation_code: ScsiOp::WRITE,
252 logical_block: block.into(),
253 transfer_blocks: ((byte_len / 512) as u16).into(),
254 ..FromZeros::new_zeroed()
255 };
256
257 let mut scsi_req = storvsp_protocol::ScsiRequest {
258 target_id: path.target,
259 path_id: path.path,
260 lun: path.lun,
261 length: storvsp_protocol::SCSI_REQUEST_LEN_V2 as u16,
262 cdb_length: size_of::<scsi::Cdb10>() as u8,
263 data_transfer_length: byte_len as u32,
264 ..FromZeros::new_zeroed()
265 };
266
267 scsi_req.payload[0..10].copy_from_slice(cdb.as_bytes());
268
269 self.send_gpa_direct_packet_sync(
271 &[write_packet.as_bytes(), scsi_req.as_bytes()],
272 buf_gpa,
273 byte_len,
274 )
275 .await;
276 }
277
278 pub async fn send_read_packet(
280 &mut self,
281 path: ScsiPath,
282 read_gpa: u64,
283 block: u32,
284 byte_len: usize,
285 ) {
286 let read_packet = storvsp_protocol::Packet {
287 operation: storvsp_protocol::Operation::EXECUTE_SRB,
288 flags: 0,
289 status: storvsp_protocol::NtStatus::SUCCESS,
290 };
291
292 let cdb = scsi::Cdb10 {
293 operation_code: ScsiOp::READ,
294 logical_block: block.into(),
295 transfer_blocks: ((byte_len / 512) as u16).into(),
296 ..FromZeros::new_zeroed()
297 };
298
299 let mut scsi_req = storvsp_protocol::ScsiRequest {
300 target_id: path.target,
301 path_id: path.path,
302 lun: path.lun,
303 length: storvsp_protocol::SCSI_REQUEST_LEN_V2 as u16,
304 cdb_length: size_of::<scsi::Cdb10>() as u8,
305 data_transfer_length: byte_len as u32,
306 data_in: 1,
307 ..FromZeros::new_zeroed()
308 };
309
310 scsi_req.payload[0..10].copy_from_slice(cdb.as_bytes());
311
312 self.send_gpa_direct_packet_sync(
314 &[read_packet.as_bytes(), scsi_req.as_bytes()],
315 read_gpa,
316 byte_len,
317 )
318 .await;
319 }
320
321 pub async fn send_report_luns_packet(
322 &mut self,
323 path: ScsiPath,
324 data_buffer_gpa: u64,
325 data_buffer_len: usize,
326 ) {
327 let packet = storvsp_protocol::Packet {
328 operation: storvsp_protocol::Operation::EXECUTE_SRB,
329 flags: 0,
330 status: storvsp_protocol::NtStatus::SUCCESS,
331 };
332
333 let cdb = scsi::Cdb10 {
334 operation_code: ScsiOp::REPORT_LUNS,
335 ..FromZeros::new_zeroed()
336 };
337
338 let mut scsi_req = storvsp_protocol::ScsiRequest {
339 target_id: path.target,
340 path_id: path.path,
341 lun: path.lun,
342 length: storvsp_protocol::SCSI_REQUEST_LEN_V2 as u16,
343 cdb_length: size_of::<scsi::Cdb10>() as u8,
344 data_transfer_length: data_buffer_len as u32,
345 data_in: 1,
346 ..FromZeros::new_zeroed()
347 };
348
349 scsi_req.payload[0..10].copy_from_slice(cdb.as_bytes());
350
351 self.send_gpa_direct_packet_sync(
352 &[packet.as_bytes(), scsi_req.as_bytes()],
353 data_buffer_gpa,
354 data_buffer_len,
355 )
356 .await;
357 }
358
359 pub(crate) async fn verify_completion<F>(&mut self, f: F)
360 where
361 F: Clone + FnOnce(&IncomingPacket<'_, FlatRingMem>) -> Result<(), PacketError>,
362 {
363 let (mut reader, _) = self.queue.split();
364 let packet = reader.read().await.unwrap();
365 f(&packet).unwrap();
366 }
367
368 pub async fn perform_protocol_negotiation(&mut self) {
370 let negotiate_packet = storvsp_protocol::Packet {
371 operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
372 flags: 0,
373 status: storvsp_protocol::NtStatus::SUCCESS,
374 };
375 self.send_data_packet_sync(&[negotiate_packet.as_bytes()])
376 .await;
377 self.verify_completion(parse_guest_completion).await;
378
379 let version_packet = storvsp_protocol::Packet {
380 operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
381 flags: 0,
382 status: storvsp_protocol::NtStatus::SUCCESS,
383 };
384 let version = storvsp_protocol::ProtocolVersion {
385 major_minor: storvsp_protocol::VERSION_BLUE,
386 reserved: 0,
387 };
388 self.send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
389 .await;
390 self.verify_completion(parse_guest_completion).await;
391
392 let properties_packet = storvsp_protocol::Packet {
393 operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
394 flags: 0,
395 status: storvsp_protocol::NtStatus::SUCCESS,
396 };
397 self.send_data_packet_sync(&[properties_packet.as_bytes()])
398 .await;
399 self.verify_completion(parse_guest_completion).await;
400
401 let negotiate_packet = storvsp_protocol::Packet {
402 operation: storvsp_protocol::Operation::END_INITIALIZATION,
403 flags: 0,
404 status: storvsp_protocol::NtStatus::SUCCESS,
405 };
406 self.send_data_packet_sync(&[negotiate_packet.as_bytes()])
407 .await;
408 self.verify_completion(parse_guest_completion).await;
409 }
410
411 pub(crate) async fn verify_graceful_close(self, worker: TestWorker) {
412 drop(self);
413 match worker.task.await {
414 Err(WorkerError::Queue(err)) if err.is_closed_error() => (),
415 _ => panic!("Worker thread did not complete gracefully!"),
416 }
417 }
418}