1 #include <torch/csrc/Dtype.h>
2 #include <torch/csrc/DynamicTypes.h>
3 #include <torch/csrc/Exceptions.h>
4 #include <torch/csrc/autograd/function.h>
5 #include <torch/csrc/autograd/functions/basic_ops.h>
6 #include <torch/csrc/autograd/functions/utils.h>
7 #include <torch/csrc/autograd/generated/variable_factories.h>
8 #include <torch/csrc/autograd/python_torch_functions.h>
9 #include <torch/csrc/autograd/python_variable.h>
10 #include <torch/csrc/autograd/utils/wrap_outputs.h>
11 #include <torch/csrc/jit/frontend/tracer.h>
12 #include <torch/csrc/utils/device_lazy_init.h>
13 #include <torch/csrc/utils/out_types.h>
14 #include <torch/csrc/utils/pybind.h>
15 #include <torch/csrc/utils/pycfunction_helpers.h>
16 #include <torch/csrc/utils/python_arg_parser.h>
17 #include <torch/csrc/utils/structseq.h>
18 #include <torch/csrc/utils/tensor_layouts.h>
19 #include <torch/csrc/utils/tensor_new.h>
20 #include <torch/csrc/utils/tensor_numpy.h>
21
22 #include <ATen/ATen.h>
23 #include <ATen/FunctionalTensorWrapper.h>
24 #include <ATen/native/Resize.h>
25
26 #include <Python.h>
27 #include <fmt/format.h>
28 #include <pybind11/pybind11.h>
29 #include <utility>
30 #include <vector>
31
32 using at::DeviceGuard;
33 using at::DimnameList;
34 using at::IntArrayRef;
35 using at::OptionalDeviceGuard;
36 using at::Scalar;
37 using at::Tensor;
38 using at::TensorList;
39 using at::TensorOptions;
40
41 using torch::utils::check_out_type_matches;
42 using namespace torch::autograd::utils;
43
44 namespace torch::autograd {
45
46 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
47 PyObject* THPVariableFunctionsModule = nullptr;
48
dispatch_range(const Scalar & start,const Scalar & end,const Scalar & step,Tensor result)49 inline Tensor dispatch_range(
50 const Scalar& start,
51 const Scalar& end,
52 const Scalar& step,
53 Tensor result) {
54 pybind11::gil_scoped_release no_gil;
55 OptionalDeviceGuard device_guard(device_of(result));
56 return at::range_out(result, start, end, step);
57 }
58
dispatch_range(const Scalar & start,const Scalar & end,const Scalar & step,const TensorOptions & options)59 inline Tensor dispatch_range(
60 const Scalar& start,
61 const Scalar& end,
62 const Scalar& step,
63 const TensorOptions& options) {
64 torch::utils::maybe_initialize_device(options);
65 pybind11::gil_scoped_release no_gil;
66 DeviceGuard device_guard(options.device());
67 return torch::range(start, end, step, options);
68 }
69
THPVariable_range(PyObject * self,PyObject * args,PyObject * kwargs)70 static PyObject* THPVariable_range(
71 PyObject* self,
72 PyObject* args,
73 PyObject* kwargs) {
74 HANDLE_TH_ERRORS
75 static PythonArgParser parser({
76 "range(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)",
77 });
78
79 ParsedArgs<8> parsed_args;
80 auto r = parser.parse(args, kwargs, parsed_args);
81
82 if (r.idx == 0) {
83 auto ret = PyErr_WarnEx(
84 PyExc_UserWarning,
85 "torch.range is deprecated and will be removed in a future release "
86 "because its behavior is inconsistent with Python's range builtin. "
87 "Instead, use torch.arange, which produces values in [start, end).",
88 1);
89 if (ret != 0)
90 throw python_error();
91 if (r.isNone(3)) {
92 const auto options = TensorOptions()
93 .dtype(r.scalartype(4))
94 .device(r.device(6))
95 .layout(r.layout(5))
96 .requires_grad(r.toBool(7));
97 return wrap(
98 dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), options));
99 } else {
100 check_out_type_matches(
101 r.tensor(3),
102 r.scalartype(4),
103 r.isNone(4),
104 r.layout(5),
105 r.device(6),
106 r.isNone(6));
107 return wrap(
108 dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3))
109 .set_requires_grad(r.toBool(7)));
110 }
111 }
112 Py_RETURN_NONE;
113 END_HANDLE_TH_ERRORS
114 }
115
116 // implemented on python object to allow torch.as_tensor to be constructed with
117 // arbitrarily nested python objects - list, tuple, np array, scalar, etc.
THPVariable_as_tensor(PyObject * self,PyObject * args,PyObject * kwargs)118 static PyObject* THPVariable_as_tensor(
119 PyObject* self,
120 PyObject* args,
121 PyObject* kwargs) {
122 HANDLE_TH_ERRORS
123 static PythonArgParser parser({
124 "as_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None)",
125 });
126
127 ParsedArgs<3> parsed_args;
128 auto r = parser.parse(args, kwargs, parsed_args);
129 if (r.has_torch_function()) {
130 return handle_torch_function(
131 r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
132 }
133 jit::tracer::warn("torch.as_tensor", jit::tracer::WARN_CONSTRUCTOR);
134 return THPVariable_Wrap(torch::utils::as_tensor(
135 torch::tensors::get_default_dispatch_key(),
136 torch::tensors::get_default_scalar_type(),
137 r));
138 END_HANDLE_TH_ERRORS
139 }
140
141 // implemented on python object here because PyObject currently not natively
142 // declarable See: ATen/native/README.md for more context
THPVariable_from_numpy(PyObject * module,PyObject * arg)143 static PyObject* THPVariable_from_numpy(PyObject* module, PyObject* arg) {
144 HANDLE_TH_ERRORS
145 jit::tracer::warn("torch.from_numpy", jit::tracer::WARN_CONSTRUCTOR);
146 return THPVariable_Wrap(torch::utils::tensor_from_numpy(arg));
147 END_HANDLE_TH_ERRORS
148 }
149
dispatch_nonzero(const Tensor & self)150 static Tensor dispatch_nonzero(const Tensor& self) {
151 pybind11::gil_scoped_release no_gil;
152 OptionalDeviceGuard device_guard(device_of(self));
153 return self.nonzero();
154 }
155
dispatch_nonzero(const Tensor & self,Tensor out)156 static Tensor dispatch_nonzero(const Tensor& self, Tensor out) {
157 pybind11::gil_scoped_release no_gil;
158 OptionalDeviceGuard device_guard(device_of(self));
159 return at::nonzero_out(out, self);
160 }
161
dispatch_nonzero_numpy(const Tensor & self)162 static std::vector<Tensor> dispatch_nonzero_numpy(const Tensor& self) {
163 pybind11::gil_scoped_release no_gil;
164 OptionalDeviceGuard device_guard(device_of(self));
165 return self.nonzero_numpy();
166 }
167
168 static PyObject* THPVariable_nonzero(
169 PyObject* self,
170 PyObject* args,
171 PyObject* kwargs);
172
173 #define THPVARIABLE_SPARSE_COMPRESSED_CTOR(NAME, NARGS, SIGNATURES) \
174 static PyObject* THPVariable_##NAME( \
175 PyObject* self, PyObject* args, PyObject* kwargs) { \
176 HANDLE_TH_ERRORS \
177 static PythonArgParser parser SIGNATURES; \
178 ParsedArgs<NARGS> parsed_args; \
179 auto r = parser.parse(args, kwargs, parsed_args); \
180 if (r.has_torch_function()) { \
181 return handle_torch_function( \
182 r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); \
183 } \
184 jit::tracer::warn("torch." #NAME, jit::tracer::WARN_CONSTRUCTOR); \
185 return THPVariable_Wrap(torch::utils::NAME##_ctor( \
186 torch::tensors::get_default_dispatch_key(), \
187 torch::tensors::get_default_scalar_type(), \
188 r)); \
189 END_HANDLE_TH_ERRORS \
190 }
191
192 THPVARIABLE_SPARSE_COMPRESSED_CTOR(
193 sparse_compressed_tensor,
194 10,
195 ({"sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
196 "sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
197 THPVARIABLE_SPARSE_COMPRESSED_CTOR(
198 sparse_csr_tensor,
199 10,
200 ({"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
201 "sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
202 THPVARIABLE_SPARSE_COMPRESSED_CTOR(
203 sparse_csc_tensor,
204 10,
205 ({"sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
206 "sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
207 THPVARIABLE_SPARSE_COMPRESSED_CTOR(
208 sparse_bsr_tensor,
209 10,
210 ({"sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
211 "sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
212 THPVARIABLE_SPARSE_COMPRESSED_CTOR(
213 sparse_bsc_tensor,
214 10,
215 ({"sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
216 "sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
217
THPVariable_sparse_coo_tensor(PyObject * self,PyObject * args,PyObject * kwargs)218 static PyObject* THPVariable_sparse_coo_tensor(
219 PyObject* self,
220 PyObject* args,
221 PyObject* kwargs) {
222 HANDLE_TH_ERRORS
223 static PythonArgParser parser({
224 "sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
225 "sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None, bool is_coalesced=None)",
226 "sparse_coo_tensor(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False, bool check_invariants=None)",
227 });
228
229 ParsedArgs<9> parsed_args;
230 auto r = parser.parse(args, kwargs, parsed_args);
231 if (r.has_torch_function()) {
232 return handle_torch_function(
233 r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
234 }
235 jit::tracer::warn("torch.sparse_coo_tensor", jit::tracer::WARN_CONSTRUCTOR);
236 return THPVariable_Wrap(torch::utils::sparse_coo_tensor_ctor(
237 torch::tensors::get_default_dispatch_key(),
238 torch::tensors::get_default_scalar_type(),
239 r));
240 END_HANDLE_TH_ERRORS
241 }
242
243 // implemented on python object to allow torch.tensor to be constructed with
244 // arbitrarily nested python objects - list, tuple, np array, scalar, etc.
THPVariable_tensor(PyObject * self,PyObject * args,PyObject * kwargs)245 static PyObject* THPVariable_tensor(
246 PyObject* self,
247 PyObject* args,
248 PyObject* kwargs) {
249 HANDLE_TH_ERRORS
250 static PythonArgParser parser({
251 "tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, DimnameList? names=None)",
252 });
253
254 constexpr int ctor_num_args = 6;
255 ParsedArgs<ctor_num_args> parsed_args;
256 auto r = parser.parse(args, kwargs, parsed_args);
257 if (r.has_torch_function()) {
258 return handle_torch_function(
259 r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
260 }
261 jit::tracer::warn("torch.tensor", jit::tracer::WARN_CONSTRUCTOR);
262 return THPVariable_Wrap(torch::utils::tensor_ctor(
263 torch::tensors::get_default_dispatch_key(),
264 torch::tensors::get_default_scalar_type(),
265 r));
266 END_HANDLE_TH_ERRORS
267 }
268
THPVariable_get_device(PyObject * self_,PyObject * args,PyObject * kwargs)269 static PyObject* THPVariable_get_device(
270 PyObject* self_,
271 PyObject* args,
272 PyObject* kwargs) {
273 HANDLE_TH_ERRORS
274 static PythonArgParser parser(
275 {
276 "get_device(Tensor input)",
277 },
278 /*traceable=*/false);
279
280 ParsedArgs<1> parsed_args;
281 auto r = parser.parse(args, kwargs, parsed_args);
282 if (r.has_torch_function()) {
283 return handle_torch_function(
284 r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
285 }
286
287 if (r.idx == 0) {
288 return wrap(r.tensor(0).get_device());
289 }
290 Py_RETURN_NONE;
291 END_HANDLE_TH_ERRORS
292 }
293
THPVariable_frombuffer(PyObject * self_,PyObject * args,PyObject * kwargs)294 static PyObject* THPVariable_frombuffer(
295 PyObject* self_,
296 PyObject* args,
297 PyObject* kwargs) {
298 HANDLE_TH_ERRORS
299 static PythonArgParser parser(
300 {
301 "frombuffer(PyObject* buffer, *, ScalarType dtype, int64_t count=-1, int64_t offset=0, bool requires_grad=False)",
302 },
303 /*traceable=*/false);
304
305 ParsedArgs<5> parsed_args;
306 auto r = parser.parse(args, kwargs, parsed_args);
307
308 if (r.idx == 0) {
309 auto buffer = r.pyobject(0);
310 auto dtype = r.scalartype(1);
311 auto count = r.toInt64(2);
312 auto offset = r.toInt64(3);
313 auto requires_grad = r.toBool(4);
314
315 TORCH_CHECK_VALUE(
316 PyObject_CheckBuffer(buffer) != 0,
317 "object does not implement Python buffer protocol.");
318 return wrap(torch::utils::tensor_frombuffer(
319 buffer, dtype, count, offset, requires_grad));
320 }
321
322 Py_RETURN_NONE;
323 END_HANDLE_TH_ERRORS
324 }
325
THPVariable_asarray(PyObject * self_,PyObject * args,PyObject * kwargs)326 static PyObject* THPVariable_asarray(
327 PyObject* self_,
328 PyObject* args,
329 PyObject* kwargs) {
330 HANDLE_TH_ERRORS
331 static PythonArgParser parser(
332 {
333 "asarray(PyObject* obj, *, ScalarType? dtype=None, Device? device=None, bool? copy=None, bool requires_grad=False)",
334 },
335 /*traceable=*/false);
336
337 ParsedArgs<5> parsed_args;
338 auto r = parser.parse(args, kwargs, parsed_args);
339
340 if (r.has_torch_function()) {
341 return handle_torch_function(
342 r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
343 }
344
345 if (r.idx == 0) {
346 auto obj = r.pyobject(0);
347 auto dtype = r.scalartypeOptional(1);
348 auto device = r.deviceOptional(2);
349 auto copy = r.toBoolOptional(3);
350 auto requires_grad = r.toBool(4);
351 return wrap(torch::utils::asarray(obj, dtype, device, copy, requires_grad));
352 }
353
354 Py_RETURN_NONE;
355 END_HANDLE_TH_ERRORS
356 }
357
358 static PyObject* THPVariable_numel(
359 PyObject* self_,
360 PyObject* args,
361 PyObject* kwargs);
362
363 // XXX: ops that are bound here are not exposed to the C++ api nor the JIT.
364 // Any new ops added here should be accompanied with a comment why they are not
365 // being registered through native_functions.yaml, and be tagged cpp / JIT
366 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
367 static PyMethodDef torch_functions_manual[] = {
368 {"asarray",
369 castPyCFunctionWithKeywords(THPVariable_asarray),
370 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
371 nullptr},
372 {"as_tensor",
373 castPyCFunctionWithKeywords(THPVariable_as_tensor),
374 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
375 nullptr},
376 {"from_numpy", THPVariable_from_numpy, METH_STATIC | METH_O, nullptr},
377 {"frombuffer",
378 castPyCFunctionWithKeywords(THPVariable_frombuffer),
379 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
380 nullptr},
381 {"nonzero",
382 castPyCFunctionWithKeywords(THPVariable_nonzero),
383 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
384 nullptr},
385 {"range",
386 castPyCFunctionWithKeywords(THPVariable_range),
387 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
388 nullptr},
389 {"sparse_coo_tensor",
390 castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor),
391 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
392 nullptr},
393 {"sparse_compressed_tensor",
394 castPyCFunctionWithKeywords(THPVariable_sparse_compressed_tensor),
395 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
396 nullptr},
397 {"sparse_csr_tensor",
398 castPyCFunctionWithKeywords(THPVariable_sparse_csr_tensor),
399 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
400 nullptr},
401 {"sparse_csc_tensor",
402 castPyCFunctionWithKeywords(THPVariable_sparse_csc_tensor),
403 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
404 nullptr},
405 {"sparse_bsr_tensor",
406 castPyCFunctionWithKeywords(THPVariable_sparse_bsr_tensor),
407 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
408 nullptr},
409 {"sparse_bsc_tensor",
410 castPyCFunctionWithKeywords(THPVariable_sparse_bsc_tensor),
411 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
412 nullptr},
413 {"tensor",
414 castPyCFunctionWithKeywords(THPVariable_tensor),
415 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
416 nullptr},
417 {"get_device",
418 castPyCFunctionWithKeywords(THPVariable_get_device),
419 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
420 nullptr},
421 {"numel",
422 castPyCFunctionWithKeywords(THPVariable_numel),
423 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
424 nullptr},
425 };
426
THPVariable_nonzero(PyObject * self,PyObject * args,PyObject * kwargs)427 static PyObject* THPVariable_nonzero(
428 PyObject* self,
429 PyObject* args,
430 PyObject* kwargs) {
431 HANDLE_TH_ERRORS
432 static PythonArgParser parser({
433 "nonzero(Tensor input, *, bool as_tuple=False, Tensor out=None)",
434 });
435 ParsedArgs<3> parsed_args;
436 auto r = parser.parse(args, kwargs, parsed_args);
437
438 if (r.has_torch_function()) {
439 return handle_torch_function(
440 r, args, kwargs, THPVariableFunctionsModule, "torch");
441 }
442
443 const auto as_tuple = r.toBool(1);
444 const auto has_out = !r.isNone(2);
445
446 if (as_tuple) {
447 TORCH_CHECK(
448 !has_out,
449 "nonzero does not support the out kwarg when as_tuple is True");
450 return wrap(dispatch_nonzero_numpy(r.tensor(0)));
451 }
452
453 if (has_out) {
454 return wrap(dispatch_nonzero(r.tensor(0), r.tensor(2)));
455 }
456
457 return wrap(dispatch_nonzero(r.tensor(0)));
458
459 END_HANDLE_TH_ERRORS
460 }
461
THPVariable_numel(PyObject * self_,PyObject * args,PyObject * kwargs)462 static PyObject* THPVariable_numel(
463 PyObject* self_,
464 PyObject* args,
465 PyObject* kwargs) {
466 HANDLE_TH_ERRORS
467 static PythonArgParser parser(
468 {
469 "numel(Tensor input)",
470 },
471 /*traceable=*/false);
472
473 ParsedArgs<1> parsed_args;
474 auto r = parser.parse(args, kwargs, parsed_args);
475
476 if (r.has_torch_function()) {
477 return handle_torch_function(
478 r, args, kwargs, THPVariableFunctionsModule, "torch");
479 }
480
481 if (r.idx == 0) {
482 return py::cast(r.tensor(0).sym_numel()).release().ptr();
483 }
484 Py_RETURN_NONE;
485 END_HANDLE_TH_ERRORS
486 }
487
488 // Sharded function definitions
489 void gatherTorchFunctions_0(std::vector<PyMethodDef>& torch_functions);
490 void gatherTorchFunctions_1(std::vector<PyMethodDef>& torch_functions);
491 void gatherTorchFunctions_2(std::vector<PyMethodDef>& torch_functions);
492
gatherTorchFunctions(std::vector<PyMethodDef> & torch_functions)493 void gatherTorchFunctions(std::vector<PyMethodDef>& torch_functions) {
494 constexpr size_t num_functions =
495 sizeof(torch_functions_manual) / sizeof(torch_functions_manual[0]);
496 torch_functions.assign(
497 torch_functions_manual, torch_functions_manual + num_functions);
498 // NOTE: Must be synced with num_shards in
499 // tools/autograd/gen_python_functions.py
500 gatherTorchFunctions_0(torch_functions);
501 gatherTorchFunctions_1(torch_functions);
502 gatherTorchFunctions_2(torch_functions);
503
504 static std::array<std::pair<const char*, const char*>, 4> aliases{
505 {// Canonical function, alias name
506 {"sspaddmm", "saddmm"},
507 {"mm", "spmm"},
508 {"mm", "dsmm"},
509 {"hspmm", "hsmm"}}};
510
511 for (const auto& alias : aliases) {
512 auto it = std::find_if(
513 torch_functions.begin(),
514 torch_functions.end(),
515 [&](const PyMethodDef& def) {
516 return strcmp(def.ml_name, alias.first) == 0;
517 });
518 TORCH_INTERNAL_ASSERT(
519 it != torch_functions.end(),
520 "Failed to create function alias from ",
521 alias.first,
522 " to ",
523 alias.second);
524 PyMethodDef alias_def = *it;
525 alias_def.ml_name = alias.second;
526
527 torch_functions.push_back(alias_def);
528 }
529
530 torch_functions.push_back({nullptr});
531 torch_functions.shrink_to_fit();
532 }
533
534 static PyTypeObject THPVariableFunctions = {
535 PyVarObject_HEAD_INIT(
536 nullptr,
537 0) "torch._C._VariableFunctionsClass", /* tp_name */
538 0, /* tp_basicsize */
539 0, /* tp_itemsize */
540 nullptr, /* tp_dealloc */
541 0, /* tp_vectorcall_offset */
542 nullptr, /* tp_getattr */
543 nullptr, /* tp_setattr */
544 nullptr, /* tp_reserved */
545 nullptr, /* tp_repr */
546 nullptr, /* tp_as_number */
547 nullptr, /* tp_as_sequence */
548 nullptr, /* tp_as_mapping */
549 nullptr, /* tp_hash */
550 nullptr, /* tp_call */
551 nullptr, /* tp_str */
552 nullptr, /* tp_getattro */
553 nullptr, /* tp_setattro */
554 nullptr, /* tp_as_buffer */
555 Py_TPFLAGS_DEFAULT, /* tp_flags */
556 nullptr, /* tp_doc */
557 nullptr, /* tp_traverse */
558 nullptr, /* tp_clear */
559 nullptr, /* tp_richcompare */
560 0, /* tp_weaklistoffset */
561 nullptr, /* tp_iter */
562 nullptr, /* tp_iternext */
563 nullptr, /* tp_methods */
564 nullptr, /* tp_members */
565 nullptr, /* tp_getset */
566 nullptr, /* tp_base */
567 nullptr, /* tp_dict */
568 nullptr, /* tp_descr_get */
569 nullptr, /* tp_descr_set */
570 0, /* tp_dictoffset */
571 nullptr, /* tp_init */
572 nullptr, /* tp_alloc */
573 nullptr /* tp_new */
574 };
575
initTorchFunctions(PyObject * module)576 void initTorchFunctions(PyObject* module) {
577 static std::vector<PyMethodDef> torch_functions;
578 gatherTorchFunctions(torch_functions);
579 THPVariableFunctions.tp_methods = torch_functions.data();
580
581 if (PyType_Ready(&THPVariableFunctions) < 0) {
582 throw python_error();
583 }
584 Py_INCREF(&THPVariableFunctions);
585
586 // Steals
587 Py_INCREF(&THPVariableFunctions);
588 if (PyModule_AddObject(
589 module,
590 "_VariableFunctionsClass",
591 reinterpret_cast<PyObject*>(&THPVariableFunctions)) < 0) {
592 throw python_error();
593 }
594 // PyType_GenericNew returns a new reference
595 THPVariableFunctionsModule =
596 PyType_GenericNew(&THPVariableFunctions, Py_None, Py_None);
597 // PyModule_AddObject steals a reference
598 if (PyModule_AddObject(
599 module, "_VariableFunctions", THPVariableFunctionsModule) < 0) {
600 throw python_error();
601 }
602
603 // pybind registrations to torch module
604 // TODO: move these from torch.* to torch._C.*
605 auto py_module = py::module::import("torch");
606
607 py_module.def(
608 "_functionalize_are_all_mutations_under_no_grad_or_inference_mode",
609 [](const at::Tensor& t) {
610 TORCH_INTERNAL_ASSERT(
611 at::functionalization::impl::isFunctionalTensor(t));
612 return at::functionalization::impl::
613 are_all_mutations_under_no_grad_or_inference_mode(t);
614 });
615 py_module.def(
616 "_functionalize_was_inductor_storage_resized", [](const at::Tensor& t) {
617 TORCH_INTERNAL_ASSERT(
618 at::functionalization::impl::isFunctionalTensor(t));
619 auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
620 return impl->was_inductor_storage_resized();
621 });
622 py_module.def(
623 "_functionalize_are_all_mutations_hidden_from_autograd",
624 [](const at::Tensor& t) {
625 TORCH_INTERNAL_ASSERT(
626 at::functionalization::impl::isFunctionalTensor(t));
627 return at::functionalization::impl::
628 are_all_mutations_hidden_from_autograd(t);
629 });
630 py_module.def(
631 "_functionalize_mark_mutation_hidden_from_autograd",
632 [](const at::Tensor& t) {
633 TORCH_INTERNAL_ASSERT(
634 at::functionalization::impl::isFunctionalTensor(t));
635 at::functionalization::impl::mark_mutation_hidden_from_autograd(t);
636 });
637 py_module.def(
638 "_functionalize_apply_view_metas",
639 [](const at::Tensor& tensor, const at::Tensor& base) {
640 TORCH_INTERNAL_ASSERT(
641 at::functionalization::impl::isFunctionalTensor(tensor));
642 auto impl =
643 at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
644 return impl->apply_view_metas(base);
645 });
646 py_module.def("_functionalize_is_symbolic", [](const at::Tensor& t) {
647 TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
648 auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
649 return impl->is_symbolic();
650 });
651 py_module.def("_functionalize_sync", [](const at::Tensor& t) {
652 TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
653 at::functionalization::impl::sync(t);
654 });
655 py_module.def("_functionalize_commit_update", [](const at::Tensor& t) {
656 TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
657 at::functionalization::impl::commit_update(t);
658 });
659 py_module.def(
660 "_functionalize_replace", [](const at::Tensor& t, const at::Tensor& o) {
661 TORCH_INTERNAL_ASSERT(
662 at::functionalization::impl::isFunctionalTensor(t));
663 TORCH_INTERNAL_ASSERT(
664 !at::functionalization::impl::isFunctionalTensor(o));
665 at::functionalization::impl::replace_(t, o);
666 });
667 py_module.def("_is_functional_tensor_base", [](const at::Tensor& t) {
668 TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
669 return at::functionalization::impl::isBaseTensor(t);
670 });
671 py_module.def("_functionalize_is_multi_output_view", [](const at::Tensor& t) {
672 TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
673 auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
674 return t_impl->is_multi_output_view();
675 });
676 py_module.def(
677 "_functionalize_enable_reapply_views",
678 [](bool reapply_views = false) {
679 auto old =
680 at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
681 at::functionalization::impl::setFunctionalizationReapplyViewsTLS(
682 reapply_views);
683 return old;
684 },
685 py::arg("reapply_views") = false);
686 py_module.def(
687 "_functionalize_has_metadata_mutation", [](const at::Tensor& t) {
688 TORCH_INTERNAL_ASSERT(
689 at::functionalization::impl::isFunctionalTensor(t));
690 auto t_impl =
691 at::functionalization::impl::unsafeGetFunctionalWrapper(t);
692 return t_impl->has_metadata_mutation();
693 });
694 py_module.def("_functionalize_has_data_mutation", [](const at::Tensor& t) {
695 TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
696 auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
697 return t_impl->has_data_mutation();
698 });
699 py_module.def(
700 "_functionalize_get_storage_size", [](const at::Tensor& t, bool before) {
701 TORCH_INTERNAL_ASSERT(
702 at::functionalization::impl::isFunctionalTensor(t));
703 auto wrapper =
704 at::functionalization::impl::unsafeGetFunctionalWrapper(t);
705 auto size = wrapper->get_storage_size(/*before=*/before);
706 return size;
707 });
708 py_module.def("_functionalize_set_storage_changed", [](const at::Tensor& t) {
709 TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
710 auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
711 wrapper->set_storage_changed();
712 });
713 py_module.def("_functionalize_was_storage_changed", [](const at::Tensor& t) {
714 TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
715 auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
716 return wrapper->was_storage_changed();
717 });
718 py_module.def(
719 "_functionalize_unsafe_set", [](at::Tensor& dst, const at::Tensor& src) {
720 // Forcefully/unsafely dumps src.storage into dst.
721 // This API is technically and not specific to functionalization
722 // (it just runs set_() without the safety checks).
723 // But its main intended purpose today is during functionalization.
724 // In particular: when we generate a new FunctionalTensor from a view
725 // op, we need to ensure it shares a storage with the view input.
726 //
727 // Other subclasses shouldn't really need to care about this,
728 // because we define aliasing on wrapper subclasses such that:
729 // - differentiable aliasing: subclass_x and subclass_y share a ._base.
730 // - non-differentiable aliasing: aliasing of subclass_x and subclass_y
731 // is defined recursively based on the aliasing of their inner
732 // tensors.
733 at::native::checkSetStorage(
734 dst,
735 src.storage(),
736 dst.sym_storage_offset(),
737 dst.sym_sizes(),
738 dst.sym_strides());
739 });
740 py_module.def(
741 "_functionalize_mark_mutation_hidden_from_autograd",
742 [](const at::Tensor& t) {
743 TORCH_INTERNAL_ASSERT(
744 at::functionalization::impl::isFunctionalTensor(t));
745 at::functionalization::impl::mark_mutation_hidden_from_autograd(t);
746 });
747 py_module.def("_is_functional_tensor", [](const at::Tensor& t) {
748 return at::functionalization::impl::isFunctionalTensor(t);
749 });
750 py_module.def("_to_functional_tensor", [](const at::Tensor& t) {
751 return at::functionalization::impl::to_functional_tensor(t);
752 });
753 py_module.def("_from_functional_tensor", [](const at::Tensor& t) {
754 return at::functionalization::impl::from_functional_tensor(t);
755 });
756 py_module.def("_freeze_functional_tensor", [](const at::Tensor& t) {
757 at::functionalization::impl::freeze_functional_tensor(t);
758 });
759 py_module.def(
760 "_enable_functionalization",
761 [](bool reapply_views = false) {
762 if (c10::impl::tls_is_dispatch_key_included(
763 at::DispatchKey::Functionalize)) {
764 TORCH_INTERNAL_ASSERT(
765 false,
766 "multiple layers of mode-style functionalization nesting is not"
767 " currently supported, outside of the functionalize() transform");
768 }
769 c10::impl::tls_set_dispatch_key_included(
770 at::DispatchKey::Functionalize, true);
771 if (reapply_views) {
772 at::functionalization::impl::setFunctionalizationReapplyViewsTLS(
773 true);
774 }
775 },
776 py::arg("reapply_views") = false);
777 py_module.def("_disable_functionalization", []() {
778 c10::impl::tls_set_dispatch_key_included(
779 at::DispatchKey::Functionalize, false);
780 at::functionalization::impl::setFunctionalizationReapplyViewsTLS(false);
781 });
782 py_module.def(
783 "_mirror_autograd_meta_to",
784 [](const at::Tensor& src_, const at::Tensor& dst_) {
785 // Here, we unsafely set the grad function on the wrapper to be the same
786 // as the inner. We expect this grad_fn to NEVER be used. It's needed so
787 // that .is_leaf metadata is accurate on the wrapper
788 auto inner_autograd_meta = impl::get_autograd_meta(src_);
789 if (inner_autograd_meta) {
790 dst_.set_requires_grad(src_.requires_grad());
791 if (dst_.requires_grad()) {
792 auto new_grad_fn = std::shared_ptr<torch::autograd::Error>(
793 new torch::autograd::Error(
794 "Cannot backprop through mirrored meta, file a bug in PyTorch"),
795 torch::autograd::deleteNode);
796 torch::autograd::set_history(dst_, new_grad_fn);
797 }
798 }
799 });
800 }
801
802 } // namespace torch::autograd
803