1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/TypeProperties.h>
4
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/Functions.h>
7 #include <ATen/NativeFunctions.h>
8 #else
9 #include <ATen/ops/_has_compatible_shallow_copy_type_native.h>
10 #include <ATen/ops/_is_zerotensor_native.h>
11 #include <ATen/ops/can_cast_native.h>
12 #include <ATen/ops/is_complex_native.h>
13 #include <ATen/ops/is_conj_native.h>
14 #include <ATen/ops/is_distributed_native.h>
15 #include <ATen/ops/is_floating_point_native.h>
16 #include <ATen/ops/is_inference_native.h>
17 #include <ATen/ops/is_neg_native.h>
18 #include <ATen/ops/is_signed_native.h>
19 #include <ATen/ops/promote_types_native.h>
20 #include <ATen/ops/result_type.h>
21 #include <ATen/ops/result_type_native.h>
22 #include <ATen/ops/type_as_native.h>
23 #endif
24
25 namespace at::native {
26
is_distributed(const Tensor & self)27 bool is_distributed(const Tensor& self) {
28 return false;
29 }
30
is_complex(const Tensor & self)31 bool is_complex(const Tensor& self) {
32 return self.is_complex();
33 }
34
is_floating_point(const Tensor & self)35 bool is_floating_point(const Tensor& self) {
36 return self.is_floating_point();
37 }
38
is_inference(const Tensor & self)39 bool is_inference(const Tensor& self) {
40 return self.is_inference();
41 }
42
is_signed(const Tensor & self)43 bool is_signed(const Tensor &self) {
44 return self.is_signed();
45 }
46
_is_zerotensor(const Tensor & self)47 bool _is_zerotensor(const Tensor& self) {
48 return self._is_zerotensor();
49 }
50
is_conj(const Tensor & self)51 bool is_conj(const Tensor& self) {
52 return self.is_conj();
53 }
54
is_neg(const Tensor & self)55 bool is_neg(const Tensor& self) {
56 return self.is_neg();
57 }
58
59 // True if `self` and `from` have compatible tensor type so that `from`'s
60 // TensorImpl can be copied to `self`.
_has_compatible_shallow_copy_type(const Tensor & self,const Tensor & from)61 bool _has_compatible_shallow_copy_type(const Tensor& self, const Tensor& from) {
62 return self.unsafeGetTensorImpl()->has_compatible_shallow_copy_type(
63 from.key_set());
64 }
65
type_as(const Tensor & self,const Tensor & other)66 Tensor type_as(const Tensor& self, const Tensor& other) {
67 return self.to(other.options());
68 }
69
promote_skip_undefined(ScalarType a,ScalarType b)70 static inline ScalarType promote_skip_undefined(ScalarType a, ScalarType b) {
71 if (a == ScalarType::Undefined) {
72 return b;
73 }
74 if (b == ScalarType::Undefined) {
75 return a;
76 }
77 return promoteTypes(a, b);
78 }
79
80
combine_categories(ScalarType higher,ScalarType lower)81 static inline ScalarType combine_categories(ScalarType higher, ScalarType lower) {
82 // NOLINTNEXTLINE(bugprone-branch-clone)
83 if(isComplexType(higher)) {
84 return higher;
85 } else if (isComplexType(lower)) {
86 // preserve value type of higher if it is floating type.
87 if (isFloatingType(higher)) {
88 return toComplexType(higher);
89 }
90 // in case of integral input
91 // lower complex takes precedence.
92 return lower;
93 } else if (isFloatingType(higher)) {
94 return higher;
95 }
96 if (higher == ScalarType::Bool || isFloatingType(lower)) {
97 return promote_skip_undefined(higher, lower);
98 }
99 if (higher != ScalarType::Undefined) {
100 return higher;
101 }
102 return lower;
103 }
104
update_result_type_state(const Tensor & tensor,const ResultTypeState & in_state)105 ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeState& in_state) {
106 if (!tensor.defined()) {
107 return in_state;
108 }
109 ResultTypeState new_state = in_state;
110 ScalarType current = tensor.scalar_type();
111 if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
112 if(isComplexType(current)) {
113 current = typeMetaToScalarType(at::get_default_complex_dtype());
114 }
115 else if(isFloatingType(current)) {
116 current = typeMetaToScalarType(at::get_default_dtype());
117 }
118 }
119 if ( tensor.dim() > 0 ) {
120 new_state.dimResult = promote_skip_undefined(in_state.dimResult, current);
121 } else if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
122 new_state.wrappedResult = promote_skip_undefined(in_state.wrappedResult, current);
123 } else {
124 new_state.zeroResult = promote_skip_undefined(in_state.zeroResult, current);
125 }
126 return new_state;
127 }
128
update_result_type_state(const Scalar & scalar,const ResultTypeState & in_state)129 ResultTypeState update_result_type_state(const Scalar& scalar, const ResultTypeState& in_state) {
130 ResultTypeState new_state = in_state;
131 ScalarType current = scalar.type();
132 if (isComplexType(current)) {
133 current = typeMetaToScalarType(at::get_default_complex_dtype());
134 } else if (isFloatingType(current)) {
135 current = typeMetaToScalarType(at::get_default_dtype());
136 }
137 new_state.wrappedResult = promote_skip_undefined(in_state.wrappedResult, current);
138 return new_state;
139 }
140
result_type(const ResultTypeState & in_state)141 ScalarType result_type(const ResultTypeState& in_state) {
142 return combine_categories(in_state.dimResult, combine_categories(in_state.zeroResult, in_state.wrappedResult));
143 }
144
result_type(ITensorListRef tensors)145 ScalarType result_type(ITensorListRef tensors) {
146 ResultTypeState state = {};
147 for (const Tensor& tensor : tensors) {
148 state = update_result_type_state(tensor, state);
149 }
150 return result_type(state);
151 }
152
result_type(const Tensor & tensor,const Tensor & other)153 ScalarType result_type(const Tensor &tensor, const Tensor &other) {
154 ResultTypeState state = {};
155 state = update_result_type_state(tensor, state);
156 state = update_result_type_state(other, state);
157 return result_type(state);
158 }
159
result_type(const Tensor & tensor,const Scalar & other)160 ScalarType result_type(const Tensor &tensor, const Scalar& other) {
161 ResultTypeState state = {};
162 state = update_result_type_state(tensor, state);
163 state = update_result_type_state(other, state);
164 return result_type(state);
165 }
166
result_type(const Scalar & scalar,const Tensor & tensor)167 ScalarType result_type(const Scalar& scalar, const Tensor &tensor) {
168 return at::result_type(tensor, scalar);
169 }
170
result_type(const Scalar & scalar1,const Scalar & scalar2)171 ScalarType result_type(const Scalar& scalar1, const Scalar& scalar2) {
172 ResultTypeState state = {};
173 state = update_result_type_state(scalar1, state);
174 state = update_result_type_state(scalar2, state);
175 return result_type(state);
176 }
177
can_cast(const at::ScalarType from_,const at::ScalarType to)178 bool can_cast(const at::ScalarType from_, const at::ScalarType to) {
179 return at::canCast(from_, to);
180 }
181
promote_types(ScalarType type1,ScalarType type2)182 ScalarType promote_types(ScalarType type1, ScalarType type2) {
183 ScalarType ret = promoteTypes(type1, type2);
184 TORCH_CHECK(ret != ScalarType::Undefined, "Promotion from ", type1, " and ", type2, " is unsupported.");
185 return ret;
186 }
187
188 } // namespace at::native
189