xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/python_symnode.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/SafePyObject.h>
4 #include <c10/core/SymNodeImpl.h>
5 
6 #include <torch/csrc/PyInterpreter.h>
7 #include <torch/csrc/autograd/python_variable.h>
8 #include <torch/csrc/utils/pybind.h>
9 
10 namespace torch {
11 
12 TORCH_PYTHON_API py::handle get_symint_class();
13 TORCH_PYTHON_API py::handle get_symfloat_class();
14 TORCH_PYTHON_API py::handle get_symbool_class();
15 
16 // NB: These functions must not be called too early, otherwise torch not setup.
17 // Alternate design is to have torch "register" the object to us
is_symint(py::handle obj)18 inline bool is_symint(py::handle obj) {
19   return py::isinstance(obj, get_symint_class());
20 }
is_symfloat(py::handle obj)21 inline bool is_symfloat(py::handle obj) {
22   return py::isinstance(obj, get_symfloat_class());
23 }
is_symbool(py::handle obj)24 inline bool is_symbool(py::handle obj) {
25   return py::isinstance(obj, get_symbool_class());
26 }
27 
28 namespace impl {
29 
30 // This c10::SymNodeImpl simply backends to a Python object that
31 // implements the API.   The Python object is the source of truth,
32 // this is just an adapter so C++ calls can get to the object.
33 class PythonSymNodeImpl : public c10::SymNodeImpl {
34  public:
PythonSymNodeImpl(py::object pyobj)35   PythonSymNodeImpl(py::object pyobj) : c10::SymNodeImpl() {
36     pyobj_ = std::make_shared<c10::SafePyObject>(
37         pyobj.release().ptr(), getPyInterpreter());
38   };
39 
wrap_int(int64_t num)40   c10::SymNode wrap_int(int64_t num) override {
41     py::gil_scoped_acquire acquire;
42     auto r = getPyObj().attr("wrap_int")(num);
43     return c10::make_intrusive<PythonSymNodeImpl>(std::move(r));
44   }
45 
wrap_float(double num)46   c10::SymNode wrap_float(double num) override {
47     py::gil_scoped_acquire acquire;
48     auto r = getPyObj().attr("wrap_float")(num);
49     return c10::make_intrusive<PythonSymNodeImpl>(std::move(r));
50   }
51 
wrap_bool(bool num)52   c10::SymNode wrap_bool(bool num) override {
53     py::gil_scoped_acquire acquire;
54     auto r = getPyObj().attr("wrap_bool")(num);
55     return c10::make_intrusive<PythonSymNodeImpl>(std::move(r));
56   }
57 
58 #define TORCH_SYMNODE_SIZES_STRIDES(n)                                        \
59   c10::SymNode n(                                                             \
60       c10::ArrayRef<c10::SymNode> sizes, c10::ArrayRef<c10::SymNode> strides) \
61       override {                                                              \
62     py::gil_scoped_acquire acquire;                                           \
63     auto r = getPyObj().attr(#n)(sizes, strides);                             \
64     return c10::make_intrusive<PythonSymNodeImpl>(std::move(r));              \
65   }
66 
67   // clang-format off
68     TORCH_SYMNODE_SIZES_STRIDES(is_contiguous)
TORCH_SYMNODE_SIZES_STRIDES(is_channels_last_contiguous_2d)69     TORCH_SYMNODE_SIZES_STRIDES(is_channels_last_contiguous_2d)
70     TORCH_SYMNODE_SIZES_STRIDES(is_channels_last_contiguous_3d)
71     TORCH_SYMNODE_SIZES_STRIDES(is_channels_last_strides_2d)
72     TORCH_SYMNODE_SIZES_STRIDES(is_channels_last_strides_3d)
73     TORCH_SYMNODE_SIZES_STRIDES(is_non_overlapping_and_dense)
74   // clang-format on
75 
76 #undef TORCH_SYMNODE_SIZES_STRIDES
77 
78   bool bool_() override {
79     py::gil_scoped_acquire acquire;
80     return getPyObj().attr("bool_")().is(py::handle(Py_True));
81   }
82 
is_int()83   bool is_int() override {
84     py::gil_scoped_acquire acquire;
85     return getPyObj().attr("is_int")().is(py::handle(Py_True));
86   }
87 
is_float()88   bool is_float() override {
89     py::gil_scoped_acquire acquire;
90     return getPyObj().attr("is_float")().is(py::handle(Py_True));
91   }
92 
is_bool()93   bool is_bool() override {
94     py::gil_scoped_acquire acquire;
95     return getPyObj().attr("is_bool")().is(py::handle(Py_True));
96   }
97 
is_nested_int()98   bool is_nested_int() const override {
99     py::gil_scoped_acquire acquire;
100     return getPyObj().attr("is_nested_int")().is(py::handle(Py_True));
101   }
102 
has_hint()103   bool has_hint() override {
104     py::gil_scoped_acquire acquire;
105     return getPyObj().attr("has_hint")().is(py::handle(Py_True));
106   }
107 
guard_int(const char * file,int64_t line)108   int64_t guard_int(const char* file, int64_t line) override {
109     py::gil_scoped_acquire acquire;
110     return getPyObj().attr("guard_int")(file, line).cast<int64_t>();
111   }
112 
guard_float(const char * file,int64_t line)113   double guard_float(const char* file, int64_t line) override {
114     py::gil_scoped_acquire acquire;
115     return getPyObj().attr("guard_float")(file, line).cast<double>();
116   }
117 
guard_bool(const char * file,int64_t line)118   bool guard_bool(const char* file, int64_t line) override {
119     py::gil_scoped_acquire acquire;
120     return getPyObj().attr("guard_bool")(file, line).cast<bool>();
121   }
122 
expect_true(const char * file,int64_t line)123   bool expect_true(const char* file, int64_t line) override {
124     py::gil_scoped_acquire acquire;
125     return getPyObj().attr("expect_true")(file, line).cast<bool>();
126   }
127 
expect_size(const char * file,int64_t line)128   bool expect_size(const char* file, int64_t line) override {
129     py::gil_scoped_acquire acquire;
130     return getPyObj().attr("expect_size")(file, line).cast<bool>();
131   }
132 
guard_size_oblivious(const char * file,int64_t line)133   bool guard_size_oblivious(const char* file, int64_t line) override {
134     py::gil_scoped_acquire acquire;
135     return getPyObj().attr("guard_size_oblivious")(file, line).cast<bool>();
136   }
137 
int_()138   int64_t int_() override {
139     py::gil_scoped_acquire acquire;
140     return getPyObj().attr("int_")().cast<int64_t>();
141   }
142 
maybe_as_int()143   std::optional<int64_t> maybe_as_int() override {
144     py::gil_scoped_acquire acquire;
145     const auto& r = getPyObj().attr("maybe_as_int")();
146     if (r.is_none()) {
147       return std::nullopt;
148     } else {
149       return r.cast<int64_t>();
150     }
151   }
152 
str()153   std::string str() override {
154     py::gil_scoped_acquire acquire;
155     return getPyObj().attr("str")().cast<std::string>();
156   }
157 
_graph_repr()158   std::string _graph_repr() override {
159     py::gil_scoped_acquire acquire;
160     return getPyObj().attr("_graph_repr")().cast<std::string>();
161   }
162 
dispatch_sym_ite_(const char * fname,const c10::SymNode & other,const c10::SymNode & third)163   c10::SymNode dispatch_sym_ite_(
164       const char* fname,
165       const c10::SymNode& other,
166       const c10::SymNode& third) {
167     auto pother = dynamic_cast<PythonSymNodeImpl*>(other.get());
168     auto pthird = dynamic_cast<PythonSymNodeImpl*>(third.get());
169     TORCH_CHECK(pother);
170     TORCH_CHECK(pthird);
171     py::gil_scoped_acquire acquire;
172     auto r = getPyObj().attr(fname)(pother->getPyObj(), pthird->getPyObj());
173     return c10::make_intrusive<PythonSymNodeImpl>(r);
174   }
175 
dispatch_common_(const char * fname,const c10::SymNode & other)176   c10::SymNode dispatch_common_(const char* fname, const c10::SymNode& other) {
177     auto pother = dynamic_cast<PythonSymNodeImpl*>(other.get());
178     TORCH_CHECK(pother);
179     py::gil_scoped_acquire acquire;
180     auto r = getPyObj().attr(fname)(pother->getPyObj());
181     return c10::make_intrusive<PythonSymNodeImpl>(r);
182   }
183 
dispatch_common_(const char * fname)184   c10::SymNode dispatch_common_(const char* fname) {
185     py::gil_scoped_acquire acquire;
186     auto r = getPyObj().attr(fname)();
187     return c10::make_intrusive<PythonSymNodeImpl>(r);
188   }
189 
add(const c10::SymNode & other)190   c10::SymNode add(const c10::SymNode& other) override {
191     return dispatch_common_(__func__, other);
192   }
193 
sub(const c10::SymNode & other)194   c10::SymNode sub(const c10::SymNode& other) override {
195     return dispatch_common_(__func__, other);
196   }
197 
mul(const c10::SymNode & other)198   c10::SymNode mul(const c10::SymNode& other) override {
199     return dispatch_common_(__func__, other);
200   }
201 
truediv(const c10::SymNode & other)202   c10::SymNode truediv(const c10::SymNode& other) override {
203     return dispatch_common_(__func__, other);
204   }
205 
float_truediv(const c10::SymNode & other)206   c10::SymNode float_truediv(const c10::SymNode& other) override {
207     return dispatch_common_(__func__, other);
208   }
209 
int_truediv(const c10::SymNode & other)210   c10::SymNode int_truediv(const c10::SymNode& other) override {
211     return dispatch_common_(__func__, other);
212   }
213 
pow(const c10::SymNode & other)214   c10::SymNode pow(const c10::SymNode& other) override {
215     return dispatch_common_(__func__, other);
216   }
217 
float_pow(const c10::SymNode & other)218   c10::SymNode float_pow(const c10::SymNode& other) override {
219     return dispatch_common_(__func__, other);
220   }
221 
pow_by_natural(const c10::SymNode & other)222   c10::SymNode pow_by_natural(const c10::SymNode& other) override {
223     return dispatch_common_(__func__, other);
224   }
225 
floordiv(const c10::SymNode & other)226   c10::SymNode floordiv(const c10::SymNode& other) override {
227     return dispatch_common_(__func__, other);
228   }
229 
int_floordiv(const c10::SymNode & other)230   c10::SymNode int_floordiv(const c10::SymNode& other) override {
231     return dispatch_common_(__func__, other);
232   }
233 
mod(const c10::SymNode & other)234   c10::SymNode mod(const c10::SymNode& other) override {
235     return dispatch_common_(__func__, other);
236   }
237 
eq(const c10::SymNode & other)238   c10::SymNode eq(const c10::SymNode& other) override {
239     return dispatch_common_(__func__, other);
240   }
241 
ne(const c10::SymNode & other)242   c10::SymNode ne(const c10::SymNode& other) override {
243     return dispatch_common_(__func__, other);
244   }
245 
gt(const c10::SymNode & other)246   c10::SymNode gt(const c10::SymNode& other) override {
247     return dispatch_common_(__func__, other);
248   }
249 
lt(const c10::SymNode & other)250   c10::SymNode lt(const c10::SymNode& other) override {
251     return dispatch_common_(__func__, other);
252   }
253 
le(const c10::SymNode & other)254   c10::SymNode le(const c10::SymNode& other) override {
255     return dispatch_common_(__func__, other);
256   }
257 
ge(const c10::SymNode & other)258   c10::SymNode ge(const c10::SymNode& other) override {
259     return dispatch_common_(__func__, other);
260   }
261 
sym_min(const c10::SymNode & other)262   c10::SymNode sym_min(const c10::SymNode& other) override {
263     return dispatch_common_(__func__, other);
264   }
sym_max(const c10::SymNode & other)265   c10::SymNode sym_max(const c10::SymNode& other) override {
266     return dispatch_common_(__func__, other);
267   }
268 
sym_and(const c10::SymNode & other)269   c10::SymNode sym_and(const c10::SymNode& other) override {
270     return dispatch_common_(__func__, other);
271   }
272 
sym_or(const c10::SymNode & other)273   c10::SymNode sym_or(const c10::SymNode& other) override {
274     return dispatch_common_(__func__, other);
275   }
276 
sym_ite(const c10::SymNode & other,const c10::SymNode & third)277   c10::SymNode sym_ite(const c10::SymNode& other, const c10::SymNode& third)
278       override {
279     return dispatch_sym_ite_(__func__, other, third);
280   }
281 
sym_not()282   c10::SymNode sym_not() override {
283     return dispatch_common_(__func__);
284   }
285 
ceil()286   c10::SymNode ceil() override {
287     return dispatch_common_(__func__);
288   }
289 
floor()290   c10::SymNode floor() override {
291     return dispatch_common_(__func__);
292   }
293 
neg()294   c10::SymNode neg() override {
295     return dispatch_common_(__func__);
296   }
297 
clone()298   c10::SymNode clone() override {
299     return dispatch_common_(__func__);
300   }
301 
sym_float()302   c10::SymNode sym_float() override {
303     return dispatch_common_(__func__);
304   }
305 
getPyObj()306   py::handle getPyObj() const {
307     return py::handle(pyobj_->ptr(getPyInterpreter()));
308   }
309   std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
310 };
311 
312 } // namespace impl
313 } // namespace torch
314