xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_ANALYSIS_COST_ANALYSIS_H_
17 #define TENSORFLOW_COMPILER_MLIR_TFRT_ANALYSIS_COST_ANALYSIS_H_
18 
19 #include "absl/strings/string_view.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "tensorflow/core/platform/status.h"
23 #include "tensorflow/core/tfrt/fallback/op_cost_map.pb.h"
24 
25 namespace tensorflow {
26 namespace tfrt_compiler {
27 
28 // Analyze costs for tensorflow operations.
29 //
30 // The current heuristic used is quite simple, which is to calculate the total
31 // size of input tensors. The exception is that ops whose cost is irrelevant to
32 // input sizes, such as tf.Shape and tf.Reshape, are whitelisted to have cheap
33 // cost. This cost analysis is expected to be used conservatively (eg. use a low
34 // threshold to decide whether a cost is cheap or expensive), as it might not be
35 // accurate in some cases.
36 //
37 class CostAnalysis {
38  public:
CostAnalysis(mlir::func::FuncOp func_op)39   explicit CostAnalysis(mlir::func::FuncOp func_op) {
40     AnalyzeArguments(func_op);
41     TF_CHECK_OK(ReadMeasuredCosts());
42     AnalyzeBlock(&func_op.front());
43   }
44 
45   int64_t GetCost(mlir::Operation* op, int64_t op_key) const;
46 
47  private:
48   void AnalyzeArguments(mlir::func::FuncOp func_op);
49   void AnalyzeBlock(mlir::Block* block);
50   void EvaluateCost(mlir::Operation* op);
51   Status ReadMeasuredCosts();
52 
53   int64_t max_arg_size_ = 1;
54   llvm::DenseMap<mlir::Operation*, int64_t> cost_map_;
55   tfrt_stub::OpCostMapProto op_cost_map_proto_;
56 };
57 
58 struct CostContext {
59   int64_t default_unranked_tensor_size;
60 };
61 
62 using CostFunction =
63     std::function<int64_t(const CostContext&, mlir::Operation*)>;
64 
65 void RegisterCostFunction(absl::string_view op_name,
66                           CostFunction cost_function);
67 
68 template <typename OpType, typename F>
RegisterCostFunction(F f)69 void RegisterCostFunction(F f) {
70   RegisterCostFunction(
71       OpType::getOperationName().str(),
72       [f = std::move(f)](const CostContext& context, mlir::Operation* op) {
73         return f(context, llvm::cast<OpType>(op));
74       });
75 }
76 
77 template <typename OpType>
78 struct CostFunctionRegistration {
CostFunctionRegistrationCostFunctionRegistration79   explicit CostFunctionRegistration(
80       std::function<int64_t(const CostContext&, OpType)> cost_function) {
81     RegisterCostFunction<OpType>(std::move(cost_function));
82   }
83 };
84 
85 }  // namespace tfrt_compiler
86 }  // namespace tensorflow
87 
88 #endif  // TENSORFLOW_COMPILER_MLIR_TFRT_ANALYSIS_COST_ANALYSIS_H_
89