1 use crate::api::icd::*;
2 use crate::core::device::*;
3 use crate::core::event::*;
4 use crate::core::memory::*;
5 use crate::core::platform::*;
6 use crate::core::program::*;
7 use crate::core::queue::*;
8 use crate::impl_cl_type_trait;
9
10 use mesa_rust::compiler::clc::*;
11 use mesa_rust::compiler::nir::*;
12 use mesa_rust::nir_pass;
13 use mesa_rust::pipe::context::RWFlags;
14 use mesa_rust::pipe::resource::*;
15 use mesa_rust::pipe::screen::ResourceType;
16 use mesa_rust_gen::*;
17 use mesa_rust_util::math::*;
18 use mesa_rust_util::serialize::*;
19 use rusticl_opencl_gen::*;
20 use spirv::SpirvKernelInfo;
21
22 use std::cmp;
23 use std::collections::HashMap;
24 use std::convert::TryInto;
25 use std::fmt::Debug;
26 use std::fmt::Display;
27 use std::ops::Index;
28 use std::os::raw::c_void;
29 use std::ptr;
30 use std::slice;
31 use std::sync::Arc;
32 use std::sync::Mutex;
33 use std::sync::MutexGuard;
34
35 // ugh, we are not allowed to take refs, so...
36 #[derive(Clone)]
37 pub enum KernelArgValue {
38 None,
39 Buffer(Arc<Buffer>),
40 Constant(Vec<u8>),
41 Image(Arc<Image>),
42 LocalMem(usize),
43 Sampler(Arc<Sampler>),
44 }
45
46 #[repr(u8)]
47 #[derive(Hash, PartialEq, Eq, Clone, Copy)]
48 pub enum KernelArgType {
49 Constant(/* size */ u16), // for anything passed by value
50 Image,
51 RWImage,
52 Sampler,
53 Texture,
54 MemGlobal,
55 MemConstant,
56 MemLocal,
57 }
58
59 impl KernelArgType {
deserialize(blob: &mut blob_reader) -> Option<Self>60 fn deserialize(blob: &mut blob_reader) -> Option<Self> {
61 Some(match unsafe { blob_read_uint8(blob) } {
62 0 => {
63 let size = unsafe { blob_read_uint16(blob) };
64 KernelArgType::Constant(size)
65 }
66 1 => KernelArgType::Image,
67 2 => KernelArgType::RWImage,
68 3 => KernelArgType::Sampler,
69 4 => KernelArgType::Texture,
70 5 => KernelArgType::MemGlobal,
71 6 => KernelArgType::MemConstant,
72 7 => KernelArgType::MemLocal,
73 _ => return None,
74 })
75 }
76
serialize(&self, blob: &mut blob)77 fn serialize(&self, blob: &mut blob) {
78 unsafe {
79 match self {
80 KernelArgType::Constant(size) => {
81 blob_write_uint8(blob, 0);
82 blob_write_uint16(blob, *size)
83 }
84 KernelArgType::Image => blob_write_uint8(blob, 1),
85 KernelArgType::RWImage => blob_write_uint8(blob, 2),
86 KernelArgType::Sampler => blob_write_uint8(blob, 3),
87 KernelArgType::Texture => blob_write_uint8(blob, 4),
88 KernelArgType::MemGlobal => blob_write_uint8(blob, 5),
89 KernelArgType::MemConstant => blob_write_uint8(blob, 6),
90 KernelArgType::MemLocal => blob_write_uint8(blob, 7),
91 };
92 }
93 }
94
is_opaque(&self) -> bool95 fn is_opaque(&self) -> bool {
96 matches!(
97 self,
98 KernelArgType::Image
99 | KernelArgType::RWImage
100 | KernelArgType::Texture
101 | KernelArgType::Sampler
102 )
103 }
104 }
105
106 #[derive(Hash, PartialEq, Eq, Clone)]
107 enum CompiledKernelArgType {
108 APIArg(u32),
109 ConstantBuffer,
110 GlobalWorkOffsets,
111 GlobalWorkSize,
112 PrintfBuffer,
113 InlineSampler((cl_addressing_mode, cl_filter_mode, bool)),
114 FormatArray,
115 OrderArray,
116 WorkDim,
117 WorkGroupOffsets,
118 NumWorkgroups,
119 }
120
121 #[derive(Hash, PartialEq, Eq, Clone)]
122 pub struct KernelArg {
123 spirv: spirv::SPIRVKernelArg,
124 pub kind: KernelArgType,
125 pub dead: bool,
126 }
127
128 impl KernelArg {
from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self>129 fn from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self> {
130 let nir_arg_map: HashMap<_, _> = nir
131 .variables_with_mode(
132 nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
133 )
134 .map(|v| (v.data.location, v))
135 .collect();
136 let mut res = Vec::new();
137
138 for (i, s) in spirv.iter().enumerate() {
139 let nir = nir_arg_map.get(&(i as i32)).unwrap();
140 let kind = match s.address_qualifier {
141 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
142 if unsafe { glsl_type_is_sampler(nir.type_) } {
143 KernelArgType::Sampler
144 } else {
145 let size = unsafe { glsl_get_cl_size(nir.type_) } as u16;
146 // nir types of non opaque types are never sized 0
147 KernelArgType::Constant(size)
148 }
149 }
150 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT => {
151 KernelArgType::MemConstant
152 }
153 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL => {
154 KernelArgType::MemLocal
155 }
156 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL => {
157 if unsafe { glsl_type_is_image(nir.type_) } {
158 let access = nir.data.access();
159 if access == gl_access_qualifier::ACCESS_NON_WRITEABLE.0 {
160 KernelArgType::Texture
161 } else if access == gl_access_qualifier::ACCESS_NON_READABLE.0 {
162 KernelArgType::Image
163 } else {
164 KernelArgType::RWImage
165 }
166 } else {
167 KernelArgType::MemGlobal
168 }
169 }
170 };
171
172 res.push(Self {
173 spirv: s.clone(),
174 // we'll update it later in the 2nd pass
175 kind: kind,
176 dead: true,
177 });
178 }
179 res
180 }
181
serialize(args: &[Self], blob: &mut blob)182 fn serialize(args: &[Self], blob: &mut blob) {
183 unsafe {
184 blob_write_uint16(blob, args.len() as u16);
185
186 for arg in args {
187 arg.spirv.serialize(blob);
188 blob_write_uint8(blob, arg.dead.into());
189 arg.kind.serialize(blob);
190 }
191 }
192 }
193
deserialize(blob: &mut blob_reader) -> Option<Vec<Self>>194 fn deserialize(blob: &mut blob_reader) -> Option<Vec<Self>> {
195 unsafe {
196 let len = blob_read_uint16(blob) as usize;
197 let mut res = Vec::with_capacity(len);
198
199 for _ in 0..len {
200 let spirv = spirv::SPIRVKernelArg::deserialize(blob)?;
201 let dead = blob_read_uint8(blob) != 0;
202 let kind = KernelArgType::deserialize(blob)?;
203
204 res.push(Self {
205 spirv: spirv,
206 kind: kind,
207 dead: dead,
208 });
209 }
210
211 Some(res)
212 }
213 }
214 }
215
216 #[derive(Hash, PartialEq, Eq, Clone)]
217 struct CompiledKernelArg {
218 kind: CompiledKernelArgType,
219 /// The binding for image/sampler args, the offset into the input buffer
220 /// for anything else.
221 offset: u32,
222 dead: bool,
223 }
224
225 impl CompiledKernelArg {
assign_locations(compiled_args: &mut [Self], nir: &mut NirShader)226 fn assign_locations(compiled_args: &mut [Self], nir: &mut NirShader) {
227 for var in nir.variables_with_mode(
228 nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
229 ) {
230 let arg = &mut compiled_args[var.data.location as usize];
231 let t = var.type_;
232
233 arg.dead = false;
234 arg.offset = if unsafe {
235 glsl_type_is_image(t) || glsl_type_is_texture(t) || glsl_type_is_sampler(t)
236 } {
237 var.data.binding
238 } else {
239 var.data.driver_location
240 };
241 }
242 }
243
serialize(args: &[Self], blob: &mut blob)244 fn serialize(args: &[Self], blob: &mut blob) {
245 unsafe {
246 blob_write_uint16(blob, args.len() as u16);
247 for arg in args {
248 blob_write_uint32(blob, arg.offset);
249 blob_write_uint8(blob, arg.dead.into());
250 match arg.kind {
251 CompiledKernelArgType::ConstantBuffer => blob_write_uint8(blob, 0),
252 CompiledKernelArgType::GlobalWorkOffsets => blob_write_uint8(blob, 1),
253 CompiledKernelArgType::PrintfBuffer => blob_write_uint8(blob, 2),
254 CompiledKernelArgType::InlineSampler((addr_mode, filter_mode, norm)) => {
255 blob_write_uint8(blob, 3);
256 blob_write_uint8(blob, norm.into());
257 blob_write_uint32(blob, addr_mode);
258 blob_write_uint32(blob, filter_mode)
259 }
260 CompiledKernelArgType::FormatArray => blob_write_uint8(blob, 4),
261 CompiledKernelArgType::OrderArray => blob_write_uint8(blob, 5),
262 CompiledKernelArgType::WorkDim => blob_write_uint8(blob, 6),
263 CompiledKernelArgType::WorkGroupOffsets => blob_write_uint8(blob, 7),
264 CompiledKernelArgType::NumWorkgroups => blob_write_uint8(blob, 8),
265 CompiledKernelArgType::GlobalWorkSize => blob_write_uint8(blob, 9),
266 CompiledKernelArgType::APIArg(idx) => {
267 blob_write_uint8(blob, 10);
268 blob_write_uint32(blob, idx)
269 }
270 };
271 }
272 }
273 }
274
deserialize(blob: &mut blob_reader) -> Option<Vec<Self>>275 fn deserialize(blob: &mut blob_reader) -> Option<Vec<Self>> {
276 unsafe {
277 let len = blob_read_uint16(blob) as usize;
278 let mut res = Vec::with_capacity(len);
279
280 for _ in 0..len {
281 let offset = blob_read_uint32(blob);
282 let dead = blob_read_uint8(blob) != 0;
283
284 let kind = match blob_read_uint8(blob) {
285 0 => CompiledKernelArgType::ConstantBuffer,
286 1 => CompiledKernelArgType::GlobalWorkOffsets,
287 2 => CompiledKernelArgType::PrintfBuffer,
288 3 => {
289 let norm = blob_read_uint8(blob) != 0;
290 let addr_mode = blob_read_uint32(blob);
291 let filter_mode = blob_read_uint32(blob);
292 CompiledKernelArgType::InlineSampler((addr_mode, filter_mode, norm))
293 }
294 4 => CompiledKernelArgType::FormatArray,
295 5 => CompiledKernelArgType::OrderArray,
296 6 => CompiledKernelArgType::WorkDim,
297 7 => CompiledKernelArgType::WorkGroupOffsets,
298 8 => CompiledKernelArgType::NumWorkgroups,
299 9 => CompiledKernelArgType::GlobalWorkSize,
300 10 => {
301 let idx = blob_read_uint32(blob);
302 CompiledKernelArgType::APIArg(idx)
303 }
304 _ => return None,
305 };
306
307 res.push(Self {
308 kind: kind,
309 offset: offset,
310 dead: dead,
311 });
312 }
313
314 Some(res)
315 }
316 }
317 }
318
319 #[derive(Clone, PartialEq, Eq, Hash)]
320 pub struct KernelInfo {
321 pub args: Vec<KernelArg>,
322 pub attributes_string: String,
323 work_group_size: [usize; 3],
324 work_group_size_hint: [u32; 3],
325 subgroup_size: usize,
326 num_subgroups: usize,
327 }
328
329 struct CSOWrapper {
330 cso_ptr: *mut c_void,
331 dev: &'static Device,
332 }
333
334 impl CSOWrapper {
new(dev: &'static Device, nir: &NirShader) -> Self335 fn new(dev: &'static Device, nir: &NirShader) -> Self {
336 let cso_ptr = dev
337 .helper_ctx()
338 .create_compute_state(nir, nir.shared_size());
339
340 Self {
341 cso_ptr: cso_ptr,
342 dev: dev,
343 }
344 }
345
get_cso_info(&self) -> pipe_compute_state_object_info346 fn get_cso_info(&self) -> pipe_compute_state_object_info {
347 self.dev.helper_ctx().compute_state_info(self.cso_ptr)
348 }
349 }
350
351 impl Drop for CSOWrapper {
drop(&mut self)352 fn drop(&mut self) {
353 self.dev.helper_ctx().delete_compute_state(self.cso_ptr);
354 }
355 }
356
357 enum KernelDevStateVariant {
358 Cso(CSOWrapper),
359 Nir(NirShader),
360 }
361
362 #[derive(Debug, PartialEq)]
363 enum NirKernelVariant {
364 /// Can be used under any circumstance.
365 Default,
366
367 /// Optimized variant making the following assumptions:
368 /// - global_id_offsets are 0
369 /// - workgroup_offsets are 0
370 /// - local_size is info.local_size_hint
371 Optimized,
372 }
373
374 impl Display for NirKernelVariant {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result375 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376 // this simply prints the enum name, so that's fine
377 Debug::fmt(self, f)
378 }
379 }
380
381 pub struct NirKernelBuilds {
382 default_build: NirKernelBuild,
383 optimized: Option<NirKernelBuild>,
384 /// merged info with worst case values
385 info: pipe_compute_state_object_info,
386 }
387
388 impl Index<NirKernelVariant> for NirKernelBuilds {
389 type Output = NirKernelBuild;
390
index(&self, index: NirKernelVariant) -> &Self::Output391 fn index(&self, index: NirKernelVariant) -> &Self::Output {
392 match index {
393 NirKernelVariant::Default => &self.default_build,
394 NirKernelVariant::Optimized => self.optimized.as_ref().unwrap_or(&self.default_build),
395 }
396 }
397 }
398
399 impl NirKernelBuilds {
new(default_build: NirKernelBuild, optimized: Option<NirKernelBuild>) -> Self400 fn new(default_build: NirKernelBuild, optimized: Option<NirKernelBuild>) -> Self {
401 let mut info = default_build.info;
402 if let Some(build) = &optimized {
403 info.max_threads = cmp::min(info.max_threads, build.info.max_threads);
404 info.simd_sizes &= build.info.simd_sizes;
405 info.private_memory = cmp::max(info.private_memory, build.info.private_memory);
406 info.preferred_simd_size =
407 cmp::max(info.preferred_simd_size, build.info.preferred_simd_size);
408 }
409
410 Self {
411 default_build: default_build,
412 optimized: optimized,
413 info: info,
414 }
415 }
416 }
417
418 struct NirKernelBuild {
419 nir_or_cso: KernelDevStateVariant,
420 constant_buffer: Option<Arc<PipeResource>>,
421 info: pipe_compute_state_object_info,
422 shared_size: u64,
423 printf_info: Option<NirPrintfInfo>,
424 compiled_args: Vec<CompiledKernelArg>,
425 }
426
427 // SAFETY: `CSOWrapper` is only safe to use if the device supports `PIPE_CAP_SHAREABLE_SHADERS` and
428 // we make sure to set `nir_or_cso` to `KernelDevStateVariant::Cso` only if that's the case.
429 unsafe impl Send for NirKernelBuild {}
430 unsafe impl Sync for NirKernelBuild {}
431
432 impl NirKernelBuild {
new(dev: &'static Device, mut out: CompilationResult) -> Self433 fn new(dev: &'static Device, mut out: CompilationResult) -> Self {
434 let cso = CSOWrapper::new(dev, &out.nir);
435 let info = cso.get_cso_info();
436 let cb = Self::create_nir_constant_buffer(dev, &out.nir);
437 let shared_size = out.nir.shared_size() as u64;
438 let printf_info = out.nir.take_printf_info();
439
440 let nir_or_cso = if !dev.shareable_shaders() {
441 KernelDevStateVariant::Nir(out.nir)
442 } else {
443 KernelDevStateVariant::Cso(cso)
444 };
445
446 NirKernelBuild {
447 nir_or_cso: nir_or_cso,
448 constant_buffer: cb,
449 info: info,
450 shared_size: shared_size,
451 printf_info: printf_info,
452 compiled_args: out.compiled_args,
453 }
454 }
455
create_nir_constant_buffer(dev: &Device, nir: &NirShader) -> Option<Arc<PipeResource>>456 fn create_nir_constant_buffer(dev: &Device, nir: &NirShader) -> Option<Arc<PipeResource>> {
457 let buf = nir.get_constant_buffer();
458 let len = buf.len() as u32;
459
460 if len > 0 {
461 // TODO bind as constant buffer
462 let res = dev
463 .screen()
464 .resource_create_buffer(len, ResourceType::Normal, PIPE_BIND_GLOBAL)
465 .unwrap();
466
467 dev.helper_ctx()
468 .exec(|ctx| ctx.buffer_subdata(&res, 0, buf.as_ptr().cast(), len))
469 .wait();
470
471 Some(Arc::new(res))
472 } else {
473 None
474 }
475 }
476 }
477
478 pub struct Kernel {
479 pub base: CLObjectBase<CL_INVALID_KERNEL>,
480 pub prog: Arc<Program>,
481 pub name: String,
482 values: Mutex<Vec<Option<KernelArgValue>>>,
483 builds: HashMap<&'static Device, Arc<NirKernelBuilds>>,
484 pub kernel_info: Arc<KernelInfo>,
485 }
486
487 impl_cl_type_trait!(cl_kernel, Kernel, CL_INVALID_KERNEL);
488
create_kernel_arr<T>(vals: &[usize], val: T) -> CLResult<[T; 3]> where T: std::convert::TryFrom<usize> + Copy, <T as std::convert::TryFrom<usize>>::Error: std::fmt::Debug,489 fn create_kernel_arr<T>(vals: &[usize], val: T) -> CLResult<[T; 3]>
490 where
491 T: std::convert::TryFrom<usize> + Copy,
492 <T as std::convert::TryFrom<usize>>::Error: std::fmt::Debug,
493 {
494 let mut res = [val; 3];
495 for (i, v) in vals.iter().enumerate() {
496 res[i] = (*v).try_into().ok().ok_or(CL_OUT_OF_RESOURCES)?;
497 }
498
499 Ok(res)
500 }
501
502 #[derive(Clone)]
503 struct CompilationResult {
504 nir: NirShader,
505 compiled_args: Vec<CompiledKernelArg>,
506 }
507
508 impl CompilationResult {
deserialize(reader: &mut blob_reader, d: &Device) -> Option<Self>509 fn deserialize(reader: &mut blob_reader, d: &Device) -> Option<Self> {
510 let nir = NirShader::deserialize(
511 reader,
512 d.screen()
513 .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE),
514 )?;
515 let compiled_args = CompiledKernelArg::deserialize(reader)?;
516
517 Some(Self {
518 nir: nir,
519 compiled_args,
520 })
521 }
522
serialize(&self, blob: &mut blob)523 fn serialize(&self, blob: &mut blob) {
524 self.nir.serialize(blob);
525 CompiledKernelArg::serialize(&self.compiled_args, blob);
526 }
527 }
528
opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool)529 fn opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool) {
530 let nir_options = unsafe {
531 &*dev
532 .screen
533 .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE)
534 };
535
536 while {
537 let mut progress = false;
538
539 progress |= nir_pass!(nir, nir_copy_prop);
540 progress |= nir_pass!(nir, nir_opt_copy_prop_vars);
541 progress |= nir_pass!(nir, nir_opt_dead_write_vars);
542
543 if nir_options.lower_to_scalar {
544 nir_pass!(
545 nir,
546 nir_lower_alu_to_scalar,
547 nir_options.lower_to_scalar_filter,
548 ptr::null(),
549 );
550 nir_pass!(nir, nir_lower_phis_to_scalar, false);
551 }
552
553 progress |= nir_pass!(nir, nir_opt_deref);
554 if has_explicit_types {
555 progress |= nir_pass!(nir, nir_opt_memcpy);
556 }
557 progress |= nir_pass!(nir, nir_opt_dce);
558 progress |= nir_pass!(nir, nir_opt_undef);
559 progress |= nir_pass!(nir, nir_opt_constant_folding);
560 progress |= nir_pass!(nir, nir_opt_cse);
561 nir_pass!(nir, nir_split_var_copies);
562 progress |= nir_pass!(nir, nir_lower_var_copies);
563 progress |= nir_pass!(nir, nir_lower_vars_to_ssa);
564 nir_pass!(nir, nir_lower_alu);
565 progress |= nir_pass!(nir, nir_opt_phi_precision);
566 progress |= nir_pass!(nir, nir_opt_algebraic);
567 progress |= nir_pass!(
568 nir,
569 nir_opt_if,
570 nir_opt_if_options::nir_opt_if_optimize_phi_true_false,
571 );
572 progress |= nir_pass!(nir, nir_opt_dead_cf);
573 progress |= nir_pass!(nir, nir_opt_remove_phis);
574 // we don't want to be too aggressive here, but it kills a bit of CFG
575 progress |= nir_pass!(nir, nir_opt_peephole_select, 8, true, true);
576 progress |= nir_pass!(
577 nir,
578 nir_lower_vec3_to_vec4,
579 nir_variable_mode::nir_var_mem_generic | nir_variable_mode::nir_var_uniform,
580 );
581
582 if nir_options.max_unroll_iterations != 0 {
583 progress |= nir_pass!(nir, nir_opt_loop_unroll);
584 }
585 nir.sweep_mem();
586 progress
587 } {}
588 }
589
590 /// # Safety
591 ///
592 /// Only safe to call when `var` is a valid pointer to a valid [`nir_variable`]
can_remove_var(var: *mut nir_variable, _: *mut c_void) -> bool593 unsafe extern "C" fn can_remove_var(var: *mut nir_variable, _: *mut c_void) -> bool {
594 // SAFETY: It is the caller's responsibility to provide a valid and aligned pointer
595 let var_type = unsafe { (*var).type_ };
596 // SAFETY: `nir_variable`'s type invariant guarantees that the `type_` field is valid and
597 // properly aligned.
598 unsafe {
599 !glsl_type_is_image(var_type)
600 && !glsl_type_is_texture(var_type)
601 && !glsl_type_is_sampler(var_type)
602 }
603 }
604
605 const DV_OPTS: nir_remove_dead_variables_options = nir_remove_dead_variables_options {
606 can_remove_var: Some(can_remove_var),
607 can_remove_var_data: ptr::null_mut(),
608 };
609
compile_nir_to_args( dev: &Device, mut nir: NirShader, args: &[spirv::SPIRVKernelArg], lib_clc: &NirShader, ) -> (Vec<KernelArg>, NirShader)610 fn compile_nir_to_args(
611 dev: &Device,
612 mut nir: NirShader,
613 args: &[spirv::SPIRVKernelArg],
614 lib_clc: &NirShader,
615 ) -> (Vec<KernelArg>, NirShader) {
616 // this is a hack until we support fp16 properly and check for denorms inside vstore/vload_half
617 nir.preserve_fp16_denorms();
618
619 // Set to rtne for now until drivers are able to report their preferred rounding mode, that also
620 // matches what we report via the API.
621 nir.set_fp_rounding_mode_rtne();
622
623 nir_pass!(nir, nir_scale_fdiv);
624 nir.set_workgroup_size_variable_if_zero();
625 nir.structurize();
626 while {
627 let mut progress = false;
628 nir_pass!(nir, nir_split_var_copies);
629 progress |= nir_pass!(nir, nir_copy_prop);
630 progress |= nir_pass!(nir, nir_opt_copy_prop_vars);
631 progress |= nir_pass!(nir, nir_opt_dead_write_vars);
632 progress |= nir_pass!(nir, nir_opt_deref);
633 progress |= nir_pass!(nir, nir_opt_dce);
634 progress |= nir_pass!(nir, nir_opt_undef);
635 progress |= nir_pass!(nir, nir_opt_constant_folding);
636 progress |= nir_pass!(nir, nir_opt_cse);
637 progress |= nir_pass!(nir, nir_lower_vars_to_ssa);
638 progress |= nir_pass!(nir, nir_opt_algebraic);
639 progress
640 } {}
641 nir.inline(lib_clc);
642 nir.cleanup_functions();
643 // that should free up tons of memory
644 nir.sweep_mem();
645
646 nir_pass!(nir, nir_dedup_inline_samplers);
647
648 let printf_opts = nir_lower_printf_options {
649 ptr_bit_size: 0,
650 use_printf_base_identifier: false,
651 max_buffer_size: dev.printf_buffer_size() as u32,
652 };
653 nir_pass!(nir, nir_lower_printf, &printf_opts);
654
655 opt_nir(&mut nir, dev, false);
656
657 (KernelArg::from_spirv_nir(args, &mut nir), nir)
658 }
659
compile_nir_prepare_for_variants( dev: &Device, nir: &mut NirShader, compiled_args: &mut Vec<CompiledKernelArg>, )660 fn compile_nir_prepare_for_variants(
661 dev: &Device,
662 nir: &mut NirShader,
663 compiled_args: &mut Vec<CompiledKernelArg>,
664 ) {
665 // assign locations for inline samplers.
666 // IMPORTANT: this needs to happen before nir_remove_dead_variables.
667 let mut last_loc = -1;
668 for v in nir
669 .variables_with_mode(nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image)
670 {
671 if unsafe { !glsl_type_is_sampler(v.type_) } {
672 last_loc = v.data.location;
673 continue;
674 }
675 let s = unsafe { v.data.anon_1.sampler };
676 if s.is_inline_sampler() != 0 {
677 last_loc += 1;
678 v.data.location = last_loc;
679
680 compiled_args.push(CompiledKernelArg {
681 kind: CompiledKernelArgType::InlineSampler(Sampler::nir_to_cl(
682 s.addressing_mode(),
683 s.filter_mode(),
684 s.normalized_coordinates(),
685 )),
686 offset: 0,
687 dead: true,
688 });
689 } else {
690 last_loc = v.data.location;
691 }
692 }
693
694 nir_pass!(
695 nir,
696 nir_remove_dead_variables,
697 nir_variable_mode::nir_var_uniform
698 | nir_variable_mode::nir_var_image
699 | nir_variable_mode::nir_var_mem_constant
700 | nir_variable_mode::nir_var_mem_shared
701 | nir_variable_mode::nir_var_function_temp,
702 &DV_OPTS,
703 );
704
705 nir_pass!(nir, nir_lower_readonly_images_to_tex, true);
706 nir_pass!(
707 nir,
708 nir_lower_cl_images,
709 !dev.images_as_deref(),
710 !dev.samplers_as_deref(),
711 );
712
713 nir_pass!(
714 nir,
715 nir_lower_vars_to_explicit_types,
716 nir_variable_mode::nir_var_mem_constant,
717 Some(glsl_get_cl_type_size_align),
718 );
719
720 // has to run before adding internal kernel arguments
721 nir.extract_constant_initializers();
722
723 // needed to convert variables to load intrinsics
724 nir_pass!(nir, nir_lower_system_values);
725
726 // Run here so we can decide if it makes sense to compile a variant, e.g. read system values.
727 nir.gather_info();
728 }
729
compile_nir_variant( res: &mut CompilationResult, dev: &Device, variant: NirKernelVariant, args: &[KernelArg], name: &str, )730 fn compile_nir_variant(
731 res: &mut CompilationResult,
732 dev: &Device,
733 variant: NirKernelVariant,
734 args: &[KernelArg],
735 name: &str,
736 ) {
737 let mut lower_state = rusticl_lower_state::default();
738 let compiled_args = &mut res.compiled_args;
739 let nir = &mut res.nir;
740
741 let address_bits_ptr_type;
742 let address_bits_base_type;
743 let global_address_format;
744 let shared_address_format;
745
746 if dev.address_bits() == 64 {
747 address_bits_ptr_type = unsafe { glsl_uint64_t_type() };
748 address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT64;
749 global_address_format = nir_address_format::nir_address_format_64bit_global;
750 shared_address_format = nir_address_format::nir_address_format_32bit_offset_as_64bit;
751 } else {
752 address_bits_ptr_type = unsafe { glsl_uint_type() };
753 address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT;
754 global_address_format = nir_address_format::nir_address_format_32bit_global;
755 shared_address_format = nir_address_format::nir_address_format_32bit_offset;
756 }
757
758 let nir_options = unsafe {
759 &*dev
760 .screen
761 .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE)
762 };
763
764 if variant == NirKernelVariant::Optimized {
765 let wgsh = nir.workgroup_size_hint();
766 if wgsh != [0; 3] {
767 nir.set_workgroup_size(wgsh);
768 }
769 }
770
771 let mut compute_options = nir_lower_compute_system_values_options::default();
772 compute_options.set_has_global_size(true);
773 if variant != NirKernelVariant::Optimized {
774 compute_options.set_has_base_global_invocation_id(true);
775 compute_options.set_has_base_workgroup_id(true);
776 }
777 nir_pass!(nir, nir_lower_compute_system_values, &compute_options);
778 nir.gather_info();
779
780 let mut add_var = |nir: &mut NirShader,
781 var_loc: &mut usize,
782 kind: CompiledKernelArgType,
783 glsl_type: *const glsl_type,
784 name: &str| {
785 *var_loc = compiled_args.len();
786 compiled_args.push(CompiledKernelArg {
787 kind: kind,
788 offset: 0,
789 dead: true,
790 });
791 nir.add_var(
792 nir_variable_mode::nir_var_uniform,
793 glsl_type,
794 *var_loc,
795 name,
796 );
797 };
798
799 if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_GLOBAL_INVOCATION_ID) {
800 debug_assert_ne!(variant, NirKernelVariant::Optimized);
801 add_var(
802 nir,
803 &mut lower_state.base_global_invoc_id_loc,
804 CompiledKernelArgType::GlobalWorkOffsets,
805 unsafe { glsl_vector_type(address_bits_base_type, 3) },
806 "base_global_invocation_id",
807 )
808 }
809
810 if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_GLOBAL_GROUP_SIZE) {
811 add_var(
812 nir,
813 &mut lower_state.global_size_loc,
814 CompiledKernelArgType::GlobalWorkSize,
815 unsafe { glsl_vector_type(address_bits_base_type, 3) },
816 "global_size",
817 )
818 }
819
820 if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_WORKGROUP_ID) {
821 debug_assert_ne!(variant, NirKernelVariant::Optimized);
822 add_var(
823 nir,
824 &mut lower_state.base_workgroup_id_loc,
825 CompiledKernelArgType::WorkGroupOffsets,
826 unsafe { glsl_vector_type(address_bits_base_type, 3) },
827 "base_workgroup_id",
828 );
829 }
830
831 if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_NUM_WORKGROUPS) {
832 add_var(
833 nir,
834 &mut lower_state.num_workgroups_loc,
835 CompiledKernelArgType::NumWorkgroups,
836 unsafe { glsl_vector_type(glsl_base_type::GLSL_TYPE_UINT, 3) },
837 "num_workgroups",
838 );
839 }
840
841 if nir.has_constant() {
842 add_var(
843 nir,
844 &mut lower_state.const_buf_loc,
845 CompiledKernelArgType::ConstantBuffer,
846 address_bits_ptr_type,
847 "constant_buffer_addr",
848 );
849 }
850 if nir.has_printf() {
851 add_var(
852 nir,
853 &mut lower_state.printf_buf_loc,
854 CompiledKernelArgType::PrintfBuffer,
855 address_bits_ptr_type,
856 "printf_buffer_addr",
857 );
858 }
859
860 if nir.num_images() > 0 || nir.num_textures() > 0 {
861 let count = nir.num_images() + nir.num_textures();
862
863 add_var(
864 nir,
865 &mut lower_state.format_arr_loc,
866 CompiledKernelArgType::FormatArray,
867 unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
868 "image_formats",
869 );
870
871 add_var(
872 nir,
873 &mut lower_state.order_arr_loc,
874 CompiledKernelArgType::OrderArray,
875 unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
876 "image_orders",
877 );
878 }
879
880 if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_WORK_DIM) {
881 add_var(
882 nir,
883 &mut lower_state.work_dim_loc,
884 CompiledKernelArgType::WorkDim,
885 unsafe { glsl_uint8_t_type() },
886 "work_dim",
887 );
888 }
889
890 // need to run after first opt loop and remove_dead_variables to get rid of uneccessary scratch
891 // memory
892 nir_pass!(
893 nir,
894 nir_lower_vars_to_explicit_types,
895 nir_variable_mode::nir_var_mem_shared
896 | nir_variable_mode::nir_var_function_temp
897 | nir_variable_mode::nir_var_shader_temp
898 | nir_variable_mode::nir_var_uniform
899 | nir_variable_mode::nir_var_mem_global
900 | nir_variable_mode::nir_var_mem_generic,
901 Some(glsl_get_cl_type_size_align),
902 );
903
904 opt_nir(nir, dev, true);
905 nir_pass!(nir, nir_lower_memcpy);
906
907 // we might have got rid of more function_temp or shared memory
908 nir.reset_scratch_size();
909 nir.reset_shared_size();
910 nir_pass!(
911 nir,
912 nir_remove_dead_variables,
913 nir_variable_mode::nir_var_function_temp | nir_variable_mode::nir_var_mem_shared,
914 &DV_OPTS,
915 );
916 nir_pass!(
917 nir,
918 nir_lower_vars_to_explicit_types,
919 nir_variable_mode::nir_var_function_temp
920 | nir_variable_mode::nir_var_mem_shared
921 | nir_variable_mode::nir_var_mem_generic,
922 Some(glsl_get_cl_type_size_align),
923 );
924
925 nir_pass!(
926 nir,
927 nir_lower_explicit_io,
928 nir_variable_mode::nir_var_mem_global | nir_variable_mode::nir_var_mem_constant,
929 global_address_format,
930 );
931
932 nir_pass!(nir, rusticl_lower_intrinsics, &mut lower_state);
933 nir_pass!(
934 nir,
935 nir_lower_explicit_io,
936 nir_variable_mode::nir_var_mem_shared
937 | nir_variable_mode::nir_var_function_temp
938 | nir_variable_mode::nir_var_uniform,
939 shared_address_format,
940 );
941
942 if nir_options.lower_int64_options.0 != 0 {
943 nir_pass!(nir, nir_lower_int64);
944 }
945
946 if nir_options.lower_uniforms_to_ubo {
947 nir_pass!(nir, rusticl_lower_inputs);
948 }
949
950 nir_pass!(nir, nir_lower_convert_alu_types, None);
951
952 opt_nir(nir, dev, true);
953
954 /* before passing it into drivers, assign locations as drivers might remove nir_variables or
955 * other things we depend on
956 */
957 CompiledKernelArg::assign_locations(compiled_args, nir);
958
959 /* update the has_variable_shared_mem info as we might have DCEed all of them */
960 nir.set_has_variable_shared_mem(compiled_args.iter().any(|arg| {
961 if let CompiledKernelArgType::APIArg(idx) = arg.kind {
962 args[idx as usize].kind == KernelArgType::MemLocal && !arg.dead
963 } else {
964 false
965 }
966 }));
967
968 if Platform::dbg().nir {
969 eprintln!("=== Printing nir variant '{variant}' for '{name}' before driver finalization");
970 nir.print();
971 }
972
973 if dev.screen.finalize_nir(nir) {
974 if Platform::dbg().nir {
975 eprintln!(
976 "=== Printing nir variant '{variant}' for '{name}' after driver finalization"
977 );
978 nir.print();
979 }
980 }
981
982 nir_pass!(nir, nir_opt_dce);
983 nir.sweep_mem();
984 }
985
compile_nir_remaining( dev: &Device, mut nir: NirShader, args: &[KernelArg], name: &str, ) -> (CompilationResult, Option<CompilationResult>)986 fn compile_nir_remaining(
987 dev: &Device,
988 mut nir: NirShader,
989 args: &[KernelArg],
990 name: &str,
991 ) -> (CompilationResult, Option<CompilationResult>) {
992 // add all API kernel args
993 let mut compiled_args: Vec<_> = (0..args.len())
994 .map(|idx| CompiledKernelArg {
995 kind: CompiledKernelArgType::APIArg(idx as u32),
996 offset: 0,
997 dead: true,
998 })
999 .collect();
1000
1001 compile_nir_prepare_for_variants(dev, &mut nir, &mut compiled_args);
1002 if Platform::dbg().nir {
1003 eprintln!("=== Printing nir for '{name}' before specialization");
1004 nir.print();
1005 }
1006
1007 let mut default_build = CompilationResult {
1008 nir: nir,
1009 compiled_args: compiled_args,
1010 };
1011
1012 // check if we even want to compile a variant before cloning the compilation state
1013 let has_wgs_hint = default_build.nir.workgroup_size_variable()
1014 && default_build.nir.workgroup_size_hint() != [0; 3];
1015 let has_offsets = default_build
1016 .nir
1017 .reads_sysval(gl_system_value::SYSTEM_VALUE_GLOBAL_INVOCATION_ID);
1018
1019 let mut optimized = (!Platform::dbg().no_variants && (has_offsets || has_wgs_hint))
1020 .then(|| default_build.clone());
1021
1022 compile_nir_variant(
1023 &mut default_build,
1024 dev,
1025 NirKernelVariant::Default,
1026 args,
1027 name,
1028 );
1029 if let Some(optimized) = &mut optimized {
1030 compile_nir_variant(optimized, dev, NirKernelVariant::Optimized, args, name);
1031 }
1032
1033 (default_build, optimized)
1034 }
1035
1036 pub struct SPIRVToNirResult {
1037 pub kernel_info: KernelInfo,
1038 pub nir_kernel_builds: NirKernelBuilds,
1039 }
1040
1041 impl SPIRVToNirResult {
new( dev: &'static Device, kernel_info: &clc_kernel_info, args: Vec<KernelArg>, default_build: CompilationResult, optimized: Option<CompilationResult>, ) -> Self1042 fn new(
1043 dev: &'static Device,
1044 kernel_info: &clc_kernel_info,
1045 args: Vec<KernelArg>,
1046 default_build: CompilationResult,
1047 optimized: Option<CompilationResult>,
1048 ) -> Self {
1049 // TODO: we _should_ be able to parse them out of the SPIR-V, but clc doesn't handle
1050 // indirections yet.
1051 let nir = &default_build.nir;
1052 let wgs = nir.workgroup_size();
1053 let subgroup_size = nir.subgroup_size();
1054 let num_subgroups = nir.num_subgroups();
1055
1056 let default_build = NirKernelBuild::new(dev, default_build);
1057 let optimized = optimized.map(|opt| NirKernelBuild::new(dev, opt));
1058
1059 let kernel_info = KernelInfo {
1060 args: args,
1061 attributes_string: kernel_info.attribute_str(),
1062 work_group_size: [wgs[0] as usize, wgs[1] as usize, wgs[2] as usize],
1063 work_group_size_hint: kernel_info.local_size_hint,
1064 subgroup_size: subgroup_size as usize,
1065 num_subgroups: num_subgroups as usize,
1066 };
1067
1068 Self {
1069 kernel_info: kernel_info,
1070 nir_kernel_builds: NirKernelBuilds::new(default_build, optimized),
1071 }
1072 }
1073
deserialize(bin: &[u8], d: &'static Device, kernel_info: &clc_kernel_info) -> Option<Self>1074 fn deserialize(bin: &[u8], d: &'static Device, kernel_info: &clc_kernel_info) -> Option<Self> {
1075 let mut reader = blob_reader::default();
1076 unsafe {
1077 blob_reader_init(&mut reader, bin.as_ptr().cast(), bin.len());
1078 }
1079
1080 let args = KernelArg::deserialize(&mut reader)?;
1081 let default_build = CompilationResult::deserialize(&mut reader, d)?;
1082
1083 let optimized = match unsafe { blob_read_uint8(&mut reader) } {
1084 0 => None,
1085 _ => Some(CompilationResult::deserialize(&mut reader, d)?),
1086 };
1087
1088 Some(SPIRVToNirResult::new(
1089 d,
1090 kernel_info,
1091 args,
1092 default_build,
1093 optimized,
1094 ))
1095 }
1096
1097 // we can't use Self here as the nir shader might be compiled to a cso already and we can't
1098 // cache that.
serialize( blob: &mut blob, args: &[KernelArg], default_build: &CompilationResult, optimized: &Option<CompilationResult>, )1099 fn serialize(
1100 blob: &mut blob,
1101 args: &[KernelArg],
1102 default_build: &CompilationResult,
1103 optimized: &Option<CompilationResult>,
1104 ) {
1105 KernelArg::serialize(args, blob);
1106 default_build.serialize(blob);
1107 match optimized {
1108 Some(variant) => {
1109 unsafe { blob_write_uint8(blob, 1) };
1110 variant.serialize(blob);
1111 }
1112 None => unsafe {
1113 blob_write_uint8(blob, 0);
1114 },
1115 }
1116 }
1117 }
1118
convert_spirv_to_nir( build: &ProgramBuild, name: &str, args: &[spirv::SPIRVKernelArg], dev: &'static Device, ) -> SPIRVToNirResult1119 pub(super) fn convert_spirv_to_nir(
1120 build: &ProgramBuild,
1121 name: &str,
1122 args: &[spirv::SPIRVKernelArg],
1123 dev: &'static Device,
1124 ) -> SPIRVToNirResult {
1125 let cache = dev.screen().shader_cache();
1126 let key = build.hash_key(dev, name);
1127 let spirv_info = build.spirv_info(name, dev).unwrap();
1128
1129 cache
1130 .as_ref()
1131 .and_then(|cache| cache.get(&mut key?))
1132 .and_then(|entry| SPIRVToNirResult::deserialize(&entry, dev, spirv_info))
1133 .unwrap_or_else(|| {
1134 let nir = build.to_nir(name, dev);
1135
1136 if Platform::dbg().nir {
1137 eprintln!("=== Printing nir for '{name}' after spirv_to_nir");
1138 nir.print();
1139 }
1140
1141 let (mut args, nir) = compile_nir_to_args(dev, nir, args, &dev.lib_clc);
1142 let (default_build, optimized) = compile_nir_remaining(dev, nir, &args, name);
1143
1144 for build in [Some(&default_build), optimized.as_ref()].into_iter() {
1145 let Some(build) = build else {
1146 continue;
1147 };
1148
1149 for arg in &build.compiled_args {
1150 if let CompiledKernelArgType::APIArg(idx) = arg.kind {
1151 args[idx as usize].dead &= arg.dead;
1152 }
1153 }
1154 }
1155
1156 if let Some(cache) = cache {
1157 let mut blob = blob::default();
1158 unsafe {
1159 blob_init(&mut blob);
1160 SPIRVToNirResult::serialize(&mut blob, &args, &default_build, &optimized);
1161 let bin = slice::from_raw_parts(blob.data, blob.size);
1162 cache.put(bin, &mut key.unwrap());
1163 blob_finish(&mut blob);
1164 }
1165 }
1166
1167 SPIRVToNirResult::new(dev, spirv_info, args, default_build, optimized)
1168 })
1169 }
1170
extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S]1171 fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
1172 let val;
1173 (val, *buf) = (*buf).split_at(S);
1174 // we split of 4 bytes and convert to [u8; 4], so this should be safe
1175 // use split_array_ref once it's stable
1176 val.try_into().unwrap()
1177 }
1178
1179 impl Kernel {
new(name: String, prog: Arc<Program>, prog_build: &ProgramBuild) -> Arc<Kernel>1180 pub fn new(name: String, prog: Arc<Program>, prog_build: &ProgramBuild) -> Arc<Kernel> {
1181 let kernel_info = Arc::clone(prog_build.kernel_info.get(&name).unwrap());
1182 let builds = prog_build
1183 .builds
1184 .iter()
1185 .filter_map(|(&dev, b)| b.kernels.get(&name).map(|k| (dev, k.clone())))
1186 .collect();
1187
1188 let values = vec![None; kernel_info.args.len()];
1189 Arc::new(Self {
1190 base: CLObjectBase::new(RusticlTypes::Kernel),
1191 prog: prog,
1192 name: name,
1193 values: Mutex::new(values),
1194 builds: builds,
1195 kernel_info: kernel_info,
1196 })
1197 }
1198
suggest_local_size( &self, d: &Device, work_dim: usize, grid: &mut [usize], block: &mut [usize], )1199 pub fn suggest_local_size(
1200 &self,
1201 d: &Device,
1202 work_dim: usize,
1203 grid: &mut [usize],
1204 block: &mut [usize],
1205 ) {
1206 let mut threads = self.max_threads_per_block(d);
1207 let dim_threads = d.max_block_sizes();
1208 let subgroups = self.preferred_simd_size(d);
1209
1210 for i in 0..work_dim {
1211 let t = cmp::min(threads, dim_threads[i]);
1212 let gcd = gcd(t, grid[i]);
1213
1214 block[i] = gcd;
1215 grid[i] /= gcd;
1216
1217 // update limits
1218 threads /= block[i];
1219 }
1220
1221 // if we didn't fill the subgroup we can do a bit better if we have threads remaining
1222 let total_threads = block.iter().take(work_dim).product::<usize>();
1223 if threads != 1 && total_threads < subgroups {
1224 for i in 0..work_dim {
1225 if grid[i] * total_threads < threads && grid[i] * block[i] <= dim_threads[i] {
1226 block[i] *= grid[i];
1227 grid[i] = 1;
1228 // can only do it once as nothing is cleanly divisible
1229 break;
1230 }
1231 }
1232 }
1233 }
1234
optimize_local_size(&self, d: &Device, grid: &mut [usize; 3], block: &mut [u32; 3])1235 fn optimize_local_size(&self, d: &Device, grid: &mut [usize; 3], block: &mut [u32; 3]) {
1236 if !block.contains(&0) {
1237 for i in 0..3 {
1238 // we already made sure everything is fine
1239 grid[i] /= block[i] as usize;
1240 }
1241 return;
1242 }
1243
1244 let mut usize_block = [0usize; 3];
1245 for i in 0..3 {
1246 usize_block[i] = block[i] as usize;
1247 }
1248
1249 self.suggest_local_size(d, 3, grid, &mut usize_block);
1250
1251 for i in 0..3 {
1252 block[i] = usize_block[i] as u32;
1253 }
1254 }
1255
1256 // the painful part is, that host threads are allowed to modify the kernel object once it was
1257 // enqueued, so return a closure with all req data included.
launch( self: &Arc<Self>, q: &Arc<Queue>, work_dim: u32, block: &[usize], grid: &[usize], offsets: &[usize], ) -> CLResult<EventSig>1258 pub fn launch(
1259 self: &Arc<Self>,
1260 q: &Arc<Queue>,
1261 work_dim: u32,
1262 block: &[usize],
1263 grid: &[usize],
1264 offsets: &[usize],
1265 ) -> CLResult<EventSig> {
1266 // Clone all the data we need to execute this kernel
1267 let kernel_info = Arc::clone(&self.kernel_info);
1268 let arg_values = self.arg_values().clone();
1269 let nir_kernel_builds = Arc::clone(&self.builds[q.device]);
1270
1271 // operations we want to report errors to the clients
1272 let mut block = create_kernel_arr::<u32>(block, 1)?;
1273 let mut grid = create_kernel_arr::<usize>(grid, 1)?;
1274 let offsets = create_kernel_arr::<usize>(offsets, 0)?;
1275
1276 let api_grid = grid;
1277
1278 self.optimize_local_size(q.device, &mut grid, &mut block);
1279
1280 Ok(Box::new(move |q, ctx| {
1281 let hw_max_grid: Vec<usize> = q
1282 .device
1283 .max_grid_size()
1284 .into_iter()
1285 .map(|val| val.try_into().unwrap_or(usize::MAX))
1286 // clamped as pipe_launch_grid::grid is only u32
1287 .map(|val| cmp::min(val, u32::MAX as usize))
1288 .collect();
1289
1290 let variant = if offsets == [0; 3]
1291 && grid[0] <= hw_max_grid[0]
1292 && grid[1] <= hw_max_grid[1]
1293 && grid[2] <= hw_max_grid[2]
1294 && block == kernel_info.work_group_size_hint
1295 {
1296 NirKernelVariant::Optimized
1297 } else {
1298 NirKernelVariant::Default
1299 };
1300
1301 let nir_kernel_build = &nir_kernel_builds[variant];
1302 let mut workgroup_id_offset_loc = None;
1303 let mut input = Vec::new();
1304 // Set it once so we get the alignment padding right
1305 let static_local_size: u64 = nir_kernel_build.shared_size;
1306 let mut variable_local_size: u64 = static_local_size;
1307 let printf_size = q.device.printf_buffer_size() as u32;
1308 let mut samplers = Vec::new();
1309 let mut iviews = Vec::new();
1310 let mut sviews = Vec::new();
1311 let mut tex_formats: Vec<u16> = Vec::new();
1312 let mut tex_orders: Vec<u16> = Vec::new();
1313 let mut img_formats: Vec<u16> = Vec::new();
1314 let mut img_orders: Vec<u16> = Vec::new();
1315
1316 let null_ptr;
1317 let null_ptr_v3;
1318 if q.device.address_bits() == 64 {
1319 null_ptr = [0u8; 8].as_slice();
1320 null_ptr_v3 = [0u8; 24].as_slice();
1321 } else {
1322 null_ptr = [0u8; 4].as_slice();
1323 null_ptr_v3 = [0u8; 12].as_slice();
1324 };
1325
1326 let mut resource_info = Vec::new();
1327 fn add_global<'a>(
1328 q: &Queue,
1329 input: &mut Vec<u8>,
1330 resource_info: &mut Vec<(&'a PipeResource, usize)>,
1331 res: &'a PipeResource,
1332 offset: usize,
1333 ) {
1334 resource_info.push((res, input.len()));
1335 if q.device.address_bits() == 64 {
1336 let offset: u64 = offset as u64;
1337 input.extend_from_slice(&offset.to_ne_bytes());
1338 } else {
1339 let offset: u32 = offset as u32;
1340 input.extend_from_slice(&offset.to_ne_bytes());
1341 }
1342 }
1343
1344 fn add_sysval(q: &Queue, input: &mut Vec<u8>, vals: &[usize; 3]) {
1345 if q.device.address_bits() == 64 {
1346 input.extend_from_slice(unsafe { as_byte_slice(&vals.map(|v| v as u64)) });
1347 } else {
1348 input.extend_from_slice(unsafe { as_byte_slice(&vals.map(|v| v as u32)) });
1349 }
1350 }
1351
1352 let mut printf_buf = None;
1353 if nir_kernel_build.printf_info.is_some() {
1354 let buf = q
1355 .device
1356 .screen
1357 .resource_create_buffer(printf_size, ResourceType::Staging, PIPE_BIND_GLOBAL)
1358 .unwrap();
1359
1360 let init_data: [u8; 1] = [4];
1361 ctx.buffer_subdata(&buf, 0, init_data.as_ptr().cast(), init_data.len() as u32);
1362
1363 printf_buf = Some(buf);
1364 }
1365
1366 for arg in &nir_kernel_build.compiled_args {
1367 let is_opaque = if let CompiledKernelArgType::APIArg(idx) = arg.kind {
1368 kernel_info.args[idx as usize].kind.is_opaque()
1369 } else {
1370 false
1371 };
1372
1373 if !is_opaque && arg.offset as usize > input.len() {
1374 input.resize(arg.offset as usize, 0);
1375 }
1376
1377 match arg.kind {
1378 CompiledKernelArgType::APIArg(idx) => {
1379 let api_arg = &kernel_info.args[idx as usize];
1380 if api_arg.dead {
1381 continue;
1382 }
1383
1384 let Some(value) = &arg_values[idx as usize] else {
1385 continue;
1386 };
1387
1388 match value {
1389 KernelArgValue::Constant(c) => input.extend_from_slice(c),
1390 KernelArgValue::Buffer(buffer) => {
1391 let res = buffer.get_res_of_dev(q.device)?;
1392 add_global(q, &mut input, &mut resource_info, res, buffer.offset);
1393 }
1394 KernelArgValue::Image(image) => {
1395 let res = image.get_res_of_dev(q.device)?;
1396
1397 // If resource is a buffer, the image was created from a buffer. Use
1398 // strides and dimensions of the image then.
1399 let app_img_info = if res.as_ref().is_buffer()
1400 && image.mem_type == CL_MEM_OBJECT_IMAGE2D
1401 {
1402 Some(AppImgInfo::new(
1403 image.image_desc.row_pitch()?
1404 / image.image_elem_size as u32,
1405 image.image_desc.width()?,
1406 image.image_desc.height()?,
1407 ))
1408 } else {
1409 None
1410 };
1411
1412 let format = image.pipe_format;
1413 let (formats, orders) = if api_arg.kind == KernelArgType::Image {
1414 iviews.push(res.pipe_image_view(
1415 format,
1416 false,
1417 image.pipe_image_host_access(),
1418 app_img_info.as_ref(),
1419 ));
1420 (&mut img_formats, &mut img_orders)
1421 } else if api_arg.kind == KernelArgType::RWImage {
1422 iviews.push(res.pipe_image_view(
1423 format,
1424 true,
1425 image.pipe_image_host_access(),
1426 app_img_info.as_ref(),
1427 ));
1428 (&mut img_formats, &mut img_orders)
1429 } else {
1430 sviews.push((res.clone(), format, app_img_info));
1431 (&mut tex_formats, &mut tex_orders)
1432 };
1433
1434 let binding = arg.offset as usize;
1435 assert!(binding >= formats.len());
1436
1437 formats.resize(binding, 0);
1438 orders.resize(binding, 0);
1439
1440 formats.push(image.image_format.image_channel_data_type as u16);
1441 orders.push(image.image_format.image_channel_order as u16);
1442 }
1443 KernelArgValue::LocalMem(size) => {
1444 // TODO 32 bit
1445 let pot = cmp::min(*size, 0x80);
1446 variable_local_size = variable_local_size
1447 .next_multiple_of(pot.next_power_of_two() as u64);
1448 if q.device.address_bits() == 64 {
1449 let variable_local_size: [u8; 8] =
1450 variable_local_size.to_ne_bytes();
1451 input.extend_from_slice(&variable_local_size);
1452 } else {
1453 let variable_local_size: [u8; 4] =
1454 (variable_local_size as u32).to_ne_bytes();
1455 input.extend_from_slice(&variable_local_size);
1456 }
1457 variable_local_size += *size as u64;
1458 }
1459 KernelArgValue::Sampler(sampler) => {
1460 samplers.push(sampler.pipe());
1461 }
1462 KernelArgValue::None => {
1463 assert!(
1464 api_arg.kind == KernelArgType::MemGlobal
1465 || api_arg.kind == KernelArgType::MemConstant
1466 );
1467 input.extend_from_slice(null_ptr);
1468 }
1469 }
1470 }
1471 CompiledKernelArgType::ConstantBuffer => {
1472 assert!(nir_kernel_build.constant_buffer.is_some());
1473 let res = nir_kernel_build.constant_buffer.as_ref().unwrap();
1474 add_global(q, &mut input, &mut resource_info, res, 0);
1475 }
1476 CompiledKernelArgType::GlobalWorkOffsets => {
1477 add_sysval(q, &mut input, &offsets);
1478 }
1479 CompiledKernelArgType::WorkGroupOffsets => {
1480 workgroup_id_offset_loc = Some(input.len());
1481 input.extend_from_slice(null_ptr_v3);
1482 }
1483 CompiledKernelArgType::GlobalWorkSize => {
1484 add_sysval(q, &mut input, &api_grid);
1485 }
1486 CompiledKernelArgType::PrintfBuffer => {
1487 let res = printf_buf.as_ref().unwrap();
1488 add_global(q, &mut input, &mut resource_info, res, 0);
1489 }
1490 CompiledKernelArgType::InlineSampler(cl) => {
1491 samplers.push(Sampler::cl_to_pipe(cl));
1492 }
1493 CompiledKernelArgType::FormatArray => {
1494 input.extend_from_slice(unsafe { as_byte_slice(&tex_formats) });
1495 input.extend_from_slice(unsafe { as_byte_slice(&img_formats) });
1496 }
1497 CompiledKernelArgType::OrderArray => {
1498 input.extend_from_slice(unsafe { as_byte_slice(&tex_orders) });
1499 input.extend_from_slice(unsafe { as_byte_slice(&img_orders) });
1500 }
1501 CompiledKernelArgType::WorkDim => {
1502 input.extend_from_slice(&[work_dim as u8; 1]);
1503 }
1504 CompiledKernelArgType::NumWorkgroups => {
1505 input.extend_from_slice(unsafe {
1506 as_byte_slice(&[grid[0] as u32, grid[1] as u32, grid[2] as u32])
1507 });
1508 }
1509 }
1510 }
1511
1512 // subtract the shader local_size as we only request something on top of that.
1513 variable_local_size -= static_local_size;
1514
1515 let mut sviews: Vec<_> = sviews
1516 .iter()
1517 .map(|(s, f, aii)| ctx.create_sampler_view(s, *f, aii.as_ref()))
1518 .collect();
1519 let samplers: Vec<_> = samplers
1520 .iter()
1521 .map(|s| ctx.create_sampler_state(s))
1522 .collect();
1523
1524 let mut resources = Vec::with_capacity(resource_info.len());
1525 let mut globals: Vec<*mut u32> = Vec::with_capacity(resource_info.len());
1526 for (res, offset) in resource_info {
1527 resources.push(res);
1528 globals.push(unsafe { input.as_mut_ptr().byte_add(offset) }.cast());
1529 }
1530
1531 let temp_cso;
1532 let cso = match &nir_kernel_build.nir_or_cso {
1533 KernelDevStateVariant::Cso(cso) => cso,
1534 KernelDevStateVariant::Nir(nir) => {
1535 temp_cso = CSOWrapper::new(q.device, nir);
1536 &temp_cso
1537 }
1538 };
1539
1540 ctx.bind_compute_state(cso.cso_ptr);
1541 ctx.bind_sampler_states(&samplers);
1542 ctx.set_sampler_views(&mut sviews);
1543 ctx.set_shader_images(&iviews);
1544 ctx.set_global_binding(resources.as_slice(), &mut globals);
1545
1546 for z in 0..grid[2].div_ceil(hw_max_grid[2]) {
1547 for y in 0..grid[1].div_ceil(hw_max_grid[1]) {
1548 for x in 0..grid[0].div_ceil(hw_max_grid[0]) {
1549 if let Some(workgroup_id_offset_loc) = workgroup_id_offset_loc {
1550 let this_offsets =
1551 [x * hw_max_grid[0], y * hw_max_grid[1], z * hw_max_grid[2]];
1552
1553 if q.device.address_bits() == 64 {
1554 let val = this_offsets.map(|v| v as u64);
1555 input[workgroup_id_offset_loc..workgroup_id_offset_loc + 24]
1556 .copy_from_slice(unsafe { as_byte_slice(&val) });
1557 } else {
1558 let val = this_offsets.map(|v| v as u32);
1559 input[workgroup_id_offset_loc..workgroup_id_offset_loc + 12]
1560 .copy_from_slice(unsafe { as_byte_slice(&val) });
1561 }
1562 }
1563
1564 let this_grid = [
1565 cmp::min(hw_max_grid[0], grid[0] - hw_max_grid[0] * x) as u32,
1566 cmp::min(hw_max_grid[1], grid[1] - hw_max_grid[1] * y) as u32,
1567 cmp::min(hw_max_grid[2], grid[2] - hw_max_grid[2] * z) as u32,
1568 ];
1569
1570 ctx.update_cb0(&input)?;
1571 ctx.launch_grid(work_dim, block, this_grid, variable_local_size as u32);
1572
1573 if Platform::dbg().sync_every_event {
1574 ctx.flush().wait();
1575 }
1576 }
1577 }
1578 }
1579
1580 ctx.clear_global_binding(globals.len() as u32);
1581 ctx.clear_shader_images(iviews.len() as u32);
1582 ctx.clear_sampler_views(sviews.len() as u32);
1583 ctx.clear_sampler_states(samplers.len() as u32);
1584
1585 ctx.bind_compute_state(ptr::null_mut());
1586
1587 ctx.memory_barrier(PIPE_BARRIER_GLOBAL_BUFFER);
1588
1589 samplers.iter().for_each(|s| ctx.delete_sampler_state(*s));
1590 sviews.iter().for_each(|v| ctx.sampler_view_destroy(*v));
1591
1592 if let Some(printf_buf) = &printf_buf {
1593 let tx = ctx
1594 .buffer_map(printf_buf, 0, printf_size as i32, RWFlags::RD)
1595 .ok_or(CL_OUT_OF_RESOURCES)?;
1596 let mut buf: &[u8] =
1597 unsafe { slice::from_raw_parts(tx.ptr().cast(), printf_size as usize) };
1598 let length = u32::from_ne_bytes(*extract(&mut buf));
1599
1600 // update our slice to make sure we don't go out of bounds
1601 buf = &buf[0..(length - 4) as usize];
1602 if let Some(pf) = &nir_kernel_build.printf_info {
1603 pf.u_printf(buf)
1604 }
1605 }
1606
1607 Ok(())
1608 }))
1609 }
1610
arg_values(&self) -> MutexGuard<Vec<Option<KernelArgValue>>>1611 pub fn arg_values(&self) -> MutexGuard<Vec<Option<KernelArgValue>>> {
1612 self.values.lock().unwrap()
1613 }
1614
set_kernel_arg(&self, idx: usize, arg: KernelArgValue) -> CLResult<()>1615 pub fn set_kernel_arg(&self, idx: usize, arg: KernelArgValue) -> CLResult<()> {
1616 self.values
1617 .lock()
1618 .unwrap()
1619 .get_mut(idx)
1620 .ok_or(CL_INVALID_ARG_INDEX)?
1621 .replace(arg);
1622 Ok(())
1623 }
1624
access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier1625 pub fn access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier {
1626 let aq = self.kernel_info.args[idx as usize].spirv.access_qualifier;
1627
1628 if aq
1629 == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ
1630 | clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_WRITE
1631 {
1632 CL_KERNEL_ARG_ACCESS_READ_WRITE
1633 } else if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ {
1634 CL_KERNEL_ARG_ACCESS_READ_ONLY
1635 } else if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_WRITE {
1636 CL_KERNEL_ARG_ACCESS_WRITE_ONLY
1637 } else {
1638 CL_KERNEL_ARG_ACCESS_NONE
1639 }
1640 }
1641
address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier1642 pub fn address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier {
1643 match self.kernel_info.args[idx as usize].spirv.address_qualifier {
1644 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
1645 CL_KERNEL_ARG_ADDRESS_PRIVATE
1646 }
1647 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_CONSTANT => {
1648 CL_KERNEL_ARG_ADDRESS_CONSTANT
1649 }
1650 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_LOCAL => {
1651 CL_KERNEL_ARG_ADDRESS_LOCAL
1652 }
1653 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_GLOBAL => {
1654 CL_KERNEL_ARG_ADDRESS_GLOBAL
1655 }
1656 }
1657 }
1658
type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier1659 pub fn type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier {
1660 let tq = self.kernel_info.args[idx as usize].spirv.type_qualifier;
1661 let zero = clc_kernel_arg_type_qualifier(0);
1662 let mut res = CL_KERNEL_ARG_TYPE_NONE;
1663
1664 if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_CONST != zero {
1665 res |= CL_KERNEL_ARG_TYPE_CONST;
1666 }
1667
1668 if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_RESTRICT != zero {
1669 res |= CL_KERNEL_ARG_TYPE_RESTRICT;
1670 }
1671
1672 if tq & clc_kernel_arg_type_qualifier::CLC_KERNEL_ARG_TYPE_VOLATILE != zero {
1673 res |= CL_KERNEL_ARG_TYPE_VOLATILE;
1674 }
1675
1676 res.into()
1677 }
1678
work_group_size(&self) -> [usize; 3]1679 pub fn work_group_size(&self) -> [usize; 3] {
1680 self.kernel_info.work_group_size
1681 }
1682
num_subgroups(&self) -> usize1683 pub fn num_subgroups(&self) -> usize {
1684 self.kernel_info.num_subgroups
1685 }
1686
subgroup_size(&self) -> usize1687 pub fn subgroup_size(&self) -> usize {
1688 self.kernel_info.subgroup_size
1689 }
1690
arg_name(&self, idx: cl_uint) -> &String1691 pub fn arg_name(&self, idx: cl_uint) -> &String {
1692 &self.kernel_info.args[idx as usize].spirv.name
1693 }
1694
arg_type_name(&self, idx: cl_uint) -> &String1695 pub fn arg_type_name(&self, idx: cl_uint) -> &String {
1696 &self.kernel_info.args[idx as usize].spirv.type_name
1697 }
1698
priv_mem_size(&self, dev: &Device) -> cl_ulong1699 pub fn priv_mem_size(&self, dev: &Device) -> cl_ulong {
1700 self.builds.get(dev).unwrap().info.private_memory as cl_ulong
1701 }
1702
max_threads_per_block(&self, dev: &Device) -> usize1703 pub fn max_threads_per_block(&self, dev: &Device) -> usize {
1704 self.builds.get(dev).unwrap().info.max_threads as usize
1705 }
1706
preferred_simd_size(&self, dev: &Device) -> usize1707 pub fn preferred_simd_size(&self, dev: &Device) -> usize {
1708 self.builds.get(dev).unwrap().info.preferred_simd_size as usize
1709 }
1710
local_mem_size(&self, dev: &Device) -> cl_ulong1711 pub fn local_mem_size(&self, dev: &Device) -> cl_ulong {
1712 // TODO include args
1713 // this is purely informational so it shouldn't even matter
1714 self.builds.get(dev).unwrap()[NirKernelVariant::Default].shared_size as cl_ulong
1715 }
1716
has_svm_devs(&self) -> bool1717 pub fn has_svm_devs(&self) -> bool {
1718 self.prog.devs.iter().any(|dev| dev.svm_supported())
1719 }
1720
subgroup_sizes(&self, dev: &Device) -> Vec<usize>1721 pub fn subgroup_sizes(&self, dev: &Device) -> Vec<usize> {
1722 SetBitIndices::from_msb(self.builds.get(dev).unwrap().info.simd_sizes)
1723 .map(|bit| 1 << bit)
1724 .collect()
1725 }
1726
subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize1727 pub fn subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize {
1728 let subgroup_size = self.subgroup_size_for_block(dev, block);
1729 if subgroup_size == 0 {
1730 return 0;
1731 }
1732
1733 let threads: usize = block.iter().product();
1734 threads.div_ceil(subgroup_size)
1735 }
1736
subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize1737 pub fn subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize {
1738 let subgroup_sizes = self.subgroup_sizes(dev);
1739 if subgroup_sizes.is_empty() {
1740 return 0;
1741 }
1742
1743 if subgroup_sizes.len() == 1 {
1744 return subgroup_sizes[0];
1745 }
1746
1747 let block = [
1748 *block.first().unwrap_or(&1) as u32,
1749 *block.get(1).unwrap_or(&1) as u32,
1750 *block.get(2).unwrap_or(&1) as u32,
1751 ];
1752
1753 // TODO: this _might_ bite us somewhere, but I think it probably doesn't matter
1754 match &self.builds.get(dev).unwrap()[NirKernelVariant::Default].nir_or_cso {
1755 KernelDevStateVariant::Cso(cso) => {
1756 dev.helper_ctx()
1757 .compute_state_subgroup_size(cso.cso_ptr, &block) as usize
1758 }
1759 _ => {
1760 panic!()
1761 }
1762 }
1763 }
1764 }
1765
1766 impl Clone for Kernel {
clone(&self) -> Self1767 fn clone(&self) -> Self {
1768 Self {
1769 base: CLObjectBase::new(RusticlTypes::Kernel),
1770 prog: self.prog.clone(),
1771 name: self.name.clone(),
1772 values: Mutex::new(self.arg_values().clone()),
1773 builds: self.builds.clone(),
1774 kernel_info: self.kernel_info.clone(),
1775 }
1776 }
1777 }
1778