xref: /aosp_15_r20/system/authgraph/derive/src/lib.rs (revision 4185b0660fbe514985fdcf75410317caad8afad1)
1 // Copyright 2022, The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 //! Derive macro for `AsCborValue`.
16 use proc_macro2::TokenStream;
17 use quote::{quote, quote_spanned};
18 use syn::{
19     parse_macro_input, parse_quote, spanned::Spanned, Data, DeriveInput, Fields, GenericParam,
20     Generics, Index,
21 };
22 
23 /// Derive macro that implements the `AsCborValue` trait.  Using this macro requires
24 /// that `AsCborValue`, `CborError` and `cbor_type_error` are locally `use`d.
25 #[proc_macro_derive(AsCborValue)]
derive_as_cbor_value(input: proc_macro::TokenStream) -> proc_macro::TokenStream26 pub fn derive_as_cbor_value(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
27     let input = parse_macro_input!(input as DeriveInput);
28     derive_as_cbor_value_internal(&input)
29 }
30 
derive_as_cbor_value_internal(input: &DeriveInput) -> proc_macro::TokenStream31 fn derive_as_cbor_value_internal(input: &DeriveInput) -> proc_macro::TokenStream {
32     let name = &input.ident;
33 
34     // Add a bound `T: AsCborValue` for every type parameter `T`.
35     let generics = add_trait_bounds(&input.generics);
36     let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
37 
38     let from_val = from_val_struct(&input.data);
39     let to_val = to_val_struct(&input.data);
40 
41     let expanded = quote! {
42         // The generated impl
43         impl #impl_generics AsCborValue for #name #ty_generics #where_clause {
44             fn from_cbor_value(value: ciborium::value::Value) -> Result<Self, CborError> {
45                 #from_val
46             }
47             fn to_cbor_value(self) -> Result<ciborium::value::Value, CborError> {
48                 #to_val
49             }
50         }
51     };
52 
53     expanded.into()
54 }
55 
56 /// Add a bound `T: AsCborValue` for every type parameter `T`.
add_trait_bounds(generics: &Generics) -> Generics57 fn add_trait_bounds(generics: &Generics) -> Generics {
58     let mut generics = generics.clone();
59     for param in &mut generics.params {
60         if let GenericParam::Type(ref mut type_param) = *param {
61             type_param.bounds.push(parse_quote!(AsCborValue));
62         }
63     }
64     generics
65 }
66 
67 /// Generate an expression to convert an instance of a compound type to `ciborium::value::Value`
to_val_struct(data: &Data) -> TokenStream68 fn to_val_struct(data: &Data) -> TokenStream {
69     match *data {
70         Data::Struct(ref data) => {
71             match data.fields {
72                 Fields::Named(ref fields) => {
73                     // Expands to an expression like
74                     //
75                     //     {
76                     //         let mut v = Vec::new();
77                     //         v.try_reserve(3).map_err(|_e| CborError::AllocationFailed)?;
78                     //         v.push(AsCborValue::to_cbor_value(self.x)?);
79                     //         v.push(AsCborValue::to_cbor_value(self.y)?);
80                     //         v.push(AsCborValue::to_cbor_value(self.z)?);
81                     //         Ok(ciborium::value::Value::Array(v))
82                     //     }
83                     let nfields = fields.named.len();
84                     let recurse = fields.named.iter().map(|f| {
85                         let name = &f.ident;
86                         quote_spanned! {f.span()=>
87                             v.push(AsCborValue::to_cbor_value(self.#name)?)
88                         }
89                     });
90                     quote! {
91                         {
92                             let mut v = Vec::new();
93                             v.try_reserve(#nfields).map_err(|_e| CborError::AllocationFailed)?;
94                             #(#recurse; )*
95                             Ok(ciborium::value::Value::Array(v))
96                         }
97                     }
98                 }
99                 Fields::Unnamed(_) => unimplemented!(),
100                 Fields::Unit => unimplemented!(),
101             }
102         }
103         Data::Enum(_) => {
104             quote! {
105                 let v: ciborium::value::Integer = (self as i32).into();
106                 Ok(ciborium::value::Value::Integer(v))
107             }
108         }
109         Data::Union(_) => unimplemented!(),
110     }
111 }
112 
113 /// Generate an expression to convert a `ciborium::value::Value` into an instance of a compound
114 /// type.
from_val_struct(data: &Data) -> TokenStream115 fn from_val_struct(data: &Data) -> TokenStream {
116     match data {
117         Data::Struct(ref data) => {
118             match data.fields {
119                 Fields::Named(ref fields) => {
120                     // Expands to an expression like
121                     //
122                     //     let mut a = match value {
123                     //         ciborium::value::Value::Array(a) => a,
124                     //         _ => return cbor_type_error(&value, "arr"),
125                     //     };
126                     //     if a.len() != 3 {
127                     //         return Err(CborError::UnexpectedItem("arr", "arr len 3"));
128                     //     }
129                     //     // Fields specified in reverse order to reduce shifting.
130                     //     Ok(Self {
131                     //         z: <ZType>::from_cbor_value(a.remove(2))?,
132                     //         y: <YType>::from_cbor_value(a.remove(1))?,
133                     //         x: <XType>::from_cbor_value(a.remove(0))?,
134                     //     })
135                     //
136                     // but using fully qualified function call syntax.
137                     let nfields = fields.named.len();
138                     let recurse = fields.named.iter().enumerate().rev().map(|(i, f)| {
139                         let name = &f.ident;
140                         let index = Index::from(i);
141                         let typ = &f.ty;
142                         quote_spanned! {f.span()=>
143                                         #name: <#typ>::from_cbor_value(a.remove(#index))?
144                         }
145                     });
146                     quote! {
147                         let mut a = match value {
148                             ciborium::value::Value::Array(a) => a,
149                             _ => return cbor_type_error(&value, "arr"),
150                         };
151                         if a.len() != #nfields {
152                             return Err(CborError::UnexpectedItem(
153                                 "arr",
154                                 concat!("arr len ", stringify!(#nfields)),
155                             ));
156                         }
157                         // Fields specified in reverse order to reduce shifting.
158                         Ok(Self {
159                             #(#recurse, )*
160                         })
161                     }
162                 }
163                 Fields::Unnamed(_) => unimplemented!(),
164                 Fields::Unit => unimplemented!(),
165             }
166         }
167         Data::Enum(enum_data) => {
168             // This only copes with variants with no fields.
169             // Expands to an expression like:
170             //
171             //     use core::convert::TryInto;
172             //     let v: i32 = match value {
173             //         ciborium::value::Value::Integer(i) => i.try_into().map_err(|_| {
174             //             CborError::OutOfRangeIntegerValue
175             //         })?,
176             //         v => return cbor_type_error(&v, &"int"),
177             //     };
178             //     match v {
179             //         x if x == Self::Variant1 as i32 => Ok(Self::Variant1),
180             //         x if x == Self::Variant2 as i32 => Ok(Self::Variant2),
181             //         x if x == Self::Variant3 as i32 => Ok(Self::Variant3),
182             //         _ => Err( CborError::OutOfRangeIntegerValue),
183             //     }
184             let recurse = enum_data.variants.iter().map(|variant| {
185                 let vname = &variant.ident;
186                 quote_spanned! {variant.span()=>
187                                 x if x == Self::#vname as i32 => Ok(Self::#vname),
188                 }
189             });
190 
191             quote! {
192                 use core::convert::TryInto;
193                 // First get the int value as an `i32`.
194                 let v: i32 = match value {
195                     ciborium::value::Value::Integer(i) => i.try_into().map_err(|_| {
196                         CborError::OutOfRangeIntegerValue
197                     })?,
198                     v => return cbor_type_error(&v, &"int"),
199                 };
200                 // Now match against enum possibilities.
201                 match v {
202                     #(#recurse)*
203                     _ => Err(
204                         CborError::OutOfRangeIntegerValue
205                     ),
206                 }
207             }
208         }
209         Data::Union(_) => unimplemented!(),
210     }
211 }
212