1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 use crate::backends::rust::{mask_bits, types, ToIdent, ToUpperCamelCase};
16 use crate::{analyzer, ast};
17 use quote::{format_ident, quote};
18
19 /// Generate a range check for a scalar value backed to a rust type
20 /// that exceeds the actual size of the PDL field.
range_check( value: proc_macro2::TokenStream, width: usize, packet_name: &str, field_name: &str, ) -> proc_macro2::TokenStream21 fn range_check(
22 value: proc_macro2::TokenStream,
23 width: usize,
24 packet_name: &str,
25 field_name: &str,
26 ) -> proc_macro2::TokenStream {
27 let max_value = mask_bits(width, "u64");
28 quote! {
29 if #value > #max_value {
30 return Err(EncodeError::InvalidScalarValue {
31 packet: #packet_name,
32 field: #field_name,
33 value: #value as u64,
34 maximum_value: #max_value as u64,
35 })
36 }
37 }
38 }
39
40 /// Represents the computed size of a packet,
41 /// compoased of constant and variable size fields.
42 struct RuntimeSize {
43 constant: usize,
44 variable: Vec<proc_macro2::TokenStream>,
45 }
46
47 impl RuntimeSize {
payload_size() -> Self48 fn payload_size() -> Self {
49 RuntimeSize { constant: 0, variable: vec![quote! { self.payload.len() }] }
50 }
51 }
52
53 impl std::ops::AddAssign<&RuntimeSize> for RuntimeSize {
add_assign(&mut self, other: &RuntimeSize)54 fn add_assign(&mut self, other: &RuntimeSize) {
55 self.constant += other.constant;
56 self.variable.extend_from_slice(&other.variable)
57 }
58 }
59
60 impl quote::ToTokens for RuntimeSize {
to_tokens(&self, tokens: &mut proc_macro2::TokenStream)61 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
62 let constant = proc_macro2::Literal::usize_unsuffixed(self.constant);
63 tokens.extend(match self {
64 RuntimeSize { variable, .. } if variable.is_empty() => quote! { #constant },
65 RuntimeSize { variable, constant: 0 } => quote! { #(#variable)+* },
66 RuntimeSize { variable, .. } => quote! { #constant + #(#variable)+* },
67 })
68 }
69 }
70
71 /// Represents part of a compound bit-field.
72 struct BitField {
73 value: proc_macro2::TokenStream,
74 field_type: types::Integer,
75 shift: usize,
76 }
77
78 struct Encoder {
79 endianness: ast::EndiannessValue,
80 buf: proc_macro2::Ident,
81 packet_name: String,
82 packet_size: RuntimeSize,
83 payload_size: RuntimeSize,
84 tokens: proc_macro2::TokenStream,
85 bit_shift: usize,
86 bit_fields: Vec<BitField>,
87 }
88
89 impl Encoder {
new( endianness: ast::EndiannessValue, packet_name: &str, buf: proc_macro2::Ident, payload_size: RuntimeSize, ) -> Self90 pub fn new(
91 endianness: ast::EndiannessValue,
92 packet_name: &str,
93 buf: proc_macro2::Ident,
94 payload_size: RuntimeSize,
95 ) -> Self {
96 Encoder {
97 buf,
98 packet_name: packet_name.to_owned(),
99 endianness,
100 packet_size: RuntimeSize { constant: 0, variable: vec![] },
101 payload_size,
102 tokens: quote! {},
103 bit_shift: 0,
104 bit_fields: vec![],
105 }
106 }
107
encode_typedef_field( &mut self, scope: &analyzer::Scope<'_>, schema: &analyzer::Schema, id: &str, type_id: &str, )108 fn encode_typedef_field(
109 &mut self,
110 scope: &analyzer::Scope<'_>,
111 schema: &analyzer::Schema,
112 id: &str,
113 type_id: &str,
114 ) {
115 assert_eq!(self.bit_shift, 0, "Typedef field does not start on an octet boundary");
116
117 let decl = scope.typedef[type_id];
118 let id = id.to_ident();
119 let buf = &self.buf;
120
121 self.tokens.extend(match &decl.desc {
122 ast::DeclDesc::Checksum { .. } => todo!(),
123 ast::DeclDesc::CustomField { width: Some(width), .. } => {
124 let backing_type = types::Integer::new(*width);
125 let put_uint = types::put_uint(
126 self.endianness,
127 "e! { #backing_type::from(self.#id) },
128 *width,
129 &self.buf,
130 );
131 quote! {
132 #put_uint;
133 }
134 }
135 ast::DeclDesc::Struct { .. } | ast::DeclDesc::CustomField { .. } => quote! {
136 self.#id.encode(#buf)?;
137 },
138 _ => todo!("{:?}", decl),
139 });
140
141 match schema.decl_size(decl.key) {
142 analyzer::Size::Static(s) => self.packet_size.constant += s / 8,
143 _ => self.packet_size.variable.push(quote! { self.#id.encoded_len() }),
144 }
145 }
146
encode_optional_field( &mut self, scope: &analyzer::Scope<'_>, _schema: &analyzer::Schema, field: &ast::Field, )147 fn encode_optional_field(
148 &mut self,
149 scope: &analyzer::Scope<'_>,
150 _schema: &analyzer::Schema,
151 field: &ast::Field,
152 ) {
153 assert_eq!(self.bit_shift, 0, "Optional field does not start on an octet boundary");
154
155 self.tokens.extend(match &field.desc {
156 ast::FieldDesc::Scalar { id, width } => {
157 let field_name = id;
158 let id = id.to_ident();
159 let backing_type = types::Integer::new(*width);
160 let put_uint = types::put_uint(self.endianness, "e!(*#id), *width, &self.buf);
161 let range_check = (backing_type.width > *width)
162 .then(|| range_check(quote! { *#id }, *width, &self.packet_name, field_name));
163 quote! {
164 if let Some(#id) = &self.#id {
165 #range_check
166 #put_uint;
167 }
168 }
169 }
170 ast::FieldDesc::Typedef { id, type_id } => match &scope.typedef[type_id].desc {
171 ast::DeclDesc::Enum { width, .. } => {
172 let id = id.to_ident();
173 let backing_type = types::Integer::new(*width);
174 let put_uint = types::put_uint(
175 self.endianness,
176 "e!(#backing_type::from(#id)),
177 *width,
178 &self.buf,
179 );
180
181 quote! {
182 if let Some(#id) = &self.#id {
183 #put_uint;
184 }
185 }
186 }
187 ast::DeclDesc::Struct { .. } => {
188 let id = id.to_ident();
189 let buf = &self.buf;
190
191 quote! {
192 if let Some(#id) = &self.#id {
193 #id.encode(#buf)?;
194 }
195 }
196 }
197 _ => unreachable!(),
198 },
199 _ => unreachable!(),
200 });
201
202 self.packet_size.variable.push(match &field.desc {
203 ast::FieldDesc::Scalar { id, width } => {
204 let id = id.to_ident();
205 let size = width / 8;
206 quote! { if self.#id.is_some() { #size } else { 0 } }
207 }
208 ast::FieldDesc::Typedef { id, type_id } => match &scope.typedef[type_id].desc {
209 ast::DeclDesc::Enum { width, .. } => {
210 let id = id.to_ident();
211 let size = width / 8;
212 quote! { if self.#id.is_some() { #size } else { 0 } }
213 }
214 ast::DeclDesc::Struct { .. } => {
215 let id = id.to_ident();
216 let type_id = type_id.to_ident();
217 quote! {
218 &self.#id
219 .as_ref()
220 .map(#type_id::encoded_len)
221 .unwrap_or(0)
222 }
223 }
224 _ => unreachable!(),
225 },
226 _ => unreachable!(),
227 })
228 }
229
encode_bit_field( &mut self, scope: &analyzer::Scope<'_>, schema: &analyzer::Schema, field: &ast::Field, )230 fn encode_bit_field(
231 &mut self,
232 scope: &analyzer::Scope<'_>,
233 schema: &analyzer::Schema,
234 field: &ast::Field,
235 ) {
236 let width = schema.field_size(field.key).static_().unwrap();
237 let shift = self.bit_shift;
238
239 match &field.desc {
240 ast::FieldDesc::Flag { optional_field_id, set_value, .. } => {
241 let optional_field_id = optional_field_id.to_ident();
242 let cond_value_present =
243 syn::parse_str::<syn::LitInt>(&format!("{}", set_value)).unwrap();
244 let cond_value_absent =
245 syn::parse_str::<syn::LitInt>(&format!("{}", 1 - set_value)).unwrap();
246 self.bit_fields.push(BitField {
247 value: quote! {
248 if self.#optional_field_id.is_some() {
249 #cond_value_present
250 } else {
251 #cond_value_absent
252 }
253 },
254 field_type: types::Integer::new(1),
255 shift,
256 });
257 }
258 ast::FieldDesc::Scalar { id, width } => {
259 let field_name = id;
260 let field_id = id.to_ident();
261 let field_type = types::Integer::new(*width);
262 if field_type.width > *width {
263 self.tokens.extend(range_check(
264 quote! { self.#field_id() },
265 *width,
266 &self.packet_name,
267 field_name,
268 ));
269 }
270 self.bit_fields.push(BitField {
271 value: quote! { self.#field_id() },
272 field_type,
273 shift,
274 });
275 }
276 ast::FieldDesc::FixedEnum { enum_id, tag_id, .. } => {
277 let field_type = types::Integer::new(width);
278 let enum_id = enum_id.to_ident();
279 let tag_id = format_ident!("{}", tag_id.to_upper_camel_case());
280 self.bit_fields.push(BitField {
281 value: quote!(#field_type::from(#enum_id::#tag_id)),
282 field_type,
283 shift,
284 });
285 }
286 ast::FieldDesc::FixedScalar { value, .. } => {
287 let field_type = types::Integer::new(width);
288 let value = proc_macro2::Literal::usize_unsuffixed(*value);
289 self.bit_fields.push(BitField { value: quote!(#value), field_type, shift });
290 }
291 ast::FieldDesc::Typedef { id, .. } => {
292 let id = id.to_ident();
293 let field_type = types::Integer::new(width);
294 self.bit_fields.push(BitField {
295 value: quote!(#field_type::from(self.#id())),
296 field_type,
297 shift,
298 });
299 }
300 ast::FieldDesc::Reserved { .. } => {
301 // Nothing to do here.
302 }
303 ast::FieldDesc::Size { field_id, width, .. } => {
304 let packet_name = &self.packet_name;
305 let max_value = mask_bits(*width, "usize");
306
307 let decl = scope.typedef.get(&self.packet_name).unwrap();
308 let value_field = scope
309 .iter_fields(decl)
310 .find(|field| match &field.desc {
311 ast::FieldDesc::Payload { .. } => field_id == "_payload_",
312 ast::FieldDesc::Body { .. } => field_id == "_body_",
313 _ => field.id() == Some(field_id),
314 })
315 .unwrap();
316
317 let field_name = field_id.to_ident();
318 let field_type = types::Integer::new(*width);
319 // TODO: size modifier
320
321 let value_field_decl = scope.get_type_declaration(value_field);
322 let array_size = match (&value_field.desc, value_field_decl.map(|decl| &decl.desc))
323 {
324 (ast::FieldDesc::Payload { size_modifier: Some(size_modifier) }, _) => {
325 let size_modifier = proc_macro2::Literal::usize_unsuffixed(
326 size_modifier
327 .parse::<usize>()
328 .expect("failed to parse the size modifier"),
329 );
330 let payload_size = &self.payload_size;
331 quote! { (#payload_size + #size_modifier) }
332 }
333 (ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body { .. }, _) => {
334 let payload_size = &self.payload_size;
335 quote! { #payload_size }
336 }
337 (ast::FieldDesc::Array { width: Some(width), .. }, _)
338 | (ast::FieldDesc::Array { .. }, Some(ast::DeclDesc::Enum { width, .. })) => {
339 let size = width / 8;
340 if size == 1 {
341 quote! { self.#field_name.len() }
342 } else {
343 let size = proc_macro2::Literal::usize_unsuffixed(size);
344 quote! { (self.#field_name.len() * #size) }
345 }
346 }
347 (ast::FieldDesc::Array { .. }, _) => {
348 let field_size_name = format_ident!("{field_id}_size");
349 self.tokens.extend(quote! {
350 let #field_size_name = self.#field_name
351 .iter()
352 .map(Packet::encoded_len)
353 .sum::<usize>();
354 });
355 quote! { #field_size_name }
356 }
357 _ => panic!("Unexpected size field: {field:?}"),
358 };
359
360 self.tokens.extend(quote! {
361 if #array_size > #max_value {
362 return Err(EncodeError::SizeOverflow {
363 packet: #packet_name,
364 field: #field_id,
365 size: #array_size,
366 maximum_size: #max_value,
367 })
368 }
369 });
370
371 self.bit_fields.push(BitField {
372 value: quote!(#array_size as #field_type),
373 field_type,
374 shift,
375 });
376 }
377 ast::FieldDesc::ElementSize { field_id, width, .. } => {
378 let field_name = field_id.to_ident();
379 let field_type = types::Integer::new(*width);
380 let field_element_size_name = format_ident!("{field_id}_element_size");
381 let packet_name = &self.packet_name;
382 let max_value = mask_bits(*width, "usize");
383 self.tokens.extend(quote! {
384 let #field_element_size_name = self.#field_name
385 .get(0)
386 .map_or(0, Packet::encoded_len);
387
388 for (element_index, element) in self.#field_name.iter().enumerate() {
389 if element.encoded_len() != #field_element_size_name {
390 return Err(EncodeError::InvalidArrayElementSize {
391 packet: #packet_name,
392 field: #field_id,
393 size: element.encoded_len(),
394 expected_size: #field_element_size_name,
395 element_index,
396 })
397 }
398 }
399 if #field_element_size_name > #max_value {
400 return Err(EncodeError::SizeOverflow {
401 packet: #packet_name,
402 field: #field_id,
403 size: #field_element_size_name,
404 maximum_size: #max_value,
405 })
406 }
407 let #field_element_size_name = #field_element_size_name as #field_type;
408 });
409 self.bit_fields.push(BitField {
410 value: quote!(#field_element_size_name),
411 field_type,
412 shift,
413 });
414 }
415 ast::FieldDesc::Count { field_id, width, .. } => {
416 let field_name = field_id.to_ident();
417 let field_type = types::Integer::new(*width);
418 if field_type.width > *width {
419 let packet_name = &self.packet_name;
420 let max_value = mask_bits(*width, "usize");
421 self.tokens.extend(quote! {
422 if self.#field_name.len() > #max_value {
423 return Err(EncodeError::CountOverflow {
424 packet: #packet_name,
425 field: #field_id,
426 count: self.#field_name.len(),
427 maximum_count: #max_value,
428 })
429 }
430 });
431 }
432 self.bit_fields.push(BitField {
433 value: quote!(self.#field_name.len() as #field_type),
434 field_type,
435 shift,
436 });
437 }
438 _ => todo!("{field:?}"),
439 }
440
441 self.bit_shift += width;
442 if self.bit_shift % 8 == 0 {
443 self.pack_bit_fields()
444 }
445 }
446
pack_bit_fields(&mut self)447 fn pack_bit_fields(&mut self) {
448 assert_eq!(self.bit_shift % 8, 0);
449 let chunk_type = types::Integer::new(self.bit_shift);
450 let values = self
451 .bit_fields
452 .drain(..)
453 .map(|BitField { mut value, field_type, shift }| {
454 if field_type.width != chunk_type.width {
455 // We will be combining values with `|`, so we
456 // need to cast them first.
457 value = quote! { (#value as #chunk_type) };
458 }
459 if shift > 0 {
460 let op = quote!(<<);
461 let shift = proc_macro2::Literal::usize_unsuffixed(shift);
462 value = quote! { (#value #op #shift) };
463 }
464 value
465 })
466 .collect::<Vec<_>>();
467
468 self.tokens.extend(match values.as_slice() {
469 [] => {
470 let buf = format_ident!("{}", self.buf);
471 let count = proc_macro2::Literal::usize_unsuffixed(self.bit_shift / 8);
472 quote! {
473 #buf.put_bytes(0, #count);
474 }
475 }
476 [value] => {
477 let put = types::put_uint(self.endianness, value, self.bit_shift, &self.buf);
478 quote! {
479 #put;
480 }
481 }
482 _ => {
483 let put =
484 types::put_uint(self.endianness, "e!(value), self.bit_shift, &self.buf);
485 quote! {
486 let value = #(#values)|*;
487 #put;
488 }
489 }
490 });
491
492 self.packet_size.constant += self.bit_shift / 8;
493 self.bit_shift = 0;
494 }
495
encode_array_field( &mut self, _scope: &analyzer::Scope<'_>, schema: &analyzer::Schema, id: &str, width: Option<usize>, padding_size: Option<usize>, decl: Option<&ast::Decl>, )496 fn encode_array_field(
497 &mut self,
498 _scope: &analyzer::Scope<'_>,
499 schema: &analyzer::Schema,
500 id: &str,
501 width: Option<usize>,
502 padding_size: Option<usize>,
503 decl: Option<&ast::Decl>,
504 ) {
505 assert_eq!(self.bit_shift, 0, "Array field does not start on an octet boundary");
506
507 let buf = &self.buf;
508
509 // Code to encode one array element.
510 let put_element = match width {
511 Some(width) => {
512 let value = quote!(*elem);
513 types::put_uint(self.endianness, &value, width, &self.buf)
514 }
515 None => {
516 if let Some(ast::DeclDesc::Enum { width, .. }) = decl.map(|decl| &decl.desc) {
517 let element_type = types::Integer::new(*width);
518 types::put_uint(
519 self.endianness,
520 "e!(#element_type::from(elem)),
521 *width,
522 &self.buf,
523 )
524 } else {
525 quote! {
526 elem.encode(#buf)?
527 }
528 }
529 }
530 };
531
532 let packet_name = &self.packet_name;
533 let field_name = id;
534 let id = id.to_ident();
535
536 let element_width = match &width {
537 Some(width) => Some(*width),
538 None => schema.decl_size(decl.unwrap().key).static_(),
539 };
540
541 let array_size = match element_width {
542 Some(8) => quote! { self.#id.len() },
543 Some(element_width) => {
544 let element_size = proc_macro2::Literal::usize_unsuffixed(element_width / 8);
545 quote! { (self.#id.len() * #element_size) }
546 }
547 _ => {
548 quote! {
549 self.#id
550 .iter()
551 .map(Packet::encoded_len)
552 .sum::<usize>()
553 }
554 }
555 };
556
557 self.tokens.extend(if let Some(padding_size) = padding_size {
558 let padding_octets = padding_size / 8;
559 quote! {
560 let array_size = #array_size;
561 if array_size > #padding_octets {
562 return Err(EncodeError::SizeOverflow {
563 packet: #packet_name,
564 field: #field_name,
565 size: array_size,
566 maximum_size: #padding_octets,
567 })
568 }
569 for elem in &self.#id {
570 #put_element;
571 }
572 #buf.put_bytes(0, #padding_octets - array_size);
573 }
574 } else {
575 quote! {
576 for elem in &self.#id {
577 #put_element;
578 }
579 }
580 });
581
582 if let Some(padding_size) = padding_size {
583 self.packet_size.constant += padding_size / 8;
584 } else {
585 self.packet_size.variable.push(array_size);
586 }
587 }
588
encode_field( &mut self, scope: &analyzer::Scope<'_>, schema: &analyzer::Schema, payload: &proc_macro2::TokenStream, field: &ast::Field, )589 fn encode_field(
590 &mut self,
591 scope: &analyzer::Scope<'_>,
592 schema: &analyzer::Schema,
593 payload: &proc_macro2::TokenStream,
594 field: &ast::Field,
595 ) {
596 match &field.desc {
597 _ if field.cond.is_some() => self.encode_optional_field(scope, schema, field),
598 _ if scope.is_bitfield(field) => self.encode_bit_field(scope, schema, field),
599 ast::FieldDesc::Array { id, width, .. } => self.encode_array_field(
600 scope,
601 schema,
602 id,
603 *width,
604 schema.padded_size(field.key),
605 scope.get_type_declaration(field),
606 ),
607 ast::FieldDesc::Typedef { id, type_id } => {
608 self.encode_typedef_field(scope, schema, id, type_id)
609 }
610 ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body { .. } => {
611 self.tokens.extend(payload.clone());
612 self.packet_size += &self.payload_size
613 }
614 // Padding field handled in serialization of associated array field.
615 ast::FieldDesc::Padding { .. } => (),
616 _ => todo!("Cannot yet serialize {field:?}"),
617 }
618 }
619 }
620
encode_with_parents( scope: &analyzer::Scope<'_>, schema: &analyzer::Schema, endianness: ast::EndiannessValue, buf: proc_macro2::Ident, decl: &ast::Decl, payload_size: RuntimeSize, payload: proc_macro2::TokenStream, ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream)621 fn encode_with_parents(
622 scope: &analyzer::Scope<'_>,
623 schema: &analyzer::Schema,
624 endianness: ast::EndiannessValue,
625 buf: proc_macro2::Ident,
626 decl: &ast::Decl,
627 payload_size: RuntimeSize,
628 payload: proc_macro2::TokenStream,
629 ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
630 let packet_name = decl.id().unwrap();
631 let mut encoder = Encoder::new(endianness, packet_name, buf.clone(), payload_size);
632 for field in decl.fields() {
633 encoder.encode_field(scope, schema, &payload, field);
634 }
635
636 match scope.get_parent(decl) {
637 Some(parent_decl) => encode_with_parents(
638 scope,
639 schema,
640 endianness,
641 buf,
642 parent_decl,
643 encoder.packet_size,
644 encoder.tokens,
645 ),
646 None => {
647 let packet_size = encoder.packet_size;
648 (encoder.tokens, quote! { #packet_size })
649 }
650 }
651 }
652
encode( scope: &analyzer::Scope<'_>, schema: &analyzer::Schema, endianness: ast::EndiannessValue, buf: proc_macro2::Ident, decl: &ast::Decl, ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream)653 pub fn encode(
654 scope: &analyzer::Scope<'_>,
655 schema: &analyzer::Schema,
656 endianness: ast::EndiannessValue,
657 buf: proc_macro2::Ident,
658 decl: &ast::Decl,
659 ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
660 encode_with_parents(
661 scope,
662 schema,
663 endianness,
664 buf.clone(),
665 decl,
666 RuntimeSize::payload_size(),
667 quote! { #buf.put_slice(&self.payload); },
668 )
669 }
670
encode_partial( scope: &analyzer::Scope<'_>, schema: &analyzer::Schema, endianness: ast::EndiannessValue, buf: proc_macro2::Ident, decl: &ast::Decl, ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream, proc_macro2::TokenStream)671 pub fn encode_partial(
672 scope: &analyzer::Scope<'_>,
673 schema: &analyzer::Schema,
674 endianness: ast::EndiannessValue,
675 buf: proc_macro2::Ident,
676 decl: &ast::Decl,
677 ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream, proc_macro2::TokenStream) {
678 let parent_decl = scope.get_parent(decl).unwrap();
679
680 let mut encoder =
681 Encoder::new(endianness, decl.id().unwrap(), buf.clone(), RuntimeSize::payload_size());
682
683 for field in decl.fields() {
684 encoder.encode_field(scope, schema, "e! { #buf.put_slice(&self.payload); }, field);
685 }
686
687 let (encode_parents, encoded_len) = encode_with_parents(
688 scope,
689 schema,
690 endianness,
691 buf,
692 parent_decl,
693 encoder.packet_size,
694 quote! { self.encode_partial(buf)?; },
695 );
696
697 (encoder.tokens, encode_parents, encoded_len)
698 }
699