1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 use std::iter::empty;
16
17 use proc_macro2::TokenStream;
18 use quote::{format_ident, quote};
19
20 use crate::ast;
21
22 use crate::backends::intermediate::{
23 ComputedOffsetId, ComputedValueId, PacketOrStruct, PacketOrStructLength, Schema,
24 };
25
26 use super::computed_values::{Computable, Declarable};
27 use super::utils::get_integer_type;
28
generate_packet( id: &str, fields: &[ast::Field], parent_id: Option<&str>, schema: &Schema, curr_schema: &PacketOrStruct, ) -> Result<TokenStream, String>29 pub fn generate_packet(
30 id: &str,
31 fields: &[ast::Field],
32 parent_id: Option<&str>,
33 schema: &Schema,
34 curr_schema: &PacketOrStruct,
35 ) -> Result<TokenStream, String> {
36 let id_ident = format_ident!("{id}View");
37
38 let needs_external = matches!(curr_schema.length, PacketOrStructLength::NeedsExternal);
39
40 let length_getter = if needs_external {
41 ComputedOffsetId::PacketEnd.declare_fn(quote! { Ok(self.buf.get_size_in_bits()) })
42 } else {
43 quote! {}
44 };
45
46 let computed_getters = empty()
47 .chain(
48 curr_schema.computed_offsets.iter().map(|(decl, defn)| decl.declare_fn(defn.compute())),
49 )
50 .chain(
51 curr_schema.computed_values.iter().map(|(decl, defn)| decl.declare_fn(defn.compute())),
52 );
53
54 let field_getters = fields.iter().map(|field| {
55 match &field.desc {
56 ast::FieldDesc::Padding { .. }
57 | ast::FieldDesc::Flag { .. }
58 | ast::FieldDesc::Reserved { .. }
59 | ast::FieldDesc::FixedScalar { .. }
60 | ast::FieldDesc::FixedEnum { .. }
61 | ast::FieldDesc::ElementSize { .. }
62 | ast::FieldDesc::Count { .. }
63 | ast::FieldDesc::Size { .. } => {
64 // no-op, no getter generated for this type
65 quote! {}
66 }
67 ast::FieldDesc::Group { .. } => unreachable!(),
68 ast::FieldDesc::Checksum { .. } => {
69 unimplemented!("checksums not yet supported with this backend")
70 }
71 ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body => {
72 let name = if matches!(field.desc, ast::FieldDesc::Payload { .. }) { "_payload_"} else { "_body_"};
73 let payload_start_offset = ComputedOffsetId::FieldOffset(name).call_fn();
74 let payload_end_offset = ComputedOffsetId::FieldEndOffset(name).call_fn();
75 quote! {
76 fn try_get_payload(&self) -> Result<SizedBitSlice<'a>, ParseError> {
77 let payload_start_offset = #payload_start_offset;
78 let payload_end_offset = #payload_end_offset;
79 self.buf.offset(payload_start_offset)?.slice(payload_end_offset - payload_start_offset)
80 }
81
82 fn try_get_raw_payload(&self) -> Result<impl Iterator<Item = Result<u8, ParseError>> + '_, ParseError> {
83 let view = self.try_get_payload()?;
84 let count = (view.get_size_in_bits() + 7) / 8;
85 Ok((0..count).map(move |i| Ok(view.offset(i*8)?.slice(8.min(view.get_size_in_bits() - i*8))?.try_parse()?)))
86 }
87
88 pub fn get_raw_payload(&self) -> impl Iterator<Item = u8> + '_ {
89 self.try_get_raw_payload().unwrap().map(|x| x.unwrap())
90 }
91 }
92 }
93 ast::FieldDesc::Array { id, width, type_id, .. } => {
94 let (elem_type, return_type) = if let Some(width) = width {
95 let ident = get_integer_type(*width);
96 (ident.clone(), quote!{ #ident })
97 } else if let Some(type_id) = type_id {
98 if schema.enums.contains_key(type_id.as_str()) {
99 let ident = format_ident!("{}", type_id);
100 (ident.clone(), quote! { #ident })
101 } else {
102 let ident = format_ident!("{}View", type_id);
103 (ident.clone(), quote! { #ident<'a> })
104 }
105 } else {
106 unreachable!()
107 };
108
109 let try_getter_name = format_ident!("try_get_{id}_iter");
110 let getter_name = format_ident!("get_{id}_iter");
111
112 let start_offset = ComputedOffsetId::FieldOffset(id).call_fn();
113 let count = ComputedValueId::FieldCount(id).call_fn();
114
115 let element_size_known = curr_schema
116 .computed_values
117 .contains_key(&ComputedValueId::FieldElementSize(id));
118
119 let body = if element_size_known {
120 let element_size = ComputedValueId::FieldElementSize(id).call_fn();
121 let parsed_curr_view = if width.is_some() {
122 quote! { curr_view.try_parse() }
123 } else {
124 quote! { #elem_type::try_parse(curr_view.into()) }
125 };
126 quote! {
127 let view = self.buf.offset(#start_offset)?;
128 let count = #count;
129 let element_size = #element_size;
130 Ok((0..count).map(move |i| {
131 let curr_view = view.offset(element_size.checked_mul(i * 8).ok_or(ParseError::ArithmeticOverflow)?)?
132 .slice(element_size.checked_mul(8).ok_or(ParseError::ArithmeticOverflow)?)?;
133 #parsed_curr_view
134 }))
135 }
136 } else {
137 quote! {
138 let mut view = self.buf.offset(#start_offset)?;
139 let count = #count;
140 Ok((0..count).map(move |i| {
141 let parsed = #elem_type::try_parse(view.into())?;
142 view = view.offset(parsed.try_get_size()? * 8)?;
143 Ok(parsed)
144 }))
145 }
146 };
147
148 quote! {
149 fn #try_getter_name(&self) -> Result<impl Iterator<Item = Result<#return_type, ParseError>> + 'a, ParseError> {
150 #body
151 }
152
153 #[inline]
154 pub fn #getter_name(&self) -> impl Iterator<Item = #return_type> + 'a {
155 self.#try_getter_name().unwrap().map(|x| x.unwrap())
156 }
157 }
158 }
159 ast::FieldDesc::Scalar { id, width } => {
160 let try_getter_name = format_ident!("try_get_{id}");
161 let getter_name = format_ident!("get_{id}");
162 let offset = ComputedOffsetId::FieldOffset(id).call_fn();
163 let scalar_type = get_integer_type(*width);
164 quote! {
165 fn #try_getter_name(&self) -> Result<#scalar_type, ParseError> {
166 self.buf.offset(#offset)?.slice(#width)?.try_parse()
167 }
168
169 #[inline]
170 pub fn #getter_name(&self) -> #scalar_type {
171 self.#try_getter_name().unwrap()
172 }
173 }
174 }
175 ast::FieldDesc::Typedef { id, type_id } => {
176 let try_getter_name = format_ident!("try_get_{id}");
177 let getter_name = format_ident!("get_{id}");
178
179 let (type_ident, return_type) = if schema.enums.contains_key(type_id.as_str()) {
180 let ident = format_ident!("{type_id}");
181 (ident.clone(), quote! { #ident })
182 } else {
183 let ident = format_ident!("{}View", type_id);
184 (ident.clone(), quote! { #ident<'a> })
185 };
186 let offset = ComputedOffsetId::FieldOffset(id).call_fn();
187 let end_offset_known = curr_schema
188 .computed_offsets
189 .contains_key(&ComputedOffsetId::FieldEndOffset(id));
190 let sliced_view = if end_offset_known {
191 let end_offset = ComputedOffsetId::FieldEndOffset(id).call_fn();
192 quote! { self.buf.offset(#offset)?.slice(#end_offset.checked_sub(#offset).ok_or(ParseError::ArithmeticOverflow)?)? }
193 } else {
194 quote! { self.buf.offset(#offset)? }
195 };
196
197 quote! {
198 fn #try_getter_name(&self) -> Result<#return_type, ParseError> {
199 #type_ident::try_parse(#sliced_view.into())
200 }
201
202 #[inline]
203 pub fn #getter_name(&self) -> #return_type {
204 self.#try_getter_name().unwrap()
205 }
206 }
207 }
208 }
209 });
210
211 let backing_buffer = if needs_external {
212 quote! { SizedBitSlice<'a> }
213 } else {
214 quote! { BitSlice<'a> }
215 };
216
217 let parent_ident = match parent_id {
218 Some(parent) => format_ident!("{parent}View"),
219 None => match curr_schema.length {
220 PacketOrStructLength::Static(_) => format_ident!("BitSlice"),
221 PacketOrStructLength::Dynamic => format_ident!("BitSlice"),
222 PacketOrStructLength::NeedsExternal => format_ident!("SizedBitSlice"),
223 },
224 };
225
226 let buffer_extractor = if parent_id.is_some() {
227 quote! { parent.try_get_payload().unwrap().into() }
228 } else {
229 quote! { parent }
230 };
231
232 let field_validators = fields.iter().map(|field| match &field.desc {
233 ast::FieldDesc::Checksum { .. } => unimplemented!(),
234 ast::FieldDesc::Group { .. } => unreachable!(),
235 ast::FieldDesc::Padding { .. }
236 | ast::FieldDesc::Flag { .. }
237 | ast::FieldDesc::Size { .. }
238 | ast::FieldDesc::Count { .. }
239 | ast::FieldDesc::ElementSize { .. }
240 | ast::FieldDesc::Body
241 | ast::FieldDesc::FixedScalar { .. }
242 | ast::FieldDesc::FixedEnum { .. }
243 | ast::FieldDesc::Reserved { .. } => {
244 quote! {}
245 }
246 ast::FieldDesc::Payload { .. } => {
247 quote! {
248 self.try_get_payload()?;
249 self.try_get_raw_payload()?;
250 }
251 }
252 ast::FieldDesc::Array { id, .. } => {
253 let iter_ident = format_ident!("try_get_{id}_iter");
254 quote! {
255 for elem in self.#iter_ident()? {
256 elem?;
257 }
258 }
259 }
260 ast::FieldDesc::Scalar { id, .. } | ast::FieldDesc::Typedef { id, .. } => {
261 let getter_ident = format_ident!("try_get_{id}");
262 quote! { self.#getter_ident()?; }
263 }
264 });
265
266 let packet_end_offset = ComputedOffsetId::PacketEnd.call_fn();
267
268 let owned_id_ident = format_ident!("Owned{id_ident}");
269 let builder_ident = format_ident!("{id}Builder");
270
271 Ok(quote! {
272 #[derive(Clone, Copy, Debug)]
273 pub struct #id_ident<'a> {
274 buf: #backing_buffer,
275 }
276
277 impl<'a> #id_ident<'a> {
278 #length_getter
279
280 #(#computed_getters)*
281
282 #(#field_getters)*
283
284 #[inline]
285 fn try_get_header_start_offset(&self) -> Result<usize, ParseError> {
286 Ok(0)
287 }
288
289 #[inline]
290 fn try_get_size(&self) -> Result<usize, ParseError> {
291 let size = #packet_end_offset;
292 if size % 8 != 0 {
293 return Err(ParseError::MisalignedPayload);
294 }
295 Ok(size / 8)
296 }
297
298 fn validate(&self) -> Result<(), ParseError> {
299 #(#field_validators)*
300 Ok(())
301 }
302 }
303
304 impl<'a> Packet<'a> for #id_ident<'a> {
305 type Parent = #parent_ident<'a>;
306 type Owned = #owned_id_ident;
307 type Builder = #builder_ident;
308
309 fn try_parse_from_buffer(buf: impl Into<SizedBitSlice<'a>>) -> Result<Self, ParseError> {
310 let out = Self { buf: buf.into().into() };
311 out.validate()?;
312 Ok(out)
313 }
314
315 fn try_parse(parent: #parent_ident<'a>) -> Result<Self, ParseError> {
316 let out = Self { buf: #buffer_extractor };
317 out.validate()?;
318 Ok(out)
319 }
320
321 fn to_owned_packet(&self) -> #owned_id_ident {
322 #owned_id_ident {
323 buf: self.buf.backing.to_owned().into(),
324 start_bit_offset: self.buf.start_bit_offset,
325 end_bit_offset: self.buf.end_bit_offset,
326 }
327 }
328 }
329
330 #[derive(Debug)]
331 pub struct #owned_id_ident {
332 buf: Box<[u8]>,
333 start_bit_offset: usize,
334 end_bit_offset: usize,
335 }
336
337 impl OwnedPacket for #owned_id_ident {
338 fn try_parse(buf: Box<[u8]>) -> Result<Self, ParseError> {
339 #id_ident::try_parse_from_buffer(&buf[..])?;
340 let end_bit_offset = buf.len() * 8;
341 Ok(Self { buf, start_bit_offset: 0, end_bit_offset })
342 }
343 }
344
345 impl #owned_id_ident {
346 pub fn view<'a>(&'a self) -> #id_ident<'a> {
347 #id_ident {
348 buf: SizedBitSlice(BitSlice {
349 backing: &self.buf[..],
350 start_bit_offset: self.start_bit_offset,
351 end_bit_offset: self.end_bit_offset,
352 })
353 .into(),
354 }
355 }
356 }
357
358 impl<'a> From<&'a #owned_id_ident> for #id_ident<'a> {
359 fn from(x: &'a #owned_id_ident) -> Self {
360 x.view()
361 }
362 }
363 })
364 }
365