xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/activations_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 #include <stdlib.h>
17 
18 #include <algorithm>
19 #include <cmath>
20 #include <initializer_list>
21 #include <limits>
22 #include <map>
23 #include <memory>
24 #include <random>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 #include <gtest/gtest.h>
30 #include "absl/memory/memory.h"
31 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
32 #include "tensorflow/lite/core/api/op_resolver.h"
33 #include "tensorflow/lite/interpreter.h"
34 #include "tensorflow/lite/kernels/test_util.h"
35 #include "tensorflow/lite/schema/schema_generated.h"
36 #include "tensorflow/lite/string_type.h"
37 
38 namespace tflite {
39 
40 namespace ops {
41 namespace builtin {
42 
43 // Tanh kernel registrations.
44 TfLiteRegistration* Register_TANH_REF();
45 TfLiteRegistration* Register_TANH_GENERIC_OPT();
46 TfLiteRegistration* Register_TANH_FIXED_POINT_OPT();
47 
48 // Logistic kernel registrations.
49 TfLiteRegistration* Register_LOGISTIC_REF();
50 TfLiteRegistration* Register_LOGISTIC_GENERIC_OPT();
51 TfLiteRegistration* Register_LOGISTIC_FIXED_POINT_OPT();
52 
53 // LogSoftmax kernel registrations.
54 TfLiteRegistration* Register_LOG_SOFTMAX_REF();
55 TfLiteRegistration* Register_LOG_SOFTMAX();
56 
57 // Softmax kernel registrations.
58 TfLiteRegistration* Register_SOFTMAX_REF();
59 TfLiteRegistration* Register_SOFTMAX();
60 
61 // PRelu kernel registrations.
62 TfLiteRegistration* Register_PRELU_REF();
63 TfLiteRegistration* Register_PRELU();
64 
65 // LeakyRelu kernel registrations.
66 TfLiteRegistration* Register_LEAKY_RELU_REF();
67 TfLiteRegistration* Register_LEAKY_RELU();
68 
69 }  // namespace builtin
70 }  // namespace ops
71 
72 namespace {
73 
74 using ::testing::ElementsAreArray;
75 
76 class BaseActivationsOpModel : public SingleOpModel {
77  public:
78   // Most activations don't take any options, so this constructor works for
79   // them.
BaseActivationsOpModel(BuiltinOperator type,TensorData input)80   BaseActivationsOpModel(BuiltinOperator type, TensorData input) {
81     input_ = AddInput(input);
82     if (input.type == TensorType_UINT8) {
83       output_ = AddOutput({input.type, {}, 0, 0, 1. / 256});
84     } else if (input.type == TensorType_INT8) {
85       output_ = AddOutput({input.type, {}, 0, 0, 1. / 256, -128});
86     } else {
87       output_ = AddOutput({input.type, {}});
88     }
89     SetBuiltinOp(type, BuiltinOptions_NONE, 0);
90     BuildInterpreter({GetShape(input_)});
91   }
92 
BaseActivationsOpModel(TfLiteRegistration * registration,BuiltinOperator type,TensorData input)93   BaseActivationsOpModel(TfLiteRegistration* registration, BuiltinOperator type,
94                          TensorData input) {
95     input_ = AddInput(input);
96     if (input.type == TensorType_UINT8) {
97       output_ = AddOutput({input.type, {}, 0, 0, 1. / 256});
98     } else if (input.type == TensorType_INT8) {
99       output_ = AddOutput({input.type, {}, 0, 0, 1. / 256, -128});
100     } else {
101       output_ = AddOutput({input.type, {}});
102     }
103     SetBuiltinOp(type, BuiltinOptions_NONE, 0);
104     resolver_ = std::make_unique<SingleOpResolver>(type, registration);
105     BuildInterpreter({GetShape(input_)});
106   }
107 
108   // A dedicated constructor for SOFTMAX, which does some options.
BaseActivationsOpModel(TfLiteRegistration * registration,float softmax_beta,TensorData input,TensorType output_type)109   BaseActivationsOpModel(TfLiteRegistration* registration, float softmax_beta,
110                          TensorData input, TensorType output_type) {
111     input_ = AddInput(input);
112     if (output_type == TensorType_UINT8) {
113       output_ = AddOutput({TensorType_UINT8, {}, 0, 0, 1. / 256});
114     } else if (output_type == TensorType_INT8) {
115       output_ = AddOutput({TensorType_INT8, {}, 0, 0, 1. / 256, -128});
116     } else if (input.type == TensorType_INT16 &&
117                output_type == TensorType_INT16) {
118       output_ = AddOutput({TensorType_INT16,
119                            {},
120                            0,
121                            0,
122                            1.0f / (std::numeric_limits<int16_t>::max() + 1),
123                            0});
124     } else if (input.type != TensorType_INT16 &&
125                output_type == TensorType_INT16) {
126       output_ = AddOutput({TensorType_INT16, {}, 0, 0, 1. / 65536, -32768});
127     } else {
128       output_ = AddOutput({output_type, {}});
129     }
130     SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions,
131                  CreateSoftmaxOptions(builder_, softmax_beta).Union());
132     resolver_ = std::make_unique<SingleOpResolver>(BuiltinOperator_SOFTMAX,
133                                                    registration);
134     BuildInterpreter({GetShape(input_)});
135   }
136 
137   // A dedicated constructor for LeakyRelu, which does some options.
BaseActivationsOpModel(TfLiteRegistration * registration,TensorData input,float alpha)138   BaseActivationsOpModel(TfLiteRegistration* registration, TensorData input,
139                          float alpha) {
140     input_ = AddInput(input);
141     // The output scale and input scale might be different.
142     if (input.type == TensorType_UINT8 || input.type == TensorType_INT8 ||
143         input.type == TensorType_INT16) {
144       auto output_min = (input.min >= 0) ? input.min : input.min * alpha;
145       auto output_max = (input.max >= 0) ? input.max : input.max * alpha;
146       if (input.type == TensorType_INT16) {
147         output_ = AddOutput({TensorType_INT16,
148                              {},
149                              0,
150                              0,
151                              output_max / (std::numeric_limits<int16_t>::max()),
152                              0});
153       } else {
154         output_ = AddOutput({input.type, {}, output_min, output_max});
155       }
156     } else {
157       output_ = AddOutput({input.type, {}});
158     }
159     SetBuiltinOp(BuiltinOperator_LEAKY_RELU, BuiltinOptions_LeakyReluOptions,
160                  CreateLeakyReluOptions(builder_, alpha).Union());
161     resolver_ = std::make_unique<SingleOpResolver>(BuiltinOperator_LEAKY_RELU,
162                                                    registration);
163     BuildInterpreter({GetShape(input_)});
164   }
165 
BaseActivationsOpModel(BuiltinOperator type,const TensorData & input,const TensorData & output)166   BaseActivationsOpModel(BuiltinOperator type, const TensorData& input,
167                          const TensorData& output) {
168     input_ = AddInput(input);
169     output_ = AddOutput(output);
170     SetBuiltinOp(type, BuiltinOptions_NONE, 0);
171     BuildInterpreter({GetShape(input_)});
172   }
173 
BaseActivationsOpModel(TfLiteRegistration * registration,BuiltinOperator type,const TensorData & input,const TensorData & output)174   BaseActivationsOpModel(TfLiteRegistration* registration, BuiltinOperator type,
175                          const TensorData& input, const TensorData& output) {
176     input_ = AddInput(input);
177     output_ = AddOutput(output);
178     SetBuiltinOp(type, BuiltinOptions_NONE, 0);
179     resolver_ = std::make_unique<SingleOpResolver>(type, registration);
180     BuildInterpreter({GetShape(input_)});
181   }
182 
183  protected:
184   int input_;
185   int output_;
186 };
187 
188 class FloatActivationsOpModel : public BaseActivationsOpModel {
189  public:
190   using BaseActivationsOpModel::BaseActivationsOpModel;
191 
SetInput(const std::vector<float> & data)192   void SetInput(const std::vector<float>& data) {
193     PopulateTensor(input_, data);
194   }
GetOutput()195   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
196 };
197 
198 // Our fixed-point math function implementations have roughly 12 bits of
199 // accuracy, when specialized to 16-bit fixed-point arithmetic.
200 // That is purely an implementation compromise, it would have been possible
201 // to get closer to 16 bits of accuracy but that would be more expensive,
202 // and not needed for our purposes as ultimately the output is either
203 // immediately down-quantized to 8 bits, or will typically be at the output
204 // of the surrounding LSTM cell.
205 // So we can require roughly 2^-12 accuracy when the output is 16-bit, and
206 // we can more or less expect the full 2^-8 accuracy when the output is 8-bit.
207 //
208 // However, the representable output interval is often [-1, 1]  (it has to be
209 // for tanh, and even for logistic, when we implement it in fixed-point, we
210 // typically have to do so on such a symmetric interval, e.g. ARM NEON only
211 // has signed fixed-point arithmetic (SQRDMULH)).  As the width of [-1, 1]
212 // is 2, our representable values are often diluted by a factor of 2, whence
213 // the factor of 2 below.
214 const float kQuantizedTolerance = 2 * (1. / 256);
215 const float kQuantizedToleranceInt16 = 2 * (1. / 4096);
216 
217 class QuantizedActivationsOpModel : public BaseActivationsOpModel {
218  public:
219   using BaseActivationsOpModel::BaseActivationsOpModel;
220 
221   template <typename T>
SetInput(const std::vector<float> & data)222   void SetInput(const std::vector<float>& data) {
223     QuantizeAndPopulate<T>(input_, data);
224   }
225   template <typename T>
GetOutput()226   std::vector<T> GetOutput() {
227     return ExtractVector<T>(output_);
228   }
229 
230   template <typename T>
GetDequantizedOutput()231   std::vector<float> GetDequantizedOutput() {
232     return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
233                          GetZeroPoint(output_));
234   }
235 };
236 
237 const auto kTanhKernelMap = new std::map<string, TfLiteRegistration*>({
238     {"Reference", ops::builtin::Register_TANH_REF()},
239     {"GenericOptimized", ops::builtin::Register_TANH_GENERIC_OPT()},
240     {"FixedPointOptimized", ops::builtin::Register_TANH_FIXED_POINT_OPT()},
241 });
242 
243 class TanhOpTest : public SingleOpTest {
244  protected:
GetKernelMap()245   const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
246     return *kTanhKernelMap;
247   }
248 };
249 
250 const auto kLogisticKernelMap = new std::map<string, TfLiteRegistration*>({
251     {"Reference", ops::builtin::Register_LOGISTIC_REF()},
252     {"GenericOptimized", ops::builtin::Register_LOGISTIC_GENERIC_OPT()},
253     {"FixedPointOptimized", ops::builtin::Register_LOGISTIC_FIXED_POINT_OPT()},
254 });
255 
256 class LogisticOpTest : public SingleOpTest {
257  protected:
GetKernelMap()258   const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
259     return *kLogisticKernelMap;
260   }
261 };
262 
263 const auto kLogSoftmaxKernelMap = new std::map<string, TfLiteRegistration*>({
264     {"Reference", ops::builtin::Register_LOG_SOFTMAX_REF()},
265     {"GenericOptimized", ops::builtin::Register_LOG_SOFTMAX()},
266 });
267 
268 class LogSoftmaxOpTest : public SingleOpTest {
269  protected:
GetKernelMap()270   const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
271     return *kLogSoftmaxKernelMap;
272   }
273 };
274 
275 const auto kSoftmaxKernelMap = new std::map<string, TfLiteRegistration*>({
276     {"Reference", ops::builtin::Register_SOFTMAX_REF()},
277     {"GenericOptimized", ops::builtin::Register_SOFTMAX()},
278 });
279 
280 class SoftmaxOpTest : public SingleOpTest {
281  protected:
GetKernelMap()282   const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
283     return *kSoftmaxKernelMap;
284   }
285 };
286 
TEST(FloatActivationsOpTest,Elu)287 TEST(FloatActivationsOpTest, Elu) {
288   FloatActivationsOpModel m(BuiltinOperator_ELU,
289                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
290   m.SetInput({
291       0, -6, 2, -4,     //
292       3, -2, 10, -0.1,  //
293   });
294   ASSERT_EQ(m.Invoke(), kTfLiteOk);
295   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
296                                  0.0, -0.997521, 2.0, -0.981684,    //
297                                  3.0, -0.864665, 10.0, -0.0951626,  //
298                              })));
299 }
300 
TEST(QuantizedActivationsOpTest,EluInt8)301 TEST(QuantizedActivationsOpTest, EluInt8) {
302   const float kMin = -1;
303   const float kMax = 127.f / 128.f;
304   QuantizedActivationsOpModel model(
305       BuiltinOperator_ELU,
306       /*input=*/{TensorType_INT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
307       /*output=*/{TensorType_INT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax});
308 
309   model.SetInput<int8_t>({
310       0, -6, 2, -4,    //
311       3, -2, 6, -0.1,  //
312   });
313 
314   ASSERT_EQ(model.Invoke(), kTfLiteOk);
315   EXPECT_THAT(model.GetDequantizedOutput<int8_t>(),
316               ElementsAreArray(ArrayFloatNear(
317                   {
318                       0, -1.0, 2.0, -1,          //
319                       3.0, -0.875, 6.0, -0.125,  //
320                   },
321                   kQuantizedTolerance)));
322 }
323 
TEST(FloatActivationsOpTest,Relu)324 TEST(FloatActivationsOpTest, Relu) {
325   FloatActivationsOpModel m(BuiltinOperator_RELU,
326                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
327   m.SetInput({
328       0, -6, 2, 4,   //
329       3, -2, 10, 1,  //
330   });
331   ASSERT_EQ(m.Invoke(), kTfLiteOk);
332   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
333                                  0, 0, 2, 4,   //
334                                  3, 0, 10, 1,  //
335                              }));
336 }
337 
TEST(FloatActivationsOpTest,Relu0To1)338 TEST(FloatActivationsOpTest, Relu0To1) {
339   FloatActivationsOpModel m(BuiltinOperator_RELU_0_TO_1,
340                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
341   m.SetInput({
342       0.0, -0.6, 0.2, -0.4,  //
343       0.3, -2.0, 1.1, -0.1,  //
344   });
345   m.Invoke();
346   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
347                                  0.0, 0.0, 0.2, 0.0,  //
348                                  0.3, 0.0, 1.0, 0.0,  //
349                              }));
350 }
351 
TEST(FloatActivationsOpTest,Relu1)352 TEST(FloatActivationsOpTest, Relu1) {
353   FloatActivationsOpModel m(BuiltinOperator_RELU_N1_TO_1,
354                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
355   m.SetInput({
356       0.0, -0.6, 0.2, -0.4,  //
357       0.3, -2.0, 1.1, -0.1,  //
358   });
359   ASSERT_EQ(m.Invoke(), kTfLiteOk);
360   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
361                                  0.0, -0.6, 0.2, -0.4,  //
362                                  0.3, -1.0, 1.0, -0.1,  //
363                              }));
364 }
365 
TEST(FloatActivationsOpTest,Relu6)366 TEST(FloatActivationsOpTest, Relu6) {
367   FloatActivationsOpModel m(BuiltinOperator_RELU6,
368                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
369   m.SetInput({
370       0, -6, 2, 4,   //
371       3, -2, 10, 1,  //
372   });
373   ASSERT_EQ(m.Invoke(), kTfLiteOk);
374   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
375                                  0, 0, 2, 4,  //
376                                  3, 0, 6, 1,  //
377                              }));
378 }
379 
GenerateUniformRandomVector(int size,float min,float max,std::minstd_rand * random_engine,std::vector<float> * result)380 void GenerateUniformRandomVector(int size, float min, float max,
381                                  std::minstd_rand* random_engine,
382                                  std::vector<float>* result) {
383   // Never use std::uniform_*_distribution in tests, it's
384   // implementation-defined. Likewise, don't use std::default_random_engine,
385   // implementation-defined. Implementation-defined is bad because it means that
386   // any toolchain update or new platform may run into test failures.
387   // std::minstd_rand is a standard instantiation of
388   // std::linear_congruential_engine, the cheapest generator in c++11 stdlib,
389   // it's good enough here.
390   result->resize(size);
391   for (int i = 0; i < size; i++) {
392     // We don't care whether the `max` value may ever be produced exactly.
393     // It may actually be thanks to rounding, as std::minstd_rand::modulus
394     // is 2^31 - 1 is greater than the inverse float epsilon.
395     float random_value_scaled_0_1 =
396         (*random_engine)() *
397         (1.0f / static_cast<float>(std::minstd_rand::modulus));
398     (*result)[i] = min + (max - min) * random_value_scaled_0_1;
399   }
400 }
401 
EvalTestReferenceHardSwish(int size,const std::vector<float> & input,std::vector<float> * result)402 void EvalTestReferenceHardSwish(int size, const std::vector<float>& input,
403                                 std::vector<float>* result) {
404   result->resize(size);
405   for (int i = 0; i < size; i++) {
406     const float in = input[i];
407     (*result)[i] = in * std::min(6.0f, std::max(0.0f, in + 3)) * (1.0f / 6.0f);
408   }
409 }
410 
TestFloatHardSwish(int size,std::minstd_rand * random_engine)411 void TestFloatHardSwish(int size, std::minstd_rand* random_engine) {
412   std::vector<float> float_input_values;
413   const float kMin = -10.0f;
414   const float kMax = 10.0f;
415   GenerateUniformRandomVector(size, kMin, kMax, random_engine,
416                               &float_input_values);
417   std::vector<float> float_ref_output_values;
418   EvalTestReferenceHardSwish(size, float_input_values,
419                              &float_ref_output_values);
420   FloatActivationsOpModel m(BuiltinOperator_HARD_SWISH,
421                             /*input=*/{TensorType_FLOAT32, {1, 1, 1, size}},
422                             /*output=*/{TensorType_FLOAT32, {1, 1, 1, size}});
423   m.SetInput(float_input_values);
424 
425   ASSERT_EQ(m.Invoke(), kTfLiteOk);
426   EXPECT_THAT(m.GetOutput(),
427               ElementsAreArray(ArrayFloatNear(float_ref_output_values)));
428 }
429 
430 template <typename QuantizedType>
TestQuantizedHardSwish(TensorType tensor_type,int size,float input_min,float input_max,float output_min,float output_max,std::minstd_rand * random_engine)431 void TestQuantizedHardSwish(TensorType tensor_type, int size, float input_min,
432                             float input_max, float output_min, float output_max,
433                             std::minstd_rand* random_engine) {
434   std::vector<float> float_input_values;
435   GenerateUniformRandomVector(size, input_min, input_max, random_engine,
436                               &float_input_values);
437   std::vector<float> float_ref_output_values;
438   EvalTestReferenceHardSwish(size, float_input_values,
439                              &float_ref_output_values);
440   for (float& val : float_ref_output_values) {
441     val = std::min(output_max, std::max(output_min, val));
442   }
443   QuantizedActivationsOpModel m(
444       BuiltinOperator_HARD_SWISH,
445       /*input=*/{tensor_type, {1, 1, 1, size}, input_min, input_max},
446       /*output=*/{tensor_type, {1, 1, 1, size}, output_min, output_max});
447   m.SetInput<QuantizedType>(float_input_values);
448 
449   ASSERT_EQ(m.Invoke(), kTfLiteOk);
450   const std::vector<float>& dequantized_output =
451       m.GetDequantizedOutput<QuantizedType>();
452   // The numerical error for any 8bit quantized function is at least one half
453   // times the quantization step: 0.5 * (kOutMax - kOutMin) / 256.
454   // To that we add again the quantization step (kOutMax - kOutMin) / 256
455   // to allow for an off-by-one rounding error.
456   const float kTolerance =
457       std::max(input_max - input_min, output_max - output_min) * (1.5f / 256.f);
458   EXPECT_THAT(dequantized_output, ElementsAreArray(ArrayFloatNear(
459                                       float_ref_output_values, kTolerance)));
460 }
461 
462 template <typename QuantizedType>
TestQuantizedHardSwishBias(TensorType tensor_type,float input_min,float input_max,float output_min,float output_max,float tolerated_bias)463 void TestQuantizedHardSwishBias(TensorType tensor_type, float input_min,
464                                 float input_max, float output_min,
465                                 float output_max, float tolerated_bias) {
466   const float quantized_type_range =
467       static_cast<float>(std::numeric_limits<QuantizedType>::max()) -
468       static_cast<float>(std::numeric_limits<QuantizedType>::min());
469   const float input_scale = (input_max - input_min) / quantized_type_range;
470   const float output_scale = (output_max - output_min) / quantized_type_range;
471   const float max_scale = std::max(output_scale, input_scale);
472 
473   // In this bias-focused test case, no need for randomly generated input
474   // values.
475   ASSERT_LE(input_min, -3.0f);
476   ASSERT_GE(input_max, 3.0f);
477   const int quantized_input_negative_three =
478       std::round(std::numeric_limits<QuantizedType>::min() +
479                  (-3.0f - input_min) / input_scale);
480   const int quantized_input_positive_three =
481       std::round(std::numeric_limits<QuantizedType>::min() +
482                  (3.0f - input_min) / input_scale);
483   std::vector<float> float_input_values;
484   for (int i = quantized_input_negative_three;
485        i <= quantized_input_positive_three; i++) {
486     float_input_values.push_back(
487         input_min +
488         (i - std::numeric_limits<QuantizedType>::min()) * input_scale);
489   }
490   const int size = float_input_values.size();
491   std::vector<float> float_ref_output_values;
492   EvalTestReferenceHardSwish(size, float_input_values,
493                              &float_ref_output_values);
494   for (float& val : float_ref_output_values) {
495     val = std::min(output_max, std::max(output_min, val));
496   }
497   QuantizedActivationsOpModel m(
498       BuiltinOperator_HARD_SWISH,
499       /*input=*/{tensor_type, {1, 1, 1, size}, input_min, input_max},
500       /*output=*/{tensor_type, {1, 1, 1, size}, output_min, output_max});
501   m.SetInput<QuantizedType>(float_input_values);
502 
503   ASSERT_EQ(m.Invoke(), kTfLiteOk);
504   const std::vector<float>& dequantized_output =
505       m.GetDequantizedOutput<QuantizedType>();
506 
507   float sum_diff = 0;
508   for (int i = 0; i < size; i++) {
509     sum_diff += dequantized_output[i] - float_ref_output_values[i];
510   }
511   const float bias = sum_diff / (size * max_scale);
512   EXPECT_LE(std::abs(bias), tolerated_bias);
513 }
514 
TEST(FloatActivationsOpTest,HardSwish)515 TEST(FloatActivationsOpTest, HardSwish) {
516   std::minstd_rand random_engine;
517   for (int size : {1, 2, 3, 4, 10, 20, 30, 40, 100}) {
518     TestFloatHardSwish(size, &random_engine);
519   }
520 }
521 
TEST(QuantizedActivationsOpTest,HardSwish)522 TEST(QuantizedActivationsOpTest, HardSwish) {
523   std::minstd_rand random_engine;
524   std::vector<std::pair<float, float>> minmax_pairs{
525       {0.f, 1.f}, {-2.f, 1.f}, {-5.f, 10.f}, {-40.f, 60.f}};
526   for (const auto& input_minmax : minmax_pairs) {
527     for (const auto& output_minmax : minmax_pairs) {
528       float input_min = input_minmax.first;
529       float input_max = input_minmax.second;
530       float output_min = output_minmax.first;
531       float output_max = output_minmax.second;
532       for (int size : {1, 3, 10, 100}) {
533         TestQuantizedHardSwish<uint8_t>(TensorType_UINT8, size, input_min,
534                                         input_max, output_min, output_max,
535                                         &random_engine);
536         TestQuantizedHardSwish<int8_t>(TensorType_INT8, size, input_min,
537                                        input_max, output_min, output_max,
538                                        &random_engine);
539       }
540     }
541   }
542 }
543 
544 // See the comment in the reference implementation of quantized HardSwish:
545 // A numerical issue significantly affecting ImageNet classification accuracy
546 // with MobileNet v3 is only observable at the scale of HardSwish unit tests
547 // if we monitor specifically bias. This testcase is extracted from one of the
548 // HardSwish nodes in that MobileNet v3 that exhibited this issue.
TEST(QuantizedActivationsOpTest,HardSwishBias)549 TEST(QuantizedActivationsOpTest, HardSwishBias) {
550   TestQuantizedHardSwishBias<uint8_t>(TensorType_UINT8, -11.654928f, 25.036512f,
551                                       -0.3905796f, 24.50887f, 0.035);
552 }
553 
TEST_P(TanhOpTest,Tanh)554 TEST_P(TanhOpTest, Tanh) {
555   FloatActivationsOpModel m(GetRegistration(), BuiltinOperator_TANH,
556                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
557   m.SetInput({
558       0, -6, 2, 4,   //
559       3, -2, 10, 1,  //
560   });
561   ASSERT_EQ(m.Invoke(), kTfLiteOk);
562   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
563                                  0, -0.9999877, 0.9640275, 0.999329,    //
564                                  0.99505475, -0.9640275, 1, 0.7615941,  //
565                              })));
566 }
567 
TEST(QuantizedActivationsOpTest,Relu6Uint8)568 TEST(QuantizedActivationsOpTest, Relu6Uint8) {
569   const float kMin = -1;
570   const float kMax = 127.f / 128.f;
571   QuantizedActivationsOpModel m(
572       BuiltinOperator_RELU6,
573       /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
574       /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax});
575   m.SetInput<uint8_t>({
576       0, -6, 2, 4,   //
577       3, -2, 10, 1,  //
578   });
579   ASSERT_EQ(m.Invoke(), kTfLiteOk);
580   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
581               ElementsAreArray(ArrayFloatNear(
582                   {
583                       0, 0, 2, 4,  //
584                       3, 0, 6, 1,  //
585                   },
586                   kQuantizedTolerance)));
587   EXPECT_THAT(m.GetOutput<uint8_t>(),
588               ElementsAreArray({128, 128, 160, 192, 176, 128, 224, 144}));
589 }
590 
591 const auto kLeakyReluKernelMap = new std::map<string, TfLiteRegistration*>({
592     {"Reference", ops::builtin::Register_LEAKY_RELU_REF()},
593     {"GenericOptimized", ops::builtin::Register_LEAKY_RELU()},
594 });
595 
596 class LeakyReluOpTest : public SingleOpTest {
597  protected:
GetKernelMap()598   const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
599     return *kLeakyReluKernelMap;
600   }
601 };
602 
TEST_P(LeakyReluOpTest,LeakyReluUint8)603 TEST_P(LeakyReluOpTest, LeakyReluUint8) {
604   const float kMin = -1;
605   const float kMax = 127.f / 128.f;
606   QuantizedActivationsOpModel m(
607       GetRegistration(),
608       /*input=*/{TensorType_UINT8, {2, 3}, 8 * kMin, 8 * kMax}, 0.5);
609 
610   m.SetInput<uint8_t>({
611       0.0f, 1.0f, 3.0f,    // Row 1
612       1.0f, -1.0f, -2.0f,  // Row 2
613   });
614   ASSERT_EQ(m.Invoke(), kTfLiteOk);
615   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
616               ElementsAreArray(ArrayFloatNear(
617                   {
618                       0.0f, 1.0f, 3.0f,    // Row 1
619                       1.0f, -0.5f, -1.0f,  // Row 2
620                   },
621                   kQuantizedTolerance * 8)));
622 }
623 
624 template <TensorType tensor_type, typename integer_dtype>
QuantizedActivationsOpTestLeakyRelu(TfLiteRegistration * registration)625 void QuantizedActivationsOpTestLeakyRelu(TfLiteRegistration* registration) {
626   const float kMin = -1;
627   const float kMax =
628       std::numeric_limits<integer_dtype>::max() /
629       static_cast<float>(std::numeric_limits<integer_dtype>::max() + 1);
630 
631   QuantizedActivationsOpModel m(
632       registration,
633       /*input=*/{tensor_type, {5, 5}, 5 * kMin, 5 * kMax}, 0.1);
634 
635   m.SetInput<integer_dtype>({
636       -5.0f, -4.6f, -4.2f, -3.8f, -3.4f,  // Row 1
637       -3.0f, -2.6f, -2.2f, -1.8f, -1.4f,  // Row 2
638       -1.0f, -0.6f, -0.2f, 0.2f,  0.6f,   // Row 3
639       1.0f,  1.4f,  1.8f,  2.2f,  2.6f,   // Row 4
640       3.0f,  3.4f,  3.8f,  4.2f,  4.6f,   // Row 5
641   });
642   ASSERT_EQ(m.Invoke(), kTfLiteOk);
643 
644   float kTestQuantizedTolerance = tensor_type == TensorType_INT16
645                                       ? kQuantizedToleranceInt16
646                                       : kQuantizedTolerance * 5;
647 
648   EXPECT_THAT(m.GetDequantizedOutput<integer_dtype>(),
649               ElementsAreArray(ArrayFloatNear(
650                   {
651                       -0.50f, -0.46f, -0.42f, -0.38f, -0.34f,  // Row 1
652                       -0.30f, -0.26f, -0.22f, -0.18f, -0.14f,  // Row 2
653                       -0.10f, -0.06f, -0.02f, 0.20f,  0.60f,   // Row 3
654                       1.00f,  1.40f,  1.80f,  2.20f,  2.60f,   // Row 4
655                       3.00f,  3.40f,  3.80f,  4.20f,  4.60f,   // Row 5
656                   },
657                   kTestQuantizedTolerance)));
658 }
659 
TEST_P(LeakyReluOpTest,LeakyReluInt8)660 TEST_P(LeakyReluOpTest, LeakyReluInt8) {
661   QuantizedActivationsOpTestLeakyRelu<TensorType_INT8, int8_t>(
662       GetRegistration());
663 }
664 
TEST_P(LeakyReluOpTest,LeakyReluInt16)665 TEST_P(LeakyReluOpTest, LeakyReluInt16) {
666   QuantizedActivationsOpTestLeakyRelu<TensorType_INT16, int16_t>(
667       GetRegistration());
668 }
669 
TEST(QuantizedActivationsOpTest,Relu0To1Int8)670 TEST(QuantizedActivationsOpTest, Relu0To1Int8) {
671   const float kMin = 0;
672   const float kMax = 1;
673   QuantizedActivationsOpModel m(
674       BuiltinOperator_RELU_0_TO_1,
675       /*input=*/{TensorType_INT8, {1, 2, 4, 1}, 2 * kMin, kMax},
676       /*output=*/{TensorType_INT8, {1, 2, 4, 1}, 2 * kMin, kMax});
677 
678   m.SetInput<int8_t>({
679       0.0, -0.6, 0.2, -0.4,  //
680       0.3, -2.0, 1.1, -0.1,  //
681   });
682   ASSERT_EQ(m.Invoke(), kTfLiteOk);
683 
684   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(), ElementsAreArray(ArrayFloatNear(
685                                                     {
686                                                         0.0, 0.0, 0.2, 0.0,  //
687                                                         0.3, 0.0, 1.0, 0.0,  //
688                                                     },
689                                                     kQuantizedTolerance)));
690 }
691 
TEST(QuantizedActivationsOpTest,Relu1Int8)692 TEST(QuantizedActivationsOpTest, Relu1Int8) {
693   const float kMin = -1;
694   const float kMax = 1;
695   QuantizedActivationsOpModel m(
696       BuiltinOperator_RELU_N1_TO_1,
697       /*input=*/{TensorType_INT8, {1, 2, 4, 1}, 2 * kMin, kMax},
698       /*output=*/{TensorType_INT8, {1, 2, 4, 1}, 2 * kMin, kMax});
699 
700   m.SetInput<int8_t>({
701       0.0, -0.6, 0.2, -0.4,  //
702       0.3, -2.0, 1.1, -0.1,  //
703   });
704   ASSERT_EQ(m.Invoke(), kTfLiteOk);
705 
706   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
707               ElementsAreArray(ArrayFloatNear(
708                   {
709                       0.0, -0.6, 0.2, -0.4,  //
710                       0.3, -1.0, 1.0, -0.1,  //
711                   },
712                   kQuantizedTolerance)));
713 }
714 
TEST(QuantizedActivationsOpTest,Relu0To1UInt8)715 TEST(QuantizedActivationsOpTest, Relu0To1UInt8) {
716   const float kMin = 0;
717   const float kMax = 1;
718   QuantizedActivationsOpModel m(
719       BuiltinOperator_RELU_0_TO_1,
720       /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 2 * kMin, kMax},
721       /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, 2 * kMin, kMax});
722 
723   m.SetInput<uint8_t>({
724       0.0, -0.6, 0.2, -0.4,  //
725       0.3, -2.0, 1.1, -0.1,  //
726   });
727   ASSERT_EQ(m.Invoke(), kTfLiteOk);
728 
729   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
730               ElementsAreArray(ArrayFloatNear(
731                   {
732                       0.0, 0.0, 0.2, 0.0,  //
733                       0.3, 0.0, 1.0, 0.0,  //
734                   },
735                   kQuantizedTolerance)));
736 }
737 
TEST(QuantizedActivationsOpTest,Relu1UInt8)738 TEST(QuantizedActivationsOpTest, Relu1UInt8) {
739   const float kMin = -1;
740   const float kMax = 1;
741   QuantizedActivationsOpModel m(
742       BuiltinOperator_RELU_N1_TO_1,
743       /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 2 * kMin, kMax},
744       /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, 2 * kMin, kMax});
745 
746   m.SetInput<uint8_t>({
747       0.0, -0.6, 0.2, -0.4,  //
748       0.3, -2.0, 1.1, -0.1,  //
749   });
750   ASSERT_EQ(m.Invoke(), kTfLiteOk);
751 
752   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
753               ElementsAreArray(ArrayFloatNear(
754                   {
755                       0.0, -0.6, 0.2, -0.4,  //
756                       0.3, -1.0, 1.0, -0.1,  //
757                   },
758                   kQuantizedTolerance)));
759 }
760 
TEST(QuantizedActivationsOpTest,Relu6Int8)761 TEST(QuantizedActivationsOpTest, Relu6Int8) {
762   const float kMin = -1;
763   const float kMax = 127.f / 128.f;
764   QuantizedActivationsOpModel m(
765       BuiltinOperator_RELU6,
766       /*input=*/{TensorType_INT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
767       /*output=*/{TensorType_INT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax});
768   m.SetInput<int8_t>({
769       0, -6, 2, 4,   //
770       3, -2, 10, 1,  //
771   });
772   ASSERT_EQ(m.Invoke(), kTfLiteOk);
773   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(), ElementsAreArray(ArrayFloatNear(
774                                                     {
775                                                         0, 0, 2, 4,  //
776                                                         3, 0, 6, 1,  //
777                                                     },
778                                                     kQuantizedTolerance)));
779   EXPECT_THAT(m.GetOutput<int8_t>(),
780               ElementsAreArray({0, 0, 32, 64, 48, 0, 96, 16}));
781 }
782 
TEST(QuantizedActivationsOpTest,Relu6Int16)783 TEST(QuantizedActivationsOpTest, Relu6Int16) {
784   const float kMin = -1;
785   const float kMax = 32767.f / 32768.f;
786   QuantizedActivationsOpModel m(
787       BuiltinOperator_RELU6,
788       /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
789       /*output=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax});
790   m.SetInput<int16_t>({
791       0, -6, 2, 4,   //
792       3, -2, 10, 1,  //
793   });
794   ASSERT_EQ(m.Invoke(), kTfLiteOk);
795   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
796               ElementsAreArray(ArrayFloatNear(
797                   {
798                       0, 0, 2, 4,  //
799                       3, 0, 6, 1,  //
800                   },
801                   kQuantizedToleranceInt16)));
802   EXPECT_THAT(m.GetOutput<int16_t>(),
803               ElementsAreArray({0, 0, 8192, 16384, 12288, 0, 24576, 4096}));
804 }
805 
TEST(QuantizedActivationsOpTest,ReluUint8)806 TEST(QuantizedActivationsOpTest, ReluUint8) {
807   const float kMin = -1;
808   const float kMax = 127.f / 128.f;
809   QuantizedActivationsOpModel m(
810       BuiltinOperator_RELU,
811       /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
812       /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax});
813   m.SetInput<uint8_t>({
814       0, -6, 2, 4,  //
815       3, -2, 7, 1,  //
816   });
817   ASSERT_EQ(m.Invoke(), kTfLiteOk);
818   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
819               ElementsAreArray(ArrayFloatNear(
820                   {
821                       0, 0, 2, 4,  //
822                       3, 0, 7, 1,  //
823                   },
824                   kQuantizedTolerance)));
825   EXPECT_THAT(m.GetOutput<uint8_t>(),
826               ElementsAreArray({128, 128, 160, 192, 176, 128, 240, 144}));
827 }
828 
TEST(QuantizedActivationsOpTest,ReluInt8)829 TEST(QuantizedActivationsOpTest, ReluInt8) {
830   const float kMin = -1;
831   const float kMax = 127.f / 128.f;
832   QuantizedActivationsOpModel m(
833       BuiltinOperator_RELU,
834       /*input=*/{TensorType_INT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
835       /*output=*/{TensorType_INT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax});
836   m.SetInput<int8_t>({
837       0, -6, 2, 4,  //
838       3, -2, 7, 1,  //
839   });
840   ASSERT_EQ(m.Invoke(), kTfLiteOk);
841   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(), ElementsAreArray(ArrayFloatNear(
842                                                     {
843                                                         0, 0, 2, 4,  //
844                                                         3, 0, 7, 1,  //
845                                                     },
846                                                     kQuantizedTolerance)));
847   EXPECT_THAT(m.GetOutput<int8_t>(),
848               ElementsAreArray({0, 0, 32, 64, 48, 0, 112, 16}));
849 }
850 
TEST(QuantizedActivationsOpTest,ReluInt16)851 TEST(QuantizedActivationsOpTest, ReluInt16) {
852   const float kMin = -1;
853   const float kMax = 32767.f / 32768.f;
854   QuantizedActivationsOpModel m(
855       BuiltinOperator_RELU,
856       /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
857       /*output=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax});
858   m.SetInput<int16_t>({
859       0, -6, 2, 4,  //
860       3, -2, 7, 1,  //
861   });
862   ASSERT_EQ(m.Invoke(), kTfLiteOk);
863   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
864               ElementsAreArray(ArrayFloatNear(
865                   {
866                       0, 0, 2, 4,  //
867                       3, 0, 7, 1,  //
868                   },
869                   kQuantizedToleranceInt16)));
870   EXPECT_THAT(m.GetOutput<int16_t>(),
871               ElementsAreArray({0, 0, 8192, 16384, 12288, 0, 28672, 4096}));
872 }
873 
TEST_P(TanhOpTest,TanhUint8)874 TEST_P(TanhOpTest, TanhUint8) {
875   const float kMin = -1;
876   const float kMax = 127.f / 128.f;
877   const float kTanhTolerance = 0.014f;
878   QuantizedActivationsOpModel m(
879       GetRegistration(), BuiltinOperator_TANH,
880       /*input=*/{TensorType_UINT8, {89}, 8 * kMin, 8 * kMax},
881       /*output=*/{TensorType_UINT8, {89}, kMin, kMax});
882   // 64+16+8+1 elements, from -8 to 8.
883   m.SetInput<uint8_t>(
884       {-8.0000000000, -7.8181818182, -7.6363636364, -7.4545454545,
885        -7.2727272727, -7.0909090909, -6.9090909091, -6.7272727273,
886        -6.5454545455, -6.3636363636, -6.1818181818, -6.0000000000,
887        -5.8181818182, -5.6363636364, -5.4545454545, -5.2727272727,
888        -5.0909090909, -4.9090909091, -4.7272727273, -4.5454545455,
889        -4.3636363636, -4.1818181818, -4.0000000000, -3.8181818182,
890        -3.6363636364, -3.4545454545, -3.2727272727, -3.0909090909,
891        -2.9090909091, -2.7272727273, -2.5454545455, -2.3636363636,
892        -2.1818181818, -2.0000000000, -1.8181818182, -1.6363636364,
893        -1.4545454545, -1.2727272727, -1.0909090909, -0.9090909091,
894        -0.7272727273, -0.5454545455, -0.3636363636, -0.1818181818,
895        0.0000000000,  0.1818181818,  0.3636363636,  0.5454545455,
896        0.7272727273,  0.9090909091,  1.0909090909,  1.2727272727,
897        1.4545454545,  1.6363636364,  1.8181818182,  2.0000000000,
898        2.1818181818,  2.3636363636,  2.5454545455,  2.7272727273,
899        2.9090909091,  3.0909090909,  3.2727272727,  3.4545454545,
900        3.6363636364,  3.8181818182,  4.0000000000,  4.1818181818,
901        4.3636363636,  4.5454545455,  4.7272727273,  4.9090909091,
902        5.0909090909,  5.2727272727,  5.4545454545,  5.6363636364,
903        5.8181818182,  6.0000000000,  6.1818181818,  6.3636363636,
904        6.5454545455,  6.7272727273,  6.9090909091,  7.0909090909,
905        7.2727272727,  7.4545454545,  7.6363636364,  7.8181818182,
906        8.0000000000});
907   ASSERT_EQ(m.Invoke(), kTfLiteOk);
908   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
909               ElementsAreArray(ArrayFloatNear(
910                   {-0.9999997749, -0.9999996762, -0.9999995342, -0.9999993300,
911                    -0.9999990361, -0.9999986134, -0.9999980053, -0.9999971306,
912                    -0.9999958722, -0.9999940619, -0.9999914578, -0.9999877117,
913                    -0.9999823226, -0.9999745703, -0.9999634183, -0.9999473758,
914                    -0.9999242982, -0.9998911009, -0.9998433469, -0.9997746542,
915                    -0.9996758446, -0.9995337191, -0.9993292997, -0.9990353053,
916                    -0.9986125310, -0.9980046622, -0.9971308601, -0.9958751909,
917                    -0.9940716137, -0.9914827859, -0.9877703933, -0.9824541388,
918                    -0.9748561217, -0.9640275801, -0.9486568273, -0.9269625051,
919                    -0.8965880154, -0.8545351057, -0.7972097087, -0.7206956332,
920                    -0.6213939966, -0.4971057414, -0.3484130125, -0.1798408185,
921                    0.0000000000,  0.1798408185,  0.3484130125,  0.4971057414,
922                    0.6213939966,  0.7206956332,  0.7972097087,  0.8545351057,
923                    0.8965880154,  0.9269625051,  0.9486568273,  0.9640275801,
924                    0.9748561217,  0.9824541388,  0.9877703933,  0.9914827859,
925                    0.9940716137,  0.9958751909,  0.9971308601,  0.9980046622,
926                    0.9986125310,  0.9990353053,  0.9993292997,  0.9995337191,
927                    0.9996758446,  0.9997746542,  0.9998433469,  0.9998911009,
928                    0.9999242982,  0.9999473758,  0.9999634183,  0.9999745703,
929                    0.9999823226,  0.9999877117,  0.9999914578,  0.9999940619,
930                    0.9999958722,  0.9999971306,  0.9999980053,  0.9999986134,
931                    0.9999990361,  0.9999993300,  0.9999995342,  0.9999996762,
932                    0.9999997749},
933                   kTanhTolerance)));
934 }
935 
TEST_P(TanhOpTest,TanhInt8)936 TEST_P(TanhOpTest, TanhInt8) {
937   const float kMin = -1;
938   const float kMax = 127.f / 128.f;
939   const float kTanhTolerance = 0.014f;
940   QuantizedActivationsOpModel m(
941       GetRegistration(), BuiltinOperator_TANH,
942       /*input=*/{TensorType_INT8, {89}, 8 * kMin, 8 * kMax},
943       /*output=*/{TensorType_INT8, {89}, kMin, kMax});
944   // 64+16+8+1 elements, from -8 to 8.
945   m.SetInput<int8_t>(
946       {-8.0000000000, -7.8181818182, -7.6363636364, -7.4545454545,
947        -7.2727272727, -7.0909090909, -6.9090909091, -6.7272727273,
948        -6.5454545455, -6.3636363636, -6.1818181818, -6.0000000000,
949        -5.8181818182, -5.6363636364, -5.4545454545, -5.2727272727,
950        -5.0909090909, -4.9090909091, -4.7272727273, -4.5454545455,
951        -4.3636363636, -4.1818181818, -4.0000000000, -3.8181818182,
952        -3.6363636364, -3.4545454545, -3.2727272727, -3.0909090909,
953        -2.9090909091, -2.7272727273, -2.5454545455, -2.3636363636,
954        -2.1818181818, -2.0000000000, -1.8181818182, -1.6363636364,
955        -1.4545454545, -1.2727272727, -1.0909090909, -0.9090909091,
956        -0.7272727273, -0.5454545455, -0.3636363636, -0.1818181818,
957        0.0000000000,  0.1818181818,  0.3636363636,  0.5454545455,
958        0.7272727273,  0.9090909091,  1.0909090909,  1.2727272727,
959        1.4545454545,  1.6363636364,  1.8181818182,  2.0000000000,
960        2.1818181818,  2.3636363636,  2.5454545455,  2.7272727273,
961        2.9090909091,  3.0909090909,  3.2727272727,  3.4545454545,
962        3.6363636364,  3.8181818182,  4.0000000000,  4.1818181818,
963        4.3636363636,  4.5454545455,  4.7272727273,  4.9090909091,
964        5.0909090909,  5.2727272727,  5.4545454545,  5.6363636364,
965        5.8181818182,  6.0000000000,  6.1818181818,  6.3636363636,
966        6.5454545455,  6.7272727273,  6.9090909091,  7.0909090909,
967        7.2727272727,  7.4545454545,  7.6363636364,  7.8181818182,
968        8.0000000000});
969   ASSERT_EQ(m.Invoke(), kTfLiteOk);
970   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
971               ElementsAreArray(ArrayFloatNear(
972                   {-0.9999997749, -0.9999996762, -0.9999995342, -0.9999993300,
973                    -0.9999990361, -0.9999986134, -0.9999980053, -0.9999971306,
974                    -0.9999958722, -0.9999940619, -0.9999914578, -0.9999877117,
975                    -0.9999823226, -0.9999745703, -0.9999634183, -0.9999473758,
976                    -0.9999242982, -0.9998911009, -0.9998433469, -0.9997746542,
977                    -0.9996758446, -0.9995337191, -0.9993292997, -0.9990353053,
978                    -0.9986125310, -0.9980046622, -0.9971308601, -0.9958751909,
979                    -0.9940716137, -0.9914827859, -0.9877703933, -0.9824541388,
980                    -0.9748561217, -0.9640275801, -0.9486568273, -0.9269625051,
981                    -0.8965880154, -0.8545351057, -0.7972097087, -0.7206956332,
982                    -0.6213939966, -0.4971057414, -0.3484130125, -0.1798408185,
983                    0.0000000000,  0.1798408185,  0.3484130125,  0.4971057414,
984                    0.6213939966,  0.7206956332,  0.7972097087,  0.8545351057,
985                    0.8965880154,  0.9269625051,  0.9486568273,  0.9640275801,
986                    0.9748561217,  0.9824541388,  0.9877703933,  0.9914827859,
987                    0.9940716137,  0.9958751909,  0.9971308601,  0.9980046622,
988                    0.9986125310,  0.9990353053,  0.9993292997,  0.9995337191,
989                    0.9996758446,  0.9997746542,  0.9998433469,  0.9998911009,
990                    0.9999242982,  0.9999473758,  0.9999634183,  0.9999745703,
991                    0.9999823226,  0.9999877117,  0.9999914578,  0.9999940619,
992                    0.9999958722,  0.9999971306,  0.9999980053,  0.9999986134,
993                    0.9999990361,  0.9999993300,  0.9999995342,  0.9999996762,
994                    0.9999997749},
995                   kTanhTolerance)));
996 }
997 
TEST_P(TanhOpTest,TanhInt16)998 TEST_P(TanhOpTest, TanhInt16) {
999   const float kMin = -1;
1000   const float kMax = 32767.f / 32768.f;
1001   QuantizedActivationsOpModel m(
1002       GetRegistration(), BuiltinOperator_TANH,
1003       /*input=*/{TensorType_INT16, {177}, 16 * kMin, 16 * kMax},
1004       /*output=*/{TensorType_INT16, {177}, kMin, kMax});
1005   m.SetInput<int16_t>(
1006       {-20.0000000000, -19.7727272727, -19.5454545455, -19.3181818182,
1007        -19.0909090909, -18.8636363636, -18.6363636364, -18.4090909091,
1008        -18.1818181818, -17.9545454545, -17.7272727273, -17.5000000000,
1009        -17.2727272727, -17.0454545455, -16.8181818182, -16.5909090909,
1010        -16.3636363636, -16.1363636364, -15.9090909091, -15.6818181818,
1011        -15.4545454545, -15.2272727273, -15.0000000000, -14.7727272727,
1012        -14.5454545455, -14.3181818182, -14.0909090909, -13.8636363636,
1013        -13.6363636364, -13.4090909091, -13.1818181818, -12.9545454545,
1014        -12.7272727273, -12.5000000000, -12.2727272727, -12.0454545455,
1015        -11.8181818182, -11.5909090909, -11.3636363636, -11.1363636364,
1016        -10.9090909091, -10.6818181818, -10.4545454545, -10.2272727273,
1017        -10.0000000000, -9.7727272727,  -9.5454545455,  -9.3181818182,
1018        -9.0909090909,  -8.8636363636,  -8.6363636364,  -8.4090909091,
1019        -8.1818181818,  -7.9545454545,  -7.7272727273,  -7.5000000000,
1020        -7.2727272727,  -7.0454545455,  -6.8181818182,  -6.5909090909,
1021        -6.3636363636,  -6.1363636364,  -5.9090909091,  -5.6818181818,
1022        -5.4545454545,  -5.2272727273,  -5.0000000000,  -4.7727272727,
1023        -4.5454545455,  -4.3181818182,  -4.0909090909,  -3.8636363636,
1024        -3.6363636364,  -3.4090909091,  -3.1818181818,  -2.9545454545,
1025        -2.7272727273,  -2.5000000000,  -2.2727272727,  -2.0454545455,
1026        -1.8181818182,  -1.5909090909,  -1.3636363636,  -1.1363636364,
1027        -0.9090909091,  -0.6818181818,  -0.4545454545,  -0.2272727273,
1028        0.0000000000,   0.2272727273,   0.4545454545,   0.6818181818,
1029        0.9090909091,   1.1363636364,   1.3636363636,   1.5909090909,
1030        1.8181818182,   2.0454545455,   2.2727272727,   2.5000000000,
1031        2.7272727273,   2.9545454545,   3.1818181818,   3.4090909091,
1032        3.6363636364,   3.8636363636,   4.0909090909,   4.3181818182,
1033        4.5454545455,   4.7727272727,   5.0000000000,   5.2272727273,
1034        5.4545454545,   5.6818181818,   5.9090909091,   6.1363636364,
1035        6.3636363636,   6.5909090909,   6.8181818182,   7.0454545455,
1036        7.2727272727,   7.5000000000,   7.7272727273,   7.9545454545,
1037        8.1818181818,   8.4090909091,   8.6363636364,   8.8636363636,
1038        9.0909090909,   9.3181818182,   9.5454545455,   9.7727272727,
1039        10.0000000000,  10.2272727273,  10.4545454545,  10.6818181818,
1040        10.9090909091,  11.1363636364,  11.3636363636,  11.5909090909,
1041        11.8181818182,  12.0454545455,  12.2727272727,  12.5000000000,
1042        12.7272727273,  12.9545454545,  13.1818181818,  13.4090909091,
1043        13.6363636364,  13.8636363636,  14.0909090909,  14.3181818182,
1044        14.5454545455,  14.7727272727,  15.0000000000,  15.2272727273,
1045        15.4545454545,  15.6818181818,  15.9090909091,  16.1363636364,
1046        16.3636363636,  16.5909090909,  16.8181818182,  17.0454545455,
1047        17.2727272727,  17.5000000000,  17.7272727273,  17.9545454545,
1048        18.1818181818,  18.4090909091,  18.6363636364,  18.8636363636,
1049        19.0909090909,  19.3181818182,  19.5454545455,  19.7727272727,
1050        20.0000000000});
1051   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1052   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
1053               ElementsAreArray(ArrayFloatNear(
1054                   {-1.0000000000, -1.0000000000, -1.0000000000, -1.0000000000,
1055                    -1.0000000000, -1.0000000000, -1.0000000000, -1.0000000000,
1056                    -1.0000000000, -1.0000000000, -1.0000000000, -1.0000000000,
1057                    -1.0000000000, -1.0000000000, -1.0000000000, -1.0000000000,
1058                    -1.0000000000, -1.0000000000, -1.0000000000, -1.0000000000,
1059                    -1.0000000000, -1.0000000000, -1.0000000000, -1.0000000000,
1060                    -1.0000000000, -1.0000000000, -1.0000000000, -1.0000000000,
1061                    -1.0000000000, -1.0000000000, -1.0000000000, -1.0000000000,
1062                    -1.0000000000, -1.0000000000, -1.0000000000, -0.9999999999,
1063                    -0.9999999999, -0.9999999998, -0.9999999997, -0.9999999996,
1064                    -0.9999999993, -0.9999999989, -0.9999999983, -0.9999999974,
1065                    -0.9999999959, -0.9999999935, -0.9999999898, -0.9999999839,
1066                    -0.9999999746, -0.9999999600, -0.9999999370, -0.9999999007,
1067                    -0.9999998435, -0.9999997535, -0.9999996117, -0.9999993882,
1068                    -0.9999990361, -0.9999984815, -0.9999976076, -0.9999962309,
1069                    -0.9999940619, -0.9999906449, -0.9999852614, -0.9999767801,
1070                    -0.9999634183, -0.9999423677, -0.9999092043, -0.9998569589,
1071                    -0.9997746542, -0.9996450004, -0.9994407705, -0.9991190997,
1072                    -0.9986125310, -0.9978149744, -0.9965597488, -0.9945853915,
1073                    -0.9914827859, -0.9866142982, -0.9789923110, -0.9671021386,
1074                    -0.9486568273, -0.9202886021, -0.8772337852, -0.8131859906,
1075                    -0.7206956332, -0.5927001330, -0.4256281972, -0.2234388228,
1076                    0.0000000000,  0.2234388228,  0.4256281972,  0.5927001330,
1077                    0.7206956332,  0.8131859906,  0.8772337852,  0.9202886021,
1078                    0.9486568273,  0.9671021386,  0.9789923110,  0.9866142982,
1079                    0.9914827859,  0.9945853915,  0.9965597488,  0.9978149744,
1080                    0.9986125310,  0.9991190997,  0.9994407705,  0.9996450004,
1081                    0.9997746542,  0.9998569589,  0.9999092043,  0.9999423677,
1082                    0.9999634183,  0.9999767801,  0.9999852614,  0.9999906449,
1083                    0.9999940619,  0.9999962309,  0.9999976076,  0.9999984815,
1084                    0.9999990361,  0.9999993882,  0.9999996117,  0.9999997535,
1085                    0.9999998435,  0.9999999007,  0.9999999370,  0.9999999600,
1086                    0.9999999746,  0.9999999839,  0.9999999898,  0.9999999935,
1087                    0.9999999959,  0.9999999974,  0.9999999983,  0.9999999989,
1088                    0.9999999993,  0.9999999996,  0.9999999997,  0.9999999998,
1089                    0.9999999999,  0.9999999999,  1.0000000000,  1.0000000000,
1090                    1.0000000000,  1.0000000000,  1.0000000000,  1.0000000000,
1091                    1.0000000000,  1.0000000000,  1.0000000000,  1.0000000000,
1092                    1.0000000000,  1.0000000000,  1.0000000000,  1.0000000000,
1093                    1.0000000000,  1.0000000000,  1.0000000000,  1.0000000000,
1094                    1.0000000000,  1.0000000000,  1.0000000000,  1.0000000000,
1095                    1.0000000000,  1.0000000000,  1.0000000000,  1.0000000000,
1096                    1.0000000000,  1.0000000000,  1.0000000000,  1.0000000000,
1097                    1.0000000000,  1.0000000000,  1.0000000000,  1.0000000000,
1098                    1.0000000000},
1099                   kQuantizedToleranceInt16)));
1100 }
1101 
TEST_P(TanhOpTest,TanhInt16General)1102 TEST_P(TanhOpTest, TanhInt16General) {
1103   const float kMin = -1;
1104   const float kMax = 32767.f / 32768.f;
1105   QuantizedActivationsOpModel m(
1106       GetRegistration(), BuiltinOperator_TANH,
1107       /*input=*/{TensorType_INT16, {10}, 11 * kMin, 11 * kMax},
1108       /*output=*/{TensorType_INT16, {10}, kMin, kMax});
1109   m.SetInput<int16_t>({-10, -4, 1, 0.5, 0.25,  //
1110                        0, -0.1, 6, 7.0909090909, 8});
1111   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1112   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
1113               ElementsAreArray(ArrayFloatNear(
1114                   {-1.0, -0.999329, 0.761594, 0.462117, 0.244919,  //
1115                    0.0, -0.099668, 0.999988, 0.999999, 1.0},
1116                   kQuantizedToleranceInt16)));
1117 }
1118 
TEST_P(LogisticOpTest,Sigmoid)1119 TEST_P(LogisticOpTest, Sigmoid) {
1120   FloatActivationsOpModel m(GetRegistration(), BuiltinOperator_LOGISTIC,
1121                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
1122   m.SetInput({
1123       0, -6, 2, 4,   //
1124       3, -2, 10, 1,  //
1125   });
1126   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1127   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
1128                                  0.5, 0.002473, 0.880797, 0.982014,       //
1129                                  0.952574, 0.119203, 0.999955, 0.731059,  //
1130                              })));
1131 }
1132 
TEST_P(LogisticOpTest,SigmoidUint8)1133 TEST_P(LogisticOpTest, SigmoidUint8) {
1134   QuantizedActivationsOpModel m(GetRegistration(), BuiltinOperator_LOGISTIC,
1135                                 /*input=*/{TensorType_UINT8, {89}, -10, 10});
1136   // 64+16+8+1 elements, from -10 to 10
1137   m.SetInput<uint8_t>(
1138       {-10.0000000000, -9.7727272727, -9.5454545455, -9.3181818182,
1139        -9.0909090909,  -8.8636363636, -8.6363636364, -8.4090909091,
1140        -8.1818181818,  -7.9545454545, -7.7272727273, -7.5000000000,
1141        -7.2727272727,  -7.0454545455, -6.8181818182, -6.5909090909,
1142        -6.3636363636,  -6.1363636364, -5.9090909091, -5.6818181818,
1143        -5.4545454545,  -5.2272727273, -5.0000000000, -4.7727272727,
1144        -4.5454545455,  -4.3181818182, -4.0909090909, -3.8636363636,
1145        -3.6363636364,  -3.4090909091, -3.1818181818, -2.9545454545,
1146        -2.7272727273,  -2.5000000000, -2.2727272727, -2.0454545455,
1147        -1.8181818182,  -1.5909090909, -1.3636363636, -1.1363636364,
1148        -0.9090909091,  -0.6818181818, -0.4545454545, -0.2272727273,
1149        0.0000000000,   0.2272727273,  0.4545454545,  0.6818181818,
1150        0.9090909091,   1.1363636364,  1.3636363636,  1.5909090909,
1151        1.8181818182,   2.0454545455,  2.2727272727,  2.5000000000,
1152        2.7272727273,   2.9545454545,  3.1818181818,  3.4090909091,
1153        3.6363636364,   3.8636363636,  4.0909090909,  4.3181818182,
1154        4.5454545455,   4.7727272727,  5.0000000000,  5.2272727273,
1155        5.4545454545,   5.6818181818,  5.9090909091,  6.1363636364,
1156        6.3636363636,   6.5909090909,  6.8181818182,  7.0454545455,
1157        7.2727272727,   7.5000000000,  7.7272727273,  7.9545454545,
1158        8.1818181818,   8.4090909091,  8.6363636364,  8.8636363636,
1159        9.0909090909,   9.3181818182,  9.5454545455,  9.7727272727,
1160        10.0000000000});
1161   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1162   EXPECT_THAT(
1163       m.GetDequantizedOutput<uint8_t>(),
1164       ElementsAreArray(ArrayFloatNear(
1165           {0.0000453979, 0.0000569815, 0.0000715205, 0.0000897689, 0.0001126729,
1166            0.0001414198, 0.0001774998, 0.0002227827, 0.0002796147, 0.0003509396,
1167            0.0004404502, 0.0005527786, 0.0006937345, 0.0008706021, 0.0010925128,
1168            0.0013709094, 0.0017201256, 0.0021581065, 0.0027073042, 0.0033957870,
1169            0.0042586071, 0.0053394826, 0.0066928509, 0.0083863576, 0.0105038445,
1170            0.0131488902, 0.0164489307, 0.0205599431, 0.0256715863, 0.0320125562,
1171            0.0398556989, 0.0495221198, 0.0613831074, 0.0758581800, 0.0934070047,
1172            0.1145124805, 0.1396521834, 0.1692560327, 0.2036499335, 0.2429886272,
1173            0.2871859014, 0.3358556241, 0.3882805886, 0.4434251301, 0.5000000000,
1174            0.5565748699, 0.6117194114, 0.6641443759, 0.7128140986, 0.7570113728,
1175            0.7963500665, 0.8307439673, 0.8603478166, 0.8854875195, 0.9065929953,
1176            0.9241418200, 0.9386168926, 0.9504778802, 0.9601443011, 0.9679874438,
1177            0.9743284137, 0.9794400569, 0.9835510693, 0.9868511098, 0.9894961555,
1178            0.9916136424, 0.9933071491, 0.9946605174, 0.9957413929, 0.9966042130,
1179            0.9972926958, 0.9978418935, 0.9982798744, 0.9986290906, 0.9989074872,
1180            0.9991293979, 0.9993062655, 0.9994472214, 0.9995595498, 0.9996490604,
1181            0.9997203853, 0.9997772173, 0.9998225002, 0.9998585802, 0.9998873271,
1182            0.9999102311, 0.9999284795, 0.9999430185, 0.9999546021},
1183           kQuantizedTolerance)));
1184 }
1185 
TEST_P(LogisticOpTest,SigmoidInt8)1186 TEST_P(LogisticOpTest, SigmoidInt8) {
1187   QuantizedActivationsOpModel m(GetRegistration(), BuiltinOperator_LOGISTIC,
1188                                 /*input=*/{TensorType_INT8, {89}, -10, 10});
1189   // 64+16+8+1 elements, from -10 to 10
1190   m.SetInput<int8_t>(
1191       {-10.0000000000, -9.7727272727, -9.5454545455, -9.3181818182,
1192        -9.0909090909,  -8.8636363636, -8.6363636364, -8.4090909091,
1193        -8.1818181818,  -7.9545454545, -7.7272727273, -7.5000000000,
1194        -7.2727272727,  -7.0454545455, -6.8181818182, -6.5909090909,
1195        -6.3636363636,  -6.1363636364, -5.9090909091, -5.6818181818,
1196        -5.4545454545,  -5.2272727273, -5.0000000000, -4.7727272727,
1197        -4.5454545455,  -4.3181818182, -4.0909090909, -3.8636363636,
1198        -3.6363636364,  -3.4090909091, -3.1818181818, -2.9545454545,
1199        -2.7272727273,  -2.5000000000, -2.2727272727, -2.0454545455,
1200        -1.8181818182,  -1.5909090909, -1.3636363636, -1.1363636364,
1201        -0.9090909091,  -0.6818181818, -0.4545454545, -0.2272727273,
1202        0.0000000000,   0.2272727273,  0.4545454545,  0.6818181818,
1203        0.9090909091,   1.1363636364,  1.3636363636,  1.5909090909,
1204        1.8181818182,   2.0454545455,  2.2727272727,  2.5000000000,
1205        2.7272727273,   2.9545454545,  3.1818181818,  3.4090909091,
1206        3.6363636364,   3.8636363636,  4.0909090909,  4.3181818182,
1207        4.5454545455,   4.7727272727,  5.0000000000,  5.2272727273,
1208        5.4545454545,   5.6818181818,  5.9090909091,  6.1363636364,
1209        6.3636363636,   6.5909090909,  6.8181818182,  7.0454545455,
1210        7.2727272727,   7.5000000000,  7.7272727273,  7.9545454545,
1211        8.1818181818,   8.4090909091,  8.6363636364,  8.8636363636,
1212        9.0909090909,   9.3181818182,  9.5454545455,  9.7727272727,
1213        10.0000000000});
1214   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1215   EXPECT_THAT(
1216       m.GetDequantizedOutput<int8_t>(),
1217       ElementsAreArray(ArrayFloatNear(
1218           {0.0000453979, 0.0000569815, 0.0000715205, 0.0000897689, 0.0001126729,
1219            0.0001414198, 0.0001774998, 0.0002227827, 0.0002796147, 0.0003509396,
1220            0.0004404502, 0.0005527786, 0.0006937345, 0.0008706021, 0.0010925128,
1221            0.0013709094, 0.0017201256, 0.0021581065, 0.0027073042, 0.0033957870,
1222            0.0042586071, 0.0053394826, 0.0066928509, 0.0083863576, 0.0105038445,
1223            0.0131488902, 0.0164489307, 0.0205599431, 0.0256715863, 0.0320125562,
1224            0.0398556989, 0.0495221198, 0.0613831074, 0.0758581800, 0.0934070047,
1225            0.1145124805, 0.1396521834, 0.1692560327, 0.2036499335, 0.2429886272,
1226            0.2871859014, 0.3358556241, 0.3882805886, 0.4434251301, 0.5000000000,
1227            0.5565748699, 0.6117194114, 0.6641443759, 0.7128140986, 0.7570113728,
1228            0.7963500665, 0.8307439673, 0.8603478166, 0.8854875195, 0.9065929953,
1229            0.9241418200, 0.9386168926, 0.9504778802, 0.9601443011, 0.9679874438,
1230            0.9743284137, 0.9794400569, 0.9835510693, 0.9868511098, 0.9894961555,
1231            0.9916136424, 0.9933071491, 0.9946605174, 0.9957413929, 0.9966042130,
1232            0.9972926958, 0.9978418935, 0.9982798744, 0.9986290906, 0.9989074872,
1233            0.9991293979, 0.9993062655, 0.9994472214, 0.9995595498, 0.9996490604,
1234            0.9997203853, 0.9997772173, 0.9998225002, 0.9998585802, 0.9998873271,
1235            0.9999102311, 0.9999284795, 0.9999430185, 0.9999546021},
1236           kQuantizedTolerance)));
1237 }
1238 
TEST_P(LogisticOpTest,SigmoidInt16)1239 TEST_P(LogisticOpTest, SigmoidInt16) {
1240   const float kMin = -1;
1241   const float kMax = 32767.f / 32768.f;
1242   QuantizedActivationsOpModel m(
1243       GetRegistration(), BuiltinOperator_LOGISTIC,
1244       /*input=*/{TensorType_INT16, {177}, 16 * kMin, 16 * kMax},
1245       /*output=*/{TensorType_INT16, {177}, kMin, kMax});
1246   m.SetInput<int16_t>(
1247       {-20.0000000000, -19.7727272727, -19.5454545455, -19.3181818182,
1248        -19.0909090909, -18.8636363636, -18.6363636364, -18.4090909091,
1249        -18.1818181818, -17.9545454545, -17.7272727273, -17.5000000000,
1250        -17.2727272727, -17.0454545455, -16.8181818182, -16.5909090909,
1251        -16.3636363636, -16.1363636364, -15.9090909091, -15.6818181818,
1252        -15.4545454545, -15.2272727273, -15.0000000000, -14.7727272727,
1253        -14.5454545455, -14.3181818182, -14.0909090909, -13.8636363636,
1254        -13.6363636364, -13.4090909091, -13.1818181818, -12.9545454545,
1255        -12.7272727273, -12.5000000000, -12.2727272727, -12.0454545455,
1256        -11.8181818182, -11.5909090909, -11.3636363636, -11.1363636364,
1257        -10.9090909091, -10.6818181818, -10.4545454545, -10.2272727273,
1258        -10.0000000000, -9.7727272727,  -9.5454545455,  -9.3181818182,
1259        -9.0909090909,  -8.8636363636,  -8.6363636364,  -8.4090909091,
1260        -8.1818181818,  -7.9545454545,  -7.7272727273,  -7.5000000000,
1261        -7.2727272727,  -7.0454545455,  -6.8181818182,  -6.5909090909,
1262        -6.3636363636,  -6.1363636364,  -5.9090909091,  -5.6818181818,
1263        -5.4545454545,  -5.2272727273,  -5.0000000000,  -4.7727272727,
1264        -4.5454545455,  -4.3181818182,  -4.0909090909,  -3.8636363636,
1265        -3.6363636364,  -3.4090909091,  -3.1818181818,  -2.9545454545,
1266        -2.7272727273,  -2.5000000000,  -2.2727272727,  -2.0454545455,
1267        -1.8181818182,  -1.5909090909,  -1.3636363636,  -1.1363636364,
1268        -0.9090909091,  -0.6818181818,  -0.4545454545,  -0.2272727273,
1269        0.0000000000,   0.2272727273,   0.4545454545,   0.6818181818,
1270        0.9090909091,   1.1363636364,   1.3636363636,   1.5909090909,
1271        1.8181818182,   2.0454545455,   2.2727272727,   2.5000000000,
1272        2.7272727273,   2.9545454545,   3.1818181818,   3.4090909091,
1273        3.6363636364,   3.8636363636,   4.0909090909,   4.3181818182,
1274        4.5454545455,   4.7727272727,   5.0000000000,   5.2272727273,
1275        5.4545454545,   5.6818181818,   5.9090909091,   6.1363636364,
1276        6.3636363636,   6.5909090909,   6.8181818182,   7.0454545455,
1277        7.2727272727,   7.5000000000,   7.7272727273,   7.9545454545,
1278        8.1818181818,   8.4090909091,   8.6363636364,   8.8636363636,
1279        9.0909090909,   9.3181818182,   9.5454545455,   9.7727272727,
1280        10.0000000000,  10.2272727273,  10.4545454545,  10.6818181818,
1281        10.9090909091,  11.1363636364,  11.3636363636,  11.5909090909,
1282        11.8181818182,  12.0454545455,  12.2727272727,  12.5000000000,
1283        12.7272727273,  12.9545454545,  13.1818181818,  13.4090909091,
1284        13.6363636364,  13.8636363636,  14.0909090909,  14.3181818182,
1285        14.5454545455,  14.7727272727,  15.0000000000,  15.2272727273,
1286        15.4545454545,  15.6818181818,  15.9090909091,  16.1363636364,
1287        16.3636363636,  16.5909090909,  16.8181818182,  17.0454545455,
1288        17.2727272727,  17.5000000000,  17.7272727273,  17.9545454545,
1289        18.1818181818,  18.4090909091,  18.6363636364,  18.8636363636,
1290        19.0909090909,  19.3181818182,  19.5454545455,  19.7727272727,
1291        20.0000000000});
1292   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1293   EXPECT_THAT(
1294       m.GetDequantizedOutput<int16_t>(),
1295       ElementsAreArray(ArrayFloatNear(
1296           {0.0000000021, 0.0000000026, 0.0000000032, 0.0000000041, 0.0000000051,
1297            0.0000000064, 0.0000000081, 0.0000000101, 0.0000000127, 0.0000000159,
1298            0.0000000200, 0.0000000251, 0.0000000315, 0.0000000396, 0.0000000497,
1299            0.0000000623, 0.0000000782, 0.0000000982, 0.0000001232, 0.0000001547,
1300            0.0000001942, 0.0000002437, 0.0000003059, 0.0000003840, 0.0000004819,
1301            0.0000006049, 0.0000007593, 0.0000009530, 0.0000011962, 0.0000015014,
1302            0.0000018846, 0.0000023654, 0.0000029690, 0.0000037266, 0.0000046776,
1303            0.0000058711, 0.0000073693, 0.0000092497, 0.0000116100, 0.0000145724,
1304            0.0000182909, 0.0000229581, 0.0000288162, 0.0000361690, 0.0000453979,
1305            0.0000569815, 0.0000715205, 0.0000897689, 0.0001126729, 0.0001414198,
1306            0.0001774998, 0.0002227827, 0.0002796147, 0.0003509396, 0.0004404502,
1307            0.0005527786, 0.0006937345, 0.0008706021, 0.0010925128, 0.0013709094,
1308            0.0017201256, 0.0021581065, 0.0027073042, 0.0033957870, 0.0042586071,
1309            0.0053394826, 0.0066928509, 0.0083863576, 0.0105038445, 0.0131488902,
1310            0.0164489307, 0.0205599431, 0.0256715863, 0.0320125562, 0.0398556989,
1311            0.0495221198, 0.0613831074, 0.0758581800, 0.0934070047, 0.1145124805,
1312            0.1396521834, 0.1692560327, 0.2036499335, 0.2429886272, 0.2871859014,
1313            0.3358556241, 0.3882805886, 0.4434251301, 0.5000000000, 0.5565748699,
1314            0.6117194114, 0.6641443759, 0.7128140986, 0.7570113728, 0.7963500665,
1315            0.8307439673, 0.8603478166, 0.8854875195, 0.9065929953, 0.9241418200,
1316            0.9386168926, 0.9504778802, 0.9601443011, 0.9679874438, 0.9743284137,
1317            0.9794400569, 0.9835510693, 0.9868511098, 0.9894961555, 0.9916136424,
1318            0.9933071491, 0.9946605174, 0.9957413929, 0.9966042130, 0.9972926958,
1319            0.9978418935, 0.9982798744, 0.9986290906, 0.9989074872, 0.9991293979,
1320            0.9993062655, 0.9994472214, 0.9995595498, 0.9996490604, 0.9997203853,
1321            0.9997772173, 0.9998225002, 0.9998585802, 0.9998873271, 0.9999102311,
1322            0.9999284795, 0.9999430185, 0.9999546021, 0.9999638310, 0.9999711838,
1323            0.9999770419, 0.9999817091, 0.9999854276, 0.9999883900, 0.9999907503,
1324            0.9999926307, 0.9999941289, 0.9999953224, 0.9999962734, 0.9999970310,
1325            0.9999976346, 0.9999981154, 0.9999984986, 0.9999988038, 0.9999990470,
1326            0.9999992407, 0.9999993951, 0.9999995181, 0.9999996160, 0.9999996941,
1327            0.9999997563, 0.9999998058, 0.9999998453, 0.9999998768, 0.9999999018,
1328            0.9999999218, 0.9999999377, 0.9999999503, 0.9999999604, 0.9999999685,
1329            0.9999999749, 0.9999999800, 0.9999999841, 0.9999999873, 0.9999999899,
1330            0.9999999919, 0.9999999936, 0.9999999949, 0.9999999959, 0.9999999968,
1331            0.9999999974, 0.9999999979},
1332           kQuantizedToleranceInt16)));
1333 }
1334 
TEST_P(LogisticOpTest,SigmoidInt16General)1335 TEST_P(LogisticOpTest, SigmoidInt16General) {
1336   const float kMin = -1;
1337   const float kMax = 32767.f / 32768.f;
1338   QuantizedActivationsOpModel m(
1339       GetRegistration(), BuiltinOperator_LOGISTIC,
1340       /*input=*/{TensorType_INT16, {12}, 13 * kMin, 13 * kMax},
1341       /*output=*/{TensorType_INT16, {12}, kMin, kMax});
1342   m.SetInput<int16_t>({
1343       0, -6, 2, 4, 0.1, 12,    //
1344       3, -2, 10, 1, 0.25, -12  //
1345   });
1346   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1347   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
1348               ElementsAreArray(ArrayFloatNear(
1349                   {0.5, 0.002473, 0.880797, 0.982014, 0.524979, 0.999994,  //
1350                    0.952574, 0.119203, 0.999955, 0.731059, 0.562177, 0},
1351                   kQuantizedToleranceInt16)));
1352 }
1353 
TEST_P(SoftmaxOpTest,Softmax4D)1354 TEST_P(SoftmaxOpTest, Softmax4D) {
1355   FloatActivationsOpModel m(GetRegistration(), 0.1f,
1356                             {TensorType_FLOAT32, {1, 2, 1, 4}},
1357                             TensorType_FLOAT32);
1358   m.SetInput({
1359       0, -6, 2, 4,   // depth = 0
1360       3, -2, 10, 1,  // depth = 1
1361   });
1362   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1363   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
1364                                  .23463, .12877, .28658, .35003,  //
1365                                  .22528, .13664, .45365, .18443,  //
1366                              })));
1367 
1368   // Same input, but a different shape.
1369   FloatActivationsOpModel m2(GetRegistration(), 0.1f,
1370                              {TensorType_FLOAT32, {4, 1, 1, 2}},
1371                              TensorType_FLOAT32);
1372   m2.SetInput({
1373       0, -6,  //
1374       2, 4,   //
1375       3, -2,  //
1376       10, 1,  //
1377   });
1378   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
1379   EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({
1380                                   0.645656, 0.354344,  //
1381                                   0.450166, 0.549834,  //
1382                                   0.622459, 0.377541,  //
1383                                   0.710949, 0.28905,   //
1384                               })));
1385 }
1386 
TEST_P(SoftmaxOpTest,Softmax4DUint8)1387 TEST_P(SoftmaxOpTest, Softmax4DUint8) {
1388   QuantizedActivationsOpModel m(GetRegistration(), 0.1f,
1389                                 {TensorType_UINT8, {1, 2, 1, 4}, -10, 10},
1390                                 TensorType_UINT8);
1391   m.SetInput<uint8_t>({
1392       0, -6, 2, 4,   // depth = 0
1393       3, -2, 10, 1,  // depth = 1
1394   });
1395   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1396   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
1397               ElementsAreArray(ArrayFloatNear(
1398                   {
1399                       .23463, .12877, .28658, .35003,  //
1400                       .22528, .13664, .45365, .18443,  //
1401                   },
1402                   kQuantizedTolerance)));
1403 
1404   // Same input, but a different shape.
1405   QuantizedActivationsOpModel m2(GetRegistration(), 0.1f,
1406                                  {TensorType_UINT8, {4, 1, 1, 2}, -10, 10},
1407                                  TensorType_UINT8);
1408   m2.SetInput<uint8_t>({
1409       0, -6,  //
1410       2, 4,   //
1411       3, -2,  //
1412       10, 1,  //
1413   });
1414   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
1415   EXPECT_THAT(m2.GetDequantizedOutput<uint8_t>(),
1416               ElementsAreArray(ArrayFloatNear(
1417                   {
1418                       0.645656, 0.354344,  //
1419                       0.450166, 0.549834,  //
1420                       0.622459, 0.377541,  //
1421                       0.710949, 0.28905,   //
1422                   },
1423                   kQuantizedTolerance)));
1424 }
1425 
TEST_P(SoftmaxOpTest,Softmax4DUint8Int16)1426 TEST_P(SoftmaxOpTest, Softmax4DUint8Int16) {
1427   QuantizedActivationsOpModel m(GetRegistration(), 0.1f,
1428                                 {TensorType_UINT8, {1, 2, 1, 4}, -10, 10},
1429                                 TensorType_INT16);
1430   m.SetInput<uint8_t>({
1431       0, -6, 2, 4,   // depth = 0
1432       3, -2, 10, 1,  // depth = 1
1433   });
1434   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1435   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
1436               ElementsAreArray(ArrayFloatNear(
1437                   {
1438                       .23463, .12877, .28658, .35003,  //
1439                       .22528, .13664, .45365, .18443,  //
1440                   },
1441                   kQuantizedTolerance)));
1442 
1443   // Same input, but a different shape.
1444   QuantizedActivationsOpModel m2(GetRegistration(), 0.1f,
1445                                  {TensorType_UINT8, {4, 1, 1, 2}, -10, 10},
1446                                  TensorType_INT16);
1447   m2.SetInput<uint8_t>({
1448       0, -6,  //
1449       2, 4,   //
1450       3, -2,  //
1451       10, 1,  //
1452   });
1453   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
1454   EXPECT_THAT(m2.GetDequantizedOutput<int16_t>(),
1455               ElementsAreArray(ArrayFloatNear(
1456                   {
1457                       0.645656, 0.354344,  //
1458                       0.450166, 0.549834,  //
1459                       0.622459, 0.377541,  //
1460                       0.710949, 0.28905,   //
1461                   },
1462                   kQuantizedTolerance)));
1463 }
1464 
1465 // Test quantized softmax with int8 input and output. With the same input as in
1466 // QuantizedActivationsOpTest.Softmax1D, the dequantized output is identical.
TEST_P(SoftmaxOpTest,Softmax1DInt8)1467 TEST_P(SoftmaxOpTest, Softmax1DInt8) {
1468   QuantizedActivationsOpModel m(
1469       GetRegistration(), 0.1, {TensorType_INT8, {8}, -10, 10}, TensorType_INT8);
1470   m.SetInput<int8_t>({0, -6, 2, 4, 3, -2, 10, 1});
1471   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1472   EXPECT_THAT(
1473       m.GetDequantizedOutput<int8_t>(),
1474       ElementsAreArray(ArrayFloatNear({0.09766, 0.05469, 0.12109, 0.14453,
1475                                        0.13281, 0.07813, 0.26563, 0.10938},
1476                                       kQuantizedTolerance)));
1477 }
1478 
1479 // Test quantized softmax with int16 input and output. With the same input as in
1480 // QuantizedActivationsOpTest.Softmax2D, the dequantized output is identical.
TEST_P(SoftmaxOpTest,Softmax1DInt16)1481 TEST_P(SoftmaxOpTest, Softmax1DInt16) {
1482   const float kMin = -1;
1483   const float kMax = 32767.f / 32768.f;
1484   QuantizedActivationsOpModel m(
1485       GetRegistration(), 1,
1486       /*input=*/{TensorType_INT16, {3}, 3 * kMin, 3 * kMax},
1487       /*output_type-*/ TensorType_INT16);
1488   m.SetInput<int16_t>({1, 2, 3});
1489   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1490   EXPECT_THAT(
1491       m.GetDequantizedOutput<int16_t>(),
1492       ElementsAreArray(ArrayFloatNear({0.0900269, 0.2447285, 0.66524096},
1493                                       kQuantizedToleranceInt16)));
1494 }
1495 
TEST_P(SoftmaxOpTest,Softmax1DInt16ZeroElement)1496 TEST_P(SoftmaxOpTest, Softmax1DInt16ZeroElement) {
1497   const float kMin = -1;
1498   const float kMax = 32767.f / 32768.f;
1499   QuantizedActivationsOpModel m(
1500       GetRegistration(), 0.1,
1501       /*input=*/{TensorType_INT16, {1}, 1 * kMin, 1 * kMax}, TensorType_INT16);
1502   m.SetInput<int16_t>({0});
1503   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1504   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
1505               ElementsAreArray(ArrayFloatNear({1}, kQuantizedToleranceInt16)));
1506 }
1507 
TEST_P(SoftmaxOpTest,Softmax2DInt16)1508 TEST_P(SoftmaxOpTest, Softmax2DInt16) {
1509   const float kMin = -1;
1510   const float kMax = 32767.f / 32768.f;
1511   QuantizedActivationsOpModel m(
1512       GetRegistration(), 0.1,
1513       /*input=*/{TensorType_INT16, {2, 4}, 10 * kMin, 10 * kMax},
1514       TensorType_INT16);
1515   m.SetInput<int16_t>({
1516       0, -6, 2, 4,   //
1517       3, -2, 10, 1,  //
1518   });
1519   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1520   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
1521               ElementsAreArray(ArrayFloatNear(
1522                   {
1523                       .23463, .12877, .28658, .35003,  //
1524                       .22528, .13664, .45365, .18443,  //
1525                   },
1526                   kQuantizedToleranceInt16)));
1527 
1528   // Same input, but a different shape.
1529   QuantizedActivationsOpModel m2(
1530       GetRegistration(), 0.1,
1531       /*input=*/{TensorType_INT16, {4, 2}, 10 * kMin, 10 * kMax},
1532       TensorType_INT16);
1533   m2.SetInput<int16_t>({
1534       0, -6,  //
1535       2, 4,   //
1536       3, -2,  //
1537       10, 1,  //
1538   });
1539   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
1540   EXPECT_THAT(m2.GetDequantizedOutput<int16_t>(),
1541               ElementsAreArray(ArrayFloatNear(
1542                   {
1543                       0.645656, 0.354344,  //
1544                       0.450166, 0.549834,  //
1545                       0.622459, 0.377541,  //
1546                       0.710949, 0.28905,   //
1547                   },
1548                   kQuantizedToleranceInt16)));
1549 }
1550 
TEST_P(SoftmaxOpTest,Softmax3DInt16)1551 TEST_P(SoftmaxOpTest, Softmax3DInt16) {
1552   const float kMin = -1;
1553   const float kMax = 32767.f / 32768.f;
1554   QuantizedActivationsOpModel m(
1555       GetRegistration(), 1,
1556       /*input=*/{TensorType_INT16, {1, 2, 4}, 10 * kMin, 10 * kMax},
1557       TensorType_INT16);
1558   m.SetInput<int16_t>({
1559       0, -6, 2, 4,   // depth = 0
1560       3, -2, 10, 1,  // depth = 1
1561   });
1562   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1563   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
1564               ElementsAreArray(ArrayFloatNear(
1565                   {
1566                       .0158756, .000039, .1173, .866779,   //
1567                       .00091, .0000061, .998959, .000123,  //
1568                   },
1569                   kQuantizedTolerance)));
1570 
1571   // Same input, but a different shape.
1572   QuantizedActivationsOpModel m2(
1573       GetRegistration(), 1,
1574       /*input=*/{TensorType_INT16, {4, 1, 2}, 10 * kMin, 10 * kMax},
1575       TensorType_INT16);
1576   m2.SetInput<int16_t>({
1577       0, -6,  //
1578       2, 4,   //
1579       3, -2,  //
1580       10, 1,  //
1581   });
1582   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
1583   EXPECT_THAT(m2.GetDequantizedOutput<int16_t>(),
1584               ElementsAreArray(ArrayFloatNear(
1585                   {
1586                       0.997527, 0.0024726,       //
1587                       0.11920292, 0.88079707,    //
1588                       0.99330715, 0.00669285,    //
1589                       0.999876605, 0.000123395,  //
1590                   },
1591                   kQuantizedTolerance)));
1592 }
1593 
1594 // Test quantized softmax with int16 input and output. With the same input as in
1595 // QuantizedActivationsOpTest.Softmax4D, the dequantized output is identical.
TEST_P(SoftmaxOpTest,Softmax4DInt16)1596 TEST_P(SoftmaxOpTest, Softmax4DInt16) {
1597   const float kMin = -1;
1598   const float kMax = 32767.f / 32768.f;
1599   QuantizedActivationsOpModel m(
1600       GetRegistration(), 0.1,
1601       /*input=*/{TensorType_INT16, {1, 2, 1, 4}, 10 * kMin, 10 * kMax},
1602       TensorType_INT16);
1603   m.SetInput<int16_t>({
1604       0, -6, 2, 4,   // depth = 0
1605       3, -2, 10, 1,  // depth = 1
1606   });
1607   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1608   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
1609               ElementsAreArray(ArrayFloatNear(
1610                   {
1611                       .23463, .12877, .28658, .35003,  //
1612                       .22528, .13664, .45365, .18443,  //
1613                   },
1614                   kQuantizedToleranceInt16)));
1615 
1616   // Same input, but a different shape.
1617   QuantizedActivationsOpModel m2(
1618       GetRegistration(), 0.1,
1619       /*input=*/{TensorType_INT16, {4, 1, 1, 2}, 10 * kMin, 10 * kMax},
1620       TensorType_INT16);
1621   m2.SetInput<int16_t>({
1622       0, -6,  //
1623       2, 4,   //
1624       3, -2,  //
1625       10, 1,  //
1626   });
1627   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
1628   EXPECT_THAT(m2.GetDequantizedOutput<int16_t>(),
1629               ElementsAreArray(ArrayFloatNear(
1630                   {
1631                       0.645656, 0.354344,  //
1632                       0.450166, 0.549834,  //
1633                       0.622459, 0.377541,  //
1634                       0.710949, 0.28905,   //
1635                   },
1636                   kQuantizedToleranceInt16)));
1637 }
1638 
1639 // Test quantized softmax with int8 input and int16 output. With the same input
1640 // as in QuantizedActivationsOpTest.Softmax1D, the dequantized output is
1641 // identical.
TEST_P(SoftmaxOpTest,Softmax1DInt8Int16)1642 TEST_P(SoftmaxOpTest, Softmax1DInt8Int16) {
1643   QuantizedActivationsOpModel m(GetRegistration(), 0.1f,
1644                                 {TensorType_INT8, {8}, -10, 10},
1645                                 TensorType_INT16);
1646   m.SetInput<int8_t>({0, -6, 2, 4, 3, -2, 10, 1});
1647   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1648   EXPECT_THAT(
1649       m.GetDequantizedOutput<int16_t>(),
1650       ElementsAreArray(ArrayFloatNear({0.09766, 0.05469, 0.12109, 0.14453,
1651                                        0.13281, 0.07813, 0.26563, 0.10938},
1652                                       kQuantizedTolerance)));
1653 }
1654 
1655 // Test quantized softmax with int8 input and output. With the same input as in
1656 // QuantizedActivationsOpTest.Softmax2D, the dequantized output is identical.
TEST_P(SoftmaxOpTest,Softmax2DInt8)1657 TEST_P(SoftmaxOpTest, Softmax2DInt8) {
1658   QuantizedActivationsOpModel m(GetRegistration(), 0.1f,
1659                                 {TensorType_INT8, {2, 4}, -10, 10},
1660                                 TensorType_INT8);
1661   m.SetInput<int8_t>({
1662       0, -6, 2, 4,   //
1663       3, -2, 10, 1,  //
1664   });
1665   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1666   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
1667               ElementsAreArray(ArrayFloatNear(
1668                   {
1669                       .23463, .12877, .28658, .35003,  //
1670                       .22528, .13664, .45365, .18443,  //
1671                   },
1672                   kQuantizedTolerance)));
1673 
1674   // Same input, but a different shape.
1675   QuantizedActivationsOpModel m2(GetRegistration(), 0.1f,
1676                                  {TensorType_INT8, {4, 2}, -10, 10},
1677                                  TensorType_INT8);
1678   m2.SetInput<int8_t>({
1679       0, -6,  //
1680       2, 4,   //
1681       3, -2,  //
1682       10, 1,  //
1683   });
1684   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
1685   EXPECT_THAT(m2.GetDequantizedOutput<int8_t>(),
1686               ElementsAreArray(ArrayFloatNear(
1687                   {
1688                       0.645656, 0.354344,  //
1689                       0.450166, 0.549834,  //
1690                       0.622459, 0.377541,  //
1691                       0.710949, 0.28905,   //
1692                   },
1693                   kQuantizedTolerance)));
1694 }
1695 
1696 // Test quantized softmax with int8 input and int16 output. With the same input
1697 // as in QuantizedActivationsOpTest.Softmax2D, the dequantized output is
1698 // identical.
TEST_P(SoftmaxOpTest,Softmax2DInt8Int16)1699 TEST_P(SoftmaxOpTest, Softmax2DInt8Int16) {
1700   QuantizedActivationsOpModel m(GetRegistration(), 0.1f,
1701                                 {TensorType_INT8, {2, 4}, -10, 10},
1702                                 TensorType_INT16);
1703   m.SetInput<int8_t>({
1704       0, -6, 2, 4,   //
1705       3, -2, 10, 1,  //
1706   });
1707   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1708   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
1709               ElementsAreArray(ArrayFloatNear(
1710                   {
1711                       .23463, .12877, .28658, .35003,  //
1712                       .22528, .13664, .45365, .18443,  //
1713                   },
1714                   kQuantizedTolerance)));
1715 
1716   // Same input, but a different shape.
1717   QuantizedActivationsOpModel m2(GetRegistration(), 0.1f,
1718                                  {TensorType_INT8, {4, 2}, -10, 10},
1719                                  TensorType_INT16);
1720   m2.SetInput<int8_t>({
1721       0, -6,  //
1722       2, 4,   //
1723       3, -2,  //
1724       10, 1,  //
1725   });
1726   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
1727   EXPECT_THAT(m2.GetDequantizedOutput<int16_t>(),
1728               ElementsAreArray(ArrayFloatNear(
1729                   {
1730                       0.645656, 0.354344,  //
1731                       0.450166, 0.549834,  //
1732                       0.622459, 0.377541,  //
1733                       0.710949, 0.28905,   //
1734                   },
1735                   kQuantizedTolerance)));
1736 }
1737 
1738 // Test quantized softmax with int8 input and output. With the same input as in
1739 // QuantizedActivationsOpTest.Softmax3D, the dequantized output is identical.
TEST_P(SoftmaxOpTest,Softmax3DInt8)1740 TEST_P(SoftmaxOpTest, Softmax3DInt8) {
1741   QuantizedActivationsOpModel m(GetRegistration(), 0.1f,
1742                                 {TensorType_INT8, {1, 2, 4}, -10, 10},
1743                                 TensorType_INT8);
1744   m.SetInput<int8_t>({
1745       0, -6, 2, 4,   // depth = 0
1746       3, -2, 10, 1,  // depth = 1
1747   });
1748   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1749   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
1750               ElementsAreArray(ArrayFloatNear(
1751                   {
1752                       .23463, .12877, .28658, .35003,  //
1753                       .22528, .13664, .45365, .18443,  //
1754                   },
1755                   kQuantizedTolerance)));
1756 
1757   // Same input, but a different shape.
1758   QuantizedActivationsOpModel m2(GetRegistration(), 0.1f,
1759                                  {TensorType_INT8, {4, 1, 2}, -10, 10},
1760                                  TensorType_INT8);
1761   m2.SetInput<int8_t>({
1762       0, -6,  //
1763       2, 4,   //
1764       3, -2,  //
1765       10, 1,  //
1766   });
1767   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
1768   EXPECT_THAT(m2.GetDequantizedOutput<int8_t>(),
1769               ElementsAreArray(ArrayFloatNear(
1770                   {
1771                       0.645656, 0.354344,  //
1772                       0.450166, 0.549834,  //
1773                       0.622459, 0.377541,  //
1774                       0.710949, 0.28905,   //
1775                   },
1776                   kQuantizedTolerance)));
1777 }
1778 
1779 // Test quantized softmax with int8 input and output. With the same input as in
1780 // QuantizedActivationsOpTest.Softmax3D, the dequantized output is identical.
TEST_P(SoftmaxOpTest,Softmax3DInt8Int16)1781 TEST_P(SoftmaxOpTest, Softmax3DInt8Int16) {
1782   QuantizedActivationsOpModel m(GetRegistration(), 0.1f,
1783                                 {TensorType_INT8, {1, 2, 4}, -10, 10},
1784                                 TensorType_INT16);
1785   m.SetInput<int8_t>({
1786       0, -6, 2, 4,   // depth = 0
1787       3, -2, 10, 1,  // depth = 1
1788   });
1789   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1790   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
1791               ElementsAreArray(ArrayFloatNear(
1792                   {
1793                       .23463, .12877, .28658, .35003,  //
1794                       .22528, .13664, .45365, .18443,  //
1795                   },
1796                   kQuantizedTolerance)));
1797 
1798   // Same input, but a different shape.
1799   QuantizedActivationsOpModel m2(GetRegistration(), 0.1f,
1800                                  {TensorType_INT8, {4, 1, 2}, -10, 10},
1801                                  TensorType_INT16);
1802   m2.SetInput<int8_t>({
1803       0, -6,  //
1804       2, 4,   //
1805       3, -2,  //
1806       10, 1,  //
1807   });
1808   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
1809   EXPECT_THAT(m2.GetDequantizedOutput<int16_t>(),
1810               ElementsAreArray(ArrayFloatNear(
1811                   {
1812                       0.645656, 0.354344,  //
1813                       0.450166, 0.549834,  //
1814                       0.622459, 0.377541,  //
1815                       0.710949, 0.28905,   //
1816                   },
1817                   kQuantizedTolerance)));
1818 }
1819 
1820 // Test quantized softmax with int8 input and output. With the same input as in
1821 // QuantizedActivationsOpTest.Softmax4D, the dequantized output is identical.
TEST_P(SoftmaxOpTest,Softmax4DInt8)1822 TEST_P(SoftmaxOpTest, Softmax4DInt8) {
1823   QuantizedActivationsOpModel m(GetRegistration(), 0.1f,
1824                                 {TensorType_INT8, {1, 2, 1, 4}, -10, 10},
1825                                 TensorType_INT8);
1826   m.SetInput<int8_t>({
1827       0, -6, 2, 4,   // depth = 0
1828       3, -2, 10, 1,  // depth = 1
1829   });
1830   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1831   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({
1832                                          -68, -95, -54, -38,  //
1833                                          -70, -93, -12, -81,  //
1834                                      }));
1835   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
1836               ElementsAreArray(ArrayFloatNear(
1837                   {
1838                       .23463, .12877, .28658, .35003,  //
1839                       .22528, .13664, .45365, .18443,  //
1840                   },
1841                   kQuantizedTolerance)));
1842 
1843   // Same input, but a different shape.
1844   QuantizedActivationsOpModel m2(GetRegistration(), 0.1f,
1845                                  {TensorType_INT8, {4, 1, 1, 2}, -10, 10},
1846                                  TensorType_INT8);
1847   m2.SetInput<int8_t>({
1848       0, -6,  //
1849       2, 4,   //
1850       3, -2,  //
1851       10, 1,  //
1852   });
1853   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
1854   EXPECT_THAT(m2.GetDequantizedOutput<int8_t>(),
1855               ElementsAreArray(ArrayFloatNear(
1856                   {
1857                       0.645656, 0.354344,  //
1858                       0.450166, 0.549834,  //
1859                       0.622459, 0.377541,  //
1860                       0.710949, 0.28905,   //
1861                   },
1862                   kQuantizedTolerance)));
1863 }
1864 
1865 // Test quantized softmax with int8 input and output. With the same input as in
1866 // QuantizedActivationsOpTest.Softmax4D, the dequantized output is identical.
TEST_P(SoftmaxOpTest,Softmax4DInt8Int16)1867 TEST_P(SoftmaxOpTest, Softmax4DInt8Int16) {
1868   QuantizedActivationsOpModel m(GetRegistration(), 0.1f,
1869                                 {TensorType_INT8, {1, 2, 1, 4}, -10, 10},
1870                                 TensorType_INT16);
1871   m.SetInput<int8_t>({
1872       0, -6, 2, 4,   // depth = 0
1873       3, -2, 10, 1,  // depth = 1
1874   });
1875   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1876   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
1877               ElementsAreArray(ArrayFloatNear(
1878                   {
1879                       .23463, .12877, .28658, .35003,  //
1880                       .22528, .13664, .45365, .18443,  //
1881                   },
1882                   kQuantizedTolerance)));
1883 
1884   // Same input, but a different shape.
1885   QuantizedActivationsOpModel m2(GetRegistration(), 0.1f,
1886                                  {TensorType_INT8, {4, 1, 1, 2}, -10, 10},
1887                                  TensorType_INT16);
1888   m2.SetInput<int8_t>({
1889       0, -6,  //
1890       2, 4,   //
1891       3, -2,  //
1892       10, 1,  //
1893   });
1894   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
1895   EXPECT_THAT(m2.GetDequantizedOutput<int16_t>(),
1896               ElementsAreArray(ArrayFloatNear(
1897                   {
1898                       0.645656, 0.354344,  //
1899                       0.450166, 0.549834,  //
1900                       0.622459, 0.377541,  //
1901                       0.710949, 0.28905,   //
1902                   },
1903                   kQuantizedTolerance)));
1904 }
1905 
TEST_P(SoftmaxOpTest,Softmax3D)1906 TEST_P(SoftmaxOpTest, Softmax3D) {
1907   FloatActivationsOpModel m(GetRegistration(), 0.1f,
1908                             {TensorType_FLOAT32, {1, 2, 4}},
1909                             TensorType_FLOAT32);
1910   m.SetInput({
1911       0, -6, 2, 4,   // depth = 0
1912       3, -2, 10, 1,  // depth = 1
1913   });
1914   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1915   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
1916                                  .23463, .12877, .28658, .35003,  //
1917                                  .22528, .13664, .45365, .18443,  //
1918                              })));
1919 
1920   // Same input, but a different shape.
1921   FloatActivationsOpModel m2(GetRegistration(), 0.1f,
1922                              {TensorType_FLOAT32, {4, 1, 2}},
1923                              TensorType_FLOAT32);
1924   m2.SetInput({
1925       0, -6,  //
1926       2, 4,   //
1927       3, -2,  //
1928       10, 1,  //
1929   });
1930   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
1931   EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({
1932                                   0.645656, 0.354344,  //
1933                                   0.450166, 0.549834,  //
1934                                   0.622459, 0.377541,  //
1935                                   0.710949, 0.28905,   //
1936                               })));
1937 }
1938 
TEST_P(SoftmaxOpTest,Softmax3DUint8)1939 TEST_P(SoftmaxOpTest, Softmax3DUint8) {
1940   QuantizedActivationsOpModel m(GetRegistration(), 0.1f,
1941                                 {TensorType_UINT8, {1, 2, 4}, -10, 10},
1942                                 TensorType_UINT8);
1943   m.SetInput<uint8_t>({
1944       0, -6, 2, 4,   // depth = 0
1945       3, -2, 10, 1,  // depth = 1
1946   });
1947   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1948   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
1949               ElementsAreArray(ArrayFloatNear(
1950                   {
1951                       .23463, .12877, .28658, .35003,  //
1952                       .22528, .13664, .45365, .18443,  //
1953                   },
1954                   kQuantizedTolerance)));
1955 
1956   // Same input, but a different shape.
1957   QuantizedActivationsOpModel m2(GetRegistration(), 0.1f,
1958                                  {TensorType_UINT8, {4, 1, 2}, -10, 10},
1959                                  TensorType_UINT8);
1960   m2.SetInput<uint8_t>({
1961       0, -6,  //
1962       2, 4,   //
1963       3, -2,  //
1964       10, 1,  //
1965   });
1966   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
1967   EXPECT_THAT(m2.GetDequantizedOutput<uint8_t>(),
1968               ElementsAreArray(ArrayFloatNear(
1969                   {
1970                       0.645656, 0.354344,  //
1971                       0.450166, 0.549834,  //
1972                       0.622459, 0.377541,  //
1973                       0.710949, 0.28905,   //
1974                   },
1975                   kQuantizedTolerance)));
1976 }
1977 
TEST_P(SoftmaxOpTest,Softmax3DUint8Int16)1978 TEST_P(SoftmaxOpTest, Softmax3DUint8Int16) {
1979   QuantizedActivationsOpModel m(GetRegistration(), 0.1f,
1980                                 {TensorType_UINT8, {1, 2, 4}, -10, 10},
1981                                 TensorType_INT16);
1982   m.SetInput<uint8_t>({
1983       0, -6, 2, 4,   // depth = 0
1984       3, -2, 10, 1,  // depth = 1
1985   });
1986   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1987   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
1988               ElementsAreArray(ArrayFloatNear(
1989                   {
1990                       .23463, .12877, .28658, .35003,  //
1991                       .22528, .13664, .45365, .18443,  //
1992                   },
1993                   kQuantizedTolerance)));
1994 
1995   // Same input, but a different shape.
1996   QuantizedActivationsOpModel m2(GetRegistration(), 0.1f,
1997                                  {TensorType_UINT8, {4, 1, 2}, -10, 10},
1998                                  TensorType_INT16);
1999   m2.SetInput<uint8_t>({
2000       0, -6,  //
2001       2, 4,   //
2002       3, -2,  //
2003       10, 1,  //
2004   });
2005   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
2006   EXPECT_THAT(m2.GetDequantizedOutput<int16_t>(),
2007               ElementsAreArray(ArrayFloatNear(
2008                   {
2009                       0.645656, 0.354344,  //
2010                       0.450166, 0.549834,  //
2011                       0.622459, 0.377541,  //
2012                       0.710949, 0.28905,   //
2013                   },
2014                   kQuantizedTolerance)));
2015 }
2016 
TEST_P(SoftmaxOpTest,Softmax1D)2017 TEST_P(SoftmaxOpTest, Softmax1D) {
2018   FloatActivationsOpModel m(GetRegistration(), 0.1f, {TensorType_FLOAT32, {8}},
2019                             TensorType_FLOAT32);
2020   m.SetInput({0, -6, 2, 4, 3, -2, 10, 1});
2021   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2022   EXPECT_THAT(
2023       m.GetOutput(),
2024       ElementsAreArray(ArrayFloatNear(
2025           {.09752, .05352, .11911, .14548, .13164, .07984, .26509, .10778})));
2026 }
2027 
TEST_P(SoftmaxOpTest,Softmax1DMax)2028 TEST_P(SoftmaxOpTest, Softmax1DMax) {
2029   FloatActivationsOpModel m(GetRegistration(), 0.1f, {TensorType_FLOAT32, {8}},
2030                             TensorType_FLOAT32);
2031   m.SetInput({std::numeric_limits<float>::max(), -6, 2, 4, 3, -2, 10, 1});
2032   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2033   EXPECT_THAT(m.GetOutput(),
2034               ElementsAreArray(ArrayFloatNear({1, 0, 0, 0, 0, 0, 0, 0})));
2035 }
2036 
TEST_P(SoftmaxOpTest,Softmax1DInf)2037 TEST_P(SoftmaxOpTest, Softmax1DInf) {
2038   FloatActivationsOpModel m(GetRegistration(), 0.1f, {TensorType_FLOAT32, {8}},
2039                             TensorType_FLOAT32);
2040   m.SetInput({std::numeric_limits<float>::infinity(), -6, 2, 4, 3, -2, 10, 1});
2041   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2042   auto output = m.GetOutput();
2043   for (int i = 0; i < 8; ++i) {
2044     EXPECT_TRUE(std::isnan(output[i]));
2045   }
2046 }
2047 
TEST_P(SoftmaxOpTest,Softmax1DUint8)2048 TEST_P(SoftmaxOpTest, Softmax1DUint8) {
2049   QuantizedActivationsOpModel m(GetRegistration(), 0.1f,
2050                                 {TensorType_UINT8, {8}, -10, 10},
2051                                 TensorType_UINT8);
2052   m.SetInput<uint8_t>({0, -6, 2, 4, 3, -2, 10, 1});
2053   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2054   EXPECT_THAT(
2055       m.GetDequantizedOutput<uint8_t>(),
2056       ElementsAreArray(ArrayFloatNear({0.09766, 0.05469, 0.12109, 0.14453,
2057                                        0.13281, 0.07813, 0.26563, 0.10938},
2058                                       kQuantizedTolerance)));
2059 }
2060 
TEST_P(SoftmaxOpTest,Softmax1DUint8Int16)2061 TEST_P(SoftmaxOpTest, Softmax1DUint8Int16) {
2062   QuantizedActivationsOpModel m(GetRegistration(), 0.1f,
2063                                 {TensorType_UINT8, {8}, -10, 10},
2064                                 TensorType_INT16);
2065   m.SetInput<uint8_t>({0, -6, 2, 4, 3, -2, 10, 1});
2066   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2067   EXPECT_THAT(
2068       m.GetDequantizedOutput<int16_t>(),
2069       ElementsAreArray(ArrayFloatNear({0.09766, 0.05469, 0.12109, 0.14453,
2070                                        0.13281, 0.07813, 0.26563, 0.10938},
2071                                       kQuantizedTolerance)));
2072 }
2073 
TEST_P(SoftmaxOpTest,Softmax2D)2074 TEST_P(SoftmaxOpTest, Softmax2D) {
2075   FloatActivationsOpModel m(GetRegistration(), 0.1f,
2076                             {TensorType_FLOAT32, {2, 4}}, TensorType_FLOAT32);
2077   m.SetInput({
2078       0, -6, 2, 4,   //
2079       3, -2, 10, 1,  //
2080   });
2081   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2082   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
2083                                  .23463, .12877, .28658, .35003,  //
2084                                  .22528, .13664, .45365, .18443,  //
2085                              })));
2086 
2087   // Same input, but a different shape.
2088   FloatActivationsOpModel m2(GetRegistration(), 0.1f,
2089                              {TensorType_FLOAT32, {4, 2}}, TensorType_FLOAT32);
2090   m2.SetInput({
2091       0, -6,  //
2092       2, 4,   //
2093       3, -2,  //
2094       10, 1,  //
2095   });
2096   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
2097   EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({
2098                                   0.645656, 0.354344,  //
2099                                   0.450166, 0.549834,  //
2100                                   0.622459, 0.377541,  //
2101                                   0.710949, 0.28905,   //
2102                               })));
2103 }
2104 
TEST_P(SoftmaxOpTest,Softmax2DMultithreading)2105 TEST_P(SoftmaxOpTest, Softmax2DMultithreading) {
2106   FloatActivationsOpModel m(GetRegistration(), 0.1f,
2107                             {TensorType_FLOAT32, {16, 4}}, TensorType_FLOAT32);
2108   m.SetInput({
2109       0, -6, 2,  4,  //  Thread 1.
2110       3, -2, 10, 1,  //
2111       0, -6, 2,  4,  //
2112       3, -2, 10, 1,  //
2113       0, -6, 2,  4,  //
2114       3, -2, 10, 1,  //
2115       0, -6, 2,  4,  //
2116       3, -2, 10, 1,  //
2117       3, -2, 10, 1,  //  Thread 2.
2118       0, -6, 2,  4,  //
2119       3, -2, 10, 1,  //
2120       0, -6, 2,  4,  //
2121       3, -2, 10, 1,  //
2122       0, -6, 2,  4,  //
2123       3, -2, 10, 1,  //
2124       0, -6, 2,  4,  //
2125   });
2126   m.SetNumThreads(2);
2127   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2128   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
2129                                  .23463, .12877, .28658, .35003,  //
2130                                  .22528, .13664, .45365, .18443,  //
2131                                  .23463, .12877, .28658, .35003,  //
2132                                  .22528, .13664, .45365, .18443,  //
2133                                  .23463, .12877, .28658, .35003,  //
2134                                  .22528, .13664, .45365, .18443,  //
2135                                  .23463, .12877, .28658, .35003,  //
2136                                  .22528, .13664, .45365, .18443,  //
2137                                  .22528, .13664, .45365, .18443,  //
2138                                  .23463, .12877, .28658, .35003,  //
2139                                  .22528, .13664, .45365, .18443,  //
2140                                  .23463, .12877, .28658, .35003,  //
2141                                  .22528, .13664, .45365, .18443,  //
2142                                  .23463, .12877, .28658, .35003,  //
2143                                  .22528, .13664, .45365, .18443,  //
2144                                  .23463, .12877, .28658, .35003,  //
2145                              })));
2146 
2147   // Same input, but a different shape.
2148   FloatActivationsOpModel m2(GetRegistration(), 0.1f,
2149                              {TensorType_FLOAT32, {16, 2}}, TensorType_FLOAT32);
2150   m2.SetInput({
2151       0,  -6,  // Thread 1
2152       2,  4,   //
2153       3,  -2,  //
2154       10, 1,   //
2155       0,  -6,  //
2156       2,  4,   //
2157       3,  -2,  //
2158       10, 1,   //
2159       10, 1,   // Thread 2
2160       3,  -2,  //
2161       2,  4,   //
2162       0,  -6,  //
2163       10, 1,   //
2164       3,  -2,  //
2165       2,  4,   //
2166       0,  -6,  //
2167   });
2168   m2.SetNumThreads(2);
2169   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
2170   EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({
2171                                   0.645656, 0.354344,  //
2172                                   0.450166, 0.549834,  //
2173                                   0.622459, 0.377541,  //
2174                                   0.710949, 0.28905,   //
2175                                   0.645656, 0.354344,  //
2176                                   0.450166, 0.549834,  //
2177                                   0.622459, 0.377541,  //
2178                                   0.710949, 0.28905,   //
2179                                   0.710949, 0.28905,   //
2180                                   0.622459, 0.377541,  //
2181                                   0.450166, 0.549834,  //
2182                                   0.645656, 0.354344,  //
2183                                   0.710949, 0.28905,   //
2184                                   0.622459, 0.377541,  //
2185                                   0.450166, 0.549834,  //
2186                                   0.645656, 0.354344,  //
2187                               })));
2188 }
2189 
TEST_P(SoftmaxOpTest,Softmax2DUint8)2190 TEST_P(SoftmaxOpTest, Softmax2DUint8) {
2191   QuantizedActivationsOpModel m(GetRegistration(), 0.1f,
2192                                 {TensorType_UINT8, {2, 4}, -10, 10},
2193                                 TensorType_UINT8);
2194   m.SetInput<uint8_t>({
2195       0, -6, 2, 4,   //
2196       3, -2, 10, 1,  //
2197   });
2198   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2199   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
2200               ElementsAreArray(ArrayFloatNear(
2201                   {
2202                       .23463, .12877, .28658, .35003,  //
2203                       .22528, .13664, .45365, .18443,  //
2204                   },
2205                   kQuantizedTolerance)));
2206 
2207   // Same input, but a different shape.
2208   QuantizedActivationsOpModel m2(GetRegistration(), 0.1f,
2209                                  {TensorType_UINT8, {4, 2}, -10, 10},
2210                                  TensorType_UINT8);
2211   m2.SetInput<uint8_t>({
2212       0, -6,  //
2213       2, 4,   //
2214       3, -2,  //
2215       10, 1,  //
2216   });
2217   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
2218   EXPECT_THAT(m2.GetDequantizedOutput<uint8_t>(),
2219               ElementsAreArray(ArrayFloatNear(
2220                   {
2221                       0.645656, 0.354344,  //
2222                       0.450166, 0.549834,  //
2223                       0.622459, 0.377541,  //
2224                       0.710949, 0.28905,   //
2225                   },
2226                   kQuantizedTolerance)));
2227 }
2228 
TEST_P(SoftmaxOpTest,Softmax2DUint8Int16)2229 TEST_P(SoftmaxOpTest, Softmax2DUint8Int16) {
2230   QuantizedActivationsOpModel m(GetRegistration(), 0.1f,
2231                                 {TensorType_UINT8, {2, 4}, -10, 10},
2232                                 TensorType_INT16);
2233   m.SetInput<uint8_t>({
2234       0, -6, 2, 4,   //
2235       3, -2, 10, 1,  //
2236   });
2237   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2238   EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
2239               ElementsAreArray(ArrayFloatNear(
2240                   {
2241                       .23463, .12877, .28658, .35003,  //
2242                       .22528, .13664, .45365, .18443,  //
2243                   },
2244                   kQuantizedTolerance)));
2245 
2246   // Same input, but a different shape.
2247   QuantizedActivationsOpModel m2(GetRegistration(), 0.1f,
2248                                  {TensorType_UINT8, {4, 2}, -10, 10},
2249                                  TensorType_INT16);
2250   m2.SetInput<uint8_t>({
2251       0, -6,  //
2252       2, 4,   //
2253       3, -2,  //
2254       10, 1,  //
2255   });
2256   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
2257   EXPECT_THAT(m2.GetDequantizedOutput<int16_t>(),
2258               ElementsAreArray(ArrayFloatNear(
2259                   {
2260                       0.645656, 0.354344,  //
2261                       0.450166, 0.549834,  //
2262                       0.622459, 0.377541,  //
2263                       0.710949, 0.28905,   //
2264                   },
2265                   kQuantizedTolerance)));
2266 }
2267 
2268 // This contains the same test values as the Softmax test, but reference answer
2269 // generated via the following snippet of python:
2270 //   logits1 = tf.constant([[0, -6, 2, 4],[3, -2, 10, 1]], dtype=tf.float32)
2271 //   logits2 = tf.constant([[0,-6],[2,4],[3,-2],[10,1]], dtype=tf.float32)
2272 //   lsm1 = tf.nn.log_softmax(logits1)
2273 //   lsm2 = tf.nn.log_softmax(logits2)
2274 //   with tf.Session() as sess:
2275 //     print('lsm1', sess.run(lsm1))
2276 //     print('lsm2', sess.run(lsm2))
2277 
TEST_P(LogSoftmaxOpTest,LogSoftmax)2278 TEST_P(LogSoftmaxOpTest, LogSoftmax) {
2279   FloatActivationsOpModel m(GetRegistration(), BuiltinOperator_LOG_SOFTMAX,
2280                             /*input=*/{TensorType_FLOAT32, {2, 4}});
2281   m.SetInput({
2282       0, -6, 2, 4,   //
2283       3, -2, 10, 1,  //
2284   });
2285   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2286   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
2287                                  -4.14297, -10.14297, -2.14297, -.142971,    //
2288                                  -7.00104, -12.00104, -.00104087, -9.00104,  //
2289                              })));
2290 
2291   // Same input, but a different shape.
2292   FloatActivationsOpModel m2(GetRegistration(), BuiltinOperator_LOG_SOFTMAX,
2293                              /*input=*/{TensorType_FLOAT32, {4, 2}});
2294   m2.SetInput({
2295       0, -6,  //
2296       2, 4,   //
2297       3, -2,  //
2298       10, 1,  //
2299   });
2300   ASSERT_EQ(m2.Invoke(), kTfLiteOk);
2301   EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({
2302                                   -.00247565, -6.00247,   //
2303                                   -2.12692, -.126928,     //
2304                                   -.00671534, -5.00671,   //
2305                                   -.000123374, -9.00012,  //
2306                               })));
2307 }
2308 
TEST_P(LogSoftmaxOpTest,LogSoftmaxUint8)2309 TEST_P(LogSoftmaxOpTest, LogSoftmaxUint8) {
2310   const float kLogSoftmaxQuantizedTolerance = 16 / 256.0;
2311   // Corresponds to input scale of 20/255.
2312   QuantizedActivationsOpModel m(
2313       GetRegistration(), BuiltinOperator_LOG_SOFTMAX,
2314       /*input=*/{TensorType_UINT8, {2, 4}, -10, 10},
2315       /*output=*/{TensorType_UINT8, {}, 0, 0, 16. / 256, 255});
2316   m.SetInput<uint8_t>({
2317       0, -6, 2, 4,   //
2318       3, -2, 10, 1,  //
2319   });
2320   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2321   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
2322               ElementsAreArray(ArrayFloatNear(
2323                   {
2324                       -4.14297, -10.14297, -2.14297, -.142971,    //
2325                       -7.00104, -12.00104, -.00104087, -9.00104,  //
2326                   },
2327                   kLogSoftmaxQuantizedTolerance)));
2328   EXPECT_THAT(m.GetOutput<uint8_t>(),
2329               ElementsAreArray({189, 93, 221, 253, 142, 63, 255, 111}));
2330 }
2331 
TEST_P(LogSoftmaxOpTest,LogSoftmaxInt8)2332 TEST_P(LogSoftmaxOpTest, LogSoftmaxInt8) {
2333   const float kLogSoftmaxQuantizedTolerance = 0.06355;
2334   QuantizedActivationsOpModel m(
2335       GetRegistration(), BuiltinOperator_LOG_SOFTMAX,
2336       /*input=*/{TensorType_INT8, {2, 4}, -10, 10},
2337       /*output=*/{TensorType_INT8, {}, 0, 0, 16. / 256, 127});
2338   m.SetInput<int8_t>({
2339       0, -6, 2, 4,   //
2340       3, -2, 10, 1,  //
2341   });
2342   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2343   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
2344               ElementsAreArray(ArrayFloatNear(
2345                   {
2346                       -4.14297, -10.14297, -2.14297, -.142971,    //
2347                       -7.00104, -12.00104, -.00104087, -9.00104,  //
2348                   },
2349                   kLogSoftmaxQuantizedTolerance)));
2350   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({
2351                                          61, -36, 93, 125,   //
2352                                          15, -65, 127, -16,  //
2353                                      }));
2354 }
2355 
TEST(QuantizedActivationsOpTest,LogSoftmaxInt8LargeNegativeNumber)2356 TEST(QuantizedActivationsOpTest, LogSoftmaxInt8LargeNegativeNumber) {
2357   const float kLogSoftmaxQuantizedTolerance = 0.06355;
2358   QuantizedActivationsOpModel m(
2359       BuiltinOperator_LOG_SOFTMAX,
2360       /*input=*/{TensorType_INT8, {2, 4}, -10, 10},
2361       /*output=*/{TensorType_INT8, {}, 0, 0, 16. / 256, 127});
2362   m.SetInput<int8_t>({
2363       -9.9, -9.9, 0, 0,  //
2364       7.8, -2, 2, 1,     //
2365   });
2366   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2367   EXPECT_THAT(
2368       m.GetDequantizedOutput<int8_t>(),
2369       ElementsAreArray(ArrayFloatNear(
2370           {-10.5625, -10.5625, -0.6875, -0.6875, -0.004, -9.8125, -5.75, -6.75},
2371           kLogSoftmaxQuantizedTolerance)));
2372   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({
2373                                          -42, -42, 116, 116,  //
2374                                          127, -30, 35, 19,    //
2375                                      }));
2376 }
2377 
2378 const auto kPReluKernelMap = new std::map<string, TfLiteRegistration*>({
2379     {"Reference", ops::builtin::Register_PRELU_REF()},
2380     {"GenericOptimized", ops::builtin::Register_PRELU()},
2381 });
2382 
2383 // A base class of PRelu op model. It provides the constructor for
2384 // FloatPReluOpModel and QuantizedPReluOpModel.
2385 class BasePReluOpModel : public SingleOpModel {
2386  public:
BasePReluOpModel(const TensorData & input,const TensorData & alpha)2387   BasePReluOpModel(const TensorData& input, const TensorData& alpha) {
2388     input_ = AddInput(input);
2389     alpha_ = AddInput(alpha);
2390     output_ = AddOutput({input.type, input.shape, input.min, input.max});
2391     SetBuiltinOp(BuiltinOperator_PRELU, BuiltinOptions_NONE, 0);
2392     BuildInterpreter({GetShape(input_), GetShape(alpha_)});
2393   }
2394 
2395  protected:
2396   int input_;
2397   int alpha_;
2398   int output_;
2399 };
2400 
2401 // The FloatPReluOpModel class handles float input and output.
2402 class FloatPReluOpModel : public BasePReluOpModel {
2403  public:
2404   using BasePReluOpModel::BasePReluOpModel;
2405 
SetInput(std::initializer_list<float> data)2406   void SetInput(std::initializer_list<float> data) {
2407     PopulateTensor(input_, data);
2408   }
SetAlpha(std::initializer_list<float> data)2409   void SetAlpha(std::initializer_list<float> data) {
2410     PopulateTensor(alpha_, data);
2411   }
GetOutput()2412   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
2413 };
2414 
2415 // The QuantizedPReluOpModel class handles quantized input and output.
2416 class QuantizedPReluOpModel : public BasePReluOpModel {
2417  public:
2418   using BasePReluOpModel::BasePReluOpModel;
2419 
2420   template <typename T>
SetInput(std::initializer_list<float> data)2421   void SetInput(std::initializer_list<float> data) {
2422     QuantizeAndPopulate<T>(input_, data);
2423   }
2424   template <typename T>
SetAlpha(std::initializer_list<float> data)2425   void SetAlpha(std::initializer_list<float> data) {
2426     QuantizeAndPopulate<T>(alpha_, data);
2427   }
2428   template <typename T>
GetOutput()2429   std::vector<T> GetOutput() {
2430     return ExtractVector<T>(output_);
2431   }
2432   template <typename T>
GetDequantizedOutput()2433   std::vector<float> GetDequantizedOutput() {
2434     return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
2435                          GetZeroPoint(output_));
2436   }
2437 };
2438 
2439 class PReluOpTest : public SingleOpTest {
2440  protected:
GetKernelMap()2441   const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
2442     return *kPReluKernelMap;
2443   }
2444 };
2445 
TEST_P(PReluOpTest,PReluFloat32)2446 TEST_P(PReluOpTest, PReluFloat32) {
2447   FloatPReluOpModel m({TensorType_FLOAT32, {1, 2, 2, 3}},
2448                       {TensorType_FLOAT32, {1, 1, 3}});
2449 
2450   m.SetInput({
2451       0.0f, 0.0f, 0.0f,     // Row 1, Column 1
2452       1.0f, 1.0f, 1.0f,     // Row 1, Column 2
2453       -1.0f, -1.0f, -1.0f,  // Row 2, Column 1
2454       -2.0f, -2.0f, -2.0f,  // Row 2, Column 2
2455   });
2456   m.SetAlpha({0.0f, 1.0f, 2.0f});
2457   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2458   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
2459                                  0.0f, 0.0f, 0.0f,    // Row 1, Column 1
2460                                  1.0f, 1.0f, 1.0f,    // Row 1, Column 2
2461                                  0.0f, -1.0f, -2.0f,  // Row 2, Column 1
2462                                  0.0f, -2.0f, -4.0f,  // Row 2, Column 2
2463                              }));
2464 }
2465 
TEST_P(PReluOpTest,PReluFloat32SameShapes)2466 TEST_P(PReluOpTest, PReluFloat32SameShapes) {
2467   FloatPReluOpModel m({TensorType_FLOAT32, {1, 2, 2, 3}},
2468                       {TensorType_FLOAT32, {1, 2, 2, 3}});
2469 
2470   m.SetInput({
2471       0.0f, 0.0f, 0.0f,     // Row 1, Column 1
2472       1.0f, 1.0f, 1.0f,     // Row 1, Column 2
2473       -1.0f, -1.0f, -1.0f,  // Row 2, Column 1
2474       -2.0f, -2.0f, -2.0f,  // Row 2, Column 2
2475   });
2476   m.SetAlpha({
2477       0.0f, 1.0f, 2.0f,  // Row 1, Column 1
2478       0.0f, 1.0f, 2.0f,  // Row 1, Column 2
2479       0.0f, 1.0f, 2.0f,  // Row 2, Column 1
2480       0.0f, 1.0f, 2.0f,  // Row 2, Column 2
2481   });
2482   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2483   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
2484                                  0.0f, 0.0f, 0.0f,    // Row 1, Column 1
2485                                  1.0f, 1.0f, 1.0f,    // Row 1, Column 2
2486                                  0.0f, -1.0f, -2.0f,  // Row 2, Column 1
2487                                  0.0f, -2.0f, -4.0f,  // Row 2, Column 2
2488                              }));
2489 }
2490 
TEST_P(PReluOpTest,PReluUInt8)2491 TEST_P(PReluOpTest, PReluUInt8) {
2492   const float kMin = -1;
2493   const float kMax = 127.f / 128.f;
2494   QuantizedPReluOpModel m({TensorType_UINT8, {1, 2, 2, 3}, kMin, kMax},
2495                           {TensorType_UINT8, {1, 1, 3}, kMin, kMax});
2496   m.SetInput<uint8_t>({
2497       0.0f, 0.0f, 0.0f,        // Row 1, Column 1
2498       0.5f, 0.5f, 0.5f,        // Row 1, Column 2
2499       -1.0f, -1.0f, -1.0f,     // Row 2, Column 1
2500       -0.25f, -0.25f, -0.25f,  // Row 2, Column 2
2501   });
2502   m.SetAlpha<uint8_t>({0.0f, 0.5f, -0.5f});
2503   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2504   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
2505               ElementsAreArray(ArrayFloatNear(
2506                   {
2507                       0.0f, 0.0f, 0.0f,       // Row 1, Column 1
2508                       0.5f, 0.5f, 0.5f,       // Row 1, Column 2
2509                       0.0f, -0.5f, 0.5f,      // Row 2, Column 1
2510                       0.0f, -0.125f, 0.125f,  // Row 2, Column 2
2511                   },
2512                   kQuantizedTolerance)));
2513   EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({
2514                                           128, 128, 128,  // Row 1, Column 1
2515                                           192, 192, 192,  // Row 1, Column 2
2516                                           128, 64, 192,   // Row 2, Column 1
2517                                           128, 112, 144,  // Row 2, Column 2
2518                                       }));
2519 }
2520 
TEST_P(PReluOpTest,PReluUInt8SameShapes)2521 TEST_P(PReluOpTest, PReluUInt8SameShapes) {
2522   const float kMin = -1;
2523   const float kMax = 127.f / 128.f;
2524   QuantizedPReluOpModel m({TensorType_UINT8, {1, 2, 2, 3}, kMin, kMax},
2525                           {TensorType_UINT8, {1, 2, 2, 3}, kMin, kMax});
2526   m.SetInput<uint8_t>({
2527       0.0f, 0.0f, 0.0f,        // Row 1, Column 1
2528       0.5f, 0.5f, 0.5f,        // Row 1, Column 2
2529       -1.0f, -1.0f, -1.0f,     // Row 2, Column 1
2530       -0.25f, -0.25f, -0.25f,  // Row 2, Column 2
2531   });
2532   m.SetAlpha<uint8_t>({
2533       0.0f, 0.5f, -0.5f,  // Row 1, Column 1
2534       0.0f, 0.5f, -0.5f,  // Row 1, Column 2
2535       0.0f, 0.5f, -0.5f,  // Row 2, Column 1
2536       0.0f, 0.5f, -0.5f,  // Row 2, Column 2
2537   });
2538   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2539   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
2540               ElementsAreArray(ArrayFloatNear(
2541                   {
2542                       0.0f, 0.0f, 0.0f,       // Row 1, Column 1
2543                       0.5f, 0.5f, 0.5f,       // Row 1, Column 2
2544                       0.0f, -0.5f, 0.5f,      // Row 2, Column 1
2545                       0.0f, -0.125f, 0.125f,  // Row 2, Column 2
2546                   },
2547                   kQuantizedTolerance)));
2548   EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({
2549                                           128, 128, 128,  // Row 1, Column 1
2550                                           192, 192, 192,  // Row 1, Column 2
2551                                           128, 64, 192,   // Row 2, Column 1
2552                                           128, 112, 144,  // Row 2, Column 2
2553                                       }));
2554 }
2555 
TEST_P(PReluOpTest,PReluInt8)2556 TEST_P(PReluOpTest, PReluInt8) {
2557   const float kMin = -1;
2558   const float kMax = 127.f / 128.f;
2559   QuantizedPReluOpModel m({TensorType_INT8, {1, 2, 2, 3}, kMin, kMax},
2560                           {TensorType_INT8, {1, 1, 3}, kMin, kMax});
2561   m.SetInput<int8_t>({
2562       0.0f, 0.0f, 0.0f,        // Row 1, Column 1
2563       0.5f, 0.5f, 0.5f,        // Row 1, Column 2
2564       -1.0f, -1.0f, -1.0f,     // Row 2, Column 1
2565       -0.25f, -0.25f, -0.25f,  // Row 2, Column 2
2566   });
2567   m.SetAlpha<int8_t>({0.0f, 0.5f, -0.5f});
2568   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2569   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
2570               ElementsAreArray(ArrayFloatNear(
2571                   {
2572                       0.0f, 0.0f, 0.0f,       // Row 1, Column 1
2573                       0.5f, 0.5f, 0.5f,       // Row 1, Column 2
2574                       0.0f, -0.5f, 0.5f,      // Row 2, Column 1
2575                       0.0f, -0.125f, 0.125f,  // Row 2, Column 2
2576                   },
2577                   kQuantizedTolerance)));
2578   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({
2579                                          0, 0, 0,     // Row 1, Column 1
2580                                          64, 64, 64,  // Row 1, Column 2
2581                                          0, -64, 64,  // Row 2, Column 1
2582                                          0, -16, 16,  // Row 2, Column 2
2583                                      }));
2584 }
2585 
TEST_P(PReluOpTest,PReluInt8SameShapes)2586 TEST_P(PReluOpTest, PReluInt8SameShapes) {
2587   const float kMin = -1;
2588   const float kMax = 127.f / 128.f;
2589   QuantizedPReluOpModel m({TensorType_INT8, {1, 2, 2, 3}, kMin, kMax},
2590                           {TensorType_INT8, {1, 1, 3}, kMin, kMax});
2591   m.SetInput<int8_t>({
2592       0.0f, 0.0f, 0.0f,        // Row 1, Column 1
2593       0.5f, 0.5f, 0.5f,        // Row 1, Column 2
2594       -1.0f, -1.0f, -1.0f,     // Row 2, Column 1
2595       -0.25f, -0.25f, -0.25f,  // Row 2, Column 2
2596   });
2597   m.SetAlpha<int8_t>({0.0f, 0.5f, -0.5f});
2598   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2599   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
2600               ElementsAreArray(ArrayFloatNear(
2601                   {
2602                       0.0f, 0.0f, 0.0f,       // Row 1, Column 1
2603                       0.5f, 0.5f, 0.5f,       // Row 1, Column 2
2604                       0.0f, -0.5f, 0.5f,      // Row 2, Column 1
2605                       0.0f, -0.125f, 0.125f,  // Row 2, Column 2
2606                   },
2607                   kQuantizedTolerance)));
2608   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({
2609                                          0, 0, 0,     // Row 1, Column 1
2610                                          64, 64, 64,  // Row 1, Column 2
2611                                          0, -64, 64,  // Row 2, Column 1
2612                                          0, -16, 16,  // Row 2, Column 2
2613                                      }));
2614 }
2615 
2616 class LeakyReluOpModel : public SingleOpModel {
2617  public:
LeakyReluOpModel(const TensorData & input,float alpha)2618   LeakyReluOpModel(const TensorData& input, float alpha) {
2619     input_ = AddInput(input);
2620     output_ = AddOutput(input);
2621     SetBuiltinOp(BuiltinOperator_LEAKY_RELU, BuiltinOptions_LeakyReluOptions,
2622                  CreateLeakyReluOptions(builder_, alpha).Union());
2623     BuildInterpreter({GetShape(input_)});
2624   }
SetInput(std::initializer_list<float> data)2625   void SetInput(std::initializer_list<float> data) {
2626     PopulateTensor(input_, data);
2627   }
GetOutput()2628   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
2629 
2630  protected:
2631   int input_;
2632   int output_;
2633 };
2634 
TEST(FloatActivationsOpTest,LeakyRelu)2635 TEST(FloatActivationsOpTest, LeakyRelu) {
2636   LeakyReluOpModel m({TensorType_FLOAT32, {2, 3}}, 0.5f);
2637 
2638   m.SetInput({
2639       0.0f, 1.0f, 3.0f,    // Row 1
2640       1.0f, -1.0f, -2.0f,  // Row 2
2641   });
2642   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2643   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
2644                                  0.0f, 1.0f, 3.0f,    // Row 1
2645                                  1.0f, -0.5f, -1.0f,  // Row 2
2646                              }));
2647 }
2648 
2649 class GeluOpModel : public SingleOpModel {
2650  public:
GeluOpModel(const TensorData & input,bool approximate)2651   GeluOpModel(const TensorData& input, bool approximate) {
2652     input_ = AddInput(input);
2653     output_ = AddOutput(input);
2654     SetBuiltinOp(BuiltinOperator_GELU, BuiltinOptions_GeluOptions,
2655                  CreateGeluOptions(builder_, approximate).Union());
2656     BuildInterpreter({GetShape(input_)});
2657   }
SetInput(std::initializer_list<float> data)2658   void SetInput(std::initializer_list<float> data) {
2659     PopulateTensor(input_, data);
2660   }
GetOutput()2661   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
2662 
2663  protected:
2664   int input_;
2665   int output_;
2666 };
2667 
2668 class BaseGeluOpModel : public SingleOpModel {
2669  public:
BaseGeluOpModel(const TensorData & input,bool approximate)2670   BaseGeluOpModel(const TensorData& input, bool approximate) {
2671     input_ = AddInput(input);
2672     approximate_ = approximate;
2673     output_ = AddOutput({input.type, input.shape, input.min, input.max});
2674     SetBuiltinOp(BuiltinOperator_GELU, BuiltinOptions_GeluOptions,
2675                  CreateGeluOptions(builder_, approximate).Union());
2676     BuildInterpreter({GetShape(input_)});
2677   }
2678 
2679  protected:
2680   int input_;
2681   bool approximate_;
2682   int output_;
2683 };
2684 
2685 // The FloatGeluOpModel class handles float input and output.
2686 class FloatGeluOpModel : public BaseGeluOpModel {
2687  public:
2688   using BaseGeluOpModel::BaseGeluOpModel;
2689 
SetInput(std::initializer_list<float> data)2690   void SetInput(std::initializer_list<float> data) {
2691     PopulateTensor(input_, data);
2692   }
GetOutput()2693   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
2694 };
2695 
2696 // The QuantizedGeluOpModel class handles quantized input and output.
2697 class QuantizedGeluOpModel : public BaseGeluOpModel {
2698  public:
2699   using BaseGeluOpModel::BaseGeluOpModel;
2700 
2701   template <typename T>
SetInput(std::initializer_list<float> data)2702   void SetInput(std::initializer_list<float> data) {
2703     QuantizeAndPopulate<T>(input_, data);
2704   }
2705   template <typename T>
GetOutput()2706   std::vector<T> GetOutput() {
2707     return ExtractVector<T>(output_);
2708   }
2709   template <typename T>
GetDequantizedOutput()2710   std::vector<float> GetDequantizedOutput() {
2711     return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
2712                          GetZeroPoint(output_));
2713   }
2714 };
2715 
TEST(FloatActivationsOpTest,Gelu)2716 TEST(FloatActivationsOpTest, Gelu) {
2717   FloatGeluOpModel m({TensorType_FLOAT32, {2, 3}}, /*approximate=*/false);
2718 
2719   m.SetInput({
2720       0.0f, 1.0f, 3.0f,    // Row 1
2721       1.0f, -1.0f, -2.0f,  // Row 2
2722   });
2723   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2724   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
2725                                  0.0f, 0.841345f, 2.99595f,           // Row 1
2726                                  0.841345f, -0.158655f, -0.0455003f,  // Row 2
2727                              })));
2728 }
2729 
TEST(FloatActivationsOpTest,GeluApproximate)2730 TEST(FloatActivationsOpTest, GeluApproximate) {
2731   FloatGeluOpModel m({TensorType_FLOAT32, {2, 3}}, /*approximate=*/true);
2732 
2733   m.SetInput({
2734       0.0f, 1.0f, 3.0f,    // Row 1
2735       1.0f, -1.0f, -2.0f,  // Row 2
2736   });
2737   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2738   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
2739                                  0.0f, 0.841192f, 2.99636f,           // Row 1
2740                                  0.841192f, -0.158808f, -0.0454023f,  // Row 2
2741                              })));
2742 }
2743 
TEST(QuantizedGeluOpTest,GeluInt8)2744 TEST(QuantizedGeluOpTest, GeluInt8) {
2745   const float kMin = -1;
2746   const float kMax = 127.f / 128.f;
2747   QuantizedGeluOpModel m({TensorType_INT8, {2, 3}, 3 * kMin, 3 * kMax},
2748                          /*approximate=*/false);
2749   m.SetInput<int8_t>({
2750       0.0f, 1.0f, 3.0f,    // Row 1
2751       1.0f, -1.0f, -2.0f,  // Row 2
2752   });
2753   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2754   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
2755               ElementsAreArray(ArrayFloatNear({
2756                   0.f, 0.84375f, 2.97656f,          // Row 1
2757                   0.84375f, -0.164062f, -0.046875f  // Row 2
2758               })));
2759 }
2760 
TEST(QuantizedGeluOpTest,GeluInt8Approximate)2761 TEST(QuantizedGeluOpTest, GeluInt8Approximate) {
2762   const float kMin = -1;
2763   const float kMax = 127.f / 128.f;
2764   QuantizedGeluOpModel m({TensorType_INT8, {2, 3}, 3 * kMin, 3 * kMax},
2765                          /*approximate=*/true);
2766   m.SetInput<int8_t>({
2767       0.0f, 1.0f, 3.0f,    // Row 1
2768       1.0f, -1.0f, -2.0f,  // Row 2
2769   });
2770   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2771   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
2772               ElementsAreArray(ArrayFloatNear({
2773                   0.f, 0.84375f, 2.97656f,          // Row 1
2774                   0.84375f, -0.164062f, -0.046875f  // Row 2
2775               })));
2776 }
TEST(QuantizedGeluOpTest,GeluUInt8)2777 TEST(QuantizedGeluOpTest, GeluUInt8) {
2778   const float kMin = -1;
2779   const float kMax = 127.f / 128.f;
2780   QuantizedGeluOpModel m({TensorType_UINT8, {2, 3}, 3 * kMin, 3 * kMax},
2781                          /*approximate=*/false);
2782   m.SetInput<uint8_t>({
2783       0.0f, 1.0f, 3.0f,    // Row 1
2784       1.0f, -1.0f, -2.0f,  // Row 2
2785   });
2786   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2787   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
2788               ElementsAreArray(ArrayFloatNear({
2789                   0.f, 0.84375f, 2.97656f,          // Row 1
2790                   0.84375f, -0.164062f, -0.046875f  // Row 2
2791               })));
2792 }
2793 
TEST(QuantizedGeluOpTest,GeluUInt8Approximate)2794 TEST(QuantizedGeluOpTest, GeluUInt8Approximate) {
2795   const float kMin = -1;
2796   const float kMax = 127.f / 128.f;
2797   QuantizedGeluOpModel m({TensorType_UINT8, {2, 3}, 3 * kMin, 3 * kMax},
2798                          /*approximate=*/true);
2799   m.SetInput<uint8_t>({
2800       0.0f, 1.0f, 3.0f,    // Row 1
2801       1.0f, -1.0f, -2.0f,  // Row 2
2802   });
2803   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2804   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
2805               ElementsAreArray(ArrayFloatNear({
2806                   0.f, 0.84375f, 2.97656f,          // Row 1
2807                   0.84375f, -0.164062f, -0.046875f  // Row 2
2808               })));
2809 }
2810 
2811 INSTANTIATE_TEST_SUITE_P(
2812     TanhOpTest, TanhOpTest,
2813     ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kTanhKernelMap)));
2814 
2815 INSTANTIATE_TEST_SUITE_P(
2816     LogisticOpTest, LogisticOpTest,
2817     ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kLogisticKernelMap)));
2818 
2819 INSTANTIATE_TEST_SUITE_P(
2820     LogSoftmaxOpTest, LogSoftmaxOpTest,
2821     ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kLogSoftmaxKernelMap)));
2822 
2823 INSTANTIATE_TEST_SUITE_P(
2824     SoftmaxOpTest, SoftmaxOpTest,
2825     ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kSoftmaxKernelMap)));
2826 
2827 INSTANTIATE_TEST_SUITE_P(
2828     PReluOpTest, PReluOpTest,
2829     ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kPReluKernelMap)));
2830 
2831 INSTANTIATE_TEST_SUITE_P(
2832     LeakyReluOpTest, LeakyReluOpTest,
2833     ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kLeakyReluKernelMap)));
2834 
2835 }  // namespace
2836 }  // namespace tflite
2837