xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/experimental/tac/transforms/compute_cost.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/DenseSet.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
26 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
27 #include "mlir/IR/Operation.h"  // from @llvm-project
28 #include "mlir/Pass/Pass.h"  // from @llvm-project
29 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
30 #include "mlir/Support/LLVM.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/cost.h"
32 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h"
33 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
34 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
35 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.h"
36 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h"
37 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
38 
39 namespace mlir {
40 namespace TFL {
41 namespace tac {
42 namespace {
43 
44 // We will caculate the total compute cost for each Func Op.
45 //
46 // The compute cost is simply an add-up of the costs of all the operations
47 // within the FuncOp. (Excluding const ops since they're just "data".)
48 // We will ignore quant/dequant/requant costs within the Func Op as well,
49 // intuition:
50 //
51 // The assumpution is that quant/dequant/requant will only happen at the begin
52 // and the end of the FuncOp (basically the "boundaries" of the subgraph).
53 // So we can imagine if multiple "same-inference-typed" graph are presented at
54 // the same time, the quant/dequant ops pair can be squashed:
55 //
56 //         dequant         ------------
57 //            |
58 //          ops...             FuncOp1
59 //            |
60 //         quant           -------------
61 //           |         <--- can be squashed
62 //         dequant         -------------
63 //            |
64 //        ops...               FuncOp2
65 //           |
66 //         quant          ---------------
67 //
68 // But it's true quant & dequant ops can happen "within" the FuncOp as well,
69 // normally as "quantization params" adjust. We should check more careful to
70 // include those as those ops wouldn't be "squashed".
71 
72 class ComputeCostPass
73     : public mlir::PassWrapper<ComputeCostPass, mlir::OperationPass<ModuleOp>> {
74  public:
75   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ComputeCostPass)
76 
77  private:
getArgument() const78   llvm::StringRef getArgument() const final { return "tfl-compute-cost"; }
getDescription() const79   llvm::StringRef getDescription() const final {
80     return "Compute the total cost for each available subgraph.";
81   }
82   void runOnOperation() override;
83 };
84 
runOnOperation()85 void ComputeCostPass::runOnOperation() {
86   auto module = getOperation();
87 
88   for (auto func : module.getOps<func::FuncOp>()) {
89     // We only care about those functions annotated with "tac.interface_name".
90     auto interface_name = GetInterFaceName(func);
91     if (!interface_name.has_value()) continue;
92 
93     auto target = GetTargetAnnotation(func);
94     if (!target.has_value()) {
95       func.emitError("we cannot get hardware info for this function.");
96       signalPassFailure();
97     }
98 
99     float total_cost = GetCostForFunc(&func, target.getValue());
100     OpBuilder builder(func);
101     UpdateCost(func, total_cost, &builder);
102   }
103 }
104 
105 }  // namespace
106 
CreateComputeCostPass()107 std::unique_ptr<OperationPass<ModuleOp>> CreateComputeCostPass() {
108   return std::make_unique<ComputeCostPass>();
109 }
110 
111 static PassRegistration<ComputeCostPass> pass;
112 
113 }  // namespace tac
114 }  // namespace TFL
115 }  // namespace mlir
116