1 #include <ATen/core/ivalue.h>
2 #include <pybind11/cast.h>
3 #include <pybind11/detail/common.h>
4 #include <torch/csrc/jit/python/pybind_utils.h>
5 #include <torch/csrc/jit/python/python_dict.h>
6 #include <torch/csrc/jit/runtime/jit_exception.h>
7 #include <torch/csrc/utils/pybind.h>
8 #include <sstream>
9 #include <stdexcept>
10
11 namespace torch::jit {
12
next()13 IValue ScriptDictIterator::next() {
14 if (iter_ == end_) {
15 throw py::stop_iteration();
16 }
17
18 // Since this is the iterator for .items(), the current key and value
19 // should be returned as a tuple.
20 IValue result = c10::ivalue::Tuple::create({iter_->key(), iter_->value()});
21
22 // Advance the iterator for next time.
23 iter_++;
24
25 return result;
26 }
27
next()28 IValue ScriptDictKeyIterator::next() {
29 if (iter_ == end_) {
30 throw py::stop_iteration();
31 }
32
33 // Since this is the iterator for .keys() and __iter__(), return only the key.
34 IValue result = iter_->key();
35
36 // Advance the iterator for next time.
37 iter_++;
38
39 return result;
40 }
41
initScriptDictBindings(PyObject * module)42 void initScriptDictBindings(PyObject* module) {
43 auto m = py::handle(module).cast<py::module>();
44
45 py::class_<ScriptDictKeyIterator>(m, "ScriptDictKeyIterator")
46 .def(
47 "__next__",
48 [](ScriptDictKeyIterator& iter) {
49 auto result = iter.next();
50 return toPyObject(result);
51 })
52 .def("__iter__", [](ScriptDictKeyIterator& iter) { return iter; });
53
54 py::class_<ScriptDictIterator>(m, "ScriptDictIterator")
55 .def(
56 "__next__",
57 [](ScriptDictIterator& iter) {
58 auto result = iter.next();
59 return toPyObject(result);
60 })
61 .def("__iter__", [](ScriptDictIterator& iter) { return iter; });
62
63 py::class_<ScriptDict, std::shared_ptr<ScriptDict>>(m, "ScriptDict")
64 .def(py::init([](py::dict dict) {
65 TypePtr type = nullptr;
66
67 if (!dict.empty()) {
68 // If the source dictionary is nonempty, try to infer its type.
69 auto inferred_type = tryToInferType(dict);
70
71 if (!inferred_type.success()) {
72 std::stringstream ss;
73 ss << "Unable to infer type of dictionary: "
74 << inferred_type.reason();
75 throw JITException(ss.str());
76 }
77
78 type = inferred_type.type();
79 } else {
80 // If is empty, assume the type is Dict[str, Tensor] as is done in
81 // TorchScript code.
82 type = DictType::create(StringType::get(), TensorType::getInferred());
83 }
84
85 auto data = toIValue(std::move(dict), type);
86 return std::make_shared<ScriptDict>(data);
87 }))
88 .def(
89 "__repr__",
90 [](const std::shared_ptr<ScriptDict>& self) {
91 return toPyObject(self->repr());
92 })
93 .def(
94 "__bool__",
95 [](const std::shared_ptr<ScriptDict>& self) {
96 return toPyObject(self->toBool());
97 })
98 .def(
99 "__len__",
100 [](const std::shared_ptr<ScriptDict>& self) {
101 return toPyObject(self->len());
102 })
103 .def(
104 "__contains__",
105 [](const std::shared_ptr<ScriptDict>& self, py::object key) {
106 try {
107 return toPyObject(self->contains(
108 toIValue(std::move(key), self->type()->getKeyType())));
109 } catch (const py::cast_error& e) {
110 throw py::key_error();
111 }
112 })
113 .def(
114 "__getitem__",
115 [](const std::shared_ptr<ScriptDict>& self, py::object key) {
116 IValue value;
117
118 // Convert key to IValue.
119 try {
120 value = toIValue(std::move(key), self->type()->getKeyType());
121 } catch (const py::cast_error& e) {
122 // It would be nice to throw py::type_error here but py::key_error
123 // needs to be thrown for parity with eager mode.
124 throw py::key_error();
125 }
126
127 // Call getItem on self.
128 try {
129 value = self->getItem(value);
130 } catch (const std::out_of_range& e) { // Key doesn't exist.
131 throw py::key_error();
132 }
133
134 return toPyObject(std::move(value));
135 },
136 py::return_value_policy::
137 reference_internal) // Return value is a reference to an object
138 // that resides in the ScriptDict
139 .def(
140 "__setitem__",
141 [](const std::shared_ptr<ScriptDict>& self,
142 py::object key,
143 py::object value) {
144 IValue key_ivalue, value_ivalue;
145
146 // Try to convert the key to an IValue.
147 try {
148 key_ivalue = toIValue(std::move(key), self->type()->getKeyType());
149 } catch (const py::cast_error& e) {
150 throw py::type_error();
151 }
152
153 // Try to convert the value to an IValue.
154 try {
155 value_ivalue =
156 toIValue(std::move(value), self->type()->getValueType());
157 } catch (const py::cast_error& e) {
158 throw py::type_error();
159 }
160
161 self->setItem(key_ivalue, value_ivalue);
162 })
163 .def(
164 "__delitem__",
165 [](const std::shared_ptr<ScriptDict>& self, py::object key) {
166 IValue key_ivalue;
167
168 // Try to convert the key to an IValue.
169 try {
170 key_ivalue = toIValue(std::move(key), self->type()->getKeyType());
171 } catch (const py::cast_error& e) {
172 throw py::type_error();
173 }
174
175 // If removed = false, that means the key didn't exist in the
176 // dictionary.
177 bool removed = self->delItem(key_ivalue);
178
179 if (!removed) {
180 throw py::key_error();
181 }
182 })
183 .def(
184 "__iter__",
185 [](const std::shared_ptr<ScriptDict>& self) { return self->iter(); },
186 py::keep_alive<0, 1>()) // ScriptDict needs to be alive at least as
187 // long as the iterator
188 .def(
189 "items",
190 [](const std::shared_ptr<ScriptDict>& self) { return self->items(); },
191 py::keep_alive<0, 1>()) // ScriptDict needs to be alive at least as
192 // long as the iterator
193 .def(
194 "keys",
195 [](const std::shared_ptr<ScriptDict>& self) { return self->iter(); },
196 py::keep_alive<0, 1>()); // ScriptDict needs to be alive at least as
197 // long as the iterator
198 }
199
200 } // namespace torch::jit
201