1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/NamedTensorUtils.h>
5 #include <ATen/ScalarOps.h>
6 #include <ATen/TensorIndexing.h>
7 #include <ATen/TensorMeta.h>
8 #include <ATen/TensorOperators.h>
9 #include <ATen/WrapDimUtils.h>
10 #include <ATen/native/BinaryOps.h>
11 #include <ATen/native/ReduceOpsUtils.h>
12 #include <ATen/native/Resize.h>
13 #include <ATen/native/TensorCompare.h>
14 #include <ATen/native/TypeProperties.h>
15 #include <ATen/TensorSubclassLikeUtils.h>
16 #include <iostream>
17 #include <c10/util/Exception.h>
18
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/Functions.h>
21 #include <ATen/NativeFunctions.h>
22 #else
23 #include <ATen/ops/_aminmax_native.h>
24 #include <ATen/ops/_assert_async_native.h>
25 #include <ATen/ops/_functional_assert_async_native.h>
26 #include <ATen/ops/_print_native.h>
27 #include <ATen/ops/_assert_scalar_native.h>
28 #include <ATen/ops/_functional_assert_scalar_native.h>
29 #include <ATen/ops/_make_per_tensor_quantized_tensor.h>
30 #include <ATen/ops/_unique.h>
31 #include <ATen/ops/allclose_native.h>
32 #include <ATen/ops/aminmax.h>
33 #include <ATen/ops/argsort_native.h>
34 #include <ATen/ops/cat.h>
35 #include <ATen/ops/clamp.h>
36 #include <ATen/ops/clamp_max.h>
37 #include <ATen/ops/clamp_max_native.h>
38 #include <ATen/ops/clamp_min.h>
39 #include <ATen/ops/clamp_min_native.h>
40 #include <ATen/ops/clamp_native.h>
41 #include <ATen/ops/clip_native.h>
42 #include <ATen/ops/empty.h>
43 #include <ATen/ops/empty_like.h>
44 #include <ATen/ops/eq.h>
45 #include <ATen/ops/fill.h>
46 #include <ATen/ops/imag.h>
47 #include <ATen/ops/index.h>
48 #include <ATen/ops/is_nonzero_native.h>
49 #include <ATen/ops/isclose.h>
50 #include <ATen/ops/isclose_native.h>
51 #include <ATen/ops/isfinite.h>
52 #include <ATen/ops/isfinite_native.h>
53 #include <ATen/ops/isin.h>
54 #include <ATen/ops/isin_native.h>
55 #include <ATen/ops/isinf.h>
56 #include <ATen/ops/isinf_native.h>
57 #include <ATen/ops/isnan_native.h>
58 #include <ATen/ops/isneginf_native.h>
59 #include <ATen/ops/isposinf_native.h>
60 #include <ATen/ops/isreal_native.h>
61 #include <ATen/ops/max.h>
62 #include <ATen/ops/max_native.h>
63 #include <ATen/ops/min.h>
64 #include <ATen/ops/min_native.h>
65 #include <ATen/ops/mode.h>
66 #include <ATen/ops/mode_native.h>
67 #include <ATen/ops/ne.h>
68 #include <ATen/ops/ones_like.h>
69 #include <ATen/ops/real.h>
70 #include <ATen/ops/result_type_native.h>
71 #include <ATen/ops/scalar_tensor.h>
72 #include <ATen/ops/where.h>
73 #include <ATen/ops/where_native.h>
74 #include <ATen/ops/zeros_like.h>
75
76 #include <iostream>
77 #include <utility>
78 #endif
79
80 namespace at::meta {
81
check_for_unsupported_isin_dtype(const ScalarType type)82 static inline void check_for_unsupported_isin_dtype(const ScalarType type) {
83 // Bail out for dtypes unsupported by the sorting algorithm to keep the interface consistent.
84 TORCH_CHECK(type != ScalarType::Bool &&
85 type != ScalarType::BFloat16 &&
86 type != ScalarType::ComplexFloat &&
87 type != ScalarType::ComplexDouble,
88 "Unsupported input type encountered for isin(): ", type);
89 }
90
TORCH_META_FUNC(clamp)91 TORCH_META_FUNC(clamp) (
92 const Tensor& self,
93 const OptionalScalarRef min,
94 const OptionalScalarRef max) {
95 if (!min && !max) {
96 TORCH_CHECK(false, "torch.clamp: At least one of 'min' or 'max' must not be None");
97 }
98 //Manual type promotion, since scalars have to participate in it
99 ScalarType result_type = self.scalar_type();
100 TORCH_CHECK(!isComplexType(result_type), "clamp is not supported for complex types");
101 //Floating is the highest supported
102 if (!isFloatingType(result_type)) {
103 at::native::ResultTypeState state = {};
104 state = at::native::update_result_type_state(self, state);
105
106 if (min) {
107 state = at::native::update_result_type_state(min.get(), state);
108 }
109 if (max) {
110 state = at::native::update_result_type_state(max.get(), state);
111 }
112 result_type = at::native::result_type(state);
113 //disallow type promoting inplace op
114 TORCH_CHECK((result_type == self.scalar_type()) ||
115 (!(maybe_get_output().defined()) || !(maybe_get_output().is_same(self))),
116 "result type ", result_type, " can't be cast to the desired output type ",
117 self.dtype());
118 }
119 //make sure scalars weren't complex
120 TORCH_CHECK(!isComplexType(result_type), "clamp is not supported for complex types");
121 build_unary_op(maybe_get_output(), self.to(result_type));
122 }
123
TORCH_META_FUNC2(clamp,Tensor)124 TORCH_META_FUNC2(clamp, Tensor) (
125 const Tensor& self,
126 const OptionalTensorRef min,
127 const OptionalTensorRef max) {
128 TORCH_CHECK(min || max, "torch.clamp: At least one of 'min' or 'max' must not be None");
129 TORCH_CHECK(!isComplexType(self.scalar_type()), "clamp is not supported for complex types");
130 #define CLAMP_CONFIG() \
131 TensorIteratorConfig() \
132 .set_check_mem_overlap(true) \
133 .add_output(maybe_get_output()) \
134 .add_const_input(self) \
135 .promote_inputs_to_common_dtype(true) \
136 .cast_common_dtype_to_outputs(true) \
137 .enforce_safe_casting_to_output(true)
138
139 if (min && max) {
140 build(CLAMP_CONFIG().add_const_input(*min).add_const_input(*max));
141 } else if (min) {
142 build(CLAMP_CONFIG().add_const_input(*min));
143 } else if (max) {
144 build(CLAMP_CONFIG().add_const_input(*max));
145 }
146 }
147
148
TORCH_META_FUNC(clamp_max)149 TORCH_META_FUNC(clamp_max) (
150 const Tensor& self,
151 const Scalar& max
152 ) {
153 //we could wrap max into tensor and send to tensor overload,
154 //but relu is implemented via clamp_min, so for perf an uniformity reasons
155 //do a faster but correct thing
156 ScalarType result_type = self.scalar_type();
157 TORCH_CHECK(!isComplexType(result_type), "clamp is not supported for complex types");
158 TORCH_CHECK(!max.isComplex(), "clamp is not supported for complex types");
159 //Floating is the highest supported
160 if (!isFloatingType(result_type)) {
161 auto result_type = at::native::result_type(self, max);
162 TORCH_CHECK((result_type == self.scalar_type()) ||
163 (!(maybe_get_output().defined()) || !(maybe_get_output().is_same(self))),
164 "result type ", result_type, " can't be cast to the desired output type ",
165 self.dtype());
166 build_unary_op(maybe_get_output(), self.to(result_type));
167 } else {
168 build_borrowing_unary_op(maybe_get_output(), self);
169 }
170 }
171
TORCH_META_FUNC2(clamp_max,Tensor)172 TORCH_META_FUNC2(clamp_max, Tensor) (
173 const Tensor& self,
174 const Tensor& max
175 ) {
176 build_borrowing_binary_op(maybe_get_output(), self, max);
177 }
178
179
TORCH_META_FUNC(clamp_min)180 TORCH_META_FUNC(clamp_min) (
181 const Tensor& self,
182 const Scalar& min
183 ) {
184 ScalarType result_type = self.scalar_type();
185 TORCH_CHECK(!isComplexType(result_type), "clamp is not supported for complex types");
186 TORCH_CHECK(!min.isComplex(), "clamp is not supported for complex types");
187 //Floating is the highest supported
188 if (!isFloatingType(result_type)) {
189 auto result_type = at::native::result_type(self, min);
190 TORCH_CHECK((result_type == self.scalar_type() ||
191 !(maybe_get_output().defined()) || !(maybe_get_output().is_same(self))),
192 "result type ", result_type, " can't be cast to the desired output type ",
193 self.dtype());
194 build_unary_op(maybe_get_output(), self.to(result_type));
195 } else {
196 build_borrowing_unary_op(maybe_get_output(), self);
197 }
198 }
199
TORCH_META_FUNC2(clamp_min,Tensor)200 TORCH_META_FUNC2(clamp_min, Tensor) (
201 const Tensor& self,
202 const Tensor& min
203 ) {
204 build_borrowing_binary_op(maybe_get_output(), self, min);
205 }
206
TORCH_META_FUNC2(isin,Tensor_Tensor)207 TORCH_META_FUNC2(isin, Tensor_Tensor) (
208 const Tensor& elements, const Tensor& test_elements, bool /*assume_unique*/, bool /*invert*/
209 ) {
210 check_for_unsupported_isin_dtype(elements.scalar_type());
211 check_for_unsupported_isin_dtype(test_elements.scalar_type());
212 set_output_raw_strided(0, elements.sizes(), {}, TensorOptions(elements.device()).dtype(ScalarType::Bool));
213 }
214
TORCH_META_FUNC2(isin,Tensor_Scalar)215 TORCH_META_FUNC2(isin, Tensor_Scalar) (
216 const Tensor& elements, const c10::Scalar& test_elements, bool /*assume_unique*/, bool /*invert*/
217 ) {
218 check_for_unsupported_isin_dtype(elements.scalar_type());
219 check_for_unsupported_isin_dtype(test_elements.type());
220 set_output_raw_strided(0, elements.sizes(), {}, TensorOptions(elements.device()).dtype(ScalarType::Bool));
221 }
222
TORCH_META_FUNC2(isin,Scalar_Tensor)223 TORCH_META_FUNC2(isin, Scalar_Tensor) (
224 const c10::Scalar& elements, const Tensor& test_elements, bool /*assume_unique*/, bool /*invert*/
225 ) {
226 check_for_unsupported_isin_dtype(elements.type());
227 check_for_unsupported_isin_dtype(test_elements.scalar_type());
228 set_output_raw_strided(0, {0}, {}, TensorOptions(test_elements.device()).dtype(ScalarType::Bool));
229 }
230
TORCH_META_FUNC(isposinf)231 TORCH_META_FUNC(isposinf) (const Tensor& self) {
232 TORCH_CHECK(!self.is_complex(), "isposinf does not support complex inputs.");
233 TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true,
234 "isposinf does not support non-boolean outputs.");
235 build_borrowing_unary_force_boolean_op(maybe_get_output(), self);
236 }
237
TORCH_META_FUNC(isneginf)238 TORCH_META_FUNC(isneginf) (const Tensor& self) {
239 TORCH_CHECK(!self.is_complex(), "isneginf does not support complex inputs.");
240 TORCH_CHECK(maybe_get_output().defined() ? maybe_get_output().dtype() == at::kBool : true,
241 "isneginf does not support non-boolean outputs.");
242 build_borrowing_unary_force_boolean_op(maybe_get_output(), self);
243 }
244
check_unsupported_complex(const char * name,const Tensor & self)245 static void check_unsupported_complex(const char* name, const Tensor& self) {
246 TORCH_CHECK(!self.is_complex(), name, ": does not support complex input");
247 }
248
TORCH_PRECOMPUTE_META_FUNC2(max,dim)249 TORCH_PRECOMPUTE_META_FUNC2(max, dim)
250 (const Tensor& self, int64_t dim, bool keepdim) {
251 dim = maybe_wrap_dim(dim, self.dim());
252 at::native::zero_numel_check_dims(self, dim, "max()");
253 check_unsupported_complex("max()", self);
254 resize_reduction_with_indices(*this, self, dim, keepdim, self.scalar_type());
255 return TORCH_PRECOMPUTE_STRUCT2(max, dim)()
256 .set_dim(maybe_wrap_dim(dim, self.dim()));
257 }
258
TORCH_PRECOMPUTE_META_FUNC2(min,dim)259 TORCH_PRECOMPUTE_META_FUNC2(min, dim)(const Tensor& self, int64_t dim, bool keepdim) {
260 dim = maybe_wrap_dim(dim, self.dim());
261 at::native::zero_numel_check_dims(self, dim, "min()");
262 check_unsupported_complex("min()", self);
263 resize_reduction_with_indices(*this, self, dim, keepdim, self.scalar_type());
264 return TORCH_PRECOMPUTE_STRUCT2(min, dim)()
265 .set_dim(maybe_wrap_dim(dim, self.dim()));
266 }
267
268 } // namespace at::meta
269
270 namespace at::native {
271
272 DEFINE_DISPATCH(where_kernel); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
273 DEFINE_DISPATCH(max_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
274 DEFINE_DISPATCH(min_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
275 DEFINE_DISPATCH(isposinf_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
276 DEFINE_DISPATCH(isneginf_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
277 DEFINE_DISPATCH(mode_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
278 DEFINE_DISPATCH(clamp_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
279 DEFINE_DISPATCH(clamp_scalar_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
280 DEFINE_DISPATCH(clamp_min_scalar_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
281 DEFINE_DISPATCH(clamp_max_scalar_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
282 DEFINE_DISPATCH(isin_default_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
283
allclose(const Tensor & self,const Tensor & other,double rtol,double atol,bool equal_nan)284 bool allclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) {
285 return at::isclose(self, other, rtol, atol, equal_nan).all().item<uint8_t>();
286 }
287
288 // Note [closeness]
289 // A number A is close to B when either:
290 //
291 // (1) A is equal to B, with NaNs comparing equal when equal_nan is true.
292 // (2) The error abs(A - B) is finite and less than the max error
293 // (atol + abs(rtol * B)).
294 //
295 // Note that this is consistent with NumPy's isclose but divergent from
296 // Python's isclose, which computes the max error symmetrically as
297 // max(rtol * max(abs(A), abs(B)), atol).
298 // TODO: use bitwise operator overloads once we add them
299 // TODO: revisit complex inputs and equal_nan=true after
300 // https://github.com/numpy/numpy/issues/15959 is resolved
isclose(const Tensor & self,const Tensor & other,double rtol,double atol,bool equal_nan)301 Tensor isclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) {
302 TORCH_CHECK(self.scalar_type() == other.scalar_type(), self.scalar_type(), " did not match ", other.scalar_type());
303 TORCH_CHECK(!(self.is_quantized() || other.is_quantized()),
304 "isclose is not supported for quantized inputs.");
305
306 // Checks that rtol and atol are non-negative
307 // Note: consistent with Python's isclose but divergent from NumPy's, which
308 // allows negative atol and rtol.
309 TORCH_CHECK(rtol >= 0, "rtol must be greater than or equal to zero, but got ", rtol);
310 TORCH_CHECK(atol >= 0, "atol must be greater than or equal to zero, but got ", atol);
311
312 // Computes equality closeness
313 Tensor close = self == other;
314 if (equal_nan && (self.is_floating_point() || self.is_complex())) {
315 // For CompositeCompliance, if `other` is a CCT and `self` is a regular Tensor,
316 // then we can't perform inplace op into `self` with `other`.
317 // NOTE: Inplacing into `close` is fine because it is generated from
318 // out-of-place with args `self` and `other`. So if either of them is
319 // a CCT then `close` will also be a `CCT`.
320 if (isTensorSubclassLike(other)) {
321 close.__ior__(self.isnan().bitwise_and(other.isnan()));
322 } else {
323 close.__ior__(self.isnan().__iand__(other.isnan()));
324 }
325 }
326
327 // In case of zero tolerances the closeness inequality degenerates to an equality check.
328 // In this case, the short-circuit prevents false positives as detailed in the paragraph below.
329 if (rtol == 0 && atol == 0){
330 return close;
331 }
332
333 // Note [closeness error computation]
334 // atol and rtol are provided as doubles, so the computation
335 // rtol * other will produce a float or complex tensor.
336 // When the difference (self - other) is compared to it then the
337 // tensor representing the difference will also be cast to float or complex.
338 // However, since (self - other) in uint8 is very likely to produce a
339 // negative value, this moves the cast forward so the difference is
340 // always computed in a float or complex type.
341 // If the values of the integer tensors cannot be exactly represented
342 // by the default scalar type then this may cause an incorrect result.
343
344 // Computes allowed and actual error
345 Tensor cast_self, cast_other;
346 cast_self = self.scalar_type() == at::kBool ? self.to(at::get_default_dtype()) : self;
347 if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
348 cast_other = other.to(at::get_default_dtype());
349 } else {
350 cast_other = other;
351 }
352
353 Tensor allowed_error = atol + (rtol * cast_other).abs();
354 Tensor actual_error = (cast_self - cast_other).abs();
355
356 // Computes finite closeness
357 close.__ior__(at::isfinite(actual_error).__iand__(actual_error <= allowed_error));
358
359 return close;
360 }
361
isnan(const Tensor & self)362 Tensor isnan(const Tensor& self) {
363 return self != self;
364 }
365
isreal(const Tensor & self)366 Tensor isreal(const Tensor& self) {
367 // Note: Integral and Floating tensor values are always real
368 if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true) ||
369 c10::isFloatingType(self.scalar_type())) {
370 return at::ones_like(self, at::kBool, at::MemoryFormat::Preserve);
371 }
372
373 return at::imag(self) == 0;
374 }
375
376
377 #if !defined(C10_MOBILE)
378 #define _AT_DISPATCH_INF_TYPES(TYPE, NAME, ...) \
379 AT_DISPATCH_FLOATING_TYPES_AND3( kHalf, kBFloat16, kFloat8_e5m2, \
380 TYPE, NAME, __VA_ARGS__)
381 #else
382 #define _AT_DISPATCH_INF_TYPES(TYPE, NAME, ...) \
383 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, \
384 TYPE, NAME, __VA_ARGS__)
385 #endif
386
387
isinf(const Tensor & self)388 Tensor isinf(const Tensor &self) {
389 // Note: Integral tensor values are never infinite
390 if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
391 return at::zeros_like(self, at::kBool, at::MemoryFormat::Preserve);
392 }
393
394 // Note: a complex value is infinite when either part is infinite
395 if (self.is_complex()) {
396 return at::isinf(at::real(self)).__ior__
397 (at::isinf(at::imag(self)));
398 }
399
400 return _AT_DISPATCH_INF_TYPES(self.scalar_type(), "isinf", [&]() {
401 return self.abs() == std::numeric_limits<scalar_t>::infinity();
402 });
403 }
404
isfinite(const Tensor & self)405 Tensor isfinite(const Tensor& self) {
406 // Note: Integral tensor values are always finite
407 if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
408 return at::ones_like(self, at::kBool, at::MemoryFormat::Preserve);
409 }
410
411 // Note: a complex value is finite iff both parts are finite
412 if (self.is_complex()) {
413 return at::isfinite(at::real(self)).__iand__(at::isfinite(at::imag(self)));
414 }
415
416 return _AT_DISPATCH_INF_TYPES(self.scalar_type(), "isfinite", [&]() {
417 return (self == self) * (self.abs() != std::numeric_limits<scalar_t>::infinity());
418 });
419 }
420
_assert_async_cpu(const Tensor & self)421 void _assert_async_cpu(const Tensor& self) {
422 TORCH_CHECK(native::is_nonzero(self), "Expected Tensor with single nonzero value, but got zero");
423 }
424
_assert_async_msg_cpu(const Tensor & self,c10::string_view assert_msg)425 void _assert_async_msg_cpu(const Tensor& self, c10::string_view assert_msg) {
426 TORCH_CHECK(native::is_nonzero(self), assert_msg != "" ? assert_msg : "Assertion is failed");
427 }
428
_assert_scalar(const Scalar & scalar,c10::string_view assert_msg)429 void _assert_scalar(const Scalar& scalar, c10::string_view assert_msg) {
430 TORCH_SYM_CHECK(scalar.toSymBool(), assert_msg != "" ? assert_msg : "Assertion is failed");
431 }
432
_functional_assert_scalar(const Scalar & scalar,c10::string_view assert_msg,const Tensor & dep_token)433 Tensor _functional_assert_scalar(const Scalar& scalar, c10::string_view assert_msg, const Tensor& dep_token) {
434 _assert_scalar(scalar, assert_msg);
435 return dep_token.clone();
436 }
437
_functional_assert_async_msg_cpu(const Tensor & self,c10::string_view assert_msg,const Tensor & dep_token)438 Tensor _functional_assert_async_msg_cpu(
439 const Tensor& self,
440 c10::string_view assert_msg,
441 const Tensor& dep_token) {
442 _assert_async_msg_cpu(self, assert_msg);
443 return dep_token.clone();
444 }
445
_print(c10::string_view s)446 void _print(c10::string_view s) {
447 std::cout << s << "\n";
448 }
449
450 // Sorting-based algorithm for isin(); used when the number of test elements is large.
isin_sorting(const Tensor & elements,const Tensor & test_elements,bool assume_unique,bool invert,const Tensor & out)451 static void isin_sorting(
452 const Tensor& elements,
453 const Tensor& test_elements,
454 bool assume_unique,
455 bool invert,
456 const Tensor& out) {
457 // 1. Concatenate unique elements with unique test elements in 1D form. If
458 // assume_unique is true, skip calls to unique().
459 Tensor elements_flat, test_elements_flat, unique_order;
460 if (assume_unique) {
461 elements_flat = elements.ravel();
462 test_elements_flat = test_elements.ravel();
463 } else {
464 std::tie(elements_flat, unique_order) = at::_unique(
465 elements, /*sorted=*/ false, /*return_inverse=*/ true);
466 std::tie(test_elements_flat, std::ignore) = at::_unique(test_elements, /*sorted=*/ false);
467 }
468
469 // 2. Stable sort all elements, maintaining order indices to reverse the
470 // operation. Stable sort is necessary to keep elements before test
471 // elements within the sorted list.
472 Tensor all_elements = at::cat({std::move(elements_flat), std::move(test_elements_flat)});
473 auto [sorted_elements, sorted_order] = all_elements.sort(
474 /*stable=*/ true, /*dim=*/ 0, /*descending=*/ false);
475
476 // 3. Create a mask for locations of adjacent duplicate values within the
477 // sorted list. Duplicate values are in both elements and test elements.
478 Tensor duplicate_mask = at::empty_like(sorted_elements, TensorOptions(ScalarType::Bool));
479 Tensor sorted_except_first = sorted_elements.slice(0, 1, at::indexing::None);
480 Tensor sorted_except_last = sorted_elements.slice(0, 0, -1);
481 duplicate_mask.slice(0, 0, -1).copy_(
482 invert ? sorted_except_first.ne(sorted_except_last) : sorted_except_first.eq(sorted_except_last));
483 duplicate_mask.index_put_({-1}, invert);
484
485 // 4. Reorder the mask to match the pre-sorted element order.
486 Tensor mask = at::empty_like(duplicate_mask);
487 mask.index_copy_(0, sorted_order, duplicate_mask);
488
489 // 5. Index the mask to match the pre-unique element order. If
490 // assume_unique is true, just take the first N items of the mask,
491 // where N is the original number of elements.
492 if (assume_unique) {
493 out.copy_(mask.slice(0, 0, elements.numel()).view_as(out));
494 } else {
495 out.copy_(at::index(mask, {std::optional<Tensor>(unique_order)}));
496 }
497 }
498
499 template<typename... Args>
out_device(Args &...inps)500 Device out_device(Args&... inps){
501 for (const auto& i : {inps...}){
502 if (i.device() != at::kCPU) {
503 return i.device();
504 }
505 }
506 return at::kCPU;
507 }
508
509
where_self_out(const Tensor & condition,const Tensor & self,const Tensor & other,Tensor & out)510 Tensor& where_self_out(const Tensor& condition, const Tensor& self, const Tensor& other, Tensor& out) {
511 const auto result_type = at::native::result_type(self, other);
512 TORCH_CHECK(out.scalar_type() == result_type, "Expected out type to be ", result_type, " but got ", out.scalar_type());
513
514 auto self_ = self.scalar_type() != result_type ? self.to(result_type): self;
515 auto other_ = other.scalar_type() != result_type ? other.to(result_type): other;
516 auto condition_ = condition;
517 auto device = out_device(condition, self_, other_);
518 if (device != at::kCPU) { // allow CPU scalars on non-cpu device
519 if (condition.device() != device && condition.ndimension() == 0) {
520 condition_ = condition.to(device);
521 }
522 if (self_.device() != device && self_.ndimension() == 0) {
523 self_ = self_.to(device);
524 }
525 if (other_.device() != device && other_.ndimension() == 0) {
526 other_ = other_.to(device);
527 }
528 }
529 if (condition_.scalar_type() == ScalarType::Byte) {
530 TORCH_WARN_ONCE("where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead.");
531 condition_ = condition_.to(kBool);
532 }
533 TORCH_CHECK(condition_.scalar_type() == kBool, "where expected condition to be a boolean tensor, but got a tensor with dtype ", condition_.scalar_type());
534 // if there's still a device mismatch, let tensoriterator error out with it
535 auto iter = at::TensorIteratorConfig()
536 .check_all_same_dtype(false)
537 .add_output(out)
538 .add_const_input(condition_)
539 .add_const_input(self_)
540 .add_const_input(other_)
541 .build();
542 where_kernel(iter.device_type(), iter);
543 return out;
544 }
545
546
where(const Tensor & condition,const Tensor & self,const Tensor & other)547 Tensor where(const Tensor& condition, const Tensor& self, const Tensor& other) {
548 auto device = out_device(condition, self, other);
549 auto result_type = at::native::result_type(self, other);
550 Tensor ret = at::empty({0}, self.options().dtype(result_type).device(device));
551 at::native::where_self_out(condition, self, other, ret);
552 return ret;
553 }
554
where(const Tensor & condition,const Scalar & self,const Tensor & other)555 Tensor where(const Tensor& condition, const Scalar& self, const Tensor& other) {
556 auto result_type = at::native::result_type(other, self);
557 auto self_converted = at::scalar_tensor(self, other.options().dtype(result_type));
558 auto other_converted = other.to(result_type);
559 return at::where(condition, self_converted, other_converted);
560 }
561
where(const Tensor & condition,const Tensor & self,const Scalar & other)562 Tensor where(const Tensor& condition, const Tensor& self, const Scalar& other) {
563 auto result_type = at::native::result_type(self, other);
564 auto other_converted = at::scalar_tensor(other, self.options().dtype(result_type));
565 auto self_converted = self.to(result_type);
566 return at::where(condition, self_converted, other_converted);
567 }
568
where(const Tensor & condition,const Scalar & self,const Scalar & other)569 Tensor where(const Tensor& condition, const Scalar& self, const Scalar& other) {
570 auto result_type = at::native::result_type(self, other);
571 const Tensor& other_t = at::scalar_tensor(other, condition.options().dtype(result_type));
572 const Tensor& self_t = at::scalar_tensor(self, condition.options().dtype(result_type));
573 return at::where(condition, self_t, other_t);
574 }
575
where(const Tensor & condition)576 std::vector<Tensor> where(const Tensor& condition) {
577 return condition.nonzero_numpy();
578 }
579
mode(const Tensor & self,int64_t dim,bool keepdim)580 std::tuple<Tensor, Tensor> mode(const Tensor& self, int64_t dim, bool keepdim) {
581 Tensor values = at::empty({0}, self.options());
582 Tensor indices = at::empty({0}, self.options().dtype(kLong));
583 return at::native::mode_out(self, dim, keepdim, values, indices);
584 }
585
mode_out(const Tensor & self,int64_t dim,bool keepdim,Tensor & values,Tensor & indices)586 std::tuple<Tensor &,Tensor &> mode_out(const Tensor& self, int64_t dim, bool keepdim,
587 Tensor& values, Tensor& indices) {
588 TORCH_CHECK(self.device().is_cpu() || self.is_cuda(),
589 "mode only supports CPU AND CUDA device type, got: ", self.device().type());
590 TORCH_CHECK(self.layout() == Layout::Strided,
591 "mode only supports strided layout, got: ", self.layout());
592 TORCH_CHECK(self.device() == values.device(),
593 "expected device '", self.device(), "' but got '",
594 values.device(), "' for values output");
595 TORCH_CHECK(self.device() == indices.device(),
596 "expected device '", self.device(), "' but got '",
597 indices.device(), "' for indices output");
598 TORCH_CHECK(self.scalar_type() == values.scalar_type(),
599 "expected scalar type '", self.scalar_type(), "' but got '",
600 values.scalar_type(), "' for values output");
601 TORCH_CHECK(indices.scalar_type() == ScalarType::Long,
602 "expected scalar type '", ScalarType::Long, "' but got '",
603 indices.scalar_type(), "' for indices output");
604 dim = maybe_wrap_dim(dim, self.dim());
605 if (self.numel() == 0) {
606 auto sizes = get_zero_numel_tensor_size(self, dim, keepdim, "mode()");
607 resize_output(values, sizes);
608 resize_output(indices, sizes);
609 return std::tie(values, indices);
610 }
611 else if (_dimreduce_return_trivial_no_ident(values, self, dim, keepdim, "mode")) {
612 AT_ASSERT(values.dim() == 0);
613 indices.resize_({}).fill_(0);
614 return std::forward_as_tuple(values, indices);
615 } else {
616 auto result = [&]() {
617 NoNamesGuard guard;
618 mode_stub(self.device().type(), values, indices, self, dim, keepdim);
619 return std::tuple<Tensor &,Tensor &>{values, indices};
620 }();
621 namedinference::propagate_names_for_reduction(std::get<0>(result), self, dim, keepdim);
622 namedinference::propagate_names_for_reduction(std::get<1>(result), self, dim, keepdim);
623 return result;
624 }
625 }
626
627 template <class Stub>
minmax_out_impl(const Tensor & self,int64_t dim,bool keepdim,const Tensor & values,const Tensor & indices,Stub & stub)628 void minmax_out_impl(
629 const Tensor& self,
630 int64_t dim,
631 bool keepdim,
632 const Tensor& values,
633 const Tensor& indices,
634 Stub& stub) {
635 NoNamesGuard guard;
636 if (self.numel() > 0) {
637 if (self.numel() == 1 && self.dim() == 0) {
638 values.fill_(self);
639 indices.fill_(0);
640 } else {
641 stub(self.device().type(), values, indices, self, dim, keepdim);
642 }
643 }
644 }
645
TORCH_IMPL_FUNC(max_out)646 TORCH_IMPL_FUNC(max_out)
647 (const Tensor& self,
648 int64_t dim,
649 bool keepdim,
650 const Tensor& values,
651 const Tensor& indices) {
652 minmax_out_impl(self, dim, keepdim, values, indices, max_stub);
653 }
654
TORCH_IMPL_FUNC(min_out)655 TORCH_IMPL_FUNC(min_out)
656 (const Tensor& self,
657 int64_t dim,
658 bool keepdim,
659 const Tensor& values,
660 const Tensor& indices) {
661 minmax_out_impl(self, dim, keepdim, values, indices, min_stub);
662 }
663
qmax(const Tensor & self,int64_t dim,bool keepdim)664 std::tuple<Tensor, Tensor> qmax(const Tensor& self, int64_t dim, bool keepdim) {
665 TORCH_CHECK(self.qscheme() == at::kPerTensorAffine, "Max operator for quantized tensors only works for per tensor quantized tensors. "
666 "Please open an issue on https://github.com/pytorch/pytorch/issues if you need per channel quantized tensor support.");
667 Tensor max_indices = at::empty({0}, self.options().dtype(kLong));
668 Tensor max = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type())));
669 at::max_outf(self.int_repr(), dim, keepdim, max, max_indices);
670 // TODO: qscheme
671 return std::tuple<Tensor, Tensor>(
672 at::_make_per_tensor_quantized_tensor(max, self.q_scale(), self.q_zero_point()), max_indices);
673 }
674
qmin(const Tensor & self,int64_t dim,bool keepdim)675 std::tuple<Tensor, Tensor> qmin(const Tensor& self, int64_t dim, bool keepdim) {
676 TORCH_CHECK(self.qscheme() == at::kPerTensorAffine, "Min operator for quantized tensors only works for per tensor quantized tensors. "
677 "Please open an issue on https://github.com/pytorch/pytorch/issues if you need per channel quantized tensor support.");
678 Tensor min_indices = at::empty({0}, self.options().dtype(kLong));
679 Tensor min = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type())));
680 at::min_outf(self.int_repr(), dim, keepdim, min, min_indices);
681 return std::tuple<Tensor, Tensor>(
682 at::_make_per_tensor_quantized_tensor(min, self.q_scale(), self.q_zero_point()), min_indices);
683 }
684
685 // DEPRECATED: Use at::aminmax instead
_aminmax(const Tensor & self,int64_t dim,bool keepdim)686 std::tuple<Tensor, Tensor> _aminmax(const Tensor& self, int64_t dim, bool keepdim) {
687 TORCH_WARN_ONCE("_aminmax is deprecated as of PyTorch 1.11 and will be removed in a future release. Use aminmax instead."
688 " This warning will only appear once per process.");
689 return at::aminmax(self, dim, keepdim);
690 }
691
TORCH_IMPL_FUNC(clamp_out)692 TORCH_IMPL_FUNC(clamp_out)
693 (
694 const Tensor& /*self*/,
695 const OptionalScalarRef min,
696 const OptionalScalarRef max,
697 const Tensor& result) {
698 using at::native::detail::ClampLimits;
699 if (min && max) {
700 if (min.get().toDouble() != min.get().toDouble() ||
701 max.get().toDouble() != max.get().toDouble()) {
702 at::fill_(const_cast<Tensor&>(result), std::numeric_limits<double>::quiet_NaN());
703 } else {
704 clamp_scalar_stub(device_type(), *this, min.get(), max.get());
705 }
706 } else if (max) {
707 clamp_max_scalar_stub(device_type(), *this, max.get());
708 } else if (min) {
709 clamp_min_scalar_stub(device_type(), *this, min.get());
710 }
711 }
712
TORCH_IMPL_FUNC(clamp_Tensor_out)713 TORCH_IMPL_FUNC(clamp_Tensor_out)
714 (const Tensor& self, const OptionalTensorRef min,
715 const OptionalTensorRef max, const Tensor&) {
716 if (min && max) {
717 clamp_stub(device_type(), *this);
718 } else if (min) {
719 maximum_stub(device_type(), *this);
720 } else if (max) {
721 minimum_stub(device_type(), *this);
722 }
723 }
724
TORCH_IMPL_FUNC(clamp_max_out)725 TORCH_IMPL_FUNC(clamp_max_out)
726 (const Tensor& self, const Scalar& max, const Tensor& result) {
727 if (max.toDouble() != max.toDouble()) {
728 //TODO this is not great, building TI again is expensive, but I can't use
729 //fill_stub because fill is not structured
730 //this is a corner case anyway
731 at::fill_(const_cast<Tensor&>(result), wrapped_scalar_tensor(max));
732 } else {
733 clamp_max_scalar_stub(device_type(), *this, max);
734 }
735 }
736
TORCH_IMPL_FUNC(clamp_max_Tensor_out)737 TORCH_IMPL_FUNC(clamp_max_Tensor_out)
738 (const Tensor& self, const Tensor& max, const Tensor& result) {
739 minimum_stub(device_type(), *this);
740 }
741
TORCH_IMPL_FUNC(clamp_min_out)742 TORCH_IMPL_FUNC(clamp_min_out)
743 (const Tensor& self, const Scalar& min, const Tensor& result) {
744 if (min.toDouble() != min.toDouble()) {
745 at::fill_(const_cast<Tensor&>(result), min);
746 } else {
747 clamp_min_scalar_stub(device_type(), *this, min);
748 }
749 }
750
TORCH_IMPL_FUNC(clamp_min_Tensor_out)751 TORCH_IMPL_FUNC(clamp_min_Tensor_out)
752 (const Tensor& self, const Tensor& min, const Tensor& result) {
753 maximum_stub(device_type(), *this);
754 }
755
756 // Implements the "clip" alias for clamp
clip_out(const Tensor & self,const std::optional<Scalar> & min,const std::optional<Scalar> & max,Tensor & result)757 Tensor& clip_out(const Tensor& self, const std::optional<Scalar>& min, const std::optional<Scalar>& max, Tensor& result) {
758 return at::clamp_outf(self, min, max, result);
759 }
760
clip_out(const Tensor & self,const std::optional<Tensor> & min,const std::optional<Tensor> & max,Tensor & result)761 Tensor& clip_out(const Tensor& self, const std::optional<Tensor>& min, const std::optional<Tensor>& max, Tensor& result) {
762 return at::clamp_outf(self, min, max, result);
763 }
764
clip(const Tensor & self,const std::optional<Scalar> & min,const std::optional<Scalar> & max)765 Tensor clip(const Tensor& self, const std::optional<Scalar>& min, const std::optional<Scalar>& max) {
766 return at::clamp(self, min, max);
767 }
768
clip(const Tensor & self,const std::optional<Tensor> & min,const std::optional<Tensor> & max)769 Tensor clip(const Tensor& self, const std::optional<Tensor>& min, const std::optional<Tensor>& max) {
770 return at::clamp(self, min, max);
771 }
772
clip_(Tensor & self,const std::optional<Scalar> & min,const std::optional<Scalar> & max)773 Tensor& clip_(Tensor& self, const std::optional<Scalar>& min, const std::optional<Scalar>& max) {
774 return at::clamp_(self, min, max);
775 }
776
clip_(Tensor & self,const std::optional<Tensor> & min,const std::optional<Tensor> & max)777 Tensor& clip_(Tensor& self, const std::optional<Tensor>& min, const std::optional<Tensor>& max) {
778 return at::clamp_(self, min, max);
779 }
780
781 // Named tensor overloads
782
min(const Tensor & self,Dimname dim,bool keepdim)783 std::tuple<Tensor, Tensor> min(const Tensor& self, Dimname dim, bool keepdim) {
784 return at::min(self, dimname_to_position(self, dim), keepdim);
785 }
min_out(const Tensor & self,Dimname dim,bool keepdim,Tensor & min,Tensor & min_indices)786 std::tuple<Tensor &,Tensor &> min_out(const Tensor& self, Dimname dim, bool keepdim, Tensor& min, Tensor& min_indices) {
787 return at::min_out(min, min_indices, self, dimname_to_position(self, dim), keepdim);
788 }
max(const Tensor & self,Dimname dim,bool keepdim)789 std::tuple<Tensor, Tensor> max(const Tensor& self, Dimname dim, bool keepdim) {
790 return at::max(self, dimname_to_position(self, dim), keepdim);
791 }
max_out(const Tensor & self,Dimname dim,bool keepdim,Tensor & max,Tensor & max_indices)792 std::tuple<Tensor&, Tensor&> max_out(const Tensor& self, Dimname dim, bool keepdim, Tensor& max, Tensor& max_indices) {
793 return at::max_out(max, max_indices, self, dimname_to_position(self, dim), keepdim);
794 }
argsort(const Tensor &,Dimname,bool)795 Tensor argsort(const Tensor& /*self*/, Dimname /*dim*/, bool /*keepdim*/) {
796 reportNYIDimnameOverload("argsort");
797 }
mode(const Tensor & self,Dimname dim,bool keepdim)798 std::tuple<Tensor, Tensor> mode(const Tensor& self, Dimname dim, bool keepdim) {
799 return at::mode(self, dimname_to_position(self, dim), keepdim);
800 }
mode_out(const Tensor & self,Dimname dim,bool keepdim,Tensor & values,Tensor & indices)801 std::tuple<Tensor &,Tensor &> mode_out(const Tensor& self, Dimname dim, bool keepdim, Tensor& values, Tensor& indices) {
802 return at::mode_out(values, indices, self, dimname_to_position(self, dim), keepdim);
803 }
804
TORCH_IMPL_FUNC(isin_Tensor_Tensor_out)805 TORCH_IMPL_FUNC(isin_Tensor_Tensor_out) (
806 const Tensor& elements, const Tensor& test_elements, bool assume_unique, bool invert, const Tensor& out
807 ) {
808 if (elements.numel() == 0) {
809 return;
810 }
811
812 // Heuristic taken from numpy's implementation.
813 // See https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/arraysetops.py#L575
814 if (test_elements.numel() < static_cast<int64_t>(
815 10.0f * std::pow(static_cast<double>(elements.numel()), 0.145))) {
816 out.fill_(invert);
817 isin_default_stub(elements.device().type(), elements, test_elements, invert, out);
818 } else {
819 isin_sorting(elements, test_elements, assume_unique, invert, out);
820 }
821 }
822
TORCH_IMPL_FUNC(isin_Tensor_Scalar_out)823 TORCH_IMPL_FUNC(isin_Tensor_Scalar_out) (
824 const Tensor& elements, const c10::Scalar& test_elements, bool assume_unique, bool invert, const Tensor& out
825 ) {
826 // redispatch to eq / ne
827 if (invert) {
828 at::ne_out(const_cast<Tensor&>(out), elements, test_elements);
829 } else {
830 at::eq_out(const_cast<Tensor&>(out), elements, test_elements);
831 }
832 }
833
TORCH_IMPL_FUNC(isin_Scalar_Tensor_out)834 TORCH_IMPL_FUNC(isin_Scalar_Tensor_out) (
835 const c10::Scalar& elements, const Tensor& test_elements, bool assume_unique, bool invert, const Tensor& out
836 ) {
837 // redispatch
838 at::isin_out(const_cast<Tensor&>(out), wrapped_scalar_tensor(elements, test_elements.device()),
839 test_elements, assume_unique, invert);
840 }
841
TORCH_IMPL_FUNC(isposinf_out)842 TORCH_IMPL_FUNC(isposinf_out) (const Tensor& self, const Tensor& result) {
843 if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
844 result.fill_(false);
845 } else {
846 isposinf_stub(device_type(), *this);
847 }
848 }
849
TORCH_IMPL_FUNC(isneginf_out)850 TORCH_IMPL_FUNC(isneginf_out) (const Tensor& self, const Tensor& result) {
851 if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
852 result.fill_(false);
853 } else {
854 isneginf_stub(device_type(), *this);
855 }
856 }
857
858 } // namespace at::native
859