xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/vec/vec_convert.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cpu/vec/vec_base.h>
4 #include <ATen/cpu/vec/vec_n.h>
5 
6 namespace at::vec {
7 inline namespace CPU_CAPABILITY {
8 
9 template <
10     typename dst_t,
11     int dst_n,
12     typename src_t,
13     int src_n,
14     typename Enabled = void>
15 struct VecConvert {
applyVecConvert16   static inline VectorizedN<dst_t, dst_n> apply(
17       const VectorizedN<src_t, src_n>& src) {
18     constexpr int count = std::min(
19         VectorizedN<src_t, src_n>::size(), VectorizedN<dst_t, dst_n>::size());
20     __at_align__ src_t src_buf[VectorizedN<src_t, src_n>::size()];
21     src.store(src_buf);
22     __at_align__ dst_t dst_buf[VectorizedN<dst_t, dst_n>::size()];
23     for (int i = 0; i < count; i++) {
24       dst_buf[i] = static_cast<dst_t>(src_buf[i]);
25     }
26     return VectorizedN<dst_t, dst_n>::loadu(dst_buf, count);
27   }
28 };
29 
30 template <typename dst_t, typename src_t>
31 inline std::enable_if_t<std::is_same_v<dst_t, src_t>, Vectorized<src_t>>
convert(const Vectorized<src_t> & src)32 convert(const Vectorized<src_t>& src) {
33   return src;
34 }
35 
36 template <typename dst_t, typename src_t>
37 inline std::enable_if_t<!std::is_same_v<dst_t, src_t>, Vectorized<dst_t>>
convert(const Vectorized<src_t> & src)38 convert(const Vectorized<src_t>& src) {
39   return VecConvert<dst_t, 1, src_t, 1>::apply(src);
40 }
41 
42 template <
43     typename dst_t,
44     int dst_n,
45     typename src_t,
46     int src_n,
47     std::enable_if_t<dst_n != 1, int> = 0>
convert(const VectorizedN<src_t,src_n> & src)48 inline VectorizedN<dst_t, dst_n> convert(const VectorizedN<src_t, src_n>& src) {
49   return VecConvert<dst_t, dst_n, src_t, src_n>::apply(src);
50 }
51 
52 template <
53     typename dst_t,
54     int dst_n,
55     typename src_t,
56     int src_n,
57     bool keep = false,
58     std::enable_if_t<dst_n == 1, int> = 0>
59 inline std::conditional_t<keep, VectorizedN<dst_t, 1>, Vectorized<dst_t>>
convert(const VectorizedN<src_t,src_n> & src)60 convert(const VectorizedN<src_t, src_n>& src) {
61   return VecConvert<dst_t, dst_n, src_t, src_n>::apply(src);
62 }
63 
64 } // namespace CPU_CAPABILITY
65 } // namespace at::vec
66