xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/transforms/remove_tf_if_const_args.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
17 #include "mlir/Transforms/Passes.h"  // from @llvm-project
18 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
19 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
20 
21 namespace tensorflow {
22 namespace tfrt_compiler {
23 namespace {
24 
25 // This pass removes tf.If ops' operands that are produced by tf.Const ops.
26 // These constants can be moved into branches' function body for further
27 // optimziation.
28 class RemoveTfIfConstArgs
29     : public mlir::PassWrapper<RemoveTfIfConstArgs,
30                                mlir::OperationPass<mlir::ModuleOp>> {
31  public:
32   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RemoveTfIfConstArgs)
33 
34  private:
getArgument() const35   llvm::StringRef getArgument() const final {
36     return "tfrt-remove-tf-if-const-args";
37   }
getDescription() const38   llvm::StringRef getDescription() const final {
39     return "Remove const args from tf.If ops";
40   }
41 
runOnOperation()42   void runOnOperation() override {
43     auto module = getOperation();
44     for (auto func_op :
45          llvm::make_early_inc_range(module.getOps<mlir::func::FuncOp>())) {
46       ProcessFunction(func_op);
47     }
48   }
49 
ProcessFunction(mlir::func::FuncOp op)50   void ProcessFunction(mlir::func::FuncOp op) {
51     // Set the insertion point to the current function, as we will insert new
52     // functions here.
53     mlir::OpBuilder builder(op);
54     for (mlir::Operation &op : op.front()) {
55       auto if_op = llvm::dyn_cast<mlir::TF::IfOp>(&op);
56       if (!if_op) continue;
57 
58       // Record the operands that are produced by tf.Const ops.
59       llvm::SmallVector<mlir::TF::ConstOp, 2> const_args;
60       // Record these operands's corresponding operand indices.
61       llvm::SmallVector<unsigned, 2> const_arg_indices;
62       // Record the remaining operands that won't be removed.
63       llvm::SmallVector<mlir::Value, 2> remaining_args;
64       for (auto iter : llvm::enumerate(if_op.input())) {
65         mlir::Value operand = iter.value();
66         if (auto const_op = operand.getDefiningOp<mlir::TF::ConstOp>()) {
67           const_args.push_back(const_op);
68           const_arg_indices.push_back(iter.index());
69         } else {
70           remaining_args.push_back(operand);
71         }
72       }
73 
74       if (const_args.empty()) continue;
75 
76       RemoveConstArgsFromTfIfOp(builder, if_op, const_args, const_arg_indices,
77                                 remaining_args);
78     }
79   }
80 
RemoveConstArgsFromTfIfOp(mlir::OpBuilder & builder,mlir::TF::IfOp if_op,llvm::ArrayRef<mlir::TF::ConstOp> const_args,llvm::ArrayRef<unsigned> const_arg_indices,llvm::ArrayRef<mlir::Value> remaining_args)81   void RemoveConstArgsFromTfIfOp(mlir::OpBuilder &builder, mlir::TF::IfOp if_op,
82                                  llvm::ArrayRef<mlir::TF::ConstOp> const_args,
83                                  llvm::ArrayRef<unsigned> const_arg_indices,
84                                  llvm::ArrayRef<mlir::Value> remaining_args) {
85     auto branch_suffix = absl::StrCat("_removed_const_args_", id_++);
86 
87     // Create wrapper functions with the new arguments (as const args are
88     // removed) for both then function and else function.
89     auto new_then_function_name =
90         CreateBranchFunction(builder, if_op.then_function(), branch_suffix,
91                              const_args, const_arg_indices);
92     auto new_else_function_name =
93         CreateBranchFunction(builder, if_op.else_function(), branch_suffix,
94                              const_args, const_arg_indices);
95 
96     // Change the if_op's argumetns to the new arguments, branches to new
97     // branches. Note that the outputs are not changed.
98     if_op.inputMutable().assign(remaining_args);
99     if_op.then_branchAttr(
100         mlir::SymbolRefAttr::get(builder.getContext(), new_then_function_name));
101     if_op.else_branchAttr(
102         mlir::SymbolRefAttr::get(builder.getContext(), new_else_function_name));
103   }
104 
CreateBranchFunction(mlir::OpBuilder & builder,mlir::func::FuncOp branch,absl::string_view branch_suffix,llvm::ArrayRef<mlir::TF::ConstOp> const_args,llvm::ArrayRef<unsigned> const_arg_indices)105   llvm::StringRef CreateBranchFunction(
106       mlir::OpBuilder &builder, mlir::func::FuncOp branch,
107       absl::string_view branch_suffix,
108       llvm::ArrayRef<mlir::TF::ConstOp> const_args,
109       llvm::ArrayRef<unsigned> const_arg_indices) {
110     // Get the new function type as const args are removed.
111     llvm::BitVector const_arg_indices_bv(branch.getNumArguments());
112     for (auto i : const_arg_indices) const_arg_indices_bv.set(i);
113     auto new_branch_type = branch.getFunctionType().getWithoutArgsAndResults(
114         const_arg_indices_bv, {});
115     std::string new_branch_name =
116         absl::StrCat(branch.getSymName().str(), branch_suffix);
117     // Create the wrapper function with the new arguments that calls the
118     // original branch.
119     auto new_branch = builder.create<mlir::func::FuncOp>(
120         branch.getLoc(), new_branch_name, new_branch_type);
121     new_branch.setVisibility(mlir::func::FuncOp::Visibility::Private);
122 
123     // In its function body, we will add the corresponding const ops and call
124     // the original branch.
125 
126     mlir::OpBuilder::InsertionGuard guard(builder);
127     auto *block = new_branch.addEntryBlock();
128     builder.setInsertionPointToStart(block);
129 
130     // Prepare the function arguments of the original branch.
131     llvm::SmallVector<mlir::Value, 4> call_args(branch.getNumArguments());
132 
133     // For those removed const args, we copy the tf.Const op, and use that as
134     // the corresponding argument when calling the original branch.
135     for (const auto &iter : llvm::zip(const_args, const_arg_indices)) {
136       auto const_op =
137           llvm::cast<mlir::TF::ConstOp>(builder.clone(*std::get<0>(iter)));
138       unsigned index = std::get<1>(iter);
139       call_args[index] = const_op;
140     }
141 
142     // For the rest, they are now coming from the wrapper function's arguments
143     // in the original order.
144     for (int i = 0, j = 0; i < call_args.size(); ++i) {
145       if (!call_args[i]) {
146         assert(j < block->getNumArguments());
147         call_args[i] = block->getArgument(j++);
148       }
149     }
150 
151     // Now create the call op to the original branch.
152     auto call_op = builder.create<mlir::TF::StatefulPartitionedCallOp>(
153         new_branch.getLoc(), new_branch_type.getResults(), call_args,
154         branch.getSymName(), "", "", "");
155     // Note that the outputs are not changed.
156     builder.create<mlir::func::ReturnOp>(new_branch.getLoc(), call_op.output());
157 
158     return new_branch.getSymName();
159   }
160 
161   int id_ = 0;
162 };
163 
164 }  // namespace
165 
166 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateRemoveTfIfConstArgsPass()167 CreateRemoveTfIfConstArgsPass() {
168   return std::make_unique<RemoveTfIfConstArgs>();
169 }
170 
171 static mlir::PassRegistration<RemoveTfIfConstArgs> register_pass(
172     CreateRemoveTfIfConstArgsPass);
173 
174 }  // namespace tfrt_compiler
175 }  // namespace tensorflow
176