xref: /aosp_15_r20/external/executorch/backends/vulkan/test/glsl/indexing_utils.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Width Dim Index, assuming (W, H, C, N) order
10 #define W_DIM 0
11 // Height, assuming (W, H, C, N) order
12 #define H_DIM 1
13 // Channels, assuming (W, H, C, N) order
14 #define C_DIM 2
15 
16 /*
17  * Describes which texture axis the "batches" dimension runs along in a 4D
18  * texture.
19  *
20  * Currently it is set to 2 since we represent batches by concatenating along
21  * the channels dim, which has index 2 in (W, H, C, N) order and maps to the
22  * depth dimension of a texture, which also corresponds to index 2 in (x, y, z)
23  * order.
24  */
25 #define BATCH_AXIS 2
26 
27 //
28 // Basic Indexing Utility Macros and Functions
29 //
30 
31 /*
32  * Aligns input to the next multiple of 4
33  */
34 #define alignup4(x) ((x + 3) & -4)
35 
36 //
37 // (w, h, c, n) Tensor Index <-> Contiguous Buffer Index Conversion
38 //
39 
40 /*
41  * Input: (w, h, c, n) tensor index, (W, H, C, N) sizes of a tensor, which dim
42  *        is packed along a texel
43  * Output: A ivec4 containing the buffer indices corresponding to each texel
44  *         element.
45  */
get_texel_nchw_buffer_ixs(ivec4 idx,ivec4 sizes,int packed_dim)46 ivec4 get_texel_nchw_buffer_ixs(ivec4 idx, ivec4 sizes, int packed_dim) {
47   ivec4 strides =
48       ivec4(1, sizes.x, sizes.x * sizes.y, sizes.x * sizes.y * sizes.z);
49 
50   int base_i = idx.x * strides.x + idx.y * strides.y + idx.z * strides.z +
51       idx.w * strides.w;
52 
53   return base_i + ivec4(0, 1, 2, 3) * strides[packed_dim];
54 }
55 
56 //
57 // (w, h, c, n) Tensor Index <-> (x, y, z) Texture Position Conversion
58 //
59 
60 /*
61  * Input: (x, y, z) texel position, (W, H, C, N) sizes of the tensor, which dim
62  *        is packed along a texel
63  * Output: Whether the texel position is outside the bounds of the image texture
64  *         given the size and packed dimension of the tensor.
65  */
pos_out_of_bounds(ivec3 pos,ivec4 sizes,int packed_dim)66 bool pos_out_of_bounds(ivec3 pos, ivec4 sizes, int packed_dim) {
67   // Align packed dim to next multiple of 4 to account for texel padding
68   sizes[packed_dim] = alignup4(sizes[packed_dim]);
69 
70   ivec3 max_pos = sizes.xyz;
71   max_pos[BATCH_AXIS] += sizes.w * sizes[BATCH_AXIS];
72   max_pos[packed_dim] /= 4;
73   return (any(greaterThanEqual(pos, max_pos)));
74 }
75 
76 /*
77  * Input: (x, y, z) texel position, (W, H, C, N) sizes of the tensor,
78  *        which dim is packed along a texel
79  * Returns: the (w, h, c, n) tensor index cooresponding to the first element of
80  *          the texel at the specified position
81  */
to_tensor_idx(ivec3 pos,ivec4 sizes,int packed_dim)82 ivec4 to_tensor_idx(ivec3 pos, ivec4 sizes, int packed_dim) {
83   // Align packed dim to next multiple of 4 to account for texel padding
84   sizes[packed_dim] = alignup4(sizes[packed_dim]);
85 
86   // Packed dim contains 4 elements per texel
87   pos[packed_dim] *= 4;
88   // Construct the initial tensor index via swizzling
89 #if BATCH_AXIS == 2
90   ivec4 tensor_idx = pos.xyzz;
91 #endif
92 #if BATCH_AXIS == 1
93   ivec4 tensor_idx = pos.xyzy;
94 #endif
95 #if BATCH_AXIS == 0
96   ivec4 tensor_idx = pos.xyzx;
97 #endif
98   // Adjust the axis that the batch dim runs along
99   tensor_idx[3] /= sizes[BATCH_AXIS];
100   tensor_idx[BATCH_AXIS] %= sizes[BATCH_AXIS];
101 
102   return tensor_idx;
103 }
104