1 #pragma once
2
3 #include <c10/core/SymBool.h>
4 #include <c10/core/SymNodeImpl.h>
5 #include <c10/macros/Export.h>
6 #include <c10/macros/Macros.h>
7 #include <c10/util/Exception.h>
8 #include <c10/util/Optional.h>
9
10 #include <cstdint>
11 #include <iterator>
12 #include <numeric>
13 #include <optional>
14 #include <ostream>
15 #include <type_traits>
16
17 namespace c10 {
18
19 class SymFloat;
20
21 // SymInt represents either a regular int64_t, or a symbolic integer
22 // (represented in a type erased way as SymNode). The intention is for SymInt
23 // to represent symbolic sizes that arise when doing shape computation in
24 // operator kernels. This allows for tracing through programs without baking in
25 // concrete sizes into kernel calls.
26 //
27 // SymInt has an API equivalent to int64_t. In particular, it is a value type.
28 // Internally, SymInt is represented in a clever packed way, so that it only
29 // occupies one word of space; but morally, it is a union between an int64_t
30 // and an intrusive pointer to SymNodeImpl.
31 //
32 // Invariant: the referenced SymNodeImpl is guaranteed to be a SymNode where
33 // is_int() returns true
34
35 class C10_API SymInt {
36 public:
37 enum Unchecked {
38 UNCHECKED,
39 };
40
SymInt(int64_t d)41 /*implicit*/ SymInt(int64_t d) : data_(d) {
42 if (is_heap_allocated()) {
43 // Large negative number, heap allocate it
44 promote_to_negative();
45 }
46 };
SymInt()47 SymInt() : data_(0) {}
48 SymInt(SymNode n);
49
50 // unchecked c-tor accepting raw `data_`
51 // One appropriate use for this is when you are constructing a symint
52 // in a situation where you know it is non-negative (or, if it is negative,
53 // the negative value is -1; i.e., not user controlled)
SymInt(Unchecked,int64_t d)54 SymInt(Unchecked, int64_t d) : data_(d) {}
55
56 // TODO: these implementations are not optimal because they allocate a
57 // temporary and then use the move constructor/assignment
SymInt(const SymInt & s)58 SymInt(const SymInt& s) : data_(0) {
59 if (s.is_heap_allocated()) {
60 *this = SymInt(s.toSymNode());
61 } else {
62 data_ = s.data_;
63 }
64 }
SymInt(SymInt && s)65 SymInt(SymInt&& s) noexcept : data_(s.data_) {
66 s.data_ = 0;
67 }
68
69 SymInt& operator=(const SymInt& s) {
70 if (this != &s) {
71 if (s.is_heap_allocated()) {
72 *this = SymInt(s.toSymNode());
73 } else {
74 data_ = s.data_;
75 }
76 }
77 return *this;
78 }
79 SymInt& operator=(SymInt&& s) noexcept {
80 if (this != &s) {
81 release_(); // release the current SymNode if any
82 data_ = s.data_;
83 if (s.is_heap_allocated())
84 s.data_ = 0;
85 };
86 return *this;
87 }
88
toSymNodeImplUnowned()89 SymNodeImpl* toSymNodeImplUnowned() const {
90 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_heap_allocated());
91 uint64_t unextended_bits = static_cast<uint64_t>(data_) & ~MASK;
92 uint64_t sign_bit_mask = 1ULL << (62 - 1);
93 // https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c
94 uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask;
95 return static_cast<SymNodeImpl*>(
96 // NOLINTNEXTLINE(performance-no-int-to-ptr)
97 reinterpret_cast<void*>(static_cast<uintptr_t>(extended_bits)));
98 }
99
release_()100 void release_() {
101 if (is_heap_allocated()) {
102 SymNode::reclaim(toSymNodeImplUnowned()); // steal
103 }
104 }
105
release()106 SymNodeImpl* release() && {
107 #ifndef C10_MOBILE
108 TORCH_INTERNAL_ASSERT(is_heap_allocated());
109 auto* r = toSymNodeImplUnowned();
110 data_ = 0; // transfer ownership
111 return r;
112 #else
113 TORCH_INTERNAL_ASSERT(false);
114 #endif
115 }
116
117 // Only valid if is_heap_allocated()
118 SymNode toSymNode() const;
119
120 // Guaranteed to return a SymNode, wrapping using base if necessary
121 SymNode wrap_node(const SymNode& base) const;
122
~SymInt()123 ~SymInt() {
124 release_();
125 }
126
127 // Require the int to be non-symbolic, and if it is symbolic raise an
128 // error. This is safe to use for C++ code that doesn't work for symbolic
129 // shapes, and you don't have time to fix it immediately, as if we
130 // try to trigger the path in C++ you'll appropriately get an error
expect_int()131 int64_t expect_int() const {
132 if (auto r = maybe_as_int()) {
133 return *r;
134 }
135 TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(
136 false, "when unpacking SymInt, expected int but got ", *this);
137 }
138
139 // Test if we have a hint for this int (e.g., guard_int would work).
140 // Most of the time this is true; it is only false when you have
141 // an unbacked SymInt.
142 bool has_hint() const;
143
144 // Insert a guard for the int to be its concrete value, and then return
145 // that value. This operation always works, even if the int is symbolic,
146 // so long as we know what the underlying value is (e.g., this won't work
147 // if you call it on the size of nonzero output). Don't blindly put this
148 // everywhere; you can cause overspecialization of PyTorch programs with
149 // this method.
150 //
151 // It should be called as guard_int(__FILE__, __LINE__). The file and line
152 // number can be used to diagnose overspecialization.
153 int64_t guard_int(const char* file, int64_t line) const;
154
155 // Insert a guard that this SymInt must be size-like, returning true if
156 // the integer actually is >= 0. Unlike manually performing a >= 0 test,
157 // if the SymInt in question is an unbacked SymInt (or, potentially in the
158 // future, if it contains unbacked SymInts), we will also treat the
159 // unbacked SymInt as statically testing >= 2 (which will prevent us from
160 // choking on, e.g., contiguity checks.)
161 bool expect_size(const char* file, int64_t line) const;
162
163 // Distinguish actual symbolic values from constants stored on the heap
is_symbolic()164 bool is_symbolic() const {
165 return is_heap_allocated() &&
166 !toSymNodeImplUnowned()->constant_int().has_value();
167 }
168
169 // N.B. It's important to keep this definition in the header
170 // as we expect if checks to be folded for mobile builds
171 // where `is_heap_allocated` is always false and optimize dead code paths
is_heap_allocated()172 C10_ALWAYS_INLINE bool is_heap_allocated() const {
173 #ifdef C10_MOBILE
174 return false;
175 #else
176 return !check_range(data_);
177 #endif
178 }
179
180 SymInt operator+(const SymInt& sci) const;
181 SymInt operator-(const SymInt& sci) const;
182 SymInt operator*(const SymInt& sci) const;
183 SymInt operator/(const SymInt& sci) const;
184 SymInt operator%(const SymInt& sci) const;
185 void operator*=(const SymInt& sci);
186 void operator+=(const SymInt& sci);
187 void operator/=(const SymInt& sci);
188
189 SymInt clone() const;
190
191 SymBool sym_eq(const SymInt&) const;
192 SymBool sym_ne(const SymInt&) const;
193 SymBool sym_lt(const SymInt&) const;
194 SymBool sym_le(const SymInt&) const;
195 SymBool sym_gt(const SymInt&) const;
196 SymBool sym_ge(const SymInt&) const;
197
198 bool operator==(const SymInt& o) const {
199 return sym_eq(o).guard_bool(__FILE__, __LINE__);
200 }
201 bool operator!=(const SymInt& o) const {
202 return sym_ne(o).guard_bool(__FILE__, __LINE__);
203 }
204 bool operator<(const SymInt& o) const {
205 return sym_lt(o).guard_bool(__FILE__, __LINE__);
206 }
207 bool operator<=(const SymInt& o) const {
208 return sym_le(o).guard_bool(__FILE__, __LINE__);
209 }
210 bool operator>(const SymInt& o) const {
211 return sym_gt(o).guard_bool(__FILE__, __LINE__);
212 }
213 bool operator>=(const SymInt& o) const {
214 return sym_ge(o).guard_bool(__FILE__, __LINE__);
215 }
216
217 SymInt min(const SymInt& sci) const;
218 SymInt max(const SymInt& sci) const;
219
220 // If both are symbolic, this checks if
221 // they share the same node.
222 // If both are not symbolic this just checks normal equality.
223 bool is_same(const SymInt& other) const;
224
225 operator SymFloat() const;
226
227 // Don't use this. Prefer maybe_as_int instead
as_int_unchecked()228 int64_t as_int_unchecked() const {
229 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_heap_allocated());
230 return data_;
231 }
232
maybe_as_int()233 std::optional<int64_t> maybe_as_int() const {
234 if (!is_heap_allocated()) {
235 return std::make_optional(data_);
236 }
237 auto* node = toSymNodeImplUnowned();
238 if (auto c = node->constant_int()) {
239 return c;
240 }
241 return node->maybe_as_int();
242 }
243
244 // Return whether the integer is directly coercible to a SymInt
245 // without requiring heap allocation. You don't need to use this
246 // to check if you can pass an integer to SymInt; this is guaranteed
247 // to work (it just might heap allocate!)
check_range(int64_t i)248 static bool check_range(int64_t i) {
249 return i > MAX_UNREPRESENTABLE_INT;
250 }
251
252 // Return the min representable integer as a SymInt without
253 // heap allocation. For quantities that count bytes (or larger),
254 // this is still much larger than you need, so you may consider
255 // using this as a more efficient version of MIN_INT
min_representable_int()256 static constexpr int64_t min_representable_int() {
257 return MAX_UNREPRESENTABLE_INT + 1;
258 }
259
260 private:
261 void promote_to_negative();
262
263 // Constraints on the internal representation:
264 //
265 // - Should represent positive and small negative ints
266 // - No conversion necessary for operations on ints
267 // - Must represent valid 64-bit pointers
268 // - Is symbolic test should be FAST (two arithmetic instructions is too
269 // much).
270 // This code being a hotpath is based on Strobelight profiles of
271 // is_heap_allocated(). FB only: https://fburl.com/strobelight/5l50ncxd
272 // (you will need to change the time window).
273 //
274 // So, the scheme is to reserve large negative numbers (assuming
275 // two's complement):
276 //
277 // - 0b0.... means we are a positive int
278 // - 0b11... means we are a small negative int
279 // - 0b10... means we are are a pointer. This means that
280 // [-2^63, -2^62-1] are not representable as ints.
281 // We don't actually need all of this space as on x86_64
282 // as the top 16bits aren't used for anything
283 static constexpr uint64_t MASK = 1ULL << 63 | 1ULL << 62 | 1ULL << 61;
284 static constexpr uint64_t IS_SYM = 1ULL << 63 | 1ULL << 61;
285 // We must manually translate the bit pattern test into a greater
286 // than test because compiler doesn't figure it out:
287 // https://godbolt.org/z/356aferaW
288 static constexpr int64_t MAX_UNREPRESENTABLE_INT =
289 -1LL & static_cast<int64_t>(~(1ULL << 62));
290 int64_t data_;
291 };
292
293 /// Sum of a list of SymInt; accumulates into the c10::SymInt expression
294 template <
295 typename C,
296 typename std::enable_if_t<
297 std::is_same_v<typename C::value_type, c10::SymInt>,
298 int> = 0>
multiply_integers(const C & container)299 inline c10::SymInt multiply_integers(const C& container) {
300 return std::accumulate(
301 container.begin(),
302 container.end(),
303 c10::SymInt(1),
304 [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; });
305 }
306
307 template <
308 typename Iter,
309 typename = std::enable_if_t<std::is_same_v<
310 typename std::iterator_traits<Iter>::value_type,
311 c10::SymInt>>>
multiply_integers(Iter begin,Iter end)312 inline c10::SymInt multiply_integers(Iter begin, Iter end) {
313 return std::accumulate(
314 begin,
315 end,
316 c10::SymInt(1),
317 [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; });
318 }
319
320 #define DECLARE_SYMINT_OP_INTONLY(scalar_t, RetTy) \
321 C10_API RetTy operator%(const SymInt& a, scalar_t b); \
322 C10_API RetTy operator%(scalar_t a, const SymInt& b);
323
324 #define DECLARE_SYMINT_OP(scalar_t, RetTy) \
325 C10_API RetTy operator+(const SymInt& a, scalar_t b); \
326 C10_API RetTy operator-(const SymInt& a, scalar_t b); \
327 C10_API RetTy operator*(const SymInt& a, scalar_t b); \
328 C10_API RetTy operator/(const SymInt& a, scalar_t b); \
329 C10_API RetTy operator+(scalar_t a, const SymInt& b); \
330 C10_API RetTy operator-(scalar_t a, const SymInt& b); \
331 C10_API RetTy operator*(scalar_t a, const SymInt& b); \
332 C10_API RetTy operator/(scalar_t a, const SymInt& b); \
333 C10_API bool operator==(const SymInt& a, scalar_t b); \
334 C10_API bool operator!=(const SymInt& a, scalar_t b); \
335 C10_API bool operator<(const SymInt& a, scalar_t b); \
336 C10_API bool operator<=(const SymInt& a, scalar_t b); \
337 C10_API bool operator>(const SymInt& a, scalar_t b); \
338 C10_API bool operator>=(const SymInt& a, scalar_t b); \
339 C10_API bool operator==(scalar_t a, const SymInt& b); \
340 C10_API bool operator!=(scalar_t a, const SymInt& b); \
341 C10_API bool operator<(scalar_t a, const SymInt& b); \
342 C10_API bool operator<=(scalar_t a, const SymInt& b); \
343 C10_API bool operator>(scalar_t a, const SymInt& b); \
344 C10_API bool operator>=(scalar_t a, const SymInt& b);
345
346 DECLARE_SYMINT_OP_INTONLY(int64_t, SymInt)
347 DECLARE_SYMINT_OP_INTONLY(int32_t, SymInt)
348 DECLARE_SYMINT_OP_INTONLY(uint64_t, SymInt)
349 DECLARE_SYMINT_OP_INTONLY(uint32_t, SymInt)
350 DECLARE_SYMINT_OP(int64_t, SymInt)
351 DECLARE_SYMINT_OP(int32_t, SymInt) // make sure constants work
352 DECLARE_SYMINT_OP(uint64_t, SymInt)
353 DECLARE_SYMINT_OP(uint32_t, SymInt)
354 DECLARE_SYMINT_OP(double, SymFloat)
355 DECLARE_SYMINT_OP(float, SymFloat) // just for completeness
356
357 // On OSX size_t is different than uint64_t so we have to
358 // define it separately
359 #if defined(__APPLE__)
360 DECLARE_SYMINT_OP_INTONLY(size_t, SymInt)
361 DECLARE_SYMINT_OP(size_t, SymInt)
362 #endif
363
364 #undef DECLARE_SYMINT_OP
365
366 C10_API std::ostream& operator<<(std::ostream& os, const SymInt& s);
367 C10_API SymInt operator-(const SymInt& s);
368
sym_eq(int64_t a,int64_t b)369 inline bool sym_eq(int64_t a, int64_t b) {
370 return a == b;
371 }
372
sym_eq(const SymInt & a,const SymInt & b)373 inline SymBool sym_eq(const SymInt& a, const SymInt& b) {
374 return a.sym_eq(b);
375 }
376
sym_ne(int64_t a,int64_t b)377 inline bool sym_ne(int64_t a, int64_t b) {
378 return a != b;
379 }
380
sym_ne(const SymInt & a,const SymInt & b)381 inline SymBool sym_ne(const SymInt& a, const SymInt& b) {
382 return a.sym_ne(b);
383 }
384
sym_lt(int64_t a,int64_t b)385 inline bool sym_lt(int64_t a, int64_t b) {
386 return a < b;
387 }
388
sym_lt(const SymInt & a,const SymInt & b)389 inline SymBool sym_lt(const SymInt& a, const SymInt& b) {
390 return a.sym_lt(b);
391 }
392
sym_le(int64_t a,int64_t b)393 inline bool sym_le(int64_t a, int64_t b) {
394 return a <= b;
395 }
396
sym_le(const SymInt & a,const SymInt & b)397 inline SymBool sym_le(const SymInt& a, const SymInt& b) {
398 return a.sym_le(b);
399 }
400
sym_gt(int64_t a,int64_t b)401 inline bool sym_gt(int64_t a, int64_t b) {
402 return a > b;
403 }
404
sym_gt(const SymInt & a,const SymInt & b)405 inline SymBool sym_gt(const SymInt& a, const SymInt& b) {
406 return a.sym_gt(b);
407 }
408
sym_ge(int64_t a,int64_t b)409 inline bool sym_ge(int64_t a, int64_t b) {
410 return a >= b;
411 }
412
sym_ge(const SymInt & a,const SymInt & b)413 inline SymBool sym_ge(const SymInt& a, const SymInt& b) {
414 return a.sym_ge(b);
415 }
416
definitely_true(const c10::SymBool & b,const char * file,int64_t line)417 inline bool definitely_true(
418 const c10::SymBool& b,
419 const char* file,
420 int64_t line) {
421 return b.has_hint() && b.guard_bool(file, line);
422 }
423
424 } // namespace c10
425