xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/TypeProperties.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/TypeProperties.h>
4 
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/Functions.h>
7 #include <ATen/NativeFunctions.h>
8 #else
9 #include <ATen/ops/_has_compatible_shallow_copy_type_native.h>
10 #include <ATen/ops/_is_zerotensor_native.h>
11 #include <ATen/ops/can_cast_native.h>
12 #include <ATen/ops/is_complex_native.h>
13 #include <ATen/ops/is_conj_native.h>
14 #include <ATen/ops/is_distributed_native.h>
15 #include <ATen/ops/is_floating_point_native.h>
16 #include <ATen/ops/is_inference_native.h>
17 #include <ATen/ops/is_neg_native.h>
18 #include <ATen/ops/is_signed_native.h>
19 #include <ATen/ops/promote_types_native.h>
20 #include <ATen/ops/result_type.h>
21 #include <ATen/ops/result_type_native.h>
22 #include <ATen/ops/type_as_native.h>
23 #endif
24 
25 namespace at::native {
26 
is_distributed(const Tensor & self)27 bool is_distributed(const Tensor& self) {
28   return false;
29 }
30 
is_complex(const Tensor & self)31 bool is_complex(const Tensor& self) {
32   return self.is_complex();
33 }
34 
is_floating_point(const Tensor & self)35 bool is_floating_point(const Tensor& self) {
36   return self.is_floating_point();
37 }
38 
is_inference(const Tensor & self)39 bool is_inference(const Tensor& self) {
40   return self.is_inference();
41 }
42 
is_signed(const Tensor & self)43 bool is_signed(const Tensor &self) {
44   return self.is_signed();
45 }
46 
_is_zerotensor(const Tensor & self)47 bool _is_zerotensor(const Tensor& self) {
48   return self._is_zerotensor();
49 }
50 
is_conj(const Tensor & self)51 bool is_conj(const Tensor& self) {
52   return self.is_conj();
53 }
54 
is_neg(const Tensor & self)55 bool is_neg(const Tensor& self) {
56   return self.is_neg();
57 }
58 
59 // True if `self` and `from` have compatible tensor type so that `from`'s
60 // TensorImpl can be copied to `self`.
_has_compatible_shallow_copy_type(const Tensor & self,const Tensor & from)61 bool _has_compatible_shallow_copy_type(const Tensor& self, const Tensor& from) {
62   return self.unsafeGetTensorImpl()->has_compatible_shallow_copy_type(
63       from.key_set());
64 }
65 
type_as(const Tensor & self,const Tensor & other)66 Tensor type_as(const Tensor& self, const Tensor& other) {
67   return self.to(other.options());
68 }
69 
promote_skip_undefined(ScalarType a,ScalarType b)70 static inline ScalarType promote_skip_undefined(ScalarType a, ScalarType b) {
71   if (a == ScalarType::Undefined) {
72     return b;
73   }
74   if (b == ScalarType::Undefined) {
75     return a;
76   }
77   return promoteTypes(a, b);
78 }
79 
80 
combine_categories(ScalarType higher,ScalarType lower)81 static inline ScalarType combine_categories(ScalarType higher, ScalarType lower) {
82   // NOLINTNEXTLINE(bugprone-branch-clone)
83   if(isComplexType(higher)) {
84     return higher;
85   } else if (isComplexType(lower)) {
86     // preserve value type of higher if it is floating type.
87     if (isFloatingType(higher)) {
88       return toComplexType(higher);
89     }
90     // in case of integral input
91     // lower complex takes precedence.
92     return lower;
93   } else if (isFloatingType(higher)) {
94     return higher;
95   }
96   if (higher == ScalarType::Bool || isFloatingType(lower)) {
97     return promote_skip_undefined(higher, lower);
98   }
99   if (higher != ScalarType::Undefined) {
100     return higher;
101   }
102   return lower;
103 }
104 
update_result_type_state(const Tensor & tensor,const ResultTypeState & in_state)105 ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeState& in_state) {
106   if (!tensor.defined()) {
107     return in_state;
108   }
109   ResultTypeState new_state = in_state;
110   ScalarType current = tensor.scalar_type();
111   if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
112     if(isComplexType(current)) {
113       current = typeMetaToScalarType(at::get_default_complex_dtype());
114     }
115     else if(isFloatingType(current)) {
116       current = typeMetaToScalarType(at::get_default_dtype());
117     }
118   }
119   if ( tensor.dim() > 0 ) {
120     new_state.dimResult = promote_skip_undefined(in_state.dimResult, current);
121   } else if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
122     new_state.wrappedResult = promote_skip_undefined(in_state.wrappedResult, current);
123   } else {
124     new_state.zeroResult = promote_skip_undefined(in_state.zeroResult, current);
125   }
126   return new_state;
127 }
128 
update_result_type_state(const Scalar & scalar,const ResultTypeState & in_state)129 ResultTypeState update_result_type_state(const Scalar& scalar, const ResultTypeState& in_state) {
130   ResultTypeState new_state = in_state;
131   ScalarType current = scalar.type();
132   if (isComplexType(current)) {
133     current = typeMetaToScalarType(at::get_default_complex_dtype());
134   } else if (isFloatingType(current)) {
135     current = typeMetaToScalarType(at::get_default_dtype());
136   }
137   new_state.wrappedResult = promote_skip_undefined(in_state.wrappedResult, current);
138   return new_state;
139 }
140 
result_type(const ResultTypeState & in_state)141 ScalarType result_type(const ResultTypeState& in_state) {
142   return combine_categories(in_state.dimResult, combine_categories(in_state.zeroResult, in_state.wrappedResult));
143 }
144 
result_type(ITensorListRef tensors)145 ScalarType result_type(ITensorListRef tensors) {
146   ResultTypeState state = {};
147   for (const Tensor& tensor : tensors) {
148     state = update_result_type_state(tensor, state);
149   }
150   return result_type(state);
151 }
152 
result_type(const Tensor & tensor,const Tensor & other)153 ScalarType result_type(const Tensor &tensor, const Tensor &other) {
154   ResultTypeState state = {};
155   state = update_result_type_state(tensor, state);
156   state = update_result_type_state(other, state);
157   return result_type(state);
158 }
159 
result_type(const Tensor & tensor,const Scalar & other)160 ScalarType result_type(const Tensor &tensor, const Scalar& other) {
161   ResultTypeState state = {};
162   state = update_result_type_state(tensor, state);
163   state = update_result_type_state(other, state);
164   return result_type(state);
165 }
166 
result_type(const Scalar & scalar,const Tensor & tensor)167 ScalarType result_type(const Scalar& scalar, const Tensor &tensor) {
168   return at::result_type(tensor, scalar);
169 }
170 
result_type(const Scalar & scalar1,const Scalar & scalar2)171 ScalarType result_type(const Scalar& scalar1, const Scalar& scalar2) {
172   ResultTypeState state = {};
173   state = update_result_type_state(scalar1, state);
174   state = update_result_type_state(scalar2, state);
175   return result_type(state);
176 }
177 
can_cast(const at::ScalarType from_,const at::ScalarType to)178 bool can_cast(const at::ScalarType from_, const at::ScalarType to) {
179   return at::canCast(from_, to);
180 }
181 
promote_types(ScalarType type1,ScalarType type2)182 ScalarType promote_types(ScalarType type1, ScalarType type2) {
183   ScalarType ret = promoteTypes(type1, type2);
184   TORCH_CHECK(ret != ScalarType::Undefined, "Promotion from ", type1, " and ", type2, " is unsupported.");
185   return ret;
186 }
187 
188 } // namespace at::native
189