1 #pragma once 2 3 #include <ATen/core/DimVector.h> 4 #include <ATen/core/TensorBody.h> 5 #include <ATen/core/blob.h> 6 #include <ATen/core/custom_class.h> 7 #include <ATen/core/ivalue_to.h> 8 #include <ATen/core/jit_type_base.h> 9 #include <ATen/core/type_factory.h> 10 #include <c10/core/SymBool.h> 11 #include <c10/core/SymFloat.h> 12 #include <c10/macros/Export.h> 13 #include <c10/util/MaybeOwned.h> 14 #include <c10/util/intrusive_ptr.h> 15 #include <type_traits> 16 #include <unordered_map> 17 #include <unordered_set> 18 #include <utility> 19 20 namespace torch { 21 class TORCH_API CustomClassHolder : public c10::intrusive_ptr_target {}; 22 namespace jit { 23 using ::torch::CustomClassHolder; 24 struct Function; 25 struct CompilationUnit; 26 struct Module; 27 } // namespace jit 28 } // namespace torch 29 namespace c10 { 30 template <class Key, class Value> 31 class Dict; 32 template <class T> 33 class List; 34 template <class T> 35 class IListRef; 36 struct IValue; 37 struct ClassType; 38 struct Type; 39 class RRefInterface; 40 41 struct ClassType; 42 using ClassTypePtr = std::shared_ptr<ClassType>; 43 44 TORCH_API bool _fastEqualsForContainer(const IValue& lhs, const IValue& rhs); 45 46 TORCH_API torch::jit::Function* checkObjectSortSchema( 47 const c10::ClassTypePtr& t, 48 std::stringstream& why_not); 49 50 // A comparator that checks ordering of two IValues of same type. 51 typedef std::function<bool(const IValue& a, const IValue& b)> IValueComparator; 52 53 TORCH_API IValueComparator getLessThanComparator(const IValue& v); 54 TORCH_API IValueComparator getGreaterThanComparator(const IValue& v); 55 56 namespace ivalue { 57 struct Tuple; 58 struct Future; 59 struct Await; 60 struct ConstantString; 61 struct GenericDict; 62 struct Object; 63 struct PyObjectHolder; 64 struct EnumHolder; 65 // We need a ComplexHolder because currently the payloads in the Union 66 // only take 64 bits. Since ComplexDouble takes up 128 bits, and is too big 67 // to fit in the IValue directly, we indirect complex numbers through an 68 // intrusive pointer to ComplexHolder (which contains a c10::complex). 69 struct ComplexHolder : c10::intrusive_ptr_target { 70 public: 71 template <typename T> ComplexHolderComplexHolder72 ComplexHolder(c10::complex<T> c) { 73 val = convert<decltype(val), c10::complex<T>>(c); 74 } 75 ComplexHolder() = default; 76 c10::complex<double> val; 77 }; 78 79 // Similar to ComplexHolder, for StreamData3 80 struct StreamData3Holder : c10::intrusive_ptr_target { 81 public: StreamData3HolderStreamData3Holder82 StreamData3Holder(struct c10::StreamData3 d) : val(d) {} 83 StreamData3Holder() = delete; 84 struct c10::StreamData3 val; 85 }; 86 87 } // namespace ivalue 88 89 // This is an owning wrapper for a std::optional<std::vector<T>> 90 // that can be implicitly converted to a (non-owning) std::optional<ArrayRef<T>>. 91 // Its purpose is to be used in generated code to keep the vector alive 92 // either until the end of a statement (as a temporary), or as a saved arg 93 // in autograd. 94 template <typename T> 95 struct OptionalArray { 96 std::optional<std::vector<T>> list; 97 98 OptionalArray() = default; OptionalArrayOptionalArray99 OptionalArray(std::vector<T> val) : list(std::move(val)) {} 100 101 // Used when saving an argument for the backwards pass. 102 OptionalArray& operator=(std::optional<ArrayRef<T>> ref) { 103 if (ref) { 104 list = std::vector<T>(ref->begin(), ref->end()); 105 } else { 106 list = std::nullopt; 107 } 108 return *this; 109 } 110 111 // Used when saving an argument for the backwards pass. 112 OptionalArray& operator=(c10::OptionalArrayRef<T> ref) { 113 if (ref) { 114 list = std::vector<T>(ref->begin(), ref->end()); 115 } else { 116 list = std::nullopt; 117 } 118 return *this; 119 } 120 121 operator std::optional<c10::ArrayRef<T>>() { 122 if (!list) { 123 return std::nullopt; 124 } 125 return *list; 126 } 127 128 operator c10::OptionalArrayRef<T>() { 129 if (!list) { 130 return std::nullopt; 131 } 132 return *list; 133 } 134 }; 135 136 // Capsule is an internal implementation detail of custom C++ classes. We 137 // define it as an owning wrapper for 138 // c10::intrusive_ptr<torch::CustomClassHolder> This wrapper is here to serve as 139 // an abstraction of the type erased custom class object pointer. It also allow 140 // pybind11 to treat this as a standalone class to register as a separate type 141 // caster, instead of a custom pointer holder which the pointer holder type 142 // caster try to "unwrap" it automatically. 143 struct Capsule { 144 c10::intrusive_ptr<torch::CustomClassHolder> obj_ptr; CapsuleCapsule145 explicit Capsule(c10::intrusive_ptr<torch::CustomClassHolder> ptr) 146 : obj_ptr(std::move(ptr)) {} 147 }; 148 149 // IValue is the generic tagged union used by the interpreter to hold 150 // all value types. 151 // It is a 16-byte object with an 8-byte payload and an 8-byte tag. 152 // The tag is currently 4 bytes to determine the type, and 1 byte 153 // to mark whether that type is a subtype of c10::intrusive_ptr_target and needs 154 // retain/release calls. 155 156 #define TORCH_FORALL_TAGS(_) \ 157 _(None) \ 158 _(Tensor) \ 159 _(Storage) \ 160 _(Double) \ 161 _(ComplexDouble) \ 162 _(Int) \ 163 _(SymInt) \ 164 _(SymFloat) \ 165 _(SymBool) \ 166 _(Bool) \ 167 _(Tuple) \ 168 _(String) \ 169 _(Blob) \ 170 _(GenericList) \ 171 _(GenericDict) \ 172 _(Future) \ 173 _(Await) \ 174 _(Device) \ 175 _(Stream) \ 176 _(Object) \ 177 _(PyObject) \ 178 _(Uninitialized) \ 179 _(Capsule) \ 180 _(RRef) \ 181 _(Quantizer) \ 182 _(Generator) \ 183 _(Enum) 184 185 // [doxygen private] 186 // These methods are not actually private but we don't want to document them, so 187 // they are marked `@private`, which hides them on the doxygen documentation for 188 // this page. 189 190 /// IValue (Interpreter Value) is a tagged union over the types 191 /// supported by the TorchScript interpreter. IValues contain their 192 /// values as an `IValue::Payload`, which holds primitive types 193 /// (`int64_t`, `bool`, `double`, `Device`) and `Tensor` as values, 194 /// and all other types as a `c10::intrusive_ptr`. In order to 195 /// optimize performance of the destructor and related operations by 196 /// making the `Tensor` and `c10::intrusive_ptr` paths generate the 197 /// same code, we represent a null `c10::intrusive_ptr` as 198 /// `UndefinedTensorImpl::singleton()`, *not* `nullptr`. 199 /// 200 /// IValues are used as inputs to and outputs from the TorchScript interpreter. 201 /// To retrieve the value contained within an IValue, use the `.toX()` methods, 202 /// where `X` is the type you are trying to get. Note that neither the `.toX()` 203 /// methods nor the templated `.to<T>` functions do any kind of casting, they 204 /// only unwrap the contained value. For example: 205 /// 206 /// \rst 207 /// .. code-block:: cpp 208 /// 209 /// // Make the IValue 210 /// torch::IValue my_ivalue(26); 211 /// std::cout << my_ivalue << "\n"; 212 /// 213 /// // Unwrap the IValue 214 /// int64_t my_int = my_ivalue.toInt(); 215 /// std::cout << my_int << "\n"; 216 /// 217 /// // This will throw an error! 218 /// // `my_ivalue` is tagged as an int and cannot be used as another type 219 /// torch::Tensor my_tensor = my_ivalue.toTensor(); 220 /// \endrst 221 struct TORCH_API IValue final { IValuefinal222 IValue(const IValue& rhs) : IValue(rhs.payload, rhs.tag) { 223 if (isIntrusivePtr() && 224 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { 225 c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr); 226 } 227 } 228 IValuefinal229 IValue(IValue&& rhs) noexcept : tag(rhs.tag) { 230 moveFrom(std::move(rhs)); 231 } 232 233 /// @private [doxygen private] ~IValuefinal234 ~IValue() { 235 destroy(); 236 } 237 238 C10_ALWAYS_INLINE IValue& operator=(IValue&& rhs) & noexcept { 239 if (&rhs == this) { 240 return *this; 241 } 242 243 destroy(); 244 moveFrom(std::move(rhs)); 245 return *this; 246 } 247 248 IValue& operator=(IValue const& rhs) & { 249 *this = IValue(rhs); 250 return *this; 251 } 252 253 void dump() const; 254 255 /** 256 * Equality comparison. The semantics are the same as Python's `==`: 257 * 1. Numerical types are compared by value. 258 * 2. Tensors compute element-wise equality, returning a BoolTensor (see: 259 * `torch.eq()`) 260 * 3. Strings are compared by value. 261 * 4. Sequence types (list, tuple) are compared lexicographically by 262 * comparing their elements. Different sequence types never compare equal. 263 * 5. Mappings (dict) must have equal (key, value) pairs. 264 * 6. If not listed above, the default behavior for is to test identity 265 * equality (e.g. pointer equality). 266 * 267 * Why does this return an IValue instead of a bool? Because in PyTorch, 268 * `tensor1 == tensor2` returns a `BoolTensor`, not a bool. 269 * 270 * NOTE: we (like Python) assume that identity equality implies value equality 271 * for efficiency. 272 * TODO: need to support customizing equality 273 */ 274 IValue equals(const IValue& rhs) const; 275 /** 276 * This implements the same semantics as `bool(lhs == rhs)` in Python. which 277 * is the same as `equals()` except for Tensor types. 278 */ 279 TORCH_API friend bool operator==(const IValue& lhs, const IValue& rhs); 280 TORCH_API friend bool operator!=(const IValue& lhs, const IValue& rhs); 281 282 /** 283 * Identity comparison. Checks if `this` is the same object as `rhs`. The 284 * semantics are the same as Python's `is` operator. 285 * 286 * NOTE: Like in Python, this operation is poorly defined for primitive types 287 * like numbers and strings. Prefer to use `==` unless you really want to 288 * check identity equality. 289 */ 290 bool is(const IValue& rhs) const; 291 292 /** 293 * Hashing for IValues. Returns an IValue-boxed int. 294 * 295 * Some notes: 296 * - Like eager, Tensors are hashed by looking at the pointer. This is not 297 * strictly correct because two value-equal tensors with different tensor 298 * pointers will hash differently, but we choose to reproduce the eager 299 * semantics. 300 * - Hashing is not defined on all built-in IValue types (e.g. list and 301 * dict), following Python. Calling `hash()` on these types will throw. 302 */ hashfinal303 IValue hash() const { 304 return (int64_t)IValue::hash(*this); 305 } 306 // This is defined because `c10::hash` dispatches to a function of this 307 // signature. See the member function `hash()`. 308 static size_t hash(const IValue& iv); 309 310 /** 311 * @private [doxygen private] 312 * [container equality] 313 * This is an equality implementation that assumes objects with the same 314 * identity equal themselves, for efficiency reasons. We primarily have this 315 * for consistency, because Python does the same thing. This actually 316 * provokes user-visible changes in behavior due to quirks in torch: 317 * [tensor1] == [tensor1] -> True (because container equality will first 318 * compare identity) [tensor1] == [tensor1_copy] -> RuntimeError: 319 * Boolean value of Tensor with more than one value is ambiguous 320 */ 321 TORCH_API friend bool _fastEqualsForContainer( 322 const IValue& lhs, 323 const IValue& rhs); 324 325 private: isAliasOffinal326 static bool isAliasOf(const at::Tensor& a, const at::Tensor& b) { 327 if (a.is_sparse()) { 328 return isAliasOf(a._values(), b) || isAliasOf(a._indices(), b); 329 } 330 if (b.is_sparse()) { 331 return isAliasOf(a, b._values()) || isAliasOf(a, b._indices()); 332 } 333 if (a.is_sparse_csr()) { 334 return isAliasOf(a.values(), b) || isAliasOf(a.crow_indices(), b) || 335 isAliasOf(a.col_indices(), b); 336 } 337 if (b.is_sparse_csr()) { 338 return isAliasOf(a, b.values()) || isAliasOf(a, b.crow_indices()) || 339 isAliasOf(a, b.col_indices()); 340 } 341 342 // Opaque tensors such as the ones constructed by the MKL-DNN backend 343 // don't have storage so we just compare their TensorImpls. 344 // TODO: Find way to expose alias info for opaque tensors. 345 if (!a.has_storage() || !b.has_storage()) { 346 return a.unsafeGetTensorImpl() == b.unsafeGetTensorImpl(); 347 } 348 349 return a.is_alias_of(b); 350 } 351 352 template <typename T> 353 bool isListOf() const; 354 355 public: 356 /// @private [doxygen private] isAliasOffinal357 bool isAliasOf(const IValue& rhs) const { 358 if (this->tag != rhs.tag) { 359 // Trivially don't alias if the type is different 360 return false; 361 } 362 363 // Tensors should be compared based on internal storage 364 if (this->isTensor()) { 365 return isAliasOf(this->toTensor(), rhs.toTensor()); 366 } 367 368 if (!isIntrusivePtr()) { 369 // Primitive types don't alias anything 370 return false; 371 } 372 373 AT_ASSERT(rhs.isIntrusivePtr()); 374 375 // Other types can be compared by their ptr value 376 return this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr; 377 } 378 379 /// @private [doxygen private] use_countfinal380 size_t use_count() const noexcept { 381 if (isTensor()) { 382 return payload.as_tensor.use_count(); 383 } 384 385 if (!isIntrusivePtrLegacyBehavior()) { 386 return 1; 387 } 388 389 if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) { 390 return 0; 391 } 392 return c10::raw::intrusive_ptr::use_count(payload.u.as_intrusive_ptr); 393 } 394 395 /// @private [doxygen private] swapfinal396 void swap(IValue& rhs) noexcept { 397 if (isTensor() && rhs.isTensor()) { 398 std::swap(payload.as_tensor, rhs.payload.as_tensor); 399 } else if (isTensor()) { 400 at::Tensor t = std::move(payload.as_tensor); 401 // As far as I can tell, omitting the usual explicit destructor call 402 // is not UB in and of itself, and it's a slight perf win. The 403 // destructor is a no-op, because the moved-from Tensor is 404 // effectively an intrusive_ptr in the null state, so we don't need 405 // the behavior for correctness reasons either. Leaving this 406 // explanatory comment, including commented-out destructor call, to 407 // make this abundantly clear. 408 // 409 // payload.as_tensor.~Tensor(); 410 payload.u = rhs.payload.u; 411 new (&rhs.payload.as_tensor) at::Tensor(std::move(t)); 412 } else if (rhs.isTensor()) { 413 rhs.swap(*this); 414 return; 415 } else { 416 std::swap(payload.u, rhs.payload.u); 417 } 418 std::swap(tag, rhs.tag); 419 } 420 421 // Accessors for subtypes are arranged together below 422 // While some of these accessors could be generated through templates, 423 // we prefer to write them manually for clarity 424 IValuefinal425 IValue(at::TensorBase t) : tag(Tag::Tensor) { 426 new (&payload.as_tensor) at::Tensor(std::move(t)); 427 } isTensorfinal428 bool isTensor() const { 429 return Tag::Tensor == tag; 430 } 431 432 private: 433 // Outlined error path so that toTensor() can be inlined. 434 [[noreturn]] void reportToTensorTypeError() const; 435 436 public: 437 at::Tensor toTensor() &&; 438 at::Tensor& toTensor() &; 439 const at::Tensor& toTensor() const&; unsafeToTensorImplfinal440 at::TensorImpl* unsafeToTensorImpl() const { 441 TORCH_INTERNAL_ASSERT(isTensor()); 442 return payload.as_tensor.unsafeGetTensorImpl(); 443 } 444 IValuefinal445 IValue(at::Storage s) : tag(Tag::Storage) { 446 payload.u.as_intrusive_ptr = 447 null_to_undefined_tensor(s.unsafeReleaseStorageImpl()); 448 } isStoragefinal449 bool isStorage() const { 450 return Tag::Storage == tag; 451 } 452 c10::Storage toStorage() &&; 453 c10::Storage toStorage() const&; 454 toIValuefinal455 const IValue& toIValue() const { 456 return *this; 457 } toIValuefinal458 IValue& toIValue() { 459 return *this; 460 } 461 462 /// @private [doxygen private] IValuefinal463 IValue(intrusive_ptr<caffe2::Blob> blob) : tag(Tag::Blob) { 464 // TODO (after Tensor merge) If we pass in a Blob holding a Tensor, extract 465 // and store it as a Tensor instead. 466 payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release()); 467 } 468 469 /// @private [doxygen private] isBlobfinal470 bool isBlob() const { 471 return Tag::Blob == tag; 472 } 473 474 /// @private [doxygen private] 475 c10::intrusive_ptr<caffe2::Blob> toBlob() &&; 476 477 /// @private [doxygen private] 478 c10::intrusive_ptr<caffe2::Blob> toBlob() const&; 479 480 // Capsule. No new callsites of these APIs should 481 // be introduced. 482 static inline IValue make_capsule( 483 intrusive_ptr<torch::CustomClassHolder> blob); isCapsulefinal484 bool isCapsule() const { 485 return Tag::Capsule == tag; 486 } 487 c10::intrusive_ptr<torch::CustomClassHolder> toCapsule() &&; 488 c10::intrusive_ptr<torch::CustomClassHolder> toCapsule() const&; 489 490 // Custom C++ classes 491 template < 492 typename T, 493 std::enable_if_t<std::is_base_of_v<torch::CustomClassHolder, T>, int> = 0> 494 IValue(intrusive_ptr<T> custom_class); 495 bool isCustomClass() const; 496 template <typename T> 497 c10::intrusive_ptr<T> toCustomClass() &&; 498 template <typename T> 499 c10::intrusive_ptr<T> toCustomClass() const&; 500 501 // Tuple 502 IValue(c10::intrusive_ptr<ivalue::Tuple> v); 503 504 template < 505 typename... Args, 506 std::enable_if_t< 507 !std::disjunction_v< 508 std::is_lvalue_reference<Args>..., 509 std::negation<std::is_constructible<IValue, Args>>...>, 510 std::nullptr_t> = nullptr> 511 IValue(const std::tuple<Args...>& t); 512 template < 513 typename... Args, 514 std::enable_if_t< 515 !std::disjunction_v< 516 std::is_lvalue_reference<Args>..., 517 std::negation<std::is_constructible<IValue, Args>>...>, 518 std::nullptr_t> = nullptr> 519 IValue(std::tuple<Args...>&& t); isTuplefinal520 bool isTuple() const { 521 return Tag::Tuple == tag; 522 } 523 c10::intrusive_ptr<ivalue::Tuple> toTuple() &&; 524 c10::intrusive_ptr<ivalue::Tuple> toTuple() const&; 525 C10_NODISCARD ivalue::Tuple& toTupleRef() const; 526 527 // Double IValuefinal528 IValue(double d) : tag(Tag::Double) { 529 payload.u.as_double = d; 530 } isDoublefinal531 bool isDouble() const { 532 return Tag::Double == tag; 533 } toDoublefinal534 double toDouble() const { 535 if (isDouble()) { 536 return payload.u.as_double; 537 } else if (isSymFloat()) { 538 return toSymFloat().guard_float(__FILE__, __LINE__); 539 } else { 540 TORCH_INTERNAL_ASSERT(0, "expected double"); 541 } 542 } 543 544 // ComplexDouble 545 template <typename T> 546 IValue(c10::complex<T> c); isComplexDoublefinal547 bool isComplexDouble() const { 548 return Tag::ComplexDouble == tag; 549 } 550 c10::complex<double> toComplexDouble() const; 551 552 // Future 553 IValue(c10::intrusive_ptr<ivalue::Future> v); isFuturefinal554 bool isFuture() const { 555 return Tag::Future == tag; 556 } 557 c10::intrusive_ptr<ivalue::Future> toFuture() &&; 558 c10::intrusive_ptr<ivalue::Future> toFuture() const&; 559 560 IValue(c10::intrusive_ptr<ivalue::Await> v); isAwaitfinal561 bool isAwait() const { 562 return Tag::Await == tag; 563 } 564 c10::intrusive_ptr<ivalue::Await> toAwait() &&; 565 c10::intrusive_ptr<ivalue::Await> toAwait() const&; 566 567 // RRef 568 IValue(c10::intrusive_ptr<c10::RRefInterface> v); isRReffinal569 bool isRRef() const { 570 return Tag::RRef == tag; 571 } 572 c10::intrusive_ptr<c10::RRefInterface> toRRef() &&; 573 c10::intrusive_ptr<c10::RRefInterface> toRRef() const&; 574 575 // Quantizer 576 IValue(c10::intrusive_ptr<at::Quantizer> v); isQuantizerfinal577 bool isQuantizer() const { 578 return Tag::Quantizer == tag; 579 } 580 c10::intrusive_ptr<at::Quantizer> toQuantizer() &&; 581 c10::intrusive_ptr<at::Quantizer> toQuantizer() const&; 582 583 // Int IValuefinal584 IValue(int64_t i) : tag(Tag::Int) { 585 payload.u.as_int = i; 586 } 587 IValuefinal588 IValue(const c10::SymInt& i) { 589 if (auto mi = i.maybe_as_int()) { 590 tag = Tag::Int; 591 payload.u.as_int = *mi; 592 } else { 593 tag = Tag::SymInt; 594 payload.u.as_intrusive_ptr = i.toSymNode().release(); 595 } 596 } 597 isSymIntfinal598 bool isSymInt() const { 599 return Tag::SymInt == tag; 600 } 601 602 c10::SymInt toSymInt() &&; 603 c10::SymInt toSymInt() const&; 604 IValuefinal605 IValue(const c10::SymFloat& i) { 606 if (i.is_symbolic()) { 607 tag = Tag::SymFloat; 608 payload.u.as_intrusive_ptr = i.toSymNodeImpl().release(); 609 } else { 610 tag = Tag::Double; 611 payload.u.as_double = i.as_float_unchecked(); 612 } 613 } 614 isSymFloatfinal615 bool isSymFloat() const { 616 return Tag::SymFloat == tag; 617 } 618 619 c10::SymFloat toSymFloat() &&; 620 c10::SymFloat toSymFloat() const&; 621 IValuefinal622 IValue(const c10::SymBool& i) { 623 if (auto mi = i.maybe_as_bool()) { 624 tag = Tag::Bool; 625 payload.u.as_int = *mi; 626 } else { 627 tag = Tag::SymBool; 628 payload.u.as_intrusive_ptr = i.toSymNodeImpl().release(); 629 } 630 } 631 isSymBoolfinal632 bool isSymBool() const { 633 return Tag::SymBool == tag; 634 } 635 636 c10::SymBool toSymBool() &&; 637 c10::SymBool toSymBool() const&; 638 639 // allow you to pass literals (3, 4) without ambiguity IValuefinal640 IValue(int32_t i) : IValue(static_cast<int64_t>(i)) {} 641 isIntfinal642 bool isInt() const { 643 return Tag::Int == tag; 644 } 645 toIntfinal646 int64_t toInt() const { 647 if (isInt()) { 648 return payload.u.as_int; 649 } else if (isSymInt()) { 650 return toSymInt().guard_int(__FILE__, __LINE__); 651 } else { 652 TORCH_INTERNAL_ASSERT(0, "expected int"); 653 } 654 } 655 656 // Bool IValuefinal657 IValue(bool b) : tag(Tag::Bool) { 658 #if defined(__clang__) && defined(__x86_64__) 659 // Initializing entire payload stops valgrind's from reporting 660 // "jump or move depends on uninitialised value" in IValue copy constructor 661 // See https://github.com/pytorch/pytorch/issues/37117 662 payload.u.as_int = b; 663 #else 664 payload.u.as_bool = b; 665 #endif 666 } isBoolfinal667 bool isBool() const { 668 return Tag::Bool == tag; 669 } toBoolfinal670 bool toBool() const { 671 if (isBool()) { 672 return payload.u.as_bool; 673 } else if (isSymBool()) { 674 return toSymBool().guard_bool(__FILE__, __LINE__); 675 } else { 676 TORCH_INTERNAL_ASSERT(0, "expected bool"); 677 } 678 } 679 680 // IntList 681 bool isIntList() const; 682 bool isSymIntList() const; 683 c10::List<int64_t> toIntList() &&; 684 c10::List<int64_t> toIntList() const&; 685 std::vector<int64_t> toIntVector() const; 686 std::vector<c10::SymInt> toSymIntVector() const; 687 at::DimVector toDimVector() const; 688 689 // ConstantString 690 IValue(c10::intrusive_ptr<ivalue::ConstantString> v); 691 IValue(std::string v); IValuefinal692 IValue(const char* v) : IValue(std::string(v)) {} IValuefinal693 IValue(c10::string_view v) : IValue(std::string(v)){}; isStringfinal694 bool isString() const { 695 return Tag::String == tag; 696 } 697 c10::intrusive_ptr<ivalue::ConstantString> toString() &&; 698 c10::intrusive_ptr<ivalue::ConstantString> toString() const&; 699 const std::string& toStringRef() const; 700 std::optional<std::reference_wrapper<const std::string>> toOptionalStringRef() 701 const; 702 c10::string_view toStringView() const; 703 704 // DoubleList 705 bool isDoubleList() const; 706 c10::List<double> toDoubleList() &&; 707 c10::List<double> toDoubleList() const&; 708 std::vector<double> toDoubleVector() const; 709 710 // ComplexDoubleList 711 bool isComplexDoubleList() const; 712 c10::List<c10::complex<double>> toComplexDoubleList() &&; 713 c10::List<c10::complex<double>> toComplexDoubleList() const&; 714 std::vector<c10::complex<double>> toComplexDoubleVector() const; 715 716 // BoolList 717 bool isBoolList() const; 718 c10::List<bool> toBoolList() &&; 719 c10::List<bool> toBoolList() const&; 720 721 // TensorList 722 bool isTensorList() const; 723 c10::List<at::Tensor> toTensorList() &&; 724 c10::List<at::Tensor> toTensorList() const&; 725 std::vector<at::Tensor> toTensorVector() const; 726 727 // OptionalTensorList 728 bool isOptionalTensorList() const; 729 c10::List<std::optional<at::Tensor>> toOptionalTensorList() &&; 730 c10::List<std::optional<at::Tensor>> toOptionalTensorList() const&; 731 std::vector<std::optional<at::Tensor>> toOptionalTensorVector() const; 732 733 // GenericList 734 IValue(c10::List<IValue> v); isListfinal735 bool isList() const { 736 return Tag::GenericList == tag; 737 } 738 c10::List<IValue> toList() &&; 739 c10::List<IValue> toList() const&; 740 c10::ArrayRef<IValue> toListRef() const; 741 742 // Some template constructors of IValue calls another constructor recursively. 743 // This SFINAEs the called constructor exists. 744 template <class T> 745 using enable_if_ivalue_constructible = 746 std::enable_if_t<std::is_constructible_v<IValue, T>, std::nullptr_t>; 747 748 // The rule for lists is more complicated; the generic constructor is only 749 // acceptable if your element isn't SymInt. If you do have a SymInt element, 750 // then you must also, at construction time, check if you can decay the list 751 // into an int list (this is MANDATORY, as at a use site we may expect 752 // toIntList to work even if at the call site you had a SymIntArrayRef 753 // argument). In practice, only SymIntArrayRef is used this way, so we 754 // didn't bother making it work for the other constructors, we just make sure 755 // they're not selectable. 756 template <class T> 757 using enable_if_list_is_ivalue_constructible = std::enable_if_t< 758 std::is_constructible_v<IValue, T> && !std::is_same_v<T, c10::SymInt>, 759 std::nullptr_t>; 760 761 template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr> 762 IValue(c10::List<T>&& v); 763 template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr> 764 IValue(const c10::List<T>& v); 765 template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr> 766 IValue(at::ArrayRef<T> v); 767 template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr> 768 IValue(const std::vector<T>& v); 769 template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr> 770 IValue(std::vector<T>&& v); 771 template <class T, size_t N> 772 IValue(std::array<T, N> v); 773 774 // Manual constructors for lists of symints, which decay to int list if 775 // possible. To avoid ambiguous overload situations, we template them 776 // to prevent implicit conversions 777 template <class T> 778 using enable_if_symint = 779 std::enable_if_t<std::is_same_v<T, c10::SymInt>, std::nullptr_t>; 780 781 template <class T, enable_if_symint<T> = nullptr> 782 IValue(at::ArrayRef<T> v); 783 template <class T, enable_if_symint<T> = nullptr> 784 IValue(at::OptionalArrayRef<T> v); 785 template <class T, enable_if_symint<T> = nullptr> 786 IValue(const std::vector<T>& v); 787 template <class T, enable_if_symint<T> = nullptr> 788 IValue(std::vector<T>&& v); 789 790 template <class T> 791 using enable_if_ilist_is_ivalue_constructible = std::enable_if_t< 792 std::is_constructible_v<IValue, T> && 793 std::is_constructible_v<IValue, typename IListRef<T>::boxed_type> && 794 !std::is_same_v<T, c10::SymInt>, 795 std::nullptr_t>; 796 797 template <class T, enable_if_ilist_is_ivalue_constructible<T> = nullptr> 798 IValue(c10::IListRef<T> v); 799 800 // GenericDict 801 IValue(c10::Dict<IValue, IValue> v); isGenericDictfinal802 bool isGenericDict() const { 803 return Tag::GenericDict == tag; 804 } 805 c10::Dict<IValue, IValue> toGenericDict() &&; 806 c10::Dict<IValue, IValue> toGenericDict() const&; 807 808 template <class Key, class Value> 809 IValue(c10::Dict<Key, Value> v); 810 811 template <class Key, class Value> 812 /// \cond 813 /// DOXYGEN_CANNOT_HANDLE_CONSTRUCTORS_WITH_MACROS_SO_EXCLUDE_THIS_LINE_FROM_DOXYGEN 814 C10_DEPRECATED_MESSAGE( 815 "IValues based on std::unordered_map<K, V> are slow and deprecated. Please use c10::Dict<K, V> instead.") 816 /// \endcond 817 IValue(std::unordered_map<Key, Value> v); 818 819 template <class T, enable_if_ivalue_constructible<T> = nullptr> 820 IValue(std::optional<T> v); 821 template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr> 822 IValue(c10::OptionalArrayRef<T> v); 823 IValue(std::nullopt_t); 824 825 // ClassType 826 IValue(c10::intrusive_ptr<ivalue::Object> v); isObjectfinal827 bool isObject() const { 828 return tag == Tag::Object; 829 } 830 c10::intrusive_ptr<ivalue::Object> toObject() &&; 831 c10::intrusive_ptr<ivalue::Object> toObject() const&; 832 ivalue::Object& toObjectRef() const; 833 834 torch::jit::Module toModule() const; 835 bool isModule() const; 836 837 // PyObject 838 IValue(c10::intrusive_ptr<ivalue::PyObjectHolder> v); isPyObjectfinal839 bool isPyObject() const { 840 return tag == Tag::PyObject; 841 } 842 c10::intrusive_ptr<ivalue::PyObjectHolder> toPyObjectHolder() &&; 843 c10::intrusive_ptr<ivalue::PyObjectHolder> toPyObjectHolder() const&; 844 PyObject* toPyObject() const; 845 846 // Enum 847 explicit IValue(c10::intrusive_ptr<ivalue::EnumHolder> v); isEnumfinal848 bool isEnum() const { 849 return tag == Tag::Enum; 850 } 851 c10::intrusive_ptr<ivalue::EnumHolder> toEnumHolder() &&; 852 c10::intrusive_ptr<ivalue::EnumHolder> toEnumHolder() const&; 853 854 // None 855 IValue() = default; isNonefinal856 bool isNone() const { 857 return Tag::None == tag; 858 } toNonefinal859 std::string toNone() const { 860 AT_ASSERT(isNone()); 861 return "None"; 862 } 863 uninitializedfinal864 static IValue uninitialized() { 865 auto i = IValue(); 866 i.tag = Tag::Uninitialized; 867 return i; 868 } 869 870 // Scalar, which gets encoded as either an Int, a Double or a ComplexDouble IValuefinal871 IValue(const at::Scalar& s) : IValue() { 872 // NB: do the symbolic versions first, as isFloatingPoint is true 873 // for both SymFloat and double 874 if (s.isSymInt()) { 875 tag = Tag::SymInt; 876 payload.u.as_intrusive_ptr = s.toSymInt().toSymNode().release(); 877 } else if (s.isSymFloat()) { 878 tag = Tag::SymFloat; 879 payload.u.as_intrusive_ptr = s.toSymFloat().toSymNodeImpl().release(); 880 } else if (s.isSymBool()) { 881 tag = Tag::SymBool; 882 payload.u.as_intrusive_ptr = s.toSymBool().toSymNodeImpl().release(); 883 } else if (s.isFloatingPoint()) { 884 tag = Tag::Double; 885 payload.u.as_double = s.toDouble(); 886 } else if (s.isComplex()) { 887 *this = s.toComplexDouble(); 888 } else if (s.isBoolean()) { 889 tag = Tag::Bool; 890 payload.u.as_bool = s.toBool(); 891 } else { 892 TORCH_INTERNAL_ASSERT_DEBUG_ONLY( 893 s.isIntegral(false), "Unknown type in Scalar"); 894 tag = Tag::Int; 895 payload.u.as_int = s.toLong(); 896 } 897 } 898 isScalarfinal899 bool isScalar() const { 900 return isDouble() || isInt() || isComplexDouble() || isBool() || 901 isSymInt() || isSymFloat() || isSymBool(); 902 } 903 toScalarfinal904 at::Scalar toScalar() const { 905 if (isDouble()) 906 return toDouble(); 907 else if (isInt()) 908 return toInt(); 909 else if (isComplexDouble()) 910 return toComplexDouble(); 911 else if (isBool()) 912 return toBool(); 913 else if (isSymInt()) 914 return toSymInt(); 915 else if (isSymFloat()) 916 return toSymFloat(); 917 else if (isSymBool()) 918 return toSymBool(); 919 throw std::runtime_error("IValue is not a Scalar"); 920 } 921 922 // Device IValuefinal923 IValue(c10::Device d) : tag(Tag::Device) { 924 payload.u.as_device.type = d.type(); 925 payload.u.as_device.index = d.index(); 926 } isDevicefinal927 bool isDevice() const { 928 return Tag::Device == tag; 929 } toDevicefinal930 c10::Device toDevice() const { 931 AT_ASSERT(isDevice()); 932 return c10::Device(payload.u.as_device.type, payload.u.as_device.index); 933 } 934 935 // Stream IValuefinal936 IValue(c10::Stream s) : tag(Tag::Stream) { 937 auto v = c10::make_intrusive<ivalue::StreamData3Holder>(s.pack3()); 938 payload.u.as_intrusive_ptr = v.release(); 939 } 940 c10::Stream toStream() &&; 941 c10::Stream toStream() const&; isStreamfinal942 bool isStream() const { 943 return Tag::Stream == tag; 944 } 945 946 // ScalarType IValuefinal947 IValue(ScalarType t) 948 : IValue(static_cast<std::underlying_type_t<ScalarType>>(t)) {} toScalarTypefinal949 at::ScalarType toScalarType() const { 950 return static_cast<at::ScalarType>(toInt()); 951 } 952 953 // Layout IValuefinal954 IValue(Layout l) : IValue(static_cast<std::underlying_type_t<Layout>>(l)) {} toLayoutfinal955 at::Layout toLayout() const { 956 return static_cast<at::Layout>(toInt()); 957 } 958 959 // MemoryFormat IValuefinal960 IValue(MemoryFormat m) 961 : IValue(static_cast<std::underlying_type_t<MemoryFormat>>(m)) {} toMemoryFormatfinal962 at::MemoryFormat toMemoryFormat() const { 963 return static_cast<at::MemoryFormat>(toInt()); 964 } 965 966 // QScheme IValuefinal967 IValue(at::QScheme qscheme) : tag(Tag::Int) { 968 payload.u.as_int = static_cast<int64_t>(qscheme); 969 } 970 toQSchemefinal971 at::QScheme toQScheme() const { 972 return static_cast<at::QScheme>(toInt()); 973 } 974 975 // Dimname IValuefinal976 IValue(at::Dimname dimname) : IValue(dimname.symbol().toQualString()) {} 977 toDimnamefinal978 at::Dimname toDimname() const { 979 return at::Dimname::fromSymbol(Symbol::fromQualString(toStringRef())); 980 } 981 982 // Generator IValuefinal983 IValue(at::Generator g) : tag(Tag::Generator) { 984 payload.u.as_intrusive_ptr = 985 null_to_undefined_tensor(g.unsafeReleaseGeneratorImpl()); 986 } isGeneratorfinal987 bool isGenerator() const { 988 return Tag::Generator == tag; 989 } 990 at::Generator toGenerator() &&; 991 at::Generator toGenerator() const&; 992 993 // for debugging tagKindfinal994 std::string tagKind() const { 995 switch (tag) { 996 #define DEFINE_CASE(x) \ 997 case Tag::x: \ 998 return #x; 999 TORCH_FORALL_TAGS(DEFINE_CASE) 1000 #undef DEFINE_CASE 1001 } 1002 return "InvalidTag(" + std::to_string(static_cast<int>(tag)) + ")"; 1003 } 1004 1005 // generic v.to<at::Tensor>() implementations 1006 // that can be used in special functions like pop/push 1007 // that use template meta-programming. 1008 // prefer the directly named methods when you can, 1009 // since they are simpler to understand 1010 1011 // Note: if you get linker errors saying one of these is missing, 1012 // change it to ... && = delete; and you will see better error messages for 1013 // why However, we cannot commit this because some compiler versions barf on 1014 // it. 1015 template <typename T> 1016 T to() &&; 1017 template <typename T> 1018 typename c10::detail::ivalue_to_const_ref_overload_return<T>::type to() 1019 const&; 1020 1021 // ToOptional: convert a IValue to the Optional obj that accepts both T and 1022 // None 1023 template <typename T> 1024 std::optional<T> toOptional(); 1025 template <typename T> 1026 std::optional<T> toOptional() const; 1027 1028 /// @private [doxygen private] 1029 /// this is a shallow comparison of two IValues to test the object identity 1030 bool isSameIdentity(const IValue& rhs) const; 1031 1032 // Computes the "official" string representation of an IValue. This produces a 1033 // TorchScript expression that can be used to recreate an IValue with the same 1034 // value (e.g. when we are printing constants in the serializer). 1035 // 1036 // Callers can use `customFormatter` to override how `repr()` prints out an 1037 // IValue. This is useful if you have some other environment where you can 1038 // look up values, and you want to print a reference to that environment (like 1039 // the serializer's constant table). 1040 // 1041 // repr() is not necessarily defined on all objects! 1042 std::ostream& repr( 1043 std::ostream& stream, 1044 std::function<bool(std::ostream&, const IValue& v)> customFormatter) 1045 const; 1046 1047 // Computes an "informal" string representation of an IValue. This should be 1048 // used for debugging, or servicing `print()`-like functions. 1049 // This is different from `repr()` in that there is no expectation that we can 1050 // exactly reconstruct an IValue from the output; feel free to use a 1051 // concise/pretty form 1052 TORCH_API friend std::ostream& operator<<(std::ostream& out, const IValue& v); 1053 isPtrTypefinal1054 bool isPtrType() const { 1055 if (isTensor()) { 1056 return payload.as_tensor.defined(); 1057 } 1058 return isIntrusivePtrLegacyBehavior(); 1059 } 1060 1061 /// @private [doxygen private] internalToPointerfinal1062 const void* internalToPointer() const { 1063 TORCH_INTERNAL_ASSERT( 1064 isPtrType(), "Can only call internalToPointer() for pointer types"); 1065 if (isTensor()) { 1066 return payload.as_tensor.unsafeGetTensorImpl(); 1067 } else { 1068 return payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton() 1069 ? payload.u.as_intrusive_ptr 1070 : nullptr; 1071 } 1072 } 1073 1074 template <typename T = c10::PlatformType> 1075 TypePtr type() const; 1076 1077 // Detect aliased tensors. 1078 struct HashAliasedIValue { hashTensorfinal::HashAliasedIValue1079 size_t hashTensor(const at::Tensor& ten) const { 1080 if (ten.is_sparse()) { 1081 // COO sparse tensors have a "values" tensor and an "indices" tensor 1082 // so this will detect overlap of sparse tensors that share a values 1083 // tensor, but not sparse tensors that share an indices tensor. 1084 return hashTensor(ten._values()); 1085 } else if (ten.is_sparse_csr()) { 1086 // COO sparse tensors have a "values" tensor and an "indices" tensor 1087 // so this will detect overlap of sparse tensors that share a values 1088 // tensor, but not sparse tensors that share an indices tensor. 1089 return hashTensor(ten.values()); 1090 } else if (!ten.has_storage()) { 1091 // Opaque tensors such as the ones constructed by the MKL-DNN backend 1092 // don't have storage so we just use their TensorImpls. 1093 // TODO: Find way to expose alias info for opaque tensors. 1094 return reinterpret_cast<size_t>(ten.unsafeGetTensorImpl()); 1095 } else { 1096 return reinterpret_cast<size_t>(ten.storage().unsafeGetStorageImpl()); 1097 } 1098 } operatorfinal::HashAliasedIValue1099 size_t operator()(const IValue& val) const { 1100 if (val.isTensor()) { 1101 return hashTensor(val.toTensor()); 1102 } 1103 // If it is not a Tensor, then two mutable IValues alias each other only 1104 // if they are the same pointer. 1105 return val.payload.u.as_int; 1106 } 1107 }; 1108 1109 struct CompAliasedIValues { operatorfinal::CompAliasedIValues1110 bool operator()(const IValue& lhs, const IValue& rhs) const { 1111 return lhs.isAliasOf(rhs); 1112 } 1113 }; 1114 1115 using HashAliasedIValues = 1116 std::unordered_set<IValue, HashAliasedIValue, CompAliasedIValues>; 1117 using HashAliasedIValueMap = 1118 std::unordered_map<IValue, IValue, HashAliasedIValue, CompAliasedIValues>; 1119 1120 struct HashIdentityIValue { operatorfinal::HashIdentityIValue1121 size_t operator()(const IValue& val) const { 1122 return val.payload.u.as_int; 1123 } 1124 }; 1125 1126 struct CompIdentityIValues { operatorfinal::CompIdentityIValues1127 bool operator()(const IValue& lhs, const IValue& rhs) const { 1128 return lhs.is(rhs); 1129 } 1130 }; 1131 1132 using HashIdentityIValues = 1133 std::unordered_set<IValue, HashIdentityIValue, CompIdentityIValues>; 1134 using HashIdentityIValueMap = 1135 std::unordered_map<IValue, IValue, HashIdentityIValue, CompIdentityIValues>; 1136 1137 // Chechs if this and rhs has a subvalues in common. 1138 // [t1,t2] and [t2, t3] returns true. 1139 bool overlaps(const IValue& rhs) const; 1140 1141 // Inserts all subvalues of this in subValues. 1142 void getSubValues(HashAliasedIValues& subValues) const; 1143 1144 // Apply visitor to every subvalue. 1145 // TODO: There are several places that recurse over IValue. This is fragile. 1146 // This visitor should be used to recurse over ivalues. 1147 void visit(const std::function<bool(const IValue&)>& visitor) const; 1148 IValue deepcopy(std::optional<at::Device> device = std::nullopt) const; 1149 IValue deepcopy( 1150 HashIdentityIValueMap& memo, 1151 std::optional<at::Device> device = std::nullopt) const; 1152 1153 private: null_to_undefined_tensorfinal1154 static c10::intrusive_ptr_target* null_to_undefined_tensor( 1155 c10::intrusive_ptr_target* p) { 1156 return p ? p 1157 : static_cast<c10::intrusive_ptr_target*>( 1158 c10::UndefinedTensorImpl::singleton()); 1159 } 1160 1161 static bool ptrEqual(const IValue& lhs, const IValue& rhs); 1162 // NOTE: IValue tags are intentionally private. In the future we may encode 1163 // this value different (e.g. using NaN boxing), and this would make it more 1164 // costly to determine the tag for all types vs just determining if something 1165 // is a particular type. Instead we want clients to use the `isX` methods when 1166 // possible. If for perf. reasons you really, absolutely, must have a jump 1167 // table, then we can revisit this. 1168 enum class Tag : uint32_t { 1169 #define DEFINE_TAG(x) x, 1170 TORCH_FORALL_TAGS(DEFINE_TAG) 1171 #undef DEFINE_TAG 1172 }; 1173 1174 #define COUNT_TAG(x) 1 + 1175 static constexpr auto kNumTags = TORCH_FORALL_TAGS(COUNT_TAG) 0; 1176 #undef COUNT_TAG 1177 1178 template < 1179 class T, 1180 class NullType = c10::detail::intrusive_target_default_null_type<T>> 1181 c10::intrusive_ptr<T, NullType> moveToIntrusivePtr(); 1182 template < 1183 typename T, 1184 class NullType = c10::detail::intrusive_target_default_null_type<T>> 1185 c10::intrusive_ptr<T, NullType> toIntrusivePtr() const; 1186 destroyfinal1187 void destroy() { 1188 // We carefully construct this call to both 1) avoid UB by using 1189 // the "wrong" one of as_tensor and as_intrusive_ptr and 2) enable 1190 // the compiler to generate the same code for each case. It is 1191 // surprisingly difficult to get this right. 1192 if (isTensor() || isIntrusivePtr()) { 1193 c10::intrusive_ptr_target* p = isTensor() 1194 ? payload.as_tensor.unsafeGetTensorImpl() 1195 : payload.u.as_intrusive_ptr; 1196 c10::intrusive_ptr<intrusive_ptr_target, c10::UndefinedTensorImpl>:: 1197 reclaim(p); 1198 // No need to make this destructor call! 1199 // payload.as_tensor.~Tensor(); 1200 } 1201 } 1202 1203 // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) moveFromfinal1204 C10_ALWAYS_INLINE void moveFrom(IValue&& rhs) noexcept { 1205 if (rhs.isTensor()) { 1206 new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor)); 1207 // As far as I can tell, omitting the usual explicit destructor call 1208 // is not UB in and of itself, and it's a slight perf win. The 1209 // destructor is a no-op, because the moved-from Tensor is 1210 // effectively an intrusive_ptr in the null state, so we don't need 1211 // the behavior for correctness reasons either. Leaving this 1212 // explanatory comment, including commented-out destructor call, to 1213 // make this abundantly clear. 1214 // 1215 // rhs.payload.as_tensor.~Tensor(); 1216 } else { 1217 payload.u = rhs.payload.u; 1218 } 1219 tag = rhs.tag; 1220 rhs.clearToNone(); 1221 } 1222 clearToNonefinal1223 void clearToNone() noexcept { 1224 payload.u.as_int = 0; 1225 tag = Tag::None; 1226 } 1227 1228 private: 1229 // This is the source of truth for isIntrusivePtr; edit results here 1230 // as needed and isIntrusivePtr will pick them up. 1231 // NOLINTBEGIN(bugprone-branch-clone) isIntrusivePtrConstexprfinal1232 static constexpr bool isIntrusivePtrConstexpr(Tag tag) { 1233 switch (tag) { 1234 case Tag::None: 1235 return false; 1236 case Tag::Tensor: 1237 return false; 1238 case Tag::Storage: 1239 return true; 1240 case Tag::Generator: 1241 return true; 1242 case Tag::Double: 1243 return false; 1244 case Tag::ComplexDouble: 1245 return true; 1246 case Tag::Int: 1247 return false; 1248 case Tag::SymInt: 1249 return true; 1250 case Tag::SymFloat: 1251 return true; 1252 case Tag::SymBool: 1253 return true; 1254 case Tag::Bool: 1255 return false; 1256 case Tag::Tuple: 1257 return true; 1258 case Tag::String: 1259 return true; 1260 case Tag::Blob: 1261 return true; 1262 case Tag::GenericList: 1263 return true; 1264 case Tag::GenericDict: 1265 return true; 1266 case Tag::Future: 1267 return true; 1268 case Tag::Await: 1269 return true; 1270 case Tag::Device: 1271 return false; 1272 case Tag::Stream: 1273 return true; 1274 case Tag::Object: 1275 return true; 1276 case Tag::PyObject: 1277 return true; 1278 case Tag::Uninitialized: 1279 return false; 1280 case Tag::Capsule: 1281 return true; 1282 case Tag::RRef: 1283 return true; 1284 case Tag::Quantizer: 1285 return true; 1286 case Tag::Enum: 1287 return true; 1288 } 1289 return false; 1290 } 1291 // NOLINTEND(bugprone-branch-clone) 1292 1293 public: 1294 // Don't edit this just to add results for new tags; edit 1295 // isIntrusivePtrConstexpr above. isIntrusivePtrfinal1296 bool isIntrusivePtr() const { 1297 // Implementation NOTE: the switch in isIntrusivePtrConstexpr 1298 // above is the previous production implementation of this 1299 // function. We observed that, at least on x86_64, the generated 1300 // instruction sequence was a similar bit vector test to what we 1301 // have manually implemented below, except that there was an extra 1302 // "bounds check" branch confirming, essentially, that `tag < 1303 // kNumTags` and providing a consistent result in that case. We 1304 // don't care about the result if tag is out of bounds, so we'd 1305 // like to eliminate that comparison and branch; manually 1306 // implementing this function as a bit test is the simplest way I 1307 // could find to accomplish that elimination. 1308 static constexpr uint32_t kTruthTableBitVector = 1309 #define TRUTH_TABLE_ENTRY(tag) \ 1310 (uint32_t(isIntrusivePtrConstexpr(Tag::tag)) << uint32_t(Tag::tag)) | 1311 TORCH_FORALL_TAGS(TRUTH_TABLE_ENTRY) 1312 #undef TRUTH_TABLE_ENTRY 1313 0; 1314 1315 TORCH_INTERNAL_ASSERT_DEBUG_ONLY( 1316 static_cast<uint32_t>(tag) < kNumTags, 1317 "unexpected tag ", 1318 static_cast<int>(tag)); 1319 return kTruthTableBitVector & (1 << (uint32_t(tag) % 32)); 1320 } 1321 1322 // Storage and Generator were treated specially when 1323 // is_intrusive_ptr was stored as explicit state. This getter 1324 // preserves the old behavior for use with WeakIValue for now. isIntrusivePtrLegacyBehaviorfinal1325 bool isIntrusivePtrLegacyBehavior() const { 1326 if (tag == Tag::Storage || tag == Tag::Generator) { 1327 return payload.u.as_intrusive_ptr != 1328 c10::UndefinedTensorImpl::singleton(); 1329 } else { 1330 return isIntrusivePtr(); 1331 } 1332 } 1333 1334 union Payload { 1335 // [TriviallyCopyablePayload] 1336 // We use a nested union here so that we can make the copy easy 1337 // and efficient in the non-tensor (i.e., trivially copyable) 1338 // case. Specifically, we do not have to do a switch-on-tag to 1339 // figure out which union member to assign; we can just use 1340 // TriviallyCopyablePayload::operator=. 1341 union TriviallyCopyablePayload { TriviallyCopyablePayload()1342 TriviallyCopyablePayload() : as_int(0) {} 1343 int64_t as_int; 1344 double as_double; 1345 bool as_bool; 1346 // Invariant: never nullptr; null state is represented as 1347 // c10::UndefinedTensorImpl::singleton() for consistency of 1348 // representation with Tensor. 1349 c10::intrusive_ptr_target* as_intrusive_ptr; 1350 struct { 1351 c10::DeviceType type; 1352 DeviceIndex index; 1353 } as_device; 1354 } u; 1355 at::Tensor as_tensor; Payload()1356 Payload() : u() {} ~Payload()1357 ~Payload() {} 1358 }; 1359 IValuefinal1360 IValue(const Payload& p, Tag t) : tag(t) { 1361 if (isTensor()) { 1362 new (&payload.as_tensor) at::Tensor(p.as_tensor); 1363 } else { 1364 payload.u = p.u; 1365 } 1366 } 1367 1368 template <typename T> 1369 struct TagType {}; 1370 1371 friend MaybeOwnedTraits<IValue>; 1372 1373 Payload payload; 1374 Tag tag{IValue::Tag::None}; 1375 friend struct WeakIValue; 1376 }; 1377 1378 struct TORCH_API WeakIValue final { 1379 WeakIValue() = default; 1380 WeakIValuefinal1381 WeakIValue(const WeakIValue& rhs) 1382 : payload(rhs.payload), 1383 tag(rhs.tag), 1384 is_intrusive_ptr(rhs.is_intrusive_ptr) { 1385 if (is_intrusive_ptr && 1386 payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { 1387 c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr); 1388 } 1389 } WeakIValuefinal1390 WeakIValue(const IValue& rhs) 1391 : tag(rhs.tag), is_intrusive_ptr(rhs.isIntrusivePtrLegacyBehavior()) { 1392 if (rhs.isTensor()) { 1393 payload.as_intrusive_ptr = rhs.unsafeToTensorImpl(); 1394 is_intrusive_ptr = true; 1395 } else { 1396 payload = rhs.payload.u; 1397 } 1398 if (is_intrusive_ptr) { 1399 if (payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { 1400 c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr); 1401 } 1402 } 1403 } WeakIValuefinal1404 WeakIValue(WeakIValue&& rhs) noexcept : WeakIValue() { 1405 swap(rhs); 1406 } ~WeakIValuefinal1407 ~WeakIValue() { 1408 if (is_intrusive_ptr && 1409 payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { 1410 c10::raw::weak_intrusive_ptr::decref(payload.as_intrusive_ptr); 1411 } 1412 } 1413 WeakIValue& operator=(WeakIValue&& rhs) & noexcept { 1414 WeakIValue(std::move(rhs)).swap(*this); // this also sets rhs to None 1415 return *this; 1416 } 1417 WeakIValue& operator=(WeakIValue const& rhs) & { 1418 WeakIValue(rhs).swap(*this); 1419 return *this; 1420 } swapfinal1421 void swap(WeakIValue& rhs) noexcept { 1422 std::swap(payload, rhs.payload); 1423 std::swap(is_intrusive_ptr, rhs.is_intrusive_ptr); 1424 std::swap(tag, rhs.tag); 1425 } 1426 isSameIdentityfinal1427 bool isSameIdentity(const WeakIValue& rhs) const { 1428 return payload.as_int == rhs.payload.as_int && tag == rhs.tag && 1429 is_intrusive_ptr == rhs.is_intrusive_ptr; 1430 } 1431 lockfinal1432 IValue lock() const { 1433 if (!is_intrusive_ptr) { 1434 IValue::Payload newPayload; 1435 newPayload.u = payload; 1436 return IValue(newPayload, tag); 1437 } 1438 if (IValue::Tag::Tensor == tag) { 1439 auto temp = 1440 c10::weak_intrusive_ptr<at::TensorImpl, c10::UndefinedTensorImpl>:: 1441 reclaim(static_cast<at::TensorImpl*>(payload.as_intrusive_ptr)); 1442 c10::intrusive_ptr<at::TensorImpl, c10::UndefinedTensorImpl> ip( 1443 temp.lock()); 1444 temp.release(); 1445 if (!ip) { 1446 return IValue(); 1447 } else { 1448 return IValue(at::Tensor(std::move(ip))); 1449 } 1450 } else { 1451 auto temp = c10::weak_intrusive_ptr<c10::intrusive_ptr_target>::reclaim( 1452 payload.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton() 1453 ? nullptr 1454 : payload.as_intrusive_ptr); 1455 IValue::Payload pl; 1456 pl.u.as_intrusive_ptr = temp.lock().release(); 1457 temp.release(); 1458 if (!pl.u.as_intrusive_ptr) { 1459 return IValue(); 1460 } else { 1461 return IValue(pl, tag); 1462 } 1463 } 1464 } 1465 use_countfinal1466 size_t use_count() const noexcept { 1467 if (!is_intrusive_ptr) { 1468 return 1; 1469 } 1470 auto temp = c10::weak_intrusive_ptr< 1471 c10::intrusive_ptr_target, 1472 c10::UndefinedTensorImpl>::reclaim(payload.as_intrusive_ptr); 1473 size_t result = temp.use_count(); 1474 temp.release(); 1475 return result; 1476 } 1477 weak_use_countfinal1478 size_t weak_use_count() const noexcept { 1479 if (!is_intrusive_ptr) { 1480 return 1; 1481 } 1482 auto temp = c10::weak_intrusive_ptr< 1483 c10::intrusive_ptr_target, 1484 c10::UndefinedTensorImpl>::reclaim(payload.as_intrusive_ptr); 1485 size_t result = temp.weak_use_count(); 1486 temp.release(); 1487 return result; 1488 } hashfinal1489 size_t hash() const { 1490 return payload.as_int; 1491 } 1492 1493 private: 1494 using Payload = IValue::Payload::TriviallyCopyablePayload; 1495 Payload payload; 1496 IValue::Tag tag{IValue::Tag::None}; 1497 bool is_intrusive_ptr{false}; 1498 }; 1499 1500 // An owning pointer to a type. When the type is class type, it requires a pair 1501 // of shared_ptrs to the class type and its owning CU, so that the class type is 1502 // guaranteed to stay alive as long as we hold this object. 1503 struct TORCH_API StrongTypePtr { 1504 StrongTypePtr(std::shared_ptr<torch::jit::CompilationUnit> cu, TypePtr type); 1505 1506 std::shared_ptr<torch::jit::CompilationUnit> cu_; 1507 TypePtr type_; 1508 }; 1509 1510 // [Constant Object Weak CompilationUnit Reference] 1511 // A non owning pointer to a type. When a class get inserted as a constant 1512 // into a graph, if we used a strong pointer we would have a circular reference 1513 // from Object -> CompilationUnit and CompilationUnit -> Graph (which owns the 1514 // Constant Object) 1515 struct TORCH_API WeakTypePtr { 1516 WeakTypePtr(std::weak_ptr<torch::jit::CompilationUnit> cu, TypePtr type); 1517 1518 std::weak_ptr<torch::jit::CompilationUnit> cu_; 1519 TypePtr type_; 1520 }; 1521 1522 // internal build errors with std::variant :/ 1523 struct WeakOrStrongCompilationUnit { WeakOrStrongCompilationUnitWeakOrStrongCompilationUnit1524 explicit WeakOrStrongCompilationUnit( 1525 std::shared_ptr<torch::jit::CompilationUnit> shared_cu) 1526 : strong_ptr_(std::move(shared_cu)), weak_ptr_(std::nullopt) {} 1527 WeakOrStrongCompilationUnitWeakOrStrongCompilationUnit1528 explicit WeakOrStrongCompilationUnit( 1529 std::weak_ptr<torch::jit::CompilationUnit> weak_cu) 1530 : strong_ptr_(std::nullopt), weak_ptr_(std::move(weak_cu)) {} 1531 getStrongRefOrThrowWeakOrStrongCompilationUnit1532 std::shared_ptr<torch::jit::CompilationUnit> getStrongRefOrThrow() const { 1533 TORCH_INTERNAL_ASSERT(strong_ptr_ != std::nullopt); 1534 return *strong_ptr_; 1535 } 1536 getWeakRefOrThrowWeakOrStrongCompilationUnit1537 std::weak_ptr<torch::jit::CompilationUnit> getWeakRefOrThrow() const { 1538 TORCH_INTERNAL_ASSERT(weak_ptr_ != std::nullopt); 1539 return *weak_ptr_; 1540 } 1541 holdingStrongRefWeakOrStrongCompilationUnit1542 bool holdingStrongRef() const { 1543 return strong_ptr_ != std::nullopt; 1544 } 1545 holdingEmptyStrongRefWeakOrStrongCompilationUnit1546 bool holdingEmptyStrongRef() const { 1547 return holdingStrongRef() && *strong_ptr_ == nullptr; 1548 } 1549 1550 std::optional<std::shared_ptr<torch::jit::CompilationUnit>> strong_ptr_; 1551 std::optional<std::weak_ptr<torch::jit::CompilationUnit>> weak_ptr_; 1552 }; 1553 1554 // An Object will hold a non-owning Compilation Unit reference if it is a 1555 // Constant in the graph and a Owning reference otherwise 1556 struct TORCH_API WeakOrStrongTypePtr { WeakOrStrongTypePtrWeakOrStrongTypePtr1557 explicit WeakOrStrongTypePtr(WeakTypePtr weak) 1558 : cu_(WeakOrStrongCompilationUnit(std::move(weak.cu_))), 1559 type_(std::move(weak.type_)) {} WeakOrStrongTypePtrWeakOrStrongTypePtr1560 explicit WeakOrStrongTypePtr(StrongTypePtr strong) 1561 : cu_(WeakOrStrongCompilationUnit(std::move(strong.cu_))), 1562 type_(std::move(strong.type_)) {} WeakOrStrongTypePtrWeakOrStrongTypePtr1563 explicit WeakOrStrongTypePtr(WeakOrStrongCompilationUnit cu, TypePtr type) 1564 : cu_(std::move(cu)), type_(std::move(type)) {} 1565 WeakTypePtr asWeakTypePtr() const; 1566 1567 WeakOrStrongCompilationUnit cu_; 1568 TypePtr type_; 1569 holds_strong_refWeakOrStrongTypePtr1570 bool holds_strong_ref() const { 1571 return cu_.holdingStrongRef(); 1572 } 1573 holds_empty_strong_refWeakOrStrongTypePtr1574 bool holds_empty_strong_ref() const { 1575 return cu_.holdingEmptyStrongRef(); 1576 } 1577 }; 1578 1579 } // namespace c10 1580 1581 #include <ATen/core/ivalue_inl.h> // IWYU pragma: keep 1582