1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 
19 #include "absl/strings/str_cat.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
27 #include "mlir/IR/Builders.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
31 #include "mlir/Pass/Pass.h"  // from @llvm-project
32 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
33 #include "mlir/Support/LLVM.h"  // from @llvm-project
34 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h"
35 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h"
36 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
37 
38 namespace mlir {
39 namespace TFL {
40 namespace tac {
41 namespace {
42 
43 // This pass is used to fold tfl.const ops to each subgraph (func::FuncOp):
44 // See the example below:
45 //
46 // In main:
47 // %0 = tfl.const...
48 // %1 = tfl.const...
49 // %2 = call func_1(..., %0,...)
50 // %3 = call func_2(..., %0, ..., %1...)
51 // ...
52 //
53 // Then those consts will be copied into each function and replace their usage.
54 // func_1:
55 //   %0 = tfl.const...
56 // func_2:
57 //   %0 = tfl.const...
58 //   %1 = tfl.const...
59 class FoldConstantsToSubgraphPass
60     : public mlir::PassWrapper<FoldConstantsToSubgraphPass,
61                                mlir::OperationPass<ModuleOp>> {
62  public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FoldConstantsToSubgraphPass)63   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FoldConstantsToSubgraphPass)
64 
65   llvm::StringRef getArgument() const final {
66     return "tfl-fold-constants-to-subgraph";
67   }
getDescription() const68   llvm::StringRef getDescription() const final {
69     return "Fold constants into each subgraph.";
70   }
71   FoldConstantsToSubgraphPass() = default;
FoldConstantsToSubgraphPass(const FoldConstantsToSubgraphPass & other)72   FoldConstantsToSubgraphPass(const FoldConstantsToSubgraphPass& other) {
73     this->fold_all_constants_flag_ = other.fold_all_constants_flag_;
74   }
FoldConstantsToSubgraphPass(bool fold_all_constants)75   explicit FoldConstantsToSubgraphPass(bool fold_all_constants) {
76     fold_all_constants_flag_ = fold_all_constants;
77   }
78 
79  private:
80   void runOnOperation() override;
81 
82   Option<bool> fold_all_constants_flag_{
83       *this, "fold-all-constants",
84       llvm::cl::desc("Whether to fold all constants or just i32."),
85       llvm::cl::init(false)};
86 };
87 
CopyConstantIntoFunc(int argument_index,Operation * const_op,func::FuncOp func)88 void CopyConstantIntoFunc(int argument_index, Operation* const_op,
89                           func::FuncOp func) {
90   assert((llvm::isa<TFL::ConstOp, TFL::QConstOp>(const_op)) &&
91          "Expect QConst or Const op.");
92   OpBuilder builder(func.getBody());
93   auto cloned_const_op = const_op->clone();
94   cloned_const_op->setLoc(func.getBody().getLoc());
95   builder.insert(cloned_const_op);
96   // Rewire the usage.
97   func.getArgument(argument_index)
98       .replaceAllUsesWith(cloned_const_op->getResult(0));
99 }
100 
IsConstOrQConstInt(Operation * op)101 bool IsConstOrQConstInt(Operation* op) {
102   if (!llvm::isa<TFL::ConstOp, TFL::QConstOp>(op)) return false;
103 
104   if (auto const_op = dyn_cast_or_null<TFL::ConstOp>(op)) {
105     // ConstOp path.
106     auto type = const_op.getType()
107                     .dyn_cast_or_null<RankedTensorType>()
108                     .getElementType();
109     if (!type.isInteger(32) && !type.isInteger(64)) return false;
110   } else {
111     // QConstOp path.
112     auto qconst_op = dyn_cast<TFL::QConstOp>(op);
113     auto type =
114         quant::QuantizedType::getQuantizedElementType(qconst_op.getType());
115     if (type.getStorageTypeIntegralWidth() != 32) {
116       return false;
117     }
118   }
119   return true;
120 }
121 
runOnOperation()122 void FoldConstantsToSubgraphPass::runOnOperation() {
123   auto module = getOperation();
124 
125   for (auto fn : module.getOps<func::FuncOp>()) {
126     fn.walk([&](Operation* op) {
127       if (!llvm::isa<TFL::ConstOp, TFL::QConstOp>(op)) return;
128 
129       // We only fold int32/int64 for Const and i32 for QConst if not specify
130       // all constants flag. (Since they're more like "configs" or i32 biases.)
131       // We will fold every const ops (and q_const ops) if we speicfy the
132       // fold_all_constants_flag.
133       if (!fold_all_constants_flag_) {
134         if (!IsConstOrQConstInt(op)) return;
135       }
136 
137       for (auto consumer : op->getResult(0).getUsers()) {
138         auto consumer_call = llvm::dyn_cast_or_null<func::CallOp>(consumer);
139 
140         if (!consumer_call) continue;
141 
142         auto function_name = consumer_call.getCallee();
143 
144         // Locate the argument position of the use.
145         int argument_index = -1;
146         for (int i = 0; i < consumer_call.getNumOperands(); ++i) {
147           if (consumer_call.getOperand(i) == op->getResult(0)) {
148             argument_index = i;
149             break;
150           }
151         }
152 
153         // Copy the const into the consumer func and replace their usages.
154         func::FuncOp func = module.lookupSymbol<func::FuncOp>(function_name);
155 
156         CopyConstantIntoFunc(argument_index, op, func);
157       }
158     });
159   }
160 }
161 
162 }  // namespace
163 
CreateFoldConstantsToSubgraphPass(bool fold_all_constants)164 std::unique_ptr<OperationPass<ModuleOp>> CreateFoldConstantsToSubgraphPass(
165     bool fold_all_constants) {
166   return std::make_unique<FoldConstantsToSubgraphPass>(fold_all_constants);
167 }
168 
169 static PassRegistration<FoldConstantsToSubgraphPass> pass;
170 
171 }  // namespace tac
172 }  // namespace TFL
173 }  // namespace mlir
174