xref: /aosp_15_r20/external/mesa3d/src/gallium/frontends/rusticl/mesa/compiler/clc/spirv.rs (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 use crate::compiler::nir::*;
2 use crate::pipe::screen::*;
3 use crate::util::disk_cache::*;
4 
5 use libc_rust_gen::malloc;
6 use mesa_rust_gen::*;
7 use mesa_rust_util::serialize::*;
8 use mesa_rust_util::string::*;
9 
10 use std::ffi::CString;
11 use std::fmt::Debug;
12 use std::os::raw::c_char;
13 use std::os::raw::c_void;
14 use std::ptr;
15 use std::slice;
16 
17 const INPUT_STR: *const c_char = b"input.cl\0".as_ptr().cast();
18 
19 pub enum SpecConstant {
20     None,
21 }
22 
23 pub struct SPIRVBin {
24     spirv: clc_binary,
25     info: Option<clc_parsed_spirv>,
26 }
27 
28 // Safety: SPIRVBin is not mutable and is therefore Send and Sync, needed due to `clc_binary::data`
29 unsafe impl Send for SPIRVBin {}
30 unsafe impl Sync for SPIRVBin {}
31 
32 #[derive(PartialEq, Eq, Hash, Clone)]
33 pub struct SPIRVKernelArg {
34     pub name: String,
35     pub type_name: String,
36     pub access_qualifier: clc_kernel_arg_access_qualifier,
37     pub address_qualifier: clc_kernel_arg_address_qualifier,
38     pub type_qualifier: clc_kernel_arg_type_qualifier,
39 }
40 
41 pub struct CLCHeader<'a> {
42     pub name: CString,
43     pub source: &'a CString,
44 }
45 
46 impl<'a> Debug for CLCHeader<'a> {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result47     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48         let name = self.name.to_string_lossy();
49         let source = self.source.to_string_lossy();
50 
51         f.write_fmt(format_args!("[{name}]:\n{source}"))
52     }
53 }
54 
callback_impl(data: *mut c_void, msg: *const c_char)55 unsafe fn callback_impl(data: *mut c_void, msg: *const c_char) {
56     let data = data as *mut Vec<String>;
57     let msgs = unsafe { data.as_mut() }.unwrap();
58     msgs.push(c_string_to_string(msg));
59 }
60 
spirv_msg_callback(data: *mut c_void, msg: *const c_char)61 unsafe extern "C" fn spirv_msg_callback(data: *mut c_void, msg: *const c_char) {
62     unsafe {
63         callback_impl(data, msg);
64     }
65 }
66 
spirv_to_nir_msg_callback( data: *mut c_void, dbg_level: nir_spirv_debug_level, _offset: usize, msg: *const c_char, )67 unsafe extern "C" fn spirv_to_nir_msg_callback(
68     data: *mut c_void,
69     dbg_level: nir_spirv_debug_level,
70     _offset: usize,
71     msg: *const c_char,
72 ) {
73     if dbg_level >= nir_spirv_debug_level::NIR_SPIRV_DEBUG_LEVEL_WARNING {
74         unsafe {
75             callback_impl(data, msg);
76         }
77     }
78 }
79 
create_clc_logger(msgs: &mut Vec<String>) -> clc_logger80 fn create_clc_logger(msgs: &mut Vec<String>) -> clc_logger {
81     clc_logger {
82         priv_: ptr::from_mut(msgs).cast(),
83         error: Some(spirv_msg_callback),
84         warning: Some(spirv_msg_callback),
85     }
86 }
87 
88 impl SPIRVBin {
from_clc( source: &CString, args: &[CString], headers: &[CLCHeader], cache: &Option<DiskCache>, features: clc_optional_features, spirv_extensions: &[CString], address_bits: u32, ) -> (Option<Self>, String)89     pub fn from_clc(
90         source: &CString,
91         args: &[CString],
92         headers: &[CLCHeader],
93         cache: &Option<DiskCache>,
94         features: clc_optional_features,
95         spirv_extensions: &[CString],
96         address_bits: u32,
97     ) -> (Option<Self>, String) {
98         let mut hash_key = None;
99         let has_includes = args.iter().any(|a| a.as_bytes()[0..2] == *b"-I");
100 
101         let mut spirv_extensions: Vec<_> = spirv_extensions.iter().map(|s| s.as_ptr()).collect();
102         spirv_extensions.push(ptr::null());
103 
104         if let Some(cache) = cache {
105             if !has_includes {
106                 let mut key = Vec::new();
107 
108                 key.extend_from_slice(source.as_bytes());
109                 args.iter()
110                     .for_each(|a| key.extend_from_slice(a.as_bytes()));
111                 headers.iter().for_each(|h| {
112                     key.extend_from_slice(h.name.as_bytes());
113                     key.extend_from_slice(h.source.as_bytes());
114                 });
115 
116                 // Safety: clc_optional_features is a struct of bools and contains no padding.
117                 // Sadly we can't guarentee this.
118                 key.extend(unsafe { as_byte_slice(slice::from_ref(&features)) });
119 
120                 let mut key = cache.gen_key(&key);
121                 if let Some(data) = cache.get(&mut key) {
122                     return (Some(Self::from_bin(&data)), String::from(""));
123                 }
124 
125                 hash_key = Some(key);
126             }
127         }
128 
129         let c_headers: Vec<_> = headers
130             .iter()
131             .map(|h| clc_named_value {
132                 name: h.name.as_ptr(),
133                 value: h.source.as_ptr(),
134             })
135             .collect();
136 
137         let c_args: Vec<_> = args.iter().map(|a| a.as_ptr()).collect();
138 
139         let args = clc_compile_args {
140             headers: c_headers.as_ptr(),
141             num_headers: c_headers.len() as u32,
142             source: clc_named_value {
143                 name: INPUT_STR,
144                 value: source.as_ptr(),
145             },
146             args: c_args.as_ptr(),
147             num_args: c_args.len() as u32,
148             spirv_version: clc_spirv_version::CLC_SPIRV_VERSION_MAX,
149             features: features,
150             use_llvm_spirv_target: false,
151             allowed_spirv_extensions: spirv_extensions.as_ptr(),
152             address_bits: address_bits,
153         };
154         let mut msgs: Vec<String> = Vec::new();
155         let logger = create_clc_logger(&mut msgs);
156         let mut out = clc_binary::default();
157 
158         let res = unsafe { clc_compile_c_to_spirv(&args, &logger, &mut out) };
159 
160         let res = if res {
161             let spirv = SPIRVBin {
162                 spirv: out,
163                 info: None,
164             };
165 
166             // add cache entry
167             if !has_includes {
168                 if let Some(mut key) = hash_key {
169                     cache.as_ref().unwrap().put(spirv.to_bin(), &mut key);
170                 }
171             }
172 
173             Some(spirv)
174         } else {
175             None
176         };
177 
178         (res, msgs.join(""))
179     }
180 
181     // TODO cache linking, parsing is around 25% of link time
link(spirvs: &[&SPIRVBin], library: bool) -> (Option<Self>, String)182     pub fn link(spirvs: &[&SPIRVBin], library: bool) -> (Option<Self>, String) {
183         let bins: Vec<_> = spirvs.iter().map(|s| ptr::from_ref(&s.spirv)).collect();
184 
185         let linker_args = clc_linker_args {
186             in_objs: bins.as_ptr(),
187             num_in_objs: bins.len() as u32,
188             create_library: library as u32,
189         };
190 
191         let mut msgs: Vec<String> = Vec::new();
192         let logger = create_clc_logger(&mut msgs);
193 
194         let mut out = clc_binary::default();
195         let res = unsafe { clc_link_spirv(&linker_args, &logger, &mut out) };
196 
197         let info = if !library && res {
198             let mut pspirv = clc_parsed_spirv::default();
199             let res = unsafe { clc_parse_spirv(&out, &logger, &mut pspirv) };
200             res.then_some(pspirv)
201         } else {
202             None
203         };
204 
205         let res = res.then_some(SPIRVBin {
206             spirv: out,
207             info: info,
208         });
209         (res, msgs.join(""))
210     }
211 
validate(&self, options: &clc_validator_options) -> (bool, String)212     pub fn validate(&self, options: &clc_validator_options) -> (bool, String) {
213         let mut msgs: Vec<String> = Vec::new();
214         let logger = create_clc_logger(&mut msgs);
215         let res = unsafe { clc_validate_spirv(&self.spirv, &logger, options) };
216 
217         (res, msgs.join(""))
218     }
219 
clone_on_validate(&self, options: &clc_validator_options) -> (Option<Self>, String)220     pub fn clone_on_validate(&self, options: &clc_validator_options) -> (Option<Self>, String) {
221         let (res, msgs) = self.validate(options);
222         (res.then(|| self.clone()), msgs)
223     }
224 
kernel_infos(&self) -> &[clc_kernel_info]225     fn kernel_infos(&self) -> &[clc_kernel_info] {
226         match self.info {
227             Some(info) if info.num_kernels > 0 => unsafe {
228                 slice::from_raw_parts(info.kernels, info.num_kernels as usize)
229             },
230             _ => &[],
231         }
232     }
233 
kernel_info(&self, name: &str) -> Option<&clc_kernel_info>234     pub fn kernel_info(&self, name: &str) -> Option<&clc_kernel_info> {
235         self.kernel_infos()
236             .iter()
237             .find(|i| c_string_to_string(i.name) == name)
238     }
239 
kernels(&self) -> Vec<String>240     pub fn kernels(&self) -> Vec<String> {
241         self.kernel_infos()
242             .iter()
243             .map(|i| i.name)
244             .map(c_string_to_string)
245             .collect()
246     }
247 
args(&self, name: &str) -> Vec<SPIRVKernelArg>248     pub fn args(&self, name: &str) -> Vec<SPIRVKernelArg> {
249         match self.kernel_info(name) {
250             Some(info) if info.num_args > 0 => {
251                 unsafe { slice::from_raw_parts(info.args, info.num_args) }
252                     .iter()
253                     .map(|a| SPIRVKernelArg {
254                         name: c_string_to_string(a.name),
255                         type_name: c_string_to_string(a.type_name),
256                         access_qualifier: clc_kernel_arg_access_qualifier(a.access_qualifier),
257                         address_qualifier: a.address_qualifier,
258                         type_qualifier: clc_kernel_arg_type_qualifier(a.type_qualifier),
259                     })
260                     .collect()
261             }
262             _ => Vec::new(),
263         }
264     }
265 
get_spirv_capabilities() -> spirv_capabilities266     fn get_spirv_capabilities() -> spirv_capabilities {
267         spirv_capabilities {
268             Addresses: true,
269             Float16: true,
270             Float16Buffer: true,
271             Float64: true,
272             GenericPointer: true,
273             Groups: true,
274             GroupNonUniformShuffle: true,
275             GroupNonUniformShuffleRelative: true,
276             Int8: true,
277             Int16: true,
278             Int64: true,
279             Kernel: true,
280             ImageBasic: true,
281             ImageReadWrite: true,
282             Linkage: true,
283             LiteralSampler: true,
284             SampledBuffer: true,
285             Sampled1D: true,
286             Vector16: true,
287             ..Default::default()
288         }
289     }
290 
get_spirv_options( library: bool, clc_shader: *const nir_shader, address_bits: u32, caps: &spirv_capabilities, log: Option<&mut Vec<String>>, ) -> spirv_to_nir_options291     fn get_spirv_options(
292         library: bool,
293         clc_shader: *const nir_shader,
294         address_bits: u32,
295         caps: &spirv_capabilities,
296         log: Option<&mut Vec<String>>,
297     ) -> spirv_to_nir_options {
298         let global_addr_format;
299         let offset_addr_format;
300 
301         if address_bits == 32 {
302             global_addr_format = nir_address_format::nir_address_format_32bit_global;
303             offset_addr_format = nir_address_format::nir_address_format_32bit_offset;
304         } else {
305             global_addr_format = nir_address_format::nir_address_format_64bit_global;
306             offset_addr_format = nir_address_format::nir_address_format_32bit_offset_as_64bit;
307         }
308 
309         let debug = log.map(|log| spirv_to_nir_options__bindgen_ty_1 {
310             func: Some(spirv_to_nir_msg_callback),
311             private_data: ptr::from_mut(log).cast(),
312         });
313 
314         spirv_to_nir_options {
315             create_library: library,
316             environment: nir_spirv_execution_environment::NIR_SPIRV_OPENCL,
317             clc_shader: clc_shader,
318             float_controls_execution_mode: float_controls::FLOAT_CONTROLS_DENORM_FLUSH_TO_ZERO_FP32
319                 as u32,
320 
321             printf: true,
322             capabilities: caps,
323             constant_addr_format: global_addr_format,
324             global_addr_format: global_addr_format,
325             shared_addr_format: offset_addr_format,
326             temp_addr_format: offset_addr_format,
327             debug: debug.unwrap_or_default(),
328 
329             ..Default::default()
330         }
331     }
332 
to_nir( &self, entry_point: &str, nir_options: *const nir_shader_compiler_options, libclc: &NirShader, spec_constants: &mut [nir_spirv_specialization], address_bits: u32, log: Option<&mut Vec<String>>, ) -> Option<NirShader>333     pub fn to_nir(
334         &self,
335         entry_point: &str,
336         nir_options: *const nir_shader_compiler_options,
337         libclc: &NirShader,
338         spec_constants: &mut [nir_spirv_specialization],
339         address_bits: u32,
340         log: Option<&mut Vec<String>>,
341     ) -> Option<NirShader> {
342         let c_entry = CString::new(entry_point.as_bytes()).unwrap();
343         let spirv_caps = Self::get_spirv_capabilities();
344         let spirv_options =
345             Self::get_spirv_options(false, libclc.get_nir(), address_bits, &spirv_caps, log);
346 
347         let nir = unsafe {
348             spirv_to_nir(
349                 self.spirv.data.cast(),
350                 self.spirv.size / 4,
351                 spec_constants.as_mut_ptr(),
352                 spec_constants.len() as u32,
353                 gl_shader_stage::MESA_SHADER_KERNEL,
354                 c_entry.as_ptr(),
355                 &spirv_options,
356                 nir_options,
357             )
358         };
359 
360         NirShader::new(nir)
361     }
362 
get_lib_clc(screen: &PipeScreen) -> Option<NirShader>363     pub fn get_lib_clc(screen: &PipeScreen) -> Option<NirShader> {
364         let nir_options = screen.nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE);
365         let address_bits = screen.compute_param(pipe_compute_cap::PIPE_COMPUTE_CAP_ADDRESS_BITS);
366         let spirv_caps = Self::get_spirv_capabilities();
367         let spirv_options =
368             Self::get_spirv_options(false, ptr::null(), address_bits, &spirv_caps, None);
369         let shader_cache = DiskCacheBorrowed::as_ptr(&screen.shader_cache());
370 
371         NirShader::new(unsafe {
372             nir_load_libclc_shader(
373                 address_bits,
374                 shader_cache,
375                 &spirv_options,
376                 nir_options,
377                 true,
378             )
379         })
380     }
381 
to_bin(&self) -> &[u8]382     pub fn to_bin(&self) -> &[u8] {
383         unsafe { slice::from_raw_parts(self.spirv.data.cast(), self.spirv.size) }
384     }
385 
from_bin(bin: &[u8]) -> Self386     pub fn from_bin(bin: &[u8]) -> Self {
387         unsafe {
388             let ptr = malloc(bin.len());
389             ptr::copy_nonoverlapping(bin.as_ptr(), ptr.cast(), bin.len());
390             let spirv = clc_binary {
391                 data: ptr,
392                 size: bin.len(),
393             };
394 
395             let mut pspirv = clc_parsed_spirv::default();
396 
397             let info = if clc_parse_spirv(&spirv, ptr::null(), &mut pspirv) {
398                 Some(pspirv)
399             } else {
400                 None
401             };
402 
403             SPIRVBin {
404                 spirv: spirv,
405                 info: info,
406             }
407         }
408     }
409 
spec_constant(&self, spec_id: u32) -> Option<clc_spec_constant_type>410     pub fn spec_constant(&self, spec_id: u32) -> Option<clc_spec_constant_type> {
411         let info = self.info?;
412         if info.num_spec_constants == 0 {
413             return None;
414         }
415 
416         let spec_constants =
417             unsafe { slice::from_raw_parts(info.spec_constants, info.num_spec_constants as usize) };
418 
419         spec_constants
420             .iter()
421             .find(|sc| sc.id == spec_id)
422             .map(|sc| sc.type_)
423     }
424 
print(&self)425     pub fn print(&self) {
426         unsafe {
427             clc_dump_spirv(&self.spirv, stderr_ptr());
428         }
429     }
430 }
431 
432 impl Clone for SPIRVBin {
clone(&self) -> Self433     fn clone(&self) -> Self {
434         Self::from_bin(self.to_bin())
435     }
436 }
437 
438 impl Drop for SPIRVBin {
drop(&mut self)439     fn drop(&mut self) {
440         unsafe {
441             clc_free_spirv(&mut self.spirv);
442             if let Some(info) = &mut self.info {
443                 clc_free_parsed_spirv(info);
444             }
445         }
446     }
447 }
448 
449 impl SPIRVKernelArg {
serialize(&self, blob: &mut blob)450     pub fn serialize(&self, blob: &mut blob) {
451         let name_arr = self.name.as_bytes();
452         let type_name_arr = self.type_name.as_bytes();
453 
454         unsafe {
455             blob_write_uint32(blob, self.access_qualifier.0);
456             blob_write_uint32(blob, self.type_qualifier.0);
457 
458             blob_write_uint16(blob, name_arr.len() as u16);
459             blob_write_uint16(blob, type_name_arr.len() as u16);
460 
461             blob_write_bytes(blob, name_arr.as_ptr().cast(), name_arr.len());
462             blob_write_bytes(blob, type_name_arr.as_ptr().cast(), type_name_arr.len());
463 
464             blob_write_uint8(blob, self.address_qualifier as u8);
465         }
466     }
467 
deserialize(blob: &mut blob_reader) -> Option<Self>468     pub fn deserialize(blob: &mut blob_reader) -> Option<Self> {
469         unsafe {
470             let access_qualifier = blob_read_uint32(blob);
471             let type_qualifier = blob_read_uint32(blob);
472 
473             let name_len = blob_read_uint16(blob) as usize;
474             let type_len = blob_read_uint16(blob) as usize;
475 
476             let name = slice::from_raw_parts(blob_read_bytes(blob, name_len).cast(), name_len);
477             let type_name = slice::from_raw_parts(blob_read_bytes(blob, type_len).cast(), type_len);
478 
479             let address_qualifier = match blob_read_uint8(blob) {
480                 0 => clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE,
481                 1 => clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT,
482                 2 => clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL,
483                 3 => clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL,
484                 _ => return None,
485             };
486 
487             Some(Self {
488                 name: String::from_utf8_unchecked(name.to_owned()),
489                 type_name: String::from_utf8_unchecked(type_name.to_owned()),
490                 access_qualifier: clc_kernel_arg_access_qualifier(access_qualifier),
491                 address_qualifier: address_qualifier,
492                 type_qualifier: clc_kernel_arg_type_qualifier(type_qualifier),
493             })
494         }
495     }
496 }
497 
498 pub trait CLCSpecConstantType {
size(self) -> u8499     fn size(self) -> u8;
500 }
501 
502 impl CLCSpecConstantType for clc_spec_constant_type {
size(self) -> u8503     fn size(self) -> u8 {
504         match self {
505             Self::CLC_SPEC_CONSTANT_INT64
506             | Self::CLC_SPEC_CONSTANT_UINT64
507             | Self::CLC_SPEC_CONSTANT_DOUBLE => 8,
508             Self::CLC_SPEC_CONSTANT_INT32
509             | Self::CLC_SPEC_CONSTANT_UINT32
510             | Self::CLC_SPEC_CONSTANT_FLOAT => 4,
511             Self::CLC_SPEC_CONSTANT_INT16 | Self::CLC_SPEC_CONSTANT_UINT16 => 2,
512             Self::CLC_SPEC_CONSTANT_INT8
513             | Self::CLC_SPEC_CONSTANT_UINT8
514             | Self::CLC_SPEC_CONSTANT_BOOL => 1,
515             Self::CLC_SPEC_CONSTANT_UNKNOWN => 0,
516         }
517     }
518 }
519 
520 pub trait SpirvKernelInfo {
vec_type_hint(&self) -> Option<String>521     fn vec_type_hint(&self) -> Option<String>;
local_size(&self) -> Option<String>522     fn local_size(&self) -> Option<String>;
local_size_hint(&self) -> Option<String>523     fn local_size_hint(&self) -> Option<String>;
524 
attribute_str(&self) -> String525     fn attribute_str(&self) -> String {
526         let attributes_strings = [
527             self.vec_type_hint(),
528             self.local_size(),
529             self.local_size_hint(),
530         ];
531 
532         let attributes_strings: Vec<_> = attributes_strings.into_iter().flatten().collect();
533         attributes_strings.join(",")
534     }
535 }
536 
537 impl SpirvKernelInfo for clc_kernel_info {
vec_type_hint(&self) -> Option<String>538     fn vec_type_hint(&self) -> Option<String> {
539         if ![1, 2, 3, 4, 8, 16].contains(&self.vec_hint_size) {
540             return None;
541         }
542         let cltype = match self.vec_hint_type {
543             clc_vec_hint_type::CLC_VEC_HINT_TYPE_CHAR => "uchar",
544             clc_vec_hint_type::CLC_VEC_HINT_TYPE_SHORT => "ushort",
545             clc_vec_hint_type::CLC_VEC_HINT_TYPE_INT => "uint",
546             clc_vec_hint_type::CLC_VEC_HINT_TYPE_LONG => "ulong",
547             clc_vec_hint_type::CLC_VEC_HINT_TYPE_HALF => "half",
548             clc_vec_hint_type::CLC_VEC_HINT_TYPE_FLOAT => "float",
549             clc_vec_hint_type::CLC_VEC_HINT_TYPE_DOUBLE => "double",
550         };
551 
552         Some(format!("vec_type_hint({}{})", cltype, self.vec_hint_size))
553     }
554 
local_size(&self) -> Option<String>555     fn local_size(&self) -> Option<String> {
556         if self.local_size == [0; 3] {
557             return None;
558         }
559         Some(format!(
560             "reqd_work_group_size({},{},{})",
561             self.local_size[0], self.local_size[1], self.local_size[2]
562         ))
563     }
564 
local_size_hint(&self) -> Option<String>565     fn local_size_hint(&self) -> Option<String> {
566         if self.local_size_hint == [0; 3] {
567             return None;
568         }
569         Some(format!(
570             "work_group_size_hint({},{},{})",
571             self.local_size_hint[0], self.local_size_hint[1], self.local_size_hint[2]
572         ))
573     }
574 }
575