xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_torch_functions_manual.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/Dtype.h>
2 #include <torch/csrc/DynamicTypes.h>
3 #include <torch/csrc/Exceptions.h>
4 #include <torch/csrc/autograd/function.h>
5 #include <torch/csrc/autograd/functions/basic_ops.h>
6 #include <torch/csrc/autograd/functions/utils.h>
7 #include <torch/csrc/autograd/generated/variable_factories.h>
8 #include <torch/csrc/autograd/python_torch_functions.h>
9 #include <torch/csrc/autograd/python_variable.h>
10 #include <torch/csrc/autograd/utils/wrap_outputs.h>
11 #include <torch/csrc/jit/frontend/tracer.h>
12 #include <torch/csrc/utils/device_lazy_init.h>
13 #include <torch/csrc/utils/out_types.h>
14 #include <torch/csrc/utils/pybind.h>
15 #include <torch/csrc/utils/pycfunction_helpers.h>
16 #include <torch/csrc/utils/python_arg_parser.h>
17 #include <torch/csrc/utils/structseq.h>
18 #include <torch/csrc/utils/tensor_layouts.h>
19 #include <torch/csrc/utils/tensor_new.h>
20 #include <torch/csrc/utils/tensor_numpy.h>
21 
22 #include <ATen/ATen.h>
23 #include <ATen/FunctionalTensorWrapper.h>
24 #include <ATen/native/Resize.h>
25 
26 #include <Python.h>
27 #include <fmt/format.h>
28 #include <pybind11/pybind11.h>
29 #include <utility>
30 #include <vector>
31 
32 using at::DeviceGuard;
33 using at::DimnameList;
34 using at::IntArrayRef;
35 using at::OptionalDeviceGuard;
36 using at::Scalar;
37 using at::Tensor;
38 using at::TensorList;
39 using at::TensorOptions;
40 
41 using torch::utils::check_out_type_matches;
42 using namespace torch::autograd::utils;
43 
44 namespace torch::autograd {
45 
46 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
47 PyObject* THPVariableFunctionsModule = nullptr;
48 
dispatch_range(const Scalar & start,const Scalar & end,const Scalar & step,Tensor result)49 inline Tensor dispatch_range(
50     const Scalar& start,
51     const Scalar& end,
52     const Scalar& step,
53     Tensor result) {
54   pybind11::gil_scoped_release no_gil;
55   OptionalDeviceGuard device_guard(device_of(result));
56   return at::range_out(result, start, end, step);
57 }
58 
dispatch_range(const Scalar & start,const Scalar & end,const Scalar & step,const TensorOptions & options)59 inline Tensor dispatch_range(
60     const Scalar& start,
61     const Scalar& end,
62     const Scalar& step,
63     const TensorOptions& options) {
64   torch::utils::maybe_initialize_device(options);
65   pybind11::gil_scoped_release no_gil;
66   DeviceGuard device_guard(options.device());
67   return torch::range(start, end, step, options);
68 }
69 
THPVariable_range(PyObject * self,PyObject * args,PyObject * kwargs)70 static PyObject* THPVariable_range(
71     PyObject* self,
72     PyObject* args,
73     PyObject* kwargs) {
74   HANDLE_TH_ERRORS
75   static PythonArgParser parser({
76       "range(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)",
77   });
78 
79   ParsedArgs<8> parsed_args;
80   auto r = parser.parse(args, kwargs, parsed_args);
81 
82   if (r.idx == 0) {
83     auto ret = PyErr_WarnEx(
84         PyExc_UserWarning,
85         "torch.range is deprecated and will be removed in a future release "
86         "because its behavior is inconsistent with Python's range builtin. "
87         "Instead, use torch.arange, which produces values in [start, end).",
88         1);
89     if (ret != 0)
90       throw python_error();
91     if (r.isNone(3)) {
92       const auto options = TensorOptions()
93                                .dtype(r.scalartype(4))
94                                .device(r.device(6))
95                                .layout(r.layout(5))
96                                .requires_grad(r.toBool(7));
97       return wrap(
98           dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), options));
99     } else {
100       check_out_type_matches(
101           r.tensor(3),
102           r.scalartype(4),
103           r.isNone(4),
104           r.layout(5),
105           r.device(6),
106           r.isNone(6));
107       return wrap(
108           dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3))
109               .set_requires_grad(r.toBool(7)));
110     }
111   }
112   Py_RETURN_NONE;
113   END_HANDLE_TH_ERRORS
114 }
115 
116 // implemented on python object to allow torch.as_tensor to be constructed with
117 // arbitrarily nested python objects - list, tuple, np array, scalar, etc.
THPVariable_as_tensor(PyObject * self,PyObject * args,PyObject * kwargs)118 static PyObject* THPVariable_as_tensor(
119     PyObject* self,
120     PyObject* args,
121     PyObject* kwargs) {
122   HANDLE_TH_ERRORS
123   static PythonArgParser parser({
124       "as_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None)",
125   });
126 
127   ParsedArgs<3> parsed_args;
128   auto r = parser.parse(args, kwargs, parsed_args);
129   if (r.has_torch_function()) {
130     return handle_torch_function(
131         r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
132   }
133   jit::tracer::warn("torch.as_tensor", jit::tracer::WARN_CONSTRUCTOR);
134   return THPVariable_Wrap(torch::utils::as_tensor(
135       torch::tensors::get_default_dispatch_key(),
136       torch::tensors::get_default_scalar_type(),
137       r));
138   END_HANDLE_TH_ERRORS
139 }
140 
141 // implemented on python object here because PyObject currently not natively
142 // declarable See: ATen/native/README.md for more context
THPVariable_from_numpy(PyObject * module,PyObject * arg)143 static PyObject* THPVariable_from_numpy(PyObject* module, PyObject* arg) {
144   HANDLE_TH_ERRORS
145   jit::tracer::warn("torch.from_numpy", jit::tracer::WARN_CONSTRUCTOR);
146   return THPVariable_Wrap(torch::utils::tensor_from_numpy(arg));
147   END_HANDLE_TH_ERRORS
148 }
149 
dispatch_nonzero(const Tensor & self)150 static Tensor dispatch_nonzero(const Tensor& self) {
151   pybind11::gil_scoped_release no_gil;
152   OptionalDeviceGuard device_guard(device_of(self));
153   return self.nonzero();
154 }
155 
dispatch_nonzero(const Tensor & self,Tensor out)156 static Tensor dispatch_nonzero(const Tensor& self, Tensor out) {
157   pybind11::gil_scoped_release no_gil;
158   OptionalDeviceGuard device_guard(device_of(self));
159   return at::nonzero_out(out, self);
160 }
161 
dispatch_nonzero_numpy(const Tensor & self)162 static std::vector<Tensor> dispatch_nonzero_numpy(const Tensor& self) {
163   pybind11::gil_scoped_release no_gil;
164   OptionalDeviceGuard device_guard(device_of(self));
165   return self.nonzero_numpy();
166 }
167 
168 static PyObject* THPVariable_nonzero(
169     PyObject* self,
170     PyObject* args,
171     PyObject* kwargs);
172 
173 #define THPVARIABLE_SPARSE_COMPRESSED_CTOR(NAME, NARGS, SIGNATURES)       \
174   static PyObject* THPVariable_##NAME(                                    \
175       PyObject* self, PyObject* args, PyObject* kwargs) {                 \
176     HANDLE_TH_ERRORS                                                      \
177     static PythonArgParser parser SIGNATURES;                             \
178     ParsedArgs<NARGS> parsed_args;                                        \
179     auto r = parser.parse(args, kwargs, parsed_args);                     \
180     if (r.has_torch_function()) {                                         \
181       return handle_torch_function(                                       \
182           r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); \
183     }                                                                     \
184     jit::tracer::warn("torch." #NAME, jit::tracer::WARN_CONSTRUCTOR);     \
185     return THPVariable_Wrap(torch::utils::NAME##_ctor(                    \
186         torch::tensors::get_default_dispatch_key(),                       \
187         torch::tensors::get_default_scalar_type(),                        \
188         r));                                                              \
189     END_HANDLE_TH_ERRORS                                                  \
190   }
191 
192 THPVARIABLE_SPARSE_COMPRESSED_CTOR(
193     sparse_compressed_tensor,
194     10,
195     ({"sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
196       "sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
197 THPVARIABLE_SPARSE_COMPRESSED_CTOR(
198     sparse_csr_tensor,
199     10,
200     ({"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
201       "sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
202 THPVARIABLE_SPARSE_COMPRESSED_CTOR(
203     sparse_csc_tensor,
204     10,
205     ({"sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
206       "sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
207 THPVARIABLE_SPARSE_COMPRESSED_CTOR(
208     sparse_bsr_tensor,
209     10,
210     ({"sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
211       "sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
212 THPVARIABLE_SPARSE_COMPRESSED_CTOR(
213     sparse_bsc_tensor,
214     10,
215     ({"sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
216       "sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
217 
THPVariable_sparse_coo_tensor(PyObject * self,PyObject * args,PyObject * kwargs)218 static PyObject* THPVariable_sparse_coo_tensor(
219     PyObject* self,
220     PyObject* args,
221     PyObject* kwargs) {
222   HANDLE_TH_ERRORS
223   static PythonArgParser parser({
224       "sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
225       "sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None, bool is_coalesced=None)",
226       "sparse_coo_tensor(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False, bool check_invariants=None)",
227   });
228 
229   ParsedArgs<9> parsed_args;
230   auto r = parser.parse(args, kwargs, parsed_args);
231   if (r.has_torch_function()) {
232     return handle_torch_function(
233         r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
234   }
235   jit::tracer::warn("torch.sparse_coo_tensor", jit::tracer::WARN_CONSTRUCTOR);
236   return THPVariable_Wrap(torch::utils::sparse_coo_tensor_ctor(
237       torch::tensors::get_default_dispatch_key(),
238       torch::tensors::get_default_scalar_type(),
239       r));
240   END_HANDLE_TH_ERRORS
241 }
242 
243 // implemented on python object to allow torch.tensor to be constructed with
244 // arbitrarily nested python objects - list, tuple, np array, scalar, etc.
THPVariable_tensor(PyObject * self,PyObject * args,PyObject * kwargs)245 static PyObject* THPVariable_tensor(
246     PyObject* self,
247     PyObject* args,
248     PyObject* kwargs) {
249   HANDLE_TH_ERRORS
250   static PythonArgParser parser({
251       "tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, DimnameList? names=None)",
252   });
253 
254   constexpr int ctor_num_args = 6;
255   ParsedArgs<ctor_num_args> parsed_args;
256   auto r = parser.parse(args, kwargs, parsed_args);
257   if (r.has_torch_function()) {
258     return handle_torch_function(
259         r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
260   }
261   jit::tracer::warn("torch.tensor", jit::tracer::WARN_CONSTRUCTOR);
262   return THPVariable_Wrap(torch::utils::tensor_ctor(
263       torch::tensors::get_default_dispatch_key(),
264       torch::tensors::get_default_scalar_type(),
265       r));
266   END_HANDLE_TH_ERRORS
267 }
268 
THPVariable_get_device(PyObject * self_,PyObject * args,PyObject * kwargs)269 static PyObject* THPVariable_get_device(
270     PyObject* self_,
271     PyObject* args,
272     PyObject* kwargs) {
273   HANDLE_TH_ERRORS
274   static PythonArgParser parser(
275       {
276           "get_device(Tensor input)",
277       },
278       /*traceable=*/false);
279 
280   ParsedArgs<1> parsed_args;
281   auto r = parser.parse(args, kwargs, parsed_args);
282   if (r.has_torch_function()) {
283     return handle_torch_function(
284         r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
285   }
286 
287   if (r.idx == 0) {
288     return wrap(r.tensor(0).get_device());
289   }
290   Py_RETURN_NONE;
291   END_HANDLE_TH_ERRORS
292 }
293 
THPVariable_frombuffer(PyObject * self_,PyObject * args,PyObject * kwargs)294 static PyObject* THPVariable_frombuffer(
295     PyObject* self_,
296     PyObject* args,
297     PyObject* kwargs) {
298   HANDLE_TH_ERRORS
299   static PythonArgParser parser(
300       {
301           "frombuffer(PyObject* buffer, *, ScalarType dtype, int64_t count=-1, int64_t offset=0, bool requires_grad=False)",
302       },
303       /*traceable=*/false);
304 
305   ParsedArgs<5> parsed_args;
306   auto r = parser.parse(args, kwargs, parsed_args);
307 
308   if (r.idx == 0) {
309     auto buffer = r.pyobject(0);
310     auto dtype = r.scalartype(1);
311     auto count = r.toInt64(2);
312     auto offset = r.toInt64(3);
313     auto requires_grad = r.toBool(4);
314 
315     TORCH_CHECK_VALUE(
316         PyObject_CheckBuffer(buffer) != 0,
317         "object does not implement Python buffer protocol.");
318     return wrap(torch::utils::tensor_frombuffer(
319         buffer, dtype, count, offset, requires_grad));
320   }
321 
322   Py_RETURN_NONE;
323   END_HANDLE_TH_ERRORS
324 }
325 
THPVariable_asarray(PyObject * self_,PyObject * args,PyObject * kwargs)326 static PyObject* THPVariable_asarray(
327     PyObject* self_,
328     PyObject* args,
329     PyObject* kwargs) {
330   HANDLE_TH_ERRORS
331   static PythonArgParser parser(
332       {
333           "asarray(PyObject* obj, *, ScalarType? dtype=None, Device? device=None, bool? copy=None, bool requires_grad=False)",
334       },
335       /*traceable=*/false);
336 
337   ParsedArgs<5> parsed_args;
338   auto r = parser.parse(args, kwargs, parsed_args);
339 
340   if (r.has_torch_function()) {
341     return handle_torch_function(
342         r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
343   }
344 
345   if (r.idx == 0) {
346     auto obj = r.pyobject(0);
347     auto dtype = r.scalartypeOptional(1);
348     auto device = r.deviceOptional(2);
349     auto copy = r.toBoolOptional(3);
350     auto requires_grad = r.toBool(4);
351     return wrap(torch::utils::asarray(obj, dtype, device, copy, requires_grad));
352   }
353 
354   Py_RETURN_NONE;
355   END_HANDLE_TH_ERRORS
356 }
357 
358 static PyObject* THPVariable_numel(
359     PyObject* self_,
360     PyObject* args,
361     PyObject* kwargs);
362 
363 // XXX: ops that are bound here are not exposed to the C++ api nor the JIT.
364 // Any new ops added here should be accompanied with a comment why they are not
365 // being registered through native_functions.yaml, and be tagged cpp / JIT
366 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
367 static PyMethodDef torch_functions_manual[] = {
368     {"asarray",
369      castPyCFunctionWithKeywords(THPVariable_asarray),
370      METH_VARARGS | METH_KEYWORDS | METH_STATIC,
371      nullptr},
372     {"as_tensor",
373      castPyCFunctionWithKeywords(THPVariable_as_tensor),
374      METH_VARARGS | METH_KEYWORDS | METH_STATIC,
375      nullptr},
376     {"from_numpy", THPVariable_from_numpy, METH_STATIC | METH_O, nullptr},
377     {"frombuffer",
378      castPyCFunctionWithKeywords(THPVariable_frombuffer),
379      METH_VARARGS | METH_KEYWORDS | METH_STATIC,
380      nullptr},
381     {"nonzero",
382      castPyCFunctionWithKeywords(THPVariable_nonzero),
383      METH_VARARGS | METH_KEYWORDS | METH_STATIC,
384      nullptr},
385     {"range",
386      castPyCFunctionWithKeywords(THPVariable_range),
387      METH_VARARGS | METH_KEYWORDS | METH_STATIC,
388      nullptr},
389     {"sparse_coo_tensor",
390      castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor),
391      METH_VARARGS | METH_KEYWORDS | METH_STATIC,
392      nullptr},
393     {"sparse_compressed_tensor",
394      castPyCFunctionWithKeywords(THPVariable_sparse_compressed_tensor),
395      METH_VARARGS | METH_KEYWORDS | METH_STATIC,
396      nullptr},
397     {"sparse_csr_tensor",
398      castPyCFunctionWithKeywords(THPVariable_sparse_csr_tensor),
399      METH_VARARGS | METH_KEYWORDS | METH_STATIC,
400      nullptr},
401     {"sparse_csc_tensor",
402      castPyCFunctionWithKeywords(THPVariable_sparse_csc_tensor),
403      METH_VARARGS | METH_KEYWORDS | METH_STATIC,
404      nullptr},
405     {"sparse_bsr_tensor",
406      castPyCFunctionWithKeywords(THPVariable_sparse_bsr_tensor),
407      METH_VARARGS | METH_KEYWORDS | METH_STATIC,
408      nullptr},
409     {"sparse_bsc_tensor",
410      castPyCFunctionWithKeywords(THPVariable_sparse_bsc_tensor),
411      METH_VARARGS | METH_KEYWORDS | METH_STATIC,
412      nullptr},
413     {"tensor",
414      castPyCFunctionWithKeywords(THPVariable_tensor),
415      METH_VARARGS | METH_KEYWORDS | METH_STATIC,
416      nullptr},
417     {"get_device",
418      castPyCFunctionWithKeywords(THPVariable_get_device),
419      METH_VARARGS | METH_KEYWORDS | METH_STATIC,
420      nullptr},
421     {"numel",
422      castPyCFunctionWithKeywords(THPVariable_numel),
423      METH_VARARGS | METH_KEYWORDS | METH_STATIC,
424      nullptr},
425 };
426 
THPVariable_nonzero(PyObject * self,PyObject * args,PyObject * kwargs)427 static PyObject* THPVariable_nonzero(
428     PyObject* self,
429     PyObject* args,
430     PyObject* kwargs) {
431   HANDLE_TH_ERRORS
432   static PythonArgParser parser({
433       "nonzero(Tensor input, *, bool as_tuple=False, Tensor out=None)",
434   });
435   ParsedArgs<3> parsed_args;
436   auto r = parser.parse(args, kwargs, parsed_args);
437 
438   if (r.has_torch_function()) {
439     return handle_torch_function(
440         r, args, kwargs, THPVariableFunctionsModule, "torch");
441   }
442 
443   const auto as_tuple = r.toBool(1);
444   const auto has_out = !r.isNone(2);
445 
446   if (as_tuple) {
447     TORCH_CHECK(
448         !has_out,
449         "nonzero does not support the out kwarg when as_tuple is True");
450     return wrap(dispatch_nonzero_numpy(r.tensor(0)));
451   }
452 
453   if (has_out) {
454     return wrap(dispatch_nonzero(r.tensor(0), r.tensor(2)));
455   }
456 
457   return wrap(dispatch_nonzero(r.tensor(0)));
458 
459   END_HANDLE_TH_ERRORS
460 }
461 
THPVariable_numel(PyObject * self_,PyObject * args,PyObject * kwargs)462 static PyObject* THPVariable_numel(
463     PyObject* self_,
464     PyObject* args,
465     PyObject* kwargs) {
466   HANDLE_TH_ERRORS
467   static PythonArgParser parser(
468       {
469           "numel(Tensor input)",
470       },
471       /*traceable=*/false);
472 
473   ParsedArgs<1> parsed_args;
474   auto r = parser.parse(args, kwargs, parsed_args);
475 
476   if (r.has_torch_function()) {
477     return handle_torch_function(
478         r, args, kwargs, THPVariableFunctionsModule, "torch");
479   }
480 
481   if (r.idx == 0) {
482     return py::cast(r.tensor(0).sym_numel()).release().ptr();
483   }
484   Py_RETURN_NONE;
485   END_HANDLE_TH_ERRORS
486 }
487 
488 // Sharded function definitions
489 void gatherTorchFunctions_0(std::vector<PyMethodDef>& torch_functions);
490 void gatherTorchFunctions_1(std::vector<PyMethodDef>& torch_functions);
491 void gatherTorchFunctions_2(std::vector<PyMethodDef>& torch_functions);
492 
gatherTorchFunctions(std::vector<PyMethodDef> & torch_functions)493 void gatherTorchFunctions(std::vector<PyMethodDef>& torch_functions) {
494   constexpr size_t num_functions =
495       sizeof(torch_functions_manual) / sizeof(torch_functions_manual[0]);
496   torch_functions.assign(
497       torch_functions_manual, torch_functions_manual + num_functions);
498   // NOTE: Must be synced with num_shards in
499   // tools/autograd/gen_python_functions.py
500   gatherTorchFunctions_0(torch_functions);
501   gatherTorchFunctions_1(torch_functions);
502   gatherTorchFunctions_2(torch_functions);
503 
504   static std::array<std::pair<const char*, const char*>, 4> aliases{
505       {// Canonical function, alias name
506        {"sspaddmm", "saddmm"},
507        {"mm", "spmm"},
508        {"mm", "dsmm"},
509        {"hspmm", "hsmm"}}};
510 
511   for (const auto& alias : aliases) {
512     auto it = std::find_if(
513         torch_functions.begin(),
514         torch_functions.end(),
515         [&](const PyMethodDef& def) {
516           return strcmp(def.ml_name, alias.first) == 0;
517         });
518     TORCH_INTERNAL_ASSERT(
519         it != torch_functions.end(),
520         "Failed to create function alias from ",
521         alias.first,
522         " to ",
523         alias.second);
524     PyMethodDef alias_def = *it;
525     alias_def.ml_name = alias.second;
526 
527     torch_functions.push_back(alias_def);
528   }
529 
530   torch_functions.push_back({nullptr});
531   torch_functions.shrink_to_fit();
532 }
533 
534 static PyTypeObject THPVariableFunctions = {
535     PyVarObject_HEAD_INIT(
536         nullptr,
537         0) "torch._C._VariableFunctionsClass", /* tp_name */
538     0, /* tp_basicsize */
539     0, /* tp_itemsize */
540     nullptr, /* tp_dealloc */
541     0, /* tp_vectorcall_offset */
542     nullptr, /* tp_getattr */
543     nullptr, /* tp_setattr */
544     nullptr, /* tp_reserved */
545     nullptr, /* tp_repr */
546     nullptr, /* tp_as_number */
547     nullptr, /* tp_as_sequence */
548     nullptr, /* tp_as_mapping */
549     nullptr, /* tp_hash  */
550     nullptr, /* tp_call */
551     nullptr, /* tp_str */
552     nullptr, /* tp_getattro */
553     nullptr, /* tp_setattro */
554     nullptr, /* tp_as_buffer */
555     Py_TPFLAGS_DEFAULT, /* tp_flags */
556     nullptr, /* tp_doc */
557     nullptr, /* tp_traverse */
558     nullptr, /* tp_clear */
559     nullptr, /* tp_richcompare */
560     0, /* tp_weaklistoffset */
561     nullptr, /* tp_iter */
562     nullptr, /* tp_iternext */
563     nullptr, /* tp_methods */
564     nullptr, /* tp_members */
565     nullptr, /* tp_getset */
566     nullptr, /* tp_base */
567     nullptr, /* tp_dict */
568     nullptr, /* tp_descr_get */
569     nullptr, /* tp_descr_set */
570     0, /* tp_dictoffset */
571     nullptr, /* tp_init */
572     nullptr, /* tp_alloc */
573     nullptr /* tp_new */
574 };
575 
initTorchFunctions(PyObject * module)576 void initTorchFunctions(PyObject* module) {
577   static std::vector<PyMethodDef> torch_functions;
578   gatherTorchFunctions(torch_functions);
579   THPVariableFunctions.tp_methods = torch_functions.data();
580 
581   if (PyType_Ready(&THPVariableFunctions) < 0) {
582     throw python_error();
583   }
584   Py_INCREF(&THPVariableFunctions);
585 
586   // Steals
587   Py_INCREF(&THPVariableFunctions);
588   if (PyModule_AddObject(
589           module,
590           "_VariableFunctionsClass",
591           reinterpret_cast<PyObject*>(&THPVariableFunctions)) < 0) {
592     throw python_error();
593   }
594   // PyType_GenericNew returns a new reference
595   THPVariableFunctionsModule =
596       PyType_GenericNew(&THPVariableFunctions, Py_None, Py_None);
597   // PyModule_AddObject steals a reference
598   if (PyModule_AddObject(
599           module, "_VariableFunctions", THPVariableFunctionsModule) < 0) {
600     throw python_error();
601   }
602 
603   // pybind registrations to torch module
604   // TODO: move these from torch.* to torch._C.*
605   auto py_module = py::module::import("torch");
606 
607   py_module.def(
608       "_functionalize_are_all_mutations_under_no_grad_or_inference_mode",
609       [](const at::Tensor& t) {
610         TORCH_INTERNAL_ASSERT(
611             at::functionalization::impl::isFunctionalTensor(t));
612         return at::functionalization::impl::
613             are_all_mutations_under_no_grad_or_inference_mode(t);
614       });
615   py_module.def(
616       "_functionalize_was_inductor_storage_resized", [](const at::Tensor& t) {
617         TORCH_INTERNAL_ASSERT(
618             at::functionalization::impl::isFunctionalTensor(t));
619         auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
620         return impl->was_inductor_storage_resized();
621       });
622   py_module.def(
623       "_functionalize_are_all_mutations_hidden_from_autograd",
624       [](const at::Tensor& t) {
625         TORCH_INTERNAL_ASSERT(
626             at::functionalization::impl::isFunctionalTensor(t));
627         return at::functionalization::impl::
628             are_all_mutations_hidden_from_autograd(t);
629       });
630   py_module.def(
631       "_functionalize_mark_mutation_hidden_from_autograd",
632       [](const at::Tensor& t) {
633         TORCH_INTERNAL_ASSERT(
634             at::functionalization::impl::isFunctionalTensor(t));
635         at::functionalization::impl::mark_mutation_hidden_from_autograd(t);
636       });
637   py_module.def(
638       "_functionalize_apply_view_metas",
639       [](const at::Tensor& tensor, const at::Tensor& base) {
640         TORCH_INTERNAL_ASSERT(
641             at::functionalization::impl::isFunctionalTensor(tensor));
642         auto impl =
643             at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
644         return impl->apply_view_metas(base);
645       });
646   py_module.def("_functionalize_is_symbolic", [](const at::Tensor& t) {
647     TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
648     auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
649     return impl->is_symbolic();
650   });
651   py_module.def("_functionalize_sync", [](const at::Tensor& t) {
652     TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
653     at::functionalization::impl::sync(t);
654   });
655   py_module.def("_functionalize_commit_update", [](const at::Tensor& t) {
656     TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
657     at::functionalization::impl::commit_update(t);
658   });
659   py_module.def(
660       "_functionalize_replace", [](const at::Tensor& t, const at::Tensor& o) {
661         TORCH_INTERNAL_ASSERT(
662             at::functionalization::impl::isFunctionalTensor(t));
663         TORCH_INTERNAL_ASSERT(
664             !at::functionalization::impl::isFunctionalTensor(o));
665         at::functionalization::impl::replace_(t, o);
666       });
667   py_module.def("_is_functional_tensor_base", [](const at::Tensor& t) {
668     TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
669     return at::functionalization::impl::isBaseTensor(t);
670   });
671   py_module.def("_functionalize_is_multi_output_view", [](const at::Tensor& t) {
672     TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
673     auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
674     return t_impl->is_multi_output_view();
675   });
676   py_module.def(
677       "_functionalize_enable_reapply_views",
678       [](bool reapply_views = false) {
679         auto old =
680             at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
681         at::functionalization::impl::setFunctionalizationReapplyViewsTLS(
682             reapply_views);
683         return old;
684       },
685       py::arg("reapply_views") = false);
686   py_module.def(
687       "_functionalize_has_metadata_mutation", [](const at::Tensor& t) {
688         TORCH_INTERNAL_ASSERT(
689             at::functionalization::impl::isFunctionalTensor(t));
690         auto t_impl =
691             at::functionalization::impl::unsafeGetFunctionalWrapper(t);
692         return t_impl->has_metadata_mutation();
693       });
694   py_module.def("_functionalize_has_data_mutation", [](const at::Tensor& t) {
695     TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
696     auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
697     return t_impl->has_data_mutation();
698   });
699   py_module.def(
700       "_functionalize_get_storage_size", [](const at::Tensor& t, bool before) {
701         TORCH_INTERNAL_ASSERT(
702             at::functionalization::impl::isFunctionalTensor(t));
703         auto wrapper =
704             at::functionalization::impl::unsafeGetFunctionalWrapper(t);
705         auto size = wrapper->get_storage_size(/*before=*/before);
706         return size;
707       });
708   py_module.def("_functionalize_set_storage_changed", [](const at::Tensor& t) {
709     TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
710     auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
711     wrapper->set_storage_changed();
712   });
713   py_module.def("_functionalize_was_storage_changed", [](const at::Tensor& t) {
714     TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
715     auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
716     return wrapper->was_storage_changed();
717   });
718   py_module.def(
719       "_functionalize_unsafe_set", [](at::Tensor& dst, const at::Tensor& src) {
720         // Forcefully/unsafely dumps src.storage into dst.
721         // This API is technically and not specific to functionalization
722         // (it just runs set_() without the safety checks).
723         // But its main intended purpose today is during functionalization.
724         // In particular: when we generate a new FunctionalTensor from a view
725         // op, we need to ensure it shares a storage with the view input.
726         //
727         // Other subclasses shouldn't really need to care about this,
728         // because we define aliasing on wrapper subclasses such that:
729         // - differentiable aliasing: subclass_x and subclass_y share a ._base.
730         // - non-differentiable aliasing: aliasing of subclass_x and subclass_y
731         //   is defined recursively based on the aliasing of their inner
732         //   tensors.
733         at::native::checkSetStorage(
734             dst,
735             src.storage(),
736             dst.sym_storage_offset(),
737             dst.sym_sizes(),
738             dst.sym_strides());
739       });
740   py_module.def(
741       "_functionalize_mark_mutation_hidden_from_autograd",
742       [](const at::Tensor& t) {
743         TORCH_INTERNAL_ASSERT(
744             at::functionalization::impl::isFunctionalTensor(t));
745         at::functionalization::impl::mark_mutation_hidden_from_autograd(t);
746       });
747   py_module.def("_is_functional_tensor", [](const at::Tensor& t) {
748     return at::functionalization::impl::isFunctionalTensor(t);
749   });
750   py_module.def("_to_functional_tensor", [](const at::Tensor& t) {
751     return at::functionalization::impl::to_functional_tensor(t);
752   });
753   py_module.def("_from_functional_tensor", [](const at::Tensor& t) {
754     return at::functionalization::impl::from_functional_tensor(t);
755   });
756   py_module.def("_freeze_functional_tensor", [](const at::Tensor& t) {
757     at::functionalization::impl::freeze_functional_tensor(t);
758   });
759   py_module.def(
760       "_enable_functionalization",
761       [](bool reapply_views = false) {
762         if (c10::impl::tls_is_dispatch_key_included(
763                 at::DispatchKey::Functionalize)) {
764           TORCH_INTERNAL_ASSERT(
765               false,
766               "multiple layers of mode-style functionalization nesting is not"
767               " currently supported, outside of the functionalize() transform");
768         }
769         c10::impl::tls_set_dispatch_key_included(
770             at::DispatchKey::Functionalize, true);
771         if (reapply_views) {
772           at::functionalization::impl::setFunctionalizationReapplyViewsTLS(
773               true);
774         }
775       },
776       py::arg("reapply_views") = false);
777   py_module.def("_disable_functionalization", []() {
778     c10::impl::tls_set_dispatch_key_included(
779         at::DispatchKey::Functionalize, false);
780     at::functionalization::impl::setFunctionalizationReapplyViewsTLS(false);
781   });
782   py_module.def(
783       "_mirror_autograd_meta_to",
784       [](const at::Tensor& src_, const at::Tensor& dst_) {
785         // Here, we unsafely set the grad function on the wrapper to be the same
786         // as the inner. We expect this grad_fn to NEVER be used. It's needed so
787         // that .is_leaf metadata is accurate on the wrapper
788         auto inner_autograd_meta = impl::get_autograd_meta(src_);
789         if (inner_autograd_meta) {
790           dst_.set_requires_grad(src_.requires_grad());
791           if (dst_.requires_grad()) {
792             auto new_grad_fn = std::shared_ptr<torch::autograd::Error>(
793                 new torch::autograd::Error(
794                     "Cannot backprop through mirrored meta, file a bug in PyTorch"),
795                 torch::autograd::deleteNode);
796             torch::autograd::set_history(dst_, new_grad_fn);
797           }
798         }
799       });
800 }
801 
802 } // namespace torch::autograd
803