1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
17 #include "mlir/Transforms/Passes.h" // from @llvm-project
18 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
19 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
20
21 namespace tensorflow {
22 namespace tfrt_compiler {
23 namespace {
24
25 // This pass removes tf.If ops' operands that are produced by tf.Const ops.
26 // These constants can be moved into branches' function body for further
27 // optimziation.
28 class RemoveTfIfConstArgs
29 : public mlir::PassWrapper<RemoveTfIfConstArgs,
30 mlir::OperationPass<mlir::ModuleOp>> {
31 public:
32 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RemoveTfIfConstArgs)
33
34 private:
getArgument() const35 llvm::StringRef getArgument() const final {
36 return "tfrt-remove-tf-if-const-args";
37 }
getDescription() const38 llvm::StringRef getDescription() const final {
39 return "Remove const args from tf.If ops";
40 }
41
runOnOperation()42 void runOnOperation() override {
43 auto module = getOperation();
44 for (auto func_op :
45 llvm::make_early_inc_range(module.getOps<mlir::func::FuncOp>())) {
46 ProcessFunction(func_op);
47 }
48 }
49
ProcessFunction(mlir::func::FuncOp op)50 void ProcessFunction(mlir::func::FuncOp op) {
51 // Set the insertion point to the current function, as we will insert new
52 // functions here.
53 mlir::OpBuilder builder(op);
54 for (mlir::Operation &op : op.front()) {
55 auto if_op = llvm::dyn_cast<mlir::TF::IfOp>(&op);
56 if (!if_op) continue;
57
58 // Record the operands that are produced by tf.Const ops.
59 llvm::SmallVector<mlir::TF::ConstOp, 2> const_args;
60 // Record these operands's corresponding operand indices.
61 llvm::SmallVector<unsigned, 2> const_arg_indices;
62 // Record the remaining operands that won't be removed.
63 llvm::SmallVector<mlir::Value, 2> remaining_args;
64 for (auto iter : llvm::enumerate(if_op.input())) {
65 mlir::Value operand = iter.value();
66 if (auto const_op = operand.getDefiningOp<mlir::TF::ConstOp>()) {
67 const_args.push_back(const_op);
68 const_arg_indices.push_back(iter.index());
69 } else {
70 remaining_args.push_back(operand);
71 }
72 }
73
74 if (const_args.empty()) continue;
75
76 RemoveConstArgsFromTfIfOp(builder, if_op, const_args, const_arg_indices,
77 remaining_args);
78 }
79 }
80
RemoveConstArgsFromTfIfOp(mlir::OpBuilder & builder,mlir::TF::IfOp if_op,llvm::ArrayRef<mlir::TF::ConstOp> const_args,llvm::ArrayRef<unsigned> const_arg_indices,llvm::ArrayRef<mlir::Value> remaining_args)81 void RemoveConstArgsFromTfIfOp(mlir::OpBuilder &builder, mlir::TF::IfOp if_op,
82 llvm::ArrayRef<mlir::TF::ConstOp> const_args,
83 llvm::ArrayRef<unsigned> const_arg_indices,
84 llvm::ArrayRef<mlir::Value> remaining_args) {
85 auto branch_suffix = absl::StrCat("_removed_const_args_", id_++);
86
87 // Create wrapper functions with the new arguments (as const args are
88 // removed) for both then function and else function.
89 auto new_then_function_name =
90 CreateBranchFunction(builder, if_op.then_function(), branch_suffix,
91 const_args, const_arg_indices);
92 auto new_else_function_name =
93 CreateBranchFunction(builder, if_op.else_function(), branch_suffix,
94 const_args, const_arg_indices);
95
96 // Change the if_op's argumetns to the new arguments, branches to new
97 // branches. Note that the outputs are not changed.
98 if_op.inputMutable().assign(remaining_args);
99 if_op.then_branchAttr(
100 mlir::SymbolRefAttr::get(builder.getContext(), new_then_function_name));
101 if_op.else_branchAttr(
102 mlir::SymbolRefAttr::get(builder.getContext(), new_else_function_name));
103 }
104
CreateBranchFunction(mlir::OpBuilder & builder,mlir::func::FuncOp branch,absl::string_view branch_suffix,llvm::ArrayRef<mlir::TF::ConstOp> const_args,llvm::ArrayRef<unsigned> const_arg_indices)105 llvm::StringRef CreateBranchFunction(
106 mlir::OpBuilder &builder, mlir::func::FuncOp branch,
107 absl::string_view branch_suffix,
108 llvm::ArrayRef<mlir::TF::ConstOp> const_args,
109 llvm::ArrayRef<unsigned> const_arg_indices) {
110 // Get the new function type as const args are removed.
111 llvm::BitVector const_arg_indices_bv(branch.getNumArguments());
112 for (auto i : const_arg_indices) const_arg_indices_bv.set(i);
113 auto new_branch_type = branch.getFunctionType().getWithoutArgsAndResults(
114 const_arg_indices_bv, {});
115 std::string new_branch_name =
116 absl::StrCat(branch.getSymName().str(), branch_suffix);
117 // Create the wrapper function with the new arguments that calls the
118 // original branch.
119 auto new_branch = builder.create<mlir::func::FuncOp>(
120 branch.getLoc(), new_branch_name, new_branch_type);
121 new_branch.setVisibility(mlir::func::FuncOp::Visibility::Private);
122
123 // In its function body, we will add the corresponding const ops and call
124 // the original branch.
125
126 mlir::OpBuilder::InsertionGuard guard(builder);
127 auto *block = new_branch.addEntryBlock();
128 builder.setInsertionPointToStart(block);
129
130 // Prepare the function arguments of the original branch.
131 llvm::SmallVector<mlir::Value, 4> call_args(branch.getNumArguments());
132
133 // For those removed const args, we copy the tf.Const op, and use that as
134 // the corresponding argument when calling the original branch.
135 for (const auto &iter : llvm::zip(const_args, const_arg_indices)) {
136 auto const_op =
137 llvm::cast<mlir::TF::ConstOp>(builder.clone(*std::get<0>(iter)));
138 unsigned index = std::get<1>(iter);
139 call_args[index] = const_op;
140 }
141
142 // For the rest, they are now coming from the wrapper function's arguments
143 // in the original order.
144 for (int i = 0, j = 0; i < call_args.size(); ++i) {
145 if (!call_args[i]) {
146 assert(j < block->getNumArguments());
147 call_args[i] = block->getArgument(j++);
148 }
149 }
150
151 // Now create the call op to the original branch.
152 auto call_op = builder.create<mlir::TF::StatefulPartitionedCallOp>(
153 new_branch.getLoc(), new_branch_type.getResults(), call_args,
154 branch.getSymName(), "", "", "");
155 // Note that the outputs are not changed.
156 builder.create<mlir::func::ReturnOp>(new_branch.getLoc(), call_op.output());
157
158 return new_branch.getSymName();
159 }
160
161 int id_ = 0;
162 };
163
164 } // namespace
165
166 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateRemoveTfIfConstArgsPass()167 CreateRemoveTfIfConstArgsPass() {
168 return std::make_unique<RemoveTfIfConstArgs>();
169 }
170
171 static mlir::PassRegistration<RemoveTfIfConstArgs> register_pass(
172 CreateRemoveTfIfConstArgsPass);
173
174 } // namespace tfrt_compiler
175 } // namespace tensorflow
176