1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
16
17 #include <algorithm>
18 #include <cctype>
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <utility>
23
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
27 #include "mlir/Support/TypeID.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
29 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
30 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
31
32 namespace mlir {
33 namespace TFL {
34 namespace tac {
35 namespace {
36 struct RegisteredTargetHardware {
37 // TODO(b/177376459): Remove this constructor.
RegisteredTargetHardwaremlir::TFL::tac::__anonad033e510111::RegisteredTargetHardware38 RegisteredTargetHardware(const std::string& name,
39 const std::string& description, mlir::TypeID type_id,
40 std::unique_ptr<TargetHardware> target_hardware)
41 : unique_name(GetCanonicalHardwareName(name)),
42 description(description),
43 type_id(type_id),
44 target_hardware(std::move(target_hardware)) {}
45
RegisteredTargetHardwaremlir::TFL::tac::__anonad033e510111::RegisteredTargetHardware46 RegisteredTargetHardware(
47 const std::string& name, const std::string& description,
48 mlir::TypeID type_id,
49 std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory)
50 : unique_name(GetCanonicalHardwareName(name)),
51 description(description),
52 target_hardware_factory(target_hardware_factory) {}
53
54 std::string unique_name;
55 std::string description;
56 mlir::TypeID type_id;
57 std::unique_ptr<TargetHardware> target_hardware;
58 std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory;
59 };
60
61 struct RegisteredTargetHardwareOps {
RegisteredTargetHardwareOpsmlir::TFL::tac::__anonad033e510111::RegisteredTargetHardwareOps62 explicit RegisteredTargetHardwareOps(mlir::TypeID hardware_type)
63 : hardware_typeid(hardware_type) {}
64 // Key is the Operation TypeID
65 llvm::DenseMap<mlir::TypeID, std::unique_ptr<TargetHardwareOperation>>
66 target_hardware_ops;
67 // Key is the Operation TypeID
68 llvm::DenseMap<mlir::TypeID,
69 std::function<std::unique_ptr<TargetHardwareOperation>()>>
70 target_hardware_ops_factory;
71 mlir::TypeID hardware_typeid;
72 };
73
74 std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>*
GetRegisteredTargetHardwareOps()75 GetRegisteredTargetHardwareOps() {
76 static std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>*
77 hardwares_ops =
78 []() -> std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>* {
79 return new std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>();
80 }();
81 return hardwares_ops;
82 }
83
GetRegisteredHardwares()84 std::vector<RegisteredTargetHardware>* GetRegisteredHardwares() {
85 static std::vector<RegisteredTargetHardware>* hardwares =
86 []() -> std::vector<RegisteredTargetHardware>* {
87 return new std::vector<RegisteredTargetHardware>();
88 }();
89 return hardwares;
90 }
91
92 llvm::DenseMap<mlir::TypeID, std::unique_ptr<TargetHardwareOperation>>*
getRegisteredOperationsForHardware(mlir::TypeID type_id)93 getRegisteredOperationsForHardware(mlir::TypeID type_id) {
94 auto* hardwares = GetRegisteredTargetHardwareOps();
95 for (auto& hardware : *hardwares) {
96 if (hardware->hardware_typeid == type_id) {
97 return &hardware->target_hardware_ops;
98 }
99 }
100 return nullptr;
101 }
102
103 // A deny list for op cost computation since those ops are not arithemtic.
IsNonArithmeticOp(mlir::Operation * op)104 inline bool IsNonArithmeticOp(mlir::Operation* op) {
105 if (llvm::isa<func::ReturnOp, func::FuncOp>(op)) return true;
106 if (op->hasTrait<OpTrait::ConstantLike>()) return true;
107 if (llvm::isa<QConstOp, SparseQConstOp>(op)) return true;
108 if (!NotTFLQuantDequantizeOp(op)) return true;
109 return false;
110 }
111
112 } // namespace
113
Init()114 bool TargetHardware::Init() {
115 auto* hardware_ops_factory = GetRegisteredTargetHardwareOps();
116 for (auto& hardware_ops : *hardware_ops_factory) {
117 if (hardware_ops->hardware_typeid != this->GetTypeId()) continue;
118 auto& op_factories = hardware_ops->target_hardware_ops_factory;
119 for (auto& op_factory : op_factories) {
120 hardware_ops_.emplace_back(op_factory.getSecond()());
121 }
122 break;
123 }
124 return true;
125 }
126
GetOpCost(mlir::Operation * op) const127 double TargetHardware::GetOpCost(mlir::Operation* op) const {
128 auto* registered_ops = getRegisteredOperationsForHardware(GetTypeId());
129 if (registered_ops == nullptr) {
130 return kDefaultFixedValuedCost;
131 }
132 auto abstract_op = op->getRegisteredInfo();
133 auto hardware_op = registered_ops->find(abstract_op->getTypeID());
134 if (hardware_op == registered_ops->end()) return kDefaultFixedValuedCost;
135 return hardware_op->second->GetOpCost(op);
136 }
137
IsOpSupported(mlir::Operation * op) const138 bool TargetHardware::IsOpSupported(mlir::Operation* op) const {
139 auto* registered_ops = getRegisteredOperationsForHardware(GetTypeId());
140 if (registered_ops == nullptr) {
141 return false;
142 }
143 auto abstract_op = op->getRegisteredInfo();
144 auto hardware_op = registered_ops->find(abstract_op->getTypeID());
145 if (hardware_op == registered_ops->end()) return false;
146 return hardware_op->second->IsOpSupported(op);
147 }
148
GetFuncCost(func::FuncOp * func) const149 double TargetHardware::GetFuncCost(func::FuncOp* func) const {
150 double total_cost = 0.0;
151 func->walk([&](Operation* op) {
152 if (IsNonArithmeticOp(op)) return;
153 // We will always defer to the hardware to decide the cost.
154 total_cost += GetOpCost(op);
155 });
156 return total_cost;
157 }
158
GetTargetHardware(const std::string & hardware_name)159 const TargetHardware* GetTargetHardware(const std::string& hardware_name) {
160 const std::string canonical_name = GetCanonicalHardwareName(hardware_name);
161 // Just loop for now, we don't expect number of hardwares to be huge.
162 // Revisit to have map if number of elements increased.
163 auto* registered_hardwares = GetRegisteredHardwares();
164 for (const auto& hardware : *registered_hardwares) {
165 if (hardware.unique_name == canonical_name) {
166 return hardware.target_hardware.get();
167 }
168 }
169 return nullptr;
170 }
171
GetTargetHardwareFactory(const std::string & hardware_name)172 std::function<std::unique_ptr<TargetHardware>()> GetTargetHardwareFactory(
173 const std::string& hardware_name) {
174 const std::string canonical_name = GetCanonicalHardwareName(hardware_name);
175 // Just loop for now, we don't expect number of hardwares to be huge.
176 // Revisit to have map if number of elements increased.
177 auto* registered_hardwares = GetRegisteredHardwares();
178 for (const auto& hardware : *registered_hardwares) {
179 if (hardware.unique_name == canonical_name) {
180 return hardware.target_hardware_factory;
181 }
182 }
183 return nullptr;
184 }
185
186 namespace internal {
187
RegisterTargetHardware(const std::string & unique_name,const std::string & description,mlir::TypeID type_id,std::function<std::unique_ptr<TargetHardware> ()> target_hardware_factory)188 void RegisterTargetHardware(
189 const std::string& unique_name, const std::string& description,
190 mlir::TypeID type_id,
191 std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory) {
192 auto* registered_hardwares = GetRegisteredHardwares();
193 for (const auto& hardware : *registered_hardwares) {
194 if (hardware.unique_name == unique_name) {
195 llvm::errs() << "Ignoring duplicate hardware. Hardware " << unique_name
196 << " already registered\n";
197 return;
198 }
199 }
200 registered_hardwares->push_back(RegisteredTargetHardware(
201 unique_name, description, type_id, target_hardware_factory()));
202 }
203
RegisterTargetHardwareFactory(const std::string & unique_name,const std::string & description,mlir::TypeID type_id,std::function<std::unique_ptr<TargetHardware> ()> target_hardware_factory)204 void RegisterTargetHardwareFactory(
205 const std::string& unique_name, const std::string& description,
206 mlir::TypeID type_id,
207 std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory) {
208 auto* registered_hardwares = GetRegisteredHardwares();
209 for (auto& hardware : *registered_hardwares) {
210 if (hardware.unique_name == unique_name) {
211 llvm::errs() << "Ignoring duplicate hardware. Hardware " << unique_name
212 << " already registered\n";
213 hardware.target_hardware_factory = target_hardware_factory;
214 return;
215 }
216 }
217 registered_hardwares->push_back(RegisteredTargetHardware(
218 unique_name, description, type_id, target_hardware_factory));
219 }
220
RegisterTargetHardwareOp(mlir::TypeID hardware_type,mlir::TypeID op_type,std::function<std::unique_ptr<TargetHardwareOperation> ()> target_hardware_op_factory)221 void RegisterTargetHardwareOp(
222 mlir::TypeID hardware_type, mlir::TypeID op_type,
223 std::function<std::unique_ptr<TargetHardwareOperation>()>
224 target_hardware_op_factory) {
225 auto* registered_hardware_ops = GetRegisteredTargetHardwareOps();
226 for (auto& hardware : *registered_hardware_ops) {
227 if (hardware->hardware_typeid == hardware_type) {
228 if (hardware->target_hardware_ops.count(op_type)) {
229 llvm::errs() << "Trying to register duplicate Op";
230 return;
231 }
232 hardware->target_hardware_ops[op_type] = target_hardware_op_factory();
233 return;
234 }
235 }
236 registered_hardware_ops->push_back(
237 std::make_unique<RegisteredTargetHardwareOps>(
238 RegisteredTargetHardwareOps(hardware_type)));
239 registered_hardware_ops->back()->target_hardware_ops[op_type] =
240 target_hardware_op_factory();
241 }
242
RegisterTargetHardwareOpFactory(mlir::TypeID hardware_type,mlir::TypeID op_type,std::function<std::unique_ptr<TargetHardwareOperation> ()> target_hardware_op_factory)243 void RegisterTargetHardwareOpFactory(
244 mlir::TypeID hardware_type, mlir::TypeID op_type,
245 std::function<std::unique_ptr<TargetHardwareOperation>()>
246 target_hardware_op_factory) {
247 auto* registered_hardware_ops = GetRegisteredTargetHardwareOps();
248 for (auto& hardware : *registered_hardware_ops) {
249 if (hardware->hardware_typeid == hardware_type) {
250 if (hardware->target_hardware_ops_factory.count(op_type)) {
251 llvm::errs() << "Trying to register duplicate Op";
252 return;
253 }
254 hardware->target_hardware_ops_factory[op_type] =
255 target_hardware_op_factory;
256 return;
257 }
258 }
259 registered_hardware_ops->push_back(
260 std::make_unique<RegisteredTargetHardwareOps>(
261 RegisteredTargetHardwareOps(hardware_type)));
262 registered_hardware_ops->back()->target_hardware_ops_factory[op_type] =
263 target_hardware_op_factory;
264 }
265
266 } // namespace internal
267
ProcessTargetDevices(llvm::ArrayRef<std::string> specified_device_specs,std::vector<std::string> * device_specs)268 bool ProcessTargetDevices(llvm::ArrayRef<std::string> specified_device_specs,
269 std::vector<std::string>* device_specs) {
270 bool cpu_include = false;
271 for (auto& device_spec : specified_device_specs) {
272 auto device = GetCanonicalHardwareName(device_spec);
273
274 if (device == "CPU") cpu_include = true;
275 device_specs->push_back(device);
276 }
277 if (!cpu_include) {
278 device_specs->push_back("CPU");
279 }
280
281 // Make sure all the devices are registered.
282 for (const std::string& device : *device_specs) {
283 if (GetTargetHardware(device) == nullptr) {
284 llvm::errs() << "cannot get target hardware for device: " << device;
285 return false;
286 }
287 }
288
289 return true;
290 }
291
GetHardwareName(const TargetHardware * hardware)292 std::string GetHardwareName(const TargetHardware* hardware) {
293 const auto* registered_hardwares = GetRegisteredHardwares();
294 for (const auto& registered_hardware : *registered_hardwares) {
295 if (registered_hardware.type_id == hardware->GetTypeId())
296 return registered_hardware.unique_name;
297 }
298 return "";
299 }
300
301 } // namespace tac
302 } // namespace TFL
303 } // namespace mlir
304