xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/flatbuffer_import.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
17 
18 #include <algorithm>
19 #include <cctype>
20 #include <climits>
21 #include <cstdint>
22 #include <iostream>
23 #include <sstream>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/base/casts.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/string_view.h"
32 #include "llvm/ADT/APFloat.h"
33 #include "llvm/ADT/APInt.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/DenseMap.h"
36 #include "llvm/ADT/None.h"
37 #include "llvm/ADT/Optional.h"
38 #include "llvm/ADT/STLExtras.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/StringExtras.h"
41 #include "llvm/ADT/StringRef.h"
42 #include "llvm/Support/Casting.h"
43 #include "llvm/Support/CommandLine.h"
44 #include "llvm/Support/Endian.h"
45 #include "llvm/Support/FormatVariadic.h"
46 #include "llvm/Support/MemoryBuffer.h"
47 #include "llvm/Support/SourceMgr.h"
48 #include "llvm/Support/raw_ostream.h"
49 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
50 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
51 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
52 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
53 #include "mlir/IR/Attributes.h"  // from @llvm-project
54 #include "mlir/IR/Builders.h"  // from @llvm-project
55 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
56 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
57 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
58 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
59 #include "mlir/IR/Location.h"  // from @llvm-project
60 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
61 #include "mlir/IR/Operation.h"  // from @llvm-project
62 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
63 #include "mlir/IR/Types.h"  // from @llvm-project
64 #include "mlir/IR/Value.h"  // from @llvm-project
65 #include "mlir/Support/LLVM.h"  // from @llvm-project
66 #include "mlir/Tools/mlir-translate/Translation.h"  // from @llvm-project
67 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
68 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
69 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
70 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
71 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
72 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
73 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
74 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
75 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
76 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
77 #include "tensorflow/compiler/xla/statusor.h"
78 #include "tensorflow/core/framework/tensor.pb.h"
79 #include "tensorflow/core/framework/tensor_shape.pb.h"
80 #include "tensorflow/core/platform/errors.h"
81 #include "tensorflow/core/platform/status.h"
82 #include "tensorflow/lite/model.h"
83 #include "tensorflow/lite/schema/schema_generated.h"
84 #include "tensorflow/lite/schema/schema_utils.h"
85 #include "tensorflow/lite/string_util.h"
86 
87 using llvm::ArrayRef;
88 using mlir::Builder;
89 using mlir::DenseElementsAttr;
90 using mlir::Location;
91 using mlir::MLIRContext;
92 using mlir::OpBuilder;
93 using mlir::Operation;
94 using mlir::OperationState;
95 using mlir::OwningOpRef;
96 using mlir::RankedTensorType;
97 using mlir::UnrankedTensorType;
98 using mlir::Value;
99 using mlir::func::FuncOp;
100 using mlir::quant::QuantizedType;
101 using tflite::OperatorT;
102 using tflite::TensorT;
103 using xla::Status;
104 using xla::StatusOr;
105 
106 namespace errors = tensorflow::errors;
107 namespace tfl = mlir::TFL;
108 
109 namespace {
110 
IsQuantized(const TensorT & tensor)111 bool IsQuantized(const TensorT& tensor) {
112   return (tensor.quantization != nullptr) &&
113          !tensor.quantization->zero_point.empty();
114 }
115 
116 // Create the MLIR NamedLoc location corresponding to a given tensor
TensorLoc(const TensorT & tensor,Builder builder,Location base)117 Location TensorLoc(const TensorT& tensor, Builder builder, Location base) {
118   if (tensor.name.empty()) {
119     return base;
120   }
121   return mlir::NameLoc::get(builder.getStringAttr(tensor.name), base);
122 }
123 
124 // Create the MLIR Location corresponding to a given op. This is an
125 // experimental/debugging feature and production code should not rely on names
126 // of intermediate tensors since importer doesn't guarantee to preserve tensor
127 // names except output tensors.
OpLoc(const OperatorT & op,const std::vector<std::unique_ptr<tflite::TensorT>> & tensors,Builder builder,Location base)128 Location OpLoc(const OperatorT& op,
129                const std::vector<std::unique_ptr<tflite::TensorT>>& tensors,
130                Builder builder, Location base) {
131   if (op.outputs.empty()) return base;
132 
133   llvm::SmallVector<Location, 4> locations;
134   locations.reserve(op.outputs.size());
135   for (auto tensor_index : op.outputs) {
136     locations.push_back(TensorLoc(*tensors[tensor_index], builder, base));
137   }
138   return mlir::FusedLoc::get(builder.getContext(), locations);
139 }
140 
141 // Returns the correct type for a quantized tensor
142 // We have a special case for constants since they have a higher minimum value.
GetQuantizedType(const TensorT & tensor,Builder builder,bool is_constant=false)143 StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
144                                          bool is_constant = false) {
145   tflite::QuantizationParametersT& quant_params = *tensor.quantization;
146   if (quant_params.details.AsCustomQuantization()) {
147     return errors::Unimplemented("Cannot handle experimental quantization");
148   }
149 
150   bool is_signed = true;
151   mlir::IntegerType storage_type;
152   if (tensor.type == tflite::TensorType_UINT8) {
153     is_signed = false;
154     storage_type = builder.getIntegerType(8);
155   } else {
156     auto raw_elem_type = ConvertElementType(tensor.type, builder);
157     if (!raw_elem_type.isa<mlir::IntegerType>()) {
158       return errors::InvalidArgument(
159           "Quantized tensors must be stored as integers");
160     }
161     storage_type = raw_elem_type.cast<mlir::IntegerType>();
162   }
163 
164   // TFlite uses narrow-range [u]int8 for constant buffers of quantized weights.
165   // Since we don't know which ones are weights, we represent this optimization
166   // as a change in the storage bounds for the type for all constants of this
167   // type.
168   bool is_weight_buffer = is_constant && (storage_type.getWidth() == 8);
169 
170   int64_t storage_min = QuantizedType::getDefaultMinimumForInteger(
171                             is_signed, storage_type.getWidth()) +
172                         static_cast<int>(is_weight_buffer);
173   int64_t storage_max = QuantizedType::getDefaultMaximumForInteger(
174       is_signed, storage_type.getWidth());
175   uint32_t flags =
176       is_signed ? mlir::quant::QuantizationFlags::FlagValue::Signed : 0;
177 
178   // Rejects if quantized tensors have zero scales.
179   for (float scale : quant_params.scale) {
180     if (scale == 0) {
181       return errors::InvalidArgument(
182           "Quantized tensors must have non-zero scales");
183     }
184   }
185 
186   // Scale size can't be zero as it is checked before.
187   if (quant_params.scale.size() != 1) {
188     llvm::SmallVector<double, 4> scales(quant_params.scale.begin(),
189                                         quant_params.scale.end());
190     return mlir::quant::UniformQuantizedPerAxisType::get(
191         flags, storage_type, builder.getF32Type(), scales,
192         quant_params.zero_point, quant_params.quantized_dimension, storage_min,
193         storage_max);
194   }
195   return mlir::quant::UniformQuantizedType::get(
196       flags, storage_type, builder.getF32Type(), quant_params.scale.at(0),
197       quant_params.zero_point.at(0), storage_min, storage_max);
198 }
199 
200 // import float tensor with calibration value into calibrated quantized type.
GetCalibratedQuantizedType(const TensorT & tensor,Builder builder)201 StatusOr<QuantizedType> GetCalibratedQuantizedType(const TensorT& tensor,
202                                                    Builder builder) {
203   if (tensor.quantization == nullptr) {
204     return errors::InvalidArgument("The tensor is not quantized.");
205   }
206   auto raw_elem_type = ConvertElementType(tensor.type, builder);
207   float min = tensor.quantization->min[0];
208   float max = tensor.quantization->max[0];
209   return mlir::quant::CalibratedQuantizedType::get(raw_elem_type, min, max);
210 }
211 
GetTensorType(const TensorT & tensor,Builder builder,bool is_constant=false,bool is_intermediate=false)212 StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
213                                          bool is_constant = false,
214                                          bool is_intermediate = false) {
215   mlir::Type elem_type = ConvertElementType(tensor.type, builder);
216   if (tensor.type == tflite::TensorType_VARIANT) {
217     llvm::SmallVector<mlir::TensorType> tensor_types;
218     if (tensor.variant_tensors.size() > 1) {
219       return errors::InvalidArgument(
220           "Have more than one nested type in `variant_tensors`.");
221     }
222     for (const auto& nested_tensor : tensor.variant_tensors) {
223       mlir::Type nested_elem_type =
224           ConvertElementType(nested_tensor->type, builder);
225       if (nested_tensor->has_rank) {
226         llvm::SmallVector<int64_t> shape(nested_tensor->shape.begin(),
227                                          nested_tensor->shape.end());
228         tensor_types.push_back(RankedTensorType::get(shape, nested_elem_type));
229       } else {
230         tensor_types.push_back(UnrankedTensorType::get(nested_elem_type));
231       }
232     }
233     elem_type = mlir::TF::VariantType::get(tensor_types, builder.getContext());
234   }
235   if (IsQuantized(tensor)) {
236     TF_ASSIGN_OR_RETURN(elem_type,
237                         GetQuantizedType(tensor, builder, is_constant));
238   }
239 
240   // Intermediate tensors with calibration value (but not scale and zero points)
241   // should return calibrated quantized type.
242   if (is_intermediate && tensor.quantization != nullptr &&
243       !IsQuantized(tensor)) {
244     TF_ASSIGN_OR_RETURN(elem_type, GetCalibratedQuantizedType(tensor, builder));
245   }
246 
247   if (tensor.shape.empty() && (is_constant || tensor.has_rank)) {
248     return RankedTensorType::get({}, elem_type);
249   }
250 
251   if (!tensor.shape_signature.empty()) {
252     llvm::SmallVector<int64_t, 4> shape(tensor.shape_signature.begin(),
253                                         tensor.shape_signature.end());
254     return RankedTensorType::get(shape, elem_type);
255   }
256 
257   if (!tensor.shape.empty()) {
258     llvm::SmallVector<int64_t, 4> shape(tensor.shape.begin(),
259                                         tensor.shape.end());
260     return RankedTensorType::get(shape, elem_type);
261   }
262 
263   return UnrankedTensorType::get(elem_type);
264 }
265 
266 // Extract the min max information in the tensor and create the quant stats op.
267 // If the input `tensor` has scale/zero_point, `res` should have quantized
268 // type, thus none stats op is required and nullptr is retruned.
269 // If the min max information is invalid, nullptr is returned.
ConvertMinMaxToStatsOp(const TensorT & tensor,OpBuilder b,Value res)270 mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
271                                         Value res) {
272   // If the `tensor` has scale/zero_point, it must have been quantized, then the
273   // min/max stats is just for comments, so ignore it.
274   if (!tensor.quantization || IsQuantized(tensor)) return nullptr;
275   // If the result isn't float and unquantizable, the min/max is ignored.
276   if (!res.getType()
277            .cast<mlir::ShapedType>()
278            .getElementType()
279            .isa<mlir::FloatType>()) {
280     return nullptr;
281   }
282   auto mins = tensor.quantization->min;
283   auto maxs = tensor.quantization->max;
284   if (mins.size() != maxs.size() || mins.empty()) return nullptr;
285 
286   llvm::SmallVector<llvm::APFloat, 4> min_maxs;
287   min_maxs.reserve(mins.size() * 2);
288   for (int i = 0, end = mins.size(); i < end; ++i) {
289     llvm::APFloat min(mins[i]);
290     llvm::APFloat max(maxs[i]);
291     min_maxs.push_back(min);
292     min_maxs.push_back(max);
293   }
294   // The layer stats contain only the first min/max pairs.
295   mlir::ElementsAttr layer_stats = mlir::DenseFPElementsAttr::get(
296       mlir::RankedTensorType::get({2}, b.getF32Type()),
297       {min_maxs[0], min_maxs[1]});
298   mlir::ElementsAttr axis_stats;
299   mlir::IntegerAttr axis;
300   if (mins.size() > 1) {
301     llvm::SmallVector<int64_t, 4> axis_stats_shape{
302         static_cast<int64_t>(mins.size()), 2};
303     axis_stats = mlir::DenseFPElementsAttr::get(
304         mlir::RankedTensorType::get(axis_stats_shape, b.getF32Type()),
305         min_maxs);
306     // TODO(fengliuai): this quantization dimension isn't correct.
307     axis = b.getI64IntegerAttr(tensor.quantization->quantized_dimension);
308   }
309   return b.create<mlir::quantfork::StatisticsOp>(b.getUnknownLoc(), res,
310                                                  layer_stats, axis_stats, axis);
311 }
312 
313 // Returns true if this is a basic LSTM op.
IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union)314 bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
315   if (const auto* op = op_union.AsLSTMOptions()) {
316     return op->kernel_type == tflite::LSTMKernelType_BASIC;
317   } else {
318     return false;
319   }
320 }
321 
322 // Gets the MLIR op name with the dialect name for the flatbuffer operator.
GetMlirOpName(const tflite::OperatorT & op,const tflite::OperatorCodeT & op_code)323 std::string GetMlirOpName(const tflite::OperatorT& op,
324                           const tflite::OperatorCodeT& op_code) {
325   if (IsBasicLSTMOp(op.builtin_options)) {
326     return std::string("tfl.basic_lstm");
327   }
328   return mlir::GetMlirOpNameFromOpCode(op_code);
329 }
330 
331 // The buffers in TFLite flatbuffers have their contents stored as a vector of
332 // bytes that represent little-endian values.
333 // The read_size parameter is present to allow reading both float16 and float32s
334 // without a case split.
335 template <typename T>
ReadAsLittleEndian(ArrayRef<uint8_t> bytes)336 std::vector<T> ReadAsLittleEndian(ArrayRef<uint8_t> bytes) {
337   std::vector<T> ret;
338   size_t read_size = sizeof(T);
339   int bytes_len = bytes.size();
340   assert(bytes_len % read_size == 0);
341 
342   int elem_count = bytes_len / read_size;
343   ret.reserve(elem_count);
344 
345   const char* data_ptr = reinterpret_cast<const char*>(bytes.data());
346   for (int i = 0; i < elem_count; i++) {
347     ret.push_back(
348         llvm::support::endian::readNext<T, llvm::support::little,
349                                         llvm::support::unaligned>(data_ptr));
350   }
351   return ret;
352 }
353 
ConvertTfliteConstTensor(const tflite::TensorT & tensor,const std::vector<uint8_t> & buffer)354 tensorflow::TensorProto ConvertTfliteConstTensor(
355     const tflite::TensorT& tensor, const std::vector<uint8_t>& buffer) {
356   tensorflow::TensorProto ret;
357   ret.set_dtype(TflTypeToTfType(tensor.type));
358 
359   tensorflow::TensorShapeProto* shape = ret.mutable_tensor_shape();
360   shape->set_unknown_rank(false);
361   for (auto dim : tensor.shape) {
362     shape->add_dim()->set_size(int64_t{dim});
363   }
364   // TensorFlow Lite uses tflite::DynamicBufer to encode vector of strings.
365   if (tensor.type == tflite::TensorType_STRING) {
366     for (int i = 0; i < tflite::GetStringCount(buffer.data()); ++i) {
367       tflite::StringRef str = tflite::GetString(buffer.data(), i);
368       ret.add_string_val(str.str, str.len);
369     }
370     return ret;
371   }
372   std::string content;
373   content.assign(reinterpret_cast<const char*>(buffer.data()), buffer.size());
374   ret.set_tensor_content(content);
375   return ret;
376 }
377 
ConvertFloatBuffer(mlir::RankedTensorType shaped_type,mlir::FloatType elem_type,const std::vector<uint8_t> & buffer)378 StatusOr<mlir::ElementsAttr> ConvertFloatBuffer(
379     mlir::RankedTensorType shaped_type, mlir::FloatType elem_type,
380     const std::vector<uint8_t>& buffer) {
381   size_t bytes_len = buffer.size();
382 
383   // The bytes of floats are stored little-endian.
384   switch (elem_type.getWidth()) {
385     case 16: {
386       assert(bytes_len % 2 == 0);
387       int elem_count = bytes_len / 2;
388       std::vector<llvm::APFloat> values;
389       values.reserve(elem_count);
390 
391       const char* data = reinterpret_cast<const char*>(buffer.data());
392       auto& semantics = elem_type.getFloatSemantics();
393 
394       for (int i = 0; i < elem_count; i++) {
395         uint16_t bit_repr =
396             llvm::support::endian::readNext<uint16_t, llvm::support::little,
397                                             llvm::support::unaligned>(data);
398         llvm::APInt int_repr(16, bit_repr);
399         values.emplace_back(semantics, int_repr);
400       }
401 
402       return mlir::ElementsAttr(DenseElementsAttr::get(shaped_type, values));
403     }
404     case 32: {
405       assert(bytes_len % 4 == 0);
406       int elem_count = bytes_len / 4;
407       std::vector<float> values;
408       values.reserve(elem_count);
409 
410       const char* data = reinterpret_cast<const char*>(buffer.data());
411 
412       for (int i = 0; i < elem_count; i++) {
413         uint32_t bit_repr =
414             llvm::support::endian::readNext<uint32_t, llvm::support::little,
415                                             llvm::support::unaligned>(data);
416         values.push_back(absl::bit_cast<float>(bit_repr));
417       }
418       return mlir::ElementsAttr(
419           DenseElementsAttr::get(shaped_type, ArrayRef<float>(values)));
420     }
421     case 64: {
422       assert(bytes_len % 8 == 0);
423       int elem_count = bytes_len / 8;
424       std::vector<double> values;
425       values.reserve(elem_count);
426 
427       const char* data = reinterpret_cast<const char*>(buffer.data());
428 
429       for (int i = 0; i < elem_count; i++) {
430         uint64_t bit_repr =
431             llvm::support::endian::readNext<uint64_t, llvm::support::little,
432                                             llvm::support::unaligned>(data);
433         values.push_back(absl::bit_cast<double>(bit_repr));
434       }
435       return mlir::ElementsAttr(
436           DenseElementsAttr::get(shaped_type, ArrayRef<double>(values)));
437     }
438   }
439   return errors::InvalidArgument("unsupported bit width", elem_type.getWidth());
440 }
441 
ConvertIntBuffer(mlir::RankedTensorType shaped_type,mlir::Type elem_type,const std::vector<uint8_t> & buffer)442 StatusOr<mlir::ElementsAttr> ConvertIntBuffer(
443     mlir::RankedTensorType shaped_type, mlir::Type elem_type,
444     const std::vector<uint8_t>& buffer) {
445   unsigned bit_width;
446   if (auto itype = elem_type.dyn_cast<mlir::IntegerType>()) {
447     bit_width = itype.getWidth();
448   } else if (auto qtype = elem_type.dyn_cast<QuantizedType>()) {
449     bit_width = qtype.getStorageTypeIntegralWidth();
450     shaped_type = mlir::RankedTensorType::get(shaped_type.getShape(),
451                                               qtype.getStorageType());
452   } else {
453     return errors::InvalidArgument("unsupported integer constant type");
454   }
455 
456   switch (bit_width) {
457     case 1: {
458       // vector<bool> doesn't convert to an ArrayRef
459       llvm::SmallVector<bool, 8> values;
460       values.reserve(buffer.size());
461       for (auto b : buffer) {
462         values.emplace_back(b != 0);
463       }
464       return mlir::ElementsAttr(
465           DenseElementsAttr::get(shaped_type, ArrayRef<bool>(values)));
466     }
467     case 8: {
468       return mlir::ElementsAttr(
469           DenseElementsAttr::get(shaped_type, ArrayRef<uint8_t>(buffer)));
470     }
471     case 16: {
472       auto values = ReadAsLittleEndian<uint16_t>(buffer);
473       return mlir::ElementsAttr(
474           DenseElementsAttr::get(shaped_type, ArrayRef<uint16_t>(values)));
475     }
476     case 32: {
477       auto values = ReadAsLittleEndian<uint32_t>(buffer);
478       return mlir::ElementsAttr(
479           DenseElementsAttr::get(shaped_type, ArrayRef<uint32_t>(values)));
480     }
481     case 64: {
482       auto values = ReadAsLittleEndian<uint64_t>(buffer);
483       return mlir::ElementsAttr(
484           DenseElementsAttr::get(shaped_type, ArrayRef<uint64_t>(values)));
485     }
486     default:
487       return errors::Unimplemented("Cannot handle bit width ", bit_width);
488   }
489 }
490 
BuildExternalConstOp(const tflite::TensorT & tensor,int32_t buffer_index,OpBuilder builder,Location loc)491 StatusOr<Operation*> BuildExternalConstOp(const tflite::TensorT& tensor,
492                                           int32_t buffer_index,
493                                           OpBuilder builder, Location loc) {
494   TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
495                                                /*is_constant=*/true));
496   auto shaped_type = type.dyn_cast<mlir::RankedTensorType>();
497   if (!shaped_type) {
498     return errors::Internal("Constant doesn't have a shape");
499   }
500   auto op = builder.create<tfl::ExternalConstOp>(
501       loc, shaped_type, builder.getI32IntegerAttr(buffer_index));
502   return op.getOperation();
503 }
504 
505 // Gets a constant splat for the given value of type. Requires value to be of
506 // type static shaped RankedTensorType. `unique_index` is used to get the unique
507 // value for the attribute.
GetSplat(RankedTensorType type,int unique_index,OpBuilder builder)508 static mlir::ElementsAttr GetSplat(RankedTensorType type, int unique_index,
509                                    OpBuilder builder) {
510   mlir::Type element_ty = getElementTypeOrSelf(type);
511 
512   if (element_ty.isSignlessInteger())
513     return DenseElementsAttr::get(
514         type, builder.getIntegerAttr(element_ty, unique_index));
515 
516   if (element_ty.isa<mlir::FloatType>())
517     return DenseElementsAttr::get(
518         type, builder.getFloatAttr(element_ty, unique_index));
519 
520   if (auto qtype = element_ty.dyn_cast<QuantizedType>()) {
521     mlir::RankedTensorType new_type =
522         RankedTensorType::get(type.getShape(), qtype.getStorageType());
523     return DenseElementsAttr::get(
524         new_type, builder.getIntegerAttr(qtype.getStorageType(), unique_index));
525   }
526   llvm_unreachable("unhandled element type");
527 }
528 
529 // TODO(b/172664358): Creates a new op instead of reusing constant op.
530 // Creates a constant op to represent stateful variable. The function static
531 // variable `stateful_variable_idx` is used as a unique value for each constant
532 // to avoid CSEed. `tensor` is the data structure of flatbuffer. `shaped_type`
533 // is the ShapedType for the const op.
BuildVariableOp(const tflite::TensorT & tensor,mlir::RankedTensorType shaped_type,OpBuilder builder,Location loc)534 Operation* BuildVariableOp(const tflite::TensorT& tensor,
535                            mlir::RankedTensorType shaped_type,
536                            OpBuilder builder, Location loc) {
537   static int stateful_variable_idx = 0;
538   mlir::ElementsAttr value =
539       GetSplat(shaped_type, stateful_variable_idx++, builder);
540   if (IsQuantized(tensor)) {
541     auto op = builder.create<tfl::QConstOp>(
542         loc, mlir::TypeAttr::get(shaped_type), value);
543     return op.getOperation();
544   }
545   auto op = builder.create<tfl::ConstOp>(loc, value);
546   if (tensor.quantization && !tensor.quantization->min.empty()) {
547     if (auto stats_op =
548             ConvertMinMaxToStatsOp(tensor, builder, op.getResult())) {
549       return stats_op;
550     }
551   }
552   return op.getOperation();
553 }
554 
ConvertSparseIndexVector(const tflite::SparseIndexVectorUnion & sparse_index_vector)555 static StatusOr<std::vector<int32_t>> ConvertSparseIndexVector(
556     const tflite::SparseIndexVectorUnion& sparse_index_vector) {
557   if (sparse_index_vector.type == tflite::SparseIndexVector_Int32Vector) {
558     return sparse_index_vector.AsInt32Vector()->values;
559   } else if (sparse_index_vector.type ==
560              tflite::SparseIndexVector_Uint16Vector) {
561     const auto& inputs = sparse_index_vector.AsUint16Vector()->values;
562     std::vector<int32_t> outputs(inputs.size());
563     std::transform(inputs.begin(), inputs.end(), outputs.begin(),
564                    [](auto x) { return static_cast<int32_t>(x); });
565     return outputs;
566   } else if (sparse_index_vector.type ==
567              tflite::SparseIndexVector_Uint8Vector) {
568     const auto& inputs = sparse_index_vector.AsUint8Vector()->values;
569     std::vector<int32_t> outputs(inputs.size());
570     std::transform(inputs.begin(), inputs.end(), outputs.begin(),
571                    [](auto x) { return static_cast<int32_t>(x); });
572     return outputs;
573   } else {
574     return errors::Unimplemented("Unsupported SparseIndexVector type");
575   }
576 }
577 
BuildSparseConstOp(const tflite::TensorT & tensor,const std::vector<uint8_t> & buffer,const mlir::RankedTensorType shaped_type,OpBuilder & builder,Location loc)578 static StatusOr<Operation*> BuildSparseConstOp(
579     const tflite::TensorT& tensor, const std::vector<uint8_t>& buffer,
580     const mlir::RankedTensorType shaped_type, OpBuilder& builder,
581     Location loc) {
582   tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
583   repr.clear_tensor_shape();
584   if (IsQuantized(tensor)) {
585     repr.mutable_tensor_shape()->add_dim()->set_size(buffer.size());
586     repr.set_dtype(tensorflow::DT_INT8);
587   } else {
588     repr.mutable_tensor_shape()->add_dim()->set_size(
589         buffer.size() / (shaped_type.getElementTypeBitWidth() / CHAR_BIT));
590   }
591   TF_ASSIGN_OR_RETURN(mlir::ElementsAttr compressed_data,
592                       tensorflow::ConvertTensorProto(repr, &builder));
593 
594   const int dim_metadata_size = tensor.sparsity->dim_metadata.size();
595   std::vector<mlir::TFL::DimensionMetadataAttr> dim_metadata(dim_metadata_size);
596   for (int i = 0; i < dim_metadata_size; i++) {
597     if (tensor.sparsity->dim_metadata[i]->format ==
598         tflite::DimensionType_DENSE) {
599       dim_metadata[i] = tfl::DimensionMetadataAttr::get(
600           builder.getContext(),
601           mlir::TFL::DimensionTypeAttr::get(builder.getContext(),
602                                             tfl::DimensionType::DENSE),
603           tensor.sparsity->dim_metadata[i]->dense_size, {}, {});
604     } else if (tensor.sparsity->dim_metadata[i]->format ==
605                tflite::DimensionType_SPARSE_CSR) {
606       TF_ASSIGN_OR_RETURN(
607           auto segments, ConvertSparseIndexVector(
608                              tensor.sparsity->dim_metadata[i]->array_segments));
609       TF_ASSIGN_OR_RETURN(auto indices,
610                           ConvertSparseIndexVector(
611                               tensor.sparsity->dim_metadata[i]->array_indices));
612       dim_metadata[i] = tfl::DimensionMetadataAttr::get(
613           builder.getContext(),
614           mlir::TFL::DimensionTypeAttr::get(builder.getContext(),
615                                             tfl::DimensionType::SPARSE_CSR),
616           0, segments, indices);
617     } else {
618       return errors::Unimplemented("Unsupported dimension metadata type");
619     }
620   }
621   auto s_param = tfl::SparsityParameterAttr::get(
622       builder.getContext(), tensor.sparsity->traversal_order,
623       tensor.sparsity->block_map, dim_metadata);
624 
625   auto value_type = shaped_type;
626   if (IsQuantized(tensor)) {
627     value_type = RankedTensorType::get(
628         shaped_type.getShape(), shaped_type.getElementType()
629                                     .dyn_cast<mlir::quant::QuantizedType>()
630                                     .getStorageType());
631   }
632   std::vector<char> dense_buffer(
633       value_type.getElementType().getIntOrFloatBitWidth() / CHAR_BIT);
634   mlir::Attribute dummy_value =
635       mlir::DenseIntOrFPElementsAttr::getFromRawBuffer(value_type,
636                                                        dense_buffer);
637 
638   if (IsQuantized(tensor)) {
639     return builder
640         .create<tfl::SparseQConstOp>(loc, mlir::TypeAttr::get(shaped_type),
641                                      dummy_value, s_param, compressed_data)
642         .getOperation();
643   }
644   return builder
645       .create<tfl::SparseConstOp>(loc, dummy_value, s_param, compressed_data)
646       .getOperation();
647 }
648 
BuildConstOp(const tflite::TensorT & tensor,const std::vector<uint8_t> & buffer,bool is_variable,OpBuilder builder,Location loc)649 StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
650                                   const std::vector<uint8_t>& buffer,
651                                   bool is_variable, OpBuilder builder,
652                                   Location loc) {
653   TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
654                                                /*is_constant=*/true));
655   auto shaped_type = type.dyn_cast<mlir::RankedTensorType>();
656   if (!shaped_type) {
657     return errors::Internal("Constant doesn't have a shape");
658   }
659 
660   if (tensor.sparsity != nullptr) {
661     return BuildSparseConstOp(tensor, buffer, shaped_type, builder, loc);
662   }
663 
664   auto elem_type = shaped_type.getElementType();
665 
666   mlir::ElementsAttr value;
667   if (is_variable) {
668     return BuildVariableOp(tensor, shaped_type, builder, loc);
669   } else if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
670     TF_ASSIGN_OR_RETURN(value,
671                         ConvertFloatBuffer(shaped_type, float_type, buffer));
672   } else if (elem_type.isa<mlir::IntegerType, QuantizedType>()) {
673     TF_ASSIGN_OR_RETURN(value,
674                         ConvertIntBuffer(shaped_type, elem_type, buffer));
675   } else if (elem_type.isa<mlir::TF::StringType>()) {
676     tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
677     std::vector<llvm::StringRef> refs;
678     refs.reserve(repr.string_val_size());
679 
680     for (const auto& ref : repr.string_val())
681       refs.push_back({ref.data(), ref.size()});
682 
683     value = mlir::DenseStringElementsAttr::get(shaped_type, refs);
684   } else if (elem_type.isa<mlir::ComplexType, mlir::TF::TensorFlowType>()) {
685     tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
686     std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
687 
688     value = mlir::TF::TensorProtoAttr::get(shaped_type, mangled);
689   } else {
690     return errors::Unimplemented("Constant of unsupported type");
691   }
692 
693   if (IsQuantized(tensor)) {
694     auto op = builder.create<tfl::QConstOp>(
695         loc, mlir::TypeAttr::get(shaped_type), value);
696     return op.getOperation();
697   }
698   auto op = builder.create<tfl::ConstOp>(loc, value);
699   return op.getOperation();
700 }
701 
702 StatusOr<llvm::SmallVector<mlir::NamedAttribute, 4>>
ConvertSubgraphIdxsToFunctionAttrs(tflite::BuiltinOptionsUnion options,const std::vector<std::string> & func_names,Builder builder)703 ConvertSubgraphIdxsToFunctionAttrs(tflite::BuiltinOptionsUnion options,
704                                    const std::vector<std::string>& func_names,
705                                    Builder builder) {
706   if (auto* opts = options.AsCallOnceOptions()) {
707     uint32_t init_idx = opts->init_subgraph_index;
708     if (init_idx >= func_names.size()) {
709       return errors::InvalidArgument("subgraph with index not found: ",
710                                      init_idx);
711     }
712     auto init_attr = builder.getStringAttr(func_names.at(init_idx));
713 
714     return llvm::SmallVector<mlir::NamedAttribute, 4>{
715         builder.getNamedAttr("session_init_function", init_attr)};
716   }
717   if (auto* opts = options.AsIfOptions()) {
718     uint32_t then_idx = opts->then_subgraph_index;
719     if (then_idx >= func_names.size()) {
720       return errors::InvalidArgument("subgraph with index not found: ",
721                                      then_idx);
722     }
723     auto then_attr =
724         mlir::SymbolRefAttr::get(builder.getContext(), func_names.at(then_idx));
725     uint32_t else_idx = opts->else_subgraph_index;
726     if (else_idx >= func_names.size()) {
727       return errors::InvalidArgument("subgraph with index not found: ",
728                                      else_idx);
729     }
730     auto else_attr =
731         mlir::SymbolRefAttr::get(builder.getContext(), func_names.at(else_idx));
732 
733     return llvm::SmallVector<mlir::NamedAttribute, 4>{
734         builder.getNamedAttr("then_branch", then_attr),
735         builder.getNamedAttr("else_branch", else_attr),
736         // TODO(b/139667752): Analyze statelessness correctly
737         builder.getNamedAttr("is_stateless", builder.getBoolAttr(false))};
738   }
739   if (auto* opts = options.AsWhileOptions()) {
740     uint32_t cond_idx = opts->cond_subgraph_index;
741     if (cond_idx >= func_names.size()) {
742       return errors::InvalidArgument("subgraph with index not found: ",
743                                      cond_idx);
744     }
745     auto cond_attr =
746         mlir::SymbolRefAttr::get(builder.getContext(), func_names.at(cond_idx));
747     uint32_t body_idx = opts->body_subgraph_index;
748     if (body_idx >= func_names.size()) {
749       return errors::InvalidArgument("subgraph with index not found: ",
750                                      body_idx);
751     }
752     auto body_attr =
753         mlir::SymbolRefAttr::get(builder.getContext(), func_names.at(body_idx));
754 
755     return llvm::SmallVector<mlir::NamedAttribute, 4>{
756         builder.getNamedAttr("cond", cond_attr),
757         builder.getNamedAttr("body", body_attr)};
758   }
759   return llvm::SmallVector<mlir::NamedAttribute, 4>{};
760 }
761 
AddOpIntermediatesForLstm(const tflite::OperatorT & op,const std::vector<mlir::TensorType> & intermediate_types,OperationState & op_state,Location loc,OpBuilder & builder)762 Status AddOpIntermediatesForLstm(
763     const tflite::OperatorT& op,
764     const std::vector<mlir::TensorType>& intermediate_types,
765     OperationState& op_state, Location loc, OpBuilder& builder) {
766   if (!op.intermediates.empty()) {
767     if (op.intermediates.size() != 5) {
768       auto err = errors::InvalidArgument(
769           "operator has intermediate tensors but the number of them is not "
770           "five.");
771       return emitError(loc, err.ToString()), err;
772     }
773     // Create intermediate value
774 
775     const llvm::SmallVector<llvm::StringRef, 5> kIntermediateNames = {
776         "input_to_input_intermediate", "input_to_forget_intermediate",
777         "input_to_cell_intermediate", "input_to_output_intermediate",
778         "effective_hidden_scale_intermediate"};
779     for (auto type_and_name :
780          llvm::zip(intermediate_types, kIntermediateNames)) {
781       mlir::TypeAttr type_attr =
782           mlir::TypeAttr::get(std::get<0>(type_and_name));
783       auto named_attr =
784           builder.getNamedAttr(std::get<1>(type_and_name), type_attr);
785       op_state.addAttribute(named_attr.getName(), named_attr.getValue());
786     }
787   }
788   return ::tensorflow::OkStatus();
789 }
790 
791 // TODO(krzysd) Handle function calls
ConvertOp(const tflite::OperatorT & op,const std::vector<Value> & vals_map,const std::vector<mlir::TensorType> & intermediate_types,Value optional_arg_marker,const std::vector<std::unique_ptr<tflite::OperatorCodeT>> & op_codes,const std::vector<std::string> & func_names,const std::vector<std::unique_ptr<tflite::TensorT>> & tensors,Location loc,OpBuilder builder)792 StatusOr<Operation*> ConvertOp(
793     const tflite::OperatorT& op, const std::vector<Value>& vals_map,
794     const std::vector<mlir::TensorType>& intermediate_types,
795     Value optional_arg_marker,
796     const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
797     const std::vector<std::string>& func_names,
798     const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
799     OpBuilder builder) {
800   llvm::SmallVector<Value, 4> operands;
801   llvm::SmallVector<mlir::Type, 2> outputTypes;
802 
803   const tflite::OperatorCodeT& op_code = *op_codes.at(op.opcode_index);
804 
805   const std::string op_name = GetMlirOpName(op, op_code);
806 
807   OperationState op_state(loc, op_name);
808 
809   for (auto input_num : op.inputs) {
810     if (input_num == -1) {
811       assert(optional_arg_marker != nullptr);
812       op_state.addOperands({optional_arg_marker});
813     } else {
814       op_state.addOperands({vals_map.at(input_num)});
815     }
816   }
817 
818   for (auto output_num : op.outputs) {
819     auto& tensor = *tensors.at(output_num);
820     auto type_or_err = GetTensorType(tensor, builder);
821     if (!type_or_err.ok()) {
822       return emitError(loc, type_or_err.status().ToString()),
823              type_or_err.status();
824     }
825     auto type = std::move(type_or_err).value();
826 
827     if (op_name == "tfl.quantize") {
828       // Special case for quantize: return type must also be in qtype attribute
829       op_state.addAttribute("qtype", mlir::TypeAttr::get(type));
830     } else if (op_name == "tfl.reshape" && op_state.operands.size() == 1) {
831       // Special case for reshape: the second op is optional in the old
832       // converter and kernel, so we create the second operand, which is
833       // required by the new converter, from the reshape op's option.
834       auto new_shape = op.builtin_options.AsReshapeOptions()->new_shape;
835       auto shape_type = RankedTensorType::get(
836           {static_cast<int64_t>(new_shape.size())}, builder.getIntegerType(32));
837 
838       mlir::SmallVector<mlir::Attribute, 4> shape;
839       for (auto s : new_shape) {
840         shape.push_back(builder.getI32IntegerAttr(static_cast<int32_t>(s)));
841       }
842       auto output_shape = DenseElementsAttr::get(shape_type, shape);
843       auto shape_op = builder.create<tfl::ConstOp>(loc, output_shape);
844       op_state.addOperands({shape_op});
845     }
846 
847     op_state.addTypes({type});
848   }
849 
850   // While the last several tensors could be optional tensors for an tfl op, the
851   // number of input operands could vary. Gets the min/max number of
852   // operands from tflite op name.
853   // Also, since the above code special-handles the `tfl.reshape` op and add an
854   // additional input, we put these function block here.
855   llvm::MinMax input_min_max = mlir::OperandNumbersMinMax(op_name);
856   int input_max_num = input_min_max.Max;
857   int op_input_num = op_state.operands.size();
858   if (input_max_num != 0 && input_max_num > op_input_num) {
859     // If the number of current inputs is less than the op definition, fill in
860     // with `none` value,
861     llvm::SmallVector<Value, 4> none_operands(
862         input_max_num - op_input_num,
863         builder.create<mlir::TFL::NoValueOp>(loc, builder.getNoneType(),
864                                              builder.getUnitAttr()));
865     op_state.addOperands(ArrayRef<Value>(none_operands));
866   }
867 
868   if (op_name == "tfl.lstm") {
869     // TODO(b/147587779): add the right region if region is empty.
870     op_state.addRegion();
871     TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc,
872                                           builder));
873   }
874   if (op_name == "tfl.while") {
875     // Adds two empty regions for "tfl.while". We will fill the regions after
876     // creating the callee functions because the "tfl.while" input/output types
877     // may be different with the callee functions, and the call ops need to sync
878     // with callee function types.
879     op_state.addRegion();
880     op_state.addRegion();
881   }
882   if (op_name == "tfl.unidirectional_sequence_lstm") {
883     TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc,
884                                           builder));
885   }
886   if (op_name == "tfl.reshape") {
887     // Flattern reshape ops when more than one dimension shape operand is given.
888     mlir::DenseIntElementsAttr shape_attr;
889     if (matchPattern(op_state.operands[1], m_Constant(&shape_attr))) {
890       auto shape_ty =
891           op_state.operands[1].getType().dyn_cast<RankedTensorType>();
892       if (shape_ty != nullptr && shape_ty.hasRank() && shape_ty.getRank() > 1) {
893         llvm::SmallVector<mlir::Attribute, 4> shape;
894         int32_t dim_size = 0;
895         for (const auto& dim :
896              llvm::enumerate(shape_attr.getValues<llvm::APInt>())) {
897           const int64_t size = dim.value().getSExtValue();
898           shape.push_back(
899               builder.getI32IntegerAttr(static_cast<int32_t>(size)));
900           ++dim_size;
901         }
902         auto shape_type = RankedTensorType::get(
903             {static_cast<int32_t>(dim_size)}, builder.getIntegerType(32));
904         auto output_shape = mlir::DenseElementsAttr::get(shape_type, shape);
905         auto shape_op = builder.create<tfl::ConstOp>(loc, output_shape);
906         op_state.operands[1] = shape_op;
907       }
908     }
909   }
910 
911   llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
912   auto builtin_code = tflite::GetBuiltinCode(&op_code);
913   if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
914     auto status = mlir::CustomOptionsToAttributes(
915         op_code.custom_code, op.custom_options, builder, loc, &attrs);
916     if (!status.ok()) {
917       return emitError(loc, status.ToString()), status;
918     }
919   } else {
920     mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs);
921   }
922   op_state.addAttributes(attrs);
923 
924   // Handle the conversion from subgraph index to functions for If and While. We
925   // will add CallOps in the region to call the functions later for While.
926   TF_ASSIGN_OR_RETURN(auto function_ref_attrs,
927                       ConvertSubgraphIdxsToFunctionAttrs(op.builtin_options,
928                                                          func_names, builder));
929   op_state.addAttributes(function_ref_attrs);
930 
931   return builder.create(op_state);
932 }
933 
934 // Returns indices of the given tensors in the subgraph. Returns error if a
935 // tensor name cannot be found in the subgraph.
GetTensorIndices(const tflite::SubGraphT & subgraph,const std::vector<std::string> & tensor_names)936 StatusOr<std::vector<int>> GetTensorIndices(
937     const tflite::SubGraphT& subgraph,
938     const std::vector<std::string>& tensor_names) {
939   absl::flat_hash_map<std::string, int> name_to_index;
940   for (const auto& index_and_tensor : llvm::enumerate(subgraph.tensors)) {
941     name_to_index[index_and_tensor.value()->name] = index_and_tensor.index();
942   }
943 
944   std::vector<int> indices;
945   indices.reserve(tensor_names.size());
946 
947   for (const auto& name : tensor_names) {
948     auto found = name_to_index.find(name);
949     if (found != name_to_index.end()) {
950       indices.push_back(found->second);
951     } else {
952       return errors::InvalidArgument("could not find tensor in subgraph: ",
953                                      name);
954     }
955   }
956 
957   return indices;
958 }
959 
960 // Given a list of tensor indices, returns a string of concatenated tensor names
961 // wrapped in a NamedAttribute.
962 template <typename ContainerType>
BuildTFEntryFunctionAttribute(const tflite::SubGraphT & subgraph,Builder * builder,const std::string name,const ContainerType indices)963 mlir::NamedAttribute BuildTFEntryFunctionAttribute(
964     const tflite::SubGraphT& subgraph, Builder* builder, const std::string name,
965     const ContainerType indices) {
966   auto tensor_names = llvm::map_range(
967       indices, [&](int i) { return subgraph.tensors.at(i)->name; });
968   return builder->getNamedAttr(
969       name, builder->getStringAttr(llvm::join(tensor_names, ",")));
970 }
971 
972 // Traverses the subgraph from output_indices to input_indices and returns the
973 // set of ops that are visited.
PruneSubgraph(const tflite::SubGraphT & subgraph,ArrayRef<int32_t> input_indices,ArrayRef<int32_t> output_indices)974 StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
975     const tflite::SubGraphT& subgraph, ArrayRef<int32_t> input_indices,
976     ArrayRef<int32_t> output_indices) {
977   // Create a map from tensor index to defining op.
978   absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
979   for (const auto& op : subgraph.operators) {
980     for (int32_t output : op->outputs) {
981       if (!llvm::is_contained(input_indices, output)) {
982         defining_op[output] = op.get();
983       }
984     }
985   }
986 
987   std::vector<const tflite::OperatorT*> queue;
988   for (int32_t output : output_indices) {
989     if (auto& op = defining_op[output]) {
990       queue.push_back(op);
991     }
992   }
993 
994   // Traverse the graph towards inputs.
995   absl::flat_hash_set<const tflite::OperatorT*> visited;
996   while (!queue.empty()) {
997     const tflite::OperatorT* op = queue.back();
998     queue.pop_back();
999     if (!visited.insert(op).second) {
1000       // The node has already been visited.
1001       continue;
1002     }
1003 
1004     for (int32_t input : op->inputs) {
1005       // Input tensor may not have a defining op in case it is a subgraph input
1006       // or a constant tensor.
1007       if (auto& op = defining_op[input]) {
1008         queue.push_back(op);
1009       }
1010     }
1011   }
1012 
1013   return visited;
1014 }
1015 
1016 // We want to adjust the func op according to some cross ops information.
PostProcessFuncOp(FuncOp func)1017 static StatusOr<FuncOp> PostProcessFuncOp(FuncOp func) {
1018   OpBuilder builder(func);
1019   // When a quantized constant is imported, its quantization parameter is set
1020   // to be narrow range. Here revert to be the fully range if the user doesn't
1021   // require narrow range.
1022   func.walk([&](tfl::QConstOp cst) {
1023     Value value = cst.getResult();
1024     Value full_range_const = value;
1025     auto qtype = mlir::quant::UniformQuantizedType::getQuantizedElementType(
1026         value.getType());
1027     // Only the 8-bit constants are imported with narrow range.
1028     if (!qtype || qtype.getStorageTypeIntegralWidth() != 8 ||
1029         !(qtype.isa<mlir::quant::UniformQuantizedType>() ||
1030           qtype.isa<mlir::quant::UniformQuantizedPerAxisType>())) {
1031       return;
1032     }
1033     for (auto& use : value.getUses()) {
1034       Operation* user = use.getOwner();
1035       if (user->hasTrait<mlir::OpTrait::IsTerminator>()) continue;
1036 
1037       auto affine_user = llvm::dyn_cast<mlir::AffineQuantizedOpInterface>(user);
1038       if (affine_user &&
1039           affine_user.GetAffineOperandIndex() == use.getOperandNumber() &&
1040           affine_user.RequiredNarrowRangeAffineOperand())
1041         continue;
1042       // Create a fully range quantized constant.
1043       if (full_range_const == value) {
1044         mlir::quant::QuantizedType new_qtype;
1045         if (auto per_axis =
1046                 qtype.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
1047           new_qtype = mlir::quant::UniformQuantizedPerAxisType::get(
1048               per_axis.getFlags(), per_axis.getStorageType(),
1049               per_axis.getExpressedType(), per_axis.getScales(),
1050               per_axis.getZeroPoints(), per_axis.getQuantizedDimension(),
1051               per_axis.getStorageTypeMin() - 1, per_axis.getStorageTypeMax());
1052         } else if (auto per_tensor =
1053                        qtype.dyn_cast<mlir::quant::UniformQuantizedType>()) {
1054           new_qtype = mlir::quant::UniformQuantizedType::get(
1055               per_tensor.getFlags(), per_tensor.getStorageType(),
1056               per_tensor.getExpressedType(), per_tensor.getScale(),
1057               per_tensor.getZeroPoint(), per_tensor.getStorageTypeMin() - 1,
1058               per_tensor.getStorageTypeMax());
1059         } else {
1060           return;  // Should not reach here, as it's already checked.
1061         }
1062         auto new_output_type = new_qtype.castFromExpressedType(
1063             mlir::quant::UniformQuantizedType::castToExpressedType(
1064                 value.getType()));
1065         builder.setInsertionPointAfter(cst.getOperation());
1066         auto new_op = builder.create<tfl::QConstOp>(
1067             cst.getLoc(), new_output_type, mlir::TypeAttr::get(new_output_type),
1068             cst.valueAttr());
1069         full_range_const = new_op.output();
1070       }
1071       use.set(full_range_const);
1072     }
1073     if (cst.use_empty()) cst.erase();
1074   });
1075   return func;
1076 }
1077 
1078 // Helper method that returns the index of the tensor with name 'tensor_name'
1079 // in the list of tensor names 'tensors'. It allows excluding some indices.
GetTensorIndex(const std::string & tensor_name,llvm::SmallVector<llvm::StringRef,2> tensors,const std::set<int> & exclude_indices={})1080 int GetTensorIndex(const std::string& tensor_name,
1081                    llvm::SmallVector<llvm::StringRef, 2> tensors,
1082                    const std::set<int>& exclude_indices = {}) {
1083   for (const auto& tensor_index_pair : llvm::enumerate(tensors)) {
1084     if (tensor_index_pair.value() == tensor_name &&
1085         exclude_indices.find(tensor_index_pair.index()) ==
1086             exclude_indices.end())
1087       return tensor_index_pair.index();
1088   }
1089   return -1;
1090 }
1091 
1092 // Helper method that returns list of all strings in a StringAttr identified
1093 // by 'attr_key' and values are separated by a comma.
GetStringsFromAttrWithSeparator(mlir::DictionaryAttr attr,const std::string & attr_key)1094 llvm::SmallVector<llvm::StringRef, 2> GetStringsFromAttrWithSeparator(
1095     mlir::DictionaryAttr attr, const std::string& attr_key) {
1096   llvm::SmallVector<llvm::StringRef, 2> result;
1097   if (auto str = attr.get(attr_key).dyn_cast_or_null<mlir::StringAttr>()) {
1098     str.getValue().split(result, ',', /*MaxSplit=*/-1,
1099                          /*KeepEmpty=*/false);
1100   }
1101   return result;
1102 }
1103 
1104 // Sets signature attributes on the function.
SetSignature(FuncOp func,const tflite::SignatureDefT * signature,const std::vector<std::unique_ptr<tflite::TensorT>> & tensors)1105 void SetSignature(
1106     FuncOp func, const tflite::SignatureDefT* signature,
1107     const std::vector<std::unique_ptr<tflite::TensorT>>& tensors) {
1108   auto* context = func->getContext();
1109   static const char kSignatureDefIndexPath[] = "tf_saved_model.index_path";
1110   static const char kExportedNameAttr[] = "tf_saved_model.exported_names";
1111   static const char kEntryFunctionAttributes[] = "tf.entry_function";
1112 
1113   auto dict_attr =
1114       func->getAttrOfType<mlir::DictionaryAttr>(kEntryFunctionAttributes);
1115   if (!dict_attr) return;
1116 
1117   // Get Input and output tensor names from attribute.
1118   llvm::SmallVector<llvm::StringRef, 2> input_names =
1119       GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"inputs");
1120   llvm::SmallVector<llvm::StringRef, 2> output_names =
1121       GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"outputs");
1122 
1123   for (const auto& input_pair : llvm::enumerate(signature->inputs)) {
1124     const int arg_index = GetTensorIndex(
1125         tensors[input_pair.value()->tensor_index]->name, input_names);
1126     if (arg_index == -1) {
1127       func->emitWarning("Invalid signature tensors specified.");
1128       return;
1129     }
1130     func.setArgAttr(
1131         arg_index, kSignatureDefIndexPath,
1132         mlir::ArrayAttr::get(context, {mlir::StringAttr::get(
1133                                           context, input_pair.value()->name)}));
1134   }
1135   // Multiple signature outputs can refer to the same tensor. Avoid setting
1136   // signature output attribute at the same index by maintaining a set.
1137   std::set<int> seen_indices;
1138   for (const auto& output_pair : llvm::enumerate(signature->outputs)) {
1139     const int arg_index =
1140         GetTensorIndex(tensors[output_pair.value()->tensor_index]->name,
1141                        output_names, seen_indices);
1142     if (arg_index == -1) {
1143       func->emitWarning("Invalid signature tensors specified.");
1144       return;
1145     }
1146     func.setResultAttr(arg_index, kSignatureDefIndexPath,
1147                        mlir::ArrayAttr::get(
1148                            context, {mlir::StringAttr::get(
1149                                         context, output_pair.value()->name)}));
1150     seen_indices.insert(arg_index);
1151   }
1152   func->setAttr(
1153       kExportedNameAttr,
1154       mlir::ArrayAttr::get(
1155           context, {mlir::StringAttr::get(context, signature->signature_key)}));
1156 }
1157 
1158 // Build a FuncOp from a tflite SubGraph
1159 // The buffers are directly taken
1160 // from the deserialized flatbuffer as we do not have the type information to
1161 // interpret them until this point. The base_loc parameter is the location of
1162 // the flatbuffer as a whole (usually a file). If ordered_output_arrays is not
1163 // empty, then the imported mlir function will only return nodes in
1164 // ordered_output_arrays in the same order. If signature is not null, then the
1165 // inputs/outputs in signature will be attached to the FuncOp.
ConvertSubgraph(const tflite::SubGraphT & subgraph,llvm::StringRef name,const std::vector<std::unique_ptr<tflite::OperatorCodeT>> & op_codes,const std::vector<std::string> & func_names,const std::vector<std::unique_ptr<tflite::BufferT>> & buffers,Location base_loc,Builder builder,bool is_entry_point,bool use_external_constant,const std::vector<std::string> & ordered_input_arrays,const std::vector<std::string> & ordered_output_arrays,bool experimental_prune_unreachable_nodes_unconditionally,const tflite::SignatureDefT * signature)1166 StatusOr<FuncOp> ConvertSubgraph(
1167     const tflite::SubGraphT& subgraph, llvm::StringRef name,
1168     const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
1169     const std::vector<std::string>& func_names,
1170     const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
1171     Location base_loc, Builder builder, bool is_entry_point,
1172     bool use_external_constant,
1173     const std::vector<std::string>& ordered_input_arrays,
1174     const std::vector<std::string>& ordered_output_arrays,
1175     bool experimental_prune_unreachable_nodes_unconditionally,
1176     const tflite::SignatureDefT* signature) {
1177   llvm::SmallVector<mlir::Type, 2> ret_types;
1178   llvm::SmallVector<mlir::Type, 4> input_types;
1179 
1180   auto func_loc = mlir::NameLoc::get(builder.getStringAttr(name), base_loc);
1181 
1182   std::vector<int> func_inputs = subgraph.inputs;
1183   if (is_entry_point && !ordered_input_arrays.empty()) {
1184     if (!experimental_prune_unreachable_nodes_unconditionally) {
1185       // TODO(b/149922113): Resolve input-arrays/pruning flags interaction.
1186       return errors::InvalidArgument(
1187           "input-arrays should be used with experimental pruning flag");
1188     }
1189     TF_ASSIGN_OR_RETURN(func_inputs,
1190                         GetTensorIndices(subgraph, ordered_input_arrays));
1191   }
1192 
1193   for (int input : func_inputs) {
1194     auto& tensor = *subgraph.tensors.at(input);
1195     auto type_or_err = GetTensorType(tensor, builder);
1196     if (!type_or_err.ok()) {
1197       emitError(func_loc, "error reading argument types")
1198           << type_or_err.status().ToString();
1199       return type_or_err.status();
1200     }
1201     auto type = std::move(type_or_err).value();
1202     input_types.push_back(type);
1203   }
1204 
1205   llvm::SmallVector<bool, 16> is_op_output(subgraph.tensors.size(), false);
1206   for (auto& op : subgraph.operators) {
1207     for (auto output : op->outputs) {
1208       is_op_output[output] = true;
1209     }
1210   }
1211 
1212   std::vector<int> func_outputs = subgraph.outputs;
1213   if (is_entry_point && !ordered_output_arrays.empty()) {
1214     TF_ASSIGN_OR_RETURN(func_outputs,
1215                         GetTensorIndices(subgraph, ordered_output_arrays));
1216   }
1217 
1218   for (auto output : func_outputs) {
1219     const bool is_func_input = std::find(func_inputs.begin(), func_inputs.end(),
1220                                          output) != func_inputs.end();
1221     bool is_constant = !is_op_output[output] && !is_func_input;
1222 
1223     auto type_or_err =
1224         GetTensorType(*subgraph.tensors.at(output), builder, is_constant);
1225     if (!type_or_err.ok()) {
1226       emitError(func_loc, "error reading return types")
1227           << type_or_err.status().ToString();
1228       return type_or_err.status();
1229     }
1230     auto type = std::move(type_or_err).value();
1231     ret_types.push_back(type);
1232   }
1233   auto func_type = builder.getFunctionType(input_types, ret_types);
1234 
1235   // Construct function object
1236   auto func = FuncOp::create(func_loc, name, func_type, /* attrs= */ {});
1237   func.addEntryBlock();
1238   auto& body = func.getBody();
1239   OpBuilder op_builder{body};
1240 
1241   std::vector<Value> vals_map(subgraph.tensors.size(), nullptr);
1242   Value maybe_optional_arg_marker = nullptr;
1243 
1244   // Get or construct MLIR values for each input
1245   for (int i = 0, e = func_inputs.size(); i < e; i++) {
1246     auto input_tensor = func_inputs[i];
1247     const auto& tensor = *subgraph.tensors.at(input_tensor);
1248     auto loc = TensorLoc(tensor, builder, base_loc);
1249     if (vals_map[input_tensor]) {
1250       auto err = errors::FailedPrecondition("duplicate input arguments");
1251       return emitError(loc, err.ToString()), err;
1252     }
1253     Value input_value = func.getArgument(i);
1254 
1255     // If the `tensor` has min/max and doesn't have scale/zero_point
1256     // information, a stats op is created to use the input_value, then the
1257     // `tensor` should be mapped to the result of this new stats op.
1258     if (auto stats_op =
1259             ConvertMinMaxToStatsOp(tensor, op_builder, input_value)) {
1260       vals_map[input_tensor] = stats_op->getResult(0);
1261     } else {
1262       vals_map[input_tensor] = input_value;
1263     }
1264   }
1265 
1266   // Set tf.entry_function attribute
1267   if (is_entry_point) {
1268     llvm::SmallVector<mlir::NamedAttribute, 2> attributes;
1269     if (!func_inputs.empty()) {
1270       attributes.push_back(BuildTFEntryFunctionAttribute(
1271           subgraph, &builder, "inputs", func_inputs));
1272     }
1273     if (!func_outputs.empty()) {
1274       attributes.push_back(BuildTFEntryFunctionAttribute(
1275           subgraph, &builder, "outputs", func_outputs));
1276     }
1277     func->setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
1278   } else {
1279     func.setPrivate();
1280   }
1281 
1282   // Set signature on function.
1283   if (signature) {
1284     SetSignature(func, signature, subgraph.tensors);
1285   }
1286 
1287   absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
1288   if (experimental_prune_unreachable_nodes_unconditionally) {
1289     TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
1290                         PruneSubgraph(subgraph, func_inputs, func_outputs));
1291   }
1292 
1293   // Construct MLIR operators from TFLite operators
1294   for (auto& op : subgraph.operators) {
1295     if (experimental_prune_unreachable_nodes_unconditionally &&
1296         !pruned_subgraph_ops.contains(op)) {
1297       continue;
1298     }
1299 
1300     for (auto input_num : op->inputs) {
1301       // The operators in a graph are topologically sorted
1302       // and so if no previous operation has produced a tensor
1303       // it must be a constant.
1304       if (input_num == -1) {
1305         if (maybe_optional_arg_marker == nullptr) {
1306           maybe_optional_arg_marker =
1307               op_builder
1308                   .create<mlir::TFL::NoValueOp>(base_loc, builder.getNoneType(),
1309                                                 builder.getUnitAttr())
1310                   .getResult();
1311         }
1312       } else if (!vals_map.at(input_num)) {
1313         auto& const_tensor = *subgraph.tensors[input_num];
1314         auto const_loc = TensorLoc(const_tensor, builder, base_loc);
1315         auto op_or_err =
1316             use_external_constant
1317                 ? BuildExternalConstOp(const_tensor, const_tensor.buffer,
1318                                        op_builder, const_loc)
1319                 : BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
1320                                const_tensor.is_variable, op_builder, const_loc);
1321         if (!op_or_err.ok()) {
1322           return emitError(const_loc, op_or_err.status().ToString()),
1323                  op_or_err.status();
1324         }
1325         vals_map[input_num] = op_or_err.ValueOrDie()->getResult(0);
1326       }
1327     }
1328 
1329     // Intermediate tensors for LSTMs are used to carry quantization range
1330     // in their types, so we only need and extract their types.
1331     std::vector<mlir::TensorType> intermediate_types;
1332     intermediate_types.reserve(5);
1333     for (auto intermediate : op->intermediates) {
1334       TF_ASSIGN_OR_RETURN(
1335           auto type,
1336           GetTensorType(*subgraph.tensors[intermediate], builder,
1337                         /*is_constant=*/false, /*is_intermediate=*/true));
1338       intermediate_types.emplace_back(type);
1339     }
1340 
1341     auto op_loc = OpLoc(*op, subgraph.tensors, builder, base_loc);
1342 
1343     // If there's an optional argument, maybe_optional_arg_marker has been set
1344     // to a valid Value
1345     TF_ASSIGN_OR_RETURN(
1346         auto* mlir_op,
1347         ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker,
1348                   op_codes, func_names, subgraph.tensors, op_loc, op_builder));
1349 
1350     // Add the results to the value maps. There are two cases: 1. the result
1351     // tensor does not have min/max values, the original op result is used
1352     // directly; 2. the result tensor has some min/max values, a stats op is
1353     // created, then the result of the stats op is used.
1354     for (const auto& pair : llvm::enumerate(mlir_op->getResults())) {
1355       int output_tensor_index = op->outputs[pair.index()];
1356       auto& tensor = *subgraph.tensors[output_tensor_index];
1357       if (auto stats_op =
1358               ConvertMinMaxToStatsOp(tensor, op_builder, pair.value())) {
1359         vals_map[output_tensor_index] = stats_op->getResult(0);
1360       } else {
1361         vals_map[output_tensor_index] = pair.value();
1362       }
1363     }
1364   }
1365 
1366   // Construct return values
1367   llvm::SmallVector<Value, 4> return_operands;
1368   for (auto index : func_outputs) {
1369     if (!vals_map.at(index)) {
1370       auto& const_tensor = *subgraph.tensors[index];
1371       auto const_loc = TensorLoc(const_tensor, builder, base_loc);
1372       auto op_or_err =
1373           use_external_constant
1374               ? BuildExternalConstOp(const_tensor, const_tensor.buffer,
1375                                      op_builder, const_loc)
1376               : BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
1377                              const_tensor.is_variable, op_builder, const_loc);
1378       if (!op_or_err.ok()) {
1379         return emitError(const_loc, op_or_err.status().ToString()),
1380                op_or_err.status();
1381       }
1382       vals_map[index] = op_or_err.ValueOrDie()->getResult(0);
1383     }
1384     return_operands.push_back(vals_map[index]);
1385   }
1386 
1387   op_builder.create<mlir::func::ReturnOp>(base_loc, return_operands);
1388 
1389   return PostProcessFuncOp(func);
1390 }
1391 
1392 // TFLite subgraphs do not necessarily have names, though MLIR functions must
1393 // have them, so we generate a name for subgraphs that are missing one here.
1394 // Note: in TFLite, the first subgraph is the entry point, and in MLIR that
1395 // represents TFLite, this entry point must be called "main"
SubgraphName(bool set_implicit_main_func,unsigned index,const tflite::SubGraphT & subgraph)1396 std::string SubgraphName(bool set_implicit_main_func, unsigned index,
1397                          const tflite::SubGraphT& subgraph) {
1398   if (index == 0 && set_implicit_main_func) {
1399     return "main";
1400   }
1401   if (subgraph.name.empty()) {
1402     return llvm::formatv("fn_{0}", index).str();
1403   }
1404   return subgraph.name;
1405 }
1406 
1407 // Adds a CallOp in `region` to call the `func` and returns the results of
1408 // CallOp.
AddCallOpInWhileOpRegion(mlir::Region & region,mlir::func::FuncOp func)1409 void AddCallOpInWhileOpRegion(mlir::Region& region, mlir::func::FuncOp func) {
1410   OpBuilder op_builder{region};
1411   region.push_back(new mlir::Block());
1412   Location loc = region.getLoc();
1413   auto inputs = func.getFunctionType().getInputs();
1414   region.addArguments(inputs, mlir::SmallVector<Location>(inputs.size(), loc));
1415   op_builder.setInsertionPointToStart(&region.front());
1416   auto call_op = op_builder.create<mlir::func::CallOp>(
1417       loc, func.getFunctionType().getResults(), func.getSymName(),
1418       region.getArguments());
1419   op_builder.create<mlir::TFL::YieldOp>(loc, call_op.getResults());
1420 }
1421 
1422 // TFL::WhileOp has regions, so we add CallOp to call the FuncOp in the regions
1423 // if we have while ops.
AddRegionsForTflWhileOp(mlir::ModuleOp module)1424 void AddRegionsForTflWhileOp(mlir::ModuleOp module) {
1425   mlir::SymbolTable symbol_table(module);
1426   module.walk([&](mlir::TFL::WhileOp while_op) {
1427     auto cond = symbol_table.lookup<mlir::func::FuncOp>(
1428         while_op->getAttr("cond").cast<mlir::FlatSymbolRefAttr>().getValue());
1429     AddCallOpInWhileOpRegion(while_op.cond(), cond);
1430     while_op->removeAttr("cond");
1431     auto body = symbol_table.lookup<mlir::func::FuncOp>(
1432         while_op->getAttr("body").cast<mlir::FlatSymbolRefAttr>().getValue());
1433     AddCallOpInWhileOpRegion(while_op.body(), body);
1434     while_op->removeAttr("body");
1435   });
1436 }
1437 }  // namespace
1438 
FlatBufferToMlir(absl::string_view buffer,MLIRContext * context,Location base_loc,bool use_external_constant,const std::vector<std::string> & ordered_input_arrays,const std::vector<std::string> & ordered_output_arrays,bool experimental_prune_unreachable_nodes_unconditionally)1439 OwningOpRef<mlir::ModuleOp> tflite::FlatBufferToMlir(
1440     absl::string_view buffer, MLIRContext* context, Location base_loc,
1441     bool use_external_constant,
1442     const std::vector<std::string>& ordered_input_arrays,
1443     const std::vector<std::string>& ordered_output_arrays,
1444     bool experimental_prune_unreachable_nodes_unconditionally) {
1445   context->loadDialect<mlir::arith::ArithmeticDialect, mlir::func::FuncDialect,
1446                        mlir::quant::QuantizationDialect,
1447                        mlir::quantfork::QuantizationForkDialect,
1448                        mlir::TFL::TensorFlowLiteDialect,
1449                        mlir::TF::TensorFlowDialect>();
1450 
1451   auto model_ptr =
1452       FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
1453   if (nullptr == model_ptr) {
1454     return emitError(base_loc, "couldn't parse flatbuffer"), nullptr;
1455   }
1456 
1457   std::unique_ptr<ModelT> model(model_ptr->GetModel()->UnPack());
1458 
1459   auto builder = Builder(context);
1460 
1461   std::vector<std::string> func_names;
1462   for (auto& subgraph : model->subgraphs) {
1463     func_names.push_back(subgraph->name);
1464   }
1465 
1466   auto module = mlir::ModuleOp::create(base_loc);
1467   // We currently don't use this to make decisions, but we could
1468   // use it in exports or if there are breaking changes
1469   module->setAttr("tfl.schema_version",
1470                   builder.getI32IntegerAttr(model->version));
1471   if (!model->description.empty()) {
1472     module->setAttr("tfl.description",
1473                     builder.getStringAttr(model->description));
1474   }
1475 
1476   if (!model->signature_defs.empty()) {
1477     module->setAttr("tf_saved_model.semantics",
1478                     mlir::UnitAttr::get(builder.getContext()));
1479   }
1480 
1481   absl::flat_hash_map<uint32_t, tflite::SignatureDefT*>
1482       subgraph_to_signature_map;
1483   for (int i = 0; i < model->signature_defs.size(); i++) {
1484     auto* signature_def = model->signature_defs[i].get();
1485     const uint32_t subgraph_index = signature_def->subgraph_index;
1486     subgraph_to_signature_map[subgraph_index] = signature_def;
1487   }
1488 
1489   const bool set_implicit_main_func = subgraph_to_signature_map.size() <= 1;
1490   for (const auto& e : llvm::enumerate(model->subgraphs)) {
1491     auto& subgraph = e.value();
1492     std::string name =
1493         SubgraphName(set_implicit_main_func, e.index(), *subgraph);
1494     uint32_t subgraph_index = static_cast<uint32_t>(e.index());
1495     auto func_or_error = ConvertSubgraph(
1496         *subgraph, name, model->operator_codes, func_names, model->buffers,
1497         base_loc, builder,
1498         /*is_entry_point=*/
1499         set_implicit_main_func
1500             ? e.index() == 0
1501             : subgraph_to_signature_map.contains(subgraph_index),
1502         /*use_external_constant=*/use_external_constant, ordered_input_arrays,
1503         ordered_output_arrays,
1504         experimental_prune_unreachable_nodes_unconditionally,
1505         subgraph_to_signature_map.contains(subgraph_index)
1506             ? subgraph_to_signature_map.at(subgraph_index)
1507             : nullptr);
1508     if (!func_or_error.ok()) {
1509       return emitError(base_loc, "could not translate function ")
1510                  << subgraph->name << ": "
1511                  << func_or_error.status().error_message(),
1512              nullptr;
1513     }
1514     module.push_back(std::move(func_or_error).value());
1515   }
1516   AddRegionsForTflWhileOp(module);
1517   return OwningOpRef<mlir::ModuleOp>(module);
1518 }
1519