1#![forbid(unsafe_code)]
17
18pub mod backing;
19mod protocol;
20pub mod resolver;
21pub mod single_file_backing;
22
23use async_trait::async_trait;
24use inspect::Inspect;
25use inspect::InspectMut;
26use std::io::IoSlice;
27use task_control::Cancelled;
28use task_control::StopTask;
29use thiserror::Error;
30use vmbus_async::async_dgram::AsyncRecvExt;
31use vmbus_channel::bus::OfferParams;
32use vmbus_channel::gpadl_ring::GpadlRingMem;
33use vmbus_channel::simple::SimpleVmbusDevice;
34use vmcore::save_restore::SavedStateNotSupported;
35use zerocopy::FromBytes;
36use zerocopy::IntoBytes;
37
38#[derive(InspectMut)]
40pub struct VmbfsDevice {
41 #[inspect(mut)]
42 backing: Box<dyn backing::VmbfsIo>,
43}
44
45impl VmbfsDevice {
46 pub fn new(backing: Box<dyn backing::VmbfsIo>) -> Self {
48 Self { backing }
49 }
50}
51
52#[async_trait]
53impl SimpleVmbusDevice for VmbfsDevice {
54 type SavedState = SavedStateNotSupported;
55 type Runner = VmbfsChannel;
56
57 fn offer(&self) -> OfferParams {
58 OfferParams {
59 interface_name: "vmbfs".to_owned(),
60 channel_type: vmbus_channel::bus::ChannelType::Device { pipe_packets: true },
61 instance_id: protocol::IMC_INSTANCE,
64 interface_id: protocol::INTERFACE_TYPE,
65 ..OfferParams::default()
66 }
67 }
68
69 fn inspect(&mut self, req: inspect::Request<'_>, runner: Option<&mut Self::Runner>) {
70 req.respond().merge(self).merge(runner);
71 }
72
73 fn open(
74 &mut self,
75 channel: vmbus_channel::RawAsyncChannel<GpadlRingMem>,
76 _guest_memory: guestmem::GuestMemory,
77 ) -> Result<Self::Runner, vmbus_channel::channel::ChannelOpenError> {
78 Ok(VmbfsChannel {
79 state: State::VersionRequest,
80 pipe: vmbus_async::pipe::MessagePipe::new(channel)?,
81 buf: vec![0; protocol::MAX_MESSAGE_SIZE],
82 })
83 }
84
85 async fn run(
86 &mut self,
87 stop: &mut StopTask<'_>,
88 runner: &mut Self::Runner,
89 ) -> Result<(), Cancelled> {
90 stop.until_stopped(runner.process(self)).await
91 }
92
93 fn supports_save_restore(
94 &mut self,
95 ) -> Option<
96 &mut dyn vmbus_channel::simple::SaveRestoreSimpleVmbusDevice<
97 SavedState = Self::SavedState,
98 Runner = Self::Runner,
99 >,
100 > {
101 None
102 }
103}
104
105#[doc(hidden)] #[derive(InspectMut)]
107pub struct VmbfsChannel {
108 state: State,
109 #[inspect(mut)]
110 pipe: vmbus_async::pipe::MessagePipe<GpadlRingMem>,
111 buf: Vec<u8>,
112}
113
114#[derive(Inspect)]
115enum State {
116 VersionRequest,
117 Ready,
118}
119
120#[derive(Debug)]
121enum Request {
122 Version(protocol::Version),
123 GetFileInfo(String),
124 ReadFile {
125 byte_count: u32,
126 offset: u64,
127 path: String,
128 },
129}
130
131impl VmbfsChannel {
132 async fn process(&mut self, dev: &mut VmbfsDevice) {
133 match self.process_inner(dev).await {
134 Ok(()) => {}
135 Err(err) => {
136 tracing::error!(error = &err as &dyn std::error::Error, "vmbfs failed");
137 }
138 }
139 }
140
141 async fn process_inner(&mut self, dev: &mut VmbfsDevice) -> Result<(), DeviceError> {
142 loop {
143 self.pipe
144 .wait_write_ready(protocol::MAX_MESSAGE_SIZE)
145 .await
146 .map_err(DeviceError::Pipe)?;
147
148 match self.state {
149 State::VersionRequest => match self.read_message().await? {
150 Request::Version(version) => {
151 let ok = match version {
152 protocol::Version::WIN10 => true,
153 version => {
154 tracing::debug!(?version, "unsupported version");
155 false
156 }
157 };
158 self.pipe
159 .try_send_vectored(&[
160 IoSlice::new(
161 protocol::MessageHeader {
162 message_type: protocol::MessageType::VERSION_RESPONSE,
163 reserved: 0,
164 }
165 .as_bytes(),
166 ),
167 IoSlice::new(
168 protocol::VersionResponse {
169 status: if ok {
170 protocol::VersionStatus::SUPPORTED
171 } else {
172 protocol::VersionStatus::UNSUPPORTED
173 },
174 }
175 .as_bytes(),
176 ),
177 ])
178 .map_err(DeviceError::Pipe)?;
179
180 if ok {
181 self.state = State::Ready;
182 }
183 }
184 _ => return Err(DeviceError::UnexpectedMessage),
185 },
186 State::Ready => match self.read_message().await? {
187 Request::GetFileInfo(path) => {
188 self.handle_get_file_info(dev, &path)?;
189 }
190 Request::ReadFile {
191 byte_count,
192 offset,
193 path,
194 } => self.handle_read_file(dev, &path, offset, byte_count)?,
195 _ => return Err(DeviceError::UnexpectedMessage),
196 },
197 }
198 }
199 }
200
201 fn handle_get_file_info(
202 &mut self,
203 dev: &mut VmbfsDevice,
204 path: &str,
205 ) -> Result<(), DeviceError> {
206 let response = match dev.backing.file_info(path) {
207 Ok(info) => protocol::GetFileInfoResponse {
208 status: protocol::Status::SUCCESS,
209 flags: protocol::FileInfoFlags::new().with_directory(info.directory),
212 file_size: info.file_size,
213 },
214 Err(err) => protocol::GetFileInfoResponse {
215 status: err.to_protocol(),
216 flags: protocol::FileInfoFlags::new(),
217 file_size: 0,
218 },
219 };
220 self.pipe
221 .try_send_vectored(&[
222 IoSlice::new(
223 protocol::MessageHeader {
224 message_type: protocol::MessageType::GET_FILE_INFO_RESPONSE,
225 reserved: 0,
226 }
227 .as_bytes(),
228 ),
229 IoSlice::new(response.as_bytes()),
230 ])
231 .map_err(DeviceError::Pipe)?;
232 Ok(())
233 }
234
235 fn handle_read_file(
236 &mut self,
237 dev: &mut VmbfsDevice,
238 path: &str,
239 offset: u64,
240 byte_count: u32,
241 ) -> Result<(), DeviceError> {
242 if byte_count > protocol::MAX_READ_SIZE as u32 {
243 return Err(DeviceError::ReadTooLarge);
244 }
245 let buf = &mut self.buf[..byte_count as usize];
246 let (status, buf) = match dev.backing.read_file(path, offset, buf) {
247 Ok(()) => (protocol::Status::SUCCESS, &*buf),
248 Err(err) => (err.to_protocol(), &[] as _),
249 };
250 self.pipe
251 .try_send_vectored(&[
252 IoSlice::new(
253 protocol::MessageHeader {
254 message_type: protocol::MessageType::READ_FILE_RESPONSE,
255 reserved: 0,
256 }
257 .as_bytes(),
258 ),
259 IoSlice::new(protocol::ReadFileResponse { status }.as_bytes()),
260 IoSlice::new(buf),
261 ])
262 .map_err(DeviceError::Pipe)?;
263 Ok(())
264 }
265
266 async fn read_message(&mut self) -> Result<Request, DeviceError> {
267 let n = self
268 .pipe
269 .recv(&mut self.buf)
270 .await
271 .map_err(DeviceError::Pipe)?;
272
273 let buf = &self.buf[..n];
274 let (header, buf) =
275 protocol::MessageHeader::read_from_prefix(buf).map_err(|_| DeviceError::TooShort)?; let request = match header.message_type {
278 protocol::MessageType::VERSION_REQUEST => {
279 let version = protocol::VersionRequest::read_from_prefix(buf)
280 .map_err(|_| DeviceError::TooShort)?
281 .0; Request::Version(version.requested_version)
283 }
284 protocol::MessageType::GET_FILE_INFO_REQUEST => Request::GetFileInfo(parse_path(buf)?),
285 protocol::MessageType::READ_FILE_REQUEST => {
286 let (read, buf) = protocol::ReadFileRequest::read_from_prefix(buf)
287 .map_err(|_| DeviceError::TooShort)?; Request::ReadFile {
289 byte_count: read.byte_count,
290 offset: read.offset.get(),
291 path: parse_path(buf)?,
292 }
293 }
294 ty => return Err(DeviceError::InvalidMessageType(ty)),
295 };
296
297 tracing::trace!(?request, "message");
298 Ok(request)
299 }
300}
301
302fn parse_path(buf: &[u8]) -> Result<String, DeviceError> {
303 let buf = <[u16]>::ref_from_bytes(buf).map_err(|_| DeviceError::Unaligned)?; if buf.contains(&0) {
305 return Err(DeviceError::NullTerminatorInPath);
306 }
307 let path = String::from_utf16(buf).map_err(|_| DeviceError::InvalidUtf16Path)?;
308 Ok(path.replace('\\', "/"))
309}
310
311#[derive(Debug, Error)]
312enum DeviceError {
313 #[error("vmbus pipe error")]
314 Pipe(#[source] std::io::Error),
315 #[error("message too short")]
316 TooShort,
317 #[error("unaligned message")]
318 Unaligned,
319 #[error("invalid utf-16 path")]
320 InvalidUtf16Path,
321 #[error("null terminator in path")]
322 NullTerminatorInPath,
323 #[error("unexpected message")]
324 UnexpectedMessage,
325 #[error("invalid message type: {0:#x?}")]
326 InvalidMessageType(protocol::MessageType),
327 #[error("read too large")]
328 ReadTooLarge,
329}