1 #include <c10/core/ConstantSymNodeImpl.h>
2 #include <c10/core/SymFloat.h>
3 #include <c10/core/SymInt.h>
4 #include <c10/core/SymNodeImpl.h>
5 #include <c10/util/intrusive_ptr.h>
6 #include <c10/util/safe_numerics.h>
7 #include <functional>
8
9 namespace c10 {
10
11 // Precondition: data_ has a large negative number that should be
12 // treated as a constant. It is NOT a valid pointer. In other words,
13 // SymInt has temporarily violated invariants
14 // Postcondition: invariants on SymInt are fixed
promote_to_negative()15 void SymInt::promote_to_negative() {
16 auto s =
17 SymInt(SymNode(c10::make_intrusive<ConstantSymNodeImpl<int64_t>>(data_)));
18 // Similar to move operator=, but do NOT release data_
19 data_ = s.data_;
20 s.data_ = 0;
21 }
22
toSymNode() const23 SymNode SymInt::toSymNode() const {
24 TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(
25 is_heap_allocated(), "SymInt::toSymNode is_heap_allocated");
26 return SymNode::reclaim_copy(toSymNodeImplUnowned());
27 }
28
SymInt(SymNode sin_sp)29 SymInt::SymInt(SymNode sin_sp) {
30 TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(
31 sin_sp->is_int(), "SymInt::SymInt sin_sp->is_int()");
32 auto ptr = static_cast<uint64_t>(
33 reinterpret_cast<uintptr_t>(static_cast<void*>(sin_sp.release())));
34 auto rep = (ptr & ~MASK) | IS_SYM;
35 data_ = static_cast<int64_t>(rep);
36 }
37
has_hint() const38 bool SymInt::has_hint() const {
39 if (!is_heap_allocated()) {
40 return true;
41 }
42 return toSymNodeImplUnowned()->has_hint();
43 }
44
45 #define DEFINE_BINARY(API, OP, METHOD, RET) \
46 RET SymInt::API(const SymInt& sci) const { \
47 if (auto ma = maybe_as_int()) { \
48 if (auto mb = sci.maybe_as_int()) { \
49 return RET(OP(*ma, *mb)); \
50 } else { \
51 auto b = sci.toSymNode(); \
52 return RET(b->wrap_int(*ma)->METHOD(b)); \
53 } \
54 } else { \
55 if (auto mb = sci.maybe_as_int()) { \
56 auto a = toSymNodeImplUnowned(); \
57 return RET(a->METHOD(a->wrap_int(*mb))); \
58 } else { \
59 return RET(toSymNodeImplUnowned()->METHOD(sci.toSymNode())); \
60 } \
61 } \
62 }
63
64 // clang-format off
65 DEFINE_BINARY(operator+, std::plus<>(), add, SymInt)
66 DEFINE_BINARY(operator-, std::minus<>(), sub, SymInt)
DEFINE_BINARY(operator *,std::multiplies<> (),mul,SymInt)67 DEFINE_BINARY(operator*, std::multiplies<>(), mul, SymInt)
68 DEFINE_BINARY(operator/, std::divides<>(), floordiv, SymInt)
69 DEFINE_BINARY(operator%, std::modulus<>(), mod, SymInt)
70 DEFINE_BINARY(sym_eq, std::equal_to<>(), eq, SymBool)
71 DEFINE_BINARY(sym_ne, std::not_equal_to<>(), ne, SymBool)
72 DEFINE_BINARY(sym_lt, std::less<>(), lt, SymBool)
73 DEFINE_BINARY(sym_le, std::less_equal<>(), le, SymBool)
74 DEFINE_BINARY(sym_gt, std::greater<>(), gt, SymBool)
75 DEFINE_BINARY(sym_ge, std::greater_equal<>(), ge, SymBool)
76 DEFINE_BINARY(min, std::min, sym_min, SymInt)
77 DEFINE_BINARY(max, std::max, sym_max, SymInt)
78 // clang-format on
79
80 SymInt::operator SymFloat() const {
81 if (auto ma = maybe_as_int()) {
82 return SymFloat(double(*ma));
83 } else {
84 return SymFloat(toSymNodeImplUnowned()->sym_float());
85 }
86 }
87
is_same(const SymInt & other) const88 bool SymInt::is_same(const SymInt& other) const {
89 if (is_heap_allocated() != other.is_heap_allocated()) {
90 return false;
91 }
92 // Both not heap allocated
93 if (!is_heap_allocated() && this->operator!=(other)) {
94 return false;
95 }
96 // Both heap allocated
97 if (is_heap_allocated() &&
98 toSymNodeImplUnowned() != other.toSymNodeImplUnowned()) {
99 return false;
100 }
101 return true;
102 }
103
wrap_node(const SymNode & base) const104 SymNode SymInt::wrap_node(const SymNode& base) const {
105 if (auto ma = maybe_as_int()) {
106 return base->wrap_int(*ma);
107 } else {
108 return toSymNode();
109 }
110 }
111
clone() const112 SymInt SymInt::clone() const {
113 if (auto ma = maybe_as_int()) {
114 return SymInt(*ma);
115 } else {
116 return SymInt(toSymNodeImplUnowned()->clone());
117 }
118 }
119
guard_int(const char * file,int64_t line) const120 int64_t SymInt::guard_int(const char* file, int64_t line) const {
121 if (auto ma = maybe_as_int()) {
122 return *ma;
123 } else {
124 return toSymNodeImplUnowned()->guard_int(file, line);
125 }
126 }
127
expect_size(const char * file,int64_t line) const128 bool SymInt::expect_size(const char* file, int64_t line) const {
129 if (auto ma = maybe_as_int()) {
130 return *ma >= 0;
131 } else {
132 return toSymNodeImplUnowned()->expect_size(file, line);
133 }
134 }
135
operator -(const SymInt & s)136 SymInt operator-(const SymInt& s) {
137 if (auto ma = s.maybe_as_int()) {
138 const auto val = *ma;
139 // Note: Result of `-std::numeric_limits<decltype(val)>::min()` is undefined
140 // But on many platforms it equals to self + setting Carry/Overflow flags
141 // Which in opimized code affects results of `check_range` condition
142 // Workaround by using ternary that avoids alterning the flags
143 #if C10_HAS_BUILTIN_OVERFLOW()
144 std::decay_t<decltype(val)> out = 0;
145 if (C10_UNLIKELY(__builtin_sub_overflow(out, val, &out))) {
146 return SymInt(val);
147 }
148 return SymInt(out);
149 #else
150 constexpr auto val_min = std::numeric_limits<decltype(val)>::min();
151 return SymInt(val != val_min ? -val : val_min);
152 #endif
153 } else {
154 return SymInt(s.toSymNodeImplUnowned()->neg());
155 }
156 }
157
operator *=(const SymInt & sci)158 void SymInt::operator*=(const SymInt& sci) {
159 *this = *this * sci;
160 }
161
operator /=(const SymInt & sci)162 void SymInt::operator/=(const SymInt& sci) {
163 *this = *this / sci;
164 }
165
operator +=(const SymInt & sci)166 void SymInt::operator+=(const SymInt& sci) {
167 *this = *this + sci;
168 }
169
operator <<(std::ostream & os,const SymInt & s)170 std::ostream& operator<<(std::ostream& os, const SymInt& s) {
171 if (s.is_heap_allocated()) {
172 os << s.toSymNodeImplUnowned()->str();
173 } else {
174 os << s.as_int_unchecked();
175 }
176 return os;
177 }
178
179 // This template lets us not do a refcount bump when we do an
180 // identity conversion
181 template <typename T>
182 struct Convert {};
183
184 template <>
185 struct Convert<SymInt> {
operator ()c10::Convert186 const SymInt& operator()(const SymInt& a) {
187 return a;
188 }
189 };
190
191 template <>
192 struct Convert<SymFloat> {
operator ()c10::Convert193 SymFloat operator()(const SymInt& a) {
194 return a;
195 }
196 };
197
198 #define DEFINE_SYMINT_OP_INTONLY(scalar_t, RetTy) \
199 RetTy operator%(const SymInt& a, scalar_t b) { \
200 return Convert<RetTy>()(a) % RetTy(b); \
201 }; \
202 RetTy operator%(scalar_t a, const SymInt& b) { \
203 return RetTy(a) % Convert<RetTy>()(b); \
204 };
205
206 #define DEFINE_SYMINT_OP(scalar_t, RetTy) \
207 RetTy operator+(const SymInt& a, scalar_t b) { \
208 return Convert<RetTy>()(a) + RetTy(b); \
209 }; \
210 RetTy operator-(const SymInt& a, scalar_t b) { \
211 return Convert<RetTy>()(a) - RetTy(b); \
212 }; \
213 RetTy operator*(const SymInt& a, scalar_t b) { \
214 return Convert<RetTy>()(a) * RetTy(b); \
215 }; \
216 RetTy operator/(const SymInt& a, scalar_t b) { \
217 return Convert<RetTy>()(a) / RetTy(b); \
218 }; \
219 RetTy operator+(scalar_t a, const SymInt& b) { \
220 return RetTy(a) + Convert<RetTy>()(b); \
221 }; \
222 RetTy operator-(scalar_t a, const SymInt& b) { \
223 return RetTy(a) - Convert<RetTy>()(b); \
224 }; \
225 RetTy operator*(scalar_t a, const SymInt& b) { \
226 return RetTy(a) * Convert<RetTy>()(b); \
227 }; \
228 RetTy operator/(scalar_t a, const SymInt& b) { \
229 return RetTy(a) / Convert<RetTy>()(b); \
230 }; \
231 bool operator==(const SymInt& a, scalar_t b) { \
232 return Convert<RetTy>()(a) == RetTy(b); \
233 }; \
234 bool operator!=(const SymInt& a, scalar_t b) { \
235 return Convert<RetTy>()(a) != RetTy(b); \
236 }; \
237 bool operator<(const SymInt& a, scalar_t b) { \
238 return Convert<RetTy>()(a) < RetTy(b); \
239 }; \
240 bool operator<=(const SymInt& a, scalar_t b) { \
241 return Convert<RetTy>()(a) <= RetTy(b); \
242 }; \
243 bool operator>(const SymInt& a, scalar_t b) { \
244 return Convert<RetTy>()(a) > RetTy(b); \
245 }; \
246 bool operator>=(const SymInt& a, scalar_t b) { \
247 return Convert<RetTy>()(a) >= RetTy(b); \
248 }; \
249 bool operator==(scalar_t a, const SymInt& b) { \
250 return RetTy(a) == Convert<RetTy>()(b); \
251 }; \
252 bool operator!=(scalar_t a, const SymInt& b) { \
253 return RetTy(a) != Convert<RetTy>()(b); \
254 }; \
255 bool operator<(scalar_t a, const SymInt& b) { \
256 return RetTy(a) < Convert<RetTy>()(b); \
257 }; \
258 bool operator<=(scalar_t a, const SymInt& b) { \
259 return RetTy(a) <= Convert<RetTy>()(b); \
260 }; \
261 bool operator>(scalar_t a, const SymInt& b) { \
262 return RetTy(a) > Convert<RetTy>()(b); \
263 }; \
264 bool operator>=(scalar_t a, const SymInt& b) { \
265 return RetTy(a) >= Convert<RetTy>()(b); \
266 };
267
268 DEFINE_SYMINT_OP_INTONLY(int64_t, SymInt)
269 DEFINE_SYMINT_OP_INTONLY(int32_t, SymInt)
270 DEFINE_SYMINT_OP_INTONLY(uint64_t, SymInt)
271 DEFINE_SYMINT_OP_INTONLY(uint32_t, SymInt)
272 DEFINE_SYMINT_OP(int64_t, SymInt)
273 DEFINE_SYMINT_OP(int32_t, SymInt) // make sure constants work
274 DEFINE_SYMINT_OP(uint64_t, SymInt)
275 DEFINE_SYMINT_OP(uint32_t, SymInt)
276 DEFINE_SYMINT_OP(double, SymFloat)
277 DEFINE_SYMINT_OP(float, SymFloat) // just for completeness
278
279 #if defined(__APPLE__)
280 DEFINE_SYMINT_OP_INTONLY(size_t, SymInt) // needed for osx
281 DEFINE_SYMINT_OP(size_t, SymInt) // needed for osx
282 #endif
283
284 } // namespace c10
285