xref: /aosp_15_r20/external/crosvm/hypervisor/hypervisor_test_macro/src/lib.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2024 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #![warn(missing_docs)]
6 #![recursion_limit = "128"]
7 
8 //! Macros for hypervisor tests
9 
10 use std::collections::hash_map::DefaultHasher;
11 use std::hash::Hash;
12 use std::hash::Hasher;
13 use std::sync::atomic::AtomicU64;
14 
15 use proc_macro::TokenStream;
16 use proc_macro2::Span;
17 use proc_macro2::TokenStream as TokenStream2;
18 use quote::quote;
19 use rand::Rng;
20 use syn::parse::Parse;
21 use syn::parse_macro_input;
22 use syn::Error;
23 use syn::Ident;
24 use syn::LitStr;
25 use syn::Token;
26 use syn::Visibility;
27 
28 /// Embed the compiled assembly as an array.
29 ///
30 /// This macro will generate a module with the given `$name` and provides a `data` function in the
31 /// module to allow accessing the compiled machine code as an array.
32 ///
33 /// Note that this macro uses [`std::arch::global_asm`], so we can only use this macro in a global
34 /// scope, outside a function.
35 ///
36 /// # Example
37 ///
38 /// Given the following x86 assembly:
39 /// ```Text
40 /// 0:  01 d8                   add    eax,ebx
41 /// 2:  f4                      hlt
42 /// ```
43 ///
44 /// ```rust
45 /// # use hypervisor_test_macro::global_asm_data;
46 /// global_asm_data!(
47 ///     my_code,
48 ///     ".code64",
49 ///     "add eax, ebx",
50 ///     "hlt",
51 /// );
52 /// # fn main() {
53 /// assert_eq!([0x01, 0xd8, 0xf4], my_code::data());
54 /// # }
55 /// ```
56 ///
57 /// It is supported to pass arbitrary supported [`std::arch::global_asm`] operands and options.
58 /// ```rust
59 /// # use hypervisor_test_macro::global_asm_data;
60 /// fn f() {}
61 /// global_asm_data!(
62 ///     my_code1,
63 ///     ".global {0}",
64 ///     ".code64",
65 ///     "add eax, ebx",
66 ///     "hlt",
67 ///     sym f,
68 /// );
69 /// global_asm_data!(
70 ///     my_code2,
71 ///     ".code64",
72 ///     "add eax, ebx",
73 ///     "hlt",
74 ///     options(raw),
75 /// );
76 /// # fn main() {
77 /// assert_eq!([0x01, 0xd8, 0xf4], my_code1::data());
78 /// assert_eq!([0x01, 0xd8, 0xf4], my_code2::data());
79 /// # }
80 /// ```
81 ///
82 /// It is also supported to specify the visibility of the generated module. Note that the below
83 /// example won't work if the `pub` in the macro is missing.
84 /// ```rust
85 /// # use hypervisor_test_macro::global_asm_data;
86 /// mod my_mod {
87 ///     // This use is needed to import the global_asm_data macro to this module.
88 ///     use super::*;
89 ///
90 ///     global_asm_data!(
91 ///         // pub is needed so that my_mod::my_code is visible to the outer scope.
92 ///         pub my_code,
93 ///         ".code64",
94 ///         "add eax, ebx",
95 ///         "hlt",
96 ///     );
97 /// }
98 /// # fn main() {
99 /// assert_eq!([0x01, 0xd8, 0xf4], my_mod::my_code::data());
100 /// # }
101 /// ```
102 #[proc_macro]
global_asm_data(item: TokenStream) -> TokenStream103 pub fn global_asm_data(item: TokenStream) -> TokenStream {
104     let args = parse_macro_input!(item as GlobalAsmDataArgs);
105     global_asm_data_impl(args).unwrap_or_else(|e| e.to_compile_error().into())
106 }
107 
108 struct GlobalAsmDataArgs {
109     visibility: Visibility,
110     mod_name: Ident,
111     global_asm_strings: Vec<LitStr>,
112     global_asm_rest_args: TokenStream2,
113 }
114 
115 impl Parse for GlobalAsmDataArgs {
parse(input: syn::parse::ParseStream) -> syn::Result<Self>116     fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
117         // The first argument is visibilty + identifier, e.g. my_code or pub my_code. The identifier
118         // will be used as the name of the gnerated module.
119         let visibility: Visibility = input.parse()?;
120         let mod_name: Ident = input.parse()?;
121         // There must be following arguments, so we consume the first argument separator here.
122         input.parse::<Token![,]>()?;
123 
124         // Retrieve the input assemblies, which are a list of comma separated string literals. We
125         // need to obtain the list of assemblies explicitly, so that we can insert the begin tag and
126         // the end tag to the global_asm! call when we generate the result code.
127         let mut global_asm_strings = vec![];
128         loop {
129             let lookahead = input.lookahead1();
130             if !lookahead.peek(LitStr) {
131                 // If the upcoming tokens are not string literal, we hit the end of the input
132                 // assemblies.
133                 break;
134             }
135             global_asm_strings.push(input.parse::<LitStr>()?);
136 
137             if input.is_empty() {
138                 // In case the current string literal is the last argument.
139                 break;
140             }
141             input.parse::<Token![,]>()?;
142             if input.is_empty() {
143                 // In case the current string literal is the last argument with a trailing comma.
144                 break;
145             }
146         }
147 
148         // We store the rest of the arguments, and we will forward them as is to global_asm!.
149         let global_asm_rest_args: TokenStream2 = input.parse()?;
150         Ok(Self {
151             visibility,
152             mod_name,
153             global_asm_strings,
154             global_asm_rest_args,
155         })
156     }
157 }
158 
159 static COUNTER: AtomicU64 = AtomicU64::new(0);
160 
161 fn global_asm_data_impl(
162     GlobalAsmDataArgs {
163         visibility,
164         mod_name,
165         global_asm_strings,
166         global_asm_rest_args,
167     }: GlobalAsmDataArgs,
168 ) -> Result<TokenStream, Error> {
169     let span = Span::call_site();
170 
171     // Generate the unique tags based on the macro input, code location and a random number to avoid
172     // symbol collision.
173     let tag_base_name = {
174         let content_id = {
175             let mut hasher = DefaultHasher::new();
176             span.source_text().hash(&mut hasher);
177             hasher.finish()
178         };
179         let location_id = format!(
180             "{}_{}_{}_{}",
181             span.start().line,
182             span.start().column,
183             span.end().line,
184             span.end().column
185         );
186         let rand_id = rand::thread_rng().gen::<u64>();
187         let static_counter_id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
188         let prefix = "crosvm_hypervisor_test_macro_global_asm_data";
189         format!(
190             "{}_{}_{}_{}_{}_{}",
191             prefix, mod_name, content_id, location_id, static_counter_id, rand_id
192         )
193     };
194     let start_tag = format!("{}_start", tag_base_name);
195     let end_tag = format!("{}_end", tag_base_name);
196 
197     let global_directive = LitStr::new(&format!(".global {}, {}", start_tag, end_tag), span);
198     let start_tag_asm = LitStr::new(&format!("{}:", start_tag), span);
199     let end_tag_asm = LitStr::new(&format!("{}:", end_tag), span);
200     let start_tag_ident = Ident::new(&start_tag, span);
201     let end_tag_ident = Ident::new(&end_tag, span);
202 
203     Ok(quote! {
204         #visibility mod #mod_name {
205             use super::*;
206 
207             extern {
208                 static #start_tag_ident: u8;
209                 static #end_tag_ident: u8;
210             }
211 
212             std::arch::global_asm!(
213                 #global_directive,
214                 #start_tag_asm,
215                 #(#global_asm_strings),*,
216                 #end_tag_asm,
217                 #global_asm_rest_args
218             );
219             pub fn data() -> &'static [u8] {
220                 // SAFETY:
221                 // * The extern statics are u8, and any arbitrary bit patterns are valid for u8.
222                 // * The data starting from start to end is valid u8.
223                 // * Without unsafe block, one can't mutate the value between start and end. In
224                 //   addition, it is likely that the data is written to a readonly block, and can't
225                 //   be mutated at all.
226                 // * The address shouldn't be too large, and won't wrap around.
227                 unsafe {
228                     let ptr = std::ptr::addr_of!(#start_tag_ident);
229                     let len = std::ptr::addr_of!(#end_tag_ident).offset_from(ptr);
230                     std::slice::from_raw_parts(
231                         ptr,
232                         len.try_into().expect("length must be positive")
233                     )
234                 }
235             }
236         }
237     }
238     .into())
239 }
240