1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TAC_MODULE_H_ 16 #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TAC_MODULE_H_ 17 18 #include <string> 19 #include <utility> 20 #include <vector> 21 22 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project 23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 24 #include "mlir/Pass/PassManager.h" // from @llvm-project 25 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h" 26 #include "tensorflow/compiler/mlir/lite/experimental/tac/tac_importer_exporter.h" 27 28 namespace mlir { 29 namespace TFL { 30 namespace tac { 31 32 // Main class for using Target Aware Conversion (TAC). 33 // To run TAC: 34 // 1) users should create object form this class, with desired options 35 // (TacModule::Options). 36 // 2) Use SetImporter/SetExporter to the desired importer 37 // and exporter. 38 // 3) Call Run() 39 // 40 // The module fetches all TargetHardware backends registered in the binary 41 // and only create TargetHardware requested in Options. 42 // 43 // This class is not thread safe. 44 class TacModule { 45 public: 46 // TAC options. Contains knobs to configure TAC as needed. 47 struct Options { 48 // List of names for the requested Target hardware. 49 std::vector<std::string> hardware_backends; 50 // Debug mode. 51 // This will output different alternative subgraphs in mlir format for debug 52 // purpose. 53 bool debug_mode = false; 54 // Whether to enable inliner passes or not. 55 bool enable_inliner = false; 56 // Whether to legalize ops to TFLite ops before exporting. 57 bool legalize_to_tflite_ops = false; 58 }; 59 ~TacModule()60 virtual ~TacModule() {} 61 TacModule(const Options & options)62 explicit TacModule(const Options& options) : options_(options) {} 63 SetImporter(std::unique_ptr<TacImporter> importer)64 void SetImporter(std::unique_ptr<TacImporter> importer) { 65 importer_ = std::move(importer); 66 } 67 SetExporter(std::unique_ptr<TacExporter> exporter)68 void SetExporter(std::unique_ptr<TacExporter> exporter) { 69 exporter_ = std::move(exporter); 70 } 71 72 // Returns pointer to the TargetHardware that is identified by 'hardware_name' 73 // Returns NULL If no hardware with this name found. 74 const tac::TargetHardware* GetTargetHardware( 75 const std::string& hardware_name) const; 76 77 // Runs the TAC workflow, configured as in the options provided during 78 // construction. 79 // SetImporter/SetExporter should be called prior to invoking `Run`. 80 // Returns Status of the Run. 81 virtual absl::Status Run(); 82 83 // Returns all available hardware backends registered in this module 84 // instance. GetAvailableHardwares()85 const std::vector<const tac::TargetHardware*>& GetAvailableHardwares() const { 86 return const_backends_; 87 } 88 89 // Registers all dialects in 'registry' with the module. 90 // This to allow clients to register extra dialects required. 91 void RegisterExtraDialects(mlir::DialectRegistry& registry); 92 93 protected: 94 // Adds TAC passes to the 'pass_manager'. 95 virtual void AddTACPass(mlir::OpPassManager* pass_manager, 96 llvm::ArrayRef<std::string> device_specs); 97 98 private: 99 // Runs all TAC passes on the provided module. 100 absl::Status RunTacPasses(mlir::ModuleOp* module, bool debug_mode = false); 101 102 // Create instances of all registered hardwares. 103 std::vector<std::unique_ptr<tac::TargetHardware>> InstantiateBackends(); 104 105 std::unique_ptr<TacImporter> importer_; 106 std::unique_ptr<TacExporter> exporter_; 107 // Owned list of all target hardware backends. 108 std::vector<std::unique_ptr<tac::TargetHardware>> backends_; 109 // Holder for const pointers for the data in 'backends_' 110 std::vector<const tac::TargetHardware*> const_backends_; 111 // Extra dialects requested by the user. 112 mlir::DialectRegistry registry_; 113 114 const Options options_; 115 }; 116 117 } // namespace tac 118 } // namespace TFL 119 } // namespace mlir 120 121 #endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TAC_MODULE_H_ 122