1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use std::iter::empty;
16 
17 use proc_macro2::TokenStream;
18 use quote::{format_ident, quote};
19 
20 use crate::ast;
21 
22 use crate::backends::intermediate::{
23     ComputedOffsetId, ComputedValueId, PacketOrStruct, PacketOrStructLength, Schema,
24 };
25 
26 use super::computed_values::{Computable, Declarable};
27 use super::utils::get_integer_type;
28 
generate_packet( id: &str, fields: &[ast::Field], parent_id: Option<&str>, schema: &Schema, curr_schema: &PacketOrStruct, ) -> Result<TokenStream, String>29 pub fn generate_packet(
30     id: &str,
31     fields: &[ast::Field],
32     parent_id: Option<&str>,
33     schema: &Schema,
34     curr_schema: &PacketOrStruct,
35 ) -> Result<TokenStream, String> {
36     let id_ident = format_ident!("{id}View");
37 
38     let needs_external = matches!(curr_schema.length, PacketOrStructLength::NeedsExternal);
39 
40     let length_getter = if needs_external {
41         ComputedOffsetId::PacketEnd.declare_fn(quote! { Ok(self.buf.get_size_in_bits()) })
42     } else {
43         quote! {}
44     };
45 
46     let computed_getters = empty()
47         .chain(
48             curr_schema.computed_offsets.iter().map(|(decl, defn)| decl.declare_fn(defn.compute())),
49         )
50         .chain(
51             curr_schema.computed_values.iter().map(|(decl, defn)| decl.declare_fn(defn.compute())),
52         );
53 
54     let field_getters = fields.iter().map(|field| {
55         match &field.desc {
56             ast::FieldDesc::Padding { .. }
57             | ast::FieldDesc::Flag { .. }
58             | ast::FieldDesc::Reserved { .. }
59             | ast::FieldDesc::FixedScalar { .. }
60             | ast::FieldDesc::FixedEnum { .. }
61             | ast::FieldDesc::ElementSize { .. }
62             | ast::FieldDesc::Count { .. }
63             | ast::FieldDesc::Size { .. } => {
64                 // no-op, no getter generated for this type
65                 quote! {}
66             }
67             ast::FieldDesc::Group { .. } => unreachable!(),
68             ast::FieldDesc::Checksum { .. } => {
69                 unimplemented!("checksums not yet supported with this backend")
70             }
71             ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body => {
72                 let name = if matches!(field.desc, ast::FieldDesc::Payload { .. }) { "_payload_"} else { "_body_"};
73                 let payload_start_offset = ComputedOffsetId::FieldOffset(name).call_fn();
74                 let payload_end_offset = ComputedOffsetId::FieldEndOffset(name).call_fn();
75                 quote! {
76                     fn try_get_payload(&self) -> Result<SizedBitSlice<'a>, ParseError> {
77                         let payload_start_offset = #payload_start_offset;
78                         let payload_end_offset = #payload_end_offset;
79                         self.buf.offset(payload_start_offset)?.slice(payload_end_offset - payload_start_offset)
80                     }
81 
82                     fn try_get_raw_payload(&self) -> Result<impl Iterator<Item = Result<u8, ParseError>> + '_, ParseError> {
83                         let view = self.try_get_payload()?;
84                         let count = (view.get_size_in_bits() + 7) / 8;
85                         Ok((0..count).map(move |i| Ok(view.offset(i*8)?.slice(8.min(view.get_size_in_bits() - i*8))?.try_parse()?)))
86                     }
87 
88                     pub fn get_raw_payload(&self) -> impl Iterator<Item = u8> + '_ {
89                         self.try_get_raw_payload().unwrap().map(|x| x.unwrap())
90                     }
91                 }
92             }
93             ast::FieldDesc::Array { id, width, type_id, .. } => {
94                 let (elem_type, return_type) = if let Some(width) = width {
95                     let ident = get_integer_type(*width);
96                     (ident.clone(), quote!{ #ident })
97                 } else if let Some(type_id) = type_id {
98                     if schema.enums.contains_key(type_id.as_str()) {
99                         let ident = format_ident!("{}", type_id);
100                         (ident.clone(), quote! { #ident })
101                     } else {
102                         let ident = format_ident!("{}View", type_id);
103                         (ident.clone(), quote! { #ident<'a> })
104                     }
105                 } else {
106                     unreachable!()
107                 };
108 
109                 let try_getter_name = format_ident!("try_get_{id}_iter");
110                 let getter_name = format_ident!("get_{id}_iter");
111 
112                 let start_offset = ComputedOffsetId::FieldOffset(id).call_fn();
113                 let count = ComputedValueId::FieldCount(id).call_fn();
114 
115                 let element_size_known = curr_schema
116                     .computed_values
117                     .contains_key(&ComputedValueId::FieldElementSize(id));
118 
119                 let body = if element_size_known {
120                     let element_size = ComputedValueId::FieldElementSize(id).call_fn();
121                     let parsed_curr_view = if width.is_some() {
122                         quote! { curr_view.try_parse() }
123                     } else {
124                         quote! { #elem_type::try_parse(curr_view.into()) }
125                     };
126                     quote! {
127                         let view = self.buf.offset(#start_offset)?;
128                         let count = #count;
129                         let element_size = #element_size;
130                         Ok((0..count).map(move |i| {
131                             let curr_view = view.offset(element_size.checked_mul(i * 8).ok_or(ParseError::ArithmeticOverflow)?)?
132                                     .slice(element_size.checked_mul(8).ok_or(ParseError::ArithmeticOverflow)?)?;
133                             #parsed_curr_view
134                         }))
135                     }
136                 } else {
137                     quote! {
138                         let mut view = self.buf.offset(#start_offset)?;
139                         let count = #count;
140                         Ok((0..count).map(move |i| {
141                             let parsed = #elem_type::try_parse(view.into())?;
142                             view = view.offset(parsed.try_get_size()? * 8)?;
143                             Ok(parsed)
144                         }))
145                     }
146                 };
147 
148                 quote! {
149                     fn #try_getter_name(&self) -> Result<impl Iterator<Item = Result<#return_type, ParseError>> + 'a, ParseError> {
150                         #body
151                     }
152 
153                     #[inline]
154                     pub fn #getter_name(&self) -> impl Iterator<Item = #return_type> + 'a {
155                         self.#try_getter_name().unwrap().map(|x| x.unwrap())
156                     }
157                 }
158             }
159             ast::FieldDesc::Scalar { id, width } => {
160                 let try_getter_name = format_ident!("try_get_{id}");
161                 let getter_name = format_ident!("get_{id}");
162                 let offset = ComputedOffsetId::FieldOffset(id).call_fn();
163                 let scalar_type = get_integer_type(*width);
164                 quote! {
165                     fn #try_getter_name(&self) -> Result<#scalar_type, ParseError> {
166                         self.buf.offset(#offset)?.slice(#width)?.try_parse()
167                     }
168 
169                     #[inline]
170                     pub fn #getter_name(&self) -> #scalar_type {
171                         self.#try_getter_name().unwrap()
172                     }
173                 }
174             }
175             ast::FieldDesc::Typedef { id, type_id } => {
176                 let try_getter_name = format_ident!("try_get_{id}");
177                 let getter_name = format_ident!("get_{id}");
178 
179                 let (type_ident, return_type) = if schema.enums.contains_key(type_id.as_str()) {
180                     let ident = format_ident!("{type_id}");
181                     (ident.clone(), quote! { #ident })
182                 } else {
183                     let ident = format_ident!("{}View", type_id);
184                     (ident.clone(), quote! { #ident<'a> })
185                 };
186                 let offset = ComputedOffsetId::FieldOffset(id).call_fn();
187                 let end_offset_known = curr_schema
188                     .computed_offsets
189                     .contains_key(&ComputedOffsetId::FieldEndOffset(id));
190                 let sliced_view = if end_offset_known {
191                     let end_offset = ComputedOffsetId::FieldEndOffset(id).call_fn();
192                     quote! { self.buf.offset(#offset)?.slice(#end_offset.checked_sub(#offset).ok_or(ParseError::ArithmeticOverflow)?)? }
193                 } else {
194                     quote! { self.buf.offset(#offset)? }
195                 };
196 
197                 quote! {
198                     fn #try_getter_name(&self) -> Result<#return_type, ParseError> {
199                         #type_ident::try_parse(#sliced_view.into())
200                     }
201 
202                     #[inline]
203                     pub fn #getter_name(&self) -> #return_type {
204                         self.#try_getter_name().unwrap()
205                     }
206                 }
207             }
208         }
209     });
210 
211     let backing_buffer = if needs_external {
212         quote! { SizedBitSlice<'a> }
213     } else {
214         quote! { BitSlice<'a> }
215     };
216 
217     let parent_ident = match parent_id {
218         Some(parent) => format_ident!("{parent}View"),
219         None => match curr_schema.length {
220             PacketOrStructLength::Static(_) => format_ident!("BitSlice"),
221             PacketOrStructLength::Dynamic => format_ident!("BitSlice"),
222             PacketOrStructLength::NeedsExternal => format_ident!("SizedBitSlice"),
223         },
224     };
225 
226     let buffer_extractor = if parent_id.is_some() {
227         quote! { parent.try_get_payload().unwrap().into() }
228     } else {
229         quote! { parent }
230     };
231 
232     let field_validators = fields.iter().map(|field| match &field.desc {
233         ast::FieldDesc::Checksum { .. } => unimplemented!(),
234         ast::FieldDesc::Group { .. } => unreachable!(),
235         ast::FieldDesc::Padding { .. }
236         | ast::FieldDesc::Flag { .. }
237         | ast::FieldDesc::Size { .. }
238         | ast::FieldDesc::Count { .. }
239         | ast::FieldDesc::ElementSize { .. }
240         | ast::FieldDesc::Body
241         | ast::FieldDesc::FixedScalar { .. }
242         | ast::FieldDesc::FixedEnum { .. }
243         | ast::FieldDesc::Reserved { .. } => {
244             quote! {}
245         }
246         ast::FieldDesc::Payload { .. } => {
247             quote! {
248                 self.try_get_payload()?;
249                 self.try_get_raw_payload()?;
250             }
251         }
252         ast::FieldDesc::Array { id, .. } => {
253             let iter_ident = format_ident!("try_get_{id}_iter");
254             quote! {
255                 for elem in self.#iter_ident()? {
256                     elem?;
257                 }
258             }
259         }
260         ast::FieldDesc::Scalar { id, .. } | ast::FieldDesc::Typedef { id, .. } => {
261             let getter_ident = format_ident!("try_get_{id}");
262             quote! { self.#getter_ident()?; }
263         }
264     });
265 
266     let packet_end_offset = ComputedOffsetId::PacketEnd.call_fn();
267 
268     let owned_id_ident = format_ident!("Owned{id_ident}");
269     let builder_ident = format_ident!("{id}Builder");
270 
271     Ok(quote! {
272         #[derive(Clone, Copy, Debug)]
273         pub struct #id_ident<'a> {
274             buf: #backing_buffer,
275         }
276 
277         impl<'a> #id_ident<'a> {
278             #length_getter
279 
280             #(#computed_getters)*
281 
282             #(#field_getters)*
283 
284             #[inline]
285             fn try_get_header_start_offset(&self) -> Result<usize, ParseError> {
286                 Ok(0)
287             }
288 
289             #[inline]
290             fn try_get_size(&self) -> Result<usize, ParseError> {
291                 let size = #packet_end_offset;
292                 if size % 8 != 0 {
293                     return Err(ParseError::MisalignedPayload);
294                 }
295                 Ok(size / 8)
296             }
297 
298             fn validate(&self) -> Result<(), ParseError> {
299                 #(#field_validators)*
300                 Ok(())
301             }
302         }
303 
304         impl<'a> Packet<'a> for #id_ident<'a> {
305             type Parent = #parent_ident<'a>;
306             type Owned = #owned_id_ident;
307             type Builder = #builder_ident;
308 
309             fn try_parse_from_buffer(buf: impl Into<SizedBitSlice<'a>>) -> Result<Self, ParseError> {
310                 let out = Self { buf: buf.into().into() };
311                 out.validate()?;
312                 Ok(out)
313             }
314 
315             fn try_parse(parent: #parent_ident<'a>) -> Result<Self, ParseError> {
316                 let out = Self { buf: #buffer_extractor };
317                 out.validate()?;
318                 Ok(out)
319             }
320 
321             fn to_owned_packet(&self) -> #owned_id_ident {
322                 #owned_id_ident {
323                     buf: self.buf.backing.to_owned().into(),
324                     start_bit_offset: self.buf.start_bit_offset,
325                     end_bit_offset: self.buf.end_bit_offset,
326                 }
327             }
328         }
329 
330         #[derive(Debug)]
331         pub struct #owned_id_ident {
332             buf: Box<[u8]>,
333             start_bit_offset: usize,
334             end_bit_offset: usize,
335         }
336 
337         impl OwnedPacket for #owned_id_ident {
338             fn try_parse(buf: Box<[u8]>) -> Result<Self, ParseError> {
339                 #id_ident::try_parse_from_buffer(&buf[..])?;
340                 let end_bit_offset = buf.len() * 8;
341                 Ok(Self { buf, start_bit_offset: 0, end_bit_offset })
342             }
343         }
344 
345         impl #owned_id_ident {
346             pub fn view<'a>(&'a self) -> #id_ident<'a> {
347                 #id_ident {
348                     buf: SizedBitSlice(BitSlice {
349                         backing: &self.buf[..],
350                         start_bit_offset: self.start_bit_offset,
351                         end_bit_offset: self.end_bit_offset,
352                     })
353                     .into(),
354                 }
355             }
356         }
357 
358         impl<'a> From<&'a #owned_id_ident> for #id_ident<'a> {
359             fn from(x: &'a #owned_id_ident) -> Self {
360                 x.view()
361             }
362         }
363     })
364 }
365