1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
17
18 #include <algorithm>
19 #include <cctype>
20 #include <climits>
21 #include <cstdint>
22 #include <iostream>
23 #include <sstream>
24 #include <string>
25 #include <utility>
26 #include <vector>
27
28 #include "absl/base/casts.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/string_view.h"
32 #include "llvm/ADT/APFloat.h"
33 #include "llvm/ADT/APInt.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/DenseMap.h"
36 #include "llvm/ADT/None.h"
37 #include "llvm/ADT/Optional.h"
38 #include "llvm/ADT/STLExtras.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/StringExtras.h"
41 #include "llvm/ADT/StringRef.h"
42 #include "llvm/Support/Casting.h"
43 #include "llvm/Support/CommandLine.h"
44 #include "llvm/Support/Endian.h"
45 #include "llvm/Support/FormatVariadic.h"
46 #include "llvm/Support/MemoryBuffer.h"
47 #include "llvm/Support/SourceMgr.h"
48 #include "llvm/Support/raw_ostream.h"
49 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
50 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
51 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
52 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
53 #include "mlir/IR/Attributes.h" // from @llvm-project
54 #include "mlir/IR/Builders.h" // from @llvm-project
55 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
56 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
57 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
58 #include "mlir/IR/Diagnostics.h" // from @llvm-project
59 #include "mlir/IR/Location.h" // from @llvm-project
60 #include "mlir/IR/MLIRContext.h" // from @llvm-project
61 #include "mlir/IR/Operation.h" // from @llvm-project
62 #include "mlir/IR/OperationSupport.h" // from @llvm-project
63 #include "mlir/IR/Types.h" // from @llvm-project
64 #include "mlir/IR/Value.h" // from @llvm-project
65 #include "mlir/Support/LLVM.h" // from @llvm-project
66 #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project
67 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
68 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
69 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
70 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
71 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
72 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
73 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
74 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
75 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
76 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
77 #include "tensorflow/compiler/xla/statusor.h"
78 #include "tensorflow/core/framework/tensor.pb.h"
79 #include "tensorflow/core/framework/tensor_shape.pb.h"
80 #include "tensorflow/core/platform/errors.h"
81 #include "tensorflow/core/platform/status.h"
82 #include "tensorflow/lite/model.h"
83 #include "tensorflow/lite/schema/schema_generated.h"
84 #include "tensorflow/lite/schema/schema_utils.h"
85 #include "tensorflow/lite/string_util.h"
86
87 using llvm::ArrayRef;
88 using mlir::Builder;
89 using mlir::DenseElementsAttr;
90 using mlir::Location;
91 using mlir::MLIRContext;
92 using mlir::OpBuilder;
93 using mlir::Operation;
94 using mlir::OperationState;
95 using mlir::OwningOpRef;
96 using mlir::RankedTensorType;
97 using mlir::UnrankedTensorType;
98 using mlir::Value;
99 using mlir::func::FuncOp;
100 using mlir::quant::QuantizedType;
101 using tflite::OperatorT;
102 using tflite::TensorT;
103 using xla::Status;
104 using xla::StatusOr;
105
106 namespace errors = tensorflow::errors;
107 namespace tfl = mlir::TFL;
108
109 namespace {
110
IsQuantized(const TensorT & tensor)111 bool IsQuantized(const TensorT& tensor) {
112 return (tensor.quantization != nullptr) &&
113 !tensor.quantization->zero_point.empty();
114 }
115
116 // Create the MLIR NamedLoc location corresponding to a given tensor
TensorLoc(const TensorT & tensor,Builder builder,Location base)117 Location TensorLoc(const TensorT& tensor, Builder builder, Location base) {
118 if (tensor.name.empty()) {
119 return base;
120 }
121 return mlir::NameLoc::get(builder.getStringAttr(tensor.name), base);
122 }
123
124 // Create the MLIR Location corresponding to a given op. This is an
125 // experimental/debugging feature and production code should not rely on names
126 // of intermediate tensors since importer doesn't guarantee to preserve tensor
127 // names except output tensors.
OpLoc(const OperatorT & op,const std::vector<std::unique_ptr<tflite::TensorT>> & tensors,Builder builder,Location base)128 Location OpLoc(const OperatorT& op,
129 const std::vector<std::unique_ptr<tflite::TensorT>>& tensors,
130 Builder builder, Location base) {
131 if (op.outputs.empty()) return base;
132
133 llvm::SmallVector<Location, 4> locations;
134 locations.reserve(op.outputs.size());
135 for (auto tensor_index : op.outputs) {
136 locations.push_back(TensorLoc(*tensors[tensor_index], builder, base));
137 }
138 return mlir::FusedLoc::get(builder.getContext(), locations);
139 }
140
141 // Returns the correct type for a quantized tensor
142 // We have a special case for constants since they have a higher minimum value.
GetQuantizedType(const TensorT & tensor,Builder builder,bool is_constant=false)143 StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
144 bool is_constant = false) {
145 tflite::QuantizationParametersT& quant_params = *tensor.quantization;
146 if (quant_params.details.AsCustomQuantization()) {
147 return errors::Unimplemented("Cannot handle experimental quantization");
148 }
149
150 bool is_signed = true;
151 mlir::IntegerType storage_type;
152 if (tensor.type == tflite::TensorType_UINT8) {
153 is_signed = false;
154 storage_type = builder.getIntegerType(8);
155 } else {
156 auto raw_elem_type = ConvertElementType(tensor.type, builder);
157 if (!raw_elem_type.isa<mlir::IntegerType>()) {
158 return errors::InvalidArgument(
159 "Quantized tensors must be stored as integers");
160 }
161 storage_type = raw_elem_type.cast<mlir::IntegerType>();
162 }
163
164 // TFlite uses narrow-range [u]int8 for constant buffers of quantized weights.
165 // Since we don't know which ones are weights, we represent this optimization
166 // as a change in the storage bounds for the type for all constants of this
167 // type.
168 bool is_weight_buffer = is_constant && (storage_type.getWidth() == 8);
169
170 int64_t storage_min = QuantizedType::getDefaultMinimumForInteger(
171 is_signed, storage_type.getWidth()) +
172 static_cast<int>(is_weight_buffer);
173 int64_t storage_max = QuantizedType::getDefaultMaximumForInteger(
174 is_signed, storage_type.getWidth());
175 uint32_t flags =
176 is_signed ? mlir::quant::QuantizationFlags::FlagValue::Signed : 0;
177
178 // Rejects if quantized tensors have zero scales.
179 for (float scale : quant_params.scale) {
180 if (scale == 0) {
181 return errors::InvalidArgument(
182 "Quantized tensors must have non-zero scales");
183 }
184 }
185
186 // Scale size can't be zero as it is checked before.
187 if (quant_params.scale.size() != 1) {
188 llvm::SmallVector<double, 4> scales(quant_params.scale.begin(),
189 quant_params.scale.end());
190 return mlir::quant::UniformQuantizedPerAxisType::get(
191 flags, storage_type, builder.getF32Type(), scales,
192 quant_params.zero_point, quant_params.quantized_dimension, storage_min,
193 storage_max);
194 }
195 return mlir::quant::UniformQuantizedType::get(
196 flags, storage_type, builder.getF32Type(), quant_params.scale.at(0),
197 quant_params.zero_point.at(0), storage_min, storage_max);
198 }
199
200 // import float tensor with calibration value into calibrated quantized type.
GetCalibratedQuantizedType(const TensorT & tensor,Builder builder)201 StatusOr<QuantizedType> GetCalibratedQuantizedType(const TensorT& tensor,
202 Builder builder) {
203 if (tensor.quantization == nullptr) {
204 return errors::InvalidArgument("The tensor is not quantized.");
205 }
206 auto raw_elem_type = ConvertElementType(tensor.type, builder);
207 float min = tensor.quantization->min[0];
208 float max = tensor.quantization->max[0];
209 return mlir::quant::CalibratedQuantizedType::get(raw_elem_type, min, max);
210 }
211
GetTensorType(const TensorT & tensor,Builder builder,bool is_constant=false,bool is_intermediate=false)212 StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
213 bool is_constant = false,
214 bool is_intermediate = false) {
215 mlir::Type elem_type = ConvertElementType(tensor.type, builder);
216 if (tensor.type == tflite::TensorType_VARIANT) {
217 llvm::SmallVector<mlir::TensorType> tensor_types;
218 if (tensor.variant_tensors.size() > 1) {
219 return errors::InvalidArgument(
220 "Have more than one nested type in `variant_tensors`.");
221 }
222 for (const auto& nested_tensor : tensor.variant_tensors) {
223 mlir::Type nested_elem_type =
224 ConvertElementType(nested_tensor->type, builder);
225 if (nested_tensor->has_rank) {
226 llvm::SmallVector<int64_t> shape(nested_tensor->shape.begin(),
227 nested_tensor->shape.end());
228 tensor_types.push_back(RankedTensorType::get(shape, nested_elem_type));
229 } else {
230 tensor_types.push_back(UnrankedTensorType::get(nested_elem_type));
231 }
232 }
233 elem_type = mlir::TF::VariantType::get(tensor_types, builder.getContext());
234 }
235 if (IsQuantized(tensor)) {
236 TF_ASSIGN_OR_RETURN(elem_type,
237 GetQuantizedType(tensor, builder, is_constant));
238 }
239
240 // Intermediate tensors with calibration value (but not scale and zero points)
241 // should return calibrated quantized type.
242 if (is_intermediate && tensor.quantization != nullptr &&
243 !IsQuantized(tensor)) {
244 TF_ASSIGN_OR_RETURN(elem_type, GetCalibratedQuantizedType(tensor, builder));
245 }
246
247 if (tensor.shape.empty() && (is_constant || tensor.has_rank)) {
248 return RankedTensorType::get({}, elem_type);
249 }
250
251 if (!tensor.shape_signature.empty()) {
252 llvm::SmallVector<int64_t, 4> shape(tensor.shape_signature.begin(),
253 tensor.shape_signature.end());
254 return RankedTensorType::get(shape, elem_type);
255 }
256
257 if (!tensor.shape.empty()) {
258 llvm::SmallVector<int64_t, 4> shape(tensor.shape.begin(),
259 tensor.shape.end());
260 return RankedTensorType::get(shape, elem_type);
261 }
262
263 return UnrankedTensorType::get(elem_type);
264 }
265
266 // Extract the min max information in the tensor and create the quant stats op.
267 // If the input `tensor` has scale/zero_point, `res` should have quantized
268 // type, thus none stats op is required and nullptr is retruned.
269 // If the min max information is invalid, nullptr is returned.
ConvertMinMaxToStatsOp(const TensorT & tensor,OpBuilder b,Value res)270 mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
271 Value res) {
272 // If the `tensor` has scale/zero_point, it must have been quantized, then the
273 // min/max stats is just for comments, so ignore it.
274 if (!tensor.quantization || IsQuantized(tensor)) return nullptr;
275 // If the result isn't float and unquantizable, the min/max is ignored.
276 if (!res.getType()
277 .cast<mlir::ShapedType>()
278 .getElementType()
279 .isa<mlir::FloatType>()) {
280 return nullptr;
281 }
282 auto mins = tensor.quantization->min;
283 auto maxs = tensor.quantization->max;
284 if (mins.size() != maxs.size() || mins.empty()) return nullptr;
285
286 llvm::SmallVector<llvm::APFloat, 4> min_maxs;
287 min_maxs.reserve(mins.size() * 2);
288 for (int i = 0, end = mins.size(); i < end; ++i) {
289 llvm::APFloat min(mins[i]);
290 llvm::APFloat max(maxs[i]);
291 min_maxs.push_back(min);
292 min_maxs.push_back(max);
293 }
294 // The layer stats contain only the first min/max pairs.
295 mlir::ElementsAttr layer_stats = mlir::DenseFPElementsAttr::get(
296 mlir::RankedTensorType::get({2}, b.getF32Type()),
297 {min_maxs[0], min_maxs[1]});
298 mlir::ElementsAttr axis_stats;
299 mlir::IntegerAttr axis;
300 if (mins.size() > 1) {
301 llvm::SmallVector<int64_t, 4> axis_stats_shape{
302 static_cast<int64_t>(mins.size()), 2};
303 axis_stats = mlir::DenseFPElementsAttr::get(
304 mlir::RankedTensorType::get(axis_stats_shape, b.getF32Type()),
305 min_maxs);
306 // TODO(fengliuai): this quantization dimension isn't correct.
307 axis = b.getI64IntegerAttr(tensor.quantization->quantized_dimension);
308 }
309 return b.create<mlir::quantfork::StatisticsOp>(b.getUnknownLoc(), res,
310 layer_stats, axis_stats, axis);
311 }
312
313 // Returns true if this is a basic LSTM op.
IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union)314 bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
315 if (const auto* op = op_union.AsLSTMOptions()) {
316 return op->kernel_type == tflite::LSTMKernelType_BASIC;
317 } else {
318 return false;
319 }
320 }
321
322 // Gets the MLIR op name with the dialect name for the flatbuffer operator.
GetMlirOpName(const tflite::OperatorT & op,const tflite::OperatorCodeT & op_code)323 std::string GetMlirOpName(const tflite::OperatorT& op,
324 const tflite::OperatorCodeT& op_code) {
325 if (IsBasicLSTMOp(op.builtin_options)) {
326 return std::string("tfl.basic_lstm");
327 }
328 return mlir::GetMlirOpNameFromOpCode(op_code);
329 }
330
331 // The buffers in TFLite flatbuffers have their contents stored as a vector of
332 // bytes that represent little-endian values.
333 // The read_size parameter is present to allow reading both float16 and float32s
334 // without a case split.
335 template <typename T>
ReadAsLittleEndian(ArrayRef<uint8_t> bytes)336 std::vector<T> ReadAsLittleEndian(ArrayRef<uint8_t> bytes) {
337 std::vector<T> ret;
338 size_t read_size = sizeof(T);
339 int bytes_len = bytes.size();
340 assert(bytes_len % read_size == 0);
341
342 int elem_count = bytes_len / read_size;
343 ret.reserve(elem_count);
344
345 const char* data_ptr = reinterpret_cast<const char*>(bytes.data());
346 for (int i = 0; i < elem_count; i++) {
347 ret.push_back(
348 llvm::support::endian::readNext<T, llvm::support::little,
349 llvm::support::unaligned>(data_ptr));
350 }
351 return ret;
352 }
353
ConvertTfliteConstTensor(const tflite::TensorT & tensor,const std::vector<uint8_t> & buffer)354 tensorflow::TensorProto ConvertTfliteConstTensor(
355 const tflite::TensorT& tensor, const std::vector<uint8_t>& buffer) {
356 tensorflow::TensorProto ret;
357 ret.set_dtype(TflTypeToTfType(tensor.type));
358
359 tensorflow::TensorShapeProto* shape = ret.mutable_tensor_shape();
360 shape->set_unknown_rank(false);
361 for (auto dim : tensor.shape) {
362 shape->add_dim()->set_size(int64_t{dim});
363 }
364 // TensorFlow Lite uses tflite::DynamicBufer to encode vector of strings.
365 if (tensor.type == tflite::TensorType_STRING) {
366 for (int i = 0; i < tflite::GetStringCount(buffer.data()); ++i) {
367 tflite::StringRef str = tflite::GetString(buffer.data(), i);
368 ret.add_string_val(str.str, str.len);
369 }
370 return ret;
371 }
372 std::string content;
373 content.assign(reinterpret_cast<const char*>(buffer.data()), buffer.size());
374 ret.set_tensor_content(content);
375 return ret;
376 }
377
ConvertFloatBuffer(mlir::RankedTensorType shaped_type,mlir::FloatType elem_type,const std::vector<uint8_t> & buffer)378 StatusOr<mlir::ElementsAttr> ConvertFloatBuffer(
379 mlir::RankedTensorType shaped_type, mlir::FloatType elem_type,
380 const std::vector<uint8_t>& buffer) {
381 size_t bytes_len = buffer.size();
382
383 // The bytes of floats are stored little-endian.
384 switch (elem_type.getWidth()) {
385 case 16: {
386 assert(bytes_len % 2 == 0);
387 int elem_count = bytes_len / 2;
388 std::vector<llvm::APFloat> values;
389 values.reserve(elem_count);
390
391 const char* data = reinterpret_cast<const char*>(buffer.data());
392 auto& semantics = elem_type.getFloatSemantics();
393
394 for (int i = 0; i < elem_count; i++) {
395 uint16_t bit_repr =
396 llvm::support::endian::readNext<uint16_t, llvm::support::little,
397 llvm::support::unaligned>(data);
398 llvm::APInt int_repr(16, bit_repr);
399 values.emplace_back(semantics, int_repr);
400 }
401
402 return mlir::ElementsAttr(DenseElementsAttr::get(shaped_type, values));
403 }
404 case 32: {
405 assert(bytes_len % 4 == 0);
406 int elem_count = bytes_len / 4;
407 std::vector<float> values;
408 values.reserve(elem_count);
409
410 const char* data = reinterpret_cast<const char*>(buffer.data());
411
412 for (int i = 0; i < elem_count; i++) {
413 uint32_t bit_repr =
414 llvm::support::endian::readNext<uint32_t, llvm::support::little,
415 llvm::support::unaligned>(data);
416 values.push_back(absl::bit_cast<float>(bit_repr));
417 }
418 return mlir::ElementsAttr(
419 DenseElementsAttr::get(shaped_type, ArrayRef<float>(values)));
420 }
421 case 64: {
422 assert(bytes_len % 8 == 0);
423 int elem_count = bytes_len / 8;
424 std::vector<double> values;
425 values.reserve(elem_count);
426
427 const char* data = reinterpret_cast<const char*>(buffer.data());
428
429 for (int i = 0; i < elem_count; i++) {
430 uint64_t bit_repr =
431 llvm::support::endian::readNext<uint64_t, llvm::support::little,
432 llvm::support::unaligned>(data);
433 values.push_back(absl::bit_cast<double>(bit_repr));
434 }
435 return mlir::ElementsAttr(
436 DenseElementsAttr::get(shaped_type, ArrayRef<double>(values)));
437 }
438 }
439 return errors::InvalidArgument("unsupported bit width", elem_type.getWidth());
440 }
441
ConvertIntBuffer(mlir::RankedTensorType shaped_type,mlir::Type elem_type,const std::vector<uint8_t> & buffer)442 StatusOr<mlir::ElementsAttr> ConvertIntBuffer(
443 mlir::RankedTensorType shaped_type, mlir::Type elem_type,
444 const std::vector<uint8_t>& buffer) {
445 unsigned bit_width;
446 if (auto itype = elem_type.dyn_cast<mlir::IntegerType>()) {
447 bit_width = itype.getWidth();
448 } else if (auto qtype = elem_type.dyn_cast<QuantizedType>()) {
449 bit_width = qtype.getStorageTypeIntegralWidth();
450 shaped_type = mlir::RankedTensorType::get(shaped_type.getShape(),
451 qtype.getStorageType());
452 } else {
453 return errors::InvalidArgument("unsupported integer constant type");
454 }
455
456 switch (bit_width) {
457 case 1: {
458 // vector<bool> doesn't convert to an ArrayRef
459 llvm::SmallVector<bool, 8> values;
460 values.reserve(buffer.size());
461 for (auto b : buffer) {
462 values.emplace_back(b != 0);
463 }
464 return mlir::ElementsAttr(
465 DenseElementsAttr::get(shaped_type, ArrayRef<bool>(values)));
466 }
467 case 8: {
468 return mlir::ElementsAttr(
469 DenseElementsAttr::get(shaped_type, ArrayRef<uint8_t>(buffer)));
470 }
471 case 16: {
472 auto values = ReadAsLittleEndian<uint16_t>(buffer);
473 return mlir::ElementsAttr(
474 DenseElementsAttr::get(shaped_type, ArrayRef<uint16_t>(values)));
475 }
476 case 32: {
477 auto values = ReadAsLittleEndian<uint32_t>(buffer);
478 return mlir::ElementsAttr(
479 DenseElementsAttr::get(shaped_type, ArrayRef<uint32_t>(values)));
480 }
481 case 64: {
482 auto values = ReadAsLittleEndian<uint64_t>(buffer);
483 return mlir::ElementsAttr(
484 DenseElementsAttr::get(shaped_type, ArrayRef<uint64_t>(values)));
485 }
486 default:
487 return errors::Unimplemented("Cannot handle bit width ", bit_width);
488 }
489 }
490
BuildExternalConstOp(const tflite::TensorT & tensor,int32_t buffer_index,OpBuilder builder,Location loc)491 StatusOr<Operation*> BuildExternalConstOp(const tflite::TensorT& tensor,
492 int32_t buffer_index,
493 OpBuilder builder, Location loc) {
494 TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
495 /*is_constant=*/true));
496 auto shaped_type = type.dyn_cast<mlir::RankedTensorType>();
497 if (!shaped_type) {
498 return errors::Internal("Constant doesn't have a shape");
499 }
500 auto op = builder.create<tfl::ExternalConstOp>(
501 loc, shaped_type, builder.getI32IntegerAttr(buffer_index));
502 return op.getOperation();
503 }
504
505 // Gets a constant splat for the given value of type. Requires value to be of
506 // type static shaped RankedTensorType. `unique_index` is used to get the unique
507 // value for the attribute.
GetSplat(RankedTensorType type,int unique_index,OpBuilder builder)508 static mlir::ElementsAttr GetSplat(RankedTensorType type, int unique_index,
509 OpBuilder builder) {
510 mlir::Type element_ty = getElementTypeOrSelf(type);
511
512 if (element_ty.isSignlessInteger())
513 return DenseElementsAttr::get(
514 type, builder.getIntegerAttr(element_ty, unique_index));
515
516 if (element_ty.isa<mlir::FloatType>())
517 return DenseElementsAttr::get(
518 type, builder.getFloatAttr(element_ty, unique_index));
519
520 if (auto qtype = element_ty.dyn_cast<QuantizedType>()) {
521 mlir::RankedTensorType new_type =
522 RankedTensorType::get(type.getShape(), qtype.getStorageType());
523 return DenseElementsAttr::get(
524 new_type, builder.getIntegerAttr(qtype.getStorageType(), unique_index));
525 }
526 llvm_unreachable("unhandled element type");
527 }
528
529 // TODO(b/172664358): Creates a new op instead of reusing constant op.
530 // Creates a constant op to represent stateful variable. The function static
531 // variable `stateful_variable_idx` is used as a unique value for each constant
532 // to avoid CSEed. `tensor` is the data structure of flatbuffer. `shaped_type`
533 // is the ShapedType for the const op.
BuildVariableOp(const tflite::TensorT & tensor,mlir::RankedTensorType shaped_type,OpBuilder builder,Location loc)534 Operation* BuildVariableOp(const tflite::TensorT& tensor,
535 mlir::RankedTensorType shaped_type,
536 OpBuilder builder, Location loc) {
537 static int stateful_variable_idx = 0;
538 mlir::ElementsAttr value =
539 GetSplat(shaped_type, stateful_variable_idx++, builder);
540 if (IsQuantized(tensor)) {
541 auto op = builder.create<tfl::QConstOp>(
542 loc, mlir::TypeAttr::get(shaped_type), value);
543 return op.getOperation();
544 }
545 auto op = builder.create<tfl::ConstOp>(loc, value);
546 if (tensor.quantization && !tensor.quantization->min.empty()) {
547 if (auto stats_op =
548 ConvertMinMaxToStatsOp(tensor, builder, op.getResult())) {
549 return stats_op;
550 }
551 }
552 return op.getOperation();
553 }
554
ConvertSparseIndexVector(const tflite::SparseIndexVectorUnion & sparse_index_vector)555 static StatusOr<std::vector<int32_t>> ConvertSparseIndexVector(
556 const tflite::SparseIndexVectorUnion& sparse_index_vector) {
557 if (sparse_index_vector.type == tflite::SparseIndexVector_Int32Vector) {
558 return sparse_index_vector.AsInt32Vector()->values;
559 } else if (sparse_index_vector.type ==
560 tflite::SparseIndexVector_Uint16Vector) {
561 const auto& inputs = sparse_index_vector.AsUint16Vector()->values;
562 std::vector<int32_t> outputs(inputs.size());
563 std::transform(inputs.begin(), inputs.end(), outputs.begin(),
564 [](auto x) { return static_cast<int32_t>(x); });
565 return outputs;
566 } else if (sparse_index_vector.type ==
567 tflite::SparseIndexVector_Uint8Vector) {
568 const auto& inputs = sparse_index_vector.AsUint8Vector()->values;
569 std::vector<int32_t> outputs(inputs.size());
570 std::transform(inputs.begin(), inputs.end(), outputs.begin(),
571 [](auto x) { return static_cast<int32_t>(x); });
572 return outputs;
573 } else {
574 return errors::Unimplemented("Unsupported SparseIndexVector type");
575 }
576 }
577
BuildSparseConstOp(const tflite::TensorT & tensor,const std::vector<uint8_t> & buffer,const mlir::RankedTensorType shaped_type,OpBuilder & builder,Location loc)578 static StatusOr<Operation*> BuildSparseConstOp(
579 const tflite::TensorT& tensor, const std::vector<uint8_t>& buffer,
580 const mlir::RankedTensorType shaped_type, OpBuilder& builder,
581 Location loc) {
582 tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
583 repr.clear_tensor_shape();
584 if (IsQuantized(tensor)) {
585 repr.mutable_tensor_shape()->add_dim()->set_size(buffer.size());
586 repr.set_dtype(tensorflow::DT_INT8);
587 } else {
588 repr.mutable_tensor_shape()->add_dim()->set_size(
589 buffer.size() / (shaped_type.getElementTypeBitWidth() / CHAR_BIT));
590 }
591 TF_ASSIGN_OR_RETURN(mlir::ElementsAttr compressed_data,
592 tensorflow::ConvertTensorProto(repr, &builder));
593
594 const int dim_metadata_size = tensor.sparsity->dim_metadata.size();
595 std::vector<mlir::TFL::DimensionMetadataAttr> dim_metadata(dim_metadata_size);
596 for (int i = 0; i < dim_metadata_size; i++) {
597 if (tensor.sparsity->dim_metadata[i]->format ==
598 tflite::DimensionType_DENSE) {
599 dim_metadata[i] = tfl::DimensionMetadataAttr::get(
600 builder.getContext(),
601 mlir::TFL::DimensionTypeAttr::get(builder.getContext(),
602 tfl::DimensionType::DENSE),
603 tensor.sparsity->dim_metadata[i]->dense_size, {}, {});
604 } else if (tensor.sparsity->dim_metadata[i]->format ==
605 tflite::DimensionType_SPARSE_CSR) {
606 TF_ASSIGN_OR_RETURN(
607 auto segments, ConvertSparseIndexVector(
608 tensor.sparsity->dim_metadata[i]->array_segments));
609 TF_ASSIGN_OR_RETURN(auto indices,
610 ConvertSparseIndexVector(
611 tensor.sparsity->dim_metadata[i]->array_indices));
612 dim_metadata[i] = tfl::DimensionMetadataAttr::get(
613 builder.getContext(),
614 mlir::TFL::DimensionTypeAttr::get(builder.getContext(),
615 tfl::DimensionType::SPARSE_CSR),
616 0, segments, indices);
617 } else {
618 return errors::Unimplemented("Unsupported dimension metadata type");
619 }
620 }
621 auto s_param = tfl::SparsityParameterAttr::get(
622 builder.getContext(), tensor.sparsity->traversal_order,
623 tensor.sparsity->block_map, dim_metadata);
624
625 auto value_type = shaped_type;
626 if (IsQuantized(tensor)) {
627 value_type = RankedTensorType::get(
628 shaped_type.getShape(), shaped_type.getElementType()
629 .dyn_cast<mlir::quant::QuantizedType>()
630 .getStorageType());
631 }
632 std::vector<char> dense_buffer(
633 value_type.getElementType().getIntOrFloatBitWidth() / CHAR_BIT);
634 mlir::Attribute dummy_value =
635 mlir::DenseIntOrFPElementsAttr::getFromRawBuffer(value_type,
636 dense_buffer);
637
638 if (IsQuantized(tensor)) {
639 return builder
640 .create<tfl::SparseQConstOp>(loc, mlir::TypeAttr::get(shaped_type),
641 dummy_value, s_param, compressed_data)
642 .getOperation();
643 }
644 return builder
645 .create<tfl::SparseConstOp>(loc, dummy_value, s_param, compressed_data)
646 .getOperation();
647 }
648
BuildConstOp(const tflite::TensorT & tensor,const std::vector<uint8_t> & buffer,bool is_variable,OpBuilder builder,Location loc)649 StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
650 const std::vector<uint8_t>& buffer,
651 bool is_variable, OpBuilder builder,
652 Location loc) {
653 TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
654 /*is_constant=*/true));
655 auto shaped_type = type.dyn_cast<mlir::RankedTensorType>();
656 if (!shaped_type) {
657 return errors::Internal("Constant doesn't have a shape");
658 }
659
660 if (tensor.sparsity != nullptr) {
661 return BuildSparseConstOp(tensor, buffer, shaped_type, builder, loc);
662 }
663
664 auto elem_type = shaped_type.getElementType();
665
666 mlir::ElementsAttr value;
667 if (is_variable) {
668 return BuildVariableOp(tensor, shaped_type, builder, loc);
669 } else if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
670 TF_ASSIGN_OR_RETURN(value,
671 ConvertFloatBuffer(shaped_type, float_type, buffer));
672 } else if (elem_type.isa<mlir::IntegerType, QuantizedType>()) {
673 TF_ASSIGN_OR_RETURN(value,
674 ConvertIntBuffer(shaped_type, elem_type, buffer));
675 } else if (elem_type.isa<mlir::TF::StringType>()) {
676 tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
677 std::vector<llvm::StringRef> refs;
678 refs.reserve(repr.string_val_size());
679
680 for (const auto& ref : repr.string_val())
681 refs.push_back({ref.data(), ref.size()});
682
683 value = mlir::DenseStringElementsAttr::get(shaped_type, refs);
684 } else if (elem_type.isa<mlir::ComplexType, mlir::TF::TensorFlowType>()) {
685 tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
686 std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
687
688 value = mlir::TF::TensorProtoAttr::get(shaped_type, mangled);
689 } else {
690 return errors::Unimplemented("Constant of unsupported type");
691 }
692
693 if (IsQuantized(tensor)) {
694 auto op = builder.create<tfl::QConstOp>(
695 loc, mlir::TypeAttr::get(shaped_type), value);
696 return op.getOperation();
697 }
698 auto op = builder.create<tfl::ConstOp>(loc, value);
699 return op.getOperation();
700 }
701
702 StatusOr<llvm::SmallVector<mlir::NamedAttribute, 4>>
ConvertSubgraphIdxsToFunctionAttrs(tflite::BuiltinOptionsUnion options,const std::vector<std::string> & func_names,Builder builder)703 ConvertSubgraphIdxsToFunctionAttrs(tflite::BuiltinOptionsUnion options,
704 const std::vector<std::string>& func_names,
705 Builder builder) {
706 if (auto* opts = options.AsCallOnceOptions()) {
707 uint32_t init_idx = opts->init_subgraph_index;
708 if (init_idx >= func_names.size()) {
709 return errors::InvalidArgument("subgraph with index not found: ",
710 init_idx);
711 }
712 auto init_attr = builder.getStringAttr(func_names.at(init_idx));
713
714 return llvm::SmallVector<mlir::NamedAttribute, 4>{
715 builder.getNamedAttr("session_init_function", init_attr)};
716 }
717 if (auto* opts = options.AsIfOptions()) {
718 uint32_t then_idx = opts->then_subgraph_index;
719 if (then_idx >= func_names.size()) {
720 return errors::InvalidArgument("subgraph with index not found: ",
721 then_idx);
722 }
723 auto then_attr =
724 mlir::SymbolRefAttr::get(builder.getContext(), func_names.at(then_idx));
725 uint32_t else_idx = opts->else_subgraph_index;
726 if (else_idx >= func_names.size()) {
727 return errors::InvalidArgument("subgraph with index not found: ",
728 else_idx);
729 }
730 auto else_attr =
731 mlir::SymbolRefAttr::get(builder.getContext(), func_names.at(else_idx));
732
733 return llvm::SmallVector<mlir::NamedAttribute, 4>{
734 builder.getNamedAttr("then_branch", then_attr),
735 builder.getNamedAttr("else_branch", else_attr),
736 // TODO(b/139667752): Analyze statelessness correctly
737 builder.getNamedAttr("is_stateless", builder.getBoolAttr(false))};
738 }
739 if (auto* opts = options.AsWhileOptions()) {
740 uint32_t cond_idx = opts->cond_subgraph_index;
741 if (cond_idx >= func_names.size()) {
742 return errors::InvalidArgument("subgraph with index not found: ",
743 cond_idx);
744 }
745 auto cond_attr =
746 mlir::SymbolRefAttr::get(builder.getContext(), func_names.at(cond_idx));
747 uint32_t body_idx = opts->body_subgraph_index;
748 if (body_idx >= func_names.size()) {
749 return errors::InvalidArgument("subgraph with index not found: ",
750 body_idx);
751 }
752 auto body_attr =
753 mlir::SymbolRefAttr::get(builder.getContext(), func_names.at(body_idx));
754
755 return llvm::SmallVector<mlir::NamedAttribute, 4>{
756 builder.getNamedAttr("cond", cond_attr),
757 builder.getNamedAttr("body", body_attr)};
758 }
759 return llvm::SmallVector<mlir::NamedAttribute, 4>{};
760 }
761
AddOpIntermediatesForLstm(const tflite::OperatorT & op,const std::vector<mlir::TensorType> & intermediate_types,OperationState & op_state,Location loc,OpBuilder & builder)762 Status AddOpIntermediatesForLstm(
763 const tflite::OperatorT& op,
764 const std::vector<mlir::TensorType>& intermediate_types,
765 OperationState& op_state, Location loc, OpBuilder& builder) {
766 if (!op.intermediates.empty()) {
767 if (op.intermediates.size() != 5) {
768 auto err = errors::InvalidArgument(
769 "operator has intermediate tensors but the number of them is not "
770 "five.");
771 return emitError(loc, err.ToString()), err;
772 }
773 // Create intermediate value
774
775 const llvm::SmallVector<llvm::StringRef, 5> kIntermediateNames = {
776 "input_to_input_intermediate", "input_to_forget_intermediate",
777 "input_to_cell_intermediate", "input_to_output_intermediate",
778 "effective_hidden_scale_intermediate"};
779 for (auto type_and_name :
780 llvm::zip(intermediate_types, kIntermediateNames)) {
781 mlir::TypeAttr type_attr =
782 mlir::TypeAttr::get(std::get<0>(type_and_name));
783 auto named_attr =
784 builder.getNamedAttr(std::get<1>(type_and_name), type_attr);
785 op_state.addAttribute(named_attr.getName(), named_attr.getValue());
786 }
787 }
788 return ::tensorflow::OkStatus();
789 }
790
791 // TODO(krzysd) Handle function calls
ConvertOp(const tflite::OperatorT & op,const std::vector<Value> & vals_map,const std::vector<mlir::TensorType> & intermediate_types,Value optional_arg_marker,const std::vector<std::unique_ptr<tflite::OperatorCodeT>> & op_codes,const std::vector<std::string> & func_names,const std::vector<std::unique_ptr<tflite::TensorT>> & tensors,Location loc,OpBuilder builder)792 StatusOr<Operation*> ConvertOp(
793 const tflite::OperatorT& op, const std::vector<Value>& vals_map,
794 const std::vector<mlir::TensorType>& intermediate_types,
795 Value optional_arg_marker,
796 const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
797 const std::vector<std::string>& func_names,
798 const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
799 OpBuilder builder) {
800 llvm::SmallVector<Value, 4> operands;
801 llvm::SmallVector<mlir::Type, 2> outputTypes;
802
803 const tflite::OperatorCodeT& op_code = *op_codes.at(op.opcode_index);
804
805 const std::string op_name = GetMlirOpName(op, op_code);
806
807 OperationState op_state(loc, op_name);
808
809 for (auto input_num : op.inputs) {
810 if (input_num == -1) {
811 assert(optional_arg_marker != nullptr);
812 op_state.addOperands({optional_arg_marker});
813 } else {
814 op_state.addOperands({vals_map.at(input_num)});
815 }
816 }
817
818 for (auto output_num : op.outputs) {
819 auto& tensor = *tensors.at(output_num);
820 auto type_or_err = GetTensorType(tensor, builder);
821 if (!type_or_err.ok()) {
822 return emitError(loc, type_or_err.status().ToString()),
823 type_or_err.status();
824 }
825 auto type = std::move(type_or_err).value();
826
827 if (op_name == "tfl.quantize") {
828 // Special case for quantize: return type must also be in qtype attribute
829 op_state.addAttribute("qtype", mlir::TypeAttr::get(type));
830 } else if (op_name == "tfl.reshape" && op_state.operands.size() == 1) {
831 // Special case for reshape: the second op is optional in the old
832 // converter and kernel, so we create the second operand, which is
833 // required by the new converter, from the reshape op's option.
834 auto new_shape = op.builtin_options.AsReshapeOptions()->new_shape;
835 auto shape_type = RankedTensorType::get(
836 {static_cast<int64_t>(new_shape.size())}, builder.getIntegerType(32));
837
838 mlir::SmallVector<mlir::Attribute, 4> shape;
839 for (auto s : new_shape) {
840 shape.push_back(builder.getI32IntegerAttr(static_cast<int32_t>(s)));
841 }
842 auto output_shape = DenseElementsAttr::get(shape_type, shape);
843 auto shape_op = builder.create<tfl::ConstOp>(loc, output_shape);
844 op_state.addOperands({shape_op});
845 }
846
847 op_state.addTypes({type});
848 }
849
850 // While the last several tensors could be optional tensors for an tfl op, the
851 // number of input operands could vary. Gets the min/max number of
852 // operands from tflite op name.
853 // Also, since the above code special-handles the `tfl.reshape` op and add an
854 // additional input, we put these function block here.
855 llvm::MinMax input_min_max = mlir::OperandNumbersMinMax(op_name);
856 int input_max_num = input_min_max.Max;
857 int op_input_num = op_state.operands.size();
858 if (input_max_num != 0 && input_max_num > op_input_num) {
859 // If the number of current inputs is less than the op definition, fill in
860 // with `none` value,
861 llvm::SmallVector<Value, 4> none_operands(
862 input_max_num - op_input_num,
863 builder.create<mlir::TFL::NoValueOp>(loc, builder.getNoneType(),
864 builder.getUnitAttr()));
865 op_state.addOperands(ArrayRef<Value>(none_operands));
866 }
867
868 if (op_name == "tfl.lstm") {
869 // TODO(b/147587779): add the right region if region is empty.
870 op_state.addRegion();
871 TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc,
872 builder));
873 }
874 if (op_name == "tfl.while") {
875 // Adds two empty regions for "tfl.while". We will fill the regions after
876 // creating the callee functions because the "tfl.while" input/output types
877 // may be different with the callee functions, and the call ops need to sync
878 // with callee function types.
879 op_state.addRegion();
880 op_state.addRegion();
881 }
882 if (op_name == "tfl.unidirectional_sequence_lstm") {
883 TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc,
884 builder));
885 }
886 if (op_name == "tfl.reshape") {
887 // Flattern reshape ops when more than one dimension shape operand is given.
888 mlir::DenseIntElementsAttr shape_attr;
889 if (matchPattern(op_state.operands[1], m_Constant(&shape_attr))) {
890 auto shape_ty =
891 op_state.operands[1].getType().dyn_cast<RankedTensorType>();
892 if (shape_ty != nullptr && shape_ty.hasRank() && shape_ty.getRank() > 1) {
893 llvm::SmallVector<mlir::Attribute, 4> shape;
894 int32_t dim_size = 0;
895 for (const auto& dim :
896 llvm::enumerate(shape_attr.getValues<llvm::APInt>())) {
897 const int64_t size = dim.value().getSExtValue();
898 shape.push_back(
899 builder.getI32IntegerAttr(static_cast<int32_t>(size)));
900 ++dim_size;
901 }
902 auto shape_type = RankedTensorType::get(
903 {static_cast<int32_t>(dim_size)}, builder.getIntegerType(32));
904 auto output_shape = mlir::DenseElementsAttr::get(shape_type, shape);
905 auto shape_op = builder.create<tfl::ConstOp>(loc, output_shape);
906 op_state.operands[1] = shape_op;
907 }
908 }
909 }
910
911 llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
912 auto builtin_code = tflite::GetBuiltinCode(&op_code);
913 if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
914 auto status = mlir::CustomOptionsToAttributes(
915 op_code.custom_code, op.custom_options, builder, loc, &attrs);
916 if (!status.ok()) {
917 return emitError(loc, status.ToString()), status;
918 }
919 } else {
920 mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs);
921 }
922 op_state.addAttributes(attrs);
923
924 // Handle the conversion from subgraph index to functions for If and While. We
925 // will add CallOps in the region to call the functions later for While.
926 TF_ASSIGN_OR_RETURN(auto function_ref_attrs,
927 ConvertSubgraphIdxsToFunctionAttrs(op.builtin_options,
928 func_names, builder));
929 op_state.addAttributes(function_ref_attrs);
930
931 return builder.create(op_state);
932 }
933
934 // Returns indices of the given tensors in the subgraph. Returns error if a
935 // tensor name cannot be found in the subgraph.
GetTensorIndices(const tflite::SubGraphT & subgraph,const std::vector<std::string> & tensor_names)936 StatusOr<std::vector<int>> GetTensorIndices(
937 const tflite::SubGraphT& subgraph,
938 const std::vector<std::string>& tensor_names) {
939 absl::flat_hash_map<std::string, int> name_to_index;
940 for (const auto& index_and_tensor : llvm::enumerate(subgraph.tensors)) {
941 name_to_index[index_and_tensor.value()->name] = index_and_tensor.index();
942 }
943
944 std::vector<int> indices;
945 indices.reserve(tensor_names.size());
946
947 for (const auto& name : tensor_names) {
948 auto found = name_to_index.find(name);
949 if (found != name_to_index.end()) {
950 indices.push_back(found->second);
951 } else {
952 return errors::InvalidArgument("could not find tensor in subgraph: ",
953 name);
954 }
955 }
956
957 return indices;
958 }
959
960 // Given a list of tensor indices, returns a string of concatenated tensor names
961 // wrapped in a NamedAttribute.
962 template <typename ContainerType>
BuildTFEntryFunctionAttribute(const tflite::SubGraphT & subgraph,Builder * builder,const std::string name,const ContainerType indices)963 mlir::NamedAttribute BuildTFEntryFunctionAttribute(
964 const tflite::SubGraphT& subgraph, Builder* builder, const std::string name,
965 const ContainerType indices) {
966 auto tensor_names = llvm::map_range(
967 indices, [&](int i) { return subgraph.tensors.at(i)->name; });
968 return builder->getNamedAttr(
969 name, builder->getStringAttr(llvm::join(tensor_names, ",")));
970 }
971
972 // Traverses the subgraph from output_indices to input_indices and returns the
973 // set of ops that are visited.
PruneSubgraph(const tflite::SubGraphT & subgraph,ArrayRef<int32_t> input_indices,ArrayRef<int32_t> output_indices)974 StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
975 const tflite::SubGraphT& subgraph, ArrayRef<int32_t> input_indices,
976 ArrayRef<int32_t> output_indices) {
977 // Create a map from tensor index to defining op.
978 absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
979 for (const auto& op : subgraph.operators) {
980 for (int32_t output : op->outputs) {
981 if (!llvm::is_contained(input_indices, output)) {
982 defining_op[output] = op.get();
983 }
984 }
985 }
986
987 std::vector<const tflite::OperatorT*> queue;
988 for (int32_t output : output_indices) {
989 if (auto& op = defining_op[output]) {
990 queue.push_back(op);
991 }
992 }
993
994 // Traverse the graph towards inputs.
995 absl::flat_hash_set<const tflite::OperatorT*> visited;
996 while (!queue.empty()) {
997 const tflite::OperatorT* op = queue.back();
998 queue.pop_back();
999 if (!visited.insert(op).second) {
1000 // The node has already been visited.
1001 continue;
1002 }
1003
1004 for (int32_t input : op->inputs) {
1005 // Input tensor may not have a defining op in case it is a subgraph input
1006 // or a constant tensor.
1007 if (auto& op = defining_op[input]) {
1008 queue.push_back(op);
1009 }
1010 }
1011 }
1012
1013 return visited;
1014 }
1015
1016 // We want to adjust the func op according to some cross ops information.
PostProcessFuncOp(FuncOp func)1017 static StatusOr<FuncOp> PostProcessFuncOp(FuncOp func) {
1018 OpBuilder builder(func);
1019 // When a quantized constant is imported, its quantization parameter is set
1020 // to be narrow range. Here revert to be the fully range if the user doesn't
1021 // require narrow range.
1022 func.walk([&](tfl::QConstOp cst) {
1023 Value value = cst.getResult();
1024 Value full_range_const = value;
1025 auto qtype = mlir::quant::UniformQuantizedType::getQuantizedElementType(
1026 value.getType());
1027 // Only the 8-bit constants are imported with narrow range.
1028 if (!qtype || qtype.getStorageTypeIntegralWidth() != 8 ||
1029 !(qtype.isa<mlir::quant::UniformQuantizedType>() ||
1030 qtype.isa<mlir::quant::UniformQuantizedPerAxisType>())) {
1031 return;
1032 }
1033 for (auto& use : value.getUses()) {
1034 Operation* user = use.getOwner();
1035 if (user->hasTrait<mlir::OpTrait::IsTerminator>()) continue;
1036
1037 auto affine_user = llvm::dyn_cast<mlir::AffineQuantizedOpInterface>(user);
1038 if (affine_user &&
1039 affine_user.GetAffineOperandIndex() == use.getOperandNumber() &&
1040 affine_user.RequiredNarrowRangeAffineOperand())
1041 continue;
1042 // Create a fully range quantized constant.
1043 if (full_range_const == value) {
1044 mlir::quant::QuantizedType new_qtype;
1045 if (auto per_axis =
1046 qtype.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
1047 new_qtype = mlir::quant::UniformQuantizedPerAxisType::get(
1048 per_axis.getFlags(), per_axis.getStorageType(),
1049 per_axis.getExpressedType(), per_axis.getScales(),
1050 per_axis.getZeroPoints(), per_axis.getQuantizedDimension(),
1051 per_axis.getStorageTypeMin() - 1, per_axis.getStorageTypeMax());
1052 } else if (auto per_tensor =
1053 qtype.dyn_cast<mlir::quant::UniformQuantizedType>()) {
1054 new_qtype = mlir::quant::UniformQuantizedType::get(
1055 per_tensor.getFlags(), per_tensor.getStorageType(),
1056 per_tensor.getExpressedType(), per_tensor.getScale(),
1057 per_tensor.getZeroPoint(), per_tensor.getStorageTypeMin() - 1,
1058 per_tensor.getStorageTypeMax());
1059 } else {
1060 return; // Should not reach here, as it's already checked.
1061 }
1062 auto new_output_type = new_qtype.castFromExpressedType(
1063 mlir::quant::UniformQuantizedType::castToExpressedType(
1064 value.getType()));
1065 builder.setInsertionPointAfter(cst.getOperation());
1066 auto new_op = builder.create<tfl::QConstOp>(
1067 cst.getLoc(), new_output_type, mlir::TypeAttr::get(new_output_type),
1068 cst.valueAttr());
1069 full_range_const = new_op.output();
1070 }
1071 use.set(full_range_const);
1072 }
1073 if (cst.use_empty()) cst.erase();
1074 });
1075 return func;
1076 }
1077
1078 // Helper method that returns the index of the tensor with name 'tensor_name'
1079 // in the list of tensor names 'tensors'. It allows excluding some indices.
GetTensorIndex(const std::string & tensor_name,llvm::SmallVector<llvm::StringRef,2> tensors,const std::set<int> & exclude_indices={})1080 int GetTensorIndex(const std::string& tensor_name,
1081 llvm::SmallVector<llvm::StringRef, 2> tensors,
1082 const std::set<int>& exclude_indices = {}) {
1083 for (const auto& tensor_index_pair : llvm::enumerate(tensors)) {
1084 if (tensor_index_pair.value() == tensor_name &&
1085 exclude_indices.find(tensor_index_pair.index()) ==
1086 exclude_indices.end())
1087 return tensor_index_pair.index();
1088 }
1089 return -1;
1090 }
1091
1092 // Helper method that returns list of all strings in a StringAttr identified
1093 // by 'attr_key' and values are separated by a comma.
GetStringsFromAttrWithSeparator(mlir::DictionaryAttr attr,const std::string & attr_key)1094 llvm::SmallVector<llvm::StringRef, 2> GetStringsFromAttrWithSeparator(
1095 mlir::DictionaryAttr attr, const std::string& attr_key) {
1096 llvm::SmallVector<llvm::StringRef, 2> result;
1097 if (auto str = attr.get(attr_key).dyn_cast_or_null<mlir::StringAttr>()) {
1098 str.getValue().split(result, ',', /*MaxSplit=*/-1,
1099 /*KeepEmpty=*/false);
1100 }
1101 return result;
1102 }
1103
1104 // Sets signature attributes on the function.
SetSignature(FuncOp func,const tflite::SignatureDefT * signature,const std::vector<std::unique_ptr<tflite::TensorT>> & tensors)1105 void SetSignature(
1106 FuncOp func, const tflite::SignatureDefT* signature,
1107 const std::vector<std::unique_ptr<tflite::TensorT>>& tensors) {
1108 auto* context = func->getContext();
1109 static const char kSignatureDefIndexPath[] = "tf_saved_model.index_path";
1110 static const char kExportedNameAttr[] = "tf_saved_model.exported_names";
1111 static const char kEntryFunctionAttributes[] = "tf.entry_function";
1112
1113 auto dict_attr =
1114 func->getAttrOfType<mlir::DictionaryAttr>(kEntryFunctionAttributes);
1115 if (!dict_attr) return;
1116
1117 // Get Input and output tensor names from attribute.
1118 llvm::SmallVector<llvm::StringRef, 2> input_names =
1119 GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"inputs");
1120 llvm::SmallVector<llvm::StringRef, 2> output_names =
1121 GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"outputs");
1122
1123 for (const auto& input_pair : llvm::enumerate(signature->inputs)) {
1124 const int arg_index = GetTensorIndex(
1125 tensors[input_pair.value()->tensor_index]->name, input_names);
1126 if (arg_index == -1) {
1127 func->emitWarning("Invalid signature tensors specified.");
1128 return;
1129 }
1130 func.setArgAttr(
1131 arg_index, kSignatureDefIndexPath,
1132 mlir::ArrayAttr::get(context, {mlir::StringAttr::get(
1133 context, input_pair.value()->name)}));
1134 }
1135 // Multiple signature outputs can refer to the same tensor. Avoid setting
1136 // signature output attribute at the same index by maintaining a set.
1137 std::set<int> seen_indices;
1138 for (const auto& output_pair : llvm::enumerate(signature->outputs)) {
1139 const int arg_index =
1140 GetTensorIndex(tensors[output_pair.value()->tensor_index]->name,
1141 output_names, seen_indices);
1142 if (arg_index == -1) {
1143 func->emitWarning("Invalid signature tensors specified.");
1144 return;
1145 }
1146 func.setResultAttr(arg_index, kSignatureDefIndexPath,
1147 mlir::ArrayAttr::get(
1148 context, {mlir::StringAttr::get(
1149 context, output_pair.value()->name)}));
1150 seen_indices.insert(arg_index);
1151 }
1152 func->setAttr(
1153 kExportedNameAttr,
1154 mlir::ArrayAttr::get(
1155 context, {mlir::StringAttr::get(context, signature->signature_key)}));
1156 }
1157
1158 // Build a FuncOp from a tflite SubGraph
1159 // The buffers are directly taken
1160 // from the deserialized flatbuffer as we do not have the type information to
1161 // interpret them until this point. The base_loc parameter is the location of
1162 // the flatbuffer as a whole (usually a file). If ordered_output_arrays is not
1163 // empty, then the imported mlir function will only return nodes in
1164 // ordered_output_arrays in the same order. If signature is not null, then the
1165 // inputs/outputs in signature will be attached to the FuncOp.
ConvertSubgraph(const tflite::SubGraphT & subgraph,llvm::StringRef name,const std::vector<std::unique_ptr<tflite::OperatorCodeT>> & op_codes,const std::vector<std::string> & func_names,const std::vector<std::unique_ptr<tflite::BufferT>> & buffers,Location base_loc,Builder builder,bool is_entry_point,bool use_external_constant,const std::vector<std::string> & ordered_input_arrays,const std::vector<std::string> & ordered_output_arrays,bool experimental_prune_unreachable_nodes_unconditionally,const tflite::SignatureDefT * signature)1166 StatusOr<FuncOp> ConvertSubgraph(
1167 const tflite::SubGraphT& subgraph, llvm::StringRef name,
1168 const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
1169 const std::vector<std::string>& func_names,
1170 const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
1171 Location base_loc, Builder builder, bool is_entry_point,
1172 bool use_external_constant,
1173 const std::vector<std::string>& ordered_input_arrays,
1174 const std::vector<std::string>& ordered_output_arrays,
1175 bool experimental_prune_unreachable_nodes_unconditionally,
1176 const tflite::SignatureDefT* signature) {
1177 llvm::SmallVector<mlir::Type, 2> ret_types;
1178 llvm::SmallVector<mlir::Type, 4> input_types;
1179
1180 auto func_loc = mlir::NameLoc::get(builder.getStringAttr(name), base_loc);
1181
1182 std::vector<int> func_inputs = subgraph.inputs;
1183 if (is_entry_point && !ordered_input_arrays.empty()) {
1184 if (!experimental_prune_unreachable_nodes_unconditionally) {
1185 // TODO(b/149922113): Resolve input-arrays/pruning flags interaction.
1186 return errors::InvalidArgument(
1187 "input-arrays should be used with experimental pruning flag");
1188 }
1189 TF_ASSIGN_OR_RETURN(func_inputs,
1190 GetTensorIndices(subgraph, ordered_input_arrays));
1191 }
1192
1193 for (int input : func_inputs) {
1194 auto& tensor = *subgraph.tensors.at(input);
1195 auto type_or_err = GetTensorType(tensor, builder);
1196 if (!type_or_err.ok()) {
1197 emitError(func_loc, "error reading argument types")
1198 << type_or_err.status().ToString();
1199 return type_or_err.status();
1200 }
1201 auto type = std::move(type_or_err).value();
1202 input_types.push_back(type);
1203 }
1204
1205 llvm::SmallVector<bool, 16> is_op_output(subgraph.tensors.size(), false);
1206 for (auto& op : subgraph.operators) {
1207 for (auto output : op->outputs) {
1208 is_op_output[output] = true;
1209 }
1210 }
1211
1212 std::vector<int> func_outputs = subgraph.outputs;
1213 if (is_entry_point && !ordered_output_arrays.empty()) {
1214 TF_ASSIGN_OR_RETURN(func_outputs,
1215 GetTensorIndices(subgraph, ordered_output_arrays));
1216 }
1217
1218 for (auto output : func_outputs) {
1219 const bool is_func_input = std::find(func_inputs.begin(), func_inputs.end(),
1220 output) != func_inputs.end();
1221 bool is_constant = !is_op_output[output] && !is_func_input;
1222
1223 auto type_or_err =
1224 GetTensorType(*subgraph.tensors.at(output), builder, is_constant);
1225 if (!type_or_err.ok()) {
1226 emitError(func_loc, "error reading return types")
1227 << type_or_err.status().ToString();
1228 return type_or_err.status();
1229 }
1230 auto type = std::move(type_or_err).value();
1231 ret_types.push_back(type);
1232 }
1233 auto func_type = builder.getFunctionType(input_types, ret_types);
1234
1235 // Construct function object
1236 auto func = FuncOp::create(func_loc, name, func_type, /* attrs= */ {});
1237 func.addEntryBlock();
1238 auto& body = func.getBody();
1239 OpBuilder op_builder{body};
1240
1241 std::vector<Value> vals_map(subgraph.tensors.size(), nullptr);
1242 Value maybe_optional_arg_marker = nullptr;
1243
1244 // Get or construct MLIR values for each input
1245 for (int i = 0, e = func_inputs.size(); i < e; i++) {
1246 auto input_tensor = func_inputs[i];
1247 const auto& tensor = *subgraph.tensors.at(input_tensor);
1248 auto loc = TensorLoc(tensor, builder, base_loc);
1249 if (vals_map[input_tensor]) {
1250 auto err = errors::FailedPrecondition("duplicate input arguments");
1251 return emitError(loc, err.ToString()), err;
1252 }
1253 Value input_value = func.getArgument(i);
1254
1255 // If the `tensor` has min/max and doesn't have scale/zero_point
1256 // information, a stats op is created to use the input_value, then the
1257 // `tensor` should be mapped to the result of this new stats op.
1258 if (auto stats_op =
1259 ConvertMinMaxToStatsOp(tensor, op_builder, input_value)) {
1260 vals_map[input_tensor] = stats_op->getResult(0);
1261 } else {
1262 vals_map[input_tensor] = input_value;
1263 }
1264 }
1265
1266 // Set tf.entry_function attribute
1267 if (is_entry_point) {
1268 llvm::SmallVector<mlir::NamedAttribute, 2> attributes;
1269 if (!func_inputs.empty()) {
1270 attributes.push_back(BuildTFEntryFunctionAttribute(
1271 subgraph, &builder, "inputs", func_inputs));
1272 }
1273 if (!func_outputs.empty()) {
1274 attributes.push_back(BuildTFEntryFunctionAttribute(
1275 subgraph, &builder, "outputs", func_outputs));
1276 }
1277 func->setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
1278 } else {
1279 func.setPrivate();
1280 }
1281
1282 // Set signature on function.
1283 if (signature) {
1284 SetSignature(func, signature, subgraph.tensors);
1285 }
1286
1287 absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
1288 if (experimental_prune_unreachable_nodes_unconditionally) {
1289 TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
1290 PruneSubgraph(subgraph, func_inputs, func_outputs));
1291 }
1292
1293 // Construct MLIR operators from TFLite operators
1294 for (auto& op : subgraph.operators) {
1295 if (experimental_prune_unreachable_nodes_unconditionally &&
1296 !pruned_subgraph_ops.contains(op)) {
1297 continue;
1298 }
1299
1300 for (auto input_num : op->inputs) {
1301 // The operators in a graph are topologically sorted
1302 // and so if no previous operation has produced a tensor
1303 // it must be a constant.
1304 if (input_num == -1) {
1305 if (maybe_optional_arg_marker == nullptr) {
1306 maybe_optional_arg_marker =
1307 op_builder
1308 .create<mlir::TFL::NoValueOp>(base_loc, builder.getNoneType(),
1309 builder.getUnitAttr())
1310 .getResult();
1311 }
1312 } else if (!vals_map.at(input_num)) {
1313 auto& const_tensor = *subgraph.tensors[input_num];
1314 auto const_loc = TensorLoc(const_tensor, builder, base_loc);
1315 auto op_or_err =
1316 use_external_constant
1317 ? BuildExternalConstOp(const_tensor, const_tensor.buffer,
1318 op_builder, const_loc)
1319 : BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
1320 const_tensor.is_variable, op_builder, const_loc);
1321 if (!op_or_err.ok()) {
1322 return emitError(const_loc, op_or_err.status().ToString()),
1323 op_or_err.status();
1324 }
1325 vals_map[input_num] = op_or_err.ValueOrDie()->getResult(0);
1326 }
1327 }
1328
1329 // Intermediate tensors for LSTMs are used to carry quantization range
1330 // in their types, so we only need and extract their types.
1331 std::vector<mlir::TensorType> intermediate_types;
1332 intermediate_types.reserve(5);
1333 for (auto intermediate : op->intermediates) {
1334 TF_ASSIGN_OR_RETURN(
1335 auto type,
1336 GetTensorType(*subgraph.tensors[intermediate], builder,
1337 /*is_constant=*/false, /*is_intermediate=*/true));
1338 intermediate_types.emplace_back(type);
1339 }
1340
1341 auto op_loc = OpLoc(*op, subgraph.tensors, builder, base_loc);
1342
1343 // If there's an optional argument, maybe_optional_arg_marker has been set
1344 // to a valid Value
1345 TF_ASSIGN_OR_RETURN(
1346 auto* mlir_op,
1347 ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker,
1348 op_codes, func_names, subgraph.tensors, op_loc, op_builder));
1349
1350 // Add the results to the value maps. There are two cases: 1. the result
1351 // tensor does not have min/max values, the original op result is used
1352 // directly; 2. the result tensor has some min/max values, a stats op is
1353 // created, then the result of the stats op is used.
1354 for (const auto& pair : llvm::enumerate(mlir_op->getResults())) {
1355 int output_tensor_index = op->outputs[pair.index()];
1356 auto& tensor = *subgraph.tensors[output_tensor_index];
1357 if (auto stats_op =
1358 ConvertMinMaxToStatsOp(tensor, op_builder, pair.value())) {
1359 vals_map[output_tensor_index] = stats_op->getResult(0);
1360 } else {
1361 vals_map[output_tensor_index] = pair.value();
1362 }
1363 }
1364 }
1365
1366 // Construct return values
1367 llvm::SmallVector<Value, 4> return_operands;
1368 for (auto index : func_outputs) {
1369 if (!vals_map.at(index)) {
1370 auto& const_tensor = *subgraph.tensors[index];
1371 auto const_loc = TensorLoc(const_tensor, builder, base_loc);
1372 auto op_or_err =
1373 use_external_constant
1374 ? BuildExternalConstOp(const_tensor, const_tensor.buffer,
1375 op_builder, const_loc)
1376 : BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
1377 const_tensor.is_variable, op_builder, const_loc);
1378 if (!op_or_err.ok()) {
1379 return emitError(const_loc, op_or_err.status().ToString()),
1380 op_or_err.status();
1381 }
1382 vals_map[index] = op_or_err.ValueOrDie()->getResult(0);
1383 }
1384 return_operands.push_back(vals_map[index]);
1385 }
1386
1387 op_builder.create<mlir::func::ReturnOp>(base_loc, return_operands);
1388
1389 return PostProcessFuncOp(func);
1390 }
1391
1392 // TFLite subgraphs do not necessarily have names, though MLIR functions must
1393 // have them, so we generate a name for subgraphs that are missing one here.
1394 // Note: in TFLite, the first subgraph is the entry point, and in MLIR that
1395 // represents TFLite, this entry point must be called "main"
SubgraphName(bool set_implicit_main_func,unsigned index,const tflite::SubGraphT & subgraph)1396 std::string SubgraphName(bool set_implicit_main_func, unsigned index,
1397 const tflite::SubGraphT& subgraph) {
1398 if (index == 0 && set_implicit_main_func) {
1399 return "main";
1400 }
1401 if (subgraph.name.empty()) {
1402 return llvm::formatv("fn_{0}", index).str();
1403 }
1404 return subgraph.name;
1405 }
1406
1407 // Adds a CallOp in `region` to call the `func` and returns the results of
1408 // CallOp.
AddCallOpInWhileOpRegion(mlir::Region & region,mlir::func::FuncOp func)1409 void AddCallOpInWhileOpRegion(mlir::Region& region, mlir::func::FuncOp func) {
1410 OpBuilder op_builder{region};
1411 region.push_back(new mlir::Block());
1412 Location loc = region.getLoc();
1413 auto inputs = func.getFunctionType().getInputs();
1414 region.addArguments(inputs, mlir::SmallVector<Location>(inputs.size(), loc));
1415 op_builder.setInsertionPointToStart(®ion.front());
1416 auto call_op = op_builder.create<mlir::func::CallOp>(
1417 loc, func.getFunctionType().getResults(), func.getSymName(),
1418 region.getArguments());
1419 op_builder.create<mlir::TFL::YieldOp>(loc, call_op.getResults());
1420 }
1421
1422 // TFL::WhileOp has regions, so we add CallOp to call the FuncOp in the regions
1423 // if we have while ops.
AddRegionsForTflWhileOp(mlir::ModuleOp module)1424 void AddRegionsForTflWhileOp(mlir::ModuleOp module) {
1425 mlir::SymbolTable symbol_table(module);
1426 module.walk([&](mlir::TFL::WhileOp while_op) {
1427 auto cond = symbol_table.lookup<mlir::func::FuncOp>(
1428 while_op->getAttr("cond").cast<mlir::FlatSymbolRefAttr>().getValue());
1429 AddCallOpInWhileOpRegion(while_op.cond(), cond);
1430 while_op->removeAttr("cond");
1431 auto body = symbol_table.lookup<mlir::func::FuncOp>(
1432 while_op->getAttr("body").cast<mlir::FlatSymbolRefAttr>().getValue());
1433 AddCallOpInWhileOpRegion(while_op.body(), body);
1434 while_op->removeAttr("body");
1435 });
1436 }
1437 } // namespace
1438
FlatBufferToMlir(absl::string_view buffer,MLIRContext * context,Location base_loc,bool use_external_constant,const std::vector<std::string> & ordered_input_arrays,const std::vector<std::string> & ordered_output_arrays,bool experimental_prune_unreachable_nodes_unconditionally)1439 OwningOpRef<mlir::ModuleOp> tflite::FlatBufferToMlir(
1440 absl::string_view buffer, MLIRContext* context, Location base_loc,
1441 bool use_external_constant,
1442 const std::vector<std::string>& ordered_input_arrays,
1443 const std::vector<std::string>& ordered_output_arrays,
1444 bool experimental_prune_unreachable_nodes_unconditionally) {
1445 context->loadDialect<mlir::arith::ArithmeticDialect, mlir::func::FuncDialect,
1446 mlir::quant::QuantizationDialect,
1447 mlir::quantfork::QuantizationForkDialect,
1448 mlir::TFL::TensorFlowLiteDialect,
1449 mlir::TF::TensorFlowDialect>();
1450
1451 auto model_ptr =
1452 FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
1453 if (nullptr == model_ptr) {
1454 return emitError(base_loc, "couldn't parse flatbuffer"), nullptr;
1455 }
1456
1457 std::unique_ptr<ModelT> model(model_ptr->GetModel()->UnPack());
1458
1459 auto builder = Builder(context);
1460
1461 std::vector<std::string> func_names;
1462 for (auto& subgraph : model->subgraphs) {
1463 func_names.push_back(subgraph->name);
1464 }
1465
1466 auto module = mlir::ModuleOp::create(base_loc);
1467 // We currently don't use this to make decisions, but we could
1468 // use it in exports or if there are breaking changes
1469 module->setAttr("tfl.schema_version",
1470 builder.getI32IntegerAttr(model->version));
1471 if (!model->description.empty()) {
1472 module->setAttr("tfl.description",
1473 builder.getStringAttr(model->description));
1474 }
1475
1476 if (!model->signature_defs.empty()) {
1477 module->setAttr("tf_saved_model.semantics",
1478 mlir::UnitAttr::get(builder.getContext()));
1479 }
1480
1481 absl::flat_hash_map<uint32_t, tflite::SignatureDefT*>
1482 subgraph_to_signature_map;
1483 for (int i = 0; i < model->signature_defs.size(); i++) {
1484 auto* signature_def = model->signature_defs[i].get();
1485 const uint32_t subgraph_index = signature_def->subgraph_index;
1486 subgraph_to_signature_map[subgraph_index] = signature_def;
1487 }
1488
1489 const bool set_implicit_main_func = subgraph_to_signature_map.size() <= 1;
1490 for (const auto& e : llvm::enumerate(model->subgraphs)) {
1491 auto& subgraph = e.value();
1492 std::string name =
1493 SubgraphName(set_implicit_main_func, e.index(), *subgraph);
1494 uint32_t subgraph_index = static_cast<uint32_t>(e.index());
1495 auto func_or_error = ConvertSubgraph(
1496 *subgraph, name, model->operator_codes, func_names, model->buffers,
1497 base_loc, builder,
1498 /*is_entry_point=*/
1499 set_implicit_main_func
1500 ? e.index() == 0
1501 : subgraph_to_signature_map.contains(subgraph_index),
1502 /*use_external_constant=*/use_external_constant, ordered_input_arrays,
1503 ordered_output_arrays,
1504 experimental_prune_unreachable_nodes_unconditionally,
1505 subgraph_to_signature_map.contains(subgraph_index)
1506 ? subgraph_to_signature_map.at(subgraph_index)
1507 : nullptr);
1508 if (!func_or_error.ok()) {
1509 return emitError(base_loc, "could not translate function ")
1510 << subgraph->name << ": "
1511 << func_or_error.status().error_message(),
1512 nullptr;
1513 }
1514 module.push_back(std::move(func_or_error).value());
1515 }
1516 AddRegionsForTflWhileOp(module);
1517 return OwningOpRef<mlir::ModuleOp>(module);
1518 }
1519