vmm_core/partition_unit/
vp_set.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Virtual processor state management.
5
6use super::HaltReason;
7use super::HaltReasonReceiver;
8use super::InternalHaltReason;
9#[cfg(feature = "gdb")]
10use anyhow::Context as _;
11use async_trait::async_trait;
12use futures::FutureExt;
13use futures::StreamExt;
14use futures::future::JoinAll;
15use futures::future::TryJoinAll;
16use futures::stream::select_with_strategy;
17use futures_concurrency::future::Race;
18use futures_concurrency::stream::Merge;
19use guestmem::GuestMemory;
20use hvdef::Vtl;
21use inspect::Inspect;
22use mesh::rpc::Rpc;
23use mesh::rpc::RpcError;
24use mesh::rpc::RpcSend;
25use parking_lot::Mutex;
26use slab::Slab;
27use std::future::Future;
28use std::pin::Pin;
29use std::pin::pin;
30use std::sync::Arc;
31use std::task::Context;
32use std::task::Poll;
33use std::task::Waker;
34use thiserror::Error;
35use tracing::instrument;
36use virt::InitialRegs;
37use virt::Processor;
38use virt::StopVp;
39use virt::StopVpSource;
40use virt::VpHaltReason;
41use virt::VpIndex;
42use virt::VpStopped;
43use virt::io::CpuIo;
44use virt::vp::AccessVpState;
45use vm_topology::processor::TargetVpInfo;
46use vmcore::save_restore::ProtobufSaveRestore;
47use vmcore::save_restore::RestoreError;
48use vmcore::save_restore::SaveError;
49use vmcore::save_restore::SavedStateBlob;
50#[cfg(feature = "gdb")]
51use vmm_core_defs::debug_rpc::DebuggerVpState;
52
53const NUM_VTLS: usize = 3;
54
55/// Trait for controlling a VP on a bound partition.
56#[async_trait(?Send)]
57trait ControlVp: ProtobufSaveRestore {
58    /// Run the VP until `stop` says to stop.
59    async fn run_vp(
60        &mut self,
61        vtl_guest_memory: &[Option<GuestMemory>; NUM_VTLS],
62        stop: StopVp<'_>,
63    ) -> Result<StopReason, HaltReason>;
64
65    /// Inspect the VP.
66    fn inspect_vp(&mut self, gm: &[Option<GuestMemory>; NUM_VTLS], req: inspect::Request<'_>);
67
68    /// Sets the register state at first boot.
69    fn set_initial_regs(
70        &mut self,
71        vtl: Vtl,
72        state: &InitialRegs,
73        to_set: RegistersToSet,
74    ) -> Result<(), RegisterSetError>;
75
76    #[cfg(feature = "gdb")]
77    fn debug(&mut self) -> &mut dyn DebugVp;
78}
79
80enum StopReason {
81    OnRequest(VpStopped),
82    Cancel,
83}
84
85#[derive(Copy, Clone, Debug, PartialEq, Eq)]
86pub enum RegistersToSet {
87    All,
88    MtrrsOnly,
89}
90
91#[cfg(feature = "gdb")]
92trait DebugVp {
93    fn set_debug_state(
94        &mut self,
95        vtl: Vtl,
96        state: Option<&virt::x86::DebugState>,
97    ) -> anyhow::Result<()>;
98
99    fn set_vp_state(&mut self, vtl: Vtl, state: &DebuggerVpState) -> anyhow::Result<()>;
100
101    fn get_vp_state(&mut self, vtl: Vtl) -> anyhow::Result<Box<DebuggerVpState>>;
102}
103
104struct BoundVp<'a, T, U> {
105    vp: &'a mut T,
106    io: &'a U,
107    vp_index: VpIndex,
108}
109
110impl<T: ProtobufSaveRestore, U> ProtobufSaveRestore for BoundVp<'_, T, U> {
111    fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
112        self.vp.save()
113    }
114
115    fn restore(&mut self, state: SavedStateBlob) -> Result<(), RestoreError> {
116        self.vp.restore(state)
117    }
118}
119
120#[async_trait(?Send)]
121impl<T, U> ControlVp for BoundVp<'_, T, U>
122where
123    T: Processor + ProtobufSaveRestore,
124    U: CpuIo,
125{
126    async fn run_vp(
127        &mut self,
128        vtl_guest_memory: &[Option<GuestMemory>; NUM_VTLS],
129        stop: StopVp<'_>,
130    ) -> Result<StopReason, HaltReason> {
131        let r = self.vp.run_vp(stop, self.io).await;
132        // Convert the inner error type to a generic one.
133        match r.unwrap_err() {
134            VpHaltReason::Stop(stop) => Ok(StopReason::OnRequest(stop)),
135            VpHaltReason::Cancel => Ok(StopReason::Cancel),
136            VpHaltReason::PowerOff => Err(HaltReason::PowerOff),
137            VpHaltReason::Reset => Err(HaltReason::Reset),
138            VpHaltReason::TripleFault { vtl } => {
139                let registers = self.vp.access_state(vtl).registers().ok().map(Arc::new);
140
141                tracing::error!(?vtl, vp = self.vp_index.index(), "triple fault");
142                self.trace_fault(
143                    vtl,
144                    vtl_guest_memory[vtl as usize].as_ref(),
145                    registers.as_deref(),
146                );
147                Err(HaltReason::TripleFault {
148                    vp: self.vp_index.index(),
149                    registers,
150                })
151            }
152            VpHaltReason::InvalidVmState(err) => {
153                tracing::error!(err = &err as &dyn std::error::Error, "invalid vm state");
154                Err(HaltReason::InvalidVmState {
155                    vp: self.vp_index.index(),
156                })
157            }
158            VpHaltReason::EmulationFailure(error) => {
159                tracing::error!(error, "emulation failure");
160                Err(HaltReason::VpError {
161                    vp: self.vp_index.index(),
162                })
163            }
164            VpHaltReason::Hypervisor(err) => {
165                tracing::error!(err = &err as &dyn std::error::Error, "fatal vp error");
166                Err(HaltReason::VpError {
167                    vp: self.vp_index.index(),
168                })
169            }
170            VpHaltReason::SingleStep => {
171                tracing::debug!("single step");
172                Err(HaltReason::SingleStep {
173                    vp: self.vp_index.index(),
174                })
175            }
176            VpHaltReason::HwBreak(breakpoint) => {
177                tracing::debug!(?breakpoint, "hardware breakpoint");
178                Err(HaltReason::HwBreakpoint {
179                    vp: self.vp_index.index(),
180                    breakpoint,
181                })
182            }
183        }
184    }
185
186    fn inspect_vp(
187        &mut self,
188        vtl_guest_memory: &[Option<GuestMemory>; NUM_VTLS],
189        req: inspect::Request<'_>,
190    ) {
191        let mut resp = req.respond();
192        resp.merge(&mut *self.vp);
193        for (name, vtl) in [
194            ("vtl0", Vtl::Vtl0),
195            ("vtl1", Vtl::Vtl1),
196            ("vtl2", Vtl::Vtl2),
197        ] {
198            if self.vp.vtl_inspectable(vtl) {
199                resp.field_mut(
200                    name,
201                    &mut inspect::adhoc_mut(|req| {
202                        self.inspect_vtl(vtl_guest_memory[vtl as usize].as_ref(), req, vtl)
203                    }),
204                );
205            }
206        }
207    }
208
209    fn set_initial_regs(
210        &mut self,
211        vtl: Vtl,
212        state: &InitialRegs,
213        to_set: RegistersToSet,
214    ) -> Result<(), RegisterSetError> {
215        let InitialRegs {
216            registers,
217            #[cfg(guest_arch = "x86_64")]
218            mtrrs,
219            #[cfg(guest_arch = "x86_64")]
220            pat,
221            #[cfg(guest_arch = "aarch64")]
222            system_registers,
223        } = state;
224        let mut access = self.vp.access_state(vtl);
225        // Only set the registers on the BSP.
226        if self.vp_index.is_bsp() && to_set == RegistersToSet::All {
227            access
228                .set_registers(registers)
229                .map_err(|err| RegisterSetError("registers", err.into()))?;
230
231            #[cfg(guest_arch = "aarch64")]
232            access
233                .set_system_registers(system_registers)
234                .map_err(|err| RegisterSetError("system_registers", err.into()))?;
235        }
236
237        // Set MTRRs and PAT on all VPs.
238        #[cfg(guest_arch = "x86_64")]
239        access
240            .set_mtrrs(mtrrs)
241            .map_err(|err| RegisterSetError("mtrrs", err.into()))?;
242        #[cfg(guest_arch = "x86_64")]
243        access
244            .set_pat(pat)
245            .map_err(|err| RegisterSetError("pat", err.into()))?;
246
247        Ok(())
248    }
249
250    #[cfg(feature = "gdb")]
251    fn debug(&mut self) -> &mut dyn DebugVp {
252        self
253    }
254}
255
256impl<T, U> BoundVp<'_, T, U>
257where
258    T: Processor + ProtobufSaveRestore,
259    U: CpuIo,
260{
261    fn inspect_vtl(&mut self, gm: Option<&GuestMemory>, req: inspect::Request<'_>, vtl: Vtl) {
262        let mut resp = req.respond();
263        resp.field("enabled", true)
264            .merge(self.vp.access_state(vtl).inspect_all());
265
266        let _ = gm;
267        #[cfg(all(guest_arch = "x86_64", feature = "gdb"))]
268        if let Some(gm) = gm {
269            let registers = self.vp.access_state(vtl).registers();
270            if let Ok(registers) = &registers {
271                resp.field_with("next_instruction", || {
272                    Some(
273                        vp_state::next_instruction(gm, self.debug(), vtl, registers).map_or_else(
274                            |err| format!("{:#}", err),
275                            |(instr, _)| instr.to_string(),
276                        ),
277                    )
278                })
279                .field_with("previous_instruction", || {
280                    Some(
281                        vp_state::previous_instruction(gm, self.debug(), vtl, registers)
282                            .map_or_else(|err| format!("{:#}", err), |instr| instr.to_string()),
283                    )
284                });
285            }
286        }
287    }
288
289    #[cfg(guest_arch = "x86_64")]
290    fn trace_fault(
291        &mut self,
292        vtl: Vtl,
293        guest_memory: Option<&GuestMemory>,
294        registers: Option<&virt::x86::vp::Registers>,
295    ) {
296        use cvm_tracing::CVM_CONFIDENTIAL;
297
298        #[cfg(not(feature = "gdb"))]
299        let _ = (guest_memory, vtl);
300
301        let Some(registers) = registers else {
302            return;
303        };
304
305        let virt::x86::vp::Registers {
306            rax,
307            rcx,
308            rdx,
309            rbx,
310            rsp,
311            rbp,
312            rsi,
313            rdi,
314            r8,
315            r9,
316            r10,
317            r11,
318            r12,
319            r13,
320            r14,
321            r15,
322            rip,
323            rflags,
324            cs,
325            ds,
326            es,
327            fs,
328            gs,
329            ss,
330            tr,
331            ldtr,
332            gdtr,
333            idtr,
334            cr0,
335            cr2,
336            cr3,
337            cr4,
338            cr8,
339            efer,
340        } = *registers;
341        tracing::error!(
342            CVM_CONFIDENTIAL,
343            vp = self.vp_index.index(),
344            ?vtl,
345            rax,
346            rcx,
347            rdx,
348            rbx,
349            rsp,
350            rbp,
351            rsi,
352            rdi,
353            r8,
354            r9,
355            r10,
356            r11,
357            r12,
358            r13,
359            r14,
360            r15,
361            rip,
362            rflags,
363            "triple fault register state",
364        );
365        tracing::error!(
366            CVM_CONFIDENTIAL,
367            ?vtl,
368            vp = self.vp_index.index(),
369            ?cs,
370            ?ds,
371            ?es,
372            ?fs,
373            ?gs,
374            ?ss,
375            ?tr,
376            ?ldtr,
377            ?gdtr,
378            ?idtr,
379            cr0,
380            cr2,
381            cr3,
382            cr4,
383            cr8,
384            efer,
385            "triple fault system register state",
386        );
387
388        #[cfg(feature = "gdb")]
389        if let Some(guest_memory) = guest_memory {
390            if let Ok((instr, bytes)) =
391                vp_state::next_instruction(guest_memory, self, vtl, registers)
392            {
393                tracing::error!(
394                    CVM_CONFIDENTIAL,
395                    instruction = instr.to_string(),
396                    ?bytes,
397                    "faulting instruction"
398                );
399            }
400        }
401    }
402
403    #[cfg(guest_arch = "aarch64")]
404    fn trace_fault(
405        &mut self,
406        _vtl: Vtl,
407        _guest_memory: Option<&GuestMemory>,
408        _registers: Option<&virt::aarch64::vp::Registers>,
409    ) {
410        // TODO
411    }
412}
413
414#[cfg(feature = "gdb")]
415impl<T: Processor, U> DebugVp for BoundVp<'_, T, U> {
416    fn set_debug_state(
417        &mut self,
418        vtl: Vtl,
419        state: Option<&virt::x86::DebugState>,
420    ) -> anyhow::Result<()> {
421        self.vp
422            .set_debug_state(vtl, state)
423            .context("failed to set debug state")
424    }
425
426    #[cfg(guest_arch = "x86_64")]
427    fn set_vp_state(&mut self, vtl: Vtl, state: &DebuggerVpState) -> anyhow::Result<()> {
428        let mut access = self.vp.access_state(vtl);
429        let DebuggerVpState::X86_64(state) = state else {
430            anyhow::bail!("wrong architecture")
431        };
432        let regs = virt::x86::vp::Registers {
433            rax: state.gp[0],
434            rcx: state.gp[1],
435            rdx: state.gp[2],
436            rbx: state.gp[3],
437            rsp: state.gp[4],
438            rbp: state.gp[5],
439            rsi: state.gp[6],
440            rdi: state.gp[7],
441            r8: state.gp[8],
442            r9: state.gp[9],
443            r10: state.gp[10],
444            r11: state.gp[11],
445            r12: state.gp[12],
446            r13: state.gp[13],
447            r14: state.gp[14],
448            r15: state.gp[15],
449            rip: state.rip,
450            rflags: state.rflags,
451            cs: state.cs,
452            ds: state.ds,
453            es: state.es,
454            fs: state.fs,
455            gs: state.gs,
456            ss: state.ss,
457            cr0: state.cr0,
458            cr2: state.cr2,
459            cr3: state.cr3,
460            cr4: state.cr4,
461            cr8: state.cr8,
462            efer: state.efer,
463            ..access.registers()?
464        };
465        let msrs = virt::x86::vp::VirtualMsrs {
466            kernel_gs_base: state.kernel_gs_base,
467            ..access.virtual_msrs()?
468        };
469        access.set_registers(&regs)?;
470        access.set_virtual_msrs(&msrs)?;
471        access.commit()?;
472        Ok(())
473    }
474
475    #[cfg(guest_arch = "x86_64")]
476    fn get_vp_state(&mut self, vtl: Vtl) -> anyhow::Result<Box<DebuggerVpState>> {
477        let mut access = self.vp.access_state(vtl);
478        let regs = access.registers()?;
479        let msrs = access.virtual_msrs()?;
480        Ok(Box::new(DebuggerVpState::X86_64(
481            vmm_core_defs::debug_rpc::X86VpState {
482                gp: [
483                    regs.rax, regs.rcx, regs.rdx, regs.rbx, regs.rsp, regs.rbp, regs.rsi, regs.rdi,
484                    regs.r8, regs.r9, regs.r10, regs.r11, regs.r12, regs.r13, regs.r14, regs.r15,
485                ],
486                rip: regs.rip,
487                rflags: regs.rflags,
488                cr0: regs.cr0,
489                cr2: regs.cr2,
490                cr3: regs.cr3,
491                cr4: regs.cr4,
492                cr8: regs.cr8,
493                efer: regs.efer,
494                kernel_gs_base: msrs.kernel_gs_base,
495                es: regs.es,
496                cs: regs.cs,
497                ss: regs.ss,
498                ds: regs.ds,
499                fs: regs.fs,
500                gs: regs.gs,
501            },
502        )))
503    }
504
505    #[cfg(guest_arch = "aarch64")]
506    fn set_vp_state(&mut self, vtl: Vtl, state: &DebuggerVpState) -> anyhow::Result<()> {
507        let DebuggerVpState::Aarch64(state) = state else {
508            anyhow::bail!("wrong architecture")
509        };
510        let mut access = self.vp.access_state(vtl);
511        let regs = virt::aarch64::vp::Registers {
512            x0: state.x[0],
513            x1: state.x[1],
514            x2: state.x[2],
515            x3: state.x[3],
516            x4: state.x[4],
517            x5: state.x[5],
518            x6: state.x[6],
519            x7: state.x[7],
520            x8: state.x[8],
521            x9: state.x[9],
522            x10: state.x[10],
523            x11: state.x[11],
524            x12: state.x[12],
525            x13: state.x[13],
526            x14: state.x[14],
527            x15: state.x[15],
528            x16: state.x[16],
529            x17: state.x[17],
530            x18: state.x[18],
531            x19: state.x[19],
532            x20: state.x[20],
533            x21: state.x[21],
534            x22: state.x[22],
535            x23: state.x[23],
536            x24: state.x[24],
537            x25: state.x[25],
538            x26: state.x[26],
539            x27: state.x[27],
540            x28: state.x[28],
541            fp: state.x[29],
542            lr: state.x[30],
543            sp_el0: state.sp_el0,
544            sp_el1: state.sp_el1,
545            pc: state.pc,
546            cpsr: state.cpsr,
547        };
548        let sregs = virt::aarch64::vp::SystemRegisters {
549            sctlr_el1: state.sctlr_el1,
550            tcr_el1: state.tcr_el1,
551            ttbr0_el1: state.ttbr0_el1,
552            ttbr1_el1: state.ttbr1_el1,
553            ..access.system_registers()?
554        };
555        access.set_registers(&regs)?;
556        access.set_system_registers(&sregs)?;
557        access.commit()?;
558        Ok(())
559    }
560
561    #[cfg(guest_arch = "aarch64")]
562    fn get_vp_state(&mut self, vtl: Vtl) -> anyhow::Result<Box<DebuggerVpState>> {
563        let mut access = self.vp.access_state(vtl);
564        let regs = access.registers()?;
565        let sregs = access.system_registers()?;
566
567        Ok(Box::new(DebuggerVpState::Aarch64(
568            vmm_core_defs::debug_rpc::Aarch64VpState {
569                x: [
570                    regs.x0, regs.x1, regs.x2, regs.x3, regs.x4, regs.x5, regs.x6, regs.x7,
571                    regs.x8, regs.x9, regs.x10, regs.x11, regs.x12, regs.x13, regs.x14, regs.x15,
572                    regs.x16, regs.x17, regs.x18, regs.x19, regs.x20, regs.x21, regs.x22, regs.x23,
573                    regs.x24, regs.x25, regs.x26, regs.x27, regs.x28, regs.fp, regs.lr,
574                ],
575                sp_el0: regs.sp_el0,
576                sp_el1: regs.sp_el1,
577                pc: regs.pc,
578                cpsr: regs.cpsr,
579                sctlr_el1: sregs.sctlr_el1,
580                tcr_el1: sregs.tcr_el1,
581                ttbr0_el1: sregs.ttbr0_el1,
582                ttbr1_el1: sregs.ttbr1_el1,
583            },
584        )))
585    }
586}
587
588/// Tracks whether the VP should halt due to a guest-initiated condition (triple
589/// fault, etc.).
590#[derive(Inspect)]
591pub struct Halt {
592    #[inspect(flatten)]
593    state: Mutex<HaltState>,
594    #[inspect(skip)]
595    send: mesh::Sender<InternalHaltReason>,
596}
597
598#[derive(Default, Inspect)]
599struct HaltState {
600    halt_count: usize,
601    #[inspect(skip)]
602    wakers: Slab<Option<Waker>>,
603}
604
605impl Halt {
606    /// Returns a new halt object, plus a receiver to asynchronously receive the
607    /// reason for a halt.
608    pub fn new() -> (Self, HaltReasonReceiver) {
609        let (send, recv) = mesh::channel();
610        (
611            Self {
612                state: Default::default(),
613                send,
614            },
615            HaltReasonReceiver(recv),
616        )
617    }
618
619    /// Halts all VPs and sends the halt reason to the receiver returned by
620    /// [`Self::new()`].
621    ///
622    /// After this returns, it's guaranteed that any VPs that try to run again
623    /// will instead halt. So if this is called from a VP thread, it will ensure
624    /// that that VP will not resume.
625    pub fn halt(&self, reason: HaltReason) {
626        self.halt_internal(InternalHaltReason::Halt(reason));
627    }
628
629    /// Halts all VPs temporarily, resets their variable MTRRs to their initial
630    /// state, then resumes the VPs.
631    ///
632    /// This is used by the legacy BIOS, since it stomps over the variable MTRRs
633    /// in undesirable ways and is difficult to fix.
634    pub fn replay_mtrrs(&self) {
635        self.halt_internal(InternalHaltReason::ReplayMtrrs);
636    }
637
638    fn halt_internal(&self, reason: InternalHaltReason) {
639        // Set the VP halt state immediately and wake them up.
640        let mut inner = self.state.lock();
641        inner.halt_count += 1;
642        for waker in inner.wakers.iter_mut().filter_map(|x| x.1.take()) {
643            waker.wake();
644        }
645
646        // Send the halt reason asynchronously.
647        self.send.send(reason);
648    }
649
650    /// Clears a single halt reason. Must be called for each halt reason that
651    /// arrives in order to resume the VM.
652    fn clear_halt(&self) {
653        let mut inner = self.state.lock();
654        inner.halt_count = inner
655            .halt_count
656            .checked_sub(1)
657            .expect("too many halt clears");
658    }
659
660    fn is_halted(&self) -> bool {
661        self.state.lock().halt_count != 0
662    }
663
664    fn halted(&self) -> Halted<'_> {
665        Halted {
666            halt: self,
667            idx: None,
668        }
669    }
670}
671
672struct Halted<'a> {
673    halt: &'a Halt,
674    idx: Option<usize>,
675}
676
677impl Clone for Halted<'_> {
678    fn clone(&self) -> Self {
679        Self {
680            halt: self.halt,
681            idx: None,
682        }
683    }
684}
685
686impl Future for Halted<'_> {
687    type Output = ();
688
689    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
690        let mut halt = self.halt.state.lock();
691        if halt.halt_count != 0 {
692            return Poll::Ready(());
693        }
694
695        if let Some(idx) = self.idx {
696            halt.wakers[idx] = Some(cx.waker().clone());
697        } else {
698            self.idx = Some(halt.wakers.insert(Some(cx.waker().clone())));
699        }
700        Poll::Pending
701    }
702}
703
704impl Drop for Halted<'_> {
705    fn drop(&mut self) {
706        if let Some(idx) = self.idx {
707            self.halt.state.lock().wakers.remove(idx);
708        }
709    }
710}
711
712#[derive(Inspect)]
713struct Inner {
714    #[inspect(flatten)]
715    halt: Arc<Halt>,
716    #[inspect(skip)]
717    vtl_guest_memory: [Option<GuestMemory>; NUM_VTLS],
718}
719
720#[derive(Inspect)]
721pub struct VpSet {
722    #[inspect(flatten)]
723    inner: Arc<Inner>,
724    #[inspect(rename = "vp", iter_by_index)]
725    vps: Vec<Vp>,
726    #[inspect(skip)]
727    started: bool,
728}
729
730struct Vp {
731    send: mesh::Sender<VpEvent>,
732    done: mesh::OneshotReceiver<()>,
733    vp_info: TargetVpInfo,
734}
735
736impl Inspect for Vp {
737    fn inspect(&self, req: inspect::Request<'_>) {
738        req.respond()
739            .merge(&self.vp_info)
740            .merge(inspect::adhoc(|req| {
741                self.send
742                    .send(VpEvent::State(StateEvent::Inspect(req.defer())))
743            }));
744    }
745}
746
747impl VpSet {
748    pub fn new(vtl_guest_memory: [Option<GuestMemory>; NUM_VTLS], halt: Arc<Halt>) -> Self {
749        let inner = Inner {
750            vtl_guest_memory,
751            halt,
752        };
753        Self {
754            inner: Arc::new(inner),
755            vps: Vec::new(),
756            started: false,
757        }
758    }
759
760    /// Adds a VP and returns its runner.
761    pub fn add(&mut self, vp: TargetVpInfo) -> VpRunner {
762        assert!(!self.started);
763        let (send, recv) = mesh::channel();
764        let (done_send, done_recv) = mesh::oneshot();
765        self.vps.push(Vp {
766            send,
767            done: done_recv,
768            vp_info: vp,
769        });
770        let (cancel_send, cancel_recv) = mesh::channel();
771        VpRunner {
772            recv,
773            _done: done_send,
774            cancel_recv,
775            cancel_send,
776            inner: RunnerInner {
777                vp: vp.as_ref().vp_index,
778                inner: self.inner.clone(),
779                state: VpState::Stopped,
780            },
781        }
782    }
783
784    /// Starts all VPs.
785    pub fn start(&mut self) {
786        if !self.started {
787            for vp in &self.vps {
788                vp.send.send(VpEvent::Start);
789            }
790            self.started = true;
791        }
792    }
793
794    /// Initiates a halt to the VPs.
795    #[cfg_attr(not(feature = "gdb"), expect(dead_code))]
796    pub fn halt(&mut self, reason: HaltReason) {
797        self.inner.halt.halt(reason);
798    }
799
800    /// Resets the halt state for all VPs.
801    ///
802    /// The VPs must be stopped.
803    pub fn clear_halt(&mut self) {
804        assert!(!self.started);
805        self.inner.halt.clear_halt();
806    }
807
808    /// Stops all VPs.
809    pub async fn stop(&mut self) {
810        if self.started {
811            self.vps
812                .iter()
813                .map(|vp| {
814                    let (send, recv) = mesh::oneshot();
815                    vp.send.send(VpEvent::Stop(send));
816                    // Ignore VPs whose runners have been dropped.
817                    async { recv.await.ok() }
818                })
819                .collect::<JoinAll<_>>()
820                .await;
821            self.started = false;
822        }
823    }
824
825    pub async fn save(&mut self) -> Result<Vec<(VpIndex, SavedStateBlob)>, SaveError> {
826        assert!(!self.started);
827        self.vps
828            .iter()
829            .enumerate()
830            .map(async |(index, vp)| {
831                let data = vp
832                    .send
833                    .call(|x| VpEvent::State(StateEvent::Save(x)), ())
834                    .await
835                    .map_err(|err| SaveError::Other(RunnerGoneError(err).into()))
836                    .and_then(|x| x)
837                    .map_err(|err| SaveError::ChildError(format!("vp{index}"), Box::new(err)))?;
838                Ok((VpIndex::new(index as u32), data))
839            })
840            .collect::<TryJoinAll<_>>()
841            .await
842    }
843
844    pub async fn restore(
845        &mut self,
846        states: impl IntoIterator<Item = (VpIndex, SavedStateBlob)>,
847    ) -> Result<(), RestoreError> {
848        assert!(!self.started);
849        states
850            .into_iter()
851            .map(|(vp_index, data)| {
852                let vp = self.vps.get(vp_index.index() as usize);
853                async move {
854                    let vp = vp.ok_or_else(|| {
855                        RestoreError::UnknownEntryId(format!("vp{}", vp_index.index()))
856                    })?;
857                    vp.send
858                        .call(|x| VpEvent::State(StateEvent::Restore(x)), data)
859                        .await
860                        .map_err(|err| RestoreError::Other(RunnerGoneError(err).into()))
861                        .and_then(|x| x)
862                        .map_err(|err| {
863                            RestoreError::ChildError(
864                                format!("vp{}", vp_index.index()),
865                                Box::new(err),
866                            )
867                        })
868                }
869            })
870            .collect::<TryJoinAll<_>>()
871            .await?;
872
873        Ok(())
874    }
875
876    /// Tears down the VPs.
877    pub async fn teardown(self) {
878        self.vps
879            .into_iter()
880            .map(|vp| vp.done.map(drop))
881            .collect::<JoinAll<_>>()
882            .await;
883    }
884
885    pub async fn set_initial_regs(
886        &mut self,
887        vtl: Vtl,
888        initial_regs: Arc<InitialRegs>,
889        to_set: RegistersToSet,
890    ) -> Result<(), RegisterSetError> {
891        self.vps
892            .iter()
893            .map(|vp| {
894                let initial_regs = initial_regs.clone();
895                async move {
896                    vp.send
897                        .call(
898                            |x| VpEvent::State(StateEvent::SetInitialRegs(x)),
899                            (vtl, initial_regs, to_set),
900                        )
901                        .await
902                        .map_err(|err| {
903                            RegisterSetError("initial_regs", RunnerGoneError(err).into())
904                        })?
905                }
906            })
907            .collect::<TryJoinAll<_>>()
908            .await?;
909
910        Ok(())
911    }
912}
913
914/// Error returned when registers could not be set on a VP.
915#[derive(Debug, Error)]
916#[error("failed to set VP register set {0}")]
917pub struct RegisterSetError(&'static str, #[source] anyhow::Error);
918
919#[derive(Debug, Error)]
920#[error("the vp runner was dropped")]
921struct RunnerGoneError(#[source] RpcError);
922
923#[cfg(feature = "gdb")]
924impl VpSet {
925    /// Set the debug state for a single VP.
926    pub async fn set_debug_state(
927        &self,
928        vp: VpIndex,
929        state: virt::x86::DebugState,
930    ) -> anyhow::Result<()> {
931        self.vps[vp.index() as usize]
932            .send
933            .call(
934                |x| VpEvent::State(StateEvent::Debug(DebugEvent::SetDebugState(x))),
935                Some(state),
936            )
937            .await
938            .map_err(RunnerGoneError)?
939    }
940
941    /// Clear the debug state for all VPs.
942    pub async fn clear_debug_state(&self) -> anyhow::Result<()> {
943        for vp in &self.vps {
944            vp.send
945                .call(
946                    |x| VpEvent::State(StateEvent::Debug(DebugEvent::SetDebugState(x))),
947                    None,
948                )
949                .await
950                .map_err(RunnerGoneError)??;
951        }
952        Ok(())
953    }
954
955    pub async fn set_vp_state(
956        &self,
957        vp: VpIndex,
958        state: Box<DebuggerVpState>,
959    ) -> anyhow::Result<()> {
960        self.vps[vp.index() as usize]
961            .send
962            .call(
963                |x| VpEvent::State(StateEvent::Debug(DebugEvent::SetVpState(x))),
964                state,
965            )
966            .await
967            .map_err(RunnerGoneError)?
968    }
969
970    pub async fn get_vp_state(&self, vp: VpIndex) -> anyhow::Result<Box<DebuggerVpState>> {
971        self.vps[vp.index() as usize]
972            .send
973            .call(
974                |x| VpEvent::State(StateEvent::Debug(DebugEvent::GetVpState(x))),
975                (),
976            )
977            .await
978            .map_err(RunnerGoneError)?
979    }
980
981    pub async fn read_virtual_memory(
982        &self,
983        vp: VpIndex,
984        gva: u64,
985        len: usize,
986    ) -> anyhow::Result<Vec<u8>> {
987        self.vps[vp.index() as usize]
988            .send
989            .call(
990                |x| VpEvent::State(StateEvent::Debug(DebugEvent::ReadVirtualMemory(x))),
991                (gva, len),
992            )
993            .await
994            .map_err(RunnerGoneError)?
995    }
996
997    pub async fn write_virtual_memory(
998        &self,
999        vp: VpIndex,
1000        gva: u64,
1001        data: Vec<u8>,
1002    ) -> anyhow::Result<()> {
1003        self.vps[vp.index() as usize]
1004            .send
1005            .call(
1006                |x| VpEvent::State(StateEvent::Debug(DebugEvent::WriteVirtualMemory(x))),
1007                (gva, data),
1008            )
1009            .await
1010            .map_err(RunnerGoneError)?
1011    }
1012}
1013
1014#[derive(Debug)]
1015enum VpEvent {
1016    Start,
1017    Stop(mesh::OneshotSender<()>),
1018    State(StateEvent),
1019}
1020
1021#[derive(Debug)]
1022enum StateEvent {
1023    Inspect(inspect::Deferred),
1024    SetInitialRegs(Rpc<(Vtl, Arc<InitialRegs>, RegistersToSet), Result<(), RegisterSetError>>),
1025    Save(Rpc<(), Result<SavedStateBlob, SaveError>>),
1026    Restore(Rpc<SavedStateBlob, Result<(), RestoreError>>),
1027    #[cfg(feature = "gdb")]
1028    Debug(DebugEvent),
1029}
1030
1031#[cfg(feature = "gdb")]
1032#[derive(Debug)]
1033enum DebugEvent {
1034    SetDebugState(Rpc<Option<virt::x86::DebugState>, anyhow::Result<()>>),
1035    SetVpState(Rpc<Box<DebuggerVpState>, anyhow::Result<()>>),
1036    GetVpState(Rpc<(), anyhow::Result<Box<DebuggerVpState>>>),
1037    ReadVirtualMemory(Rpc<(u64, usize), anyhow::Result<Vec<u8>>>),
1038    WriteVirtualMemory(Rpc<(u64, Vec<u8>), anyhow::Result<()>>),
1039}
1040
1041/// An object used to dispatch a virtual processor.
1042#[must_use]
1043pub struct VpRunner {
1044    recv: mesh::Receiver<VpEvent>,
1045    cancel_send: mesh::Sender<()>,
1046    cancel_recv: mesh::Receiver<()>,
1047    _done: mesh::OneshotSender<()>,
1048    inner: RunnerInner,
1049}
1050
1051/// An object that can cancel a pending call into [`VpRunner::run`].
1052pub struct RunnerCanceller(mesh::Sender<()>);
1053
1054impl RunnerCanceller {
1055    /// Requests that the current or next call to [`VpRunner::run`] return as
1056    /// soon as possible.
1057    pub fn cancel(&mut self) {
1058        self.0.send(());
1059    }
1060}
1061
1062/// Error returned when a VP run is cancelled.
1063#[derive(Debug)]
1064pub struct RunCancelled(bool);
1065
1066impl RunCancelled {
1067    /// Returns `true` if the run was cancelled by the user, or `false` if it was
1068    /// cancelled by the VP itself.
1069    pub fn is_user_cancelled(&self) -> bool {
1070        self.0
1071    }
1072}
1073
1074struct RunnerInner {
1075    vp: VpIndex,
1076    inner: Arc<Inner>,
1077    state: VpState,
1078}
1079
1080#[derive(Copy, Clone, Debug, Inspect, PartialEq, Eq)]
1081enum VpState {
1082    Stopped,
1083    Running,
1084    Halted,
1085}
1086
1087impl VpRunner {
1088    /// Runs the VP dispatch loop for `vp`, using `io` to handle CPU requests.
1089    ///
1090    /// Returns [`RunCancelled`] if [`RunnerCanceller::cancel`] was called, or
1091    /// if the VP returns [`VpHaltReason::Cancel`]. In this case, the call can
1092    /// be reissued, with the same or different `vp` object, to continue running
1093    /// the VP.
1094    ///
1095    /// Do not reissue this call if it returns `Ok`. Do not drop this future
1096    /// without awaiting it to completion.
1097    pub async fn run(
1098        &mut self,
1099        vp: &mut (impl Processor + ProtobufSaveRestore),
1100        io: &impl CpuIo,
1101    ) -> Result<(), RunCancelled> {
1102        let vp_index = self.inner.vp;
1103        self.run_inner(&mut BoundVp { vp, io, vp_index }).await
1104    }
1105
1106    /// Returns an object that can be used to cancel a `run` call.
1107    pub fn canceller(&self) -> RunnerCanceller {
1108        RunnerCanceller(self.cancel_send.clone())
1109    }
1110
1111    #[instrument(level = "debug", name = "run_vp", skip_all, fields(vp_index = self.inner.vp.index()))]
1112    async fn run_inner(&mut self, vp: &mut dyn ControlVp) -> Result<(), RunCancelled> {
1113        loop {
1114            // Wait for start.
1115            while self.inner.state != VpState::Running {
1116                let r = (self.recv.next().map(Ok), self.cancel_recv.next().map(Err))
1117                    .race()
1118                    .await
1119                    .map_err(|_| RunCancelled(true))?;
1120                match r {
1121                    Some(VpEvent::Start) => {
1122                        assert_eq!(self.inner.state, VpState::Stopped);
1123                        self.inner.state = VpState::Running;
1124                    }
1125                    Some(VpEvent::Stop(send)) => {
1126                        assert_eq!(self.inner.state, VpState::Halted);
1127                        self.inner.state = VpState::Stopped;
1128                        send.send(());
1129                    }
1130                    Some(VpEvent::State(event)) => self.inner.state_event(vp, event),
1131                    None => return Ok(()),
1132                }
1133            }
1134
1135            // If the VPs are already halted, wait for the next request without
1136            // running the VP even once.
1137            if self.inner.inner.halt.is_halted() {
1138                self.inner.state = VpState::Halted;
1139                continue;
1140            }
1141
1142            let mut stop_complete = None;
1143            let mut state_requests = Vec::new();
1144            let mut cancelled_by_user = None;
1145            {
1146                enum Event {
1147                    Vp(VpEvent),
1148                    Teardown,
1149                    Halt,
1150                    VpStopped(Result<StopReason, HaltReason>),
1151                    Cancel,
1152                }
1153
1154                let stop = StopVpSource::new();
1155
1156                let run_vp = vp
1157                    .run_vp(&self.inner.inner.vtl_guest_memory, stop.checker())
1158                    .into_stream()
1159                    .map(Event::VpStopped);
1160
1161                let halt = self
1162                    .inner
1163                    .inner
1164                    .halt
1165                    .halted()
1166                    .into_stream()
1167                    .map(|_| Event::Halt);
1168
1169                let recv = (&mut self.recv)
1170                    .map(Event::Vp)
1171                    .chain(async { Event::Teardown }.into_stream());
1172
1173                let cancel = (&mut self.cancel_recv).map(|()| Event::Cancel);
1174
1175                let s = (recv, halt, cancel).merge();
1176
1177                // Since `run_vp` will block the thread until it receives a
1178                // cancellation notification, always poll the other sources to
1179                // exhaustion before polling the future.
1180                let mut s = pin!(select_with_strategy(s, run_vp, |_: &mut ()| {
1181                    futures::stream::PollNext::Left
1182                }));
1183
1184                // Wait for stop or a VP failure.
1185                while let Some(event) = s.next().await {
1186                    match event {
1187                        Event::Vp(VpEvent::Start) => panic!("vp already started"),
1188                        Event::Vp(VpEvent::Stop(send)) => {
1189                            tracing::debug!("stopping VP");
1190                            stop.stop();
1191                            stop_complete = Some(send);
1192                        }
1193                        Event::Vp(VpEvent::State(event)) => {
1194                            // Stop the VP so that we can drop the run_vp future
1195                            // before manipulating state.
1196                            //
1197                            // FUTURE: This causes inspection delays during slow
1198                            // MMIO/PIO exit handling. Fix the backends to support
1199                            // calling inspect while run_vp is still alive (but
1200                            // suspended).
1201                            stop.stop();
1202                            state_requests.push(event);
1203                        }
1204                        Event::Halt => {
1205                            tracing::debug!("stopping VP due to halt");
1206                            stop.stop();
1207                        }
1208                        Event::Cancel => {
1209                            tracing::debug!("run cancelled externally");
1210                            stop.stop();
1211                            cancelled_by_user = Some(true);
1212                        }
1213                        Event::Teardown => {
1214                            tracing::debug!("tearing down");
1215                            stop.stop();
1216                        }
1217                        Event::VpStopped(r) => {
1218                            match r {
1219                                Ok(StopReason::OnRequest(VpStopped { .. })) => {
1220                                    assert!(stop.is_stopping(), "vp stopped without a reason");
1221                                    tracing::debug!("VP stopped on request");
1222                                }
1223                                Ok(StopReason::Cancel) => {
1224                                    tracing::debug!("run cancelled internally");
1225                                    cancelled_by_user = Some(false);
1226                                }
1227                                Err(halt_reason) => {
1228                                    tracing::debug!("VP halted");
1229                                    self.inner.inner.halt.halt(halt_reason);
1230                                }
1231                            }
1232                            break;
1233                        }
1234                    }
1235                }
1236            }
1237            for event in state_requests {
1238                self.inner.state_event(vp, event);
1239            }
1240
1241            if let Some(send) = stop_complete {
1242                self.inner.state = VpState::Stopped;
1243                send.send(());
1244            }
1245
1246            if let Some(by_user) = cancelled_by_user {
1247                return Err(RunCancelled(by_user));
1248            }
1249        }
1250    }
1251}
1252
1253impl RunnerInner {
1254    fn state_event(&mut self, vp: &mut dyn ControlVp, event: StateEvent) {
1255        match event {
1256            StateEvent::Inspect(deferred) => {
1257                deferred.respond(|resp| {
1258                    resp.field("state", self.state)
1259                        .merge(inspect::adhoc_mut(|req| {
1260                            vp.inspect_vp(&self.inner.vtl_guest_memory, req)
1261                        }));
1262                });
1263            }
1264            StateEvent::SetInitialRegs(rpc) => {
1265                rpc.handle_sync(|(vtl, state, to_set)| vp.set_initial_regs(vtl, &state, to_set))
1266            }
1267            StateEvent::Save(rpc) => rpc.handle_sync(|()| vp.save()),
1268            StateEvent::Restore(rpc) => rpc.handle_sync(|data| vp.restore(data)),
1269            #[cfg(feature = "gdb")]
1270            StateEvent::Debug(event) => match event {
1271                DebugEvent::SetDebugState(rpc) => {
1272                    rpc.handle_sync(|state| vp.debug().set_debug_state(Vtl::Vtl0, state.as_ref()))
1273                }
1274                DebugEvent::SetVpState(rpc) => {
1275                    rpc.handle_sync(|state| vp.debug().set_vp_state(Vtl::Vtl0, &state))
1276                }
1277                DebugEvent::GetVpState(rpc) => {
1278                    rpc.handle_sync(|()| vp.debug().get_vp_state(Vtl::Vtl0))
1279                }
1280                DebugEvent::ReadVirtualMemory(rpc) => rpc.handle_sync(|(gva, len)| {
1281                    let mut buf = vec![0; len];
1282                    vp_state::read_virtual_memory(
1283                        self.inner.vtl_guest_memory[0]
1284                            .as_ref()
1285                            .context("no guest memory for vtl0")?,
1286                        vp.debug(),
1287                        Vtl::Vtl0,
1288                        gva,
1289                        &mut buf,
1290                    )?;
1291                    Ok(buf)
1292                }),
1293                DebugEvent::WriteVirtualMemory(rpc) => rpc.handle_sync(|(gva, buf)| {
1294                    vp_state::write_virtual_memory(
1295                        self.inner.vtl_guest_memory[0]
1296                            .as_ref()
1297                            .context("no guest memory for vtl0")?,
1298                        vp.debug(),
1299                        Vtl::Vtl0,
1300                        gva,
1301                        &buf,
1302                    )?;
1303                    Ok(())
1304                }),
1305            },
1306        }
1307    }
1308}
1309
1310#[cfg(feature = "gdb")]
1311mod vp_state {
1312    use super::DebugVp;
1313    use anyhow::Context;
1314    use guestmem::GuestMemory;
1315    use hvdef::Vtl;
1316    use vmm_core_defs::debug_rpc::DebuggerVpState;
1317
1318    fn translate_gva(
1319        guest_memory: &GuestMemory,
1320        debug: &mut dyn DebugVp,
1321        vtl: Vtl,
1322        gva: u64,
1323    ) -> anyhow::Result<u64> {
1324        let state = debug.get_vp_state(vtl).context("failed to get vp state")?;
1325
1326        match &*state {
1327            DebuggerVpState::X86_64(state) => {
1328                let registers = virt_support_x86emu::translate::TranslationRegisters {
1329                    cr0: state.cr0,
1330                    cr4: state.cr4,
1331                    efer: state.efer,
1332                    cr3: state.cr3,
1333                    rflags: state.rflags,
1334                    ss: state.ss.into(),
1335                    // For debug translation, don't worry about accidentally reading
1336                    // page tables from shared memory.
1337                    encryption_mode: virt_support_x86emu::translate::EncryptionMode::None,
1338                };
1339                let flags = virt_support_x86emu::translate::TranslateFlags {
1340                    validate_execute: false,
1341                    validate_read: false,
1342                    validate_write: false,
1343                    override_smap: false,
1344                    enforce_smap: false,
1345                    privilege_check: virt_support_x86emu::translate::TranslatePrivilegeCheck::None,
1346                    set_page_table_bits: false,
1347                };
1348                Ok(virt_support_x86emu::translate::translate_gva_to_gpa(
1349                    guest_memory,
1350                    gva,
1351                    &registers,
1352                    flags,
1353                )?
1354                .gpa)
1355            }
1356            DebuggerVpState::Aarch64(state) => {
1357                let registers = virt_support_aarch64emu::translate::TranslationRegisters {
1358                    cpsr: state.cpsr.into(),
1359                    sctlr: state.sctlr_el1.into(),
1360                    tcr: state.tcr_el1.into(),
1361                    ttbr0: state.ttbr0_el1,
1362                    ttbr1: state.ttbr1_el1,
1363                    syndrome: 0,
1364                    // For debug translation, don't worry about accidentally reading
1365                    // page tables from shared memory.
1366                    encryption_mode: virt_support_aarch64emu::translate::EncryptionMode::None,
1367                };
1368                let flags = virt_support_aarch64emu::translate::TranslateFlags {
1369                    validate_execute: false,
1370                    validate_read: false,
1371                    validate_write: false,
1372                    privilege_check:
1373                        virt_support_aarch64emu::translate::TranslatePrivilegeCheck::None,
1374                    set_page_table_bits: false,
1375                };
1376                Ok(virt_support_aarch64emu::translate::translate_gva_to_gpa(
1377                    guest_memory,
1378                    gva,
1379                    &registers,
1380                    flags,
1381                )?)
1382            }
1383        }
1384    }
1385
1386    pub(super) fn read_virtual_memory(
1387        guest_memory: &GuestMemory,
1388        debug: &mut dyn DebugVp,
1389        vtl: Vtl,
1390        gva: u64,
1391        buf: &mut [u8],
1392    ) -> Result<(), anyhow::Error> {
1393        let mut offset = 0;
1394        while offset < buf.len() {
1395            let gpa = translate_gva(guest_memory, debug, vtl, gva + offset as u64)
1396                .context("failed to translate gva")?;
1397            let this_len = (buf.len() - offset).min(4096 - (gpa & 4095) as usize);
1398            guest_memory.read_at(gpa, &mut buf[offset..offset + this_len])?;
1399            offset += this_len;
1400        }
1401        Ok(())
1402    }
1403
1404    pub(super) fn write_virtual_memory(
1405        guest_memory: &GuestMemory,
1406        debug: &mut dyn DebugVp,
1407        vtl: Vtl,
1408        gva: u64,
1409        buf: &[u8],
1410    ) -> Result<(), anyhow::Error> {
1411        let mut offset = 0;
1412        while offset < buf.len() {
1413            let gpa = translate_gva(guest_memory, debug, vtl, gva + offset as u64)
1414                .context("failed to translate gva")?;
1415            let this_len = (buf.len() - offset).min(4096 - (gpa & 4095) as usize);
1416            guest_memory.write_at(gpa, &buf[offset..offset + this_len])?;
1417            offset += this_len;
1418        }
1419        Ok(())
1420    }
1421
1422    #[cfg(guest_arch = "x86_64")]
1423    fn bits(regs: &virt::x86::vp::Registers) -> u32 {
1424        if regs.cr0 & x86defs::X64_CR0_PE != 0 {
1425            if regs.efer & x86defs::X64_EFER_LMA != 0 {
1426                64
1427            } else {
1428                32
1429            }
1430        } else {
1431            16
1432        }
1433    }
1434
1435    #[cfg(guest_arch = "x86_64")]
1436    fn linear_ip(regs: &virt::x86::vp::Registers, rip: u64) -> u64 {
1437        if bits(regs) == 64 {
1438            rip
1439        } else {
1440            // 32 or 16 bits
1441            regs.cs.base.wrapping_add(rip)
1442        }
1443    }
1444
1445    /// Get the previous instruction for debugging purposes.
1446    #[cfg(guest_arch = "x86_64")]
1447    pub(super) fn previous_instruction(
1448        guest_memory: &GuestMemory,
1449        debug: &mut dyn DebugVp,
1450        vtl: Vtl,
1451        regs: &virt::x86::vp::Registers,
1452    ) -> anyhow::Result<iced_x86::Instruction> {
1453        let mut bytes = [0u8; 16];
1454        // Read 16 bytes before RIP.
1455        let rip = regs.rip.wrapping_sub(16);
1456        read_virtual_memory(guest_memory, debug, vtl, linear_ip(regs, rip), &mut bytes)
1457            .context("failed to read memory")?;
1458        let mut decoder = iced_x86::Decoder::new(bits(regs), &bytes, 0);
1459
1460        // Try decoding at each byte until we find the instruction right before the current one.
1461        for offset in 0..16 {
1462            decoder.set_ip(rip.wrapping_add(offset));
1463            decoder.try_set_position(offset as usize).unwrap();
1464            let instr = decoder.decode();
1465            if !instr.is_invalid() && instr.next_ip() == regs.rip {
1466                return Ok(instr);
1467            }
1468        }
1469        Err(anyhow::anyhow!("could not find previous instruction"))
1470    }
1471
1472    /// Get the next instruction for debugging purposes.
1473    #[cfg(guest_arch = "x86_64")]
1474    pub(super) fn next_instruction(
1475        guest_memory: &GuestMemory,
1476        debug: &mut dyn DebugVp,
1477        vtl: Vtl,
1478        regs: &virt::x86::vp::Registers,
1479    ) -> anyhow::Result<(iced_x86::Instruction, [u8; 16])> {
1480        let mut bytes = [0u8; 16];
1481        read_virtual_memory(
1482            guest_memory,
1483            debug,
1484            vtl,
1485            linear_ip(regs, regs.rip),
1486            &mut bytes,
1487        )
1488        .context("failed to read memory")?;
1489        let mut decoder = iced_x86::Decoder::new(bits(regs), &bytes, 0);
1490        decoder.set_ip(regs.rip);
1491        Ok((decoder.decode(), bytes))
1492    }
1493}
1494
1495struct VpWaker {
1496    partition: Arc<dyn RequestYield>,
1497    vp: VpIndex,
1498    inner: Waker,
1499}
1500
1501impl VpWaker {
1502    fn new(partition: Arc<dyn RequestYield>, vp: VpIndex, waker: Waker) -> Self {
1503        Self {
1504            partition,
1505            vp,
1506            inner: waker,
1507        }
1508    }
1509}
1510
1511impl std::task::Wake for VpWaker {
1512    fn wake_by_ref(self: &Arc<Self>) {
1513        self.partition.request_yield(self.vp);
1514        self.inner.wake_by_ref();
1515    }
1516
1517    fn wake(self: Arc<Self>) {
1518        self.wake_by_ref()
1519    }
1520}
1521
1522/// Trait for requesting that a VP yield in its [`virt::Processor::run_vp`]
1523/// call.
1524pub trait RequestYield: Send + Sync {
1525    /// Forces the run_vp call to yield to the scheduler (i.e. return
1526    /// Poll::Pending).
1527    fn request_yield(&self, vp_index: VpIndex);
1528}
1529
1530impl<T: virt::Partition> RequestYield for T {
1531    fn request_yield(&self, vp_index: VpIndex) {
1532        self.request_yield(vp_index)
1533    }
1534}
1535
1536/// Blocks on a future, where the future may run a VP (and so the associated
1537/// waker needs to ask the VP to yield).
1538pub fn block_on_vp<F: Future>(partition: Arc<dyn RequestYield>, vp: VpIndex, fut: F) -> F::Output {
1539    let mut fut = pin!(fut);
1540    pal_async::local::block_on(std::future::poll_fn(|cx| {
1541        let waker = Arc::new(VpWaker::new(partition.clone(), vp, cx.waker().clone())).into();
1542        let mut cx = Context::from_waker(&waker);
1543        fut.poll_unpin(&mut cx)
1544    }))
1545}