1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <queue>
18 
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/EquivalenceClasses.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
27 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
28 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
29 #include "mlir/IR/Value.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
31 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h"
38 
39 namespace mlir {
40 namespace tf_executor {
41 namespace {
42 
43 using TF::ResourceId;
44 static constexpr ResourceId kUnknownResourceId =
45     TF::detail::ResourceAliasAnalysisInfo::kUnknownResourceId;
46 static constexpr ResourceId kInvalidResourceId =
47     TF::detail::ResourceAliasAnalysisInfo::kInvalidResourceId;
48 using OperationSetTy = SmallPtrSet<Operation*, 4>;
49 using ResourceToOpsMapTy = DenseMap<ResourceId, OperationSetTy>;
50 
51 class ConvertControlToDataOutputsPass
52     : public TF::ExecutorConvertControlToDataOutputsPassBase<
53           ConvertControlToDataOutputsPass> {
54  public:
55   void runOnOperation() override;
56 };
57 
58 // Returns a vector of all tf.WhileOp(s) which use func as while body. If any of
59 // the uses is as a while condition, an empty vector is returned.
GetWhileCallers(func::FuncOp func,SymbolUserMap & symbol_map)60 SmallVector<TF::WhileOp> GetWhileCallers(func::FuncOp func,
61                                          SymbolUserMap& symbol_map) {
62   SmallVector<TF::WhileOp> while_callers;
63   for (auto user : symbol_map.getUsers(func)) {
64     if (auto while_caller = dyn_cast<TF::WhileOp>(user)) {
65       // If used as while conditional anywhere, then skip optimizing this
66       // function. Return empty vector.
67       if (while_caller.cond_function() == func) return {};
68       assert(while_caller.body_function() == func);
69       while_callers.push_back(while_caller);
70     }
71   }
72   return while_callers;
73 }
74 
75 // Populates `chain_resource_to_ops_map`, the map from all resources that need
76 // to be chained to the set of operations that access the resource, and
77 // `resource_equivalence_classes`. Resources are equivalent if they are accessed
78 // by a common op, and equivalent resources will be assigned to the same chain.
CollectChainResources(func::FuncOp func,ResourceToOpsMapTy & chain_resource_to_ops_map,llvm::EquivalenceClasses<ResourceId> & resource_equivalence_classes,const TF::SideEffectAnalysis::Info & side_effect_analysis)79 void CollectChainResources(
80     func::FuncOp func, ResourceToOpsMapTy& chain_resource_to_ops_map,
81     llvm::EquivalenceClasses<ResourceId>& resource_equivalence_classes,
82     const TF::SideEffectAnalysis::Info& side_effect_analysis) {
83   auto graph_op = cast<GraphOp>(func.front().front());
84 
85   // For each op in the graph, get the resources it uses and update the access
86   // information for them.
87   graph_op.walk([&](IslandOp island) {
88     // This pass assumes that all functions are suitable for export i.e., each
89     // function has a single tf_executor.graph op and all islands wrap the
90     // internal op perfectly. Hence this assertion should never fail.
91     assert(island.WrapsSingleOp());
92     Operation& op = island.GetBody().front();
93 
94     ResourceId prev_resource_id = kInvalidResourceId;
95     for (auto resource_id_read_only_pair :
96          side_effect_analysis.GetResourceIds(&op)) {
97       ResourceId resource_id = resource_id_read_only_pair.first;
98       // If the resource was allocated by an op with `UniqueResourceAllocation`
99       // trait, then we don't need to chain resource ops accessing this resource
100       // between iterations: Every iteration will create a new independent
101       // resource. This enables more parallelism across iterations.
102       if (!side_effect_analysis.IsUniqueResourceAllocationId(resource_id)) {
103         chain_resource_to_ops_map[resource_id].insert(&op);
104         if (prev_resource_id != kInvalidResourceId) {
105           // Merge class of current ID with class of previous ID since both
106           // resources are accessed by `op`.
107           resource_equivalence_classes.unionSets(prev_resource_id, resource_id);
108         } else {
109           resource_equivalence_classes.insert(resource_id);
110         }
111         prev_resource_id = resource_id;
112       }
113     }
114   });
115 }
116 
117 // tf.NoOp islands are used to combine multiple control dependencies into one.
118 // These islands have a single tf.NoOp inside them and consume multiple control
119 // outputs to generate a single control output.
120 //
121 // For example,
122 // ```
123 // %merged_control = "tf_executor.island"(%control_a, %control_b) ({
124 //   "tf.NoOp"() : () -> ()
125 //   "tf_executor.yield"() : () -> ()
126 // }) : (!tf_executor.control, !tf_executor.control) -> (!tf_executor.control)
127 // ```
128 //
129 // `%merged_control` is a NoOp control barrier in this case.
130 //
131 // Checks if the value `control` is a NoOp control barrier.
IsNoOpControlBarrier(Value control)132 bool IsNoOpControlBarrier(Value control) {
133   if (!control.getType().isa<ControlType>()) return false;
134 
135   auto control_island = dyn_cast_or_null<IslandOp>(control.getDefiningOp());
136   if (!control_island) return false;
137 
138   // All islands perfectly wrap a single op is an invariant of this pass and
139   // is checked at the very beginning of the pass.
140   assert(control_island.WrapsSingleOp());
141   return control_island.outputs().empty() &&
142          isa<TF::NoOp>(control_island.GetBody().front());
143 }
144 
145 // Remove all control outputs of the function. Traverses NoOp control barrier
146 // chains from FetchOp to all NoOp control barriers. Returns true
147 // iff at least one control output is deleted.
RemoveAllControlOutputs(func::FuncOp func)148 bool RemoveAllControlOutputs(func::FuncOp func) {
149   auto graph_op = cast<GraphOp>(func.front().front());
150 
151   FetchOp fetch = graph_op.GetFetch();
152   // Return early if no control outputs exist.
153   if (fetch.getNumOperands() == graph_op->getNumResults()) return false;
154 
155   std::queue<Value> control_barrier_worklist;
156   for (Value control_output :
157        fetch.fetches().drop_front(graph_op->getNumResults())) {
158     if (IsNoOpControlBarrier(control_output))
159       control_barrier_worklist.push(control_output);
160   }
161 
162   // Erase all control outputs at the end from fetch.
163   fetch.fetchesMutable().erase(
164       graph_op.getNumResults(),
165       fetch.getNumOperands() - graph_op.getNumResults());
166 
167   // Iterate the worklist to remove all NoOp control barriers at the end of the
168   // function body that are used to merge two or more control dependencies.
169   while (!control_barrier_worklist.empty()) {
170     Value control_barrier = control_barrier_worklist.front();
171     control_barrier_worklist.pop();
172 
173     // We can only erase control barriers whose uses have been erased as well.
174     if (!control_barrier.use_empty()) continue;
175 
176     // Only values defined by IslandOp were inserted in the worklist.
177     IslandOp current_island = cast<IslandOp>(control_barrier.getDefiningOp());
178 
179     for (auto control_input : current_island.controlInputs()) {
180       if (IsNoOpControlBarrier(control_input))
181         control_barrier_worklist.push(control_input);
182     }
183     current_island.erase();
184   }
185   return true;
186 }
187 
188 // Appends function arguments with `num_resources` number of arguments of
189 // requested type.
AppendFunctionArguments(func::FuncOp func,int num_resources,ShapedType chaining_data_type)190 void AppendFunctionArguments(func::FuncOp func, int num_resources,
191                              ShapedType chaining_data_type) {
192   for (int i = 0; i < num_resources; ++i) {
193     func.getRegion().addArgument(chaining_data_type, func.getLoc());
194   }
195 
196   FunctionType ftype =
197       FunctionType::get(func.getContext(), func.getBody().getArgumentTypes(),
198                         func.getFunctionType().getResults());
199   func.setType(ftype);
200 }
201 
202 // Appends function results with `num_resources` number of results of requested
203 // type.
AppendFunctionResults(func::FuncOp func,int num_resources,ShapedType chaining_data_type)204 void AppendFunctionResults(func::FuncOp func, int num_resources,
205                            ShapedType chaining_data_type) {
206   Block& block = func.front();
207   auto graph_op = cast<GraphOp>(block.front());
208   // Note that func result types are same as the result types of
209   // GraphOp in the function `func`.
210   assert(std::equal(func->getResultTypes().begin(),
211                     func->getResultTypes().end(),
212                     graph_op->getResultTypes().begin()));
213   auto new_result_types =
214       llvm::to_vector<4>(func.getFunctionType().getResults());
215   for (int i = 0; i < num_resources; ++i) {
216     new_result_types.push_back(chaining_data_type);
217   }
218   FunctionType ftype = FunctionType::get(
219       func.getContext(), func.getArgumentTypes(), new_result_types);
220   func.setType(ftype);
221 
222   // Rewrite GraphOp to have same number of results as the
223   // function.
224   OpBuilder builder(graph_op);
225   auto new_graph_op =
226       builder.create<GraphOp>(graph_op.getLoc(), new_result_types);
227   new_graph_op.getRegion().takeBody(graph_op.getRegion());
228   graph_op->replaceAllUsesWith(
229       new_graph_op->getResults().drop_back(num_resources));
230   graph_op.erase();
231   func::ReturnOp return_op = cast<func::ReturnOp>(block.getTerminator());
232   int num_old_arguments = return_op.getNumOperands();
233   for (int i = 0; i < num_resources; ++i) {
234     return_op.operandsMutable().append(
235         new_graph_op.getResult(num_old_arguments + i));
236   }
237 }
238 
239 // Creates a wrapper island enclosing the `sub_op` dependent on
240 // `control_inputs`.
CreateIsland(Operation * sub_op,ValueRange control_inputs,OpBuilder builder)241 IslandOp CreateIsland(Operation* sub_op, ValueRange control_inputs,
242                       OpBuilder builder) {
243   assert(sub_op);
244   auto control_type = ControlType::get(builder.getContext());
245   auto island = builder.create<IslandOp>(
246       sub_op->getLoc(), sub_op->getResultTypes(), control_type, control_inputs);
247   island.body().push_back(new Block);
248   Block* block = &island.body().back();
249   builder.setInsertionPointToEnd(block);
250   sub_op->replaceAllUsesWith(island.outputs());
251   sub_op->moveBefore(block, block->begin());
252   builder.create<YieldOp>(sub_op->getLoc(), sub_op->getResults());
253   return island;
254 }
255 
256 // Adds control dependencies from/to chain arguments/results. It adds two
257 // identity ops, chain_src and chain_sink, per resource equivalence class.
258 // Using the resource to operations map, it adds (1) a control dependency
259 // from chain_src to all the operations that read/write to a resource of the
260 // equivalence class, and (2) a control dependency from all the operations that
261 // read/write to a resource of the class to the chain_sink operation.
ChainResourceOps(func::FuncOp func,ResourceToOpsMapTy & chain_resource_to_ops_map,llvm::EquivalenceClasses<ResourceId> & resource_equivalence_classes,int num_old_outputs)262 void ChainResourceOps(
263     func::FuncOp func, ResourceToOpsMapTy& chain_resource_to_ops_map,
264     llvm::EquivalenceClasses<ResourceId>& resource_equivalence_classes,
265     int num_old_outputs) {
266   assert(num_old_outputs + resource_equivalence_classes.getNumClasses() ==
267          func.getNumArguments());
268   auto graph_op = cast<GraphOp>(func.front().front());
269 
270   auto fetch = graph_op.GetFetch();
271   OpBuilder builder_chain_src(fetch);
272   builder_chain_src.setInsertionPointToStart(fetch->getBlock());
273 
274   OpBuilder builder_chain_sink(fetch);
275   int chain_index = num_old_outputs;
276 
277   // Iterate over all equivalence classes.
278   for (auto class_iter = resource_equivalence_classes.begin();
279        class_iter != resource_equivalence_classes.end(); ++class_iter) {
280     // Only visit one element per class, the leader.
281     if (!class_iter->isLeader()) continue;
282 
283     // Create chain source and sink identity islands for current equivalence
284     // class.
285     auto chain_arg = func.getArgument(chain_index++);
286     auto src_identity = builder_chain_src.create<TF::IdentityOp>(
287         chain_arg.getLoc(), chain_arg.getType(), chain_arg);
288     auto chain_src_island = CreateIsland(src_identity, {}, builder_chain_src);
289 
290     auto sink_identity = builder_chain_sink.create<TF::IdentityOp>(
291         chain_arg.getLoc(), chain_arg.getType(), chain_arg);
292     auto chain_sink_island =
293         CreateIsland(sink_identity, {}, builder_chain_sink);
294 
295     // Add the chain sink data output to fetch.
296     fetch.fetchesMutable().append(chain_sink_island.outputs().front());
297 
298     // Iterate over all members of the current equivalence class (represented
299     // by `class_iter`). Keep track of ops that have already been processed.
300     llvm::SmallDenseSet<Operation*> processed_ops;
301     for (auto member_iter =
302              resource_equivalence_classes.member_begin(class_iter);
303          member_iter != resource_equivalence_classes.member_end();
304          ++member_iter) {
305       ResourceId resource_id = *member_iter;
306       auto map_iter = chain_resource_to_ops_map.find(resource_id);
307       if (map_iter == chain_resource_to_ops_map.end()) continue;
308       OperationSetTy& resource_ops = map_iter->getSecond();
309 
310       // Add dependencies between all ops that access current resource and chain
311       // source and sink.
312       for (Operation* op : resource_ops) {
313         if (processed_ops.contains(op)) continue;
314 
315         IslandOp wrapper = op->getParentOfType<IslandOp>();
316         assert(wrapper);
317         wrapper.controlInputsMutable().append(chain_src_island.control());
318         chain_sink_island.controlInputsMutable().append(wrapper.control());
319         processed_ops.insert(op);
320       }
321     }
322   }
323   VLOG(2) << "Added " << resource_equivalence_classes.getNumClasses()
324           << " chains for " << chain_resource_to_ops_map.size() << " resources";
325 }
326 
327 // Generate a dummy constant island of requested type.
GetDummyConstant(OpBuilder builder,ShapedType const_type,Location loc)328 IslandOp GetDummyConstant(OpBuilder builder, ShapedType const_type,
329                           Location loc) {
330   DenseIntElementsAttr val = DenseIntElementsAttr::get(const_type, 1);
331   auto const_op = builder.create<TF::ConstOp>(loc, val);
332   auto const_island = CreateIsland(const_op, {}, builder);
333   return const_island;
334 }
335 
336 // Rewrites the while op with extra chaining operands and results. Uses a
337 // dummy constant of requested type as argument to all the new chaining
338 // operands.
RewriteWhileOp(TF::WhileOp while_op,int num_resource_inputs,ShapedType const_type)339 TF::WhileOp RewriteWhileOp(TF::WhileOp while_op, int num_resource_inputs,
340                            ShapedType const_type) {
341   IslandOp while_wrapper = while_op->getParentOfType<IslandOp>();
342   assert(while_wrapper && "While op is expected to be wrapped in a IslandOp");
343 
344   // Get the dummy constant.
345   OpBuilder builder(while_wrapper);
346   auto loc = NameLoc::get(
347       builder.getStringAttr("chain_control_outputs@" + while_op.body()));
348   IslandOp const_wrapper = GetDummyConstant(builder, const_type, loc);
349 
350   // Get new operand and result types.
351   auto new_operands = llvm::to_vector<4>(while_op->getOperands());
352   auto new_result_types = llvm::to_vector<4>(while_op->getResultTypes());
353   Value const_output = const_wrapper.outputs()[0];
354   for (int i = 0; i < num_resource_inputs; ++i) {
355     new_operands.push_back(const_output);
356     new_result_types.push_back(const_output.getType());
357   }
358 
359   // Replace old while op with new while op.
360   auto new_while_op = builder.create<TF::WhileOp>(
361       while_op.getLoc(), new_result_types, new_operands, while_op->getAttrs());
362   auto new_while_wrapper =
363       CreateIsland(new_while_op, while_wrapper.controlInputs(), builder);
364   for (auto result : while_wrapper.outputs()) {
365     result.replaceAllUsesWith(
366         new_while_wrapper.outputs()[result.getResultNumber()]);
367   }
368   while_wrapper.control().replaceAllUsesWith(new_while_wrapper.control());
369   while_wrapper.erase();
370   return new_while_op;
371 }
372 
373 // Converts the control outputs of the while body to data outputs, thus
374 // removing control barrier at the end of while loop body.
ConvertControlToDataOutputs(func::FuncOp while_body,SmallVectorImpl<TF::WhileOp> & while_callers,OperationSetTy & recompute_analysis_for_funcs,const TF::SideEffectAnalysis::Info & side_effect_analysis)375 void ConvertControlToDataOutputs(
376     func::FuncOp while_body, SmallVectorImpl<TF::WhileOp>& while_callers,
377     OperationSetTy& recompute_analysis_for_funcs,
378     const TF::SideEffectAnalysis::Info& side_effect_analysis) {
379   if (while_callers.empty()) return;
380 
381   // Collect access information for each resource in the while body that needs
382   // to be chained, along with equivalence classes (resources in one class will
383   // use the same chain).
384   ResourceToOpsMapTy chain_resource_to_ops_map;
385   llvm::EquivalenceClasses<ResourceId> resource_equivalence_classes;
386   CollectChainResources(while_body, chain_resource_to_ops_map,
387                         resource_equivalence_classes, side_effect_analysis);
388 
389   // Check for presence of unknown side-effecting ops within the while loop
390   // body. These ops act as barriers and the optimization would not yield much
391   // inter iteration parallelism for this while loop body. So return with
392   // warning.
393   if (chain_resource_to_ops_map.count(kUnknownResourceId) > 0) {
394     std::set<std::string> blocking_ops;
395     for (Operation* op : chain_resource_to_ops_map[kUnknownResourceId]) {
396       std::string op_name = op->getName().getStringRef().str();
397       if (blocking_ops.insert(op_name).second) {
398         LOG(INFO) << "[`tf-executor-convert-control-to-data-outputs` disabled] "
399                      "Op type '"
400                   << op_name
401                   << "' has unknown side effects and blocks inter iteration "
402                      "parallelism for the while loop. Consider modeling side "
403                      "effects of this op.";
404       }
405     }
406     return;
407   }
408 
409   // First remove all control outputs of while loop body.
410   bool changed = RemoveAllControlOutputs(while_body);
411 
412   // If there was no control output to be removed, return early.
413   if (!changed) return;
414 
415   int num_chains = resource_equivalence_classes.getNumClasses();
416   RankedTensorType chaining_data_type =
417       RankedTensorType::get({}, OpBuilder(while_body).getI32Type());
418   // Create new while body
419   int num_old_outputs = while_body.getNumResults();
420   AppendFunctionArguments(while_body, num_chains, chaining_data_type);
421   AppendFunctionResults(while_body, num_chains, chaining_data_type);
422 
423   // Insert identity ops with control dep
424   ChainResourceOps(while_body, chain_resource_to_ops_map,
425                    resource_equivalence_classes, num_old_outputs);
426   // Modify all the while ops referencing the body function and the
427   // corresponding while condition functions. Note that each while condition
428   // needs to be modified only once.
429   OperationSetTy visited;
430   for (TF::WhileOp while_op : while_callers) {
431     // If the while callers are modified as part of the optimization, then the
432     // side effect analysis of their parent functions are invalidated. They
433     // need to be recomputed.
434     recompute_analysis_for_funcs.insert(
435         while_op->getParentOfType<func::FuncOp>());
436     func::FuncOp while_cond = while_op.cond_function();
437     // Rewrite while op with extra chaining arguments and results.
438     while_op = RewriteWhileOp(while_op, num_chains, chaining_data_type);
439     bool first_visit = visited.insert(while_cond).second;
440     if (!first_visit) continue;
441     // Modify while condition function with extra chaining arguments.
442     AppendFunctionArguments(while_cond, num_chains, chaining_data_type);
443   }
444 }
445 
runOnOperation()446 void ConvertControlToDataOutputsPass::runOnOperation() {
447   ModuleOp module = getOperation();
448   // This pass assumes that all functions are suitable for export i.e., each
449   // function has a single tf_executor.graph op and all islands wrap the
450   // internal op perfectly. Verify that in the beginning once.
451   if (failed(tensorflow::VerifyExportSuitable(module))) {
452     signalPassFailure();
453     return;
454   }
455   TF::SideEffectAnalysis side_effect_analysis(module);
456 
457   SymbolTableCollection table;
458   SymbolUserMap symbol_map(table, module);
459   llvm::SmallDenseMap<func::FuncOp, SmallVector<TF::WhileOp>>
460       while_body_func_to_while_ops;
461 
462   // Get all the while body functions and the corresponding while ops first
463   // because the symbol user map is invalidated once we start deleting while
464   // ops.
465   for (auto func : module.getOps<func::FuncOp>()) {
466     if (func.isExternal()) continue;
467     SmallVector<TF::WhileOp> while_callers = GetWhileCallers(func, symbol_map);
468     if (while_callers.empty()) continue;
469     while_body_func_to_while_ops[func] = while_callers;
470   }
471   // Keep track of functions whose side effect analysis is invalidated because
472   // of modifications to that function.
473   OperationSetTy recompute_analysis_for_funcs;
474 
475   for (auto& entry : while_body_func_to_while_ops) {
476     func::FuncOp while_body = entry.getFirst();
477     SmallVector<TF::WhileOp>& while_callers = entry.getSecond();
478     if (recompute_analysis_for_funcs.contains(while_body)) {
479       // TODO(b/202540801): Recomputing side effect analysis for the entire
480       // module is wasteful. It would be better to just recompute analysis for
481       // specific functions but the current side effect analysis interface
482       // does not allow that.
483       side_effect_analysis = TF::SideEffectAnalysis(module);
484     }
485     ConvertControlToDataOutputs(
486         while_body, while_callers, recompute_analysis_for_funcs,
487         side_effect_analysis.GetAnalysisForFunc(while_body));
488   }
489 }
490 
491 }  // namespace
492 
493 std::unique_ptr<OperationPass<ModuleOp>>
CreateTFExecutorConvertControlToDataOutputsPass()494 CreateTFExecutorConvertControlToDataOutputsPass() {
495   return std::make_unique<ConvertControlToDataOutputsPass>();
496 }
497 
498 }  // namespace tf_executor
499 }  // namespace mlir
500