xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <memory>
18 #include <queue>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_set.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
29 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
30 #include "mlir/IR/Value.h"  // from @llvm-project
31 #include "mlir/Pass/Pass.h"  // from @llvm-project
32 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
33 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
38 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
40 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
41 
42 namespace mlir {
43 namespace TFDevice {
44 
45 namespace {
46 
47 // For the 2d vector type aliases defined below, the first dimension represents
48 // the class of the IfRegion group and the second dimension represents the
49 // segments of the IfRegion group.
50 // For example, if we want to merge the following six IfRegions
51 // which share the same if_cond (regionA)
52 // `````````````
53 // IfRegionA(1)
54 // IfRegionA(2)
55 // IfRegionA(3)
56 // IfRegionA(4)
57 // IfRegionA(5)
58 // IfRegionA(6)
59 // ``````````````
60 // After the analysis, we consider IfRegionA(1), IfRegionA(2) and IfRegionA(3)
61 // can be merged, IfRegionA(4) is standalone, IfRegionA(5) and IfRegionA(6)
62 // can be merged. Then the defined 2D vector is
63 // [[IfRegionA(1), IfRegionA(2), IfRegionA(3)],
64 //  [IfRegionA(4)],
65 //  [IfRegionA(5), IfRegionA(6)]]
66 using RegionVec2D = llvm::SmallVector<llvm::SmallVector<TF::IfRegionOp, 8>, 8>;
67 using OperationVec2D = llvm::SmallVector<llvm::SmallVector<Operation*, 8>, 8>;
68 using MapToRegionVec2D = llvm::SmallDenseMap<Value, RegionVec2D>;
69 using MapToOperationVec2D = llvm::SmallDenseMap<Value, OperationVec2D>;
70 using IfOpIterConst =
71     llvm::SmallVectorTemplateCommon<mlir::TF::IfRegionOp>::const_iterator;
72 
73 struct MergeControlFlowPass
74     : public TF::MergeControlFlowPassBase<MergeControlFlowPass> {
75   void runOnOperation() override;
76 };
77 
78 // Gets the IfRegion op and all of ops in the then and else branches.
GetAllOpsFromIf(TF::IfRegionOp if_op)79 llvm::SmallSetVector<Operation*, 4> GetAllOpsFromIf(TF::IfRegionOp if_op) {
80   llvm::SmallSetVector<Operation*, 4> all_ops;
81   all_ops.insert(if_op);
82   for (Operation& op : if_op.then_branch().front()) {
83     all_ops.insert(&op);
84   }
85   for (Operation& op : if_op.else_branch().front()) {
86     all_ops.insert(&op);
87   }
88   return all_ops;
89 }
90 
91 // Returns whether it is safe to merge `second_if` IfRegion into `first_if`
92 // IfRegion. `second if` must come after `first_if`.
93 // Note that `downstream_if_ops` means the ops in IfRegions except`first_if`.
SafeToMerge(TF::IfRegionOp first_if,TF::IfRegionOp second_if,llvm::SmallSetVector<Operation *,4> & downstream_if_ops,const TF::SideEffectAnalysis::Info & side_effect_analysis)94 bool SafeToMerge(TF::IfRegionOp first_if, TF::IfRegionOp second_if,
95                  llvm::SmallSetVector<Operation*, 4>& downstream_if_ops,
96                  const TF::SideEffectAnalysis::Info& side_effect_analysis) {
97   // IfRegion ops must be in the same block.
98   if (second_if.getOperation()->getBlock() !=
99       first_if.getOperation()->getBlock()) {
100     return false;
101   }
102   assert(first_if.getOperation()->isBeforeInBlock(second_if.getOperation()));
103 
104   llvm::SmallSetVector<Operation*, 4> destination_ops =
105       GetAllOpsFromIf(first_if);
106 
107   // If there is an intermediate data or side effect dependency between the
108   // ops in first_if and the ops in second_if, it's not safe to merge
109   // them.
110   std::vector<Operation*> dependencies;
111   for (auto* user : first_if.getOperation()->getUsers()) {
112     if (!downstream_if_ops.contains(user)) {
113       dependencies.push_back(user);
114     }
115   }
116   for (auto* successor :
117        side_effect_analysis.DirectControlSuccessors(first_if.getOperation())) {
118     if (!downstream_if_ops.contains(successor)) {
119       dependencies.push_back(successor);
120     }
121   }
122   for (Operation& op : first_if.then_branch().front()) {
123     for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) {
124       if (!downstream_if_ops.contains(successor) &&
125           !destination_ops.contains(successor))
126         dependencies.push_back(successor);
127     }
128   }
129   for (Operation& op : first_if.else_branch().front()) {
130     for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) {
131       if (!downstream_if_ops.contains(successor) &&
132           !destination_ops.contains(successor))
133         dependencies.push_back(successor);
134     }
135   }
136 
137   bool safe_to_merge = true;
138 
139   llvm::SmallPtrSet<Operation*, 4> visited;
140   while (!dependencies.empty()) {
141     Operation* dependency = dependencies.back();
142     dependencies.pop_back();
143     if (visited.count(dependency)) continue;
144     visited.insert(dependency);
145     for (auto* user : dependency->getUsers()) {
146       if (downstream_if_ops.contains(user)) {
147         safe_to_merge = false;
148         break;
149       } else {
150         dependencies.push_back(user);
151       }
152     }
153     for (auto* successor :
154          side_effect_analysis.DirectControlSuccessors(dependency)) {
155       if (downstream_if_ops.contains(successor)) {
156         safe_to_merge = false;
157         break;
158       } else {
159         dependencies.push_back(successor);
160       }
161     }
162     // If the op is nested, then also consider the users and successors of the
163     // parent op.
164     if (dependency->getBlock() != first_if.getOperation()->getBlock())
165       dependencies.push_back(dependency->getParentOp());
166     if (!safe_to_merge) break;
167   }
168   return safe_to_merge;
169 }
170 
171 // Move the body excluding the terminators of else and then regions from
172 // 'second_if' to 'first_if'.
MoveBranches(TF::IfRegionOp first_if,TF::IfRegionOp second_if)173 void MoveBranches(TF::IfRegionOp first_if, TF::IfRegionOp second_if) {
174   Block& first_if_then_block = first_if.then_branch().front();
175   auto& second_if_then_body = second_if.then_branch().front().getOperations();
176   first_if_then_block.getOperations().splice(
177       first_if_then_block.without_terminator().end(), second_if_then_body,
178       second_if_then_body.begin(), std::prev(second_if_then_body.end()));
179 
180   Block& first_if_else_block = first_if.else_branch().front();
181   auto& second_if_else_body = second_if.else_branch().front().getOperations();
182   first_if_else_block.getOperations().splice(
183       first_if_else_block.without_terminator().end(), second_if_else_body,
184       second_if_else_body.begin(), std::prev(second_if_else_body.end()));
185 }
186 
187 // Check if the `last` IfRegion can be added to the segment of
188 // IfRegion start with `first` IfRegion.
CanAddToIfSegment(IfOpIterConst first,IfOpIterConst last,const llvm::SmallVector<mlir::TF::IfRegionOp,8> & if_ops,const std::unique_ptr<TF::SideEffectAnalysis> & side_effect_analysis)189 bool CanAddToIfSegment(
190     IfOpIterConst first, IfOpIterConst last,
191     const llvm::SmallVector<mlir::TF::IfRegionOp, 8>& if_ops,
192     const std::unique_ptr<TF::SideEffectAnalysis>& side_effect_analysis) {
193   if (last == if_ops.end()) {
194     return false;
195   }
196   // downstream_if_ops contain ops in those IfRegions between first IfRegion
197   // and last IfRegion plus the ops in the last IfRegion.
198   llvm::SmallSetVector<Operation*, 4> downstream_if_ops;
199 
200   TF::IfRegionOp second_if_op = *last;
201 
202   for (auto iter = std::prev(last); std::next(iter) != first; iter--) {
203     TF::IfRegionOp first_if_op = *iter;
204     func::FuncOp func = first_if_op->getParentOfType<func::FuncOp>();
205     const TF::SideEffectAnalysis::Info& analysis =
206         side_effect_analysis->GetAnalysisForFunc(func);
207     auto all_ops = GetAllOpsFromIf(*(std::next(iter)));
208     downstream_if_ops.insert(all_ops.begin(), all_ops.end());
209     if (!SafeToMerge(first_if_op, second_if_op, downstream_if_ops, analysis)) {
210       return false;
211     }
212   }
213   return true;
214 }
215 
216 // Return the iterator of the IfRegion Op. This is the last IfRegion
217 // in the segment.
218 // For example, we have the following sequence of IfRegions
219 // `````
220 //      1          2          3         4           5
221 // IfRegionA, IfRegionA, IfRegionA, IfRegionA, IfRegionA
222 // `````
223 // The first three IfRegionA are in one group and the last two are in another
224 // group. Then when we call FindLastIfInSegment for the first segment, it
225 // will return iterator of the 3rd IfRegionA.
226 // In the same way, when we call it for the second segment, it will return
227 // iterator of the 5th IfRegionA.
FindLastIfInSegment(IfOpIterConst first_if,const llvm::SmallVector<mlir::TF::IfRegionOp,8> & if_ops,const std::unique_ptr<TF::SideEffectAnalysis> & side_effect_analysis)228 IfOpIterConst FindLastIfInSegment(
229     IfOpIterConst first_if,
230     const llvm::SmallVector<mlir::TF::IfRegionOp, 8>& if_ops,
231     const std::unique_ptr<TF::SideEffectAnalysis>& side_effect_analysis) {
232   IfOpIterConst last_if = first_if;
233   for (; CanAddToIfSegment(first_if, last_if, if_ops, side_effect_analysis);
234        last_if = std::next(last_if)) {
235   }
236   return std::prev(last_if);
237 }
238 
239 // Returns a set of ops to be moved after merged IfRegion between two IfRegions.
GetMoveOpsBetweenTwoIfRegions(Operation * result_op,Operation * after_op,llvm::SmallSetVector<Operation *,4> middle_if_ops,const std::unique_ptr<TF::SideEffectAnalysis> & side_effect_analysis)240 absl::flat_hash_set<Operation*> GetMoveOpsBetweenTwoIfRegions(
241     Operation* result_op, Operation* after_op,
242     llvm::SmallSetVector<Operation*, 4> middle_if_ops,
243     const std::unique_ptr<TF::SideEffectAnalysis>& side_effect_analysis) {
244   Block* block = after_op->getBlock();
245   std::queue<Operation*> queue;
246   absl::flat_hash_set<Operation*> visited;
247   absl::flat_hash_set<Operation*> moved_ops;
248 
249   func::FuncOp func = result_op->getParentOfType<func::FuncOp>();
250   const TF::SideEffectAnalysis::Info& analysis =
251       side_effect_analysis->GetAnalysisForFunc(func);
252 
253   // Enqueue dependencies of source_op into queue.
254   auto enqueue_deps = [&](Operation* source_op) {
255     for (Operation* user : source_op->getUsers()) {
256       if (!visited.count(user) && !middle_if_ops.count(user)) {
257         visited.insert(user);
258         queue.push(user);
259       }
260     }
261     source_op->walk([&](Operation* walked_op) {
262       for (Operation* successor : analysis.DirectControlSuccessors(walked_op)) {
263         if (!source_op->isProperAncestor(successor)) {
264           if (!visited.count(successor) && !middle_if_ops.count(successor)) {
265             visited.insert(successor);
266             queue.push(successor);
267           }
268         }
269       }
270     });
271   };
272   enqueue_deps(result_op);
273 
274   while (!queue.empty()) {
275     auto* op = queue.front();
276     queue.pop();
277     while (op->getBlock() != block) op = op->getParentOp();
278     if (op->isBeforeInBlock(after_op)) {
279       moved_ops.insert(op);
280       enqueue_deps(op);
281     }
282   }
283   return moved_ops;
284 }
285 
286 // Returns a vector that contains the ops to be moved after merged IfRegion.
287 // `sub_if_group` refers to a segment of IfRegions.
288 // The returned vector preserves op order.
GetMoveOpList(llvm::SmallVector<TF::IfRegionOp,8> & sub_if_group,const std::unique_ptr<TF::SideEffectAnalysis> & side_effect_analysis)289 llvm::SmallVector<Operation*, 8> GetMoveOpList(
290     llvm::SmallVector<TF::IfRegionOp, 8>& sub_if_group,
291     const std::unique_ptr<TF::SideEffectAnalysis>& side_effect_analysis) {
292   absl::flat_hash_set<Operation*> all_moved_ops;
293   Operation* last_if_op = sub_if_group.back().getOperation();
294   llvm::SmallSetVector<Operation*, 4> middle_if_ops;
295 
296   // reversely calculate the all ops need to be moved because in this way,
297   // ops in the middle IfRegions can be easily obtained by simply adding to the
298   // current set.
299   for (auto it = std::prev(std::prev(sub_if_group.end()));
300        std::next(it) != sub_if_group.begin(); --it) {
301     auto op_list = GetMoveOpsBetweenTwoIfRegions(
302         it->getOperation(), last_if_op, middle_if_ops, side_effect_analysis);
303     all_moved_ops.insert(op_list.begin(), op_list.end());
304     auto first_if_ops = GetAllOpsFromIf(*it);
305     middle_if_ops.insert(first_if_ops.begin(), first_if_ops.end());
306   }
307 
308   llvm::SmallVector<Operation*, 8> moved_ops_ordered;
309   moved_ops_ordered.reserve(all_moved_ops.size());
310   for (Operation& op : *last_if_op->getBlock()) {
311     if (all_moved_ops.count(&op)) {
312       moved_ops_ordered.push_back(&op);
313     }
314   }
315 
316   return moved_ops_ordered;
317 }
318 
319 // Generate the segments for each IfRegion groups. Each element in the segments
320 // are supposed to can be merged into one new IfRegion.`if_cond` refers to the
321 // if condition of the segment of IfRegions. `if_ops` refers to the segment of
322 // IfRegions. `merged_groups` refers to all segments of IfRegions.
323 // `moved_ops_groups` refers to the ops need to be moved after new merged
324 // IfRegions associated with each segment of IfRegions.
GenerateSegmentsPerIfGroups(const mlir::Value & if_cond,const llvm::SmallVector<mlir::TF::IfRegionOp,8> & if_ops,const std::unique_ptr<TF::SideEffectAnalysis> & side_effect_analysis,MapToRegionVec2D & merged_groups,MapToOperationVec2D & moved_ops_groups)325 void GenerateSegmentsPerIfGroups(
326     const mlir::Value& if_cond,
327     const llvm::SmallVector<mlir::TF::IfRegionOp, 8>& if_ops,
328     const std::unique_ptr<TF::SideEffectAnalysis>& side_effect_analysis,
329     MapToRegionVec2D& merged_groups, MapToOperationVec2D& moved_ops_groups) {
330   auto it_merged = merged_groups.try_emplace(if_cond);
331   auto it_moved = moved_ops_groups.try_emplace(if_cond);
332   llvm::SmallVector<TF::IfRegionOp, 8> sub_merged_groups;
333   auto begin_if_op_iter = if_ops.begin();
334 
335   while (begin_if_op_iter != if_ops.end()) {
336     auto current_last_if_op_iter =
337         FindLastIfInSegment(begin_if_op_iter, if_ops, side_effect_analysis);
338     assert(current_last_if_op_iter != if_ops.end());
339     llvm::SmallVector<TF::IfRegionOp, 8> sub_if_group;
340     for (auto it = begin_if_op_iter; it != std::next(current_last_if_op_iter);
341          ++it) {
342       sub_if_group.push_back(*it);
343     }
344     it_merged.first->getSecond().push_back(sub_if_group);
345     it_moved.first->getSecond().push_back(
346         GetMoveOpList(sub_if_group, side_effect_analysis));
347     begin_if_op_iter = std::next(current_last_if_op_iter);
348   }
349 }
350 
351 // Checks whether a return index should be kept for `current_if_op` by checking
352 // for results in `if_op_segment`.
GetReturnIndicesToKeep(TF::IfRegionOp current_if_op,const llvm::SmallVector<TF::IfRegionOp,8> & if_op_segment)353 llvm::SmallVector<int, 4> GetReturnIndicesToKeep(
354     TF::IfRegionOp current_if_op,
355     const llvm::SmallVector<TF::IfRegionOp, 8>& if_op_segment) {
356   llvm::SmallVector<int, 4> return_indices_to_keep;
357   auto is_op_inside_IfRegions = [&](Operation* op) {
358     for (auto& if_op : if_op_segment) {
359       if (if_op == current_if_op) {
360         continue;
361       }
362       if (if_op->isProperAncestor(op)) {
363         return true;
364       }
365     }
366     return false;
367   };
368   for (auto& index_and_value : llvm::enumerate(current_if_op.getResults())) {
369     if (!llvm::all_of(index_and_value.value().getUsers(),
370                       is_op_inside_IfRegions)) {
371       return_indices_to_keep.push_back(index_and_value.index());
372     }
373   }
374   return return_indices_to_keep;
375 }
376 
377 // Return a vector of the return indices.
GetReturnIndicesVec(const llvm::SmallVector<TF::IfRegionOp,8> & if_op_segment)378 llvm::SmallVector<llvm::SmallVector<int, 4>> GetReturnIndicesVec(
379     const llvm::SmallVector<TF::IfRegionOp, 8>& if_op_segment) {
380   llvm::SmallVector<llvm::SmallVector<int, 4>> return_indices_vec;
381   for (auto it = if_op_segment.begin(); it != if_op_segment.end(); ++it) {
382     llvm::SmallVector<int, 4> indices_to_keep_vec =
383         GetReturnIndicesToKeep(*it, if_op_segment);
384     return_indices_vec.push_back(indices_to_keep_vec);
385   }
386   return return_indices_vec;
387 }
388 
389 // Replace the internal usage in each pair of IfRegions from top to bottom for
390 // both then branch and else branch.
ReplaceInternalUsage(llvm::SmallVector<TF::IfRegionOp,8> & if_op_segment)391 void ReplaceInternalUsage(llvm::SmallVector<TF::IfRegionOp, 8>& if_op_segment) {
392   for (auto it = if_op_segment.begin(); it != if_op_segment.end(); ++it) {
393     for (auto it2 = std::next(it); it2 != if_op_segment.end(); ++it2) {
394       for (OpResult result : it->getResults()) {
395         replaceAllUsesInRegionWith(
396             result,
397             it->then_branch().front().getTerminator()->getOperand(
398                 result.getResultNumber()),
399             it2->then_branch());
400         replaceAllUsesInRegionWith(
401             result,
402             it->else_branch().front().getTerminator()->getOperand(
403                 result.getResultNumber()),
404             it2->else_branch());
405       }
406     }
407   }
408 }
409 
410 // Move ops in the `moved_ops_ordered` after `last_op`.
MoveOpsAfter(Operation * last_op,llvm::SmallVector<Operation *,8> & moved_ops_ordered)411 void MoveOpsAfter(Operation* last_op,
412                   llvm::SmallVector<Operation*, 8>& moved_ops_ordered) {
413   auto block = last_op->getBlock();
414   absl::flat_hash_set<Operation*> all_moved_ops(moved_ops_ordered.begin(),
415                                                 moved_ops_ordered.end());
416   moved_ops_ordered.clear();
417   for (Operation& op : *block) {
418     // There are no mutations in the loop. So each call of `isBeforeInBlock`
419     // is O(1).
420     if (all_moved_ops.count(&op) && op.isBeforeInBlock(last_op)) {
421       moved_ops_ordered.push_back(&op);
422     }
423   }
424   // Move ops in order.
425   for (Operation* op : moved_ops_ordered) {
426     op->moveAfter(last_op);
427     last_op = op;
428   }
429 }
430 
431 // Replace all external usage for each IfRegion in the segment of IfRegions.
432 // `if_op_segment` refers to the segment of IfRegions, `new_if_op` refers to the
433 // new merged IfRegion, `return_indices` refers to the indices to be kept in new
434 // merged IfRegion.
ReplaceExternalUsage(llvm::SmallVector<TF::IfRegionOp,8> & if_op_segment,TF::IfRegionOp new_if_op,llvm::SmallVector<llvm::SmallVector<int,4>> & return_indices)435 void ReplaceExternalUsage(
436     llvm::SmallVector<TF::IfRegionOp, 8>& if_op_segment,
437     TF::IfRegionOp new_if_op,
438     llvm::SmallVector<llvm::SmallVector<int, 4>>& return_indices) {
439   int new_return_index = 0;
440   for (const auto& index_and_value : llvm::enumerate(if_op_segment)) {
441     auto old_if_op = index_and_value.value();
442     for (int i : return_indices[index_and_value.index()]) {
443       old_if_op.getResult(i).replaceAllUsesWith(
444           new_if_op.getResult(new_return_index++));
445     }
446   }
447 }
448 
449 // Update the moved op list to remove old IfRegions from the list and add new
450 // merged IfRegions. `old_to_new_IfRegions_map` refers to a map from old
451 // IfRegion to new merged IfRegion. `moved_ops_list` refers to the list of ops
452 // to be moved after new merged IfRegion.
UpdateMovedOpList(llvm::SmallDenseMap<Operation *,TF::IfRegionOp> & old_to_new_IfRegion_map,llvm::SmallVector<Operation *,8> & moved_ops_list)453 void UpdateMovedOpList(
454     llvm::SmallDenseMap<Operation*, TF::IfRegionOp>& old_to_new_IfRegion_map,
455     llvm::SmallVector<Operation*, 8>& moved_ops_list) {
456   llvm::SmallDenseSet<TF::IfRegionOp> new_if_ops;
457   bool need_add_new_if_op = false;
458   for (auto iter = moved_ops_list.begin(); iter != moved_ops_list.end();
459        iter++) {
460     if (old_to_new_IfRegion_map.count(*iter)) {
461       need_add_new_if_op = true;
462       auto new_if_op = old_to_new_IfRegion_map[*iter];
463       new_if_ops.insert(new_if_op);
464       moved_ops_list.erase(iter--);
465     }
466   }
467   if (need_add_new_if_op) {
468     for (auto& new_if_op : new_if_ops) {
469       moved_ops_list.push_back(new_if_op.getOperation());
470     }
471   }
472 }
473 
474 // Create the Yield ops for both branches with merged results.
475 // `builder` is the OpBuilder.
476 // `if_op_segment` refers to the segment of IfRegions to be merged.
477 // `return_indices` refers to the return indices to be kept in merged IfRegion
478 // `new_if_op` refers to the created new IfRegion
CreateYieldOps(OpBuilder & builder,llvm::SmallVector<TF::IfRegionOp,8> & if_op_segment,llvm::SmallVector<llvm::SmallVector<int,4>> & return_indices,TF::IfRegionOp new_if_op,TF::IfRegionOp first_if)479 void CreateYieldOps(
480     OpBuilder& builder, llvm::SmallVector<TF::IfRegionOp, 8>& if_op_segment,
481     llvm::SmallVector<llvm::SmallVector<int, 4>>& return_indices,
482     TF::IfRegionOp new_if_op, TF::IfRegionOp first_if) {
483   llvm::SmallVector<Value, 4> merged_then_yield_values;
484   for (const auto& index_and_value : llvm::enumerate(if_op_segment)) {
485     auto if_op = index_and_value.value();
486     for (auto i : return_indices[index_and_value.index()]) {
487       merged_then_yield_values.push_back(
488           if_op.then_branch().front().getTerminator()->getOperand(i));
489     }
490   }
491   builder.setInsertionPointToEnd(&new_if_op.then_branch().front());
492   builder.create<TF::YieldOp>(
493       first_if.then_branch().front().getTerminator()->getLoc(),
494       /*operands=*/merged_then_yield_values);
495 
496   llvm::SmallVector<Value, 4> merged_else_yield_values;
497   for (const auto& index_and_value : llvm::enumerate(if_op_segment)) {
498     auto if_op = index_and_value.value();
499     for (auto i : return_indices[index_and_value.index()]) {
500       merged_else_yield_values.push_back(
501           if_op.else_branch().front().getTerminator()->getOperand(i));
502     }
503   }
504   builder.setInsertionPointToEnd(&new_if_op.else_branch().front());
505   builder.create<TF::YieldOp>(
506       first_if.else_branch().front().getTerminator()->getLoc(),
507       /*operands=*/merged_else_yield_values);
508 }
509 
510 // Merge the IfRegions in each segment. In the meantime, the old IfRegions in
511 // the segment will be added to `regions_to_remove`. They will be erased in the
512 // end.
513 // `if_op_segment` refers to segments of IfRegions. `moved_op_list` refers to
514 // the ops to be moved after new merged IfRegion. `regions_to_remove` refers to
515 // the regions to be removed from the `moved_ops_list`.
516 // `old_to_new_IfRegion_map` refers to a map from old IfRegion to new merged
517 // IfRegion.
MergeIfPerSegment(llvm::SmallVector<TF::IfRegionOp,8> & if_op_segment,llvm::SmallVector<Operation *,8> & moved_ops_list,llvm::SmallSetVector<TF::IfRegionOp,8> & regions_to_remove,llvm::SmallDenseMap<Operation *,TF::IfRegionOp> & old_to_new_IfRegion_map)518 void MergeIfPerSegment(
519     llvm::SmallVector<TF::IfRegionOp, 8>& if_op_segment,
520     llvm::SmallVector<Operation*, 8>& moved_ops_list,
521     llvm::SmallSetVector<TF::IfRegionOp, 8>& regions_to_remove,
522     llvm::SmallDenseMap<Operation*, TF::IfRegionOp>& old_to_new_IfRegion_map) {
523   TF::IfRegionOp first_if = if_op_segment[0];
524   llvm::SmallVector<Type, 4> merged_return_types;
525   llvm::SmallVector<TF::IfRegionOp, 8> sources_if_ops(
526       std::next(if_op_segment.begin()), if_op_segment.end());
527 
528   // Create new IfRegion's merged results.
529   auto return_indices = GetReturnIndicesVec(if_op_segment);
530   for (const auto& index_and_value : llvm::enumerate(return_indices)) {
531     TF::IfRegionOp if_op = if_op_segment[index_and_value.index()];
532     for (auto i : index_and_value.value()) {
533       merged_return_types.push_back(if_op.getResult(i).getType());
534     }
535   }
536 
537   // Create new IfRegion for merged all IfRegions in if_op_segmemt.
538   OpBuilder builder(first_if);
539   builder.setInsertionPoint(if_op_segment.back().getOperation());
540 
541   auto new_if_op = builder.create<TF::IfRegionOp>(
542       first_if.getLoc(), merged_return_types, first_if.cond(),
543       llvm::all_of(if_op_segment,
544                    [&](TF::IfRegionOp op) { return op.is_stateless(); }),
545       first_if._then_func_nameAttr(), first_if._else_func_nameAttr());
546   new_if_op.then_branch().push_back(new Block);
547   new_if_op.else_branch().push_back(new Block);
548 
549   // Replace internal usages of merged if ops.
550   ReplaceInternalUsage(if_op_segment);
551 
552   // Replace external usages of merged if ops.
553   ReplaceExternalUsage(if_op_segment, new_if_op, return_indices);
554 
555   // Move ops after the new merged If region.
556   MoveOpsAfter(new_if_op.getOperation(), moved_ops_list);
557 
558   // Create the Yield ops for both branches with merged results.
559   CreateYieldOps(builder, if_op_segment, return_indices, new_if_op, first_if);
560 
561   for (auto& old_if_op : if_op_segment) {
562     MoveBranches(/*first_if=*/new_if_op, /*second_if=*/old_if_op);
563   }
564 
565   for (auto& old_if_op : if_op_segment) {
566     old_to_new_IfRegion_map[old_if_op.getOperation()] = new_if_op;
567     regions_to_remove.insert(old_if_op);
568   }
569 }
570 
571 // Merge IfRegions for each IfRegion group. Each IfRegion group contains
572 // several segments of IfRegions and each segment of IfRegions can be merged
573 // into one IfRegion.
574 // `if_cond` refers to the if condition of the segments of IfRegions.
575 // `planned_merged_groups` refers to the groups of IfRegions to be merged
576 // `moved_ops_groups` refers to the ops need to be moved after new merged
577 // IfRegions associated with each segment of IfRegions.
578 // `regions_to_remove` refers to the regions to be removed
579 // `old_to_new_IfRegion_map` refers to a map from old IfRegion to new merged
580 // IfRegion.
MergeIfPerIfGroups(const Value & if_cond,MapToRegionVec2D & planned_merged_groups,MapToOperationVec2D & moved_ops_groups,llvm::SmallSetVector<TF::IfRegionOp,8> & regions_to_remove,llvm::SmallDenseMap<Operation *,TF::IfRegionOp> & old_to_new_IfRegion_map)581 void MergeIfPerIfGroups(
582     const Value& if_cond, MapToRegionVec2D& planned_merged_groups,
583     MapToOperationVec2D& moved_ops_groups,
584     llvm::SmallSetVector<TF::IfRegionOp, 8>& regions_to_remove,
585     llvm::SmallDenseMap<Operation*, TF::IfRegionOp>& old_to_new_IfRegion_map) {
586   OperationVec2D& moved_ops_group = moved_ops_groups[if_cond];
587   RegionVec2D& segments = planned_merged_groups[if_cond];
588 
589   for (auto i = 0; i < segments.size(); ++i) {
590     if (segments[i].size() >= 2) {
591       UpdateMovedOpList(old_to_new_IfRegion_map, moved_ops_group[i]);
592       MergeIfPerSegment(segments[i], moved_ops_group[i], regions_to_remove,
593                         old_to_new_IfRegion_map);
594     }
595   }
596 }
597 
598 // Groups IfRegions by common predicate and attemps to merge them.
OptimizeIfRegions(Block * block,ModuleOp module)599 void OptimizeIfRegions(Block* block, ModuleOp module) {
600   // Do side effect analysis only one time in the beginning
601   auto side_effect_analysis = std::make_unique<TF::SideEffectAnalysis>(module);
602 
603   // Determine IfRegions with the same predicate.
604   llvm::SmallDenseMap<Value, llvm::SmallVector<TF::IfRegionOp, 8>, 8>
605       grouped_if_ops;
606   llvm::SmallVector<Value, 4> if_cond_order;
607   block->walk([&](TF::IfRegionOp if_op) {
608     auto it = grouped_if_ops.try_emplace(if_op.cond());
609     if (it.second) {
610       if_cond_order.push_back(if_op.cond());
611     }
612     it.first->getSecond().push_back(if_op);
613   });
614 
615   MapToRegionVec2D planned_merged_groups;
616   MapToOperationVec2D moved_ops_groups;
617   llvm::SmallSetVector<TF::IfRegionOp, 8> regions_to_remove;
618   llvm::SmallDenseMap<Operation*, TF::IfRegionOp> old_to_new_IfRegion_map;
619 
620   // For each if group, determine the segments of each if groups
621   // that can be merged and their related ops to be moved after
622   // the new generated IfRegions
623   // We cache the infomation into two maps:
624   // planned_merged_groups and moved_ops_groups
625   for (const auto& if_cond : if_cond_order) {
626     GenerateSegmentsPerIfGroups(if_cond, grouped_if_ops[if_cond],
627                                 side_effect_analysis, planned_merged_groups,
628                                 moved_ops_groups);
629   }
630 
631   // Merge IfRegions for each IfRegion groups.
632   for (const auto& if_cond : if_cond_order) {
633     MergeIfPerIfGroups(if_cond, planned_merged_groups, moved_ops_groups,
634                        regions_to_remove, old_to_new_IfRegion_map);
635   }
636 
637   // Remove all old IfRegions that already been merged.
638   for (auto old_if_region : regions_to_remove) {
639     old_if_region.erase();
640   }
641 }
642 
runOnOperation()643 void MergeControlFlowPass::runOnOperation() {
644   ModuleOp module = getOperation();
645   auto result = module.walk([&](tf_device::ClusterOp cluster) {
646     OptimizeIfRegions(&cluster.GetBody(), module);
647     return WalkResult::advance();
648   });
649   if (result.wasInterrupted()) return signalPassFailure();
650 }
651 
652 }  // namespace
653 
CreateMergeControlFlowPass()654 std::unique_ptr<OperationPass<ModuleOp>> CreateMergeControlFlowPass() {
655   return std::make_unique<MergeControlFlowPass>();
656 }
657 
658 }  // namespace TFDevice
659 }  // namespace mlir
660