1 //! Support for deriving the `ValueOrd` trait on enums and structs.
2 //!
3 //! This trait is used in conjunction with ASN.1 `SET OF` types to determine
4 //! the lexicographical order of their DER encodings.
5 
6 // TODO(tarcieri): enum support
7 
8 use crate::{FieldAttrs, TypeAttrs};
9 use proc_macro2::TokenStream;
10 use quote::quote;
11 use syn::{DeriveInput, Field, Ident, Lifetime, Variant};
12 
13 /// Derive the `Enumerated` trait for an enum.
14 pub(crate) struct DeriveValueOrd {
15     /// Name of the enum.
16     ident: Ident,
17 
18     /// Lifetime of the struct.
19     lifetime: Option<Lifetime>,
20 
21     /// Fields of structs or enum variants.
22     fields: Vec<ValueField>,
23 
24     /// Type of input provided (`enum` or `struct`).
25     input_type: InputType,
26 }
27 
28 impl DeriveValueOrd {
29     /// Parse [`DeriveInput`].
new(input: DeriveInput) -> syn::Result<Self>30     pub fn new(input: DeriveInput) -> syn::Result<Self> {
31         let ident = input.ident;
32         let type_attrs = TypeAttrs::parse(&input.attrs)?;
33 
34         // TODO(tarcieri): properly handle multiple lifetimes
35         let lifetime = input
36             .generics
37             .lifetimes()
38             .next()
39             .map(|lt| lt.lifetime.clone());
40 
41         let (fields, input_type) = match input.data {
42             syn::Data::Enum(data) => (
43                 data.variants
44                     .into_iter()
45                     .map(|variant| ValueField::new_enum(variant, &type_attrs))
46                     .collect::<syn::Result<_>>()?,
47                 InputType::Enum,
48             ),
49             syn::Data::Struct(data) => (
50                 data.fields
51                     .into_iter()
52                     .map(|field| ValueField::new_struct(field, &type_attrs))
53                     .collect::<syn::Result<_>>()?,
54                 InputType::Struct,
55             ),
56             _ => abort!(
57                 ident,
58                 "can't derive `ValueOrd` on this type: \
59                  only `enum` and `struct` types are allowed",
60             ),
61         };
62 
63         Ok(Self {
64             ident,
65             lifetime,
66             fields,
67             input_type,
68         })
69     }
70 
71     /// Lower the derived output into a [`TokenStream`].
to_tokens(&self) -> TokenStream72     pub fn to_tokens(&self) -> TokenStream {
73         let ident = &self.ident;
74 
75         // Lifetime parameters
76         // TODO(tarcieri): support multiple lifetimes
77         let lt_params = self
78             .lifetime
79             .as_ref()
80             .map(|lt| vec![lt.clone()])
81             .unwrap_or_default();
82 
83         let mut body = Vec::new();
84 
85         for field in &self.fields {
86             body.push(field.to_tokens());
87         }
88 
89         let body = match self.input_type {
90             InputType::Enum => {
91                 quote! {
92                     #[allow(unused_imports)]
93                     use ::der::ValueOrd;
94                     match (self, other) {
95                         #(#body)*
96                         _ => unreachable!(),
97                     }
98                 }
99             }
100             InputType::Struct => {
101                 quote! {
102                     #[allow(unused_imports)]
103                     use ::der::{DerOrd, ValueOrd};
104 
105                     #(#body)*
106 
107                     Ok(::core::cmp::Ordering::Equal)
108                 }
109             }
110         };
111 
112         quote! {
113             impl<#(#lt_params)*> ::der::ValueOrd for #ident<#(#lt_params)*> {
114                 fn value_cmp(&self, other: &Self) -> ::der::Result<::core::cmp::Ordering> {
115                     #body
116                 }
117             }
118         }
119     }
120 }
121 
122 /// What kind of input was provided (i.e. `enum` or `struct`).
123 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
124 enum InputType {
125     /// Input is an `enum`.
126     Enum,
127 
128     /// Input is a `struct`.
129     Struct,
130 }
131 
132 struct ValueField {
133     /// Name of the field
134     ident: Ident,
135 
136     /// Field-level attributes.
137     attrs: FieldAttrs,
138 
139     is_enum: bool,
140 }
141 
142 impl ValueField {
143     /// Create from an `enum` variant.
new_enum(variant: Variant, type_attrs: &TypeAttrs) -> syn::Result<Self>144     fn new_enum(variant: Variant, type_attrs: &TypeAttrs) -> syn::Result<Self> {
145         let ident = variant.ident;
146 
147         let attrs = FieldAttrs::parse(&variant.attrs, type_attrs)?;
148         Ok(Self {
149             ident,
150             attrs,
151             is_enum: true,
152         })
153     }
154 
155     /// Create from a `struct` field.
new_struct(field: Field, type_attrs: &TypeAttrs) -> syn::Result<Self>156     fn new_struct(field: Field, type_attrs: &TypeAttrs) -> syn::Result<Self> {
157         let ident =
158             field.ident.as_ref().cloned().ok_or_else(|| {
159                 syn::Error::new_spanned(&field, "tuple structs are not supported")
160             })?;
161 
162         let attrs = FieldAttrs::parse(&field.attrs, type_attrs)?;
163         Ok(Self {
164             ident,
165             attrs,
166             is_enum: false,
167         })
168     }
169 
170     /// Lower to [`TokenStream`].
to_tokens(&self) -> TokenStream171     fn to_tokens(&self) -> TokenStream {
172         let ident = &self.ident;
173 
174         if self.is_enum {
175             let binding1 = quote!(Self::#ident(this));
176             let binding2 = quote!(Self::#ident(other));
177             quote! {
178                 (#binding1, #binding2) => this.value_cmp(other),
179             }
180         } else {
181             let mut binding1 = quote!(self.#ident);
182             let mut binding2 = quote!(other.#ident);
183 
184             if let Some(ty) = &self.attrs.asn1_type {
185                 binding1 = ty.encoder(&binding1);
186                 binding2 = ty.encoder(&binding2);
187             }
188 
189             quote! {
190                 match #binding1.der_cmp(&#binding2)? {
191                     ::core::cmp::Ordering::Equal => (),
192                     other => return Ok(other),
193                 }
194             }
195         }
196     }
197 }
198