xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/ivalue.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/DimVector.h>
4 #include <ATen/core/TensorBody.h>
5 #include <ATen/core/blob.h>
6 #include <ATen/core/custom_class.h>
7 #include <ATen/core/ivalue_to.h>
8 #include <ATen/core/jit_type_base.h>
9 #include <ATen/core/type_factory.h>
10 #include <c10/core/SymBool.h>
11 #include <c10/core/SymFloat.h>
12 #include <c10/macros/Export.h>
13 #include <c10/util/MaybeOwned.h>
14 #include <c10/util/intrusive_ptr.h>
15 #include <type_traits>
16 #include <unordered_map>
17 #include <unordered_set>
18 #include <utility>
19 
20 namespace torch {
21 class TORCH_API CustomClassHolder : public c10::intrusive_ptr_target {};
22 namespace jit {
23 using ::torch::CustomClassHolder;
24 struct Function;
25 struct CompilationUnit;
26 struct Module;
27 } // namespace jit
28 } // namespace torch
29 namespace c10 {
30 template <class Key, class Value>
31 class Dict;
32 template <class T>
33 class List;
34 template <class T>
35 class IListRef;
36 struct IValue;
37 struct ClassType;
38 struct Type;
39 class RRefInterface;
40 
41 struct ClassType;
42 using ClassTypePtr = std::shared_ptr<ClassType>;
43 
44 TORCH_API bool _fastEqualsForContainer(const IValue& lhs, const IValue& rhs);
45 
46 TORCH_API torch::jit::Function* checkObjectSortSchema(
47     const c10::ClassTypePtr& t,
48     std::stringstream& why_not);
49 
50 // A comparator that checks ordering of two IValues of same type.
51 typedef std::function<bool(const IValue& a, const IValue& b)> IValueComparator;
52 
53 TORCH_API IValueComparator getLessThanComparator(const IValue& v);
54 TORCH_API IValueComparator getGreaterThanComparator(const IValue& v);
55 
56 namespace ivalue {
57 struct Tuple;
58 struct Future;
59 struct Await;
60 struct ConstantString;
61 struct GenericDict;
62 struct Object;
63 struct PyObjectHolder;
64 struct EnumHolder;
65 // We need a ComplexHolder because currently the payloads in the Union
66 // only take 64 bits. Since ComplexDouble takes up 128 bits, and is too big
67 // to fit in the IValue directly, we indirect complex numbers through an
68 // intrusive pointer to ComplexHolder (which contains a c10::complex).
69 struct ComplexHolder : c10::intrusive_ptr_target {
70  public:
71   template <typename T>
ComplexHolderComplexHolder72   ComplexHolder(c10::complex<T> c) {
73     val = convert<decltype(val), c10::complex<T>>(c);
74   }
75   ComplexHolder() = default;
76   c10::complex<double> val;
77 };
78 
79 // Similar to ComplexHolder, for StreamData3
80 struct StreamData3Holder : c10::intrusive_ptr_target {
81  public:
StreamData3HolderStreamData3Holder82   StreamData3Holder(struct c10::StreamData3 d) : val(d) {}
83   StreamData3Holder() = delete;
84   struct c10::StreamData3 val;
85 };
86 
87 } // namespace ivalue
88 
89 // This is an owning wrapper for a std::optional<std::vector<T>>
90 // that can be implicitly converted to a (non-owning) std::optional<ArrayRef<T>>.
91 // Its purpose is to be used in generated code to keep the vector alive
92 // either until the end of a statement (as a temporary), or as a saved arg
93 // in autograd.
94 template <typename T>
95 struct OptionalArray {
96   std::optional<std::vector<T>> list;
97 
98   OptionalArray() = default;
OptionalArrayOptionalArray99   OptionalArray(std::vector<T> val) : list(std::move(val)) {}
100 
101   // Used when saving an argument for the backwards pass.
102   OptionalArray& operator=(std::optional<ArrayRef<T>> ref) {
103     if (ref) {
104       list = std::vector<T>(ref->begin(), ref->end());
105     } else {
106       list = std::nullopt;
107     }
108     return *this;
109   }
110 
111   // Used when saving an argument for the backwards pass.
112   OptionalArray& operator=(c10::OptionalArrayRef<T> ref) {
113     if (ref) {
114       list = std::vector<T>(ref->begin(), ref->end());
115     } else {
116       list = std::nullopt;
117     }
118     return *this;
119   }
120 
121   operator std::optional<c10::ArrayRef<T>>() {
122     if (!list) {
123       return std::nullopt;
124     }
125     return *list;
126   }
127 
128   operator c10::OptionalArrayRef<T>() {
129     if (!list) {
130       return std::nullopt;
131     }
132     return *list;
133   }
134 };
135 
136 // Capsule is an internal implementation detail of custom C++ classes. We
137 // define it as an owning wrapper for
138 // c10::intrusive_ptr<torch::CustomClassHolder> This wrapper is here to serve as
139 // an abstraction of the type erased custom class object pointer. It also allow
140 // pybind11 to treat this as a standalone class to register as a separate type
141 // caster, instead of a custom pointer holder which the pointer holder type
142 // caster try to "unwrap" it automatically.
143 struct Capsule {
144   c10::intrusive_ptr<torch::CustomClassHolder> obj_ptr;
CapsuleCapsule145   explicit Capsule(c10::intrusive_ptr<torch::CustomClassHolder> ptr)
146       : obj_ptr(std::move(ptr)) {}
147 };
148 
149 // IValue is the generic tagged union used by the interpreter to hold
150 // all value types.
151 // It is a 16-byte object with an 8-byte payload and an 8-byte tag.
152 // The tag is currently 4 bytes to determine the type, and 1 byte
153 // to mark whether that type is a subtype of c10::intrusive_ptr_target and needs
154 // retain/release calls.
155 
156 #define TORCH_FORALL_TAGS(_) \
157   _(None)                    \
158   _(Tensor)                  \
159   _(Storage)                 \
160   _(Double)                  \
161   _(ComplexDouble)           \
162   _(Int)                     \
163   _(SymInt)                  \
164   _(SymFloat)                \
165   _(SymBool)                 \
166   _(Bool)                    \
167   _(Tuple)                   \
168   _(String)                  \
169   _(Blob)                    \
170   _(GenericList)             \
171   _(GenericDict)             \
172   _(Future)                  \
173   _(Await)                   \
174   _(Device)                  \
175   _(Stream)                  \
176   _(Object)                  \
177   _(PyObject)                \
178   _(Uninitialized)           \
179   _(Capsule)                 \
180   _(RRef)                    \
181   _(Quantizer)               \
182   _(Generator)               \
183   _(Enum)
184 
185 // [doxygen private]
186 // These methods are not actually private but we don't want to document them, so
187 // they are marked `@private`, which hides them on the doxygen documentation for
188 // this page.
189 
190 /// IValue (Interpreter Value) is a tagged union over the types
191 /// supported by the TorchScript interpreter. IValues contain their
192 /// values as an `IValue::Payload`, which holds primitive types
193 /// (`int64_t`, `bool`, `double`, `Device`) and `Tensor` as values,
194 /// and all other types as a `c10::intrusive_ptr`. In order to
195 /// optimize performance of the destructor and related operations by
196 /// making the `Tensor` and `c10::intrusive_ptr` paths generate the
197 /// same code, we represent a null `c10::intrusive_ptr` as
198 /// `UndefinedTensorImpl::singleton()`, *not* `nullptr`.
199 ///
200 /// IValues are used as inputs to and outputs from the TorchScript interpreter.
201 /// To retrieve the value contained within an IValue, use the `.toX()` methods,
202 /// where `X` is the type you are trying to get. Note that neither the `.toX()`
203 /// methods nor the templated `.to<T>` functions do any kind of casting, they
204 /// only unwrap the contained value. For example:
205 ///
206 /// \rst
207 /// .. code-block:: cpp
208 ///
209 ///   // Make the IValue
210 ///   torch::IValue my_ivalue(26);
211 ///   std::cout << my_ivalue << "\n";
212 ///
213 ///   // Unwrap the IValue
214 ///   int64_t my_int = my_ivalue.toInt();
215 ///   std::cout << my_int << "\n";
216 ///
217 ///   // This will throw an error!
218 ///   // `my_ivalue` is tagged as an int and cannot be used as another type
219 ///   torch::Tensor my_tensor = my_ivalue.toTensor();
220 /// \endrst
221 struct TORCH_API IValue final {
IValuefinal222   IValue(const IValue& rhs) : IValue(rhs.payload, rhs.tag) {
223     if (isIntrusivePtr() &&
224         payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
225       c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr);
226     }
227   }
228 
IValuefinal229   IValue(IValue&& rhs) noexcept : tag(rhs.tag) {
230     moveFrom(std::move(rhs));
231   }
232 
233   /// @private [doxygen private]
~IValuefinal234   ~IValue() {
235     destroy();
236   }
237 
238   C10_ALWAYS_INLINE IValue& operator=(IValue&& rhs) & noexcept {
239     if (&rhs == this) {
240       return *this;
241     }
242 
243     destroy();
244     moveFrom(std::move(rhs));
245     return *this;
246   }
247 
248   IValue& operator=(IValue const& rhs) & {
249     *this = IValue(rhs);
250     return *this;
251   }
252 
253   void dump() const;
254 
255   /**
256    * Equality comparison. The semantics are the same as Python's `==`:
257    * 1. Numerical types are compared by value.
258    * 2. Tensors compute element-wise equality, returning a BoolTensor (see:
259    * `torch.eq()`)
260    * 3. Strings are compared by value.
261    * 4. Sequence types (list, tuple) are compared lexicographically by
262    *    comparing their elements. Different sequence types never compare equal.
263    * 5. Mappings (dict) must have equal (key, value) pairs.
264    * 6. If not listed above, the default behavior for is to test identity
265    * equality (e.g. pointer equality).
266    *
267    * Why does this return an IValue instead of a bool? Because in PyTorch,
268    * `tensor1 == tensor2` returns a `BoolTensor`, not a bool.
269    *
270    * NOTE: we (like Python) assume that identity equality implies value equality
271    * for efficiency.
272    * TODO: need to support customizing equality
273    */
274   IValue equals(const IValue& rhs) const;
275   /**
276    * This implements the same semantics as `bool(lhs == rhs)` in Python. which
277    * is the same as `equals()` except for Tensor types.
278    */
279   TORCH_API friend bool operator==(const IValue& lhs, const IValue& rhs);
280   TORCH_API friend bool operator!=(const IValue& lhs, const IValue& rhs);
281 
282   /**
283    * Identity comparison. Checks if `this` is the same object as `rhs`. The
284    * semantics are the same as Python's `is` operator.
285    *
286    * NOTE: Like in Python, this operation is poorly defined for primitive types
287    * like numbers and strings. Prefer to use `==` unless you really want to
288    * check identity equality.
289    */
290   bool is(const IValue& rhs) const;
291 
292   /**
293    * Hashing for IValues. Returns an IValue-boxed int.
294    *
295    * Some notes:
296    * - Like eager, Tensors are hashed by looking at the pointer. This is not
297    *   strictly correct because two value-equal tensors with different tensor
298    *   pointers will hash differently, but we choose to reproduce the eager
299    *   semantics.
300    * - Hashing is not defined on all built-in IValue types (e.g. list and
301    *   dict), following Python. Calling `hash()` on these types will throw.
302    */
hashfinal303   IValue hash() const {
304     return (int64_t)IValue::hash(*this);
305   }
306   // This is defined because `c10::hash` dispatches to a function of this
307   // signature. See the member function `hash()`.
308   static size_t hash(const IValue& iv);
309 
310   /**
311    * @private [doxygen private]
312    * [container equality]
313    * This is an equality implementation that assumes objects with the same
314    * identity equal themselves, for efficiency reasons. We primarily have this
315    * for consistency, because Python does the same thing. This actually
316    * provokes user-visible changes in behavior due to quirks in torch:
317    *      [tensor1] == [tensor1] -> True (because container equality will first
318    * compare identity) [tensor1] == [tensor1_copy] -> RuntimeError:
319    * Boolean value of Tensor with more than one value is ambiguous
320    */
321   TORCH_API friend bool _fastEqualsForContainer(
322       const IValue& lhs,
323       const IValue& rhs);
324 
325  private:
isAliasOffinal326   static bool isAliasOf(const at::Tensor& a, const at::Tensor& b) {
327     if (a.is_sparse()) {
328       return isAliasOf(a._values(), b) || isAliasOf(a._indices(), b);
329     }
330     if (b.is_sparse()) {
331       return isAliasOf(a, b._values()) || isAliasOf(a, b._indices());
332     }
333     if (a.is_sparse_csr()) {
334       return isAliasOf(a.values(), b) || isAliasOf(a.crow_indices(), b) ||
335           isAliasOf(a.col_indices(), b);
336     }
337     if (b.is_sparse_csr()) {
338       return isAliasOf(a, b.values()) || isAliasOf(a, b.crow_indices()) ||
339           isAliasOf(a, b.col_indices());
340     }
341 
342     // Opaque tensors such as the ones constructed by the MKL-DNN backend
343     // don't have storage so we just compare their TensorImpls.
344     // TODO: Find way to expose alias info for opaque tensors.
345     if (!a.has_storage() || !b.has_storage()) {
346       return a.unsafeGetTensorImpl() == b.unsafeGetTensorImpl();
347     }
348 
349     return a.is_alias_of(b);
350   }
351 
352   template <typename T>
353   bool isListOf() const;
354 
355  public:
356   /// @private [doxygen private]
isAliasOffinal357   bool isAliasOf(const IValue& rhs) const {
358     if (this->tag != rhs.tag) {
359       // Trivially don't alias if the type is different
360       return false;
361     }
362 
363     // Tensors should be compared based on internal storage
364     if (this->isTensor()) {
365       return isAliasOf(this->toTensor(), rhs.toTensor());
366     }
367 
368     if (!isIntrusivePtr()) {
369       // Primitive types don't alias anything
370       return false;
371     }
372 
373     AT_ASSERT(rhs.isIntrusivePtr());
374 
375     // Other types can be compared by their ptr value
376     return this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr;
377   }
378 
379   /// @private [doxygen private]
use_countfinal380   size_t use_count() const noexcept {
381     if (isTensor()) {
382       return payload.as_tensor.use_count();
383     }
384 
385     if (!isIntrusivePtrLegacyBehavior()) {
386       return 1;
387     }
388 
389     if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) {
390       return 0;
391     }
392     return c10::raw::intrusive_ptr::use_count(payload.u.as_intrusive_ptr);
393   }
394 
395   /// @private [doxygen private]
swapfinal396   void swap(IValue& rhs) noexcept {
397     if (isTensor() && rhs.isTensor()) {
398       std::swap(payload.as_tensor, rhs.payload.as_tensor);
399     } else if (isTensor()) {
400       at::Tensor t = std::move(payload.as_tensor);
401       // As far as I can tell, omitting the usual explicit destructor call
402       // is not UB in and of itself, and it's a slight perf win. The
403       // destructor is a no-op, because the moved-from Tensor is
404       // effectively an intrusive_ptr in the null state, so we don't need
405       // the behavior for correctness reasons either. Leaving this
406       // explanatory comment, including commented-out destructor call, to
407       // make this abundantly clear.
408       //
409       // payload.as_tensor.~Tensor();
410       payload.u = rhs.payload.u;
411       new (&rhs.payload.as_tensor) at::Tensor(std::move(t));
412     } else if (rhs.isTensor()) {
413       rhs.swap(*this);
414       return;
415     } else {
416       std::swap(payload.u, rhs.payload.u);
417     }
418     std::swap(tag, rhs.tag);
419   }
420 
421   // Accessors for subtypes are arranged together below
422   // While some of these accessors could be generated through templates,
423   // we prefer to write them manually for clarity
424 
IValuefinal425   IValue(at::TensorBase t) : tag(Tag::Tensor) {
426     new (&payload.as_tensor) at::Tensor(std::move(t));
427   }
isTensorfinal428   bool isTensor() const {
429     return Tag::Tensor == tag;
430   }
431 
432  private:
433   // Outlined error path so that toTensor() can be inlined.
434   [[noreturn]] void reportToTensorTypeError() const;
435 
436  public:
437   at::Tensor toTensor() &&;
438   at::Tensor& toTensor() &;
439   const at::Tensor& toTensor() const&;
unsafeToTensorImplfinal440   at::TensorImpl* unsafeToTensorImpl() const {
441     TORCH_INTERNAL_ASSERT(isTensor());
442     return payload.as_tensor.unsafeGetTensorImpl();
443   }
444 
IValuefinal445   IValue(at::Storage s) : tag(Tag::Storage) {
446     payload.u.as_intrusive_ptr =
447         null_to_undefined_tensor(s.unsafeReleaseStorageImpl());
448   }
isStoragefinal449   bool isStorage() const {
450     return Tag::Storage == tag;
451   }
452   c10::Storage toStorage() &&;
453   c10::Storage toStorage() const&;
454 
toIValuefinal455   const IValue& toIValue() const {
456     return *this;
457   }
toIValuefinal458   IValue& toIValue() {
459     return *this;
460   }
461 
462   /// @private [doxygen private]
IValuefinal463   IValue(intrusive_ptr<caffe2::Blob> blob) : tag(Tag::Blob) {
464     // TODO (after Tensor merge) If we pass in a Blob holding a Tensor, extract
465     // and store it as a Tensor instead.
466     payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release());
467   }
468 
469   /// @private [doxygen private]
isBlobfinal470   bool isBlob() const {
471     return Tag::Blob == tag;
472   }
473 
474   /// @private [doxygen private]
475   c10::intrusive_ptr<caffe2::Blob> toBlob() &&;
476 
477   /// @private [doxygen private]
478   c10::intrusive_ptr<caffe2::Blob> toBlob() const&;
479 
480   // Capsule. No new callsites of these APIs should
481   // be introduced.
482   static inline IValue make_capsule(
483       intrusive_ptr<torch::CustomClassHolder> blob);
isCapsulefinal484   bool isCapsule() const {
485     return Tag::Capsule == tag;
486   }
487   c10::intrusive_ptr<torch::CustomClassHolder> toCapsule() &&;
488   c10::intrusive_ptr<torch::CustomClassHolder> toCapsule() const&;
489 
490   // Custom C++ classes
491   template <
492       typename T,
493       std::enable_if_t<std::is_base_of_v<torch::CustomClassHolder, T>, int> = 0>
494   IValue(intrusive_ptr<T> custom_class);
495   bool isCustomClass() const;
496   template <typename T>
497   c10::intrusive_ptr<T> toCustomClass() &&;
498   template <typename T>
499   c10::intrusive_ptr<T> toCustomClass() const&;
500 
501   // Tuple
502   IValue(c10::intrusive_ptr<ivalue::Tuple> v);
503 
504   template <
505       typename... Args,
506       std::enable_if_t<
507           !std::disjunction_v<
508               std::is_lvalue_reference<Args>...,
509               std::negation<std::is_constructible<IValue, Args>>...>,
510           std::nullptr_t> = nullptr>
511   IValue(const std::tuple<Args...>& t);
512   template <
513       typename... Args,
514       std::enable_if_t<
515           !std::disjunction_v<
516               std::is_lvalue_reference<Args>...,
517               std::negation<std::is_constructible<IValue, Args>>...>,
518           std::nullptr_t> = nullptr>
519   IValue(std::tuple<Args...>&& t);
isTuplefinal520   bool isTuple() const {
521     return Tag::Tuple == tag;
522   }
523   c10::intrusive_ptr<ivalue::Tuple> toTuple() &&;
524   c10::intrusive_ptr<ivalue::Tuple> toTuple() const&;
525   C10_NODISCARD ivalue::Tuple& toTupleRef() const;
526 
527   // Double
IValuefinal528   IValue(double d) : tag(Tag::Double) {
529     payload.u.as_double = d;
530   }
isDoublefinal531   bool isDouble() const {
532     return Tag::Double == tag;
533   }
toDoublefinal534   double toDouble() const {
535     if (isDouble()) {
536       return payload.u.as_double;
537     } else if (isSymFloat()) {
538       return toSymFloat().guard_float(__FILE__, __LINE__);
539     } else {
540       TORCH_INTERNAL_ASSERT(0, "expected double");
541     }
542   }
543 
544   // ComplexDouble
545   template <typename T>
546   IValue(c10::complex<T> c);
isComplexDoublefinal547   bool isComplexDouble() const {
548     return Tag::ComplexDouble == tag;
549   }
550   c10::complex<double> toComplexDouble() const;
551 
552   // Future
553   IValue(c10::intrusive_ptr<ivalue::Future> v);
isFuturefinal554   bool isFuture() const {
555     return Tag::Future == tag;
556   }
557   c10::intrusive_ptr<ivalue::Future> toFuture() &&;
558   c10::intrusive_ptr<ivalue::Future> toFuture() const&;
559 
560   IValue(c10::intrusive_ptr<ivalue::Await> v);
isAwaitfinal561   bool isAwait() const {
562     return Tag::Await == tag;
563   }
564   c10::intrusive_ptr<ivalue::Await> toAwait() &&;
565   c10::intrusive_ptr<ivalue::Await> toAwait() const&;
566 
567   // RRef
568   IValue(c10::intrusive_ptr<c10::RRefInterface> v);
isRReffinal569   bool isRRef() const {
570     return Tag::RRef == tag;
571   }
572   c10::intrusive_ptr<c10::RRefInterface> toRRef() &&;
573   c10::intrusive_ptr<c10::RRefInterface> toRRef() const&;
574 
575   // Quantizer
576   IValue(c10::intrusive_ptr<at::Quantizer> v);
isQuantizerfinal577   bool isQuantizer() const {
578     return Tag::Quantizer == tag;
579   }
580   c10::intrusive_ptr<at::Quantizer> toQuantizer() &&;
581   c10::intrusive_ptr<at::Quantizer> toQuantizer() const&;
582 
583   // Int
IValuefinal584   IValue(int64_t i) : tag(Tag::Int) {
585     payload.u.as_int = i;
586   }
587 
IValuefinal588   IValue(const c10::SymInt& i) {
589     if (auto mi = i.maybe_as_int()) {
590       tag = Tag::Int;
591       payload.u.as_int = *mi;
592     } else {
593       tag = Tag::SymInt;
594       payload.u.as_intrusive_ptr = i.toSymNode().release();
595     }
596   }
597 
isSymIntfinal598   bool isSymInt() const {
599     return Tag::SymInt == tag;
600   }
601 
602   c10::SymInt toSymInt() &&;
603   c10::SymInt toSymInt() const&;
604 
IValuefinal605   IValue(const c10::SymFloat& i) {
606     if (i.is_symbolic()) {
607       tag = Tag::SymFloat;
608       payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
609     } else {
610       tag = Tag::Double;
611       payload.u.as_double = i.as_float_unchecked();
612     }
613   }
614 
isSymFloatfinal615   bool isSymFloat() const {
616     return Tag::SymFloat == tag;
617   }
618 
619   c10::SymFloat toSymFloat() &&;
620   c10::SymFloat toSymFloat() const&;
621 
IValuefinal622   IValue(const c10::SymBool& i) {
623     if (auto mi = i.maybe_as_bool()) {
624       tag = Tag::Bool;
625       payload.u.as_int = *mi;
626     } else {
627       tag = Tag::SymBool;
628       payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
629     }
630   }
631 
isSymBoolfinal632   bool isSymBool() const {
633     return Tag::SymBool == tag;
634   }
635 
636   c10::SymBool toSymBool() &&;
637   c10::SymBool toSymBool() const&;
638 
639   // allow you to pass literals (3, 4) without ambiguity
IValuefinal640   IValue(int32_t i) : IValue(static_cast<int64_t>(i)) {}
641 
isIntfinal642   bool isInt() const {
643     return Tag::Int == tag;
644   }
645 
toIntfinal646   int64_t toInt() const {
647     if (isInt()) {
648       return payload.u.as_int;
649     } else if (isSymInt()) {
650       return toSymInt().guard_int(__FILE__, __LINE__);
651     } else {
652       TORCH_INTERNAL_ASSERT(0, "expected int");
653     }
654   }
655 
656   // Bool
IValuefinal657   IValue(bool b) : tag(Tag::Bool) {
658 #if defined(__clang__) && defined(__x86_64__)
659     // Initializing entire payload stops valgrind's from reporting
660     // "jump or move depends on uninitialised value" in IValue copy constructor
661     // See https://github.com/pytorch/pytorch/issues/37117
662     payload.u.as_int = b;
663 #else
664     payload.u.as_bool = b;
665 #endif
666   }
isBoolfinal667   bool isBool() const {
668     return Tag::Bool == tag;
669   }
toBoolfinal670   bool toBool() const {
671     if (isBool()) {
672       return payload.u.as_bool;
673     } else if (isSymBool()) {
674       return toSymBool().guard_bool(__FILE__, __LINE__);
675     } else {
676       TORCH_INTERNAL_ASSERT(0, "expected bool");
677     }
678   }
679 
680   // IntList
681   bool isIntList() const;
682   bool isSymIntList() const;
683   c10::List<int64_t> toIntList() &&;
684   c10::List<int64_t> toIntList() const&;
685   std::vector<int64_t> toIntVector() const;
686   std::vector<c10::SymInt> toSymIntVector() const;
687   at::DimVector toDimVector() const;
688 
689   // ConstantString
690   IValue(c10::intrusive_ptr<ivalue::ConstantString> v);
691   IValue(std::string v);
IValuefinal692   IValue(const char* v) : IValue(std::string(v)) {}
IValuefinal693   IValue(c10::string_view v) : IValue(std::string(v)){};
isStringfinal694   bool isString() const {
695     return Tag::String == tag;
696   }
697   c10::intrusive_ptr<ivalue::ConstantString> toString() &&;
698   c10::intrusive_ptr<ivalue::ConstantString> toString() const&;
699   const std::string& toStringRef() const;
700   std::optional<std::reference_wrapper<const std::string>> toOptionalStringRef()
701       const;
702   c10::string_view toStringView() const;
703 
704   // DoubleList
705   bool isDoubleList() const;
706   c10::List<double> toDoubleList() &&;
707   c10::List<double> toDoubleList() const&;
708   std::vector<double> toDoubleVector() const;
709 
710   // ComplexDoubleList
711   bool isComplexDoubleList() const;
712   c10::List<c10::complex<double>> toComplexDoubleList() &&;
713   c10::List<c10::complex<double>> toComplexDoubleList() const&;
714   std::vector<c10::complex<double>> toComplexDoubleVector() const;
715 
716   // BoolList
717   bool isBoolList() const;
718   c10::List<bool> toBoolList() &&;
719   c10::List<bool> toBoolList() const&;
720 
721   // TensorList
722   bool isTensorList() const;
723   c10::List<at::Tensor> toTensorList() &&;
724   c10::List<at::Tensor> toTensorList() const&;
725   std::vector<at::Tensor> toTensorVector() const;
726 
727   // OptionalTensorList
728   bool isOptionalTensorList() const;
729   c10::List<std::optional<at::Tensor>> toOptionalTensorList() &&;
730   c10::List<std::optional<at::Tensor>> toOptionalTensorList() const&;
731   std::vector<std::optional<at::Tensor>> toOptionalTensorVector() const;
732 
733   // GenericList
734   IValue(c10::List<IValue> v);
isListfinal735   bool isList() const {
736     return Tag::GenericList == tag;
737   }
738   c10::List<IValue> toList() &&;
739   c10::List<IValue> toList() const&;
740   c10::ArrayRef<IValue> toListRef() const;
741 
742   // Some template constructors of IValue calls another constructor recursively.
743   // This SFINAEs the called constructor exists.
744   template <class T>
745   using enable_if_ivalue_constructible =
746       std::enable_if_t<std::is_constructible_v<IValue, T>, std::nullptr_t>;
747 
748   // The rule for lists is more complicated; the generic constructor is only
749   // acceptable if your element isn't SymInt.  If you do have a SymInt element,
750   // then you must also, at construction time, check if you can decay the list
751   // into an int list (this is MANDATORY, as at a use site we may expect
752   // toIntList to work even if at the call site you had a SymIntArrayRef
753   // argument).  In practice, only SymIntArrayRef is used this way, so we
754   // didn't bother making it work for the other constructors, we just make sure
755   // they're not selectable.
756   template <class T>
757   using enable_if_list_is_ivalue_constructible = std::enable_if_t<
758       std::is_constructible_v<IValue, T> && !std::is_same_v<T, c10::SymInt>,
759       std::nullptr_t>;
760 
761   template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr>
762   IValue(c10::List<T>&& v);
763   template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr>
764   IValue(const c10::List<T>& v);
765   template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr>
766   IValue(at::ArrayRef<T> v);
767   template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr>
768   IValue(const std::vector<T>& v);
769   template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr>
770   IValue(std::vector<T>&& v);
771   template <class T, size_t N>
772   IValue(std::array<T, N> v);
773 
774   // Manual constructors for lists of symints, which decay to int list if
775   // possible.  To avoid ambiguous overload situations, we template them
776   // to prevent implicit conversions
777   template <class T>
778   using enable_if_symint =
779       std::enable_if_t<std::is_same_v<T, c10::SymInt>, std::nullptr_t>;
780 
781   template <class T, enable_if_symint<T> = nullptr>
782   IValue(at::ArrayRef<T> v);
783   template <class T, enable_if_symint<T> = nullptr>
784   IValue(at::OptionalArrayRef<T> v);
785   template <class T, enable_if_symint<T> = nullptr>
786   IValue(const std::vector<T>& v);
787   template <class T, enable_if_symint<T> = nullptr>
788   IValue(std::vector<T>&& v);
789 
790   template <class T>
791   using enable_if_ilist_is_ivalue_constructible = std::enable_if_t<
792       std::is_constructible_v<IValue, T> &&
793           std::is_constructible_v<IValue, typename IListRef<T>::boxed_type> &&
794           !std::is_same_v<T, c10::SymInt>,
795       std::nullptr_t>;
796 
797   template <class T, enable_if_ilist_is_ivalue_constructible<T> = nullptr>
798   IValue(c10::IListRef<T> v);
799 
800   // GenericDict
801   IValue(c10::Dict<IValue, IValue> v);
isGenericDictfinal802   bool isGenericDict() const {
803     return Tag::GenericDict == tag;
804   }
805   c10::Dict<IValue, IValue> toGenericDict() &&;
806   c10::Dict<IValue, IValue> toGenericDict() const&;
807 
808   template <class Key, class Value>
809   IValue(c10::Dict<Key, Value> v);
810 
811   template <class Key, class Value>
812   /// \cond
813   /// DOXYGEN_CANNOT_HANDLE_CONSTRUCTORS_WITH_MACROS_SO_EXCLUDE_THIS_LINE_FROM_DOXYGEN
814   C10_DEPRECATED_MESSAGE(
815       "IValues based on std::unordered_map<K, V> are slow and deprecated. Please use c10::Dict<K, V> instead.")
816       /// \endcond
817       IValue(std::unordered_map<Key, Value> v);
818 
819   template <class T, enable_if_ivalue_constructible<T> = nullptr>
820   IValue(std::optional<T> v);
821   template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr>
822   IValue(c10::OptionalArrayRef<T> v);
823   IValue(std::nullopt_t);
824 
825   // ClassType
826   IValue(c10::intrusive_ptr<ivalue::Object> v);
isObjectfinal827   bool isObject() const {
828     return tag == Tag::Object;
829   }
830   c10::intrusive_ptr<ivalue::Object> toObject() &&;
831   c10::intrusive_ptr<ivalue::Object> toObject() const&;
832   ivalue::Object& toObjectRef() const;
833 
834   torch::jit::Module toModule() const;
835   bool isModule() const;
836 
837   // PyObject
838   IValue(c10::intrusive_ptr<ivalue::PyObjectHolder> v);
isPyObjectfinal839   bool isPyObject() const {
840     return tag == Tag::PyObject;
841   }
842   c10::intrusive_ptr<ivalue::PyObjectHolder> toPyObjectHolder() &&;
843   c10::intrusive_ptr<ivalue::PyObjectHolder> toPyObjectHolder() const&;
844   PyObject* toPyObject() const;
845 
846   // Enum
847   explicit IValue(c10::intrusive_ptr<ivalue::EnumHolder> v);
isEnumfinal848   bool isEnum() const {
849     return tag == Tag::Enum;
850   }
851   c10::intrusive_ptr<ivalue::EnumHolder> toEnumHolder() &&;
852   c10::intrusive_ptr<ivalue::EnumHolder> toEnumHolder() const&;
853 
854   // None
855   IValue() = default;
isNonefinal856   bool isNone() const {
857     return Tag::None == tag;
858   }
toNonefinal859   std::string toNone() const {
860     AT_ASSERT(isNone());
861     return "None";
862   }
863 
uninitializedfinal864   static IValue uninitialized() {
865     auto i = IValue();
866     i.tag = Tag::Uninitialized;
867     return i;
868   }
869 
870   // Scalar, which gets encoded as either an Int, a Double or a ComplexDouble
IValuefinal871   IValue(const at::Scalar& s) : IValue() {
872     // NB: do the symbolic versions first, as isFloatingPoint is true
873     // for both SymFloat and double
874     if (s.isSymInt()) {
875       tag = Tag::SymInt;
876       payload.u.as_intrusive_ptr = s.toSymInt().toSymNode().release();
877     } else if (s.isSymFloat()) {
878       tag = Tag::SymFloat;
879       payload.u.as_intrusive_ptr = s.toSymFloat().toSymNodeImpl().release();
880     } else if (s.isSymBool()) {
881       tag = Tag::SymBool;
882       payload.u.as_intrusive_ptr = s.toSymBool().toSymNodeImpl().release();
883     } else if (s.isFloatingPoint()) {
884       tag = Tag::Double;
885       payload.u.as_double = s.toDouble();
886     } else if (s.isComplex()) {
887       *this = s.toComplexDouble();
888     } else if (s.isBoolean()) {
889       tag = Tag::Bool;
890       payload.u.as_bool = s.toBool();
891     } else {
892       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
893           s.isIntegral(false), "Unknown type in Scalar");
894       tag = Tag::Int;
895       payload.u.as_int = s.toLong();
896     }
897   }
898 
isScalarfinal899   bool isScalar() const {
900     return isDouble() || isInt() || isComplexDouble() || isBool() ||
901         isSymInt() || isSymFloat() || isSymBool();
902   }
903 
toScalarfinal904   at::Scalar toScalar() const {
905     if (isDouble())
906       return toDouble();
907     else if (isInt())
908       return toInt();
909     else if (isComplexDouble())
910       return toComplexDouble();
911     else if (isBool())
912       return toBool();
913     else if (isSymInt())
914       return toSymInt();
915     else if (isSymFloat())
916       return toSymFloat();
917     else if (isSymBool())
918       return toSymBool();
919     throw std::runtime_error("IValue is not a Scalar");
920   }
921 
922   // Device
IValuefinal923   IValue(c10::Device d) : tag(Tag::Device) {
924     payload.u.as_device.type = d.type();
925     payload.u.as_device.index = d.index();
926   }
isDevicefinal927   bool isDevice() const {
928     return Tag::Device == tag;
929   }
toDevicefinal930   c10::Device toDevice() const {
931     AT_ASSERT(isDevice());
932     return c10::Device(payload.u.as_device.type, payload.u.as_device.index);
933   }
934 
935   // Stream
IValuefinal936   IValue(c10::Stream s) : tag(Tag::Stream) {
937     auto v = c10::make_intrusive<ivalue::StreamData3Holder>(s.pack3());
938     payload.u.as_intrusive_ptr = v.release();
939   }
940   c10::Stream toStream() &&;
941   c10::Stream toStream() const&;
isStreamfinal942   bool isStream() const {
943     return Tag::Stream == tag;
944   }
945 
946   // ScalarType
IValuefinal947   IValue(ScalarType t)
948       : IValue(static_cast<std::underlying_type_t<ScalarType>>(t)) {}
toScalarTypefinal949   at::ScalarType toScalarType() const {
950     return static_cast<at::ScalarType>(toInt());
951   }
952 
953   // Layout
IValuefinal954   IValue(Layout l) : IValue(static_cast<std::underlying_type_t<Layout>>(l)) {}
toLayoutfinal955   at::Layout toLayout() const {
956     return static_cast<at::Layout>(toInt());
957   }
958 
959   // MemoryFormat
IValuefinal960   IValue(MemoryFormat m)
961       : IValue(static_cast<std::underlying_type_t<MemoryFormat>>(m)) {}
toMemoryFormatfinal962   at::MemoryFormat toMemoryFormat() const {
963     return static_cast<at::MemoryFormat>(toInt());
964   }
965 
966   // QScheme
IValuefinal967   IValue(at::QScheme qscheme) : tag(Tag::Int) {
968     payload.u.as_int = static_cast<int64_t>(qscheme);
969   }
970 
toQSchemefinal971   at::QScheme toQScheme() const {
972     return static_cast<at::QScheme>(toInt());
973   }
974 
975   // Dimname
IValuefinal976   IValue(at::Dimname dimname) : IValue(dimname.symbol().toQualString()) {}
977 
toDimnamefinal978   at::Dimname toDimname() const {
979     return at::Dimname::fromSymbol(Symbol::fromQualString(toStringRef()));
980   }
981 
982   // Generator
IValuefinal983   IValue(at::Generator g) : tag(Tag::Generator) {
984     payload.u.as_intrusive_ptr =
985         null_to_undefined_tensor(g.unsafeReleaseGeneratorImpl());
986   }
isGeneratorfinal987   bool isGenerator() const {
988     return Tag::Generator == tag;
989   }
990   at::Generator toGenerator() &&;
991   at::Generator toGenerator() const&;
992 
993   // for debugging
tagKindfinal994   std::string tagKind() const {
995     switch (tag) {
996 #define DEFINE_CASE(x) \
997   case Tag::x:         \
998     return #x;
999       TORCH_FORALL_TAGS(DEFINE_CASE)
1000 #undef DEFINE_CASE
1001     }
1002     return "InvalidTag(" + std::to_string(static_cast<int>(tag)) + ")";
1003   }
1004 
1005   // generic v.to<at::Tensor>() implementations
1006   // that can be used in special functions like pop/push
1007   // that use template meta-programming.
1008   // prefer the directly named methods when you can,
1009   // since they are simpler to understand
1010 
1011   // Note: if you get linker errors saying one of these is missing,
1012   // change it to ... && = delete; and you will see better error messages for
1013   // why However, we cannot commit this because some compiler versions barf on
1014   // it.
1015   template <typename T>
1016   T to() &&;
1017   template <typename T>
1018   typename c10::detail::ivalue_to_const_ref_overload_return<T>::type to()
1019       const&;
1020 
1021   // ToOptional: convert a IValue to the Optional obj that accepts both T and
1022   // None
1023   template <typename T>
1024   std::optional<T> toOptional();
1025   template <typename T>
1026   std::optional<T> toOptional() const;
1027 
1028   /// @private [doxygen private]
1029   /// this is a shallow comparison of two IValues to test the object identity
1030   bool isSameIdentity(const IValue& rhs) const;
1031 
1032   // Computes the "official" string representation of an IValue. This produces a
1033   // TorchScript expression that can be used to recreate an IValue with the same
1034   // value (e.g. when we are printing constants in the serializer).
1035   //
1036   // Callers can use `customFormatter` to override how `repr()` prints out an
1037   // IValue. This is useful if you have some other environment where you can
1038   // look up values, and you want to print a reference to that environment (like
1039   // the serializer's constant table).
1040   //
1041   // repr() is not necessarily defined on all objects!
1042   std::ostream& repr(
1043       std::ostream& stream,
1044       std::function<bool(std::ostream&, const IValue& v)> customFormatter)
1045       const;
1046 
1047   // Computes an "informal" string representation of an IValue. This should be
1048   // used for debugging, or servicing `print()`-like functions.
1049   // This is different from `repr()` in that there is no expectation that we can
1050   // exactly reconstruct an IValue from the output; feel free to use a
1051   // concise/pretty form
1052   TORCH_API friend std::ostream& operator<<(std::ostream& out, const IValue& v);
1053 
isPtrTypefinal1054   bool isPtrType() const {
1055     if (isTensor()) {
1056       return payload.as_tensor.defined();
1057     }
1058     return isIntrusivePtrLegacyBehavior();
1059   }
1060 
1061   /// @private [doxygen private]
internalToPointerfinal1062   const void* internalToPointer() const {
1063     TORCH_INTERNAL_ASSERT(
1064         isPtrType(), "Can only call internalToPointer() for pointer types");
1065     if (isTensor()) {
1066       return payload.as_tensor.unsafeGetTensorImpl();
1067     } else {
1068       return payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()
1069           ? payload.u.as_intrusive_ptr
1070           : nullptr;
1071     }
1072   }
1073 
1074   template <typename T = c10::PlatformType>
1075   TypePtr type() const;
1076 
1077   // Detect aliased tensors.
1078   struct HashAliasedIValue {
hashTensorfinal::HashAliasedIValue1079     size_t hashTensor(const at::Tensor& ten) const {
1080       if (ten.is_sparse()) {
1081         // COO sparse tensors have a "values" tensor and an "indices" tensor
1082         // so this will detect overlap of sparse tensors that share a values
1083         // tensor, but not sparse tensors that share an indices tensor.
1084         return hashTensor(ten._values());
1085       } else if (ten.is_sparse_csr()) {
1086         // COO sparse tensors have a "values" tensor and an "indices" tensor
1087         // so this will detect overlap of sparse tensors that share a values
1088         // tensor, but not sparse tensors that share an indices tensor.
1089         return hashTensor(ten.values());
1090       } else if (!ten.has_storage()) {
1091         // Opaque tensors such as the ones constructed by the MKL-DNN backend
1092         // don't have storage so we just use their TensorImpls.
1093         // TODO: Find way to expose alias info for opaque tensors.
1094         return reinterpret_cast<size_t>(ten.unsafeGetTensorImpl());
1095       } else {
1096         return reinterpret_cast<size_t>(ten.storage().unsafeGetStorageImpl());
1097       }
1098     }
operatorfinal::HashAliasedIValue1099     size_t operator()(const IValue& val) const {
1100       if (val.isTensor()) {
1101         return hashTensor(val.toTensor());
1102       }
1103       // If it is not a Tensor, then two mutable IValues alias each other only
1104       // if they are the same pointer.
1105       return val.payload.u.as_int;
1106     }
1107   };
1108 
1109   struct CompAliasedIValues {
operatorfinal::CompAliasedIValues1110     bool operator()(const IValue& lhs, const IValue& rhs) const {
1111       return lhs.isAliasOf(rhs);
1112     }
1113   };
1114 
1115   using HashAliasedIValues =
1116       std::unordered_set<IValue, HashAliasedIValue, CompAliasedIValues>;
1117   using HashAliasedIValueMap =
1118       std::unordered_map<IValue, IValue, HashAliasedIValue, CompAliasedIValues>;
1119 
1120   struct HashIdentityIValue {
operatorfinal::HashIdentityIValue1121     size_t operator()(const IValue& val) const {
1122       return val.payload.u.as_int;
1123     }
1124   };
1125 
1126   struct CompIdentityIValues {
operatorfinal::CompIdentityIValues1127     bool operator()(const IValue& lhs, const IValue& rhs) const {
1128       return lhs.is(rhs);
1129     }
1130   };
1131 
1132   using HashIdentityIValues =
1133       std::unordered_set<IValue, HashIdentityIValue, CompIdentityIValues>;
1134   using HashIdentityIValueMap =
1135       std::unordered_map<IValue, IValue, HashIdentityIValue, CompIdentityIValues>;
1136 
1137   // Chechs if this and rhs has a subvalues in common.
1138   // [t1,t2] and [t2, t3] returns true.
1139   bool overlaps(const IValue& rhs) const;
1140 
1141   // Inserts all subvalues of this in subValues.
1142   void getSubValues(HashAliasedIValues& subValues) const;
1143 
1144   // Apply visitor to every subvalue.
1145   // TODO: There are several places that recurse over IValue. This is fragile.
1146   // This visitor should be used to recurse over ivalues.
1147   void visit(const std::function<bool(const IValue&)>& visitor) const;
1148   IValue deepcopy(std::optional<at::Device> device = std::nullopt) const;
1149   IValue deepcopy(
1150       HashIdentityIValueMap& memo,
1151       std::optional<at::Device> device = std::nullopt) const;
1152 
1153  private:
null_to_undefined_tensorfinal1154   static c10::intrusive_ptr_target* null_to_undefined_tensor(
1155       c10::intrusive_ptr_target* p) {
1156     return p ? p
1157              : static_cast<c10::intrusive_ptr_target*>(
1158                    c10::UndefinedTensorImpl::singleton());
1159   }
1160 
1161   static bool ptrEqual(const IValue& lhs, const IValue& rhs);
1162   // NOTE: IValue tags are intentionally private. In the future we may encode
1163   // this value different (e.g. using NaN boxing), and this would make it more
1164   // costly to determine the tag for all types vs just determining if something
1165   // is a particular type. Instead we want clients to use the `isX` methods when
1166   // possible. If for perf. reasons you really, absolutely, must have a jump
1167   // table, then we can revisit this.
1168   enum class Tag : uint32_t {
1169 #define DEFINE_TAG(x) x,
1170     TORCH_FORALL_TAGS(DEFINE_TAG)
1171 #undef DEFINE_TAG
1172   };
1173 
1174 #define COUNT_TAG(x) 1 +
1175   static constexpr auto kNumTags = TORCH_FORALL_TAGS(COUNT_TAG) 0;
1176 #undef COUNT_TAG
1177 
1178   template <
1179       class T,
1180       class NullType = c10::detail::intrusive_target_default_null_type<T>>
1181   c10::intrusive_ptr<T, NullType> moveToIntrusivePtr();
1182   template <
1183       typename T,
1184       class NullType = c10::detail::intrusive_target_default_null_type<T>>
1185   c10::intrusive_ptr<T, NullType> toIntrusivePtr() const;
1186 
destroyfinal1187   void destroy() {
1188     // We carefully construct this call to both 1) avoid UB by using
1189     // the "wrong" one of as_tensor and as_intrusive_ptr and 2) enable
1190     // the compiler to generate the same code for each case. It is
1191     // surprisingly difficult to get this right.
1192     if (isTensor() || isIntrusivePtr()) {
1193       c10::intrusive_ptr_target* p = isTensor()
1194           ? payload.as_tensor.unsafeGetTensorImpl()
1195           : payload.u.as_intrusive_ptr;
1196       c10::intrusive_ptr<intrusive_ptr_target, c10::UndefinedTensorImpl>::
1197           reclaim(p);
1198       // No need to make this destructor call!
1199       // payload.as_tensor.~Tensor();
1200     }
1201   }
1202 
1203   // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
moveFromfinal1204   C10_ALWAYS_INLINE void moveFrom(IValue&& rhs) noexcept {
1205     if (rhs.isTensor()) {
1206       new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor));
1207       // As far as I can tell, omitting the usual explicit destructor call
1208       // is not UB in and of itself, and it's a slight perf win. The
1209       // destructor is a no-op, because the moved-from Tensor is
1210       // effectively an intrusive_ptr in the null state, so we don't need
1211       // the behavior for correctness reasons either. Leaving this
1212       // explanatory comment, including commented-out destructor call, to
1213       // make this abundantly clear.
1214       //
1215       // rhs.payload.as_tensor.~Tensor();
1216     } else {
1217       payload.u = rhs.payload.u;
1218     }
1219     tag = rhs.tag;
1220     rhs.clearToNone();
1221   }
1222 
clearToNonefinal1223   void clearToNone() noexcept {
1224     payload.u.as_int = 0;
1225     tag = Tag::None;
1226   }
1227 
1228  private:
1229   // This is the source of truth for isIntrusivePtr; edit results here
1230   // as needed and isIntrusivePtr will pick them up.
1231   // NOLINTBEGIN(bugprone-branch-clone)
isIntrusivePtrConstexprfinal1232   static constexpr bool isIntrusivePtrConstexpr(Tag tag) {
1233     switch (tag) {
1234       case Tag::None:
1235         return false;
1236       case Tag::Tensor:
1237         return false;
1238       case Tag::Storage:
1239         return true;
1240       case Tag::Generator:
1241         return true;
1242       case Tag::Double:
1243         return false;
1244       case Tag::ComplexDouble:
1245         return true;
1246       case Tag::Int:
1247         return false;
1248       case Tag::SymInt:
1249         return true;
1250       case Tag::SymFloat:
1251         return true;
1252       case Tag::SymBool:
1253         return true;
1254       case Tag::Bool:
1255         return false;
1256       case Tag::Tuple:
1257         return true;
1258       case Tag::String:
1259         return true;
1260       case Tag::Blob:
1261         return true;
1262       case Tag::GenericList:
1263         return true;
1264       case Tag::GenericDict:
1265         return true;
1266       case Tag::Future:
1267         return true;
1268       case Tag::Await:
1269         return true;
1270       case Tag::Device:
1271         return false;
1272       case Tag::Stream:
1273         return true;
1274       case Tag::Object:
1275         return true;
1276       case Tag::PyObject:
1277         return true;
1278       case Tag::Uninitialized:
1279         return false;
1280       case Tag::Capsule:
1281         return true;
1282       case Tag::RRef:
1283         return true;
1284       case Tag::Quantizer:
1285         return true;
1286       case Tag::Enum:
1287         return true;
1288     }
1289     return false;
1290   }
1291   // NOLINTEND(bugprone-branch-clone)
1292 
1293  public:
1294   // Don't edit this just to add results for new tags; edit
1295   // isIntrusivePtrConstexpr above.
isIntrusivePtrfinal1296   bool isIntrusivePtr() const {
1297     // Implementation NOTE: the switch in isIntrusivePtrConstexpr
1298     // above is the previous production implementation of this
1299     // function. We observed that, at least on x86_64, the generated
1300     // instruction sequence was a similar bit vector test to what we
1301     // have manually implemented below, except that there was an extra
1302     // "bounds check" branch confirming, essentially, that `tag <
1303     // kNumTags` and providing a consistent result in that case. We
1304     // don't care about the result if tag is out of bounds, so we'd
1305     // like to eliminate that comparison and branch; manually
1306     // implementing this function as a bit test is the simplest way I
1307     // could find to accomplish that elimination.
1308     static constexpr uint32_t kTruthTableBitVector =
1309 #define TRUTH_TABLE_ENTRY(tag) \
1310   (uint32_t(isIntrusivePtrConstexpr(Tag::tag)) << uint32_t(Tag::tag)) |
1311         TORCH_FORALL_TAGS(TRUTH_TABLE_ENTRY)
1312 #undef TRUTH_TABLE_ENTRY
1313             0;
1314 
1315     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1316         static_cast<uint32_t>(tag) < kNumTags,
1317         "unexpected tag ",
1318         static_cast<int>(tag));
1319     return kTruthTableBitVector & (1 << (uint32_t(tag) % 32));
1320   }
1321 
1322   // Storage and Generator were treated specially when
1323   // is_intrusive_ptr was stored as explicit state. This getter
1324   // preserves the old behavior for use with WeakIValue for now.
isIntrusivePtrLegacyBehaviorfinal1325   bool isIntrusivePtrLegacyBehavior() const {
1326     if (tag == Tag::Storage || tag == Tag::Generator) {
1327       return payload.u.as_intrusive_ptr !=
1328           c10::UndefinedTensorImpl::singleton();
1329     } else {
1330       return isIntrusivePtr();
1331     }
1332   }
1333 
1334   union Payload {
1335     // [TriviallyCopyablePayload]
1336     // We use a nested union here so that we can make the copy easy
1337     // and efficient in the non-tensor (i.e., trivially copyable)
1338     // case. Specifically, we do not have to do a switch-on-tag to
1339     // figure out which union member to assign; we can just use
1340     // TriviallyCopyablePayload::operator=.
1341     union TriviallyCopyablePayload {
TriviallyCopyablePayload()1342       TriviallyCopyablePayload() : as_int(0) {}
1343       int64_t as_int;
1344       double as_double;
1345       bool as_bool;
1346       // Invariant: never nullptr; null state is represented as
1347       // c10::UndefinedTensorImpl::singleton() for consistency of
1348       // representation with Tensor.
1349       c10::intrusive_ptr_target* as_intrusive_ptr;
1350       struct {
1351         c10::DeviceType type;
1352         DeviceIndex index;
1353       } as_device;
1354     } u;
1355     at::Tensor as_tensor;
Payload()1356     Payload() : u() {}
~Payload()1357     ~Payload() {}
1358   };
1359 
IValuefinal1360   IValue(const Payload& p, Tag t) : tag(t) {
1361     if (isTensor()) {
1362       new (&payload.as_tensor) at::Tensor(p.as_tensor);
1363     } else {
1364       payload.u = p.u;
1365     }
1366   }
1367 
1368   template <typename T>
1369   struct TagType {};
1370 
1371   friend MaybeOwnedTraits<IValue>;
1372 
1373   Payload payload;
1374   Tag tag{IValue::Tag::None};
1375   friend struct WeakIValue;
1376 };
1377 
1378 struct TORCH_API WeakIValue final {
1379   WeakIValue() = default;
1380 
WeakIValuefinal1381   WeakIValue(const WeakIValue& rhs)
1382       : payload(rhs.payload),
1383         tag(rhs.tag),
1384         is_intrusive_ptr(rhs.is_intrusive_ptr) {
1385     if (is_intrusive_ptr &&
1386         payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
1387       c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr);
1388     }
1389   }
WeakIValuefinal1390   WeakIValue(const IValue& rhs)
1391       : tag(rhs.tag), is_intrusive_ptr(rhs.isIntrusivePtrLegacyBehavior()) {
1392     if (rhs.isTensor()) {
1393       payload.as_intrusive_ptr = rhs.unsafeToTensorImpl();
1394       is_intrusive_ptr = true;
1395     } else {
1396       payload = rhs.payload.u;
1397     }
1398     if (is_intrusive_ptr) {
1399       if (payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
1400         c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr);
1401       }
1402     }
1403   }
WeakIValuefinal1404   WeakIValue(WeakIValue&& rhs) noexcept : WeakIValue() {
1405     swap(rhs);
1406   }
~WeakIValuefinal1407   ~WeakIValue() {
1408     if (is_intrusive_ptr &&
1409         payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
1410       c10::raw::weak_intrusive_ptr::decref(payload.as_intrusive_ptr);
1411     }
1412   }
1413   WeakIValue& operator=(WeakIValue&& rhs) & noexcept {
1414     WeakIValue(std::move(rhs)).swap(*this); // this also sets rhs to None
1415     return *this;
1416   }
1417   WeakIValue& operator=(WeakIValue const& rhs) & {
1418     WeakIValue(rhs).swap(*this);
1419     return *this;
1420   }
swapfinal1421   void swap(WeakIValue& rhs) noexcept {
1422     std::swap(payload, rhs.payload);
1423     std::swap(is_intrusive_ptr, rhs.is_intrusive_ptr);
1424     std::swap(tag, rhs.tag);
1425   }
1426 
isSameIdentityfinal1427   bool isSameIdentity(const WeakIValue& rhs) const {
1428     return payload.as_int == rhs.payload.as_int && tag == rhs.tag &&
1429         is_intrusive_ptr == rhs.is_intrusive_ptr;
1430   }
1431 
lockfinal1432   IValue lock() const {
1433     if (!is_intrusive_ptr) {
1434       IValue::Payload newPayload;
1435       newPayload.u = payload;
1436       return IValue(newPayload, tag);
1437     }
1438     if (IValue::Tag::Tensor == tag) {
1439       auto temp =
1440           c10::weak_intrusive_ptr<at::TensorImpl, c10::UndefinedTensorImpl>::
1441               reclaim(static_cast<at::TensorImpl*>(payload.as_intrusive_ptr));
1442       c10::intrusive_ptr<at::TensorImpl, c10::UndefinedTensorImpl> ip(
1443           temp.lock());
1444       temp.release();
1445       if (!ip) {
1446         return IValue();
1447       } else {
1448         return IValue(at::Tensor(std::move(ip)));
1449       }
1450     } else {
1451       auto temp = c10::weak_intrusive_ptr<c10::intrusive_ptr_target>::reclaim(
1452           payload.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()
1453               ? nullptr
1454               : payload.as_intrusive_ptr);
1455       IValue::Payload pl;
1456       pl.u.as_intrusive_ptr = temp.lock().release();
1457       temp.release();
1458       if (!pl.u.as_intrusive_ptr) {
1459         return IValue();
1460       } else {
1461         return IValue(pl, tag);
1462       }
1463     }
1464   }
1465 
use_countfinal1466   size_t use_count() const noexcept {
1467     if (!is_intrusive_ptr) {
1468       return 1;
1469     }
1470     auto temp = c10::weak_intrusive_ptr<
1471         c10::intrusive_ptr_target,
1472         c10::UndefinedTensorImpl>::reclaim(payload.as_intrusive_ptr);
1473     size_t result = temp.use_count();
1474     temp.release();
1475     return result;
1476   }
1477 
weak_use_countfinal1478   size_t weak_use_count() const noexcept {
1479     if (!is_intrusive_ptr) {
1480       return 1;
1481     }
1482     auto temp = c10::weak_intrusive_ptr<
1483         c10::intrusive_ptr_target,
1484         c10::UndefinedTensorImpl>::reclaim(payload.as_intrusive_ptr);
1485     size_t result = temp.weak_use_count();
1486     temp.release();
1487     return result;
1488   }
hashfinal1489   size_t hash() const {
1490     return payload.as_int;
1491   }
1492 
1493  private:
1494   using Payload = IValue::Payload::TriviallyCopyablePayload;
1495   Payload payload;
1496   IValue::Tag tag{IValue::Tag::None};
1497   bool is_intrusive_ptr{false};
1498 };
1499 
1500 // An owning pointer to a type. When the type is class type, it requires a pair
1501 // of shared_ptrs to the class type and its owning CU, so that the class type is
1502 // guaranteed to stay alive as long as we hold this object.
1503 struct TORCH_API StrongTypePtr {
1504   StrongTypePtr(std::shared_ptr<torch::jit::CompilationUnit> cu, TypePtr type);
1505 
1506   std::shared_ptr<torch::jit::CompilationUnit> cu_;
1507   TypePtr type_;
1508 };
1509 
1510 // [Constant Object Weak CompilationUnit Reference]
1511 // A non owning pointer to a type. When a class get inserted as a constant
1512 // into a graph, if we used a strong pointer we would have a circular reference
1513 // from Object -> CompilationUnit and CompilationUnit -> Graph (which owns the
1514 // Constant Object)
1515 struct TORCH_API WeakTypePtr {
1516   WeakTypePtr(std::weak_ptr<torch::jit::CompilationUnit> cu, TypePtr type);
1517 
1518   std::weak_ptr<torch::jit::CompilationUnit> cu_;
1519   TypePtr type_;
1520 };
1521 
1522 // internal build errors with std::variant :/
1523 struct WeakOrStrongCompilationUnit {
WeakOrStrongCompilationUnitWeakOrStrongCompilationUnit1524   explicit WeakOrStrongCompilationUnit(
1525       std::shared_ptr<torch::jit::CompilationUnit> shared_cu)
1526       : strong_ptr_(std::move(shared_cu)), weak_ptr_(std::nullopt) {}
1527 
WeakOrStrongCompilationUnitWeakOrStrongCompilationUnit1528   explicit WeakOrStrongCompilationUnit(
1529       std::weak_ptr<torch::jit::CompilationUnit> weak_cu)
1530       : strong_ptr_(std::nullopt), weak_ptr_(std::move(weak_cu)) {}
1531 
getStrongRefOrThrowWeakOrStrongCompilationUnit1532   std::shared_ptr<torch::jit::CompilationUnit> getStrongRefOrThrow() const {
1533     TORCH_INTERNAL_ASSERT(strong_ptr_ != std::nullopt);
1534     return *strong_ptr_;
1535   }
1536 
getWeakRefOrThrowWeakOrStrongCompilationUnit1537   std::weak_ptr<torch::jit::CompilationUnit> getWeakRefOrThrow() const {
1538     TORCH_INTERNAL_ASSERT(weak_ptr_ != std::nullopt);
1539     return *weak_ptr_;
1540   }
1541 
holdingStrongRefWeakOrStrongCompilationUnit1542   bool holdingStrongRef() const {
1543     return strong_ptr_ != std::nullopt;
1544   }
1545 
holdingEmptyStrongRefWeakOrStrongCompilationUnit1546   bool holdingEmptyStrongRef() const {
1547     return holdingStrongRef() && *strong_ptr_ == nullptr;
1548   }
1549 
1550   std::optional<std::shared_ptr<torch::jit::CompilationUnit>> strong_ptr_;
1551   std::optional<std::weak_ptr<torch::jit::CompilationUnit>> weak_ptr_;
1552 };
1553 
1554 // An Object will hold a non-owning Compilation Unit reference if it is a
1555 // Constant in the graph and a Owning reference otherwise
1556 struct TORCH_API WeakOrStrongTypePtr {
WeakOrStrongTypePtrWeakOrStrongTypePtr1557   explicit WeakOrStrongTypePtr(WeakTypePtr weak)
1558       : cu_(WeakOrStrongCompilationUnit(std::move(weak.cu_))),
1559         type_(std::move(weak.type_)) {}
WeakOrStrongTypePtrWeakOrStrongTypePtr1560   explicit WeakOrStrongTypePtr(StrongTypePtr strong)
1561       : cu_(WeakOrStrongCompilationUnit(std::move(strong.cu_))),
1562         type_(std::move(strong.type_)) {}
WeakOrStrongTypePtrWeakOrStrongTypePtr1563   explicit WeakOrStrongTypePtr(WeakOrStrongCompilationUnit cu, TypePtr type)
1564       : cu_(std::move(cu)), type_(std::move(type)) {}
1565   WeakTypePtr asWeakTypePtr() const;
1566 
1567   WeakOrStrongCompilationUnit cu_;
1568   TypePtr type_;
1569 
holds_strong_refWeakOrStrongTypePtr1570   bool holds_strong_ref() const {
1571     return cu_.holdingStrongRef();
1572   }
1573 
holds_empty_strong_refWeakOrStrongTypePtr1574   bool holds_empty_strong_ref() const {
1575     return cu_.holdingEmptyStrongRef();
1576   }
1577 };
1578 
1579 } // namespace c10
1580 
1581 #include <ATen/core/ivalue_inl.h> // IWYU pragma: keep
1582