xref: /aosp_15_r20/external/mesa3d/src/intel/vulkan/grl/gpu/morton_radix_sort.h (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 //
2 // Copyright (C) 2009-2021 Intel Corporation
3 //
4 // SPDX-License-Identifier: MIT
5 //
6 //
7 
8 #pragma once
9 
10 #include "common.h"
11 #include "libs/lsc_intrinsics.h"
12 
13 /* ============================================================================= */
14 /* ============================== LSB RADIX SORT =============================== */
15 /* ============================================================================= */
16 
17 #define RADIX_BINS 256
18 #define SCATTER_WG_SIZE 512
19 #define MORTON_LSB_SORT_NO_SHIFT_THRESHOLD 0xFFFFFFFF // turn off, because current hierarchy build requires full sort
20 
get_thread_range(uint numItems,uint numGroups,uint taskID)21 uint2 get_thread_range( uint numItems, uint numGroups, uint taskID )
22 {
23     uint items_per_group = (numItems / numGroups);
24     uint remainder = numItems - (items_per_group * numGroups);
25     uint startID = taskID * items_per_group  + min(taskID, remainder);
26     uint endID   = startID + items_per_group + ((taskID < remainder) ? 1 : 0);
27 
28     return (uint2)(startID,endID);
29 }
30 
sort_morton_codes_bin_items_taskID_func(global struct Globals * globals,global uint * global_histogram,global uchar * input,local uint * histogram,uint iteration,uint numGroups,uint numItems,bool shift_primID,uint taskID,uint startID,uint endID)31 GRL_INLINE void sort_morton_codes_bin_items_taskID_func(global struct Globals* globals,
32                                                  global uint* global_histogram,
33                                                  global uchar* input,
34                                                  local uint* histogram,
35                                                  uint iteration,
36                                                  uint numGroups,
37                                                  uint numItems,
38                                                  bool shift_primID,
39                                                  uint taskID,
40                                                  uint startID,
41                                                  uint endID)
42 {
43     const uint shift = globals->shift;
44 
45     for (uint i = get_local_id(0); i < RADIX_BINS; i += get_local_size(0))
46         histogram[i] = 0;
47 
48     barrier(CLK_LOCAL_MEM_FENCE);
49 
50     if (shift_primID)
51     {
52         for (uint i = startID + get_local_id(0); i < endID; i += get_local_size(0))
53         {
54             // Read input as ulong to make bitshift, so the bits representing primID are not being
55             // taken into account during sorting, which would result in smaller sort loops for
56             // cases where morton shift are bigger than 8 bits
57             ulong* ptr_ul = (ulong*)&input[8 * i];
58             ulong code = *ptr_ul;
59             uchar* ptr = (uchar*)&code;
60             code >>= shift;
61 
62             uchar bin = ptr[iteration];
63             atomic_inc_local(&histogram[bin]);
64         }
65     }
66     else
67     {
68         for (uint i = startID + get_local_id(0); i < endID; i += get_local_size(0))
69         {
70             uchar bin = input[8 * i + iteration];
71             atomic_inc_local(&histogram[bin]);
72         }
73     }
74 
75     barrier(CLK_LOCAL_MEM_FENCE);
76 
77     for (uint i = get_local_id(0); i < RADIX_BINS; i += get_local_size(0))
78         global_histogram[RADIX_BINS * taskID + i] = histogram[i];
79 }
80 
sort_morton_codes_bin_items_func(global struct Globals * globals,global uint * global_histogram,global uint * wg_flags,global uchar * input,local uint * histogram,uint iteration,uint numGroups,uint numItems,bool shift_primID,bool update_wg_flags)81 GRL_INLINE void sort_morton_codes_bin_items_func(global struct Globals* globals,
82     global uint* global_histogram,
83     global uint* wg_flags,
84     global uchar* input,
85     local uint* histogram,
86     uint iteration,
87     uint numGroups,
88     uint numItems,
89     bool shift_primID,
90     bool update_wg_flags)
91 {
92     if (shift_primID)
93     {
94         // This check is present in other LSB sort functions as well, its purpose is
95         // to skip first n iterations where n is the difference between max iterations
96         // and actually needed iterations to sort without primIDs
97         const uint req_iterations = globals->sort_iterations;
98         if (iteration < req_iterations)
99             return;
100 
101         // iteration needs to be adjusted to reflect the skipped cycles
102         iteration -= req_iterations;
103     }
104 
105     const uint taskID = get_group_id(0);
106 
107     if (taskID == 0 && update_wg_flags)
108     {
109         for (uint i = get_local_id(0); i < RADIX_BINS; i += get_local_size(0))
110             wg_flags[i] = 0;
111     }
112 
113     uint2 ids = get_thread_range(numItems, numGroups, taskID);
114     uint startID = ids.x;
115     uint endID = ids.y;
116 
117     sort_morton_codes_bin_items_taskID_func(globals, global_histogram, input, histogram, iteration, numGroups, numItems, shift_primID,
118                                             taskID, startID, endID);
119 }
120 
121 __attribute__((reqd_work_group_size(512, 1, 1)))
122 void kernel
sort_morton_codes_bin_items(global struct Globals * globals,global uint * global_histogram,global uint * wg_flags,global uchar * input,uint iteration,uint numGroups,uint update_wg_flags)123 sort_morton_codes_bin_items(
124     global struct Globals* globals,
125     global uint* global_histogram,
126     global uint* wg_flags,
127     global uchar* input,
128     uint iteration,
129     uint numGroups,
130     uint update_wg_flags
131 )
132 {
133     local uint histogram[RADIX_BINS];
134     const uint numItems = globals->numPrimitives;
135     if(numItems < MORTON_LSB_SORT_NO_SHIFT_THRESHOLD)
136         sort_morton_codes_bin_items_func(globals, global_histogram, wg_flags, input, histogram, iteration, numGroups, numItems, false, update_wg_flags);
137     else
138         sort_morton_codes_bin_items_func(globals, global_histogram, wg_flags, input, histogram, iteration, numGroups, numItems, true, update_wg_flags);
139 }
140 
141 
sort_morton_codes_reduce_bins_func(global struct Globals * globals,global uint * global_histogram,local uint * partials,uint numTasks,uint iteration,bool shift_primID)142 GRL_INLINE void sort_morton_codes_reduce_bins_func(global struct Globals* globals,
143                                                    global uint* global_histogram,
144                                                    local uint* partials,
145                                                    uint numTasks,
146                                                    uint iteration,
147                                                    bool shift_primID)
148 {
149     const uint localID = get_local_id(0);
150 
151     if (shift_primID)
152     {
153         const uint req_iterations = globals->sort_iterations;
154         if (iteration < req_iterations)
155             return;
156     }
157 
158     uint t = 0;
159     for (uint j = 0; j < numTasks; j++)
160     {
161         const uint count = load_uint_L1C_L3C(&global_histogram[RADIX_BINS * j + localID], 0);
162         store_uint_L1WB_L3WB(&global_histogram[RADIX_BINS * j + localID], 0, t);
163         t += count;
164     }
165 
166     // each lane now contains the number of elements in the corresponding bin
167     //     prefix sum this for use in the subsequent scattering pass.
168     uint global_count = t;
169 
170     partials[get_sub_group_id()] = sub_group_reduce_add(global_count);
171 
172     barrier(CLK_LOCAL_MEM_FENCE);
173 
174     uint lane = get_sub_group_local_id();
175     uint p = partials[lane];
176     p = (lane < get_sub_group_id()) ? p : 0;
177 
178     global_count = sub_group_reduce_add(p) + sub_group_scan_exclusive_add(global_count);
179 
180     store_uint_L1WB_L3WB(&global_histogram[RADIX_BINS * numTasks + localID], 0, global_count);
181 }
182 
183 GRL_ANNOTATE_IGC_DO_NOT_SPILL
184 __attribute__((reqd_work_group_size(256, 1, 1)))
185 void kernel
sort_morton_codes_reduce_bins(global struct Globals * globals,uint numTasks,global uint * global_histogram,uint iteration)186 sort_morton_codes_reduce_bins(global struct Globals* globals,
187     uint numTasks,
188     global uint* global_histogram,
189     uint iteration)
190 {
191     local uint partials[RADIX_BINS];
192     const uint numItems = globals->numPrimitives;
193     if (numItems < MORTON_LSB_SORT_NO_SHIFT_THRESHOLD)
194         sort_morton_codes_reduce_bins_func(globals, global_histogram, partials, numTasks, iteration, false);
195     else
196         sort_morton_codes_reduce_bins_func(globals, global_histogram, partials, numTasks, iteration, true);
197 }
198 
199 
200 #if 1
sort_morton_codes_scatter_items_func(global struct Globals * globals,global uint * global_histogram,global ulong * input,global ulong * output,local uint * local_offset,local uint * flags,uint iteration,uint numGroups,uint numItems,bool shift_primID,bool update_morton_sort_in_flight)201 GRL_INLINE void sort_morton_codes_scatter_items_func(
202     global struct Globals* globals,
203     global uint* global_histogram,
204     global ulong* input,
205     global ulong* output,
206     local uint* local_offset,
207     local uint* flags,
208     uint iteration,
209     uint numGroups,
210     uint numItems,
211     bool shift_primID,
212     bool update_morton_sort_in_flight)
213 {
214     const uint gID = get_local_id(0) + get_group_id(0) * get_local_size(0);
215 
216     const uint global_shift = globals->shift;
217     const uint localID = get_local_id(0);
218     const uint taskID = get_group_id(0);
219 
220     if (gID == 0 && update_morton_sort_in_flight)
221         globals->morton_sort_in_flight = 0;
222 
223     uint2 ids = get_thread_range(numItems, numGroups, taskID);
224     uint startID = ids.x;
225     uint endID = ids.y;
226 
227     if (shift_primID)
228     {
229         const uint req_iterations = globals->sort_iterations;
230         if (iteration < req_iterations)
231             return;
232 
233         iteration -= req_iterations;
234     }
235 
236     const uint shift = 8 * iteration;
237 
238     // load the global bin counts, and add each bin's global prefix
239     //   to the local prefix
240     {
241         uint global_prefix = 0, local_prefix = 0;
242         if (localID < RADIX_BINS)
243         {
244             local_prefix = global_histogram[RADIX_BINS * taskID + localID];
245             global_prefix = global_histogram[RADIX_BINS * numGroups + localID];
246             local_offset[localID] = global_prefix + local_prefix;
247         }
248 
249         barrier(CLK_LOCAL_MEM_FENCE);
250     }
251 
252 
253     // move elements in WG-sized chunks.   The elements need to be moved sequentially (can't use atomics)
254     //   because relative order has to be preserved for LSB radix sort to work
255 
256     // For each bin, a bit vector indicating which elements are in the bin
257     for (uint block_base = startID; block_base < endID; block_base += get_local_size(0))
258     {
259         // initialize bit vectors
260         for (uint i = 4 * localID; i < RADIX_BINS * SCATTER_WG_SIZE / 32; i += 4 * get_local_size(0))
261         {
262             flags[i + 0] = 0;
263             flags[i + 1] = 0;
264             flags[i + 2] = 0;
265             flags[i + 3] = 0;
266         }
267 
268         barrier(CLK_LOCAL_MEM_FENCE);
269 
270         // read sort key, determine which bin it goes into, scatter into the bit vector
271         //  and pre-load the local offset
272         uint ID = localID + block_base;
273         ulong key = 0;
274         uint bin_offset = 0;
275         uint bin = 0;
276         uint bin_word = localID / 32;
277         uint bin_bit = 1 << (localID % 32);
278 
279         if (ID < endID)
280         {
281             key = input[ID];
282 
283             if (shift_primID)
284                 bin = ((key >> global_shift) >> shift) & (RADIX_BINS - 1);
285             else
286                 bin = (key >> shift) & (RADIX_BINS - 1);
287 
288             atomic_add_local(&flags[(SCATTER_WG_SIZE / 32) * bin + bin_word], bin_bit);
289             bin_offset = local_offset[bin];
290         }
291 
292         barrier(CLK_LOCAL_MEM_FENCE);
293 
294         if (ID < endID)
295         {
296             // each key reads the bit-vectors for its bin,
297             //    - Computes local prefix sum to determine its output location
298             //    - Computes number of items added to its bin (last thread adjusts bin position)
299             uint prefix = 0;
300             uint count = 0;
301             for (uint i = 0; i < (SCATTER_WG_SIZE / 32); i++)
302             {
303                 uint bits = flags[(SCATTER_WG_SIZE / 32) * bin + i];
304                 uint bc = popcount(bits);
305                 uint pc = popcount(bits & (bin_bit - 1));
306                 prefix += (i < bin_word) ? bc : 0;
307                 prefix += (i == bin_word) ? pc : 0;
308 
309                 count += bc;
310             }
311 
312             // store the key in its proper place..
313             output[prefix + bin_offset] = key;
314 
315             // last item for each bin adjusts local offset for next outer loop iteration
316             if (prefix == count - 1)
317                 local_offset[bin] += count;
318         }
319 
320         barrier(CLK_LOCAL_MEM_FENCE);
321 
322     }
323 
324     /* uint local_offset[RADIX_BINS];   */
325     /* uint offset_global = 0; */
326     /* for (int i=0;i<RADIX_BINS;i++) */
327     /*   { */
328     /*     const uint count_global = global_histogram[RADIX_BINS*numTasks+i]; */
329     /*     const uint offset_local  = global_histogram[RADIX_BINS*taskID+i]; */
330     /*     local_offset[i] = offset_global + offset_local; */
331     /*     offset_global += count_global; */
332     /*   } */
333 
334     /* for (uint ID=startID;ID<endID;ID++) */
335     /* { */
336     /*   const uint bin = (input[ID] >> shift) & (RADIX_BINS-1); */
337     /*   const uint offset = local_offset[bin]; */
338     /*   output[offset] = input[ID]; */
339     /*   local_offset[bin]++; */
340     /* } */
341 }
342 
343 #else
344 
345 GRL_ANNOTATE_IGC_DO_NOT_SPILL
346 __attribute__((reqd_work_group_size(16, 1, 1)))
347 __attribute__((intel_reqd_sub_group_size(16))) void kernel
sort_morton_codes_scatter_items(global struct Globals * globals,uint shift,global uint * global_histogram,global char * input0,global char * input1,unsigned int input0_offset,unsigned int input1_offset,uint iteration)348 sort_morton_codes_scatter_items(
349     global struct Globals* globals,
350     uint shift,
351     global uint* global_histogram,
352     global char* input0,
353     global char* input1,
354     unsigned int input0_offset,
355     unsigned int input1_offset,
356     uint iteration)
357 {
358     const uint numItems = globals->numPrimitives;
359     const uint local_size = get_local_size(0);
360     const uint taskID = get_group_id(0);
361     const uint numTasks = get_num_groups(0);
362     const uint localID = get_local_id(0);
363     const uint globalID = get_local_id(0) + get_group_id(0) * get_local_size(0);
364     const uint subgroupLocalID = get_sub_group_local_id();
365     const uint subgroup_size = get_sub_group_size();
366 
367     const uint startID = (taskID + 0) * numItems / numTasks;
368     const uint endID = (taskID + 1) * numItems / numTasks;
369 
370     global ulong* input = (global ulong*)((iteration % 2) == 0 ? input0 + input0_offset : input1 + input1_offset);
371     global ulong* output = (global ulong*)((iteration % 2) == 0 ? input1 + input1_offset : input0 + input0_offset);
372 
373     local uint local_offset[RADIX_BINS];
374     uint off = 0;
375     for (int i = subgroupLocalID; i < RADIX_BINS; i += subgroup_size)
376     {
377         const uint count = global_histogram[RADIX_BINS * numTasks + i];
378         const uint offset_task = global_histogram[RADIX_BINS * taskID + i];
379         const uint sum = sub_group_reduce_add(count);
380         const uint prefix_sum = sub_group_scan_exclusive_add(count);
381         local_offset[i] = off + offset_task + prefix_sum;
382         off += sum;
383     }
384 
385     for (uint ID = startID + subgroupLocalID; ID < endID; ID += subgroup_size)
386     {
387         const uint bin = (input[ID] >> shift) & (RADIX_BINS - 1);
388         const uint offset = atomic_add_local(&local_offset[bin], 1);
389         output[offset] = input[ID];
390     }
391 
392     /* uint local_offset[RADIX_BINS];   */
393     /* uint offset_global = 0; */
394     /* for (int i=0;i<RADIX_BINS;i++) */
395     /*   { */
396     /*     const uint count_global = global_histogram[RADIX_BINS*numTasks+i]; */
397     /*     const uint offset_local  = global_histogram[RADIX_BINS*taskID+i]; */
398     /*     local_offset[i] = offset_global + offset_local; */
399     /*     offset_global += count_global; */
400     /*   } */
401 
402     /* for (uint ID=startID;ID<endID;ID++) */
403     /* { */
404     /*   const uint bin = (input[ID] >> shift) & (RADIX_BINS-1); */
405     /*   const uint offset = local_offset[bin]; */
406     /*   output[offset] = input[ID]; */
407     /*   local_offset[bin]++; */
408     /* } */
409 }
410 #endif
411 
412 #if 1
413 GRL_ANNOTATE_IGC_DO_NOT_SPILL
414 __attribute__((reqd_work_group_size(SCATTER_WG_SIZE, 1, 1)))
415 void kernel
sort_morton_codes_scatter_items(global struct Globals * globals,global uint * global_histogram,global ulong * input,global ulong * output,uint iteration,uint numGroups,uint update_morton_sort_in_flight)416 sort_morton_codes_scatter_items(
417     global struct Globals *globals,
418     global uint *global_histogram,
419     global ulong *input,
420     global ulong *output,
421     uint iteration,
422     uint numGroups,
423     uint update_morton_sort_in_flight)
424 {
425     local uint local_offset[RADIX_BINS];
426     local uint flags[RADIX_BINS*SCATTER_WG_SIZE/32];
427     const uint numItems = globals->numPrimitives;
428     if (numItems < MORTON_LSB_SORT_NO_SHIFT_THRESHOLD)
429         sort_morton_codes_scatter_items_func(globals, global_histogram, input, output, local_offset,
430                                              flags, iteration, numGroups, numItems, false, update_morton_sort_in_flight);
431     else
432         sort_morton_codes_scatter_items_func(globals, global_histogram, input, output, local_offset,
433                                              flags, iteration, numGroups, numItems, true, update_morton_sort_in_flight);
434 }
435 
436 #else
437 
438 GRL_ANNOTATE_IGC_DO_NOT_SPILL
439 __attribute__((reqd_work_group_size(16, 1, 1)))
440 __attribute__((intel_reqd_sub_group_size(16))) void kernel
sort_morton_codes_scatter_items(global struct Globals * globals,uint shift,global uint * global_histogram,global char * input0,global char * input1,unsigned int input0_offset,unsigned int input1_offset,uint iteration)441 sort_morton_codes_scatter_items(
442     global struct Globals *globals,
443     uint shift,
444     global uint *global_histogram,
445     global char *input0,
446     global char *input1,
447     unsigned int input0_offset,
448     unsigned int input1_offset,
449     uint iteration)
450 {
451     const uint numItems = globals->numPrimitives;
452     const uint local_size = get_local_size(0);
453     const uint taskID = get_group_id(0);
454     const uint numTasks = get_num_groups(0);
455     const uint localID = get_local_id(0);
456     const uint globalID = get_local_id(0) + get_group_id(0)*get_local_size(0);
457     const uint subgroupLocalID = get_sub_group_local_id();
458     const uint subgroup_size = get_sub_group_size();
459 
460     const uint startID = (taskID + 0) * numItems / numTasks;
461     const uint endID = (taskID + 1) * numItems / numTasks;
462 
463     global ulong *input = (global ulong *)((iteration % 2) == 0 ? input0 + input0_offset : input1 + input1_offset);
464     global ulong *output = (global ulong *)((iteration % 2) == 0 ? input1 + input1_offset : input0 + input0_offset);
465 
466     local uint local_offset[RADIX_BINS];
467     uint off = 0;
468     for (int i = subgroupLocalID; i < RADIX_BINS; i += subgroup_size)
469     {
470         const uint count = global_histogram[RADIX_BINS * numTasks + i];
471         const uint offset_task = global_histogram[RADIX_BINS * taskID + i];
472         const uint sum = sub_group_reduce_add(count);
473         const uint prefix_sum = sub_group_scan_exclusive_add(count);
474         local_offset[i] = off + offset_task + prefix_sum;
475         off += sum;
476     }
477 
478     for (uint ID = startID + subgroupLocalID; ID < endID; ID += subgroup_size)
479     {
480         const uint bin = (input[ID] >> shift) & (RADIX_BINS - 1);
481         const uint offset = atomic_add_local(&local_offset[bin], 1);
482         output[offset] = input[ID];
483     }
484 
485     /* uint local_offset[RADIX_BINS];   */
486     /* uint offset_global = 0; */
487     /* for (int i=0;i<RADIX_BINS;i++) */
488     /*   { */
489     /*     const uint count_global = global_histogram[RADIX_BINS*numTasks+i]; */
490     /*     const uint offset_local  = global_histogram[RADIX_BINS*taskID+i]; */
491     /*     local_offset[i] = offset_global + offset_local; */
492     /*     offset_global += count_global; */
493     /*   } */
494 
495     /* for (uint ID=startID;ID<endID;ID++) */
496     /* { */
497     /*   const uint bin = (input[ID] >> shift) & (RADIX_BINS-1); */
498     /*   const uint offset = local_offset[bin]; */
499     /*   output[offset] = input[ID]; */
500     /*   local_offset[bin]++; */
501     /* } */
502 }
503 #endif
504 
505 GRL_ANNOTATE_IGC_DO_NOT_SPILL
506 __attribute__((reqd_work_group_size(512, 1, 1)))
507 __attribute__((intel_reqd_sub_group_size(MAX_HW_SIMD_WIDTH)))
508 void kernel
sort_morton_codes_merged(global struct Globals * globals,global uint * global_histogram,global uchar * input,uint iteration,uint numGroups)509 sort_morton_codes_merged(
510     global struct Globals* globals,
511     global uint* global_histogram,
512     global uchar* input,
513     uint iteration,
514     uint numGroups
515 )
516 {
517     const uint numItems = globals->numPrimitives;
518     const uint taskID = get_group_id(0);
519     const uint loc_id = get_local_id(0);
520     const uint lane = get_sub_group_local_id();
521 
522     uint2 ids = get_thread_range(numItems, numGroups, taskID);
523     uint startID = ids.x;
524     uint endID = ids.y;
525 
526     local uint histogram[RADIX_BINS];
527     local uint hist_tmp[RADIX_BINS];
528 
529     if (numItems < MORTON_LSB_SORT_NO_SHIFT_THRESHOLD)
530     {
531         sort_morton_codes_bin_items_taskID_func(globals, global_histogram, input, histogram, iteration, numGroups, numItems, false,
532             taskID, startID, endID);
533     }
534     else
535     {
536         const uint req_iterations = globals->sort_iterations;
537         if (iteration < req_iterations)
538             return;
539 
540         iteration -= req_iterations;
541 
542         sort_morton_codes_bin_items_taskID_func(globals, global_histogram, input, histogram, iteration, numGroups, numItems, true,
543             taskID, startID, endID);
544     }
545 
546     uint last_group = 0;
547     if (loc_id == 0)
548         last_group = atomic_inc_global(&globals->morton_sort_in_flight);
549 
550     write_mem_fence(CLK_GLOBAL_MEM_FENCE);
551     barrier(CLK_LOCAL_MEM_FENCE);
552 
553     last_group = work_group_broadcast(last_group, 0);
554 
555     bool isLastGroup = (loc_id < RADIX_BINS) && (last_group == numGroups - 1);
556 
557     uint global_count = 0;
558 
559     if (isLastGroup)
560     {
561         for (uint j = 0; j < numGroups; j++)
562         {
563             const uint count = (j == taskID) ? histogram[loc_id] : load_uint_L1C_L3C(&global_histogram[RADIX_BINS * j + loc_id], 0);
564             store_uint_L1WB_L3WB(&global_histogram[RADIX_BINS * j + loc_id], 0, global_count);
565             global_count += count;
566         }
567 
568         hist_tmp[get_sub_group_id()] = (get_sub_group_id() < MAX_HW_SIMD_WIDTH) ? sub_group_reduce_add(global_count) : 0;
569     }
570 
571     barrier(CLK_LOCAL_MEM_FENCE);
572 
573     if (isLastGroup)
574     {
575         uint p = hist_tmp[lane];
576         p = (lane < get_sub_group_id()) ? p : 0;
577 
578         global_count = sub_group_reduce_add(p) + sub_group_scan_exclusive_add(global_count);
579 
580         store_uint_L1WB_L3WB(&global_histogram[RADIX_BINS * numGroups + loc_id], 0, global_count);
581     }
582 }
583 
584 #if 0
585 GRL_ANNOTATE_IGC_DO_NOT_SPILL
586 __attribute__((reqd_work_group_size(16, 1, 1)))
587 __attribute__((intel_reqd_sub_group_size(16))) void kernel
588 sort_morton_codes_bin_items(
589     global struct Globals* globals,
590     uint shift,
591     global uint* global_histogram,
592     global char* input0,
593     global char* input1,
594     unsigned int input0_offset,
595     unsigned int input1_offset,
596     uint iteration)
597 {
598     const uint numItems = globals->numPrimitives;
599     const uint local_size = get_local_size(0);
600     const uint taskID = get_group_id(0);
601     const uint numTasks = get_num_groups(0);
602     const uint localID = get_local_id(0);
603     const uint globalID = get_local_id(0) + get_group_id(0) * get_local_size(0);
604     const uint subgroupLocalID = get_sub_group_local_id();
605     const uint subgroup_size = get_sub_group_size();
606 
607     const uint startID = (taskID + 0) * numItems / numTasks;
608     const uint endID = (taskID + 1) * numItems / numTasks;
609 
610     global ulong* input = (global ulong*)((iteration % 2) == 0 ? input0 + input0_offset : input1 + input1_offset);
611 
612 #if 1
613     local uint histogram[RADIX_BINS];
614     for (uint i = subgroupLocalID; i < RADIX_BINS; i += subgroup_size)
615         histogram[i] = 0;
616 
617     for (uint ID = startID + subgroupLocalID; ID < endID; ID += subgroup_size)
618     {
619         const uint bin = ((uint)(input[ID] >> (ulong)shift)) & (RADIX_BINS - 1);
620         atomic_add(&histogram[bin], 1);
621     }
622 
623     for (uint i = subgroupLocalID; i < RADIX_BINS; i += subgroup_size)
624         global_histogram[RADIX_BINS * taskID + i] = histogram[i];
625 
626 #else
627     uint histogram[RADIX_BINS];
628     for (int i = 0; i < RADIX_BINS; i++)
629         histogram[i] = 0;
630 
631     for (uint ID = startID + subgroupLocalID; ID < endID; ID += subgroup_size)
632     {
633         const uint bin = ((uint)(input[ID] >> (ulong)shift)) & (RADIX_BINS - 1);
634         histogram[bin]++;
635     }
636 
637     for (uint i = 0; i < RADIX_BINS; i++)
638     {
639         const uint reduced_counter = sub_group_reduce_add(histogram[i]);
640         global_histogram[RADIX_BINS * taskID + i] = reduced_counter;
641     }
642 #endif
643 }
644 
645 #endif
646 
647 #define WG_SIZE_WIDE 256
648 #define SG_SIZE_SCAN 16
649 
650 // Fast implementation of work_group_scan_exclusive using SLM for WG size 256 and SG size 16
work_group_scan_exclusive_add_opt(local uint * tmp,uint val)651 GRL_INLINE uint work_group_scan_exclusive_add_opt(local uint* tmp, uint val)
652 {
653     const uint hw_thread_in_wg_id = get_local_id(0) / SG_SIZE_SCAN;
654     const uint sg_local_id = get_local_id(0) % SG_SIZE_SCAN;
655     const uint NUM_HW_THREADS_IN_WG = WG_SIZE_WIDE / SG_SIZE_SCAN;
656 
657     uint acc = sub_group_scan_exclusive_add(val);
658     uint acc2 = acc + val;
659 
660     tmp[hw_thread_in_wg_id] = sub_group_broadcast(acc2, SG_SIZE_SCAN - 1);
661     barrier(CLK_LOCAL_MEM_FENCE);
662     uint loaded_val = tmp[sg_local_id];
663     uint wgs_acc = sub_group_scan_exclusive_add(loaded_val);
664     uint acc_for_this_hw_thread = sub_group_broadcast(wgs_acc, hw_thread_in_wg_id);
665     return acc + acc_for_this_hw_thread;
666 }
667 
668 // Wide reduce algorithm is divided into 2 kernels:
669 // 1. First, partial exclusive add scans are made within each work group using SLM.
670 //    Then, The last work group for each histogram bin perform exclusive add scan along the bins using separate histgram_partials buffer.
671 //    Last work group is determined using global atomics on wg_flags buffer.
672 // 2. Second kernel globally adds the values from histgram_partials to the histogram buffer where partial sums are.
673 //    Then, last work group performs one more work_group scan and add so the histogram buffer values are adjusted with the global ones.
sort_morton_codes_reduce_bins_wide_partial_sum_func(global struct Globals * globals,global uint * global_histogram,global uint * global_histogram_partials,global uint * wg_flags,local uint * exclusive_scan_tmp,uint numTasks,uint numGroups,uint iteration,bool shift_primID)674 GRL_INLINE void sort_morton_codes_reduce_bins_wide_partial_sum_func(
675     global struct Globals* globals,
676     global uint* global_histogram,
677     global uint* global_histogram_partials,
678     global uint* wg_flags,
679     local uint* exclusive_scan_tmp,
680     uint numTasks,
681     uint numGroups,
682     uint iteration,
683     bool shift_primID)
684 {
685     if (shift_primID)
686     {
687         const uint req_iterations = globals->sort_iterations;
688         if (iteration < req_iterations)
689             return;
690 
691         iteration -= req_iterations;
692     }
693 
694     const uint groupID = get_group_id(0) % RADIX_BINS;
695     const uint scanGroupID = get_group_id(0) / RADIX_BINS;
696     uint localID = get_local_id(0);
697     uint globalID = localID + (scanGroupID * WG_SIZE_WIDE);
698     const uint lastGroup = (numGroups / WG_SIZE_WIDE);
699     const uint endID = min(numTasks, (uint)(scanGroupID * WG_SIZE_WIDE + WG_SIZE_WIDE)) - 1;
700 
701     uint temp = 0;
702     uint last_count = 0;
703     if (globalID < numTasks)
704     {
705         temp = global_histogram[RADIX_BINS * globalID + groupID];
706 
707         // Store the last value of the work group, it is either last element of histogram or last item in work group
708         if (globalID == endID)
709             last_count = temp;
710     }
711 
712     uint val = work_group_scan_exclusive_add_opt(exclusive_scan_tmp, temp);
713 
714     if (globalID <= numTasks)
715     {
716         global_histogram[RADIX_BINS * globalID + groupID] = val;
717 
718         // Store the block sum value to separate buffer
719         if (globalID == endID)
720             global_histogram_partials[scanGroupID * WG_SIZE_WIDE + groupID] = val + last_count;
721     }
722 
723     // Make sure that global_histogram_partials is updated in all work groups
724     write_mem_fence(CLK_GLOBAL_MEM_FENCE);
725     barrier(0);
726 
727     // Now, wait for the last group for each histogram bin, so we know that
728     // all work groups already updated the global_histogram_partials buffer
729     uint last_group = 0;
730     if (localID == 0)
731         last_group = atomic_inc_global(&wg_flags[groupID]);
732 
733     last_group = work_group_broadcast(last_group, 0);
734     bool isLastGroup = (last_group == lastGroup - 1);
735 
736     // Each of the last groups computes the scan exclusive add for each partial sum we have
737     if (isLastGroup)
738     {
739         uint temp1 = 0;
740         if (localID < lastGroup)
741             temp1 = global_histogram_partials[localID * WG_SIZE_WIDE + groupID];
742 
743         uint val2 = work_group_scan_exclusive_add_opt(exclusive_scan_tmp, temp1);
744 
745         if (localID < lastGroup)
746             global_histogram_partials[localID * WG_SIZE_WIDE + groupID] = val2;
747     }
748 }
749 
sort_morton_codes_reduce_bins_wide_add_reduce_func(global struct Globals * globals,global uint * global_histogram,global uint * global_histogram_partials,local uint * partials,uint numTasks,uint numGroups,uint iteration,bool shift_primID)750 GRL_INLINE void sort_morton_codes_reduce_bins_wide_add_reduce_func(
751     global struct Globals* globals,
752     global uint* global_histogram,
753     global uint* global_histogram_partials,
754     local uint* partials,
755     uint numTasks,
756     uint numGroups,
757     uint iteration,
758     bool shift_primID)
759 {
760     if (shift_primID)
761     {
762         const uint req_iterations = globals->sort_iterations;
763         if (iteration < req_iterations)
764             return;
765 
766         iteration -= req_iterations;
767     }
768 
769     const uint groupID = get_group_id(0) % RADIX_BINS;
770     const uint scanGroupID = get_group_id(0) / RADIX_BINS;
771     const uint lastGroup = (numGroups / WG_SIZE_WIDE);
772     uint localID = get_local_id(0);
773     uint globalID = localID + (scanGroupID * WG_SIZE_WIDE);
774     const uint endID = min(numTasks, (uint)(scanGroupID * WG_SIZE_WIDE + WG_SIZE_WIDE)) - 1;
775 
776     // Add the global sums to the partials, skip the firsy scanGroupID as the first add
777     // value is 0 in case of exclusive add scans
778     if (scanGroupID > 0 && globalID <= numTasks)
779     {
780         uint add_val = global_histogram_partials[scanGroupID * RADIX_BINS + groupID];
781         atomic_add_global(&global_histogram[globalID * RADIX_BINS + groupID], add_val);
782     }
783 
784     // Wait for the last group
785     uint last_group = 0;
786     if (localID == 0)
787         last_group = atomic_inc_global(&globals->morton_sort_in_flight);
788 
789     last_group = work_group_broadcast(last_group, 0);
790     bool isLastGroup = (last_group == numGroups - 1);
791 
792     // Do the exclusive scan within all bins with global data now
793     if (isLastGroup)
794     {
795         mem_fence_gpu_invalidate();
796 
797         uint global_count = global_histogram[numTasks * RADIX_BINS + localID];
798 
799         partials[get_sub_group_id()] = sub_group_reduce_add(global_count);
800 
801         barrier(CLK_LOCAL_MEM_FENCE);
802 
803         uint lane = get_sub_group_local_id();
804         uint p = partials[lane];
805         p = (lane < get_sub_group_id()) ? p : 0;
806 
807         global_count = sub_group_reduce_add(p) + sub_group_scan_exclusive_add(global_count);
808 
809         store_uint_L1WB_L3WB(&global_histogram[numTasks * RADIX_BINS + localID], 0, global_count);
810     }
811 }
812 
813 
814 GRL_ANNOTATE_IGC_DO_NOT_SPILL
815 __attribute__((reqd_work_group_size(WG_SIZE_WIDE, 1, 1)))
816 __attribute__((intel_reqd_sub_group_size(SG_SIZE_SCAN)))
817 void kernel
sort_morton_codes_reduce_bins_wide_partial_sum(global struct Globals * globals,uint numTasks,uint numGroups,global uint * global_histogram,global uint * global_histogram_partials,global uint * wg_flags,uint iteration)818 sort_morton_codes_reduce_bins_wide_partial_sum(
819     global struct Globals* globals,
820     uint numTasks,
821     uint numGroups,
822     global uint* global_histogram,
823     global uint* global_histogram_partials,
824     global uint* wg_flags,
825     uint iteration)
826 {
827     local uint exclusive_scan_tmp[WG_SIZE_WIDE / SG_SIZE_SCAN];
828 
829     const uint numItems = globals->numPrimitives;
830     if (numItems < MORTON_LSB_SORT_NO_SHIFT_THRESHOLD)
831         sort_morton_codes_reduce_bins_wide_partial_sum_func(globals, global_histogram, global_histogram_partials, wg_flags, exclusive_scan_tmp, numTasks, numGroups, iteration, false);
832     else
833         sort_morton_codes_reduce_bins_wide_partial_sum_func(globals, global_histogram, global_histogram_partials, wg_flags, exclusive_scan_tmp, numTasks, numGroups, iteration, true);
834 }
835 
836 GRL_ANNOTATE_IGC_DO_NOT_SPILL
837 __attribute__((reqd_work_group_size(WG_SIZE_WIDE, 1, 1)))
838 __attribute__((intel_reqd_sub_group_size(SG_SIZE_SCAN)))
839 void kernel
sort_morton_codes_reduce_bins_wide_add_reduce(global struct Globals * globals,uint numTasks,uint numGroups,global uint * global_histogram,global uint * global_histogram_partials,uint iteration)840 sort_morton_codes_reduce_bins_wide_add_reduce(
841     global struct Globals* globals,
842     uint numTasks,
843     uint numGroups,
844     global uint* global_histogram,
845     global uint* global_histogram_partials,
846     uint iteration)
847 {
848     local uint partials[RADIX_BINS];
849 
850     const uint numItems = globals->numPrimitives;
851     if (numItems < MORTON_LSB_SORT_NO_SHIFT_THRESHOLD)
852         sort_morton_codes_reduce_bins_wide_add_reduce_func(globals, global_histogram, global_histogram_partials, partials, numTasks, numGroups, iteration, false);
853     else
854         sort_morton_codes_reduce_bins_wide_add_reduce_func(globals, global_histogram, global_histogram_partials, partials, numTasks, numGroups, iteration, true);
855 }
856