1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TARGET_HARDWARE_H_
16 #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TARGET_HARDWARE_H_
17
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <vector>
22
23 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
24 #include "mlir/IR/Operation.h" // from @llvm-project
25 #include "mlir/IR/PatternMatch.h" // from @llvm-project
26 #include "mlir/Support/TypeID.h" // from @llvm-project
27
28 namespace mlir {
29 namespace TFL {
30 namespace tac {
31
32 // Default fixed values for ops.
33 constexpr static float kDefaultFixedValuedCost = 1000000.0;
34
35 // This is just fake data.
36 constexpr static float kCrossHardwareTransferPerByteCost = 5.0f;
37
38 // This is just fake data.
39 constexpr static float kCrossHardwareTransferFixedCost = 10.f;
40
41 // Interface for an Operation capabilities which should be tied to
42 // a specific hardware.
43 // Users should implement the interface and use TargetHardwareOpRegistration
44 // for registering the operation.
45 class TargetHardwareOperation {
46 public:
~TargetHardwareOperation()47 virtual ~TargetHardwareOperation() {}
48
49 virtual double GetOpCost(mlir::Operation* op) const = 0;
50
51 virtual bool IsOpSupported(mlir::Operation* op) const = 0;
52 };
53
54 // Abstract base class for a hardware.
55 // To introduce new hardware
56 // users should implement the interface and use TargetHardwareRegistration
57 // for registering the hardware.
58 // Subclasses must implement the pure virtual function interface and
59 // define static member variable that retrieves string identifying the Target
60 // Hardware. Example,
61 // class MyType : public TargetHardware {
62 // public:
63 // static constexpr char kId[] = "MyHardware";
64 // };
65 class TargetHardware {
66 public:
~TargetHardware()67 virtual ~TargetHardware() {}
68
69 // Initializes all TargetHardwareOperation registered for this hardware.
70 // Users overriding this function, should call the base class method to
71 // initialize the ops.
72 virtual bool Init();
73
74 // Returns the cost of running 'op' on this Hardware.
75 virtual double GetOpCost(mlir::Operation* op) const;
76
77 // Returns the cost of running the whole function on this hardware.
78 // By default this is the sum of the cost of individual cost for each op.
79 virtual double GetFuncCost(func::FuncOp* func) const;
80
81 // Returns true if 'op' can run on this Hardware.
82 virtual bool IsOpSupported(mlir::Operation* op) const;
83
84 // Switching cost between from hardware and this hardware.
85 // If both the hardwares are the same, the transfer cost is basically 0.
86 virtual double GetHardwareSwitchingCost(const TargetHardware* from,
87 size_t buffer_size) const = 0;
88
89 // Returns a list of all patterns to apply for this hardware.
90 virtual mlir::RewritePatternSet GetTransformations(
91 MLIRContext* context) const = 0;
92
93 // Returns TypeId for the provided hardware.
94 // Usually should be something like mlir::TypeID::get<MyType>()
95 virtual mlir::TypeID GetTypeId() const = 0;
96
97 protected:
98 // All registered hardware ops.
99 std::vector<std::unique_ptr<TargetHardwareOperation>> hardware_ops_;
100 };
101
102 // Returns pointer to the Hardware identified by 'hardware_name'.
103 // If not found nullptr is returned.
104 // DEPRECATED: Do not use, prefer GetTargetHardwareFactory instead.
105 const TargetHardware* GetTargetHardware(const std::string& hardware_name);
106
107 // Returns the factory method for the requested hardware if present.
108 std::function<std::unique_ptr<TargetHardware>()> GetTargetHardwareFactory(
109 const std::string& hardware_name);
110
111 namespace internal {
112 // DEPRECATED: Do not use, prefer using RegisterTargetHardwareFactory instead.
113 void RegisterTargetHardware(
114 const std::string& unique_name, const std::string& description,
115 mlir::TypeID type_id,
116 std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory);
117
118 // DEPRECATED: Do not use, prefer using RegisterTargetHardwareFactory instead.
119 template <typename T>
RegisterTargetHardware(const std::string & description,std::function<std::unique_ptr<TargetHardware> ()> target_hardware_factory)120 void RegisterTargetHardware(
121 const std::string& description,
122 std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory) {
123 RegisterTargetHardware(T::kId, description, mlir::TypeID::get<T>(),
124 target_hardware_factory);
125 }
126
127 void RegisterTargetHardwareFactory(
128 const std::string& unique_name, const std::string& description,
129 mlir::TypeID type_id,
130 std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory);
131
132 // Registers the provided target hardware factory.
133 template <typename T>
RegisterTargetHardwareFactory(const std::string & description,std::function<std::unique_ptr<TargetHardware> ()> target_hardware_factory)134 void RegisterTargetHardwareFactory(
135 const std::string& description,
136 std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory) {
137 RegisterTargetHardwareFactory(T::kId, description, mlir::TypeID::get<T>(),
138 target_hardware_factory);
139 }
140
141 // DEPRECATED: Do not use, prefer RegisterTargetHardwareOpFactory intstead.
142 void RegisterTargetHardwareOp(
143 mlir::TypeID hardware_type, mlir::TypeID op_type,
144 std::function<std::unique_ptr<TargetHardwareOperation>()>
145 target_hardware_op_factory);
146
147 void RegisterTargetHardwareOpFactory(
148 mlir::TypeID hardware_type, mlir::TypeID op_type,
149 std::function<std::unique_ptr<TargetHardwareOperation>()>
150 target_hardware_op_factory);
151 } // namespace internal
152
153 // Register target hardware.
154 template <typename Hardware>
155 struct TargetHardwareRegistration {
TargetHardwareRegistrationTargetHardwareRegistration156 TargetHardwareRegistration(const std::string& description,
157 std::function<std::unique_ptr<TargetHardware>()>
158 target_hardware_factory) {
159 // TODO(b/177376459): remove this.
160 internal::RegisterTargetHardware<Hardware>(description,
161 target_hardware_factory);
162 internal::RegisterTargetHardwareFactory<Hardware>(description,
163 target_hardware_factory);
164 }
165 };
166
167 // Register Op capabilities for specific hardware.
168 template <typename Hardware, typename Op>
169 struct TargetHardwareOpRegistration {
TargetHardwareOpRegistrationTargetHardwareOpRegistration170 explicit TargetHardwareOpRegistration(
171 std::function<std::unique_ptr<TargetHardwareOperation>()>
172 target_hardware_op_factory) {
173 // TODO(b/177376459): remove this.
174 internal::RegisterTargetHardwareOp(mlir::TypeID::get<Hardware>(),
175 mlir::TypeID::get<Op>(),
176 target_hardware_op_factory);
177 internal::RegisterTargetHardwareOpFactory(mlir::TypeID::get<Hardware>(),
178 mlir::TypeID::get<Op>(),
179 target_hardware_op_factory);
180 }
181 };
182
183 //======== util functions ==========
184
185 // Process user specified device specs, will always add CPU if it's not there.
186 // specified_deivce_specs: ',' separated, like "GPU,DSP,CPU".
187 // device_specs: processed device specs enum.
188 bool ProcessTargetDevices(llvm::ArrayRef<std::string> specified_device_specs,
189 std::vector<std::string>* device_specs);
190
191 // Check whether two hardwares are the same.
IsSameHardware(const TargetHardware * lhs,const TargetHardware * rhs)192 inline bool IsSameHardware(const TargetHardware* lhs,
193 const TargetHardware* rhs) {
194 return lhs->GetTypeId() == rhs->GetTypeId();
195 }
196
197 // Returns the ID identifying 'hardware'. This should match the ID defined
198 // in the hardware field ID.
199 // For example, if MyHardware is passed the value returned should match
200 // MyHardware::kId.
201 std::string GetHardwareName(const TargetHardware* hardware);
202
203 } // namespace tac
204 } // namespace TFL
205 } // namespace mlir
206
207 #endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TARGET_HARDWARE_H_
208