xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/toco/tflite/operator.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/operator.h"
16 
17 #include <map>
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/op.h"
25 #include "tensorflow/core/framework/op_def.pb.h"
26 #include "tensorflow/core/util/ptr_util.h"
27 
28 // TODO(ycling): Consider refactoring to extract the LSTM definition out of
29 // graph_transformation module.
30 #include "tensorflow/lite/builtin_op_data.h"
31 #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
32 #include "tensorflow/lite/schema/schema_generated.h"
33 #include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
34 #include "tensorflow/lite/toco/model.h"
35 #include "tensorflow/lite/toco/tflite/builtin_operator.h"
36 #include "tensorflow/lite/toco/tflite/custom_operator.h"
37 #include "tensorflow/lite/toco/tflite/simple_operator.h"
38 #include "tensorflow/lite/toco/tflite/types.h"
39 #include "tensorflow/lite/tools/versioning/op_version.h"
40 
41 namespace toco {
42 
43 namespace tflite {
44 
45 // LINT.IfChange
46 
GetTensorType(const ArrayDataType type)47 TfLiteType GetTensorType(const ArrayDataType type) {
48   const std::map<ArrayDataType, TfLiteType> tensor_type_map = {
49       {ArrayDataType::kBool, kTfLiteBool},
50       {ArrayDataType::kFloat, kTfLiteFloat32},
51       {ArrayDataType::kInt8, kTfLiteInt8},
52       {ArrayDataType::kUint8, kTfLiteUInt8},
53       {ArrayDataType::kInt16, kTfLiteInt16},
54       {ArrayDataType::kUint16, kTfLiteUInt16},
55       {ArrayDataType::kInt32, kTfLiteInt32},
56       {ArrayDataType::kUint32, kTfLiteUInt32},
57       {ArrayDataType::kInt64, kTfLiteInt64},
58       {ArrayDataType::kUint64, kTfLiteUInt64},
59       {ArrayDataType::kString, kTfLiteString},
60       {ArrayDataType::kComplex64, kTfLiteComplex64},
61       {ArrayDataType::kComplex128, kTfLiteComplex128},
62       {ArrayDataType::kFloat16, kTfLiteFloat16},
63       {ArrayDataType::kFloat64, kTfLiteFloat64}};
64 
65   auto it = tensor_type_map.find(type);
66   if (it != tensor_type_map.end()) {
67     return it->second;
68   }
69   return kTfLiteNoType;
70 }
71 
GetVersioningOpSig(const::tflite::BuiltinOperator op,const OperatorSignature & op_signature)72 ::tflite::OpSignature GetVersioningOpSig(
73     const ::tflite::BuiltinOperator op, const OperatorSignature& op_signature) {
74   std::vector<::tflite::OpSignatureTensorSpec> inputs, outputs;
75   for (const auto& input_name : op_signature.op->inputs) {
76     ::tflite::OpSignatureTensorSpec tensor = {kTfLiteNoType};
77     if (op_signature.model->HasArray(input_name)) {
78       const Array& input_array = op_signature.model->GetArray(input_name);
79       tensor.type = GetTensorType(input_array.data_type);
80       if (input_array.has_shape()) {
81         tensor.dims = input_array.shape().dims();
82       }
83     }
84     inputs.push_back(tensor);
85   }
86   for (const auto& output_name : op_signature.op->outputs) {
87     ::tflite::OpSignatureTensorSpec tensor = {kTfLiteNoType};
88     if (op_signature.model->HasArray(output_name)) {
89       const Array& output_array = op_signature.model->GetArray(output_name);
90       tensor.type = GetTensorType(output_array.data_type);
91       if (output_array.has_shape()) {
92         tensor.dims = output_array.shape().dims();
93       }
94     }
95     outputs.push_back(tensor);
96   }
97   return ::tflite::OpSignature{op, inputs, outputs};
98 }
99 
100 class AveragePool
101     : public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions,
102                              ::tflite::BuiltinOptions_Pool2DOptions> {
103  public:
104   using BuiltinOperator::BuiltinOperator;
105 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const106   flatbuffers::Offset<TfLiteOptions> WriteOptions(
107       const TocoOperator& op,
108       flatbuffers::FlatBufferBuilder* builder) const override {
109     auto padding = Padding::Serialize(op.padding.type);
110     auto activation_function =
111         ActivationFunction::Serialize(op.fused_activation_function);
112     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
113                                          op.stride_height, op.kwidth,
114                                          op.kheight, activation_function);
115   }
116 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const117   void ReadOptions(const TfLiteOptions& options,
118                    TocoOperator* op) const override {
119     op->padding.type = Padding::Deserialize(options.padding());
120     op->stride_width = options.stride_w();
121     op->stride_height = options.stride_h();
122     op->kwidth = options.filter_width();
123     op->kheight = options.filter_height();
124     op->fused_activation_function =
125         ActivationFunction::Deserialize(options.fused_activation_function());
126   }
127 };
128 
129 class Convolution
130     : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
131                              ::tflite::BuiltinOptions_Conv2DOptions> {
132  public:
133   using BuiltinOperator::BuiltinOperator;
134 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const135   flatbuffers::Offset<TfLiteOptions> WriteOptions(
136       const TocoOperator& op,
137       flatbuffers::FlatBufferBuilder* builder) const override {
138     auto padding = Padding::Serialize(op.padding.type);
139     auto activation_function =
140         ActivationFunction::Serialize(op.fused_activation_function);
141     return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
142                                          op.stride_height, activation_function,
143                                          op.dilation_width_factor,
144                                          op.dilation_height_factor);
145   }
146 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const147   void ReadOptions(const TfLiteOptions& options,
148                    TocoOperator* op) const override {
149     op->padding.type = Padding::Deserialize(options.padding());
150     op->stride_width = options.stride_w();
151     op->stride_height = options.stride_h();
152     op->dilation_width_factor = options.dilation_w_factor();
153     op->dilation_height_factor = options.dilation_h_factor();
154     op->fused_activation_function =
155         ActivationFunction::Deserialize(options.fused_activation_function());
156   }
157 };
158 
159 class DepthwiseConvolution
160     : public BuiltinOperator<DepthwiseConvOperator,
161                              ::tflite::DepthwiseConv2DOptions,
162                              ::tflite::BuiltinOptions_DepthwiseConv2DOptions> {
163  public:
164   using BuiltinOperator::BuiltinOperator;
165 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const166   flatbuffers::Offset<TfLiteOptions> WriteOptions(
167       const TocoOperator& op,
168       flatbuffers::FlatBufferBuilder* builder) const override {
169     auto padding = Padding::Serialize(op.padding.type);
170     auto activation_function =
171         ActivationFunction::Serialize(op.fused_activation_function);
172     return ::tflite::CreateDepthwiseConv2DOptions(
173         *builder, padding, op.stride_width, op.stride_height,
174         op.depth_multiplier, activation_function, op.dilation_width_factor,
175         op.dilation_height_factor);
176   }
177 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const178   void ReadOptions(const TfLiteOptions& options,
179                    TocoOperator* op) const override {
180     op->padding.type = Padding::Deserialize(options.padding());
181     op->stride_width = options.stride_w();
182     op->stride_height = options.stride_h();
183     op->depth_multiplier = options.depth_multiplier();
184     op->fused_activation_function =
185         ActivationFunction::Deserialize(options.fused_activation_function());
186     op->dilation_width_factor = options.dilation_w_factor();
187     op->dilation_height_factor = options.dilation_h_factor();
188   }
189 
GetVersion(const OperatorSignature & op_signature) const190   int GetVersion(const OperatorSignature& op_signature) const override {
191     const auto& conv_op =
192         static_cast<const DepthwiseConvOperator&>(*op_signature.op);
193     ::tflite::OpSignature op_sig =
194         GetVersioningOpSig(builtin_op(), op_signature);
195     TfLiteDepthwiseConvParams depthwise_conv_params = {};
196     depthwise_conv_params.dilation_width_factor = conv_op.dilation_width_factor;
197     depthwise_conv_params.dilation_height_factor =
198         conv_op.dilation_height_factor;
199     op_sig.builtin_data = reinterpret_cast<void*>(&depthwise_conv_params);
200     return ::tflite::GetBuiltinOperatorVersion(op_sig);
201   }
202 };
203 
204 class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
205                                    ::tflite::BuiltinOptions_AddOptions> {
206  public:
207   using BuiltinOperator::BuiltinOperator;
208 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const209   flatbuffers::Offset<TfLiteOptions> WriteOptions(
210       const TocoOperator& op,
211       flatbuffers::FlatBufferBuilder* builder) const override {
212     auto activation_function =
213         ActivationFunction::Serialize(op.fused_activation_function);
214     return ::tflite::CreateAddOptions(*builder, activation_function);
215   }
216 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const217   void ReadOptions(const TfLiteOptions& options,
218                    TocoOperator* op) const override {
219     op->fused_activation_function =
220         ActivationFunction::Deserialize(options.fused_activation_function());
221   }
222 };
223 
224 class AddN : public BuiltinOperator<AddNOperator, ::tflite::AddNOptions,
225                                     ::tflite::BuiltinOptions_AddNOptions> {
226  public:
227   using BuiltinOperator::BuiltinOperator;
228 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const229   flatbuffers::Offset<TfLiteOptions> WriteOptions(
230       const TocoOperator& op,
231       flatbuffers::FlatBufferBuilder* builder) const override {
232     return ::tflite::CreateAddNOptions(*builder);
233   }
234 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const235   void ReadOptions(const TfLiteOptions& options,
236                    TocoOperator* op) const override {}
237 };
238 
239 class SpaceToBatchND
240     : public BuiltinOperator<SpaceToBatchNDOperator,
241                              ::tflite::SpaceToBatchNDOptions,
242                              ::tflite::BuiltinOptions_SpaceToBatchNDOptions> {
243  public:
244   using BuiltinOperator::BuiltinOperator;
245 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const246   flatbuffers::Offset<TfLiteOptions> WriteOptions(
247       const TocoOperator& op,
248       flatbuffers::FlatBufferBuilder* builder) const override {
249     return ::tflite::CreateSpaceToBatchNDOptions(*builder);
250   }
251 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const252   void ReadOptions(const TfLiteOptions& options,
253                    TocoOperator* op) const override {}
254 
GetVersion(const OperatorSignature & op_signature) const255   int GetVersion(const OperatorSignature& op_signature) const override {
256     ::tflite::OpSignature op_sig =
257         GetVersioningOpSig(builtin_op(), op_signature);
258     return ::tflite::GetBuiltinOperatorVersion(op_sig);
259   }
260 };
261 
262 class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
263                                    ::tflite::BuiltinOptions_SubOptions> {
264  public:
265   using BuiltinOperator::BuiltinOperator;
266 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const267   flatbuffers::Offset<TfLiteOptions> WriteOptions(
268       const TocoOperator& op,
269       flatbuffers::FlatBufferBuilder* builder) const override {
270     auto activation_function =
271         ActivationFunction::Serialize(op.fused_activation_function);
272     return ::tflite::CreateSubOptions(*builder, activation_function);
273   }
274 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const275   void ReadOptions(const TfLiteOptions& options,
276                    TocoOperator* op) const override {
277     op->fused_activation_function =
278         ActivationFunction::Deserialize(options.fused_activation_function());
279   }
280 
GetVersion(const OperatorSignature & op_signature) const281   int GetVersion(const OperatorSignature& op_signature) const override {
282     ::tflite::OpSignature op_sig =
283         GetVersioningOpSig(builtin_op(), op_signature);
284     return ::tflite::GetBuiltinOperatorVersion(op_sig);
285   }
286 };
287 
288 class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
289                                    ::tflite::BuiltinOptions_DivOptions> {
290  public:
291   using BuiltinOperator::BuiltinOperator;
292 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const293   flatbuffers::Offset<TfLiteOptions> WriteOptions(
294       const TocoOperator& op,
295       flatbuffers::FlatBufferBuilder* builder) const override {
296     auto activation_function =
297         ActivationFunction::Serialize(op.fused_activation_function);
298     return ::tflite::CreateDivOptions(*builder, activation_function);
299   }
300 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const301   void ReadOptions(const TfLiteOptions& options,
302                    TocoOperator* op) const override {
303     op->fused_activation_function =
304         ActivationFunction::Deserialize(options.fused_activation_function());
305   }
306 
GetVersion(const OperatorSignature & op_signature) const307   int GetVersion(const OperatorSignature& op_signature) const override {
308     ::tflite::OpSignature op_sig =
309         GetVersioningOpSig(builtin_op(), op_signature);
310     return ::tflite::GetBuiltinOperatorVersion(op_sig);
311   }
312 };
313 
314 class BatchToSpaceND
315     : public BuiltinOperator<BatchToSpaceNDOperator,
316                              ::tflite::BatchToSpaceNDOptions,
317                              ::tflite::BuiltinOptions_BatchToSpaceNDOptions> {
318  public:
319   using BuiltinOperator::BuiltinOperator;
320 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const321   flatbuffers::Offset<TfLiteOptions> WriteOptions(
322       const TocoOperator& op,
323       flatbuffers::FlatBufferBuilder* builder) const override {
324     return ::tflite::CreateBatchToSpaceNDOptions(*builder);
325   }
326 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const327   void ReadOptions(const TfLiteOptions& options,
328                    TocoOperator* op) const override {}
329 
GetVersion(const OperatorSignature & op_signature) const330   int GetVersion(const OperatorSignature& op_signature) const override {
331     ::tflite::OpSignature op_sig =
332         GetVersioningOpSig(builtin_op(), op_signature);
333     return ::tflite::GetBuiltinOperatorVersion(op_sig);
334   }
335 };
336 
337 class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
338                                     ::tflite::BuiltinOptions_CastOptions> {
339  public:
340   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const341   flatbuffers::Offset<TfLiteOptions> WriteOptions(
342       const TocoOperator& op,
343       flatbuffers::FlatBufferBuilder* builder) const override {
344     return ::tflite::CreateCastOptions(*builder,
345                                        DataType::Serialize(op.src_data_type),
346                                        DataType::Serialize(op.dst_data_type));
347   }
348 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const349   void ReadOptions(const TfLiteOptions& options,
350                    TocoOperator* op) const override {
351     op->src_data_type = DataType::Deserialize(options.in_data_type());
352     op->dst_data_type = DataType::Deserialize(options.out_data_type());
353   }
354 };
355 
356 class Concatenation
357     : public BuiltinOperator<ConcatenationOperator,
358                              ::tflite::ConcatenationOptions,
359                              ::tflite::BuiltinOptions_ConcatenationOptions> {
360  public:
361   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const362   flatbuffers::Offset<TfLiteOptions> WriteOptions(
363       const TocoOperator& op,
364       flatbuffers::FlatBufferBuilder* builder) const override {
365     return ::tflite::CreateConcatenationOptions(*builder, op.axis);
366   }
367 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const368   void ReadOptions(const TfLiteOptions& options,
369                    TocoOperator* op) const override {
370     op->axis = options.axis();
371   }
372 };
373 
374 class DepthToSpace
375     : public BuiltinOperator<DepthToSpaceOperator,
376                              ::tflite::DepthToSpaceOptions,
377                              ::tflite::BuiltinOptions_DepthToSpaceOptions> {
378  public:
379   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const380   flatbuffers::Offset<TfLiteOptions> WriteOptions(
381       const TocoOperator& op,
382       flatbuffers::FlatBufferBuilder* builder) const override {
383     return ::tflite::CreateDepthToSpaceOptions(*builder, op.block_size);
384   }
385 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const386   void ReadOptions(const TfLiteOptions& options,
387                    TocoOperator* op) const override {
388     op->block_size = options.block_size();
389   }
390 };
391 
392 class FakeQuant
393     : public BuiltinOperator<FakeQuantOperator, ::tflite::FakeQuantOptions,
394                              ::tflite::BuiltinOptions_FakeQuantOptions> {
395  public:
396   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const397   flatbuffers::Offset<TfLiteOptions> WriteOptions(
398       const TocoOperator& op,
399       flatbuffers::FlatBufferBuilder* builder) const override {
400     return ::tflite::CreateFakeQuantOptions(
401         *builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range);
402   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const403   void ReadOptions(const TfLiteOptions& options,
404                    TocoOperator* op) const override {
405     auto* minmax = new MinMax;
406     minmax->min = options.min();
407     minmax->max = options.max();
408     op->minmax.reset(minmax);
409     op->num_bits = options.num_bits();
410     op->narrow_range = options.narrow_range();
411   }
GetVersion(const OperatorSignature & op_signature) const412   int GetVersion(const OperatorSignature& op_signature) const override {
413     const auto& fq_op = static_cast<const FakeQuantOperator&>(*op_signature.op);
414     ::tflite::OpSignature op_sig =
415         GetVersioningOpSig(builtin_op(), op_signature);
416     TfLiteFakeQuantParams fake_quant_params = {};
417     fake_quant_params.narrow_range = fq_op.narrow_range;
418     op_sig.builtin_data = reinterpret_cast<void*>(&fake_quant_params);
419     return ::tflite::GetBuiltinOperatorVersion(op_sig);
420   }
421 };
422 
423 class FullyConnected
424     : public BuiltinOperator<FullyConnectedOperator,
425                              ::tflite::FullyConnectedOptions,
426                              ::tflite::BuiltinOptions_FullyConnectedOptions> {
427  public:
428   using BuiltinOperator::BuiltinOperator;
429 
GetWeightFormat(FullyConnectedWeightsFormat fmt) const430   ::tflite::FullyConnectedOptionsWeightsFormat GetWeightFormat(
431       FullyConnectedWeightsFormat fmt) const {
432     switch (fmt) {
433       case FullyConnectedWeightsFormat::kDefault:
434         return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
435       case FullyConnectedWeightsFormat::kShuffled4x16Int8:
436         return ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
437       default:
438         LOG(ERROR) << "Unhandled FC weights format";
439         return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
440     }
441   }
442 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const443   flatbuffers::Offset<TfLiteOptions> WriteOptions(
444       const TocoOperator& op,
445       flatbuffers::FlatBufferBuilder* builder) const override {
446     auto activation_function =
447         ActivationFunction::Serialize(op.fused_activation_function);
448     return ::tflite::CreateFullyConnectedOptions(
449         *builder, activation_function, GetWeightFormat(op.weights_format));
450   }
451 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const452   void ReadOptions(const TfLiteOptions& options,
453                    TocoOperator* op) const override {
454     op->fused_activation_function =
455         ActivationFunction::Deserialize(options.fused_activation_function());
456     switch (options.weights_format()) {
457       case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT:
458         op->weights_format = FullyConnectedWeightsFormat::kDefault;
459         break;
460       case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
461         op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8;
462         break;
463       default:
464         LOG(ERROR) << "Unhandled FC weights format";
465         op->weights_format = FullyConnectedWeightsFormat::kDefault;
466     }
467   }
468 
GetVersion(const OperatorSignature & op_signature) const469   int GetVersion(const OperatorSignature& op_signature) const override {
470     const auto& fc_op =
471         static_cast<const FullyConnectedOperator&>(*op_signature.op);
472     ::tflite::OpSignature op_sig =
473         GetVersioningOpSig(builtin_op(), op_signature);
474     TfLiteFullyConnectedParams fully_connected_params = {};
475     fully_connected_params.keep_num_dims = fc_op.keep_num_dims;
476     fully_connected_params.weights_format =
477         static_cast<TfLiteFullyConnectedWeightsFormat>(
478             GetWeightFormat(fc_op.weights_format));
479     op_sig.builtin_data = reinterpret_cast<void*>(&fully_connected_params);
480     return ::tflite::GetBuiltinOperatorVersion(op_sig);
481   }
482 };
483 
484 class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
485                                       ::tflite::BuiltinOptions_GatherOptions> {
486  public:
487   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const488   flatbuffers::Offset<TfLiteOptions> WriteOptions(
489       const TocoOperator& op,
490       flatbuffers::FlatBufferBuilder* builder) const override {
491     int axis = op.axis ? op.axis.value() : 0;
492     return ::tflite::CreateGatherOptions(*builder, axis);
493   }
494 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const495   void ReadOptions(const TfLiteOptions& options,
496                    TocoOperator* op) const override {
497     op->axis = {options.axis()};
498   }
499 };
500 
501 class GatherNd
502     : public BuiltinOperator<GatherNdOperator, ::tflite::GatherNdOptions,
503                              ::tflite::BuiltinOptions_GatherNdOptions> {
504  public:
505   using BuiltinOperator::BuiltinOperator;
506 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const507   flatbuffers::Offset<TfLiteOptions> WriteOptions(
508       const TocoOperator& op,
509       flatbuffers::FlatBufferBuilder* builder) const override {
510     return ::tflite::CreateGatherNdOptions(*builder);
511   }
512 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const513   void ReadOptions(const TfLiteOptions& options,
514                    TocoOperator* op) const override {}
515 };
516 
517 class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
518                                     ::tflite::BuiltinOptions_SVDFOptions> {
519  public:
520   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const521   flatbuffers::Offset<TfLiteOptions> WriteOptions(
522       const TocoOperator& op,
523       flatbuffers::FlatBufferBuilder* builder) const override {
524     auto activation_function =
525         ActivationFunction::Serialize(op.fused_activation_function);
526     return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function);
527   }
528 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const529   void ReadOptions(const TfLiteOptions& options,
530                    TocoOperator* op) const override {
531     op->fused_activation_function =
532         ActivationFunction::Deserialize(options.fused_activation_function());
533     op->rank = options.rank();
534   }
535 };
536 
537 class L2Normalization
538     : public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions,
539                              ::tflite::BuiltinOptions_L2NormOptions> {
540  public:
541   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const542   flatbuffers::Offset<TfLiteOptions> WriteOptions(
543       const TocoOperator& op,
544       flatbuffers::FlatBufferBuilder* builder) const override {
545     auto activation_function =
546         ActivationFunction::Serialize(op.fused_activation_function);
547     return ::tflite::CreateL2NormOptions(*builder, activation_function);
548   }
549 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const550   void ReadOptions(const TfLiteOptions& options,
551                    TocoOperator* op) const override {
552     op->fused_activation_function =
553         ActivationFunction::Deserialize(options.fused_activation_function());
554   }
555 };
556 
557 class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
558                                       ::tflite::BuiltinOptions_Pool2DOptions> {
559  public:
560   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const561   flatbuffers::Offset<TfLiteOptions> WriteOptions(
562       const TocoOperator& op,
563       flatbuffers::FlatBufferBuilder* builder) const override {
564     auto padding = Padding::Serialize(op.padding.type);
565     auto activation_function =
566         ActivationFunction::Serialize(op.fused_activation_function);
567     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
568                                          op.stride_height, op.kwidth,
569                                          op.kheight, activation_function);
570   }
571 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const572   void ReadOptions(const TfLiteOptions& options,
573                    TocoOperator* op) const override {
574     op->padding.type = Padding::Deserialize(options.padding());
575     op->stride_width = options.stride_w();
576     op->stride_height = options.stride_h();
577     op->kwidth = options.filter_width();
578     op->kheight = options.filter_height();
579     op->fused_activation_function =
580         ActivationFunction::Deserialize(options.fused_activation_function());
581   }
582 };
583 
584 class LocalResponseNormalization
585     : public BuiltinOperator<
586           LocalResponseNormalizationOperator,
587           ::tflite::LocalResponseNormalizationOptions,
588           ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> {
589  public:
590   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const591   flatbuffers::Offset<TfLiteOptions> WriteOptions(
592       const TocoOperator& op,
593       flatbuffers::FlatBufferBuilder* builder) const override {
594     return ::tflite::CreateLocalResponseNormalizationOptions(
595         *builder, op.range, op.bias, op.alpha, op.beta);
596   }
597 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const598   void ReadOptions(const TfLiteOptions& options,
599                    TocoOperator* op) const override {
600     op->range = options.radius();
601     op->bias = options.bias();
602     op->alpha = options.alpha();
603     op->beta = options.beta();
604   }
605 };
606 
607 class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
608                                        ::tflite::BuiltinOptions_Pool2DOptions> {
609  public:
610   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const611   flatbuffers::Offset<TfLiteOptions> WriteOptions(
612       const TocoOperator& op,
613       flatbuffers::FlatBufferBuilder* builder) const override {
614     auto padding = Padding::Serialize(op.padding.type);
615     auto activation_function =
616         ActivationFunction::Serialize(op.fused_activation_function);
617     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
618                                          op.stride_height, op.kwidth,
619                                          op.kheight, activation_function);
620   }
621 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const622   void ReadOptions(const TfLiteOptions& options,
623                    TocoOperator* op) const override {
624     op->padding.type = Padding::Deserialize(options.padding());
625     op->stride_width = options.stride_w();
626     op->stride_height = options.stride_h();
627     op->kwidth = options.filter_width();
628     op->kheight = options.filter_height();
629     op->fused_activation_function =
630         ActivationFunction::Deserialize(options.fused_activation_function());
631   }
632 };
633 
634 class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
635                                    ::tflite::BuiltinOptions_MulOptions> {
636  public:
637   using BuiltinOperator::BuiltinOperator;
638 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const639   flatbuffers::Offset<TfLiteOptions> WriteOptions(
640       const TocoOperator& op,
641       flatbuffers::FlatBufferBuilder* builder) const override {
642     auto activation_function =
643         ActivationFunction::Serialize(op.fused_activation_function);
644     return ::tflite::CreateMulOptions(*builder, activation_function);
645   }
646 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const647   void ReadOptions(const TfLiteOptions& options,
648                    TocoOperator* op) const override {
649     op->fused_activation_function =
650         ActivationFunction::Deserialize(options.fused_activation_function());
651   }
652 
GetVersion(const OperatorSignature & op_signature) const653   int GetVersion(const OperatorSignature& op_signature) const override {
654     const std::string& input1_name = op_signature.op->inputs[0];
655     const std::string& input2_name = op_signature.op->inputs[1];
656     const std::string& output_name = op_signature.op->outputs[0];
657     const Array& input1_array = op_signature.model->GetArray(input1_name);
658     const Array& input2_array = op_signature.model->GetArray(input2_name);
659     const Array& output_array = op_signature.model->GetArray(output_name);
660     const auto& input1_quant = input1_array.quantization_params;
661     const auto& input2_quant = input2_array.quantization_params;
662     const auto& output_quant = output_array.quantization_params;
663     const float input1_scale = input1_quant ? input1_quant->scale : 0.0f;
664     const float input2_scale = input2_quant ? input2_quant->scale : 0.0f;
665     const float output_scale = output_quant ? output_quant->scale : 0.0f;
666     ::tflite::OpSignature op_sig =
667         GetVersioningOpSig(builtin_op(), op_signature);
668     op_sig.ext_options.mul.input1_scale = input1_scale;
669     op_sig.ext_options.mul.input2_scale = input2_scale;
670     op_sig.ext_options.mul.output_scale = output_scale;
671     return ::tflite::GetBuiltinOperatorVersion(op_sig);
672   }
673 };
674 
675 class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
676                                    ::tflite::BuiltinOptions_PadOptions> {
677  public:
678   using BuiltinOperator::BuiltinOperator;
679 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const680   flatbuffers::Offset<TfLiteOptions> WriteOptions(
681       const TocoOperator& op,
682       flatbuffers::FlatBufferBuilder* builder) const override {
683     return ::tflite::CreatePadOptions(*builder);
684   }
685 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const686   void ReadOptions(const TfLiteOptions& options,
687                    TocoOperator* op) const override {}
688 };
689 
690 class Tile
691     : public BuiltinOperator<TensorFlowTileOperator, ::tflite::TileOptions,
692                              ::tflite::BuiltinOptions_TileOptions> {
693   using BuiltinOperator::BuiltinOperator;
694 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const695   flatbuffers::Offset<TfLiteOptions> WriteOptions(
696       const TocoOperator& op,
697       flatbuffers::FlatBufferBuilder* builder) const override {
698     return ::tflite::CreateTileOptions(*builder);
699   }
700 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const701   void ReadOptions(const TfLiteOptions& options,
702                    TocoOperator* op) const override {}
703 };
704 
705 class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
706                                      ::tflite::BuiltinOptions_PadV2Options> {
707  public:
708   using BuiltinOperator::BuiltinOperator;
709 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const710   flatbuffers::Offset<TfLiteOptions> WriteOptions(
711       const TocoOperator& op,
712       flatbuffers::FlatBufferBuilder* builder) const override {
713     return ::tflite::CreatePadV2Options(*builder);
714   }
715 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const716   void ReadOptions(const TfLiteOptions& options,
717                    TocoOperator* op) const override {}
718 };
719 
720 class Reshape
721     : public BuiltinOperator<TensorFlowReshapeOperator,
722                              ::tflite::ReshapeOptions,
723                              ::tflite::BuiltinOptions_ReshapeOptions> {
724  public:
725   using BuiltinOperator::BuiltinOperator;
726 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const727   flatbuffers::Offset<TfLiteOptions> WriteOptions(
728       const TocoOperator& op,
729       flatbuffers::FlatBufferBuilder* builder) const override {
730     return ::tflite::CreateReshapeOptions(*builder,
731                                           builder->CreateVector(op.shape));
732   }
733 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const734   void ReadOptions(const TfLiteOptions& options,
735                    TocoOperator* op) const override {
736     op->shape.insert(op->shape.end(), options.new_shape()->begin(),
737                      options.new_shape()->end());
738   }
739 };
740 
741 class Softmax
742     : public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions,
743                              ::tflite::BuiltinOptions_SoftmaxOptions> {
744  public:
745   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const746   flatbuffers::Offset<TfLiteOptions> WriteOptions(
747       const TocoOperator& op,
748       flatbuffers::FlatBufferBuilder* builder) const override {
749     return ::tflite::CreateSoftmaxOptions(*builder, op.beta);
750   }
751 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const752   void ReadOptions(const TfLiteOptions& options,
753                    TocoOperator* op) const override {
754     op->beta = options.beta();
755   }
756 };
757 
758 class SpaceToDepth
759     : public BuiltinOperator<SpaceToDepthOperator,
760                              ::tflite::SpaceToDepthOptions,
761                              ::tflite::BuiltinOptions_SpaceToDepthOptions> {
762  public:
763   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const764   flatbuffers::Offset<TfLiteOptions> WriteOptions(
765       const TocoOperator& op,
766       flatbuffers::FlatBufferBuilder* builder) const override {
767     return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size);
768   }
769 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const770   void ReadOptions(const TfLiteOptions& options,
771                    TocoOperator* op) const override {
772     op->block_size = options.block_size();
773   }
774 };
775 
776 class Transpose
777     : public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions,
778                              ::tflite::BuiltinOptions_TransposeOptions> {
779  public:
780   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const781   flatbuffers::Offset<TfLiteOptions> WriteOptions(
782       const TocoOperator& op,
783       flatbuffers::FlatBufferBuilder* builder) const override {
784     return ::tflite::CreateTransposeOptions(*builder);
785   }
786 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const787   void ReadOptions(const TfLiteOptions& options,
788                    TocoOperator* op) const override {}
789 };
790 
791 class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
792                                     ::tflite::BuiltinOptions_LSTMOptions> {
793  public:
794   using BuiltinOperator::BuiltinOperator;
795 
GetKernelType(LstmCellOperator::KernelType type) const796   ::tflite::LSTMKernelType GetKernelType(
797       LstmCellOperator::KernelType type) const {
798     switch (type) {
799       case LstmCellOperator::KERNEL_BASIC:
800         return ::tflite::LSTMKernelType_BASIC;
801         break;
802       case LstmCellOperator::KERNEL_FULL:
803         return ::tflite::LSTMKernelType_FULL;
804         break;
805       default:
806         LOG(ERROR) << "Unhandled Kernel Type";
807         return static_cast<::tflite::LSTMKernelType>(-1);
808     }
809   }
810 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const811   flatbuffers::Offset<TfLiteOptions> WriteOptions(
812       const TocoOperator& op,
813       flatbuffers::FlatBufferBuilder* builder) const override {
814     ::tflite::LSTMKernelType kernel_type = GetKernelType(op.kernel_type);
815 
816     // Current toco converter only supports tanh, no clip.
817     return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
818                                        ::tflite::ActivationFunctionType_TANH,
819                                        /*cell_clip=*/0.0,
820                                        /*proj_clip=*/0.0, kernel_type);
821   }
822 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const823   void ReadOptions(const TfLiteOptions& options,
824                    TocoOperator* op) const override {
825     // Only support tanh activation, so check that tflite type is tanh.
826     CHECK(options.fused_activation_function() ==
827           ::tflite::ActivationFunctionType_TANH);
828 
829     switch (options.kernel_type()) {
830       case ::tflite::LSTMKernelType_BASIC:
831         op->kernel_type = LstmCellOperator::KERNEL_BASIC;
832         break;
833       case ::tflite::LSTMKernelType_FULL:
834         op->kernel_type = LstmCellOperator::KERNEL_FULL;
835         break;
836     }
837   }
838 
GetVersion(const OperatorSignature & op_signature) const839   int GetVersion(const OperatorSignature& op_signature) const override {
840     const auto& lstm_op =
841         static_cast<const LstmCellOperator&>(*op_signature.op);
842     ::tflite::OpSignature op_sig =
843         GetVersioningOpSig(builtin_op(), op_signature);
844     TfLiteLSTMParams lstm_params = {};
845     lstm_params.kernel_type =
846         static_cast<TfLiteLSTMKernelType>(GetKernelType(lstm_op.kernel_type));
847     op_sig.builtin_data = reinterpret_cast<void*>(&lstm_params);
848     return ::tflite::GetBuiltinOperatorVersion(op_sig);
849   }
850 
GetMutatingInputVariables(const Operator & op) const851   std::vector<bool> GetMutatingInputVariables(
852       const Operator& op) const override {
853     const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
854 
855     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
856     switch (lstm_op.kernel_type) {
857       case LstmCellOperator::KERNEL_FULL: {
858         mutating_input_variables[kInputActivationStateTensor] = true;
859         mutating_input_variables[kInputCellStateTensor] = true;
860         break;
861       }
862       case LstmCellOperator::KERNEL_BASIC: {
863         mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true;
864         mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true;
865         break;
866       }
867     }
868     return mutating_input_variables;
869   }
870 };
871 
872 class UnidirectionalSequenceLstm
873     : public BuiltinOperator<
874           UnidirectionalSequenceLstmOperator,
875           ::tflite::UnidirectionalSequenceLSTMOptions,
876           ::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> {
877  public:
878   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const879   flatbuffers::Offset<TfLiteOptions> WriteOptions(
880       const TocoOperator& op,
881       flatbuffers::FlatBufferBuilder* builder) const override {
882     // Current toco converter only supports tanh, no clip.
883     return ::tflite::CreateUnidirectionalSequenceLSTMOptions(
884         *builder, /*fused_activation_function=*/
885         ::tflite::ActivationFunctionType_TANH,
886         /*cell_clip=*/0.0,
887         /*proj_clip=*/0.0,
888         /*time_major=*/true);
889   }
890 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const891   void ReadOptions(const TfLiteOptions& options,
892                    TocoOperator* op) const override {
893     // Only support tanh activation, so check that tflite type is tanh.
894     DCHECK(options.fused_activation_function() ==
895            ::tflite::ActivationFunctionType_TANH);
896   }
897 
GetMutatingInputVariables(const Operator & op) const898   std::vector<bool> GetMutatingInputVariables(
899       const Operator& op) const override {
900     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
901     mutating_input_variables[kInputActivationStateTensor] = true;
902     mutating_input_variables[kInputCellStateTensor] = true;
903     return mutating_input_variables;
904   }
905 };
906 
907 class BidirectionalSequenceLstm
908     : public BuiltinOperator<
909           BidirectionalSequenceLstmOperator,
910           ::tflite::BidirectionalSequenceLSTMOptions,
911           ::tflite::BuiltinOptions_BidirectionalSequenceLSTMOptions> {
912  public:
913   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const914   flatbuffers::Offset<TfLiteOptions> WriteOptions(
915       const TocoOperator& op,
916       flatbuffers::FlatBufferBuilder* builder) const override {
917     // Current toco converter only supports tanh, no clip.
918     return ::tflite::CreateBidirectionalSequenceLSTMOptions(
919         *builder, /*fused_activation_function=*/
920         ::tflite::ActivationFunctionType_TANH,
921         /*cell_clip=*/0.0,
922         /*proj_clip=*/0.0,
923         /*merge_outputs=*/op.merge_outputs,
924         /*time_major=*/true);
925   }
926 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const927   void ReadOptions(const TfLiteOptions& options,
928                    TocoOperator* op) const override {
929     // Only support tanh activation, so check that tflite type is tanh.
930     DCHECK(options.fused_activation_function() ==
931            ::tflite::ActivationFunctionType_TANH);
932     op->merge_outputs = options.merge_outputs();
933   }
934 
GetMutatingInputVariables(const Operator & op) const935   std::vector<bool> GetMutatingInputVariables(
936       const Operator& op) const override {
937     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
938     // Forward input activation state.
939     mutating_input_variables[35] = true;
940     // Forward input cell state.
941     mutating_input_variables[36] = true;
942     // Backward input activation state.
943     mutating_input_variables[37] = true;
944     // Backward input cell state.
945     mutating_input_variables[38] = true;
946     return mutating_input_variables;
947   }
948 };
949 
950 class BidirectionalSequenceRnn
951     : public BuiltinOperator<
952           BidirectionalSequenceRnnOperator,
953           ::tflite::BidirectionalSequenceRNNOptions,
954           ::tflite::BuiltinOptions_BidirectionalSequenceRNNOptions> {
955  public:
956   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const957   flatbuffers::Offset<TfLiteOptions> WriteOptions(
958       const TocoOperator& op,
959       flatbuffers::FlatBufferBuilder* builder) const override {
960     // Current toco converter only supports tanh, no clip.
961     return ::tflite::CreateBidirectionalSequenceRNNOptions(
962         *builder, /*time_major=*/true,
963         /*fused_activation_function=*/
964         ::tflite::ActivationFunctionType_TANH,
965         /*merge_outputs=*/op.merge_outputs);
966   }
967 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const968   void ReadOptions(const TfLiteOptions& options,
969                    TocoOperator* op) const override {
970     // Only support tanh activation, so check that tflite type is tanh.
971     DCHECK(options.fused_activation_function() ==
972            ::tflite::ActivationFunctionType_TANH);
973     op->merge_outputs = options.merge_outputs();
974   }
975 
GetMutatingInputVariables(const Operator & op) const976   std::vector<bool> GetMutatingInputVariables(
977       const Operator& op) const override {
978     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
979     // Forward hidden state.
980     mutating_input_variables[4] = true;
981     // Backward hidden state.
982     mutating_input_variables[8] = true;
983     return mutating_input_variables;
984   }
985 };
986 
987 class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions,
988                                     ::tflite::BuiltinOptions_ReducerOptions> {
989  public:
990   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const991   flatbuffers::Offset<TfLiteOptions> WriteOptions(
992       const TocoOperator& op,
993       flatbuffers::FlatBufferBuilder* builder) const override {
994     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
995   }
996 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const997   void ReadOptions(const TfLiteOptions& options,
998                    TocoOperator* op) const override {
999     op->keep_dims = options.keep_dims();
1000   }
1001 };
1002 
1003 class Sum
1004     : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
1005                              ::tflite::BuiltinOptions_ReducerOptions> {
1006  public:
1007   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1008   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1009       const TocoOperator& op,
1010       flatbuffers::FlatBufferBuilder* builder) const override {
1011     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1012   }
1013 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1014   void ReadOptions(const TfLiteOptions& options,
1015                    TocoOperator* op) const override {
1016     op->keep_dims = options.keep_dims();
1017   }
1018 };
1019 
1020 class ReduceMax
1021     : public BuiltinOperator<TensorFlowMaxOperator, ::tflite::ReducerOptions,
1022                              ::tflite::BuiltinOptions_ReducerOptions> {
1023  public:
1024   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1025   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1026       const TocoOperator& op,
1027       flatbuffers::FlatBufferBuilder* builder) const override {
1028     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1029   }
1030 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1031   void ReadOptions(const TfLiteOptions& options,
1032                    TocoOperator* op) const override {
1033     op->keep_dims = options.keep_dims();
1034   }
1035 };
1036 
1037 class ReduceMin
1038     : public BuiltinOperator<TensorFlowMinOperator, ::tflite::ReducerOptions,
1039                              ::tflite::BuiltinOptions_ReducerOptions> {
1040  public:
1041   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1042   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1043       const TocoOperator& op,
1044       flatbuffers::FlatBufferBuilder* builder) const override {
1045     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1046   }
1047 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1048   void ReadOptions(const TfLiteOptions& options,
1049                    TocoOperator* op) const override {
1050     op->keep_dims = options.keep_dims();
1051   }
1052 };
1053 
1054 class ReduceProd
1055     : public BuiltinOperator<TensorFlowProdOperator, ::tflite::ReducerOptions,
1056                              ::tflite::BuiltinOptions_ReducerOptions> {
1057  public:
1058   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1059   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1060       const TocoOperator& op,
1061       flatbuffers::FlatBufferBuilder* builder) const override {
1062     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1063   }
1064 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1065   void ReadOptions(const TfLiteOptions& options,
1066                    TocoOperator* op) const override {
1067     op->keep_dims = options.keep_dims();
1068   }
1069 };
1070 
1071 class ReduceAny
1072     : public BuiltinOperator<TensorFlowAnyOperator, ::tflite::ReducerOptions,
1073                              ::tflite::BuiltinOptions_ReducerOptions> {
1074  public:
1075   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1076   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1077       const TocoOperator& op,
1078       flatbuffers::FlatBufferBuilder* builder) const override {
1079     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1080   }
1081 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1082   void ReadOptions(const TfLiteOptions& options,
1083                    TocoOperator* op) const override {
1084     op->keep_dims = options.keep_dims();
1085   }
1086 };
1087 
1088 class ResizeBilinear
1089     : public BuiltinOperator<ResizeBilinearOperator,
1090                              ::tflite::ResizeBilinearOptions,
1091                              ::tflite::BuiltinOptions_ResizeBilinearOptions> {
1092  public:
1093   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1094   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1095       const TocoOperator& op,
1096       flatbuffers::FlatBufferBuilder* builder) const override {
1097     return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners,
1098                                                  op.half_pixel_centers);
1099   }
1100 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1101   void ReadOptions(const TfLiteOptions& options,
1102                    TocoOperator* op) const override {
1103     op->align_corners = options.align_corners();
1104     op->half_pixel_centers = options.half_pixel_centers();
1105   }
1106 
GetVersion(const OperatorSignature & op_signature) const1107   int GetVersion(const OperatorSignature& op_signature) const override {
1108     const auto& resize_bilinear_op =
1109         static_cast<const ResizeBilinearOperator&>(*op_signature.op);
1110     ::tflite::OpSignature op_sig =
1111         GetVersioningOpSig(builtin_op(), op_signature);
1112     TfLiteResizeBilinearParams resize_bilinear_params = {};
1113     resize_bilinear_params.half_pixel_centers =
1114         resize_bilinear_op.half_pixel_centers;
1115     resize_bilinear_params.align_corners = resize_bilinear_op.align_corners;
1116     op_sig.builtin_data = reinterpret_cast<void*>(&resize_bilinear_params);
1117     return ::tflite::GetBuiltinOperatorVersion(op_sig);
1118   }
1119 };
1120 
1121 class ResizeNearestNeighbor
1122     : public BuiltinOperator<
1123           ResizeNearestNeighborOperator, ::tflite::ResizeNearestNeighborOptions,
1124           ::tflite::BuiltinOptions_ResizeNearestNeighborOptions> {
1125  public:
1126   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1127   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1128       const TocoOperator& op,
1129       flatbuffers::FlatBufferBuilder* builder) const override {
1130     return ::tflite::CreateResizeNearestNeighborOptions(
1131         *builder, op.align_corners, op.half_pixel_centers);
1132   }
1133 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1134   void ReadOptions(const TfLiteOptions& options,
1135                    TocoOperator* op) const override {
1136     op->align_corners = options.align_corners();
1137     op->half_pixel_centers = options.half_pixel_centers();
1138   }
1139 
GetVersion(const OperatorSignature & op_signature) const1140   int GetVersion(const OperatorSignature& op_signature) const override {
1141     const auto& resize_nn_op =
1142         static_cast<const ResizeNearestNeighborOperator&>(*op_signature.op);
1143     ::tflite::OpSignature op_sig =
1144         GetVersioningOpSig(builtin_op(), op_signature);
1145     TfLiteResizeNearestNeighborParams resize_nearest_neighbor_params = {};
1146     resize_nearest_neighbor_params.half_pixel_centers =
1147         resize_nn_op.half_pixel_centers;
1148     resize_nearest_neighbor_params.align_corners = resize_nn_op.align_corners;
1149     op_sig.builtin_data =
1150         reinterpret_cast<void*>(&resize_nearest_neighbor_params);
1151     return ::tflite::GetBuiltinOperatorVersion(op_sig);
1152   }
1153 };
1154 
1155 class Squeeze
1156     : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
1157                              ::tflite::BuiltinOptions_SqueezeOptions> {
1158  public:
1159   using BuiltinOperator::BuiltinOperator;
1160 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1161   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1162       const TocoOperator& op,
1163       flatbuffers::FlatBufferBuilder* builder) const override {
1164     auto squeeze_dims = builder->CreateVector(op.squeeze_dims);
1165     return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims);
1166   }
1167 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1168   void ReadOptions(const TfLiteOptions& options,
1169                    TocoOperator* op) const override {
1170     op->squeeze_dims.insert(op->squeeze_dims.end(),
1171                             options.squeeze_dims()->begin(),
1172                             options.squeeze_dims()->end());
1173   }
1174 };
1175 
1176 class Split
1177     : public BuiltinOperator<TensorFlowSplitOperator, ::tflite::SplitOptions,
1178                              ::tflite::BuiltinOptions_SplitOptions> {
1179  public:
1180   using BuiltinOperator::BuiltinOperator;
1181 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1182   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1183       const TocoOperator& op,
1184       flatbuffers::FlatBufferBuilder* builder) const override {
1185     return ::tflite::CreateSplitOptions(*builder, op.num_split);
1186   }
1187 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1188   void ReadOptions(const TfLiteOptions& options,
1189                    TocoOperator* op) const override {
1190     op->num_split = options.num_splits();
1191   }
1192 };
1193 
1194 class SplitV
1195     : public BuiltinOperator<TensorFlowSplitVOperator, ::tflite::SplitVOptions,
1196                              ::tflite::BuiltinOptions_SplitVOptions> {
1197  public:
1198   using BuiltinOperator::BuiltinOperator;
1199 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1200   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1201       const TocoOperator& op,
1202       flatbuffers::FlatBufferBuilder* builder) const override {
1203     return ::tflite::CreateSplitVOptions(*builder, op.num_split);
1204   }
1205 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1206   void ReadOptions(const TfLiteOptions& options,
1207                    TocoOperator* op) const override {
1208     op->num_split = options.num_splits();
1209   }
1210 };
1211 
1212 class StridedSlice
1213     : public BuiltinOperator<StridedSliceOperator,
1214                              ::tflite::StridedSliceOptions,
1215                              ::tflite::BuiltinOptions_StridedSliceOptions> {
1216  public:
1217   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1218   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1219       const TocoOperator& op,
1220       flatbuffers::FlatBufferBuilder* builder) const override {
1221     return ::tflite::CreateStridedSliceOptions(
1222         *builder, op.begin_mask, op.end_mask, op.ellipsis_mask,
1223         op.new_axis_mask, op.shrink_axis_mask);
1224   }
1225 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1226   void ReadOptions(const TfLiteOptions& options,
1227                    TocoOperator* op) const override {
1228     op->begin_mask = options.begin_mask();
1229     op->end_mask = options.end_mask();
1230     op->ellipsis_mask = options.ellipsis_mask();
1231     op->new_axis_mask = options.new_axis_mask();
1232     op->shrink_axis_mask = options.shrink_axis_mask();
1233   }
1234 
GetVersion(const OperatorSignature & op_signature) const1235   int GetVersion(const OperatorSignature& op_signature) const override {
1236     const auto& ss_op =
1237         static_cast<const StridedSliceOperator&>(*op_signature.op);
1238     ::tflite::OpSignature op_sig =
1239         GetVersioningOpSig(builtin_op(), op_signature);
1240     op_sig.ext_options.strided_slice.num_dims = ss_op.start_indices.size();
1241     TfLiteStridedSliceParams strided_slice_params = {};
1242     strided_slice_params.ellipsis_mask = ss_op.ellipsis_mask;
1243     strided_slice_params.new_axis_mask = ss_op.new_axis_mask;
1244     op_sig.builtin_data = reinterpret_cast<void*>(&strided_slice_params);
1245     return ::tflite::GetBuiltinOperatorVersion(op_sig);
1246   }
1247 };
1248 
1249 class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
1250                                        ::tflite::BuiltinOptions_TopKV2Options> {
1251  public:
1252   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1253   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1254       const TocoOperator& op,
1255       flatbuffers::FlatBufferBuilder* builder) const override {
1256     return ::tflite::CreateTopKV2Options(*builder);
1257   }
1258 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1259   void ReadOptions(const TfLiteOptions& options,
1260                    TocoOperator* op) const override {}
1261 };
1262 
1263 class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
1264                                       ::tflite::BuiltinOptions_ArgMaxOptions> {
1265  public:
1266   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1267   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1268       const TocoOperator& op,
1269       flatbuffers::FlatBufferBuilder* builder) const override {
1270     return ::tflite::CreateArgMaxOptions(
1271         *builder, DataType::Serialize(op.output_data_type));
1272   }
1273 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1274   void ReadOptions(const TfLiteOptions& options,
1275                    TocoOperator* op) const override {
1276     op->output_data_type = DataType::Deserialize(options.output_type());
1277   }
1278 };
1279 
1280 class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
1281                                       ::tflite::BuiltinOptions_ArgMinOptions> {
1282  public:
1283   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1284   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1285       const TocoOperator& op,
1286       flatbuffers::FlatBufferBuilder* builder) const override {
1287     return ::tflite::CreateArgMinOptions(
1288         *builder, DataType::Serialize(op.output_data_type));
1289   }
1290 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1291   void ReadOptions(const TfLiteOptions& options,
1292                    TocoOperator* op) const override {
1293     op->output_data_type = DataType::Deserialize(options.output_type());
1294   }
1295 };
1296 
1297 class TransposeConv
1298     : public BuiltinOperator<TransposeConvOperator,
1299                              ::tflite::TransposeConvOptions,
1300                              ::tflite::BuiltinOptions_TransposeConvOptions> {
1301  public:
1302   using BuiltinOperator::BuiltinOperator;
1303 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1304   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1305       const TocoOperator& op,
1306       flatbuffers::FlatBufferBuilder* builder) const override {
1307     auto padding = Padding::Serialize(op.padding.type);
1308     return ::tflite::CreateTransposeConvOptions(
1309         *builder, padding, op.stride_width, op.stride_height);
1310   }
1311 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1312   void ReadOptions(const TfLiteOptions& options,
1313                    TocoOperator* op) const override {
1314     op->padding.type = Padding::Deserialize(options.padding());
1315     op->stride_width = options.stride_w();
1316     op->stride_height = options.stride_h();
1317   }
1318 };
1319 
1320 class SparseToDense
1321     : public BuiltinOperator<SparseToDenseOperator,
1322                              ::tflite::SparseToDenseOptions,
1323                              ::tflite::BuiltinOptions_SparseToDenseOptions> {
1324  public:
1325   using BuiltinOperator::BuiltinOperator;
1326 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1327   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1328       const TocoOperator& op,
1329       flatbuffers::FlatBufferBuilder* builder) const override {
1330     return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices);
1331   }
1332 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1333   void ReadOptions(const TfLiteOptions& options,
1334                    TocoOperator* op) const override {
1335     op->validate_indices = options.validate_indices();
1336   }
1337 };
1338 
1339 class ExpandDims
1340     : public BuiltinOperator<ExpandDimsOperator, ::tflite::ExpandDimsOptions,
1341                              ::tflite::BuiltinOptions_ExpandDimsOptions> {
1342  public:
1343   using BuiltinOperator::BuiltinOperator;
1344 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1345   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1346       const TocoOperator& op,
1347       flatbuffers::FlatBufferBuilder* builder) const override {
1348     return ::tflite::CreateExpandDimsOptions(*builder);
1349   }
1350 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1351   void ReadOptions(const TfLiteOptions& options,
1352                    TocoOperator* op) const override {}
1353 };
1354 
1355 class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions,
1356                                     ::tflite::BuiltinOptions_PackOptions> {
1357  public:
1358   using BuiltinOperator::BuiltinOperator;
1359 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1360   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1361       const TocoOperator& op,
1362       flatbuffers::FlatBufferBuilder* builder) const override {
1363     return ::tflite::CreatePackOptions(*builder, op.values_count, op.axis);
1364   }
1365 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1366   void ReadOptions(const TfLiteOptions& options,
1367                    TocoOperator* op) const override {
1368     op->values_count = options.values_count();
1369     op->axis = options.axis();
1370   }
1371 };
1372 
1373 class Shape
1374     : public BuiltinOperator<TensorFlowShapeOperator, ::tflite::ShapeOptions,
1375                              ::tflite::BuiltinOptions_ShapeOptions> {
1376  public:
1377   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1378   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1379       const TocoOperator& op,
1380       flatbuffers::FlatBufferBuilder* builder) const override {
1381     return ::tflite::CreateShapeOptions(
1382         *builder, DataType::Serialize(op.output_data_type));
1383   }
1384 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1385   void ReadOptions(const TfLiteOptions& options,
1386                    TocoOperator* op) const override {
1387     op->output_data_type = DataType::Deserialize(options.out_type());
1388   }
1389 };
1390 
1391 class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
1392                                       ::tflite::BuiltinOptions_OneHotOptions> {
1393  public:
1394   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1395   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1396       const TocoOperator& op,
1397       flatbuffers::FlatBufferBuilder* builder) const override {
1398     return ::tflite::CreateOneHotOptions(*builder, op.axis);
1399   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1400   void ReadOptions(const TfLiteOptions& options,
1401                    TocoOperator* op) const override {
1402     op->axis = options.axis();
1403   }
1404 };
1405 
1406 class CTCBeamSearchDecoder
1407     : public CustomOperator<CTCBeamSearchDecoderOperator> {
1408  public:
1409   using CustomOperator::CustomOperator;
1410 
WriteOptions(const TocoOperator & op,flexbuffers::Builder * fbb) const1411   void WriteOptions(const TocoOperator& op,
1412                     flexbuffers::Builder* fbb) const override {
1413     fbb->Int("beam_width", op.beam_width);
1414     fbb->Int("top_paths", op.top_paths);
1415     fbb->Bool("merge_repeated", op.merge_repeated);
1416   }
1417 
ReadOptions(const flexbuffers::Map & m,TocoOperator * op) const1418   void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
1419     op->beam_width = m["beam_width"].AsInt32();
1420     op->top_paths = m["top_paths"].AsInt32();
1421     op->merge_repeated = m["merge_repeated"].AsBool();
1422   }
1423 
GetVersion(const OperatorSignature & op_signature) const1424   int GetVersion(const OperatorSignature& op_signature) const override {
1425     return 1;
1426   }
1427 };
1428 
1429 class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
1430                                       ::tflite::BuiltinOptions_UnpackOptions> {
1431  public:
1432   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1433   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1434       const TocoOperator& op,
1435       flatbuffers::FlatBufferBuilder* builder) const override {
1436     return ::tflite::CreateUnpackOptions(*builder, op.num, op.axis);
1437   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1438   void ReadOptions(const TfLiteOptions& options,
1439                    TocoOperator* op) const override {
1440     op->num = options.num();
1441     op->axis = options.axis();
1442   }
1443 
GetVersion(const OperatorSignature & op_signature) const1444   int GetVersion(const OperatorSignature& op_signature) const override {
1445     const std::string& input_name = op_signature.op->inputs[0];
1446     const Array& input_array = op_signature.model->GetArray(input_name);
1447     // If the op take int8/uint8 input, it is version 2.
1448     if (input_array.data_type == ArrayDataType::kInt8 ||
1449         input_array.data_type == ArrayDataType::kUint8) {
1450       return 2;
1451     }
1452     // If the op take bool input, it is version 3.
1453     if (input_array.data_type == ArrayDataType::kBool) {
1454       return 3;
1455     }
1456     return 1;
1457   }
1458 };
1459 
1460 class LeakyRelu
1461     : public BuiltinOperator<LeakyReluOperator, ::tflite::LeakyReluOptions,
1462                              ::tflite::BuiltinOptions_LeakyReluOptions> {
1463  public:
1464   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1465   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1466       const TocoOperator& op,
1467       flatbuffers::FlatBufferBuilder* builder) const override {
1468     return ::tflite::CreateLeakyReluOptions(*builder, op.alpha);
1469   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1470   void ReadOptions(const TfLiteOptions& options,
1471                    TocoOperator* op) const override {
1472     op->alpha = options.alpha();
1473   }
1474 };
1475 
1476 class SquaredDifference
1477     : public BuiltinOperator<
1478           SquaredDifferenceOperator, ::tflite::SquaredDifferenceOptions,
1479           ::tflite::BuiltinOptions_SquaredDifferenceOptions> {
1480  public:
1481   using BuiltinOperator::BuiltinOperator;
1482 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1483   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1484       const TocoOperator& op,
1485       flatbuffers::FlatBufferBuilder* builder) const override {
1486     return ::tflite::CreateSquaredDifferenceOptions(*builder);
1487   }
1488 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1489   void ReadOptions(const TfLiteOptions& options,
1490                    TocoOperator* op) const override {}
1491 };
1492 
1493 class MirrorPad
1494     : public BuiltinOperator<MirrorPadOperator, ::tflite::MirrorPadOptions,
1495                              ::tflite::BuiltinOptions_MirrorPadOptions> {
1496  public:
1497   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1498   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1499       const TocoOperator& op,
1500       flatbuffers::FlatBufferBuilder* builder) const override {
1501     return ::tflite::CreateMirrorPadOptions(
1502         *builder, op.mode == MirrorPadMode::kReflect
1503                       ? ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
1504                       : ::tflite::MirrorPadMode::MirrorPadMode_SYMMETRIC);
1505   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1506   void ReadOptions(const TfLiteOptions& options,
1507                    TocoOperator* op) const override {
1508     op->mode = options.mode() == ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
1509                    ? MirrorPadMode::kReflect
1510                    : MirrorPadMode::kSymmetric;
1511   }
1512 };
1513 
1514 class Unique : public BuiltinOperator<UniqueOperator, ::tflite::UniqueOptions,
1515                                       ::tflite::BuiltinOptions_UniqueOptions> {
1516  public:
1517   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1518   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1519       const TocoOperator& op,
1520       flatbuffers::FlatBufferBuilder* builder) const override {
1521     const UniqueOperator& unique_op = static_cast<const UniqueOperator&>(op);
1522     return ::tflite::CreateUniqueOptions(
1523         *builder, unique_op.idx_out_type == toco::ArrayDataType::kInt64
1524                       ? ::tflite::TensorType::TensorType_INT64
1525                       : ::tflite::TensorType_INT32);
1526   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1527   void ReadOptions(const TfLiteOptions& options,
1528                    TocoOperator* op) const override {
1529     UniqueOperator* unique_op = static_cast<UniqueOperator*>(op);
1530     unique_op->idx_out_type =
1531         options.idx_out_type() == ::tflite::TensorType_INT64
1532             ? toco::ArrayDataType::kInt64
1533             : toco::ArrayDataType::kInt32;
1534   }
1535 };
1536 
1537 class UnidirectionalSequenceRnn
1538     : public BuiltinOperator<UnidirectionalSequenceRnnOperator,
1539                              ::tflite::SequenceRNNOptions,
1540                              ::tflite::BuiltinOptions_SequenceRNNOptions> {
1541  public:
1542   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1543   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1544       const TocoOperator& op,
1545       flatbuffers::FlatBufferBuilder* builder) const override {
1546     return ::tflite::CreateSequenceRNNOptions(
1547         *builder, /*time_major=*/true,
1548         /*fused_activation_function=*/
1549         ::tflite::ActivationFunctionType_TANH);
1550   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1551   void ReadOptions(const TfLiteOptions& options,
1552                    TocoOperator* op) const override {
1553     // Only support tanh activation, so check that tflite type is tanh.
1554     DCHECK(options.fused_activation_function() ==
1555            ::tflite::ActivationFunctionType_TANH);
1556   }
1557 
GetMutatingInputVariables(const Operator & op) const1558   std::vector<bool> GetMutatingInputVariables(
1559       const Operator& op) const override {
1560     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
1561     mutating_input_variables[4] = true;
1562     return mutating_input_variables;
1563   }
1564 };
1565 
1566 class Where : public BuiltinOperator<WhereOperator, ::tflite::WhereOptions,
1567                                      ::tflite::BuiltinOptions_WhereOptions> {
1568  public:
1569   using BuiltinOperator::BuiltinOperator;
1570 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1571   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1572       const TocoOperator& op,
1573       flatbuffers::FlatBufferBuilder* builder) const override {
1574     return ::tflite::CreateWhereOptions(*builder);
1575   }
1576 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1577   void ReadOptions(const TfLiteOptions& options,
1578                    TocoOperator* op) const override {}
1579 };
1580 
WriteFlexOpOptions(const std::string & tensorflow_node_def)1581 std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
1582     const std::string& tensorflow_node_def) {
1583   auto fbb = std::make_unique<flexbuffers::Builder>();
1584 
1585   ::tensorflow::NodeDef node_def;
1586   if (!node_def.ParseFromString(tensorflow_node_def)) {
1587     LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
1588     return {};
1589   }
1590 
1591   fbb->Vector([&]() {
1592     fbb->String(node_def.op());
1593     fbb->String(tensorflow_node_def);
1594   });
1595   fbb->Finish();
1596   LOG(INFO) << "Writing flex op: " << node_def.op();
1597   return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1598 }
1599 
1600 class TensorFlowUnsupported : public BaseOperator {
1601  public:
TensorFlowUnsupported(const std::string & name,OperatorType type,bool enable_select_tf_ops)1602   TensorFlowUnsupported(const std::string& name, OperatorType type,
1603                         bool enable_select_tf_ops)
1604       : BaseOperator(name, type), enable_select_tf_ops_(enable_select_tf_ops) {}
1605 
Serialize(const Operator & op,flatbuffers::FlatBufferBuilder * builder) const1606   Options Serialize(const Operator& op,
1607                     flatbuffers::FlatBufferBuilder* builder) const override {
1608     auto fbb =
1609         WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op));
1610     if (fbb) {
1611       return Options::Custom(builder->CreateVector(fbb->GetBuffer()));
1612     } else {
1613       return Options::Custom(0);
1614     }
1615   }
1616 
Deserialize(const BuiltinOptions * builtin_options,const CustomOptions * custom_options) const1617   std::unique_ptr<Operator> Deserialize(
1618       const BuiltinOptions* builtin_options,
1619       const CustomOptions* custom_options) const override {
1620     // Deserializing Flex ops doesn't work now.
1621     // TODO(ycling): Revisit and decide if we should fix the flow for importing
1622     // TFLite models with Flex ops.
1623     auto op = std::make_unique<TensorFlowUnsupportedOperator>();
1624     if (custom_options) {
1625       auto flexbuffer_map =
1626           flexbuffers::GetRoot(custom_options->data(), custom_options->size())
1627               .AsMap();
1628       ReadOptions(flexbuffer_map, op.get());
1629     }
1630     return std::unique_ptr<Operator>(op.release());
1631   }
1632 
WriteOptions(const TensorFlowUnsupportedOperator & op) const1633   std::unique_ptr<flexbuffers::Builder> WriteOptions(
1634       const TensorFlowUnsupportedOperator& op) const {
1635     if (enable_select_tf_ops_) {
1636       return WriteFlexOpOptions(op.tensorflow_node_def);
1637     }
1638     auto fbb = std::make_unique<flexbuffers::Builder>();
1639 
1640     ::tensorflow::NodeDef node_def;
1641     if (!node_def.ParseFromString(op.tensorflow_node_def)) {
1642       LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
1643       return std::unique_ptr<flexbuffers::Builder>();
1644     }
1645 
1646     if (ShouldExportAsFlexOp(enable_select_tf_ops_, node_def.op())) {
1647       fbb->Vector([&]() {
1648         fbb->String(node_def.op());
1649         fbb->String(op.tensorflow_node_def);
1650       });
1651       fbb->Finish();
1652       LOG(INFO) << "Writing flex op: " << node_def.op();
1653       return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1654     }
1655 
1656     bool has_valid_attr = false;
1657     size_t map_start = fbb->StartMap();
1658     for (const auto& pair : node_def.attr()) {
1659       const char* key = pair.first.c_str();
1660       const auto& attr = pair.second;
1661       switch (attr.value_case()) {
1662         case ::tensorflow::AttrValue::kS:
1663           fbb->String(key, attr.s());
1664           has_valid_attr = true;
1665           break;
1666         case ::tensorflow::AttrValue::kI:
1667           fbb->Int(key, attr.i());
1668           has_valid_attr = true;
1669           break;
1670         case ::tensorflow::AttrValue::kF:
1671           fbb->Float(key, attr.f());
1672           has_valid_attr = true;
1673           break;
1674         case ::tensorflow::AttrValue::kB:
1675           fbb->Bool(key, attr.b());
1676           has_valid_attr = true;
1677           break;
1678         case tensorflow::AttrValue::kList:
1679           if (attr.list().s_size() > 0) {
1680             auto start = fbb->StartVector(key);
1681             for (const std::string& v : attr.list().s()) {
1682               fbb->Add(v);
1683             }
1684             fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1685             has_valid_attr = true;
1686           } else if (attr.list().i_size() > 0) {
1687             auto start = fbb->StartVector(key);
1688             for (const int64_t v : attr.list().i()) {
1689               fbb->Add(v);
1690             }
1691             fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1692             has_valid_attr = true;
1693           } else if (attr.list().f_size() > 0) {
1694             auto start = fbb->StartVector(key);
1695             for (const float v : attr.list().f()) {
1696               fbb->Add(v);
1697             }
1698             fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1699             has_valid_attr = true;
1700           } else {
1701             LOG(WARNING)
1702                 << "Ignoring unsupported type in list attribute with key '"
1703                 << key << "'";
1704           }
1705           break;
1706         default:
1707           LOG(WARNING) << "Ignoring unsupported attribute type with key '"
1708                        << key << "'";
1709           break;
1710       }
1711     }
1712     if (!has_valid_attr) {
1713       return std::unique_ptr<flexbuffers::Builder>();
1714     }
1715     fbb->EndMap(map_start);
1716     fbb->Finish();
1717     return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1718   }
1719 
ReadOptions(const flexbuffers::Map & m,TensorFlowUnsupportedOperator * op) const1720   void ReadOptions(const flexbuffers::Map& m,
1721                    TensorFlowUnsupportedOperator* op) const {
1722     ::tensorflow::NodeDef node_def;
1723     auto attr = node_def.mutable_attr();
1724 
1725     const auto& keys = m.Keys();
1726     for (size_t i = 0; i < keys.size(); ++i) {
1727       const auto key = keys[i].AsKey();
1728       const auto& value = m[key];
1729       switch (value.GetType()) {
1730         case flexbuffers::FBT_STRING:
1731           (*attr)[key].set_s(value.AsString().c_str());
1732           break;
1733         case flexbuffers::FBT_INT:
1734           (*attr)[key].set_i(value.AsInt64());
1735           break;
1736         case flexbuffers::FBT_FLOAT:
1737           (*attr)[key].set_f(value.AsFloat());
1738           break;
1739         case flexbuffers::FBT_BOOL:
1740           (*attr)[key].set_b(value.AsBool());
1741           if (std::string(key) == "_output_quantized") {
1742             op->quantized = value.AsBool();
1743           }
1744           if (std::string(key) ==
1745               "_support_output_type_float_in_quantized_op") {
1746             op->support_output_type_float_in_quantized_op = value.AsBool();
1747           }
1748           break;
1749         case flexbuffers::FBT_VECTOR_INT: {
1750           auto* list = (*attr)[key].mutable_list();
1751           const auto& vector = value.AsTypedVector();
1752           for (size_t i = 0; i < vector.size(); i++) {
1753             list->add_i(vector[i].AsInt64());
1754           }
1755           break;
1756         }
1757         case flexbuffers::FBT_VECTOR_FLOAT: {
1758           auto* list = (*attr)[key].mutable_list();
1759           const auto& vector = value.AsTypedVector();
1760           for (size_t i = 0; i < vector.size(); i++) {
1761             list->add_f(vector[i].AsFloat());
1762           }
1763           break;
1764         }
1765         case 15 /* TO_DO(wvo): flexbuffers::FBT_VECTOR_STRING_DEPRECATED*/: {
1766           auto* list = (*attr)[key].mutable_list();
1767           const auto& vector = value.AsTypedVector();
1768           for (size_t i = 0; i < vector.size(); i++) {
1769             list->add_s(vector[i].AsString().str());
1770           }
1771           break;
1772         }
1773         default:
1774           LOG(WARNING) << "Ignoring unsupported attribute type with key '"
1775                        << key << "'";
1776           break;
1777       }
1778     }
1779     node_def.SerializeToString(&op->tensorflow_node_def);
1780   }
1781 
GetVersion(const OperatorSignature & op_signature) const1782   int GetVersion(const OperatorSignature& op_signature) const override {
1783     // TODO(ycling): Design and implement a way to plumb the version of
1784     // custom ops.
1785     return 1;
1786   }
1787 
1788  private:
1789   const bool enable_select_tf_ops_;
1790 };
1791 
1792 class Dequantize
1793     : public BuiltinOperator<DequantizeOperator, ::tflite::DequantizeOptions,
1794                              ::tflite::BuiltinOptions_DequantizeOptions> {
1795  public:
1796   using BuiltinOperator::BuiltinOperator;
1797 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1798   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1799       const TocoOperator& op,
1800       flatbuffers::FlatBufferBuilder* builder) const override {
1801     return ::tflite::CreateDequantizeOptions(*builder);
1802   }
1803 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1804   void ReadOptions(const TfLiteOptions& options,
1805                    TocoOperator* op) const override {}
1806 };
1807 
1808 class ReverseSequence
1809     : public BuiltinOperator<ReverseSequenceOperator,
1810                              ::tflite::ReverseSequenceOptions,
1811                              ::tflite::BuiltinOptions_ReverseSequenceOptions> {
1812  public:
1813   using BuiltinOperator::BuiltinOperator;
1814 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1815   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1816       const TocoOperator& op,
1817       flatbuffers::FlatBufferBuilder* builder) const override {
1818     return ::tflite::CreateReverseSequenceOptions(*builder, op.seq_dim,
1819                                                   op.batch_dim);
1820   }
1821 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1822   void ReadOptions(const TfLiteOptions& options,
1823                    TocoOperator* op) const override {
1824     op->seq_dim = options.seq_dim();
1825     op->batch_dim = options.batch_dim();
1826   }
1827 };
1828 
1829 namespace {
1830 // Build a vector containing all the known operators.
BuildOperatorList(bool enable_select_tf_ops=false)1831 std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
1832     bool enable_select_tf_ops = false) {
1833   std::vector<std::unique_ptr<BaseOperator>> ops;
1834   using tensorflow::MakeUnique;
1835   // Builtin Operators.
1836   ops.push_back(
1837       MakeUnique<Add>(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
1838   ops.push_back(
1839       MakeUnique<AddN>(::tflite::BuiltinOperator_ADD_N, OperatorType::kAddN));
1840   ops.push_back(
1841       MakeUnique<Div>(::tflite::BuiltinOperator_DIV, OperatorType::kDiv));
1842   ops.push_back(
1843       MakeUnique<Sub>(::tflite::BuiltinOperator_SUB, OperatorType::kSub));
1844   ops.push_back(MakeUnique<AveragePool>(
1845       ::tflite::BuiltinOperator_AVERAGE_POOL_2D, OperatorType::kAveragePool));
1846   ops.push_back(
1847       MakeUnique<SpaceToBatchND>(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND,
1848                                  OperatorType::kSpaceToBatchND));
1849   ops.push_back(
1850       MakeUnique<BatchToSpaceND>(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND,
1851                                  OperatorType::kBatchToSpaceND));
1852   ops.push_back(MakeUnique<Concatenation>(
1853       ::tflite::BuiltinOperator_CONCATENATION, OperatorType::kConcatenation));
1854   ops.push_back(MakeUnique<Convolution>(::tflite::BuiltinOperator_CONV_2D,
1855                                         OperatorType::kConv));
1856   ops.push_back(MakeUnique<DepthwiseConvolution>(
1857       ::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
1858       OperatorType::kDepthwiseConv));
1859   ops.push_back(MakeUnique<Dequantize>(::tflite::BuiltinOperator_DEQUANTIZE,
1860                                        OperatorType::kDequantize));
1861   ops.push_back(
1862       MakeUnique<FullyConnected>(::tflite::BuiltinOperator_FULLY_CONNECTED,
1863                                  OperatorType::kFullyConnected));
1864   ops.push_back(MakeUnique<Gather>(::tflite::BuiltinOperator_GATHER,
1865                                    OperatorType::kGather));
1866   ops.push_back(MakeUnique<GatherNd>(::tflite::BuiltinOperator_GATHER_ND,
1867                                      OperatorType::kGatherNd));
1868   ops.push_back(
1869       MakeUnique<L2Normalization>(::tflite::BuiltinOperator_L2_NORMALIZATION,
1870                                   OperatorType::kL2Normalization));
1871   ops.push_back(MakeUnique<L2Pool>(::tflite::BuiltinOperator_L2_POOL_2D,
1872                                    OperatorType::kL2Pool));
1873   ops.push_back(MakeUnique<LocalResponseNormalization>(
1874       ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
1875       OperatorType::kLocalResponseNormalization));
1876   ops.push_back(MakeUnique<MaxPool>(::tflite::BuiltinOperator_MAX_POOL_2D,
1877                                     OperatorType::kMaxPool));
1878   ops.push_back(
1879       MakeUnique<Mul>(::tflite::BuiltinOperator_MUL, OperatorType::kMul));
1880 
1881   ops.push_back(
1882       MakeUnique<Pad>(::tflite::BuiltinOperator_PAD, OperatorType::kPad));
1883   ops.push_back(
1884       MakeUnique<PadV2>(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2));
1885   ops.push_back(MakeUnique<Reshape>(::tflite::BuiltinOperator_RESHAPE,
1886                                     OperatorType::kReshape));
1887   ops.push_back(MakeUnique<Softmax>(::tflite::BuiltinOperator_SOFTMAX,
1888                                     OperatorType::kSoftmax));
1889   ops.push_back(MakeUnique<SpaceToDepth>(
1890       ::tflite::BuiltinOperator_SPACE_TO_DEPTH, OperatorType::kSpaceToDepth));
1891   ops.push_back(MakeUnique<DepthToSpace>(
1892       ::tflite::BuiltinOperator_DEPTH_TO_SPACE, OperatorType::kDepthToSpace));
1893   ops.push_back(
1894       MakeUnique<Svdf>(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
1895   ops.push_back(MakeUnique<Transpose>(::tflite::BuiltinOperator_TRANSPOSE,
1896                                       OperatorType::kTranspose));
1897   ops.push_back(
1898       MakeUnique<Mean>(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
1899   ops.push_back(
1900       MakeUnique<Sum>(::tflite::BuiltinOperator_SUM, OperatorType::kSum));
1901   ops.push_back(MakeUnique<ReduceProd>(::tflite::BuiltinOperator_REDUCE_PROD,
1902                                        OperatorType::kReduceProd));
1903   ops.push_back(MakeUnique<ReduceMax>(::tflite::BuiltinOperator_REDUCE_MAX,
1904                                       OperatorType::kReduceMax));
1905   ops.push_back(MakeUnique<ReduceMin>(::tflite::BuiltinOperator_REDUCE_MIN,
1906                                       OperatorType::kReduceMin));
1907   ops.push_back(MakeUnique<ReduceAny>(::tflite::BuiltinOperator_REDUCE_ANY,
1908                                       OperatorType::kAny));
1909   ops.push_back(
1910       MakeUnique<ResizeBilinear>(::tflite::BuiltinOperator_RESIZE_BILINEAR,
1911                                  OperatorType::kResizeBilinear));
1912   ops.push_back(MakeUnique<ResizeNearestNeighbor>(
1913       ::tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
1914       OperatorType::kResizeNearestNeighbor));
1915   ops.push_back(MakeUnique<Squeeze>(::tflite::BuiltinOperator_SQUEEZE,
1916                                     OperatorType::kSqueeze));
1917   ops.push_back(
1918       MakeUnique<Split>(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit));
1919   ops.push_back(MakeUnique<SplitV>(::tflite::BuiltinOperator_SPLIT_V,
1920                                    OperatorType::kSplitV));
1921   ops.push_back(MakeUnique<StridedSlice>(
1922       ::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice));
1923   ops.push_back(MakeUnique<TopK_V2>(::tflite::BuiltinOperator_TOPK_V2,
1924                                     OperatorType::kTopK_V2));
1925   ops.push_back(MakeUnique<Lstm>(::tflite::BuiltinOperator_LSTM,
1926                                  OperatorType::kLstmCell));
1927   ops.push_back(
1928       MakeUnique<Cast>(::tflite::BuiltinOperator_CAST, OperatorType::kCast));
1929   ops.push_back(MakeUnique<ArgMax>(::tflite::BuiltinOperator_ARG_MAX,
1930                                    OperatorType::kArgMax));
1931   ops.push_back(MakeUnique<ArgMin>(::tflite::BuiltinOperator_ARG_MIN,
1932                                    OperatorType::kArgMin));
1933   ops.push_back(
1934       MakeUnique<Tile>(::tflite::BuiltinOperator_TILE, OperatorType::kTile));
1935   ops.push_back(MakeUnique<ExpandDims>(::tflite::BuiltinOperator_EXPAND_DIMS,
1936                                        OperatorType::kExpandDims));
1937   ops.push_back(MakeUnique<TransposeConv>(
1938       ::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv));
1939   ops.push_back(MakeUnique<SparseToDense>(
1940       ::tflite::BuiltinOperator_SPARSE_TO_DENSE, OperatorType::kSparseToDense));
1941   ops.push_back(
1942       MakeUnique<Shape>(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape));
1943   ops.push_back(MakeUnique<FakeQuant>(::tflite::BuiltinOperator_FAKE_QUANT,
1944                                       OperatorType::kFakeQuant));
1945   ops.push_back(
1946       MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
1947   ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>(
1948       ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
1949       OperatorType::kUnidirectionalSequenceLstm));
1950   ops.emplace_back(MakeUnique<BidirectionalSequenceLstm>(
1951       ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
1952       OperatorType::kBidirectionalSequenceLstm));
1953   ops.emplace_back(MakeUnique<BidirectionalSequenceRnn>(
1954       ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
1955       OperatorType::kBidirectionalSequenceRnn));
1956   ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT,
1957                                    OperatorType::kOneHot));
1958   ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK,
1959                                    OperatorType::kUnpack));
1960   ops.push_back(MakeUnique<LeakyRelu>(::tflite::BuiltinOperator_LEAKY_RELU,
1961                                       OperatorType::kLeakyRelu));
1962   ops.push_back(MakeUnique<SquaredDifference>(
1963       ::tflite::BuiltinOperator_SQUARED_DIFFERENCE,
1964       OperatorType::kSquaredDifference));
1965   ops.push_back(MakeUnique<MirrorPad>(::tflite::BuiltinOperator_MIRROR_PAD,
1966                                       OperatorType::kMirrorPad));
1967   ops.push_back(MakeUnique<Unique>(::tflite::BuiltinOperator_UNIQUE,
1968                                    OperatorType::kUnique));
1969   ops.push_back(MakeUnique<UnidirectionalSequenceRnn>(
1970       ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
1971       OperatorType::kUnidirectionalSequenceRnn));
1972   ops.push_back(
1973       MakeUnique<Where>(::tflite::BuiltinOperator_WHERE, OperatorType::kWhere));
1974   ops.push_back(
1975       MakeUnique<ReverseSequence>(::tflite::BuiltinOperator_REVERSE_SEQUENCE,
1976                                   OperatorType::kReverseSequence));
1977   ops.push_back(MakeUnique<SimpleOperator<MatrixDiagOperator>>(
1978       ::tflite::BuiltinOperator_MATRIX_DIAG, OperatorType::kMatrixDiag));
1979   ops.push_back(MakeUnique<SimpleOperator<MatrixSetDiagOperator>>(
1980       ::tflite::BuiltinOperator_MATRIX_SET_DIAG, OperatorType::kMatrixSetDiag));
1981   // Custom Operators.
1982   ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
1983       "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
1984   ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED",
1985                                                   OperatorType::kUnsupported,
1986                                                   enable_select_tf_ops));
1987 
1988   // SimpleOperator was designed to export CUSTOM TF Lite ops, but has since
1989   // been modified to also export builtins. As TOCO evolved we added warnings
1990   // when custom ops are exported but SimpleOperator bypasses thoses. To
1991   // prevent user confusion we are settling on using SimpleOperator only for
1992   // builtins.
1993   ops.push_back(MakeUnique<SimpleOperator<FloorOperator>>(
1994       ::tflite::BuiltinOperator_FLOOR, OperatorType::kFloor));
1995   ops.push_back(MakeUnique<SimpleOperator<CeilOperator>>(
1996       ::tflite::BuiltinOperator_CEIL, OperatorType::kCeil));
1997   ops.push_back(MakeUnique<SimpleOperator<EluOperator>>(
1998       ::tflite::BuiltinOperator_ELU, OperatorType::kElu));
1999   ops.push_back(MakeUnique<SimpleOperator<RoundOperator>>(
2000       ::tflite::BuiltinOperator_ROUND, OperatorType::kRound));
2001   ops.push_back(MakeUnique<SimpleOperator<ReluOperator>>(
2002       ::tflite::BuiltinOperator_RELU, OperatorType::kRelu));
2003   ops.push_back(MakeUnique<SimpleOperator<Relu1Operator>>(
2004       ::tflite::BuiltinOperator_RELU_N1_TO_1, OperatorType::kRelu1));
2005   ops.push_back(MakeUnique<SimpleOperator<Relu6Operator>>(
2006       ::tflite::BuiltinOperator_RELU6, OperatorType::kRelu6));
2007   ops.push_back(MakeUnique<SimpleOperator<PReluOperator>>(
2008       ::tflite::BuiltinOperator_PRELU, OperatorType::kPRelu));
2009   ops.push_back(MakeUnique<SimpleOperator<LogisticOperator>>(
2010       ::tflite::BuiltinOperator_LOGISTIC, OperatorType::kLogistic));
2011   ops.push_back(MakeUnique<SimpleOperator<TanhOperator>>(
2012       ::tflite::BuiltinOperator_TANH, OperatorType::kTanh));
2013   ops.push_back(MakeUnique<SimpleOperator<ExpOperator>>(
2014       ::tflite::BuiltinOperator_EXP, OperatorType::kExp));
2015   ops.push_back(MakeUnique<SimpleOperator<CosOperator>>(
2016       ::tflite::BuiltinOperator_COS, OperatorType::kCos));
2017   ops.push_back(MakeUnique<SimpleOperator<LogSoftmaxOperator>>(
2018       ::tflite::BuiltinOperator_LOG_SOFTMAX, OperatorType::kLogSoftmax));
2019   ops.push_back(MakeUnique<SimpleOperator<TensorFlowMaximumOperator>>(
2020       ::tflite::BuiltinOperator_MAXIMUM, OperatorType::kMaximum));
2021   ops.push_back(MakeUnique<SimpleOperator<TensorFlowMinimumOperator>>(
2022       ::tflite::BuiltinOperator_MINIMUM, OperatorType::kMinimum));
2023   ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterOperator>>(
2024       ::tflite::BuiltinOperator_GREATER, OperatorType::kGreater));
2025   ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterEqualOperator>>(
2026       ::tflite::BuiltinOperator_GREATER_EQUAL, OperatorType::kGreaterEqual));
2027   ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessOperator>>(
2028       ::tflite::BuiltinOperator_LESS, OperatorType::kLess));
2029   ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessEqualOperator>>(
2030       ::tflite::BuiltinOperator_LESS_EQUAL, OperatorType::kLessEqual));
2031   ops.push_back(MakeUnique<SimpleOperator<TensorFlowEqualOperator>>(
2032       ::tflite::BuiltinOperator_EQUAL, OperatorType::kEqual));
2033   ops.push_back(MakeUnique<SimpleOperator<TensorFlowNotEqualOperator>>(
2034       ::tflite::BuiltinOperator_NOT_EQUAL, OperatorType::kNotEqual));
2035   ops.push_back(MakeUnique<SimpleOperator<NegOperator>>(
2036       ::tflite::BuiltinOperator_NEG, OperatorType::kNeg));
2037   ops.push_back(MakeUnique<SimpleOperator<SelectOperator>>(
2038       ::tflite::BuiltinOperator_SELECT, OperatorType::kSelect));
2039   ops.push_back(MakeUnique<SimpleOperator<SliceOperator>>(
2040       ::tflite::BuiltinOperator_SLICE, OperatorType::kSlice));
2041   ops.push_back(MakeUnique<SimpleOperator<PowOperator>>(
2042       ::tflite::BuiltinOperator_POW, OperatorType::kPow));
2043   ops.push_back(MakeUnique<SimpleOperator<LogicalOrOperator>>(
2044       ::tflite::BuiltinOperator_LOGICAL_OR, OperatorType::kLogicalOr));
2045   ops.emplace_back(new SimpleOperator<LogicalAndOperator>(
2046       ::tflite::BuiltinOperator_LOGICAL_AND, OperatorType::kLogicalAnd));
2047   ops.emplace_back(new SimpleOperator<LogicalNotOperator>(
2048       ::tflite::BuiltinOperator_LOGICAL_NOT, OperatorType::kLogicalNot));
2049   ops.emplace_back(new SimpleOperator<FloorDivOperator>(
2050       ::tflite::BuiltinOperator_FLOOR_DIV, OperatorType::kFloorDiv));
2051   ops.emplace_back(new SimpleOperator<FloorModOperator>(
2052       ::tflite::BuiltinOperator_FLOOR_MOD, OperatorType::kFloorMod));
2053   ops.emplace_back(new SimpleOperator<RangeOperator>(
2054       ::tflite::BuiltinOperator_RANGE, OperatorType::kRange));
2055   // Element-wise operator
2056   ops.push_back(MakeUnique<SimpleOperator<SinOperator>>(
2057       ::tflite::BuiltinOperator_SIN, OperatorType::kSin));
2058   ops.push_back(MakeUnique<SimpleOperator<LogOperator>>(
2059       ::tflite::BuiltinOperator_LOG, OperatorType::kLog));
2060   ops.push_back(MakeUnique<SimpleOperator<TensorFlowSqrtOperator>>(
2061       ::tflite::BuiltinOperator_SQRT, OperatorType::kSqrt));
2062   ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
2063       ::tflite::BuiltinOperator_RSQRT, OperatorType::kRsqrt));
2064   ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
2065       ::tflite::BuiltinOperator_SQUARE, OperatorType::kSquare));
2066   ops.push_back(MakeUnique<SimpleOperator<TensorFlowZerosLikeOperator>>(
2067       ::tflite::BuiltinOperator_ZEROS_LIKE, OperatorType::kZerosLike));
2068   ops.push_back(MakeUnique<SimpleOperator<AbsOperator>>(
2069       ::tflite::BuiltinOperator_ABS, OperatorType::kAbs));
2070   ops.push_back(MakeUnique<SimpleOperator<HardSwishOperator>>(
2071       ::tflite::BuiltinOperator_HARD_SWISH, OperatorType::kHardSwish));
2072   ops.push_back(MakeUnique<SimpleOperator<FillOperator>>(
2073       ::tflite::BuiltinOperator_FILL, OperatorType::kFill));
2074   ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>(
2075       ::tflite::BuiltinOperator_REVERSE_V2, OperatorType::kReverseV2));
2076   ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>(
2077       ::tflite::BuiltinOperator_RANK, OperatorType::kRank));
2078   ops.emplace_back(new SimpleOperator<SegmentSumOperator>(
2079       ::tflite::BuiltinOperator_SEGMENT_SUM, OperatorType::kSegmentSum));
2080   ops.emplace_back(MakeUnique<SimpleOperator<ScatterNdOperator>>(
2081       ::tflite::BuiltinOperator_SCATTER_ND, OperatorType::kScatterNd));
2082   return ops;
2083 }
2084 }  // namespace
2085 
2086 // LINT.ThenChange(//tensorflow/lite/tools/versioning/op_version.cc)
2087 
BuildOperatorByTypeMap(bool enable_select_tf_ops)2088 std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
2089     bool enable_select_tf_ops) {
2090   std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
2091 
2092   std::vector<std::unique_ptr<BaseOperator>> ops =
2093       BuildOperatorList(enable_select_tf_ops);
2094   for (auto& op : ops) {
2095     result[op->type()] = std::move(op);
2096   }
2097 
2098   return result;
2099 }
2100 
BuildOperatorByNameMap(bool enable_select_tf_ops)2101 std::map<std::string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
2102     bool enable_select_tf_ops) {
2103   std::map<std::string, std::unique_ptr<BaseOperator>> result;
2104 
2105   std::vector<std::unique_ptr<BaseOperator>> ops =
2106       BuildOperatorList(enable_select_tf_ops);
2107   for (auto& op : ops) {
2108     result[op->name()] = std::move(op);
2109   }
2110 
2111   return result;
2112 }
2113 
ShouldExportAsFlexOp(bool enable_select_tf_ops,const std::string & tensorflow_op_name)2114 bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
2115                           const std::string& tensorflow_op_name) {
2116   // If Flex ops aren't allow at all, simply return false.
2117   if (!enable_select_tf_ops) {
2118     return false;
2119   }
2120   // Check if we can find the `OpDef` for the TensorFlow op. If we can find
2121   // it and it has been allowlisted, export the op as an Flex op. Otherwise,
2122   // export it as a regular custom op.
2123   const tensorflow::OpDef* op_def = nullptr;
2124   if (!tensorflow::OpRegistry::Global()
2125            ->LookUpOpDef(tensorflow_op_name, &op_def)
2126            .ok()) {
2127     return false;
2128   }
2129 
2130   if (!::tflite::flex::IsAllowlistedFlexOp(tensorflow_op_name)) {
2131     LOG(WARNING) << "Op " << tensorflow_op_name
2132                  << " is a valid TensorFlow op but has not been allowlisted for"
2133                     " the TensorFlow Lite flex op set.";
2134     return false;
2135   }
2136 
2137   return true;
2138 }
2139 
2140 }  // namespace tflite
2141 
2142 }  // namespace toco
2143