xref: /aosp_15_r20/external/pytorch/tools/autograd/templates/python_nn_functions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 // ${generated_comment}
3 
4 #include "torch/csrc/Device.h"
5 #include "torch/csrc/DynamicTypes.h"
6 #include "torch/csrc/Exceptions.h"
7 #include "torch/csrc/autograd/python_nn_functions.h"
8 #include "torch/csrc/autograd/generated/python_return_types.h"
9 #include "torch/csrc/autograd/python_variable.h"
10 #include "torch/csrc/autograd/utils/wrap_outputs.h"
11 #include "torch/csrc/autograd/utils/python_arg_parsing.h"
12 #include "torch/csrc/utils/pycfunction_helpers.h"
13 #include "torch/csrc/utils/python_arg_parser.h"
14 #include "torch/csrc/utils/structseq.h"
15 #include "torch/csrc/utils/tensor_memoryformats.h"
16 
17 #ifndef AT_PER_OPERATOR_HEADERS
18 #include <ATen/Functions.h>
19 #else
20 $ops_headers
21 #endif
22 
23 using at::Tensor;
24 using at::Scalar;
25 using at::MemoryFormat;
26 using at::Generator;
27 using at::IntArrayRef;
28 using at::ArrayRef;
29 
30 using namespace torch::autograd::utils;
31 
32 namespace torch::autograd {
33 
34 static PyObject* THPNNVariableFunctionsModule = NULL;
35 
THPVariable__parse_to(PyObject * module,PyObject * args,PyObject * kwargs)36 static PyObject * THPVariable__parse_to(PyObject* module, PyObject* args, PyObject* kwargs)
37 {
38   HANDLE_TH_ERRORS
39   static PythonArgParser parser({
40     "to(Device device=None, ScalarType dtype=None, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)",
41     "to(ScalarType dtype, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)",
42     "to(Tensor tensor, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)",
43   });
44   ParsedArgs<5> parsed_args;
45   auto r = parser.parse(args, kwargs, parsed_args);
46   if (r.has_torch_function()) {
47     return handle_torch_function(r, args, kwargs, THPNNVariableFunctionsModule, "torch.nn", "_parse_to");
48   }
49   auto parsed = parse_to_conversion(r, /*allow_copy*/ false); // we don't want copy for nn.Module.to
50   auto& device = std::get<0>(parsed);
51   auto& scalarType = std::get<1>(parsed);
52   auto non_blocking = std::get<2>(parsed);
53   auto opt_memory_format = std::get<4>(parsed);
54   auto tuple = THPObjectPtr{PyTuple_New(4)};
55   if (!tuple) throw python_error();
56   if (device) {
57     PyTuple_SET_ITEM(tuple.get(), 0, THPDevice_New(*device));
58   } else {
59     Py_INCREF(Py_None);
60     PyTuple_SET_ITEM(tuple.get(), 0, Py_None);
61   }
62   if (scalarType) {
63     PyTuple_SET_ITEM(tuple.get(), 1, Py_NewRef(torch::getTHPDtype(*scalarType)));
64   } else {
65     Py_INCREF(Py_None);
66     PyTuple_SET_ITEM(tuple.get(), 1, Py_None);
67   }
68   PyTuple_SET_ITEM(tuple.get(), 2, torch::autograd::utils::wrap(non_blocking));
69   if (opt_memory_format.has_value()) {
70     PyTuple_SET_ITEM(tuple.get(), 3, Py_NewRef(torch::utils::getTHPMemoryFormat(opt_memory_format.value())));
71   } else {
72     Py_INCREF(Py_None);
73     PyTuple_SET_ITEM(tuple.get(), 3, Py_None);
74   }
75   return tuple.release();
76   END_HANDLE_TH_ERRORS
77 }
78 
79 // generated forward declarations start here
80 
81 ${py_forwards}
82 
83 static PyMethodDef nn_functions[] = {
84   {"_parse_to", castPyCFunctionWithKeywords(THPVariable__parse_to),
85     METH_VARARGS | METH_KEYWORDS, nullptr},
86   ${py_method_defs}
87   {NULL}
88 };
89 
initNNFunctions(PyObject * module)90 void initNNFunctions(PyObject* module) {
91   static struct PyModuleDef def = {
92      PyModuleDef_HEAD_INIT,
93      "torch._C._nn",
94      NULL,
95      -1,
96      nn_functions
97   };
98   PyObject* nn = PyModule_Create(&def);
99   THPNNVariableFunctionsModule = nn;
100   if (!nn) {
101     throw python_error();
102   }
103   // steals a reference to nn
104   if (PyModule_AddObject(module, "_nn", nn) != 0) {
105     throw python_error();
106   }
107 }
108 
109 // generated methods start here
110 
111 ${py_methods}
112 
113 } // namespace torch::autograd
114