1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22
23 #include "absl/strings/str_cat.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
27 #include "mlir/IR/Builders.h" // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/IR/PatternMatch.h" // from @llvm-project
31 #include "mlir/IR/Visitors.h" // from @llvm-project
32 #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
33 #include "mlir/Pass/Pass.h" // from @llvm-project
34 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
35 #include "mlir/Support/LLVM.h" // from @llvm-project
36 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/cost.h"
37 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
38 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
39 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
40 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
41
42 namespace mlir {
43 namespace TFL {
44 namespace tac {
45 namespace {
46
47 // These are just fake costs.
48 constexpr float kDequantCost = 2.0;
49 constexpr float kQuantCost = 2.0;
50 constexpr float kRequantCost = 2.0;
51
52 // TODO(renjieliu): Ideally this should consider different kinds of SOCs as
53 // well.
54
55 // Get total bytes transferred.
GetTransferredTensorBytes(func::CallOp from_graph,func::CallOp to_graph)56 int64_t GetTransferredTensorBytes(func::CallOp from_graph,
57 func::CallOp to_graph) {
58 int64_t total_size_transferred = 0;
59 for (auto input : to_graph.getOperands()) {
60 Operation* input_op = input.getDefiningOp();
61 if (input_op && input_op == from_graph.getOperation()) {
62 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
63 if (input_type == nullptr || !input_type.hasStaticShape()) continue;
64 // Quantized type does not support getSizeInBits.
65 if (IsQUI8Type(input_type) || IsQI8Type(input_type)) {
66 total_size_transferred += input_type.getNumElements() * 8;
67 } else {
68 total_size_transferred += input_type.cast<ShapedType>().getSizeInBits();
69 }
70 }
71 }
72 return total_size_transferred;
73 }
74
75 // Get total tensor element size transferred.
GetTransferredElementCount(func::CallOp from_graph,func::CallOp to_graph)76 int64_t GetTransferredElementCount(func::CallOp from_graph,
77 func::CallOp to_graph) {
78 int64_t total_element_count = 0;
79 for (auto input : to_graph.getOperands()) {
80 Operation* input_op = input.getDefiningOp();
81 if (input_op && input_op == from_graph.getOperation()) {
82 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
83 if (input_type == nullptr || !input_type.hasStaticShape()) continue;
84 total_element_count += input_type.getNumElements();
85 }
86 }
87 return total_element_count;
88 }
89
90 struct GetOpCostPass
91 : mlir::PassWrapper<GetOpCostPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_IDmlir::TFL::tac::__anon39f6a1540111::GetOpCostPass92 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GetOpCostPass)
93
94 llvm::StringRef getArgument() const final { return "tfl-get-op-cost"; }
getDescriptionmlir::TFL::tac::__anon39f6a1540111::GetOpCostPass95 llvm::StringRef getDescription() const final {
96 return "Get cost for every op";
97 }
98 void runOnOperation() override;
99 };
100
runOnOperation()101 void GetOpCostPass::runOnOperation() {
102 auto func = getOperation();
103 OpBuilder builder(func);
104 func.walk([&](Operation* op) {
105 if (IsNonConstOp(op) && !IsTerminatorOp(op) &&
106 !llvm::isa<func::ReturnOp, func::FuncOp, CallOpInterface>(op)) {
107 auto hardware = GetTargetAnnotation(op);
108 if (!hardware) return;
109 float cost = GetCostForOp(op, hardware.getValue());
110 UpdateCost(op, cost, &builder);
111 }
112 });
113 }
114
115 } // namespace
116
GetCostForOp(Operation * op,const std::string & hardware)117 float GetCostForOp(Operation* op, const std::string& hardware) {
118 auto* device_hardware = GetTargetHardware(hardware);
119 if (device_hardware == nullptr) {
120 return kDefaultFixedValuedCost;
121 }
122
123 return device_hardware->GetOpCost(op);
124 }
125
GetCostForFunc(func::FuncOp * func,const std::string & hardware)126 float GetCostForFunc(func::FuncOp* func, const std::string& hardware) {
127 auto* device_hardware = GetTargetHardware(hardware);
128 if (device_hardware == nullptr) {
129 return kDefaultFixedValuedCost;
130 }
131
132 return device_hardware->GetFuncCost(func);
133 }
134
GetTransferCost(const std::string & from_hardware_str,const std::string & to_hardware_str,func::CallOp from_graph,func::CallOp to_graph)135 float GetTransferCost(const std::string& from_hardware_str,
136 const std::string& to_hardware_str,
137 func::CallOp from_graph, func::CallOp to_graph) {
138 auto from_hardware = GetTargetHardware(from_hardware_str);
139 auto to_hardware = GetTargetHardware(to_hardware_str);
140 if (from_hardware == nullptr) {
141 from_graph.emitError(absl::StrCat(
142 "we cannot find the registered hardware: ", from_hardware_str));
143 }
144
145 if (to_hardware == nullptr) {
146 to_graph.emitError(absl::StrCat("we cannot find the registered hardware: ",
147 to_hardware_str));
148 }
149
150 const int64_t total_size_transferred =
151 GetTransferredTensorBytes(from_graph, to_graph);
152 return to_hardware->GetHardwareSwitchingCost(from_hardware,
153 total_size_transferred);
154 }
155
GetQuantDequantCost(InferenceType from_inference_type,InferenceType to_inference_type,func::CallOp from_graph,func::CallOp to_graph)156 float GetQuantDequantCost(InferenceType from_inference_type,
157 InferenceType to_inference_type,
158 func::CallOp from_graph, func::CallOp to_graph) {
159 // Same inference type, no dequant/quant happens.
160 if (from_inference_type == to_inference_type) return 0;
161
162 const int64_t total_element_count_transferred =
163 GetTransferredElementCount(from_graph, to_graph);
164
165 if (from_inference_type == FLOAT || from_inference_type == HYBRID) {
166 // FLOAT <-> HYBRID will have no quant/dequant as well.
167 if (to_inference_type == FLOAT || to_inference_type == HYBRID) {
168 return 0;
169 } else if (to_inference_type == QUANTIZED_INT8 ||
170 to_inference_type == QUANTIZED_UINT8) {
171 // QUANT path.
172 return kQuantCost * total_element_count_transferred;
173 }
174 }
175
176 if (from_inference_type == QUANTIZED_INT8 ||
177 from_inference_type == QUANTIZED_UINT8) {
178 // Dequant path.
179 if (to_inference_type == FLOAT || to_inference_type == HYBRID) {
180 return kDequantCost * total_element_count_transferred;
181 } else if (to_inference_type == QUANTIZED_INT8 ||
182 to_inference_type == QUANTIZED_UINT8) {
183 // Requant path.
184 return kRequantCost * total_element_count_transferred;
185 }
186 }
187
188 // Default quant/dequant/requant cost.
189 return kDefaultFixedValuedCost;
190 }
191
CreateGetOpCostPass()192 std::unique_ptr<OperationPass<func::FuncOp>> CreateGetOpCostPass() {
193 return std::make_unique<GetOpCostPass>();
194 }
195
196 static PassRegistration<GetOpCostPass> pass;
197
198 } // namespace tac
199 } // namespace TFL
200 } // namespace mlir
201