xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/toco/import_tensorflow.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/import_tensorflow.h"
16 
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "google/protobuf/map.h"
23 #include "google/protobuf/text_format.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "absl/strings/numbers.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_split.h"
29 #include "absl/strings/strip.h"
30 #include "tensorflow/core/common_runtime/device_factory.h"
31 #include "tensorflow/core/common_runtime/function.h"
32 #include "tensorflow/core/common_runtime/graph_constructor.h"
33 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
34 #include "tensorflow/core/framework/attr_value.pb.h"
35 #include "tensorflow/core/framework/function.pb.h"
36 #include "tensorflow/core/framework/graph.pb.h"
37 #include "tensorflow/core/framework/node_def.pb.h"
38 #include "tensorflow/core/framework/tensor.pb.h"
39 #include "tensorflow/core/framework/tensor_shape.pb.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/platform/logging.h"
44 #include "tensorflow/core/public/session_options.h"
45 #include "tensorflow/core/public/version.h"
46 #include "tensorflow/lite/toco/model.h"
47 #include "tensorflow/lite/toco/model_flags.pb.h"
48 #include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_cluster.h"
49 #include "tensorflow/lite/toco/tensorflow_util.h"
50 #include "tensorflow/lite/toco/tooling_util.h"
51 
52 using tensorflow::AttrValue;
53 using tensorflow::DT_BOOL;
54 using tensorflow::DT_COMPLEX64;
55 using tensorflow::DT_FLOAT;
56 using tensorflow::DT_INT16;
57 using tensorflow::DT_INT32;
58 using tensorflow::DT_INT64;
59 using tensorflow::DT_QUINT8;
60 using tensorflow::DT_STRING;
61 using tensorflow::DT_UINT16;
62 using tensorflow::DT_UINT32;
63 using tensorflow::DT_UINT8;
64 using tensorflow::GraphDef;
65 using tensorflow::NodeDef;
66 using tensorflow::TensorProto;
67 using tensorflow::TensorShapeProto;
68 
69 namespace toco {
70 
71 namespace {
HasAttr(const NodeDef & node,const std::string & attr_name)72 bool HasAttr(const NodeDef& node, const std::string& attr_name) {
73   return node.attr().count(attr_name) > 0;
74 }
75 
HasWildcardDimension(const TensorShapeProto & shape)76 bool HasWildcardDimension(const TensorShapeProto& shape) {
77   for (const auto& dim : shape.dim()) {
78     if (dim.size() == -1) return true;
79   }
80   return false;
81 }
82 
GetStringAttr(const NodeDef & node,const std::string & attr_name)83 const std::string& GetStringAttr(const NodeDef& node,
84                                  const std::string& attr_name) {
85   CHECK(HasAttr(node, attr_name));
86   const auto& attr = node.attr().at(attr_name);
87   CHECK_EQ(attr.value_case(), AttrValue::kS);
88   return attr.s();
89 }
90 
GetIntAttr(const NodeDef & node,const std::string & attr_name)91 int64_t GetIntAttr(const NodeDef& node, const std::string& attr_name) {
92   CHECK(HasAttr(node, attr_name)) << attr_name << " not found in:\n"
93                                   << node.DebugString();
94   const auto& attr = node.attr().at(attr_name);
95   CHECK_EQ(attr.value_case(), AttrValue::kI);
96   return attr.i();
97 }
98 
GetFloatAttr(const NodeDef & node,const std::string & attr_name)99 float GetFloatAttr(const NodeDef& node, const std::string& attr_name) {
100   CHECK(HasAttr(node, attr_name));
101   const auto& attr = node.attr().at(attr_name);
102   CHECK_EQ(attr.value_case(), AttrValue::kF);
103   return attr.f();
104 }
105 
GetBoolAttr(const NodeDef & node,const std::string & attr_name)106 bool GetBoolAttr(const NodeDef& node, const std::string& attr_name) {
107   CHECK(HasAttr(node, attr_name));
108   const auto& attr = node.attr().at(attr_name);
109   CHECK_EQ(attr.value_case(), AttrValue::kB);
110   return attr.b();
111 }
112 
GetDataTypeAttr(const NodeDef & node,const std::string & attr_name)113 tensorflow::DataType GetDataTypeAttr(const NodeDef& node,
114                                      const std::string& attr_name) {
115   CHECK(HasAttr(node, attr_name));
116   const auto& attr = node.attr().at(attr_name);
117   CHECK_EQ(attr.value_case(), AttrValue::kType);
118   return attr.type();
119 }
120 
GetShapeAttr(const NodeDef & node,const std::string & attr_name)121 const TensorShapeProto& GetShapeAttr(const NodeDef& node,
122                                      const std::string& attr_name) {
123   CHECK(HasAttr(node, attr_name));
124   const auto& attr = node.attr().at(attr_name);
125   CHECK_EQ(attr.value_case(), AttrValue::kShape);
126   return attr.shape();
127 }
128 
GetTensorAttr(const NodeDef & node,const std::string & attr_name)129 const TensorProto& GetTensorAttr(const NodeDef& node,
130                                  const std::string& attr_name) {
131   CHECK(HasAttr(node, attr_name)) << "No attr named '" << attr_name << "'";
132   const auto& attr = node.attr().at(attr_name);
133   CHECK_EQ(attr.value_case(), AttrValue::kTensor);
134   return attr.tensor();
135 }
136 
GetListAttr(const NodeDef & node,const std::string & attr_name)137 const AttrValue::ListValue& GetListAttr(const NodeDef& node,
138                                         const std::string& attr_name) {
139   CHECK(HasAttr(node, attr_name));
140   const auto& attr = node.attr().at(attr_name);
141   CHECK_EQ(attr.value_case(), AttrValue::kList);
142   return attr.list();
143 }
144 
CheckOptionalAttr(const NodeDef & node,const std::string & attr_name,const std::string & expected_value)145 tensorflow::Status CheckOptionalAttr(const NodeDef& node,
146                                      const std::string& attr_name,
147                                      const std::string& expected_value) {
148   if (HasAttr(node, attr_name)) {
149     const std::string& value = GetStringAttr(node, attr_name);
150     if (value != expected_value) {
151       return tensorflow::errors::InvalidArgument(
152           "Unexpected value for attribute '" + attr_name + "'. Expected '" +
153           expected_value + "'");
154     }
155   }
156   return ::tensorflow::OkStatus();
157 }
158 
CheckOptionalAttr(const NodeDef & node,const std::string & attr_name,const tensorflow::DataType & expected_value)159 tensorflow::Status CheckOptionalAttr(
160     const NodeDef& node, const std::string& attr_name,
161     const tensorflow::DataType& expected_value) {
162   if (HasAttr(node, attr_name)) {
163     const tensorflow::DataType& value = GetDataTypeAttr(node, attr_name);
164     if (value != expected_value) {
165       return tensorflow::errors::InvalidArgument(
166           "Unexpected value for attribute '" + attr_name + "'. Expected '" +
167           tensorflow::DataType_Name(expected_value) + "'");
168     }
169   }
170   return ::tensorflow::OkStatus();
171 }
172 
173 template <typename T1, typename T2>
ExpectValue(const T1 & v1,const T2 & v2,const std::string & description)174 tensorflow::Status ExpectValue(const T1& v1, const T2& v2,
175                                const std::string& description) {
176   if (v1 == v2) return ::tensorflow::OkStatus();
177   return tensorflow::errors::InvalidArgument(absl::StrCat(
178       "Unexpected ", description, ": got ", v1, ", expected ", v2));
179 }
180 
ConvertDataType(tensorflow::DataType dtype)181 ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
182   if (dtype == DT_UINT8)
183     return ArrayDataType::kUint8;
184   else if (dtype == DT_FLOAT)
185     return ArrayDataType::kFloat;
186   else if (dtype == DT_BOOL)
187     return ArrayDataType::kBool;
188   else if (dtype == DT_INT16)
189     return ArrayDataType::kInt16;
190   else if (dtype == DT_UINT16)
191     return ArrayDataType::kUint16;
192   else if (dtype == DT_INT32)
193     return ArrayDataType::kInt32;
194   else if (dtype == DT_UINT32)
195     return ArrayDataType::kUint32;
196   else if (dtype == DT_INT64)
197     return ArrayDataType::kInt64;
198   else if (dtype == DT_STRING)
199     return ArrayDataType::kString;
200   else if (dtype == DT_COMPLEX64)
201     return ArrayDataType::kComplex64;
202   else
203     LOG(INFO) << "Unsupported data type in placeholder op: " << dtype;
204   return ArrayDataType::kNone;
205 }
206 
ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField<tensorflow::TensorShapeProto_Dim> & input_dims,int * input_flat_size,Shape * shape)207 tensorflow::Status ImportShape(
208     const TFLITE_PROTO_NS::RepeatedPtrField<tensorflow::TensorShapeProto_Dim>&
209         input_dims,
210     int* input_flat_size, Shape* shape) {
211   std::vector<int> input_dims_only_sizes;
212   bool zero_sized_shape = false;
213   for (auto& d : input_dims) {
214     // TensorFlow's shapes use int64s, while TOCO uses ints.
215     if (d.size() > std::numeric_limits<int>::max()) {
216       return tensorflow::errors::InvalidArgument("Shape element overflows");
217     }
218     if (d.size() == 0) {
219       zero_sized_shape = true;
220     }
221     input_dims_only_sizes.push_back(d.size());
222   }
223 
224   // Note that up to this point we were OK with the input shape containing
225   // elements valued -1 or 0, which are perfectly legal in tensorflow. However
226   // our CheckValidShapeDimensions() insists on them being >= 1, with the
227   // exception of the "scalar" shape [0]. The main issue with zero-values shape
228   // elements is that the corresponding arrays don't contain any data and the
229   // allocation code gets a bit confused. It seems that the code expects an
230   // empty shape for zero-sized shapes, so we will do just that, except for the
231   // [0] case.
232   // TODO(b/119325030): In order to correctly import the "scalar" shapes the
233   // following test must include "&& input_dims_only_sizes.size() > 1", but
234   // that seems to slow everything down a lot.
235   if (zero_sized_shape) {
236     shape->mutable_dims()->clear();
237     if (input_flat_size != nullptr) *input_flat_size = 0;
238     return ::tensorflow::OkStatus();
239   }
240 
241   *shape->mutable_dims() = input_dims_only_sizes;
242 
243   if (input_flat_size == nullptr) return ::tensorflow::OkStatus();
244 
245   return NumElements(input_dims_only_sizes, input_flat_size);
246 }
247 
248 // Define ways to retrieve data from tensors of different types.
249 // TODO(b/80208043): simply use tensorflow::Tensor::FromProto() instead.
250 template <typename T>
251 struct TensorTraits;
252 
253 template <>
254 struct TensorTraits<float> {
sizetoco::__anon64b056ed0111::TensorTraits255   static int size(const TensorProto& p) { return p.float_val_size(); }
gettoco::__anon64b056ed0111::TensorTraits256   static float get(const TensorProto& p, int i) { return p.float_val(i); }
accessor_nametoco::__anon64b056ed0111::TensorTraits257   static std::string accessor_name() { return "float_val"; }
type_nametoco::__anon64b056ed0111::TensorTraits258   static std::string type_name() { return "float"; }
CopyFromContenttoco::__anon64b056ed0111::TensorTraits259   static void CopyFromContent(const TensorProto& p, std::vector<float>* data) {
260     toco::port::CopyToBuffer(p.tensor_content(),
261                              reinterpret_cast<char*>(data->data()));
262   }
263 };
264 
265 template <>
266 struct TensorTraits<uint8_t> {
sizetoco::__anon64b056ed0111::TensorTraits267   static int size(const TensorProto& p) { return p.int_val_size(); }
gettoco::__anon64b056ed0111::TensorTraits268   static uint8_t get(const TensorProto& p, int i) { return p.int_val(i); }
accessor_nametoco::__anon64b056ed0111::TensorTraits269   static std::string accessor_name() { return "int_val"; }
type_nametoco::__anon64b056ed0111::TensorTraits270   static std::string type_name() { return "uint8"; }
CopyFromContenttoco::__anon64b056ed0111::TensorTraits271   static void CopyFromContent(const TensorProto& p,
272                               std::vector<uint8_t>* data) {
273     toco::port::CopyToBuffer(p.tensor_content(),
274                              reinterpret_cast<char*>(data->data()));
275   }
276 };
277 
278 template <>
279 struct TensorTraits<std::complex<float>> {
sizetoco::__anon64b056ed0111::TensorTraits280   static int size(const TensorProto& p) { return p.scomplex_val_size() / 2; }
gettoco::__anon64b056ed0111::TensorTraits281   static std::complex<float> get(const TensorProto& p, int i) {
282     return std::complex<float>(p.scomplex_val(2 * i),
283                                p.scomplex_val(2 * i + 1));
284   }
accessor_nametoco::__anon64b056ed0111::TensorTraits285   static std::string accessor_name() { return "scomplex_val"; }
type_nametoco::__anon64b056ed0111::TensorTraits286   static std::string type_name() { return "complex64"; }
CopyFromContenttoco::__anon64b056ed0111::TensorTraits287   static void CopyFromContent(const TensorProto& p,
288                               std::vector<std::complex<float>>* data) {
289     toco::port::CopyToBuffer(p.tensor_content(),
290                              reinterpret_cast<char*>(data->data()));
291   }
292 };
293 
294 template <>
295 struct TensorTraits<int32> {
sizetoco::__anon64b056ed0111::TensorTraits296   static int size(const TensorProto& p) { return p.int_val_size(); }
gettoco::__anon64b056ed0111::TensorTraits297   static int32 get(const TensorProto& p, int i) { return p.int_val(i); }
accessor_nametoco::__anon64b056ed0111::TensorTraits298   static std::string accessor_name() { return "int_val"; }
type_nametoco::__anon64b056ed0111::TensorTraits299   static std::string type_name() { return "int32"; }
CopyFromContenttoco::__anon64b056ed0111::TensorTraits300   static void CopyFromContent(const TensorProto& p, std::vector<int32>* data) {
301     toco::port::CopyToBuffer(p.tensor_content(),
302                              reinterpret_cast<char*>(data->data()));
303   }
304 };
305 
306 template <>
307 struct TensorTraits<uint32> {
sizetoco::__anon64b056ed0111::TensorTraits308   static int size(const TensorProto& p) { return p.uint32_val_size(); }
gettoco::__anon64b056ed0111::TensorTraits309   static int32 get(const TensorProto& p, int i) { return p.uint32_val(i); }
accessor_nametoco::__anon64b056ed0111::TensorTraits310   static std::string accessor_name() { return "uint32_val"; }
type_nametoco::__anon64b056ed0111::TensorTraits311   static std::string type_name() { return "uint32"; }
CopyFromContenttoco::__anon64b056ed0111::TensorTraits312   static void CopyFromContent(const TensorProto& p, std::vector<uint32>* data) {
313     toco::port::CopyToBuffer(p.tensor_content(),
314                              reinterpret_cast<char*>(data->data()));
315   }
316 };
317 
318 template <>
319 struct TensorTraits<int64_t> {
sizetoco::__anon64b056ed0111::TensorTraits320   static int size(const TensorProto& p) { return p.int64_val_size(); }
gettoco::__anon64b056ed0111::TensorTraits321   static int64_t get(const TensorProto& p, int i) { return p.int64_val(i); }
accessor_nametoco::__anon64b056ed0111::TensorTraits322   static std::string accessor_name() { return "int64_val"; }
type_nametoco::__anon64b056ed0111::TensorTraits323   static std::string type_name() { return "int64"; }
CopyFromContenttoco::__anon64b056ed0111::TensorTraits324   static void CopyFromContent(const TensorProto& p,
325                               std::vector<int64_t>* data) {
326     toco::port::CopyToBuffer(p.tensor_content(),
327                              reinterpret_cast<char*>(data->data()));
328   }
329 };
330 
331 template <>
332 struct TensorTraits<bool> {
sizetoco::__anon64b056ed0111::TensorTraits333   static int size(const TensorProto& p) { return p.bool_val_size(); }
gettoco::__anon64b056ed0111::TensorTraits334   static bool get(const TensorProto& p, int i) { return p.bool_val(i); }
accessor_nametoco::__anon64b056ed0111::TensorTraits335   static std::string accessor_name() { return "bool_val"; }
type_nametoco::__anon64b056ed0111::TensorTraits336   static std::string type_name() { return "bool"; }
CopyFromContenttoco::__anon64b056ed0111::TensorTraits337   static void CopyFromContent(const TensorProto& p, std::vector<bool>* data) {
338     std::vector<char> buf(p.tensor_content().size());
339     toco::port::CopyToBuffer(p.tensor_content(), buf.data());
340     for (int i = 0; i < p.tensor_content().size(); i++) {
341       (*data)[i] = static_cast<bool>(buf[i]);
342     }
343   }
344 };
345 
346 template <typename T>
ImportTensorData(const TensorProto & input_tensor,int input_flat_size,std::vector<T> * output_data)347 tensorflow::Status ImportTensorData(const TensorProto& input_tensor,
348                                     int input_flat_size,
349                                     std::vector<T>* output_data) {
350   CHECK_GE(output_data->size(), input_flat_size);
351   int num_elements_in_tensor = TensorTraits<T>::size(input_tensor);
352   if (num_elements_in_tensor == input_flat_size) {
353     for (int i = 0; i < num_elements_in_tensor; i++) {
354       (*output_data)[i] = TensorTraits<T>::get(input_tensor, i);
355     }
356   } else if (input_tensor.tensor_content().size() ==
357              input_flat_size * sizeof(T)) {
358     TensorTraits<T>::CopyFromContent(input_tensor, output_data);
359   } else if (num_elements_in_tensor >= 0 &&
360              num_elements_in_tensor < input_flat_size) {
361     // TODO(b/80208043): use tensorflow::Tensor::FromProto() which is the
362     // official way to import tensor data. This particular else-if handles a
363     // grappler optimization where the last few elements in a tensor are
364     // omitted if they are repeated, and where all elements are omitted if they
365     // are zero.
366     int i = 0;
367     for (; i < num_elements_in_tensor; ++i) {
368       (*output_data)[i] = TensorTraits<T>::get(input_tensor, i);
369     }
370     auto last = i == 0 ? T(0) : (*output_data)[i - 1];
371     for (; i < input_flat_size; ++i) {
372       (*output_data)[i] = last;
373     }
374   } else {
375     std::string accessor_name = TensorTraits<T>::accessor_name();
376     std::string type_name = TensorTraits<T>::type_name();
377     return tensorflow::errors::InvalidArgument(
378         absl::StrCat("Neither input_content (",
379                      input_tensor.tensor_content().size() / sizeof(T), ") nor ",
380                      accessor_name, " (", num_elements_in_tensor,
381                      ") have the right dimensions (", input_flat_size,
382                      ") for this ", type_name, " tensor"));
383   }
384   return ::tensorflow::OkStatus();
385 }
386 
ImportFloatArray(const TensorProto & input_tensor,Array * output_array)387 tensorflow::Status ImportFloatArray(const TensorProto& input_tensor,
388                                     Array* output_array) {
389   CHECK_EQ(input_tensor.dtype(), DT_FLOAT);
390   const auto& input_shape = input_tensor.tensor_shape();
391   CHECK_LE(input_shape.dim_size(), 6);
392   int input_flat_size;
393   auto status = ImportShape(input_shape.dim(), &input_flat_size,
394                             output_array->mutable_shape());
395   if (!status.ok()) return status;
396 
397   auto& output_float_data =
398       output_array->GetMutableBuffer<ArrayDataType::kFloat>().data;
399   output_float_data.resize(RequiredBufferSizeForShape(output_array->shape()),
400                            0.f);
401   return ImportTensorData<float>(input_tensor, input_flat_size,
402                                  &output_float_data);
403 }
404 
ImportComplex64Array(const TensorProto & input_tensor,Array * output_array)405 tensorflow::Status ImportComplex64Array(const TensorProto& input_tensor,
406                                         Array* output_array) {
407   CHECK_EQ(input_tensor.dtype(), DT_COMPLEX64);
408   const auto& input_shape = input_tensor.tensor_shape();
409   CHECK_LE(input_shape.dim_size(), 4);
410   int input_flat_size;
411   auto status = ImportShape(input_shape.dim(), &input_flat_size,
412                             output_array->mutable_shape());
413   if (!status.ok()) return status;
414 
415   auto& output_complex_data =
416       output_array->GetMutableBuffer<ArrayDataType::kComplex64>().data;
417   output_complex_data.resize(RequiredBufferSizeForShape(output_array->shape()),
418                              std::complex<float>(0.f, 0.f));
419   return ImportTensorData<std::complex<float>>(input_tensor, input_flat_size,
420                                                &output_complex_data);
421 }
422 
ImportQuint8Array(const TensorProto & input_tensor,Array * output_array)423 tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor,
424                                      Array* output_array) {
425   CHECK_EQ(input_tensor.dtype(), DT_QUINT8);
426   const auto& input_shape = input_tensor.tensor_shape();
427   CHECK_LE(input_shape.dim_size(), 6);
428   int input_flat_size;
429   auto status = ImportShape(input_shape.dim(), &input_flat_size,
430                             output_array->mutable_shape());
431   if (!status.ok()) return status;
432 
433   auto& output_int_data =
434       output_array->GetMutableBuffer<ArrayDataType::kUint8>().data;
435   output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
436   return ImportTensorData<uint8_t>(input_tensor, input_flat_size,
437                                    &output_int_data);
438 }
439 
ImportInt32Array(const TensorProto & input_tensor,Array * output_array)440 tensorflow::Status ImportInt32Array(const TensorProto& input_tensor,
441                                     Array* output_array) {
442   CHECK_EQ(input_tensor.dtype(), DT_INT32);
443   const auto& input_shape = input_tensor.tensor_shape();
444   CHECK_LE(input_shape.dim_size(), 6);
445   int input_flat_size;
446   auto status = ImportShape(input_shape.dim(), &input_flat_size,
447                             output_array->mutable_shape());
448   if (!status.ok()) return status;
449 
450   auto& output_int_data =
451       output_array->GetMutableBuffer<ArrayDataType::kInt32>().data;
452   output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
453   return ImportTensorData<int32>(input_tensor, input_flat_size,
454                                  &output_int_data);
455 }
456 
ImportUint32Array(const TensorProto & input_tensor,Array * output_array)457 tensorflow::Status ImportUint32Array(const TensorProto& input_tensor,
458                                      Array* output_array) {
459   CHECK_EQ(input_tensor.dtype(), DT_UINT32);
460   const auto& input_shape = input_tensor.tensor_shape();
461   CHECK_LE(input_shape.dim_size(), 6);
462   int input_flat_size;
463   auto status = ImportShape(input_shape.dim(), &input_flat_size,
464                             output_array->mutable_shape());
465   if (!status.ok()) return status;
466 
467   auto& output_int_data =
468       output_array->GetMutableBuffer<ArrayDataType::kUint32>().data;
469   output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
470   return ImportTensorData<uint32>(input_tensor, input_flat_size,
471                                   &output_int_data);
472 }
473 
ImportInt64Array(const TensorProto & input_tensor,Array * output_array)474 tensorflow::Status ImportInt64Array(const TensorProto& input_tensor,
475                                     Array* output_array) {
476   CHECK_EQ(input_tensor.dtype(), DT_INT64);
477   const auto& input_shape = input_tensor.tensor_shape();
478   CHECK_LE(input_shape.dim_size(), 6);
479   int input_flat_size;
480   auto status = ImportShape(input_shape.dim(), &input_flat_size,
481                             output_array->mutable_shape());
482   if (!status.ok()) return status;
483 
484   auto& output_int_data =
485       output_array->GetMutableBuffer<ArrayDataType::kInt64>().data;
486   output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
487   return ImportTensorData<int64_t>(input_tensor, input_flat_size,
488                                    &output_int_data);
489 }
490 
ImportBoolArray(const TensorProto & input_tensor,Array * output_array)491 tensorflow::Status ImportBoolArray(const TensorProto& input_tensor,
492                                    Array* output_array) {
493   CHECK_EQ(input_tensor.dtype(), DT_BOOL);
494   const auto& input_shape = input_tensor.tensor_shape();
495   CHECK_LE(input_shape.dim_size(), 6);
496   int input_flat_size;
497   auto status = ImportShape(input_shape.dim(), &input_flat_size,
498                             output_array->mutable_shape());
499   if (!status.ok()) return status;
500 
501   auto& output_bool_data =
502       output_array->GetMutableBuffer<ArrayDataType::kBool>().data;
503   output_bool_data.resize(RequiredBufferSizeForShape(output_array->shape()),
504                           false);
505   status =
506       ImportTensorData<bool>(input_tensor, input_flat_size, &output_bool_data);
507   if (!status.ok() && output_bool_data.size() == 1) {
508     // Some graphs have bool const nodes without actual value...
509     // assuming that 'false' is implied.
510     // So far only encountered that in an array with 1 entry, let's
511     // require that until we encounter a graph where that's not the case.
512     output_bool_data[0] = false;
513     return ::tensorflow::OkStatus();
514   }
515   return status;
516 }
517 
ImportStringArray(const TensorProto & input_tensor,Array * output_array)518 tensorflow::Status ImportStringArray(const TensorProto& input_tensor,
519                                      Array* output_array) {
520   CHECK_EQ(input_tensor.dtype(), DT_STRING);
521   const auto& input_shape = input_tensor.tensor_shape();
522   CHECK_LE(input_shape.dim_size(), 6);
523   int input_flat_size;
524   auto status = ImportShape(input_shape.dim(), &input_flat_size,
525                             output_array->mutable_shape());
526   if (!status.ok()) return status;
527 
528   if (input_flat_size != input_tensor.string_val_size()) {
529     return tensorflow::errors::InvalidArgument(
530         "Input_content string_val doesn't have the right dimensions "
531         "for this string tensor");
532   }
533 
534   auto& output_string_data =
535       output_array->GetMutableBuffer<ArrayDataType::kString>().data;
536   output_string_data.resize(RequiredBufferSizeForShape(output_array->shape()));
537   CHECK_GE(output_string_data.size(), input_flat_size);
538   for (int i = 0; i < input_flat_size; ++i) {
539     output_string_data[i] = input_tensor.string_val(i);
540   }
541   return ::tensorflow::OkStatus();
542 }
543 
544 // Count the number of inputs of a given node. If
545 // `tf_import_flags.drop_control_dependency` is true, count the number of
546 // non-control-dependency inputs.
GetInputsCount(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags)547 int GetInputsCount(const NodeDef& node,
548                    const TensorFlowImportFlags& tf_import_flags) {
549   if (tf_import_flags.drop_control_dependency) {
550     for (size_t i = 0; i < node.input_size(); ++i) {
551       if (node.input(i)[0] == '^') {
552         return i;
553       }
554     }
555   }
556   return node.input_size();
557 }
558 
CheckInputsCount(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,int expected_input_count)559 tensorflow::Status CheckInputsCount(
560     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
561     int expected_input_count) {
562   if (GetInputsCount(node, tf_import_flags) != expected_input_count) {
563     return tensorflow::errors::FailedPrecondition(
564         node.op(), " node expects ", expected_input_count,
565         " input(s) other than control dependencies: ", node.DebugString());
566   }
567   return ::tensorflow::OkStatus();
568 }
569 
570 template <ArrayDataType T>
CreateConstArray(Model * model,std::string const & name,std::vector<typename toco::DataType<T>> const & data)571 std::string CreateConstArray(
572     Model* model, std::string const& name,
573     std::vector<typename toco::DataType<T>> const& data) {
574   // Utility function to create a const 1D array, useful for input parameters.
575   std::string array_name = toco::AvailableArrayName(*model, name);
576   auto& array = model->GetOrCreateArray(array_name);
577   array.data_type = T;
578   array.mutable_shape()->mutable_dims()->emplace_back(
579       static_cast<int>(data.size()));
580   array.GetMutableBuffer<T>().data = data;
581   return array_name;
582 }
583 
584 // Retain TensorFlow NodeDef in Toco Operator.
585 //
586 // If an op is supported by Toco but not supported by TFLite, TFLite exporter
587 // will use the retained NodeDef to populate a Flex op when Flex mode is
588 // enabled.
589 //
590 // This can't be easily applied to all operations, because a TensorFlow node
591 // may become multiple Toco operators. Thus we need to call this function in
592 // operator conversion functions one by one whenever feasible.
593 //
594 // This may cause problems if a graph transformation rule changes parameters
595 // of the node. When calling this function, please check if any existing
596 // graph transformation rule will change an existing operator with the same
597 // type.
598 //
599 // This provides a route to handle Toco-supported & TFLite-unsupported ops
600 // in Flex mode. However it's not a solid solution. Eventually we should
601 // get rid of this.
602 // TODO(b/117327937): Implement all Toco-supported ops in TFLite, and remove
603 // this function.
RetainTensorFlowNodeDef(const NodeDef & node,Operator * op)604 void RetainTensorFlowNodeDef(const NodeDef& node, Operator* op) {
605   node.SerializeToString(&op->tensorflow_node_def);
606 }
607 
GetOutputNamesFromNodeDef(const NodeDef & node,const tensorflow::OpDef & op_def,TensorFlowUnsupportedOperator * op)608 void GetOutputNamesFromNodeDef(const NodeDef& node,
609                                const tensorflow::OpDef& op_def,
610                                TensorFlowUnsupportedOperator* op) {
611   int next_output = 0;
612   auto add_output = [&node, &next_output, op]() {
613     if (next_output == 0) {
614       op->outputs.push_back(node.name());  // Implicit :0.
615     } else {
616       op->outputs.push_back(absl::StrCat(node.name(), ":", next_output));
617     }
618     ++next_output;
619   };
620   for (int i = 0; i < op_def.output_arg_size(); ++i) {
621     std::string multiples = op_def.output_arg(i).number_attr();
622     if (!multiples.empty()) {
623       CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
624       int num_outputs = GetIntAttr(node, multiples);
625       for (int j = 0; j < num_outputs; ++j) {
626         add_output();
627       }
628     } else {
629       std::string list = op_def.output_arg(i).type_list_attr();
630       if (!list.empty()) {
631         CHECK(HasAttr(node, list)) << "No attr named " << list;
632         const AttrValue::ListValue& list_value = GetListAttr(node, list);
633         for (int j = 0; j < list_value.type_size(); ++j) {
634           add_output();
635         }
636       } else {
637         add_output();
638       }
639     }
640   }
641 }
642 
GetOutputTypesFromNodeDef(const NodeDef & node,const tensorflow::OpDef & op_def,TensorFlowUnsupportedOperator * op)643 void GetOutputTypesFromNodeDef(const NodeDef& node,
644                                const tensorflow::OpDef& op_def,
645                                TensorFlowUnsupportedOperator* op) {
646   // The given type to the op, or clear the types if invalid.
647   auto add_type = [&node, op](tensorflow::DataType type) {
648     if (type == tensorflow::DT_INVALID) {
649       LOG(WARNING) << "Op node missing output type attribute: " << node.name();
650       op->output_data_types.clear();
651     } else {
652       op->output_data_types.push_back(ConvertDataType(type));
653     }
654   };
655 
656   // Retrieve the data type according to the OpDef definition: either the
657   // "type" or "type_attr" field will be set.
658   auto get_type = [&node](const tensorflow::OpDef::ArgDef& a) {
659     if (a.type() != tensorflow::DT_INVALID) {
660       return a.type();
661     } else if (HasAttr(node, a.type_attr())) {
662       return GetDataTypeAttr(node, a.type_attr());
663     } else {
664       return tensorflow::DT_INVALID;
665     }
666   };
667 
668   for (int i = 0; i < op_def.output_arg_size(); ++i) {
669     std::string multiples = op_def.output_arg(i).number_attr();
670     if (!multiples.empty()) {
671       CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
672       int num_outputs = GetIntAttr(node, multiples);
673       auto type = get_type(op_def.output_arg(i));
674       for (int j = 0; j < num_outputs; ++j) {
675         add_type(type);
676       }
677     } else {
678       std::string list = op_def.output_arg(i).type_list_attr();
679       if (!list.empty()) {
680         CHECK(HasAttr(node, list)) << "No attr named " << list;
681         const AttrValue::ListValue& list_value = GetListAttr(node, list);
682         for (int j = 0; j < list_value.type_size(); ++j) {
683           add_type(list_value.type(j));
684         }
685       } else {
686         add_type(get_type(op_def.output_arg(i)));
687       }
688     }
689   }
690 }
691 
ConvertUnsupportedOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)692 tensorflow::Status ConvertUnsupportedOperator(
693     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
694     const ModelFlags& model_flags, Model* model) {
695   // Names of special attributes in TF graph that are used by Toco.
696   static constexpr char kAttrOutputQuantized[] = "_output_quantized";
697   static constexpr char kAttrOutputTypes[] = "_output_types";
698   static constexpr char kAttrOutputShapes[] = "_output_shapes";
699   static constexpr char kAttrSupportOutputTypeFloatInQuantizedOp[] =
700       "_support_output_type_float_in_quantized_op";
701 
702   LOG(INFO) << "Converting unsupported operation: " << node.op();
703 
704   auto* op = new TensorFlowUnsupportedOperator;
705   op->tensorflow_op = node.op();
706 
707   // For Flex mode. Please read the comments of the function.
708   RetainTensorFlowNodeDef(node, op);
709 
710   model->operators.emplace_back(op);
711 
712   // Parse inputs.
713   const int num_inputs = GetInputsCount(node, tf_import_flags);
714   for (int i = 0; i < num_inputs; ++i) {
715     op->inputs.push_back(node.input(i));
716   }
717 
718   // Parse outputs. Name them after the node's name, plus an ordinal suffix.
719   // Note that some outputs are to be multiplied by a named attribute.
720   const tensorflow::OpDef* op_def = nullptr;
721   if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
722     GetOutputNamesFromNodeDef(node, *op_def, op);
723   } else {
724     op->outputs.push_back(node.name());  // Implicit :0.
725   }
726 
727   // Parse if the op supports quantization
728   if (HasAttr(node, kAttrOutputQuantized)) {
729     op->quantized = GetBoolAttr(node, kAttrOutputQuantized);
730   }
731   // Parse if the quantized op allows output arrays of type float
732   if (HasAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp)) {
733     op->support_output_type_float_in_quantized_op =
734         GetBoolAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp);
735   }
736 
737   // Parse output type(s).
738   if (HasAttr(node, kAttrOutputTypes)) {
739     const auto& output_types = GetListAttr(node, kAttrOutputTypes);
740     for (int i = 0; i < output_types.type_size(); ++i) {
741       op->output_data_types.push_back(ConvertDataType(output_types.type(i)));
742     }
743   } else if (HasAttr(node, "Tout")) {
744     const auto& output_type = GetDataTypeAttr(node, "Tout");
745     op->output_data_types.push_back(ConvertDataType(output_type));
746   } else if (op_def != nullptr) {
747     GetOutputTypesFromNodeDef(node, *op_def, op);
748   } else {
749     // TODO(b/113613439): Figure out how to propagate types for custom ops
750     // that have no OpDef.
751     LOG(INFO) << "Unable to determine output type for op: " << node.op();
752   }
753 
754   // Parse output shape(s).
755   if (HasAttr(node, kAttrOutputShapes)) {
756     const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
757     Shape output_shape;
758     for (int i = 0; i < output_shapes.shape_size(); ++i) {
759       const auto& shape = output_shapes.shape(i);
760       // TOCO doesn't yet properly handle shapes with wildcard dimensions.
761       // TODO(b/113613439): Handle shape inference for unsupported ops that have
762       // shapes with wildcard dimensions.
763       if (HasWildcardDimension(shape)) {
764         LOG(INFO) << "Skipping wildcard output shape(s) for node: "
765                   << node.name();
766         op->output_shapes.clear();
767         break;
768       }
769       const auto status =
770           ImportShape(shape.dim(), /*input_flat_size=*/nullptr, &output_shape);
771       if (!status.ok()) {
772         return status;
773       }
774       op->output_shapes.push_back(output_shape);
775     }
776   }
777   return ::tensorflow::OkStatus();
778 }
779 
ConvertConstOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)780 tensorflow::Status ConvertConstOperator(
781     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
782     const ModelFlags& model_flags, Model* model) {
783   CHECK_EQ(node.op(), "Const");
784   const auto& tensor = GetTensorAttr(node, "value");
785   const auto dtype = GetDataTypeAttr(node, "dtype");
786 
787   tensorflow::Status status = ::tensorflow::OkStatus();
788 
789   auto& array = model->GetOrCreateArray(node.name());
790   switch (dtype) {
791     case DT_FLOAT:
792       array.data_type = ArrayDataType::kFloat;
793       status = ImportFloatArray(tensor, &array);
794       break;
795     case DT_INT32:
796       array.data_type = ArrayDataType::kInt32;
797       status = ImportInt32Array(tensor, &array);
798       break;
799     case DT_UINT32:
800       array.data_type = ArrayDataType::kUint32;
801       status = ImportUint32Array(tensor, &array);
802       break;
803     case DT_QUINT8:
804       array.data_type = ArrayDataType::kUint8;
805       status = ImportQuint8Array(tensor, &array);
806       break;
807     case DT_INT64:
808       array.data_type = ArrayDataType::kInt64;
809       status = ImportInt64Array(tensor, &array);
810       break;
811     case DT_STRING:
812       array.data_type = ArrayDataType::kString;
813       status = ImportStringArray(tensor, &array);
814       break;
815     case DT_BOOL:
816       array.data_type = ArrayDataType::kBool;
817       status = ImportBoolArray(tensor, &array);
818       break;
819     case DT_COMPLEX64:
820       array.data_type = ArrayDataType::kComplex64;
821       status = ImportComplex64Array(tensor, &array);
822       break;
823     default:
824       array.data_type = ArrayDataType::kNone;
825       // do nothing, silently ignore the Const data.
826       // We just make a dummy buffer to indicate that
827       // this array does not rely on external input.
828       array.GetMutableBuffer<ArrayDataType::kNone>();
829       break;
830   }
831   TF_RETURN_WITH_CONTEXT_IF_ERROR(
832       status, " (while processing node '" + node.name() + "')");
833   return ::tensorflow::OkStatus();
834 }
835 
ConvertConvOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)836 tensorflow::Status ConvertConvOperator(
837     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
838     const ModelFlags& model_flags, Model* model) {
839   CHECK_EQ(node.op(), "Conv2D");
840   TF_RETURN_IF_ERROR(CheckInputsCount(node, tf_import_flags, 2));
841 
842   // We only support NHWC, which is the default data_format.
843   // So if data_format is not defined, we're all good.
844   TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "data_format", "NHWC"));
845   TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "T", DT_FLOAT));
846 
847   const auto& input_name = node.input(0);
848   const auto& weights_name = node.input(1);
849   const auto& reordered_weights_name =
850       AvailableArrayName(*model, weights_name + "_reordered");
851   // Check if a ReorderAxesOperator was already created for these weights
852   // (that happens when multiple layers share the same weights).
853   const Operator* existing_reorder =
854       GetOpWithOutput(*model, reordered_weights_name);
855   if (existing_reorder) {
856     // Check that it is safe to rely on the _reordered naming of the output
857     // array!
858     CHECK(existing_reorder->type == OperatorType::kReorderAxes);
859   } else {
860     // Create a new ReorderAxesOperator
861     auto* reorder = new ReorderAxesOperator;
862     reorder->inputs = {weights_name};
863     reorder->outputs = {reordered_weights_name};
864     reorder->input_axes_order = AxesOrder::kHWIO;
865     reorder->output_axes_order = AxesOrder::kOHWI;
866     model->operators.emplace_back(reorder);
867   }
868   if (!HasAttr(node, "strides")) {
869     return tensorflow::errors::InvalidArgument("Missing attribute 'strides'");
870   }
871   const auto& strides = GetListAttr(node, "strides");
872   TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides"));
873   TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)"));
874   TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)"));
875   int dilation_height_factor;
876   int dilation_width_factor;
877   if (HasAttr(node, "dilations")) {
878     const auto& dilations = GetListAttr(node, "dilations");
879     TF_RETURN_IF_ERROR(
880         ExpectValue(dilations.i_size(), 4, "number of dilations"));
881     if (dilations.i(0) != 1 || dilations.i(3) != 1) {
882       return tensorflow::errors::InvalidArgument(absl::StrCat(
883           "Can only import Conv ops with dilation along the height "
884           "(1st) or width (2nd) axis. TensorFlow op \"",
885           node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
886           dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
887     }
888     dilation_height_factor = dilations.i(1);
889     dilation_width_factor = dilations.i(2);
890   } else {
891     dilation_height_factor = 1;
892     dilation_width_factor = 1;
893   }
894   const auto& padding = GetStringAttr(node, "padding");
895   PaddingType padding_type;
896   if (padding == "SAME") {
897     padding_type = PaddingType::kSame;
898   } else if (padding == "VALID") {
899     padding_type = PaddingType::kValid;
900   } else {
901     return tensorflow::errors::InvalidArgument(
902         "Bad padding (only SAME and VALID are supported)");
903   }
904   auto* conv = new ConvOperator;
905   conv->inputs = {input_name, reordered_weights_name};
906   conv->outputs = {node.name()};
907   conv->stride_height = strides.i(1);
908   conv->stride_width = strides.i(2);
909   conv->dilation_height_factor = dilation_height_factor;
910   conv->dilation_width_factor = dilation_width_factor;
911   conv->padding.type = padding_type;
912   model->operators.emplace_back(conv);
913 
914   return ::tensorflow::OkStatus();
915 }
916 
ConvertDepthwiseConvOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)917 tensorflow::Status ConvertDepthwiseConvOperator(
918     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
919     const ModelFlags& model_flags, Model* model) {
920   CHECK_EQ(node.op(), "DepthwiseConv2dNative");
921   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
922 
923   // We only support NHWC, which is the default data_format.
924   // So if data_format is not defined, we're all good.
925   if (HasAttr(node, "data_format")) {
926     CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
927   }
928   CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
929 
930   const auto& input_name = node.input(0);
931   const auto& weights_name = node.input(1);
932   const auto& reordered_weights_name = weights_name + "_reordered";
933   // Check if a ReorderAxesOperator was already created for these weights
934   // (that happens when multiple layers share the same weights).
935   const Operator* existing_reorder =
936       GetOpWithOutput(*model, reordered_weights_name);
937   if (existing_reorder) {
938     // Check that it is safe to rely on the _reordered naming of the output
939     // array!
940     CHECK(existing_reorder->type == OperatorType::kReorderAxes);
941   } else {
942     // Create a new ReorderAxesOperator
943     auto* reorder = new ReorderAxesOperator;
944     reorder->inputs = {weights_name};
945     reorder->outputs = {reordered_weights_name};
946     reorder->input_axes_order = AxesOrder::kHWIM;
947     reorder->output_axes_order = AxesOrder::k1HWO;
948     model->operators.emplace_back(reorder);
949   }
950   const auto& strides = GetListAttr(node, "strides");
951   TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides"));
952   TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)"));
953   TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)"));
954   int dilation_height_factor;
955   int dilation_width_factor;
956   if (HasAttr(node, "dilations")) {
957     const auto& dilations = GetListAttr(node, "dilations");
958     TF_RETURN_IF_ERROR(
959         ExpectValue(dilations.i_size(), 4, "number of dilations"));
960     if (dilations.i(0) != 1 || dilations.i(3) != 1) {
961       return tensorflow::errors::InvalidArgument(absl::StrCat(
962           "Can only import Conv ops with dilation along the height "
963           "(1st) or width (2nd) axis. TensorFlow op \"",
964           node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
965           dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
966     }
967     dilation_height_factor = dilations.i(1);
968     dilation_width_factor = dilations.i(2);
969   } else {
970     dilation_height_factor = 1;
971     dilation_width_factor = 1;
972   }
973   const auto& padding = GetStringAttr(node, "padding");
974   PaddingType padding_type;
975   if (padding == "SAME") {
976     padding_type = PaddingType::kSame;
977   } else if (padding == "VALID") {
978     padding_type = PaddingType::kValid;
979   } else {
980     return tensorflow::errors::InvalidArgument(
981         "Bad padding (only SAME and VALID are supported)");
982   }
983   auto* conv = new DepthwiseConvOperator;
984   conv->inputs = {input_name, reordered_weights_name};
985   conv->outputs = {node.name()};
986   conv->stride_height = strides.i(1);
987   conv->stride_width = strides.i(2);
988   conv->dilation_height_factor = dilation_height_factor;
989   conv->dilation_width_factor = dilation_width_factor;
990   conv->padding.type = padding_type;
991   model->operators.emplace_back(conv);
992   return ::tensorflow::OkStatus();
993 }
994 
ConvertDepthToSpaceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)995 tensorflow::Status ConvertDepthToSpaceOperator(
996     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
997     const ModelFlags& model_flags, Model* model) {
998   CHECK_EQ(node.op(), "DepthToSpace");
999   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1000 
1001   tensorflow::DataType dtype = GetDataTypeAttr(node, "T");
1002   if (dtype != DT_FLOAT && dtype != DT_UINT8 && dtype != DT_INT32 &&
1003       dtype != DT_INT64) {
1004     const auto* enum_descriptor = tensorflow::DataType_descriptor();
1005     LOG(FATAL) << "TFLite does not support DepthToSpace with type T:"
1006                << enum_descriptor->FindValueByNumber(dtype)->name() << ". "
1007                << "T must be one of {DT_FLOAT, DT_UINT8, DT_INT32, DT_INT64}.";
1008   }
1009   auto* op = new DepthToSpaceOperator;
1010   op->inputs.push_back(node.input(0));
1011   op->outputs.push_back(node.name());
1012   op->block_size = GetIntAttr(node, "block_size");
1013   QCHECK_GE(op->block_size, 2);
1014   model->operators.emplace_back(op);
1015   return ::tensorflow::OkStatus();
1016 }
1017 
ConvertSpaceToDepthOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1018 tensorflow::Status ConvertSpaceToDepthOperator(
1019     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1020     const ModelFlags& model_flags, Model* model) {
1021   CHECK_EQ(node.op(), "SpaceToDepth");
1022   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1023 
1024   tensorflow::DataType dtype = GetDataTypeAttr(node, "T");
1025   if (dtype != DT_FLOAT && dtype != DT_UINT8 && dtype != DT_INT32 &&
1026       dtype != DT_INT64) {
1027     const auto* enum_descriptor = tensorflow::DataType_descriptor();
1028     LOG(FATAL) << "TFLite does not support SpaceToDepth with type T:"
1029                << enum_descriptor->FindValueByNumber(dtype)->name() << ". "
1030                << "T must be one of {DT_FLOAT, DT_UINT8, DT_INT32, DT_INT64}.";
1031   }
1032   auto* op = new SpaceToDepthOperator;
1033   op->inputs.push_back(node.input(0));
1034   op->outputs.push_back(node.name());
1035   op->block_size = GetIntAttr(node, "block_size");
1036   QCHECK_GE(op->block_size, 2);
1037   model->operators.emplace_back(op);
1038   return ::tensorflow::OkStatus();
1039 }
1040 
ConvertBiasAddOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1041 tensorflow::Status ConvertBiasAddOperator(
1042     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1043     const ModelFlags& model_flags, Model* model) {
1044   CHECK_EQ(node.op(), "BiasAdd");
1045   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1046 
1047   const auto& input_name = node.input(0);
1048   const auto& bias_name = node.input(1);
1049   CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
1050   auto* biasadd = new AddOperator;
1051   biasadd->inputs.push_back(input_name);
1052   biasadd->inputs.push_back(bias_name);
1053   biasadd->outputs.push_back(node.name());
1054   model->operators.emplace_back(biasadd);
1055   return ::tensorflow::OkStatus();
1056 }
1057 
ConvertRandomUniform(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1058 tensorflow::Status ConvertRandomUniform(
1059     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1060     const ModelFlags& model_flags, Model* model) {
1061   CHECK_EQ(node.op(), "RandomUniform");
1062   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1063 
1064   CHECK_EQ(GetDataTypeAttr(node, "T"), DT_INT32);
1065   auto op = std::make_unique<RandomUniformOperator>();
1066   op->inputs.push_back(node.input(0));
1067   op->outputs.push_back(node.name());
1068   op->dtype = ConvertDataType(GetDataTypeAttr(node, "dtype"));
1069   op->seed = GetIntAttr(node, "seed");
1070   op->seed2 = GetIntAttr(node, "seed2");
1071   CHECK(model != nullptr);
1072   model->operators.emplace_back(std::move(op));
1073   return ::tensorflow::OkStatus();
1074 }
1075 
ConvertIdentityOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1076 tensorflow::Status ConvertIdentityOperator(
1077     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1078     const ModelFlags& model_flags, Model* model) {
1079   CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
1080         node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient" ||
1081         node.op() == "Snapshot" || node.op() == "EnsureShape");
1082   auto* op = new TensorFlowIdentityOperator;
1083   // Amazingly, some TensorFlow graphs (at least rajeev_lstm.pb) have
1084   // identity nodes with multiple inputs, but the other inputs seem
1085   // to be gratuitous (in the case of rajeev_lstm.pb, these are
1086   // enumerating the LSTM state arrays). We will just ignore extra
1087   // inputs beyond the first input.
1088   QCHECK_GE(node.input_size(), 1)
1089       << node.op()
1090       << " node expects at least 1 input other than control dependencies: "
1091       << node.DebugString();
1092   const auto& input_name = node.input(0);
1093   op->inputs.push_back(input_name);
1094   op->outputs.push_back(node.name());
1095   model->operators.emplace_back(op);
1096   return ::tensorflow::OkStatus();
1097 }
1098 
ConvertIdentityNOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1099 tensorflow::Status ConvertIdentityNOperator(
1100     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1101     const ModelFlags& model_flags, Model* model) {
1102   CHECK_EQ(node.op(), "IdentityN");
1103   for (int i = 0; i < node.input_size(); ++i) {
1104     auto* op = new TensorFlowIdentityOperator;
1105     const auto& input_name = node.input(i);
1106     std::string output_name = node.name();
1107     if (i > 0) {
1108       output_name = output_name + ":" + std::to_string(i);
1109     }
1110     op->inputs.push_back(input_name);
1111     op->outputs.push_back(output_name);
1112     model->operators.emplace_back(op);
1113   }
1114   return ::tensorflow::OkStatus();
1115 }
1116 
ConvertFakeQuantWithMinMaxArgs(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1117 tensorflow::Status ConvertFakeQuantWithMinMaxArgs(
1118     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1119     const ModelFlags& model_flags, Model* model) {
1120   CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs");
1121   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1122   auto* op = new FakeQuantOperator;
1123   op->inputs.push_back(node.input(0));
1124   op->minmax = std::make_unique<MinMax>();
1125   auto& minmax = *op->minmax;
1126   minmax.min = GetFloatAttr(node, "min");
1127   minmax.max = GetFloatAttr(node, "max");
1128   op->outputs.push_back(node.name());
1129   // tf.fake_quant_with_min_max_args num_bits defaults to 8.
1130   op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
1131   if (HasAttr(node, "narrow_range")) {
1132     op->narrow_range = GetBoolAttr(node, "narrow_range");
1133   }
1134   model->operators.emplace_back(op);
1135   return ::tensorflow::OkStatus();
1136 }
1137 
ConvertFakeQuantWithMinMaxVars(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1138 tensorflow::Status ConvertFakeQuantWithMinMaxVars(
1139     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1140     const ModelFlags& model_flags, Model* model) {
1141   CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars");
1142   const int num_inputs = GetInputsCount(node, tf_import_flags);
1143   QCHECK(num_inputs == 3 || num_inputs == 4)
1144       << "FakeQuantWithMinMaxVars node expects 3 or 4 inputs other than "
1145          "control dependencies: "
1146       << node.DebugString();
1147   auto* op = new FakeQuantOperator;
1148   for (int i = 0; i < 3; i++) {
1149     op->inputs.push_back(node.input(i));
1150   }
1151   op->outputs.push_back(node.name());
1152   op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
1153   if (HasAttr(node, "narrow_range")) {
1154     op->narrow_range = GetBoolAttr(node, "narrow_range");
1155   }
1156   model->operators.emplace_back(op);
1157   return ::tensorflow::OkStatus();
1158 }
1159 
ConvertSqueezeOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1160 tensorflow::Status ConvertSqueezeOperator(
1161     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1162     const ModelFlags& model_flags, Model* model) {
1163   CHECK_EQ(node.op(), "Squeeze");
1164   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1165   auto* op = new SqueezeOperator;
1166   op->inputs.push_back(node.input(0));
1167   op->outputs.push_back(node.name());
1168 
1169   // When omitted we are to squeeze all dimensions == 1.
1170   if (HasAttr(node, "squeeze_dims")) {
1171     const auto& squeeze_dims = GetListAttr(node, "squeeze_dims");
1172     for (int i = 0; i < squeeze_dims.i_size(); ++i) {
1173       op->squeeze_dims.push_back(squeeze_dims.i(i));
1174     }
1175   }
1176 
1177   model->operators.emplace_back(op);
1178   return ::tensorflow::OkStatus();
1179 }
1180 
ConvertSplitOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1181 tensorflow::Status ConvertSplitOperator(
1182     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1183     const ModelFlags& model_flags, Model* model) {
1184   CHECK_EQ(node.op(), "Split");
1185   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1186   auto* op = new TensorFlowSplitOperator;
1187   op->inputs.push_back(node.input(0));
1188   op->inputs.push_back(node.input(1));
1189   const int num_split = GetIntAttr(node, "num_split");
1190   op->outputs.push_back(node.name());
1191   for (int i = 1; i < num_split; i++) {
1192     op->outputs.push_back(absl::StrCat(node.name(), ":", i));
1193   }
1194   op->num_split = num_split;
1195   model->operators.emplace_back(op);
1196   return ::tensorflow::OkStatus();
1197 }
1198 
ConvertSplitVOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1199 tensorflow::Status ConvertSplitVOperator(
1200     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1201     const ModelFlags& model_flags, Model* model) {
1202   CHECK_EQ(node.op(), "SplitV");
1203   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1204   auto* op = new TensorFlowSplitVOperator;
1205   op->inputs.push_back(node.input(0));
1206   op->inputs.push_back(node.input(1));
1207   op->inputs.push_back(node.input(2));
1208   const int num_split = GetIntAttr(node, "num_split");
1209   op->outputs.push_back(node.name());
1210   for (int i = 1; i < num_split; i++) {
1211     op->outputs.push_back(absl::StrCat(node.name(), ":", i));
1212   }
1213   op->num_split = num_split;
1214   model->operators.emplace_back(op);
1215   return ::tensorflow::OkStatus();
1216 }
1217 
ConvertSwitchOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1218 tensorflow::Status ConvertSwitchOperator(
1219     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1220     const ModelFlags& model_flags, Model* model) {
1221   CHECK_EQ(node.op(), "Switch");
1222   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1223   auto* op = new TensorFlowSwitchOperator;
1224   op->inputs.push_back(node.input(0));
1225   op->inputs.push_back(node.input(1));
1226   op->outputs.push_back(node.name());
1227   // Switch operators have two outputs: "name" and "name:1".
1228   op->outputs.push_back(node.name() + ":1");
1229   model->operators.emplace_back(op);
1230   return ::tensorflow::OkStatus();
1231 }
1232 
ConvertSoftmaxOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1233 tensorflow::Status ConvertSoftmaxOperator(
1234     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1235     const ModelFlags& model_flags, Model* model) {
1236   CHECK_EQ(node.op(), "Softmax");
1237   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1238   const auto& input_name = node.input(0);
1239   auto* softmax = new SoftmaxOperator;
1240   softmax->inputs.push_back(input_name);
1241   softmax->outputs.push_back(node.name());
1242   // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter.
1243   CHECK(!node.attr().count("beta"));  // Stab in the dark, just in case.
1244   if (node.attr().count("_softmax_beta")) {
1245     softmax->beta = GetFloatAttr(node, "_softmax_beta");
1246   } else {
1247     softmax->beta = 1.f;
1248   }
1249   model->operators.emplace_back(softmax);
1250   return ::tensorflow::OkStatus();
1251 }
1252 
ConvertLRNOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1253 tensorflow::Status ConvertLRNOperator(
1254     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1255     const ModelFlags& model_flags, Model* model) {
1256   CHECK_EQ(node.op(), "LRN");
1257   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1258   const auto& input_name = node.input(0);
1259   auto* lrn = new LocalResponseNormalizationOperator;
1260   lrn->inputs.push_back(input_name);
1261   lrn->outputs.push_back(node.name());
1262   lrn->range = GetIntAttr(node, "depth_radius");
1263   lrn->bias = GetFloatAttr(node, "bias");
1264   lrn->alpha = GetFloatAttr(node, "alpha");
1265   lrn->beta = GetFloatAttr(node, "beta");
1266   model->operators.emplace_back(lrn);
1267   return ::tensorflow::OkStatus();
1268 }
1269 
ConvertMaxPoolOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1270 tensorflow::Status ConvertMaxPoolOperator(
1271     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1272     const ModelFlags& model_flags, Model* model) {
1273   CHECK_EQ(node.op(), "MaxPool");
1274   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1275   const auto& input_name = node.input(0);
1276   // We only support NHWC, which is the default data_format.
1277   // So if data_format is not defined, we're all good.
1278   if (node.attr().count("data_format")) {
1279     CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
1280   }
1281   if (HasAttr(node, "T")) {
1282     CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
1283   } else {
1284     LOG(WARNING) << "Found MaxPool operator missing 'T' attribute";
1285   }
1286   auto* maxpool = new MaxPoolOperator;
1287   maxpool->inputs.push_back(input_name);
1288   maxpool->outputs.push_back(node.name());
1289   const auto& strides = GetListAttr(node, "strides");
1290   CHECK_EQ(strides.i_size(), 4);
1291   CHECK_EQ(strides.i(0), 1);
1292   CHECK_EQ(strides.i(3), 1);
1293   maxpool->stride_height = strides.i(1);
1294   maxpool->stride_width = strides.i(2);
1295   const auto& ksize = GetListAttr(node, "ksize");
1296   CHECK_EQ(ksize.i_size(), 4);
1297   CHECK_EQ(ksize.i(0), 1);
1298   CHECK_EQ(ksize.i(3), 1);
1299   maxpool->kheight = ksize.i(1);
1300   maxpool->kwidth = ksize.i(2);
1301   const auto& padding = GetStringAttr(node, "padding");
1302   if (padding == "SAME") {
1303     maxpool->padding.type = PaddingType::kSame;
1304   } else if (padding == "VALID") {
1305     maxpool->padding.type = PaddingType::kValid;
1306   } else {
1307     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1308   }
1309   model->operators.emplace_back(maxpool);
1310   return ::tensorflow::OkStatus();
1311 }
1312 
ConvertAvgPoolOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1313 tensorflow::Status ConvertAvgPoolOperator(
1314     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1315     const ModelFlags& model_flags, Model* model) {
1316   CHECK_EQ(node.op(), "AvgPool");
1317   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1318   const auto& input_name = node.input(0);
1319   // We only support NHWC, which is the default data_format.
1320   // So if data_format is not defined, we're all good.
1321   if (node.attr().count("data_format")) {
1322     CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
1323   }
1324   CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
1325   auto* avgpool = new AveragePoolOperator;
1326   avgpool->inputs.push_back(input_name);
1327   avgpool->outputs.push_back(node.name());
1328   const auto& strides = GetListAttr(node, "strides");
1329   CHECK_EQ(strides.i_size(), 4);
1330   CHECK_EQ(strides.i(0), 1);
1331   CHECK_EQ(strides.i(3), 1);
1332   avgpool->stride_height = strides.i(1);
1333   avgpool->stride_width = strides.i(2);
1334   const auto& ksize = GetListAttr(node, "ksize");
1335   CHECK_EQ(ksize.i_size(), 4);
1336   CHECK_EQ(ksize.i(0), 1);
1337   CHECK_EQ(ksize.i(3), 1);
1338   avgpool->kheight = ksize.i(1);
1339   avgpool->kwidth = ksize.i(2);
1340   const auto& padding = GetStringAttr(node, "padding");
1341   if (padding == "SAME") {
1342     avgpool->padding.type = PaddingType::kSame;
1343   } else if (padding == "VALID") {
1344     avgpool->padding.type = PaddingType::kValid;
1345   } else {
1346     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1347   }
1348   model->operators.emplace_back(avgpool);
1349   return ::tensorflow::OkStatus();
1350 }
1351 
ConvertBatchMatMulOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1352 tensorflow::Status ConvertBatchMatMulOperator(
1353     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1354     const ModelFlags& model_flags, Model* model) {
1355   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1356 
1357   auto* batch_matmul = new BatchMatMulOperator;
1358   // https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops/matrix_math_functions
1359   if (HasAttr(node, "adj_x")) {
1360     batch_matmul->adj_x = GetBoolAttr(node, "adj_x");
1361   }
1362   if (HasAttr(node, "adj_y")) {
1363     batch_matmul->adj_y = GetBoolAttr(node, "adj_y");
1364   }
1365   batch_matmul->inputs = {node.input(0), node.input(1)};
1366   batch_matmul->outputs = {node.name()};
1367 
1368   // For Flex mode. Please read the comments of the function.
1369   RetainTensorFlowNodeDef(node, batch_matmul);
1370 
1371   model->operators.emplace_back(batch_matmul);
1372   return ::tensorflow::OkStatus();
1373 }
1374 
ConvertMatMulOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1375 tensorflow::Status ConvertMatMulOperator(
1376     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1377     const ModelFlags& model_flags, Model* model) {
1378   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1379 
1380   CHECK(!HasAttr(node, "adjoint_a") ||
1381         (GetBoolAttr(node, "adjoint_a") == false));
1382   CHECK(!HasAttr(node, "adjoint_b") ||
1383         (GetBoolAttr(node, "adjoint_b") == false));
1384 
1385   auto* matmul = new TensorFlowMatMulOperator;
1386   if (HasAttr(node, "transpose_a")) {
1387     matmul->transpose_a = GetBoolAttr(node, "transpose_a");
1388   }
1389   if (HasAttr(node, "transpose_b")) {
1390     matmul->transpose_b = GetBoolAttr(node, "transpose_b");
1391   }
1392 
1393   matmul->inputs = {node.input(0), node.input(1)};
1394   matmul->outputs = {node.name()};
1395   model->operators.emplace_back(matmul);
1396   return ::tensorflow::OkStatus();
1397 }
1398 
ConvertConcatOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1399 tensorflow::Status ConvertConcatOperator(
1400     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1401     const ModelFlags& model_flags, Model* model) {
1402   Operator* op = nullptr;
1403   if (node.op() == "Concat") {
1404     op = new TensorFlowConcatOperator;
1405   } else if (node.op() == "ConcatV2") {
1406     op = new TensorFlowConcatV2Operator;
1407   } else {
1408     LOG(FATAL) << "Expected Concat or ConcatV2";
1409   }
1410   const int num_inputs = GetInputsCount(node, tf_import_flags);
1411   QCHECK_GE(num_inputs, 2)
1412       << node.op()
1413       << " node expects at least 2 inputs other than control dependencies: "
1414       << node.DebugString();
1415   CHECK_EQ(num_inputs, 1 + GetIntAttr(node, "N"));
1416   for (int i = 0; i < num_inputs; ++i) {
1417     op->inputs.push_back(node.input(i));
1418   }
1419   op->outputs.push_back(node.name());
1420   model->operators.emplace_back(op);
1421   return ::tensorflow::OkStatus();
1422 }
1423 
ConvertMirrorPadOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1424 tensorflow::Status ConvertMirrorPadOperator(
1425     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1426     const ModelFlags& model_flags, Model* model) {
1427   if (node.op() != "MirrorPad") {
1428     LOG(FATAL) << "Expected MirrorPad.";
1429   }
1430   const int num_inputs = GetInputsCount(node, tf_import_flags);
1431   CHECK_EQ(num_inputs, 2);
1432   auto* op = new MirrorPadOperator;
1433   for (int i = 0; i < num_inputs; ++i) {
1434     op->inputs.push_back(node.input(i));
1435   }
1436   op->outputs.push_back(node.name());
1437   const auto mode = GetStringAttr(node, "mode");
1438   if (mode == "REFLECT") {
1439     op->mode = toco::MirrorPadMode::kReflect;
1440   } else if (mode == "SYMMETRIC") {
1441     op->mode = toco::MirrorPadMode::kSymmetric;
1442   }
1443 
1444   model->operators.emplace_back(op);
1445 
1446   return ::tensorflow::OkStatus();
1447 }
1448 
1449 static constexpr int kAnyNumInputs = -1;
1450 
1451 enum FlexSupport { kFlexOk, kFlexNotOk };
1452 
1453 // This method supports simple operators without additional attributes.
1454 // Converts a simple operator that takes no attributes. The list of inputs is
1455 // taken from the given NodeDef, and its number must match NumInputs, unless
1456 // kAnyNumInputs is passed in. If kFlexOk is passed in the resulting operator
1457 // will be eligible for being exported as a flex op.
1458 template <typename Op, int NumInputs, int NumOutputs, FlexSupport flex>
ConvertSimpleOperatorGeneric(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1459 tensorflow::Status ConvertSimpleOperatorGeneric(
1460     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1461     const ModelFlags& model_flags, Model* model) {
1462   if (NumInputs != kAnyNumInputs) {
1463     TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, NumInputs));
1464   }
1465   auto* op = new Op;
1466   const int num_inputs = GetInputsCount(node, tf_import_flags);
1467   for (int i = 0; i < num_inputs; ++i) {
1468     op->inputs.push_back(node.input(i));
1469   }
1470   op->outputs.push_back(node.name());
1471   if (NumOutputs > 1) {
1472     for (int i = 1; i < NumOutputs; ++i) {
1473       op->outputs.push_back(node.name() + ":" + std::to_string(i));
1474     }
1475   }
1476 
1477   if (flex == kFlexOk) {
1478     RetainTensorFlowNodeDef(node, op);
1479   }
1480 
1481   model->operators.emplace_back(op);
1482   return ::tensorflow::OkStatus();
1483 }
1484 
1485 // Convert a simple operator which is not valid as a flex op.
1486 template <typename Op, int NumInputs, int NumOutputs>
ConvertSimpleOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1487 tensorflow::Status ConvertSimpleOperator(
1488     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1489     const ModelFlags& model_flags, Model* model) {
1490   return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexNotOk>(
1491       node, tf_import_flags, model_flags, model);
1492 }
1493 
1494 // Convert a simple operator which is valid as a flex op.
1495 template <typename Op, int NumInputs, int NumOutputs>
ConvertSimpleOperatorFlexOk(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1496 tensorflow::Status ConvertSimpleOperatorFlexOk(
1497     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1498     const ModelFlags& model_flags, Model* model) {
1499   return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexOk>(
1500       node, tf_import_flags, model_flags, model);
1501 }
1502 
1503 // Same as ConvertConstOperator, but revert to ConvertUnsupportedOperator if
1504 // the types are not supported. Converting Const operators here avoids
1505 // expensive copies of the protocol buffers downstream in the flex delegate.
ConditionallyConvertConstOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1506 tensorflow::Status ConditionallyConvertConstOperator(
1507     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1508     const ModelFlags& model_flags, Model* model) {
1509   // We avoid incomplete and zero shapes because the resulting arrays
1510   // are not completely compatible with Eager/TensorFlow.
1511   const auto& tensor = GetTensorAttr(node, "value");
1512   const auto& shape = tensor.tensor_shape();
1513   for (const auto& dim : shape.dim()) {
1514     if (dim.size() <= 0) {
1515       return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
1516                                         model);
1517     }
1518   }
1519   switch (GetDataTypeAttr(node, "dtype")) {
1520     case DT_FLOAT:
1521     case DT_INT32:
1522     case DT_QUINT8:
1523     case DT_INT64:
1524     case DT_STRING:
1525     case DT_BOOL:
1526     case DT_COMPLEX64:
1527       return ConvertConstOperator(node, tf_import_flags, model_flags, model);
1528     default:
1529       return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
1530                                         model);
1531   }
1532 }
1533 
ConvertStridedSliceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1534 tensorflow::Status ConvertStridedSliceOperator(
1535     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1536     const ModelFlags& model_flags, Model* model) {
1537   CHECK_EQ(node.op(), "StridedSlice");
1538   // TODO(soroosh): The 4th input (strides) should be e optional, to be
1539   // consistent with TF.
1540   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
1541 
1542   auto* op = new StridedSliceOperator;
1543   for (const auto& input : node.input()) {
1544     op->inputs.push_back(input);
1545   }
1546   op->outputs.push_back(node.name());
1547 
1548   op->begin_mask =
1549       HasAttr(node, "begin_mask") ? GetIntAttr(node, "begin_mask") : 0;
1550   op->ellipsis_mask =
1551       HasAttr(node, "ellipsis_mask") ? GetIntAttr(node, "ellipsis_mask") : 0;
1552   op->end_mask = HasAttr(node, "end_mask") ? GetIntAttr(node, "end_mask") : 0;
1553   op->new_axis_mask =
1554       HasAttr(node, "new_axis_mask") ? GetIntAttr(node, "new_axis_mask") : 0;
1555   op->shrink_axis_mask = HasAttr(node, "shrink_axis_mask")
1556                              ? GetIntAttr(node, "shrink_axis_mask")
1557                              : 0;
1558 
1559   model->operators.emplace_back(op);
1560   return ::tensorflow::OkStatus();
1561 }
1562 
ConvertPlaceholderOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1563 tensorflow::Status ConvertPlaceholderOperator(
1564     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1565     const ModelFlags& model_flags, Model* model) {
1566   CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput");
1567   if (node.op() == "Placeholder") {
1568     TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 0));
1569   }
1570 
1571   bool inside_input_arrays = false;
1572   for (const auto& input_array : model_flags.input_arrays()) {
1573     if (node.name() == input_array.name()) {
1574       inside_input_arrays = true;
1575       break;
1576     }
1577   }
1578 
1579   if (!inside_input_arrays) {
1580     model->AddInvalidInputArray(node.name());
1581   }
1582 
1583   auto& array = model->GetOrCreateArray(node.name());
1584   if (node.attr().count("dtype")) {
1585     array.data_type = ConvertDataType(GetDataTypeAttr(node, "dtype"));
1586   }
1587   if (node.attr().count("shape")) {
1588     const auto& shape = GetShapeAttr(node, "shape");
1589     auto num_dims = shape.dim_size();
1590     // TODO(b/62716978): This logic needs to be revisited.  During dims
1591     // refactoring it is an interim fix.
1592     if (num_dims > 0 && !HasWildcardDimension(shape)) {
1593       auto& dst_array_dims = *array.mutable_shape()->mutable_dims();
1594       dst_array_dims.resize(num_dims);
1595       for (std::size_t i = 0; i < num_dims; i++) {
1596         dst_array_dims[i] = shape.dim(i).size();
1597       }
1598     }
1599   }
1600   return ::tensorflow::OkStatus();
1601 }
1602 
ConvertNoOpOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1603 tensorflow::Status ConvertNoOpOperator(
1604     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1605     const ModelFlags& model_flags, Model* model) {
1606   return ::tensorflow::OkStatus();
1607 }
1608 
ConvertCastOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1609 tensorflow::Status ConvertCastOperator(
1610     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1611     const ModelFlags& model_flags, Model* model) {
1612   CHECK_EQ(node.op(), "Cast");
1613   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1614   const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT");
1615   const auto tf_dst_dtype = GetDataTypeAttr(node, "DstT");
1616   auto* op = new CastOperator;
1617   op->src_data_type = ConvertDataType(tf_src_dtype);
1618   op->dst_data_type = ConvertDataType(tf_dst_dtype);
1619   op->inputs.push_back(node.input(0));
1620   op->outputs.push_back(node.name());
1621   model->operators.emplace_back(op);
1622   return ::tensorflow::OkStatus();
1623 }
1624 
ConvertFloorOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1625 tensorflow::Status ConvertFloorOperator(
1626     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1627     const ModelFlags& model_flags, Model* model) {
1628   CHECK_EQ(node.op(), "Floor");
1629   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1630   const auto data_type = GetDataTypeAttr(node, "T");
1631   CHECK(data_type == DT_FLOAT);
1632   auto* op = new FloorOperator;
1633   op->inputs.push_back(node.input(0));
1634   op->outputs.push_back(node.name());
1635   model->operators.emplace_back(op);
1636   return ::tensorflow::OkStatus();
1637 }
1638 
ConvertCeilOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1639 tensorflow::Status ConvertCeilOperator(
1640     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1641     const ModelFlags& model_flags, Model* model) {
1642   CHECK_EQ(node.op(), "Ceil");
1643   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1644   const auto data_type = GetDataTypeAttr(node, "T");
1645   CHECK(data_type == DT_FLOAT);
1646   auto* op = new CeilOperator;
1647   op->inputs.push_back(node.input(0));
1648   op->outputs.push_back(node.name());
1649   model->operators.emplace_back(op);
1650   return ::tensorflow::OkStatus();
1651 }
1652 
ConvertRoundOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1653 tensorflow::Status ConvertRoundOperator(
1654     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1655     const ModelFlags& model_flags, Model* model) {
1656   CHECK_EQ(node.op(), "Round");
1657   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1658   const auto data_type = GetDataTypeAttr(node, "T");
1659   CHECK(data_type == DT_FLOAT);
1660   auto* op = new RoundOperator;
1661   op->inputs.push_back(node.input(0));
1662   op->outputs.push_back(node.name());
1663   model->operators.emplace_back(op);
1664   return ::tensorflow::OkStatus();
1665 }
1666 
ConvertGatherOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1667 tensorflow::Status ConvertGatherOperator(
1668     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1669     const ModelFlags& model_flags, Model* model) {
1670   CHECK(node.op() == "Gather" || node.op() == "GatherV2");
1671   if (node.op() == "Gather")
1672     TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1673   if (node.op() == "GatherV2")
1674     TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1675   const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
1676   CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64);
1677   auto* op = new GatherOperator;
1678   op->inputs.push_back(node.input(0));
1679   op->inputs.push_back(node.input(1));
1680   if (node.input_size() >= 3) {
1681     // GatherV2 form where we are provided an axis. It may be either a constant
1682     // or runtime defined value, so we just wire up the array and let
1683     // ResolveGatherAttributes take care of it later on.
1684     const auto axis_data_type = GetDataTypeAttr(node, "Taxis");
1685     CHECK(axis_data_type == DT_INT32 || axis_data_type == DT_INT64);
1686     op->inputs.push_back(node.input(2));
1687   } else {
1688     // Gather form that assumes axis=0.
1689     op->axis = {0};
1690   }
1691   op->outputs.push_back(node.name());
1692   model->operators.emplace_back(op);
1693   return ::tensorflow::OkStatus();
1694 }
1695 
ConvertGatherNdOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1696 tensorflow::Status ConvertGatherNdOperator(
1697     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1698     const ModelFlags& model_flags, Model* model) {
1699   CHECK_EQ(node.op(), "GatherNd");
1700   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1701   const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
1702   CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64);
1703   auto* op = new GatherNdOperator;
1704   op->inputs.push_back(node.input(0));
1705   op->inputs.push_back(node.input(1));
1706   op->outputs.push_back(node.name());
1707   model->operators.emplace_back(op);
1708   return ::tensorflow::OkStatus();
1709 }
1710 
1711 template <typename Op>
ConvertArgMinMaxOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1712 tensorflow::Status ConvertArgMinMaxOperator(
1713     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1714     const ModelFlags& model_flags, Model* model) {
1715   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1716   const auto axis_data_type =
1717       HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
1718   const auto output_type = HasAttr(node, "output_type")
1719                                ? GetDataTypeAttr(node, "output_type")
1720                                : DT_INT64;
1721   CHECK(axis_data_type == DT_INT64 || axis_data_type == DT_INT32);
1722   CHECK(output_type == DT_INT64 || output_type == DT_INT32);
1723   auto* op = new Op;
1724   op->output_data_type = ConvertDataType(output_type);
1725   op->inputs.push_back(node.input(0));
1726   op->inputs.push_back(node.input(1));
1727   op->outputs.push_back(node.name());
1728   model->operators.emplace_back(op);
1729   return ::tensorflow::OkStatus();
1730 }
1731 
ConvertArgMaxOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1732 tensorflow::Status ConvertArgMaxOperator(
1733     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1734     const ModelFlags& model_flags, Model* model) {
1735   CHECK_EQ(node.op(), "ArgMax");
1736   return ConvertArgMinMaxOperator<ArgMaxOperator>(node, tf_import_flags,
1737                                                   model_flags, model);
1738 }
1739 
ConvertArgMinOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1740 tensorflow::Status ConvertArgMinOperator(
1741     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1742     const ModelFlags& model_flags, Model* model) {
1743   CHECK_EQ(node.op(), "ArgMin");
1744   return ConvertArgMinMaxOperator<ArgMinOperator>(node, tf_import_flags,
1745                                                   model_flags, model);
1746 }
1747 
ConvertResizeBilinearOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1748 tensorflow::Status ConvertResizeBilinearOperator(
1749     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1750     const ModelFlags& model_flags, Model* model) {
1751   CHECK_EQ(node.op(), "ResizeBilinear");
1752   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1753   auto* op = new ResizeBilinearOperator;
1754 
1755   op->align_corners = false;
1756   op->half_pixel_centers = false;
1757   if (HasAttr(node, "align_corners")) {
1758     op->align_corners = GetBoolAttr(node, "align_corners");
1759   }
1760   if (HasAttr(node, "half_pixel_centers")) {
1761     op->half_pixel_centers = GetBoolAttr(node, "half_pixel_centers");
1762   }
1763 
1764   op->inputs.push_back(node.input(0));
1765   op->inputs.push_back(node.input(1));
1766   op->outputs.push_back(node.name());
1767   model->operators.emplace_back(op);
1768   return ::tensorflow::OkStatus();
1769 }
1770 
ConvertResizeNearestNeighborOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1771 tensorflow::Status ConvertResizeNearestNeighborOperator(
1772     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1773     const ModelFlags& model_flags, Model* model) {
1774   CHECK_EQ(node.op(), "ResizeNearestNeighbor");
1775   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1776   auto* op = new ResizeNearestNeighborOperator;
1777 
1778   op->align_corners = false;
1779   op->half_pixel_centers = false;
1780   if (HasAttr(node, "align_corners")) {
1781     op->align_corners = GetBoolAttr(node, "align_corners");
1782   }
1783   if (HasAttr(node, "half_pixel_centers")) {
1784     op->half_pixel_centers = GetBoolAttr(node, "half_pixel_centers");
1785   }
1786 
1787   op->inputs.push_back(node.input(0));
1788   op->inputs.push_back(node.input(1));
1789   op->outputs.push_back(node.name());
1790   model->operators.emplace_back(op);
1791   return ::tensorflow::OkStatus();
1792 }
1793 
ConvertBatchNormWithGlobalNormalizationOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1794 tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator(
1795     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1796     const ModelFlags& model_flags, Model* model) {
1797   CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization");
1798   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
1799 
1800   // TODO(ahentz): to really match tensorflow we need to add variance_epsilon
1801   // to the input, before feeding it into TensorFlowRsqrtOperator.
1802   // CHECK_EQ(GetFloatAttr(node, "variance_epsilon"), 0.001f);
1803 
1804   std::string multiplier = node.name() + "_mul";
1805   if (GetBoolAttr(node, "scale_after_normalization")) {
1806     // Create graph:
1807     //   v -> RSQRT ->
1808     //                 MUL  -> multiplier
1809     //   gamma  ----->
1810     std::string rsqrt = node.name() + "_rsqrt";
1811 
1812     auto* rsqrt_op = new TensorFlowRsqrtOperator;
1813     rsqrt_op->inputs.push_back(node.input(2));
1814     rsqrt_op->outputs.push_back(rsqrt);
1815     model->operators.emplace_back(rsqrt_op);
1816 
1817     auto* mul_op = new MulOperator;
1818     mul_op->inputs.push_back(rsqrt);
1819     mul_op->inputs.push_back(node.input(4));
1820     mul_op->outputs.push_back(multiplier);
1821     model->operators.emplace_back(mul_op);
1822   } else {
1823     // Create graph:
1824     //   v -> RSQRT -> multiplier
1825     auto* rsqrt_op = new TensorFlowRsqrtOperator;
1826     rsqrt_op->inputs.push_back(node.input(2));
1827     rsqrt_op->outputs.push_back(multiplier);
1828     model->operators.emplace_back(rsqrt_op);
1829   }
1830 
1831   auto* op = new BatchNormalizationOperator;
1832   op->global_normalization = true;
1833 
1834   op->inputs.push_back(node.input(0));
1835   op->inputs.push_back(node.input(1));
1836   op->inputs.push_back(multiplier);
1837   op->inputs.push_back(node.input(3));
1838   op->outputs.push_back(node.name());
1839 
1840   model->operators.emplace_back(op);
1841   return ::tensorflow::OkStatus();
1842 }
1843 
ConvertFusedBatchNormOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1844 tensorflow::Status ConvertFusedBatchNormOperator(
1845     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1846     const ModelFlags& model_flags, Model* model) {
1847   CHECK((node.op() == "FusedBatchNorm") || (node.op() == "FusedBatchNormV3"));
1848   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
1849 
1850   // Declare shortcuts for the inputs.
1851   const std::string& gamma_input = node.input(1);
1852   const std::string& beta_input = node.input(2);
1853   const std::string& moving_mean_input = node.input(3);
1854   const std::string& moving_variance_input = node.input(4);
1855 
1856   // Create an array holding the epsilon value (typically, 0.001).
1857   const std::string epsilon_array_name =
1858       CreateConstArray<ArrayDataType::kFloat>(model,
1859                                               node.name() + "_epsilon_array",
1860                                               {GetFloatAttr(node, "epsilon")});
1861 
1862   // Add epsilon to the moving variance.
1863   const std::string epsilon_add_op_name = node.name() + "_epsilon";
1864   auto* epsilon_add_op = new AddOperator;
1865   epsilon_add_op->inputs.push_back(moving_variance_input);
1866   epsilon_add_op->inputs.push_back(epsilon_array_name);
1867   epsilon_add_op->outputs.push_back(epsilon_add_op_name);
1868   model->operators.emplace_back(epsilon_add_op);
1869 
1870   // Take the inverse square root of the (variance + epsilon).
1871   const std::string rsqrt_op_name = node.name() + "_rsqrt";
1872   auto* rsqrt_op = new TensorFlowRsqrtOperator;
1873   rsqrt_op->inputs.push_back(epsilon_add_op_name);
1874   rsqrt_op->outputs.push_back(rsqrt_op_name);
1875   model->operators.emplace_back(rsqrt_op);
1876 
1877   // Multiply the result by gamma.
1878   const std::string multiplier = node.name() + "_mul";
1879   auto* mul_op = new MulOperator;
1880   mul_op->inputs.push_back(rsqrt_op_name);
1881   mul_op->inputs.push_back(gamma_input);
1882   mul_op->outputs.push_back(multiplier);
1883   model->operators.emplace_back(mul_op);
1884 
1885   // Now we have all required inputs for the BatchNormalizationOperator.
1886   auto* op = new BatchNormalizationOperator;
1887   op->global_normalization = true;
1888 
1889   op->inputs.push_back(node.input(0));
1890   op->inputs.push_back(moving_mean_input);
1891   op->inputs.push_back(multiplier);
1892   op->inputs.push_back(beta_input);
1893   op->outputs.push_back(node.name());
1894 
1895   model->operators.emplace_back(op);
1896   return ::tensorflow::OkStatus();
1897 }
1898 
ConvertSpaceToBatchNDOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1899 tensorflow::Status ConvertSpaceToBatchNDOperator(
1900     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1901     const ModelFlags& model_flags, Model* model) {
1902   CHECK_EQ(node.op(), "SpaceToBatchND");
1903   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1904   CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
1905   CHECK_EQ(GetDataTypeAttr(node, "Tpaddings"), DT_INT32);
1906   auto* op = new SpaceToBatchNDOperator;
1907   op->inputs.push_back(node.input(0));
1908   op->inputs.push_back(node.input(1));
1909   op->inputs.push_back(node.input(2));
1910   op->outputs.push_back(node.name());
1911   model->operators.emplace_back(op);
1912   return ::tensorflow::OkStatus();
1913 }
1914 
ConvertBatchToSpaceNDOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1915 tensorflow::Status ConvertBatchToSpaceNDOperator(
1916     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1917     const ModelFlags& model_flags, Model* model) {
1918   CHECK_EQ(node.op(), "BatchToSpaceND");
1919   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1920   CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
1921   CHECK_EQ(GetDataTypeAttr(node, "Tcrops"), DT_INT32);
1922   auto* op = new BatchToSpaceNDOperator;
1923   op->inputs.push_back(node.input(0));
1924   op->inputs.push_back(node.input(1));
1925   op->inputs.push_back(node.input(2));
1926   op->outputs.push_back(node.name());
1927   model->operators.emplace_back(op);
1928   return ::tensorflow::OkStatus();
1929 }
1930 
1931 template <typename T>
ConvertReduceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1932 tensorflow::Status ConvertReduceOperator(
1933     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1934     const ModelFlags& model_flags, Model* model) {
1935   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1936   auto* op = new T;
1937   op->inputs.push_back(node.input(0));
1938   op->inputs.push_back(node.input(1));
1939   op->outputs.push_back(node.name());
1940   model->operators.emplace_back(op);
1941   if (HasAttr(node, "keepdims")) {
1942     op->keep_dims = GetBoolAttr(node, "keepdims");
1943   } else if (HasAttr(node, "keep_dims")) {
1944     op->keep_dims = GetBoolAttr(node, "keep_dims");
1945   }
1946   return ::tensorflow::OkStatus();
1947 }
1948 
1949 // TODO(b/139320642): Add test when fused op is supported.
ConvertSvdfOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1950 tensorflow::Status ConvertSvdfOperator(
1951     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1952     const ModelFlags& model_flags, Model* model) {
1953   CHECK_EQ(node.op(), "Svdf");
1954   const int input_size = GetInputsCount(node, tf_import_flags);
1955   QCHECK(input_size == 4 || input_size == 5)
1956       << "Svdf node expects 3 or 4 inputs other than control dependencies: "
1957       << node.DebugString();
1958   bool has_bias = (input_size == 5);
1959   auto* op = new SvdfOperator;
1960   int index = 0;
1961   op->inputs.push_back(node.input(index++));
1962   op->inputs.push_back(node.input(index++));
1963   op->inputs.push_back(node.input(index++));
1964   if (has_bias) {
1965     op->inputs.push_back(node.input(index++));
1966   }
1967   op->inputs.push_back(node.input(index));
1968   op->outputs.push_back(node.name());
1969   if (node.attr().at("ActivationFunction").s() == "Relu") {
1970     op->fused_activation_function = FusedActivationFunctionType::kRelu;
1971   } else {
1972     op->fused_activation_function = FusedActivationFunctionType::kNone;
1973   }
1974   op->rank = node.attr().at("Rank").i();
1975   model->operators.emplace_back(op);
1976   return ::tensorflow::OkStatus();
1977 }
1978 
1979 // This is just bare bones support to get the shapes to propagate.
ConvertTransposeConvOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)1980 tensorflow::Status ConvertTransposeConvOperator(
1981     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1982     const ModelFlags& model_flags, Model* model) {
1983   CHECK_EQ(node.op(), "Conv2DBackpropInput");
1984   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1985   auto* op = new TransposeConvOperator;
1986   op->inputs.push_back(node.input(0));
1987   op->inputs.push_back(node.input(1));
1988   op->inputs.push_back(node.input(2));
1989   op->outputs.push_back(node.name());
1990   const auto& strides = GetListAttr(node, "strides");
1991   op->stride_height = strides.i(1);
1992   op->stride_width = strides.i(2);
1993   CHECK_EQ(strides.i_size(), 4)
1994       << "Can only import TransposeConv ops with 4D strides. TensorFlow op \""
1995       << node.name() << "\" has " << strides.i_size() << "D strides.";
1996   CHECK((strides.i(0) == 1) && (strides.i(3) == 1))
1997       << "Can only import TransposeConv ops with striding along the height "
1998          "(1st) or width (2nd) axis. TensorFlow op \""
1999       << node.name() << "\" had strides:[ " << strides.i(0) << ", "
2000       << strides.i(1) << ", " << strides.i(2) << ", " << strides.i(3) << "].";
2001   op->stride_height = strides.i(1);
2002   op->stride_width = strides.i(2);
2003   if (HasAttr(node, "dilations")) {
2004     const auto& dilations = GetListAttr(node, "dilations");
2005     CHECK_EQ(dilations.i_size(), 4)
2006         << "Dilation unsupported in TransposeConv. TensorFlow op \""
2007         << node.name() << "\" had dilations";
2008     CHECK((dilations.i(0) == 1) && (dilations.i(1) == 1) &&
2009           (dilations.i(2) == 1) && (dilations.i(3) == 1))
2010         << "Dilation unsupported in TransposeConv. TensorFlow op \""
2011         << node.name() << "\" had dilations:[ " << dilations.i(0) << ", "
2012         << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3)
2013         << "].";
2014   }
2015 
2016   const std::string& weights_name = node.input(TransposeConvOperator::WEIGHTS);
2017   const std::string& transposed_weights_name = weights_name + "_transposed";
2018   // Check if a TransposeOperator was already created for these weights
2019   // (can happen when multiple layers share the same weights).
2020   const Operator* existing_transpose =
2021       GetOpWithOutput(*model, transposed_weights_name);
2022   if (existing_transpose) {
2023     CHECK(existing_transpose->type == OperatorType::kTranspose);
2024   } else {
2025     // Transpose weights from HWOI order to OHWI order, which is more efficient
2026     // for computation. (Note that TensorFlow considers the order as HWIO
2027     // because they consider this a backward conv, inverting the sense of
2028     // input/output.)
2029     TransposeOperator* transpose = new TransposeOperator;
2030     std::string perm_array = CreateConstArray<ArrayDataType::kInt32>(
2031         model, node.name() + "_transpose_perm", {2, 0, 1, 3});
2032     transpose->inputs = {weights_name, perm_array};
2033     transpose->outputs = {transposed_weights_name};
2034     model->operators.emplace_back(transpose);
2035   }
2036   op->inputs[1] = transposed_weights_name;
2037 
2038   auto const& padding = GetStringAttr(node, "padding");
2039   if (padding == "SAME") {
2040     op->padding.type = PaddingType::kSame;
2041   } else if (padding == "VALID") {
2042     op->padding.type = PaddingType::kValid;
2043   } else {
2044     LOG(FATAL) << "Only SAME and VALID padding supported on "
2045                   "Conv2DBackpropInput nodes.";
2046   }
2047   model->operators.emplace_back(op);
2048   return ::tensorflow::OkStatus();
2049 }
2050 
ConvertRangeOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2051 tensorflow::Status ConvertRangeOperator(
2052     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2053     const ModelFlags& model_flags, Model* model) {
2054   CHECK_EQ(node.op(), "Range");
2055   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
2056   auto* op = new RangeOperator;
2057   if (HasAttr(node, "Tidx")) {
2058     const auto dtype = toco::GetDataTypeAttr(node, "Tidx");
2059     CHECK(dtype == DT_UINT8 || dtype == DT_INT32 || dtype == DT_INT64 ||
2060           dtype == DT_FLOAT);
2061     op->dtype = ConvertDataType(dtype);
2062   }
2063   op->inputs.push_back(node.input(0));
2064   op->inputs.push_back(node.input(1));
2065   op->inputs.push_back(node.input(2));
2066   op->outputs.push_back(node.name());
2067 
2068   model->operators.emplace_back(op);
2069   return ::tensorflow::OkStatus();
2070 }
2071 
2072 // Note that it's easy to confuse/conflate "Stack" and "Pack" operators, but
2073 // they aren't the same thing.  tf.stack results in a "Pack" operator.  "Stack"
2074 // operators also exist, but involve manipulating the TF runtime stack, and are
2075 // not directly related to tf.stack() usage.
ConvertPackOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2076 tensorflow::Status ConvertPackOperator(
2077     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2078     const ModelFlags& model_flags, Model* model) {
2079   CHECK_EQ(node.op(), "Pack");
2080   auto op = std::make_unique<PackOperator>();
2081   const int num_inputs = GetInputsCount(node, tf_import_flags);
2082   QCHECK_GE(num_inputs, 1)
2083       << node.op()
2084       << " node expects at least 1 input other than control dependencies: "
2085       << node.DebugString();
2086   CHECK_EQ(num_inputs, GetIntAttr(node, "N"));
2087   for (int i = 0; i < num_inputs; ++i) {
2088     op->inputs.push_back(node.input(i));
2089   }
2090   op->values_count = HasAttr(node, "N") ? GetIntAttr(node, "N") : num_inputs;
2091   op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0;
2092   op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T"));
2093   op->outputs.push_back(node.name());
2094   model->operators.emplace_back(std::move(op));
2095   return ::tensorflow::OkStatus();
2096 }
2097 
ConvertUnpackOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2098 tensorflow::Status ConvertUnpackOperator(
2099     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2100     const ModelFlags& model_flags, Model* model) {
2101   CHECK_EQ(node.op(), "Unpack");
2102   auto op = std::make_unique<UnpackOperator>();
2103   const int num_inputs = GetInputsCount(node, tf_import_flags);
2104   QCHECK_EQ(num_inputs, 1);
2105   op->inputs.push_back(node.input(0));
2106   op->num = GetIntAttr(node, "num");
2107   op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0;
2108   op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T"));
2109 
2110   op->outputs.push_back(node.name());  // Implicit :0.
2111   for (int i = 1; i < op->num; ++i) {
2112     op->outputs.push_back(node.name() + ":" + std::to_string(i));
2113   }
2114   model->operators.emplace_back(std::move(op));
2115   return ::tensorflow::OkStatus();
2116 }
2117 
2118 // Some TensorFlow ops only occur in graph cycles, representing
2119 // control flow. We do not currently support control flow, so we wouldn't
2120 // be able to fully support such graphs, including performing inference,
2121 // anyway. However, rather than erroring out early on graphs being cyclic,
2122 // it helps to at least support these just enough to allow getting a
2123 // graph visualization. This is not trivial, as we require graphs to be
2124 // acyclic aside from RNN back-edges. The solution is to special-case
2125 // such ops as RNN back-edges, which is technically incorrect (does not
2126 // allow representing the op's semantics) but good enough to get a
2127 // graph visualization.
ConvertOperatorSpecialCasedAsRNNBackEdge(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2128 tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge(
2129     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2130     const ModelFlags& model_flags, Model* model) {
2131   // At the moment, the only type of operator special-cased in this way is
2132   // NextIteration, occurring only in control-flow cycles.
2133   CHECK_EQ(node.op(), "NextIteration");
2134   CHECK_EQ(node.input_size(), 1);
2135   auto* rnn_state = model->flags.add_rnn_states();
2136   // This RNN state is not explicitly created by the user, so it's
2137   // OK for some later graph transformation to discard it.
2138   rnn_state->set_discardable(true);
2139   rnn_state->set_state_array(node.name());
2140   rnn_state->set_back_edge_source_array(node.input(0));
2141   // TODO(tianjuny): Temporary set the size to 1 to avoid transient array
2142   // allocation crash. The real value should depend on the hidden_size of RNN.
2143   rnn_state->set_size(1);
2144   return ::tensorflow::OkStatus();
2145 }
2146 
ConvertShapeOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2147 tensorflow::Status ConvertShapeOperator(
2148     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2149     const ModelFlags& model_flags, Model* model) {
2150   CHECK_EQ(node.op(), "Shape");
2151   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
2152   const auto out_type =
2153       HasAttr(node, "out_type") ? GetDataTypeAttr(node, "out_type") : DT_INT32;
2154   CHECK(out_type == DT_INT64 || out_type == DT_INT32);
2155   auto op = std::make_unique<TensorFlowShapeOperator>();
2156   op->output_data_type = ConvertDataType(out_type);
2157   op->inputs.push_back(node.input(0));
2158   op->outputs.push_back(node.name());
2159   model->operators.push_back(std::move(op));
2160   return ::tensorflow::OkStatus();
2161 }
2162 
ConvertReverseSequenceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2163 tensorflow::Status ConvertReverseSequenceOperator(
2164     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2165     const ModelFlags& model_flags, Model* model) {
2166   CHECK_EQ(node.op(), "ReverseSequence");
2167   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2168   auto op = std::make_unique<ReverseSequenceOperator>();
2169   if (HasAttr(node, "seq_dim")) {
2170     op->seq_dim = GetIntAttr(node, "seq_dim");
2171   }
2172   // In tf.reverse_sequence, batch_dim defaults to 0.
2173   op->batch_dim =
2174       HasAttr(node, "batch_dim") ? GetIntAttr(node, "batch_dim") : 0;
2175   const int num_inputs = GetInputsCount(node, tf_import_flags);
2176   for (int i = 0; i < num_inputs; ++i) {
2177     op->inputs.push_back(node.input(i));
2178   }
2179   op->outputs.push_back(node.name());
2180   model->operators.push_back(std::move(op));
2181   return ::tensorflow::OkStatus();
2182 }
2183 
StripCaretFromArrayNames(Model * model)2184 void StripCaretFromArrayNames(Model* model) {
2185   for (auto& op : model->operators) {
2186     for (auto& input : op->inputs) {
2187       input = std::string(absl::StripPrefix(input, "^"));
2188     }
2189     for (auto& output : op->outputs) {
2190       output = std::string(absl::StripPrefix(output, "^"));
2191     }
2192   }
2193   for (auto& array : model->GetArrayMap()) {
2194     if (absl::StartsWith(array.first, "^")) {
2195       LOG(FATAL) << "What?";
2196     }
2197   }
2198 }
2199 
StripZeroOutputIndexFromInputs(NodeDef * node)2200 void StripZeroOutputIndexFromInputs(NodeDef* node) {
2201   for (auto& input : *node->mutable_input()) {
2202     input = std::string(absl::StripSuffix(input, ":0"));
2203   }
2204 }
2205 
2206 // In TensorFlow GraphDef, when a node has multiple outputs, they are named
2207 // name:0, name:1, ...
2208 // where 'name' is the node's name(). Just 'name' is an equivalent shorthand
2209 // form for name:0.
2210 // A TensorFlow GraphDef does not explicitly list all the outputs of each node
2211 // (unlike inputs), it being implied by the node's name and operator type
2212 // (the latter implies the number of outputs).
2213 // This makes it non-trivial for us to reconstruct the list of all arrays
2214 // present in the graph and, for each operator, the list of its outputs.
2215 // We do that by taking advantage of the fact that
2216 // at least each node lists explicitly its inputs, so after we've loaded
2217 // all nodes, we can use that information.
AddExtraOutputs(Model * model)2218 void AddExtraOutputs(Model* model) {
2219   // Construct the list of all arrays consumed by anything in the graph.
2220   std::vector<std::string> consumed_arrays;
2221   // Add arrays consumed by an op.
2222   for (const auto& consumer_op : model->operators) {
2223     for (const std::string& input : consumer_op->inputs) {
2224       consumed_arrays.push_back(input);
2225     }
2226   }
2227   // Add global outputs of the model.
2228   for (const std::string& output_array : model->flags.output_arrays()) {
2229     consumed_arrays.push_back(output_array);
2230   }
2231   // Add arrays consumed by a RNN back-edge.
2232   for (const auto& rnn_state : model->flags.rnn_states()) {
2233     consumed_arrays.push_back(rnn_state.back_edge_source_array());
2234   }
2235   // Now add operator outputs so that all arrays that are consumed,
2236   // are produced.
2237   for (const std::string& consumed_array : consumed_arrays) {
2238     // Test if consumed_array is already the output of some op.
2239     // This has occurred in a model where separate nodes had names of the form
2240     // foo:$i with the same base name foo.
2241     if (GetOpWithOutput(*model, consumed_array)) {
2242       continue;
2243     }
2244     // Split the consumed array name into the form name:output_index.
2245     const std::vector<std::string>& split = absl::StrSplit(consumed_array, ':');
2246     // If not of the form name:output_index, then this is not an additional
2247     // output of a node with multiple outputs, so nothing to do here.
2248     if (split.size() != 2) {
2249       continue;
2250     }
2251     int output_index = 0;
2252     if (!absl::SimpleAtoi(split[1], &output_index)) {
2253       continue;
2254     }
2255     // Each op is initially recorded as producing at least the array that
2256     // has its name. We use that to identify the producer node.
2257     auto* producer_op = GetOpWithOutput(*model, split[0]);
2258     if (!producer_op) {
2259       continue;
2260     }
2261     // Add extra outputs to that producer node, all the way to the
2262     // output_index.
2263     while (producer_op->outputs.size() <= output_index) {
2264       using toco::port::StringF;
2265       producer_op->outputs.push_back(
2266           StringF("%s:%d", split[0], producer_op->outputs.size()));
2267     }
2268   }
2269 }
2270 
InlineAllFunctions(GraphDef * graphdef)2271 bool InlineAllFunctions(GraphDef* graphdef) {
2272   if (graphdef->library().function().empty()) {
2273     VLOG(kLogLevelModelUnchanged) << "No functions to inline.";
2274     return false;
2275   }
2276 
2277   // Override "_noinline" attribute on all functions
2278   GraphDef graphdef_copy(*graphdef);
2279   for (auto& function :
2280        (*graphdef_copy.mutable_library()->mutable_function())) {
2281     auto* attributes = function.mutable_attr();
2282     if (attributes->count(tensorflow::kNoInlineAttr) != 0) {
2283       (*attributes)[tensorflow::kNoInlineAttr].set_b(false);
2284     }
2285   }
2286 
2287   // Construct minimum resources needed to use ExpandInlineFunctions().
2288   tensorflow::SessionOptions options;
2289   auto* device_count = options.config.mutable_device_count();
2290   device_count->insert({"CPU", 1});
2291   std::vector<std::unique_ptr<tensorflow::Device>> devices;
2292   TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
2293       options, "/job:localhost/replica:0/task:0", &devices));
2294 
2295   tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(),
2296                                             graphdef_copy.library());
2297   tensorflow::StaticDeviceMgr device_mgr(std::move(devices));
2298   tensorflow::ProcessFunctionLibraryRuntime pflr(
2299       &device_mgr, tensorflow::Env::Default(), &options.config,
2300       TF_GRAPH_DEF_VERSION, &fld,
2301       options.config.graph_options().optimizer_options(), nullptr);
2302   tensorflow::FunctionLibraryRuntime* flr;
2303   flr = pflr.GetFLR("/job:localhost/replica:0/task:0/cpu:0");
2304 
2305   tensorflow::Graph graph(fld);
2306   tensorflow::ImportGraphDefOptions gc_opts;
2307   gc_opts.validate_shape = false;
2308   const auto& tf_convert_status = tensorflow::ImportGraphDef(
2309       gc_opts, graphdef_copy, &graph, nullptr, nullptr);
2310   if (!tf_convert_status.ok()) {
2311     LOG(ERROR) << "tensorflow::ImportGraphDef failed with status: "
2312                << tf_convert_status.ToString();
2313     return false;
2314   }
2315 
2316   // Iterate over the graph until there are no more nodes to be inlined.
2317   bool graph_modified = false;
2318   while (tensorflow::ExpandInlineFunctions(flr, &graph)) {
2319     graph_modified = true;
2320   }
2321 
2322   // Output inlined graph
2323   if (graph_modified) {
2324     LOG(INFO) << "Found and inlined TensorFlow functions.";
2325     graph.ToGraphDef(graphdef);
2326   }
2327   return graph_modified;
2328 }
2329 
ConvertTopKV2Operator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2330 tensorflow::Status ConvertTopKV2Operator(
2331     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2332     const ModelFlags& model_flags, Model* model) {
2333   CHECK((node.op() == "TopK") || (node.op() == "TopKV2"));
2334   auto op = std::make_unique<TopKV2Operator>();
2335   op->inputs.push_back(node.input(0));
2336   // K can be encoded as attr (TopK) convert it to a const.
2337   if (HasAttr(node, "k")) {
2338     std::string k_array = CreateConstArray<ArrayDataType::kInt32>(
2339         model, node.name() + "k", {static_cast<int32>(GetIntAttr(node, "k"))});
2340     op->inputs.push_back(k_array);
2341   } else {
2342     TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2343     op->inputs.push_back(node.input(1));
2344   }
2345   // The op has two outputs.
2346   op->outputs.push_back(node.name());
2347   op->outputs.push_back(node.name() + ":1");
2348   model->operators.emplace_back(op.release());
2349   return ::tensorflow::OkStatus();
2350 }
2351 
ConvertDynamicPartitionOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2352 tensorflow::Status ConvertDynamicPartitionOperator(
2353     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2354     const ModelFlags& model_flags, Model* model) {
2355   auto op = std::make_unique<DynamicPartitionOperator>();
2356   CHECK(HasAttr(node, "num_partitions"));
2357   op->num_partitions = GetIntAttr(node, "num_partitions");
2358   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2359   op->inputs.push_back(node.input(0));
2360   op->inputs.push_back(node.input(1));
2361   CHECK_GT(op->num_partitions, 1);
2362   op->outputs.push_back(node.name());  // Implicit :0.
2363   for (int i = 1; i < op->num_partitions; ++i) {
2364     op->outputs.push_back(node.name() + ":" + std::to_string(i));
2365   }
2366   model->operators.emplace_back(op.release());
2367   return ::tensorflow::OkStatus();
2368 }
2369 
ConvertDynamicStitchOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2370 tensorflow::Status ConvertDynamicStitchOperator(
2371     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2372     const ModelFlags& model_flags, Model* model) {
2373   // The parallel and non-parallel variants are the same besides whether they
2374   // have a parallel loop; there are no behavioral differences.
2375   CHECK(node.op() == "DynamicStitch" || node.op() == "ParallelDynamicStitch");
2376   auto op = std::make_unique<DynamicStitchOperator>();
2377   CHECK(HasAttr(node, "N"));
2378   op->num_partitions = GetIntAttr(node, "N");
2379   // Expect all ID partitions + all value partitions.
2380   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, op->num_partitions * 2));
2381   for (int i = 0; i < op->num_partitions * 2; ++i) {
2382     op->inputs.push_back(node.input(i));
2383   }
2384   op->outputs.push_back(node.name());
2385   model->operators.emplace_back(op.release());
2386   return ::tensorflow::OkStatus();
2387 }
2388 
ConvertSparseToDenseOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2389 tensorflow::Status ConvertSparseToDenseOperator(
2390     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2391     const ModelFlags& model_flags, Model* model) {
2392   CHECK_EQ(node.op(), "SparseToDense");
2393   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
2394 
2395   auto* op = new SparseToDenseOperator;
2396   for (const std::string& input : node.input()) {
2397     op->inputs.push_back(input);
2398   }
2399   op->outputs.push_back(node.name());
2400 
2401   op->validate_indices = HasAttr(node, "validate_indices")
2402                              ? GetBoolAttr(node, "validate_indices")
2403                              : true;
2404   model->operators.emplace_back(op);
2405   return ::tensorflow::OkStatus();
2406 }
2407 
ConvertOneHotOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2408 tensorflow::Status ConvertOneHotOperator(
2409     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2410     const ModelFlags& model_flags, Model* model) {
2411   CHECK_EQ(node.op(), "OneHot");
2412   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
2413 
2414   const auto dtype = GetDataTypeAttr(node, "T");
2415   // TODO(b/111744875): Support DT_UINT8 and quantization.
2416   CHECK(dtype == DT_INT32 || dtype == DT_INT64 || dtype == DT_FLOAT ||
2417         dtype == DT_BOOL);
2418 
2419   auto op = std::make_unique<OneHotOperator>();
2420   op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : -1;
2421   for (const std::string& input : node.input()) {
2422     op->inputs.push_back(input);
2423   }
2424   op->outputs.push_back(node.name());
2425   model->operators.emplace_back(op.release());
2426   return ::tensorflow::OkStatus();
2427 }
2428 
ConvertCTCBeamSearchDecoderOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2429 tensorflow::Status ConvertCTCBeamSearchDecoderOperator(
2430     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2431     const ModelFlags& model_flags, Model* model) {
2432   CHECK_EQ(node.op(), "CTCBeamSearchDecoder");
2433   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2434 
2435   auto* op = new CTCBeamSearchDecoderOperator;
2436   for (const std::string& input : node.input()) {
2437     op->inputs.push_back(input);
2438   }
2439 
2440   op->beam_width =
2441       HasAttr(node, "beam_width") ? GetIntAttr(node, "beam_width") : 1;
2442   op->top_paths =
2443       HasAttr(node, "top_paths") ? GetIntAttr(node, "top_paths") : 1;
2444   op->merge_repeated = HasAttr(node, "merge_repeated")
2445                            ? GetBoolAttr(node, "merge_repeated")
2446                            : true;
2447 
2448   // There are top_paths + 1 outputs.
2449   op->outputs.push_back(node.name());  // Implicit :0.
2450   for (int i = 0; i < op->top_paths; ++i) {
2451     op->outputs.push_back(node.name() + ":" + std::to_string(i + 1));
2452   }
2453   model->operators.emplace_back(op);
2454   return ::tensorflow::OkStatus();
2455 }
2456 
2457 // This isn't a TensorFlow builtin op. Currently this node can only be generated
2458 // with TfLite OpHint API.
ConvertUnidirectionalSequenceLstm(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2459 tensorflow::Status ConvertUnidirectionalSequenceLstm(
2460     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2461     const ModelFlags& model_flags, Model* model) {
2462   DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm");
2463 
2464   const auto& indices = GetListAttr(node, "_tflite_input_indices");
2465 
2466   auto* op = new UnidirectionalSequenceLstmOperator();
2467 
2468   // The input size needs to be the same as the TfLite UniDirectionalSequence
2469   // Lstm implementation.
2470   const int kInputsSize = 20;
2471 
2472   op->inputs.resize(kInputsSize);
2473 
2474   if (indices.i_size() != node.input().size()) {
2475     // New version, the optional inputs are filled with constant nodes.
2476     int count = 0;
2477     for (int idx = 0; idx < kInputsSize; ++idx) {
2478       if (count < indices.i_size() && indices.i(count) == idx) {
2479         // Specified input.
2480         op->inputs[idx] = node.input(idx);
2481         count++;
2482       } else {
2483         // Optional input.
2484         std::string optional_name = node.name() + "_" + std::to_string(idx);
2485         model->CreateOptionalArray(optional_name);
2486         op->inputs[idx] = optional_name;
2487       }
2488     }
2489   } else {  // Legacy version.
2490     std::vector<bool> done(kInputsSize);
2491     int idx = 0;
2492     for (const std::string& input : node.input()) {
2493       int real_index = indices.i(idx);
2494       op->inputs[real_index] = (input);
2495       done[real_index] = true;
2496       idx++;
2497     }
2498 
2499     for (int idx = 0; idx < done.size(); idx++) {
2500       if (!done[idx]) {
2501         std::string optional_name = node.name() + "_" + std::to_string(idx);
2502         model->CreateOptionalArray(optional_name);
2503         op->inputs[idx] = optional_name;
2504       }
2505     }
2506   }
2507 
2508   // There're three outputs, only the last one is required.
2509   op->outputs.push_back(node.name() + ":2");
2510   model->operators.emplace_back(op);
2511 
2512   return ::tensorflow::OkStatus();
2513 }
2514 
ConvertLeakyReluOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2515 tensorflow::Status ConvertLeakyReluOperator(
2516     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2517     const ModelFlags& model_flags, Model* model) {
2518   CHECK_EQ(node.op(), "LeakyRelu");
2519   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
2520   CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
2521   const auto& input_name = node.input(0);
2522   auto* op = new LeakyReluOperator;
2523   op->inputs.push_back(input_name);
2524   op->outputs.push_back(node.name());
2525   op->alpha = GetFloatAttr(node, "alpha");
2526   model->operators.emplace_back(op);
2527   return ::tensorflow::OkStatus();
2528 }
2529 
ConvertUnidirectionalSequenceRnn(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model)2530 tensorflow::Status ConvertUnidirectionalSequenceRnn(
2531     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2532     const ModelFlags& model_flags, Model* model) {
2533   DCHECK_EQ(node.op(), "UnidirectionalSequenceRnn");
2534 
2535   const auto& indices = GetListAttr(node, "_tflite_input_indices");
2536   if (indices.i_size() != node.input().size()) {
2537     return tensorflow::errors::InvalidArgument("Input size does not match.");
2538   }
2539 
2540   auto* op = new UnidirectionalSequenceRnnOperator();
2541   for (const std::string& input : node.input()) {
2542     op->inputs.push_back(input);
2543   }
2544   // Only use the last one as input.
2545   op->outputs.push_back(node.name() + ":1");
2546   model->operators.emplace_back(op);
2547 
2548   return ::tensorflow::OkStatus();
2549 }
2550 
2551 }  // namespace
2552 
2553 namespace internal {
2554 
2555 using ConverterType = tensorflow::Status (*)(
2556     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2557     const ModelFlags& model_flags, Model* model);
2558 using ConverterMapType = std::unordered_map<std::string, ConverterType>;
2559 
GetTensorFlowNodeConverterMapForFlex()2560 ConverterMapType GetTensorFlowNodeConverterMapForFlex() {
2561   return std::unordered_map<std::string, ConverterType>({
2562       // We need to let TOCO convert Placeholder information into
2563       // array data, so that the data types are correct.
2564       {"LegacyFedInput", ConvertPlaceholderOperator},
2565       {"Placeholder", ConvertPlaceholderOperator},
2566       {"Const", ConditionallyConvertConstOperator},
2567   });
2568 }
2569 
GetTensorFlowNodeConverterMap()2570 ConverterMapType GetTensorFlowNodeConverterMap() {
2571   return std::unordered_map<std::string, ConverterType>({
2572       {"Abs", ConvertSimpleOperator<AbsOperator, kAnyNumInputs, 1>},
2573       {"Add", ConvertSimpleOperator<AddOperator, 2, 1>},
2574       {"AddV2", ConvertSimpleOperator<AddOperator, 2, 1>},
2575       {"AddN", ConvertSimpleOperator<AddNOperator, kAnyNumInputs, 1>},
2576       {"All", ConvertSimpleOperator<TensorFlowAllOperator, kAnyNumInputs, 1>},
2577       {"Any", ConvertReduceOperator<TensorFlowAnyOperator>},
2578       {"ArgMax", ConvertArgMaxOperator},
2579       {"ArgMin", ConvertArgMinOperator},
2580       {"Assert",
2581        ConvertSimpleOperator<TensorFlowAssertOperator, kAnyNumInputs, 1>},
2582       {"AvgPool", ConvertAvgPoolOperator},
2583       {"BatchMatMul", ConvertBatchMatMulOperator},
2584       {"BatchMatMulV2", ConvertBatchMatMulOperator},
2585       {"BatchNormWithGlobalNormalization",
2586        ConvertBatchNormWithGlobalNormalizationOperator},
2587       {"BatchToSpaceND", ConvertBatchToSpaceNDOperator},
2588       {"BiasAdd", ConvertBiasAddOperator},
2589       {"Cast", ConvertCastOperator},
2590       {"Ceil", ConvertCeilOperator},
2591       {"CheckNumerics", ConvertIdentityOperator},
2592       {"Concat", ConvertConcatOperator},
2593       {"ConcatV2", ConvertConcatOperator},
2594       {"Const", ConvertConstOperator},
2595       {"Conv2D", ConvertConvOperator},
2596       {"Conv2DBackpropInput", ConvertTransposeConvOperator},
2597       {"Cos", ConvertSimpleOperator<CosOperator, 1, 1>},
2598       {"CTCBeamSearchDecoder", ConvertCTCBeamSearchDecoderOperator},
2599       {"DepthToSpace", ConvertDepthToSpaceOperator},
2600       {"DepthwiseConv2dNative", ConvertDepthwiseConvOperator},
2601       {"Div", ConvertSimpleOperator<DivOperator, 2, 1>},
2602       {"DynamicPartition", ConvertDynamicPartitionOperator},
2603       {"DynamicStitch", ConvertDynamicStitchOperator},
2604       {"Elu", ConvertSimpleOperator<EluOperator, 1, 1>},
2605       {"EnsureShape", ConvertIdentityOperator},
2606       {"Equal", ConvertSimpleOperator<TensorFlowEqualOperator, 2, 1>},
2607       {"Exp", ConvertSimpleOperator<ExpOperator, 1, 1>},
2608       {"ExpandDims", ConvertSimpleOperator<ExpandDimsOperator, 2, 1>},
2609       {"FakeQuantWithMinMaxArgs", ConvertFakeQuantWithMinMaxArgs},
2610       {"FakeQuantWithMinMaxVars", ConvertFakeQuantWithMinMaxVars},
2611       {"Fill", ConvertSimpleOperator<FillOperator, 2, 1>},
2612       {"Floor", ConvertFloorOperator},
2613       {"FloorDiv", ConvertSimpleOperator<FloorDivOperator, 2, 1>},
2614       {"FloorMod", ConvertSimpleOperator<FloorModOperator, 2, 1>},
2615       {"FusedBatchNorm", ConvertFusedBatchNormOperator},
2616       {"FusedBatchNormV3", ConvertFusedBatchNormOperator},
2617       {"Gather", ConvertGatherOperator},
2618       {"GatherV2", ConvertGatherOperator},
2619       {"GatherNd", ConvertGatherNdOperator},
2620       {"Greater", ConvertSimpleOperator<TensorFlowGreaterOperator, 2, 1>},
2621       {"GreaterEqual",
2622        ConvertSimpleOperator<TensorFlowGreaterEqualOperator, 2, 1>},
2623       {"Identity", ConvertIdentityOperator},
2624       {"IdentityN", ConvertIdentityNOperator},
2625       {"LRN", ConvertLRNOperator},
2626       {"LeakyRelu", ConvertLeakyReluOperator},
2627       {"LegacyFedInput", ConvertPlaceholderOperator},
2628       {"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2, 1>},
2629       {"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2, 1>},
2630       {"Log", ConvertSimpleOperator<LogOperator, 1, 1>},
2631       {"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2, 1>},
2632       {"LogicalOr", ConvertSimpleOperator<LogicalOrOperator, 2, 1>},
2633       {"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1, 1>},
2634       {"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1, 1>},
2635       {"MatMul", ConvertMatMulOperator},
2636       {"MatrixDiag", ConvertSimpleOperator<MatrixDiagOperator, 1, 1>},
2637       {"MatrixDiagV2", ConvertSimpleOperator<MatrixDiagV2Operator, 5, 1>},
2638       // `MatrixDiagV3` has an `align` attribute. However, Toco only converts
2639       // `MatrixDiagV3` to `MatrixDiag` with default `k, num_rows, num_cols,
2640       // padding_value` inputs. In this case, `align` can be ignored.
2641       {"MatrixDiagV3", ConvertSimpleOperator<MatrixDiagV3Operator, 5, 1>},
2642       {"MatrixSetDiag", ConvertSimpleOperator<MatrixSetDiagOperator, 2, 1>},
2643       {"MatrixSetDiagV2", ConvertSimpleOperator<MatrixSetDiagV2Operator, 3, 1>},
2644       // `MatrixSetDiagV3` has an `align` attribute. However, Toco only converts
2645       // `MatrixSetDiagV3` to `MatrixSetDiag` with default `k` inputs. In this
2646       // case, `align` can be ignored.
2647       {"MatrixSetDiagV3", ConvertSimpleOperator<MatrixSetDiagV3Operator, 3, 1>},
2648       {"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
2649       {"MaxPool", ConvertMaxPoolOperator},
2650       {"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},
2651       {"Mean", ConvertReduceOperator<MeanOperator>},
2652       {"Merge",
2653        ConvertSimpleOperator<TensorFlowMergeOperator, kAnyNumInputs, 1>},
2654       {"Min", ConvertReduceOperator<TensorFlowMinOperator>},
2655       {"Minimum", ConvertSimpleOperator<TensorFlowMinimumOperator, 2, 1>},
2656       {"Mul", ConvertSimpleOperator<MulOperator, 2, 1>},
2657       {"Neg", ConvertSimpleOperator<NegOperator, 1, 1>},
2658       {"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge},
2659       {"NoOp", ConvertNoOpOperator},
2660       {"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2, 1>},
2661       {"OneHot", ConvertOneHotOperator},
2662       {"Pack", ConvertPackOperator},
2663       {"Pad", ConvertSimpleOperator<PadOperator, 2, 1>},
2664       {"PadV2", ConvertSimpleOperator<PadV2Operator, 3, 1>},
2665       {"ParallelDynamicStitch", ConvertDynamicStitchOperator},
2666       {"Placeholder", ConvertPlaceholderOperator},
2667       {"PlaceholderWithDefault", ConvertIdentityOperator},
2668       {"Pow", ConvertSimpleOperator<PowOperator, 2, 1>},
2669       {"Prod", ConvertReduceOperator<TensorFlowProdOperator>},
2670       {"RandomUniform", ConvertRandomUniform},
2671       {"Range", ConvertRangeOperator},
2672       {"Rank", ConvertSimpleOperator<TensorFlowRankOperator, 1, 1>},
2673       {"RealDiv", ConvertSimpleOperator<DivOperator, 2, 1>},
2674       {"Relu", ConvertSimpleOperator<ReluOperator, 1, 1>},
2675       {"Relu6", ConvertSimpleOperator<Relu6Operator, 1, 1>},
2676       {"Reshape", ConvertSimpleOperator<TensorFlowReshapeOperator, 2, 1>},
2677       {"ResizeBilinear", ConvertResizeBilinearOperator},
2678       {"ResizeNearestNeighbor", ConvertResizeNearestNeighborOperator},
2679       {"ReverseSequence", ConvertReverseSequenceOperator},
2680       {"ReverseV2", ConvertSimpleOperator<ReverseV2Operator, 2, 1>},
2681       {"Round", ConvertRoundOperator},
2682       {"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1, 1>},
2683       {"ScatterNd", ConvertSimpleOperator<ScatterNdOperator, 3, 1>},
2684       {"SegmentSum", ConvertSimpleOperator<SegmentSumOperator, 2, 1>},
2685       {"Select", ConvertSimpleOperator<SelectOperator, 3, 1>},
2686       {"SelectV2", ConvertSimpleOperator<SelectOperator, 3, 1>},
2687       {"Shape", ConvertShapeOperator},
2688       {"Sigmoid", ConvertSimpleOperator<LogisticOperator, 1, 1>},
2689       {"Sin", ConvertSimpleOperator<SinOperator, 1, 1>},
2690       {"Slice", ConvertSimpleOperator<SliceOperator, 3, 1>},
2691       {"Softmax", ConvertSoftmaxOperator},
2692       {"SpaceToBatchND", ConvertSpaceToBatchNDOperator},
2693       {"SpaceToDepth", ConvertSpaceToDepthOperator},
2694       {"SparseToDense", ConvertSparseToDenseOperator},
2695       {"Split", ConvertSplitOperator},
2696       {"SplitV", ConvertSplitVOperator},
2697       {"Sqrt", ConvertSimpleOperator<TensorFlowSqrtOperator, 1, 1>},
2698       {"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1, 1>},
2699       {"SquaredDifference",
2700        ConvertSimpleOperator<SquaredDifferenceOperator, 2, 1>},
2701       {"Snapshot", ConvertIdentityOperator},
2702       {"Squeeze", ConvertSqueezeOperator},
2703       {"StopGradient", ConvertIdentityOperator},
2704       {"StridedSlice", ConvertStridedSliceOperator},
2705       {"Sub", ConvertSimpleOperator<SubOperator, 2, 1>},
2706       {"Sum", ConvertReduceOperator<TensorFlowSumOperator>},
2707       {"Svdf", ConvertSvdfOperator},
2708       {"Switch", ConvertSwitchOperator},
2709       {"Tanh", ConvertSimpleOperator<TanhOperator, 1, 1>},
2710       {"Tile", ConvertSimpleOperator<TensorFlowTileOperator, 2, 1>},
2711       {"TopK", ConvertTopKV2Operator},
2712       {"TopKV2", ConvertTopKV2Operator},
2713       {"Transpose", ConvertSimpleOperator<TransposeOperator, 2, 1>},
2714       {"Unpack", ConvertUnpackOperator},
2715       {"ZerosLike", ConvertSimpleOperator<TensorFlowZerosLikeOperator, 1, 1>},
2716       {"UnidirectionalSequenceLstm", ConvertUnidirectionalSequenceLstm},
2717       {"UnidirectionalSequenceRnn", ConvertUnidirectionalSequenceRnn},
2718       {"MirrorPad", ConvertMirrorPadOperator},
2719       {"Unique", ConvertSimpleOperator<UniqueOperator, 1, 2>},
2720       {"Where", ConvertSimpleOperator<WhereOperator, 1, 1>},
2721   });
2722 }
2723 
ImportTensorFlowNode(const tensorflow::NodeDef & node,const TensorFlowImportFlags & tf_import_flags,const ModelFlags & model_flags,Model * model,const ConverterMapType & converter_map)2724 tensorflow::Status ImportTensorFlowNode(
2725     const tensorflow::NodeDef& node,
2726     const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags,
2727     Model* model, const ConverterMapType& converter_map) {
2728   auto converter = converter_map.find(node.op());
2729   if (converter == converter_map.end()) {
2730     return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
2731                                       model);
2732   } else {
2733     return converter->second(node, tf_import_flags, model_flags, model);
2734   }
2735 }
2736 }  // namespace internal
2737 
ImportTensorFlowGraphDef(const ModelFlags & model_flags,const TensorFlowImportFlags & tf_import_flags,const GraphDef & tf_graph)2738 std::unique_ptr<Model> ImportTensorFlowGraphDef(
2739     const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
2740     const GraphDef& tf_graph) {
2741   LogDumpGraphDef(kLogLevelModelChanged, "AT IMPORT", tf_graph);
2742 
2743   GraphDef inlined_graph(tf_graph);
2744   if (InlineAllFunctions(&inlined_graph)) {
2745     LogDumpGraphDef(kLogLevelModelChanged, "AFTER INLINING", inlined_graph);
2746   }
2747 
2748   // Check input and output specification.
2749   for (const auto& specified_input_array : model_flags.input_arrays()) {
2750     CHECK(!absl::EndsWith(specified_input_array.name(), ":0"))
2751         << "Unsupported explicit zero output index: "
2752         << specified_input_array.name();
2753   }
2754   for (const std::string& specified_output_array :
2755        model_flags.output_arrays()) {
2756     CHECK(!absl::EndsWith(specified_output_array, ":0"))
2757         << "Unsupported explicit zero output index: " << specified_output_array;
2758   }
2759 
2760   Model* model = new Model;
2761   internal::ConverterMapType converter_map;
2762 
2763   // This is used for the TFLite "Full Flex Mode" conversion. All the ops are
2764   // imported as `TensorFlowUnsupportedOperator`, and later all these ops are
2765   // converted to TFLite Flex ops.
2766   if (!tf_import_flags.import_all_ops_as_unsupported) {
2767     converter_map = internal::GetTensorFlowNodeConverterMap();
2768   } else {
2769     converter_map = internal::GetTensorFlowNodeConverterMapForFlex();
2770   }
2771 
2772   for (auto node : inlined_graph.node()) {
2773     StripZeroOutputIndexFromInputs(&node);
2774     auto status = internal::ImportTensorFlowNode(
2775         node, tf_import_flags, model_flags, model, converter_map);
2776     CHECK(status.ok()) << status.error_message();
2777   }
2778 
2779   ResolveModelFlags(model_flags, model);
2780 
2781   StripCaretFromArrayNames(model);
2782   AddExtraOutputs(model);
2783   FixNoMissingArray(model);
2784   FixNoOrphanedArray(model);
2785   FixOperatorOrdering(model);
2786   CheckInvariants(*model);
2787 
2788   // if rnn state arrays are constant, make them transient
2789   for (const auto& rnn_state : model->flags.rnn_states()) {
2790     model->GetArray(rnn_state.state_array()).buffer = nullptr;
2791   }
2792 
2793   return std::unique_ptr<Model>(model);
2794 }
2795 
ImportTensorFlowGraphDef(const ModelFlags & model_flags,const TensorFlowImportFlags & tf_import_flags,const std::string & input_file_contents)2796 std::unique_ptr<Model> ImportTensorFlowGraphDef(
2797     const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
2798     const std::string& input_file_contents) {
2799   std::unique_ptr<GraphDef> tf_graph(new GraphDef);
2800   CHECK(ParseFromStringEitherTextOrBinary(input_file_contents, tf_graph.get()));
2801 
2802   std::unique_ptr<GraphDef> pruned_graph =
2803       MaybeReplaceCompositeSubgraph(*tf_graph);
2804   if (pruned_graph) {
2805     tf_graph = std::move(pruned_graph);
2806   }
2807   return ImportTensorFlowGraphDef(model_flags, tf_import_flags, *tf_graph);
2808 }
2809 
2810 }  // namespace toco
2811