xref: /aosp_15_r20/development/tools/external_crates/name_and_version_proc_macros/src/lib.rs (revision 90c8c64db3049935a07c6143d7fd006e26f8ecca)
1 // Copyright (C) 2023 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 //! Derive the NameAndVersionMap trait for a struct with a suitable map field.
16 
17 use syn::{parse_macro_input, DeriveInput, Error};
18 
19 /// Derive the NameAndVersionMap trait for a struct with a suitable map field.
20 #[proc_macro_derive(NameAndVersionMap)]
derive_name_and_version_map(input: proc_macro::TokenStream) -> proc_macro::TokenStream21 pub fn derive_name_and_version_map(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
22     let input = parse_macro_input!(input as DeriveInput);
23     name_and_version_map::expand(input).unwrap_or_else(Error::into_compile_error).into()
24 }
25 
26 mod name_and_version_map {
27     use proc_macro2::TokenStream;
28     use quote::quote;
29     use syn::{
30         Data, DataStruct, DeriveInput, Error, Field, GenericArgument, PathArguments, Result, Type,
31     };
32 
expand(input: DeriveInput) -> Result<TokenStream>33     pub(crate) fn expand(input: DeriveInput) -> Result<TokenStream> {
34         let name = &input.ident;
35         let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
36 
37         let mapfield = get_map_field(get_struct(&input)?)?;
38         let mapfield_name = mapfield
39             .ident
40             .as_ref()
41             .ok_or(Error::new_spanned(mapfield, "mapfield ident is none"))?;
42         let (_, value_type) = get_map_type(&mapfield.ty)?;
43 
44         let expanded = quote! {
45             #[automatically_derived]
46             impl #impl_generics NameAndVersionMap for #name #ty_generics #where_clause {
47                 type Value = #value_type;
48 
49                 fn map_field(&self) -> &BTreeMap<NameAndVersion, Self::Value> {
50                     self.#mapfield_name.map_field()
51                 }
52 
53                 fn map_field_mut(&mut self) -> &mut BTreeMap<NameAndVersion, Self::Value> {
54                     self.#mapfield_name.map_field_mut()
55                 }
56 
57                 fn insert_or_error(&mut self, key: NameAndVersion, val: Self::Value) -> Result<(), name_and_version::Error> {
58                     self.#mapfield_name.insert_or_error(key, val)
59                 }
60 
61                 fn num_crates(&self) -> usize {
62                     self.#mapfield_name.num_crates()
63                 }
64 
65                 fn get_versions<'a, 'b>(&'a self, name: &'b str) -> Box<dyn Iterator<Item = (&'a NameAndVersion, &'a Self::Value)> + 'a> {
66                     self.#mapfield_name.get_versions(name)
67                 }
68 
69                 fn get_versions_mut<'a, 'b>(&'a mut self, name: &'b str) -> Box<dyn Iterator<Item = (&'a NameAndVersion, &'a mut Self::Value)> + 'a> {
70                     self.#mapfield_name.get_versions_mut(name)
71                 }
72 
73                 fn filter_versions<'a: 'b, 'b, F: Fn(&mut dyn Iterator<Item = (&'b NameAndVersion, &'b Self::Value)>,
74                 ) -> HashSet<Version> + 'a>(
75                     &'a self,
76                     f: F,
77                 ) -> Box<dyn Iterator<Item =(&'a NameAndVersion, &'a Self::Value)> + 'a> {
78                     self.#mapfield_name.filter_versions(f)
79                 }
80             }
81         };
82 
83         Ok(expanded)
84     }
85 
get_struct(input: &DeriveInput) -> Result<&DataStruct>86     fn get_struct(input: &DeriveInput) -> Result<&DataStruct> {
87         match &input.data {
88             Data::Struct(strukt) => Ok(strukt),
89             _ => Err(Error::new_spanned(input, "Not a struct")),
90         }
91     }
92 
get_map_field(strukt: &DataStruct) -> Result<&Field>93     fn get_map_field(strukt: &DataStruct) -> Result<&Field> {
94         for field in &strukt.fields {
95             if let Ok((syn::Type::Path(path), _value_type)) = get_map_type(&field.ty) {
96                 if path.path.segments.len() == 1 && path.path.segments[0].ident == "NameAndVersion"
97                 {
98                     return Ok(field);
99                 }
100             }
101         }
102         Err(Error::new_spanned(strukt.struct_token, "No field of type NameAndVersionMap"))
103     }
104 
get_map_type(typ: &Type) -> Result<(&Type, &Type)>105     fn get_map_type(typ: &Type) -> Result<(&Type, &Type)> {
106         if let syn::Type::Path(path) = &typ {
107             if path.path.segments.len() == 1 && path.path.segments[0].ident == "BTreeMap" {
108                 if let PathArguments::AngleBracketed(args) = &path.path.segments[0].arguments {
109                     if args.args.len() == 2 {
110                         return Ok((get_type(&args.args[0])?, get_type(&args.args[1])?));
111                     }
112                 }
113             }
114         }
115         Err(Error::new_spanned(typ, "Must be BTreeMap"))
116     }
117 
get_type(arg: &GenericArgument) -> Result<&Type>118     fn get_type(arg: &GenericArgument) -> Result<&Type> {
119         if let GenericArgument::Type(typ) = arg {
120             return Ok(typ);
121         }
122         Err(Error::new_spanned(arg, "Could not extract argument type"))
123     }
124 }
125