xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/List_inl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/jit_type_base.h>
4 #include <ATen/core/ivalue.h>
5 
6 namespace c10 {
7 
8 template<class T> decltype(auto) getTypePtr();
9 std::string toString(const Type& type);
10 
11 template<class T>
List(c10::intrusive_ptr<c10::detail::ListImpl> && elements)12 List<T>::List(c10::intrusive_ptr<c10::detail::ListImpl>&& elements)
13 : impl_(std::move(elements)) {}
14 
15 template<class T>
List(const c10::intrusive_ptr<c10::detail::ListImpl> & elements)16 List<T>::List(const c10::intrusive_ptr<c10::detail::ListImpl>& elements)
17 : impl_(elements) {}
18 
19 template<class T>
List()20 List<T>::List()
21 : List(make_intrusive<c10::detail::ListImpl>(
22   typename c10::detail::ListImpl::list_type(),
23   getTypePtr<T>())) {
24   static_assert(!std::is_same<T, IValue>::value, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType) instead.");
25 }
26 
27 template<class T>
List(ArrayRef<T> values)28 List<T>::List(ArrayRef<T> values)
29 : List(make_intrusive<c10::detail::ListImpl>(
30     typename c10::detail::ListImpl::list_type(),
31     getTypePtr<T>())) {
32   static_assert(!std::is_same<T, IValue>::value, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType).");
33   impl_->list.reserve(values.size());
34   for (const T& element : values) {
35     impl_->list.push_back(element);
36   }
37 }
38 
39 template<class T>
List(std::initializer_list<T> initial_values)40 List<T>::List(std::initializer_list<T> initial_values)
41 : List(ArrayRef<T>(initial_values)) {
42   static_assert(!std::is_same<T, IValue>::value, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType).");
43 }
44 
45 template<class T>
List(TypePtr elementType)46 List<T>::List(TypePtr elementType)
47 : List(make_intrusive<c10::detail::ListImpl>(
48     typename c10::detail::ListImpl::list_type(),
49     std::move(elementType))) {
50   static_assert(std::is_same<T, IValue>::value || std::is_same<T, c10::intrusive_ptr<ivalue::Future>>::value,
51                 "This constructor is only valid for c10::impl::GenericList or List<Future>.");
52 }
53 
54 namespace impl {
55 template<class T>
toTypedList(impl::GenericList list)56 List<T> toTypedList(impl::GenericList list) {
57   // If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant
58   // because upcasting would allow people to add types into the new list that would break the old list.
59   // However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can
60   // allow upcasting. This can be a perf improvement since we can cast List<T> to List<optional<T>>
61   // without having to copy it. This is also used to provide backwards compatibility with some old models
62   // that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_
63   // as List<Tensor> before we changed that argument to be List<optional<Tensor>>. When deserializing, we
64   // have list.use_count() == 1 and can deserialize the List<Tensor> directly as List<optional<Tensor>>.
65   TORCH_CHECK(*list.impl_->elementType == *getTypePtr<T>()
66     || (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(*getTypePtr<T>()))
67     , "Tried to cast a List<", toString(*list.impl_->elementType), "> to a List<", toString(*getTypePtr<T>()), ">. Types mismatch.");
68   return List<T>(std::move(list.impl_));
69 }
70 
71 template<class T>
toList(List<T> && list)72 impl::GenericList toList(List<T>&& list) {
73   return GenericList(std::move(list.impl_));
74 }
75 template<class T>
toList(const List<T> & list)76 impl::GenericList toList(const List<T>& list) {
77   return GenericList(list.impl_);
78 }
79 }
80 
81 template<class T>
copy()82 List<T> List<T>::copy() const {
83   return List<T>(impl_->copy());
84 }
85 
86 namespace detail {
87   template<class T>
list_element_to(T element)88   T list_element_to(T element) {
89     return element;
90   }
91   template<class T>
list_element_to(const IValue & element)92   T list_element_to(const IValue& element) {
93     return element.template to<T>();
94   }
95   template<class T>
list_element_to(IValue && element)96   T list_element_to(IValue&& element) {
97     return std::move(element).template to<T>();
98   }
99   template<class T>
100   struct ListElementFrom {
fromListElementFrom101     static IValue from(const T& element) {
102       return element;
103     }
fromListElementFrom104     static IValue from(T&& element) {
105       return std::move(element);
106     }
107   };
108   template<>
109   struct ListElementFrom<IValue> {
110     static const IValue& from(const IValue& element) {
111       return element;
112     }
113     static IValue&& from(IValue&& element) {
114       return std::move(element);
115     }
116   };
117 }
118 
119 namespace impl {
120 
121 template <class T, class Iterator>
122 ListElementReference<T, Iterator>::operator std::conditional_t<
123     std::is_reference_v<typename c10::detail::ivalue_to_const_ref_overload_return<
124         T>::type>,
125     const T&,
126     T>() const {
127   return iterator_->template to<T>();
128 }
129 
130 template<class T, class Iterator>
131 ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(T&& new_value) && {
132   *iterator_ = c10::detail::ListElementFrom<T>::from(std::move(new_value));
133   return *this;
134 }
135 
136 template<class T, class Iterator>
137 ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(const T& new_value) && {
138   *iterator_ = c10::detail::ListElementFrom<T>::from(new_value);
139   return *this;
140 }
141 
142 template<class T, class Iterator>
143 ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(ListElementReference<T, Iterator>&& rhs) && noexcept {
144   *iterator_ = *rhs.iterator_;
145   return *this;
146 }
147 
148 template<class T, class Iterator>
149 void swap(ListElementReference<T, Iterator>&& lhs, ListElementReference<T, Iterator>&& rhs)  noexcept {
150   std::swap(*lhs.iterator_, *rhs.iterator_);
151 }
152 
153 template<class T, class Iterator>
154 bool operator==(const ListElementReference<T, Iterator>& lhs, const T& rhs) {
155   const T& lhs_tmp = lhs;
156   return lhs_tmp == rhs;
157 }
158 
159 template<class T, class Iterator>
160 inline bool operator==(const T& lhs, const ListElementReference<T, Iterator>& rhs) {
161   return rhs == lhs;
162 }
163 
164 template<class T>
165 inline typename ListElementConstReferenceTraits<T>::const_reference
166 list_element_to_const_ref(const IValue& element) {
167   return element.template to<T>();
168 }
169 
170 template<>
171 inline typename ListElementConstReferenceTraits<std::optional<std::string>>::const_reference
172 list_element_to_const_ref<std::optional<std::string>>(const IValue& element) {
173   return element.toOptionalStringRef();
174 }
175 
176 } // namespace impl
177 
178 template<class T>
179 void List<T>::set(size_type pos, const value_type& value) const {
180   impl_->list.at(pos) = c10::detail::ListElementFrom<T>::from(value);
181 }
182 
183 template<class T>
184 void List<T>::set(size_type pos, value_type&& value) const {
185   impl_->list.at(pos) = c10::detail::ListElementFrom<T>::from(std::move(value));
186 }
187 
188 template<class T>
189 typename List<T>::internal_const_reference_type List<T>::get(size_type pos) const {
190   return operator[](pos);
191 }
192 
193 template<class T>
194 typename List<T>::internal_const_reference_type List<T>::operator[](size_type pos) const {
195   return c10::impl::list_element_to_const_ref<T>(impl_->list.at(pos));
196 }
197 
198 template<class T>
199 typename List<T>::internal_reference_type List<T>::operator[](size_type pos) {
200   static_cast<void>(impl_->list.at(pos)); // Throw the exception if it is out of range.
201   return {impl_->list.begin() + static_cast<typename decltype(impl_->list)::difference_type>(pos)};
202 }
203 
204 template<class T>
205 typename List<T>::value_type List<T>::extract(size_type pos) const {
206   auto& elem = impl_->list.at(pos);
207   auto result = c10::detail::list_element_to<T>(std::move(elem));
208   // Reset the list element to a T() instead of None to keep it correctly typed
209   elem = c10::detail::ListElementFrom<T>::from(T{});
210   return result;
211 }
212 
213 template<class T>
214 typename List<T>::iterator List<T>::begin() const {
215   return iterator(impl_->list.begin());
216 }
217 
218 template<class T>
219 typename List<T>::iterator List<T>::end() const {
220   return iterator(impl_->list.end());
221 }
222 
223 template<class T>
224 bool List<T>::empty() const {
225   return impl_->list.empty();
226 }
227 
228 template<class T>
229 typename List<T>::size_type List<T>::size() const {
230   return impl_->list.size();
231 }
232 
233 template<class T>
234 void List<T>::reserve(size_type new_cap) const {
235   impl_->list.reserve(new_cap);
236 }
237 
238 template<class T>
239 void List<T>::clear() const {
240   impl_->list.clear();
241 }
242 
243 template<class T>
244 typename List<T>::iterator List<T>::insert(iterator pos, const T& value) const {
245   return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom<T>::from(value)) };
246 }
247 
248 template<class T>
249 typename List<T>::iterator List<T>::insert(iterator pos, T&& value) const {
250   return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom<T>::from(std::move(value))) };
251 }
252 
253 template<class T>
254 template<class... Args>
255 typename List<T>::iterator List<T>::emplace(iterator pos, Args&&... value) const {
256   // TODO Use list_element_from?
257   return iterator { impl_->list.emplace(pos.iterator_, std::forward<Args>(value)...) };
258 }
259 
260 template<class T>
261 void List<T>::push_back(const T& value) const {
262   impl_->list.push_back(c10::detail::ListElementFrom<T>::from(value));
263 }
264 
265 template<class T>
266 void List<T>::push_back(T&& value) const {
267   impl_->list.push_back(c10::detail::ListElementFrom<T>::from(std::move(value)));
268 }
269 
270 template<class T>
271 void List<T>::append(List<T> b) const {
272   if (b.use_count() == 1) {
273     impl_->list.insert(impl_->list.end(), make_move_iterator(b.impl_->list.begin()), make_move_iterator(b.impl_->list.end()));
274   } else {
275     impl_->list.insert(impl_->list.end(), b.impl_->list.begin(), b.impl_->list.end());
276   }
277 }
278 
279 template<class T>
280 template<class... Args>
281 void List<T>::emplace_back(Args&&... args) const {
282   // TODO Use list_element_from?
283   impl_->list.push_back(T(std::forward<Args>(args)...));
284 }
285 
286 template<class T>
287 typename List<T>::iterator List<T>::erase(iterator pos) const {
288   return iterator { impl_->list.erase(pos.iterator_) };
289 }
290 
291 template<class T>
292 typename List<T>::iterator List<T>::erase(iterator first, iterator last) const {
293   return iterator { impl_->list.erase(first.iterator_, last.iterator_) };
294 }
295 
296 template<class T>
297 void List<T>::pop_back() const {
298   impl_->list.pop_back();
299 }
300 
301 template<class T>
302 void List<T>::resize(size_type count) const {
303   impl_->list.resize(count, T{});
304 }
305 
306 template<class T>
307 void List<T>::resize(size_type count, const T& value) const {
308   impl_->list.resize(count, value);
309 }
310 
311 template<class T>
312 bool operator==(const List<T>& lhs, const List<T>& rhs) {
313   // Lists with the same identity trivially compare equal.
314   if (lhs.impl_ == rhs.impl_) {
315     return true;
316   }
317 
318   // Otherwise, just compare values directly.
319   return *lhs.impl_ == *rhs.impl_;
320 }
321 
322 template<class T>
323 bool operator!=(const List<T>& lhs, const List<T>& rhs) {
324   return !(lhs == rhs);
325 }
326 
327 template<class T>
328 bool List<T>::is(const List<T>& rhs) const {
329   return this->impl_ == rhs.impl_;
330 }
331 
332 template<class T>
333 std::vector<T> List<T>::vec() const {
334   std::vector<T> result(begin(), end());
335   return result;
336 }
337 
338 template<class T>
339 size_t List<T>::use_count() const {
340   return impl_.use_count();
341 }
342 
343 template <class T>
344 TypePtr List<T>::elementType() const {
345   return impl_->elementType;
346 }
347 
348 template <class T>
349 void List<T>::unsafeSetElementType(TypePtr t) {
350   impl_->elementType = std::move(t);
351 }
352 
353 }
354