1 #include <torch/csrc/autograd/python_legacy_variable.h>
2
3 #include <ATen/ATen.h>
4
5 #include <torch/csrc/Exceptions.h>
6 #include <torch/csrc/autograd/python_function.h>
7 #include <torch/csrc/autograd/python_variable.h>
8 #include <torch/csrc/jit/frontend/tracer.h>
9 #include <torch/csrc/tensor/python_tensor.h>
10
11 using namespace at;
12
13 namespace torch::autograd {
14
THPVariable_pynew(PyTypeObject * type,PyObject * args,PyObject * kwds)15 static PyObject* THPVariable_pynew(
16 PyTypeObject* type,
17 PyObject* args,
18 PyObject* kwds) {
19 HANDLE_TH_ERRORS
20 THPObjectPtr _data;
21 PyObject* data = nullptr;
22 PyObject* grad_fn = nullptr;
23 char is_volatile = 0;
24 char requires_grad = 0;
25 const char* name = nullptr;
26
27 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
28 constexpr const char* accepted_args[] = {
29 "data", "requires_grad", "volatile", "_grad_fn", "name", nullptr};
30 if (!PyArg_ParseTupleAndKeywords(
31 args,
32 kwds,
33 "|ObbOz",
34 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
35 const_cast<char**>(accepted_args),
36 &data,
37 &requires_grad,
38 &is_volatile,
39 &grad_fn,
40 &name))
41 return nullptr;
42
43 if (grad_fn == Py_None)
44 grad_fn = nullptr;
45
46 if (is_volatile) {
47 auto r = PyErr_WarnEx(
48 PyExc_UserWarning,
49 "volatile was removed and now has no effect. Use `with torch.no_grad():` "
50 "instead.",
51 1);
52 if (r != 0)
53 throw python_error();
54 }
55
56 TORCH_CHECK_VALUE(
57 !is_volatile || !requires_grad,
58 "Variable can't be volatile and require_grad at the same time!");
59 if (grad_fn && !THPFunction_Check(grad_fn)) {
60 throw TypeError(
61 "_grad_fn has to be a Function object or None, but got %s",
62 Py_TYPE(grad_fn)->tp_name);
63 }
64 Variable var;
65 if (!data || data == Py_None) {
66 // For legacy serialization code, create an empty tensor. This is also used
67 // by nn.Parameter() with no arguments.
68 auto dispatch_key = torch::tensors::get_default_dispatch_key();
69 auto scalar_type = torch::tensors::get_default_scalar_type();
70 auto options = TensorOptions(scalar_type)
71 .device(dispatchKeyToDeviceType(dispatch_key))
72 .layout(dispatchKeyToLayout(dispatch_key));
73 var = at::empty({0}, options);
74 } else if (THPVariable_Check(data)) {
75 var = THPVariable_Unpack(data).detach();
76 } else {
77 throw torch::TypeError(
78 "Variable data has to be a tensor, but got %s", Py_TYPE(data)->tp_name);
79 }
80 // We set `tensor`'s `allow_tensor_metadata_change` to true here, because we
81 // want to allow the following use case for backward compatibility:
82 //
83 // ```python
84 // var = Variable(torch.randn(2, 3))
85 // var.resize_(4, 5)
86 // ```
87 var.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
88
89 TORCH_CHECK(
90 !grad_fn,
91 "_grad_fn argument to legacy Variable constructor is no longer supported. "
92 "Instead, please invoke your _grad_fn to produce a variable with it as the "
93 "_grad_fn.");
94 var.set_requires_grad(requires_grad);
95
96 if (name) {
97 impl::set_name(var, name);
98 }
99
100 if (jit::tracer::isTracing() && data && data != Py_None &&
101 THPVariable_Check(data)) {
102 if (auto* v = jit::tracer::getValueTrace(THPVariable_Unpack(data))) {
103 jit::tracer::setValueTrace(var, v);
104 }
105 }
106
107 return THPVariable_Wrap(std::move(var));
108 END_HANDLE_TH_ERRORS
109 }
110
111 PyTypeObject THPLegacyVariableType = {
112 PyVarObject_HEAD_INIT(
113 nullptr,
114 0) "torch._C._LegacyVariableBase", /* tp_name */
115 0, /* tp_basicsize */
116 0, /* tp_itemsize */
117 nullptr, /* tp_dealloc */
118 0, /* tp_vectorcall_offset */
119 nullptr, /* tp_getattr */
120 nullptr, /* tp_setattr */
121 nullptr, /* tp_reserved */
122 nullptr, /* tp_repr */
123 nullptr, /* tp_as_number */
124 nullptr, /* tp_as_sequence */
125 nullptr, /* tp_as_mapping */
126 nullptr, /* tp_hash */
127 nullptr, /* tp_call */
128 nullptr, /* tp_str */
129 nullptr, /* tp_getattro */
130 nullptr, /* tp_setattro */
131 nullptr, /* tp_as_buffer */
132 // NOLINTNEXTLINE(misc-redundant-expression)
133 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
134 nullptr, /* tp_doc */
135 nullptr, /* tp_traverse */
136 nullptr, /* tp_clear */
137 nullptr, /* tp_richcompare */
138 0, /* tp_weaklistoffset */
139 nullptr, /* tp_iter */
140 nullptr, /* tp_iternext */
141 nullptr, /* tp_methods */
142 nullptr, /* tp_members */
143 nullptr, /* tp_getset */
144 nullptr, /* tp_base */
145 nullptr, /* tp_dict */
146 nullptr, /* tp_descr_get */
147 nullptr, /* tp_descr_set */
148 0, /* tp_dictoffset */
149 nullptr, /* tp_init */
150 nullptr, /* tp_alloc */
151 THPVariable_pynew /* tp_new */
152 };
153
init_legacy_variable(PyObject * module)154 void init_legacy_variable(PyObject* module) {
155 if (PyType_Ready(&THPLegacyVariableType) < 0) {
156 throw python_error();
157 }
158 auto obj = (PyObject*)&THPLegacyVariableType;
159 Py_INCREF(obj);
160 if (PyModule_AddObject(module, "_LegacyVariableBase", obj) < 0) {
161 throw python_error();
162 }
163 }
164
165 } // namespace torch::autograd
166