1 #pragma once 2 3 #include <c10/util/Exception.h> 4 5 namespace torch::jit { 6 7 // Intrusive doubly linked lists with sane reverse iterators. 8 // The header file is named generic_graph_node_list.h because it is ONLY 9 // used for Graph's Node lists, and if you want to use it for other 10 // things, you will have to do some refactoring. 11 // 12 // At the moment, the templated type T must support a few operations: 13 // 14 // - It must have a field: T* next_in_graph[2] = { nullptr, nullptr }; 15 // which are used for the intrusive linked list pointers. 16 // 17 // - It must have a method 'destroy()', which removes T from the 18 // list and frees a T. 19 // 20 // In practice, we are only using it with Node and const Node. 'destroy()' 21 // needs to be renegotiated if you want to use this somewhere else. 22 // 23 // Regardless of the iteration direction, iterators always physically point 24 // to the element they logically point to, rather than 25 // the off-by-one behavior for all standard library reverse iterators like 26 // std::list. 27 28 // The list is includes two sentinel nodes, one at the beginning and one at the 29 // end with a circular link between them. It is an error to insert nodes after 30 // the end sentinel node but before the beginning node: 31 32 // Visualization showing only the next() links: 33 // HEAD -> first -> second -> ... -> last -> TAIL 34 // ^------------------------------------------ 35 36 // Visualization showing only the prev() links: 37 // HEAD <- first <- second <- ... <- last <- TAIL 38 // ------------------------------------------^ 39 40 static constexpr int kNextDirection = 0; 41 static constexpr int kPrevDirection = 1; 42 43 template <typename T> 44 struct generic_graph_node_list; 45 46 template <typename T> 47 struct generic_graph_node_list_iterator; 48 49 struct Node; 50 using graph_node_list = generic_graph_node_list<Node>; 51 using const_graph_node_list = generic_graph_node_list<const Node>; 52 using graph_node_list_iterator = generic_graph_node_list_iterator<Node>; 53 using const_graph_node_list_iterator = 54 generic_graph_node_list_iterator<const Node>; 55 56 template <typename T> 57 struct generic_graph_node_list_iterator { generic_graph_node_list_iteratorgeneric_graph_node_list_iterator58 generic_graph_node_list_iterator() : cur(nullptr), d(kNextDirection) {} generic_graph_node_list_iteratorgeneric_graph_node_list_iterator59 generic_graph_node_list_iterator(T* cur, int d) : cur(cur), d(d) {} 60 generic_graph_node_list_iterator( 61 const generic_graph_node_list_iterator& rhs) = default; 62 generic_graph_node_list_iterator( 63 generic_graph_node_list_iterator&& rhs) noexcept = default; 64 generic_graph_node_list_iterator& operator=( 65 const generic_graph_node_list_iterator& rhs) = default; 66 generic_graph_node_list_iterator& operator=( 67 generic_graph_node_list_iterator&& rhs) noexcept = default; 68 T* operator*() const { 69 return cur; 70 } 71 T* operator->() const { 72 return cur; 73 } 74 generic_graph_node_list_iterator& operator++() { 75 AT_ASSERT(cur); 76 cur = cur->next_in_graph[d]; 77 return *this; 78 } 79 generic_graph_node_list_iterator operator++(int) { 80 generic_graph_node_list_iterator old = *this; 81 ++(*this); 82 return old; 83 } 84 generic_graph_node_list_iterator& operator--() { 85 AT_ASSERT(cur); 86 cur = cur->next_in_graph[reverseDir()]; 87 return *this; 88 } 89 generic_graph_node_list_iterator operator--(int) { 90 generic_graph_node_list_iterator old = *this; 91 --(*this); 92 return old; 93 } 94 95 // erase cur without invalidating this iterator 96 // named differently from destroy so that ->/. bugs do not 97 // silently cause the wrong one to be called. 98 // iterator will point to the previous entry after call destroyCurrentgeneric_graph_node_list_iterator99 void destroyCurrent() { 100 T* n = cur; 101 cur = cur->next_in_graph[reverseDir()]; 102 n->destroy(); 103 } reversegeneric_graph_node_list_iterator104 generic_graph_node_list_iterator reverse() { 105 return generic_graph_node_list_iterator(cur, reverseDir()); 106 } 107 108 private: reverseDirgeneric_graph_node_list_iterator109 int reverseDir() { 110 return d == kNextDirection ? kPrevDirection : kNextDirection; 111 } 112 T* cur; 113 int d; // direction 0 is forward 1 is reverse, see next_in_graph 114 }; 115 116 template <typename T> 117 struct generic_graph_node_list { 118 using iterator = generic_graph_node_list_iterator<T>; 119 using const_iterator = generic_graph_node_list_iterator<const T>; begingeneric_graph_node_list120 generic_graph_node_list_iterator<T> begin() { 121 return generic_graph_node_list_iterator<T>(head->next_in_graph[d], d); 122 } begingeneric_graph_node_list123 generic_graph_node_list_iterator<const T> begin() const { 124 return generic_graph_node_list_iterator<const T>(head->next_in_graph[d], d); 125 } endgeneric_graph_node_list126 generic_graph_node_list_iterator<T> end() { 127 return generic_graph_node_list_iterator<T>(head->next_in_graph[!d], d); 128 } endgeneric_graph_node_list129 generic_graph_node_list_iterator<const T> end() const { 130 return generic_graph_node_list_iterator<const T>( 131 head->next_in_graph[!d], d); 132 } rbegingeneric_graph_node_list133 generic_graph_node_list_iterator<T> rbegin() { 134 return reverse().begin(); 135 } rbegingeneric_graph_node_list136 generic_graph_node_list_iterator<const T> rbegin() const { 137 return reverse().begin(); 138 } rendgeneric_graph_node_list139 generic_graph_node_list_iterator<T> rend() { 140 return reverse().end(); 141 } rendgeneric_graph_node_list142 generic_graph_node_list_iterator<const T> rend() const { 143 return reverse().end(); 144 } reversegeneric_graph_node_list145 generic_graph_node_list reverse() { 146 return generic_graph_node_list(head->next_in_graph[!d], !d); 147 } reversegeneric_graph_node_list148 const generic_graph_node_list reverse() const { 149 return generic_graph_node_list(head->next_in_graph[!d], !d); 150 } frontgeneric_graph_node_list151 T* front() { 152 return head->next_in_graph[d]; 153 } frontgeneric_graph_node_list154 const T* front() const { 155 return head->next_in_graph[d]; 156 } backgeneric_graph_node_list157 T* back() { 158 return head->next_in_graph[!d]; 159 } backgeneric_graph_node_list160 const T* back() const { 161 return head->next_in_graph[!d]; 162 } generic_graph_node_listgeneric_graph_node_list163 generic_graph_node_list(T* head, int d) : head(head), d(d) {} 164 165 private: 166 T* head; // both head and tail are sentinel nodes 167 // the first real node is head->next_in_graph[d] 168 // the tail sentinel is head->next_in_graph[!d] 169 int d; 170 }; 171 172 template <typename T> 173 static inline bool operator==( 174 generic_graph_node_list_iterator<T> a, 175 generic_graph_node_list_iterator<T> b) { 176 return *a == *b; 177 } 178 179 template <typename T> 180 static inline bool operator!=( 181 generic_graph_node_list_iterator<T> a, 182 generic_graph_node_list_iterator<T> b) { 183 return *a != *b; 184 } 185 186 } // namespace torch::jit 187 188 namespace std { 189 190 template <typename T> 191 struct iterator_traits<torch::jit::generic_graph_node_list_iterator<T>> { 192 using difference_type = int64_t; 193 using value_type = T*; 194 using pointer = T**; 195 using reference = T*&; 196 using iterator_category = bidirectional_iterator_tag; 197 }; 198 199 } // namespace std 200