1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h"
16
17 #include <string>
18
19 #include "absl/container/flat_hash_map.h"
20 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
21 #include "tensorflow/core/tfrt/fallback/cost_recorder.h"
22
23 namespace tensorflow {
24 namespace tfrt_compiler {
25 namespace {
26
27 constexpr int64_t kDefaultCheapCost = 1;
28
GetRankedTensorSize(mlir::TensorType type)29 int64_t GetRankedTensorSize(mlir::TensorType type) {
30 auto shape = type.getShape();
31
32 int64_t size = 1;
33 for (int64_t dim : shape) {
34 // For unknown dimensions, use 1 as the size because it is usually the batch
35 // dimension.
36 //
37 // TODO(chky): Find out a better default number for this case.
38 size *= std::max(kDefaultCheapCost, dim);
39 }
40
41 return size;
42 }
43
InferTensorSize(const CostContext & context,mlir::TensorType type)44 int64_t InferTensorSize(const CostContext& context, mlir::TensorType type) {
45 if (type.hasRank()) return GetRankedTensorSize(type);
46 return context.default_unranked_tensor_size;
47 }
48
49 // The cost function for tf.LookupTableFindV2.
InferLookupTableFindV2Cost(const CostContext & context,mlir::TF::LookupTableFindV2Op op)50 int64_t InferLookupTableFindV2Cost(const CostContext& context,
51 mlir::TF::LookupTableFindV2Op op) {
52 // tf.LookupTableFindV2 ops are usually more costly than tf.AddV2 with the
53 // same input size, as it involves more operations like hashing, map lookup,
54 // etc.
55 constexpr int64_t kLookupTableFindCostScale = 8;
56 constexpr int64_t kLookupTableFindStringKeyCostScale = 16;
57
58 auto value_type = op.values().getType().cast<mlir::TensorType>();
59 auto key_type = op.keys().getType().cast<mlir::TensorType>();
60
61 int64_t output_size = InferTensorSize(context, value_type);
62
63 int64_t cost = kLookupTableFindCostScale * output_size;
64
65 if (key_type.getElementType().isa<mlir::TF::StringType>())
66 cost *= kLookupTableFindStringKeyCostScale;
67
68 return cost;
69 }
70
71 // The cost function for tf.GatherV2.
InferGatherV2Cost(const CostContext & context,mlir::TF::GatherV2Op op)72 int64_t InferGatherV2Cost(const CostContext& context, mlir::TF::GatherV2Op op) {
73 return InferTensorSize(context,
74 op.output().getType().cast<mlir::TensorType>());
75 }
76
77 // The cost function for tf.SparseSegmentSumOp.
78 template <typename OpType>
InferSparseSegmentOpCost(const CostContext & context,OpType op)79 int64_t InferSparseSegmentOpCost(const CostContext& context, OpType op) {
80 return InferTensorSize(
81 context, op.output().getType().template cast<mlir::TensorType>());
82 }
83
84 // CostFunctionRegistry is a map from op names to their cost functions.
85 using CostFunctionRegistry = absl::flat_hash_map<std::string, CostFunction>;
86
RegisterCostFunction(CostFunctionRegistry & registry,absl::string_view op_name,CostFunction cost_function)87 void RegisterCostFunction(CostFunctionRegistry& registry,
88 absl::string_view op_name,
89 CostFunction cost_function) {
90 auto r = registry.try_emplace(op_name, std::move(cost_function));
91 assert(r.second);
92 (void)r;
93 }
94
95 template <typename OpType, typename F>
RegisterCostFunction(CostFunctionRegistry & registry,F f)96 void RegisterCostFunction(CostFunctionRegistry& registry, F f) {
97 RegisterCostFunction(
98 registry, OpType::getOperationName().str(),
99 [f = std::move(f)](const CostContext& context, mlir::Operation* op) {
100 return f(context, llvm::cast<OpType>(op));
101 });
102 }
103
GetCostFunctionRegistry()104 CostFunctionRegistry& GetCostFunctionRegistry() {
105 static auto* const registry = []() {
106 auto* registry = new CostFunctionRegistry;
107 // TODO(chky): Find a more scalable way to register cost functions. One
108 // option is to incorporate it is TF MLIR ODS.
109 RegisterCostFunction<mlir::TF::GatherV2Op>(*registry, InferGatherV2Cost);
110 RegisterCostFunction<mlir::TF::SparseSegmentSumOp>(
111 *registry, InferSparseSegmentOpCost<mlir::TF::SparseSegmentSumOp>);
112 RegisterCostFunction<mlir::TF::SparseSegmentMeanOp>(
113 *registry, InferSparseSegmentOpCost<mlir::TF::SparseSegmentMeanOp>);
114 RegisterCostFunction<mlir::TF::SparseSegmentSqrtNOp>(
115 *registry, InferSparseSegmentOpCost<mlir::TF::SparseSegmentSqrtNOp>);
116 RegisterCostFunction<mlir::TF::LookupTableFindV2Op>(
117 *registry, InferLookupTableFindV2Cost);
118 return registry;
119 }();
120 return *registry;
121 }
122
123 } // namespace
124
RegisterCostFunction(absl::string_view op_name,CostFunction cost_function)125 void RegisterCostFunction(absl::string_view op_name,
126 CostFunction cost_function) {
127 RegisterCostFunction(GetCostFunctionRegistry(), op_name,
128 std::move(cost_function));
129 }
130
GetCost(mlir::Operation * op,int64_t op_key) const131 int64_t CostAnalysis::GetCost(mlir::Operation* op, int64_t op_key) const {
132 // Try to use its measured cost.
133 const auto& measured_cost_map = op_cost_map_proto_.op_cost_map();
134 if (const auto op_cost = measured_cost_map.find(op_key);
135 op_cost != measured_cost_map.end()) {
136 return op_cost->second;
137 }
138
139 assert(cost_map_.count(op) > 0);
140 return cost_map_.lookup(op);
141 }
142
AnalyzeArguments(mlir::func::FuncOp func_op)143 void CostAnalysis::AnalyzeArguments(mlir::func::FuncOp func_op) {
144 // Use the max size among function inputs as the default size of dynamic
145 // shaped tensors in the function.
146 for (auto arg : func_op.getArguments()) {
147 auto type = arg.getType().cast<mlir::TensorType>();
148 if (type.hasRank()) {
149 max_arg_size_ = std::max(max_arg_size_, GetRankedTensorSize(type));
150 }
151 }
152 }
153
AnalyzeBlock(mlir::Block * block)154 void CostAnalysis::AnalyzeBlock(mlir::Block* block) {
155 for (auto& op : *block) {
156 EvaluateCost(&op);
157 }
158 }
159
EvaluateCost(mlir::Operation * op)160 void CostAnalysis::EvaluateCost(mlir::Operation* op) {
161 if (!llvm::isa<mlir::TF::TensorFlowDialect>(op->getDialect())) {
162 cost_map_[op] = max_arg_size_;
163 return;
164 }
165
166 // These ops are cheap regardless of their input sizes.
167 //
168 // TODO(chky): Find a more scalable way to figure out cheap ops.
169 if (llvm::isa<mlir::TF::ShapeOp, mlir::TF::StridedSliceOp,
170 mlir::TF::ReshapeOp, mlir::TF::ExpandDimsOp>(op)) {
171 cost_map_[op] = kDefaultCheapCost;
172 return;
173 }
174
175 // Try to use its cost function if it is registered.
176 const auto& registry = GetCostFunctionRegistry();
177 absl::string_view op_name = op->getName().getStringRef();
178 auto iter = registry.find(op_name);
179 if (iter != registry.end()) {
180 CostContext context;
181 context.default_unranked_tensor_size = max_arg_size_;
182 cost_map_[op] = iter->second(context, op);
183 return;
184 }
185
186 // For other ops, use the sum of input sizes as its cost.
187 int64_t cost = kDefaultCheapCost;
188 for (auto operand : op->getOperands()) {
189 auto type = operand.getType().cast<mlir::TensorType>();
190 if (type.hasRank()) {
191 cost += GetRankedTensorSize(type);
192 } else {
193 // For unranked tensors, use the max size among the input tensors. This is
194 // because the only dynamic information of the function should be the
195 // input, so the size of dynamic tensors should be usually capped by
196 // inputs' sizes.
197 cost += max_arg_size_;
198 }
199 }
200
201 cost_map_[op] = cost;
202 }
203
ReadMeasuredCosts()204 Status CostAnalysis::ReadMeasuredCosts() {
205 const char* env_var = getenv("TF_TFRT_MEASURED_COST_PATH");
206 // No need to read because the cost measurement is disabled.
207 if (env_var == nullptr) return Status::OK();
208
209 tensorflow::Env* env = Env::Default();
210 const std::string measured_cost_path(env_var);
211 TF_RETURN_IF_ERROR(env->FileExists(measured_cost_path));
212 return ReadTextProto(env, measured_cost_path, &op_cost_map_proto_);
213 }
214
215 } // namespace tfrt_compiler
216 } // namespace tensorflow
217