1 //
2 // Copyright (C) 2009-2021 Intel Corporation
3 //
4 // SPDX-License-Identifier: MIT
5 //
6 //
7
8 #pragma once
9
10 #include "common.h"
11 #include "libs/lsc_intrinsics.h"
12
13 /* ============================================================================= */
14 /* ============================== LSB RADIX SORT =============================== */
15 /* ============================================================================= */
16
17 #define RADIX_BINS 256
18 #define SCATTER_WG_SIZE 512
19 #define MORTON_LSB_SORT_NO_SHIFT_THRESHOLD 0xFFFFFFFF // turn off, because current hierarchy build requires full sort
20
get_thread_range(uint numItems,uint numGroups,uint taskID)21 uint2 get_thread_range( uint numItems, uint numGroups, uint taskID )
22 {
23 uint items_per_group = (numItems / numGroups);
24 uint remainder = numItems - (items_per_group * numGroups);
25 uint startID = taskID * items_per_group + min(taskID, remainder);
26 uint endID = startID + items_per_group + ((taskID < remainder) ? 1 : 0);
27
28 return (uint2)(startID,endID);
29 }
30
sort_morton_codes_bin_items_taskID_func(global struct Globals * globals,global uint * global_histogram,global uchar * input,local uint * histogram,uint iteration,uint numGroups,uint numItems,bool shift_primID,uint taskID,uint startID,uint endID)31 GRL_INLINE void sort_morton_codes_bin_items_taskID_func(global struct Globals* globals,
32 global uint* global_histogram,
33 global uchar* input,
34 local uint* histogram,
35 uint iteration,
36 uint numGroups,
37 uint numItems,
38 bool shift_primID,
39 uint taskID,
40 uint startID,
41 uint endID)
42 {
43 const uint shift = globals->shift;
44
45 for (uint i = get_local_id(0); i < RADIX_BINS; i += get_local_size(0))
46 histogram[i] = 0;
47
48 barrier(CLK_LOCAL_MEM_FENCE);
49
50 if (shift_primID)
51 {
52 for (uint i = startID + get_local_id(0); i < endID; i += get_local_size(0))
53 {
54 // Read input as ulong to make bitshift, so the bits representing primID are not being
55 // taken into account during sorting, which would result in smaller sort loops for
56 // cases where morton shift are bigger than 8 bits
57 ulong* ptr_ul = (ulong*)&input[8 * i];
58 ulong code = *ptr_ul;
59 uchar* ptr = (uchar*)&code;
60 code >>= shift;
61
62 uchar bin = ptr[iteration];
63 atomic_inc_local(&histogram[bin]);
64 }
65 }
66 else
67 {
68 for (uint i = startID + get_local_id(0); i < endID; i += get_local_size(0))
69 {
70 uchar bin = input[8 * i + iteration];
71 atomic_inc_local(&histogram[bin]);
72 }
73 }
74
75 barrier(CLK_LOCAL_MEM_FENCE);
76
77 for (uint i = get_local_id(0); i < RADIX_BINS; i += get_local_size(0))
78 global_histogram[RADIX_BINS * taskID + i] = histogram[i];
79 }
80
sort_morton_codes_bin_items_func(global struct Globals * globals,global uint * global_histogram,global uint * wg_flags,global uchar * input,local uint * histogram,uint iteration,uint numGroups,uint numItems,bool shift_primID,bool update_wg_flags)81 GRL_INLINE void sort_morton_codes_bin_items_func(global struct Globals* globals,
82 global uint* global_histogram,
83 global uint* wg_flags,
84 global uchar* input,
85 local uint* histogram,
86 uint iteration,
87 uint numGroups,
88 uint numItems,
89 bool shift_primID,
90 bool update_wg_flags)
91 {
92 if (shift_primID)
93 {
94 // This check is present in other LSB sort functions as well, its purpose is
95 // to skip first n iterations where n is the difference between max iterations
96 // and actually needed iterations to sort without primIDs
97 const uint req_iterations = globals->sort_iterations;
98 if (iteration < req_iterations)
99 return;
100
101 // iteration needs to be adjusted to reflect the skipped cycles
102 iteration -= req_iterations;
103 }
104
105 const uint taskID = get_group_id(0);
106
107 if (taskID == 0 && update_wg_flags)
108 {
109 for (uint i = get_local_id(0); i < RADIX_BINS; i += get_local_size(0))
110 wg_flags[i] = 0;
111 }
112
113 uint2 ids = get_thread_range(numItems, numGroups, taskID);
114 uint startID = ids.x;
115 uint endID = ids.y;
116
117 sort_morton_codes_bin_items_taskID_func(globals, global_histogram, input, histogram, iteration, numGroups, numItems, shift_primID,
118 taskID, startID, endID);
119 }
120
121 __attribute__((reqd_work_group_size(512, 1, 1)))
122 void kernel
sort_morton_codes_bin_items(global struct Globals * globals,global uint * global_histogram,global uint * wg_flags,global uchar * input,uint iteration,uint numGroups,uint update_wg_flags)123 sort_morton_codes_bin_items(
124 global struct Globals* globals,
125 global uint* global_histogram,
126 global uint* wg_flags,
127 global uchar* input,
128 uint iteration,
129 uint numGroups,
130 uint update_wg_flags
131 )
132 {
133 local uint histogram[RADIX_BINS];
134 const uint numItems = globals->numPrimitives;
135 if(numItems < MORTON_LSB_SORT_NO_SHIFT_THRESHOLD)
136 sort_morton_codes_bin_items_func(globals, global_histogram, wg_flags, input, histogram, iteration, numGroups, numItems, false, update_wg_flags);
137 else
138 sort_morton_codes_bin_items_func(globals, global_histogram, wg_flags, input, histogram, iteration, numGroups, numItems, true, update_wg_flags);
139 }
140
141
sort_morton_codes_reduce_bins_func(global struct Globals * globals,global uint * global_histogram,local uint * partials,uint numTasks,uint iteration,bool shift_primID)142 GRL_INLINE void sort_morton_codes_reduce_bins_func(global struct Globals* globals,
143 global uint* global_histogram,
144 local uint* partials,
145 uint numTasks,
146 uint iteration,
147 bool shift_primID)
148 {
149 const uint localID = get_local_id(0);
150
151 if (shift_primID)
152 {
153 const uint req_iterations = globals->sort_iterations;
154 if (iteration < req_iterations)
155 return;
156 }
157
158 uint t = 0;
159 for (uint j = 0; j < numTasks; j++)
160 {
161 const uint count = load_uint_L1C_L3C(&global_histogram[RADIX_BINS * j + localID], 0);
162 store_uint_L1WB_L3WB(&global_histogram[RADIX_BINS * j + localID], 0, t);
163 t += count;
164 }
165
166 // each lane now contains the number of elements in the corresponding bin
167 // prefix sum this for use in the subsequent scattering pass.
168 uint global_count = t;
169
170 partials[get_sub_group_id()] = sub_group_reduce_add(global_count);
171
172 barrier(CLK_LOCAL_MEM_FENCE);
173
174 uint lane = get_sub_group_local_id();
175 uint p = partials[lane];
176 p = (lane < get_sub_group_id()) ? p : 0;
177
178 global_count = sub_group_reduce_add(p) + sub_group_scan_exclusive_add(global_count);
179
180 store_uint_L1WB_L3WB(&global_histogram[RADIX_BINS * numTasks + localID], 0, global_count);
181 }
182
183 GRL_ANNOTATE_IGC_DO_NOT_SPILL
184 __attribute__((reqd_work_group_size(256, 1, 1)))
185 void kernel
sort_morton_codes_reduce_bins(global struct Globals * globals,uint numTasks,global uint * global_histogram,uint iteration)186 sort_morton_codes_reduce_bins(global struct Globals* globals,
187 uint numTasks,
188 global uint* global_histogram,
189 uint iteration)
190 {
191 local uint partials[RADIX_BINS];
192 const uint numItems = globals->numPrimitives;
193 if (numItems < MORTON_LSB_SORT_NO_SHIFT_THRESHOLD)
194 sort_morton_codes_reduce_bins_func(globals, global_histogram, partials, numTasks, iteration, false);
195 else
196 sort_morton_codes_reduce_bins_func(globals, global_histogram, partials, numTasks, iteration, true);
197 }
198
199
200 #if 1
sort_morton_codes_scatter_items_func(global struct Globals * globals,global uint * global_histogram,global ulong * input,global ulong * output,local uint * local_offset,local uint * flags,uint iteration,uint numGroups,uint numItems,bool shift_primID,bool update_morton_sort_in_flight)201 GRL_INLINE void sort_morton_codes_scatter_items_func(
202 global struct Globals* globals,
203 global uint* global_histogram,
204 global ulong* input,
205 global ulong* output,
206 local uint* local_offset,
207 local uint* flags,
208 uint iteration,
209 uint numGroups,
210 uint numItems,
211 bool shift_primID,
212 bool update_morton_sort_in_flight)
213 {
214 const uint gID = get_local_id(0) + get_group_id(0) * get_local_size(0);
215
216 const uint global_shift = globals->shift;
217 const uint localID = get_local_id(0);
218 const uint taskID = get_group_id(0);
219
220 if (gID == 0 && update_morton_sort_in_flight)
221 globals->morton_sort_in_flight = 0;
222
223 uint2 ids = get_thread_range(numItems, numGroups, taskID);
224 uint startID = ids.x;
225 uint endID = ids.y;
226
227 if (shift_primID)
228 {
229 const uint req_iterations = globals->sort_iterations;
230 if (iteration < req_iterations)
231 return;
232
233 iteration -= req_iterations;
234 }
235
236 const uint shift = 8 * iteration;
237
238 // load the global bin counts, and add each bin's global prefix
239 // to the local prefix
240 {
241 uint global_prefix = 0, local_prefix = 0;
242 if (localID < RADIX_BINS)
243 {
244 local_prefix = global_histogram[RADIX_BINS * taskID + localID];
245 global_prefix = global_histogram[RADIX_BINS * numGroups + localID];
246 local_offset[localID] = global_prefix + local_prefix;
247 }
248
249 barrier(CLK_LOCAL_MEM_FENCE);
250 }
251
252
253 // move elements in WG-sized chunks. The elements need to be moved sequentially (can't use atomics)
254 // because relative order has to be preserved for LSB radix sort to work
255
256 // For each bin, a bit vector indicating which elements are in the bin
257 for (uint block_base = startID; block_base < endID; block_base += get_local_size(0))
258 {
259 // initialize bit vectors
260 for (uint i = 4 * localID; i < RADIX_BINS * SCATTER_WG_SIZE / 32; i += 4 * get_local_size(0))
261 {
262 flags[i + 0] = 0;
263 flags[i + 1] = 0;
264 flags[i + 2] = 0;
265 flags[i + 3] = 0;
266 }
267
268 barrier(CLK_LOCAL_MEM_FENCE);
269
270 // read sort key, determine which bin it goes into, scatter into the bit vector
271 // and pre-load the local offset
272 uint ID = localID + block_base;
273 ulong key = 0;
274 uint bin_offset = 0;
275 uint bin = 0;
276 uint bin_word = localID / 32;
277 uint bin_bit = 1 << (localID % 32);
278
279 if (ID < endID)
280 {
281 key = input[ID];
282
283 if (shift_primID)
284 bin = ((key >> global_shift) >> shift) & (RADIX_BINS - 1);
285 else
286 bin = (key >> shift) & (RADIX_BINS - 1);
287
288 atomic_add_local(&flags[(SCATTER_WG_SIZE / 32) * bin + bin_word], bin_bit);
289 bin_offset = local_offset[bin];
290 }
291
292 barrier(CLK_LOCAL_MEM_FENCE);
293
294 if (ID < endID)
295 {
296 // each key reads the bit-vectors for its bin,
297 // - Computes local prefix sum to determine its output location
298 // - Computes number of items added to its bin (last thread adjusts bin position)
299 uint prefix = 0;
300 uint count = 0;
301 for (uint i = 0; i < (SCATTER_WG_SIZE / 32); i++)
302 {
303 uint bits = flags[(SCATTER_WG_SIZE / 32) * bin + i];
304 uint bc = popcount(bits);
305 uint pc = popcount(bits & (bin_bit - 1));
306 prefix += (i < bin_word) ? bc : 0;
307 prefix += (i == bin_word) ? pc : 0;
308
309 count += bc;
310 }
311
312 // store the key in its proper place..
313 output[prefix + bin_offset] = key;
314
315 // last item for each bin adjusts local offset for next outer loop iteration
316 if (prefix == count - 1)
317 local_offset[bin] += count;
318 }
319
320 barrier(CLK_LOCAL_MEM_FENCE);
321
322 }
323
324 /* uint local_offset[RADIX_BINS]; */
325 /* uint offset_global = 0; */
326 /* for (int i=0;i<RADIX_BINS;i++) */
327 /* { */
328 /* const uint count_global = global_histogram[RADIX_BINS*numTasks+i]; */
329 /* const uint offset_local = global_histogram[RADIX_BINS*taskID+i]; */
330 /* local_offset[i] = offset_global + offset_local; */
331 /* offset_global += count_global; */
332 /* } */
333
334 /* for (uint ID=startID;ID<endID;ID++) */
335 /* { */
336 /* const uint bin = (input[ID] >> shift) & (RADIX_BINS-1); */
337 /* const uint offset = local_offset[bin]; */
338 /* output[offset] = input[ID]; */
339 /* local_offset[bin]++; */
340 /* } */
341 }
342
343 #else
344
345 GRL_ANNOTATE_IGC_DO_NOT_SPILL
346 __attribute__((reqd_work_group_size(16, 1, 1)))
347 __attribute__((intel_reqd_sub_group_size(16))) void kernel
sort_morton_codes_scatter_items(global struct Globals * globals,uint shift,global uint * global_histogram,global char * input0,global char * input1,unsigned int input0_offset,unsigned int input1_offset,uint iteration)348 sort_morton_codes_scatter_items(
349 global struct Globals* globals,
350 uint shift,
351 global uint* global_histogram,
352 global char* input0,
353 global char* input1,
354 unsigned int input0_offset,
355 unsigned int input1_offset,
356 uint iteration)
357 {
358 const uint numItems = globals->numPrimitives;
359 const uint local_size = get_local_size(0);
360 const uint taskID = get_group_id(0);
361 const uint numTasks = get_num_groups(0);
362 const uint localID = get_local_id(0);
363 const uint globalID = get_local_id(0) + get_group_id(0) * get_local_size(0);
364 const uint subgroupLocalID = get_sub_group_local_id();
365 const uint subgroup_size = get_sub_group_size();
366
367 const uint startID = (taskID + 0) * numItems / numTasks;
368 const uint endID = (taskID + 1) * numItems / numTasks;
369
370 global ulong* input = (global ulong*)((iteration % 2) == 0 ? input0 + input0_offset : input1 + input1_offset);
371 global ulong* output = (global ulong*)((iteration % 2) == 0 ? input1 + input1_offset : input0 + input0_offset);
372
373 local uint local_offset[RADIX_BINS];
374 uint off = 0;
375 for (int i = subgroupLocalID; i < RADIX_BINS; i += subgroup_size)
376 {
377 const uint count = global_histogram[RADIX_BINS * numTasks + i];
378 const uint offset_task = global_histogram[RADIX_BINS * taskID + i];
379 const uint sum = sub_group_reduce_add(count);
380 const uint prefix_sum = sub_group_scan_exclusive_add(count);
381 local_offset[i] = off + offset_task + prefix_sum;
382 off += sum;
383 }
384
385 for (uint ID = startID + subgroupLocalID; ID < endID; ID += subgroup_size)
386 {
387 const uint bin = (input[ID] >> shift) & (RADIX_BINS - 1);
388 const uint offset = atomic_add_local(&local_offset[bin], 1);
389 output[offset] = input[ID];
390 }
391
392 /* uint local_offset[RADIX_BINS]; */
393 /* uint offset_global = 0; */
394 /* for (int i=0;i<RADIX_BINS;i++) */
395 /* { */
396 /* const uint count_global = global_histogram[RADIX_BINS*numTasks+i]; */
397 /* const uint offset_local = global_histogram[RADIX_BINS*taskID+i]; */
398 /* local_offset[i] = offset_global + offset_local; */
399 /* offset_global += count_global; */
400 /* } */
401
402 /* for (uint ID=startID;ID<endID;ID++) */
403 /* { */
404 /* const uint bin = (input[ID] >> shift) & (RADIX_BINS-1); */
405 /* const uint offset = local_offset[bin]; */
406 /* output[offset] = input[ID]; */
407 /* local_offset[bin]++; */
408 /* } */
409 }
410 #endif
411
412 #if 1
413 GRL_ANNOTATE_IGC_DO_NOT_SPILL
414 __attribute__((reqd_work_group_size(SCATTER_WG_SIZE, 1, 1)))
415 void kernel
sort_morton_codes_scatter_items(global struct Globals * globals,global uint * global_histogram,global ulong * input,global ulong * output,uint iteration,uint numGroups,uint update_morton_sort_in_flight)416 sort_morton_codes_scatter_items(
417 global struct Globals *globals,
418 global uint *global_histogram,
419 global ulong *input,
420 global ulong *output,
421 uint iteration,
422 uint numGroups,
423 uint update_morton_sort_in_flight)
424 {
425 local uint local_offset[RADIX_BINS];
426 local uint flags[RADIX_BINS*SCATTER_WG_SIZE/32];
427 const uint numItems = globals->numPrimitives;
428 if (numItems < MORTON_LSB_SORT_NO_SHIFT_THRESHOLD)
429 sort_morton_codes_scatter_items_func(globals, global_histogram, input, output, local_offset,
430 flags, iteration, numGroups, numItems, false, update_morton_sort_in_flight);
431 else
432 sort_morton_codes_scatter_items_func(globals, global_histogram, input, output, local_offset,
433 flags, iteration, numGroups, numItems, true, update_morton_sort_in_flight);
434 }
435
436 #else
437
438 GRL_ANNOTATE_IGC_DO_NOT_SPILL
439 __attribute__((reqd_work_group_size(16, 1, 1)))
440 __attribute__((intel_reqd_sub_group_size(16))) void kernel
sort_morton_codes_scatter_items(global struct Globals * globals,uint shift,global uint * global_histogram,global char * input0,global char * input1,unsigned int input0_offset,unsigned int input1_offset,uint iteration)441 sort_morton_codes_scatter_items(
442 global struct Globals *globals,
443 uint shift,
444 global uint *global_histogram,
445 global char *input0,
446 global char *input1,
447 unsigned int input0_offset,
448 unsigned int input1_offset,
449 uint iteration)
450 {
451 const uint numItems = globals->numPrimitives;
452 const uint local_size = get_local_size(0);
453 const uint taskID = get_group_id(0);
454 const uint numTasks = get_num_groups(0);
455 const uint localID = get_local_id(0);
456 const uint globalID = get_local_id(0) + get_group_id(0)*get_local_size(0);
457 const uint subgroupLocalID = get_sub_group_local_id();
458 const uint subgroup_size = get_sub_group_size();
459
460 const uint startID = (taskID + 0) * numItems / numTasks;
461 const uint endID = (taskID + 1) * numItems / numTasks;
462
463 global ulong *input = (global ulong *)((iteration % 2) == 0 ? input0 + input0_offset : input1 + input1_offset);
464 global ulong *output = (global ulong *)((iteration % 2) == 0 ? input1 + input1_offset : input0 + input0_offset);
465
466 local uint local_offset[RADIX_BINS];
467 uint off = 0;
468 for (int i = subgroupLocalID; i < RADIX_BINS; i += subgroup_size)
469 {
470 const uint count = global_histogram[RADIX_BINS * numTasks + i];
471 const uint offset_task = global_histogram[RADIX_BINS * taskID + i];
472 const uint sum = sub_group_reduce_add(count);
473 const uint prefix_sum = sub_group_scan_exclusive_add(count);
474 local_offset[i] = off + offset_task + prefix_sum;
475 off += sum;
476 }
477
478 for (uint ID = startID + subgroupLocalID; ID < endID; ID += subgroup_size)
479 {
480 const uint bin = (input[ID] >> shift) & (RADIX_BINS - 1);
481 const uint offset = atomic_add_local(&local_offset[bin], 1);
482 output[offset] = input[ID];
483 }
484
485 /* uint local_offset[RADIX_BINS]; */
486 /* uint offset_global = 0; */
487 /* for (int i=0;i<RADIX_BINS;i++) */
488 /* { */
489 /* const uint count_global = global_histogram[RADIX_BINS*numTasks+i]; */
490 /* const uint offset_local = global_histogram[RADIX_BINS*taskID+i]; */
491 /* local_offset[i] = offset_global + offset_local; */
492 /* offset_global += count_global; */
493 /* } */
494
495 /* for (uint ID=startID;ID<endID;ID++) */
496 /* { */
497 /* const uint bin = (input[ID] >> shift) & (RADIX_BINS-1); */
498 /* const uint offset = local_offset[bin]; */
499 /* output[offset] = input[ID]; */
500 /* local_offset[bin]++; */
501 /* } */
502 }
503 #endif
504
505 GRL_ANNOTATE_IGC_DO_NOT_SPILL
506 __attribute__((reqd_work_group_size(512, 1, 1)))
507 __attribute__((intel_reqd_sub_group_size(MAX_HW_SIMD_WIDTH)))
508 void kernel
sort_morton_codes_merged(global struct Globals * globals,global uint * global_histogram,global uchar * input,uint iteration,uint numGroups)509 sort_morton_codes_merged(
510 global struct Globals* globals,
511 global uint* global_histogram,
512 global uchar* input,
513 uint iteration,
514 uint numGroups
515 )
516 {
517 const uint numItems = globals->numPrimitives;
518 const uint taskID = get_group_id(0);
519 const uint loc_id = get_local_id(0);
520 const uint lane = get_sub_group_local_id();
521
522 uint2 ids = get_thread_range(numItems, numGroups, taskID);
523 uint startID = ids.x;
524 uint endID = ids.y;
525
526 local uint histogram[RADIX_BINS];
527 local uint hist_tmp[RADIX_BINS];
528
529 if (numItems < MORTON_LSB_SORT_NO_SHIFT_THRESHOLD)
530 {
531 sort_morton_codes_bin_items_taskID_func(globals, global_histogram, input, histogram, iteration, numGroups, numItems, false,
532 taskID, startID, endID);
533 }
534 else
535 {
536 const uint req_iterations = globals->sort_iterations;
537 if (iteration < req_iterations)
538 return;
539
540 iteration -= req_iterations;
541
542 sort_morton_codes_bin_items_taskID_func(globals, global_histogram, input, histogram, iteration, numGroups, numItems, true,
543 taskID, startID, endID);
544 }
545
546 uint last_group = 0;
547 if (loc_id == 0)
548 last_group = atomic_inc_global(&globals->morton_sort_in_flight);
549
550 write_mem_fence(CLK_GLOBAL_MEM_FENCE);
551 barrier(CLK_LOCAL_MEM_FENCE);
552
553 last_group = work_group_broadcast(last_group, 0);
554
555 bool isLastGroup = (loc_id < RADIX_BINS) && (last_group == numGroups - 1);
556
557 uint global_count = 0;
558
559 if (isLastGroup)
560 {
561 for (uint j = 0; j < numGroups; j++)
562 {
563 const uint count = (j == taskID) ? histogram[loc_id] : load_uint_L1C_L3C(&global_histogram[RADIX_BINS * j + loc_id], 0);
564 store_uint_L1WB_L3WB(&global_histogram[RADIX_BINS * j + loc_id], 0, global_count);
565 global_count += count;
566 }
567
568 hist_tmp[get_sub_group_id()] = (get_sub_group_id() < MAX_HW_SIMD_WIDTH) ? sub_group_reduce_add(global_count) : 0;
569 }
570
571 barrier(CLK_LOCAL_MEM_FENCE);
572
573 if (isLastGroup)
574 {
575 uint p = hist_tmp[lane];
576 p = (lane < get_sub_group_id()) ? p : 0;
577
578 global_count = sub_group_reduce_add(p) + sub_group_scan_exclusive_add(global_count);
579
580 store_uint_L1WB_L3WB(&global_histogram[RADIX_BINS * numGroups + loc_id], 0, global_count);
581 }
582 }
583
584 #if 0
585 GRL_ANNOTATE_IGC_DO_NOT_SPILL
586 __attribute__((reqd_work_group_size(16, 1, 1)))
587 __attribute__((intel_reqd_sub_group_size(16))) void kernel
588 sort_morton_codes_bin_items(
589 global struct Globals* globals,
590 uint shift,
591 global uint* global_histogram,
592 global char* input0,
593 global char* input1,
594 unsigned int input0_offset,
595 unsigned int input1_offset,
596 uint iteration)
597 {
598 const uint numItems = globals->numPrimitives;
599 const uint local_size = get_local_size(0);
600 const uint taskID = get_group_id(0);
601 const uint numTasks = get_num_groups(0);
602 const uint localID = get_local_id(0);
603 const uint globalID = get_local_id(0) + get_group_id(0) * get_local_size(0);
604 const uint subgroupLocalID = get_sub_group_local_id();
605 const uint subgroup_size = get_sub_group_size();
606
607 const uint startID = (taskID + 0) * numItems / numTasks;
608 const uint endID = (taskID + 1) * numItems / numTasks;
609
610 global ulong* input = (global ulong*)((iteration % 2) == 0 ? input0 + input0_offset : input1 + input1_offset);
611
612 #if 1
613 local uint histogram[RADIX_BINS];
614 for (uint i = subgroupLocalID; i < RADIX_BINS; i += subgroup_size)
615 histogram[i] = 0;
616
617 for (uint ID = startID + subgroupLocalID; ID < endID; ID += subgroup_size)
618 {
619 const uint bin = ((uint)(input[ID] >> (ulong)shift)) & (RADIX_BINS - 1);
620 atomic_add(&histogram[bin], 1);
621 }
622
623 for (uint i = subgroupLocalID; i < RADIX_BINS; i += subgroup_size)
624 global_histogram[RADIX_BINS * taskID + i] = histogram[i];
625
626 #else
627 uint histogram[RADIX_BINS];
628 for (int i = 0; i < RADIX_BINS; i++)
629 histogram[i] = 0;
630
631 for (uint ID = startID + subgroupLocalID; ID < endID; ID += subgroup_size)
632 {
633 const uint bin = ((uint)(input[ID] >> (ulong)shift)) & (RADIX_BINS - 1);
634 histogram[bin]++;
635 }
636
637 for (uint i = 0; i < RADIX_BINS; i++)
638 {
639 const uint reduced_counter = sub_group_reduce_add(histogram[i]);
640 global_histogram[RADIX_BINS * taskID + i] = reduced_counter;
641 }
642 #endif
643 }
644
645 #endif
646
647 #define WG_SIZE_WIDE 256
648 #define SG_SIZE_SCAN 16
649
650 // Fast implementation of work_group_scan_exclusive using SLM for WG size 256 and SG size 16
work_group_scan_exclusive_add_opt(local uint * tmp,uint val)651 GRL_INLINE uint work_group_scan_exclusive_add_opt(local uint* tmp, uint val)
652 {
653 const uint hw_thread_in_wg_id = get_local_id(0) / SG_SIZE_SCAN;
654 const uint sg_local_id = get_local_id(0) % SG_SIZE_SCAN;
655 const uint NUM_HW_THREADS_IN_WG = WG_SIZE_WIDE / SG_SIZE_SCAN;
656
657 uint acc = sub_group_scan_exclusive_add(val);
658 uint acc2 = acc + val;
659
660 tmp[hw_thread_in_wg_id] = sub_group_broadcast(acc2, SG_SIZE_SCAN - 1);
661 barrier(CLK_LOCAL_MEM_FENCE);
662 uint loaded_val = tmp[sg_local_id];
663 uint wgs_acc = sub_group_scan_exclusive_add(loaded_val);
664 uint acc_for_this_hw_thread = sub_group_broadcast(wgs_acc, hw_thread_in_wg_id);
665 return acc + acc_for_this_hw_thread;
666 }
667
668 // Wide reduce algorithm is divided into 2 kernels:
669 // 1. First, partial exclusive add scans are made within each work group using SLM.
670 // Then, The last work group for each histogram bin perform exclusive add scan along the bins using separate histgram_partials buffer.
671 // Last work group is determined using global atomics on wg_flags buffer.
672 // 2. Second kernel globally adds the values from histgram_partials to the histogram buffer where partial sums are.
673 // Then, last work group performs one more work_group scan and add so the histogram buffer values are adjusted with the global ones.
sort_morton_codes_reduce_bins_wide_partial_sum_func(global struct Globals * globals,global uint * global_histogram,global uint * global_histogram_partials,global uint * wg_flags,local uint * exclusive_scan_tmp,uint numTasks,uint numGroups,uint iteration,bool shift_primID)674 GRL_INLINE void sort_morton_codes_reduce_bins_wide_partial_sum_func(
675 global struct Globals* globals,
676 global uint* global_histogram,
677 global uint* global_histogram_partials,
678 global uint* wg_flags,
679 local uint* exclusive_scan_tmp,
680 uint numTasks,
681 uint numGroups,
682 uint iteration,
683 bool shift_primID)
684 {
685 if (shift_primID)
686 {
687 const uint req_iterations = globals->sort_iterations;
688 if (iteration < req_iterations)
689 return;
690
691 iteration -= req_iterations;
692 }
693
694 const uint groupID = get_group_id(0) % RADIX_BINS;
695 const uint scanGroupID = get_group_id(0) / RADIX_BINS;
696 uint localID = get_local_id(0);
697 uint globalID = localID + (scanGroupID * WG_SIZE_WIDE);
698 const uint lastGroup = (numGroups / WG_SIZE_WIDE);
699 const uint endID = min(numTasks, (uint)(scanGroupID * WG_SIZE_WIDE + WG_SIZE_WIDE)) - 1;
700
701 uint temp = 0;
702 uint last_count = 0;
703 if (globalID < numTasks)
704 {
705 temp = global_histogram[RADIX_BINS * globalID + groupID];
706
707 // Store the last value of the work group, it is either last element of histogram or last item in work group
708 if (globalID == endID)
709 last_count = temp;
710 }
711
712 uint val = work_group_scan_exclusive_add_opt(exclusive_scan_tmp, temp);
713
714 if (globalID <= numTasks)
715 {
716 global_histogram[RADIX_BINS * globalID + groupID] = val;
717
718 // Store the block sum value to separate buffer
719 if (globalID == endID)
720 global_histogram_partials[scanGroupID * WG_SIZE_WIDE + groupID] = val + last_count;
721 }
722
723 // Make sure that global_histogram_partials is updated in all work groups
724 write_mem_fence(CLK_GLOBAL_MEM_FENCE);
725 barrier(0);
726
727 // Now, wait for the last group for each histogram bin, so we know that
728 // all work groups already updated the global_histogram_partials buffer
729 uint last_group = 0;
730 if (localID == 0)
731 last_group = atomic_inc_global(&wg_flags[groupID]);
732
733 last_group = work_group_broadcast(last_group, 0);
734 bool isLastGroup = (last_group == lastGroup - 1);
735
736 // Each of the last groups computes the scan exclusive add for each partial sum we have
737 if (isLastGroup)
738 {
739 uint temp1 = 0;
740 if (localID < lastGroup)
741 temp1 = global_histogram_partials[localID * WG_SIZE_WIDE + groupID];
742
743 uint val2 = work_group_scan_exclusive_add_opt(exclusive_scan_tmp, temp1);
744
745 if (localID < lastGroup)
746 global_histogram_partials[localID * WG_SIZE_WIDE + groupID] = val2;
747 }
748 }
749
sort_morton_codes_reduce_bins_wide_add_reduce_func(global struct Globals * globals,global uint * global_histogram,global uint * global_histogram_partials,local uint * partials,uint numTasks,uint numGroups,uint iteration,bool shift_primID)750 GRL_INLINE void sort_morton_codes_reduce_bins_wide_add_reduce_func(
751 global struct Globals* globals,
752 global uint* global_histogram,
753 global uint* global_histogram_partials,
754 local uint* partials,
755 uint numTasks,
756 uint numGroups,
757 uint iteration,
758 bool shift_primID)
759 {
760 if (shift_primID)
761 {
762 const uint req_iterations = globals->sort_iterations;
763 if (iteration < req_iterations)
764 return;
765
766 iteration -= req_iterations;
767 }
768
769 const uint groupID = get_group_id(0) % RADIX_BINS;
770 const uint scanGroupID = get_group_id(0) / RADIX_BINS;
771 const uint lastGroup = (numGroups / WG_SIZE_WIDE);
772 uint localID = get_local_id(0);
773 uint globalID = localID + (scanGroupID * WG_SIZE_WIDE);
774 const uint endID = min(numTasks, (uint)(scanGroupID * WG_SIZE_WIDE + WG_SIZE_WIDE)) - 1;
775
776 // Add the global sums to the partials, skip the firsy scanGroupID as the first add
777 // value is 0 in case of exclusive add scans
778 if (scanGroupID > 0 && globalID <= numTasks)
779 {
780 uint add_val = global_histogram_partials[scanGroupID * RADIX_BINS + groupID];
781 atomic_add_global(&global_histogram[globalID * RADIX_BINS + groupID], add_val);
782 }
783
784 // Wait for the last group
785 uint last_group = 0;
786 if (localID == 0)
787 last_group = atomic_inc_global(&globals->morton_sort_in_flight);
788
789 last_group = work_group_broadcast(last_group, 0);
790 bool isLastGroup = (last_group == numGroups - 1);
791
792 // Do the exclusive scan within all bins with global data now
793 if (isLastGroup)
794 {
795 mem_fence_gpu_invalidate();
796
797 uint global_count = global_histogram[numTasks * RADIX_BINS + localID];
798
799 partials[get_sub_group_id()] = sub_group_reduce_add(global_count);
800
801 barrier(CLK_LOCAL_MEM_FENCE);
802
803 uint lane = get_sub_group_local_id();
804 uint p = partials[lane];
805 p = (lane < get_sub_group_id()) ? p : 0;
806
807 global_count = sub_group_reduce_add(p) + sub_group_scan_exclusive_add(global_count);
808
809 store_uint_L1WB_L3WB(&global_histogram[numTasks * RADIX_BINS + localID], 0, global_count);
810 }
811 }
812
813
814 GRL_ANNOTATE_IGC_DO_NOT_SPILL
815 __attribute__((reqd_work_group_size(WG_SIZE_WIDE, 1, 1)))
816 __attribute__((intel_reqd_sub_group_size(SG_SIZE_SCAN)))
817 void kernel
sort_morton_codes_reduce_bins_wide_partial_sum(global struct Globals * globals,uint numTasks,uint numGroups,global uint * global_histogram,global uint * global_histogram_partials,global uint * wg_flags,uint iteration)818 sort_morton_codes_reduce_bins_wide_partial_sum(
819 global struct Globals* globals,
820 uint numTasks,
821 uint numGroups,
822 global uint* global_histogram,
823 global uint* global_histogram_partials,
824 global uint* wg_flags,
825 uint iteration)
826 {
827 local uint exclusive_scan_tmp[WG_SIZE_WIDE / SG_SIZE_SCAN];
828
829 const uint numItems = globals->numPrimitives;
830 if (numItems < MORTON_LSB_SORT_NO_SHIFT_THRESHOLD)
831 sort_morton_codes_reduce_bins_wide_partial_sum_func(globals, global_histogram, global_histogram_partials, wg_flags, exclusive_scan_tmp, numTasks, numGroups, iteration, false);
832 else
833 sort_morton_codes_reduce_bins_wide_partial_sum_func(globals, global_histogram, global_histogram_partials, wg_flags, exclusive_scan_tmp, numTasks, numGroups, iteration, true);
834 }
835
836 GRL_ANNOTATE_IGC_DO_NOT_SPILL
837 __attribute__((reqd_work_group_size(WG_SIZE_WIDE, 1, 1)))
838 __attribute__((intel_reqd_sub_group_size(SG_SIZE_SCAN)))
839 void kernel
sort_morton_codes_reduce_bins_wide_add_reduce(global struct Globals * globals,uint numTasks,uint numGroups,global uint * global_histogram,global uint * global_histogram_partials,uint iteration)840 sort_morton_codes_reduce_bins_wide_add_reduce(
841 global struct Globals* globals,
842 uint numTasks,
843 uint numGroups,
844 global uint* global_histogram,
845 global uint* global_histogram_partials,
846 uint iteration)
847 {
848 local uint partials[RADIX_BINS];
849
850 const uint numItems = globals->numPrimitives;
851 if (numItems < MORTON_LSB_SORT_NO_SHIFT_THRESHOLD)
852 sort_morton_codes_reduce_bins_wide_add_reduce_func(globals, global_histogram, global_histogram_partials, partials, numTasks, numGroups, iteration, false);
853 else
854 sort_morton_codes_reduce_bins_wide_add_reduce_func(globals, global_histogram, global_histogram_partials, partials, numTasks, numGroups, iteration, true);
855 }
856