1 // Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.h"
16
17 #include <cstdint>
18 #include <map>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
24 #include "llvm/ADT/None.h"
25 #include "llvm/ADT/Optional.h"
26 #include "llvm/Support/Casting.h"
27 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
28 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
29 #include "mlir/IR/Attributes.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/Operation.h" // from @llvm-project
32 #include "mlir/IR/Region.h" // from @llvm-project
33 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
34 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
35 #include "tensorflow/compiler/mlir/lite/experimental/tac/runtime_metadata_generated.h"
36 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
38
39 namespace tflite {
40 namespace {
41
IsConst(mlir::Operation * op)42 bool IsConst(mlir::Operation* op) {
43 return llvm::isa<mlir::arith::ConstantOp, mlir::TF::ConstOp,
44 mlir::TFL::ConstOp, mlir::TFL::QConstOp>(op);
45 }
46
IsOpSupported(mlir::Operation * op,const std::string & hardware)47 bool IsOpSupported(mlir::Operation* op, const std::string& hardware) {
48 auto* devce_hardware = mlir::TFL::tac::GetTargetHardware(hardware);
49 if (devce_hardware == nullptr) return {};
50 return devce_hardware->IsOpSupported(op);
51 }
52
HasValidHardwareTarget(mlir::Operation * op)53 bool HasValidHardwareTarget(mlir::Operation* op) {
54 // All TFLite ops has CPU interface, should be enough to check for cpu.
55 return IsOpSupported(op, "CPU");
56 }
57
GetDeviceName(mlir::Operation * op)58 llvm::Optional<std::string> GetDeviceName(mlir::Operation* op) {
59 if (IsConst(op)) return llvm::None;
60
61 // The model may contain quant stats op which is unrelevant to the
62 // execution.
63 if (llvm::isa<mlir::func::ReturnOp, mlir::quantfork::StatisticsOp>(op))
64 return llvm::None;
65
66 if (!HasValidHardwareTarget(op)) return llvm::None;
67
68 auto device = op->getAttrOfType<mlir::StringAttr>(mlir::TFL::tac::kDevice);
69 if (device == nullptr) return llvm::None;
70
71 llvm::StringRef device_name_str = device.getValue();
72 return device_name_str.str();
73 }
74
GetPerDeviceCosts(const std::map<std::string,uint8_t> & hardware_map,mlir::Operation * op)75 llvm::Optional<std::vector<float>> GetPerDeviceCosts(
76 const std::map<std::string, uint8_t>& hardware_map, mlir::Operation* op) {
77 auto device_costs_attr =
78 op->getAttrOfType<mlir::DictionaryAttr>("per_device_costs");
79 if (device_costs_attr == nullptr) return llvm::None;
80
81 std::vector<float> device_costs(hardware_map.size(), -1.f);
82
83 for (const auto& kv : hardware_map) {
84 auto cost_attr = device_costs_attr.getNamed(kv.first);
85 if (!cost_attr.has_value()) return llvm::None;
86 float cost = cost_attr->getValue()
87 .dyn_cast_or_null<mlir::FloatAttr>()
88 .getValueAsDouble();
89 device_costs[kv.second] = cost;
90 }
91 return device_costs;
92 }
93
CreateSubgraphMetadata(const std::map<std::string,uint8_t> & hardware_map,mlir::Region * Region,flatbuffers::FlatBufferBuilder * builder)94 flatbuffers::Offset<SubgraphMetadata> CreateSubgraphMetadata(
95 const std::map<std::string, uint8_t>& hardware_map, mlir::Region* Region,
96 flatbuffers::FlatBufferBuilder* builder) {
97 auto& block = Region->front();
98 int index = 0;
99 std::vector<flatbuffers::Offset<tflite::OpMetadata>> ops;
100 for (auto& inst : block) {
101 // Const nodes are mapped to const vectors in flatbuffer, so skip.
102 if (IsConst(&inst)) continue;
103
104 // The model may contain quant stats op which is unrelevant to the
105 // execution.
106 if (llvm::isa<mlir::func::ReturnOp, mlir::quantfork::StatisticsOp>(&inst))
107 continue;
108
109 // If an op doesn't implement any of the hardware interface we skip it.
110 // This can happen in cases like Flex when we have non TFLite ops.
111 auto device_name = GetDeviceName(&inst);
112
113 if (device_name.has_value()) {
114 // Add per device costs if present.
115 auto per_device_cost = GetPerDeviceCosts(hardware_map, &inst);
116 flatbuffers::Offset<flatbuffers::Vector<float>> per_device_cost_offset;
117
118 if (per_device_cost.has_value()) {
119 per_device_cost_offset =
120 builder->CreateVector(per_device_cost.getValue());
121 }
122
123 OpMetadataBuilder op_builder(*builder);
124 op_builder.add_index(index);
125 uint8_t hardware = hardware_map.at(device_name.getValue());
126 op_builder.add_hardware(hardware);
127
128 if (per_device_cost.has_value()) {
129 op_builder.add_op_costs(per_device_cost_offset);
130 }
131
132 ops.push_back(op_builder.Finish());
133 }
134 index++;
135 }
136 return CreateSubgraphMetadata(*builder, builder->CreateVector(ops));
137 }
138
139 flatbuffers::Offset<tflite::HardwareMetadata>
CreateHardwareMetadataAndPopulateLookupTable(std::vector<mlir::func::FuncOp> * funcs,flatbuffers::FlatBufferBuilder * builder,std::map<std::string,uint8_t> * hardware_names)140 CreateHardwareMetadataAndPopulateLookupTable(
141 std::vector<mlir::func::FuncOp>* funcs,
142 flatbuffers::FlatBufferBuilder* builder,
143 std::map<std::string, uint8_t>* hardware_names) {
144 uint8_t index = 0;
145 for (auto& func : *funcs) {
146 func.walk([&hardware_names, &index](mlir::Operation* op) {
147 auto device_name = GetDeviceName(op);
148 if (!device_name.has_value()) return;
149
150 auto iter = hardware_names->find(device_name.getValue());
151 if (iter == hardware_names->end()) {
152 hardware_names->insert({device_name.getValue(), index++});
153 }
154 });
155 }
156
157 // Build the flatbuffer.
158 std::vector<flatbuffers::Offset<flatbuffers::String>> hardwares;
159 for (const auto& kv : *hardware_names) {
160 hardwares.push_back(builder->CreateString(kv.first));
161 }
162
163 return CreateHardwareMetadata(*builder, builder->CreateVector(hardwares));
164 }
165
166 } // namespace
167
ExportRuntimeMetadata(mlir::ModuleOp module)168 llvm::Optional<std::string> ExportRuntimeMetadata(mlir::ModuleOp module) {
169 mlir::func::FuncOp main_fn = module.lookupSymbol<mlir::func::FuncOp>("main");
170 if (!main_fn) return std::string("");
171
172 flatbuffers::FlatBufferBuilder fb_builder;
173 std::vector<mlir::func::FuncOp> funcs;
174 funcs.push_back(main_fn);
175 module.walk([&](mlir::func::FuncOp fn) {
176 if (fn != main_fn) {
177 funcs.push_back(fn);
178 }
179 });
180
181 // Populate the hardware metadata.
182 // And collect the hardwares used.
183 std::map<std::string, uint8_t> hardware_map;
184 flatbuffers::Offset<tflite::HardwareMetadata> hardware_metadata_offset =
185 CreateHardwareMetadataAndPopulateLookupTable(&funcs, &fb_builder,
186 &hardware_map);
187
188 // Populate the runtime metadata.
189 std::vector<flatbuffers::Offset<SubgraphMetadata>> subgraphs_metadata;
190 subgraphs_metadata.reserve(funcs.size());
191 for (auto& func : funcs) {
192 subgraphs_metadata.push_back(
193 CreateSubgraphMetadata(hardware_map, &func.getBody(), &fb_builder));
194 }
195 auto runtime_metadata =
196 CreateRuntimeMetadata(fb_builder, hardware_metadata_offset,
197 fb_builder.CreateVector(subgraphs_metadata));
198 fb_builder.Finish(runtime_metadata);
199 return std::string(
200 reinterpret_cast<const char*>(fb_builder.GetBufferPointer()),
201 fb_builder.GetSize());
202 }
203 } // namespace tflite
204