1 #include <ATen/core/ivalue.h>
2 #include <c10/util/irange.h>
3 #include <pybind11/detail/common.h>
4 #include <pybind11/pytypes.h>
5 #include <torch/csrc/jit/python/pybind_utils.h>
6 #include <torch/csrc/jit/python/python_list.h>
7 #include <torch/csrc/utils/pybind.h>
8 #include <stdexcept>
9
10 namespace torch::jit {
11
next()12 IValue ScriptListIterator::next() {
13 if (iter_ == end_) {
14 throw py::stop_iteration();
15 }
16
17 IValue result = *iter_;
18
19 // Advance the iterator for next time.
20 iter_++;
21
22 return result;
23 }
24
done() const25 bool ScriptListIterator::done() const {
26 return iter_ == end_;
27 }
28
29 namespace {
scriptListToPyList(const ScriptList & src)30 py::list scriptListToPyList(const ScriptList& src) {
31 py::list out(src.len());
32 auto iter = src.iter();
33
34 size_t i = 0;
35 while (!iter.done()) {
36 auto val = iter.next();
37 // TODO: Handle nested dictionaries.
38 if (val.isList()) {
39 out[i] = scriptListToPyList(val);
40 } else {
41 out[i] = toPyObject(val);
42 }
43 ++i;
44 }
45
46 return out;
47 }
48 } // namespace
49
initScriptListBindings(PyObject * module)50 void initScriptListBindings(PyObject* module) {
51 auto m = py::handle(module).cast<py::module>();
52
53 py::class_<ScriptListIterator>(m, "ScriptListIterator")
54 .def(
55 "__next__",
56 [](ScriptListIterator& iter) {
57 auto result = iter.next();
58 return toPyObject(result);
59 })
60 .def("__iter__", [](ScriptListIterator& iter) { return iter; });
61
62 py::class_<ScriptList, std::shared_ptr<ScriptList>>(m, "ScriptList")
63 .def(py::init([](py::list list) {
64 TypePtr type = nullptr;
65
66 if (!list.empty()) {
67 // If the source list is nonempty, try to infer its type.
68 auto inferred_type = tryToInferType(list);
69
70 if (!inferred_type.success()) {
71 std::stringstream ss;
72 ss << "Unable to infer type of list: " << inferred_type.reason();
73 throw JITException(ss.str());
74 }
75
76 type = inferred_type.type();
77 } else {
78 // If is empty, assume the type is List[Tensor] as is done in
79 // TorchScript code.
80 type = ListType::create(TensorType::getInferred());
81 }
82
83 auto data = toIValue(std::move(list), type);
84 return std::make_shared<ScriptList>(data);
85 }))
86 .def(
87 "__repr__",
88 [](const std::shared_ptr<ScriptList>& self) {
89 return toPyObject(self->repr());
90 })
91 .def(
92 "__bool__",
93 [](const std::shared_ptr<ScriptList>& self) {
94 return toPyObject(self->toBool());
95 })
96 .def(
97 "__len__",
98 [](const std::shared_ptr<ScriptList>& self) {
99 return toPyObject(static_cast<int64_t>(self->len()));
100 })
101 .def(
102 "__contains__",
103 [](const std::shared_ptr<ScriptList>& self, py::object elem) {
104 try {
105 return toPyObject(self->contains(
106 toIValue(std::move(elem), self->type()->getElementType())));
107 } catch (const py::cast_error& e) {
108 throw py::type_error();
109 }
110 })
111 .def(
112 "__getitem__",
113 [](const std::shared_ptr<ScriptList>& self,
114 ScriptList::diff_type idx) {
115 try {
116 auto value = self->getItem(idx);
117 return toPyObject(value);
118 } catch (const std::out_of_range& e) {
119 throw py::index_error();
120 }
121 },
122 py::return_value_policy::
123 reference_internal) // Return value is a reference to an object
124 // that resides in the ScriptList
125 .def(
126 "__getitem__",
127 [](const std::shared_ptr<ScriptList>& self, const py::slice& slice) {
128 size_t start = 0, stop = 0, step = 0, slicelength = 0;
129
130 if (!slice.compute(
131 self->len(), &start, &stop, &step, &slicelength)) {
132 throw py::error_already_set();
133 }
134
135 auto seq = std::make_shared<ScriptList>(self->type());
136
137 for (const auto i [[maybe_unused]] : c10::irange(slicelength)) {
138 seq->append(self->getItem(static_cast<ptrdiff_t>(start)));
139 start += step;
140 }
141
142 return seq;
143 })
144 .def(
145 "__setitem__",
146 [](const std::shared_ptr<ScriptList>& self,
147 ScriptList::diff_type idx,
148 py::object value) {
149 try {
150 self->setItem(
151 idx,
152 toIValue(std::move(value), self->type()->getElementType()));
153 } catch (const std::out_of_range& e) {
154 throw py::index_error();
155 } catch (const py::cast_error& e) {
156 throw py::type_error();
157 }
158 })
159 .def(
160 "__setitem__",
161 [](const std::shared_ptr<ScriptList>& self,
162 const py::slice& slice,
163 const py::list& value) {
164 size_t start = 0, stop = 0, step = 0, slicelength = 0;
165
166 if (!slice.compute(
167 self->len(), &start, &stop, &step, &slicelength)) {
168 throw py::error_already_set();
169 }
170
171 if (slicelength != value.size()) {
172 throw std::runtime_error(
173 "Left and right hand size of slice assignment have different sizes");
174 }
175
176 for (const auto i : c10::irange(slicelength)) {
177 try {
178 self->setItem(
179 static_cast<ptrdiff_t>(start),
180 toIValue(value[i], self->type()->getElementType()));
181 } catch (const py::cast_error& e) {
182 throw py::type_error();
183 }
184 start += step;
185 }
186 })
187 .def(
188 "__delitem__",
189 [](const std::shared_ptr<ScriptList>& self,
190 ScriptList::diff_type idx) {
191 try {
192 self->delItem(idx);
193 } catch (const std::out_of_range& e) {
194 throw py::index_error();
195 }
196 })
197 .def(
198 "__iter__",
199 [](const std::shared_ptr<ScriptList>& self) { return self->iter(); },
200 py::keep_alive<0, 1>()) // ScriptList needs to be alive at least as
201 // long as the iterator
202 .def(
203 "count",
204 [](const std::shared_ptr<ScriptList>& self, py::object value) {
205 try {
206 return self->count(
207 toIValue(std::move(value), self->type()->getElementType()));
208
209 } catch (const py::cast_error& e) {
210 throw py::type_error();
211 }
212 })
213 .def(
214 "remove",
215 [](const std::shared_ptr<ScriptList>& self, py::object value) {
216 try {
217 return self->remove(
218 toIValue(std::move(value), self->type()->getElementType()));
219 } catch (const py::cast_error& e) {
220 throw py::type_error();
221 }
222 })
223 .def(
224 "append",
225 [](const std::shared_ptr<ScriptList>& self, py::object value) {
226 try {
227 return self->append(
228 toIValue(std::move(value), self->type()->getElementType()));
229 } catch (const py::cast_error& e) {
230 throw py::type_error();
231 }
232 })
233 .def(
234 "clear",
235 [](const std::shared_ptr<ScriptList>& self) { self->clear(); })
236 .def(
237 "extend",
238 [](const std::shared_ptr<ScriptList>& self, py::list list) {
239 try {
240 self->extend(toIValue(std::move(list), self->type()));
241 } catch (const py::cast_error& e) {
242 throw py::type_error();
243 }
244 })
245 .def(
246 "extend",
247 [](const std::shared_ptr<ScriptList>& self,
248 const py::iterable& iter) {
249 ScriptList iter_list(self->type());
250
251 try {
252 for (py::handle obj : iter) {
253 iter_list.append(toIValue(
254 py::reinterpret_borrow<py::object>(obj),
255 self->type()->getElementType()));
256 }
257 } catch (const py::cast_error& e) {
258 throw py::type_error();
259 }
260
261 self->extend(toIValue(py::cast(iter_list), self->type()));
262 })
263 .def(
264 "pop",
265 [](const std::shared_ptr<ScriptList>& self) {
266 return toPyObject(self->pop());
267 })
268 .def(
269 "pop",
270 [](const std::shared_ptr<ScriptList>& self,
271 ScriptList::diff_type idx) { return toPyObject(self->pop(idx)); })
272 .def(
273 "insert",
274 [](const std::shared_ptr<ScriptList>& self,
275 ScriptList::diff_type idx,
276 py::object obj) {
277 try {
278 self->insert(
279 toIValue(std::move(obj), self->type()->getElementType()),
280 idx);
281 } catch (const py::cast_error& e) {
282 throw py::type_error();
283 }
284 })
285 .def(py::pickle(
286 [](const ScriptList& data) { // __getstate__
287 return scriptListToPyList(data);
288 },
289 [](py::list list) { // __setstate__
290 TypePtr type = nullptr;
291
292 if (!list.empty()) {
293 // If the source list is nonempty, try to infer its type.
294 auto inferred_type = tryToInferType(list);
295
296 if (!inferred_type.success()) {
297 std::stringstream ss;
298 ss << "Unable to infer type of list: "
299 << inferred_type.reason();
300 throw JITException(ss.str());
301 }
302
303 type = inferred_type.type();
304 } else {
305 // If is empty, assume the type is List[Tensor] as is done in
306 // TorchScript code.
307 type = ListType::create(TensorType::getInferred());
308 }
309
310 auto data = toIValue(std::move(list), type);
311 return std::make_shared<ScriptList>(data);
312 }));
313 }
314
315 } // namespace torch::jit
316