xref: /aosp_15_r20/external/pytorch/c10/core/SymInt.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/SymBool.h>
4 #include <c10/core/SymNodeImpl.h>
5 #include <c10/macros/Export.h>
6 #include <c10/macros/Macros.h>
7 #include <c10/util/Exception.h>
8 #include <c10/util/Optional.h>
9 
10 #include <cstdint>
11 #include <iterator>
12 #include <numeric>
13 #include <optional>
14 #include <ostream>
15 #include <type_traits>
16 
17 namespace c10 {
18 
19 class SymFloat;
20 
21 // SymInt represents either a regular int64_t, or a symbolic integer
22 // (represented in a type erased way as SymNode).  The intention is for SymInt
23 // to represent symbolic sizes that arise when doing shape computation in
24 // operator kernels. This allows for tracing through programs without baking in
25 // concrete sizes into kernel calls.
26 //
27 // SymInt has an API equivalent to int64_t.  In particular, it is a value type.
28 // Internally, SymInt is represented in a clever packed way, so that it only
29 // occupies one word of space; but morally, it is a union between an int64_t
30 // and an intrusive pointer to SymNodeImpl.
31 //
32 // Invariant: the referenced SymNodeImpl is guaranteed to be a SymNode where
33 // is_int() returns true
34 
35 class C10_API SymInt {
36  public:
37   enum Unchecked {
38     UNCHECKED,
39   };
40 
SymInt(int64_t d)41   /*implicit*/ SymInt(int64_t d) : data_(d) {
42     if (is_heap_allocated()) {
43       // Large negative number, heap allocate it
44       promote_to_negative();
45     }
46   };
SymInt()47   SymInt() : data_(0) {}
48   SymInt(SymNode n);
49 
50   // unchecked c-tor accepting raw `data_`
51   // One appropriate use for this is when you are constructing a symint
52   // in a situation where you know it is non-negative (or, if it is negative,
53   // the negative value is -1; i.e., not user controlled)
SymInt(Unchecked,int64_t d)54   SymInt(Unchecked, int64_t d) : data_(d) {}
55 
56   // TODO: these implementations are not optimal because they allocate a
57   // temporary and then use the move constructor/assignment
SymInt(const SymInt & s)58   SymInt(const SymInt& s) : data_(0) {
59     if (s.is_heap_allocated()) {
60       *this = SymInt(s.toSymNode());
61     } else {
62       data_ = s.data_;
63     }
64   }
SymInt(SymInt && s)65   SymInt(SymInt&& s) noexcept : data_(s.data_) {
66     s.data_ = 0;
67   }
68 
69   SymInt& operator=(const SymInt& s) {
70     if (this != &s) {
71       if (s.is_heap_allocated()) {
72         *this = SymInt(s.toSymNode());
73       } else {
74         data_ = s.data_;
75       }
76     }
77     return *this;
78   }
79   SymInt& operator=(SymInt&& s) noexcept {
80     if (this != &s) {
81       release_(); // release the current SymNode if any
82       data_ = s.data_;
83       if (s.is_heap_allocated())
84         s.data_ = 0;
85     };
86     return *this;
87   }
88 
toSymNodeImplUnowned()89   SymNodeImpl* toSymNodeImplUnowned() const {
90     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_heap_allocated());
91     uint64_t unextended_bits = static_cast<uint64_t>(data_) & ~MASK;
92     uint64_t sign_bit_mask = 1ULL << (62 - 1);
93     // https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c
94     uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask;
95     return static_cast<SymNodeImpl*>(
96         // NOLINTNEXTLINE(performance-no-int-to-ptr)
97         reinterpret_cast<void*>(static_cast<uintptr_t>(extended_bits)));
98   }
99 
release_()100   void release_() {
101     if (is_heap_allocated()) {
102       SymNode::reclaim(toSymNodeImplUnowned()); // steal
103     }
104   }
105 
release()106   SymNodeImpl* release() && {
107 #ifndef C10_MOBILE
108     TORCH_INTERNAL_ASSERT(is_heap_allocated());
109     auto* r = toSymNodeImplUnowned();
110     data_ = 0; // transfer ownership
111     return r;
112 #else
113     TORCH_INTERNAL_ASSERT(false);
114 #endif
115   }
116 
117   // Only valid if is_heap_allocated()
118   SymNode toSymNode() const;
119 
120   // Guaranteed to return a SymNode, wrapping using base if necessary
121   SymNode wrap_node(const SymNode& base) const;
122 
~SymInt()123   ~SymInt() {
124     release_();
125   }
126 
127   // Require the int to be non-symbolic, and if it is symbolic raise an
128   // error.  This is safe to use for C++ code that doesn't work for symbolic
129   // shapes, and you don't have time to fix it immediately, as if we
130   // try to trigger the path in C++ you'll appropriately get an error
expect_int()131   int64_t expect_int() const {
132     if (auto r = maybe_as_int()) {
133       return *r;
134     }
135     TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(
136         false, "when unpacking SymInt, expected int but got ", *this);
137   }
138 
139   // Test if we have a hint for this int (e.g., guard_int would work).
140   // Most of the time this is true; it is only false when you have
141   // an unbacked SymInt.
142   bool has_hint() const;
143 
144   // Insert a guard for the int to be its concrete value, and then return
145   // that value.  This operation always works, even if the int is symbolic,
146   // so long as we know what the underlying value is (e.g., this won't work
147   // if you call it on the size of nonzero output).  Don't blindly put this
148   // everywhere; you can cause overspecialization of PyTorch programs with
149   // this method.
150   //
151   // It should be called as guard_int(__FILE__, __LINE__).  The file and line
152   // number can be used to diagnose overspecialization.
153   int64_t guard_int(const char* file, int64_t line) const;
154 
155   // Insert a guard that this SymInt must be size-like, returning true if
156   // the integer actually is >= 0.  Unlike manually performing a >= 0 test,
157   // if the SymInt in question is an unbacked SymInt (or, potentially in the
158   // future, if it contains unbacked SymInts), we will also treat the
159   // unbacked SymInt as statically testing >= 2 (which will prevent us from
160   // choking on, e.g., contiguity checks.)
161   bool expect_size(const char* file, int64_t line) const;
162 
163   // Distinguish actual symbolic values from constants stored on the heap
is_symbolic()164   bool is_symbolic() const {
165     return is_heap_allocated() &&
166         !toSymNodeImplUnowned()->constant_int().has_value();
167   }
168 
169   // N.B. It's important to keep this definition in the header
170   // as we expect if checks to be folded for mobile builds
171   // where `is_heap_allocated` is always false and optimize dead code paths
is_heap_allocated()172   C10_ALWAYS_INLINE bool is_heap_allocated() const {
173 #ifdef C10_MOBILE
174     return false;
175 #else
176     return !check_range(data_);
177 #endif
178   }
179 
180   SymInt operator+(const SymInt& sci) const;
181   SymInt operator-(const SymInt& sci) const;
182   SymInt operator*(const SymInt& sci) const;
183   SymInt operator/(const SymInt& sci) const;
184   SymInt operator%(const SymInt& sci) const;
185   void operator*=(const SymInt& sci);
186   void operator+=(const SymInt& sci);
187   void operator/=(const SymInt& sci);
188 
189   SymInt clone() const;
190 
191   SymBool sym_eq(const SymInt&) const;
192   SymBool sym_ne(const SymInt&) const;
193   SymBool sym_lt(const SymInt&) const;
194   SymBool sym_le(const SymInt&) const;
195   SymBool sym_gt(const SymInt&) const;
196   SymBool sym_ge(const SymInt&) const;
197 
198   bool operator==(const SymInt& o) const {
199     return sym_eq(o).guard_bool(__FILE__, __LINE__);
200   }
201   bool operator!=(const SymInt& o) const {
202     return sym_ne(o).guard_bool(__FILE__, __LINE__);
203   }
204   bool operator<(const SymInt& o) const {
205     return sym_lt(o).guard_bool(__FILE__, __LINE__);
206   }
207   bool operator<=(const SymInt& o) const {
208     return sym_le(o).guard_bool(__FILE__, __LINE__);
209   }
210   bool operator>(const SymInt& o) const {
211     return sym_gt(o).guard_bool(__FILE__, __LINE__);
212   }
213   bool operator>=(const SymInt& o) const {
214     return sym_ge(o).guard_bool(__FILE__, __LINE__);
215   }
216 
217   SymInt min(const SymInt& sci) const;
218   SymInt max(const SymInt& sci) const;
219 
220   // If both are symbolic, this checks if
221   // they share the same node.
222   // If both are not symbolic this just checks normal equality.
223   bool is_same(const SymInt& other) const;
224 
225   operator SymFloat() const;
226 
227   // Don't use this.  Prefer maybe_as_int instead
as_int_unchecked()228   int64_t as_int_unchecked() const {
229     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_heap_allocated());
230     return data_;
231   }
232 
maybe_as_int()233   std::optional<int64_t> maybe_as_int() const {
234     if (!is_heap_allocated()) {
235       return std::make_optional(data_);
236     }
237     auto* node = toSymNodeImplUnowned();
238     if (auto c = node->constant_int()) {
239       return c;
240     }
241     return node->maybe_as_int();
242   }
243 
244   // Return whether the integer is directly coercible to a SymInt
245   // without requiring heap allocation.  You don't need to use this
246   // to check if you can pass an integer to SymInt; this is guaranteed
247   // to work (it just might heap allocate!)
check_range(int64_t i)248   static bool check_range(int64_t i) {
249     return i > MAX_UNREPRESENTABLE_INT;
250   }
251 
252   // Return the min representable integer as a SymInt without
253   // heap allocation.  For quantities that count bytes (or larger),
254   // this is still much larger than you need, so you may consider
255   // using this as a more efficient version of MIN_INT
min_representable_int()256   static constexpr int64_t min_representable_int() {
257     return MAX_UNREPRESENTABLE_INT + 1;
258   }
259 
260  private:
261   void promote_to_negative();
262 
263   // Constraints on the internal representation:
264   //
265   // - Should represent positive and small negative ints
266   // - No conversion necessary for operations on ints
267   // - Must represent valid 64-bit pointers
268   // - Is symbolic test should be FAST (two arithmetic instructions is too
269   // much).
270   //   This code being a hotpath is based on Strobelight profiles of
271   //   is_heap_allocated().  FB only: https://fburl.com/strobelight/5l50ncxd
272   //   (you will need to change the time window).
273   //
274   // So, the scheme is to reserve large negative numbers (assuming
275   // two's complement):
276   //
277   // - 0b0.... means we are a positive int
278   // - 0b11... means we are a small negative int
279   // - 0b10... means we are are a pointer. This means that
280   //           [-2^63, -2^62-1] are not representable as ints.
281   //           We don't actually need all of this space as on x86_64
282   //           as the top 16bits aren't used for anything
283   static constexpr uint64_t MASK = 1ULL << 63 | 1ULL << 62 | 1ULL << 61;
284   static constexpr uint64_t IS_SYM = 1ULL << 63 | 1ULL << 61;
285   // We must manually translate the bit pattern test into a greater
286   // than test because compiler doesn't figure it out:
287   // https://godbolt.org/z/356aferaW
288   static constexpr int64_t MAX_UNREPRESENTABLE_INT =
289       -1LL & static_cast<int64_t>(~(1ULL << 62));
290   int64_t data_;
291 };
292 
293 /// Sum of a list of SymInt; accumulates into the c10::SymInt expression
294 template <
295     typename C,
296     typename std::enable_if_t<
297         std::is_same_v<typename C::value_type, c10::SymInt>,
298         int> = 0>
multiply_integers(const C & container)299 inline c10::SymInt multiply_integers(const C& container) {
300   return std::accumulate(
301       container.begin(),
302       container.end(),
303       c10::SymInt(1),
304       [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; });
305 }
306 
307 template <
308     typename Iter,
309     typename = std::enable_if_t<std::is_same_v<
310         typename std::iterator_traits<Iter>::value_type,
311         c10::SymInt>>>
multiply_integers(Iter begin,Iter end)312 inline c10::SymInt multiply_integers(Iter begin, Iter end) {
313   return std::accumulate(
314       begin,
315       end,
316       c10::SymInt(1),
317       [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; });
318 }
319 
320 #define DECLARE_SYMINT_OP_INTONLY(scalar_t, RetTy)      \
321   C10_API RetTy operator%(const SymInt& a, scalar_t b); \
322   C10_API RetTy operator%(scalar_t a, const SymInt& b);
323 
324 #define DECLARE_SYMINT_OP(scalar_t, RetTy)              \
325   C10_API RetTy operator+(const SymInt& a, scalar_t b); \
326   C10_API RetTy operator-(const SymInt& a, scalar_t b); \
327   C10_API RetTy operator*(const SymInt& a, scalar_t b); \
328   C10_API RetTy operator/(const SymInt& a, scalar_t b); \
329   C10_API RetTy operator+(scalar_t a, const SymInt& b); \
330   C10_API RetTy operator-(scalar_t a, const SymInt& b); \
331   C10_API RetTy operator*(scalar_t a, const SymInt& b); \
332   C10_API RetTy operator/(scalar_t a, const SymInt& b); \
333   C10_API bool operator==(const SymInt& a, scalar_t b); \
334   C10_API bool operator!=(const SymInt& a, scalar_t b); \
335   C10_API bool operator<(const SymInt& a, scalar_t b);  \
336   C10_API bool operator<=(const SymInt& a, scalar_t b); \
337   C10_API bool operator>(const SymInt& a, scalar_t b);  \
338   C10_API bool operator>=(const SymInt& a, scalar_t b); \
339   C10_API bool operator==(scalar_t a, const SymInt& b); \
340   C10_API bool operator!=(scalar_t a, const SymInt& b); \
341   C10_API bool operator<(scalar_t a, const SymInt& b);  \
342   C10_API bool operator<=(scalar_t a, const SymInt& b); \
343   C10_API bool operator>(scalar_t a, const SymInt& b);  \
344   C10_API bool operator>=(scalar_t a, const SymInt& b);
345 
346 DECLARE_SYMINT_OP_INTONLY(int64_t, SymInt)
347 DECLARE_SYMINT_OP_INTONLY(int32_t, SymInt)
348 DECLARE_SYMINT_OP_INTONLY(uint64_t, SymInt)
349 DECLARE_SYMINT_OP_INTONLY(uint32_t, SymInt)
350 DECLARE_SYMINT_OP(int64_t, SymInt)
351 DECLARE_SYMINT_OP(int32_t, SymInt) // make sure constants work
352 DECLARE_SYMINT_OP(uint64_t, SymInt)
353 DECLARE_SYMINT_OP(uint32_t, SymInt)
354 DECLARE_SYMINT_OP(double, SymFloat)
355 DECLARE_SYMINT_OP(float, SymFloat) // just for completeness
356 
357 // On OSX size_t is different than uint64_t so we have to
358 // define it separately
359 #if defined(__APPLE__)
360 DECLARE_SYMINT_OP_INTONLY(size_t, SymInt)
361 DECLARE_SYMINT_OP(size_t, SymInt)
362 #endif
363 
364 #undef DECLARE_SYMINT_OP
365 
366 C10_API std::ostream& operator<<(std::ostream& os, const SymInt& s);
367 C10_API SymInt operator-(const SymInt& s);
368 
sym_eq(int64_t a,int64_t b)369 inline bool sym_eq(int64_t a, int64_t b) {
370   return a == b;
371 }
372 
sym_eq(const SymInt & a,const SymInt & b)373 inline SymBool sym_eq(const SymInt& a, const SymInt& b) {
374   return a.sym_eq(b);
375 }
376 
sym_ne(int64_t a,int64_t b)377 inline bool sym_ne(int64_t a, int64_t b) {
378   return a != b;
379 }
380 
sym_ne(const SymInt & a,const SymInt & b)381 inline SymBool sym_ne(const SymInt& a, const SymInt& b) {
382   return a.sym_ne(b);
383 }
384 
sym_lt(int64_t a,int64_t b)385 inline bool sym_lt(int64_t a, int64_t b) {
386   return a < b;
387 }
388 
sym_lt(const SymInt & a,const SymInt & b)389 inline SymBool sym_lt(const SymInt& a, const SymInt& b) {
390   return a.sym_lt(b);
391 }
392 
sym_le(int64_t a,int64_t b)393 inline bool sym_le(int64_t a, int64_t b) {
394   return a <= b;
395 }
396 
sym_le(const SymInt & a,const SymInt & b)397 inline SymBool sym_le(const SymInt& a, const SymInt& b) {
398   return a.sym_le(b);
399 }
400 
sym_gt(int64_t a,int64_t b)401 inline bool sym_gt(int64_t a, int64_t b) {
402   return a > b;
403 }
404 
sym_gt(const SymInt & a,const SymInt & b)405 inline SymBool sym_gt(const SymInt& a, const SymInt& b) {
406   return a.sym_gt(b);
407 }
408 
sym_ge(int64_t a,int64_t b)409 inline bool sym_ge(int64_t a, int64_t b) {
410   return a >= b;
411 }
412 
sym_ge(const SymInt & a,const SymInt & b)413 inline SymBool sym_ge(const SymInt& a, const SymInt& b) {
414   return a.sym_ge(b);
415 }
416 
definitely_true(const c10::SymBool & b,const char * file,int64_t line)417 inline bool definitely_true(
418     const c10::SymBool& b,
419     const char* file,
420     int64_t line) {
421   return b.has_hint() && b.guard_bool(file, line);
422 }
423 
424 } // namespace c10
425