1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <executorch/backends/vulkan/runtime/api/api.h>
12
13 namespace vkcompute {
14
15 /*
16 * Maps a semantic dimension name to an integer that corresponds to its
17 * innermost ordering in a 4D tensor in NCHW format. In a way, it is the
18 * "negative index" associated with a dim. For instance: in a NCHW tensor, Width
19 * is the innermost dimension, so it corresponds to 1, height is the next
20 * innermost, so it corresponds to 2, and so on.
21 */
22 enum DimIndex : int32_t {
23 DIM_LAST = -1,
24 DIM_2ND_LAST = -2,
25 DIM_3RD_LAST = -3,
26 DIM_4TH_LAST = -4,
27 };
28
29 constexpr DimIndex kWidth4D = DimIndex::DIM_LAST;
30 constexpr DimIndex kHeight4D = DimIndex::DIM_2ND_LAST;
31 constexpr DimIndex kChannel4D = DimIndex::DIM_3RD_LAST;
32 constexpr DimIndex kBatch4D = DimIndex::DIM_4TH_LAST;
33
normalize_to_dim_index(const api::vTensor & v_in,int32_t dim)34 inline DimIndex normalize_to_dim_index(const api::vTensor& v_in, int32_t dim) {
35 return dim < 0 ? static_cast<DimIndex>(dim)
36 : static_cast<DimIndex>(dim - v_in.dim());
37 }
38
39 /*
40 * Semantic dimension names for a 1D tensor
41 */
42 struct Dim1D {
43 static constexpr uint32_t Length = 1u;
44 };
45
46 /*
47 * Semantic dimension names for a 2D Convolution kernel.
48 */
49 struct DimConv2DKernel {
50 static constexpr uint32_t Width = 1u;
51 static constexpr uint32_t Height = 2u;
52 static constexpr uint32_t InChannels = 3u;
53 static constexpr uint32_t OutChannels = 4u;
54 };
55
56 /*
57 * The same as the above, except for a 2D Transposed Convolution kernel.
58 */
59 struct DimTConv2DKernel {
60 static constexpr uint32_t Width = 1u;
61 static constexpr uint32_t Height = 2u;
62 static constexpr uint32_t OutChannels = 3u;
63 static constexpr uint32_t InChannels = 4u;
64 };
65
66 /*
67 * The functions below safely return the size of the dimension at the N-th
68 * innermost index. If the dimensionality of the size array is not sufficient
69 * then 1 will be returned. The structs above are intended to be used with
70 * these functions.
71 */
72
dim_at(const std::vector<int64_t> & sizes,DimIndex dim_index)73 inline int32_t dim_at(const std::vector<int64_t>& sizes, DimIndex dim_index) {
74 const uint32_t dims = sizes.size();
75 // Recall that dim_index is a negative index.
76 return dims < -dim_index
77 ? 1
78 : utils::safe_downcast<int32_t>(sizes[dims + dim_index]);
79 }
80
81 template <DimIndex DI>
dim_at(const std::vector<int64_t> & sizes)82 int32_t dim_at(const std::vector<int64_t>& sizes) {
83 return dim_at(sizes, DI);
84 }
85
86 template <DimIndex DI>
dim_at(const api::vTensor & v_in)87 int32_t dim_at(const api::vTensor& v_in) {
88 return dim_at(v_in.sizes(), DI);
89 }
90
dim_at(const api::vTensor & v_in,DimIndex dim_index)91 inline int32_t dim_at(const api::vTensor& v_in, DimIndex dim_index) {
92 return dim_at(v_in.sizes(), dim_index);
93 }
94
95 inline std::ostream& operator<<(std::ostream& os, DimIndex dim_index) {
96 switch (dim_index) {
97 case kWidth4D:
98 os << "kWidth4D";
99 break;
100 case kHeight4D:
101 os << "kHeight4D";
102 break;
103 case kChannel4D:
104 os << "kChannel4D";
105 break;
106 case kBatch4D:
107 os << "kBatch4D";
108 break;
109 default:
110 os << "kDim4DUnknown";
111 break;
112 }
113 return os;
114 }
115 } // namespace vkcompute
116