xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h"
16 
17 #include <string>
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
21 #include "tensorflow/core/tfrt/fallback/cost_recorder.h"
22 
23 namespace tensorflow {
24 namespace tfrt_compiler {
25 namespace {
26 
27 constexpr int64_t kDefaultCheapCost = 1;
28 
GetRankedTensorSize(mlir::TensorType type)29 int64_t GetRankedTensorSize(mlir::TensorType type) {
30   auto shape = type.getShape();
31 
32   int64_t size = 1;
33   for (int64_t dim : shape) {
34     // For unknown dimensions, use 1 as the size because it is usually the batch
35     // dimension.
36     //
37     // TODO(chky): Find out a better default number for this case.
38     size *= std::max(kDefaultCheapCost, dim);
39   }
40 
41   return size;
42 }
43 
InferTensorSize(const CostContext & context,mlir::TensorType type)44 int64_t InferTensorSize(const CostContext& context, mlir::TensorType type) {
45   if (type.hasRank()) return GetRankedTensorSize(type);
46   return context.default_unranked_tensor_size;
47 }
48 
49 // The cost function for tf.LookupTableFindV2.
InferLookupTableFindV2Cost(const CostContext & context,mlir::TF::LookupTableFindV2Op op)50 int64_t InferLookupTableFindV2Cost(const CostContext& context,
51                                    mlir::TF::LookupTableFindV2Op op) {
52   // tf.LookupTableFindV2 ops are usually more costly than tf.AddV2 with the
53   // same input size, as it involves more operations like hashing, map lookup,
54   // etc.
55   constexpr int64_t kLookupTableFindCostScale = 8;
56   constexpr int64_t kLookupTableFindStringKeyCostScale = 16;
57 
58   auto value_type = op.values().getType().cast<mlir::TensorType>();
59   auto key_type = op.keys().getType().cast<mlir::TensorType>();
60 
61   int64_t output_size = InferTensorSize(context, value_type);
62 
63   int64_t cost = kLookupTableFindCostScale * output_size;
64 
65   if (key_type.getElementType().isa<mlir::TF::StringType>())
66     cost *= kLookupTableFindStringKeyCostScale;
67 
68   return cost;
69 }
70 
71 // The cost function for tf.GatherV2.
InferGatherV2Cost(const CostContext & context,mlir::TF::GatherV2Op op)72 int64_t InferGatherV2Cost(const CostContext& context, mlir::TF::GatherV2Op op) {
73   return InferTensorSize(context,
74                          op.output().getType().cast<mlir::TensorType>());
75 }
76 
77 // The cost function for tf.SparseSegmentSumOp.
78 template <typename OpType>
InferSparseSegmentOpCost(const CostContext & context,OpType op)79 int64_t InferSparseSegmentOpCost(const CostContext& context, OpType op) {
80   return InferTensorSize(
81       context, op.output().getType().template cast<mlir::TensorType>());
82 }
83 
84 // CostFunctionRegistry is a map from op names to their cost functions.
85 using CostFunctionRegistry = absl::flat_hash_map<std::string, CostFunction>;
86 
RegisterCostFunction(CostFunctionRegistry & registry,absl::string_view op_name,CostFunction cost_function)87 void RegisterCostFunction(CostFunctionRegistry& registry,
88                           absl::string_view op_name,
89                           CostFunction cost_function) {
90   auto r = registry.try_emplace(op_name, std::move(cost_function));
91   assert(r.second);
92   (void)r;
93 }
94 
95 template <typename OpType, typename F>
RegisterCostFunction(CostFunctionRegistry & registry,F f)96 void RegisterCostFunction(CostFunctionRegistry& registry, F f) {
97   RegisterCostFunction(
98       registry, OpType::getOperationName().str(),
99       [f = std::move(f)](const CostContext& context, mlir::Operation* op) {
100         return f(context, llvm::cast<OpType>(op));
101       });
102 }
103 
GetCostFunctionRegistry()104 CostFunctionRegistry& GetCostFunctionRegistry() {
105   static auto* const registry = []() {
106     auto* registry = new CostFunctionRegistry;
107     // TODO(chky): Find a more scalable way to register cost functions. One
108     // option is to incorporate it is TF MLIR ODS.
109     RegisterCostFunction<mlir::TF::GatherV2Op>(*registry, InferGatherV2Cost);
110     RegisterCostFunction<mlir::TF::SparseSegmentSumOp>(
111         *registry, InferSparseSegmentOpCost<mlir::TF::SparseSegmentSumOp>);
112     RegisterCostFunction<mlir::TF::SparseSegmentMeanOp>(
113         *registry, InferSparseSegmentOpCost<mlir::TF::SparseSegmentMeanOp>);
114     RegisterCostFunction<mlir::TF::SparseSegmentSqrtNOp>(
115         *registry, InferSparseSegmentOpCost<mlir::TF::SparseSegmentSqrtNOp>);
116     RegisterCostFunction<mlir::TF::LookupTableFindV2Op>(
117         *registry, InferLookupTableFindV2Cost);
118     return registry;
119   }();
120   return *registry;
121 }
122 
123 }  // namespace
124 
RegisterCostFunction(absl::string_view op_name,CostFunction cost_function)125 void RegisterCostFunction(absl::string_view op_name,
126                           CostFunction cost_function) {
127   RegisterCostFunction(GetCostFunctionRegistry(), op_name,
128                        std::move(cost_function));
129 }
130 
GetCost(mlir::Operation * op,int64_t op_key) const131 int64_t CostAnalysis::GetCost(mlir::Operation* op, int64_t op_key) const {
132   // Try to use its measured cost.
133   const auto& measured_cost_map = op_cost_map_proto_.op_cost_map();
134   if (const auto op_cost = measured_cost_map.find(op_key);
135       op_cost != measured_cost_map.end()) {
136     return op_cost->second;
137   }
138 
139   assert(cost_map_.count(op) > 0);
140   return cost_map_.lookup(op);
141 }
142 
AnalyzeArguments(mlir::func::FuncOp func_op)143 void CostAnalysis::AnalyzeArguments(mlir::func::FuncOp func_op) {
144   // Use the max size among function inputs as the default size of dynamic
145   // shaped tensors in the function.
146   for (auto arg : func_op.getArguments()) {
147     auto type = arg.getType().cast<mlir::TensorType>();
148     if (type.hasRank()) {
149       max_arg_size_ = std::max(max_arg_size_, GetRankedTensorSize(type));
150     }
151   }
152 }
153 
AnalyzeBlock(mlir::Block * block)154 void CostAnalysis::AnalyzeBlock(mlir::Block* block) {
155   for (auto& op : *block) {
156     EvaluateCost(&op);
157   }
158 }
159 
EvaluateCost(mlir::Operation * op)160 void CostAnalysis::EvaluateCost(mlir::Operation* op) {
161   if (!llvm::isa<mlir::TF::TensorFlowDialect>(op->getDialect())) {
162     cost_map_[op] = max_arg_size_;
163     return;
164   }
165 
166   // These ops are cheap regardless of their input sizes.
167   //
168   // TODO(chky): Find a more scalable way to figure out cheap ops.
169   if (llvm::isa<mlir::TF::ShapeOp, mlir::TF::StridedSliceOp,
170                 mlir::TF::ReshapeOp, mlir::TF::ExpandDimsOp>(op)) {
171     cost_map_[op] = kDefaultCheapCost;
172     return;
173   }
174 
175   // Try to use its cost function if it is registered.
176   const auto& registry = GetCostFunctionRegistry();
177   absl::string_view op_name = op->getName().getStringRef();
178   auto iter = registry.find(op_name);
179   if (iter != registry.end()) {
180     CostContext context;
181     context.default_unranked_tensor_size = max_arg_size_;
182     cost_map_[op] = iter->second(context, op);
183     return;
184   }
185 
186   // For other ops, use the sum of input sizes as its cost.
187   int64_t cost = kDefaultCheapCost;
188   for (auto operand : op->getOperands()) {
189     auto type = operand.getType().cast<mlir::TensorType>();
190     if (type.hasRank()) {
191       cost += GetRankedTensorSize(type);
192     } else {
193       // For unranked tensors, use the max size among the input tensors. This is
194       // because the only dynamic information of the function should be the
195       // input, so the size of dynamic tensors should be usually capped by
196       // inputs' sizes.
197       cost += max_arg_size_;
198     }
199   }
200 
201   cost_map_[op] = cost;
202 }
203 
ReadMeasuredCosts()204 Status CostAnalysis::ReadMeasuredCosts() {
205   const char* env_var = getenv("TF_TFRT_MEASURED_COST_PATH");
206   // No need to read because the cost measurement is disabled.
207   if (env_var == nullptr) return Status::OK();
208 
209   tensorflow::Env* env = Env::Default();
210   const std::string measured_cost_path(env_var);
211   TF_RETURN_IF_ERROR(env->FileExists(measured_cost_path));
212   return ReadTextProto(env, measured_cost_path, &op_cost_map_proto_);
213 }
214 
215 }  // namespace tfrt_compiler
216 }  // namespace tensorflow
217