xref: /aosp_15_r20/external/pytorch/torch/csrc/fx/node.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/fx/node.h>
2 
3 #include <structmember.h>
4 #include <torch/csrc/utils/pythoncapi_compat.h>
5 
6 ////////////////////////////////
7 // NodeBase
8 ///////////////////////////////
9 
10 struct NodeBase {
11   PyObject_HEAD bool _erased;
12   NodeBase* _prev;
13   NodeBase* _next;
14 };
15 
NodeBase_new(PyTypeObject * type,PyObject * args,PyObject * kwds)16 static PyObject* NodeBase_new(
17     PyTypeObject* type,
18     PyObject* args,
19     PyObject* kwds) {
20   PyObject* self = type->tp_alloc(type, 0);
21   if (!self)
22     return nullptr;
23   return self;
24 }
25 
NodeBase_init_fn(NodeBase * self,PyObject * args,PyObject * kwds)26 static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
27   self->_erased = false;
28   Py_INCREF(self);
29   self->_prev = self;
30   Py_INCREF(self);
31   self->_next = self;
32   return 0;
33 }
34 
35 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
36 static struct PyMemberDef NodeBase_members[] = {
37     {"_erased", T_BOOL, offsetof(NodeBase, _erased), 0, nullptr},
38     {"_prev", T_OBJECT_EX, offsetof(NodeBase, _prev), 0, nullptr},
39     {"_next", T_OBJECT_EX, offsetof(NodeBase, _next), 0, nullptr},
40     {nullptr} /* Sentinel */
41 };
42 
NodeBase_traverse(NodeBase * self,visitproc visit,void * arg)43 static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
44   Py_VISIT(self->_prev);
45   Py_VISIT(self->_next);
46   return 0;
47 }
48 
NodeBase_clear(NodeBase * self)49 static int NodeBase_clear(NodeBase* self) {
50   Py_CLEAR(self->_prev);
51   Py_CLEAR(self->_next);
52   return 0;
53 }
54 
NodeBase_dealloc(PyObject * self)55 static void NodeBase_dealloc(PyObject* self) {
56   PyObject_GC_UnTrack(self);
57   (void)NodeBase_clear((NodeBase*)self);
58   Py_TYPE(self)->tp_free(self);
59 }
60 
61 static PyTypeObject NodeBaseType = {
62     PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeBase", /* tp_name */
63     sizeof(NodeBase), /* tp_basicsize */
64     0, /* tp_itemsize */
65     (destructor)NodeBase_dealloc, /* tp_dealloc */
66     0, /* tp_vectorcall_offset */
67     nullptr, /* tp_getattr */
68     nullptr, /* tp_setattr */
69     nullptr, /* tp_reserved */
70     nullptr, /* tp_repr */
71     nullptr, /* tp_as_number */
72     nullptr, /* tp_as_sequence */
73     nullptr, /* tp_as_mapping */
74     nullptr, /* tp_hash  */
75     nullptr, /* tp_call */
76     nullptr, /* tp_str */
77     nullptr, /* tp_getattro */
78     nullptr, /* tp_setattro */
79     nullptr, /* tp_as_buffer */
80     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
81         Py_TPFLAGS_HAVE_GC, /* tp_flags */
82     nullptr, /* tp_doc */
83     (traverseproc)NodeBase_traverse, /* tp_traverse */
84     (inquiry)NodeBase_clear, /* tp_clear */
85     nullptr, /* tp_richcompare */
86     0, /* tp_weaklistoffset */
87     nullptr, /* tp_iter */
88     nullptr, /* tp_iternext */
89     nullptr, /* tp_methods */
90     NodeBase_members, /* tp_members */
91     nullptr, /* tp_getset */
92     nullptr, /* tp_base */
93     nullptr, /* tp_dict */
94     nullptr, /* tp_descr_get */
95     nullptr, /* tp_descr_set */
96     0, /* tp_dictoffset */
97     (initproc)NodeBase_init_fn, /* tp_init */
98     nullptr, /* tp_alloc */
99     NodeBase_new, /* tp_new */
100 };
101 
NodeBase_init(PyObject * module)102 bool NodeBase_init(PyObject* module) {
103   if (PyModule_AddType(module, &NodeBaseType) < 0) {
104     return false;
105   }
106   return true;
107 }
108 
109 ////////////////////////////////
110 // NodeIter
111 ////////////////////////////////
112 
113 struct NodeIter {
114   PyObject_HEAD bool _reversed;
115   NodeBase* _root;
116   NodeBase* _cur;
117 };
118 
NodeIter_new(PyTypeObject * type,PyObject * args,PyObject * kwds)119 static PyObject* NodeIter_new(
120     PyTypeObject* type,
121     PyObject* args,
122     PyObject* kwds) {
123   PyObject* self = type->tp_alloc(type, 0);
124   if (!self)
125     return nullptr;
126   return self;
127 }
128 
NodeIter_init_fn(NodeIter * self,PyObject * args,PyObject * kwargs)129 static int NodeIter_init_fn(NodeIter* self, PyObject* args, PyObject* kwargs) {
130   NodeBase* root = nullptr;
131   bool reversed = false;
132   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
133   constexpr const char* keywords[] = {"root", "reversed", nullptr};
134   if (!PyArg_ParseTupleAndKeywords(
135           args,
136           kwargs,
137           "Ob|",
138           // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
139           const_cast<char**>(keywords),
140           &root,
141           &reversed)) {
142     return -1;
143   }
144   self->_reversed = reversed;
145   Py_INCREF(root);
146   self->_root = root;
147   Py_INCREF(root);
148   self->_cur = root;
149   return 0;
150 }
151 
152 template <bool reversed>
NodeIter_iternext_helper(NodeIter * self)153 PyObject* NodeIter_iternext_helper(NodeIter* self) {
154   // It should be possible to relax the ref counting here
155   // but in practice, we do not have that many _erased Nodes,
156   // so probably not worth it.
157   if constexpr (reversed) {
158     NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev);
159     Py_CLEAR(self->_cur);
160     self->_cur = prev;
161   } else {
162     NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next);
163     Py_CLEAR(self->_cur);
164     self->_cur = next;
165   }
166   while (self->_cur != self->_root) {
167     if (!self->_cur->_erased) {
168       Py_INCREF(self->_cur);
169       return (PyObject*)self->_cur;
170     }
171     if constexpr (reversed) {
172       NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev);
173       Py_CLEAR(self->_cur);
174       self->_cur = prev;
175     } else {
176       NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next);
177       Py_CLEAR(self->_cur);
178       self->_cur = next;
179     }
180   }
181   PyErr_SetNone(PyExc_StopIteration);
182   return nullptr;
183 }
184 
NodeIter_iternext(PyObject * _self)185 PyObject* NodeIter_iternext(PyObject* _self) {
186   NodeIter* self = (NodeIter*)_self;
187   if (self->_reversed) {
188     return NodeIter_iternext_helper<true>(self);
189   } else {
190     return NodeIter_iternext_helper<false>(self);
191   }
192 }
193 
NodeIter_traverse(NodeIter * self,visitproc visit,void * arg)194 static int NodeIter_traverse(NodeIter* self, visitproc visit, void* arg) {
195   Py_VISIT(self->_root);
196   Py_VISIT(self->_cur);
197   return 0;
198 }
199 
NodeIter_clear(NodeIter * self)200 static int NodeIter_clear(NodeIter* self) {
201   Py_CLEAR(self->_root);
202   Py_CLEAR(self->_cur);
203   return 0;
204 }
205 
NodeIter_dealloc(PyObject * self)206 static void NodeIter_dealloc(PyObject* self) {
207   PyObject_GC_UnTrack(self);
208   (void)NodeIter_clear((NodeIter*)self);
209   Py_TYPE(self)->tp_free(self);
210 }
211 
212 static PyTypeObject NodeIterType = {
213     PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeIter", /* tp_name */
214     sizeof(NodeIter), /* tp_basicsize */
215     0, /* tp_itemsize */
216     (destructor)NodeIter_dealloc, /* tp_dealloc */
217     0, /* tp_vectorcall_offset */
218     nullptr, /* tp_getattr */
219     nullptr, /* tp_setattr */
220     nullptr, /* tp_reserved */
221     nullptr, /* tp_repr */
222     nullptr, /* tp_as_number */
223     nullptr, /* tp_as_sequence */
224     nullptr, /* tp_as_mapping */
225     nullptr, /* tp_hash  */
226     nullptr, /* tp_call */
227     nullptr, /* tp_str */
228     nullptr, /* tp_getattro */
229     nullptr, /* tp_setattro */
230     nullptr, /* tp_as_buffer */
231     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */
232     nullptr, /* tp_doc */
233     (traverseproc)NodeIter_traverse, /* tp_traverse */
234     (inquiry)NodeIter_clear, /* tp_clear */
235     nullptr, /* tp_richcompare */
236     0, /* tp_weaklistoffset */
237     PyObject_SelfIter, /* tp_iter */
238     NodeIter_iternext, /* tp_iternext */
239     nullptr, /* tp_methods */
240     nullptr, /* tp_members */
241     nullptr, /* tp_getset */
242     nullptr, /* tp_base */
243     nullptr, /* tp_dict */
244     nullptr, /* tp_descr_get */
245     nullptr, /* tp_descr_set */
246     0, /* tp_dictoffset */
247     (initproc)NodeIter_init_fn, /* tp_init */
248     nullptr, /* tp_alloc */
249     NodeIter_new, /* tp_new */
250 };
251 
NodeIter_init(PyObject * module)252 bool NodeIter_init(PyObject* module) {
253   if (PyModule_AddType(module, &NodeIterType) < 0) {
254     return false;
255   }
256   return true;
257 }
258