// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // Copyright by contributors to this project. // SPDX-License-Identifier: (Apache-2.0 OR MIT) use std::str::FromStr; use darling::{ ast::{self, Fields}, FromDeriveInput, FromField, FromVariant, }; use proc_macro2::{Literal, TokenStream}; use quote::quote; use syn::{ parse_macro_input, parse_quote, Attribute, DeriveInput, Expr, Generics, Ident, Index, Lit, Path, }; enum Operation { Size, Encode, Decode, } impl Operation { fn path(&self) -> Path { match self { Operation::Size => parse_quote! { mls_rs_codec::MlsSize }, Operation::Encode => parse_quote! { mls_rs_codec::MlsEncode }, Operation::Decode => parse_quote! { mls_rs_codec::MlsDecode }, } } fn call(&self) -> TokenStream { match self { Operation::Size => quote! { mls_encoded_len }, Operation::Encode => quote! { mls_encode }, Operation::Decode => quote! { mls_decode }, } } fn extras(&self) -> TokenStream { match self { Operation::Size => quote! {}, Operation::Encode => quote! { , writer }, Operation::Decode => quote! { reader }, } } fn is_result(&self) -> bool { match self { Operation::Size => false, Operation::Encode => true, Operation::Decode => true, } } } #[derive(Debug, FromField)] #[darling(attributes(mls_codec))] struct MlsFieldReceiver { ident: Option, with: Option, } impl MlsFieldReceiver { pub fn call_tokens(&self, index: Index) -> TokenStream { if let Some(ref ident) = self.ident { quote! { &self.#ident } } else { quote! { &self.#index } } } pub fn name(&self, index: Index) -> TokenStream { if let Some(ref ident) = self.ident { quote! {#ident: } } else { quote! { #index: } } } } #[derive(Debug, FromVariant)] #[darling(attributes(mls_codec))] struct MlsVariantReceiver { ident: Ident, discriminant: Option, fields: ast::Fields, } #[derive(FromDeriveInput)] #[darling(attributes(mls_codec), forward_attrs(repr))] struct MlsInputReceiver { attrs: Vec, ident: Ident, generics: Generics, data: ast::Data, } impl MlsInputReceiver { fn handle_input(&self, operation: Operation) -> TokenStream { match self.data { ast::Data::Struct(ref s) => struct_impl(s, operation), ast::Data::Enum(ref e) => enum_impl(&self.ident, &self.attrs, e, operation), } } } fn repr_ident(attrs: &[Attribute]) -> Option { let repr_path = attrs .iter() .filter(|attr| matches!(attr.style, syn::AttrStyle::Outer)) .find(|attr| attr.path().is_ident("repr")) .map(|repr| repr.parse_args()) .transpose() .ok() .flatten(); let Some(Expr::Path(path)) = repr_path else { return None; }; path.path .segments .iter() .find(|s| s.ident != "C") .map(|path| path.ident.clone()) } /// Provides the discriminant for a given variant. If the variant does not specify a suffix /// and a `repr_ident` is provided, it will be appended to number. fn discriminant_for_variant( variant: &MlsVariantReceiver, repr_ident: &Option, ) -> TokenStream { let discriminant = variant .discriminant .clone() .expect("Enum discriminants must be explicitly defined"); let Expr::Lit(lit_expr) = &discriminant else { return quote! {#discriminant}; }; let Lit::Int(lit_int) = &lit_expr.lit else { return quote! {#discriminant}; }; if lit_int.suffix().is_empty() { // This is dirty and there is probably a better way of doing this but I'm way too much of a noob at // proc macros to pull it off... // TODO: Add proper support for correctly ignoring transparent, packed and modifiers let str = format!( "{}{}", lit_int.base10_digits(), &repr_ident.clone().expect("Expected a repr(u*) to be provided or for the variant's discriminant to be defined with suffixed literals.") ); Literal::from_str(&str) .map(|l| quote! {#l}) .ok() .unwrap_or_else(|| quote! {#discriminant}) } else { quote! {#discriminant} } } fn enum_impl( ident: &Ident, attrs: &[Attribute], variants: &[MlsVariantReceiver], operation: Operation, ) -> TokenStream { let handle_error = operation.is_result().then_some(quote! { ? }); let path = operation.path(); let call = operation.call(); let extras = operation.extras(); let enum_name = &ident; let repr_ident = repr_ident(attrs); if matches!(operation, Operation::Decode) { let cases = variants.iter().map(|variant| { let variant_name = &variant.ident; let discriminant = discriminant_for_variant(variant, &repr_ident); // TODO: Support more than 1 field match variant.fields.len() { 0 => quote! { #discriminant => Ok(#enum_name::#variant_name), }, 1 =>{ let path = variant.fields.fields[0].with.as_ref().unwrap_or(&path); quote! { #discriminant => Ok(#enum_name::#variant_name(#path::#call(#extras) #handle_error)), } }, _ => panic!("Enum discriminants with more than 1 field are not currently supported") } }); return quote! { let discriminant = #path::#call(#extras)#handle_error; match discriminant { #(#cases)* _ => Err(mls_rs_codec::Error::UnsupportedEnumDiscriminant), } }; } let cases = variants.iter().map(|variant| { let variant_name = &variant.ident; let discriminant = discriminant_for_variant(variant, &repr_ident); let (parameter, field) = if variant.fields.is_empty() { (None, None) } else { let path = variant.fields.fields[0].with.as_ref().unwrap_or(&path); let start = match operation { Operation::Size => Some(quote! { + }), Operation::Encode => Some(quote! {;}), Operation::Decode => None, }; ( Some(quote! {(ref val)}), Some(quote! { #start #path::#call (val #extras) #handle_error }), ) }; let discrim = quote! { #path::#call (&#discriminant #extras) #handle_error }; quote! { #enum_name::#variant_name #parameter => { #discrim #field }} }); let enum_impl = quote! { match self { #(#cases)* } }; if operation.is_result() { quote! { Ok(#enum_impl) } } else { enum_impl } } fn struct_impl(s: &Fields, operation: Operation) -> TokenStream { let recurse = s.fields.iter().enumerate().map(|(index, field)| { let (call_tokens, field_name) = match operation { Operation::Size | Operation::Encode => { (field.call_tokens(Index::from(index)), quote! {}) } Operation::Decode => (quote! {}, field.name(Index::from(index))), }; let handle_error = operation.is_result().then_some(quote! { ? }); let path = field.with.clone().unwrap_or(operation.path()); let call = operation.call(); let extras = operation.extras(); quote! { #field_name #path::#call (#call_tokens #extras) #handle_error } }); match operation { Operation::Size => quote! { 0 #(+ #recurse)* }, Operation::Encode => quote! { #(#recurse;)* Ok(()) }, Operation::Decode => quote! { Ok(Self { #(#recurse,)* }) }, } } fn derive_impl( input: proc_macro::TokenStream, trait_name: TokenStream, function_def: TokenStream, internals: F, ) -> proc_macro::TokenStream where F: FnOnce(&MlsInputReceiver) -> TokenStream, { let input = parse_macro_input!(input as DeriveInput); let input = MlsInputReceiver::from_derive_input(&input).unwrap(); let name = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); // Generate an expression to sum up the heap size of each field. let function_impl = internals(&input); let expanded = quote! { // The generated impl. impl #impl_generics #trait_name for #name #ty_generics #where_clause { #function_def { #function_impl } } }; // Hand the output tokens back to the compiler. proc_macro::TokenStream::from(expanded) } #[proc_macro_derive(MlsSize, attributes(mls_codec))] pub fn derive_size(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let trait_name = quote! { mls_rs_codec::MlsSize }; let function_def = quote! {fn mls_encoded_len(&self) -> usize }; derive_impl(input, trait_name, function_def, |input| { input.handle_input(Operation::Size) }) } #[proc_macro_derive(MlsEncode, attributes(mls_codec))] pub fn derive_encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let trait_name = quote! { mls_rs_codec::MlsEncode }; let function_def = quote! { fn mls_encode(&self, writer: &mut mls_rs_codec::Vec) -> Result<(), mls_rs_codec::Error> }; derive_impl(input, trait_name, function_def, |input| { input.handle_input(Operation::Encode) }) } #[proc_macro_derive(MlsDecode, attributes(mls_codec))] pub fn derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let trait_name = quote! { mls_rs_codec::MlsDecode }; let function_def = quote! { fn mls_decode(reader: &mut &[u8]) -> Result }; derive_impl(input, trait_name, function_def, |input| { input.handle_input(Operation::Decode) }) }