xref: /aosp_15_r20/external/pytorch/torch/csrc/Size.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <pybind11/pytypes.h>
3 #include <torch/csrc/Size.h>
4 #include <torch/csrc/utils/pybind.h>
5 
6 #include <torch/csrc/utils/object_ptr.h>
7 #include <torch/csrc/utils/python_arg_parser.h>
8 #include <torch/csrc/utils/python_numbers.h>
9 #include <torch/csrc/utils/python_strings.h>
10 #include <torch/csrc/utils/python_tuples.h>
11 #include <string>
12 
13 #include <torch/csrc/autograd/python_variable.h>
14 #include <torch/csrc/jit/frontend/tracer.h>
15 
16 struct THPSize {
17   PyTupleObject tuple;
18 };
19 
THPSize_New(const torch::autograd::Variable & var)20 PyObject* THPSize_New(const torch::autograd::Variable& var) {
21   if (!torch::jit::tracer::isTracing()) {
22     auto sizes = var.sizes();
23     return THPSize_NewFromSizes(var.dim(), sizes.data());
24   }
25   auto self = THPObjectPtr(THPSizeType.tp_alloc(&THPSizeType, var.dim()));
26   if (!self)
27     throw python_error();
28 
29   for (const auto i : c10::irange(var.dim())) {
30     PyObject* py_size_tensor =
31         THPVariable_Wrap(torch::jit::tracer::getSizeOf(var, i));
32     if (!py_size_tensor)
33       throw python_error();
34     PyTuple_SET_ITEM(self.get(), i, py_size_tensor);
35   }
36 
37   return self.release();
38 }
39 
THPSize_NewFromSizes(int64_t dim,const int64_t * sizes)40 PyObject* THPSize_NewFromSizes(int64_t dim, const int64_t* sizes) {
41   auto self = THPObjectPtr(THPSizeType.tp_alloc(&THPSizeType, dim));
42   if (!self)
43     throw python_error();
44   THPUtils_packInt64Array(self, dim, sizes);
45   return self.release();
46 }
47 
THPSize_NewFromSymSizes(const at::Tensor & self_)48 PyObject* THPSize_NewFromSymSizes(const at::Tensor& self_) {
49   auto sym_sizes = self_.sym_sizes();
50 
51   auto ret = THPObjectPtr(THPSizeType.tp_alloc(
52       &THPSizeType, static_cast<Py_ssize_t>(sym_sizes.size())));
53   if (!ret)
54     throw python_error();
55 
56   for (auto i : c10::irange(sym_sizes.size())) {
57     auto si = sym_sizes[i];
58     if (si.is_symbolic()) {
59       // First check for actual symbolic values.
60       // Reason: so that we don't replace it by its integer replacement
61       // implicitly.
62       TORCH_CHECK(
63           !torch::jit::tracer::isTracing(),
64           "JIT Tracing of SymInts isn't supported");
65       auto py_symint = py::cast(si).release().ptr();
66       if (!py_symint)
67         throw python_error();
68       PyTuple_SET_ITEM(ret.get(), i, py_symint);
69     } else {
70       // Otherwise, we know that it is an actual integer value.
71       auto m = si.maybe_as_int();
72       if (torch::jit::tracer::isTracing()) {
73         PyObject* py_size_tensor = THPVariable_Wrap(
74             torch::jit::tracer::getSizeOf(self_, static_cast<int64_t>(i)));
75         if (!py_size_tensor)
76           throw python_error();
77         PyTuple_SET_ITEM(ret.get(), i, py_size_tensor);
78       } else {
79         // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
80         PyTuple_SET_ITEM(ret.get(), i, THPUtils_packInt64(*m));
81       }
82     }
83   }
84   return ret.release();
85 }
86 
isTracedZeroDimVar(PyObject * item)87 static bool isTracedZeroDimVar(PyObject* item) {
88   if (!THPVariable_Check(item))
89     return false;
90   auto& var = THPVariable_Unpack(item);
91   return var.dim() == 0 && torch::jit::tracer::getValueTrace(var);
92 }
93 
THPSize_pynew(PyTypeObject * type,PyObject * args,PyObject * kwargs)94 static PyObject* THPSize_pynew(
95     PyTypeObject* type,
96     PyObject* args,
97     PyObject* kwargs) {
98   HANDLE_TH_ERRORS
99   THPObjectPtr self(PyTuple_Type.tp_new(type, args, kwargs));
100   if (self) {
101     for (Py_ssize_t i = 0; i < PyTuple_Size(self); ++i) {
102       PyObject* item = PyTuple_GET_ITEM(self.get(), i);
103       if (THPUtils_checkLong(item)) {
104         continue;
105       }
106       if (torch::is_symint(item)) {
107         continue;
108       }
109       if (torch::jit::tracer::isTracing() && isTracedZeroDimVar(item)) {
110         continue;
111       }
112       // item.__index__() works with 0-dim tensors and tensors with one element
113       THPObjectPtr number(PyNumber_Index(item));
114       if (number && THPUtils_checkLong(number.get())) {
115         Py_INCREF(number.get());
116         auto status = PyTuple_SetItem(self, i, number.get());
117         if (status != 0) {
118           throw python_error();
119         }
120         continue;
121       }
122       return PyErr_Format(
123           PyExc_TypeError,
124           "torch.Size() takes an iterable of 'int' (item %zd is '%s')",
125           i,
126           Py_TYPE(item)->tp_name);
127     }
128   }
129   return self.release();
130   END_HANDLE_TH_ERRORS
131 }
132 
THPSize_repr(THPSize * self)133 static PyObject* THPSize_repr(THPSize* self) {
134   HANDLE_TH_ERRORS
135   std::string repr("torch.Size([");
136   for (Py_ssize_t i = 0; i < PyTuple_Size((PyObject*)self); ++i) {
137     if (i != 0) {
138       repr += ", ";
139     }
140     auto item = PyTuple_GET_ITEM(self, i);
141     auto ih = py::handle(item);
142 
143     repr += torch::is_symint(ih)
144         ? std::string(py::str(ih))
145         : std::to_string(THPUtils_unpackLong(PyTuple_GET_ITEM(self, i)));
146   }
147   repr += "])";
148   return THPUtils_packString(repr);
149   END_HANDLE_TH_ERRORS
150 }
151 
152 extern PyTypeObject THPSizeType;
153 
154 template <typename FnType, FnType fn, typename... Args>
wrap_tuple_fn(Args...args)155 static PyObject* wrap_tuple_fn(Args... args) {
156   THPObjectPtr result((*fn)(std::forward<Args>(args)...));
157   if (!result)
158     return nullptr;
159   if (PyTuple_Check(result.get())) {
160     return PyObject_CallFunctionObjArgs(
161         (PyObject*)&THPSizeType, result.get(), nullptr);
162   }
163   return result.release();
164 }
165 
166 // We use an anonymous namespace instead of static to work around
167 // (what @peterjc123 think is) a bug in Visual Studio
168 namespace {
169 auto sq_concat = PyTuple_Type.tp_as_sequence->sq_concat;
170 auto sq_repeat = PyTuple_Type.tp_as_sequence->sq_repeat;
171 binaryfunc mp_subscript = PyTuple_Type.tp_as_mapping->mp_subscript;
172 } // namespace
173 
174 static PySequenceMethods THPSize_as_sequence = {
175     nullptr, /* sq_length */
176     wrap_tuple_fn<decltype(&sq_concat), &sq_concat>,
177     wrap_tuple_fn<decltype(&sq_repeat), &sq_repeat>,
178     nullptr, /* sq_item */
179     nullptr, /* sq_slice */
180     nullptr, /* sq_ass_item */
181     nullptr, /* sq_ass_slice */
182     nullptr /* sq_contains */
183 };
184 
185 static PyMappingMethods THPSize_as_mapping = {
186     nullptr, /* mp_length */
187     wrap_tuple_fn<decltype(&mp_subscript), &mp_subscript>,
188     nullptr};
189 
THPSize_numel(PyObject * _self,PyObject * noargs)190 static PyObject* THPSize_numel(PyObject* _self, PyObject* noargs) {
191   HANDLE_TH_ERRORS
192   auto self = (THPSize*)_self;
193   int64_t numel = 1;
194   for (Py_ssize_t i = 0; i < PyTuple_Size((PyObject*)self); ++i) {
195     numel *= THPUtils_unpackLong(PyTuple_GET_ITEM(self, i));
196   }
197   return THPUtils_packInt64(numel);
198   END_HANDLE_TH_ERRORS
199 }
200 
THPSize_reduce(PyObject * _self,PyObject * noargs)201 static PyObject* THPSize_reduce(PyObject* _self, PyObject* noargs) {
202   HANDLE_TH_ERRORS
203   auto self = (THPSize*)_self;
204   auto ret = THPObjectPtr{PyTuple_New(2)};
205   if (!ret)
206     throw python_error();
207 
208   auto obj = (PyObject*)(&THPSizeType);
209   Py_INCREF(&THPSizeType);
210   PyTuple_SET_ITEM(ret.get(), 0, obj);
211 
212   THPObjectPtr t(PyTuple_New(PyTuple_Size((PyObject*)self)));
213   if (!t)
214     throw python_error();
215   for (Py_ssize_t i = 0; i < PyTuple_Size((PyObject*)self); ++i) {
216     auto d = PyTuple_GET_ITEM(self, i);
217     Py_INCREF(d);
218     PyTuple_SET_ITEM(t.get(), i, d);
219   }
220 
221   THPObjectPtr dims(Py_BuildValue("(O)", t.get()));
222   if (!dims)
223     throw python_error();
224   PyTuple_SET_ITEM(ret.get(), 1, dims.release());
225 
226   return ret.release();
227   END_HANDLE_TH_ERRORS
228 }
229 
230 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
231 static PyMethodDef THPSize_methods[] = {
232     {"numel", THPSize_numel, METH_NOARGS, nullptr},
233     {"__reduce__", THPSize_reduce, METH_NOARGS, nullptr},
234     {nullptr}};
235 
236 PyTypeObject THPSizeType = {
237     PyVarObject_HEAD_INIT(nullptr, 0) "torch.Size", /* tp_name */
238     sizeof(THPSize), /* tp_basicsize */
239     0, /* tp_itemsize */
240     nullptr, /* tp_dealloc */
241     0, /* tp_vectorcall_offset */
242     nullptr, /* tp_getattr */
243     nullptr, /* tp_setattr */
244     nullptr, /* tp_reserved */
245     (reprfunc)THPSize_repr, /* tp_repr */
246     nullptr, /* tp_as_number */
247     &THPSize_as_sequence, /* tp_as_sequence */
248     &THPSize_as_mapping, /* tp_as_mapping */
249     nullptr, /* tp_hash  */
250     nullptr, /* tp_call */
251     nullptr, /* tp_str */
252     nullptr, /* tp_getattro */
253     nullptr, /* tp_setattro */
254     nullptr, /* tp_as_buffer */
255     Py_TPFLAGS_DEFAULT, /* tp_flags */
256     nullptr, /* tp_doc */
257     nullptr, /* tp_traverse */
258     nullptr, /* tp_clear */
259     nullptr, /* tp_richcompare */
260     0, /* tp_weaklistoffset */
261     nullptr, /* tp_iter */
262     nullptr, /* tp_iternext */
263     THPSize_methods, /* tp_methods */
264     nullptr, /* tp_members */
265     nullptr, /* tp_getset */
266     &PyTuple_Type, /* tp_base */
267     nullptr, /* tp_dict */
268     nullptr, /* tp_descr_get */
269     nullptr, /* tp_descr_set */
270     0, /* tp_dictoffset */
271     nullptr, /* tp_init */
272     nullptr, /* tp_alloc */
273     THPSize_pynew, /* tp_new */
274 };
275 
THPSize_init(PyObject * module)276 void THPSize_init(PyObject* module) {
277   if (PyType_Ready(&THPSizeType) < 0) {
278     throw python_error();
279   }
280   Py_INCREF(&THPSizeType);
281   if (PyModule_AddObject(module, "Size", (PyObject*)&THPSizeType) < 0) {
282     throw python_error();
283   }
284 }
285