1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <memory>
18 #include <queue>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/container/flat_hash_set.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
29 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
30 #include "mlir/IR/Value.h" // from @llvm-project
31 #include "mlir/Pass/Pass.h" // from @llvm-project
32 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
33 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
38 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
40 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
41
42 namespace mlir {
43 namespace TFDevice {
44
45 namespace {
46
47 // For the 2d vector type aliases defined below, the first dimension represents
48 // the class of the IfRegion group and the second dimension represents the
49 // segments of the IfRegion group.
50 // For example, if we want to merge the following six IfRegions
51 // which share the same if_cond (regionA)
52 // `````````````
53 // IfRegionA(1)
54 // IfRegionA(2)
55 // IfRegionA(3)
56 // IfRegionA(4)
57 // IfRegionA(5)
58 // IfRegionA(6)
59 // ``````````````
60 // After the analysis, we consider IfRegionA(1), IfRegionA(2) and IfRegionA(3)
61 // can be merged, IfRegionA(4) is standalone, IfRegionA(5) and IfRegionA(6)
62 // can be merged. Then the defined 2D vector is
63 // [[IfRegionA(1), IfRegionA(2), IfRegionA(3)],
64 // [IfRegionA(4)],
65 // [IfRegionA(5), IfRegionA(6)]]
66 using RegionVec2D = llvm::SmallVector<llvm::SmallVector<TF::IfRegionOp, 8>, 8>;
67 using OperationVec2D = llvm::SmallVector<llvm::SmallVector<Operation*, 8>, 8>;
68 using MapToRegionVec2D = llvm::SmallDenseMap<Value, RegionVec2D>;
69 using MapToOperationVec2D = llvm::SmallDenseMap<Value, OperationVec2D>;
70 using IfOpIterConst =
71 llvm::SmallVectorTemplateCommon<mlir::TF::IfRegionOp>::const_iterator;
72
73 struct MergeControlFlowPass
74 : public TF::MergeControlFlowPassBase<MergeControlFlowPass> {
75 void runOnOperation() override;
76 };
77
78 // Gets the IfRegion op and all of ops in the then and else branches.
GetAllOpsFromIf(TF::IfRegionOp if_op)79 llvm::SmallSetVector<Operation*, 4> GetAllOpsFromIf(TF::IfRegionOp if_op) {
80 llvm::SmallSetVector<Operation*, 4> all_ops;
81 all_ops.insert(if_op);
82 for (Operation& op : if_op.then_branch().front()) {
83 all_ops.insert(&op);
84 }
85 for (Operation& op : if_op.else_branch().front()) {
86 all_ops.insert(&op);
87 }
88 return all_ops;
89 }
90
91 // Returns whether it is safe to merge `second_if` IfRegion into `first_if`
92 // IfRegion. `second if` must come after `first_if`.
93 // Note that `downstream_if_ops` means the ops in IfRegions except`first_if`.
SafeToMerge(TF::IfRegionOp first_if,TF::IfRegionOp second_if,llvm::SmallSetVector<Operation *,4> & downstream_if_ops,const TF::SideEffectAnalysis::Info & side_effect_analysis)94 bool SafeToMerge(TF::IfRegionOp first_if, TF::IfRegionOp second_if,
95 llvm::SmallSetVector<Operation*, 4>& downstream_if_ops,
96 const TF::SideEffectAnalysis::Info& side_effect_analysis) {
97 // IfRegion ops must be in the same block.
98 if (second_if.getOperation()->getBlock() !=
99 first_if.getOperation()->getBlock()) {
100 return false;
101 }
102 assert(first_if.getOperation()->isBeforeInBlock(second_if.getOperation()));
103
104 llvm::SmallSetVector<Operation*, 4> destination_ops =
105 GetAllOpsFromIf(first_if);
106
107 // If there is an intermediate data or side effect dependency between the
108 // ops in first_if and the ops in second_if, it's not safe to merge
109 // them.
110 std::vector<Operation*> dependencies;
111 for (auto* user : first_if.getOperation()->getUsers()) {
112 if (!downstream_if_ops.contains(user)) {
113 dependencies.push_back(user);
114 }
115 }
116 for (auto* successor :
117 side_effect_analysis.DirectControlSuccessors(first_if.getOperation())) {
118 if (!downstream_if_ops.contains(successor)) {
119 dependencies.push_back(successor);
120 }
121 }
122 for (Operation& op : first_if.then_branch().front()) {
123 for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) {
124 if (!downstream_if_ops.contains(successor) &&
125 !destination_ops.contains(successor))
126 dependencies.push_back(successor);
127 }
128 }
129 for (Operation& op : first_if.else_branch().front()) {
130 for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) {
131 if (!downstream_if_ops.contains(successor) &&
132 !destination_ops.contains(successor))
133 dependencies.push_back(successor);
134 }
135 }
136
137 bool safe_to_merge = true;
138
139 llvm::SmallPtrSet<Operation*, 4> visited;
140 while (!dependencies.empty()) {
141 Operation* dependency = dependencies.back();
142 dependencies.pop_back();
143 if (visited.count(dependency)) continue;
144 visited.insert(dependency);
145 for (auto* user : dependency->getUsers()) {
146 if (downstream_if_ops.contains(user)) {
147 safe_to_merge = false;
148 break;
149 } else {
150 dependencies.push_back(user);
151 }
152 }
153 for (auto* successor :
154 side_effect_analysis.DirectControlSuccessors(dependency)) {
155 if (downstream_if_ops.contains(successor)) {
156 safe_to_merge = false;
157 break;
158 } else {
159 dependencies.push_back(successor);
160 }
161 }
162 // If the op is nested, then also consider the users and successors of the
163 // parent op.
164 if (dependency->getBlock() != first_if.getOperation()->getBlock())
165 dependencies.push_back(dependency->getParentOp());
166 if (!safe_to_merge) break;
167 }
168 return safe_to_merge;
169 }
170
171 // Move the body excluding the terminators of else and then regions from
172 // 'second_if' to 'first_if'.
MoveBranches(TF::IfRegionOp first_if,TF::IfRegionOp second_if)173 void MoveBranches(TF::IfRegionOp first_if, TF::IfRegionOp second_if) {
174 Block& first_if_then_block = first_if.then_branch().front();
175 auto& second_if_then_body = second_if.then_branch().front().getOperations();
176 first_if_then_block.getOperations().splice(
177 first_if_then_block.without_terminator().end(), second_if_then_body,
178 second_if_then_body.begin(), std::prev(second_if_then_body.end()));
179
180 Block& first_if_else_block = first_if.else_branch().front();
181 auto& second_if_else_body = second_if.else_branch().front().getOperations();
182 first_if_else_block.getOperations().splice(
183 first_if_else_block.without_terminator().end(), second_if_else_body,
184 second_if_else_body.begin(), std::prev(second_if_else_body.end()));
185 }
186
187 // Check if the `last` IfRegion can be added to the segment of
188 // IfRegion start with `first` IfRegion.
CanAddToIfSegment(IfOpIterConst first,IfOpIterConst last,const llvm::SmallVector<mlir::TF::IfRegionOp,8> & if_ops,const std::unique_ptr<TF::SideEffectAnalysis> & side_effect_analysis)189 bool CanAddToIfSegment(
190 IfOpIterConst first, IfOpIterConst last,
191 const llvm::SmallVector<mlir::TF::IfRegionOp, 8>& if_ops,
192 const std::unique_ptr<TF::SideEffectAnalysis>& side_effect_analysis) {
193 if (last == if_ops.end()) {
194 return false;
195 }
196 // downstream_if_ops contain ops in those IfRegions between first IfRegion
197 // and last IfRegion plus the ops in the last IfRegion.
198 llvm::SmallSetVector<Operation*, 4> downstream_if_ops;
199
200 TF::IfRegionOp second_if_op = *last;
201
202 for (auto iter = std::prev(last); std::next(iter) != first; iter--) {
203 TF::IfRegionOp first_if_op = *iter;
204 func::FuncOp func = first_if_op->getParentOfType<func::FuncOp>();
205 const TF::SideEffectAnalysis::Info& analysis =
206 side_effect_analysis->GetAnalysisForFunc(func);
207 auto all_ops = GetAllOpsFromIf(*(std::next(iter)));
208 downstream_if_ops.insert(all_ops.begin(), all_ops.end());
209 if (!SafeToMerge(first_if_op, second_if_op, downstream_if_ops, analysis)) {
210 return false;
211 }
212 }
213 return true;
214 }
215
216 // Return the iterator of the IfRegion Op. This is the last IfRegion
217 // in the segment.
218 // For example, we have the following sequence of IfRegions
219 // `````
220 // 1 2 3 4 5
221 // IfRegionA, IfRegionA, IfRegionA, IfRegionA, IfRegionA
222 // `````
223 // The first three IfRegionA are in one group and the last two are in another
224 // group. Then when we call FindLastIfInSegment for the first segment, it
225 // will return iterator of the 3rd IfRegionA.
226 // In the same way, when we call it for the second segment, it will return
227 // iterator of the 5th IfRegionA.
FindLastIfInSegment(IfOpIterConst first_if,const llvm::SmallVector<mlir::TF::IfRegionOp,8> & if_ops,const std::unique_ptr<TF::SideEffectAnalysis> & side_effect_analysis)228 IfOpIterConst FindLastIfInSegment(
229 IfOpIterConst first_if,
230 const llvm::SmallVector<mlir::TF::IfRegionOp, 8>& if_ops,
231 const std::unique_ptr<TF::SideEffectAnalysis>& side_effect_analysis) {
232 IfOpIterConst last_if = first_if;
233 for (; CanAddToIfSegment(first_if, last_if, if_ops, side_effect_analysis);
234 last_if = std::next(last_if)) {
235 }
236 return std::prev(last_if);
237 }
238
239 // Returns a set of ops to be moved after merged IfRegion between two IfRegions.
GetMoveOpsBetweenTwoIfRegions(Operation * result_op,Operation * after_op,llvm::SmallSetVector<Operation *,4> middle_if_ops,const std::unique_ptr<TF::SideEffectAnalysis> & side_effect_analysis)240 absl::flat_hash_set<Operation*> GetMoveOpsBetweenTwoIfRegions(
241 Operation* result_op, Operation* after_op,
242 llvm::SmallSetVector<Operation*, 4> middle_if_ops,
243 const std::unique_ptr<TF::SideEffectAnalysis>& side_effect_analysis) {
244 Block* block = after_op->getBlock();
245 std::queue<Operation*> queue;
246 absl::flat_hash_set<Operation*> visited;
247 absl::flat_hash_set<Operation*> moved_ops;
248
249 func::FuncOp func = result_op->getParentOfType<func::FuncOp>();
250 const TF::SideEffectAnalysis::Info& analysis =
251 side_effect_analysis->GetAnalysisForFunc(func);
252
253 // Enqueue dependencies of source_op into queue.
254 auto enqueue_deps = [&](Operation* source_op) {
255 for (Operation* user : source_op->getUsers()) {
256 if (!visited.count(user) && !middle_if_ops.count(user)) {
257 visited.insert(user);
258 queue.push(user);
259 }
260 }
261 source_op->walk([&](Operation* walked_op) {
262 for (Operation* successor : analysis.DirectControlSuccessors(walked_op)) {
263 if (!source_op->isProperAncestor(successor)) {
264 if (!visited.count(successor) && !middle_if_ops.count(successor)) {
265 visited.insert(successor);
266 queue.push(successor);
267 }
268 }
269 }
270 });
271 };
272 enqueue_deps(result_op);
273
274 while (!queue.empty()) {
275 auto* op = queue.front();
276 queue.pop();
277 while (op->getBlock() != block) op = op->getParentOp();
278 if (op->isBeforeInBlock(after_op)) {
279 moved_ops.insert(op);
280 enqueue_deps(op);
281 }
282 }
283 return moved_ops;
284 }
285
286 // Returns a vector that contains the ops to be moved after merged IfRegion.
287 // `sub_if_group` refers to a segment of IfRegions.
288 // The returned vector preserves op order.
GetMoveOpList(llvm::SmallVector<TF::IfRegionOp,8> & sub_if_group,const std::unique_ptr<TF::SideEffectAnalysis> & side_effect_analysis)289 llvm::SmallVector<Operation*, 8> GetMoveOpList(
290 llvm::SmallVector<TF::IfRegionOp, 8>& sub_if_group,
291 const std::unique_ptr<TF::SideEffectAnalysis>& side_effect_analysis) {
292 absl::flat_hash_set<Operation*> all_moved_ops;
293 Operation* last_if_op = sub_if_group.back().getOperation();
294 llvm::SmallSetVector<Operation*, 4> middle_if_ops;
295
296 // reversely calculate the all ops need to be moved because in this way,
297 // ops in the middle IfRegions can be easily obtained by simply adding to the
298 // current set.
299 for (auto it = std::prev(std::prev(sub_if_group.end()));
300 std::next(it) != sub_if_group.begin(); --it) {
301 auto op_list = GetMoveOpsBetweenTwoIfRegions(
302 it->getOperation(), last_if_op, middle_if_ops, side_effect_analysis);
303 all_moved_ops.insert(op_list.begin(), op_list.end());
304 auto first_if_ops = GetAllOpsFromIf(*it);
305 middle_if_ops.insert(first_if_ops.begin(), first_if_ops.end());
306 }
307
308 llvm::SmallVector<Operation*, 8> moved_ops_ordered;
309 moved_ops_ordered.reserve(all_moved_ops.size());
310 for (Operation& op : *last_if_op->getBlock()) {
311 if (all_moved_ops.count(&op)) {
312 moved_ops_ordered.push_back(&op);
313 }
314 }
315
316 return moved_ops_ordered;
317 }
318
319 // Generate the segments for each IfRegion groups. Each element in the segments
320 // are supposed to can be merged into one new IfRegion.`if_cond` refers to the
321 // if condition of the segment of IfRegions. `if_ops` refers to the segment of
322 // IfRegions. `merged_groups` refers to all segments of IfRegions.
323 // `moved_ops_groups` refers to the ops need to be moved after new merged
324 // IfRegions associated with each segment of IfRegions.
GenerateSegmentsPerIfGroups(const mlir::Value & if_cond,const llvm::SmallVector<mlir::TF::IfRegionOp,8> & if_ops,const std::unique_ptr<TF::SideEffectAnalysis> & side_effect_analysis,MapToRegionVec2D & merged_groups,MapToOperationVec2D & moved_ops_groups)325 void GenerateSegmentsPerIfGroups(
326 const mlir::Value& if_cond,
327 const llvm::SmallVector<mlir::TF::IfRegionOp, 8>& if_ops,
328 const std::unique_ptr<TF::SideEffectAnalysis>& side_effect_analysis,
329 MapToRegionVec2D& merged_groups, MapToOperationVec2D& moved_ops_groups) {
330 auto it_merged = merged_groups.try_emplace(if_cond);
331 auto it_moved = moved_ops_groups.try_emplace(if_cond);
332 llvm::SmallVector<TF::IfRegionOp, 8> sub_merged_groups;
333 auto begin_if_op_iter = if_ops.begin();
334
335 while (begin_if_op_iter != if_ops.end()) {
336 auto current_last_if_op_iter =
337 FindLastIfInSegment(begin_if_op_iter, if_ops, side_effect_analysis);
338 assert(current_last_if_op_iter != if_ops.end());
339 llvm::SmallVector<TF::IfRegionOp, 8> sub_if_group;
340 for (auto it = begin_if_op_iter; it != std::next(current_last_if_op_iter);
341 ++it) {
342 sub_if_group.push_back(*it);
343 }
344 it_merged.first->getSecond().push_back(sub_if_group);
345 it_moved.first->getSecond().push_back(
346 GetMoveOpList(sub_if_group, side_effect_analysis));
347 begin_if_op_iter = std::next(current_last_if_op_iter);
348 }
349 }
350
351 // Checks whether a return index should be kept for `current_if_op` by checking
352 // for results in `if_op_segment`.
GetReturnIndicesToKeep(TF::IfRegionOp current_if_op,const llvm::SmallVector<TF::IfRegionOp,8> & if_op_segment)353 llvm::SmallVector<int, 4> GetReturnIndicesToKeep(
354 TF::IfRegionOp current_if_op,
355 const llvm::SmallVector<TF::IfRegionOp, 8>& if_op_segment) {
356 llvm::SmallVector<int, 4> return_indices_to_keep;
357 auto is_op_inside_IfRegions = [&](Operation* op) {
358 for (auto& if_op : if_op_segment) {
359 if (if_op == current_if_op) {
360 continue;
361 }
362 if (if_op->isProperAncestor(op)) {
363 return true;
364 }
365 }
366 return false;
367 };
368 for (auto& index_and_value : llvm::enumerate(current_if_op.getResults())) {
369 if (!llvm::all_of(index_and_value.value().getUsers(),
370 is_op_inside_IfRegions)) {
371 return_indices_to_keep.push_back(index_and_value.index());
372 }
373 }
374 return return_indices_to_keep;
375 }
376
377 // Return a vector of the return indices.
GetReturnIndicesVec(const llvm::SmallVector<TF::IfRegionOp,8> & if_op_segment)378 llvm::SmallVector<llvm::SmallVector<int, 4>> GetReturnIndicesVec(
379 const llvm::SmallVector<TF::IfRegionOp, 8>& if_op_segment) {
380 llvm::SmallVector<llvm::SmallVector<int, 4>> return_indices_vec;
381 for (auto it = if_op_segment.begin(); it != if_op_segment.end(); ++it) {
382 llvm::SmallVector<int, 4> indices_to_keep_vec =
383 GetReturnIndicesToKeep(*it, if_op_segment);
384 return_indices_vec.push_back(indices_to_keep_vec);
385 }
386 return return_indices_vec;
387 }
388
389 // Replace the internal usage in each pair of IfRegions from top to bottom for
390 // both then branch and else branch.
ReplaceInternalUsage(llvm::SmallVector<TF::IfRegionOp,8> & if_op_segment)391 void ReplaceInternalUsage(llvm::SmallVector<TF::IfRegionOp, 8>& if_op_segment) {
392 for (auto it = if_op_segment.begin(); it != if_op_segment.end(); ++it) {
393 for (auto it2 = std::next(it); it2 != if_op_segment.end(); ++it2) {
394 for (OpResult result : it->getResults()) {
395 replaceAllUsesInRegionWith(
396 result,
397 it->then_branch().front().getTerminator()->getOperand(
398 result.getResultNumber()),
399 it2->then_branch());
400 replaceAllUsesInRegionWith(
401 result,
402 it->else_branch().front().getTerminator()->getOperand(
403 result.getResultNumber()),
404 it2->else_branch());
405 }
406 }
407 }
408 }
409
410 // Move ops in the `moved_ops_ordered` after `last_op`.
MoveOpsAfter(Operation * last_op,llvm::SmallVector<Operation *,8> & moved_ops_ordered)411 void MoveOpsAfter(Operation* last_op,
412 llvm::SmallVector<Operation*, 8>& moved_ops_ordered) {
413 auto block = last_op->getBlock();
414 absl::flat_hash_set<Operation*> all_moved_ops(moved_ops_ordered.begin(),
415 moved_ops_ordered.end());
416 moved_ops_ordered.clear();
417 for (Operation& op : *block) {
418 // There are no mutations in the loop. So each call of `isBeforeInBlock`
419 // is O(1).
420 if (all_moved_ops.count(&op) && op.isBeforeInBlock(last_op)) {
421 moved_ops_ordered.push_back(&op);
422 }
423 }
424 // Move ops in order.
425 for (Operation* op : moved_ops_ordered) {
426 op->moveAfter(last_op);
427 last_op = op;
428 }
429 }
430
431 // Replace all external usage for each IfRegion in the segment of IfRegions.
432 // `if_op_segment` refers to the segment of IfRegions, `new_if_op` refers to the
433 // new merged IfRegion, `return_indices` refers to the indices to be kept in new
434 // merged IfRegion.
ReplaceExternalUsage(llvm::SmallVector<TF::IfRegionOp,8> & if_op_segment,TF::IfRegionOp new_if_op,llvm::SmallVector<llvm::SmallVector<int,4>> & return_indices)435 void ReplaceExternalUsage(
436 llvm::SmallVector<TF::IfRegionOp, 8>& if_op_segment,
437 TF::IfRegionOp new_if_op,
438 llvm::SmallVector<llvm::SmallVector<int, 4>>& return_indices) {
439 int new_return_index = 0;
440 for (const auto& index_and_value : llvm::enumerate(if_op_segment)) {
441 auto old_if_op = index_and_value.value();
442 for (int i : return_indices[index_and_value.index()]) {
443 old_if_op.getResult(i).replaceAllUsesWith(
444 new_if_op.getResult(new_return_index++));
445 }
446 }
447 }
448
449 // Update the moved op list to remove old IfRegions from the list and add new
450 // merged IfRegions. `old_to_new_IfRegions_map` refers to a map from old
451 // IfRegion to new merged IfRegion. `moved_ops_list` refers to the list of ops
452 // to be moved after new merged IfRegion.
UpdateMovedOpList(llvm::SmallDenseMap<Operation *,TF::IfRegionOp> & old_to_new_IfRegion_map,llvm::SmallVector<Operation *,8> & moved_ops_list)453 void UpdateMovedOpList(
454 llvm::SmallDenseMap<Operation*, TF::IfRegionOp>& old_to_new_IfRegion_map,
455 llvm::SmallVector<Operation*, 8>& moved_ops_list) {
456 llvm::SmallDenseSet<TF::IfRegionOp> new_if_ops;
457 bool need_add_new_if_op = false;
458 for (auto iter = moved_ops_list.begin(); iter != moved_ops_list.end();
459 iter++) {
460 if (old_to_new_IfRegion_map.count(*iter)) {
461 need_add_new_if_op = true;
462 auto new_if_op = old_to_new_IfRegion_map[*iter];
463 new_if_ops.insert(new_if_op);
464 moved_ops_list.erase(iter--);
465 }
466 }
467 if (need_add_new_if_op) {
468 for (auto& new_if_op : new_if_ops) {
469 moved_ops_list.push_back(new_if_op.getOperation());
470 }
471 }
472 }
473
474 // Create the Yield ops for both branches with merged results.
475 // `builder` is the OpBuilder.
476 // `if_op_segment` refers to the segment of IfRegions to be merged.
477 // `return_indices` refers to the return indices to be kept in merged IfRegion
478 // `new_if_op` refers to the created new IfRegion
CreateYieldOps(OpBuilder & builder,llvm::SmallVector<TF::IfRegionOp,8> & if_op_segment,llvm::SmallVector<llvm::SmallVector<int,4>> & return_indices,TF::IfRegionOp new_if_op,TF::IfRegionOp first_if)479 void CreateYieldOps(
480 OpBuilder& builder, llvm::SmallVector<TF::IfRegionOp, 8>& if_op_segment,
481 llvm::SmallVector<llvm::SmallVector<int, 4>>& return_indices,
482 TF::IfRegionOp new_if_op, TF::IfRegionOp first_if) {
483 llvm::SmallVector<Value, 4> merged_then_yield_values;
484 for (const auto& index_and_value : llvm::enumerate(if_op_segment)) {
485 auto if_op = index_and_value.value();
486 for (auto i : return_indices[index_and_value.index()]) {
487 merged_then_yield_values.push_back(
488 if_op.then_branch().front().getTerminator()->getOperand(i));
489 }
490 }
491 builder.setInsertionPointToEnd(&new_if_op.then_branch().front());
492 builder.create<TF::YieldOp>(
493 first_if.then_branch().front().getTerminator()->getLoc(),
494 /*operands=*/merged_then_yield_values);
495
496 llvm::SmallVector<Value, 4> merged_else_yield_values;
497 for (const auto& index_and_value : llvm::enumerate(if_op_segment)) {
498 auto if_op = index_and_value.value();
499 for (auto i : return_indices[index_and_value.index()]) {
500 merged_else_yield_values.push_back(
501 if_op.else_branch().front().getTerminator()->getOperand(i));
502 }
503 }
504 builder.setInsertionPointToEnd(&new_if_op.else_branch().front());
505 builder.create<TF::YieldOp>(
506 first_if.else_branch().front().getTerminator()->getLoc(),
507 /*operands=*/merged_else_yield_values);
508 }
509
510 // Merge the IfRegions in each segment. In the meantime, the old IfRegions in
511 // the segment will be added to `regions_to_remove`. They will be erased in the
512 // end.
513 // `if_op_segment` refers to segments of IfRegions. `moved_op_list` refers to
514 // the ops to be moved after new merged IfRegion. `regions_to_remove` refers to
515 // the regions to be removed from the `moved_ops_list`.
516 // `old_to_new_IfRegion_map` refers to a map from old IfRegion to new merged
517 // IfRegion.
MergeIfPerSegment(llvm::SmallVector<TF::IfRegionOp,8> & if_op_segment,llvm::SmallVector<Operation *,8> & moved_ops_list,llvm::SmallSetVector<TF::IfRegionOp,8> & regions_to_remove,llvm::SmallDenseMap<Operation *,TF::IfRegionOp> & old_to_new_IfRegion_map)518 void MergeIfPerSegment(
519 llvm::SmallVector<TF::IfRegionOp, 8>& if_op_segment,
520 llvm::SmallVector<Operation*, 8>& moved_ops_list,
521 llvm::SmallSetVector<TF::IfRegionOp, 8>& regions_to_remove,
522 llvm::SmallDenseMap<Operation*, TF::IfRegionOp>& old_to_new_IfRegion_map) {
523 TF::IfRegionOp first_if = if_op_segment[0];
524 llvm::SmallVector<Type, 4> merged_return_types;
525 llvm::SmallVector<TF::IfRegionOp, 8> sources_if_ops(
526 std::next(if_op_segment.begin()), if_op_segment.end());
527
528 // Create new IfRegion's merged results.
529 auto return_indices = GetReturnIndicesVec(if_op_segment);
530 for (const auto& index_and_value : llvm::enumerate(return_indices)) {
531 TF::IfRegionOp if_op = if_op_segment[index_and_value.index()];
532 for (auto i : index_and_value.value()) {
533 merged_return_types.push_back(if_op.getResult(i).getType());
534 }
535 }
536
537 // Create new IfRegion for merged all IfRegions in if_op_segmemt.
538 OpBuilder builder(first_if);
539 builder.setInsertionPoint(if_op_segment.back().getOperation());
540
541 auto new_if_op = builder.create<TF::IfRegionOp>(
542 first_if.getLoc(), merged_return_types, first_if.cond(),
543 llvm::all_of(if_op_segment,
544 [&](TF::IfRegionOp op) { return op.is_stateless(); }),
545 first_if._then_func_nameAttr(), first_if._else_func_nameAttr());
546 new_if_op.then_branch().push_back(new Block);
547 new_if_op.else_branch().push_back(new Block);
548
549 // Replace internal usages of merged if ops.
550 ReplaceInternalUsage(if_op_segment);
551
552 // Replace external usages of merged if ops.
553 ReplaceExternalUsage(if_op_segment, new_if_op, return_indices);
554
555 // Move ops after the new merged If region.
556 MoveOpsAfter(new_if_op.getOperation(), moved_ops_list);
557
558 // Create the Yield ops for both branches with merged results.
559 CreateYieldOps(builder, if_op_segment, return_indices, new_if_op, first_if);
560
561 for (auto& old_if_op : if_op_segment) {
562 MoveBranches(/*first_if=*/new_if_op, /*second_if=*/old_if_op);
563 }
564
565 for (auto& old_if_op : if_op_segment) {
566 old_to_new_IfRegion_map[old_if_op.getOperation()] = new_if_op;
567 regions_to_remove.insert(old_if_op);
568 }
569 }
570
571 // Merge IfRegions for each IfRegion group. Each IfRegion group contains
572 // several segments of IfRegions and each segment of IfRegions can be merged
573 // into one IfRegion.
574 // `if_cond` refers to the if condition of the segments of IfRegions.
575 // `planned_merged_groups` refers to the groups of IfRegions to be merged
576 // `moved_ops_groups` refers to the ops need to be moved after new merged
577 // IfRegions associated with each segment of IfRegions.
578 // `regions_to_remove` refers to the regions to be removed
579 // `old_to_new_IfRegion_map` refers to a map from old IfRegion to new merged
580 // IfRegion.
MergeIfPerIfGroups(const Value & if_cond,MapToRegionVec2D & planned_merged_groups,MapToOperationVec2D & moved_ops_groups,llvm::SmallSetVector<TF::IfRegionOp,8> & regions_to_remove,llvm::SmallDenseMap<Operation *,TF::IfRegionOp> & old_to_new_IfRegion_map)581 void MergeIfPerIfGroups(
582 const Value& if_cond, MapToRegionVec2D& planned_merged_groups,
583 MapToOperationVec2D& moved_ops_groups,
584 llvm::SmallSetVector<TF::IfRegionOp, 8>& regions_to_remove,
585 llvm::SmallDenseMap<Operation*, TF::IfRegionOp>& old_to_new_IfRegion_map) {
586 OperationVec2D& moved_ops_group = moved_ops_groups[if_cond];
587 RegionVec2D& segments = planned_merged_groups[if_cond];
588
589 for (auto i = 0; i < segments.size(); ++i) {
590 if (segments[i].size() >= 2) {
591 UpdateMovedOpList(old_to_new_IfRegion_map, moved_ops_group[i]);
592 MergeIfPerSegment(segments[i], moved_ops_group[i], regions_to_remove,
593 old_to_new_IfRegion_map);
594 }
595 }
596 }
597
598 // Groups IfRegions by common predicate and attemps to merge them.
OptimizeIfRegions(Block * block,ModuleOp module)599 void OptimizeIfRegions(Block* block, ModuleOp module) {
600 // Do side effect analysis only one time in the beginning
601 auto side_effect_analysis = std::make_unique<TF::SideEffectAnalysis>(module);
602
603 // Determine IfRegions with the same predicate.
604 llvm::SmallDenseMap<Value, llvm::SmallVector<TF::IfRegionOp, 8>, 8>
605 grouped_if_ops;
606 llvm::SmallVector<Value, 4> if_cond_order;
607 block->walk([&](TF::IfRegionOp if_op) {
608 auto it = grouped_if_ops.try_emplace(if_op.cond());
609 if (it.second) {
610 if_cond_order.push_back(if_op.cond());
611 }
612 it.first->getSecond().push_back(if_op);
613 });
614
615 MapToRegionVec2D planned_merged_groups;
616 MapToOperationVec2D moved_ops_groups;
617 llvm::SmallSetVector<TF::IfRegionOp, 8> regions_to_remove;
618 llvm::SmallDenseMap<Operation*, TF::IfRegionOp> old_to_new_IfRegion_map;
619
620 // For each if group, determine the segments of each if groups
621 // that can be merged and their related ops to be moved after
622 // the new generated IfRegions
623 // We cache the infomation into two maps:
624 // planned_merged_groups and moved_ops_groups
625 for (const auto& if_cond : if_cond_order) {
626 GenerateSegmentsPerIfGroups(if_cond, grouped_if_ops[if_cond],
627 side_effect_analysis, planned_merged_groups,
628 moved_ops_groups);
629 }
630
631 // Merge IfRegions for each IfRegion groups.
632 for (const auto& if_cond : if_cond_order) {
633 MergeIfPerIfGroups(if_cond, planned_merged_groups, moved_ops_groups,
634 regions_to_remove, old_to_new_IfRegion_map);
635 }
636
637 // Remove all old IfRegions that already been merged.
638 for (auto old_if_region : regions_to_remove) {
639 old_if_region.erase();
640 }
641 }
642
runOnOperation()643 void MergeControlFlowPass::runOnOperation() {
644 ModuleOp module = getOperation();
645 auto result = module.walk([&](tf_device::ClusterOp cluster) {
646 OptimizeIfRegions(&cluster.GetBody(), module);
647 return WalkResult::advance();
648 });
649 if (result.wasInterrupted()) return signalPassFailure();
650 }
651
652 } // namespace
653
CreateMergeControlFlowPass()654 std::unique_ptr<OperationPass<ModuleOp>> CreateMergeControlFlowPass() {
655 return std::make_unique<MergeControlFlowPass>();
656 }
657
658 } // namespace TFDevice
659 } // namespace mlir
660