xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/IListRef.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ivalue_to.h>
4 #include <c10/util/ArrayRef.h>
5 #include <c10/util/Exception.h>
6 
7 #include <functional>
8 #include <initializer_list>
9 #include <iterator>
10 #include <type_traits>
11 
12 /*
13  * [Note: IListRef]
14  * Wrapper around different API containers (e.g. boxed and unboxed).
15  *
16  * What is it?
17  * ===========
18  * It is a tagged union of both boxed and unboxed API containers.
19  * Working implementations:
20  *
21  * - `IListRef<at::Tensor>`
22  * - `IListRef<at::OptionalTensorRef>`
23  *
24  * Note that `IListRef` is a view type. Meaning that it won't own the
25  * tensors it holds. It's intended to be used only as argument parameters.
26  * Specifically, where these 2 worlds overlap.
27  *
28  * What is this for?
29  * =================
30  * Historically, PyTorch has maintained 2 different APIs: the unboxed
31  * (called from C++ API and Python eager mode) and boxed APIs (called
32  * from the TorchScript JIT, mobile interpreter, and boxed fallbacks).
33  *
34  * Calling unboxed kernels from the boxed "world" and vice-versa may
35  * result in non-negligible overhead. Lists are one of those types:
36  *
37  * - Boxed world: `c10::List`
38  * - Unboxed world: `c10::ArrayRef`
39  *
40  * In this context, `c10::IListRef` solves this problem by wrapping those
41  * 2 container types, so that we don't need to convert from one to
42  * the other.
43  *
44  * (see https://github.com/pytorch/pytorch/issues/66328)
45  *
46  * What does it do?
47  * ================
48  * This container wraps around the different tagged containers
49  * (currently, only boxed and unboxed), without incurring in extra
50  * overhead for converting from one to another. It does so while
51  * exposing usual container methods, which dispatch to corresponding
52  * implementations.
53  *
54  * While it works with different container types, it introduces
55  * overhead for repeatedly calling member functions (since those will
56  * get dispatched, again). Therefore, you should only use it to iterate
57  * through the list up to one time. If you need to do more complex things,
58  * call `materialize()` first.
59  *
60  * Adding support for a new Tag
61  * ============================
62  * Suppose we want to add a new tag: `Chest`. Here are the steps
63  * we would have to go through:
64  *
65  * 1. Add a line for it in the macro `TORCH_ILISTREF_FORALL_TAGS`.
66  *
67  *   #define TORCH_ILISTREF_FORALL_TAGS(_, ...) \
68  *     ...
69  *     _(Chest, ##__VA_ARGS__)
70  *
71  * 2. Add type aliases, union members, and constructors.
72  *
73  *   template <typename T>
74  *   class IListRef {
75  *     ...
76  *     using chest_type =
77  *       typename detail::IListRefTagImpl<T, IListRefTag::Chest>::list_type;
78  *     ...
79  *     IListRef(...) : tag_(IListRefTag::Chest) {
80  *       ...
81  *     }
82  *     ...
83  *     union Payload {
84  *       ...
85  *       chest_type chest;
86  *       ...
87  *     };
88  *     ...
89  *   };
90  *
91  * 3. Add a default implementation for it (in 'IListRef_inl.h'). It's
92  *    preferable to make the default implementation work for `T = Tensor`
93  *    (both `Unboxed` and `Boxed` do it).
94  *
95  *   template <typename T, typename ListElemT>
96  *   class IListRefTagImplBase<IListRefTag::Chest, T, ListElemT> {
97  *    public:
98  *     using elem_type = ListElemT;
99  *     using list_type = ChestContainer<elem_type>;
100  *
101  *     static const list_type& unwrap(const IListRef<T>& ilist) { ... }
102  *
103  *     static typename list_type::const_iterator& unwrap(
104  *         IListRefIterator<T>& it) { ... }
105  *
106  *     static const typename list_type::const_iterator& unwrap(
107  *         const IListRefIterator<T>& it) { ... }
108  *
109  *     static IListRefConstRef<T> iterator_get(
110  *         const typename list_type::const_iterator& it) { ... }
111  *   }
112  *
113  * 4. Add an specialization for each of the already supported types.
114  *    Finally, for consistency, add them to the tracking list.
115  *    (see [Note: IListRefTagImpl Specializations])
116  *
117  *   template <>
118  *   class IListRefTagImpl<IListRefTag::Chest, at::Tensor>
119  *       : public IListRefTagImplBase<IListRefTag::Chest, at::Tensor> {};
120  *
121  * Adding support for a new Type
122  * =============================
123  * Suppose we want to add support for a new type: `Matrix`.
124  * Here are the steps we would have to go through:
125  *
126  * 1. Add an specialization for each of the existing tags.
127  *    For consistency, add them to the tracking list.
128  *    (see [Note: IListRefTagImpl Specializations])
129  *
130  *   template <>
131  *   class IListRefTagImpl<IListRefTag::Unboxed, Matrix>
132  *       : public IListRefTagImplBase<IListRefTag::Unboxed, Matrix> {};
133  *
134  *   template <>
135  *   class IListRefTagImpl<Matrix, IListRefTag::Boxed>
136  *       : public IListRefTagImplBase<IListRefTag::Boxed, Matrix> {};
137  *
138  * Common Problems
139  * ===============
140  * 1. One of `IListRef(Iterator)` methods are failing to compile.
141  *
142  *     That may be happening because the container type you added
143  *     is not compatible with the code written for that method. If
144  *     that's true, then you might have to transform that code into
145  *     a static method call (see `List::operator[]` method).
146  *
147  * 2. Can't make `IListRefIterator<T>::operator*` return a const-reference.
148  *
149  *    First, keep in mind that we assume that boxed containers will
150  *    have to deal with `IValue` (e.g. `c10::List`). In this context,
151  *    what may be happening is that `IValue` doesn't store internally
152  *    your type `T`. Instead, it constructs a type new `T` everytime
153  *    you try to get `T` for it (see `IListRef<at::OptinalTensorRef>`).
154  */
155 
156 namespace c10 {
157 template <typename T>
158 class IListRef;
159 
160 /*
161  * Applies arbitrary macros to each `IListRefTag`.
162  */
163 #define TORCH_ILISTREF_FORALL_TAGS(_, ...) \
164   _(Unboxed, ##__VA_ARGS__)                \
165   _(Boxed, ##__VA_ARGS__)                  \
166   _(Materialized, ##__VA_ARGS__)
167 
168 /*
169  * Defines a "switch-case" for `TAG`. Inside, it executes `BODY`,
170  * while bringing to scope:
171  *
172  * - `ImplT`: the implementation class for `TAG`
173  * - `this_`: the result of unwrapping `this`
174  */
175 #define TORCH_ILISTREF_UNWRAP_CASE(TAG, BODY)                        \
176   case c10::IListRefTag::TAG: {                                      \
177     using ImplT = c10::detail::IListRefTagImpl<IListRefTag::TAG, T>; \
178     auto& this_ = ImplT::unwrap(*this);                              \
179     BODY                                                             \
180   } break;
181 
182 /*
183  * Dispatches the unwrap call, depending on `TAG`, followed by
184  * the execution of `BODY`. It aborts if `TAG` is not a `IListRefTag`.
185  *
186  * This macro is useful because it allows us to handle different
187  * types (that correspond to different tags) to be implemented
188  * only once. We can do it even when the implementation of the
189  * different tags aren't syntatically the same, by dispatching
190  * it to a function (e.g. `ImplT::<dispatch-function>(this_)`).
191  */
192 #define TORCH_ILISTREF_UNWRAP(TAG, BODY)                         \
193   switch (TAG) {                                                 \
194     TORCH_ILISTREF_FORALL_TAGS(TORCH_ILISTREF_UNWRAP_CASE, BODY) \
195     break;                                                       \
196     default:                                                     \
197       TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag.");     \
198   }
199 
200 enum class IListRefTag {
201 #define DEFINE_TAG(tag, ...) tag,
202   TORCH_ILISTREF_FORALL_TAGS(DEFINE_TAG)
203 #undef DEFINE_TAG
204       None
205 };
206 
207 namespace detail {
208 /*
209  * Type alias that specifies whether we return a reference or a copy of `T`.
210  *
211  * What is this for?
212  * =================
213  * Since values in the boxed world are represented by an `IValue`, we also
214  * depend on whether it can be converted to a const-reference (`Tensor`) or
215  * has to create a new copy of `T` (`OptionalTensorRef`).
216  */
217 template <typename T>
218 using IListRefConstRef = typename ivalue_to_const_ref_overload_return<T>::type;
219 
220 /*
221  * Interface that implements key functions for each `IListRefTag` type.
222  *
223  * What is this for?
224  * =================
225  * Given an `IListRef(Iterator)<T>`, some methods have to be implemented
226  * differently for each `TAG`. Therefore, the methods inside this class
227  * are used as dispatch targets for the different `IListRefTag` values.
228  *
229  * You should create an specialization of this class for each possible
230  * combination of `IListRefTag` type (except `None`) and element types
231  * (e.g. `Tensor`).
232  *
233  * What does it do?
234  * ================
235  * 1. defines static methods to be used as dispatch targets by both
236  *    `IListRef<T>` and `IListRefIterator<T>` (see the implementation of
237  *    `IListRefTagImplBase`).
238  *
239  * 2. defines the `elem_type` and `list_type` aliases that will be
240  *    used in the definition of `IListRef<T>`. In general, we should do
241  *    so by inheriting from `IListRefTagImplBase<TAG, T, ListElemT>`.
242  *
243  * [Note: IListRefTagImpl Specialization]
244  * ======================================
245  * For `IListRef(Iterator)<at::Tensor>`:
246  * - <IListRefTag::Unboxed, at::Tensor>
247  * - <IListRefTag::Boxed, at::Tensor>
248  * - <IListRefTag::Materialized, at::Tensor>
249  *
250  * For `IListRef(Iterator)<at::OptionalTensorRef>`:
251  * - <IListRefTag::Unboxed, at::OptionalTensorRef>
252  * - <IListRefTag::Boxed, at::OptionalTensorRef>
253  * - <IListRefTag::Materialized, at::OptionalTensorRef>
254  */
255 template <IListRefTag TAG, typename T>
256 class IListRefTagImpl {};
257 
258 /*
259  * Base implementation of `IListRefTagImpl<TAG, T>` methods.
260  *
261  * What is this for?
262  * =================
263  * This should make adding specializations for new types easier. For
264  * example, one should be able to add a new type just by making its
265  * `IListRefTagImpl` specialization inherit from `IListRefTagImplBase`.
266  *
267  * You should create a partial specialization for this class only if
268  * you introduce a new `IListRefTag`. The idea being that there is one
269  * default implementation for each possible value of `IListRefTag`.
270  *
271  * What does it do?
272  * ================
273  * 1. defines `elem_type` as an alias to `ListElemT`.
274  *
275  * 1. defines `list_type` as an alias to the default container type
276  *    that will hold a collection of `elem_type`. The idea being that
277  *    all types tagged as `TAG` will have `list_type` as its container,
278  *    with different `elem_type`.
279  *
280  * 3. defines the default implementation for each of the methods that
281  *    are supposed to be defined on `IListRefTagImpl` specializations.
282  *
283  * 4. inheriting from `IListRefTagImplBase<TAG, T, ListElemT>` also means
284  *    that the payload of the type `IListRef<T>` will be of type `list_type`
285  *    when it is tagged as `TAG`.
286  */
287 template <IListRefTag TAG, typename T, typename ListElemT = T>
288 class IListRefTagImplBase {};
289 
290 /*
291  * Materialized container for `IListRef<T>`.
292  *
293  * What is this for?
294  * =================
295  * Container that groups `T` references together. This exchanges the
296  * overhead of every method call from `IListRef<T>` for a dynamic allocation.
297  *
298  * You should use this container instead of `IListRef<T>` if:
299  *
300  *   - You are going to iterate the list more than once
301  *   - You need to repeatedly access arbitrary elements (using `operator[]`)
302  * What does it do?
303 
304  * ================
305  * Removes the reference (&) from the type, and wraps it into a
306  * `std::reference_wrapper`. If `IListRefConstRef<T>` is not a
307  * reference type, then it's left unchanged.
308  */
309 template <typename T>
310 using _MaterializedIListRefElem = std::conditional_t<
311     std::is_reference_v<T>,
312     typename std::reference_wrapper<std::remove_reference_t<T>>,
313     T>;
314 
315 template <typename T>
316 using MaterializedIListRefElem = _MaterializedIListRefElem<IListRefConstRef<T>>;
317 
318 template <typename T>
319 using MaterializedIListRef = std::vector<MaterializedIListRefElem<T>>;
320 
321 } // namespace detail
322 
323 /*
324  * Iterator for `IListRef<T>`.
325  *
326  * What is it?
327  * ===========
328  * Currently, a `std::bidirectional_iterator` that wraps the iterator
329  * types defined for each of the `IListRefTag`.
330  *
331  * One should be able to use it, as if it were the unwrapped
332  * iterators themselves.
333 
334  * What does it do?
335  * ================
336  * Similarly to `IListRef<T>`, this is a wrapper class. Specifically, it
337  * wraps each container's `const_iterator` type alias. So, for example,
338  * given that the container for `IListRefTag::Boxed` is `c10::List`, this
339  * iterator will wrap a `c10::List::const_iterator`.
340  *
341  * [Note: MSVC Iterator Debug]
342  * ===========================
343  * MSVC `vector<T>::iterator` implementation (used in the boxed variant)
344  * makes it so this union's destructor, copy-constructor (assignment), and
345  * move-constructor (assignment) are implicitly deleted.
346  *
347  * Therefore, we need to explicitly define them as needed. Follows a list
348  * of places where these are needed and their reason:
349  *
350  *   - `Payload` destructor:
351  *     it is deleted only if the macro `_ITERATOR_DEBUG_LEVEL` is set to 2.
352  *
353  *   - `IListRefIterator` destructor:
354  *     same as above. However, we need to explicitly call the variant
355  *     destructor explicitly.
356  *
357  *   - `IListRefIterator` copy-constructor:
358  *     it is deleted only if the macro `_ITERATOR_DEBUG_LEVEL` is different
359  *     than 0.
360  */
361 template <typename T>
362 class IListRefIterator {
363  private:
364 #define DEFINE_FRIEND_CLASS(TAG, ...)                        \
365   friend class detail::IListRefTagImpl<IListRefTag::TAG, T>; \
366   friend class detail::IListRefTagImplBase<                  \
367       IListRefTag::TAG,                                      \
368       T,                                                     \
369       typename detail::IListRefTagImpl<IListRefTag::TAG, T>::elem_type>;
370   TORCH_ILISTREF_FORALL_TAGS(DEFINE_FRIEND_CLASS)
371 #undef DEFINE_FRIEND_CLASS
372 
373  public:
374   // C++17 friendly std::iterator implementation
375   using iterator_category = std::bidirectional_iterator_tag;
376   using value_type = T;
377   using difference_type = std::ptrdiff_t;
378   using pointer = T*;
379   using reference = T&;
380 
381   using unboxed_iterator_type = typename detail::
382       IListRefTagImpl<IListRefTag::Unboxed, T>::list_type::const_iterator;
383   using boxed_iterator_type = typename detail::
384       IListRefTagImpl<IListRefTag::Boxed, T>::list_type::const_iterator;
385   using materialized_iterator_type =
386       typename detail::MaterializedIListRef<T>::const_iterator;
387 
IListRefIterator()388   IListRefIterator() : tag_(IListRefTag::None) {}
389 
390 #if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL != 0
391   // See [Note: MSVC Iterator Debug]
IListRefIterator(const IListRefIterator & iterator)392   IListRefIterator(const IListRefIterator& iterator)
393       : tag_(iterator.tag_) {
394     switch (tag_) {
395       case IListRefTag::Boxed:
396         payload_.boxed_iterator = iterator.payload_.boxed_iterator;
397         break;
398       case IListRefTag::Unboxed:
399         payload_.unboxed_iterator = iterator.payload_.unboxed_iterator;
400         break;
401       case IListRefTag::Materialized:
402         payload_.materialized_iterator = iterator.payload_.materialized_iterator;
403         break;
404       default:
405         TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag.");
406     }
407   }
408 #endif
409 
410 #if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL == 2
411   // See [Note: MSVC Iterator Debug]
noexcept(false)412   ~IListRefIterator() noexcept(false) {
413     switch (tag_) {
414       case IListRefTag::Boxed:
415         payload_.boxed_iterator.~boxed_iterator_type();
416         break;
417       case IListRefTag::Unboxed:
418         payload_.unboxed_iterator.~unboxed_iterator_type();
419         break;
420       case IListRefTag::Materialized:
421         payload_.materialized_iterator.~materialized_iterator_type();
422         break;
423       default:
424         TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag.");
425     }
426   }
427 #endif
428 
IListRefIterator(boxed_iterator_type boxed)429   IListRefIterator(boxed_iterator_type boxed) : tag_(IListRefTag::Boxed) {
430     payload_.boxed_iterator = boxed;
431   }
432 
IListRefIterator(unboxed_iterator_type unboxed)433   IListRefIterator(unboxed_iterator_type unboxed) : tag_(IListRefTag::Unboxed) {
434     payload_.unboxed_iterator = unboxed;
435   }
436 
IListRefIterator(materialized_iterator_type materialized)437   IListRefIterator(materialized_iterator_type materialized) : tag_(IListRefTag::Materialized) {
438     payload_.materialized_iterator = materialized;
439   }
440 
441   detail::IListRefConstRef<T> operator*() const {
442     TORCH_ILISTREF_UNWRAP(tag_, { return ImplT::iterator_get(this_); });
443   }
444 
445   IListRefIterator& operator++() {
446     TORCH_ILISTREF_UNWRAP(tag_, { ++this_; });
447     return *this;
448   }
449 
450   IListRefIterator operator++(int) {
451     auto old = *this;
452     TORCH_ILISTREF_UNWRAP(tag_, { ++this_; });
453     return old;
454   }
455 
456   IListRefIterator& operator--() {
457     TORCH_ILISTREF_UNWRAP(tag_, { --this_; });
458     return *this;
459   }
460 
461   IListRefIterator operator--(int) {
462     auto old = *this;
463     TORCH_ILISTREF_UNWRAP(tag_, { --this_; });
464     return old;
465   }
466 
467   bool operator==(const IListRefIterator& rhs) const {
468     if (tag_ != rhs.tag_) {
469       return false;
470     }
471     TORCH_ILISTREF_UNWRAP(tag_, {
472       auto& rhs_it = ImplT::unwrap(rhs);
473       return this_ == rhs_it;
474     });
475   }
476 
477   bool operator!=(const IListRefIterator& rhs) const {
478     return !(*this == rhs);
479   }
480 
481  private:
482   union Payload {
483     boxed_iterator_type boxed_iterator;
484     unboxed_iterator_type unboxed_iterator;
485     materialized_iterator_type materialized_iterator;
486     void* _init_ptr;
Payload()487     Payload() : _init_ptr(nullptr) {}
488 #if defined(_MSC_VER)
489     // See [Note: MSVC Iterator Debug]
~Payload()490     ~Payload() {}
491 #endif
492   };
493 
494   Payload payload_;
495   IListRefTag tag_;
496 };
497 
498 /*
499  * See [Note: IListRef]
500  */
501 template <typename T>
502 class IListRef {
503  private:
504 #define DEFINE_FRIEND_CLASS(TAG, ...)                        \
505   friend class detail::IListRefTagImpl<IListRefTag::TAG, T>; \
506   friend class detail::IListRefTagImplBase<                  \
507       IListRefTag::TAG,                                      \
508       T,                                                     \
509       typename detail::IListRefTagImpl<IListRefTag::TAG, T>::elem_type>;
510   TORCH_ILISTREF_FORALL_TAGS(DEFINE_FRIEND_CLASS)
511 #undef DEFINE_FRIEND_CLASS
512 
513  public:
514   using unboxed_type =
515       typename detail::IListRefTagImpl<IListRefTag::Unboxed, T>::list_type;
516   using boxed_type =
517       typename detail::IListRefTagImpl<IListRefTag::Boxed, T>::list_type;
518   using materialized_type =
519       typename detail::MaterializedIListRef<T>;
520 
521   using iterator = IListRefIterator<T>;
522   using const_iterator = IListRefIterator<T>;
523   using reverse_iterator = std::reverse_iterator<iterator>;
524   using value_type = typename iterator::value_type;
525 
IListRef()526   IListRef() : tag_(IListRefTag::None) {}
527 
IListRef(const boxed_type & boxed)528   IListRef(const boxed_type& boxed) : tag_(IListRefTag::Boxed) {
529     payload_.boxed = &boxed;
530   }
531 
IListRef(const unboxed_type & unboxed)532   IListRef(const unboxed_type& unboxed) : tag_(IListRefTag::Unboxed) {
533     payload_.unboxed = unboxed;
534   }
535 
IListRef(const std::initializer_list<T> & list)536   IListRef(const std::initializer_list<T>& list) : tag_(IListRefTag::Unboxed) {
537     payload_.unboxed = at::ArrayRef<T>(list);
538   }
539 
540   template <
541       typename... UnboxedConstructorArgs,
542       typename = std::enable_if_t<
543           std::is_constructible_v<unboxed_type, UnboxedConstructorArgs...>>>
IListRef(UnboxedConstructorArgs &&...args)544   IListRef(UnboxedConstructorArgs&&... args) : tag_(IListRefTag::Unboxed) {
545     payload_.unboxed = unboxed_type(std::forward<UnboxedConstructorArgs>(args)...);
546   }
547 
IListRef(const materialized_type & materialized)548   IListRef(const materialized_type& materialized) : tag_(IListRefTag::Materialized) {
549     payload_.materialized = &materialized;
550   }
551 
size()552   size_t size() const {
553     TORCH_ILISTREF_UNWRAP(tag_, { return this_.size(); });
554   }
555 
empty()556   bool empty() const {
557     return size() == 0;
558   }
559 
begin()560   iterator begin() const {
561     TORCH_ILISTREF_UNWRAP(tag_, { return this_.begin(); });
562   }
563 
end()564   iterator end() const {
565     TORCH_ILISTREF_UNWRAP(tag_, { return this_.end(); });
566   }
567 
front()568   detail::IListRefConstRef<T> front() const {
569     TORCH_ILISTREF_UNWRAP(tag_, { return ImplT::front(this_); });
570   }
571 
572   /*
573    * Materializes the `IListRef` into a `std::vector`.
574    *
575    * This should be used when one wishes to either:
576    *
577    *   - iterate over the list more than once: each `IListRefIterator`
578    *     member function call has to go through a switch, introducing
579    *     non-negligible overhead
580    *
581    *   - randomly access an arbitrary element using `operator[]`:
582    *     same reason as above
583    */
materialize()584   detail::MaterializedIListRef<T> materialize() const {
585     if (isMaterialized()) {
586       return toMaterialized();
587     }
588 
589     detail::MaterializedIListRef<T> materialized;
590     materialized.reserve(size());
591     for (const auto& t : *this) {
592       materialized.emplace_back(t);
593     }
594     return materialized;
595   }
596 
597 #define DEFINE_CHECK(TAG, ...)    \
598   bool is##TAG() const {          \
599     return tag_ == IListRefTag::TAG; \
600   }
601   TORCH_ILISTREF_FORALL_TAGS(DEFINE_CHECK);
602 #undef DEFINE_CHECK
603 
isNone()604   bool isNone() const {
605     return tag_ == IListRefTag::None;
606   }
607 
608 #define DEFINE_CASTING(TAG, ...)                                          \
609   const typename detail::IListRefTagImpl<IListRefTag::TAG, T>::list_type& \
610       to##TAG() const {                                                   \
611     TORCH_INTERNAL_ASSERT(is##TAG());                                     \
612     return detail::IListRefTagImpl<IListRefTag::TAG, T>::unwrap(*this);   \
613   }
614   TORCH_ILISTREF_FORALL_TAGS(DEFINE_CASTING);
615 #undef DEFINE_CASTING
616 
617  private:
618   union Payload {
619     const boxed_type* boxed;
620     unboxed_type unboxed;
621     const materialized_type* materialized;
Payload()622     Payload() : boxed(nullptr) {}
623   };
624 
625   Payload payload_;
626   IListRefTag tag_;
627 };
628 
629 } // namespace c10
630 
631 #include <ATen/core/IListRef_inl.h>
632