1 //! Support for deriving the `Sequence` trait on structs for the purposes of
2 //! decoding/encoding ASN.1 `SEQUENCE` types as mapped to struct fields.
3 
4 mod field;
5 
6 use crate::{default_lifetime, TypeAttrs};
7 use field::SequenceField;
8 use proc_macro2::TokenStream;
9 use quote::quote;
10 use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam};
11 
12 /// Derive the `Sequence` trait for a struct
13 pub(crate) struct DeriveSequence {
14     /// Name of the sequence struct.
15     ident: Ident,
16 
17     /// Generics of the struct.
18     generics: Generics,
19 
20     /// Fields of the struct.
21     fields: Vec<SequenceField>,
22 }
23 
24 impl DeriveSequence {
25     /// Parse [`DeriveInput`].
new(input: DeriveInput) -> syn::Result<Self>26     pub fn new(input: DeriveInput) -> syn::Result<Self> {
27         let data = match input.data {
28             syn::Data::Struct(data) => data,
29             _ => abort!(
30                 input.ident,
31                 "can't derive `Sequence` on this type: only `struct` types are allowed",
32             ),
33         };
34 
35         let type_attrs = TypeAttrs::parse(&input.attrs)?;
36 
37         let fields = data
38             .fields
39             .iter()
40             .map(|field| SequenceField::new(field, &type_attrs))
41             .collect::<syn::Result<_>>()?;
42 
43         Ok(Self {
44             ident: input.ident,
45             generics: input.generics.clone(),
46             fields,
47         })
48     }
49 
50     /// Lower the derived output into a [`TokenStream`].
to_tokens(&self) -> TokenStream51     pub fn to_tokens(&self) -> TokenStream {
52         let ident = &self.ident;
53         let mut generics = self.generics.clone();
54 
55         // Use the first lifetime parameter as lifetime for Decode/Encode lifetime
56         // if none found, add one.
57         let lifetime = generics
58             .lifetimes()
59             .next()
60             .map(|lt| lt.lifetime.clone())
61             .unwrap_or_else(|| {
62                 let lt = default_lifetime();
63                 generics
64                     .params
65                     .insert(0, GenericParam::Lifetime(LifetimeParam::new(lt.clone())));
66                 lt
67             });
68 
69         // We may or may not have inserted a lifetime.
70         let (_, ty_generics, where_clause) = self.generics.split_for_impl();
71         let (impl_generics, _, _) = generics.split_for_impl();
72 
73         let mut decode_body = Vec::new();
74         let mut decode_result = Vec::new();
75         let mut encoded_lengths = Vec::new();
76         let mut encode_fields = Vec::new();
77 
78         for field in &self.fields {
79             decode_body.push(field.to_decode_tokens());
80             decode_result.push(&field.ident);
81 
82             let field = field.to_encode_tokens();
83             encoded_lengths.push(quote!(#field.encoded_len()?));
84             encode_fields.push(quote!(#field.encode(writer)?;));
85         }
86 
87         quote! {
88             impl #impl_generics ::der::DecodeValue<#lifetime> for #ident #ty_generics #where_clause {
89                 fn decode_value<R: ::der::Reader<#lifetime>>(
90                     reader: &mut R,
91                     header: ::der::Header,
92                 ) -> ::der::Result<Self> {
93                     use ::der::{Decode as _, DecodeValue as _, Reader as _};
94 
95                     reader.read_nested(header.length, |reader| {
96                         #(#decode_body)*
97 
98                         Ok(Self {
99                             #(#decode_result),*
100                         })
101                     })
102                 }
103             }
104 
105             impl #impl_generics ::der::EncodeValue for #ident #ty_generics #where_clause {
106                 fn value_len(&self) -> ::der::Result<::der::Length> {
107                     use ::der::Encode as _;
108 
109                     [
110                         #(#encoded_lengths),*
111                     ]
112                         .into_iter()
113                         .try_fold(::der::Length::ZERO, |acc, len| acc + len)
114                 }
115 
116                 fn encode_value(&self, writer: &mut impl ::der::Writer) -> ::der::Result<()> {
117                     use ::der::Encode as _;
118                     #(#encode_fields)*
119                     Ok(())
120                 }
121             }
122 
123             impl #impl_generics ::der::Sequence<#lifetime> for #ident #ty_generics #where_clause {}
124         }
125     }
126 }
127 
128 #[cfg(test)]
129 mod tests {
130     use super::DeriveSequence;
131     use crate::{Asn1Type, TagMode};
132     use syn::parse_quote;
133 
134     /// X.509 SPKI `AlgorithmIdentifier`.
135     #[test]
algorithm_identifier_example()136     fn algorithm_identifier_example() {
137         let input = parse_quote! {
138             #[derive(Sequence)]
139             pub struct AlgorithmIdentifier<'a> {
140                 pub algorithm: ObjectIdentifier,
141                 pub parameters: Option<Any<'a>>,
142             }
143         };
144 
145         let ir = DeriveSequence::new(input).unwrap();
146         assert_eq!(ir.ident, "AlgorithmIdentifier");
147         assert_eq!(
148             ir.generics.lifetimes().next().unwrap().lifetime.to_string(),
149             "'a"
150         );
151         assert_eq!(ir.fields.len(), 2);
152 
153         let algorithm_field = &ir.fields[0];
154         assert_eq!(algorithm_field.ident, "algorithm");
155         assert_eq!(algorithm_field.attrs.asn1_type, None);
156         assert_eq!(algorithm_field.attrs.context_specific, None);
157         assert_eq!(algorithm_field.attrs.tag_mode, TagMode::Explicit);
158 
159         let parameters_field = &ir.fields[1];
160         assert_eq!(parameters_field.ident, "parameters");
161         assert_eq!(parameters_field.attrs.asn1_type, None);
162         assert_eq!(parameters_field.attrs.context_specific, None);
163         assert_eq!(parameters_field.attrs.tag_mode, TagMode::Explicit);
164     }
165 
166     /// X.509 `SubjectPublicKeyInfo`.
167     #[test]
spki_example()168     fn spki_example() {
169         let input = parse_quote! {
170             #[derive(Sequence)]
171             pub struct SubjectPublicKeyInfo<'a> {
172                 pub algorithm: AlgorithmIdentifier<'a>,
173 
174                 #[asn1(type = "BIT STRING")]
175                 pub subject_public_key: &'a [u8],
176             }
177         };
178 
179         let ir = DeriveSequence::new(input).unwrap();
180         assert_eq!(ir.ident, "SubjectPublicKeyInfo");
181         assert_eq!(
182             ir.generics.lifetimes().next().unwrap().lifetime.to_string(),
183             "'a"
184         );
185         assert_eq!(ir.fields.len(), 2);
186 
187         let algorithm_field = &ir.fields[0];
188         assert_eq!(algorithm_field.ident, "algorithm");
189         assert_eq!(algorithm_field.attrs.asn1_type, None);
190         assert_eq!(algorithm_field.attrs.context_specific, None);
191         assert_eq!(algorithm_field.attrs.tag_mode, TagMode::Explicit);
192 
193         let subject_public_key_field = &ir.fields[1];
194         assert_eq!(subject_public_key_field.ident, "subject_public_key");
195         assert_eq!(
196             subject_public_key_field.attrs.asn1_type,
197             Some(Asn1Type::BitString)
198         );
199         assert_eq!(subject_public_key_field.attrs.context_specific, None);
200         assert_eq!(subject_public_key_field.attrs.tag_mode, TagMode::Explicit);
201     }
202 
203     /// PKCS#8v2 `OneAsymmetricKey`.
204     ///
205     /// ```text
206     /// OneAsymmetricKey ::= SEQUENCE {
207     ///     version                   Version,
208     ///     privateKeyAlgorithm       PrivateKeyAlgorithmIdentifier,
209     ///     privateKey                PrivateKey,
210     ///     attributes            [0] Attributes OPTIONAL,
211     ///     ...,
212     ///     [[2: publicKey        [1] PublicKey OPTIONAL ]],
213     ///     ...
214     ///   }
215     ///
216     /// Version ::= INTEGER { v1(0), v2(1) } (v1, ..., v2)
217     ///
218     /// PrivateKeyAlgorithmIdentifier ::= AlgorithmIdentifier
219     ///
220     /// PrivateKey ::= OCTET STRING
221     ///
222     /// Attributes ::= SET OF Attribute
223     ///
224     /// PublicKey ::= BIT STRING
225     /// ```
226     #[test]
pkcs8_example()227     fn pkcs8_example() {
228         let input = parse_quote! {
229             #[derive(Sequence)]
230             pub struct OneAsymmetricKey<'a> {
231                 pub version: u8,
232                 pub private_key_algorithm: AlgorithmIdentifier<'a>,
233                 #[asn1(type = "OCTET STRING")]
234                 pub private_key: &'a [u8],
235                 #[asn1(context_specific = "0", extensible = "true", optional = "true")]
236                 pub attributes: Option<SetOf<Any<'a>, 1>>,
237                 #[asn1(
238                     context_specific = "1",
239                     extensible = "true",
240                     optional = "true",
241                     type = "BIT STRING"
242                 )]
243                 pub public_key: Option<&'a [u8]>,
244             }
245         };
246 
247         let ir = DeriveSequence::new(input).unwrap();
248         assert_eq!(ir.ident, "OneAsymmetricKey");
249         assert_eq!(
250             ir.generics.lifetimes().next().unwrap().lifetime.to_string(),
251             "'a"
252         );
253         assert_eq!(ir.fields.len(), 5);
254 
255         let version_field = &ir.fields[0];
256         assert_eq!(version_field.ident, "version");
257         assert_eq!(version_field.attrs.asn1_type, None);
258         assert_eq!(version_field.attrs.context_specific, None);
259         assert_eq!(version_field.attrs.extensible, false);
260         assert_eq!(version_field.attrs.optional, false);
261         assert_eq!(version_field.attrs.tag_mode, TagMode::Explicit);
262 
263         let algorithm_field = &ir.fields[1];
264         assert_eq!(algorithm_field.ident, "private_key_algorithm");
265         assert_eq!(algorithm_field.attrs.asn1_type, None);
266         assert_eq!(algorithm_field.attrs.context_specific, None);
267         assert_eq!(algorithm_field.attrs.extensible, false);
268         assert_eq!(algorithm_field.attrs.optional, false);
269         assert_eq!(algorithm_field.attrs.tag_mode, TagMode::Explicit);
270 
271         let private_key_field = &ir.fields[2];
272         assert_eq!(private_key_field.ident, "private_key");
273         assert_eq!(
274             private_key_field.attrs.asn1_type,
275             Some(Asn1Type::OctetString)
276         );
277         assert_eq!(private_key_field.attrs.context_specific, None);
278         assert_eq!(private_key_field.attrs.extensible, false);
279         assert_eq!(private_key_field.attrs.optional, false);
280         assert_eq!(private_key_field.attrs.tag_mode, TagMode::Explicit);
281 
282         let attributes_field = &ir.fields[3];
283         assert_eq!(attributes_field.ident, "attributes");
284         assert_eq!(attributes_field.attrs.asn1_type, None);
285         assert_eq!(
286             attributes_field.attrs.context_specific,
287             Some("0".parse().unwrap())
288         );
289         assert_eq!(attributes_field.attrs.extensible, true);
290         assert_eq!(attributes_field.attrs.optional, true);
291         assert_eq!(attributes_field.attrs.tag_mode, TagMode::Explicit);
292 
293         let public_key_field = &ir.fields[4];
294         assert_eq!(public_key_field.ident, "public_key");
295         assert_eq!(public_key_field.attrs.asn1_type, Some(Asn1Type::BitString));
296         assert_eq!(
297             public_key_field.attrs.context_specific,
298             Some("1".parse().unwrap())
299         );
300         assert_eq!(public_key_field.attrs.extensible, true);
301         assert_eq!(public_key_field.attrs.optional, true);
302         assert_eq!(public_key_field.attrs.tag_mode, TagMode::Explicit);
303     }
304 
305     /// `IMPLICIT` tagged example
306     #[test]
implicit_example()307     fn implicit_example() {
308         let input = parse_quote! {
309             #[asn1(tag_mode = "IMPLICIT")]
310             pub struct ImplicitSequence<'a> {
311                 #[asn1(context_specific = "0", type = "BIT STRING")]
312                 bit_string: BitString<'a>,
313 
314                 #[asn1(context_specific = "1", type = "GeneralizedTime")]
315                 time: GeneralizedTime,
316 
317                 #[asn1(context_specific = "2", type = "UTF8String")]
318                 utf8_string: String,
319             }
320         };
321 
322         let ir = DeriveSequence::new(input).unwrap();
323         assert_eq!(ir.ident, "ImplicitSequence");
324         assert_eq!(
325             ir.generics.lifetimes().next().unwrap().lifetime.to_string(),
326             "'a"
327         );
328         assert_eq!(ir.fields.len(), 3);
329 
330         let bit_string = &ir.fields[0];
331         assert_eq!(bit_string.ident, "bit_string");
332         assert_eq!(bit_string.attrs.asn1_type, Some(Asn1Type::BitString));
333         assert_eq!(
334             bit_string.attrs.context_specific,
335             Some("0".parse().unwrap())
336         );
337         assert_eq!(bit_string.attrs.tag_mode, TagMode::Implicit);
338 
339         let time = &ir.fields[1];
340         assert_eq!(time.ident, "time");
341         assert_eq!(time.attrs.asn1_type, Some(Asn1Type::GeneralizedTime));
342         assert_eq!(time.attrs.context_specific, Some("1".parse().unwrap()));
343         assert_eq!(time.attrs.tag_mode, TagMode::Implicit);
344 
345         let utf8_string = &ir.fields[2];
346         assert_eq!(utf8_string.ident, "utf8_string");
347         assert_eq!(utf8_string.attrs.asn1_type, Some(Asn1Type::Utf8String));
348         assert_eq!(
349             utf8_string.attrs.context_specific,
350             Some("2".parse().unwrap())
351         );
352         assert_eq!(utf8_string.attrs.tag_mode, TagMode::Implicit);
353     }
354 }
355