xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22 
23 #include "absl/strings/str_cat.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
27 #include "mlir/IR/Builders.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
31 #include "mlir/IR/Visitors.h"  // from @llvm-project
32 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
33 #include "mlir/Pass/Pass.h"  // from @llvm-project
34 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
35 #include "mlir/Support/LLVM.h"  // from @llvm-project
36 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/cost.h"
37 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
38 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
39 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
40 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
41 
42 namespace mlir {
43 namespace TFL {
44 namespace tac {
45 namespace {
46 
47 // These are just fake costs.
48 constexpr float kDequantCost = 2.0;
49 constexpr float kQuantCost = 2.0;
50 constexpr float kRequantCost = 2.0;
51 
52 // TODO(renjieliu): Ideally this should consider different kinds of SOCs as
53 // well.
54 
55 // Get total bytes transferred.
GetTransferredTensorBytes(func::CallOp from_graph,func::CallOp to_graph)56 int64_t GetTransferredTensorBytes(func::CallOp from_graph,
57                                   func::CallOp to_graph) {
58   int64_t total_size_transferred = 0;
59   for (auto input : to_graph.getOperands()) {
60     Operation* input_op = input.getDefiningOp();
61     if (input_op && input_op == from_graph.getOperation()) {
62       auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
63       if (input_type == nullptr || !input_type.hasStaticShape()) continue;
64       // Quantized type does not support getSizeInBits.
65       if (IsQUI8Type(input_type) || IsQI8Type(input_type)) {
66         total_size_transferred += input_type.getNumElements() * 8;
67       } else {
68         total_size_transferred += input_type.cast<ShapedType>().getSizeInBits();
69       }
70     }
71   }
72   return total_size_transferred;
73 }
74 
75 // Get total tensor element size transferred.
GetTransferredElementCount(func::CallOp from_graph,func::CallOp to_graph)76 int64_t GetTransferredElementCount(func::CallOp from_graph,
77                                    func::CallOp to_graph) {
78   int64_t total_element_count = 0;
79   for (auto input : to_graph.getOperands()) {
80     Operation* input_op = input.getDefiningOp();
81     if (input_op && input_op == from_graph.getOperation()) {
82       auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
83       if (input_type == nullptr || !input_type.hasStaticShape()) continue;
84       total_element_count += input_type.getNumElements();
85     }
86   }
87   return total_element_count;
88 }
89 
90 struct GetOpCostPass
91     : mlir::PassWrapper<GetOpCostPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_IDmlir::TFL::tac::__anon39f6a1540111::GetOpCostPass92   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GetOpCostPass)
93 
94   llvm::StringRef getArgument() const final { return "tfl-get-op-cost"; }
getDescriptionmlir::TFL::tac::__anon39f6a1540111::GetOpCostPass95   llvm::StringRef getDescription() const final {
96     return "Get cost for every op";
97   }
98   void runOnOperation() override;
99 };
100 
runOnOperation()101 void GetOpCostPass::runOnOperation() {
102   auto func = getOperation();
103   OpBuilder builder(func);
104   func.walk([&](Operation* op) {
105     if (IsNonConstOp(op) && !IsTerminatorOp(op) &&
106         !llvm::isa<func::ReturnOp, func::FuncOp, CallOpInterface>(op)) {
107       auto hardware = GetTargetAnnotation(op);
108       if (!hardware) return;
109       float cost = GetCostForOp(op, hardware.getValue());
110       UpdateCost(op, cost, &builder);
111     }
112   });
113 }
114 
115 }  // namespace
116 
GetCostForOp(Operation * op,const std::string & hardware)117 float GetCostForOp(Operation* op, const std::string& hardware) {
118   auto* device_hardware = GetTargetHardware(hardware);
119   if (device_hardware == nullptr) {
120     return kDefaultFixedValuedCost;
121   }
122 
123   return device_hardware->GetOpCost(op);
124 }
125 
GetCostForFunc(func::FuncOp * func,const std::string & hardware)126 float GetCostForFunc(func::FuncOp* func, const std::string& hardware) {
127   auto* device_hardware = GetTargetHardware(hardware);
128   if (device_hardware == nullptr) {
129     return kDefaultFixedValuedCost;
130   }
131 
132   return device_hardware->GetFuncCost(func);
133 }
134 
GetTransferCost(const std::string & from_hardware_str,const std::string & to_hardware_str,func::CallOp from_graph,func::CallOp to_graph)135 float GetTransferCost(const std::string& from_hardware_str,
136                       const std::string& to_hardware_str,
137                       func::CallOp from_graph, func::CallOp to_graph) {
138   auto from_hardware = GetTargetHardware(from_hardware_str);
139   auto to_hardware = GetTargetHardware(to_hardware_str);
140   if (from_hardware == nullptr) {
141     from_graph.emitError(absl::StrCat(
142         "we cannot find the registered hardware: ", from_hardware_str));
143   }
144 
145   if (to_hardware == nullptr) {
146     to_graph.emitError(absl::StrCat("we cannot find the registered hardware: ",
147                                     to_hardware_str));
148   }
149 
150   const int64_t total_size_transferred =
151       GetTransferredTensorBytes(from_graph, to_graph);
152   return to_hardware->GetHardwareSwitchingCost(from_hardware,
153                                                total_size_transferred);
154 }
155 
GetQuantDequantCost(InferenceType from_inference_type,InferenceType to_inference_type,func::CallOp from_graph,func::CallOp to_graph)156 float GetQuantDequantCost(InferenceType from_inference_type,
157                           InferenceType to_inference_type,
158                           func::CallOp from_graph, func::CallOp to_graph) {
159   // Same inference type, no dequant/quant happens.
160   if (from_inference_type == to_inference_type) return 0;
161 
162   const int64_t total_element_count_transferred =
163       GetTransferredElementCount(from_graph, to_graph);
164 
165   if (from_inference_type == FLOAT || from_inference_type == HYBRID) {
166     // FLOAT <-> HYBRID will have no quant/dequant as well.
167     if (to_inference_type == FLOAT || to_inference_type == HYBRID) {
168       return 0;
169     } else if (to_inference_type == QUANTIZED_INT8 ||
170                to_inference_type == QUANTIZED_UINT8) {
171       // QUANT path.
172       return kQuantCost * total_element_count_transferred;
173     }
174   }
175 
176   if (from_inference_type == QUANTIZED_INT8 ||
177       from_inference_type == QUANTIZED_UINT8) {
178     // Dequant path.
179     if (to_inference_type == FLOAT || to_inference_type == HYBRID) {
180       return kDequantCost * total_element_count_transferred;
181     } else if (to_inference_type == QUANTIZED_INT8 ||
182                to_inference_type == QUANTIZED_UINT8) {
183       // Requant path.
184       return kRequantCost * total_element_count_transferred;
185     }
186   }
187 
188   // Default quant/dequant/requant cost.
189   return kDefaultFixedValuedCost;
190 }
191 
CreateGetOpCostPass()192 std::unique_ptr<OperationPass<func::FuncOp>> CreateGetOpCostPass() {
193   return std::make_unique<GetOpCostPass>();
194 }
195 
196 static PassRegistration<GetOpCostPass> pass;
197 
198 }  // namespace tac
199 }  // namespace TFL
200 }  // namespace mlir
201