1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 use std::collections::HashMap;
16
17 use proc_macro2::TokenStream;
18 use quote::{format_ident, quote};
19
20 use crate::{
21 ast,
22 backends::{
23 intermediate::{ComputedValue, ComputedValueId, PacketOrStruct, Schema},
24 rust_no_allocation::utils::get_integer_type,
25 },
26 };
27
standardize_child(id: &str) -> &str28 fn standardize_child(id: &str) -> &str {
29 match id {
30 "_body_" | "_payload_" => "_child_",
31 _ => id,
32 }
33 }
34
generate_packet_serializer( id: &str, parent_id: Option<&str>, fields: &[ast::Field], schema: &Schema, curr_schema: &PacketOrStruct, children: &HashMap<&str, Vec<&str>>, ) -> TokenStream35 pub fn generate_packet_serializer(
36 id: &str,
37 parent_id: Option<&str>,
38 fields: &[ast::Field],
39 schema: &Schema,
40 curr_schema: &PacketOrStruct,
41 children: &HashMap<&str, Vec<&str>>,
42 ) -> TokenStream {
43 let id_ident = format_ident!("{id}Builder");
44
45 let builder_fields = fields
46 .iter()
47 .filter_map(|field| {
48 match &field.desc {
49 ast::FieldDesc::Padding { .. }
50 | ast::FieldDesc::Flag { .. }
51 | ast::FieldDesc::Reserved { .. }
52 | ast::FieldDesc::FixedScalar { .. }
53 | ast::FieldDesc::FixedEnum { .. }
54 | ast::FieldDesc::ElementSize { .. }
55 | ast::FieldDesc::Count { .. }
56 | ast::FieldDesc::Size { .. } => {
57 // no-op, no getter generated for this type
58 None
59 }
60 ast::FieldDesc::Group { .. } => unreachable!(),
61 ast::FieldDesc::Checksum { .. } => {
62 unimplemented!("checksums not yet supported with this backend")
63 }
64 ast::FieldDesc::Body | ast::FieldDesc::Payload { .. } => {
65 let type_ident = format_ident!("{id}Child");
66 Some(("_child_", quote! { #type_ident }))
67 }
68 ast::FieldDesc::Array { id, width, type_id, .. } => {
69 let element_type = if let Some(width) = width {
70 get_integer_type(*width)
71 } else if let Some(type_id) = type_id {
72 if schema.enums.contains_key(type_id.as_str()) {
73 format_ident!("{type_id}")
74 } else {
75 format_ident!("{type_id}Builder")
76 }
77 } else {
78 unreachable!();
79 };
80 Some((id.as_str(), quote! { Box<[#element_type]> }))
81 }
82 ast::FieldDesc::Scalar { id, width } => {
83 let id_type = get_integer_type(*width);
84 Some((id.as_str(), quote! { #id_type }))
85 }
86 ast::FieldDesc::Typedef { id, type_id } => {
87 let type_ident = if schema.enums.contains_key(type_id.as_str()) {
88 format_ident!("{type_id}")
89 } else {
90 format_ident!("{type_id}Builder")
91 };
92 Some((id.as_str(), quote! { #type_ident }))
93 }
94 }
95 })
96 .map(|(id, typ)| {
97 let id_ident = format_ident!("{id}");
98 quote! { pub #id_ident: #typ }
99 });
100
101 let mut has_child = false;
102
103 let serializer = fields.iter().map(|field| {
104 match &field.desc {
105 ast::FieldDesc::Checksum { .. } | ast::FieldDesc::Group { .. } | ast::FieldDesc::Flag { .. } => unimplemented!(),
106 ast::FieldDesc::Padding { size, .. } => {
107 quote! {
108 if (most_recent_array_size_in_bits > #size * 8) {
109 return Err(SerializeError::NegativePadding);
110 }
111 writer.write_bits((#size * 8 - most_recent_array_size_in_bits) as usize, || Ok(0u64))?;
112 }
113 },
114 ast::FieldDesc::Size { field_id, width } => {
115 let field_id = standardize_child(field_id);
116 let field_ident = format_ident!("{field_id}");
117
118 // if the element-size is fixed, we can directly multiply
119 if let Some(ComputedValue::Constant(element_width)) = curr_schema.computed_values.get(&ComputedValueId::FieldElementSize(field_id)) {
120 return quote! {
121 writer.write_bits(
122 #width,
123 || u64::try_from(self.#field_ident.len() * #element_width).or(Err(SerializeError::IntegerConversionFailure))
124 )?;
125 }
126 }
127
128 // if the field is "countable", loop over it to sum up the size
129 if curr_schema.computed_values.contains_key(&ComputedValueId::FieldCount(field_id)) {
130 return quote! {
131 writer.write_bits(#width, || {
132 let size_in_bits = self.#field_ident.iter().map(|elem| elem.size_in_bits()).fold(Ok(0), |total, next| {
133 let total: u64 = total?;
134 let next = u64::try_from(next?).or(Err(SerializeError::IntegerConversionFailure))?;
135 total.checked_add(next).ok_or(SerializeError::IntegerConversionFailure)
136 })?;
137 if size_in_bits % 8 != 0 {
138 return Err(SerializeError::AlignmentError);
139 }
140 Ok(size_in_bits / 8)
141 })?;
142 }
143 }
144
145 // otherwise, try to get the size directly
146 quote! {
147 writer.write_bits(#width, || {
148 let size_in_bits: u64 = self.#field_ident.size_in_bits()?.try_into().or(Err(SerializeError::IntegerConversionFailure))?;
149 if size_in_bits % 8 != 0 {
150 return Err(SerializeError::AlignmentError);
151 }
152 Ok(size_in_bits / 8)
153 })?;
154 }
155 }
156 ast::FieldDesc::Count { field_id, width } => {
157 let field_ident = format_ident!("{field_id}");
158 quote! { writer.write_bits(#width, || u64::try_from(self.#field_ident.len()).or(Err(SerializeError::IntegerConversionFailure)))?; }
159 }
160 ast::FieldDesc::ElementSize { field_id, width } => {
161 // TODO(aryarahul) - add validation for elementsize against all the other elements
162 let field_ident = format_ident!("{field_id}");
163 quote! {
164 let get_element_size = || Ok(if let Some(field) = self.#field_ident.get(0) {
165 let size_in_bits = field.size_in_bits()?;
166 if size_in_bits % 8 == 0 {
167 (size_in_bits / 8) as u64
168 } else {
169 return Err(SerializeError::AlignmentError);
170 }
171 } else {
172 0
173 });
174 writer.write_bits(#width, || get_element_size() )?;
175 }
176 }
177 ast::FieldDesc::Reserved { width, .. } => {
178 quote!{ writer.write_bits(#width, || Ok(0u64))?; }
179 }
180 ast::FieldDesc::Scalar { width, id } => {
181 let field_ident = format_ident!("{id}");
182 quote! { writer.write_bits(#width, || Ok(self.#field_ident))?; }
183 }
184 ast::FieldDesc::FixedScalar { width, value } => {
185 let width = quote! { #width };
186 let value = {
187 let value = *value as u64;
188 quote! { #value }
189 };
190 quote!{ writer.write_bits(#width, || Ok(#value))?; }
191 }
192 ast::FieldDesc::FixedEnum { enum_id, tag_id } => {
193 let width = {
194 let width = schema.enums[enum_id.as_str()].width;
195 quote! { #width }
196 };
197 let value = {
198 let enum_ident = format_ident!("{}", enum_id);
199 let tag_ident = format_ident!("{tag_id}");
200 quote! { #enum_ident::#tag_ident.value() }
201 };
202 quote!{ writer.write_bits(#width, || Ok(#value))?; }
203 }
204 ast::FieldDesc::Body | ast::FieldDesc::Payload { .. } => {
205 has_child = true;
206 quote! { self._child_.serialize(writer)?; }
207 }
208 ast::FieldDesc::Array { width, id, .. } => {
209 let id_ident = format_ident!("{id}");
210 if let Some(width) = width {
211 quote! {
212 for elem in self.#id_ident.iter() {
213 writer.write_bits(#width, || Ok(*elem))?;
214 }
215 let most_recent_array_size_in_bits = #width * self.#id_ident.len();
216 }
217 } else {
218 quote! {
219 let mut most_recent_array_size_in_bits = 0;
220 for elem in self.#id_ident.iter() {
221 most_recent_array_size_in_bits += elem.size_in_bits()?;
222 elem.serialize(writer)?;
223 }
224 }
225 }
226 }
227 ast::FieldDesc::Typedef { id, .. } => {
228 let id_ident = format_ident!("{id}");
229 quote! { self.#id_ident.serialize(writer)?; }
230 }
231 }
232 }).collect::<Vec<_>>();
233
234 let variant_names = children.get(id).into_iter().flatten().collect::<Vec<_>>();
235
236 let variants = variant_names.iter().map(|name| {
237 let name_ident = format_ident!("{name}");
238 let variant_ident = format_ident!("{name}Builder");
239 quote! { #name_ident(#variant_ident) }
240 });
241
242 let variant_serializers = variant_names.iter().map(|name| {
243 let name_ident = format_ident!("{name}");
244 quote! {
245 Self::#name_ident(x) => {
246 x.serialize(writer)?;
247 }
248 }
249 });
250
251 let children_enum = if has_child {
252 let enum_ident = format_ident!("{id}Child");
253 quote! {
254 #[derive(Debug, Clone, PartialEq, Eq)]
255 pub enum #enum_ident {
256 RawData(Box<[u8]>),
257 #(#variants),*
258 }
259
260 impl Serializable for #enum_ident {
261 fn serialize(&self, writer: &mut impl BitWriter) -> Result<(), SerializeError> {
262 match self {
263 Self::RawData(data) => {
264 for byte in data.iter() {
265 writer.write_bits(8, || Ok(*byte as u64))?;
266 }
267 },
268 #(#variant_serializers),*
269 }
270 Ok(())
271 }
272 }
273 }
274 } else {
275 quote! {}
276 };
277
278 let parent_type_converter = if let Some(parent_id) = parent_id {
279 let parent_enum_ident = format_ident!("{parent_id}Child");
280 let variant_ident = format_ident!("{id}");
281 Some(quote! {
282 impl From<#id_ident> for #parent_enum_ident {
283 fn from(x: #id_ident) -> Self {
284 Self::#variant_ident(x)
285 }
286 }
287 })
288 } else {
289 None
290 };
291
292 let owned_packet_ident = format_ident!("Owned{id}View");
293
294 quote! {
295 #[derive(Debug, Clone, PartialEq, Eq)]
296 pub struct #id_ident {
297 #(#builder_fields),*
298 }
299
300 impl Builder for #id_ident {
301 type OwnedPacket = #owned_packet_ident;
302 }
303
304 impl Serializable for #id_ident {
305 fn serialize(&self, writer: &mut impl BitWriter) -> Result<(), SerializeError> {
306 #(#serializer)*
307 Ok(())
308 }
309 }
310
311 #parent_type_converter
312
313 #children_enum
314 }
315 }
316