1 // Copyright 2023 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 use proc_macro2::{Ident, TokenStream}; 16 use quote::{format_ident, quote}; 17 18 use crate::backends::intermediate::{ 19 ComputedOffset, ComputedOffsetId, ComputedValue, ComputedValueId, 20 }; 21 22 /// This trait is implemented on computed quantities (offsets and values) that can be retrieved via a function call 23 pub trait Declarable { get_name(&self) -> String24 fn get_name(&self) -> String; 25 get_ident(&self) -> Ident26 fn get_ident(&self) -> Ident { 27 format_ident!("try_get_{}", self.get_name()) 28 } 29 call_fn(&self) -> TokenStream30 fn call_fn(&self) -> TokenStream { 31 let fn_name = self.get_ident(); 32 quote! { self.#fn_name()? } 33 } 34 declare_fn(&self, body: TokenStream) -> TokenStream35 fn declare_fn(&self, body: TokenStream) -> TokenStream { 36 let fn_name = self.get_ident(); 37 quote! { 38 #[inline] 39 fn #fn_name(&self) -> Result<usize, ParseError> { 40 #body 41 } 42 } 43 } 44 } 45 46 impl Declarable for ComputedValueId<'_> { get_name(&self) -> String47 fn get_name(&self) -> String { 48 match self { 49 ComputedValueId::FieldSize(field) => format!("{field}_size"), 50 ComputedValueId::FieldElementSize(field) => format!("{field}_element_size"), 51 ComputedValueId::FieldCount(field) => format!("{field}_count"), 52 ComputedValueId::Custom(i) => format!("custom_value_{i}"), 53 } 54 } 55 } 56 57 impl Declarable for ComputedOffsetId<'_> { get_name(&self) -> String58 fn get_name(&self) -> String { 59 match self { 60 ComputedOffsetId::HeaderStart => "header_start_offset".to_string(), 61 ComputedOffsetId::PacketEnd => "packet_end_offset".to_string(), 62 ComputedOffsetId::FieldOffset(field) => format!("{field}_offset"), 63 ComputedOffsetId::FieldEndOffset(field) => format!("{field}_end_offset"), 64 ComputedOffsetId::Custom(i) => format!("custom_offset_{i}"), 65 ComputedOffsetId::TrailerStart => "trailer_start_offset".to_string(), 66 } 67 } 68 } 69 70 /// This trait is implemented on computed expressions that are computed on-demand (i.e. not via a function call) 71 pub trait Computable { compute(&self) -> TokenStream72 fn compute(&self) -> TokenStream; 73 } 74 75 impl Computable for ComputedValue<'_> { compute(&self) -> TokenStream76 fn compute(&self) -> TokenStream { 77 match self { 78 ComputedValue::Constant(k) => quote! { Ok(#k) }, 79 ComputedValue::CountStructsUpToSize { base_id, size, struct_type } => { 80 let base_offset = base_id.call_fn(); 81 let size = size.call_fn(); 82 let struct_type = format_ident!("{struct_type}View"); 83 quote! { 84 let mut cnt = 0; 85 let mut view = self.buf.offset(#base_offset)?; 86 let mut remaining_size = #size; 87 while remaining_size > 0 { 88 let next_struct_size = #struct_type::try_parse(view)?.try_get_size()?; 89 if next_struct_size > remaining_size { 90 return Err(ParseError::OutOfBoundsAccess); 91 } 92 remaining_size -= next_struct_size; 93 view = view.offset(next_struct_size * 8)?; 94 cnt += 1; 95 } 96 Ok(cnt) 97 } 98 } 99 ComputedValue::SizeOfNStructs { base_id, n, struct_type } => { 100 let base_offset = base_id.call_fn(); 101 let n = n.call_fn(); 102 let struct_type = format_ident!("{struct_type}View"); 103 quote! { 104 let mut view = self.buf.offset(#base_offset)?; 105 let mut size = 0; 106 for _ in 0..#n { 107 let next_struct_size = #struct_type::try_parse(view)?.try_get_size()?; 108 size += next_struct_size; 109 view = view.offset(next_struct_size * 8)?; 110 } 111 Ok(size) 112 } 113 } 114 ComputedValue::Product(x, y) => { 115 let x = x.call_fn(); 116 let y = y.call_fn(); 117 quote! { #x.checked_mul(#y).ok_or(ParseError::ArithmeticOverflow) } 118 } 119 ComputedValue::Divide(x, y) => { 120 let x = x.call_fn(); 121 let y = y.call_fn(); 122 quote! { 123 if #y == 0 || #x % #y != 0 { 124 return Err(ParseError::DivisionFailure) 125 } 126 Ok(#x / #y) 127 } 128 } 129 ComputedValue::Difference(x, y) => { 130 let x = x.call_fn(); 131 let y = y.call_fn(); 132 quote! { 133 let bit_difference = #x.checked_sub(#y).ok_or(ParseError::ArithmeticOverflow)?; 134 if bit_difference % 8 != 0 { 135 return Err(ParseError::DivisionFailure); 136 } 137 Ok(bit_difference / 8) 138 } 139 } 140 ComputedValue::ValueAt { offset, width } => { 141 let offset = offset.call_fn(); 142 quote! { self.buf.offset(#offset)?.slice(#width)?.try_parse() } 143 } 144 } 145 } 146 } 147 148 impl Computable for ComputedOffset<'_> { compute(&self) -> TokenStream149 fn compute(&self) -> TokenStream { 150 match self { 151 ComputedOffset::ConstantPlusOffsetInBits(base_id, offset) => { 152 let base_id = base_id.call_fn(); 153 quote! { #base_id.checked_add_signed(#offset as isize).ok_or(ParseError::ArithmeticOverflow) } 154 } 155 ComputedOffset::SumWithOctets(x, y) => { 156 let x = x.call_fn(); 157 let y = y.call_fn(); 158 quote! { 159 #x.checked_add(#y.checked_mul(8).ok_or(ParseError::ArithmeticOverflow)?) 160 .ok_or(ParseError::ArithmeticOverflow) 161 } 162 } 163 ComputedOffset::Alias(alias) => { 164 let alias = alias.call_fn(); 165 quote! { Ok(#alias) } 166 } 167 } 168 } 169 } 170