1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/operator.h"
16
17 #include <string>
18
19 #include "flatbuffers/flexbuffers.h"
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/lite/toco/model.h"
25 #include "tensorflow/lite/toco/tooling_util.h"
26
27 namespace toco {
28
29 namespace tflite {
30 namespace {
31
32 class OperatorTest : public ::testing::Test {
33 protected:
34 // Return the operator for the given name and type.
GetOperator(const std::string & name,OperatorType type)35 const BaseOperator& GetOperator(const std::string& name, OperatorType type) {
36 using OpsByName = std::map<std::string, std::unique_ptr<BaseOperator>>;
37 using OpsByType = std::map<OperatorType, std::unique_ptr<BaseOperator>>;
38
39 static auto* by_name = new OpsByName(BuildOperatorByNameMap());
40 static auto* by_type = new OpsByType(BuildOperatorByTypeMap());
41
42 // Make sure the two maps were consistently built.
43 CHECK(by_name->count(name)) << "No operator for '" << name << "'.";
44 BaseOperator* op1 = by_name->at(name).get();
45 CHECK(op1->type() == type) << "while verifying '" << name << "'.";
46
47 CHECK(by_type->count(type))
48 << "No operator for '" << OperatorTypeName(type) << "'.";
49 BaseOperator* op2 = by_type->at(type).get();
50 CHECK(op2->name() == name)
51 << "while verifying '" << OperatorTypeName(type) << "'.";
52
53 return *op1;
54 }
55
56 // Use the given BaseOperator to serialize the tf.mini operator into a set of
57 // TF Lite options. Proceed to deserialize the options back into a new
58 // tf.mini operator, which is then returned. If `options` is given, it will
59 // be populated with the serialized options.
60 template <typename T>
SerializeAndDeserialize(const BaseOperator & op,const T & toco_op,Options * options=nullptr)61 std::unique_ptr<T> SerializeAndDeserialize(const BaseOperator& op,
62 const T& toco_op,
63 Options* options = nullptr) {
64 flatbuffers::FlatBufferBuilder builder;
65 Options input_options = op.Serialize(toco_op, &builder);
66
67 if (options) {
68 *options = input_options;
69 }
70
71 builder.Finish(CreateOperator(builder, 0, 0, 0, input_options.type,
72 input_options.builtin, input_options.custom,
73 ::tflite::CustomOptionsFormat_FLEXBUFFERS));
74 auto* output_options =
75 flatbuffers::GetRoot<::tflite::Operator>(builder.GetBufferPointer());
76 auto new_toco_op = op.Deserialize(output_options->builtin_options(),
77 output_options->custom_options());
78
79 CHECK(new_toco_op->type == toco_op.type)
80 << "The type of the serialized and deserialized"
81 << HelpfulOperatorTypeName(*new_toco_op)
82 << " does not match the type of the original "
83 << HelpfulOperatorTypeName(toco_op);
84
85 return std::unique_ptr<T>(dynamic_cast<T*>(new_toco_op.release()));
86 }
87
88 // Verify serialization and deserialization of simple operators (those
89 // that don't have any configuration parameters).
90 template <typename T>
CheckSimpleOperator(const std::string & name,OperatorType type)91 void CheckSimpleOperator(const std::string& name, OperatorType type) {
92 Options options;
93 auto output_toco_op =
94 SerializeAndDeserialize(GetOperator(name, type), T(), &options);
95
96 ASSERT_EQ(0, options.builtin.o);
97 ASSERT_EQ(0, options.custom.o);
98 ASSERT_EQ(::tflite::BuiltinOptions_NONE, options.type);
99
100 ASSERT_NE(nullptr, output_toco_op.get());
101 }
102
103 template <typename T>
CheckReducerOperator(const std::string & name,OperatorType type)104 void CheckReducerOperator(const std::string& name, OperatorType type) {
105 T op;
106
107 op.keep_dims = false;
108
109 auto output_toco_op = SerializeAndDeserialize(GetOperator(name, type), op);
110 EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims);
111 }
112 };
113
TEST_F(OperatorTest,SimpleOperators)114 TEST_F(OperatorTest, SimpleOperators) {
115 CheckSimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor);
116 CheckSimpleOperator<CeilOperator>("CEIL", OperatorType::kCeil);
117 CheckSimpleOperator<EluOperator>("ELU", OperatorType::kElu);
118 CheckSimpleOperator<RoundOperator>("ROUND", OperatorType::kRound);
119 CheckSimpleOperator<ReluOperator>("RELU", OperatorType::kRelu);
120 CheckSimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1);
121 CheckSimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6);
122 CheckSimpleOperator<LogisticOperator>("LOGISTIC", OperatorType::kLogistic);
123 CheckSimpleOperator<TanhOperator>("TANH", OperatorType::kTanh);
124 CheckSimpleOperator<ExpOperator>("EXP", OperatorType::kExp);
125 CheckSimpleOperator<CosOperator>("COS", OperatorType::kCos);
126 CheckSimpleOperator<LogSoftmaxOperator>("LOG_SOFTMAX",
127 OperatorType::kLogSoftmax);
128 CheckSimpleOperator<TensorFlowMaximumOperator>(
129 "MAXIMUM", OperatorType::kMaximum); // Element-wise Maximum
130 CheckSimpleOperator<TensorFlowMinimumOperator>(
131 "MINIMUM", OperatorType::kMinimum); // Element-wise Minimum
132 CheckSimpleOperator<TensorFlowLessOperator>("LESS", OperatorType::kLess);
133 CheckSimpleOperator<NegOperator>("NEG", OperatorType::kNeg);
134 CheckSimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect);
135 CheckSimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice);
136 CheckSimpleOperator<SinOperator>("SIN", OperatorType::kSin);
137 CheckSimpleOperator<TensorFlowEqualOperator>("EQUAL", OperatorType::kEqual);
138 CheckSimpleOperator<TensorFlowNotEqualOperator>("NOT_EQUAL",
139 OperatorType::kNotEqual);
140 CheckSimpleOperator<LogOperator>("LOG", OperatorType::kLog);
141 CheckSimpleOperator<TensorFlowSqrtOperator>("SQRT", OperatorType::kSqrt);
142 CheckSimpleOperator<TensorFlowRsqrtOperator>("RSQRT", OperatorType::kRsqrt);
143 CheckSimpleOperator<PowOperator>("POW", OperatorType::kPow);
144 CheckSimpleOperator<LogicalOrOperator>("LOGICAL_OR",
145 OperatorType::kLogicalOr);
146 CheckSimpleOperator<LogicalAndOperator>("LOGICAL_AND",
147 OperatorType::kLogicalAnd);
148 CheckSimpleOperator<LogicalNotOperator>("LOGICAL_NOT",
149 OperatorType::kLogicalNot);
150 CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv);
151 CheckSimpleOperator<TensorFlowSquareOperator>("SQUARE",
152 OperatorType::kSquare);
153 CheckSimpleOperator<TensorFlowZerosLikeOperator>("ZEROS_LIKE",
154 OperatorType::kZerosLike);
155 CheckSimpleOperator<FloorModOperator>("FLOOR_MOD", OperatorType::kFloorMod);
156 CheckSimpleOperator<RangeOperator>("RANGE", OperatorType::kRange);
157 CheckSimpleOperator<FillOperator>("FILL", OperatorType::kFill);
158 CheckSimpleOperator<ReverseV2Operator>("REVERSE_V2",
159 OperatorType::kReverseV2);
160 CheckSimpleOperator<TensorFlowRankOperator>("RANK", OperatorType::kRank);
161 }
162
TEST_F(OperatorTest,BuiltinAdd)163 TEST_F(OperatorTest, BuiltinAdd) {
164 AddOperator op;
165 op.fused_activation_function = FusedActivationFunctionType::kRelu6;
166 auto output_toco_op =
167 SerializeAndDeserialize(GetOperator("ADD", OperatorType::kAdd), op);
168 EXPECT_EQ(op.fused_activation_function,
169 output_toco_op->fused_activation_function);
170 }
171
TEST_F(OperatorTest,BuiltinAddN)172 TEST_F(OperatorTest, BuiltinAddN) {
173 AddNOperator op;
174 auto output_toco_op =
175 SerializeAndDeserialize(GetOperator("ADD_N", OperatorType::kAddN), op);
176 ASSERT_NE(output_toco_op.get(), nullptr);
177 }
178
TEST_F(OperatorTest,BuiltinReducerOps)179 TEST_F(OperatorTest, BuiltinReducerOps) {
180 CheckReducerOperator<MeanOperator>("MEAN", OperatorType::kMean);
181 CheckReducerOperator<TensorFlowSumOperator>("SUM", OperatorType::kSum);
182 CheckReducerOperator<TensorFlowProdOperator>("REDUCE_PROD",
183 OperatorType::kReduceProd);
184 CheckReducerOperator<TensorFlowMaxOperator>("REDUCE_MAX",
185 OperatorType::kReduceMax);
186 CheckReducerOperator<TensorFlowMinOperator>("REDUCE_MIN",
187 OperatorType::kReduceMin);
188 CheckReducerOperator<TensorFlowAnyOperator>("REDUCE_ANY", OperatorType::kAny);
189 }
190
TEST_F(OperatorTest,BuiltinCast)191 TEST_F(OperatorTest, BuiltinCast) {
192 CastOperator op;
193 op.src_data_type = ArrayDataType::kFloat;
194 op.dst_data_type = ArrayDataType::kUint8;
195 auto output_toco_op =
196 SerializeAndDeserialize(GetOperator("CAST", OperatorType::kCast), op);
197 EXPECT_EQ(op.src_data_type, output_toco_op->src_data_type);
198 EXPECT_EQ(op.dst_data_type, output_toco_op->dst_data_type);
199 }
200
TEST_F(OperatorTest,CustomConcatenation)201 TEST_F(OperatorTest, CustomConcatenation) {
202 ConcatenationOperator op;
203 op.axis = 123;
204 auto output_toco_op = SerializeAndDeserialize(
205 GetOperator("CONCATENATION", OperatorType::kConcatenation), op);
206 EXPECT_EQ(op.axis, output_toco_op->axis);
207 }
208
TEST_F(OperatorTest,CustomDepthToSpace)209 TEST_F(OperatorTest, CustomDepthToSpace) {
210 DepthToSpaceOperator op;
211 op.block_size = 123;
212 auto output_toco_op = SerializeAndDeserialize(
213 GetOperator("DEPTH_TO_SPACE", OperatorType::kDepthToSpace), op);
214 EXPECT_EQ(op.block_size, output_toco_op->block_size);
215 }
216
TEST_F(OperatorTest,CustomFakeQuant)217 TEST_F(OperatorTest, CustomFakeQuant) {
218 FakeQuantOperator op;
219 auto* minmax = new MinMax;
220 minmax->min = -10;
221 minmax->max = 200;
222 op.minmax.reset(minmax);
223 op.num_bits = 16;
224 auto output_toco_op = SerializeAndDeserialize(
225 GetOperator("FAKE_QUANT", OperatorType::kFakeQuant), op);
226 EXPECT_EQ(op.minmax->min, output_toco_op->minmax->min);
227 EXPECT_EQ(op.minmax->max, output_toco_op->minmax->max);
228 EXPECT_EQ(op.num_bits, output_toco_op->num_bits);
229 }
230
TEST_F(OperatorTest,CustomFullyConnected)231 TEST_F(OperatorTest, CustomFullyConnected) {
232 FullyConnectedOperator op;
233 op.fused_activation_function = FusedActivationFunctionType::kRelu6;
234 auto output_toco_op = SerializeAndDeserialize(
235 GetOperator("FULLY_CONNECTED", OperatorType::kFullyConnected), op);
236 EXPECT_EQ(op.fused_activation_function,
237 output_toco_op->fused_activation_function);
238 }
239
TEST_F(OperatorTest,BuiltinGather)240 TEST_F(OperatorTest, BuiltinGather) {
241 GatherOperator op;
242 auto output_toco_op =
243 SerializeAndDeserialize(GetOperator("GATHER", OperatorType::kGather), op);
244 ASSERT_NE(nullptr, output_toco_op.get());
245 }
246
TEST_F(OperatorTest,BuiltinGatherNd)247 TEST_F(OperatorTest, BuiltinGatherNd) {
248 GatherNdOperator op;
249 auto output_toco_op = SerializeAndDeserialize(
250 GetOperator("GATHER_ND", OperatorType::kGatherNd), op);
251 ASSERT_NE(output_toco_op.get(), nullptr);
252 }
253
TEST_F(OperatorTest,BuiltinWhere)254 TEST_F(OperatorTest, BuiltinWhere) {
255 WhereOperator op;
256 auto output_toco_op =
257 SerializeAndDeserialize(GetOperator("WHERE", OperatorType::kWhere), op);
258 ASSERT_NE(output_toco_op.get(), nullptr);
259 }
260
TEST_F(OperatorTest,BuiltinL2Pool)261 TEST_F(OperatorTest, BuiltinL2Pool) {
262 L2PoolOperator op;
263 op.stride_width = 123;
264 op.stride_height = 124;
265 op.padding.type = PaddingType::kValid;
266 op.kwidth = 480;
267 op.kheight = 1080;
268 auto output_toco_op = SerializeAndDeserialize(
269 GetOperator("L2_POOL_2D", OperatorType::kL2Pool), op);
270 EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
271 EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
272 EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
273 EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
274 EXPECT_EQ(op.kheight, output_toco_op->kheight);
275 }
276
TEST_F(OperatorTest,BuiltinLocalResponseNormalization)277 TEST_F(OperatorTest, BuiltinLocalResponseNormalization) {
278 LocalResponseNormalizationOperator op;
279 op.range = 123;
280 op.bias = 1.23;
281 op.alpha = 12.3;
282 op.beta = .123;
283 auto output_toco_op = SerializeAndDeserialize(
284 GetOperator("LOCAL_RESPONSE_NORMALIZATION",
285 OperatorType::kLocalResponseNormalization),
286 op);
287 EXPECT_EQ(op.range, output_toco_op->range);
288 EXPECT_EQ(op.bias, output_toco_op->bias);
289 EXPECT_EQ(op.alpha, output_toco_op->alpha);
290 EXPECT_EQ(op.beta, output_toco_op->beta);
291 }
292
TEST_F(OperatorTest,BuiltinMaxPool)293 TEST_F(OperatorTest, BuiltinMaxPool) {
294 MaxPoolOperator op;
295 op.stride_width = 123;
296 op.stride_height = 124;
297 op.padding.type = PaddingType::kValid;
298 op.kwidth = 480;
299 op.kheight = 1080;
300 auto output_toco_op = SerializeAndDeserialize(
301 GetOperator("MAX_POOL_2D", OperatorType::kMaxPool), op);
302 EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
303 EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
304 EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
305 EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
306 EXPECT_EQ(op.kheight, output_toco_op->kheight);
307 }
308
TEST_F(OperatorTest,BuiltinReshape)309 TEST_F(OperatorTest, BuiltinReshape) {
310 TensorFlowReshapeOperator op;
311 op.shape = {1, 2, 4, 5, 8};
312 auto output_toco_op = SerializeAndDeserialize(
313 GetOperator("RESHAPE", OperatorType::kReshape), op);
314 EXPECT_EQ(op.shape, output_toco_op->shape);
315 }
316
TEST_F(OperatorTest,CustomSoftmax)317 TEST_F(OperatorTest, CustomSoftmax) {
318 SoftmaxOperator op;
319 op.beta = 123.1;
320 auto output_toco_op = SerializeAndDeserialize(
321 GetOperator("SOFTMAX", OperatorType::kSoftmax), op);
322 EXPECT_EQ(op.beta, output_toco_op->beta);
323 }
324
TEST_F(OperatorTest,BuiltinSpaceToDepth)325 TEST_F(OperatorTest, BuiltinSpaceToDepth) {
326 SpaceToDepthOperator op;
327 op.block_size = 123;
328 auto output_toco_op = SerializeAndDeserialize(
329 GetOperator("SPACE_TO_DEPTH", OperatorType::kSpaceToDepth), op);
330 EXPECT_EQ(op.block_size, output_toco_op->block_size);
331 }
332
TEST_F(OperatorTest,CustomSplit)333 TEST_F(OperatorTest, CustomSplit) {
334 TensorFlowSplitOperator op;
335 op.num_split = 123;
336 auto output_toco_op =
337 SerializeAndDeserialize(GetOperator("SPLIT", OperatorType::kSplit), op);
338 EXPECT_EQ(op.num_split, output_toco_op->num_split);
339 }
340
TEST_F(OperatorTest,CustomSplitV)341 TEST_F(OperatorTest, CustomSplitV) {
342 TensorFlowSplitVOperator op;
343 op.num_split = 123;
344 auto output_toco_op = SerializeAndDeserialize(
345 GetOperator("SPLIT_V", OperatorType::kSplitV), op);
346 EXPECT_EQ(op.num_split, output_toco_op->num_split);
347 }
348
TEST_F(OperatorTest,BuiltinAveragePool)349 TEST_F(OperatorTest, BuiltinAveragePool) {
350 AveragePoolOperator op;
351 op.fused_activation_function = FusedActivationFunctionType::kRelu6;
352 op.stride_width = 123;
353 op.stride_height = 124;
354 op.padding.type = PaddingType::kValid;
355 op.kwidth = 480;
356 op.kheight = 1080;
357 auto output_toco_op = SerializeAndDeserialize(
358 GetOperator("AVERAGE_POOL_2D", OperatorType::kAveragePool), op);
359 EXPECT_EQ(op.fused_activation_function,
360 output_toco_op->fused_activation_function);
361 EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
362 EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
363 EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
364 EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
365 EXPECT_EQ(op.kheight, output_toco_op->kheight);
366 }
367
TEST_F(OperatorTest,BuiltinConvolution)368 TEST_F(OperatorTest, BuiltinConvolution) {
369 ConvOperator op;
370 op.stride_width = 123;
371 op.stride_height = 124;
372 op.padding.type = PaddingType::kValid;
373 op.fused_activation_function = FusedActivationFunctionType::kRelu6;
374 auto output_toco_op =
375 SerializeAndDeserialize(GetOperator("CONV_2D", OperatorType::kConv), op);
376 EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
377 EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
378 EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
379 EXPECT_EQ(op.fused_activation_function,
380 output_toco_op->fused_activation_function);
381 }
382
TEST_F(OperatorTest,BuiltinDepthwiseConvolution)383 TEST_F(OperatorTest, BuiltinDepthwiseConvolution) {
384 DepthwiseConvOperator op;
385 op.stride_width = 123;
386 op.stride_height = 124;
387 op.padding.type = PaddingType::kValid;
388 op.depth_multiplier = 6;
389 op.fused_activation_function = FusedActivationFunctionType::kRelu6;
390 auto output_toco_op = SerializeAndDeserialize(
391 GetOperator("DEPTHWISE_CONV_2D", OperatorType::kDepthwiseConv), op);
392 EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
393 EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
394 EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
395 EXPECT_EQ(op.depth_multiplier, output_toco_op->depth_multiplier);
396 EXPECT_EQ(op.fused_activation_function,
397 output_toco_op->fused_activation_function);
398 }
399
TEST_F(OperatorTest,BuiltinL2Norm)400 TEST_F(OperatorTest, BuiltinL2Norm) {
401 L2NormalizationOperator op;
402 op.fused_activation_function = FusedActivationFunctionType::kRelu6;
403 auto output_toco_op = SerializeAndDeserialize(
404 GetOperator("L2_NORMALIZATION", OperatorType::kL2Normalization), op);
405 EXPECT_EQ(op.fused_activation_function,
406 output_toco_op->fused_activation_function);
407 }
408
TEST_F(OperatorTest,BuiltinMul)409 TEST_F(OperatorTest, BuiltinMul) {
410 MulOperator op;
411 op.fused_activation_function = FusedActivationFunctionType::kRelu6;
412 auto output_toco_op =
413 SerializeAndDeserialize(GetOperator("MUL", OperatorType::kMul), op);
414 EXPECT_EQ(op.fused_activation_function,
415 output_toco_op->fused_activation_function);
416 }
417
TEST_F(OperatorTest,ResizeBilinear)418 TEST_F(OperatorTest, ResizeBilinear) {
419 ResizeBilinearOperator op;
420 op.align_corners = true;
421 op.half_pixel_centers = false;
422 auto output_toco_op = SerializeAndDeserialize(
423 GetOperator("RESIZE_BILINEAR", OperatorType::kResizeBilinear), op);
424 EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
425 EXPECT_EQ(op.half_pixel_centers, output_toco_op->half_pixel_centers);
426 }
427
TEST_F(OperatorTest,ResizeBilinear_HalfPixelCenters)428 TEST_F(OperatorTest, ResizeBilinear_HalfPixelCenters) {
429 ResizeBilinearOperator op;
430 op.align_corners = true;
431 op.half_pixel_centers = true;
432 auto output_toco_op = SerializeAndDeserialize(
433 GetOperator("RESIZE_BILINEAR", OperatorType::kResizeBilinear), op);
434 EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
435 EXPECT_EQ(op.half_pixel_centers, output_toco_op->half_pixel_centers);
436 }
437
TEST_F(OperatorTest,ResizeNearestNeighbor)438 TEST_F(OperatorTest, ResizeNearestNeighbor) {
439 ResizeNearestNeighborOperator op;
440 op.align_corners = true;
441 op.half_pixel_centers = false;
442 auto output_toco_op =
443 SerializeAndDeserialize(GetOperator("RESIZE_NEAREST_NEIGHBOR",
444 OperatorType::kResizeNearestNeighbor),
445 op);
446 EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
447 EXPECT_EQ(op.half_pixel_centers, output_toco_op->half_pixel_centers);
448 }
449
TEST_F(OperatorTest,ResizeNearestNeighbor_HalfPixelCenters)450 TEST_F(OperatorTest, ResizeNearestNeighbor_HalfPixelCenters) {
451 ResizeNearestNeighborOperator op;
452 op.align_corners = true;
453 op.half_pixel_centers = true;
454 auto output_toco_op =
455 SerializeAndDeserialize(GetOperator("RESIZE_NEAREST_NEIGHBOR",
456 OperatorType::kResizeNearestNeighbor),
457 op);
458 EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
459 EXPECT_EQ(op.half_pixel_centers, output_toco_op->half_pixel_centers);
460 }
461
TEST_F(OperatorTest,Svdf)462 TEST_F(OperatorTest, Svdf) {
463 SvdfOperator op;
464 op.fused_activation_function = FusedActivationFunctionType::kRelu;
465 op.rank = 1;
466 auto output_toco_op =
467 SerializeAndDeserialize(GetOperator("SVDF", OperatorType::kSvdf), op);
468 EXPECT_EQ(op.fused_activation_function,
469 output_toco_op->fused_activation_function);
470 EXPECT_EQ(op.rank, output_toco_op->rank);
471 }
472
TEST_F(OperatorTest,Squeeze)473 TEST_F(OperatorTest, Squeeze) {
474 SqueezeOperator op;
475 op.squeeze_dims = {-2, -3, 4, 1, 4};
476
477 auto output_toco_op = SerializeAndDeserialize(
478 GetOperator("SQUEEZE", OperatorType::kSqueeze), op);
479 EXPECT_EQ(op.squeeze_dims, output_toco_op->squeeze_dims);
480 }
481
TEST_F(OperatorTest,StridedSlice)482 TEST_F(OperatorTest, StridedSlice) {
483 StridedSliceOperator op;
484
485 op.begin_mask = 1;
486 op.end_mask = 2;
487 op.ellipsis_mask = 1;
488 op.new_axis_mask = 1;
489 op.shrink_axis_mask = 2;
490
491 auto output_toco_op = SerializeAndDeserialize(
492 GetOperator("STRIDED_SLICE", OperatorType::kStridedSlice), op);
493 EXPECT_EQ(op.start_indices, output_toco_op->start_indices);
494 EXPECT_EQ(op.stop_indices, output_toco_op->stop_indices);
495 EXPECT_EQ(op.strides, output_toco_op->strides);
496 EXPECT_EQ(op.begin_mask, output_toco_op->begin_mask);
497 EXPECT_EQ(op.end_mask, output_toco_op->end_mask);
498 EXPECT_EQ(op.end_mask, output_toco_op->end_mask);
499 EXPECT_EQ(op.ellipsis_mask, output_toco_op->ellipsis_mask);
500 EXPECT_EQ(op.new_axis_mask, output_toco_op->new_axis_mask);
501 EXPECT_EQ(op.shrink_axis_mask, output_toco_op->shrink_axis_mask);
502 }
503
TEST_F(OperatorTest,BuiltinTopKV2)504 TEST_F(OperatorTest, BuiltinTopKV2) {
505 TopKV2Operator op;
506 auto output_toco_op = SerializeAndDeserialize(
507 GetOperator("TOPK_V2", OperatorType::kTopK_V2), op);
508 ASSERT_NE(nullptr, output_toco_op.get());
509 }
510
TEST_F(OperatorTest,BuiltinArgMax)511 TEST_F(OperatorTest, BuiltinArgMax) {
512 ArgMaxOperator op;
513 auto output_toco_op = SerializeAndDeserialize(
514 GetOperator("ARG_MAX", OperatorType::kArgMax), op);
515 EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
516 }
517
TEST_F(OperatorTest,BuiltinArgMin)518 TEST_F(OperatorTest, BuiltinArgMin) {
519 ArgMinOperator op;
520 auto output_toco_op = SerializeAndDeserialize(
521 GetOperator("ARG_MIN", OperatorType::kArgMin), op);
522 EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
523 }
524
TEST_F(OperatorTest,BuiltinDequantize)525 TEST_F(OperatorTest, BuiltinDequantize) {
526 DequantizeOperator op;
527 auto output_toco_op = SerializeAndDeserialize(
528 GetOperator("DEQUANTIZE", OperatorType::kDequantize), op);
529 }
530
TEST_F(OperatorTest,BuiltinTransposeConv)531 TEST_F(OperatorTest, BuiltinTransposeConv) {
532 TransposeConvOperator op;
533 op.stride_width = 123;
534 op.stride_height = 124;
535 op.padding.type = PaddingType::kValid;
536 auto output_toco_op = SerializeAndDeserialize(
537 GetOperator("TRANSPOSE_CONV", OperatorType::kTransposeConv), op);
538 EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
539 EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
540 EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
541 }
542
TEST_F(OperatorTest,BuiltinShape)543 TEST_F(OperatorTest, BuiltinShape) {
544 TensorFlowShapeOperator op;
545 op.output_data_type = ArrayDataType::kInt64;
546 auto output_toco_op =
547 SerializeAndDeserialize(GetOperator("SHAPE", OperatorType::kShape), op);
548 EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
549 }
550
TEST_F(OperatorTest,BuiltinSparseToDense)551 TEST_F(OperatorTest, BuiltinSparseToDense) {
552 SparseToDenseOperator op;
553 op.validate_indices = false;
554 std::unique_ptr<toco::SparseToDenseOperator> output_toco_op =
555 SerializeAndDeserialize(
556 GetOperator("SPARSE_TO_DENSE", OperatorType::kSparseToDense), op);
557 EXPECT_EQ(op.validate_indices, output_toco_op->validate_indices);
558 }
559
TEST_F(OperatorTest,VersioningSpareToDense)560 TEST_F(OperatorTest, VersioningSpareToDense) {
561 SparseToDenseOperator op;
562 op.inputs = {"indices", "output_shape", "input_values", "default_value"};
563 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
564 const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
565
566 Model int32_model;
567 Array& int32_array = int32_model.GetOrCreateArray(op.inputs[2]);
568 int32_array.data_type = ArrayDataType::kInt32;
569 OperatorSignature int32_signature = {.op = &op, .model = &int32_model};
570 EXPECT_EQ(base_op->GetVersion(int32_signature), 1);
571
572 // Expect version 2 for int64 input.
573 Model int64_model;
574 Array& int64_array = int64_model.GetOrCreateArray(op.inputs[2]);
575 int64_array.data_type = ArrayDataType::kInt64;
576 OperatorSignature int64_signature = {.op = &op, .model = &int64_model};
577 EXPECT_EQ(base_op->GetVersion(int64_signature), 2);
578
579 // Expect version 3 for int8 and uint8 input.
580 Model int8_model;
581 Array& int8_array = int8_model.GetOrCreateArray(op.inputs[2]);
582 int8_array.data_type = ArrayDataType::kInt8;
583 OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
584 EXPECT_EQ(base_op->GetVersion(int8_signature), 3);
585
586 Model uint8_model;
587 Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[2]);
588 uint8_array.data_type = ArrayDataType::kUint8;
589 OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
590 EXPECT_EQ(base_op->GetVersion(uint8_signature), 3);
591 }
592
TEST_F(OperatorTest,BuiltinPack)593 TEST_F(OperatorTest, BuiltinPack) {
594 PackOperator op;
595 op.values_count = 3;
596 op.axis = 1;
597 std::unique_ptr<toco::PackOperator> output_toco_op =
598 SerializeAndDeserialize(GetOperator("PACK", OperatorType::kPack), op);
599 EXPECT_EQ(op.values_count, output_toco_op->values_count);
600 EXPECT_EQ(op.axis, output_toco_op->axis);
601 }
602
TEST_F(OperatorTest,BuiltinOneHot)603 TEST_F(OperatorTest, BuiltinOneHot) {
604 OneHotOperator op;
605 op.axis = 2;
606 auto output_toco_op = SerializeAndDeserialize(
607 GetOperator("ONE_HOT", OperatorType::kOneHot), op);
608 EXPECT_EQ(op.axis, output_toco_op->axis);
609 }
610
TEST_F(OperatorTest,BuiltinUnpack)611 TEST_F(OperatorTest, BuiltinUnpack) {
612 UnpackOperator op;
613 op.num = 5;
614 op.axis = 2;
615 auto output_toco_op =
616 SerializeAndDeserialize(GetOperator("UNPACK", OperatorType::kUnpack), op);
617 EXPECT_EQ(op.num, output_toco_op->num);
618 EXPECT_EQ(op.axis, output_toco_op->axis);
619 }
620
TEST_F(OperatorTest,BuiltinLeakyRelu)621 TEST_F(OperatorTest, BuiltinLeakyRelu) {
622 LeakyReluOperator op;
623 op.alpha = 3;
624 auto output_toco_op = SerializeAndDeserialize(
625 GetOperator("LEAKY_RELU", OperatorType::kLeakyRelu), op);
626 EXPECT_EQ(op.alpha, output_toco_op->alpha);
627 }
628
TEST_F(OperatorTest,BuiltinSquaredDifference)629 TEST_F(OperatorTest, BuiltinSquaredDifference) {
630 SquaredDifferenceOperator op;
631 auto output_toco_op = SerializeAndDeserialize(
632 GetOperator("SQUARED_DIFFERENCE", OperatorType::kSquaredDifference), op);
633 ASSERT_NE(nullptr, output_toco_op.get());
634 }
635
TEST_F(OperatorTest,BuiltinScatterNd)636 TEST_F(OperatorTest, BuiltinScatterNd) {
637 ScatterNdOperator op;
638 auto output_toco_op = SerializeAndDeserialize(
639 GetOperator("SCATTER_ND", OperatorType::kScatterNd), op);
640 ASSERT_NE(nullptr, output_toco_op.get());
641 }
642
TEST_F(OperatorTest,CustomCTCBeamSearchDecoder)643 TEST_F(OperatorTest, CustomCTCBeamSearchDecoder) {
644 CTCBeamSearchDecoderOperator op;
645 op.beam_width = 3;
646 op.top_paths = 2;
647 op.merge_repeated = false;
648 std::unique_ptr<toco::CTCBeamSearchDecoderOperator> output_toco_op =
649 SerializeAndDeserialize(GetOperator("CTC_BEAM_SEARCH_DECODER",
650 OperatorType::kCTCBeamSearchDecoder),
651 op);
652 EXPECT_EQ(op.beam_width, output_toco_op->beam_width);
653 EXPECT_EQ(op.top_paths, output_toco_op->top_paths);
654 EXPECT_EQ(op.merge_repeated, output_toco_op->merge_repeated);
655 }
656
TEST_F(OperatorTest,TensorFlowUnsupported)657 TEST_F(OperatorTest, TensorFlowUnsupported) {
658 TensorFlowUnsupportedOperator op;
659 op.tensorflow_op = "MyCustomUnsupportedOp";
660
661 ::tensorflow::NodeDef node_def;
662 auto attr = node_def.mutable_attr();
663 (*attr)["float_attr"].set_f(2.0);
664 (*attr)["str_attr"].set_s("Hello World");
665 (*attr)["int_attr"].set_i(17);
666 (*attr)["bool_attr"].set_b(true);
667 {
668 auto* list = (*attr)["list_string_attr"].mutable_list();
669 list->add_s("abcde");
670 list->add_s("1234");
671 list->add_s("");
672 list->add_s("zyxwv");
673 list->add_s("!-.");
674 }
675 {
676 auto* list = (*attr)["list_float_attr"].mutable_list();
677 list->add_f(std::numeric_limits<float>::min());
678 list->add_f(2.0);
679 list->add_f(-std::numeric_limits<float>::max());
680 }
681 {
682 auto* list = (*attr)["list_int_attr"].mutable_list();
683 list->add_i(1);
684 list->add_i(20);
685 list->add_i(1LL << 40);
686 list->add_i(-(1LL << 40));
687 }
688 node_def.SerializeToString(&op.tensorflow_node_def);
689
690 auto output_toco_op = SerializeAndDeserialize(
691 GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op);
692
693 ::tensorflow::NodeDef output_node_def;
694 output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
695 const auto& output_attr = output_node_def.attr();
696 EXPECT_EQ(2.0, output_attr.at("float_attr").f());
697 EXPECT_EQ("Hello World", output_attr.at("str_attr").s());
698 EXPECT_EQ(17, output_attr.at("int_attr").i());
699 EXPECT_EQ(true, output_attr.at("bool_attr").b());
700 {
701 const auto& list = output_attr.at("list_string_attr").list();
702 ASSERT_EQ(5, list.s_size());
703 EXPECT_EQ("abcde", list.s(0));
704 EXPECT_EQ("1234", list.s(1));
705 EXPECT_EQ("", list.s(2));
706 EXPECT_EQ("zyxwv", list.s(3));
707 EXPECT_EQ("!-.", list.s(4));
708 }
709 {
710 const auto& list = output_attr.at("list_float_attr").list();
711 ASSERT_EQ(3, list.f_size());
712 EXPECT_EQ(std::numeric_limits<float>::min(), list.f(0));
713 EXPECT_EQ(2.0, list.f(1));
714 EXPECT_EQ(-std::numeric_limits<float>::max(), list.f(2));
715 }
716 {
717 const auto& list = output_attr.at("list_int_attr").list();
718 ASSERT_EQ(4, list.i_size());
719 EXPECT_EQ(1, list.i(0));
720 EXPECT_EQ(20, list.i(1));
721 EXPECT_EQ(1LL << 40, list.i(2));
722 EXPECT_EQ(-(1LL << 40), list.i(3));
723 }
724 }
725
TEST_F(OperatorTest,TensorFlowUnsupportedWithoutAttr)726 TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) {
727 TensorFlowUnsupportedOperator op;
728 op.tensorflow_op = "MyCustomUnsupportedOp";
729 auto output_toco_op = SerializeAndDeserialize(
730 GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op);
731
732 ::tensorflow::NodeDef output_node_def;
733 output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
734 EXPECT_TRUE(output_node_def.attr().empty());
735 }
736
TEST_F(OperatorTest,TestShouldExportAsFlexOp)737 TEST_F(OperatorTest, TestShouldExportAsFlexOp) {
738 EXPECT_FALSE(ShouldExportAsFlexOp(false, "Conv2D"));
739 EXPECT_TRUE(ShouldExportAsFlexOp(true, "Conv2D"));
740 EXPECT_TRUE(ShouldExportAsFlexOp(true, "EluGrad"));
741 EXPECT_TRUE(ShouldExportAsFlexOp(true, "RFFT"));
742 EXPECT_FALSE(ShouldExportAsFlexOp(true, "MyAwesomeCustomOp"));
743 EXPECT_TRUE(ShouldExportAsFlexOp(true, "RandomShuffle"));
744 }
745
TEST_F(OperatorTest,BuiltinMirrorPad)746 TEST_F(OperatorTest, BuiltinMirrorPad) {
747 MirrorPadOperator op;
748 op.mode = MirrorPadMode::kReflect;
749 auto output_toco_op = SerializeAndDeserialize(
750 GetOperator("MIRROR_PAD", OperatorType::kMirrorPad), op);
751 EXPECT_EQ(op.mode, output_toco_op->mode);
752 }
753
TEST_F(OperatorTest,BuiltinUnique)754 TEST_F(OperatorTest, BuiltinUnique) {
755 UniqueOperator op;
756 op.idx_out_type = ArrayDataType::kInt64;
757 auto output_toco_op =
758 SerializeAndDeserialize(GetOperator("UNIQUE", OperatorType::kUnique), op);
759 ASSERT_NE(nullptr, output_toco_op.get());
760 EXPECT_EQ(output_toco_op->idx_out_type, op.idx_out_type);
761 }
762
TEST_F(OperatorTest,BuiltinSegmentSum)763 TEST_F(OperatorTest, BuiltinSegmentSum) {
764 SegmentSumOperator op;
765 auto output_toco_op = SerializeAndDeserialize(
766 GetOperator("SEGMENT_SUM", OperatorType::kSegmentSum), op);
767 ASSERT_NE(nullptr, output_toco_op.get());
768 }
769
TEST_F(OperatorTest,BuiltinReverseSequence)770 TEST_F(OperatorTest, BuiltinReverseSequence) {
771 ReverseSequenceOperator op;
772 op.seq_dim = 3;
773 op.batch_dim = 1;
774 std::unique_ptr<toco::ReverseSequenceOperator> output_toco_op =
775 SerializeAndDeserialize(
776 GetOperator("REVERSE_SEQUENCE", OperatorType::kReverseSequence), op);
777 EXPECT_EQ(op.seq_dim, output_toco_op->seq_dim);
778 EXPECT_EQ(op.batch_dim, output_toco_op->batch_dim);
779 }
780
TEST_F(OperatorTest,BuiltinMatrixDiag)781 TEST_F(OperatorTest, BuiltinMatrixDiag) {
782 MatrixDiagOperator op;
783 std::unique_ptr<toco::MatrixDiagOperator> output_toco_op =
784 SerializeAndDeserialize(
785 GetOperator("MATRIX_DIAG", OperatorType::kMatrixDiag), op);
786 }
787
TEST_F(OperatorTest,BuiltinMatrixSetDiag)788 TEST_F(OperatorTest, BuiltinMatrixSetDiag) {
789 MatrixSetDiagOperator op;
790 std::unique_ptr<toco::MatrixSetDiagOperator> output_toco_op =
791 SerializeAndDeserialize(
792 GetOperator("MATRIX_SET_DIAG", OperatorType::kMatrixSetDiag), op);
793 }
794
795 // Test version for a simple Op with 2 versions and the input type controls the
796 // version.
797 template <typename Op>
SimpleVersioningTest()798 void SimpleVersioningTest() {
799 Op op;
800 op.inputs = {"input1"};
801 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
802 const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
803
804 Model uint8_model;
805 Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
806 uint8_array.data_type = ArrayDataType::kUint8;
807 OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
808 EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
809
810 Model int8_model;
811 Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
812 int8_array.data_type = ArrayDataType::kInt8;
813 OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
814 EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
815 }
816
817 // Test version for a simple Op with 2 versions and the output type controls the
818 // version.
819 template <typename Op>
SimpleOutputVersioningTest()820 void SimpleOutputVersioningTest() {
821 Op op;
822 op.outputs = {"output1"};
823 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
824 const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
825
826 Model uint8_model;
827 Array& uint8_array = uint8_model.GetOrCreateArray(op.outputs[0]);
828 uint8_array.data_type = ArrayDataType::kUint8;
829 OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
830 EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
831
832 Model int8_model;
833 Array& int8_array = int8_model.GetOrCreateArray(op.outputs[0]);
834 int8_array.data_type = ArrayDataType::kInt8;
835 OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
836 EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
837 }
838
TEST_F(OperatorTest,VersioningEqualTest)839 TEST_F(OperatorTest, VersioningEqualTest) {
840 SimpleVersioningTest<TensorFlowEqualOperator>();
841 }
842
TEST_F(OperatorTest,VersioningNotEqualTest)843 TEST_F(OperatorTest, VersioningNotEqualTest) {
844 SimpleVersioningTest<TensorFlowNotEqualOperator>();
845 }
846
TEST_F(OperatorTest,VersioningLessTest)847 TEST_F(OperatorTest, VersioningLessTest) {
848 SimpleVersioningTest<TensorFlowLessOperator>();
849 }
850
TEST_F(OperatorTest,VersioningLessEqualTest)851 TEST_F(OperatorTest, VersioningLessEqualTest) {
852 SimpleVersioningTest<TensorFlowLessEqualOperator>();
853 }
854
TEST_F(OperatorTest,VersioningGreaterTest)855 TEST_F(OperatorTest, VersioningGreaterTest) {
856 SimpleVersioningTest<TensorFlowGreaterOperator>();
857 }
858
TEST_F(OperatorTest,VersioningGreaterEqualTest)859 TEST_F(OperatorTest, VersioningGreaterEqualTest) {
860 SimpleVersioningTest<TensorFlowGreaterEqualOperator>();
861 }
862
TEST_F(OperatorTest,VersioningSpaceToBatchNDTest)863 TEST_F(OperatorTest, VersioningSpaceToBatchNDTest) {
864 SpaceToBatchNDOperator op;
865 op.inputs = {"input1"};
866 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
867 const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
868
869 Model uint8_model;
870 Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
871 uint8_array.copy_shape({1, 2, 2, 2});
872 uint8_array.data_type = ArrayDataType::kUint8;
873 OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
874 EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
875
876 Model int8_model;
877 Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
878 int8_array.copy_shape({1, 2, 2, 2});
879 int8_array.data_type = ArrayDataType::kInt8;
880 OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
881 EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
882
883 Model float_model;
884 Array& float_array = float_model.GetOrCreateArray(op.inputs[0]);
885 float_array.copy_shape({1, 2, 2});
886 float_array.data_type = ArrayDataType::kFloat;
887 OperatorSignature float_signature = {.op = &op, .model = &float_model};
888 EXPECT_EQ(base_op->GetVersion(float_signature), 3);
889 }
890
TEST_F(OperatorTest,VersioningLogSoftmaxTest)891 TEST_F(OperatorTest, VersioningLogSoftmaxTest) {
892 SimpleVersioningTest<LogSoftmaxOperator>();
893 }
894
TEST_F(OperatorTest,VersioningPackTest)895 TEST_F(OperatorTest, VersioningPackTest) {
896 SimpleVersioningTest<PackOperator>();
897 }
898
TEST_F(OperatorTest,VersioningUnpackTest)899 TEST_F(OperatorTest, VersioningUnpackTest) {
900 UnpackOperator op;
901 op.inputs = {"input1"};
902 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
903 const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
904
905 Model int32_model;
906 Array& int32_array = int32_model.GetOrCreateArray(op.inputs[0]);
907 int32_array.data_type = ArrayDataType::kInt32;
908 OperatorSignature int32_signature = {.op = &op, .model = &int32_model};
909 EXPECT_EQ(base_op->GetVersion(int32_signature), 1);
910
911 Model uint8_model;
912 Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
913 uint8_array.data_type = ArrayDataType::kUint8;
914 OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
915 EXPECT_EQ(base_op->GetVersion(uint8_signature), 2);
916
917 Model int8_model;
918 Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
919 int8_array.data_type = ArrayDataType::kInt8;
920 OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
921 EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
922 }
923
TEST_F(OperatorTest,VersioningBatchToSpaceNDTest)924 TEST_F(OperatorTest, VersioningBatchToSpaceNDTest) {
925 BatchToSpaceNDOperator op;
926 op.inputs = {"input1"};
927 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
928 const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
929
930 Model uint8_model;
931 Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
932 uint8_array.data_type = ArrayDataType::kUint8;
933 uint8_array.copy_shape({1, 2, 2, 2});
934 OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
935 EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
936
937 Model int8_model;
938 Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
939 int8_array.data_type = ArrayDataType::kInt8;
940 int8_array.copy_shape({1, 2, 2, 2});
941 OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
942 EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
943
944 Model float_model;
945 Array& float_array = float_model.GetOrCreateArray(op.inputs[0]);
946 float_array.copy_shape({1, 2, 2});
947 float_array.data_type = ArrayDataType::kFloat;
948 OperatorSignature float_signature = {.op = &op, .model = &float_model};
949 EXPECT_EQ(base_op->GetVersion(float_signature), 3);
950 }
951
TEST_F(OperatorTest,VersioningTanhTest)952 TEST_F(OperatorTest, VersioningTanhTest) {
953 SimpleVersioningTest<TanhOperator>();
954 }
955
TEST_F(OperatorTest,VersioningStridedSliceTest)956 TEST_F(OperatorTest, VersioningStridedSliceTest) {
957 StridedSliceOperator op;
958 op.inputs = {"input1"};
959 op.ellipsis_mask = 0;
960 op.new_axis_mask = 0;
961 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
962 const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
963
964 Model uint8_model;
965 Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
966 uint8_array.data_type = ArrayDataType::kUint8;
967 OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
968 EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
969
970 Model int8_model;
971 Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
972 int8_array.data_type = ArrayDataType::kInt8;
973 OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
974 EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
975
976 Model bool_model;
977 Array& bool_array = bool_model.GetOrCreateArray(op.inputs[0]);
978 bool_array.data_type = ArrayDataType::kBool;
979 OperatorSignature bool_signature = {.op = &op, .model = &bool_model};
980 EXPECT_EQ(base_op->GetVersion(bool_signature), 3);
981
982 op.start_indices = {0, 0, 0, 0, 0};
983 op.stop_indices = {1, 2, 2, 2, 2};
984 op.strides = {1, 1, 1, 1, 1};
985 EXPECT_EQ(base_op->GetVersion(uint8_signature), 4);
986 EXPECT_EQ(base_op->GetVersion(int8_signature), 4);
987 EXPECT_EQ(base_op->GetVersion(bool_signature), 4);
988 }
989
TEST_F(OperatorTest,VersioningSpaceToDepthTest)990 TEST_F(OperatorTest, VersioningSpaceToDepthTest) {
991 SimpleVersioningTest<SpaceToDepthOperator>();
992 }
993
TEST_F(OperatorTest,VersioningSliceTest)994 TEST_F(OperatorTest, VersioningSliceTest) {
995 SimpleVersioningTest<SliceOperator>();
996
997 // Check that a string input results in a version 3 op.
998 SliceOperator op;
999 op.inputs = {"input1"};
1000 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
1001 const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
1002
1003 Model string_model;
1004 Array& string_array = string_model.GetOrCreateArray(op.inputs[0]);
1005 string_array.data_type = ArrayDataType::kString;
1006 OperatorSignature string_signature = {.op = &op, .model = &string_model};
1007 EXPECT_EQ(base_op->GetVersion(string_signature), 3);
1008 }
1009
TEST_F(OperatorTest,VersioningLogisticTest)1010 TEST_F(OperatorTest, VersioningLogisticTest) {
1011 SimpleVersioningTest<LogisticOperator>();
1012 }
1013
TEST_F(OperatorTest,VersioningL2NormTest)1014 TEST_F(OperatorTest, VersioningL2NormTest) {
1015 SimpleOutputVersioningTest<L2NormalizationOperator>();
1016 }
1017
TEST_F(OperatorTest,VersioningMaxTest)1018 TEST_F(OperatorTest, VersioningMaxTest) {
1019 SimpleVersioningTest<TensorFlowMaximumOperator>();
1020 }
1021
TEST_F(OperatorTest,VersioningMinTest)1022 TEST_F(OperatorTest, VersioningMinTest) {
1023 SimpleVersioningTest<TensorFlowMinimumOperator>();
1024 }
1025
TEST_F(OperatorTest,VersioningMeanTest)1026 TEST_F(OperatorTest, VersioningMeanTest) {
1027 SimpleVersioningTest<MeanOperator>();
1028 }
1029
TEST_F(OperatorTest,VersioningSumTest)1030 TEST_F(OperatorTest, VersioningSumTest) {
1031 SimpleVersioningTest<TensorFlowSumOperator>();
1032 }
1033
TEST_F(OperatorTest,VersioningAddTest)1034 TEST_F(OperatorTest, VersioningAddTest) { SimpleVersioningTest<AddOperator>(); }
1035
SimpleMulVersioningTest(ArrayDataType data_type,float multiplier,int version)1036 void SimpleMulVersioningTest(ArrayDataType data_type, float multiplier,
1037 int version) {
1038 MulOperator op;
1039 op.inputs = {"input1", "input2"};
1040 op.outputs = {"output"};
1041 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
1042 const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
1043
1044 Model model;
1045 Array& input0 = model.GetOrCreateArray(op.inputs[0]);
1046 Array& input1 = model.GetOrCreateArray(op.inputs[1]);
1047 Array& output = model.GetOrCreateArray(op.outputs[0]);
1048
1049 input0.data_type = data_type;
1050 input0.GetOrCreateQuantizationParams().scale = 1.0f;
1051 input1.data_type = data_type;
1052 input1.GetOrCreateQuantizationParams().scale = 1.0f;
1053 output.data_type = data_type;
1054 output.GetOrCreateQuantizationParams().scale = 1.0f / multiplier;
1055
1056 OperatorSignature signature = {.op = &op, .model = &model};
1057 EXPECT_EQ(base_op->GetVersion(signature), version);
1058 }
1059
TEST_F(OperatorTest,VersioningMulTest)1060 TEST_F(OperatorTest, VersioningMulTest) {
1061 SimpleMulVersioningTest(ArrayDataType::kUint8, 0.5f, 1);
1062 SimpleMulVersioningTest(ArrayDataType::kInt8, 0.5f, 2);
1063 SimpleMulVersioningTest(ArrayDataType::kInt8, 2.0f, 3);
1064 }
1065
1066 template <typename OpType>
SimpleTwoInputsVersioningTest(ArrayDataType data_type,Shape shape1,Shape shape2,int version)1067 void SimpleTwoInputsVersioningTest(ArrayDataType data_type, Shape shape1,
1068 Shape shape2, int version) {
1069 OpType op;
1070 op.inputs = {"input1", "input2"};
1071 op.outputs = {"output"};
1072 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
1073 const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
1074
1075 Model model;
1076 Array& input0 = model.GetOrCreateArray(op.inputs[0]);
1077 Array& input1 = model.GetOrCreateArray(op.inputs[1]);
1078 Array& output = model.GetOrCreateArray(op.outputs[0]);
1079
1080 input0.data_type = data_type;
1081 input0.copy_shape(shape1);
1082 input1.data_type = data_type;
1083 input1.copy_shape(shape2);
1084 output.data_type = data_type;
1085
1086 OperatorSignature signature = {.op = &op, .model = &model};
1087 EXPECT_EQ(base_op->GetVersion(signature), version);
1088 }
1089
1090 template <typename OpType>
SimpleThreeInputsVersioningTest(ArrayDataType data_type,Shape shape1,Shape shape2,Shape shape3,int version)1091 void SimpleThreeInputsVersioningTest(ArrayDataType data_type, Shape shape1,
1092 Shape shape2, Shape shape3, int version) {
1093 OpType op;
1094 op.inputs = {"input1", "input2", "input3"};
1095 op.outputs = {"output"};
1096 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
1097 const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
1098
1099 Model model;
1100 Array& input0 = model.GetOrCreateArray(op.inputs[0]);
1101 Array& input1 = model.GetOrCreateArray(op.inputs[1]);
1102 Array& input2 = model.GetOrCreateArray(op.inputs[2]);
1103 Array& output = model.GetOrCreateArray(op.outputs[0]);
1104
1105 input0.data_type = data_type;
1106 input0.copy_shape(shape1);
1107 input1.data_type = data_type;
1108 input1.copy_shape(shape2);
1109 input2.data_type = data_type;
1110 input2.copy_shape(shape3);
1111 output.data_type = data_type;
1112
1113 OperatorSignature signature = {.op = &op, .model = &model};
1114 EXPECT_EQ(base_op->GetVersion(signature), version);
1115 }
1116
TEST_F(OperatorTest,VersioningSubTest)1117 TEST_F(OperatorTest, VersioningSubTest) {
1118 SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kUint8,
1119 {1, 2, 2, 2}, {1, 2, 2, 2}, 1);
1120 SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kInt8, {1, 2, 2, 2},
1121 {1, 2, 2, 2}, 2);
1122 SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kUint8, {1, 2, 2},
1123 {1, 2, 2}, 1);
1124 SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kInt8, {1, 2, 2},
1125 {1, 2, 2}, 2);
1126 SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kUint8,
1127 {1, 2, 2, 2}, {1, 2, 2, 1}, 1);
1128 SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kInt8, {1, 2, 2, 2},
1129 {1, 2, 2, 1}, 2);
1130 SimpleTwoInputsVersioningTest<SubOperator>(
1131 ArrayDataType::kUint8, {1, 2, 2, 2, 2}, {1, 2, 2, 2, 1}, 3);
1132 SimpleTwoInputsVersioningTest<SubOperator>(
1133 ArrayDataType::kInt8, {1, 2, 2, 2, 2}, {1, 2, 2, 2, 1}, 3);
1134 }
1135
TEST_F(OperatorTest,VersioningDivTest)1136 TEST_F(OperatorTest, VersioningDivTest) {
1137 SimpleTwoInputsVersioningTest<DivOperator>(ArrayDataType::kUint8,
1138 {1, 2, 2, 2}, {1, 2, 2, 2}, 1);
1139 SimpleTwoInputsVersioningTest<DivOperator>(ArrayDataType::kInt8, {1, 2, 2},
1140 {1, 2, 2}, 1);
1141 SimpleTwoInputsVersioningTest<DivOperator>(ArrayDataType::kUint8,
1142 {1, 2, 2, 2}, {1, 2, 2, 1}, 1);
1143 SimpleTwoInputsVersioningTest<DivOperator>(
1144 ArrayDataType::kInt8, {1, 2, 2, 2, 2}, {1, 2, 2, 2, 1}, 2);
1145 }
1146
TEST_F(OperatorTest,VersioningPadTest)1147 TEST_F(OperatorTest, VersioningPadTest) { SimpleVersioningTest<PadOperator>(); }
1148
TEST_F(OperatorTest,VersioningPadV2Test)1149 TEST_F(OperatorTest, VersioningPadV2Test) {
1150 SimpleVersioningTest<PadV2Operator>();
1151 }
1152
TEST_F(OperatorTest,VersioningConcatenationTest)1153 TEST_F(OperatorTest, VersioningConcatenationTest) {
1154 SimpleVersioningTest<ConcatenationOperator>();
1155 }
1156
TEST_F(OperatorTest,VersioningSelectTest)1157 TEST_F(OperatorTest, VersioningSelectTest) {
1158 SimpleThreeInputsVersioningTest<SelectOperator>(
1159 ArrayDataType::kUint8, {1, 2, 2, 2}, {1, 2, 2, 1}, {1, 2, 2, 1}, 1);
1160 SimpleThreeInputsVersioningTest<SelectOperator>(
1161 ArrayDataType::kInt8, {1, 2, 2, 2}, {1, 2, 2, 1}, {1, 2, 2, 1}, 2);
1162 SimpleThreeInputsVersioningTest<SelectOperator>(
1163 ArrayDataType::kInt8, {1, 2, 2, 2, 1}, {1, 2, 2, 1, 1}, {1, 2, 2, 1, 1},
1164 3);
1165 }
1166
TEST_F(OperatorTest,VersioningRelu6Test)1167 TEST_F(OperatorTest, VersioningRelu6Test) {
1168 SimpleVersioningTest<Relu6Operator>();
1169 }
1170
TEST_F(OperatorTest,VersioningFullyConnectedTest)1171 TEST_F(OperatorTest, VersioningFullyConnectedTest) {
1172 FullyConnectedOperator fully_connected_op;
1173 fully_connected_op.inputs = {"input", "weight"};
1174 fully_connected_op.outputs = {"output"};
1175 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
1176 const BaseOperator* op =
1177 operator_by_type_map.at(fully_connected_op.type).get();
1178
1179 Model uint8_model;
1180 Array& input_uint8_array =
1181 uint8_model.GetOrCreateArray(fully_connected_op.inputs[0]);
1182 input_uint8_array.data_type = ArrayDataType::kUint8;
1183 Array& weight_uint8_array =
1184 uint8_model.GetOrCreateArray(fully_connected_op.inputs[1]);
1185 weight_uint8_array.data_type = ArrayDataType::kUint8;
1186 Array& output_uint8_array =
1187 uint8_model.GetOrCreateArray(fully_connected_op.outputs[0]);
1188 output_uint8_array.data_type = ArrayDataType::kUint8;
1189 OperatorSignature uint8_signature = {.op = &fully_connected_op,
1190 .model = &uint8_model};
1191 EXPECT_EQ(op->GetVersion(uint8_signature), 6);
1192
1193 Model int8_model;
1194 Array& input_int8_array =
1195 int8_model.GetOrCreateArray(fully_connected_op.inputs[0]);
1196 input_int8_array.data_type = ArrayDataType::kInt8;
1197 Array& weight_int8_array =
1198 int8_model.GetOrCreateArray(fully_connected_op.inputs[1]);
1199 weight_int8_array.data_type = ArrayDataType::kInt8;
1200 Array& output_int8_array =
1201 int8_model.GetOrCreateArray(fully_connected_op.outputs[0]);
1202 output_int8_array.data_type = ArrayDataType::kInt8;
1203 OperatorSignature int8_signature = {.op = &fully_connected_op,
1204 .model = &int8_model};
1205 EXPECT_EQ(op->GetVersion(int8_signature), 6);
1206 }
1207
TEST_F(OperatorTest,VersioningDequantizeTest)1208 TEST_F(OperatorTest, VersioningDequantizeTest) {
1209 DequantizeOperator dequant_op;
1210 dequant_op.inputs = {"input"};
1211 dequant_op.outputs = {"output"};
1212 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
1213 const BaseOperator* op = operator_by_type_map.at(dequant_op.type).get();
1214
1215 Model int16_model;
1216 Array& input_int16_array = int16_model.GetOrCreateArray(dequant_op.inputs[0]);
1217 input_int16_array.data_type = ArrayDataType::kInt16;
1218 OperatorSignature int16_signature = {.op = &dequant_op,
1219 .model = &int16_model};
1220 EXPECT_EQ(op->GetVersion(int16_signature), 3);
1221
1222 Model float16_model;
1223 Array& input_float16_array =
1224 float16_model.GetOrCreateArray(dequant_op.inputs[0]);
1225 input_float16_array.data_type = ArrayDataType::kFloat16;
1226 OperatorSignature float16_signature = {.op = &dequant_op,
1227 .model = &float16_model};
1228 EXPECT_EQ(op->GetVersion(float16_signature), 3);
1229
1230 Model int8_model;
1231 Array& input_int8_array = int8_model.GetOrCreateArray(dequant_op.inputs[0]);
1232 input_int8_array.data_type = ArrayDataType::kInt8;
1233 OperatorSignature int8_signature = {.op = &dequant_op, .model = &int8_model};
1234 EXPECT_EQ(op->GetVersion(int8_signature), 2);
1235
1236 Model float_model;
1237 Array& input_float_array = float_model.GetOrCreateArray(dequant_op.inputs[0]);
1238 input_float_array.data_type = ArrayDataType::kFloat;
1239 OperatorSignature float_signature = {.op = &dequant_op,
1240 .model = &float_model};
1241 EXPECT_EQ(op->GetVersion(float_signature), 1);
1242 }
1243
TEST_F(OperatorTest,VersioningConv2DTest)1244 TEST_F(OperatorTest, VersioningConv2DTest) {
1245 ConvOperator conv_op;
1246 conv_op.inputs = {"input", "filter"};
1247 conv_op.outputs = {"output"};
1248 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
1249 const BaseOperator* op = operator_by_type_map.at(conv_op.type).get();
1250
1251 Model uint8_model;
1252 Array& input_uint8_array = uint8_model.GetOrCreateArray(conv_op.inputs[0]);
1253 input_uint8_array.data_type = ArrayDataType::kUint8;
1254 Array& filter_uint8_array = uint8_model.GetOrCreateArray(conv_op.inputs[1]);
1255 filter_uint8_array.data_type = ArrayDataType::kUint8;
1256 Array& output_uint8_array = uint8_model.GetOrCreateArray(conv_op.outputs[0]);
1257 output_uint8_array.data_type = ArrayDataType::kUint8;
1258 OperatorSignature uint8_signature = {.op = &conv_op, .model = &uint8_model};
1259 EXPECT_EQ(op->GetVersion(uint8_signature), 1);
1260
1261 Model int8_model;
1262 Array& input_int8_array = int8_model.GetOrCreateArray(conv_op.inputs[0]);
1263 input_int8_array.data_type = ArrayDataType::kInt8;
1264 Array& filter_int8_array = int8_model.GetOrCreateArray(conv_op.inputs[1]);
1265 filter_int8_array.data_type = ArrayDataType::kInt8;
1266 Array& output_int8_array = int8_model.GetOrCreateArray(conv_op.outputs[0]);
1267 output_int8_array.data_type = ArrayDataType::kInt8;
1268 OperatorSignature int8_signature = {.op = &conv_op, .model = &int8_model};
1269 EXPECT_EQ(op->GetVersion(int8_signature), 3);
1270
1271 Model float_model;
1272 Array& input_float_array = float_model.GetOrCreateArray(conv_op.inputs[0]);
1273 input_float_array.data_type = ArrayDataType::kFloat;
1274 Array& filter_int8_array1 = float_model.GetOrCreateArray(conv_op.inputs[1]);
1275 filter_int8_array1.data_type = ArrayDataType::kInt8;
1276 Array& output_float_array = float_model.GetOrCreateArray(conv_op.outputs[0]);
1277 output_float_array.data_type = ArrayDataType::kFloat;
1278 OperatorSignature float_signature = {.op = &conv_op, .model = &float_model};
1279 EXPECT_EQ(op->GetVersion(float_signature), 2);
1280 }
1281
TEST_F(OperatorTest,VersioningFloorDivOperatorTest)1282 TEST_F(OperatorTest, VersioningFloorDivOperatorTest) {
1283 FloorDivOperator floordiv_op;
1284 floordiv_op.inputs = {"input1"};
1285 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
1286 const BaseOperator* op = operator_by_type_map.at(floordiv_op.type).get();
1287
1288 Model int32_model;
1289 Array& input_int32_array =
1290 int32_model.GetOrCreateArray(floordiv_op.inputs[0]);
1291 input_int32_array.data_type = ArrayDataType::kInt32;
1292 OperatorSignature int32_signature = {.op = &floordiv_op,
1293 .model = &int32_model};
1294 EXPECT_EQ(op->GetVersion(int32_signature), 1);
1295
1296 Model float_model;
1297 Array& input_float_array =
1298 float_model.GetOrCreateArray(floordiv_op.inputs[0]);
1299 input_float_array.data_type = ArrayDataType::kFloat;
1300 OperatorSignature float_signature = {.op = &floordiv_op,
1301 .model = &float_model};
1302 EXPECT_EQ(op->GetVersion(float_signature), 2);
1303 }
1304
1305 } // namespace
1306 } // namespace tflite
1307
1308 } // namespace toco
1309