1 // Copyright (C) 2023 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 //! Derive the NameAndVersionMap trait for a struct with a suitable map field.
16
17 use syn::{parse_macro_input, DeriveInput, Error};
18
19 /// Derive the NameAndVersionMap trait for a struct with a suitable map field.
20 #[proc_macro_derive(NameAndVersionMap)]
derive_name_and_version_map(input: proc_macro::TokenStream) -> proc_macro::TokenStream21 pub fn derive_name_and_version_map(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
22 let input = parse_macro_input!(input as DeriveInput);
23 name_and_version_map::expand(input).unwrap_or_else(Error::into_compile_error).into()
24 }
25
26 mod name_and_version_map {
27 use proc_macro2::TokenStream;
28 use quote::quote;
29 use syn::{
30 Data, DataStruct, DeriveInput, Error, Field, GenericArgument, PathArguments, Result, Type,
31 };
32
expand(input: DeriveInput) -> Result<TokenStream>33 pub(crate) fn expand(input: DeriveInput) -> Result<TokenStream> {
34 let name = &input.ident;
35 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
36
37 let mapfield = get_map_field(get_struct(&input)?)?;
38 let mapfield_name = mapfield
39 .ident
40 .as_ref()
41 .ok_or(Error::new_spanned(mapfield, "mapfield ident is none"))?;
42 let (_, value_type) = get_map_type(&mapfield.ty)?;
43
44 let expanded = quote! {
45 #[automatically_derived]
46 impl #impl_generics NameAndVersionMap for #name #ty_generics #where_clause {
47 type Value = #value_type;
48
49 fn map_field(&self) -> &BTreeMap<NameAndVersion, Self::Value> {
50 self.#mapfield_name.map_field()
51 }
52
53 fn map_field_mut(&mut self) -> &mut BTreeMap<NameAndVersion, Self::Value> {
54 self.#mapfield_name.map_field_mut()
55 }
56
57 fn insert_or_error(&mut self, key: NameAndVersion, val: Self::Value) -> Result<(), name_and_version::Error> {
58 self.#mapfield_name.insert_or_error(key, val)
59 }
60
61 fn num_crates(&self) -> usize {
62 self.#mapfield_name.num_crates()
63 }
64
65 fn get_versions<'a, 'b>(&'a self, name: &'b str) -> Box<dyn Iterator<Item = (&'a NameAndVersion, &'a Self::Value)> + 'a> {
66 self.#mapfield_name.get_versions(name)
67 }
68
69 fn get_versions_mut<'a, 'b>(&'a mut self, name: &'b str) -> Box<dyn Iterator<Item = (&'a NameAndVersion, &'a mut Self::Value)> + 'a> {
70 self.#mapfield_name.get_versions_mut(name)
71 }
72
73 fn filter_versions<'a: 'b, 'b, F: Fn(&mut dyn Iterator<Item = (&'b NameAndVersion, &'b Self::Value)>,
74 ) -> HashSet<Version> + 'a>(
75 &'a self,
76 f: F,
77 ) -> Box<dyn Iterator<Item =(&'a NameAndVersion, &'a Self::Value)> + 'a> {
78 self.#mapfield_name.filter_versions(f)
79 }
80 }
81 };
82
83 Ok(expanded)
84 }
85
get_struct(input: &DeriveInput) -> Result<&DataStruct>86 fn get_struct(input: &DeriveInput) -> Result<&DataStruct> {
87 match &input.data {
88 Data::Struct(strukt) => Ok(strukt),
89 _ => Err(Error::new_spanned(input, "Not a struct")),
90 }
91 }
92
get_map_field(strukt: &DataStruct) -> Result<&Field>93 fn get_map_field(strukt: &DataStruct) -> Result<&Field> {
94 for field in &strukt.fields {
95 if let Ok((syn::Type::Path(path), _value_type)) = get_map_type(&field.ty) {
96 if path.path.segments.len() == 1 && path.path.segments[0].ident == "NameAndVersion"
97 {
98 return Ok(field);
99 }
100 }
101 }
102 Err(Error::new_spanned(strukt.struct_token, "No field of type NameAndVersionMap"))
103 }
104
get_map_type(typ: &Type) -> Result<(&Type, &Type)>105 fn get_map_type(typ: &Type) -> Result<(&Type, &Type)> {
106 if let syn::Type::Path(path) = &typ {
107 if path.path.segments.len() == 1 && path.path.segments[0].ident == "BTreeMap" {
108 if let PathArguments::AngleBracketed(args) = &path.path.segments[0].arguments {
109 if args.args.len() == 2 {
110 return Ok((get_type(&args.args[0])?, get_type(&args.args[1])?));
111 }
112 }
113 }
114 }
115 Err(Error::new_spanned(typ, "Must be BTreeMap"))
116 }
117
get_type(arg: &GenericArgument) -> Result<&Type>118 fn get_type(arg: &GenericArgument) -> Result<&Type> {
119 if let GenericArgument::Type(typ) = arg {
120 return Ok(typ);
121 }
122 Err(Error::new_spanned(arg, "Could not extract argument type"))
123 }
124 }
125