1 #pragma once 2 3 #if defined(CPU_CAPABILITY_AVX512) 4 #include <ATen/cpu/vec/vec512/vec512.h> 5 #else 6 #include <ATen/cpu/vec/vec256/vec256.h> 7 #endif 8 9 namespace at::vec { 10 // See Note [CPU_CAPABILITY namespace] 11 inline namespace CPU_CAPABILITY { 12 convert_to_bool(Vectorized<int8_t> x)13inline Vectorized<bool> convert_to_bool(Vectorized<int8_t> x) { 14 __at_align__ bool buffer[x.size()]; 15 x.ne(Vectorized<int8_t>(0)).store(buffer); 16 17 Vectorized<bool> ret; 18 static_assert(x.size() == ret.size()); 19 std::memcpy(ret, buffer, ret.size() * sizeof(bool)); 20 return ret; 21 } 22 23 template <> loadu(const void * ptr)24inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr) { 25 // See NOTE [Loading boolean values] 26 return convert_to_bool(Vectorized<int8_t>::loadu(ptr)); 27 } 28 29 template <> loadu(const void * ptr,int64_t count)30inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr, int64_t count) { 31 // See NOTE [Loading boolean values] 32 return convert_to_bool(Vectorized<int8_t>::loadu(ptr, count)); 33 } 34 35 template <typename VT> 36 struct VecHoldType { using hold_type = typename VT::value_type; }; 37 38 template <> 39 struct VecHoldType<Vectorized<BFloat16>> { using hold_type = BFloat16; }; 40 41 template <> 42 struct VecHoldType<Vectorized<Half>> {using hold_type = Half; }; 43 44 template <typename VT> 45 using vechold_type = typename VecHoldType<VT>::hold_type; 46 47 }} // namespace at::vec::CPU_CAPABILITY 48