1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/ReduceOps.h>
4 #include <ATen/native/ReduceAllOps.h>
5 #include <ATen/native/ReduceOpsUtils.h>
6
7 #include <ATen/Dispatch.h>
8 #include <ATen/Parallel.h>
9 #include <ATen/TensorIterator.h>
10 #include <ATen/OpMathType.h>
11
12 #include <ATen/native/cpu/Loops.h>
13 #include <ATen/native/cpu/zmath.h>
14 #include <ATen/cpu/vec/functional.h>
15 #include <ATen/cpu/vec/vec.h>
16 #include <c10/util/irange.h>
17
18 namespace at::native {
19 namespace {
20
21 using namespace vec;
22
23 template <typename scalar_t, typename func_t, typename vec_func_t>
reduce_all_impl_vec(Tensor & output,const Tensor & input,const scalar_t ident_v,func_t op,vec_func_t vop)24 inline void reduce_all_impl_vec(
25 Tensor& output,
26 const Tensor& input,
27 const scalar_t ident_v,
28 func_t op,
29 vec_func_t vop) {
30 using Vec = Vectorized<opmath_type<scalar_t>>;
31 const int64_t input_numel = input.numel();
32 auto input_data = input.const_data_ptr<scalar_t>();
33 // NOTE: parallel_reduce not support bool type
34 scalar_t result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
35 [&](int64_t start, int64_t end, const scalar_t /*ident*/) -> scalar_t {
36 scalar_t partial_out = vec::reduce_all<scalar_t>(
37 [=](Vec x, Vec y) { return vop(x, y); },
38 input_data + start,
39 end - start);
40 return partial_out;
41 }, op);
42 output.fill_(result);
43 }
44
45 // For operation not support in avx/avx2
46 template <typename scalar_t, typename func_t>
reduce_all_impl(Tensor & output,const Tensor & input,const scalar_t ident_v,func_t op)47 inline void reduce_all_impl(
48 Tensor& output,
49 const Tensor& input,
50 const scalar_t ident_v,
51 func_t op) {
52 const int64_t input_numel = input.numel();
53 auto input_data = input.const_data_ptr<scalar_t>();
54 scalar_t result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
55 [&](int64_t start, int64_t end, const scalar_t ident) -> scalar_t {
56 scalar_t partial_out = ident;
57 for (const auto i : c10::irange(start, end)) {
58 partial_out = op(partial_out, input_data[i]);
59 }
60 return partial_out;
61 }, op);
62 output.fill_(result);
63 }
64
min_all_kernel_impl(Tensor & result,const Tensor & input)65 static void min_all_kernel_impl(Tensor& result, const Tensor& input) {
66 if (input.scalar_type() == ScalarType::Bool) {
67 TensorIterator iter = TensorIteratorConfig()
68 .add_input(input)
69 .build();
70 bool result_data = true;
71 cpu_serial_kernel(iter, [&](const bool a) -> void {
72 result_data = result_data && a;
73 });
74 result.fill_(result_data);
75 } else if(input.scalar_type() == ScalarType::Long) {
76 // for int64_t, vectorized implementation have performance issue,
77 // just use scalar path
78 reduce_all_impl<int64_t>(result, input, upper_bound<int64_t>(),
79 [=](int64_t a, int64_t b) -> int64_t { return min_impl(a, b); });
80 } else {
81 AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "min_all", [&] {
82 using Vec = Vectorized<opmath_type<scalar_t>>;
83 reduce_all_impl_vec<scalar_t>(result, input, upper_bound<scalar_t>(),
84 [=] (scalar_t a , scalar_t b) -> scalar_t { return min_impl(a, b); },
85 [=](Vec a, Vec b) -> Vec { return minimum(a, b); });
86 });
87 }
88 }
89
max_all_kernel_impl(Tensor & result,const Tensor & input)90 static void max_all_kernel_impl(Tensor& result, const Tensor& input) {
91 if (input.scalar_type() == ScalarType::Bool) {
92 TensorIterator iter = TensorIteratorConfig()
93 .add_input(input)
94 .build();
95 bool result_data = false;
96 cpu_serial_kernel(iter, [&](const bool a) -> void {
97 result_data = result_data || a;
98 });
99 result.fill_(result_data);
100 } else if (input.scalar_type() == ScalarType::Long) {
101 // for int64_t, vectorized implementation have performance issue,
102 // just use scalar path
103 reduce_all_impl<int64_t>(result, input, lower_bound<int64_t>(),
104 [=](int64_t a, int64_t b) -> int64_t { return max_impl(a, b); });
105 } else {
106 AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "max_all", [&] {
107 using Vec = Vectorized<opmath_type<scalar_t>>;
108 reduce_all_impl_vec<scalar_t>(result, input, lower_bound<scalar_t>(),
109 [=] (scalar_t a , scalar_t b) -> scalar_t { return max_impl(a, b); },
110 [=](Vec a, Vec b) -> Vec { return maximum(a, b); });
111 });
112 }
113 }
114
115 // For operation not support in avx/avx2
116 template <typename scalar_t, typename func_t1, typename func_t2>
reduce_all_impl_two_outputs(Tensor & output1,Tensor & output2,const Tensor & input,const std::pair<scalar_t,scalar_t> & ident_v,func_t1 reduce_chunk_func,func_t2 reduce_acc_func)117 inline void reduce_all_impl_two_outputs(
118 Tensor& output1,
119 Tensor& output2,
120 const Tensor& input,
121 const std::pair<scalar_t, scalar_t>& ident_v,
122 func_t1 reduce_chunk_func,
123 func_t2 reduce_acc_func) {
124 using scalar_t_pair = std::pair<scalar_t, scalar_t>;
125 const int64_t input_numel = input.numel();
126 auto input_data = input.const_data_ptr<scalar_t>();
127 scalar_t_pair result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
128 [&](int64_t start, int64_t end, const scalar_t_pair& ident) -> scalar_t_pair {
129 scalar_t_pair partial_out(ident);
130 for (const auto i : c10::irange(start, end)) {
131 partial_out = reduce_chunk_func(partial_out, input_data[i]);
132 }
133 return partial_out;
134 },
135 reduce_acc_func
136 );
137 output1.fill_(result.first);
138 output2.fill_(result.second);
139 }
140
141 template <typename scalar_t, typename func_t, typename vec_func_t1, typename vec_func_t2>
reduce_all_impl_vec_two_outputs(Tensor & output1,Tensor & output2,const Tensor & input,const std::pair<scalar_t,scalar_t> & ident_v,func_t reduce_acc_func,vec_func_t1 reduce_chunk_func1,vec_func_t2 reduce_chunk_func2)142 inline void reduce_all_impl_vec_two_outputs(
143 Tensor& output1,
144 Tensor& output2,
145 const Tensor& input,
146 const std::pair<scalar_t, scalar_t>& ident_v,
147 func_t reduce_acc_func,
148 vec_func_t1 reduce_chunk_func1,
149 vec_func_t2 reduce_chunk_func2) {
150 using Vec = Vectorized<opmath_type<scalar_t>>;
151 using scalar_t_pair = std::pair<scalar_t, scalar_t>;
152 const int64_t input_numel = input.numel();
153 auto input_data = input.const_data_ptr<scalar_t>();
154 // NOTE: parallel_reduce not support bool type
155 std::pair<scalar_t, scalar_t> result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
156 [&](int64_t start, int64_t end, const scalar_t_pair& /* ident */) -> scalar_t_pair {
157 scalar_t_pair partial_out = vec::reduce2_all<scalar_t>(
158 [=](Vec x, Vec y) { return reduce_chunk_func1(x, y); },
159 [=](Vec x, Vec y) { return reduce_chunk_func2(x, y); },
160 input_data + start,
161 end - start);
162 return partial_out;
163 },
164 reduce_acc_func
165 );
166 output1.fill_(result.first);
167 output2.fill_(result.second);
168 }
169
aminmax_allreduce_kernel(const Tensor & input,Tensor & min_result,Tensor & max_result)170 static void aminmax_allreduce_kernel(
171 const Tensor& input,
172 Tensor& min_result,
173 Tensor& max_result) {
174 if (input.scalar_type() == ScalarType::Bool) {
175 TensorIterator iter = TensorIteratorConfig()
176 .add_input(input)
177 .build();
178 bool min_result_data = true;
179 bool max_result_data = false;
180 cpu_serial_kernel(iter, [&](const bool a) -> void {
181 min_result_data = min_result_data && a;
182 max_result_data = max_result_data || a;
183 });
184 min_result.fill_(min_result_data);
185 max_result.fill_(max_result_data);
186 } else if (input.scalar_type() == ScalarType::Long) {
187 // for int64_t, vectorized implementation have performance issue,
188 // just use scalar path
189 using int64_t_pair = std::pair<int64_t, int64_t>;
190 reduce_all_impl_two_outputs<int64_t>(min_result, max_result, input,
191 int64_t_pair(upper_bound<int64_t>(), lower_bound<int64_t>()),
192 // reduce over chunk
193 [=](int64_t_pair a, int64_t b) -> int64_t_pair {
194 return int64_t_pair(min_impl(a.first, b), max_impl(a.second, b));
195 },
196 // combine two inputs
197 [=](int64_t_pair a, int64_t_pair b) -> int64_t_pair {
198 return int64_t_pair(min_impl(a.first, b.first), max_impl(a.second, b.second));
199 }
200 );
201 } else {
202 AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "aminmax_cpu", [&] {
203 using Vec = Vectorized<opmath_type<scalar_t>>;
204 using scalar_t_pair = std::pair<scalar_t, scalar_t>;
205 reduce_all_impl_vec_two_outputs<scalar_t>(
206 min_result,
207 max_result,
208 input,
209 scalar_t_pair(upper_bound<scalar_t>(), lower_bound<scalar_t>()),
210 [=] (scalar_t_pair a , scalar_t_pair b) -> scalar_t_pair {
211 return scalar_t_pair(
212 min_impl(a.first, b.first), max_impl(a.second, b.second));
213 },
214 [=](Vec a, Vec b) -> Vec { return minimum(a, b); },
215 [=](Vec a, Vec b) -> Vec { return maximum(a, b); }
216 );
217 });
218 }
219 }
220
221 } // namespace
222
223 REGISTER_DISPATCH(min_all_stub, &min_all_kernel_impl);
224 REGISTER_DISPATCH(max_all_stub, &max_all_kernel_impl);
225 REGISTER_DISPATCH(aminmax_allreduce_stub, &aminmax_allreduce_kernel);
226
227 } // namespace at::native
228