1 //! Support for deriving the `Sequence` trait on structs for the purposes of 2 //! decoding/encoding ASN.1 `SEQUENCE` types as mapped to struct fields. 3 4 mod field; 5 6 use crate::{default_lifetime, TypeAttrs}; 7 use field::SequenceField; 8 use proc_macro2::TokenStream; 9 use quote::quote; 10 use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam}; 11 12 /// Derive the `Sequence` trait for a struct 13 pub(crate) struct DeriveSequence { 14 /// Name of the sequence struct. 15 ident: Ident, 16 17 /// Generics of the struct. 18 generics: Generics, 19 20 /// Fields of the struct. 21 fields: Vec<SequenceField>, 22 } 23 24 impl DeriveSequence { 25 /// Parse [`DeriveInput`]. new(input: DeriveInput) -> syn::Result<Self>26 pub fn new(input: DeriveInput) -> syn::Result<Self> { 27 let data = match input.data { 28 syn::Data::Struct(data) => data, 29 _ => abort!( 30 input.ident, 31 "can't derive `Sequence` on this type: only `struct` types are allowed", 32 ), 33 }; 34 35 let type_attrs = TypeAttrs::parse(&input.attrs)?; 36 37 let fields = data 38 .fields 39 .iter() 40 .map(|field| SequenceField::new(field, &type_attrs)) 41 .collect::<syn::Result<_>>()?; 42 43 Ok(Self { 44 ident: input.ident, 45 generics: input.generics.clone(), 46 fields, 47 }) 48 } 49 50 /// Lower the derived output into a [`TokenStream`]. to_tokens(&self) -> TokenStream51 pub fn to_tokens(&self) -> TokenStream { 52 let ident = &self.ident; 53 let mut generics = self.generics.clone(); 54 55 // Use the first lifetime parameter as lifetime for Decode/Encode lifetime 56 // if none found, add one. 57 let lifetime = generics 58 .lifetimes() 59 .next() 60 .map(|lt| lt.lifetime.clone()) 61 .unwrap_or_else(|| { 62 let lt = default_lifetime(); 63 generics 64 .params 65 .insert(0, GenericParam::Lifetime(LifetimeParam::new(lt.clone()))); 66 lt 67 }); 68 69 // We may or may not have inserted a lifetime. 70 let (_, ty_generics, where_clause) = self.generics.split_for_impl(); 71 let (impl_generics, _, _) = generics.split_for_impl(); 72 73 let mut decode_body = Vec::new(); 74 let mut decode_result = Vec::new(); 75 let mut encoded_lengths = Vec::new(); 76 let mut encode_fields = Vec::new(); 77 78 for field in &self.fields { 79 decode_body.push(field.to_decode_tokens()); 80 decode_result.push(&field.ident); 81 82 let field = field.to_encode_tokens(); 83 encoded_lengths.push(quote!(#field.encoded_len()?)); 84 encode_fields.push(quote!(#field.encode(writer)?;)); 85 } 86 87 quote! { 88 impl #impl_generics ::der::DecodeValue<#lifetime> for #ident #ty_generics #where_clause { 89 fn decode_value<R: ::der::Reader<#lifetime>>( 90 reader: &mut R, 91 header: ::der::Header, 92 ) -> ::der::Result<Self> { 93 use ::der::{Decode as _, DecodeValue as _, Reader as _}; 94 95 reader.read_nested(header.length, |reader| { 96 #(#decode_body)* 97 98 Ok(Self { 99 #(#decode_result),* 100 }) 101 }) 102 } 103 } 104 105 impl #impl_generics ::der::EncodeValue for #ident #ty_generics #where_clause { 106 fn value_len(&self) -> ::der::Result<::der::Length> { 107 use ::der::Encode as _; 108 109 [ 110 #(#encoded_lengths),* 111 ] 112 .into_iter() 113 .try_fold(::der::Length::ZERO, |acc, len| acc + len) 114 } 115 116 fn encode_value(&self, writer: &mut impl ::der::Writer) -> ::der::Result<()> { 117 use ::der::Encode as _; 118 #(#encode_fields)* 119 Ok(()) 120 } 121 } 122 123 impl #impl_generics ::der::Sequence<#lifetime> for #ident #ty_generics #where_clause {} 124 } 125 } 126 } 127 128 #[cfg(test)] 129 mod tests { 130 use super::DeriveSequence; 131 use crate::{Asn1Type, TagMode}; 132 use syn::parse_quote; 133 134 /// X.509 SPKI `AlgorithmIdentifier`. 135 #[test] algorithm_identifier_example()136 fn algorithm_identifier_example() { 137 let input = parse_quote! { 138 #[derive(Sequence)] 139 pub struct AlgorithmIdentifier<'a> { 140 pub algorithm: ObjectIdentifier, 141 pub parameters: Option<Any<'a>>, 142 } 143 }; 144 145 let ir = DeriveSequence::new(input).unwrap(); 146 assert_eq!(ir.ident, "AlgorithmIdentifier"); 147 assert_eq!( 148 ir.generics.lifetimes().next().unwrap().lifetime.to_string(), 149 "'a" 150 ); 151 assert_eq!(ir.fields.len(), 2); 152 153 let algorithm_field = &ir.fields[0]; 154 assert_eq!(algorithm_field.ident, "algorithm"); 155 assert_eq!(algorithm_field.attrs.asn1_type, None); 156 assert_eq!(algorithm_field.attrs.context_specific, None); 157 assert_eq!(algorithm_field.attrs.tag_mode, TagMode::Explicit); 158 159 let parameters_field = &ir.fields[1]; 160 assert_eq!(parameters_field.ident, "parameters"); 161 assert_eq!(parameters_field.attrs.asn1_type, None); 162 assert_eq!(parameters_field.attrs.context_specific, None); 163 assert_eq!(parameters_field.attrs.tag_mode, TagMode::Explicit); 164 } 165 166 /// X.509 `SubjectPublicKeyInfo`. 167 #[test] spki_example()168 fn spki_example() { 169 let input = parse_quote! { 170 #[derive(Sequence)] 171 pub struct SubjectPublicKeyInfo<'a> { 172 pub algorithm: AlgorithmIdentifier<'a>, 173 174 #[asn1(type = "BIT STRING")] 175 pub subject_public_key: &'a [u8], 176 } 177 }; 178 179 let ir = DeriveSequence::new(input).unwrap(); 180 assert_eq!(ir.ident, "SubjectPublicKeyInfo"); 181 assert_eq!( 182 ir.generics.lifetimes().next().unwrap().lifetime.to_string(), 183 "'a" 184 ); 185 assert_eq!(ir.fields.len(), 2); 186 187 let algorithm_field = &ir.fields[0]; 188 assert_eq!(algorithm_field.ident, "algorithm"); 189 assert_eq!(algorithm_field.attrs.asn1_type, None); 190 assert_eq!(algorithm_field.attrs.context_specific, None); 191 assert_eq!(algorithm_field.attrs.tag_mode, TagMode::Explicit); 192 193 let subject_public_key_field = &ir.fields[1]; 194 assert_eq!(subject_public_key_field.ident, "subject_public_key"); 195 assert_eq!( 196 subject_public_key_field.attrs.asn1_type, 197 Some(Asn1Type::BitString) 198 ); 199 assert_eq!(subject_public_key_field.attrs.context_specific, None); 200 assert_eq!(subject_public_key_field.attrs.tag_mode, TagMode::Explicit); 201 } 202 203 /// PKCS#8v2 `OneAsymmetricKey`. 204 /// 205 /// ```text 206 /// OneAsymmetricKey ::= SEQUENCE { 207 /// version Version, 208 /// privateKeyAlgorithm PrivateKeyAlgorithmIdentifier, 209 /// privateKey PrivateKey, 210 /// attributes [0] Attributes OPTIONAL, 211 /// ..., 212 /// [[2: publicKey [1] PublicKey OPTIONAL ]], 213 /// ... 214 /// } 215 /// 216 /// Version ::= INTEGER { v1(0), v2(1) } (v1, ..., v2) 217 /// 218 /// PrivateKeyAlgorithmIdentifier ::= AlgorithmIdentifier 219 /// 220 /// PrivateKey ::= OCTET STRING 221 /// 222 /// Attributes ::= SET OF Attribute 223 /// 224 /// PublicKey ::= BIT STRING 225 /// ``` 226 #[test] pkcs8_example()227 fn pkcs8_example() { 228 let input = parse_quote! { 229 #[derive(Sequence)] 230 pub struct OneAsymmetricKey<'a> { 231 pub version: u8, 232 pub private_key_algorithm: AlgorithmIdentifier<'a>, 233 #[asn1(type = "OCTET STRING")] 234 pub private_key: &'a [u8], 235 #[asn1(context_specific = "0", extensible = "true", optional = "true")] 236 pub attributes: Option<SetOf<Any<'a>, 1>>, 237 #[asn1( 238 context_specific = "1", 239 extensible = "true", 240 optional = "true", 241 type = "BIT STRING" 242 )] 243 pub public_key: Option<&'a [u8]>, 244 } 245 }; 246 247 let ir = DeriveSequence::new(input).unwrap(); 248 assert_eq!(ir.ident, "OneAsymmetricKey"); 249 assert_eq!( 250 ir.generics.lifetimes().next().unwrap().lifetime.to_string(), 251 "'a" 252 ); 253 assert_eq!(ir.fields.len(), 5); 254 255 let version_field = &ir.fields[0]; 256 assert_eq!(version_field.ident, "version"); 257 assert_eq!(version_field.attrs.asn1_type, None); 258 assert_eq!(version_field.attrs.context_specific, None); 259 assert_eq!(version_field.attrs.extensible, false); 260 assert_eq!(version_field.attrs.optional, false); 261 assert_eq!(version_field.attrs.tag_mode, TagMode::Explicit); 262 263 let algorithm_field = &ir.fields[1]; 264 assert_eq!(algorithm_field.ident, "private_key_algorithm"); 265 assert_eq!(algorithm_field.attrs.asn1_type, None); 266 assert_eq!(algorithm_field.attrs.context_specific, None); 267 assert_eq!(algorithm_field.attrs.extensible, false); 268 assert_eq!(algorithm_field.attrs.optional, false); 269 assert_eq!(algorithm_field.attrs.tag_mode, TagMode::Explicit); 270 271 let private_key_field = &ir.fields[2]; 272 assert_eq!(private_key_field.ident, "private_key"); 273 assert_eq!( 274 private_key_field.attrs.asn1_type, 275 Some(Asn1Type::OctetString) 276 ); 277 assert_eq!(private_key_field.attrs.context_specific, None); 278 assert_eq!(private_key_field.attrs.extensible, false); 279 assert_eq!(private_key_field.attrs.optional, false); 280 assert_eq!(private_key_field.attrs.tag_mode, TagMode::Explicit); 281 282 let attributes_field = &ir.fields[3]; 283 assert_eq!(attributes_field.ident, "attributes"); 284 assert_eq!(attributes_field.attrs.asn1_type, None); 285 assert_eq!( 286 attributes_field.attrs.context_specific, 287 Some("0".parse().unwrap()) 288 ); 289 assert_eq!(attributes_field.attrs.extensible, true); 290 assert_eq!(attributes_field.attrs.optional, true); 291 assert_eq!(attributes_field.attrs.tag_mode, TagMode::Explicit); 292 293 let public_key_field = &ir.fields[4]; 294 assert_eq!(public_key_field.ident, "public_key"); 295 assert_eq!(public_key_field.attrs.asn1_type, Some(Asn1Type::BitString)); 296 assert_eq!( 297 public_key_field.attrs.context_specific, 298 Some("1".parse().unwrap()) 299 ); 300 assert_eq!(public_key_field.attrs.extensible, true); 301 assert_eq!(public_key_field.attrs.optional, true); 302 assert_eq!(public_key_field.attrs.tag_mode, TagMode::Explicit); 303 } 304 305 /// `IMPLICIT` tagged example 306 #[test] implicit_example()307 fn implicit_example() { 308 let input = parse_quote! { 309 #[asn1(tag_mode = "IMPLICIT")] 310 pub struct ImplicitSequence<'a> { 311 #[asn1(context_specific = "0", type = "BIT STRING")] 312 bit_string: BitString<'a>, 313 314 #[asn1(context_specific = "1", type = "GeneralizedTime")] 315 time: GeneralizedTime, 316 317 #[asn1(context_specific = "2", type = "UTF8String")] 318 utf8_string: String, 319 } 320 }; 321 322 let ir = DeriveSequence::new(input).unwrap(); 323 assert_eq!(ir.ident, "ImplicitSequence"); 324 assert_eq!( 325 ir.generics.lifetimes().next().unwrap().lifetime.to_string(), 326 "'a" 327 ); 328 assert_eq!(ir.fields.len(), 3); 329 330 let bit_string = &ir.fields[0]; 331 assert_eq!(bit_string.ident, "bit_string"); 332 assert_eq!(bit_string.attrs.asn1_type, Some(Asn1Type::BitString)); 333 assert_eq!( 334 bit_string.attrs.context_specific, 335 Some("0".parse().unwrap()) 336 ); 337 assert_eq!(bit_string.attrs.tag_mode, TagMode::Implicit); 338 339 let time = &ir.fields[1]; 340 assert_eq!(time.ident, "time"); 341 assert_eq!(time.attrs.asn1_type, Some(Asn1Type::GeneralizedTime)); 342 assert_eq!(time.attrs.context_specific, Some("1".parse().unwrap())); 343 assert_eq!(time.attrs.tag_mode, TagMode::Implicit); 344 345 let utf8_string = &ir.fields[2]; 346 assert_eq!(utf8_string.ident, "utf8_string"); 347 assert_eq!(utf8_string.attrs.asn1_type, Some(Asn1Type::Utf8String)); 348 assert_eq!( 349 utf8_string.attrs.context_specific, 350 Some("2".parse().unwrap()) 351 ); 352 assert_eq!(utf8_string.attrs.tag_mode, TagMode::Implicit); 353 } 354 } 355