1 #pragma once 2 3 #include <c10/macros/Macros.h> 4 #include <c10/macros/Export.h> 5 #include <c10/util/TypeTraits.h> 6 #include <c10/util/TypeList.h> 7 #include <c10/util/intrusive_ptr.h> 8 #include <c10/util/order_preserving_flat_hash_map.h> 9 #include <optional> 10 #include <ATen/core/TensorBody.h> 11 #include <ATen/core/jit_type_base.h> 12 13 namespace c10 { 14 struct IValue; 15 template<class Key, class Value> class Dict; 16 struct Type; 17 18 namespace impl { 19 20 using valid_dict_key_types = guts::typelist::typelist< 21 int64_t, 22 std::string, 23 double, 24 c10::complex<double>, 25 bool, 26 at::Tensor 27 >; 28 } 29 30 namespace detail { 31 32 struct DictKeyHash { 33 size_t operator()(const IValue& ivalue) const; 34 }; 35 36 struct DictKeyEqualTo { 37 bool operator()(const IValue& lhs, const IValue& rhs) const; 38 }; 39 40 struct DictImpl final : public c10::intrusive_ptr_target { 41 using dict_map_type = ska_ordered::order_preserving_flat_hash_map<IValue, IValue, DictKeyHash, DictKeyEqualTo>; 42 struct DictElementTypes final { 43 TypePtr keyType; 44 TypePtr valueType; 45 }; 46 DictImplfinal47 explicit DictImpl(dict_map_type dict_, DictElementTypes elementTypes_) 48 : dict(std::move(dict_)) 49 , elementTypes(std::move(elementTypes_)) {} 50 dict_map_type dict; 51 52 DictElementTypes elementTypes; 53 54 intrusive_ptr<DictImpl> copy() const; 55 friend TORCH_API bool operator==(const DictImpl& lhs, const DictImpl& rhs); 56 }; 57 58 } 59 60 namespace impl { 61 template<class Key, class Value, class Iterator> class DictIterator; 62 63 /** 64 * A reference to an entry in the Dict. 65 * Use the `key()` and `value()` methods to read the element. 66 */ 67 template<class Key, class Value, class Iterator> 68 class DictEntryRef final { 69 public: DictEntryRef(Iterator iterator)70 explicit DictEntryRef(Iterator iterator) 71 : iterator_(std::move(iterator)) {} 72 key()73 decltype(auto) key() const { 74 return iterator_->first.template to<Key>(); 75 } 76 value()77 decltype(auto) value() const { 78 return iterator_->second.template to<Value>(); 79 } 80 81 template<class Value_> setValue(Value_ && value)82 void setValue(Value_&& value) const { 83 static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of setValue()"); 84 iterator_->second = Value(std::forward<Value_>(value)); 85 } 86 87 private: 88 // allow copying and moving, but only our friends (i.e. the Dict class) can do 89 // it. Copying/moving this reference wrapper would be too ambiguous to allow it 90 // in the public API. 91 DictEntryRef(const DictEntryRef&) = default; 92 DictEntryRef& operator=(const DictEntryRef&) = default; 93 DictEntryRef(DictEntryRef&&) noexcept = default; 94 DictEntryRef& operator=(DictEntryRef&& rhs) & noexcept = default; 95 96 Iterator iterator_; 97 friend class DictIterator<Key, Value, Iterator>; 98 friend class Dict<Key, Value>; 99 }; 100 101 // this wraps map_type::iterator to make sure user code can't rely 102 // on it being the type of the underlying map. 103 template<class Key, class Value, class Iterator> 104 class DictIterator final { 105 public: 106 // C++17 friendly std::iterator implementation 107 using iterator_category = std::forward_iterator_tag; 108 using value_type = DictEntryRef<Key, Value, Iterator>; 109 using difference_type = std::ptrdiff_t; 110 using pointer = value_type*; 111 using reference = value_type&; 112 113 explicit DictIterator() = default; 114 ~DictIterator() = default; 115 DictIterator(const DictIterator & rhs)116 DictIterator(const DictIterator& rhs): entryRef_(rhs.entryRef_) {} DictIterator(DictIterator && rhs)117 DictIterator(DictIterator&& rhs) noexcept: entryRef_(std::move(rhs.entryRef_)) {} 118 DictIterator& operator=(const DictIterator& rhs) { 119 entryRef_ = rhs.entryRef_; 120 return *this; 121 } 122 DictIterator& operator=(DictIterator&& rhs) noexcept { 123 entryRef_ = std::move(rhs.entryRef_); 124 return *this; 125 } 126 127 DictIterator& operator++() { 128 ++entryRef_.iterator_; 129 return *this; 130 } 131 132 DictIterator operator++(int) { 133 DictIterator copy(*this); 134 ++*this; 135 return copy; 136 } 137 138 const DictEntryRef<Key, Value, Iterator>& operator*() const { 139 return entryRef_; 140 } 141 142 const DictEntryRef<Key, Value, Iterator>* operator->() const { 143 return &entryRef_; 144 } 145 146 friend difference_type operator-(const DictIterator& lhs, const DictIterator& rhs) { 147 return lhs.entryRef_.iterator_ - rhs.entryRef_.iterator_; 148 } 149 150 private: DictIterator(Iterator iterator)151 explicit DictIterator(Iterator iterator): entryRef_(std::move(iterator)) {} 152 get_iterator_()153 const Iterator& get_iterator_() const { 154 return entryRef_.iterator_; 155 } 156 157 friend bool operator==(const DictIterator& lhs, const DictIterator& rhs) { 158 return lhs.get_iterator_() == rhs.get_iterator_(); 159 } 160 161 friend bool operator!=(const DictIterator& lhs, const DictIterator& rhs) { 162 return lhs.get_iterator_() != rhs.get_iterator_(); 163 } 164 165 friend bool operator<(const DictIterator& lhs, const DictIterator& rhs) { 166 return lhs.get_iterator_() < rhs.get_iterator_(); 167 } 168 169 friend bool operator<=(const DictIterator& lhs, const DictIterator& rhs) { 170 return lhs.get_iterator_() <= rhs.get_iterator_(); 171 } 172 173 friend bool operator>(const DictIterator& lhs, const DictIterator& rhs) { 174 return lhs.get_iterator_() > rhs.get_iterator_(); 175 } 176 177 friend bool operator>=(const DictIterator& lhs, const DictIterator& rhs) { 178 return lhs.get_iterator_() >= rhs.get_iterator_(); 179 } 180 181 DictEntryRef<Key, Value, Iterator> entryRef_; 182 183 friend class DictIterator<Key, Value, typename c10::detail::DictImpl::dict_map_type::iterator>; 184 friend class Dict<Key, Value>; 185 }; 186 187 template<class Key, class Value> Dict<Key, Value> toTypedDict(Dict<IValue, IValue> dict); 188 template<class Key, class Value> Dict<IValue, IValue> toGenericDict(Dict<Key, Value> dict); 189 } 190 191 /** 192 * An object of this class stores a map from Key to Value. 193 * 194 * This is a pointer type. After a copy, both Dicts 195 * will share the same storage: 196 * 197 * > Dict<int, string> a; 198 * > Dict<int, string> b = a; 199 * > b.insert(3, "three"); 200 * > ASSERT("three" == a.at(3)); 201 * 202 * We use this class in the PyTorch kernel API because that 203 * allows us to do optimizations and switch out the underlying 204 * map implementation without breaking backwards compatibility 205 * for the kernel API. 206 */ 207 template<class Key, class Value> 208 class Dict final { 209 private: 210 static_assert((std::is_same_v<IValue, Key> && std::is_same_v<IValue, Value>) || guts::typelist::contains<impl::valid_dict_key_types, Key>::value, "Invalid Key type for Dict. We only support int64_t, double, bool, and string."); 211 212 // impl_ stores the underlying map as a ska_ordered::order_preserving_flat_hash_map. 213 // We intentionally don't offer conversion from/to 214 // order_preserving_flat_hash_map, return references to it or something like that, 215 // because such operations would get expensive if we switch out 216 // the actual map implementation. 217 // This is an intrusive_ptr because Dict is a pointer type. 218 // Invariant: This will never be a nullptr, there will always be a valid 219 // DictImpl. 220 c10::intrusive_ptr<detail::DictImpl> impl_; 221 222 explicit Dict(c10::intrusive_ptr<detail::DictImpl>&& impl); 223 friend struct IValue; 224 template<class K, class V> friend Dict<K, V> impl::toTypedDict(Dict<IValue, IValue>); 225 template<class K, class V> friend Dict<IValue, IValue> impl::toGenericDict(Dict<K, V>); 226 227 public: 228 using key_type = Key; 229 using mapped_type = Value; 230 using size_type = typename detail::DictImpl::dict_map_type::size_type; 231 using iterator = impl::DictIterator<Key, Value, typename detail::DictImpl::dict_map_type::iterator>; 232 233 /** 234 * Creates an empty dict. 235 */ 236 explicit Dict(); 237 238 /** 239 * Create a generic dict with runtime type information. 240 * This only works for c10::impl::GenericDict and is not part of the public API 241 * but only supposed to be used internally by PyTorch. 242 */ 243 explicit Dict(TypePtr keyType, TypePtr valueType); 244 245 ~Dict() = default; 246 247 Dict(const Dict&) = default; 248 Dict& operator=(const Dict&) = default; 249 250 /** 251 * Create a new Dict pointing to a deep copy of the same data. 252 * The Dict returned is a new dict with separate storage. 253 * Changes in it are not reflected in the original dict or vice versa. 254 */ 255 Dict copy() const; 256 257 /** 258 * Returns an iterator to the first element of the container. 259 * If the container is empty, the returned iterator will be equal to end(). 260 */ 261 iterator begin() const; 262 263 /** 264 * Returns an iterator to the element following the last element of the container. 265 * This element acts as a placeholder; attempting to access it results in undefined behavior. 266 */ 267 iterator end() const; 268 269 /** 270 * Checks if the container has no elements. 271 */ 272 bool empty() const; 273 274 /** 275 * Returns the number of elements in the container. 276 */ 277 size_type size() const; 278 279 /** 280 * Erases all elements from the container. After this call, size() returns zero. 281 * Invalidates any references, pointers, or iterators referring to contained elements. May also invalidate past-the-end iterators. 282 */ 283 void clear() const; 284 285 /** 286 * Inserts element(s) into the container, if the container doesn't already contain an element with an equivalent key. 287 * May invalidate any references, pointers, or iterators referring to contained elements. 288 * 289 * @return A pair consisting of an iterator to the inserted element (or to the element that prevented the insertion) and a bool denoting whether the insertion took place. 290 */ 291 template<class Key_, class Value_> 292 std::pair<iterator, bool> insert(Key_&& key, Value_&& value) const; 293 294 /** 295 * If an element with the given key already exists, it is overwritten with the given value. 296 * Otherwise, a new element with the given key and value are inserted. 297 * May invalidate any references, pointers, or iterators referring to contained elements. 298 * 299 * @return The bool component is true if the insertion took place and false if the assignment took place. The iterator component is pointing at the element that was inserted or updated. 300 */ 301 template<class Key_, class Value_> 302 std::pair<iterator, bool> insert_or_assign(Key_&& key, Value_&& value) const; 303 304 /** 305 * Removes the element pointed to by iter. 306 * May invalidate any references, pointers, or iterators referring to contained elements. 307 * The iterator iter must be valid and dereferenceable. Thus the end() iterator (which is valid, but is not dereferenceable) cannot be used as a value for iter. 308 */ 309 void erase(iterator iter) const; 310 311 /** 312 * Removes the element with the given key, if it exists. 313 * May invalidate any references, pointers, or iterators referring to contained elements. 314 * 315 * @return The number of elements removed. This is either '1' if an element with the key existed, or '0' if it didn't. 316 */ 317 C10_NODISCARD size_t erase(const Key& key) const; 318 319 /** 320 * Returns the mapped value of the element with key equivalent to key. 321 * If no such element exists, an exception of type std::out_of_range is thrown. 322 */ 323 Value at(const Key& key) const; 324 325 /** 326 * Finds an element with key equivalent to key. 327 * 328 * @return Iterator to an element with key equivalent to key. 329 * If no such element is found, past-the-end (see end()) iterator is returned. 330 */ 331 iterator find(const Key& key) const; 332 333 /** 334 * Checks if there is an element with key equivalent to key in the container. 335 * 336 * @return true if there is such an element, otherwise false. 337 */ 338 bool contains(const Key& key) const; 339 340 /** 341 * Increase the capacity so that at least count elements can be stored without 342 * having to reallocate or rehash. 343 */ 344 void reserve(size_type count) const; 345 346 /** 347 * Value equality comparison. This function implements Python-like semantics for 348 * equality: two dicts with the same identity (e.g. same pointer) trivially 349 * compare equal, otherwise each element is compared for equality. 350 */ 351 template <class Key_, class Value_> 352 friend bool operator==( 353 const Dict<Key_, Value_>& lhs, 354 const Dict<Key_, Value_>& rhs); 355 template <class Key_, class Value_> 356 friend bool operator!=( 357 const Dict<Key_, Value_>& lhs, 358 const Dict<Key_, Value_>& rhs); 359 360 /** 361 * Identity comparison. Returns true if and only if `rhs` represents the same 362 * Dict object as `this`. 363 */ 364 bool is(const Dict& rhs) const; 365 366 // private API for now because the return type will change to TypePtr 367 // instead of std::optional<TypePtr> once types are mandatory. 368 TypePtr keyType() const; 369 TypePtr valueType() const; 370 371 // [unsafe set type] 372 // These functions mutate the tagged type of this dictionary in place. 373 // There is no checking that the members of the dictionary are instances 374 // of the new types, nor is there a check that other IValues which 375 // hold references to this dictionary have the right static type. 376 // This functionality is used only in the unpickler, where at 377 // creation type the real type of the dictionary is unknown, but 378 // then later recovered from the static type information of the 379 // unpickled object. 380 void unsafeSetKeyType(TypePtr t); 381 void unsafeSetValueType(TypePtr t); 382 }; 383 384 namespace impl { 385 // GenericDict is how IValue stores dicts. It is, however, not part of the 386 // public API. Kernels should use Dicts with concrete Key, Value types instead 387 // (maybe except for some internal prim ops). 388 using GenericDict = Dict<IValue, IValue>; 389 390 } 391 } 392 393 namespace torch { 394 template<class Key, class Value> using Dict = c10::Dict<Key, Value>; 395 } 396 397 #include <ATen/core/Dict_inl.h> // IWYU pragma: keep 398