xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/Dict_inl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ivalue.h>
4 #include <c10/util/hash.h>
5 
6 namespace c10 {
7 namespace detail {
operator()8 inline bool DictKeyEqualTo::operator()(const IValue& lhs, const IValue& rhs) const {
9   if (lhs.isTensor() && rhs.isTensor()) {
10     // for tensors, we compare only by identity (following how it's done in Python).
11     return lhs.is(rhs);
12   }
13   // Otherwise, we first compare by identity for efficiency, then by value (see:
14   // [container equality])
15   return _fastEqualsForContainer(lhs, rhs);
16 }
17 }
18 
19 template<class T> decltype(auto) getTypePtr();
20 std::string toString(const Type& type);
21 
22 namespace impl {
23 
24 template<class Key, class Value>
toTypedDict(GenericDict dict)25 Dict<Key, Value> toTypedDict(GenericDict dict) {
26   TORCH_INTERNAL_ASSERT(*getTypePtr<Key>() == *dict.impl_->elementTypes.keyType, "Tried to cast a Dict<", toString(*dict.impl_->elementTypes.keyType), ", ", toString(*dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(*getTypePtr<Key>()), ", ", toString(*getTypePtr<Value>()), ">. Key types mismatch.");
27   TORCH_INTERNAL_ASSERT(*getTypePtr<Value>() == *dict.impl_->elementTypes.valueType, "Tried to cast a Dict<", toString(*dict.impl_->elementTypes.keyType), ", ", toString(*dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(*getTypePtr<Key>()), ", ", toString(*getTypePtr<Value>()), ">. Value types mismatch.");
28 
29   return Dict<Key, Value>(std::move(dict.impl_));
30 }
31 
32 template<class Key, class Value>
toGenericDict(Dict<Key,Value> dict)33 GenericDict toGenericDict(Dict<Key, Value> dict) {
34   return GenericDict(std::move(dict.impl_));
35 }
36 }
37 
38 namespace detail {
39 
operator()40 inline size_t DictKeyHash::operator()(const IValue& ivalue) const {
41   if (ivalue.isInt()) {
42     return std::hash<int64_t>()(ivalue.toInt());
43   } else if (ivalue.isString()) {
44     return std::hash<c10::string_view>()(ivalue.toStringView());
45   } else if (ivalue.isDouble()) {
46     return std::hash<double>()(ivalue.toDouble());
47   } else if (ivalue.isComplexDouble()) {
48     return c10::hash<c10::complex<double>>()(ivalue.toComplexDouble());
49   } else if (ivalue.isBool()) {
50     return std::hash<bool>()(ivalue.toBool());
51   } else if (ivalue.isTensor()) {
52     return std::hash<TensorImpl*>()(ivalue.toTensor().unsafeGetTensorImpl());
53   } else if (ivalue.isDevice()) {
54     return std::hash<Device>()(ivalue.toDevice());
55   } else {
56     throw std::runtime_error(
57         "Can't hash IValues with tag '" + ivalue.tagKind() + "'");
58   }
59 }
60 
copy()61 inline intrusive_ptr<DictImpl> DictImpl::copy() const {
62   return make_intrusive<DictImpl>(dict, elementTypes);
63 }
64 
65 }
66 
67 template<class Key, class Value>
Dict()68 Dict<Key, Value>::Dict()
69   :Dict(make_intrusive<detail::DictImpl>(
70       detail::DictImpl::dict_map_type(),
71       detail::DictImpl::DictElementTypes{getTypePtr<Key>(), getTypePtr<Value>()})) {
72   static_assert(!std::is_same<Key, IValue>::value, "This constructor is not valid for Dict<IValue, _>. Please use c10::impl::GenericDict(keyType, valueType) instead.");
73   static_assert(!std::is_same<Value, IValue>::value, "This constructor is not valid for Dict<_, IValue>. Please use c10::impl::GenericDict(keyType, valueType) instead.");
74 }
75 
76 template<class Key, class Value>
Dict(TypePtr keyType,TypePtr valueType)77 Dict<Key, Value>::Dict(TypePtr keyType, TypePtr valueType)
78 : Dict(make_intrusive<detail::DictImpl>(
79     detail::DictImpl::dict_map_type(),
80     detail::DictImpl::DictElementTypes {std::move(keyType), std::move(valueType)})) {
81   static_assert(std::is_same<Key, IValue>::value, "This constructor is only valid for c10::impl::GenericDict.");
82   static_assert(std::is_same<Value, IValue>::value, "This constructor is only valid for c10::impl::GenericDict.");
83 }
84 
85 template<class Key, class Value>
Dict(c10::intrusive_ptr<detail::DictImpl> && impl)86 Dict<Key, Value>::Dict(c10::intrusive_ptr<detail::DictImpl>&& impl): impl_(std::move(impl)) {}
87 
88 template<class Key, class Value>
copy()89 Dict<Key, Value> Dict<Key, Value>::copy() const {
90   return Dict<Key, Value>(impl_->copy());
91 }
92 
93 template<class Key, class Value>
begin()94 typename Dict<Key, Value>::iterator Dict<Key, Value>::begin() const {
95   return iterator{impl_->dict.begin()};
96 }
97 
98 template<class Key, class Value>
end()99 typename Dict<Key, Value>::iterator Dict<Key, Value>::end() const {
100   return iterator{impl_->dict.end()};
101 }
102 
103 template<class Key, class Value>
empty()104 bool Dict<Key, Value>::empty() const {
105   return impl_->dict.empty();
106 }
107 
108 template<class Key, class Value>
size()109 typename Dict<Key, Value>::size_type Dict<Key, Value>::size() const {
110   return impl_->dict.size();
111 }
112 
113 template<class Key, class Value>
clear()114 void Dict<Key, Value>::clear() const {
115   impl_->dict.clear();
116 }
117 
118 template<class Key, class Value>
119 template<class Key_, class Value_>
insert(Key_ && key,Value_ && value)120 std::pair<typename Dict<Key, Value>::iterator, bool> Dict<Key, Value>::insert(Key_&& key, Value_&& value) const {
121   static_assert(std::is_constructible<Key, Key_>::value, "Wrong type for the key argument of Dict::insert");
122   static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of Dict::insert");
123   auto inserted = impl_->dict.emplace(
124       Key(std::forward<Key_>(key)),
125       Value(std::forward<Value_>(value)));
126   return {iterator{inserted.first}, inserted.second};
127 }
128 
129 template<class Key, class Value>
130 template<class Key_, class Value_>
insert_or_assign(Key_ && key,Value_ && value)131 std::pair<typename Dict<Key, Value>::iterator, bool> Dict<Key, Value>::insert_or_assign(Key_&& key, Value_&& value) const {
132   static_assert(std::is_constructible<Key, Key_>::value, "Wrong type for the key argument of Dict::insert_or_assign");
133   static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of Dict::insert_or_assign");
134   auto inserted = impl_->dict.insert_or_assign(
135     Key(std::forward<Key_>(key)),
136     Value(std::forward<Value_>(value)));
137   return {iterator{inserted.first}, inserted.second};
138 }
139 
140 template<class Key, class Value>
erase(iterator iter)141 void Dict<Key, Value>::erase(iterator iter) const {
142   impl_->dict.erase(iter.entryRef_.iterator_);
143 }
144 
145 template<class Key, class Value>
erase(const Key & key)146 C10_NODISCARD size_t Dict<Key, Value>::erase(const Key& key) const {
147   return impl_->dict.erase(key);
148 }
149 
150 template<class Key, class Value>
at(const Key & key)151 Value Dict<Key, Value>::at(const Key& key) const {
152   return impl_->dict.at(key).template to<Value>();
153 }
154 
155 template<class Key, class Value>
find(const Key & key)156 typename Dict<Key, Value>::iterator Dict<Key, Value>::find(const Key& key) const {
157   return iterator{impl_->dict.find(key)};
158 }
159 
160 template<class Key, class Value>
contains(const Key & key)161 bool Dict<Key, Value>::contains(const Key& key) const {
162   return end() != find(key);
163 }
164 
165 template<class Key, class Value>
reserve(size_type count)166 void Dict<Key, Value>::reserve(size_type count) const {
167   impl_->dict.reserve(count);
168 }
169 
170 template<class Key, class Value>
keyType()171 TypePtr Dict<Key, Value>::keyType() const {
172   return impl_->elementTypes.keyType;
173 }
174 
175 template<class Key, class Value>
valueType()176 TypePtr Dict<Key, Value>::valueType() const {
177   return impl_->elementTypes.valueType;
178 }
179 template <class Key, class Value>
unsafeSetKeyType(TypePtr t)180 void Dict<Key, Value>::unsafeSetKeyType(TypePtr t) {
181   impl_->elementTypes.keyType = std::move(t);
182 }
183 
184 template <class Key, class Value>
unsafeSetValueType(TypePtr t)185 void Dict<Key, Value>::unsafeSetValueType(TypePtr t) {
186   impl_->elementTypes.valueType = std::move(t);
187 }
188 
189 template <class Key_, class Value_>
190 bool operator==(const Dict<Key_, Value_>& lhs, const Dict<Key_, Value_>& rhs) {
191   // Dicts with the same identity trivially compare equal.
192   if (lhs.impl_ == rhs.impl_) {
193     return true;
194   }
195 
196   // Otherwise compare the values
197   return *lhs.impl_ == *rhs.impl_;
198 }
199 
200 template <class Key_, class Value_>
201 bool operator!=(const Dict<Key_, Value_>& lhs, const Dict<Key_, Value_>& rhs) {
202   return !(lhs == rhs);
203 }
204 
205 template <class Key, class Value>
is(const Dict & rhs)206 bool Dict<Key, Value>::is(const Dict& rhs) const {
207   return this->impl_ == rhs.impl_;
208 }
209 }
210