1 use crate::compiler::nir::*;
2 use crate::pipe::screen::*;
3 use crate::util::disk_cache::*;
4
5 use libc_rust_gen::malloc;
6 use mesa_rust_gen::*;
7 use mesa_rust_util::serialize::*;
8 use mesa_rust_util::string::*;
9
10 use std::ffi::CString;
11 use std::fmt::Debug;
12 use std::os::raw::c_char;
13 use std::os::raw::c_void;
14 use std::ptr;
15 use std::slice;
16
17 const INPUT_STR: *const c_char = b"input.cl\0".as_ptr().cast();
18
19 pub enum SpecConstant {
20 None,
21 }
22
23 pub struct SPIRVBin {
24 spirv: clc_binary,
25 info: Option<clc_parsed_spirv>,
26 }
27
28 // Safety: SPIRVBin is not mutable and is therefore Send and Sync, needed due to `clc_binary::data`
29 unsafe impl Send for SPIRVBin {}
30 unsafe impl Sync for SPIRVBin {}
31
32 #[derive(PartialEq, Eq, Hash, Clone)]
33 pub struct SPIRVKernelArg {
34 pub name: String,
35 pub type_name: String,
36 pub access_qualifier: clc_kernel_arg_access_qualifier,
37 pub address_qualifier: clc_kernel_arg_address_qualifier,
38 pub type_qualifier: clc_kernel_arg_type_qualifier,
39 }
40
41 pub struct CLCHeader<'a> {
42 pub name: CString,
43 pub source: &'a CString,
44 }
45
46 impl<'a> Debug for CLCHeader<'a> {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 let name = self.name.to_string_lossy();
49 let source = self.source.to_string_lossy();
50
51 f.write_fmt(format_args!("[{name}]:\n{source}"))
52 }
53 }
54
callback_impl(data: *mut c_void, msg: *const c_char)55 unsafe fn callback_impl(data: *mut c_void, msg: *const c_char) {
56 let data = data as *mut Vec<String>;
57 let msgs = unsafe { data.as_mut() }.unwrap();
58 msgs.push(c_string_to_string(msg));
59 }
60
spirv_msg_callback(data: *mut c_void, msg: *const c_char)61 unsafe extern "C" fn spirv_msg_callback(data: *mut c_void, msg: *const c_char) {
62 unsafe {
63 callback_impl(data, msg);
64 }
65 }
66
spirv_to_nir_msg_callback( data: *mut c_void, dbg_level: nir_spirv_debug_level, _offset: usize, msg: *const c_char, )67 unsafe extern "C" fn spirv_to_nir_msg_callback(
68 data: *mut c_void,
69 dbg_level: nir_spirv_debug_level,
70 _offset: usize,
71 msg: *const c_char,
72 ) {
73 if dbg_level >= nir_spirv_debug_level::NIR_SPIRV_DEBUG_LEVEL_WARNING {
74 unsafe {
75 callback_impl(data, msg);
76 }
77 }
78 }
79
create_clc_logger(msgs: &mut Vec<String>) -> clc_logger80 fn create_clc_logger(msgs: &mut Vec<String>) -> clc_logger {
81 clc_logger {
82 priv_: ptr::from_mut(msgs).cast(),
83 error: Some(spirv_msg_callback),
84 warning: Some(spirv_msg_callback),
85 }
86 }
87
88 impl SPIRVBin {
from_clc( source: &CString, args: &[CString], headers: &[CLCHeader], cache: &Option<DiskCache>, features: clc_optional_features, spirv_extensions: &[CString], address_bits: u32, ) -> (Option<Self>, String)89 pub fn from_clc(
90 source: &CString,
91 args: &[CString],
92 headers: &[CLCHeader],
93 cache: &Option<DiskCache>,
94 features: clc_optional_features,
95 spirv_extensions: &[CString],
96 address_bits: u32,
97 ) -> (Option<Self>, String) {
98 let mut hash_key = None;
99 let has_includes = args.iter().any(|a| a.as_bytes()[0..2] == *b"-I");
100
101 let mut spirv_extensions: Vec<_> = spirv_extensions.iter().map(|s| s.as_ptr()).collect();
102 spirv_extensions.push(ptr::null());
103
104 if let Some(cache) = cache {
105 if !has_includes {
106 let mut key = Vec::new();
107
108 key.extend_from_slice(source.as_bytes());
109 args.iter()
110 .for_each(|a| key.extend_from_slice(a.as_bytes()));
111 headers.iter().for_each(|h| {
112 key.extend_from_slice(h.name.as_bytes());
113 key.extend_from_slice(h.source.as_bytes());
114 });
115
116 // Safety: clc_optional_features is a struct of bools and contains no padding.
117 // Sadly we can't guarentee this.
118 key.extend(unsafe { as_byte_slice(slice::from_ref(&features)) });
119
120 let mut key = cache.gen_key(&key);
121 if let Some(data) = cache.get(&mut key) {
122 return (Some(Self::from_bin(&data)), String::from(""));
123 }
124
125 hash_key = Some(key);
126 }
127 }
128
129 let c_headers: Vec<_> = headers
130 .iter()
131 .map(|h| clc_named_value {
132 name: h.name.as_ptr(),
133 value: h.source.as_ptr(),
134 })
135 .collect();
136
137 let c_args: Vec<_> = args.iter().map(|a| a.as_ptr()).collect();
138
139 let args = clc_compile_args {
140 headers: c_headers.as_ptr(),
141 num_headers: c_headers.len() as u32,
142 source: clc_named_value {
143 name: INPUT_STR,
144 value: source.as_ptr(),
145 },
146 args: c_args.as_ptr(),
147 num_args: c_args.len() as u32,
148 spirv_version: clc_spirv_version::CLC_SPIRV_VERSION_MAX,
149 features: features,
150 use_llvm_spirv_target: false,
151 allowed_spirv_extensions: spirv_extensions.as_ptr(),
152 address_bits: address_bits,
153 };
154 let mut msgs: Vec<String> = Vec::new();
155 let logger = create_clc_logger(&mut msgs);
156 let mut out = clc_binary::default();
157
158 let res = unsafe { clc_compile_c_to_spirv(&args, &logger, &mut out) };
159
160 let res = if res {
161 let spirv = SPIRVBin {
162 spirv: out,
163 info: None,
164 };
165
166 // add cache entry
167 if !has_includes {
168 if let Some(mut key) = hash_key {
169 cache.as_ref().unwrap().put(spirv.to_bin(), &mut key);
170 }
171 }
172
173 Some(spirv)
174 } else {
175 None
176 };
177
178 (res, msgs.join(""))
179 }
180
181 // TODO cache linking, parsing is around 25% of link time
link(spirvs: &[&SPIRVBin], library: bool) -> (Option<Self>, String)182 pub fn link(spirvs: &[&SPIRVBin], library: bool) -> (Option<Self>, String) {
183 let bins: Vec<_> = spirvs.iter().map(|s| ptr::from_ref(&s.spirv)).collect();
184
185 let linker_args = clc_linker_args {
186 in_objs: bins.as_ptr(),
187 num_in_objs: bins.len() as u32,
188 create_library: library as u32,
189 };
190
191 let mut msgs: Vec<String> = Vec::new();
192 let logger = create_clc_logger(&mut msgs);
193
194 let mut out = clc_binary::default();
195 let res = unsafe { clc_link_spirv(&linker_args, &logger, &mut out) };
196
197 let info = if !library && res {
198 let mut pspirv = clc_parsed_spirv::default();
199 let res = unsafe { clc_parse_spirv(&out, &logger, &mut pspirv) };
200 res.then_some(pspirv)
201 } else {
202 None
203 };
204
205 let res = res.then_some(SPIRVBin {
206 spirv: out,
207 info: info,
208 });
209 (res, msgs.join(""))
210 }
211
validate(&self, options: &clc_validator_options) -> (bool, String)212 pub fn validate(&self, options: &clc_validator_options) -> (bool, String) {
213 let mut msgs: Vec<String> = Vec::new();
214 let logger = create_clc_logger(&mut msgs);
215 let res = unsafe { clc_validate_spirv(&self.spirv, &logger, options) };
216
217 (res, msgs.join(""))
218 }
219
clone_on_validate(&self, options: &clc_validator_options) -> (Option<Self>, String)220 pub fn clone_on_validate(&self, options: &clc_validator_options) -> (Option<Self>, String) {
221 let (res, msgs) = self.validate(options);
222 (res.then(|| self.clone()), msgs)
223 }
224
kernel_infos(&self) -> &[clc_kernel_info]225 fn kernel_infos(&self) -> &[clc_kernel_info] {
226 match self.info {
227 Some(info) if info.num_kernels > 0 => unsafe {
228 slice::from_raw_parts(info.kernels, info.num_kernels as usize)
229 },
230 _ => &[],
231 }
232 }
233
kernel_info(&self, name: &str) -> Option<&clc_kernel_info>234 pub fn kernel_info(&self, name: &str) -> Option<&clc_kernel_info> {
235 self.kernel_infos()
236 .iter()
237 .find(|i| c_string_to_string(i.name) == name)
238 }
239
kernels(&self) -> Vec<String>240 pub fn kernels(&self) -> Vec<String> {
241 self.kernel_infos()
242 .iter()
243 .map(|i| i.name)
244 .map(c_string_to_string)
245 .collect()
246 }
247
args(&self, name: &str) -> Vec<SPIRVKernelArg>248 pub fn args(&self, name: &str) -> Vec<SPIRVKernelArg> {
249 match self.kernel_info(name) {
250 Some(info) if info.num_args > 0 => {
251 unsafe { slice::from_raw_parts(info.args, info.num_args) }
252 .iter()
253 .map(|a| SPIRVKernelArg {
254 name: c_string_to_string(a.name),
255 type_name: c_string_to_string(a.type_name),
256 access_qualifier: clc_kernel_arg_access_qualifier(a.access_qualifier),
257 address_qualifier: a.address_qualifier,
258 type_qualifier: clc_kernel_arg_type_qualifier(a.type_qualifier),
259 })
260 .collect()
261 }
262 _ => Vec::new(),
263 }
264 }
265
get_spirv_capabilities() -> spirv_capabilities266 fn get_spirv_capabilities() -> spirv_capabilities {
267 spirv_capabilities {
268 Addresses: true,
269 Float16: true,
270 Float16Buffer: true,
271 Float64: true,
272 GenericPointer: true,
273 Groups: true,
274 GroupNonUniformShuffle: true,
275 GroupNonUniformShuffleRelative: true,
276 Int8: true,
277 Int16: true,
278 Int64: true,
279 Kernel: true,
280 ImageBasic: true,
281 ImageReadWrite: true,
282 Linkage: true,
283 LiteralSampler: true,
284 SampledBuffer: true,
285 Sampled1D: true,
286 Vector16: true,
287 ..Default::default()
288 }
289 }
290
get_spirv_options( library: bool, clc_shader: *const nir_shader, address_bits: u32, caps: &spirv_capabilities, log: Option<&mut Vec<String>>, ) -> spirv_to_nir_options291 fn get_spirv_options(
292 library: bool,
293 clc_shader: *const nir_shader,
294 address_bits: u32,
295 caps: &spirv_capabilities,
296 log: Option<&mut Vec<String>>,
297 ) -> spirv_to_nir_options {
298 let global_addr_format;
299 let offset_addr_format;
300
301 if address_bits == 32 {
302 global_addr_format = nir_address_format::nir_address_format_32bit_global;
303 offset_addr_format = nir_address_format::nir_address_format_32bit_offset;
304 } else {
305 global_addr_format = nir_address_format::nir_address_format_64bit_global;
306 offset_addr_format = nir_address_format::nir_address_format_32bit_offset_as_64bit;
307 }
308
309 let debug = log.map(|log| spirv_to_nir_options__bindgen_ty_1 {
310 func: Some(spirv_to_nir_msg_callback),
311 private_data: ptr::from_mut(log).cast(),
312 });
313
314 spirv_to_nir_options {
315 create_library: library,
316 environment: nir_spirv_execution_environment::NIR_SPIRV_OPENCL,
317 clc_shader: clc_shader,
318 float_controls_execution_mode: float_controls::FLOAT_CONTROLS_DENORM_FLUSH_TO_ZERO_FP32
319 as u32,
320
321 printf: true,
322 capabilities: caps,
323 constant_addr_format: global_addr_format,
324 global_addr_format: global_addr_format,
325 shared_addr_format: offset_addr_format,
326 temp_addr_format: offset_addr_format,
327 debug: debug.unwrap_or_default(),
328
329 ..Default::default()
330 }
331 }
332
to_nir( &self, entry_point: &str, nir_options: *const nir_shader_compiler_options, libclc: &NirShader, spec_constants: &mut [nir_spirv_specialization], address_bits: u32, log: Option<&mut Vec<String>>, ) -> Option<NirShader>333 pub fn to_nir(
334 &self,
335 entry_point: &str,
336 nir_options: *const nir_shader_compiler_options,
337 libclc: &NirShader,
338 spec_constants: &mut [nir_spirv_specialization],
339 address_bits: u32,
340 log: Option<&mut Vec<String>>,
341 ) -> Option<NirShader> {
342 let c_entry = CString::new(entry_point.as_bytes()).unwrap();
343 let spirv_caps = Self::get_spirv_capabilities();
344 let spirv_options =
345 Self::get_spirv_options(false, libclc.get_nir(), address_bits, &spirv_caps, log);
346
347 let nir = unsafe {
348 spirv_to_nir(
349 self.spirv.data.cast(),
350 self.spirv.size / 4,
351 spec_constants.as_mut_ptr(),
352 spec_constants.len() as u32,
353 gl_shader_stage::MESA_SHADER_KERNEL,
354 c_entry.as_ptr(),
355 &spirv_options,
356 nir_options,
357 )
358 };
359
360 NirShader::new(nir)
361 }
362
get_lib_clc(screen: &PipeScreen) -> Option<NirShader>363 pub fn get_lib_clc(screen: &PipeScreen) -> Option<NirShader> {
364 let nir_options = screen.nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE);
365 let address_bits = screen.compute_param(pipe_compute_cap::PIPE_COMPUTE_CAP_ADDRESS_BITS);
366 let spirv_caps = Self::get_spirv_capabilities();
367 let spirv_options =
368 Self::get_spirv_options(false, ptr::null(), address_bits, &spirv_caps, None);
369 let shader_cache = DiskCacheBorrowed::as_ptr(&screen.shader_cache());
370
371 NirShader::new(unsafe {
372 nir_load_libclc_shader(
373 address_bits,
374 shader_cache,
375 &spirv_options,
376 nir_options,
377 true,
378 )
379 })
380 }
381
to_bin(&self) -> &[u8]382 pub fn to_bin(&self) -> &[u8] {
383 unsafe { slice::from_raw_parts(self.spirv.data.cast(), self.spirv.size) }
384 }
385
from_bin(bin: &[u8]) -> Self386 pub fn from_bin(bin: &[u8]) -> Self {
387 unsafe {
388 let ptr = malloc(bin.len());
389 ptr::copy_nonoverlapping(bin.as_ptr(), ptr.cast(), bin.len());
390 let spirv = clc_binary {
391 data: ptr,
392 size: bin.len(),
393 };
394
395 let mut pspirv = clc_parsed_spirv::default();
396
397 let info = if clc_parse_spirv(&spirv, ptr::null(), &mut pspirv) {
398 Some(pspirv)
399 } else {
400 None
401 };
402
403 SPIRVBin {
404 spirv: spirv,
405 info: info,
406 }
407 }
408 }
409
spec_constant(&self, spec_id: u32) -> Option<clc_spec_constant_type>410 pub fn spec_constant(&self, spec_id: u32) -> Option<clc_spec_constant_type> {
411 let info = self.info?;
412 if info.num_spec_constants == 0 {
413 return None;
414 }
415
416 let spec_constants =
417 unsafe { slice::from_raw_parts(info.spec_constants, info.num_spec_constants as usize) };
418
419 spec_constants
420 .iter()
421 .find(|sc| sc.id == spec_id)
422 .map(|sc| sc.type_)
423 }
424
print(&self)425 pub fn print(&self) {
426 unsafe {
427 clc_dump_spirv(&self.spirv, stderr_ptr());
428 }
429 }
430 }
431
432 impl Clone for SPIRVBin {
clone(&self) -> Self433 fn clone(&self) -> Self {
434 Self::from_bin(self.to_bin())
435 }
436 }
437
438 impl Drop for SPIRVBin {
drop(&mut self)439 fn drop(&mut self) {
440 unsafe {
441 clc_free_spirv(&mut self.spirv);
442 if let Some(info) = &mut self.info {
443 clc_free_parsed_spirv(info);
444 }
445 }
446 }
447 }
448
449 impl SPIRVKernelArg {
serialize(&self, blob: &mut blob)450 pub fn serialize(&self, blob: &mut blob) {
451 let name_arr = self.name.as_bytes();
452 let type_name_arr = self.type_name.as_bytes();
453
454 unsafe {
455 blob_write_uint32(blob, self.access_qualifier.0);
456 blob_write_uint32(blob, self.type_qualifier.0);
457
458 blob_write_uint16(blob, name_arr.len() as u16);
459 blob_write_uint16(blob, type_name_arr.len() as u16);
460
461 blob_write_bytes(blob, name_arr.as_ptr().cast(), name_arr.len());
462 blob_write_bytes(blob, type_name_arr.as_ptr().cast(), type_name_arr.len());
463
464 blob_write_uint8(blob, self.address_qualifier as u8);
465 }
466 }
467
deserialize(blob: &mut blob_reader) -> Option<Self>468 pub fn deserialize(blob: &mut blob_reader) -> Option<Self> {
469 unsafe {
470 let access_qualifier = blob_read_uint32(blob);
471 let type_qualifier = blob_read_uint32(blob);
472
473 let name_len = blob_read_uint16(blob) as usize;
474 let type_len = blob_read_uint16(blob) as usize;
475
476 let name = slice::from_raw_parts(blob_read_bytes(blob, name_len).cast(), name_len);
477 let type_name = slice::from_raw_parts(blob_read_bytes(blob, type_len).cast(), type_len);
478
479 let address_qualifier = match blob_read_uint8(blob) {
480 0 => clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE,
481 1 => clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT,
482 2 => clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL,
483 3 => clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL,
484 _ => return None,
485 };
486
487 Some(Self {
488 name: String::from_utf8_unchecked(name.to_owned()),
489 type_name: String::from_utf8_unchecked(type_name.to_owned()),
490 access_qualifier: clc_kernel_arg_access_qualifier(access_qualifier),
491 address_qualifier: address_qualifier,
492 type_qualifier: clc_kernel_arg_type_qualifier(type_qualifier),
493 })
494 }
495 }
496 }
497
498 pub trait CLCSpecConstantType {
size(self) -> u8499 fn size(self) -> u8;
500 }
501
502 impl CLCSpecConstantType for clc_spec_constant_type {
size(self) -> u8503 fn size(self) -> u8 {
504 match self {
505 Self::CLC_SPEC_CONSTANT_INT64
506 | Self::CLC_SPEC_CONSTANT_UINT64
507 | Self::CLC_SPEC_CONSTANT_DOUBLE => 8,
508 Self::CLC_SPEC_CONSTANT_INT32
509 | Self::CLC_SPEC_CONSTANT_UINT32
510 | Self::CLC_SPEC_CONSTANT_FLOAT => 4,
511 Self::CLC_SPEC_CONSTANT_INT16 | Self::CLC_SPEC_CONSTANT_UINT16 => 2,
512 Self::CLC_SPEC_CONSTANT_INT8
513 | Self::CLC_SPEC_CONSTANT_UINT8
514 | Self::CLC_SPEC_CONSTANT_BOOL => 1,
515 Self::CLC_SPEC_CONSTANT_UNKNOWN => 0,
516 }
517 }
518 }
519
520 pub trait SpirvKernelInfo {
vec_type_hint(&self) -> Option<String>521 fn vec_type_hint(&self) -> Option<String>;
local_size(&self) -> Option<String>522 fn local_size(&self) -> Option<String>;
local_size_hint(&self) -> Option<String>523 fn local_size_hint(&self) -> Option<String>;
524
attribute_str(&self) -> String525 fn attribute_str(&self) -> String {
526 let attributes_strings = [
527 self.vec_type_hint(),
528 self.local_size(),
529 self.local_size_hint(),
530 ];
531
532 let attributes_strings: Vec<_> = attributes_strings.into_iter().flatten().collect();
533 attributes_strings.join(",")
534 }
535 }
536
537 impl SpirvKernelInfo for clc_kernel_info {
vec_type_hint(&self) -> Option<String>538 fn vec_type_hint(&self) -> Option<String> {
539 if ![1, 2, 3, 4, 8, 16].contains(&self.vec_hint_size) {
540 return None;
541 }
542 let cltype = match self.vec_hint_type {
543 clc_vec_hint_type::CLC_VEC_HINT_TYPE_CHAR => "uchar",
544 clc_vec_hint_type::CLC_VEC_HINT_TYPE_SHORT => "ushort",
545 clc_vec_hint_type::CLC_VEC_HINT_TYPE_INT => "uint",
546 clc_vec_hint_type::CLC_VEC_HINT_TYPE_LONG => "ulong",
547 clc_vec_hint_type::CLC_VEC_HINT_TYPE_HALF => "half",
548 clc_vec_hint_type::CLC_VEC_HINT_TYPE_FLOAT => "float",
549 clc_vec_hint_type::CLC_VEC_HINT_TYPE_DOUBLE => "double",
550 };
551
552 Some(format!("vec_type_hint({}{})", cltype, self.vec_hint_size))
553 }
554
local_size(&self) -> Option<String>555 fn local_size(&self) -> Option<String> {
556 if self.local_size == [0; 3] {
557 return None;
558 }
559 Some(format!(
560 "reqd_work_group_size({},{},{})",
561 self.local_size[0], self.local_size[1], self.local_size[2]
562 ))
563 }
564
local_size_hint(&self) -> Option<String>565 fn local_size_hint(&self) -> Option<String> {
566 if self.local_size_hint == [0; 3] {
567 return None;
568 }
569 Some(format!(
570 "work_group_size_hint({},{},{})",
571 self.local_size_hint[0], self.local_size_hint[1], self.local_size_hint[2]
572 ))
573 }
574 }
575