1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <queue>
18
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/EquivalenceClasses.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
27 #include "mlir/IR/OperationSupport.h" // from @llvm-project
28 #include "mlir/IR/SymbolTable.h" // from @llvm-project
29 #include "mlir/IR/Value.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
31 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h"
38
39 namespace mlir {
40 namespace tf_executor {
41 namespace {
42
43 using TF::ResourceId;
44 static constexpr ResourceId kUnknownResourceId =
45 TF::detail::ResourceAliasAnalysisInfo::kUnknownResourceId;
46 static constexpr ResourceId kInvalidResourceId =
47 TF::detail::ResourceAliasAnalysisInfo::kInvalidResourceId;
48 using OperationSetTy = SmallPtrSet<Operation*, 4>;
49 using ResourceToOpsMapTy = DenseMap<ResourceId, OperationSetTy>;
50
51 class ConvertControlToDataOutputsPass
52 : public TF::ExecutorConvertControlToDataOutputsPassBase<
53 ConvertControlToDataOutputsPass> {
54 public:
55 void runOnOperation() override;
56 };
57
58 // Returns a vector of all tf.WhileOp(s) which use func as while body. If any of
59 // the uses is as a while condition, an empty vector is returned.
GetWhileCallers(func::FuncOp func,SymbolUserMap & symbol_map)60 SmallVector<TF::WhileOp> GetWhileCallers(func::FuncOp func,
61 SymbolUserMap& symbol_map) {
62 SmallVector<TF::WhileOp> while_callers;
63 for (auto user : symbol_map.getUsers(func)) {
64 if (auto while_caller = dyn_cast<TF::WhileOp>(user)) {
65 // If used as while conditional anywhere, then skip optimizing this
66 // function. Return empty vector.
67 if (while_caller.cond_function() == func) return {};
68 assert(while_caller.body_function() == func);
69 while_callers.push_back(while_caller);
70 }
71 }
72 return while_callers;
73 }
74
75 // Populates `chain_resource_to_ops_map`, the map from all resources that need
76 // to be chained to the set of operations that access the resource, and
77 // `resource_equivalence_classes`. Resources are equivalent if they are accessed
78 // by a common op, and equivalent resources will be assigned to the same chain.
CollectChainResources(func::FuncOp func,ResourceToOpsMapTy & chain_resource_to_ops_map,llvm::EquivalenceClasses<ResourceId> & resource_equivalence_classes,const TF::SideEffectAnalysis::Info & side_effect_analysis)79 void CollectChainResources(
80 func::FuncOp func, ResourceToOpsMapTy& chain_resource_to_ops_map,
81 llvm::EquivalenceClasses<ResourceId>& resource_equivalence_classes,
82 const TF::SideEffectAnalysis::Info& side_effect_analysis) {
83 auto graph_op = cast<GraphOp>(func.front().front());
84
85 // For each op in the graph, get the resources it uses and update the access
86 // information for them.
87 graph_op.walk([&](IslandOp island) {
88 // This pass assumes that all functions are suitable for export i.e., each
89 // function has a single tf_executor.graph op and all islands wrap the
90 // internal op perfectly. Hence this assertion should never fail.
91 assert(island.WrapsSingleOp());
92 Operation& op = island.GetBody().front();
93
94 ResourceId prev_resource_id = kInvalidResourceId;
95 for (auto resource_id_read_only_pair :
96 side_effect_analysis.GetResourceIds(&op)) {
97 ResourceId resource_id = resource_id_read_only_pair.first;
98 // If the resource was allocated by an op with `UniqueResourceAllocation`
99 // trait, then we don't need to chain resource ops accessing this resource
100 // between iterations: Every iteration will create a new independent
101 // resource. This enables more parallelism across iterations.
102 if (!side_effect_analysis.IsUniqueResourceAllocationId(resource_id)) {
103 chain_resource_to_ops_map[resource_id].insert(&op);
104 if (prev_resource_id != kInvalidResourceId) {
105 // Merge class of current ID with class of previous ID since both
106 // resources are accessed by `op`.
107 resource_equivalence_classes.unionSets(prev_resource_id, resource_id);
108 } else {
109 resource_equivalence_classes.insert(resource_id);
110 }
111 prev_resource_id = resource_id;
112 }
113 }
114 });
115 }
116
117 // tf.NoOp islands are used to combine multiple control dependencies into one.
118 // These islands have a single tf.NoOp inside them and consume multiple control
119 // outputs to generate a single control output.
120 //
121 // For example,
122 // ```
123 // %merged_control = "tf_executor.island"(%control_a, %control_b) ({
124 // "tf.NoOp"() : () -> ()
125 // "tf_executor.yield"() : () -> ()
126 // }) : (!tf_executor.control, !tf_executor.control) -> (!tf_executor.control)
127 // ```
128 //
129 // `%merged_control` is a NoOp control barrier in this case.
130 //
131 // Checks if the value `control` is a NoOp control barrier.
IsNoOpControlBarrier(Value control)132 bool IsNoOpControlBarrier(Value control) {
133 if (!control.getType().isa<ControlType>()) return false;
134
135 auto control_island = dyn_cast_or_null<IslandOp>(control.getDefiningOp());
136 if (!control_island) return false;
137
138 // All islands perfectly wrap a single op is an invariant of this pass and
139 // is checked at the very beginning of the pass.
140 assert(control_island.WrapsSingleOp());
141 return control_island.outputs().empty() &&
142 isa<TF::NoOp>(control_island.GetBody().front());
143 }
144
145 // Remove all control outputs of the function. Traverses NoOp control barrier
146 // chains from FetchOp to all NoOp control barriers. Returns true
147 // iff at least one control output is deleted.
RemoveAllControlOutputs(func::FuncOp func)148 bool RemoveAllControlOutputs(func::FuncOp func) {
149 auto graph_op = cast<GraphOp>(func.front().front());
150
151 FetchOp fetch = graph_op.GetFetch();
152 // Return early if no control outputs exist.
153 if (fetch.getNumOperands() == graph_op->getNumResults()) return false;
154
155 std::queue<Value> control_barrier_worklist;
156 for (Value control_output :
157 fetch.fetches().drop_front(graph_op->getNumResults())) {
158 if (IsNoOpControlBarrier(control_output))
159 control_barrier_worklist.push(control_output);
160 }
161
162 // Erase all control outputs at the end from fetch.
163 fetch.fetchesMutable().erase(
164 graph_op.getNumResults(),
165 fetch.getNumOperands() - graph_op.getNumResults());
166
167 // Iterate the worklist to remove all NoOp control barriers at the end of the
168 // function body that are used to merge two or more control dependencies.
169 while (!control_barrier_worklist.empty()) {
170 Value control_barrier = control_barrier_worklist.front();
171 control_barrier_worklist.pop();
172
173 // We can only erase control barriers whose uses have been erased as well.
174 if (!control_barrier.use_empty()) continue;
175
176 // Only values defined by IslandOp were inserted in the worklist.
177 IslandOp current_island = cast<IslandOp>(control_barrier.getDefiningOp());
178
179 for (auto control_input : current_island.controlInputs()) {
180 if (IsNoOpControlBarrier(control_input))
181 control_barrier_worklist.push(control_input);
182 }
183 current_island.erase();
184 }
185 return true;
186 }
187
188 // Appends function arguments with `num_resources` number of arguments of
189 // requested type.
AppendFunctionArguments(func::FuncOp func,int num_resources,ShapedType chaining_data_type)190 void AppendFunctionArguments(func::FuncOp func, int num_resources,
191 ShapedType chaining_data_type) {
192 for (int i = 0; i < num_resources; ++i) {
193 func.getRegion().addArgument(chaining_data_type, func.getLoc());
194 }
195
196 FunctionType ftype =
197 FunctionType::get(func.getContext(), func.getBody().getArgumentTypes(),
198 func.getFunctionType().getResults());
199 func.setType(ftype);
200 }
201
202 // Appends function results with `num_resources` number of results of requested
203 // type.
AppendFunctionResults(func::FuncOp func,int num_resources,ShapedType chaining_data_type)204 void AppendFunctionResults(func::FuncOp func, int num_resources,
205 ShapedType chaining_data_type) {
206 Block& block = func.front();
207 auto graph_op = cast<GraphOp>(block.front());
208 // Note that func result types are same as the result types of
209 // GraphOp in the function `func`.
210 assert(std::equal(func->getResultTypes().begin(),
211 func->getResultTypes().end(),
212 graph_op->getResultTypes().begin()));
213 auto new_result_types =
214 llvm::to_vector<4>(func.getFunctionType().getResults());
215 for (int i = 0; i < num_resources; ++i) {
216 new_result_types.push_back(chaining_data_type);
217 }
218 FunctionType ftype = FunctionType::get(
219 func.getContext(), func.getArgumentTypes(), new_result_types);
220 func.setType(ftype);
221
222 // Rewrite GraphOp to have same number of results as the
223 // function.
224 OpBuilder builder(graph_op);
225 auto new_graph_op =
226 builder.create<GraphOp>(graph_op.getLoc(), new_result_types);
227 new_graph_op.getRegion().takeBody(graph_op.getRegion());
228 graph_op->replaceAllUsesWith(
229 new_graph_op->getResults().drop_back(num_resources));
230 graph_op.erase();
231 func::ReturnOp return_op = cast<func::ReturnOp>(block.getTerminator());
232 int num_old_arguments = return_op.getNumOperands();
233 for (int i = 0; i < num_resources; ++i) {
234 return_op.operandsMutable().append(
235 new_graph_op.getResult(num_old_arguments + i));
236 }
237 }
238
239 // Creates a wrapper island enclosing the `sub_op` dependent on
240 // `control_inputs`.
CreateIsland(Operation * sub_op,ValueRange control_inputs,OpBuilder builder)241 IslandOp CreateIsland(Operation* sub_op, ValueRange control_inputs,
242 OpBuilder builder) {
243 assert(sub_op);
244 auto control_type = ControlType::get(builder.getContext());
245 auto island = builder.create<IslandOp>(
246 sub_op->getLoc(), sub_op->getResultTypes(), control_type, control_inputs);
247 island.body().push_back(new Block);
248 Block* block = &island.body().back();
249 builder.setInsertionPointToEnd(block);
250 sub_op->replaceAllUsesWith(island.outputs());
251 sub_op->moveBefore(block, block->begin());
252 builder.create<YieldOp>(sub_op->getLoc(), sub_op->getResults());
253 return island;
254 }
255
256 // Adds control dependencies from/to chain arguments/results. It adds two
257 // identity ops, chain_src and chain_sink, per resource equivalence class.
258 // Using the resource to operations map, it adds (1) a control dependency
259 // from chain_src to all the operations that read/write to a resource of the
260 // equivalence class, and (2) a control dependency from all the operations that
261 // read/write to a resource of the class to the chain_sink operation.
ChainResourceOps(func::FuncOp func,ResourceToOpsMapTy & chain_resource_to_ops_map,llvm::EquivalenceClasses<ResourceId> & resource_equivalence_classes,int num_old_outputs)262 void ChainResourceOps(
263 func::FuncOp func, ResourceToOpsMapTy& chain_resource_to_ops_map,
264 llvm::EquivalenceClasses<ResourceId>& resource_equivalence_classes,
265 int num_old_outputs) {
266 assert(num_old_outputs + resource_equivalence_classes.getNumClasses() ==
267 func.getNumArguments());
268 auto graph_op = cast<GraphOp>(func.front().front());
269
270 auto fetch = graph_op.GetFetch();
271 OpBuilder builder_chain_src(fetch);
272 builder_chain_src.setInsertionPointToStart(fetch->getBlock());
273
274 OpBuilder builder_chain_sink(fetch);
275 int chain_index = num_old_outputs;
276
277 // Iterate over all equivalence classes.
278 for (auto class_iter = resource_equivalence_classes.begin();
279 class_iter != resource_equivalence_classes.end(); ++class_iter) {
280 // Only visit one element per class, the leader.
281 if (!class_iter->isLeader()) continue;
282
283 // Create chain source and sink identity islands for current equivalence
284 // class.
285 auto chain_arg = func.getArgument(chain_index++);
286 auto src_identity = builder_chain_src.create<TF::IdentityOp>(
287 chain_arg.getLoc(), chain_arg.getType(), chain_arg);
288 auto chain_src_island = CreateIsland(src_identity, {}, builder_chain_src);
289
290 auto sink_identity = builder_chain_sink.create<TF::IdentityOp>(
291 chain_arg.getLoc(), chain_arg.getType(), chain_arg);
292 auto chain_sink_island =
293 CreateIsland(sink_identity, {}, builder_chain_sink);
294
295 // Add the chain sink data output to fetch.
296 fetch.fetchesMutable().append(chain_sink_island.outputs().front());
297
298 // Iterate over all members of the current equivalence class (represented
299 // by `class_iter`). Keep track of ops that have already been processed.
300 llvm::SmallDenseSet<Operation*> processed_ops;
301 for (auto member_iter =
302 resource_equivalence_classes.member_begin(class_iter);
303 member_iter != resource_equivalence_classes.member_end();
304 ++member_iter) {
305 ResourceId resource_id = *member_iter;
306 auto map_iter = chain_resource_to_ops_map.find(resource_id);
307 if (map_iter == chain_resource_to_ops_map.end()) continue;
308 OperationSetTy& resource_ops = map_iter->getSecond();
309
310 // Add dependencies between all ops that access current resource and chain
311 // source and sink.
312 for (Operation* op : resource_ops) {
313 if (processed_ops.contains(op)) continue;
314
315 IslandOp wrapper = op->getParentOfType<IslandOp>();
316 assert(wrapper);
317 wrapper.controlInputsMutable().append(chain_src_island.control());
318 chain_sink_island.controlInputsMutable().append(wrapper.control());
319 processed_ops.insert(op);
320 }
321 }
322 }
323 VLOG(2) << "Added " << resource_equivalence_classes.getNumClasses()
324 << " chains for " << chain_resource_to_ops_map.size() << " resources";
325 }
326
327 // Generate a dummy constant island of requested type.
GetDummyConstant(OpBuilder builder,ShapedType const_type,Location loc)328 IslandOp GetDummyConstant(OpBuilder builder, ShapedType const_type,
329 Location loc) {
330 DenseIntElementsAttr val = DenseIntElementsAttr::get(const_type, 1);
331 auto const_op = builder.create<TF::ConstOp>(loc, val);
332 auto const_island = CreateIsland(const_op, {}, builder);
333 return const_island;
334 }
335
336 // Rewrites the while op with extra chaining operands and results. Uses a
337 // dummy constant of requested type as argument to all the new chaining
338 // operands.
RewriteWhileOp(TF::WhileOp while_op,int num_resource_inputs,ShapedType const_type)339 TF::WhileOp RewriteWhileOp(TF::WhileOp while_op, int num_resource_inputs,
340 ShapedType const_type) {
341 IslandOp while_wrapper = while_op->getParentOfType<IslandOp>();
342 assert(while_wrapper && "While op is expected to be wrapped in a IslandOp");
343
344 // Get the dummy constant.
345 OpBuilder builder(while_wrapper);
346 auto loc = NameLoc::get(
347 builder.getStringAttr("chain_control_outputs@" + while_op.body()));
348 IslandOp const_wrapper = GetDummyConstant(builder, const_type, loc);
349
350 // Get new operand and result types.
351 auto new_operands = llvm::to_vector<4>(while_op->getOperands());
352 auto new_result_types = llvm::to_vector<4>(while_op->getResultTypes());
353 Value const_output = const_wrapper.outputs()[0];
354 for (int i = 0; i < num_resource_inputs; ++i) {
355 new_operands.push_back(const_output);
356 new_result_types.push_back(const_output.getType());
357 }
358
359 // Replace old while op with new while op.
360 auto new_while_op = builder.create<TF::WhileOp>(
361 while_op.getLoc(), new_result_types, new_operands, while_op->getAttrs());
362 auto new_while_wrapper =
363 CreateIsland(new_while_op, while_wrapper.controlInputs(), builder);
364 for (auto result : while_wrapper.outputs()) {
365 result.replaceAllUsesWith(
366 new_while_wrapper.outputs()[result.getResultNumber()]);
367 }
368 while_wrapper.control().replaceAllUsesWith(new_while_wrapper.control());
369 while_wrapper.erase();
370 return new_while_op;
371 }
372
373 // Converts the control outputs of the while body to data outputs, thus
374 // removing control barrier at the end of while loop body.
ConvertControlToDataOutputs(func::FuncOp while_body,SmallVectorImpl<TF::WhileOp> & while_callers,OperationSetTy & recompute_analysis_for_funcs,const TF::SideEffectAnalysis::Info & side_effect_analysis)375 void ConvertControlToDataOutputs(
376 func::FuncOp while_body, SmallVectorImpl<TF::WhileOp>& while_callers,
377 OperationSetTy& recompute_analysis_for_funcs,
378 const TF::SideEffectAnalysis::Info& side_effect_analysis) {
379 if (while_callers.empty()) return;
380
381 // Collect access information for each resource in the while body that needs
382 // to be chained, along with equivalence classes (resources in one class will
383 // use the same chain).
384 ResourceToOpsMapTy chain_resource_to_ops_map;
385 llvm::EquivalenceClasses<ResourceId> resource_equivalence_classes;
386 CollectChainResources(while_body, chain_resource_to_ops_map,
387 resource_equivalence_classes, side_effect_analysis);
388
389 // Check for presence of unknown side-effecting ops within the while loop
390 // body. These ops act as barriers and the optimization would not yield much
391 // inter iteration parallelism for this while loop body. So return with
392 // warning.
393 if (chain_resource_to_ops_map.count(kUnknownResourceId) > 0) {
394 std::set<std::string> blocking_ops;
395 for (Operation* op : chain_resource_to_ops_map[kUnknownResourceId]) {
396 std::string op_name = op->getName().getStringRef().str();
397 if (blocking_ops.insert(op_name).second) {
398 LOG(INFO) << "[`tf-executor-convert-control-to-data-outputs` disabled] "
399 "Op type '"
400 << op_name
401 << "' has unknown side effects and blocks inter iteration "
402 "parallelism for the while loop. Consider modeling side "
403 "effects of this op.";
404 }
405 }
406 return;
407 }
408
409 // First remove all control outputs of while loop body.
410 bool changed = RemoveAllControlOutputs(while_body);
411
412 // If there was no control output to be removed, return early.
413 if (!changed) return;
414
415 int num_chains = resource_equivalence_classes.getNumClasses();
416 RankedTensorType chaining_data_type =
417 RankedTensorType::get({}, OpBuilder(while_body).getI32Type());
418 // Create new while body
419 int num_old_outputs = while_body.getNumResults();
420 AppendFunctionArguments(while_body, num_chains, chaining_data_type);
421 AppendFunctionResults(while_body, num_chains, chaining_data_type);
422
423 // Insert identity ops with control dep
424 ChainResourceOps(while_body, chain_resource_to_ops_map,
425 resource_equivalence_classes, num_old_outputs);
426 // Modify all the while ops referencing the body function and the
427 // corresponding while condition functions. Note that each while condition
428 // needs to be modified only once.
429 OperationSetTy visited;
430 for (TF::WhileOp while_op : while_callers) {
431 // If the while callers are modified as part of the optimization, then the
432 // side effect analysis of their parent functions are invalidated. They
433 // need to be recomputed.
434 recompute_analysis_for_funcs.insert(
435 while_op->getParentOfType<func::FuncOp>());
436 func::FuncOp while_cond = while_op.cond_function();
437 // Rewrite while op with extra chaining arguments and results.
438 while_op = RewriteWhileOp(while_op, num_chains, chaining_data_type);
439 bool first_visit = visited.insert(while_cond).second;
440 if (!first_visit) continue;
441 // Modify while condition function with extra chaining arguments.
442 AppendFunctionArguments(while_cond, num_chains, chaining_data_type);
443 }
444 }
445
runOnOperation()446 void ConvertControlToDataOutputsPass::runOnOperation() {
447 ModuleOp module = getOperation();
448 // This pass assumes that all functions are suitable for export i.e., each
449 // function has a single tf_executor.graph op and all islands wrap the
450 // internal op perfectly. Verify that in the beginning once.
451 if (failed(tensorflow::VerifyExportSuitable(module))) {
452 signalPassFailure();
453 return;
454 }
455 TF::SideEffectAnalysis side_effect_analysis(module);
456
457 SymbolTableCollection table;
458 SymbolUserMap symbol_map(table, module);
459 llvm::SmallDenseMap<func::FuncOp, SmallVector<TF::WhileOp>>
460 while_body_func_to_while_ops;
461
462 // Get all the while body functions and the corresponding while ops first
463 // because the symbol user map is invalidated once we start deleting while
464 // ops.
465 for (auto func : module.getOps<func::FuncOp>()) {
466 if (func.isExternal()) continue;
467 SmallVector<TF::WhileOp> while_callers = GetWhileCallers(func, symbol_map);
468 if (while_callers.empty()) continue;
469 while_body_func_to_while_ops[func] = while_callers;
470 }
471 // Keep track of functions whose side effect analysis is invalidated because
472 // of modifications to that function.
473 OperationSetTy recompute_analysis_for_funcs;
474
475 for (auto& entry : while_body_func_to_while_ops) {
476 func::FuncOp while_body = entry.getFirst();
477 SmallVector<TF::WhileOp>& while_callers = entry.getSecond();
478 if (recompute_analysis_for_funcs.contains(while_body)) {
479 // TODO(b/202540801): Recomputing side effect analysis for the entire
480 // module is wasteful. It would be better to just recompute analysis for
481 // specific functions but the current side effect analysis interface
482 // does not allow that.
483 side_effect_analysis = TF::SideEffectAnalysis(module);
484 }
485 ConvertControlToDataOutputs(
486 while_body, while_callers, recompute_analysis_for_funcs,
487 side_effect_analysis.GetAnalysisForFunc(while_body));
488 }
489 }
490
491 } // namespace
492
493 std::unique_ptr<OperationPass<ModuleOp>>
CreateTFExecutorConvertControlToDataOutputsPass()494 CreateTFExecutorConvertControlToDataOutputsPass() {
495 return std::make_unique<ConvertControlToDataOutputsPass>();
496 }
497
498 } // namespace tf_executor
499 } // namespace mlir
500