xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/experimental/tac/tac_module.h"
16 
17 #include <memory>
18 #include <string>
19 
20 #include "absl/status/status.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/Transforms/Passes.h"  // from @llvm-project
23 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
24 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h"
25 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
27 
28 namespace mlir {
29 namespace TFL {
30 namespace tac {
31 namespace {
32 // TODO(b/177376459): We should make this configureable.
AddExportTFLPass(mlir::OpPassManager * pass_manager,bool enable_inliner)33 void AddExportTFLPass(mlir::OpPassManager* pass_manager, bool enable_inliner) {
34   if (enable_inliner) pass_manager->addPass(mlir::createInlinerPass());
35   pass_manager->addPass(mlir::createSymbolDCEPass());
36   pass_manager->addNestedPass<mlir::func::FuncOp>(
37       mlir::createCanonicalizerPass());
38   pass_manager->addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
39 }
40 }  // namespace
41 
42 // TODO(b/177376459): We should make this configureable.
AddTACPass(mlir::OpPassManager * pass_manager,llvm::ArrayRef<std::string> device_specs)43 void TacModule::AddTACPass(mlir::OpPassManager* pass_manager,
44                            llvm::ArrayRef<std::string> device_specs) {
45   pass_manager->addPass(mlir::TFL::tac::CreateTargetAnnotationPass(this));
46   pass_manager->addPass(mlir::TFL::tac::CreateRaiseTargetSubgraphsPass());
47   pass_manager->addPass(mlir::TFL::tac::CreateFoldConstantsToSubgraphPass(
48       /*fold_all_constants=*/false));
49   pass_manager->addPass(
50       mlir::TFL::tac::CreateAlternativeSubgraphPass(device_specs));
51   if (options_.legalize_to_tflite_ops) {
52     // After we creat the alternative subgraph, we can still do canonicalization
53     // legalization & other optimizations as long as we're not inlining the
54     // function.
55     // And in fact, we probably need to do the proper legalization, for the
56     // compute cost to work. (in case we added some TF ops)
57     pass_manager->addPass(mlir::TFL::CreatePrepareTFPass(
58         /*unfold_batch_matmul=*/true,
59         /*allow_bf16_and_f16_type_legalization=*/false));
60     pass_manager->addNestedPass<mlir::func::FuncOp>(
61         mlir::createCanonicalizerPass());
62     pass_manager->addPass(
63         mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
64     pass_manager->addPass(
65         mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true));
66   }
67 
68   pass_manager->addPass(mlir::TFL::tac::CreateComputeCostPass());
69   pass_manager->addPass(mlir::TFL::tac::CreatePickSubgraphsPass());
70   // After this pass, we may consider add a pass to merge small functions into
71   // large functions (and maybe other metadata as well).
72 }
73 
GetTargetHardware(const std::string & hardware_name) const74 const tac::TargetHardware* TacModule::GetTargetHardware(
75     const std::string& hardware_name) const {
76   for (auto& hardware : backends_) {
77     if (GetHardwareName(hardware.get()) == hardware_name) return hardware.get();
78   }
79   return nullptr;
80 }
81 
RunTacPasses(mlir::ModuleOp * module,bool debug_mode)82 absl::Status TacModule::RunTacPasses(mlir::ModuleOp* module, bool debug_mode) {
83   mlir::PassManager pm(module->getContext(),
84                        mlir::OpPassManager::Nesting::Implicit);
85   AddTACPass(&pm, options_.hardware_backends);
86   if (!debug_mode) {
87     AddExportTFLPass(&pm, options_.enable_inliner);
88   }
89 
90   mlir::StatusScopedDiagnosticHandler statusHandler(module->getContext(),
91                                                     /*propagate=*/true);
92   if (failed(pm.run(*module))) {
93     return absl::InternalError("conversion error");
94   }
95   return absl::OkStatus();
96 }
97 
98 std::vector<std::unique_ptr<tac::TargetHardware>>
InstantiateBackends()99 TacModule::InstantiateBackends() {
100   std::vector<std::unique_ptr<tac::TargetHardware>> backends;
101   for (const auto& hardware_name : options_.hardware_backends) {
102     auto factory = tac::GetTargetHardwareFactory(hardware_name);
103     backends.emplace_back(factory());
104     backends.back()->Init();
105   }
106   return backends;
107 }
108 
Run()109 absl::Status TacModule::Run() {
110   // Construct all backends.
111   backends_ = InstantiateBackends();
112   const_backends_.resize(backends_.size());
113   for (const auto& backend : backends_)
114     const_backends_.emplace_back(backend.get());
115 
116   if (!importer_) {
117     return absl::Status(absl::StatusCode::kFailedPrecondition,
118                         "Null Importer provided");
119   }
120   if (!exporter_) {
121     return absl::Status(absl::StatusCode::kFailedPrecondition,
122                         "Null Exporter provided");
123   }
124 
125   auto module_status = importer_->Import();
126   if (!module_status.ok()) {
127     return module_status.status();
128   }
129   auto module = module_status->get();
130   auto* context = module->getContext();
131   context->appendDialectRegistry(registry_);
132   context->loadAllAvailableDialects();
133 
134   // Run TAC passes.
135   auto status = RunTacPasses(&module, options_.debug_mode);
136 
137   if (!status.ok()) {
138     return status;
139   }
140 
141   return exporter_->Export(module);
142 }
143 
RegisterExtraDialects(mlir::DialectRegistry & registry)144 void TacModule::RegisterExtraDialects(mlir::DialectRegistry& registry) {
145   registry.appendTo(registry_);
146 }
147 }  // namespace tac
148 }  // namespace TFL
149 }  // namespace mlir
150