1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use proc_macro2::{Ident, TokenStream};
16 use quote::{format_ident, quote};
17 
18 use crate::backends::intermediate::{
19     ComputedOffset, ComputedOffsetId, ComputedValue, ComputedValueId,
20 };
21 
22 /// This trait is implemented on computed quantities (offsets and values) that can be retrieved via a function call
23 pub trait Declarable {
get_name(&self) -> String24     fn get_name(&self) -> String;
25 
get_ident(&self) -> Ident26     fn get_ident(&self) -> Ident {
27         format_ident!("try_get_{}", self.get_name())
28     }
29 
call_fn(&self) -> TokenStream30     fn call_fn(&self) -> TokenStream {
31         let fn_name = self.get_ident();
32         quote! { self.#fn_name()? }
33     }
34 
declare_fn(&self, body: TokenStream) -> TokenStream35     fn declare_fn(&self, body: TokenStream) -> TokenStream {
36         let fn_name = self.get_ident();
37         quote! {
38             #[inline]
39             fn #fn_name(&self) -> Result<usize, ParseError> {
40                 #body
41             }
42         }
43     }
44 }
45 
46 impl Declarable for ComputedValueId<'_> {
get_name(&self) -> String47     fn get_name(&self) -> String {
48         match self {
49             ComputedValueId::FieldSize(field) => format!("{field}_size"),
50             ComputedValueId::FieldElementSize(field) => format!("{field}_element_size"),
51             ComputedValueId::FieldCount(field) => format!("{field}_count"),
52             ComputedValueId::Custom(i) => format!("custom_value_{i}"),
53         }
54     }
55 }
56 
57 impl Declarable for ComputedOffsetId<'_> {
get_name(&self) -> String58     fn get_name(&self) -> String {
59         match self {
60             ComputedOffsetId::HeaderStart => "header_start_offset".to_string(),
61             ComputedOffsetId::PacketEnd => "packet_end_offset".to_string(),
62             ComputedOffsetId::FieldOffset(field) => format!("{field}_offset"),
63             ComputedOffsetId::FieldEndOffset(field) => format!("{field}_end_offset"),
64             ComputedOffsetId::Custom(i) => format!("custom_offset_{i}"),
65             ComputedOffsetId::TrailerStart => "trailer_start_offset".to_string(),
66         }
67     }
68 }
69 
70 /// This trait is implemented on computed expressions that are computed on-demand (i.e. not via a function call)
71 pub trait Computable {
compute(&self) -> TokenStream72     fn compute(&self) -> TokenStream;
73 }
74 
75 impl Computable for ComputedValue<'_> {
compute(&self) -> TokenStream76     fn compute(&self) -> TokenStream {
77         match self {
78             ComputedValue::Constant(k) => quote! { Ok(#k) },
79             ComputedValue::CountStructsUpToSize { base_id, size, struct_type } => {
80                 let base_offset = base_id.call_fn();
81                 let size = size.call_fn();
82                 let struct_type = format_ident!("{struct_type}View");
83                 quote! {
84                     let mut cnt = 0;
85                     let mut view = self.buf.offset(#base_offset)?;
86                     let mut remaining_size = #size;
87                     while remaining_size > 0 {
88                         let next_struct_size = #struct_type::try_parse(view)?.try_get_size()?;
89                         if next_struct_size > remaining_size {
90                             return Err(ParseError::OutOfBoundsAccess);
91                         }
92                         remaining_size -= next_struct_size;
93                         view = view.offset(next_struct_size * 8)?;
94                         cnt += 1;
95                     }
96                     Ok(cnt)
97                 }
98             }
99             ComputedValue::SizeOfNStructs { base_id, n, struct_type } => {
100                 let base_offset = base_id.call_fn();
101                 let n = n.call_fn();
102                 let struct_type = format_ident!("{struct_type}View");
103                 quote! {
104                     let mut view = self.buf.offset(#base_offset)?;
105                     let mut size = 0;
106                     for _ in 0..#n {
107                         let next_struct_size = #struct_type::try_parse(view)?.try_get_size()?;
108                         size += next_struct_size;
109                         view = view.offset(next_struct_size * 8)?;
110                     }
111                     Ok(size)
112                 }
113             }
114             ComputedValue::Product(x, y) => {
115                 let x = x.call_fn();
116                 let y = y.call_fn();
117                 quote! { #x.checked_mul(#y).ok_or(ParseError::ArithmeticOverflow) }
118             }
119             ComputedValue::Divide(x, y) => {
120                 let x = x.call_fn();
121                 let y = y.call_fn();
122                 quote! {
123                     if #y == 0 || #x % #y != 0 {
124                         return Err(ParseError::DivisionFailure)
125                     }
126                     Ok(#x / #y)
127                 }
128             }
129             ComputedValue::Difference(x, y) => {
130                 let x = x.call_fn();
131                 let y = y.call_fn();
132                 quote! {
133                    let bit_difference = #x.checked_sub(#y).ok_or(ParseError::ArithmeticOverflow)?;
134                    if bit_difference % 8 != 0 {
135                        return Err(ParseError::DivisionFailure);
136                    }
137                    Ok(bit_difference / 8)
138                 }
139             }
140             ComputedValue::ValueAt { offset, width } => {
141                 let offset = offset.call_fn();
142                 quote! { self.buf.offset(#offset)?.slice(#width)?.try_parse() }
143             }
144         }
145     }
146 }
147 
148 impl Computable for ComputedOffset<'_> {
compute(&self) -> TokenStream149     fn compute(&self) -> TokenStream {
150         match self {
151             ComputedOffset::ConstantPlusOffsetInBits(base_id, offset) => {
152                 let base_id = base_id.call_fn();
153                 quote! { #base_id.checked_add_signed(#offset as isize).ok_or(ParseError::ArithmeticOverflow) }
154             }
155             ComputedOffset::SumWithOctets(x, y) => {
156                 let x = x.call_fn();
157                 let y = y.call_fn();
158                 quote! {
159                     #x.checked_add(#y.checked_mul(8).ok_or(ParseError::ArithmeticOverflow)?)
160                       .ok_or(ParseError::ArithmeticOverflow)
161                 }
162             }
163             ComputedOffset::Alias(alias) => {
164                 let alias = alias.call_fn();
165                 quote! { Ok(#alias) }
166             }
167         }
168     }
169 }
170