1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use std::collections::HashMap;
16 
17 use proc_macro2::TokenStream;
18 use quote::{format_ident, quote};
19 use serde::Deserialize;
20 
21 use crate::{ast, parser::parse_inline, quote_block};
22 
23 #[derive(Deserialize)]
24 struct PacketTest {
25     packet: String,
26     tests: Box<[PacketTestCase]>,
27 }
28 
29 #[derive(Deserialize)]
30 struct PacketTestCase {
31     packed: String,
32     unpacked: UnpackedTestFields,
33     packet: Option<String>,
34 }
35 
36 #[derive(Deserialize)]
37 struct UnpackedTestFields(HashMap<String, Field>);
38 
39 // fields can be scalars, lists, or structs
40 #[derive(Deserialize)]
41 #[serde(untagged)]
42 enum Field {
43     Number(usize),
44     Struct(UnpackedTestFields),
45     List(Box<[ListEntry]>),
46 }
47 
48 // lists can either contain scalars or structs
49 #[derive(Deserialize)]
50 #[serde(untagged)]
51 enum ListEntry {
52     Number(usize),
53     Struct(UnpackedTestFields),
54 }
55 
generate_matchers( base: TokenStream, value: &UnpackedTestFields, filter_fields: &dyn Fn(&str) -> Result<bool, String>, curr_type: &str, type_lookup: &HashMap<&str, HashMap<&str, Option<&str>>>, ) -> Result<TokenStream, String>56 fn generate_matchers(
57     base: TokenStream,
58     value: &UnpackedTestFields,
59     filter_fields: &dyn Fn(&str) -> Result<bool, String>,
60     curr_type: &str,
61     type_lookup: &HashMap<&str, HashMap<&str, Option<&str>>>,
62 ) -> Result<TokenStream, String> {
63     let mut out = vec![];
64 
65     for (field_name, field_value) in value.0.iter() {
66         if !filter_fields(field_name)? {
67             continue;
68         }
69         let getter_ident = format_ident!("get_{field_name}");
70         match field_value {
71             Field::Number(num) => {
72                 let num = *num as u64;
73                 if let Some(field_type) = type_lookup[curr_type][field_name.as_str()] {
74                     let field_ident = format_ident!("{field_type}");
75                     out.push(quote! { assert_eq!(#base.#getter_ident(), #field_ident::new(#num as _).unwrap()); });
76                 } else {
77                     out.push(quote! { assert_eq!(u64::from(#base.#getter_ident()), #num); });
78                 }
79             }
80             Field::List(lst) => {
81                 if field_name == "payload" {
82                     let reference = lst
83                         .iter()
84                         .map(|val| match val {
85                             ListEntry::Number(val) => *val as u8,
86                             _ => unreachable!(),
87                         })
88                         .collect::<Vec<_>>();
89                     out.push(quote! {
90                         assert_eq!(#base.get_raw_payload().collect::<Vec<_>>(), vec![#(#reference),*]);
91                     })
92                 } else {
93                     let get_iter_ident = format_ident!("get_{field_name}_iter");
94                     let vec_ident = format_ident!("{field_name}_vec");
95                     out.push(
96                         quote! { let #vec_ident = #base.#get_iter_ident().collect::<Vec<_>>(); },
97                     );
98 
99                     for (i, val) in lst.iter().enumerate() {
100                         let list_elem = quote! { #vec_ident[#i] };
101                         out.push(match val {
102                             ListEntry::Number(num) => {
103                                 if let Some(field_type) = type_lookup[curr_type][field_name.as_str()] {
104                                     let field_ident = format_ident!("{field_type}");
105                                     quote! { assert_eq!(#list_elem, #field_ident::new(#num as _).unwrap()); }
106                                 } else {
107                                     quote! { assert_eq!(u64::from(#list_elem), #num as u64); }
108                                 }
109                             }
110                             ListEntry::Struct(fields) => {
111                                 generate_matchers(list_elem, fields, &|_| Ok(true), type_lookup[curr_type][field_name.as_str()].unwrap(), type_lookup)?
112                             }
113                         })
114                     }
115                 }
116             }
117             Field::Struct(fields) => {
118                 out.push(generate_matchers(
119                     quote! { #base.#getter_ident() },
120                     fields,
121                     &|_| Ok(true),
122                     type_lookup[curr_type][field_name.as_str()].unwrap(),
123                     type_lookup,
124                 )?);
125             }
126         }
127     }
128     Ok(quote! { { #(#out)* } })
129 }
130 
generate_builder( curr_type: &str, child_type: Option<&str>, type_lookup: &HashMap<&str, HashMap<&str, Option<&str>>>, value: &UnpackedTestFields, ) -> TokenStream131 fn generate_builder(
132     curr_type: &str,
133     child_type: Option<&str>,
134     type_lookup: &HashMap<&str, HashMap<&str, Option<&str>>>,
135     value: &UnpackedTestFields,
136 ) -> TokenStream {
137     let builder_ident = format_ident!("{curr_type}Builder");
138     let child_ident = format_ident!("{curr_type}Child");
139 
140     let curr_fields = &type_lookup[curr_type];
141 
142     let fields = value.0.iter().filter_map(|(field_name, field_value)| {
143         let curr_field_info = curr_fields.get(field_name.as_str());
144 
145         if let Some(curr_field_info) = curr_field_info {
146             let field_name_ident = if field_name == "payload" {
147                 format_ident!("_child_")
148             } else {
149                 format_ident!("{field_name}")
150             };
151             let val = match field_value {
152                 Field::Number(val) => {
153                     if let Some(field) = curr_field_info {
154                         let field_ident = format_ident!("{field}");
155                         quote! { #field_ident::new(#val as _).unwrap() }
156                     } else {
157                         quote! { (#val as u64).try_into().unwrap() }
158                     }
159                 }
160                 Field::Struct(fields) => {
161                     generate_builder(curr_field_info.unwrap(), None, type_lookup, fields)
162                 }
163                 Field::List(lst) => {
164                     let elems = lst.iter().map(|entry| match entry {
165                         ListEntry::Number(val) => {
166                             if let Some(field) = curr_field_info {
167                                 let field_ident = format_ident!("{field}");
168                                 quote! { #field_ident::new(#val as _).unwrap() }
169                             } else {
170                                 quote! { (#val as u64).try_into().unwrap() }
171                             }
172                         }
173                         ListEntry::Struct(fields) => {
174                             generate_builder(curr_field_info.unwrap(), None, type_lookup, fields)
175                         }
176                     });
177                     quote! { vec![#(#elems),*].into_boxed_slice() }
178                 }
179             };
180 
181             Some(if field_name == "payload" {
182                 quote! { #field_name_ident: #child_ident::RawData(#val) }
183             } else {
184                 quote! { #field_name_ident: #val }
185             })
186         } else {
187             None
188         }
189     });
190 
191     let child_field = if let Some(child_type) = child_type {
192         let child_builder = generate_builder(child_type, None, type_lookup, value);
193         Some(quote! {
194             _child_: #child_builder.into(),
195         })
196     } else {
197         None
198     };
199 
200     quote! {
201         #builder_ident {
202             #child_field
203             #(#fields),*
204         }
205     }
206 }
207 
generate_test_file() -> Result<String, String>208 pub fn generate_test_file() -> Result<String, String> {
209     let mut out = String::new();
210 
211     out.push_str(include_str!("test_preamble.rs"));
212 
213     let file = include_str!("../../../tests/canonical/le_test_vectors.json");
214     let test_vectors: Box<[_]> =
215         serde_json::from_str(file).map_err(|_| "could not parse test vectors")?;
216 
217     let pdl = include_str!("../../../tests/canonical/le_rust_noalloc_test_file.pdl");
218     let ast = parse_inline(&mut ast::SourceDatabase::new(), "test.pdl", pdl.to_owned())
219         .expect("could not parse reference PDL");
220     let packet_lookup =
221         ast.declarations
222             .iter()
223             .filter_map(|decl| match &decl.desc {
224                 ast::DeclDesc::Packet { id, fields, .. }
225                 | ast::DeclDesc::Struct { id, fields, .. } => Some((
226                     id.as_str(),
227                     fields
228                         .iter()
229                         .filter_map(|field| match &field.desc {
230                             ast::FieldDesc::Body { .. } | ast::FieldDesc::Payload { .. } => {
231                                 Some(("payload", None))
232                             }
233                             ast::FieldDesc::Array { id, type_id, .. } => match type_id {
234                                 Some(type_id) => Some((id.as_str(), Some(type_id.as_str()))),
235                                 None => Some((id.as_str(), None)),
236                             },
237                             ast::FieldDesc::Typedef { id, type_id, .. } => {
238                                 Some((id.as_str(), Some(type_id.as_str())))
239                             }
240                             ast::FieldDesc::Scalar { id, .. } => Some((id.as_str(), None)),
241                             _ => None,
242                         })
243                         .collect::<HashMap<_, _>>(),
244                 )),
245                 _ => None,
246             })
247             .collect::<HashMap<_, _>>();
248 
249     for PacketTest { packet, tests } in test_vectors.iter() {
250         if !pdl.contains(packet) {
251             // huge brain hack to skip unsupported test vectors
252             continue;
253         }
254 
255         for (i, PacketTestCase { packed, unpacked, packet: sub_packet }) in tests.iter().enumerate()
256         {
257             if let Some(sub_packet) = sub_packet {
258                 if !pdl.contains(sub_packet) {
259                     // huge brain hack to skip unsupported test vectors
260                     continue;
261                 }
262             }
263 
264             let test_name_ident = format_ident!("test_{packet}_{i}");
265             let packet_ident = format_ident!("{packet}_instance");
266             let packet_view = format_ident!("{packet}View");
267 
268             let mut leaf_packet = packet;
269 
270             let specialization = if let Some(sub_packet) = sub_packet {
271                 let sub_packet_ident = format_ident!("{}_instance", sub_packet);
272                 let sub_packet_view_ident = format_ident!("{}View", sub_packet);
273 
274                 leaf_packet = sub_packet;
275                 quote! { let #sub_packet_ident = #sub_packet_view_ident::try_parse(#packet_ident).unwrap(); }
276             } else {
277                 quote! {}
278             };
279 
280             let leaf_packet_ident = format_ident!("{leaf_packet}_instance");
281 
282             let packet_matchers = generate_matchers(
283                 quote! { #packet_ident },
284                 unpacked,
285                 &|field| {
286                     Ok(packet_lookup
287                         .get(packet.as_str())
288                         .ok_or(format!("could not find packet {packet}"))?
289                         .contains_key(field))
290                 },
291                 packet,
292                 &packet_lookup,
293             )?;
294 
295             let sub_packet_matchers = generate_matchers(
296                 quote! { #leaf_packet_ident },
297                 unpacked,
298                 &|field| {
299                     Ok(packet_lookup
300                         .get(leaf_packet.as_str())
301                         .ok_or(format!("could not find packet {packet}"))?
302                         .contains_key(field))
303                 },
304                 sub_packet.as_ref().unwrap_or(packet),
305                 &packet_lookup,
306             )?;
307 
308             out.push_str(&quote_block! {
309               #[test]
310               fn #test_name_ident() {
311                 let base = hex_str_to_byte_vector(#packed);
312                 let #packet_ident = #packet_view::try_parse(SizedBitSlice::from(&base[..]).into()).unwrap();
313 
314                 #specialization
315 
316                 #packet_matchers
317                 #sub_packet_matchers
318               }
319             });
320 
321             let builder = generate_builder(packet, sub_packet.as_deref(), &packet_lookup, unpacked);
322 
323             let test_name_ident = format_ident!("test_{packet}_builder_{i}");
324             out.push_str(&quote_block! {
325               #[test]
326               fn #test_name_ident() {
327                 let packed = hex_str_to_byte_vector(#packed);
328                 let serialized = #builder.to_vec().unwrap();
329                 assert_eq!(packed, serialized);
330               }
331             });
332         }
333     }
334 
335     Ok(out)
336 }
337