1#![cfg(feature = "std")]
7
8use super::FieldDescriptor;
9use super::FieldType;
10use super::MessageDescriptor;
11use super::OneofDescriptor;
12use super::TopLevelDescriptor;
13use crate::protofile::FieldKind;
14use crate::protofile::MessageDescription;
15use crate::protofile::SequenceType;
16use alloc::borrow::Cow;
17use alloc::boxed::Box;
18use alloc::collections::VecDeque;
19use alloc::format;
20use alloc::string::String;
21use alloc::vec::Vec;
22use heck::ToUpperCamelCase;
23use std::collections::HashMap;
24use std::io;
25use std::io::Write;
26use std::path::Path;
27use std::path::PathBuf;
28
29pub struct DescriptorWriter<'a> {
31 descriptors: Vec<&'a TopLevelDescriptor<'a>>,
32 file_heading: &'a str,
33}
34
35impl<'a> DescriptorWriter<'a> {
36 pub fn new(descriptors: impl IntoIterator<Item = &'a MessageDescription<'a>>) -> Self {
43 let mut descriptors = referenced_descriptors(descriptors);
45
46 descriptors.sort_by_key(|desc| (desc.package, desc.message.name));
48
49 Self {
50 descriptors,
51 file_heading: "",
52 }
53 }
54
55 pub fn file_heading(&mut self, file_heading: &'a str) -> &mut Self {
57 self.file_heading = file_heading;
58 self
59 }
60
61 pub fn write<W: Write>(&self, mut f: impl FnMut(&str) -> io::Result<W>) -> io::Result<()> {
63 let mut descriptors = self.descriptors.iter().copied().peekable();
64 while let Some(&first) = descriptors.peek() {
65 let file = f(&package_proto_file(first.package))?;
66 let mut writer = PackageWriter::new(first.package, Box::new(file));
67 write!(
68 writer,
69 "{file_heading}// Autogenerated, do not edit.\n\nsyntax = \"proto3\";\npackage {proto_package};\n",
70 file_heading = self.file_heading,
71 proto_package = first.package,
72 )?;
73 writer.nl_next();
74
75 let mut imports = Vec::new();
77 let n = {
78 let mut descriptors = descriptors.clone();
79 let mut n = 0;
80 while descriptors
81 .peek()
82 .is_some_and(|d| d.package == first.package)
83 {
84 let desc = descriptors.next().unwrap();
85 desc.message.collect_imports(&mut writer, &mut imports)?;
86 n += 1;
87 }
88 n
89 };
90
91 imports.sort();
92 imports.dedup();
93 for import in imports {
94 writeln!(writer, "import \"{import}\";")?;
95 }
96
97 writer.nl_next();
98
99 for desc in (&mut descriptors).take(n) {
101 desc.message.fmt(&mut writer)?;
102 }
103 }
104 Ok(())
105 }
106
107 pub fn write_to_path(&self, path: impl AsRef<Path>) -> io::Result<Vec<PathBuf>> {
111 let mut paths = Vec::new();
112 self.write(|name| {
113 let path = path.as_ref().join(name);
114 if let Some(parent) = path.parent() {
115 fs_err::create_dir_all(parent)?;
116 }
117 let file = fs_err::File::create(&path)?;
118 paths.push(path);
119 Ok(file)
120 })?;
121 Ok(paths)
122 }
123}
124
125struct PackageWriter<'a, 'w> {
126 writer: Box<dyn 'w + Write>,
127 needs_nl: bool,
128 needs_indent: bool,
129 indent: String,
130 package: &'a str,
131}
132
133impl<'a, 'w> PackageWriter<'a, 'w> {
134 fn new(package: &'a str, writer: Box<dyn 'w + Write>) -> Self {
135 Self {
136 writer,
137 needs_nl: false,
138 needs_indent: false,
139 indent: String::new(),
140 package,
141 }
142 }
143
144 fn indent(&mut self) {
145 self.indent += " ";
146 }
147
148 fn unindent(&mut self) {
149 self.indent.truncate(self.indent.len() - 2);
150 self.needs_nl = false;
151 }
152
153 fn nl_next(&mut self) {
154 self.needs_nl = true;
155 }
156}
157
158impl Write for PackageWriter<'_, '_> {
159 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
160 if buf.first() == Some(&b'\n') {
161 self.writer.write_all(b"\n")?;
162 self.needs_nl = false;
163 self.needs_indent = true;
164 return Ok(1);
165 }
166 if self.needs_nl {
167 self.writer.write_all(b"\n")?;
168 self.needs_nl = false;
169 }
170 if self.needs_indent {
171 self.writer.write_all(self.indent.as_bytes())?;
172 self.needs_indent = false;
173 }
174 self.writer.write_all(buf)?;
175 if buf.last() == Some(&b'\n') {
176 self.needs_indent = true;
177 }
178 Ok(buf.len())
179 }
180
181 fn flush(&mut self) -> io::Result<()> {
182 self.writer.flush()
183 }
184}
185
186fn referenced_descriptors<'a>(
188 descriptors: impl IntoIterator<Item = &'a MessageDescription<'a>>,
189) -> Vec<&'a TopLevelDescriptor<'a>> {
190 let mut descriptors =
192 HashMap::from_iter(descriptors.into_iter().copied().filter_map(|d| match d {
193 MessageDescription::Internal(tld) => Some(((tld.package, tld.message.name), tld)),
194 MessageDescription::External { .. } => None,
195 }));
196
197 let mut queue = VecDeque::from_iter(descriptors.values().copied());
198
199 fn process_field_type<'a>(
200 field_type: &FieldType<'a>,
201 descriptors: &mut HashMap<(&'a str, &'a str), &'a TopLevelDescriptor<'a>>,
202 queue: &mut VecDeque<&'a TopLevelDescriptor<'a>>,
203 ) {
204 match field_type.kind {
205 FieldKind::Message(tld) => {
206 if let MessageDescription::Internal(tld) = tld() {
207 if descriptors
208 .insert((tld.package, tld.message.name), tld)
209 .is_none()
210 {
211 queue.push_back(tld);
212 }
213 }
214 }
215 FieldKind::Tuple(tys) => {
216 for ty in tys {
217 process_field_type(ty, descriptors, queue);
218 }
219 }
220 FieldKind::KeyValue(tys) => {
221 for ty in tys {
222 process_field_type(ty, descriptors, queue);
223 }
224 }
225 FieldKind::Builtin(_) | FieldKind::Local(_) | FieldKind::External { .. } => {}
226 }
227 }
228
229 fn process_message<'a>(
230 message: &MessageDescriptor<'a>,
231 descriptors: &mut HashMap<(&'a str, &'a str), &'a TopLevelDescriptor<'a>>,
232 queue: &mut VecDeque<&'a TopLevelDescriptor<'a>>,
233 ) {
234 for field in message
235 .fields
236 .iter()
237 .chain(message.oneofs.iter().flat_map(|oneof| oneof.variants))
238 {
239 process_field_type(&field.field_type, descriptors, queue);
240 }
241 for inner in message.messages {
242 process_message(inner, descriptors, queue);
243 }
244 }
245
246 while let Some(tld) = queue.pop_front() {
247 process_message(tld.message, &mut descriptors, &mut queue);
248 }
249
250 descriptors.values().copied().collect()
251}
252
253fn package_proto_file(package: &str) -> String {
254 format!("{}.proto", package)
255}
256
257impl<'a> MessageDescriptor<'a> {
258 fn collect_imports(
259 &self,
260 w: &mut PackageWriter<'a, '_>,
261 imports: &mut Vec<Cow<'a, str>>,
262 ) -> io::Result<()> {
263 for message in self.messages {
264 message.collect_imports(w, imports)?;
265 }
266 for oneof in self.oneofs {
267 for field in oneof.variants {
268 field.field_type.collect_imports(w, imports)?;
269 }
270 }
271 for field in self.fields {
272 field.field_type.collect_imports(w, imports)?;
273 }
274 Ok(())
275 }
276
277 fn fmt(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
278 if !self.comment.is_empty() {
279 for line in self.comment.split('\n') {
280 writeln!(w, "//{line}")?;
281 }
282 }
283 writeln!(w, "message {} {{", self.name)?;
284 w.indent();
285 for message in self.messages {
286 message.fmt(w)?;
287 }
288 for oneof in self.oneofs {
289 oneof.fmt_nested_messages(w)?;
290 }
291 for field in self.fields {
292 field.fmt_nested_message(w)?;
293 }
294 for oneof in self.oneofs {
295 oneof.fmt(w)?;
296 }
297 for field in self.fields {
298 field.fmt(w)?;
299 }
300 w.unindent();
301 writeln!(w, "}}")?;
302 w.nl_next();
303 Ok(())
304 }
305}
306
307impl<'a> FieldType<'a> {
308 fn collect_imports(
309 &self,
310 w: &mut PackageWriter<'a, '_>,
311 imports: &mut Vec<Cow<'a, str>>,
312 ) -> io::Result<()> {
313 match self.kind {
314 FieldKind::Builtin(_) | FieldKind::Local(_) => {}
315 FieldKind::External { import_path, .. } => {
316 imports.push(import_path.into());
317 }
318 FieldKind::Message(f) => match f() {
319 MessageDescription::Internal(tld) => {
320 if w.package != tld.package {
321 imports.push(package_proto_file(tld.package).into());
322 }
323 }
324 MessageDescription::External {
325 name: _,
326 import_path,
327 } => {
328 imports.push(import_path.into());
329 }
330 },
331 FieldKind::Tuple(field_types) => {
332 for field_type in field_types {
333 field_type.collect_imports(w, imports)?;
334 }
335 }
336 FieldKind::KeyValue(field_types) => {
337 for field_type in field_types {
338 field_type.collect_imports(w, imports)?;
339 }
340 }
341 }
342 Ok(())
343 }
344}
345
346impl FieldDescriptor<'_> {
347 fn fmt_nested_message(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
348 match self.field_type.kind {
349 FieldKind::Tuple(field_types) => {
350 self.fmt_tuple_message(
351 w,
352 field_types,
353 (1..=field_types.len()).map(|i| format!("field{i}")),
354 )?;
355 }
356 FieldKind::KeyValue(field_types) => {
357 self.fmt_tuple_message(w, field_types, ["key", "value"])?;
358 }
359 FieldKind::Builtin(_)
360 | FieldKind::Local(_)
361 | FieldKind::External { .. }
362 | FieldKind::Message(_) => {}
363 }
364 Ok(())
365 }
366
367 fn fmt_tuple_message(
368 &self,
369 w: &mut PackageWriter<'_, '_>,
370 field_types: &[FieldType<'_>],
371 names: impl IntoIterator<Item = impl AsRef<str>>,
372 ) -> Result<(), io::Error> {
373 let fields = field_types
374 .iter()
375 .enumerate()
376 .zip(names)
377 .map(|((i, field_type), name)| (field_type, i as u32 + 1, name))
378 .collect::<Vec<_>>();
379 let fields = fields
380 .iter()
381 .map(|&(ty, number, ref name)| FieldDescriptor::new("", *ty, name.as_ref(), number))
382 .collect::<Vec<_>>();
383 MessageDescriptor::new(&self.name.to_upper_camel_case(), "", &fields, &[], &[]).fmt(w)?;
384 Ok(())
385 }
386
387 fn fmt(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
388 if !self.comment.is_empty() {
389 for line in self.comment.split('\n') {
390 writeln!(w, "//{}", line.trim_end())?;
391 }
392 }
393
394 let is_message = match self.field_type.kind {
395 FieldKind::Builtin(_) => false,
396 FieldKind::Local(_)
397 | FieldKind::External { .. }
398 | FieldKind::Message(_)
399 | FieldKind::Tuple(_)
400 | FieldKind::KeyValue { .. } => true,
401 };
402
403 match self.field_type.sequence_type {
404 Some(SequenceType::Optional) if !is_message => write!(w, "optional ")?,
406 None | Some(SequenceType::Optional) => {}
407 Some(SequenceType::Repeated) => write!(w, "repeated ")?,
408 Some(SequenceType::Map(key)) => write!(w, "map<{key}, ")?,
409 };
410 match self.field_type.kind {
411 FieldKind::Builtin(name) | FieldKind::Local(name) => write!(w, "{}", name)?,
412 FieldKind::External { name, .. } => write!(w, ".{}", name)?,
413 FieldKind::Message(tld) => match tld() {
414 MessageDescription::Internal(tld) => {
415 write!(w, ".{}.{}", tld.package, tld.message.name)?;
416 }
417 MessageDescription::External {
418 name,
419 import_path: _,
420 } => {
421 write!(w, ".{name}")?;
422 }
423 },
424 FieldKind::Tuple(_) | FieldKind::KeyValue(_) => {
425 write!(w, "{}", self.name.to_upper_camel_case())?
426 }
427 }
428 if matches!(self.field_type.sequence_type, Some(SequenceType::Map(_))) {
429 write!(w, ">")?;
430 }
431 write!(w, " {} = {};", self.name, self.field_number)?;
432 if !self.field_type.annotation.is_empty() {
433 write!(w, " // {}", self.field_type.annotation)?;
434 }
435 writeln!(w)
436 }
437}
438
439impl OneofDescriptor<'_> {
440 fn fmt_nested_messages(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
441 for variant in self.variants {
442 if variant.field_type.is_sequence() {
443 FieldDescriptor {
444 field_type: FieldType::tuple(&[variant.field_type]),
445 ..*variant
446 }
447 .fmt_nested_message(w)?;
448 } else {
449 variant.fmt_nested_message(w)?;
450 }
451 }
452 Ok(())
453 }
454
455 fn fmt(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
456 writeln!(w, "oneof {} {{", self.name)?;
457 w.indent();
458 for variant in self.variants {
459 if variant.field_type.is_sequence() {
460 FieldDescriptor {
461 field_type: FieldType::tuple(&[variant.field_type]),
462 ..*variant
463 }
464 .fmt(w)?;
465 } else {
466 variant.fmt(w)?;
467 }
468 }
469 w.unindent();
470 writeln!(w, "}}")?;
471 w.nl_next();
472 Ok(())
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::DescriptorWriter;
479 use crate::Protobuf;
480 use crate::protofile::message_description;
481 use alloc::string::String;
482 use alloc::vec::Vec;
483 use core::cell::RefCell;
484 use expect_test::expect;
485 use std::collections::HashMap;
486 use std::io::Write;
487
488 #[derive(Protobuf)]
490 #[mesh(package = "test")]
491 struct Foo {
492 #[mesh(1)]
494 x: u32,
495 #[mesh(2)]
496 t: (u32,),
497 #[mesh(3)]
498 t2: (),
499 #[mesh(4)]
500 bar: (u32, ()),
501 #[mesh(5)]
504 y: Vec<u32>,
505 #[mesh(6)]
510 b: (),
511 #[mesh(7)]
512 repeated_self: Vec<Foo>,
513 #[mesh(8)]
514 e: Bar,
515 #[mesh(9)]
516 nested_repeat: Vec<Vec<u32>>,
517 #[mesh(10)]
518 proto_map: HashMap<String, (u32,)>,
519 #[mesh(11)]
520 vec_map: HashMap<u32, Vec<u32>>,
521 #[mesh(12)]
522 bad_array: [u32; 3],
523 #[mesh(13)]
524 wrapped_array: [String; 3],
525 }
526
527 #[derive(Protobuf)]
528 #[mesh(package = "test")]
529 enum Bar {
530 #[mesh(1)]
531 This,
532 #[mesh(2)]
533 This2(),
534 #[mesh(3, transparent)]
535 That(u32),
536 #[mesh(4)]
537 Other {
538 #[mesh(1)]
539 hi: bool,
540 #[mesh(2)]
541 hello: u32,
542 },
543 #[mesh(5, transparent)]
544 Repeat(Vec<u32>),
545 #[mesh(6, transparent)]
546 DoubleRepeat(Vec<Vec<u32>>),
547 }
548
549 struct BorrowedWriter<T>(RefCell<T>);
550
551 impl<T: Write> Write for &BorrowedWriter<T> {
552 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
553 self.0.borrow_mut().write(buf)
554 }
555
556 fn flush(&mut self) -> std::io::Result<()> {
557 self.0.borrow_mut().flush()
558 }
559 }
560
561 #[test]
562 fn test() {
563 let writer = BorrowedWriter(RefCell::new(Vec::<u8>::new()));
564 DescriptorWriter::new(&[message_description::<Foo>()])
565 .write(|_name| Ok(&writer))
566 .unwrap();
567 let s = String::from_utf8(writer.0.into_inner()).unwrap();
568 let expected = expect!([r#"
569 // Autogenerated, do not edit.
570
571 syntax = "proto3";
572 package test;
573
574 import "google/protobuf/empty.proto";
575 import "google/protobuf/wrappers.proto";
576
577 message Bar {
578 message Other {
579 bool hi = 1;
580 uint32 hello = 2;
581 }
582
583 message Repeat {
584 repeated uint32 field1 = 1;
585 }
586
587 message DoubleRepeat {
588 message Field1 {
589 repeated uint32 field1 = 1;
590 }
591
592 repeated Field1 field1 = 1;
593 }
594
595 oneof variant {
596 .google.protobuf.Empty this = 1;
597 .google.protobuf.Empty this2 = 2;
598 uint32 that = 3;
599 Other other = 4;
600 Repeat repeat = 5;
601 DoubleRepeat double_repeat = 6;
602 }
603 }
604
605 // Comment on this guy.
606 message Foo {
607 message Bar {
608 uint32 field1 = 1;
609 .google.protobuf.Empty field2 = 2;
610 }
611
612 message NestedRepeat {
613 repeated uint32 field1 = 1;
614 }
615
616 message VecMap {
617 uint32 key = 1;
618 repeated uint32 value = 2;
619 }
620
621 message WrappedArray {
622 repeated string field1 = 1;
623 }
624
625 // Doc comment
626 uint32 x = 1;
627 .google.protobuf.UInt32Value t = 2;
628 .google.protobuf.Empty t2 = 3;
629 Bar bar = 4;
630 // Another doc comment
631 // (multi-line)
632 repeated uint32 y = 5;
633 //
634 // multi
635 // line
636 //
637 .google.protobuf.Empty b = 6;
638 repeated .test.Foo repeated_self = 7;
639 .test.Bar e = 8;
640 repeated NestedRepeat nested_repeat = 9;
641 map<string, .google.protobuf.UInt32Value> proto_map = 10;
642 repeated VecMap vec_map = 11;
643 repeated uint32 bad_array = 12; // packed repr only
644 WrappedArray wrapped_array = 13;
645 }
646 "#]);
647 expected.assert_eq(&s);
648 }
649}