xref: /aosp_15_r20/external/pytorch/torch/csrc/DynamicTypes.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/python_headers.h>
2 
3 #include <torch/csrc/Device.h>
4 #include <torch/csrc/Dtype.h>
5 #include <torch/csrc/DynamicTypes.h>
6 #include <torch/csrc/Exceptions.h>
7 #include <torch/csrc/Layout.h>
8 #include <torch/csrc/Storage.h>
9 #include <torch/csrc/autograd/generated/VariableType.h>
10 #include <torch/csrc/utils/cuda_enabled.h>
11 #include <torch/csrc/utils/device_lazy_init.h>
12 #include <torch/csrc/utils/object_ptr.h>
13 
14 #include <ATen/ATen.h>
15 #include <ATen/FunctionalStorageImpl.h>
16 
17 #include <array>
18 #include <stdexcept>
19 
20 namespace torch {
21 namespace {
22 std::array<THPDtype*, static_cast<int>(at::ScalarType::NumOptions)>
23     dtype_registry = {};
24 
25 std::array<THPLayout*, static_cast<int>(at::Layout::NumOptions)>
26     layout_registry = {};
27 
28 } // namespace
29 
registerDtypeObject(THPDtype * dtype,at::ScalarType scalarType)30 void registerDtypeObject(THPDtype* dtype, at::ScalarType scalarType) {
31   dtype_registry[static_cast<int>(scalarType)] = dtype;
32 }
33 
registerLayoutObject(THPLayout * thp_layout,at::Layout layout)34 void registerLayoutObject(THPLayout* thp_layout, at::Layout layout) {
35   layout_registry[static_cast<int>(layout)] = thp_layout;
36 }
37 
getTHPDtype(at::ScalarType scalarType)38 THPDtype* getTHPDtype(at::ScalarType scalarType) {
39   auto dtype = dtype_registry[static_cast<int>(scalarType)];
40   if (!dtype) {
41     throw std::invalid_argument("unsupported scalarType");
42   }
43   return dtype;
44 }
45 
getTHPLayout(at::Layout layout)46 THPLayout* getTHPLayout(at::Layout layout) {
47   auto thp_layout = layout_registry[static_cast<int>(layout)];
48   if (!thp_layout) {
49     throw std::invalid_argument("unsupported at::Layout");
50   }
51   return thp_layout;
52 }
53 
createPyObject(const at::Storage & storage)54 PyObject* createPyObject(const at::Storage& storage) {
55   // Note [Invalid Python Storages]
56   // When a user creates a python tensor wrapper subclass, the subclass
57   // is a tensor object that has a nullptr storage.
58   // We still allow users to call `my_subclass.untyped_storage()`, and get back
59   // a valid storage object (this can be useful for detecting aliasing
60   // information about storages from python). However, any accesses to the
61   // data_ptr is not allowed, through methods like
62   // x.untyped_storage().data_ptr()
63   PyObject* obj = THPStorage_Wrap(storage);
64   if (!obj)
65     throw python_error();
66   return obj;
67 }
68 
loadTypedStorageTypeObject()69 PyTypeObject* loadTypedStorageTypeObject() {
70   PyObject* storage_module = PyImport_ImportModule("torch.storage");
71   TORCH_INTERNAL_ASSERT(storage_module && PyModule_Check(storage_module));
72 
73   PyObject* typed_storage_obj =
74       PyObject_GetAttrString(storage_module, "TypedStorage");
75   TORCH_INTERNAL_ASSERT(typed_storage_obj && PyType_Check(typed_storage_obj));
76   return reinterpret_cast<PyTypeObject*>(
77       PyObject_GetAttrString(storage_module, "TypedStorage"));
78 }
79 
getTypedStorageTypeObject()80 PyTypeObject* getTypedStorageTypeObject() {
81   // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
82   static PyTypeObject* typed_storage_type_obj = loadTypedStorageTypeObject();
83   return typed_storage_type_obj;
84 }
85 
isStorage(PyObject * obj)86 bool isStorage(PyObject* obj) {
87   if (PyObject_TypeCheck(obj, getTypedStorageTypeObject())) {
88     return true;
89   }
90   return THPStorage_Check(obj);
91 }
92 
createStorageGetType(PyObject * obj)93 std::tuple<at::Storage, at::ScalarType, bool> createStorageGetType(
94     PyObject* obj) {
95   at::ScalarType scalar_type = at::ScalarType::Undefined;
96   bool is_typed_storage = PyObject_TypeCheck(obj, getTypedStorageTypeObject());
97   PyObject* untyped_storage_obj = nullptr;
98 
99   if (is_typed_storage) {
100     // NOTE: `PyObject_GetAttrString` increments the refcounts to `dtype` and
101     // `_untyped_storage`, so we must decrement them. The refcounts will still
102     // stay nonzero since the `TypedStorage` maintains a reference.
103     PyObject* dtype_obj = PyObject_GetAttrString(obj, "dtype");
104     TORCH_INTERNAL_ASSERT(dtype_obj);
105     TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_obj));
106     scalar_type = reinterpret_cast<THPDtype*>(dtype_obj)->scalar_type;
107     Py_DECREF(dtype_obj);
108 
109     untyped_storage_obj = PyObject_GetAttrString(obj, "_untyped_storage");
110     TORCH_INTERNAL_ASSERT(untyped_storage_obj);
111     Py_DECREF(untyped_storage_obj);
112 
113   } else {
114     scalar_type = at::kByte;
115     untyped_storage_obj = obj;
116   }
117 
118   TORCH_CHECK(
119       THPStorage_Check(untyped_storage_obj),
120       "not a storage '",
121       Py_TYPE(obj)->tp_name,
122       "'");
123 
124   auto storage = THPStorage_Unpack(untyped_storage_obj);
125   return std::make_tuple(storage, scalar_type, is_typed_storage);
126 }
127 
createStorage(PyObject * obj)128 at::Storage createStorage(PyObject* obj) {
129   return std::get<0>(createStorageGetType(obj));
130 }
131 
132 } // namespace torch
133