1 // Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.h"
16 
17 #include <cstdint>
18 #include <map>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
24 #include "llvm/ADT/None.h"
25 #include "llvm/ADT/Optional.h"
26 #include "llvm/Support/Casting.h"
27 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
28 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
29 #include "mlir/IR/Attributes.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/Operation.h"  // from @llvm-project
32 #include "mlir/IR/Region.h"  // from @llvm-project
33 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
34 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
35 #include "tensorflow/compiler/mlir/lite/experimental/tac/runtime_metadata_generated.h"
36 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
38 
39 namespace tflite {
40 namespace {
41 
IsConst(mlir::Operation * op)42 bool IsConst(mlir::Operation* op) {
43   return llvm::isa<mlir::arith::ConstantOp, mlir::TF::ConstOp,
44                    mlir::TFL::ConstOp, mlir::TFL::QConstOp>(op);
45 }
46 
IsOpSupported(mlir::Operation * op,const std::string & hardware)47 bool IsOpSupported(mlir::Operation* op, const std::string& hardware) {
48   auto* devce_hardware = mlir::TFL::tac::GetTargetHardware(hardware);
49   if (devce_hardware == nullptr) return {};
50   return devce_hardware->IsOpSupported(op);
51 }
52 
HasValidHardwareTarget(mlir::Operation * op)53 bool HasValidHardwareTarget(mlir::Operation* op) {
54   // All TFLite ops has CPU interface, should be enough to check for cpu.
55   return IsOpSupported(op, "CPU");
56 }
57 
GetDeviceName(mlir::Operation * op)58 llvm::Optional<std::string> GetDeviceName(mlir::Operation* op) {
59   if (IsConst(op)) return llvm::None;
60 
61   // The model may contain quant stats op which is unrelevant to the
62   // execution.
63   if (llvm::isa<mlir::func::ReturnOp, mlir::quantfork::StatisticsOp>(op))
64     return llvm::None;
65 
66   if (!HasValidHardwareTarget(op)) return llvm::None;
67 
68   auto device = op->getAttrOfType<mlir::StringAttr>(mlir::TFL::tac::kDevice);
69   if (device == nullptr) return llvm::None;
70 
71   llvm::StringRef device_name_str = device.getValue();
72   return device_name_str.str();
73 }
74 
GetPerDeviceCosts(const std::map<std::string,uint8_t> & hardware_map,mlir::Operation * op)75 llvm::Optional<std::vector<float>> GetPerDeviceCosts(
76     const std::map<std::string, uint8_t>& hardware_map, mlir::Operation* op) {
77   auto device_costs_attr =
78       op->getAttrOfType<mlir::DictionaryAttr>("per_device_costs");
79   if (device_costs_attr == nullptr) return llvm::None;
80 
81   std::vector<float> device_costs(hardware_map.size(), -1.f);
82 
83   for (const auto& kv : hardware_map) {
84     auto cost_attr = device_costs_attr.getNamed(kv.first);
85     if (!cost_attr.has_value()) return llvm::None;
86     float cost = cost_attr->getValue()
87                      .dyn_cast_or_null<mlir::FloatAttr>()
88                      .getValueAsDouble();
89     device_costs[kv.second] = cost;
90   }
91   return device_costs;
92 }
93 
CreateSubgraphMetadata(const std::map<std::string,uint8_t> & hardware_map,mlir::Region * Region,flatbuffers::FlatBufferBuilder * builder)94 flatbuffers::Offset<SubgraphMetadata> CreateSubgraphMetadata(
95     const std::map<std::string, uint8_t>& hardware_map, mlir::Region* Region,
96     flatbuffers::FlatBufferBuilder* builder) {
97   auto& block = Region->front();
98   int index = 0;
99   std::vector<flatbuffers::Offset<tflite::OpMetadata>> ops;
100   for (auto& inst : block) {
101     // Const nodes are mapped to const vectors in flatbuffer, so skip.
102     if (IsConst(&inst)) continue;
103 
104     // The model may contain quant stats op which is unrelevant to the
105     // execution.
106     if (llvm::isa<mlir::func::ReturnOp, mlir::quantfork::StatisticsOp>(&inst))
107       continue;
108 
109     // If an op doesn't implement any of the hardware interface we skip it.
110     // This can happen in cases like Flex when we have non TFLite ops.
111     auto device_name = GetDeviceName(&inst);
112 
113     if (device_name.has_value()) {
114       // Add per device costs if present.
115       auto per_device_cost = GetPerDeviceCosts(hardware_map, &inst);
116       flatbuffers::Offset<flatbuffers::Vector<float>> per_device_cost_offset;
117 
118       if (per_device_cost.has_value()) {
119         per_device_cost_offset =
120             builder->CreateVector(per_device_cost.getValue());
121       }
122 
123       OpMetadataBuilder op_builder(*builder);
124       op_builder.add_index(index);
125       uint8_t hardware = hardware_map.at(device_name.getValue());
126       op_builder.add_hardware(hardware);
127 
128       if (per_device_cost.has_value()) {
129         op_builder.add_op_costs(per_device_cost_offset);
130       }
131 
132       ops.push_back(op_builder.Finish());
133     }
134     index++;
135   }
136   return CreateSubgraphMetadata(*builder, builder->CreateVector(ops));
137 }
138 
139 flatbuffers::Offset<tflite::HardwareMetadata>
CreateHardwareMetadataAndPopulateLookupTable(std::vector<mlir::func::FuncOp> * funcs,flatbuffers::FlatBufferBuilder * builder,std::map<std::string,uint8_t> * hardware_names)140 CreateHardwareMetadataAndPopulateLookupTable(
141     std::vector<mlir::func::FuncOp>* funcs,
142     flatbuffers::FlatBufferBuilder* builder,
143     std::map<std::string, uint8_t>* hardware_names) {
144   uint8_t index = 0;
145   for (auto& func : *funcs) {
146     func.walk([&hardware_names, &index](mlir::Operation* op) {
147       auto device_name = GetDeviceName(op);
148       if (!device_name.has_value()) return;
149 
150       auto iter = hardware_names->find(device_name.getValue());
151       if (iter == hardware_names->end()) {
152         hardware_names->insert({device_name.getValue(), index++});
153       }
154     });
155   }
156 
157   // Build the flatbuffer.
158   std::vector<flatbuffers::Offset<flatbuffers::String>> hardwares;
159   for (const auto& kv : *hardware_names) {
160     hardwares.push_back(builder->CreateString(kv.first));
161   }
162 
163   return CreateHardwareMetadata(*builder, builder->CreateVector(hardwares));
164 }
165 
166 }  // namespace
167 
ExportRuntimeMetadata(mlir::ModuleOp module)168 llvm::Optional<std::string> ExportRuntimeMetadata(mlir::ModuleOp module) {
169   mlir::func::FuncOp main_fn = module.lookupSymbol<mlir::func::FuncOp>("main");
170   if (!main_fn) return std::string("");
171 
172   flatbuffers::FlatBufferBuilder fb_builder;
173   std::vector<mlir::func::FuncOp> funcs;
174   funcs.push_back(main_fn);
175   module.walk([&](mlir::func::FuncOp fn) {
176     if (fn != main_fn) {
177       funcs.push_back(fn);
178     }
179   });
180 
181   // Populate the hardware metadata.
182   // And collect the hardwares used.
183   std::map<std::string, uint8_t> hardware_map;
184   flatbuffers::Offset<tflite::HardwareMetadata> hardware_metadata_offset =
185       CreateHardwareMetadataAndPopulateLookupTable(&funcs, &fb_builder,
186                                                    &hardware_map);
187 
188   // Populate the runtime metadata.
189   std::vector<flatbuffers::Offset<SubgraphMetadata>> subgraphs_metadata;
190   subgraphs_metadata.reserve(funcs.size());
191   for (auto& func : funcs) {
192     subgraphs_metadata.push_back(
193         CreateSubgraphMetadata(hardware_map, &func.getBody(), &fb_builder));
194   }
195   auto runtime_metadata =
196       CreateRuntimeMetadata(fb_builder, hardware_metadata_offset,
197                             fb_builder.CreateVector(subgraphs_metadata));
198   fb_builder.Finish(runtime_metadata);
199   return std::string(
200       reinterpret_cast<const char*>(fb_builder.GetBufferPointer()),
201       fb_builder.GetSize());
202 }
203 }  // namespace tflite
204