1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/operator.h"
16
17 #include <map>
18 #include <memory>
19 #include <string>
20 #include <utility>
21
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/op.h"
25 #include "tensorflow/core/framework/op_def.pb.h"
26 #include "tensorflow/core/util/ptr_util.h"
27
28 // TODO(ycling): Consider refactoring to extract the LSTM definition out of
29 // graph_transformation module.
30 #include "tensorflow/lite/builtin_op_data.h"
31 #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
32 #include "tensorflow/lite/schema/schema_generated.h"
33 #include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
34 #include "tensorflow/lite/toco/model.h"
35 #include "tensorflow/lite/toco/tflite/builtin_operator.h"
36 #include "tensorflow/lite/toco/tflite/custom_operator.h"
37 #include "tensorflow/lite/toco/tflite/simple_operator.h"
38 #include "tensorflow/lite/toco/tflite/types.h"
39 #include "tensorflow/lite/tools/versioning/op_version.h"
40
41 namespace toco {
42
43 namespace tflite {
44
45 // LINT.IfChange
46
GetTensorType(const ArrayDataType type)47 TfLiteType GetTensorType(const ArrayDataType type) {
48 const std::map<ArrayDataType, TfLiteType> tensor_type_map = {
49 {ArrayDataType::kBool, kTfLiteBool},
50 {ArrayDataType::kFloat, kTfLiteFloat32},
51 {ArrayDataType::kInt8, kTfLiteInt8},
52 {ArrayDataType::kUint8, kTfLiteUInt8},
53 {ArrayDataType::kInt16, kTfLiteInt16},
54 {ArrayDataType::kUint16, kTfLiteUInt16},
55 {ArrayDataType::kInt32, kTfLiteInt32},
56 {ArrayDataType::kUint32, kTfLiteUInt32},
57 {ArrayDataType::kInt64, kTfLiteInt64},
58 {ArrayDataType::kUint64, kTfLiteUInt64},
59 {ArrayDataType::kString, kTfLiteString},
60 {ArrayDataType::kComplex64, kTfLiteComplex64},
61 {ArrayDataType::kComplex128, kTfLiteComplex128},
62 {ArrayDataType::kFloat16, kTfLiteFloat16},
63 {ArrayDataType::kFloat64, kTfLiteFloat64}};
64
65 auto it = tensor_type_map.find(type);
66 if (it != tensor_type_map.end()) {
67 return it->second;
68 }
69 return kTfLiteNoType;
70 }
71
GetVersioningOpSig(const::tflite::BuiltinOperator op,const OperatorSignature & op_signature)72 ::tflite::OpSignature GetVersioningOpSig(
73 const ::tflite::BuiltinOperator op, const OperatorSignature& op_signature) {
74 std::vector<::tflite::OpSignatureTensorSpec> inputs, outputs;
75 for (const auto& input_name : op_signature.op->inputs) {
76 ::tflite::OpSignatureTensorSpec tensor = {kTfLiteNoType};
77 if (op_signature.model->HasArray(input_name)) {
78 const Array& input_array = op_signature.model->GetArray(input_name);
79 tensor.type = GetTensorType(input_array.data_type);
80 if (input_array.has_shape()) {
81 tensor.dims = input_array.shape().dims();
82 }
83 }
84 inputs.push_back(tensor);
85 }
86 for (const auto& output_name : op_signature.op->outputs) {
87 ::tflite::OpSignatureTensorSpec tensor = {kTfLiteNoType};
88 if (op_signature.model->HasArray(output_name)) {
89 const Array& output_array = op_signature.model->GetArray(output_name);
90 tensor.type = GetTensorType(output_array.data_type);
91 if (output_array.has_shape()) {
92 tensor.dims = output_array.shape().dims();
93 }
94 }
95 outputs.push_back(tensor);
96 }
97 return ::tflite::OpSignature{op, inputs, outputs};
98 }
99
100 class AveragePool
101 : public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions,
102 ::tflite::BuiltinOptions_Pool2DOptions> {
103 public:
104 using BuiltinOperator::BuiltinOperator;
105
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const106 flatbuffers::Offset<TfLiteOptions> WriteOptions(
107 const TocoOperator& op,
108 flatbuffers::FlatBufferBuilder* builder) const override {
109 auto padding = Padding::Serialize(op.padding.type);
110 auto activation_function =
111 ActivationFunction::Serialize(op.fused_activation_function);
112 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
113 op.stride_height, op.kwidth,
114 op.kheight, activation_function);
115 }
116
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const117 void ReadOptions(const TfLiteOptions& options,
118 TocoOperator* op) const override {
119 op->padding.type = Padding::Deserialize(options.padding());
120 op->stride_width = options.stride_w();
121 op->stride_height = options.stride_h();
122 op->kwidth = options.filter_width();
123 op->kheight = options.filter_height();
124 op->fused_activation_function =
125 ActivationFunction::Deserialize(options.fused_activation_function());
126 }
127 };
128
129 class Convolution
130 : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
131 ::tflite::BuiltinOptions_Conv2DOptions> {
132 public:
133 using BuiltinOperator::BuiltinOperator;
134
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const135 flatbuffers::Offset<TfLiteOptions> WriteOptions(
136 const TocoOperator& op,
137 flatbuffers::FlatBufferBuilder* builder) const override {
138 auto padding = Padding::Serialize(op.padding.type);
139 auto activation_function =
140 ActivationFunction::Serialize(op.fused_activation_function);
141 return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
142 op.stride_height, activation_function,
143 op.dilation_width_factor,
144 op.dilation_height_factor);
145 }
146
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const147 void ReadOptions(const TfLiteOptions& options,
148 TocoOperator* op) const override {
149 op->padding.type = Padding::Deserialize(options.padding());
150 op->stride_width = options.stride_w();
151 op->stride_height = options.stride_h();
152 op->dilation_width_factor = options.dilation_w_factor();
153 op->dilation_height_factor = options.dilation_h_factor();
154 op->fused_activation_function =
155 ActivationFunction::Deserialize(options.fused_activation_function());
156 }
157 };
158
159 class DepthwiseConvolution
160 : public BuiltinOperator<DepthwiseConvOperator,
161 ::tflite::DepthwiseConv2DOptions,
162 ::tflite::BuiltinOptions_DepthwiseConv2DOptions> {
163 public:
164 using BuiltinOperator::BuiltinOperator;
165
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const166 flatbuffers::Offset<TfLiteOptions> WriteOptions(
167 const TocoOperator& op,
168 flatbuffers::FlatBufferBuilder* builder) const override {
169 auto padding = Padding::Serialize(op.padding.type);
170 auto activation_function =
171 ActivationFunction::Serialize(op.fused_activation_function);
172 return ::tflite::CreateDepthwiseConv2DOptions(
173 *builder, padding, op.stride_width, op.stride_height,
174 op.depth_multiplier, activation_function, op.dilation_width_factor,
175 op.dilation_height_factor);
176 }
177
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const178 void ReadOptions(const TfLiteOptions& options,
179 TocoOperator* op) const override {
180 op->padding.type = Padding::Deserialize(options.padding());
181 op->stride_width = options.stride_w();
182 op->stride_height = options.stride_h();
183 op->depth_multiplier = options.depth_multiplier();
184 op->fused_activation_function =
185 ActivationFunction::Deserialize(options.fused_activation_function());
186 op->dilation_width_factor = options.dilation_w_factor();
187 op->dilation_height_factor = options.dilation_h_factor();
188 }
189
GetVersion(const OperatorSignature & op_signature) const190 int GetVersion(const OperatorSignature& op_signature) const override {
191 const auto& conv_op =
192 static_cast<const DepthwiseConvOperator&>(*op_signature.op);
193 ::tflite::OpSignature op_sig =
194 GetVersioningOpSig(builtin_op(), op_signature);
195 TfLiteDepthwiseConvParams depthwise_conv_params = {};
196 depthwise_conv_params.dilation_width_factor = conv_op.dilation_width_factor;
197 depthwise_conv_params.dilation_height_factor =
198 conv_op.dilation_height_factor;
199 op_sig.builtin_data = reinterpret_cast<void*>(&depthwise_conv_params);
200 return ::tflite::GetBuiltinOperatorVersion(op_sig);
201 }
202 };
203
204 class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
205 ::tflite::BuiltinOptions_AddOptions> {
206 public:
207 using BuiltinOperator::BuiltinOperator;
208
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const209 flatbuffers::Offset<TfLiteOptions> WriteOptions(
210 const TocoOperator& op,
211 flatbuffers::FlatBufferBuilder* builder) const override {
212 auto activation_function =
213 ActivationFunction::Serialize(op.fused_activation_function);
214 return ::tflite::CreateAddOptions(*builder, activation_function);
215 }
216
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const217 void ReadOptions(const TfLiteOptions& options,
218 TocoOperator* op) const override {
219 op->fused_activation_function =
220 ActivationFunction::Deserialize(options.fused_activation_function());
221 }
222 };
223
224 class AddN : public BuiltinOperator<AddNOperator, ::tflite::AddNOptions,
225 ::tflite::BuiltinOptions_AddNOptions> {
226 public:
227 using BuiltinOperator::BuiltinOperator;
228
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const229 flatbuffers::Offset<TfLiteOptions> WriteOptions(
230 const TocoOperator& op,
231 flatbuffers::FlatBufferBuilder* builder) const override {
232 return ::tflite::CreateAddNOptions(*builder);
233 }
234
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const235 void ReadOptions(const TfLiteOptions& options,
236 TocoOperator* op) const override {}
237 };
238
239 class SpaceToBatchND
240 : public BuiltinOperator<SpaceToBatchNDOperator,
241 ::tflite::SpaceToBatchNDOptions,
242 ::tflite::BuiltinOptions_SpaceToBatchNDOptions> {
243 public:
244 using BuiltinOperator::BuiltinOperator;
245
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const246 flatbuffers::Offset<TfLiteOptions> WriteOptions(
247 const TocoOperator& op,
248 flatbuffers::FlatBufferBuilder* builder) const override {
249 return ::tflite::CreateSpaceToBatchNDOptions(*builder);
250 }
251
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const252 void ReadOptions(const TfLiteOptions& options,
253 TocoOperator* op) const override {}
254
GetVersion(const OperatorSignature & op_signature) const255 int GetVersion(const OperatorSignature& op_signature) const override {
256 ::tflite::OpSignature op_sig =
257 GetVersioningOpSig(builtin_op(), op_signature);
258 return ::tflite::GetBuiltinOperatorVersion(op_sig);
259 }
260 };
261
262 class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
263 ::tflite::BuiltinOptions_SubOptions> {
264 public:
265 using BuiltinOperator::BuiltinOperator;
266
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const267 flatbuffers::Offset<TfLiteOptions> WriteOptions(
268 const TocoOperator& op,
269 flatbuffers::FlatBufferBuilder* builder) const override {
270 auto activation_function =
271 ActivationFunction::Serialize(op.fused_activation_function);
272 return ::tflite::CreateSubOptions(*builder, activation_function);
273 }
274
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const275 void ReadOptions(const TfLiteOptions& options,
276 TocoOperator* op) const override {
277 op->fused_activation_function =
278 ActivationFunction::Deserialize(options.fused_activation_function());
279 }
280
GetVersion(const OperatorSignature & op_signature) const281 int GetVersion(const OperatorSignature& op_signature) const override {
282 ::tflite::OpSignature op_sig =
283 GetVersioningOpSig(builtin_op(), op_signature);
284 return ::tflite::GetBuiltinOperatorVersion(op_sig);
285 }
286 };
287
288 class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
289 ::tflite::BuiltinOptions_DivOptions> {
290 public:
291 using BuiltinOperator::BuiltinOperator;
292
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const293 flatbuffers::Offset<TfLiteOptions> WriteOptions(
294 const TocoOperator& op,
295 flatbuffers::FlatBufferBuilder* builder) const override {
296 auto activation_function =
297 ActivationFunction::Serialize(op.fused_activation_function);
298 return ::tflite::CreateDivOptions(*builder, activation_function);
299 }
300
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const301 void ReadOptions(const TfLiteOptions& options,
302 TocoOperator* op) const override {
303 op->fused_activation_function =
304 ActivationFunction::Deserialize(options.fused_activation_function());
305 }
306
GetVersion(const OperatorSignature & op_signature) const307 int GetVersion(const OperatorSignature& op_signature) const override {
308 ::tflite::OpSignature op_sig =
309 GetVersioningOpSig(builtin_op(), op_signature);
310 return ::tflite::GetBuiltinOperatorVersion(op_sig);
311 }
312 };
313
314 class BatchToSpaceND
315 : public BuiltinOperator<BatchToSpaceNDOperator,
316 ::tflite::BatchToSpaceNDOptions,
317 ::tflite::BuiltinOptions_BatchToSpaceNDOptions> {
318 public:
319 using BuiltinOperator::BuiltinOperator;
320
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const321 flatbuffers::Offset<TfLiteOptions> WriteOptions(
322 const TocoOperator& op,
323 flatbuffers::FlatBufferBuilder* builder) const override {
324 return ::tflite::CreateBatchToSpaceNDOptions(*builder);
325 }
326
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const327 void ReadOptions(const TfLiteOptions& options,
328 TocoOperator* op) const override {}
329
GetVersion(const OperatorSignature & op_signature) const330 int GetVersion(const OperatorSignature& op_signature) const override {
331 ::tflite::OpSignature op_sig =
332 GetVersioningOpSig(builtin_op(), op_signature);
333 return ::tflite::GetBuiltinOperatorVersion(op_sig);
334 }
335 };
336
337 class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
338 ::tflite::BuiltinOptions_CastOptions> {
339 public:
340 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const341 flatbuffers::Offset<TfLiteOptions> WriteOptions(
342 const TocoOperator& op,
343 flatbuffers::FlatBufferBuilder* builder) const override {
344 return ::tflite::CreateCastOptions(*builder,
345 DataType::Serialize(op.src_data_type),
346 DataType::Serialize(op.dst_data_type));
347 }
348
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const349 void ReadOptions(const TfLiteOptions& options,
350 TocoOperator* op) const override {
351 op->src_data_type = DataType::Deserialize(options.in_data_type());
352 op->dst_data_type = DataType::Deserialize(options.out_data_type());
353 }
354 };
355
356 class Concatenation
357 : public BuiltinOperator<ConcatenationOperator,
358 ::tflite::ConcatenationOptions,
359 ::tflite::BuiltinOptions_ConcatenationOptions> {
360 public:
361 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const362 flatbuffers::Offset<TfLiteOptions> WriteOptions(
363 const TocoOperator& op,
364 flatbuffers::FlatBufferBuilder* builder) const override {
365 return ::tflite::CreateConcatenationOptions(*builder, op.axis);
366 }
367
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const368 void ReadOptions(const TfLiteOptions& options,
369 TocoOperator* op) const override {
370 op->axis = options.axis();
371 }
372 };
373
374 class DepthToSpace
375 : public BuiltinOperator<DepthToSpaceOperator,
376 ::tflite::DepthToSpaceOptions,
377 ::tflite::BuiltinOptions_DepthToSpaceOptions> {
378 public:
379 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const380 flatbuffers::Offset<TfLiteOptions> WriteOptions(
381 const TocoOperator& op,
382 flatbuffers::FlatBufferBuilder* builder) const override {
383 return ::tflite::CreateDepthToSpaceOptions(*builder, op.block_size);
384 }
385
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const386 void ReadOptions(const TfLiteOptions& options,
387 TocoOperator* op) const override {
388 op->block_size = options.block_size();
389 }
390 };
391
392 class FakeQuant
393 : public BuiltinOperator<FakeQuantOperator, ::tflite::FakeQuantOptions,
394 ::tflite::BuiltinOptions_FakeQuantOptions> {
395 public:
396 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const397 flatbuffers::Offset<TfLiteOptions> WriteOptions(
398 const TocoOperator& op,
399 flatbuffers::FlatBufferBuilder* builder) const override {
400 return ::tflite::CreateFakeQuantOptions(
401 *builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range);
402 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const403 void ReadOptions(const TfLiteOptions& options,
404 TocoOperator* op) const override {
405 auto* minmax = new MinMax;
406 minmax->min = options.min();
407 minmax->max = options.max();
408 op->minmax.reset(minmax);
409 op->num_bits = options.num_bits();
410 op->narrow_range = options.narrow_range();
411 }
GetVersion(const OperatorSignature & op_signature) const412 int GetVersion(const OperatorSignature& op_signature) const override {
413 const auto& fq_op = static_cast<const FakeQuantOperator&>(*op_signature.op);
414 ::tflite::OpSignature op_sig =
415 GetVersioningOpSig(builtin_op(), op_signature);
416 TfLiteFakeQuantParams fake_quant_params = {};
417 fake_quant_params.narrow_range = fq_op.narrow_range;
418 op_sig.builtin_data = reinterpret_cast<void*>(&fake_quant_params);
419 return ::tflite::GetBuiltinOperatorVersion(op_sig);
420 }
421 };
422
423 class FullyConnected
424 : public BuiltinOperator<FullyConnectedOperator,
425 ::tflite::FullyConnectedOptions,
426 ::tflite::BuiltinOptions_FullyConnectedOptions> {
427 public:
428 using BuiltinOperator::BuiltinOperator;
429
GetWeightFormat(FullyConnectedWeightsFormat fmt) const430 ::tflite::FullyConnectedOptionsWeightsFormat GetWeightFormat(
431 FullyConnectedWeightsFormat fmt) const {
432 switch (fmt) {
433 case FullyConnectedWeightsFormat::kDefault:
434 return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
435 case FullyConnectedWeightsFormat::kShuffled4x16Int8:
436 return ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
437 default:
438 LOG(ERROR) << "Unhandled FC weights format";
439 return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
440 }
441 }
442
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const443 flatbuffers::Offset<TfLiteOptions> WriteOptions(
444 const TocoOperator& op,
445 flatbuffers::FlatBufferBuilder* builder) const override {
446 auto activation_function =
447 ActivationFunction::Serialize(op.fused_activation_function);
448 return ::tflite::CreateFullyConnectedOptions(
449 *builder, activation_function, GetWeightFormat(op.weights_format));
450 }
451
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const452 void ReadOptions(const TfLiteOptions& options,
453 TocoOperator* op) const override {
454 op->fused_activation_function =
455 ActivationFunction::Deserialize(options.fused_activation_function());
456 switch (options.weights_format()) {
457 case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT:
458 op->weights_format = FullyConnectedWeightsFormat::kDefault;
459 break;
460 case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
461 op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8;
462 break;
463 default:
464 LOG(ERROR) << "Unhandled FC weights format";
465 op->weights_format = FullyConnectedWeightsFormat::kDefault;
466 }
467 }
468
GetVersion(const OperatorSignature & op_signature) const469 int GetVersion(const OperatorSignature& op_signature) const override {
470 const auto& fc_op =
471 static_cast<const FullyConnectedOperator&>(*op_signature.op);
472 ::tflite::OpSignature op_sig =
473 GetVersioningOpSig(builtin_op(), op_signature);
474 TfLiteFullyConnectedParams fully_connected_params = {};
475 fully_connected_params.keep_num_dims = fc_op.keep_num_dims;
476 fully_connected_params.weights_format =
477 static_cast<TfLiteFullyConnectedWeightsFormat>(
478 GetWeightFormat(fc_op.weights_format));
479 op_sig.builtin_data = reinterpret_cast<void*>(&fully_connected_params);
480 return ::tflite::GetBuiltinOperatorVersion(op_sig);
481 }
482 };
483
484 class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
485 ::tflite::BuiltinOptions_GatherOptions> {
486 public:
487 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const488 flatbuffers::Offset<TfLiteOptions> WriteOptions(
489 const TocoOperator& op,
490 flatbuffers::FlatBufferBuilder* builder) const override {
491 int axis = op.axis ? op.axis.value() : 0;
492 return ::tflite::CreateGatherOptions(*builder, axis);
493 }
494
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const495 void ReadOptions(const TfLiteOptions& options,
496 TocoOperator* op) const override {
497 op->axis = {options.axis()};
498 }
499 };
500
501 class GatherNd
502 : public BuiltinOperator<GatherNdOperator, ::tflite::GatherNdOptions,
503 ::tflite::BuiltinOptions_GatherNdOptions> {
504 public:
505 using BuiltinOperator::BuiltinOperator;
506
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const507 flatbuffers::Offset<TfLiteOptions> WriteOptions(
508 const TocoOperator& op,
509 flatbuffers::FlatBufferBuilder* builder) const override {
510 return ::tflite::CreateGatherNdOptions(*builder);
511 }
512
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const513 void ReadOptions(const TfLiteOptions& options,
514 TocoOperator* op) const override {}
515 };
516
517 class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
518 ::tflite::BuiltinOptions_SVDFOptions> {
519 public:
520 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const521 flatbuffers::Offset<TfLiteOptions> WriteOptions(
522 const TocoOperator& op,
523 flatbuffers::FlatBufferBuilder* builder) const override {
524 auto activation_function =
525 ActivationFunction::Serialize(op.fused_activation_function);
526 return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function);
527 }
528
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const529 void ReadOptions(const TfLiteOptions& options,
530 TocoOperator* op) const override {
531 op->fused_activation_function =
532 ActivationFunction::Deserialize(options.fused_activation_function());
533 op->rank = options.rank();
534 }
535 };
536
537 class L2Normalization
538 : public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions,
539 ::tflite::BuiltinOptions_L2NormOptions> {
540 public:
541 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const542 flatbuffers::Offset<TfLiteOptions> WriteOptions(
543 const TocoOperator& op,
544 flatbuffers::FlatBufferBuilder* builder) const override {
545 auto activation_function =
546 ActivationFunction::Serialize(op.fused_activation_function);
547 return ::tflite::CreateL2NormOptions(*builder, activation_function);
548 }
549
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const550 void ReadOptions(const TfLiteOptions& options,
551 TocoOperator* op) const override {
552 op->fused_activation_function =
553 ActivationFunction::Deserialize(options.fused_activation_function());
554 }
555 };
556
557 class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
558 ::tflite::BuiltinOptions_Pool2DOptions> {
559 public:
560 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const561 flatbuffers::Offset<TfLiteOptions> WriteOptions(
562 const TocoOperator& op,
563 flatbuffers::FlatBufferBuilder* builder) const override {
564 auto padding = Padding::Serialize(op.padding.type);
565 auto activation_function =
566 ActivationFunction::Serialize(op.fused_activation_function);
567 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
568 op.stride_height, op.kwidth,
569 op.kheight, activation_function);
570 }
571
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const572 void ReadOptions(const TfLiteOptions& options,
573 TocoOperator* op) const override {
574 op->padding.type = Padding::Deserialize(options.padding());
575 op->stride_width = options.stride_w();
576 op->stride_height = options.stride_h();
577 op->kwidth = options.filter_width();
578 op->kheight = options.filter_height();
579 op->fused_activation_function =
580 ActivationFunction::Deserialize(options.fused_activation_function());
581 }
582 };
583
584 class LocalResponseNormalization
585 : public BuiltinOperator<
586 LocalResponseNormalizationOperator,
587 ::tflite::LocalResponseNormalizationOptions,
588 ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> {
589 public:
590 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const591 flatbuffers::Offset<TfLiteOptions> WriteOptions(
592 const TocoOperator& op,
593 flatbuffers::FlatBufferBuilder* builder) const override {
594 return ::tflite::CreateLocalResponseNormalizationOptions(
595 *builder, op.range, op.bias, op.alpha, op.beta);
596 }
597
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const598 void ReadOptions(const TfLiteOptions& options,
599 TocoOperator* op) const override {
600 op->range = options.radius();
601 op->bias = options.bias();
602 op->alpha = options.alpha();
603 op->beta = options.beta();
604 }
605 };
606
607 class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
608 ::tflite::BuiltinOptions_Pool2DOptions> {
609 public:
610 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const611 flatbuffers::Offset<TfLiteOptions> WriteOptions(
612 const TocoOperator& op,
613 flatbuffers::FlatBufferBuilder* builder) const override {
614 auto padding = Padding::Serialize(op.padding.type);
615 auto activation_function =
616 ActivationFunction::Serialize(op.fused_activation_function);
617 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
618 op.stride_height, op.kwidth,
619 op.kheight, activation_function);
620 }
621
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const622 void ReadOptions(const TfLiteOptions& options,
623 TocoOperator* op) const override {
624 op->padding.type = Padding::Deserialize(options.padding());
625 op->stride_width = options.stride_w();
626 op->stride_height = options.stride_h();
627 op->kwidth = options.filter_width();
628 op->kheight = options.filter_height();
629 op->fused_activation_function =
630 ActivationFunction::Deserialize(options.fused_activation_function());
631 }
632 };
633
634 class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
635 ::tflite::BuiltinOptions_MulOptions> {
636 public:
637 using BuiltinOperator::BuiltinOperator;
638
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const639 flatbuffers::Offset<TfLiteOptions> WriteOptions(
640 const TocoOperator& op,
641 flatbuffers::FlatBufferBuilder* builder) const override {
642 auto activation_function =
643 ActivationFunction::Serialize(op.fused_activation_function);
644 return ::tflite::CreateMulOptions(*builder, activation_function);
645 }
646
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const647 void ReadOptions(const TfLiteOptions& options,
648 TocoOperator* op) const override {
649 op->fused_activation_function =
650 ActivationFunction::Deserialize(options.fused_activation_function());
651 }
652
GetVersion(const OperatorSignature & op_signature) const653 int GetVersion(const OperatorSignature& op_signature) const override {
654 const std::string& input1_name = op_signature.op->inputs[0];
655 const std::string& input2_name = op_signature.op->inputs[1];
656 const std::string& output_name = op_signature.op->outputs[0];
657 const Array& input1_array = op_signature.model->GetArray(input1_name);
658 const Array& input2_array = op_signature.model->GetArray(input2_name);
659 const Array& output_array = op_signature.model->GetArray(output_name);
660 const auto& input1_quant = input1_array.quantization_params;
661 const auto& input2_quant = input2_array.quantization_params;
662 const auto& output_quant = output_array.quantization_params;
663 const float input1_scale = input1_quant ? input1_quant->scale : 0.0f;
664 const float input2_scale = input2_quant ? input2_quant->scale : 0.0f;
665 const float output_scale = output_quant ? output_quant->scale : 0.0f;
666 ::tflite::OpSignature op_sig =
667 GetVersioningOpSig(builtin_op(), op_signature);
668 op_sig.ext_options.mul.input1_scale = input1_scale;
669 op_sig.ext_options.mul.input2_scale = input2_scale;
670 op_sig.ext_options.mul.output_scale = output_scale;
671 return ::tflite::GetBuiltinOperatorVersion(op_sig);
672 }
673 };
674
675 class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
676 ::tflite::BuiltinOptions_PadOptions> {
677 public:
678 using BuiltinOperator::BuiltinOperator;
679
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const680 flatbuffers::Offset<TfLiteOptions> WriteOptions(
681 const TocoOperator& op,
682 flatbuffers::FlatBufferBuilder* builder) const override {
683 return ::tflite::CreatePadOptions(*builder);
684 }
685
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const686 void ReadOptions(const TfLiteOptions& options,
687 TocoOperator* op) const override {}
688 };
689
690 class Tile
691 : public BuiltinOperator<TensorFlowTileOperator, ::tflite::TileOptions,
692 ::tflite::BuiltinOptions_TileOptions> {
693 using BuiltinOperator::BuiltinOperator;
694
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const695 flatbuffers::Offset<TfLiteOptions> WriteOptions(
696 const TocoOperator& op,
697 flatbuffers::FlatBufferBuilder* builder) const override {
698 return ::tflite::CreateTileOptions(*builder);
699 }
700
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const701 void ReadOptions(const TfLiteOptions& options,
702 TocoOperator* op) const override {}
703 };
704
705 class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
706 ::tflite::BuiltinOptions_PadV2Options> {
707 public:
708 using BuiltinOperator::BuiltinOperator;
709
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const710 flatbuffers::Offset<TfLiteOptions> WriteOptions(
711 const TocoOperator& op,
712 flatbuffers::FlatBufferBuilder* builder) const override {
713 return ::tflite::CreatePadV2Options(*builder);
714 }
715
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const716 void ReadOptions(const TfLiteOptions& options,
717 TocoOperator* op) const override {}
718 };
719
720 class Reshape
721 : public BuiltinOperator<TensorFlowReshapeOperator,
722 ::tflite::ReshapeOptions,
723 ::tflite::BuiltinOptions_ReshapeOptions> {
724 public:
725 using BuiltinOperator::BuiltinOperator;
726
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const727 flatbuffers::Offset<TfLiteOptions> WriteOptions(
728 const TocoOperator& op,
729 flatbuffers::FlatBufferBuilder* builder) const override {
730 return ::tflite::CreateReshapeOptions(*builder,
731 builder->CreateVector(op.shape));
732 }
733
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const734 void ReadOptions(const TfLiteOptions& options,
735 TocoOperator* op) const override {
736 op->shape.insert(op->shape.end(), options.new_shape()->begin(),
737 options.new_shape()->end());
738 }
739 };
740
741 class Softmax
742 : public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions,
743 ::tflite::BuiltinOptions_SoftmaxOptions> {
744 public:
745 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const746 flatbuffers::Offset<TfLiteOptions> WriteOptions(
747 const TocoOperator& op,
748 flatbuffers::FlatBufferBuilder* builder) const override {
749 return ::tflite::CreateSoftmaxOptions(*builder, op.beta);
750 }
751
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const752 void ReadOptions(const TfLiteOptions& options,
753 TocoOperator* op) const override {
754 op->beta = options.beta();
755 }
756 };
757
758 class SpaceToDepth
759 : public BuiltinOperator<SpaceToDepthOperator,
760 ::tflite::SpaceToDepthOptions,
761 ::tflite::BuiltinOptions_SpaceToDepthOptions> {
762 public:
763 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const764 flatbuffers::Offset<TfLiteOptions> WriteOptions(
765 const TocoOperator& op,
766 flatbuffers::FlatBufferBuilder* builder) const override {
767 return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size);
768 }
769
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const770 void ReadOptions(const TfLiteOptions& options,
771 TocoOperator* op) const override {
772 op->block_size = options.block_size();
773 }
774 };
775
776 class Transpose
777 : public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions,
778 ::tflite::BuiltinOptions_TransposeOptions> {
779 public:
780 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const781 flatbuffers::Offset<TfLiteOptions> WriteOptions(
782 const TocoOperator& op,
783 flatbuffers::FlatBufferBuilder* builder) const override {
784 return ::tflite::CreateTransposeOptions(*builder);
785 }
786
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const787 void ReadOptions(const TfLiteOptions& options,
788 TocoOperator* op) const override {}
789 };
790
791 class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
792 ::tflite::BuiltinOptions_LSTMOptions> {
793 public:
794 using BuiltinOperator::BuiltinOperator;
795
GetKernelType(LstmCellOperator::KernelType type) const796 ::tflite::LSTMKernelType GetKernelType(
797 LstmCellOperator::KernelType type) const {
798 switch (type) {
799 case LstmCellOperator::KERNEL_BASIC:
800 return ::tflite::LSTMKernelType_BASIC;
801 break;
802 case LstmCellOperator::KERNEL_FULL:
803 return ::tflite::LSTMKernelType_FULL;
804 break;
805 default:
806 LOG(ERROR) << "Unhandled Kernel Type";
807 return static_cast<::tflite::LSTMKernelType>(-1);
808 }
809 }
810
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const811 flatbuffers::Offset<TfLiteOptions> WriteOptions(
812 const TocoOperator& op,
813 flatbuffers::FlatBufferBuilder* builder) const override {
814 ::tflite::LSTMKernelType kernel_type = GetKernelType(op.kernel_type);
815
816 // Current toco converter only supports tanh, no clip.
817 return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
818 ::tflite::ActivationFunctionType_TANH,
819 /*cell_clip=*/0.0,
820 /*proj_clip=*/0.0, kernel_type);
821 }
822
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const823 void ReadOptions(const TfLiteOptions& options,
824 TocoOperator* op) const override {
825 // Only support tanh activation, so check that tflite type is tanh.
826 CHECK(options.fused_activation_function() ==
827 ::tflite::ActivationFunctionType_TANH);
828
829 switch (options.kernel_type()) {
830 case ::tflite::LSTMKernelType_BASIC:
831 op->kernel_type = LstmCellOperator::KERNEL_BASIC;
832 break;
833 case ::tflite::LSTMKernelType_FULL:
834 op->kernel_type = LstmCellOperator::KERNEL_FULL;
835 break;
836 }
837 }
838
GetVersion(const OperatorSignature & op_signature) const839 int GetVersion(const OperatorSignature& op_signature) const override {
840 const auto& lstm_op =
841 static_cast<const LstmCellOperator&>(*op_signature.op);
842 ::tflite::OpSignature op_sig =
843 GetVersioningOpSig(builtin_op(), op_signature);
844 TfLiteLSTMParams lstm_params = {};
845 lstm_params.kernel_type =
846 static_cast<TfLiteLSTMKernelType>(GetKernelType(lstm_op.kernel_type));
847 op_sig.builtin_data = reinterpret_cast<void*>(&lstm_params);
848 return ::tflite::GetBuiltinOperatorVersion(op_sig);
849 }
850
GetMutatingInputVariables(const Operator & op) const851 std::vector<bool> GetMutatingInputVariables(
852 const Operator& op) const override {
853 const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
854
855 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
856 switch (lstm_op.kernel_type) {
857 case LstmCellOperator::KERNEL_FULL: {
858 mutating_input_variables[kInputActivationStateTensor] = true;
859 mutating_input_variables[kInputCellStateTensor] = true;
860 break;
861 }
862 case LstmCellOperator::KERNEL_BASIC: {
863 mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true;
864 mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true;
865 break;
866 }
867 }
868 return mutating_input_variables;
869 }
870 };
871
872 class UnidirectionalSequenceLstm
873 : public BuiltinOperator<
874 UnidirectionalSequenceLstmOperator,
875 ::tflite::UnidirectionalSequenceLSTMOptions,
876 ::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> {
877 public:
878 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const879 flatbuffers::Offset<TfLiteOptions> WriteOptions(
880 const TocoOperator& op,
881 flatbuffers::FlatBufferBuilder* builder) const override {
882 // Current toco converter only supports tanh, no clip.
883 return ::tflite::CreateUnidirectionalSequenceLSTMOptions(
884 *builder, /*fused_activation_function=*/
885 ::tflite::ActivationFunctionType_TANH,
886 /*cell_clip=*/0.0,
887 /*proj_clip=*/0.0,
888 /*time_major=*/true);
889 }
890
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const891 void ReadOptions(const TfLiteOptions& options,
892 TocoOperator* op) const override {
893 // Only support tanh activation, so check that tflite type is tanh.
894 DCHECK(options.fused_activation_function() ==
895 ::tflite::ActivationFunctionType_TANH);
896 }
897
GetMutatingInputVariables(const Operator & op) const898 std::vector<bool> GetMutatingInputVariables(
899 const Operator& op) const override {
900 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
901 mutating_input_variables[kInputActivationStateTensor] = true;
902 mutating_input_variables[kInputCellStateTensor] = true;
903 return mutating_input_variables;
904 }
905 };
906
907 class BidirectionalSequenceLstm
908 : public BuiltinOperator<
909 BidirectionalSequenceLstmOperator,
910 ::tflite::BidirectionalSequenceLSTMOptions,
911 ::tflite::BuiltinOptions_BidirectionalSequenceLSTMOptions> {
912 public:
913 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const914 flatbuffers::Offset<TfLiteOptions> WriteOptions(
915 const TocoOperator& op,
916 flatbuffers::FlatBufferBuilder* builder) const override {
917 // Current toco converter only supports tanh, no clip.
918 return ::tflite::CreateBidirectionalSequenceLSTMOptions(
919 *builder, /*fused_activation_function=*/
920 ::tflite::ActivationFunctionType_TANH,
921 /*cell_clip=*/0.0,
922 /*proj_clip=*/0.0,
923 /*merge_outputs=*/op.merge_outputs,
924 /*time_major=*/true);
925 }
926
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const927 void ReadOptions(const TfLiteOptions& options,
928 TocoOperator* op) const override {
929 // Only support tanh activation, so check that tflite type is tanh.
930 DCHECK(options.fused_activation_function() ==
931 ::tflite::ActivationFunctionType_TANH);
932 op->merge_outputs = options.merge_outputs();
933 }
934
GetMutatingInputVariables(const Operator & op) const935 std::vector<bool> GetMutatingInputVariables(
936 const Operator& op) const override {
937 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
938 // Forward input activation state.
939 mutating_input_variables[35] = true;
940 // Forward input cell state.
941 mutating_input_variables[36] = true;
942 // Backward input activation state.
943 mutating_input_variables[37] = true;
944 // Backward input cell state.
945 mutating_input_variables[38] = true;
946 return mutating_input_variables;
947 }
948 };
949
950 class BidirectionalSequenceRnn
951 : public BuiltinOperator<
952 BidirectionalSequenceRnnOperator,
953 ::tflite::BidirectionalSequenceRNNOptions,
954 ::tflite::BuiltinOptions_BidirectionalSequenceRNNOptions> {
955 public:
956 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const957 flatbuffers::Offset<TfLiteOptions> WriteOptions(
958 const TocoOperator& op,
959 flatbuffers::FlatBufferBuilder* builder) const override {
960 // Current toco converter only supports tanh, no clip.
961 return ::tflite::CreateBidirectionalSequenceRNNOptions(
962 *builder, /*time_major=*/true,
963 /*fused_activation_function=*/
964 ::tflite::ActivationFunctionType_TANH,
965 /*merge_outputs=*/op.merge_outputs);
966 }
967
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const968 void ReadOptions(const TfLiteOptions& options,
969 TocoOperator* op) const override {
970 // Only support tanh activation, so check that tflite type is tanh.
971 DCHECK(options.fused_activation_function() ==
972 ::tflite::ActivationFunctionType_TANH);
973 op->merge_outputs = options.merge_outputs();
974 }
975
GetMutatingInputVariables(const Operator & op) const976 std::vector<bool> GetMutatingInputVariables(
977 const Operator& op) const override {
978 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
979 // Forward hidden state.
980 mutating_input_variables[4] = true;
981 // Backward hidden state.
982 mutating_input_variables[8] = true;
983 return mutating_input_variables;
984 }
985 };
986
987 class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions,
988 ::tflite::BuiltinOptions_ReducerOptions> {
989 public:
990 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const991 flatbuffers::Offset<TfLiteOptions> WriteOptions(
992 const TocoOperator& op,
993 flatbuffers::FlatBufferBuilder* builder) const override {
994 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
995 }
996
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const997 void ReadOptions(const TfLiteOptions& options,
998 TocoOperator* op) const override {
999 op->keep_dims = options.keep_dims();
1000 }
1001 };
1002
1003 class Sum
1004 : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
1005 ::tflite::BuiltinOptions_ReducerOptions> {
1006 public:
1007 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1008 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1009 const TocoOperator& op,
1010 flatbuffers::FlatBufferBuilder* builder) const override {
1011 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1012 }
1013
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1014 void ReadOptions(const TfLiteOptions& options,
1015 TocoOperator* op) const override {
1016 op->keep_dims = options.keep_dims();
1017 }
1018 };
1019
1020 class ReduceMax
1021 : public BuiltinOperator<TensorFlowMaxOperator, ::tflite::ReducerOptions,
1022 ::tflite::BuiltinOptions_ReducerOptions> {
1023 public:
1024 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1025 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1026 const TocoOperator& op,
1027 flatbuffers::FlatBufferBuilder* builder) const override {
1028 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1029 }
1030
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1031 void ReadOptions(const TfLiteOptions& options,
1032 TocoOperator* op) const override {
1033 op->keep_dims = options.keep_dims();
1034 }
1035 };
1036
1037 class ReduceMin
1038 : public BuiltinOperator<TensorFlowMinOperator, ::tflite::ReducerOptions,
1039 ::tflite::BuiltinOptions_ReducerOptions> {
1040 public:
1041 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1042 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1043 const TocoOperator& op,
1044 flatbuffers::FlatBufferBuilder* builder) const override {
1045 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1046 }
1047
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1048 void ReadOptions(const TfLiteOptions& options,
1049 TocoOperator* op) const override {
1050 op->keep_dims = options.keep_dims();
1051 }
1052 };
1053
1054 class ReduceProd
1055 : public BuiltinOperator<TensorFlowProdOperator, ::tflite::ReducerOptions,
1056 ::tflite::BuiltinOptions_ReducerOptions> {
1057 public:
1058 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1059 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1060 const TocoOperator& op,
1061 flatbuffers::FlatBufferBuilder* builder) const override {
1062 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1063 }
1064
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1065 void ReadOptions(const TfLiteOptions& options,
1066 TocoOperator* op) const override {
1067 op->keep_dims = options.keep_dims();
1068 }
1069 };
1070
1071 class ReduceAny
1072 : public BuiltinOperator<TensorFlowAnyOperator, ::tflite::ReducerOptions,
1073 ::tflite::BuiltinOptions_ReducerOptions> {
1074 public:
1075 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1076 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1077 const TocoOperator& op,
1078 flatbuffers::FlatBufferBuilder* builder) const override {
1079 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1080 }
1081
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1082 void ReadOptions(const TfLiteOptions& options,
1083 TocoOperator* op) const override {
1084 op->keep_dims = options.keep_dims();
1085 }
1086 };
1087
1088 class ResizeBilinear
1089 : public BuiltinOperator<ResizeBilinearOperator,
1090 ::tflite::ResizeBilinearOptions,
1091 ::tflite::BuiltinOptions_ResizeBilinearOptions> {
1092 public:
1093 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1094 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1095 const TocoOperator& op,
1096 flatbuffers::FlatBufferBuilder* builder) const override {
1097 return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners,
1098 op.half_pixel_centers);
1099 }
1100
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1101 void ReadOptions(const TfLiteOptions& options,
1102 TocoOperator* op) const override {
1103 op->align_corners = options.align_corners();
1104 op->half_pixel_centers = options.half_pixel_centers();
1105 }
1106
GetVersion(const OperatorSignature & op_signature) const1107 int GetVersion(const OperatorSignature& op_signature) const override {
1108 const auto& resize_bilinear_op =
1109 static_cast<const ResizeBilinearOperator&>(*op_signature.op);
1110 ::tflite::OpSignature op_sig =
1111 GetVersioningOpSig(builtin_op(), op_signature);
1112 TfLiteResizeBilinearParams resize_bilinear_params = {};
1113 resize_bilinear_params.half_pixel_centers =
1114 resize_bilinear_op.half_pixel_centers;
1115 resize_bilinear_params.align_corners = resize_bilinear_op.align_corners;
1116 op_sig.builtin_data = reinterpret_cast<void*>(&resize_bilinear_params);
1117 return ::tflite::GetBuiltinOperatorVersion(op_sig);
1118 }
1119 };
1120
1121 class ResizeNearestNeighbor
1122 : public BuiltinOperator<
1123 ResizeNearestNeighborOperator, ::tflite::ResizeNearestNeighborOptions,
1124 ::tflite::BuiltinOptions_ResizeNearestNeighborOptions> {
1125 public:
1126 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1127 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1128 const TocoOperator& op,
1129 flatbuffers::FlatBufferBuilder* builder) const override {
1130 return ::tflite::CreateResizeNearestNeighborOptions(
1131 *builder, op.align_corners, op.half_pixel_centers);
1132 }
1133
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1134 void ReadOptions(const TfLiteOptions& options,
1135 TocoOperator* op) const override {
1136 op->align_corners = options.align_corners();
1137 op->half_pixel_centers = options.half_pixel_centers();
1138 }
1139
GetVersion(const OperatorSignature & op_signature) const1140 int GetVersion(const OperatorSignature& op_signature) const override {
1141 const auto& resize_nn_op =
1142 static_cast<const ResizeNearestNeighborOperator&>(*op_signature.op);
1143 ::tflite::OpSignature op_sig =
1144 GetVersioningOpSig(builtin_op(), op_signature);
1145 TfLiteResizeNearestNeighborParams resize_nearest_neighbor_params = {};
1146 resize_nearest_neighbor_params.half_pixel_centers =
1147 resize_nn_op.half_pixel_centers;
1148 resize_nearest_neighbor_params.align_corners = resize_nn_op.align_corners;
1149 op_sig.builtin_data =
1150 reinterpret_cast<void*>(&resize_nearest_neighbor_params);
1151 return ::tflite::GetBuiltinOperatorVersion(op_sig);
1152 }
1153 };
1154
1155 class Squeeze
1156 : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
1157 ::tflite::BuiltinOptions_SqueezeOptions> {
1158 public:
1159 using BuiltinOperator::BuiltinOperator;
1160
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1161 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1162 const TocoOperator& op,
1163 flatbuffers::FlatBufferBuilder* builder) const override {
1164 auto squeeze_dims = builder->CreateVector(op.squeeze_dims);
1165 return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims);
1166 }
1167
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1168 void ReadOptions(const TfLiteOptions& options,
1169 TocoOperator* op) const override {
1170 op->squeeze_dims.insert(op->squeeze_dims.end(),
1171 options.squeeze_dims()->begin(),
1172 options.squeeze_dims()->end());
1173 }
1174 };
1175
1176 class Split
1177 : public BuiltinOperator<TensorFlowSplitOperator, ::tflite::SplitOptions,
1178 ::tflite::BuiltinOptions_SplitOptions> {
1179 public:
1180 using BuiltinOperator::BuiltinOperator;
1181
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1182 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1183 const TocoOperator& op,
1184 flatbuffers::FlatBufferBuilder* builder) const override {
1185 return ::tflite::CreateSplitOptions(*builder, op.num_split);
1186 }
1187
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1188 void ReadOptions(const TfLiteOptions& options,
1189 TocoOperator* op) const override {
1190 op->num_split = options.num_splits();
1191 }
1192 };
1193
1194 class SplitV
1195 : public BuiltinOperator<TensorFlowSplitVOperator, ::tflite::SplitVOptions,
1196 ::tflite::BuiltinOptions_SplitVOptions> {
1197 public:
1198 using BuiltinOperator::BuiltinOperator;
1199
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1200 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1201 const TocoOperator& op,
1202 flatbuffers::FlatBufferBuilder* builder) const override {
1203 return ::tflite::CreateSplitVOptions(*builder, op.num_split);
1204 }
1205
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1206 void ReadOptions(const TfLiteOptions& options,
1207 TocoOperator* op) const override {
1208 op->num_split = options.num_splits();
1209 }
1210 };
1211
1212 class StridedSlice
1213 : public BuiltinOperator<StridedSliceOperator,
1214 ::tflite::StridedSliceOptions,
1215 ::tflite::BuiltinOptions_StridedSliceOptions> {
1216 public:
1217 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1218 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1219 const TocoOperator& op,
1220 flatbuffers::FlatBufferBuilder* builder) const override {
1221 return ::tflite::CreateStridedSliceOptions(
1222 *builder, op.begin_mask, op.end_mask, op.ellipsis_mask,
1223 op.new_axis_mask, op.shrink_axis_mask);
1224 }
1225
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1226 void ReadOptions(const TfLiteOptions& options,
1227 TocoOperator* op) const override {
1228 op->begin_mask = options.begin_mask();
1229 op->end_mask = options.end_mask();
1230 op->ellipsis_mask = options.ellipsis_mask();
1231 op->new_axis_mask = options.new_axis_mask();
1232 op->shrink_axis_mask = options.shrink_axis_mask();
1233 }
1234
GetVersion(const OperatorSignature & op_signature) const1235 int GetVersion(const OperatorSignature& op_signature) const override {
1236 const auto& ss_op =
1237 static_cast<const StridedSliceOperator&>(*op_signature.op);
1238 ::tflite::OpSignature op_sig =
1239 GetVersioningOpSig(builtin_op(), op_signature);
1240 op_sig.ext_options.strided_slice.num_dims = ss_op.start_indices.size();
1241 TfLiteStridedSliceParams strided_slice_params = {};
1242 strided_slice_params.ellipsis_mask = ss_op.ellipsis_mask;
1243 strided_slice_params.new_axis_mask = ss_op.new_axis_mask;
1244 op_sig.builtin_data = reinterpret_cast<void*>(&strided_slice_params);
1245 return ::tflite::GetBuiltinOperatorVersion(op_sig);
1246 }
1247 };
1248
1249 class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
1250 ::tflite::BuiltinOptions_TopKV2Options> {
1251 public:
1252 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1253 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1254 const TocoOperator& op,
1255 flatbuffers::FlatBufferBuilder* builder) const override {
1256 return ::tflite::CreateTopKV2Options(*builder);
1257 }
1258
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1259 void ReadOptions(const TfLiteOptions& options,
1260 TocoOperator* op) const override {}
1261 };
1262
1263 class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
1264 ::tflite::BuiltinOptions_ArgMaxOptions> {
1265 public:
1266 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1267 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1268 const TocoOperator& op,
1269 flatbuffers::FlatBufferBuilder* builder) const override {
1270 return ::tflite::CreateArgMaxOptions(
1271 *builder, DataType::Serialize(op.output_data_type));
1272 }
1273
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1274 void ReadOptions(const TfLiteOptions& options,
1275 TocoOperator* op) const override {
1276 op->output_data_type = DataType::Deserialize(options.output_type());
1277 }
1278 };
1279
1280 class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
1281 ::tflite::BuiltinOptions_ArgMinOptions> {
1282 public:
1283 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1284 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1285 const TocoOperator& op,
1286 flatbuffers::FlatBufferBuilder* builder) const override {
1287 return ::tflite::CreateArgMinOptions(
1288 *builder, DataType::Serialize(op.output_data_type));
1289 }
1290
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1291 void ReadOptions(const TfLiteOptions& options,
1292 TocoOperator* op) const override {
1293 op->output_data_type = DataType::Deserialize(options.output_type());
1294 }
1295 };
1296
1297 class TransposeConv
1298 : public BuiltinOperator<TransposeConvOperator,
1299 ::tflite::TransposeConvOptions,
1300 ::tflite::BuiltinOptions_TransposeConvOptions> {
1301 public:
1302 using BuiltinOperator::BuiltinOperator;
1303
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1304 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1305 const TocoOperator& op,
1306 flatbuffers::FlatBufferBuilder* builder) const override {
1307 auto padding = Padding::Serialize(op.padding.type);
1308 return ::tflite::CreateTransposeConvOptions(
1309 *builder, padding, op.stride_width, op.stride_height);
1310 }
1311
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1312 void ReadOptions(const TfLiteOptions& options,
1313 TocoOperator* op) const override {
1314 op->padding.type = Padding::Deserialize(options.padding());
1315 op->stride_width = options.stride_w();
1316 op->stride_height = options.stride_h();
1317 }
1318 };
1319
1320 class SparseToDense
1321 : public BuiltinOperator<SparseToDenseOperator,
1322 ::tflite::SparseToDenseOptions,
1323 ::tflite::BuiltinOptions_SparseToDenseOptions> {
1324 public:
1325 using BuiltinOperator::BuiltinOperator;
1326
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1327 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1328 const TocoOperator& op,
1329 flatbuffers::FlatBufferBuilder* builder) const override {
1330 return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices);
1331 }
1332
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1333 void ReadOptions(const TfLiteOptions& options,
1334 TocoOperator* op) const override {
1335 op->validate_indices = options.validate_indices();
1336 }
1337 };
1338
1339 class ExpandDims
1340 : public BuiltinOperator<ExpandDimsOperator, ::tflite::ExpandDimsOptions,
1341 ::tflite::BuiltinOptions_ExpandDimsOptions> {
1342 public:
1343 using BuiltinOperator::BuiltinOperator;
1344
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1345 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1346 const TocoOperator& op,
1347 flatbuffers::FlatBufferBuilder* builder) const override {
1348 return ::tflite::CreateExpandDimsOptions(*builder);
1349 }
1350
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1351 void ReadOptions(const TfLiteOptions& options,
1352 TocoOperator* op) const override {}
1353 };
1354
1355 class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions,
1356 ::tflite::BuiltinOptions_PackOptions> {
1357 public:
1358 using BuiltinOperator::BuiltinOperator;
1359
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1360 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1361 const TocoOperator& op,
1362 flatbuffers::FlatBufferBuilder* builder) const override {
1363 return ::tflite::CreatePackOptions(*builder, op.values_count, op.axis);
1364 }
1365
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1366 void ReadOptions(const TfLiteOptions& options,
1367 TocoOperator* op) const override {
1368 op->values_count = options.values_count();
1369 op->axis = options.axis();
1370 }
1371 };
1372
1373 class Shape
1374 : public BuiltinOperator<TensorFlowShapeOperator, ::tflite::ShapeOptions,
1375 ::tflite::BuiltinOptions_ShapeOptions> {
1376 public:
1377 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1378 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1379 const TocoOperator& op,
1380 flatbuffers::FlatBufferBuilder* builder) const override {
1381 return ::tflite::CreateShapeOptions(
1382 *builder, DataType::Serialize(op.output_data_type));
1383 }
1384
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1385 void ReadOptions(const TfLiteOptions& options,
1386 TocoOperator* op) const override {
1387 op->output_data_type = DataType::Deserialize(options.out_type());
1388 }
1389 };
1390
1391 class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
1392 ::tflite::BuiltinOptions_OneHotOptions> {
1393 public:
1394 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1395 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1396 const TocoOperator& op,
1397 flatbuffers::FlatBufferBuilder* builder) const override {
1398 return ::tflite::CreateOneHotOptions(*builder, op.axis);
1399 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1400 void ReadOptions(const TfLiteOptions& options,
1401 TocoOperator* op) const override {
1402 op->axis = options.axis();
1403 }
1404 };
1405
1406 class CTCBeamSearchDecoder
1407 : public CustomOperator<CTCBeamSearchDecoderOperator> {
1408 public:
1409 using CustomOperator::CustomOperator;
1410
WriteOptions(const TocoOperator & op,flexbuffers::Builder * fbb) const1411 void WriteOptions(const TocoOperator& op,
1412 flexbuffers::Builder* fbb) const override {
1413 fbb->Int("beam_width", op.beam_width);
1414 fbb->Int("top_paths", op.top_paths);
1415 fbb->Bool("merge_repeated", op.merge_repeated);
1416 }
1417
ReadOptions(const flexbuffers::Map & m,TocoOperator * op) const1418 void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
1419 op->beam_width = m["beam_width"].AsInt32();
1420 op->top_paths = m["top_paths"].AsInt32();
1421 op->merge_repeated = m["merge_repeated"].AsBool();
1422 }
1423
GetVersion(const OperatorSignature & op_signature) const1424 int GetVersion(const OperatorSignature& op_signature) const override {
1425 return 1;
1426 }
1427 };
1428
1429 class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
1430 ::tflite::BuiltinOptions_UnpackOptions> {
1431 public:
1432 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1433 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1434 const TocoOperator& op,
1435 flatbuffers::FlatBufferBuilder* builder) const override {
1436 return ::tflite::CreateUnpackOptions(*builder, op.num, op.axis);
1437 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1438 void ReadOptions(const TfLiteOptions& options,
1439 TocoOperator* op) const override {
1440 op->num = options.num();
1441 op->axis = options.axis();
1442 }
1443
GetVersion(const OperatorSignature & op_signature) const1444 int GetVersion(const OperatorSignature& op_signature) const override {
1445 const std::string& input_name = op_signature.op->inputs[0];
1446 const Array& input_array = op_signature.model->GetArray(input_name);
1447 // If the op take int8/uint8 input, it is version 2.
1448 if (input_array.data_type == ArrayDataType::kInt8 ||
1449 input_array.data_type == ArrayDataType::kUint8) {
1450 return 2;
1451 }
1452 // If the op take bool input, it is version 3.
1453 if (input_array.data_type == ArrayDataType::kBool) {
1454 return 3;
1455 }
1456 return 1;
1457 }
1458 };
1459
1460 class LeakyRelu
1461 : public BuiltinOperator<LeakyReluOperator, ::tflite::LeakyReluOptions,
1462 ::tflite::BuiltinOptions_LeakyReluOptions> {
1463 public:
1464 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1465 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1466 const TocoOperator& op,
1467 flatbuffers::FlatBufferBuilder* builder) const override {
1468 return ::tflite::CreateLeakyReluOptions(*builder, op.alpha);
1469 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1470 void ReadOptions(const TfLiteOptions& options,
1471 TocoOperator* op) const override {
1472 op->alpha = options.alpha();
1473 }
1474 };
1475
1476 class SquaredDifference
1477 : public BuiltinOperator<
1478 SquaredDifferenceOperator, ::tflite::SquaredDifferenceOptions,
1479 ::tflite::BuiltinOptions_SquaredDifferenceOptions> {
1480 public:
1481 using BuiltinOperator::BuiltinOperator;
1482
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1483 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1484 const TocoOperator& op,
1485 flatbuffers::FlatBufferBuilder* builder) const override {
1486 return ::tflite::CreateSquaredDifferenceOptions(*builder);
1487 }
1488
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1489 void ReadOptions(const TfLiteOptions& options,
1490 TocoOperator* op) const override {}
1491 };
1492
1493 class MirrorPad
1494 : public BuiltinOperator<MirrorPadOperator, ::tflite::MirrorPadOptions,
1495 ::tflite::BuiltinOptions_MirrorPadOptions> {
1496 public:
1497 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1498 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1499 const TocoOperator& op,
1500 flatbuffers::FlatBufferBuilder* builder) const override {
1501 return ::tflite::CreateMirrorPadOptions(
1502 *builder, op.mode == MirrorPadMode::kReflect
1503 ? ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
1504 : ::tflite::MirrorPadMode::MirrorPadMode_SYMMETRIC);
1505 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1506 void ReadOptions(const TfLiteOptions& options,
1507 TocoOperator* op) const override {
1508 op->mode = options.mode() == ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
1509 ? MirrorPadMode::kReflect
1510 : MirrorPadMode::kSymmetric;
1511 }
1512 };
1513
1514 class Unique : public BuiltinOperator<UniqueOperator, ::tflite::UniqueOptions,
1515 ::tflite::BuiltinOptions_UniqueOptions> {
1516 public:
1517 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1518 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1519 const TocoOperator& op,
1520 flatbuffers::FlatBufferBuilder* builder) const override {
1521 const UniqueOperator& unique_op = static_cast<const UniqueOperator&>(op);
1522 return ::tflite::CreateUniqueOptions(
1523 *builder, unique_op.idx_out_type == toco::ArrayDataType::kInt64
1524 ? ::tflite::TensorType::TensorType_INT64
1525 : ::tflite::TensorType_INT32);
1526 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1527 void ReadOptions(const TfLiteOptions& options,
1528 TocoOperator* op) const override {
1529 UniqueOperator* unique_op = static_cast<UniqueOperator*>(op);
1530 unique_op->idx_out_type =
1531 options.idx_out_type() == ::tflite::TensorType_INT64
1532 ? toco::ArrayDataType::kInt64
1533 : toco::ArrayDataType::kInt32;
1534 }
1535 };
1536
1537 class UnidirectionalSequenceRnn
1538 : public BuiltinOperator<UnidirectionalSequenceRnnOperator,
1539 ::tflite::SequenceRNNOptions,
1540 ::tflite::BuiltinOptions_SequenceRNNOptions> {
1541 public:
1542 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1543 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1544 const TocoOperator& op,
1545 flatbuffers::FlatBufferBuilder* builder) const override {
1546 return ::tflite::CreateSequenceRNNOptions(
1547 *builder, /*time_major=*/true,
1548 /*fused_activation_function=*/
1549 ::tflite::ActivationFunctionType_TANH);
1550 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1551 void ReadOptions(const TfLiteOptions& options,
1552 TocoOperator* op) const override {
1553 // Only support tanh activation, so check that tflite type is tanh.
1554 DCHECK(options.fused_activation_function() ==
1555 ::tflite::ActivationFunctionType_TANH);
1556 }
1557
GetMutatingInputVariables(const Operator & op) const1558 std::vector<bool> GetMutatingInputVariables(
1559 const Operator& op) const override {
1560 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
1561 mutating_input_variables[4] = true;
1562 return mutating_input_variables;
1563 }
1564 };
1565
1566 class Where : public BuiltinOperator<WhereOperator, ::tflite::WhereOptions,
1567 ::tflite::BuiltinOptions_WhereOptions> {
1568 public:
1569 using BuiltinOperator::BuiltinOperator;
1570
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1571 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1572 const TocoOperator& op,
1573 flatbuffers::FlatBufferBuilder* builder) const override {
1574 return ::tflite::CreateWhereOptions(*builder);
1575 }
1576
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1577 void ReadOptions(const TfLiteOptions& options,
1578 TocoOperator* op) const override {}
1579 };
1580
WriteFlexOpOptions(const std::string & tensorflow_node_def)1581 std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
1582 const std::string& tensorflow_node_def) {
1583 auto fbb = std::make_unique<flexbuffers::Builder>();
1584
1585 ::tensorflow::NodeDef node_def;
1586 if (!node_def.ParseFromString(tensorflow_node_def)) {
1587 LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
1588 return {};
1589 }
1590
1591 fbb->Vector([&]() {
1592 fbb->String(node_def.op());
1593 fbb->String(tensorflow_node_def);
1594 });
1595 fbb->Finish();
1596 LOG(INFO) << "Writing flex op: " << node_def.op();
1597 return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1598 }
1599
1600 class TensorFlowUnsupported : public BaseOperator {
1601 public:
TensorFlowUnsupported(const std::string & name,OperatorType type,bool enable_select_tf_ops)1602 TensorFlowUnsupported(const std::string& name, OperatorType type,
1603 bool enable_select_tf_ops)
1604 : BaseOperator(name, type), enable_select_tf_ops_(enable_select_tf_ops) {}
1605
Serialize(const Operator & op,flatbuffers::FlatBufferBuilder * builder) const1606 Options Serialize(const Operator& op,
1607 flatbuffers::FlatBufferBuilder* builder) const override {
1608 auto fbb =
1609 WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op));
1610 if (fbb) {
1611 return Options::Custom(builder->CreateVector(fbb->GetBuffer()));
1612 } else {
1613 return Options::Custom(0);
1614 }
1615 }
1616
Deserialize(const BuiltinOptions * builtin_options,const CustomOptions * custom_options) const1617 std::unique_ptr<Operator> Deserialize(
1618 const BuiltinOptions* builtin_options,
1619 const CustomOptions* custom_options) const override {
1620 // Deserializing Flex ops doesn't work now.
1621 // TODO(ycling): Revisit and decide if we should fix the flow for importing
1622 // TFLite models with Flex ops.
1623 auto op = std::make_unique<TensorFlowUnsupportedOperator>();
1624 if (custom_options) {
1625 auto flexbuffer_map =
1626 flexbuffers::GetRoot(custom_options->data(), custom_options->size())
1627 .AsMap();
1628 ReadOptions(flexbuffer_map, op.get());
1629 }
1630 return std::unique_ptr<Operator>(op.release());
1631 }
1632
WriteOptions(const TensorFlowUnsupportedOperator & op) const1633 std::unique_ptr<flexbuffers::Builder> WriteOptions(
1634 const TensorFlowUnsupportedOperator& op) const {
1635 if (enable_select_tf_ops_) {
1636 return WriteFlexOpOptions(op.tensorflow_node_def);
1637 }
1638 auto fbb = std::make_unique<flexbuffers::Builder>();
1639
1640 ::tensorflow::NodeDef node_def;
1641 if (!node_def.ParseFromString(op.tensorflow_node_def)) {
1642 LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
1643 return std::unique_ptr<flexbuffers::Builder>();
1644 }
1645
1646 if (ShouldExportAsFlexOp(enable_select_tf_ops_, node_def.op())) {
1647 fbb->Vector([&]() {
1648 fbb->String(node_def.op());
1649 fbb->String(op.tensorflow_node_def);
1650 });
1651 fbb->Finish();
1652 LOG(INFO) << "Writing flex op: " << node_def.op();
1653 return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1654 }
1655
1656 bool has_valid_attr = false;
1657 size_t map_start = fbb->StartMap();
1658 for (const auto& pair : node_def.attr()) {
1659 const char* key = pair.first.c_str();
1660 const auto& attr = pair.second;
1661 switch (attr.value_case()) {
1662 case ::tensorflow::AttrValue::kS:
1663 fbb->String(key, attr.s());
1664 has_valid_attr = true;
1665 break;
1666 case ::tensorflow::AttrValue::kI:
1667 fbb->Int(key, attr.i());
1668 has_valid_attr = true;
1669 break;
1670 case ::tensorflow::AttrValue::kF:
1671 fbb->Float(key, attr.f());
1672 has_valid_attr = true;
1673 break;
1674 case ::tensorflow::AttrValue::kB:
1675 fbb->Bool(key, attr.b());
1676 has_valid_attr = true;
1677 break;
1678 case tensorflow::AttrValue::kList:
1679 if (attr.list().s_size() > 0) {
1680 auto start = fbb->StartVector(key);
1681 for (const std::string& v : attr.list().s()) {
1682 fbb->Add(v);
1683 }
1684 fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1685 has_valid_attr = true;
1686 } else if (attr.list().i_size() > 0) {
1687 auto start = fbb->StartVector(key);
1688 for (const int64_t v : attr.list().i()) {
1689 fbb->Add(v);
1690 }
1691 fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1692 has_valid_attr = true;
1693 } else if (attr.list().f_size() > 0) {
1694 auto start = fbb->StartVector(key);
1695 for (const float v : attr.list().f()) {
1696 fbb->Add(v);
1697 }
1698 fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1699 has_valid_attr = true;
1700 } else {
1701 LOG(WARNING)
1702 << "Ignoring unsupported type in list attribute with key '"
1703 << key << "'";
1704 }
1705 break;
1706 default:
1707 LOG(WARNING) << "Ignoring unsupported attribute type with key '"
1708 << key << "'";
1709 break;
1710 }
1711 }
1712 if (!has_valid_attr) {
1713 return std::unique_ptr<flexbuffers::Builder>();
1714 }
1715 fbb->EndMap(map_start);
1716 fbb->Finish();
1717 return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1718 }
1719
ReadOptions(const flexbuffers::Map & m,TensorFlowUnsupportedOperator * op) const1720 void ReadOptions(const flexbuffers::Map& m,
1721 TensorFlowUnsupportedOperator* op) const {
1722 ::tensorflow::NodeDef node_def;
1723 auto attr = node_def.mutable_attr();
1724
1725 const auto& keys = m.Keys();
1726 for (size_t i = 0; i < keys.size(); ++i) {
1727 const auto key = keys[i].AsKey();
1728 const auto& value = m[key];
1729 switch (value.GetType()) {
1730 case flexbuffers::FBT_STRING:
1731 (*attr)[key].set_s(value.AsString().c_str());
1732 break;
1733 case flexbuffers::FBT_INT:
1734 (*attr)[key].set_i(value.AsInt64());
1735 break;
1736 case flexbuffers::FBT_FLOAT:
1737 (*attr)[key].set_f(value.AsFloat());
1738 break;
1739 case flexbuffers::FBT_BOOL:
1740 (*attr)[key].set_b(value.AsBool());
1741 if (std::string(key) == "_output_quantized") {
1742 op->quantized = value.AsBool();
1743 }
1744 if (std::string(key) ==
1745 "_support_output_type_float_in_quantized_op") {
1746 op->support_output_type_float_in_quantized_op = value.AsBool();
1747 }
1748 break;
1749 case flexbuffers::FBT_VECTOR_INT: {
1750 auto* list = (*attr)[key].mutable_list();
1751 const auto& vector = value.AsTypedVector();
1752 for (size_t i = 0; i < vector.size(); i++) {
1753 list->add_i(vector[i].AsInt64());
1754 }
1755 break;
1756 }
1757 case flexbuffers::FBT_VECTOR_FLOAT: {
1758 auto* list = (*attr)[key].mutable_list();
1759 const auto& vector = value.AsTypedVector();
1760 for (size_t i = 0; i < vector.size(); i++) {
1761 list->add_f(vector[i].AsFloat());
1762 }
1763 break;
1764 }
1765 case 15 /* TO_DO(wvo): flexbuffers::FBT_VECTOR_STRING_DEPRECATED*/: {
1766 auto* list = (*attr)[key].mutable_list();
1767 const auto& vector = value.AsTypedVector();
1768 for (size_t i = 0; i < vector.size(); i++) {
1769 list->add_s(vector[i].AsString().str());
1770 }
1771 break;
1772 }
1773 default:
1774 LOG(WARNING) << "Ignoring unsupported attribute type with key '"
1775 << key << "'";
1776 break;
1777 }
1778 }
1779 node_def.SerializeToString(&op->tensorflow_node_def);
1780 }
1781
GetVersion(const OperatorSignature & op_signature) const1782 int GetVersion(const OperatorSignature& op_signature) const override {
1783 // TODO(ycling): Design and implement a way to plumb the version of
1784 // custom ops.
1785 return 1;
1786 }
1787
1788 private:
1789 const bool enable_select_tf_ops_;
1790 };
1791
1792 class Dequantize
1793 : public BuiltinOperator<DequantizeOperator, ::tflite::DequantizeOptions,
1794 ::tflite::BuiltinOptions_DequantizeOptions> {
1795 public:
1796 using BuiltinOperator::BuiltinOperator;
1797
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1798 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1799 const TocoOperator& op,
1800 flatbuffers::FlatBufferBuilder* builder) const override {
1801 return ::tflite::CreateDequantizeOptions(*builder);
1802 }
1803
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1804 void ReadOptions(const TfLiteOptions& options,
1805 TocoOperator* op) const override {}
1806 };
1807
1808 class ReverseSequence
1809 : public BuiltinOperator<ReverseSequenceOperator,
1810 ::tflite::ReverseSequenceOptions,
1811 ::tflite::BuiltinOptions_ReverseSequenceOptions> {
1812 public:
1813 using BuiltinOperator::BuiltinOperator;
1814
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1815 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1816 const TocoOperator& op,
1817 flatbuffers::FlatBufferBuilder* builder) const override {
1818 return ::tflite::CreateReverseSequenceOptions(*builder, op.seq_dim,
1819 op.batch_dim);
1820 }
1821
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1822 void ReadOptions(const TfLiteOptions& options,
1823 TocoOperator* op) const override {
1824 op->seq_dim = options.seq_dim();
1825 op->batch_dim = options.batch_dim();
1826 }
1827 };
1828
1829 namespace {
1830 // Build a vector containing all the known operators.
BuildOperatorList(bool enable_select_tf_ops=false)1831 std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
1832 bool enable_select_tf_ops = false) {
1833 std::vector<std::unique_ptr<BaseOperator>> ops;
1834 using tensorflow::MakeUnique;
1835 // Builtin Operators.
1836 ops.push_back(
1837 MakeUnique<Add>(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
1838 ops.push_back(
1839 MakeUnique<AddN>(::tflite::BuiltinOperator_ADD_N, OperatorType::kAddN));
1840 ops.push_back(
1841 MakeUnique<Div>(::tflite::BuiltinOperator_DIV, OperatorType::kDiv));
1842 ops.push_back(
1843 MakeUnique<Sub>(::tflite::BuiltinOperator_SUB, OperatorType::kSub));
1844 ops.push_back(MakeUnique<AveragePool>(
1845 ::tflite::BuiltinOperator_AVERAGE_POOL_2D, OperatorType::kAveragePool));
1846 ops.push_back(
1847 MakeUnique<SpaceToBatchND>(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND,
1848 OperatorType::kSpaceToBatchND));
1849 ops.push_back(
1850 MakeUnique<BatchToSpaceND>(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND,
1851 OperatorType::kBatchToSpaceND));
1852 ops.push_back(MakeUnique<Concatenation>(
1853 ::tflite::BuiltinOperator_CONCATENATION, OperatorType::kConcatenation));
1854 ops.push_back(MakeUnique<Convolution>(::tflite::BuiltinOperator_CONV_2D,
1855 OperatorType::kConv));
1856 ops.push_back(MakeUnique<DepthwiseConvolution>(
1857 ::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
1858 OperatorType::kDepthwiseConv));
1859 ops.push_back(MakeUnique<Dequantize>(::tflite::BuiltinOperator_DEQUANTIZE,
1860 OperatorType::kDequantize));
1861 ops.push_back(
1862 MakeUnique<FullyConnected>(::tflite::BuiltinOperator_FULLY_CONNECTED,
1863 OperatorType::kFullyConnected));
1864 ops.push_back(MakeUnique<Gather>(::tflite::BuiltinOperator_GATHER,
1865 OperatorType::kGather));
1866 ops.push_back(MakeUnique<GatherNd>(::tflite::BuiltinOperator_GATHER_ND,
1867 OperatorType::kGatherNd));
1868 ops.push_back(
1869 MakeUnique<L2Normalization>(::tflite::BuiltinOperator_L2_NORMALIZATION,
1870 OperatorType::kL2Normalization));
1871 ops.push_back(MakeUnique<L2Pool>(::tflite::BuiltinOperator_L2_POOL_2D,
1872 OperatorType::kL2Pool));
1873 ops.push_back(MakeUnique<LocalResponseNormalization>(
1874 ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
1875 OperatorType::kLocalResponseNormalization));
1876 ops.push_back(MakeUnique<MaxPool>(::tflite::BuiltinOperator_MAX_POOL_2D,
1877 OperatorType::kMaxPool));
1878 ops.push_back(
1879 MakeUnique<Mul>(::tflite::BuiltinOperator_MUL, OperatorType::kMul));
1880
1881 ops.push_back(
1882 MakeUnique<Pad>(::tflite::BuiltinOperator_PAD, OperatorType::kPad));
1883 ops.push_back(
1884 MakeUnique<PadV2>(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2));
1885 ops.push_back(MakeUnique<Reshape>(::tflite::BuiltinOperator_RESHAPE,
1886 OperatorType::kReshape));
1887 ops.push_back(MakeUnique<Softmax>(::tflite::BuiltinOperator_SOFTMAX,
1888 OperatorType::kSoftmax));
1889 ops.push_back(MakeUnique<SpaceToDepth>(
1890 ::tflite::BuiltinOperator_SPACE_TO_DEPTH, OperatorType::kSpaceToDepth));
1891 ops.push_back(MakeUnique<DepthToSpace>(
1892 ::tflite::BuiltinOperator_DEPTH_TO_SPACE, OperatorType::kDepthToSpace));
1893 ops.push_back(
1894 MakeUnique<Svdf>(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
1895 ops.push_back(MakeUnique<Transpose>(::tflite::BuiltinOperator_TRANSPOSE,
1896 OperatorType::kTranspose));
1897 ops.push_back(
1898 MakeUnique<Mean>(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
1899 ops.push_back(
1900 MakeUnique<Sum>(::tflite::BuiltinOperator_SUM, OperatorType::kSum));
1901 ops.push_back(MakeUnique<ReduceProd>(::tflite::BuiltinOperator_REDUCE_PROD,
1902 OperatorType::kReduceProd));
1903 ops.push_back(MakeUnique<ReduceMax>(::tflite::BuiltinOperator_REDUCE_MAX,
1904 OperatorType::kReduceMax));
1905 ops.push_back(MakeUnique<ReduceMin>(::tflite::BuiltinOperator_REDUCE_MIN,
1906 OperatorType::kReduceMin));
1907 ops.push_back(MakeUnique<ReduceAny>(::tflite::BuiltinOperator_REDUCE_ANY,
1908 OperatorType::kAny));
1909 ops.push_back(
1910 MakeUnique<ResizeBilinear>(::tflite::BuiltinOperator_RESIZE_BILINEAR,
1911 OperatorType::kResizeBilinear));
1912 ops.push_back(MakeUnique<ResizeNearestNeighbor>(
1913 ::tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
1914 OperatorType::kResizeNearestNeighbor));
1915 ops.push_back(MakeUnique<Squeeze>(::tflite::BuiltinOperator_SQUEEZE,
1916 OperatorType::kSqueeze));
1917 ops.push_back(
1918 MakeUnique<Split>(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit));
1919 ops.push_back(MakeUnique<SplitV>(::tflite::BuiltinOperator_SPLIT_V,
1920 OperatorType::kSplitV));
1921 ops.push_back(MakeUnique<StridedSlice>(
1922 ::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice));
1923 ops.push_back(MakeUnique<TopK_V2>(::tflite::BuiltinOperator_TOPK_V2,
1924 OperatorType::kTopK_V2));
1925 ops.push_back(MakeUnique<Lstm>(::tflite::BuiltinOperator_LSTM,
1926 OperatorType::kLstmCell));
1927 ops.push_back(
1928 MakeUnique<Cast>(::tflite::BuiltinOperator_CAST, OperatorType::kCast));
1929 ops.push_back(MakeUnique<ArgMax>(::tflite::BuiltinOperator_ARG_MAX,
1930 OperatorType::kArgMax));
1931 ops.push_back(MakeUnique<ArgMin>(::tflite::BuiltinOperator_ARG_MIN,
1932 OperatorType::kArgMin));
1933 ops.push_back(
1934 MakeUnique<Tile>(::tflite::BuiltinOperator_TILE, OperatorType::kTile));
1935 ops.push_back(MakeUnique<ExpandDims>(::tflite::BuiltinOperator_EXPAND_DIMS,
1936 OperatorType::kExpandDims));
1937 ops.push_back(MakeUnique<TransposeConv>(
1938 ::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv));
1939 ops.push_back(MakeUnique<SparseToDense>(
1940 ::tflite::BuiltinOperator_SPARSE_TO_DENSE, OperatorType::kSparseToDense));
1941 ops.push_back(
1942 MakeUnique<Shape>(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape));
1943 ops.push_back(MakeUnique<FakeQuant>(::tflite::BuiltinOperator_FAKE_QUANT,
1944 OperatorType::kFakeQuant));
1945 ops.push_back(
1946 MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
1947 ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>(
1948 ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
1949 OperatorType::kUnidirectionalSequenceLstm));
1950 ops.emplace_back(MakeUnique<BidirectionalSequenceLstm>(
1951 ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
1952 OperatorType::kBidirectionalSequenceLstm));
1953 ops.emplace_back(MakeUnique<BidirectionalSequenceRnn>(
1954 ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
1955 OperatorType::kBidirectionalSequenceRnn));
1956 ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT,
1957 OperatorType::kOneHot));
1958 ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK,
1959 OperatorType::kUnpack));
1960 ops.push_back(MakeUnique<LeakyRelu>(::tflite::BuiltinOperator_LEAKY_RELU,
1961 OperatorType::kLeakyRelu));
1962 ops.push_back(MakeUnique<SquaredDifference>(
1963 ::tflite::BuiltinOperator_SQUARED_DIFFERENCE,
1964 OperatorType::kSquaredDifference));
1965 ops.push_back(MakeUnique<MirrorPad>(::tflite::BuiltinOperator_MIRROR_PAD,
1966 OperatorType::kMirrorPad));
1967 ops.push_back(MakeUnique<Unique>(::tflite::BuiltinOperator_UNIQUE,
1968 OperatorType::kUnique));
1969 ops.push_back(MakeUnique<UnidirectionalSequenceRnn>(
1970 ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
1971 OperatorType::kUnidirectionalSequenceRnn));
1972 ops.push_back(
1973 MakeUnique<Where>(::tflite::BuiltinOperator_WHERE, OperatorType::kWhere));
1974 ops.push_back(
1975 MakeUnique<ReverseSequence>(::tflite::BuiltinOperator_REVERSE_SEQUENCE,
1976 OperatorType::kReverseSequence));
1977 ops.push_back(MakeUnique<SimpleOperator<MatrixDiagOperator>>(
1978 ::tflite::BuiltinOperator_MATRIX_DIAG, OperatorType::kMatrixDiag));
1979 ops.push_back(MakeUnique<SimpleOperator<MatrixSetDiagOperator>>(
1980 ::tflite::BuiltinOperator_MATRIX_SET_DIAG, OperatorType::kMatrixSetDiag));
1981 // Custom Operators.
1982 ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
1983 "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
1984 ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED",
1985 OperatorType::kUnsupported,
1986 enable_select_tf_ops));
1987
1988 // SimpleOperator was designed to export CUSTOM TF Lite ops, but has since
1989 // been modified to also export builtins. As TOCO evolved we added warnings
1990 // when custom ops are exported but SimpleOperator bypasses thoses. To
1991 // prevent user confusion we are settling on using SimpleOperator only for
1992 // builtins.
1993 ops.push_back(MakeUnique<SimpleOperator<FloorOperator>>(
1994 ::tflite::BuiltinOperator_FLOOR, OperatorType::kFloor));
1995 ops.push_back(MakeUnique<SimpleOperator<CeilOperator>>(
1996 ::tflite::BuiltinOperator_CEIL, OperatorType::kCeil));
1997 ops.push_back(MakeUnique<SimpleOperator<EluOperator>>(
1998 ::tflite::BuiltinOperator_ELU, OperatorType::kElu));
1999 ops.push_back(MakeUnique<SimpleOperator<RoundOperator>>(
2000 ::tflite::BuiltinOperator_ROUND, OperatorType::kRound));
2001 ops.push_back(MakeUnique<SimpleOperator<ReluOperator>>(
2002 ::tflite::BuiltinOperator_RELU, OperatorType::kRelu));
2003 ops.push_back(MakeUnique<SimpleOperator<Relu1Operator>>(
2004 ::tflite::BuiltinOperator_RELU_N1_TO_1, OperatorType::kRelu1));
2005 ops.push_back(MakeUnique<SimpleOperator<Relu6Operator>>(
2006 ::tflite::BuiltinOperator_RELU6, OperatorType::kRelu6));
2007 ops.push_back(MakeUnique<SimpleOperator<PReluOperator>>(
2008 ::tflite::BuiltinOperator_PRELU, OperatorType::kPRelu));
2009 ops.push_back(MakeUnique<SimpleOperator<LogisticOperator>>(
2010 ::tflite::BuiltinOperator_LOGISTIC, OperatorType::kLogistic));
2011 ops.push_back(MakeUnique<SimpleOperator<TanhOperator>>(
2012 ::tflite::BuiltinOperator_TANH, OperatorType::kTanh));
2013 ops.push_back(MakeUnique<SimpleOperator<ExpOperator>>(
2014 ::tflite::BuiltinOperator_EXP, OperatorType::kExp));
2015 ops.push_back(MakeUnique<SimpleOperator<CosOperator>>(
2016 ::tflite::BuiltinOperator_COS, OperatorType::kCos));
2017 ops.push_back(MakeUnique<SimpleOperator<LogSoftmaxOperator>>(
2018 ::tflite::BuiltinOperator_LOG_SOFTMAX, OperatorType::kLogSoftmax));
2019 ops.push_back(MakeUnique<SimpleOperator<TensorFlowMaximumOperator>>(
2020 ::tflite::BuiltinOperator_MAXIMUM, OperatorType::kMaximum));
2021 ops.push_back(MakeUnique<SimpleOperator<TensorFlowMinimumOperator>>(
2022 ::tflite::BuiltinOperator_MINIMUM, OperatorType::kMinimum));
2023 ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterOperator>>(
2024 ::tflite::BuiltinOperator_GREATER, OperatorType::kGreater));
2025 ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterEqualOperator>>(
2026 ::tflite::BuiltinOperator_GREATER_EQUAL, OperatorType::kGreaterEqual));
2027 ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessOperator>>(
2028 ::tflite::BuiltinOperator_LESS, OperatorType::kLess));
2029 ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessEqualOperator>>(
2030 ::tflite::BuiltinOperator_LESS_EQUAL, OperatorType::kLessEqual));
2031 ops.push_back(MakeUnique<SimpleOperator<TensorFlowEqualOperator>>(
2032 ::tflite::BuiltinOperator_EQUAL, OperatorType::kEqual));
2033 ops.push_back(MakeUnique<SimpleOperator<TensorFlowNotEqualOperator>>(
2034 ::tflite::BuiltinOperator_NOT_EQUAL, OperatorType::kNotEqual));
2035 ops.push_back(MakeUnique<SimpleOperator<NegOperator>>(
2036 ::tflite::BuiltinOperator_NEG, OperatorType::kNeg));
2037 ops.push_back(MakeUnique<SimpleOperator<SelectOperator>>(
2038 ::tflite::BuiltinOperator_SELECT, OperatorType::kSelect));
2039 ops.push_back(MakeUnique<SimpleOperator<SliceOperator>>(
2040 ::tflite::BuiltinOperator_SLICE, OperatorType::kSlice));
2041 ops.push_back(MakeUnique<SimpleOperator<PowOperator>>(
2042 ::tflite::BuiltinOperator_POW, OperatorType::kPow));
2043 ops.push_back(MakeUnique<SimpleOperator<LogicalOrOperator>>(
2044 ::tflite::BuiltinOperator_LOGICAL_OR, OperatorType::kLogicalOr));
2045 ops.emplace_back(new SimpleOperator<LogicalAndOperator>(
2046 ::tflite::BuiltinOperator_LOGICAL_AND, OperatorType::kLogicalAnd));
2047 ops.emplace_back(new SimpleOperator<LogicalNotOperator>(
2048 ::tflite::BuiltinOperator_LOGICAL_NOT, OperatorType::kLogicalNot));
2049 ops.emplace_back(new SimpleOperator<FloorDivOperator>(
2050 ::tflite::BuiltinOperator_FLOOR_DIV, OperatorType::kFloorDiv));
2051 ops.emplace_back(new SimpleOperator<FloorModOperator>(
2052 ::tflite::BuiltinOperator_FLOOR_MOD, OperatorType::kFloorMod));
2053 ops.emplace_back(new SimpleOperator<RangeOperator>(
2054 ::tflite::BuiltinOperator_RANGE, OperatorType::kRange));
2055 // Element-wise operator
2056 ops.push_back(MakeUnique<SimpleOperator<SinOperator>>(
2057 ::tflite::BuiltinOperator_SIN, OperatorType::kSin));
2058 ops.push_back(MakeUnique<SimpleOperator<LogOperator>>(
2059 ::tflite::BuiltinOperator_LOG, OperatorType::kLog));
2060 ops.push_back(MakeUnique<SimpleOperator<TensorFlowSqrtOperator>>(
2061 ::tflite::BuiltinOperator_SQRT, OperatorType::kSqrt));
2062 ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
2063 ::tflite::BuiltinOperator_RSQRT, OperatorType::kRsqrt));
2064 ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
2065 ::tflite::BuiltinOperator_SQUARE, OperatorType::kSquare));
2066 ops.push_back(MakeUnique<SimpleOperator<TensorFlowZerosLikeOperator>>(
2067 ::tflite::BuiltinOperator_ZEROS_LIKE, OperatorType::kZerosLike));
2068 ops.push_back(MakeUnique<SimpleOperator<AbsOperator>>(
2069 ::tflite::BuiltinOperator_ABS, OperatorType::kAbs));
2070 ops.push_back(MakeUnique<SimpleOperator<HardSwishOperator>>(
2071 ::tflite::BuiltinOperator_HARD_SWISH, OperatorType::kHardSwish));
2072 ops.push_back(MakeUnique<SimpleOperator<FillOperator>>(
2073 ::tflite::BuiltinOperator_FILL, OperatorType::kFill));
2074 ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>(
2075 ::tflite::BuiltinOperator_REVERSE_V2, OperatorType::kReverseV2));
2076 ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>(
2077 ::tflite::BuiltinOperator_RANK, OperatorType::kRank));
2078 ops.emplace_back(new SimpleOperator<SegmentSumOperator>(
2079 ::tflite::BuiltinOperator_SEGMENT_SUM, OperatorType::kSegmentSum));
2080 ops.emplace_back(MakeUnique<SimpleOperator<ScatterNdOperator>>(
2081 ::tflite::BuiltinOperator_SCATTER_ND, OperatorType::kScatterNd));
2082 return ops;
2083 }
2084 } // namespace
2085
2086 // LINT.ThenChange(//tensorflow/lite/tools/versioning/op_version.cc)
2087
BuildOperatorByTypeMap(bool enable_select_tf_ops)2088 std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
2089 bool enable_select_tf_ops) {
2090 std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
2091
2092 std::vector<std::unique_ptr<BaseOperator>> ops =
2093 BuildOperatorList(enable_select_tf_ops);
2094 for (auto& op : ops) {
2095 result[op->type()] = std::move(op);
2096 }
2097
2098 return result;
2099 }
2100
BuildOperatorByNameMap(bool enable_select_tf_ops)2101 std::map<std::string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
2102 bool enable_select_tf_ops) {
2103 std::map<std::string, std::unique_ptr<BaseOperator>> result;
2104
2105 std::vector<std::unique_ptr<BaseOperator>> ops =
2106 BuildOperatorList(enable_select_tf_ops);
2107 for (auto& op : ops) {
2108 result[op->name()] = std::move(op);
2109 }
2110
2111 return result;
2112 }
2113
ShouldExportAsFlexOp(bool enable_select_tf_ops,const std::string & tensorflow_op_name)2114 bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
2115 const std::string& tensorflow_op_name) {
2116 // If Flex ops aren't allow at all, simply return false.
2117 if (!enable_select_tf_ops) {
2118 return false;
2119 }
2120 // Check if we can find the `OpDef` for the TensorFlow op. If we can find
2121 // it and it has been allowlisted, export the op as an Flex op. Otherwise,
2122 // export it as a regular custom op.
2123 const tensorflow::OpDef* op_def = nullptr;
2124 if (!tensorflow::OpRegistry::Global()
2125 ->LookUpOpDef(tensorflow_op_name, &op_def)
2126 .ok()) {
2127 return false;
2128 }
2129
2130 if (!::tflite::flex::IsAllowlistedFlexOp(tensorflow_op_name)) {
2131 LOG(WARNING) << "Op " << tensorflow_op_name
2132 << " is a valid TensorFlow op but has not been allowlisted for"
2133 " the TensorFlow Lite flex op set.";
2134 return false;
2135 }
2136
2137 return true;
2138 }
2139
2140 } // namespace tflite
2141
2142 } // namespace toco
2143