#include #include #include //////////////////////////////// // NodeBase /////////////////////////////// struct NodeBase { PyObject_HEAD bool _erased; NodeBase* _prev; NodeBase* _next; }; static PyObject* NodeBase_new( PyTypeObject* type, PyObject* args, PyObject* kwds) { PyObject* self = type->tp_alloc(type, 0); if (!self) return nullptr; return self; } static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) { self->_erased = false; Py_INCREF(self); self->_prev = self; Py_INCREF(self); self->_next = self; return 0; } // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) static struct PyMemberDef NodeBase_members[] = { {"_erased", T_BOOL, offsetof(NodeBase, _erased), 0, nullptr}, {"_prev", T_OBJECT_EX, offsetof(NodeBase, _prev), 0, nullptr}, {"_next", T_OBJECT_EX, offsetof(NodeBase, _next), 0, nullptr}, {nullptr} /* Sentinel */ }; static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) { Py_VISIT(self->_prev); Py_VISIT(self->_next); return 0; } static int NodeBase_clear(NodeBase* self) { Py_CLEAR(self->_prev); Py_CLEAR(self->_next); return 0; } static void NodeBase_dealloc(PyObject* self) { PyObject_GC_UnTrack(self); (void)NodeBase_clear((NodeBase*)self); Py_TYPE(self)->tp_free(self); } static PyTypeObject NodeBaseType = { PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeBase", /* tp_name */ sizeof(NodeBase), /* tp_basicsize */ 0, /* tp_itemsize */ (destructor)NodeBase_dealloc, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ nullptr, /* tp_getattr */ nullptr, /* tp_setattr */ nullptr, /* tp_reserved */ nullptr, /* tp_repr */ nullptr, /* tp_as_number */ nullptr, /* tp_as_sequence */ nullptr, /* tp_as_mapping */ nullptr, /* tp_hash */ nullptr, /* tp_call */ nullptr, /* tp_str */ nullptr, /* tp_getattro */ nullptr, /* tp_setattro */ nullptr, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /* tp_flags */ nullptr, /* tp_doc */ (traverseproc)NodeBase_traverse, /* tp_traverse */ (inquiry)NodeBase_clear, /* tp_clear */ nullptr, /* tp_richcompare */ 0, /* tp_weaklistoffset */ nullptr, /* tp_iter */ nullptr, /* tp_iternext */ nullptr, /* tp_methods */ NodeBase_members, /* tp_members */ nullptr, /* tp_getset */ nullptr, /* tp_base */ nullptr, /* tp_dict */ nullptr, /* tp_descr_get */ nullptr, /* tp_descr_set */ 0, /* tp_dictoffset */ (initproc)NodeBase_init_fn, /* tp_init */ nullptr, /* tp_alloc */ NodeBase_new, /* tp_new */ }; bool NodeBase_init(PyObject* module) { if (PyModule_AddType(module, &NodeBaseType) < 0) { return false; } return true; } //////////////////////////////// // NodeIter //////////////////////////////// struct NodeIter { PyObject_HEAD bool _reversed; NodeBase* _root; NodeBase* _cur; }; static PyObject* NodeIter_new( PyTypeObject* type, PyObject* args, PyObject* kwds) { PyObject* self = type->tp_alloc(type, 0); if (!self) return nullptr; return self; } static int NodeIter_init_fn(NodeIter* self, PyObject* args, PyObject* kwargs) { NodeBase* root = nullptr; bool reversed = false; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) constexpr const char* keywords[] = {"root", "reversed", nullptr}; if (!PyArg_ParseTupleAndKeywords( args, kwargs, "Ob|", // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) const_cast(keywords), &root, &reversed)) { return -1; } self->_reversed = reversed; Py_INCREF(root); self->_root = root; Py_INCREF(root); self->_cur = root; return 0; } template PyObject* NodeIter_iternext_helper(NodeIter* self) { // It should be possible to relax the ref counting here // but in practice, we do not have that many _erased Nodes, // so probably not worth it. if constexpr (reversed) { NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev); Py_CLEAR(self->_cur); self->_cur = prev; } else { NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next); Py_CLEAR(self->_cur); self->_cur = next; } while (self->_cur != self->_root) { if (!self->_cur->_erased) { Py_INCREF(self->_cur); return (PyObject*)self->_cur; } if constexpr (reversed) { NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev); Py_CLEAR(self->_cur); self->_cur = prev; } else { NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next); Py_CLEAR(self->_cur); self->_cur = next; } } PyErr_SetNone(PyExc_StopIteration); return nullptr; } PyObject* NodeIter_iternext(PyObject* _self) { NodeIter* self = (NodeIter*)_self; if (self->_reversed) { return NodeIter_iternext_helper(self); } else { return NodeIter_iternext_helper(self); } } static int NodeIter_traverse(NodeIter* self, visitproc visit, void* arg) { Py_VISIT(self->_root); Py_VISIT(self->_cur); return 0; } static int NodeIter_clear(NodeIter* self) { Py_CLEAR(self->_root); Py_CLEAR(self->_cur); return 0; } static void NodeIter_dealloc(PyObject* self) { PyObject_GC_UnTrack(self); (void)NodeIter_clear((NodeIter*)self); Py_TYPE(self)->tp_free(self); } static PyTypeObject NodeIterType = { PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeIter", /* tp_name */ sizeof(NodeIter), /* tp_basicsize */ 0, /* tp_itemsize */ (destructor)NodeIter_dealloc, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ nullptr, /* tp_getattr */ nullptr, /* tp_setattr */ nullptr, /* tp_reserved */ nullptr, /* tp_repr */ nullptr, /* tp_as_number */ nullptr, /* tp_as_sequence */ nullptr, /* tp_as_mapping */ nullptr, /* tp_hash */ nullptr, /* tp_call */ nullptr, /* tp_str */ nullptr, /* tp_getattro */ nullptr, /* tp_setattro */ nullptr, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */ nullptr, /* tp_doc */ (traverseproc)NodeIter_traverse, /* tp_traverse */ (inquiry)NodeIter_clear, /* tp_clear */ nullptr, /* tp_richcompare */ 0, /* tp_weaklistoffset */ PyObject_SelfIter, /* tp_iter */ NodeIter_iternext, /* tp_iternext */ nullptr, /* tp_methods */ nullptr, /* tp_members */ nullptr, /* tp_getset */ nullptr, /* tp_base */ nullptr, /* tp_dict */ nullptr, /* tp_descr_get */ nullptr, /* tp_descr_set */ 0, /* tp_dictoffset */ (initproc)NodeIter_init_fn, /* tp_init */ nullptr, /* tp_alloc */ NodeIter_new, /* tp_new */ }; bool NodeIter_init(PyObject* module) { if (PyModule_AddType(module, &NodeIterType) < 0) { return false; } return true; }