1 #pragma once
2
3 #include <ATen/core/jit_type_base.h>
4 #include <ATen/core/ivalue.h>
5
6 namespace c10 {
7
8 template<class T> decltype(auto) getTypePtr();
9 std::string toString(const Type& type);
10
11 template<class T>
List(c10::intrusive_ptr<c10::detail::ListImpl> && elements)12 List<T>::List(c10::intrusive_ptr<c10::detail::ListImpl>&& elements)
13 : impl_(std::move(elements)) {}
14
15 template<class T>
List(const c10::intrusive_ptr<c10::detail::ListImpl> & elements)16 List<T>::List(const c10::intrusive_ptr<c10::detail::ListImpl>& elements)
17 : impl_(elements) {}
18
19 template<class T>
List()20 List<T>::List()
21 : List(make_intrusive<c10::detail::ListImpl>(
22 typename c10::detail::ListImpl::list_type(),
23 getTypePtr<T>())) {
24 static_assert(!std::is_same<T, IValue>::value, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType) instead.");
25 }
26
27 template<class T>
List(ArrayRef<T> values)28 List<T>::List(ArrayRef<T> values)
29 : List(make_intrusive<c10::detail::ListImpl>(
30 typename c10::detail::ListImpl::list_type(),
31 getTypePtr<T>())) {
32 static_assert(!std::is_same<T, IValue>::value, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType).");
33 impl_->list.reserve(values.size());
34 for (const T& element : values) {
35 impl_->list.push_back(element);
36 }
37 }
38
39 template<class T>
List(std::initializer_list<T> initial_values)40 List<T>::List(std::initializer_list<T> initial_values)
41 : List(ArrayRef<T>(initial_values)) {
42 static_assert(!std::is_same<T, IValue>::value, "This constructor is not valid for List<IValue>. Please use c10::impl::GenericList(elementType).");
43 }
44
45 template<class T>
List(TypePtr elementType)46 List<T>::List(TypePtr elementType)
47 : List(make_intrusive<c10::detail::ListImpl>(
48 typename c10::detail::ListImpl::list_type(),
49 std::move(elementType))) {
50 static_assert(std::is_same<T, IValue>::value || std::is_same<T, c10::intrusive_ptr<ivalue::Future>>::value,
51 "This constructor is only valid for c10::impl::GenericList or List<Future>.");
52 }
53
54 namespace impl {
55 template<class T>
toTypedList(impl::GenericList list)56 List<T> toTypedList(impl::GenericList list) {
57 // If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant
58 // because upcasting would allow people to add types into the new list that would break the old list.
59 // However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can
60 // allow upcasting. This can be a perf improvement since we can cast List<T> to List<optional<T>>
61 // without having to copy it. This is also used to provide backwards compatibility with some old models
62 // that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_
63 // as List<Tensor> before we changed that argument to be List<optional<Tensor>>. When deserializing, we
64 // have list.use_count() == 1 and can deserialize the List<Tensor> directly as List<optional<Tensor>>.
65 TORCH_CHECK(*list.impl_->elementType == *getTypePtr<T>()
66 || (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(*getTypePtr<T>()))
67 , "Tried to cast a List<", toString(*list.impl_->elementType), "> to a List<", toString(*getTypePtr<T>()), ">. Types mismatch.");
68 return List<T>(std::move(list.impl_));
69 }
70
71 template<class T>
toList(List<T> && list)72 impl::GenericList toList(List<T>&& list) {
73 return GenericList(std::move(list.impl_));
74 }
75 template<class T>
toList(const List<T> & list)76 impl::GenericList toList(const List<T>& list) {
77 return GenericList(list.impl_);
78 }
79 }
80
81 template<class T>
copy()82 List<T> List<T>::copy() const {
83 return List<T>(impl_->copy());
84 }
85
86 namespace detail {
87 template<class T>
list_element_to(T element)88 T list_element_to(T element) {
89 return element;
90 }
91 template<class T>
list_element_to(const IValue & element)92 T list_element_to(const IValue& element) {
93 return element.template to<T>();
94 }
95 template<class T>
list_element_to(IValue && element)96 T list_element_to(IValue&& element) {
97 return std::move(element).template to<T>();
98 }
99 template<class T>
100 struct ListElementFrom {
fromListElementFrom101 static IValue from(const T& element) {
102 return element;
103 }
fromListElementFrom104 static IValue from(T&& element) {
105 return std::move(element);
106 }
107 };
108 template<>
109 struct ListElementFrom<IValue> {
110 static const IValue& from(const IValue& element) {
111 return element;
112 }
113 static IValue&& from(IValue&& element) {
114 return std::move(element);
115 }
116 };
117 }
118
119 namespace impl {
120
121 template <class T, class Iterator>
122 ListElementReference<T, Iterator>::operator std::conditional_t<
123 std::is_reference_v<typename c10::detail::ivalue_to_const_ref_overload_return<
124 T>::type>,
125 const T&,
126 T>() const {
127 return iterator_->template to<T>();
128 }
129
130 template<class T, class Iterator>
131 ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(T&& new_value) && {
132 *iterator_ = c10::detail::ListElementFrom<T>::from(std::move(new_value));
133 return *this;
134 }
135
136 template<class T, class Iterator>
137 ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(const T& new_value) && {
138 *iterator_ = c10::detail::ListElementFrom<T>::from(new_value);
139 return *this;
140 }
141
142 template<class T, class Iterator>
143 ListElementReference<T, Iterator>& ListElementReference<T, Iterator>::operator=(ListElementReference<T, Iterator>&& rhs) && noexcept {
144 *iterator_ = *rhs.iterator_;
145 return *this;
146 }
147
148 template<class T, class Iterator>
149 void swap(ListElementReference<T, Iterator>&& lhs, ListElementReference<T, Iterator>&& rhs) noexcept {
150 std::swap(*lhs.iterator_, *rhs.iterator_);
151 }
152
153 template<class T, class Iterator>
154 bool operator==(const ListElementReference<T, Iterator>& lhs, const T& rhs) {
155 const T& lhs_tmp = lhs;
156 return lhs_tmp == rhs;
157 }
158
159 template<class T, class Iterator>
160 inline bool operator==(const T& lhs, const ListElementReference<T, Iterator>& rhs) {
161 return rhs == lhs;
162 }
163
164 template<class T>
165 inline typename ListElementConstReferenceTraits<T>::const_reference
166 list_element_to_const_ref(const IValue& element) {
167 return element.template to<T>();
168 }
169
170 template<>
171 inline typename ListElementConstReferenceTraits<std::optional<std::string>>::const_reference
172 list_element_to_const_ref<std::optional<std::string>>(const IValue& element) {
173 return element.toOptionalStringRef();
174 }
175
176 } // namespace impl
177
178 template<class T>
179 void List<T>::set(size_type pos, const value_type& value) const {
180 impl_->list.at(pos) = c10::detail::ListElementFrom<T>::from(value);
181 }
182
183 template<class T>
184 void List<T>::set(size_type pos, value_type&& value) const {
185 impl_->list.at(pos) = c10::detail::ListElementFrom<T>::from(std::move(value));
186 }
187
188 template<class T>
189 typename List<T>::internal_const_reference_type List<T>::get(size_type pos) const {
190 return operator[](pos);
191 }
192
193 template<class T>
194 typename List<T>::internal_const_reference_type List<T>::operator[](size_type pos) const {
195 return c10::impl::list_element_to_const_ref<T>(impl_->list.at(pos));
196 }
197
198 template<class T>
199 typename List<T>::internal_reference_type List<T>::operator[](size_type pos) {
200 static_cast<void>(impl_->list.at(pos)); // Throw the exception if it is out of range.
201 return {impl_->list.begin() + static_cast<typename decltype(impl_->list)::difference_type>(pos)};
202 }
203
204 template<class T>
205 typename List<T>::value_type List<T>::extract(size_type pos) const {
206 auto& elem = impl_->list.at(pos);
207 auto result = c10::detail::list_element_to<T>(std::move(elem));
208 // Reset the list element to a T() instead of None to keep it correctly typed
209 elem = c10::detail::ListElementFrom<T>::from(T{});
210 return result;
211 }
212
213 template<class T>
214 typename List<T>::iterator List<T>::begin() const {
215 return iterator(impl_->list.begin());
216 }
217
218 template<class T>
219 typename List<T>::iterator List<T>::end() const {
220 return iterator(impl_->list.end());
221 }
222
223 template<class T>
224 bool List<T>::empty() const {
225 return impl_->list.empty();
226 }
227
228 template<class T>
229 typename List<T>::size_type List<T>::size() const {
230 return impl_->list.size();
231 }
232
233 template<class T>
234 void List<T>::reserve(size_type new_cap) const {
235 impl_->list.reserve(new_cap);
236 }
237
238 template<class T>
239 void List<T>::clear() const {
240 impl_->list.clear();
241 }
242
243 template<class T>
244 typename List<T>::iterator List<T>::insert(iterator pos, const T& value) const {
245 return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom<T>::from(value)) };
246 }
247
248 template<class T>
249 typename List<T>::iterator List<T>::insert(iterator pos, T&& value) const {
250 return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom<T>::from(std::move(value))) };
251 }
252
253 template<class T>
254 template<class... Args>
255 typename List<T>::iterator List<T>::emplace(iterator pos, Args&&... value) const {
256 // TODO Use list_element_from?
257 return iterator { impl_->list.emplace(pos.iterator_, std::forward<Args>(value)...) };
258 }
259
260 template<class T>
261 void List<T>::push_back(const T& value) const {
262 impl_->list.push_back(c10::detail::ListElementFrom<T>::from(value));
263 }
264
265 template<class T>
266 void List<T>::push_back(T&& value) const {
267 impl_->list.push_back(c10::detail::ListElementFrom<T>::from(std::move(value)));
268 }
269
270 template<class T>
271 void List<T>::append(List<T> b) const {
272 if (b.use_count() == 1) {
273 impl_->list.insert(impl_->list.end(), make_move_iterator(b.impl_->list.begin()), make_move_iterator(b.impl_->list.end()));
274 } else {
275 impl_->list.insert(impl_->list.end(), b.impl_->list.begin(), b.impl_->list.end());
276 }
277 }
278
279 template<class T>
280 template<class... Args>
281 void List<T>::emplace_back(Args&&... args) const {
282 // TODO Use list_element_from?
283 impl_->list.push_back(T(std::forward<Args>(args)...));
284 }
285
286 template<class T>
287 typename List<T>::iterator List<T>::erase(iterator pos) const {
288 return iterator { impl_->list.erase(pos.iterator_) };
289 }
290
291 template<class T>
292 typename List<T>::iterator List<T>::erase(iterator first, iterator last) const {
293 return iterator { impl_->list.erase(first.iterator_, last.iterator_) };
294 }
295
296 template<class T>
297 void List<T>::pop_back() const {
298 impl_->list.pop_back();
299 }
300
301 template<class T>
302 void List<T>::resize(size_type count) const {
303 impl_->list.resize(count, T{});
304 }
305
306 template<class T>
307 void List<T>::resize(size_type count, const T& value) const {
308 impl_->list.resize(count, value);
309 }
310
311 template<class T>
312 bool operator==(const List<T>& lhs, const List<T>& rhs) {
313 // Lists with the same identity trivially compare equal.
314 if (lhs.impl_ == rhs.impl_) {
315 return true;
316 }
317
318 // Otherwise, just compare values directly.
319 return *lhs.impl_ == *rhs.impl_;
320 }
321
322 template<class T>
323 bool operator!=(const List<T>& lhs, const List<T>& rhs) {
324 return !(lhs == rhs);
325 }
326
327 template<class T>
328 bool List<T>::is(const List<T>& rhs) const {
329 return this->impl_ == rhs.impl_;
330 }
331
332 template<class T>
333 std::vector<T> List<T>::vec() const {
334 std::vector<T> result(begin(), end());
335 return result;
336 }
337
338 template<class T>
339 size_t List<T>::use_count() const {
340 return impl_.use_count();
341 }
342
343 template <class T>
344 TypePtr List<T>::elementType() const {
345 return impl_->elementType;
346 }
347
348 template <class T>
349 void List<T>::unsafeSetElementType(TypePtr t) {
350 impl_->elementType = std::move(t);
351 }
352
353 }
354