1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/Builders.h" // from @llvm-project
25 #include "mlir/IR/Operation.h" // from @llvm-project
26 #include "mlir/IR/Value.h" // from @llvm-project
27 #include "mlir/Pass/Pass.h" // from @llvm-project
28 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32
33 // This pass is used in preparation for Graph export.
34 // The GraphDef exporter expects each op to be in its own island.
35 // This pass puts the IR in that form.
36 //
37 // We do this as an IR->IR transform to keep the Graph exporter as simple as
38 // possible.
39
40 namespace mlir {
41
42 namespace {
43
44 class BreakUpIslands : public TF::PerFunctionAggregateAnalysisConsumerPass<
45 BreakUpIslands, TF::SideEffectAnalysis> {
getDependentDialects(DialectRegistry & registry) const46 void getDependentDialects(DialectRegistry& registry) const override {
47 registry.insert<tf_executor::TensorFlowExecutorDialect>();
48 }
49
50 public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BreakUpIslands)51 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BreakUpIslands)
52
53 StringRef getArgument() const final { return "tf-executor-break-up-islands"; }
54
getDescription() const55 StringRef getDescription() const final {
56 return "Transform from TF control dialect to TF executor dialect.";
57 }
58
59 void runOnFunction(func::FuncOp func,
60 const TF::SideEffectAnalysis::Info& side_effect_analysis);
61
62 void BreakUpIsland(tf_executor::IslandOp island_op,
63 const TF::SideEffectAnalysis::Info& side_effect_analysis,
64 llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>*
65 new_control_inputs);
66 };
67
68 // Returns true if the operation is a stateful If, Case, or While op.
IsStatefulFunctionalControlFlowOp(Operation * op)69 bool IsStatefulFunctionalControlFlowOp(Operation* op) {
70 if (!isa<TF::IfOp, TF::CaseOp, TF::WhileOp>(op)) {
71 return false;
72 }
73
74 if (auto is_stateless = op->getAttrOfType<BoolAttr>("is_stateless")) {
75 return !is_stateless.getValue();
76 }
77 return false;
78 }
79
80 // Add control dependencies from stateful control-flow ops to graph fetch op.
81 // This is needed to avoid that such control-flow ops get pruned because of a
82 // bug in common runtime (see b/185483669).
AddStatefulControlFlowDependencies(tf_executor::GraphOp graph_op)83 void AddStatefulControlFlowDependencies(tf_executor::GraphOp graph_op) {
84 llvm::SmallDenseSet<Value, 8> graph_fetches;
85 for (Value value : graph_op.GetFetch().fetches()) {
86 graph_fetches.insert(value);
87 }
88 for (Operation& op : graph_op.GetBody().without_terminator()) {
89 auto island = dyn_cast<tf_executor::IslandOp>(&op);
90 if (!island) continue;
91 if (!island.WrapsSingleOp()) continue;
92 Operation& wrapped_op = island.GetBody().front();
93 if (!IsStatefulFunctionalControlFlowOp(&wrapped_op)) continue;
94 if (graph_fetches.contains(island.control())) continue;
95
96 graph_op.GetFetch().fetchesMutable().append(island.control());
97 }
98 }
99
runOnFunction(func::FuncOp func,const TF::SideEffectAnalysis::Info & side_effect_analysis)100 void BreakUpIslands::runOnFunction(
101 func::FuncOp func,
102 const TF::SideEffectAnalysis::Info& side_effect_analysis) {
103 auto graph_op_range = func.front().without_terminator();
104 tf_executor::GraphOp graph_op;
105
106 if (llvm::hasSingleElement(graph_op_range))
107 graph_op = dyn_cast<tf_executor::GraphOp>(func.front().front());
108
109 if (!graph_op) {
110 func.emitError("expected function to contain only a graph_op");
111 signalPassFailure();
112 return;
113 }
114
115 // New control inputs to be added. For an operation x, new_control_inputs[x]
116 // contains all control inputs that need to be added to x as operands.
117 llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>> new_control_inputs;
118 // Iterate in reverse order to avoid invalidating Operation* stored in
119 // new_control_inputs.
120 for (auto& item :
121 llvm::make_early_inc_range(llvm::reverse(graph_op.GetBody()))) {
122 if (auto island = dyn_cast<tf_executor::IslandOp>(&item)) {
123 BreakUpIsland(island, side_effect_analysis, &new_control_inputs);
124 }
125 }
126 OpBuilder builder(func);
127
128 // For every op, add new control inputs in reverse order so that the ops don't
129 // get invalidated.
130 llvm::SmallVector<Value, 8> operands;
131 llvm::SmallPtrSet<Operation*, 4> defining_ops;
132 llvm::SmallVector<Type, 4> types;
133 for (auto& item :
134 llvm::make_early_inc_range(llvm::reverse(graph_op.GetBody()))) {
135 auto it = new_control_inputs.find(&item);
136 if (it == new_control_inputs.end()) continue;
137 auto& new_control_inputs_for_item = it->second;
138 builder.setInsertionPoint(&item);
139 OperationState state(item.getLoc(), item.getName());
140 types.assign(item.result_type_begin(), item.result_type_end());
141 state.addTypes(types);
142 for (Region& region : item.getRegions()) {
143 state.addRegion()->takeBody(region);
144 }
145 // Assign existing operands for item.
146 operands.assign(item.operand_begin(), item.operand_end());
147
148 // Collect defining ops for existing operands.
149 defining_ops.clear();
150 for (Value operand : operands) {
151 defining_ops.insert(operand.getDefiningOp());
152 }
153 for (Value new_control_input : llvm::reverse(new_control_inputs_for_item)) {
154 // Add new control input if its defining op is not already a defining
155 // op for some other operand. Update defining_ops.
156 if (defining_ops.insert(new_control_input.getDefiningOp()).second) {
157 operands.push_back(new_control_input);
158 }
159 }
160 state.addOperands(operands);
161 Operation* new_op = builder.create(state);
162 item.replaceAllUsesWith(new_op);
163 new_op->setAttrs(item.getAttrDictionary());
164 item.erase();
165 }
166 AddStatefulControlFlowDependencies(graph_op);
167 }
168
169 // Populates an empty IslandOp and with a NoOp or Identity/IdentityN depending
170 // on if there are any data results.
PopulateEmptyIsland(tf_executor::IslandOp island)171 void PopulateEmptyIsland(tf_executor::IslandOp island) {
172 OpBuilder builder(&island.GetBody(), island.GetBody().begin());
173 tf_executor::YieldOp yield = island.GetYield();
174 if (yield.getNumOperands() == 0) {
175 builder.create<TF::NoOp>(island.getLoc(), TypeRange{}, ValueRange{});
176 } else if (yield.getNumOperands() == 1) {
177 Value operand = yield.getOperand(0);
178 auto identity = builder.create<TF::IdentityOp>(island.getLoc(),
179 operand.getType(), operand);
180 yield.setOperand(0, identity.output());
181 } else {
182 auto identity_n = builder.create<TF::IdentityNOp>(
183 island.getLoc(), yield.getOperandTypes(), yield.getOperands());
184 for (auto it : llvm::enumerate(identity_n.getResults()))
185 yield.setOperand(it.index(), it.value());
186 }
187 }
188
189 // Helper that creates an island. If `sub_op` is not nullptr, it will be moved
190 // to the island. Otherwise a NoOp will be added to the island.
CreateIsland(TypeRange result_types,ValueRange control_inputs,const tf_executor::ControlType & control_type,const Location & loc,Operation * sub_op,tf_executor::IslandOp original_island)191 tf_executor::IslandOp CreateIsland(TypeRange result_types,
192 ValueRange control_inputs,
193 const tf_executor::ControlType& control_type,
194 const Location& loc, Operation* sub_op,
195 tf_executor::IslandOp original_island) {
196 OpBuilder builder(original_island);
197 auto island = builder.create<tf_executor::IslandOp>(
198 loc, result_types, control_type, control_inputs);
199 island.body().push_back(new Block);
200 Block* block = &island.body().back();
201 OpBuilder island_builder(original_island);
202 island_builder.setInsertionPointToEnd(block);
203 if (sub_op) {
204 sub_op->replaceAllUsesWith(island.outputs());
205 sub_op->moveBefore(block, block->begin());
206 island_builder.create<tf_executor::YieldOp>(loc, sub_op->getResults());
207 } else {
208 island_builder.create<TF::NoOp>(island.getLoc(), TypeRange{}, ValueRange{});
209 island_builder.create<tf_executor::YieldOp>(loc, ValueRange{});
210 }
211 return island;
212 }
213
214 // A struct that contains the operations in an island that need explicit control
215 // dependencies added going into and out of the island to capture inter-island
216 // dependencies properly.
217 struct IslandSourcesAndSinks {
218 // Sub-ops that need a control dependency going into the island. This includes
219 // sub-ops that do not depend on other sub-ops in the island and functional
220 // control ops (e.g. if, while, case) with side effects that must not take
221 // effect before the previous island is finished executing.
222 llvm::SmallPtrSet<Operation*, 4> sources;
223
224 // Sub-ops that need a control dependency going out of the island. This
225 // includes sub-ops that do not have other sub-ops in the island depending on
226 // them (excluding yield) and functional control ops (e.g. if, while, case)
227 // with side effects that must take effect before the next island starts
228 // executing.
229 llvm::SmallPtrSet<Operation*, 4> sinks;
230 };
231
232 // Finds IslandSourcesAndSinks for an unmodified island.
FindSourcesAndSinksInIsland(tf_executor::IslandOp island,const TF::SideEffectAnalysis::Info & side_effect_analysis)233 IslandSourcesAndSinks FindSourcesAndSinksInIsland(
234 tf_executor::IslandOp island,
235 const TF::SideEffectAnalysis::Info& side_effect_analysis) {
236 IslandSourcesAndSinks result;
237 auto island_body = island.GetBody().without_terminator();
238 for (Operation& sub_op : island_body) {
239 auto predecessors = side_effect_analysis.DirectControlPredecessors(&sub_op);
240 result.sinks.insert(&sub_op);
241 // Remove predecessor from sinks.
242 for (auto predecessor : predecessors) result.sinks.erase(predecessor);
243 bool has_in_island_operands = false;
244 for (auto operand : sub_op.getOperands()) {
245 auto defining_op = operand.getDefiningOp();
246 if (!defining_op || defining_op->getParentOp() != island) continue;
247 has_in_island_operands = true;
248
249 // Remove operands from sinks.
250 // We don't remove the operand if it is a stateful functional control flow
251 // op to work around an issue in LowerFunctionalOpsPass where the operand
252 // dependency isn't enough to ensure the side effects take place
253 // (b/185483669).
254 if (!IsStatefulFunctionalControlFlowOp(defining_op)) {
255 result.sinks.erase(defining_op);
256 }
257 }
258 if (predecessors.empty() && (!has_in_island_operands ||
259 IsStatefulFunctionalControlFlowOp(&sub_op))) {
260 result.sources.insert(&sub_op);
261 }
262 }
263 return result;
264 }
265
266 // Converts a single island into multiple islands (one for each op). The islands
267 // are chained together by control flow values.
BreakUpIsland(tf_executor::IslandOp island_op,const TF::SideEffectAnalysis::Info & side_effect_analysis,llvm::DenseMap<Operation *,llvm::SmallVector<Value,4>> * new_control_inputs)268 void BreakUpIslands::BreakUpIsland(
269 tf_executor::IslandOp island_op,
270 const TF::SideEffectAnalysis::Info& side_effect_analysis,
271 llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>*
272 new_control_inputs) {
273 auto island_body = island_op.GetBody().without_terminator();
274 // Populate islands that are empty (only yield).
275 if (island_body.empty()) {
276 PopulateEmptyIsland(island_op);
277 return;
278 }
279
280 // Skip islands that are already only a single op.
281 if (island_op.WrapsSingleOp()) return;
282
283 auto control_type = tf_executor::ControlType::get(&getContext());
284 auto island_control_inputs = llvm::to_vector<4>(island_op.controlInputs());
285 // Add control dependencies for yields of values defined by other islands to
286 // the island that defines that fetched value.
287 for (auto fetch : island_op.GetYield().fetches()) {
288 if (!fetch.getDefiningOp()) {
289 // Skip, because there is no op to add control to (eg: function args).
290 continue;
291 } else if (fetch.getDefiningOp()->getParentOp() == island_op) {
292 // Skip, because it is the same island.
293 continue;
294 } else if (auto other_island_op = llvm::dyn_cast<tf_executor::IslandOp>(
295 fetch.getDefiningOp())) {
296 island_control_inputs.push_back(other_island_op.control());
297 } else {
298 // TODO(parkers): Any defining op that has a control output can be handled
299 // just like an island.
300 fetch.getDefiningOp()->emitError("fetching non-island as dependency");
301 return signalPassFailure();
302 }
303 }
304 // If there are multiple control inputs, create an empty island to group them.
305 if (island_control_inputs.size() > 1) {
306 auto new_island = CreateIsland({}, island_control_inputs, control_type,
307 island_op.getLoc(), nullptr, island_op);
308 island_control_inputs.clear();
309 island_control_inputs.push_back(new_island.control());
310 }
311 // Find sources and sinks inside the original island.
312 IslandSourcesAndSinks sources_and_sinks =
313 FindSourcesAndSinksInIsland(island_op, side_effect_analysis);
314 // The corresponding control output of the new island created for each sub-op.
315 llvm::SmallDenseMap<Operation*, Value, 8> new_control_for_sub_ops;
316 // Control outputs of newly created islands that are sinks.
317 llvm::SmallVector<Value, 8> sink_island_controls;
318 // For each operation in the island, construct a new island to wrap the op,
319 // yield all the results, and replace all the usages with the results of the
320 // new island.
321 for (auto& sub_op : llvm::make_early_inc_range(island_body)) {
322 const auto predecessors =
323 side_effect_analysis.DirectControlPredecessors(&sub_op);
324 // Get the controls from the predecessors.
325 llvm::SmallVector<Value, 4> predecessor_controls;
326 predecessor_controls.reserve(predecessors.size());
327 for (auto predecessor : predecessors) {
328 predecessor_controls.push_back(new_control_for_sub_ops[predecessor]);
329 }
330 // If sub_op is a source, use island_control_inputs, because that's required
331 // by inter-islands dependencies; otherwise, we do not need to include
332 // island_control_inputs, since they must have been tracked by the (direct
333 // or indirect) control predecessors or operands.
334 ArrayRef<Value> control = sources_and_sinks.sources.count(&sub_op) > 0
335 ? island_control_inputs
336 : predecessor_controls;
337 auto new_island =
338 CreateIsland(sub_op.getResultTypes(), control, control_type,
339 sub_op.getLoc(), &sub_op, island_op);
340 new_control_for_sub_ops[&sub_op] = new_island.control();
341 if (sources_and_sinks.sinks.count(&sub_op)) {
342 sink_island_controls.push_back(new_island.control());
343 }
344 }
345 // Create control outputs for the sinks.
346 assert(!sink_island_controls.empty());
347 // If there are multiple control outputs, create an empty island to group
348 // them.
349 if (sink_island_controls.size() > 1) {
350 auto new_island = CreateIsland({}, sink_island_controls, control_type,
351 island_op.getLoc(), nullptr, island_op);
352 sink_island_controls.clear();
353 sink_island_controls.push_back(new_island.control());
354 }
355 assert(sink_island_controls.size() == 1);
356 auto& sink_island_control = sink_island_controls[0];
357 island_op.control().replaceAllUsesWith(sink_island_control);
358 // All existing outputs need to add sink_island_control as control input.
359 // GraphOp, YieldOp and NextIterationSourceOp don't have control inputs so
360 // exclude them below.
361 for (Value out : island_op.outputs()) {
362 for (auto& use : out.getUses()) {
363 Operation* owner = use.getOwner();
364 if (auto other_island_op =
365 llvm::dyn_cast<tf_executor::IslandOp>(owner->getParentOp())) {
366 (*new_control_inputs)[other_island_op].push_back(sink_island_control);
367 } else if (owner->getDialect() == island_op->getDialect() &&
368 !llvm::isa<tf_executor::GraphOp, tf_executor::YieldOp,
369 tf_executor::NextIterationSourceOp>(owner)) {
370 (*new_control_inputs)[owner].push_back(sink_island_control);
371 } else {
372 owner->emitOpError("adding control dependency not supported");
373 return signalPassFailure();
374 }
375 }
376 }
377 for (auto item :
378 llvm::zip(island_op.outputs(), island_op.GetYield().fetches()))
379 std::get<0>(item).replaceAllUsesWith(std::get<1>(item));
380 island_op.erase();
381 }
382
383 } // namespace
384
CreateBreakUpIslandsPass()385 std::unique_ptr<OperationPass<ModuleOp>> CreateBreakUpIslandsPass() {
386 return std::make_unique<BreakUpIslands>();
387 }
388
389 } // namespace mlir
390
391