xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/graphcycles/graphcycles.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // GraphCycles provides incremental cycle detection on a dynamic
17 // graph using the following algorithm:
18 //
19 // A dynamic topological sort algorithm for directed acyclic graphs
20 // David J. Pearce, Paul H. J. Kelly
21 // Journal of Experimental Algorithmics (JEA) JEA Homepage archive
22 // Volume 11, 2006, Article No. 1.7
23 //
24 // Brief summary of the algorithm:
25 //
26 // (1) Maintain a rank for each node that is consistent
27 //     with the topological sort of the graph. I.e., path from x to y
28 //     implies rank[x] < rank[y].
29 // (2) When a new edge (x->y) is inserted, do nothing if rank[x] < rank[y].
30 // (3) Otherwise: adjust ranks in the neighborhood of x and y.
31 
32 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
33 
34 #include <algorithm>
35 
36 #include "absl/algorithm/container.h"
37 #include "absl/container/flat_hash_set.h"
38 #include "absl/container/inlined_vector.h"
39 #include "absl/strings/str_cat.h"
40 #include "tensorflow/compiler/xla/service/graphcycles/ordered_set.h"
41 #include "tensorflow/core/platform/logging.h"
42 
43 namespace tensorflow {
44 
45 namespace {
46 
47 using NodeSet = absl::flat_hash_set<int32_t>;
48 using OrderedNodeSet = OrderedSet<int32_t>;
49 
50 template <typename T>
51 struct VecStruct {
52   typedef absl::InlinedVector<T, 4> type;
53 };
54 template <typename T>
55 using Vec = typename VecStruct<T>::type;
56 
57 struct Node {
58   int32_t rank;        // rank number assigned by Pearce-Kelly algorithm
59   bool visited;        // Temporary marker used by depth-first-search
60   void* data;          // User-supplied data
61   OrderedNodeSet in;   // List of immediate predecessor nodes in graph
62   OrderedNodeSet out;  // List of immediate successor nodes in graph
63 };
64 
65 }  // namespace
66 
67 struct GraphCycles::Rep {
68   Vec<Node*> nodes_;
69   Vec<int32_t> free_nodes_;  // Indices for unused entries in nodes_
70 
71   // Temporary state.
72   Vec<int32_t> deltaf_;  // Results of forward DFS
73   Vec<int32_t> deltab_;  // Results of backward DFS
74   Vec<int32_t> list_;    // All nodes to reprocess
75   Vec<int32_t> merged_;  // Rank values to assign to list_ entries
76   Vec<int32_t>
77       stack_;  // Emulates recursion stack when doing depth first search
78 };
79 
GraphCycles()80 GraphCycles::GraphCycles() : rep_(new Rep) {}
81 
~GraphCycles()82 GraphCycles::~GraphCycles() {
83   for (Vec<Node*>::size_type i = 0; i < rep_->nodes_.size(); i++) {
84     delete rep_->nodes_[i];
85   }
86   delete rep_;
87 }
88 
CheckInvariants() const89 bool GraphCycles::CheckInvariants() const {
90   Rep* r = rep_;
91   NodeSet ranks;  // Set of ranks seen so far.
92   for (Vec<Node*>::size_type x = 0; x < r->nodes_.size(); x++) {
93     Node* nx = r->nodes_[x];
94     if (nx->visited) {
95       LOG(FATAL) << "Did not clear visited marker on node " << x;
96     }
97     if (!ranks.insert(nx->rank).second) {
98       LOG(FATAL) << "Duplicate occurrence of rank " << nx->rank;
99     }
100     for (int32_t y : nx->out.GetSequence()) {
101       Node* ny = r->nodes_[y];
102       if (nx->rank >= ny->rank) {
103         LOG(FATAL) << "Edge " << x << "->" << y << " has bad rank assignment "
104                    << nx->rank << "->" << ny->rank;
105       }
106     }
107   }
108   return true;
109 }
110 
NewNode()111 int32_t GraphCycles::NewNode() {
112   if (rep_->free_nodes_.empty()) {
113     Node* n = new Node;
114     n->visited = false;
115     n->data = nullptr;
116     n->rank = rep_->nodes_.size();
117     rep_->nodes_.push_back(n);
118     return n->rank;
119   } else {
120     // Preserve preceding rank since the set of ranks in use must be
121     // a permutation of [0,rep_->nodes_.size()-1].
122     int32_t r = rep_->free_nodes_.back();
123     rep_->nodes_[r]->data = nullptr;
124     rep_->free_nodes_.pop_back();
125     return r;
126   }
127 }
128 
RemoveNode(int32_t node)129 void GraphCycles::RemoveNode(int32_t node) {
130   Node* x = rep_->nodes_[node];
131   for (int32_t y : x->out.GetSequence()) {
132     rep_->nodes_[y]->in.Erase(node);
133   }
134   for (int32_t y : x->in.GetSequence()) {
135     rep_->nodes_[y]->out.Erase(node);
136   }
137   x->in.Clear();
138   x->out.Clear();
139   rep_->free_nodes_.push_back(node);
140 }
141 
GetNodeData(int32_t node) const142 void* GraphCycles::GetNodeData(int32_t node) const {
143   return rep_->nodes_[node]->data;
144 }
145 
SetNodeData(int32_t node,void * data)146 void GraphCycles::SetNodeData(int32_t node, void* data) {
147   rep_->nodes_[node]->data = data;
148 }
149 
HasEdge(int32_t x,int32_t y) const150 bool GraphCycles::HasEdge(int32_t x, int32_t y) const {
151   return rep_->nodes_[x]->out.Contains(y);
152 }
153 
RemoveEdge(int32_t x,int32_t y)154 void GraphCycles::RemoveEdge(int32_t x, int32_t y) {
155   rep_->nodes_[x]->out.Erase(y);
156   rep_->nodes_[y]->in.Erase(x);
157   // No need to update the rank assignment since a previous valid
158   // rank assignment remains valid after an edge deletion.
159 }
160 
161 static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound);
162 static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound);
163 static void Reorder(GraphCycles::Rep* r);
164 static void Sort(const Vec<Node*>&, Vec<int32_t>* delta);
165 static void MoveToList(GraphCycles::Rep* r, Vec<int32_t>* src,
166                        Vec<int32_t>* dst);
167 static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32_t>& nodes);
168 
InsertEdge(int32_t x,int32_t y)169 bool GraphCycles::InsertEdge(int32_t x, int32_t y) {
170   if (x == y) return false;
171   Rep* r = rep_;
172   Node* nx = r->nodes_[x];
173   if (!nx->out.Insert(y)) {
174     // Edge already exists.
175     return true;
176   }
177 
178   Node* ny = r->nodes_[y];
179   ny->in.Insert(x);
180 
181   if (nx->rank <= ny->rank) {
182     // New edge is consistent with existing rank assignment.
183     return true;
184   }
185 
186   // Current rank assignments are incompatible with the new edge.  Recompute.
187   // We only need to consider nodes that fall in the range [ny->rank,nx->rank].
188   if (!ForwardDFS(r, y, nx->rank)) {
189     // Found a cycle.  Undo the insertion and tell caller.
190     nx->out.Erase(y);
191     ny->in.Erase(x);
192     // Since we do not call Reorder() on this path, clear any visited
193     // markers left by ForwardDFS.
194     ClearVisitedBits(r, r->deltaf_);
195     return false;
196   }
197   BackwardDFS(r, x, ny->rank);
198   Reorder(r);
199   return true;
200 }
201 
ForwardDFS(GraphCycles::Rep * r,int32_t n,int32_t upper_bound)202 static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound) {
203   // Avoid recursion since stack space might be limited.
204   // We instead keep a stack of nodes to visit.
205   r->deltaf_.clear();
206   r->stack_.clear();
207   r->stack_.push_back(n);
208   while (!r->stack_.empty()) {
209     n = r->stack_.back();
210     r->stack_.pop_back();
211     Node* nn = r->nodes_[n];
212     if (nn->visited) continue;
213 
214     nn->visited = true;
215     r->deltaf_.push_back(n);
216 
217     for (auto w : nn->out.GetSequence()) {
218       Node* nw = r->nodes_[w];
219       if (nw->rank == upper_bound) {
220         return false;  // Cycle
221       }
222       if (!nw->visited && nw->rank < upper_bound) {
223         r->stack_.push_back(w);
224       }
225     }
226   }
227   return true;
228 }
229 
BackwardDFS(GraphCycles::Rep * r,int32_t n,int32_t lower_bound)230 static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound) {
231   r->deltab_.clear();
232   r->stack_.clear();
233   r->stack_.push_back(n);
234   while (!r->stack_.empty()) {
235     n = r->stack_.back();
236     r->stack_.pop_back();
237     Node* nn = r->nodes_[n];
238     if (nn->visited) continue;
239 
240     nn->visited = true;
241     r->deltab_.push_back(n);
242 
243     for (auto w : nn->in.GetSequence()) {
244       Node* nw = r->nodes_[w];
245       if (!nw->visited && lower_bound < nw->rank) {
246         r->stack_.push_back(w);
247       }
248     }
249   }
250 }
251 
Reorder(GraphCycles::Rep * r)252 static void Reorder(GraphCycles::Rep* r) {
253   Sort(r->nodes_, &r->deltab_);
254   Sort(r->nodes_, &r->deltaf_);
255 
256   // Adds contents of delta lists to list_ (backwards deltas first).
257   r->list_.clear();
258   MoveToList(r, &r->deltab_, &r->list_);
259   MoveToList(r, &r->deltaf_, &r->list_);
260 
261   // Produce sorted list of all ranks that will be reassigned.
262   r->merged_.resize(r->deltab_.size() + r->deltaf_.size());
263   std::merge(r->deltab_.begin(), r->deltab_.end(), r->deltaf_.begin(),
264              r->deltaf_.end(), r->merged_.begin());
265 
266   // Assign the ranks in order to the collected list.
267   for (Vec<int32_t>::size_type i = 0; i < r->list_.size(); i++) {
268     r->nodes_[r->list_[i]]->rank = r->merged_[i];
269   }
270 }
271 
Sort(const Vec<Node * > & nodes,Vec<int32_t> * delta)272 static void Sort(const Vec<Node*>& nodes, Vec<int32_t>* delta) {
273   struct ByRank {
274     const Vec<Node*>* nodes;
275     bool operator()(int32_t a, int32_t b) const {
276       return (*nodes)[a]->rank < (*nodes)[b]->rank;
277     }
278   };
279   ByRank cmp;
280   cmp.nodes = &nodes;
281   std::sort(delta->begin(), delta->end(), cmp);
282 }
283 
MoveToList(GraphCycles::Rep * r,Vec<int32_t> * src,Vec<int32_t> * dst)284 static void MoveToList(GraphCycles::Rep* r, Vec<int32_t>* src,
285                        Vec<int32_t>* dst) {
286   for (Vec<int32_t>::size_type i = 0; i < src->size(); i++) {
287     int32_t w = (*src)[i];
288     (*src)[i] = r->nodes_[w]->rank;  // Replace src entry with its rank
289     r->nodes_[w]->visited = false;   // Prepare for future DFS calls
290     dst->push_back(w);
291   }
292 }
293 
ClearVisitedBits(GraphCycles::Rep * r,const Vec<int32_t> & nodes)294 static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32_t>& nodes) {
295   for (Vec<int32_t>::size_type i = 0; i < nodes.size(); i++) {
296     r->nodes_[nodes[i]]->visited = false;
297   }
298 }
299 
FindPath(int32_t x,int32_t y,int max_path_len,int32_t path[]) const300 int GraphCycles::FindPath(int32_t x, int32_t y, int max_path_len,
301                           int32_t path[]) const {
302   // Forward depth first search starting at x until we hit y.
303   // As we descend into a node, we push it onto the path.
304   // As we leave a node, we remove it from the path.
305   int path_len = 0;
306 
307   Rep* r = rep_;
308   NodeSet seen;
309   r->stack_.clear();
310   r->stack_.push_back(x);
311   while (!r->stack_.empty()) {
312     int32_t n = r->stack_.back();
313     r->stack_.pop_back();
314     if (n < 0) {
315       // Marker to indicate that we are leaving a node
316       path_len--;
317       continue;
318     }
319 
320     if (path_len < max_path_len) {
321       path[path_len] = n;
322     }
323     path_len++;
324     r->stack_.push_back(-1);  // Will remove tentative path entry
325 
326     if (n == y) {
327       return path_len;
328     }
329 
330     for (auto w : r->nodes_[n]->out.GetSequence()) {
331       if (seen.insert(w).second) {
332         r->stack_.push_back(w);
333       }
334     }
335   }
336 
337   return 0;
338 }
339 
IsReachable(int32_t x,int32_t y) const340 bool GraphCycles::IsReachable(int32_t x, int32_t y) const {
341   return FindPath(x, y, 0, nullptr) > 0;
342 }
343 
IsReachableNonConst(int32_t x,int32_t y)344 bool GraphCycles::IsReachableNonConst(int32_t x, int32_t y) {
345   if (x == y) return true;
346   Rep* r = rep_;
347   Node* nx = r->nodes_[x];
348   Node* ny = r->nodes_[y];
349 
350   if (nx->rank >= ny->rank) {
351     // x cannot reach y since it is after it in the topological ordering
352     return false;
353   }
354 
355   // See if x can reach y using a DFS search that is limited to y's rank
356   bool reachable = !ForwardDFS(r, x, ny->rank);
357 
358   // Clear any visited markers left by ForwardDFS.
359   ClearVisitedBits(r, r->deltaf_);
360   return reachable;
361 }
362 
CanContractEdge(int32_t a,int32_t b)363 bool GraphCycles::CanContractEdge(int32_t a, int32_t b) {
364   CHECK(HasEdge(a, b)) << "No edge exists from " << a << " to " << b;
365   RemoveEdge(a, b);
366   bool reachable = IsReachableNonConst(a, b);
367   // Restore the graph to its original state.
368   InsertEdge(a, b);
369   // If reachable, then contracting edge will cause cycle.
370   return !reachable;
371 }
372 
ContractEdge(int32_t a,int32_t b)373 std::optional<int32_t> GraphCycles::ContractEdge(int32_t a, int32_t b) {
374   CHECK(HasEdge(a, b));
375   RemoveEdge(a, b);
376 
377   if (IsReachableNonConst(a, b)) {
378     // Restore the graph to its original state.
379     InsertEdge(a, b);
380     return std::nullopt;
381   }
382 
383   if (rep_->nodes_[b]->in.Size() + rep_->nodes_[b]->out.Size() >
384       rep_->nodes_[a]->in.Size() + rep_->nodes_[a]->out.Size()) {
385     // Swap "a" and "b" to minimize copying.
386     std::swap(a, b);
387   }
388 
389   Node* nb = rep_->nodes_[b];
390   OrderedNodeSet out = std::move(nb->out);
391   OrderedNodeSet in = std::move(nb->in);
392   for (int32_t y : out.GetSequence()) {
393     rep_->nodes_[y]->in.Erase(b);
394   }
395   for (int32_t y : in.GetSequence()) {
396     rep_->nodes_[y]->out.Erase(b);
397   }
398   rep_->free_nodes_.push_back(b);
399 
400   rep_->nodes_[a]->out.Reserve(rep_->nodes_[a]->out.Size() + out.Size());
401   for (int32_t y : out.GetSequence()) {
402     InsertEdge(a, y);
403   }
404 
405   rep_->nodes_[a]->in.Reserve(rep_->nodes_[a]->in.Size() + in.Size());
406   for (int32_t y : in.GetSequence()) {
407     InsertEdge(y, a);
408   }
409 
410   // Note, if the swap happened it might be what originally was called "b".
411   return a;
412 }
413 
Successors(int32_t node) const414 absl::Span<const int32_t> GraphCycles::Successors(int32_t node) const {
415   return rep_->nodes_[node]->out.GetSequence();
416 }
417 
Predecessors(int32_t node) const418 absl::Span<const int32_t> GraphCycles::Predecessors(int32_t node) const {
419   return rep_->nodes_[node]->in.GetSequence();
420 }
421 
SuccessorsCopy(int32_t node) const422 std::vector<int32_t> GraphCycles::SuccessorsCopy(int32_t node) const {
423   absl::Span<const int32_t> successors = Successors(node);
424   return std::vector<int32_t>(successors.begin(), successors.end());
425 }
426 
PredecessorsCopy(int32_t node) const427 std::vector<int32_t> GraphCycles::PredecessorsCopy(int32_t node) const {
428   absl::Span<const int32_t> predecessors = Predecessors(node);
429   return std::vector<int32_t>(predecessors.begin(), predecessors.end());
430 }
431 
432 namespace {
SortInPostOrder(absl::Span<Node * const> nodes,std::vector<int32_t> * to_sort)433 void SortInPostOrder(absl::Span<Node* const> nodes,
434                      std::vector<int32_t>* to_sort) {
435   absl::c_sort(*to_sort, [&](int32_t a, int32_t b) {
436     DCHECK(a == b || nodes[a]->rank != nodes[b]->rank);
437     return nodes[a]->rank > nodes[b]->rank;
438   });
439 }
440 }  // namespace
441 
AllNodesInPostOrder() const442 std::vector<int32_t> GraphCycles::AllNodesInPostOrder() const {
443   absl::flat_hash_set<int32_t> free_nodes_set;
444   absl::c_copy(rep_->free_nodes_,
445                std::inserter(free_nodes_set, free_nodes_set.begin()));
446 
447   std::vector<int32_t> all_nodes;
448   all_nodes.reserve(rep_->nodes_.size() - free_nodes_set.size());
449   for (int64_t i = 0, e = rep_->nodes_.size(); i < e; i++) {
450     if (!free_nodes_set.contains(i)) {
451       all_nodes.push_back(i);
452     }
453   }
454 
455   SortInPostOrder(rep_->nodes_, &all_nodes);
456   return all_nodes;
457 }
458 
DebugString() const459 std::string GraphCycles::DebugString() const {
460   absl::flat_hash_set<int32_t> free_nodes_set;
461   for (int32_t free_node : rep_->free_nodes_) {
462     free_nodes_set.insert(free_node);
463   }
464 
465   std::string result = "digraph {\n";
466   for (int i = 0, end = rep_->nodes_.size(); i < end; i++) {
467     if (free_nodes_set.contains(i)) {
468       continue;
469     }
470 
471     for (int32_t succ : rep_->nodes_[i]->out.GetSequence()) {
472       absl::StrAppend(&result, "  \"", i, "\" -> \"", succ, "\"\n");
473     }
474   }
475 
476   absl::StrAppend(&result, "}\n");
477 
478   return result;
479 }
480 
481 }  // namespace tensorflow
482