xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TAC_MODULE_H_
16 #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TAC_MODULE_H_
17 
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/Pass/PassManager.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
26 #include "tensorflow/compiler/mlir/lite/experimental/tac/tac_importer_exporter.h"
27 
28 namespace mlir {
29 namespace TFL {
30 namespace tac {
31 
32 // Main class for using Target Aware Conversion (TAC).
33 // To run TAC:
34 // 1) users should create object form this class, with desired options
35 //    (TacModule::Options).
36 // 2) Use SetImporter/SetExporter to the desired importer
37 //    and exporter.
38 // 3) Call Run()
39 //
40 // The module fetches all TargetHardware backends registered in the binary
41 // and only create TargetHardware requested in Options.
42 //
43 // This class is not thread safe.
44 class TacModule {
45  public:
46   // TAC options. Contains knobs to configure TAC as needed.
47   struct Options {
48     // List of names for the requested Target hardware.
49     std::vector<std::string> hardware_backends;
50     // Debug mode.
51     // This will output different alternative subgraphs in mlir format for debug
52     // purpose.
53     bool debug_mode = false;
54     // Whether to enable inliner passes or not.
55     bool enable_inliner = false;
56     // Whether to legalize ops to TFLite ops before exporting.
57     bool legalize_to_tflite_ops = false;
58   };
59 
~TacModule()60   virtual ~TacModule() {}
61 
TacModule(const Options & options)62   explicit TacModule(const Options& options) : options_(options) {}
63 
SetImporter(std::unique_ptr<TacImporter> importer)64   void SetImporter(std::unique_ptr<TacImporter> importer) {
65     importer_ = std::move(importer);
66   }
67 
SetExporter(std::unique_ptr<TacExporter> exporter)68   void SetExporter(std::unique_ptr<TacExporter> exporter) {
69     exporter_ = std::move(exporter);
70   }
71 
72   // Returns pointer to the TargetHardware that is identified by 'hardware_name'
73   // Returns NULL If no hardware with this name found.
74   const tac::TargetHardware* GetTargetHardware(
75       const std::string& hardware_name) const;
76 
77   // Runs the TAC workflow, configured as in the options provided during
78   // construction.
79   // SetImporter/SetExporter should be called prior to invoking `Run`.
80   // Returns Status of the Run.
81   virtual absl::Status Run();
82 
83   // Returns all available hardware backends registered in this module
84   // instance.
GetAvailableHardwares()85   const std::vector<const tac::TargetHardware*>& GetAvailableHardwares() const {
86     return const_backends_;
87   }
88 
89   // Registers all dialects in 'registry' with the module.
90   // This to allow clients to register extra dialects required.
91   void RegisterExtraDialects(mlir::DialectRegistry& registry);
92 
93  protected:
94   // Adds TAC passes to the 'pass_manager'.
95   virtual void AddTACPass(mlir::OpPassManager* pass_manager,
96                           llvm::ArrayRef<std::string> device_specs);
97 
98  private:
99   // Runs all TAC passes on the provided module.
100   absl::Status RunTacPasses(mlir::ModuleOp* module, bool debug_mode = false);
101 
102   // Create instances of all registered hardwares.
103   std::vector<std::unique_ptr<tac::TargetHardware>> InstantiateBackends();
104 
105   std::unique_ptr<TacImporter> importer_;
106   std::unique_ptr<TacExporter> exporter_;
107   // Owned list of all target hardware backends.
108   std::vector<std::unique_ptr<tac::TargetHardware>> backends_;
109   // Holder for const pointers for the data in 'backends_'
110   std::vector<const tac::TargetHardware*> const_backends_;
111   // Extra dialects requested by the user.
112   mlir::DialectRegistry registry_;
113 
114   const Options options_;
115 };
116 
117 }  // namespace tac
118 }  // namespace TFL
119 }  // namespace mlir
120 
121 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TAC_MODULE_H_
122