1 #pragma once 2 3 #include <ATen/core/ivalue_to.h> 4 #include <ATen/core/jit_type_base.h> 5 #include <c10/macros/Macros.h> 6 #include <c10/macros/Export.h> 7 #include <c10/util/TypeTraits.h> 8 #include <c10/util/TypeList.h> 9 #include <c10/util/intrusive_ptr.h> 10 #include <c10/util/ArrayRef.h> 11 #include <optional> 12 #include <vector> 13 14 namespace at { 15 class Tensor; 16 } 17 namespace c10 { 18 struct IValue; 19 template<class T> class List; 20 struct Type; 21 22 namespace detail { 23 24 struct ListImpl final : public c10::intrusive_ptr_target { 25 using list_type = std::vector<IValue>; 26 27 explicit TORCH_API ListImpl(list_type list_, TypePtr elementType_); 28 29 list_type list; 30 31 TypePtr elementType; 32 copyfinal33 intrusive_ptr<ListImpl> copy() const { 34 return make_intrusive<ListImpl>(list, elementType); 35 } 36 friend TORCH_API bool operator==(const ListImpl& lhs, const ListImpl& rhs); 37 }; 38 } 39 40 namespace impl { 41 42 template<class T, class Iterator> class ListIterator; 43 44 template<class T, class Iterator> class ListElementReference; 45 46 template<class T, class Iterator> 47 void swap(ListElementReference<T, Iterator>&& lhs, ListElementReference<T, Iterator>&& rhs) noexcept; 48 49 template<class T, class Iterator> 50 bool operator==(const ListElementReference<T, Iterator>& lhs, const T& rhs); 51 52 template<class T, class Iterator> 53 bool operator==(const T& lhs, const ListElementReference<T, Iterator>& rhs); 54 55 template<class T> 56 struct ListElementConstReferenceTraits { 57 // In the general case, we use IValue::to(). 58 using const_reference = typename c10::detail::ivalue_to_const_ref_overload_return<T>::type; 59 }; 60 61 // There is no to() overload for std::optional<std::string>. 62 template<> 63 struct ListElementConstReferenceTraits<std::optional<std::string>> { 64 using const_reference = std::optional<std::reference_wrapper<const std::string>>; 65 }; 66 67 template<class T, class Iterator> 68 class ListElementReference final { 69 public: 70 operator std::conditional_t< 71 std::is_reference_v<typename c10::detail:: 72 ivalue_to_const_ref_overload_return<T>::type>, 73 const T&, 74 T>() const; 75 76 ListElementReference& operator=(T&& new_value) &&; 77 78 ListElementReference& operator=(const T& new_value) &&; 79 80 // assigning another ref to this assigns the underlying value 81 ListElementReference& operator=(ListElementReference&& rhs) && noexcept; 82 83 const IValue& get() const& { 84 return *iterator_; 85 } 86 87 friend void swap<T, Iterator>(ListElementReference&& lhs, ListElementReference&& rhs) noexcept; 88 89 ListElementReference(const ListElementReference&) = delete; 90 ListElementReference& operator=(const ListElementReference&) = delete; 91 92 private: 93 ListElementReference(Iterator iter) 94 : iterator_(iter) {} 95 96 // allow moving, but only our friends (i.e. the List class) can move us 97 ListElementReference(ListElementReference&&) noexcept = default; 98 ListElementReference& operator=(ListElementReference&& rhs) & noexcept { 99 iterator_ = std::move(rhs.iterator_); 100 return *this; 101 } 102 103 friend class List<T>; 104 friend class ListIterator<T, Iterator>; 105 106 Iterator iterator_; 107 }; 108 109 // this wraps vector::iterator to make sure user code can't rely 110 // on it being the type of the underlying vector. 111 template <class T, class Iterator> 112 class ListIterator final { 113 public: 114 // C++17 friendly std::iterator implementation 115 using iterator_category = std::random_access_iterator_tag; 116 using value_type = T; 117 using difference_type = std::ptrdiff_t; 118 using pointer = T*; 119 using reference = ListElementReference<T, Iterator>; 120 121 explicit ListIterator() = default; 122 ~ListIterator() = default; 123 124 ListIterator(const ListIterator&) = default; 125 ListIterator(ListIterator&&) noexcept = default; 126 ListIterator& operator=(const ListIterator&) = default; 127 ListIterator& operator=(ListIterator&&) noexcept = default; 128 129 ListIterator& operator++() { 130 ++iterator_; 131 return *this; 132 } 133 134 ListIterator operator++(int) { 135 ListIterator copy(*this); 136 ++*this; 137 return copy; 138 } 139 140 ListIterator& operator--() { 141 --iterator_; 142 return *this; 143 } 144 145 ListIterator operator--(int) { 146 ListIterator copy(*this); 147 --*this; 148 return copy; 149 } 150 151 ListIterator& operator+=(typename List<T>::size_type offset) { 152 iterator_ += offset; 153 return *this; 154 } 155 156 ListIterator& operator-=(typename List<T>::size_type offset) { 157 iterator_ -= offset; 158 return *this; 159 } 160 161 ListIterator operator+(typename List<T>::size_type offset) const { 162 return ListIterator{iterator_ + offset}; 163 } 164 165 ListIterator operator-(typename List<T>::size_type offset) const { 166 return ListIterator{iterator_ - offset}; 167 } 168 169 friend difference_type operator-(const ListIterator& lhs, const ListIterator& rhs) { 170 return lhs.iterator_ - rhs.iterator_; 171 } 172 173 ListElementReference<T, Iterator> operator*() const { 174 return {iterator_}; 175 } 176 177 ListElementReference<T, Iterator> operator[](typename List<T>::size_type offset) const { 178 return {iterator_ + offset}; 179 } 180 181 private: 182 explicit ListIterator(Iterator iterator): iterator_(std::move(iterator)) {} 183 184 Iterator iterator_; 185 186 friend bool operator==(const ListIterator& lhs, const ListIterator& rhs) { 187 return lhs.iterator_ == rhs.iterator_; 188 } 189 190 friend bool operator!=(const ListIterator& lhs, const ListIterator& rhs) { 191 return !(lhs == rhs); 192 } 193 194 friend bool operator<(const ListIterator& lhs, const ListIterator& rhs) { 195 return lhs.iterator_ < rhs.iterator_; 196 } 197 198 friend bool operator<=(const ListIterator& lhs, const ListIterator& rhs) { 199 return lhs.iterator_ <= rhs.iterator_; 200 } 201 202 friend bool operator>(const ListIterator& lhs, const ListIterator& rhs) { 203 return lhs.iterator_ > rhs.iterator_; 204 } 205 206 friend bool operator>=(const ListIterator& lhs, const ListIterator& rhs) { 207 return lhs.iterator_ >= rhs.iterator_; 208 } 209 210 friend class ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>; 211 friend class List<T>; 212 }; 213 214 template<class T> List<T> toTypedList(List<IValue> list); 215 template<class T> List<IValue> toList(List<T>&& list); 216 template<class T> List<IValue> toList(const List<T>& list); 217 const IValue* ptr_to_first_element(const List<IValue>& list); 218 } 219 220 /** 221 * An object of this class stores a list of values of type T. 222 * 223 * This is a pointer type. After a copy, both Lists 224 * will share the same storage: 225 * 226 * > List<int> a; 227 * > List<int> b = a; 228 * > b.push_back("three"); 229 * > ASSERT("three" == a.get(0)); 230 * 231 * We use this class in the PyTorch kernel API instead of 232 * std::vector<T>, because that allows us to do optimizations 233 * and switch out the underlying list implementation without 234 * breaking backwards compatibility for the kernel API. 235 */ 236 template<class T> 237 class List final { 238 private: 239 // This is an intrusive_ptr because List is a pointer type. 240 // Invariant: This will never be a nullptr, there will always be a valid 241 // ListImpl. 242 c10::intrusive_ptr<c10::detail::ListImpl> impl_; 243 244 using internal_reference_type = impl::ListElementReference<T, typename c10::detail::ListImpl::list_type::iterator>; 245 using internal_const_reference_type = typename impl::ListElementConstReferenceTraits<T>::const_reference; 246 247 public: 248 using value_type = T; 249 using size_type = typename c10::detail::ListImpl::list_type::size_type; 250 using iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>; 251 using const_iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>; 252 using reverse_iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::reverse_iterator>; 253 254 /** 255 * Constructs an empty list. 256 */ 257 explicit List(); 258 259 /** 260 * Constructs a list with some initial values. 261 * Example: 262 * List<int> a({2, 3, 4}); 263 */ 264 List(std::initializer_list<T> initial_values); 265 explicit List(ArrayRef<T> initial_values); 266 267 /** 268 * Create a generic list with runtime type information. 269 * This only works for c10::impl::GenericList and is not part of the public API 270 * but only supposed to be used internally by PyTorch. 271 */ 272 explicit List(TypePtr elementType); 273 274 List(const List&) = default; 275 List& operator=(const List&) = default; 276 277 /** 278 * Create a new List pointing to a deep copy of the same data. 279 * The List returned is a new list with separate storage. 280 * Changes in it are not reflected in the original list or vice versa. 281 */ 282 List copy() const; 283 284 /** 285 * Returns the element at specified location pos, with bounds checking. 286 * If pos is not within the range of the container, an exception of type std::out_of_range is thrown. 287 */ 288 internal_const_reference_type get(size_type pos) const; 289 290 /** 291 * Moves out the element at the specified location pos and returns it, with bounds checking. 292 * If pos is not within the range of the container, an exception of type std::out_of_range is thrown. 293 * The list contains an invalid element at position pos afterwards. Any operations 294 * on it before re-setting it are invalid. 295 */ 296 value_type extract(size_type pos) const; 297 298 /** 299 * Returns a reference to the element at specified location pos, with bounds checking. 300 * If pos is not within the range of the container, an exception of type std::out_of_range is thrown. 301 * 302 * You cannot store the reference, but you can read it and assign new values to it: 303 * 304 * List<int64_t> list = ...; 305 * list[2] = 5; 306 * int64_t v = list[1]; 307 */ 308 internal_const_reference_type operator[](size_type pos) const; 309 310 internal_reference_type operator[](size_type pos); 311 312 /** 313 * Assigns a new value to the element at location pos. 314 */ 315 void set(size_type pos, const value_type& value) const; 316 317 /** 318 * Assigns a new value to the element at location pos. 319 */ 320 void set(size_type pos, value_type&& value) const; 321 322 /** 323 * Returns an iterator to the first element of the container. 324 * If the container is empty, the returned iterator will be equal to end(). 325 */ 326 iterator begin() const; 327 328 /** 329 * Returns an iterator to the element following the last element of the container. 330 * This element acts as a placeholder; attempting to access it results in undefined behavior. 331 */ 332 iterator end() const; 333 334 /** 335 * Checks if the container has no elements. 336 */ 337 bool empty() const; 338 339 /** 340 * Returns the number of elements in the container 341 */ 342 size_type size() const; 343 344 /** 345 * Increase the capacity of the vector to a value that's greater or equal to new_cap. 346 */ 347 void reserve(size_type new_cap) const; 348 349 /** 350 * Erases all elements from the container. After this call, size() returns zero. 351 * Invalidates any references, pointers, or iterators referring to contained elements. Any past-the-end iterators are also invalidated. 352 */ 353 void clear() const; 354 355 /** 356 * Inserts value before pos. 357 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. 358 */ 359 iterator insert(iterator pos, const T& value) const; 360 361 /** 362 * Inserts value before pos. 363 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. 364 */ 365 iterator insert(iterator pos, T&& value) const; 366 367 /** 368 * Inserts a new element into the container directly before pos. 369 * The new element is constructed with the given arguments. 370 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. 371 */ 372 template<class... Args> 373 iterator emplace(iterator pos, Args&&... value) const; 374 375 /** 376 * Appends the given element value to the end of the container. 377 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. 378 */ 379 void push_back(const T& value) const; 380 381 /** 382 * Appends the given element value to the end of the container. 383 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. 384 */ 385 void push_back(T&& value) const; 386 387 /** 388 * Appends the given list to the end of the container. Uses at most one memory allocation. 389 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. 390 */ 391 void append(List<T> lst) const; 392 393 /** 394 * Appends the given element value to the end of the container. 395 * The new element is constructed with the given arguments. 396 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. 397 */ 398 template<class... Args> 399 void emplace_back(Args&&... args) const; 400 401 /** 402 * Removes the element at pos. 403 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. 404 */ 405 iterator erase(iterator pos) const; 406 407 /** 408 * Removes the elements in the range [first, last). 409 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. 410 */ 411 iterator erase(iterator first, iterator last) const; 412 413 /** 414 * Removes the last element of the container. 415 * Calling pop_back on an empty container is undefined. 416 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. 417 */ 418 void pop_back() const; 419 420 /** 421 * Resizes the container to contain count elements. 422 * If the current size is less than count, additional default-inserted elements are appended. 423 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. 424 */ 425 void resize(size_type count) const; 426 427 /** 428 * Resizes the container to contain count elements. 429 * If the current size is less than count, additional copies of value are appended. 430 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated. 431 */ 432 void resize(size_type count, const T& value) const; 433 434 /** 435 * Value equality comparison. This function implements Python-like semantics for 436 * equality: two lists with the same identity (e.g. same pointer) trivially 437 * compare equal, otherwise each element is compared for equality. 438 */ 439 template <class T_> 440 friend bool operator==(const List<T_>& lhs, const List<T_>& rhs); 441 442 template <class T_> 443 friend bool operator!=(const List<T_>& lhs, const List<T_>& rhs); 444 445 /** 446 * Identity comparison. Returns true if and only if `rhs` represents the same 447 * List object as `this`. 448 */ 449 bool is(const List<T>& rhs) const; 450 451 std::vector<T> vec() const; 452 453 /** 454 * Returns the number of Lists currently pointing to this same list. 455 * If this is the only instance pointing to this list, returns 1. 456 */ 457 // TODO Test use_count 458 size_t use_count() const; 459 460 TypePtr elementType() const; 461 462 // See [unsafe set type] for why this exists. 463 void unsafeSetElementType(TypePtr t); 464 465 private: 466 explicit List(c10::intrusive_ptr<c10::detail::ListImpl>&& elements); 467 explicit List(const c10::intrusive_ptr<c10::detail::ListImpl>& elements); 468 friend struct IValue; 469 template<class T_> friend List<T_> impl::toTypedList(List<IValue>); 470 template<class T_> friend List<IValue> impl::toList(List<T_>&&); 471 template<class T_> friend List<IValue> impl::toList(const List<T_>&); 472 friend const IValue* impl::ptr_to_first_element(const List<IValue>& list); 473 }; 474 475 namespace impl { 476 // GenericList is how IValue stores lists. It is, however, not part of the 477 // public API. Kernels should use Lists with concrete types instead 478 // (maybe except for some internal prim ops). 479 using GenericList = List<IValue>; 480 481 } 482 } 483 484 namespace torch { 485 template<class T> using List = c10::List<T>; 486 } 487 488 #include <ATen/core/List_inl.h> // IWYU pragma: keep 489