xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/toco/export_tensorflow.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <string>
18 #include <unordered_map>
19 #include <utility>
20 #include <vector>
21 
22 #include "google/protobuf/map.h"
23 #include "google/protobuf/text_format.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/string_view.h"
26 #include "tensorflow/core/framework/attr_value.pb.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/framework/tensor_shape.pb.h"
31 #include "tensorflow/core/framework/types.pb.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/lite/toco/model.h"
34 #include "tensorflow/lite/toco/model_flags.pb.h"
35 #include "tensorflow/lite/toco/runtime/types.h"
36 #include "tensorflow/lite/toco/tensorflow_util.h"
37 #include "tensorflow/lite/toco/tooling_util.h"
38 
39 using tensorflow::DT_BOOL;
40 using tensorflow::DT_COMPLEX64;
41 using tensorflow::DT_FLOAT;
42 using tensorflow::DT_INT16;
43 using tensorflow::DT_INT32;
44 using tensorflow::DT_INT64;
45 using tensorflow::DT_UINT32;
46 using tensorflow::DT_UINT8;
47 using tensorflow::GraphDef;
48 using tensorflow::TensorProto;
49 
50 namespace toco {
51 namespace {
52 
GetTensorFlowDataType(ArrayDataType data_type,const std::string & error_location)53 tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type,
54                                            const std::string& error_location) {
55   switch (data_type) {
56     case ArrayDataType::kBool:
57       return tensorflow::DT_BOOL;
58     case ArrayDataType::kFloat:
59       return tensorflow::DT_FLOAT;
60     case ArrayDataType::kUint8:
61       return tensorflow::DT_UINT8;
62     case ArrayDataType::kInt16:
63       return tensorflow::DT_INT16;
64     case ArrayDataType::kUint16:
65       return tensorflow::DT_UINT16;
66     case ArrayDataType::kInt32:
67       return tensorflow::DT_INT32;
68     case ArrayDataType::kUint32:
69       return tensorflow::DT_UINT32;
70     case ArrayDataType::kInt64:
71       return tensorflow::DT_INT64;
72     case ArrayDataType::kString:
73       return tensorflow::DT_STRING;
74     case ArrayDataType::kComplex64:
75       return tensorflow::DT_COMPLEX64;
76     default:
77     case ArrayDataType::kNone:
78       LOG(FATAL) << "Unsupported data type '" << ArrayDataTypeName(data_type)
79                  << "' in " << error_location;
80       return tensorflow::DT_INVALID;
81   }
82 }
83 
GetTensorFlowDataTypeForOp(ArrayDataType data_type,const std::string & op_name)84 tensorflow::DataType GetTensorFlowDataTypeForOp(ArrayDataType data_type,
85                                                 const std::string& op_name) {
86   return GetTensorFlowDataType(data_type, "op '" + op_name + "'");
87 }
88 
GetTensorFlowDataType(const Model & model,const std::string & array_name)89 tensorflow::DataType GetTensorFlowDataType(const Model& model,
90                                            const std::string& array_name) {
91   return GetTensorFlowDataType(model.GetArray(array_name).data_type,
92                                "array '" + array_name + "'");
93 }
94 
95 // TensorFlow sometimes forbids what it calls "legacy scalars",
96 // which are 1-D shapes where the unique shape size is 1.
97 // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars.
98 // For that reason, we generally avoid creating legacy scalars,
99 // by detecting the case where a 1-D shape would be of size 1 and
100 // replacing that by a 0-D shape.
101 // However, there is a special circumstance where we must not do that
102 // and must unconditionally create a 1-D shape even if it is going to
103 // be of size 1: that is the case of bias vectors, with BiasAdd nodes.
104 // Indeed, TensorFlow requires bias vectors to be 1-D; in the case of
105 // a depth of 1, that would be a legacy scalar, so in that case we
106 // must go ahead and keep the shape 1-D, letting it be a legacy scalar.
107 enum class LegacyScalarPolicy { kAvoidLegacyScalars, kDoCreateLegacyScalars };
108 
ExportFloatArray(const Shape & input_shape,const float * input_data,TensorProto * output_tensor,LegacyScalarPolicy legacy_scalar_policy)109 void ExportFloatArray(const Shape& input_shape, const float* input_data,
110                       TensorProto* output_tensor,
111                       LegacyScalarPolicy legacy_scalar_policy) {
112   output_tensor->set_dtype(DT_FLOAT);
113   const int input_flat_size = RequiredBufferSizeForShape(input_shape);
114   auto* shape = output_tensor->mutable_tensor_shape();
115 
116   const int kDims = input_shape.dimensions_count();
117   if (legacy_scalar_policy == LegacyScalarPolicy::kDoCreateLegacyScalars ||
118       kDims > 1 || (kDims == 1 && input_shape.dims(0) > 1)) {
119     for (int i = 0; i < kDims; ++i) {
120       shape->add_dim()->set_size(input_shape.dims(i));
121     }
122   }
123   output_tensor->set_tensor_content(
124       std::string(reinterpret_cast<const char*>(input_data),
125                   sizeof(*input_data) * input_flat_size));
126 }
127 
ExportFloatArray(AxesOrder input_axes_order,const Shape & input_shape,const float * input_data,AxesOrder output_axes_order,TensorProto * output_tensor,LegacyScalarPolicy legacy_scalar_policy)128 void ExportFloatArray(AxesOrder input_axes_order, const Shape& input_shape,
129                       const float* input_data, AxesOrder output_axes_order,
130                       TensorProto* output_tensor,
131                       LegacyScalarPolicy legacy_scalar_policy) {
132   CHECK_EQ(AxesCount(output_axes_order), AxesCount(input_axes_order));
133   output_tensor->set_dtype(DT_FLOAT);
134   CHECK_EQ(input_shape.dimensions_count(), AxesCount(input_axes_order));
135   const int input_flat_size = RequiredBufferSizeForShape(input_shape);
136 
137   Shape shuffled_shape;
138   ShuffleDims(input_shape, input_axes_order, output_axes_order,
139               &shuffled_shape);
140   std::vector<float> shuffled_data(input_flat_size);
141   ShuffleArray(input_shape, input_axes_order, output_axes_order, shuffled_shape,
142                input_data, shuffled_data.data());
143 
144   ExportFloatArray(shuffled_shape, shuffled_data.data(), output_tensor,
145                    legacy_scalar_policy);
146 }
147 
HasAlreadyExportedConst(const std::string & name,const GraphDef & tensorflow_graph)148 bool HasAlreadyExportedConst(const std::string& name,
149                              const GraphDef& tensorflow_graph) {
150   for (const auto& node : tensorflow_graph.node()) {
151     if (node.op() == "Const" && node.name() == name) {
152       return true;
153     }
154   }
155   return false;
156 }
157 
ConvertFloatTensorConst(const std::string & name,const Shape & input_shape,const float * input_data,AxesOrder input_axes_order,AxesOrder output_axes_order,GraphDef * tensorflow_graph,LegacyScalarPolicy legacy_scalar_policy)158 void ConvertFloatTensorConst(const std::string& name, const Shape& input_shape,
159                              const float* input_data,
160                              AxesOrder input_axes_order,
161                              AxesOrder output_axes_order,
162                              GraphDef* tensorflow_graph,
163                              LegacyScalarPolicy legacy_scalar_policy) {
164   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
165     return;
166   }
167   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
168   const_op->set_op("Const");
169   const_op->set_name(name);
170   (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
171   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
172   ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order,
173                    tensor, legacy_scalar_policy);
174 }
175 
ConvertFloatTensorConst(const std::string & name,const Shape & input_shape,const float * input_data,AxesOrder input_axes_order,AxesOrder output_axes_order,GraphDef * tensorflow_graph)176 void ConvertFloatTensorConst(const std::string& name, const Shape& input_shape,
177                              const float* input_data,
178                              AxesOrder input_axes_order,
179                              AxesOrder output_axes_order,
180                              GraphDef* tensorflow_graph) {
181   ConvertFloatTensorConst(name, input_shape, input_data, input_axes_order,
182                           output_axes_order, tensorflow_graph,
183                           LegacyScalarPolicy::kAvoidLegacyScalars);
184 }
185 
ConvertFloatTensorConst(const Model & model,const std::string & name,AxesOrder input_axes_order,AxesOrder output_axes_order,GraphDef * tensorflow_graph)186 void ConvertFloatTensorConst(const Model& model, const std::string& name,
187                              AxesOrder input_axes_order,
188                              AxesOrder output_axes_order,
189                              GraphDef* tensorflow_graph) {
190   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
191     return;
192   }
193   CHECK(model.HasArray(name));
194   const auto& input_array = model.GetArray(name);
195   const auto& input_shape = input_array.shape();
196   CHECK(input_array.buffer);
197   CHECK(input_array.buffer->type == ArrayDataType::kFloat);
198   const float* input_data =
199       input_array.GetBuffer<ArrayDataType::kFloat>().data.data();
200   ConvertFloatTensorConst(name, input_shape, input_data, input_axes_order,
201                           output_axes_order, tensorflow_graph);
202 }
203 
ConvertFloatTensorConst(const Model & model,const std::string & name,GraphDef * tensorflow_graph)204 void ConvertFloatTensorConst(const Model& model, const std::string& name,
205                              GraphDef* tensorflow_graph) {
206   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
207     return;
208   }
209   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
210   const_op->set_op("Const");
211   const_op->set_name(name);
212   (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
213   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
214   CHECK(model.HasArray(name));
215   const auto& input_array = model.GetArray(name);
216   const auto& input_shape = input_array.shape();
217   CHECK(input_array.buffer);
218   CHECK(input_array.buffer->type == ArrayDataType::kFloat);
219   const float* input_data =
220       input_array.GetBuffer<ArrayDataType::kFloat>().data.data();
221   ExportFloatArray(input_shape, input_data, tensor,
222                    LegacyScalarPolicy::kAvoidLegacyScalars);
223 }
224 
ConvertBoolTensorConst(const Model & model,const std::string & name,GraphDef * tensorflow_graph)225 void ConvertBoolTensorConst(const Model& model, const std::string& name,
226                             GraphDef* tensorflow_graph) {
227   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
228     return;
229   }
230   CHECK(model.HasArray(name));
231   const auto& array = model.GetArray(name);
232   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
233   const_op->set_op("Const");
234   const_op->set_name(name);
235   (*const_op->mutable_attr())["dtype"].set_type(DT_BOOL);
236   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
237   tensor->set_dtype(DT_BOOL);
238   const auto& data = array.GetBuffer<ArrayDataType::kBool>().data;
239   for (auto index : data) {
240     tensor->add_bool_val(index);
241   }
242   const auto& array_shape = array.shape();
243   auto* shape = tensor->mutable_tensor_shape();
244   for (int i = 0; i < array_shape.dimensions_count(); i++) {
245     shape->add_dim()->set_size(array_shape.dims(i));
246   }
247 }
248 
ConvertIntTensorConst(const Model & model,const std::string & name,GraphDef * tensorflow_graph)249 void ConvertIntTensorConst(const Model& model, const std::string& name,
250                            GraphDef* tensorflow_graph) {
251   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
252     return;
253   }
254   CHECK(model.HasArray(name));
255   const auto& array = model.GetArray(name);
256   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
257   const_op->set_op("Const");
258   const_op->set_name(name);
259   (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
260   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
261   tensor->set_dtype(DT_INT32);
262   const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data;
263   for (auto index : data) {
264     tensor->add_int_val(index);
265   }
266   const auto& array_shape = array.shape();
267   auto* shape = tensor->mutable_tensor_shape();
268   for (int i = 0; i < array_shape.dimensions_count(); i++) {
269     shape->add_dim()->set_size(array_shape.dims(i));
270   }
271 }
272 
CreateIntTensorConst(const std::string & name,const std::vector<int32> & data,const std::vector<int32> & shape,GraphDef * tensorflow_graph)273 void CreateIntTensorConst(const std::string& name,
274                           const std::vector<int32>& data,
275                           const std::vector<int32>& shape,
276                           GraphDef* tensorflow_graph) {
277   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
278     return;
279   }
280   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
281   const_op->set_op("Const");
282   const_op->set_name(name);
283   (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
284   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
285   tensor->set_dtype(DT_INT32);
286   for (auto index : data) {
287     tensor->add_int_val(index);
288   }
289   auto* tensor_shape = tensor->mutable_tensor_shape();
290   int num_elements = 1;
291   for (int size : shape) {
292     tensor_shape->add_dim()->set_size(size);
293     num_elements *= size;
294   }
295   CHECK_EQ(num_elements, data.size());
296 }
297 
ConvertComplex64TensorConst(const Model & model,const std::string & name,GraphDef * tensorflow_graph)298 void ConvertComplex64TensorConst(const Model& model, const std::string& name,
299                                  GraphDef* tensorflow_graph) {
300   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
301     return;
302   }
303   CHECK(model.HasArray(name));
304   const auto& array = model.GetArray(name);
305   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
306   const_op->set_op("Const");
307   const_op->set_name(name);
308   (*const_op->mutable_attr())["dtype"].set_type(DT_COMPLEX64);
309   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
310   tensor->set_dtype(DT_COMPLEX64);
311   const auto& data = array.GetBuffer<ArrayDataType::kComplex64>().data;
312   for (auto index : data) {
313     tensor->add_scomplex_val(std::real(index));
314     tensor->add_scomplex_val(std::imag(index));
315   }
316   const auto& array_shape = array.shape();
317   auto* shape = tensor->mutable_tensor_shape();
318   for (int i = 0; i < array_shape.dimensions_count(); i++) {
319     shape->add_dim()->set_size(array_shape.dims(i));
320   }
321 }
322 
CreateMatrixShapeTensorConst(const std::string & name,int rows,int cols,GraphDef * tensorflow_graph)323 void CreateMatrixShapeTensorConst(const std::string& name, int rows, int cols,
324                                   GraphDef* tensorflow_graph) {
325   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
326     return;
327   }
328   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
329   const_op->set_op("Const");
330   const_op->set_name(name);
331   (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
332   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
333   tensor->set_dtype(DT_INT32);
334   const int32 data[2] = {cols, rows};
335   tensor->set_tensor_content(
336       std::string(reinterpret_cast<const char*>(data), sizeof(data)));
337   auto* shape = tensor->mutable_tensor_shape();
338   shape->add_dim()->set_size(2);
339 }
340 
CreateDummyConcatDimTensorConst(const std::string & name,int dim,GraphDef * tensorflow_graph)341 void CreateDummyConcatDimTensorConst(const std::string& name, int dim,
342                                      GraphDef* tensorflow_graph) {
343   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
344     return;
345   }
346   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
347   const_op->set_op("Const");
348   const_op->set_name(name);
349   (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
350   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
351   tensor->set_dtype(DT_INT32);
352   tensor->add_int_val(dim);
353 }
354 
CreateReshapeShapeTensorConst(const std::string & name,const std::vector<int32> & shape,GraphDef * tensorflow_graph)355 void CreateReshapeShapeTensorConst(const std::string& name,
356                                    const std::vector<int32>& shape,
357                                    GraphDef* tensorflow_graph) {
358   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
359     return;
360   }
361   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
362   const_op->set_op("Const");
363   const_op->set_name(name);
364   (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
365   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
366   tensor->set_dtype(DT_INT32);
367   for (auto s : shape) {
368     tensor->add_int_val(s);
369   }
370   // TensorFlow sometimes forbids what it calls "legacy scalars",
371   // which are shapes of size 1 where the unique shape size is 1.
372   // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars.
373   if (shape.size() > 1) {
374     auto* tensor_shape = tensor->mutable_tensor_shape();
375     tensor_shape->add_dim()->set_size(shape.size());
376   }
377 }
378 
WalkUpToConstantArray(const Model & model,const std::string & name)379 std::string WalkUpToConstantArray(const Model& model, const std::string& name) {
380   const Array& original_array = model.GetArray(name);
381   if (original_array.buffer) {
382     return name;
383   }
384   const auto* op = GetOpWithOutput(model, name);
385   CHECK(op);
386   CHECK(op->type == OperatorType::kFakeQuant);
387   const std::string& input_of_fakequant_name = op->inputs[0];
388   const Array& input_of_fakequant = model.GetArray(input_of_fakequant_name);
389   CHECK(input_of_fakequant.buffer);
390   return input_of_fakequant_name;
391 }
392 
ConvertConvOperator(const Model & model,const ConvOperator & src_op,GraphDef * tensorflow_graph)393 void ConvertConvOperator(const Model& model, const ConvOperator& src_op,
394                          GraphDef* tensorflow_graph) {
395   const bool has_bias = src_op.inputs.size() >= 3;
396   std::string conv_output = src_op.outputs[0];
397   if (has_bias) {
398     conv_output += "/conv";
399   }
400 
401   tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node();
402   conv2d_op->set_op("Conv2D");
403   conv2d_op->set_name(conv_output);
404   *conv2d_op->add_input() = src_op.inputs[0];
405   *conv2d_op->add_input() = src_op.inputs[1];
406   (*conv2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
407   const std::string& weights_array_name =
408       WalkUpToConstantArray(model, src_op.inputs[1]);
409   const auto& weights_array = model.GetArray(weights_array_name);
410   CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
411   ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI,
412                           AxesOrder::kHWIO, tensorflow_graph);
413   auto& strides = (*conv2d_op->mutable_attr())["strides"];
414   strides.mutable_list()->add_i(1);
415   strides.mutable_list()->add_i(src_op.stride_height);
416   strides.mutable_list()->add_i(src_op.stride_width);
417   strides.mutable_list()->add_i(1);
418   if ((src_op.dilation_width_factor != 1) ||
419       (src_op.dilation_height_factor != 1)) {
420     auto& dilations = (*conv2d_op->mutable_attr())["dilations"];
421     dilations.mutable_list()->add_i(1);
422     dilations.mutable_list()->add_i(src_op.dilation_height_factor);
423     dilations.mutable_list()->add_i(src_op.dilation_width_factor);
424     dilations.mutable_list()->add_i(1);
425   }
426   std::string padding;
427   if (src_op.padding.type == PaddingType::kSame) {
428     padding = "SAME";
429   } else if (src_op.padding.type == PaddingType::kValid) {
430     padding = "VALID";
431   } else {
432     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
433   }
434   (*conv2d_op->mutable_attr())["padding"].set_s(padding);
435 
436   if (has_bias) {
437     tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
438     biasadd_op->set_op("BiasAdd");
439     biasadd_op->set_name(src_op.outputs[0]);
440     biasadd_op->add_input(conv_output);
441     biasadd_op->add_input(src_op.inputs[2]);
442     (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
443     CHECK(model.HasArray(src_op.inputs[2]));
444     const std::string& bias_array_name =
445         WalkUpToConstantArray(model, src_op.inputs[2]);
446     const auto& bias_array = model.GetArray(bias_array_name);
447     // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
448     Shape bias_shape_1d = bias_array.shape();
449     UnextendShape(&bias_shape_1d, 1);
450     CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
451     const float* bias_data =
452         bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
453     ConvertFloatTensorConst(bias_array_name, bias_shape_1d, bias_data,
454                             AxesOrder::kOneAxis, AxesOrder::kOneAxis,
455                             tensorflow_graph,
456                             LegacyScalarPolicy::kDoCreateLegacyScalars);
457   }
458 }
459 
ConvertDepthwiseConvOperator(const Model & model,const DepthwiseConvOperator & src_op,GraphDef * tensorflow_graph)460 void ConvertDepthwiseConvOperator(const Model& model,
461                                   const DepthwiseConvOperator& src_op,
462                                   GraphDef* tensorflow_graph) {
463   const bool has_bias = src_op.inputs.size() >= 3;
464   std::string conv_output = src_op.outputs[0];
465   if (has_bias) {
466     conv_output += "/conv";
467   }
468 
469   tensorflow::NodeDef* dc2d_op = tensorflow_graph->add_node();
470   dc2d_op->set_op("DepthwiseConv2dNative");
471   dc2d_op->set_name(conv_output);
472   *dc2d_op->add_input() = src_op.inputs[0];
473   *dc2d_op->add_input() = src_op.inputs[1];
474   (*dc2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
475 
476   // Our internal DepthwiseConv weights are 1 x H x W x OutputDepth.
477   // We need to convert that to H x W x InputDepth x Multiplier.
478   // That's only a matter of constructing a Dims object; the actual
479   // array layout is the same.
480   CHECK(model.HasArray(src_op.inputs[1]));
481   const std::string& src_weights_name =
482       WalkUpToConstantArray(model, src_op.inputs[1]);
483   const auto& src_weights_array = model.GetArray(src_weights_name);
484   const auto& src_weights_shape = src_weights_array.shape();
485   CHECK_EQ(src_weights_shape.dimensions_count(), 4);
486   const Shape dst_weights_shape =
487       Shape({src_weights_shape.dims(1), src_weights_shape.dims(2),
488              src_weights_shape.dims(3) / src_op.depth_multiplier,
489              src_op.depth_multiplier});
490   CHECK_EQ(src_weights_shape.dims(3) % src_op.depth_multiplier, 0);
491   CHECK(dst_weights_shape.dims(2) * dst_weights_shape.dims(3) ==
492         src_weights_shape.dims(3));
493   CHECK_EQ(src_weights_shape.dims(0), 1);
494 
495   CHECK(src_weights_array.buffer->type == ArrayDataType::kFloat);
496   const float* src_weights_data =
497       src_weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
498   ConvertFloatTensorConst(src_weights_name, dst_weights_shape, src_weights_data,
499                           AxesOrder::kHWIM, AxesOrder::kHWIM, tensorflow_graph);
500 
501   auto& strides = (*dc2d_op->mutable_attr())["strides"];
502   strides.mutable_list()->add_i(1);
503   strides.mutable_list()->add_i(src_op.stride_height);
504   strides.mutable_list()->add_i(src_op.stride_width);
505   strides.mutable_list()->add_i(1);
506   // TODO(b/116063589): To return a working TF GraphDef, we should be returning
507   // the correct SpaceToBatchNd and BatchToSpaceND operation before and after
508   // the conv since TF doesn't support dilations.
509   if ((src_op.dilation_width_factor != 1) ||
510       (src_op.dilation_height_factor != 1)) {
511     auto& dilations = (*dc2d_op->mutable_attr())["dilations"];
512     dilations.mutable_list()->add_i(1);
513     dilations.mutable_list()->add_i(src_op.dilation_height_factor);
514     dilations.mutable_list()->add_i(src_op.dilation_width_factor);
515     dilations.mutable_list()->add_i(1);
516   }
517   std::string padding;
518   if (src_op.padding.type == PaddingType::kSame) {
519     padding = "SAME";
520   } else if (src_op.padding.type == PaddingType::kValid) {
521     padding = "VALID";
522   } else {
523     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
524   }
525   (*dc2d_op->mutable_attr())["padding"].set_s(padding);
526 
527   if (has_bias) {
528     tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
529     biasadd_op->set_op("BiasAdd");
530     biasadd_op->set_name(src_op.outputs[0]);
531     biasadd_op->add_input(conv_output);
532     biasadd_op->add_input(src_op.inputs[2]);
533     (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
534     CHECK(model.HasArray(src_op.inputs[2]));
535     const std::string& bias_name =
536         WalkUpToConstantArray(model, src_op.inputs[2]);
537     const auto& bias_array = model.GetArray(bias_name);
538     // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
539     Shape bias_shape_1d = bias_array.shape();
540     UnextendShape(&bias_shape_1d, 1);
541     CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
542     const float* bias_data =
543         bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
544     ConvertFloatTensorConst(bias_name, bias_shape_1d, bias_data,
545                             AxesOrder::kOneAxis, AxesOrder::kOneAxis,
546                             tensorflow_graph,
547                             LegacyScalarPolicy::kDoCreateLegacyScalars);
548   }
549 }
550 
ConvertTransposeConvOperator(const Model & model,const TransposeConvOperator & src_op,GraphDef * tensorflow_graph)551 void ConvertTransposeConvOperator(const Model& model,
552                                   const TransposeConvOperator& src_op,
553                                   GraphDef* tensorflow_graph) {
554   tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node();
555   conv2d_op->set_op("Conv2DBackpropInput");
556   conv2d_op->set_name(src_op.outputs[0]);
557   *conv2d_op->add_input() = src_op.inputs[0];
558   *conv2d_op->add_input() = src_op.inputs[1];
559   *conv2d_op->add_input() = src_op.inputs[2];
560   (*conv2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
561   const std::string& weights_array_name = WalkUpToConstantArray(
562       model, src_op.inputs[TransposeConvOperator::WEIGHTS]);
563   const auto& weights_array = model.GetArray(weights_array_name);
564   CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
565   ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI,
566                           AxesOrder::kHWOI, tensorflow_graph);
567   auto& strides = (*conv2d_op->mutable_attr())["strides"];
568   strides.mutable_list()->add_i(1);
569   strides.mutable_list()->add_i(src_op.stride_height);
570   strides.mutable_list()->add_i(src_op.stride_width);
571   strides.mutable_list()->add_i(1);
572   std::string padding;
573   if (src_op.padding.type == PaddingType::kSame) {
574     padding = "SAME";
575   } else if (src_op.padding.type == PaddingType::kValid) {
576     padding = "VALID";
577   } else {
578     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
579   }
580   (*conv2d_op->mutable_attr())["padding"].set_s(padding);
581 }
582 
ConvertDepthToSpaceOperator(const Model & model,const DepthToSpaceOperator & src_op,GraphDef * tensorflow_graph)583 void ConvertDepthToSpaceOperator(const Model& model,
584                                  const DepthToSpaceOperator& src_op,
585                                  GraphDef* tensorflow_graph) {
586   tensorflow::NodeDef* op = tensorflow_graph->add_node();
587   op->set_op("DepthToSpace");
588   op->set_name(src_op.outputs[0]);
589   *op->add_input() = src_op.inputs[0];
590   (*op->mutable_attr())["T"].set_type(DT_FLOAT);
591   (*op->mutable_attr())["block_size"].set_i(src_op.block_size);
592 }
593 
ConvertSpaceToDepthOperator(const Model & model,const SpaceToDepthOperator & src_op,GraphDef * tensorflow_graph)594 void ConvertSpaceToDepthOperator(const Model& model,
595                                  const SpaceToDepthOperator& src_op,
596                                  GraphDef* tensorflow_graph) {
597   tensorflow::NodeDef* op = tensorflow_graph->add_node();
598   op->set_op("SpaceToDepth");
599   op->set_name(src_op.outputs[0]);
600   *op->add_input() = src_op.inputs[0];
601   (*op->mutable_attr())["T"].set_type(DT_FLOAT);
602   (*op->mutable_attr())["block_size"].set_i(src_op.block_size);
603 }
604 
ConvertFullyConnectedOperator(const Model & model,const FullyConnectedOperator & src_op,GraphDef * tensorflow_graph)605 void ConvertFullyConnectedOperator(const Model& model,
606                                    const FullyConnectedOperator& src_op,
607                                    GraphDef* tensorflow_graph) {
608   // Reshape input activations to have the shape expected by the MatMul.
609   const std::string reshape_output =
610       AvailableArrayName(model, src_op.outputs[0] + "/reshape");
611   const std::string reshape_shape =
612       AvailableArrayName(model, reshape_output + "/shape");
613   const auto& fc_weights_array = model.GetArray(src_op.inputs[1]);
614   const auto& fc_weights_shape = fc_weights_array.shape();
615   CHECK_EQ(fc_weights_shape.dimensions_count(), 2);
616   CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1,
617                                tensorflow_graph);
618   tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
619   reshape_op->set_op("Reshape");
620   reshape_op->set_name(reshape_output);
621   reshape_op->add_input(src_op.inputs[0]);
622   reshape_op->add_input(reshape_shape);
623   (*reshape_op->mutable_attr())["T"].set_type(
624       GetTensorFlowDataType(model, src_op.inputs[0]));
625 
626   const bool has_bias = src_op.inputs.size() >= 3;
627   std::string matmul_output = src_op.outputs[0];
628   if (has_bias) {
629     matmul_output += "/matmul";
630   }
631 
632   // Transpose the RHS input from column-major to row-major to match TensorFlow
633   // expectations. This is the inverse of the transpose we do during
634   // ResolveTensorFlowMatMul.
635   const std::string transpose_output =
636       AvailableArrayName(model, matmul_output + "/transpose_weights");
637   const std::string transpose_perm =
638       AvailableArrayName(model, transpose_output + "/perm");
639   CreateIntTensorConst(transpose_perm, {1, 0}, {2}, tensorflow_graph);
640   tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node();
641   transpose_op->set_op("Transpose");
642   transpose_op->set_name(transpose_output);
643   *transpose_op->add_input() = src_op.inputs[1];
644   *transpose_op->add_input() = transpose_perm;
645   (*transpose_op->mutable_attr())["T"].set_type(
646       GetTensorFlowDataType(model, src_op.inputs[1]));
647   (*transpose_op->mutable_attr())["Tperm"].set_type(DT_INT32);
648 
649   tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node();
650   matmul_op->set_op("MatMul");
651   matmul_op->set_name(matmul_output);
652   *matmul_op->add_input() = reshape_output;
653   *matmul_op->add_input() = transpose_op->name();
654   (*matmul_op->mutable_attr())["T"].set_type(
655       GetTensorFlowDataType(model, src_op.inputs[0]));
656   (*matmul_op->mutable_attr())["transpose_a"].set_b(false);
657   (*matmul_op->mutable_attr())["transpose_b"].set_b(false);
658   CHECK(model.HasArray(src_op.inputs[1]));
659 
660   // Add the bias, if it exists.
661   if (has_bias) {
662     tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
663     biasadd_op->set_op("BiasAdd");
664     biasadd_op->set_name(src_op.outputs[0]);
665     biasadd_op->add_input(matmul_output);
666     biasadd_op->add_input(src_op.inputs[2]);
667     (*biasadd_op->mutable_attr())["T"].set_type(
668         GetTensorFlowDataType(model, src_op.inputs[0]));
669     CHECK(model.HasArray(src_op.inputs[2]));
670     const auto& bias_array = model.GetArray(src_op.inputs[2]);
671     // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
672     Shape bias_shape_1d = bias_array.shape();
673     UnextendShape(&bias_shape_1d, 1);
674     CHECK(bias_array.buffer);
675     CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
676     const float* bias_data =
677         bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
678     ConvertFloatTensorConst(WalkUpToConstantArray(model, src_op.inputs[2]),
679                             bias_shape_1d, bias_data, AxesOrder::kOneAxis,
680                             AxesOrder::kOneAxis, tensorflow_graph,
681                             LegacyScalarPolicy::kDoCreateLegacyScalars);
682   }
683 }
684 
ConvertAddOperator(const Model & model,const AddOperator & src_op,GraphDef * tensorflow_graph)685 void ConvertAddOperator(const Model& model, const AddOperator& src_op,
686                         GraphDef* tensorflow_graph) {
687   tensorflow::NodeDef* add_op = tensorflow_graph->add_node();
688   add_op->set_op("Add");
689   add_op->set_name(src_op.outputs[0]);
690   CHECK_EQ(src_op.inputs.size(), 2);
691   *add_op->add_input() = src_op.inputs[0];
692   *add_op->add_input() = src_op.inputs[1];
693   (*add_op->mutable_attr())["T"].set_type(
694       GetTensorFlowDataType(model, src_op.outputs[0]));
695 }
696 
ConvertAddNOperator(const Model & model,const AddNOperator & src_op,GraphDef * tensorflow_graph)697 void ConvertAddNOperator(const Model& model, const AddNOperator& src_op,
698                          GraphDef* tensorflow_graph) {
699   tensorflow::NodeDef* add_op = tensorflow_graph->add_node();
700   add_op->set_op("AddN");
701   add_op->set_name(src_op.outputs[0]);
702   for (const auto& input : src_op.inputs) {
703     *add_op->add_input() = input;
704   }
705   (*add_op->mutable_attr())["N"].set_i(src_op.inputs.size());
706   (*add_op->mutable_attr())["T"].set_type(
707       GetTensorFlowDataType(model, src_op.outputs[0]));
708 }
709 
ConvertMulOperator(const Model & model,const MulOperator & src_op,GraphDef * tensorflow_graph)710 void ConvertMulOperator(const Model& model, const MulOperator& src_op,
711                         GraphDef* tensorflow_graph) {
712   tensorflow::NodeDef* mul_op = tensorflow_graph->add_node();
713   mul_op->set_op("Mul");
714   mul_op->set_name(src_op.outputs[0]);
715   CHECK_EQ(src_op.inputs.size(), 2);
716   *mul_op->add_input() = src_op.inputs[0];
717   *mul_op->add_input() = src_op.inputs[1];
718   (*mul_op->mutable_attr())["T"].set_type(
719       GetTensorFlowDataType(model, src_op.outputs[0]));
720 }
721 
ConvertDivOperator(const Model & model,const DivOperator & src_op,GraphDef * tensorflow_graph)722 void ConvertDivOperator(const Model& model, const DivOperator& src_op,
723                         GraphDef* tensorflow_graph) {
724   tensorflow::NodeDef* div_op = tensorflow_graph->add_node();
725   div_op->set_op("Div");
726   div_op->set_name(src_op.outputs[0]);
727   CHECK_EQ(src_op.inputs.size(), 2);
728   *div_op->add_input() = src_op.inputs[0];
729   *div_op->add_input() = src_op.inputs[1];
730   (*div_op->mutable_attr())["T"].set_type(
731       GetTensorFlowDataType(model, src_op.outputs[0]));
732 }
733 
ConvertReluOperator(const Model & model,const ReluOperator & src_op,GraphDef * tensorflow_graph)734 void ConvertReluOperator(const Model& model, const ReluOperator& src_op,
735                          GraphDef* tensorflow_graph) {
736   tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
737   relu_op->set_op("Relu");
738   relu_op->set_name(src_op.outputs[0]);
739   *relu_op->add_input() = src_op.inputs[0];
740   (*relu_op->mutable_attr())["T"].set_type(
741       GetTensorFlowDataType(model, src_op.outputs[0]));
742 }
743 
ConvertRelu1Operator(const Relu1Operator & src_op,GraphDef * tensorflow_graph)744 void ConvertRelu1Operator(const Relu1Operator& src_op,
745                           GraphDef* tensorflow_graph) {
746   const std::string max_bounds = src_op.outputs[0] + "/max_bounds";
747   const std::string min_bounds = src_op.outputs[0] + "/min_bounds";
748   const std::string max_output = src_op.outputs[0] + "/max_output";
749 
750   tensorflow::NodeDef* max_bounds_const_op = tensorflow_graph->add_node();
751   max_bounds_const_op->set_op("Const");
752   max_bounds_const_op->set_name(max_bounds);
753   (*max_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
754   auto* max_bounds_const_op_tensor =
755       (*max_bounds_const_op->mutable_attr())["value"].mutable_tensor();
756   max_bounds_const_op_tensor->set_dtype(DT_FLOAT);
757   max_bounds_const_op_tensor->add_float_val(-1.0f);
758 
759   tensorflow::NodeDef* min_bounds_const_op = tensorflow_graph->add_node();
760   min_bounds_const_op->set_op("Const");
761   min_bounds_const_op->set_name(min_bounds);
762   (*min_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
763   auto* min_bounds_const_op_tensor =
764       (*min_bounds_const_op->mutable_attr())["value"].mutable_tensor();
765   min_bounds_const_op_tensor->set_dtype(DT_FLOAT);
766   min_bounds_const_op_tensor->add_float_val(1.0f);
767 
768   tensorflow::NodeDef* max_op = tensorflow_graph->add_node();
769   max_op->set_op("Maximum");
770   max_op->set_name(max_output);
771   *max_op->add_input() = src_op.inputs[0];
772   *max_op->add_input() = max_bounds;
773   (*max_op->mutable_attr())["T"].set_type(DT_FLOAT);
774 
775   tensorflow::NodeDef* min_op = tensorflow_graph->add_node();
776   min_op->set_op("Minimum");
777   min_op->set_name(src_op.outputs[0]);
778   *min_op->add_input() = max_output;
779   *min_op->add_input() = min_bounds;
780   (*min_op->mutable_attr())["T"].set_type(DT_FLOAT);
781 }
782 
ConvertRelu6Operator(const Relu6Operator & src_op,GraphDef * tensorflow_graph)783 void ConvertRelu6Operator(const Relu6Operator& src_op,
784                           GraphDef* tensorflow_graph) {
785   tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
786   relu_op->set_op("Relu6");
787   relu_op->set_name(src_op.outputs[0]);
788   *relu_op->add_input() = src_op.inputs[0];
789   (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
790 }
791 
ConvertLogOperator(const LogOperator & src_op,GraphDef * tensorflow_graph)792 void ConvertLogOperator(const LogOperator& src_op, GraphDef* tensorflow_graph) {
793   tensorflow::NodeDef* op = tensorflow_graph->add_node();
794   op->set_op("Log");
795   op->set_name(src_op.outputs[0]);
796   CHECK_EQ(src_op.inputs.size(), 1);
797   *op->add_input() = src_op.inputs[0];
798   (*op->mutable_attr())["T"].set_type(DT_FLOAT);
799 }
800 
ConvertLogisticOperator(const LogisticOperator & src_op,GraphDef * tensorflow_graph)801 void ConvertLogisticOperator(const LogisticOperator& src_op,
802                              GraphDef* tensorflow_graph) {
803   tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
804   relu_op->set_op("Sigmoid");
805   relu_op->set_name(src_op.outputs[0]);
806   *relu_op->add_input() = src_op.inputs[0];
807   (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
808 }
809 
ConvertTanhOperator(const TanhOperator & src_op,GraphDef * tensorflow_graph)810 void ConvertTanhOperator(const TanhOperator& src_op,
811                          GraphDef* tensorflow_graph) {
812   tensorflow::NodeDef* tanh_op = tensorflow_graph->add_node();
813   tanh_op->set_op("Tanh");
814   tanh_op->set_name(src_op.outputs[0]);
815   *tanh_op->add_input() = src_op.inputs[0];
816   (*tanh_op->mutable_attr())["T"].set_type(DT_FLOAT);
817 }
818 
ConvertSoftmaxOperator(const Model & model,const SoftmaxOperator & src_op,GraphDef * tensorflow_graph)819 void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op,
820                             GraphDef* tensorflow_graph) {
821   std::string softmax_input;
822   Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
823   if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) {
824     softmax_input = src_op.inputs[0];
825   } else {
826     // Insert a reshape operator that reduces the dimensions down to the 2 that
827     // are required for TensorFlow Logits.
828     const std::string reshape_output =
829         src_op.outputs[0] + "/softmax_insert_reshape";
830     const std::string softmax_size = src_op.outputs[0] + "/softmax_insert_size";
831     softmax_input = reshape_output;
832 
833     tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
834     reshape_op->set_op("Reshape");
835     reshape_op->set_name(reshape_output);
836     *reshape_op->add_input() = src_op.inputs[0];
837     *reshape_op->add_input() = softmax_size;
838     (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
839 
840     const auto& input_shape = model.GetArray(src_op.inputs[0]).shape();
841     int32_t flattened_size = 1;
842     for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
843       flattened_size *= input_shape.dims(i);
844     }
845     const std::vector<int32> shape_data = {
846         flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)};
847     CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph);
848   }
849 
850   tensorflow::NodeDef* softmax_op = tensorflow_graph->add_node();
851   softmax_op->set_op("Softmax");
852   softmax_op->set_name(src_op.outputs[0]);
853   *softmax_op->add_input() = softmax_input;
854   // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter
855   CHECK_EQ(src_op.beta, 1.f);
856   (*softmax_op->mutable_attr())["T"].set_type(DT_FLOAT);
857 }
858 
ConvertLogSoftmaxOperator(const Model & model,const LogSoftmaxOperator & src_op,GraphDef * tensorflow_graph)859 void ConvertLogSoftmaxOperator(const Model& model,
860                                const LogSoftmaxOperator& src_op,
861                                GraphDef* tensorflow_graph) {
862   std::string softmax_input;
863   Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
864   if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) {
865     softmax_input = src_op.inputs[0];
866   } else {
867     // Insert a reshape operator that reduces the dimensions down to the 2 that
868     // are required for TensorFlow Logits.
869     const std::string reshape_output =
870         src_op.outputs[0] + "/log_softmax_insert_reshape";
871     const std::string softmax_size =
872         src_op.outputs[0] + "/log_softmax_insert_size";
873     softmax_input = reshape_output;
874 
875     tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
876     reshape_op->set_op("Reshape");
877     reshape_op->set_name(reshape_output);
878     *reshape_op->add_input() = src_op.inputs[0];
879     *reshape_op->add_input() = softmax_size;
880     (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
881 
882     const auto& input_shape = model.GetArray(src_op.inputs[0]).shape();
883     int32_t flattened_size = 1;
884     for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
885       flattened_size *= input_shape.dims(i);
886     }
887     const std::vector<int32> shape_data = {
888         flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)};
889     CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph);
890   }
891 
892   tensorflow::NodeDef* log_softmax_op = tensorflow_graph->add_node();
893   log_softmax_op->set_op("LogSoftmax");
894   log_softmax_op->set_name(src_op.outputs[0]);
895   *log_softmax_op->add_input() = softmax_input;
896   (*log_softmax_op->mutable_attr())["T"].set_type(DT_FLOAT);
897 }
898 
ConvertL2NormalizationOperator(const L2NormalizationOperator & src_op,GraphDef * tensorflow_graph)899 void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op,
900                                     GraphDef* tensorflow_graph) {
901   const std::string square_output = src_op.outputs[0] + "/square";
902   const std::string sum_reduction_indices =
903       src_op.outputs[0] + "/reduction_indices";
904   const std::string sum_output = src_op.outputs[0] + "/sum";
905   const std::string rsqrt_output = src_op.outputs[0] + "/rsqrt";
906   const std::string rsqrt_tiled_output = src_op.outputs[0] + "/rsqrt_tiled";
907 
908   tensorflow::NodeDef* sum_reduction_indices_op = tensorflow_graph->add_node();
909   sum_reduction_indices_op->set_op("Const");
910   sum_reduction_indices_op->set_name(sum_reduction_indices);
911   (*sum_reduction_indices_op->mutable_attr())["dtype"].set_type(DT_INT32);
912   auto* sum_reduction_indices_tensor =
913       (*sum_reduction_indices_op->mutable_attr())["value"].mutable_tensor();
914   sum_reduction_indices_tensor->set_dtype(DT_INT32);
915   auto* sum_reduction_indices_shape =
916       sum_reduction_indices_tensor->mutable_tensor_shape();
917   auto* sum_reduction_indices_dim = sum_reduction_indices_shape->add_dim();
918   sum_reduction_indices_dim->set_size(2);
919   sum_reduction_indices_tensor->add_int_val(0);
920   sum_reduction_indices_tensor->add_int_val(1);
921 
922   tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
923   square_op->set_op("Square");
924   square_op->set_name(square_output);
925   *square_op->add_input() = src_op.inputs[0];
926   (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
927 
928   tensorflow::NodeDef* sum_op = tensorflow_graph->add_node();
929   sum_op->set_op("Sum");
930   sum_op->set_name(sum_output);
931   *sum_op->add_input() = square_output;
932   *sum_op->add_input() = sum_reduction_indices;
933   (*sum_op->mutable_attr())["T"].set_type(DT_FLOAT);
934 
935   tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node();
936   rsqrt_op->set_op("Rsqrt");
937   rsqrt_op->set_name(rsqrt_output);
938   *rsqrt_op->add_input() = sum_output;
939   (*rsqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
940 
941   tensorflow::NodeDef* mul_op = tensorflow_graph->add_node();
942   mul_op->set_op("Mul");
943   mul_op->set_name(src_op.outputs[0]);
944   *mul_op->add_input() = src_op.inputs[0];
945   *mul_op->add_input() = rsqrt_output;
946   (*mul_op->mutable_attr())["T"].set_type(DT_FLOAT);
947 }
948 
ConvertLocalResponseNormalizationOperator(const LocalResponseNormalizationOperator & src_op,GraphDef * tensorflow_graph)949 void ConvertLocalResponseNormalizationOperator(
950     const LocalResponseNormalizationOperator& src_op,
951     GraphDef* tensorflow_graph) {
952   tensorflow::NodeDef* lrn_op = tensorflow_graph->add_node();
953   lrn_op->set_op("LRN");
954   lrn_op->set_name(src_op.outputs[0]);
955   *lrn_op->add_input() = src_op.inputs[0];
956   (*lrn_op->mutable_attr())["depth_radius"].set_i(src_op.range);
957   (*lrn_op->mutable_attr())["bias"].set_f(src_op.bias);
958   (*lrn_op->mutable_attr())["alpha"].set_f(src_op.alpha);
959   (*lrn_op->mutable_attr())["beta"].set_f(src_op.beta);
960 }
961 
ConvertFakeQuantOperator(const FakeQuantOperator & src_op,GraphDef * tensorflow_graph)962 void ConvertFakeQuantOperator(const FakeQuantOperator& src_op,
963                               GraphDef* tensorflow_graph) {
964   tensorflow::NodeDef* fakequant_op = tensorflow_graph->add_node();
965   fakequant_op->set_op("FakeQuantWithMinMaxArgs");
966   fakequant_op->set_name(src_op.outputs[0]);
967   CHECK_EQ(src_op.inputs.size(), 1);
968   *fakequant_op->add_input() = src_op.inputs[0];
969   CHECK(src_op.minmax);
970   (*fakequant_op->mutable_attr())["min"].set_f(src_op.minmax->min);
971   (*fakequant_op->mutable_attr())["max"].set_f(src_op.minmax->max);
972   if (src_op.num_bits) {
973     (*fakequant_op->mutable_attr())["num_bits"].set_i(src_op.num_bits);
974   }
975   if (src_op.narrow_range) {
976     (*fakequant_op->mutable_attr())["narrow_range"].set_b(src_op.narrow_range);
977   }
978 }
979 
ConvertMaxPoolOperator(const MaxPoolOperator & src_op,GraphDef * tensorflow_graph)980 void ConvertMaxPoolOperator(const MaxPoolOperator& src_op,
981                             GraphDef* tensorflow_graph) {
982   tensorflow::NodeDef* maxpool_op = tensorflow_graph->add_node();
983   maxpool_op->set_op("MaxPool");
984   maxpool_op->set_name(src_op.outputs[0]);
985   *maxpool_op->add_input() = src_op.inputs[0];
986   auto& strides = (*maxpool_op->mutable_attr())["strides"];
987   strides.mutable_list()->add_i(1);
988   strides.mutable_list()->add_i(src_op.stride_height);
989   strides.mutable_list()->add_i(src_op.stride_width);
990   strides.mutable_list()->add_i(1);
991   std::string padding;
992   if (src_op.padding.type == PaddingType::kSame) {
993     padding = "SAME";
994   } else if (src_op.padding.type == PaddingType::kValid) {
995     padding = "VALID";
996   } else {
997     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
998   }
999   (*maxpool_op->mutable_attr())["padding"].set_s(padding);
1000   (*maxpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
1001   auto& ksize = (*maxpool_op->mutable_attr())["ksize"];
1002   ksize.mutable_list()->add_i(1);
1003   ksize.mutable_list()->add_i(src_op.kheight);
1004   ksize.mutable_list()->add_i(src_op.kwidth);
1005   ksize.mutable_list()->add_i(1);
1006 }
1007 
ConvertAveragePoolOperator(const AveragePoolOperator & src_op,GraphDef * tensorflow_graph)1008 void ConvertAveragePoolOperator(const AveragePoolOperator& src_op,
1009                                 GraphDef* tensorflow_graph) {
1010   tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node();
1011   avgpool_op->set_op("AvgPool");
1012   avgpool_op->set_name(src_op.outputs[0]);
1013   *avgpool_op->add_input() = src_op.inputs[0];
1014   auto& strides = (*avgpool_op->mutable_attr())["strides"];
1015   strides.mutable_list()->add_i(1);
1016   strides.mutable_list()->add_i(src_op.stride_height);
1017   strides.mutable_list()->add_i(src_op.stride_width);
1018   strides.mutable_list()->add_i(1);
1019   std::string padding;
1020   if (src_op.padding.type == PaddingType::kSame) {
1021     padding = "SAME";
1022   } else if (src_op.padding.type == PaddingType::kValid) {
1023     padding = "VALID";
1024   } else {
1025     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1026   }
1027   (*avgpool_op->mutable_attr())["padding"].set_s(padding);
1028   (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
1029   auto& ksize = (*avgpool_op->mutable_attr())["ksize"];
1030   ksize.mutable_list()->add_i(1);
1031   ksize.mutable_list()->add_i(src_op.kheight);
1032   ksize.mutable_list()->add_i(src_op.kwidth);
1033   ksize.mutable_list()->add_i(1);
1034 }
1035 
ConvertConcatenationOperator(const Model & model,const ConcatenationOperator & src_op,GraphDef * tensorflow_graph)1036 void ConvertConcatenationOperator(const Model& model,
1037                                   const ConcatenationOperator& src_op,
1038                                   GraphDef* tensorflow_graph) {
1039   tensorflow::NodeDef* dc_op = tensorflow_graph->add_node();
1040   dc_op->set_op("ConcatV2");
1041   dc_op->set_name(src_op.outputs[0]);
1042   const std::string dummy_axis = src_op.outputs[0] + "/axis";
1043   CreateDummyConcatDimTensorConst(dummy_axis, src_op.axis, tensorflow_graph);
1044   for (const auto& input : src_op.inputs) {
1045     *dc_op->add_input() = input;
1046   }
1047   *dc_op->add_input() = dummy_axis;
1048   (*dc_op->mutable_attr())["T"].set_type(
1049       GetTensorFlowDataType(model, src_op.inputs[0]));
1050   (*dc_op->mutable_attr())["Tidx"].set_type(DT_INT32);
1051   (*dc_op->mutable_attr())["N"].set_i(src_op.inputs.size());
1052 }
1053 
ConvertTensorFlowReshapeOperator(const Model & model,const TensorFlowReshapeOperator & src_op,GraphDef * tensorflow_graph)1054 void ConvertTensorFlowReshapeOperator(const Model& model,
1055                                       const TensorFlowReshapeOperator& src_op,
1056                                       GraphDef* tensorflow_graph) {
1057   tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
1058   reshape_op->set_op("Reshape");
1059   reshape_op->set_name(src_op.outputs[0]);
1060   CHECK_EQ(src_op.inputs.size(), 2);
1061   *reshape_op->add_input() = src_op.inputs[0];
1062   *reshape_op->add_input() = src_op.inputs[1];
1063   (*reshape_op->mutable_attr())["T"].set_type(
1064       GetTensorFlowDataType(model, src_op.outputs[0]));
1065   const auto& shape_array = model.GetArray(src_op.inputs[1]);
1066   QCHECK(shape_array.data_type == ArrayDataType::kInt32)
1067       << "Only int32 shape is supported.";
1068   QCHECK(shape_array.buffer != nullptr)
1069       << "Shape inferred at runtime is not supported.";
1070   const auto& shape_data = shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1071   CreateReshapeShapeTensorConst(src_op.inputs[1], shape_data, tensorflow_graph);
1072 }
1073 
ConvertL2PoolOperator(const L2PoolOperator & src_op,GraphDef * tensorflow_graph)1074 void ConvertL2PoolOperator(const L2PoolOperator& src_op,
1075                            GraphDef* tensorflow_graph) {
1076   const std::string square_output = src_op.outputs[0] + "/square";
1077   const std::string avgpool_output = src_op.outputs[0] + "/avgpool";
1078 
1079   tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
1080   square_op->set_op("Square");
1081   square_op->set_name(square_output);
1082   *square_op->add_input() = src_op.inputs[0];
1083   (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
1084 
1085   std::string padding;
1086   if (src_op.padding.type == PaddingType::kSame) {
1087     padding = "SAME";
1088   } else if (src_op.padding.type == PaddingType::kValid) {
1089     padding = "VALID";
1090   } else {
1091     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1092   }
1093 
1094   tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node();
1095   avgpool_op->set_op("AvgPool");
1096   avgpool_op->set_name(avgpool_output);
1097   *avgpool_op->add_input() = square_output;
1098   auto& strides = (*avgpool_op->mutable_attr())["strides"];
1099   strides.mutable_list()->add_i(1);
1100   strides.mutable_list()->add_i(src_op.stride_height);
1101   strides.mutable_list()->add_i(src_op.stride_width);
1102   strides.mutable_list()->add_i(1);
1103 
1104   (*avgpool_op->mutable_attr())["padding"].set_s(padding);
1105   (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
1106   auto& ksize = (*avgpool_op->mutable_attr())["ksize"];
1107   ksize.mutable_list()->add_i(1);
1108   ksize.mutable_list()->add_i(src_op.kheight);
1109   ksize.mutable_list()->add_i(src_op.kwidth);
1110   ksize.mutable_list()->add_i(1);
1111 
1112   tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node();
1113   sqrt_op->set_op("Sqrt");
1114   sqrt_op->set_name(src_op.outputs[0]);
1115   *sqrt_op->add_input() = avgpool_output;
1116   (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
1117 }
1118 
ConvertSquareOperator(const TensorFlowSquareOperator & src_op,GraphDef * tensorflow_graph)1119 void ConvertSquareOperator(const TensorFlowSquareOperator& src_op,
1120                            GraphDef* tensorflow_graph) {
1121   tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
1122   square_op->set_op("Square");
1123   square_op->set_name(src_op.outputs[0]);
1124   CHECK_EQ(src_op.inputs.size(), 1);
1125   *square_op->add_input() = src_op.inputs[0];
1126   (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
1127 }
1128 
ConvertSqrtOperator(const TensorFlowSqrtOperator & src_op,GraphDef * tensorflow_graph)1129 void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op,
1130                          GraphDef* tensorflow_graph) {
1131   tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node();
1132   sqrt_op->set_op("Sqrt");
1133   sqrt_op->set_name(src_op.outputs[0]);
1134   CHECK_EQ(src_op.inputs.size(), 1);
1135   *sqrt_op->add_input() = src_op.inputs[0];
1136   (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
1137 }
1138 
ConvertRsqrtOperator(const Model & model,const TensorFlowRsqrtOperator & src_op,GraphDef * tensorflow_graph)1139 void ConvertRsqrtOperator(const Model& model,
1140                           const TensorFlowRsqrtOperator& src_op,
1141                           GraphDef* tensorflow_graph) {
1142   tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node();
1143   rsqrt_op->set_op("Rsqrt");
1144   rsqrt_op->set_name(src_op.outputs[0]);
1145   CHECK_EQ(src_op.inputs.size(), 1);
1146   *rsqrt_op->add_input() = src_op.inputs[0];
1147   const tensorflow::DataType data_type =
1148       GetTensorFlowDataType(model, src_op.inputs[0]);
1149   (*rsqrt_op->mutable_attr())["T"].set_type(data_type);
1150 }
1151 
ConvertSplitOperator(const Model & model,const TensorFlowSplitOperator & src_op,GraphDef * tensorflow_graph)1152 void ConvertSplitOperator(const Model& model,
1153                           const TensorFlowSplitOperator& src_op,
1154                           GraphDef* tensorflow_graph) {
1155   tensorflow::NodeDef* split_op = tensorflow_graph->add_node();
1156   split_op->set_op("Split");
1157   split_op->set_name(src_op.outputs[0]);
1158   for (const auto& input : src_op.inputs) {
1159     *split_op->add_input() = input;
1160   }
1161   (*split_op->mutable_attr())["T"].set_type(
1162       GetTensorFlowDataType(model, src_op.outputs[0]));
1163   (*split_op->mutable_attr())["num_split"].set_i(src_op.num_split);
1164   const auto& split_dim_array = model.GetArray(src_op.inputs[0]);
1165   CHECK(split_dim_array.buffer);
1166   CHECK(split_dim_array.data_type == ArrayDataType::kInt32);
1167   const auto& split_dim_data =
1168       split_dim_array.GetBuffer<ArrayDataType::kInt32>().data;
1169   CHECK_EQ(split_dim_data.size(), 1);
1170   const int split_dim = split_dim_data[0];
1171   CreateDummyConcatDimTensorConst(src_op.inputs[0], split_dim,
1172                                   tensorflow_graph);
1173 }
1174 
ConvertSplitVOperator(const Model & model,const TensorFlowSplitVOperator & src_op,GraphDef * tensorflow_graph)1175 void ConvertSplitVOperator(const Model& model,
1176                            const TensorFlowSplitVOperator& src_op,
1177                            GraphDef* tensorflow_graph) {
1178   tensorflow::NodeDef* split_v_op = tensorflow_graph->add_node();
1179   split_v_op->set_op("SplitV");
1180   split_v_op->set_name(src_op.outputs[0]);
1181   for (const auto& input : src_op.inputs) {
1182     *split_v_op->add_input() = input;
1183   }
1184   (*split_v_op->mutable_attr())["T"].set_type(
1185       GetTensorFlowDataType(model, src_op.outputs[0]));
1186   (*split_v_op->mutable_attr())["Tlen"].set_type(
1187       GetTensorFlowDataType(model, src_op.inputs[1]));
1188   (*split_v_op->mutable_attr())["num_split"].set_i(src_op.num_split);
1189   ConvertIntTensorConst(model, src_op.inputs[1], tensorflow_graph);
1190 }
1191 
ConvertCastOperator(const Model & model,const CastOperator & src_op,GraphDef * tensorflow_graph)1192 void ConvertCastOperator(const Model& model, const CastOperator& src_op,
1193                          GraphDef* tensorflow_graph) {
1194   tensorflow::NodeDef* cast_op = tensorflow_graph->add_node();
1195   cast_op->set_op("Cast");
1196   cast_op->set_name(src_op.outputs[0]);
1197   CHECK_EQ(src_op.inputs.size(), 1);
1198   *cast_op->add_input() = src_op.inputs[0];
1199 
1200   (*cast_op->mutable_attr())["DstT"].set_type(
1201       GetTensorFlowDataType(model, src_op.outputs[0]));
1202   (*cast_op->mutable_attr())["SrcT"].set_type(
1203       GetTensorFlowDataType(model, src_op.inputs[0]));
1204 }
1205 
ConvertFloorOperator(const Model & model,const FloorOperator & src_op,GraphDef * tensorflow_graph)1206 void ConvertFloorOperator(const Model& model, const FloorOperator& src_op,
1207                           GraphDef* tensorflow_graph) {
1208   tensorflow::NodeDef* floor_op = tensorflow_graph->add_node();
1209   floor_op->set_op("Floor");
1210   floor_op->set_name(src_op.outputs[0]);
1211   CHECK_EQ(src_op.inputs.size(), 1);
1212   *floor_op->add_input() = src_op.inputs[0];
1213   (*floor_op->mutable_attr())["T"].set_type(DT_FLOAT);
1214 }
1215 
ConvertCeilOperator(const Model & model,const CeilOperator & src_op,GraphDef * tensorflow_graph)1216 void ConvertCeilOperator(const Model& model, const CeilOperator& src_op,
1217                          GraphDef* tensorflow_graph) {
1218   tensorflow::NodeDef* ceil_op = tensorflow_graph->add_node();
1219   ceil_op->set_op("Ceil");
1220   ceil_op->set_name(src_op.outputs[0]);
1221   CHECK_EQ(src_op.inputs.size(), 1);
1222   *ceil_op->add_input() = src_op.inputs[0];
1223   (*ceil_op->mutable_attr())["T"].set_type(DT_FLOAT);
1224 }
1225 
ConvertRoundOperator(const Model & model,const RoundOperator & src_op,GraphDef * tensorflow_graph)1226 void ConvertRoundOperator(const Model& model, const RoundOperator& src_op,
1227                           GraphDef* tensorflow_graph) {
1228   tensorflow::NodeDef* round_op = tensorflow_graph->add_node();
1229   round_op->set_op("Round");
1230   round_op->set_name(src_op.outputs[0]);
1231   CHECK_EQ(src_op.inputs.size(), 1);
1232   *round_op->add_input() = src_op.inputs[0];
1233   (*round_op->mutable_attr())["T"].set_type(DT_FLOAT);
1234 }
1235 
ConvertGatherOperator(const Model & model,const GatherOperator & src_op,GraphDef * tensorflow_graph)1236 void ConvertGatherOperator(const Model& model, const GatherOperator& src_op,
1237                            GraphDef* tensorflow_graph) {
1238   tensorflow::NodeDef* gather_op = tensorflow_graph->add_node();
1239   gather_op->set_op("GatherV2");
1240   gather_op->set_name(src_op.outputs[0]);
1241   *gather_op->add_input() = src_op.inputs[0];
1242   *gather_op->add_input() = src_op.inputs[1];
1243 
1244   if (!src_op.axis) {
1245     // Dynamic axis.
1246     CHECK_EQ(src_op.inputs.size(), 3);
1247     *gather_op->add_input() = src_op.inputs[2];
1248   } else {
1249     // Constant axis.
1250     CHECK_EQ(src_op.inputs.size(), 2);
1251     const std::string gather_axis =
1252         AvailableArrayName(model, gather_op->name() + "/axis");
1253     CreateIntTensorConst(gather_axis, {src_op.axis.value()}, {},
1254                          tensorflow_graph);
1255     *gather_op->add_input() = gather_axis;
1256   }
1257 
1258   (*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32);
1259   (*gather_op->mutable_attr())["Taxis"].set_type(DT_INT32);
1260   const tensorflow::DataType params_type =
1261       GetTensorFlowDataType(model, src_op.inputs[0]);
1262   (*gather_op->mutable_attr())["Tparams"].set_type(params_type);
1263 }
1264 
ConvertArgMaxOperator(const Model & model,const ArgMaxOperator & src_op,GraphDef * tensorflow_graph)1265 void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op,
1266                            GraphDef* tensorflow_graph) {
1267   tensorflow::NodeDef* argmax_op = tensorflow_graph->add_node();
1268   argmax_op->set_op("ArgMax");
1269   argmax_op->set_name(src_op.outputs[0]);
1270   CHECK_EQ(src_op.inputs.size(), 2);
1271   *argmax_op->add_input() = src_op.inputs[0];
1272   *argmax_op->add_input() = src_op.inputs[1];
1273   (*argmax_op->mutable_attr())["T"].set_type(
1274       GetTensorFlowDataType(model, src_op.inputs[0]));
1275   (*argmax_op->mutable_attr())["Tidx"].set_type(
1276       GetTensorFlowDataType(model, src_op.inputs[1]));
1277   (*argmax_op->mutable_attr())["output_type"].set_type(
1278       GetTensorFlowDataType(model, src_op.outputs[0]));
1279 }
1280 
ConvertArgMinOperator(const Model & model,const ArgMinOperator & src_op,GraphDef * tensorflow_graph)1281 void ConvertArgMinOperator(const Model& model, const ArgMinOperator& src_op,
1282                            GraphDef* tensorflow_graph) {
1283   tensorflow::NodeDef* argmin_op = tensorflow_graph->add_node();
1284   argmin_op->set_op("ArgMin");
1285   argmin_op->set_name(src_op.outputs[0]);
1286   CHECK_EQ(src_op.inputs.size(), 2);
1287   *argmin_op->add_input() = src_op.inputs[0];
1288   *argmin_op->add_input() = src_op.inputs[1];
1289   (*argmin_op->mutable_attr())["T"].set_type(
1290       GetTensorFlowDataType(model, src_op.inputs[0]));
1291   (*argmin_op->mutable_attr())["Tidx"].set_type(
1292       GetTensorFlowDataType(model, src_op.inputs[1]));
1293   (*argmin_op->mutable_attr())["output_type"].set_type(
1294       GetTensorFlowDataType(model, src_op.outputs[0]));
1295 }
1296 
ConvertTransposeOperator(const Model & model,const TransposeOperator & src_op,GraphDef * tensorflow_graph)1297 void ConvertTransposeOperator(const Model& model,
1298                               const TransposeOperator& src_op,
1299                               GraphDef* tensorflow_graph) {
1300   tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node();
1301   transpose_op->set_op("Transpose");
1302   transpose_op->set_name(src_op.outputs[0]);
1303   CHECK_EQ(src_op.inputs.size(), 2);
1304   *transpose_op->add_input() = src_op.inputs[0];
1305   *transpose_op->add_input() = src_op.inputs[1];
1306   (*transpose_op->mutable_attr())["T"].set_type(
1307       GetTensorFlowDataType(model, src_op.inputs[0]));
1308   (*transpose_op->mutable_attr())["Tperm"].set_type(
1309       GetTensorFlowDataType(model, src_op.inputs[1]));
1310 }
1311 
ConvertTensorFlowShapeOperator(const Model & model,const TensorFlowShapeOperator & src_op,GraphDef * tensorflow_graph)1312 void ConvertTensorFlowShapeOperator(const Model& model,
1313                                     const TensorFlowShapeOperator& src_op,
1314                                     GraphDef* tensorflow_graph) {
1315   tensorflow::NodeDef* shape_op = tensorflow_graph->add_node();
1316   shape_op->set_op("Shape");
1317   shape_op->set_name(src_op.outputs[0]);
1318   CHECK_EQ(src_op.inputs.size(), 1);
1319   *shape_op->add_input() = src_op.inputs[0];
1320   (*shape_op->mutable_attr())["T"].set_type(
1321       GetTensorFlowDataType(model, src_op.inputs[0]));
1322   (*shape_op->mutable_attr())["out_type"].set_type(
1323       GetTensorFlowDataType(model, src_op.outputs[0]));
1324 }
1325 
ConvertRankOperator(const Model & model,const TensorFlowRankOperator & src_op,GraphDef * tensorflow_graph)1326 void ConvertRankOperator(const Model& model,
1327                          const TensorFlowRankOperator& src_op,
1328                          GraphDef* tensorflow_graph) {
1329   tensorflow::NodeDef* rank_op = tensorflow_graph->add_node();
1330   rank_op->set_op("Rank");
1331   rank_op->set_name(src_op.outputs[0]);
1332   CHECK_EQ(src_op.inputs.size(), 1);
1333   *rank_op->add_input() = src_op.inputs[0];
1334   (*rank_op->mutable_attr())["T"].set_type(
1335       GetTensorFlowDataType(model, src_op.inputs[0]));
1336 }
1337 
ConvertRangeOperator(const Model & model,const RangeOperator & src_op,GraphDef * tensorflow_graph)1338 void ConvertRangeOperator(const Model& model, const RangeOperator& src_op,
1339                           GraphDef* tensorflow_graph) {
1340   tensorflow::NodeDef* range_op = tensorflow_graph->add_node();
1341   range_op->set_op("Range");
1342   range_op->set_name(src_op.outputs[0]);
1343   CHECK_EQ(src_op.inputs.size(), 3);
1344   *range_op->add_input() = src_op.inputs[0];
1345   *range_op->add_input() = src_op.inputs[1];
1346   *range_op->add_input() = src_op.inputs[2];
1347   (*range_op->mutable_attr())["Tidx"].set_type(
1348       GetTensorFlowDataTypeForOp(src_op.dtype, /*op_name=*/src_op.outputs[0]));
1349 }
1350 
ConvertPackOperator(const Model & model,const PackOperator & src_op,GraphDef * tensorflow_graph)1351 void ConvertPackOperator(const Model& model, const PackOperator& src_op,
1352                          GraphDef* tensorflow_graph) {
1353   tensorflow::NodeDef* pack_op = tensorflow_graph->add_node();
1354   pack_op->set_op("Pack");
1355   pack_op->set_name(src_op.outputs[0]);
1356   for (const auto& input : src_op.inputs) {
1357     *pack_op->add_input() = input;
1358   }
1359   (*pack_op->mutable_attr())["axis"].set_i(src_op.axis);
1360   (*pack_op->mutable_attr())["N"].set_i(src_op.inputs.size());
1361   (*pack_op->mutable_attr())["T"].set_type(
1362       GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0]));
1363 }
1364 
ConvertFillOperator(const Model & model,const FillOperator & src_op,GraphDef * tensorflow_graph)1365 void ConvertFillOperator(const Model& model, const FillOperator& src_op,
1366                          GraphDef* tensorflow_graph) {
1367   tensorflow::NodeDef* fill_op = tensorflow_graph->add_node();
1368   fill_op->set_op("Fill");
1369   fill_op->set_name(src_op.outputs[0]);
1370   CHECK_EQ(src_op.inputs.size(), 2);
1371   *fill_op->add_input() = src_op.inputs[0];
1372   *fill_op->add_input() = src_op.inputs[1];
1373   (*fill_op->mutable_attr())["index_type"].set_type(
1374       GetTensorFlowDataType(model, src_op.inputs[0]));
1375   (*fill_op->mutable_attr())["T"].set_type(
1376       GetTensorFlowDataType(model, src_op.inputs[1]));
1377 }
1378 
ConvertFloorDivOperator(const Model & model,const FloorDivOperator & src_op,GraphDef * tensorflow_graph)1379 void ConvertFloorDivOperator(const Model& model, const FloorDivOperator& src_op,
1380                              GraphDef* tensorflow_graph) {
1381   tensorflow::NodeDef* floor_div_op = tensorflow_graph->add_node();
1382   floor_div_op->set_op("FloorDiv");
1383   floor_div_op->set_name(src_op.outputs[0]);
1384   CHECK_EQ(src_op.inputs.size(), 2);
1385   *floor_div_op->add_input() = src_op.inputs[0];
1386   *floor_div_op->add_input() = src_op.inputs[1];
1387   (*floor_div_op->mutable_attr())["T"].set_type(
1388       GetTensorFlowDataType(model, src_op.inputs[0]));
1389 }
1390 
ConvertFloorModOperator(const Model & model,const FloorModOperator & src_op,GraphDef * tensorflow_graph)1391 void ConvertFloorModOperator(const Model& model, const FloorModOperator& src_op,
1392                              GraphDef* tensorflow_graph) {
1393   tensorflow::NodeDef* floor_mod_op = tensorflow_graph->add_node();
1394   floor_mod_op->set_op("FloorMod");
1395   floor_mod_op->set_name(src_op.outputs[0]);
1396   DCHECK_EQ(src_op.inputs.size(), 2);
1397   *floor_mod_op->add_input() = src_op.inputs[0];
1398   *floor_mod_op->add_input() = src_op.inputs[1];
1399   (*floor_mod_op->mutable_attr())["T"].set_type(
1400       GetTensorFlowDataType(model, src_op.inputs[0]));
1401 }
1402 
ConvertExpandDimsOperator(const Model & model,const ExpandDimsOperator & src_op,GraphDef * tensorflow_graph)1403 void ConvertExpandDimsOperator(const Model& model,
1404                                const ExpandDimsOperator& src_op,
1405                                GraphDef* tensorflow_graph) {
1406   tensorflow::NodeDef* expand_dims_op = tensorflow_graph->add_node();
1407   expand_dims_op->set_op("ExpandDims");
1408   expand_dims_op->set_name(src_op.outputs[0]);
1409   CHECK_EQ(src_op.inputs.size(), 2);
1410   *expand_dims_op->add_input() = src_op.inputs[0];
1411   *expand_dims_op->add_input() = src_op.inputs[1];
1412   (*expand_dims_op->mutable_attr())["T"].set_type(
1413       GetTensorFlowDataType(model, src_op.inputs[0]));
1414   (*expand_dims_op->mutable_attr())["Tdim"].set_type(
1415       GetTensorFlowDataType(model, src_op.inputs[1]));
1416 }
1417 
ConvertResizeBilinearOperator(const Model & model,const ResizeBilinearOperator & src_op,GraphDef * tensorflow_graph)1418 void ConvertResizeBilinearOperator(const Model& model,
1419                                    const ResizeBilinearOperator& src_op,
1420                                    GraphDef* tensorflow_graph) {
1421   tensorflow::NodeDef* resize_op = tensorflow_graph->add_node();
1422   resize_op->set_op("ResizeBilinear");
1423   resize_op->set_name(src_op.outputs[0]);
1424   CHECK_EQ(src_op.inputs.size(), 2);
1425   *resize_op->add_input() = src_op.inputs[0];
1426   *resize_op->add_input() = src_op.inputs[1];
1427   (*resize_op->mutable_attr())["T"].set_type(DT_FLOAT);
1428   (*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners);
1429   (*resize_op->mutable_attr())["half_pixel_centers"].set_b(
1430       src_op.half_pixel_centers);
1431 }
1432 
ConvertResizeNearestNeighborOperator(const Model & model,const ResizeNearestNeighborOperator & src_op,GraphDef * tensorflow_graph)1433 void ConvertResizeNearestNeighborOperator(
1434     const Model& model, const ResizeNearestNeighborOperator& src_op,
1435     GraphDef* tensorflow_graph) {
1436   tensorflow::NodeDef* resize_op = tensorflow_graph->add_node();
1437   resize_op->set_op("ResizeNearestNeighbor");
1438   resize_op->set_name(src_op.outputs[0]);
1439   CHECK_EQ(src_op.inputs.size(), 2);
1440   *resize_op->add_input() = src_op.inputs[0];
1441   *resize_op->add_input() = src_op.inputs[1];
1442   (*resize_op->mutable_attr())["T"].set_type(DT_FLOAT);
1443   (*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners);
1444   (*resize_op->mutable_attr())["half_pixel_centers"].set_b(
1445       src_op.half_pixel_centers);
1446 }
1447 
ConvertOneHotOperator(const Model & model,const OneHotOperator & src_op,GraphDef * tensorflow_graph)1448 void ConvertOneHotOperator(const Model& model, const OneHotOperator& src_op,
1449                            GraphDef* tensorflow_graph) {
1450   tensorflow::NodeDef* onehot_op = tensorflow_graph->add_node();
1451   onehot_op->set_op("OneHot");
1452   onehot_op->set_name(src_op.outputs[0]);
1453   CHECK_EQ(src_op.inputs.size(), 4);
1454   for (const auto& input : src_op.inputs) {
1455     *onehot_op->add_input() = input;
1456   }
1457   (*onehot_op->mutable_attr())["T"].set_type(
1458       GetTensorFlowDataType(model, src_op.outputs[0]));
1459   (*onehot_op->mutable_attr())["axis"].set_i(src_op.axis);
1460 }
1461 
1462 namespace {
1463 // TODO(aselle): Remove when available in absl
FindLongestCommonPrefix(absl::string_view a,absl::string_view b)1464 absl::string_view FindLongestCommonPrefix(absl::string_view a,
1465                                           absl::string_view b) {
1466   if (a.empty() || b.empty()) return absl::string_view();
1467 
1468   const char* pa = a.data();
1469   const char* pb = b.data();
1470   std::string::difference_type count = 0;
1471   const std::string::difference_type limit = std::min(a.size(), b.size());
1472   while (count < limit && *pa == *pb) {
1473     ++pa;
1474     ++pb;
1475     ++count;
1476   }
1477 
1478   return absl::string_view(a.data(), count);
1479 }
1480 }  // namespace
1481 
ConvertLstmCellOperator(const Model & model,const LstmCellOperator & src_op,GraphDef * tensorflow_graph)1482 void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
1483                              GraphDef* tensorflow_graph) {
1484   // Find the base name
1485   const std::string base(
1486       FindLongestCommonPrefix(src_op.outputs[LstmCellOperator::STATE_OUTPUT],
1487                               src_op.outputs[LstmCellOperator::ACTIV_OUTPUT]));
1488 
1489   // Concatenate inputs
1490   const std::string concat_output = base + "basic_lstm_cell/concat";
1491   // Op names have been chosen to match the tf.slim LSTM naming
1492   // as closely as possible.
1493   const int axis =
1494       model.GetArray(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT])
1495           .shape()
1496           .dimensions_count() -
1497       1;
1498   // Note that DATA_INPUT may have extra size 1 dimensions, but TF concat
1499   // works the same since the tensor has the same underlying data layout.
1500   const std::string axis_output = concat_output + "/axis";
1501   CreateDummyConcatDimTensorConst(axis_output, axis, tensorflow_graph);
1502   tensorflow::NodeDef* concat_op = tensorflow_graph->add_node();
1503   concat_op->set_op("ConcatV2");
1504   concat_op->set_name(concat_output);
1505   *concat_op->add_input() = src_op.inputs[LstmCellOperator::DATA_INPUT];
1506   *concat_op->add_input() = src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT];
1507   *concat_op->add_input() = axis_output;
1508   (*concat_op->mutable_attr())["T"].set_type(DT_FLOAT);
1509   (*concat_op->mutable_attr())["Tidx"].set_type(DT_INT32);
1510   (*concat_op->mutable_attr())["N"].set_i(2);  // Number of inputs
1511 
1512   // Write weights
1513   const std::string weights_output = base + "weights";
1514   CHECK(model.HasArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]));
1515   const std::string weights_name = WalkUpToConstantArray(
1516       model, src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]);
1517   const auto& weights_array = model.GetArray(weights_name);
1518   // Convert 4D FullyConnected weights into 2D matrix
1519   const auto& weights_shape = weights_array.shape();
1520   CHECK_EQ(weights_shape.dimensions_count(), 2);
1521   CHECK(weights_array.buffer);
1522   CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
1523   const float* weights_data =
1524       weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
1525   ConvertFloatTensorConst(weights_output, weights_shape, weights_data,
1526                           AxesOrder::kCR, AxesOrder::kRC, tensorflow_graph);
1527 
1528   // Fully connected matrix multiply
1529   const std::string matmul_output = base + "MatMul";
1530   tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node();
1531   matmul_op->set_op("MatMul");
1532   matmul_op->set_name(matmul_output);
1533   *matmul_op->add_input() = concat_output;
1534   *matmul_op->add_input() = weights_output;
1535   (*matmul_op->mutable_attr())["transpose_a"].set_b(false);
1536   (*matmul_op->mutable_attr())["transpose_b"].set_b(false);
1537   (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT);
1538 
1539   // Write biases
1540   const std::string biases_output = base + "biases";
1541   CHECK(model.HasArray(src_op.inputs[LstmCellOperator::BIASES_INPUT]));
1542   const std::string bias_name = WalkUpToConstantArray(
1543       model, src_op.inputs[LstmCellOperator::BIASES_INPUT]);
1544   const auto& bias_array = model.GetArray(bias_name);
1545   // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
1546   Shape bias_shape_1d = bias_array.shape();
1547   UnextendShape(&bias_shape_1d, 1);
1548   CHECK(bias_array.buffer);
1549   CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
1550   const float* bias_data =
1551       bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
1552   ConvertFloatTensorConst(biases_output, bias_shape_1d, bias_data,
1553                           AxesOrder::kOneAxis, AxesOrder::kOneAxis,
1554                           tensorflow_graph,
1555                           LegacyScalarPolicy::kDoCreateLegacyScalars);
1556 
1557   // Add biases
1558   std::string biasadd_output = base + "BiasAdd";
1559   tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
1560   biasadd_op->set_op("BiasAdd");
1561   biasadd_op->set_name(biasadd_output);
1562   biasadd_op->add_input(matmul_output);
1563   biasadd_op->add_input(biases_output);
1564   (*biasadd_op->mutable_attr())["data_format"].set_s("NHWC");
1565   (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
1566 
1567   // Split
1568   std::string split_dim_output = base + "split/split_dim";
1569   // The dimension is the same as the concatenation dimension
1570   CreateDummyConcatDimTensorConst(split_dim_output, axis, tensorflow_graph);
1571   std::string split_output = base + "split";
1572   tensorflow::NodeDef* split_op = tensorflow_graph->add_node();
1573   split_op->set_op("Split");
1574   split_op->set_name(split_output);
1575   *split_op->add_input() = split_dim_output;
1576   *split_op->add_input() = biasadd_output;
1577   (*split_op->mutable_attr())["T"].set_type(DT_FLOAT);
1578   (*split_op->mutable_attr())["num_split"].set_i(4);  // Split into four outputs
1579 
1580   // Activation functions and memory computations
1581   const std::string tanh_0_output = base + "Tanh";
1582   tensorflow::NodeDef* tanh_0_op = tensorflow_graph->add_node();
1583   tanh_0_op->set_op("Tanh");
1584   tanh_0_op->set_name(tanh_0_output);
1585   *tanh_0_op->add_input() = split_output + ":1";
1586   (*tanh_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
1587 
1588   const std::string sigmoid_1_output = base + "Sigmoid_1";
1589   tensorflow::NodeDef* logistic_1_op = tensorflow_graph->add_node();
1590   logistic_1_op->set_op("Sigmoid");
1591   logistic_1_op->set_name(sigmoid_1_output);
1592   *logistic_1_op->add_input() = split_output;
1593   (*logistic_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1594 
1595   const std::string mul_1_output = base + "mul_1";
1596   tensorflow::NodeDef* mul_1_op = tensorflow_graph->add_node();
1597   mul_1_op->set_op("Mul");
1598   mul_1_op->set_name(mul_1_output);
1599   *mul_1_op->add_input() = sigmoid_1_output;
1600   *mul_1_op->add_input() = tanh_0_output;
1601   (*mul_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1602 
1603   const std::string sigmoid_0_output = base + "Sigmoid";
1604   tensorflow::NodeDef* logistic_2_op = tensorflow_graph->add_node();
1605   logistic_2_op->set_op("Sigmoid");
1606   logistic_2_op->set_name(sigmoid_0_output);
1607   *logistic_2_op->add_input() = split_output + ":2";
1608   (*logistic_2_op->mutable_attr())["T"].set_type(DT_FLOAT);
1609 
1610   const std::string sigmoid_2_output = base + "Sigmoid_2";
1611   tensorflow::NodeDef* logistic_3_op = tensorflow_graph->add_node();
1612   logistic_3_op->set_op("Sigmoid");
1613   logistic_3_op->set_name(sigmoid_2_output);
1614   *logistic_3_op->add_input() = split_output + ":3";
1615   (*logistic_3_op->mutable_attr())["T"].set_type(DT_FLOAT);
1616 
1617   const std::string mul_0_output = base + "mul";
1618   tensorflow::NodeDef* mul_0_op = tensorflow_graph->add_node();
1619   mul_0_op->set_op("Mul");
1620   mul_0_op->set_name(mul_0_output);
1621   *mul_0_op->add_input() = src_op.inputs[LstmCellOperator::PREV_STATE_INPUT];
1622   *mul_0_op->add_input() = sigmoid_0_output;
1623   (*mul_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
1624 
1625   const std::string add_1_output =
1626       src_op.outputs[LstmCellOperator::STATE_OUTPUT];
1627   tensorflow::NodeDef* add_1_op = tensorflow_graph->add_node();
1628   add_1_op->set_op("Add");
1629   add_1_op->set_name(add_1_output);
1630   *add_1_op->add_input() = mul_0_output;
1631   *add_1_op->add_input() = mul_1_output;
1632   (*add_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1633 
1634   const std::string tanh_1_output = base + "Tanh_1";
1635   tensorflow::NodeDef* tanh_1_op = tensorflow_graph->add_node();
1636   tanh_1_op->set_op("Tanh");
1637   tanh_1_op->set_name(tanh_1_output);
1638   *tanh_1_op->add_input() = add_1_output;
1639   (*tanh_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1640 
1641   const std::string mul_2_output =
1642       src_op.outputs[LstmCellOperator::ACTIV_OUTPUT];
1643   tensorflow::NodeDef* mul_2_op = tensorflow_graph->add_node();
1644   mul_2_op->set_op("Mul");
1645   mul_2_op->set_name(mul_2_output);
1646   *mul_2_op->add_input() = tanh_1_output;
1647   *mul_2_op->add_input() = sigmoid_2_output;
1648   (*mul_2_op->mutable_attr())["T"].set_type(DT_FLOAT);
1649 }
1650 
ConvertSpaceToBatchNDOperator(const Model & model,const SpaceToBatchNDOperator & src_op,GraphDef * tensorflow_graph)1651 void ConvertSpaceToBatchNDOperator(const Model& model,
1652                                    const SpaceToBatchNDOperator& src_op,
1653                                    GraphDef* tensorflow_graph) {
1654   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1655   new_op->set_op("SpaceToBatchND");
1656   new_op->set_name(src_op.outputs[0]);
1657   CHECK_EQ(src_op.inputs.size(), 3);
1658   *new_op->add_input() = src_op.inputs[0];
1659   *new_op->add_input() = src_op.inputs[1];
1660   *new_op->add_input() = src_op.inputs[2];
1661   const tensorflow::DataType params_type =
1662       GetTensorFlowDataType(model, src_op.inputs[0]);
1663   (*new_op->mutable_attr())["T"].set_type(params_type);
1664   (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
1665   (*new_op->mutable_attr())["Tpaddings"].set_type(DT_INT32);
1666 }
1667 
ConvertBatchToSpaceNDOperator(const Model & model,const BatchToSpaceNDOperator & src_op,GraphDef * tensorflow_graph)1668 void ConvertBatchToSpaceNDOperator(const Model& model,
1669                                    const BatchToSpaceNDOperator& src_op,
1670                                    GraphDef* tensorflow_graph) {
1671   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1672   new_op->set_op("BatchToSpaceND");
1673   new_op->set_name(src_op.outputs[0]);
1674   CHECK_EQ(src_op.inputs.size(), 3);
1675   *new_op->add_input() = src_op.inputs[0];
1676   *new_op->add_input() = src_op.inputs[1];
1677   *new_op->add_input() = src_op.inputs[2];
1678   const tensorflow::DataType params_type =
1679       GetTensorFlowDataType(model, src_op.inputs[0]);
1680   (*new_op->mutable_attr())["T"].set_type(params_type);
1681   (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
1682   (*new_op->mutable_attr())["Tcrops"].set_type(DT_INT32);
1683 }
1684 
ConvertPadOperator(const Model & model,const PadOperator & src_op,GraphDef * tensorflow_graph)1685 void ConvertPadOperator(const Model& model, const PadOperator& src_op,
1686                         GraphDef* tensorflow_graph) {
1687   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1688   new_op->set_op("Pad");
1689   new_op->set_name(src_op.outputs[0]);
1690   CHECK_EQ(src_op.inputs.size(), 2);
1691   *new_op->add_input() = src_op.inputs[0];
1692   *new_op->add_input() = src_op.inputs[1];
1693 
1694   const tensorflow::DataType params_type =
1695       GetTensorFlowDataType(model, src_op.inputs[0]);
1696   (*new_op->mutable_attr())["T"].set_type(params_type);
1697 
1698   // Create the params tensor.
1699   tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1700   params_op->set_op("Const");
1701   params_op->set_name(src_op.inputs[1]);
1702   (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1703   auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1704   tensor->set_dtype(DT_INT32);
1705 
1706   CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size());
1707   for (int i = 0; i < src_op.left_padding.size(); ++i) {
1708     tensor->add_int_val(src_op.left_padding[i]);
1709     tensor->add_int_val(src_op.right_padding[i]);
1710   }
1711   auto* shape = tensor->mutable_tensor_shape();
1712   shape->add_dim()->set_size(src_op.left_padding.size());
1713   shape->add_dim()->set_size(2);
1714 }
1715 
ConvertPadV2Operator(const Model & model,const PadV2Operator & src_op,GraphDef * tensorflow_graph)1716 void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op,
1717                           GraphDef* tensorflow_graph) {
1718   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1719   new_op->set_op("PadV2");
1720   new_op->set_name(src_op.outputs[0]);
1721   CHECK_EQ(src_op.inputs.size(), 2);
1722   *new_op->add_input() = src_op.inputs[0];
1723   *new_op->add_input() = src_op.inputs[1];
1724   *new_op->add_input() = src_op.inputs[2];
1725 
1726   const tensorflow::DataType params_type =
1727       GetTensorFlowDataType(model, src_op.inputs[0]);
1728   (*new_op->mutable_attr())["T"].set_type(params_type);
1729 
1730   // Create the params tensor.
1731   tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1732   params_op->set_op("Const");
1733   params_op->set_name(src_op.inputs[1]);
1734   (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1735   auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1736   tensor->set_dtype(DT_INT32);
1737 
1738   CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size());
1739   for (int i = 0; i < src_op.left_padding.size(); ++i) {
1740     tensor->add_int_val(src_op.left_padding[i]);
1741     tensor->add_int_val(src_op.right_padding[i]);
1742   }
1743   auto* shape = tensor->mutable_tensor_shape();
1744   shape->add_dim()->set_size(src_op.left_padding.size());
1745   shape->add_dim()->set_size(2);
1746 }
1747 
CreateSliceInput(const std::string & input_name,const std::vector<int> & values,GraphDef * tensorflow_graph)1748 void CreateSliceInput(const std::string& input_name,
1749                       const std::vector<int>& values,
1750                       GraphDef* tensorflow_graph) {
1751   tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1752   params_op->set_op("Const");
1753   params_op->set_name(input_name);
1754   (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1755   auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1756   tensor->set_dtype(DT_INT32);
1757 
1758   for (int i = 0; i < values.size(); ++i) {
1759     tensor->add_int_val(values[i]);
1760   }
1761   auto* shape = tensor->mutable_tensor_shape();
1762   shape->add_dim()->set_size(values.size());
1763 }
1764 
ConvertStridedSliceOperator(const Model & model,const StridedSliceOperator & src_op,GraphDef * tensorflow_graph)1765 void ConvertStridedSliceOperator(const Model& model,
1766                                  const StridedSliceOperator& src_op,
1767                                  GraphDef* tensorflow_graph) {
1768   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1769   new_op->set_op("StridedSlice");
1770   new_op->set_name(src_op.outputs[0]);
1771   CHECK_EQ(src_op.inputs.size(), 4);
1772   *new_op->add_input() = src_op.inputs[0];
1773   *new_op->add_input() = src_op.inputs[1];
1774   *new_op->add_input() = src_op.inputs[2];
1775   *new_op->add_input() = src_op.inputs[3];
1776 
1777   const tensorflow::DataType params_type =
1778       GetTensorFlowDataType(model, src_op.inputs[0]);
1779   (*new_op->mutable_attr())["T"].set_type(params_type);
1780 
1781   (*new_op->mutable_attr())["Index"].set_type(DT_INT32);
1782   (*new_op->mutable_attr())["begin_mask"].set_i(src_op.begin_mask);
1783   (*new_op->mutable_attr())["ellipsis_mask"].set_i(src_op.ellipsis_mask);
1784   (*new_op->mutable_attr())["end_mask"].set_i(src_op.end_mask);
1785   (*new_op->mutable_attr())["new_axis_mask"].set_i(src_op.new_axis_mask);
1786   (*new_op->mutable_attr())["shrink_axis_mask"].set_i(src_op.shrink_axis_mask);
1787 
1788   // Create tensors for start/stop indices and strides.
1789   CreateSliceInput(src_op.inputs[1], src_op.start_indices, tensorflow_graph);
1790   CreateSliceInput(src_op.inputs[2], src_op.stop_indices, tensorflow_graph);
1791   CreateSliceInput(src_op.inputs[3], src_op.strides, tensorflow_graph);
1792 }
1793 
ConvertSliceOperator(const Model & model,const SliceOperator & src_op,GraphDef * tensorflow_graph)1794 void ConvertSliceOperator(const Model& model, const SliceOperator& src_op,
1795                           GraphDef* tensorflow_graph) {
1796   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1797   new_op->set_op("Slice");
1798   new_op->set_name(src_op.outputs[0]);
1799   CHECK_EQ(src_op.inputs.size(), 3);
1800   *new_op->add_input() = src_op.inputs[0];
1801   *new_op->add_input() = src_op.inputs[1];
1802   *new_op->add_input() = src_op.inputs[2];
1803 
1804   const tensorflow::DataType params_type =
1805       GetTensorFlowDataType(model, src_op.inputs[0]);
1806   (*new_op->mutable_attr())["T"].set_type(params_type);
1807   (*new_op->mutable_attr())["Index"].set_type(DT_INT32);
1808 
1809   // Create tensors for begin and size inputs.
1810   CreateSliceInput(src_op.inputs[1], src_op.begin, tensorflow_graph);
1811   CreateSliceInput(src_op.inputs[2], src_op.size, tensorflow_graph);
1812 }
1813 
1814 template <typename T>
ConvertReduceOperator(const Model & model,const T & src_op,GraphDef * tensorflow_graph,const std::string & op_name)1815 void ConvertReduceOperator(const Model& model, const T& src_op,
1816                            GraphDef* tensorflow_graph,
1817                            const std::string& op_name) {
1818   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1819   new_op->set_op(op_name);
1820   new_op->set_name(src_op.outputs[0]);
1821   CHECK_EQ(src_op.inputs.size(), 2);
1822   *new_op->add_input() = src_op.inputs[0];
1823   *new_op->add_input() = src_op.inputs[1];
1824 
1825   if (src_op.type != OperatorType::kAny) {
1826     const tensorflow::DataType params_type =
1827         GetTensorFlowDataType(model, src_op.inputs[0]);
1828     (*new_op->mutable_attr())["T"].set_type(params_type);
1829   }
1830   const tensorflow::DataType indices_type =
1831       GetTensorFlowDataType(model, src_op.inputs[1]);
1832   (*new_op->mutable_attr())["Tidx"].set_type(indices_type);
1833 
1834   if (src_op.keep_dims) {
1835     (*new_op->mutable_attr())["keep_dims"].set_b(true);
1836   }
1837 
1838   // Create the params tensor.
1839   tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1840   params_op->set_op("Const");
1841   params_op->set_name(src_op.inputs[1]);
1842   (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1843   auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1844   tensor->set_dtype(DT_INT32);
1845 
1846   for (int i = 0; i < src_op.axis.size(); ++i) {
1847     tensor->add_int_val(src_op.axis[i]);
1848   }
1849   auto* shape = tensor->mutable_tensor_shape();
1850   shape->add_dim()->set_size(src_op.axis.size());
1851 }
1852 
ConvertSqueezeOperator(const Model & model,const SqueezeOperator & src_op,GraphDef * tensorflow_graph)1853 void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op,
1854                             GraphDef* tensorflow_graph) {
1855   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1856   new_op->set_op("Squeeze");
1857   new_op->set_name(src_op.outputs[0]);
1858   CHECK_EQ(src_op.inputs.size(), 1);
1859   *new_op->add_input() = src_op.inputs[0];
1860 
1861   const tensorflow::DataType params_type =
1862       GetTensorFlowDataType(model, src_op.inputs[0]);
1863   (*new_op->mutable_attr())["T"].set_type(params_type);
1864 
1865   if (!src_op.squeeze_dims.empty()) {
1866     auto& squeeze_dims = (*new_op->mutable_attr())["squeeze_dims"];
1867     for (int i : src_op.squeeze_dims) {
1868       squeeze_dims.mutable_list()->add_i(i);
1869     }
1870   }
1871 }
1872 
ConvertSubOperator(const Model & model,const SubOperator & src_op,GraphDef * tensorflow_graph)1873 void ConvertSubOperator(const Model& model, const SubOperator& src_op,
1874                         GraphDef* tensorflow_graph) {
1875   tensorflow::NodeDef* sub_op = tensorflow_graph->add_node();
1876   sub_op->set_op("Sub");
1877   sub_op->set_name(src_op.outputs[0]);
1878   CHECK_EQ(src_op.inputs.size(), 2);
1879   *sub_op->add_input() = src_op.inputs[0];
1880   *sub_op->add_input() = src_op.inputs[1];
1881   const tensorflow::DataType data_type =
1882       GetTensorFlowDataType(model, src_op.inputs[0]);
1883   (*sub_op->mutable_attr())["T"].set_type(data_type);
1884 }
1885 
ConvertTensorFlowMinimumOperator(const Model & model,const TensorFlowMinimumOperator & src_op,GraphDef * tensorflow_graph)1886 void ConvertTensorFlowMinimumOperator(const Model& model,
1887                                       const TensorFlowMinimumOperator& src_op,
1888                                       GraphDef* tensorflow_graph) {
1889   tensorflow::NodeDef* min_op = tensorflow_graph->add_node();
1890   min_op->set_op("Minimum");
1891   min_op->set_name(src_op.outputs[0]);
1892   CHECK_EQ(src_op.inputs.size(), 2);
1893   *min_op->add_input() = src_op.inputs[0];
1894   *min_op->add_input() = src_op.inputs[1];
1895   const tensorflow::DataType data_type =
1896       GetTensorFlowDataType(model, src_op.inputs[0]);
1897   (*min_op->mutable_attr())["T"].set_type(data_type);
1898 }
1899 
ConvertTensorFlowMaximumOperator(const Model & model,const TensorFlowMaximumOperator & src_op,GraphDef * tensorflow_graph)1900 void ConvertTensorFlowMaximumOperator(const Model& model,
1901                                       const TensorFlowMaximumOperator& src_op,
1902                                       GraphDef* tensorflow_graph) {
1903   tensorflow::NodeDef* max_op = tensorflow_graph->add_node();
1904   max_op->set_op("Maximum");
1905   max_op->set_name(src_op.outputs[0]);
1906   CHECK_EQ(src_op.inputs.size(), 2);
1907   *max_op->add_input() = src_op.inputs[0];
1908   *max_op->add_input() = src_op.inputs[1];
1909   const tensorflow::DataType data_type =
1910       GetTensorFlowDataType(model, src_op.inputs[0]);
1911   (*max_op->mutable_attr())["T"].set_type(data_type);
1912 }
1913 
ConvertSelectOperator(const Model & model,const SelectOperator & src_op,GraphDef * tensorflow_graph)1914 void ConvertSelectOperator(const Model& model, const SelectOperator& src_op,
1915                            GraphDef* tensorflow_graph) {
1916   tensorflow::NodeDef* select_op = tensorflow_graph->add_node();
1917   select_op->set_op("Select");
1918   select_op->set_name(src_op.outputs[0]);
1919   CHECK_EQ(src_op.inputs.size(), 3);
1920   *select_op->add_input() = src_op.inputs[0];
1921   *select_op->add_input() = src_op.inputs[1];
1922   *select_op->add_input() = src_op.inputs[2];
1923   const tensorflow::DataType data_type =
1924       GetTensorFlowDataType(model, src_op.inputs[1]);
1925   (*select_op->mutable_attr())["T"].set_type(data_type);
1926 }
1927 
ConvertTileOperator(const Model & model,const TensorFlowTileOperator & src_op,GraphDef * tensorflow_graph)1928 void ConvertTileOperator(const Model& model,
1929                          const TensorFlowTileOperator& src_op,
1930                          GraphDef* tensorflow_graph) {
1931   tensorflow::NodeDef* tile_op = tensorflow_graph->add_node();
1932   tile_op->set_op("Tile");
1933   tile_op->set_name(src_op.outputs[0]);
1934   CHECK_EQ(src_op.inputs.size(), 2);
1935   *tile_op->add_input() = src_op.inputs[0];
1936   *tile_op->add_input() = src_op.inputs[1];
1937   const tensorflow::DataType data_type =
1938       GetTensorFlowDataType(model, src_op.inputs[0]);
1939   (*tile_op->mutable_attr())["T"].set_type(data_type);
1940   const tensorflow::DataType multiples_data_type =
1941       GetTensorFlowDataType(model, src_op.inputs[1]);
1942   (*tile_op->mutable_attr())["Tmultiples"].set_type(multiples_data_type);
1943 }
1944 
ConvertTopKV2Operator(const Model & model,const TopKV2Operator & src_op,GraphDef * tensorflow_graph)1945 void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op,
1946                            GraphDef* tensorflow_graph) {
1947   tensorflow::NodeDef* topk_op = tensorflow_graph->add_node();
1948   topk_op->set_op("TopKV2");
1949   topk_op->set_name(src_op.outputs[0]);
1950   CHECK_EQ(src_op.inputs.size(), 2);
1951   *topk_op->add_input() = src_op.inputs[0];
1952   *topk_op->add_input() = src_op.inputs[1];
1953   const tensorflow::DataType data_type =
1954       GetTensorFlowDataType(model, src_op.inputs[0]);
1955   (*topk_op->mutable_attr())["T"].set_type(data_type);
1956   (*topk_op->mutable_attr())["sorted"].set_b(true);
1957 }
1958 
ConvertRandomUniformOperator(const Model & model,const RandomUniformOperator & src_op,GraphDef * tensorflow_graph)1959 void ConvertRandomUniformOperator(const Model& model,
1960                                   const RandomUniformOperator& src_op,
1961                                   GraphDef* tensorflow_graph) {
1962   CHECK(tensorflow_graph != nullptr);
1963   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1964   new_op->set_op("RandomUniform");
1965   CHECK_EQ(src_op.inputs.size(), 1);
1966   new_op->set_name(src_op.outputs[0]);
1967   *new_op->add_input() = src_op.inputs[0];
1968   const tensorflow::DataType shape_type =
1969       GetTensorFlowDataType(model, src_op.inputs[0]);
1970   (*new_op->mutable_attr())["T"].set_type(shape_type);
1971   (*new_op->mutable_attr())["dtype"].set_type(
1972       GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0]));
1973   (*new_op->mutable_attr())["seed"].set_i(src_op.seed);
1974   (*new_op->mutable_attr())["seed2"].set_i(src_op.seed2);
1975 }
1976 
ConvertComparisonOperator(const Model & model,const Operator & src_op,const char * op_name,GraphDef * tensorflow_graph)1977 void ConvertComparisonOperator(const Model& model, const Operator& src_op,
1978                                const char* op_name,
1979                                GraphDef* tensorflow_graph) {
1980   tensorflow::NodeDef* comparison_op = tensorflow_graph->add_node();
1981   comparison_op->set_op(op_name);
1982   comparison_op->set_name(src_op.outputs[0]);
1983   CHECK_EQ(src_op.inputs.size(), 2);
1984   *comparison_op->add_input() = src_op.inputs[0];
1985   *comparison_op->add_input() = src_op.inputs[1];
1986   const tensorflow::DataType data_type =
1987       GetTensorFlowDataType(model, src_op.inputs[0]);
1988   (*comparison_op->mutable_attr())["T"].set_type(data_type);
1989 }
1990 
ConvertSparseToDenseOperator(const Model & model,const SparseToDenseOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)1991 void ConvertSparseToDenseOperator(const Model& model,
1992                                   const SparseToDenseOperator& src_op,
1993                                   const char* op_name,
1994                                   GraphDef* tensorflow_graph) {
1995   tensorflow::NodeDef* sparse_to_dense_op = tensorflow_graph->add_node();
1996   sparse_to_dense_op->set_op(op_name);
1997   sparse_to_dense_op->set_name(src_op.outputs[0]);
1998   CHECK_EQ(src_op.inputs.size(), 4);
1999   for (int i = 0; i < 4; ++i) {
2000     *sparse_to_dense_op->add_input() = src_op.inputs[i];
2001   }
2002   const tensorflow::DataType data_type =
2003       GetTensorFlowDataType(model, src_op.inputs[3]);
2004   (*sparse_to_dense_op->mutable_attr())["T"].set_type(data_type);
2005   const tensorflow::DataType index_type =
2006       GetTensorFlowDataType(model, src_op.inputs[0]);
2007   (*sparse_to_dense_op->mutable_attr())["Tindices"].set_type(index_type);
2008   (*sparse_to_dense_op->mutable_attr())["Tindices"].set_b(
2009       src_op.validate_indices);
2010 }
2011 
ConvertPowOperator(const Model & model,const PowOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2012 void ConvertPowOperator(const Model& model, const PowOperator& src_op,
2013                         const char* op_name, GraphDef* tensorflow_graph) {
2014   tensorflow::NodeDef* pow_op = tensorflow_graph->add_node();
2015   pow_op->set_op(op_name);
2016   pow_op->set_name(src_op.outputs[0]);
2017   CHECK_EQ(src_op.inputs.size(), 2);
2018   for (int i = 0; i < 2; ++i) {
2019     *pow_op->add_input() = src_op.inputs[i];
2020   }
2021   const tensorflow::DataType data_type =
2022       GetTensorFlowDataType(model, src_op.inputs[0]);
2023   (*pow_op->mutable_attr())["T"].set_type(data_type);
2024 }
2025 
ConvertLogicalAndOperator(const Model & model,const LogicalAndOperator & src_op,GraphDef * tensorflow_graph)2026 void ConvertLogicalAndOperator(const Model& model,
2027                                const LogicalAndOperator& src_op,
2028                                GraphDef* tensorflow_graph) {
2029   tensorflow::NodeDef* logical_op = tensorflow_graph->add_node();
2030   logical_op->set_op("LogicalAnd");
2031   logical_op->set_name(src_op.outputs[0]);
2032   CHECK_EQ(src_op.inputs.size(), 2);
2033   for (int i = 0; i < 2; ++i) {
2034     *logical_op->add_input() = src_op.inputs[i];
2035   }
2036 }
2037 
ConvertLogicalNotOperator(const Model & model,const LogicalNotOperator & src_op,GraphDef * tensorflow_graph)2038 void ConvertLogicalNotOperator(const Model& model,
2039                                const LogicalNotOperator& src_op,
2040                                GraphDef* tensorflow_graph) {
2041   tensorflow::NodeDef* logical_op = tensorflow_graph->add_node();
2042   logical_op->set_op("LogicalNot");
2043   logical_op->set_name(src_op.outputs[0]);
2044   CHECK_EQ(src_op.inputs.size(), 1);
2045   *logical_op->add_input() = src_op.inputs[0];
2046 }
2047 
ConvertLogicalOrOperator(const Model & model,const LogicalOrOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2048 void ConvertLogicalOrOperator(const Model& model,
2049                               const LogicalOrOperator& src_op,
2050                               const char* op_name, GraphDef* tensorflow_graph) {
2051   tensorflow::NodeDef* logical_or_op = tensorflow_graph->add_node();
2052   logical_or_op->set_op(op_name);
2053   logical_or_op->set_name(src_op.outputs[0]);
2054   CHECK_EQ(src_op.inputs.size(), 2);
2055   for (int i = 0; i < 2; ++i) {
2056     *logical_or_op->add_input() = src_op.inputs[i];
2057   }
2058   const tensorflow::DataType data_type =
2059       GetTensorFlowDataType(model, src_op.inputs[0]);
2060   (*logical_or_op->mutable_attr())["T"].set_type(data_type);
2061 }
2062 
ConvertCTCBeamSearchDecoderOperator(const Model & model,const CTCBeamSearchDecoderOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2063 void ConvertCTCBeamSearchDecoderOperator(
2064     const Model& model, const CTCBeamSearchDecoderOperator& src_op,
2065     const char* op_name, GraphDef* tensorflow_graph) {
2066   auto* op = tensorflow_graph->add_node();
2067   op->set_op(op_name);
2068   op->set_name(src_op.outputs[0]);
2069   CHECK_EQ(src_op.inputs.size(), 2);
2070   for (int i = 0; i < 2; ++i) {
2071     *op->add_input() = src_op.inputs[i];
2072   }
2073   (*op->mutable_attr())["beam_width"].set_i(src_op.beam_width);
2074   (*op->mutable_attr())["top_paths"].set_i(src_op.top_paths);
2075   (*op->mutable_attr())["merge_repeated"].set_b(src_op.merge_repeated);
2076 }
2077 
ConvertUnpackOperator(const Model & model,const UnpackOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2078 void ConvertUnpackOperator(const Model& model, const UnpackOperator& src_op,
2079                            const char* op_name, GraphDef* tensorflow_graph) {
2080   tensorflow::NodeDef* unpack_op = tensorflow_graph->add_node();
2081   unpack_op->set_op(op_name);
2082   unpack_op->set_name(src_op.outputs[0]);
2083   CHECK_EQ(src_op.inputs.size(), 2);
2084   *unpack_op->add_input() = src_op.inputs[0];
2085   const tensorflow::DataType data_type =
2086       GetTensorFlowDataType(model, src_op.inputs[0]);
2087   (*unpack_op->mutable_attr())["T"].set_type(data_type);
2088   (*unpack_op->mutable_attr())["num"].set_i(src_op.num);
2089   (*unpack_op->mutable_attr())["axis"].set_i(src_op.axis);
2090 }
2091 
ConvertZerosLikeOperator(const Model & model,const TensorFlowZerosLikeOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2092 void ConvertZerosLikeOperator(const Model& model,
2093                               const TensorFlowZerosLikeOperator& src_op,
2094                               const char* op_name, GraphDef* tensorflow_graph) {
2095   tensorflow::NodeDef* zeros_like_op = tensorflow_graph->add_node();
2096   zeros_like_op->set_op(op_name);
2097   zeros_like_op->set_name(src_op.outputs[0]);
2098   DCHECK_EQ(src_op.inputs.size(), 1);
2099   *zeros_like_op->add_input() = src_op.inputs[0];
2100   const tensorflow::DataType data_type =
2101       GetTensorFlowDataType(model, src_op.inputs[0]);
2102   (*zeros_like_op->mutable_attr())["T"].set_type(data_type);
2103 }
2104 
ConvertReverseV2Operator(const Model & model,const ReverseV2Operator & src_op,const char * op_name,GraphDef * tensorflow_graph)2105 void ConvertReverseV2Operator(const Model& model,
2106                               const ReverseV2Operator& src_op,
2107                               const char* op_name, GraphDef* tensorflow_graph) {
2108   tensorflow::NodeDef* reverse_v2_op = tensorflow_graph->add_node();
2109   reverse_v2_op->set_op(op_name);
2110   reverse_v2_op->set_name(src_op.outputs[0]);
2111   DCHECK_EQ(src_op.inputs.size(), 2);
2112   *reverse_v2_op->add_input() = src_op.inputs[0];
2113   *reverse_v2_op->add_input() = src_op.inputs[1];
2114   const tensorflow::DataType data_type =
2115       GetTensorFlowDataType(model, src_op.inputs[0]);
2116   (*reverse_v2_op->mutable_attr())["T"].set_type(data_type);
2117 }
2118 
ConvertReverseSequenceOperator(const Model & model,const ReverseSequenceOperator & src_op,GraphDef * tensorflow_graph)2119 void ConvertReverseSequenceOperator(const Model& model,
2120                                     const ReverseSequenceOperator& src_op,
2121                                     GraphDef* tensorflow_graph) {
2122   tensorflow::NodeDef* reverse_seq_op = tensorflow_graph->add_node();
2123   reverse_seq_op->set_op("ReverseSequence");
2124   reverse_seq_op->set_name(src_op.outputs[0]);
2125   CHECK_EQ(src_op.inputs.size(), 2);
2126   *reverse_seq_op->add_input() = src_op.inputs[0];
2127   *reverse_seq_op->add_input() = src_op.inputs[1];
2128   (*reverse_seq_op->mutable_attr())["seq_dim"].set_i(src_op.seq_dim);
2129   (*reverse_seq_op->mutable_attr())["batch_dim"].set_i(src_op.batch_dim);
2130 }
2131 
ConvertOperator(const Model & model,const Operator & src_op,GraphDef * tensorflow_graph)2132 void ConvertOperator(const Model& model, const Operator& src_op,
2133                      GraphDef* tensorflow_graph) {
2134   if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
2135     LOG(FATAL)
2136         << "Unsupported: the input model has a fused activation function";
2137   }
2138 
2139   if (src_op.type == OperatorType::kConv) {
2140     ConvertConvOperator(model, static_cast<const ConvOperator&>(src_op),
2141                         tensorflow_graph);
2142   } else if (src_op.type == OperatorType::kDepthwiseConv) {
2143     ConvertDepthwiseConvOperator(
2144         model, static_cast<const DepthwiseConvOperator&>(src_op),
2145         tensorflow_graph);
2146   } else if (src_op.type == OperatorType::kDepthToSpace) {
2147     ConvertDepthToSpaceOperator(
2148         model, static_cast<const DepthToSpaceOperator&>(src_op),
2149         tensorflow_graph);
2150   } else if (src_op.type == OperatorType::kSpaceToDepth) {
2151     ConvertSpaceToDepthOperator(
2152         model, static_cast<const SpaceToDepthOperator&>(src_op),
2153         tensorflow_graph);
2154   } else if (src_op.type == OperatorType::kFullyConnected) {
2155     ConvertFullyConnectedOperator(
2156         model, static_cast<const FullyConnectedOperator&>(src_op),
2157         tensorflow_graph);
2158   } else if (src_op.type == OperatorType::kAdd) {
2159     ConvertAddOperator(model, static_cast<const AddOperator&>(src_op),
2160                        tensorflow_graph);
2161   } else if (src_op.type == OperatorType::kAddN) {
2162     ConvertAddNOperator(model, static_cast<const AddNOperator&>(src_op),
2163                         tensorflow_graph);
2164   } else if (src_op.type == OperatorType::kMul) {
2165     ConvertMulOperator(model, static_cast<const MulOperator&>(src_op),
2166                        tensorflow_graph);
2167   } else if (src_op.type == OperatorType::kDiv) {
2168     ConvertDivOperator(model, static_cast<const DivOperator&>(src_op),
2169                        tensorflow_graph);
2170   } else if (src_op.type == OperatorType::kRelu) {
2171     ConvertReluOperator(model, static_cast<const ReluOperator&>(src_op),
2172                         tensorflow_graph);
2173   } else if (src_op.type == OperatorType::kRelu1) {
2174     ConvertRelu1Operator(static_cast<const Relu1Operator&>(src_op),
2175                          tensorflow_graph);
2176   } else if (src_op.type == OperatorType::kRelu6) {
2177     ConvertRelu6Operator(static_cast<const Relu6Operator&>(src_op),
2178                          tensorflow_graph);
2179   } else if (src_op.type == OperatorType::kLog) {
2180     ConvertLogOperator(static_cast<const LogOperator&>(src_op),
2181                        tensorflow_graph);
2182   } else if (src_op.type == OperatorType::kLogistic) {
2183     ConvertLogisticOperator(static_cast<const LogisticOperator&>(src_op),
2184                             tensorflow_graph);
2185   } else if (src_op.type == OperatorType::kTanh) {
2186     ConvertTanhOperator(static_cast<const TanhOperator&>(src_op),
2187                         tensorflow_graph);
2188   } else if (src_op.type == OperatorType::kL2Normalization) {
2189     ConvertL2NormalizationOperator(
2190         static_cast<const L2NormalizationOperator&>(src_op), tensorflow_graph);
2191   } else if (src_op.type == OperatorType::kSoftmax) {
2192     ConvertSoftmaxOperator(model, static_cast<const SoftmaxOperator&>(src_op),
2193                            tensorflow_graph);
2194   } else if (src_op.type == OperatorType::kLogSoftmax) {
2195     ConvertLogSoftmaxOperator(model,
2196                               static_cast<const LogSoftmaxOperator&>(src_op),
2197                               tensorflow_graph);
2198   } else if (src_op.type == OperatorType::kLocalResponseNormalization) {
2199     ConvertLocalResponseNormalizationOperator(
2200         static_cast<const LocalResponseNormalizationOperator&>(src_op),
2201         tensorflow_graph);
2202   } else if (src_op.type == OperatorType::kLstmCell) {
2203     ConvertLstmCellOperator(model, static_cast<const LstmCellOperator&>(src_op),
2204                             tensorflow_graph);
2205   } else if (src_op.type == OperatorType::kMaxPool) {
2206     ConvertMaxPoolOperator(static_cast<const MaxPoolOperator&>(src_op),
2207                            tensorflow_graph);
2208   } else if (src_op.type == OperatorType::kAveragePool) {
2209     ConvertAveragePoolOperator(static_cast<const AveragePoolOperator&>(src_op),
2210                                tensorflow_graph);
2211   } else if (src_op.type == OperatorType::kConcatenation) {
2212     ConvertConcatenationOperator(
2213         model, static_cast<const ConcatenationOperator&>(src_op),
2214         tensorflow_graph);
2215   } else if (src_op.type == OperatorType::kReshape) {
2216     ConvertTensorFlowReshapeOperator(
2217         model, static_cast<const TensorFlowReshapeOperator&>(src_op),
2218         tensorflow_graph);
2219   } else if (src_op.type == OperatorType::kL2Pool) {
2220     ConvertL2PoolOperator(static_cast<const L2PoolOperator&>(src_op),
2221                           tensorflow_graph);
2222   } else if (src_op.type == OperatorType::kSquare) {
2223     ConvertSquareOperator(static_cast<const TensorFlowSquareOperator&>(src_op),
2224                           tensorflow_graph);
2225   } else if (src_op.type == OperatorType::kSqrt) {
2226     ConvertSqrtOperator(static_cast<const TensorFlowSqrtOperator&>(src_op),
2227                         tensorflow_graph);
2228   } else if (src_op.type == OperatorType::kRsqrt) {
2229     ConvertRsqrtOperator(model,
2230                          static_cast<const TensorFlowRsqrtOperator&>(src_op),
2231                          tensorflow_graph);
2232   } else if (src_op.type == OperatorType::kSplit) {
2233     ConvertSplitOperator(model,
2234                          static_cast<const TensorFlowSplitOperator&>(src_op),
2235                          tensorflow_graph);
2236   } else if (src_op.type == OperatorType::kSplitV) {
2237     ConvertSplitVOperator(model,
2238                           static_cast<const TensorFlowSplitVOperator&>(src_op),
2239                           tensorflow_graph);
2240   } else if (src_op.type == OperatorType::kFakeQuant) {
2241     ConvertFakeQuantOperator(static_cast<const FakeQuantOperator&>(src_op),
2242                              tensorflow_graph);
2243   } else if (src_op.type == OperatorType::kCast) {
2244     ConvertCastOperator(model, static_cast<const CastOperator&>(src_op),
2245                         tensorflow_graph);
2246   } else if (src_op.type == OperatorType::kFloor) {
2247     ConvertFloorOperator(model, static_cast<const FloorOperator&>(src_op),
2248                          tensorflow_graph);
2249   } else if (src_op.type == OperatorType::kCeil) {
2250     ConvertCeilOperator(model, static_cast<const CeilOperator&>(src_op),
2251                         tensorflow_graph);
2252   } else if (src_op.type == OperatorType::kRound) {
2253     ConvertRoundOperator(model, static_cast<const RoundOperator&>(src_op),
2254                          tensorflow_graph);
2255   } else if (src_op.type == OperatorType::kGather) {
2256     ConvertGatherOperator(model, static_cast<const GatherOperator&>(src_op),
2257                           tensorflow_graph);
2258   } else if (src_op.type == OperatorType::kResizeBilinear) {
2259     ConvertResizeBilinearOperator(
2260         model, static_cast<const ResizeBilinearOperator&>(src_op),
2261         tensorflow_graph);
2262   } else if (src_op.type == OperatorType::kResizeNearestNeighbor) {
2263     ConvertResizeNearestNeighborOperator(
2264         model, static_cast<const ResizeNearestNeighborOperator&>(src_op),
2265         tensorflow_graph);
2266   } else if (src_op.type == OperatorType::kSpaceToBatchND) {
2267     ConvertSpaceToBatchNDOperator(
2268         model, static_cast<const SpaceToBatchNDOperator&>(src_op),
2269         tensorflow_graph);
2270   } else if (src_op.type == OperatorType::kBatchToSpaceND) {
2271     ConvertBatchToSpaceNDOperator(
2272         model, static_cast<const BatchToSpaceNDOperator&>(src_op),
2273         tensorflow_graph);
2274   } else if (src_op.type == OperatorType::kPad) {
2275     ConvertPadOperator(model, static_cast<const PadOperator&>(src_op),
2276                        tensorflow_graph);
2277   } else if (src_op.type == OperatorType::kPadV2) {
2278     ConvertPadV2Operator(model, static_cast<const PadV2Operator&>(src_op),
2279                          tensorflow_graph);
2280   } else if (src_op.type == OperatorType::kStridedSlice) {
2281     ConvertStridedSliceOperator(
2282         model, static_cast<const StridedSliceOperator&>(src_op),
2283         tensorflow_graph);
2284   } else if (src_op.type == OperatorType::kMean) {
2285     ConvertReduceOperator(model, static_cast<const MeanOperator&>(src_op),
2286                           tensorflow_graph, "Mean");
2287   } else if (src_op.type == OperatorType::kSum) {
2288     ConvertReduceOperator(model,
2289                           static_cast<const TensorFlowSumOperator&>(src_op),
2290                           tensorflow_graph, "Sum");
2291   } else if (src_op.type == OperatorType::kReduceProd) {
2292     ConvertReduceOperator(model,
2293                           static_cast<const TensorFlowProdOperator&>(src_op),
2294                           tensorflow_graph, "Prod");
2295   } else if (src_op.type == OperatorType::kReduceMin) {
2296     ConvertReduceOperator(model,
2297                           static_cast<const TensorFlowMinOperator&>(src_op),
2298                           tensorflow_graph, "Min");
2299   } else if (src_op.type == OperatorType::kReduceMax) {
2300     ConvertReduceOperator(model,
2301                           static_cast<const TensorFlowMaxOperator&>(src_op),
2302                           tensorflow_graph, "Max");
2303   } else if (src_op.type == OperatorType::kSub) {
2304     ConvertSubOperator(model, static_cast<const SubOperator&>(src_op),
2305                        tensorflow_graph);
2306   } else if (src_op.type == OperatorType::kMinimum) {
2307     ConvertTensorFlowMinimumOperator(
2308         model, static_cast<const TensorFlowMinimumOperator&>(src_op),
2309         tensorflow_graph);
2310   } else if (src_op.type == OperatorType::kMaximum) {
2311     ConvertTensorFlowMaximumOperator(
2312         model, static_cast<const TensorFlowMaximumOperator&>(src_op),
2313         tensorflow_graph);
2314   } else if (src_op.type == OperatorType::kSqueeze) {
2315     ConvertSqueezeOperator(model, static_cast<const SqueezeOperator&>(src_op),
2316                            tensorflow_graph);
2317   } else if (src_op.type == OperatorType::kSlice) {
2318     ConvertSliceOperator(model, static_cast<const SliceOperator&>(src_op),
2319                          tensorflow_graph);
2320   } else if (src_op.type == OperatorType::kArgMax) {
2321     ConvertArgMaxOperator(model, static_cast<const ArgMaxOperator&>(src_op),
2322                           tensorflow_graph);
2323   } else if (src_op.type == OperatorType::kArgMin) {
2324     ConvertArgMinOperator(model, static_cast<const ArgMinOperator&>(src_op),
2325                           tensorflow_graph);
2326   } else if (src_op.type == OperatorType::kTopK_V2) {
2327     ConvertTopKV2Operator(model, static_cast<const TopKV2Operator&>(src_op),
2328                           tensorflow_graph);
2329   } else if (src_op.type == OperatorType::kTranspose) {
2330     ConvertTransposeOperator(
2331         model, static_cast<const TransposeOperator&>(src_op), tensorflow_graph);
2332   } else if (src_op.type == OperatorType::kShape) {
2333     ConvertTensorFlowShapeOperator(
2334         model, static_cast<const TensorFlowShapeOperator&>(src_op),
2335         tensorflow_graph);
2336   } else if (src_op.type == OperatorType::kRank) {
2337     ConvertRankOperator(model,
2338                         static_cast<const TensorFlowRankOperator&>(src_op),
2339                         tensorflow_graph);
2340   } else if (src_op.type == OperatorType::kRange) {
2341     ConvertRangeOperator(model, static_cast<const RangeOperator&>(src_op),
2342                          tensorflow_graph);
2343   } else if (src_op.type == OperatorType::kPack) {
2344     ConvertPackOperator(model, static_cast<const PackOperator&>(src_op),
2345                         tensorflow_graph);
2346   } else if (src_op.type == OperatorType::kFill) {
2347     ConvertFillOperator(model, static_cast<const FillOperator&>(src_op),
2348                         tensorflow_graph);
2349   } else if (src_op.type == OperatorType::kFloorDiv) {
2350     ConvertFloorDivOperator(model, static_cast<const FloorDivOperator&>(src_op),
2351                             tensorflow_graph);
2352   } else if (src_op.type == OperatorType::kFloorMod) {
2353     ConvertFloorModOperator(model, static_cast<const FloorModOperator&>(src_op),
2354                             tensorflow_graph);
2355   } else if (src_op.type == OperatorType::kExpandDims) {
2356     ConvertExpandDimsOperator(model,
2357                               static_cast<const ExpandDimsOperator&>(src_op),
2358                               tensorflow_graph);
2359   } else if (src_op.type == OperatorType::kTransposeConv) {
2360     ConvertTransposeConvOperator(
2361         model, static_cast<const TransposeConvOperator&>(src_op),
2362         tensorflow_graph);
2363   } else if (src_op.type == OperatorType::kRandomUniform) {
2364     ConvertRandomUniformOperator(
2365         model, static_cast<const RandomUniformOperator&>(src_op),
2366         tensorflow_graph);
2367   } else if (src_op.type == OperatorType::kEqual) {
2368     ConvertComparisonOperator(model, src_op, "Equal", tensorflow_graph);
2369   } else if (src_op.type == OperatorType::kNotEqual) {
2370     ConvertComparisonOperator(model, src_op, "NotEqual", tensorflow_graph);
2371   } else if (src_op.type == OperatorType::kGreater) {
2372     ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph);
2373   } else if (src_op.type == OperatorType::kGreaterEqual) {
2374     ConvertComparisonOperator(model, src_op, "GreaterEqual", tensorflow_graph);
2375   } else if (src_op.type == OperatorType::kLess) {
2376     ConvertComparisonOperator(model, src_op, "Less", tensorflow_graph);
2377   } else if (src_op.type == OperatorType::kLessEqual) {
2378     ConvertComparisonOperator(model, src_op, "LessEqual", tensorflow_graph);
2379   } else if (src_op.type == OperatorType::kSelect) {
2380     ConvertSelectOperator(model, static_cast<const SelectOperator&>(src_op),
2381                           tensorflow_graph);
2382   } else if (src_op.type == OperatorType::kTile) {
2383     ConvertTileOperator(model,
2384                         static_cast<const TensorFlowTileOperator&>(src_op),
2385                         tensorflow_graph);
2386   } else if (src_op.type == OperatorType::kPow) {
2387     ConvertPowOperator(model, static_cast<const PowOperator&>(src_op), "Pow",
2388                        tensorflow_graph);
2389   } else if (src_op.type == OperatorType::kAny) {
2390     ConvertReduceOperator(model,
2391                           static_cast<const TensorFlowAnyOperator&>(src_op),
2392                           tensorflow_graph, "Any");
2393   } else if (src_op.type == OperatorType::kLogicalAnd) {
2394     ConvertLogicalAndOperator(model,
2395                               static_cast<const LogicalAndOperator&>(src_op),
2396                               tensorflow_graph);
2397   } else if (src_op.type == OperatorType::kLogicalNot) {
2398     ConvertLogicalNotOperator(model,
2399                               static_cast<const LogicalNotOperator&>(src_op),
2400                               tensorflow_graph);
2401   } else if (src_op.type == OperatorType::kOneHot) {
2402     ConvertOneHotOperator(model, static_cast<const OneHotOperator&>(src_op),
2403                           tensorflow_graph);
2404   } else if (src_op.type == OperatorType::kLogicalOr) {
2405     ConvertLogicalOrOperator(model,
2406                              static_cast<const LogicalOrOperator&>(src_op),
2407                              "LogicalOr", tensorflow_graph);
2408   } else if (src_op.type == OperatorType::kCTCBeamSearchDecoder) {
2409     ConvertCTCBeamSearchDecoderOperator(
2410         model, static_cast<const CTCBeamSearchDecoderOperator&>(src_op),
2411         "CTCBeamSearchDecoder", tensorflow_graph);
2412   } else if (src_op.type == OperatorType::kUnpack) {
2413     ConvertUnpackOperator(model, static_cast<const UnpackOperator&>(src_op),
2414                           "Unpack", tensorflow_graph);
2415   } else if (src_op.type == OperatorType::kZerosLike) {
2416     ConvertZerosLikeOperator(
2417         model, static_cast<const TensorFlowZerosLikeOperator&>(src_op),
2418         "ZerosLike", tensorflow_graph);
2419   } else if (src_op.type == OperatorType::kReverseV2) {
2420     ConvertReverseV2Operator(model,
2421                              static_cast<const ReverseV2Operator&>(src_op),
2422                              "Reverse_V2", tensorflow_graph);
2423   } else if (src_op.type == OperatorType::kReverseSequence) {
2424     ConvertReverseSequenceOperator(
2425         model, static_cast<const ReverseSequenceOperator&>(src_op),
2426         tensorflow_graph);
2427   } else {
2428     LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
2429   }
2430 }
2431 
AddPlaceholder(const std::string & name,ArrayDataType type,GraphDef * tensorflow_graph)2432 void AddPlaceholder(const std::string& name, ArrayDataType type,
2433                     GraphDef* tensorflow_graph) {
2434   tensorflow::NodeDef* placeholder = tensorflow_graph->add_node();
2435   placeholder->set_op("Placeholder");
2436   switch (type) {
2437     case ArrayDataType::kBool:
2438       (*placeholder->mutable_attr())["dtype"].set_type(DT_BOOL);
2439       break;
2440     case ArrayDataType::kFloat:
2441       (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
2442       break;
2443     case ArrayDataType::kUint8:
2444       (*placeholder->mutable_attr())["dtype"].set_type(DT_UINT8);
2445       break;
2446     case ArrayDataType::kInt32:
2447       (*placeholder->mutable_attr())["dtype"].set_type(DT_INT32);
2448       break;
2449     case ArrayDataType::kUint32:
2450       (*placeholder->mutable_attr())["dtype"].set_type(DT_UINT32);
2451       break;
2452     case ArrayDataType::kInt64:
2453       (*placeholder->mutable_attr())["dtype"].set_type(DT_INT64);
2454       break;
2455     case ArrayDataType::kInt16:
2456       (*placeholder->mutable_attr())["dtype"].set_type(DT_INT16);
2457       break;
2458     case ArrayDataType::kComplex64:
2459       (*placeholder->mutable_attr())["dtype"].set_type(DT_COMPLEX64);
2460       break;
2461     default:
2462       LOG(FATAL) << "Unexpected data type in array \"" << name << "\"";
2463   }
2464   placeholder->set_name(name);
2465 }
2466 
AddPlaceholderForRNNState(const Model & model,const std::string & name,int size,GraphDef * tensorflow_graph)2467 void AddPlaceholderForRNNState(const Model& model, const std::string& name,
2468                                int size, GraphDef* tensorflow_graph) {
2469   tensorflow::NodeDef* placeholder = tensorflow_graph->add_node();
2470   placeholder->set_op("Placeholder");
2471   placeholder->set_name(name);
2472   (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
2473 
2474   auto* shape = (*placeholder->mutable_attr())["shape"].mutable_shape();
2475   const auto& state_array = model.GetArray(name);
2476   if (state_array.has_shape()) {
2477     const auto& state_shape = state_array.shape();
2478     const int kDims = state_shape.dimensions_count();
2479     for (int i = 0; i < kDims; ++i) {
2480       shape->add_dim()->set_size(state_shape.dims(i));
2481     }
2482   } else {
2483     shape->add_dim()->set_size(1);
2484     shape->add_dim()->set_size(size);
2485   }
2486 }
2487 
ExportTensorFlowGraphDefImplementation(const Model & model,GraphDef * tensorflow_graph)2488 void ExportTensorFlowGraphDefImplementation(const Model& model,
2489                                             GraphDef* tensorflow_graph) {
2490   for (const auto& input_array : model.flags.input_arrays()) {
2491     AddPlaceholder(input_array.name(),
2492                    model.GetArray(input_array.name()).data_type,
2493                    tensorflow_graph);
2494   }
2495   for (const auto& rnn_state : model.flags.rnn_states()) {
2496     AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(),
2497                               tensorflow_graph);
2498   }
2499   for (const auto& op : model.operators) {
2500     ConvertOperator(model, *op, tensorflow_graph);
2501   }
2502   // Generically export arrays that haven't been exported already
2503   // by the above operators export. It's important that this comes
2504   // after, as some operators need to export arrays that they reference
2505   // in a specific way, rather than in the generic way done below.
2506   for (const auto& array_pair : model.GetArrayMap()) {
2507     const std::string& array_name = array_pair.first;
2508     const auto& array = *array_pair.second;
2509     if (array.buffer) {
2510       switch (array.data_type) {
2511         case ArrayDataType::kBool:
2512           ConvertBoolTensorConst(model, array_name, tensorflow_graph);
2513           break;
2514         case ArrayDataType::kFloat:
2515           ConvertFloatTensorConst(model, array_name, tensorflow_graph);
2516           break;
2517         case ArrayDataType::kInt32:
2518           ConvertIntTensorConst(model, array_name, tensorflow_graph);
2519           break;
2520         case ArrayDataType::kComplex64:
2521           ConvertComplex64TensorConst(model, array_name, tensorflow_graph);
2522           break;
2523         default:
2524           break;
2525       }
2526     }
2527   }
2528 }
2529 }  // namespace
2530 
EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(Model * model)2531 void EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(Model* model) {
2532   for (const auto& array_kv : model->GetArrayMap()) {
2533     const std::string& array_name = array_kv.first;
2534     Array& array = *array_kv.second;
2535     if (!array.buffer || !array.minmax) {
2536       continue;
2537     }
2538     const std::string& wrapped_array_name =
2539         AvailableArrayName(*model, array_name + "/data");
2540     Array& wrapped_array = model->GetOrCreateArray(wrapped_array_name);
2541     wrapped_array.data_type = array.data_type;
2542     wrapped_array.copy_shape(array.shape());
2543     wrapped_array.buffer = std::move(array.buffer);
2544     FakeQuantOperator* fakequant_op = new FakeQuantOperator;
2545     fakequant_op->inputs = {wrapped_array_name};
2546     fakequant_op->outputs = {array_name};
2547     fakequant_op->minmax = std::make_unique<MinMax>();
2548     *fakequant_op->minmax = *array.minmax;
2549     const auto& it = FindOpWithInput(*model, array_name);
2550     model->operators.emplace(it, fakequant_op);
2551   }
2552   CheckInvariants(*model);
2553 }
2554 
ExportTensorFlowGraphDef(const Model & model,std::string * output_file_contents)2555 void ExportTensorFlowGraphDef(const Model& model,
2556                               std::string* output_file_contents) {
2557   CHECK(output_file_contents->empty());
2558   GraphDef tensorflow_graph;
2559   ExportTensorFlowGraphDefImplementation(model, &tensorflow_graph);
2560   LogDumpGraphDef(kLogLevelModelChanged, "AT EXPORT", tensorflow_graph);
2561   CHECK(tensorflow_graph.SerializeToString(output_file_contents));
2562 }
2563 }  // namespace toco
2564