xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/python/python_list.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/ivalue.h>
2 #include <c10/util/irange.h>
3 #include <pybind11/detail/common.h>
4 #include <pybind11/pytypes.h>
5 #include <torch/csrc/jit/python/pybind_utils.h>
6 #include <torch/csrc/jit/python/python_list.h>
7 #include <torch/csrc/utils/pybind.h>
8 #include <stdexcept>
9 
10 namespace torch::jit {
11 
next()12 IValue ScriptListIterator::next() {
13   if (iter_ == end_) {
14     throw py::stop_iteration();
15   }
16 
17   IValue result = *iter_;
18 
19   // Advance the iterator for next time.
20   iter_++;
21 
22   return result;
23 }
24 
done() const25 bool ScriptListIterator::done() const {
26   return iter_ == end_;
27 }
28 
29 namespace {
scriptListToPyList(const ScriptList & src)30 py::list scriptListToPyList(const ScriptList& src) {
31   py::list out(src.len());
32   auto iter = src.iter();
33 
34   size_t i = 0;
35   while (!iter.done()) {
36     auto val = iter.next();
37     // TODO: Handle nested dictionaries.
38     if (val.isList()) {
39       out[i] = scriptListToPyList(val);
40     } else {
41       out[i] = toPyObject(val);
42     }
43     ++i;
44   }
45 
46   return out;
47 }
48 } // namespace
49 
initScriptListBindings(PyObject * module)50 void initScriptListBindings(PyObject* module) {
51   auto m = py::handle(module).cast<py::module>();
52 
53   py::class_<ScriptListIterator>(m, "ScriptListIterator")
54       .def(
55           "__next__",
56           [](ScriptListIterator& iter) {
57             auto result = iter.next();
58             return toPyObject(result);
59           })
60       .def("__iter__", [](ScriptListIterator& iter) { return iter; });
61 
62   py::class_<ScriptList, std::shared_ptr<ScriptList>>(m, "ScriptList")
63       .def(py::init([](py::list list) {
64         TypePtr type = nullptr;
65 
66         if (!list.empty()) {
67           // If the source list is nonempty, try to infer its type.
68           auto inferred_type = tryToInferType(list);
69 
70           if (!inferred_type.success()) {
71             std::stringstream ss;
72             ss << "Unable to infer type of list: " << inferred_type.reason();
73             throw JITException(ss.str());
74           }
75 
76           type = inferred_type.type();
77         } else {
78           // If is empty, assume the type is List[Tensor] as is done in
79           // TorchScript code.
80           type = ListType::create(TensorType::getInferred());
81         }
82 
83         auto data = toIValue(std::move(list), type);
84         return std::make_shared<ScriptList>(data);
85       }))
86       .def(
87           "__repr__",
88           [](const std::shared_ptr<ScriptList>& self) {
89             return toPyObject(self->repr());
90           })
91       .def(
92           "__bool__",
93           [](const std::shared_ptr<ScriptList>& self) {
94             return toPyObject(self->toBool());
95           })
96       .def(
97           "__len__",
98           [](const std::shared_ptr<ScriptList>& self) {
99             return toPyObject(static_cast<int64_t>(self->len()));
100           })
101       .def(
102           "__contains__",
103           [](const std::shared_ptr<ScriptList>& self, py::object elem) {
104             try {
105               return toPyObject(self->contains(
106                   toIValue(std::move(elem), self->type()->getElementType())));
107             } catch (const py::cast_error& e) {
108               throw py::type_error();
109             }
110           })
111       .def(
112           "__getitem__",
113           [](const std::shared_ptr<ScriptList>& self,
114              ScriptList::diff_type idx) {
115             try {
116               auto value = self->getItem(idx);
117               return toPyObject(value);
118             } catch (const std::out_of_range& e) {
119               throw py::index_error();
120             }
121           },
122           py::return_value_policy::
123               reference_internal) // Return value is a reference to an object
124                                   // that resides in the ScriptList
125       .def(
126           "__getitem__",
127           [](const std::shared_ptr<ScriptList>& self, const py::slice& slice) {
128             size_t start = 0, stop = 0, step = 0, slicelength = 0;
129 
130             if (!slice.compute(
131                     self->len(), &start, &stop, &step, &slicelength)) {
132               throw py::error_already_set();
133             }
134 
135             auto seq = std::make_shared<ScriptList>(self->type());
136 
137             for (const auto i [[maybe_unused]] : c10::irange(slicelength)) {
138               seq->append(self->getItem(static_cast<ptrdiff_t>(start)));
139               start += step;
140             }
141 
142             return seq;
143           })
144       .def(
145           "__setitem__",
146           [](const std::shared_ptr<ScriptList>& self,
147              ScriptList::diff_type idx,
148              py::object value) {
149             try {
150               self->setItem(
151                   idx,
152                   toIValue(std::move(value), self->type()->getElementType()));
153             } catch (const std::out_of_range& e) {
154               throw py::index_error();
155             } catch (const py::cast_error& e) {
156               throw py::type_error();
157             }
158           })
159       .def(
160           "__setitem__",
161           [](const std::shared_ptr<ScriptList>& self,
162              const py::slice& slice,
163              const py::list& value) {
164             size_t start = 0, stop = 0, step = 0, slicelength = 0;
165 
166             if (!slice.compute(
167                     self->len(), &start, &stop, &step, &slicelength)) {
168               throw py::error_already_set();
169             }
170 
171             if (slicelength != value.size()) {
172               throw std::runtime_error(
173                   "Left and right hand size of slice assignment have different sizes");
174             }
175 
176             for (const auto i : c10::irange(slicelength)) {
177               try {
178                 self->setItem(
179                     static_cast<ptrdiff_t>(start),
180                     toIValue(value[i], self->type()->getElementType()));
181               } catch (const py::cast_error& e) {
182                 throw py::type_error();
183               }
184               start += step;
185             }
186           })
187       .def(
188           "__delitem__",
189           [](const std::shared_ptr<ScriptList>& self,
190              ScriptList::diff_type idx) {
191             try {
192               self->delItem(idx);
193             } catch (const std::out_of_range& e) {
194               throw py::index_error();
195             }
196           })
197       .def(
198           "__iter__",
199           [](const std::shared_ptr<ScriptList>& self) { return self->iter(); },
200           py::keep_alive<0, 1>()) // ScriptList needs to be alive at least as
201                                   // long as the iterator
202       .def(
203           "count",
204           [](const std::shared_ptr<ScriptList>& self, py::object value) {
205             try {
206               return self->count(
207                   toIValue(std::move(value), self->type()->getElementType()));
208 
209             } catch (const py::cast_error& e) {
210               throw py::type_error();
211             }
212           })
213       .def(
214           "remove",
215           [](const std::shared_ptr<ScriptList>& self, py::object value) {
216             try {
217               return self->remove(
218                   toIValue(std::move(value), self->type()->getElementType()));
219             } catch (const py::cast_error& e) {
220               throw py::type_error();
221             }
222           })
223       .def(
224           "append",
225           [](const std::shared_ptr<ScriptList>& self, py::object value) {
226             try {
227               return self->append(
228                   toIValue(std::move(value), self->type()->getElementType()));
229             } catch (const py::cast_error& e) {
230               throw py::type_error();
231             }
232           })
233       .def(
234           "clear",
235           [](const std::shared_ptr<ScriptList>& self) { self->clear(); })
236       .def(
237           "extend",
238           [](const std::shared_ptr<ScriptList>& self, py::list list) {
239             try {
240               self->extend(toIValue(std::move(list), self->type()));
241             } catch (const py::cast_error& e) {
242               throw py::type_error();
243             }
244           })
245       .def(
246           "extend",
247           [](const std::shared_ptr<ScriptList>& self,
248              const py::iterable& iter) {
249             ScriptList iter_list(self->type());
250 
251             try {
252               for (py::handle obj : iter) {
253                 iter_list.append(toIValue(
254                     py::reinterpret_borrow<py::object>(obj),
255                     self->type()->getElementType()));
256               }
257             } catch (const py::cast_error& e) {
258               throw py::type_error();
259             }
260 
261             self->extend(toIValue(py::cast(iter_list), self->type()));
262           })
263       .def(
264           "pop",
265           [](const std::shared_ptr<ScriptList>& self) {
266             return toPyObject(self->pop());
267           })
268       .def(
269           "pop",
270           [](const std::shared_ptr<ScriptList>& self,
271              ScriptList::diff_type idx) { return toPyObject(self->pop(idx)); })
272       .def(
273           "insert",
274           [](const std::shared_ptr<ScriptList>& self,
275              ScriptList::diff_type idx,
276              py::object obj) {
277             try {
278               self->insert(
279                   toIValue(std::move(obj), self->type()->getElementType()),
280                   idx);
281             } catch (const py::cast_error& e) {
282               throw py::type_error();
283             }
284           })
285       .def(py::pickle(
286           [](const ScriptList& data) { // __getstate__
287             return scriptListToPyList(data);
288           },
289           [](py::list list) { // __setstate__
290             TypePtr type = nullptr;
291 
292             if (!list.empty()) {
293               // If the source list is nonempty, try to infer its type.
294               auto inferred_type = tryToInferType(list);
295 
296               if (!inferred_type.success()) {
297                 std::stringstream ss;
298                 ss << "Unable to infer type of list: "
299                    << inferred_type.reason();
300                 throw JITException(ss.str());
301               }
302 
303               type = inferred_type.type();
304             } else {
305               // If is empty, assume the type is List[Tensor] as is done in
306               // TorchScript code.
307               type = ListType::create(TensorType::getInferred());
308             }
309 
310             auto data = toIValue(std::move(list), type);
311             return std::make_shared<ScriptList>(data);
312           }));
313 }
314 
315 } // namespace torch::jit
316