1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h"
17
18 #include "llvm/ADT/BitVector.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
20 #include "mlir/IR/Attributes.h" // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
22 #include "mlir/IR/Value.h" // from @llvm-project
23 #include "mlir/IR/Visitors.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
27
28 namespace mlir {
29 namespace {
30
IsResource(Value value)31 bool IsResource(Value value) {
32 return getElementTypeOrSelf(value.getType()).isa<TF::ResourceType>();
33 }
34
35 // Checks if a cast op is casting a resource -> resource.
IsCastOfResource(Operation & op)36 bool IsCastOfResource(Operation &op) {
37 auto cast = dyn_cast<TF::CastOp>(op);
38 if (!cast) return false;
39 return IsResource(cast.x());
40 }
41
42 // Removes passthrough ops in the block. The device computation does not need
43 // such nodes to carry information.
RemovePassthroughOp(Block & block)44 void RemovePassthroughOp(Block &block) {
45 for (auto &op : llvm::make_early_inc_range(block)) {
46 if (isa<TF::IdentityOp, TF::IdentityNOp>(op) || IsCastOfResource(op)) {
47 op.replaceAllUsesWith(op.getOperands());
48 op.erase();
49 }
50 }
51 }
52
53 // Eliminate local variables that are only assigned to but never read, and thus
54 // are dead.
RemoveDeadLocalVariables(Block & block)55 void RemoveDeadLocalVariables(Block &block) {
56 llvm::SmallVector<TF::MlirLocalVarOp, 8> local_vars;
57 for (Operation &op : block) {
58 if (auto local_var = llvm::dyn_cast<TF::MlirLocalVarOp>(&op)) {
59 local_vars.push_back(local_var);
60 }
61 }
62 for (auto local_var : local_vars) {
63 auto users = local_var.resource().getUsers();
64 if (llvm::all_of(users, [](const Operation *user) {
65 return isa<TF::AssignVariableOp>(user);
66 })) {
67 for (auto user : llvm::make_early_inc_range(users)) user->erase();
68 local_var.erase();
69 }
70 }
71 }
72
73 LogicalResult CleanupAndCanonicalize(Operation *parent_op);
74
75 // Eliminates unusued results from an operation `op` by cloning it with reduced
76 // result types and doing appropriate use replacements. `results_to_eliminate`
77 // is a bitvector of result positions to eliminate. If its null, then all unused
78 // results of the operation will be eliminated.
EliminateUnusedResults(Operation * op,const llvm::BitVector * results_to_eliminate=nullptr)79 void EliminateUnusedResults(
80 Operation *op, const llvm::BitVector *results_to_eliminate = nullptr) {
81 auto can_eliminate = [&](OpResult &result) -> bool {
82 if (!result.use_empty()) return false;
83 if (results_to_eliminate)
84 return results_to_eliminate->test(result.getResultNumber());
85 else
86 return true;
87 };
88 SmallVector<Type, 4> new_result_types;
89 for (OpResult result : op->getResults()) {
90 if (can_eliminate(result)) continue;
91 new_result_types.push_back(result.getType());
92 }
93
94 // Rebuild the new operation with lesser number of results.
95 OpBuilder builder(op);
96 Operation *new_op = Operation::create(
97 op->getLoc(), op->getName(), new_result_types, op->getOperands(),
98 op->getAttrs(), op->getSuccessors(), op->getNumRegions());
99 builder.insert(new_op);
100
101 // Move region bodies to the new operation.
102 for (auto it : llvm::zip(op->getRegions(), new_op->getRegions())) {
103 Region &old_region = std::get<0>(it);
104 Region &new_region = std::get<1>(it);
105 new_region.takeBody(old_region);
106 }
107
108 // Replace used results and erase the old op.
109 int next_result_idx = 0;
110 for (OpResult result : op->getResults()) {
111 if (can_eliminate(result)) continue;
112 result.replaceAllUsesWith(new_op->getResult(next_result_idx++));
113 }
114 op->erase();
115 }
116
117 // Clones a function if it cannot be patched in place. Clone if there are
118 // multiple uses or unknown uses (for external functions). The cloned function
119 // will be marked as private.
CloneFunctionIfNeeded(func::FuncOp func)120 func::FuncOp CloneFunctionIfNeeded(func::FuncOp func) {
121 ModuleOp module = func->getParentOfType<ModuleOp>();
122 auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion());
123 if (func_uses.has_value() && llvm::hasSingleElement(func_uses.getValue()))
124 return func;
125 func::FuncOp cloned = func.clone();
126 cloned.setPrivate();
127 cloned.setName(
128 StringAttr::get(func.getContext(), func.getName().str() + "_lifted"));
129 SymbolTable(module).insert(cloned);
130 return cloned;
131 }
132
133 // Eliminates unused results for If/Case operations. Also patches up the
134 // branch functions to (a) drop the ununsed return values, and (b) as a result
135 // if some argument becomes unused in all branches, drop that argument and the
136 // corresponding if/case input operand.
EliminateUnusedResultsForIfCase(Operation * op,ArrayRef<func::FuncOp> branches)137 void EliminateUnusedResultsForIfCase(Operation *op,
138 ArrayRef<func::FuncOp> branches) {
139 // Clone branch functions if needed since we will be mutating them.
140 SmallVector<func::FuncOp, 2> cloned_branches;
141 cloned_branches.reserve(branches.size());
142 for (func::FuncOp func : branches) {
143 func::FuncOp cloned = CloneFunctionIfNeeded(func);
144 cloned_branches.push_back(cloned);
145 if (cloned == func) continue;
146 // Patch up the op attribute to point to the new function.
147 for (NamedAttribute attr : op->getAttrs()) {
148 auto symref = attr.getValue().dyn_cast<FlatSymbolRefAttr>();
149 if (!symref) continue;
150 if (symref.getValue() != func.getName()) continue;
151 op->setAttr(attr.getName(),
152 FlatSymbolRefAttr::get(op->getContext(), cloned.getName()));
153 break;
154 }
155 }
156
157 // Traverse results backward so that indices to be deleted stay unchanged.
158 for (OpResult result : llvm::reverse(op->getResults())) {
159 if (!result.use_empty()) continue;
160 int result_idx = result.getResultNumber();
161 for (func::FuncOp func : cloned_branches)
162 func.front().getTerminator()->eraseOperand(result_idx);
163 }
164
165 // Check which function arguments are unused in all branches. We can drop
166 // those as well.
167 int num_args = cloned_branches[0].getNumArguments();
168 llvm::BitVector used_args(num_args);
169 for (func::FuncOp func : branches) {
170 for (BlockArgument arg : func.getArguments()) {
171 if (!arg.use_empty()) used_args.set(arg.getArgNumber());
172 }
173 }
174
175 // There are some unused args that we can drop. Also drop the corresponding
176 // input operand.
177 if (used_args.count() != num_args) {
178 // Traverse arguments backward so that indices to be deleted stay unchanged.
179 for (int idx = num_args - 1; idx >= 0; --idx) {
180 if (used_args.test(idx)) continue;
181 for (func::FuncOp func : cloned_branches) func.eraseArgument(idx);
182 // For if/case, arg #i of attached function corresponds to operand #i+1
183 op->eraseOperand(idx + 1);
184 }
185 }
186
187 // Patch up function types (with less number of return values and potentially
188 // less number of arguments)
189 for (func::FuncOp func : cloned_branches) {
190 func.setType(
191 FunctionType::get(func.getContext(), func.front().getArgumentTypes(),
192 func.front().getTerminator()->getOperandTypes()));
193 }
194
195 EliminateUnusedResults(op);
196 }
197
198 // Eliminated unused results from a functional while.
EliminateUnusedResultsForWhile(TF::WhileOp op)199 void EliminateUnusedResultsForWhile(TF::WhileOp op) {
200 func::FuncOp cond = op.cond_function();
201 func::FuncOp body = op.body_function();
202
203 llvm::BitVector can_eliminate(op.getNumResults());
204 for (OpResult result : llvm::reverse(op.getResults())) {
205 if (!result.use_empty()) continue;
206 int result_idx = result.getResultNumber();
207 BlockArgument cond_arg = cond.getArgument(result_idx);
208 BlockArgument body_arg = cond.getArgument(result_idx);
209 Operation *body_ret = body.front().getTerminator();
210 // We can eliminate a result if its unused and the corresponding argument
211 // is unused in cond and the only use in body is use it as a return value.
212 if (cond_arg.use_empty() && body_arg.hasOneUse() &&
213 body_arg.use_begin()->getOperandNumber() == result_idx &&
214 body_arg.use_begin()->getOwner() == body_ret) {
215 can_eliminate.set(result_idx);
216 }
217 }
218
219 if (can_eliminate.empty()) return;
220
221 func::FuncOp cloned_cond = CloneFunctionIfNeeded(cond);
222 func::FuncOp cloned_body = CloneFunctionIfNeeded(body);
223 op.condAttr(FlatSymbolRefAttr::get(op.getContext(), cloned_cond.getName()));
224 op.bodyAttr(FlatSymbolRefAttr::get(op.getContext(), cloned_body.getName()));
225
226 // Drop cond/body args and return value. WhileOp result will be dropped later
227 // in EliminateUnusedResults. Traverse in reverse order so that indices to be
228 // deleted stay unchanged.
229 for (int idx = op.getNumResults() - 1; idx >= 0; --idx) {
230 if (!can_eliminate.test(idx)) continue;
231 cloned_cond.eraseArgument(idx);
232 cloned_body.front().getTerminator()->eraseOperand(idx);
233 cloned_body.eraseArgument(idx);
234 }
235
236 // Patch up branch function types.
237 for (func::FuncOp func : {cloned_cond, cloned_body}) {
238 func.setType(
239 FunctionType::get(func.getContext(), func.front().getArgumentTypes(),
240 func.front().getTerminator()->getOperandTypes()));
241 }
242 EliminateUnusedResults(op, &can_eliminate);
243 }
244
245 // For resource results, replace all uses with the resource input to which the
246 // result is tied to. After this, resource outputs of this op are expected to be
247 // unused.
ForwardCommonArgToOutput(Operation * op,ArrayRef<func::FuncOp> branches,ValueRange branch_args,bool & has_resource_result)248 LogicalResult ForwardCommonArgToOutput(Operation *op,
249 ArrayRef<func::FuncOp> branches,
250 ValueRange branch_args,
251 bool &has_resource_result) {
252 // For while, the branch inputs and outputs need to match.
253 bool io_match = isa<TF::WhileOp>(op);
254
255 has_resource_result = false;
256 // Check if the same input argument number is passed through all functions.
257 for (OpResult result : op->getResults()) {
258 if (!IsResource(result)) continue;
259
260 has_resource_result = true;
261 int result_idx = result.getResultNumber();
262 Optional<int> common_arg_index;
263 for (func::FuncOp func : branches) {
264 auto ret = func.front().getTerminator();
265 auto block_arg = ret->getOperand(result_idx).dyn_cast<BlockArgument>();
266 if (!block_arg) {
267 return op->emitOpError("result #")
268 << result_idx << " not tied to function argument for branch @"
269 << func.getName();
270 }
271 if (!common_arg_index.has_value()) {
272 common_arg_index = block_arg.getArgNumber();
273 } else if (common_arg_index.getValue() != block_arg.getArgNumber()) {
274 return op->emitError("result #")
275 << result_idx
276 << " is not tied to the same argument across all branches";
277 }
278 }
279
280 if (io_match && result_idx != common_arg_index.getValue()) {
281 return op->emitOpError("Result #")
282 << result_idx << " is tied to argument #"
283 << common_arg_index.getValue();
284 }
285
286 // Forward the corresponding input to the output
287 result.replaceAllUsesWith(branch_args[common_arg_index.getValue()]);
288 }
289 return success();
290 }
291
292 // Canonicalizes a function if. Forwards input argument to resource results and
293 // then deletes the resource results.
CanonicalizeFunctionalIfCase(Operation * op,ArrayRef<func::FuncOp> branches,ValueRange branch_args)294 LogicalResult CanonicalizeFunctionalIfCase(Operation *op,
295 ArrayRef<func::FuncOp> branches,
296 ValueRange branch_args) {
297 for (func::FuncOp func : branches) {
298 if (failed(CleanupAndCanonicalize(func))) return failure();
299 }
300
301 bool has_resource_result = false;
302 if (failed(ForwardCommonArgToOutput(op, branches, branch_args,
303 has_resource_result)))
304 return failure();
305
306 // If no resource type results were found, no further cleanup needed.
307 if (!has_resource_result) return success();
308
309 // Drop unused results.
310 EliminateUnusedResultsForIfCase(op, branches);
311 return success();
312 }
313
314 // Canonicalizes a functional while. Forwards common argument to results and
315 // drop resource results if posible.
CanonicalizeFunctionalWhile(TF::WhileOp op)316 LogicalResult CanonicalizeFunctionalWhile(TF::WhileOp op) {
317 for (func::FuncOp func : {op.cond_function(), op.body_function()}) {
318 if (failed(CleanupAndCanonicalize(func))) return failure();
319 }
320
321 // For while, just use the body function to forward operand to result.
322 bool has_resource_result = false;
323 if (failed(ForwardCommonArgToOutput(op, {op.body_function()},
324 op.getOperands(), has_resource_result)))
325 return failure();
326 // If no resource type results were found, no further cleanup needed.
327 if (!has_resource_result) return success();
328
329 // Drop unused results.
330 EliminateUnusedResultsForWhile(op);
331 return success();
332 }
333
334 // Canonicalizes region based if/case and cluster operations. If the same
335 // captured resource typed value is used for all region results, then that value
336 // is forwared to the result and the result is dropped.
CanonicalizeRegionIfCaseCluster(Operation * op)337 LogicalResult CanonicalizeRegionIfCaseCluster(Operation *op) {
338 // Check if the same value is used for all region results for this output.
339 bool has_resource_result = false;
340 for (OpResult result : op->getResults()) {
341 if (!IsResource(result)) continue;
342 has_resource_result = true;
343 int result_idx = result.getResultNumber();
344
345 Value ret0 =
346 op->getRegion(0).front().getTerminator()->getOperand(result_idx);
347 for (Region ®ion : op->getRegions().drop_front()) {
348 Value ret = region.front().getTerminator()->getOperand(result_idx);
349 if (ret != ret0) {
350 return op->emitError("Result #")
351 << result_idx
352 << " not tied to the same capture across all regions";
353 }
354 }
355 result.replaceAllUsesWith(ret0);
356 }
357
358 if (!has_resource_result) return success();
359
360 // Eliminate unused region results. Traverse in reverse order so that
361 // indices to be deleted stay unchanged.
362 for (OpResult result : llvm::reverse(op->getResults())) {
363 if (!result.use_empty()) continue;
364 int result_idx = result.getResultNumber();
365 for (Region ®ion : op->getRegions())
366 region.front().getTerminator()->eraseOperand(result_idx);
367 }
368 EliminateUnusedResults(op);
369 return success();
370 }
371
372 // Canonicalizes a region based while. If the same value is passed through
373 // the body, the result is replaced with the operand and all argument/results
374 // and retuns values corresponding to that result are dropped.
CanonicalizeWhileRegion(TF::WhileRegionOp op)375 LogicalResult CanonicalizeWhileRegion(TF::WhileRegionOp op) {
376 Region &body = op.body();
377 Region &cond = op.cond();
378 llvm::BitVector can_eliminate(op.getNumResults());
379
380 // Traverse in reverse order so that indices to be deleted stay unchanged.
381 for (OpResult result : llvm::reverse(op.getResults())) {
382 if (!IsResource(result)) continue;
383 int result_idx = result.getResultNumber();
384 Operation *yield_op = body.front().getTerminator();
385 Value yield_operand = yield_op->getOperand(result_idx);
386 Value while_operand = op.getOperand(result_idx);
387 Value body_arg = body.getArgument(result_idx);
388 Value cond_arg = cond.getArgument(result_idx);
389 if (yield_operand != body_arg && yield_operand != while_operand) {
390 return op.emitOpError("Result #") << result_idx << " is not tied to arg #"
391 << result_idx << " of the body";
392 }
393 body_arg.replaceAllUsesWith(while_operand);
394 cond_arg.replaceAllUsesWith(while_operand);
395 result.replaceAllUsesWith(while_operand);
396 body.front().getTerminator()->eraseOperand(result_idx);
397 body.eraseArgument(result_idx);
398 cond.eraseArgument(result_idx);
399 op.getOperation()->eraseOperand(result_idx);
400 can_eliminate.set(result_idx);
401 }
402 EliminateUnusedResults(op, &can_eliminate);
403 return success();
404 }
405
406 // Removes identities and canonicalizes all operations within `parent_op`.
CleanupAndCanonicalize(Operation * parent_op)407 LogicalResult CleanupAndCanonicalize(Operation *parent_op) {
408 auto walk_result = parent_op->walk([](Operation *op) {
409 // Cleanup code in attached regions.
410 for (Region ®ion : op->getRegions()) {
411 if (!llvm::hasSingleElement(region)) return WalkResult::interrupt();
412 RemovePassthroughOp(region.front());
413 RemoveDeadLocalVariables(region.front());
414 }
415
416 LogicalResult result = success();
417
418 // While condition cannot write to resource variables.
419 auto check_while_cond = [&](TF::AssignVariableOp assign) {
420 op->emitOpError("found resource write in loop condition.");
421 return WalkResult::interrupt();
422 };
423
424 if (auto if_op = dyn_cast<TF::IfOp>(op)) {
425 result = CanonicalizeFunctionalIfCase(
426 op, {if_op.then_function(), if_op.else_function()}, if_op.input());
427 } else if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
428 SmallVector<func::FuncOp, 4> branches;
429 case_op.get_branch_functions(branches);
430 result = CanonicalizeFunctionalIfCase(case_op, branches, case_op.input());
431 } else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
432 if (while_op.cond_function().walk(check_while_cond).wasInterrupted())
433 return WalkResult::interrupt();
434 result = CanonicalizeFunctionalWhile(while_op);
435 } else if (isa<TF::IfRegionOp, TF::CaseRegionOp, tf_device::ClusterOp>(
436 op)) {
437 result = CanonicalizeRegionIfCaseCluster(op);
438 } else if (auto while_region = dyn_cast<TF::WhileRegionOp>(op)) {
439 if (while_region.cond().walk(check_while_cond).wasInterrupted())
440 return WalkResult::interrupt();
441 // For while region, the body input and output arg should match.
442 result = CanonicalizeWhileRegion(while_region);
443 } else if (auto call = dyn_cast<CallOpInterface>(op)) {
444 func::FuncOp func = dyn_cast<func::FuncOp>(call.resolveCallable());
445 if (!func) return WalkResult::interrupt();
446 result = CleanupAndCanonicalize(func);
447 }
448 return failed(result) ? WalkResult::interrupt() : WalkResult::advance();
449 });
450
451 return failure(walk_result.wasInterrupted());
452 }
453
454 } // anonymous namespace
455
456 namespace TF {
457
CleanupAndCanonicalizeForResourceOpLifting(func::FuncOp func)458 LogicalResult CleanupAndCanonicalizeForResourceOpLifting(func::FuncOp func) {
459 return CleanupAndCanonicalize(func);
460 }
461
CleanupAndCanonicalizeForResourceOpLifting(ModuleOp module)462 LogicalResult CleanupAndCanonicalizeForResourceOpLifting(ModuleOp module) {
463 auto walk_result = module.walk([](tf_device::ClusterOp cluster) {
464 if (failed(CleanupAndCanonicalize(cluster))) return WalkResult::interrupt();
465 return WalkResult::advance();
466 });
467 return failure(walk_result.wasInterrupted());
468 }
469
470 } // namespace TF
471 } // namespace mlir
472