xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/Dict.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/macros/Macros.h>
4 #include <c10/macros/Export.h>
5 #include <c10/util/TypeTraits.h>
6 #include <c10/util/TypeList.h>
7 #include <c10/util/intrusive_ptr.h>
8 #include <c10/util/order_preserving_flat_hash_map.h>
9 #include <optional>
10 #include <ATen/core/TensorBody.h>
11 #include <ATen/core/jit_type_base.h>
12 
13 namespace c10 {
14 struct IValue;
15 template<class Key, class Value> class Dict;
16 struct Type;
17 
18 namespace impl {
19 
20 using valid_dict_key_types = guts::typelist::typelist<
21   int64_t,
22   std::string,
23   double,
24   c10::complex<double>,
25   bool,
26   at::Tensor
27 >;
28 }
29 
30 namespace detail {
31 
32 struct DictKeyHash {
33   size_t operator()(const IValue& ivalue) const;
34 };
35 
36 struct DictKeyEqualTo {
37   bool operator()(const IValue& lhs, const IValue& rhs) const;
38 };
39 
40 struct DictImpl final : public c10::intrusive_ptr_target {
41   using dict_map_type = ska_ordered::order_preserving_flat_hash_map<IValue, IValue, DictKeyHash, DictKeyEqualTo>;
42   struct DictElementTypes final {
43     TypePtr keyType;
44     TypePtr valueType;
45   };
46 
DictImplfinal47   explicit DictImpl(dict_map_type dict_, DictElementTypes elementTypes_)
48   : dict(std::move(dict_))
49   , elementTypes(std::move(elementTypes_)) {}
50   dict_map_type dict;
51 
52   DictElementTypes elementTypes;
53 
54   intrusive_ptr<DictImpl> copy() const;
55   friend TORCH_API bool operator==(const DictImpl& lhs, const DictImpl& rhs);
56 };
57 
58 }
59 
60 namespace impl {
61 template<class Key, class Value, class Iterator> class DictIterator;
62 
63 /**
64  * A reference to an entry in the Dict.
65  * Use the `key()` and `value()` methods to read the element.
66  */
67 template<class Key, class Value, class Iterator>
68 class DictEntryRef final {
69 public:
DictEntryRef(Iterator iterator)70   explicit DictEntryRef(Iterator iterator)
71   : iterator_(std::move(iterator)) {}
72 
key()73   decltype(auto) key() const {
74     return iterator_->first.template to<Key>();
75   }
76 
value()77   decltype(auto) value() const {
78     return iterator_->second.template to<Value>();
79   }
80 
81   template<class Value_>
setValue(Value_ && value)82   void setValue(Value_&& value) const {
83     static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of setValue()");
84     iterator_->second = Value(std::forward<Value_>(value));
85   }
86 
87 private:
88   // allow copying and moving, but only our friends (i.e. the Dict class) can do
89   // it. Copying/moving this reference wrapper would be too ambiguous to allow it
90   // in the public API.
91   DictEntryRef(const DictEntryRef&) = default;
92   DictEntryRef& operator=(const DictEntryRef&) = default;
93   DictEntryRef(DictEntryRef&&) noexcept = default;
94   DictEntryRef& operator=(DictEntryRef&& rhs) & noexcept = default;
95 
96   Iterator iterator_;
97   friend class DictIterator<Key, Value, Iterator>;
98   friend class Dict<Key, Value>;
99 };
100 
101 // this wraps map_type::iterator to make sure user code can't rely
102 // on it being the type of the underlying map.
103 template<class Key, class Value, class Iterator>
104 class DictIterator final {
105 public:
106    // C++17 friendly std::iterator implementation
107   using iterator_category = std::forward_iterator_tag;
108   using value_type = DictEntryRef<Key, Value, Iterator>;
109   using difference_type = std::ptrdiff_t;
110   using pointer = value_type*;
111   using reference = value_type&;
112 
113   explicit DictIterator() = default;
114   ~DictIterator() = default;
115 
DictIterator(const DictIterator & rhs)116   DictIterator(const DictIterator& rhs): entryRef_(rhs.entryRef_) {}
DictIterator(DictIterator && rhs)117   DictIterator(DictIterator&& rhs) noexcept: entryRef_(std::move(rhs.entryRef_)) {}
118   DictIterator& operator=(const DictIterator& rhs) {
119     entryRef_ = rhs.entryRef_;
120     return *this;
121   }
122   DictIterator& operator=(DictIterator&& rhs) noexcept {
123     entryRef_ = std::move(rhs.entryRef_);
124     return *this;
125   }
126 
127   DictIterator& operator++() {
128       ++entryRef_.iterator_;
129       return *this;
130   }
131 
132   DictIterator operator++(int) {
133       DictIterator copy(*this);
134       ++*this;
135       return copy;
136   }
137 
138   const DictEntryRef<Key, Value, Iterator>& operator*() const {
139       return entryRef_;
140   }
141 
142   const DictEntryRef<Key, Value, Iterator>* operator->() const {
143     return &entryRef_;
144   }
145 
146   friend difference_type operator-(const DictIterator& lhs, const DictIterator& rhs) {
147     return lhs.entryRef_.iterator_ - rhs.entryRef_.iterator_;
148   }
149 
150 private:
DictIterator(Iterator iterator)151   explicit DictIterator(Iterator iterator): entryRef_(std::move(iterator)) {}
152 
get_iterator_()153   const Iterator& get_iterator_() const {
154     return entryRef_.iterator_;
155   }
156 
157   friend bool operator==(const DictIterator& lhs, const DictIterator& rhs) {
158     return lhs.get_iterator_() == rhs.get_iterator_();
159   }
160 
161   friend bool operator!=(const DictIterator& lhs, const DictIterator& rhs) {
162     return lhs.get_iterator_() != rhs.get_iterator_();
163   }
164 
165   friend bool operator<(const DictIterator& lhs, const DictIterator& rhs) {
166     return lhs.get_iterator_() < rhs.get_iterator_();
167   }
168 
169   friend bool operator<=(const DictIterator& lhs, const DictIterator& rhs) {
170     return lhs.get_iterator_() <= rhs.get_iterator_();
171   }
172 
173   friend bool operator>(const DictIterator& lhs, const DictIterator& rhs) {
174     return lhs.get_iterator_() > rhs.get_iterator_();
175   }
176 
177   friend bool operator>=(const DictIterator& lhs, const DictIterator& rhs) {
178     return lhs.get_iterator_() >= rhs.get_iterator_();
179   }
180 
181   DictEntryRef<Key, Value, Iterator> entryRef_;
182 
183   friend class DictIterator<Key, Value, typename c10::detail::DictImpl::dict_map_type::iterator>;
184   friend class Dict<Key, Value>;
185 };
186 
187 template<class Key, class Value> Dict<Key, Value> toTypedDict(Dict<IValue, IValue> dict);
188 template<class Key, class Value> Dict<IValue, IValue> toGenericDict(Dict<Key, Value> dict);
189 }
190 
191 /**
192  * An object of this class stores a map from Key to Value.
193  *
194  * This is a pointer type. After a copy, both Dicts
195  * will share the same storage:
196  *
197  * > Dict<int, string> a;
198  * > Dict<int, string> b = a;
199  * > b.insert(3, "three");
200  * > ASSERT("three" == a.at(3));
201  *
202  * We use this class in the PyTorch kernel API because that
203  * allows us to do optimizations and switch out the underlying
204  * map implementation without breaking backwards compatibility
205  * for the kernel API.
206  */
207 template<class Key, class Value>
208 class Dict final {
209 private:
210   static_assert((std::is_same_v<IValue, Key> && std::is_same_v<IValue, Value>) || guts::typelist::contains<impl::valid_dict_key_types, Key>::value, "Invalid Key type for Dict. We only support int64_t, double, bool, and string.");
211 
212   // impl_ stores the underlying map as a ska_ordered::order_preserving_flat_hash_map.
213   // We intentionally don't offer conversion from/to
214   // order_preserving_flat_hash_map, return references to it or something like that,
215   // because such operations would get expensive if we switch out
216   // the actual map implementation.
217   // This is an intrusive_ptr because Dict is a pointer type.
218   // Invariant: This will never be a nullptr, there will always be a valid
219   // DictImpl.
220   c10::intrusive_ptr<detail::DictImpl> impl_;
221 
222   explicit Dict(c10::intrusive_ptr<detail::DictImpl>&& impl);
223   friend struct IValue;
224   template<class K, class V> friend Dict<K, V> impl::toTypedDict(Dict<IValue, IValue>);
225   template<class K, class V> friend Dict<IValue, IValue> impl::toGenericDict(Dict<K, V>);
226 
227 public:
228   using key_type = Key;
229   using mapped_type = Value;
230   using size_type = typename detail::DictImpl::dict_map_type::size_type;
231   using iterator = impl::DictIterator<Key, Value, typename detail::DictImpl::dict_map_type::iterator>;
232 
233   /**
234    * Creates an empty dict.
235    */
236   explicit Dict();
237 
238   /**
239    * Create a generic dict with runtime type information.
240    * This only works for c10::impl::GenericDict and is not part of the public API
241    * but only supposed to be used internally by PyTorch.
242    */
243   explicit Dict(TypePtr keyType, TypePtr valueType);
244 
245   ~Dict() = default;
246 
247   Dict(const Dict&) = default;
248   Dict& operator=(const Dict&) = default;
249 
250   /**
251    * Create a new Dict pointing to a deep copy of the same data.
252    * The Dict returned is a new dict with separate storage.
253    * Changes in it are not reflected in the original dict or vice versa.
254    */
255   Dict copy() const;
256 
257   /**
258    * Returns an iterator to the first element of the container.
259    * If the container is empty, the returned iterator will be equal to end().
260    */
261   iterator begin() const;
262 
263   /**
264    * Returns an iterator to the element following the last element of the container.
265    * This element acts as a placeholder; attempting to access it results in undefined behavior.
266    */
267   iterator end() const;
268 
269   /**
270    * Checks if the container has no elements.
271    */
272   bool empty() const;
273 
274   /**
275    * Returns the number of elements in the container.
276    */
277   size_type size() const;
278 
279   /**
280    * Erases all elements from the container. After this call, size() returns zero.
281    * Invalidates any references, pointers, or iterators referring to contained elements. May also invalidate past-the-end iterators.
282    */
283   void clear() const;
284 
285   /**
286    * Inserts element(s) into the container, if the container doesn't already contain an element with an equivalent key.
287    * May invalidate any references, pointers, or iterators referring to contained elements.
288    *
289    * @return A pair consisting of an iterator to the inserted element (or to the element that prevented the insertion) and a bool denoting whether the insertion took place.
290    */
291   template<class Key_, class Value_>
292   std::pair<iterator, bool> insert(Key_&& key, Value_&& value) const;
293 
294   /**
295    * If an element with the given key already exists, it is overwritten with the given value.
296    * Otherwise, a new element with the given key and value are inserted.
297    * May invalidate any references, pointers, or iterators referring to contained elements.
298    *
299    * @return The bool component is true if the insertion took place and false if the assignment took place. The iterator component is pointing at the element that was inserted or updated.
300    */
301   template<class Key_, class Value_>
302   std::pair<iterator, bool> insert_or_assign(Key_&& key, Value_&& value) const;
303 
304   /**
305    * Removes the element pointed to by iter.
306    * May invalidate any references, pointers, or iterators referring to contained elements.
307    * The iterator iter must be valid and dereferenceable. Thus the end() iterator (which is valid, but is not dereferenceable) cannot be used as a value for iter.
308    */
309   void erase(iterator iter) const;
310 
311   /**
312    * Removes the element with the given key, if it exists.
313    * May invalidate any references, pointers, or iterators referring to contained elements.
314    *
315    * @return The number of elements removed. This is either '1' if an element with the key existed, or '0' if it didn't.
316    */
317   C10_NODISCARD size_t erase(const Key& key) const;
318 
319   /**
320    * Returns the mapped value of the element with key equivalent to key.
321    * If no such element exists, an exception of type std::out_of_range is thrown.
322    */
323   Value at(const Key& key) const;
324 
325   /**
326    * Finds an element with key equivalent to key.
327    *
328    * @return Iterator to an element with key equivalent to key.
329    *         If no such element is found, past-the-end (see end()) iterator is returned.
330    */
331   iterator find(const Key& key) const;
332 
333   /**
334    * Checks if there is an element with key equivalent to key in the container.
335    *
336    * @return true if there is such an element, otherwise false.
337    */
338   bool contains(const Key& key) const;
339 
340   /**
341    * Increase the capacity so that at least count elements can be stored without
342    * having to reallocate or rehash.
343    */
344   void reserve(size_type count) const;
345 
346   /**
347    * Value equality comparison. This function implements Python-like semantics for
348    * equality: two dicts with the same identity (e.g. same pointer) trivially
349    * compare equal, otherwise each element is compared for equality.
350    */
351   template <class Key_, class Value_>
352   friend bool operator==(
353       const Dict<Key_, Value_>& lhs,
354       const Dict<Key_, Value_>& rhs);
355   template <class Key_, class Value_>
356   friend bool operator!=(
357       const Dict<Key_, Value_>& lhs,
358       const Dict<Key_, Value_>& rhs);
359 
360   /**
361    * Identity comparison. Returns true if and only if `rhs` represents the same
362    * Dict object as `this`.
363    */
364   bool is(const Dict& rhs) const;
365 
366   // private API for now because the return type will change to TypePtr
367   // instead of std::optional<TypePtr> once types are mandatory.
368   TypePtr keyType() const;
369   TypePtr valueType() const;
370 
371   // [unsafe set type]
372   // These functions mutate the tagged type of this dictionary in place.
373   // There is no checking that the members of the dictionary are instances
374   // of the new types, nor is there a check that other IValues which
375   // hold references to this dictionary have the right static type.
376   // This functionality is used only in the unpickler, where at
377   // creation type the real type of the dictionary is unknown, but
378   // then later recovered from the static type information of the
379   // unpickled object.
380   void unsafeSetKeyType(TypePtr t);
381   void unsafeSetValueType(TypePtr t);
382 };
383 
384 namespace impl {
385 // GenericDict is how IValue stores dicts. It is, however, not part of the
386 // public API. Kernels should use Dicts with concrete Key, Value types instead
387 // (maybe except for some internal prim ops).
388 using GenericDict = Dict<IValue, IValue>;
389 
390 }
391 }
392 
393 namespace torch {
394   template<class Key, class Value> using Dict = c10::Dict<Key, Value>;
395 }
396 
397 #include <ATen/core/Dict_inl.h>  // IWYU pragma: keep
398