xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <utility>
17 
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallSet.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
28 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
29 #include "mlir/Pass/Pass.h"  // from @llvm-project
30 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
32 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
33 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
35 
36 namespace mlir {
37 namespace TFL {
38 namespace {
39 #define GEN_PASS_CLASSES
40 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
41 
42 // Module pass to optimize TensorFlow functional ops.
43 struct OptimizeFunctionalOpsPass
44     : public OptimizeFunctionalOpsPassBase<OptimizeFunctionalOpsPass> {
45   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OptimizeFunctionalOpsPass)
46 
47   void runOnOperation() override;
48 };
49 
50 // Updates function return type of the given functions to match the terminator
51 // op operands' types.
52 //
53 // Requires the function has exactly one block.
UpdateFuncType(func::FuncOp func)54 void UpdateFuncType(func::FuncOp func) {
55   Operation* terminator = func.front().getTerminator();
56   auto return_types = llvm::to_vector<4>(terminator->getOperandTypes());
57 
58   FunctionType func_type = func.getFunctionType();
59   if (llvm::makeArrayRef(return_types) == func_type.getResults()) return;
60 
61   auto updated_type =
62       FunctionType::get(func.getContext(), func_type.getInputs(), return_types);
63   func.setType(updated_type);
64 }
65 
66 // TODO(jpienaar): Remove when recursive side-effect modeling is added.
IsSideEffectFree(func::FuncOp func)67 bool IsSideEffectFree(func::FuncOp func) {
68   return !func.getBody()
69               .walk([&](Operation* op) {
70                 if (!MemoryEffectOpInterface::hasNoEffect(op) &&
71                     !op->hasTrait<OpTrait::IsTerminator>())
72                   return WalkResult::interrupt();
73                 return WalkResult::advance();
74               })
75               .wasInterrupted();
76 }
77 
78 // Folds TensorFlow If op with constant conditional operand by inlining the
79 // function body based on the conditional value.
80 class FoldIfOp : public OpRewritePattern<TF::IfOp> {
81  public:
FoldIfOp(MLIRContext * context)82   explicit FoldIfOp(MLIRContext* context)
83       : OpRewritePattern<TF::IfOp>(context) {}
84 
matchAndRewrite(TF::IfOp op,PatternRewriter & rewriter) const85   LogicalResult matchAndRewrite(TF::IfOp op,
86                                 PatternRewriter& rewriter) const override {
87     // This pattern is restricted to if ops in functions with exactly one block
88     // and therefore one terminator op. So, that function return type can be
89     // updated if operands' shapes change after inlining. Without this
90     // restriction, it would require tensor cast ops.
91     func::FuncOp parent_op = op->getParentOfType<func::FuncOp>();
92     if (!llvm::hasSingleElement(parent_op)) return failure();
93 
94     // Find the then and else branch functions.
95     func::FuncOp then_func = op.then_function();
96     func::FuncOp else_func = op.else_function();
97 
98     // If the If has no uses and its functions are side-effect free, then
99     // remove.
100     // TODO(jpienaar): Remove once recusive side-effects are supported.
101     if (op.use_empty() &&
102         (op.is_stateless() ||
103          (IsSideEffectFree(then_func) && IsSideEffectFree(else_func)))) {
104       rewriter.eraseOp(op.getOperation());
105       return success();
106     }
107 
108     // Extract the constant cond value.
109     DenseElementsAttr cond;
110     if (!matchPattern(op.cond(), m_Constant(&cond))) return failure();
111 
112     // TODO(hinsu): Handle constants that are not scalar booleans.
113     auto cond_type = cond.getType().dyn_cast<RankedTensorType>();
114     if (!cond_type || !cond_type.getShape().equals({}) ||
115         !cond_type.getElementType().isInteger(/*width=*/1))
116       return failure();
117 
118     // Identify the branch to inline.
119     bool cond_value = (*cond.value_begin<APInt>()).getSExtValue();
120     func::FuncOp func = cond_value ? then_func : else_func;
121 
122     // Make sure that the function has exactly one block to simplify inlining.
123     // TFLite doesn't use control flow with blocks so functions with more than
124     // one blocks are not encountered in practice.
125     if (!llvm::hasSingleElement(func)) return failure();
126 
127     BlockAndValueMapping mapper;
128     for (int i = 0, e = func.getNumArguments(); i != e; ++i)
129       mapper.map(func.getArgument(i), op.getOperand(i + 1));
130 
131     llvm::SmallVector<Value, 4> updated_results;
132     for (auto& op_to_inline : func.front()) {
133       // If this is a terminator, identify the values to use to replace the
134       // original If op.
135       if (op_to_inline.hasTrait<OpTrait::IsTerminator>()) {
136         updated_results.reserve(op_to_inline.getNumOperands());
137         for (Value operand : op_to_inline.getOperands())
138           updated_results.push_back(mapper.lookup(operand));
139         break;
140       }
141 
142       // Otherwise, clone the op here.
143       rewriter.clone(op_to_inline, mapper);
144     }
145     rewriter.replaceOp(op, updated_results);
146 
147     // Here, shapes of the updated_results may not match the original values. If
148     // any of the values are operands of the terminator op, then the function
149     // return type should be updated.
150     UpdateFuncType(parent_op);
151 
152     return success();
153   }
154 };
155 
runOnOperation()156 void OptimizeFunctionalOpsPass::runOnOperation() {
157   RewritePatternSet patterns(&getContext());
158 
159   patterns.add<FoldIfOp>(&getContext());
160 
161   ModuleOp module = getOperation();
162   (void)applyPatternsAndFoldGreedily(module, std::move(patterns));
163 }
164 }  // namespace
165 
CreateOptimizeFunctionalOpsPass()166 std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass() {
167   return std::make_unique<OptimizeFunctionalOpsPass>();
168 }
169 
170 }  // namespace TFL
171 }  // namespace mlir
172