1 #pragma once
2
3 #include <ATen/cpu/vec/vec_base.h>
4 #include <ATen/cpu/vec/vec_n.h>
5
6 namespace at::vec {
7 inline namespace CPU_CAPABILITY {
8
9 template <
10 typename dst_t,
11 int dst_n,
12 typename src_t,
13 int src_n,
14 typename Enabled = void>
15 struct VecConvert {
applyVecConvert16 static inline VectorizedN<dst_t, dst_n> apply(
17 const VectorizedN<src_t, src_n>& src) {
18 constexpr int count = std::min(
19 VectorizedN<src_t, src_n>::size(), VectorizedN<dst_t, dst_n>::size());
20 __at_align__ src_t src_buf[VectorizedN<src_t, src_n>::size()];
21 src.store(src_buf);
22 __at_align__ dst_t dst_buf[VectorizedN<dst_t, dst_n>::size()];
23 for (int i = 0; i < count; i++) {
24 dst_buf[i] = static_cast<dst_t>(src_buf[i]);
25 }
26 return VectorizedN<dst_t, dst_n>::loadu(dst_buf, count);
27 }
28 };
29
30 template <typename dst_t, typename src_t>
31 inline std::enable_if_t<std::is_same_v<dst_t, src_t>, Vectorized<src_t>>
convert(const Vectorized<src_t> & src)32 convert(const Vectorized<src_t>& src) {
33 return src;
34 }
35
36 template <typename dst_t, typename src_t>
37 inline std::enable_if_t<!std::is_same_v<dst_t, src_t>, Vectorized<dst_t>>
convert(const Vectorized<src_t> & src)38 convert(const Vectorized<src_t>& src) {
39 return VecConvert<dst_t, 1, src_t, 1>::apply(src);
40 }
41
42 template <
43 typename dst_t,
44 int dst_n,
45 typename src_t,
46 int src_n,
47 std::enable_if_t<dst_n != 1, int> = 0>
convert(const VectorizedN<src_t,src_n> & src)48 inline VectorizedN<dst_t, dst_n> convert(const VectorizedN<src_t, src_n>& src) {
49 return VecConvert<dst_t, dst_n, src_t, src_n>::apply(src);
50 }
51
52 template <
53 typename dst_t,
54 int dst_n,
55 typename src_t,
56 int src_n,
57 bool keep = false,
58 std::enable_if_t<dst_n == 1, int> = 0>
59 inline std::conditional_t<keep, VectorizedN<dst_t, 1>, Vectorized<dst_t>>
convert(const VectorizedN<src_t,src_n> & src)60 convert(const VectorizedN<src_t, src_n>& src) {
61 return VecConvert<dst_t, dst_n, src_t, src_n>::apply(src);
62 }
63
64 } // namespace CPU_CAPABILITY
65 } // namespace at::vec
66