xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/Operation.h"  // from @llvm-project
26 #include "mlir/IR/Value.h"  // from @llvm-project
27 #include "mlir/Pass/Pass.h"  // from @llvm-project
28 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 
33 // This pass is used in preparation for Graph export.
34 // The GraphDef exporter expects each op to be in its own island.
35 // This pass puts the IR in that form.
36 //
37 // We do this as an IR->IR transform to keep the Graph exporter as simple as
38 // possible.
39 
40 namespace mlir {
41 
42 namespace {
43 
44 class BreakUpIslands : public TF::PerFunctionAggregateAnalysisConsumerPass<
45                            BreakUpIslands, TF::SideEffectAnalysis> {
getDependentDialects(DialectRegistry & registry) const46   void getDependentDialects(DialectRegistry& registry) const override {
47     registry.insert<tf_executor::TensorFlowExecutorDialect>();
48   }
49 
50  public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BreakUpIslands)51   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BreakUpIslands)
52 
53   StringRef getArgument() const final { return "tf-executor-break-up-islands"; }
54 
getDescription() const55   StringRef getDescription() const final {
56     return "Transform from TF control dialect to TF executor dialect.";
57   }
58 
59   void runOnFunction(func::FuncOp func,
60                      const TF::SideEffectAnalysis::Info& side_effect_analysis);
61 
62   void BreakUpIsland(tf_executor::IslandOp island_op,
63                      const TF::SideEffectAnalysis::Info& side_effect_analysis,
64                      llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>*
65                          new_control_inputs);
66 };
67 
68 // Returns true if the operation is a stateful If, Case, or While op.
IsStatefulFunctionalControlFlowOp(Operation * op)69 bool IsStatefulFunctionalControlFlowOp(Operation* op) {
70   if (!isa<TF::IfOp, TF::CaseOp, TF::WhileOp>(op)) {
71     return false;
72   }
73 
74   if (auto is_stateless = op->getAttrOfType<BoolAttr>("is_stateless")) {
75     return !is_stateless.getValue();
76   }
77   return false;
78 }
79 
80 // Add control dependencies from stateful control-flow ops to graph fetch op.
81 // This is needed to avoid that such control-flow ops get pruned because of a
82 // bug in common runtime (see b/185483669).
AddStatefulControlFlowDependencies(tf_executor::GraphOp graph_op)83 void AddStatefulControlFlowDependencies(tf_executor::GraphOp graph_op) {
84   llvm::SmallDenseSet<Value, 8> graph_fetches;
85   for (Value value : graph_op.GetFetch().fetches()) {
86     graph_fetches.insert(value);
87   }
88   for (Operation& op : graph_op.GetBody().without_terminator()) {
89     auto island = dyn_cast<tf_executor::IslandOp>(&op);
90     if (!island) continue;
91     if (!island.WrapsSingleOp()) continue;
92     Operation& wrapped_op = island.GetBody().front();
93     if (!IsStatefulFunctionalControlFlowOp(&wrapped_op)) continue;
94     if (graph_fetches.contains(island.control())) continue;
95 
96     graph_op.GetFetch().fetchesMutable().append(island.control());
97   }
98 }
99 
runOnFunction(func::FuncOp func,const TF::SideEffectAnalysis::Info & side_effect_analysis)100 void BreakUpIslands::runOnFunction(
101     func::FuncOp func,
102     const TF::SideEffectAnalysis::Info& side_effect_analysis) {
103   auto graph_op_range = func.front().without_terminator();
104   tf_executor::GraphOp graph_op;
105 
106   if (llvm::hasSingleElement(graph_op_range))
107     graph_op = dyn_cast<tf_executor::GraphOp>(func.front().front());
108 
109   if (!graph_op) {
110     func.emitError("expected function to contain only a graph_op");
111     signalPassFailure();
112     return;
113   }
114 
115   // New control inputs to be added. For an operation x, new_control_inputs[x]
116   // contains all control inputs that need to be added to x as operands.
117   llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>> new_control_inputs;
118   // Iterate in reverse order to avoid invalidating Operation* stored in
119   // new_control_inputs.
120   for (auto& item :
121        llvm::make_early_inc_range(llvm::reverse(graph_op.GetBody()))) {
122     if (auto island = dyn_cast<tf_executor::IslandOp>(&item)) {
123       BreakUpIsland(island, side_effect_analysis, &new_control_inputs);
124     }
125   }
126   OpBuilder builder(func);
127 
128   // For every op, add new control inputs in reverse order so that the ops don't
129   // get invalidated.
130   llvm::SmallVector<Value, 8> operands;
131   llvm::SmallPtrSet<Operation*, 4> defining_ops;
132   llvm::SmallVector<Type, 4> types;
133   for (auto& item :
134        llvm::make_early_inc_range(llvm::reverse(graph_op.GetBody()))) {
135     auto it = new_control_inputs.find(&item);
136     if (it == new_control_inputs.end()) continue;
137     auto& new_control_inputs_for_item = it->second;
138     builder.setInsertionPoint(&item);
139     OperationState state(item.getLoc(), item.getName());
140     types.assign(item.result_type_begin(), item.result_type_end());
141     state.addTypes(types);
142     for (Region& region : item.getRegions()) {
143       state.addRegion()->takeBody(region);
144     }
145     // Assign existing operands for item.
146     operands.assign(item.operand_begin(), item.operand_end());
147 
148     // Collect defining ops for existing operands.
149     defining_ops.clear();
150     for (Value operand : operands) {
151       defining_ops.insert(operand.getDefiningOp());
152     }
153     for (Value new_control_input : llvm::reverse(new_control_inputs_for_item)) {
154       // Add new control input if its defining op is not already a defining
155       // op for some other operand. Update defining_ops.
156       if (defining_ops.insert(new_control_input.getDefiningOp()).second) {
157         operands.push_back(new_control_input);
158       }
159     }
160     state.addOperands(operands);
161     Operation* new_op = builder.create(state);
162     item.replaceAllUsesWith(new_op);
163     new_op->setAttrs(item.getAttrDictionary());
164     item.erase();
165   }
166   AddStatefulControlFlowDependencies(graph_op);
167 }
168 
169 // Populates an empty IslandOp and with a NoOp or Identity/IdentityN depending
170 // on if there are any data results.
PopulateEmptyIsland(tf_executor::IslandOp island)171 void PopulateEmptyIsland(tf_executor::IslandOp island) {
172   OpBuilder builder(&island.GetBody(), island.GetBody().begin());
173   tf_executor::YieldOp yield = island.GetYield();
174   if (yield.getNumOperands() == 0) {
175     builder.create<TF::NoOp>(island.getLoc(), TypeRange{}, ValueRange{});
176   } else if (yield.getNumOperands() == 1) {
177     Value operand = yield.getOperand(0);
178     auto identity = builder.create<TF::IdentityOp>(island.getLoc(),
179                                                    operand.getType(), operand);
180     yield.setOperand(0, identity.output());
181   } else {
182     auto identity_n = builder.create<TF::IdentityNOp>(
183         island.getLoc(), yield.getOperandTypes(), yield.getOperands());
184     for (auto it : llvm::enumerate(identity_n.getResults()))
185       yield.setOperand(it.index(), it.value());
186   }
187 }
188 
189 // Helper that creates an island. If `sub_op` is not nullptr, it will be moved
190 // to the island. Otherwise a NoOp will be added to the island.
CreateIsland(TypeRange result_types,ValueRange control_inputs,const tf_executor::ControlType & control_type,const Location & loc,Operation * sub_op,tf_executor::IslandOp original_island)191 tf_executor::IslandOp CreateIsland(TypeRange result_types,
192                                    ValueRange control_inputs,
193                                    const tf_executor::ControlType& control_type,
194                                    const Location& loc, Operation* sub_op,
195                                    tf_executor::IslandOp original_island) {
196   OpBuilder builder(original_island);
197   auto island = builder.create<tf_executor::IslandOp>(
198       loc, result_types, control_type, control_inputs);
199   island.body().push_back(new Block);
200   Block* block = &island.body().back();
201   OpBuilder island_builder(original_island);
202   island_builder.setInsertionPointToEnd(block);
203   if (sub_op) {
204     sub_op->replaceAllUsesWith(island.outputs());
205     sub_op->moveBefore(block, block->begin());
206     island_builder.create<tf_executor::YieldOp>(loc, sub_op->getResults());
207   } else {
208     island_builder.create<TF::NoOp>(island.getLoc(), TypeRange{}, ValueRange{});
209     island_builder.create<tf_executor::YieldOp>(loc, ValueRange{});
210   }
211   return island;
212 }
213 
214 // A struct that contains the operations in an island that need explicit control
215 // dependencies added going into and out of the island to capture inter-island
216 // dependencies properly.
217 struct IslandSourcesAndSinks {
218   // Sub-ops that need a control dependency going into the island. This includes
219   // sub-ops that do not depend on other sub-ops in the island and functional
220   // control ops (e.g. if, while, case) with side effects that must not take
221   // effect before the previous island is finished executing.
222   llvm::SmallPtrSet<Operation*, 4> sources;
223 
224   // Sub-ops that need a control dependency going out of the island. This
225   // includes sub-ops that do not have other sub-ops in the island depending on
226   // them (excluding yield) and functional control ops (e.g. if, while, case)
227   // with side effects that must take effect before the next island starts
228   // executing.
229   llvm::SmallPtrSet<Operation*, 4> sinks;
230 };
231 
232 // Finds IslandSourcesAndSinks for an unmodified island.
FindSourcesAndSinksInIsland(tf_executor::IslandOp island,const TF::SideEffectAnalysis::Info & side_effect_analysis)233 IslandSourcesAndSinks FindSourcesAndSinksInIsland(
234     tf_executor::IslandOp island,
235     const TF::SideEffectAnalysis::Info& side_effect_analysis) {
236   IslandSourcesAndSinks result;
237   auto island_body = island.GetBody().without_terminator();
238   for (Operation& sub_op : island_body) {
239     auto predecessors = side_effect_analysis.DirectControlPredecessors(&sub_op);
240     result.sinks.insert(&sub_op);
241     // Remove predecessor from sinks.
242     for (auto predecessor : predecessors) result.sinks.erase(predecessor);
243     bool has_in_island_operands = false;
244     for (auto operand : sub_op.getOperands()) {
245       auto defining_op = operand.getDefiningOp();
246       if (!defining_op || defining_op->getParentOp() != island) continue;
247       has_in_island_operands = true;
248 
249       // Remove operands from sinks.
250       // We don't remove the operand if it is a stateful functional control flow
251       // op to work around an issue in LowerFunctionalOpsPass where the operand
252       // dependency isn't enough to ensure the side effects take place
253       // (b/185483669).
254       if (!IsStatefulFunctionalControlFlowOp(defining_op)) {
255         result.sinks.erase(defining_op);
256       }
257     }
258     if (predecessors.empty() && (!has_in_island_operands ||
259                                  IsStatefulFunctionalControlFlowOp(&sub_op))) {
260       result.sources.insert(&sub_op);
261     }
262   }
263   return result;
264 }
265 
266 // Converts a single island into multiple islands (one for each op). The islands
267 // are chained together by control flow values.
BreakUpIsland(tf_executor::IslandOp island_op,const TF::SideEffectAnalysis::Info & side_effect_analysis,llvm::DenseMap<Operation *,llvm::SmallVector<Value,4>> * new_control_inputs)268 void BreakUpIslands::BreakUpIsland(
269     tf_executor::IslandOp island_op,
270     const TF::SideEffectAnalysis::Info& side_effect_analysis,
271     llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>*
272         new_control_inputs) {
273   auto island_body = island_op.GetBody().without_terminator();
274   // Populate islands that are empty (only yield).
275   if (island_body.empty()) {
276     PopulateEmptyIsland(island_op);
277     return;
278   }
279 
280   // Skip islands that are already only a single op.
281   if (island_op.WrapsSingleOp()) return;
282 
283   auto control_type = tf_executor::ControlType::get(&getContext());
284   auto island_control_inputs = llvm::to_vector<4>(island_op.controlInputs());
285   // Add control dependencies for yields of values defined by other islands to
286   // the island that defines that fetched value.
287   for (auto fetch : island_op.GetYield().fetches()) {
288     if (!fetch.getDefiningOp()) {
289       // Skip, because there is no op to add control to (eg: function args).
290       continue;
291     } else if (fetch.getDefiningOp()->getParentOp() == island_op) {
292       // Skip, because it is the same island.
293       continue;
294     } else if (auto other_island_op = llvm::dyn_cast<tf_executor::IslandOp>(
295                    fetch.getDefiningOp())) {
296       island_control_inputs.push_back(other_island_op.control());
297     } else {
298       // TODO(parkers): Any defining op that has a control output can be handled
299       // just like an island.
300       fetch.getDefiningOp()->emitError("fetching non-island as dependency");
301       return signalPassFailure();
302     }
303   }
304   // If there are multiple control inputs, create an empty island to group them.
305   if (island_control_inputs.size() > 1) {
306     auto new_island = CreateIsland({}, island_control_inputs, control_type,
307                                    island_op.getLoc(), nullptr, island_op);
308     island_control_inputs.clear();
309     island_control_inputs.push_back(new_island.control());
310   }
311   // Find sources and sinks inside the original island.
312   IslandSourcesAndSinks sources_and_sinks =
313       FindSourcesAndSinksInIsland(island_op, side_effect_analysis);
314   // The corresponding control output of the new island created for each sub-op.
315   llvm::SmallDenseMap<Operation*, Value, 8> new_control_for_sub_ops;
316   // Control outputs of newly created islands that are sinks.
317   llvm::SmallVector<Value, 8> sink_island_controls;
318   // For each operation in the island, construct a new island to wrap the op,
319   // yield all the results, and replace all the usages with the results of the
320   // new island.
321   for (auto& sub_op : llvm::make_early_inc_range(island_body)) {
322     const auto predecessors =
323         side_effect_analysis.DirectControlPredecessors(&sub_op);
324     // Get the controls from the predecessors.
325     llvm::SmallVector<Value, 4> predecessor_controls;
326     predecessor_controls.reserve(predecessors.size());
327     for (auto predecessor : predecessors) {
328       predecessor_controls.push_back(new_control_for_sub_ops[predecessor]);
329     }
330     // If sub_op is a source, use island_control_inputs, because that's required
331     // by inter-islands dependencies; otherwise, we do not need to include
332     // island_control_inputs, since they must have been tracked by the (direct
333     // or indirect) control predecessors or operands.
334     ArrayRef<Value> control = sources_and_sinks.sources.count(&sub_op) > 0
335                                   ? island_control_inputs
336                                   : predecessor_controls;
337     auto new_island =
338         CreateIsland(sub_op.getResultTypes(), control, control_type,
339                      sub_op.getLoc(), &sub_op, island_op);
340     new_control_for_sub_ops[&sub_op] = new_island.control();
341     if (sources_and_sinks.sinks.count(&sub_op)) {
342       sink_island_controls.push_back(new_island.control());
343     }
344   }
345   // Create control outputs for the sinks.
346   assert(!sink_island_controls.empty());
347   // If there are multiple control outputs, create an empty island to group
348   // them.
349   if (sink_island_controls.size() > 1) {
350     auto new_island = CreateIsland({}, sink_island_controls, control_type,
351                                    island_op.getLoc(), nullptr, island_op);
352     sink_island_controls.clear();
353     sink_island_controls.push_back(new_island.control());
354   }
355   assert(sink_island_controls.size() == 1);
356   auto& sink_island_control = sink_island_controls[0];
357   island_op.control().replaceAllUsesWith(sink_island_control);
358   // All existing outputs need to add sink_island_control as control input.
359   // GraphOp, YieldOp and NextIterationSourceOp don't have control inputs so
360   // exclude them below.
361   for (Value out : island_op.outputs()) {
362     for (auto& use : out.getUses()) {
363       Operation* owner = use.getOwner();
364       if (auto other_island_op =
365               llvm::dyn_cast<tf_executor::IslandOp>(owner->getParentOp())) {
366         (*new_control_inputs)[other_island_op].push_back(sink_island_control);
367       } else if (owner->getDialect() == island_op->getDialect() &&
368                  !llvm::isa<tf_executor::GraphOp, tf_executor::YieldOp,
369                             tf_executor::NextIterationSourceOp>(owner)) {
370         (*new_control_inputs)[owner].push_back(sink_island_control);
371       } else {
372         owner->emitOpError("adding control dependency not supported");
373         return signalPassFailure();
374       }
375     }
376   }
377   for (auto item :
378        llvm::zip(island_op.outputs(), island_op.GetYield().fetches()))
379     std::get<0>(item).replaceAllUsesWith(std::get<1>(item));
380   island_op.erase();
381 }
382 
383 }  // namespace
384 
CreateBreakUpIslandsPass()385 std::unique_ptr<OperationPass<ModuleOp>> CreateBreakUpIslandsPass() {
386   return std::make_unique<BreakUpIslands>();
387 }
388 
389 }  // namespace mlir
390 
391