1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18
19 #include "absl/strings/str_cat.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
27 #include "mlir/IR/Builders.h" // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/IR/MLIRContext.h" // from @llvm-project
31 #include "mlir/Pass/Pass.h" // from @llvm-project
32 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
33 #include "mlir/Support/LLVM.h" // from @llvm-project
34 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h"
35 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h"
36 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
37
38 namespace mlir {
39 namespace TFL {
40 namespace tac {
41 namespace {
42
43 // This pass is used to fold tfl.const ops to each subgraph (func::FuncOp):
44 // See the example below:
45 //
46 // In main:
47 // %0 = tfl.const...
48 // %1 = tfl.const...
49 // %2 = call func_1(..., %0,...)
50 // %3 = call func_2(..., %0, ..., %1...)
51 // ...
52 //
53 // Then those consts will be copied into each function and replace their usage.
54 // func_1:
55 // %0 = tfl.const...
56 // func_2:
57 // %0 = tfl.const...
58 // %1 = tfl.const...
59 class FoldConstantsToSubgraphPass
60 : public mlir::PassWrapper<FoldConstantsToSubgraphPass,
61 mlir::OperationPass<ModuleOp>> {
62 public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FoldConstantsToSubgraphPass)63 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FoldConstantsToSubgraphPass)
64
65 llvm::StringRef getArgument() const final {
66 return "tfl-fold-constants-to-subgraph";
67 }
getDescription() const68 llvm::StringRef getDescription() const final {
69 return "Fold constants into each subgraph.";
70 }
71 FoldConstantsToSubgraphPass() = default;
FoldConstantsToSubgraphPass(const FoldConstantsToSubgraphPass & other)72 FoldConstantsToSubgraphPass(const FoldConstantsToSubgraphPass& other) {
73 this->fold_all_constants_flag_ = other.fold_all_constants_flag_;
74 }
FoldConstantsToSubgraphPass(bool fold_all_constants)75 explicit FoldConstantsToSubgraphPass(bool fold_all_constants) {
76 fold_all_constants_flag_ = fold_all_constants;
77 }
78
79 private:
80 void runOnOperation() override;
81
82 Option<bool> fold_all_constants_flag_{
83 *this, "fold-all-constants",
84 llvm::cl::desc("Whether to fold all constants or just i32."),
85 llvm::cl::init(false)};
86 };
87
CopyConstantIntoFunc(int argument_index,Operation * const_op,func::FuncOp func)88 void CopyConstantIntoFunc(int argument_index, Operation* const_op,
89 func::FuncOp func) {
90 assert((llvm::isa<TFL::ConstOp, TFL::QConstOp>(const_op)) &&
91 "Expect QConst or Const op.");
92 OpBuilder builder(func.getBody());
93 auto cloned_const_op = const_op->clone();
94 cloned_const_op->setLoc(func.getBody().getLoc());
95 builder.insert(cloned_const_op);
96 // Rewire the usage.
97 func.getArgument(argument_index)
98 .replaceAllUsesWith(cloned_const_op->getResult(0));
99 }
100
IsConstOrQConstInt(Operation * op)101 bool IsConstOrQConstInt(Operation* op) {
102 if (!llvm::isa<TFL::ConstOp, TFL::QConstOp>(op)) return false;
103
104 if (auto const_op = dyn_cast_or_null<TFL::ConstOp>(op)) {
105 // ConstOp path.
106 auto type = const_op.getType()
107 .dyn_cast_or_null<RankedTensorType>()
108 .getElementType();
109 if (!type.isInteger(32) && !type.isInteger(64)) return false;
110 } else {
111 // QConstOp path.
112 auto qconst_op = dyn_cast<TFL::QConstOp>(op);
113 auto type =
114 quant::QuantizedType::getQuantizedElementType(qconst_op.getType());
115 if (type.getStorageTypeIntegralWidth() != 32) {
116 return false;
117 }
118 }
119 return true;
120 }
121
runOnOperation()122 void FoldConstantsToSubgraphPass::runOnOperation() {
123 auto module = getOperation();
124
125 for (auto fn : module.getOps<func::FuncOp>()) {
126 fn.walk([&](Operation* op) {
127 if (!llvm::isa<TFL::ConstOp, TFL::QConstOp>(op)) return;
128
129 // We only fold int32/int64 for Const and i32 for QConst if not specify
130 // all constants flag. (Since they're more like "configs" or i32 biases.)
131 // We will fold every const ops (and q_const ops) if we speicfy the
132 // fold_all_constants_flag.
133 if (!fold_all_constants_flag_) {
134 if (!IsConstOrQConstInt(op)) return;
135 }
136
137 for (auto consumer : op->getResult(0).getUsers()) {
138 auto consumer_call = llvm::dyn_cast_or_null<func::CallOp>(consumer);
139
140 if (!consumer_call) continue;
141
142 auto function_name = consumer_call.getCallee();
143
144 // Locate the argument position of the use.
145 int argument_index = -1;
146 for (int i = 0; i < consumer_call.getNumOperands(); ++i) {
147 if (consumer_call.getOperand(i) == op->getResult(0)) {
148 argument_index = i;
149 break;
150 }
151 }
152
153 // Copy the const into the consumer func and replace their usages.
154 func::FuncOp func = module.lookupSymbol<func::FuncOp>(function_name);
155
156 CopyConstantIntoFunc(argument_index, op, func);
157 }
158 });
159 }
160 }
161
162 } // namespace
163
CreateFoldConstantsToSubgraphPass(bool fold_all_constants)164 std::unique_ptr<OperationPass<ModuleOp>> CreateFoldConstantsToSubgraphPass(
165 bool fold_all_constants) {
166 return std::make_unique<FoldConstantsToSubgraphPass>(fold_all_constants);
167 }
168
169 static PassRegistration<FoldConstantsToSubgraphPass> pass;
170
171 } // namespace tac
172 } // namespace TFL
173 } // namespace mlir
174