1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use crate::backends::rust_legacy::{mask_bits, types, ToIdent, ToUpperCamelCase};
16 use crate::{analyzer, ast};
17 use quote::{format_ident, quote};
18 
19 /// A single bit-field value.
20 struct BitField {
21     value: proc_macro2::TokenStream, // An expression which produces a value.
22     field_type: types::Integer,      // The type of the value.
23     shift: usize,                    // A bit-shift to apply to `value`.
24 }
25 
26 pub struct FieldSerializer<'a> {
27     scope: &'a analyzer::Scope<'a>,
28     schema: &'a analyzer::Schema,
29     endianness: ast::EndiannessValue,
30     packet_name: &'a str,
31     span: &'a proc_macro2::Ident,
32     chunk: Vec<BitField>,
33     code: Vec<proc_macro2::TokenStream>,
34     shift: usize,
35 }
36 
37 impl<'a> FieldSerializer<'a> {
new( scope: &'a analyzer::Scope<'a>, schema: &'a analyzer::Schema, endianness: ast::EndiannessValue, packet_name: &'a str, span: &'a proc_macro2::Ident, ) -> FieldSerializer<'a>38     pub fn new(
39         scope: &'a analyzer::Scope<'a>,
40         schema: &'a analyzer::Schema,
41         endianness: ast::EndiannessValue,
42         packet_name: &'a str,
43         span: &'a proc_macro2::Ident,
44     ) -> FieldSerializer<'a> {
45         FieldSerializer {
46             scope,
47             schema,
48             endianness,
49             packet_name,
50             span,
51             chunk: Vec::new(),
52             code: Vec::new(),
53             shift: 0,
54         }
55     }
56 
add(&mut self, field: &ast::Field)57     pub fn add(&mut self, field: &ast::Field) {
58         match &field.desc {
59             _ if field.cond.is_some() => self.add_optional_field(field),
60             _ if self.scope.is_bitfield(field) => self.add_bit_field(field),
61             ast::FieldDesc::Array { id, width, .. } => self.add_array_field(
62                 id,
63                 *width,
64                 self.schema.padded_size(field.key),
65                 self.scope.get_type_declaration(field),
66             ),
67             ast::FieldDesc::Typedef { id, type_id } => {
68                 self.add_typedef_field(id, type_id);
69             }
70             ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body { .. } => {
71                 self.add_payload_field();
72             }
73             // Padding field handled in serialization of associated array field.
74             ast::FieldDesc::Padding { .. } => (),
75             _ => todo!("Cannot yet serialize {field:?}"),
76         }
77     }
78 
add_optional_field(&mut self, field: &ast::Field)79     fn add_optional_field(&mut self, field: &ast::Field) {
80         self.code.push(match &field.desc {
81             ast::FieldDesc::Scalar { id, width } => {
82                 let name = id;
83                 let id = id.to_ident();
84                 let backing_type = types::Integer::new(*width);
85                 let write = types::put_uint(self.endianness, &quote!(*#id), *width, self.span);
86 
87                 let range_check = (backing_type.width > *width).then(|| {
88                     let packet_name = &self.packet_name;
89                     let max_value = mask_bits(*width, "u64");
90 
91                     quote! {
92                         if *#id > #max_value {
93                             return Err(EncodeError::InvalidScalarValue {
94                                 packet: #packet_name,
95                                 field: #name,
96                                 value: *#id as u64,
97                                 maximum_value: #max_value as u64,
98                             })
99                         }
100                     }
101                 });
102 
103                 quote! {
104                     if let Some(#id) = &self.#id {
105                         #range_check
106                         #write
107                     }
108                 }
109             }
110             ast::FieldDesc::Typedef { id, type_id } => match &self.scope.typedef[type_id].desc {
111                 ast::DeclDesc::Enum { width, .. } => {
112                     let id = id.to_ident();
113                     let backing_type = types::Integer::new(*width);
114                     let write = types::put_uint(
115                         self.endianness,
116                         &quote!(#backing_type::from(#id)),
117                         *width,
118                         self.span,
119                     );
120                     quote! {
121                         if let Some(#id) = &self.#id {
122                             #write
123                         }
124                     }
125                 }
126                 ast::DeclDesc::Struct { .. } => {
127                     let id = id.to_ident();
128                     let span = self.span;
129                     quote! {
130                         if let Some(#id) = &self.#id {
131                             #id.write_to(#span)?;
132                         }
133                     }
134                 }
135                 _ => unreachable!(),
136             },
137             _ => unreachable!(),
138         })
139     }
140 
add_bit_field(&mut self, field: &ast::Field)141     fn add_bit_field(&mut self, field: &ast::Field) {
142         let width = self.schema.field_size(field.key).static_().unwrap();
143         let shift = self.shift;
144 
145         match &field.desc {
146             ast::FieldDesc::Flag { optional_field_id, set_value, .. } => {
147                 let optional_field_id = optional_field_id.to_ident();
148                 let cond_value_present =
149                     syn::parse_str::<syn::LitInt>(&format!("{}", set_value)).unwrap();
150                 let cond_value_absent =
151                     syn::parse_str::<syn::LitInt>(&format!("{}", 1 - set_value)).unwrap();
152                 self.chunk.push(BitField {
153                     value: quote! {
154                         if self.#optional_field_id.is_some() {
155                             #cond_value_present
156                         } else {
157                             #cond_value_absent
158                         }
159                     },
160                     field_type: types::Integer::new(1),
161                     shift,
162                 });
163             }
164             ast::FieldDesc::Scalar { id, width } => {
165                 let field_name = id.to_ident();
166                 let field_type = types::Integer::new(*width);
167                 if field_type.width > *width {
168                     let packet_name = &self.packet_name;
169                     let max_value = mask_bits(*width, "u64");
170                     self.code.push(quote! {
171                         if self.#field_name > #max_value {
172                             return Err(EncodeError::InvalidScalarValue {
173                                 packet: #packet_name,
174                                 field: #id,
175                                 value: self.#field_name as u64,
176                                 maximum_value: #max_value,
177                             })
178                         }
179                     });
180                 }
181                 self.chunk.push(BitField { value: quote!(self.#field_name), field_type, shift });
182             }
183             ast::FieldDesc::FixedEnum { enum_id, tag_id, .. } => {
184                 let field_type = types::Integer::new(width);
185                 let enum_id = enum_id.to_ident();
186                 let tag_id = format_ident!("{}", tag_id.to_upper_camel_case());
187                 self.chunk.push(BitField {
188                     value: quote!(#field_type::from(#enum_id::#tag_id)),
189                     field_type,
190                     shift,
191                 });
192             }
193             ast::FieldDesc::FixedScalar { value, .. } => {
194                 let field_type = types::Integer::new(width);
195                 let value = proc_macro2::Literal::usize_unsuffixed(*value);
196                 self.chunk.push(BitField { value: quote!(#value), field_type, shift });
197             }
198             ast::FieldDesc::Typedef { id, .. } => {
199                 let field_name = id.to_ident();
200                 let field_type = types::Integer::new(width);
201                 self.chunk.push(BitField {
202                     value: quote!(#field_type::from(self.#field_name)),
203                     field_type,
204                     shift,
205                 });
206             }
207             ast::FieldDesc::Reserved { .. } => {
208                 // Nothing to do here.
209             }
210             ast::FieldDesc::Size { field_id, width, .. } => {
211                 let packet_name = &self.packet_name;
212                 let max_value = mask_bits(*width, "usize");
213 
214                 let decl = self.scope.typedef.get(self.packet_name).unwrap();
215                 let value_field = self
216                     .scope
217                     .iter_fields(decl)
218                     .find(|field| match &field.desc {
219                         ast::FieldDesc::Payload { .. } => field_id == "_payload_",
220                         ast::FieldDesc::Body { .. } => field_id == "_body_",
221                         _ => field.id() == Some(field_id),
222                     })
223                     .unwrap();
224 
225                 let field_name = field_id.to_ident();
226                 let field_type = types::Integer::new(*width);
227                 // TODO: size modifier
228 
229                 let value_field_decl = self.scope.get_type_declaration(value_field);
230 
231                 let field_size_name = format_ident!("{field_id}_size");
232                 let array_size = match (&value_field.desc, value_field_decl.map(|decl| &decl.desc))
233                 {
234                     (ast::FieldDesc::Payload { size_modifier: Some(size_modifier) }, _) => {
235                         let size_modifier = proc_macro2::Literal::usize_unsuffixed(
236                             size_modifier
237                                 .parse::<usize>()
238                                 .expect("failed to parse the size modifier"),
239                         );
240                         if let ast::DeclDesc::Packet { .. } = &decl.desc {
241                             quote! { (self.child.get_total_size() + #size_modifier) }
242                         } else {
243                             quote! { (self.payload.len() + #size_modifier) }
244                         }
245                     }
246                     (ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body { .. }, _) => {
247                         if let ast::DeclDesc::Packet { .. } = &decl.desc {
248                             quote! { self.child.get_total_size() }
249                         } else {
250                             quote! { self.payload.len() }
251                         }
252                     }
253                     (ast::FieldDesc::Array { width: Some(width), .. }, _)
254                     | (ast::FieldDesc::Array { .. }, Some(ast::DeclDesc::Enum { width, .. })) => {
255                         let byte_width = syn::Index::from(width / 8);
256                         if byte_width.index == 1 {
257                             quote! { self.#field_name.len() }
258                         } else {
259                             quote! { (self.#field_name.len() * #byte_width) }
260                         }
261                     }
262                     (ast::FieldDesc::Array { .. }, _) => {
263                         self.code.push(quote! {
264                             let #field_size_name = self.#field_name
265                                 .iter()
266                                 .map(|elem| elem.get_size())
267                                 .sum::<usize>();
268                         });
269                         quote! { #field_size_name }
270                     }
271                     _ => panic!("Unexpected size field: {field:?}"),
272                 };
273 
274                 self.code.push(quote! {
275                     if #array_size > #max_value {
276                         return Err(EncodeError::SizeOverflow {
277                             packet: #packet_name,
278                             field: #field_id,
279                             size: #array_size,
280                             maximum_size: #max_value,
281                         })
282                     }
283                 });
284 
285                 self.chunk.push(BitField {
286                     value: quote!(#array_size as #field_type),
287                     field_type,
288                     shift,
289                 });
290             }
291             ast::FieldDesc::Count { field_id, width, .. } => {
292                 let field_name = field_id.to_ident();
293                 let field_type = types::Integer::new(*width);
294                 if field_type.width > *width {
295                     let packet_name = &self.packet_name;
296                     let max_value = mask_bits(*width, "usize");
297                     self.code.push(quote! {
298                         if self.#field_name.len() > #max_value {
299                             return Err(EncodeError::CountOverflow {
300                                 packet: #packet_name,
301                                 field: #field_id,
302                                 count: self.#field_name.len(),
303                                 maximum_count: #max_value,
304                             })
305                         }
306                     });
307                 }
308                 self.chunk.push(BitField {
309                     value: quote!(self.#field_name.len() as #field_type),
310                     field_type,
311                     shift,
312                 });
313             }
314             _ => todo!("{field:?}"),
315         }
316 
317         self.shift += width;
318         if self.shift % 8 == 0 {
319             self.pack_bit_fields()
320         }
321     }
322 
pack_bit_fields(&mut self)323     fn pack_bit_fields(&mut self) {
324         assert_eq!(self.shift % 8, 0);
325         let chunk_type = types::Integer::new(self.shift);
326         let values = self
327             .chunk
328             .drain(..)
329             .map(|BitField { mut value, field_type, shift }| {
330                 if field_type.width != chunk_type.width {
331                     // We will be combining values with `|`, so we
332                     // need to cast them first.
333                     value = quote! { (#value as #chunk_type) };
334                 }
335                 if shift > 0 {
336                     let op = quote!(<<);
337                     let shift = proc_macro2::Literal::usize_unsuffixed(shift);
338                     value = quote! { (#value #op #shift) };
339                 }
340                 value
341             })
342             .collect::<Vec<_>>();
343 
344         match values.as_slice() {
345             [] => {
346                 let span = format_ident!("{}", self.span);
347                 let count = syn::Index::from(self.shift / 8);
348                 self.code.push(quote! {
349                     #span.put_bytes(0, #count);
350                 });
351             }
352             [value] => {
353                 let put = types::put_uint(self.endianness, value, self.shift, self.span);
354                 self.code.push(quote! {
355                     #put;
356                 });
357             }
358             _ => {
359                 let put = types::put_uint(self.endianness, &quote!(value), self.shift, self.span);
360                 self.code.push(quote! {
361                     let value = #(#values)|*;
362                     #put;
363                 });
364             }
365         }
366 
367         self.shift = 0;
368     }
369 
add_array_field( &mut self, id: &str, width: Option<usize>, padding_size: Option<usize>, decl: Option<&ast::Decl>, )370     fn add_array_field(
371         &mut self,
372         id: &str,
373         width: Option<usize>,
374         padding_size: Option<usize>,
375         decl: Option<&ast::Decl>,
376     ) {
377         let span = format_ident!("{}", self.span);
378         let serialize = match width {
379             Some(width) => {
380                 let value = quote!(*elem);
381                 types::put_uint(self.endianness, &value, width, self.span)
382             }
383             None => {
384                 if let Some(ast::DeclDesc::Enum { width, .. }) = decl.map(|decl| &decl.desc) {
385                     let element_type = types::Integer::new(*width);
386                     types::put_uint(
387                         self.endianness,
388                         &quote!(#element_type::from(elem)),
389                         *width,
390                         self.span,
391                     )
392                 } else {
393                     quote! {
394                         elem.write_to(#span)?
395                     }
396                 }
397             }
398         };
399 
400         let packet_name = self.packet_name;
401         let name = id;
402         let id = id.to_ident();
403 
404         if let Some(padding_size) = padding_size {
405             let padding_octets = padding_size / 8;
406             let element_width = match &width {
407                 Some(width) => Some(*width),
408                 None => self.schema.decl_size(decl.unwrap().key).static_(),
409             };
410 
411             let array_size = match element_width {
412                 Some(element_width) => {
413                     let element_size = proc_macro2::Literal::usize_unsuffixed(element_width / 8);
414                     quote! { self.#id.len() * #element_size }
415                 }
416                 _ => {
417                     quote! { self.#id.iter().fold(0, |size, elem| size + elem.get_size()) }
418                 }
419             };
420 
421             self.code.push(quote! {
422                 let array_size = #array_size;
423                 if array_size > #padding_octets {
424                     return Err(EncodeError::SizeOverflow {
425                         packet: #packet_name,
426                         field: #name,
427                         size: array_size,
428                         maximum_size: #padding_octets,
429                     })
430                 }
431                 for elem in &self.#id {
432                     #serialize;
433                 }
434                 #span.put_bytes(0, #padding_octets - array_size);
435             });
436         } else {
437             self.code.push(quote! {
438                 for elem in &self.#id {
439                     #serialize;
440                 }
441             });
442         }
443     }
444 
add_typedef_field(&mut self, id: &str, type_id: &str)445     fn add_typedef_field(&mut self, id: &str, type_id: &str) {
446         assert_eq!(self.shift, 0, "Typedef field does not start on an octet boundary");
447         let decl = self.scope.typedef[type_id];
448         if let ast::DeclDesc::Struct { parent_id: Some(_), .. } = &decl.desc {
449             panic!("Derived struct used in typedef field");
450         }
451 
452         let id = id.to_ident();
453         let span = format_ident!("{}", self.span);
454 
455         self.code.push(match &decl.desc {
456             ast::DeclDesc::Checksum { .. } => todo!(),
457             ast::DeclDesc::CustomField { width: Some(width), .. } => {
458                 let backing_type = types::Integer::new(*width);
459                 let put_uint = types::put_uint(
460                     self.endianness,
461                     &quote! { #backing_type::from(self.#id) },
462                     *width,
463                     self.span,
464                 );
465                 quote! {
466                     #put_uint;
467                 }
468             }
469             ast::DeclDesc::Struct { .. } => quote! {
470                 self.#id.write_to(#span)?;
471             },
472             _ => unreachable!(),
473         });
474     }
475 
add_payload_field(&mut self)476     fn add_payload_field(&mut self) {
477         if self.shift != 0 && self.endianness == ast::EndiannessValue::BigEndian {
478             panic!("Payload field does not start on an octet boundary");
479         }
480 
481         let decl = self.scope.typedef[self.packet_name];
482         let is_packet = matches!(&decl.desc, ast::DeclDesc::Packet { .. });
483 
484         let child_ids = self
485             .scope
486             .iter_children(decl)
487             .map(|child| child.id().unwrap().to_ident())
488             .collect::<Vec<_>>();
489 
490         let span = format_ident!("{}", self.span);
491         if self.shift == 0 {
492             if is_packet {
493                 let packet_data_child = format_ident!("{}DataChild", self.packet_name);
494                 self.code.push(quote! {
495                     match &self.child {
496                         #(#packet_data_child::#child_ids(child) => child.write_to(#span)?,)*
497                         #packet_data_child::Payload(payload) => #span.put_slice(payload),
498                         #packet_data_child::None => {},
499                     }
500                 })
501             } else {
502                 self.code.push(quote! {
503                     #span.put_slice(&self.payload);
504                 });
505             }
506         } else {
507             todo!("Shifted payloads");
508         }
509     }
510 }
511 
512 impl quote::ToTokens for FieldSerializer<'_> {
to_tokens(&self, tokens: &mut proc_macro2::TokenStream)513     fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
514         let code = &self.code;
515         tokens.extend(quote! {
516             #(#code)*
517         });
518     }
519 }
520