flowey_core/
patch.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use crate::node::FlowNodeBase;
5use crate::node::NodeHandle;
6use crate::node::WriteVar;
7use std::collections::BTreeMap;
8use std::sync::OnceLock;
9
10pub type PatchFn = fn(&mut PatchManager<'_>);
11
12// A patchfn that does nothing. Can be useful when writing logic that
13// conditionally applies patches.
14pub fn noop_patchfn(_: &mut PatchManager<'_>) {}
15
16enum PatchEvent {
17    Swap {
18        from_old_node: NodeHandle,
19        with_new_node: NodeHandle,
20    },
21    InjectSideEffect {
22        from_old_node: NodeHandle,
23        with_new_node: NodeHandle,
24        side_effect_var: String,
25        req: Box<[u8]>,
26    },
27}
28
29trait PatchManagerBackend {
30    fn new_side_effect_var(&mut self) -> String;
31    fn on_patch_event(&mut self, event: PatchEvent);
32}
33
34/// Passed to patch functions
35pub struct PatchManager<'a> {
36    backend: &'a mut dyn PatchManagerBackend,
37}
38
39impl PatchManager<'_> {
40    pub fn hook<N: FlowNodeBase>(&mut self) -> PatchHook<'_, N> {
41        PatchHook {
42            backend: self.backend,
43            _kind: std::marker::PhantomData,
44        }
45    }
46}
47
48/// Patch operations in the context of a particular Node.
49pub struct PatchHook<'a, N: FlowNodeBase> {
50    backend: &'a mut dyn PatchManagerBackend,
51    _kind: std::marker::PhantomData<N>,
52}
53
54impl<N> PatchHook<'_, N>
55where
56    N: FlowNodeBase + 'static,
57{
58    /// Swap out the target Node's implementation with a different
59    /// implementation.
60    pub fn swap_with<M>(&mut self) -> &mut Self
61    where
62        M: 'static,
63        // use the type system to enforce that patch nodes have an identical
64        // request type
65        M: FlowNodeBase<Request = N::Request>,
66    {
67        self.backend.on_patch_event(PatchEvent::Swap {
68            from_old_node: NodeHandle::from_type::<N>(),
69            with_new_node: NodeHandle::from_type::<M>(),
70        });
71        self
72    }
73
74    /// Inject a side-effect dependency, which runs before any other steps in
75    /// the Node.
76    pub fn inject_side_effect<T, M>(
77        &mut self,
78        f: impl FnOnce(WriteVar<T>) -> M::Request,
79    ) -> &mut Self
80    where
81        T: serde::Serialize + serde::de::DeserializeOwned,
82        M: 'static,
83        M: FlowNodeBase,
84    {
85        let backing_var = self.backend.new_side_effect_var();
86        let req = f(crate::node::thin_air_write_runtime_var(backing_var.clone()));
87
88        self.backend.on_patch_event(PatchEvent::InjectSideEffect {
89            from_old_node: NodeHandle::from_type::<N>(),
90            with_new_node: NodeHandle::from_type::<M>(),
91            side_effect_var: backing_var,
92            req: serde_json::to_vec(&req).map(Into::into).unwrap(),
93        });
94        self
95    }
96}
97
98pub fn patchfn_by_modpath() -> &'static BTreeMap<String, PatchFn> {
99    static MODPATH_LOOKUP: OnceLock<BTreeMap<String, PatchFn>> = OnceLock::new();
100
101    MODPATH_LOOKUP.get_or_init(|| {
102        let mut lookup = BTreeMap::new();
103        for (f, module_path, fn_name) in private::PATCH_FNS {
104            let existing = lookup.insert(format!("{}::{}", module_path, fn_name), *f);
105            // Rust would've errored out at module defn time with a duplicate fn name error
106            assert!(existing.is_none());
107        }
108        lookup
109    })
110}
111
112/// [`PatchResolver`]
113#[derive(Debug, Clone)]
114pub struct ResolvedPatches {
115    pub swap: BTreeMap<NodeHandle, NodeHandle>,
116    pub inject_side_effect: BTreeMap<NodeHandle, Vec<(NodeHandle, String, Box<[u8]>)>>,
117}
118
119impl ResolvedPatches {
120    pub fn build() -> PatchResolver {
121        PatchResolver {
122            side_effect_var_idx: 0,
123            swap: BTreeMap::default(),
124            inject_side_effect: BTreeMap::new(),
125        }
126    }
127}
128
129/// Helper method to resolve multiple patches into a single [`ResolvedPatches`]
130#[derive(Debug)]
131pub struct PatchResolver {
132    side_effect_var_idx: usize,
133    swap: BTreeMap<NodeHandle, NodeHandle>,
134    inject_side_effect: BTreeMap<NodeHandle, Vec<(NodeHandle, String, Box<[u8]>)>>,
135}
136
137impl PatchResolver {
138    pub fn apply_patchfn(&mut self, patchfn: PatchFn) {
139        patchfn(&mut PatchManager { backend: self });
140    }
141
142    pub fn finalize(self) -> ResolvedPatches {
143        let Self {
144            swap,
145            mut inject_side_effect,
146            side_effect_var_idx: _,
147        } = self;
148
149        // take into account the interaction between swaps and injected effects
150        for (from, to) in &swap {
151            let injected = inject_side_effect.remove(from);
152            if let Some(injected) = injected {
153                inject_side_effect.insert(*to, injected);
154            }
155        }
156
157        ResolvedPatches {
158            swap,
159            inject_side_effect,
160        }
161    }
162}
163
164impl PatchManagerBackend for PatchResolver {
165    fn new_side_effect_var(&mut self) -> String {
166        self.side_effect_var_idx += 1;
167        format!("patch_side_effect:{}", self.side_effect_var_idx)
168    }
169
170    fn on_patch_event(&mut self, event: PatchEvent) {
171        match event {
172            PatchEvent::Swap {
173                from_old_node,
174                with_new_node,
175            } => {
176                let existing = self.swap.insert(from_old_node, with_new_node);
177                // FUTURE: add some better error reporting / logging to
178                // allow doing this, albeit with a warning
179                assert!(
180                    existing.is_none(),
181                    "cannot double-patch the same node combo"
182                );
183            }
184            PatchEvent::InjectSideEffect {
185                from_old_node,
186                with_new_node,
187                side_effect_var,
188                req,
189            } => {
190                self.inject_side_effect
191                    .entry(from_old_node)
192                    .or_default()
193                    .push((with_new_node, side_effect_var, req));
194            }
195        }
196    }
197}
198
199#[doc(hidden)]
200pub mod private {
201    use super::PatchFn;
202    pub use linkme;
203
204    #[linkme::distributed_slice]
205    pub static PATCH_FNS: [(PatchFn, &'static str, &'static str)] = [..];
206
207    /// Register a patch function which can be used when emitting flows.
208    ///
209    /// The function must conform to the signature of [`PatchFn`]
210    #[macro_export]
211    macro_rules! register_patch {
212        ($patchfn:ident) => {
213            const _: () = {
214                use $crate::node::private::linkme;
215
216                #[linkme::distributed_slice($crate::patch::private::PATCH_FNS)]
217                #[linkme(crate = linkme)]
218                pub static PATCH_FNS: ($crate::patch::PatchFn, &'static str, &'static str) =
219                    ($patchfn, module_path!(), stringify!($patchfn));
220            };
221        };
222    }
223}