1 use anyhow::{bail, Error};
2 use proc_macro2::{Span, TokenStream};
3 use quote::quote;
4 use syn::punctuated::Punctuated;
5 use syn::{Expr, ExprLit, Ident, Lit, Meta, MetaNameValue, Token};
6 
7 use crate::field::{scalar, set_option, tag_attr};
8 
9 #[derive(Clone, Debug)]
10 pub enum MapTy {
11     HashMap,
12     BTreeMap,
13 }
14 
15 impl MapTy {
from_str(s: &str) -> Option<MapTy>16     fn from_str(s: &str) -> Option<MapTy> {
17         match s {
18             "map" | "hash_map" => Some(MapTy::HashMap),
19             "btree_map" => Some(MapTy::BTreeMap),
20             _ => None,
21         }
22     }
23 
module(&self) -> Ident24     fn module(&self) -> Ident {
25         match *self {
26             MapTy::HashMap => Ident::new("hash_map", Span::call_site()),
27             MapTy::BTreeMap => Ident::new("btree_map", Span::call_site()),
28         }
29     }
30 
lib(&self) -> TokenStream31     fn lib(&self) -> TokenStream {
32         match self {
33             MapTy::HashMap => quote! { std },
34             MapTy::BTreeMap => quote! { prost::alloc },
35         }
36     }
37 }
38 
fake_scalar(ty: scalar::Ty) -> scalar::Field39 fn fake_scalar(ty: scalar::Ty) -> scalar::Field {
40     let kind = scalar::Kind::Plain(scalar::DefaultValue::new(&ty));
41     scalar::Field {
42         ty,
43         kind,
44         tag: 0, // Not used here
45     }
46 }
47 
48 #[derive(Clone)]
49 pub struct Field {
50     pub map_ty: MapTy,
51     pub key_ty: scalar::Ty,
52     pub value_ty: ValueTy,
53     pub tag: u32,
54 }
55 
56 impl Field {
new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error>57     pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> {
58         let mut types = None;
59         let mut tag = None;
60 
61         for attr in attrs {
62             if let Some(t) = tag_attr(attr)? {
63                 set_option(&mut tag, t, "duplicate tag attributes")?;
64             } else if let Some(map_ty) = attr
65                 .path()
66                 .get_ident()
67                 .and_then(|i| MapTy::from_str(&i.to_string()))
68             {
69                 let (k, v): (String, String) = match attr {
70                     Meta::NameValue(MetaNameValue {
71                         value:
72                             Expr::Lit(ExprLit {
73                                 lit: Lit::Str(lit), ..
74                             }),
75                         ..
76                     }) => {
77                         let items = lit.value();
78                         let mut items = items.split(',').map(ToString::to_string);
79                         let k = items.next().unwrap();
80                         let v = match items.next() {
81                             Some(k) => k,
82                             None => bail!("invalid map attribute: must have key and value types"),
83                         };
84                         if items.next().is_some() {
85                             bail!("invalid map attribute: {:?}", attr);
86                         }
87                         (k, v)
88                     }
89                     Meta::List(meta_list) => {
90                         let nested = meta_list
91                             .parse_args_with(Punctuated::<Ident, Token![,]>::parse_terminated)?
92                             .into_iter()
93                             .collect::<Vec<_>>();
94                         if nested.len() != 2 {
95                             bail!("invalid map attribute: must contain key and value types");
96                         }
97                         (nested[0].to_string(), nested[1].to_string())
98                     }
99                     _ => return Ok(None),
100                 };
101                 set_option(
102                     &mut types,
103                     (map_ty, key_ty_from_str(&k)?, ValueTy::from_str(&v)?),
104                     "duplicate map type attribute",
105                 )?;
106             } else {
107                 return Ok(None);
108             }
109         }
110 
111         Ok(match (types, tag.or(inferred_tag)) {
112             (Some((map_ty, key_ty, value_ty)), Some(tag)) => Some(Field {
113                 map_ty,
114                 key_ty,
115                 value_ty,
116                 tag,
117             }),
118             _ => None,
119         })
120     }
121 
new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error>122     pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> {
123         Field::new(attrs, None)
124     }
125 
126     /// Returns a statement which encodes the map field.
encode(&self, ident: TokenStream) -> TokenStream127     pub fn encode(&self, ident: TokenStream) -> TokenStream {
128         let tag = self.tag;
129         let key_mod = self.key_ty.module();
130         let ke = quote!(::prost::encoding::#key_mod::encode);
131         let kl = quote!(::prost::encoding::#key_mod::encoded_len);
132         let module = self.map_ty.module();
133         match &self.value_ty {
134             ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
135                 let default = quote!(#ty::default() as i32);
136                 quote! {
137                     ::prost::encoding::#module::encode_with_default(
138                         #ke,
139                         #kl,
140                         ::prost::encoding::int32::encode,
141                         ::prost::encoding::int32::encoded_len,
142                         &(#default),
143                         #tag,
144                         &#ident,
145                         buf,
146                     );
147                 }
148             }
149             ValueTy::Scalar(value_ty) => {
150                 let val_mod = value_ty.module();
151                 let ve = quote!(::prost::encoding::#val_mod::encode);
152                 let vl = quote!(::prost::encoding::#val_mod::encoded_len);
153                 quote! {
154                     ::prost::encoding::#module::encode(
155                         #ke,
156                         #kl,
157                         #ve,
158                         #vl,
159                         #tag,
160                         &#ident,
161                         buf,
162                     );
163                 }
164             }
165             ValueTy::Message => quote! {
166                 ::prost::encoding::#module::encode(
167                     #ke,
168                     #kl,
169                     ::prost::encoding::message::encode,
170                     ::prost::encoding::message::encoded_len,
171                     #tag,
172                     &#ident,
173                     buf,
174                 );
175             },
176         }
177     }
178 
179     /// Returns an expression which evaluates to the result of merging a decoded key value pair
180     /// into the map.
merge(&self, ident: TokenStream) -> TokenStream181     pub fn merge(&self, ident: TokenStream) -> TokenStream {
182         let key_mod = self.key_ty.module();
183         let km = quote!(::prost::encoding::#key_mod::merge);
184         let module = self.map_ty.module();
185         match &self.value_ty {
186             ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
187                 let default = quote!(#ty::default() as i32);
188                 quote! {
189                     ::prost::encoding::#module::merge_with_default(
190                         #km,
191                         ::prost::encoding::int32::merge,
192                         #default,
193                         &mut #ident,
194                         buf,
195                         ctx,
196                     )
197                 }
198             }
199             ValueTy::Scalar(value_ty) => {
200                 let val_mod = value_ty.module();
201                 let vm = quote!(::prost::encoding::#val_mod::merge);
202                 quote!(::prost::encoding::#module::merge(#km, #vm, &mut #ident, buf, ctx))
203             }
204             ValueTy::Message => quote! {
205                 ::prost::encoding::#module::merge(
206                     #km,
207                     ::prost::encoding::message::merge,
208                     &mut #ident,
209                     buf,
210                     ctx,
211                 )
212             },
213         }
214     }
215 
216     /// Returns an expression which evaluates to the encoded length of the map.
encoded_len(&self, ident: TokenStream) -> TokenStream217     pub fn encoded_len(&self, ident: TokenStream) -> TokenStream {
218         let tag = self.tag;
219         let key_mod = self.key_ty.module();
220         let kl = quote!(::prost::encoding::#key_mod::encoded_len);
221         let module = self.map_ty.module();
222         match &self.value_ty {
223             ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
224                 let default = quote!(#ty::default() as i32);
225                 quote! {
226                     ::prost::encoding::#module::encoded_len_with_default(
227                         #kl,
228                         ::prost::encoding::int32::encoded_len,
229                         &(#default),
230                         #tag,
231                         &#ident,
232                     )
233                 }
234             }
235             ValueTy::Scalar(value_ty) => {
236                 let val_mod = value_ty.module();
237                 let vl = quote!(::prost::encoding::#val_mod::encoded_len);
238                 quote!(::prost::encoding::#module::encoded_len(#kl, #vl, #tag, &#ident))
239             }
240             ValueTy::Message => quote! {
241                 ::prost::encoding::#module::encoded_len(
242                     #kl,
243                     ::prost::encoding::message::encoded_len,
244                     #tag,
245                     &#ident,
246                 )
247             },
248         }
249     }
250 
clear(&self, ident: TokenStream) -> TokenStream251     pub fn clear(&self, ident: TokenStream) -> TokenStream {
252         quote!(#ident.clear())
253     }
254 
255     /// Returns methods to embed in the message.
methods(&self, ident: &TokenStream) -> Option<TokenStream>256     pub fn methods(&self, ident: &TokenStream) -> Option<TokenStream> {
257         if let ValueTy::Scalar(scalar::Ty::Enumeration(ty)) = &self.value_ty {
258             let key_ty = self.key_ty.rust_type();
259             let key_ref_ty = self.key_ty.rust_ref_type();
260 
261             let get = Ident::new(&format!("get_{}", ident), Span::call_site());
262             let insert = Ident::new(&format!("insert_{}", ident), Span::call_site());
263             let take_ref = if self.key_ty.is_numeric() {
264                 quote!(&)
265             } else {
266                 quote!()
267             };
268 
269             let get_doc = format!(
270                 "Returns the enum value for the corresponding key in `{}`, \
271                  or `None` if the entry does not exist or it is not a valid enum value.",
272                 ident,
273             );
274             let insert_doc = format!("Inserts a key value pair into `{}`.", ident);
275             Some(quote! {
276                 #[doc=#get_doc]
277                 pub fn #get(&self, key: #key_ref_ty) -> ::core::option::Option<#ty> {
278                     self.#ident.get(#take_ref key).cloned().and_then(|x| {
279                         let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x);
280                         result.ok()
281                     })
282                 }
283                 #[doc=#insert_doc]
284                 pub fn #insert(&mut self, key: #key_ty, value: #ty) -> ::core::option::Option<#ty> {
285                     self.#ident.insert(key, value as i32).and_then(|x| {
286                         let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x);
287                         result.ok()
288                     })
289                 }
290             })
291         } else {
292             None
293         }
294     }
295 
296     /// Returns a newtype wrapper around the map, implementing nicer Debug
297     ///
298     /// The Debug tries to convert any enumerations met into the variants if possible, instead of
299     /// outputting the raw numbers.
debug(&self, wrapper_name: TokenStream) -> TokenStream300     pub fn debug(&self, wrapper_name: TokenStream) -> TokenStream {
301         let type_name = match self.map_ty {
302             MapTy::HashMap => Ident::new("HashMap", Span::call_site()),
303             MapTy::BTreeMap => Ident::new("BTreeMap", Span::call_site()),
304         };
305 
306         // A fake field for generating the debug wrapper
307         let key_wrapper = fake_scalar(self.key_ty.clone()).debug(quote!(KeyWrapper));
308         let key = self.key_ty.rust_type();
309         let value_wrapper = self.value_ty.debug();
310         let libname = self.map_ty.lib();
311         let fmt = quote! {
312             fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
313                 #key_wrapper
314                 #value_wrapper
315                 let mut builder = f.debug_map();
316                 for (k, v) in self.0 {
317                     builder.entry(&KeyWrapper(k), &ValueWrapper(v));
318                 }
319                 builder.finish()
320             }
321         };
322         match &self.value_ty {
323             ValueTy::Scalar(ty) => {
324                 if let scalar::Ty::Bytes(_) = *ty {
325                     return quote! {
326                         struct #wrapper_name<'a>(&'a dyn ::core::fmt::Debug);
327                         impl<'a> ::core::fmt::Debug for #wrapper_name<'a> {
328                             fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
329                                 self.0.fmt(f)
330                             }
331                         }
332                     };
333                 }
334 
335                 let value = ty.rust_type();
336                 quote! {
337                     struct #wrapper_name<'a>(&'a ::#libname::collections::#type_name<#key, #value>);
338                     impl<'a> ::core::fmt::Debug for #wrapper_name<'a> {
339                         #fmt
340                     }
341                 }
342             }
343             ValueTy::Message => quote! {
344                 struct #wrapper_name<'a, V: 'a>(&'a ::#libname::collections::#type_name<#key, V>);
345                 impl<'a, V> ::core::fmt::Debug for #wrapper_name<'a, V>
346                 where
347                     V: ::core::fmt::Debug + 'a,
348                 {
349                     #fmt
350                 }
351             },
352         }
353     }
354 }
355 
key_ty_from_str(s: &str) -> Result<scalar::Ty, Error>356 fn key_ty_from_str(s: &str) -> Result<scalar::Ty, Error> {
357     let ty = scalar::Ty::from_str(s)?;
358     match ty {
359         scalar::Ty::Int32
360         | scalar::Ty::Int64
361         | scalar::Ty::Uint32
362         | scalar::Ty::Uint64
363         | scalar::Ty::Sint32
364         | scalar::Ty::Sint64
365         | scalar::Ty::Fixed32
366         | scalar::Ty::Fixed64
367         | scalar::Ty::Sfixed32
368         | scalar::Ty::Sfixed64
369         | scalar::Ty::Bool
370         | scalar::Ty::String => Ok(ty),
371         _ => bail!("invalid map key type: {}", s),
372     }
373 }
374 
375 /// A map value type.
376 #[derive(Clone, Debug, PartialEq, Eq)]
377 pub enum ValueTy {
378     Scalar(scalar::Ty),
379     Message,
380 }
381 
382 impl ValueTy {
from_str(s: &str) -> Result<ValueTy, Error>383     fn from_str(s: &str) -> Result<ValueTy, Error> {
384         if let Ok(ty) = scalar::Ty::from_str(s) {
385             Ok(ValueTy::Scalar(ty))
386         } else if s.trim() == "message" {
387             Ok(ValueTy::Message)
388         } else {
389             bail!("invalid map value type: {}", s);
390         }
391     }
392 
393     /// Returns a newtype wrapper around the ValueTy for nicer debug.
394     ///
395     /// If the contained value is enumeration, it tries to convert it to the variant. If not, it
396     /// just forwards the implementation.
debug(&self) -> TokenStream397     fn debug(&self) -> TokenStream {
398         match self {
399             ValueTy::Scalar(ty) => fake_scalar(ty.clone()).debug(quote!(ValueWrapper)),
400             ValueTy::Message => quote!(
401                 fn ValueWrapper<T>(v: T) -> T {
402                     v
403                 }
404             ),
405         }
406     }
407 }
408