1 // Copyright 2023 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 use crate::backends::rust_legacy::{mask_bits, types, ToIdent, ToUpperCamelCase}; 16 use crate::{analyzer, ast}; 17 use quote::{format_ident, quote}; 18 19 /// A single bit-field value. 20 struct BitField { 21 value: proc_macro2::TokenStream, // An expression which produces a value. 22 field_type: types::Integer, // The type of the value. 23 shift: usize, // A bit-shift to apply to `value`. 24 } 25 26 pub struct FieldSerializer<'a> { 27 scope: &'a analyzer::Scope<'a>, 28 schema: &'a analyzer::Schema, 29 endianness: ast::EndiannessValue, 30 packet_name: &'a str, 31 span: &'a proc_macro2::Ident, 32 chunk: Vec<BitField>, 33 code: Vec<proc_macro2::TokenStream>, 34 shift: usize, 35 } 36 37 impl<'a> FieldSerializer<'a> { new( scope: &'a analyzer::Scope<'a>, schema: &'a analyzer::Schema, endianness: ast::EndiannessValue, packet_name: &'a str, span: &'a proc_macro2::Ident, ) -> FieldSerializer<'a>38 pub fn new( 39 scope: &'a analyzer::Scope<'a>, 40 schema: &'a analyzer::Schema, 41 endianness: ast::EndiannessValue, 42 packet_name: &'a str, 43 span: &'a proc_macro2::Ident, 44 ) -> FieldSerializer<'a> { 45 FieldSerializer { 46 scope, 47 schema, 48 endianness, 49 packet_name, 50 span, 51 chunk: Vec::new(), 52 code: Vec::new(), 53 shift: 0, 54 } 55 } 56 add(&mut self, field: &ast::Field)57 pub fn add(&mut self, field: &ast::Field) { 58 match &field.desc { 59 _ if field.cond.is_some() => self.add_optional_field(field), 60 _ if self.scope.is_bitfield(field) => self.add_bit_field(field), 61 ast::FieldDesc::Array { id, width, .. } => self.add_array_field( 62 id, 63 *width, 64 self.schema.padded_size(field.key), 65 self.scope.get_type_declaration(field), 66 ), 67 ast::FieldDesc::Typedef { id, type_id } => { 68 self.add_typedef_field(id, type_id); 69 } 70 ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body { .. } => { 71 self.add_payload_field(); 72 } 73 // Padding field handled in serialization of associated array field. 74 ast::FieldDesc::Padding { .. } => (), 75 _ => todo!("Cannot yet serialize {field:?}"), 76 } 77 } 78 add_optional_field(&mut self, field: &ast::Field)79 fn add_optional_field(&mut self, field: &ast::Field) { 80 self.code.push(match &field.desc { 81 ast::FieldDesc::Scalar { id, width } => { 82 let name = id; 83 let id = id.to_ident(); 84 let backing_type = types::Integer::new(*width); 85 let write = types::put_uint(self.endianness, "e!(*#id), *width, self.span); 86 87 let range_check = (backing_type.width > *width).then(|| { 88 let packet_name = &self.packet_name; 89 let max_value = mask_bits(*width, "u64"); 90 91 quote! { 92 if *#id > #max_value { 93 return Err(EncodeError::InvalidScalarValue { 94 packet: #packet_name, 95 field: #name, 96 value: *#id as u64, 97 maximum_value: #max_value as u64, 98 }) 99 } 100 } 101 }); 102 103 quote! { 104 if let Some(#id) = &self.#id { 105 #range_check 106 #write 107 } 108 } 109 } 110 ast::FieldDesc::Typedef { id, type_id } => match &self.scope.typedef[type_id].desc { 111 ast::DeclDesc::Enum { width, .. } => { 112 let id = id.to_ident(); 113 let backing_type = types::Integer::new(*width); 114 let write = types::put_uint( 115 self.endianness, 116 "e!(#backing_type::from(#id)), 117 *width, 118 self.span, 119 ); 120 quote! { 121 if let Some(#id) = &self.#id { 122 #write 123 } 124 } 125 } 126 ast::DeclDesc::Struct { .. } => { 127 let id = id.to_ident(); 128 let span = self.span; 129 quote! { 130 if let Some(#id) = &self.#id { 131 #id.write_to(#span)?; 132 } 133 } 134 } 135 _ => unreachable!(), 136 }, 137 _ => unreachable!(), 138 }) 139 } 140 add_bit_field(&mut self, field: &ast::Field)141 fn add_bit_field(&mut self, field: &ast::Field) { 142 let width = self.schema.field_size(field.key).static_().unwrap(); 143 let shift = self.shift; 144 145 match &field.desc { 146 ast::FieldDesc::Flag { optional_field_id, set_value, .. } => { 147 let optional_field_id = optional_field_id.to_ident(); 148 let cond_value_present = 149 syn::parse_str::<syn::LitInt>(&format!("{}", set_value)).unwrap(); 150 let cond_value_absent = 151 syn::parse_str::<syn::LitInt>(&format!("{}", 1 - set_value)).unwrap(); 152 self.chunk.push(BitField { 153 value: quote! { 154 if self.#optional_field_id.is_some() { 155 #cond_value_present 156 } else { 157 #cond_value_absent 158 } 159 }, 160 field_type: types::Integer::new(1), 161 shift, 162 }); 163 } 164 ast::FieldDesc::Scalar { id, width } => { 165 let field_name = id.to_ident(); 166 let field_type = types::Integer::new(*width); 167 if field_type.width > *width { 168 let packet_name = &self.packet_name; 169 let max_value = mask_bits(*width, "u64"); 170 self.code.push(quote! { 171 if self.#field_name > #max_value { 172 return Err(EncodeError::InvalidScalarValue { 173 packet: #packet_name, 174 field: #id, 175 value: self.#field_name as u64, 176 maximum_value: #max_value, 177 }) 178 } 179 }); 180 } 181 self.chunk.push(BitField { value: quote!(self.#field_name), field_type, shift }); 182 } 183 ast::FieldDesc::FixedEnum { enum_id, tag_id, .. } => { 184 let field_type = types::Integer::new(width); 185 let enum_id = enum_id.to_ident(); 186 let tag_id = format_ident!("{}", tag_id.to_upper_camel_case()); 187 self.chunk.push(BitField { 188 value: quote!(#field_type::from(#enum_id::#tag_id)), 189 field_type, 190 shift, 191 }); 192 } 193 ast::FieldDesc::FixedScalar { value, .. } => { 194 let field_type = types::Integer::new(width); 195 let value = proc_macro2::Literal::usize_unsuffixed(*value); 196 self.chunk.push(BitField { value: quote!(#value), field_type, shift }); 197 } 198 ast::FieldDesc::Typedef { id, .. } => { 199 let field_name = id.to_ident(); 200 let field_type = types::Integer::new(width); 201 self.chunk.push(BitField { 202 value: quote!(#field_type::from(self.#field_name)), 203 field_type, 204 shift, 205 }); 206 } 207 ast::FieldDesc::Reserved { .. } => { 208 // Nothing to do here. 209 } 210 ast::FieldDesc::Size { field_id, width, .. } => { 211 let packet_name = &self.packet_name; 212 let max_value = mask_bits(*width, "usize"); 213 214 let decl = self.scope.typedef.get(self.packet_name).unwrap(); 215 let value_field = self 216 .scope 217 .iter_fields(decl) 218 .find(|field| match &field.desc { 219 ast::FieldDesc::Payload { .. } => field_id == "_payload_", 220 ast::FieldDesc::Body { .. } => field_id == "_body_", 221 _ => field.id() == Some(field_id), 222 }) 223 .unwrap(); 224 225 let field_name = field_id.to_ident(); 226 let field_type = types::Integer::new(*width); 227 // TODO: size modifier 228 229 let value_field_decl = self.scope.get_type_declaration(value_field); 230 231 let field_size_name = format_ident!("{field_id}_size"); 232 let array_size = match (&value_field.desc, value_field_decl.map(|decl| &decl.desc)) 233 { 234 (ast::FieldDesc::Payload { size_modifier: Some(size_modifier) }, _) => { 235 let size_modifier = proc_macro2::Literal::usize_unsuffixed( 236 size_modifier 237 .parse::<usize>() 238 .expect("failed to parse the size modifier"), 239 ); 240 if let ast::DeclDesc::Packet { .. } = &decl.desc { 241 quote! { (self.child.get_total_size() + #size_modifier) } 242 } else { 243 quote! { (self.payload.len() + #size_modifier) } 244 } 245 } 246 (ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body { .. }, _) => { 247 if let ast::DeclDesc::Packet { .. } = &decl.desc { 248 quote! { self.child.get_total_size() } 249 } else { 250 quote! { self.payload.len() } 251 } 252 } 253 (ast::FieldDesc::Array { width: Some(width), .. }, _) 254 | (ast::FieldDesc::Array { .. }, Some(ast::DeclDesc::Enum { width, .. })) => { 255 let byte_width = syn::Index::from(width / 8); 256 if byte_width.index == 1 { 257 quote! { self.#field_name.len() } 258 } else { 259 quote! { (self.#field_name.len() * #byte_width) } 260 } 261 } 262 (ast::FieldDesc::Array { .. }, _) => { 263 self.code.push(quote! { 264 let #field_size_name = self.#field_name 265 .iter() 266 .map(|elem| elem.get_size()) 267 .sum::<usize>(); 268 }); 269 quote! { #field_size_name } 270 } 271 _ => panic!("Unexpected size field: {field:?}"), 272 }; 273 274 self.code.push(quote! { 275 if #array_size > #max_value { 276 return Err(EncodeError::SizeOverflow { 277 packet: #packet_name, 278 field: #field_id, 279 size: #array_size, 280 maximum_size: #max_value, 281 }) 282 } 283 }); 284 285 self.chunk.push(BitField { 286 value: quote!(#array_size as #field_type), 287 field_type, 288 shift, 289 }); 290 } 291 ast::FieldDesc::Count { field_id, width, .. } => { 292 let field_name = field_id.to_ident(); 293 let field_type = types::Integer::new(*width); 294 if field_type.width > *width { 295 let packet_name = &self.packet_name; 296 let max_value = mask_bits(*width, "usize"); 297 self.code.push(quote! { 298 if self.#field_name.len() > #max_value { 299 return Err(EncodeError::CountOverflow { 300 packet: #packet_name, 301 field: #field_id, 302 count: self.#field_name.len(), 303 maximum_count: #max_value, 304 }) 305 } 306 }); 307 } 308 self.chunk.push(BitField { 309 value: quote!(self.#field_name.len() as #field_type), 310 field_type, 311 shift, 312 }); 313 } 314 _ => todo!("{field:?}"), 315 } 316 317 self.shift += width; 318 if self.shift % 8 == 0 { 319 self.pack_bit_fields() 320 } 321 } 322 pack_bit_fields(&mut self)323 fn pack_bit_fields(&mut self) { 324 assert_eq!(self.shift % 8, 0); 325 let chunk_type = types::Integer::new(self.shift); 326 let values = self 327 .chunk 328 .drain(..) 329 .map(|BitField { mut value, field_type, shift }| { 330 if field_type.width != chunk_type.width { 331 // We will be combining values with `|`, so we 332 // need to cast them first. 333 value = quote! { (#value as #chunk_type) }; 334 } 335 if shift > 0 { 336 let op = quote!(<<); 337 let shift = proc_macro2::Literal::usize_unsuffixed(shift); 338 value = quote! { (#value #op #shift) }; 339 } 340 value 341 }) 342 .collect::<Vec<_>>(); 343 344 match values.as_slice() { 345 [] => { 346 let span = format_ident!("{}", self.span); 347 let count = syn::Index::from(self.shift / 8); 348 self.code.push(quote! { 349 #span.put_bytes(0, #count); 350 }); 351 } 352 [value] => { 353 let put = types::put_uint(self.endianness, value, self.shift, self.span); 354 self.code.push(quote! { 355 #put; 356 }); 357 } 358 _ => { 359 let put = types::put_uint(self.endianness, "e!(value), self.shift, self.span); 360 self.code.push(quote! { 361 let value = #(#values)|*; 362 #put; 363 }); 364 } 365 } 366 367 self.shift = 0; 368 } 369 add_array_field( &mut self, id: &str, width: Option<usize>, padding_size: Option<usize>, decl: Option<&ast::Decl>, )370 fn add_array_field( 371 &mut self, 372 id: &str, 373 width: Option<usize>, 374 padding_size: Option<usize>, 375 decl: Option<&ast::Decl>, 376 ) { 377 let span = format_ident!("{}", self.span); 378 let serialize = match width { 379 Some(width) => { 380 let value = quote!(*elem); 381 types::put_uint(self.endianness, &value, width, self.span) 382 } 383 None => { 384 if let Some(ast::DeclDesc::Enum { width, .. }) = decl.map(|decl| &decl.desc) { 385 let element_type = types::Integer::new(*width); 386 types::put_uint( 387 self.endianness, 388 "e!(#element_type::from(elem)), 389 *width, 390 self.span, 391 ) 392 } else { 393 quote! { 394 elem.write_to(#span)? 395 } 396 } 397 } 398 }; 399 400 let packet_name = self.packet_name; 401 let name = id; 402 let id = id.to_ident(); 403 404 if let Some(padding_size) = padding_size { 405 let padding_octets = padding_size / 8; 406 let element_width = match &width { 407 Some(width) => Some(*width), 408 None => self.schema.decl_size(decl.unwrap().key).static_(), 409 }; 410 411 let array_size = match element_width { 412 Some(element_width) => { 413 let element_size = proc_macro2::Literal::usize_unsuffixed(element_width / 8); 414 quote! { self.#id.len() * #element_size } 415 } 416 _ => { 417 quote! { self.#id.iter().fold(0, |size, elem| size + elem.get_size()) } 418 } 419 }; 420 421 self.code.push(quote! { 422 let array_size = #array_size; 423 if array_size > #padding_octets { 424 return Err(EncodeError::SizeOverflow { 425 packet: #packet_name, 426 field: #name, 427 size: array_size, 428 maximum_size: #padding_octets, 429 }) 430 } 431 for elem in &self.#id { 432 #serialize; 433 } 434 #span.put_bytes(0, #padding_octets - array_size); 435 }); 436 } else { 437 self.code.push(quote! { 438 for elem in &self.#id { 439 #serialize; 440 } 441 }); 442 } 443 } 444 add_typedef_field(&mut self, id: &str, type_id: &str)445 fn add_typedef_field(&mut self, id: &str, type_id: &str) { 446 assert_eq!(self.shift, 0, "Typedef field does not start on an octet boundary"); 447 let decl = self.scope.typedef[type_id]; 448 if let ast::DeclDesc::Struct { parent_id: Some(_), .. } = &decl.desc { 449 panic!("Derived struct used in typedef field"); 450 } 451 452 let id = id.to_ident(); 453 let span = format_ident!("{}", self.span); 454 455 self.code.push(match &decl.desc { 456 ast::DeclDesc::Checksum { .. } => todo!(), 457 ast::DeclDesc::CustomField { width: Some(width), .. } => { 458 let backing_type = types::Integer::new(*width); 459 let put_uint = types::put_uint( 460 self.endianness, 461 "e! { #backing_type::from(self.#id) }, 462 *width, 463 self.span, 464 ); 465 quote! { 466 #put_uint; 467 } 468 } 469 ast::DeclDesc::Struct { .. } => quote! { 470 self.#id.write_to(#span)?; 471 }, 472 _ => unreachable!(), 473 }); 474 } 475 add_payload_field(&mut self)476 fn add_payload_field(&mut self) { 477 if self.shift != 0 && self.endianness == ast::EndiannessValue::BigEndian { 478 panic!("Payload field does not start on an octet boundary"); 479 } 480 481 let decl = self.scope.typedef[self.packet_name]; 482 let is_packet = matches!(&decl.desc, ast::DeclDesc::Packet { .. }); 483 484 let child_ids = self 485 .scope 486 .iter_children(decl) 487 .map(|child| child.id().unwrap().to_ident()) 488 .collect::<Vec<_>>(); 489 490 let span = format_ident!("{}", self.span); 491 if self.shift == 0 { 492 if is_packet { 493 let packet_data_child = format_ident!("{}DataChild", self.packet_name); 494 self.code.push(quote! { 495 match &self.child { 496 #(#packet_data_child::#child_ids(child) => child.write_to(#span)?,)* 497 #packet_data_child::Payload(payload) => #span.put_slice(payload), 498 #packet_data_child::None => {}, 499 } 500 }) 501 } else { 502 self.code.push(quote! { 503 #span.put_slice(&self.payload); 504 }); 505 } 506 } else { 507 todo!("Shifted payloads"); 508 } 509 } 510 } 511 512 impl quote::ToTokens for FieldSerializer<'_> { to_tokens(&self, tokens: &mut proc_macro2::TokenStream)513 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { 514 let code = &self.code; 515 tokens.extend(quote! { 516 #(#code)* 517 }); 518 } 519 } 520