xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/ReduceUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Parallel.h>
4 #include <ATen/NumericUtils.h>
5 #include <ATen/cpu/vec/vec.h>
6 #include <ATen/cpu/vec/functional.h>
7 #include <ATen/native/ReductionType.h>
8 #include <c10/util/irange.h>
9 #include <ATen/OpMathType.h>
10 #include <ATen/native/cpu/utils.h>
11 #include <ATen/OpMathType.h>
12 
13 namespace at::native {
14 inline namespace CPU_CAPABILITY {
15 
16 using namespace vec;
17 
18 #define AT_DISPATCH_REDUCTION_TYPES(op, ...)                                   \
19   [&] {                                                                        \
20     switch (op) {                                                              \
21       case ReductionType::SUM: {                                               \
22         static constexpr auto reduce = ReductionType::SUM;                     \
23         return __VA_ARGS__();                                                  \
24       }                                                                        \
25       case ReductionType::MEAN: {                                              \
26         static constexpr auto reduce = ReductionType::MEAN;                    \
27         return __VA_ARGS__();                                                  \
28       }                                                                        \
29       case ReductionType::MIN: {                                               \
30         static constexpr auto reduce = ReductionType::MIN;                     \
31         return __VA_ARGS__();                                                  \
32       }                                                                        \
33       case ReductionType::MAX: {                                               \
34         static constexpr auto reduce = ReductionType::MAX;                     \
35         return __VA_ARGS__();                                                  \
36       }                                                                        \
37       case ReductionType::PROD: {                                              \
38         static constexpr auto reduce = ReductionType::PROD;                    \
39         return __VA_ARGS__();                                                  \
40       }                                                                        \
41     }                                                                          \
42   }()
43 
44 template <typename scalar_t, ReductionType reduce>
init_value()45 inline vec_scalar_t<scalar_t> init_value() {
46   using acc_t = vec_scalar_t<scalar_t>;
47   acc_t val;
48   if (reduce == ReductionType::SUM ||
49       reduce == ReductionType::MEAN) {
50     val = static_cast<acc_t>(0);
51   } else if (reduce == ReductionType::PROD) {
52     val = static_cast<acc_t>(1);
53   } else if (reduce == ReductionType::MAX) {
54     val = -std::numeric_limits<acc_t>::infinity();
55   } else {
56     TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
57     val = std::numeric_limits<acc_t>::infinity();
58   }
59   return val;
60 }
61 
62 template <typename scalar_t, ReductionType reduce>
init_value(const std::optional<Scalar> & initial)63 inline vec_scalar_t<scalar_t> init_value(const std::optional<Scalar>& initial) {
64   using acc_t = vec_scalar_t<scalar_t>;
65   if (initial.has_value()) {
66     return initial.value().to<acc_t>();
67   } else {
68     return init_value<scalar_t, reduce>();
69   }
70 }
71 
72 template <typename scalar_t>
init(scalar_t * out,int64_t size,const vec_scalar_t<scalar_t> & val)73 inline void init(scalar_t* out, int64_t size, const vec_scalar_t<scalar_t>& val) {
74   using Vec = Vectorized<vec_scalar_t<scalar_t>>;
75   map<scalar_t>(
76       [val](Vec x) { return Vec(val); },
77       out,
78       out,
79       size);
80 }
81 
82 template <typename scalar_t, ReductionType reduce>
init(scalar_t * out,int64_t size,const std::optional<Scalar> & initial)83 inline void init(scalar_t* out, int64_t size, const std::optional<Scalar>& initial) {
84   using acc_t = vec_scalar_t<scalar_t>;
85   acc_t val = init_value<scalar_t, reduce>(initial);
86   init(out, size, val);
87 }
88 
89 // overload with `include_self`, used by scatter_reduce
90 template <typename scalar_t, ReductionType reduce>
91 inline void init(scalar_t* out, int64_t size, bool include_self = false) {
92   using acc_t = vec_scalar_t<scalar_t>;
93   if (!include_self) {
94     acc_t val = init_value<scalar_t, reduce>();
95     init(out, size, val);
96   }
97 }
98 
99 template <typename scalar_t, ReductionType reduce>
_init(scalar_t * self_ptr,at::opmath_type<scalar_t> * buffer_ptr,int64_t size,bool include_self)100 inline void _init(scalar_t* self_ptr, at::opmath_type<scalar_t>* buffer_ptr, int64_t size, bool include_self) {
101   if (!include_self) {
102     init<at::opmath_type<scalar_t>, reduce>(buffer_ptr, size, include_self);
103   } else {
104     vec::convert(self_ptr, buffer_ptr, size);
105   }
106 }
107 
108 template <typename scalar_t>
109 inline typename std::enable_if<!std::is_same<scalar_t, Vec2>::value, scalar_t>::type
_max(const scalar_t & x,const scalar_t & y)110 _max(const scalar_t& x, const scalar_t& y) {
111   return at::_isnan(y) ? y : std::max(x, y);
112 }
113 
114 template <typename scalar_t>
_max(const Vectorized<scalar_t> & x,const Vectorized<scalar_t> & y)115 inline Vectorized<scalar_t> _max(const Vectorized<scalar_t>& x, const Vectorized<scalar_t>& y) {
116   // vec::maximum propagates NaN
117   return vec::maximum(x, y);
118 }
119 
120 template <typename vec_t>
121 inline typename std::enable_if<std::is_same<vec_t, Vec2>::value, Vec2>::type
_max(const vec_t & x,const vec_t & y)122 _max(const vec_t& x, const vec_t& y) {
123   // vec::maximum propagates NaN
124   return maximum(x, y);
125 }
126 
127 template <typename scalar_t>
128 inline typename std::enable_if<!std::is_same<scalar_t, Vec2>::value, scalar_t>::type
_min(const scalar_t & x,const scalar_t & y)129 _min(const scalar_t& x, const scalar_t& y) {
130   return at::_isnan(y) ? y : std::min(x, y);
131 }
132 
133 template <typename scalar_t>
_min(const Vectorized<scalar_t> & x,const Vectorized<scalar_t> & y)134 inline Vectorized<scalar_t> _min(const Vectorized<scalar_t>& x, const Vectorized<scalar_t>& y) {
135   // vec::minimum propagates NaN
136   return vec::minimum(x, y);
137 }
138 
139 template <typename vec_t>
140 inline typename std::enable_if<std::is_same<vec_t, Vec2>::value, Vec2>::type
_min(const vec_t & x,const vec_t & y)141 _min(const vec_t& x, const vec_t& y) {
142   // vec::minimum propagates NaN
143   return minimum(x, y);
144 }
145 
146 template <typename scalar_t, typename accumut, typename Op,
147           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
map_acc(const Op & vec_fun,accumut * output_data,const accumut * input_data,const scalar_t * input_data2,int64_t size)148 inline void map_acc(
149     const Op& vec_fun,
150     accumut* output_data,
151     const accumut* input_data,
152     const scalar_t* input_data2,
153     int64_t size) {
154   using Vec = vec::Vectorized<scalar_t>;
155   using aVec = vec::Vectorized<accumut>;
156   int64_t d = 0;
157   constexpr int64_t kVecSize = Vec::size();
158   constexpr int64_t kaVecSize = aVec::size();
159   for (d = 0; d < size - (size % kVecSize); d += kVecSize) {
160     Vec data2_vec = Vec::loadu(input_data2 + d);
161     auto [data2_avec0, data2_avec1] = convert_to_float<scalar_t>(data2_vec);
162     aVec input_vec0 = aVec::loadu(input_data + d);
163     aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize);
164     vec_fun(input_vec0, data2_avec0).store(output_data + d);
165     vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize);
166   }
167   if (size - d > 0) {
168     int64_t tail_size = size - d;
169     Vec data2_vec = Vec::loadu(input_data2 + d, tail_size);
170     auto [data2_avec0, data2_avec1] = convert_to_float<scalar_t>(data2_vec);
171     if (tail_size > kaVecSize) {
172       aVec input_vec0 = aVec::loadu(input_data + d);
173       aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize, tail_size - kaVecSize);
174       vec_fun(input_vec0, data2_avec0).store(output_data + d);
175       vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize, tail_size - kaVecSize);
176     } else {
177       aVec input_vec0 = aVec::loadu(input_data + d, tail_size);
178       vec_fun(input_vec0, data2_avec0).store(output_data + d, tail_size);
179     }
180   }
181 }
182 
183 // for Max and Min, propagate NaN:
184 template <typename T, ReductionType reduce>
update(const T & x,const T & y)185 inline T update(const T& x, const T& y) {
186   if (reduce == ReductionType::SUM ||
187       reduce == ReductionType::MEAN) {
188     return x + y;
189   } else if (reduce == ReductionType::PROD) {
190     return x * y;
191   } else if (reduce == ReductionType::MAX) {
192     return _max(x, y);
193   } else {
194     TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
195     return _min(x, y);
196   }
197 }
198 
199 template <typename scalar_t, ReductionType reduce>
update(scalar_t * out,const scalar_t * data,int64_t K)200 inline void update(scalar_t* out, const scalar_t* data, int64_t K) {
201   using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
202   map2<scalar_t>(
203       [](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
204       out,
205       out,
206       data,
207       K);
208 }
209 
210 template <typename scalar_t, ReductionType reduce,
211           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
update(at::opmath_type<scalar_t> * out,const scalar_t * data,int64_t K)212 inline void update(at::opmath_type<scalar_t>* out, const scalar_t* data, int64_t K) {
213   using opmath_t = at::opmath_type<scalar_t>;
214   using Vec = vec::Vectorized<opmath_t>;
215   map_acc<scalar_t, opmath_t>(
216       [](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
217       out,
218       out,
219       data,
220       K);
221 }
222 
223 template <typename scalar_t, ReductionType reduce>
write(scalar_t * out,int64_t count,int64_t K)224 inline void write(scalar_t* out, int64_t count, int64_t K) {
225   using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
226   if (reduce == ReductionType::MEAN) {
227     if (count > 0) {
228       vec::map<scalar_t>(
229           [count](Vec x) { return x / Vec(count); },
230           out,
231           out,
232           K);
233     }
234   }
235 }
236 
237 } // namespace CPU_CAPABILITY
238 } // namespace at::native
239