xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/List.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ivalue_to.h>
4 #include <ATen/core/jit_type_base.h>
5 #include <c10/macros/Macros.h>
6 #include <c10/macros/Export.h>
7 #include <c10/util/TypeTraits.h>
8 #include <c10/util/TypeList.h>
9 #include <c10/util/intrusive_ptr.h>
10 #include <c10/util/ArrayRef.h>
11 #include <optional>
12 #include <vector>
13 
14 namespace at {
15 class Tensor;
16 }
17 namespace c10 {
18 struct IValue;
19 template<class T> class List;
20 struct Type;
21 
22 namespace detail {
23 
24 struct ListImpl final : public c10::intrusive_ptr_target {
25   using list_type = std::vector<IValue>;
26 
27   explicit TORCH_API ListImpl(list_type list_, TypePtr elementType_);
28 
29   list_type list;
30 
31   TypePtr elementType;
32 
copyfinal33   intrusive_ptr<ListImpl> copy() const {
34     return make_intrusive<ListImpl>(list, elementType);
35   }
36   friend TORCH_API bool operator==(const ListImpl& lhs, const ListImpl& rhs);
37 };
38 }
39 
40 namespace impl {
41 
42 template<class T, class Iterator> class ListIterator;
43 
44 template<class T, class Iterator> class ListElementReference;
45 
46 template<class T, class Iterator>
47 void swap(ListElementReference<T, Iterator>&& lhs, ListElementReference<T, Iterator>&& rhs) noexcept;
48 
49 template<class T, class Iterator>
50 bool operator==(const ListElementReference<T, Iterator>& lhs, const T& rhs);
51 
52 template<class T, class Iterator>
53 bool operator==(const T& lhs, const ListElementReference<T, Iterator>& rhs);
54 
55 template<class T>
56 struct ListElementConstReferenceTraits {
57   // In the general case, we use IValue::to().
58   using const_reference = typename c10::detail::ivalue_to_const_ref_overload_return<T>::type;
59 };
60 
61 // There is no to() overload for std::optional<std::string>.
62 template<>
63 struct ListElementConstReferenceTraits<std::optional<std::string>> {
64   using const_reference = std::optional<std::reference_wrapper<const std::string>>;
65 };
66 
67 template<class T, class Iterator>
68 class ListElementReference final {
69 public:
70   operator std::conditional_t<
71       std::is_reference_v<typename c10::detail::
72                             ivalue_to_const_ref_overload_return<T>::type>,
73       const T&,
74       T>() const;
75 
76   ListElementReference& operator=(T&& new_value) &&;
77 
78   ListElementReference& operator=(const T& new_value) &&;
79 
80   // assigning another ref to this assigns the underlying value
81   ListElementReference& operator=(ListElementReference&& rhs) && noexcept;
82 
83   const IValue& get() const& {
84     return *iterator_;
85   }
86 
87   friend void swap<T, Iterator>(ListElementReference&& lhs, ListElementReference&& rhs) noexcept;
88 
89   ListElementReference(const ListElementReference&) = delete;
90   ListElementReference& operator=(const ListElementReference&) = delete;
91 
92 private:
93   ListElementReference(Iterator iter)
94   : iterator_(iter) {}
95 
96   // allow moving, but only our friends (i.e. the List class) can move us
97   ListElementReference(ListElementReference&&) noexcept = default;
98   ListElementReference& operator=(ListElementReference&& rhs) & noexcept {
99     iterator_ = std::move(rhs.iterator_);
100     return *this;
101   }
102 
103   friend class List<T>;
104   friend class ListIterator<T, Iterator>;
105 
106   Iterator iterator_;
107 };
108 
109 // this wraps vector::iterator to make sure user code can't rely
110 // on it being the type of the underlying vector.
111 template <class T, class Iterator>
112 class ListIterator final {
113  public:
114    // C++17 friendly std::iterator implementation
115   using iterator_category = std::random_access_iterator_tag;
116   using value_type = T;
117   using difference_type = std::ptrdiff_t;
118   using pointer = T*;
119   using reference = ListElementReference<T, Iterator>;
120 
121   explicit ListIterator() = default;
122   ~ListIterator() = default;
123 
124   ListIterator(const ListIterator&) = default;
125   ListIterator(ListIterator&&) noexcept = default;
126   ListIterator& operator=(const ListIterator&) = default;
127   ListIterator& operator=(ListIterator&&) noexcept = default;
128 
129   ListIterator& operator++() {
130       ++iterator_;
131       return *this;
132   }
133 
134   ListIterator operator++(int) {
135       ListIterator copy(*this);
136       ++*this;
137       return copy;
138   }
139 
140   ListIterator& operator--() {
141       --iterator_;
142       return *this;
143   }
144 
145   ListIterator operator--(int) {
146       ListIterator copy(*this);
147       --*this;
148       return copy;
149   }
150 
151   ListIterator& operator+=(typename List<T>::size_type offset) {
152       iterator_ += offset;
153       return *this;
154   }
155 
156   ListIterator& operator-=(typename List<T>::size_type offset) {
157       iterator_ -= offset;
158       return *this;
159   }
160 
161   ListIterator operator+(typename List<T>::size_type offset) const {
162     return ListIterator{iterator_ + offset};
163   }
164 
165   ListIterator operator-(typename List<T>::size_type offset) const {
166     return ListIterator{iterator_ - offset};
167   }
168 
169   friend difference_type operator-(const ListIterator& lhs, const ListIterator& rhs) {
170     return lhs.iterator_ - rhs.iterator_;
171   }
172 
173   ListElementReference<T, Iterator> operator*() const {
174     return {iterator_};
175   }
176 
177   ListElementReference<T, Iterator> operator[](typename List<T>::size_type offset) const {
178     return {iterator_ + offset};
179   }
180 
181 private:
182   explicit ListIterator(Iterator iterator): iterator_(std::move(iterator)) {}
183 
184   Iterator iterator_;
185 
186   friend bool operator==(const ListIterator& lhs, const ListIterator& rhs) {
187     return lhs.iterator_ == rhs.iterator_;
188   }
189 
190   friend bool operator!=(const ListIterator& lhs, const ListIterator& rhs) {
191     return !(lhs == rhs);
192   }
193 
194   friend bool operator<(const ListIterator& lhs, const ListIterator& rhs) {
195     return lhs.iterator_ < rhs.iterator_;
196   }
197 
198   friend bool operator<=(const ListIterator& lhs, const ListIterator& rhs) {
199     return lhs.iterator_ <= rhs.iterator_;
200   }
201 
202   friend bool operator>(const ListIterator& lhs, const ListIterator& rhs) {
203     return lhs.iterator_ > rhs.iterator_;
204   }
205 
206   friend bool operator>=(const ListIterator& lhs, const ListIterator& rhs) {
207     return lhs.iterator_ >= rhs.iterator_;
208   }
209 
210   friend class ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
211   friend class List<T>;
212 };
213 
214 template<class T> List<T> toTypedList(List<IValue> list);
215 template<class T> List<IValue> toList(List<T>&& list);
216 template<class T> List<IValue> toList(const List<T>& list);
217 const IValue* ptr_to_first_element(const List<IValue>& list);
218 }
219 
220 /**
221  * An object of this class stores a list of values of type T.
222  *
223  * This is a pointer type. After a copy, both Lists
224  * will share the same storage:
225  *
226  * > List<int> a;
227  * > List<int> b = a;
228  * > b.push_back("three");
229  * > ASSERT("three" == a.get(0));
230  *
231  * We use this class in the PyTorch kernel API instead of
232  * std::vector<T>, because that allows us to do optimizations
233  * and switch out the underlying list implementation without
234  * breaking backwards compatibility for the kernel API.
235  */
236 template<class T>
237 class List final {
238 private:
239   // This is an intrusive_ptr because List is a pointer type.
240   // Invariant: This will never be a nullptr, there will always be a valid
241   // ListImpl.
242   c10::intrusive_ptr<c10::detail::ListImpl> impl_;
243 
244   using internal_reference_type = impl::ListElementReference<T, typename c10::detail::ListImpl::list_type::iterator>;
245   using internal_const_reference_type = typename impl::ListElementConstReferenceTraits<T>::const_reference;
246 
247 public:
248   using value_type = T;
249   using size_type = typename c10::detail::ListImpl::list_type::size_type;
250   using iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
251   using const_iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
252   using reverse_iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::reverse_iterator>;
253 
254   /**
255    * Constructs an empty list.
256    */
257   explicit List();
258 
259   /**
260    * Constructs a list with some initial values.
261    * Example:
262    *   List<int> a({2, 3, 4});
263    */
264   List(std::initializer_list<T> initial_values);
265   explicit List(ArrayRef<T> initial_values);
266 
267   /**
268    * Create a generic list with runtime type information.
269    * This only works for c10::impl::GenericList and is not part of the public API
270    * but only supposed to be used internally by PyTorch.
271    */
272   explicit List(TypePtr elementType);
273 
274   List(const List&) = default;
275   List& operator=(const List&) = default;
276 
277   /**
278    * Create a new List pointing to a deep copy of the same data.
279    * The List returned is a new list with separate storage.
280    * Changes in it are not reflected in the original list or vice versa.
281    */
282   List copy() const;
283 
284   /**
285    * Returns the element at specified location pos, with bounds checking.
286    * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
287    */
288   internal_const_reference_type get(size_type pos) const;
289 
290   /**
291    * Moves out the element at the specified location pos and returns it, with bounds checking.
292    * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
293    * The list contains an invalid element at position pos afterwards. Any operations
294    * on it before re-setting it are invalid.
295    */
296   value_type extract(size_type pos) const;
297 
298   /**
299    * Returns a reference to the element at specified location pos, with bounds checking.
300    * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
301    *
302    * You cannot store the reference, but you can read it and assign new values to it:
303    *
304    *   List<int64_t> list = ...;
305    *   list[2] = 5;
306    *   int64_t v = list[1];
307    */
308   internal_const_reference_type operator[](size_type pos) const;
309 
310   internal_reference_type operator[](size_type pos);
311 
312   /**
313    * Assigns a new value to the element at location pos.
314    */
315   void set(size_type pos, const value_type& value) const;
316 
317   /**
318    * Assigns a new value to the element at location pos.
319    */
320   void set(size_type pos, value_type&& value) const;
321 
322   /**
323    * Returns an iterator to the first element of the container.
324    * If the container is empty, the returned iterator will be equal to end().
325    */
326   iterator begin() const;
327 
328   /**
329    * Returns an iterator to the element following the last element of the container.
330    * This element acts as a placeholder; attempting to access it results in undefined behavior.
331    */
332   iterator end() const;
333 
334   /**
335    * Checks if the container has no elements.
336    */
337   bool empty() const;
338 
339   /**
340    * Returns the number of elements in the container
341    */
342   size_type size() const;
343 
344   /**
345    * Increase the capacity of the vector to a value that's greater or equal to new_cap.
346    */
347   void reserve(size_type new_cap) const;
348 
349   /**
350    * Erases all elements from the container. After this call, size() returns zero.
351    * Invalidates any references, pointers, or iterators referring to contained elements. Any past-the-end iterators are also invalidated.
352    */
353   void clear() const;
354 
355   /**
356    * Inserts value before pos.
357    * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
358    */
359   iterator insert(iterator pos, const T& value) const;
360 
361   /**
362    * Inserts value before pos.
363    * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
364    */
365   iterator insert(iterator pos, T&& value) const;
366 
367   /**
368    * Inserts a new element into the container directly before pos.
369    * The new element is constructed with the given arguments.
370    * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
371    */
372   template<class... Args>
373   iterator emplace(iterator pos, Args&&... value) const;
374 
375   /**
376    * Appends the given element value to the end of the container.
377    * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
378    */
379   void push_back(const T& value) const;
380 
381   /**
382    * Appends the given element value to the end of the container.
383    * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
384    */
385   void push_back(T&& value) const;
386 
387   /**
388    * Appends the given list to the end of the container. Uses at most one memory allocation.
389    * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
390    */
391   void append(List<T> lst) const;
392 
393   /**
394    * Appends the given element value to the end of the container.
395    * The new element is constructed with the given arguments.
396    * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
397    */
398   template<class... Args>
399   void emplace_back(Args&&... args) const;
400 
401   /**
402    * Removes the element at pos.
403    * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
404    */
405   iterator erase(iterator pos) const;
406 
407   /**
408    * Removes the elements in the range [first, last).
409    * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
410    */
411   iterator erase(iterator first, iterator last) const;
412 
413   /**
414    * Removes the last element of the container.
415    * Calling pop_back on an empty container is undefined.
416    * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
417    */
418   void pop_back() const;
419 
420   /**
421    * Resizes the container to contain count elements.
422    * If the current size is less than count, additional default-inserted elements are appended.
423    * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
424    */
425   void resize(size_type count) const;
426 
427   /**
428    * Resizes the container to contain count elements.
429    * If the current size is less than count, additional copies of value are appended.
430    * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
431    */
432   void resize(size_type count, const T& value) const;
433 
434   /**
435    * Value equality comparison. This function implements Python-like semantics for
436    * equality: two lists with the same identity (e.g. same pointer) trivially
437    * compare equal, otherwise each element is compared for equality.
438    */
439   template <class T_>
440   friend bool operator==(const List<T_>& lhs, const List<T_>& rhs);
441 
442   template <class T_>
443   friend bool operator!=(const List<T_>& lhs, const List<T_>& rhs);
444 
445   /**
446    * Identity comparison. Returns true if and only if `rhs` represents the same
447    * List object as `this`.
448    */
449   bool is(const List<T>& rhs) const;
450 
451   std::vector<T> vec() const;
452 
453   /**
454    * Returns the number of Lists currently pointing to this same list.
455    * If this is the only instance pointing to this list, returns 1.
456    */
457   // TODO Test use_count
458   size_t use_count() const;
459 
460   TypePtr elementType() const;
461 
462   // See [unsafe set type] for why this exists.
463   void unsafeSetElementType(TypePtr t);
464 
465 private:
466   explicit List(c10::intrusive_ptr<c10::detail::ListImpl>&& elements);
467   explicit List(const c10::intrusive_ptr<c10::detail::ListImpl>& elements);
468   friend struct IValue;
469   template<class T_> friend List<T_> impl::toTypedList(List<IValue>);
470   template<class T_> friend List<IValue> impl::toList(List<T_>&&);
471   template<class T_> friend List<IValue> impl::toList(const List<T_>&);
472   friend const IValue* impl::ptr_to_first_element(const List<IValue>& list);
473 };
474 
475 namespace impl {
476 // GenericList is how IValue stores lists. It is, however, not part of the
477 // public API. Kernels should use Lists with concrete types instead
478 // (maybe except for some internal prim ops).
479 using GenericList = List<IValue>;
480 
481 }
482 }
483 
484 namespace torch {
485   template<class T> using List = c10::List<T>;
486 }
487 
488 #include <ATen/core/List_inl.h>  // IWYU pragma: keep
489