1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3
4 #ifndef AT_PER_OPERATOR_HEADERS
5 #include <ATen/NativeFunctions.h>
6 #else
7 #include <ATen/ops/is_nonzero_native.h>
8 #include <ATen/ops/_foobar_native.h>
9 #include <ATen/ops/_test_functorch_fallback_native.h>
10 #endif
11
12 namespace at::native {
13
is_nonzero(const Tensor & self)14 bool is_nonzero(const Tensor& self) {
15 auto n = self.numel();
16 TORCH_CHECK(n != 0, "Boolean value of Tensor with no values is ambiguous");
17 TORCH_CHECK(
18 n < 2, "Boolean value of Tensor with more than one value is ambiguous");
19
20 Scalar localScalar = self.item();
21 if (localScalar.isFloatingPoint()) {
22 return localScalar.to<double>() != 0;
23 } else if (localScalar.isComplex()) {
24 return localScalar.to<c10::complex<double>>() !=
25 c10::complex<double>(0.0, 0.0);
26 } else if (localScalar.isIntegral(false)) {
27 return localScalar.to<int64_t>() != 0;
28 } else if (localScalar.isBoolean()) {
29 return localScalar.to<bool>();
30 }
31 TORCH_INTERNAL_ASSERT(false, "Expected non-Tensor backend scalar");
32 }
33
34
35 // Aux function used in the test TestPythonDispatch.test_kwarg_only_and_positional_default
36 // within test/test_python_dispatch.py
foobar(const Tensor & self,bool arg1,bool arg2,bool arg3)37 Tensor foobar(const Tensor& self, bool arg1, bool arg2, bool arg3) {
38 return self;
39 }
40
41 // Aux function used to test functorch fallback warning
_test_functorch_fallback(const Tensor & self,const Tensor & other)42 Tensor _test_functorch_fallback(const Tensor& self, const Tensor& other) {
43 return self.clone();
44 }
45
46 } // namespace at::meta
47