1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass forms `tf_executor.island` per region of
17 // `tf_device.parallel_execute`.
18 //
19 // For example, the following:
20 //
21 //  %0 = tf_executor.island {
22 //    tf_executor.yield
23 //  }
24 //  %1:2 = tf_executor.island {
25 //    %2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
26 //      tf_executor.yield %2 : tensor<i1>
27 //  }
28 //  %3:2 = tf_executor.island(%0) {
29 //    %4 = "tf_device.parallel_execute"() ({
30 //      %5 = "tf.opB"() : () -> tensor<i1>
31 //      tf_device.return %5 : tensor<i1>
32 //    }, {
33 //      %5 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
34 //      tf_device.return
35 //    }) {} : () -> (tensor<i1>)
36 //    tf_executor.yield %4 : tensor<i1>
37 //  }
38 //  tf_executor.fetch %3#0 : tensor<i1>
39 //
40 // gets lowered to:
41 //
42 //  %0 = tf_executor.island {
43 //    tf_executor.yield
44 //  }
45 //  %1:2 = tf_executor.island {
46 //    %2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
47 //    tf_executor.yield %2 : tensor<i1>
48 //  }
49 //
50 //  // Island for the first region of above parallel_execute.
51 //  %3:2 = tf_executor.island(%0) {
52 //    %4 = "tf.opB"() : () -> tensor<i1>
53 //    tf_executor.yield %4 : tensor<i1>
54 //  }
55 //
56 //  // Island for the second region of above parallel_execute.
57 //  %5 = tf_executor.island(%0) {
58 //    %6 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
59 //    tf_executor.yield
60 //  }
61 //
62 //  tf_executor.fetch %3#0, %5 : tensor<i1>, !tf_executor.control
63 //
64 //  When tf_device.parallel_execute op is enclosed after tf_device.replicate,
65 //  then this pass will run following `replicate-to-island` pass and
66 //  `tf-executor-break-up-islands` pass.
67 
68 #include "llvm/ADT/STLExtras.h"
69 #include "llvm/ADT/SmallVector.h"
70 #include "mlir/IR/Block.h"  // from @llvm-project
71 #include "mlir/IR/Builders.h"  // from @llvm-project
72 #include "mlir/IR/Value.h"  // from @llvm-project
73 #include "mlir/Pass/Pass.h"  // from @llvm-project
74 #include "mlir/Support/LLVM.h"  // from @llvm-project
75 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
76 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
77 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
78 
79 namespace mlir {
80 namespace TFDevice {
81 namespace {
82 
83 struct ParallelExecuteToIslandsPass
84     : public TF::ParallelExecuteToIslandsPassBase<
85           ParallelExecuteToIslandsPass> {
86   void runOnOperation() override;
87 };
88 
89 // Convert parallel_execute op to a set of islands where each region of
90 // parallel_execute op becomes a separate island. This ensures that the regions
91 // of the parallel_execute op gets executed concurrently.
ExpandParallelExecuteToIslands(tf_executor::IslandOp island_op,tf_device::ParallelExecuteOp parallel_execute_op,OpBuilder * builder,llvm::SmallVectorImpl<tf_executor::IslandOp> & executes)92 void ExpandParallelExecuteToIslands(
93     tf_executor::IslandOp island_op,
94     tf_device::ParallelExecuteOp parallel_execute_op, OpBuilder* builder,
95     llvm::SmallVectorImpl<tf_executor::IslandOp>& executes) {
96   const int num_regions = parallel_execute_op.getOperation()->getNumRegions();
97   executes.reserve(num_regions);
98 
99   for (int i : llvm::seq<int>(0, num_regions)) {
100     Block& execute_block = parallel_execute_op.GetRegionBlockWithIndex(i);
101 
102     // Replace terminator with tf_executor.YieldOp.
103     Operation* terminator = execute_block.getTerminator();
104     builder->setInsertionPoint(terminator);
105     auto yield = builder->create<tf_executor::YieldOp>(
106         terminator->getLoc(), terminator->getOperands());
107     terminator->erase();
108 
109     // Create new island for each region.
110     builder->setInsertionPoint(island_op);
111     auto execute_island = builder->create<tf_executor::IslandOp>(
112         island_op.getLoc(), yield.getOperandTypes(),
113         island_op.control().getType(), island_op.controlInputs());
114 
115     // Move over tf_device.parallel_execute body region into newly the created
116     // island.
117     execute_island.body().takeBody(*execute_block.getParent());
118     executes.push_back(execute_island);
119   }
120 }
121 
CreateIslandsFromParallelExecute(tf_executor::IslandOp island_op,tf_device::ParallelExecuteOp parallel_execute_op)122 void CreateIslandsFromParallelExecute(
123     tf_executor::IslandOp island_op,
124     tf_device::ParallelExecuteOp parallel_execute_op) {
125   OpBuilder builder(island_op);
126 
127   // Create islands for each region of the parallel_execute op.
128   llvm::SmallVector<tf_executor::IslandOp, 4> executes;
129   ExpandParallelExecuteToIslands(island_op, parallel_execute_op, &builder,
130                                  executes);
131 
132   // Remap all results of parallel_execute op with outputs from newly created
133   // islands.
134   llvm::SmallVector<Value, 8> parallel_execute_outputs;
135   parallel_execute_outputs.reserve(
136       parallel_execute_op.getOperation()->getNumResults());
137 
138   for (auto& execute : executes)
139     parallel_execute_outputs.append(execute.outputs().begin(),
140                                     execute.outputs().end());
141 
142   for (auto result : llvm::zip(island_op.outputs(), parallel_execute_outputs))
143     std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
144 
145   // Add sink island to pin all islands as a control dependency if there is a
146   // control dependency leading from the parallel_execute originally.
147   if (!island_op.control().use_empty()) {
148     llvm::SmallVector<Value, 8> island_operands;
149     for (auto& execute : executes) island_operands.push_back(execute.control());
150 
151     builder.setInsertionPoint(island_op);
152     auto island_sink = builder.create<tf_executor::IslandOp>(
153         island_op.getLoc(), llvm::ArrayRef<Type>{},
154         island_op.control().getType(), island_operands);
155     island_sink.body().push_back(new Block);
156     builder.setInsertionPointToEnd(&island_sink.GetBody());
157     builder.create<tf_executor::YieldOp>(island_op.getLoc(),
158                                          llvm::ArrayRef<Value>{});
159     island_op.control().replaceAllUsesWith(island_sink.control());
160   }
161 
162   // Islands with no uses should be pinned to a graph fetch so they still
163   // execute.
164   llvm::SmallVector<Value, 8> unused_execute_controls;
165   for (auto& execute : executes)
166     if (execute.use_empty())
167       unused_execute_controls.push_back(execute.control());
168 
169   if (!unused_execute_controls.empty()) {
170     auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
171     tf_executor::FetchOp fetch = graph_op.GetFetch();
172     auto fetches = llvm::to_vector<8>(fetch.getOperands());
173     fetches.append(unused_execute_controls.begin(),
174                    unused_execute_controls.end());
175     builder.setInsertionPoint(fetch);
176     builder.create<tf_executor::FetchOp>(fetch.getLoc(), fetches);
177     fetch.erase();
178   }
179 
180   island_op.erase();
181 }
182 
runOnOperation()183 void ParallelExecuteToIslandsPass::runOnOperation() {
184   // Find islands with a single `tf_device.parallel_execute` and create
185   // individual islands per execute region of the parallel_execute.
186   llvm::SmallVector<tf_executor::IslandOp, 4> parallel_execute_op_islands;
187   getOperation().walk([&](tf_executor::GraphOp graph_op) {
188     for (auto island_op : graph_op.getOps<tf_executor::IslandOp>()) {
189       if (!island_op.WrapsSingleOp()) continue;
190 
191       if (isa<tf_device::ParallelExecuteOp>(&island_op.GetBody().front()))
192         parallel_execute_op_islands.push_back(island_op);
193     }
194   });
195 
196   for (tf_executor::IslandOp island_op : parallel_execute_op_islands) {
197     auto parallel_execute_op =
198         cast<tf_device::ParallelExecuteOp>(island_op.GetBody().front());
199     CreateIslandsFromParallelExecute(island_op, parallel_execute_op);
200   }
201 }
202 }  // anonymous namespace
203 
204 std::unique_ptr<OperationPass<func::FuncOp>>
CreateParallelExecuteToIslandsPass()205 CreateParallelExecuteToIslandsPass() {
206   return std::make_unique<ParallelExecuteToIslandsPass>();
207 }
208 
209 }  // namespace TFDevice
210 }  // namespace mlir
211