1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 use std::collections::HashMap;
16
17 use proc_macro2::TokenStream;
18 use quote::{format_ident, quote};
19 use serde::Deserialize;
20
21 use crate::{ast, parser::parse_inline, quote_block};
22
23 #[derive(Deserialize)]
24 struct PacketTest {
25 packet: String,
26 tests: Box<[PacketTestCase]>,
27 }
28
29 #[derive(Deserialize)]
30 struct PacketTestCase {
31 packed: String,
32 unpacked: UnpackedTestFields,
33 packet: Option<String>,
34 }
35
36 #[derive(Deserialize)]
37 struct UnpackedTestFields(HashMap<String, Field>);
38
39 // fields can be scalars, lists, or structs
40 #[derive(Deserialize)]
41 #[serde(untagged)]
42 enum Field {
43 Number(usize),
44 Struct(UnpackedTestFields),
45 List(Box<[ListEntry]>),
46 }
47
48 // lists can either contain scalars or structs
49 #[derive(Deserialize)]
50 #[serde(untagged)]
51 enum ListEntry {
52 Number(usize),
53 Struct(UnpackedTestFields),
54 }
55
generate_matchers( base: TokenStream, value: &UnpackedTestFields, filter_fields: &dyn Fn(&str) -> Result<bool, String>, curr_type: &str, type_lookup: &HashMap<&str, HashMap<&str, Option<&str>>>, ) -> Result<TokenStream, String>56 fn generate_matchers(
57 base: TokenStream,
58 value: &UnpackedTestFields,
59 filter_fields: &dyn Fn(&str) -> Result<bool, String>,
60 curr_type: &str,
61 type_lookup: &HashMap<&str, HashMap<&str, Option<&str>>>,
62 ) -> Result<TokenStream, String> {
63 let mut out = vec![];
64
65 for (field_name, field_value) in value.0.iter() {
66 if !filter_fields(field_name)? {
67 continue;
68 }
69 let getter_ident = format_ident!("get_{field_name}");
70 match field_value {
71 Field::Number(num) => {
72 let num = *num as u64;
73 if let Some(field_type) = type_lookup[curr_type][field_name.as_str()] {
74 let field_ident = format_ident!("{field_type}");
75 out.push(quote! { assert_eq!(#base.#getter_ident(), #field_ident::new(#num as _).unwrap()); });
76 } else {
77 out.push(quote! { assert_eq!(u64::from(#base.#getter_ident()), #num); });
78 }
79 }
80 Field::List(lst) => {
81 if field_name == "payload" {
82 let reference = lst
83 .iter()
84 .map(|val| match val {
85 ListEntry::Number(val) => *val as u8,
86 _ => unreachable!(),
87 })
88 .collect::<Vec<_>>();
89 out.push(quote! {
90 assert_eq!(#base.get_raw_payload().collect::<Vec<_>>(), vec![#(#reference),*]);
91 })
92 } else {
93 let get_iter_ident = format_ident!("get_{field_name}_iter");
94 let vec_ident = format_ident!("{field_name}_vec");
95 out.push(
96 quote! { let #vec_ident = #base.#get_iter_ident().collect::<Vec<_>>(); },
97 );
98
99 for (i, val) in lst.iter().enumerate() {
100 let list_elem = quote! { #vec_ident[#i] };
101 out.push(match val {
102 ListEntry::Number(num) => {
103 if let Some(field_type) = type_lookup[curr_type][field_name.as_str()] {
104 let field_ident = format_ident!("{field_type}");
105 quote! { assert_eq!(#list_elem, #field_ident::new(#num as _).unwrap()); }
106 } else {
107 quote! { assert_eq!(u64::from(#list_elem), #num as u64); }
108 }
109 }
110 ListEntry::Struct(fields) => {
111 generate_matchers(list_elem, fields, &|_| Ok(true), type_lookup[curr_type][field_name.as_str()].unwrap(), type_lookup)?
112 }
113 })
114 }
115 }
116 }
117 Field::Struct(fields) => {
118 out.push(generate_matchers(
119 quote! { #base.#getter_ident() },
120 fields,
121 &|_| Ok(true),
122 type_lookup[curr_type][field_name.as_str()].unwrap(),
123 type_lookup,
124 )?);
125 }
126 }
127 }
128 Ok(quote! { { #(#out)* } })
129 }
130
generate_builder( curr_type: &str, child_type: Option<&str>, type_lookup: &HashMap<&str, HashMap<&str, Option<&str>>>, value: &UnpackedTestFields, ) -> TokenStream131 fn generate_builder(
132 curr_type: &str,
133 child_type: Option<&str>,
134 type_lookup: &HashMap<&str, HashMap<&str, Option<&str>>>,
135 value: &UnpackedTestFields,
136 ) -> TokenStream {
137 let builder_ident = format_ident!("{curr_type}Builder");
138 let child_ident = format_ident!("{curr_type}Child");
139
140 let curr_fields = &type_lookup[curr_type];
141
142 let fields = value.0.iter().filter_map(|(field_name, field_value)| {
143 let curr_field_info = curr_fields.get(field_name.as_str());
144
145 if let Some(curr_field_info) = curr_field_info {
146 let field_name_ident = if field_name == "payload" {
147 format_ident!("_child_")
148 } else {
149 format_ident!("{field_name}")
150 };
151 let val = match field_value {
152 Field::Number(val) => {
153 if let Some(field) = curr_field_info {
154 let field_ident = format_ident!("{field}");
155 quote! { #field_ident::new(#val as _).unwrap() }
156 } else {
157 quote! { (#val as u64).try_into().unwrap() }
158 }
159 }
160 Field::Struct(fields) => {
161 generate_builder(curr_field_info.unwrap(), None, type_lookup, fields)
162 }
163 Field::List(lst) => {
164 let elems = lst.iter().map(|entry| match entry {
165 ListEntry::Number(val) => {
166 if let Some(field) = curr_field_info {
167 let field_ident = format_ident!("{field}");
168 quote! { #field_ident::new(#val as _).unwrap() }
169 } else {
170 quote! { (#val as u64).try_into().unwrap() }
171 }
172 }
173 ListEntry::Struct(fields) => {
174 generate_builder(curr_field_info.unwrap(), None, type_lookup, fields)
175 }
176 });
177 quote! { vec![#(#elems),*].into_boxed_slice() }
178 }
179 };
180
181 Some(if field_name == "payload" {
182 quote! { #field_name_ident: #child_ident::RawData(#val) }
183 } else {
184 quote! { #field_name_ident: #val }
185 })
186 } else {
187 None
188 }
189 });
190
191 let child_field = if let Some(child_type) = child_type {
192 let child_builder = generate_builder(child_type, None, type_lookup, value);
193 Some(quote! {
194 _child_: #child_builder.into(),
195 })
196 } else {
197 None
198 };
199
200 quote! {
201 #builder_ident {
202 #child_field
203 #(#fields),*
204 }
205 }
206 }
207
generate_test_file() -> Result<String, String>208 pub fn generate_test_file() -> Result<String, String> {
209 let mut out = String::new();
210
211 out.push_str(include_str!("test_preamble.rs"));
212
213 let file = include_str!("../../../tests/canonical/le_test_vectors.json");
214 let test_vectors: Box<[_]> =
215 serde_json::from_str(file).map_err(|_| "could not parse test vectors")?;
216
217 let pdl = include_str!("../../../tests/canonical/le_rust_noalloc_test_file.pdl");
218 let ast = parse_inline(&mut ast::SourceDatabase::new(), "test.pdl", pdl.to_owned())
219 .expect("could not parse reference PDL");
220 let packet_lookup =
221 ast.declarations
222 .iter()
223 .filter_map(|decl| match &decl.desc {
224 ast::DeclDesc::Packet { id, fields, .. }
225 | ast::DeclDesc::Struct { id, fields, .. } => Some((
226 id.as_str(),
227 fields
228 .iter()
229 .filter_map(|field| match &field.desc {
230 ast::FieldDesc::Body { .. } | ast::FieldDesc::Payload { .. } => {
231 Some(("payload", None))
232 }
233 ast::FieldDesc::Array { id, type_id, .. } => match type_id {
234 Some(type_id) => Some((id.as_str(), Some(type_id.as_str()))),
235 None => Some((id.as_str(), None)),
236 },
237 ast::FieldDesc::Typedef { id, type_id, .. } => {
238 Some((id.as_str(), Some(type_id.as_str())))
239 }
240 ast::FieldDesc::Scalar { id, .. } => Some((id.as_str(), None)),
241 _ => None,
242 })
243 .collect::<HashMap<_, _>>(),
244 )),
245 _ => None,
246 })
247 .collect::<HashMap<_, _>>();
248
249 for PacketTest { packet, tests } in test_vectors.iter() {
250 if !pdl.contains(packet) {
251 // huge brain hack to skip unsupported test vectors
252 continue;
253 }
254
255 for (i, PacketTestCase { packed, unpacked, packet: sub_packet }) in tests.iter().enumerate()
256 {
257 if let Some(sub_packet) = sub_packet {
258 if !pdl.contains(sub_packet) {
259 // huge brain hack to skip unsupported test vectors
260 continue;
261 }
262 }
263
264 let test_name_ident = format_ident!("test_{packet}_{i}");
265 let packet_ident = format_ident!("{packet}_instance");
266 let packet_view = format_ident!("{packet}View");
267
268 let mut leaf_packet = packet;
269
270 let specialization = if let Some(sub_packet) = sub_packet {
271 let sub_packet_ident = format_ident!("{}_instance", sub_packet);
272 let sub_packet_view_ident = format_ident!("{}View", sub_packet);
273
274 leaf_packet = sub_packet;
275 quote! { let #sub_packet_ident = #sub_packet_view_ident::try_parse(#packet_ident).unwrap(); }
276 } else {
277 quote! {}
278 };
279
280 let leaf_packet_ident = format_ident!("{leaf_packet}_instance");
281
282 let packet_matchers = generate_matchers(
283 quote! { #packet_ident },
284 unpacked,
285 &|field| {
286 Ok(packet_lookup
287 .get(packet.as_str())
288 .ok_or(format!("could not find packet {packet}"))?
289 .contains_key(field))
290 },
291 packet,
292 &packet_lookup,
293 )?;
294
295 let sub_packet_matchers = generate_matchers(
296 quote! { #leaf_packet_ident },
297 unpacked,
298 &|field| {
299 Ok(packet_lookup
300 .get(leaf_packet.as_str())
301 .ok_or(format!("could not find packet {packet}"))?
302 .contains_key(field))
303 },
304 sub_packet.as_ref().unwrap_or(packet),
305 &packet_lookup,
306 )?;
307
308 out.push_str("e_block! {
309 #[test]
310 fn #test_name_ident() {
311 let base = hex_str_to_byte_vector(#packed);
312 let #packet_ident = #packet_view::try_parse(SizedBitSlice::from(&base[..]).into()).unwrap();
313
314 #specialization
315
316 #packet_matchers
317 #sub_packet_matchers
318 }
319 });
320
321 let builder = generate_builder(packet, sub_packet.as_deref(), &packet_lookup, unpacked);
322
323 let test_name_ident = format_ident!("test_{packet}_builder_{i}");
324 out.push_str("e_block! {
325 #[test]
326 fn #test_name_ident() {
327 let packed = hex_str_to_byte_vector(#packed);
328 let serialized = #builder.to_vec().unwrap();
329 assert_eq!(packed, serialized);
330 }
331 });
332 }
333 }
334
335 Ok(out)
336 }
337