1 #pragma once
2
3 #include <cstdint>
4 #include <initializer_list>
5 #include <string>
6 #include <unordered_map>
7 #include <utility>
8 #include <vector>
9
10 namespace torch {
11 /// An ordered dictionary implementation, akin to Python's `OrderedDict`.
12 template <typename Key, typename Value>
13 class OrderedDict {
14 public:
15 /// A (key, value) pair.
16 class Item;
17
18 // The lifetime of an iterator is bound to the lifetime of the `OrderedDict`.
19 // Further, any `insert()` operation may invalidate all iterators
20 // pointing into the vector.
21 using Iterator = typename std::vector<Item>::iterator;
22 using ConstIterator = typename std::vector<Item>::const_iterator;
23
24 /// Constructs the `OrderedDict` with a short description of the kinds of keys
25 /// stored in the `OrderedDict`. This description is used in error messages
26 /// thrown by the `OrderedDict`.
27 explicit OrderedDict(std::string key_description = "Key");
28
29 /// Copy constructs this `OrderedDict` from `other`.
30 OrderedDict(const OrderedDict& other);
31
32 /// Assigns items from `other` to this `OrderedDict`.
33 OrderedDict& operator=(const OrderedDict& other);
34
35 // NB: Move works by default, because you can move-construct vectors of const
36 // values. I tried to make this noexcept (conditional on the move constructors
37 // of index_ and items_ being noexcept) but the obvious spelling didn't
38 // compile on Windows.
39 OrderedDict(OrderedDict&& other) noexcept = default;
40 OrderedDict& operator=(OrderedDict&& other) noexcept = default;
41
42 ~OrderedDict() = default;
43
44 /// Constructs a new `OrderedDict` and pre-populates it with the given
45 /// `Item`s.
46 /*implicit */ OrderedDict(std::initializer_list<Item> initializer_list);
47
48 /// Returns the key description string the `OrderedDict` was constructed with.
49 const std::string& key_description() const noexcept;
50
51 // Element Access
52
53 /// Returns the very first item in the `OrderedDict` and throws an exception
54 /// if it is empty.
55 Item& front();
56
57 /// Returns the very first item in the `OrderedDict` and throws an exception
58 /// if it is empty.
59 const Item& front() const;
60
61 /// Returns the very last item in the `OrderedDict` and throws an exception
62 /// if it is empty.
63 Item& back();
64
65 /// Returns the very last item in the `OrderedDict` and throws an exception
66 /// if it is empty.
67 const Item& back() const;
68
69 /// Returns the item at the `index`-th position in the `OrderedDict`. Throws
70 /// an exception if the index is out of bounds.
71 Item& operator[](size_t index);
72
73 /// Returns the item at the `index`-th position in the `OrderedDict`. Throws
74 /// an exception if the index is out of bounds.
75 const Item& operator[](size_t index) const;
76
77 /// Returns the value associated with the given `key`. Throws an exception if
78 /// no such key is stored in the `OrderedDict`. Use `find()` for a
79 /// non-throwing way of accessing a value if it is present.
80 Value& operator[](const Key& key);
81
82 /// Returns the value associated with the given `key`. Throws an exception if
83 /// no such key is stored in the `OrderedDict`. Use `find()` for a
84 /// non-throwing way of accessing a value if it is present.
85 const Value& operator[](const Key& key) const;
86
87 // Lookup
88
89 /// Returns a pointer to the value associated with the given key, or a
90 /// `nullptr` if no such key is stored in the `OrderedDict`.
91 Value* find(const Key& key) noexcept;
92
93 /// Returns a pointer to the value associated with the given key, or a
94 /// `nullptr` if no such key is stored in the `OrderedDict`.
95 const Value* find(const Key& key) const noexcept;
96
97 /// Returns true if the key is present in the `OrderedDict`.
98 bool contains(const Key& key) const noexcept;
99
100 // Iterators
101
102 /// Returns an iterator to the first item in the `OrderedDict`. Iteration is
103 /// ordered.
104 Iterator begin();
105
106 /// Returns an iterator to the first item in the `OrderedDict`. Iteration is
107 /// ordered.
108 ConstIterator begin() const;
109
110 /// Returns an iterator one past the last item in the `OrderedDict`.
111 Iterator end();
112
113 /// Returns an iterator one past the last item in the `OrderedDict`.
114 ConstIterator end() const;
115
116 // Capacity
117
118 /// Returns the number of items currently stored in the `OrderedDict`.
119 size_t size() const noexcept;
120
121 /// Returns true if the `OrderedDict` contains no elements.
122 bool is_empty() const noexcept;
123
124 /// Resizes internal storage to fit at least `requested_capacity` items
125 /// without requiring reallocation.
126 void reserve(size_t requested_capacity);
127
128 // Modifiers
129
130 /// Inserts a new `(key, value)` pair into the `OrderedDict`. Throws an
131 /// exception if the key is already present. If insertion is successful,
132 /// immediately returns a reference to the inserted value.
133 template <typename K, typename V>
134 Value& insert(K&& key, V&& value);
135
136 /// Inserts a new `(key, value)` pair into the `OrderedDict`. Throws an
137 /// exception if the key is already present. If insertion is successful,
138 /// immediately returns a reference to the inserted value.
139 Value& insert(Key key, Value&& value);
140
141 /// Inserts all items from `other` into this `OrderedDict`. If any key from
142 /// `other` is already present in this `OrderedDict`, an exception is thrown.
143 void update(OrderedDict&& other);
144
145 /// Inserts all items from `other` into this `OrderedDict`. If any key from
146 /// `other` is already present in this `OrderedDict`, an exception is thrown.
147 void update(const OrderedDict& other);
148
149 /// Removes the item that has `key` from this `OrderedDict` if exists and if
150 /// it doesn't an exception is thrown.
151 void erase(const Key& key);
152
153 /// Removes all items from this `OrderedDict`.
154 void clear();
155
156 // Observers
157
158 /// Returns the items stored in the `OrderedDict`.
159 const std::vector<Item>& items() const noexcept;
160
161 /// Returns a newly allocated vector and copies all keys from this
162 /// `OrderedDict` into the vector.
163 ::std::vector<Key> keys() const;
164
165 /// Returns a newly allocated vector and copies all values from this
166 /// `OrderedDict` into the vector.
167 ::std::vector<Value> values() const;
168
169 /// Returns a newly allocated vector and copies all keys and values from this
170 /// `OrderedDict` into a vector of `std::pair<Key, Value>`.
171 ::std::vector<std::pair<Key, Value>> pairs() const;
172
173 /// Returns true if both dicts contain the same keys and values, in the same
174 /// order.
175 template <typename K, typename V>
176 friend bool operator==(
177 const OrderedDict<K, V>& a,
178 const OrderedDict<K, V>& b);
179
180 private:
181 /// A mapping from a key to an index into the `items_` vector.
182 ::std::unordered_map<Key, size_t> index_;
183
184 /// The items stored in the `OrderedDict`.
185 ::std::vector<Item> items_;
186
187 /// A description of the keys stored in the `OrderedDict`.
188 ::std::string key_description_{"Key"};
189 };
190
191 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ OrderedDict::Item ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
192
193 template <typename Key, typename Value>
194 class OrderedDict<Key, Value>::Item {
195 public:
196 /// Constructs a new item.
Item(Key key,Value value)197 Item(Key key, Value value) : pair_(std::move(key), std::move(value)) {}
198
199 /// Returns a reference to the value.
200 Value& operator*() {
201 return value();
202 }
203
204 /// Returns a reference to the value.
205 const Value& operator*() const {
206 return value();
207 }
208
209 /// Allows access to the value using the arrow operator.
210 Value* operator->() {
211 return &value();
212 }
213
214 /// Allows access to the value using the arrow operator.
215 const Value* operator->() const {
216 return &value();
217 }
218
219 /// Returns a reference to the key.
key()220 const Key& key() const noexcept {
221 return pair_.first;
222 }
223
224 /// Returns a reference to the value.
value()225 Value& value() noexcept {
226 return pair_.second;
227 }
228
229 /// Returns a reference to the value.
value()230 const Value& value() const noexcept {
231 return pair_.second;
232 }
233
234 /// Returns a `(key, value)` pair.
pair()235 const std::pair<Key, Value>& pair() const noexcept {
236 return pair_;
237 }
238
239 private:
240 /// This is stored as an std::pair because it will make Python binding a lot,
241 /// lot easier.
242 ::std::pair<Key, Value> pair_;
243 };
244
245 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ OrderedDict ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
246
247 template <typename Key, typename Value>
OrderedDict(std::string key_description)248 OrderedDict<Key, Value>::OrderedDict(std::string key_description)
249 : key_description_(std::move(key_description)) {}
250
251 template <typename Key, typename Value>
OrderedDict(const OrderedDict & other)252 OrderedDict<Key, Value>::OrderedDict(const OrderedDict& other)
253 : index_(other.index_), key_description_(other.key_description_) {
254 // Copy we have to do ourselves, because items' keys are const, so we have to
255 // re-insert the items.
256 for (const auto& item : other.items_) {
257 items_.push_back(item);
258 }
259 }
260
261 template <typename Key, typename Value>
262 OrderedDict<Key, Value>& OrderedDict<Key, Value>::operator=(
263 const OrderedDict& other) {
264 index_ = other.index_;
265 items_.clear();
266 for (auto& item : other.items_) {
267 items_.push_back(item);
268 }
269 key_description_ = other.key_description_;
270 return *this;
271 }
272
273 template <typename Key, typename Value>
OrderedDict(std::initializer_list<Item> initializer_list)274 OrderedDict<Key, Value>::OrderedDict(
275 std::initializer_list<Item> initializer_list)
276 : OrderedDict("Key") {
277 items_.reserve(initializer_list.size());
278 for (auto& item : initializer_list) {
279 // Copy the key here and move it into the index.
280 items_.emplace_back(item.key(), std::move(item.value()));
281 index_.emplace(std::move(item.key()), size() - 1);
282 }
283 }
284
285 template <typename Key, typename Value>
begin()286 typename OrderedDict<Key, Value>::Iterator OrderedDict<Key, Value>::begin() {
287 return items_.begin();
288 }
289
290 template <typename Key, typename Value>
begin()291 typename OrderedDict<Key, Value>::ConstIterator OrderedDict<Key, Value>::begin()
292 const {
293 return items_.begin();
294 }
295
296 template <typename Key, typename Value>
end()297 typename OrderedDict<Key, Value>::Iterator OrderedDict<Key, Value>::end() {
298 return items_.end();
299 }
300
301 template <typename Key, typename Value>
end()302 typename OrderedDict<Key, Value>::ConstIterator OrderedDict<Key, Value>::end()
303 const {
304 return items_.end();
305 }
306
307 template <typename Key, typename Value>
front()308 typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::front() {
309 TORCH_CHECK(!items_.empty(), "Called front() on an empty OrderedDict");
310 return items_.front();
311 }
312
313 template <typename Key, typename Value>
front()314 const typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::front()
315 const {
316 TORCH_CHECK(!items_.empty(), "Called front() on an empty OrderedDict");
317 return items_.front();
318 }
319
320 template <typename Key, typename Value>
back()321 typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::back() {
322 TORCH_CHECK(!items_.empty(), "Called back() on an empty OrderedDict");
323 return items_.back();
324 }
325
326 template <typename Key, typename Value>
back()327 const typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::back()
328 const {
329 TORCH_CHECK(!items_.empty(), "Called back() on an empty OrderedDict");
330 return items_.back();
331 }
332
333 template <typename Key, typename Value>
334 typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::operator[](
335 size_t index) {
336 TORCH_CHECK(index < items_.size(), "Index ", index, " is out of bounds");
337 return items_[index];
338 }
339
340 template <typename Key, typename Value>
341 const typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::
342 operator[](size_t index) const {
343 TORCH_CHECK(index < items_.size(), "Index ", index, " is out of bounds");
344 return items_[index];
345 }
346
347 template <typename Key, typename Value>
348 Value& OrderedDict<Key, Value>::operator[](const Key& key) {
349 if (auto* value = find(key)) {
350 return *value;
351 }
352 AT_ERROR(key_description_, " '", key, "' is not defined");
353 }
354
355 template <typename Key, typename Value>
356 const Value& OrderedDict<Key, Value>::operator[](const Key& key) const {
357 if (auto* value = find(key)) {
358 return *value;
359 }
360 AT_ERROR(key_description_, " '", key, "' is not defined");
361 }
362
363 template <typename Key, typename Value>
364 template <typename K, typename V>
insert(K && key,V && value)365 Value& OrderedDict<Key, Value>::insert(K&& key, V&& value) {
366 TORCH_CHECK(
367 index_.count(key) == 0, key_description_, " '", key, "' already defined");
368 // Copy `key` here and move it into the index.
369 items_.emplace_back(key, std::forward<V>(value));
370 index_.emplace(std::forward<K>(key), size() - 1);
371 return items_.back().value();
372 }
373
374 template <typename Key, typename Value>
insert(Key key,Value && value)375 Value& OrderedDict<Key, Value>::insert(Key key, Value&& value) {
376 return insert<Key, Value>(std::move(key), std::move(value));
377 }
378
379 template <typename Key, typename Value>
update(OrderedDict && other)380 void OrderedDict<Key, Value>::update(OrderedDict&& other) {
381 reserve(size() + other.size());
382 for (auto& item : other) {
383 // We want to call `insert()` to prevent duplicate keys.
384 insert(std::move(item.key()), std::move(item.value()));
385 }
386 }
387
388 template <typename Key, typename Value>
update(const OrderedDict & other)389 void OrderedDict<Key, Value>::update(const OrderedDict& other) {
390 reserve(size() + other.size());
391 for (auto& item : other) {
392 // We want to call `insert()` to prevent duplicate keys.
393 insert(item.key(), item.value());
394 }
395 }
396
397 template <typename Key, typename Value>
find(const Key & key)398 Value* OrderedDict<Key, Value>::find(const Key& key) noexcept {
399 auto iterator = index_.find(key);
400 if (iterator == index_.end()) {
401 return nullptr;
402 }
403 return &items_[iterator->second].value();
404 }
405
406 template <typename Key, typename Value>
find(const Key & key)407 const Value* OrderedDict<Key, Value>::find(const Key& key) const noexcept {
408 auto iterator = index_.find(key);
409 if (iterator == index_.end()) {
410 return nullptr;
411 }
412 return &items_[iterator->second].value();
413 }
414
415 template <typename Key, typename Value>
erase(const Key & key)416 void OrderedDict<Key, Value>::erase(const Key& key) {
417 auto it = index_.find(key);
418 TORCH_CHECK(it != index_.end(), "Key '", key, "' doesn't exist");
419
420 auto index = it->second;
421 index_.erase(it);
422 items_.erase(items_.begin() + index);
423
424 for (auto& pair : index_)
425 if (pair.second > index)
426 --pair.second;
427 }
428
429 template <typename Key, typename Value>
contains(const Key & key)430 bool OrderedDict<Key, Value>::contains(const Key& key) const noexcept {
431 return find(key) != nullptr;
432 }
433
434 template <typename Key, typename Value>
clear()435 void OrderedDict<Key, Value>::clear() {
436 index_.clear();
437 items_.clear();
438 }
439
440 template <typename Key, typename Value>
size()441 size_t OrderedDict<Key, Value>::size() const noexcept {
442 return items_.size();
443 }
444
445 template <typename Key, typename Value>
is_empty()446 bool OrderedDict<Key, Value>::is_empty() const noexcept {
447 return items_.empty();
448 }
449
450 template <typename Key, typename Value>
key_description()451 const std::string& OrderedDict<Key, Value>::key_description() const noexcept {
452 return key_description_;
453 }
454
455 template <typename Key, typename Value>
456 const std::vector<typename OrderedDict<Key, Value>::Item>& OrderedDict<
457 Key,
items()458 Value>::items() const noexcept {
459 return items_;
460 }
461
462 template <typename Key, typename Value>
keys()463 ::std::vector<Key> OrderedDict<Key, Value>::keys() const {
464 std::vector<Key> keys;
465 keys.reserve(size());
466 for (const auto& item : items_) {
467 keys.push_back(item.key());
468 }
469 return keys;
470 }
471
472 template <typename Key, typename Value>
values()473 ::std::vector<Value> OrderedDict<Key, Value>::values() const {
474 std::vector<Value> values;
475 values.reserve(size());
476 for (const auto& item : items_) {
477 values.push_back(item.value());
478 }
479 return values;
480 }
481
482 template <typename Key, typename Value>
pairs()483 ::std::vector<std::pair<Key, Value>> OrderedDict<Key, Value>::pairs() const {
484 std::vector<std::pair<Key, Value>> values;
485 values.reserve(size());
486 for (const auto& item : items_) {
487 values.push_back(item.pair());
488 }
489 return values;
490 }
491
492 template <typename Key, typename Value>
reserve(size_t requested_capacity)493 void OrderedDict<Key, Value>::reserve(size_t requested_capacity) {
494 index_.reserve(requested_capacity);
495 items_.reserve(requested_capacity);
496 }
497
498 template <typename K, typename V>
499 bool operator==(
500 const torch::OrderedDict<K, V>& a,
501 const torch::OrderedDict<K, V>& b) {
502 using Item = typename torch::OrderedDict<K, V>::Item;
503 if (a.index_ != b.index_)
504 return false;
505 if (a.items_.size() != b.items_.size())
506 return false;
507 // NOTE: There's no point in comparing keys for items_, as we already know
508 // that index is equal.
509 return std::equal(
510 a.items_.begin(),
511 a.items_.end(),
512 b.items_.begin(),
513 [](const Item& a, const Item& b) { return a.value() == b.value(); });
514 }
515
516 } // namespace torch
517