1 #include <torch/csrc/fx/node.h>
2
3 #include <structmember.h>
4 #include <torch/csrc/utils/pythoncapi_compat.h>
5
6 ////////////////////////////////
7 // NodeBase
8 ///////////////////////////////
9
10 struct NodeBase {
11 PyObject_HEAD bool _erased;
12 NodeBase* _prev;
13 NodeBase* _next;
14 };
15
NodeBase_new(PyTypeObject * type,PyObject * args,PyObject * kwds)16 static PyObject* NodeBase_new(
17 PyTypeObject* type,
18 PyObject* args,
19 PyObject* kwds) {
20 PyObject* self = type->tp_alloc(type, 0);
21 if (!self)
22 return nullptr;
23 return self;
24 }
25
NodeBase_init_fn(NodeBase * self,PyObject * args,PyObject * kwds)26 static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
27 self->_erased = false;
28 Py_INCREF(self);
29 self->_prev = self;
30 Py_INCREF(self);
31 self->_next = self;
32 return 0;
33 }
34
35 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
36 static struct PyMemberDef NodeBase_members[] = {
37 {"_erased", T_BOOL, offsetof(NodeBase, _erased), 0, nullptr},
38 {"_prev", T_OBJECT_EX, offsetof(NodeBase, _prev), 0, nullptr},
39 {"_next", T_OBJECT_EX, offsetof(NodeBase, _next), 0, nullptr},
40 {nullptr} /* Sentinel */
41 };
42
NodeBase_traverse(NodeBase * self,visitproc visit,void * arg)43 static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
44 Py_VISIT(self->_prev);
45 Py_VISIT(self->_next);
46 return 0;
47 }
48
NodeBase_clear(NodeBase * self)49 static int NodeBase_clear(NodeBase* self) {
50 Py_CLEAR(self->_prev);
51 Py_CLEAR(self->_next);
52 return 0;
53 }
54
NodeBase_dealloc(PyObject * self)55 static void NodeBase_dealloc(PyObject* self) {
56 PyObject_GC_UnTrack(self);
57 (void)NodeBase_clear((NodeBase*)self);
58 Py_TYPE(self)->tp_free(self);
59 }
60
61 static PyTypeObject NodeBaseType = {
62 PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeBase", /* tp_name */
63 sizeof(NodeBase), /* tp_basicsize */
64 0, /* tp_itemsize */
65 (destructor)NodeBase_dealloc, /* tp_dealloc */
66 0, /* tp_vectorcall_offset */
67 nullptr, /* tp_getattr */
68 nullptr, /* tp_setattr */
69 nullptr, /* tp_reserved */
70 nullptr, /* tp_repr */
71 nullptr, /* tp_as_number */
72 nullptr, /* tp_as_sequence */
73 nullptr, /* tp_as_mapping */
74 nullptr, /* tp_hash */
75 nullptr, /* tp_call */
76 nullptr, /* tp_str */
77 nullptr, /* tp_getattro */
78 nullptr, /* tp_setattro */
79 nullptr, /* tp_as_buffer */
80 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
81 Py_TPFLAGS_HAVE_GC, /* tp_flags */
82 nullptr, /* tp_doc */
83 (traverseproc)NodeBase_traverse, /* tp_traverse */
84 (inquiry)NodeBase_clear, /* tp_clear */
85 nullptr, /* tp_richcompare */
86 0, /* tp_weaklistoffset */
87 nullptr, /* tp_iter */
88 nullptr, /* tp_iternext */
89 nullptr, /* tp_methods */
90 NodeBase_members, /* tp_members */
91 nullptr, /* tp_getset */
92 nullptr, /* tp_base */
93 nullptr, /* tp_dict */
94 nullptr, /* tp_descr_get */
95 nullptr, /* tp_descr_set */
96 0, /* tp_dictoffset */
97 (initproc)NodeBase_init_fn, /* tp_init */
98 nullptr, /* tp_alloc */
99 NodeBase_new, /* tp_new */
100 };
101
NodeBase_init(PyObject * module)102 bool NodeBase_init(PyObject* module) {
103 if (PyModule_AddType(module, &NodeBaseType) < 0) {
104 return false;
105 }
106 return true;
107 }
108
109 ////////////////////////////////
110 // NodeIter
111 ////////////////////////////////
112
113 struct NodeIter {
114 PyObject_HEAD bool _reversed;
115 NodeBase* _root;
116 NodeBase* _cur;
117 };
118
NodeIter_new(PyTypeObject * type,PyObject * args,PyObject * kwds)119 static PyObject* NodeIter_new(
120 PyTypeObject* type,
121 PyObject* args,
122 PyObject* kwds) {
123 PyObject* self = type->tp_alloc(type, 0);
124 if (!self)
125 return nullptr;
126 return self;
127 }
128
NodeIter_init_fn(NodeIter * self,PyObject * args,PyObject * kwargs)129 static int NodeIter_init_fn(NodeIter* self, PyObject* args, PyObject* kwargs) {
130 NodeBase* root = nullptr;
131 bool reversed = false;
132 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
133 constexpr const char* keywords[] = {"root", "reversed", nullptr};
134 if (!PyArg_ParseTupleAndKeywords(
135 args,
136 kwargs,
137 "Ob|",
138 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
139 const_cast<char**>(keywords),
140 &root,
141 &reversed)) {
142 return -1;
143 }
144 self->_reversed = reversed;
145 Py_INCREF(root);
146 self->_root = root;
147 Py_INCREF(root);
148 self->_cur = root;
149 return 0;
150 }
151
152 template <bool reversed>
NodeIter_iternext_helper(NodeIter * self)153 PyObject* NodeIter_iternext_helper(NodeIter* self) {
154 // It should be possible to relax the ref counting here
155 // but in practice, we do not have that many _erased Nodes,
156 // so probably not worth it.
157 if constexpr (reversed) {
158 NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev);
159 Py_CLEAR(self->_cur);
160 self->_cur = prev;
161 } else {
162 NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next);
163 Py_CLEAR(self->_cur);
164 self->_cur = next;
165 }
166 while (self->_cur != self->_root) {
167 if (!self->_cur->_erased) {
168 Py_INCREF(self->_cur);
169 return (PyObject*)self->_cur;
170 }
171 if constexpr (reversed) {
172 NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev);
173 Py_CLEAR(self->_cur);
174 self->_cur = prev;
175 } else {
176 NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next);
177 Py_CLEAR(self->_cur);
178 self->_cur = next;
179 }
180 }
181 PyErr_SetNone(PyExc_StopIteration);
182 return nullptr;
183 }
184
NodeIter_iternext(PyObject * _self)185 PyObject* NodeIter_iternext(PyObject* _self) {
186 NodeIter* self = (NodeIter*)_self;
187 if (self->_reversed) {
188 return NodeIter_iternext_helper<true>(self);
189 } else {
190 return NodeIter_iternext_helper<false>(self);
191 }
192 }
193
NodeIter_traverse(NodeIter * self,visitproc visit,void * arg)194 static int NodeIter_traverse(NodeIter* self, visitproc visit, void* arg) {
195 Py_VISIT(self->_root);
196 Py_VISIT(self->_cur);
197 return 0;
198 }
199
NodeIter_clear(NodeIter * self)200 static int NodeIter_clear(NodeIter* self) {
201 Py_CLEAR(self->_root);
202 Py_CLEAR(self->_cur);
203 return 0;
204 }
205
NodeIter_dealloc(PyObject * self)206 static void NodeIter_dealloc(PyObject* self) {
207 PyObject_GC_UnTrack(self);
208 (void)NodeIter_clear((NodeIter*)self);
209 Py_TYPE(self)->tp_free(self);
210 }
211
212 static PyTypeObject NodeIterType = {
213 PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeIter", /* tp_name */
214 sizeof(NodeIter), /* tp_basicsize */
215 0, /* tp_itemsize */
216 (destructor)NodeIter_dealloc, /* tp_dealloc */
217 0, /* tp_vectorcall_offset */
218 nullptr, /* tp_getattr */
219 nullptr, /* tp_setattr */
220 nullptr, /* tp_reserved */
221 nullptr, /* tp_repr */
222 nullptr, /* tp_as_number */
223 nullptr, /* tp_as_sequence */
224 nullptr, /* tp_as_mapping */
225 nullptr, /* tp_hash */
226 nullptr, /* tp_call */
227 nullptr, /* tp_str */
228 nullptr, /* tp_getattro */
229 nullptr, /* tp_setattro */
230 nullptr, /* tp_as_buffer */
231 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */
232 nullptr, /* tp_doc */
233 (traverseproc)NodeIter_traverse, /* tp_traverse */
234 (inquiry)NodeIter_clear, /* tp_clear */
235 nullptr, /* tp_richcompare */
236 0, /* tp_weaklistoffset */
237 PyObject_SelfIter, /* tp_iter */
238 NodeIter_iternext, /* tp_iternext */
239 nullptr, /* tp_methods */
240 nullptr, /* tp_members */
241 nullptr, /* tp_getset */
242 nullptr, /* tp_base */
243 nullptr, /* tp_dict */
244 nullptr, /* tp_descr_get */
245 nullptr, /* tp_descr_set */
246 0, /* tp_dictoffset */
247 (initproc)NodeIter_init_fn, /* tp_init */
248 nullptr, /* tp_alloc */
249 NodeIter_new, /* tp_new */
250 };
251
NodeIter_init(PyObject * module)252 bool NodeIter_init(PyObject* module) {
253 if (PyModule_AddType(module, &NodeIterType) < 0) {
254 return false;
255 }
256 return true;
257 }
258