1 #include <c10/core/ScalarType.h>
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/ReduceOps.h>
5 #include <ATen/native/TensorCompare.h>
6
7 #include <numeric>
8 #include <iterator>
9 #include <algorithm>
10 #include <utility>
11 #include <vector>
12
13 #include <ATen/Dispatch.h>
14 #include <ATen/Parallel.h>
15 #include <ATen/NumericUtils.h>
16 #include <ATen/TensorIterator.h>
17 #include <ATen/WrapDimUtils.h>
18 #include <c10/util/irange.h>
19 #include <ATen/native/ReduceOpsUtils.h>
20 #include <ATen/native/Resize.h>
21 #include <ATen/native/cpu/Loops.h>
22
23 #ifndef AT_PER_OPERATOR_HEADERS
24 #include <ATen/Functions.h>
25 #else
26 #include <ATen/ops/result_type.h>
27 #endif
28
29 namespace at::native { namespace {
30
31 template <typename scalar_t, typename scalar_t_2 = int64_t, typename loop1d_t>
compare_base_kernel_core(const Tensor & result1,const Tensor & result2,const Tensor & self,int64_t dim,bool keepdim,const loop1d_t & loop)32 static inline void compare_base_kernel_core(
33 const Tensor& result1,
34 const Tensor& result2,
35 const Tensor& self,
36 int64_t dim,
37 bool keepdim,
38 const loop1d_t& loop) {
39 auto self_sizes = ensure_nonempty_vec(self.sizes().vec());
40 self_sizes[dim] = 1;
41
42 // result1 and result2 may be a empty tensor, if not,
43 // reshape them as self dims
44 if (!keepdim) {
45 if (result1.ndimension() >= dim) {
46 result1.unsqueeze_(dim);
47 }
48 if (result2.ndimension() >= dim) {
49 result2.unsqueeze_(dim);
50 }
51 }
52
53 at::native::resize_output(result1, self_sizes);
54 at::native::resize_output(result2, self_sizes);
55
56 auto iter = TensorIteratorConfig()
57 .check_all_same_dtype(false)
58 .resize_outputs(false)
59 .declare_static_shape(self.sizes(), /*squash_dims=*/dim)
60 .add_output(result1)
61 .add_output(result2)
62 .add_const_input(self)
63 .build();
64
65 iter.for_each(loop, /* grain_size */ 1);
66
67 if (!keepdim) {
68 result1.squeeze_(dim);
69 result2.squeeze_(dim);
70 }
71 }
72
73 template <typename scalar_t, typename scalar_t_2=int64_t, typename func_t>
compare_base_kernel(const Tensor & result1,const Tensor & result2,const Tensor & self,int64_t dim,bool keepdim,const func_t & f)74 static inline void compare_base_kernel(const Tensor& result1, const Tensor& result2,
75 const Tensor& self,
76 int64_t dim,
77 bool keepdim,
78 const func_t& f) {
79
80 auto self_dim_stride = ensure_nonempty_stride(self, dim);
81
82 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
83 auto* result1_data_bytes = data[0];
84 auto* result2_data_bytes = data[1];
85 const auto* self_data_bytes = data[2];
86 for (const auto i C10_UNUSED : c10::irange(n)) {
87 f((scalar_t*)result1_data_bytes,
88 (scalar_t_2*)result2_data_bytes,
89 (scalar_t*)self_data_bytes,
90 self_dim_stride);
91 result1_data_bytes += strides[0];
92 result2_data_bytes += strides[1];
93 self_data_bytes += strides[2];
94 }
95 };
96
97 compare_base_kernel_core<scalar_t, scalar_t_2>(
98 result1, result2, self, dim, keepdim, loop);
99 }
100
min_kernel_impl(const Tensor & result,const Tensor & indice,const Tensor & self,int64_t dim,bool keepdim)101 static void min_kernel_impl(
102 const Tensor& result,
103 const Tensor& indice,
104 const Tensor& self,
105 int64_t dim,
106 bool keepdim) {
107 int64_t self_dim_size = ensure_nonempty_size(self, dim);
108
109 AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "min_cpu", [&] {
110 compare_base_kernel<scalar_t>(result, indice, self, dim, keepdim, [&] (
111 scalar_t* result_data, int64_t* indice_data,
112 const scalar_t* self_data, auto self_dim_stride) {
113 using value_t = typename c10::scalar_value_type<scalar_t>::type;
114 value_t (*zabs_)(scalar_t) = zabs<scalar_t, value_t>;
115 scalar_t min_number = c10::load(self_data);
116 int64_t index = 0;
117 for (const auto i : c10::irange(self_dim_size)) {
118 scalar_t value = self_data[i * self_dim_stride];
119 if (!(zabs_(value) >= zabs_(min_number))) {
120 min_number = value;
121 index = i;
122 if (_isnan<scalar_t>(value)) {
123 break;
124 }
125 }
126 }
127 *result_data = min_number;
128 *indice_data = index;
129 }
130 );
131 });
132 }
133
max_kernel_impl(const Tensor & result,const Tensor & indice,const Tensor & self,int64_t dim,bool keepdim)134 static void max_kernel_impl(
135 const Tensor& result,
136 const Tensor& indice,
137 const Tensor& self,
138 int64_t dim,
139 bool keepdim) {
140 int64_t self_dim_size = ensure_nonempty_size(self, dim);
141
142 AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "max_cpu", [&] {
143 compare_base_kernel<scalar_t>(result, indice, self, dim, keepdim, [&] (
144 scalar_t* result_data, int64_t* indice_data,
145 const scalar_t* self_data, auto self_dim_stride) {
146 using value_t = typename c10::scalar_value_type<scalar_t>::type;
147 value_t (*zabs_)(scalar_t) = zabs<scalar_t, value_t>;
148 scalar_t max_number = c10::load(self_data);
149 int64_t index = 0;
150 for (const auto i : c10::irange(self_dim_size)) {
151 scalar_t value = c10::load(&self_data[i * self_dim_stride]);
152 if (!(zabs_(value) <= zabs_(max_number))) {
153 max_number = value;
154 index = i;
155 if (_isnan<scalar_t>(value)) {
156 break;
157 }
158 }
159 }
160 *result_data = max_number;
161 *indice_data = index;
162 }
163 );
164 });
165 }
166
aminmax_kernel(const Tensor & self,int64_t dim,bool keepdim,Tensor & min_result,Tensor & max_result)167 static void aminmax_kernel(
168 const Tensor& self,
169 int64_t dim,
170 bool keepdim,
171 Tensor& min_result,
172 Tensor& max_result) {
173 auto wrap_dim = maybe_wrap_dim(dim, self.dim());
174 int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim);
175
176 TORCH_CHECK(min_result.scalar_type() == self.scalar_type() && max_result.scalar_type() == self.scalar_type(),
177 "Expect min and max dtype ", self.scalar_type(),
178 " but got ", min_result.scalar_type(), " and ", max_result.scalar_type());
179
180 if (self.numel() == 1 && self.ndimension() == 0) {
181 TORCH_CHECK(!self.is_complex(), "aminmax not implemented for ", self.scalar_type());
182 min_result.resize_({});
183 max_result.resize_({});
184 min_result.fill_(self);
185 max_result.fill_(self);
186 return;
187 }
188
189 AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "aminmax_cpu", [&] {
190 compare_base_kernel<scalar_t, scalar_t>(min_result, max_result, self, wrap_dim, keepdim, [&] (
191 scalar_t* min_result_data, scalar_t* max_result_data,
192 const scalar_t* self_data, auto self_dim_stride) {
193 scalar_t min_number = c10::load(self_data);
194 scalar_t max_number = min_number;
195 for (const auto i : c10::irange(self_dim_size)) {
196 scalar_t value = c10::load(&self_data[i * self_dim_stride]);
197 // note: comparison is written this way to handle NaN correctly
198 if (!(value >= min_number)) {
199 min_number = value;
200 if (_isnan<scalar_t>(value)) {
201 max_number = value;
202 break;
203 }
204 } else if (!(value <= max_number)) {
205 max_number = value;
206 }
207 }
208 *min_result_data = min_number;
209 *max_result_data = max_number;
210 }
211 );
212 });
213 }
214
where_kernel_impl(TensorIterator & iter)215 static void where_kernel_impl(TensorIterator &iter) {
216 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool,
217 iter.dtype(), "where_cpu", [&] {
218 cpu_kernel(
219 iter,
220 [=](bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
221 return cond_val ? self_val : other_val;
222 });
223 });
224 }
225
isposinf_kernel_impl(TensorIteratorBase & iter)226 static void isposinf_kernel_impl(TensorIteratorBase& iter) {
227 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isposinf_cpu", [&]() {
228 cpu_kernel(iter, [](scalar_t a) -> bool { return a == std::numeric_limits<scalar_t>::infinity(); });
229 });
230 }
231
isneginf_kernel_impl(TensorIteratorBase & iter)232 static void isneginf_kernel_impl(TensorIteratorBase& iter) {
233 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isneginf_cpu", [&]() {
234 cpu_kernel(iter, [](scalar_t a) -> bool { return a == -std::numeric_limits<scalar_t>::infinity(); });
235 });
236 }
237
mode_kernel_impl(Tensor & values,Tensor & indices,const Tensor & self,int64_t dim,bool keepdim)238 static void mode_kernel_impl(
239 Tensor& values,
240 Tensor& indices,
241 const Tensor& self,
242 int64_t dim,
243 bool keepdim) {
244 auto self_dim_size = ensure_nonempty_size(self, dim);
245 auto self_dim_stride = ensure_nonempty_stride(self, dim);
246
247 AT_DISPATCH_ALL_TYPES_AND3(
248 kHalf, kBFloat16, kBool, self.scalar_type(), "mode_cpu", [&] {
249 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
250 auto* values_data_bytes = data[0];
251 auto* indices_data_bytes = data[1];
252 const auto* self_data_bytes = data[2];
253
254 std::vector<std::pair<scalar_t, int64_t>> elements(self_dim_size);
255
256 for (const auto k C10_UNUSED : c10::irange(n)) {
257 scalar_t* values_data = (scalar_t*)values_data_bytes;
258 int64_t* indices_data = (int64_t*)indices_data_bytes;
259 const scalar_t* self_data = (scalar_t*)self_data_bytes;
260
261 scalar_t mode = 0;
262 int64_t modei = 0;
263 int64_t temp_freq = 0;
264 int64_t max_freq = 0;
265
266 for (const auto i : c10::irange(self_dim_size)) {
267 elements[i] = std::make_pair(c10::load(&self_data[i * self_dim_stride]), i);
268 }
269
270 // Even though, theoretically, we don't need to specify this lambda
271 // (it's basically the same as std::less), doing so degrades
272 // performance. That is because its implementation for std::pair
273 // uses 3 comparisons.
274 std::sort(
275 elements.begin(),
276 elements.end(),
277 [=](const auto& i, const auto& j) {
278 return i.first < j.first;
279 });
280
281 for (const auto i : c10::irange(self_dim_size)) {
282 temp_freq++;
283 if ((i == self_dim_size - 1) ||
284 (elements[i].first != elements[i + 1].first)) {
285 if (temp_freq > max_freq) {
286 mode = elements[i].first;
287 modei = elements[i].second;
288 max_freq = temp_freq;
289 }
290 temp_freq = 0;
291 }
292 }
293
294 *values_data = mode;
295 *indices_data = modei;
296
297 values_data_bytes += strides[0];
298 indices_data_bytes += strides[1];
299 self_data_bytes += strides[2];
300 }
301 };
302
303 compare_base_kernel_core<scalar_t>(
304 values, indices, self, dim, keepdim, loop);
305 });
306 }
307
308 // Default brute force implementation of isin(). Used when the number of test elements is small.
309 // Iterates through each element and checks it against each test element.
isin_default_kernel_cpu(const Tensor & elements,const Tensor & test_elements,bool invert,const Tensor & out)310 static void isin_default_kernel_cpu(
311 const Tensor& elements,
312 const Tensor& test_elements,
313 bool invert,
314 const Tensor& out) {
315 // Since test elements is not an input of the TensorIterator, type promotion
316 // must be done manually.
317 ScalarType common_type = at::result_type(elements, test_elements);
318 Tensor promoted_elements = elements.to(common_type);
319 Tensor test_elements_flat = test_elements.to(common_type).view(-1);
320 auto test_elements_stride = test_elements_flat.stride(0);
321
322 auto iter = TensorIteratorConfig()
323 .add_output(out)
324 .add_const_input(promoted_elements)
325 .check_all_same_dtype(false)
326 .build();
327 // Dispatch based on promoted type.
328 AT_DISPATCH_ALL_TYPES(iter.dtype(1), "isin_default_cpu", [&]() {
329 cpu_kernel(iter, [&](scalar_t element_val) -> bool {
330 const auto* test_element_data = test_elements_flat.const_data_ptr<scalar_t>();
331 for (const auto j : c10::irange(test_elements_flat.numel())) {
332 if (element_val == *(test_element_data + test_elements_stride * j)) {
333 return !invert;
334 }
335 }
336 return invert;
337 });
338 });
339 }
340
clamp_kernel_impl(TensorIteratorBase & iter)341 static void clamp_kernel_impl(TensorIteratorBase& iter) {
342 AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "clamp_cpu", [&]() {
343 cpu_kernel_vec(iter,
344 [](scalar_t a, scalar_t min, scalar_t max) -> scalar_t {
345 if (min != min || max != max) {
346 return std::numeric_limits<scalar_t>::quiet_NaN();
347 } else {
348 return std::min(std::max(a, min), max);
349 }
350 },
351 [](Vectorized<scalar_t> a, Vectorized<scalar_t> min, Vectorized<scalar_t> max) {
352 return vec::minimum(vec::maximum(a, min), max);
353 });
354 });
355 }
356
clamp_scalar_kernel_impl(TensorIteratorBase & iter,const Scalar & min_,const Scalar & max_)357 static void clamp_scalar_kernel_impl(TensorIteratorBase& iter, const Scalar& min_, const Scalar& max_) {
358 AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "clamp_scalar_cpu", [&]() {
359 const auto min = min_.to<scalar_t>();
360 const auto max = max_.to<scalar_t>();
361 const Vectorized<scalar_t> min_vec(min);
362 const Vectorized<scalar_t> max_vec(max);
363 cpu_kernel_vec(iter,
364 [=](scalar_t a) -> scalar_t {
365 return std::min(std::max(a, min), max);
366 },
367 [=](Vectorized<scalar_t> a) {
368 return vec::clamp(a, min_vec, max_vec);
369 });
370 });
371 }
372
clamp_max_scalar_kernel_impl(TensorIteratorBase & iter,Scalar max_)373 static void clamp_max_scalar_kernel_impl(TensorIteratorBase& iter, Scalar max_) {
374 AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "clamp_max_scalar_cpu", [&]() {
375 const auto max = max_.to<scalar_t>();
376 const Vectorized<scalar_t> max_vec(max);
377 cpu_kernel_vec(iter,
378 [=](scalar_t a) -> scalar_t {
379 return std::min(a, max);
380 },
381 [=](Vectorized<scalar_t> a) {
382 return vec::clamp_max(a, max_vec);
383 });
384 });
385 }
386
clamp_min_scalar_kernel_impl(TensorIteratorBase & iter,Scalar min_)387 static void clamp_min_scalar_kernel_impl(TensorIteratorBase& iter, Scalar min_) {
388 AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "clamp_min_scalar_cpu", [&]() {
389 const auto min = min_.to<scalar_t>();
390 const Vectorized<scalar_t> min_vec(min);
391 cpu_kernel_vec(iter,
392 [=](scalar_t a) -> scalar_t {
393 return std::max(a, min);
394 },
395 [=](Vectorized<scalar_t> a) {
396 return vec::clamp_min(a, min_vec);
397 });
398 });
399 }
400
401 } // anonymous namespace
402
403 REGISTER_DISPATCH(max_stub, &max_kernel_impl);
404 REGISTER_DISPATCH(min_stub, &min_kernel_impl);
405 REGISTER_DISPATCH(aminmax_stub, &aminmax_kernel);
406 REGISTER_DISPATCH(where_kernel, &where_kernel_impl);
407 REGISTER_DISPATCH(isposinf_stub, &isposinf_kernel_impl);
408 REGISTER_DISPATCH(isneginf_stub, &isneginf_kernel_impl);
409 REGISTER_DISPATCH(mode_stub, &mode_kernel_impl);
410 REGISTER_DISPATCH(clamp_stub, &clamp_kernel_impl);
411 REGISTER_DISPATCH(clamp_scalar_stub, &clamp_scalar_kernel_impl);
412 REGISTER_DISPATCH(clamp_min_scalar_stub, &clamp_min_scalar_kernel_impl);
413 REGISTER_DISPATCH(clamp_max_scalar_stub, &clamp_max_scalar_kernel_impl);
414 REGISTER_DISPATCH(isin_default_stub, &isin_default_kernel_cpu);
415
416 } // namespace at::native
417