xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TARGET_HARDWARE_H_
16 #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TARGET_HARDWARE_H_
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
24 #include "mlir/IR/Operation.h"  // from @llvm-project
25 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
26 #include "mlir/Support/TypeID.h"  // from @llvm-project
27 
28 namespace mlir {
29 namespace TFL {
30 namespace tac {
31 
32 // Default fixed values for ops.
33 constexpr static float kDefaultFixedValuedCost = 1000000.0;
34 
35 // This is just fake data.
36 constexpr static float kCrossHardwareTransferPerByteCost = 5.0f;
37 
38 // This is just fake data.
39 constexpr static float kCrossHardwareTransferFixedCost = 10.f;
40 
41 // Interface for an Operation capabilities which should be tied to
42 // a specific hardware.
43 // Users should implement the interface and use TargetHardwareOpRegistration
44 // for registering the operation.
45 class TargetHardwareOperation {
46  public:
~TargetHardwareOperation()47   virtual ~TargetHardwareOperation() {}
48 
49   virtual double GetOpCost(mlir::Operation* op) const = 0;
50 
51   virtual bool IsOpSupported(mlir::Operation* op) const = 0;
52 };
53 
54 // Abstract base class for a hardware.
55 // To introduce new hardware
56 // users should implement the interface and use TargetHardwareRegistration
57 // for registering the hardware.
58 // Subclasses must implement the pure virtual function interface and
59 // define static member variable that retrieves string identifying the Target
60 // Hardware. Example,
61 // class MyType : public TargetHardware {
62 //  public:
63 //   static constexpr char kId[] = "MyHardware";
64 // };
65 class TargetHardware {
66  public:
~TargetHardware()67   virtual ~TargetHardware() {}
68 
69   // Initializes all TargetHardwareOperation registered for this hardware.
70   // Users overriding this function, should call the base class method to
71   // initialize the ops.
72   virtual bool Init();
73 
74   // Returns the cost of running 'op' on this Hardware.
75   virtual double GetOpCost(mlir::Operation* op) const;
76 
77   // Returns the cost of running the whole function on this hardware.
78   // By default this is the sum of the cost of individual cost for each op.
79   virtual double GetFuncCost(func::FuncOp* func) const;
80 
81   // Returns true if 'op' can run on this Hardware.
82   virtual bool IsOpSupported(mlir::Operation* op) const;
83 
84   // Switching cost between from hardware and this hardware.
85   // If both the hardwares are the same, the transfer cost is basically 0.
86   virtual double GetHardwareSwitchingCost(const TargetHardware* from,
87                                           size_t buffer_size) const = 0;
88 
89   // Returns a list of all patterns to apply for this hardware.
90   virtual mlir::RewritePatternSet GetTransformations(
91       MLIRContext* context) const = 0;
92 
93   // Returns TypeId for the provided hardware.
94   // Usually should be something like mlir::TypeID::get<MyType>()
95   virtual mlir::TypeID GetTypeId() const = 0;
96 
97  protected:
98   // All registered hardware ops.
99   std::vector<std::unique_ptr<TargetHardwareOperation>> hardware_ops_;
100 };
101 
102 // Returns pointer to the Hardware identified by 'hardware_name'.
103 // If not found nullptr is returned.
104 // DEPRECATED: Do not use, prefer GetTargetHardwareFactory instead.
105 const TargetHardware* GetTargetHardware(const std::string& hardware_name);
106 
107 // Returns the factory method for the requested hardware if present.
108 std::function<std::unique_ptr<TargetHardware>()> GetTargetHardwareFactory(
109     const std::string& hardware_name);
110 
111 namespace internal {
112 // DEPRECATED: Do not use, prefer using RegisterTargetHardwareFactory instead.
113 void RegisterTargetHardware(
114     const std::string& unique_name, const std::string& description,
115     mlir::TypeID type_id,
116     std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory);
117 
118 // DEPRECATED: Do not use, prefer using RegisterTargetHardwareFactory instead.
119 template <typename T>
RegisterTargetHardware(const std::string & description,std::function<std::unique_ptr<TargetHardware> ()> target_hardware_factory)120 void RegisterTargetHardware(
121     const std::string& description,
122     std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory) {
123   RegisterTargetHardware(T::kId, description, mlir::TypeID::get<T>(),
124                          target_hardware_factory);
125 }
126 
127 void RegisterTargetHardwareFactory(
128     const std::string& unique_name, const std::string& description,
129     mlir::TypeID type_id,
130     std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory);
131 
132 // Registers the provided target hardware factory.
133 template <typename T>
RegisterTargetHardwareFactory(const std::string & description,std::function<std::unique_ptr<TargetHardware> ()> target_hardware_factory)134 void RegisterTargetHardwareFactory(
135     const std::string& description,
136     std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory) {
137   RegisterTargetHardwareFactory(T::kId, description, mlir::TypeID::get<T>(),
138                                 target_hardware_factory);
139 }
140 
141 // DEPRECATED: Do not use, prefer RegisterTargetHardwareOpFactory intstead.
142 void RegisterTargetHardwareOp(
143     mlir::TypeID hardware_type, mlir::TypeID op_type,
144     std::function<std::unique_ptr<TargetHardwareOperation>()>
145         target_hardware_op_factory);
146 
147 void RegisterTargetHardwareOpFactory(
148     mlir::TypeID hardware_type, mlir::TypeID op_type,
149     std::function<std::unique_ptr<TargetHardwareOperation>()>
150         target_hardware_op_factory);
151 }  // namespace internal
152 
153 // Register target hardware.
154 template <typename Hardware>
155 struct TargetHardwareRegistration {
TargetHardwareRegistrationTargetHardwareRegistration156   TargetHardwareRegistration(const std::string& description,
157                              std::function<std::unique_ptr<TargetHardware>()>
158                                  target_hardware_factory) {
159     // TODO(b/177376459): remove this.
160     internal::RegisterTargetHardware<Hardware>(description,
161                                                target_hardware_factory);
162     internal::RegisterTargetHardwareFactory<Hardware>(description,
163                                                       target_hardware_factory);
164   }
165 };
166 
167 // Register Op capabilities for specific hardware.
168 template <typename Hardware, typename Op>
169 struct TargetHardwareOpRegistration {
TargetHardwareOpRegistrationTargetHardwareOpRegistration170   explicit TargetHardwareOpRegistration(
171       std::function<std::unique_ptr<TargetHardwareOperation>()>
172           target_hardware_op_factory) {
173     // TODO(b/177376459): remove this.
174     internal::RegisterTargetHardwareOp(mlir::TypeID::get<Hardware>(),
175                                        mlir::TypeID::get<Op>(),
176                                        target_hardware_op_factory);
177     internal::RegisterTargetHardwareOpFactory(mlir::TypeID::get<Hardware>(),
178                                               mlir::TypeID::get<Op>(),
179                                               target_hardware_op_factory);
180   }
181 };
182 
183 //======== util functions ==========
184 
185 // Process user specified device specs, will always add CPU if it's not there.
186 // specified_deivce_specs: ',' separated, like "GPU,DSP,CPU".
187 // device_specs: processed device specs enum.
188 bool ProcessTargetDevices(llvm::ArrayRef<std::string> specified_device_specs,
189                           std::vector<std::string>* device_specs);
190 
191 // Check whether two hardwares are the same.
IsSameHardware(const TargetHardware * lhs,const TargetHardware * rhs)192 inline bool IsSameHardware(const TargetHardware* lhs,
193                            const TargetHardware* rhs) {
194   return lhs->GetTypeId() == rhs->GetTypeId();
195 }
196 
197 // Returns the ID identifying 'hardware'. This should match the ID defined
198 // in the hardware field ID.
199 // For example, if MyHardware is passed the value returned should match
200 // MyHardware::kId.
201 std::string GetHardwareName(const TargetHardware* hardware);
202 
203 }  // namespace tac
204 }  // namespace TFL
205 }  // namespace mlir
206 
207 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TARGET_HARDWARE_H_
208