1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use std::collections::{btree_map::Entry, BTreeMap, HashMap};
16 
17 use crate::ast;
18 
19 pub struct Schema<'a> {
20     pub packets_and_structs: HashMap<&'a str, PacketOrStruct<'a>>,
21     pub enums: HashMap<&'a str, Enum<'a>>,
22 }
23 
24 pub struct PacketOrStruct<'a> {
25     pub computed_offsets: BTreeMap<ComputedOffsetId<'a>, ComputedOffset<'a>>,
26     pub computed_values: BTreeMap<ComputedValueId<'a>, ComputedValue<'a>>,
27     /// whether the parse of this packet needs to know its length,
28     /// or if the packet can determine its own length
29     pub length: PacketOrStructLength,
30 }
31 
32 pub enum PacketOrStructLength {
33     Static(usize),
34     Dynamic,
35     NeedsExternal,
36 }
37 
38 pub struct Enum<'a> {
39     pub tags: &'a [ast::Tag],
40     pub width: usize,
41 }
42 
43 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
44 pub enum ComputedValueId<'a> {
45     // needed for array fields + varlength structs - note that this is in OCTETS, not BITS
46     // this always works since array entries are either structs (which are byte-aligned) or integer-octet-width scalars
47     FieldSize(&'a str),
48 
49     // needed for arrays with fixed element size (otherwise codegen will loop!)
50     FieldElementSize(&'a str), // note that this is in OCTETS, not BITS
51     FieldCount(&'a str),
52 
53     Custom(u16),
54 }
55 
56 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
57 pub enum ComputedOffsetId<'a> {
58     // these quantities are known by the runtime
59     HeaderStart,
60 
61     // if the packet needs its length, this will be supplied. otherwise it will be computed
62     PacketEnd,
63 
64     // these quantities will be computed and stored in computed_values
65     FieldOffset(&'a str),    // needed for all fields, measured in BITS
66     FieldEndOffset(&'a str), // needed only for Payload + Body fields, as well as variable-size structs (not arrays), measured in BITS
67     Custom(u16),
68     TrailerStart,
69 }
70 
71 #[derive(PartialEq, Eq, Debug, PartialOrd, Ord)]
72 pub enum ComputedValue<'a> {
73     Constant(usize),
74     CountStructsUpToSize {
75         base_id: ComputedOffsetId<'a>,
76         size: ComputedValueId<'a>,
77         struct_type: &'a str,
78     },
79     SizeOfNStructs {
80         base_id: ComputedOffsetId<'a>,
81         n: ComputedValueId<'a>,
82         struct_type: &'a str,
83     },
84     Product(ComputedValueId<'a>, ComputedValueId<'a>),
85     Divide(ComputedValueId<'a>, ComputedValueId<'a>),
86     Difference(ComputedOffsetId<'a>, ComputedOffsetId<'a>),
87     ValueAt {
88         offset: ComputedOffsetId<'a>,
89         width: usize,
90     },
91 }
92 
93 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
94 pub enum ComputedOffset<'a> {
95     ConstantPlusOffsetInBits(ComputedOffsetId<'a>, i64),
96     SumWithOctets(ComputedOffsetId<'a>, ComputedValueId<'a>),
97     Alias(ComputedOffsetId<'a>),
98 }
99 
generate(file: &ast::File) -> Result<Schema, String>100 pub fn generate(file: &ast::File) -> Result<Schema, String> {
101     let mut schema = Schema { packets_and_structs: HashMap::new(), enums: HashMap::new() };
102     match file.endianness.value {
103         ast::EndiannessValue::LittleEndian => {}
104         _ => unimplemented!("Only little_endian endianness supported"),
105     };
106 
107     for decl in &file.declarations {
108         process_decl(&mut schema, decl);
109     }
110 
111     Ok(schema)
112 }
113 
process_decl<'a>(schema: &mut Schema<'a>, decl: &'a ast::Decl)114 fn process_decl<'a>(schema: &mut Schema<'a>, decl: &'a ast::Decl) {
115     match &decl.desc {
116         ast::DeclDesc::Enum { id, tags, width, .. } => process_enum(schema, id, tags, *width),
117         ast::DeclDesc::Packet { id, fields, .. } | ast::DeclDesc::Struct { id, fields, .. } => {
118             process_packet_or_struct(schema, id, fields)
119         }
120         ast::DeclDesc::Group { .. } => todo!(),
121         _ => unimplemented!("type {decl:?} not supported"),
122     }
123 }
124 
process_enum<'a>(schema: &mut Schema<'a>, id: &'a str, tags: &'a [ast::Tag], width: usize)125 fn process_enum<'a>(schema: &mut Schema<'a>, id: &'a str, tags: &'a [ast::Tag], width: usize) {
126     schema.enums.insert(id, Enum { tags, width });
127     schema.packets_and_structs.insert(
128         id,
129         PacketOrStruct {
130             computed_offsets: BTreeMap::new(),
131             computed_values: BTreeMap::new(),
132             length: PacketOrStructLength::Static(width),
133         },
134     );
135 }
136 
process_packet_or_struct<'a>(schema: &mut Schema<'a>, id: &'a str, fields: &'a [ast::Field])137 fn process_packet_or_struct<'a>(schema: &mut Schema<'a>, id: &'a str, fields: &'a [ast::Field]) {
138     schema.packets_and_structs.insert(id, compute_getters(schema, fields));
139 }
140 
compute_getters<'a>(schema: &Schema<'a>, fields: &'a [ast::Field]) -> PacketOrStruct<'a>141 fn compute_getters<'a>(schema: &Schema<'a>, fields: &'a [ast::Field]) -> PacketOrStruct<'a> {
142     let mut prev_pos_id = None;
143     let mut curr_pos_id = ComputedOffsetId::HeaderStart;
144     let mut computed_values = BTreeMap::new();
145     let mut computed_offsets = BTreeMap::new();
146 
147     let mut cnt = 0;
148 
149     let one_id = ComputedValueId::Custom(cnt);
150     let one_val = ComputedValue::Constant(1);
151     cnt += 1;
152     computed_values.insert(one_id, one_val);
153 
154     let mut needs_length = false;
155 
156     for field in fields {
157         // populate this only if we are an array with a knowable size
158         let mut next_prev_pos_id = None;
159 
160         let next_pos = match &field.desc {
161             ast::FieldDesc::Reserved { width } => {
162                 ComputedOffset::ConstantPlusOffsetInBits(curr_pos_id, *width as i64)
163             }
164             ast::FieldDesc::Scalar { id, width } => {
165                 computed_offsets
166                     .insert(ComputedOffsetId::FieldOffset(id), ComputedOffset::Alias(curr_pos_id));
167                 ComputedOffset::ConstantPlusOffsetInBits(curr_pos_id, *width as i64)
168             }
169             ast::FieldDesc::FixedScalar { width, .. } => {
170                 let offset = *width;
171                 ComputedOffset::ConstantPlusOffsetInBits(curr_pos_id, offset as i64)
172             }
173             ast::FieldDesc::FixedEnum { enum_id, .. } => {
174                 let offset = schema.enums[enum_id.as_str()].width;
175                 ComputedOffset::ConstantPlusOffsetInBits(curr_pos_id, offset as i64)
176             }
177             ast::FieldDesc::Size { field_id, width } => {
178                 computed_values.insert(
179                     ComputedValueId::FieldSize(field_id),
180                     ComputedValue::ValueAt { offset: curr_pos_id, width: *width },
181                 );
182                 ComputedOffset::ConstantPlusOffsetInBits(curr_pos_id, *width as i64)
183             }
184             ast::FieldDesc::Count { field_id, width } => {
185                 computed_values.insert(
186                     ComputedValueId::FieldCount(field_id.as_str()),
187                     ComputedValue::ValueAt { offset: curr_pos_id, width: *width },
188                 );
189                 ComputedOffset::ConstantPlusOffsetInBits(curr_pos_id, *width as i64)
190             }
191             ast::FieldDesc::ElementSize { field_id, width } => {
192                 computed_values.insert(
193                     ComputedValueId::FieldElementSize(field_id),
194                     ComputedValue::ValueAt { offset: curr_pos_id, width: *width },
195                 );
196                 ComputedOffset::ConstantPlusOffsetInBits(curr_pos_id, *width as i64)
197             }
198             ast::FieldDesc::Flag { .. } => unimplemented!(),
199             ast::FieldDesc::Group { .. } => {
200                 unimplemented!("this should be removed by the linter...")
201             }
202             ast::FieldDesc::Checksum { .. } => unimplemented!("checksum not supported"),
203             ast::FieldDesc::Body => {
204                 computed_offsets.insert(
205                     ComputedOffsetId::FieldOffset("_body_"),
206                     ComputedOffset::Alias(curr_pos_id),
207                 );
208                 let computed_size_id = ComputedValueId::FieldSize("_body_");
209                 let end_offset = if computed_values.contains_key(&computed_size_id) {
210                     ComputedOffset::SumWithOctets(curr_pos_id, computed_size_id)
211                 } else {
212                     if needs_length {
213                         panic!("only one variable-length field can exist")
214                     }
215                     needs_length = true;
216                     ComputedOffset::Alias(ComputedOffsetId::TrailerStart)
217                 };
218                 computed_offsets.insert(ComputedOffsetId::FieldEndOffset("_body_"), end_offset);
219                 end_offset
220             }
221             ast::FieldDesc::Payload { size_modifier } => {
222                 if size_modifier.is_some() {
223                     unimplemented!("size modifiers not supported")
224                 }
225                 computed_offsets.insert(
226                     ComputedOffsetId::FieldOffset("_payload_"),
227                     ComputedOffset::Alias(curr_pos_id),
228                 );
229                 let computed_size_id = ComputedValueId::FieldSize("_payload_");
230                 let end_offset = if computed_values.contains_key(&computed_size_id) {
231                     ComputedOffset::SumWithOctets(curr_pos_id, computed_size_id)
232                 } else {
233                     if needs_length {
234                         panic!("only one variable-length field can exist")
235                     }
236                     needs_length = true;
237                     ComputedOffset::Alias(ComputedOffsetId::TrailerStart)
238                 };
239                 computed_offsets.insert(ComputedOffsetId::FieldEndOffset("_payload_"), end_offset);
240                 end_offset
241             }
242             ast::FieldDesc::Array {
243                 id,
244                 width,
245                 type_id,
246                 size_modifier,
247                 size: statically_known_count,
248             } => {
249                 if size_modifier.is_some() {
250                     unimplemented!("size modifiers not supported")
251                 }
252 
253                 computed_offsets
254                     .insert(ComputedOffsetId::FieldOffset(id), ComputedOffset::Alias(curr_pos_id));
255 
256                 // there are a few parameters to consider when parsing arrays
257                 // 1: the count of elements
258                 // 2: the total byte size (possibly by subtracting out the len of the trailer)
259                 // 3: whether the structs know their own lengths
260                 // parsing is possible if we know (1 OR 2) AND 3
261 
262                 if let Some(count) = statically_known_count {
263                     computed_values
264                         .insert(ComputedValueId::FieldCount(id), ComputedValue::Constant(*count));
265                 }
266 
267                 let statically_known_width_in_bits = if let Some(type_id) = type_id {
268                     if let PacketOrStructLength::Static(len) =
269                         schema.packets_and_structs[type_id.as_str()].length
270                     {
271                         Some(len)
272                     } else {
273                         None
274                     }
275                 } else if let Some(width) = width {
276                     Some(*width)
277                 } else {
278                     unreachable!()
279                 };
280 
281                 // whether the count is known *prior* to parsing the field
282                 let is_count_known = computed_values.contains_key(&ComputedValueId::FieldCount(id));
283                 // whether the total field size is explicitly specified
284                 let is_total_size_known =
285                     computed_values.contains_key(&ComputedValueId::FieldSize(id));
286 
287                 let element_size = if let Some(type_id) = type_id {
288                     match schema.packets_and_structs[type_id.as_str()].length {
289                         PacketOrStructLength::Static(width) => {
290                             assert!(width % 8 == 0);
291                             Some(width / 8)
292                         }
293                         PacketOrStructLength::Dynamic => None,
294                         PacketOrStructLength::NeedsExternal => None,
295                     }
296                 } else if let Some(width) = width {
297                     assert!(width % 8 == 0);
298                     Some(width / 8)
299                 } else {
300                     unreachable!()
301                 };
302                 if let Some(element_size) = element_size {
303                     computed_values.insert(
304                         ComputedValueId::FieldElementSize(id),
305                         ComputedValue::Constant(element_size),
306                     );
307                 }
308 
309                 // whether we can know the length of each element in the array by greedy parsing,
310                 let structs_know_length = if let Some(type_id) = type_id {
311                     match schema.packets_and_structs[type_id.as_str()].length {
312                         PacketOrStructLength::Static(_) => true,
313                         PacketOrStructLength::Dynamic => true,
314                         PacketOrStructLength::NeedsExternal => {
315                             computed_values.contains_key(&ComputedValueId::FieldElementSize(id))
316                         }
317                     }
318                 } else {
319                     width.is_some()
320                 };
321 
322                 if !structs_know_length {
323                     panic!("structs need to know their own length, if they live in an array")
324                 }
325 
326                 let mut out = None;
327                 if let Some(count) = statically_known_count {
328                     if let Some(width) = statically_known_width_in_bits {
329                         // the fast path, if the count and width are statically known, is to just immediately multiply
330                         // otherwise this becomes a dynamic computation
331                         assert!(width % 8 == 0);
332                         computed_values.insert(
333                             ComputedValueId::FieldSize(id),
334                             ComputedValue::Constant(count * width / 8),
335                         );
336                         out = Some(ComputedOffset::ConstantPlusOffsetInBits(
337                             curr_pos_id,
338                             (count * width) as i64,
339                         ));
340                     }
341                 }
342 
343                 // note: this introduces a forward dependency with the total_size_id
344                 // however, the FieldSize(id) only depends on the FieldElementSize(id) if FieldCount() == true
345                 // thus, there will never be an infinite loop, since the FieldElementSize(id) only depends on the
346                 // FieldSize() if the FieldCount() is not unknown
347                 if !is_count_known {
348                     // the count is not known statically, or from earlier in the packet
349                     // thus, we must compute it from the total size of the field, known either explicitly or implicitly via the trailer
350                     // the fast path is to do a divide, but otherwise we need to loop over the TLVs
351                     computed_values.insert(
352                         ComputedValueId::FieldCount(id),
353                         if computed_values.contains_key(&ComputedValueId::FieldElementSize(id)) {
354                             ComputedValue::Divide(
355                                 ComputedValueId::FieldSize(id),
356                                 ComputedValueId::FieldElementSize(id),
357                             )
358                         } else {
359                             ComputedValue::CountStructsUpToSize {
360                                 base_id: curr_pos_id,
361                                 size: ComputedValueId::FieldSize(id),
362                                 struct_type: type_id.as_ref().unwrap(),
363                             }
364                         },
365                     );
366                 }
367 
368                 if let Some(out) = out {
369                     // we are paddable if the total size is known
370                     next_prev_pos_id = Some(curr_pos_id);
371                     out
372                 } else if is_total_size_known {
373                     // we are paddable if the total size is known
374                     next_prev_pos_id = Some(curr_pos_id);
375                     ComputedOffset::SumWithOctets(curr_pos_id, ComputedValueId::FieldSize(id))
376                 } else if is_count_known {
377                     // we are paddable if the total count is known, since structs know their lengths
378                     next_prev_pos_id = Some(curr_pos_id);
379 
380                     computed_values.insert(
381                         ComputedValueId::FieldSize(id),
382                         if computed_values.contains_key(&ComputedValueId::FieldElementSize(id)) {
383                             ComputedValue::Product(
384                                 ComputedValueId::FieldCount(id),
385                                 ComputedValueId::FieldElementSize(id),
386                             )
387                         } else {
388                             ComputedValue::SizeOfNStructs {
389                                 base_id: curr_pos_id,
390                                 n: ComputedValueId::FieldCount(id),
391                                 struct_type: type_id.as_ref().unwrap(),
392                             }
393                         },
394                     );
395                     ComputedOffset::SumWithOctets(curr_pos_id, ComputedValueId::FieldSize(id))
396                 } else {
397                     // we can try to infer the total size if we are still in the header
398                     // however, we are not paddable in this case
399                     next_prev_pos_id = None;
400 
401                     if needs_length {
402                         panic!("either the total size, or the count of elements in an array, must be known")
403                     }
404                     // now we are in the trailer
405                     computed_values.insert(
406                         ComputedValueId::FieldSize(id),
407                         ComputedValue::Difference(ComputedOffsetId::TrailerStart, curr_pos_id),
408                     );
409                     needs_length = true;
410                     ComputedOffset::Alias(ComputedOffsetId::TrailerStart)
411                 }
412             }
413             ast::FieldDesc::Padding { size } => {
414                 if let Some(prev_pos_id) = prev_pos_id {
415                     ComputedOffset::ConstantPlusOffsetInBits(prev_pos_id, *size as i64)
416                 } else {
417                     panic!("padding must follow array field with known total size")
418                 }
419             }
420             ast::FieldDesc::Typedef { id, type_id } => {
421                 computed_offsets
422                     .insert(ComputedOffsetId::FieldOffset(id), ComputedOffset::Alias(curr_pos_id));
423 
424                 match schema.packets_and_structs[type_id.as_str()].length {
425                     PacketOrStructLength::Static(len) => {
426                         ComputedOffset::ConstantPlusOffsetInBits(curr_pos_id, len as i64)
427                     }
428                     PacketOrStructLength::Dynamic => {
429                         computed_values.insert(
430                             ComputedValueId::FieldSize(id),
431                             ComputedValue::SizeOfNStructs {
432                                 base_id: curr_pos_id,
433                                 n: one_id,
434                                 struct_type: type_id,
435                             },
436                         );
437                         ComputedOffset::SumWithOctets(curr_pos_id, ComputedValueId::FieldSize(id))
438                     }
439                     PacketOrStructLength::NeedsExternal => {
440                         let end_offset = if let Entry::Vacant(entry) =
441                             computed_values.entry(ComputedValueId::FieldSize(id))
442                         {
443                             // its size is presently unknown
444                             if needs_length {
445                                 panic!(
446                                         "cannot have multiple variable-length fields in a single packet/struct"
447                                     )
448                             }
449                             // we are now in the trailer
450                             entry.insert(ComputedValue::Difference(
451                                 ComputedOffsetId::TrailerStart,
452                                 curr_pos_id,
453                             ));
454                             needs_length = true;
455                             ComputedOffset::Alias(ComputedOffsetId::TrailerStart)
456                         } else {
457                             ComputedOffset::SumWithOctets(
458                                 curr_pos_id,
459                                 ComputedValueId::FieldSize(id),
460                             )
461                         };
462                         computed_offsets.insert(ComputedOffsetId::FieldEndOffset(id), end_offset);
463                         end_offset
464                     }
465                 }
466 
467                 // it is possible to size a struct in this variant of PDL, even though the linter doesn't allow it
468             }
469         };
470 
471         prev_pos_id = next_prev_pos_id;
472         curr_pos_id = ComputedOffsetId::Custom(cnt);
473         cnt += 1;
474         computed_offsets.insert(curr_pos_id, next_pos);
475     }
476 
477     // TODO(aryarahul): simplify compute graph to improve trailer resolution?
478 
479     // we are now at the end of the packet
480     let length = if needs_length {
481         // if we needed the length, use the PacketEnd and length to reconstruct the TrailerStart
482         let trailer_length =
483             compute_length_to_goal(&computed_offsets, curr_pos_id, ComputedOffsetId::TrailerStart)
484                 .expect("trailers should have deterministic length");
485         computed_offsets.insert(
486             ComputedOffsetId::TrailerStart,
487             ComputedOffset::ConstantPlusOffsetInBits(ComputedOffsetId::PacketEnd, -trailer_length),
488         );
489         PacketOrStructLength::NeedsExternal
490     } else {
491         // otherwise, try to reconstruct the full length, if possible
492         let full_length =
493             compute_length_to_goal(&computed_offsets, curr_pos_id, ComputedOffsetId::HeaderStart);
494         if let Some(full_length) = full_length {
495             computed_offsets.insert(
496                 ComputedOffsetId::PacketEnd,
497                 ComputedOffset::ConstantPlusOffsetInBits(
498                     ComputedOffsetId::HeaderStart,
499                     full_length,
500                 ),
501             );
502             PacketOrStructLength::Static(full_length as usize)
503         } else {
504             computed_offsets
505                 .insert(ComputedOffsetId::PacketEnd, ComputedOffset::Alias(curr_pos_id));
506             PacketOrStructLength::Dynamic
507         }
508     };
509 
510     PacketOrStruct { computed_values, computed_offsets, length }
511 }
512 
compute_length_to_goal( computed_offsets: &BTreeMap<ComputedOffsetId, ComputedOffset>, start: ComputedOffsetId, goal: ComputedOffsetId, ) -> Option<i64>513 fn compute_length_to_goal(
514     computed_offsets: &BTreeMap<ComputedOffsetId, ComputedOffset>,
515     start: ComputedOffsetId,
516     goal: ComputedOffsetId,
517 ) -> Option<i64> {
518     let mut out = 0;
519     let mut pos = start;
520     while pos != goal {
521         match computed_offsets.get(&pos).ok_or_else(|| format!("key {pos:?} not found")).unwrap() {
522             ComputedOffset::ConstantPlusOffsetInBits(base_id, offset) => {
523                 out += offset;
524                 pos = *base_id;
525             }
526             ComputedOffset::Alias(alias) => pos = *alias,
527             ComputedOffset::SumWithOctets { .. } => return None,
528         }
529     }
530     Some(out)
531 }
532