mesh_protobuf/protofile/
writer.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Code to write .proto files from descriptors.
5
6#![cfg(feature = "std")]
7
8use super::FieldDescriptor;
9use super::FieldType;
10use super::MessageDescriptor;
11use super::OneofDescriptor;
12use super::TopLevelDescriptor;
13use crate::protofile::FieldKind;
14use crate::protofile::MessageDescription;
15use crate::protofile::SequenceType;
16use alloc::borrow::Cow;
17use alloc::boxed::Box;
18use alloc::format;
19use alloc::string::String;
20use alloc::vec::Vec;
21use heck::ToUpperCamelCase;
22use std::collections::HashSet;
23use std::io;
24use std::io::Write;
25use std::path::Path;
26use std::path::PathBuf;
27
28/// A type used to write protobuf descriptors to `.proto`-format files.
29pub struct DescriptorWriter<'a> {
30    descriptors: Vec<&'a TopLevelDescriptor<'a>>,
31    file_heading: &'a str,
32}
33
34impl<'a> DescriptorWriter<'a> {
35    /// Returns a new object for writing the `.proto` files described by
36    /// `descriptors`.
37    ///
38    /// `descriptors` only needs to contain the roots of the protobuf
39    /// message graph; any other message types referred to by the types in
40    /// `descriptors` will be found and written to `.proto` files as well.
41    pub fn new(descriptors: impl IntoIterator<Item = &'a MessageDescription<'a>>) -> Self {
42        // First find all the descriptors starting with the provided roots.
43        let mut descriptors = referenced_descriptors(descriptors);
44
45        // Sort the descriptors to get a consistent order from run to run and build to build.
46        descriptors.sort_by_key(|desc| (desc.package, desc.message.name));
47        // Deduplicate by package and name. TODO: ensure duplicates match.
48        descriptors.dedup_by_key(|desc| (desc.package, desc.message.name));
49
50        Self {
51            descriptors,
52            file_heading: "",
53        }
54    }
55
56    /// Sets the file heading written to each file.
57    pub fn file_heading(&mut self, file_heading: &'a str) -> &mut Self {
58        self.file_heading = file_heading;
59        self
60    }
61
62    /// Writes the `.proto` files to writers returned by `f`.
63    pub fn write<W: Write>(&self, mut f: impl FnMut(&str) -> io::Result<W>) -> io::Result<()> {
64        let mut descriptors = self.descriptors.iter().copied().peekable();
65        while let Some(&first) = descriptors.peek() {
66            let file = f(&package_proto_file(first.package))?;
67            let mut writer = PackageWriter::new(first.package, Box::new(file));
68            write!(
69                writer,
70                "{file_heading}// Autogenerated, do not edit.\n\nsyntax = \"proto3\";\npackage {proto_package};\n",
71                file_heading = self.file_heading,
72                proto_package = first.package,
73            )?;
74            writer.nl_next();
75
76            // Collect imports.
77            let mut imports = Vec::new();
78            let n = {
79                let mut descriptors = descriptors.clone();
80                let mut n = 0;
81                while descriptors
82                    .peek()
83                    .is_some_and(|d| d.package == first.package)
84                {
85                    let desc = descriptors.next().unwrap();
86                    desc.message.collect_imports(&mut writer, &mut imports)?;
87                    n += 1;
88                }
89                n
90            };
91
92            imports.sort();
93            imports.dedup();
94            for import in imports {
95                writeln!(writer, "import \"{import}\";")?;
96            }
97
98            writer.nl_next();
99
100            // Collect messages.
101            for desc in (&mut descriptors).take(n) {
102                desc.message.fmt(&mut writer)?;
103            }
104        }
105        Ok(())
106    }
107
108    /// Writes the `.proto` files to disk, rooted at `path`.
109    ///
110    /// Returns the paths of written files.
111    pub fn write_to_path(&self, path: impl AsRef<Path>) -> io::Result<Vec<PathBuf>> {
112        let mut paths = Vec::new();
113        self.write(|name| {
114            let path = path.as_ref().join(name);
115            if let Some(parent) = path.parent() {
116                fs_err::create_dir_all(parent)?;
117            }
118            let file = fs_err::File::create(&path)?;
119            paths.push(path);
120            Ok(file)
121        })?;
122        Ok(paths)
123    }
124}
125
126struct PackageWriter<'a, 'w> {
127    writer: Box<dyn 'w + Write>,
128    needs_nl: bool,
129    needs_indent: bool,
130    indent: String,
131    package: &'a str,
132}
133
134impl<'a, 'w> PackageWriter<'a, 'w> {
135    fn new(package: &'a str, writer: Box<dyn 'w + Write>) -> Self {
136        Self {
137            writer,
138            needs_nl: false,
139            needs_indent: false,
140            indent: String::new(),
141            package,
142        }
143    }
144
145    fn indent(&mut self) {
146        self.indent += "  ";
147    }
148
149    fn unindent(&mut self) {
150        self.indent.truncate(self.indent.len() - 2);
151        self.needs_nl = false;
152    }
153
154    fn nl_next(&mut self) {
155        self.needs_nl = true;
156    }
157}
158
159impl Write for PackageWriter<'_, '_> {
160    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
161        if buf.first() == Some(&b'\n') {
162            self.writer.write_all(b"\n")?;
163            self.needs_nl = false;
164            self.needs_indent = true;
165            return Ok(1);
166        }
167        if self.needs_nl {
168            self.writer.write_all(b"\n")?;
169            self.needs_nl = false;
170        }
171        if self.needs_indent {
172            self.writer.write_all(self.indent.as_bytes())?;
173            self.needs_indent = false;
174        }
175        self.writer.write_all(buf)?;
176        if buf.last() == Some(&b'\n') {
177            self.needs_indent = true;
178        }
179        Ok(buf.len())
180    }
181
182    fn flush(&mut self) -> io::Result<()> {
183        self.writer.flush()
184    }
185}
186
187/// Computes the referenced descriptors from a set of descriptors.
188fn referenced_descriptors<'a>(
189    descriptors: impl IntoIterator<Item = &'a MessageDescription<'a>>,
190) -> Vec<&'a TopLevelDescriptor<'a>> {
191    let mut descriptors =
192        Vec::from_iter(descriptors.into_iter().copied().filter_map(|d| match d {
193            MessageDescription::Internal(tld) => Some(tld),
194            MessageDescription::External { .. } => None,
195        }));
196    let mut inserted = HashSet::from_iter(descriptors.iter().copied());
197
198    fn process_field_type<'a>(
199        field_type: &FieldType<'a>,
200        descriptors: &mut Vec<&'a TopLevelDescriptor<'a>>,
201        inserted: &mut HashSet<&'a TopLevelDescriptor<'a>>,
202    ) {
203        match field_type.kind {
204            FieldKind::Message(tld) => {
205                if let MessageDescription::Internal(tld) = tld() {
206                    if inserted.insert(tld) {
207                        descriptors.push(tld);
208                    }
209                }
210            }
211            FieldKind::Tuple(tys) => {
212                for ty in tys {
213                    process_field_type(ty, descriptors, inserted);
214                }
215            }
216            FieldKind::KeyValue(tys) => {
217                for ty in tys {
218                    process_field_type(ty, descriptors, inserted);
219                }
220            }
221            FieldKind::Builtin(_) | FieldKind::Local(_) | FieldKind::External { .. } => {}
222        }
223    }
224
225    fn process_message<'a>(
226        message: &MessageDescriptor<'a>,
227        descriptors: &mut Vec<&'a TopLevelDescriptor<'a>>,
228        inserted: &mut HashSet<&'a TopLevelDescriptor<'a>>,
229    ) {
230        for field in message
231            .fields
232            .iter()
233            .chain(message.oneofs.iter().flat_map(|oneof| oneof.variants))
234        {
235            process_field_type(&field.field_type, descriptors, inserted);
236        }
237        for inner in message.messages {
238            process_message(inner, descriptors, inserted);
239        }
240    }
241
242    let mut i = 0;
243    while let Some(&tld) = descriptors.get(i) {
244        process_message(tld.message, &mut descriptors, &mut inserted);
245        i += 1;
246    }
247
248    descriptors
249}
250
251fn package_proto_file(package: &str) -> String {
252    format!("{}.proto", package)
253}
254
255impl<'a> MessageDescriptor<'a> {
256    fn collect_imports(
257        &self,
258        w: &mut PackageWriter<'a, '_>,
259        imports: &mut Vec<Cow<'a, str>>,
260    ) -> io::Result<()> {
261        for message in self.messages {
262            message.collect_imports(w, imports)?;
263        }
264        for oneof in self.oneofs {
265            for field in oneof.variants {
266                field.field_type.collect_imports(w, imports)?;
267            }
268        }
269        for field in self.fields {
270            field.field_type.collect_imports(w, imports)?;
271        }
272        Ok(())
273    }
274
275    fn fmt(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
276        if !self.comment.is_empty() {
277            for line in self.comment.split('\n') {
278                writeln!(w, "//{line}")?;
279            }
280        }
281        writeln!(w, "message {} {{", self.name)?;
282        w.indent();
283        for message in self.messages {
284            message.fmt(w)?;
285        }
286        for oneof in self.oneofs {
287            oneof.fmt_nested_messages(w)?;
288        }
289        for field in self.fields {
290            field.fmt_nested_message(w)?;
291        }
292        for oneof in self.oneofs {
293            oneof.fmt(w)?;
294        }
295        for field in self.fields {
296            field.fmt(w)?;
297        }
298        w.unindent();
299        writeln!(w, "}}")?;
300        w.nl_next();
301        Ok(())
302    }
303}
304
305impl<'a> FieldType<'a> {
306    fn collect_imports(
307        &self,
308        w: &mut PackageWriter<'a, '_>,
309        imports: &mut Vec<Cow<'a, str>>,
310    ) -> io::Result<()> {
311        match self.kind {
312            FieldKind::Builtin(_) | FieldKind::Local(_) => {}
313            FieldKind::External { import_path, .. } => {
314                imports.push(import_path.into());
315            }
316            FieldKind::Message(f) => match f() {
317                MessageDescription::Internal(tld) => {
318                    if w.package != tld.package {
319                        imports.push(package_proto_file(tld.package).into());
320                    }
321                }
322                MessageDescription::External {
323                    name: _,
324                    import_path,
325                } => {
326                    imports.push(import_path.into());
327                }
328            },
329            FieldKind::Tuple(field_types) => {
330                for field_type in field_types {
331                    field_type.collect_imports(w, imports)?;
332                }
333            }
334            FieldKind::KeyValue(field_types) => {
335                for field_type in field_types {
336                    field_type.collect_imports(w, imports)?;
337                }
338            }
339        }
340        Ok(())
341    }
342}
343
344impl FieldDescriptor<'_> {
345    fn fmt_nested_message(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
346        match self.field_type.kind {
347            FieldKind::Tuple(field_types) => {
348                self.fmt_tuple_message(
349                    w,
350                    field_types,
351                    (1..=field_types.len()).map(|i| format!("field{i}")),
352                )?;
353            }
354            FieldKind::KeyValue(field_types) => {
355                self.fmt_tuple_message(w, field_types, ["key", "value"])?;
356            }
357            FieldKind::Builtin(_)
358            | FieldKind::Local(_)
359            | FieldKind::External { .. }
360            | FieldKind::Message(_) => {}
361        }
362        Ok(())
363    }
364
365    fn fmt_tuple_message(
366        &self,
367        w: &mut PackageWriter<'_, '_>,
368        field_types: &[FieldType<'_>],
369        names: impl IntoIterator<Item = impl AsRef<str>>,
370    ) -> Result<(), io::Error> {
371        let fields = field_types
372            .iter()
373            .enumerate()
374            .zip(names)
375            .map(|((i, field_type), name)| (field_type, i as u32 + 1, name))
376            .collect::<Vec<_>>();
377        let fields = fields
378            .iter()
379            .map(|&(ty, number, ref name)| FieldDescriptor::new("", *ty, name.as_ref(), number))
380            .collect::<Vec<_>>();
381        MessageDescriptor::new(&self.name.to_upper_camel_case(), "", &fields, &[], &[]).fmt(w)?;
382        Ok(())
383    }
384
385    fn fmt(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
386        if !self.comment.is_empty() {
387            for line in self.comment.split('\n') {
388                writeln!(w, "//{}", line.trim_end())?;
389            }
390        }
391
392        let is_message = match self.field_type.kind {
393            FieldKind::Builtin(_) => false,
394            FieldKind::Local(_)
395            | FieldKind::External { .. }
396            | FieldKind::Message(_)
397            | FieldKind::Tuple(_)
398            | FieldKind::KeyValue { .. } => true,
399        };
400
401        match self.field_type.sequence_type {
402            // Message fields are implicitly optional.
403            Some(SequenceType::Optional) if !is_message => write!(w, "optional ")?,
404            None | Some(SequenceType::Optional) => {}
405            Some(SequenceType::Repeated) => write!(w, "repeated ")?,
406            Some(SequenceType::Map(key)) => write!(w, "map<{key}, ")?,
407        };
408        match self.field_type.kind {
409            FieldKind::Builtin(name) | FieldKind::Local(name) => write!(w, "{}", name)?,
410            FieldKind::External { name, .. } => write!(w, ".{}", name)?,
411            FieldKind::Message(tld) => match tld() {
412                MessageDescription::Internal(tld) => {
413                    write!(w, ".{}.{}", tld.package, tld.message.name)?;
414                }
415                MessageDescription::External {
416                    name,
417                    import_path: _,
418                } => {
419                    write!(w, ".{name}")?;
420                }
421            },
422            FieldKind::Tuple(_) | FieldKind::KeyValue(_) => {
423                write!(w, "{}", self.name.to_upper_camel_case())?
424            }
425        }
426        if matches!(self.field_type.sequence_type, Some(SequenceType::Map(_))) {
427            write!(w, ">")?;
428        }
429        write!(w, " {} = {};", self.name, self.field_number)?;
430        if !self.field_type.annotation.is_empty() {
431            write!(w, " // {}", self.field_type.annotation)?;
432        }
433        writeln!(w)
434    }
435}
436
437impl OneofDescriptor<'_> {
438    fn fmt_nested_messages(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
439        for variant in self.variants {
440            if variant.field_type.is_sequence() {
441                FieldDescriptor {
442                    field_type: FieldType::tuple(&[variant.field_type]),
443                    ..*variant
444                }
445                .fmt_nested_message(w)?;
446            } else {
447                variant.fmt_nested_message(w)?;
448            }
449        }
450        Ok(())
451    }
452
453    fn fmt(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
454        writeln!(w, "oneof {} {{", self.name)?;
455        w.indent();
456        for variant in self.variants {
457            if variant.field_type.is_sequence() {
458                FieldDescriptor {
459                    field_type: FieldType::tuple(&[variant.field_type]),
460                    ..*variant
461                }
462                .fmt(w)?;
463            } else {
464                variant.fmt(w)?;
465            }
466        }
467        w.unindent();
468        writeln!(w, "}}")?;
469        w.nl_next();
470        Ok(())
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use super::DescriptorWriter;
477    use crate::Protobuf;
478    use crate::protofile::message_description;
479    use alloc::string::String;
480    use alloc::vec::Vec;
481    use core::cell::RefCell;
482    use expect_test::expect;
483    use std::collections::HashMap;
484    use std::io::Write;
485
486    /// Comment on this guy.
487    #[derive(Protobuf)]
488    #[mesh(package = "test")]
489    struct Foo {
490        /// Doc comment
491        #[mesh(1)]
492        x: u32,
493        #[mesh(2)]
494        t: (u32,),
495        #[mesh(3)]
496        t2: (),
497        #[mesh(4)]
498        bar: (u32, ()),
499        /// Another doc comment
500        /// (multi-line)
501        #[mesh(5)]
502        y: Vec<u32>,
503        /**
504        multi
505        line
506        */
507        #[mesh(6)]
508        b: (),
509        #[mesh(7)]
510        repeated_self: Vec<Foo>,
511        #[mesh(8)]
512        e: Bar,
513        #[mesh(9)]
514        nested_repeat: Vec<Vec<u32>>,
515        #[mesh(10)]
516        proto_map: HashMap<String, (u32,)>,
517        #[mesh(11)]
518        vec_map: HashMap<u32, Vec<u32>>,
519        #[mesh(12)]
520        bad_array: [u32; 3],
521        #[mesh(13)]
522        wrapped_array: [String; 3],
523    }
524
525    #[derive(Protobuf)]
526    #[mesh(package = "test")]
527    enum Bar {
528        #[mesh(1)]
529        This,
530        #[mesh(2)]
531        This2(),
532        #[mesh(3, transparent)]
533        That(u32),
534        #[mesh(4)]
535        Other {
536            #[mesh(1)]
537            hi: bool,
538            #[mesh(2)]
539            hello: u32,
540        },
541        #[mesh(5, transparent)]
542        Repeat(Vec<u32>),
543        #[mesh(6, transparent)]
544        DoubleRepeat(Vec<Vec<u32>>),
545    }
546
547    struct BorrowedWriter<T>(RefCell<T>);
548
549    impl<T: Write> Write for &BorrowedWriter<T> {
550        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
551            self.0.borrow_mut().write(buf)
552        }
553
554        fn flush(&mut self) -> std::io::Result<()> {
555            self.0.borrow_mut().flush()
556        }
557    }
558
559    #[test]
560    fn test() {
561        let writer = BorrowedWriter(RefCell::new(Vec::<u8>::new()));
562        DescriptorWriter::new(&[message_description::<Foo>()])
563            .write(|_name| Ok(&writer))
564            .unwrap();
565        let s = String::from_utf8(writer.0.into_inner()).unwrap();
566        let expected = expect!([r#"
567            // Autogenerated, do not edit.
568
569            syntax = "proto3";
570            package test;
571
572            import "google/protobuf/empty.proto";
573            import "google/protobuf/wrappers.proto";
574
575            message Bar {
576              message Other {
577                bool hi = 1;
578                uint32 hello = 2;
579              }
580
581              message Repeat {
582                repeated uint32 field1 = 1;
583              }
584
585              message DoubleRepeat {
586                message Field1 {
587                  repeated uint32 field1 = 1;
588                }
589
590                repeated Field1 field1 = 1;
591              }
592
593              oneof variant {
594                .google.protobuf.Empty this = 1;
595                .google.protobuf.Empty this2 = 2;
596                uint32 that = 3;
597                Other other = 4;
598                Repeat repeat = 5;
599                DoubleRepeat double_repeat = 6;
600              }
601            }
602
603            // Comment on this guy.
604            message Foo {
605              message Bar {
606                uint32 field1 = 1;
607                .google.protobuf.Empty field2 = 2;
608              }
609
610              message NestedRepeat {
611                repeated uint32 field1 = 1;
612              }
613
614              message VecMap {
615                uint32 key = 1;
616                repeated uint32 value = 2;
617              }
618
619              message WrappedArray {
620                repeated string field1 = 1;
621              }
622
623              // Doc comment
624              uint32 x = 1;
625              .google.protobuf.UInt32Value t = 2;
626              .google.protobuf.Empty t2 = 3;
627              Bar bar = 4;
628              // Another doc comment
629              // (multi-line)
630              repeated uint32 y = 5;
631              //
632              //        multi
633              //        line
634              //
635              .google.protobuf.Empty b = 6;
636              repeated .test.Foo repeated_self = 7;
637              .test.Bar e = 8;
638              repeated NestedRepeat nested_repeat = 9;
639              map<string, .google.protobuf.UInt32Value> proto_map = 10;
640              repeated VecMap vec_map = 11;
641              repeated uint32 bad_array = 12; // packed repr only
642              WrappedArray wrapped_array = 13;
643            }
644        "#]);
645        expected.assert_eq(&s);
646    }
647}