xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/ir/graph_node_list.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/Exception.h>
4 
5 namespace torch::jit {
6 
7 // Intrusive doubly linked lists with sane reverse iterators.
8 // The header file is named generic_graph_node_list.h because it is ONLY
9 // used for Graph's Node lists, and if you want to use it for other
10 // things, you will have to do some refactoring.
11 //
12 // At the moment, the templated type T must support a few operations:
13 //
14 //  - It must have a field: T* next_in_graph[2] = { nullptr, nullptr };
15 //    which are used for the intrusive linked list pointers.
16 //
17 //  - It must have a method 'destroy()', which removes T from the
18 //    list and frees a T.
19 //
20 // In practice, we are only using it with Node and const Node.  'destroy()'
21 // needs to be renegotiated if you want to use this somewhere else.
22 //
23 // Regardless of the iteration direction, iterators always physically point
24 // to the element they logically point to, rather than
25 // the off-by-one behavior for all standard library reverse iterators like
26 // std::list.
27 
28 // The list is includes two sentinel nodes, one at the beginning and one at the
29 // end with a circular link between them. It is an error to insert nodes after
30 // the end sentinel node but before the beginning node:
31 
32 // Visualization showing only the next() links:
33 //  HEAD -> first -> second  -> ... -> last -> TAIL
34 //   ^------------------------------------------
35 
36 // Visualization showing only the prev() links:
37 //  HEAD <- first <- second  <- ... <- last <- TAIL
38 //   ------------------------------------------^
39 
40 static constexpr int kNextDirection = 0;
41 static constexpr int kPrevDirection = 1;
42 
43 template <typename T>
44 struct generic_graph_node_list;
45 
46 template <typename T>
47 struct generic_graph_node_list_iterator;
48 
49 struct Node;
50 using graph_node_list = generic_graph_node_list<Node>;
51 using const_graph_node_list = generic_graph_node_list<const Node>;
52 using graph_node_list_iterator = generic_graph_node_list_iterator<Node>;
53 using const_graph_node_list_iterator =
54     generic_graph_node_list_iterator<const Node>;
55 
56 template <typename T>
57 struct generic_graph_node_list_iterator {
generic_graph_node_list_iteratorgeneric_graph_node_list_iterator58   generic_graph_node_list_iterator() : cur(nullptr), d(kNextDirection) {}
generic_graph_node_list_iteratorgeneric_graph_node_list_iterator59   generic_graph_node_list_iterator(T* cur, int d) : cur(cur), d(d) {}
60   generic_graph_node_list_iterator(
61       const generic_graph_node_list_iterator& rhs) = default;
62   generic_graph_node_list_iterator(
63       generic_graph_node_list_iterator&& rhs) noexcept = default;
64   generic_graph_node_list_iterator& operator=(
65       const generic_graph_node_list_iterator& rhs) = default;
66   generic_graph_node_list_iterator& operator=(
67       generic_graph_node_list_iterator&& rhs) noexcept = default;
68   T* operator*() const {
69     return cur;
70   }
71   T* operator->() const {
72     return cur;
73   }
74   generic_graph_node_list_iterator& operator++() {
75     AT_ASSERT(cur);
76     cur = cur->next_in_graph[d];
77     return *this;
78   }
79   generic_graph_node_list_iterator operator++(int) {
80     generic_graph_node_list_iterator old = *this;
81     ++(*this);
82     return old;
83   }
84   generic_graph_node_list_iterator& operator--() {
85     AT_ASSERT(cur);
86     cur = cur->next_in_graph[reverseDir()];
87     return *this;
88   }
89   generic_graph_node_list_iterator operator--(int) {
90     generic_graph_node_list_iterator old = *this;
91     --(*this);
92     return old;
93   }
94 
95   // erase cur without invalidating this iterator
96   // named differently from destroy so that ->/. bugs do not
97   // silently cause the wrong one to be called.
98   // iterator will point to the previous entry after call
destroyCurrentgeneric_graph_node_list_iterator99   void destroyCurrent() {
100     T* n = cur;
101     cur = cur->next_in_graph[reverseDir()];
102     n->destroy();
103   }
reversegeneric_graph_node_list_iterator104   generic_graph_node_list_iterator reverse() {
105     return generic_graph_node_list_iterator(cur, reverseDir());
106   }
107 
108  private:
reverseDirgeneric_graph_node_list_iterator109   int reverseDir() {
110     return d == kNextDirection ? kPrevDirection : kNextDirection;
111   }
112   T* cur;
113   int d; // direction 0 is forward 1 is reverse, see next_in_graph
114 };
115 
116 template <typename T>
117 struct generic_graph_node_list {
118   using iterator = generic_graph_node_list_iterator<T>;
119   using const_iterator = generic_graph_node_list_iterator<const T>;
begingeneric_graph_node_list120   generic_graph_node_list_iterator<T> begin() {
121     return generic_graph_node_list_iterator<T>(head->next_in_graph[d], d);
122   }
begingeneric_graph_node_list123   generic_graph_node_list_iterator<const T> begin() const {
124     return generic_graph_node_list_iterator<const T>(head->next_in_graph[d], d);
125   }
endgeneric_graph_node_list126   generic_graph_node_list_iterator<T> end() {
127     return generic_graph_node_list_iterator<T>(head->next_in_graph[!d], d);
128   }
endgeneric_graph_node_list129   generic_graph_node_list_iterator<const T> end() const {
130     return generic_graph_node_list_iterator<const T>(
131         head->next_in_graph[!d], d);
132   }
rbegingeneric_graph_node_list133   generic_graph_node_list_iterator<T> rbegin() {
134     return reverse().begin();
135   }
rbegingeneric_graph_node_list136   generic_graph_node_list_iterator<const T> rbegin() const {
137     return reverse().begin();
138   }
rendgeneric_graph_node_list139   generic_graph_node_list_iterator<T> rend() {
140     return reverse().end();
141   }
rendgeneric_graph_node_list142   generic_graph_node_list_iterator<const T> rend() const {
143     return reverse().end();
144   }
reversegeneric_graph_node_list145   generic_graph_node_list reverse() {
146     return generic_graph_node_list(head->next_in_graph[!d], !d);
147   }
reversegeneric_graph_node_list148   const generic_graph_node_list reverse() const {
149     return generic_graph_node_list(head->next_in_graph[!d], !d);
150   }
frontgeneric_graph_node_list151   T* front() {
152     return head->next_in_graph[d];
153   }
frontgeneric_graph_node_list154   const T* front() const {
155     return head->next_in_graph[d];
156   }
backgeneric_graph_node_list157   T* back() {
158     return head->next_in_graph[!d];
159   }
backgeneric_graph_node_list160   const T* back() const {
161     return head->next_in_graph[!d];
162   }
generic_graph_node_listgeneric_graph_node_list163   generic_graph_node_list(T* head, int d) : head(head), d(d) {}
164 
165  private:
166   T* head; // both head and tail are sentinel nodes
167            // the first real node is head->next_in_graph[d]
168            // the tail sentinel is head->next_in_graph[!d]
169   int d;
170 };
171 
172 template <typename T>
173 static inline bool operator==(
174     generic_graph_node_list_iterator<T> a,
175     generic_graph_node_list_iterator<T> b) {
176   return *a == *b;
177 }
178 
179 template <typename T>
180 static inline bool operator!=(
181     generic_graph_node_list_iterator<T> a,
182     generic_graph_node_list_iterator<T> b) {
183   return *a != *b;
184 }
185 
186 } // namespace torch::jit
187 
188 namespace std {
189 
190 template <typename T>
191 struct iterator_traits<torch::jit::generic_graph_node_list_iterator<T>> {
192   using difference_type = int64_t;
193   using value_type = T*;
194   using pointer = T**;
195   using reference = T*&;
196   using iterator_category = bidirectional_iterator_tag;
197 };
198 
199 } // namespace std
200