1 #pragma once
2
3 // @lint-ignore-every CLANGTIDY bugprone-branch-clone
4
5 #ifdef USE_VULKAN_API
6
7 #include <cstddef>
8 #include <cstdint>
9
10 #include <ATen/native/vulkan/api/vk_api.h>
11
12 #include <ATen/native/vulkan/api/Exception.h>
13
14 #ifdef USE_VULKAN_FP16_INFERENCE
15 #define VK_FORMAT_FLOAT4 VK_FORMAT_R16G16B16A16_SFLOAT
16 #else
17 #define VK_FORMAT_FLOAT4 VK_FORMAT_R32G32B32A32_SFLOAT
18 #endif /* USE_VULKAN_FP16_INFERENCE */
19
20 #define VK_FORALL_SCALAR_TYPES(_) \
21 _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte) \
22 _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \
23 _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \
24 _(bool, VK_FORMAT_R8G8B8A8_SINT, Bool) \
25 _(float, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \
26 _(float, VK_FORMAT_FLOAT4, Float) \
27 _(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8) \
28 _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, QUInt8) \
29 _(int32_t, VK_FORMAT_R32G32B32A32_SINT, QInt32)
30
31 namespace at {
32 namespace native {
33 namespace vulkan {
34 namespace api {
35
36 //
37 // Scalar Types
38 //
39
40 enum class ScalarType : int8_t {
41 #define DEFINE_ENUM_VAL_(ctype, vkformat, name) name,
42 VK_FORALL_SCALAR_TYPES(DEFINE_ENUM_VAL_)
43 #undef DEFINE_ENUM_VAL_
44 Undefined,
45 NumOptions
46 };
47
48 #define DEFINE_CONSTANT(ctype, vkformat, name) \
49 constexpr ScalarType k##name = ScalarType::name;
50
VK_FORALL_SCALAR_TYPES(DEFINE_CONSTANT)51 VK_FORALL_SCALAR_TYPES(DEFINE_CONSTANT)
52 #undef DEFINE_CONSTANT
53
54 /*
55 * Given a `ScalarType`, return the corresponding `VkFormat` that should be used
56 * for image texture storage. The `ScalarType` to `VkFormat` mapping is dictated
57 * by the `VK_FORALL_SCALAR_TYPE` macro in `api/Types.h`
58 */
59 inline VkFormat to_vkformat(const ScalarType t) {
60 #define CASE_VK_FORMAT(ctype, vkformat, name) \
61 case ScalarType::name: \
62 return vkformat;
63
64 switch (t) {
65 VK_FORALL_SCALAR_TYPES(CASE_VK_FORMAT)
66 default:
67 VK_THROW("Unknown ScalarType: ", t);
68 }
69 #undef CASE_VK_FORMAT
70 }
71
72 /*
73 * Given a `VkFormat`, return the `ScalarType` that best represents the data
74 * type of invidivual elements in an image texture of the `VkFormat`. Note that
75 * this mapping is different from the `to_vkformat()` function, since different
76 * `ScalarType`s may use the same `VkFormat`.
77 */
element_scalartype(const VkFormat vkformat)78 inline ScalarType element_scalartype(const VkFormat vkformat) {
79 switch (vkformat) {
80 case VK_FORMAT_R8G8B8A8_SINT:
81 return kChar;
82 case VK_FORMAT_R8G8B8A8_UINT:
83 return kByte;
84 case VK_FORMAT_R32G32B32A32_SINT:
85 return kInt;
86 case VK_FORMAT_R32G32B32A32_SFLOAT:
87 return kFloat;
88 case VK_FORMAT_R16G16B16A16_SFLOAT:
89 return kHalf;
90 default:
91 VK_THROW("No corresponding scalar type for unknown VkFormat: ", vkformat);
92 }
93 }
94
95 /*
96 * Given a ScalarType, return `sizeof(ctype)` where ctype is the C type
97 * corresponding to the ScalarType. The C type to ScalarType mapping is dictated
98 * by the VK_FORALL_SCALAR_TYPE macro in api/Types.h
99 */
element_size(const ScalarType t)100 inline size_t element_size(const ScalarType t) {
101 #define CASE_ELEMENTSIZE_CASE(ctype, vkformat, name) \
102 case ScalarType::name: \
103 return sizeof(ctype);
104
105 switch (t) {
106 VK_FORALL_SCALAR_TYPES(CASE_ELEMENTSIZE_CASE)
107 default:
108 VK_THROW("Unknown ScalarType: ", t);
109 }
110 #undef CASE_ELEMENTSIZE_CASE
111 }
112
to_string(const ScalarType t)113 inline const char* to_string(const ScalarType t) {
114 #define CASE_TO_STRING(ctype, vkformat, name) \
115 case ScalarType::name: \
116 return #name;
117
118 switch (t) {
119 VK_FORALL_SCALAR_TYPES(CASE_TO_STRING)
120 default:
121 return "UNKNOWN_SCALAR_TYPE";
122 }
123 #undef CASE_TO_STRING
124 }
125
126 inline std::ostream& operator<<(std::ostream& os, const ScalarType dtype) {
127 return os << to_string(dtype);
128 }
129
130 //
131 // Map ScalarTypes to C++ types
132 //
133
134 template <ScalarType N>
135 struct ScalarTypeToCType;
136
137 #define SPECIALIZE_ScalarTypeToCType(ctype, vkformat, scalar_type) \
138 template <> \
139 struct ScalarTypeToCType< \
140 ::at::native::vulkan::api::ScalarType::scalar_type> { \
141 using type = ctype; \
142 };
143
144 VK_FORALL_SCALAR_TYPES(SPECIALIZE_ScalarTypeToCType)
145
146 #undef SPECIALIZE_ScalarTypeToCPPType
147
148 //
149 // GPU Storage Options
150 //
151
152 /**
153 * The enum below is used to describe what type of GPU memory will be used to
154 * store a particular tensor's data.
155 *
156 * BUFFER means that a SSBO (Shader Storage Buffer Object) will be used.
157 * TEXTURE_3D means that a 3-dimensional image texture will be used.
158 * TEXTURE_2D means that a 2-dimensional image texture will be used.
159 *
160 * UNKNOWN is not expected to be used.
161 */
162 enum class StorageType {
163 BUFFER,
164 TEXTURE_3D,
165 TEXTURE_2D,
166 UNKNOWN,
167 };
168
169 /**
170 * The enum below is used to describe how tensor data is laid out when stored in
171 * GPU memory. The name of the enum describes which dimension is tightly packed;
172 * so for tensors that are stored as image textures, loading a texel will
173 * retrieve 4 consecutive elements of the named dimension, and for tensors
174 * stored as buffers, the named dimension will have a stride of 1.
175 *
176 * The GPU memory layout qualifier will be used by compute shaders to determine
177 * how to convert between logical tensor coordinates and physical texel
178 * coordinates. For tensors that are stored as buffers, it is expected that the
179 * strides of the tensor will be used instead to convert between logical tensor
180 * coordinates and linear access indices.
181 */
182 enum class GPUMemoryLayout : uint32_t {
183 TENSOR_WIDTH_PACKED = 0u,
184 TENSOR_HEIGHT_PACKED = 1u,
185 TENSOR_CHANNELS_PACKED = 2u,
186 };
187
188 } // namespace api
189 } // namespace vulkan
190 } // namespace native
191 } // namespace at
192
193 #endif /* USE_VULKAN_API */
194