1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/experimental/tac/tac_module.h"
16
17 #include <memory>
18 #include <string>
19
20 #include "absl/status/status.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22 #include "mlir/Transforms/Passes.h" // from @llvm-project
23 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
24 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h"
25 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
27
28 namespace mlir {
29 namespace TFL {
30 namespace tac {
31 namespace {
32 // TODO(b/177376459): We should make this configureable.
AddExportTFLPass(mlir::OpPassManager * pass_manager,bool enable_inliner)33 void AddExportTFLPass(mlir::OpPassManager* pass_manager, bool enable_inliner) {
34 if (enable_inliner) pass_manager->addPass(mlir::createInlinerPass());
35 pass_manager->addPass(mlir::createSymbolDCEPass());
36 pass_manager->addNestedPass<mlir::func::FuncOp>(
37 mlir::createCanonicalizerPass());
38 pass_manager->addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
39 }
40 } // namespace
41
42 // TODO(b/177376459): We should make this configureable.
AddTACPass(mlir::OpPassManager * pass_manager,llvm::ArrayRef<std::string> device_specs)43 void TacModule::AddTACPass(mlir::OpPassManager* pass_manager,
44 llvm::ArrayRef<std::string> device_specs) {
45 pass_manager->addPass(mlir::TFL::tac::CreateTargetAnnotationPass(this));
46 pass_manager->addPass(mlir::TFL::tac::CreateRaiseTargetSubgraphsPass());
47 pass_manager->addPass(mlir::TFL::tac::CreateFoldConstantsToSubgraphPass(
48 /*fold_all_constants=*/false));
49 pass_manager->addPass(
50 mlir::TFL::tac::CreateAlternativeSubgraphPass(device_specs));
51 if (options_.legalize_to_tflite_ops) {
52 // After we creat the alternative subgraph, we can still do canonicalization
53 // legalization & other optimizations as long as we're not inlining the
54 // function.
55 // And in fact, we probably need to do the proper legalization, for the
56 // compute cost to work. (in case we added some TF ops)
57 pass_manager->addPass(mlir::TFL::CreatePrepareTFPass(
58 /*unfold_batch_matmul=*/true,
59 /*allow_bf16_and_f16_type_legalization=*/false));
60 pass_manager->addNestedPass<mlir::func::FuncOp>(
61 mlir::createCanonicalizerPass());
62 pass_manager->addPass(
63 mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
64 pass_manager->addPass(
65 mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true));
66 }
67
68 pass_manager->addPass(mlir::TFL::tac::CreateComputeCostPass());
69 pass_manager->addPass(mlir::TFL::tac::CreatePickSubgraphsPass());
70 // After this pass, we may consider add a pass to merge small functions into
71 // large functions (and maybe other metadata as well).
72 }
73
GetTargetHardware(const std::string & hardware_name) const74 const tac::TargetHardware* TacModule::GetTargetHardware(
75 const std::string& hardware_name) const {
76 for (auto& hardware : backends_) {
77 if (GetHardwareName(hardware.get()) == hardware_name) return hardware.get();
78 }
79 return nullptr;
80 }
81
RunTacPasses(mlir::ModuleOp * module,bool debug_mode)82 absl::Status TacModule::RunTacPasses(mlir::ModuleOp* module, bool debug_mode) {
83 mlir::PassManager pm(module->getContext(),
84 mlir::OpPassManager::Nesting::Implicit);
85 AddTACPass(&pm, options_.hardware_backends);
86 if (!debug_mode) {
87 AddExportTFLPass(&pm, options_.enable_inliner);
88 }
89
90 mlir::StatusScopedDiagnosticHandler statusHandler(module->getContext(),
91 /*propagate=*/true);
92 if (failed(pm.run(*module))) {
93 return absl::InternalError("conversion error");
94 }
95 return absl::OkStatus();
96 }
97
98 std::vector<std::unique_ptr<tac::TargetHardware>>
InstantiateBackends()99 TacModule::InstantiateBackends() {
100 std::vector<std::unique_ptr<tac::TargetHardware>> backends;
101 for (const auto& hardware_name : options_.hardware_backends) {
102 auto factory = tac::GetTargetHardwareFactory(hardware_name);
103 backends.emplace_back(factory());
104 backends.back()->Init();
105 }
106 return backends;
107 }
108
Run()109 absl::Status TacModule::Run() {
110 // Construct all backends.
111 backends_ = InstantiateBackends();
112 const_backends_.resize(backends_.size());
113 for (const auto& backend : backends_)
114 const_backends_.emplace_back(backend.get());
115
116 if (!importer_) {
117 return absl::Status(absl::StatusCode::kFailedPrecondition,
118 "Null Importer provided");
119 }
120 if (!exporter_) {
121 return absl::Status(absl::StatusCode::kFailedPrecondition,
122 "Null Exporter provided");
123 }
124
125 auto module_status = importer_->Import();
126 if (!module_status.ok()) {
127 return module_status.status();
128 }
129 auto module = module_status->get();
130 auto* context = module->getContext();
131 context->appendDialectRegistry(registry_);
132 context->loadAllAvailableDialects();
133
134 // Run TAC passes.
135 auto status = RunTacPasses(&module, options_.debug_mode);
136
137 if (!status.ok()) {
138 return status;
139 }
140
141 return exporter_->Export(module);
142 }
143
RegisterExtraDialects(mlir::DialectRegistry & registry)144 void TacModule::RegisterExtraDialects(mlir::DialectRegistry& registry) {
145 registry.appendTo(registry_);
146 }
147 } // namespace tac
148 } // namespace TFL
149 } // namespace mlir
150