xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/api/Utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cmath>
4 #include <numeric>
5 
6 #include <ATen/native/vulkan/api/vk_api.h>
7 
8 #include <ATen/native/vulkan/api/Exception.h>
9 
10 #ifdef USE_VULKAN_API
11 
12 // Compiler Macros
13 
14 // Suppress an unused variable. Copied from C10_UNUSED
15 #if defined(_MSC_VER) && !defined(__clang__)
16 #define VK_UNUSED __pragma(warning(suppress : 4100 4101))
17 #else
18 #define VK_UNUSED __attribute__((__unused__))
19 #endif //_MSC_VER
20 
21 namespace at {
22 namespace native {
23 namespace vulkan {
24 namespace api {
25 namespace utils {
26 
27 //
28 // Hashing
29 //
30 
31 /**
32  * hash_combine is taken from c10/util/hash.h, which in turn is based on
33  * implementation from Boost
34  */
hash_combine(size_t seed,size_t value)35 inline size_t hash_combine(size_t seed, size_t value) {
36   return seed ^ (value + 0x9e3779b9 + (seed << 6u) + (seed >> 2u));
37 }
38 
39 //
40 // Alignment
41 //
42 
43 template <typename Type>
align_down(const Type & number,const Type & multiple)44 inline constexpr Type align_down(const Type& number, const Type& multiple) {
45   return (number / multiple) * multiple;
46 }
47 
48 template <typename Type>
align_up(const Type & number,const Type & multiple)49 inline constexpr Type align_up(const Type& number, const Type& multiple) {
50   return align_down(number + multiple - 1, multiple);
51 }
52 
53 template <typename Type>
div_up(const Type & numerator,const Type & denominator)54 inline constexpr Type div_up(const Type& numerator, const Type& denominator) {
55   return (numerator + denominator - 1) / denominator;
56 }
57 
58 //
59 // Casting Utilities
60 //
61 
62 namespace detail {
63 
64 /*
65  * x cannot be less than 0 if x is unsigned
66  */
67 template <typename T>
is_negative(const T &,std::true_type)68 static inline constexpr bool is_negative(
69     const T& /*x*/,
70     std::true_type /*is_unsigned*/) {
71   return false;
72 }
73 
74 /*
75  * check if x is less than 0 if x is signed
76  */
77 template <typename T>
is_negative(const T & x,std::false_type)78 static inline constexpr bool is_negative(
79     const T& x,
80     std::false_type /*is_unsigned*/) {
81   return x < T(0);
82 }
83 
84 /*
85  * Returns true if x < 0
86  */
87 template <typename T>
is_negative(const T & x)88 inline constexpr bool is_negative(const T& x) {
89   return is_negative(x, std::is_unsigned<T>());
90 }
91 
92 /*
93  * Returns true if x < lowest(Limit); standard comparison
94  */
95 template <typename Limit, typename T>
less_than_lowest(const T & x,std::false_type,std::false_type)96 static inline constexpr bool less_than_lowest(
97     const T& x,
98     std::false_type /*limit_is_unsigned*/,
99     std::false_type /*x_is_unsigned*/) {
100   return x < std::numeric_limits<Limit>::lowest();
101 }
102 
103 /*
104  * Limit can contained negative values, but x cannot; return false
105  */
106 template <typename Limit, typename T>
less_than_lowest(const T &,std::false_type,std::true_type)107 static inline constexpr bool less_than_lowest(
108     const T& /*x*/,
109     std::false_type /*limit_is_unsigned*/,
110     std::true_type /*x_is_unsigned*/) {
111   return false;
112 }
113 
114 /*
115  * Limit cannot contained negative values, but x can; check if x is negative
116  */
117 template <typename Limit, typename T>
less_than_lowest(const T & x,std::true_type,std::false_type)118 static inline constexpr bool less_than_lowest(
119     const T& x,
120     std::true_type /*limit_is_unsigned*/,
121     std::false_type /*x_is_unsigned*/) {
122   return x < T(0);
123 }
124 
125 /*
126  * Both x and Limit cannot be negative; return false
127  */
128 template <typename Limit, typename T>
less_than_lowest(const T &,std::true_type,std::true_type)129 static inline constexpr bool less_than_lowest(
130     const T& /*x*/,
131     std::true_type /*limit_is_unsigned*/,
132     std::true_type /*x_is_unsigned*/) {
133   return false;
134 }
135 
136 /*
137  * Returns true if x is less than the lowest value of type T
138  */
139 template <typename Limit, typename T>
less_than_lowest(const T & x)140 inline constexpr bool less_than_lowest(const T& x) {
141   return less_than_lowest<Limit>(
142       x, std::is_unsigned<Limit>(), std::is_unsigned<T>());
143 }
144 
145 // Suppress sign compare warning when compiling with GCC
146 // as later does not account for short-circuit rule before
147 // raising the warning, see https://godbolt.org/z/Tr3Msnz99
148 #ifdef __GNUC__
149 #pragma GCC diagnostic push
150 #pragma GCC diagnostic ignored "-Wsign-compare"
151 #endif
152 
153 /*
154  * Returns true if x is greater than the greatest value of the type Limit
155  */
156 template <typename Limit, typename T>
greater_than_max(const T & x)157 inline constexpr bool greater_than_max(const T& x) {
158   constexpr bool can_overflow =
159       std::numeric_limits<T>::digits > std::numeric_limits<Limit>::digits;
160   return can_overflow && x > std::numeric_limits<Limit>::max();
161 }
162 
163 #ifdef __GNUC__
164 #pragma GCC diagnostic pop
165 #endif
166 
167 template <typename To, typename From>
168 std::enable_if_t<std::is_integral_v<From> && !std::is_same_v<From, bool>, bool>
overflows(From f)169 overflows(From f) {
170   using limit = std::numeric_limits<To>;
171   // Casting from signed to unsigned; allow for negative numbers to wrap using
172   // two's complement arithmetic.
173   if (!limit::is_signed && std::numeric_limits<From>::is_signed) {
174     return greater_than_max<To>(f) ||
175         (is_negative(f) && -static_cast<uint64_t>(f) > limit::max());
176   }
177   // standard case, check if f is outside the range of type To
178   else {
179     return less_than_lowest<To>(f) || greater_than_max<To>(f);
180   }
181 }
182 
183 template <typename To, typename From>
overflows(From f)184 std::enable_if_t<std::is_floating_point_v<From>, bool> overflows(From f) {
185   using limit = std::numeric_limits<To>;
186   if (limit::has_infinity && std::isinf(static_cast<double>(f))) {
187     return false;
188   }
189   return f < limit::lowest() || f > limit::max();
190 }
191 
192 template <typename To, typename From>
safe_downcast(const From & v)193 inline constexpr To safe_downcast(const From& v) {
194   VK_CHECK_COND(!overflows<To>(v), "Cast failed: out of range!");
195   return static_cast<To>(v);
196 }
197 
198 template <typename To, typename From>
is_signed_to_unsigned()199 inline constexpr bool is_signed_to_unsigned() {
200   return std::is_signed<From>::value && std::is_unsigned<To>::value;
201 }
202 
203 } // namespace detail
204 
205 template <
206     typename To,
207     typename From,
208     std::enable_if_t<detail::is_signed_to_unsigned<To, From>(), bool> = true>
safe_downcast(const From & v)209 inline constexpr To safe_downcast(const From& v) {
210   VK_CHECK_COND(v >= From{}, "Cast failed: negative signed to unsigned!");
211   return detail::safe_downcast<To, From>(v);
212 }
213 
214 template <
215     typename To,
216     typename From,
217     std::enable_if_t<!detail::is_signed_to_unsigned<To, From>(), bool> = true>
safe_downcast(const From & v)218 inline constexpr To safe_downcast(const From& v) {
219   return detail::safe_downcast<To, From>(v);
220 }
221 
222 //
223 // Vector Types
224 //
225 
226 namespace detail {
227 
228 template <typename Type, uint32_t N>
229 struct vec final {
230   // NOLINTNEXTLINE
231   Type data[N];
232 };
233 
234 } // namespace detail
235 
236 template <uint32_t N>
237 using ivec = detail::vec<int32_t, N>;
238 using ivec2 = ivec<2u>;
239 using ivec3 = ivec<3u>;
240 using ivec4 = ivec<4u>;
241 
242 template <uint32_t N>
243 using uvec = detail::vec<uint32_t, N>;
244 using uvec2 = uvec<2u>;
245 using uvec3 = uvec<3u>;
246 using uvec4 = uvec<4u>;
247 
248 template <uint32_t N>
249 using vec = detail::vec<float, N>;
250 using vec2 = vec<2u>;
251 using vec3 = vec<3u>;
252 using vec4 = vec<4u>;
253 
254 // uvec3 is the type representing tensor extents. Useful for debugging.
255 inline std::ostream& operator<<(std::ostream& os, const uvec3& v) {
256   os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ")";
257   return os;
258 }
259 
260 //
261 // std::vector<T> Handling
262 //
263 
264 /*
265  * Utility function to perform indexing on an std::vector<T>. Negative indexing
266  * is allowed. For instance, passing an index of -1 will retrieve the last
267  * element. If the requested index is out of bounds, then 1u will be returned.
268  */
269 template <typename T>
val_at(const int64_t index,const std::vector<T> & sizes)270 inline T val_at(const int64_t index, const std::vector<T>& sizes) {
271   const int64_t ndim = static_cast<int64_t>(sizes.size());
272   if (index >= 0) {
273     return index >= ndim ? 1 : sizes[index];
274   } else {
275     return ndim + index < 0 ? 1 : sizes[ndim + index];
276   }
277 }
278 
279 inline ivec2 make_ivec2(
280     const std::vector<int64_t>& ints,
281     bool reverse = false) {
282   VK_CHECK_COND(ints.size() == 2);
283   if (reverse) {
284     return {safe_downcast<int32_t>(ints[1]), safe_downcast<int32_t>(ints[0])};
285   } else {
286     return {safe_downcast<int32_t>(ints[0]), safe_downcast<int32_t>(ints[1])};
287   }
288 }
289 
290 inline ivec4 make_ivec4(
291     const std::vector<int64_t>& ints,
292     bool reverse = false) {
293   VK_CHECK_COND(ints.size() == 4);
294   if (reverse) {
295     return {
296         safe_downcast<int32_t>(ints[3]),
297         safe_downcast<int32_t>(ints[2]),
298         safe_downcast<int32_t>(ints[1]),
299         safe_downcast<int32_t>(ints[0]),
300     };
301   } else {
302     return {
303         safe_downcast<int32_t>(ints[0]),
304         safe_downcast<int32_t>(ints[1]),
305         safe_downcast<int32_t>(ints[2]),
306         safe_downcast<int32_t>(ints[3]),
307     };
308   }
309 }
310 
make_ivec4_prepadded1(const std::vector<int64_t> & ints)311 inline ivec4 make_ivec4_prepadded1(const std::vector<int64_t>& ints) {
312   VK_CHECK_COND(ints.size() <= 4);
313 
314   ivec4 result = {1, 1, 1, 1};
315   size_t base = 4 - ints.size();
316   for (size_t i = 0; i < ints.size(); ++i) {
317     result.data[i + base] = safe_downcast<int32_t>(ints[i]);
318   }
319 
320   return result;
321 }
322 
make_ivec3(uvec3 ints)323 inline ivec3 make_ivec3(uvec3 ints) {
324   return {
325       safe_downcast<int32_t>(ints.data[0u]),
326       safe_downcast<int32_t>(ints.data[1u]),
327       safe_downcast<int32_t>(ints.data[2u])};
328 }
329 
330 /*
331  * Given an vector of up to 4 uint64_t representing the sizes of a tensor,
332  * constructs a uvec4 containing those elements in reverse order.
333  */
make_whcn_uvec4(const std::vector<int64_t> & arr)334 inline uvec4 make_whcn_uvec4(const std::vector<int64_t>& arr) {
335   uint32_t w = safe_downcast<uint32_t>(val_at(-1, arr));
336   uint32_t h = safe_downcast<uint32_t>(val_at(-2, arr));
337   uint32_t c = safe_downcast<uint32_t>(val_at(-3, arr));
338   uint32_t n = safe_downcast<uint32_t>(val_at(-4, arr));
339 
340   return {w, h, c, n};
341 }
342 
343 /*
344  * Given an vector of up to 4 int64_t representing the sizes of a tensor,
345  * constructs an ivec4 containing those elements in reverse order.
346  */
make_whcn_ivec4(const std::vector<int64_t> & arr)347 inline ivec4 make_whcn_ivec4(const std::vector<int64_t>& arr) {
348   int32_t w = val_at(-1, arr);
349   int32_t h = val_at(-2, arr);
350   int32_t c = val_at(-3, arr);
351   int32_t n = val_at(-4, arr);
352 
353   return {w, h, c, n};
354 }
355 
356 /*
357  * Wrapper around std::accumulate that accumulates values of a container of
358  * integral types into int64_t. Taken from `multiply_integers` in
359  * <c10/util/accumulate.h>
360  */
361 template <
362     typename C,
363     std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
multiply_integers(const C & container)364 inline int64_t multiply_integers(const C& container) {
365   return std::accumulate(
366       container.begin(),
367       container.end(),
368       static_cast<int64_t>(1),
369       std::multiplies<>());
370 }
371 
372 } // namespace utils
373 
374 inline bool operator==(const utils::uvec3& _1, const utils::uvec3& _2) {
375   return (
376       _1.data[0u] == _2.data[0u] && _1.data[1u] == _2.data[1u] &&
377       _1.data[2u] == _2.data[2u]);
378 }
379 
create_offset3d(const utils::uvec3 & offsets)380 inline VkOffset3D create_offset3d(const utils::uvec3& offsets) {
381   return VkOffset3D{
382       utils::safe_downcast<int32_t>(offsets.data[0u]),
383       static_cast<int32_t>(offsets.data[1u]),
384       static_cast<int32_t>(offsets.data[2u])};
385 }
386 
create_extent3d(const utils::uvec3 & extents)387 inline VkExtent3D create_extent3d(const utils::uvec3& extents) {
388   return VkExtent3D{extents.data[0u], extents.data[1u], extents.data[2u]};
389 }
390 
391 } // namespace api
392 } // namespace vulkan
393 } // namespace native
394 } // namespace at
395 
396 #endif /* USE_VULKAN_API */
397