xref: /aosp_15_r20/external/mesa3d/src/intel/vulkan/grl/gpu/morton_msb_radix_bitonic_sort.h (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 //
2 // Copyright (C) 2009-2021 Intel Corporation
3 //
4 // SPDX-License-Identifier: MIT
5 //
6 //
7 
8 #pragma once
9 
10 #include "common.h"
11 #include "morton_msb_radix_bitonic_sort_shared.h"
12 
13 #include "libs/lsc_intrinsics.h"
14 
15 ///////////////////////////////////////////////////////////////////////////////
16 //
17 // Configuration switches
18 //
19 ///////////////////////////////////////////////////////////////////////////////
20 
21 #define DEBUG 0
22 #define MERGE_BLS_WITHIN_SG 0
23 
24 ///////////////////////////////////////////////////////////////////////////////
25 
26 
27 #if DEBUG
28 #define DEBUG_CODE(A) A
29 #else
30 #define DEBUG_CODE(A)
31 #endif
32 
33 #define BOTTOM_LEVEL_SORT_WG_SIZE 512
34 
35 // this kernel is only used to put into metakernel for debug to print that the code reached that place
36 GRL_ANNOTATE_IGC_DO_NOT_SPILL
37 __attribute__((reqd_work_group_size(1, 1, 1)))
debug_print_kernel(uint variable)38 void kernel debug_print_kernel(uint variable)
39 {
40     if(get_local_id(0) == 0)
41     printf("I'm here! %d\n", variable);
42 }
43 
44 GRL_ANNOTATE_IGC_DO_NOT_SPILL
45 __attribute__((reqd_work_group_size(1, 1, 1)))
check_bls_sort(global struct Globals * globals,global ulong * input)46 void kernel check_bls_sort(global struct Globals* globals, global ulong* input)
47 {
48     uint prims_num = globals->numPrimitives;
49 
50     printf("in check_bls_sort kernel. Values count:: %d\n", prims_num);
51 
52     ulong left = input[0];
53     ulong right;
54     for (int i = 0; i < prims_num - 1; i++)
55     {
56         right = input[i + 1];
57         printf("sorted val: %llu\n", left);
58         if (left > right)
59         {
60             printf("element %d is bigger than %d: %llu > %llu\n", i, i+1, left, right);
61         }
62         left = right;
63     }
64 }
65 
wg_scan_inclusive_add_opt(local uint * tmp,uint val,uint SG_SIZE,uint WG_SIZE)66 inline uint wg_scan_inclusive_add_opt(local uint* tmp, uint val, uint SG_SIZE, uint WG_SIZE)
67 {
68     const uint hw_thread_in_wg_id = get_local_id(0) / SG_SIZE;
69     const uint sg_local_id = get_local_id(0) % SG_SIZE;
70     const uint NUM_HW_THREADS_IN_WG = WG_SIZE / SG_SIZE;
71 
72     uint acc = sub_group_scan_inclusive_add(val);
73     if (NUM_HW_THREADS_IN_WG == 1)
74     {
75         return acc;
76     }
77     tmp[hw_thread_in_wg_id] = sub_group_broadcast(acc, SG_SIZE - 1);
78     barrier(CLK_LOCAL_MEM_FENCE);
79 
80     uint loaded_val = sg_local_id < NUM_HW_THREADS_IN_WG ? tmp[sg_local_id] : 0;
81     uint wgs_acc = sub_group_scan_exclusive_add(loaded_val);
82     uint acc_for_this_hw_thread = sub_group_broadcast(wgs_acc, hw_thread_in_wg_id);
83     // for > 256 workitems in SIMD16 we won't fit in 16 workitems per subgroup, so we need additional iteration
84     // same for > 64 workitems and more in SIMD8
85     uint num_iterations = (NUM_HW_THREADS_IN_WG + SG_SIZE - 1) / SG_SIZE;
86     for (int i = 1; i < num_iterations; i++)
87     {
88         // need to add tmp[] because of "exclusive" scan, so last element misses it
89         uint prev_max_sum = sub_group_broadcast(wgs_acc, SG_SIZE - 1) + tmp[(i * SG_SIZE) - 1];
90         loaded_val = (sg_local_id + i * SG_SIZE) < NUM_HW_THREADS_IN_WG ? tmp[sg_local_id] : 0;
91         wgs_acc = sub_group_scan_exclusive_add(loaded_val);
92         wgs_acc += prev_max_sum;
93         uint new_acc_for_this_hw_thread = sub_group_broadcast(wgs_acc, hw_thread_in_wg_id % SG_SIZE);
94         if (hw_thread_in_wg_id >= i * SG_SIZE)
95             acc_for_this_hw_thread = new_acc_for_this_hw_thread;
96     }
97     return acc + acc_for_this_hw_thread;
98 }
99 
100 struct MSBDispatchArgs
101 {
102     global struct MSBRadixContext* context;
103     uint num_of_wgs; // this is the number of workgroups that was dispatched for this context
104     ulong* wg_key_start; // this is where keys to process start for current workgroup
105     ulong* wg_key_end;
106     uint shift_bit;
107 };
108 
109 
110 
111 
get_msb_dispatch_args(global struct VContextScheduler * scheduler)112 struct MSBDispatchArgs get_msb_dispatch_args(global struct VContextScheduler* scheduler)
113 {
114     global struct MSBDispatchQueue* queue = &scheduler->msb_queue;
115 
116     uint group = get_group_id(0);
117     struct MSBDispatchRecord record;
118 
119     // TODO_OPT:  Load this entire prefix array into SLM instead of searching..
120     //    Or use sub-group ops
121     uint i = 0;
122     while (i < queue->num_records)
123     {
124         uint n = queue->records[i].wgs_to_dispatch;
125 
126         if (group < n)
127         {
128             record = queue->records[i];
129             break;
130         }
131 
132         group -= n;
133         i++;
134     }
135 
136     uint context_id = i;
137     global struct MSBRadixContext* context = &scheduler->contexts[context_id];
138 
139     // moving to ulongs to avoid uint overflow
140     ulong group_id_in_dispatch = group;
141     ulong start_offset = context->start_offset;
142     ulong num_keys = context->num_keys;
143     ulong wgs_to_dispatch = record.wgs_to_dispatch;
144 
145     struct MSBDispatchArgs args;
146     args.context = context;
147     args.num_of_wgs = record.wgs_to_dispatch;
148     args.wg_key_start = context->keys_in + start_offset + (group_id_in_dispatch * num_keys / wgs_to_dispatch);
149     args.wg_key_end = context->keys_in + start_offset + ((group_id_in_dispatch+1) * num_keys / wgs_to_dispatch);
150     args.shift_bit = MSB_SHIFT_BYTE_START_OFFSET - context->iteration * MSB_BITS_PER_ITERATION;
151     return args;
152 }
153 
154 
155 
156 
BLSDispatchQueue_push(global struct BLSDispatchQueue * queue,struct BLSDispatchRecord * record)157 void BLSDispatchQueue_push(global struct BLSDispatchQueue* queue, struct BLSDispatchRecord* record)
158 {
159     uint new_idx = atomic_inc_global(&queue->num_records);
160     queue->records[new_idx] = *record;
161     DEBUG_CODE(printf("adding bls of size: %d\n", record->count));
162 }
163 
164 
165 
166 
DO_CountSort(struct BLSDispatchRecord dispatchRecord,local ulong * SLM_shared,global ulong * output)167 void DO_CountSort(struct BLSDispatchRecord dispatchRecord, local ulong* SLM_shared, global ulong* output)
168 {
169     uint tid = get_local_id(0);
170 
171     global ulong* in = ((global ulong*)(dispatchRecord.keys_in)) + dispatchRecord.start_offset;
172 
173     ulong a = tid < dispatchRecord.count ? in[tid] : ULONG_MAX;
174 
175     SLM_shared[tid] = a;
176 
177     uint counter = 0;
178 
179     barrier(CLK_LOCAL_MEM_FENCE);
180 
181     ulong curr = SLM_shared[get_sub_group_local_id()];
182 
183     for (uint i = 16; i < dispatchRecord.count; i += 16)
184     {
185         ulong next  = SLM_shared[i + get_sub_group_local_id()];
186 
187         for (uint j = 0; j < 16; j++)
188         {
189             // some older drivers have bug when shuffling ulong so we process by shuffling 2x uint
190             uint2 curr_as_uint2 = as_uint2(curr);
191             uint2 sg_curr_as_uint2 = (uint2)(sub_group_broadcast(curr_as_uint2.x, j), sub_group_broadcast(curr_as_uint2.y, j));
192             ulong c = as_ulong(sg_curr_as_uint2);
193             if (c < a)
194                 counter++;
195         }
196 
197         curr = next;
198     }
199 
200 
201     // last iter
202     for (uint j = 0; j < 16; j++)
203     {
204         // some older drivers have bug when shuffling ulong so we process by shuffling 2x uint
205         uint2 curr_as_uint2 = as_uint2(curr);
206         uint2 sg_curr_as_uint2 = (uint2)(sub_group_broadcast(curr_as_uint2.x, j), sub_group_broadcast(curr_as_uint2.y, j));
207         ulong c = as_ulong(sg_curr_as_uint2);
208         if (c < a)
209             counter++;
210     }
211 
212     // save elements to its sorted positions
213     if (tid < dispatchRecord.count)
214         output[dispatchRecord.start_offset + counter] = a;
215 }
216 
DO_Bitonic(struct BLSDispatchRecord dispatchRecord,local ulong * SLM_shared,global ulong * output)217 void DO_Bitonic(struct BLSDispatchRecord dispatchRecord, local ulong* SLM_shared, global ulong* output)
218 {
219     uint lid = get_local_id(0);
220     uint elements_to_sort = BOTTOM_LEVEL_SORT_THRESHOLD;
221     while ((elements_to_sort >> 1) >= dispatchRecord.count && elements_to_sort >> 1 >= BOTTOM_LEVEL_SORT_WG_SIZE)
222     {
223         elements_to_sort >>= 1;
224     }
225 
226     for (int i = 0; i < elements_to_sort / BOTTOM_LEVEL_SORT_WG_SIZE; i++)
227     {
228         uint tid = lid + i * BOTTOM_LEVEL_SORT_WG_SIZE;
229 
230         if (tid >= dispatchRecord.count)
231             SLM_shared[tid] = ULONG_MAX;
232         else
233             SLM_shared[tid] = ((global ulong*)(dispatchRecord.keys_in))[dispatchRecord.start_offset + tid];
234     }
235 
236     barrier(CLK_LOCAL_MEM_FENCE);
237 
238     uint k_iterations = elements_to_sort;
239     while(k_iterations >> 1 >= dispatchRecord.count && k_iterations != 0)
240     {
241         k_iterations >>= 1;
242     }
243 
244     for (unsigned int k = 2; k <= k_iterations; k *= 2)
245     {
246         for (unsigned int j = k / 2; j > 0; j /= 2)
247         {
248             // this loop is needed when we can't create big enough workgroup so we need to process multiple times
249             for (uint i = 0; i < elements_to_sort / BOTTOM_LEVEL_SORT_WG_SIZE; i++)
250             {
251                 uint tid = lid + i * BOTTOM_LEVEL_SORT_WG_SIZE;
252                 unsigned int ixj = tid ^ j;
253                 if (ixj > tid)
254                 {
255                     if ((tid & k) == 0)
256                     {
257                         if (SLM_shared[tid] > SLM_shared[ixj])
258                         {
259                             ulong tmp = SLM_shared[tid];
260                             SLM_shared[tid] = SLM_shared[ixj];
261                             SLM_shared[ixj] = tmp;
262                         }
263                     }
264                     else
265                     {
266                         if (SLM_shared[tid] < SLM_shared[ixj])
267                         {
268                             ulong tmp = SLM_shared[tid];
269                             SLM_shared[tid] = SLM_shared[ixj];
270                             SLM_shared[ixj] = tmp;
271                         }
272                     }
273                 }
274             }
275 
276             barrier(CLK_LOCAL_MEM_FENCE);
277         }
278     }
279 
280     for (int i = 0; i < elements_to_sort / BOTTOM_LEVEL_SORT_WG_SIZE; i++)
281     {
282         uint tid = lid + i * BOTTOM_LEVEL_SORT_WG_SIZE;
283 
284         if (tid < dispatchRecord.count)
285             output[dispatchRecord.start_offset + tid] = SLM_shared[tid];
286     }
287 }
288 
289 
290 
291 
DO_Create_Separate_BLS_Work(global struct VContextScheduler * scheduler,global struct MSBRadixContext * context,global ulong * input)292 void DO_Create_Separate_BLS_Work(global struct VContextScheduler* scheduler, global struct MSBRadixContext* context, global ulong* input)
293 {
294     uint lid = get_local_id(0);
295 
296     uint start = context->start[lid];
297     uint count = context->count[lid];
298     uint start_offset = context->start_offset + start;
299 
300     struct BLSDispatchRecord record;
301     record.start_offset = start_offset;
302     record.count = count;
303     record.keys_in = context->keys_out;
304 
305     if (count == 0) // we don't have elements so don't do anything
306     {
307     }
308     else if (count == 1) // single element so just write it out
309     {
310         input[start_offset] = ((global ulong*)record.keys_in)[start_offset];
311     }
312     else if (count <= BOTTOM_LEVEL_SORT_THRESHOLD)
313     {
314         BLSDispatchQueue_push((global struct BLSDispatchQueue*)scheduler->next_bls_queue, &record);
315     }
316 }
317 
318 
319 
320 
321 // We try to merge small BLS into larger one within the sub_group
DO_Create_SG_Merged_BLS_Work_Parallel(global struct VContextScheduler * scheduler,global struct MSBRadixContext * context,global ulong * input)322 void DO_Create_SG_Merged_BLS_Work_Parallel(global struct VContextScheduler* scheduler, global struct MSBRadixContext* context, global ulong* input)
323 {
324     uint lid = get_local_id(0);
325     uint sid = get_sub_group_local_id();
326 
327     uint create_msb_work = context->count[lid] > BOTTOM_LEVEL_SORT_THRESHOLD ? 1 : 0;
328 
329     uint start = context->start[lid];
330     uint count = context->count[lid];
331     uint ctx_start_offset = context->start_offset;
332 
333     if (sid == 0 || create_msb_work) // these SIMD lanes are the begining of merged BLS
334     {
335         struct BLSDispatchRecord record;
336         if (create_msb_work)
337         {
338             record.start_offset = ctx_start_offset + start + count;
339             record.count = 0;
340         }
341         else // SIMD lane 0 case
342         {
343             record.start_offset = ctx_start_offset + start;
344             record.count = count;
345         }
346 
347         record.keys_in = context->keys_out;
348 
349         uint loop_idx = 1;
350         while (sid + loop_idx < 16) // loop over subgroup
351         {
352             uint _create_msb_work = intel_sub_group_shuffle_down(create_msb_work, 0u, loop_idx);
353             uint _count = intel_sub_group_shuffle_down(count, 0u, loop_idx);
354             uint _start = intel_sub_group_shuffle_down(start, 0u, loop_idx);
355 
356             if (_create_msb_work) // found out next MSB work, so range of merges ends
357                 break;
358 
359             // need to push record since nothing more will fit
360             if (record.count + _count > BOTTOM_LEVEL_SORT_MERGING_THRESHOLD)
361             {
362                 if (record.count == 1)
363                 {
364                     input[record.start_offset] = record.keys_in[record.start_offset];
365                 }
366                 else if (record.count > 1)
367                 {
368                     BLSDispatchQueue_push((global struct BLSDispatchQueue*)scheduler->next_bls_queue, &record);
369                 }
370                 record.start_offset = ctx_start_offset + _start;
371                 record.count = _count;
372             }
373             else
374             {
375                 record.count += _count;
376             }
377             loop_idx++;
378         }
379         // if we have any elements left, then schedule them
380         if (record.count == 1) // only one element, so just write it out
381         {
382             input[record.start_offset] = record.keys_in[record.start_offset];
383         }
384         else if (record.count > 1)
385         {
386             BLSDispatchQueue_push((global struct BLSDispatchQueue*)scheduler->next_bls_queue, &record);
387         }
388     }
389 }
390 
391 
392 
393 
394 // We try to merge small BLS into larger one within the sub_group
DO_Create_SG_Merged_BLS_Work(global struct VContextScheduler * scheduler,global struct MSBRadixContext * context,global ulong * input)395 void DO_Create_SG_Merged_BLS_Work(global struct VContextScheduler* scheduler, global struct MSBRadixContext* context, global ulong* input)
396 {
397     uint lid = get_local_id(0);
398     uint sid = get_sub_group_local_id();
399 
400     uint create_msb_work = context->count[lid] > BOTTOM_LEVEL_SORT_THRESHOLD ? 1 : 0;
401 
402     uint start = context->start[lid];
403     uint count = context->count[lid];
404     uint ctx_start_offset = context->start_offset;
405 
406     if (sid == 0)
407     {
408         struct BLSDispatchRecord record;
409         record.start_offset = ctx_start_offset + start;
410         record.count = 0;
411         record.keys_in = context->keys_out;
412 
413         for (int i = 0; i < 16; i++)
414         {
415             uint _create_msb_work = sub_group_broadcast(create_msb_work, i);
416             uint _count = sub_group_broadcast(count, i);
417             uint _start = sub_group_broadcast(start, i);
418             if (_create_msb_work)
419             {
420                 if (record.count == 1) // only one element, so just write it out
421                 {
422                     input[record.start_offset] = record.keys_in[record.start_offset];
423                 }
424                 else if (record.count > 1)
425                 {
426                     BLSDispatchQueue_push((global struct BLSDispatchQueue*)scheduler->next_bls_queue, &record);
427                 }
428                 record.start_offset = ctx_start_offset + _start + _count;
429                 record.count = 0;
430                 continue;
431             }
432             // need to push record since nothing more will fit
433             if (record.count + _count > BOTTOM_LEVEL_SORT_MERGING_THRESHOLD)
434             {
435                 BLSDispatchQueue_push((global struct BLSDispatchQueue*)scheduler->next_bls_queue, &record);
436                 record.start_offset = ctx_start_offset + _start;
437                 record.count = _count;
438             }
439             else
440             {
441                 record.count += _count;
442             }
443         }
444         // if we have any elements left, then schedule them
445         if (record.count == 1) // only one element, so just write it out
446         {
447             input[record.start_offset] = record.keys_in[record.start_offset];
448         }
449         else if (record.count > 1)
450         {
451             BLSDispatchQueue_push((global struct BLSDispatchQueue*)scheduler->next_bls_queue, &record);
452         }
453     }
454 }
455 
456 
457 
458 
DO_Create_Work(global struct VContextScheduler * scheduler,global struct MSBRadixContext * context,global ulong * input,local uint * slm_for_wg_scan,uint sg_size,uint wg_size)459 void DO_Create_Work(global struct VContextScheduler* scheduler, global struct MSBRadixContext* context, global ulong* input, local uint* slm_for_wg_scan, uint sg_size, uint wg_size)
460 {
461     uint lid = get_local_id(0);
462 
463     uint iteration = context->iteration + 1;
464     uint start = context->start[lid];
465     uint count = context->count[lid];
466     uint start_offset = context->start_offset + start;
467 
468     uint create_msb_work = count > BOTTOM_LEVEL_SORT_THRESHOLD ? 1 : 0;
469 
470 #if MERGE_BLS_WITHIN_SG
471     DO_Create_SG_Merged_BLS_Work_Parallel(scheduler, context, input);
472 #else
473     DO_Create_Separate_BLS_Work(scheduler, context, input);
474 #endif
475 
476     uint new_entry_id = wg_scan_inclusive_add_opt(slm_for_wg_scan, create_msb_work, sg_size, wg_size);//work_group_scan_inclusive_add(create_msb_work);
477     uint stack_begin_entry;
478     // last workitem in wg contains number of all new entries
479     if (lid == (MSB_RADIX_NUM_BINS - 1))
480     {
481         stack_begin_entry = atomic_add_global(&scheduler->msb_stack.num_entries, new_entry_id);
482     }
483     stack_begin_entry = work_group_broadcast(stack_begin_entry, (MSB_RADIX_NUM_BINS - 1));
484     new_entry_id += stack_begin_entry -1;
485 
486 
487     if (create_msb_work)
488     {
489         scheduler->msb_stack.entries[new_entry_id].start_offset = start_offset;
490         scheduler->msb_stack.entries[new_entry_id].count = count;
491         scheduler->msb_stack.entries[new_entry_id].iteration = iteration;
492     }
493 
494     if (lid == 0) {
495         DEBUG_CODE(printf("num of new bls: %d\n", scheduler->next_bls_queue->num_records));
496     }
497 }
498 
499 
500 struct BatchedBLSDispatchEntry
501 {
502     /////////////////////////////////////////////////////////////
503     //  State data used for communication with command streamer
504     //  NOTE: This part must match definition in 'msb_radix_bitonic_sort.grl'
505     /////////////////////////////////////////////////////////////
506     qword p_data_buffer;
507     qword num_elements; // number of elements in p_data_buffer
508 };
509 
510 
511 GRL_ANNOTATE_IGC_DO_NOT_SPILL
512 __attribute__((reqd_work_group_size(BOTTOM_LEVEL_SORT_WG_SIZE, 1, 1)))
513 __attribute__((intel_reqd_sub_group_size(16)))
sort_morton_codes_batched_BLS_dispatch(global struct BatchedBLSDispatchEntry * bls_dispatches)514 void kernel sort_morton_codes_batched_BLS_dispatch(global struct BatchedBLSDispatchEntry* bls_dispatches)
515 {
516     uint dispatch_id = get_group_id(0);
517     uint lid = get_local_id(0);
518 
519     local ulong SLM_shared[BOTTOM_LEVEL_SORT_THRESHOLD];
520 
521     struct BatchedBLSDispatchEntry dispatchArgs = bls_dispatches[dispatch_id];
522     struct BLSDispatchRecord dispatchRecord;
523     dispatchRecord.start_offset = 0;
524     dispatchRecord.count = dispatchArgs.num_elements;
525     dispatchRecord.keys_in = (ulong*)dispatchArgs.p_data_buffer;
526 
527     DEBUG_CODE(if (lid == 0) printf("running sort_morton_codes_bottom_level_single_wg for %d elements\n", dispatchRecord.count));
528 
529     if(dispatchRecord.count > 1)
530         DO_Bitonic(dispatchRecord, SLM_shared, (global ulong*)dispatchRecord.keys_in);
531 }
532 
533 
534 
535 
536 GRL_ANNOTATE_IGC_DO_NOT_SPILL
537 __attribute__((reqd_work_group_size(BOTTOM_LEVEL_SORT_WG_SIZE, 1, 1)))
538 __attribute__((intel_reqd_sub_group_size(16)))
sort_morton_codes_bottom_level_single_wg(global struct Globals * globals,global ulong * input,global ulong * output)539 void kernel sort_morton_codes_bottom_level_single_wg(global struct Globals* globals, global ulong* input, global ulong* output)
540 {
541     uint lid = get_local_id(0);
542 
543     DEBUG_CODE(if (lid == 0) printf("running sort_morton_codes_bottom_level_single_wg for %d elements\n", globals->numPrimitives));
544 
545     local ulong SLM_shared[BOTTOM_LEVEL_SORT_THRESHOLD];
546 
547     struct BLSDispatchRecord dispatchRecord;
548     dispatchRecord.start_offset = 0;
549     dispatchRecord.count = globals->numPrimitives;
550     dispatchRecord.keys_in = (ulong*)input;
551 
552     //TODO: count or bitonic here?
553     //DO_Bitonic(dispatchRecord, SLM_shared, output);
554     DO_CountSort(dispatchRecord, SLM_shared, output);
555 }
556 
557 
558 
559 
560 // This kernel initializes first context to start up the whole execution
561 GRL_ANNOTATE_IGC_DO_NOT_SPILL
562 __attribute__((reqd_work_group_size(MSB_RADIX_NUM_BINS, 1, 1)))
563 __attribute__((intel_reqd_sub_group_size(16)))
sort_morton_codes_msb_begin(global struct Globals * globals,global struct VContextScheduler * scheduler,global ulong * buf0,global ulong * buf1)564 void kernel sort_morton_codes_msb_begin(
565     global struct Globals* globals,
566     global struct VContextScheduler* scheduler,
567     global ulong* buf0,
568     global ulong* buf1)
569 {
570     uint lid = get_local_id(0);
571     uint gid = get_group_id(0);
572 
573     DEBUG_CODE(if (lid == 0)printf("running sort_morton_codes_msb_begin\n"));
574 
575     scheduler->contexts[gid].count[lid] = 0;
576 
577     if (gid == 0 && lid == 0)
578     {
579         global struct MSBRadixContext* context = &scheduler->contexts[lid];
580         const uint num_prims = globals->numPrimitives;
581 
582         scheduler->bls_queue0.num_records = 0;
583         scheduler->bls_queue1.num_records = 0;
584 
585         scheduler->curr_bls_queue = &scheduler->bls_queue1;
586         scheduler->next_bls_queue = &scheduler->bls_queue0;
587 
588         context->start_offset = 0;
589         context->num_wgs_in_flight = 0;
590         context->num_keys = num_prims;
591         context->iteration = 0;
592         context->keys_in = buf0;
593         context->keys_out = buf1;
594 
595         uint msb_wgs_to_dispatch = (num_prims + MSB_WG_SORT_ELEMENTS_THRESHOLD - 1) / MSB_WG_SORT_ELEMENTS_THRESHOLD;
596         scheduler->msb_queue.records[0].wgs_to_dispatch = msb_wgs_to_dispatch;
597 
598         scheduler->num_wgs_msb = msb_wgs_to_dispatch;
599         scheduler->num_wgs_bls = 0;
600         scheduler->msb_stack.num_entries = 0;
601         scheduler->msb_queue.num_records = 1;
602     }
603 }
604 
605 
606 
607 
608 __attribute__((reqd_work_group_size(MSB_RADIX_NUM_VCONTEXTS, 1, 1)))
609 kernel void
scheduler(global struct VContextScheduler * scheduler,global ulong * buf0,global ulong * buf1)610 scheduler(global struct VContextScheduler* scheduler, global ulong* buf0, global ulong* buf1)
611 {
612     uint lid = get_local_id(0);
613 
614     DEBUG_CODE(if (lid == 0) printf("running sort_morton_codes_scheduler\n"));
615 
616     uint context_idx = lid;
617 
618     const uint num_of_stack_entries = scheduler->msb_stack.num_entries;
619 
620     uint msb_wgs_to_dispatch = 0;
621     if (lid < num_of_stack_entries)
622     {
623         struct MSBStackEntry entry = scheduler->msb_stack.entries[(num_of_stack_entries-1) - lid];
624         global struct MSBRadixContext* context = &scheduler->contexts[lid];
625         context->start_offset = entry.start_offset;
626         context->num_wgs_in_flight = 0;
627         context->num_keys = entry.count;
628         context->iteration = entry.iteration;
629         context->keys_in = entry.iteration % 2 == 0 ? buf0 : buf1;
630         context->keys_out = entry.iteration % 2 == 0 ? buf1 : buf0;
631 
632         msb_wgs_to_dispatch = (entry.count + MSB_WG_SORT_ELEMENTS_THRESHOLD - 1) / MSB_WG_SORT_ELEMENTS_THRESHOLD;
633         scheduler->msb_queue.records[lid].wgs_to_dispatch = msb_wgs_to_dispatch;
634     }
635 
636     msb_wgs_to_dispatch = work_group_reduce_add(msb_wgs_to_dispatch);// TODO: if compiler implementation is slow, then consider to manually write it
637 
638     if (lid == 0)
639     {
640         // swap queue for next iteration
641         struct BLSDispatchQueue* tmp = scheduler->curr_bls_queue;
642         scheduler->curr_bls_queue = scheduler->next_bls_queue;
643         scheduler->next_bls_queue = tmp;
644 
645         scheduler->next_bls_queue->num_records = 0;
646 
647         scheduler->num_wgs_bls = scheduler->curr_bls_queue->num_records;
648         scheduler->num_wgs_msb = msb_wgs_to_dispatch;
649 
650         if (num_of_stack_entries < MSB_RADIX_NUM_VCONTEXTS)
651         {
652             scheduler->msb_queue.num_records = num_of_stack_entries;
653             scheduler->msb_stack.num_entries = 0;
654         }
655         else
656         {
657             scheduler->msb_queue.num_records = MSB_RADIX_NUM_VCONTEXTS;
658             scheduler->msb_stack.num_entries -= MSB_RADIX_NUM_VCONTEXTS;
659         }
660     }
661 
662     DEBUG_CODE(if (lid == 0) printf("running sort_morton_codes_scheduler finished, to spawn %d MSB wgs in %d contexts and %d BLS wgs, MSB records on stack %d\n",
663         scheduler->num_wgs_msb, scheduler->msb_queue.num_records, scheduler->num_wgs_bls, scheduler->msb_stack.num_entries));
664 }
665 
666 
667 
668 
669 // this is the lowest sub-task, which should end return sorted codes
670 GRL_ANNOTATE_IGC_DO_NOT_SPILL
671 __attribute__((reqd_work_group_size(BOTTOM_LEVEL_SORT_WG_SIZE, 1, 1)))
672 __attribute__((intel_reqd_sub_group_size(16)))
sort_morton_codes_bottom_level(global struct VContextScheduler * scheduler,global ulong * output)673 void kernel sort_morton_codes_bottom_level( global struct VContextScheduler* scheduler, global ulong* output)
674 {
675     uint lid = get_local_id(0);
676 
677     DEBUG_CODE(if (get_group_id(0) == 0 && lid == 0) printf("running sort_morton_codes_bottom_level\n"));
678 
679     local struct BLSDispatchRecord l_dispatchRecord;
680     if (lid == 0)
681     {
682         uint record_idx = get_group_id(0);
683         l_dispatchRecord = scheduler->curr_bls_queue->records[record_idx];
684         //l_dispatchRecord = BLSDispatchQueue_pop((global struct BLSDispatchQueue*)scheduler->curr_bls_queue);
685         atomic_dec_global(&scheduler->num_wgs_bls);
686     }
687 
688     barrier(CLK_LOCAL_MEM_FENCE);
689 
690     struct BLSDispatchRecord dispatchRecord = l_dispatchRecord;
691 
692     local ulong SLM_shared[BOTTOM_LEVEL_SORT_THRESHOLD];
693 
694     // right now use only bitonic sort
695     // TODO: maybe implement something else
696     if (1)
697     {
698         //DO_Bitonic(dispatchRecord, SLM_shared, output);
699         DO_CountSort(dispatchRecord, SLM_shared, output);
700     }
701 }
702 
703 
704 
705 
706 #define MSB_COUNT_WG_SIZE MSB_RADIX_NUM_BINS
707 #define MSB_COUNT_SG_SIZE 16
708 
709 // count how many elements per buckets we have
710 GRL_ANNOTATE_IGC_DO_NOT_SPILL
711 __attribute__((reqd_work_group_size(MSB_COUNT_WG_SIZE, 1, 1)))
712 __attribute__((intel_reqd_sub_group_size(MSB_COUNT_SG_SIZE)))
sort_morton_codes_msb_count_items(global struct VContextScheduler * scheduler)713 void kernel sort_morton_codes_msb_count_items( global struct VContextScheduler* scheduler)
714 {
715     uint lid = get_local_id(0);
716     uint lsz = MSB_RADIX_NUM_BINS;
717 
718     DEBUG_CODE(if (lid == 0) printf("running sort_morton_codes_msb_count_items\n"));
719 
720     local uint bucket_count[MSB_RADIX_NUM_BINS];
721     local uint finish_count;
722     bucket_count[lid] = 0;
723     if (lid == 0)
724     {
725         finish_count = 0;
726     }
727 
728     struct MSBDispatchArgs dispatchArgs = get_msb_dispatch_args(scheduler);
729 
730     global struct MSBRadixContext* context = dispatchArgs.context;
731 
732     global ulong* key_start = (global ulong*)dispatchArgs.wg_key_start + lid;
733     global ulong* key_end = (global ulong*)dispatchArgs.wg_key_end;
734     uint shift_bit = dispatchArgs.shift_bit;
735     uchar shift_byte = shift_bit / 8; // so we count how many uchars to shift
736     barrier(CLK_LOCAL_MEM_FENCE);
737 
738     global uchar* ks = (global uchar*)key_start;
739     ks += shift_byte;
740     global uchar* ke = (global uchar*)key_end;
741     ke += shift_byte;
742 
743     // double buffering on value loading
744     if (ks < ke)
745     {
746         uchar bucket_id = *ks;
747         ks += lsz * sizeof(ulong);
748 
749         for (global uchar* k = ks; k < ke; k += lsz * sizeof(ulong))
750         {
751             uchar next_bucket_id = *k;
752             atomic_inc_local(&bucket_count[bucket_id]);
753             bucket_id = next_bucket_id;
754         }
755 
756         atomic_inc_local(&bucket_count[bucket_id]);
757 
758     }
759 
760     barrier(CLK_LOCAL_MEM_FENCE);
761 
762     //update global counters for context
763     uint count = bucket_count[lid];
764     if (count > 0)
765         atomic_add_global(&context->count[lid], bucket_count[lid]);
766 
767     mem_fence_gpu_invalidate();
768     work_group_barrier(0);
769 
770     bool final_wg = true;
771     // count WGs which have reached the end
772     if (dispatchArgs.num_of_wgs > 1)
773     {
774         if (lid == 0)
775             finish_count = atomic_inc_global(&context->num_wgs_in_flight) + 1;
776 
777         barrier(CLK_LOCAL_MEM_FENCE);
778 
779         final_wg = finish_count == dispatchArgs.num_of_wgs;
780     }
781 
782     local uint partial_dispatches[MSB_COUNT_WG_SIZE / MSB_COUNT_SG_SIZE];
783     // if this is last wg for current dispatch, update context
784     if (final_wg)
785     {
786         // code below does work_group_scan_exclusive_add(context->count[lid]);
787         {
788             uint lane_val = context->count[lid];
789             uint sg_result = sub_group_scan_inclusive_add(lane_val);
790 
791             partial_dispatches[get_sub_group_id()] = sub_group_broadcast(sg_result, MSB_COUNT_SG_SIZE - 1);
792             barrier(CLK_LOCAL_MEM_FENCE);
793 
794             uint slm_result = sub_group_scan_exclusive_add(partial_dispatches[get_sub_group_local_id()]);
795             slm_result = sub_group_broadcast(slm_result, get_sub_group_id());
796             uint result = slm_result + sg_result - lane_val;
797             context->start[lid] = result;//work_group_scan_exclusive_add(context->count[lid]);
798         }
799 
800         context->count[lid] = 0;
801         if(lid == 0)
802             context->num_wgs_in_flight = 0;
803     }
804 }
805 
806 
807 
808 
809 // sort elements into appropriate buckets
810 GRL_ANNOTATE_IGC_DO_NOT_SPILL
811 __attribute__((reqd_work_group_size(MSB_RADIX_NUM_BINS, 1, 1)))
812 __attribute__((intel_reqd_sub_group_size(16)))
sort_morton_codes_msb_bin_items(global struct VContextScheduler * scheduler,global ulong * input)813 void kernel sort_morton_codes_msb_bin_items(
814     global struct VContextScheduler* scheduler, global ulong* input)
815 {
816     uint lid = get_local_id(0);
817     uint lsz = get_local_size(0);
818 
819     DEBUG_CODE(if (lid == 0) printf("running sort_morton_codes_msb_bin_items\n"));
820 
821     local uint finish_count;
822     if (lid == 0)
823     {
824         finish_count = 0;
825     }
826 
827     struct MSBDispatchArgs dispatchArgs = get_msb_dispatch_args(scheduler);
828     global struct MSBRadixContext* context = dispatchArgs.context;
829 
830     global ulong* key_start = (global ulong*)dispatchArgs.wg_key_start + lid;
831     global ulong* key_end = (global ulong*)dispatchArgs.wg_key_end;
832     uint shift_bit = dispatchArgs.shift_bit;
833 
834     barrier(CLK_LOCAL_MEM_FENCE);
835 
836     global ulong* sorted_keys = (global ulong*)context->keys_out + context->start_offset;
837 
838 #if MSB_RADIX_NUM_BINS == MSB_WG_SORT_ELEMENTS_THRESHOLD // special case meaning that we process exactly 1 element per workitem
839     // here we'll do local counting, then move to global
840 
841     local uint slm_counters[MSB_RADIX_NUM_BINS];
842     slm_counters[lid] = 0;
843 
844     barrier(CLK_LOCAL_MEM_FENCE);
845 
846     uint place_in_slm_bucket;
847     uint bucket_id;
848     ulong val;
849 
850     bool active_lane = key_start < key_end;
851 
852     if (active_lane)
853     {
854         val = *key_start;
855 
856         bucket_id = (val >> (ulong)shift_bit) & (MSB_RADIX_NUM_BINS - 1);
857         place_in_slm_bucket = atomic_inc_local(&slm_counters[bucket_id]);
858     }
859 
860     barrier(CLK_LOCAL_MEM_FENCE);
861 
862     // override slm_counters with global counters - we don't need to override counters with 0 elements since we won't use them anyway
863     if (slm_counters[lid])
864         slm_counters[lid] = atomic_add_global(&context->count[lid], slm_counters[lid]);
865 
866     barrier(CLK_LOCAL_MEM_FENCE);
867 
868     uint id_in_bucket = slm_counters[bucket_id] + place_in_slm_bucket;//atomic_inc_global(&context->count[bucket_id]);
869 
870     if (active_lane)
871         sorted_keys[context->start[bucket_id] + id_in_bucket] = val;
872 #else
873     // double buffering on value loading
874     if (key_start < key_end)
875     {
876         ulong val = *key_start;
877         key_start += lsz;
878 
879         for (global ulong* k = key_start; k < key_end; k += lsz)
880         {
881             ulong next_val = *k;
882             uint bucket_id = (val >> (ulong)shift_bit) & (MSB_RADIX_NUM_BINS - 1);
883             uint id_in_bucket = atomic_inc_global(&context->count[bucket_id]);
884 
885             //printf("dec: %llu, val: %llX bucket_id: %X", *k, *k, bucket_id);
886             sorted_keys[context->start[bucket_id] + id_in_bucket] = val;
887 
888             val = next_val;
889         }
890 
891         uint bucket_id = (val >> (ulong)shift_bit) & (MSB_RADIX_NUM_BINS - 1);
892         uint id_in_bucket = atomic_inc_global(&context->count[bucket_id]);
893 
894         sorted_keys[context->start[bucket_id] + id_in_bucket] = val;
895     }
896 #endif
897 
898     // make sure all groups's "counters" and "starts" are visible to final workgroup
899     mem_fence_gpu_invalidate();
900     work_group_barrier(0);
901 
902     bool final_wg = true;
903     // count WGs which have reached the end
904     if (dispatchArgs.num_of_wgs > 1)
905     {
906         if (lid == 0)
907             finish_count = atomic_inc_global(&context->num_wgs_in_flight) + 1;
908 
909         barrier(CLK_LOCAL_MEM_FENCE);
910 
911         final_wg = finish_count == dispatchArgs.num_of_wgs;
912     }
913 
914     local uint slm_for_wg_funcs[MSB_COUNT_WG_SIZE / MSB_COUNT_SG_SIZE];
915     // if this is last wg for current dispatch, then prepare sub-tasks
916     if (final_wg)
917     {
918         DO_Create_Work(scheduler, context, input, slm_for_wg_funcs, 16, MSB_RADIX_NUM_BINS);
919 
920         // clear context's counters for future execution
921         context->count[lid] = 0;
922     }
923 
924 }