1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
16 
17 #include <algorithm>
18 #include <cctype>
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
27 #include "mlir/Support/TypeID.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
29 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
30 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
31 
32 namespace mlir {
33 namespace TFL {
34 namespace tac {
35 namespace {
36 struct RegisteredTargetHardware {
37   // TODO(b/177376459): Remove this constructor.
RegisteredTargetHardwaremlir::TFL::tac::__anonad033e510111::RegisteredTargetHardware38   RegisteredTargetHardware(const std::string& name,
39                            const std::string& description, mlir::TypeID type_id,
40                            std::unique_ptr<TargetHardware> target_hardware)
41       : unique_name(GetCanonicalHardwareName(name)),
42         description(description),
43         type_id(type_id),
44         target_hardware(std::move(target_hardware)) {}
45 
RegisteredTargetHardwaremlir::TFL::tac::__anonad033e510111::RegisteredTargetHardware46   RegisteredTargetHardware(
47       const std::string& name, const std::string& description,
48       mlir::TypeID type_id,
49       std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory)
50       : unique_name(GetCanonicalHardwareName(name)),
51         description(description),
52         target_hardware_factory(target_hardware_factory) {}
53 
54   std::string unique_name;
55   std::string description;
56   mlir::TypeID type_id;
57   std::unique_ptr<TargetHardware> target_hardware;
58   std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory;
59 };
60 
61 struct RegisteredTargetHardwareOps {
RegisteredTargetHardwareOpsmlir::TFL::tac::__anonad033e510111::RegisteredTargetHardwareOps62   explicit RegisteredTargetHardwareOps(mlir::TypeID hardware_type)
63       : hardware_typeid(hardware_type) {}
64   // Key is the Operation TypeID
65   llvm::DenseMap<mlir::TypeID, std::unique_ptr<TargetHardwareOperation>>
66       target_hardware_ops;
67   // Key is the Operation TypeID
68   llvm::DenseMap<mlir::TypeID,
69                  std::function<std::unique_ptr<TargetHardwareOperation>()>>
70       target_hardware_ops_factory;
71   mlir::TypeID hardware_typeid;
72 };
73 
74 std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>*
GetRegisteredTargetHardwareOps()75 GetRegisteredTargetHardwareOps() {
76   static std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>*
77       hardwares_ops =
78           []() -> std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>* {
79     return new std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>();
80   }();
81   return hardwares_ops;
82 }
83 
GetRegisteredHardwares()84 std::vector<RegisteredTargetHardware>* GetRegisteredHardwares() {
85   static std::vector<RegisteredTargetHardware>* hardwares =
86       []() -> std::vector<RegisteredTargetHardware>* {
87     return new std::vector<RegisteredTargetHardware>();
88   }();
89   return hardwares;
90 }
91 
92 llvm::DenseMap<mlir::TypeID, std::unique_ptr<TargetHardwareOperation>>*
getRegisteredOperationsForHardware(mlir::TypeID type_id)93 getRegisteredOperationsForHardware(mlir::TypeID type_id) {
94   auto* hardwares = GetRegisteredTargetHardwareOps();
95   for (auto& hardware : *hardwares) {
96     if (hardware->hardware_typeid == type_id) {
97       return &hardware->target_hardware_ops;
98     }
99   }
100   return nullptr;
101 }
102 
103 // A deny list for op cost computation since those ops are not arithemtic.
IsNonArithmeticOp(mlir::Operation * op)104 inline bool IsNonArithmeticOp(mlir::Operation* op) {
105   if (llvm::isa<func::ReturnOp, func::FuncOp>(op)) return true;
106   if (op->hasTrait<OpTrait::ConstantLike>()) return true;
107   if (llvm::isa<QConstOp, SparseQConstOp>(op)) return true;
108   if (!NotTFLQuantDequantizeOp(op)) return true;
109   return false;
110 }
111 
112 }  // namespace
113 
Init()114 bool TargetHardware::Init() {
115   auto* hardware_ops_factory = GetRegisteredTargetHardwareOps();
116   for (auto& hardware_ops : *hardware_ops_factory) {
117     if (hardware_ops->hardware_typeid != this->GetTypeId()) continue;
118     auto& op_factories = hardware_ops->target_hardware_ops_factory;
119     for (auto& op_factory : op_factories) {
120       hardware_ops_.emplace_back(op_factory.getSecond()());
121     }
122     break;
123   }
124   return true;
125 }
126 
GetOpCost(mlir::Operation * op) const127 double TargetHardware::GetOpCost(mlir::Operation* op) const {
128   auto* registered_ops = getRegisteredOperationsForHardware(GetTypeId());
129   if (registered_ops == nullptr) {
130     return kDefaultFixedValuedCost;
131   }
132   auto abstract_op = op->getRegisteredInfo();
133   auto hardware_op = registered_ops->find(abstract_op->getTypeID());
134   if (hardware_op == registered_ops->end()) return kDefaultFixedValuedCost;
135   return hardware_op->second->GetOpCost(op);
136 }
137 
IsOpSupported(mlir::Operation * op) const138 bool TargetHardware::IsOpSupported(mlir::Operation* op) const {
139   auto* registered_ops = getRegisteredOperationsForHardware(GetTypeId());
140   if (registered_ops == nullptr) {
141     return false;
142   }
143   auto abstract_op = op->getRegisteredInfo();
144   auto hardware_op = registered_ops->find(abstract_op->getTypeID());
145   if (hardware_op == registered_ops->end()) return false;
146   return hardware_op->second->IsOpSupported(op);
147 }
148 
GetFuncCost(func::FuncOp * func) const149 double TargetHardware::GetFuncCost(func::FuncOp* func) const {
150   double total_cost = 0.0;
151   func->walk([&](Operation* op) {
152     if (IsNonArithmeticOp(op)) return;
153     // We will always defer to the hardware to decide the cost.
154     total_cost += GetOpCost(op);
155   });
156   return total_cost;
157 }
158 
GetTargetHardware(const std::string & hardware_name)159 const TargetHardware* GetTargetHardware(const std::string& hardware_name) {
160   const std::string canonical_name = GetCanonicalHardwareName(hardware_name);
161   // Just loop for now, we don't expect number of hardwares to be huge.
162   // Revisit to have map if number of elements increased.
163   auto* registered_hardwares = GetRegisteredHardwares();
164   for (const auto& hardware : *registered_hardwares) {
165     if (hardware.unique_name == canonical_name) {
166       return hardware.target_hardware.get();
167     }
168   }
169   return nullptr;
170 }
171 
GetTargetHardwareFactory(const std::string & hardware_name)172 std::function<std::unique_ptr<TargetHardware>()> GetTargetHardwareFactory(
173     const std::string& hardware_name) {
174   const std::string canonical_name = GetCanonicalHardwareName(hardware_name);
175   // Just loop for now, we don't expect number of hardwares to be huge.
176   // Revisit to have map if number of elements increased.
177   auto* registered_hardwares = GetRegisteredHardwares();
178   for (const auto& hardware : *registered_hardwares) {
179     if (hardware.unique_name == canonical_name) {
180       return hardware.target_hardware_factory;
181     }
182   }
183   return nullptr;
184 }
185 
186 namespace internal {
187 
RegisterTargetHardware(const std::string & unique_name,const std::string & description,mlir::TypeID type_id,std::function<std::unique_ptr<TargetHardware> ()> target_hardware_factory)188 void RegisterTargetHardware(
189     const std::string& unique_name, const std::string& description,
190     mlir::TypeID type_id,
191     std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory) {
192   auto* registered_hardwares = GetRegisteredHardwares();
193   for (const auto& hardware : *registered_hardwares) {
194     if (hardware.unique_name == unique_name) {
195       llvm::errs() << "Ignoring duplicate hardware. Hardware " << unique_name
196                    << " already registered\n";
197       return;
198     }
199   }
200   registered_hardwares->push_back(RegisteredTargetHardware(
201       unique_name, description, type_id, target_hardware_factory()));
202 }
203 
RegisterTargetHardwareFactory(const std::string & unique_name,const std::string & description,mlir::TypeID type_id,std::function<std::unique_ptr<TargetHardware> ()> target_hardware_factory)204 void RegisterTargetHardwareFactory(
205     const std::string& unique_name, const std::string& description,
206     mlir::TypeID type_id,
207     std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory) {
208   auto* registered_hardwares = GetRegisteredHardwares();
209   for (auto& hardware : *registered_hardwares) {
210     if (hardware.unique_name == unique_name) {
211       llvm::errs() << "Ignoring duplicate hardware. Hardware " << unique_name
212                    << " already registered\n";
213       hardware.target_hardware_factory = target_hardware_factory;
214       return;
215     }
216   }
217   registered_hardwares->push_back(RegisteredTargetHardware(
218       unique_name, description, type_id, target_hardware_factory));
219 }
220 
RegisterTargetHardwareOp(mlir::TypeID hardware_type,mlir::TypeID op_type,std::function<std::unique_ptr<TargetHardwareOperation> ()> target_hardware_op_factory)221 void RegisterTargetHardwareOp(
222     mlir::TypeID hardware_type, mlir::TypeID op_type,
223     std::function<std::unique_ptr<TargetHardwareOperation>()>
224         target_hardware_op_factory) {
225   auto* registered_hardware_ops = GetRegisteredTargetHardwareOps();
226   for (auto& hardware : *registered_hardware_ops) {
227     if (hardware->hardware_typeid == hardware_type) {
228       if (hardware->target_hardware_ops.count(op_type)) {
229         llvm::errs() << "Trying to register duplicate Op";
230         return;
231       }
232       hardware->target_hardware_ops[op_type] = target_hardware_op_factory();
233       return;
234     }
235   }
236   registered_hardware_ops->push_back(
237       std::make_unique<RegisteredTargetHardwareOps>(
238           RegisteredTargetHardwareOps(hardware_type)));
239   registered_hardware_ops->back()->target_hardware_ops[op_type] =
240       target_hardware_op_factory();
241 }
242 
RegisterTargetHardwareOpFactory(mlir::TypeID hardware_type,mlir::TypeID op_type,std::function<std::unique_ptr<TargetHardwareOperation> ()> target_hardware_op_factory)243 void RegisterTargetHardwareOpFactory(
244     mlir::TypeID hardware_type, mlir::TypeID op_type,
245     std::function<std::unique_ptr<TargetHardwareOperation>()>
246         target_hardware_op_factory) {
247   auto* registered_hardware_ops = GetRegisteredTargetHardwareOps();
248   for (auto& hardware : *registered_hardware_ops) {
249     if (hardware->hardware_typeid == hardware_type) {
250       if (hardware->target_hardware_ops_factory.count(op_type)) {
251         llvm::errs() << "Trying to register duplicate Op";
252         return;
253       }
254       hardware->target_hardware_ops_factory[op_type] =
255           target_hardware_op_factory;
256       return;
257     }
258   }
259   registered_hardware_ops->push_back(
260       std::make_unique<RegisteredTargetHardwareOps>(
261           RegisteredTargetHardwareOps(hardware_type)));
262   registered_hardware_ops->back()->target_hardware_ops_factory[op_type] =
263       target_hardware_op_factory;
264 }
265 
266 }  // namespace internal
267 
ProcessTargetDevices(llvm::ArrayRef<std::string> specified_device_specs,std::vector<std::string> * device_specs)268 bool ProcessTargetDevices(llvm::ArrayRef<std::string> specified_device_specs,
269                           std::vector<std::string>* device_specs) {
270   bool cpu_include = false;
271   for (auto& device_spec : specified_device_specs) {
272     auto device = GetCanonicalHardwareName(device_spec);
273 
274     if (device == "CPU") cpu_include = true;
275     device_specs->push_back(device);
276   }
277   if (!cpu_include) {
278     device_specs->push_back("CPU");
279   }
280 
281   // Make sure all the devices are registered.
282   for (const std::string& device : *device_specs) {
283     if (GetTargetHardware(device) == nullptr) {
284       llvm::errs() << "cannot get target hardware for device: " << device;
285       return false;
286     }
287   }
288 
289   return true;
290 }
291 
GetHardwareName(const TargetHardware * hardware)292 std::string GetHardwareName(const TargetHardware* hardware) {
293   const auto* registered_hardwares = GetRegisteredHardwares();
294   for (const auto& registered_hardware : *registered_hardwares) {
295     if (registered_hardware.type_id == hardware->GetTypeId())
296       return registered_hardware.unique_name;
297   }
298   return "";
299 }
300 
301 }  // namespace tac
302 }  // namespace TFL
303 }  // namespace mlir
304