1 #include <torch/csrc/utils/tensor_list.h>
2
3 #include <c10/util/irange.h>
4 #include <pybind11/pybind11.h>
5 #include <torch/csrc/Exceptions.h>
6 #include <torch/csrc/autograd/python_variable.h>
7 #include <torch/csrc/utils/pybind.h>
8 #include <torch/csrc/utils/python_scalars.h>
9
10 using namespace at;
11
12 namespace torch::utils {
13
recursive_to_list(const char * data,IntArrayRef sizes,IntArrayRef strides,int64_t dim,ScalarType scalarType,size_t elementSize)14 static PyObject* recursive_to_list(
15 const char* data,
16 IntArrayRef sizes,
17 IntArrayRef strides,
18 int64_t dim,
19 ScalarType scalarType,
20 size_t elementSize) {
21 int64_t ndim = static_cast<int64_t>(sizes.size());
22 if (dim == ndim) {
23 return torch::utils::load_scalar(data, scalarType);
24 }
25 auto n = sizes[dim];
26 auto list = THPObjectPtr(PyList_New(n));
27 if (!list)
28 throw python_error();
29 for (const auto i : c10::irange(n)) {
30 PyObject* obj = recursive_to_list(
31 data, sizes, strides, dim + 1, scalarType, elementSize);
32 if (!obj)
33 throw python_error();
34 PyList_SET_ITEM(list.get(), i, obj);
35 auto advance_data_ptr = strides[dim] * elementSize;
36 TORCH_INTERNAL_ASSERT(data || (advance_data_ptr == 0));
37 data += advance_data_ptr;
38 }
39 return list.release();
40 }
41
tensor_to_list(const Tensor & tensor)42 PyObject* tensor_to_list(const Tensor& tensor) {
43 {
44 py::object pytensor =
45 py::reinterpret_steal<py::object>(THPVariable_Wrap(tensor));
46 TORCH_CHECK(
47 !tensor.unsafeGetTensorImpl()->is_python_dispatch(),
48 ".tolist() is not supported for tensor subclasses, got ",
49 Py_TYPE(pytensor.ptr())->tp_name);
50 }
51 Tensor data = tensor.resolve_conj().resolve_neg();
52 if (!data.device().is_cpu()) {
53 pybind11::gil_scoped_release no_gil;
54 data = data.toBackend(Backend::CPU);
55 }
56 TORCH_CHECK(
57 tensor.numel() == 0 || data.const_data_ptr(),
58 "tolist() shouldn't be called on a tensor with unallocated storage");
59 return recursive_to_list(
60 (const char*)data.const_data_ptr(),
61 data.sizes(),
62 data.strides(),
63 0,
64 data.scalar_type(),
65 tensor.numel() == 0 ? 0 : data.dtype().itemsize());
66 }
67
68 } // namespace torch::utils
69