xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/tensor_list.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/utils/tensor_list.h>
2 
3 #include <c10/util/irange.h>
4 #include <pybind11/pybind11.h>
5 #include <torch/csrc/Exceptions.h>
6 #include <torch/csrc/autograd/python_variable.h>
7 #include <torch/csrc/utils/pybind.h>
8 #include <torch/csrc/utils/python_scalars.h>
9 
10 using namespace at;
11 
12 namespace torch::utils {
13 
recursive_to_list(const char * data,IntArrayRef sizes,IntArrayRef strides,int64_t dim,ScalarType scalarType,size_t elementSize)14 static PyObject* recursive_to_list(
15     const char* data,
16     IntArrayRef sizes,
17     IntArrayRef strides,
18     int64_t dim,
19     ScalarType scalarType,
20     size_t elementSize) {
21   int64_t ndim = static_cast<int64_t>(sizes.size());
22   if (dim == ndim) {
23     return torch::utils::load_scalar(data, scalarType);
24   }
25   auto n = sizes[dim];
26   auto list = THPObjectPtr(PyList_New(n));
27   if (!list)
28     throw python_error();
29   for (const auto i : c10::irange(n)) {
30     PyObject* obj = recursive_to_list(
31         data, sizes, strides, dim + 1, scalarType, elementSize);
32     if (!obj)
33       throw python_error();
34     PyList_SET_ITEM(list.get(), i, obj);
35     auto advance_data_ptr = strides[dim] * elementSize;
36     TORCH_INTERNAL_ASSERT(data || (advance_data_ptr == 0));
37     data += advance_data_ptr;
38   }
39   return list.release();
40 }
41 
tensor_to_list(const Tensor & tensor)42 PyObject* tensor_to_list(const Tensor& tensor) {
43   {
44     py::object pytensor =
45         py::reinterpret_steal<py::object>(THPVariable_Wrap(tensor));
46     TORCH_CHECK(
47         !tensor.unsafeGetTensorImpl()->is_python_dispatch(),
48         ".tolist() is not supported for tensor subclasses, got ",
49         Py_TYPE(pytensor.ptr())->tp_name);
50   }
51   Tensor data = tensor.resolve_conj().resolve_neg();
52   if (!data.device().is_cpu()) {
53     pybind11::gil_scoped_release no_gil;
54     data = data.toBackend(Backend::CPU);
55   }
56   TORCH_CHECK(
57       tensor.numel() == 0 || data.const_data_ptr(),
58       "tolist() shouldn't be called on a tensor with unallocated storage");
59   return recursive_to_list(
60       (const char*)data.const_data_ptr(),
61       data.sizes(),
62       data.strides(),
63       0,
64       data.scalar_type(),
65       tensor.numel() == 0 ? 0 : data.dtype().itemsize());
66 }
67 
68 } // namespace torch::utils
69