mesh_protobuf/protofile/
writer.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Code to write .proto files from descriptors.
5
6#![cfg(feature = "std")]
7
8use super::FieldDescriptor;
9use super::FieldType;
10use super::MessageDescriptor;
11use super::OneofDescriptor;
12use super::TopLevelDescriptor;
13use crate::protofile::FieldKind;
14use crate::protofile::MessageDescription;
15use crate::protofile::SequenceType;
16use alloc::borrow::Cow;
17use alloc::boxed::Box;
18use alloc::collections::VecDeque;
19use alloc::format;
20use alloc::string::String;
21use alloc::vec::Vec;
22use heck::ToUpperCamelCase;
23use std::collections::HashMap;
24use std::io;
25use std::io::Write;
26use std::path::Path;
27use std::path::PathBuf;
28
29/// A type used to write protobuf descriptors to `.proto`-format files.
30pub struct DescriptorWriter<'a> {
31    descriptors: Vec<&'a TopLevelDescriptor<'a>>,
32    file_heading: &'a str,
33}
34
35impl<'a> DescriptorWriter<'a> {
36    /// Returns a new object for writing the `.proto` files described by
37    /// `descriptors`.
38    ///
39    /// `descriptors` only needs to contain the roots of the protobuf
40    /// message graph; any other message types referred to by the types in
41    /// `descriptors` will be found and written to `.proto` files as well.
42    pub fn new(descriptors: impl IntoIterator<Item = &'a MessageDescription<'a>>) -> Self {
43        // First find all the descriptors starting with the provided roots.
44        let mut descriptors = referenced_descriptors(descriptors);
45
46        // Sort the descriptors to get a consistent order from run to run and build to build.
47        descriptors.sort_by_key(|desc| (desc.package, desc.message.name));
48
49        Self {
50            descriptors,
51            file_heading: "",
52        }
53    }
54
55    /// Sets the file heading written to each file.
56    pub fn file_heading(&mut self, file_heading: &'a str) -> &mut Self {
57        self.file_heading = file_heading;
58        self
59    }
60
61    /// Writes the `.proto` files to writers returned by `f`.
62    pub fn write<W: Write>(&self, mut f: impl FnMut(&str) -> io::Result<W>) -> io::Result<()> {
63        let mut descriptors = self.descriptors.iter().copied().peekable();
64        while let Some(&first) = descriptors.peek() {
65            let file = f(&package_proto_file(first.package))?;
66            let mut writer = PackageWriter::new(first.package, Box::new(file));
67            write!(
68                writer,
69                "{file_heading}// Autogenerated, do not edit.\n\nsyntax = \"proto3\";\npackage {proto_package};\n",
70                file_heading = self.file_heading,
71                proto_package = first.package,
72            )?;
73            writer.nl_next();
74
75            // Collect imports.
76            let mut imports = Vec::new();
77            let n = {
78                let mut descriptors = descriptors.clone();
79                let mut n = 0;
80                while descriptors
81                    .peek()
82                    .is_some_and(|d| d.package == first.package)
83                {
84                    let desc = descriptors.next().unwrap();
85                    desc.message.collect_imports(&mut writer, &mut imports)?;
86                    n += 1;
87                }
88                n
89            };
90
91            imports.sort();
92            imports.dedup();
93            for import in imports {
94                writeln!(writer, "import \"{import}\";")?;
95            }
96
97            writer.nl_next();
98
99            // Collect messages.
100            for desc in (&mut descriptors).take(n) {
101                desc.message.fmt(&mut writer)?;
102            }
103        }
104        Ok(())
105    }
106
107    /// Writes the `.proto` files to disk, rooted at `path`.
108    ///
109    /// Returns the paths of written files.
110    pub fn write_to_path(&self, path: impl AsRef<Path>) -> io::Result<Vec<PathBuf>> {
111        let mut paths = Vec::new();
112        self.write(|name| {
113            let path = path.as_ref().join(name);
114            if let Some(parent) = path.parent() {
115                fs_err::create_dir_all(parent)?;
116            }
117            let file = fs_err::File::create(&path)?;
118            paths.push(path);
119            Ok(file)
120        })?;
121        Ok(paths)
122    }
123}
124
125struct PackageWriter<'a, 'w> {
126    writer: Box<dyn 'w + Write>,
127    needs_nl: bool,
128    needs_indent: bool,
129    indent: String,
130    package: &'a str,
131}
132
133impl<'a, 'w> PackageWriter<'a, 'w> {
134    fn new(package: &'a str, writer: Box<dyn 'w + Write>) -> Self {
135        Self {
136            writer,
137            needs_nl: false,
138            needs_indent: false,
139            indent: String::new(),
140            package,
141        }
142    }
143
144    fn indent(&mut self) {
145        self.indent += "  ";
146    }
147
148    fn unindent(&mut self) {
149        self.indent.truncate(self.indent.len() - 2);
150        self.needs_nl = false;
151    }
152
153    fn nl_next(&mut self) {
154        self.needs_nl = true;
155    }
156}
157
158impl Write for PackageWriter<'_, '_> {
159    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
160        if buf.first() == Some(&b'\n') {
161            self.writer.write_all(b"\n")?;
162            self.needs_nl = false;
163            self.needs_indent = true;
164            return Ok(1);
165        }
166        if self.needs_nl {
167            self.writer.write_all(b"\n")?;
168            self.needs_nl = false;
169        }
170        if self.needs_indent {
171            self.writer.write_all(self.indent.as_bytes())?;
172            self.needs_indent = false;
173        }
174        self.writer.write_all(buf)?;
175        if buf.last() == Some(&b'\n') {
176            self.needs_indent = true;
177        }
178        Ok(buf.len())
179    }
180
181    fn flush(&mut self) -> io::Result<()> {
182        self.writer.flush()
183    }
184}
185
186/// Computes the referenced descriptors from a set of descriptors.
187fn referenced_descriptors<'a>(
188    descriptors: impl IntoIterator<Item = &'a MessageDescription<'a>>,
189) -> Vec<&'a TopLevelDescriptor<'a>> {
190    // Deduplicate by package and name. TODO: ensure duplicates match.
191    let mut descriptors =
192        HashMap::from_iter(descriptors.into_iter().copied().filter_map(|d| match d {
193            MessageDescription::Internal(tld) => Some(((tld.package, tld.message.name), tld)),
194            MessageDescription::External { .. } => None,
195        }));
196
197    let mut queue = VecDeque::from_iter(descriptors.values().copied());
198
199    fn process_field_type<'a>(
200        field_type: &FieldType<'a>,
201        descriptors: &mut HashMap<(&'a str, &'a str), &'a TopLevelDescriptor<'a>>,
202        queue: &mut VecDeque<&'a TopLevelDescriptor<'a>>,
203    ) {
204        match field_type.kind {
205            FieldKind::Message(tld) => {
206                if let MessageDescription::Internal(tld) = tld() {
207                    if descriptors
208                        .insert((tld.package, tld.message.name), tld)
209                        .is_none()
210                    {
211                        queue.push_back(tld);
212                    }
213                }
214            }
215            FieldKind::Tuple(tys) => {
216                for ty in tys {
217                    process_field_type(ty, descriptors, queue);
218                }
219            }
220            FieldKind::KeyValue(tys) => {
221                for ty in tys {
222                    process_field_type(ty, descriptors, queue);
223                }
224            }
225            FieldKind::Builtin(_) | FieldKind::Local(_) | FieldKind::External { .. } => {}
226        }
227    }
228
229    fn process_message<'a>(
230        message: &MessageDescriptor<'a>,
231        descriptors: &mut HashMap<(&'a str, &'a str), &'a TopLevelDescriptor<'a>>,
232        queue: &mut VecDeque<&'a TopLevelDescriptor<'a>>,
233    ) {
234        for field in message
235            .fields
236            .iter()
237            .chain(message.oneofs.iter().flat_map(|oneof| oneof.variants))
238        {
239            process_field_type(&field.field_type, descriptors, queue);
240        }
241        for inner in message.messages {
242            process_message(inner, descriptors, queue);
243        }
244    }
245
246    while let Some(tld) = queue.pop_front() {
247        process_message(tld.message, &mut descriptors, &mut queue);
248    }
249
250    descriptors.values().copied().collect()
251}
252
253fn package_proto_file(package: &str) -> String {
254    format!("{}.proto", package)
255}
256
257impl<'a> MessageDescriptor<'a> {
258    fn collect_imports(
259        &self,
260        w: &mut PackageWriter<'a, '_>,
261        imports: &mut Vec<Cow<'a, str>>,
262    ) -> io::Result<()> {
263        for message in self.messages {
264            message.collect_imports(w, imports)?;
265        }
266        for oneof in self.oneofs {
267            for field in oneof.variants {
268                field.field_type.collect_imports(w, imports)?;
269            }
270        }
271        for field in self.fields {
272            field.field_type.collect_imports(w, imports)?;
273        }
274        Ok(())
275    }
276
277    fn fmt(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
278        if !self.comment.is_empty() {
279            for line in self.comment.split('\n') {
280                writeln!(w, "//{line}")?;
281            }
282        }
283        writeln!(w, "message {} {{", self.name)?;
284        w.indent();
285        for message in self.messages {
286            message.fmt(w)?;
287        }
288        for oneof in self.oneofs {
289            oneof.fmt_nested_messages(w)?;
290        }
291        for field in self.fields {
292            field.fmt_nested_message(w)?;
293        }
294        for oneof in self.oneofs {
295            oneof.fmt(w)?;
296        }
297        for field in self.fields {
298            field.fmt(w)?;
299        }
300        w.unindent();
301        writeln!(w, "}}")?;
302        w.nl_next();
303        Ok(())
304    }
305}
306
307impl<'a> FieldType<'a> {
308    fn collect_imports(
309        &self,
310        w: &mut PackageWriter<'a, '_>,
311        imports: &mut Vec<Cow<'a, str>>,
312    ) -> io::Result<()> {
313        match self.kind {
314            FieldKind::Builtin(_) | FieldKind::Local(_) => {}
315            FieldKind::External { import_path, .. } => {
316                imports.push(import_path.into());
317            }
318            FieldKind::Message(f) => match f() {
319                MessageDescription::Internal(tld) => {
320                    if w.package != tld.package {
321                        imports.push(package_proto_file(tld.package).into());
322                    }
323                }
324                MessageDescription::External {
325                    name: _,
326                    import_path,
327                } => {
328                    imports.push(import_path.into());
329                }
330            },
331            FieldKind::Tuple(field_types) => {
332                for field_type in field_types {
333                    field_type.collect_imports(w, imports)?;
334                }
335            }
336            FieldKind::KeyValue(field_types) => {
337                for field_type in field_types {
338                    field_type.collect_imports(w, imports)?;
339                }
340            }
341        }
342        Ok(())
343    }
344}
345
346impl FieldDescriptor<'_> {
347    fn fmt_nested_message(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
348        match self.field_type.kind {
349            FieldKind::Tuple(field_types) => {
350                self.fmt_tuple_message(
351                    w,
352                    field_types,
353                    (1..=field_types.len()).map(|i| format!("field{i}")),
354                )?;
355            }
356            FieldKind::KeyValue(field_types) => {
357                self.fmt_tuple_message(w, field_types, ["key", "value"])?;
358            }
359            FieldKind::Builtin(_)
360            | FieldKind::Local(_)
361            | FieldKind::External { .. }
362            | FieldKind::Message(_) => {}
363        }
364        Ok(())
365    }
366
367    fn fmt_tuple_message(
368        &self,
369        w: &mut PackageWriter<'_, '_>,
370        field_types: &[FieldType<'_>],
371        names: impl IntoIterator<Item = impl AsRef<str>>,
372    ) -> Result<(), io::Error> {
373        let fields = field_types
374            .iter()
375            .enumerate()
376            .zip(names)
377            .map(|((i, field_type), name)| (field_type, i as u32 + 1, name))
378            .collect::<Vec<_>>();
379        let fields = fields
380            .iter()
381            .map(|&(ty, number, ref name)| FieldDescriptor::new("", *ty, name.as_ref(), number))
382            .collect::<Vec<_>>();
383        MessageDescriptor::new(&self.name.to_upper_camel_case(), "", &fields, &[], &[]).fmt(w)?;
384        Ok(())
385    }
386
387    fn fmt(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
388        if !self.comment.is_empty() {
389            for line in self.comment.split('\n') {
390                writeln!(w, "//{}", line.trim_end())?;
391            }
392        }
393
394        let is_message = match self.field_type.kind {
395            FieldKind::Builtin(_) => false,
396            FieldKind::Local(_)
397            | FieldKind::External { .. }
398            | FieldKind::Message(_)
399            | FieldKind::Tuple(_)
400            | FieldKind::KeyValue { .. } => true,
401        };
402
403        match self.field_type.sequence_type {
404            // Message fields are implicitly optional.
405            Some(SequenceType::Optional) if !is_message => write!(w, "optional ")?,
406            None | Some(SequenceType::Optional) => {}
407            Some(SequenceType::Repeated) => write!(w, "repeated ")?,
408            Some(SequenceType::Map(key)) => write!(w, "map<{key}, ")?,
409        };
410        match self.field_type.kind {
411            FieldKind::Builtin(name) | FieldKind::Local(name) => write!(w, "{}", name)?,
412            FieldKind::External { name, .. } => write!(w, ".{}", name)?,
413            FieldKind::Message(tld) => match tld() {
414                MessageDescription::Internal(tld) => {
415                    write!(w, ".{}.{}", tld.package, tld.message.name)?;
416                }
417                MessageDescription::External {
418                    name,
419                    import_path: _,
420                } => {
421                    write!(w, ".{name}")?;
422                }
423            },
424            FieldKind::Tuple(_) | FieldKind::KeyValue(_) => {
425                write!(w, "{}", self.name.to_upper_camel_case())?
426            }
427        }
428        if matches!(self.field_type.sequence_type, Some(SequenceType::Map(_))) {
429            write!(w, ">")?;
430        }
431        write!(w, " {} = {};", self.name, self.field_number)?;
432        if !self.field_type.annotation.is_empty() {
433            write!(w, " // {}", self.field_type.annotation)?;
434        }
435        writeln!(w)
436    }
437}
438
439impl OneofDescriptor<'_> {
440    fn fmt_nested_messages(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
441        for variant in self.variants {
442            if variant.field_type.is_sequence() {
443                FieldDescriptor {
444                    field_type: FieldType::tuple(&[variant.field_type]),
445                    ..*variant
446                }
447                .fmt_nested_message(w)?;
448            } else {
449                variant.fmt_nested_message(w)?;
450            }
451        }
452        Ok(())
453    }
454
455    fn fmt(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
456        writeln!(w, "oneof {} {{", self.name)?;
457        w.indent();
458        for variant in self.variants {
459            if variant.field_type.is_sequence() {
460                FieldDescriptor {
461                    field_type: FieldType::tuple(&[variant.field_type]),
462                    ..*variant
463                }
464                .fmt(w)?;
465            } else {
466                variant.fmt(w)?;
467            }
468        }
469        w.unindent();
470        writeln!(w, "}}")?;
471        w.nl_next();
472        Ok(())
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use super::DescriptorWriter;
479    use crate::Protobuf;
480    use crate::protofile::message_description;
481    use alloc::string::String;
482    use alloc::vec::Vec;
483    use core::cell::RefCell;
484    use expect_test::expect;
485    use std::collections::HashMap;
486    use std::io::Write;
487
488    /// Comment on this guy.
489    #[derive(Protobuf)]
490    #[mesh(package = "test")]
491    struct Foo {
492        /// Doc comment
493        #[mesh(1)]
494        x: u32,
495        #[mesh(2)]
496        t: (u32,),
497        #[mesh(3)]
498        t2: (),
499        #[mesh(4)]
500        bar: (u32, ()),
501        /// Another doc comment
502        /// (multi-line)
503        #[mesh(5)]
504        y: Vec<u32>,
505        /**
506        multi
507        line
508        */
509        #[mesh(6)]
510        b: (),
511        #[mesh(7)]
512        repeated_self: Vec<Foo>,
513        #[mesh(8)]
514        e: Bar,
515        #[mesh(9)]
516        nested_repeat: Vec<Vec<u32>>,
517        #[mesh(10)]
518        proto_map: HashMap<String, (u32,)>,
519        #[mesh(11)]
520        vec_map: HashMap<u32, Vec<u32>>,
521        #[mesh(12)]
522        bad_array: [u32; 3],
523        #[mesh(13)]
524        wrapped_array: [String; 3],
525    }
526
527    #[derive(Protobuf)]
528    #[mesh(package = "test")]
529    enum Bar {
530        #[mesh(1)]
531        This,
532        #[mesh(2)]
533        This2(),
534        #[mesh(3, transparent)]
535        That(u32),
536        #[mesh(4)]
537        Other {
538            #[mesh(1)]
539            hi: bool,
540            #[mesh(2)]
541            hello: u32,
542        },
543        #[mesh(5, transparent)]
544        Repeat(Vec<u32>),
545        #[mesh(6, transparent)]
546        DoubleRepeat(Vec<Vec<u32>>),
547    }
548
549    struct BorrowedWriter<T>(RefCell<T>);
550
551    impl<T: Write> Write for &BorrowedWriter<T> {
552        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
553            self.0.borrow_mut().write(buf)
554        }
555
556        fn flush(&mut self) -> std::io::Result<()> {
557            self.0.borrow_mut().flush()
558        }
559    }
560
561    #[test]
562    fn test() {
563        let writer = BorrowedWriter(RefCell::new(Vec::<u8>::new()));
564        DescriptorWriter::new(&[message_description::<Foo>()])
565            .write(|_name| Ok(&writer))
566            .unwrap();
567        let s = String::from_utf8(writer.0.into_inner()).unwrap();
568        let expected = expect!([r#"
569            // Autogenerated, do not edit.
570
571            syntax = "proto3";
572            package test;
573
574            import "google/protobuf/empty.proto";
575            import "google/protobuf/wrappers.proto";
576
577            message Bar {
578              message Other {
579                bool hi = 1;
580                uint32 hello = 2;
581              }
582
583              message Repeat {
584                repeated uint32 field1 = 1;
585              }
586
587              message DoubleRepeat {
588                message Field1 {
589                  repeated uint32 field1 = 1;
590                }
591
592                repeated Field1 field1 = 1;
593              }
594
595              oneof variant {
596                .google.protobuf.Empty this = 1;
597                .google.protobuf.Empty this2 = 2;
598                uint32 that = 3;
599                Other other = 4;
600                Repeat repeat = 5;
601                DoubleRepeat double_repeat = 6;
602              }
603            }
604
605            // Comment on this guy.
606            message Foo {
607              message Bar {
608                uint32 field1 = 1;
609                .google.protobuf.Empty field2 = 2;
610              }
611
612              message NestedRepeat {
613                repeated uint32 field1 = 1;
614              }
615
616              message VecMap {
617                uint32 key = 1;
618                repeated uint32 value = 2;
619              }
620
621              message WrappedArray {
622                repeated string field1 = 1;
623              }
624
625              // Doc comment
626              uint32 x = 1;
627              .google.protobuf.UInt32Value t = 2;
628              .google.protobuf.Empty t2 = 3;
629              Bar bar = 4;
630              // Another doc comment
631              // (multi-line)
632              repeated uint32 y = 5;
633              //
634              //        multi
635              //        line
636              //
637              .google.protobuf.Empty b = 6;
638              repeated .test.Foo repeated_self = 7;
639              .test.Bar e = 8;
640              repeated NestedRepeat nested_repeat = 9;
641              map<string, .google.protobuf.UInt32Value> proto_map = 10;
642              repeated VecMap vec_map = 11;
643              repeated uint32 bad_array = 12; // packed repr only
644              WrappedArray wrapped_array = 13;
645            }
646        "#]);
647        expected.assert_eq(&s);
648    }
649}