1 //
2 // Copyright (C) 2009-2021 Intel Corporation
3 //
4 // SPDX-License-Identifier: MIT
5 //
6 //
7
8 #pragma once
9
10 #include "common.h"
11 #include "morton_msb_radix_bitonic_sort_shared.h"
12
13 #include "libs/lsc_intrinsics.h"
14
15 ///////////////////////////////////////////////////////////////////////////////
16 //
17 // Configuration switches
18 //
19 ///////////////////////////////////////////////////////////////////////////////
20
21 #define DEBUG 0
22 #define MERGE_BLS_WITHIN_SG 0
23
24 ///////////////////////////////////////////////////////////////////////////////
25
26
27 #if DEBUG
28 #define DEBUG_CODE(A) A
29 #else
30 #define DEBUG_CODE(A)
31 #endif
32
33 #define BOTTOM_LEVEL_SORT_WG_SIZE 512
34
35 // this kernel is only used to put into metakernel for debug to print that the code reached that place
36 GRL_ANNOTATE_IGC_DO_NOT_SPILL
37 __attribute__((reqd_work_group_size(1, 1, 1)))
debug_print_kernel(uint variable)38 void kernel debug_print_kernel(uint variable)
39 {
40 if(get_local_id(0) == 0)
41 printf("I'm here! %d\n", variable);
42 }
43
44 GRL_ANNOTATE_IGC_DO_NOT_SPILL
45 __attribute__((reqd_work_group_size(1, 1, 1)))
check_bls_sort(global struct Globals * globals,global ulong * input)46 void kernel check_bls_sort(global struct Globals* globals, global ulong* input)
47 {
48 uint prims_num = globals->numPrimitives;
49
50 printf("in check_bls_sort kernel. Values count:: %d\n", prims_num);
51
52 ulong left = input[0];
53 ulong right;
54 for (int i = 0; i < prims_num - 1; i++)
55 {
56 right = input[i + 1];
57 printf("sorted val: %llu\n", left);
58 if (left > right)
59 {
60 printf("element %d is bigger than %d: %llu > %llu\n", i, i+1, left, right);
61 }
62 left = right;
63 }
64 }
65
wg_scan_inclusive_add_opt(local uint * tmp,uint val,uint SG_SIZE,uint WG_SIZE)66 inline uint wg_scan_inclusive_add_opt(local uint* tmp, uint val, uint SG_SIZE, uint WG_SIZE)
67 {
68 const uint hw_thread_in_wg_id = get_local_id(0) / SG_SIZE;
69 const uint sg_local_id = get_local_id(0) % SG_SIZE;
70 const uint NUM_HW_THREADS_IN_WG = WG_SIZE / SG_SIZE;
71
72 uint acc = sub_group_scan_inclusive_add(val);
73 if (NUM_HW_THREADS_IN_WG == 1)
74 {
75 return acc;
76 }
77 tmp[hw_thread_in_wg_id] = sub_group_broadcast(acc, SG_SIZE - 1);
78 barrier(CLK_LOCAL_MEM_FENCE);
79
80 uint loaded_val = sg_local_id < NUM_HW_THREADS_IN_WG ? tmp[sg_local_id] : 0;
81 uint wgs_acc = sub_group_scan_exclusive_add(loaded_val);
82 uint acc_for_this_hw_thread = sub_group_broadcast(wgs_acc, hw_thread_in_wg_id);
83 // for > 256 workitems in SIMD16 we won't fit in 16 workitems per subgroup, so we need additional iteration
84 // same for > 64 workitems and more in SIMD8
85 uint num_iterations = (NUM_HW_THREADS_IN_WG + SG_SIZE - 1) / SG_SIZE;
86 for (int i = 1; i < num_iterations; i++)
87 {
88 // need to add tmp[] because of "exclusive" scan, so last element misses it
89 uint prev_max_sum = sub_group_broadcast(wgs_acc, SG_SIZE - 1) + tmp[(i * SG_SIZE) - 1];
90 loaded_val = (sg_local_id + i * SG_SIZE) < NUM_HW_THREADS_IN_WG ? tmp[sg_local_id] : 0;
91 wgs_acc = sub_group_scan_exclusive_add(loaded_val);
92 wgs_acc += prev_max_sum;
93 uint new_acc_for_this_hw_thread = sub_group_broadcast(wgs_acc, hw_thread_in_wg_id % SG_SIZE);
94 if (hw_thread_in_wg_id >= i * SG_SIZE)
95 acc_for_this_hw_thread = new_acc_for_this_hw_thread;
96 }
97 return acc + acc_for_this_hw_thread;
98 }
99
100 struct MSBDispatchArgs
101 {
102 global struct MSBRadixContext* context;
103 uint num_of_wgs; // this is the number of workgroups that was dispatched for this context
104 ulong* wg_key_start; // this is where keys to process start for current workgroup
105 ulong* wg_key_end;
106 uint shift_bit;
107 };
108
109
110
111
get_msb_dispatch_args(global struct VContextScheduler * scheduler)112 struct MSBDispatchArgs get_msb_dispatch_args(global struct VContextScheduler* scheduler)
113 {
114 global struct MSBDispatchQueue* queue = &scheduler->msb_queue;
115
116 uint group = get_group_id(0);
117 struct MSBDispatchRecord record;
118
119 // TODO_OPT: Load this entire prefix array into SLM instead of searching..
120 // Or use sub-group ops
121 uint i = 0;
122 while (i < queue->num_records)
123 {
124 uint n = queue->records[i].wgs_to_dispatch;
125
126 if (group < n)
127 {
128 record = queue->records[i];
129 break;
130 }
131
132 group -= n;
133 i++;
134 }
135
136 uint context_id = i;
137 global struct MSBRadixContext* context = &scheduler->contexts[context_id];
138
139 // moving to ulongs to avoid uint overflow
140 ulong group_id_in_dispatch = group;
141 ulong start_offset = context->start_offset;
142 ulong num_keys = context->num_keys;
143 ulong wgs_to_dispatch = record.wgs_to_dispatch;
144
145 struct MSBDispatchArgs args;
146 args.context = context;
147 args.num_of_wgs = record.wgs_to_dispatch;
148 args.wg_key_start = context->keys_in + start_offset + (group_id_in_dispatch * num_keys / wgs_to_dispatch);
149 args.wg_key_end = context->keys_in + start_offset + ((group_id_in_dispatch+1) * num_keys / wgs_to_dispatch);
150 args.shift_bit = MSB_SHIFT_BYTE_START_OFFSET - context->iteration * MSB_BITS_PER_ITERATION;
151 return args;
152 }
153
154
155
156
BLSDispatchQueue_push(global struct BLSDispatchQueue * queue,struct BLSDispatchRecord * record)157 void BLSDispatchQueue_push(global struct BLSDispatchQueue* queue, struct BLSDispatchRecord* record)
158 {
159 uint new_idx = atomic_inc_global(&queue->num_records);
160 queue->records[new_idx] = *record;
161 DEBUG_CODE(printf("adding bls of size: %d\n", record->count));
162 }
163
164
165
166
DO_CountSort(struct BLSDispatchRecord dispatchRecord,local ulong * SLM_shared,global ulong * output)167 void DO_CountSort(struct BLSDispatchRecord dispatchRecord, local ulong* SLM_shared, global ulong* output)
168 {
169 uint tid = get_local_id(0);
170
171 global ulong* in = ((global ulong*)(dispatchRecord.keys_in)) + dispatchRecord.start_offset;
172
173 ulong a = tid < dispatchRecord.count ? in[tid] : ULONG_MAX;
174
175 SLM_shared[tid] = a;
176
177 uint counter = 0;
178
179 barrier(CLK_LOCAL_MEM_FENCE);
180
181 ulong curr = SLM_shared[get_sub_group_local_id()];
182
183 for (uint i = 16; i < dispatchRecord.count; i += 16)
184 {
185 ulong next = SLM_shared[i + get_sub_group_local_id()];
186
187 for (uint j = 0; j < 16; j++)
188 {
189 // some older drivers have bug when shuffling ulong so we process by shuffling 2x uint
190 uint2 curr_as_uint2 = as_uint2(curr);
191 uint2 sg_curr_as_uint2 = (uint2)(sub_group_broadcast(curr_as_uint2.x, j), sub_group_broadcast(curr_as_uint2.y, j));
192 ulong c = as_ulong(sg_curr_as_uint2);
193 if (c < a)
194 counter++;
195 }
196
197 curr = next;
198 }
199
200
201 // last iter
202 for (uint j = 0; j < 16; j++)
203 {
204 // some older drivers have bug when shuffling ulong so we process by shuffling 2x uint
205 uint2 curr_as_uint2 = as_uint2(curr);
206 uint2 sg_curr_as_uint2 = (uint2)(sub_group_broadcast(curr_as_uint2.x, j), sub_group_broadcast(curr_as_uint2.y, j));
207 ulong c = as_ulong(sg_curr_as_uint2);
208 if (c < a)
209 counter++;
210 }
211
212 // save elements to its sorted positions
213 if (tid < dispatchRecord.count)
214 output[dispatchRecord.start_offset + counter] = a;
215 }
216
DO_Bitonic(struct BLSDispatchRecord dispatchRecord,local ulong * SLM_shared,global ulong * output)217 void DO_Bitonic(struct BLSDispatchRecord dispatchRecord, local ulong* SLM_shared, global ulong* output)
218 {
219 uint lid = get_local_id(0);
220 uint elements_to_sort = BOTTOM_LEVEL_SORT_THRESHOLD;
221 while ((elements_to_sort >> 1) >= dispatchRecord.count && elements_to_sort >> 1 >= BOTTOM_LEVEL_SORT_WG_SIZE)
222 {
223 elements_to_sort >>= 1;
224 }
225
226 for (int i = 0; i < elements_to_sort / BOTTOM_LEVEL_SORT_WG_SIZE; i++)
227 {
228 uint tid = lid + i * BOTTOM_LEVEL_SORT_WG_SIZE;
229
230 if (tid >= dispatchRecord.count)
231 SLM_shared[tid] = ULONG_MAX;
232 else
233 SLM_shared[tid] = ((global ulong*)(dispatchRecord.keys_in))[dispatchRecord.start_offset + tid];
234 }
235
236 barrier(CLK_LOCAL_MEM_FENCE);
237
238 uint k_iterations = elements_to_sort;
239 while(k_iterations >> 1 >= dispatchRecord.count && k_iterations != 0)
240 {
241 k_iterations >>= 1;
242 }
243
244 for (unsigned int k = 2; k <= k_iterations; k *= 2)
245 {
246 for (unsigned int j = k / 2; j > 0; j /= 2)
247 {
248 // this loop is needed when we can't create big enough workgroup so we need to process multiple times
249 for (uint i = 0; i < elements_to_sort / BOTTOM_LEVEL_SORT_WG_SIZE; i++)
250 {
251 uint tid = lid + i * BOTTOM_LEVEL_SORT_WG_SIZE;
252 unsigned int ixj = tid ^ j;
253 if (ixj > tid)
254 {
255 if ((tid & k) == 0)
256 {
257 if (SLM_shared[tid] > SLM_shared[ixj])
258 {
259 ulong tmp = SLM_shared[tid];
260 SLM_shared[tid] = SLM_shared[ixj];
261 SLM_shared[ixj] = tmp;
262 }
263 }
264 else
265 {
266 if (SLM_shared[tid] < SLM_shared[ixj])
267 {
268 ulong tmp = SLM_shared[tid];
269 SLM_shared[tid] = SLM_shared[ixj];
270 SLM_shared[ixj] = tmp;
271 }
272 }
273 }
274 }
275
276 barrier(CLK_LOCAL_MEM_FENCE);
277 }
278 }
279
280 for (int i = 0; i < elements_to_sort / BOTTOM_LEVEL_SORT_WG_SIZE; i++)
281 {
282 uint tid = lid + i * BOTTOM_LEVEL_SORT_WG_SIZE;
283
284 if (tid < dispatchRecord.count)
285 output[dispatchRecord.start_offset + tid] = SLM_shared[tid];
286 }
287 }
288
289
290
291
DO_Create_Separate_BLS_Work(global struct VContextScheduler * scheduler,global struct MSBRadixContext * context,global ulong * input)292 void DO_Create_Separate_BLS_Work(global struct VContextScheduler* scheduler, global struct MSBRadixContext* context, global ulong* input)
293 {
294 uint lid = get_local_id(0);
295
296 uint start = context->start[lid];
297 uint count = context->count[lid];
298 uint start_offset = context->start_offset + start;
299
300 struct BLSDispatchRecord record;
301 record.start_offset = start_offset;
302 record.count = count;
303 record.keys_in = context->keys_out;
304
305 if (count == 0) // we don't have elements so don't do anything
306 {
307 }
308 else if (count == 1) // single element so just write it out
309 {
310 input[start_offset] = ((global ulong*)record.keys_in)[start_offset];
311 }
312 else if (count <= BOTTOM_LEVEL_SORT_THRESHOLD)
313 {
314 BLSDispatchQueue_push((global struct BLSDispatchQueue*)scheduler->next_bls_queue, &record);
315 }
316 }
317
318
319
320
321 // We try to merge small BLS into larger one within the sub_group
DO_Create_SG_Merged_BLS_Work_Parallel(global struct VContextScheduler * scheduler,global struct MSBRadixContext * context,global ulong * input)322 void DO_Create_SG_Merged_BLS_Work_Parallel(global struct VContextScheduler* scheduler, global struct MSBRadixContext* context, global ulong* input)
323 {
324 uint lid = get_local_id(0);
325 uint sid = get_sub_group_local_id();
326
327 uint create_msb_work = context->count[lid] > BOTTOM_LEVEL_SORT_THRESHOLD ? 1 : 0;
328
329 uint start = context->start[lid];
330 uint count = context->count[lid];
331 uint ctx_start_offset = context->start_offset;
332
333 if (sid == 0 || create_msb_work) // these SIMD lanes are the begining of merged BLS
334 {
335 struct BLSDispatchRecord record;
336 if (create_msb_work)
337 {
338 record.start_offset = ctx_start_offset + start + count;
339 record.count = 0;
340 }
341 else // SIMD lane 0 case
342 {
343 record.start_offset = ctx_start_offset + start;
344 record.count = count;
345 }
346
347 record.keys_in = context->keys_out;
348
349 uint loop_idx = 1;
350 while (sid + loop_idx < 16) // loop over subgroup
351 {
352 uint _create_msb_work = intel_sub_group_shuffle_down(create_msb_work, 0u, loop_idx);
353 uint _count = intel_sub_group_shuffle_down(count, 0u, loop_idx);
354 uint _start = intel_sub_group_shuffle_down(start, 0u, loop_idx);
355
356 if (_create_msb_work) // found out next MSB work, so range of merges ends
357 break;
358
359 // need to push record since nothing more will fit
360 if (record.count + _count > BOTTOM_LEVEL_SORT_MERGING_THRESHOLD)
361 {
362 if (record.count == 1)
363 {
364 input[record.start_offset] = record.keys_in[record.start_offset];
365 }
366 else if (record.count > 1)
367 {
368 BLSDispatchQueue_push((global struct BLSDispatchQueue*)scheduler->next_bls_queue, &record);
369 }
370 record.start_offset = ctx_start_offset + _start;
371 record.count = _count;
372 }
373 else
374 {
375 record.count += _count;
376 }
377 loop_idx++;
378 }
379 // if we have any elements left, then schedule them
380 if (record.count == 1) // only one element, so just write it out
381 {
382 input[record.start_offset] = record.keys_in[record.start_offset];
383 }
384 else if (record.count > 1)
385 {
386 BLSDispatchQueue_push((global struct BLSDispatchQueue*)scheduler->next_bls_queue, &record);
387 }
388 }
389 }
390
391
392
393
394 // We try to merge small BLS into larger one within the sub_group
DO_Create_SG_Merged_BLS_Work(global struct VContextScheduler * scheduler,global struct MSBRadixContext * context,global ulong * input)395 void DO_Create_SG_Merged_BLS_Work(global struct VContextScheduler* scheduler, global struct MSBRadixContext* context, global ulong* input)
396 {
397 uint lid = get_local_id(0);
398 uint sid = get_sub_group_local_id();
399
400 uint create_msb_work = context->count[lid] > BOTTOM_LEVEL_SORT_THRESHOLD ? 1 : 0;
401
402 uint start = context->start[lid];
403 uint count = context->count[lid];
404 uint ctx_start_offset = context->start_offset;
405
406 if (sid == 0)
407 {
408 struct BLSDispatchRecord record;
409 record.start_offset = ctx_start_offset + start;
410 record.count = 0;
411 record.keys_in = context->keys_out;
412
413 for (int i = 0; i < 16; i++)
414 {
415 uint _create_msb_work = sub_group_broadcast(create_msb_work, i);
416 uint _count = sub_group_broadcast(count, i);
417 uint _start = sub_group_broadcast(start, i);
418 if (_create_msb_work)
419 {
420 if (record.count == 1) // only one element, so just write it out
421 {
422 input[record.start_offset] = record.keys_in[record.start_offset];
423 }
424 else if (record.count > 1)
425 {
426 BLSDispatchQueue_push((global struct BLSDispatchQueue*)scheduler->next_bls_queue, &record);
427 }
428 record.start_offset = ctx_start_offset + _start + _count;
429 record.count = 0;
430 continue;
431 }
432 // need to push record since nothing more will fit
433 if (record.count + _count > BOTTOM_LEVEL_SORT_MERGING_THRESHOLD)
434 {
435 BLSDispatchQueue_push((global struct BLSDispatchQueue*)scheduler->next_bls_queue, &record);
436 record.start_offset = ctx_start_offset + _start;
437 record.count = _count;
438 }
439 else
440 {
441 record.count += _count;
442 }
443 }
444 // if we have any elements left, then schedule them
445 if (record.count == 1) // only one element, so just write it out
446 {
447 input[record.start_offset] = record.keys_in[record.start_offset];
448 }
449 else if (record.count > 1)
450 {
451 BLSDispatchQueue_push((global struct BLSDispatchQueue*)scheduler->next_bls_queue, &record);
452 }
453 }
454 }
455
456
457
458
DO_Create_Work(global struct VContextScheduler * scheduler,global struct MSBRadixContext * context,global ulong * input,local uint * slm_for_wg_scan,uint sg_size,uint wg_size)459 void DO_Create_Work(global struct VContextScheduler* scheduler, global struct MSBRadixContext* context, global ulong* input, local uint* slm_for_wg_scan, uint sg_size, uint wg_size)
460 {
461 uint lid = get_local_id(0);
462
463 uint iteration = context->iteration + 1;
464 uint start = context->start[lid];
465 uint count = context->count[lid];
466 uint start_offset = context->start_offset + start;
467
468 uint create_msb_work = count > BOTTOM_LEVEL_SORT_THRESHOLD ? 1 : 0;
469
470 #if MERGE_BLS_WITHIN_SG
471 DO_Create_SG_Merged_BLS_Work_Parallel(scheduler, context, input);
472 #else
473 DO_Create_Separate_BLS_Work(scheduler, context, input);
474 #endif
475
476 uint new_entry_id = wg_scan_inclusive_add_opt(slm_for_wg_scan, create_msb_work, sg_size, wg_size);//work_group_scan_inclusive_add(create_msb_work);
477 uint stack_begin_entry;
478 // last workitem in wg contains number of all new entries
479 if (lid == (MSB_RADIX_NUM_BINS - 1))
480 {
481 stack_begin_entry = atomic_add_global(&scheduler->msb_stack.num_entries, new_entry_id);
482 }
483 stack_begin_entry = work_group_broadcast(stack_begin_entry, (MSB_RADIX_NUM_BINS - 1));
484 new_entry_id += stack_begin_entry -1;
485
486
487 if (create_msb_work)
488 {
489 scheduler->msb_stack.entries[new_entry_id].start_offset = start_offset;
490 scheduler->msb_stack.entries[new_entry_id].count = count;
491 scheduler->msb_stack.entries[new_entry_id].iteration = iteration;
492 }
493
494 if (lid == 0) {
495 DEBUG_CODE(printf("num of new bls: %d\n", scheduler->next_bls_queue->num_records));
496 }
497 }
498
499
500 struct BatchedBLSDispatchEntry
501 {
502 /////////////////////////////////////////////////////////////
503 // State data used for communication with command streamer
504 // NOTE: This part must match definition in 'msb_radix_bitonic_sort.grl'
505 /////////////////////////////////////////////////////////////
506 qword p_data_buffer;
507 qword num_elements; // number of elements in p_data_buffer
508 };
509
510
511 GRL_ANNOTATE_IGC_DO_NOT_SPILL
512 __attribute__((reqd_work_group_size(BOTTOM_LEVEL_SORT_WG_SIZE, 1, 1)))
513 __attribute__((intel_reqd_sub_group_size(16)))
sort_morton_codes_batched_BLS_dispatch(global struct BatchedBLSDispatchEntry * bls_dispatches)514 void kernel sort_morton_codes_batched_BLS_dispatch(global struct BatchedBLSDispatchEntry* bls_dispatches)
515 {
516 uint dispatch_id = get_group_id(0);
517 uint lid = get_local_id(0);
518
519 local ulong SLM_shared[BOTTOM_LEVEL_SORT_THRESHOLD];
520
521 struct BatchedBLSDispatchEntry dispatchArgs = bls_dispatches[dispatch_id];
522 struct BLSDispatchRecord dispatchRecord;
523 dispatchRecord.start_offset = 0;
524 dispatchRecord.count = dispatchArgs.num_elements;
525 dispatchRecord.keys_in = (ulong*)dispatchArgs.p_data_buffer;
526
527 DEBUG_CODE(if (lid == 0) printf("running sort_morton_codes_bottom_level_single_wg for %d elements\n", dispatchRecord.count));
528
529 if(dispatchRecord.count > 1)
530 DO_Bitonic(dispatchRecord, SLM_shared, (global ulong*)dispatchRecord.keys_in);
531 }
532
533
534
535
536 GRL_ANNOTATE_IGC_DO_NOT_SPILL
537 __attribute__((reqd_work_group_size(BOTTOM_LEVEL_SORT_WG_SIZE, 1, 1)))
538 __attribute__((intel_reqd_sub_group_size(16)))
sort_morton_codes_bottom_level_single_wg(global struct Globals * globals,global ulong * input,global ulong * output)539 void kernel sort_morton_codes_bottom_level_single_wg(global struct Globals* globals, global ulong* input, global ulong* output)
540 {
541 uint lid = get_local_id(0);
542
543 DEBUG_CODE(if (lid == 0) printf("running sort_morton_codes_bottom_level_single_wg for %d elements\n", globals->numPrimitives));
544
545 local ulong SLM_shared[BOTTOM_LEVEL_SORT_THRESHOLD];
546
547 struct BLSDispatchRecord dispatchRecord;
548 dispatchRecord.start_offset = 0;
549 dispatchRecord.count = globals->numPrimitives;
550 dispatchRecord.keys_in = (ulong*)input;
551
552 //TODO: count or bitonic here?
553 //DO_Bitonic(dispatchRecord, SLM_shared, output);
554 DO_CountSort(dispatchRecord, SLM_shared, output);
555 }
556
557
558
559
560 // This kernel initializes first context to start up the whole execution
561 GRL_ANNOTATE_IGC_DO_NOT_SPILL
562 __attribute__((reqd_work_group_size(MSB_RADIX_NUM_BINS, 1, 1)))
563 __attribute__((intel_reqd_sub_group_size(16)))
sort_morton_codes_msb_begin(global struct Globals * globals,global struct VContextScheduler * scheduler,global ulong * buf0,global ulong * buf1)564 void kernel sort_morton_codes_msb_begin(
565 global struct Globals* globals,
566 global struct VContextScheduler* scheduler,
567 global ulong* buf0,
568 global ulong* buf1)
569 {
570 uint lid = get_local_id(0);
571 uint gid = get_group_id(0);
572
573 DEBUG_CODE(if (lid == 0)printf("running sort_morton_codes_msb_begin\n"));
574
575 scheduler->contexts[gid].count[lid] = 0;
576
577 if (gid == 0 && lid == 0)
578 {
579 global struct MSBRadixContext* context = &scheduler->contexts[lid];
580 const uint num_prims = globals->numPrimitives;
581
582 scheduler->bls_queue0.num_records = 0;
583 scheduler->bls_queue1.num_records = 0;
584
585 scheduler->curr_bls_queue = &scheduler->bls_queue1;
586 scheduler->next_bls_queue = &scheduler->bls_queue0;
587
588 context->start_offset = 0;
589 context->num_wgs_in_flight = 0;
590 context->num_keys = num_prims;
591 context->iteration = 0;
592 context->keys_in = buf0;
593 context->keys_out = buf1;
594
595 uint msb_wgs_to_dispatch = (num_prims + MSB_WG_SORT_ELEMENTS_THRESHOLD - 1) / MSB_WG_SORT_ELEMENTS_THRESHOLD;
596 scheduler->msb_queue.records[0].wgs_to_dispatch = msb_wgs_to_dispatch;
597
598 scheduler->num_wgs_msb = msb_wgs_to_dispatch;
599 scheduler->num_wgs_bls = 0;
600 scheduler->msb_stack.num_entries = 0;
601 scheduler->msb_queue.num_records = 1;
602 }
603 }
604
605
606
607
608 __attribute__((reqd_work_group_size(MSB_RADIX_NUM_VCONTEXTS, 1, 1)))
609 kernel void
scheduler(global struct VContextScheduler * scheduler,global ulong * buf0,global ulong * buf1)610 scheduler(global struct VContextScheduler* scheduler, global ulong* buf0, global ulong* buf1)
611 {
612 uint lid = get_local_id(0);
613
614 DEBUG_CODE(if (lid == 0) printf("running sort_morton_codes_scheduler\n"));
615
616 uint context_idx = lid;
617
618 const uint num_of_stack_entries = scheduler->msb_stack.num_entries;
619
620 uint msb_wgs_to_dispatch = 0;
621 if (lid < num_of_stack_entries)
622 {
623 struct MSBStackEntry entry = scheduler->msb_stack.entries[(num_of_stack_entries-1) - lid];
624 global struct MSBRadixContext* context = &scheduler->contexts[lid];
625 context->start_offset = entry.start_offset;
626 context->num_wgs_in_flight = 0;
627 context->num_keys = entry.count;
628 context->iteration = entry.iteration;
629 context->keys_in = entry.iteration % 2 == 0 ? buf0 : buf1;
630 context->keys_out = entry.iteration % 2 == 0 ? buf1 : buf0;
631
632 msb_wgs_to_dispatch = (entry.count + MSB_WG_SORT_ELEMENTS_THRESHOLD - 1) / MSB_WG_SORT_ELEMENTS_THRESHOLD;
633 scheduler->msb_queue.records[lid].wgs_to_dispatch = msb_wgs_to_dispatch;
634 }
635
636 msb_wgs_to_dispatch = work_group_reduce_add(msb_wgs_to_dispatch);// TODO: if compiler implementation is slow, then consider to manually write it
637
638 if (lid == 0)
639 {
640 // swap queue for next iteration
641 struct BLSDispatchQueue* tmp = scheduler->curr_bls_queue;
642 scheduler->curr_bls_queue = scheduler->next_bls_queue;
643 scheduler->next_bls_queue = tmp;
644
645 scheduler->next_bls_queue->num_records = 0;
646
647 scheduler->num_wgs_bls = scheduler->curr_bls_queue->num_records;
648 scheduler->num_wgs_msb = msb_wgs_to_dispatch;
649
650 if (num_of_stack_entries < MSB_RADIX_NUM_VCONTEXTS)
651 {
652 scheduler->msb_queue.num_records = num_of_stack_entries;
653 scheduler->msb_stack.num_entries = 0;
654 }
655 else
656 {
657 scheduler->msb_queue.num_records = MSB_RADIX_NUM_VCONTEXTS;
658 scheduler->msb_stack.num_entries -= MSB_RADIX_NUM_VCONTEXTS;
659 }
660 }
661
662 DEBUG_CODE(if (lid == 0) printf("running sort_morton_codes_scheduler finished, to spawn %d MSB wgs in %d contexts and %d BLS wgs, MSB records on stack %d\n",
663 scheduler->num_wgs_msb, scheduler->msb_queue.num_records, scheduler->num_wgs_bls, scheduler->msb_stack.num_entries));
664 }
665
666
667
668
669 // this is the lowest sub-task, which should end return sorted codes
670 GRL_ANNOTATE_IGC_DO_NOT_SPILL
671 __attribute__((reqd_work_group_size(BOTTOM_LEVEL_SORT_WG_SIZE, 1, 1)))
672 __attribute__((intel_reqd_sub_group_size(16)))
sort_morton_codes_bottom_level(global struct VContextScheduler * scheduler,global ulong * output)673 void kernel sort_morton_codes_bottom_level( global struct VContextScheduler* scheduler, global ulong* output)
674 {
675 uint lid = get_local_id(0);
676
677 DEBUG_CODE(if (get_group_id(0) == 0 && lid == 0) printf("running sort_morton_codes_bottom_level\n"));
678
679 local struct BLSDispatchRecord l_dispatchRecord;
680 if (lid == 0)
681 {
682 uint record_idx = get_group_id(0);
683 l_dispatchRecord = scheduler->curr_bls_queue->records[record_idx];
684 //l_dispatchRecord = BLSDispatchQueue_pop((global struct BLSDispatchQueue*)scheduler->curr_bls_queue);
685 atomic_dec_global(&scheduler->num_wgs_bls);
686 }
687
688 barrier(CLK_LOCAL_MEM_FENCE);
689
690 struct BLSDispatchRecord dispatchRecord = l_dispatchRecord;
691
692 local ulong SLM_shared[BOTTOM_LEVEL_SORT_THRESHOLD];
693
694 // right now use only bitonic sort
695 // TODO: maybe implement something else
696 if (1)
697 {
698 //DO_Bitonic(dispatchRecord, SLM_shared, output);
699 DO_CountSort(dispatchRecord, SLM_shared, output);
700 }
701 }
702
703
704
705
706 #define MSB_COUNT_WG_SIZE MSB_RADIX_NUM_BINS
707 #define MSB_COUNT_SG_SIZE 16
708
709 // count how many elements per buckets we have
710 GRL_ANNOTATE_IGC_DO_NOT_SPILL
711 __attribute__((reqd_work_group_size(MSB_COUNT_WG_SIZE, 1, 1)))
712 __attribute__((intel_reqd_sub_group_size(MSB_COUNT_SG_SIZE)))
sort_morton_codes_msb_count_items(global struct VContextScheduler * scheduler)713 void kernel sort_morton_codes_msb_count_items( global struct VContextScheduler* scheduler)
714 {
715 uint lid = get_local_id(0);
716 uint lsz = MSB_RADIX_NUM_BINS;
717
718 DEBUG_CODE(if (lid == 0) printf("running sort_morton_codes_msb_count_items\n"));
719
720 local uint bucket_count[MSB_RADIX_NUM_BINS];
721 local uint finish_count;
722 bucket_count[lid] = 0;
723 if (lid == 0)
724 {
725 finish_count = 0;
726 }
727
728 struct MSBDispatchArgs dispatchArgs = get_msb_dispatch_args(scheduler);
729
730 global struct MSBRadixContext* context = dispatchArgs.context;
731
732 global ulong* key_start = (global ulong*)dispatchArgs.wg_key_start + lid;
733 global ulong* key_end = (global ulong*)dispatchArgs.wg_key_end;
734 uint shift_bit = dispatchArgs.shift_bit;
735 uchar shift_byte = shift_bit / 8; // so we count how many uchars to shift
736 barrier(CLK_LOCAL_MEM_FENCE);
737
738 global uchar* ks = (global uchar*)key_start;
739 ks += shift_byte;
740 global uchar* ke = (global uchar*)key_end;
741 ke += shift_byte;
742
743 // double buffering on value loading
744 if (ks < ke)
745 {
746 uchar bucket_id = *ks;
747 ks += lsz * sizeof(ulong);
748
749 for (global uchar* k = ks; k < ke; k += lsz * sizeof(ulong))
750 {
751 uchar next_bucket_id = *k;
752 atomic_inc_local(&bucket_count[bucket_id]);
753 bucket_id = next_bucket_id;
754 }
755
756 atomic_inc_local(&bucket_count[bucket_id]);
757
758 }
759
760 barrier(CLK_LOCAL_MEM_FENCE);
761
762 //update global counters for context
763 uint count = bucket_count[lid];
764 if (count > 0)
765 atomic_add_global(&context->count[lid], bucket_count[lid]);
766
767 mem_fence_gpu_invalidate();
768 work_group_barrier(0);
769
770 bool final_wg = true;
771 // count WGs which have reached the end
772 if (dispatchArgs.num_of_wgs > 1)
773 {
774 if (lid == 0)
775 finish_count = atomic_inc_global(&context->num_wgs_in_flight) + 1;
776
777 barrier(CLK_LOCAL_MEM_FENCE);
778
779 final_wg = finish_count == dispatchArgs.num_of_wgs;
780 }
781
782 local uint partial_dispatches[MSB_COUNT_WG_SIZE / MSB_COUNT_SG_SIZE];
783 // if this is last wg for current dispatch, update context
784 if (final_wg)
785 {
786 // code below does work_group_scan_exclusive_add(context->count[lid]);
787 {
788 uint lane_val = context->count[lid];
789 uint sg_result = sub_group_scan_inclusive_add(lane_val);
790
791 partial_dispatches[get_sub_group_id()] = sub_group_broadcast(sg_result, MSB_COUNT_SG_SIZE - 1);
792 barrier(CLK_LOCAL_MEM_FENCE);
793
794 uint slm_result = sub_group_scan_exclusive_add(partial_dispatches[get_sub_group_local_id()]);
795 slm_result = sub_group_broadcast(slm_result, get_sub_group_id());
796 uint result = slm_result + sg_result - lane_val;
797 context->start[lid] = result;//work_group_scan_exclusive_add(context->count[lid]);
798 }
799
800 context->count[lid] = 0;
801 if(lid == 0)
802 context->num_wgs_in_flight = 0;
803 }
804 }
805
806
807
808
809 // sort elements into appropriate buckets
810 GRL_ANNOTATE_IGC_DO_NOT_SPILL
811 __attribute__((reqd_work_group_size(MSB_RADIX_NUM_BINS, 1, 1)))
812 __attribute__((intel_reqd_sub_group_size(16)))
sort_morton_codes_msb_bin_items(global struct VContextScheduler * scheduler,global ulong * input)813 void kernel sort_morton_codes_msb_bin_items(
814 global struct VContextScheduler* scheduler, global ulong* input)
815 {
816 uint lid = get_local_id(0);
817 uint lsz = get_local_size(0);
818
819 DEBUG_CODE(if (lid == 0) printf("running sort_morton_codes_msb_bin_items\n"));
820
821 local uint finish_count;
822 if (lid == 0)
823 {
824 finish_count = 0;
825 }
826
827 struct MSBDispatchArgs dispatchArgs = get_msb_dispatch_args(scheduler);
828 global struct MSBRadixContext* context = dispatchArgs.context;
829
830 global ulong* key_start = (global ulong*)dispatchArgs.wg_key_start + lid;
831 global ulong* key_end = (global ulong*)dispatchArgs.wg_key_end;
832 uint shift_bit = dispatchArgs.shift_bit;
833
834 barrier(CLK_LOCAL_MEM_FENCE);
835
836 global ulong* sorted_keys = (global ulong*)context->keys_out + context->start_offset;
837
838 #if MSB_RADIX_NUM_BINS == MSB_WG_SORT_ELEMENTS_THRESHOLD // special case meaning that we process exactly 1 element per workitem
839 // here we'll do local counting, then move to global
840
841 local uint slm_counters[MSB_RADIX_NUM_BINS];
842 slm_counters[lid] = 0;
843
844 barrier(CLK_LOCAL_MEM_FENCE);
845
846 uint place_in_slm_bucket;
847 uint bucket_id;
848 ulong val;
849
850 bool active_lane = key_start < key_end;
851
852 if (active_lane)
853 {
854 val = *key_start;
855
856 bucket_id = (val >> (ulong)shift_bit) & (MSB_RADIX_NUM_BINS - 1);
857 place_in_slm_bucket = atomic_inc_local(&slm_counters[bucket_id]);
858 }
859
860 barrier(CLK_LOCAL_MEM_FENCE);
861
862 // override slm_counters with global counters - we don't need to override counters with 0 elements since we won't use them anyway
863 if (slm_counters[lid])
864 slm_counters[lid] = atomic_add_global(&context->count[lid], slm_counters[lid]);
865
866 barrier(CLK_LOCAL_MEM_FENCE);
867
868 uint id_in_bucket = slm_counters[bucket_id] + place_in_slm_bucket;//atomic_inc_global(&context->count[bucket_id]);
869
870 if (active_lane)
871 sorted_keys[context->start[bucket_id] + id_in_bucket] = val;
872 #else
873 // double buffering on value loading
874 if (key_start < key_end)
875 {
876 ulong val = *key_start;
877 key_start += lsz;
878
879 for (global ulong* k = key_start; k < key_end; k += lsz)
880 {
881 ulong next_val = *k;
882 uint bucket_id = (val >> (ulong)shift_bit) & (MSB_RADIX_NUM_BINS - 1);
883 uint id_in_bucket = atomic_inc_global(&context->count[bucket_id]);
884
885 //printf("dec: %llu, val: %llX bucket_id: %X", *k, *k, bucket_id);
886 sorted_keys[context->start[bucket_id] + id_in_bucket] = val;
887
888 val = next_val;
889 }
890
891 uint bucket_id = (val >> (ulong)shift_bit) & (MSB_RADIX_NUM_BINS - 1);
892 uint id_in_bucket = atomic_inc_global(&context->count[bucket_id]);
893
894 sorted_keys[context->start[bucket_id] + id_in_bucket] = val;
895 }
896 #endif
897
898 // make sure all groups's "counters" and "starts" are visible to final workgroup
899 mem_fence_gpu_invalidate();
900 work_group_barrier(0);
901
902 bool final_wg = true;
903 // count WGs which have reached the end
904 if (dispatchArgs.num_of_wgs > 1)
905 {
906 if (lid == 0)
907 finish_count = atomic_inc_global(&context->num_wgs_in_flight) + 1;
908
909 barrier(CLK_LOCAL_MEM_FENCE);
910
911 final_wg = finish_count == dispatchArgs.num_of_wgs;
912 }
913
914 local uint slm_for_wg_funcs[MSB_COUNT_WG_SIZE / MSB_COUNT_SG_SIZE];
915 // if this is last wg for current dispatch, then prepare sub-tasks
916 if (final_wg)
917 {
918 DO_Create_Work(scheduler, context, input, slm_for_wg_funcs, 16, MSB_RADIX_NUM_BINS);
919
920 // clear context's counters for future execution
921 context->count[lid] = 0;
922 }
923
924 }