#![allow(unused_imports)] use std::{cmp, convert::TryFrom}; use proc_macro2::{Ident, Span, TokenStream, TokenTree}; use quote::{quote, quote_spanned, ToTokens}; use syn::{ parse::{Parse, ParseStream, Parser}, punctuated::Punctuated, spanned::Spanned, Result, *, }; macro_rules! bail { ($msg:expr $(,)?) => { return Err(Error::new(Span::call_site(), &$msg[..])) }; ( $msg:expr => $span_to_blame:expr $(,)? ) => { return Err(Error::new_spanned(&$span_to_blame, $msg)) }; } pub trait Derivable { fn ident(input: &DeriveInput, crate_name: &TokenStream) -> Result; fn implies_trait(_crate_name: &TokenStream) -> Option { None } fn asserts( _input: &DeriveInput, _crate_name: &TokenStream, ) -> Result { Ok(quote!()) } fn check_attributes(_ty: &Data, _attributes: &[Attribute]) -> Result<()> { Ok(()) } fn trait_impl( _input: &DeriveInput, _crate_name: &TokenStream, ) -> Result<(TokenStream, TokenStream)> { Ok((quote!(), quote!())) } fn requires_where_clause() -> bool { true } fn explicit_bounds_attribute_name() -> Option<&'static str> { None } } pub struct Pod; impl Derivable for Pod { fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result { Ok(syn::parse_quote!(#crate_name::Pod)) } fn asserts( input: &DeriveInput, crate_name: &TokenStream, ) -> Result { let repr = get_repr(&input.attrs)?; let completly_packed = repr.packed == Some(1) || repr.repr == Repr::Transparent; if !completly_packed && !input.generics.params.is_empty() { bail!("\ Pod requires cannot be derived for non-packed types containing \ generic parameters because the padding requirements can't be verified \ for generic non-packed structs\ " => input.generics.params.first().unwrap()); } match &input.data { Data::Struct(_) => { let assert_no_padding = if !completly_packed { Some(generate_assert_no_padding(input)?) } else { None }; let assert_fields_are_pod = generate_fields_are_trait(input, Self::ident(input, crate_name)?)?; Ok(quote!( #assert_no_padding #assert_fields_are_pod )) } Data::Enum(_) => bail!("Deriving Pod is not supported for enums"), Data::Union(_) => bail!("Deriving Pod is not supported for unions"), } } fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> { let repr = get_repr(attributes)?; match repr.repr { Repr::C => Ok(()), Repr::Transparent => Ok(()), _ => { bail!("Pod requires the type to be #[repr(C)] or #[repr(transparent)]") } } } } pub struct AnyBitPattern; impl Derivable for AnyBitPattern { fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result { Ok(syn::parse_quote!(#crate_name::AnyBitPattern)) } fn implies_trait(crate_name: &TokenStream) -> Option { Some(quote!(#crate_name::Zeroable)) } fn asserts( input: &DeriveInput, crate_name: &TokenStream, ) -> Result { match &input.data { Data::Union(_) => Ok(quote!()), // unions are always `AnyBitPattern` Data::Struct(_) => { generate_fields_are_trait(input, Self::ident(input, crate_name)?) } Data::Enum(_) => { bail!("Deriving AnyBitPattern is not supported for enums") } } } } pub struct Zeroable; impl Derivable for Zeroable { fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result { Ok(syn::parse_quote!(#crate_name::Zeroable)) } fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> { let repr = get_repr(attributes)?; match ty { Data::Struct(_) => Ok(()), Data::Enum(DataEnum { variants, .. }) => { if !repr.repr.is_integer() { bail!("Zeroable requires the enum to be an explicit #[repr(Int)]") } if variants.iter().any(|variant| !variant.fields.is_empty()) { bail!("Only fieldless enums are supported for Zeroable") } let iter = VariantDiscriminantIterator::new(variants.iter()); let mut has_zero_variant = false; for res in iter { let discriminant = res?; if discriminant == 0 { has_zero_variant = true; break; } } if !has_zero_variant { bail!("No variant's discriminant is 0") } Ok(()) } Data::Union(_) => Ok(()), } } fn asserts( input: &DeriveInput, crate_name: &TokenStream, ) -> Result { match &input.data { Data::Union(_) => Ok(quote!()), // unions are always `Zeroable` Data::Struct(_) => { generate_fields_are_trait(input, Self::ident(input, crate_name)?) } Data::Enum(_) => Ok(quote!()), } } fn explicit_bounds_attribute_name() -> Option<&'static str> { Some("zeroable") } } pub struct NoUninit; impl Derivable for NoUninit { fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result { Ok(syn::parse_quote!(#crate_name::NoUninit)) } fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> { let repr = get_repr(attributes)?; match ty { Data::Struct(_) => match repr.repr { Repr::C | Repr::Transparent => Ok(()), _ => bail!("NoUninit requires the struct to be #[repr(C)] or #[repr(transparent)]"), }, Data::Enum(_) => if repr.repr.is_integer() { Ok(()) } else { bail!("NoUninit requires the enum to be an explicit #[repr(Int)]") }, Data::Union(_) => bail!("NoUninit can only be derived on enums and structs") } } fn asserts( input: &DeriveInput, crate_name: &TokenStream, ) -> Result { if !input.generics.params.is_empty() { bail!("NoUninit cannot be derived for structs containing generic parameters because the padding requirements can't be verified for generic structs"); } match &input.data { Data::Struct(DataStruct { .. }) => { let assert_no_padding = generate_assert_no_padding(&input)?; let assert_fields_are_no_padding = generate_fields_are_trait(&input, Self::ident(input, crate_name)?)?; Ok(quote!( #assert_no_padding #assert_fields_are_no_padding )) } Data::Enum(DataEnum { variants, .. }) => { if variants.iter().any(|variant| !variant.fields.is_empty()) { bail!("Only fieldless enums are supported for NoUninit") } else { Ok(quote!()) } } Data::Union(_) => bail!("NoUninit cannot be derived for unions"), /* shouldn't be possible since we already error in attribute check for this case */ } } fn trait_impl( _input: &DeriveInput, _crate_name: &TokenStream, ) -> Result<(TokenStream, TokenStream)> { Ok((quote!(), quote!())) } } pub struct CheckedBitPattern; impl Derivable for CheckedBitPattern { fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result { Ok(syn::parse_quote!(#crate_name::CheckedBitPattern)) } fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> { let repr = get_repr(attributes)?; match ty { Data::Struct(_) => match repr.repr { Repr::C | Repr::Transparent => Ok(()), _ => bail!("CheckedBitPattern derive requires the struct to be #[repr(C)] or #[repr(transparent)]"), }, Data::Enum(DataEnum { variants,.. }) => { if !enum_has_fields(variants.iter()){ if repr.repr.is_integer() { Ok(()) } else { bail!("CheckedBitPattern requires the enum to be an explicit #[repr(Int)]") } } else if matches!(repr.repr, Repr::Rust) { bail!("CheckedBitPattern requires an explicit repr annotation because `repr(Rust)` doesn't have a specified type layout") } else { Ok(()) } } Data::Union(_) => bail!("CheckedBitPattern can only be derived on enums and structs") } } fn asserts( input: &DeriveInput, crate_name: &TokenStream, ) -> Result { if !input.generics.params.is_empty() { bail!("CheckedBitPattern cannot be derived for structs containing generic parameters"); } match &input.data { Data::Struct(DataStruct { .. }) => { let assert_fields_are_maybe_pod = generate_fields_are_trait(&input, Self::ident(input, crate_name)?)?; Ok(assert_fields_are_maybe_pod) } Data::Enum(_) => Ok(quote!()), /* nothing needed, already guaranteed * OK by NoUninit */ Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */ } } fn trait_impl( input: &DeriveInput, crate_name: &TokenStream, ) -> Result<(TokenStream, TokenStream)> { match &input.data { Data::Struct(DataStruct { fields, .. }) => { generate_checked_bit_pattern_struct( &input.ident, fields, &input.attrs, crate_name, ) } Data::Enum(DataEnum { variants, .. }) => { generate_checked_bit_pattern_enum(input, variants, crate_name) } Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */ } } } pub struct TransparentWrapper; impl TransparentWrapper { fn get_wrapper_type( attributes: &[Attribute], fields: &Fields, ) -> Option { let transparent_param = get_simple_attr(attributes, "transparent"); transparent_param.map(|ident| ident.to_token_stream()).or_else(|| { let mut types = get_field_types(&fields); let first_type = types.next(); if let Some(_) = types.next() { // can't guess param type if there is more than one field return None; } else { first_type.map(|ty| ty.to_token_stream()) } }) } } impl Derivable for TransparentWrapper { fn ident(input: &DeriveInput, crate_name: &TokenStream) -> Result { let fields = get_struct_fields(input)?; let ty = match Self::get_wrapper_type(&input.attrs, &fields) { Some(ty) => ty, None => bail!( "\ when deriving TransparentWrapper for a struct with more than one field \ you need to specify the transparent field using #[transparent(T)]\ " ), }; Ok(syn::parse_quote!(#crate_name::TransparentWrapper<#ty>)) } fn asserts( input: &DeriveInput, crate_name: &TokenStream, ) -> Result { let (impl_generics, _ty_generics, where_clause) = input.generics.split_for_impl(); let fields = get_struct_fields(input)?; let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) { Some(wrapped_type) => wrapped_type.to_string(), None => unreachable!(), /* other code will already reject this derive */ }; let mut wrapped_field_ty = None; let mut nonwrapped_field_tys = vec![]; for field in fields.iter() { let field_ty = &field.ty; if field_ty.to_token_stream().to_string() == wrapped_type { if wrapped_field_ty.is_some() { bail!( "TransparentWrapper can only have one field of the wrapped type" ); } wrapped_field_ty = Some(field_ty); } else { nonwrapped_field_tys.push(field_ty); } } if let Some(wrapped_field_ty) = wrapped_field_ty { Ok(quote!( const _: () = { #[repr(transparent)] #[allow(clippy::multiple_bound_locations)] struct AssertWrappedIsWrapped #impl_generics((u8, ::core::marker::PhantomData<#wrapped_field_ty>), #(#nonwrapped_field_tys),*) #where_clause; fn assert_zeroable() {} #[allow(clippy::multiple_bound_locations)] fn check #impl_generics () #where_clause { #( assert_zeroable::<#nonwrapped_field_tys>(); )* } }; )) } else { bail!("TransparentWrapper must have one field of the wrapped type") } } fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> { let repr = get_repr(attributes)?; match repr.repr { Repr::Transparent => Ok(()), _ => { bail!( "TransparentWrapper requires the struct to be #[repr(transparent)]" ) } } } fn requires_where_clause() -> bool { false } } pub struct Contiguous; impl Derivable for Contiguous { fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result { Ok(syn::parse_quote!(#crate_name::Contiguous)) } fn trait_impl( input: &DeriveInput, _crate_name: &TokenStream, ) -> Result<(TokenStream, TokenStream)> { let repr = get_repr(&input.attrs)?; let integer_ty = if let Some(integer_ty) = repr.repr.as_integer() { integer_ty } else { bail!("Contiguous requires the enum to be #[repr(Int)]"); }; let variants = get_enum_variants(input)?; if enum_has_fields(variants.clone()) { return Err(Error::new_spanned( &input, "Only fieldless enums are supported", )); } let mut variants_with_discriminator = VariantDiscriminantIterator::new(variants); let (min, max, count) = variants_with_discriminator.try_fold( (i64::max_value(), i64::min_value(), 0), |(min, max, count), res| { let discriminator = res?; Ok::<_, Error>(( i64::min(min, discriminator), i64::max(max, discriminator), count + 1, )) }, )?; if max - min != count - 1 { bail! { "Contiguous requires the enum discriminants to be contiguous", } } let min_lit = LitInt::new(&format!("{}", min), input.span()); let max_lit = LitInt::new(&format!("{}", max), input.span()); // `from_integer` and `into_integer` are usually provided by the trait's // default implementation. We override this implementation because it // goes through `transmute_copy`, which can lead to inefficient assembly as seen in https://github.com/Lokathor/bytemuck/issues/175 . Ok(( quote!(), quote! { type Int = #integer_ty; #[allow(clippy::missing_docs_in_private_items)] const MIN_VALUE: #integer_ty = #min_lit; #[allow(clippy::missing_docs_in_private_items)] const MAX_VALUE: #integer_ty = #max_lit; #[inline] fn from_integer(value: Self::Int) -> Option { #[allow(clippy::manual_range_contains)] if Self::MIN_VALUE <= value && value <= Self::MAX_VALUE { Some(unsafe { ::core::mem::transmute(value) }) } else { None } } #[inline] fn into_integer(self) -> Self::Int { self as #integer_ty } }, )) } } fn get_struct_fields(input: &DeriveInput) -> Result<&Fields> { if let Data::Struct(DataStruct { fields, .. }) = &input.data { Ok(fields) } else { bail!("deriving this trait is only supported for structs") } } fn get_fields(input: &DeriveInput) -> Result { match &input.data { Data::Struct(DataStruct { fields, .. }) => Ok(fields.clone()), Data::Union(DataUnion { fields, .. }) => Ok(Fields::Named(fields.clone())), Data::Enum(_) => bail!("deriving this trait is not supported for enums"), } } fn get_enum_variants<'a>( input: &'a DeriveInput, ) -> Result + Clone + 'a> { if let Data::Enum(DataEnum { variants, .. }) = &input.data { Ok(variants.iter()) } else { bail!("deriving this trait is only supported for enums") } } fn get_field_types<'a>( fields: &'a Fields, ) -> impl Iterator + 'a { fields.iter().map(|field| &field.ty) } fn generate_checked_bit_pattern_struct( input_ident: &Ident, fields: &Fields, attrs: &[Attribute], crate_name: &TokenStream, ) -> Result<(TokenStream, TokenStream)> { let bits_ty = Ident::new(&format!("{}Bits", input_ident), input_ident.span()); let repr = get_repr(attrs)?; let field_names = fields .iter() .enumerate() .map(|(i, field)| { field.ident.clone().unwrap_or_else(|| { Ident::new(&format!("field{}", i), input_ident.span()) }) }) .collect::>(); let field_tys = fields.iter().map(|field| &field.ty).collect::>(); let field_name = &field_names[..]; let field_ty = &field_tys[..]; let derive_dbg = quote!(#[cfg_attr(not(target_arch = "spirv"), derive(Debug))]); Ok(( quote! { #[doc = #GENERATED_TYPE_DOCUMENTATION] #repr #[derive(Clone, Copy, #crate_name::AnyBitPattern)] #derive_dbg #[allow(missing_docs)] pub struct #bits_ty { #(#field_name: <#field_ty as #crate_name::CheckedBitPattern>::Bits,)* } }, quote! { type Bits = #bits_ty; #[inline] #[allow(clippy::double_comparisons, unused)] fn is_valid_bit_pattern(bits: &#bits_ty) -> bool { #(<#field_ty as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(&{ bits.#field_name }) && )* true } }, )) } fn generate_checked_bit_pattern_enum( input: &DeriveInput, variants: &Punctuated, crate_name: &TokenStream, ) -> Result<(TokenStream, TokenStream)> { if enum_has_fields(variants.iter()) { generate_checked_bit_pattern_enum_with_fields(input, variants, crate_name) } else { generate_checked_bit_pattern_enum_without_fields(input, variants) } } fn generate_checked_bit_pattern_enum_without_fields( input: &DeriveInput, variants: &Punctuated, ) -> Result<(TokenStream, TokenStream)> { let span = input.span(); let mut variants_with_discriminant = VariantDiscriminantIterator::new(variants.iter()); let (min, max, count) = variants_with_discriminant.try_fold( (i64::max_value(), i64::min_value(), 0), |(min, max, count), res| { let discriminant = res?; Ok::<_, Error>(( i64::min(min, discriminant), i64::max(max, discriminant), count + 1, )) }, )?; let check = if count == 0 { quote_spanned!(span => false) } else if max - min == count - 1 { // contiguous range let min_lit = LitInt::new(&format!("{}", min), span); let max_lit = LitInt::new(&format!("{}", max), span); quote!(*bits >= #min_lit && *bits <= #max_lit) } else { // not contiguous range, check for each let variant_lits = VariantDiscriminantIterator::new(variants.iter()) .map(|res| { let variant = res?; Ok(LitInt::new(&format!("{}", variant), span)) }) .collect::>>()?; // count is at least 1 let first = &variant_lits[0]; let rest = &variant_lits[1..]; quote!(matches!(*bits, #first #(| #rest )*)) }; let repr = get_repr(&input.attrs)?; let integer = repr.repr.as_integer().unwrap(); // should be checked in attr check already Ok(( quote!(), quote! { type Bits = #integer; #[inline] #[allow(clippy::double_comparisons)] fn is_valid_bit_pattern(bits: &Self::Bits) -> bool { #check } }, )) } fn generate_checked_bit_pattern_enum_with_fields( input: &DeriveInput, variants: &Punctuated, crate_name: &TokenStream, ) -> Result<(TokenStream, TokenStream)> { let representation = get_repr(&input.attrs)?; let vis = &input.vis; let derive_dbg = quote!(#[cfg_attr(not(target_arch = "spirv"), derive(Debug))]); match representation.repr { Repr::Rust => unreachable!(), repr @ (Repr::C | Repr::CWithDiscriminant(_)) => { let integer = match repr { Repr::C => quote!(::core::ffi::c_int), Repr::CWithDiscriminant(integer) => quote!(#integer), _ => unreachable!(), }; let input_ident = &input.ident; let bits_repr = Representation { repr: Repr::C, ..representation }; // the enum manually re-configured as the actual tagged union it // represents, thus circumventing the requirements rust imposes on // the tag even when using #[repr(C)] enum layout // see: https://doc.rust-lang.org/reference/type-layout.html#reprc-enums-with-fields let bits_ty_ident = Ident::new(&format!("{input_ident}Bits"), input.span()); // the variants union part of the tagged union. These get put into a union // which gets the AnyBitPattern derive applied to it, thus checking // that the fields of the union obey the requriements of AnyBitPattern. // The types that actually go in the union are one more level of // indirection deep: we generate new structs for each variant // (`variant_struct_definitions`) which themselves have the // `CheckedBitPattern` derive applied, thus generating // `{variant_struct_ident}Bits` structs, which are the ones that go // into this union. let variants_union_ident = Ident::new(&format!("{}Variants", input.ident), input.span()); let variant_struct_idents = variants.iter().map(|v| { Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span()) }); let variant_struct_definitions = variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| { let fields = v.fields.iter().map(|v| &v.ty); quote! { #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)] #[repr(C)] #vis struct #variant_struct_ident(#(#fields),*); } }); let union_fields = variant_struct_idents .clone() .zip(variants.iter()) .map(|(variant_struct_ident, v)| { let variant_struct_bits_ident = Ident::new(&format!("{variant_struct_ident}Bits"), input.span()); let field_ident = &v.ident; quote! { #field_ident: #variant_struct_bits_ident } }); let variant_checks = variant_struct_idents .clone() .zip(VariantDiscriminantIterator::new(variants.iter())) .zip(variants.iter()) .map(|((variant_struct_ident, discriminant), v)| -> Result<_> { let discriminant = discriminant?; let discriminant = LitInt::new(&discriminant.to_string(), v.span()); let ident = &v.ident; Ok(quote! { #discriminant => { let payload = unsafe { &bits.payload.#ident }; <#variant_struct_ident as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(payload) } }) }) .collect::>>()?; Ok(( quote! { #[doc = #GENERATED_TYPE_DOCUMENTATION] #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)] #derive_dbg #bits_repr #vis struct #bits_ty_ident { tag: #integer, payload: #variants_union_ident, } #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)] #[repr(C)] #[allow(non_snake_case)] #vis union #variants_union_ident { #(#union_fields,)* } #[cfg(not(target_arch = "spirv"))] impl ::core::fmt::Debug for #variants_union_ident { fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#variants_union_ident)); ::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct) } } #(#variant_struct_definitions)* }, quote! { type Bits = #bits_ty_ident; #[inline] #[allow(clippy::double_comparisons)] fn is_valid_bit_pattern(bits: &Self::Bits) -> bool { match bits.tag { #(#variant_checks)* _ => false, } } }, )) } Repr::Transparent => { if variants.len() != 1 { bail!("enums with more than one variant cannot be transparent") } let variant = &variants[0]; let bits_ty = Ident::new(&format!("{}Bits", input.ident), input.span()); let fields = variant.fields.iter().map(|v| &v.ty); Ok(( quote! { #[doc = #GENERATED_TYPE_DOCUMENTATION] #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)] #[repr(C)] #vis struct #bits_ty(#(#fields),*); }, quote! { type Bits = <#bits_ty as #crate_name::CheckedBitPattern>::Bits; #[inline] #[allow(clippy::double_comparisons)] fn is_valid_bit_pattern(bits: &Self::Bits) -> bool { <#bits_ty as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(bits) } }, )) } Repr::Integer(integer) => { let bits_repr = Representation { repr: Repr::C, ..representation }; let input_ident = &input.ident; // the enum manually re-configured as the union it represents. such a // union is the union of variants as a repr(c) struct with the // discriminator type inserted at the beginning. in our case we // union the `Bits` representation of each variant rather than the variant // itself, which we generate via a nested `CheckedBitPattern` derive // on the `variant_struct_definitions` generated below. // // see: https://doc.rust-lang.org/reference/type-layout.html#primitive-representation-of-enums-with-fields let bits_ty_ident = Ident::new(&format!("{input_ident}Bits"), input.span()); let variant_struct_idents = variants.iter().map(|v| { Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span()) }); let variant_struct_definitions = variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| { let fields = v.fields.iter().map(|v| &v.ty); // adding the discriminant repr integer as first field, as described above quote! { #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)] #[repr(C)] #vis struct #variant_struct_ident(#integer, #(#fields),*); } }); let union_fields = variant_struct_idents .clone() .zip(variants.iter()) .map(|(variant_struct_ident, v)| { let variant_struct_bits_ident = Ident::new(&format!("{variant_struct_ident}Bits"), input.span()); let field_ident = &v.ident; quote! { #field_ident: #variant_struct_bits_ident } }); let variant_checks = variant_struct_idents .clone() .zip(VariantDiscriminantIterator::new(variants.iter())) .zip(variants.iter()) .map(|((variant_struct_ident, discriminant), v)| -> Result<_> { let discriminant = discriminant?; let discriminant = LitInt::new(&discriminant.to_string(), v.span()); let ident = &v.ident; Ok(quote! { #discriminant => { let payload = unsafe { &bits.#ident }; <#variant_struct_ident as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(payload) } }) }) .collect::>>()?; Ok(( quote! { #[doc = #GENERATED_TYPE_DOCUMENTATION] #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)] #bits_repr #[allow(non_snake_case)] #vis union #bits_ty_ident { __tag: #integer, #(#union_fields,)* } #[cfg(not(target_arch = "spirv"))] impl ::core::fmt::Debug for #bits_ty_ident { fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty_ident)); ::core::fmt::DebugStruct::field(&mut debug_struct, "tag", unsafe { &self.__tag }); ::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct) } } #(#variant_struct_definitions)* }, quote! { type Bits = #bits_ty_ident; #[inline] #[allow(clippy::double_comparisons)] fn is_valid_bit_pattern(bits: &Self::Bits) -> bool { match unsafe { bits.__tag } { #(#variant_checks)* _ => false, } } }, )) } } } /// Check that a struct has no padding by asserting that the size of the struct /// is equal to the sum of the size of it's fields fn generate_assert_no_padding(input: &DeriveInput) -> Result { let struct_type = &input.ident; let span = input.ident.span(); let fields = get_fields(input)?; let mut field_types = get_field_types(&fields); let size_sum = if let Some(first) = field_types.next() { let size_first = quote_spanned!(span => ::core::mem::size_of::<#first>()); let size_rest = quote_spanned!(span => #( + ::core::mem::size_of::<#field_types>() )*); quote_spanned!(span => #size_first #size_rest) } else { quote_spanned!(span => 0) }; Ok(quote_spanned! {span => const _: fn() = || { #[doc(hidden)] struct TypeWithoutPadding([u8; #size_sum]); let _ = ::core::mem::transmute::<#struct_type, TypeWithoutPadding>; };}) } /// Check that all fields implement a given trait fn generate_fields_are_trait( input: &DeriveInput, trait_: syn::Path, ) -> Result { let (impl_generics, _ty_generics, where_clause) = input.generics.split_for_impl(); let fields = get_fields(input)?; let span = input.span(); let field_types = get_field_types(&fields); Ok(quote_spanned! {span => #(const _: fn() = || { #[allow(clippy::missing_const_for_fn)] #[doc(hidden)] fn check #impl_generics () #where_clause { fn assert_impl() {} assert_impl::<#field_types>(); } };)* }) } fn get_ident_from_stream(tokens: TokenStream) -> Option { match tokens.into_iter().next() { Some(TokenTree::Group(group)) => get_ident_from_stream(group.stream()), Some(TokenTree::Ident(ident)) => Some(ident), _ => None, } } /// get a simple #[foo(bar)] attribute, returning "bar" fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option { for attr in attributes { if let (AttrStyle::Outer, Meta::List(list)) = (&attr.style, &attr.meta) { if list.path.is_ident(attr_name) { if let Some(ident) = get_ident_from_stream(list.tokens.clone()) { return Some(ident); } } } } None } fn get_repr(attributes: &[Attribute]) -> Result { attributes .iter() .filter_map(|attr| { if attr.path().is_ident("repr") { Some(attr.parse_args::()) } else { None } }) .try_fold(Representation::default(), |a, b| { let b = b?; Ok(Representation { repr: match (a.repr, b.repr) { (a, Repr::Rust) => a, (Repr::Rust, b) => b, _ => bail!("conflicting representation hints"), }, packed: match (a.packed, b.packed) { (a, None) => a, (None, b) => b, _ => bail!("conflicting representation hints"), }, align: match (a.align, b.align) { (Some(a), Some(b)) => Some(cmp::max(a, b)), (a, None) => a, (None, b) => b, }, }) }) } mk_repr! { U8 => u8, I8 => i8, U16 => u16, I16 => i16, U32 => u32, I32 => i32, U64 => u64, I64 => i64, I128 => i128, U128 => u128, Usize => usize, Isize => isize, } // where macro_rules! mk_repr {( $( $Xn:ident => $xn:ident ),* $(,)? ) => ( #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum IntegerRepr { $($Xn),* } impl<'a> TryFrom<&'a str> for IntegerRepr { type Error = &'a str; fn try_from(value: &'a str) -> std::result::Result { match value { $( stringify!($xn) => Ok(Self::$Xn), )* _ => Err(value), } } } impl ToTokens for IntegerRepr { fn to_tokens(&self, tokens: &mut TokenStream) { match self { $( Self::$Xn => tokens.extend(quote!($xn)), )* } } } )} use mk_repr; #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum Repr { Rust, C, Transparent, Integer(IntegerRepr), CWithDiscriminant(IntegerRepr), } impl Repr { fn is_integer(&self) -> bool { matches!(self, Self::Integer(..)) } fn as_integer(&self) -> Option { if let Self::Integer(v) = self { Some(*v) } else { None } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] struct Representation { packed: Option, align: Option, repr: Repr, } impl Default for Representation { fn default() -> Self { Self { packed: None, align: None, repr: Repr::Rust } } } impl Parse for Representation { fn parse(input: ParseStream<'_>) -> Result { let mut ret = Representation::default(); while !input.is_empty() { let keyword = input.parse::()?; // preƫmptively call `.to_string()` *once* (rather than on `is_ident()`) let keyword_str = keyword.to_string(); let new_repr = match keyword_str.as_str() { "C" => Repr::C, "transparent" => Repr::Transparent, "packed" => { ret.packed = Some(if input.peek(token::Paren) { let contents; parenthesized!(contents in input); LitInt::base10_parse::(&contents.parse()?)? } else { 1 }); let _: Option = input.parse()?; continue; } "align" => { let contents; parenthesized!(contents in input); let new_align = LitInt::base10_parse::(&contents.parse()?)?; ret.align = Some( ret .align .map_or(new_align, |old_align| cmp::max(old_align, new_align)), ); let _: Option = input.parse()?; continue; } ident => { let primitive = IntegerRepr::try_from(ident) .map_err(|_| input.error("unrecognized representation hint"))?; Repr::Integer(primitive) } }; ret.repr = match (ret.repr, new_repr) { (Repr::Rust, new_repr) => { // This is the first explicit repr. new_repr } (Repr::C, Repr::Integer(integer)) | (Repr::Integer(integer), Repr::C) => { // Both the C repr and an integer repr have been specified // -> merge into a C wit discriminant. Repr::CWithDiscriminant(integer) } (_, _) => { return Err(input.error("duplicate representation hint")); } }; let _: Option = input.parse()?; } Ok(ret) } } impl ToTokens for Representation { fn to_tokens(&self, tokens: &mut TokenStream) { let mut meta = Punctuated::<_, Token![,]>::new(); match self.repr { Repr::Rust => {} Repr::C => meta.push(quote!(C)), Repr::Transparent => meta.push(quote!(transparent)), Repr::Integer(primitive) => meta.push(quote!(#primitive)), Repr::CWithDiscriminant(primitive) => { meta.push(quote!(C)); meta.push(quote!(#primitive)); } } if let Some(packed) = self.packed.as_ref() { let lit = LitInt::new(&packed.to_string(), Span::call_site()); meta.push(quote!(packed(#lit))); } if let Some(align) = self.align.as_ref() { let lit = LitInt::new(&align.to_string(), Span::call_site()); meta.push(quote!(align(#lit))); } tokens.extend(quote!( #[repr(#meta)] )); } } fn enum_has_fields<'a>( mut variants: impl Iterator, ) -> bool { variants.any(|v| matches!(v.fields, Fields::Named(_) | Fields::Unnamed(_))) } struct VariantDiscriminantIterator<'a, I: Iterator + 'a> { inner: I, last_value: i64, } impl<'a, I: Iterator + 'a> VariantDiscriminantIterator<'a, I> { fn new(inner: I) -> Self { VariantDiscriminantIterator { inner, last_value: -1 } } } impl<'a, I: Iterator + 'a> Iterator for VariantDiscriminantIterator<'a, I> { type Item = Result; fn next(&mut self) -> Option { let variant = self.inner.next()?; if let Some((_, discriminant)) = &variant.discriminant { let discriminant_value = match parse_int_expr(discriminant) { Ok(value) => value, Err(e) => return Some(Err(e)), }; self.last_value = discriminant_value; } else { self.last_value += 1; } Some(Ok(self.last_value)) } } fn parse_int_expr(expr: &Expr) -> Result { match expr { Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }) => { parse_int_expr(expr).map(|int| -int) } Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => int.base10_parse(), Expr::Lit(ExprLit { lit: Lit::Byte(byte), .. }) => Ok(byte.value().into()), _ => bail!("Not an integer expression"), } } #[cfg(test)] mod tests { use syn::parse_quote; use super::{get_repr, IntegerRepr, Repr, Representation}; #[test] fn parse_basic_repr() { let attr = parse_quote!(#[repr(C)]); let repr = get_repr(&[attr]).unwrap(); assert_eq!(repr, Representation { repr: Repr::C, ..Default::default() }); let attr = parse_quote!(#[repr(transparent)]); let repr = get_repr(&[attr]).unwrap(); assert_eq!( repr, Representation { repr: Repr::Transparent, ..Default::default() } ); let attr = parse_quote!(#[repr(u8)]); let repr = get_repr(&[attr]).unwrap(); assert_eq!( repr, Representation { repr: Repr::Integer(IntegerRepr::U8), ..Default::default() } ); let attr = parse_quote!(#[repr(packed)]); let repr = get_repr(&[attr]).unwrap(); assert_eq!(repr, Representation { packed: Some(1), ..Default::default() }); let attr = parse_quote!(#[repr(packed(1))]); let repr = get_repr(&[attr]).unwrap(); assert_eq!(repr, Representation { packed: Some(1), ..Default::default() }); let attr = parse_quote!(#[repr(packed(2))]); let repr = get_repr(&[attr]).unwrap(); assert_eq!(repr, Representation { packed: Some(2), ..Default::default() }); let attr = parse_quote!(#[repr(align(2))]); let repr = get_repr(&[attr]).unwrap(); assert_eq!(repr, Representation { align: Some(2), ..Default::default() }); } #[test] fn parse_advanced_repr() { let attr = parse_quote!(#[repr(align(4), align(2))]); let repr = get_repr(&[attr]).unwrap(); assert_eq!(repr, Representation { align: Some(4), ..Default::default() }); let attr1 = parse_quote!(#[repr(align(1))]); let attr2 = parse_quote!(#[repr(align(4))]); let attr3 = parse_quote!(#[repr(align(2))]); let repr = get_repr(&[attr1, attr2, attr3]).unwrap(); assert_eq!(repr, Representation { align: Some(4), ..Default::default() }); let attr = parse_quote!(#[repr(C, u8)]); let repr = get_repr(&[attr]).unwrap(); assert_eq!( repr, Representation { repr: Repr::CWithDiscriminant(IntegerRepr::U8), ..Default::default() } ); let attr = parse_quote!(#[repr(u8, C)]); let repr = get_repr(&[attr]).unwrap(); assert_eq!( repr, Representation { repr: Repr::CWithDiscriminant(IntegerRepr::U8), ..Default::default() } ); } } pub fn bytemuck_crate_name(input: &DeriveInput) -> TokenStream { const ATTR_NAME: &'static str = "crate"; let mut crate_name = quote!(::bytemuck); for attr in &input.attrs { if !attr.path().is_ident("bytemuck") { continue; } attr.parse_nested_meta(|meta| { if meta.path.is_ident(ATTR_NAME) { let expr: syn::Expr = meta.value()?.parse()?; let mut value = &expr; while let syn::Expr::Group(e) = value { value = &e.expr; } if let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit), .. }) = value { let suffix = lit.suffix(); if !suffix.is_empty() { bail!(format!("Unexpected suffix `{}` on string literal", suffix)) } let path: syn::Path = match lit.parse() { Ok(path) => path, Err(_) => { bail!(format!("Failed to parse path: {:?}", lit.value())) } }; crate_name = path.into_token_stream(); } else { bail!( "Expected bytemuck `crate` attribute to be a string: `crate = \"...\"`", ) } } Ok(()) }).unwrap(); } return crate_name; } const GENERATED_TYPE_DOCUMENTATION: &str = " `bytemuck`-generated type for internal purposes only.";