1#![cfg(feature = "std")]
7
8use super::FieldDescriptor;
9use super::FieldType;
10use super::MessageDescriptor;
11use super::OneofDescriptor;
12use super::TopLevelDescriptor;
13use crate::protofile::FieldKind;
14use crate::protofile::MessageDescription;
15use crate::protofile::SequenceType;
16use alloc::borrow::Cow;
17use alloc::boxed::Box;
18use alloc::format;
19use alloc::string::String;
20use alloc::vec::Vec;
21use heck::ToUpperCamelCase;
22use std::collections::HashSet;
23use std::io;
24use std::io::Write;
25use std::path::Path;
26use std::path::PathBuf;
27
28pub struct DescriptorWriter<'a> {
30 descriptors: Vec<&'a TopLevelDescriptor<'a>>,
31 file_heading: &'a str,
32}
33
34impl<'a> DescriptorWriter<'a> {
35 pub fn new(descriptors: impl IntoIterator<Item = &'a MessageDescription<'a>>) -> Self {
42 let mut descriptors = referenced_descriptors(descriptors);
44
45 descriptors.sort_by_key(|desc| (desc.package, desc.message.name));
47 descriptors.dedup_by_key(|desc| (desc.package, desc.message.name));
49
50 Self {
51 descriptors,
52 file_heading: "",
53 }
54 }
55
56 pub fn file_heading(&mut self, file_heading: &'a str) -> &mut Self {
58 self.file_heading = file_heading;
59 self
60 }
61
62 pub fn write<W: Write>(&self, mut f: impl FnMut(&str) -> io::Result<W>) -> io::Result<()> {
64 let mut descriptors = self.descriptors.iter().copied().peekable();
65 while let Some(&first) = descriptors.peek() {
66 let file = f(&package_proto_file(first.package))?;
67 let mut writer = PackageWriter::new(first.package, Box::new(file));
68 write!(
69 writer,
70 "{file_heading}// Autogenerated, do not edit.\n\nsyntax = \"proto3\";\npackage {proto_package};\n",
71 file_heading = self.file_heading,
72 proto_package = first.package,
73 )?;
74 writer.nl_next();
75
76 let mut imports = Vec::new();
78 let n = {
79 let mut descriptors = descriptors.clone();
80 let mut n = 0;
81 while descriptors
82 .peek()
83 .is_some_and(|d| d.package == first.package)
84 {
85 let desc = descriptors.next().unwrap();
86 desc.message.collect_imports(&mut writer, &mut imports)?;
87 n += 1;
88 }
89 n
90 };
91
92 imports.sort();
93 imports.dedup();
94 for import in imports {
95 writeln!(writer, "import \"{import}\";")?;
96 }
97
98 writer.nl_next();
99
100 for desc in (&mut descriptors).take(n) {
102 desc.message.fmt(&mut writer)?;
103 }
104 }
105 Ok(())
106 }
107
108 pub fn write_to_path(&self, path: impl AsRef<Path>) -> io::Result<Vec<PathBuf>> {
112 let mut paths = Vec::new();
113 self.write(|name| {
114 let path = path.as_ref().join(name);
115 if let Some(parent) = path.parent() {
116 fs_err::create_dir_all(parent)?;
117 }
118 let file = fs_err::File::create(&path)?;
119 paths.push(path);
120 Ok(file)
121 })?;
122 Ok(paths)
123 }
124}
125
126struct PackageWriter<'a, 'w> {
127 writer: Box<dyn 'w + Write>,
128 needs_nl: bool,
129 needs_indent: bool,
130 indent: String,
131 package: &'a str,
132}
133
134impl<'a, 'w> PackageWriter<'a, 'w> {
135 fn new(package: &'a str, writer: Box<dyn 'w + Write>) -> Self {
136 Self {
137 writer,
138 needs_nl: false,
139 needs_indent: false,
140 indent: String::new(),
141 package,
142 }
143 }
144
145 fn indent(&mut self) {
146 self.indent += " ";
147 }
148
149 fn unindent(&mut self) {
150 self.indent.truncate(self.indent.len() - 2);
151 self.needs_nl = false;
152 }
153
154 fn nl_next(&mut self) {
155 self.needs_nl = true;
156 }
157}
158
159impl Write for PackageWriter<'_, '_> {
160 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
161 if buf.first() == Some(&b'\n') {
162 self.writer.write_all(b"\n")?;
163 self.needs_nl = false;
164 self.needs_indent = true;
165 return Ok(1);
166 }
167 if self.needs_nl {
168 self.writer.write_all(b"\n")?;
169 self.needs_nl = false;
170 }
171 if self.needs_indent {
172 self.writer.write_all(self.indent.as_bytes())?;
173 self.needs_indent = false;
174 }
175 self.writer.write_all(buf)?;
176 if buf.last() == Some(&b'\n') {
177 self.needs_indent = true;
178 }
179 Ok(buf.len())
180 }
181
182 fn flush(&mut self) -> io::Result<()> {
183 self.writer.flush()
184 }
185}
186
187fn referenced_descriptors<'a>(
189 descriptors: impl IntoIterator<Item = &'a MessageDescription<'a>>,
190) -> Vec<&'a TopLevelDescriptor<'a>> {
191 let mut descriptors =
192 Vec::from_iter(descriptors.into_iter().copied().filter_map(|d| match d {
193 MessageDescription::Internal(tld) => Some(tld),
194 MessageDescription::External { .. } => None,
195 }));
196 let mut inserted = HashSet::from_iter(descriptors.iter().copied());
197
198 fn process_field_type<'a>(
199 field_type: &FieldType<'a>,
200 descriptors: &mut Vec<&'a TopLevelDescriptor<'a>>,
201 inserted: &mut HashSet<&'a TopLevelDescriptor<'a>>,
202 ) {
203 match field_type.kind {
204 FieldKind::Message(tld) => {
205 if let MessageDescription::Internal(tld) = tld() {
206 if inserted.insert(tld) {
207 descriptors.push(tld);
208 }
209 }
210 }
211 FieldKind::Tuple(tys) => {
212 for ty in tys {
213 process_field_type(ty, descriptors, inserted);
214 }
215 }
216 FieldKind::KeyValue(tys) => {
217 for ty in tys {
218 process_field_type(ty, descriptors, inserted);
219 }
220 }
221 FieldKind::Builtin(_) | FieldKind::Local(_) | FieldKind::External { .. } => {}
222 }
223 }
224
225 fn process_message<'a>(
226 message: &MessageDescriptor<'a>,
227 descriptors: &mut Vec<&'a TopLevelDescriptor<'a>>,
228 inserted: &mut HashSet<&'a TopLevelDescriptor<'a>>,
229 ) {
230 for field in message
231 .fields
232 .iter()
233 .chain(message.oneofs.iter().flat_map(|oneof| oneof.variants))
234 {
235 process_field_type(&field.field_type, descriptors, inserted);
236 }
237 for inner in message.messages {
238 process_message(inner, descriptors, inserted);
239 }
240 }
241
242 let mut i = 0;
243 while let Some(&tld) = descriptors.get(i) {
244 process_message(tld.message, &mut descriptors, &mut inserted);
245 i += 1;
246 }
247
248 descriptors
249}
250
251fn package_proto_file(package: &str) -> String {
252 format!("{}.proto", package)
253}
254
255impl<'a> MessageDescriptor<'a> {
256 fn collect_imports(
257 &self,
258 w: &mut PackageWriter<'a, '_>,
259 imports: &mut Vec<Cow<'a, str>>,
260 ) -> io::Result<()> {
261 for message in self.messages {
262 message.collect_imports(w, imports)?;
263 }
264 for oneof in self.oneofs {
265 for field in oneof.variants {
266 field.field_type.collect_imports(w, imports)?;
267 }
268 }
269 for field in self.fields {
270 field.field_type.collect_imports(w, imports)?;
271 }
272 Ok(())
273 }
274
275 fn fmt(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
276 if !self.comment.is_empty() {
277 for line in self.comment.split('\n') {
278 writeln!(w, "//{line}")?;
279 }
280 }
281 writeln!(w, "message {} {{", self.name)?;
282 w.indent();
283 for message in self.messages {
284 message.fmt(w)?;
285 }
286 for oneof in self.oneofs {
287 oneof.fmt_nested_messages(w)?;
288 }
289 for field in self.fields {
290 field.fmt_nested_message(w)?;
291 }
292 for oneof in self.oneofs {
293 oneof.fmt(w)?;
294 }
295 for field in self.fields {
296 field.fmt(w)?;
297 }
298 w.unindent();
299 writeln!(w, "}}")?;
300 w.nl_next();
301 Ok(())
302 }
303}
304
305impl<'a> FieldType<'a> {
306 fn collect_imports(
307 &self,
308 w: &mut PackageWriter<'a, '_>,
309 imports: &mut Vec<Cow<'a, str>>,
310 ) -> io::Result<()> {
311 match self.kind {
312 FieldKind::Builtin(_) | FieldKind::Local(_) => {}
313 FieldKind::External { import_path, .. } => {
314 imports.push(import_path.into());
315 }
316 FieldKind::Message(f) => match f() {
317 MessageDescription::Internal(tld) => {
318 if w.package != tld.package {
319 imports.push(package_proto_file(tld.package).into());
320 }
321 }
322 MessageDescription::External {
323 name: _,
324 import_path,
325 } => {
326 imports.push(import_path.into());
327 }
328 },
329 FieldKind::Tuple(field_types) => {
330 for field_type in field_types {
331 field_type.collect_imports(w, imports)?;
332 }
333 }
334 FieldKind::KeyValue(field_types) => {
335 for field_type in field_types {
336 field_type.collect_imports(w, imports)?;
337 }
338 }
339 }
340 Ok(())
341 }
342}
343
344impl FieldDescriptor<'_> {
345 fn fmt_nested_message(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
346 match self.field_type.kind {
347 FieldKind::Tuple(field_types) => {
348 self.fmt_tuple_message(
349 w,
350 field_types,
351 (1..=field_types.len()).map(|i| format!("field{i}")),
352 )?;
353 }
354 FieldKind::KeyValue(field_types) => {
355 self.fmt_tuple_message(w, field_types, ["key", "value"])?;
356 }
357 FieldKind::Builtin(_)
358 | FieldKind::Local(_)
359 | FieldKind::External { .. }
360 | FieldKind::Message(_) => {}
361 }
362 Ok(())
363 }
364
365 fn fmt_tuple_message(
366 &self,
367 w: &mut PackageWriter<'_, '_>,
368 field_types: &[FieldType<'_>],
369 names: impl IntoIterator<Item = impl AsRef<str>>,
370 ) -> Result<(), io::Error> {
371 let fields = field_types
372 .iter()
373 .enumerate()
374 .zip(names)
375 .map(|((i, field_type), name)| (field_type, i as u32 + 1, name))
376 .collect::<Vec<_>>();
377 let fields = fields
378 .iter()
379 .map(|&(ty, number, ref name)| FieldDescriptor::new("", *ty, name.as_ref(), number))
380 .collect::<Vec<_>>();
381 MessageDescriptor::new(&self.name.to_upper_camel_case(), "", &fields, &[], &[]).fmt(w)?;
382 Ok(())
383 }
384
385 fn fmt(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
386 if !self.comment.is_empty() {
387 for line in self.comment.split('\n') {
388 writeln!(w, "//{}", line.trim_end())?;
389 }
390 }
391
392 let is_message = match self.field_type.kind {
393 FieldKind::Builtin(_) => false,
394 FieldKind::Local(_)
395 | FieldKind::External { .. }
396 | FieldKind::Message(_)
397 | FieldKind::Tuple(_)
398 | FieldKind::KeyValue { .. } => true,
399 };
400
401 match self.field_type.sequence_type {
402 Some(SequenceType::Optional) if !is_message => write!(w, "optional ")?,
404 None | Some(SequenceType::Optional) => {}
405 Some(SequenceType::Repeated) => write!(w, "repeated ")?,
406 Some(SequenceType::Map(key)) => write!(w, "map<{key}, ")?,
407 };
408 match self.field_type.kind {
409 FieldKind::Builtin(name) | FieldKind::Local(name) => write!(w, "{}", name)?,
410 FieldKind::External { name, .. } => write!(w, ".{}", name)?,
411 FieldKind::Message(tld) => match tld() {
412 MessageDescription::Internal(tld) => {
413 write!(w, ".{}.{}", tld.package, tld.message.name)?;
414 }
415 MessageDescription::External {
416 name,
417 import_path: _,
418 } => {
419 write!(w, ".{name}")?;
420 }
421 },
422 FieldKind::Tuple(_) | FieldKind::KeyValue(_) => {
423 write!(w, "{}", self.name.to_upper_camel_case())?
424 }
425 }
426 if matches!(self.field_type.sequence_type, Some(SequenceType::Map(_))) {
427 write!(w, ">")?;
428 }
429 write!(w, " {} = {};", self.name, self.field_number)?;
430 if !self.field_type.annotation.is_empty() {
431 write!(w, " // {}", self.field_type.annotation)?;
432 }
433 writeln!(w)
434 }
435}
436
437impl OneofDescriptor<'_> {
438 fn fmt_nested_messages(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
439 for variant in self.variants {
440 if variant.field_type.is_sequence() {
441 FieldDescriptor {
442 field_type: FieldType::tuple(&[variant.field_type]),
443 ..*variant
444 }
445 .fmt_nested_message(w)?;
446 } else {
447 variant.fmt_nested_message(w)?;
448 }
449 }
450 Ok(())
451 }
452
453 fn fmt(&self, w: &mut PackageWriter<'_, '_>) -> io::Result<()> {
454 writeln!(w, "oneof {} {{", self.name)?;
455 w.indent();
456 for variant in self.variants {
457 if variant.field_type.is_sequence() {
458 FieldDescriptor {
459 field_type: FieldType::tuple(&[variant.field_type]),
460 ..*variant
461 }
462 .fmt(w)?;
463 } else {
464 variant.fmt(w)?;
465 }
466 }
467 w.unindent();
468 writeln!(w, "}}")?;
469 w.nl_next();
470 Ok(())
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use super::DescriptorWriter;
477 use crate::Protobuf;
478 use crate::protofile::message_description;
479 use alloc::string::String;
480 use alloc::vec::Vec;
481 use core::cell::RefCell;
482 use expect_test::expect;
483 use std::collections::HashMap;
484 use std::io::Write;
485
486 #[derive(Protobuf)]
488 #[mesh(package = "test")]
489 struct Foo {
490 #[mesh(1)]
492 x: u32,
493 #[mesh(2)]
494 t: (u32,),
495 #[mesh(3)]
496 t2: (),
497 #[mesh(4)]
498 bar: (u32, ()),
499 #[mesh(5)]
502 y: Vec<u32>,
503 #[mesh(6)]
508 b: (),
509 #[mesh(7)]
510 repeated_self: Vec<Foo>,
511 #[mesh(8)]
512 e: Bar,
513 #[mesh(9)]
514 nested_repeat: Vec<Vec<u32>>,
515 #[mesh(10)]
516 proto_map: HashMap<String, (u32,)>,
517 #[mesh(11)]
518 vec_map: HashMap<u32, Vec<u32>>,
519 #[mesh(12)]
520 bad_array: [u32; 3],
521 #[mesh(13)]
522 wrapped_array: [String; 3],
523 }
524
525 #[derive(Protobuf)]
526 #[mesh(package = "test")]
527 enum Bar {
528 #[mesh(1)]
529 This,
530 #[mesh(2)]
531 This2(),
532 #[mesh(3, transparent)]
533 That(u32),
534 #[mesh(4)]
535 Other {
536 #[mesh(1)]
537 hi: bool,
538 #[mesh(2)]
539 hello: u32,
540 },
541 #[mesh(5, transparent)]
542 Repeat(Vec<u32>),
543 #[mesh(6, transparent)]
544 DoubleRepeat(Vec<Vec<u32>>),
545 }
546
547 struct BorrowedWriter<T>(RefCell<T>);
548
549 impl<T: Write> Write for &BorrowedWriter<T> {
550 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
551 self.0.borrow_mut().write(buf)
552 }
553
554 fn flush(&mut self) -> std::io::Result<()> {
555 self.0.borrow_mut().flush()
556 }
557 }
558
559 #[test]
560 fn test() {
561 let writer = BorrowedWriter(RefCell::new(Vec::<u8>::new()));
562 DescriptorWriter::new(&[message_description::<Foo>()])
563 .write(|_name| Ok(&writer))
564 .unwrap();
565 let s = String::from_utf8(writer.0.into_inner()).unwrap();
566 let expected = expect!([r#"
567 // Autogenerated, do not edit.
568
569 syntax = "proto3";
570 package test;
571
572 import "google/protobuf/empty.proto";
573 import "google/protobuf/wrappers.proto";
574
575 message Bar {
576 message Other {
577 bool hi = 1;
578 uint32 hello = 2;
579 }
580
581 message Repeat {
582 repeated uint32 field1 = 1;
583 }
584
585 message DoubleRepeat {
586 message Field1 {
587 repeated uint32 field1 = 1;
588 }
589
590 repeated Field1 field1 = 1;
591 }
592
593 oneof variant {
594 .google.protobuf.Empty this = 1;
595 .google.protobuf.Empty this2 = 2;
596 uint32 that = 3;
597 Other other = 4;
598 Repeat repeat = 5;
599 DoubleRepeat double_repeat = 6;
600 }
601 }
602
603 // Comment on this guy.
604 message Foo {
605 message Bar {
606 uint32 field1 = 1;
607 .google.protobuf.Empty field2 = 2;
608 }
609
610 message NestedRepeat {
611 repeated uint32 field1 = 1;
612 }
613
614 message VecMap {
615 uint32 key = 1;
616 repeated uint32 value = 2;
617 }
618
619 message WrappedArray {
620 repeated string field1 = 1;
621 }
622
623 // Doc comment
624 uint32 x = 1;
625 .google.protobuf.UInt32Value t = 2;
626 .google.protobuf.Empty t2 = 3;
627 Bar bar = 4;
628 // Another doc comment
629 // (multi-line)
630 repeated uint32 y = 5;
631 //
632 // multi
633 // line
634 //
635 .google.protobuf.Empty b = 6;
636 repeated .test.Foo repeated_self = 7;
637 .test.Bar e = 8;
638 repeated NestedRepeat nested_repeat = 9;
639 map<string, .google.protobuf.UInt32Value> proto_map = 10;
640 repeated VecMap vec_map = 11;
641 repeated uint32 bad_array = 12; // packed repr only
642 WrappedArray wrapped_array = 13;
643 }
644 "#]);
645 expected.assert_eq(&s);
646 }
647}