1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use std::collections::HashMap;
16 
17 use proc_macro2::TokenStream;
18 use quote::{format_ident, quote};
19 
20 use crate::{
21     ast,
22     backends::{
23         intermediate::{ComputedValue, ComputedValueId, PacketOrStruct, Schema},
24         rust_no_allocation::utils::get_integer_type,
25     },
26 };
27 
standardize_child(id: &str) -> &str28 fn standardize_child(id: &str) -> &str {
29     match id {
30         "_body_" | "_payload_" => "_child_",
31         _ => id,
32     }
33 }
34 
generate_packet_serializer( id: &str, parent_id: Option<&str>, fields: &[ast::Field], schema: &Schema, curr_schema: &PacketOrStruct, children: &HashMap<&str, Vec<&str>>, ) -> TokenStream35 pub fn generate_packet_serializer(
36     id: &str,
37     parent_id: Option<&str>,
38     fields: &[ast::Field],
39     schema: &Schema,
40     curr_schema: &PacketOrStruct,
41     children: &HashMap<&str, Vec<&str>>,
42 ) -> TokenStream {
43     let id_ident = format_ident!("{id}Builder");
44 
45     let builder_fields = fields
46         .iter()
47         .filter_map(|field| {
48             match &field.desc {
49                 ast::FieldDesc::Padding { .. }
50                 | ast::FieldDesc::Flag { .. }
51                 | ast::FieldDesc::Reserved { .. }
52                 | ast::FieldDesc::FixedScalar { .. }
53                 | ast::FieldDesc::FixedEnum { .. }
54                 | ast::FieldDesc::ElementSize { .. }
55                 | ast::FieldDesc::Count { .. }
56                 | ast::FieldDesc::Size { .. } => {
57                     // no-op, no getter generated for this type
58                     None
59                 }
60                 ast::FieldDesc::Group { .. } => unreachable!(),
61                 ast::FieldDesc::Checksum { .. } => {
62                     unimplemented!("checksums not yet supported with this backend")
63                 }
64                 ast::FieldDesc::Body | ast::FieldDesc::Payload { .. } => {
65                     let type_ident = format_ident!("{id}Child");
66                     Some(("_child_", quote! { #type_ident }))
67                 }
68                 ast::FieldDesc::Array { id, width, type_id, .. } => {
69                     let element_type = if let Some(width) = width {
70                         get_integer_type(*width)
71                     } else if let Some(type_id) = type_id {
72                         if schema.enums.contains_key(type_id.as_str()) {
73                             format_ident!("{type_id}")
74                         } else {
75                             format_ident!("{type_id}Builder")
76                         }
77                     } else {
78                         unreachable!();
79                     };
80                     Some((id.as_str(), quote! { Box<[#element_type]> }))
81                 }
82                 ast::FieldDesc::Scalar { id, width } => {
83                     let id_type = get_integer_type(*width);
84                     Some((id.as_str(), quote! { #id_type }))
85                 }
86                 ast::FieldDesc::Typedef { id, type_id } => {
87                     let type_ident = if schema.enums.contains_key(type_id.as_str()) {
88                         format_ident!("{type_id}")
89                     } else {
90                         format_ident!("{type_id}Builder")
91                     };
92                     Some((id.as_str(), quote! { #type_ident }))
93                 }
94             }
95         })
96         .map(|(id, typ)| {
97             let id_ident = format_ident!("{id}");
98             quote! { pub #id_ident: #typ }
99         });
100 
101     let mut has_child = false;
102 
103     let serializer = fields.iter().map(|field| {
104         match &field.desc {
105             ast::FieldDesc::Checksum { .. } | ast::FieldDesc::Group { .. } | ast::FieldDesc::Flag { .. } => unimplemented!(),
106             ast::FieldDesc::Padding { size, .. } => {
107                 quote! {
108                     if (most_recent_array_size_in_bits > #size * 8) {
109                         return Err(SerializeError::NegativePadding);
110                     }
111                     writer.write_bits((#size * 8 - most_recent_array_size_in_bits) as usize, || Ok(0u64))?;
112                 }
113             },
114             ast::FieldDesc::Size { field_id, width } => {
115                 let field_id = standardize_child(field_id);
116                 let field_ident = format_ident!("{field_id}");
117 
118                 // if the element-size is fixed, we can directly multiply
119                 if let Some(ComputedValue::Constant(element_width)) = curr_schema.computed_values.get(&ComputedValueId::FieldElementSize(field_id)) {
120                     return quote! {
121                         writer.write_bits(
122                             #width,
123                             || u64::try_from(self.#field_ident.len() * #element_width).or(Err(SerializeError::IntegerConversionFailure))
124                         )?;
125                     }
126                 }
127 
128                 // if the field is "countable", loop over it to sum up the size
129                 if curr_schema.computed_values.contains_key(&ComputedValueId::FieldCount(field_id)) {
130                     return quote! {
131                         writer.write_bits(#width, || {
132                             let size_in_bits = self.#field_ident.iter().map(|elem| elem.size_in_bits()).fold(Ok(0), |total, next| {
133                                 let total: u64 = total?;
134                                 let next = u64::try_from(next?).or(Err(SerializeError::IntegerConversionFailure))?;
135                                 total.checked_add(next).ok_or(SerializeError::IntegerConversionFailure)
136                             })?;
137                             if size_in_bits % 8 != 0 {
138                                 return Err(SerializeError::AlignmentError);
139                             }
140                             Ok(size_in_bits / 8)
141                         })?;
142                     }
143                 }
144 
145                 // otherwise, try to get the size directly
146                 quote! {
147                     writer.write_bits(#width, || {
148                         let size_in_bits: u64 = self.#field_ident.size_in_bits()?.try_into().or(Err(SerializeError::IntegerConversionFailure))?;
149                         if size_in_bits % 8 != 0 {
150                             return Err(SerializeError::AlignmentError);
151                         }
152                         Ok(size_in_bits / 8)
153                     })?;
154                 }
155             }
156             ast::FieldDesc::Count { field_id, width } => {
157                 let field_ident = format_ident!("{field_id}");
158                 quote! { writer.write_bits(#width, || u64::try_from(self.#field_ident.len()).or(Err(SerializeError::IntegerConversionFailure)))?; }
159             }
160             ast::FieldDesc::ElementSize { field_id, width } => {
161                 // TODO(aryarahul) - add validation for elementsize against all the other elements
162                 let field_ident = format_ident!("{field_id}");
163                 quote! {
164                     let get_element_size = || Ok(if let Some(field) = self.#field_ident.get(0) {
165                         let size_in_bits = field.size_in_bits()?;
166                         if size_in_bits % 8 == 0 {
167                             (size_in_bits / 8) as u64
168                         } else {
169                             return Err(SerializeError::AlignmentError);
170                         }
171                     } else {
172                         0
173                     });
174                     writer.write_bits(#width, || get_element_size() )?;
175                 }
176             }
177             ast::FieldDesc::Reserved { width, .. } => {
178                 quote!{ writer.write_bits(#width, || Ok(0u64))?; }
179             }
180             ast::FieldDesc::Scalar { width, id } => {
181                 let field_ident = format_ident!("{id}");
182                 quote! { writer.write_bits(#width, || Ok(self.#field_ident))?; }
183             }
184             ast::FieldDesc::FixedScalar { width, value } => {
185                 let width = quote! { #width };
186                 let value = {
187                     let value = *value as u64;
188                     quote! { #value }
189                 };
190                 quote!{ writer.write_bits(#width, || Ok(#value))?; }
191             }
192             ast::FieldDesc::FixedEnum { enum_id, tag_id } => {
193                 let width = {
194                     let width = schema.enums[enum_id.as_str()].width;
195                     quote! { #width }
196                 };
197                 let value = {
198                     let enum_ident = format_ident!("{}", enum_id);
199                     let tag_ident = format_ident!("{tag_id}");
200                     quote! { #enum_ident::#tag_ident.value() }
201                 };
202                 quote!{ writer.write_bits(#width, || Ok(#value))?; }
203             }
204             ast::FieldDesc::Body | ast::FieldDesc::Payload { .. } => {
205                 has_child = true;
206                 quote! { self._child_.serialize(writer)?; }
207             }
208             ast::FieldDesc::Array { width, id, .. } => {
209                 let id_ident = format_ident!("{id}");
210                 if let Some(width) = width {
211                     quote! {
212                         for elem in self.#id_ident.iter() {
213                             writer.write_bits(#width, || Ok(*elem))?;
214                         }
215                         let most_recent_array_size_in_bits = #width * self.#id_ident.len();
216                     }
217                 } else {
218                     quote! {
219                         let mut most_recent_array_size_in_bits = 0;
220                         for elem in self.#id_ident.iter() {
221                             most_recent_array_size_in_bits += elem.size_in_bits()?;
222                             elem.serialize(writer)?;
223                         }
224                      }
225                 }
226             }
227             ast::FieldDesc::Typedef { id, .. } => {
228                 let id_ident = format_ident!("{id}");
229                 quote! { self.#id_ident.serialize(writer)?; }
230             }
231         }
232     }).collect::<Vec<_>>();
233 
234     let variant_names = children.get(id).into_iter().flatten().collect::<Vec<_>>();
235 
236     let variants = variant_names.iter().map(|name| {
237         let name_ident = format_ident!("{name}");
238         let variant_ident = format_ident!("{name}Builder");
239         quote! { #name_ident(#variant_ident) }
240     });
241 
242     let variant_serializers = variant_names.iter().map(|name| {
243         let name_ident = format_ident!("{name}");
244         quote! {
245             Self::#name_ident(x) => {
246                 x.serialize(writer)?;
247             }
248         }
249     });
250 
251     let children_enum = if has_child {
252         let enum_ident = format_ident!("{id}Child");
253         quote! {
254             #[derive(Debug, Clone, PartialEq, Eq)]
255             pub enum #enum_ident {
256                 RawData(Box<[u8]>),
257                 #(#variants),*
258             }
259 
260             impl Serializable for #enum_ident {
261                 fn serialize(&self, writer: &mut impl BitWriter) -> Result<(), SerializeError> {
262                     match self {
263                         Self::RawData(data) => {
264                             for byte in data.iter() {
265                                 writer.write_bits(8, || Ok(*byte as u64))?;
266                             }
267                         },
268                         #(#variant_serializers),*
269                     }
270                     Ok(())
271                 }
272             }
273         }
274     } else {
275         quote! {}
276     };
277 
278     let parent_type_converter = if let Some(parent_id) = parent_id {
279         let parent_enum_ident = format_ident!("{parent_id}Child");
280         let variant_ident = format_ident!("{id}");
281         Some(quote! {
282             impl From<#id_ident> for #parent_enum_ident {
283                 fn from(x: #id_ident) -> Self {
284                     Self::#variant_ident(x)
285                 }
286             }
287         })
288     } else {
289         None
290     };
291 
292     let owned_packet_ident = format_ident!("Owned{id}View");
293 
294     quote! {
295       #[derive(Debug, Clone, PartialEq, Eq)]
296       pub struct #id_ident {
297           #(#builder_fields),*
298       }
299 
300       impl Builder for #id_ident {
301         type OwnedPacket = #owned_packet_ident;
302       }
303 
304       impl Serializable for #id_ident {
305           fn serialize(&self, writer: &mut impl BitWriter) -> Result<(), SerializeError> {
306             #(#serializer)*
307             Ok(())
308           }
309       }
310 
311       #parent_type_converter
312 
313       #children_enum
314     }
315 }
316