xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/backends/vulkan/runtime/api/api.h>
12 
13 namespace vkcompute {
14 
15 /*
16  * Maps a semantic dimension name to an integer that corresponds to its
17  * innermost ordering in a 4D tensor in NCHW format. In a way, it is the
18  * "negative index" associated with a dim. For instance: in a NCHW tensor, Width
19  * is the innermost dimension, so it corresponds to 1, height is the next
20  * innermost, so it corresponds to 2, and so on.
21  */
22 enum DimIndex : int32_t {
23   DIM_LAST = -1,
24   DIM_2ND_LAST = -2,
25   DIM_3RD_LAST = -3,
26   DIM_4TH_LAST = -4,
27 };
28 
29 constexpr DimIndex kWidth4D = DimIndex::DIM_LAST;
30 constexpr DimIndex kHeight4D = DimIndex::DIM_2ND_LAST;
31 constexpr DimIndex kChannel4D = DimIndex::DIM_3RD_LAST;
32 constexpr DimIndex kBatch4D = DimIndex::DIM_4TH_LAST;
33 
normalize_to_dim_index(const api::vTensor & v_in,int32_t dim)34 inline DimIndex normalize_to_dim_index(const api::vTensor& v_in, int32_t dim) {
35   return dim < 0 ? static_cast<DimIndex>(dim)
36                  : static_cast<DimIndex>(dim - v_in.dim());
37 }
38 
39 /*
40  * Semantic dimension names for a 1D tensor
41  */
42 struct Dim1D {
43   static constexpr uint32_t Length = 1u;
44 };
45 
46 /*
47  * Semantic dimension names for a 2D Convolution kernel.
48  */
49 struct DimConv2DKernel {
50   static constexpr uint32_t Width = 1u;
51   static constexpr uint32_t Height = 2u;
52   static constexpr uint32_t InChannels = 3u;
53   static constexpr uint32_t OutChannels = 4u;
54 };
55 
56 /*
57  * The same as the above, except for a 2D Transposed Convolution kernel.
58  */
59 struct DimTConv2DKernel {
60   static constexpr uint32_t Width = 1u;
61   static constexpr uint32_t Height = 2u;
62   static constexpr uint32_t OutChannels = 3u;
63   static constexpr uint32_t InChannels = 4u;
64 };
65 
66 /*
67  * The functions below safely return the size of the dimension at the N-th
68  * innermost index. If the dimensionality of the size array is not sufficient
69  * then 1 will be returned. The structs above are intended to be used with
70  * these functions.
71  */
72 
dim_at(const std::vector<int64_t> & sizes,DimIndex dim_index)73 inline int32_t dim_at(const std::vector<int64_t>& sizes, DimIndex dim_index) {
74   const uint32_t dims = sizes.size();
75   // Recall that dim_index is a negative index.
76   return dims < -dim_index
77       ? 1
78       : utils::safe_downcast<int32_t>(sizes[dims + dim_index]);
79 }
80 
81 template <DimIndex DI>
dim_at(const std::vector<int64_t> & sizes)82 int32_t dim_at(const std::vector<int64_t>& sizes) {
83   return dim_at(sizes, DI);
84 }
85 
86 template <DimIndex DI>
dim_at(const api::vTensor & v_in)87 int32_t dim_at(const api::vTensor& v_in) {
88   return dim_at(v_in.sizes(), DI);
89 }
90 
dim_at(const api::vTensor & v_in,DimIndex dim_index)91 inline int32_t dim_at(const api::vTensor& v_in, DimIndex dim_index) {
92   return dim_at(v_in.sizes(), dim_index);
93 }
94 
95 inline std::ostream& operator<<(std::ostream& os, DimIndex dim_index) {
96   switch (dim_index) {
97     case kWidth4D:
98       os << "kWidth4D";
99       break;
100     case kHeight4D:
101       os << "kHeight4D";
102       break;
103     case kChannel4D:
104       os << "kChannel4D";
105       break;
106     case kBatch4D:
107       os << "kBatch4D";
108       break;
109     default:
110       os << "kDim4DUnknown";
111       break;
112   }
113   return os;
114 }
115 } // namespace vkcompute
116