xref: /aosp_15_r20/external/mesa3d/src/gallium/frontends/rusticl/core/kernel.rs (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 use crate::api::icd::*;
2 use crate::core::device::*;
3 use crate::core::event::*;
4 use crate::core::memory::*;
5 use crate::core::platform::*;
6 use crate::core::program::*;
7 use crate::core::queue::*;
8 use crate::impl_cl_type_trait;
9 
10 use mesa_rust::compiler::clc::*;
11 use mesa_rust::compiler::nir::*;
12 use mesa_rust::nir_pass;
13 use mesa_rust::pipe::context::RWFlags;
14 use mesa_rust::pipe::resource::*;
15 use mesa_rust::pipe::screen::ResourceType;
16 use mesa_rust_gen::*;
17 use mesa_rust_util::math::*;
18 use mesa_rust_util::serialize::*;
19 use rusticl_opencl_gen::*;
20 use spirv::SpirvKernelInfo;
21 
22 use std::cmp;
23 use std::collections::HashMap;
24 use std::convert::TryInto;
25 use std::fmt::Debug;
26 use std::fmt::Display;
27 use std::ops::Index;
28 use std::os::raw::c_void;
29 use std::ptr;
30 use std::slice;
31 use std::sync::Arc;
32 use std::sync::Mutex;
33 use std::sync::MutexGuard;
34 
35 // ugh, we are not allowed to take refs, so...
36 #[derive(Clone)]
37 pub enum KernelArgValue {
38     None,
39     Buffer(Arc<Buffer>),
40     Constant(Vec<u8>),
41     Image(Arc<Image>),
42     LocalMem(usize),
43     Sampler(Arc<Sampler>),
44 }
45 
46 #[repr(u8)]
47 #[derive(Hash, PartialEq, Eq, Clone, Copy)]
48 pub enum KernelArgType {
49     Constant(/* size */ u16), // for anything passed by value
50     Image,
51     RWImage,
52     Sampler,
53     Texture,
54     MemGlobal,
55     MemConstant,
56     MemLocal,
57 }
58 
59 impl KernelArgType {
deserialize(blob: &mut blob_reader) -> Option<Self>60     fn deserialize(blob: &mut blob_reader) -> Option<Self> {
61         Some(match unsafe { blob_read_uint8(blob) } {
62             0 => {
63                 let size = unsafe { blob_read_uint16(blob) };
64                 KernelArgType::Constant(size)
65             }
66             1 => KernelArgType::Image,
67             2 => KernelArgType::RWImage,
68             3 => KernelArgType::Sampler,
69             4 => KernelArgType::Texture,
70             5 => KernelArgType::MemGlobal,
71             6 => KernelArgType::MemConstant,
72             7 => KernelArgType::MemLocal,
73             _ => return None,
74         })
75     }
76 
serialize(&self, blob: &mut blob)77     fn serialize(&self, blob: &mut blob) {
78         unsafe {
79             match self {
80                 KernelArgType::Constant(size) => {
81                     blob_write_uint8(blob, 0);
82                     blob_write_uint16(blob, *size)
83                 }
84                 KernelArgType::Image => blob_write_uint8(blob, 1),
85                 KernelArgType::RWImage => blob_write_uint8(blob, 2),
86                 KernelArgType::Sampler => blob_write_uint8(blob, 3),
87                 KernelArgType::Texture => blob_write_uint8(blob, 4),
88                 KernelArgType::MemGlobal => blob_write_uint8(blob, 5),
89                 KernelArgType::MemConstant => blob_write_uint8(blob, 6),
90                 KernelArgType::MemLocal => blob_write_uint8(blob, 7),
91             };
92         }
93     }
94 
is_opaque(&self) -> bool95     fn is_opaque(&self) -> bool {
96         matches!(
97             self,
98             KernelArgType::Image
99                 | KernelArgType::RWImage
100                 | KernelArgType::Texture
101                 | KernelArgType::Sampler
102         )
103     }
104 }
105 
106 #[derive(Hash, PartialEq, Eq, Clone)]
107 enum CompiledKernelArgType {
108     APIArg(u32),
109     ConstantBuffer,
110     GlobalWorkOffsets,
111     GlobalWorkSize,
112     PrintfBuffer,
113     InlineSampler((cl_addressing_mode, cl_filter_mode, bool)),
114     FormatArray,
115     OrderArray,
116     WorkDim,
117     WorkGroupOffsets,
118     NumWorkgroups,
119 }
120 
121 #[derive(Hash, PartialEq, Eq, Clone)]
122 pub struct KernelArg {
123     spirv: spirv::SPIRVKernelArg,
124     pub kind: KernelArgType,
125     pub dead: bool,
126 }
127 
128 impl KernelArg {
from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self>129     fn from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self> {
130         let nir_arg_map: HashMap<_, _> = nir
131             .variables_with_mode(
132                 nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
133             )
134             .map(|v| (v.data.location, v))
135             .collect();
136         let mut res = Vec::new();
137 
138         for (i, s) in spirv.iter().enumerate() {
139             let nir = nir_arg_map.get(&(i as i32)).unwrap();
140             let kind = match s.address_qualifier {
141                 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
142                     if unsafe { glsl_type_is_sampler(nir.type_) } {
143                         KernelArgType::Sampler
144                     } else {
145                         let size = unsafe { glsl_get_cl_size(nir.type_) } as u16;
146                         // nir types of non opaque types are never sized 0
147                         KernelArgType::Constant(size)
148                     }
149                 }
150                 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT => {
151                     KernelArgType::MemConstant
152                 }
153                 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL => {
154                     KernelArgType::MemLocal
155                 }
156                 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL => {
157                     if unsafe { glsl_type_is_image(nir.type_) } {
158                         let access = nir.data.access();
159                         if access == gl_access_qualifier::ACCESS_NON_WRITEABLE.0 {
160                             KernelArgType::Texture
161                         } else if access == gl_access_qualifier::ACCESS_NON_READABLE.0 {
162                             KernelArgType::Image
163                         } else {
164                             KernelArgType::RWImage
165                         }
166                     } else {
167                         KernelArgType::MemGlobal
168                     }
169                 }
170             };
171 
172             res.push(Self {
173                 spirv: s.clone(),
174                 // we'll update it later in the 2nd pass
175                 kind: kind,
176                 dead: true,
177             });
178         }
179         res
180     }
181 
serialize(args: &[Self], blob: &mut blob)182     fn serialize(args: &[Self], blob: &mut blob) {
183         unsafe {
184             blob_write_uint16(blob, args.len() as u16);
185 
186             for arg in args {
187                 arg.spirv.serialize(blob);
188                 blob_write_uint8(blob, arg.dead.into());
189                 arg.kind.serialize(blob);
190             }
191         }
192     }
193 
deserialize(blob: &mut blob_reader) -> Option<Vec<Self>>194     fn deserialize(blob: &mut blob_reader) -> Option<Vec<Self>> {
195         unsafe {
196             let len = blob_read_uint16(blob) as usize;
197             let mut res = Vec::with_capacity(len);
198 
199             for _ in 0..len {
200                 let spirv = spirv::SPIRVKernelArg::deserialize(blob)?;
201                 let dead = blob_read_uint8(blob) != 0;
202                 let kind = KernelArgType::deserialize(blob)?;
203 
204                 res.push(Self {
205                     spirv: spirv,
206                     kind: kind,
207                     dead: dead,
208                 });
209             }
210 
211             Some(res)
212         }
213     }
214 }
215 
216 #[derive(Hash, PartialEq, Eq, Clone)]
217 struct CompiledKernelArg {
218     kind: CompiledKernelArgType,
219     /// The binding for image/sampler args, the offset into the input buffer
220     /// for anything else.
221     offset: u32,
222     dead: bool,
223 }
224 
225 impl CompiledKernelArg {
assign_locations(compiled_args: &mut [Self], nir: &mut NirShader)226     fn assign_locations(compiled_args: &mut [Self], nir: &mut NirShader) {
227         for var in nir.variables_with_mode(
228             nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
229         ) {
230             let arg = &mut compiled_args[var.data.location as usize];
231             let t = var.type_;
232 
233             arg.dead = false;
234             arg.offset = if unsafe {
235                 glsl_type_is_image(t) || glsl_type_is_texture(t) || glsl_type_is_sampler(t)
236             } {
237                 var.data.binding
238             } else {
239                 var.data.driver_location
240             };
241         }
242     }
243 
serialize(args: &[Self], blob: &mut blob)244     fn serialize(args: &[Self], blob: &mut blob) {
245         unsafe {
246             blob_write_uint16(blob, args.len() as u16);
247             for arg in args {
248                 blob_write_uint32(blob, arg.offset);
249                 blob_write_uint8(blob, arg.dead.into());
250                 match arg.kind {
251                     CompiledKernelArgType::ConstantBuffer => blob_write_uint8(blob, 0),
252                     CompiledKernelArgType::GlobalWorkOffsets => blob_write_uint8(blob, 1),
253                     CompiledKernelArgType::PrintfBuffer => blob_write_uint8(blob, 2),
254                     CompiledKernelArgType::InlineSampler((addr_mode, filter_mode, norm)) => {
255                         blob_write_uint8(blob, 3);
256                         blob_write_uint8(blob, norm.into());
257                         blob_write_uint32(blob, addr_mode);
258                         blob_write_uint32(blob, filter_mode)
259                     }
260                     CompiledKernelArgType::FormatArray => blob_write_uint8(blob, 4),
261                     CompiledKernelArgType::OrderArray => blob_write_uint8(blob, 5),
262                     CompiledKernelArgType::WorkDim => blob_write_uint8(blob, 6),
263                     CompiledKernelArgType::WorkGroupOffsets => blob_write_uint8(blob, 7),
264                     CompiledKernelArgType::NumWorkgroups => blob_write_uint8(blob, 8),
265                     CompiledKernelArgType::GlobalWorkSize => blob_write_uint8(blob, 9),
266                     CompiledKernelArgType::APIArg(idx) => {
267                         blob_write_uint8(blob, 10);
268                         blob_write_uint32(blob, idx)
269                     }
270                 };
271             }
272         }
273     }
274 
deserialize(blob: &mut blob_reader) -> Option<Vec<Self>>275     fn deserialize(blob: &mut blob_reader) -> Option<Vec<Self>> {
276         unsafe {
277             let len = blob_read_uint16(blob) as usize;
278             let mut res = Vec::with_capacity(len);
279 
280             for _ in 0..len {
281                 let offset = blob_read_uint32(blob);
282                 let dead = blob_read_uint8(blob) != 0;
283 
284                 let kind = match blob_read_uint8(blob) {
285                     0 => CompiledKernelArgType::ConstantBuffer,
286                     1 => CompiledKernelArgType::GlobalWorkOffsets,
287                     2 => CompiledKernelArgType::PrintfBuffer,
288                     3 => {
289                         let norm = blob_read_uint8(blob) != 0;
290                         let addr_mode = blob_read_uint32(blob);
291                         let filter_mode = blob_read_uint32(blob);
292                         CompiledKernelArgType::InlineSampler((addr_mode, filter_mode, norm))
293                     }
294                     4 => CompiledKernelArgType::FormatArray,
295                     5 => CompiledKernelArgType::OrderArray,
296                     6 => CompiledKernelArgType::WorkDim,
297                     7 => CompiledKernelArgType::WorkGroupOffsets,
298                     8 => CompiledKernelArgType::NumWorkgroups,
299                     9 => CompiledKernelArgType::GlobalWorkSize,
300                     10 => {
301                         let idx = blob_read_uint32(blob);
302                         CompiledKernelArgType::APIArg(idx)
303                     }
304                     _ => return None,
305                 };
306 
307                 res.push(Self {
308                     kind: kind,
309                     offset: offset,
310                     dead: dead,
311                 });
312             }
313 
314             Some(res)
315         }
316     }
317 }
318 
319 #[derive(Clone, PartialEq, Eq, Hash)]
320 pub struct KernelInfo {
321     pub args: Vec<KernelArg>,
322     pub attributes_string: String,
323     work_group_size: [usize; 3],
324     work_group_size_hint: [u32; 3],
325     subgroup_size: usize,
326     num_subgroups: usize,
327 }
328 
329 struct CSOWrapper {
330     cso_ptr: *mut c_void,
331     dev: &'static Device,
332 }
333 
334 impl CSOWrapper {
new(dev: &'static Device, nir: &NirShader) -> Self335     fn new(dev: &'static Device, nir: &NirShader) -> Self {
336         let cso_ptr = dev
337             .helper_ctx()
338             .create_compute_state(nir, nir.shared_size());
339 
340         Self {
341             cso_ptr: cso_ptr,
342             dev: dev,
343         }
344     }
345 
get_cso_info(&self) -> pipe_compute_state_object_info346     fn get_cso_info(&self) -> pipe_compute_state_object_info {
347         self.dev.helper_ctx().compute_state_info(self.cso_ptr)
348     }
349 }
350 
351 impl Drop for CSOWrapper {
drop(&mut self)352     fn drop(&mut self) {
353         self.dev.helper_ctx().delete_compute_state(self.cso_ptr);
354     }
355 }
356 
357 enum KernelDevStateVariant {
358     Cso(CSOWrapper),
359     Nir(NirShader),
360 }
361 
362 #[derive(Debug, PartialEq)]
363 enum NirKernelVariant {
364     /// Can be used under any circumstance.
365     Default,
366 
367     /// Optimized variant making the following assumptions:
368     ///  - global_id_offsets are 0
369     ///  - workgroup_offsets are 0
370     ///  - local_size is info.local_size_hint
371     Optimized,
372 }
373 
374 impl Display for NirKernelVariant {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result375     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376         // this simply prints the enum name, so that's fine
377         Debug::fmt(self, f)
378     }
379 }
380 
381 pub struct NirKernelBuilds {
382     default_build: NirKernelBuild,
383     optimized: Option<NirKernelBuild>,
384     /// merged info with worst case values
385     info: pipe_compute_state_object_info,
386 }
387 
388 impl Index<NirKernelVariant> for NirKernelBuilds {
389     type Output = NirKernelBuild;
390 
index(&self, index: NirKernelVariant) -> &Self::Output391     fn index(&self, index: NirKernelVariant) -> &Self::Output {
392         match index {
393             NirKernelVariant::Default => &self.default_build,
394             NirKernelVariant::Optimized => self.optimized.as_ref().unwrap_or(&self.default_build),
395         }
396     }
397 }
398 
399 impl NirKernelBuilds {
new(default_build: NirKernelBuild, optimized: Option<NirKernelBuild>) -> Self400     fn new(default_build: NirKernelBuild, optimized: Option<NirKernelBuild>) -> Self {
401         let mut info = default_build.info;
402         if let Some(build) = &optimized {
403             info.max_threads = cmp::min(info.max_threads, build.info.max_threads);
404             info.simd_sizes &= build.info.simd_sizes;
405             info.private_memory = cmp::max(info.private_memory, build.info.private_memory);
406             info.preferred_simd_size =
407                 cmp::max(info.preferred_simd_size, build.info.preferred_simd_size);
408         }
409 
410         Self {
411             default_build: default_build,
412             optimized: optimized,
413             info: info,
414         }
415     }
416 }
417 
418 struct NirKernelBuild {
419     nir_or_cso: KernelDevStateVariant,
420     constant_buffer: Option<Arc<PipeResource>>,
421     info: pipe_compute_state_object_info,
422     shared_size: u64,
423     printf_info: Option<NirPrintfInfo>,
424     compiled_args: Vec<CompiledKernelArg>,
425 }
426 
427 // SAFETY: `CSOWrapper` is only safe to use if the device supports `PIPE_CAP_SHAREABLE_SHADERS` and
428 //         we make sure to set `nir_or_cso` to `KernelDevStateVariant::Cso` only if that's the case.
429 unsafe impl Send for NirKernelBuild {}
430 unsafe impl Sync for NirKernelBuild {}
431 
432 impl NirKernelBuild {
new(dev: &'static Device, mut out: CompilationResult) -> Self433     fn new(dev: &'static Device, mut out: CompilationResult) -> Self {
434         let cso = CSOWrapper::new(dev, &out.nir);
435         let info = cso.get_cso_info();
436         let cb = Self::create_nir_constant_buffer(dev, &out.nir);
437         let shared_size = out.nir.shared_size() as u64;
438         let printf_info = out.nir.take_printf_info();
439 
440         let nir_or_cso = if !dev.shareable_shaders() {
441             KernelDevStateVariant::Nir(out.nir)
442         } else {
443             KernelDevStateVariant::Cso(cso)
444         };
445 
446         NirKernelBuild {
447             nir_or_cso: nir_or_cso,
448             constant_buffer: cb,
449             info: info,
450             shared_size: shared_size,
451             printf_info: printf_info,
452             compiled_args: out.compiled_args,
453         }
454     }
455 
create_nir_constant_buffer(dev: &Device, nir: &NirShader) -> Option<Arc<PipeResource>>456     fn create_nir_constant_buffer(dev: &Device, nir: &NirShader) -> Option<Arc<PipeResource>> {
457         let buf = nir.get_constant_buffer();
458         let len = buf.len() as u32;
459 
460         if len > 0 {
461             // TODO bind as constant buffer
462             let res = dev
463                 .screen()
464                 .resource_create_buffer(len, ResourceType::Normal, PIPE_BIND_GLOBAL)
465                 .unwrap();
466 
467             dev.helper_ctx()
468                 .exec(|ctx| ctx.buffer_subdata(&res, 0, buf.as_ptr().cast(), len))
469                 .wait();
470 
471             Some(Arc::new(res))
472         } else {
473             None
474         }
475     }
476 }
477 
478 pub struct Kernel {
479     pub base: CLObjectBase<CL_INVALID_KERNEL>,
480     pub prog: Arc<Program>,
481     pub name: String,
482     values: Mutex<Vec<Option<KernelArgValue>>>,
483     builds: HashMap<&'static Device, Arc<NirKernelBuilds>>,
484     pub kernel_info: Arc<KernelInfo>,
485 }
486 
487 impl_cl_type_trait!(cl_kernel, Kernel, CL_INVALID_KERNEL);
488 
create_kernel_arr<T>(vals: &[usize], val: T) -> CLResult<[T; 3]> where T: std::convert::TryFrom<usize> + Copy, <T as std::convert::TryFrom<usize>>::Error: std::fmt::Debug,489 fn create_kernel_arr<T>(vals: &[usize], val: T) -> CLResult<[T; 3]>
490 where
491     T: std::convert::TryFrom<usize> + Copy,
492     <T as std::convert::TryFrom<usize>>::Error: std::fmt::Debug,
493 {
494     let mut res = [val; 3];
495     for (i, v) in vals.iter().enumerate() {
496         res[i] = (*v).try_into().ok().ok_or(CL_OUT_OF_RESOURCES)?;
497     }
498 
499     Ok(res)
500 }
501 
502 #[derive(Clone)]
503 struct CompilationResult {
504     nir: NirShader,
505     compiled_args: Vec<CompiledKernelArg>,
506 }
507 
508 impl CompilationResult {
deserialize(reader: &mut blob_reader, d: &Device) -> Option<Self>509     fn deserialize(reader: &mut blob_reader, d: &Device) -> Option<Self> {
510         let nir = NirShader::deserialize(
511             reader,
512             d.screen()
513                 .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE),
514         )?;
515         let compiled_args = CompiledKernelArg::deserialize(reader)?;
516 
517         Some(Self {
518             nir: nir,
519             compiled_args,
520         })
521     }
522 
serialize(&self, blob: &mut blob)523     fn serialize(&self, blob: &mut blob) {
524         self.nir.serialize(blob);
525         CompiledKernelArg::serialize(&self.compiled_args, blob);
526     }
527 }
528 
opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool)529 fn opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool) {
530     let nir_options = unsafe {
531         &*dev
532             .screen
533             .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE)
534     };
535 
536     while {
537         let mut progress = false;
538 
539         progress |= nir_pass!(nir, nir_copy_prop);
540         progress |= nir_pass!(nir, nir_opt_copy_prop_vars);
541         progress |= nir_pass!(nir, nir_opt_dead_write_vars);
542 
543         if nir_options.lower_to_scalar {
544             nir_pass!(
545                 nir,
546                 nir_lower_alu_to_scalar,
547                 nir_options.lower_to_scalar_filter,
548                 ptr::null(),
549             );
550             nir_pass!(nir, nir_lower_phis_to_scalar, false);
551         }
552 
553         progress |= nir_pass!(nir, nir_opt_deref);
554         if has_explicit_types {
555             progress |= nir_pass!(nir, nir_opt_memcpy);
556         }
557         progress |= nir_pass!(nir, nir_opt_dce);
558         progress |= nir_pass!(nir, nir_opt_undef);
559         progress |= nir_pass!(nir, nir_opt_constant_folding);
560         progress |= nir_pass!(nir, nir_opt_cse);
561         nir_pass!(nir, nir_split_var_copies);
562         progress |= nir_pass!(nir, nir_lower_var_copies);
563         progress |= nir_pass!(nir, nir_lower_vars_to_ssa);
564         nir_pass!(nir, nir_lower_alu);
565         progress |= nir_pass!(nir, nir_opt_phi_precision);
566         progress |= nir_pass!(nir, nir_opt_algebraic);
567         progress |= nir_pass!(
568             nir,
569             nir_opt_if,
570             nir_opt_if_options::nir_opt_if_optimize_phi_true_false,
571         );
572         progress |= nir_pass!(nir, nir_opt_dead_cf);
573         progress |= nir_pass!(nir, nir_opt_remove_phis);
574         // we don't want to be too aggressive here, but it kills a bit of CFG
575         progress |= nir_pass!(nir, nir_opt_peephole_select, 8, true, true);
576         progress |= nir_pass!(
577             nir,
578             nir_lower_vec3_to_vec4,
579             nir_variable_mode::nir_var_mem_generic | nir_variable_mode::nir_var_uniform,
580         );
581 
582         if nir_options.max_unroll_iterations != 0 {
583             progress |= nir_pass!(nir, nir_opt_loop_unroll);
584         }
585         nir.sweep_mem();
586         progress
587     } {}
588 }
589 
590 /// # Safety
591 ///
592 /// Only safe to call when `var` is a valid pointer to a valid [`nir_variable`]
can_remove_var(var: *mut nir_variable, _: *mut c_void) -> bool593 unsafe extern "C" fn can_remove_var(var: *mut nir_variable, _: *mut c_void) -> bool {
594     // SAFETY: It is the caller's responsibility to provide a valid and aligned pointer
595     let var_type = unsafe { (*var).type_ };
596     // SAFETY: `nir_variable`'s type invariant guarantees that the `type_` field is valid and
597     // properly aligned.
598     unsafe {
599         !glsl_type_is_image(var_type)
600             && !glsl_type_is_texture(var_type)
601             && !glsl_type_is_sampler(var_type)
602     }
603 }
604 
605 const DV_OPTS: nir_remove_dead_variables_options = nir_remove_dead_variables_options {
606     can_remove_var: Some(can_remove_var),
607     can_remove_var_data: ptr::null_mut(),
608 };
609 
compile_nir_to_args( dev: &Device, mut nir: NirShader, args: &[spirv::SPIRVKernelArg], lib_clc: &NirShader, ) -> (Vec<KernelArg>, NirShader)610 fn compile_nir_to_args(
611     dev: &Device,
612     mut nir: NirShader,
613     args: &[spirv::SPIRVKernelArg],
614     lib_clc: &NirShader,
615 ) -> (Vec<KernelArg>, NirShader) {
616     // this is a hack until we support fp16 properly and check for denorms inside vstore/vload_half
617     nir.preserve_fp16_denorms();
618 
619     // Set to rtne for now until drivers are able to report their preferred rounding mode, that also
620     // matches what we report via the API.
621     nir.set_fp_rounding_mode_rtne();
622 
623     nir_pass!(nir, nir_scale_fdiv);
624     nir.set_workgroup_size_variable_if_zero();
625     nir.structurize();
626     while {
627         let mut progress = false;
628         nir_pass!(nir, nir_split_var_copies);
629         progress |= nir_pass!(nir, nir_copy_prop);
630         progress |= nir_pass!(nir, nir_opt_copy_prop_vars);
631         progress |= nir_pass!(nir, nir_opt_dead_write_vars);
632         progress |= nir_pass!(nir, nir_opt_deref);
633         progress |= nir_pass!(nir, nir_opt_dce);
634         progress |= nir_pass!(nir, nir_opt_undef);
635         progress |= nir_pass!(nir, nir_opt_constant_folding);
636         progress |= nir_pass!(nir, nir_opt_cse);
637         progress |= nir_pass!(nir, nir_lower_vars_to_ssa);
638         progress |= nir_pass!(nir, nir_opt_algebraic);
639         progress
640     } {}
641     nir.inline(lib_clc);
642     nir.cleanup_functions();
643     // that should free up tons of memory
644     nir.sweep_mem();
645 
646     nir_pass!(nir, nir_dedup_inline_samplers);
647 
648     let printf_opts = nir_lower_printf_options {
649         ptr_bit_size: 0,
650         use_printf_base_identifier: false,
651         max_buffer_size: dev.printf_buffer_size() as u32,
652     };
653     nir_pass!(nir, nir_lower_printf, &printf_opts);
654 
655     opt_nir(&mut nir, dev, false);
656 
657     (KernelArg::from_spirv_nir(args, &mut nir), nir)
658 }
659 
compile_nir_prepare_for_variants( dev: &Device, nir: &mut NirShader, compiled_args: &mut Vec<CompiledKernelArg>, )660 fn compile_nir_prepare_for_variants(
661     dev: &Device,
662     nir: &mut NirShader,
663     compiled_args: &mut Vec<CompiledKernelArg>,
664 ) {
665     // assign locations for inline samplers.
666     // IMPORTANT: this needs to happen before nir_remove_dead_variables.
667     let mut last_loc = -1;
668     for v in nir
669         .variables_with_mode(nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image)
670     {
671         if unsafe { !glsl_type_is_sampler(v.type_) } {
672             last_loc = v.data.location;
673             continue;
674         }
675         let s = unsafe { v.data.anon_1.sampler };
676         if s.is_inline_sampler() != 0 {
677             last_loc += 1;
678             v.data.location = last_loc;
679 
680             compiled_args.push(CompiledKernelArg {
681                 kind: CompiledKernelArgType::InlineSampler(Sampler::nir_to_cl(
682                     s.addressing_mode(),
683                     s.filter_mode(),
684                     s.normalized_coordinates(),
685                 )),
686                 offset: 0,
687                 dead: true,
688             });
689         } else {
690             last_loc = v.data.location;
691         }
692     }
693 
694     nir_pass!(
695         nir,
696         nir_remove_dead_variables,
697         nir_variable_mode::nir_var_uniform
698             | nir_variable_mode::nir_var_image
699             | nir_variable_mode::nir_var_mem_constant
700             | nir_variable_mode::nir_var_mem_shared
701             | nir_variable_mode::nir_var_function_temp,
702         &DV_OPTS,
703     );
704 
705     nir_pass!(nir, nir_lower_readonly_images_to_tex, true);
706     nir_pass!(
707         nir,
708         nir_lower_cl_images,
709         !dev.images_as_deref(),
710         !dev.samplers_as_deref(),
711     );
712 
713     nir_pass!(
714         nir,
715         nir_lower_vars_to_explicit_types,
716         nir_variable_mode::nir_var_mem_constant,
717         Some(glsl_get_cl_type_size_align),
718     );
719 
720     // has to run before adding internal kernel arguments
721     nir.extract_constant_initializers();
722 
723     // needed to convert variables to load intrinsics
724     nir_pass!(nir, nir_lower_system_values);
725 
726     // Run here so we can decide if it makes sense to compile a variant, e.g. read system values.
727     nir.gather_info();
728 }
729 
compile_nir_variant( res: &mut CompilationResult, dev: &Device, variant: NirKernelVariant, args: &[KernelArg], name: &str, )730 fn compile_nir_variant(
731     res: &mut CompilationResult,
732     dev: &Device,
733     variant: NirKernelVariant,
734     args: &[KernelArg],
735     name: &str,
736 ) {
737     let mut lower_state = rusticl_lower_state::default();
738     let compiled_args = &mut res.compiled_args;
739     let nir = &mut res.nir;
740 
741     let address_bits_ptr_type;
742     let address_bits_base_type;
743     let global_address_format;
744     let shared_address_format;
745 
746     if dev.address_bits() == 64 {
747         address_bits_ptr_type = unsafe { glsl_uint64_t_type() };
748         address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT64;
749         global_address_format = nir_address_format::nir_address_format_64bit_global;
750         shared_address_format = nir_address_format::nir_address_format_32bit_offset_as_64bit;
751     } else {
752         address_bits_ptr_type = unsafe { glsl_uint_type() };
753         address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT;
754         global_address_format = nir_address_format::nir_address_format_32bit_global;
755         shared_address_format = nir_address_format::nir_address_format_32bit_offset;
756     }
757 
758     let nir_options = unsafe {
759         &*dev
760             .screen
761             .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE)
762     };
763 
764     if variant == NirKernelVariant::Optimized {
765         let wgsh = nir.workgroup_size_hint();
766         if wgsh != [0; 3] {
767             nir.set_workgroup_size(wgsh);
768         }
769     }
770 
771     let mut compute_options = nir_lower_compute_system_values_options::default();
772     compute_options.set_has_global_size(true);
773     if variant != NirKernelVariant::Optimized {
774         compute_options.set_has_base_global_invocation_id(true);
775         compute_options.set_has_base_workgroup_id(true);
776     }
777     nir_pass!(nir, nir_lower_compute_system_values, &compute_options);
778     nir.gather_info();
779 
780     let mut add_var = |nir: &mut NirShader,
781                        var_loc: &mut usize,
782                        kind: CompiledKernelArgType,
783                        glsl_type: *const glsl_type,
784                        name: &str| {
785         *var_loc = compiled_args.len();
786         compiled_args.push(CompiledKernelArg {
787             kind: kind,
788             offset: 0,
789             dead: true,
790         });
791         nir.add_var(
792             nir_variable_mode::nir_var_uniform,
793             glsl_type,
794             *var_loc,
795             name,
796         );
797     };
798 
799     if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_GLOBAL_INVOCATION_ID) {
800         debug_assert_ne!(variant, NirKernelVariant::Optimized);
801         add_var(
802             nir,
803             &mut lower_state.base_global_invoc_id_loc,
804             CompiledKernelArgType::GlobalWorkOffsets,
805             unsafe { glsl_vector_type(address_bits_base_type, 3) },
806             "base_global_invocation_id",
807         )
808     }
809 
810     if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_GLOBAL_GROUP_SIZE) {
811         add_var(
812             nir,
813             &mut lower_state.global_size_loc,
814             CompiledKernelArgType::GlobalWorkSize,
815             unsafe { glsl_vector_type(address_bits_base_type, 3) },
816             "global_size",
817         )
818     }
819 
820     if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_WORKGROUP_ID) {
821         debug_assert_ne!(variant, NirKernelVariant::Optimized);
822         add_var(
823             nir,
824             &mut lower_state.base_workgroup_id_loc,
825             CompiledKernelArgType::WorkGroupOffsets,
826             unsafe { glsl_vector_type(address_bits_base_type, 3) },
827             "base_workgroup_id",
828         );
829     }
830 
831     if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_NUM_WORKGROUPS) {
832         add_var(
833             nir,
834             &mut lower_state.num_workgroups_loc,
835             CompiledKernelArgType::NumWorkgroups,
836             unsafe { glsl_vector_type(glsl_base_type::GLSL_TYPE_UINT, 3) },
837             "num_workgroups",
838         );
839     }
840 
841     if nir.has_constant() {
842         add_var(
843             nir,
844             &mut lower_state.const_buf_loc,
845             CompiledKernelArgType::ConstantBuffer,
846             address_bits_ptr_type,
847             "constant_buffer_addr",
848         );
849     }
850     if nir.has_printf() {
851         add_var(
852             nir,
853             &mut lower_state.printf_buf_loc,
854             CompiledKernelArgType::PrintfBuffer,
855             address_bits_ptr_type,
856             "printf_buffer_addr",
857         );
858     }
859 
860     if nir.num_images() > 0 || nir.num_textures() > 0 {
861         let count = nir.num_images() + nir.num_textures();
862 
863         add_var(
864             nir,
865             &mut lower_state.format_arr_loc,
866             CompiledKernelArgType::FormatArray,
867             unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
868             "image_formats",
869         );
870 
871         add_var(
872             nir,
873             &mut lower_state.order_arr_loc,
874             CompiledKernelArgType::OrderArray,
875             unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
876             "image_orders",
877         );
878     }
879 
880     if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_WORK_DIM) {
881         add_var(
882             nir,
883             &mut lower_state.work_dim_loc,
884             CompiledKernelArgType::WorkDim,
885             unsafe { glsl_uint8_t_type() },
886             "work_dim",
887         );
888     }
889 
890     // need to run after first opt loop and remove_dead_variables to get rid of uneccessary scratch
891     // memory
892     nir_pass!(
893         nir,
894         nir_lower_vars_to_explicit_types,
895         nir_variable_mode::nir_var_mem_shared
896             | nir_variable_mode::nir_var_function_temp
897             | nir_variable_mode::nir_var_shader_temp
898             | nir_variable_mode::nir_var_uniform
899             | nir_variable_mode::nir_var_mem_global
900             | nir_variable_mode::nir_var_mem_generic,
901         Some(glsl_get_cl_type_size_align),
902     );
903 
904     opt_nir(nir, dev, true);
905     nir_pass!(nir, nir_lower_memcpy);
906 
907     // we might have got rid of more function_temp or shared memory
908     nir.reset_scratch_size();
909     nir.reset_shared_size();
910     nir_pass!(
911         nir,
912         nir_remove_dead_variables,
913         nir_variable_mode::nir_var_function_temp | nir_variable_mode::nir_var_mem_shared,
914         &DV_OPTS,
915     );
916     nir_pass!(
917         nir,
918         nir_lower_vars_to_explicit_types,
919         nir_variable_mode::nir_var_function_temp
920             | nir_variable_mode::nir_var_mem_shared
921             | nir_variable_mode::nir_var_mem_generic,
922         Some(glsl_get_cl_type_size_align),
923     );
924 
925     nir_pass!(
926         nir,
927         nir_lower_explicit_io,
928         nir_variable_mode::nir_var_mem_global | nir_variable_mode::nir_var_mem_constant,
929         global_address_format,
930     );
931 
932     nir_pass!(nir, rusticl_lower_intrinsics, &mut lower_state);
933     nir_pass!(
934         nir,
935         nir_lower_explicit_io,
936         nir_variable_mode::nir_var_mem_shared
937             | nir_variable_mode::nir_var_function_temp
938             | nir_variable_mode::nir_var_uniform,
939         shared_address_format,
940     );
941 
942     if nir_options.lower_int64_options.0 != 0 {
943         nir_pass!(nir, nir_lower_int64);
944     }
945 
946     if nir_options.lower_uniforms_to_ubo {
947         nir_pass!(nir, rusticl_lower_inputs);
948     }
949 
950     nir_pass!(nir, nir_lower_convert_alu_types, None);
951 
952     opt_nir(nir, dev, true);
953 
954     /* before passing it into drivers, assign locations as drivers might remove nir_variables or
955      * other things we depend on
956      */
957     CompiledKernelArg::assign_locations(compiled_args, nir);
958 
959     /* update the has_variable_shared_mem info as we might have DCEed all of them */
960     nir.set_has_variable_shared_mem(compiled_args.iter().any(|arg| {
961         if let CompiledKernelArgType::APIArg(idx) = arg.kind {
962             args[idx as usize].kind == KernelArgType::MemLocal && !arg.dead
963         } else {
964             false
965         }
966     }));
967 
968     if Platform::dbg().nir {
969         eprintln!("=== Printing nir variant '{variant}' for '{name}' before driver finalization");
970         nir.print();
971     }
972 
973     if dev.screen.finalize_nir(nir) {
974         if Platform::dbg().nir {
975             eprintln!(
976                 "=== Printing nir variant '{variant}' for '{name}' after driver finalization"
977             );
978             nir.print();
979         }
980     }
981 
982     nir_pass!(nir, nir_opt_dce);
983     nir.sweep_mem();
984 }
985 
compile_nir_remaining( dev: &Device, mut nir: NirShader, args: &[KernelArg], name: &str, ) -> (CompilationResult, Option<CompilationResult>)986 fn compile_nir_remaining(
987     dev: &Device,
988     mut nir: NirShader,
989     args: &[KernelArg],
990     name: &str,
991 ) -> (CompilationResult, Option<CompilationResult>) {
992     // add all API kernel args
993     let mut compiled_args: Vec<_> = (0..args.len())
994         .map(|idx| CompiledKernelArg {
995             kind: CompiledKernelArgType::APIArg(idx as u32),
996             offset: 0,
997             dead: true,
998         })
999         .collect();
1000 
1001     compile_nir_prepare_for_variants(dev, &mut nir, &mut compiled_args);
1002     if Platform::dbg().nir {
1003         eprintln!("=== Printing nir for '{name}' before specialization");
1004         nir.print();
1005     }
1006 
1007     let mut default_build = CompilationResult {
1008         nir: nir,
1009         compiled_args: compiled_args,
1010     };
1011 
1012     // check if we even want to compile a variant before cloning the compilation state
1013     let has_wgs_hint = default_build.nir.workgroup_size_variable()
1014         && default_build.nir.workgroup_size_hint() != [0; 3];
1015     let has_offsets = default_build
1016         .nir
1017         .reads_sysval(gl_system_value::SYSTEM_VALUE_GLOBAL_INVOCATION_ID);
1018 
1019     let mut optimized = (!Platform::dbg().no_variants && (has_offsets || has_wgs_hint))
1020         .then(|| default_build.clone());
1021 
1022     compile_nir_variant(
1023         &mut default_build,
1024         dev,
1025         NirKernelVariant::Default,
1026         args,
1027         name,
1028     );
1029     if let Some(optimized) = &mut optimized {
1030         compile_nir_variant(optimized, dev, NirKernelVariant::Optimized, args, name);
1031     }
1032 
1033     (default_build, optimized)
1034 }
1035 
1036 pub struct SPIRVToNirResult {
1037     pub kernel_info: KernelInfo,
1038     pub nir_kernel_builds: NirKernelBuilds,
1039 }
1040 
1041 impl SPIRVToNirResult {
new( dev: &'static Device, kernel_info: &clc_kernel_info, args: Vec<KernelArg>, default_build: CompilationResult, optimized: Option<CompilationResult>, ) -> Self1042     fn new(
1043         dev: &'static Device,
1044         kernel_info: &clc_kernel_info,
1045         args: Vec<KernelArg>,
1046         default_build: CompilationResult,
1047         optimized: Option<CompilationResult>,
1048     ) -> Self {
1049         // TODO: we _should_ be able to parse them out of the SPIR-V, but clc doesn't handle
1050         //       indirections yet.
1051         let nir = &default_build.nir;
1052         let wgs = nir.workgroup_size();
1053         let subgroup_size = nir.subgroup_size();
1054         let num_subgroups = nir.num_subgroups();
1055 
1056         let default_build = NirKernelBuild::new(dev, default_build);
1057         let optimized = optimized.map(|opt| NirKernelBuild::new(dev, opt));
1058 
1059         let kernel_info = KernelInfo {
1060             args: args,
1061             attributes_string: kernel_info.attribute_str(),
1062             work_group_size: [wgs[0] as usize, wgs[1] as usize, wgs[2] as usize],
1063             work_group_size_hint: kernel_info.local_size_hint,
1064             subgroup_size: subgroup_size as usize,
1065             num_subgroups: num_subgroups as usize,
1066         };
1067 
1068         Self {
1069             kernel_info: kernel_info,
1070             nir_kernel_builds: NirKernelBuilds::new(default_build, optimized),
1071         }
1072     }
1073 
deserialize(bin: &[u8], d: &'static Device, kernel_info: &clc_kernel_info) -> Option<Self>1074     fn deserialize(bin: &[u8], d: &'static Device, kernel_info: &clc_kernel_info) -> Option<Self> {
1075         let mut reader = blob_reader::default();
1076         unsafe {
1077             blob_reader_init(&mut reader, bin.as_ptr().cast(), bin.len());
1078         }
1079 
1080         let args = KernelArg::deserialize(&mut reader)?;
1081         let default_build = CompilationResult::deserialize(&mut reader, d)?;
1082 
1083         let optimized = match unsafe { blob_read_uint8(&mut reader) } {
1084             0 => None,
1085             _ => Some(CompilationResult::deserialize(&mut reader, d)?),
1086         };
1087 
1088         Some(SPIRVToNirResult::new(
1089             d,
1090             kernel_info,
1091             args,
1092             default_build,
1093             optimized,
1094         ))
1095     }
1096 
1097     // we can't use Self here as the nir shader might be compiled to a cso already and we can't
1098     // cache that.
serialize( blob: &mut blob, args: &[KernelArg], default_build: &CompilationResult, optimized: &Option<CompilationResult>, )1099     fn serialize(
1100         blob: &mut blob,
1101         args: &[KernelArg],
1102         default_build: &CompilationResult,
1103         optimized: &Option<CompilationResult>,
1104     ) {
1105         KernelArg::serialize(args, blob);
1106         default_build.serialize(blob);
1107         match optimized {
1108             Some(variant) => {
1109                 unsafe { blob_write_uint8(blob, 1) };
1110                 variant.serialize(blob);
1111             }
1112             None => unsafe {
1113                 blob_write_uint8(blob, 0);
1114             },
1115         }
1116     }
1117 }
1118 
convert_spirv_to_nir( build: &ProgramBuild, name: &str, args: &[spirv::SPIRVKernelArg], dev: &'static Device, ) -> SPIRVToNirResult1119 pub(super) fn convert_spirv_to_nir(
1120     build: &ProgramBuild,
1121     name: &str,
1122     args: &[spirv::SPIRVKernelArg],
1123     dev: &'static Device,
1124 ) -> SPIRVToNirResult {
1125     let cache = dev.screen().shader_cache();
1126     let key = build.hash_key(dev, name);
1127     let spirv_info = build.spirv_info(name, dev).unwrap();
1128 
1129     cache
1130         .as_ref()
1131         .and_then(|cache| cache.get(&mut key?))
1132         .and_then(|entry| SPIRVToNirResult::deserialize(&entry, dev, spirv_info))
1133         .unwrap_or_else(|| {
1134             let nir = build.to_nir(name, dev);
1135 
1136             if Platform::dbg().nir {
1137                 eprintln!("=== Printing nir for '{name}' after spirv_to_nir");
1138                 nir.print();
1139             }
1140 
1141             let (mut args, nir) = compile_nir_to_args(dev, nir, args, &dev.lib_clc);
1142             let (default_build, optimized) = compile_nir_remaining(dev, nir, &args, name);
1143 
1144             for build in [Some(&default_build), optimized.as_ref()].into_iter() {
1145                 let Some(build) = build else {
1146                     continue;
1147                 };
1148 
1149                 for arg in &build.compiled_args {
1150                     if let CompiledKernelArgType::APIArg(idx) = arg.kind {
1151                         args[idx as usize].dead &= arg.dead;
1152                     }
1153                 }
1154             }
1155 
1156             if let Some(cache) = cache {
1157                 let mut blob = blob::default();
1158                 unsafe {
1159                     blob_init(&mut blob);
1160                     SPIRVToNirResult::serialize(&mut blob, &args, &default_build, &optimized);
1161                     let bin = slice::from_raw_parts(blob.data, blob.size);
1162                     cache.put(bin, &mut key.unwrap());
1163                     blob_finish(&mut blob);
1164                 }
1165             }
1166 
1167             SPIRVToNirResult::new(dev, spirv_info, args, default_build, optimized)
1168         })
1169 }
1170 
extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S]1171 fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
1172     let val;
1173     (val, *buf) = (*buf).split_at(S);
1174     // we split of 4 bytes and convert to [u8; 4], so this should be safe
1175     // use split_array_ref once it's stable
1176     val.try_into().unwrap()
1177 }
1178 
1179 impl Kernel {
new(name: String, prog: Arc<Program>, prog_build: &ProgramBuild) -> Arc<Kernel>1180     pub fn new(name: String, prog: Arc<Program>, prog_build: &ProgramBuild) -> Arc<Kernel> {
1181         let kernel_info = Arc::clone(prog_build.kernel_info.get(&name).unwrap());
1182         let builds = prog_build
1183             .builds
1184             .iter()
1185             .filter_map(|(&dev, b)| b.kernels.get(&name).map(|k| (dev, k.clone())))
1186             .collect();
1187 
1188         let values = vec![None; kernel_info.args.len()];
1189         Arc::new(Self {
1190             base: CLObjectBase::new(RusticlTypes::Kernel),
1191             prog: prog,
1192             name: name,
1193             values: Mutex::new(values),
1194             builds: builds,
1195             kernel_info: kernel_info,
1196         })
1197     }
1198 
suggest_local_size( &self, d: &Device, work_dim: usize, grid: &mut [usize], block: &mut [usize], )1199     pub fn suggest_local_size(
1200         &self,
1201         d: &Device,
1202         work_dim: usize,
1203         grid: &mut [usize],
1204         block: &mut [usize],
1205     ) {
1206         let mut threads = self.max_threads_per_block(d);
1207         let dim_threads = d.max_block_sizes();
1208         let subgroups = self.preferred_simd_size(d);
1209 
1210         for i in 0..work_dim {
1211             let t = cmp::min(threads, dim_threads[i]);
1212             let gcd = gcd(t, grid[i]);
1213 
1214             block[i] = gcd;
1215             grid[i] /= gcd;
1216 
1217             // update limits
1218             threads /= block[i];
1219         }
1220 
1221         // if we didn't fill the subgroup we can do a bit better if we have threads remaining
1222         let total_threads = block.iter().take(work_dim).product::<usize>();
1223         if threads != 1 && total_threads < subgroups {
1224             for i in 0..work_dim {
1225                 if grid[i] * total_threads < threads && grid[i] * block[i] <= dim_threads[i] {
1226                     block[i] *= grid[i];
1227                     grid[i] = 1;
1228                     // can only do it once as nothing is cleanly divisible
1229                     break;
1230                 }
1231             }
1232         }
1233     }
1234 
optimize_local_size(&self, d: &Device, grid: &mut [usize; 3], block: &mut [u32; 3])1235     fn optimize_local_size(&self, d: &Device, grid: &mut [usize; 3], block: &mut [u32; 3]) {
1236         if !block.contains(&0) {
1237             for i in 0..3 {
1238                 // we already made sure everything is fine
1239                 grid[i] /= block[i] as usize;
1240             }
1241             return;
1242         }
1243 
1244         let mut usize_block = [0usize; 3];
1245         for i in 0..3 {
1246             usize_block[i] = block[i] as usize;
1247         }
1248 
1249         self.suggest_local_size(d, 3, grid, &mut usize_block);
1250 
1251         for i in 0..3 {
1252             block[i] = usize_block[i] as u32;
1253         }
1254     }
1255 
1256     // the painful part is, that host threads are allowed to modify the kernel object once it was
1257     // enqueued, so return a closure with all req data included.
launch( self: &Arc<Self>, q: &Arc<Queue>, work_dim: u32, block: &[usize], grid: &[usize], offsets: &[usize], ) -> CLResult<EventSig>1258     pub fn launch(
1259         self: &Arc<Self>,
1260         q: &Arc<Queue>,
1261         work_dim: u32,
1262         block: &[usize],
1263         grid: &[usize],
1264         offsets: &[usize],
1265     ) -> CLResult<EventSig> {
1266         // Clone all the data we need to execute this kernel
1267         let kernel_info = Arc::clone(&self.kernel_info);
1268         let arg_values = self.arg_values().clone();
1269         let nir_kernel_builds = Arc::clone(&self.builds[q.device]);
1270 
1271         // operations we want to report errors to the clients
1272         let mut block = create_kernel_arr::<u32>(block, 1)?;
1273         let mut grid = create_kernel_arr::<usize>(grid, 1)?;
1274         let offsets = create_kernel_arr::<usize>(offsets, 0)?;
1275 
1276         let api_grid = grid;
1277 
1278         self.optimize_local_size(q.device, &mut grid, &mut block);
1279 
1280         Ok(Box::new(move |q, ctx| {
1281             let hw_max_grid: Vec<usize> = q
1282                 .device
1283                 .max_grid_size()
1284                 .into_iter()
1285                 .map(|val| val.try_into().unwrap_or(usize::MAX))
1286                 // clamped as pipe_launch_grid::grid is only u32
1287                 .map(|val| cmp::min(val, u32::MAX as usize))
1288                 .collect();
1289 
1290             let variant = if offsets == [0; 3]
1291                 && grid[0] <= hw_max_grid[0]
1292                 && grid[1] <= hw_max_grid[1]
1293                 && grid[2] <= hw_max_grid[2]
1294                 && block == kernel_info.work_group_size_hint
1295             {
1296                 NirKernelVariant::Optimized
1297             } else {
1298                 NirKernelVariant::Default
1299             };
1300 
1301             let nir_kernel_build = &nir_kernel_builds[variant];
1302             let mut workgroup_id_offset_loc = None;
1303             let mut input = Vec::new();
1304             // Set it once so we get the alignment padding right
1305             let static_local_size: u64 = nir_kernel_build.shared_size;
1306             let mut variable_local_size: u64 = static_local_size;
1307             let printf_size = q.device.printf_buffer_size() as u32;
1308             let mut samplers = Vec::new();
1309             let mut iviews = Vec::new();
1310             let mut sviews = Vec::new();
1311             let mut tex_formats: Vec<u16> = Vec::new();
1312             let mut tex_orders: Vec<u16> = Vec::new();
1313             let mut img_formats: Vec<u16> = Vec::new();
1314             let mut img_orders: Vec<u16> = Vec::new();
1315 
1316             let null_ptr;
1317             let null_ptr_v3;
1318             if q.device.address_bits() == 64 {
1319                 null_ptr = [0u8; 8].as_slice();
1320                 null_ptr_v3 = [0u8; 24].as_slice();
1321             } else {
1322                 null_ptr = [0u8; 4].as_slice();
1323                 null_ptr_v3 = [0u8; 12].as_slice();
1324             };
1325 
1326             let mut resource_info = Vec::new();
1327             fn add_global<'a>(
1328                 q: &Queue,
1329                 input: &mut Vec<u8>,
1330                 resource_info: &mut Vec<(&'a PipeResource, usize)>,
1331                 res: &'a PipeResource,
1332                 offset: usize,
1333             ) {
1334                 resource_info.push((res, input.len()));
1335                 if q.device.address_bits() == 64 {
1336                     let offset: u64 = offset as u64;
1337                     input.extend_from_slice(&offset.to_ne_bytes());
1338                 } else {
1339                     let offset: u32 = offset as u32;
1340                     input.extend_from_slice(&offset.to_ne_bytes());
1341                 }
1342             }
1343 
1344             fn add_sysval(q: &Queue, input: &mut Vec<u8>, vals: &[usize; 3]) {
1345                 if q.device.address_bits() == 64 {
1346                     input.extend_from_slice(unsafe { as_byte_slice(&vals.map(|v| v as u64)) });
1347                 } else {
1348                     input.extend_from_slice(unsafe { as_byte_slice(&vals.map(|v| v as u32)) });
1349                 }
1350             }
1351 
1352             let mut printf_buf = None;
1353             if nir_kernel_build.printf_info.is_some() {
1354                 let buf = q
1355                     .device
1356                     .screen
1357                     .resource_create_buffer(printf_size, ResourceType::Staging, PIPE_BIND_GLOBAL)
1358                     .unwrap();
1359 
1360                 let init_data: [u8; 1] = [4];
1361                 ctx.buffer_subdata(&buf, 0, init_data.as_ptr().cast(), init_data.len() as u32);
1362 
1363                 printf_buf = Some(buf);
1364             }
1365 
1366             for arg in &nir_kernel_build.compiled_args {
1367                 let is_opaque = if let CompiledKernelArgType::APIArg(idx) = arg.kind {
1368                     kernel_info.args[idx as usize].kind.is_opaque()
1369                 } else {
1370                     false
1371                 };
1372 
1373                 if !is_opaque && arg.offset as usize > input.len() {
1374                     input.resize(arg.offset as usize, 0);
1375                 }
1376 
1377                 match arg.kind {
1378                     CompiledKernelArgType::APIArg(idx) => {
1379                         let api_arg = &kernel_info.args[idx as usize];
1380                         if api_arg.dead {
1381                             continue;
1382                         }
1383 
1384                         let Some(value) = &arg_values[idx as usize] else {
1385                             continue;
1386                         };
1387 
1388                         match value {
1389                             KernelArgValue::Constant(c) => input.extend_from_slice(c),
1390                             KernelArgValue::Buffer(buffer) => {
1391                                 let res = buffer.get_res_of_dev(q.device)?;
1392                                 add_global(q, &mut input, &mut resource_info, res, buffer.offset);
1393                             }
1394                             KernelArgValue::Image(image) => {
1395                                 let res = image.get_res_of_dev(q.device)?;
1396 
1397                                 // If resource is a buffer, the image was created from a buffer. Use
1398                                 // strides and dimensions of the image then.
1399                                 let app_img_info = if res.as_ref().is_buffer()
1400                                     && image.mem_type == CL_MEM_OBJECT_IMAGE2D
1401                                 {
1402                                     Some(AppImgInfo::new(
1403                                         image.image_desc.row_pitch()?
1404                                             / image.image_elem_size as u32,
1405                                         image.image_desc.width()?,
1406                                         image.image_desc.height()?,
1407                                     ))
1408                                 } else {
1409                                     None
1410                                 };
1411 
1412                                 let format = image.pipe_format;
1413                                 let (formats, orders) = if api_arg.kind == KernelArgType::Image {
1414                                     iviews.push(res.pipe_image_view(
1415                                         format,
1416                                         false,
1417                                         image.pipe_image_host_access(),
1418                                         app_img_info.as_ref(),
1419                                     ));
1420                                     (&mut img_formats, &mut img_orders)
1421                                 } else if api_arg.kind == KernelArgType::RWImage {
1422                                     iviews.push(res.pipe_image_view(
1423                                         format,
1424                                         true,
1425                                         image.pipe_image_host_access(),
1426                                         app_img_info.as_ref(),
1427                                     ));
1428                                     (&mut img_formats, &mut img_orders)
1429                                 } else {
1430                                     sviews.push((res.clone(), format, app_img_info));
1431                                     (&mut tex_formats, &mut tex_orders)
1432                                 };
1433 
1434                                 let binding = arg.offset as usize;
1435                                 assert!(binding >= formats.len());
1436 
1437                                 formats.resize(binding, 0);
1438                                 orders.resize(binding, 0);
1439 
1440                                 formats.push(image.image_format.image_channel_data_type as u16);
1441                                 orders.push(image.image_format.image_channel_order as u16);
1442                             }
1443                             KernelArgValue::LocalMem(size) => {
1444                                 // TODO 32 bit
1445                                 let pot = cmp::min(*size, 0x80);
1446                                 variable_local_size = variable_local_size
1447                                     .next_multiple_of(pot.next_power_of_two() as u64);
1448                                 if q.device.address_bits() == 64 {
1449                                     let variable_local_size: [u8; 8] =
1450                                         variable_local_size.to_ne_bytes();
1451                                     input.extend_from_slice(&variable_local_size);
1452                                 } else {
1453                                     let variable_local_size: [u8; 4] =
1454                                         (variable_local_size as u32).to_ne_bytes();
1455                                     input.extend_from_slice(&variable_local_size);
1456                                 }
1457                                 variable_local_size += *size as u64;
1458                             }
1459                             KernelArgValue::Sampler(sampler) => {
1460                                 samplers.push(sampler.pipe());
1461                             }
1462                             KernelArgValue::None => {
1463                                 assert!(
1464                                     api_arg.kind == KernelArgType::MemGlobal
1465                                         || api_arg.kind == KernelArgType::MemConstant
1466                                 );
1467                                 input.extend_from_slice(null_ptr);
1468                             }
1469                         }
1470                     }
1471                     CompiledKernelArgType::ConstantBuffer => {
1472                         assert!(nir_kernel_build.constant_buffer.is_some());
1473                         let res = nir_kernel_build.constant_buffer.as_ref().unwrap();
1474                         add_global(q, &mut input, &mut resource_info, res, 0);
1475                     }
1476                     CompiledKernelArgType::GlobalWorkOffsets => {
1477                         add_sysval(q, &mut input, &offsets);
1478                     }
1479                     CompiledKernelArgType::WorkGroupOffsets => {
1480                         workgroup_id_offset_loc = Some(input.len());
1481                         input.extend_from_slice(null_ptr_v3);
1482                     }
1483                     CompiledKernelArgType::GlobalWorkSize => {
1484                         add_sysval(q, &mut input, &api_grid);
1485                     }
1486                     CompiledKernelArgType::PrintfBuffer => {
1487                         let res = printf_buf.as_ref().unwrap();
1488                         add_global(q, &mut input, &mut resource_info, res, 0);
1489                     }
1490                     CompiledKernelArgType::InlineSampler(cl) => {
1491                         samplers.push(Sampler::cl_to_pipe(cl));
1492                     }
1493                     CompiledKernelArgType::FormatArray => {
1494                         input.extend_from_slice(unsafe { as_byte_slice(&tex_formats) });
1495                         input.extend_from_slice(unsafe { as_byte_slice(&img_formats) });
1496                     }
1497                     CompiledKernelArgType::OrderArray => {
1498                         input.extend_from_slice(unsafe { as_byte_slice(&tex_orders) });
1499                         input.extend_from_slice(unsafe { as_byte_slice(&img_orders) });
1500                     }
1501                     CompiledKernelArgType::WorkDim => {
1502                         input.extend_from_slice(&[work_dim as u8; 1]);
1503                     }
1504                     CompiledKernelArgType::NumWorkgroups => {
1505                         input.extend_from_slice(unsafe {
1506                             as_byte_slice(&[grid[0] as u32, grid[1] as u32, grid[2] as u32])
1507                         });
1508                     }
1509                 }
1510             }
1511 
1512             // subtract the shader local_size as we only request something on top of that.
1513             variable_local_size -= static_local_size;
1514 
1515             let mut sviews: Vec<_> = sviews
1516                 .iter()
1517                 .map(|(s, f, aii)| ctx.create_sampler_view(s, *f, aii.as_ref()))
1518                 .collect();
1519             let samplers: Vec<_> = samplers
1520                 .iter()
1521                 .map(|s| ctx.create_sampler_state(s))
1522                 .collect();
1523 
1524             let mut resources = Vec::with_capacity(resource_info.len());
1525             let mut globals: Vec<*mut u32> = Vec::with_capacity(resource_info.len());
1526             for (res, offset) in resource_info {
1527                 resources.push(res);
1528                 globals.push(unsafe { input.as_mut_ptr().byte_add(offset) }.cast());
1529             }
1530 
1531             let temp_cso;
1532             let cso = match &nir_kernel_build.nir_or_cso {
1533                 KernelDevStateVariant::Cso(cso) => cso,
1534                 KernelDevStateVariant::Nir(nir) => {
1535                     temp_cso = CSOWrapper::new(q.device, nir);
1536                     &temp_cso
1537                 }
1538             };
1539 
1540             ctx.bind_compute_state(cso.cso_ptr);
1541             ctx.bind_sampler_states(&samplers);
1542             ctx.set_sampler_views(&mut sviews);
1543             ctx.set_shader_images(&iviews);
1544             ctx.set_global_binding(resources.as_slice(), &mut globals);
1545 
1546             for z in 0..grid[2].div_ceil(hw_max_grid[2]) {
1547                 for y in 0..grid[1].div_ceil(hw_max_grid[1]) {
1548                     for x in 0..grid[0].div_ceil(hw_max_grid[0]) {
1549                         if let Some(workgroup_id_offset_loc) = workgroup_id_offset_loc {
1550                             let this_offsets =
1551                                 [x * hw_max_grid[0], y * hw_max_grid[1], z * hw_max_grid[2]];
1552 
1553                             if q.device.address_bits() == 64 {
1554                                 let val = this_offsets.map(|v| v as u64);
1555                                 input[workgroup_id_offset_loc..workgroup_id_offset_loc + 24]
1556                                     .copy_from_slice(unsafe { as_byte_slice(&val) });
1557                             } else {
1558                                 let val = this_offsets.map(|v| v as u32);
1559                                 input[workgroup_id_offset_loc..workgroup_id_offset_loc + 12]
1560                                     .copy_from_slice(unsafe { as_byte_slice(&val) });
1561                             }
1562                         }
1563 
1564                         let this_grid = [
1565                             cmp::min(hw_max_grid[0], grid[0] - hw_max_grid[0] * x) as u32,
1566                             cmp::min(hw_max_grid[1], grid[1] - hw_max_grid[1] * y) as u32,
1567                             cmp::min(hw_max_grid[2], grid[2] - hw_max_grid[2] * z) as u32,
1568                         ];
1569 
1570                         ctx.update_cb0(&input)?;
1571                         ctx.launch_grid(work_dim, block, this_grid, variable_local_size as u32);
1572 
1573                         if Platform::dbg().sync_every_event {
1574                             ctx.flush().wait();
1575                         }
1576                     }
1577                 }
1578             }
1579 
1580             ctx.clear_global_binding(globals.len() as u32);
1581             ctx.clear_shader_images(iviews.len() as u32);
1582             ctx.clear_sampler_views(sviews.len() as u32);
1583             ctx.clear_sampler_states(samplers.len() as u32);
1584 
1585             ctx.bind_compute_state(ptr::null_mut());
1586 
1587             ctx.memory_barrier(PIPE_BARRIER_GLOBAL_BUFFER);
1588 
1589             samplers.iter().for_each(|s| ctx.delete_sampler_state(*s));
1590             sviews.iter().for_each(|v| ctx.sampler_view_destroy(*v));
1591 
1592             if let Some(printf_buf) = &printf_buf {
1593                 let tx = ctx
1594                     .buffer_map(printf_buf, 0, printf_size as i32, RWFlags::RD)
1595                     .ok_or(CL_OUT_OF_RESOURCES)?;
1596                 let mut buf: &[u8] =
1597                     unsafe { slice::from_raw_parts(tx.ptr().cast(), printf_size as usize) };
1598                 let length = u32::from_ne_bytes(*extract(&mut buf));
1599 
1600                 // update our slice to make sure we don't go out of bounds
1601                 buf = &buf[0..(length - 4) as usize];
1602                 if let Some(pf) = &nir_kernel_build.printf_info {
1603                     pf.u_printf(buf)
1604                 }
1605             }
1606 
1607             Ok(())
1608         }))
1609     }
1610 
arg_values(&self) -> MutexGuard<Vec<Option<KernelArgValue>>>1611     pub fn arg_values(&self) -> MutexGuard<Vec<Option<KernelArgValue>>> {
1612         self.values.lock().unwrap()
1613     }
1614 
set_kernel_arg(&self, idx: usize, arg: KernelArgValue) -> CLResult<()>1615     pub fn set_kernel_arg(&self, idx: usize, arg: KernelArgValue) -> CLResult<()> {
1616         self.values
1617             .lock()
1618             .unwrap()
1619             .get_mut(idx)
1620             .ok_or(CL_INVALID_ARG_INDEX)?
1621             .replace(arg);
1622         Ok(())
1623     }
1624 
access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier1625     pub fn access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier {
1626         let aq = self.kernel_info.args[idx as usize].spirv.access_qualifier;
1627 
1628         if aq
1629             == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ
1630                 | clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_WRITE
1631         {
1632             CL_KERNEL_ARG_ACCESS_READ_WRITE
1633         } else if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ {
1634             CL_KERNEL_ARG_ACCESS_READ_ONLY
1635         } else if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_WRITE {
1636             CL_KERNEL_ARG_ACCESS_WRITE_ONLY
1637         } else {
1638             CL_KERNEL_ARG_ACCESS_NONE
1639         }
1640     }
1641 
address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier1642     pub fn address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier {
1643         match self.kernel_info.args[idx as usize].spirv.address_qualifier {
1644             clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
1645                 CL_KERNEL_ARG_ADDRESS_PRIVATE
1646             }
1647             clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT => {
1648                 CL_KERNEL_ARG_ADDRESS_CONSTANT
1649             }
1650             clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL => {
1651                 CL_KERNEL_ARG_ADDRESS_LOCAL
1652             }
1653             clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL => {
1654                 CL_KERNEL_ARG_ADDRESS_GLOBAL
1655             }
1656         }
1657     }
1658 
type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier1659     pub fn type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier {
1660         let tq = self.kernel_info.args[idx as usize].spirv.type_qualifier;
1661         let zero = clc_kernel_arg_type_qualifier(0);
1662         let mut res = CL_KERNEL_ARG_TYPE_NONE;
1663 
1664         if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_CONST != zero {
1665             res |= CL_KERNEL_ARG_TYPE_CONST;
1666         }
1667 
1668         if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_RESTRICT != zero {
1669             res |= CL_KERNEL_ARG_TYPE_RESTRICT;
1670         }
1671 
1672         if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_VOLATILE != zero {
1673             res |= CL_KERNEL_ARG_TYPE_VOLATILE;
1674         }
1675 
1676         res.into()
1677     }
1678 
work_group_size(&self) -> [usize; 3]1679     pub fn work_group_size(&self) -> [usize; 3] {
1680         self.kernel_info.work_group_size
1681     }
1682 
num_subgroups(&self) -> usize1683     pub fn num_subgroups(&self) -> usize {
1684         self.kernel_info.num_subgroups
1685     }
1686 
subgroup_size(&self) -> usize1687     pub fn subgroup_size(&self) -> usize {
1688         self.kernel_info.subgroup_size
1689     }
1690 
arg_name(&self, idx: cl_uint) -> &String1691     pub fn arg_name(&self, idx: cl_uint) -> &String {
1692         &self.kernel_info.args[idx as usize].spirv.name
1693     }
1694 
arg_type_name(&self, idx: cl_uint) -> &String1695     pub fn arg_type_name(&self, idx: cl_uint) -> &String {
1696         &self.kernel_info.args[idx as usize].spirv.type_name
1697     }
1698 
priv_mem_size(&self, dev: &Device) -> cl_ulong1699     pub fn priv_mem_size(&self, dev: &Device) -> cl_ulong {
1700         self.builds.get(dev).unwrap().info.private_memory as cl_ulong
1701     }
1702 
max_threads_per_block(&self, dev: &Device) -> usize1703     pub fn max_threads_per_block(&self, dev: &Device) -> usize {
1704         self.builds.get(dev).unwrap().info.max_threads as usize
1705     }
1706 
preferred_simd_size(&self, dev: &Device) -> usize1707     pub fn preferred_simd_size(&self, dev: &Device) -> usize {
1708         self.builds.get(dev).unwrap().info.preferred_simd_size as usize
1709     }
1710 
local_mem_size(&self, dev: &Device) -> cl_ulong1711     pub fn local_mem_size(&self, dev: &Device) -> cl_ulong {
1712         // TODO include args
1713         // this is purely informational so it shouldn't even matter
1714         self.builds.get(dev).unwrap()[NirKernelVariant::Default].shared_size as cl_ulong
1715     }
1716 
has_svm_devs(&self) -> bool1717     pub fn has_svm_devs(&self) -> bool {
1718         self.prog.devs.iter().any(|dev| dev.svm_supported())
1719     }
1720 
subgroup_sizes(&self, dev: &Device) -> Vec<usize>1721     pub fn subgroup_sizes(&self, dev: &Device) -> Vec<usize> {
1722         SetBitIndices::from_msb(self.builds.get(dev).unwrap().info.simd_sizes)
1723             .map(|bit| 1 << bit)
1724             .collect()
1725     }
1726 
subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize1727     pub fn subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize {
1728         let subgroup_size = self.subgroup_size_for_block(dev, block);
1729         if subgroup_size == 0 {
1730             return 0;
1731         }
1732 
1733         let threads: usize = block.iter().product();
1734         threads.div_ceil(subgroup_size)
1735     }
1736 
subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize1737     pub fn subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize {
1738         let subgroup_sizes = self.subgroup_sizes(dev);
1739         if subgroup_sizes.is_empty() {
1740             return 0;
1741         }
1742 
1743         if subgroup_sizes.len() == 1 {
1744             return subgroup_sizes[0];
1745         }
1746 
1747         let block = [
1748             *block.first().unwrap_or(&1) as u32,
1749             *block.get(1).unwrap_or(&1) as u32,
1750             *block.get(2).unwrap_or(&1) as u32,
1751         ];
1752 
1753         // TODO: this _might_ bite us somewhere, but I think it probably doesn't matter
1754         match &self.builds.get(dev).unwrap()[NirKernelVariant::Default].nir_or_cso {
1755             KernelDevStateVariant::Cso(cso) => {
1756                 dev.helper_ctx()
1757                     .compute_state_subgroup_size(cso.cso_ptr, &block) as usize
1758             }
1759             _ => {
1760                 panic!()
1761             }
1762         }
1763     }
1764 }
1765 
1766 impl Clone for Kernel {
clone(&self) -> Self1767     fn clone(&self) -> Self {
1768         Self {
1769             base: CLObjectBase::new(RusticlTypes::Kernel),
1770             prog: self.prog.clone(),
1771             name: self.name.clone(),
1772             values: Mutex::new(self.arg_values().clone()),
1773             builds: self.builds.clone(),
1774             kernel_info: self.kernel_info.clone(),
1775         }
1776     }
1777 }
1778