1#![expect(missing_docs)]
11#![forbid(unsafe_code)]
12
13use async_trait::async_trait;
14use disk_backend::Disk;
15use disk_backend::DiskError;
16use disk_backend::DiskIo;
17use disk_backend::pr;
18use disk_backend::pr::ReservationType;
19use disk_backend::resolve::ResolveDiskParameters;
20use disk_backend::resolve::ResolvedDisk;
21use disk_backend_resources::DiskWithReservationsHandle;
22use inspect::Inspect;
23use parking_lot::Mutex;
24use scsi_buffers::RequestBuffers;
25use std::future::Future;
26use std::num::NonZeroU64;
27use std::num::Wrapping;
28use thiserror::Error;
29use vm_resource::AsyncResolveResource;
30use vm_resource::ResolveError;
31use vm_resource::ResourceResolver;
32use vm_resource::declare_static_async_resolver;
33use vm_resource::kind::DiskHandleKind;
34
35pub struct DiskWithReservationsResolver;
36declare_static_async_resolver!(
37 DiskWithReservationsResolver,
38 (DiskHandleKind, DiskWithReservationsHandle)
39);
40
41#[derive(Debug, Error)]
42pub enum ResolvePrDiskError {
43 #[error("failed to resolve inner disk")]
44 Resolve(#[source] ResolveError),
45 #[error("invalid disk")]
46 InvalidDisk(#[source] disk_backend::InvalidDisk),
47}
48
49#[async_trait]
50impl AsyncResolveResource<DiskHandleKind, DiskWithReservationsHandle>
51 for DiskWithReservationsResolver
52{
53 type Output = ResolvedDisk;
54 type Error = ResolvePrDiskError;
55
56 async fn resolve(
57 &self,
58 resolver: &ResourceResolver,
59 rsrc: DiskWithReservationsHandle,
60 input: ResolveDiskParameters<'_>,
61 ) -> Result<Self::Output, Self::Error> {
62 let inner = resolver
63 .resolve(rsrc.0, input)
64 .await
65 .map_err(ResolvePrDiskError::Resolve)?;
66
67 ResolvedDisk::new(DiskWithReservations::new(inner.0))
68 .map_err(ResolvePrDiskError::InvalidDisk)
69 }
70}
71
72#[derive(Inspect)]
77pub struct DiskWithReservations {
78 inner: Disk,
79 #[inspect(flatten)]
80 state: Mutex<ReservationState>,
81}
82
83#[derive(Default, Debug, Inspect)]
84struct ReservationState {
85 generation: Wrapping<u32>,
86 registered_key: Option<NonZeroU64>,
87 reservation_type: Option<ReservationType>,
88 persist_through_power_loss: bool,
89}
90
91impl DiskWithReservations {
92 pub fn new(inner: Disk) -> Self {
94 Self {
95 inner,
96 state: Default::default(),
97 }
98 }
99}
100
101impl DiskIo for DiskWithReservations {
102 fn disk_type(&self) -> &str {
103 "prwrap"
104 }
105
106 fn sector_count(&self) -> u64 {
107 self.inner.sector_count()
108 }
109
110 fn sector_size(&self) -> u32 {
111 self.inner.sector_size()
112 }
113
114 fn disk_id(&self) -> Option<[u8; 16]> {
115 self.inner.disk_id()
116 }
117
118 fn physical_sector_size(&self) -> u32 {
119 self.inner.physical_sector_size()
120 }
121
122 fn is_fua_respected(&self) -> bool {
123 self.inner.is_fua_respected()
124 }
125
126 fn is_read_only(&self) -> bool {
127 self.inner.is_read_only()
128 }
129
130 fn unmap(
131 &self,
132 sector: u64,
133 count: u64,
134 block_level_only: bool,
135 ) -> impl Future<Output = Result<(), DiskError>> + Send {
136 self.inner.unmap(sector, count, block_level_only)
137 }
138
139 fn unmap_behavior(&self) -> disk_backend::UnmapBehavior {
140 self.inner.unmap_behavior()
141 }
142
143 fn optimal_unmap_sectors(&self) -> u32 {
144 self.inner.optimal_unmap_sectors()
145 }
146
147 fn pr(&self) -> Option<&dyn pr::PersistentReservation> {
148 Some(self)
149 }
150
151 async fn read_vectored(
152 &self,
153 buffers: &RequestBuffers<'_>,
154 sector: u64,
155 ) -> Result<(), DiskError> {
156 self.inner.read_vectored(buffers, sector).await
157 }
158
159 async fn write_vectored(
160 &self,
161 buffers: &RequestBuffers<'_>,
162 sector: u64,
163 fua: bool,
164 ) -> Result<(), DiskError> {
165 self.inner.write_vectored(buffers, sector, fua).await
166 }
167
168 fn sync_cache(&self) -> impl Future<Output = Result<(), DiskError>> + Send {
169 self.inner.sync_cache()
170 }
171}
172
173#[async_trait]
174impl pr::PersistentReservation for DiskWithReservations {
175 fn capabilities(&self) -> pr::ReservationCapabilities {
176 pr::ReservationCapabilities {
177 write_exclusive: true,
178 exclusive_access: true,
179 write_exclusive_registrants_only: true,
180 exclusive_access_registrants_only: true,
181 write_exclusive_all_registrants: false,
182 exclusive_access_all_registrants: false,
183 persist_through_power_loss: true,
184 }
185 }
186
187 async fn report(&self) -> Result<pr::ReservationReport, DiskError> {
188 tracing::info!("reading full status");
189 let state = self.state.lock();
190 let report = pr::ReservationReport {
191 generation: state.generation.0,
192 reservation_type: state.reservation_type,
193 persist_through_power_loss: state.persist_through_power_loss,
194 controllers: state
195 .registered_key
196 .iter()
197 .map(|&key| pr::RegisteredController {
198 key: key.get(),
199 host_id: vec![0; 8],
200 controller_id: 0,
201 holds_reservation: state.reservation_type.is_some(),
202 })
203 .collect(),
204 };
205 Ok(report)
206 }
207
208 async fn register(
209 &self,
210 current_key: Option<u64>,
211 new_key: u64,
212 ptpl: Option<bool>,
213 ) -> Result<(), DiskError> {
214 let mut state = self.state.lock();
215 if let Some(current_key) = current_key {
216 if state.registered_key != NonZeroU64::new(current_key) {
217 return Err(DiskError::ReservationConflict);
218 }
219 }
220 let new_key = NonZeroU64::new(new_key);
221 state.registered_key = new_key;
222 if new_key.is_none() {
223 state.reservation_type = None;
224 }
225 if let Some(ptpl) = ptpl {
226 state.persist_through_power_loss = ptpl;
227 }
228 state.generation += 1;
229 Ok(())
230 }
231
232 async fn reserve(&self, key: u64, reservation_type: ReservationType) -> Result<(), DiskError> {
233 let mut state = self.state.lock();
234 if state.registered_key.is_none()
235 || state.registered_key != NonZeroU64::new(key)
236 || (state.reservation_type.is_some()
237 && state.reservation_type != Some(reservation_type))
238 {
239 return Err(DiskError::ReservationConflict);
240 }
241 state.reservation_type = Some(reservation_type);
242 Ok(())
243 }
244
245 async fn release(&self, key: u64, reservation_type: ReservationType) -> Result<(), DiskError> {
246 let mut state = self.state.lock();
247 if state.registered_key.is_none() || state.registered_key != NonZeroU64::new(key) {
248 return Err(DiskError::ReservationConflict);
249 }
250
251 if state.reservation_type.is_some() {
252 if state.reservation_type != Some(reservation_type) {
253 return Err(DiskError::InvalidInput);
254 }
255 state.reservation_type = None;
256 }
257 Ok(())
258 }
259
260 async fn clear(&self, key: u64) -> Result<(), DiskError> {
261 let mut state = self.state.lock();
262 if state.registered_key.is_none() || state.registered_key != NonZeroU64::new(key) {
263 return Err(DiskError::ReservationConflict);
264 }
265 state.registered_key = None;
266 state.reservation_type = None;
267 state.generation += 1;
268 Ok(())
269 }
270
271 async fn preempt(
272 &self,
273 current_key: u64,
274 preempt_key: u64,
275 reservation_type: ReservationType,
276 _abort: bool,
277 ) -> Result<(), DiskError> {
278 let mut state = self.state.lock();
279 if state.registered_key.is_none() || state.registered_key != NonZeroU64::new(current_key) {
280 return Err(DiskError::ReservationConflict);
281 }
282 if state.registered_key != NonZeroU64::new(preempt_key)
283 || (state.reservation_type.is_some()
284 && state.reservation_type != Some(reservation_type))
285 {
286 return Err(DiskError::InvalidInput);
287 }
288
289 state.reservation_type = None;
290 state.generation += 1;
291 Ok(())
292 }
293}