xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/ReduceOps.h>
4 #include <ATen/native/ReduceAllOps.h>
5 #include <ATen/native/ReduceOpsUtils.h>
6 
7 #include <ATen/Dispatch.h>
8 #include <ATen/Parallel.h>
9 #include <ATen/TensorIterator.h>
10 #include <ATen/OpMathType.h>
11 
12 #include <ATen/native/cpu/Loops.h>
13 #include <ATen/native/cpu/zmath.h>
14 #include <ATen/cpu/vec/functional.h>
15 #include <ATen/cpu/vec/vec.h>
16 #include <c10/util/irange.h>
17 
18 namespace at::native {
19 namespace {
20 
21 using namespace vec;
22 
23 template <typename scalar_t, typename func_t, typename vec_func_t>
reduce_all_impl_vec(Tensor & output,const Tensor & input,const scalar_t ident_v,func_t op,vec_func_t vop)24 inline void reduce_all_impl_vec(
25     Tensor& output,
26     const Tensor& input,
27     const scalar_t ident_v,
28     func_t op,
29     vec_func_t vop) {
30   using Vec = Vectorized<opmath_type<scalar_t>>;
31   const int64_t input_numel = input.numel();
32   auto input_data = input.const_data_ptr<scalar_t>();
33   // NOTE: parallel_reduce not support bool type
34   scalar_t result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
35     [&](int64_t start, int64_t end, const scalar_t /*ident*/) -> scalar_t {
36       scalar_t partial_out = vec::reduce_all<scalar_t>(
37         [=](Vec x, Vec y) { return vop(x, y); },
38         input_data + start,
39         end - start);
40       return partial_out;
41     }, op);
42   output.fill_(result);
43 }
44 
45 // For operation not support in avx/avx2
46 template <typename scalar_t, typename func_t>
reduce_all_impl(Tensor & output,const Tensor & input,const scalar_t ident_v,func_t op)47 inline void reduce_all_impl(
48     Tensor& output,
49     const Tensor& input,
50     const scalar_t ident_v,
51     func_t op) {
52   const int64_t input_numel = input.numel();
53   auto input_data = input.const_data_ptr<scalar_t>();
54   scalar_t result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
55     [&](int64_t start, int64_t end, const scalar_t ident) -> scalar_t {
56       scalar_t partial_out = ident;
57       for (const auto i : c10::irange(start, end)) {
58          partial_out = op(partial_out, input_data[i]);
59       }
60       return partial_out;
61     }, op);
62   output.fill_(result);
63 }
64 
min_all_kernel_impl(Tensor & result,const Tensor & input)65 static void min_all_kernel_impl(Tensor& result, const Tensor& input) {
66   if (input.scalar_type() == ScalarType::Bool) {
67     TensorIterator iter = TensorIteratorConfig()
68       .add_input(input)
69       .build();
70     bool result_data  = true;
71     cpu_serial_kernel(iter, [&](const bool a) -> void {
72       result_data = result_data && a;
73     });
74     result.fill_(result_data);
75   } else if(input.scalar_type() == ScalarType::Long) {
76     // for int64_t, vectorized implementation have performance issue,
77     // just use scalar path
78     reduce_all_impl<int64_t>(result, input, upper_bound<int64_t>(),
79       [=](int64_t a, int64_t b) -> int64_t { return min_impl(a, b); });
80   } else {
81     AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "min_all", [&] {
82       using Vec = Vectorized<opmath_type<scalar_t>>;
83       reduce_all_impl_vec<scalar_t>(result, input, upper_bound<scalar_t>(),
84         [=] (scalar_t a , scalar_t b) -> scalar_t { return min_impl(a, b); },
85         [=](Vec a, Vec b) -> Vec { return minimum(a, b); });
86     });
87   }
88 }
89 
max_all_kernel_impl(Tensor & result,const Tensor & input)90 static void max_all_kernel_impl(Tensor& result, const Tensor& input) {
91   if (input.scalar_type() == ScalarType::Bool) {
92     TensorIterator iter = TensorIteratorConfig()
93       .add_input(input)
94       .build();
95     bool result_data  = false;
96     cpu_serial_kernel(iter, [&](const bool a) -> void {
97       result_data = result_data || a;
98     });
99     result.fill_(result_data);
100   } else if (input.scalar_type() == ScalarType::Long) {
101     // for int64_t, vectorized implementation have performance issue,
102     // just use scalar path
103     reduce_all_impl<int64_t>(result, input, lower_bound<int64_t>(),
104       [=](int64_t a, int64_t b) -> int64_t { return max_impl(a, b); });
105   } else {
106     AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "max_all", [&] {
107       using Vec = Vectorized<opmath_type<scalar_t>>;
108       reduce_all_impl_vec<scalar_t>(result, input, lower_bound<scalar_t>(),
109         [=] (scalar_t a , scalar_t b) -> scalar_t { return max_impl(a, b); },
110         [=](Vec a, Vec b) -> Vec { return maximum(a, b); });
111     });
112   }
113 }
114 
115 // For operation not support in avx/avx2
116 template <typename scalar_t, typename func_t1, typename func_t2>
reduce_all_impl_two_outputs(Tensor & output1,Tensor & output2,const Tensor & input,const std::pair<scalar_t,scalar_t> & ident_v,func_t1 reduce_chunk_func,func_t2 reduce_acc_func)117 inline void reduce_all_impl_two_outputs(
118     Tensor& output1,
119     Tensor& output2,
120     const Tensor& input,
121     const std::pair<scalar_t, scalar_t>& ident_v,
122     func_t1 reduce_chunk_func,
123     func_t2 reduce_acc_func) {
124   using scalar_t_pair = std::pair<scalar_t, scalar_t>;
125   const int64_t input_numel = input.numel();
126   auto input_data = input.const_data_ptr<scalar_t>();
127   scalar_t_pair result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
128     [&](int64_t start, int64_t end, const scalar_t_pair& ident) -> scalar_t_pair {
129       scalar_t_pair partial_out(ident);
130       for (const auto i : c10::irange(start, end)) {
131          partial_out = reduce_chunk_func(partial_out, input_data[i]);
132       }
133       return partial_out;
134     },
135     reduce_acc_func
136   );
137   output1.fill_(result.first);
138   output2.fill_(result.second);
139 }
140 
141 template <typename scalar_t, typename func_t, typename vec_func_t1, typename vec_func_t2>
reduce_all_impl_vec_two_outputs(Tensor & output1,Tensor & output2,const Tensor & input,const std::pair<scalar_t,scalar_t> & ident_v,func_t reduce_acc_func,vec_func_t1 reduce_chunk_func1,vec_func_t2 reduce_chunk_func2)142 inline void reduce_all_impl_vec_two_outputs(
143     Tensor& output1,
144     Tensor& output2,
145     const Tensor& input,
146     const std::pair<scalar_t, scalar_t>& ident_v,
147     func_t reduce_acc_func,
148     vec_func_t1 reduce_chunk_func1,
149     vec_func_t2 reduce_chunk_func2) {
150   using Vec = Vectorized<opmath_type<scalar_t>>;
151   using scalar_t_pair = std::pair<scalar_t, scalar_t>;
152   const int64_t input_numel = input.numel();
153   auto input_data = input.const_data_ptr<scalar_t>();
154   // NOTE: parallel_reduce not support bool type
155   std::pair<scalar_t, scalar_t> result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
156     [&](int64_t start, int64_t end, const scalar_t_pair& /* ident */) -> scalar_t_pair {
157     scalar_t_pair partial_out = vec::reduce2_all<scalar_t>(
158         [=](Vec x, Vec y) { return reduce_chunk_func1(x, y); },
159         [=](Vec x, Vec y) { return reduce_chunk_func2(x, y); },
160         input_data + start,
161         end - start);
162       return partial_out;
163     },
164     reduce_acc_func
165   );
166   output1.fill_(result.first);
167   output2.fill_(result.second);
168 }
169 
aminmax_allreduce_kernel(const Tensor & input,Tensor & min_result,Tensor & max_result)170 static void aminmax_allreduce_kernel(
171     const Tensor& input,
172     Tensor& min_result,
173     Tensor& max_result) {
174   if (input.scalar_type() == ScalarType::Bool) {
175     TensorIterator iter = TensorIteratorConfig()
176       .add_input(input)
177       .build();
178     bool min_result_data = true;
179     bool max_result_data = false;
180     cpu_serial_kernel(iter, [&](const bool a) -> void {
181       min_result_data = min_result_data && a;
182       max_result_data = max_result_data || a;
183     });
184     min_result.fill_(min_result_data);
185     max_result.fill_(max_result_data);
186   } else if (input.scalar_type() == ScalarType::Long) {
187     // for int64_t, vectorized implementation have performance issue,
188     // just use scalar path
189     using int64_t_pair = std::pair<int64_t, int64_t>;
190     reduce_all_impl_two_outputs<int64_t>(min_result, max_result, input,
191       int64_t_pair(upper_bound<int64_t>(), lower_bound<int64_t>()),
192       // reduce over chunk
193       [=](int64_t_pair a, int64_t b) -> int64_t_pair {
194         return int64_t_pair(min_impl(a.first, b), max_impl(a.second, b));
195       },
196       // combine two inputs
197       [=](int64_t_pair a, int64_t_pair b) -> int64_t_pair {
198         return int64_t_pair(min_impl(a.first, b.first), max_impl(a.second, b.second));
199       }
200     );
201   } else {
202     AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "aminmax_cpu", [&] {
203       using Vec = Vectorized<opmath_type<scalar_t>>;
204       using scalar_t_pair = std::pair<scalar_t, scalar_t>;
205       reduce_all_impl_vec_two_outputs<scalar_t>(
206         min_result,
207         max_result,
208         input,
209         scalar_t_pair(upper_bound<scalar_t>(), lower_bound<scalar_t>()),
210         [=] (scalar_t_pair a , scalar_t_pair b) -> scalar_t_pair {
211           return scalar_t_pair(
212             min_impl(a.first, b.first), max_impl(a.second, b.second));
213         },
214         [=](Vec a, Vec b) -> Vec { return minimum(a, b); },
215         [=](Vec a, Vec b) -> Vec { return maximum(a, b); }
216       );
217     });
218   }
219 }
220 
221 } // namespace
222 
223 REGISTER_DISPATCH(min_all_stub, &min_all_kernel_impl);
224 REGISTER_DISPATCH(max_all_stub, &max_all_kernel_impl);
225 REGISTER_DISPATCH(aminmax_allreduce_stub, &aminmax_allreduce_kernel);
226 
227 } // namespace at::native
228