1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h"
17 
18 #include "llvm/ADT/BitVector.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
20 #include "mlir/IR/Attributes.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/IR/Value.h"  // from @llvm-project
23 #include "mlir/IR/Visitors.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
27 
28 namespace mlir {
29 namespace {
30 
IsResource(Value value)31 bool IsResource(Value value) {
32   return getElementTypeOrSelf(value.getType()).isa<TF::ResourceType>();
33 }
34 
35 // Checks if a cast op is casting a resource -> resource.
IsCastOfResource(Operation & op)36 bool IsCastOfResource(Operation &op) {
37   auto cast = dyn_cast<TF::CastOp>(op);
38   if (!cast) return false;
39   return IsResource(cast.x());
40 }
41 
42 // Removes passthrough ops in the block. The device computation does not need
43 // such nodes to carry information.
RemovePassthroughOp(Block & block)44 void RemovePassthroughOp(Block &block) {
45   for (auto &op : llvm::make_early_inc_range(block)) {
46     if (isa<TF::IdentityOp, TF::IdentityNOp>(op) || IsCastOfResource(op)) {
47       op.replaceAllUsesWith(op.getOperands());
48       op.erase();
49     }
50   }
51 }
52 
53 // Eliminate local variables that are only assigned to but never read, and thus
54 // are dead.
RemoveDeadLocalVariables(Block & block)55 void RemoveDeadLocalVariables(Block &block) {
56   llvm::SmallVector<TF::MlirLocalVarOp, 8> local_vars;
57   for (Operation &op : block) {
58     if (auto local_var = llvm::dyn_cast<TF::MlirLocalVarOp>(&op)) {
59       local_vars.push_back(local_var);
60     }
61   }
62   for (auto local_var : local_vars) {
63     auto users = local_var.resource().getUsers();
64     if (llvm::all_of(users, [](const Operation *user) {
65           return isa<TF::AssignVariableOp>(user);
66         })) {
67       for (auto user : llvm::make_early_inc_range(users)) user->erase();
68       local_var.erase();
69     }
70   }
71 }
72 
73 LogicalResult CleanupAndCanonicalize(Operation *parent_op);
74 
75 // Eliminates unusued results from an operation `op` by cloning it with reduced
76 // result types and doing appropriate use replacements. `results_to_eliminate`
77 // is a bitvector of result positions to eliminate. If its null, then all unused
78 // results of the operation will be eliminated.
EliminateUnusedResults(Operation * op,const llvm::BitVector * results_to_eliminate=nullptr)79 void EliminateUnusedResults(
80     Operation *op, const llvm::BitVector *results_to_eliminate = nullptr) {
81   auto can_eliminate = [&](OpResult &result) -> bool {
82     if (!result.use_empty()) return false;
83     if (results_to_eliminate)
84       return results_to_eliminate->test(result.getResultNumber());
85     else
86       return true;
87   };
88   SmallVector<Type, 4> new_result_types;
89   for (OpResult result : op->getResults()) {
90     if (can_eliminate(result)) continue;
91     new_result_types.push_back(result.getType());
92   }
93 
94   // Rebuild the new operation with lesser number of results.
95   OpBuilder builder(op);
96   Operation *new_op = Operation::create(
97       op->getLoc(), op->getName(), new_result_types, op->getOperands(),
98       op->getAttrs(), op->getSuccessors(), op->getNumRegions());
99   builder.insert(new_op);
100 
101   // Move region bodies to the new operation.
102   for (auto it : llvm::zip(op->getRegions(), new_op->getRegions())) {
103     Region &old_region = std::get<0>(it);
104     Region &new_region = std::get<1>(it);
105     new_region.takeBody(old_region);
106   }
107 
108   // Replace used results and erase the old op.
109   int next_result_idx = 0;
110   for (OpResult result : op->getResults()) {
111     if (can_eliminate(result)) continue;
112     result.replaceAllUsesWith(new_op->getResult(next_result_idx++));
113   }
114   op->erase();
115 }
116 
117 // Clones a function if it cannot be patched in place. Clone if there are
118 // multiple uses or unknown uses (for external functions). The cloned function
119 // will be marked as private.
CloneFunctionIfNeeded(func::FuncOp func)120 func::FuncOp CloneFunctionIfNeeded(func::FuncOp func) {
121   ModuleOp module = func->getParentOfType<ModuleOp>();
122   auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion());
123   if (func_uses.has_value() && llvm::hasSingleElement(func_uses.getValue()))
124     return func;
125   func::FuncOp cloned = func.clone();
126   cloned.setPrivate();
127   cloned.setName(
128       StringAttr::get(func.getContext(), func.getName().str() + "_lifted"));
129   SymbolTable(module).insert(cloned);
130   return cloned;
131 }
132 
133 // Eliminates unused results for If/Case operations. Also patches up the
134 // branch functions to (a) drop the ununsed return values, and (b) as a result
135 // if some argument becomes unused in all branches, drop that argument and the
136 // corresponding if/case input operand.
EliminateUnusedResultsForIfCase(Operation * op,ArrayRef<func::FuncOp> branches)137 void EliminateUnusedResultsForIfCase(Operation *op,
138                                      ArrayRef<func::FuncOp> branches) {
139   // Clone branch functions if needed since we will be mutating them.
140   SmallVector<func::FuncOp, 2> cloned_branches;
141   cloned_branches.reserve(branches.size());
142   for (func::FuncOp func : branches) {
143     func::FuncOp cloned = CloneFunctionIfNeeded(func);
144     cloned_branches.push_back(cloned);
145     if (cloned == func) continue;
146     // Patch up the op attribute to point to the new function.
147     for (NamedAttribute attr : op->getAttrs()) {
148       auto symref = attr.getValue().dyn_cast<FlatSymbolRefAttr>();
149       if (!symref) continue;
150       if (symref.getValue() != func.getName()) continue;
151       op->setAttr(attr.getName(),
152                   FlatSymbolRefAttr::get(op->getContext(), cloned.getName()));
153       break;
154     }
155   }
156 
157   // Traverse results backward so that indices to be deleted stay unchanged.
158   for (OpResult result : llvm::reverse(op->getResults())) {
159     if (!result.use_empty()) continue;
160     int result_idx = result.getResultNumber();
161     for (func::FuncOp func : cloned_branches)
162       func.front().getTerminator()->eraseOperand(result_idx);
163   }
164 
165   // Check which function arguments are unused in all branches. We can drop
166   // those as well.
167   int num_args = cloned_branches[0].getNumArguments();
168   llvm::BitVector used_args(num_args);
169   for (func::FuncOp func : branches) {
170     for (BlockArgument arg : func.getArguments()) {
171       if (!arg.use_empty()) used_args.set(arg.getArgNumber());
172     }
173   }
174 
175   // There are some unused args that we can drop. Also drop the corresponding
176   // input operand.
177   if (used_args.count() != num_args) {
178     // Traverse arguments backward so that indices to be deleted stay unchanged.
179     for (int idx = num_args - 1; idx >= 0; --idx) {
180       if (used_args.test(idx)) continue;
181       for (func::FuncOp func : cloned_branches) func.eraseArgument(idx);
182       // For if/case, arg #i of attached function corresponds to operand #i+1
183       op->eraseOperand(idx + 1);
184     }
185   }
186 
187   // Patch up function types (with less number of return values and potentially
188   // less number of arguments)
189   for (func::FuncOp func : cloned_branches) {
190     func.setType(
191         FunctionType::get(func.getContext(), func.front().getArgumentTypes(),
192                           func.front().getTerminator()->getOperandTypes()));
193   }
194 
195   EliminateUnusedResults(op);
196 }
197 
198 // Eliminated unused results from a functional while.
EliminateUnusedResultsForWhile(TF::WhileOp op)199 void EliminateUnusedResultsForWhile(TF::WhileOp op) {
200   func::FuncOp cond = op.cond_function();
201   func::FuncOp body = op.body_function();
202 
203   llvm::BitVector can_eliminate(op.getNumResults());
204   for (OpResult result : llvm::reverse(op.getResults())) {
205     if (!result.use_empty()) continue;
206     int result_idx = result.getResultNumber();
207     BlockArgument cond_arg = cond.getArgument(result_idx);
208     BlockArgument body_arg = cond.getArgument(result_idx);
209     Operation *body_ret = body.front().getTerminator();
210     // We can eliminate a result if its unused and the corresponding argument
211     // is unused in cond and the only use in body is use it as a return value.
212     if (cond_arg.use_empty() && body_arg.hasOneUse() &&
213         body_arg.use_begin()->getOperandNumber() == result_idx &&
214         body_arg.use_begin()->getOwner() == body_ret) {
215       can_eliminate.set(result_idx);
216     }
217   }
218 
219   if (can_eliminate.empty()) return;
220 
221   func::FuncOp cloned_cond = CloneFunctionIfNeeded(cond);
222   func::FuncOp cloned_body = CloneFunctionIfNeeded(body);
223   op.condAttr(FlatSymbolRefAttr::get(op.getContext(), cloned_cond.getName()));
224   op.bodyAttr(FlatSymbolRefAttr::get(op.getContext(), cloned_body.getName()));
225 
226   // Drop cond/body args and return value. WhileOp result will be dropped later
227   // in EliminateUnusedResults. Traverse in reverse order so that indices to be
228   // deleted stay unchanged.
229   for (int idx = op.getNumResults() - 1; idx >= 0; --idx) {
230     if (!can_eliminate.test(idx)) continue;
231     cloned_cond.eraseArgument(idx);
232     cloned_body.front().getTerminator()->eraseOperand(idx);
233     cloned_body.eraseArgument(idx);
234   }
235 
236   // Patch up branch function types.
237   for (func::FuncOp func : {cloned_cond, cloned_body}) {
238     func.setType(
239         FunctionType::get(func.getContext(), func.front().getArgumentTypes(),
240                           func.front().getTerminator()->getOperandTypes()));
241   }
242   EliminateUnusedResults(op, &can_eliminate);
243 }
244 
245 // For resource results, replace all uses with the resource input to which the
246 // result is tied to. After this, resource outputs of this op are expected to be
247 // unused.
ForwardCommonArgToOutput(Operation * op,ArrayRef<func::FuncOp> branches,ValueRange branch_args,bool & has_resource_result)248 LogicalResult ForwardCommonArgToOutput(Operation *op,
249                                        ArrayRef<func::FuncOp> branches,
250                                        ValueRange branch_args,
251                                        bool &has_resource_result) {
252   // For while, the branch inputs and outputs need to match.
253   bool io_match = isa<TF::WhileOp>(op);
254 
255   has_resource_result = false;
256   // Check if the same input argument number is passed through all functions.
257   for (OpResult result : op->getResults()) {
258     if (!IsResource(result)) continue;
259 
260     has_resource_result = true;
261     int result_idx = result.getResultNumber();
262     Optional<int> common_arg_index;
263     for (func::FuncOp func : branches) {
264       auto ret = func.front().getTerminator();
265       auto block_arg = ret->getOperand(result_idx).dyn_cast<BlockArgument>();
266       if (!block_arg) {
267         return op->emitOpError("result #")
268                << result_idx << " not tied to function argument for branch @"
269                << func.getName();
270       }
271       if (!common_arg_index.has_value()) {
272         common_arg_index = block_arg.getArgNumber();
273       } else if (common_arg_index.getValue() != block_arg.getArgNumber()) {
274         return op->emitError("result #")
275                << result_idx
276                << " is not tied to the same argument across all branches";
277       }
278     }
279 
280     if (io_match && result_idx != common_arg_index.getValue()) {
281       return op->emitOpError("Result #")
282              << result_idx << " is tied to argument #"
283              << common_arg_index.getValue();
284     }
285 
286     // Forward the corresponding input to the output
287     result.replaceAllUsesWith(branch_args[common_arg_index.getValue()]);
288   }
289   return success();
290 }
291 
292 // Canonicalizes a function if. Forwards input argument to resource results and
293 // then deletes the resource results.
CanonicalizeFunctionalIfCase(Operation * op,ArrayRef<func::FuncOp> branches,ValueRange branch_args)294 LogicalResult CanonicalizeFunctionalIfCase(Operation *op,
295                                            ArrayRef<func::FuncOp> branches,
296                                            ValueRange branch_args) {
297   for (func::FuncOp func : branches) {
298     if (failed(CleanupAndCanonicalize(func))) return failure();
299   }
300 
301   bool has_resource_result = false;
302   if (failed(ForwardCommonArgToOutput(op, branches, branch_args,
303                                       has_resource_result)))
304     return failure();
305 
306   // If no resource type results were found, no further cleanup needed.
307   if (!has_resource_result) return success();
308 
309   // Drop unused results.
310   EliminateUnusedResultsForIfCase(op, branches);
311   return success();
312 }
313 
314 // Canonicalizes a functional while. Forwards common argument to results and
315 // drop resource results if posible.
CanonicalizeFunctionalWhile(TF::WhileOp op)316 LogicalResult CanonicalizeFunctionalWhile(TF::WhileOp op) {
317   for (func::FuncOp func : {op.cond_function(), op.body_function()}) {
318     if (failed(CleanupAndCanonicalize(func))) return failure();
319   }
320 
321   // For while, just use the body function to forward operand to result.
322   bool has_resource_result = false;
323   if (failed(ForwardCommonArgToOutput(op, {op.body_function()},
324                                       op.getOperands(), has_resource_result)))
325     return failure();
326   // If no resource type results were found, no further cleanup needed.
327   if (!has_resource_result) return success();
328 
329   // Drop unused results.
330   EliminateUnusedResultsForWhile(op);
331   return success();
332 }
333 
334 // Canonicalizes region based if/case and cluster operations. If the same
335 // captured resource typed value is used for all region results, then that value
336 // is forwared to the result and the result is dropped.
CanonicalizeRegionIfCaseCluster(Operation * op)337 LogicalResult CanonicalizeRegionIfCaseCluster(Operation *op) {
338   // Check if the same value is used for all region results for this output.
339   bool has_resource_result = false;
340   for (OpResult result : op->getResults()) {
341     if (!IsResource(result)) continue;
342     has_resource_result = true;
343     int result_idx = result.getResultNumber();
344 
345     Value ret0 =
346         op->getRegion(0).front().getTerminator()->getOperand(result_idx);
347     for (Region &region : op->getRegions().drop_front()) {
348       Value ret = region.front().getTerminator()->getOperand(result_idx);
349       if (ret != ret0) {
350         return op->emitError("Result #")
351                << result_idx
352                << " not tied to the same capture across all regions";
353       }
354     }
355     result.replaceAllUsesWith(ret0);
356   }
357 
358   if (!has_resource_result) return success();
359 
360   // Eliminate unused region results. Traverse in reverse order so that
361   // indices to be deleted stay unchanged.
362   for (OpResult result : llvm::reverse(op->getResults())) {
363     if (!result.use_empty()) continue;
364     int result_idx = result.getResultNumber();
365     for (Region &region : op->getRegions())
366       region.front().getTerminator()->eraseOperand(result_idx);
367   }
368   EliminateUnusedResults(op);
369   return success();
370 }
371 
372 // Canonicalizes a region based while. If the same value is passed through
373 // the body, the result is replaced with the operand and all argument/results
374 // and retuns values corresponding to that result are dropped.
CanonicalizeWhileRegion(TF::WhileRegionOp op)375 LogicalResult CanonicalizeWhileRegion(TF::WhileRegionOp op) {
376   Region &body = op.body();
377   Region &cond = op.cond();
378   llvm::BitVector can_eliminate(op.getNumResults());
379 
380   // Traverse in reverse order so that indices to be deleted stay unchanged.
381   for (OpResult result : llvm::reverse(op.getResults())) {
382     if (!IsResource(result)) continue;
383     int result_idx = result.getResultNumber();
384     Operation *yield_op = body.front().getTerminator();
385     Value yield_operand = yield_op->getOperand(result_idx);
386     Value while_operand = op.getOperand(result_idx);
387     Value body_arg = body.getArgument(result_idx);
388     Value cond_arg = cond.getArgument(result_idx);
389     if (yield_operand != body_arg && yield_operand != while_operand) {
390       return op.emitOpError("Result #") << result_idx << " is not tied to arg #"
391                                         << result_idx << " of the body";
392     }
393     body_arg.replaceAllUsesWith(while_operand);
394     cond_arg.replaceAllUsesWith(while_operand);
395     result.replaceAllUsesWith(while_operand);
396     body.front().getTerminator()->eraseOperand(result_idx);
397     body.eraseArgument(result_idx);
398     cond.eraseArgument(result_idx);
399     op.getOperation()->eraseOperand(result_idx);
400     can_eliminate.set(result_idx);
401   }
402   EliminateUnusedResults(op, &can_eliminate);
403   return success();
404 }
405 
406 // Removes identities and canonicalizes all operations within `parent_op`.
CleanupAndCanonicalize(Operation * parent_op)407 LogicalResult CleanupAndCanonicalize(Operation *parent_op) {
408   auto walk_result = parent_op->walk([](Operation *op) {
409     // Cleanup code in attached regions.
410     for (Region &region : op->getRegions()) {
411       if (!llvm::hasSingleElement(region)) return WalkResult::interrupt();
412       RemovePassthroughOp(region.front());
413       RemoveDeadLocalVariables(region.front());
414     }
415 
416     LogicalResult result = success();
417 
418     // While condition cannot write to resource variables.
419     auto check_while_cond = [&](TF::AssignVariableOp assign) {
420       op->emitOpError("found resource write in loop condition.");
421       return WalkResult::interrupt();
422     };
423 
424     if (auto if_op = dyn_cast<TF::IfOp>(op)) {
425       result = CanonicalizeFunctionalIfCase(
426           op, {if_op.then_function(), if_op.else_function()}, if_op.input());
427     } else if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
428       SmallVector<func::FuncOp, 4> branches;
429       case_op.get_branch_functions(branches);
430       result = CanonicalizeFunctionalIfCase(case_op, branches, case_op.input());
431     } else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
432       if (while_op.cond_function().walk(check_while_cond).wasInterrupted())
433         return WalkResult::interrupt();
434       result = CanonicalizeFunctionalWhile(while_op);
435     } else if (isa<TF::IfRegionOp, TF::CaseRegionOp, tf_device::ClusterOp>(
436                    op)) {
437       result = CanonicalizeRegionIfCaseCluster(op);
438     } else if (auto while_region = dyn_cast<TF::WhileRegionOp>(op)) {
439       if (while_region.cond().walk(check_while_cond).wasInterrupted())
440         return WalkResult::interrupt();
441       // For while region, the body input and output arg should match.
442       result = CanonicalizeWhileRegion(while_region);
443     } else if (auto call = dyn_cast<CallOpInterface>(op)) {
444       func::FuncOp func = dyn_cast<func::FuncOp>(call.resolveCallable());
445       if (!func) return WalkResult::interrupt();
446       result = CleanupAndCanonicalize(func);
447     }
448     return failed(result) ? WalkResult::interrupt() : WalkResult::advance();
449   });
450 
451   return failure(walk_result.wasInterrupted());
452 }
453 
454 }  // anonymous namespace
455 
456 namespace TF {
457 
CleanupAndCanonicalizeForResourceOpLifting(func::FuncOp func)458 LogicalResult CleanupAndCanonicalizeForResourceOpLifting(func::FuncOp func) {
459   return CleanupAndCanonicalize(func);
460 }
461 
CleanupAndCanonicalizeForResourceOpLifting(ModuleOp module)462 LogicalResult CleanupAndCanonicalizeForResourceOpLifting(ModuleOp module) {
463   auto walk_result = module.walk([](tf_device::ClusterOp cluster) {
464     if (failed(CleanupAndCanonicalize(cluster))) return WalkResult::interrupt();
465     return WalkResult::advance();
466   });
467   return failure(walk_result.wasInterrupted());
468 }
469 
470 }  // namespace TF
471 }  // namespace mlir
472