1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use crate::backends::rust::{mask_bits, types, ToIdent, ToUpperCamelCase};
16 use crate::{analyzer, ast};
17 use quote::{format_ident, quote};
18 
size_field_ident(id: &str) -> proc_macro2::Ident19 fn size_field_ident(id: &str) -> proc_macro2::Ident {
20     format_ident!("{}_size", id.trim_matches('_'))
21 }
22 
23 /// A single bit-field.
24 struct BitField<'a> {
25     shift: usize, // The shift to apply to this field.
26     field: &'a ast::Field,
27 }
28 
29 pub struct FieldParser<'a> {
30     scope: &'a analyzer::Scope<'a>,
31     schema: &'a analyzer::Schema,
32     endianness: ast::EndiannessValue,
33     decl: &'a ast::Decl,
34     packet_name: &'a str,
35     span: &'a proc_macro2::Ident,
36     chunk: Vec<BitField<'a>>,
37     tokens: proc_macro2::TokenStream,
38     shift: usize,
39     offset: usize,
40 }
41 
42 impl<'a> FieldParser<'a> {
new( scope: &'a analyzer::Scope<'a>, schema: &'a analyzer::Schema, endianness: ast::EndiannessValue, packet_name: &'a str, span: &'a proc_macro2::Ident, ) -> FieldParser<'a>43     pub fn new(
44         scope: &'a analyzer::Scope<'a>,
45         schema: &'a analyzer::Schema,
46         endianness: ast::EndiannessValue,
47         packet_name: &'a str,
48         span: &'a proc_macro2::Ident,
49     ) -> FieldParser<'a> {
50         FieldParser {
51             scope,
52             schema,
53             endianness,
54             decl: scope.typedef[packet_name],
55             packet_name,
56             span,
57             chunk: Vec::new(),
58             tokens: quote! {},
59             shift: 0,
60             offset: 0,
61         }
62     }
63 
add(&mut self, field: &'a ast::Field)64     pub fn add(&mut self, field: &'a ast::Field) {
65         match &field.desc {
66             _ if field.cond.is_some() => self.add_optional_field(field),
67             _ if self.scope.is_bitfield(field) => self.add_bit_field(field),
68             ast::FieldDesc::Padding { .. } => (),
69             ast::FieldDesc::Array { id, width, type_id, size, .. } => self.add_array_field(
70                 id,
71                 *width,
72                 type_id.as_deref(),
73                 *size,
74                 self.schema.padded_size(field.key),
75                 self.scope.get_type_declaration(field),
76             ),
77             ast::FieldDesc::Typedef { id, type_id } => self.add_typedef_field(id, type_id),
78             ast::FieldDesc::Payload { size_modifier, .. } => {
79                 self.add_payload_field(size_modifier.as_deref())
80             }
81             ast::FieldDesc::Body { .. } => self.add_payload_field(None),
82             _ => todo!("{field:?}"),
83         }
84     }
85 
add_optional_field(&mut self, field: &'a ast::Field)86     fn add_optional_field(&mut self, field: &'a ast::Field) {
87         let cond_id = field.cond.as_ref().unwrap().id.to_ident();
88         let cond_value = syn::parse_str::<syn::LitInt>(&format!(
89             "{}",
90             field.cond.as_ref().unwrap().value.unwrap()
91         ))
92         .unwrap();
93 
94         self.tokens.extend(match &field.desc {
95             ast::FieldDesc::Scalar { id, width } => {
96                 let id = id.to_ident();
97                 let value = types::get_uint(self.endianness, *width, self.span);
98                 quote! {
99                     let #id = (#cond_id == #cond_value).then(|| #value);
100                 }
101             }
102             ast::FieldDesc::Typedef { id, type_id } => match &self.scope.typedef[type_id].desc {
103                 ast::DeclDesc::Enum { width, .. } => {
104                     let name = id;
105                     let type_name = type_id;
106                     let id = id.to_ident();
107                     let type_id = type_id.to_ident();
108                     let decl_id = &self.packet_name;
109                     let value = types::get_uint(self.endianness, *width, self.span);
110                     quote! {
111                         let #id = (#cond_id == #cond_value)
112                             .then(||
113                                 #type_id::try_from(#value).map_err(|unknown_val| {
114                                     DecodeError::InvalidEnumValueError {
115                                         obj: #decl_id,
116                                         field: #name,
117                                         value: unknown_val as u64,
118                                         type_: #type_name,
119                                     }
120                                 }))
121                             .transpose()?;
122                     }
123                 }
124                 ast::DeclDesc::Struct { .. } => {
125                     let id = id.to_ident();
126                     let type_id = type_id.to_ident();
127                     let span = self.span;
128                     quote! {
129                         let #id = (#cond_id == #cond_value)
130                             .then(|| #type_id::decode_mut(&mut #span))
131                             .transpose()?;
132                     }
133                 }
134                 _ => unreachable!(),
135             },
136             _ => unreachable!(),
137         })
138     }
139 
add_bit_field(&mut self, field: &'a ast::Field)140     fn add_bit_field(&mut self, field: &'a ast::Field) {
141         self.chunk.push(BitField { shift: self.shift, field });
142         self.shift += self.schema.field_size(field.key).static_().unwrap();
143         if self.shift % 8 != 0 {
144             return;
145         }
146 
147         let size = self.shift / 8;
148         let end_offset = self.offset + size;
149 
150         let wanted = proc_macro2::Literal::usize_unsuffixed(size);
151         self.check_size(self.span, &quote!(#wanted));
152 
153         let chunk_type = types::Integer::new(self.shift);
154         // TODO(mgeisler): generate Rust variable names which cannot
155         // conflict with PDL field names. An option would be to start
156         // Rust variable names with `_`, but that has a special
157         // semantic in Rust.
158         let chunk_name = format_ident!("chunk");
159 
160         let get = types::get_uint(self.endianness, self.shift, self.span);
161         if self.chunk.len() > 1 {
162             // Multiple values: we read into a local variable.
163             self.tokens.extend(quote! {
164                 let #chunk_name = #get;
165             });
166         }
167 
168         let single_value = self.chunk.len() == 1; // && self.chunk[0].offset == 0;
169         for BitField { shift, field } in self.chunk.drain(..) {
170             let mut v = if single_value {
171                 // Single value: read directly.
172                 quote! { #get }
173             } else {
174                 // Multiple values: read from `chunk_name`.
175                 quote! { #chunk_name }
176             };
177 
178             if shift > 0 {
179                 let shift = proc_macro2::Literal::usize_unsuffixed(shift);
180                 v = quote! { (#v >> #shift) }
181             }
182 
183             let width = self.schema.field_size(field.key).static_().unwrap();
184             let value_type = types::Integer::new(width);
185             if !single_value && width < value_type.width {
186                 // Mask value if we grabbed more than `width` and if
187                 // `as #value_type` doesn't already do the masking.
188                 let mask = mask_bits(width, "u64");
189                 v = quote! { (#v & #mask) };
190             }
191 
192             if value_type.width < chunk_type.width {
193                 v = quote! { #v as #value_type };
194             }
195 
196             self.tokens.extend(match &field.desc {
197                 ast::FieldDesc::Scalar { id, .. }
198                 | ast::FieldDesc::Flag { id, .. } => {
199                     let id = id.to_ident();
200                     quote! {
201                         let #id = #v;
202                     }
203                 }
204                 ast::FieldDesc::FixedEnum { enum_id, tag_id, .. } => {
205                     let enum_id = enum_id.to_ident();
206                     let tag_id = tag_id.to_upper_camel_case().to_ident();
207                     quote! {
208                         let fixed_value = #v;
209                         if fixed_value != #value_type::from(#enum_id::#tag_id)  {
210                             return Err(DecodeError::InvalidFixedValue {
211                                 expected: #value_type::from(#enum_id::#tag_id) as u64,
212                                 actual: fixed_value as u64,
213                             });
214                         }
215                     }
216                 }
217                 ast::FieldDesc::FixedScalar { value, .. } => {
218                     let value = proc_macro2::Literal::usize_unsuffixed(*value);
219                     quote! {
220                         let fixed_value = #v;
221                         if fixed_value != #value {
222                             return Err(DecodeError::InvalidFixedValue {
223                                 expected: #value,
224                                 actual: fixed_value as u64,
225                             });
226                         }
227                     }
228                 }
229                 ast::FieldDesc::Typedef { id, type_id } => {
230                     let field_name = id;
231                     let type_name = type_id;
232                     let packet_name = &self.packet_name;
233                     let id = id.to_ident();
234                     let type_id = type_id.to_ident();
235                     quote! {
236                         let #id = #type_id::try_from(#v).map_err(|unknown_val| DecodeError::InvalidEnumValueError {
237                             obj: #packet_name,
238                             field: #field_name,
239                             value: unknown_val as u64,
240                             type_: #type_name,
241                         })?;
242                     }
243                 }
244                 ast::FieldDesc::Reserved { .. } => {
245                     if single_value {
246                         let span = self.span;
247                         let size = proc_macro2::Literal::usize_unsuffixed(size);
248                         quote! {
249                             #span.advance(#size);
250                         }
251                     } else {
252                         //  Otherwise we don't need anything: we will
253                         //  have advanced past the reserved field when
254                         //  reading the chunk above.
255                         quote! {}
256                     }
257                 }
258                 ast::FieldDesc::Size { field_id, .. } => {
259                     let id = size_field_ident(field_id);
260                     quote! {
261                         let #id = #v as usize;
262                     }
263                 }
264                 ast::FieldDesc::ElementSize { field_id, .. } => {
265                     let id = format_ident!("{field_id}_element_size");
266                     quote! {
267                         let #id = #v as usize;
268                     }
269                 }
270                 ast::FieldDesc::Count { field_id, .. } => {
271                     let id = format_ident!("{field_id}_count");
272                     quote! {
273                         let #id = #v as usize;
274                     }
275                 }
276                 _ => todo!(),
277             });
278         }
279 
280         self.offset = end_offset;
281         self.shift = 0;
282     }
283 
find_count_field(&self, id: &str) -> Option<proc_macro2::Ident>284     fn find_count_field(&self, id: &str) -> Option<proc_macro2::Ident> {
285         match self.decl.array_size(id)?.desc {
286             ast::FieldDesc::Count { .. } => Some(format_ident!("{id}_count")),
287             _ => None,
288         }
289     }
290 
find_size_field(&self, id: &str) -> Option<proc_macro2::Ident>291     fn find_size_field(&self, id: &str) -> Option<proc_macro2::Ident> {
292         match self.decl.array_size(id)?.desc {
293             ast::FieldDesc::Size { .. } => Some(size_field_ident(id)),
294             _ => None,
295         }
296     }
297 
find_element_size_field(&self, id: &str) -> Option<proc_macro2::Ident>298     fn find_element_size_field(&self, id: &str) -> Option<proc_macro2::Ident> {
299         self.decl.fields().find_map(|field| match &field.desc {
300             ast::FieldDesc::ElementSize { field_id, .. } if field_id == id => {
301                 Some(format_ident!("{id}_element_size"))
302             }
303             _ => None,
304         })
305     }
306 
payload_field_offset_from_end(&self) -> Option<usize>307     fn payload_field_offset_from_end(&self) -> Option<usize> {
308         let decl = self.scope.typedef[self.packet_name];
309         let mut fields = decl.fields();
310         fields.find(|f| {
311             matches!(f.desc, ast::FieldDesc::Body { .. } | ast::FieldDesc::Payload { .. })
312         })?;
313 
314         let mut offset = 0;
315         for field in fields {
316             if let Some(width) =
317                 self.schema.padded_size(field.key).or(self.schema.field_size(field.key).static_())
318             {
319                 offset += width;
320             } else {
321                 return None;
322             }
323         }
324 
325         Some(offset)
326     }
327 
check_size(&mut self, span: &proc_macro2::Ident, wanted: &proc_macro2::TokenStream)328     fn check_size(&mut self, span: &proc_macro2::Ident, wanted: &proc_macro2::TokenStream) {
329         let packet_name = &self.packet_name;
330         self.tokens.extend(quote! {
331             if #span.remaining() < #wanted {
332                 return Err(DecodeError::InvalidLengthError {
333                     obj: #packet_name,
334                     wanted: #wanted,
335                     got: #span.remaining(),
336                 });
337             }
338         });
339     }
340 
add_array_field( &mut self, id: &str, width: Option<usize>, type_id: Option<&str>, size: Option<usize>, padding_size: Option<usize>, decl: Option<&ast::Decl>, )341     fn add_array_field(
342         &mut self,
343         id: &str,
344         // `width`: the width in bits of the array elements (if Some).
345         width: Option<usize>,
346         // `type_id`: the enum type of the array elements (if Some).
347         // Mutually exclusive with `width`.
348         type_id: Option<&str>,
349         // `size`: the size of the array in number of elements (if
350         // known). If None, the array is a Vec with a dynamic size.
351         size: Option<usize>,
352         padding_size: Option<usize>,
353         decl: Option<&ast::Decl>,
354     ) {
355         enum ElementWidth {
356             Static(usize),               // Static size in bytes.
357             Dynamic(proc_macro2::Ident), // Dynamic size in bytes.
358             Unknown,
359         }
360         let element_width = if let Some(w) =
361             width.or_else(|| self.schema.total_size(decl.unwrap().key).static_())
362         {
363             assert_eq!(w % 8, 0, "Array element size ({w}) is not a multiple of 8");
364             ElementWidth::Static(w / 8)
365         } else if let Some(element_size_field) = self.find_element_size_field(id) {
366             ElementWidth::Dynamic(element_size_field)
367         } else {
368             ElementWidth::Unknown
369         };
370 
371         // The "shape" of the array, i.e., the number of elements
372         // given via a static count, a count field, a size field, or
373         // unknown.
374         enum ArrayShape {
375             Static(usize),                  // Static count
376             CountField(proc_macro2::Ident), // Count based on count field
377             SizeField(proc_macro2::Ident),  // Count based on size and field
378             Unknown,                        // Variable count based on remaining bytes
379         }
380         let array_shape = if let Some(count) = size {
381             ArrayShape::Static(count)
382         } else if let Some(count_field) = self.find_count_field(id) {
383             ArrayShape::CountField(count_field)
384         } else if let Some(size_field) = self.find_size_field(id) {
385             ArrayShape::SizeField(size_field)
386         } else {
387             ArrayShape::Unknown
388         };
389 
390         // TODO size modifier
391 
392         let span = match padding_size {
393             Some(padding_size) => {
394                 let span = self.span;
395                 let padding_octets = padding_size / 8;
396                 self.check_size(span, &quote!(#padding_octets));
397                 self.tokens.extend(quote! {
398                     let (mut head, tail) = #span.split_at(#padding_octets);
399                     #span = tail;
400                 });
401                 format_ident!("head")
402             }
403             None => self.span.clone(),
404         };
405 
406         let field_name = id;
407         let packet_name = self.packet_name;
408         let id = id.to_ident();
409 
410         let parse_element = self.parse_array_element(&span, width, type_id, decl);
411         match (element_width, &array_shape) {
412             (ElementWidth::Unknown, ArrayShape::SizeField(size_field)) => {
413                 // The element width is not known, but the array full
414                 // octet size is known by size field. Parse elements
415                 // item by item as a vector.
416                 self.check_size(&span, &quote!(#size_field));
417                 let parse_element =
418                     self.parse_array_element(&format_ident!("head"), width, type_id, decl);
419                 self.tokens.extend(quote! {
420                     let (mut head, tail) = #span.split_at(#size_field);
421                     #span = tail;
422                     let mut #id = Vec::new();
423                     while !head.is_empty() {
424                         #id.push(#parse_element?);
425                     }
426                 });
427             }
428             (ElementWidth::Unknown, ArrayShape::Static(count)) => {
429                 // The element width is not known, but the array
430                 // element count is known statically. Parse elements
431                 // item by item as an array.
432                 let count = proc_macro2::Literal::usize_unsuffixed(*count);
433                 self.tokens.extend(quote! {
434                     // TODO(mgeisler): use
435                     // https://doc.rust-lang.org/std/array/fn.try_from_fn.html
436                     // when stabilized.
437                     let mut #id = Vec::with_capacity(#count);
438                     for _ in 0..#count {
439                         #id.push(#parse_element?)
440                     }
441                     let #id = #id
442                         .try_into()
443                         .map_err(|_| DecodeError::InvalidPacketError)?;
444                 });
445             }
446             (ElementWidth::Unknown, ArrayShape::CountField(count_field)) => {
447                 // The element width is not known, but the array
448                 // element count is known by the count field. Parse
449                 // elements item by item as a vector.
450                 self.tokens.extend(quote! {
451                     let #id = (0..#count_field)
452                         .map(|_| #parse_element)
453                         .collect::<Result<Vec<_>, DecodeError>>()?;
454                 });
455             }
456             (ElementWidth::Unknown, ArrayShape::Unknown) => {
457                 // Neither the count not size is known, parse elements
458                 // until the end of the span.
459                 self.tokens.extend(quote! {
460                     let mut #id = Vec::new();
461                     while !#span.is_empty() {
462                         #id.push(#parse_element?);
463                     }
464                 });
465             }
466             (ElementWidth::Static(element_width), ArrayShape::Static(count)) => {
467                 // The element width is known, and the array element
468                 // count is known statically.
469                 let count = proc_macro2::Literal::usize_unsuffixed(*count);
470                 // This creates a nicely formatted size.
471                 let array_size = if element_width == 1 {
472                     quote!(#count)
473                 } else {
474                     let element_width = proc_macro2::Literal::usize_unsuffixed(element_width);
475                     quote!(#count * #element_width)
476                 };
477                 self.check_size(&span, &quote! { #array_size });
478                 self.tokens.extend(quote! {
479                     // TODO(mgeisler): use
480                     // https://doc.rust-lang.org/std/array/fn.try_from_fn.html
481                     // when stabilized.
482                     let mut #id = Vec::with_capacity(#count);
483                     for _ in 0..#count {
484                         #id.push(#parse_element?)
485                     }
486                     let #id = #id
487                         .try_into()
488                         .map_err(|_| DecodeError::InvalidPacketError)?;
489                 });
490             }
491             (ElementWidth::Static(element_width), ArrayShape::CountField(count_field)) => {
492                 // The element width is known, and the array element
493                 // count is known dynamically by the count field.
494                 self.check_size(&span, &quote!(#count_field * #element_width));
495                 self.tokens.extend(quote! {
496                     let #id = (0..#count_field)
497                         .map(|_| #parse_element)
498                         .collect::<Result<Vec<_>, DecodeError>>()?;
499                 });
500             }
501             (ElementWidth::Static(element_width), ArrayShape::SizeField(_))
502             | (ElementWidth::Static(element_width), ArrayShape::Unknown) => {
503                 // The element width is known, and the array full size
504                 // is known by size field, or unknown (in which case
505                 // it is the remaining span length).
506                 let array_size = if let ArrayShape::SizeField(size_field) = &array_shape {
507                     self.check_size(&span, &quote!(#size_field));
508                     quote!(#size_field)
509                 } else {
510                     quote!(#span.remaining())
511                 };
512                 let count_field = format_ident!("{id}_count");
513                 let array_count = if element_width != 1 {
514                     let element_width = proc_macro2::Literal::usize_unsuffixed(element_width);
515                     self.tokens.extend(quote! {
516                         if #array_size % #element_width != 0 {
517                             return Err(DecodeError::InvalidArraySize {
518                                 array: #array_size,
519                                 element: #element_width,
520                             });
521                         }
522                         let #count_field = #array_size / #element_width;
523                     });
524                     quote!(#count_field)
525                 } else {
526                     array_size
527                 };
528 
529                 self.tokens.extend(quote! {
530                     let mut #id = Vec::with_capacity(#array_count);
531                     for _ in 0..#array_count {
532                         #id.push(#parse_element?);
533                     }
534                 });
535             }
536             (ElementWidth::Dynamic(element_size_field), ArrayShape::Static(count)) => {
537                 // The element width is known, and the array element
538                 // count is known statically.
539                 let array_size = if *count == 1 {
540                     quote!(#element_size_field)
541                 } else {
542                     quote!(#count * #element_size_field)
543                 };
544 
545                 self.check_size(&span, &array_size);
546 
547                 let parse_element =
548                     self.parse_array_element(&format_ident!("chunk"), width, type_id, decl);
549 
550                 self.tokens.extend(quote! {
551                     // TODO: use
552                     // https://doc.rust-lang.org/std/array/fn.try_from_fn.html
553                     // when stabilized.
554                     let #id = #span.chunks(#element_size_field)
555                         .take(#count)
556                         .map(|mut chunk| #parse_element.and_then(|value| {
557                             if chunk.is_empty() {
558                                 Ok(value)
559                             } else {
560                                 Err(DecodeError::TrailingBytesInArray {
561                                     obj: #packet_name,
562                                     field: #field_name,
563                                 })
564                             }
565                          }))
566                         .collect::<Result<Vec<_>, DecodeError>>()?;
567                     #span = &#span[#array_size..];
568                     let #id = #id
569                         .try_into()
570                         .map_err(|_| DecodeError::InvalidPacketError)?;
571                 });
572             }
573             (ElementWidth::Dynamic(element_size_field), ArrayShape::CountField(count_field)) => {
574                 // The element width is known, and the array element
575                 // count is known dynamically by the count field.
576                 self.check_size(&span, &quote!(#count_field * #element_size_field));
577 
578                 let parse_element =
579                     self.parse_array_element(&format_ident!("chunk"), width, type_id, decl);
580 
581                 self.tokens.extend(quote! {
582                     let #id = #span.chunks(#element_size_field)
583                         .take(#count_field)
584                         .map(|mut chunk| #parse_element.and_then(|value| {
585                             if chunk.is_empty() {
586                                 Ok(value)
587                             } else {
588                                 Err(DecodeError::TrailingBytesInArray {
589                                     obj: #packet_name,
590                                     field: #field_name,
591                                 })
592                             }
593                          }))
594                         .collect::<Result<Vec<_>, DecodeError>>()?;
595                     #span = &#span[(#element_size_field * #count_field)..];
596                 });
597             }
598             (ElementWidth::Dynamic(element_size_field), ArrayShape::SizeField(_))
599             | (ElementWidth::Dynamic(element_size_field), ArrayShape::Unknown) => {
600                 // The element width is known, and the array full size
601                 // is known by size field, or unknown (in which case
602                 // it is the remaining span length).
603                 let array_size = if let ArrayShape::SizeField(size_field) = &array_shape {
604                     self.check_size(&span, &quote!(#size_field));
605                     quote!(#size_field)
606                 } else {
607                     quote!(#span.remaining())
608                 };
609                 self.tokens.extend(quote! {
610                     if #array_size % #element_size_field != 0 {
611                         return Err(DecodeError::InvalidArraySize {
612                             array: #array_size,
613                             element: #element_size_field,
614                         });
615                     }
616                 });
617 
618                 let parse_element =
619                     self.parse_array_element(&format_ident!("chunk"), width, type_id, decl);
620 
621                 self.tokens.extend(quote! {
622                     let #id = #span.chunks(#element_size_field)
623                         .take(#array_size / #element_size_field)
624                         .map(|mut chunk| #parse_element.and_then(|value| {
625                             if chunk.is_empty() {
626                                 Ok(value)
627                             } else {
628                                 Err(DecodeError::TrailingBytesInArray {
629                                     obj: #packet_name,
630                                     field: #field_name,
631                                 })
632                             }
633                          }))
634                         .collect::<Result<Vec<_>, DecodeError>>()?;
635                     #span = &#span[#array_size..];
636                 });
637             }
638         }
639     }
640 
641     /// Parse typedef fields.
642     ///
643     /// This is only for non-enum fields: enums are parsed via
644     /// add_bit_field.
add_typedef_field(&mut self, id: &str, type_id: &str)645     fn add_typedef_field(&mut self, id: &str, type_id: &str) {
646         assert_eq!(self.shift, 0, "Typedef field does not start on an octet boundary");
647 
648         let decl = self.scope.typedef[type_id];
649         let span = self.span;
650         let id = id.to_ident();
651         let type_id = type_id.to_ident();
652 
653         self.tokens.extend(match self.schema.decl_size(decl.key) {
654             analyzer::Size::Unknown | analyzer::Size::Dynamic => quote! {
655                 let (#id, mut #span) = #type_id::decode(#span)?;
656             },
657             analyzer::Size::Static(width) => {
658                 assert_eq!(width % 8, 0, "Typedef field type size is not a multiple of 8");
659                 match &decl.desc {
660                     ast::DeclDesc::Checksum { .. } => todo!(),
661                     ast::DeclDesc::CustomField { .. } if [8, 16, 32, 64].contains(&width) => {
662                         let get_uint = types::get_uint(self.endianness, width, span);
663                         quote! {
664                             let #id = #get_uint.into();
665                         }
666                     }
667                     ast::DeclDesc::CustomField { .. } => {
668                         let get_uint = types::get_uint(self.endianness, width, span);
669                         quote! {
670                             let #id = (#get_uint)
671                                 .try_into()
672                                 .unwrap(); // Value is masked and conversion must succeed.
673                         }
674                     }
675                     ast::DeclDesc::Struct { .. } => {
676                         quote! {
677                             let (#id, mut #span) = #type_id::decode(#span)?;
678                         }
679                     }
680                     _ => unreachable!(),
681                 }
682             }
683         });
684     }
685 
686     /// Parse body and payload fields.
add_payload_field(&mut self, size_modifier: Option<&str>)687     fn add_payload_field(&mut self, size_modifier: Option<&str>) {
688         let span = self.span;
689         let payload_size_field = self.decl.payload_size();
690         let offset_from_end = self.payload_field_offset_from_end();
691 
692         if self.shift != 0 {
693             todo!("Unexpected non byte aligned payload");
694         }
695 
696         if let Some(ast::FieldDesc::Size { field_id, .. }) = &payload_size_field.map(|f| &f.desc) {
697             // The payload or body has a known size. Consume the
698             // payload and update the span in case fields are placed
699             // after the payload.
700             let size_field = size_field_ident(field_id);
701             if let Some(size_modifier) = size_modifier {
702                 let size_modifier = proc_macro2::Literal::usize_unsuffixed(
703                     size_modifier.parse::<usize>().expect("failed to parse the size modifier"),
704                 );
705                 let packet_name = &self.packet_name;
706                 // Push code to check that the size is greater than the size
707                 // modifier. Required to safely substract the modifier from the
708                 // size.
709                 self.tokens.extend(quote! {
710                     if #size_field < #size_modifier {
711                         return Err(DecodeError::InvalidLengthError {
712                             obj: #packet_name,
713                             wanted: #size_modifier,
714                             got: #size_field,
715                         });
716                     }
717                     let #size_field = #size_field - #size_modifier;
718                 });
719             }
720             self.check_size(self.span, &quote!(#size_field ));
721             self.tokens.extend(quote! {
722                 let payload = #span[..#size_field].to_vec();
723                 #span.advance(#size_field);
724             });
725         } else if offset_from_end == Some(0) {
726             // The payload or body is the last field of a packet,
727             // consume the remaining span.
728             self.tokens.extend(quote! {
729                 let payload = #span.to_vec();
730                 #span.advance(payload.len());
731             });
732         } else if let Some(offset_from_end) = offset_from_end {
733             // The payload or body is followed by fields of static
734             // size. Consume the span that is not reserved for the
735             // following fields.
736             assert_eq!(
737                 offset_from_end % 8,
738                 0,
739                 "Payload field offset from end of packet is not a multiple of 8"
740             );
741             let offset_from_end = proc_macro2::Literal::usize_unsuffixed(offset_from_end / 8);
742             self.check_size(self.span, &quote!(#offset_from_end));
743             self.tokens.extend(quote! {
744                 let payload = #span[..#span.len() - #offset_from_end].to_vec();
745                 #span.advance(payload.len());
746             });
747         }
748 
749         let decl = self.scope.typedef[self.packet_name];
750         if let ast::DeclDesc::Struct { .. } = &decl.desc {
751             self.tokens.extend(quote! {
752                 let payload = Vec::from(payload);
753             });
754         }
755     }
756 
757     /// Parse a single array field element from `span`.
parse_array_element( &self, span: &proc_macro2::Ident, width: Option<usize>, type_id: Option<&str>, decl: Option<&ast::Decl>, ) -> proc_macro2::TokenStream758     fn parse_array_element(
759         &self,
760         span: &proc_macro2::Ident,
761         width: Option<usize>,
762         type_id: Option<&str>,
763         decl: Option<&ast::Decl>,
764     ) -> proc_macro2::TokenStream {
765         if let Some(width) = width {
766             let get_uint = types::get_uint(self.endianness, width, span);
767             return quote! {
768                 Ok::<_, DecodeError>(#get_uint)
769             };
770         }
771 
772         if let Some(ast::DeclDesc::Enum { id, width, .. }) = decl.map(|decl| &decl.desc) {
773             let get_uint = types::get_uint(self.endianness, *width, span);
774             let type_id = id.to_ident();
775             let packet_name = &self.packet_name;
776             return quote! {
777                 #type_id::try_from(#get_uint).map_err(|unknown_val| DecodeError::InvalidEnumValueError {
778                     obj: #packet_name,
779                     field: "", // TODO(mgeisler): fill out or remove
780                     value: unknown_val as u64,
781                     type_: #id,
782                 })
783             };
784         }
785 
786         let type_id = type_id.unwrap().to_ident();
787         quote! {
788             #type_id::decode_mut(&mut #span)
789         }
790     }
791 }
792 
793 impl quote::ToTokens for FieldParser<'_> {
to_tokens(&self, tokens: &mut proc_macro2::TokenStream)794     fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
795         tokens.extend(self.tokens.clone());
796     }
797 }
798 
799 #[cfg(test)]
800 mod tests {
801     use super::*;
802     use crate::analyzer;
803     use crate::ast;
804     use crate::parser::parse_inline;
805 
806     /// Parse a string fragment as a PDL file.
807     ///
808     /// # Panics
809     ///
810     /// Panics on parse errors.
parse_str(text: &str) -> ast::File811     pub fn parse_str(text: &str) -> ast::File {
812         let mut db = ast::SourceDatabase::new();
813         let file = parse_inline(&mut db, "stdin", String::from(text)).expect("parse error");
814         analyzer::analyze(&file).expect("analyzer error")
815     }
816 
817     #[test]
test_find_fields_static()818     fn test_find_fields_static() {
819         let code = "
820               little_endian_packets
821               packet P {
822                 a: 24[3],
823               }
824             ";
825         let file = parse_str(code);
826         let scope = analyzer::Scope::new(&file).unwrap();
827         let schema = analyzer::Schema::new(&file);
828         let span = format_ident!("bytes");
829         let parser = FieldParser::new(&scope, &schema, file.endianness.value, "P", &span);
830         assert_eq!(parser.find_size_field("a"), None);
831         assert_eq!(parser.find_count_field("a"), None);
832     }
833 
834     #[test]
test_find_fields_dynamic_count()835     fn test_find_fields_dynamic_count() {
836         let code = "
837               little_endian_packets
838               packet P {
839                 _count_(b): 24,
840                 b: 16[],
841               }
842             ";
843         let file = parse_str(code);
844         let scope = analyzer::Scope::new(&file).unwrap();
845         let schema = analyzer::Schema::new(&file);
846         let span = format_ident!("bytes");
847         let parser = FieldParser::new(&scope, &schema, file.endianness.value, "P", &span);
848         assert_eq!(parser.find_size_field("b"), None);
849         assert_eq!(parser.find_count_field("b"), Some(format_ident!("b_count")));
850     }
851 
852     #[test]
test_find_fields_dynamic_size()853     fn test_find_fields_dynamic_size() {
854         let code = "
855               little_endian_packets
856               packet P {
857                 _size_(c): 8,
858                 c: 24[],
859               }
860             ";
861         let file = parse_str(code);
862         let scope = analyzer::Scope::new(&file).unwrap();
863         let schema = analyzer::Schema::new(&file);
864         let span = format_ident!("bytes");
865         let parser = FieldParser::new(&scope, &schema, file.endianness.value, "P", &span);
866         assert_eq!(parser.find_size_field("c"), Some(format_ident!("c_size")));
867         assert_eq!(parser.find_count_field("c"), None);
868     }
869 }
870