xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_legacy_variable.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/python_legacy_variable.h>
2 
3 #include <ATen/ATen.h>
4 
5 #include <torch/csrc/Exceptions.h>
6 #include <torch/csrc/autograd/python_function.h>
7 #include <torch/csrc/autograd/python_variable.h>
8 #include <torch/csrc/jit/frontend/tracer.h>
9 #include <torch/csrc/tensor/python_tensor.h>
10 
11 using namespace at;
12 
13 namespace torch::autograd {
14 
THPVariable_pynew(PyTypeObject * type,PyObject * args,PyObject * kwds)15 static PyObject* THPVariable_pynew(
16     PyTypeObject* type,
17     PyObject* args,
18     PyObject* kwds) {
19   HANDLE_TH_ERRORS
20   THPObjectPtr _data;
21   PyObject* data = nullptr;
22   PyObject* grad_fn = nullptr;
23   char is_volatile = 0;
24   char requires_grad = 0;
25   const char* name = nullptr;
26 
27   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
28   constexpr const char* accepted_args[] = {
29       "data", "requires_grad", "volatile", "_grad_fn", "name", nullptr};
30   if (!PyArg_ParseTupleAndKeywords(
31           args,
32           kwds,
33           "|ObbOz",
34           // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
35           const_cast<char**>(accepted_args),
36           &data,
37           &requires_grad,
38           &is_volatile,
39           &grad_fn,
40           &name))
41     return nullptr;
42 
43   if (grad_fn == Py_None)
44     grad_fn = nullptr;
45 
46   if (is_volatile) {
47     auto r = PyErr_WarnEx(
48         PyExc_UserWarning,
49         "volatile was removed and now has no effect. Use `with torch.no_grad():` "
50         "instead.",
51         1);
52     if (r != 0)
53       throw python_error();
54   }
55 
56   TORCH_CHECK_VALUE(
57       !is_volatile || !requires_grad,
58       "Variable can't be volatile and require_grad at the same time!");
59   if (grad_fn && !THPFunction_Check(grad_fn)) {
60     throw TypeError(
61         "_grad_fn has to be a Function object or None, but got %s",
62         Py_TYPE(grad_fn)->tp_name);
63   }
64   Variable var;
65   if (!data || data == Py_None) {
66     // For legacy serialization code, create an empty tensor. This is also used
67     // by nn.Parameter() with no arguments.
68     auto dispatch_key = torch::tensors::get_default_dispatch_key();
69     auto scalar_type = torch::tensors::get_default_scalar_type();
70     auto options = TensorOptions(scalar_type)
71                        .device(dispatchKeyToDeviceType(dispatch_key))
72                        .layout(dispatchKeyToLayout(dispatch_key));
73     var = at::empty({0}, options);
74   } else if (THPVariable_Check(data)) {
75     var = THPVariable_Unpack(data).detach();
76   } else {
77     throw torch::TypeError(
78         "Variable data has to be a tensor, but got %s", Py_TYPE(data)->tp_name);
79   }
80   // We set `tensor`'s `allow_tensor_metadata_change` to true here, because we
81   // want to allow the following use case for backward compatibility:
82   //
83   // ```python
84   // var = Variable(torch.randn(2, 3))
85   // var.resize_(4, 5)
86   // ```
87   var.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
88 
89   TORCH_CHECK(
90       !grad_fn,
91       "_grad_fn argument to legacy Variable constructor is no longer supported.  "
92       "Instead, please invoke your _grad_fn to produce a variable with it as the "
93       "_grad_fn.");
94   var.set_requires_grad(requires_grad);
95 
96   if (name) {
97     impl::set_name(var, name);
98   }
99 
100   if (jit::tracer::isTracing() && data && data != Py_None &&
101       THPVariable_Check(data)) {
102     if (auto* v = jit::tracer::getValueTrace(THPVariable_Unpack(data))) {
103       jit::tracer::setValueTrace(var, v);
104     }
105   }
106 
107   return THPVariable_Wrap(std::move(var));
108   END_HANDLE_TH_ERRORS
109 }
110 
111 PyTypeObject THPLegacyVariableType = {
112     PyVarObject_HEAD_INIT(
113         nullptr,
114         0) "torch._C._LegacyVariableBase", /* tp_name */
115     0, /* tp_basicsize */
116     0, /* tp_itemsize */
117     nullptr, /* tp_dealloc */
118     0, /* tp_vectorcall_offset */
119     nullptr, /* tp_getattr */
120     nullptr, /* tp_setattr */
121     nullptr, /* tp_reserved */
122     nullptr, /* tp_repr */
123     nullptr, /* tp_as_number */
124     nullptr, /* tp_as_sequence */
125     nullptr, /* tp_as_mapping */
126     nullptr, /* tp_hash  */
127     nullptr, /* tp_call */
128     nullptr, /* tp_str */
129     nullptr, /* tp_getattro */
130     nullptr, /* tp_setattro */
131     nullptr, /* tp_as_buffer */
132     // NOLINTNEXTLINE(misc-redundant-expression)
133     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
134     nullptr, /* tp_doc */
135     nullptr, /* tp_traverse */
136     nullptr, /* tp_clear */
137     nullptr, /* tp_richcompare */
138     0, /* tp_weaklistoffset */
139     nullptr, /* tp_iter */
140     nullptr, /* tp_iternext */
141     nullptr, /* tp_methods */
142     nullptr, /* tp_members */
143     nullptr, /* tp_getset */
144     nullptr, /* tp_base */
145     nullptr, /* tp_dict */
146     nullptr, /* tp_descr_get */
147     nullptr, /* tp_descr_set */
148     0, /* tp_dictoffset */
149     nullptr, /* tp_init */
150     nullptr, /* tp_alloc */
151     THPVariable_pynew /* tp_new */
152 };
153 
init_legacy_variable(PyObject * module)154 void init_legacy_variable(PyObject* module) {
155   if (PyType_Ready(&THPLegacyVariableType) < 0) {
156     throw python_error();
157   }
158   auto obj = (PyObject*)&THPLegacyVariableType;
159   Py_INCREF(obj);
160   if (PyModule_AddObject(module, "_LegacyVariableBase", obj) < 0) {
161     throw python_error();
162   }
163 }
164 
165 } // namespace torch::autograd
166