xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/vec/vec.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #if defined(CPU_CAPABILITY_AVX512)
4 #include <ATen/cpu/vec/vec512/vec512.h>
5 #else
6 #include <ATen/cpu/vec/vec256/vec256.h>
7 #endif
8 
9 namespace at::vec {
10 // See Note [CPU_CAPABILITY namespace]
11 inline namespace CPU_CAPABILITY {
12 
convert_to_bool(Vectorized<int8_t> x)13 inline Vectorized<bool> convert_to_bool(Vectorized<int8_t> x) {
14   __at_align__ bool buffer[x.size()];
15   x.ne(Vectorized<int8_t>(0)).store(buffer);
16 
17   Vectorized<bool> ret;
18   static_assert(x.size() == ret.size());
19   std::memcpy(ret, buffer, ret.size() * sizeof(bool));
20   return ret;
21 }
22 
23 template <>
loadu(const void * ptr)24 inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr) {
25   // See NOTE [Loading boolean values]
26   return convert_to_bool(Vectorized<int8_t>::loadu(ptr));
27 }
28 
29 template <>
loadu(const void * ptr,int64_t count)30 inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr, int64_t count) {
31   // See NOTE [Loading boolean values]
32   return convert_to_bool(Vectorized<int8_t>::loadu(ptr, count));
33 }
34 
35 template <typename VT>
36 struct VecHoldType { using hold_type = typename VT::value_type; };
37 
38 template <>
39 struct VecHoldType<Vectorized<BFloat16>> { using hold_type = BFloat16; };
40 
41 template <>
42 struct VecHoldType<Vectorized<Half>> {using hold_type = Half; };
43 
44 template <typename VT>
45 using vechold_type = typename VecHoldType<VT>::hold_type;
46 
47 }} // namespace at::vec::CPU_CAPABILITY
48