xref: /aosp_15_r20/external/mesa3d/src/intel/vulkan/grl/gpu/bvh_rebraid.cl (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1//
2// Copyright (C) 2009-2021 Intel Corporation
3//
4// SPDX-License-Identifier: MIT
5//
6//
7
8#include "AABB.h"
9#include "GRLGen12.h"
10#include "api_interface.h"
11#include "common.h"
12#include "qbvh6.h"
13
14#define MAX_SPLITS_PER_INSTANCE 64
15#define NUM_REBRAID_BINS 32
16
17#define NUM_CHILDREN 6
18#define MAX_NODE_OFFSET 65535 // can't open nodes whose offsets exceed this
19
20// OCL/DPC++ *SHOULD* have a uniform keyword... but they dont... so I'm making my own
21#define uniform
22#define varying
23
24#define SGPRINT_UNIFORM(fmt,val)    {sub_group_barrier(CLK_LOCAL_MEM_FENCE); if( get_sub_group_local_id() == 0 ) { printf(fmt,val); }}
25
26#define SGPRINT_6x(prefix,fmt,type,val)  {\
27                                        type v0 = sub_group_broadcast( val, 0 );\
28                                        type v1 = sub_group_broadcast( val, 1 );\
29                                        type v2 = sub_group_broadcast( val, 2 );\
30                                        type v3 = sub_group_broadcast( val, 3 );\
31                                        type v4 = sub_group_broadcast( val, 4 );\
32                                        type v5 = sub_group_broadcast( val, 5 );\
33                                        sub_group_barrier(CLK_LOCAL_MEM_FENCE); \
34                                        if( get_sub_group_local_id() == 0 ) { \
35                                        printf(prefix fmt fmt fmt fmt fmt fmt "\n" , \
36                                            v0,v1,v2,v3,v4,v5);}}
37
38
39#define SGPRINT_16x(prefix,fmt,type,val)  {\
40                                        type v0 = sub_group_broadcast( val, 0 );\
41                                        type v1 = sub_group_broadcast( val, 1 );\
42                                        type v2 = sub_group_broadcast( val, 2 );\
43                                        type v3 = sub_group_broadcast( val, 3 );\
44                                        type v4 = sub_group_broadcast( val, 4 );\
45                                        type v5 = sub_group_broadcast( val, 5 );\
46                                        type v6 = sub_group_broadcast( val, 6 );\
47                                        type v7 = sub_group_broadcast( val, 7 );\
48                                        type v8 = sub_group_broadcast( val, 8 );\
49                                        type v9 = sub_group_broadcast( val, 9 );\
50                                        type v10 = sub_group_broadcast( val, 10 );\
51                                        type v11 = sub_group_broadcast( val, 11 );\
52                                        type v12 = sub_group_broadcast( val, 12 );\
53                                        type v13 = sub_group_broadcast( val, 13 );\
54                                        type v14 = sub_group_broadcast( val, 14 );\
55                                        type v15 = sub_group_broadcast( val, 15 );\
56                                        sub_group_barrier(CLK_LOCAL_MEM_FENCE); \
57                                        if( get_sub_group_local_id() == 0 ) { \
58                                        printf(prefix fmt fmt fmt fmt fmt fmt fmt fmt \
59                                                      fmt fmt fmt fmt fmt fmt fmt fmt"\n" , \
60                                            v0,v1,v2,v3,v4,v5,v6,v7,v8,v9,v10,v11,v12,v13,v14,v15);}}
61
62#if 1
63#define GRL_ATOMIC_INC(addr) atomic_add(addr, 1);
64#else
65#define GRL_ATOMIC_INC(addr) atomic_inc(addr);
66#endif
67
68#if 0
69#define LOOP_TRIPWIRE_INIT uint _loop_trip=0;
70
71#define LOOP_TRIPWIRE_INCREMENT(max_iterations,name) \
72    _loop_trip++;\
73    if ( _loop_trip > max_iterations )\
74    {\
75        printf( "@@@@@@@@@@@@@@@@@@@@ TRIPWIRE!!!!!!!!!!!\n" );\
76        printf( name"\n");\
77        break;\
78    }
79#else
80
81#define LOOP_TRIPWIRE_INIT
82#define LOOP_TRIPWIRE_INCREMENT(max_iterations,name)
83
84#endif
85
86
87
88typedef struct SGHeap
89{
90    uint32_t key_value;
91    bool lane_mask;
92} SGHeap;
93
94GRL_INLINE void SGHeap_init(uniform SGHeap *h)
95{
96    h->lane_mask = false;
97    h->key_value = 0xbaadf00d;
98}
99
100GRL_INLINE bool SGHeap_full(uniform SGHeap *h)
101{
102    return sub_group_all(h->lane_mask);
103}
104GRL_INLINE bool SGHeap_empty(uniform SGHeap *h)
105{
106    return sub_group_all(!h->lane_mask);
107}
108
109GRL_INLINE bool SGHeap_get_lane_mask(uniform SGHeap *h)
110{
111    return h->lane_mask;
112}
113GRL_INLINE uint16_t SGHeap_get_lane_values(uniform SGHeap *h)
114{
115    return (h->key_value & 0xffff);
116}
117
118GRL_INLINE ushort isolate_lowest_bit( ushort m )
119{
120    return m & ~(m - 1);
121}
122
123
124// lane i receives the index of the ith set bit in mask.
125GRL_INLINE ushort subgroup_bit_rank( uniform ushort mask )
126{
127    varying ushort lane = get_sub_group_local_id();
128    ushort idx = 16;
129    for ( uint i = 0; i < NUM_CHILDREN; i++ )
130    {
131        ushort lo = isolate_lowest_bit( mask );
132        mask = mask ^ lo;
133        idx = (lane == i) ? lo : idx;
134    }
135
136    return ctz( idx );
137}
138
139// push a set of elements spread across a subgroup.  Return mask of elements that were not pushed
140GRL_INLINE uint16_t SGHeap_vectorized_push(uniform SGHeap *h, varying uint16_t key, varying uint16_t value, uniform ushort push_mask)
141{
142
143#if 0 // an attempt to make this algorithm branchless
144    varying uint key_value = (((uint)key) << 16) | ((uint)value);
145    uniform ushort free_mask = intel_sub_group_ballot( !h->lane_mask );
146
147    varying ushort free_slot_idx = subgroup_bit_prefix_exclusive( free_mask ); // for each heap slot, what is its position in a compacted list of free slots (prefix sum)
148    varying ushort push_idx      = subgroup_bit_prefix_exclusive( push_mask );  // for each lane, what is its position in a compacted list of pushing lanes (prefix sum)
149
150    uniform ushort num_pushes = min( popcount( free_mask ), popcount( push_mask ) );
151
152    varying ushort push_index = subgroup_bit_rank( push_mask ); // lane i gets the index of the i'th set bit in push_mask
153
154    varying uint shuffled = intel_sub_group_shuffle( key_value, intel_sub_group_shuffle( push_index, free_slot_idx ) );
155    varying bool pushed = false;
156    if ( !h->lane_mask && free_slot_idx < num_pushes )
157    {
158        h->lane_mask = true;
159        h->key_value = shuffled;
160        pushed = true;
161    }
162
163    return push_mask & intel_sub_group_ballot( push_idx >= num_pushes );
164#else
165
166    varying uint lane = get_sub_group_local_id();
167
168    varying uint key_value = (((uint)key) << 16) | ((uint)value);
169    uniform ushort free_mask = intel_sub_group_ballot(!h->lane_mask);
170
171    // TODO_OPT:  Look for some clever way to remove this loop
172    while (free_mask && push_mask)
173    {
174        // insert first active child into first available lane
175        uniform uint child_id = ctz(push_mask);
176        uniform uint victim_lane = ctz(free_mask);
177        uniform uint kv = sub_group_broadcast( key_value, child_id );
178        if (victim_lane == lane)
179        {
180            h->lane_mask = true;
181            h->key_value = kv;
182        }
183        push_mask ^= (1 << child_id);
184        free_mask ^= (1 << victim_lane);
185    }
186
187    return push_mask;
188
189#endif
190}
191
192// push an item onto a heap that is full except for one slot
193GRL_INLINE void SGHeap_push_and_fill(uniform SGHeap *h, uniform uint16_t key, uniform uint16_t value)
194{
195    uniform uint32_t key_value = (((uint)key) << 16) | value;
196    if (!h->lane_mask)
197    {
198        h->lane_mask = true;
199        h->key_value = key_value; // only one lane will be active at this point
200    }
201}
202
203// pop the min item from a full heap
204GRL_INLINE void SGHeap_full_pop_min(uniform SGHeap *h, uniform uint16_t *key_out, uniform uint16_t *value_out)
205{
206    varying uint lane = get_sub_group_local_id();
207    uniform uint kv = sub_group_reduce_min(h->key_value);
208    if (h->key_value == kv)
209        h->lane_mask = false;
210
211    *key_out   = (kv >> 16);
212    *value_out = (kv & 0xffff);
213}
214
215// pop the max item from a heap
216GRL_INLINE void SGHeap_pop_max(uniform SGHeap *h, uniform uint16_t *key_out, uniform uint16_t *value_out)
217{
218    uniform uint lane = get_sub_group_local_id();
219    uniform uint kv = sub_group_reduce_max(h->lane_mask ? h->key_value : 0);
220    if (h->key_value == kv)
221        h->lane_mask = false;
222
223    *key_out = (kv >> 16);
224    *value_out = (kv & 0xffff);
225}
226
227GRL_INLINE void SGHeap_printf( SGHeap* heap )
228{
229    uint key = heap->key_value >> 16;
230    uint value = heap->key_value & 0xffff;
231
232    if ( get_sub_group_local_id() == 0)
233        printf( "HEAP: \n" );
234    SGPRINT_16x( "  mask: ", "%6u ", bool, heap->lane_mask );
235    SGPRINT_16x( "  key : ", "0x%04x ", uint, key );
236    SGPRINT_16x( "  val : ", "0x%04x ", uint, value );
237
238}
239
240GRL_INLINE float transformed_aabb_halfArea(float3 lower, float3 upper, const float *Transform)
241{
242    // Compute transformed extent per 'transform_aabb'.  Various terms cancel
243    float3 Extent = upper - lower;
244    float ex = Extent.x * fabs(Transform[0]) + Extent.y * fabs(Transform[1]) + Extent.z * fabs(Transform[2]);
245    float ey = Extent.x * fabs(Transform[4]) + Extent.y * fabs(Transform[5]) + Extent.z * fabs(Transform[6]);
246    float ez = Extent.x * fabs(Transform[8]) + Extent.y * fabs(Transform[9]) + Extent.z * fabs(Transform[10]);
247
248    return (ex * ey) + (ey * ez) + (ex * ez);
249}
250
251GRL_INLINE uint16_t quantize_area(float relative_area)
252{
253    // clamp relative area at 0.25 (1/4 of root area)
254    //  and apply a non-linear distribution because most things in real scenes are small
255    relative_area = pow(min(1.0f, relative_area * 4.0f), 0.125f);
256    return convert_ushort_rtn( relative_area * 65535.0f );
257}
258
259GRL_INLINE varying uint16_t SUBGROUP_get_child_areas(uniform InternalNode *n,
260                                                 uniform const float *Transform,
261                                                 uniform float relative_area_scale)
262{
263    varying uint16_t area;
264    varying uint16_t lane = get_sub_group_local_id();
265    varying int exp_x = n->exp_x;
266    varying int exp_y = n->exp_y;
267    varying int exp_z = n->exp_z;
268
269    {
270        // decode the AABB positions.  Lower in the bottom 6 lanes, upper in the top
271        uniform uint8_t *px = &n->lower_x[0];
272        uniform uint8_t *py = &n->lower_y[0];
273        uniform uint8_t *pz = &n->lower_z[0];
274
275        varying float fx = convert_float(px[lane]);
276        varying float fy = convert_float(py[lane]);
277        varying float fz = convert_float(pz[lane]);
278        fx = n->lower[0] + bitShiftLdexp(fx, exp_x - 8);
279        fy = n->lower[1] + bitShiftLdexp(fy, exp_y - 8);
280        fz = n->lower[2] + bitShiftLdexp(fz, exp_z - 8);
281
282        // transform the AABBs to world space
283        varying float3 lower = (float3)(fx, fy, fz);
284        varying float3 upper = intel_sub_group_shuffle(lower, lane + 6);
285
286        {
287
288            // TODO_OPT:  This is only utilizing 6 lanes.
289            //  We might be able to do better by vectorizing the calculation differently
290            float a1 = transformed_aabb_halfArea( lower, upper, Transform );
291            float a2 = a1 * relative_area_scale;
292            area = quantize_area( a2 );
293        }
294    }
295
296    return area;
297}
298
299
300
301GRL_INLINE ushort get_child_area(
302    InternalNode* n,
303    ushort child,
304    const float* Transform,
305    float relative_area_scale )
306{
307    uint16_t area;
308    uint16_t lane = get_sub_group_local_id();
309    int exp_x = n->exp_x;
310    int exp_y = n->exp_y;
311    int exp_z = n->exp_z;
312
313    // decode the AABB positions.  Lower in the bottom 6 lanes, upper in the top
314    uint8_t* px = &n->lower_x[0];
315    uint8_t* py = &n->lower_y[0];
316    uint8_t* pz = &n->lower_z[0];
317
318    float3 lower, upper;
319    lower.x = convert_float( n->lower_x[child] );
320    lower.y = convert_float( n->lower_y[child] );
321    lower.z = convert_float( n->lower_z[child] );
322    upper.x = convert_float( n->upper_x[child] );
323    upper.y = convert_float( n->upper_y[child] );
324    upper.z = convert_float( n->upper_z[child] );
325
326    lower.x = bitShiftLdexp( lower.x, exp_x - 8 ); // NOTE:  the node's 'lower' field cancels out, so don't add it
327    lower.y = bitShiftLdexp( lower.y, exp_y - 8 ); //    see transform_aabb_halfArea
328    lower.z = bitShiftLdexp( lower.z, exp_z - 8 );
329    upper.x = bitShiftLdexp( upper.x, exp_x - 8 );
330    upper.y = bitShiftLdexp( upper.y, exp_y - 8 );
331    upper.z = bitShiftLdexp( upper.z, exp_z - 8 );
332
333    float a1 = transformed_aabb_halfArea( lower, upper, Transform );
334    float a2 = a1 * relative_area_scale;
335    area = quantize_area( a2 );
336
337    return area;
338}
339
340
341GRL_INLINE varying int SUBGROUP_get_child_offsets(uniform InternalNode *n)
342{
343    varying uint lane = get_sub_group_local_id();
344    varying uint child = (lane < NUM_CHILDREN) ? lane : 0;
345
346    varying uint block_incr = InternalNode_GetChildBlockIncr( n, child );
347
348    //varying uint prefix = sub_group_scan_exclusive_add( block_incr );
349    varying uint prefix;
350    if ( NUM_CHILDREN == 6 )
351    {
352        prefix = block_incr + intel_sub_group_shuffle_up( 0u, block_incr, 1u );
353        prefix = prefix + intel_sub_group_shuffle_up( 0u, prefix, 2 );
354        prefix = prefix + intel_sub_group_shuffle_up( 0u, prefix, 4 );
355        prefix = prefix - block_incr;
356    }
357
358    return n->childOffset + prefix;
359}
360
361
362// compute the maximum number of leaf nodes that will be produced given 'num_splits' node openings
363GRL_INLINE uint get_num_nodes(uint num_splits, uint max_children)
364{
365    // each split consumes one node and replaces it with N nodes
366    //   there is initially one node
367    //  number of nodes is thus:  N*s + 1 - s ==> (N-1)*s + 1
368    return (max_children - 1) * num_splits + 1;
369}
370
371// compute the number of node openings that can be performed given a fixed extra node budget
372GRL_INLINE uint get_num_splits(uint num_nodes, uint max_children)
373{
374    // inverse of get_num_nodes:   x = (n-1)s + 1
375    //                             s = (x-1)/(n-1)
376    if (num_nodes == 0)
377        return 0;
378
379    return (num_nodes - 1) / (max_children - 1);
380}
381
382GRL_INLINE uint get_rebraid_bin_index(uint16_t quantized_area, uint NUM_BINS)
383{
384    // arrange bins in descending order by size
385    float relative_area = quantized_area * (1.0f/65535.0f);
386    relative_area = 1.0f - relative_area; // arrange bins largest to smallest
387    size_t bin = round(relative_area * (NUM_BINS - 1));
388    return bin;
389}
390
391GRL_INLINE global InternalNode *get_node(global BVHBase *base, int incr)
392{
393    global char *ptr = (((global char *)base) + BVH_ROOT_NODE_OFFSET); // NOTE: Assuming this will be hoisted out of inner loops
394
395    return (global InternalNode *)(ptr + incr * 64);
396}
397
398GRL_INLINE bool is_aabb_valid(float3 lower, float3 upper)
399{
400    return all(isfinite(lower)) &&
401           all(isfinite(upper)) &&
402           all(lower <= upper);
403}
404
405GRL_INLINE bool is_node_openable(InternalNode *n)
406{
407    // TODO_OPT: Optimize me by fetching dwords instead of looping over bytes
408    // TODO: OPT:  Pre-compute openability and pack into the pad byte next to the nodeType field??
409    bool openable = n->nodeType == NODE_TYPE_INTERNAL;
410    if ( openable )
411    {
412        for ( uint i = 0; i < NUM_CHILDREN; i++ )
413        {
414            bool valid = InternalNode_IsChildValid( n, i );
415            uint childType = InternalNode_GetChildType( n, i );
416            openable = openable & (!valid || (childType == NODE_TYPE_INTERNAL));
417        }
418    }
419
420    return openable;
421}
422
423
424GRL_INLINE bool SUBGROUP_can_open_root(
425    uniform global BVHBase *bvh_base,
426    uniform const struct GRL_RAYTRACING_INSTANCE_DESC* instance
427    )
428{
429    if (bvh_base == 0 || GRL_get_InstanceMask(instance) == 0)
430        return false;
431
432    // TODO_OPT: SG-vectorize this AABB test
433    uniform float3 root_lower = AABB3f_load_lower(&bvh_base->Meta.bounds);
434    uniform float3 root_upper = AABB3f_load_upper(&bvh_base->Meta.bounds);
435    if (!is_aabb_valid(root_lower, root_upper))
436        return false;
437
438    uniform global InternalNode *node = get_node(bvh_base, 0);
439    if ( node->nodeType != NODE_TYPE_INTERNAL )
440        return false;
441
442    varying bool openable = true;
443    varying uint lane = get_sub_group_local_id();
444    if (lane < NUM_CHILDREN)
445    {
446        varying uint childType = InternalNode_GetChildType(node, lane);
447        varying bool valid = InternalNode_IsChildValid(node, lane);
448        openable = childType == NODE_TYPE_INTERNAL || !valid;
449    }
450
451    return sub_group_all(openable);
452}
453
454
455
456GRL_INLINE
457varying uint2
458SUBGROUP_count_instance_splits(uniform global struct AABB3f *geometry_bounds,
459                               uniform global __const struct GRL_RAYTRACING_INSTANCE_DESC *instance)
460{
461    uniform global BVHBase *bvh_base = (global BVHBase *)instance->AccelerationStructure;
462    if (!SUBGROUP_can_open_root(bvh_base, instance))
463        return (uint2)(0, 0);
464
465    uniform float relative_area_scale = 1.0f / AABB3f_halfArea(geometry_bounds);
466    uniform float3 root_lower = AABB3f_load_lower(&bvh_base->Meta.bounds);
467    uniform float3 root_upper = AABB3f_load_upper(&bvh_base->Meta.bounds);
468
469    uniform uint16_t quantized_area = quantize_area(transformed_aabb_halfArea(root_lower, root_upper, instance->Transform) * relative_area_scale);
470    uniform uint16_t node_offs = 0;
471
472    uniform SGHeap heap;
473    uniform uint num_splits = 0;
474
475    SGHeap_init(&heap);
476    varying uint sg_split_counts_hi = 0; // cross-subgroup bin counters
477    varying uint sg_split_counts_lo = 0;
478
479    uniform global InternalNode* node_array = get_node( bvh_base, 0 );
480
481    LOOP_TRIPWIRE_INIT;
482
483    while (1)
484    {
485        uniform global InternalNode* node = node_array + node_offs;
486
487        // count this split
488        uniform uint bin = get_rebraid_bin_index(quantized_area, NUM_REBRAID_BINS);
489        varying uint lane = get_sub_group_local_id();
490
491        sg_split_counts_hi += ((lane + 16) == bin) ? 1 : 0;
492        sg_split_counts_lo += (lane == bin) ? 1 : 0;
493
494        // open this node and push all of its openable children to heap
495        varying uint sg_offs = node_offs + SUBGROUP_get_child_offsets(node);
496        varying bool sg_openable = 0;
497        if (lane < NUM_CHILDREN & sg_offs <= MAX_NODE_OFFSET )
498            if (InternalNode_IsChildValid(node, lane))
499                sg_openable = is_node_openable( node_array + sg_offs);
500
501        uniform uint openable_children = intel_sub_group_ballot(sg_openable);
502
503        if ( openable_children )
504        {
505            varying uint16_t sg_area = SUBGROUP_get_child_areas( node, instance->Transform, relative_area_scale );
506
507            if ( !SGHeap_full( &heap ) )
508            {
509                openable_children = SGHeap_vectorized_push( &heap, sg_area, sg_offs, openable_children );
510            }
511
512            while ( openable_children )
513            {
514                // pop min element
515                uniform uint16_t min_area;
516                uniform uint16_t min_offs;
517                SGHeap_full_pop_min( &heap, &min_area, &min_offs );
518
519                // eliminate all children smaller than heap minimum
520                openable_children &= intel_sub_group_ballot( sg_area > min_area );
521
522                if ( openable_children )
523                {
524                    // if any children survived,
525                    // kick out heap minimum and replace with first child.. otherwise we will re-push the minimum
526                    uniform uint child_id = ctz( openable_children );
527                    openable_children ^= (1 << child_id);
528                    min_area = sub_group_broadcast( sg_area, child_id );
529                    min_offs = sub_group_broadcast( sg_offs, child_id );
530                }
531
532                // re-insert onto heap
533                SGHeap_push_and_fill( &heap, min_area, min_offs );
534
535                // repeat until all children are accounted for.  It is possible
536                //  for multiple children to fit in the heap, because heap minimum is now changed and we need to recompute it
537            }
538        }
539
540        num_splits++;
541        if (num_splits == MAX_SPLITS_PER_INSTANCE)
542            break;
543
544        if (SGHeap_empty(&heap))
545            break;
546
547        // get next node from heap
548        SGHeap_pop_max(&heap, &quantized_area, &node_offs);
549
550        LOOP_TRIPWIRE_INCREMENT( 500, "rebraid_count_splits" );
551
552    }
553
554    return (uint2)(sg_split_counts_lo, sg_split_counts_hi);
555}
556
557typedef struct RebraidBuffers
558{
559    global uint *bin_split_counts;    // [num_bins]
560    global uint *bin_instance_counts; // [num_bins]
561    global uint *instance_bin_counts; // num_intances * num_bins
562} RebraidBuffers;
563
564GRL_INLINE RebraidBuffers cast_rebraid_buffers(global uint *scratch, uint instanceID)
565{
566    RebraidBuffers b;
567    b.bin_split_counts = scratch;
568    b.bin_instance_counts = scratch + NUM_REBRAID_BINS;
569    b.instance_bin_counts = scratch + (2 + instanceID) * NUM_REBRAID_BINS;
570    return b;
571}
572
573///////////////////////////////////////////////////////////////////////////////////////////
574//                            Compute AABB
575//                  Dispatch one work item per instance
576///////////////////////////////////////////////////////////////////////////////////////////
577
578GRL_INLINE void rebraid_compute_AABB(
579                          global struct BVHBase* bvh,
580                          global __const struct GRL_RAYTRACING_INSTANCE_DESC *instance)
581{
582    // don't open null rtas
583    global BVHBase *bvh_base = (global BVHBase *)instance->AccelerationStructure;
584
585    struct AABB new_primref;
586    if (bvh_base != 0)
587    {
588        float3 root_lower = AABB3f_load_lower(&bvh_base->Meta.bounds);
589        float3 root_upper = AABB3f_load_upper(&bvh_base->Meta.bounds);
590        const float *Transform = instance->Transform;
591
592        if (is_aabb_valid(root_lower, root_upper))
593        {
594            new_primref = AABBfromAABB3f(transform_aabb(root_lower, root_upper, Transform));
595        }
596        else
597        {
598            // degenerate instance which might be updated to be non-degenerate
599            //  use AABB position to guide BVH construction
600            //
601            new_primref.lower.x = Transform[3];
602            new_primref.lower.y = Transform[7];
603            new_primref.lower.z = Transform[11];
604            new_primref.upper = new_primref.lower;
605        }
606    }
607    else
608    {
609        AABB_init(&new_primref);
610    }
611
612    struct AABB subgroup_bbox = AABB_sub_group_reduce(&new_primref);
613
614    if (get_sub_group_local_id() == 0)
615    {
616        AABB3f_atomic_merge_global_lu(&bvh->Meta.bounds, subgroup_bbox.lower.xyz, subgroup_bbox.upper.xyz );
617    }
618}
619
620GRL_ANNOTATE_IGC_DO_NOT_SPILL
621__attribute__((reqd_work_group_size(1, 1, 1)))
622__attribute__((intel_reqd_sub_group_size(16))) void kernel
623rebraid_computeAABB_DXR_instances(
624    global struct BVHBase* bvh,
625    global __const struct GRL_RAYTRACING_INSTANCE_DESC *instances)
626{
627    const uint instanceID = get_local_id(0) + get_group_id(0)*get_local_size(0);
628    rebraid_compute_AABB(bvh, instances + instanceID);
629}
630
631GRL_ANNOTATE_IGC_DO_NOT_SPILL
632__attribute__((reqd_work_group_size(1, 1, 1)))
633__attribute__((intel_reqd_sub_group_size(16))) void kernel
634rebraid_computeAABB_DXR_instances_indirect(
635    global struct BVHBase* bvh,
636    global __const struct GRL_RAYTRACING_INSTANCE_DESC *instances,
637    global struct IndirectBuildRangeInfo const * const indirect_data)
638{
639    const uint instanceID = get_local_id(0) + get_group_id(0)*get_local_size(0);
640    instances = (global __const struct GRL_RAYTRACING_INSTANCE_DESC*)
641        (((global char*)instances) + indirect_data->primitiveOffset);
642    rebraid_compute_AABB(bvh, instances + instanceID);
643}
644
645GRL_ANNOTATE_IGC_DO_NOT_SPILL
646__attribute__((reqd_work_group_size(1, 1, 1)))
647__attribute__((intel_reqd_sub_group_size(16))) void kernel
648rebraid_computeAABB_DXR_instances_pointers(
649    global struct BVHBase* bvh,
650    global void *instances_in)
651{
652    global const struct GRL_RAYTRACING_INSTANCE_DESC **instances =
653        (global const struct GRL_RAYTRACING_INSTANCE_DESC **)instances_in;
654
655    const uint instanceID = get_local_id(0) + get_group_id(0)*get_local_size(0);
656    rebraid_compute_AABB(bvh, instances[instanceID]);
657}
658
659GRL_ANNOTATE_IGC_DO_NOT_SPILL
660__attribute__((reqd_work_group_size(1, 1, 1)))
661__attribute__((intel_reqd_sub_group_size(16))) void kernel
662rebraid_computeAABB_DXR_instances_pointers_indirect(
663    global struct BVHBase* bvh,
664    global void *instances_in,
665    global struct IndirectBuildRangeInfo const * const indirect_data)
666{
667    instances_in = ((global char*)instances_in) + indirect_data->primitiveOffset;
668    global const struct GRL_RAYTRACING_INSTANCE_DESC **instances =
669        (global const struct GRL_RAYTRACING_INSTANCE_DESC **)instances_in;
670
671    const uint instanceID = get_local_id(0) + get_group_id(0)*get_local_size(0);
672    rebraid_compute_AABB(bvh, instances[instanceID]);
673}
674
675///////////////////////////////////////////////////////////////////////////////////////////
676//                            Init scratch:  Dispatch one work group
677///////////////////////////////////////////////////////////////////////////////////////////
678
679GRL_ANNOTATE_IGC_DO_NOT_SPILL
680__attribute__((reqd_work_group_size(64, 1, 1))) void kernel rebraid_init_scratch(global uint *scratch)
681{
682    scratch[get_local_id(0) + get_group_id(0)*get_local_size(0)] = 0;
683}
684
685GRL_ANNOTATE_IGC_DO_NOT_SPILL
686__attribute__((reqd_work_group_size(1, 1, 1))) void kernel rebraid_chase_instance_pointers(global struct GRL_RAYTRACING_INSTANCE_DESC *instances_out,
687                                                                                           global void *instance_buff)
688{
689    global const struct GRL_RAYTRACING_INSTANCE_DESC **instances_in =
690        (global const struct GRL_RAYTRACING_INSTANCE_DESC **)instance_buff;
691
692    instances_out[get_local_id(0)] = *instances_in[get_local_id(0)];
693}
694
695GRL_ANNOTATE_IGC_DO_NOT_SPILL
696__attribute__((reqd_work_group_size(1, 1, 1)))
697void kernel rebraid_chase_instance_pointers_indirect(
698    global struct GRL_RAYTRACING_INSTANCE_DESC*       instances_out,
699    global void*                                      instance_buff,
700    global struct IndirectBuildRangeInfo const* const indirect_data)
701{
702    instance_buff = ((global char*)instance_buff) + indirect_data->primitiveOffset;
703    global const struct GRL_RAYTRACING_INSTANCE_DESC**
704        instances_in = (global const struct GRL_RAYTRACING_INSTANCE_DESC**)instance_buff;
705
706    instances_out[get_local_id(0)] = *instances_in[get_local_id(0)];
707}
708
709///////////////////////////////////////////////////////////////////////////////////////////
710//                             Count splits
711///////////////////////////////////////////////////////////////////////////////////////////
712
713GRL_INLINE void DEBUG_SUBGROUP_print_split_counts( uniform uint instanceID, varying uint split_counts_lo, varying uint split_counts_hi )
714{
715    uniform uint vals[32] = {
716       sub_group_broadcast( split_counts_lo, 0 ),  sub_group_broadcast( split_counts_lo, 1 ),
717       sub_group_broadcast( split_counts_lo, 2 ),  sub_group_broadcast( split_counts_lo, 3 ),
718       sub_group_broadcast( split_counts_lo, 4 ),  sub_group_broadcast( split_counts_lo, 5 ),
719       sub_group_broadcast( split_counts_lo, 6 ),  sub_group_broadcast( split_counts_lo, 7 ),
720       sub_group_broadcast( split_counts_lo, 8 ),  sub_group_broadcast( split_counts_lo, 9 ),
721       sub_group_broadcast( split_counts_lo, 10 ), sub_group_broadcast( split_counts_lo, 11 ),
722       sub_group_broadcast( split_counts_lo, 12 ), sub_group_broadcast( split_counts_lo, 13 ),
723       sub_group_broadcast( split_counts_lo, 14 ), sub_group_broadcast( split_counts_lo, 15 ),
724
725       sub_group_broadcast( split_counts_hi, 0 ),  sub_group_broadcast( split_counts_hi, 1 ),
726       sub_group_broadcast( split_counts_hi, 2 ),  sub_group_broadcast( split_counts_hi, 3 ),
727       sub_group_broadcast( split_counts_hi, 4 ),  sub_group_broadcast( split_counts_hi, 5 ),
728       sub_group_broadcast( split_counts_hi, 6 ),  sub_group_broadcast( split_counts_hi, 7 ),
729       sub_group_broadcast( split_counts_hi, 8 ),  sub_group_broadcast( split_counts_hi, 9 ),
730       sub_group_broadcast( split_counts_hi, 10 ), sub_group_broadcast( split_counts_hi, 11 ),
731       sub_group_broadcast( split_counts_hi, 12 ), sub_group_broadcast( split_counts_hi, 13 ),
732       sub_group_broadcast( split_counts_hi, 14 ), sub_group_broadcast( split_counts_hi, 15 )
733    };
734
735    if ( get_sub_group_local_id() == 0 )
736    {
737        printf(
738            "Instance: %4u  "
739            "%2u %2u %2u %2u %2u %2u %2u %2u %2u %2u %2u %2u %2u %2u %2u %2u "
740            "%2u %2u %2u %2u %2u %2u %2u %2u %2u %2u %2u %2u %2u %2u %2u %2u \n"
741            ,
742            instanceID,
743            vals[0], vals[1], vals[2], vals[3], vals[4], vals[5], vals[6], vals[7],
744            vals[8], vals[9], vals[10], vals[11], vals[12], vals[13], vals[14], vals[15],
745            vals[16], vals[17], vals[18], vals[19], vals[20], vals[21], vals[22], vals[23],
746            vals[24], vals[25], vals[26], vals[27], vals[28], vals[29], vals[30], vals[31]
747        );
748    }
749}
750
751GRL_INLINE void do_rebraid_count_splits_SG(
752    uniform global struct BVHBase* bvh,
753    uniform global __const struct GRL_RAYTRACING_INSTANCE_DESC *instances,
754    uniform global uint *rebraid_scratch)
755{
756    uniform const uint instanceID = get_sub_group_global_id();
757    uniform RebraidBuffers buffers = cast_rebraid_buffers(rebraid_scratch,instanceID);
758
759    varying uint lane = get_sub_group_local_id();
760    varying uint2 splits = SUBGROUP_count_instance_splits(&bvh->Meta.bounds, instances + instanceID);
761    varying uint split_counts_lo = splits.x;
762    varying uint split_counts_hi = splits.y;
763
764    // write this instance's per-bin counts
765    global uint* counts = buffers.instance_bin_counts;
766    intel_sub_group_block_write2( counts, splits );
767
768    // update the per-bin split and instance counters
769    if (split_counts_lo > 0)
770    {
771        atomic_add(&buffers.bin_split_counts[lane], split_counts_lo);
772        GRL_ATOMIC_INC(&buffers.bin_instance_counts[lane]);
773    }
774    if (split_counts_hi > 0)
775    {
776        atomic_add(&buffers.bin_split_counts[lane + 16], split_counts_hi);
777        GRL_ATOMIC_INC(&buffers.bin_instance_counts[lane + 16]);
778    }
779}
780
781GRL_ANNOTATE_IGC_DO_NOT_SPILL
782__attribute__((reqd_work_group_size(16, 1, 1)))
783__attribute__((intel_reqd_sub_group_size(16))) void kernel
784rebraid_count_splits_SG(
785    uniform global struct BVHBase* bvh,
786    uniform global __const struct GRL_RAYTRACING_INSTANCE_DESC *instances,
787    uniform global uint *rebraid_scratch)
788{
789    do_rebraid_count_splits_SG(bvh, instances, rebraid_scratch);
790}
791
792GRL_ANNOTATE_IGC_DO_NOT_SPILL
793__attribute__((reqd_work_group_size(16, 1, 1)))
794__attribute__((intel_reqd_sub_group_size(16))) void kernel
795rebraid_count_splits_SG_indirect(
796    uniform global struct BVHBase* bvh,
797    uniform global __const struct GRL_RAYTRACING_INSTANCE_DESC *instances,
798    uniform global uint *rebraid_scratch,
799    global struct IndirectBuildRangeInfo const * const indirect_data)
800{
801    instances = (global __const struct GRL_RAYTRACING_INSTANCE_DESC*)
802        (((global char*)instances) + indirect_data->primitiveOffset);
803    do_rebraid_count_splits_SG(bvh, instances, rebraid_scratch);
804}
805
806
807#define HEAP_SIZE 16
808#define COUNT_SPLITS_WG_SIZE 16
809
810struct SLMHeapNode
811{
812    short offs;
813    ushort area;
814};
815
816struct SLMHeap
817{
818    struct SLMHeapNode nodes[HEAP_SIZE];
819    ushort size;
820    ushort min_key;
821};
822
823GRL_INLINE bool SLMHeapNode_Greater( struct SLMHeapNode a, struct SLMHeapNode b )
824{
825    return a.area > b.area;
826}
827
828GRL_INLINE ushort SLMHeapNode_UnpackKey( struct SLMHeapNode a )
829{
830    return a.area;
831}
832
833GRL_INLINE void SLMHeapNode_Unpack( struct SLMHeapNode a, ushort* area_out, short* offs_out )
834{
835    *area_out = a.area;
836    *offs_out = a.offs;
837}
838
839GRL_INLINE struct SLMHeapNode SLMHeapNode_Pack( ushort area, short offs )
840{
841    struct SLMHeapNode n;
842    n.offs = offs;
843    n.area = area;
844    return n;
845}
846
847
848GRL_INLINE void SLMHeap_Init( struct SLMHeap* heap )
849{
850    heap->size = 0;
851    heap->min_key = 0xffff;
852}
853
854GRL_INLINE bool SLMHeap_empty( struct SLMHeap* heap )
855{
856    return heap->size == 0;
857}
858
859GRL_INLINE bool SLMHeap_full( struct SLMHeap* heap )
860{
861    return heap->size == HEAP_SIZE;
862}
863
864
865GRL_INLINE void SLMHeap_push( struct SLMHeap* heap, ushort area, short offs )
866{
867    ushort insert_pos;
868    if ( SLMHeap_full( heap ) )
869    {
870        ushort current_min_key = heap->min_key;
871        if ( area <= current_min_key )
872            return; // don't push stuff that's smaller than the current minimum
873
874        // search for the minimum element
875        //  The heap is laid out in level order, so it is sufficient to search only the last half
876        ushort last_leaf = HEAP_SIZE - 1;
877        ushort first_leaf = (last_leaf / 2) + 1;
878
879        // as we search, keep track of what the new min-key will be so we can cull future pushes
880        ushort new_min_key = area;
881        ushort min_pos = 0;
882
883        do
884        {
885            ushort idx = first_leaf++;
886
887            ushort current_key = SLMHeapNode_UnpackKey( heap->nodes[idx] );
888            bool found_min_pos = (min_pos == 0) && (current_key == current_min_key);
889
890            if ( found_min_pos )
891                min_pos = idx;
892            else
893                new_min_key = min( current_key, new_min_key );
894
895        } while ( first_leaf != last_leaf );
896
897        heap->min_key = new_min_key;
898        insert_pos = min_pos;
899    }
900    else
901    {
902        insert_pos = heap->size++;
903        heap->min_key = min( area, heap->min_key );
904    }
905
906    heap->nodes[insert_pos] = SLMHeapNode_Pack( area, offs );
907
908    // heap-up
909    while ( insert_pos )
910    {
911        ushort parent = insert_pos / 2;
912
913        struct SLMHeapNode parent_node = heap->nodes[parent];
914        struct SLMHeapNode current_node = heap->nodes[insert_pos];
915        if ( SLMHeapNode_Greater( parent_node, current_node ) )
916            break;
917
918        heap->nodes[insert_pos]    = parent_node;
919        heap->nodes[parent]        = current_node;
920        insert_pos = parent;
921    }
922
923}
924
925bool SLMHeap_pop_max( struct SLMHeap* heap, ushort* area_out, short* offs_out )
926{
927    if ( SLMHeap_empty( heap ) )
928        return false;
929
930    SLMHeapNode_Unpack( heap->nodes[0], area_out, offs_out );
931
932    // heap down
933    ushort size = heap->size;
934    ushort idx = 0;
935    do
936    {
937        ushort left = 2 * idx + 1;
938        ushort right = 2 * idx + 2;
939        if ( left >= size )
940            break;
941
942        if ( right >= size )
943        {
944            heap->nodes[idx] = heap->nodes[left];
945            break;
946        }
947
948        struct SLMHeapNode left_node = heap->nodes[left];
949        struct SLMHeapNode right_node = heap->nodes[right];
950        bool go_left = SLMHeapNode_Greater( left_node, right_node );
951        heap->nodes[idx] = go_left ? left_node : right_node;
952        idx = go_left ? left : right;
953
954    } while ( 1 );
955
956    heap->size = size - 1;
957    return true;
958}
959
960void SLMHeap_Print( struct SLMHeap* heap )
961{
962    printf( " size=%u min=%u {", heap->size, heap->min_key );
963    for ( uint i = 0; i < heap->size; i++ )
964        printf( "%04x:%04x", heap->nodes[i].area, heap->nodes[i].offs );
965}
966
967
968GRL_INLINE bool can_open_root(
969    global struct BVHBase* bvh_base,
970    const struct GRL_RAYTRACING_INSTANCE_DESC* instance
971    )
972{
973    float3 root_lower = AABB3f_load_lower( &bvh_base->Meta.bounds );
974    float3 root_upper = AABB3f_load_upper( &bvh_base->Meta.bounds );
975    if ( !is_aabb_valid( root_lower, root_upper ) || GRL_get_InstanceMask(instance) == 0 )
976        return false;
977
978    global InternalNode* node = get_node( bvh_base, 0 );
979    if ( node->nodeType != NODE_TYPE_INTERNAL )
980        return false;
981
982    return is_node_openable( node );
983}
984
985
986GRL_INLINE void count_instance_splits(
987    global struct AABB3f* geometry_bounds,
988    global __const struct GRL_RAYTRACING_INSTANCE_DESC* instance,
989    local ushort* bin_split_counts,
990    local struct SLMHeap* heap
991)
992{
993    global BVHBase* bvh_base = (global BVHBase*)instance->AccelerationStructure;
994
995    SLMHeap_Init( heap );
996
997    float relative_area_scale = 1.0f / AABB3f_halfArea( geometry_bounds );
998    float3 root_lower = AABB3f_load_lower( &bvh_base->Meta.bounds );
999    float3 root_upper = AABB3f_load_upper( &bvh_base->Meta.bounds );
1000
1001    ushort quantized_area = quantize_area( transformed_aabb_halfArea( root_lower, root_upper, instance->Transform ) * relative_area_scale );
1002    short  node_offs = 0;
1003    ushort num_splits = 0;
1004
1005    global InternalNode* node_array = get_node( bvh_base, 0 );
1006
1007    while ( 1 )
1008    {
1009        global InternalNode* node = node_array + node_offs;
1010
1011        // count this split
1012        uint bin = get_rebraid_bin_index( quantized_area, NUM_REBRAID_BINS );
1013        bin_split_counts[bin]++;
1014
1015        // open this node and push children to heap
1016
1017        // TODO_OPT:  Restructure this control flow to prevent differnet lanes from skipping different loop iterations and diverging
1018        // TODO_OPT:  Precompute openability masks in BLAS nodes at build time... one bit for self and N bits for each child
1019        int offs = node->childOffset;
1020        for ( ushort i = 0; i < NUM_CHILDREN; i++ )
1021        {
1022            if ( InternalNode_IsChildValid( node, i ) )
1023            {
1024                if ( offs >= SHRT_MIN && offs <= SHRT_MAX )
1025                {
1026                    if ( is_node_openable( node_array + offs ) )
1027                    {
1028                        ushort area = get_child_area( node, i, instance->Transform, relative_area_scale );
1029                        SLMHeap_push( heap, area, (short)offs );
1030                    }
1031                }
1032            }
1033            offs += InternalNode_GetChildBlockIncr( node, i );
1034        }
1035
1036        num_splits++;
1037        if ( num_splits == MAX_SPLITS_PER_INSTANCE )
1038            break;
1039
1040        if ( !SLMHeap_pop_max( heap, &quantized_area, &node_offs ) )
1041            break;
1042    }
1043
1044}
1045
1046GRL_ANNOTATE_IGC_DO_NOT_SPILL
1047__attribute__( (reqd_work_group_size( COUNT_SPLITS_WG_SIZE, 1, 1 )) )
1048void kernel
1049rebraid_count_splits(
1050    global struct BVHBase* bvh_base,
1051    global __const struct GRL_RAYTRACING_INSTANCE_DESC* instances,
1052    global uint* rebraid_scratch,
1053    uint num_instances
1054    )
1055{
1056    local struct SLMHeap heap[COUNT_SPLITS_WG_SIZE];
1057    local ushort split_counts[COUNT_SPLITS_WG_SIZE][NUM_REBRAID_BINS];
1058
1059    // initialize stuff
1060    // TODO_OPT:  transpose this and subgroup-vectorize it so that
1061    //     block-writes can be used
1062    for ( uint i = 0; i < NUM_REBRAID_BINS; i++ )
1063        split_counts[get_local_id( 0 )][i] = 0;
1064
1065
1066    // count splits for this thread's instance
1067    uniform uint base_instance = get_group_id( 0 ) * get_local_size( 0 );
1068    uint instanceID = base_instance + get_local_id( 0 );
1069
1070    if ( instanceID < num_instances )
1071    {
1072        global BVHBase* bvh_base = (global BVHBase*)instances[instanceID].AccelerationStructure;
1073        if ( can_open_root( bvh_base, &instances[instanceID] ) )
1074        {
1075            count_instance_splits( &bvh_base->Meta.bounds,
1076                &instances[instanceID],
1077                &split_counts[get_local_id( 0 )][0],
1078                &heap[get_local_id(0)] );
1079        }
1080    }
1081
1082    barrier( CLK_LOCAL_MEM_FENCE );
1083
1084    RebraidBuffers buffers = cast_rebraid_buffers( rebraid_scratch, instanceID );
1085
1086
1087    // reduce bins
1088    for ( uint bin = get_local_id( 0 ); bin < NUM_REBRAID_BINS; bin += get_local_size( 0 ) )
1089    {
1090        // TODO_OPT:  There's probably a better way to arrange this computation
1091        uint bin_split_count = 0;
1092        uint bin_instance_count = 0;
1093        for ( uint i = 0; i < COUNT_SPLITS_WG_SIZE; i++ )
1094        {
1095            uint s = split_counts[i][bin];
1096            bin_split_count     += s;
1097            bin_instance_count  += (s > 0) ? 1 : 0;
1098        }
1099
1100        if ( bin_split_count > 0 )
1101        {
1102            atomic_add( &buffers.bin_split_counts[bin], bin_split_count );
1103            atomic_add( &buffers.bin_instance_counts[bin], bin_instance_count );
1104        }
1105    }
1106
1107    // write out bin counts for each instance
1108    for ( uniform uint i = get_sub_group_id(); i < COUNT_SPLITS_WG_SIZE; i += get_num_sub_groups() )
1109    {
1110        uniform uint iid = base_instance + i;
1111        if ( iid > num_instances )
1112            break;
1113
1114        global uint* instance_bin_counts = cast_rebraid_buffers( rebraid_scratch, iid ).instance_bin_counts;
1115
1116        for ( uniform ushort j = 0; j < NUM_REBRAID_BINS; j += get_sub_group_size() )
1117        {
1118            uint count = split_counts[i][j + get_sub_group_local_id() ];
1119            intel_sub_group_block_write( instance_bin_counts + j, count );
1120        }
1121    }
1122
1123}
1124
1125
1126
1127
1128///////////////////////////////////////////////////////////////////////////////////////////
1129//                             Build PrimRefs
1130///////////////////////////////////////////////////////////////////////////////////////////
1131
1132GRL_INLINE uint get_instance_split_count(RebraidBuffers buffers, uint instanceID, uint available_splits)
1133{
1134    global uint* instance_desired_split_count = buffers.instance_bin_counts;
1135    global uint *bin_split_counts = buffers.bin_split_counts;
1136    global uint *bin_instance_counts = buffers.bin_instance_counts;
1137
1138    uint total_splits = 0;
1139    uint remaining_available_splits = available_splits;
1140    uint max_bin = 0;
1141    uint desired_splits_this_bin = 0;
1142    uint instance_splits = 0;
1143
1144    do
1145    {
1146        // stop when we reach a level where we can't satisfy the demand
1147        desired_splits_this_bin = instance_desired_split_count[max_bin];
1148        uint total_bin_splits = bin_split_counts[max_bin];
1149
1150        if (total_bin_splits > remaining_available_splits)
1151            break;
1152
1153        // we have enough budget to give all instances everything they want at this level, so do it
1154        remaining_available_splits -= total_bin_splits;
1155        instance_splits += desired_splits_this_bin;
1156        desired_splits_this_bin = 0;
1157        max_bin++;
1158
1159    } while (max_bin < NUM_REBRAID_BINS);
1160
1161    if (max_bin < NUM_REBRAID_BINS)
1162    {
1163        // we have more split demand than we have splits available.  The current bin is the last one that gets any splits
1164        //   distribute the leftovers as evenly as possible to instances that want them
1165        if (desired_splits_this_bin > 0)
1166        {
1167            // this instance wants splits.  how many does it want?
1168            uint desired_total = instance_splits + desired_splits_this_bin;
1169
1170            // distribute to all instances as many as possible
1171            uint count = bin_instance_counts[max_bin];
1172            uint whole = remaining_available_splits / count;
1173            remaining_available_splits -= whole * count;
1174
1175            // distribute remainder to lower numbered instances
1176            size_t partial = (instanceID < remaining_available_splits) ? 1 : 0;
1177
1178            // give the instance its share.
1179            instance_splits += whole + partial;
1180            instance_splits = min(instance_splits, desired_total); // don't give it more than it needs
1181        }
1182    }
1183
1184    return instance_splits;
1185}
1186
1187GRL_INLINE void build_unopened_primref(
1188    struct AABB3f* centroid_bounds,
1189    global __const BVHBase *bvh_base,
1190    global volatile uint *primref_counter,
1191    global struct AABB *primref_buffer,
1192    global __const float *Transform,
1193    uint instanceID,
1194    float matOverhead,
1195    ushort instanceMask)
1196{
1197    float3 root_lower = AABB3f_load_lower(&bvh_base->Meta.bounds);
1198    float3 root_upper = AABB3f_load_upper(&bvh_base->Meta.bounds);
1199
1200    struct AABB primRef;
1201    AABB_init( &primRef );
1202
1203    uint bvhoffset = (uint)BVH_ROOT_NODE_OFFSET;
1204    if (is_aabb_valid(root_lower, root_upper) && instanceMask != 0)
1205    {
1206        primRef = AABBfromAABB3f(compute_xfm_bbox(Transform, BVHBase_GetRootNode(bvh_base), XFM_BOX_NOT_REFINED_TAKE_CLIPBOX, &bvh_base->Meta.bounds, matOverhead));
1207    }
1208    else
1209    {
1210        primRef.lower.x = Transform[3];
1211        primRef.lower.y = Transform[7];
1212        primRef.lower.z = Transform[11];
1213        primRef.upper.xyz = primRef.lower.xyz;
1214
1215        instanceMask = 0;
1216        bvhoffset = NO_NODE_OFFSET;
1217    }
1218
1219    primRef.lower.w = as_float(instanceID | (instanceMask << 24));
1220    primRef.upper.w = as_float(bvhoffset);
1221
1222    float3 centroid = primRef.lower.xyz + primRef.upper.xyz;
1223    centroid_bounds->lower[0] = centroid.x;
1224    centroid_bounds->upper[0] = centroid.x;
1225    centroid_bounds->lower[1] = centroid.y;
1226    centroid_bounds->upper[1] = centroid.y;
1227    centroid_bounds->lower[2] = centroid.z;
1228    centroid_bounds->upper[2] = centroid.z;
1229
1230    uint place = GRL_ATOMIC_INC(primref_counter);
1231    primref_buffer[place] = primRef;
1232}
1233
1234GRL_INLINE void build_opened_primrefs(
1235    varying bool lane_mask,
1236    varying uint offset,
1237    varying InternalNode* node,
1238    varying struct AABB3f* centroid_bounds,
1239    uniform global BVHBase *bvh_base,
1240    uniform volatile global uint *primref_counter,
1241    uniform global struct AABB *primref_buffer,
1242    uniform uint instanceID,
1243    uniform const float *Transform,
1244    uniform float matOverhead,
1245    varying ushort instanceMask)
1246{
1247    // TODO_OPT: This function is often called with <= 6 active lanes
1248    //  If lanes are sparse, consider jumping to a sub-group vectorized variant...
1249
1250    if (lane_mask)
1251    {
1252        varying uint place = GRL_ATOMIC_INC(primref_counter);
1253
1254        struct AABB box = AABBfromAABB3f(compute_xfm_bbox(Transform, node, XFM_BOX_NOT_REFINED_CLIPPED, &bvh_base->Meta.bounds, matOverhead));
1255
1256        box.lower.w = as_float(instanceID | (instanceMask << 24));
1257        box.upper.w = as_float(offset * 64 + (uint)BVH_ROOT_NODE_OFFSET);
1258        primref_buffer[place] = box;
1259
1260        AABB3f_extend_point( centroid_bounds, box.lower.xyz + box.upper.xyz );
1261    }
1262}
1263
1264
1265GRL_INLINE void SUBGROUP_open_nodes(
1266    uniform global struct AABB3f *geometry_bounds,
1267    uniform uint split_limit,
1268    uniform global __const struct GRL_RAYTRACING_INSTANCE_DESC *instance,
1269    uniform uint instanceID,
1270    uniform volatile global uint *primref_counter,
1271    uniform global struct AABB *primref_buffer,
1272    varying struct AABB3f* centroid_bounds,
1273    float transformOverhead)
1274{
1275    uniform SGHeap heap;
1276    SGHeap_init(&heap);
1277
1278    uniform float relative_area_scale = 1.0f / AABB3f_halfArea(geometry_bounds);
1279    uniform global BVHBase *bvh_base = (global BVHBase *)instance->AccelerationStructure;
1280
1281    uniform uint16_t node_offs = 0;
1282    varying uint lane = get_sub_group_local_id();
1283
1284    uniform InternalNode* node_array = get_node( bvh_base, 0 );
1285
1286    LOOP_TRIPWIRE_INIT;
1287
1288    while ( 1 )
1289    {
1290        uniform InternalNode *node = node_array + node_offs;
1291
1292        varying uint sg_offs   = node_offs + SUBGROUP_get_child_offsets(node);
1293        varying bool sg_valid = false;
1294        varying bool sg_openable = false;
1295        if (lane < NUM_CHILDREN)
1296        {
1297            sg_valid = InternalNode_IsChildValid(node, lane);
1298            if (sg_valid && (sg_offs <= MAX_NODE_OFFSET))
1299            {
1300                sg_openable = is_node_openable( node_array + sg_offs);
1301            }
1302        }
1303
1304        uniform uint16_t valid_children = intel_sub_group_ballot(sg_valid);
1305        uniform uint16_t openable_children = intel_sub_group_ballot(sg_openable);
1306        uniform uint16_t unopenable_children = valid_children & (~openable_children);
1307
1308        if ( openable_children )
1309        {
1310            varying uint16_t sg_area = SUBGROUP_get_child_areas( node, instance->Transform, relative_area_scale );
1311
1312            // try to push all openable children to the heap
1313            if ( !SGHeap_full( &heap ) )
1314            {
1315                openable_children = SGHeap_vectorized_push( &heap, sg_area, sg_offs, openable_children );
1316            }
1317
1318            // we have more openable children than will fit in the heap
1319            //  process these one by one.
1320            //          TODO: Try re-writing with sub_group_any() and see if compiler does a better job
1321            while ( openable_children )
1322            {
1323                // pop min element
1324                uniform uint16_t min_area;
1325                uniform uint16_t min_offs;
1326                SGHeap_full_pop_min( &heap, &min_area, &min_offs );
1327
1328                // eliminate all children smaller than heap minimum.
1329                // mark eliminated children as unopenable
1330                varying uint culled_children = openable_children & intel_sub_group_ballot( sg_area <= min_area );
1331                unopenable_children ^= culled_children;
1332                openable_children &= ~culled_children;
1333
1334                if ( openable_children )
1335                {
1336                    // if any children survived the purge
1337                    //   find the first such child and swap its offset for the one from the heap
1338                    //
1339                    uniform uint child_id = ctz( openable_children );
1340                    uniform uint16_t old_min_offs = min_offs;
1341                    min_area = sub_group_broadcast( sg_area, child_id );
1342                    min_offs = sub_group_broadcast( sg_offs, child_id );
1343
1344                    if ( lane == child_id )
1345                        sg_offs = old_min_offs;
1346
1347                    openable_children ^= (1 << child_id);
1348                    unopenable_children ^= (1 << child_id);
1349                }
1350
1351                SGHeap_push_and_fill( &heap, min_area, min_offs );
1352
1353            }
1354        }
1355
1356        if (unopenable_children)
1357        {
1358            varying bool sg_create_primref = ((1 << lane) & unopenable_children);
1359            build_opened_primrefs(sg_create_primref, sg_offs, node_array + sg_offs, centroid_bounds, bvh_base, primref_counter, primref_buffer, instanceID, instance->Transform, transformOverhead, GRL_get_InstanceMask(instance));
1360        }
1361
1362        --split_limit;
1363        if (split_limit == 0)
1364        {
1365            // split limit exceeded
1366            //  create primrefs for all remaining openable nodes in heap
1367            varying bool sg_mask = SGHeap_get_lane_mask(&heap);
1368            sg_offs = SGHeap_get_lane_values(&heap);
1369            build_opened_primrefs(sg_mask, sg_offs, node_array + sg_offs, centroid_bounds, bvh_base, primref_counter, primref_buffer, instanceID, instance->Transform, transformOverhead, GRL_get_InstanceMask(instance));
1370
1371            break;
1372        }
1373
1374
1375        // NOTE: the heap should never be empty.  If it is, the instance was given too many splits.
1376
1377        // get next node from heap
1378        uint16_t quantized_area;
1379        SGHeap_pop_max(&heap, &quantized_area, &node_offs);
1380
1381        LOOP_TRIPWIRE_INCREMENT( 500, "rebraid_build_primrefs" );
1382
1383    }
1384}
1385
1386
1387#define OPEN_QUEUE_SIZE 256
1388#define OPEN_QUEUE_NUM_SGS 16
1389
1390typedef struct OpenQueueEntry
1391{
1392    uint instanceID;
1393    ushort num_splits;
1394} OpenQueueEntry;
1395
1396typedef struct OpenQueue
1397{
1398    uint num_produced;
1399    uint num_consumed;
1400    OpenQueueEntry Q[OPEN_QUEUE_SIZE];
1401} OpenQueue;
1402
1403uniform uint SUBGROUP_GetNextQueueEntry( local OpenQueue* queue )
1404{
1405    uint next = 0;
1406    if ( get_sub_group_local_id() == 0 )
1407        next = GRL_ATOMIC_INC( &queue->num_consumed );
1408    return sub_group_broadcast( next, 0 );
1409}
1410
1411
1412GRL_INLINE void do_rebraid_build_primrefs(
1413    local struct AABB3f* SLM_CentroidBounds,
1414    local OpenQueue* SLM_Q,
1415    global struct Globals* globals,
1416    global struct BVHBase* base,
1417    global __const struct GRL_RAYTRACING_INSTANCE_DESC* instance_buffer,
1418    global uint* rebraid_scratch,
1419    global struct AABB* primref_buffer,
1420    uint extra_primref_count,
1421    uint num_instances)
1422{
1423    varying uint instanceID = get_sub_group_size() * get_sub_group_global_id() + get_sub_group_local_id();
1424
1425    uniform volatile global uint* primref_counter = &globals->numPrimitives;
1426    uniform RebraidBuffers buffers = cast_rebraid_buffers( rebraid_scratch, instanceID );
1427    uniform uint available_splits = get_num_splits( extra_primref_count, NUM_CHILDREN );
1428
1429
1430
1431    varying struct AABB3f centroidBounds;
1432    AABB3f_init( &centroidBounds );
1433
1434    if ( get_local_id( 0 ) == 0 )
1435    {
1436        SLM_Q->num_produced = 0;
1437        SLM_Q->num_consumed = 0;
1438        AABB3f_init( SLM_CentroidBounds );
1439    }
1440
1441    barrier( CLK_LOCAL_MEM_FENCE );
1442
1443    // assign splits to unopened instances.  Build primrefs for unsplit instances in vectorized form
1444    varying uint num_splits = 0;
1445    if ( instanceID < num_instances )
1446    {
1447        num_splits = get_instance_split_count( buffers, instanceID, available_splits );
1448        if ( num_splits == 0 )
1449        {
1450            varying global const struct GRL_RAYTRACING_INSTANCE_DESC* instance = instance_buffer + instanceID;
1451            varying global BVHBase* bvh_base = (global BVHBase*)instance->AccelerationStructure;
1452            if ( bvh_base != 0 )
1453            {
1454                build_unopened_primref( &centroidBounds, bvh_base, primref_counter, primref_buffer, instance->Transform, instanceID, 0.0f, GRL_get_InstanceMask(instance));
1455            }
1456        }
1457        else
1458        {
1459            // defer opened instances
1460            uint place = GRL_ATOMIC_INC( &SLM_Q->num_produced );
1461            SLM_Q->Q[place].instanceID = instanceID;
1462            SLM_Q->Q[place].num_splits = (ushort)num_splits;
1463        }
1464    }
1465
1466    barrier( CLK_LOCAL_MEM_FENCE );
1467
1468    // if there were opened instances, process them, one per subgroup
1469    uniform uint num_produced = SLM_Q->num_produced;
1470    uniform uint next = SUBGROUP_GetNextQueueEntry( SLM_Q );
1471
1472    while ( next < num_produced )
1473    {
1474        uniform uint instanceID = SLM_Q->Q[next].instanceID;
1475        uniform uint num_splits = SLM_Q->Q[next].num_splits;
1476
1477        uniform global const struct GRL_RAYTRACING_INSTANCE_DESC* instance = instance_buffer + instanceID;
1478
1479        float transformOverhead =
1480#if FINE_TRANSFORM_NODE_BOX
1481            transformation_bbox_surf_overhead(instance->Transform);
1482#else
1483            0.0f;
1484#endif
1485
1486        SUBGROUP_open_nodes(
1487            &base->Meta.bounds,
1488            num_splits,
1489            instance,
1490            instanceID,
1491            primref_counter,
1492            primref_buffer,
1493            &centroidBounds,
1494            transformOverhead);
1495
1496        next = SUBGROUP_GetNextQueueEntry( SLM_Q );
1497    }
1498
1499    // reduce the centroid bounds AABB
1500    struct AABB3f reduced = AABB3f_sub_group_reduce( &centroidBounds );
1501    if ( get_sub_group_local_id() == 0 )
1502        AABB3f_atomic_merge_localBB_nocheck( SLM_CentroidBounds, &reduced );
1503
1504    barrier( CLK_LOCAL_MEM_FENCE );
1505
1506    if( get_local_id(0) == 0 )
1507    {
1508        atomic_min( (global float*) (&globals->centroidBounds.lower) + 0, SLM_CentroidBounds->lower[0] );
1509        atomic_min( (global float*) (&globals->centroidBounds.lower) + 1, SLM_CentroidBounds->lower[1] );
1510        atomic_min( (global float*) (&globals->centroidBounds.lower) + 2, SLM_CentroidBounds->lower[2] );
1511        atomic_max( (global float*) (&globals->centroidBounds.upper) + 0, SLM_CentroidBounds->upper[0] );
1512        atomic_max( (global float*) (&globals->centroidBounds.upper) + 1, SLM_CentroidBounds->upper[1] );
1513        atomic_max( (global float*) (&globals->centroidBounds.upper) + 2, SLM_CentroidBounds->upper[2] );
1514    }
1515}
1516
1517GRL_ANNOTATE_IGC_DO_NOT_SPILL
1518__attribute__( (reqd_work_group_size( OPEN_QUEUE_SIZE, 1, 1 )) )
1519__attribute__( (intel_reqd_sub_group_size( 16 )) )
1520void kernel rebraid_build_primrefs(
1521    global struct Globals* globals,
1522    global struct BVHBase* base,
1523    global __const struct GRL_RAYTRACING_INSTANCE_DESC* instance_buffer,
1524    global uint* rebraid_scratch,
1525    global struct AABB* primref_buffer,
1526    uint extra_primref_count,
1527    uint num_instances)
1528{
1529    local struct AABB3f SLM_CentroidBounds;
1530    local OpenQueue SLM_Q;
1531    do_rebraid_build_primrefs(
1532        &SLM_CentroidBounds,
1533        &SLM_Q,
1534        globals,
1535        base,
1536        instance_buffer,
1537        rebraid_scratch,
1538        primref_buffer,
1539        extra_primref_count,
1540        num_instances);
1541}
1542
1543GRL_ANNOTATE_IGC_DO_NOT_SPILL
1544__attribute__( (reqd_work_group_size( OPEN_QUEUE_SIZE, 1, 1 )) )
1545__attribute__( (intel_reqd_sub_group_size( 16 )) )
1546void kernel rebraid_build_primrefs_indirect(
1547    global struct Globals* globals,
1548    global struct BVHBase* base,
1549    global __const struct GRL_RAYTRACING_INSTANCE_DESC* instance_buffer,
1550    global uint* rebraid_scratch,
1551    global struct AABB* primref_buffer,
1552    global struct IndirectBuildRangeInfo const * const indirect_data,
1553    uint extra_primref_count )
1554{
1555    local struct AABB3f SLM_CentroidBounds;
1556    local OpenQueue SLM_Q;
1557
1558    instance_buffer = (global __const struct GRL_RAYTRACING_INSTANCE_DESC*)
1559        (((global char*)instance_buffer) + indirect_data->primitiveOffset);
1560
1561    do_rebraid_build_primrefs(
1562        &SLM_CentroidBounds,
1563        &SLM_Q,
1564        globals,
1565        base,
1566        instance_buffer,
1567        rebraid_scratch,
1568        primref_buffer,
1569        extra_primref_count,
1570        indirect_data->primitiveCount);
1571}
1572
1573
1574///////////////////////////////////////////////////////////////////////////////////////////
1575//                             Misc
1576///////////////////////////////////////////////////////////////////////////////////////////
1577
1578
1579
1580GRL_ANNOTATE_IGC_DO_NOT_SPILL
1581__attribute__((reqd_work_group_size(16, 1, 1)))
1582__attribute__((intel_reqd_sub_group_size(16))) void kernel
1583ISA_TEST(global InternalNode *n, global uint *out, global float *xform, float scale)
1584{
1585
1586    out[get_sub_group_local_id()] = InternalNode_IsChildValid(n, get_sub_group_local_id());
1587}
1588
1589GRL_ANNOTATE_IGC_DO_NOT_SPILL
1590__attribute__( (reqd_work_group_size( 1, 1, 1 )) ) void kernel
1591DEBUG_PRINT(
1592    global struct Globals* globals,
1593    global __const struct GRL_RAYTRACING_INSTANCE_DESC* instance_buffer,
1594    global uint* rebraid_scratch,
1595    global struct AABB* primref_buffer,
1596    dword num_extra,
1597    dword input_instances )
1598{
1599#if 0
1600    // validate primrefs
1601    if ( (get_local_id(0) + get_group_id(0)*get_local_size(0)) == 0 )
1602    {
1603        uint refs = globals->numPrimitives;
1604        for ( uint i = 0; i < refs; i++ )
1605        {
1606            if ( any( primref_buffer[i].lower.xyz < globals->geometryBounds.lower.xyz ) ||
1607                 any( primref_buffer[i].upper.xyz > globals->geometryBounds.upper.xyz ) ||
1608                 any( isnan(primref_buffer[i].lower.xyz) ) ||
1609                any( isnan(primref_buffer[i].upper.xyz) ) )
1610            {
1611                struct AABB box = primref_buffer[i];
1612                printf( "BAD BOX:      %u  {%f,%f,%f} {%f,%f,%f} %u\n", as_uint( box.lower.w ),
1613                    box.lower.x, box.lower.y, box.lower.z,
1614                    box.upper.x, box.upper.y, box.upper.z,
1615                    as_uint( box.lower.w ) );
1616            }
1617
1618            const uint instIndex = PRIMREF_instanceID(&primref_buffer[i]);    // TODO: Refactor me.  We should not be using struct AABB for primRefs
1619            const uint rootByteOffset = as_uint( primref_buffer[i].upper.w ); // It should be struct PrimRef
1620            if ( instIndex >= input_instances )
1621                printf( "BAD INSTANCE INDEX: %u", i );
1622            else
1623            {
1624                global struct BVHBase* blas = (global struct BVHBase*)instance_buffer[instIndex].AccelerationStructure;
1625                if ( blas )
1626                {
1627                    struct InternalNode* start = BVHBase_GetInternalNodes( blas );
1628                    struct InternalNode* end = BVHBase_GetInternalNodesEnd( blas );
1629
1630                    InternalNode* entryPoint = (struct InternalNode*)((char*)instance_buffer[instIndex].AccelerationStructure + rootByteOffset);
1631                    if ( entryPoint < start || entryPoint >= end )
1632                        printf( "BAD ENTRYPOINT: %u\n", i );
1633                    if ( (rootByteOffset & 63) != 0 )
1634                        printf( "MISALIGNED ENTRYPOInt: %u\n", i );
1635
1636                }
1637            }
1638        }
1639    }
1640#endif
1641#if 0
1642    if ( (get_local_id(0) + get_group_id(0)*get_local_size(0)) == 0 )
1643        printf( "REBRAIDED: %u\n", globals->numPrimitives );
1644
1645    // print instance bin information
1646    if ( (get_local_id(0) + get_group_id(0)*get_local_size(0)) == 0 )
1647    {
1648        printf( "REBRAIDED: %u\n", globals->numPrimitives );
1649        for( uint i=0; i<231; i++  )
1650        {
1651            RebraidBuffers buffers = cast_rebraid_buffers( rebraid_scratch,i );
1652            printf( " ID:%4u ", i );
1653            for ( uint j = 0; j < NUM_REBRAID_BINS; j++ )
1654            {
1655                global uint* count = buffers.instance_bin_counts;
1656                printf( " %2u ", count[j] );
1657            }
1658            printf( "\n" );
1659        }
1660    }
1661#endif
1662#if 0
1663    if ( (get_local_id(0) + get_group_id(0)*get_local_size(0)) == 0 )
1664    {
1665        printf( "Instances: %u\n", globals->numPrimitives );
1666
1667        for ( uint i = 0; i < globals->numPrimitives; i++ )
1668        {
1669            if ( any( primref_buffer[i].lower.xyz < globals->geometryBounds.lower.xyz ) ||
1670                 any( primref_buffer[i].upper.xyz > globals->geometryBounds.upper.xyz ) )
1671            {
1672                struct AABB box = primref_buffer[i];
1673                printf( "      %u  {%f,%f,%f} {%f,%f,%f} %u\n", as_uint( box.lower.w ),
1674                    box.lower.x, box.lower.y, box.lower.z,
1675                    box.upper.x, box.upper.y, box.upper.z,
1676                    as_uint( box.lower.w ) );
1677            }
1678
1679        }
1680    }
1681#endif
1682}
1683
1684