xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/TensorCompare.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/NamedTensorUtils.h>
5 #include <ATen/ScalarOps.h>
6 #include <ATen/TensorIndexing.h>
7 #include <ATen/TensorMeta.h>
8 #include <ATen/TensorOperators.h>
9 #include <ATen/WrapDimUtils.h>
10 #include <ATen/native/BinaryOps.h>
11 #include <ATen/native/ReduceOpsUtils.h>
12 #include <ATen/native/Resize.h>
13 #include <ATen/native/TensorCompare.h>
14 #include <ATen/native/TypeProperties.h>
15 #include <ATen/TensorSubclassLikeUtils.h>
16 #include <iostream>
17 #include <c10/util/Exception.h>
18 
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/Functions.h>
21 #include <ATen/NativeFunctions.h>
22 #else
23 #include <ATen/ops/_aminmax_native.h>
24 #include <ATen/ops/_assert_async_native.h>
25 #include <ATen/ops/_functional_assert_async_native.h>
26 #include <ATen/ops/_print_native.h>
27 #include <ATen/ops/_assert_scalar_native.h>
28 #include <ATen/ops/_functional_assert_scalar_native.h>
29 #include <ATen/ops/_make_per_tensor_quantized_tensor.h>
30 #include <ATen/ops/_unique.h>
31 #include <ATen/ops/allclose_native.h>
32 #include <ATen/ops/aminmax.h>
33 #include <ATen/ops/argsort_native.h>
34 #include <ATen/ops/cat.h>
35 #include <ATen/ops/clamp.h>
36 #include <ATen/ops/clamp_max.h>
37 #include <ATen/ops/clamp_max_native.h>
38 #include <ATen/ops/clamp_min.h>
39 #include <ATen/ops/clamp_min_native.h>
40 #include <ATen/ops/clamp_native.h>
41 #include <ATen/ops/clip_native.h>
42 #include <ATen/ops/empty.h>
43 #include <ATen/ops/empty_like.h>
44 #include <ATen/ops/eq.h>
45 #include <ATen/ops/fill.h>
46 #include <ATen/ops/imag.h>
47 #include <ATen/ops/index.h>
48 #include <ATen/ops/is_nonzero_native.h>
49 #include <ATen/ops/isclose.h>
50 #include <ATen/ops/isclose_native.h>
51 #include <ATen/ops/isfinite.h>
52 #include <ATen/ops/isfinite_native.h>
53 #include <ATen/ops/isin.h>
54 #include <ATen/ops/isin_native.h>
55 #include <ATen/ops/isinf.h>
56 #include <ATen/ops/isinf_native.h>
57 #include <ATen/ops/isnan_native.h>
58 #include <ATen/ops/isneginf_native.h>
59 #include <ATen/ops/isposinf_native.h>
60 #include <ATen/ops/isreal_native.h>
61 #include <ATen/ops/max.h>
62 #include <ATen/ops/max_native.h>
63 #include <ATen/ops/min.h>
64 #include <ATen/ops/min_native.h>
65 #include <ATen/ops/mode.h>
66 #include <ATen/ops/mode_native.h>
67 #include <ATen/ops/ne.h>
68 #include <ATen/ops/ones_like.h>
69 #include <ATen/ops/real.h>
70 #include <ATen/ops/result_type_native.h>
71 #include <ATen/ops/scalar_tensor.h>
72 #include <ATen/ops/where.h>
73 #include <ATen/ops/where_native.h>
74 #include <ATen/ops/zeros_like.h>
75 
76 #include <iostream>
77 #include <utility>
78 #endif
79 
80 namespace at::meta {
81 
check_for_unsupported_isin_dtype(const ScalarType type)82 static inline void check_for_unsupported_isin_dtype(const ScalarType type) {
83   // Bail out for dtypes unsupported by the sorting algorithm to keep the interface consistent.
84   TORCH_CHECK(type != ScalarType::Bool &&
85       type != ScalarType::BFloat16 &&
86       type != ScalarType::ComplexFloat &&
87       type != ScalarType::ComplexDouble,
88       "Unsupported input type encountered for isin(): ", type);
89 }
90 
TORCH_META_FUNC(clamp)91 TORCH_META_FUNC(clamp) (
92 const Tensor& self,
93 const OptionalScalarRef min,
94 const OptionalScalarRef max) {
95   if (!min && !max) {
96     TORCH_CHECK(false, "torch.clamp: At least one of 'min' or 'max' must not be None");
97   }
98   //Manual type promotion, since scalars have to participate in it
99   ScalarType result_type = self.scalar_type();
100   TORCH_CHECK(!isComplexType(result_type), "clamp is not supported for complex types");
101   //Floating is the highest supported
102   if (!isFloatingType(result_type)) {
103     at::native::ResultTypeState state = {};
104     state = at::native::update_result_type_state(self, state);
105 
106     if (min) {
107       state = at::native::update_result_type_state(min.get(), state);
108     }
109     if (max) {
110       state = at::native::update_result_type_state(max.get(), state);
111     }
112     result_type = at::native::result_type(state);
113     //disallow type promoting inplace op
114     TORCH_CHECK((result_type == self.scalar_type()) ||
115        (!(maybe_get_output().defined()) || !(maybe_get_output().is_same(self))),
116        "result type ", result_type, " can't be cast to the desired output type ",
117        self.dtype());
118   }
119   //make sure scalars weren't complex
120   TORCH_CHECK(!isComplexType(result_type), "clamp is not supported for complex types");
121   build_unary_op(maybe_get_output(), self.to(result_type));
122 }
123 
TORCH_META_FUNC2(clamp,Tensor)124 TORCH_META_FUNC2(clamp, Tensor) (
125 const Tensor& self,
126 const OptionalTensorRef min,
127 const OptionalTensorRef max) {
128   TORCH_CHECK(min || max, "torch.clamp: At least one of 'min' or 'max' must not be None");
129   TORCH_CHECK(!isComplexType(self.scalar_type()), "clamp is not supported for complex types");
130   #define CLAMP_CONFIG()                    \
131     TensorIteratorConfig()                  \
132       .set_check_mem_overlap(true)          \
133       .add_output(maybe_get_output())       \
134       .add_const_input(self)                \
135       .promote_inputs_to_common_dtype(true) \
136       .cast_common_dtype_to_outputs(true)   \
137       .enforce_safe_casting_to_output(true)
138 
139   if (min && max) {
140     build(CLAMP_CONFIG().add_const_input(*min).add_const_input(*max));
141   } else if (min) {
142     build(CLAMP_CONFIG().add_const_input(*min));
143   } else if (max) {
144     build(CLAMP_CONFIG().add_const_input(*max));
145   }
146 }
147 
148 
TORCH_META_FUNC(clamp_max)149 TORCH_META_FUNC(clamp_max) (
150   const Tensor& self,
151   const Scalar& max
152 ) {
153   //we could wrap max into tensor and send to tensor overload,
154   //but relu is implemented via clamp_min, so for perf an uniformity reasons
155   //do a faster but correct thing
156   ScalarType result_type = self.scalar_type();
157   TORCH_CHECK(!isComplexType(result_type), "clamp is not supported for complex types");
158   TORCH_CHECK(!max.isComplex(), "clamp is not supported for complex types");
159   //Floating is the highest supported
160   if (!isFloatingType(result_type)) {
161     auto result_type = at::native::result_type(self, max);
162     TORCH_CHECK((result_type == self.scalar_type()) ||
163        (!(maybe_get_output().defined()) || !(maybe_get_output().is_same(self))),
164        "result type ", result_type, " can't be cast to the desired output type ",
165        self.dtype());
166     build_unary_op(maybe_get_output(), self.to(result_type));
167   } else {
168     build_borrowing_unary_op(maybe_get_output(), self);
169   }
170 }
171 
TORCH_META_FUNC2(clamp_max,Tensor)172 TORCH_META_FUNC2(clamp_max, Tensor) (
173   const Tensor& self,
174   const Tensor& max
175 ) {
176   build_borrowing_binary_op(maybe_get_output(), self, max);
177 }
178 
179 
TORCH_META_FUNC(clamp_min)180 TORCH_META_FUNC(clamp_min) (
181   const Tensor& self,
182   const Scalar& min
183 ) {
184   ScalarType result_type = self.scalar_type();
185   TORCH_CHECK(!isComplexType(result_type), "clamp is not supported for complex types");
186   TORCH_CHECK(!min.isComplex(), "clamp is not supported for complex types");
187   //Floating is the highest supported
188   if (!isFloatingType(result_type)) {
189     auto result_type = at::native::result_type(self, min);
190     TORCH_CHECK((result_type == self.scalar_type() ||
191        !(maybe_get_output().defined()) || !(maybe_get_output().is_same(self))),
192        "result type ", result_type, " can't be cast to the desired output type ",
193        self.dtype());
194     build_unary_op(maybe_get_output(), self.to(result_type));
195   } else {
196     build_borrowing_unary_op(maybe_get_output(), self);
197   }
198 }
199 
TORCH_META_FUNC2(clamp_min,Tensor)200 TORCH_META_FUNC2(clamp_min, Tensor) (
201   const Tensor& self,
202   const Tensor& min
203 ) {
204   build_borrowing_binary_op(maybe_get_output(), self, min);
205 }
206 
TORCH_META_FUNC2(isin,Tensor_Tensor)207 TORCH_META_FUNC2(isin, Tensor_Tensor) (
208   const Tensor& elements, const Tensor& test_elements, bool /*assume_unique*/, bool /*invert*/
209 ) {
210   check_for_unsupported_isin_dtype(elements.scalar_type());
211   check_for_unsupported_isin_dtype(test_elements.scalar_type());
212   set_output_raw_strided(0, elements.sizes(), {}, TensorOptions(elements.device()).dtype(ScalarType::Bool));
213 }
214 
TORCH_META_FUNC2(isin,Tensor_Scalar)215 TORCH_META_FUNC2(isin, Tensor_Scalar) (
216   const Tensor& elements, const c10::Scalar& test_elements, bool /*assume_unique*/, bool /*invert*/
217 ) {
218   check_for_unsupported_isin_dtype(elements.scalar_type());
219   check_for_unsupported_isin_dtype(test_elements.type());
220   set_output_raw_strided(0, elements.sizes(), {}, TensorOptions(elements.device()).dtype(ScalarType::Bool));
221 }
222 
TORCH_META_FUNC2(isin,Scalar_Tensor)223 TORCH_META_FUNC2(isin, Scalar_Tensor) (
224   const c10::Scalar& elements, const Tensor& test_elements, bool /*assume_unique*/, bool /*invert*/
225 ) {
226   check_for_unsupported_isin_dtype(elements.type());
227   check_for_unsupported_isin_dtype(test_elements.scalar_type());
228   set_output_raw_strided(0, {0}, {}, TensorOptions(test_elements.device()).dtype(ScalarType::Bool));
229 }
230 
TORCH_META_FUNC(isposinf)231 TORCH_META_FUNC(isposinf) (const Tensor& self) {
232   TORCH_CHECK(!self.is_complex(), "isposinf does not support complex inputs.");
233   TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true,
234               "isposinf does not support non-boolean outputs.");
235   build_borrowing_unary_force_boolean_op(maybe_get_output(), self);
236 }
237 
TORCH_META_FUNC(isneginf)238 TORCH_META_FUNC(isneginf) (const Tensor& self) {
239   TORCH_CHECK(!self.is_complex(), "isneginf does not support complex inputs.");
240   TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true,
241               "isneginf does not support non-boolean outputs.");
242   build_borrowing_unary_force_boolean_op(maybe_get_output(), self);
243 }
244 
check_unsupported_complex(const char * name,const Tensor & self)245 static void check_unsupported_complex(const char* name, const Tensor& self) {
246   TORCH_CHECK(!self.is_complex(), name, ": does not support complex input");
247 }
248 
TORCH_PRECOMPUTE_META_FUNC2(max,dim)249 TORCH_PRECOMPUTE_META_FUNC2(max, dim)
250 (const Tensor& self, int64_t dim, bool keepdim) {
251   dim = maybe_wrap_dim(dim, self.dim());
252   at::native::zero_numel_check_dims(self, dim, "max()");
253   check_unsupported_complex("max()", self);
254   resize_reduction_with_indices(*this, self, dim, keepdim, self.scalar_type());
255   return TORCH_PRECOMPUTE_STRUCT2(max, dim)()
256       .set_dim(maybe_wrap_dim(dim, self.dim()));
257 }
258 
TORCH_PRECOMPUTE_META_FUNC2(min,dim)259 TORCH_PRECOMPUTE_META_FUNC2(min, dim)(const Tensor& self, int64_t dim, bool keepdim) {
260   dim = maybe_wrap_dim(dim, self.dim());
261   at::native::zero_numel_check_dims(self, dim, "min()");
262   check_unsupported_complex("min()", self);
263   resize_reduction_with_indices(*this, self, dim, keepdim, self.scalar_type());
264   return TORCH_PRECOMPUTE_STRUCT2(min, dim)()
265       .set_dim(maybe_wrap_dim(dim, self.dim()));
266 }
267 
268 } // namespace at::meta
269 
270 namespace at::native {
271 
272 DEFINE_DISPATCH(where_kernel); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
273 DEFINE_DISPATCH(max_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
274 DEFINE_DISPATCH(min_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
275 DEFINE_DISPATCH(isposinf_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
276 DEFINE_DISPATCH(isneginf_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
277 DEFINE_DISPATCH(mode_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
278 DEFINE_DISPATCH(clamp_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
279 DEFINE_DISPATCH(clamp_scalar_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
280 DEFINE_DISPATCH(clamp_min_scalar_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
281 DEFINE_DISPATCH(clamp_max_scalar_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
282 DEFINE_DISPATCH(isin_default_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
283 
allclose(const Tensor & self,const Tensor & other,double rtol,double atol,bool equal_nan)284 bool allclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) {
285   return at::isclose(self, other, rtol, atol, equal_nan).all().item<uint8_t>();
286 }
287 
288 // Note [closeness]
289 // A number A is close to B when either:
290 //
291 // (1) A is equal to B, with NaNs comparing equal when equal_nan is true.
292 // (2) The error abs(A - B) is finite and less than the max error
293 //      (atol + abs(rtol * B)).
294 //
295 // Note that this is consistent with NumPy's isclose but divergent from
296 // Python's isclose, which computes the max error symmetrically as
297 // max(rtol * max(abs(A), abs(B)), atol).
298 // TODO: use bitwise operator overloads once we add them
299 // TODO: revisit complex inputs and equal_nan=true after
300 //  https://github.com/numpy/numpy/issues/15959 is resolved
isclose(const Tensor & self,const Tensor & other,double rtol,double atol,bool equal_nan)301 Tensor isclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) {
302   TORCH_CHECK(self.scalar_type() == other.scalar_type(), self.scalar_type(), " did not match ", other.scalar_type());
303   TORCH_CHECK(!(self.is_quantized() || other.is_quantized()),
304     "isclose is not supported for quantized inputs.");
305 
306   // Checks that rtol and atol are non-negative
307   // Note: consistent with Python's isclose but divergent from NumPy's, which
308   //  allows negative atol and rtol.
309   TORCH_CHECK(rtol >= 0, "rtol must be greater than or equal to zero, but got ", rtol);
310   TORCH_CHECK(atol >= 0, "atol must be greater than or equal to zero, but got ", atol);
311 
312   // Computes equality closeness
313   Tensor close = self == other;
314   if (equal_nan && (self.is_floating_point() || self.is_complex())) {
315     // For CompositeCompliance, if `other` is a CCT and `self` is a regular Tensor,
316     // then we can't perform inplace op into `self` with `other`.
317     // NOTE: Inplacing into `close` is fine because it is generated from
318     // out-of-place with args `self` and `other`. So if either of them is
319     // a CCT then `close` will also be a `CCT`.
320     if (isTensorSubclassLike(other)) {
321       close.__ior__(self.isnan().bitwise_and(other.isnan()));
322     } else {
323       close.__ior__(self.isnan().__iand__(other.isnan()));
324     }
325   }
326 
327   // In case of zero tolerances the closeness inequality degenerates to an equality check.
328   // In this case, the short-circuit prevents false positives as detailed in the paragraph below.
329   if (rtol == 0 && atol == 0){
330       return close;
331   }
332 
333   // Note [closeness error computation]
334   // atol and rtol are provided as doubles, so the computation
335   // rtol * other will produce a float or complex tensor.
336   // When the difference (self - other) is compared to it then the
337   // tensor representing the difference will also be cast to float or complex.
338   // However, since (self - other) in uint8 is very likely to produce a
339   // negative value, this moves the cast forward so the difference is
340   // always computed in a float or complex type.
341   // If the values of the integer tensors cannot be exactly represented
342   // by the default scalar type then this may cause an incorrect result.
343 
344   // Computes allowed and actual error
345   Tensor cast_self, cast_other;
346   cast_self = self.scalar_type() == at::kBool ? self.to(at::get_default_dtype()) : self;
347   if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
348     cast_other = other.to(at::get_default_dtype());
349   } else {
350     cast_other = other;
351   }
352 
353   Tensor allowed_error = atol + (rtol * cast_other).abs();
354   Tensor actual_error = (cast_self - cast_other).abs();
355 
356   // Computes finite closeness
357   close.__ior__(at::isfinite(actual_error).__iand__(actual_error <= allowed_error));
358 
359   return close;
360 }
361 
isnan(const Tensor & self)362 Tensor isnan(const Tensor& self) {
363   return self != self;
364 }
365 
isreal(const Tensor & self)366 Tensor isreal(const Tensor& self) {
367   // Note: Integral and Floating tensor values are always real
368   if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true) ||
369       c10::isFloatingType(self.scalar_type())) {
370     return at::ones_like(self, at::kBool, at::MemoryFormat::Preserve);
371   }
372 
373   return at::imag(self) == 0;
374 }
375 
376 
377 #if !defined(C10_MOBILE)
378 #define _AT_DISPATCH_INF_TYPES(TYPE, NAME, ...)                          \
379         AT_DISPATCH_FLOATING_TYPES_AND3( kHalf, kBFloat16, kFloat8_e5m2, \
380             TYPE, NAME, __VA_ARGS__)
381 #else
382 #define _AT_DISPATCH_INF_TYPES(TYPE, NAME, ...)           \
383         AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, \
384             TYPE, NAME, __VA_ARGS__)
385 #endif
386 
387 
isinf(const Tensor & self)388 Tensor isinf(const Tensor &self) {
389   // Note: Integral tensor values are never infinite
390   if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
391     return at::zeros_like(self, at::kBool, at::MemoryFormat::Preserve);
392   }
393 
394   // Note: a complex value is infinite when either part is infinite
395   if (self.is_complex()) {
396     return at::isinf(at::real(self)).__ior__
397           (at::isinf(at::imag(self)));
398   }
399 
400   return _AT_DISPATCH_INF_TYPES(self.scalar_type(), "isinf", [&]() {
401     return self.abs() == std::numeric_limits<scalar_t>::infinity();
402   });
403 }
404 
isfinite(const Tensor & self)405 Tensor isfinite(const Tensor& self) {
406   // Note: Integral tensor values are always finite
407   if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
408     return at::ones_like(self, at::kBool, at::MemoryFormat::Preserve);
409   }
410 
411   // Note: a complex value is finite iff both parts are finite
412   if (self.is_complex()) {
413     return at::isfinite(at::real(self)).__iand__(at::isfinite(at::imag(self)));
414   }
415 
416   return _AT_DISPATCH_INF_TYPES(self.scalar_type(), "isfinite", [&]() {
417     return (self == self) * (self.abs() != std::numeric_limits<scalar_t>::infinity());
418   });
419 }
420 
_assert_async_cpu(const Tensor & self)421 void _assert_async_cpu(const Tensor& self) {
422   TORCH_CHECK(native::is_nonzero(self), "Expected Tensor with single nonzero value, but got zero");
423 }
424 
_assert_async_msg_cpu(const Tensor & self,c10::string_view assert_msg)425 void _assert_async_msg_cpu(const Tensor& self, c10::string_view assert_msg) {
426   TORCH_CHECK(native::is_nonzero(self), assert_msg != "" ? assert_msg : "Assertion is failed");
427 }
428 
_assert_scalar(const Scalar & scalar,c10::string_view assert_msg)429 void _assert_scalar(const Scalar& scalar, c10::string_view assert_msg) {
430   TORCH_SYM_CHECK(scalar.toSymBool(), assert_msg != "" ? assert_msg : "Assertion is failed");
431 }
432 
_functional_assert_scalar(const Scalar & scalar,c10::string_view assert_msg,const Tensor & dep_token)433 Tensor _functional_assert_scalar(const Scalar& scalar, c10::string_view assert_msg, const Tensor& dep_token) {
434   _assert_scalar(scalar, assert_msg);
435   return dep_token.clone();
436 }
437 
_functional_assert_async_msg_cpu(const Tensor & self,c10::string_view assert_msg,const Tensor & dep_token)438 Tensor _functional_assert_async_msg_cpu(
439   const Tensor& self,
440   c10::string_view assert_msg,
441   const Tensor& dep_token) {
442   _assert_async_msg_cpu(self, assert_msg);
443   return dep_token.clone();
444 }
445 
_print(c10::string_view s)446 void _print(c10::string_view s) {
447   std::cout << s << "\n";
448 }
449 
450 // Sorting-based algorithm for isin(); used when the number of test elements is large.
isin_sorting(const Tensor & elements,const Tensor & test_elements,bool assume_unique,bool invert,const Tensor & out)451 static void isin_sorting(
452     const Tensor& elements,
453     const Tensor& test_elements,
454     bool assume_unique,
455     bool invert,
456     const Tensor& out) {
457   // 1. Concatenate unique elements with unique test elements in 1D form. If
458   //    assume_unique is true, skip calls to unique().
459   Tensor elements_flat, test_elements_flat, unique_order;
460   if (assume_unique) {
461     elements_flat = elements.ravel();
462     test_elements_flat = test_elements.ravel();
463   } else {
464     std::tie(elements_flat, unique_order) = at::_unique(
465         elements, /*sorted=*/ false, /*return_inverse=*/ true);
466     std::tie(test_elements_flat, std::ignore) = at::_unique(test_elements, /*sorted=*/ false);
467   }
468 
469   // 2. Stable sort all elements, maintaining order indices to reverse the
470   //    operation. Stable sort is necessary to keep elements before test
471   //    elements within the sorted list.
472   Tensor all_elements = at::cat({std::move(elements_flat), std::move(test_elements_flat)});
473   auto [sorted_elements, sorted_order] = all_elements.sort(
474       /*stable=*/ true, /*dim=*/ 0, /*descending=*/ false);
475 
476   // 3. Create a mask for locations of adjacent duplicate values within the
477   //    sorted list. Duplicate values are in both elements and test elements.
478   Tensor duplicate_mask = at::empty_like(sorted_elements, TensorOptions(ScalarType::Bool));
479   Tensor sorted_except_first = sorted_elements.slice(0, 1, at::indexing::None);
480   Tensor sorted_except_last = sorted_elements.slice(0, 0, -1);
481   duplicate_mask.slice(0, 0, -1).copy_(
482     invert ? sorted_except_first.ne(sorted_except_last) : sorted_except_first.eq(sorted_except_last));
483   duplicate_mask.index_put_({-1}, invert);
484 
485   // 4. Reorder the mask to match the pre-sorted element order.
486   Tensor mask = at::empty_like(duplicate_mask);
487   mask.index_copy_(0, sorted_order, duplicate_mask);
488 
489   // 5. Index the mask to match the pre-unique element order. If
490   //    assume_unique is true, just take the first N items of the mask,
491   //    where N is the original number of elements.
492   if (assume_unique) {
493     out.copy_(mask.slice(0, 0, elements.numel()).view_as(out));
494   } else {
495     out.copy_(at::index(mask, {std::optional<Tensor>(unique_order)}));
496   }
497 }
498 
499 template<typename... Args>
out_device(Args &...inps)500 Device out_device(Args&... inps){
501   for (const auto& i : {inps...}){
502     if (i.device() != at::kCPU) {
503       return i.device();
504     }
505   }
506   return at::kCPU;
507 }
508 
509 
where_self_out(const Tensor & condition,const Tensor & self,const Tensor & other,Tensor & out)510 Tensor& where_self_out(const Tensor& condition, const Tensor& self, const Tensor& other, Tensor& out) {
511   const auto result_type = at::native::result_type(self, other);
512   TORCH_CHECK(out.scalar_type() == result_type, "Expected out type to be ", result_type, " but got ", out.scalar_type());
513 
514   auto self_ = self.scalar_type() != result_type ? self.to(result_type): self;
515   auto other_ = other.scalar_type() != result_type ? other.to(result_type): other;
516   auto condition_ = condition;
517   auto device = out_device(condition, self_, other_);
518   if (device != at::kCPU) { // allow CPU scalars on non-cpu device
519     if (condition.device() != device && condition.ndimension() == 0) {
520       condition_ = condition.to(device);
521     }
522     if (self_.device() != device && self_.ndimension() == 0) {
523         self_ = self_.to(device);
524     }
525     if (other_.device() != device && other_.ndimension() == 0) {
526         other_ = other_.to(device);
527     }
528   }
529   if (condition_.scalar_type() == ScalarType::Byte) {
530     TORCH_WARN_ONCE("where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead.");
531     condition_ = condition_.to(kBool);
532   }
533   TORCH_CHECK(condition_.scalar_type() == kBool, "where expected condition to be a boolean tensor, but got a tensor with dtype ", condition_.scalar_type());
534   // if there's still a device mismatch, let tensoriterator error out with it
535   auto iter = at::TensorIteratorConfig()
536     .check_all_same_dtype(false)
537     .add_output(out)
538     .add_const_input(condition_)
539     .add_const_input(self_)
540     .add_const_input(other_)
541     .build();
542   where_kernel(iter.device_type(), iter);
543   return out;
544 }
545 
546 
where(const Tensor & condition,const Tensor & self,const Tensor & other)547 Tensor where(const Tensor& condition, const Tensor& self, const Tensor& other) {
548   auto device = out_device(condition, self, other);
549   auto result_type = at::native::result_type(self, other);
550   Tensor ret = at::empty({0}, self.options().dtype(result_type).device(device));
551   at::native::where_self_out(condition, self, other, ret);
552   return ret;
553 }
554 
where(const Tensor & condition,const Scalar & self,const Tensor & other)555 Tensor where(const Tensor& condition, const Scalar& self, const Tensor& other) {
556   auto result_type = at::native::result_type(other, self);
557   auto self_converted = at::scalar_tensor(self, other.options().dtype(result_type));
558   auto other_converted = other.to(result_type);
559   return at::where(condition, self_converted, other_converted);
560 }
561 
where(const Tensor & condition,const Tensor & self,const Scalar & other)562 Tensor where(const Tensor& condition, const Tensor& self, const Scalar& other) {
563   auto result_type = at::native::result_type(self, other);
564   auto other_converted = at::scalar_tensor(other, self.options().dtype(result_type));
565   auto self_converted = self.to(result_type);
566   return at::where(condition, self_converted, other_converted);
567 }
568 
where(const Tensor & condition,const Scalar & self,const Scalar & other)569 Tensor where(const Tensor& condition, const Scalar& self, const Scalar& other) {
570   auto result_type = at::native::result_type(self, other);
571   const Tensor& other_t = at::scalar_tensor(other, condition.options().dtype(result_type));
572   const Tensor& self_t = at::scalar_tensor(self, condition.options().dtype(result_type));
573   return at::where(condition, self_t, other_t);
574 }
575 
where(const Tensor & condition)576 std::vector<Tensor> where(const Tensor& condition) {
577   return condition.nonzero_numpy();
578 }
579 
mode(const Tensor & self,int64_t dim,bool keepdim)580 std::tuple<Tensor, Tensor> mode(const Tensor& self, int64_t dim, bool keepdim) {
581   Tensor values = at::empty({0}, self.options());
582   Tensor indices = at::empty({0}, self.options().dtype(kLong));
583   return at::native::mode_out(self, dim, keepdim, values, indices);
584 }
585 
mode_out(const Tensor & self,int64_t dim,bool keepdim,Tensor & values,Tensor & indices)586 std::tuple<Tensor &,Tensor &> mode_out(const Tensor& self, int64_t dim, bool keepdim,
587                                        Tensor& values, Tensor& indices) {
588   TORCH_CHECK(self.device().is_cpu() || self.is_cuda(),
589               "mode only supports CPU AND CUDA device type, got: ", self.device().type());
590   TORCH_CHECK(self.layout() == Layout::Strided,
591               "mode only supports strided layout, got: ", self.layout());
592   TORCH_CHECK(self.device() == values.device(),
593               "expected device '", self.device(), "' but got '",
594               values.device(), "' for values output");
595   TORCH_CHECK(self.device() == indices.device(),
596               "expected device '", self.device(), "' but got '",
597               indices.device(), "' for indices output");
598   TORCH_CHECK(self.scalar_type() == values.scalar_type(),
599               "expected scalar type '", self.scalar_type(), "' but got '",
600               values.scalar_type(), "' for values output");
601   TORCH_CHECK(indices.scalar_type() == ScalarType::Long,
602               "expected scalar type '", ScalarType::Long, "' but got '",
603               indices.scalar_type(), "' for indices output");
604   dim = maybe_wrap_dim(dim, self.dim());
605   if (self.numel() == 0) {
606     auto sizes = get_zero_numel_tensor_size(self, dim, keepdim, "mode()");
607     resize_output(values, sizes);
608     resize_output(indices, sizes);
609     return std::tie(values, indices);
610   }
611   else if (_dimreduce_return_trivial_no_ident(values, self, dim, keepdim, "mode")) {
612     AT_ASSERT(values.dim() == 0);
613     indices.resize_({}).fill_(0);
614     return std::forward_as_tuple(values, indices);
615   } else {
616     auto result = [&]() {
617       NoNamesGuard guard;
618       mode_stub(self.device().type(), values, indices, self, dim, keepdim);
619       return std::tuple<Tensor &,Tensor &>{values, indices};
620     }();
621     namedinference::propagate_names_for_reduction(std::get<0>(result), self, dim, keepdim);
622     namedinference::propagate_names_for_reduction(std::get<1>(result), self, dim, keepdim);
623     return result;
624   }
625 }
626 
627 template <class Stub>
minmax_out_impl(const Tensor & self,int64_t dim,bool keepdim,const Tensor & values,const Tensor & indices,Stub & stub)628 void minmax_out_impl(
629     const Tensor& self,
630     int64_t dim,
631     bool keepdim,
632     const Tensor& values,
633     const Tensor& indices,
634     Stub& stub) {
635   NoNamesGuard guard;
636   if (self.numel() > 0) {
637     if (self.numel() == 1 && self.dim() == 0) {
638       values.fill_(self);
639       indices.fill_(0);
640     } else {
641       stub(self.device().type(), values, indices, self, dim, keepdim);
642     }
643   }
644 }
645 
TORCH_IMPL_FUNC(max_out)646 TORCH_IMPL_FUNC(max_out)
647 (const Tensor& self,
648  int64_t dim,
649  bool keepdim,
650  const Tensor& values,
651  const Tensor& indices) {
652   minmax_out_impl(self, dim, keepdim, values, indices, max_stub);
653 }
654 
TORCH_IMPL_FUNC(min_out)655 TORCH_IMPL_FUNC(min_out)
656 (const Tensor& self,
657  int64_t dim,
658  bool keepdim,
659  const Tensor& values,
660  const Tensor& indices) {
661   minmax_out_impl(self, dim, keepdim, values, indices, min_stub);
662 }
663 
qmax(const Tensor & self,int64_t dim,bool keepdim)664 std::tuple<Tensor, Tensor> qmax(const Tensor& self, int64_t dim, bool keepdim) {
665   TORCH_CHECK(self.qscheme() == at::kPerTensorAffine, "Max operator for quantized tensors only works for per tensor quantized tensors. "
666   "Please open an issue on https://github.com/pytorch/pytorch/issues if you need per channel quantized tensor support.");
667   Tensor max_indices = at::empty({0}, self.options().dtype(kLong));
668   Tensor max = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type())));
669   at::max_outf(self.int_repr(), dim, keepdim, max, max_indices);
670   // TODO: qscheme
671   return std::tuple<Tensor, Tensor>(
672       at::_make_per_tensor_quantized_tensor(max, self.q_scale(), self.q_zero_point()), max_indices);
673 }
674 
qmin(const Tensor & self,int64_t dim,bool keepdim)675 std::tuple<Tensor, Tensor> qmin(const Tensor& self, int64_t dim, bool keepdim) {
676   TORCH_CHECK(self.qscheme() == at::kPerTensorAffine, "Min operator for quantized tensors only works for per tensor quantized tensors. "
677   "Please open an issue on https://github.com/pytorch/pytorch/issues if you need per channel quantized tensor support.");
678   Tensor min_indices = at::empty({0}, self.options().dtype(kLong));
679   Tensor min = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type())));
680   at::min_outf(self.int_repr(), dim, keepdim, min, min_indices);
681   return std::tuple<Tensor, Tensor>(
682       at::_make_per_tensor_quantized_tensor(min, self.q_scale(), self.q_zero_point()), min_indices);
683 }
684 
685 // DEPRECATED: Use at::aminmax instead
_aminmax(const Tensor & self,int64_t dim,bool keepdim)686 std::tuple<Tensor, Tensor> _aminmax(const Tensor& self, int64_t dim, bool keepdim) {
687   TORCH_WARN_ONCE("_aminmax is deprecated as of PyTorch 1.11 and will be removed in a future release. Use aminmax instead."
688                   " This warning will only appear once per process.");
689   return at::aminmax(self, dim, keepdim);
690 }
691 
TORCH_IMPL_FUNC(clamp_out)692 TORCH_IMPL_FUNC(clamp_out)
693 (
694  const Tensor& /*self*/,
695  const OptionalScalarRef min,
696  const OptionalScalarRef max,
697  const Tensor& result) {
698   using at::native::detail::ClampLimits;
699   if (min && max) {
700     if (min.get().toDouble() != min.get().toDouble() ||
701         max.get().toDouble() != max.get().toDouble()) {
702       at::fill_(const_cast<Tensor&>(result), std::numeric_limits<double>::quiet_NaN());
703     } else {
704       clamp_scalar_stub(device_type(), *this, min.get(), max.get());
705     }
706   } else if (max) {
707     clamp_max_scalar_stub(device_type(), *this, max.get());
708   } else if (min) {
709     clamp_min_scalar_stub(device_type(), *this, min.get());
710   }
711 }
712 
TORCH_IMPL_FUNC(clamp_Tensor_out)713 TORCH_IMPL_FUNC(clamp_Tensor_out)
714 (const Tensor& self, const OptionalTensorRef min,
715                   const OptionalTensorRef max, const Tensor&) {
716   if (min && max) {
717     clamp_stub(device_type(), *this);
718   } else if (min) {
719     maximum_stub(device_type(), *this);
720   } else if (max) {
721     minimum_stub(device_type(), *this);
722   }
723 }
724 
TORCH_IMPL_FUNC(clamp_max_out)725 TORCH_IMPL_FUNC(clamp_max_out)
726 (const Tensor& self, const Scalar& max, const Tensor& result) {
727   if (max.toDouble() != max.toDouble()) {
728 //TODO this is not great, building TI again is expensive, but I can't use
729 //fill_stub because fill is not structured
730 //this is a corner case anyway
731     at::fill_(const_cast<Tensor&>(result), wrapped_scalar_tensor(max));
732   } else {
733     clamp_max_scalar_stub(device_type(), *this, max);
734   }
735 }
736 
TORCH_IMPL_FUNC(clamp_max_Tensor_out)737 TORCH_IMPL_FUNC(clamp_max_Tensor_out)
738 (const Tensor& self, const Tensor& max, const Tensor& result) {
739   minimum_stub(device_type(), *this);
740 }
741 
TORCH_IMPL_FUNC(clamp_min_out)742 TORCH_IMPL_FUNC(clamp_min_out)
743 (const Tensor& self, const Scalar& min, const Tensor& result) {
744   if (min.toDouble() != min.toDouble()) {
745     at::fill_(const_cast<Tensor&>(result), min);
746   } else {
747     clamp_min_scalar_stub(device_type(), *this, min);
748   }
749 }
750 
TORCH_IMPL_FUNC(clamp_min_Tensor_out)751 TORCH_IMPL_FUNC(clamp_min_Tensor_out)
752 (const Tensor& self, const Tensor& min, const Tensor& result) {
753   maximum_stub(device_type(), *this);
754 }
755 
756 // Implements the "clip" alias for clamp
clip_out(const Tensor & self,const std::optional<Scalar> & min,const std::optional<Scalar> & max,Tensor & result)757 Tensor& clip_out(const Tensor& self, const std::optional<Scalar>& min, const std::optional<Scalar>& max, Tensor& result) {
758   return at::clamp_outf(self, min, max, result);
759 }
760 
clip_out(const Tensor & self,const std::optional<Tensor> & min,const std::optional<Tensor> & max,Tensor & result)761 Tensor& clip_out(const Tensor& self, const std::optional<Tensor>& min, const std::optional<Tensor>& max, Tensor& result) {
762   return at::clamp_outf(self, min, max, result);
763 }
764 
clip(const Tensor & self,const std::optional<Scalar> & min,const std::optional<Scalar> & max)765 Tensor clip(const Tensor& self, const std::optional<Scalar>& min, const std::optional<Scalar>& max) {
766   return at::clamp(self, min, max);
767 }
768 
clip(const Tensor & self,const std::optional<Tensor> & min,const std::optional<Tensor> & max)769 Tensor clip(const Tensor& self, const std::optional<Tensor>& min, const std::optional<Tensor>& max) {
770   return at::clamp(self, min, max);
771 }
772 
clip_(Tensor & self,const std::optional<Scalar> & min,const std::optional<Scalar> & max)773 Tensor& clip_(Tensor& self, const std::optional<Scalar>& min, const std::optional<Scalar>& max) {
774   return at::clamp_(self, min, max);
775 }
776 
clip_(Tensor & self,const std::optional<Tensor> & min,const std::optional<Tensor> & max)777 Tensor& clip_(Tensor& self, const std::optional<Tensor>& min, const std::optional<Tensor>& max) {
778   return at::clamp_(self, min, max);
779 }
780 
781 // Named tensor overloads
782 
min(const Tensor & self,Dimname dim,bool keepdim)783 std::tuple<Tensor, Tensor> min(const Tensor& self, Dimname dim, bool keepdim) {
784   return at::min(self, dimname_to_position(self, dim), keepdim);
785 }
min_out(const Tensor & self,Dimname dim,bool keepdim,Tensor & min,Tensor & min_indices)786 std::tuple<Tensor &,Tensor &> min_out(const Tensor& self, Dimname dim, bool keepdim, Tensor& min, Tensor& min_indices) {
787   return at::min_out(min, min_indices, self, dimname_to_position(self, dim), keepdim);
788 }
max(const Tensor & self,Dimname dim,bool keepdim)789 std::tuple<Tensor, Tensor> max(const Tensor& self, Dimname dim, bool keepdim) {
790   return at::max(self, dimname_to_position(self, dim), keepdim);
791 }
max_out(const Tensor & self,Dimname dim,bool keepdim,Tensor & max,Tensor & max_indices)792 std::tuple<Tensor&, Tensor&> max_out(const Tensor& self, Dimname dim, bool keepdim, Tensor& max, Tensor& max_indices) {
793   return at::max_out(max, max_indices, self, dimname_to_position(self, dim), keepdim);
794 }
argsort(const Tensor &,Dimname,bool)795 Tensor argsort(const Tensor& /*self*/, Dimname /*dim*/, bool /*keepdim*/) {
796   reportNYIDimnameOverload("argsort");
797 }
mode(const Tensor & self,Dimname dim,bool keepdim)798 std::tuple<Tensor, Tensor> mode(const Tensor& self, Dimname dim, bool keepdim) {
799   return at::mode(self, dimname_to_position(self, dim), keepdim);
800 }
mode_out(const Tensor & self,Dimname dim,bool keepdim,Tensor & values,Tensor & indices)801 std::tuple<Tensor &,Tensor &> mode_out(const Tensor& self, Dimname dim, bool keepdim, Tensor& values, Tensor& indices) {
802   return at::mode_out(values, indices, self, dimname_to_position(self, dim), keepdim);
803 }
804 
TORCH_IMPL_FUNC(isin_Tensor_Tensor_out)805 TORCH_IMPL_FUNC(isin_Tensor_Tensor_out) (
806   const Tensor& elements, const Tensor& test_elements, bool assume_unique, bool invert, const Tensor& out
807 ) {
808   if (elements.numel() == 0) {
809     return;
810   }
811 
812   // Heuristic taken from numpy's implementation.
813   // See https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/arraysetops.py#L575
814   if (test_elements.numel() < static_cast<int64_t>(
815         10.0f * std::pow(static_cast<double>(elements.numel()), 0.145))) {
816     out.fill_(invert);
817     isin_default_stub(elements.device().type(), elements, test_elements, invert, out);
818   } else {
819     isin_sorting(elements, test_elements, assume_unique, invert, out);
820   }
821 }
822 
TORCH_IMPL_FUNC(isin_Tensor_Scalar_out)823 TORCH_IMPL_FUNC(isin_Tensor_Scalar_out) (
824   const Tensor& elements, const c10::Scalar& test_elements, bool assume_unique, bool invert, const Tensor& out
825 ) {
826   // redispatch to eq / ne
827   if (invert) {
828     at::ne_out(const_cast<Tensor&>(out), elements, test_elements);
829   } else {
830     at::eq_out(const_cast<Tensor&>(out), elements, test_elements);
831   }
832 }
833 
TORCH_IMPL_FUNC(isin_Scalar_Tensor_out)834 TORCH_IMPL_FUNC(isin_Scalar_Tensor_out) (
835   const c10::Scalar& elements, const Tensor& test_elements, bool assume_unique, bool invert, const Tensor& out
836 ) {
837   // redispatch
838   at::isin_out(const_cast<Tensor&>(out), wrapped_scalar_tensor(elements, test_elements.device()),
839     test_elements, assume_unique, invert);
840 }
841 
TORCH_IMPL_FUNC(isposinf_out)842 TORCH_IMPL_FUNC(isposinf_out) (const Tensor& self, const Tensor& result) {
843   if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
844     result.fill_(false);
845   } else {
846     isposinf_stub(device_type(), *this);
847   }
848 }
849 
TORCH_IMPL_FUNC(isneginf_out)850 TORCH_IMPL_FUNC(isneginf_out) (const Tensor& self, const Tensor& result) {
851   if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
852     result.fill_(false);
853   } else {
854     isneginf_stub(device_type(), *this);
855   }
856 }
857 
858 } // namespace at::native
859