xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/prim_native_functions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 
4 #ifndef AT_PER_OPERATOR_HEADERS
5 #include <ATen/NativeFunctions.h>
6 #else
7 #include <ATen/ops/is_nonzero_native.h>
8 #include <ATen/ops/_foobar_native.h>
9 #include <ATen/ops/_test_functorch_fallback_native.h>
10 #endif
11 
12 namespace at::native {
13 
is_nonzero(const Tensor & self)14 bool is_nonzero(const Tensor& self) {
15   auto n = self.numel();
16   TORCH_CHECK(n != 0, "Boolean value of Tensor with no values is ambiguous");
17   TORCH_CHECK(
18       n < 2, "Boolean value of Tensor with more than one value is ambiguous");
19 
20   Scalar localScalar = self.item();
21   if (localScalar.isFloatingPoint()) {
22     return localScalar.to<double>() != 0;
23   } else if (localScalar.isComplex()) {
24     return localScalar.to<c10::complex<double>>() !=
25         c10::complex<double>(0.0, 0.0);
26   } else if (localScalar.isIntegral(false)) {
27     return localScalar.to<int64_t>() != 0;
28   } else if (localScalar.isBoolean()) {
29     return localScalar.to<bool>();
30   }
31   TORCH_INTERNAL_ASSERT(false, "Expected non-Tensor backend scalar");
32 }
33 
34 
35 // Aux function used in the test TestPythonDispatch.test_kwarg_only_and_positional_default
36 // within test/test_python_dispatch.py
foobar(const Tensor & self,bool arg1,bool arg2,bool arg3)37 Tensor foobar(const Tensor& self, bool arg1, bool arg2, bool arg3) {
38   return self;
39 }
40 
41 // Aux function used to test functorch fallback warning
_test_functorch_fallback(const Tensor & self,const Tensor & other)42 Tensor _test_functorch_fallback(const Tensor& self, const Tensor& other) {
43   return self.clone();
44 }
45 
46 } // namespace at::meta
47