xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/api/Types.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // @lint-ignore-every CLANGTIDY bugprone-branch-clone
4 
5 #ifdef USE_VULKAN_API
6 
7 #include <cstddef>
8 #include <cstdint>
9 
10 #include <ATen/native/vulkan/api/vk_api.h>
11 
12 #include <ATen/native/vulkan/api/Exception.h>
13 
14 #ifdef USE_VULKAN_FP16_INFERENCE
15 #define VK_FORMAT_FLOAT4 VK_FORMAT_R16G16B16A16_SFLOAT
16 #else
17 #define VK_FORMAT_FLOAT4 VK_FORMAT_R32G32B32A32_SFLOAT
18 #endif /* USE_VULKAN_FP16_INFERENCE */
19 
20 #define VK_FORALL_SCALAR_TYPES(_)               \
21   _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte)     \
22   _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char)      \
23   _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int)  \
24   _(bool, VK_FORMAT_R8G8B8A8_SINT, Bool)        \
25   _(float, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \
26   _(float, VK_FORMAT_FLOAT4, Float)             \
27   _(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8)     \
28   _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, QUInt8)   \
29   _(int32_t, VK_FORMAT_R32G32B32A32_SINT, QInt32)
30 
31 namespace at {
32 namespace native {
33 namespace vulkan {
34 namespace api {
35 
36 //
37 // Scalar Types
38 //
39 
40 enum class ScalarType : int8_t {
41 #define DEFINE_ENUM_VAL_(ctype, vkformat, name) name,
42   VK_FORALL_SCALAR_TYPES(DEFINE_ENUM_VAL_)
43 #undef DEFINE_ENUM_VAL_
44       Undefined,
45   NumOptions
46 };
47 
48 #define DEFINE_CONSTANT(ctype, vkformat, name) \
49   constexpr ScalarType k##name = ScalarType::name;
50 
VK_FORALL_SCALAR_TYPES(DEFINE_CONSTANT)51 VK_FORALL_SCALAR_TYPES(DEFINE_CONSTANT)
52 #undef DEFINE_CONSTANT
53 
54 /*
55  * Given a `ScalarType`, return the corresponding `VkFormat` that should be used
56  * for image texture storage. The `ScalarType` to `VkFormat` mapping is dictated
57  * by the `VK_FORALL_SCALAR_TYPE` macro in `api/Types.h`
58  */
59 inline VkFormat to_vkformat(const ScalarType t) {
60 #define CASE_VK_FORMAT(ctype, vkformat, name) \
61   case ScalarType::name:                      \
62     return vkformat;
63 
64   switch (t) {
65     VK_FORALL_SCALAR_TYPES(CASE_VK_FORMAT)
66     default:
67       VK_THROW("Unknown ScalarType: ", t);
68   }
69 #undef CASE_VK_FORMAT
70 }
71 
72 /*
73  * Given a `VkFormat`, return the `ScalarType` that best represents the data
74  * type of invidivual elements in an image texture of the `VkFormat`. Note that
75  * this mapping is different from the `to_vkformat()` function, since different
76  * `ScalarType`s may use the same `VkFormat`.
77  */
element_scalartype(const VkFormat vkformat)78 inline ScalarType element_scalartype(const VkFormat vkformat) {
79   switch (vkformat) {
80     case VK_FORMAT_R8G8B8A8_SINT:
81       return kChar;
82     case VK_FORMAT_R8G8B8A8_UINT:
83       return kByte;
84     case VK_FORMAT_R32G32B32A32_SINT:
85       return kInt;
86     case VK_FORMAT_R32G32B32A32_SFLOAT:
87       return kFloat;
88     case VK_FORMAT_R16G16B16A16_SFLOAT:
89       return kHalf;
90     default:
91       VK_THROW("No corresponding scalar type for unknown VkFormat: ", vkformat);
92   }
93 }
94 
95 /*
96  * Given a ScalarType, return `sizeof(ctype)` where ctype is the C type
97  * corresponding to the ScalarType. The C type to ScalarType mapping is dictated
98  * by the VK_FORALL_SCALAR_TYPE macro in api/Types.h
99  */
element_size(const ScalarType t)100 inline size_t element_size(const ScalarType t) {
101 #define CASE_ELEMENTSIZE_CASE(ctype, vkformat, name) \
102   case ScalarType::name:                             \
103     return sizeof(ctype);
104 
105   switch (t) {
106     VK_FORALL_SCALAR_TYPES(CASE_ELEMENTSIZE_CASE)
107     default:
108       VK_THROW("Unknown ScalarType: ", t);
109   }
110 #undef CASE_ELEMENTSIZE_CASE
111 }
112 
to_string(const ScalarType t)113 inline const char* to_string(const ScalarType t) {
114 #define CASE_TO_STRING(ctype, vkformat, name) \
115   case ScalarType::name:                      \
116     return #name;
117 
118   switch (t) {
119     VK_FORALL_SCALAR_TYPES(CASE_TO_STRING)
120     default:
121       return "UNKNOWN_SCALAR_TYPE";
122   }
123 #undef CASE_TO_STRING
124 }
125 
126 inline std::ostream& operator<<(std::ostream& os, const ScalarType dtype) {
127   return os << to_string(dtype);
128 }
129 
130 //
131 // Map ScalarTypes to C++ types
132 //
133 
134 template <ScalarType N>
135 struct ScalarTypeToCType;
136 
137 #define SPECIALIZE_ScalarTypeToCType(ctype, vkformat, scalar_type) \
138   template <>                                                      \
139   struct ScalarTypeToCType<                                        \
140       ::at::native::vulkan::api::ScalarType::scalar_type> {        \
141     using type = ctype;                                            \
142   };
143 
144 VK_FORALL_SCALAR_TYPES(SPECIALIZE_ScalarTypeToCType)
145 
146 #undef SPECIALIZE_ScalarTypeToCPPType
147 
148 //
149 // GPU Storage Options
150 //
151 
152 /**
153  * The enum below is used to describe what type of GPU memory will be used to
154  * store a particular tensor's data.
155  *
156  * BUFFER means that a SSBO (Shader Storage Buffer Object) will be used.
157  * TEXTURE_3D means that a 3-dimensional image texture will be used.
158  * TEXTURE_2D means that a 2-dimensional image texture will be used.
159  *
160  * UNKNOWN is not expected to be used.
161  */
162 enum class StorageType {
163   BUFFER,
164   TEXTURE_3D,
165   TEXTURE_2D,
166   UNKNOWN,
167 };
168 
169 /**
170  * The enum below is used to describe how tensor data is laid out when stored in
171  * GPU memory. The name of the enum describes which dimension is tightly packed;
172  * so for tensors that are stored as image textures, loading a texel will
173  * retrieve 4 consecutive elements of the named dimension, and for tensors
174  * stored as buffers, the named dimension will have a stride of 1.
175  *
176  * The GPU memory layout qualifier will be used by compute shaders to determine
177  * how to convert between logical tensor coordinates and physical texel
178  * coordinates. For tensors that are stored as buffers, it is expected that the
179  * strides of the tensor will be used instead to convert between logical tensor
180  * coordinates and linear access indices.
181  */
182 enum class GPUMemoryLayout : uint32_t {
183   TENSOR_WIDTH_PACKED = 0u,
184   TENSOR_HEIGHT_PACKED = 1u,
185   TENSOR_CHANNELS_PACKED = 2u,
186 };
187 
188 } // namespace api
189 } // namespace vulkan
190 } // namespace native
191 } // namespace at
192 
193 #endif /* USE_VULKAN_API */
194