xref: /aosp_15_r20/external/pytorch/c10/core/SymInt.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/ConstantSymNodeImpl.h>
2 #include <c10/core/SymFloat.h>
3 #include <c10/core/SymInt.h>
4 #include <c10/core/SymNodeImpl.h>
5 #include <c10/util/intrusive_ptr.h>
6 #include <c10/util/safe_numerics.h>
7 #include <functional>
8 
9 namespace c10 {
10 
11 // Precondition: data_ has a large negative number that should be
12 // treated as a constant.  It is NOT a valid pointer.  In other words,
13 // SymInt has temporarily violated invariants
14 // Postcondition: invariants on SymInt are fixed
promote_to_negative()15 void SymInt::promote_to_negative() {
16   auto s =
17       SymInt(SymNode(c10::make_intrusive<ConstantSymNodeImpl<int64_t>>(data_)));
18   // Similar to move operator=, but do NOT release data_
19   data_ = s.data_;
20   s.data_ = 0;
21 }
22 
toSymNode() const23 SymNode SymInt::toSymNode() const {
24   TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(
25       is_heap_allocated(), "SymInt::toSymNode is_heap_allocated");
26   return SymNode::reclaim_copy(toSymNodeImplUnowned());
27 }
28 
SymInt(SymNode sin_sp)29 SymInt::SymInt(SymNode sin_sp) {
30   TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(
31       sin_sp->is_int(), "SymInt::SymInt sin_sp->is_int()");
32   auto ptr = static_cast<uint64_t>(
33       reinterpret_cast<uintptr_t>(static_cast<void*>(sin_sp.release())));
34   auto rep = (ptr & ~MASK) | IS_SYM;
35   data_ = static_cast<int64_t>(rep);
36 }
37 
has_hint() const38 bool SymInt::has_hint() const {
39   if (!is_heap_allocated()) {
40     return true;
41   }
42   return toSymNodeImplUnowned()->has_hint();
43 }
44 
45 #define DEFINE_BINARY(API, OP, METHOD, RET)                          \
46   RET SymInt::API(const SymInt& sci) const {                         \
47     if (auto ma = maybe_as_int()) {                                  \
48       if (auto mb = sci.maybe_as_int()) {                            \
49         return RET(OP(*ma, *mb));                                    \
50       } else {                                                       \
51         auto b = sci.toSymNode();                                    \
52         return RET(b->wrap_int(*ma)->METHOD(b));                     \
53       }                                                              \
54     } else {                                                         \
55       if (auto mb = sci.maybe_as_int()) {                            \
56         auto a = toSymNodeImplUnowned();                             \
57         return RET(a->METHOD(a->wrap_int(*mb)));                     \
58       } else {                                                       \
59         return RET(toSymNodeImplUnowned()->METHOD(sci.toSymNode())); \
60       }                                                              \
61     }                                                                \
62   }
63 
64 // clang-format off
65 DEFINE_BINARY(operator+, std::plus<>(), add, SymInt)
66 DEFINE_BINARY(operator-, std::minus<>(), sub, SymInt)
DEFINE_BINARY(operator *,std::multiplies<> (),mul,SymInt)67 DEFINE_BINARY(operator*, std::multiplies<>(), mul, SymInt)
68 DEFINE_BINARY(operator/, std::divides<>(), floordiv, SymInt)
69 DEFINE_BINARY(operator%, std::modulus<>(), mod, SymInt)
70 DEFINE_BINARY(sym_eq, std::equal_to<>(), eq, SymBool)
71 DEFINE_BINARY(sym_ne, std::not_equal_to<>(), ne, SymBool)
72 DEFINE_BINARY(sym_lt, std::less<>(), lt, SymBool)
73 DEFINE_BINARY(sym_le, std::less_equal<>(), le, SymBool)
74 DEFINE_BINARY(sym_gt, std::greater<>(), gt, SymBool)
75 DEFINE_BINARY(sym_ge, std::greater_equal<>(), ge, SymBool)
76 DEFINE_BINARY(min, std::min, sym_min, SymInt)
77 DEFINE_BINARY(max, std::max, sym_max, SymInt)
78 // clang-format on
79 
80 SymInt::operator SymFloat() const {
81   if (auto ma = maybe_as_int()) {
82     return SymFloat(double(*ma));
83   } else {
84     return SymFloat(toSymNodeImplUnowned()->sym_float());
85   }
86 }
87 
is_same(const SymInt & other) const88 bool SymInt::is_same(const SymInt& other) const {
89   if (is_heap_allocated() != other.is_heap_allocated()) {
90     return false;
91   }
92   // Both not heap allocated
93   if (!is_heap_allocated() && this->operator!=(other)) {
94     return false;
95   }
96   // Both heap allocated
97   if (is_heap_allocated() &&
98       toSymNodeImplUnowned() != other.toSymNodeImplUnowned()) {
99     return false;
100   }
101   return true;
102 }
103 
wrap_node(const SymNode & base) const104 SymNode SymInt::wrap_node(const SymNode& base) const {
105   if (auto ma = maybe_as_int()) {
106     return base->wrap_int(*ma);
107   } else {
108     return toSymNode();
109   }
110 }
111 
clone() const112 SymInt SymInt::clone() const {
113   if (auto ma = maybe_as_int()) {
114     return SymInt(*ma);
115   } else {
116     return SymInt(toSymNodeImplUnowned()->clone());
117   }
118 }
119 
guard_int(const char * file,int64_t line) const120 int64_t SymInt::guard_int(const char* file, int64_t line) const {
121   if (auto ma = maybe_as_int()) {
122     return *ma;
123   } else {
124     return toSymNodeImplUnowned()->guard_int(file, line);
125   }
126 }
127 
expect_size(const char * file,int64_t line) const128 bool SymInt::expect_size(const char* file, int64_t line) const {
129   if (auto ma = maybe_as_int()) {
130     return *ma >= 0;
131   } else {
132     return toSymNodeImplUnowned()->expect_size(file, line);
133   }
134 }
135 
operator -(const SymInt & s)136 SymInt operator-(const SymInt& s) {
137   if (auto ma = s.maybe_as_int()) {
138     const auto val = *ma;
139     // Note: Result of `-std::numeric_limits<decltype(val)>::min()` is undefined
140     // But on many platforms it equals to self + setting Carry/Overflow flags
141     // Which in opimized code affects results of `check_range` condition
142     // Workaround by using ternary that avoids alterning the flags
143 #if C10_HAS_BUILTIN_OVERFLOW()
144     std::decay_t<decltype(val)> out = 0;
145     if (C10_UNLIKELY(__builtin_sub_overflow(out, val, &out))) {
146       return SymInt(val);
147     }
148     return SymInt(out);
149 #else
150     constexpr auto val_min = std::numeric_limits<decltype(val)>::min();
151     return SymInt(val != val_min ? -val : val_min);
152 #endif
153   } else {
154     return SymInt(s.toSymNodeImplUnowned()->neg());
155   }
156 }
157 
operator *=(const SymInt & sci)158 void SymInt::operator*=(const SymInt& sci) {
159   *this = *this * sci;
160 }
161 
operator /=(const SymInt & sci)162 void SymInt::operator/=(const SymInt& sci) {
163   *this = *this / sci;
164 }
165 
operator +=(const SymInt & sci)166 void SymInt::operator+=(const SymInt& sci) {
167   *this = *this + sci;
168 }
169 
operator <<(std::ostream & os,const SymInt & s)170 std::ostream& operator<<(std::ostream& os, const SymInt& s) {
171   if (s.is_heap_allocated()) {
172     os << s.toSymNodeImplUnowned()->str();
173   } else {
174     os << s.as_int_unchecked();
175   }
176   return os;
177 }
178 
179 // This template lets us not do a refcount bump when we do an
180 // identity conversion
181 template <typename T>
182 struct Convert {};
183 
184 template <>
185 struct Convert<SymInt> {
operator ()c10::Convert186   const SymInt& operator()(const SymInt& a) {
187     return a;
188   }
189 };
190 
191 template <>
192 struct Convert<SymFloat> {
operator ()c10::Convert193   SymFloat operator()(const SymInt& a) {
194     return a;
195   }
196 };
197 
198 #define DEFINE_SYMINT_OP_INTONLY(scalar_t, RetTy) \
199   RetTy operator%(const SymInt& a, scalar_t b) {  \
200     return Convert<RetTy>()(a) % RetTy(b);        \
201   };                                              \
202   RetTy operator%(scalar_t a, const SymInt& b) {  \
203     return RetTy(a) % Convert<RetTy>()(b);        \
204   };
205 
206 #define DEFINE_SYMINT_OP(scalar_t, RetTy)        \
207   RetTy operator+(const SymInt& a, scalar_t b) { \
208     return Convert<RetTy>()(a) + RetTy(b);       \
209   };                                             \
210   RetTy operator-(const SymInt& a, scalar_t b) { \
211     return Convert<RetTy>()(a) - RetTy(b);       \
212   };                                             \
213   RetTy operator*(const SymInt& a, scalar_t b) { \
214     return Convert<RetTy>()(a) * RetTy(b);       \
215   };                                             \
216   RetTy operator/(const SymInt& a, scalar_t b) { \
217     return Convert<RetTy>()(a) / RetTy(b);       \
218   };                                             \
219   RetTy operator+(scalar_t a, const SymInt& b) { \
220     return RetTy(a) + Convert<RetTy>()(b);       \
221   };                                             \
222   RetTy operator-(scalar_t a, const SymInt& b) { \
223     return RetTy(a) - Convert<RetTy>()(b);       \
224   };                                             \
225   RetTy operator*(scalar_t a, const SymInt& b) { \
226     return RetTy(a) * Convert<RetTy>()(b);       \
227   };                                             \
228   RetTy operator/(scalar_t a, const SymInt& b) { \
229     return RetTy(a) / Convert<RetTy>()(b);       \
230   };                                             \
231   bool operator==(const SymInt& a, scalar_t b) { \
232     return Convert<RetTy>()(a) == RetTy(b);      \
233   };                                             \
234   bool operator!=(const SymInt& a, scalar_t b) { \
235     return Convert<RetTy>()(a) != RetTy(b);      \
236   };                                             \
237   bool operator<(const SymInt& a, scalar_t b) {  \
238     return Convert<RetTy>()(a) < RetTy(b);       \
239   };                                             \
240   bool operator<=(const SymInt& a, scalar_t b) { \
241     return Convert<RetTy>()(a) <= RetTy(b);      \
242   };                                             \
243   bool operator>(const SymInt& a, scalar_t b) {  \
244     return Convert<RetTy>()(a) > RetTy(b);       \
245   };                                             \
246   bool operator>=(const SymInt& a, scalar_t b) { \
247     return Convert<RetTy>()(a) >= RetTy(b);      \
248   };                                             \
249   bool operator==(scalar_t a, const SymInt& b) { \
250     return RetTy(a) == Convert<RetTy>()(b);      \
251   };                                             \
252   bool operator!=(scalar_t a, const SymInt& b) { \
253     return RetTy(a) != Convert<RetTy>()(b);      \
254   };                                             \
255   bool operator<(scalar_t a, const SymInt& b) {  \
256     return RetTy(a) < Convert<RetTy>()(b);       \
257   };                                             \
258   bool operator<=(scalar_t a, const SymInt& b) { \
259     return RetTy(a) <= Convert<RetTy>()(b);      \
260   };                                             \
261   bool operator>(scalar_t a, const SymInt& b) {  \
262     return RetTy(a) > Convert<RetTy>()(b);       \
263   };                                             \
264   bool operator>=(scalar_t a, const SymInt& b) { \
265     return RetTy(a) >= Convert<RetTy>()(b);      \
266   };
267 
268 DEFINE_SYMINT_OP_INTONLY(int64_t, SymInt)
269 DEFINE_SYMINT_OP_INTONLY(int32_t, SymInt)
270 DEFINE_SYMINT_OP_INTONLY(uint64_t, SymInt)
271 DEFINE_SYMINT_OP_INTONLY(uint32_t, SymInt)
272 DEFINE_SYMINT_OP(int64_t, SymInt)
273 DEFINE_SYMINT_OP(int32_t, SymInt) // make sure constants work
274 DEFINE_SYMINT_OP(uint64_t, SymInt)
275 DEFINE_SYMINT_OP(uint32_t, SymInt)
276 DEFINE_SYMINT_OP(double, SymFloat)
277 DEFINE_SYMINT_OP(float, SymFloat) // just for completeness
278 
279 #if defined(__APPLE__)
280 DEFINE_SYMINT_OP_INTONLY(size_t, SymInt) // needed for osx
281 DEFINE_SYMINT_OP(size_t, SymInt) // needed for osx
282 #endif
283 
284 } // namespace c10
285