1 use std::convert::TryFrom;
2 use std::fmt;
3 
4 use anyhow::{anyhow, bail, Error};
5 use proc_macro2::{Span, TokenStream};
6 use quote::{quote, ToTokens, TokenStreamExt};
7 use syn::{parse_str, Expr, ExprLit, Ident, Index, Lit, LitByteStr, Meta, MetaNameValue, Path};
8 
9 use crate::field::{bool_attr, set_option, tag_attr, Label};
10 
11 /// A scalar protobuf field.
12 #[derive(Clone)]
13 pub struct Field {
14     pub ty: Ty,
15     pub kind: Kind,
16     pub tag: u32,
17 }
18 
19 impl Field {
new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error>20     pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> {
21         let mut ty = None;
22         let mut label = None;
23         let mut packed = None;
24         let mut default = None;
25         let mut tag = None;
26 
27         let mut unknown_attrs = Vec::new();
28 
29         for attr in attrs {
30             if let Some(t) = Ty::from_attr(attr)? {
31                 set_option(&mut ty, t, "duplicate type attributes")?;
32             } else if let Some(p) = bool_attr("packed", attr)? {
33                 set_option(&mut packed, p, "duplicate packed attributes")?;
34             } else if let Some(t) = tag_attr(attr)? {
35                 set_option(&mut tag, t, "duplicate tag attributes")?;
36             } else if let Some(l) = Label::from_attr(attr) {
37                 set_option(&mut label, l, "duplicate label attributes")?;
38             } else if let Some(d) = DefaultValue::from_attr(attr)? {
39                 set_option(&mut default, d, "duplicate default attributes")?;
40             } else {
41                 unknown_attrs.push(attr);
42             }
43         }
44 
45         let ty = match ty {
46             Some(ty) => ty,
47             None => return Ok(None),
48         };
49 
50         match unknown_attrs.len() {
51             0 => (),
52             1 => bail!("unknown attribute: {:?}", unknown_attrs[0]),
53             _ => bail!("unknown attributes: {:?}", unknown_attrs),
54         }
55 
56         let tag = match tag.or(inferred_tag) {
57             Some(tag) => tag,
58             None => bail!("missing tag attribute"),
59         };
60 
61         let has_default = default.is_some();
62         let default = default.map_or_else(
63             || Ok(DefaultValue::new(&ty)),
64             |lit| DefaultValue::from_lit(&ty, lit),
65         )?;
66 
67         let kind = match (label, packed, has_default) {
68             (None, Some(true), _)
69             | (Some(Label::Optional), Some(true), _)
70             | (Some(Label::Required), Some(true), _) => {
71                 bail!("packed attribute may only be applied to repeated fields");
72             }
73             (Some(Label::Repeated), Some(true), _) if !ty.is_numeric() => {
74                 bail!("packed attribute may only be applied to numeric types");
75             }
76             (Some(Label::Repeated), _, true) => {
77                 bail!("repeated fields may not have a default value");
78             }
79 
80             (None, _, _) => Kind::Plain(default),
81             (Some(Label::Optional), _, _) => Kind::Optional(default),
82             (Some(Label::Required), _, _) => Kind::Required(default),
83             (Some(Label::Repeated), packed, false) if packed.unwrap_or_else(|| ty.is_numeric()) => {
84                 Kind::Packed
85             }
86             (Some(Label::Repeated), _, false) => Kind::Repeated,
87         };
88 
89         Ok(Some(Field { ty, kind, tag }))
90     }
91 
new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error>92     pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> {
93         if let Some(mut field) = Field::new(attrs, None)? {
94             match field.kind {
95                 Kind::Plain(default) => {
96                     field.kind = Kind::Required(default);
97                     Ok(Some(field))
98                 }
99                 Kind::Optional(..) => bail!("invalid optional attribute on oneof field"),
100                 Kind::Required(..) => bail!("invalid required attribute on oneof field"),
101                 Kind::Packed | Kind::Repeated => bail!("invalid repeated attribute on oneof field"),
102             }
103         } else {
104             Ok(None)
105         }
106     }
107 
encode(&self, ident: TokenStream) -> TokenStream108     pub fn encode(&self, ident: TokenStream) -> TokenStream {
109         let module = self.ty.module();
110         let encode_fn = match self.kind {
111             Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encode),
112             Kind::Repeated => quote!(encode_repeated),
113             Kind::Packed => quote!(encode_packed),
114         };
115         let encode_fn = quote!(::prost::encoding::#module::#encode_fn);
116         let tag = self.tag;
117 
118         match self.kind {
119             Kind::Plain(ref default) => {
120                 let default = default.typed();
121                 quote! {
122                     if #ident != #default {
123                         #encode_fn(#tag, &#ident, buf);
124                     }
125                 }
126             }
127             Kind::Optional(..) => quote! {
128                 if let ::core::option::Option::Some(ref value) = #ident {
129                     #encode_fn(#tag, value, buf);
130                 }
131             },
132             Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! {
133                 #encode_fn(#tag, &#ident, buf);
134             },
135         }
136     }
137 
138     /// Returns an expression which evaluates to the result of merging a decoded
139     /// scalar value into the field.
merge(&self, ident: TokenStream) -> TokenStream140     pub fn merge(&self, ident: TokenStream) -> TokenStream {
141         let module = self.ty.module();
142         let merge_fn = match self.kind {
143             Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(merge),
144             Kind::Repeated | Kind::Packed => quote!(merge_repeated),
145         };
146         let merge_fn = quote!(::prost::encoding::#module::#merge_fn);
147 
148         match self.kind {
149             Kind::Plain(..) | Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! {
150                 #merge_fn(wire_type, #ident, buf, ctx)
151             },
152             Kind::Optional(..) => quote! {
153                 #merge_fn(wire_type,
154                           #ident.get_or_insert_with(::core::default::Default::default),
155                           buf,
156                           ctx)
157             },
158         }
159     }
160 
161     /// Returns an expression which evaluates to the encoded length of the field.
encoded_len(&self, ident: TokenStream) -> TokenStream162     pub fn encoded_len(&self, ident: TokenStream) -> TokenStream {
163         let module = self.ty.module();
164         let encoded_len_fn = match self.kind {
165             Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encoded_len),
166             Kind::Repeated => quote!(encoded_len_repeated),
167             Kind::Packed => quote!(encoded_len_packed),
168         };
169         let encoded_len_fn = quote!(::prost::encoding::#module::#encoded_len_fn);
170         let tag = self.tag;
171 
172         match self.kind {
173             Kind::Plain(ref default) => {
174                 let default = default.typed();
175                 quote! {
176                     if #ident != #default {
177                         #encoded_len_fn(#tag, &#ident)
178                     } else {
179                         0
180                     }
181                 }
182             }
183             Kind::Optional(..) => quote! {
184                 #ident.as_ref().map_or(0, |value| #encoded_len_fn(#tag, value))
185             },
186             Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! {
187                 #encoded_len_fn(#tag, &#ident)
188             },
189         }
190     }
191 
clear(&self, ident: TokenStream) -> TokenStream192     pub fn clear(&self, ident: TokenStream) -> TokenStream {
193         match self.kind {
194             Kind::Plain(ref default) | Kind::Required(ref default) => {
195                 let default = default.typed();
196                 match self.ty {
197                     Ty::String | Ty::Bytes(..) => quote!(#ident.clear()),
198                     _ => quote!(#ident = #default),
199                 }
200             }
201             Kind::Optional(_) => quote!(#ident = ::core::option::Option::None),
202             Kind::Repeated | Kind::Packed => quote!(#ident.clear()),
203         }
204     }
205 
206     /// Returns an expression which evaluates to the default value of the field.
default(&self) -> TokenStream207     pub fn default(&self) -> TokenStream {
208         match self.kind {
209             Kind::Plain(ref value) | Kind::Required(ref value) => value.owned(),
210             Kind::Optional(_) => quote!(::core::option::Option::None),
211             Kind::Repeated | Kind::Packed => quote!(::prost::alloc::vec::Vec::new()),
212         }
213     }
214 
215     /// An inner debug wrapper, around the base type.
debug_inner(&self, wrap_name: TokenStream) -> TokenStream216     fn debug_inner(&self, wrap_name: TokenStream) -> TokenStream {
217         if let Ty::Enumeration(ref ty) = self.ty {
218             quote! {
219                 struct #wrap_name<'a>(&'a i32);
220                 impl<'a> ::core::fmt::Debug for #wrap_name<'a> {
221                     fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
222                         let res: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(*self.0);
223                         match res {
224                             Err(_) => ::core::fmt::Debug::fmt(&self.0, f),
225                             Ok(en) => ::core::fmt::Debug::fmt(&en, f),
226                         }
227                     }
228                 }
229             }
230         } else {
231             quote! {
232                 #[allow(non_snake_case)]
233                 fn #wrap_name<T>(v: T) -> T { v }
234             }
235         }
236     }
237 
238     /// Returns a fragment for formatting the field `ident` in `Debug`.
debug(&self, wrapper_name: TokenStream) -> TokenStream239     pub fn debug(&self, wrapper_name: TokenStream) -> TokenStream {
240         let wrapper = self.debug_inner(quote!(Inner));
241         let inner_ty = self.ty.rust_type();
242         match self.kind {
243             Kind::Plain(_) | Kind::Required(_) => self.debug_inner(wrapper_name),
244             Kind::Optional(_) => quote! {
245                 struct #wrapper_name<'a>(&'a ::core::option::Option<#inner_ty>);
246                 impl<'a> ::core::fmt::Debug for #wrapper_name<'a> {
247                     fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
248                         #wrapper
249                         ::core::fmt::Debug::fmt(&self.0.as_ref().map(Inner), f)
250                     }
251                 }
252             },
253             Kind::Repeated | Kind::Packed => {
254                 quote! {
255                     struct #wrapper_name<'a>(&'a ::prost::alloc::vec::Vec<#inner_ty>);
256                     impl<'a> ::core::fmt::Debug for #wrapper_name<'a> {
257                         fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
258                             let mut vec_builder = f.debug_list();
259                             for v in self.0 {
260                                 #wrapper
261                                 vec_builder.entry(&Inner(v));
262                             }
263                             vec_builder.finish()
264                         }
265                     }
266                 }
267             }
268         }
269     }
270 
271     /// Returns methods to embed in the message.
methods(&self, ident: &TokenStream) -> Option<TokenStream>272     pub fn methods(&self, ident: &TokenStream) -> Option<TokenStream> {
273         let mut ident_str = ident.to_string();
274         if ident_str.starts_with("r#") {
275             ident_str = ident_str[2..].to_owned();
276         }
277 
278         // Prepend `get_` for getter methods of tuple structs.
279         let get = match syn::parse_str::<Index>(&ident_str) {
280             Ok(index) => {
281                 let get = Ident::new(&format!("get_{}", index.index), Span::call_site());
282                 quote!(#get)
283             }
284             Err(_) => quote!(#ident),
285         };
286 
287         if let Ty::Enumeration(ref ty) = self.ty {
288             let set = Ident::new(&format!("set_{}", ident_str), Span::call_site());
289             let set_doc = format!("Sets `{}` to the provided enum value.", ident_str);
290             Some(match self.kind {
291                 Kind::Plain(ref default) | Kind::Required(ref default) => {
292                     let get_doc = format!(
293                         "Returns the enum value of `{}`, \
294                          or the default if the field is set to an invalid enum value.",
295                         ident_str,
296                     );
297                     quote! {
298                         #[doc=#get_doc]
299                         pub fn #get(&self) -> #ty {
300                             ::core::convert::TryFrom::try_from(self.#ident).unwrap_or(#default)
301                         }
302 
303                         #[doc=#set_doc]
304                         pub fn #set(&mut self, value: #ty) {
305                             self.#ident = value as i32;
306                         }
307                     }
308                 }
309                 Kind::Optional(ref default) => {
310                     let get_doc = format!(
311                         "Returns the enum value of `{}`, \
312                          or the default if the field is unset or set to an invalid enum value.",
313                         ident_str,
314                     );
315                     quote! {
316                         #[doc=#get_doc]
317                         pub fn #get(&self) -> #ty {
318                             self.#ident.and_then(|x| {
319                                 let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x);
320                                 result.ok()
321                             }).unwrap_or(#default)
322                         }
323 
324                         #[doc=#set_doc]
325                         pub fn #set(&mut self, value: #ty) {
326                             self.#ident = ::core::option::Option::Some(value as i32);
327                         }
328                     }
329                 }
330                 Kind::Repeated | Kind::Packed => {
331                     let iter_doc = format!(
332                         "Returns an iterator which yields the valid enum values contained in `{}`.",
333                         ident_str,
334                     );
335                     let push = Ident::new(&format!("push_{}", ident_str), Span::call_site());
336                     let push_doc = format!("Appends the provided enum value to `{}`.", ident_str);
337                     quote! {
338                         #[doc=#iter_doc]
339                         pub fn #get(&self) -> ::core::iter::FilterMap<
340                             ::core::iter::Cloned<::core::slice::Iter<i32>>,
341                             fn(i32) -> ::core::option::Option<#ty>,
342                         > {
343                             self.#ident.iter().cloned().filter_map(|x| {
344                                 let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x);
345                                 result.ok()
346                             })
347                         }
348                         #[doc=#push_doc]
349                         pub fn #push(&mut self, value: #ty) {
350                             self.#ident.push(value as i32);
351                         }
352                     }
353                 }
354             })
355         } else if let Kind::Optional(ref default) = self.kind {
356             let ty = self.ty.rust_ref_type();
357 
358             let match_some = if self.ty.is_numeric() {
359                 quote!(::core::option::Option::Some(val) => val,)
360             } else {
361                 quote!(::core::option::Option::Some(ref val) => &val[..],)
362             };
363 
364             let get_doc = format!(
365                 "Returns the value of `{0}`, or the default value if `{0}` is unset.",
366                 ident_str,
367             );
368 
369             Some(quote! {
370                 #[doc=#get_doc]
371                 pub fn #get(&self) -> #ty {
372                     match self.#ident {
373                         #match_some
374                         ::core::option::Option::None => #default,
375                     }
376                 }
377             })
378         } else {
379             None
380         }
381     }
382 }
383 
384 /// A scalar protobuf field type.
385 #[derive(Clone, PartialEq, Eq)]
386 pub enum Ty {
387     Double,
388     Float,
389     Int32,
390     Int64,
391     Uint32,
392     Uint64,
393     Sint32,
394     Sint64,
395     Fixed32,
396     Fixed64,
397     Sfixed32,
398     Sfixed64,
399     Bool,
400     String,
401     Bytes(BytesTy),
402     Enumeration(Path),
403 }
404 
405 #[derive(Clone, Debug, PartialEq, Eq)]
406 pub enum BytesTy {
407     Vec,
408     Bytes,
409 }
410 
411 impl BytesTy {
try_from_str(s: &str) -> Result<Self, Error>412     fn try_from_str(s: &str) -> Result<Self, Error> {
413         match s {
414             "vec" => Ok(BytesTy::Vec),
415             "bytes" => Ok(BytesTy::Bytes),
416             _ => bail!("Invalid bytes type: {}", s),
417         }
418     }
419 
rust_type(&self) -> TokenStream420     fn rust_type(&self) -> TokenStream {
421         match self {
422             BytesTy::Vec => quote! { ::prost::alloc::vec::Vec<u8> },
423             BytesTy::Bytes => quote! { ::prost::bytes::Bytes },
424         }
425     }
426 }
427 
428 impl Ty {
from_attr(attr: &Meta) -> Result<Option<Ty>, Error>429     pub fn from_attr(attr: &Meta) -> Result<Option<Ty>, Error> {
430         let ty = match *attr {
431             Meta::Path(ref name) if name.is_ident("float") => Ty::Float,
432             Meta::Path(ref name) if name.is_ident("double") => Ty::Double,
433             Meta::Path(ref name) if name.is_ident("int32") => Ty::Int32,
434             Meta::Path(ref name) if name.is_ident("int64") => Ty::Int64,
435             Meta::Path(ref name) if name.is_ident("uint32") => Ty::Uint32,
436             Meta::Path(ref name) if name.is_ident("uint64") => Ty::Uint64,
437             Meta::Path(ref name) if name.is_ident("sint32") => Ty::Sint32,
438             Meta::Path(ref name) if name.is_ident("sint64") => Ty::Sint64,
439             Meta::Path(ref name) if name.is_ident("fixed32") => Ty::Fixed32,
440             Meta::Path(ref name) if name.is_ident("fixed64") => Ty::Fixed64,
441             Meta::Path(ref name) if name.is_ident("sfixed32") => Ty::Sfixed32,
442             Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64,
443             Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool,
444             Meta::Path(ref name) if name.is_ident("string") => Ty::String,
445             Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec),
446             Meta::NameValue(MetaNameValue {
447                 ref path,
448                 value:
449                     Expr::Lit(ExprLit {
450                         lit: Lit::Str(ref l),
451                         ..
452                     }),
453                 ..
454             }) if path.is_ident("bytes") => Ty::Bytes(BytesTy::try_from_str(&l.value())?),
455             Meta::NameValue(MetaNameValue {
456                 ref path,
457                 value:
458                     Expr::Lit(ExprLit {
459                         lit: Lit::Str(ref l),
460                         ..
461                     }),
462                 ..
463             }) if path.is_ident("enumeration") => Ty::Enumeration(parse_str::<Path>(&l.value())?),
464             Meta::List(ref meta_list) if meta_list.path.is_ident("enumeration") => {
465                 Ty::Enumeration(meta_list.parse_args::<Path>()?)
466             }
467             _ => return Ok(None),
468         };
469         Ok(Some(ty))
470     }
471 
from_str(s: &str) -> Result<Ty, Error>472     pub fn from_str(s: &str) -> Result<Ty, Error> {
473         let enumeration_len = "enumeration".len();
474         let error = Err(anyhow!("invalid type: {}", s));
475         let ty = match s.trim() {
476             "float" => Ty::Float,
477             "double" => Ty::Double,
478             "int32" => Ty::Int32,
479             "int64" => Ty::Int64,
480             "uint32" => Ty::Uint32,
481             "uint64" => Ty::Uint64,
482             "sint32" => Ty::Sint32,
483             "sint64" => Ty::Sint64,
484             "fixed32" => Ty::Fixed32,
485             "fixed64" => Ty::Fixed64,
486             "sfixed32" => Ty::Sfixed32,
487             "sfixed64" => Ty::Sfixed64,
488             "bool" => Ty::Bool,
489             "string" => Ty::String,
490             "bytes" => Ty::Bytes(BytesTy::Vec),
491             s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => {
492                 let s = &s[enumeration_len..].trim();
493                 match s.chars().next() {
494                     Some('<') | Some('(') => (),
495                     _ => return error,
496                 }
497                 match s.chars().next_back() {
498                     Some('>') | Some(')') => (),
499                     _ => return error,
500                 }
501 
502                 Ty::Enumeration(parse_str::<Path>(s[1..s.len() - 1].trim())?)
503             }
504             _ => return error,
505         };
506         Ok(ty)
507     }
508 
509     /// Returns the type as it appears in protobuf field declarations.
as_str(&self) -> &'static str510     pub fn as_str(&self) -> &'static str {
511         match *self {
512             Ty::Double => "double",
513             Ty::Float => "float",
514             Ty::Int32 => "int32",
515             Ty::Int64 => "int64",
516             Ty::Uint32 => "uint32",
517             Ty::Uint64 => "uint64",
518             Ty::Sint32 => "sint32",
519             Ty::Sint64 => "sint64",
520             Ty::Fixed32 => "fixed32",
521             Ty::Fixed64 => "fixed64",
522             Ty::Sfixed32 => "sfixed32",
523             Ty::Sfixed64 => "sfixed64",
524             Ty::Bool => "bool",
525             Ty::String => "string",
526             Ty::Bytes(..) => "bytes",
527             Ty::Enumeration(..) => "enum",
528         }
529     }
530 
531     // TODO: rename to 'owned_type'.
rust_type(&self) -> TokenStream532     pub fn rust_type(&self) -> TokenStream {
533         match self {
534             Ty::String => quote!(::prost::alloc::string::String),
535             Ty::Bytes(ty) => ty.rust_type(),
536             _ => self.rust_ref_type(),
537         }
538     }
539 
540     // TODO: rename to 'ref_type'
rust_ref_type(&self) -> TokenStream541     pub fn rust_ref_type(&self) -> TokenStream {
542         match *self {
543             Ty::Double => quote!(f64),
544             Ty::Float => quote!(f32),
545             Ty::Int32 => quote!(i32),
546             Ty::Int64 => quote!(i64),
547             Ty::Uint32 => quote!(u32),
548             Ty::Uint64 => quote!(u64),
549             Ty::Sint32 => quote!(i32),
550             Ty::Sint64 => quote!(i64),
551             Ty::Fixed32 => quote!(u32),
552             Ty::Fixed64 => quote!(u64),
553             Ty::Sfixed32 => quote!(i32),
554             Ty::Sfixed64 => quote!(i64),
555             Ty::Bool => quote!(bool),
556             Ty::String => quote!(&str),
557             Ty::Bytes(..) => quote!(&[u8]),
558             Ty::Enumeration(..) => quote!(i32),
559         }
560     }
561 
module(&self) -> Ident562     pub fn module(&self) -> Ident {
563         match *self {
564             Ty::Enumeration(..) => Ident::new("int32", Span::call_site()),
565             _ => Ident::new(self.as_str(), Span::call_site()),
566         }
567     }
568 
569     /// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`).
is_numeric(&self) -> bool570     pub fn is_numeric(&self) -> bool {
571         !matches!(self, Ty::String | Ty::Bytes(..))
572     }
573 }
574 
575 impl fmt::Debug for Ty {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result576     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
577         f.write_str(self.as_str())
578     }
579 }
580 
581 impl fmt::Display for Ty {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result582     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
583         f.write_str(self.as_str())
584     }
585 }
586 
587 /// Scalar Protobuf field types.
588 #[derive(Clone)]
589 pub enum Kind {
590     /// A plain proto3 scalar field.
591     Plain(DefaultValue),
592     /// An optional scalar field.
593     Optional(DefaultValue),
594     /// A required proto2 scalar field.
595     Required(DefaultValue),
596     /// A repeated scalar field.
597     Repeated,
598     /// A packed repeated scalar field.
599     Packed,
600 }
601 
602 /// Scalar Protobuf field default value.
603 #[derive(Clone, Debug)]
604 pub enum DefaultValue {
605     F64(f64),
606     F32(f32),
607     I32(i32),
608     I64(i64),
609     U32(u32),
610     U64(u64),
611     Bool(bool),
612     String(String),
613     Bytes(Vec<u8>),
614     Enumeration(TokenStream),
615     Path(Path),
616 }
617 
618 impl DefaultValue {
from_attr(attr: &Meta) -> Result<Option<Lit>, Error>619     pub fn from_attr(attr: &Meta) -> Result<Option<Lit>, Error> {
620         if !attr.path().is_ident("default") {
621             Ok(None)
622         } else if let Meta::NameValue(MetaNameValue {
623             value: Expr::Lit(ExprLit { ref lit, .. }),
624             ..
625         }) = *attr
626         {
627             Ok(Some(lit.clone()))
628         } else {
629             bail!("invalid default value attribute: {:?}", attr)
630         }
631     }
632 
from_lit(ty: &Ty, lit: Lit) -> Result<DefaultValue, Error>633     pub fn from_lit(ty: &Ty, lit: Lit) -> Result<DefaultValue, Error> {
634         let is_i32 = *ty == Ty::Int32 || *ty == Ty::Sint32 || *ty == Ty::Sfixed32;
635         let is_i64 = *ty == Ty::Int64 || *ty == Ty::Sint64 || *ty == Ty::Sfixed64;
636 
637         let is_u32 = *ty == Ty::Uint32 || *ty == Ty::Fixed32;
638         let is_u64 = *ty == Ty::Uint64 || *ty == Ty::Fixed64;
639 
640         let empty_or_is = |expected, actual: &str| expected == actual || actual.is_empty();
641 
642         let default = match lit {
643             Lit::Int(ref lit) if is_i32 && empty_or_is("i32", lit.suffix()) => {
644                 DefaultValue::I32(lit.base10_parse()?)
645             }
646             Lit::Int(ref lit) if is_i64 && empty_or_is("i64", lit.suffix()) => {
647                 DefaultValue::I64(lit.base10_parse()?)
648             }
649             Lit::Int(ref lit) if is_u32 && empty_or_is("u32", lit.suffix()) => {
650                 DefaultValue::U32(lit.base10_parse()?)
651             }
652             Lit::Int(ref lit) if is_u64 && empty_or_is("u64", lit.suffix()) => {
653                 DefaultValue::U64(lit.base10_parse()?)
654             }
655 
656             Lit::Float(ref lit) if *ty == Ty::Float && empty_or_is("f32", lit.suffix()) => {
657                 DefaultValue::F32(lit.base10_parse()?)
658             }
659             Lit::Int(ref lit) if *ty == Ty::Float => DefaultValue::F32(lit.base10_parse()?),
660 
661             Lit::Float(ref lit) if *ty == Ty::Double && empty_or_is("f64", lit.suffix()) => {
662                 DefaultValue::F64(lit.base10_parse()?)
663             }
664             Lit::Int(ref lit) if *ty == Ty::Double => DefaultValue::F64(lit.base10_parse()?),
665 
666             Lit::Bool(ref lit) if *ty == Ty::Bool => DefaultValue::Bool(lit.value),
667             Lit::Str(ref lit) if *ty == Ty::String => DefaultValue::String(lit.value()),
668             Lit::ByteStr(ref lit)
669                 if *ty == Ty::Bytes(BytesTy::Bytes) || *ty == Ty::Bytes(BytesTy::Vec) =>
670             {
671                 DefaultValue::Bytes(lit.value())
672             }
673 
674             Lit::Str(ref lit) => {
675                 let value = lit.value();
676                 let value = value.trim();
677 
678                 if let Ty::Enumeration(ref path) = *ty {
679                     let variant = Ident::new(value, Span::call_site());
680                     return Ok(DefaultValue::Enumeration(quote!(#path::#variant)));
681                 }
682 
683                 // Parse special floating point values.
684                 if *ty == Ty::Float {
685                     match value {
686                         "inf" => {
687                             return Ok(DefaultValue::Path(parse_str::<Path>(
688                                 "::core::f32::INFINITY",
689                             )?));
690                         }
691                         "-inf" => {
692                             return Ok(DefaultValue::Path(parse_str::<Path>(
693                                 "::core::f32::NEG_INFINITY",
694                             )?));
695                         }
696                         "nan" => {
697                             return Ok(DefaultValue::Path(parse_str::<Path>("::core::f32::NAN")?));
698                         }
699                         _ => (),
700                     }
701                 }
702                 if *ty == Ty::Double {
703                     match value {
704                         "inf" => {
705                             return Ok(DefaultValue::Path(parse_str::<Path>(
706                                 "::core::f64::INFINITY",
707                             )?));
708                         }
709                         "-inf" => {
710                             return Ok(DefaultValue::Path(parse_str::<Path>(
711                                 "::core::f64::NEG_INFINITY",
712                             )?));
713                         }
714                         "nan" => {
715                             return Ok(DefaultValue::Path(parse_str::<Path>("::core::f64::NAN")?));
716                         }
717                         _ => (),
718                     }
719                 }
720 
721                 // Rust doesn't have a negative literals, so they have to be parsed specially.
722                 if let Some(Ok(lit)) = value.strip_prefix('-').map(syn::parse_str::<Lit>) {
723                     match lit {
724                         Lit::Int(ref lit) if is_i32 && empty_or_is("i32", lit.suffix()) => {
725                             // Initially parse into an i64, so that i32::MIN does not overflow.
726                             let value: i64 = -lit.base10_parse()?;
727                             return Ok(i32::try_from(value).map(DefaultValue::I32)?);
728                         }
729                         Lit::Int(ref lit) if is_i64 && empty_or_is("i64", lit.suffix()) => {
730                             // Initially parse into an i128, so that i64::MIN does not overflow.
731                             let value: i128 = -lit.base10_parse()?;
732                             return Ok(i64::try_from(value).map(DefaultValue::I64)?);
733                         }
734                         Lit::Float(ref lit)
735                             if *ty == Ty::Float && empty_or_is("f32", lit.suffix()) =>
736                         {
737                             return Ok(DefaultValue::F32(-lit.base10_parse()?));
738                         }
739                         Lit::Float(ref lit)
740                             if *ty == Ty::Double && empty_or_is("f64", lit.suffix()) =>
741                         {
742                             return Ok(DefaultValue::F64(-lit.base10_parse()?));
743                         }
744                         Lit::Int(ref lit) if *ty == Ty::Float && lit.suffix().is_empty() => {
745                             return Ok(DefaultValue::F32(-lit.base10_parse()?));
746                         }
747                         Lit::Int(ref lit) if *ty == Ty::Double && lit.suffix().is_empty() => {
748                             return Ok(DefaultValue::F64(-lit.base10_parse()?));
749                         }
750                         _ => (),
751                     }
752                 }
753                 match syn::parse_str::<Lit>(value) {
754                     Ok(Lit::Str(_)) => (),
755                     Ok(lit) => return DefaultValue::from_lit(ty, lit),
756                     _ => (),
757                 }
758                 bail!("invalid default value: {}", quote!(#value));
759             }
760             _ => bail!("invalid default value: {}", quote!(#lit)),
761         };
762 
763         Ok(default)
764     }
765 
new(ty: &Ty) -> DefaultValue766     pub fn new(ty: &Ty) -> DefaultValue {
767         match *ty {
768             Ty::Float => DefaultValue::F32(0.0),
769             Ty::Double => DefaultValue::F64(0.0),
770             Ty::Int32 | Ty::Sint32 | Ty::Sfixed32 => DefaultValue::I32(0),
771             Ty::Int64 | Ty::Sint64 | Ty::Sfixed64 => DefaultValue::I64(0),
772             Ty::Uint32 | Ty::Fixed32 => DefaultValue::U32(0),
773             Ty::Uint64 | Ty::Fixed64 => DefaultValue::U64(0),
774 
775             Ty::Bool => DefaultValue::Bool(false),
776             Ty::String => DefaultValue::String(String::new()),
777             Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()),
778             Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())),
779         }
780     }
781 
owned(&self) -> TokenStream782     pub fn owned(&self) -> TokenStream {
783         match *self {
784             DefaultValue::String(ref value) if value.is_empty() => {
785                 quote!(::prost::alloc::string::String::new())
786             }
787             DefaultValue::String(ref value) => quote!(#value.into()),
788             DefaultValue::Bytes(ref value) if value.is_empty() => {
789                 quote!(::core::default::Default::default())
790             }
791             DefaultValue::Bytes(ref value) => {
792                 let lit = LitByteStr::new(value, Span::call_site());
793                 quote!(#lit.as_ref().into())
794             }
795 
796             ref other => other.typed(),
797         }
798     }
799 
typed(&self) -> TokenStream800     pub fn typed(&self) -> TokenStream {
801         if let DefaultValue::Enumeration(_) = *self {
802             quote!(#self as i32)
803         } else {
804             quote!(#self)
805         }
806     }
807 }
808 
809 impl ToTokens for DefaultValue {
to_tokens(&self, tokens: &mut TokenStream)810     fn to_tokens(&self, tokens: &mut TokenStream) {
811         match *self {
812             DefaultValue::F64(value) => value.to_tokens(tokens),
813             DefaultValue::F32(value) => value.to_tokens(tokens),
814             DefaultValue::I32(value) => value.to_tokens(tokens),
815             DefaultValue::I64(value) => value.to_tokens(tokens),
816             DefaultValue::U32(value) => value.to_tokens(tokens),
817             DefaultValue::U64(value) => value.to_tokens(tokens),
818             DefaultValue::Bool(value) => value.to_tokens(tokens),
819             DefaultValue::String(ref value) => value.to_tokens(tokens),
820             DefaultValue::Bytes(ref value) => {
821                 let byte_str = LitByteStr::new(value, Span::call_site());
822                 tokens.append_all(quote!(#byte_str as &[u8]));
823             }
824             DefaultValue::Enumeration(ref value) => value.to_tokens(tokens),
825             DefaultValue::Path(ref value) => value.to_tokens(tokens),
826         }
827     }
828 }
829