1 #pragma once 2 3 #include <ATen/core/ivalue_to.h> 4 #include <c10/util/ArrayRef.h> 5 #include <c10/util/Exception.h> 6 7 #include <functional> 8 #include <initializer_list> 9 #include <iterator> 10 #include <type_traits> 11 12 /* 13 * [Note: IListRef] 14 * Wrapper around different API containers (e.g. boxed and unboxed). 15 * 16 * What is it? 17 * =========== 18 * It is a tagged union of both boxed and unboxed API containers. 19 * Working implementations: 20 * 21 * - `IListRef<at::Tensor>` 22 * - `IListRef<at::OptionalTensorRef>` 23 * 24 * Note that `IListRef` is a view type. Meaning that it won't own the 25 * tensors it holds. It's intended to be used only as argument parameters. 26 * Specifically, where these 2 worlds overlap. 27 * 28 * What is this for? 29 * ================= 30 * Historically, PyTorch has maintained 2 different APIs: the unboxed 31 * (called from C++ API and Python eager mode) and boxed APIs (called 32 * from the TorchScript JIT, mobile interpreter, and boxed fallbacks). 33 * 34 * Calling unboxed kernels from the boxed "world" and vice-versa may 35 * result in non-negligible overhead. Lists are one of those types: 36 * 37 * - Boxed world: `c10::List` 38 * - Unboxed world: `c10::ArrayRef` 39 * 40 * In this context, `c10::IListRef` solves this problem by wrapping those 41 * 2 container types, so that we don't need to convert from one to 42 * the other. 43 * 44 * (see https://github.com/pytorch/pytorch/issues/66328) 45 * 46 * What does it do? 47 * ================ 48 * This container wraps around the different tagged containers 49 * (currently, only boxed and unboxed), without incurring in extra 50 * overhead for converting from one to another. It does so while 51 * exposing usual container methods, which dispatch to corresponding 52 * implementations. 53 * 54 * While it works with different container types, it introduces 55 * overhead for repeatedly calling member functions (since those will 56 * get dispatched, again). Therefore, you should only use it to iterate 57 * through the list up to one time. If you need to do more complex things, 58 * call `materialize()` first. 59 * 60 * Adding support for a new Tag 61 * ============================ 62 * Suppose we want to add a new tag: `Chest`. Here are the steps 63 * we would have to go through: 64 * 65 * 1. Add a line for it in the macro `TORCH_ILISTREF_FORALL_TAGS`. 66 * 67 * #define TORCH_ILISTREF_FORALL_TAGS(_, ...) \ 68 * ... 69 * _(Chest, ##__VA_ARGS__) 70 * 71 * 2. Add type aliases, union members, and constructors. 72 * 73 * template <typename T> 74 * class IListRef { 75 * ... 76 * using chest_type = 77 * typename detail::IListRefTagImpl<T, IListRefTag::Chest>::list_type; 78 * ... 79 * IListRef(...) : tag_(IListRefTag::Chest) { 80 * ... 81 * } 82 * ... 83 * union Payload { 84 * ... 85 * chest_type chest; 86 * ... 87 * }; 88 * ... 89 * }; 90 * 91 * 3. Add a default implementation for it (in 'IListRef_inl.h'). It's 92 * preferable to make the default implementation work for `T = Tensor` 93 * (both `Unboxed` and `Boxed` do it). 94 * 95 * template <typename T, typename ListElemT> 96 * class IListRefTagImplBase<IListRefTag::Chest, T, ListElemT> { 97 * public: 98 * using elem_type = ListElemT; 99 * using list_type = ChestContainer<elem_type>; 100 * 101 * static const list_type& unwrap(const IListRef<T>& ilist) { ... } 102 * 103 * static typename list_type::const_iterator& unwrap( 104 * IListRefIterator<T>& it) { ... } 105 * 106 * static const typename list_type::const_iterator& unwrap( 107 * const IListRefIterator<T>& it) { ... } 108 * 109 * static IListRefConstRef<T> iterator_get( 110 * const typename list_type::const_iterator& it) { ... } 111 * } 112 * 113 * 4. Add an specialization for each of the already supported types. 114 * Finally, for consistency, add them to the tracking list. 115 * (see [Note: IListRefTagImpl Specializations]) 116 * 117 * template <> 118 * class IListRefTagImpl<IListRefTag::Chest, at::Tensor> 119 * : public IListRefTagImplBase<IListRefTag::Chest, at::Tensor> {}; 120 * 121 * Adding support for a new Type 122 * ============================= 123 * Suppose we want to add support for a new type: `Matrix`. 124 * Here are the steps we would have to go through: 125 * 126 * 1. Add an specialization for each of the existing tags. 127 * For consistency, add them to the tracking list. 128 * (see [Note: IListRefTagImpl Specializations]) 129 * 130 * template <> 131 * class IListRefTagImpl<IListRefTag::Unboxed, Matrix> 132 * : public IListRefTagImplBase<IListRefTag::Unboxed, Matrix> {}; 133 * 134 * template <> 135 * class IListRefTagImpl<Matrix, IListRefTag::Boxed> 136 * : public IListRefTagImplBase<IListRefTag::Boxed, Matrix> {}; 137 * 138 * Common Problems 139 * =============== 140 * 1. One of `IListRef(Iterator)` methods are failing to compile. 141 * 142 * That may be happening because the container type you added 143 * is not compatible with the code written for that method. If 144 * that's true, then you might have to transform that code into 145 * a static method call (see `List::operator[]` method). 146 * 147 * 2. Can't make `IListRefIterator<T>::operator*` return a const-reference. 148 * 149 * First, keep in mind that we assume that boxed containers will 150 * have to deal with `IValue` (e.g. `c10::List`). In this context, 151 * what may be happening is that `IValue` doesn't store internally 152 * your type `T`. Instead, it constructs a type new `T` everytime 153 * you try to get `T` for it (see `IListRef<at::OptinalTensorRef>`). 154 */ 155 156 namespace c10 { 157 template <typename T> 158 class IListRef; 159 160 /* 161 * Applies arbitrary macros to each `IListRefTag`. 162 */ 163 #define TORCH_ILISTREF_FORALL_TAGS(_, ...) \ 164 _(Unboxed, ##__VA_ARGS__) \ 165 _(Boxed, ##__VA_ARGS__) \ 166 _(Materialized, ##__VA_ARGS__) 167 168 /* 169 * Defines a "switch-case" for `TAG`. Inside, it executes `BODY`, 170 * while bringing to scope: 171 * 172 * - `ImplT`: the implementation class for `TAG` 173 * - `this_`: the result of unwrapping `this` 174 */ 175 #define TORCH_ILISTREF_UNWRAP_CASE(TAG, BODY) \ 176 case c10::IListRefTag::TAG: { \ 177 using ImplT = c10::detail::IListRefTagImpl<IListRefTag::TAG, T>; \ 178 auto& this_ = ImplT::unwrap(*this); \ 179 BODY \ 180 } break; 181 182 /* 183 * Dispatches the unwrap call, depending on `TAG`, followed by 184 * the execution of `BODY`. It aborts if `TAG` is not a `IListRefTag`. 185 * 186 * This macro is useful because it allows us to handle different 187 * types (that correspond to different tags) to be implemented 188 * only once. We can do it even when the implementation of the 189 * different tags aren't syntatically the same, by dispatching 190 * it to a function (e.g. `ImplT::<dispatch-function>(this_)`). 191 */ 192 #define TORCH_ILISTREF_UNWRAP(TAG, BODY) \ 193 switch (TAG) { \ 194 TORCH_ILISTREF_FORALL_TAGS(TORCH_ILISTREF_UNWRAP_CASE, BODY) \ 195 break; \ 196 default: \ 197 TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag."); \ 198 } 199 200 enum class IListRefTag { 201 #define DEFINE_TAG(tag, ...) tag, 202 TORCH_ILISTREF_FORALL_TAGS(DEFINE_TAG) 203 #undef DEFINE_TAG 204 None 205 }; 206 207 namespace detail { 208 /* 209 * Type alias that specifies whether we return a reference or a copy of `T`. 210 * 211 * What is this for? 212 * ================= 213 * Since values in the boxed world are represented by an `IValue`, we also 214 * depend on whether it can be converted to a const-reference (`Tensor`) or 215 * has to create a new copy of `T` (`OptionalTensorRef`). 216 */ 217 template <typename T> 218 using IListRefConstRef = typename ivalue_to_const_ref_overload_return<T>::type; 219 220 /* 221 * Interface that implements key functions for each `IListRefTag` type. 222 * 223 * What is this for? 224 * ================= 225 * Given an `IListRef(Iterator)<T>`, some methods have to be implemented 226 * differently for each `TAG`. Therefore, the methods inside this class 227 * are used as dispatch targets for the different `IListRefTag` values. 228 * 229 * You should create an specialization of this class for each possible 230 * combination of `IListRefTag` type (except `None`) and element types 231 * (e.g. `Tensor`). 232 * 233 * What does it do? 234 * ================ 235 * 1. defines static methods to be used as dispatch targets by both 236 * `IListRef<T>` and `IListRefIterator<T>` (see the implementation of 237 * `IListRefTagImplBase`). 238 * 239 * 2. defines the `elem_type` and `list_type` aliases that will be 240 * used in the definition of `IListRef<T>`. In general, we should do 241 * so by inheriting from `IListRefTagImplBase<TAG, T, ListElemT>`. 242 * 243 * [Note: IListRefTagImpl Specialization] 244 * ====================================== 245 * For `IListRef(Iterator)<at::Tensor>`: 246 * - <IListRefTag::Unboxed, at::Tensor> 247 * - <IListRefTag::Boxed, at::Tensor> 248 * - <IListRefTag::Materialized, at::Tensor> 249 * 250 * For `IListRef(Iterator)<at::OptionalTensorRef>`: 251 * - <IListRefTag::Unboxed, at::OptionalTensorRef> 252 * - <IListRefTag::Boxed, at::OptionalTensorRef> 253 * - <IListRefTag::Materialized, at::OptionalTensorRef> 254 */ 255 template <IListRefTag TAG, typename T> 256 class IListRefTagImpl {}; 257 258 /* 259 * Base implementation of `IListRefTagImpl<TAG, T>` methods. 260 * 261 * What is this for? 262 * ================= 263 * This should make adding specializations for new types easier. For 264 * example, one should be able to add a new type just by making its 265 * `IListRefTagImpl` specialization inherit from `IListRefTagImplBase`. 266 * 267 * You should create a partial specialization for this class only if 268 * you introduce a new `IListRefTag`. The idea being that there is one 269 * default implementation for each possible value of `IListRefTag`. 270 * 271 * What does it do? 272 * ================ 273 * 1. defines `elem_type` as an alias to `ListElemT`. 274 * 275 * 1. defines `list_type` as an alias to the default container type 276 * that will hold a collection of `elem_type`. The idea being that 277 * all types tagged as `TAG` will have `list_type` as its container, 278 * with different `elem_type`. 279 * 280 * 3. defines the default implementation for each of the methods that 281 * are supposed to be defined on `IListRefTagImpl` specializations. 282 * 283 * 4. inheriting from `IListRefTagImplBase<TAG, T, ListElemT>` also means 284 * that the payload of the type `IListRef<T>` will be of type `list_type` 285 * when it is tagged as `TAG`. 286 */ 287 template <IListRefTag TAG, typename T, typename ListElemT = T> 288 class IListRefTagImplBase {}; 289 290 /* 291 * Materialized container for `IListRef<T>`. 292 * 293 * What is this for? 294 * ================= 295 * Container that groups `T` references together. This exchanges the 296 * overhead of every method call from `IListRef<T>` for a dynamic allocation. 297 * 298 * You should use this container instead of `IListRef<T>` if: 299 * 300 * - You are going to iterate the list more than once 301 * - You need to repeatedly access arbitrary elements (using `operator[]`) 302 * What does it do? 303 304 * ================ 305 * Removes the reference (&) from the type, and wraps it into a 306 * `std::reference_wrapper`. If `IListRefConstRef<T>` is not a 307 * reference type, then it's left unchanged. 308 */ 309 template <typename T> 310 using _MaterializedIListRefElem = std::conditional_t< 311 std::is_reference_v<T>, 312 typename std::reference_wrapper<std::remove_reference_t<T>>, 313 T>; 314 315 template <typename T> 316 using MaterializedIListRefElem = _MaterializedIListRefElem<IListRefConstRef<T>>; 317 318 template <typename T> 319 using MaterializedIListRef = std::vector<MaterializedIListRefElem<T>>; 320 321 } // namespace detail 322 323 /* 324 * Iterator for `IListRef<T>`. 325 * 326 * What is it? 327 * =========== 328 * Currently, a `std::bidirectional_iterator` that wraps the iterator 329 * types defined for each of the `IListRefTag`. 330 * 331 * One should be able to use it, as if it were the unwrapped 332 * iterators themselves. 333 334 * What does it do? 335 * ================ 336 * Similarly to `IListRef<T>`, this is a wrapper class. Specifically, it 337 * wraps each container's `const_iterator` type alias. So, for example, 338 * given that the container for `IListRefTag::Boxed` is `c10::List`, this 339 * iterator will wrap a `c10::List::const_iterator`. 340 * 341 * [Note: MSVC Iterator Debug] 342 * =========================== 343 * MSVC `vector<T>::iterator` implementation (used in the boxed variant) 344 * makes it so this union's destructor, copy-constructor (assignment), and 345 * move-constructor (assignment) are implicitly deleted. 346 * 347 * Therefore, we need to explicitly define them as needed. Follows a list 348 * of places where these are needed and their reason: 349 * 350 * - `Payload` destructor: 351 * it is deleted only if the macro `_ITERATOR_DEBUG_LEVEL` is set to 2. 352 * 353 * - `IListRefIterator` destructor: 354 * same as above. However, we need to explicitly call the variant 355 * destructor explicitly. 356 * 357 * - `IListRefIterator` copy-constructor: 358 * it is deleted only if the macro `_ITERATOR_DEBUG_LEVEL` is different 359 * than 0. 360 */ 361 template <typename T> 362 class IListRefIterator { 363 private: 364 #define DEFINE_FRIEND_CLASS(TAG, ...) \ 365 friend class detail::IListRefTagImpl<IListRefTag::TAG, T>; \ 366 friend class detail::IListRefTagImplBase< \ 367 IListRefTag::TAG, \ 368 T, \ 369 typename detail::IListRefTagImpl<IListRefTag::TAG, T>::elem_type>; 370 TORCH_ILISTREF_FORALL_TAGS(DEFINE_FRIEND_CLASS) 371 #undef DEFINE_FRIEND_CLASS 372 373 public: 374 // C++17 friendly std::iterator implementation 375 using iterator_category = std::bidirectional_iterator_tag; 376 using value_type = T; 377 using difference_type = std::ptrdiff_t; 378 using pointer = T*; 379 using reference = T&; 380 381 using unboxed_iterator_type = typename detail:: 382 IListRefTagImpl<IListRefTag::Unboxed, T>::list_type::const_iterator; 383 using boxed_iterator_type = typename detail:: 384 IListRefTagImpl<IListRefTag::Boxed, T>::list_type::const_iterator; 385 using materialized_iterator_type = 386 typename detail::MaterializedIListRef<T>::const_iterator; 387 IListRefIterator()388 IListRefIterator() : tag_(IListRefTag::None) {} 389 390 #if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL != 0 391 // See [Note: MSVC Iterator Debug] IListRefIterator(const IListRefIterator & iterator)392 IListRefIterator(const IListRefIterator& iterator) 393 : tag_(iterator.tag_) { 394 switch (tag_) { 395 case IListRefTag::Boxed: 396 payload_.boxed_iterator = iterator.payload_.boxed_iterator; 397 break; 398 case IListRefTag::Unboxed: 399 payload_.unboxed_iterator = iterator.payload_.unboxed_iterator; 400 break; 401 case IListRefTag::Materialized: 402 payload_.materialized_iterator = iterator.payload_.materialized_iterator; 403 break; 404 default: 405 TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag."); 406 } 407 } 408 #endif 409 410 #if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL == 2 411 // See [Note: MSVC Iterator Debug] noexcept(false)412 ~IListRefIterator() noexcept(false) { 413 switch (tag_) { 414 case IListRefTag::Boxed: 415 payload_.boxed_iterator.~boxed_iterator_type(); 416 break; 417 case IListRefTag::Unboxed: 418 payload_.unboxed_iterator.~unboxed_iterator_type(); 419 break; 420 case IListRefTag::Materialized: 421 payload_.materialized_iterator.~materialized_iterator_type(); 422 break; 423 default: 424 TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag."); 425 } 426 } 427 #endif 428 IListRefIterator(boxed_iterator_type boxed)429 IListRefIterator(boxed_iterator_type boxed) : tag_(IListRefTag::Boxed) { 430 payload_.boxed_iterator = boxed; 431 } 432 IListRefIterator(unboxed_iterator_type unboxed)433 IListRefIterator(unboxed_iterator_type unboxed) : tag_(IListRefTag::Unboxed) { 434 payload_.unboxed_iterator = unboxed; 435 } 436 IListRefIterator(materialized_iterator_type materialized)437 IListRefIterator(materialized_iterator_type materialized) : tag_(IListRefTag::Materialized) { 438 payload_.materialized_iterator = materialized; 439 } 440 441 detail::IListRefConstRef<T> operator*() const { 442 TORCH_ILISTREF_UNWRAP(tag_, { return ImplT::iterator_get(this_); }); 443 } 444 445 IListRefIterator& operator++() { 446 TORCH_ILISTREF_UNWRAP(tag_, { ++this_; }); 447 return *this; 448 } 449 450 IListRefIterator operator++(int) { 451 auto old = *this; 452 TORCH_ILISTREF_UNWRAP(tag_, { ++this_; }); 453 return old; 454 } 455 456 IListRefIterator& operator--() { 457 TORCH_ILISTREF_UNWRAP(tag_, { --this_; }); 458 return *this; 459 } 460 461 IListRefIterator operator--(int) { 462 auto old = *this; 463 TORCH_ILISTREF_UNWRAP(tag_, { --this_; }); 464 return old; 465 } 466 467 bool operator==(const IListRefIterator& rhs) const { 468 if (tag_ != rhs.tag_) { 469 return false; 470 } 471 TORCH_ILISTREF_UNWRAP(tag_, { 472 auto& rhs_it = ImplT::unwrap(rhs); 473 return this_ == rhs_it; 474 }); 475 } 476 477 bool operator!=(const IListRefIterator& rhs) const { 478 return !(*this == rhs); 479 } 480 481 private: 482 union Payload { 483 boxed_iterator_type boxed_iterator; 484 unboxed_iterator_type unboxed_iterator; 485 materialized_iterator_type materialized_iterator; 486 void* _init_ptr; Payload()487 Payload() : _init_ptr(nullptr) {} 488 #if defined(_MSC_VER) 489 // See [Note: MSVC Iterator Debug] ~Payload()490 ~Payload() {} 491 #endif 492 }; 493 494 Payload payload_; 495 IListRefTag tag_; 496 }; 497 498 /* 499 * See [Note: IListRef] 500 */ 501 template <typename T> 502 class IListRef { 503 private: 504 #define DEFINE_FRIEND_CLASS(TAG, ...) \ 505 friend class detail::IListRefTagImpl<IListRefTag::TAG, T>; \ 506 friend class detail::IListRefTagImplBase< \ 507 IListRefTag::TAG, \ 508 T, \ 509 typename detail::IListRefTagImpl<IListRefTag::TAG, T>::elem_type>; 510 TORCH_ILISTREF_FORALL_TAGS(DEFINE_FRIEND_CLASS) 511 #undef DEFINE_FRIEND_CLASS 512 513 public: 514 using unboxed_type = 515 typename detail::IListRefTagImpl<IListRefTag::Unboxed, T>::list_type; 516 using boxed_type = 517 typename detail::IListRefTagImpl<IListRefTag::Boxed, T>::list_type; 518 using materialized_type = 519 typename detail::MaterializedIListRef<T>; 520 521 using iterator = IListRefIterator<T>; 522 using const_iterator = IListRefIterator<T>; 523 using reverse_iterator = std::reverse_iterator<iterator>; 524 using value_type = typename iterator::value_type; 525 IListRef()526 IListRef() : tag_(IListRefTag::None) {} 527 IListRef(const boxed_type & boxed)528 IListRef(const boxed_type& boxed) : tag_(IListRefTag::Boxed) { 529 payload_.boxed = &boxed; 530 } 531 IListRef(const unboxed_type & unboxed)532 IListRef(const unboxed_type& unboxed) : tag_(IListRefTag::Unboxed) { 533 payload_.unboxed = unboxed; 534 } 535 IListRef(const std::initializer_list<T> & list)536 IListRef(const std::initializer_list<T>& list) : tag_(IListRefTag::Unboxed) { 537 payload_.unboxed = at::ArrayRef<T>(list); 538 } 539 540 template < 541 typename... UnboxedConstructorArgs, 542 typename = std::enable_if_t< 543 std::is_constructible_v<unboxed_type, UnboxedConstructorArgs...>>> IListRef(UnboxedConstructorArgs &&...args)544 IListRef(UnboxedConstructorArgs&&... args) : tag_(IListRefTag::Unboxed) { 545 payload_.unboxed = unboxed_type(std::forward<UnboxedConstructorArgs>(args)...); 546 } 547 IListRef(const materialized_type & materialized)548 IListRef(const materialized_type& materialized) : tag_(IListRefTag::Materialized) { 549 payload_.materialized = &materialized; 550 } 551 size()552 size_t size() const { 553 TORCH_ILISTREF_UNWRAP(tag_, { return this_.size(); }); 554 } 555 empty()556 bool empty() const { 557 return size() == 0; 558 } 559 begin()560 iterator begin() const { 561 TORCH_ILISTREF_UNWRAP(tag_, { return this_.begin(); }); 562 } 563 end()564 iterator end() const { 565 TORCH_ILISTREF_UNWRAP(tag_, { return this_.end(); }); 566 } 567 front()568 detail::IListRefConstRef<T> front() const { 569 TORCH_ILISTREF_UNWRAP(tag_, { return ImplT::front(this_); }); 570 } 571 572 /* 573 * Materializes the `IListRef` into a `std::vector`. 574 * 575 * This should be used when one wishes to either: 576 * 577 * - iterate over the list more than once: each `IListRefIterator` 578 * member function call has to go through a switch, introducing 579 * non-negligible overhead 580 * 581 * - randomly access an arbitrary element using `operator[]`: 582 * same reason as above 583 */ materialize()584 detail::MaterializedIListRef<T> materialize() const { 585 if (isMaterialized()) { 586 return toMaterialized(); 587 } 588 589 detail::MaterializedIListRef<T> materialized; 590 materialized.reserve(size()); 591 for (const auto& t : *this) { 592 materialized.emplace_back(t); 593 } 594 return materialized; 595 } 596 597 #define DEFINE_CHECK(TAG, ...) \ 598 bool is##TAG() const { \ 599 return tag_ == IListRefTag::TAG; \ 600 } 601 TORCH_ILISTREF_FORALL_TAGS(DEFINE_CHECK); 602 #undef DEFINE_CHECK 603 isNone()604 bool isNone() const { 605 return tag_ == IListRefTag::None; 606 } 607 608 #define DEFINE_CASTING(TAG, ...) \ 609 const typename detail::IListRefTagImpl<IListRefTag::TAG, T>::list_type& \ 610 to##TAG() const { \ 611 TORCH_INTERNAL_ASSERT(is##TAG()); \ 612 return detail::IListRefTagImpl<IListRefTag::TAG, T>::unwrap(*this); \ 613 } 614 TORCH_ILISTREF_FORALL_TAGS(DEFINE_CASTING); 615 #undef DEFINE_CASTING 616 617 private: 618 union Payload { 619 const boxed_type* boxed; 620 unboxed_type unboxed; 621 const materialized_type* materialized; Payload()622 Payload() : boxed(nullptr) {} 623 }; 624 625 Payload payload_; 626 IListRefTag tag_; 627 }; 628 629 } // namespace c10 630 631 #include <ATen/core/IListRef_inl.h> 632