xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/TensorCompareKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/ScalarType.h>
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/ReduceOps.h>
5 #include <ATen/native/TensorCompare.h>
6 
7 #include <numeric>
8 #include <iterator>
9 #include <algorithm>
10 #include <utility>
11 #include <vector>
12 
13 #include <ATen/Dispatch.h>
14 #include <ATen/Parallel.h>
15 #include <ATen/NumericUtils.h>
16 #include <ATen/TensorIterator.h>
17 #include <ATen/WrapDimUtils.h>
18 #include <c10/util/irange.h>
19 #include <ATen/native/ReduceOpsUtils.h>
20 #include <ATen/native/Resize.h>
21 #include <ATen/native/cpu/Loops.h>
22 
23 #ifndef AT_PER_OPERATOR_HEADERS
24 #include <ATen/Functions.h>
25 #else
26 #include <ATen/ops/result_type.h>
27 #endif
28 
29 namespace at::native { namespace {
30 
31 template <typename scalar_t, typename scalar_t_2 = int64_t, typename loop1d_t>
compare_base_kernel_core(const Tensor & result1,const Tensor & result2,const Tensor & self,int64_t dim,bool keepdim,const loop1d_t & loop)32 static inline void compare_base_kernel_core(
33     const Tensor& result1,
34     const Tensor& result2,
35     const Tensor& self,
36     int64_t dim,
37     bool keepdim,
38     const loop1d_t& loop) {
39   auto self_sizes = ensure_nonempty_vec(self.sizes().vec());
40   self_sizes[dim] = 1;
41 
42   // result1 and result2 may be a empty tensor, if not,
43   // reshape them as self dims
44   if (!keepdim) {
45     if (result1.ndimension() >= dim) {
46       result1.unsqueeze_(dim);
47     }
48     if (result2.ndimension() >= dim) {
49       result2.unsqueeze_(dim);
50     }
51   }
52 
53   at::native::resize_output(result1, self_sizes);
54   at::native::resize_output(result2, self_sizes);
55 
56   auto iter = TensorIteratorConfig()
57     .check_all_same_dtype(false)
58     .resize_outputs(false)
59     .declare_static_shape(self.sizes(), /*squash_dims=*/dim)
60     .add_output(result1)
61     .add_output(result2)
62     .add_const_input(self)
63     .build();
64 
65   iter.for_each(loop, /* grain_size */ 1);
66 
67   if (!keepdim) {
68     result1.squeeze_(dim);
69     result2.squeeze_(dim);
70   }
71 }
72 
73 template <typename scalar_t, typename scalar_t_2=int64_t, typename func_t>
compare_base_kernel(const Tensor & result1,const Tensor & result2,const Tensor & self,int64_t dim,bool keepdim,const func_t & f)74 static inline void compare_base_kernel(const Tensor& result1, const Tensor& result2,
75     const Tensor& self,
76     int64_t dim,
77     bool keepdim,
78     const func_t& f) {
79 
80   auto self_dim_stride = ensure_nonempty_stride(self, dim);
81 
82   auto loop = [&](char** data, const int64_t* strides, int64_t n) {
83     auto* result1_data_bytes = data[0];
84     auto* result2_data_bytes = data[1];
85     const auto* self_data_bytes = data[2];
86     for (const auto i C10_UNUSED : c10::irange(n)) {
87       f((scalar_t*)result1_data_bytes,
88         (scalar_t_2*)result2_data_bytes,
89         (scalar_t*)self_data_bytes,
90         self_dim_stride);
91       result1_data_bytes += strides[0];
92       result2_data_bytes += strides[1];
93       self_data_bytes += strides[2];
94     }
95   };
96 
97   compare_base_kernel_core<scalar_t, scalar_t_2>(
98       result1, result2, self, dim, keepdim, loop);
99 }
100 
min_kernel_impl(const Tensor & result,const Tensor & indice,const Tensor & self,int64_t dim,bool keepdim)101 static void min_kernel_impl(
102     const Tensor& result,
103     const Tensor& indice,
104     const Tensor& self,
105     int64_t dim,
106     bool keepdim) {
107   int64_t self_dim_size = ensure_nonempty_size(self, dim);
108 
109   AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "min_cpu", [&] {
110     compare_base_kernel<scalar_t>(result, indice, self, dim, keepdim, [&] (
111       scalar_t* result_data, int64_t* indice_data,
112       const scalar_t* self_data, auto self_dim_stride) {
113         using value_t = typename c10::scalar_value_type<scalar_t>::type;
114         value_t (*zabs_)(scalar_t) = zabs<scalar_t, value_t>;
115         scalar_t min_number = c10::load(self_data);
116         int64_t index = 0;
117         for (const auto i : c10::irange(self_dim_size)) {
118           scalar_t value = self_data[i * self_dim_stride];
119           if (!(zabs_(value) >= zabs_(min_number))) {
120             min_number = value;
121             index = i;
122             if (_isnan<scalar_t>(value)) {
123               break;
124             }
125           }
126         }
127         *result_data = min_number;
128         *indice_data = index;
129       }
130     );
131   });
132 }
133 
max_kernel_impl(const Tensor & result,const Tensor & indice,const Tensor & self,int64_t dim,bool keepdim)134 static void max_kernel_impl(
135     const Tensor& result,
136     const Tensor& indice,
137     const Tensor& self,
138     int64_t dim,
139     bool keepdim) {
140   int64_t self_dim_size = ensure_nonempty_size(self, dim);
141 
142   AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "max_cpu", [&] {
143     compare_base_kernel<scalar_t>(result, indice, self, dim, keepdim, [&] (
144       scalar_t* result_data, int64_t* indice_data,
145       const scalar_t* self_data, auto self_dim_stride) {
146         using value_t = typename c10::scalar_value_type<scalar_t>::type;
147         value_t (*zabs_)(scalar_t) = zabs<scalar_t, value_t>;
148         scalar_t max_number = c10::load(self_data);
149         int64_t index = 0;
150         for (const auto i : c10::irange(self_dim_size)) {
151           scalar_t value = c10::load(&self_data[i * self_dim_stride]);
152           if (!(zabs_(value) <= zabs_(max_number))) {
153             max_number = value;
154             index = i;
155             if (_isnan<scalar_t>(value)) {
156               break;
157             }
158           }
159         }
160         *result_data = max_number;
161         *indice_data = index;
162       }
163     );
164   });
165 }
166 
aminmax_kernel(const Tensor & self,int64_t dim,bool keepdim,Tensor & min_result,Tensor & max_result)167 static void aminmax_kernel(
168     const Tensor& self,
169     int64_t dim,
170     bool keepdim,
171     Tensor& min_result,
172     Tensor& max_result) {
173   auto wrap_dim = maybe_wrap_dim(dim, self.dim());
174   int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim);
175 
176   TORCH_CHECK(min_result.scalar_type() == self.scalar_type() && max_result.scalar_type() == self.scalar_type(),
177     "Expect min and max dtype ", self.scalar_type(),
178     " but got ", min_result.scalar_type(), " and ", max_result.scalar_type());
179 
180   if (self.numel() == 1 && self.ndimension() == 0) {
181     TORCH_CHECK(!self.is_complex(), "aminmax not implemented for ", self.scalar_type());
182     min_result.resize_({});
183     max_result.resize_({});
184     min_result.fill_(self);
185     max_result.fill_(self);
186     return;
187   }
188 
189   AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "aminmax_cpu", [&] {
190     compare_base_kernel<scalar_t, scalar_t>(min_result, max_result, self, wrap_dim, keepdim, [&] (
191       scalar_t* min_result_data, scalar_t* max_result_data,
192       const scalar_t* self_data, auto self_dim_stride) {
193         scalar_t min_number = c10::load(self_data);
194         scalar_t max_number = min_number;
195         for (const auto i : c10::irange(self_dim_size)) {
196           scalar_t value = c10::load(&self_data[i * self_dim_stride]);
197           // note: comparison is written this way to handle NaN correctly
198           if (!(value >= min_number)) {
199             min_number = value;
200             if (_isnan<scalar_t>(value)) {
201               max_number = value;
202               break;
203             }
204           } else if (!(value <= max_number)) {
205             max_number = value;
206           }
207         }
208         *min_result_data = min_number;
209         *max_result_data = max_number;
210       }
211     );
212   });
213 }
214 
where_kernel_impl(TensorIterator & iter)215 static void where_kernel_impl(TensorIterator &iter) {
216   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool,
217     iter.dtype(), "where_cpu", [&] {
218       cpu_kernel(
219         iter,
220         [=](bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
221           return cond_val ? self_val : other_val;
222         });
223   });
224 }
225 
isposinf_kernel_impl(TensorIteratorBase & iter)226 static void isposinf_kernel_impl(TensorIteratorBase& iter) {
227   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isposinf_cpu", [&]() {
228     cpu_kernel(iter, [](scalar_t a) -> bool { return a == std::numeric_limits<scalar_t>::infinity(); });
229   });
230 }
231 
isneginf_kernel_impl(TensorIteratorBase & iter)232 static void isneginf_kernel_impl(TensorIteratorBase& iter) {
233   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isneginf_cpu", [&]() {
234     cpu_kernel(iter, [](scalar_t a) -> bool { return a == -std::numeric_limits<scalar_t>::infinity(); });
235   });
236 }
237 
mode_kernel_impl(Tensor & values,Tensor & indices,const Tensor & self,int64_t dim,bool keepdim)238 static void mode_kernel_impl(
239     Tensor& values,
240     Tensor& indices,
241     const Tensor& self,
242     int64_t dim,
243     bool keepdim) {
244   auto self_dim_size = ensure_nonempty_size(self, dim);
245   auto self_dim_stride = ensure_nonempty_stride(self, dim);
246 
247   AT_DISPATCH_ALL_TYPES_AND3(
248       kHalf, kBFloat16, kBool, self.scalar_type(), "mode_cpu", [&] {
249         auto loop = [&](char** data, const int64_t* strides, int64_t n) {
250           auto* values_data_bytes = data[0];
251           auto* indices_data_bytes = data[1];
252           const auto* self_data_bytes = data[2];
253 
254           std::vector<std::pair<scalar_t, int64_t>> elements(self_dim_size);
255 
256           for (const auto k C10_UNUSED : c10::irange(n)) {
257             scalar_t* values_data = (scalar_t*)values_data_bytes;
258             int64_t* indices_data = (int64_t*)indices_data_bytes;
259             const scalar_t* self_data = (scalar_t*)self_data_bytes;
260 
261             scalar_t mode = 0;
262             int64_t modei = 0;
263             int64_t temp_freq = 0;
264             int64_t max_freq = 0;
265 
266             for (const auto i : c10::irange(self_dim_size)) {
267               elements[i] = std::make_pair(c10::load(&self_data[i * self_dim_stride]), i);
268             }
269 
270             // Even though, theoretically, we don't need to specify this lambda
271             // (it's basically the same as std::less), doing so degrades
272             // performance. That is because its implementation for std::pair
273             // uses 3 comparisons.
274             std::sort(
275                 elements.begin(),
276                 elements.end(),
277                 [=](const auto& i, const auto& j) {
278                   return i.first < j.first;
279                 });
280 
281             for (const auto i : c10::irange(self_dim_size)) {
282               temp_freq++;
283               if ((i == self_dim_size - 1) ||
284                   (elements[i].first != elements[i + 1].first)) {
285                 if (temp_freq > max_freq) {
286                   mode = elements[i].first;
287                   modei = elements[i].second;
288                   max_freq = temp_freq;
289                 }
290                 temp_freq = 0;
291               }
292             }
293 
294             *values_data = mode;
295             *indices_data = modei;
296 
297             values_data_bytes += strides[0];
298             indices_data_bytes += strides[1];
299             self_data_bytes += strides[2];
300           }
301         };
302 
303         compare_base_kernel_core<scalar_t>(
304             values, indices, self, dim, keepdim, loop);
305       });
306 }
307 
308 // Default brute force implementation of isin(). Used when the number of test elements is small.
309 // Iterates through each element and checks it against each test element.
isin_default_kernel_cpu(const Tensor & elements,const Tensor & test_elements,bool invert,const Tensor & out)310 static void isin_default_kernel_cpu(
311     const Tensor& elements,
312     const Tensor& test_elements,
313     bool invert,
314     const Tensor& out) {
315   // Since test elements is not an input of the TensorIterator, type promotion
316   // must be done manually.
317   ScalarType common_type = at::result_type(elements, test_elements);
318   Tensor promoted_elements = elements.to(common_type);
319   Tensor test_elements_flat = test_elements.to(common_type).view(-1);
320   auto test_elements_stride = test_elements_flat.stride(0);
321 
322   auto iter = TensorIteratorConfig()
323     .add_output(out)
324     .add_const_input(promoted_elements)
325     .check_all_same_dtype(false)
326     .build();
327   // Dispatch based on promoted type.
328   AT_DISPATCH_ALL_TYPES(iter.dtype(1), "isin_default_cpu", [&]() {
329     cpu_kernel(iter, [&](scalar_t element_val) -> bool {
330       const auto* test_element_data = test_elements_flat.const_data_ptr<scalar_t>();
331       for (const auto j : c10::irange(test_elements_flat.numel())) {
332         if (element_val == *(test_element_data + test_elements_stride * j)) {
333           return !invert;
334         }
335       }
336       return invert;
337     });
338   });
339 }
340 
clamp_kernel_impl(TensorIteratorBase & iter)341 static void clamp_kernel_impl(TensorIteratorBase& iter) {
342   AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "clamp_cpu", [&]() {
343     cpu_kernel_vec(iter,
344       [](scalar_t a, scalar_t min, scalar_t max) -> scalar_t {
345         if (min != min || max != max) {
346             return std::numeric_limits<scalar_t>::quiet_NaN();
347         } else {
348             return std::min(std::max(a, min), max);
349         }
350       },
351       [](Vectorized<scalar_t> a, Vectorized<scalar_t> min, Vectorized<scalar_t> max) {
352         return vec::minimum(vec::maximum(a, min), max);
353       });
354   });
355 }
356 
clamp_scalar_kernel_impl(TensorIteratorBase & iter,const Scalar & min_,const Scalar & max_)357 static void clamp_scalar_kernel_impl(TensorIteratorBase& iter, const Scalar& min_, const Scalar& max_) {
358   AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "clamp_scalar_cpu", [&]() {
359     const auto min = min_.to<scalar_t>();
360     const auto max = max_.to<scalar_t>();
361     const Vectorized<scalar_t> min_vec(min);
362     const Vectorized<scalar_t> max_vec(max);
363       cpu_kernel_vec(iter,
364         [=](scalar_t a) -> scalar_t {
365           return std::min(std::max(a, min), max);
366         },
367         [=](Vectorized<scalar_t> a) {
368           return vec::clamp(a, min_vec, max_vec);
369         });
370   });
371 }
372 
clamp_max_scalar_kernel_impl(TensorIteratorBase & iter,Scalar max_)373 static void clamp_max_scalar_kernel_impl(TensorIteratorBase& iter, Scalar max_) {
374   AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "clamp_max_scalar_cpu", [&]() {
375     const auto max = max_.to<scalar_t>();
376     const Vectorized<scalar_t> max_vec(max);
377     cpu_kernel_vec(iter,
378       [=](scalar_t a) -> scalar_t {
379         return std::min(a, max);
380       },
381       [=](Vectorized<scalar_t> a) {
382         return vec::clamp_max(a, max_vec);
383       });
384   });
385 }
386 
clamp_min_scalar_kernel_impl(TensorIteratorBase & iter,Scalar min_)387 static void clamp_min_scalar_kernel_impl(TensorIteratorBase& iter, Scalar min_) {
388   AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "clamp_min_scalar_cpu", [&]() {
389     const auto min = min_.to<scalar_t>();
390     const Vectorized<scalar_t> min_vec(min);
391     cpu_kernel_vec(iter,
392         [=](scalar_t a) -> scalar_t {
393           return std::max(a, min);
394         },
395         [=](Vectorized<scalar_t> a) {
396           return vec::clamp_min(a, min_vec);
397         });
398   });
399 }
400 
401 } // anonymous namespace
402 
403 REGISTER_DISPATCH(max_stub, &max_kernel_impl);
404 REGISTER_DISPATCH(min_stub, &min_kernel_impl);
405 REGISTER_DISPATCH(aminmax_stub, &aminmax_kernel);
406 REGISTER_DISPATCH(where_kernel, &where_kernel_impl);
407 REGISTER_DISPATCH(isposinf_stub, &isposinf_kernel_impl);
408 REGISTER_DISPATCH(isneginf_stub, &isneginf_kernel_impl);
409 REGISTER_DISPATCH(mode_stub, &mode_kernel_impl);
410 REGISTER_DISPATCH(clamp_stub, &clamp_kernel_impl);
411 REGISTER_DISPATCH(clamp_scalar_stub, &clamp_scalar_kernel_impl);
412 REGISTER_DISPATCH(clamp_min_scalar_stub, &clamp_min_scalar_kernel_impl);
413 REGISTER_DISPATCH(clamp_max_scalar_stub, &clamp_max_scalar_kernel_impl);
414 REGISTER_DISPATCH(isin_default_stub, &isin_default_kernel_cpu);
415 
416 } // namespace at::native
417