xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/flatbuffer_export.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
17 
18 #include <stddef.h>
19 #include <stdlib.h>
20 
21 #include <algorithm>
22 #include <cstdint>
23 #include <memory>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/base/attributes.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/match.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_format.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/string_view.h"
36 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
37 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
38 #include "llvm/ADT/ArrayRef.h"
39 #include "llvm/ADT/DenseMap.h"
40 #include "llvm/ADT/None.h"
41 #include "llvm/ADT/Optional.h"
42 #include "llvm/ADT/STLExtras.h"
43 #include "llvm/ADT/StringRef.h"
44 #include "llvm/Support/Casting.h"
45 #include "llvm/Support/CommandLine.h"
46 #include "llvm/Support/FormatVariadic.h"
47 #include "llvm/Support/ToolOutputFile.h"
48 #include "llvm/Support/raw_ostream.h"
49 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
50 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
51 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
52 #include "mlir/IR/Attributes.h"  // from @llvm-project
53 #include "mlir/IR/Builders.h"  // from @llvm-project
54 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
55 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
56 #include "mlir/IR/Location.h"  // from @llvm-project
57 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
58 #include "mlir/IR/Operation.h"  // from @llvm-project
59 #include "mlir/IR/Types.h"  // from @llvm-project
60 #include "mlir/IR/Value.h"  // from @llvm-project
61 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
62 #include "mlir/Tools/mlir-translate/Translation.h"  // from @llvm-project
63 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
64 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
65 #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h"
66 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
67 #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h"
68 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
70 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
71 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
72 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
73 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
74 #include "tensorflow/compiler/xla/statusor.h"
75 #include "tensorflow/core/framework/attr_value.pb.h"
76 #include "tensorflow/core/framework/node_def.pb.h"
77 #include "tensorflow/core/platform/errors.h"
78 #include "tensorflow/core/platform/logging.h"
79 #include "tensorflow/core/platform/status.h"
80 #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
81 #include "tensorflow/lite/experimental/remat/metadata_util.h"
82 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
83 #include "tensorflow/lite/schema/schema_conversion_utils.h"
84 #include "tensorflow/lite/schema/schema_generated.h"
85 #include "tensorflow/lite/string_util.h"
86 #include "tensorflow/lite/tools/versioning/gpu_compatibility.h"
87 #include "tensorflow/lite/tools/versioning/op_version.h"
88 #include "tensorflow/lite/tools/versioning/runtime_version.h"
89 #include "tensorflow/lite/version.h"
90 
91 using llvm::dyn_cast;
92 using llvm::formatv;
93 using llvm::isa;
94 using llvm::Optional;
95 using llvm::StringRef;
96 using llvm::Twine;
97 using mlir::Dialect;
98 using mlir::ElementsAttr;
99 using mlir::MLIRContext;
100 using mlir::ModuleOp;
101 using mlir::NoneType;
102 using mlir::Operation;
103 using mlir::Region;
104 using mlir::StringAttr;
105 using mlir::TensorType;
106 using mlir::Type;
107 using mlir::UnknownLoc;
108 using mlir::Value;
109 using mlir::WalkResult;
110 using mlir::func::FuncOp;
111 using tensorflow::OpOrArgLocNameMapper;
112 using tensorflow::OpOrArgNameMapper;
113 using tensorflow::Status;
114 using tflite::flex::IsAllowlistedFlexOp;
115 using xla::StatusOr;
116 
117 template <typename T>
118 using BufferOffset = flatbuffers::Offset<T>;
119 
120 template <typename T>
121 using VectorBufferOffset = flatbuffers::Offset<flatbuffers::Vector<T>>;
122 
123 using CustomOptionsOffset = VectorBufferOffset<uint8_t>;
124 
125 namespace error = tensorflow::error;
126 namespace tfl = mlir::TFL;
127 
128 ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex";
129 
130 // Use initial buffer size in flatbuffer builder to be same as the initial size
131 // used by the TOCO export. (It does not explain rationale for this choice.)
132 constexpr size_t kInitialBufferSize = 10240;
133 
134 // Set `isSigned` to false if the `type` is an 8-bit unsigned integer type.
135 // Since tflite doesn't support unsigned for other types, returns error if
136 // `isSigned` is set to false for other types.
GetTFLiteType(Type type,bool is_signed=true)137 static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
138                                                   bool is_signed = true) {
139   if (!is_signed && type.isSignlessInteger(8)) {
140     return tflite::TensorType_UINT8;
141   }
142   if (!is_signed) {
143     return Status(error::INVALID_ARGUMENT,
144                   "'isSigned' can only be set for 8-bits integer type");
145   }
146 
147   if (type.isF32()) {
148     return tflite::TensorType_FLOAT32;
149   } else if (type.isF16()) {
150     return tflite::TensorType_FLOAT16;
151   } else if (type.isF64()) {
152     return tflite::TensorType_FLOAT64;
153   } else if (type.isa<mlir::TF::StringType>()) {
154     return tflite::TensorType_STRING;
155   } else if (type.isa<mlir::TF::Quint8Type>()) {
156     return tflite::TensorType_UINT8;
157   } else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
158     auto ftype = complex_type.getElementType();
159     if (ftype.isF32()) {
160       return tflite::TensorType_COMPLEX64;
161     }
162     if (ftype.isF64()) {
163       return tflite::TensorType_COMPLEX128;
164     }
165     return Status(error::INVALID_ARGUMENT, "Unsupported type");
166   } else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
167     switch (itype.getWidth()) {
168       case 1:
169         return tflite::TensorType_BOOL;
170       case 8:
171         return itype.isUnsigned() ? tflite::TensorType_UINT8
172                                   : tflite::TensorType_INT8;
173       case 16:
174         return itype.isUnsigned() ? tflite::TensorType_UINT16
175                                   : tflite::TensorType_INT16;
176       case 32:
177         return itype.isUnsigned() ? tflite::TensorType_UINT32
178                                   : tflite::TensorType_INT32;
179       case 64:
180         return itype.isUnsigned() ? tflite::TensorType_UINT64
181                                   : tflite::TensorType_INT64;
182     }
183   } else if (auto q_uniform_type =
184                  type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
185     return GetTFLiteType(q_uniform_type.getStorageType(),
186                          q_uniform_type.isSigned());
187   } else if (auto q_peraxis_type =
188                  type.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
189     return GetTFLiteType(q_peraxis_type.getStorageType(),
190                          q_peraxis_type.isSigned());
191   } else if (auto q_calibrated_type =
192                  type.dyn_cast<mlir::quant::CalibratedQuantizedType>()) {
193     return GetTFLiteType(q_calibrated_type.getExpressedType());
194   } else if (type.isa<mlir::TF::ResourceType>()) {
195     return tflite::TensorType_RESOURCE;
196   } else if (type.isa<mlir::TF::VariantType>()) {
197     return tflite::TensorType_VARIANT;
198   }
199   // TFLite export fills FLOAT32 for unknown data types. Returning an error
200   // for now for safety and this could be revisited when required.
201   return Status(error::INVALID_ARGUMENT, "Unsupported type");
202 }
203 
IsConst(Operation * op)204 static bool IsConst(Operation* op) {
205   return isa<mlir::func::ConstantOp, mlir::arith::ConstantOp, mlir::TF::ConstOp,
206              tfl::ConstOp, tfl::QConstOp, tfl::SparseConstOp,
207              tfl::SparseQConstOp, mlir::TFL::NoValueOp>(op);
208 }
209 
IsTFResourceOp(Operation * op)210 static bool IsTFResourceOp(Operation* op) {
211   for (const auto& operand : op->getOperands()) {
212     auto elementType = getElementTypeOrSelf(operand.getType());
213     if (elementType.isa<mlir::TF::ResourceType>()) {
214       return true;
215     }
216   }
217   for (const auto& result : op->getResults()) {
218     auto elementType = getElementTypeOrSelf(result.getType());
219     if (elementType.isa<mlir::TF::ResourceType>()) {
220       return true;
221     }
222   }
223   return false;
224 }
225 
226 // Returns whether the current op is not supported by the TF Lite runtime.
IsUnsupportedFlexOp(const std::string & op_name)227 static bool IsUnsupportedFlexOp(const std::string& op_name) {
228   return op_name == "PartitionedCall" || op_name == "StatefulPartitionedCall";
229 }
230 
231 // Create description of operation that could not be converted.
GetOpDescriptionForDebug(Operation * inst)232 static std::string GetOpDescriptionForDebug(Operation* inst) {
233   const int kLargeElementsAttr = 16;
234   std::string op_str;
235   llvm::raw_string_ostream os(op_str);
236   inst->getName().print(os);
237   os << "(";
238   if (!inst->getOperandTypes().empty()) {
239     bool first = true;
240     for (Type operand_type : inst->getOperandTypes()) {
241       os << (!first ? ", " : "");
242       first = false;
243       os << operand_type;
244     }
245   }
246   os << ") -> (";
247   if (!inst->getResultTypes().empty()) {
248     bool first = true;
249     for (Type result_type : inst->getResultTypes()) {
250       os << (!first ? ", " : "");
251       first = false;
252       os << result_type;
253     }
254   }
255   os << ")";
256   // Print out attributes except for large elementsattributes (which should
257   // rarely be the cause why the legalization didn't happen).
258   if (!inst->getAttrDictionary().empty()) {
259     os << " : {";
260     bool first = true;
261     for (auto& named_attr : inst->getAttrDictionary()) {
262       os << (!first ? ", " : "");
263       first = false;
264       os << named_attr.getName().getValue() << " = ";
265       if (auto element_attr = named_attr.getValue().dyn_cast<ElementsAttr>()) {
266         if (element_attr.getNumElements() <= kLargeElementsAttr) {
267           element_attr.print(os);
268         } else {
269           os << "<large>";
270         }
271       } else {
272         named_attr.getValue().print(os);
273       }
274     }
275     os << "}";
276   }
277   return os.str();
278 }
279 
280 // Create a summary with the given information regarding op names and
281 // descriptions.
GetOpsSummary(const std::map<std::string,std::set<std::string>> & ops,const std::string & summary_title)282 static std::string GetOpsSummary(
283     const std::map<std::string, std::set<std::string>>& ops,
284     const std::string& summary_title) {
285   std::string op_str;
286   llvm::raw_string_ostream os(op_str);
287 
288   std::vector<std::string> keys;
289   keys.reserve(ops.size());
290 
291   std::vector<std::string> values;
292   values.reserve(ops.size());
293 
294   for (auto const& op_name_and_details : ops) {
295     keys.push_back(op_name_and_details.first);
296     for (auto const& op_detail : op_name_and_details.second) {
297       values.push_back(op_detail);
298     }
299   }
300 
301   os << summary_title << " ops: " << absl::StrJoin(keys, ", ") << "\n";
302   os << "Details:\n\t" << absl::StrJoin(values, "\n\t");
303 
304   return os.str();
305 }
306 
307 template <typename T>
HasValidTFLiteType(Value value,T & error_handler)308 static bool HasValidTFLiteType(Value value, T& error_handler) {
309   // None type is allowed to represent unspecified operands.
310   if (value.getType().isa<NoneType>()) return true;
311 
312   auto type = value.getType().dyn_cast<TensorType>();
313   if (!type) {
314     if (auto op = value.getDefiningOp()) {
315       error_handler.emitError()
316           << '\'' << op << "' should produce value of tensor type instead of "
317           << value.getType();
318       return false;
319     }
320     error_handler.emitError("expected tensor type, got ") << value.getType();
321     return false;
322   }
323 
324   Type element_type = type.getElementType();
325   auto status = GetTFLiteType(element_type);
326   if (!status.ok()) {
327     return error_handler.emitError(
328                formatv("Failed to convert element type '{0}': {1}",
329                        element_type, status.status().error_message())),
330            false;
331   }
332   return true;
333 }
334 
335 // Returns true if the module holds all the invariants expected by the
336 // Translator class.
337 // TODO(hinsu): Now that translation is done by making a single pass over the
338 // MLIR module, consider inlining these validation checks at the place where
339 // these invariants are assumed instead of checking upfront.
IsValidTFLiteMlirModule(ModuleOp module)340 static bool IsValidTFLiteMlirModule(ModuleOp module) {
341   MLIRContext* context = module.getContext();
342 
343   // Verify that module has a function named main.
344   FuncOp main_fn = module.lookupSymbol<FuncOp>("main");
345   if (!main_fn) {
346     int entry_func_count = 0;
347     for (auto fn : module.getOps<FuncOp>()) {
348       auto attrs = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
349       if (attrs && !attrs.empty()) {
350         ++entry_func_count;
351       }
352     }
353 
354     // Verify that module has a least one enrty function.
355     if (entry_func_count == 0) {
356       return emitError(UnknownLoc::get(context),
357                        "should have a least one entry function"),
358              false;
359     }
360   }
361 
362   for (auto fn : module.getOps<FuncOp>()) {
363     if (!llvm::hasSingleElement(fn)) {
364       return fn.emitError("should have exactly one basic block"), false;
365     }
366     auto& bb = fn.front();
367 
368     for (auto arg : bb.getArguments()) {
369       if (!HasValidTFLiteType(arg, fn)) {
370         auto elementType = getElementTypeOrSelf(arg.getType());
371         if (elementType.isa<mlir::TF::VariantType>()) {
372           return fn.emitError(
373                      "function argument uses variant type. Currently, the "
374                      "variant type is not natively supported in TFLite. Please "
375                      "consider not using the variant type: ")
376                      << arg.getType(),
377                  false;
378         }
379         return fn.emitError("invalid TFLite type: ") << arg.getType(), false;
380       }
381     }
382 
383     // Verify that all operations except the terminator have exactly one
384     // result of type supported by TFLite (or is a ControlType, which
385     // will be removed later by ExtractControlEdges.)
386     for (auto& inst : bb) {
387       if (inst.hasTrait<mlir::OpTrait::IsTerminator>()) break;
388 
389       for (auto result : inst.getResults()) {
390         if (result.getType().isa<mlir::TFL::ControlType>()) continue;
391         if (!HasValidTFLiteType(result, inst)) {
392           auto elementType = getElementTypeOrSelf(result.getType());
393           if (elementType.isa<mlir::TF::VariantType>()) {
394             return inst.emitError(
395                        "operand result uses variant type. Currently, the "
396                        "variant type is not natively supported in TFLite. "
397                        "Please "
398                        "consider not using the variant type: ")
399                        << result.getType(),
400                    false;
401           }
402           return fn.emitError("invalid TFLite type: ") << result.getType(),
403                  false;
404         }
405       }
406     }
407   }
408 
409   return true;
410 }
411 
GetTensorFlowNodeDef(::mlir::Operation * inst)412 static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef(
413     ::mlir::Operation* inst) {
414   // We pass empty string for the original node_def name since Flex runtime
415   // does not care about this being set correctly on node_def. There is no
416   // "easy" (see b/120948529) way yet to get this from MLIR inst.
417   auto status_or_node_def = tensorflow::ConvertTFDialectOpToNodeDef(
418       inst, /*name=*/"", /*ignore_unregistered_attrs=*/true);
419   if (!status_or_node_def.ok()) {
420     inst->emitOpError(
421         Twine("failed to obtain TensorFlow nodedef with status: " +
422               status_or_node_def.status().ToString()));
423     return {};
424   }
425   return std::move(status_or_node_def.ValueOrDie());
426 }
427 
428 // Converts a mlir padding StringRef to TfLitePadding.
429 // Returns llvm::None if conversion fails.
GetTflitePadding(Operation * inst,llvm::StringRef padding)430 static Optional<TfLitePadding> GetTflitePadding(Operation* inst,
431                                                 llvm::StringRef padding) {
432   const tflite::Padding padding_attr =
433       std::move(llvm::StringSwitch<tflite::Padding>(padding)
434                     .Case("SAME", tflite::Padding_SAME)
435                     .Case("VALID", tflite::Padding_VALID));
436   if (padding_attr == tflite::Padding_SAME) {
437     return kTfLitePaddingSame;
438   }
439   if (padding_attr == tflite::Padding_VALID) {
440     return kTfLitePaddingValid;
441   }
442 
443   return inst->emitOpError() << "Invalid padding attribute: " << padding,
444          llvm::None;
445 }
446 
447 // Extracts TfLitePoolParams from a TFL custom op.
448 // Template parameter, TFLOp, should be a TFL custom op containing attributes
449 // generated from TfLitePoolParams.
450 // Returns llvm::None if conversion fails.
451 template <typename TFLOp>
GetTflitePoolParams(Operation * inst,TFLOp op)452 static Optional<TfLitePoolParams> GetTflitePoolParams(Operation* inst,
453                                                       TFLOp op) {
454   TfLitePoolParams pool_params;
455   pool_params.stride_height = op.stride_h().getSExtValue();
456   pool_params.stride_width = op.stride_w().getSExtValue();
457   pool_params.filter_height = op.filter_h().getSExtValue();
458   pool_params.filter_width = op.filter_w().getSExtValue();
459   const auto padding = GetTflitePadding(inst, op.padding());
460   if (padding) {
461     pool_params.padding = *padding;
462     pool_params.activation = kTfLiteActNone;
463     pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
464     return pool_params;
465   }
466 
467   return llvm::None;
468 }
469 
470 namespace {
471 
472 // Helper struct that wraps inputs/outputs of a single SignatureDef.
473 struct SignatureDefData {
474   // Note, we are using maps here to make order deterministic
475   // for easily testing only.
476 
477   // Inputs defined in the signature def mapped to tensor names.
478   std::map<std::string, std::string> inputs;
479   // Outputs defined in the signature def mapped to tensor names.
480   std::map<std::string, std::string> outputs;
481   // Signature key.
482   std::string signature_key;
483   // Subgraph index.
484   uint32_t subgraph_index;
485 };
486 
487 // Translates an MLIR module in TFLite dialect to TFLite FlatBuffer.
488 class Translator {
489  public:
490   // Translates the given MLIR module into TFLite FlatBuffer format and returns
491   // the serialized output. Returns llvm::None on unsupported, invalid inputs or
492   // internal error.
493   static Optional<std::string> Translate(
494       ModuleOp module, const toco::TocoFlags& toco_flags,
495       const std::unordered_set<std::string>& tags,
496       OpOrArgNameMapper* op_or_arg_name_mapper,
497       const std::map<std::string, std::string>& metadata);
498 
499  private:
500   enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp };
Translator(ModuleOp module,const toco::TocoFlags & toco_flags,const std::unordered_set<std::string> & saved_model_tags,OpOrArgNameMapper * op_or_arg_name_mapper,const std::map<std::string,std::string> & metadata)501   explicit Translator(ModuleOp module, const toco::TocoFlags& toco_flags,
502                       const std::unordered_set<std::string>& saved_model_tags,
503                       OpOrArgNameMapper* op_or_arg_name_mapper,
504                       const std::map<std::string, std::string>& metadata)
505       : module_(module),
506         name_mapper_(*op_or_arg_name_mapper),
507         builder_(kInitialBufferSize),
508         saved_model_tags_(saved_model_tags),
509         allow_all_select_tf_ops_(toco_flags.allow_all_select_tf_ops()),
510         select_user_tf_ops_(toco_flags.select_user_tf_ops().begin(),
511                             toco_flags.select_user_tf_ops().end()),
512         metadata_(metadata),
513         supported_backends_(toco_flags.supported_backends().begin(),
514                             toco_flags.supported_backends().end()) {
515     // The first buffer must be empty according to the schema definition.
516     empty_buffer_ = tflite::CreateBuffer(builder_);
517     buffers_.push_back(empty_buffer_);
518     if (!toco_flags.force_select_tf_ops()) {
519       enabled_op_types_.emplace(OpType::kTfliteBuiltin);
520     }
521     if (toco_flags.enable_select_tf_ops()) {
522       enabled_op_types_.emplace(OpType::kSelectTf);
523     }
524     if (toco_flags.allow_custom_ops()) {
525       enabled_op_types_.emplace(OpType::kCustomOp);
526     }
527     tf_dialect_ =
528         module.getContext()->getOrLoadDialect<mlir::TF::TensorFlowDialect>();
529     tfl_dialect_ = module.getContext()
530                        ->getOrLoadDialect<mlir::TFL::TensorFlowLiteDialect>();
531     // Right now the TF executor dialect is still needed to build NodeDef.
532     module.getContext()
533         ->getOrLoadDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
534   }
535 
536   Optional<std::string> TranslateInternal();
537 
538   // Returns TFLite buffer populated with constant value if the operation is
539   // TFLite constant operation. Otherwise, returns an empty buffer. Emits error
540   // and returns llvm::None on failure.
541   Optional<BufferOffset<tflite::Buffer>> BuildBuffer(Operation* inst);
542 
543   // Build TFLite tensor from the given type. This function is for tfl.lstm
544   // intermediates, which should have UniformQuantizedType.
545   Optional<BufferOffset<tflite::Tensor>> BuildTensorFromType(
546       mlir::Type type, const std::string& name);
547 
548   // Builds TF::VariantType from the given element type. Returns llvm::None if
549   // failure. Returns empty vector if the element type is not TF::VariantType or
550   // there is empty TensorType in the TF::VariantType.
551   Optional<std::vector<BufferOffset<tflite::VariantSubType>>>
552   BuildTFVariantType(mlir::Type element_type);
553 
554   // Builds TFLite tensor from the given value. `buffer_idx` is index of the
555   // corresponding buffer. Emits error and returns llvm::None on failure.
556   Optional<BufferOffset<tflite::Tensor>> BuildTensor(
557       Value value, const std::string& name, unsigned buffer_idx,
558       const Optional<BufferOffset<tflite::QuantizationParameters>>&
559           quant_parameters);
560 
561   // TODO(b/137395003): Legalize tf.IfOp to TFLite dialect, and change the
562   // following method to handle TFL::IfOp.
563   BufferOffset<tflite::Operator> BuildIfOperator(
564       mlir::TF::IfOp op, const std::vector<int32_t>& operands,
565       const std::vector<int32_t>& results);
566 
567   // Build while operator where cond & body are regions.
568   Optional<BufferOffset<tflite::Operator>> BuildWhileOperator(
569       mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
570       const std::vector<int32_t>& results);
571 
572   // Build call once operator.
573   BufferOffset<tflite::Operator> BuildCallOnceOperator(
574       mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
575       const std::vector<int32_t>& results);
576 
577   BufferOffset<tflite::Operator> BuildNumericVerifyOperator(
578       mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
579       const std::vector<int32_t>& results);
580 
581   BufferOffset<tflite::Operator> BuildCustomOperator(
582       Operation* inst, mlir::TFL::CustomOp op,
583       const std::vector<int32_t>& operands,
584       const std::vector<int32_t>& results);
585 
586   Optional<CustomOptionsOffset> CreateFlexOpCustomOptions(
587       const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
588 
589   Optional<CustomOptionsOffset> CreateCustomOpCustomOptions(
590       const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
591 
592   std::unique_ptr<flexbuffers::Builder> CreateFlexBuilderWithNodeAttrs(
593       const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
594 
595   // Returns opcode index for op identified by the op_name, if already
596   // available. Otherwise, creates a new OperatorCode using the given `builtin`
597   // operator and associates it with `op_name`.
598   uint32_t GetOpcodeIndex(const std::string& op_name,
599                           tflite::BuiltinOperator builtin);
600 
601   // Builds operator for the given operation with specified operand and result
602   // tensor indices. Emits an error and returns llvm::None on failure.
603   Optional<BufferOffset<tflite::Operator>> BuildOperator(
604       Operation* inst, std::vector<int32_t> operands,
605       const std::vector<int32_t>& results,
606       const std::vector<int32_t>& intermediates);
607 
608   // Returns the quantization parameters for output value of "quant.stats" op.
609   BufferOffset<tflite::QuantizationParameters>
610   GetQuantizationForQuantStatsOpOutput(mlir::quantfork::StatisticsOp stats_op);
611 
612   // Build a subgraph with a given name out of the region either corresponding
613   // to a function's body or while op. Modifies *region by calling
614   // ExtractControlEdges.
615   Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(
616       const std::string& name, Region* region, const int index);
617 
618   // Modifies *block by unwrapping all ControlNodeOps. The DAG of the control
619   // dependencies is returned as a vector of its edges, with node indices into
620   // *block.
621   std::vector<std::pair<int, int>> ExtractControlEdges(mlir::Block* block);
622 
623   // Builds Metadata with the given `name` and buffer `content`.
624   BufferOffset<tflite::Metadata> BuildMetadata(StringRef name,
625                                                StringRef content);
626 
627   // Encodes the `tfl.metadata` dictionary attribute of the module to the
628   // metadata section in the final model.
629   Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
630   CreateMetadataVector();
631 
632   // Builds and returns list of tfl.SignatureDef sections in the model.
633   Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>>
634   CreateSignatureDefs(const std::vector<SignatureDefData>& signature_defs);
635 
636   // Returns list of offsets for the passed 'items' in TensorMap structure
637   // inside the flatbuffer.
638   // 'items' is a map from tensor name in signatureDef to tensor name in
639   // the subgraph, specified by the 'subgraph_index' argument.
640   std::vector<BufferOffset<tflite::TensorMap>> GetList(
641       const int subgraph_index,
642       const std::map<std::string, std::string>& items);
643 
644   // Uses the tf.entry_function attribute (if set) to initialize the op to name
645   // mapping.
646   void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr);
647 
648   // Determines if the specified operation op's operand at operand_index
649   // is marked as a stateful operand.
650   bool IsStatefulOperand(mlir::Operation* op, int operand_index);
651 
652   // Returns a unique name for `val`.
653   std::string UniqueName(mlir::Value val);
654 
655   BufferOffset<tflite::SparsityParameters> BuildSparsityParameters(
656       const mlir::TFL::SparsityParameterAttr& s_attr);
657 
658   bool EstimateArithmeticCount(int64_t* count);
659 
660   // Check compatibility with GPU delegate and returns the compatibility.
661   bool CheckGpuDelegateCompatibility(uint8_t* model_buffer_pointer);
662 
663   ModuleOp module_;
664 
665   tensorflow::OpOrArgNameMapper& name_mapper_;
666 
667   flatbuffers::FlatBufferBuilder builder_;
668   BufferOffset<tflite::Buffer> empty_buffer_;
669 
670   std::vector<BufferOffset<tflite::Buffer>> buffers_;
671   // Maps subgraph index and tensor name in the graph to the tensor index.
672   absl::flat_hash_map<int, absl::flat_hash_map<std::string, int>>
673       tensor_index_map_;
674 
675   // Maps op name to index of the corresponding OperatorCode in opcodes_ vector.
676   absl::flat_hash_map<std::string, uint32_t> opcode_index_map_;
677   std::vector<BufferOffset<tflite::OperatorCode>> opcodes_;
678 
679   // Maps function name to index of the corresponding subgraph in the FlatBuffer
680   // model.
681   absl::flat_hash_map<std::string, int> subgraph_index_map_;
682   absl::flat_hash_set<OpType> enabled_op_types_;
683 
684   // Points to TensorFlow and TFLite dialects, respectively. nullptr if the
685   // dialect is not registered.
686   const Dialect* tf_dialect_;
687   const Dialect* tfl_dialect_;
688 
689   // The failed ops during legalization.
690   std::map<std::string, std::set<std::string>> failed_flex_ops_;
691   std::map<std::string, std::set<std::string>> failed_custom_ops_;
692 
693   // Ops to provide warning messages.
694   std::map<std::string, std::set<std::string>> custom_ops_;
695   std::map<std::string, std::set<std::string>> flex_ops_;
696 
697   // Resource ops to provide warning messages.
698   std::map<std::string, std::set<std::string>> resource_ops_;
699 
700   // Set of saved model tags, if any.
701   const std::unordered_set<std::string> saved_model_tags_;
702   // Allows automatic pass through of TF ops as select Tensorflow ops.
703   const bool allow_all_select_tf_ops_;
704   // User's defined ops allowed with Flex.
705   const std::unordered_set<std::string> select_user_tf_ops_;
706   // Map of key value pairs of metadata to export.
707   const std::map<std::string, std::string> metadata_;
708   // User's defined supported backends.
709   const std::unordered_set<std::string> supported_backends_;
710   // A mapping table to mlir::Operation objects for TFL subgraph and operator
711   // index in a flatbuffer.
712   std::vector<std::vector<Operation*>> subgraph_op_inst_map_;
713 
714   // Will be populated by ExtractControlEdges to contain the control
715   // dependencies contained in the ControlNodeOps. Will then be used to populate
716   // metadata in the exported flatbuffer file.
717   tflite::ModelControlDependencies model_control_dependencies_;
718 };
719 
EstimateArithmeticCount(int64_t * count)720 bool Translator::EstimateArithmeticCount(int64_t* count) {
721   int64_t result = 0;
722   bool encounter_undetermined_mac = false;
723   module_->walk([&](mlir::TFL::TflArithmeticCountOpInterface op) {
724     int64_t mac_count = op.GetArithmeticCount(op);
725     if (mac_count < 0) {
726       encounter_undetermined_mac = true;
727       return;
728     }
729     result += mac_count;
730   });
731 
732   *count = result;
733   return !encounter_undetermined_mac;
734 }
735 
UniqueName(mlir::Value val)736 std::string Translator::UniqueName(mlir::Value val) {
737   return std::string(name_mapper_.GetUniqueName(val));
738 }
739 
BuildBuffer(Operation * inst)740 Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
741     Operation* inst) {
742   ElementsAttr attr;
743   if (auto cst = dyn_cast<mlir::arith::ConstantOp>(inst)) {
744     // arith::ConstantOp have ElementAttr at this point due to validation of the
745     // TFLite module.
746     attr = cst.getValue().cast<ElementsAttr>();
747   } else if (auto cst = dyn_cast<mlir::TF::ConstOp>(inst)) {
748     attr = cst.value();
749   } else if (auto cst = dyn_cast<tfl::ConstOp>(inst)) {
750     attr = cst.value();
751   } else if (auto cst = dyn_cast<tfl::QConstOp>(inst)) {
752     attr = cst.value();
753   } else if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
754     attr = cst.compressed_data();
755   } else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
756     attr = cst.compressed_data();
757   } else {
758     return empty_buffer_;
759   }
760 
761   tensorflow::Tensor tensor;
762   auto status = tensorflow::ConvertToTensor(attr, &tensor);
763   if (!status.ok()) {
764     inst->emitError(
765         Twine("failed to convert value attribute to tensor with error: " +
766               status.ToString()));
767     return llvm::None;
768   }
769 
770   // TensorFlow and TensorFlow Lite use different string encoding formats.
771   // Convert to TensorFlow Lite format is it's a constant string tensor.
772   if (tensor.dtype() == tensorflow::DT_STRING) {
773     ::tflite::DynamicBuffer dynamic_buffer;
774     auto flat = tensor.flat<::tensorflow::tstring>();
775     for (int i = 0; i < flat.size(); ++i) {
776       const auto& str = flat(i);
777       dynamic_buffer.AddString(str.c_str(), str.length());
778     }
779     char* tensor_buffer;
780     int bytes = dynamic_buffer.WriteToBuffer(&tensor_buffer);
781     auto buffer_data =
782         builder_.CreateVector(reinterpret_cast<uint8_t*>(tensor_buffer), bytes);
783     free(tensor_buffer);
784     return tflite::CreateBuffer(builder_, buffer_data);
785   }
786 
787   absl::string_view tensor_data = tensor.tensor_data();
788   auto buffer_data = builder_.CreateVector(
789       reinterpret_cast<const uint8_t*>(tensor_data.data()), tensor_data.size());
790   return tflite::CreateBuffer(builder_, buffer_data);
791 }
792 
793 Optional<std::vector<BufferOffset<tflite::VariantSubType>>>
BuildTFVariantType(mlir::Type element_type)794 Translator::BuildTFVariantType(mlir::Type element_type) {
795   std::vector<BufferOffset<tflite::VariantSubType>> variant_params;
796   auto variant_type = element_type.dyn_cast<mlir::TF::VariantType>();
797   if (!variant_type) {
798     return variant_params;
799   }
800 
801   // We only support up to one nested type in tf_type.variant_type.
802   if (variant_type.getSubtypes().size() > 1) {
803     return llvm::None;
804   }
805   if (variant_type.getSubtypes().empty()) {
806     return variant_params;
807   }
808   mlir::TensorType tensor_type = variant_type.getSubtypes().front();
809   tflite::TensorType tflite_element_type =
810       GetTFLiteType(tensor_type.getElementType()).ValueOrDie();
811   std::vector<int32_t> shape;
812   if (tensor_type.hasRank()) {
813     llvm::ArrayRef<int64_t> shape_ref = tensor_type.getShape();
814     shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
815   }
816 
817   variant_params.push_back(
818       tflite::CreateVariantSubType(builder_, builder_.CreateVector(shape),
819                                    tflite_element_type, tensor_type.hasRank()));
820   return variant_params;
821 }
822 
BuildTensorFromType(mlir::Type type,const std::string & name)823 Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensorFromType(
824     mlir::Type type, const std::string& name) {
825   auto tensor_type = type.cast<TensorType>();
826 
827   llvm::ArrayRef<int64_t> shape_ref;
828   std::vector<int32_t> shape;
829 
830   if (tensor_type.hasRank()) {
831     if (tensor_type.hasStaticShape()) {
832       shape_ref = tensor_type.getShape();
833       shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
834     } else {
835       return llvm::None;
836     }
837   }
838 
839   auto element_type = tensor_type.getElementType();
840   tflite::TensorType tflite_element_type =
841       GetTFLiteType(tensor_type.getElementType()).ValueOrDie();
842   Optional<std::vector<BufferOffset<tflite::VariantSubType>>> variant_params =
843       BuildTFVariantType(element_type);
844   if (!variant_params.hasValue()) {
845     return llvm::None;
846   }
847   BufferOffset<tflite::QuantizationParameters> q_params = 0;
848   if (auto qtype = element_type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
849     std::vector<float> scales = {static_cast<float>(qtype.getScale())};
850     std::vector<int64_t> zero_points = {qtype.getZeroPoint()};
851     q_params = tflite::CreateQuantizationParameters(
852         builder_, /*min=*/0, /*max=*/0, builder_.CreateVector<float>(scales),
853         builder_.CreateVector<int64_t>(zero_points));
854   } else if (auto qtype =
855                  element_type
856                      .dyn_cast<mlir::quant::CalibratedQuantizedType>()) {
857     std::vector<float> mins = {static_cast<float>(qtype.getMin())};
858     std::vector<float> maxs = {static_cast<float>(qtype.getMax())};
859     q_params = tflite::CreateQuantizationParameters(
860         builder_, builder_.CreateVector<float>(mins),
861         builder_.CreateVector<float>(maxs));
862   }
863   return tflite::CreateTensor(
864       builder_, builder_.CreateVector(shape), tflite_element_type,
865       /*buffer=*/0, builder_.CreateString(name), q_params,
866       /*is_variable=*/false, /*sparsity=*/0, /*shape_signature=*/0,
867       /*has_rank=*/tensor_type.hasRank(),
868       variant_params->empty() ? 0 : builder_.CreateVector(*variant_params));
869 }
870 
BuildTensor(Value value,const std::string & name,unsigned buffer_idx,const Optional<BufferOffset<tflite::QuantizationParameters>> & quant_parameters)871 Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
872     Value value, const std::string& name, unsigned buffer_idx,
873     const Optional<BufferOffset<tflite::QuantizationParameters>>&
874         quant_parameters) {
875   auto type = value.getType().cast<TensorType>();
876 
877   // TFLite requires tensor shape only for the inputs and constants.
878   // However, we output all known shapes for better round-tripping
879   auto check_shape =
880       [&](llvm::ArrayRef<int64_t> shape_ref) -> mlir::LogicalResult {
881     auto is_out_of_range = [](int64_t dim) {
882       return dim > std::numeric_limits<int32_t>::max();
883     };
884 
885     if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range))
886       return mlir::emitError(
887           value.getLoc(),
888           "result shape dimensions out of 32 bit int type range");
889 
890     return mlir::success();
891   };
892 
893   std::vector<int32_t> shape;
894   std::vector<int32_t> shape_signature;
895   auto* inst = value.getDefiningOp();
896   if (type.hasStaticShape()) {
897     llvm::ArrayRef<int64_t> shape_ref = type.getShape();
898     if (mlir::failed(check_shape(shape_ref))) return llvm::None;
899 
900     shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
901   } else if (inst && IsConst(inst)) {
902     // Const op can have a result of dynamic shaped type (e.g. due to constant
903     // folding), but we can still derive the shape of a constant tensor for
904     // its attribute type.
905     auto tensor_attr = inst->getAttr("value").cast<mlir::TypedAttr>();
906     llvm::ArrayRef<int64_t> shape_ref =
907         tensor_attr.getType().cast<TensorType>().getShape();
908     if (mlir::failed(check_shape(shape_ref))) return llvm::None;
909 
910     shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
911   } else if (type.hasRank()) {
912     llvm::ArrayRef<int64_t> shape_ref = type.getShape();
913     if (mlir::failed(check_shape(shape_ref))) return llvm::None;
914 
915     shape.reserve(shape_ref.size());
916     for (auto& dim : shape_ref) {
917       shape.push_back(dim == -1 ? 1 : dim);
918     }
919     shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
920   }
921 
922   BufferOffset<tflite::SparsityParameters> s_params = 0;
923   if (auto* inst = value.getDefiningOp()) {
924     if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
925       s_params = BuildSparsityParameters(cst.s_param());
926     } else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
927       s_params = BuildSparsityParameters(cst.s_param());
928     }
929   }
930 
931   Type element_type = type.getElementType();
932   tflite::TensorType tflite_element_type =
933       GetTFLiteType(type.getElementType()).ValueOrDie();
934 
935   Optional<std::vector<BufferOffset<tflite::VariantSubType>>> variant_params =
936       BuildTFVariantType(element_type);
937   if (!variant_params.hasValue()) {
938     return llvm::None;
939   }
940 
941   BufferOffset<tflite::QuantizationParameters> q_params;
942   if (auto qtype = element_type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
943     std::vector<float> scales = {static_cast<float>(qtype.getScale())};
944     std::vector<int64_t> zero_points = {qtype.getZeroPoint()};
945     q_params = tflite::CreateQuantizationParameters(
946         // min and max values are not stored in the quantized type from MLIR, so
947         // both are set to 0 in the flatbuffer when they are exported.
948         builder_, /*min=*/0, /*max=*/0, builder_.CreateVector<float>(scales),
949         builder_.CreateVector<int64_t>(zero_points));
950   } else if (auto qtype =
951                  element_type
952                      .dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
953     std::vector<float> scales(qtype.getScales().begin(),
954                               qtype.getScales().end());
955     std::vector<int64_t> zero_points(qtype.getZeroPoints().begin(),
956                                      qtype.getZeroPoints().end());
957     q_params = tflite::CreateQuantizationParameters(
958         builder_, /*min=*/0, /*max=*/0, builder_.CreateVector<float>(scales),
959         builder_.CreateVector<int64_t>(zero_points),
960         tflite::QuantizationDetails_NONE, /*details=*/0,
961         qtype.getQuantizedDimension());
962   } else if (quant_parameters.has_value()) {
963     q_params = quant_parameters.getValue();
964   } else {
965     q_params = tflite::CreateQuantizationParameters(builder_);
966   }
967   // Check if the value's uses includes an op and usage at an operand index
968   // marked as a stateful. If so, set the tensor's is_variable as true
969   // This is v1 ref variable semantics in the TFLite runtime.
970   bool is_variable = false;
971   for (auto& use : value.getUses()) {
972     is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber());
973     if (is_variable) {
974       break;
975     }
976   }
977 
978   bool has_rank = type.hasRank();
979 
980   if (shape_signature.empty()) {
981     return tflite::CreateTensor(
982         builder_, builder_.CreateVector(shape), tflite_element_type,
983         (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
984         /*is_variable=*/is_variable, s_params, /*shape_signature=*/0,
985         /*has_rank=*/has_rank,
986         variant_params->empty() ? 0 : builder_.CreateVector(*variant_params));
987   } else {
988     return tflite::CreateTensor(
989         builder_, builder_.CreateVector(shape), tflite_element_type,
990         (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
991         /*is_variable=*/is_variable, s_params,
992         /*shape_signature=*/builder_.CreateVector(shape_signature),
993         /*has_rank=*/has_rank,
994         variant_params->empty() ? 0 : builder_.CreateVector(*variant_params));
995   }
996 }
997 
BuildIfOperator(mlir::TF::IfOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)998 BufferOffset<tflite::Operator> Translator::BuildIfOperator(
999     mlir::TF::IfOp op, const std::vector<int32_t>& operands,
1000     const std::vector<int32_t>& results) {
1001   auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF);
1002   int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str());
1003   int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str());
1004   auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index,
1005                                                  else_subgraph_index)
1006                              .Union();
1007   auto inputs = builder_.CreateVector(operands);
1008   auto outputs = builder_.CreateVector(results);
1009   return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
1010                                 tflite::BuiltinOptions_IfOptions,
1011                                 builtin_options);
1012 }
1013 
BuildCallOnceOperator(mlir::TFL::CallOnceOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)1014 BufferOffset<tflite::Operator> Translator::BuildCallOnceOperator(
1015     mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
1016     const std::vector<int32_t>& results) {
1017   auto opcode_index =
1018       GetOpcodeIndex("call_once", tflite::BuiltinOperator_CALL_ONCE);
1019   int init_subgraph_index =
1020       subgraph_index_map_.at(op.session_init_function().str());
1021   auto builtin_options =
1022       tflite::CreateCallOnceOptions(builder_, init_subgraph_index).Union();
1023   auto inputs = builder_.CreateVector(operands);
1024   auto outputs = builder_.CreateVector(results);
1025   return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
1026                                 tflite::BuiltinOptions_CallOnceOptions,
1027                                 builtin_options);
1028 }
1029 
BuildWhileOperator(mlir::TFL::WhileOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)1030 Optional<BufferOffset<tflite::Operator>> Translator::BuildWhileOperator(
1031     mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
1032     const std::vector<int32_t>& results) {
1033   auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE);
1034   auto get_call_index = [&](mlir::Block& b) -> Optional<int> {
1035     if (b.getOperations().size() != 2) return llvm::None;
1036     if (auto call_op = dyn_cast<mlir::func::CallOp>(b.front()))
1037       return subgraph_index_map_.at(call_op.getCallee().str());
1038     return llvm::None;
1039   };
1040   auto body_subgraph_index = get_call_index(op.body().front());
1041   auto cond_subgraph_index = get_call_index(op.cond().front());
1042   if (!body_subgraph_index || !cond_subgraph_index)
1043     return op.emitOpError("only single call cond/body while export supported"),
1044            llvm::None;
1045   auto builtin_options =
1046       tflite::CreateWhileOptions(builder_, *cond_subgraph_index,
1047                                  *body_subgraph_index)
1048           .Union();
1049   auto inputs = builder_.CreateVector(operands);
1050   auto outputs = builder_.CreateVector(results);
1051   return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
1052                                 tflite::BuiltinOptions_WhileOptions,
1053                                 builtin_options);
1054 }
1055 
BuildNumericVerifyOperator(mlir::TFL::NumericVerifyOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)1056 BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator(
1057     mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
1058     const std::vector<int32_t>& results) {
1059   float tolerance = op.tolerance().convertToFloat();
1060   bool log_if_failed = op.log_if_failed();
1061   auto fbb = std::make_unique<flexbuffers::Builder>();
1062   fbb->Map([&]() {
1063     fbb->Float("tolerance", tolerance);
1064     fbb->Bool("log_if_failed", log_if_failed);
1065   });
1066   fbb->Finish();
1067   auto f = std::unique_ptr<flexbuffers::Builder>(fbb.release());
1068   auto custom_option = f->GetBuffer();
1069   auto opcode_index =
1070       GetOpcodeIndex("NumericVerify", tflite::BuiltinOperator_CUSTOM);
1071   return tflite::CreateOperator(
1072       builder_, opcode_index, builder_.CreateVector(operands),
1073       builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
1074       /*builtin_options=*/0, builder_.CreateVector<uint8_t>(custom_option),
1075       tflite::CustomOptionsFormat_FLEXBUFFERS);
1076 }
1077 
BuildCustomOperator(Operation * inst,mlir::TFL::CustomOp op,const std::vector<int32_t> & operands,const std::vector<int32_t> & results)1078 BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
1079     Operation* inst, mlir::TFL::CustomOp op,
1080     const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
1081   const std::string attrs =
1082       op.custom_option().cast<mlir::TFL::ConstBytesAttr>().getValue().str();
1083   std::vector<uint8_t> custom_option_vector(attrs.size());
1084   memcpy(custom_option_vector.data(), attrs.data(), attrs.size());
1085   auto opcode_index =
1086       GetOpcodeIndex(op.custom_code().str(), tflite::BuiltinOperator_CUSTOM);
1087   return tflite::CreateOperator(
1088       builder_, opcode_index, builder_.CreateVector(operands),
1089       builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
1090       /*builtin_options=*/0,
1091       builder_.CreateVector<uint8_t>(custom_option_vector),
1092       tflite::CustomOptionsFormat_FLEXBUFFERS);
1093 }
1094 
CreateFlexOpCustomOptions(const::tensorflow::NodeDef & node_def,const mlir::Location & loc)1095 Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
1096     const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
1097   std::string node_def_str;
1098   if (!node_def.SerializeToString(&node_def_str)) {
1099     return emitError(loc, "failed to serialize tensorflow node_def"),
1100            llvm::None;
1101   }
1102 
1103   auto flex_builder = std::make_unique<flexbuffers::Builder>();
1104   flex_builder->Vector([&]() {
1105     flex_builder->String(node_def.op());
1106     flex_builder->String(node_def_str);
1107   });
1108   flex_builder->Finish();
1109   return builder_.CreateVector(flex_builder->GetBuffer());
1110 }
1111 
CreateCustomOpCustomOptions(const::tensorflow::NodeDef & node_def,const mlir::Location & loc)1112 Optional<CustomOptionsOffset> Translator::CreateCustomOpCustomOptions(
1113     const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
1114   auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc);
1115   return builder_.CreateVector(flex_builder->GetBuffer());
1116 }
1117 
1118 std::unique_ptr<flexbuffers::Builder>
CreateFlexBuilderWithNodeAttrs(const::tensorflow::NodeDef & node_def,const mlir::Location & loc)1119 Translator::CreateFlexBuilderWithNodeAttrs(
1120     const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
1121   auto flex_builder = std::make_unique<flexbuffers::Builder>();
1122   size_t map_start = flex_builder->StartMap();
1123   using Item = std::pair<std::string, ::tensorflow::AttrValue>;
1124   std::vector<Item> attrs(node_def.attr().begin(), node_def.attr().end());
1125   std::sort(attrs.begin(), attrs.end(),
1126             [](Item& p1, Item& p2) -> bool { return p1.first < p2.first; });
1127   for (const Item& pair : attrs) {
1128     const char* key = pair.first.c_str();
1129     const ::tensorflow::AttrValue& attr = pair.second;
1130     switch (attr.value_case()) {
1131       case ::tensorflow::AttrValue::kS:
1132         flex_builder->String(key, attr.s());
1133         break;
1134       case ::tensorflow::AttrValue::kType: {
1135         auto status_or_tfl_type = tflite::TfTypeToTflType(attr.type());
1136         if (status_or_tfl_type.ok()) {
1137           flex_builder->Int(key, status_or_tfl_type.ValueOrDie());
1138         } else {
1139           emitWarning(loc, "ignoring unsupported tensorflow type: ")
1140               << std::to_string(attr.type());
1141         }
1142         break;
1143       }
1144       case ::tensorflow::AttrValue::kI:
1145         flex_builder->Int(key, attr.i());
1146         break;
1147       case ::tensorflow::AttrValue::kF:
1148         flex_builder->Float(key, attr.f());
1149         break;
1150       case ::tensorflow::AttrValue::kB:
1151         flex_builder->Bool(key, attr.b());
1152         break;
1153       case tensorflow::AttrValue::kList:
1154         if (attr.list().s_size() > 0) {
1155           auto start = flex_builder->StartVector(key);
1156           for (const std::string& v : attr.list().s()) {
1157             flex_builder->Add(v);
1158           }
1159           flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
1160         } else if (attr.list().i_size() > 0) {
1161           auto start = flex_builder->StartVector(key);
1162           for (const int64_t v : attr.list().i()) {
1163             flex_builder->Add(v);
1164           }
1165           flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
1166         } else if (attr.list().f_size() > 0) {
1167           auto start = flex_builder->StartVector(key);
1168           for (const float v : attr.list().f()) {
1169             flex_builder->Add(v);
1170           }
1171           flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
1172         } else {
1173           emitWarning(loc,
1174                       "ignoring unsupported type in list attribute with key: ")
1175               << key;
1176         }
1177         break;
1178       default:
1179         emitWarning(loc, "ignoring unsupported attribute type with key: ")
1180             << key;
1181         break;
1182     }
1183   }
1184   flex_builder->EndMap(map_start);
1185   flex_builder->Finish();
1186   return flex_builder;
1187 }
1188 
GetOpcodeIndex(const std::string & op_name,tflite::BuiltinOperator builtin)1189 uint32_t Translator::GetOpcodeIndex(const std::string& op_name,
1190                                     tflite::BuiltinOperator builtin) {
1191   auto it = opcode_index_map_.insert({op_name, 0});
1192 
1193   // If the insert succeeded, the opcode has not been created already. Create a
1194   // new operator code and update its index value in the map.
1195   if (it.second) {
1196     it.first->second = opcodes_.size();
1197     auto custom_code = builtin == tflite::BuiltinOperator_CUSTOM
1198                            ? builder_.CreateString(op_name)
1199                            : BufferOffset<flatbuffers::String>();
1200     // Use version 0 for builtin op. This is a way to serialize version field to
1201     // flatbuffer (since 0 is non default) and it will be corrected later.
1202     int32_t op_version = builtin != tflite::BuiltinOperator_CUSTOM ? 0 : 1;
1203     opcodes_.push_back(CreateOperatorCode(builder_, /*builtin_code=*/builtin,
1204                                           custom_code, op_version));
1205   }
1206   return it.first->second;
1207 }
1208 
BuildOperator(Operation * inst,std::vector<int32_t> operands,const std::vector<int32_t> & results,const std::vector<int32_t> & intermediates)1209 Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
1210     Operation* inst, std::vector<int32_t> operands,
1211     const std::vector<int32_t>& results,
1212     const std::vector<int32_t>& intermediates) {
1213   const auto* dialect = inst->getDialect();
1214   if (!dialect) {
1215     inst->emitOpError("dialect is not registered");
1216     return llvm::None;
1217   }
1218 
1219   // If TFLite built in op, create operator as a builtin op.
1220   if (dialect == tfl_dialect_) {
1221     // Only if built-in TFLite op emission is enabled, would legalization have
1222     // converted any TF->TFL.
1223     if (!enabled_op_types_.contains(OpType::kTfliteBuiltin)) {
1224       return inst->emitOpError(
1225                  "is a TFLite builtin op but builtin emission is not enabled"),
1226              llvm::None;
1227     }
1228 
1229     auto builtin_code = GetBuiltinOpCode(inst);
1230     if (!builtin_code) {
1231       if (auto verify_op = dyn_cast<mlir::TFL::NumericVerifyOp>(inst)) {
1232         return BuildNumericVerifyOperator(verify_op, operands, results);
1233       }
1234       if (auto custom_op = dyn_cast<mlir::TFL::CustomOp>(inst)) {
1235         return BuildCustomOperator(inst, custom_op, operands, results);
1236       }
1237       if (auto whileOp = dyn_cast<mlir::TFL::WhileOp>(inst)) {
1238         if (inst->getNumOperands() != inst->getNumResults()) {
1239           inst->emitOpError(
1240               "number of operands and results don't match, only canonical "
1241               "TFL While supported");
1242           return llvm::None;
1243         }
1244         return BuildWhileOperator(whileOp, operands, results);
1245       }
1246 
1247       inst->emitOpError("is not a supported TFLite op");
1248       return llvm::None;
1249     }
1250 
1251     if (*builtin_code == tflite::BuiltinOperator_CALL_ONCE) {
1252       if (auto initOp = dyn_cast<mlir::TFL::CallOnceOp>(inst)) {
1253         return BuildCallOnceOperator(initOp, operands, results);
1254       }
1255     }
1256 
1257     std::string op_name = inst->getName().getStringRef().str();
1258     uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code);
1259 
1260     // If this is TransposeConv we need to do a special case of ignoring the
1261     // optional tensor, to allow newly created models to run on old runtimes.
1262     if (*builtin_code == tflite::BuiltinOperator_TRANSPOSE_CONV) {
1263       if (operands.size() == 4 && operands.at(3) == -1) {
1264         operands.pop_back();
1265       }
1266     }
1267 
1268     auto offset = CreateFlatBufferOperator(inst, opcode_index, operands,
1269                                            results, intermediates, &builder_);
1270     if (!offset) {
1271       inst->emitOpError("is not a supported TFLite op");
1272     }
1273     return offset;
1274   }
1275 
1276   if (dialect == tf_dialect_) {
1277     if (auto ifOp = dyn_cast<mlir::TF::IfOp>(inst)) {
1278       return BuildIfOperator(ifOp, operands, results);
1279     }
1280 
1281     CustomOptionsOffset custom_options;
1282 
1283     // Ops in TF dialect can either be custom ops or flex ops.
1284     // The reason we go directly from TensorFlow dialect MLIR to tensorflow
1285     // node instead of going to TF table gen'd ops via generated code is that
1286     // we do not want to restrict custom and flex op conversion support to
1287     // only those TF ops that are currently registered in MLIR. The current
1288     // model is of an open op system.
1289     //
1290     //  The following algorithm is followed:
1291     //   if flex is enabled and the op is allowlisted as flex
1292     //     we emit op as flex.
1293     //   if custom is enabled
1294     //    we emit the op as custom.
1295     auto node_def = GetTensorFlowNodeDef(inst);
1296     if (!node_def) {
1297       return llvm::None;
1298     }
1299 
1300     std::string op_name = node_def->op();
1301     std::string op_desc = GetOpDescriptionForDebug(inst);
1302 
1303     if (IsTFResourceOp(inst)) {
1304       resource_ops_[op_name].insert(op_desc);
1305     }
1306 
1307     const bool is_allowed_flex_op =
1308         !IsUnsupportedFlexOp(node_def->op()) &&
1309         (IsAllowlistedFlexOp(node_def->op()) ||
1310          (((select_user_tf_ops_.count(node_def->op()) != 0) ||
1311            allow_all_select_tf_ops_) &&
1312           (tensorflow::OpRegistry::Global()->LookUp(node_def->op()) !=
1313            nullptr)));
1314 
1315     // Flex op case
1316     // Eventually, the allowlist will go away and we will rely on some TF op
1317     // trait (e.g. No side effect) to determine if it is a supported "Flex"
1318     // op or not.
1319     if (is_allowed_flex_op && enabled_op_types_.contains(OpType::kSelectTf)) {
1320       // Construct ops as flex op encoding TensorFlow node definition
1321       // as custom options.
1322       // Flex ops are named with the kFlexOpNamePrefix prefix to the actual
1323       // TF op name.
1324       op_name = std::string(kFlexOpNamePrefix) + node_def->op();
1325       if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) {
1326         custom_options = *options;
1327       } else {
1328         return llvm::None;
1329       }
1330 
1331       // Gather flex ops.
1332       flex_ops_[op_name].insert(op_desc);
1333     } else if (enabled_op_types_.contains(OpType::kCustomOp)) {
1334       // Generic case of custom ops - write using flex buffers since that
1335       // is the only custom options supported by TFLite today.
1336       op_name = node_def->op();
1337       if (auto options =
1338               CreateCustomOpCustomOptions(*node_def, inst->getLoc())) {
1339         custom_options = *options;
1340       } else {
1341         return llvm::None;
1342       }
1343 
1344       // Gather custom ops.
1345       custom_ops_[op_name].insert(op_desc);
1346     } else {
1347       // Insert failed op to `flex_ops` or `custom_ops`.
1348       if (is_allowed_flex_op) {
1349         failed_flex_ops_[op_name].insert(op_desc);
1350         tfl::AttachErrorCode(
1351             inst->emitOpError("is neither a custom op nor a flex op"),
1352             tflite::metrics::ConverterErrorData::ERROR_NEEDS_FLEX_OPS);
1353       } else {
1354         failed_custom_ops_[op_name].insert(op_desc);
1355         tfl::AttachErrorCode(
1356             inst->emitOpError("is neither a custom op nor a flex op"),
1357             tflite::metrics::ConverterErrorData::ERROR_NEEDS_CUSTOM_OPS);
1358       }
1359       return llvm::None;
1360     }
1361 
1362     uint32_t opcode_index =
1363         GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM);
1364     auto inputs = builder_.CreateVector(operands);
1365     auto outputs = builder_.CreateVector(results);
1366 
1367     return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
1368                                   tflite::BuiltinOptions_NONE,
1369                                   /*builtin_options=*/0,
1370                                   /*custom_options=*/custom_options,
1371                                   tflite::CustomOptionsFormat_FLEXBUFFERS,
1372                                   /*mutating_variable_inputs=*/0);
1373   }
1374 
1375   return inst->emitOpError(
1376              "is not any of a builtin TFLite op, a flex TensorFlow op or a "
1377              "custom TensorFlow op"),
1378          llvm::None;
1379 }
1380 
InitializeNamesFromAttribute(FuncOp fn,bool * has_input_attr)1381 void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
1382   auto dict_attr = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1383   if (!dict_attr) return;
1384 
1385   llvm::SmallVector<llvm::StringRef, 2> input_names;
1386   llvm::SmallVector<llvm::StringRef, 2> output_names;
1387   if (auto str = dict_attr.get("inputs").dyn_cast_or_null<mlir::StringAttr>()) {
1388     str.getValue().split(input_names, ',', /*MaxSplit=*/-1,
1389                          /*KeepEmpty=*/false);
1390     if (input_names.size() != fn.getNumArguments()) {
1391       fn.emitWarning() << "invalid entry function specification";
1392       return;
1393     }
1394     for (const auto& it : llvm::enumerate(fn.getArguments())) {
1395       name_mapper_.InitOpName(it.value(), input_names[it.index()].trim());
1396     }
1397     *has_input_attr = true;
1398   }
1399 
1400   if (auto str =
1401           dict_attr.get("outputs").dyn_cast_or_null<mlir::StringAttr>()) {
1402     str.getValue().split(output_names, ',', /*MaxSplit=*/-1,
1403                          /*KeepEmpty=*/false);
1404     auto term = fn.back().getTerminator();
1405     if (output_names.size() != term->getNumOperands()) {
1406       fn.emitWarning() << "output names (" << output_names.size()
1407                        << ") != terminator operands (" << term->getNumOperands()
1408                        << ")";
1409       return;
1410     }
1411     for (const auto& it : llvm::enumerate(term->getOperands())) {
1412       name_mapper_.InitOpName(it.value(), output_names[it.index()].trim());
1413     }
1414   }
1415 }
1416 
IsStatefulOperand(mlir::Operation * op,int operand_index)1417 bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
1418   std::vector<int> operand_indices;
1419   if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false;
1420   return absl::c_find(operand_indices, operand_index) != operand_indices.end();
1421 }
1422 
1423 BufferOffset<tflite::QuantizationParameters>
GetQuantizationForQuantStatsOpOutput(mlir::quantfork::StatisticsOp stats_op)1424 Translator::GetQuantizationForQuantStatsOpOutput(
1425     mlir::quantfork::StatisticsOp stats_op) {
1426   auto layer_stats = stats_op.getLayerStats().cast<mlir::DenseFPElementsAttr>();
1427   Optional<mlir::ElementsAttr> axis_stats = stats_op.getAxisStats();
1428   Optional<uint64_t> axis = stats_op.getAxis();
1429   std::vector<float> mins, maxs;
1430   mlir::DenseFPElementsAttr min_max_attr =
1431       axis_stats.has_value()
1432           ? axis_stats.getValue().cast<mlir::DenseFPElementsAttr>()
1433           : layer_stats;
1434 
1435   for (const auto& index_and_value :
1436        llvm::enumerate(min_max_attr.getValues<llvm::APFloat>())) {
1437     const llvm::APFloat value = index_and_value.value();
1438     if (index_and_value.index() % 2 == 0) {
1439       mins.push_back(value.convertToFloat());
1440     } else {
1441       maxs.push_back(value.convertToFloat());
1442     }
1443   }
1444 
1445   return tflite::CreateQuantizationParameters(
1446       builder_, builder_.CreateVector<float>(mins),
1447       builder_.CreateVector<float>(maxs), /*scale=*/0, /*zero_point=*/0,
1448       tflite::QuantizationDetails_NONE, /*details=*/0,
1449       /*quantized_dimension=*/axis.has_value() ? axis.getValue() : 0);
1450 }
1451 
BuildSubGraph(const std::string & name,Region * region,const int index)1452 Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
1453     const std::string& name, Region* region, const int index) {
1454   const auto control_edges = ExtractControlEdges(&region->front());
1455   bool has_input_attr = false;
1456   if (auto fn = dyn_cast<FuncOp>(region->getParentOp())) {
1457     InitializeNamesFromAttribute(fn, &has_input_attr);
1458   }
1459   std::vector<BufferOffset<tflite::Tensor>> tensors;
1460   llvm::DenseMap<Value, int> tensor_index_map;
1461 
1462   // Builds tensor and buffer for argument or operation result. Returns false
1463   // on failure.
1464   auto build_tensor_and_buffer = [&](Value value, const int subgraph_index,
1465                                      const std::string& tensor_name) {
1466     // NoneType represents optional and may be skipped here.
1467     if (value.getType().isa<NoneType>()) {
1468       return true;
1469     }
1470 
1471     tensor_index_map.insert({value, tensors.size()});
1472     tensor_index_map_[subgraph_index][tensor_name] = tensors.size();
1473     Optional<BufferOffset<tflite::QuantizationParameters>> quant_parameters;
1474     if (value.hasOneUse()) {
1475       auto stats_op =
1476           llvm::dyn_cast<mlir::quantfork::StatisticsOp>(*value.user_begin());
1477       if (stats_op) {
1478         quant_parameters = GetQuantizationForQuantStatsOpOutput(stats_op);
1479       }
1480     }
1481     auto tensor_or =
1482         BuildTensor(value, tensor_name, buffers_.size(), quant_parameters);
1483     if (!tensor_or) return false;
1484     tensors.push_back(*tensor_or);
1485 
1486     // TODO(ashwinm): Check if for stateful tensors, if it is also needed to
1487     // make the Buffer empty apart from setting the buffer_idx=0 in the
1488     // Tensor. This does not seem to affect runtime behavior for RNN/LSTM,
1489     // but would be good for reducing memory footprint.
1490     if (auto* inst = value.getDefiningOp()) {
1491       auto buffer_or = BuildBuffer(inst);
1492       if (!buffer_or) return false;
1493       buffers_.push_back(*buffer_or);
1494     } else {
1495       buffers_.push_back(empty_buffer_);
1496     }
1497     return true;
1498   };
1499 
1500   std::vector<BufferOffset<tflite::Operator>> operators;
1501 
1502   // Maps positions of operations in bb to positions in operators
1503   llvm::DenseMap<int, int> operation_index_to_operator_index;
1504   std::vector<Operation*> operators_in_mlir;
1505   auto& bb = region->front();
1506 
1507   // Main function's arguments are first passed to `input` op so they don't
1508   // have associated tensor and buffer. Build FlatBuffer tensor and buffer for
1509   // other functions.
1510   for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) {
1511     mlir::BlockArgument arg = bb.getArgument(i);
1512     std::string tensor_name;
1513     if (has_input_attr)
1514       tensor_name = std::string(name_mapper_.GetUniqueName(arg));
1515     if (tensor_name.empty()) tensor_name = absl::StrCat("arg", i);
1516     if (!build_tensor_and_buffer(arg, index, tensor_name)) return llvm::None;
1517   }
1518 
1519   bool failed_once = false;
1520   for (auto& item : llvm::enumerate(bb)) {
1521     Operation& inst = item.value();
1522     const int operation_index = item.index();
1523     if (inst.hasTrait<mlir::OpTrait::IsTerminator>()) break;
1524     // For "quant.stats" op, it's used to store the quantization parameters info
1525     // and its output should be then replaced by its input value.
1526     if (auto quant_stats_op =
1527             llvm::dyn_cast<mlir::quantfork::StatisticsOp>(inst)) {
1528       continue;
1529     }
1530     std::vector<int32_t> intermediates;
1531     // Build intermediate tensors for tfl.lstm and insert these tensors into
1532     // flatbuffer.
1533     if (llvm::isa<mlir::TFL::LSTMOp, mlir::TFL::UnidirectionalSequenceLSTMOp>(
1534             inst)) {
1535       std::vector<std::string> intermediate_names = {
1536           "input_to_input_intermediate", "input_to_forget_intermediate",
1537           "input_to_cell_intermediate", "input_to_output_intermediate",
1538           "effective_hidden_scale_intermediate"};
1539       for (const std::string& intermediate : intermediate_names) {
1540         auto intermediate_attr = inst.getAttr(intermediate);
1541         if (auto attr = intermediate_attr.dyn_cast_or_null<mlir::TypeAttr>()) {
1542           Type qtype = attr.getValue();
1543           auto tensor_or = BuildTensorFromType(
1544               qtype, name_mapper_.GetUniqueName(intermediate).str());
1545           if (!tensor_or.has_value()) {
1546             continue;
1547           } else {
1548             intermediates.push_back(tensors.size());
1549             tensors.push_back(tensor_or.getValue());
1550           }
1551         }
1552       }
1553     }
1554 
1555     for (auto val : inst.getResults()) {
1556       std::string tensor_name = UniqueName(val);
1557       // For "tfl.numeric_verify" op, the name is used to find out the original
1558       // activation tensor rather than its own unique name in the visualization
1559       // or debugging tools.
1560       auto builtin_code = GetBuiltinOpCode(&inst);
1561       if (!builtin_code && dyn_cast<mlir::TFL::NumericVerifyOp>(&inst)) {
1562         // The first operand is the quantized activation, the target of this
1563         // NumericVerify op.
1564         auto quantized_op_val = inst.getOperands().front();
1565         tensor_name = "NumericVerify/" + UniqueName(quantized_op_val) + ":" +
1566                       std::to_string(tensor_index_map[quantized_op_val]);
1567       }
1568       if (!build_tensor_and_buffer(val, index, tensor_name)) return llvm::None;
1569     }
1570 
1571     // Skip constant ops as they don't represent a TFLite operator.
1572     if (IsConst(&inst)) continue;
1573 
1574     // Fetch operand and result tensor indices.
1575     std::vector<int32_t> results;
1576     results.reserve(inst.getNumResults());
1577     for (auto result : inst.getResults()) {
1578       results.push_back(tensor_index_map.lookup(result));
1579     }
1580     Operation* real_inst = &inst;
1581     std::vector<int32_t> operands;
1582     operands.reserve(real_inst->getNumOperands());
1583     for (auto operand : real_inst->getOperands()) {
1584       if (operand.getType().isa<NoneType>())
1585         operands.push_back(kTfLiteOptionalTensor);
1586       else if (auto stats_op =
1587                    llvm::dyn_cast_or_null<mlir::quantfork::StatisticsOp>(
1588                        operand.getDefiningOp()))
1589         operands.push_back(tensor_index_map.lookup(stats_op.getArg()));
1590       else
1591         operands.push_back(tensor_index_map.lookup(operand));
1592     }
1593 
1594     // CustomTfOp is just a wrapper around a TF op, we export the custom Op
1595     // not the wrapper, so we fetch the op from the region.
1596     if (auto custom_op = dyn_cast<mlir::TFL::CustomTfOp>(inst)) {
1597       // If we have custom op with a region, then use the first op in the
1598       // region, if it exists, otherwise just use params for custom op.
1599       if (!custom_op.body().empty()) {
1600         real_inst = &custom_op.body().front().front();
1601       } else {
1602         module_.emitError(
1603             "Invalid CustomTfOp: Custom TF Op have empty region.");
1604       }
1605     }
1606     if (auto tfl_operator =
1607             BuildOperator(real_inst, operands, results, intermediates)) {
1608       operation_index_to_operator_index.try_emplace(operation_index,
1609                                                     operators.size());
1610       operators.push_back(*tfl_operator);
1611       operators_in_mlir.push_back(real_inst);
1612     } else {
1613       failed_once = true;
1614     }
1615   }
1616   if (index + 1 > subgraph_op_inst_map_.size()) {
1617     subgraph_op_inst_map_.resize(index + 1);
1618   }
1619   subgraph_op_inst_map_[index] = operators_in_mlir;
1620   if (failed_once) return llvm::None;
1621 
1622   // Get input and output tensor indices for the subgraph.
1623   std::vector<int32_t> inputs, outputs;
1624   for (auto arg : bb.getArguments()) {
1625     inputs.push_back(tensor_index_map[arg]);
1626   }
1627   for (auto result : bb.getTerminator()->getOperands()) {
1628     outputs.push_back(tensor_index_map[result]);
1629   }
1630   for (const auto& [from, to] : control_edges) {
1631     for (int what : {from, to}) {
1632       if (operation_index_to_operator_index.count(what) == 0) {
1633         module_.emitError(
1634             "dangling control edge -- at least one vertex Operation isn't a "
1635             "flatbuffer Operator.");
1636       }
1637     }
1638     model_control_dependencies_[index].emplace_back(
1639         operation_index_to_operator_index[from],
1640         operation_index_to_operator_index[to]);
1641   }
1642   return tflite::CreateSubGraph(
1643       builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs),
1644       builder_.CreateVector(outputs), builder_.CreateVector(operators),
1645       /*name=*/builder_.CreateString(name));
1646 }
1647 
BuildMetadata(StringRef name,StringRef content)1648 BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
1649                                                          StringRef content) {
1650   auto buffer_index = buffers_.size();
1651   auto buffer_data = builder_.CreateVector(
1652       reinterpret_cast<const uint8_t*>(content.data()), content.size());
1653   buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data));
1654   return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index);
1655 }
1656 
1657 Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
CreateMetadataVector()1658 Translator::CreateMetadataVector() {
1659   auto dict_attr = module_->getAttrOfType<mlir::DictionaryAttr>("tfl.metadata");
1660   std::vector<BufferOffset<tflite::Metadata>> metadata;
1661   if (dict_attr) {
1662     for (const auto& named_attr : dict_attr) {
1663       StringRef name = named_attr.getName();
1664       mlir::Attribute attr = named_attr.getValue();
1665       if (auto content = attr.dyn_cast<StringAttr>()) {
1666         metadata.push_back(BuildMetadata(name, content.getValue()));
1667       } else {
1668         module_.emitError(
1669             "all values in tfl.metadata's dictionary key-value pairs should be "
1670             "string attributes");
1671         return llvm::None;
1672       }
1673     }
1674   }
1675   // Runtime version string is generated after we update the op
1676   // versions. Here we put a 16-byte dummy string as a placeholder. We choose
1677   // 16-byte because it's the alignment of buffers in flatbuffer, so it won't
1678   // cause any waste of space if the actual string is shorter than 16 bytes.
1679   constexpr std::size_t kByteStringSize = 16;
1680   metadata.push_back(
1681       BuildMetadata("min_runtime_version", std::string(kByteStringSize, '\0')));
1682   for (const auto& kv : metadata_) {
1683     const std::string& val = kv.second;
1684     // Only take the first kByteStringSize values.
1685     const int count = std::min(kByteStringSize, val.length());
1686     std::string value = std::string(kByteStringSize, '\0')
1687                             .assign(val.begin(), val.begin() + count);
1688     metadata.push_back(BuildMetadata(kv.first, value));
1689   }
1690 
1691   // Populate the model control dependencies metadata entry.
1692   if (std::any_of(
1693           model_control_dependencies_.begin(),
1694           model_control_dependencies_.end(),
1695           [](const tflite::ControlEdges& edges) { return !edges.empty(); })) {
1696     metadata.push_back(
1697         BuildMetadata(tflite::kModelControlDependenciesMetadataKey,
1698                       tflite::SerializeModelControlDependencies(
1699                           model_control_dependencies_)));
1700   }
1701   return builder_.CreateVector(metadata);
1702 }
1703 
1704 // Helper method that returns list of all strings in a StringAttr identified
1705 // by 'attr_key' and values are separated by a comma.
GetStringsFromAttrWithSeparator(mlir::DictionaryAttr attr,const std::string & attr_key)1706 llvm::SmallVector<llvm::StringRef, 2> GetStringsFromAttrWithSeparator(
1707     mlir::DictionaryAttr attr, const std::string& attr_key) {
1708   llvm::SmallVector<llvm::StringRef, 2> result;
1709   if (auto str = attr.get(attr_key).dyn_cast_or_null<mlir::StringAttr>()) {
1710     str.getValue().split(result, ',', /*MaxSplit=*/-1,
1711                          /*KeepEmpty=*/false);
1712   }
1713   return result;
1714 }
1715 
1716 // Helper method that return list of string for all the StringAttr in the
1717 // Attribute identified by 'attr_name'.
GetStringsFromDictionaryAttr(const llvm::SmallVector<mlir::DictionaryAttr,4> & dict_attrs,const std::string & attr_name)1718 std::vector<std::string> GetStringsFromDictionaryAttr(
1719     const llvm::SmallVector<mlir::DictionaryAttr, 4>& dict_attrs,
1720     const std::string& attr_name) {
1721   std::vector<std::string> result;
1722   for (const auto& arg_attr : dict_attrs) {
1723     if (!arg_attr) continue;
1724 
1725     auto attrs = arg_attr.getValue();
1726     for (const auto attr : attrs) {
1727       if (attr.getName().str() == attr_name) {
1728         auto array_attr = attr.getValue().dyn_cast_or_null<mlir::ArrayAttr>();
1729         if (!array_attr || array_attr.empty()) continue;
1730         auto string_attr = array_attr[0].dyn_cast_or_null<mlir::StringAttr>();
1731         if (!string_attr) continue;
1732         result.push_back(string_attr.getValue().str());
1733       }
1734     }
1735   }
1736   return result;
1737 }
1738 
BuildSignaturedef(FuncOp main_op,const std::string & saved_model_tag,const uint32_t subgraph_index,tensorflow::OpOrArgNameMapper & name_mapper)1739 std::vector<SignatureDefData> BuildSignaturedef(
1740     FuncOp main_op, const std::string& saved_model_tag,
1741     const uint32_t subgraph_index, tensorflow::OpOrArgNameMapper& name_mapper) {
1742   static const char kSignatureDefIndexPath[] = "tf_saved_model.index_path";
1743   static const char kEntryFunctionAttributes[] = "tf.entry_function";
1744 
1745   // Fetch inputs and outputs from the signature.
1746   llvm::SmallVector<mlir::DictionaryAttr, 4> arg_attrs, res_attrs;
1747   main_op.getAllArgAttrs(arg_attrs);
1748   main_op.getAllResultAttrs(res_attrs);
1749   std::vector<std::string> sig_def_inputs =
1750       GetStringsFromDictionaryAttr(arg_attrs, kSignatureDefIndexPath);
1751   std::vector<std::string> sig_def_outputs =
1752       GetStringsFromDictionaryAttr(res_attrs, kSignatureDefIndexPath);
1753 
1754   // If no defined saved model signature, then return empty list.
1755   // This can happen when we are converting model not from SavedModel.
1756   if (sig_def_inputs.empty() && sig_def_outputs.empty()) return {};
1757 
1758   // Fetch function inputs and outputs tensor names.
1759   auto dict_attr =
1760       main_op->getAttrOfType<mlir::DictionaryAttr>(kEntryFunctionAttributes);
1761   if (!dict_attr) return {};
1762 
1763   // Get Input and output tensor names from attribute.
1764   llvm::SmallVector<llvm::StringRef, 2> input_names =
1765       GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"inputs");
1766   llvm::SmallVector<llvm::StringRef, 2> output_names =
1767       GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"outputs");
1768 
1769   // Verify input size match the number of arguments.
1770   if (input_names.size() != main_op.getNumArguments()) {
1771     main_op.emitWarning() << "invalid entry function specification";
1772     return {};
1773   }
1774   // Verify output size match the number of arguments.
1775   auto term = main_op.back().getTerminator();
1776   if (output_names.size() != term->getNumOperands()) {
1777     main_op.emitWarning() << "output names (" << output_names.size()
1778                           << ") != terminator operands ("
1779                           << term->getNumOperands() << ")";
1780     return {};
1781   }
1782   // Verify number of tensors for inputs and outputs matches size
1783   // of the list in the signature def.
1784   if (input_names.size() != sig_def_inputs.size() ||
1785       output_names.size() != sig_def_outputs.size()) {
1786     main_op.emitWarning(
1787         "Mismatch between signature def inputs/outputs and main function "
1788         "arguments.");
1789     return {};
1790   }
1791   // Exported method name.
1792   auto exported_name =
1793       main_op->getAttrOfType<mlir::ArrayAttr>("tf_saved_model.exported_names");
1794   if (exported_name.empty()) {
1795     main_op.emitError("Empty exported names for main Function");
1796     return {};
1797   }
1798   // Fill the SignatureDefData container.
1799   // We create vector of size 1 as TFLite now supports only 1 signatureDef.
1800   std::vector<SignatureDefData> result(1);
1801   for (int i = 0; i < input_names.size(); ++i) {
1802     result[0].inputs[sig_def_inputs[i]] = input_names[i].str();
1803   }
1804   for (int i = 0; i < output_names.size(); ++i) {
1805     // Fetch the name from the actual operand and not rely on names from
1806     // outputs as deduping can make them invalid after conversion.
1807     auto& operand = term->getOpOperand(i);
1808     auto unique_name = std::string(name_mapper.GetUniqueName(operand.get()));
1809     result[0].outputs[sig_def_outputs[i]] = unique_name;
1810   }
1811   if (auto name_attr = exported_name[0].dyn_cast_or_null<StringAttr>())
1812     result[0].signature_key = name_attr.getValue().str();
1813   result[0].subgraph_index = subgraph_index;
1814   return result;
1815 }
1816 
GetList(const int subgraph_index,const std::map<std::string,std::string> & items)1817 std::vector<BufferOffset<tflite::TensorMap>> Translator::GetList(
1818     const int subgraph_index, const std::map<std::string, std::string>& items) {
1819   std::vector<BufferOffset<tflite::TensorMap>> result;
1820   for (const auto& item : items) {
1821     auto name_buf = builder_.CreateString(item.first);
1822     tflite::TensorMapBuilder tensor_map_builder(builder_);
1823     tensor_map_builder.add_name(name_buf);
1824     tensor_map_builder.add_tensor_index(
1825         tensor_index_map_[subgraph_index][item.second]);
1826     result.push_back(tensor_map_builder.Finish());
1827   }
1828   return result;
1829 }
1830 
1831 Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>>
CreateSignatureDefs(const std::vector<SignatureDefData> & signature_defs)1832 Translator::CreateSignatureDefs(
1833     const std::vector<SignatureDefData>& signature_defs) {
1834   std::vector<BufferOffset<tflite::SignatureDef>> signature_defs_buffer;
1835   // When we export each function in the module op, intentionally, we export the
1836   // entry functions at the beginning of the subgraph list and the
1837   // subgraph_index is the index in entry functions and at the same, is the
1838   // index in the subgraph list.
1839   int subgraph_index = 0;
1840   for (const auto& signature_def_data : signature_defs) {
1841     auto inputs = GetList(subgraph_index, signature_def_data.inputs);
1842     auto outputs = GetList(subgraph_index, signature_def_data.outputs);
1843     auto inputs_buf = builder_.CreateVector(inputs);
1844     auto outputs_buf = builder_.CreateVector(outputs);
1845     auto signature_key_buf =
1846         builder_.CreateString(signature_def_data.signature_key);
1847     tflite::SignatureDefBuilder sig_def_builder(builder_);
1848     sig_def_builder.add_inputs(inputs_buf);
1849     sig_def_builder.add_outputs(outputs_buf);
1850     sig_def_builder.add_signature_key(signature_key_buf);
1851     sig_def_builder.add_subgraph_index(signature_def_data.subgraph_index);
1852     signature_defs_buffer.push_back(sig_def_builder.Finish());
1853     ++subgraph_index;
1854   }
1855 
1856   return builder_.CreateVector(signature_defs_buffer);
1857 }
1858 
UpdateEntryFunction(ModuleOp module)1859 bool UpdateEntryFunction(ModuleOp module) {
1860   if (module.lookupSymbol<FuncOp>("main") != nullptr) {
1861     // We already have an entry function.
1862     return true;
1863   }
1864 
1865   int entry_func_count = 0;
1866   FuncOp entry_func = nullptr;
1867   for (auto fn : module.getOps<FuncOp>()) {
1868     auto attrs = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1869     if (!attrs || attrs.empty()) continue;
1870     ++entry_func_count;
1871     entry_func = fn;
1872   }
1873 
1874   // We should have at least one entry function.
1875   if (entry_func_count == 0) return false;
1876 
1877   if (entry_func_count == 1) {
1878     // Update the entry func to main when the entry func is only & one.
1879     entry_func.setName(StringAttr::get(module.getContext(), "main"));
1880   }
1881   return true;
1882 }
1883 
Translate(ModuleOp module,const toco::TocoFlags & toco_flags,const std::unordered_set<std::string> & tags,OpOrArgNameMapper * op_or_arg_name_mapper,const std::map<std::string,std::string> & metadata)1884 Optional<std::string> Translator::Translate(
1885     ModuleOp module, const toco::TocoFlags& toco_flags,
1886     const std::unordered_set<std::string>& tags,
1887     OpOrArgNameMapper* op_or_arg_name_mapper,
1888     const std::map<std::string, std::string>& metadata) {
1889   OpOrArgLocNameMapper default_op_or_arg_name_mapper;
1890   if (!op_or_arg_name_mapper)
1891     op_or_arg_name_mapper = &default_op_or_arg_name_mapper;
1892   if (!UpdateEntryFunction(module)) return llvm::None;
1893   if (!IsValidTFLiteMlirModule(module)) return llvm::None;
1894   Translator translator(module, toco_flags, tags, op_or_arg_name_mapper,
1895                         metadata);
1896   return translator.TranslateInternal();
1897 }
1898 
CheckGpuDelegateCompatibility(uint8_t * model_buffer_pointer)1899 bool Translator::CheckGpuDelegateCompatibility(uint8_t* model_buffer_pointer) {
1900   bool gpu_compatibile = true;
1901   auto model = tflite::GetModel(model_buffer_pointer);
1902   auto subgraphs = model->subgraphs();
1903 
1904   for (int i = 0; i < subgraphs->Length(); ++i) {
1905     const tflite::SubGraph* subgraph = subgraphs->Get(i);
1906     for (int j = 0; j < subgraph->operators()->Length(); ++j) {
1907       const tflite::Operator* op = subgraph->operators()->Get(j);
1908       const tflite::OperatorCode* op_code =
1909           model->operator_codes()->Get(op->opcode_index());
1910       auto status =
1911           tflite::CheckGpuDelegateCompatibility(op_code, op, subgraph, model);
1912       if (!status.ok()) {
1913         gpu_compatibile = false;
1914         auto inst = subgraph_op_inst_map_[i][j];
1915         tfl::AttachErrorCode(
1916             inst->emitOpError()
1917                 << "is not GPU compatible: " << std::string(status.message()),
1918             tflite::metrics::ConverterErrorData::ERROR_GPU_NOT_COMPATIBLE);
1919       }
1920     }
1921   }
1922   return gpu_compatibile;
1923 }
1924 
TranslateInternal()1925 Optional<std::string> Translator::TranslateInternal() {
1926   // A list of named regions in the module with main function being the first in
1927   // the list. The main function is required as the first subgraph in the model
1928   // is entry point for the model.
1929   std::vector<std::pair<std::string, Region*>> named_regions;
1930   named_regions.reserve(std::distance(module_.begin(), module_.end()));
1931 
1932   int subgraph_idx = 0;
1933 
1934   // Entry functions for signature defs.
1935   std::vector<FuncOp> entry_functions;
1936   std::vector<FuncOp> non_entry_functions;
1937   FuncOp main_fn = module_.lookupSymbol<FuncOp>("main");
1938   if (main_fn != nullptr) {
1939     // Treat the main function as a signature def when the given main function
1940     // contains on the tf.entry_function attribute.
1941     auto attrs =
1942         main_fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1943     if (attrs && !attrs.empty()) {
1944       entry_functions.push_back(main_fn);
1945     } else {
1946       non_entry_functions.push_back(main_fn);
1947     }
1948   }
1949 
1950   // Walk over the module collection ops with functions and while ops.
1951   module_.walk([&](FuncOp fn) {
1952     if (main_fn == fn) return WalkResult::advance();
1953     auto attrs = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1954     if (attrs && !attrs.empty()) {
1955       entry_functions.push_back(fn);
1956     } else {
1957       non_entry_functions.push_back(fn);
1958     }
1959     return WalkResult::advance();
1960   });
1961 
1962   // Assign the subgraph index. Among the given functions, it will put entry
1963   // functions at the beginning of the list of the subgrahs.
1964   for (auto fn : entry_functions) {
1965     subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
1966     named_regions.emplace_back(fn.getName().str(), &fn.getBody());
1967   }
1968   for (auto fn : non_entry_functions) {
1969     subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
1970     named_regions.emplace_back(fn.getName().str(), &fn.getBody());
1971   }
1972 
1973   // Build subgraph for each of the named regions.
1974   std::vector<BufferOffset<tflite::SubGraph>> subgraphs;
1975   subgraphs.reserve(named_regions.size());
1976   model_control_dependencies_.assign(named_regions.size(), {});
1977   int first_failed_func = -1;
1978 
1979   // When we export each function in the module op, intentionally, we export the
1980   // entry functions at the beginning of the subgraph list and the
1981   // subgraph_index is the index in entry functions and at the same, is the
1982   // index in the subgraph list.
1983   int subgraph_index = 0;
1984   for (const auto& it : llvm::enumerate(named_regions)) {
1985     auto subgraph_or =
1986         BuildSubGraph(it.value().first, it.value().second, subgraph_index);
1987     if (!subgraph_or) {
1988       if (first_failed_func == -1)
1989         // Record the index of the first region that cannot be converted.
1990         // Keep looping through all subgraphs in the module to make sure that
1991         // we collect the list of missing ops from the entire module.
1992         first_failed_func = it.index();
1993     } else {
1994       subgraphs.push_back(*subgraph_or);
1995       ++subgraph_index;
1996     }
1997   }
1998 
1999   if (!resource_ops_.empty()) {
2000     std::string resource_ops_summary =
2001         GetOpsSummary(resource_ops_, /*summary_title=*/"Resource");
2002     LOG(WARNING) << "Graph contains the following resource op(s), that use(s) "
2003                     "resource type. Currently, the "
2004                     "resource type is not natively supported in TFLite. Please "
2005                     "consider not using the resource type if there are issues "
2006                     "with either TFLite converter or TFLite runtime:\n"
2007                  << resource_ops_summary;
2008   }
2009 
2010   if (!flex_ops_.empty()) {
2011     std::string flex_ops_summary =
2012         GetOpsSummary(flex_ops_, /*summary_title=*/"Flex");
2013     LOG(WARNING) << "TFLite interpreter needs to link Flex delegate in order "
2014                     "to run the model since it contains the following Select TF"
2015                     "op(s):\n"
2016                  << flex_ops_summary
2017                  << "\nSee instructions: "
2018                     "https://www.tensorflow.org/lite/guide/ops_select";
2019   }
2020 
2021   if (!custom_ops_.empty()) {
2022     std::string custom_ops_summary =
2023         GetOpsSummary(custom_ops_, /*summary_title=*/"Custom");
2024     LOG(WARNING) << "The following operation(s) need TFLite custom op "
2025                     "implementation(s):\n"
2026                  << custom_ops_summary
2027                  << "\nSee instructions: "
2028                     "https://www.tensorflow.org/lite/guide/ops_custom";
2029   }
2030 
2031   if (first_failed_func != -1) {
2032     std::string failed_flex_ops_summary =
2033         GetOpsSummary(failed_flex_ops_, /*summary_title=*/"TF Select");
2034     std::string failed_custom_ops_summary =
2035         GetOpsSummary(failed_custom_ops_, /*summary_title=*/"Custom");
2036     std::string err;
2037     if (!failed_flex_ops_.empty())
2038       err +=
2039           "\nSome ops are not supported by the native TFLite runtime, you can "
2040           "enable TF kernels fallback using TF Select. See instructions: "
2041           "https://www.tensorflow.org/lite/guide/ops_select \n" +
2042           failed_flex_ops_summary + "\n";
2043     if (!failed_custom_ops_.empty())
2044       err +=
2045           "\nSome ops in the model are custom ops, "
2046           "See instructions to implement "
2047           "custom ops: https://www.tensorflow.org/lite/guide/ops_custom \n" +
2048           failed_custom_ops_summary + "\n";
2049 
2050     auto& failed_region = named_regions[first_failed_func];
2051     return failed_region.second->getParentOp()->emitError()
2052                << "failed while converting: '" << failed_region.first
2053                << "': " << err,
2054            llvm::None;
2055   }
2056 
2057   // Log MAC count.
2058   int64_t ops_count;
2059   if (EstimateArithmeticCount(&ops_count)) {
2060     const int64_t million = 1e6;
2061     const int64_t billion = 1e9;
2062     std::string flops_str;
2063     std::string mac_str;
2064     if (ops_count < 10000) {
2065       flops_str = absl::StrFormat("%ld ", ops_count);
2066       mac_str = absl::StrFormat("%ld ", ops_count / 2);
2067     } else if (ops_count < billion) {
2068       flops_str =
2069           absl::StrFormat("%.3f M ", static_cast<double>(ops_count) / million);
2070       mac_str = absl::StrFormat("%.3f M ",
2071                                 static_cast<double>(ops_count / 2) / million);
2072     } else {
2073       flops_str =
2074           absl::StrFormat("%.3f G ", static_cast<double>(ops_count) / billion);
2075       mac_str = absl::StrFormat("%.3f G ",
2076                                 static_cast<double>(ops_count / 2) / billion);
2077     }
2078     LOG(INFO) << "Estimated count of arithmetic ops: " << flops_str
2079               << " ops, equivalently " << mac_str << " MACs";
2080   }
2081 
2082   std::string model_description;
2083   if (auto attr = module_->getAttrOfType<StringAttr>("tfl.description")) {
2084     model_description = attr.getValue().str();
2085   } else {
2086     model_description = "MLIR Converted.";
2087   }
2088 
2089   // Build the model and finish the model building process.
2090   auto description = builder_.CreateString(model_description.data());
2091   VectorBufferOffset<int32_t> metadata_buffer = 0;  // Deprecated
2092   auto metadata = CreateMetadataVector();
2093   if (!metadata) return llvm::None;
2094 
2095   std::vector<SignatureDefData> signature_defs_vec;
2096   subgraph_index = 0;
2097   // Build SignatureDefs for the tf.entry_function based func ops.
2098   for (auto fn : entry_functions) {
2099     auto signature_defs = BuildSignaturedef(
2100         fn, saved_model_tags_.empty() ? "" : *saved_model_tags_.begin(),
2101         subgraph_index, name_mapper_);
2102     for (const auto& signature_def : signature_defs) {
2103       signature_defs_vec.push_back(signature_def);
2104     }
2105     // When we export each function in the module op, intentionally, we export
2106     // the entry functions at the beginning of the subgraph list and the
2107     // subgraph_index is the index in entry functions and at the same, is the
2108     // index in the subgraph list.
2109     ++subgraph_index;
2110   }
2111   auto signature_defs = CreateSignatureDefs(signature_defs_vec);
2112 
2113   auto model = tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION,
2114                                    builder_.CreateVector(opcodes_),
2115                                    builder_.CreateVector(subgraphs),
2116                                    description, builder_.CreateVector(buffers_),
2117                                    metadata_buffer, *metadata, *signature_defs);
2118   tflite::FinishModelBuffer(builder_, model);
2119   // There is a limit of 2GB for a flatbuffer.
2120   if (builder_.GetSize() > 2147483648) {
2121     LOG(ERROR) << "Model size is bigger than 2gb";
2122     return llvm::None;
2123   }
2124   tflite::UpdateOpVersion(builder_.GetBufferPointer());
2125   tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer());
2126   if (supported_backends_.find("GPU") != supported_backends_.end()) {
2127     if (!CheckGpuDelegateCompatibility(builder_.GetBufferPointer())) {
2128       return llvm::None;
2129     }
2130   }
2131 
2132   // Return serialized string for the built FlatBuffer.
2133   return std::string(reinterpret_cast<const char*>(builder_.GetBufferPointer()),
2134                      builder_.GetSize());
2135 }
2136 
BuildSparsityParameters(const mlir::TFL::SparsityParameterAttr & s_attr)2137 BufferOffset<tflite::SparsityParameters> Translator::BuildSparsityParameters(
2138     const mlir::TFL::SparsityParameterAttr& s_attr) {
2139   const int dim_size = s_attr.getDimMetadata().size();
2140   std::vector<flatbuffers::Offset<tflite::DimensionMetadata>> fb_dim_metadata(
2141       dim_size);
2142   for (int i = 0; i < dim_size; i++) {
2143     const auto dim_metadata =
2144         s_attr.getDimMetadata()[i].dyn_cast<mlir::TFL::DimensionMetadataAttr>();
2145     if (dim_metadata.getFormat().getValue() ==
2146         mlir::TFL::DimensionType::DENSE) {
2147       fb_dim_metadata[i] = tflite::CreateDimensionMetadata(
2148           builder_, tflite::DimensionType_DENSE, dim_metadata.getDenseSize());
2149 
2150     } else {
2151       auto segments = dim_metadata.getSegments();
2152       std::vector<int> vector_segments(segments.size(), 0);
2153       for (int j = 0, end = segments.size(); j < end; j++) {
2154         vector_segments[j] = segments[j];
2155       }
2156       tflite::SparseIndexVector segments_type;
2157       BufferOffset<void> array_segments;
2158       // The segment array is sorted.
2159       // TODO(b/147449640): Clean this up with util functions.
2160       int max_of_segments = vector_segments[segments.size() - 1];
2161       if (max_of_segments <= UINT8_MAX) {
2162         segments_type = tflite::SparseIndexVector_Uint8Vector;
2163         std::vector<uint8_t> uint8_vector(vector_segments.begin(),
2164                                           vector_segments.end());
2165         array_segments = tflite::CreateUint8Vector(
2166                              builder_, builder_.CreateVector(uint8_vector))
2167                              .Union();
2168       } else if (max_of_segments <= UINT16_MAX) {
2169         segments_type = tflite::SparseIndexVector_Uint16Vector;
2170         std::vector<uint16_t> uint16_vector(vector_segments.begin(),
2171                                             vector_segments.end());
2172         array_segments = tflite::CreateUint16Vector(
2173                              builder_, builder_.CreateVector(uint16_vector))
2174                              .Union();
2175       } else {
2176         segments_type = tflite::SparseIndexVector_Int32Vector;
2177         array_segments = tflite::CreateInt32Vector(
2178                              builder_, builder_.CreateVector(vector_segments))
2179                              .Union();
2180       }
2181 
2182       auto indices = dim_metadata.getIndices();
2183       std::vector<int> vector_indices(indices.size(), 0);
2184       int max_of_indices = 0;
2185       for (int j = 0, end = indices.size(); j < end; j++) {
2186         vector_indices[j] = indices[j];
2187         if (vector_indices[j] > max_of_indices) {
2188           max_of_indices = vector_indices[j];
2189         }
2190       }
2191       tflite::SparseIndexVector indices_type;
2192       BufferOffset<void> array_indices;
2193       if (max_of_indices <= UINT8_MAX) {
2194         indices_type = tflite::SparseIndexVector_Uint8Vector;
2195         std::vector<uint8_t> uint8_vector(vector_indices.begin(),
2196                                           vector_indices.end());
2197         array_indices = tflite::CreateUint8Vector(
2198                             builder_, builder_.CreateVector(uint8_vector))
2199                             .Union();
2200       } else if (max_of_indices <= UINT16_MAX) {
2201         indices_type = tflite::SparseIndexVector_Uint16Vector;
2202         std::vector<uint16_t> uint16_vector(vector_indices.begin(),
2203                                             vector_indices.end());
2204         array_indices = tflite::CreateUint16Vector(
2205                             builder_, builder_.CreateVector(uint16_vector))
2206                             .Union();
2207       } else {
2208         indices_type = tflite::SparseIndexVector_Int32Vector;
2209         array_indices = tflite::CreateInt32Vector(
2210                             builder_, builder_.CreateVector(vector_indices))
2211                             .Union();
2212       }
2213 
2214       fb_dim_metadata[i] = tflite::CreateDimensionMetadata(
2215           builder_, tflite::DimensionType_SPARSE_CSR, 0, segments_type,
2216           array_segments, indices_type, array_indices);
2217     }
2218   }
2219 
2220   std::vector<int> traversal_order(dim_size);
2221   for (int i = 0; i < dim_size; i++) {
2222     traversal_order[i] = s_attr.getTraversalOrder()[i];
2223   }
2224   const int block_map_size = s_attr.getBlockMap().size();
2225   std::vector<int> block_map(block_map_size);
2226   for (int i = 0; i < block_map_size; i++) {
2227     block_map[i] = s_attr.getBlockMap()[i];
2228   }
2229 
2230   return tflite::CreateSparsityParameters(
2231       builder_, builder_.CreateVector(traversal_order),
2232       builder_.CreateVector(block_map), builder_.CreateVector(fb_dim_metadata));
2233 }
2234 
ExtractControlEdges(mlir::Block * block)2235 std::vector<std::pair<int, int>> Translator::ExtractControlEdges(
2236     mlir::Block* block) {
2237   std::vector<std::pair<int, int>> control_edges;
2238 
2239   mlir::IRRewriter rewriter(block->getParentOp()->getContext());
2240 
2241   // Since we're modifying *block, we store integer offsets to block->begin().
2242   llvm::DenseMap<Operation*, int> control_nodes_at;
2243   std::vector<Operation*> control_nodes;
2244   for (const auto& item : llvm::enumerate(*block)) {
2245     if (llvm::isa<mlir::TFL::ControlNodeOp>(item.value())) {
2246       control_nodes.push_back(&item.value());
2247       control_nodes_at.try_emplace(&item.value(), item.index());
2248     }
2249   }
2250 
2251   for (auto outer_op : control_nodes) {
2252     auto control_node_op = dyn_cast<mlir::TFL::ControlNodeOp>(outer_op);
2253     auto* inner_op = &control_node_op.body().front().front();
2254     auto control_token = control_node_op.control();
2255 
2256     // Now go through all uses. Since *block is in executable order, control
2257     // edges always point to operations we haven't modified yet.
2258     for (auto& use : control_token.getUses()) {
2259       auto owner = use.getOwner();
2260       // Control tokens can only be consumed by other ControlNodeOps,
2261       assert(llvm::isa<mlir::TFL::ControlNodeOp>(owner));
2262       assert(control_nodes_at.find(owner) != control_nodes_at.end());
2263       // Control edge in terms of offsets.
2264       control_edges.emplace_back(control_nodes_at[outer_op],
2265                                  control_nodes_at[owner]);
2266     }
2267     control_token.dropAllUses();
2268 
2269     // Replace the ControlNodeOp with the wrapped operation.
2270     rewriter.setInsertionPointAfter(outer_op);
2271     auto* cloned_inner = rewriter.clone(*inner_op);
2272     for (auto it :
2273          llvm::zip(control_node_op.outputs(), cloned_inner->getResults())) {
2274       std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
2275     }
2276     rewriter.eraseOp(outer_op);
2277   }
2278   return control_edges;
2279 }
2280 
2281 }  // namespace
2282 
2283 namespace tflite {
2284 
MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,const FlatbufferExportOptions & options,std::string * serialized_flatbuffer)2285 bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
2286                                        const FlatbufferExportOptions& options,
2287                                        std::string* serialized_flatbuffer) {
2288   auto maybe_translated = Translator::Translate(
2289       module, options.toco_flags, options.saved_model_tags,
2290       options.op_or_arg_name_mapper, options.metadata);
2291   if (!maybe_translated) return false;
2292   *serialized_flatbuffer = std::move(*maybe_translated);
2293   return true;
2294 }
2295 
2296 }  // namespace tflite
2297