1use crate::Options;
7use crate::load;
8use anyhow::Context as _;
9use futures::StreamExt as _;
10use guestmem::GuestMemory;
11use hvdef::HvError;
12use hvdef::Vtl;
13use pal_async::DefaultDriver;
14use std::sync::Arc;
15use virt::PartitionCapabilities;
16use virt::Processor;
17use virt::StopVpSource;
18use virt::VpIndex;
19use virt::io::CpuIo;
20use virt::vp::AccessVpState as _;
21use vm_topology::memory::MemoryLayout;
22use vm_topology::processor::ProcessorTopology;
23use vm_topology::processor::TopologyBuilder;
24use vmcore::vmtime::VmTime;
25use vmcore::vmtime::VmTimeKeeper;
26use vmcore::vmtime::VmTimeSource;
27use zerocopy::TryFromBytes as _;
28
29pub const COMMAND_ADDRESS: u64 = 0xffff_0000;
30
31pub struct CommonState {
32    pub driver: DefaultDriver,
33    pub opts: Options,
34    pub processor_topology: ProcessorTopology,
35    pub memory_layout: MemoryLayout,
36}
37
38pub struct RunContext<'a> {
39    pub state: &'a CommonState,
40    pub vmtime_source: &'a VmTimeSource,
41}
42
43#[derive(Debug, Clone)]
44pub enum TestResult {
45    Passed,
46    Failed,
47    Faulted {
48        vp_index: VpIndex,
49        reason: String,
50        regs: Option<Box<virt::vp::Registers>>,
51    },
52}
53
54impl CommonState {
55    pub async fn new(driver: DefaultDriver, opts: Options) -> anyhow::Result<Self> {
56        #[cfg(guest_arch = "x86_64")]
57        let processor_topology = TopologyBuilder::new_x86()
58            .x2apic(vm_topology::processor::x86::X2ApicState::Supported)
59            .build(1)
60            .context("failed to build processor topology")?;
61
62        #[cfg(guest_arch = "aarch64")]
63        let processor_topology = TopologyBuilder::new_aarch64(
64            vm_topology::processor::arch::GicInfo {
65                gic_distributor_base: 0xff000000,
66                gic_redistributors_base: 0xff020000,
67            },
68            0,
69        )
70        .build(1)
71        .context("failed to build processor topology")?;
72
73        let ram_size = 0x400000;
74        let memory_layout = MemoryLayout::new(ram_size, &[], None).context("bad memory layout")?;
75
76        Ok(Self {
77            driver,
78            opts,
79            processor_topology,
80            memory_layout,
81        })
82    }
83
84    pub async fn for_each_test(
85        &mut self,
86        mut f: impl AsyncFnMut(&mut RunContext<'_>, &load::TestInfo) -> anyhow::Result<TestResult>,
87    ) -> anyhow::Result<()> {
88        let tmk = fs_err::File::open(&self.opts.tmk).context("failed to open tmk")?;
89        let available_tests = load::enumerate_tests(&tmk)?;
90        let tests = if self.opts.tests.is_empty() {
91            available_tests
92        } else {
93            self.opts
94                .tests
95                .iter()
96                .map(|name| {
97                    available_tests
98                        .iter()
99                        .find(|test| test.name == *name)
100                        .cloned()
101                        .with_context(|| format!("test {} not found", name))
102                })
103                .collect::<anyhow::Result<Vec<_>>>()?
104        };
105        let mut success = true;
106        for test in &tests {
107            tracing::info!(target: "test", name = test.name, "test started");
108
109            let mut vmtime_keeper = VmTimeKeeper::new(&self.driver, VmTime::from_100ns(0));
110            let vmtime_source = vmtime_keeper.builder().build(&self.driver).await.unwrap();
111            let mut ctx = RunContext {
112                state: self,
113                vmtime_source: &vmtime_source,
114            };
115
116            vmtime_keeper.start().await;
117
118            let r = f(&mut ctx, test)
119                .await
120                .with_context(|| format!("failed to run test {}", test.name))?;
121
122            vmtime_keeper.stop().await;
123
124            match r {
125                TestResult::Passed => {
126                    tracing::info!(target: "test", name = test.name, "test passed");
127                }
128                TestResult::Failed => {
129                    tracing::error!(target: "test", name = test.name, reason = "explicit failure", "test failed");
130                    success = false;
131                }
132                TestResult::Faulted {
133                    vp_index,
134                    reason,
135                    regs,
136                } => {
137                    tracing::error!(
138                        target: "test",
139                        name = test.name,
140                        vp_index = vp_index.index(),
141                        reason,
142                        regs = format_args!("{:#x?}", regs),
143                        "test failed"
144                    );
145                    success = false;
146                }
147            }
148        }
149        if !success {
150            anyhow::bail!("some tests failed");
151        }
152        Ok(())
153    }
154}
155
156impl RunContext<'_> {
157    pub async fn run(
158        &mut self,
159        guest_memory: &GuestMemory,
160        caps: &PartitionCapabilities,
161        test: &load::TestInfo,
162        start_vp: impl AsyncFnOnce(&mut Self, RunnerBuilder) -> anyhow::Result<()>,
163    ) -> anyhow::Result<TestResult> {
164        let (event_send, mut event_recv) = mesh::channel();
165
166        let tmk = fs_err::File::open(&self.state.opts.tmk).context("failed to open tmk")?;
168        let regs = {
169            #[cfg(guest_arch = "x86_64")]
170            {
171                load::load_x86(
172                    &self.state.memory_layout,
173                    guest_memory,
174                    &self.state.processor_topology,
175                    caps,
176                    &tmk,
177                    test,
178                )?
179            }
180            #[cfg(guest_arch = "aarch64")]
181            {
182                load::load_aarch64(
183                    &self.state.memory_layout,
184                    guest_memory,
185                    &self.state.processor_topology,
186                    caps,
187                    &tmk,
188                    test,
189                )?
190            }
191        };
192
193        start_vp(
194            self,
195            RunnerBuilder::new(
196                VpIndex::BSP,
197                Arc::clone(®s),
198                guest_memory.clone(),
199                event_send.clone(),
200            ),
201        )
202        .await?;
203
204        let event = event_recv.next().await.unwrap();
205        let r = match event {
206            VpEvent::TestComplete { success } => {
207                if success {
208                    TestResult::Passed
209                } else {
210                    TestResult::Failed
211                }
212            }
213            VpEvent::Halt {
214                vp_index,
215                reason,
216                regs,
217            } => TestResult::Faulted {
218                vp_index,
219                reason,
220                regs,
221            },
222        };
223
224        Ok(r)
225    }
226}
227
228enum VpEvent {
229    TestComplete {
230        success: bool,
231    },
232    Halt {
233        vp_index: VpIndex,
234        reason: String,
235        regs: Option<Box<virt::vp::Registers>>,
236    },
237}
238
239struct IoHandler<'a> {
240    guest_memory: &'a GuestMemory,
241    event_send: &'a mesh::Sender<VpEvent>,
242    stop: &'a StopVpSource,
243}
244
245fn widen(d: &[u8]) -> u64 {
246    let mut v = [0; 8];
247    v[..d.len()].copy_from_slice(d);
248    u64::from_ne_bytes(v)
249}
250
251impl CpuIo for IoHandler<'_> {
252    fn is_mmio(&self, _address: u64) -> bool {
253        false
254    }
255
256    fn acknowledge_pic_interrupt(&self) -> Option<u8> {
257        None
258    }
259
260    fn handle_eoi(&self, irq: u32) {
261        tracing::info!(irq, "eoi");
262    }
263
264    fn signal_synic_event(&self, vtl: Vtl, connection_id: u32, flag: u16) -> hvdef::HvResult<()> {
265        let _ = (vtl, connection_id, flag);
266        Err(HvError::InvalidConnectionId)
267    }
268
269    fn post_synic_message(
270        &self,
271        vtl: Vtl,
272        connection_id: u32,
273        secure: bool,
274        message: &[u8],
275    ) -> hvdef::HvResult<()> {
276        let _ = (vtl, connection_id, secure, message);
277        Err(HvError::InvalidConnectionId)
278    }
279
280    async fn read_mmio(&self, vp: VpIndex, address: u64, data: &mut [u8]) {
281        tracing::info!(vp = vp.index(), address, "read mmio");
282        data.fill(!0);
283    }
284
285    async fn write_mmio(&self, vp: VpIndex, address: u64, data: &[u8]) {
286        if address == COMMAND_ADDRESS {
287            let p = widen(data);
288            let r = self.handle_command(p);
289            if let Err(e) = r {
290                tracing::error!(
291                    error = e.as_ref() as &dyn std::error::Error,
292                    p,
293                    "failed to handle command"
294                );
295            }
296        } else {
297            tracing::info!(vp = vp.index(), address, data = widen(data), "write mmio");
298        }
299    }
300
301    async fn read_io(&self, vp: VpIndex, port: u16, data: &mut [u8]) {
302        tracing::info!(vp = vp.index(), port, "read io");
303        data.fill(!0);
304    }
305
306    async fn write_io(&self, vp: VpIndex, port: u16, data: &[u8]) {
307        tracing::info!(vp = vp.index(), port, data = widen(data), "write io");
308    }
309
310    #[track_caller]
311    fn fatal_error(&self, error: Box<dyn std::error::Error + Send + Sync>) -> virt::VpHaltReason {
312        tracing::error!(
313            err = error.as_ref() as &dyn std::error::Error,
314            "fatal error"
315        );
316        virt::VpHaltReason::TripleFault { vtl: Vtl::Vtl0 }
317    }
318}
319
320impl IoHandler<'_> {
321    fn read_str(&self, s: tmk_protocol::StrDescriptor) -> anyhow::Result<String> {
322        let mut buf = vec![0; s.len as usize];
323        self.guest_memory
324            .read_at(s.gpa, &mut buf)
325            .context("failed to read string")?;
326        String::from_utf8(buf).context("string not utf-8")
327    }
328
329    fn handle_command(&self, gpa: u64) -> anyhow::Result<()> {
330        let buf = self
331            .guest_memory
332            .read_plain::<[u8; size_of::<tmk_protocol::Command>()]>(gpa)
333            .context("failed to read command")?;
334        let cmd = tmk_protocol::Command::try_read_from_bytes(&buf)
335            .ok()
336            .context("bad command")?;
337        match cmd {
338            tmk_protocol::Command::Log(s) => {
339                let message = self.read_str(s)?;
340                tracing::info!(target: "tmk", message);
341            }
342            tmk_protocol::Command::Panic {
343                message,
344                filename,
345                line,
346            } => {
347                let message = self.read_str(message)?;
348                let location = if filename.len > 0 {
349                    Some(format!("{}:{}", self.read_str(filename)?, line))
350                } else {
351                    None
352                };
353                tracing::error!(target: "tmk", location, panic = message);
354                self.event_send
355                    .send(VpEvent::TestComplete { success: false });
356                self.stop.stop();
357            }
358            tmk_protocol::Command::Complete { success } => {
359                self.event_send.send(VpEvent::TestComplete { success });
360                self.stop.stop();
361            }
362        }
363        Ok(())
364    }
365}
366
367pub struct RunnerBuilder {
368    vp_index: VpIndex,
369    regs: Arc<virt::InitialRegs>,
370    guest_memory: GuestMemory,
371    event_send: mesh::Sender<VpEvent>,
372}
373
374impl RunnerBuilder {
375    fn new(
376        vp_index: VpIndex,
377        regs: Arc<virt::InitialRegs>,
378        guest_memory: GuestMemory,
379        event_send: mesh::Sender<VpEvent>,
380    ) -> Self {
381        Self {
382            vp_index,
383            regs,
384            guest_memory,
385            event_send,
386        }
387    }
388
389    pub fn build<P: Processor>(&mut self, mut vp: P) -> anyhow::Result<Runner<'_, P>> {
390        {
391            let mut state = vp.access_state(Vtl::Vtl0);
392            #[cfg(guest_arch = "x86_64")]
393            {
394                let virt::x86::X86InitialRegs {
395                    registers,
396                    mtrrs,
397                    pat,
398                } = self.regs.as_ref();
399                state.set_registers(registers)?;
400                state.set_mtrrs(mtrrs)?;
401                state.set_pat(pat)?;
402            }
403            #[cfg(guest_arch = "aarch64")]
404            {
405                let virt::aarch64::Aarch64InitialRegs {
406                    registers,
407                    system_registers,
408                } = self.regs.as_ref();
409                state.set_registers(registers)?;
410                state.set_system_registers(system_registers)?;
411            }
412            state.commit()?;
413        }
414        Ok(Runner {
415            vp,
416            vp_index: self.vp_index,
417            guest_memory: &self.guest_memory,
418            event_send: &self.event_send,
419        })
420    }
421}
422
423pub struct Runner<'a, P> {
424    vp: P,
425    vp_index: VpIndex,
426    guest_memory: &'a GuestMemory,
427    event_send: &'a mesh::Sender<VpEvent>,
428}
429
430impl<P: Processor> Runner<'_, P> {
431    pub async fn run_vp(&mut self) {
432        let stop = StopVpSource::new();
433        let Err(err) = self
434            .vp
435            .run_vp(
436                stop.checker(),
437                &IoHandler {
438                    guest_memory: self.guest_memory,
439                    event_send: self.event_send,
440                    stop: &stop,
441                },
442            )
443            .await;
444        let regs = self
445            .vp
446            .access_state(Vtl::Vtl0)
447            .registers()
448            .map(Box::new)
449            .ok();
450        self.event_send.send(VpEvent::Halt {
451            vp_index: self.vp_index,
452            reason: format!("{:?}", err),
453            regs,
454        });
455    }
456}