xref: /aosp_15_r20/external/mesa3d/src/compiler/rust/proc/as_slice.rs (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 // Copyright © 2023 Collabora, Ltd.
2 // SPDX-License-Identifier: MIT
3 
4 use proc_macro::TokenStream;
5 use proc_macro2::{Span, TokenStream as TokenStream2};
6 use syn::*;
7 
expr_as_usize(expr: &syn::Expr) -> usize8 fn expr_as_usize(expr: &syn::Expr) -> usize {
9     let lit = match expr {
10         syn::Expr::Lit(lit) => lit,
11         _ => panic!("Expected a literal, found an expression"),
12     };
13     let lit_int = match &lit.lit {
14         syn::Lit::Int(i) => i,
15         _ => panic!("Expected a literal integer"),
16     };
17     assert!(lit.attrs.is_empty());
18     lit_int
19         .base10_parse()
20         .expect("Failed to parse integer literal")
21 }
22 
count_type(ty: &Type, slice_type: &str) -> usize23 fn count_type(ty: &Type, slice_type: &str) -> usize {
24     match ty {
25         syn::Type::Array(a) => {
26             let elems = count_type(a.elem.as_ref(), slice_type);
27             if elems > 0 {
28                 elems * expr_as_usize(&a.len)
29             } else {
30                 0
31             }
32         }
33         syn::Type::Path(p) => {
34             if p.qself.is_none() && p.path.is_ident(slice_type) {
35                 1
36             } else {
37                 0
38             }
39         }
40         _ => 0,
41     }
42 }
43 
get_attr(field: &Field, attr_name: &str) -> Option<String>44 fn get_attr(field: &Field, attr_name: &str) -> Option<String> {
45     for attr in &field.attrs {
46         if let Meta::List(ml) = &attr.meta {
47             if ml.path.is_ident(attr_name) {
48                 return Some(format!("{}", ml.tokens));
49             }
50         }
51     }
52     None
53 }
54 
derive_as_slice( input: TokenStream, slice_type: &str, attr_name: &str, attr_type: &str, ) -> TokenStream55 pub fn derive_as_slice(
56     input: TokenStream,
57     slice_type: &str,
58     attr_name: &str,
59     attr_type: &str,
60 ) -> TokenStream {
61     let DeriveInput {
62         attrs, ident, data, ..
63     } = parse_macro_input!(input);
64 
65     match data {
66         Data::Struct(s) => {
67             let mut has_repr_c = false;
68             for attr in attrs {
69                 match attr.meta {
70                     Meta::List(ml) => {
71                         if ml.path.is_ident("repr")
72                             && format!("{}", ml.tokens) == "C"
73                         {
74                             has_repr_c = true;
75                         }
76                     }
77                     _ => (),
78                 }
79             }
80             assert!(has_repr_c, "Struct must be declared #[repr(C)]");
81 
82             let mut first = None;
83             let mut count = 0_usize;
84             let mut found_last = false;
85             let mut attrs = TokenStream2::new();
86 
87             if let Fields::Named(named) = s.fields {
88                 for f in named.named {
89                     let f_count = count_type(&f.ty, slice_type);
90                     let f_attr = get_attr(&f, &attr_name);
91 
92                     if f_count > 0 {
93                         assert!(
94                             !found_last,
95                             "All fields of type {slice_type} must be consecutive",
96                         );
97 
98                         let attr_type =
99                             Ident::new(attr_type, Span::call_site());
100                         let f_attr = if let Some(s) = f_attr {
101                             let s = syn::parse_str::<Ident>(&s).unwrap();
102                             quote! { #attr_type::#s, }
103                         } else {
104                             quote! { #attr_type::DEFAULT, }
105                         };
106 
107                         first.get_or_insert(f.ident);
108                         for _ in 0..f_count {
109                             attrs.extend(f_attr.clone());
110                         }
111                         count += f_count;
112                     } else {
113                         assert!(
114                             f_attr.is_none(),
115                             "{attr_name} attribute is only allowed on {slice_type}"
116                         );
117                         if !first.is_none() {
118                             found_last = true;
119                         }
120                     }
121                 }
122             } else {
123                 panic!("Fields are not named");
124             }
125 
126             let slice_type = Ident::new(slice_type, Span::call_site());
127             let attr_type = Ident::new(attr_type, Span::call_site());
128             if let Some(first) = first {
129                 quote! {
130                     impl compiler::as_slice::AsSlice<#slice_type> for #ident {
131                         type Attr = #attr_type;
132 
133                         fn as_slice(&self) -> &[#slice_type] {
134                             unsafe {
135                                 let first = &self.#first as *const #slice_type;
136                                 std::slice::from_raw_parts(first, #count)
137                             }
138                         }
139 
140                         fn as_mut_slice(&mut self) -> &mut [#slice_type] {
141                             unsafe {
142                                 let first =
143                                     &mut self.#first as *mut #slice_type;
144                                 std::slice::from_raw_parts_mut(first, #count)
145                             }
146                         }
147 
148                         fn attrs(&self) -> AttrList<Self::Attr> {
149                             static ATTRS: [#attr_type; #count] = [#attrs];
150                             AttrList::Array(&ATTRS)
151                         }
152                     }
153                 }
154             } else {
155                 quote! {
156                     impl compiler::as_slice::AsSlice<#slice_type> for #ident {
157                         type Attr = #attr_type;
158 
159                         fn as_slice(&self) -> &[#slice_type] {
160                             &[]
161                         }
162 
163                         fn as_mut_slice(&mut self) -> &mut [#slice_type] {
164                             &mut []
165                         }
166 
167                         fn attrs(&self) -> AttrList<Self::Attr> {
168                             AttrList::Uniform(#attr_type::DEFAULT)
169                         }
170                     }
171                 }
172             }
173             .into()
174         }
175         Data::Enum(e) => {
176             let mut as_slice_cases = TokenStream2::new();
177             let mut as_mut_slice_cases = TokenStream2::new();
178             let mut types_cases = TokenStream2::new();
179             let slice_type = Ident::new(slice_type, Span::call_site());
180             let attr_type = Ident::new(attr_type, Span::call_site());
181             for v in e.variants {
182                 let case = v.ident;
183                 as_slice_cases.extend(quote! {
184                     #ident::#case(x) => compiler::as_slice::AsSlice::<#slice_type>::as_slice(x),
185                 });
186                 as_mut_slice_cases.extend(quote! {
187                     #ident::#case(x) => compiler::as_slice::AsSlice::<#slice_type>::as_mut_slice(x),
188                 });
189                 types_cases.extend(quote! {
190                     #ident::#case(x) => compiler::as_slice::AsSlice::<#slice_type>::attrs(x),
191                 });
192             }
193             quote! {
194                 impl compiler::as_slice::AsSlice<#slice_type> for #ident {
195                     type Attr = #attr_type;
196 
197                     fn as_slice(&self) -> &[#slice_type] {
198                         match self {
199                             #as_slice_cases
200                         }
201                     }
202 
203                     fn as_mut_slice(&mut self) -> &mut [#slice_type] {
204                         match self {
205                             #as_mut_slice_cases
206                         }
207                     }
208 
209                     fn attrs(&self) -> AttrList<Self::Attr> {
210                         match self {
211                             #types_cases
212                         }
213                     }
214                 }
215             }
216             .into()
217         }
218         _ => panic!("Not a struct type"),
219     }
220 }
221