xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/python/python_dict.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/ivalue.h>
2 #include <pybind11/cast.h>
3 #include <pybind11/detail/common.h>
4 #include <torch/csrc/jit/python/pybind_utils.h>
5 #include <torch/csrc/jit/python/python_dict.h>
6 #include <torch/csrc/jit/runtime/jit_exception.h>
7 #include <torch/csrc/utils/pybind.h>
8 #include <sstream>
9 #include <stdexcept>
10 
11 namespace torch::jit {
12 
next()13 IValue ScriptDictIterator::next() {
14   if (iter_ == end_) {
15     throw py::stop_iteration();
16   }
17 
18   // Since this is the iterator for .items(), the current key and value
19   // should be returned as a tuple.
20   IValue result = c10::ivalue::Tuple::create({iter_->key(), iter_->value()});
21 
22   // Advance the iterator for next time.
23   iter_++;
24 
25   return result;
26 }
27 
next()28 IValue ScriptDictKeyIterator::next() {
29   if (iter_ == end_) {
30     throw py::stop_iteration();
31   }
32 
33   // Since this is the iterator for .keys() and __iter__(), return only the key.
34   IValue result = iter_->key();
35 
36   // Advance the iterator for next time.
37   iter_++;
38 
39   return result;
40 }
41 
initScriptDictBindings(PyObject * module)42 void initScriptDictBindings(PyObject* module) {
43   auto m = py::handle(module).cast<py::module>();
44 
45   py::class_<ScriptDictKeyIterator>(m, "ScriptDictKeyIterator")
46       .def(
47           "__next__",
48           [](ScriptDictKeyIterator& iter) {
49             auto result = iter.next();
50             return toPyObject(result);
51           })
52       .def("__iter__", [](ScriptDictKeyIterator& iter) { return iter; });
53 
54   py::class_<ScriptDictIterator>(m, "ScriptDictIterator")
55       .def(
56           "__next__",
57           [](ScriptDictIterator& iter) {
58             auto result = iter.next();
59             return toPyObject(result);
60           })
61       .def("__iter__", [](ScriptDictIterator& iter) { return iter; });
62 
63   py::class_<ScriptDict, std::shared_ptr<ScriptDict>>(m, "ScriptDict")
64       .def(py::init([](py::dict dict) {
65         TypePtr type = nullptr;
66 
67         if (!dict.empty()) {
68           // If the source dictionary is nonempty, try to infer its type.
69           auto inferred_type = tryToInferType(dict);
70 
71           if (!inferred_type.success()) {
72             std::stringstream ss;
73             ss << "Unable to infer type of dictionary: "
74                << inferred_type.reason();
75             throw JITException(ss.str());
76           }
77 
78           type = inferred_type.type();
79         } else {
80           // If is empty, assume the type is Dict[str, Tensor] as is done in
81           // TorchScript code.
82           type = DictType::create(StringType::get(), TensorType::getInferred());
83         }
84 
85         auto data = toIValue(std::move(dict), type);
86         return std::make_shared<ScriptDict>(data);
87       }))
88       .def(
89           "__repr__",
90           [](const std::shared_ptr<ScriptDict>& self) {
91             return toPyObject(self->repr());
92           })
93       .def(
94           "__bool__",
95           [](const std::shared_ptr<ScriptDict>& self) {
96             return toPyObject(self->toBool());
97           })
98       .def(
99           "__len__",
100           [](const std::shared_ptr<ScriptDict>& self) {
101             return toPyObject(self->len());
102           })
103       .def(
104           "__contains__",
105           [](const std::shared_ptr<ScriptDict>& self, py::object key) {
106             try {
107               return toPyObject(self->contains(
108                   toIValue(std::move(key), self->type()->getKeyType())));
109             } catch (const py::cast_error& e) {
110               throw py::key_error();
111             }
112           })
113       .def(
114           "__getitem__",
115           [](const std::shared_ptr<ScriptDict>& self, py::object key) {
116             IValue value;
117 
118             // Convert key to IValue.
119             try {
120               value = toIValue(std::move(key), self->type()->getKeyType());
121             } catch (const py::cast_error& e) {
122               // It would be nice to throw py::type_error here but py::key_error
123               // needs to be thrown for parity with eager mode.
124               throw py::key_error();
125             }
126 
127             // Call getItem on self.
128             try {
129               value = self->getItem(value);
130             } catch (const std::out_of_range& e) { // Key doesn't exist.
131               throw py::key_error();
132             }
133 
134             return toPyObject(std::move(value));
135           },
136           py::return_value_policy::
137               reference_internal) // Return value is a reference to an object
138                                   // that resides in the ScriptDict
139       .def(
140           "__setitem__",
141           [](const std::shared_ptr<ScriptDict>& self,
142              py::object key,
143              py::object value) {
144             IValue key_ivalue, value_ivalue;
145 
146             // Try to convert the key to an IValue.
147             try {
148               key_ivalue = toIValue(std::move(key), self->type()->getKeyType());
149             } catch (const py::cast_error& e) {
150               throw py::type_error();
151             }
152 
153             // Try to convert the value to an IValue.
154             try {
155               value_ivalue =
156                   toIValue(std::move(value), self->type()->getValueType());
157             } catch (const py::cast_error& e) {
158               throw py::type_error();
159             }
160 
161             self->setItem(key_ivalue, value_ivalue);
162           })
163       .def(
164           "__delitem__",
165           [](const std::shared_ptr<ScriptDict>& self, py::object key) {
166             IValue key_ivalue;
167 
168             // Try to convert the key to an IValue.
169             try {
170               key_ivalue = toIValue(std::move(key), self->type()->getKeyType());
171             } catch (const py::cast_error& e) {
172               throw py::type_error();
173             }
174 
175             // If removed = false, that means the key didn't exist in the
176             // dictionary.
177             bool removed = self->delItem(key_ivalue);
178 
179             if (!removed) {
180               throw py::key_error();
181             }
182           })
183       .def(
184           "__iter__",
185           [](const std::shared_ptr<ScriptDict>& self) { return self->iter(); },
186           py::keep_alive<0, 1>()) // ScriptDict needs to be alive at least as
187                                   // long as the iterator
188       .def(
189           "items",
190           [](const std::shared_ptr<ScriptDict>& self) { return self->items(); },
191           py::keep_alive<0, 1>()) // ScriptDict needs to be alive at least as
192                                   // long as the iterator
193       .def(
194           "keys",
195           [](const std::shared_ptr<ScriptDict>& self) { return self->iter(); },
196           py::keep_alive<0, 1>()); // ScriptDict needs to be alive at least as
197                                    // long as the iterator
198 }
199 
200 } // namespace torch::jit
201