1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 // Width Dim Index, assuming (W, H, C, N) order
10 #define W_DIM 0
11 // Height, assuming (W, H, C, N) order
12 #define H_DIM 1
13 // Channels, assuming (W, H, C, N) order
14 #define C_DIM 2
15
16 /*
17 * Describes which texture axis the "batches" dimension runs along in a 4D
18 * texture.
19 *
20 * Currently it is set to 2 since we represent batches by concatenating along
21 * the channels dim, which has index 2 in (W, H, C, N) order and maps to the
22 * depth dimension of a texture, which also corresponds to index 2 in (x, y, z)
23 * order.
24 */
25 #define BATCH_AXIS 2
26
27 //
28 // Basic Indexing Utility Macros and Functions
29 //
30
31 /*
32 * Aligns input to the next multiple of 4
33 */
34 #define alignup4(x) ((x + 3) & -4)
35
36 //
37 // (w, h, c, n) Tensor Index <-> Contiguous Buffer Index Conversion
38 //
39
40 /*
41 * Input: (w, h, c, n) tensor index, (W, H, C, N) sizes of a tensor, which dim
42 * is packed along a texel
43 * Output: A ivec4 containing the buffer indices corresponding to each texel
44 * element.
45 */
get_texel_nchw_buffer_ixs(ivec4 idx,ivec4 sizes,int packed_dim)46 ivec4 get_texel_nchw_buffer_ixs(ivec4 idx, ivec4 sizes, int packed_dim) {
47 ivec4 strides =
48 ivec4(1, sizes.x, sizes.x * sizes.y, sizes.x * sizes.y * sizes.z);
49
50 int base_i = idx.x * strides.x + idx.y * strides.y + idx.z * strides.z +
51 idx.w * strides.w;
52
53 return base_i + ivec4(0, 1, 2, 3) * strides[packed_dim];
54 }
55
56 //
57 // (w, h, c, n) Tensor Index <-> (x, y, z) Texture Position Conversion
58 //
59
60 /*
61 * Input: (x, y, z) texel position, (W, H, C, N) sizes of the tensor, which dim
62 * is packed along a texel
63 * Output: Whether the texel position is outside the bounds of the image texture
64 * given the size and packed dimension of the tensor.
65 */
pos_out_of_bounds(ivec3 pos,ivec4 sizes,int packed_dim)66 bool pos_out_of_bounds(ivec3 pos, ivec4 sizes, int packed_dim) {
67 // Align packed dim to next multiple of 4 to account for texel padding
68 sizes[packed_dim] = alignup4(sizes[packed_dim]);
69
70 ivec3 max_pos = sizes.xyz;
71 max_pos[BATCH_AXIS] += sizes.w * sizes[BATCH_AXIS];
72 max_pos[packed_dim] /= 4;
73 return (any(greaterThanEqual(pos, max_pos)));
74 }
75
76 /*
77 * Input: (x, y, z) texel position, (W, H, C, N) sizes of the tensor,
78 * which dim is packed along a texel
79 * Returns: the (w, h, c, n) tensor index cooresponding to the first element of
80 * the texel at the specified position
81 */
to_tensor_idx(ivec3 pos,ivec4 sizes,int packed_dim)82 ivec4 to_tensor_idx(ivec3 pos, ivec4 sizes, int packed_dim) {
83 // Align packed dim to next multiple of 4 to account for texel padding
84 sizes[packed_dim] = alignup4(sizes[packed_dim]);
85
86 // Packed dim contains 4 elements per texel
87 pos[packed_dim] *= 4;
88 // Construct the initial tensor index via swizzling
89 #if BATCH_AXIS == 2
90 ivec4 tensor_idx = pos.xyzz;
91 #endif
92 #if BATCH_AXIS == 1
93 ivec4 tensor_idx = pos.xyzy;
94 #endif
95 #if BATCH_AXIS == 0
96 ivec4 tensor_idx = pos.xyzx;
97 #endif
98 // Adjust the axis that the batch dim runs along
99 tensor_idx[3] /= sizes[BATCH_AXIS];
100 tensor_idx[BATCH_AXIS] %= sizes[BATCH_AXIS];
101
102 return tensor_idx;
103 }
104