xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/toco/tflite/operator_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/operator.h"
16 
17 #include <string>
18 
19 #include "flatbuffers/flexbuffers.h"
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/lite/toco/model.h"
25 #include "tensorflow/lite/toco/tooling_util.h"
26 
27 namespace toco {
28 
29 namespace tflite {
30 namespace {
31 
32 class OperatorTest : public ::testing::Test {
33  protected:
34   // Return the operator for the given name and type.
GetOperator(const std::string & name,OperatorType type)35   const BaseOperator& GetOperator(const std::string& name, OperatorType type) {
36     using OpsByName = std::map<std::string, std::unique_ptr<BaseOperator>>;
37     using OpsByType = std::map<OperatorType, std::unique_ptr<BaseOperator>>;
38 
39     static auto* by_name = new OpsByName(BuildOperatorByNameMap());
40     static auto* by_type = new OpsByType(BuildOperatorByTypeMap());
41 
42     // Make sure the two maps were consistently built.
43     CHECK(by_name->count(name)) << "No operator for '" << name << "'.";
44     BaseOperator* op1 = by_name->at(name).get();
45     CHECK(op1->type() == type) << "while verifying '" << name << "'.";
46 
47     CHECK(by_type->count(type))
48         << "No operator for '" << OperatorTypeName(type) << "'.";
49     BaseOperator* op2 = by_type->at(type).get();
50     CHECK(op2->name() == name)
51         << "while verifying '" << OperatorTypeName(type) << "'.";
52 
53     return *op1;
54   }
55 
56   // Use the given BaseOperator to serialize the tf.mini operator into a set of
57   // TF Lite options. Proceed to deserialize the options back into a new
58   // tf.mini operator, which is then returned. If `options` is given, it will
59   // be populated with the serialized options.
60   template <typename T>
SerializeAndDeserialize(const BaseOperator & op,const T & toco_op,Options * options=nullptr)61   std::unique_ptr<T> SerializeAndDeserialize(const BaseOperator& op,
62                                              const T& toco_op,
63                                              Options* options = nullptr) {
64     flatbuffers::FlatBufferBuilder builder;
65     Options input_options = op.Serialize(toco_op, &builder);
66 
67     if (options) {
68       *options = input_options;
69     }
70 
71     builder.Finish(CreateOperator(builder, 0, 0, 0, input_options.type,
72                                   input_options.builtin, input_options.custom,
73                                   ::tflite::CustomOptionsFormat_FLEXBUFFERS));
74     auto* output_options =
75         flatbuffers::GetRoot<::tflite::Operator>(builder.GetBufferPointer());
76     auto new_toco_op = op.Deserialize(output_options->builtin_options(),
77                                       output_options->custom_options());
78 
79     CHECK(new_toco_op->type == toco_op.type)
80         << "The type of the serialized and deserialized"
81         << HelpfulOperatorTypeName(*new_toco_op)
82         << " does not match the type of the original "
83         << HelpfulOperatorTypeName(toco_op);
84 
85     return std::unique_ptr<T>(dynamic_cast<T*>(new_toco_op.release()));
86   }
87 
88   // Verify serialization and deserialization of simple operators (those
89   // that don't have any configuration parameters).
90   template <typename T>
CheckSimpleOperator(const std::string & name,OperatorType type)91   void CheckSimpleOperator(const std::string& name, OperatorType type) {
92     Options options;
93     auto output_toco_op =
94         SerializeAndDeserialize(GetOperator(name, type), T(), &options);
95 
96     ASSERT_EQ(0, options.builtin.o);
97     ASSERT_EQ(0, options.custom.o);
98     ASSERT_EQ(::tflite::BuiltinOptions_NONE, options.type);
99 
100     ASSERT_NE(nullptr, output_toco_op.get());
101   }
102 
103   template <typename T>
CheckReducerOperator(const std::string & name,OperatorType type)104   void CheckReducerOperator(const std::string& name, OperatorType type) {
105     T op;
106 
107     op.keep_dims = false;
108 
109     auto output_toco_op = SerializeAndDeserialize(GetOperator(name, type), op);
110     EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims);
111   }
112 };
113 
TEST_F(OperatorTest,SimpleOperators)114 TEST_F(OperatorTest, SimpleOperators) {
115   CheckSimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor);
116   CheckSimpleOperator<CeilOperator>("CEIL", OperatorType::kCeil);
117   CheckSimpleOperator<EluOperator>("ELU", OperatorType::kElu);
118   CheckSimpleOperator<RoundOperator>("ROUND", OperatorType::kRound);
119   CheckSimpleOperator<ReluOperator>("RELU", OperatorType::kRelu);
120   CheckSimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1);
121   CheckSimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6);
122   CheckSimpleOperator<LogisticOperator>("LOGISTIC", OperatorType::kLogistic);
123   CheckSimpleOperator<TanhOperator>("TANH", OperatorType::kTanh);
124   CheckSimpleOperator<ExpOperator>("EXP", OperatorType::kExp);
125   CheckSimpleOperator<CosOperator>("COS", OperatorType::kCos);
126   CheckSimpleOperator<LogSoftmaxOperator>("LOG_SOFTMAX",
127                                           OperatorType::kLogSoftmax);
128   CheckSimpleOperator<TensorFlowMaximumOperator>(
129       "MAXIMUM", OperatorType::kMaximum);  //  Element-wise Maximum
130   CheckSimpleOperator<TensorFlowMinimumOperator>(
131       "MINIMUM", OperatorType::kMinimum);  //  Element-wise Minimum
132   CheckSimpleOperator<TensorFlowLessOperator>("LESS", OperatorType::kLess);
133   CheckSimpleOperator<NegOperator>("NEG", OperatorType::kNeg);
134   CheckSimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect);
135   CheckSimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice);
136   CheckSimpleOperator<SinOperator>("SIN", OperatorType::kSin);
137   CheckSimpleOperator<TensorFlowEqualOperator>("EQUAL", OperatorType::kEqual);
138   CheckSimpleOperator<TensorFlowNotEqualOperator>("NOT_EQUAL",
139                                                   OperatorType::kNotEqual);
140   CheckSimpleOperator<LogOperator>("LOG", OperatorType::kLog);
141   CheckSimpleOperator<TensorFlowSqrtOperator>("SQRT", OperatorType::kSqrt);
142   CheckSimpleOperator<TensorFlowRsqrtOperator>("RSQRT", OperatorType::kRsqrt);
143   CheckSimpleOperator<PowOperator>("POW", OperatorType::kPow);
144   CheckSimpleOperator<LogicalOrOperator>("LOGICAL_OR",
145                                          OperatorType::kLogicalOr);
146   CheckSimpleOperator<LogicalAndOperator>("LOGICAL_AND",
147                                           OperatorType::kLogicalAnd);
148   CheckSimpleOperator<LogicalNotOperator>("LOGICAL_NOT",
149                                           OperatorType::kLogicalNot);
150   CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv);
151   CheckSimpleOperator<TensorFlowSquareOperator>("SQUARE",
152                                                 OperatorType::kSquare);
153   CheckSimpleOperator<TensorFlowZerosLikeOperator>("ZEROS_LIKE",
154                                                    OperatorType::kZerosLike);
155   CheckSimpleOperator<FloorModOperator>("FLOOR_MOD", OperatorType::kFloorMod);
156   CheckSimpleOperator<RangeOperator>("RANGE", OperatorType::kRange);
157   CheckSimpleOperator<FillOperator>("FILL", OperatorType::kFill);
158   CheckSimpleOperator<ReverseV2Operator>("REVERSE_V2",
159                                          OperatorType::kReverseV2);
160   CheckSimpleOperator<TensorFlowRankOperator>("RANK", OperatorType::kRank);
161 }
162 
TEST_F(OperatorTest,BuiltinAdd)163 TEST_F(OperatorTest, BuiltinAdd) {
164   AddOperator op;
165   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
166   auto output_toco_op =
167       SerializeAndDeserialize(GetOperator("ADD", OperatorType::kAdd), op);
168   EXPECT_EQ(op.fused_activation_function,
169             output_toco_op->fused_activation_function);
170 }
171 
TEST_F(OperatorTest,BuiltinAddN)172 TEST_F(OperatorTest, BuiltinAddN) {
173   AddNOperator op;
174   auto output_toco_op =
175       SerializeAndDeserialize(GetOperator("ADD_N", OperatorType::kAddN), op);
176   ASSERT_NE(output_toco_op.get(), nullptr);
177 }
178 
TEST_F(OperatorTest,BuiltinReducerOps)179 TEST_F(OperatorTest, BuiltinReducerOps) {
180   CheckReducerOperator<MeanOperator>("MEAN", OperatorType::kMean);
181   CheckReducerOperator<TensorFlowSumOperator>("SUM", OperatorType::kSum);
182   CheckReducerOperator<TensorFlowProdOperator>("REDUCE_PROD",
183                                                OperatorType::kReduceProd);
184   CheckReducerOperator<TensorFlowMaxOperator>("REDUCE_MAX",
185                                               OperatorType::kReduceMax);
186   CheckReducerOperator<TensorFlowMinOperator>("REDUCE_MIN",
187                                               OperatorType::kReduceMin);
188   CheckReducerOperator<TensorFlowAnyOperator>("REDUCE_ANY", OperatorType::kAny);
189 }
190 
TEST_F(OperatorTest,BuiltinCast)191 TEST_F(OperatorTest, BuiltinCast) {
192   CastOperator op;
193   op.src_data_type = ArrayDataType::kFloat;
194   op.dst_data_type = ArrayDataType::kUint8;
195   auto output_toco_op =
196       SerializeAndDeserialize(GetOperator("CAST", OperatorType::kCast), op);
197   EXPECT_EQ(op.src_data_type, output_toco_op->src_data_type);
198   EXPECT_EQ(op.dst_data_type, output_toco_op->dst_data_type);
199 }
200 
TEST_F(OperatorTest,CustomConcatenation)201 TEST_F(OperatorTest, CustomConcatenation) {
202   ConcatenationOperator op;
203   op.axis = 123;
204   auto output_toco_op = SerializeAndDeserialize(
205       GetOperator("CONCATENATION", OperatorType::kConcatenation), op);
206   EXPECT_EQ(op.axis, output_toco_op->axis);
207 }
208 
TEST_F(OperatorTest,CustomDepthToSpace)209 TEST_F(OperatorTest, CustomDepthToSpace) {
210   DepthToSpaceOperator op;
211   op.block_size = 123;
212   auto output_toco_op = SerializeAndDeserialize(
213       GetOperator("DEPTH_TO_SPACE", OperatorType::kDepthToSpace), op);
214   EXPECT_EQ(op.block_size, output_toco_op->block_size);
215 }
216 
TEST_F(OperatorTest,CustomFakeQuant)217 TEST_F(OperatorTest, CustomFakeQuant) {
218   FakeQuantOperator op;
219   auto* minmax = new MinMax;
220   minmax->min = -10;
221   minmax->max = 200;
222   op.minmax.reset(minmax);
223   op.num_bits = 16;
224   auto output_toco_op = SerializeAndDeserialize(
225       GetOperator("FAKE_QUANT", OperatorType::kFakeQuant), op);
226   EXPECT_EQ(op.minmax->min, output_toco_op->minmax->min);
227   EXPECT_EQ(op.minmax->max, output_toco_op->minmax->max);
228   EXPECT_EQ(op.num_bits, output_toco_op->num_bits);
229 }
230 
TEST_F(OperatorTest,CustomFullyConnected)231 TEST_F(OperatorTest, CustomFullyConnected) {
232   FullyConnectedOperator op;
233   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
234   auto output_toco_op = SerializeAndDeserialize(
235       GetOperator("FULLY_CONNECTED", OperatorType::kFullyConnected), op);
236   EXPECT_EQ(op.fused_activation_function,
237             output_toco_op->fused_activation_function);
238 }
239 
TEST_F(OperatorTest,BuiltinGather)240 TEST_F(OperatorTest, BuiltinGather) {
241   GatherOperator op;
242   auto output_toco_op =
243       SerializeAndDeserialize(GetOperator("GATHER", OperatorType::kGather), op);
244   ASSERT_NE(nullptr, output_toco_op.get());
245 }
246 
TEST_F(OperatorTest,BuiltinGatherNd)247 TEST_F(OperatorTest, BuiltinGatherNd) {
248   GatherNdOperator op;
249   auto output_toco_op = SerializeAndDeserialize(
250       GetOperator("GATHER_ND", OperatorType::kGatherNd), op);
251   ASSERT_NE(output_toco_op.get(), nullptr);
252 }
253 
TEST_F(OperatorTest,BuiltinWhere)254 TEST_F(OperatorTest, BuiltinWhere) {
255   WhereOperator op;
256   auto output_toco_op =
257       SerializeAndDeserialize(GetOperator("WHERE", OperatorType::kWhere), op);
258   ASSERT_NE(output_toco_op.get(), nullptr);
259 }
260 
TEST_F(OperatorTest,BuiltinL2Pool)261 TEST_F(OperatorTest, BuiltinL2Pool) {
262   L2PoolOperator op;
263   op.stride_width = 123;
264   op.stride_height = 124;
265   op.padding.type = PaddingType::kValid;
266   op.kwidth = 480;
267   op.kheight = 1080;
268   auto output_toco_op = SerializeAndDeserialize(
269       GetOperator("L2_POOL_2D", OperatorType::kL2Pool), op);
270   EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
271   EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
272   EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
273   EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
274   EXPECT_EQ(op.kheight, output_toco_op->kheight);
275 }
276 
TEST_F(OperatorTest,BuiltinLocalResponseNormalization)277 TEST_F(OperatorTest, BuiltinLocalResponseNormalization) {
278   LocalResponseNormalizationOperator op;
279   op.range = 123;
280   op.bias = 1.23;
281   op.alpha = 12.3;
282   op.beta = .123;
283   auto output_toco_op = SerializeAndDeserialize(
284       GetOperator("LOCAL_RESPONSE_NORMALIZATION",
285                   OperatorType::kLocalResponseNormalization),
286       op);
287   EXPECT_EQ(op.range, output_toco_op->range);
288   EXPECT_EQ(op.bias, output_toco_op->bias);
289   EXPECT_EQ(op.alpha, output_toco_op->alpha);
290   EXPECT_EQ(op.beta, output_toco_op->beta);
291 }
292 
TEST_F(OperatorTest,BuiltinMaxPool)293 TEST_F(OperatorTest, BuiltinMaxPool) {
294   MaxPoolOperator op;
295   op.stride_width = 123;
296   op.stride_height = 124;
297   op.padding.type = PaddingType::kValid;
298   op.kwidth = 480;
299   op.kheight = 1080;
300   auto output_toco_op = SerializeAndDeserialize(
301       GetOperator("MAX_POOL_2D", OperatorType::kMaxPool), op);
302   EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
303   EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
304   EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
305   EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
306   EXPECT_EQ(op.kheight, output_toco_op->kheight);
307 }
308 
TEST_F(OperatorTest,BuiltinReshape)309 TEST_F(OperatorTest, BuiltinReshape) {
310   TensorFlowReshapeOperator op;
311   op.shape = {1, 2, 4, 5, 8};
312   auto output_toco_op = SerializeAndDeserialize(
313       GetOperator("RESHAPE", OperatorType::kReshape), op);
314   EXPECT_EQ(op.shape, output_toco_op->shape);
315 }
316 
TEST_F(OperatorTest,CustomSoftmax)317 TEST_F(OperatorTest, CustomSoftmax) {
318   SoftmaxOperator op;
319   op.beta = 123.1;
320   auto output_toco_op = SerializeAndDeserialize(
321       GetOperator("SOFTMAX", OperatorType::kSoftmax), op);
322   EXPECT_EQ(op.beta, output_toco_op->beta);
323 }
324 
TEST_F(OperatorTest,BuiltinSpaceToDepth)325 TEST_F(OperatorTest, BuiltinSpaceToDepth) {
326   SpaceToDepthOperator op;
327   op.block_size = 123;
328   auto output_toco_op = SerializeAndDeserialize(
329       GetOperator("SPACE_TO_DEPTH", OperatorType::kSpaceToDepth), op);
330   EXPECT_EQ(op.block_size, output_toco_op->block_size);
331 }
332 
TEST_F(OperatorTest,CustomSplit)333 TEST_F(OperatorTest, CustomSplit) {
334   TensorFlowSplitOperator op;
335   op.num_split = 123;
336   auto output_toco_op =
337       SerializeAndDeserialize(GetOperator("SPLIT", OperatorType::kSplit), op);
338   EXPECT_EQ(op.num_split, output_toco_op->num_split);
339 }
340 
TEST_F(OperatorTest,CustomSplitV)341 TEST_F(OperatorTest, CustomSplitV) {
342   TensorFlowSplitVOperator op;
343   op.num_split = 123;
344   auto output_toco_op = SerializeAndDeserialize(
345       GetOperator("SPLIT_V", OperatorType::kSplitV), op);
346   EXPECT_EQ(op.num_split, output_toco_op->num_split);
347 }
348 
TEST_F(OperatorTest,BuiltinAveragePool)349 TEST_F(OperatorTest, BuiltinAveragePool) {
350   AveragePoolOperator op;
351   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
352   op.stride_width = 123;
353   op.stride_height = 124;
354   op.padding.type = PaddingType::kValid;
355   op.kwidth = 480;
356   op.kheight = 1080;
357   auto output_toco_op = SerializeAndDeserialize(
358       GetOperator("AVERAGE_POOL_2D", OperatorType::kAveragePool), op);
359   EXPECT_EQ(op.fused_activation_function,
360             output_toco_op->fused_activation_function);
361   EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
362   EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
363   EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
364   EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
365   EXPECT_EQ(op.kheight, output_toco_op->kheight);
366 }
367 
TEST_F(OperatorTest,BuiltinConvolution)368 TEST_F(OperatorTest, BuiltinConvolution) {
369   ConvOperator op;
370   op.stride_width = 123;
371   op.stride_height = 124;
372   op.padding.type = PaddingType::kValid;
373   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
374   auto output_toco_op =
375       SerializeAndDeserialize(GetOperator("CONV_2D", OperatorType::kConv), op);
376   EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
377   EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
378   EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
379   EXPECT_EQ(op.fused_activation_function,
380             output_toco_op->fused_activation_function);
381 }
382 
TEST_F(OperatorTest,BuiltinDepthwiseConvolution)383 TEST_F(OperatorTest, BuiltinDepthwiseConvolution) {
384   DepthwiseConvOperator op;
385   op.stride_width = 123;
386   op.stride_height = 124;
387   op.padding.type = PaddingType::kValid;
388   op.depth_multiplier = 6;
389   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
390   auto output_toco_op = SerializeAndDeserialize(
391       GetOperator("DEPTHWISE_CONV_2D", OperatorType::kDepthwiseConv), op);
392   EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
393   EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
394   EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
395   EXPECT_EQ(op.depth_multiplier, output_toco_op->depth_multiplier);
396   EXPECT_EQ(op.fused_activation_function,
397             output_toco_op->fused_activation_function);
398 }
399 
TEST_F(OperatorTest,BuiltinL2Norm)400 TEST_F(OperatorTest, BuiltinL2Norm) {
401   L2NormalizationOperator op;
402   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
403   auto output_toco_op = SerializeAndDeserialize(
404       GetOperator("L2_NORMALIZATION", OperatorType::kL2Normalization), op);
405   EXPECT_EQ(op.fused_activation_function,
406             output_toco_op->fused_activation_function);
407 }
408 
TEST_F(OperatorTest,BuiltinMul)409 TEST_F(OperatorTest, BuiltinMul) {
410   MulOperator op;
411   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
412   auto output_toco_op =
413       SerializeAndDeserialize(GetOperator("MUL", OperatorType::kMul), op);
414   EXPECT_EQ(op.fused_activation_function,
415             output_toco_op->fused_activation_function);
416 }
417 
TEST_F(OperatorTest,ResizeBilinear)418 TEST_F(OperatorTest, ResizeBilinear) {
419   ResizeBilinearOperator op;
420   op.align_corners = true;
421   op.half_pixel_centers = false;
422   auto output_toco_op = SerializeAndDeserialize(
423       GetOperator("RESIZE_BILINEAR", OperatorType::kResizeBilinear), op);
424   EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
425   EXPECT_EQ(op.half_pixel_centers, output_toco_op->half_pixel_centers);
426 }
427 
TEST_F(OperatorTest,ResizeBilinear_HalfPixelCenters)428 TEST_F(OperatorTest, ResizeBilinear_HalfPixelCenters) {
429   ResizeBilinearOperator op;
430   op.align_corners = true;
431   op.half_pixel_centers = true;
432   auto output_toco_op = SerializeAndDeserialize(
433       GetOperator("RESIZE_BILINEAR", OperatorType::kResizeBilinear), op);
434   EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
435   EXPECT_EQ(op.half_pixel_centers, output_toco_op->half_pixel_centers);
436 }
437 
TEST_F(OperatorTest,ResizeNearestNeighbor)438 TEST_F(OperatorTest, ResizeNearestNeighbor) {
439   ResizeNearestNeighborOperator op;
440   op.align_corners = true;
441   op.half_pixel_centers = false;
442   auto output_toco_op =
443       SerializeAndDeserialize(GetOperator("RESIZE_NEAREST_NEIGHBOR",
444                                           OperatorType::kResizeNearestNeighbor),
445                               op);
446   EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
447   EXPECT_EQ(op.half_pixel_centers, output_toco_op->half_pixel_centers);
448 }
449 
TEST_F(OperatorTest,ResizeNearestNeighbor_HalfPixelCenters)450 TEST_F(OperatorTest, ResizeNearestNeighbor_HalfPixelCenters) {
451   ResizeNearestNeighborOperator op;
452   op.align_corners = true;
453   op.half_pixel_centers = true;
454   auto output_toco_op =
455       SerializeAndDeserialize(GetOperator("RESIZE_NEAREST_NEIGHBOR",
456                                           OperatorType::kResizeNearestNeighbor),
457                               op);
458   EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
459   EXPECT_EQ(op.half_pixel_centers, output_toco_op->half_pixel_centers);
460 }
461 
TEST_F(OperatorTest,Svdf)462 TEST_F(OperatorTest, Svdf) {
463   SvdfOperator op;
464   op.fused_activation_function = FusedActivationFunctionType::kRelu;
465   op.rank = 1;
466   auto output_toco_op =
467       SerializeAndDeserialize(GetOperator("SVDF", OperatorType::kSvdf), op);
468   EXPECT_EQ(op.fused_activation_function,
469             output_toco_op->fused_activation_function);
470   EXPECT_EQ(op.rank, output_toco_op->rank);
471 }
472 
TEST_F(OperatorTest,Squeeze)473 TEST_F(OperatorTest, Squeeze) {
474   SqueezeOperator op;
475   op.squeeze_dims = {-2, -3, 4, 1, 4};
476 
477   auto output_toco_op = SerializeAndDeserialize(
478       GetOperator("SQUEEZE", OperatorType::kSqueeze), op);
479   EXPECT_EQ(op.squeeze_dims, output_toco_op->squeeze_dims);
480 }
481 
TEST_F(OperatorTest,StridedSlice)482 TEST_F(OperatorTest, StridedSlice) {
483   StridedSliceOperator op;
484 
485   op.begin_mask = 1;
486   op.end_mask = 2;
487   op.ellipsis_mask = 1;
488   op.new_axis_mask = 1;
489   op.shrink_axis_mask = 2;
490 
491   auto output_toco_op = SerializeAndDeserialize(
492       GetOperator("STRIDED_SLICE", OperatorType::kStridedSlice), op);
493   EXPECT_EQ(op.start_indices, output_toco_op->start_indices);
494   EXPECT_EQ(op.stop_indices, output_toco_op->stop_indices);
495   EXPECT_EQ(op.strides, output_toco_op->strides);
496   EXPECT_EQ(op.begin_mask, output_toco_op->begin_mask);
497   EXPECT_EQ(op.end_mask, output_toco_op->end_mask);
498   EXPECT_EQ(op.end_mask, output_toco_op->end_mask);
499   EXPECT_EQ(op.ellipsis_mask, output_toco_op->ellipsis_mask);
500   EXPECT_EQ(op.new_axis_mask, output_toco_op->new_axis_mask);
501   EXPECT_EQ(op.shrink_axis_mask, output_toco_op->shrink_axis_mask);
502 }
503 
TEST_F(OperatorTest,BuiltinTopKV2)504 TEST_F(OperatorTest, BuiltinTopKV2) {
505   TopKV2Operator op;
506   auto output_toco_op = SerializeAndDeserialize(
507       GetOperator("TOPK_V2", OperatorType::kTopK_V2), op);
508   ASSERT_NE(nullptr, output_toco_op.get());
509 }
510 
TEST_F(OperatorTest,BuiltinArgMax)511 TEST_F(OperatorTest, BuiltinArgMax) {
512   ArgMaxOperator op;
513   auto output_toco_op = SerializeAndDeserialize(
514       GetOperator("ARG_MAX", OperatorType::kArgMax), op);
515   EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
516 }
517 
TEST_F(OperatorTest,BuiltinArgMin)518 TEST_F(OperatorTest, BuiltinArgMin) {
519   ArgMinOperator op;
520   auto output_toco_op = SerializeAndDeserialize(
521       GetOperator("ARG_MIN", OperatorType::kArgMin), op);
522   EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
523 }
524 
TEST_F(OperatorTest,BuiltinDequantize)525 TEST_F(OperatorTest, BuiltinDequantize) {
526   DequantizeOperator op;
527   auto output_toco_op = SerializeAndDeserialize(
528       GetOperator("DEQUANTIZE", OperatorType::kDequantize), op);
529 }
530 
TEST_F(OperatorTest,BuiltinTransposeConv)531 TEST_F(OperatorTest, BuiltinTransposeConv) {
532   TransposeConvOperator op;
533   op.stride_width = 123;
534   op.stride_height = 124;
535   op.padding.type = PaddingType::kValid;
536   auto output_toco_op = SerializeAndDeserialize(
537       GetOperator("TRANSPOSE_CONV", OperatorType::kTransposeConv), op);
538   EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
539   EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
540   EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
541 }
542 
TEST_F(OperatorTest,BuiltinShape)543 TEST_F(OperatorTest, BuiltinShape) {
544   TensorFlowShapeOperator op;
545   op.output_data_type = ArrayDataType::kInt64;
546   auto output_toco_op =
547       SerializeAndDeserialize(GetOperator("SHAPE", OperatorType::kShape), op);
548   EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
549 }
550 
TEST_F(OperatorTest,BuiltinSparseToDense)551 TEST_F(OperatorTest, BuiltinSparseToDense) {
552   SparseToDenseOperator op;
553   op.validate_indices = false;
554   std::unique_ptr<toco::SparseToDenseOperator> output_toco_op =
555       SerializeAndDeserialize(
556           GetOperator("SPARSE_TO_DENSE", OperatorType::kSparseToDense), op);
557   EXPECT_EQ(op.validate_indices, output_toco_op->validate_indices);
558 }
559 
TEST_F(OperatorTest,VersioningSpareToDense)560 TEST_F(OperatorTest, VersioningSpareToDense) {
561   SparseToDenseOperator op;
562   op.inputs = {"indices", "output_shape", "input_values", "default_value"};
563   auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
564   const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
565 
566   Model int32_model;
567   Array& int32_array = int32_model.GetOrCreateArray(op.inputs[2]);
568   int32_array.data_type = ArrayDataType::kInt32;
569   OperatorSignature int32_signature = {.op = &op, .model = &int32_model};
570   EXPECT_EQ(base_op->GetVersion(int32_signature), 1);
571 
572   // Expect version 2 for int64 input.
573   Model int64_model;
574   Array& int64_array = int64_model.GetOrCreateArray(op.inputs[2]);
575   int64_array.data_type = ArrayDataType::kInt64;
576   OperatorSignature int64_signature = {.op = &op, .model = &int64_model};
577   EXPECT_EQ(base_op->GetVersion(int64_signature), 2);
578 
579   // Expect version 3 for int8 and uint8 input.
580   Model int8_model;
581   Array& int8_array = int8_model.GetOrCreateArray(op.inputs[2]);
582   int8_array.data_type = ArrayDataType::kInt8;
583   OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
584   EXPECT_EQ(base_op->GetVersion(int8_signature), 3);
585 
586   Model uint8_model;
587   Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[2]);
588   uint8_array.data_type = ArrayDataType::kUint8;
589   OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
590   EXPECT_EQ(base_op->GetVersion(uint8_signature), 3);
591 }
592 
TEST_F(OperatorTest,BuiltinPack)593 TEST_F(OperatorTest, BuiltinPack) {
594   PackOperator op;
595   op.values_count = 3;
596   op.axis = 1;
597   std::unique_ptr<toco::PackOperator> output_toco_op =
598       SerializeAndDeserialize(GetOperator("PACK", OperatorType::kPack), op);
599   EXPECT_EQ(op.values_count, output_toco_op->values_count);
600   EXPECT_EQ(op.axis, output_toco_op->axis);
601 }
602 
TEST_F(OperatorTest,BuiltinOneHot)603 TEST_F(OperatorTest, BuiltinOneHot) {
604   OneHotOperator op;
605   op.axis = 2;
606   auto output_toco_op = SerializeAndDeserialize(
607       GetOperator("ONE_HOT", OperatorType::kOneHot), op);
608   EXPECT_EQ(op.axis, output_toco_op->axis);
609 }
610 
TEST_F(OperatorTest,BuiltinUnpack)611 TEST_F(OperatorTest, BuiltinUnpack) {
612   UnpackOperator op;
613   op.num = 5;
614   op.axis = 2;
615   auto output_toco_op =
616       SerializeAndDeserialize(GetOperator("UNPACK", OperatorType::kUnpack), op);
617   EXPECT_EQ(op.num, output_toco_op->num);
618   EXPECT_EQ(op.axis, output_toco_op->axis);
619 }
620 
TEST_F(OperatorTest,BuiltinLeakyRelu)621 TEST_F(OperatorTest, BuiltinLeakyRelu) {
622   LeakyReluOperator op;
623   op.alpha = 3;
624   auto output_toco_op = SerializeAndDeserialize(
625       GetOperator("LEAKY_RELU", OperatorType::kLeakyRelu), op);
626   EXPECT_EQ(op.alpha, output_toco_op->alpha);
627 }
628 
TEST_F(OperatorTest,BuiltinSquaredDifference)629 TEST_F(OperatorTest, BuiltinSquaredDifference) {
630   SquaredDifferenceOperator op;
631   auto output_toco_op = SerializeAndDeserialize(
632       GetOperator("SQUARED_DIFFERENCE", OperatorType::kSquaredDifference), op);
633   ASSERT_NE(nullptr, output_toco_op.get());
634 }
635 
TEST_F(OperatorTest,BuiltinScatterNd)636 TEST_F(OperatorTest, BuiltinScatterNd) {
637   ScatterNdOperator op;
638   auto output_toco_op = SerializeAndDeserialize(
639       GetOperator("SCATTER_ND", OperatorType::kScatterNd), op);
640   ASSERT_NE(nullptr, output_toco_op.get());
641 }
642 
TEST_F(OperatorTest,CustomCTCBeamSearchDecoder)643 TEST_F(OperatorTest, CustomCTCBeamSearchDecoder) {
644   CTCBeamSearchDecoderOperator op;
645   op.beam_width = 3;
646   op.top_paths = 2;
647   op.merge_repeated = false;
648   std::unique_ptr<toco::CTCBeamSearchDecoderOperator> output_toco_op =
649       SerializeAndDeserialize(GetOperator("CTC_BEAM_SEARCH_DECODER",
650                                           OperatorType::kCTCBeamSearchDecoder),
651                               op);
652   EXPECT_EQ(op.beam_width, output_toco_op->beam_width);
653   EXPECT_EQ(op.top_paths, output_toco_op->top_paths);
654   EXPECT_EQ(op.merge_repeated, output_toco_op->merge_repeated);
655 }
656 
TEST_F(OperatorTest,TensorFlowUnsupported)657 TEST_F(OperatorTest, TensorFlowUnsupported) {
658   TensorFlowUnsupportedOperator op;
659   op.tensorflow_op = "MyCustomUnsupportedOp";
660 
661   ::tensorflow::NodeDef node_def;
662   auto attr = node_def.mutable_attr();
663   (*attr)["float_attr"].set_f(2.0);
664   (*attr)["str_attr"].set_s("Hello World");
665   (*attr)["int_attr"].set_i(17);
666   (*attr)["bool_attr"].set_b(true);
667   {
668     auto* list = (*attr)["list_string_attr"].mutable_list();
669     list->add_s("abcde");
670     list->add_s("1234");
671     list->add_s("");
672     list->add_s("zyxwv");
673     list->add_s("!-.");
674   }
675   {
676     auto* list = (*attr)["list_float_attr"].mutable_list();
677     list->add_f(std::numeric_limits<float>::min());
678     list->add_f(2.0);
679     list->add_f(-std::numeric_limits<float>::max());
680   }
681   {
682     auto* list = (*attr)["list_int_attr"].mutable_list();
683     list->add_i(1);
684     list->add_i(20);
685     list->add_i(1LL << 40);
686     list->add_i(-(1LL << 40));
687   }
688   node_def.SerializeToString(&op.tensorflow_node_def);
689 
690   auto output_toco_op = SerializeAndDeserialize(
691       GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op);
692 
693   ::tensorflow::NodeDef output_node_def;
694   output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
695   const auto& output_attr = output_node_def.attr();
696   EXPECT_EQ(2.0, output_attr.at("float_attr").f());
697   EXPECT_EQ("Hello World", output_attr.at("str_attr").s());
698   EXPECT_EQ(17, output_attr.at("int_attr").i());
699   EXPECT_EQ(true, output_attr.at("bool_attr").b());
700   {
701     const auto& list = output_attr.at("list_string_attr").list();
702     ASSERT_EQ(5, list.s_size());
703     EXPECT_EQ("abcde", list.s(0));
704     EXPECT_EQ("1234", list.s(1));
705     EXPECT_EQ("", list.s(2));
706     EXPECT_EQ("zyxwv", list.s(3));
707     EXPECT_EQ("!-.", list.s(4));
708   }
709   {
710     const auto& list = output_attr.at("list_float_attr").list();
711     ASSERT_EQ(3, list.f_size());
712     EXPECT_EQ(std::numeric_limits<float>::min(), list.f(0));
713     EXPECT_EQ(2.0, list.f(1));
714     EXPECT_EQ(-std::numeric_limits<float>::max(), list.f(2));
715   }
716   {
717     const auto& list = output_attr.at("list_int_attr").list();
718     ASSERT_EQ(4, list.i_size());
719     EXPECT_EQ(1, list.i(0));
720     EXPECT_EQ(20, list.i(1));
721     EXPECT_EQ(1LL << 40, list.i(2));
722     EXPECT_EQ(-(1LL << 40), list.i(3));
723   }
724 }
725 
TEST_F(OperatorTest,TensorFlowUnsupportedWithoutAttr)726 TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) {
727   TensorFlowUnsupportedOperator op;
728   op.tensorflow_op = "MyCustomUnsupportedOp";
729   auto output_toco_op = SerializeAndDeserialize(
730       GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op);
731 
732   ::tensorflow::NodeDef output_node_def;
733   output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
734   EXPECT_TRUE(output_node_def.attr().empty());
735 }
736 
TEST_F(OperatorTest,TestShouldExportAsFlexOp)737 TEST_F(OperatorTest, TestShouldExportAsFlexOp) {
738   EXPECT_FALSE(ShouldExportAsFlexOp(false, "Conv2D"));
739   EXPECT_TRUE(ShouldExportAsFlexOp(true, "Conv2D"));
740   EXPECT_TRUE(ShouldExportAsFlexOp(true, "EluGrad"));
741   EXPECT_TRUE(ShouldExportAsFlexOp(true, "RFFT"));
742   EXPECT_FALSE(ShouldExportAsFlexOp(true, "MyAwesomeCustomOp"));
743   EXPECT_TRUE(ShouldExportAsFlexOp(true, "RandomShuffle"));
744 }
745 
TEST_F(OperatorTest,BuiltinMirrorPad)746 TEST_F(OperatorTest, BuiltinMirrorPad) {
747   MirrorPadOperator op;
748   op.mode = MirrorPadMode::kReflect;
749   auto output_toco_op = SerializeAndDeserialize(
750       GetOperator("MIRROR_PAD", OperatorType::kMirrorPad), op);
751   EXPECT_EQ(op.mode, output_toco_op->mode);
752 }
753 
TEST_F(OperatorTest,BuiltinUnique)754 TEST_F(OperatorTest, BuiltinUnique) {
755   UniqueOperator op;
756   op.idx_out_type = ArrayDataType::kInt64;
757   auto output_toco_op =
758       SerializeAndDeserialize(GetOperator("UNIQUE", OperatorType::kUnique), op);
759   ASSERT_NE(nullptr, output_toco_op.get());
760   EXPECT_EQ(output_toco_op->idx_out_type, op.idx_out_type);
761 }
762 
TEST_F(OperatorTest,BuiltinSegmentSum)763 TEST_F(OperatorTest, BuiltinSegmentSum) {
764   SegmentSumOperator op;
765   auto output_toco_op = SerializeAndDeserialize(
766       GetOperator("SEGMENT_SUM", OperatorType::kSegmentSum), op);
767   ASSERT_NE(nullptr, output_toco_op.get());
768 }
769 
TEST_F(OperatorTest,BuiltinReverseSequence)770 TEST_F(OperatorTest, BuiltinReverseSequence) {
771   ReverseSequenceOperator op;
772   op.seq_dim = 3;
773   op.batch_dim = 1;
774   std::unique_ptr<toco::ReverseSequenceOperator> output_toco_op =
775       SerializeAndDeserialize(
776           GetOperator("REVERSE_SEQUENCE", OperatorType::kReverseSequence), op);
777   EXPECT_EQ(op.seq_dim, output_toco_op->seq_dim);
778   EXPECT_EQ(op.batch_dim, output_toco_op->batch_dim);
779 }
780 
TEST_F(OperatorTest,BuiltinMatrixDiag)781 TEST_F(OperatorTest, BuiltinMatrixDiag) {
782   MatrixDiagOperator op;
783   std::unique_ptr<toco::MatrixDiagOperator> output_toco_op =
784       SerializeAndDeserialize(
785           GetOperator("MATRIX_DIAG", OperatorType::kMatrixDiag), op);
786 }
787 
TEST_F(OperatorTest,BuiltinMatrixSetDiag)788 TEST_F(OperatorTest, BuiltinMatrixSetDiag) {
789   MatrixSetDiagOperator op;
790   std::unique_ptr<toco::MatrixSetDiagOperator> output_toco_op =
791       SerializeAndDeserialize(
792           GetOperator("MATRIX_SET_DIAG", OperatorType::kMatrixSetDiag), op);
793 }
794 
795 // Test version for a simple Op with 2 versions and the input type controls the
796 // version.
797 template <typename Op>
SimpleVersioningTest()798 void SimpleVersioningTest() {
799   Op op;
800   op.inputs = {"input1"};
801   auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
802   const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
803 
804   Model uint8_model;
805   Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
806   uint8_array.data_type = ArrayDataType::kUint8;
807   OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
808   EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
809 
810   Model int8_model;
811   Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
812   int8_array.data_type = ArrayDataType::kInt8;
813   OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
814   EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
815 }
816 
817 // Test version for a simple Op with 2 versions and the output type controls the
818 // version.
819 template <typename Op>
SimpleOutputVersioningTest()820 void SimpleOutputVersioningTest() {
821   Op op;
822   op.outputs = {"output1"};
823   auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
824   const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
825 
826   Model uint8_model;
827   Array& uint8_array = uint8_model.GetOrCreateArray(op.outputs[0]);
828   uint8_array.data_type = ArrayDataType::kUint8;
829   OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
830   EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
831 
832   Model int8_model;
833   Array& int8_array = int8_model.GetOrCreateArray(op.outputs[0]);
834   int8_array.data_type = ArrayDataType::kInt8;
835   OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
836   EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
837 }
838 
TEST_F(OperatorTest,VersioningEqualTest)839 TEST_F(OperatorTest, VersioningEqualTest) {
840   SimpleVersioningTest<TensorFlowEqualOperator>();
841 }
842 
TEST_F(OperatorTest,VersioningNotEqualTest)843 TEST_F(OperatorTest, VersioningNotEqualTest) {
844   SimpleVersioningTest<TensorFlowNotEqualOperator>();
845 }
846 
TEST_F(OperatorTest,VersioningLessTest)847 TEST_F(OperatorTest, VersioningLessTest) {
848   SimpleVersioningTest<TensorFlowLessOperator>();
849 }
850 
TEST_F(OperatorTest,VersioningLessEqualTest)851 TEST_F(OperatorTest, VersioningLessEqualTest) {
852   SimpleVersioningTest<TensorFlowLessEqualOperator>();
853 }
854 
TEST_F(OperatorTest,VersioningGreaterTest)855 TEST_F(OperatorTest, VersioningGreaterTest) {
856   SimpleVersioningTest<TensorFlowGreaterOperator>();
857 }
858 
TEST_F(OperatorTest,VersioningGreaterEqualTest)859 TEST_F(OperatorTest, VersioningGreaterEqualTest) {
860   SimpleVersioningTest<TensorFlowGreaterEqualOperator>();
861 }
862 
TEST_F(OperatorTest,VersioningSpaceToBatchNDTest)863 TEST_F(OperatorTest, VersioningSpaceToBatchNDTest) {
864   SpaceToBatchNDOperator op;
865   op.inputs = {"input1"};
866   auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
867   const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
868 
869   Model uint8_model;
870   Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
871   uint8_array.copy_shape({1, 2, 2, 2});
872   uint8_array.data_type = ArrayDataType::kUint8;
873   OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
874   EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
875 
876   Model int8_model;
877   Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
878   int8_array.copy_shape({1, 2, 2, 2});
879   int8_array.data_type = ArrayDataType::kInt8;
880   OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
881   EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
882 
883   Model float_model;
884   Array& float_array = float_model.GetOrCreateArray(op.inputs[0]);
885   float_array.copy_shape({1, 2, 2});
886   float_array.data_type = ArrayDataType::kFloat;
887   OperatorSignature float_signature = {.op = &op, .model = &float_model};
888   EXPECT_EQ(base_op->GetVersion(float_signature), 3);
889 }
890 
TEST_F(OperatorTest,VersioningLogSoftmaxTest)891 TEST_F(OperatorTest, VersioningLogSoftmaxTest) {
892   SimpleVersioningTest<LogSoftmaxOperator>();
893 }
894 
TEST_F(OperatorTest,VersioningPackTest)895 TEST_F(OperatorTest, VersioningPackTest) {
896   SimpleVersioningTest<PackOperator>();
897 }
898 
TEST_F(OperatorTest,VersioningUnpackTest)899 TEST_F(OperatorTest, VersioningUnpackTest) {
900   UnpackOperator op;
901   op.inputs = {"input1"};
902   auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
903   const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
904 
905   Model int32_model;
906   Array& int32_array = int32_model.GetOrCreateArray(op.inputs[0]);
907   int32_array.data_type = ArrayDataType::kInt32;
908   OperatorSignature int32_signature = {.op = &op, .model = &int32_model};
909   EXPECT_EQ(base_op->GetVersion(int32_signature), 1);
910 
911   Model uint8_model;
912   Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
913   uint8_array.data_type = ArrayDataType::kUint8;
914   OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
915   EXPECT_EQ(base_op->GetVersion(uint8_signature), 2);
916 
917   Model int8_model;
918   Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
919   int8_array.data_type = ArrayDataType::kInt8;
920   OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
921   EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
922 }
923 
TEST_F(OperatorTest,VersioningBatchToSpaceNDTest)924 TEST_F(OperatorTest, VersioningBatchToSpaceNDTest) {
925   BatchToSpaceNDOperator op;
926   op.inputs = {"input1"};
927   auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
928   const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
929 
930   Model uint8_model;
931   Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
932   uint8_array.data_type = ArrayDataType::kUint8;
933   uint8_array.copy_shape({1, 2, 2, 2});
934   OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
935   EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
936 
937   Model int8_model;
938   Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
939   int8_array.data_type = ArrayDataType::kInt8;
940   int8_array.copy_shape({1, 2, 2, 2});
941   OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
942   EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
943 
944   Model float_model;
945   Array& float_array = float_model.GetOrCreateArray(op.inputs[0]);
946   float_array.copy_shape({1, 2, 2});
947   float_array.data_type = ArrayDataType::kFloat;
948   OperatorSignature float_signature = {.op = &op, .model = &float_model};
949   EXPECT_EQ(base_op->GetVersion(float_signature), 3);
950 }
951 
TEST_F(OperatorTest,VersioningTanhTest)952 TEST_F(OperatorTest, VersioningTanhTest) {
953   SimpleVersioningTest<TanhOperator>();
954 }
955 
TEST_F(OperatorTest,VersioningStridedSliceTest)956 TEST_F(OperatorTest, VersioningStridedSliceTest) {
957   StridedSliceOperator op;
958   op.inputs = {"input1"};
959   op.ellipsis_mask = 0;
960   op.new_axis_mask = 0;
961   auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
962   const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
963 
964   Model uint8_model;
965   Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
966   uint8_array.data_type = ArrayDataType::kUint8;
967   OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
968   EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
969 
970   Model int8_model;
971   Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
972   int8_array.data_type = ArrayDataType::kInt8;
973   OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
974   EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
975 
976   Model bool_model;
977   Array& bool_array = bool_model.GetOrCreateArray(op.inputs[0]);
978   bool_array.data_type = ArrayDataType::kBool;
979   OperatorSignature bool_signature = {.op = &op, .model = &bool_model};
980   EXPECT_EQ(base_op->GetVersion(bool_signature), 3);
981 
982   op.start_indices = {0, 0, 0, 0, 0};
983   op.stop_indices = {1, 2, 2, 2, 2};
984   op.strides = {1, 1, 1, 1, 1};
985   EXPECT_EQ(base_op->GetVersion(uint8_signature), 4);
986   EXPECT_EQ(base_op->GetVersion(int8_signature), 4);
987   EXPECT_EQ(base_op->GetVersion(bool_signature), 4);
988 }
989 
TEST_F(OperatorTest,VersioningSpaceToDepthTest)990 TEST_F(OperatorTest, VersioningSpaceToDepthTest) {
991   SimpleVersioningTest<SpaceToDepthOperator>();
992 }
993 
TEST_F(OperatorTest,VersioningSliceTest)994 TEST_F(OperatorTest, VersioningSliceTest) {
995   SimpleVersioningTest<SliceOperator>();
996 
997   // Check that a string input results in a version 3 op.
998   SliceOperator op;
999   op.inputs = {"input1"};
1000   auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
1001   const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
1002 
1003   Model string_model;
1004   Array& string_array = string_model.GetOrCreateArray(op.inputs[0]);
1005   string_array.data_type = ArrayDataType::kString;
1006   OperatorSignature string_signature = {.op = &op, .model = &string_model};
1007   EXPECT_EQ(base_op->GetVersion(string_signature), 3);
1008 }
1009 
TEST_F(OperatorTest,VersioningLogisticTest)1010 TEST_F(OperatorTest, VersioningLogisticTest) {
1011   SimpleVersioningTest<LogisticOperator>();
1012 }
1013 
TEST_F(OperatorTest,VersioningL2NormTest)1014 TEST_F(OperatorTest, VersioningL2NormTest) {
1015   SimpleOutputVersioningTest<L2NormalizationOperator>();
1016 }
1017 
TEST_F(OperatorTest,VersioningMaxTest)1018 TEST_F(OperatorTest, VersioningMaxTest) {
1019   SimpleVersioningTest<TensorFlowMaximumOperator>();
1020 }
1021 
TEST_F(OperatorTest,VersioningMinTest)1022 TEST_F(OperatorTest, VersioningMinTest) {
1023   SimpleVersioningTest<TensorFlowMinimumOperator>();
1024 }
1025 
TEST_F(OperatorTest,VersioningMeanTest)1026 TEST_F(OperatorTest, VersioningMeanTest) {
1027   SimpleVersioningTest<MeanOperator>();
1028 }
1029 
TEST_F(OperatorTest,VersioningSumTest)1030 TEST_F(OperatorTest, VersioningSumTest) {
1031   SimpleVersioningTest<TensorFlowSumOperator>();
1032 }
1033 
TEST_F(OperatorTest,VersioningAddTest)1034 TEST_F(OperatorTest, VersioningAddTest) { SimpleVersioningTest<AddOperator>(); }
1035 
SimpleMulVersioningTest(ArrayDataType data_type,float multiplier,int version)1036 void SimpleMulVersioningTest(ArrayDataType data_type, float multiplier,
1037                              int version) {
1038   MulOperator op;
1039   op.inputs = {"input1", "input2"};
1040   op.outputs = {"output"};
1041   auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
1042   const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
1043 
1044   Model model;
1045   Array& input0 = model.GetOrCreateArray(op.inputs[0]);
1046   Array& input1 = model.GetOrCreateArray(op.inputs[1]);
1047   Array& output = model.GetOrCreateArray(op.outputs[0]);
1048 
1049   input0.data_type = data_type;
1050   input0.GetOrCreateQuantizationParams().scale = 1.0f;
1051   input1.data_type = data_type;
1052   input1.GetOrCreateQuantizationParams().scale = 1.0f;
1053   output.data_type = data_type;
1054   output.GetOrCreateQuantizationParams().scale = 1.0f / multiplier;
1055 
1056   OperatorSignature signature = {.op = &op, .model = &model};
1057   EXPECT_EQ(base_op->GetVersion(signature), version);
1058 }
1059 
TEST_F(OperatorTest,VersioningMulTest)1060 TEST_F(OperatorTest, VersioningMulTest) {
1061   SimpleMulVersioningTest(ArrayDataType::kUint8, 0.5f, 1);
1062   SimpleMulVersioningTest(ArrayDataType::kInt8, 0.5f, 2);
1063   SimpleMulVersioningTest(ArrayDataType::kInt8, 2.0f, 3);
1064 }
1065 
1066 template <typename OpType>
SimpleTwoInputsVersioningTest(ArrayDataType data_type,Shape shape1,Shape shape2,int version)1067 void SimpleTwoInputsVersioningTest(ArrayDataType data_type, Shape shape1,
1068                                    Shape shape2, int version) {
1069   OpType op;
1070   op.inputs = {"input1", "input2"};
1071   op.outputs = {"output"};
1072   auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
1073   const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
1074 
1075   Model model;
1076   Array& input0 = model.GetOrCreateArray(op.inputs[0]);
1077   Array& input1 = model.GetOrCreateArray(op.inputs[1]);
1078   Array& output = model.GetOrCreateArray(op.outputs[0]);
1079 
1080   input0.data_type = data_type;
1081   input0.copy_shape(shape1);
1082   input1.data_type = data_type;
1083   input1.copy_shape(shape2);
1084   output.data_type = data_type;
1085 
1086   OperatorSignature signature = {.op = &op, .model = &model};
1087   EXPECT_EQ(base_op->GetVersion(signature), version);
1088 }
1089 
1090 template <typename OpType>
SimpleThreeInputsVersioningTest(ArrayDataType data_type,Shape shape1,Shape shape2,Shape shape3,int version)1091 void SimpleThreeInputsVersioningTest(ArrayDataType data_type, Shape shape1,
1092                                      Shape shape2, Shape shape3, int version) {
1093   OpType op;
1094   op.inputs = {"input1", "input2", "input3"};
1095   op.outputs = {"output"};
1096   auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
1097   const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
1098 
1099   Model model;
1100   Array& input0 = model.GetOrCreateArray(op.inputs[0]);
1101   Array& input1 = model.GetOrCreateArray(op.inputs[1]);
1102   Array& input2 = model.GetOrCreateArray(op.inputs[2]);
1103   Array& output = model.GetOrCreateArray(op.outputs[0]);
1104 
1105   input0.data_type = data_type;
1106   input0.copy_shape(shape1);
1107   input1.data_type = data_type;
1108   input1.copy_shape(shape2);
1109   input2.data_type = data_type;
1110   input2.copy_shape(shape3);
1111   output.data_type = data_type;
1112 
1113   OperatorSignature signature = {.op = &op, .model = &model};
1114   EXPECT_EQ(base_op->GetVersion(signature), version);
1115 }
1116 
TEST_F(OperatorTest,VersioningSubTest)1117 TEST_F(OperatorTest, VersioningSubTest) {
1118   SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kUint8,
1119                                              {1, 2, 2, 2}, {1, 2, 2, 2}, 1);
1120   SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kInt8, {1, 2, 2, 2},
1121                                              {1, 2, 2, 2}, 2);
1122   SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kUint8, {1, 2, 2},
1123                                              {1, 2, 2}, 1);
1124   SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kInt8, {1, 2, 2},
1125                                              {1, 2, 2}, 2);
1126   SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kUint8,
1127                                              {1, 2, 2, 2}, {1, 2, 2, 1}, 1);
1128   SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kInt8, {1, 2, 2, 2},
1129                                              {1, 2, 2, 1}, 2);
1130   SimpleTwoInputsVersioningTest<SubOperator>(
1131       ArrayDataType::kUint8, {1, 2, 2, 2, 2}, {1, 2, 2, 2, 1}, 3);
1132   SimpleTwoInputsVersioningTest<SubOperator>(
1133       ArrayDataType::kInt8, {1, 2, 2, 2, 2}, {1, 2, 2, 2, 1}, 3);
1134 }
1135 
TEST_F(OperatorTest,VersioningDivTest)1136 TEST_F(OperatorTest, VersioningDivTest) {
1137   SimpleTwoInputsVersioningTest<DivOperator>(ArrayDataType::kUint8,
1138                                              {1, 2, 2, 2}, {1, 2, 2, 2}, 1);
1139   SimpleTwoInputsVersioningTest<DivOperator>(ArrayDataType::kInt8, {1, 2, 2},
1140                                              {1, 2, 2}, 1);
1141   SimpleTwoInputsVersioningTest<DivOperator>(ArrayDataType::kUint8,
1142                                              {1, 2, 2, 2}, {1, 2, 2, 1}, 1);
1143   SimpleTwoInputsVersioningTest<DivOperator>(
1144       ArrayDataType::kInt8, {1, 2, 2, 2, 2}, {1, 2, 2, 2, 1}, 2);
1145 }
1146 
TEST_F(OperatorTest,VersioningPadTest)1147 TEST_F(OperatorTest, VersioningPadTest) { SimpleVersioningTest<PadOperator>(); }
1148 
TEST_F(OperatorTest,VersioningPadV2Test)1149 TEST_F(OperatorTest, VersioningPadV2Test) {
1150   SimpleVersioningTest<PadV2Operator>();
1151 }
1152 
TEST_F(OperatorTest,VersioningConcatenationTest)1153 TEST_F(OperatorTest, VersioningConcatenationTest) {
1154   SimpleVersioningTest<ConcatenationOperator>();
1155 }
1156 
TEST_F(OperatorTest,VersioningSelectTest)1157 TEST_F(OperatorTest, VersioningSelectTest) {
1158   SimpleThreeInputsVersioningTest<SelectOperator>(
1159       ArrayDataType::kUint8, {1, 2, 2, 2}, {1, 2, 2, 1}, {1, 2, 2, 1}, 1);
1160   SimpleThreeInputsVersioningTest<SelectOperator>(
1161       ArrayDataType::kInt8, {1, 2, 2, 2}, {1, 2, 2, 1}, {1, 2, 2, 1}, 2);
1162   SimpleThreeInputsVersioningTest<SelectOperator>(
1163       ArrayDataType::kInt8, {1, 2, 2, 2, 1}, {1, 2, 2, 1, 1}, {1, 2, 2, 1, 1},
1164       3);
1165 }
1166 
TEST_F(OperatorTest,VersioningRelu6Test)1167 TEST_F(OperatorTest, VersioningRelu6Test) {
1168   SimpleVersioningTest<Relu6Operator>();
1169 }
1170 
TEST_F(OperatorTest,VersioningFullyConnectedTest)1171 TEST_F(OperatorTest, VersioningFullyConnectedTest) {
1172   FullyConnectedOperator fully_connected_op;
1173   fully_connected_op.inputs = {"input", "weight"};
1174   fully_connected_op.outputs = {"output"};
1175   auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
1176   const BaseOperator* op =
1177       operator_by_type_map.at(fully_connected_op.type).get();
1178 
1179   Model uint8_model;
1180   Array& input_uint8_array =
1181       uint8_model.GetOrCreateArray(fully_connected_op.inputs[0]);
1182   input_uint8_array.data_type = ArrayDataType::kUint8;
1183   Array& weight_uint8_array =
1184       uint8_model.GetOrCreateArray(fully_connected_op.inputs[1]);
1185   weight_uint8_array.data_type = ArrayDataType::kUint8;
1186   Array& output_uint8_array =
1187       uint8_model.GetOrCreateArray(fully_connected_op.outputs[0]);
1188   output_uint8_array.data_type = ArrayDataType::kUint8;
1189   OperatorSignature uint8_signature = {.op = &fully_connected_op,
1190                                        .model = &uint8_model};
1191   EXPECT_EQ(op->GetVersion(uint8_signature), 6);
1192 
1193   Model int8_model;
1194   Array& input_int8_array =
1195       int8_model.GetOrCreateArray(fully_connected_op.inputs[0]);
1196   input_int8_array.data_type = ArrayDataType::kInt8;
1197   Array& weight_int8_array =
1198       int8_model.GetOrCreateArray(fully_connected_op.inputs[1]);
1199   weight_int8_array.data_type = ArrayDataType::kInt8;
1200   Array& output_int8_array =
1201       int8_model.GetOrCreateArray(fully_connected_op.outputs[0]);
1202   output_int8_array.data_type = ArrayDataType::kInt8;
1203   OperatorSignature int8_signature = {.op = &fully_connected_op,
1204                                       .model = &int8_model};
1205   EXPECT_EQ(op->GetVersion(int8_signature), 6);
1206 }
1207 
TEST_F(OperatorTest,VersioningDequantizeTest)1208 TEST_F(OperatorTest, VersioningDequantizeTest) {
1209   DequantizeOperator dequant_op;
1210   dequant_op.inputs = {"input"};
1211   dequant_op.outputs = {"output"};
1212   auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
1213   const BaseOperator* op = operator_by_type_map.at(dequant_op.type).get();
1214 
1215   Model int16_model;
1216   Array& input_int16_array = int16_model.GetOrCreateArray(dequant_op.inputs[0]);
1217   input_int16_array.data_type = ArrayDataType::kInt16;
1218   OperatorSignature int16_signature = {.op = &dequant_op,
1219                                        .model = &int16_model};
1220   EXPECT_EQ(op->GetVersion(int16_signature), 3);
1221 
1222   Model float16_model;
1223   Array& input_float16_array =
1224       float16_model.GetOrCreateArray(dequant_op.inputs[0]);
1225   input_float16_array.data_type = ArrayDataType::kFloat16;
1226   OperatorSignature float16_signature = {.op = &dequant_op,
1227                                          .model = &float16_model};
1228   EXPECT_EQ(op->GetVersion(float16_signature), 3);
1229 
1230   Model int8_model;
1231   Array& input_int8_array = int8_model.GetOrCreateArray(dequant_op.inputs[0]);
1232   input_int8_array.data_type = ArrayDataType::kInt8;
1233   OperatorSignature int8_signature = {.op = &dequant_op, .model = &int8_model};
1234   EXPECT_EQ(op->GetVersion(int8_signature), 2);
1235 
1236   Model float_model;
1237   Array& input_float_array = float_model.GetOrCreateArray(dequant_op.inputs[0]);
1238   input_float_array.data_type = ArrayDataType::kFloat;
1239   OperatorSignature float_signature = {.op = &dequant_op,
1240                                        .model = &float_model};
1241   EXPECT_EQ(op->GetVersion(float_signature), 1);
1242 }
1243 
TEST_F(OperatorTest,VersioningConv2DTest)1244 TEST_F(OperatorTest, VersioningConv2DTest) {
1245   ConvOperator conv_op;
1246   conv_op.inputs = {"input", "filter"};
1247   conv_op.outputs = {"output"};
1248   auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
1249   const BaseOperator* op = operator_by_type_map.at(conv_op.type).get();
1250 
1251   Model uint8_model;
1252   Array& input_uint8_array = uint8_model.GetOrCreateArray(conv_op.inputs[0]);
1253   input_uint8_array.data_type = ArrayDataType::kUint8;
1254   Array& filter_uint8_array = uint8_model.GetOrCreateArray(conv_op.inputs[1]);
1255   filter_uint8_array.data_type = ArrayDataType::kUint8;
1256   Array& output_uint8_array = uint8_model.GetOrCreateArray(conv_op.outputs[0]);
1257   output_uint8_array.data_type = ArrayDataType::kUint8;
1258   OperatorSignature uint8_signature = {.op = &conv_op, .model = &uint8_model};
1259   EXPECT_EQ(op->GetVersion(uint8_signature), 1);
1260 
1261   Model int8_model;
1262   Array& input_int8_array = int8_model.GetOrCreateArray(conv_op.inputs[0]);
1263   input_int8_array.data_type = ArrayDataType::kInt8;
1264   Array& filter_int8_array = int8_model.GetOrCreateArray(conv_op.inputs[1]);
1265   filter_int8_array.data_type = ArrayDataType::kInt8;
1266   Array& output_int8_array = int8_model.GetOrCreateArray(conv_op.outputs[0]);
1267   output_int8_array.data_type = ArrayDataType::kInt8;
1268   OperatorSignature int8_signature = {.op = &conv_op, .model = &int8_model};
1269   EXPECT_EQ(op->GetVersion(int8_signature), 3);
1270 
1271   Model float_model;
1272   Array& input_float_array = float_model.GetOrCreateArray(conv_op.inputs[0]);
1273   input_float_array.data_type = ArrayDataType::kFloat;
1274   Array& filter_int8_array1 = float_model.GetOrCreateArray(conv_op.inputs[1]);
1275   filter_int8_array1.data_type = ArrayDataType::kInt8;
1276   Array& output_float_array = float_model.GetOrCreateArray(conv_op.outputs[0]);
1277   output_float_array.data_type = ArrayDataType::kFloat;
1278   OperatorSignature float_signature = {.op = &conv_op, .model = &float_model};
1279   EXPECT_EQ(op->GetVersion(float_signature), 2);
1280 }
1281 
TEST_F(OperatorTest,VersioningFloorDivOperatorTest)1282 TEST_F(OperatorTest, VersioningFloorDivOperatorTest) {
1283   FloorDivOperator floordiv_op;
1284   floordiv_op.inputs = {"input1"};
1285   auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
1286   const BaseOperator* op = operator_by_type_map.at(floordiv_op.type).get();
1287 
1288   Model int32_model;
1289   Array& input_int32_array =
1290       int32_model.GetOrCreateArray(floordiv_op.inputs[0]);
1291   input_int32_array.data_type = ArrayDataType::kInt32;
1292   OperatorSignature int32_signature = {.op = &floordiv_op,
1293                                        .model = &int32_model};
1294   EXPECT_EQ(op->GetVersion(int32_signature), 1);
1295 
1296   Model float_model;
1297   Array& input_float_array =
1298       float_model.GetOrCreateArray(floordiv_op.inputs[0]);
1299   input_float_array.data_type = ArrayDataType::kFloat;
1300   OperatorSignature float_signature = {.op = &floordiv_op,
1301                                        .model = &float_model};
1302   EXPECT_EQ(op->GetVersion(float_signature), 2);
1303 }
1304 
1305 }  // namespace
1306 }  // namespace tflite
1307 
1308 }  // namespace toco
1309