xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/ordered_dict.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstdint>
4 #include <initializer_list>
5 #include <string>
6 #include <unordered_map>
7 #include <utility>
8 #include <vector>
9 
10 namespace torch {
11 /// An ordered dictionary implementation, akin to Python's `OrderedDict`.
12 template <typename Key, typename Value>
13 class OrderedDict {
14  public:
15   /// A (key, value) pair.
16   class Item;
17 
18   // The lifetime of an iterator is bound to the lifetime of the `OrderedDict`.
19   // Further, any `insert()` operation may invalidate all iterators
20   // pointing into the vector.
21   using Iterator = typename std::vector<Item>::iterator;
22   using ConstIterator = typename std::vector<Item>::const_iterator;
23 
24   /// Constructs the `OrderedDict` with a short description of the kinds of keys
25   /// stored in the `OrderedDict`. This description is used in error messages
26   /// thrown by the `OrderedDict`.
27   explicit OrderedDict(std::string key_description = "Key");
28 
29   /// Copy constructs this `OrderedDict` from `other`.
30   OrderedDict(const OrderedDict& other);
31 
32   /// Assigns items from `other` to this `OrderedDict`.
33   OrderedDict& operator=(const OrderedDict& other);
34 
35   // NB: Move works by default, because you can move-construct vectors of const
36   // values. I tried to make this noexcept (conditional on the move constructors
37   // of index_ and items_ being noexcept) but the obvious spelling didn't
38   // compile on Windows.
39   OrderedDict(OrderedDict&& other) noexcept = default;
40   OrderedDict& operator=(OrderedDict&& other) noexcept = default;
41 
42   ~OrderedDict() = default;
43 
44   /// Constructs a new `OrderedDict` and pre-populates it with the given
45   /// `Item`s.
46   /*implicit */ OrderedDict(std::initializer_list<Item> initializer_list);
47 
48   /// Returns the key description string the `OrderedDict` was constructed with.
49   const std::string& key_description() const noexcept;
50 
51   // Element Access
52 
53   /// Returns the very first item in the `OrderedDict` and throws an exception
54   /// if it is empty.
55   Item& front();
56 
57   /// Returns the very first item in the `OrderedDict` and throws an exception
58   /// if it is empty.
59   const Item& front() const;
60 
61   /// Returns the very last item in the `OrderedDict` and throws an exception
62   /// if it is empty.
63   Item& back();
64 
65   /// Returns the very last item in the `OrderedDict` and throws an exception
66   /// if it is empty.
67   const Item& back() const;
68 
69   /// Returns the item at the `index`-th position in the `OrderedDict`. Throws
70   /// an exception if the index is out of bounds.
71   Item& operator[](size_t index);
72 
73   /// Returns the item at the `index`-th position in the `OrderedDict`. Throws
74   /// an exception if the index is out of bounds.
75   const Item& operator[](size_t index) const;
76 
77   /// Returns the value associated with the given `key`. Throws an exception if
78   /// no such key is stored in the `OrderedDict`. Use `find()` for a
79   /// non-throwing way of accessing a value if it is present.
80   Value& operator[](const Key& key);
81 
82   /// Returns the value associated with the given `key`. Throws an exception if
83   /// no such key is stored in the `OrderedDict`. Use `find()` for a
84   /// non-throwing way of accessing a value if it is present.
85   const Value& operator[](const Key& key) const;
86 
87   // Lookup
88 
89   /// Returns a pointer to the value associated with the given key, or a
90   /// `nullptr` if no such key is stored in the `OrderedDict`.
91   Value* find(const Key& key) noexcept;
92 
93   /// Returns a pointer to the value associated with the given key, or a
94   /// `nullptr` if no such key is stored in the `OrderedDict`.
95   const Value* find(const Key& key) const noexcept;
96 
97   /// Returns true if the key is present in the `OrderedDict`.
98   bool contains(const Key& key) const noexcept;
99 
100   // Iterators
101 
102   /// Returns an iterator to the first item in the `OrderedDict`. Iteration is
103   /// ordered.
104   Iterator begin();
105 
106   /// Returns an iterator to the first item in the `OrderedDict`. Iteration is
107   /// ordered.
108   ConstIterator begin() const;
109 
110   /// Returns an iterator one past the last item in the `OrderedDict`.
111   Iterator end();
112 
113   /// Returns an iterator one past the last item in the `OrderedDict`.
114   ConstIterator end() const;
115 
116   // Capacity
117 
118   /// Returns the number of items currently stored in the `OrderedDict`.
119   size_t size() const noexcept;
120 
121   /// Returns true if the `OrderedDict` contains no elements.
122   bool is_empty() const noexcept;
123 
124   /// Resizes internal storage to fit at least `requested_capacity` items
125   /// without requiring reallocation.
126   void reserve(size_t requested_capacity);
127 
128   // Modifiers
129 
130   /// Inserts a new `(key, value)` pair into the `OrderedDict`. Throws an
131   /// exception if the key is already present. If insertion is successful,
132   /// immediately returns a reference to the inserted value.
133   template <typename K, typename V>
134   Value& insert(K&& key, V&& value);
135 
136   /// Inserts a new `(key, value)` pair into the `OrderedDict`. Throws an
137   /// exception if the key is already present. If insertion is successful,
138   /// immediately returns a reference to the inserted value.
139   Value& insert(Key key, Value&& value);
140 
141   /// Inserts all items from `other` into this `OrderedDict`. If any key from
142   /// `other` is already present in this `OrderedDict`, an exception is thrown.
143   void update(OrderedDict&& other);
144 
145   /// Inserts all items from `other` into this `OrderedDict`. If any key from
146   /// `other` is already present in this `OrderedDict`, an exception is thrown.
147   void update(const OrderedDict& other);
148 
149   /// Removes the item that has `key` from this `OrderedDict` if exists and if
150   /// it doesn't an exception is thrown.
151   void erase(const Key& key);
152 
153   /// Removes all items from this `OrderedDict`.
154   void clear();
155 
156   // Observers
157 
158   /// Returns the items stored in the `OrderedDict`.
159   const std::vector<Item>& items() const noexcept;
160 
161   /// Returns a newly allocated vector and copies all keys from this
162   /// `OrderedDict` into the vector.
163   ::std::vector<Key> keys() const;
164 
165   /// Returns a newly allocated vector and copies all values from this
166   /// `OrderedDict` into the vector.
167   ::std::vector<Value> values() const;
168 
169   /// Returns a newly allocated vector and copies all keys and values from this
170   /// `OrderedDict` into a vector of `std::pair<Key, Value>`.
171   ::std::vector<std::pair<Key, Value>> pairs() const;
172 
173   /// Returns true if both dicts contain the same keys and values, in the same
174   /// order.
175   template <typename K, typename V>
176   friend bool operator==(
177       const OrderedDict<K, V>& a,
178       const OrderedDict<K, V>& b);
179 
180  private:
181   /// A mapping from a key to an index into the `items_` vector.
182   ::std::unordered_map<Key, size_t> index_;
183 
184   /// The items stored in the `OrderedDict`.
185   ::std::vector<Item> items_;
186 
187   /// A description of the keys stored in the `OrderedDict`.
188   ::std::string key_description_{"Key"};
189 };
190 
191 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ OrderedDict::Item ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
192 
193 template <typename Key, typename Value>
194 class OrderedDict<Key, Value>::Item {
195  public:
196   /// Constructs a new item.
Item(Key key,Value value)197   Item(Key key, Value value) : pair_(std::move(key), std::move(value)) {}
198 
199   /// Returns a reference to the value.
200   Value& operator*() {
201     return value();
202   }
203 
204   /// Returns a reference to the value.
205   const Value& operator*() const {
206     return value();
207   }
208 
209   /// Allows access to the value using the arrow operator.
210   Value* operator->() {
211     return &value();
212   }
213 
214   /// Allows access to the value using the arrow operator.
215   const Value* operator->() const {
216     return &value();
217   }
218 
219   /// Returns a reference to the key.
key()220   const Key& key() const noexcept {
221     return pair_.first;
222   }
223 
224   /// Returns a reference to the value.
value()225   Value& value() noexcept {
226     return pair_.second;
227   }
228 
229   /// Returns a reference to the value.
value()230   const Value& value() const noexcept {
231     return pair_.second;
232   }
233 
234   /// Returns a `(key, value)` pair.
pair()235   const std::pair<Key, Value>& pair() const noexcept {
236     return pair_;
237   }
238 
239  private:
240   /// This is stored as an std::pair because it will make Python binding a lot,
241   /// lot easier.
242   ::std::pair<Key, Value> pair_;
243 };
244 
245 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ OrderedDict ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
246 
247 template <typename Key, typename Value>
OrderedDict(std::string key_description)248 OrderedDict<Key, Value>::OrderedDict(std::string key_description)
249     : key_description_(std::move(key_description)) {}
250 
251 template <typename Key, typename Value>
OrderedDict(const OrderedDict & other)252 OrderedDict<Key, Value>::OrderedDict(const OrderedDict& other)
253     : index_(other.index_), key_description_(other.key_description_) {
254   // Copy we have to do ourselves, because items' keys are const, so we have to
255   // re-insert the items.
256   for (const auto& item : other.items_) {
257     items_.push_back(item);
258   }
259 }
260 
261 template <typename Key, typename Value>
262 OrderedDict<Key, Value>& OrderedDict<Key, Value>::operator=(
263     const OrderedDict& other) {
264   index_ = other.index_;
265   items_.clear();
266   for (auto& item : other.items_) {
267     items_.push_back(item);
268   }
269   key_description_ = other.key_description_;
270   return *this;
271 }
272 
273 template <typename Key, typename Value>
OrderedDict(std::initializer_list<Item> initializer_list)274 OrderedDict<Key, Value>::OrderedDict(
275     std::initializer_list<Item> initializer_list)
276     : OrderedDict("Key") {
277   items_.reserve(initializer_list.size());
278   for (auto& item : initializer_list) {
279     // Copy the key here and move it into the index.
280     items_.emplace_back(item.key(), std::move(item.value()));
281     index_.emplace(std::move(item.key()), size() - 1);
282   }
283 }
284 
285 template <typename Key, typename Value>
begin()286 typename OrderedDict<Key, Value>::Iterator OrderedDict<Key, Value>::begin() {
287   return items_.begin();
288 }
289 
290 template <typename Key, typename Value>
begin()291 typename OrderedDict<Key, Value>::ConstIterator OrderedDict<Key, Value>::begin()
292     const {
293   return items_.begin();
294 }
295 
296 template <typename Key, typename Value>
end()297 typename OrderedDict<Key, Value>::Iterator OrderedDict<Key, Value>::end() {
298   return items_.end();
299 }
300 
301 template <typename Key, typename Value>
end()302 typename OrderedDict<Key, Value>::ConstIterator OrderedDict<Key, Value>::end()
303     const {
304   return items_.end();
305 }
306 
307 template <typename Key, typename Value>
front()308 typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::front() {
309   TORCH_CHECK(!items_.empty(), "Called front() on an empty OrderedDict");
310   return items_.front();
311 }
312 
313 template <typename Key, typename Value>
front()314 const typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::front()
315     const {
316   TORCH_CHECK(!items_.empty(), "Called front() on an empty OrderedDict");
317   return items_.front();
318 }
319 
320 template <typename Key, typename Value>
back()321 typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::back() {
322   TORCH_CHECK(!items_.empty(), "Called back() on an empty OrderedDict");
323   return items_.back();
324 }
325 
326 template <typename Key, typename Value>
back()327 const typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::back()
328     const {
329   TORCH_CHECK(!items_.empty(), "Called back() on an empty OrderedDict");
330   return items_.back();
331 }
332 
333 template <typename Key, typename Value>
334 typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::operator[](
335     size_t index) {
336   TORCH_CHECK(index < items_.size(), "Index ", index, " is out of bounds");
337   return items_[index];
338 }
339 
340 template <typename Key, typename Value>
341 const typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::
342 operator[](size_t index) const {
343   TORCH_CHECK(index < items_.size(), "Index ", index, " is out of bounds");
344   return items_[index];
345 }
346 
347 template <typename Key, typename Value>
348 Value& OrderedDict<Key, Value>::operator[](const Key& key) {
349   if (auto* value = find(key)) {
350     return *value;
351   }
352   AT_ERROR(key_description_, " '", key, "' is not defined");
353 }
354 
355 template <typename Key, typename Value>
356 const Value& OrderedDict<Key, Value>::operator[](const Key& key) const {
357   if (auto* value = find(key)) {
358     return *value;
359   }
360   AT_ERROR(key_description_, " '", key, "' is not defined");
361 }
362 
363 template <typename Key, typename Value>
364 template <typename K, typename V>
insert(K && key,V && value)365 Value& OrderedDict<Key, Value>::insert(K&& key, V&& value) {
366   TORCH_CHECK(
367       index_.count(key) == 0, key_description_, " '", key, "' already defined");
368   // Copy `key` here and move it into the index.
369   items_.emplace_back(key, std::forward<V>(value));
370   index_.emplace(std::forward<K>(key), size() - 1);
371   return items_.back().value();
372 }
373 
374 template <typename Key, typename Value>
insert(Key key,Value && value)375 Value& OrderedDict<Key, Value>::insert(Key key, Value&& value) {
376   return insert<Key, Value>(std::move(key), std::move(value));
377 }
378 
379 template <typename Key, typename Value>
update(OrderedDict && other)380 void OrderedDict<Key, Value>::update(OrderedDict&& other) {
381   reserve(size() + other.size());
382   for (auto& item : other) {
383     // We want to call `insert()` to prevent duplicate keys.
384     insert(std::move(item.key()), std::move(item.value()));
385   }
386 }
387 
388 template <typename Key, typename Value>
update(const OrderedDict & other)389 void OrderedDict<Key, Value>::update(const OrderedDict& other) {
390   reserve(size() + other.size());
391   for (auto& item : other) {
392     // We want to call `insert()` to prevent duplicate keys.
393     insert(item.key(), item.value());
394   }
395 }
396 
397 template <typename Key, typename Value>
find(const Key & key)398 Value* OrderedDict<Key, Value>::find(const Key& key) noexcept {
399   auto iterator = index_.find(key);
400   if (iterator == index_.end()) {
401     return nullptr;
402   }
403   return &items_[iterator->second].value();
404 }
405 
406 template <typename Key, typename Value>
find(const Key & key)407 const Value* OrderedDict<Key, Value>::find(const Key& key) const noexcept {
408   auto iterator = index_.find(key);
409   if (iterator == index_.end()) {
410     return nullptr;
411   }
412   return &items_[iterator->second].value();
413 }
414 
415 template <typename Key, typename Value>
erase(const Key & key)416 void OrderedDict<Key, Value>::erase(const Key& key) {
417   auto it = index_.find(key);
418   TORCH_CHECK(it != index_.end(), "Key '", key, "' doesn't exist");
419 
420   auto index = it->second;
421   index_.erase(it);
422   items_.erase(items_.begin() + index);
423 
424   for (auto& pair : index_)
425     if (pair.second > index)
426       --pair.second;
427 }
428 
429 template <typename Key, typename Value>
contains(const Key & key)430 bool OrderedDict<Key, Value>::contains(const Key& key) const noexcept {
431   return find(key) != nullptr;
432 }
433 
434 template <typename Key, typename Value>
clear()435 void OrderedDict<Key, Value>::clear() {
436   index_.clear();
437   items_.clear();
438 }
439 
440 template <typename Key, typename Value>
size()441 size_t OrderedDict<Key, Value>::size() const noexcept {
442   return items_.size();
443 }
444 
445 template <typename Key, typename Value>
is_empty()446 bool OrderedDict<Key, Value>::is_empty() const noexcept {
447   return items_.empty();
448 }
449 
450 template <typename Key, typename Value>
key_description()451 const std::string& OrderedDict<Key, Value>::key_description() const noexcept {
452   return key_description_;
453 }
454 
455 template <typename Key, typename Value>
456 const std::vector<typename OrderedDict<Key, Value>::Item>& OrderedDict<
457     Key,
items()458     Value>::items() const noexcept {
459   return items_;
460 }
461 
462 template <typename Key, typename Value>
keys()463 ::std::vector<Key> OrderedDict<Key, Value>::keys() const {
464   std::vector<Key> keys;
465   keys.reserve(size());
466   for (const auto& item : items_) {
467     keys.push_back(item.key());
468   }
469   return keys;
470 }
471 
472 template <typename Key, typename Value>
values()473 ::std::vector<Value> OrderedDict<Key, Value>::values() const {
474   std::vector<Value> values;
475   values.reserve(size());
476   for (const auto& item : items_) {
477     values.push_back(item.value());
478   }
479   return values;
480 }
481 
482 template <typename Key, typename Value>
pairs()483 ::std::vector<std::pair<Key, Value>> OrderedDict<Key, Value>::pairs() const {
484   std::vector<std::pair<Key, Value>> values;
485   values.reserve(size());
486   for (const auto& item : items_) {
487     values.push_back(item.pair());
488   }
489   return values;
490 }
491 
492 template <typename Key, typename Value>
reserve(size_t requested_capacity)493 void OrderedDict<Key, Value>::reserve(size_t requested_capacity) {
494   index_.reserve(requested_capacity);
495   items_.reserve(requested_capacity);
496 }
497 
498 template <typename K, typename V>
499 bool operator==(
500     const torch::OrderedDict<K, V>& a,
501     const torch::OrderedDict<K, V>& b) {
502   using Item = typename torch::OrderedDict<K, V>::Item;
503   if (a.index_ != b.index_)
504     return false;
505   if (a.items_.size() != b.items_.size())
506     return false;
507   // NOTE: There's no point in comparing keys for items_, as we already know
508   // that index is equal.
509   return std::equal(
510       a.items_.begin(),
511       a.items_.end(),
512       b.items_.begin(),
513       [](const Item& a, const Item& b) { return a.value() == b.value(); });
514 }
515 
516 } // namespace torch
517