storvsp/
test_helpers.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! StorVSP test helpers.
5//!
6//! These are used both by unit tests and by benchmarks.
7
8// Benchmarks do not use all the code here, but unit tests should.
9#![cfg_attr(not(test), expect(dead_code))]
10
11use crate::InitState;
12use crate::PacketError;
13use crate::Protocol;
14use crate::ProtocolState;
15use crate::ScsiController;
16use crate::ScsiPath;
17use crate::Worker;
18use crate::WorkerError;
19use guestmem::GuestMemory;
20use guestmem::MemoryRead;
21use guestmem::ranges::PagedRange;
22use pal_async::task::Spawn;
23use pal_async::task::Task;
24use parking_lot::RwLock;
25use scsi::ScsiOp;
26use scsi::srb::SrbStatus;
27use scsi_defs as scsi;
28use std::sync::Arc;
29use vmbus_async::queue::IncomingPacket;
30use vmbus_async::queue::OutgoingPacket;
31use vmbus_async::queue::Queue;
32use vmbus_channel::RawAsyncChannel;
33use vmbus_ring as ring;
34use vmbus_ring::FlatRingMem;
35use vmbus_ring::OutgoingPacketType;
36use vmbus_ring::PAGE_SIZE;
37use zerocopy::FromZeros;
38use zerocopy::IntoBytes;
39
40pub struct TestWorker {
41    task: Task<Result<(), WorkerError>>,
42}
43
44impl TestWorker {
45    pub(crate) async fn teardown(self) -> Result<(), WorkerError> {
46        self.task.await
47    }
48
49    /// Like `teardown`, but ignore the result. Nice for the fuzzer,
50    /// so that the `storvsp` crate doesn't need to expose `WorkerError`
51    /// as pub.
52    #[cfg(feature = "test")]
53    pub async fn teardown_ignore(self) {
54        let _ = self.task.await;
55    }
56
57    /// Like `teardown`, but panic if there's a failure. Nice for integration tests, so that the
58    /// `storvsp` crate doesn't need to expose `WorkerError` as pub.
59    #[cfg(feature = "test")]
60    pub async fn teardown_or_panic(self) {
61        match self.task.await {
62            Ok(()) => {}
63            Err(WorkerError::Queue(err)) if err.is_closed_error() => {}
64            Err(err) => {
65                panic!("Worker did not teardown gracefully: {:?}", err);
66            }
67        }
68    }
69
70    pub fn start<T: ring::RingMem + 'static + Sync>(
71        controller: ScsiController,
72        spawner: impl Spawn,
73        mem: GuestMemory,
74        channel: RawAsyncChannel<T>,
75        io_queue_depth: Option<u32>,
76    ) -> Self {
77        let task = spawner.spawn("test", async move {
78            let mut worker = Worker::new(
79                controller.state.clone(),
80                channel,
81                0,
82                mem,
83                Default::default(),
84                io_queue_depth.unwrap_or(256),
85                Arc::new(Protocol {
86                    state: RwLock::new(ProtocolState::Init(InitState::Begin)),
87                    ready: Default::default(),
88                }),
89                None,
90            )
91            .unwrap();
92            worker.process_primary().await
93        });
94
95        Self { task }
96    }
97}
98
99pub(crate) fn parse_guest_completion_check_flags_status<T: ring::RingMem>(
100    packet: &IncomingPacket<'_, T>,
101    flags: u32,
102    status: storvsp_protocol::NtStatus,
103) -> Result<(), PacketError> {
104    match packet {
105        IncomingPacket::Completion(compl) => {
106            let mut reader = compl.reader();
107            let header: storvsp_protocol::Packet =
108                reader.read_plain().map_err(PacketError::Access)?;
109            assert_eq!(header.flags, flags, "mismatched flags");
110            assert_eq!(header.status, status, "mismatched status");
111            assert_eq!(
112                header.operation,
113                storvsp_protocol::Operation::COMPLETE_IO,
114                "mismatched operation"
115            );
116            Ok(())
117        }
118        IncomingPacket::Data(_) => Err(PacketError::InvalidPacketType),
119    }
120}
121
122pub(crate) fn parse_guest_completion<T: ring::RingMem>(
123    packet: &IncomingPacket<'_, T>,
124) -> Result<(), PacketError> {
125    parse_guest_completion_check_flags_status(packet, 0, storvsp_protocol::NtStatus::SUCCESS)
126}
127
128pub(crate) fn parse_guest_completed_io<T: ring::RingMem>(
129    packet: &IncomingPacket<'_, T>,
130    expected_srb_status: SrbStatus,
131) -> Result<(), PacketError> {
132    parse_guest_completed_io_check_tx_len(packet, expected_srb_status, None)
133}
134
135pub(crate) fn parse_guest_completed_io_check_tx_len<T: ring::RingMem>(
136    packet: &IncomingPacket<'_, T>,
137    expected_srb_status: SrbStatus,
138    expected_data_tx_length: Option<usize>,
139) -> Result<(), PacketError> {
140    match packet {
141        IncomingPacket::Completion(compl) => {
142            let mut reader = compl.reader();
143            let header: storvsp_protocol::Packet =
144                reader.read_plain().map_err(PacketError::Access)?;
145            if header.operation != storvsp_protocol::Operation::COMPLETE_IO {
146                Err(PacketError::UnrecognizedOperation(header.operation))
147            } else {
148                if expected_srb_status == SrbStatus::SUCCESS {
149                    assert_eq!(header.status, storvsp_protocol::NtStatus::SUCCESS);
150                    if let Some(expected_data_tx_length) = expected_data_tx_length {
151                        let payload: storvsp_protocol::ScsiRequest =
152                            reader.read_plain().map_err(PacketError::Access)?;
153                        assert_eq!(
154                            payload.data_transfer_length as usize,
155                            expected_data_tx_length
156                        );
157                    }
158                } else {
159                    assert_ne!(header.status, storvsp_protocol::NtStatus::SUCCESS);
160                    let payload: storvsp_protocol::ScsiRequest =
161                        reader.read_plain().map_err(PacketError::Access)?;
162                    assert_eq!(payload.srb_status.status(), expected_srb_status);
163                }
164                Ok(())
165            }
166        }
167        _ => Err(PacketError::InvalidPacketType),
168    }
169}
170
171pub(crate) fn parse_guest_enumerate_bus<T: ring::RingMem>(
172    packet: &IncomingPacket<'_, T>,
173) -> Result<(), PacketError> {
174    match packet {
175        IncomingPacket::Data(p) => {
176            let mut reader = p.reader();
177            let header: storvsp_protocol::Packet =
178                reader.read_plain().map_err(PacketError::Access)?;
179            if header.operation != storvsp_protocol::Operation::ENUMERATE_BUS {
180                Err(PacketError::UnrecognizedOperation(header.operation))
181            } else {
182                assert_eq!(header.status, storvsp_protocol::NtStatus::SUCCESS);
183                Ok(())
184            }
185        }
186        _ => Err(PacketError::InvalidPacketType),
187    }
188}
189
190pub struct TestGuest {
191    pub queue: Queue<FlatRingMem>,
192    pub transaction_id: u64,
193}
194
195impl TestGuest {
196    pub async fn send_data_packet_sync(&mut self, payload: &[&[u8]]) {
197        self.queue
198            .split()
199            .1
200            .write(OutgoingPacket {
201                packet_type: OutgoingPacketType::InBandWithCompletion,
202                transaction_id: self.transaction_id,
203                payload,
204            })
205            .await
206            .unwrap();
207
208        self.transaction_id += 1;
209    }
210
211    pub async fn send_gpa_direct_packet_sync(
212        &mut self,
213        payload: &[&[u8]],
214        gpa_start: u64,
215        byte_len: usize,
216    ) {
217        let start_page: u64 = gpa_start / PAGE_SIZE as u64;
218        let end_page: u64 = (gpa_start + (byte_len + PAGE_SIZE - 1) as u64) / PAGE_SIZE as u64;
219        let gpas: Vec<u64> = (start_page..end_page).collect();
220        let pages =
221            PagedRange::new(gpa_start as usize % PAGE_SIZE, byte_len, gpas.as_slice()).unwrap();
222        self.queue
223            .split()
224            .1
225            .write(OutgoingPacket {
226                packet_type: OutgoingPacketType::GpaDirect(&[pages]),
227                transaction_id: self.transaction_id,
228                payload,
229            })
230            .await
231            .unwrap();
232
233        self.transaction_id += 1;
234    }
235
236    // This function assumes the sector size is 512.
237    pub async fn send_write_packet(
238        &mut self,
239        path: ScsiPath,
240        buf_gpa: u64,
241        block: u32,
242        byte_len: usize,
243    ) {
244        let write_packet = storvsp_protocol::Packet {
245            operation: storvsp_protocol::Operation::EXECUTE_SRB,
246            flags: 0,
247            status: storvsp_protocol::NtStatus::SUCCESS,
248        };
249
250        let cdb = scsi::Cdb10 {
251            operation_code: ScsiOp::WRITE,
252            logical_block: block.into(),
253            transfer_blocks: ((byte_len / 512) as u16).into(),
254            ..FromZeros::new_zeroed()
255        };
256
257        let mut scsi_req = storvsp_protocol::ScsiRequest {
258            target_id: path.target,
259            path_id: path.path,
260            lun: path.lun,
261            length: storvsp_protocol::SCSI_REQUEST_LEN_V2 as u16,
262            cdb_length: size_of::<scsi::Cdb10>() as u8,
263            data_transfer_length: byte_len as u32,
264            ..FromZeros::new_zeroed()
265        };
266
267        scsi_req.payload[0..10].copy_from_slice(cdb.as_bytes());
268
269        // send the gpa packet
270        self.send_gpa_direct_packet_sync(
271            &[write_packet.as_bytes(), scsi_req.as_bytes()],
272            buf_gpa,
273            byte_len,
274        )
275        .await;
276    }
277
278    // This function assumes the sector size is 512.
279    pub async fn send_read_packet(
280        &mut self,
281        path: ScsiPath,
282        read_gpa: u64,
283        block: u32,
284        byte_len: usize,
285    ) {
286        let read_packet = storvsp_protocol::Packet {
287            operation: storvsp_protocol::Operation::EXECUTE_SRB,
288            flags: 0,
289            status: storvsp_protocol::NtStatus::SUCCESS,
290        };
291
292        let cdb = scsi::Cdb10 {
293            operation_code: ScsiOp::READ,
294            logical_block: block.into(),
295            transfer_blocks: ((byte_len / 512) as u16).into(),
296            ..FromZeros::new_zeroed()
297        };
298
299        let mut scsi_req = storvsp_protocol::ScsiRequest {
300            target_id: path.target,
301            path_id: path.path,
302            lun: path.lun,
303            length: storvsp_protocol::SCSI_REQUEST_LEN_V2 as u16,
304            cdb_length: size_of::<scsi::Cdb10>() as u8,
305            data_transfer_length: byte_len as u32,
306            data_in: 1,
307            ..FromZeros::new_zeroed()
308        };
309
310        scsi_req.payload[0..10].copy_from_slice(cdb.as_bytes());
311
312        // send the gpa packet
313        self.send_gpa_direct_packet_sync(
314            &[read_packet.as_bytes(), scsi_req.as_bytes()],
315            read_gpa,
316            byte_len,
317        )
318        .await;
319    }
320
321    pub async fn send_report_luns_packet(
322        &mut self,
323        path: ScsiPath,
324        data_buffer_gpa: u64,
325        data_buffer_len: usize,
326    ) {
327        let packet = storvsp_protocol::Packet {
328            operation: storvsp_protocol::Operation::EXECUTE_SRB,
329            flags: 0,
330            status: storvsp_protocol::NtStatus::SUCCESS,
331        };
332
333        let cdb = scsi::Cdb10 {
334            operation_code: ScsiOp::REPORT_LUNS,
335            ..FromZeros::new_zeroed()
336        };
337
338        let mut scsi_req = storvsp_protocol::ScsiRequest {
339            target_id: path.target,
340            path_id: path.path,
341            lun: path.lun,
342            length: storvsp_protocol::SCSI_REQUEST_LEN_V2 as u16,
343            cdb_length: size_of::<scsi::Cdb10>() as u8,
344            data_transfer_length: data_buffer_len as u32,
345            data_in: 1,
346            ..FromZeros::new_zeroed()
347        };
348
349        scsi_req.payload[0..10].copy_from_slice(cdb.as_bytes());
350
351        self.send_gpa_direct_packet_sync(
352            &[packet.as_bytes(), scsi_req.as_bytes()],
353            data_buffer_gpa,
354            data_buffer_len,
355        )
356        .await;
357    }
358
359    pub(crate) async fn verify_completion<F>(&mut self, f: F)
360    where
361        F: Clone + FnOnce(&IncomingPacket<'_, FlatRingMem>) -> Result<(), PacketError>,
362    {
363        let (mut reader, _) = self.queue.split();
364        let packet = reader.read().await.unwrap();
365        f(&packet).unwrap();
366    }
367
368    // Send protocol negotiation packets for a test guest.
369    pub async fn perform_protocol_negotiation(&mut self) {
370        let negotiate_packet = storvsp_protocol::Packet {
371            operation: storvsp_protocol::Operation::BEGIN_INITIALIZATION,
372            flags: 0,
373            status: storvsp_protocol::NtStatus::SUCCESS,
374        };
375        self.send_data_packet_sync(&[negotiate_packet.as_bytes()])
376            .await;
377        self.verify_completion(parse_guest_completion).await;
378
379        let version_packet = storvsp_protocol::Packet {
380            operation: storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION,
381            flags: 0,
382            status: storvsp_protocol::NtStatus::SUCCESS,
383        };
384        let version = storvsp_protocol::ProtocolVersion {
385            major_minor: storvsp_protocol::VERSION_BLUE,
386            reserved: 0,
387        };
388        self.send_data_packet_sync(&[version_packet.as_bytes(), version.as_bytes()])
389            .await;
390        self.verify_completion(parse_guest_completion).await;
391
392        let properties_packet = storvsp_protocol::Packet {
393            operation: storvsp_protocol::Operation::QUERY_PROPERTIES,
394            flags: 0,
395            status: storvsp_protocol::NtStatus::SUCCESS,
396        };
397        self.send_data_packet_sync(&[properties_packet.as_bytes()])
398            .await;
399        self.verify_completion(parse_guest_completion).await;
400
401        let negotiate_packet = storvsp_protocol::Packet {
402            operation: storvsp_protocol::Operation::END_INITIALIZATION,
403            flags: 0,
404            status: storvsp_protocol::NtStatus::SUCCESS,
405        };
406        self.send_data_packet_sync(&[negotiate_packet.as_bytes()])
407            .await;
408        self.verify_completion(parse_guest_completion).await;
409    }
410
411    pub(crate) async fn verify_graceful_close(self, worker: TestWorker) {
412        drop(self);
413        match worker.task.await {
414            Err(WorkerError::Queue(err)) if err.is_closed_error() => (),
415            _ => panic!("Worker thread did not complete gracefully!"),
416        }
417    }
418}