xref: /aosp_15_r20/system/authgraph/derive/src/lib.rs (revision 4185b0660fbe514985fdcf75410317caad8afad1)
1*4185b066SAndroid Build Coastguard Worker // Copyright 2022, The Android Open Source Project
2*4185b066SAndroid Build Coastguard Worker //
3*4185b066SAndroid Build Coastguard Worker // Licensed under the Apache License, Version 2.0 (the "License");
4*4185b066SAndroid Build Coastguard Worker // you may not use this file except in compliance with the License.
5*4185b066SAndroid Build Coastguard Worker // You may obtain a copy of the License at
6*4185b066SAndroid Build Coastguard Worker //
7*4185b066SAndroid Build Coastguard Worker //     http://www.apache.org/licenses/LICENSE-2.0
8*4185b066SAndroid Build Coastguard Worker //
9*4185b066SAndroid Build Coastguard Worker // Unless required by applicable law or agreed to in writing, software
10*4185b066SAndroid Build Coastguard Worker // distributed under the License is distributed on an "AS IS" BASIS,
11*4185b066SAndroid Build Coastguard Worker // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*4185b066SAndroid Build Coastguard Worker // See the License for the specific language governing permissions and
13*4185b066SAndroid Build Coastguard Worker // limitations under the License.
14*4185b066SAndroid Build Coastguard Worker 
15*4185b066SAndroid Build Coastguard Worker //! Derive macro for `AsCborValue`.
16*4185b066SAndroid Build Coastguard Worker use proc_macro2::TokenStream;
17*4185b066SAndroid Build Coastguard Worker use quote::{quote, quote_spanned};
18*4185b066SAndroid Build Coastguard Worker use syn::{
19*4185b066SAndroid Build Coastguard Worker     parse_macro_input, parse_quote, spanned::Spanned, Data, DeriveInput, Fields, GenericParam,
20*4185b066SAndroid Build Coastguard Worker     Generics, Index,
21*4185b066SAndroid Build Coastguard Worker };
22*4185b066SAndroid Build Coastguard Worker 
23*4185b066SAndroid Build Coastguard Worker /// Derive macro that implements the `AsCborValue` trait.  Using this macro requires
24*4185b066SAndroid Build Coastguard Worker /// that `AsCborValue`, `CborError` and `cbor_type_error` are locally `use`d.
25*4185b066SAndroid Build Coastguard Worker #[proc_macro_derive(AsCborValue)]
derive_as_cbor_value(input: proc_macro::TokenStream) -> proc_macro::TokenStream26*4185b066SAndroid Build Coastguard Worker pub fn derive_as_cbor_value(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
27*4185b066SAndroid Build Coastguard Worker     let input = parse_macro_input!(input as DeriveInput);
28*4185b066SAndroid Build Coastguard Worker     derive_as_cbor_value_internal(&input)
29*4185b066SAndroid Build Coastguard Worker }
30*4185b066SAndroid Build Coastguard Worker 
derive_as_cbor_value_internal(input: &DeriveInput) -> proc_macro::TokenStream31*4185b066SAndroid Build Coastguard Worker fn derive_as_cbor_value_internal(input: &DeriveInput) -> proc_macro::TokenStream {
32*4185b066SAndroid Build Coastguard Worker     let name = &input.ident;
33*4185b066SAndroid Build Coastguard Worker 
34*4185b066SAndroid Build Coastguard Worker     // Add a bound `T: AsCborValue` for every type parameter `T`.
35*4185b066SAndroid Build Coastguard Worker     let generics = add_trait_bounds(&input.generics);
36*4185b066SAndroid Build Coastguard Worker     let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
37*4185b066SAndroid Build Coastguard Worker 
38*4185b066SAndroid Build Coastguard Worker     let from_val = from_val_struct(&input.data);
39*4185b066SAndroid Build Coastguard Worker     let to_val = to_val_struct(&input.data);
40*4185b066SAndroid Build Coastguard Worker 
41*4185b066SAndroid Build Coastguard Worker     let expanded = quote! {
42*4185b066SAndroid Build Coastguard Worker         // The generated impl
43*4185b066SAndroid Build Coastguard Worker         impl #impl_generics AsCborValue for #name #ty_generics #where_clause {
44*4185b066SAndroid Build Coastguard Worker             fn from_cbor_value(value: ciborium::value::Value) -> Result<Self, CborError> {
45*4185b066SAndroid Build Coastguard Worker                 #from_val
46*4185b066SAndroid Build Coastguard Worker             }
47*4185b066SAndroid Build Coastguard Worker             fn to_cbor_value(self) -> Result<ciborium::value::Value, CborError> {
48*4185b066SAndroid Build Coastguard Worker                 #to_val
49*4185b066SAndroid Build Coastguard Worker             }
50*4185b066SAndroid Build Coastguard Worker         }
51*4185b066SAndroid Build Coastguard Worker     };
52*4185b066SAndroid Build Coastguard Worker 
53*4185b066SAndroid Build Coastguard Worker     expanded.into()
54*4185b066SAndroid Build Coastguard Worker }
55*4185b066SAndroid Build Coastguard Worker 
56*4185b066SAndroid Build Coastguard Worker /// Add a bound `T: AsCborValue` for every type parameter `T`.
add_trait_bounds(generics: &Generics) -> Generics57*4185b066SAndroid Build Coastguard Worker fn add_trait_bounds(generics: &Generics) -> Generics {
58*4185b066SAndroid Build Coastguard Worker     let mut generics = generics.clone();
59*4185b066SAndroid Build Coastguard Worker     for param in &mut generics.params {
60*4185b066SAndroid Build Coastguard Worker         if let GenericParam::Type(ref mut type_param) = *param {
61*4185b066SAndroid Build Coastguard Worker             type_param.bounds.push(parse_quote!(AsCborValue));
62*4185b066SAndroid Build Coastguard Worker         }
63*4185b066SAndroid Build Coastguard Worker     }
64*4185b066SAndroid Build Coastguard Worker     generics
65*4185b066SAndroid Build Coastguard Worker }
66*4185b066SAndroid Build Coastguard Worker 
67*4185b066SAndroid Build Coastguard Worker /// Generate an expression to convert an instance of a compound type to `ciborium::value::Value`
to_val_struct(data: &Data) -> TokenStream68*4185b066SAndroid Build Coastguard Worker fn to_val_struct(data: &Data) -> TokenStream {
69*4185b066SAndroid Build Coastguard Worker     match *data {
70*4185b066SAndroid Build Coastguard Worker         Data::Struct(ref data) => {
71*4185b066SAndroid Build Coastguard Worker             match data.fields {
72*4185b066SAndroid Build Coastguard Worker                 Fields::Named(ref fields) => {
73*4185b066SAndroid Build Coastguard Worker                     // Expands to an expression like
74*4185b066SAndroid Build Coastguard Worker                     //
75*4185b066SAndroid Build Coastguard Worker                     //     {
76*4185b066SAndroid Build Coastguard Worker                     //         let mut v = Vec::new();
77*4185b066SAndroid Build Coastguard Worker                     //         v.try_reserve(3).map_err(|_e| CborError::AllocationFailed)?;
78*4185b066SAndroid Build Coastguard Worker                     //         v.push(AsCborValue::to_cbor_value(self.x)?);
79*4185b066SAndroid Build Coastguard Worker                     //         v.push(AsCborValue::to_cbor_value(self.y)?);
80*4185b066SAndroid Build Coastguard Worker                     //         v.push(AsCborValue::to_cbor_value(self.z)?);
81*4185b066SAndroid Build Coastguard Worker                     //         Ok(ciborium::value::Value::Array(v))
82*4185b066SAndroid Build Coastguard Worker                     //     }
83*4185b066SAndroid Build Coastguard Worker                     let nfields = fields.named.len();
84*4185b066SAndroid Build Coastguard Worker                     let recurse = fields.named.iter().map(|f| {
85*4185b066SAndroid Build Coastguard Worker                         let name = &f.ident;
86*4185b066SAndroid Build Coastguard Worker                         quote_spanned! {f.span()=>
87*4185b066SAndroid Build Coastguard Worker                             v.push(AsCborValue::to_cbor_value(self.#name)?)
88*4185b066SAndroid Build Coastguard Worker                         }
89*4185b066SAndroid Build Coastguard Worker                     });
90*4185b066SAndroid Build Coastguard Worker                     quote! {
91*4185b066SAndroid Build Coastguard Worker                         {
92*4185b066SAndroid Build Coastguard Worker                             let mut v = Vec::new();
93*4185b066SAndroid Build Coastguard Worker                             v.try_reserve(#nfields).map_err(|_e| CborError::AllocationFailed)?;
94*4185b066SAndroid Build Coastguard Worker                             #(#recurse; )*
95*4185b066SAndroid Build Coastguard Worker                             Ok(ciborium::value::Value::Array(v))
96*4185b066SAndroid Build Coastguard Worker                         }
97*4185b066SAndroid Build Coastguard Worker                     }
98*4185b066SAndroid Build Coastguard Worker                 }
99*4185b066SAndroid Build Coastguard Worker                 Fields::Unnamed(_) => unimplemented!(),
100*4185b066SAndroid Build Coastguard Worker                 Fields::Unit => unimplemented!(),
101*4185b066SAndroid Build Coastguard Worker             }
102*4185b066SAndroid Build Coastguard Worker         }
103*4185b066SAndroid Build Coastguard Worker         Data::Enum(_) => {
104*4185b066SAndroid Build Coastguard Worker             quote! {
105*4185b066SAndroid Build Coastguard Worker                 let v: ciborium::value::Integer = (self as i32).into();
106*4185b066SAndroid Build Coastguard Worker                 Ok(ciborium::value::Value::Integer(v))
107*4185b066SAndroid Build Coastguard Worker             }
108*4185b066SAndroid Build Coastguard Worker         }
109*4185b066SAndroid Build Coastguard Worker         Data::Union(_) => unimplemented!(),
110*4185b066SAndroid Build Coastguard Worker     }
111*4185b066SAndroid Build Coastguard Worker }
112*4185b066SAndroid Build Coastguard Worker 
113*4185b066SAndroid Build Coastguard Worker /// Generate an expression to convert a `ciborium::value::Value` into an instance of a compound
114*4185b066SAndroid Build Coastguard Worker /// type.
from_val_struct(data: &Data) -> TokenStream115*4185b066SAndroid Build Coastguard Worker fn from_val_struct(data: &Data) -> TokenStream {
116*4185b066SAndroid Build Coastguard Worker     match data {
117*4185b066SAndroid Build Coastguard Worker         Data::Struct(ref data) => {
118*4185b066SAndroid Build Coastguard Worker             match data.fields {
119*4185b066SAndroid Build Coastguard Worker                 Fields::Named(ref fields) => {
120*4185b066SAndroid Build Coastguard Worker                     // Expands to an expression like
121*4185b066SAndroid Build Coastguard Worker                     //
122*4185b066SAndroid Build Coastguard Worker                     //     let mut a = match value {
123*4185b066SAndroid Build Coastguard Worker                     //         ciborium::value::Value::Array(a) => a,
124*4185b066SAndroid Build Coastguard Worker                     //         _ => return cbor_type_error(&value, "arr"),
125*4185b066SAndroid Build Coastguard Worker                     //     };
126*4185b066SAndroid Build Coastguard Worker                     //     if a.len() != 3 {
127*4185b066SAndroid Build Coastguard Worker                     //         return Err(CborError::UnexpectedItem("arr", "arr len 3"));
128*4185b066SAndroid Build Coastguard Worker                     //     }
129*4185b066SAndroid Build Coastguard Worker                     //     // Fields specified in reverse order to reduce shifting.
130*4185b066SAndroid Build Coastguard Worker                     //     Ok(Self {
131*4185b066SAndroid Build Coastguard Worker                     //         z: <ZType>::from_cbor_value(a.remove(2))?,
132*4185b066SAndroid Build Coastguard Worker                     //         y: <YType>::from_cbor_value(a.remove(1))?,
133*4185b066SAndroid Build Coastguard Worker                     //         x: <XType>::from_cbor_value(a.remove(0))?,
134*4185b066SAndroid Build Coastguard Worker                     //     })
135*4185b066SAndroid Build Coastguard Worker                     //
136*4185b066SAndroid Build Coastguard Worker                     // but using fully qualified function call syntax.
137*4185b066SAndroid Build Coastguard Worker                     let nfields = fields.named.len();
138*4185b066SAndroid Build Coastguard Worker                     let recurse = fields.named.iter().enumerate().rev().map(|(i, f)| {
139*4185b066SAndroid Build Coastguard Worker                         let name = &f.ident;
140*4185b066SAndroid Build Coastguard Worker                         let index = Index::from(i);
141*4185b066SAndroid Build Coastguard Worker                         let typ = &f.ty;
142*4185b066SAndroid Build Coastguard Worker                         quote_spanned! {f.span()=>
143*4185b066SAndroid Build Coastguard Worker                                         #name: <#typ>::from_cbor_value(a.remove(#index))?
144*4185b066SAndroid Build Coastguard Worker                         }
145*4185b066SAndroid Build Coastguard Worker                     });
146*4185b066SAndroid Build Coastguard Worker                     quote! {
147*4185b066SAndroid Build Coastguard Worker                         let mut a = match value {
148*4185b066SAndroid Build Coastguard Worker                             ciborium::value::Value::Array(a) => a,
149*4185b066SAndroid Build Coastguard Worker                             _ => return cbor_type_error(&value, "arr"),
150*4185b066SAndroid Build Coastguard Worker                         };
151*4185b066SAndroid Build Coastguard Worker                         if a.len() != #nfields {
152*4185b066SAndroid Build Coastguard Worker                             return Err(CborError::UnexpectedItem(
153*4185b066SAndroid Build Coastguard Worker                                 "arr",
154*4185b066SAndroid Build Coastguard Worker                                 concat!("arr len ", stringify!(#nfields)),
155*4185b066SAndroid Build Coastguard Worker                             ));
156*4185b066SAndroid Build Coastguard Worker                         }
157*4185b066SAndroid Build Coastguard Worker                         // Fields specified in reverse order to reduce shifting.
158*4185b066SAndroid Build Coastguard Worker                         Ok(Self {
159*4185b066SAndroid Build Coastguard Worker                             #(#recurse, )*
160*4185b066SAndroid Build Coastguard Worker                         })
161*4185b066SAndroid Build Coastguard Worker                     }
162*4185b066SAndroid Build Coastguard Worker                 }
163*4185b066SAndroid Build Coastguard Worker                 Fields::Unnamed(_) => unimplemented!(),
164*4185b066SAndroid Build Coastguard Worker                 Fields::Unit => unimplemented!(),
165*4185b066SAndroid Build Coastguard Worker             }
166*4185b066SAndroid Build Coastguard Worker         }
167*4185b066SAndroid Build Coastguard Worker         Data::Enum(enum_data) => {
168*4185b066SAndroid Build Coastguard Worker             // This only copes with variants with no fields.
169*4185b066SAndroid Build Coastguard Worker             // Expands to an expression like:
170*4185b066SAndroid Build Coastguard Worker             //
171*4185b066SAndroid Build Coastguard Worker             //     use core::convert::TryInto;
172*4185b066SAndroid Build Coastguard Worker             //     let v: i32 = match value {
173*4185b066SAndroid Build Coastguard Worker             //         ciborium::value::Value::Integer(i) => i.try_into().map_err(|_| {
174*4185b066SAndroid Build Coastguard Worker             //             CborError::OutOfRangeIntegerValue
175*4185b066SAndroid Build Coastguard Worker             //         })?,
176*4185b066SAndroid Build Coastguard Worker             //         v => return cbor_type_error(&v, &"int"),
177*4185b066SAndroid Build Coastguard Worker             //     };
178*4185b066SAndroid Build Coastguard Worker             //     match v {
179*4185b066SAndroid Build Coastguard Worker             //         x if x == Self::Variant1 as i32 => Ok(Self::Variant1),
180*4185b066SAndroid Build Coastguard Worker             //         x if x == Self::Variant2 as i32 => Ok(Self::Variant2),
181*4185b066SAndroid Build Coastguard Worker             //         x if x == Self::Variant3 as i32 => Ok(Self::Variant3),
182*4185b066SAndroid Build Coastguard Worker             //         _ => Err( CborError::OutOfRangeIntegerValue),
183*4185b066SAndroid Build Coastguard Worker             //     }
184*4185b066SAndroid Build Coastguard Worker             let recurse = enum_data.variants.iter().map(|variant| {
185*4185b066SAndroid Build Coastguard Worker                 let vname = &variant.ident;
186*4185b066SAndroid Build Coastguard Worker                 quote_spanned! {variant.span()=>
187*4185b066SAndroid Build Coastguard Worker                                 x if x == Self::#vname as i32 => Ok(Self::#vname),
188*4185b066SAndroid Build Coastguard Worker                 }
189*4185b066SAndroid Build Coastguard Worker             });
190*4185b066SAndroid Build Coastguard Worker 
191*4185b066SAndroid Build Coastguard Worker             quote! {
192*4185b066SAndroid Build Coastguard Worker                 use core::convert::TryInto;
193*4185b066SAndroid Build Coastguard Worker                 // First get the int value as an `i32`.
194*4185b066SAndroid Build Coastguard Worker                 let v: i32 = match value {
195*4185b066SAndroid Build Coastguard Worker                     ciborium::value::Value::Integer(i) => i.try_into().map_err(|_| {
196*4185b066SAndroid Build Coastguard Worker                         CborError::OutOfRangeIntegerValue
197*4185b066SAndroid Build Coastguard Worker                     })?,
198*4185b066SAndroid Build Coastguard Worker                     v => return cbor_type_error(&v, &"int"),
199*4185b066SAndroid Build Coastguard Worker                 };
200*4185b066SAndroid Build Coastguard Worker                 // Now match against enum possibilities.
201*4185b066SAndroid Build Coastguard Worker                 match v {
202*4185b066SAndroid Build Coastguard Worker                     #(#recurse)*
203*4185b066SAndroid Build Coastguard Worker                     _ => Err(
204*4185b066SAndroid Build Coastguard Worker                         CborError::OutOfRangeIntegerValue
205*4185b066SAndroid Build Coastguard Worker                     ),
206*4185b066SAndroid Build Coastguard Worker                 }
207*4185b066SAndroid Build Coastguard Worker             }
208*4185b066SAndroid Build Coastguard Worker         }
209*4185b066SAndroid Build Coastguard Worker         Data::Union(_) => unimplemented!(),
210*4185b066SAndroid Build Coastguard Worker     }
211*4185b066SAndroid Build Coastguard Worker }
212