Skip to main content

chipset/pit/
mod.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4pub mod resolver;
5
6use bitfield_struct::bitfield;
7use chipset_device::ChipsetDevice;
8use chipset_device::io::IoError;
9use chipset_device::io::IoResult;
10use chipset_device::pio::PortIoIntercept;
11use chipset_device::poll_device::PollDevice;
12use inspect::Inspect;
13use inspect::InspectMut;
14use open_enum::open_enum;
15use std::ops::RangeInclusive;
16use std::task::Context;
17use std::task::Poll;
18use std::time::Duration;
19use thiserror::Error;
20use vmcore::device_state::ChangeDeviceState;
21use vmcore::line_interrupt::LineInterrupt;
22use vmcore::vmtime::VmTime;
23use vmcore::vmtime::VmTimeAccess;
24
25#[rustfmt::skip]
26#[derive(Inspect)]
27#[bitfield(u8)]
28struct ControlWord {
29    #[bits(1)] bcd: bool,
30    #[inspect(with = "|x| Mode::from(*x)")]
31    #[bits(3)] mode: u8,
32    #[inspect(with = "|x| RwMode(*x)")]
33    #[bits(2)] rw: u8,
34    #[inspect(skip)] // Ignore `select` since it's not part of the persistent state.
35    #[bits(2)] select: u8,
36}
37
38#[rustfmt::skip]
39#[bitfield(u8)]
40struct StatusWord {
41    #[bits(1)] bcd: bool,
42    #[bits(3)] mode: u8,
43    #[bits(2)] rw: u8,
44    #[bits(1)] null: bool,
45    #[bits(1)] out: bool,
46}
47
48#[bitfield(u8)]
49struct ReadBackCommand {
50    reserved: bool,
51    counter0: bool,
52    counter1: bool,
53    counter2: bool,
54    status: bool,
55    count: bool,
56    #[bits(2)]
57    one: u8,
58}
59
60const PIT_TIMER_RANGE_START: u16 = 0x40;
61const PIT_TIMER_RANGE_END: u16 = 0x42;
62const PIT_CONTROL_REGISTER: u16 = 0x43;
63const PIT_PORT61_REGISTER: u16 = 0x61;
64
65#[derive(Debug, Inspect)]
66struct Timer {
67    // Static configuration
68    enabled_at_reset: bool,
69
70    // Runtime glue
71    interrupt: Option<LineInterrupt>,
72
73    // Volatile state
74    #[inspect(flatten)]
75    state: TimerState,
76}
77
78#[derive(Copy, Clone, Debug, Inspect)]
79struct TimerState {
80    ce: u16,         // "counting element", i.e. the counter
81    cr: u16,         // count register, the new value
82    ol: Option<u16>, // the output latch
83    sl: Option<u8>,  // the status latch
84    state: CountState,
85    control: ControlWord,
86    out: bool,       // timer output
87    gate: bool,      // timer input
88    null: bool,      // cr has been set but not copied to ce yet
89    read_high: bool, // read the high counter byte next
90    cr_low: Option<u8>,
91}
92
93#[derive(Copy, Clone, Debug, Inspect, PartialEq, Eq)]
94enum CountState {
95    Inactive,
96    WaitingForGate,
97    Reloading,
98    Active,
99    Counting,
100}
101
102#[derive(Debug, Copy, Clone, Inspect, PartialEq, Eq)]
103enum Mode {
104    TerminalCount = 0,
105    HardwareOneShot = 1,
106    RateGenerator = 2,
107    SquareWave = 3,
108    SoftwareStrobe = 4,
109    HardwareStrobe = 5,
110}
111
112impl From<u8> for Mode {
113    fn from(v: u8) -> Self {
114        match v {
115            0 => Mode::TerminalCount,
116            1 => Mode::HardwareOneShot,
117            2 | 6 => Mode::RateGenerator,
118            3 | 7 => Mode::SquareWave,
119            4 => Mode::SoftwareStrobe,
120            5 => Mode::HardwareStrobe,
121            _ => unreachable!(),
122        }
123    }
124}
125
126impl Mode {
127    /// Returns true for modes where counting stops when gate is low.
128    fn gate_stops_count(&self) -> bool {
129        match self {
130            Mode::TerminalCount | Mode::RateGenerator | Mode::SquareWave | Mode::SoftwareStrobe => {
131                true
132            }
133            Mode::HardwareOneShot | Mode::HardwareStrobe => false,
134        }
135    }
136}
137
138open_enum! {
139    #[derive(Inspect)]
140    #[inspect(debug)]
141    enum RwMode: u8 {
142        LOW = 1,
143        HIGH = 2,
144        LOW_HIGH = 3,
145    }
146}
147
148const fn from_bcd(n: u16) -> u16 {
149    (n & 0xf) + ((n & 0xf0) >> 4) * 10 + ((n & 0xf00) >> 8) * 100 + ((n & 0xf000) >> 12) * 1000
150}
151
152const fn to_bcd(n: u16) -> u16 {
153    (n % 10) + (((n / 10) % 10) << 4) + (((n / 100) % 10) << 8) + (((n / 1000) % 10) << 12)
154}
155
156/// Subtracts `n` from `ce` with wrap.
157///
158/// If `bcd`, then `ce` is in BCD format. The return value is always in binary
159/// format.
160fn counter_sub(ce: u16, n: u64, bcd: bool) -> u64 {
161    if bcd {
162        let ce = from_bcd(ce);
163        let n = (n % 10000) as u16;
164        (if ce >= n { ce - n } else { 10000 - (n - ce) }) as u64
165    } else {
166        ce.wrapping_sub(n as u16) as u64
167    }
168}
169
170/// Nanoseconds per PIT tick.
171const NANOS_PER_TICK: u64 = 838;
172
173impl Timer {
174    fn new(enabled_at_reset: bool, interrupt: Option<LineInterrupt>) -> Self {
175        Self {
176            enabled_at_reset,
177            interrupt,
178            state: TimerState::new(enabled_at_reset),
179        }
180    }
181
182    fn reset(&mut self) {
183        self.state = TimerState::new(self.enabled_at_reset);
184        self.sync_interrupt();
185    }
186
187    fn sync_interrupt(&mut self) {
188        if let Some(interrupt) = &self.interrupt {
189            interrupt.set_level(self.state.out);
190        }
191    }
192
193    fn set_out(&mut self, state: bool) {
194        if self.state.out != state {
195            self.state.out = state;
196            self.sync_interrupt();
197        }
198    }
199
200    fn load_ce(&mut self) {
201        self.state.ce = self.state.cr;
202        self.state.null = false;
203    }
204
205    /// Sets CE to the given value, wrapped. Stores CE in BCD
206    /// format if the PIT is in BCD mode.
207    fn set_ce(&mut self, ce: u64) {
208        if self.state.control.bcd() {
209            self.state.ce = to_bcd((ce % 10000) as u16);
210        } else {
211            self.state.ce = ce as u16;
212        }
213    }
214
215    fn evaluate(&mut self, mut ticks: u64) {
216        let mode = self.state.op_mode();
217        let bcd = self.state.control.bcd();
218        while ticks > 0 {
219            match self.state.state {
220                CountState::Inactive | CountState::WaitingForGate => break,
221                CountState::Reloading => {
222                    ticks -= 1;
223                    self.load_ce();
224                    self.state.state = CountState::Active;
225                }
226                CountState::Active => {
227                    if !self.state.gate && mode.gate_stops_count() {
228                        break;
229                    }
230                    // Counts per tick.
231                    let per = match mode {
232                        Mode::TerminalCount
233                        | Mode::HardwareOneShot
234                        | Mode::RateGenerator
235                        | Mode::SoftwareStrobe
236                        | Mode::HardwareStrobe => 1,
237                        Mode::SquareWave => {
238                            // Strip the low bit. This takes an extra tick when
239                            // out is high.
240                            if self.state.ce & 1 != 0 {
241                                self.state.ce &= !1;
242                                if self.state.out {
243                                    ticks -= 1;
244                                    continue;
245                                }
246                            }
247                            2
248                        }
249                    };
250                    if self.state.ce as u64 == per {
251                        // Terminal state.
252                        self.state.ce = 0;
253                        ticks -= 1;
254                        match mode {
255                            Mode::TerminalCount | Mode::HardwareOneShot => {
256                                self.set_out(true);
257                                self.state.state = CountState::Counting;
258                            }
259                            Mode::RateGenerator => {
260                                self.set_out(true);
261                                self.load_ce();
262                            }
263                            Mode::SquareWave => {
264                                self.set_out(!self.state.out);
265                                self.load_ce();
266                            }
267                            Mode::SoftwareStrobe | Mode::HardwareStrobe => {
268                                self.set_out(false);
269                                self.state.state = CountState::Counting;
270                            }
271                        }
272                    } else {
273                        if ticks >= counter_sub(self.state.ce, per, bcd) / per {
274                            // Decrement down to one tick before the terminal state.
275                            ticks -= counter_sub(self.state.ce, per, bcd) / per;
276                            self.state.ce = per as u16;
277                            if mode == Mode::RateGenerator {
278                                self.set_out(false);
279                            }
280                        } else {
281                            self.set_ce(counter_sub(self.state.ce, ticks * per, bcd));
282                            ticks = 0;
283                        }
284                    }
285                }
286                CountState::Counting => {
287                    if !self.state.gate && mode.gate_stops_count() {
288                        break;
289                    }
290                    self.set_ce(counter_sub(self.state.ce, ticks, bcd));
291                    ticks = 0;
292                    self.set_out(true);
293                }
294            }
295        }
296    }
297
298    fn set_control(&mut self, control: ControlWord) {
299        if control.rw() == 0 {
300            self.state.latch_counter();
301            return;
302        }
303
304        self.state.control = control.with_select(0);
305        self.state.ce = 0;
306        self.state.cr = 0;
307        self.state.cr_low = None;
308        self.state.read_high = false;
309        self.state.state = CountState::Inactive;
310        self.state.null = true;
311        self.set_out(match self.state.op_mode() {
312            Mode::TerminalCount => false,
313            Mode::HardwareOneShot => true,
314            Mode::RateGenerator | Mode::SquareWave => true,
315            Mode::SoftwareStrobe => true,
316            Mode::HardwareStrobe => true,
317        });
318    }
319
320    fn write(&mut self, n: u8) {
321        let n = n as u16;
322        match RwMode(self.state.control.rw()) {
323            RwMode::LOW => self.state.cr = n,
324            RwMode::HIGH => self.state.cr = n << 8,
325            RwMode::LOW_HIGH => {
326                if let Some(low) = self.state.cr_low {
327                    self.state.cr = (n << 8) | (low as u16);
328                } else {
329                    self.state.cr_low = Some(n as u8);
330                    // Wait for high to be set before taking any actions.
331                    return;
332                }
333            }
334            _ => unreachable!(),
335        }
336        self.state.null = true;
337        match self.state.op_mode() {
338            Mode::TerminalCount => {
339                self.state.state = CountState::Reloading;
340                self.set_out(false);
341            }
342            Mode::HardwareOneShot => {
343                self.state.state = CountState::WaitingForGate;
344            }
345            Mode::RateGenerator | Mode::SquareWave => {
346                if self.state.state != CountState::Active {
347                    self.state.state = CountState::Reloading;
348                }
349            }
350            Mode::SoftwareStrobe => {
351                self.state.state = CountState::Reloading;
352            }
353            Mode::HardwareStrobe => {
354                self.state.state = CountState::WaitingForGate;
355            }
356        }
357    }
358
359    fn read(&mut self) -> u8 {
360        if let Some(sl) = self.state.sl.take() {
361            return sl;
362        }
363        let value = self.state.ol.unwrap_or(self.state.ce);
364        let value = match RwMode(self.state.control.rw()) {
365            RwMode::LOW => value as u8,
366            RwMode::HIGH => (value >> 8) as u8,
367            RwMode::LOW_HIGH => {
368                self.state.read_high = !self.state.read_high;
369                if self.state.read_high {
370                    value as u8
371                } else {
372                    (value >> 8) as u8
373                }
374            }
375            _ => unreachable!(),
376        };
377        if !self.state.read_high {
378            self.state.ol = None;
379        }
380        value
381    }
382
383    fn set_gate(&mut self, gate: bool) {
384        match self.state.op_mode() {
385            Mode::TerminalCount => {}
386            Mode::HardwareOneShot => {
387                if !self.state.gate && gate && self.state.state == CountState::WaitingForGate {
388                    self.state.state = CountState::Reloading;
389                    self.set_out(false);
390                }
391            }
392            Mode::RateGenerator | Mode::SquareWave => {
393                if gate && !self.state.gate {
394                    if self.state.state == CountState::Active {
395                        self.state.state = CountState::Reloading;
396                    }
397                } else if !gate {
398                    self.set_out(true);
399                }
400            }
401            Mode::SoftwareStrobe => {}
402            Mode::HardwareStrobe => {
403                if !self.state.gate && gate && self.state.state == CountState::WaitingForGate {
404                    self.state.state = CountState::Reloading;
405                }
406            }
407        }
408        self.state.gate = gate;
409    }
410}
411
412impl TimerState {
413    fn new(enabled: bool) -> Self {
414        Self {
415            ce: 0,
416            cr: 0,
417            ol: None,
418            sl: None,
419            control: ControlWord::new().with_rw(1),
420            state: CountState::Inactive,
421            out: false,
422            null: true,
423            gate: enabled,
424            read_high: false,
425            cr_low: None,
426        }
427    }
428
429    fn op_mode(&self) -> Mode {
430        self.control.mode().into()
431    }
432
433    // Returns the number of ticks until the next interrupt will occur.
434    fn next_wakeup(&self) -> Option<u64> {
435        let mode = self.op_mode();
436        let bcd = self.control.bcd();
437        match self.state {
438            CountState::Inactive => None,
439            CountState::WaitingForGate => None,
440            CountState::Reloading | CountState::Active => {
441                if !self.gate && mode.gate_stops_count() {
442                    return None;
443                }
444                // Add an extra count for the reload cycle.
445                let (ce, extra) = if self.state == CountState::Reloading {
446                    (self.cr, 1)
447                } else {
448                    (self.ce, 0)
449                };
450                let v = {
451                    match mode {
452                        Mode::TerminalCount
453                        | Mode::HardwareOneShot
454                        | Mode::SoftwareStrobe
455                        | Mode::HardwareStrobe => {
456                            // Changing output in ce ticks.
457                            counter_sub(ce, 1, bcd) + 1
458                        }
459                        Mode::RateGenerator => {
460                            if ce == 1 {
461                                // Going high in 1 tick.
462                                1
463                            } else {
464                                // Going low in ce - 1 ticks.
465                                counter_sub(ce, 1, bcd)
466                            }
467                        }
468                        Mode::SquareWave => {
469                            // Inverts in ce / 2 ticks, rounding up if out is high.
470                            (counter_sub(ce, 2, bcd) + 2) / 2 + (self.out && ce & 1 != 0) as u64
471                        }
472                    }
473                };
474                Some(v + extra)
475            }
476            CountState::Counting => {
477                if self.out || (!self.gate && mode.gate_stops_count()) {
478                    None
479                } else {
480                    Some(1)
481                }
482            }
483        }
484    }
485
486    fn latch_status(&mut self) {
487        if self.sl.is_none() {
488            self.sl = Some(
489                StatusWord(self.control.0)
490                    .with_null(self.null)
491                    .with_out(self.out)
492                    .into(),
493            );
494        }
495    }
496
497    fn latch_counter(&mut self) {
498        if self.ol.is_none() {
499            self.ol = Some(self.ce);
500        }
501    }
502}
503
504#[derive(InspectMut)]
505pub struct PitDevice {
506    // Runtime glue
507    vmtime: VmTimeAccess,
508
509    // Sub-emulators
510    #[inspect(iter_by_index)]
511    timers: [Timer; { PIT_TIMER_RANGE_END - PIT_TIMER_RANGE_START + 1 } as usize],
512
513    // Runtime book-keeping
514    dram_refresh: bool, // just jitters back and forth
515
516    // Volatile state
517    last: VmTime,
518}
519
520impl PitDevice {
521    pub fn new(interrupt: LineInterrupt, vmtime: VmTimeAccess) -> Self {
522        PitDevice {
523            // Timers 1 and 2 are enabled by default. Timer 1's output is hooked
524            // up to the interrupt line.
525            timers: [
526                Timer::new(true, Some(interrupt)),
527                Timer::new(true, None),
528                Timer::new(false, None),
529            ],
530            last: vmtime.now(),
531            vmtime,
532            dram_refresh: false,
533        }
534    }
535
536    fn evaluate(&mut self, now: VmTime) {
537        // Accumulate an integer number of ticks.
538        //
539        // N.B. if self.last were set to now, then each call to evaluate
540        // would leak a portion of a tick, causing timers to expire late.
541        let delta = now.checked_sub(self.last).unwrap_or(Duration::ZERO);
542        let ticks = delta.as_nanos() as u64 / NANOS_PER_TICK;
543        self.last = self
544            .last
545            .wrapping_add(Duration::from_nanos(ticks * NANOS_PER_TICK));
546        self.timers[0].evaluate(ticks);
547        self.timers[1].evaluate(ticks);
548        self.timers[2].evaluate(ticks);
549    }
550
551    fn arm_wakeup(&mut self) {
552        // Request another tick if needed. This is only needed for timer 0 since
553        // that's the only one wired up to an interrupt.
554        if let Some(next) = self.timers[0].state.next_wakeup() {
555            // Delay waking up if the next wakeup is too soon to avoid spinning.
556            let next = next.max(20);
557            self.vmtime.set_timeout_if_before(
558                self.last
559                    .wrapping_add(Duration::from_nanos(next * NANOS_PER_TICK)),
560            );
561        }
562    }
563}
564
565impl ChangeDeviceState for PitDevice {
566    fn start(&mut self) {}
567
568    async fn stop(&mut self) {}
569
570    async fn reset(&mut self) {
571        for timer in &mut self.timers {
572            timer.reset();
573        }
574        self.last = self.vmtime.now();
575    }
576}
577
578impl ChipsetDevice for PitDevice {
579    fn supports_poll_device(&mut self) -> Option<&mut dyn PollDevice> {
580        Some(self)
581    }
582
583    fn supports_pio(&mut self) -> Option<&mut dyn PortIoIntercept> {
584        Some(self)
585    }
586}
587
588impl PollDevice for PitDevice {
589    fn poll_device(&mut self, cx: &mut Context<'_>) {
590        if let Poll::Ready(now) = self.vmtime.poll_timeout(cx) {
591            self.evaluate(now);
592            // Re-register the poll before arming the next wakeup rather than
593            // after so that a very short wakeup will still allow this function
594            // to return, hopefully avoiding livelock.
595            assert!(self.vmtime.poll_timeout(cx).is_pending());
596            self.arm_wakeup();
597        }
598    }
599}
600
601impl PortIoIntercept for PitDevice {
602    fn io_read(&mut self, io_port: u16, data: &mut [u8]) -> IoResult {
603        if data.len() != 1 {
604            return IoResult::Err(IoError::InvalidAccessSize);
605        }
606
607        self.evaluate(self.vmtime.now());
608        match io_port {
609            PIT_TIMER_RANGE_START..=PIT_TIMER_RANGE_END => {
610                let offset = io_port - PIT_TIMER_RANGE_START;
611                data[0] = self.timers[offset as usize].read();
612            }
613            PIT_CONTROL_REGISTER => {
614                tracelimit::warn_ratelimited!("reading from write-only command register!");
615                data[0] = !0;
616            }
617            PIT_PORT61_REGISTER => {
618                data[0] = ((self.timers[2].state.out as u8) << 5)
619                    | ((self.dram_refresh as u8) << 4)
620                    | self.timers[2].state.gate as u8;
621                // Cycle the DRAM refresh bit every read. PCAT uses this to
622                // validate that DRAM is working, but it's not practical or
623                // useful to make the timing accurate.
624                self.dram_refresh = !self.dram_refresh;
625            }
626            _ => return IoResult::Err(IoError::InvalidRegister),
627        }
628
629        self.arm_wakeup();
630        IoResult::Ok
631    }
632
633    fn io_write(&mut self, io_port: u16, data: &[u8]) -> IoResult {
634        let &[b] = data else {
635            return IoResult::Err(IoError::InvalidAccessSize);
636        };
637
638        self.evaluate(self.vmtime.now());
639
640        match io_port {
641            PIT_TIMER_RANGE_START..=PIT_TIMER_RANGE_END => {
642                let offset = io_port - PIT_TIMER_RANGE_START;
643                self.timers[offset as usize].write(b);
644            }
645            PIT_CONTROL_REGISTER => {
646                let control = ControlWord(b);
647                match control.select() {
648                    i @ 0..=2 => {
649                        tracing::trace!(timer = i, ?control, "control write");
650                        self.timers[i as usize].set_control(control);
651                    }
652                    3 => {
653                        let command = ReadBackCommand(b);
654                        tracing::trace!(?command, "read back");
655                        for (i, select) in
656                            [command.counter0(), command.counter1(), command.counter2()]
657                                .into_iter()
658                                .enumerate()
659                        {
660                            if select {
661                                if command.status() {
662                                    self.timers[i].state.latch_status();
663                                }
664                                if command.count() {
665                                    self.timers[i].state.latch_counter();
666                                }
667                            }
668                        }
669                    }
670                    _ => unreachable!(),
671                }
672            }
673            PIT_PORT61_REGISTER => {
674                self.timers[2].set_gate((b & 1) != 0);
675            }
676            _ => return IoResult::Err(IoError::InvalidRegister),
677        }
678
679        self.arm_wakeup();
680        IoResult::Ok
681    }
682
683    fn get_static_regions(&mut self) -> &[(&str, RangeInclusive<u16>)] {
684        &[
685            ("main", PIT_TIMER_RANGE_START..=PIT_CONTROL_REGISTER),
686            ("port61", PIT_PORT61_REGISTER..=PIT_PORT61_REGISTER),
687        ]
688    }
689}
690
691mod save_restore {
692    use super::*;
693    use vmcore::save_restore::RestoreError;
694    use vmcore::save_restore::SaveError;
695    use vmcore::save_restore::SaveRestore;
696
697    mod state {
698        use mesh::payload::Protobuf;
699        use vmcore::save_restore::SavedStateRoot;
700        use vmcore::vmtime::VmTime;
701
702        #[derive(Protobuf)]
703        #[mesh(package = "chipset.pit")]
704        pub enum SavedCountState {
705            #[mesh(1)]
706            Inactive,
707            #[mesh(2)]
708            WaitingForGate,
709            #[mesh(3)]
710            Reloading,
711            #[mesh(4)]
712            Active,
713            #[mesh(5)]
714            Counting,
715        }
716
717        #[derive(Protobuf)]
718        #[mesh(package = "chipset.pit")]
719        pub struct SavedTimerState {
720            #[mesh(1)]
721            pub ce: u16,
722            #[mesh(2)]
723            pub cr: u16,
724            #[mesh(3)]
725            pub ol: Option<u16>,
726            #[mesh(4)]
727            pub sl: Option<u8>,
728            #[mesh(5)]
729            pub state: SavedCountState,
730            #[mesh(6)]
731            pub control: u8,
732            #[mesh(7)]
733            pub out: bool,
734            #[mesh(8)]
735            pub gate: bool,
736            #[mesh(9)]
737            pub null: bool,
738            #[mesh(10)]
739            pub read_high: bool,
740            #[mesh(11)]
741            pub cr_low: Option<u8>,
742        }
743
744        #[derive(Protobuf, SavedStateRoot)]
745        #[mesh(package = "chipset.pit")]
746        pub struct SavedState {
747            #[mesh(1)]
748            pub timers: [SavedTimerState; 3],
749            #[mesh(2)]
750            pub last: VmTime,
751        }
752    }
753
754    #[derive(Debug, Error)]
755    enum PitDeviceRestoreError {
756        #[error("last tick time is after current time")]
757        InvalidLastTick,
758    }
759
760    impl SaveRestore for PitDevice {
761        type SavedState = state::SavedState;
762
763        fn save(&mut self) -> Result<Self::SavedState, SaveError> {
764            let Self {
765                vmtime: _,
766                timers,
767                dram_refresh: _,
768                last,
769            } = self;
770
771            Ok(state::SavedState {
772                timers: [&timers[0].state, &timers[1].state, &timers[2].state].map(|timer| {
773                    let &TimerState {
774                        ce,
775                        cr,
776                        ol,
777                        sl,
778                        state,
779                        control,
780                        out,
781                        gate,
782                        null,
783                        read_high,
784                        cr_low,
785                    } = timer;
786
787                    state::SavedTimerState {
788                        ce,
789                        cr,
790                        ol,
791                        sl,
792                        state: match state {
793                            CountState::Inactive => state::SavedCountState::Inactive,
794                            CountState::WaitingForGate => state::SavedCountState::WaitingForGate,
795                            CountState::Reloading => state::SavedCountState::Reloading,
796                            CountState::Active => state::SavedCountState::Active,
797                            CountState::Counting => state::SavedCountState::Counting,
798                        },
799                        control: control.into(),
800                        out,
801                        gate,
802                        null,
803                        read_high,
804                        cr_low,
805                    }
806                }),
807
808                last: *last,
809            })
810        }
811
812        fn restore(&mut self, state: Self::SavedState) -> Result<(), RestoreError> {
813            let state::SavedState { timers, last } = state;
814
815            for (timer, state) in self.timers.iter_mut().zip(timers) {
816                let state::SavedTimerState {
817                    ce,
818                    cr,
819                    ol,
820                    sl,
821                    state,
822                    control,
823                    out,
824                    gate,
825                    null,
826                    read_high,
827                    cr_low,
828                } = state;
829
830                timer.state = TimerState {
831                    ce,
832                    cr,
833                    ol,
834                    sl,
835                    state: match state {
836                        state::SavedCountState::Inactive => CountState::Inactive,
837                        state::SavedCountState::WaitingForGate => CountState::WaitingForGate,
838                        state::SavedCountState::Reloading => CountState::Reloading,
839                        state::SavedCountState::Active => CountState::Active,
840                        state::SavedCountState::Counting => CountState::Counting,
841                    },
842                    out,
843                    control: ControlWord::from(control), // no unused bits
844                    gate,
845                    null,
846                    read_high,
847                    cr_low,
848                };
849
850                timer.sync_interrupt();
851            }
852
853            self.last = last;
854            if last.is_after(self.vmtime.now()) {
855                return Err(RestoreError::InvalidSavedState(
856                    PitDeviceRestoreError::InvalidLastTick.into(),
857                ));
858            }
859
860            Ok(())
861        }
862    }
863}
864
865#[cfg(test)]
866mod tests {
867    use super::ControlWord;
868    use super::Mode;
869    use super::RwMode;
870    use super::Timer;
871    use super::to_bcd;
872    use crate::pit::from_bcd;
873
874    #[test]
875    fn test_bcd_comp() {
876        for i in 0..=9999 {
877            assert_eq!(from_bcd(to_bcd(i)), i, "{i} {}", to_bcd(i));
878        }
879    }
880
881    fn set_timer(timer: &mut Timer, mode: Mode, mut cr: u16, bcd: bool) {
882        timer.set_control(
883            ControlWord::new()
884                .with_mode(mode as u8)
885                .with_rw(RwMode::LOW_HIGH.0)
886                .with_bcd(bcd),
887        );
888        if bcd {
889            cr = to_bcd(cr);
890        }
891        timer.write(cr as u8);
892        timer.write((cr >> 8) as u8);
893    }
894
895    fn check_invert(timer: &mut Timer, initial_out: bool, expected_next: u64) {
896        let mode = Mode::from(timer.state.control.mode());
897        assert_eq!(timer.state.out, initial_out, "{mode:?}");
898        let n = timer.state.next_wakeup().unwrap();
899        assert_eq!(n, expected_next, "{mode:?}");
900        for i in 0..n - 1 {
901            timer.evaluate(1);
902            assert_eq!(
903                i + timer.state.next_wakeup().unwrap() + 1,
904                n,
905                "{mode:?}, {i}"
906            );
907            assert_eq!(timer.state.out, initial_out, "{mode:?}, {i}");
908        }
909        timer.evaluate(1);
910        assert_eq!(timer.state.out, !initial_out, "{mode:?}, {n}");
911    }
912
913    fn check_done(timer: &mut Timer) {
914        assert!(timer.state.next_wakeup().is_none());
915        let out = timer.state.out;
916        for _ in 0..65536 {
917            timer.evaluate(1);
918            assert_eq!(timer.state.out, out);
919        }
920    }
921
922    fn test_output(bcd: bool) {
923        let mut timer = Timer::new(true, None);
924        let max = if bcd { 10000 } else { 0x10000 };
925
926        set_timer(&mut timer, Mode::TerminalCount, 0, bcd);
927        check_invert(&mut timer, false, max + 1);
928        check_done(&mut timer);
929
930        set_timer(&mut timer, Mode::HardwareOneShot, 0, bcd);
931        check_done(&mut timer);
932        timer.set_gate(false);
933        timer.set_gate(true);
934        check_invert(&mut timer, false, max + 1);
935        check_done(&mut timer);
936
937        set_timer(&mut timer, Mode::RateGenerator, 0, bcd);
938        check_invert(&mut timer, true, max);
939        check_invert(&mut timer, false, 1);
940        check_invert(&mut timer, true, max - 1);
941
942        set_timer(&mut timer, Mode::SquareWave, 0, bcd);
943        check_invert(&mut timer, true, max / 2 + 1);
944        check_invert(&mut timer, false, max / 2);
945        check_invert(&mut timer, true, max / 2);
946
947        set_timer(&mut timer, Mode::SquareWave, 1001, bcd);
948        check_invert(&mut timer, true, 502);
949        check_invert(&mut timer, false, 500);
950        check_invert(&mut timer, true, 501);
951
952        set_timer(&mut timer, Mode::SoftwareStrobe, 0, bcd);
953        check_invert(&mut timer, true, max + 1);
954        check_invert(&mut timer, false, 1);
955        check_done(&mut timer);
956
957        set_timer(&mut timer, Mode::HardwareStrobe, 0, bcd);
958        check_done(&mut timer);
959        timer.set_gate(false);
960        timer.set_gate(true);
961        check_invert(&mut timer, true, max + 1);
962        check_invert(&mut timer, false, 1);
963        check_done(&mut timer);
964    }
965
966    #[test]
967    fn test_binary() {
968        test_output(false);
969    }
970
971    #[test]
972    fn test_bcd() {
973        test_output(true);
974    }
975}