Skip to main content

disk_blockdevice/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![cfg(target_os = "linux")]
6
7//! Implements the [`DiskIo`] trait for virtual disks backed by a raw block
8//! device.
9
10// UNSAFETY: Issuing IOs and calling ioctls.
11#![expect(unsafe_code)]
12
13mod ioctl;
14mod nvme;
15pub mod resolver;
16
17use anyhow::Context;
18use blocking::unblock;
19use disk_backend::DiskError;
20use disk_backend::DiskIo;
21use disk_backend::UnmapBehavior;
22use disk_backend::pr::PersistentReservation;
23use disk_backend::pr::ReservationCapabilities;
24use disk_backend::pr::ReservationReport;
25use disk_backend::pr::ReservationType;
26use fs_err::PathExt;
27use guestmem::MemoryRead;
28use guestmem::MemoryWrite;
29use inspect::Inspect;
30use io_uring::opcode;
31use io_uring::types;
32use nvme::check_nvme_status;
33use nvme_spec::nvm;
34use pal::unix::affinity;
35use pal_async::driver::Driver;
36use scsi_buffers::BounceBuffer;
37use scsi_buffers::BounceBufferTracker;
38use scsi_buffers::RequestBuffers;
39use std::fmt::Debug;
40use std::fs;
41use std::os::unix::io::AsRawFd;
42use std::os::unix::prelude::FileTypeExt;
43use std::os::unix::prelude::MetadataExt;
44use std::path::Path;
45use std::path::PathBuf;
46use std::str::FromStr;
47use std::sync::Arc;
48use std::sync::atomic::AtomicU64;
49use std::sync::atomic::Ordering;
50use thiserror::Error;
51use uevent::CallbackHandle;
52use uevent::UeventListener;
53
54/// Opens a file for use with [`BlockDevice`] or
55/// [`disk_backend_resources::BlockDeviceDiskHandle`].
56pub fn open_file_for_block(
57    path: &Path,
58    read_only: bool,
59    direct: bool,
60) -> std::io::Result<fs::File> {
61    use std::os::unix::prelude::*;
62
63    tracing::debug!(?path, read_only, direct, "open_file_for_block");
64    let mut opts = fs::OpenOptions::new();
65    opts.read(true).write(!read_only);
66    if direct {
67        opts.custom_flags(libc::O_DIRECT);
68    }
69    opts.open(path)
70}
71
72/// A bounce buffer that may or may not be tracked by a
73/// [`BounceBufferTracker`].
74enum MaybeBounceBuffer<'a> {
75    Tracked(scsi_buffers::TrackedBounceBuffer<'a>),
76    Untracked(BounceBuffer),
77}
78
79impl MaybeBounceBuffer<'_> {
80    fn io_vecs(&self) -> &[scsi_buffers::IoBuffer<'_>] {
81        match self {
82            Self::Tracked(t) => t.buffer.io_vecs(),
83            Self::Untracked(b) => b.io_vecs(),
84        }
85    }
86
87    fn as_mut_bytes(&mut self) -> &mut [u8] {
88        match self {
89            Self::Tracked(t) => t.buffer.as_mut_bytes(),
90            Self::Untracked(b) => b.as_mut_bytes(),
91        }
92    }
93}
94
95/// A storvsp disk backed by a raw block device.
96#[derive(Inspect)]
97#[inspect(extra = "BlockDevice::inspect_extra")]
98pub struct BlockDevice {
99    file: Arc<fs::File>,
100    sector_size: u32,
101    physical_sector_size: u32,
102    sector_shift: u32,
103    sector_count: AtomicU64,
104    optimal_unmap_sectors: u32,
105    read_only: bool,
106    #[inspect(skip)]
107    driver: Box<dyn Driver>,
108    #[inspect(flatten)]
109    device_type: DeviceType,
110    supports_pr: bool,
111    supports_fua: bool,
112    #[inspect(skip)]
113    _uevent_filter: Option<CallbackHandle>,
114    resize_epoch: Arc<ResizeEpoch>,
115    resized_acked: AtomicU64,
116    #[inspect(skip)]
117    bounce_buffer_tracker: Option<Arc<BounceBufferTracker>>,
118    always_bounce: bool,
119}
120
121#[derive(Inspect, Debug, Default)]
122#[inspect(transparent)]
123struct ResizeEpoch {
124    epoch: AtomicU64,
125    #[inspect(skip)]
126    event: event_listener::Event,
127}
128
129#[derive(Debug, Copy, Clone, Inspect)]
130#[inspect(tag = "device_type")]
131enum DeviceType {
132    File {
133        sector_count: u64,
134    },
135    UnknownBlock,
136    NVMe {
137        ns_id: u32,
138        rescap: nvm::ReservationCapabilities,
139    },
140}
141
142impl BlockDevice {
143    fn inspect_extra(&self, resp: &mut inspect::Response<'_>) {
144        match self.device_type {
145            DeviceType::NVMe { .. } => {
146                resp.field_mut_with("interrupt_aggregation", |new_value| {
147                    self.inspect_interrupt_coalescing(new_value)
148                });
149            }
150            DeviceType::UnknownBlock => {}
151            DeviceType::File { .. } => {}
152        }
153    }
154
155    fn inspect_interrupt_coalescing(&self, new_value: Option<&str>) -> anyhow::Result<String> {
156        let coalescing = if let Some(new_value) = new_value {
157            let coalescing = (|| {
158                let (threshold, time) = new_value.split_once(' ')?;
159                Some(
160                    nvme::InterruptCoalescing::new()
161                        .with_aggregation_threshold(threshold.parse().ok()?)
162                        .with_aggregation_time(time.parse().ok()?),
163                )
164            })()
165            .context("expected `<aggregation_threshold> <aggregation_time>`")?;
166            nvme::nvme_set_features_interrupt_coalescing(&self.file, coalescing)?;
167            coalescing
168        } else if let Ok(coalescing) = nvme::nvme_get_features_interrupt_coalescing(&self.file) {
169            coalescing
170        } else {
171            return Ok("not supported".into());
172        };
173        Ok(format!(
174            "{} {}",
175            coalescing.aggregation_threshold(),
176            coalescing.aggregation_time()
177        ))
178    }
179}
180
181/// New device error
182#[derive(Debug, Error)]
183pub enum NewDeviceError {
184    #[error("block device ioctl error")]
185    IoctlError(#[from] DiskError),
186    #[error("failed to read device metadata")]
187    DeviceMetadata(#[source] anyhow::Error),
188    #[error("invalid file type, not a file or block device")]
189    InvalidFileType,
190    #[error("invalid disk size {0:#x}")]
191    InvalidDiskSize(u64),
192    #[error("driver does not support io-uring")]
193    NoIoUring,
194}
195
196impl BlockDevice {
197    /// Constructs a new `BlockDevice` backed by the specified file.
198    ///
199    /// # Arguments
200    /// * `file` - The backing device opened for raw access.
201    /// * `read_only` - Indicates whether the device is opened for read-only access.
202    /// * `driver` - The async driver to use for issuing IOs (must support io-uring).
203    /// * `always_bounce` - Whether to always use bounce buffers for IOs, even for those that are aligned.
204    pub async fn new(
205        file: fs::File,
206        read_only: bool,
207        driver: impl Driver,
208        uevent_listener: Option<&UeventListener>,
209        bounce_buffer_tracker: Option<Arc<BounceBufferTracker>>,
210        always_bounce: bool,
211    ) -> Result<BlockDevice, NewDeviceError> {
212        if !driver.io_uring_probe(opcode::Read::CODE) {
213            return Err(NewDeviceError::NoIoUring);
214        }
215        assert!(driver.io_uring_probe(opcode::Write::CODE));
216        assert!(driver.io_uring_probe(opcode::Readv::CODE));
217        assert!(driver.io_uring_probe(opcode::Writev::CODE));
218        assert!(driver.io_uring_probe(opcode::Fsync::CODE));
219
220        let metadata = file.metadata().map_err(DiskError::Io)?;
221
222        let mut uevent_filter = None;
223        let resize_epoch = Arc::new(ResizeEpoch::default());
224
225        let devmeta = if metadata.file_type().is_block_device() {
226            let rdev = metadata.rdev();
227            let (major, minor) = (libc::major(rdev), libc::minor(rdev));
228
229            // Register for resize events.
230            if let Some(uevent_listener) = uevent_listener {
231                let resize_epoch = resize_epoch.clone();
232                uevent_filter = Some(
233                    uevent_listener
234                        .add_block_resize_callback(major, minor, {
235                            move || {
236                                tracing::info!(major, minor, "disk resized");
237                                resize_epoch.epoch.fetch_add(1, Ordering::SeqCst);
238                                resize_epoch.event.notify(usize::MAX);
239                            }
240                        })
241                        .await,
242                );
243            }
244
245            DeviceMetadata::from_block_device(&file, major, minor)
246                .map_err(NewDeviceError::DeviceMetadata)?
247        } else if metadata.file_type().is_file() {
248            DeviceMetadata::from_file(&metadata).map_err(NewDeviceError::DeviceMetadata)?
249        } else {
250            return Err(NewDeviceError::InvalidFileType);
251        };
252
253        let sector_size = devmeta.logical_block_size;
254        let sector_shift = sector_size.trailing_zeros();
255        let physical_sector_size = devmeta.physical_block_size.max(sector_size);
256        let sector_count = devmeta.disk_size >> sector_shift;
257        let unmap_granularity = devmeta.discard_granularity >> sector_shift;
258        let file = Arc::new(file);
259        let device = BlockDevice {
260            file,
261            sector_size,
262            physical_sector_size,
263            sector_shift: sector_size.trailing_zeros(),
264            sector_count: sector_count.into(),
265            optimal_unmap_sectors: unmap_granularity,
266            read_only,
267            driver: Box::new(driver),
268            device_type: devmeta.device_type,
269            supports_pr: devmeta.supports_pr,
270            supports_fua: devmeta.fua,
271            _uevent_filter: uevent_filter,
272            resize_epoch,
273            resized_acked: 0.into(),
274            bounce_buffer_tracker,
275            always_bounce,
276        };
277
278        Ok(device)
279    }
280
281    /// Use a box to avoid embedding a large `TrackedBounceBuffer` directly in
282    /// the calling future.
283    async fn acquire_bounce_buffer(&self, size: usize) -> Box<MaybeBounceBuffer<'_>> {
284        Box::new(if let Some(tracker) = &self.bounce_buffer_tracker {
285            MaybeBounceBuffer::Tracked(
286                tracker
287                    .acquire_bounce_buffers(size, affinity::get_cpu_number() as usize)
288                    .await,
289            )
290        } else {
291            MaybeBounceBuffer::Untracked(BounceBuffer::new(size))
292        })
293    }
294
295    fn handle_resize(&self) {
296        if let Err(err) = self.handle_resize_inner() {
297            tracing::error!(
298                error = &err as &dyn std::error::Error,
299                "failed to update disk size"
300            );
301        }
302    }
303
304    fn handle_resize_inner(&self) -> std::io::Result<()> {
305        let mut acked = self.resized_acked.load(Ordering::SeqCst);
306        loop {
307            let epoch = self.resize_epoch.epoch.load(Ordering::SeqCst);
308            if acked == epoch {
309                break Ok(());
310            }
311
312            let size_in_bytes = ioctl::query_block_device_size_in_bytes(&self.file)?;
313
314            let new_sector_count = size_in_bytes / self.sector_size as u64;
315            let original_sector_count = self.sector_count.load(Ordering::SeqCst);
316
317            tracing::debug!(original_sector_count, new_sector_count, "resize");
318            if original_sector_count != new_sector_count {
319                tracing::info!(
320                    original_sector_count,
321                    new_sector_count,
322                    "Disk size updating..."
323                );
324                self.sector_count.store(new_sector_count, Ordering::SeqCst);
325            }
326
327            acked = self
328                .resized_acked
329                .compare_exchange(acked, epoch, Ordering::SeqCst, Ordering::SeqCst)
330                .unwrap_or_else(|x| x);
331        }
332    }
333
334    fn map_io_error(&self, err: std::io::Error) -> DiskError {
335        if !matches!(self.device_type, DeviceType::File { .. }) {
336            match err.raw_os_error() {
337                Some(libc::EBADE) => return DiskError::ReservationConflict,
338                Some(libc::ENOSPC) => return DiskError::IllegalBlock,
339                _ => {}
340            }
341        }
342        DiskError::Io(err)
343    }
344}
345
346struct DeviceMetadata {
347    device_type: DeviceType,
348    disk_size: u64,
349    logical_block_size: u32,
350    physical_block_size: u32,
351    discard_granularity: u32,
352    supports_pr: bool,
353    fua: bool,
354}
355
356impl DeviceMetadata {
357    fn from_block_device(file: &fs::File, major: u32, minor: u32) -> anyhow::Result<Self> {
358        // Ensure the sysfs path exists.
359        let devpath = PathBuf::from(format!("/sys/dev/block/{major}:{minor}"));
360        devpath
361            .fs_err_metadata()
362            .context("could not find sysfs path for block device")?;
363
364        let mut supports_pr = false;
365
366        // Check for NVMe by looking for the namespace ID.
367        let device_type = match fs_err::read_to_string(devpath.join("nsid")) {
368            Ok(ns_id) => {
369                let ns_id = ns_id
370                    .trim()
371                    .parse()
372                    .context("failed to parse NVMe namespace ID")?;
373
374                let rescap = nvme::nvme_identify_namespace_data(file, ns_id)?.rescap;
375                let oncs = nvme::nvme_identify_controller_data(file)?.oncs;
376                tracing::debug!(rescap = ?rescap, oncs = ?oncs, "get identify data");
377                supports_pr = oncs.reservations() && u8::from(rescap) != 0;
378                Some(DeviceType::NVMe { ns_id, rescap })
379            }
380            Err(err) if err.kind() == std::io::ErrorKind::NotFound => None,
381            Err(err) => Err(err).context("failed to read NVMe namespace ID")?,
382        };
383
384        // Fall back to unknown.
385        let device_type = device_type.unwrap_or(DeviceType::UnknownBlock);
386
387        fn read_val<T: FromStr>(devpath: &Path, path: &str, msg: &str) -> anyhow::Result<T>
388        where
389            T::Err: 'static + std::error::Error + Send + Sync,
390        {
391            fs_err::read_to_string(devpath.join(path))
392                .with_context(|| format!("failed to read {msg}"))?
393                .trim()
394                .parse()
395                .with_context(|| format!("failed to parse {msg}"))
396        }
397
398        let logical_block_size = read_val(&devpath, "queue/logical_block_size", "sector size")?;
399        let physical_block_size = read_val(
400            &devpath,
401            "queue/physical_block_size",
402            "physical sector size",
403        )?;
404
405        // sys/dev/block/*/*/size shows the size in 512-byte
406        // sectors irrespective of the block device
407        let disk_size = read_val::<u64>(&devpath, "size", "disk size")? * 512;
408        let discard_granularity =
409            read_val(&devpath, "queue/discard_granularity", "discard granularity")?;
410
411        let fua = read_val::<u8>(&devpath, "queue/fua", "fua")? != 0;
412
413        Self {
414            device_type,
415            disk_size,
416            logical_block_size,
417            physical_block_size,
418            discard_granularity,
419            supports_pr,
420            fua,
421        }
422        .validate()
423    }
424
425    fn from_file(metadata: &fs::Metadata) -> anyhow::Result<Self> {
426        let logical_block_size = 512;
427        Self {
428            device_type: DeviceType::File {
429                sector_count: metadata.len() / logical_block_size as u64,
430            },
431            disk_size: metadata.size(),
432            logical_block_size,
433            physical_block_size: metadata.blksize() as u32,
434            discard_granularity: 0,
435            supports_pr: false,
436            fua: false,
437        }
438        .validate()
439    }
440
441    fn validate(self) -> anyhow::Result<Self> {
442        let Self {
443            device_type: _,
444            disk_size,
445            logical_block_size,
446            physical_block_size,
447            discard_granularity,
448            supports_pr: _,
449            fua: _,
450        } = self;
451        if logical_block_size < 512 || !logical_block_size.is_power_of_two() {
452            anyhow::bail!("invalid sector size {logical_block_size}");
453        }
454        if !physical_block_size.is_power_of_two() {
455            anyhow::bail!("invalid physical sector size {physical_block_size}");
456        }
457        if disk_size % logical_block_size as u64 != 0 {
458            anyhow::bail!("invalid disk size {disk_size:#x}");
459        }
460        if discard_granularity % logical_block_size != 0 {
461            anyhow::bail!("invalid discard granularity {discard_granularity}");
462        }
463        Ok(self)
464    }
465}
466
467impl DiskIo for BlockDevice {
468    fn disk_type(&self) -> &str {
469        "block_device"
470    }
471
472    fn sector_count(&self) -> u64 {
473        if self.resize_epoch.epoch.load(Ordering::Relaxed)
474            != self.resized_acked.load(Ordering::Relaxed)
475        {
476            self.handle_resize();
477        }
478        self.sector_count.load(Ordering::Relaxed)
479    }
480
481    fn sector_size(&self) -> u32 {
482        self.sector_size
483    }
484
485    fn disk_id(&self) -> Option<[u8; 16]> {
486        None
487    }
488
489    fn physical_sector_size(&self) -> u32 {
490        self.physical_sector_size
491    }
492
493    fn is_fua_respected(&self) -> bool {
494        self.supports_fua
495    }
496
497    fn is_read_only(&self) -> bool {
498        self.read_only
499    }
500
501    fn pr(&self) -> Option<&dyn PersistentReservation> {
502        if self.supports_pr { Some(self) } else { None }
503    }
504
505    async fn eject(&self) -> Result<(), DiskError> {
506        let file = self.file.clone();
507        unblock(move || {
508            ioctl::lockdoor(&file, false)?;
509            ioctl::eject(&file)
510        })
511        .await
512        .map_err(|err| self.map_io_error(err))?;
513        Ok(())
514    }
515
516    async fn read_vectored(
517        &self,
518        buffers: &RequestBuffers<'_>,
519        sector: u64,
520    ) -> Result<(), DiskError> {
521        let io_size = buffers.len();
522        tracing::trace!(sector, io_size, "read_vectored");
523
524        let mut bounce_buffer = None;
525        let locked;
526        let should_bounce = self.always_bounce || !buffers.is_aligned(self.sector_size() as usize);
527        let io_vecs = if !should_bounce {
528            locked = buffers.lock(true)?;
529            locked.io_vecs()
530        } else {
531            tracing::trace!("double buffering IO");
532
533            bounce_buffer
534                .insert(self.acquire_bounce_buffer(buffers.len()).await)
535                .io_vecs()
536        };
537
538        // SAFETY: `io_vecs` and the underlying locked pages are locals
539        // in this `async fn`--they are part of the same state machine as
540        // the returned future and will not be freed before it completes
541        // or is dropped (which aborts).
542        let bytes_read = unsafe {
543            self.driver.io_uring_submit(
544                opcode::Readv::new(
545                    types::Fd(self.file.as_raw_fd()),
546                    io_vecs.as_ptr().cast(),
547                    io_vecs.len() as u32,
548                )
549                .offset((sector * self.sector_size() as u64) as _)
550                .build(),
551            )
552        }
553        .await
554        .map_err(|err| self.map_io_error(err))?;
555        tracing::trace!(bytes_read, "read_vectored");
556        if bytes_read != io_size as i32 {
557            return Err(DiskError::IllegalBlock);
558        }
559
560        if let Some(mut bounce_buffer) = bounce_buffer {
561            buffers.writer().write(bounce_buffer.as_mut_bytes())?;
562        }
563        Ok(())
564    }
565
566    async fn write_vectored(
567        &self,
568        buffers: &RequestBuffers<'_>,
569        sector: u64,
570        fua: bool,
571    ) -> Result<(), DiskError> {
572        let io_size = buffers.len();
573        tracing::trace!(sector, io_size, "write_vectored");
574
575        // Ensure the write doesn't extend the file.
576        if let DeviceType::File { sector_count } = self.device_type {
577            if sector + (io_size as u64 >> self.sector_shift) > sector_count {
578                return Err(DiskError::IllegalBlock);
579            }
580        }
581
582        let mut bounce_buffer;
583        let locked;
584        let should_bounce = self.always_bounce || !buffers.is_aligned(self.sector_size() as usize);
585        let io_vecs = if !should_bounce {
586            locked = buffers.lock(false)?;
587            locked.io_vecs()
588        } else {
589            tracing::trace!("double buffering IO");
590            bounce_buffer = self.acquire_bounce_buffer(buffers.len()).await;
591            buffers.reader().read(bounce_buffer.as_mut_bytes())?;
592            bounce_buffer.io_vecs()
593        };
594
595        // SAFETY: `io_vecs` and the underlying locked pages are locals
596        // in this `async fn`--they are part of the same state machine as
597        // the returned future and will not be freed before it completes
598        // or is dropped (which aborts).
599        let bytes_written = unsafe {
600            self.driver.io_uring_submit(
601                opcode::Writev::new(
602                    types::Fd(self.file.as_raw_fd()),
603                    io_vecs.as_ptr().cast::<libc::iovec>(),
604                    io_vecs.len() as _,
605                )
606                .offset((sector * self.sector_size() as u64) as _)
607                .rw_flags(if fua { libc::RWF_DSYNC } else { 0 })
608                .build(),
609            )
610        }
611        .await
612        .map_err(|err| self.map_io_error(err))?;
613        tracing::trace!(bytes_written, "write_vectored");
614        if bytes_written != io_size as i32 {
615            return Err(DiskError::IllegalBlock);
616        }
617
618        Ok(())
619    }
620
621    async fn sync_cache(&self) -> Result<(), DiskError> {
622        // SAFETY: No data buffers.
623        unsafe {
624            self.driver
625                .io_uring_submit(opcode::Fsync::new(types::Fd(self.file.as_raw_fd())).build())
626        }
627        .await
628        .map_err(|err| self.map_io_error(err))?;
629        Ok(())
630    }
631
632    async fn wait_resize(&self, sector_count: u64) -> u64 {
633        loop {
634            let listen = self.resize_epoch.event.listen();
635            let current = self.sector_count();
636            if current != sector_count {
637                break current;
638            }
639            listen.await;
640        }
641    }
642
643    async fn unmap(
644        &self,
645        sector_offset: u64,
646        sector_count: u64,
647        _block_level_only: bool,
648    ) -> Result<(), DiskError> {
649        let file = self.file.clone();
650        let file_offset = sector_offset << self.sector_shift;
651        let length = sector_count << self.sector_shift;
652        tracing::debug!(file = ?file, file_offset, length, "unmap_async");
653        match unblock(move || ioctl::discard(&file, file_offset, length)).await {
654            Ok(()) => {}
655            Err(_) if sector_offset + sector_count > self.sector_count() => {
656                return Err(DiskError::IllegalBlock);
657            }
658            Err(err) => return Err(self.map_io_error(err)),
659        }
660        Ok(())
661    }
662
663    fn unmap_behavior(&self) -> UnmapBehavior {
664        if self.optimal_unmap_sectors == 0 {
665            UnmapBehavior::Ignored
666        } else {
667            UnmapBehavior::Unspecified
668        }
669    }
670
671    fn optimal_unmap_sectors(&self) -> u32 {
672        self.optimal_unmap_sectors
673    }
674}
675
676#[async_trait::async_trait]
677impl PersistentReservation for BlockDevice {
678    fn capabilities(&self) -> ReservationCapabilities {
679        match &self.device_type {
680            &DeviceType::NVMe { rescap, .. } => {
681                nvme_common::from_nvme_reservation_capabilities(rescap)
682            }
683            DeviceType::File { .. } | DeviceType::UnknownBlock => unreachable!(),
684        }
685    }
686
687    async fn report(&self) -> Result<ReservationReport, DiskError> {
688        assert!(matches!(self.device_type, DeviceType::NVMe { .. }));
689        self.nvme_persistent_reservation_report()
690            .await
691            .map_err(|err| self.map_io_error(err))
692    }
693
694    async fn register(
695        &self,
696        current_key: Option<u64>,
697        new_key: u64,
698        ptpl: Option<bool>,
699    ) -> Result<(), DiskError> {
700        assert!(matches!(self.device_type, DeviceType::NVMe { .. }));
701
702        // The Linux kernel interface to register does not allow ptpl to be
703        // configured. We could manually issue an NVMe command, but this code
704        // path is not really used anyway.
705        if ptpl == Some(false) {
706            tracing::warn!("ignoring guest request to disable persist through power loss");
707        }
708
709        let file = self.file.clone();
710        unblock(move || {
711            ioctl::pr_register(
712                &file,
713                current_key.unwrap_or(0),
714                new_key,
715                if current_key.is_none() {
716                    ioctl::PR_FL_IGNORE_KEY
717                } else {
718                    0
719                },
720            )
721        })
722        .await
723        .and_then(check_nvme_status)
724        .map_err(|err| self.map_io_error(err))?;
725        Ok(())
726    }
727
728    async fn reserve(&self, key: u64, reservation_type: ReservationType) -> Result<(), DiskError> {
729        assert!(matches!(self.device_type, DeviceType::NVMe { .. }));
730        let file = self.file.clone();
731        unblock(move || ioctl::pr_reserve(&file, reservation_type, key))
732            .await
733            .and_then(check_nvme_status)
734            .map_err(|err| self.map_io_error(err))?;
735        Ok(())
736    }
737
738    async fn release(&self, key: u64, reservation_type: ReservationType) -> Result<(), DiskError> {
739        assert!(matches!(self.device_type, DeviceType::NVMe { .. }));
740        let file = self.file.clone();
741        unblock(move || ioctl::pr_release(&file, reservation_type, key))
742            .await
743            .and_then(check_nvme_status)
744            .map_err(|err| self.map_io_error(err))?;
745        Ok(())
746    }
747
748    async fn clear(&self, key: u64) -> Result<(), DiskError> {
749        assert!(matches!(self.device_type, DeviceType::NVMe { .. }));
750        let file = self.file.clone();
751        unblock(move || ioctl::pr_clear(&file, key))
752            .await
753            .and_then(check_nvme_status)
754            .map_err(|err| self.map_io_error(err))?;
755        Ok(())
756    }
757
758    async fn preempt(
759        &self,
760        current_key: u64,
761        preempt_key: u64,
762        reservation_type: ReservationType,
763        abort: bool,
764    ) -> Result<(), DiskError> {
765        assert!(matches!(self.device_type, DeviceType::NVMe { .. }));
766        let file = self.file.clone();
767        unblock(move || {
768            ioctl::pr_preempt(&file, reservation_type, current_key, preempt_key, abort)
769        })
770        .await
771        .and_then(check_nvme_status)
772        .map_err(|err| self.map_io_error(err))?;
773        Ok(())
774    }
775}
776
777#[cfg(test)]
778mod tests {
779    use super::*;
780    use futures::executor::block_on;
781    use guestmem::GuestMemory;
782    use hvdef::HV_PAGE_SIZE;
783    use hvdef::HV_PAGE_SIZE_USIZE;
784    use once_cell::sync::OnceCell;
785    use pal_async::async_test;
786    use pal_uring::IoUringPool;
787    use pal_uring::PoolClient;
788    use scsi_buffers::OwnedRequestBuffers;
789
790    fn is_buggy_kernel() -> bool {
791        // 5.13 kernels seem to have a bug with io_uring where tests hang.
792        let output = String::from_utf8(
793            std::process::Command::new("uname")
794                .arg("-r")
795                .output()
796                .unwrap()
797                .stdout,
798        )
799        .unwrap();
800
801        output.contains("5.13")
802    }
803
804    fn new_block_device() -> Result<BlockDevice, NewDeviceError> {
805        // TODO: switch to std::sync::OnceLock once `get_or_try_init` is stable
806        static POOL: OnceCell<PoolClient> = OnceCell::new();
807
808        let client = POOL
809            .get_or_try_init(|| {
810                let pool = IoUringPool::new("test", 16)?;
811                let client = pool.client().clone();
812                std::thread::spawn(|| pool.run());
813                Ok(client)
814            })
815            .map_err(|err| NewDeviceError::IoctlError(DiskError::Io(err)))?;
816
817        let test_file = tempfile::tempfile().unwrap();
818        test_file.set_len(1024 * 64).unwrap();
819        block_on(BlockDevice::new(
820            test_file.try_clone().unwrap(),
821            false,
822            client.initiator().clone(),
823            None,
824            None,
825            false,
826        ))
827    }
828
829    macro_rules! get_block_device_or_skip {
830        () => {
831            match new_block_device() {
832                Ok(pool) => {
833                    if is_buggy_kernel() {
834                        println!("Test case skipped (buggy kernel version)");
835                        return;
836                    }
837
838                    pool
839                }
840                Err(NewDeviceError::IoctlError(DiskError::Io(err)))
841                    if err.raw_os_error() == Some(libc::ENOSYS) =>
842                {
843                    println!("Test case skipped (no IO-Uring support)");
844                    return;
845                }
846                Err(err) => panic!("{}", err),
847            }
848        };
849    }
850
851    async fn run_async_disk_io(fua: bool) {
852        let disk = get_block_device_or_skip!();
853
854        let test_guest_mem = GuestMemory::allocate(0x8000);
855        test_guest_mem
856            .write_at(0, &(0..0x8000).map(|x| x as u8).collect::<Vec<_>>())
857            .unwrap();
858
859        let write_buffers = OwnedRequestBuffers::new(&[3, 2, 1, 0]);
860        disk.write_vectored(&write_buffers.buffer(&test_guest_mem), 0, fua)
861            .await
862            .unwrap();
863
864        if !fua {
865            disk.sync_cache().await.unwrap();
866        }
867
868        let read_buffers = OwnedRequestBuffers::new(&[7, 6, 5, 4]);
869        disk.read_vectored(&read_buffers.buffer(&test_guest_mem), 0)
870            .await
871            .unwrap();
872
873        let mut source = vec![0u8; 4 * HV_PAGE_SIZE_USIZE];
874        test_guest_mem.read_at(0, &mut source).unwrap();
875        let mut target = vec![0u8; 4 * HV_PAGE_SIZE_USIZE];
876        test_guest_mem
877            .read_at(4 * HV_PAGE_SIZE, &mut target)
878            .unwrap();
879        assert_eq!(source, target);
880    }
881
882    #[async_test]
883    async fn test_async_disk_io() {
884        run_async_disk_io(false).await;
885    }
886
887    #[async_test]
888    async fn test_async_disk_io_fua() {
889        run_async_disk_io(true).await;
890    }
891
892    async fn run_async_disk_io_unaligned(fua: bool) {
893        let disk = get_block_device_or_skip!();
894
895        let test_guest_mem = GuestMemory::allocate(0x8000);
896        test_guest_mem
897            .write_at(0, &(0..0x8000).map(|x| x as u8).collect::<Vec<_>>())
898            .unwrap();
899
900        let write_buffers =
901            OwnedRequestBuffers::new_unaligned(&[0, 1, 2, 3], 512, 3 * HV_PAGE_SIZE_USIZE);
902
903        disk.write_vectored(&write_buffers.buffer(&test_guest_mem), 0, fua)
904            .await
905            .unwrap();
906
907        if !fua {
908            disk.sync_cache().await.unwrap();
909        }
910
911        let read_buffers =
912            OwnedRequestBuffers::new_unaligned(&[4, 5, 6, 7], 512, 3 * HV_PAGE_SIZE_USIZE);
913        disk.read_vectored(&read_buffers.buffer(&test_guest_mem), 0)
914            .await
915            .unwrap();
916
917        let mut source = vec![0u8; 3 * HV_PAGE_SIZE_USIZE];
918        test_guest_mem.read_at(512, &mut source).unwrap();
919        let mut target = vec![0u8; 3 * HV_PAGE_SIZE_USIZE];
920        test_guest_mem
921            .read_at(4 * HV_PAGE_SIZE + 512, &mut target)
922            .unwrap();
923        assert_eq!(source, target);
924    }
925
926    #[async_test]
927    async fn test_async_disk_io_unaligned() {
928        run_async_disk_io_unaligned(false).await;
929    }
930
931    #[async_test]
932    async fn test_async_disk_io_unaligned_fua() {
933        run_async_disk_io_unaligned(true).await;
934    }
935
936    #[async_test]
937    async fn test_illegal_lba() {
938        let disk = get_block_device_or_skip!();
939        let gm = GuestMemory::allocate(512);
940        match disk
941            .write_vectored(
942                &OwnedRequestBuffers::linear(0, 512, true).buffer(&gm),
943                i64::MAX as u64 / 512,
944                false,
945            )
946            .await
947        {
948            Err(DiskError::IllegalBlock) => {}
949            r => panic!("unexpected result: {:?}", r),
950        }
951    }
952}