1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass forms `tf_executor.island` per region of
17 // `tf_device.parallel_execute`.
18 //
19 // For example, the following:
20 //
21 // %0 = tf_executor.island {
22 // tf_executor.yield
23 // }
24 // %1:2 = tf_executor.island {
25 // %2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
26 // tf_executor.yield %2 : tensor<i1>
27 // }
28 // %3:2 = tf_executor.island(%0) {
29 // %4 = "tf_device.parallel_execute"() ({
30 // %5 = "tf.opB"() : () -> tensor<i1>
31 // tf_device.return %5 : tensor<i1>
32 // }, {
33 // %5 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
34 // tf_device.return
35 // }) {} : () -> (tensor<i1>)
36 // tf_executor.yield %4 : tensor<i1>
37 // }
38 // tf_executor.fetch %3#0 : tensor<i1>
39 //
40 // gets lowered to:
41 //
42 // %0 = tf_executor.island {
43 // tf_executor.yield
44 // }
45 // %1:2 = tf_executor.island {
46 // %2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
47 // tf_executor.yield %2 : tensor<i1>
48 // }
49 //
50 // // Island for the first region of above parallel_execute.
51 // %3:2 = tf_executor.island(%0) {
52 // %4 = "tf.opB"() : () -> tensor<i1>
53 // tf_executor.yield %4 : tensor<i1>
54 // }
55 //
56 // // Island for the second region of above parallel_execute.
57 // %5 = tf_executor.island(%0) {
58 // %6 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
59 // tf_executor.yield
60 // }
61 //
62 // tf_executor.fetch %3#0, %5 : tensor<i1>, !tf_executor.control
63 //
64 // When tf_device.parallel_execute op is enclosed after tf_device.replicate,
65 // then this pass will run following `replicate-to-island` pass and
66 // `tf-executor-break-up-islands` pass.
67
68 #include "llvm/ADT/STLExtras.h"
69 #include "llvm/ADT/SmallVector.h"
70 #include "mlir/IR/Block.h" // from @llvm-project
71 #include "mlir/IR/Builders.h" // from @llvm-project
72 #include "mlir/IR/Value.h" // from @llvm-project
73 #include "mlir/Pass/Pass.h" // from @llvm-project
74 #include "mlir/Support/LLVM.h" // from @llvm-project
75 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
76 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
77 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
78
79 namespace mlir {
80 namespace TFDevice {
81 namespace {
82
83 struct ParallelExecuteToIslandsPass
84 : public TF::ParallelExecuteToIslandsPassBase<
85 ParallelExecuteToIslandsPass> {
86 void runOnOperation() override;
87 };
88
89 // Convert parallel_execute op to a set of islands where each region of
90 // parallel_execute op becomes a separate island. This ensures that the regions
91 // of the parallel_execute op gets executed concurrently.
ExpandParallelExecuteToIslands(tf_executor::IslandOp island_op,tf_device::ParallelExecuteOp parallel_execute_op,OpBuilder * builder,llvm::SmallVectorImpl<tf_executor::IslandOp> & executes)92 void ExpandParallelExecuteToIslands(
93 tf_executor::IslandOp island_op,
94 tf_device::ParallelExecuteOp parallel_execute_op, OpBuilder* builder,
95 llvm::SmallVectorImpl<tf_executor::IslandOp>& executes) {
96 const int num_regions = parallel_execute_op.getOperation()->getNumRegions();
97 executes.reserve(num_regions);
98
99 for (int i : llvm::seq<int>(0, num_regions)) {
100 Block& execute_block = parallel_execute_op.GetRegionBlockWithIndex(i);
101
102 // Replace terminator with tf_executor.YieldOp.
103 Operation* terminator = execute_block.getTerminator();
104 builder->setInsertionPoint(terminator);
105 auto yield = builder->create<tf_executor::YieldOp>(
106 terminator->getLoc(), terminator->getOperands());
107 terminator->erase();
108
109 // Create new island for each region.
110 builder->setInsertionPoint(island_op);
111 auto execute_island = builder->create<tf_executor::IslandOp>(
112 island_op.getLoc(), yield.getOperandTypes(),
113 island_op.control().getType(), island_op.controlInputs());
114
115 // Move over tf_device.parallel_execute body region into newly the created
116 // island.
117 execute_island.body().takeBody(*execute_block.getParent());
118 executes.push_back(execute_island);
119 }
120 }
121
CreateIslandsFromParallelExecute(tf_executor::IslandOp island_op,tf_device::ParallelExecuteOp parallel_execute_op)122 void CreateIslandsFromParallelExecute(
123 tf_executor::IslandOp island_op,
124 tf_device::ParallelExecuteOp parallel_execute_op) {
125 OpBuilder builder(island_op);
126
127 // Create islands for each region of the parallel_execute op.
128 llvm::SmallVector<tf_executor::IslandOp, 4> executes;
129 ExpandParallelExecuteToIslands(island_op, parallel_execute_op, &builder,
130 executes);
131
132 // Remap all results of parallel_execute op with outputs from newly created
133 // islands.
134 llvm::SmallVector<Value, 8> parallel_execute_outputs;
135 parallel_execute_outputs.reserve(
136 parallel_execute_op.getOperation()->getNumResults());
137
138 for (auto& execute : executes)
139 parallel_execute_outputs.append(execute.outputs().begin(),
140 execute.outputs().end());
141
142 for (auto result : llvm::zip(island_op.outputs(), parallel_execute_outputs))
143 std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
144
145 // Add sink island to pin all islands as a control dependency if there is a
146 // control dependency leading from the parallel_execute originally.
147 if (!island_op.control().use_empty()) {
148 llvm::SmallVector<Value, 8> island_operands;
149 for (auto& execute : executes) island_operands.push_back(execute.control());
150
151 builder.setInsertionPoint(island_op);
152 auto island_sink = builder.create<tf_executor::IslandOp>(
153 island_op.getLoc(), llvm::ArrayRef<Type>{},
154 island_op.control().getType(), island_operands);
155 island_sink.body().push_back(new Block);
156 builder.setInsertionPointToEnd(&island_sink.GetBody());
157 builder.create<tf_executor::YieldOp>(island_op.getLoc(),
158 llvm::ArrayRef<Value>{});
159 island_op.control().replaceAllUsesWith(island_sink.control());
160 }
161
162 // Islands with no uses should be pinned to a graph fetch so they still
163 // execute.
164 llvm::SmallVector<Value, 8> unused_execute_controls;
165 for (auto& execute : executes)
166 if (execute.use_empty())
167 unused_execute_controls.push_back(execute.control());
168
169 if (!unused_execute_controls.empty()) {
170 auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
171 tf_executor::FetchOp fetch = graph_op.GetFetch();
172 auto fetches = llvm::to_vector<8>(fetch.getOperands());
173 fetches.append(unused_execute_controls.begin(),
174 unused_execute_controls.end());
175 builder.setInsertionPoint(fetch);
176 builder.create<tf_executor::FetchOp>(fetch.getLoc(), fetches);
177 fetch.erase();
178 }
179
180 island_op.erase();
181 }
182
runOnOperation()183 void ParallelExecuteToIslandsPass::runOnOperation() {
184 // Find islands with a single `tf_device.parallel_execute` and create
185 // individual islands per execute region of the parallel_execute.
186 llvm::SmallVector<tf_executor::IslandOp, 4> parallel_execute_op_islands;
187 getOperation().walk([&](tf_executor::GraphOp graph_op) {
188 for (auto island_op : graph_op.getOps<tf_executor::IslandOp>()) {
189 if (!island_op.WrapsSingleOp()) continue;
190
191 if (isa<tf_device::ParallelExecuteOp>(&island_op.GetBody().front()))
192 parallel_execute_op_islands.push_back(island_op);
193 }
194 });
195
196 for (tf_executor::IslandOp island_op : parallel_execute_op_islands) {
197 auto parallel_execute_op =
198 cast<tf_device::ParallelExecuteOp>(island_op.GetBody().front());
199 CreateIslandsFromParallelExecute(island_op, parallel_execute_op);
200 }
201 }
202 } // anonymous namespace
203
204 std::unique_ptr<OperationPass<func::FuncOp>>
CreateParallelExecuteToIslandsPass()205 CreateParallelExecuteToIslandsPass() {
206 return std::make_unique<ParallelExecuteToIslandsPass>();
207 }
208
209 } // namespace TFDevice
210 } // namespace mlir
211