1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_QUANTIZED_BINARY_ELEMENTWISE_TESTER_H_ 17 #define TENSORFLOW_LITE_DELEGATES_XNNPACK_QUANTIZED_BINARY_ELEMENTWISE_TESTER_H_ 18 19 #include <cstdint> 20 #include <vector> 21 22 #include <gtest/gtest.h> 23 #include "tensorflow/lite/c/common.h" 24 #include "tensorflow/lite/interpreter.h" 25 #include "tensorflow/lite/schema/schema_generated.h" 26 27 namespace tflite { 28 namespace xnnpack { 29 30 class QuantizedBinaryElementwiseTester { 31 public: 32 QuantizedBinaryElementwiseTester() = default; 33 QuantizedBinaryElementwiseTester(const QuantizedBinaryElementwiseTester&) = 34 delete; 35 QuantizedBinaryElementwiseTester& operator=( 36 const QuantizedBinaryElementwiseTester&) = delete; 37 Input1Shape(std::initializer_list<int32_t> shape)38 inline QuantizedBinaryElementwiseTester& Input1Shape( 39 std::initializer_list<int32_t> shape) { 40 for (auto it = shape.begin(); it != shape.end(); ++it) { 41 EXPECT_GT(*it, 0); 42 } 43 input1_shape_ = std::vector<int32_t>(shape.begin(), shape.end()); 44 return *this; 45 } 46 Input1Shape()47 inline const std::vector<int32_t>& Input1Shape() const { 48 return input1_shape_; 49 } 50 Input2Shape(std::initializer_list<int32_t> shape)51 inline QuantizedBinaryElementwiseTester& Input2Shape( 52 std::initializer_list<int32_t> shape) { 53 for (auto it = shape.begin(); it != shape.end(); ++it) { 54 EXPECT_GT(*it, 0); 55 } 56 input2_shape_ = std::vector<int32_t>(shape.begin(), shape.end()); 57 return *this; 58 } 59 Input2Shape()60 inline const std::vector<int32_t>& Input2Shape() const { 61 return input2_shape_; 62 } 63 64 std::vector<int32_t> OutputShape() const; 65 Input1Static(bool is_static)66 inline QuantizedBinaryElementwiseTester& Input1Static(bool is_static) { 67 input1_static_ = is_static; 68 return *this; 69 } 70 Input1Static()71 inline bool Input1Static() const { return input1_static_; } 72 Input2Static(bool is_static)73 inline QuantizedBinaryElementwiseTester& Input2Static(bool is_static) { 74 input2_static_ = is_static; 75 return *this; 76 } 77 Input2Static()78 inline bool Input2Static() const { return input2_static_; } 79 Input1ZeroPoint(int32_t input1_zero_point)80 inline QuantizedBinaryElementwiseTester& Input1ZeroPoint( 81 int32_t input1_zero_point) { 82 input1_zero_point_ = input1_zero_point; 83 return *this; 84 } 85 Input1ZeroPoint()86 inline int32_t Input1ZeroPoint() const { return input1_zero_point_; } 87 Input2ZeroPoint(int32_t input2_zero_point)88 inline QuantizedBinaryElementwiseTester& Input2ZeroPoint( 89 int32_t input2_zero_point) { 90 input2_zero_point_ = input2_zero_point; 91 return *this; 92 } 93 Input2ZeroPoint()94 inline int32_t Input2ZeroPoint() const { return input2_zero_point_; } 95 OutputZeroPoint(int32_t output_zero_point)96 inline QuantizedBinaryElementwiseTester& OutputZeroPoint( 97 int32_t output_zero_point) { 98 output_zero_point_ = output_zero_point; 99 return *this; 100 } 101 OutputZeroPoint()102 inline int32_t OutputZeroPoint() const { return output_zero_point_; } 103 Input1Scale(float input1_scale)104 inline QuantizedBinaryElementwiseTester& Input1Scale(float input1_scale) { 105 input1_scale_ = input1_scale; 106 return *this; 107 } 108 Input1Scale()109 inline float Input1Scale() const { return input1_scale_; } 110 Input2Scale(float input2_scale)111 inline QuantizedBinaryElementwiseTester& Input2Scale(float input2_scale) { 112 input2_scale_ = input2_scale; 113 return *this; 114 } 115 Input2Scale()116 inline float Input2Scale() const { return input2_scale_; } 117 OutputScale(float output_scale)118 inline QuantizedBinaryElementwiseTester& OutputScale(float output_scale) { 119 output_scale_ = output_scale; 120 return *this; 121 } 122 OutputScale()123 inline float OutputScale() const { return output_scale_; } 124 Unsigned(bool is_unsigned)125 inline QuantizedBinaryElementwiseTester& Unsigned(bool is_unsigned) { 126 unsigned_ = is_unsigned; 127 return *this; 128 } 129 Unsigned()130 inline bool Unsigned() const { return unsigned_; } 131 ReluActivation()132 inline QuantizedBinaryElementwiseTester& ReluActivation() { 133 activation_ = ::tflite::ActivationFunctionType_RELU; 134 return *this; 135 } 136 Relu6Activation()137 inline QuantizedBinaryElementwiseTester& Relu6Activation() { 138 activation_ = ::tflite::ActivationFunctionType_RELU6; 139 return *this; 140 } 141 ReluMinus1To1Activation()142 inline QuantizedBinaryElementwiseTester& ReluMinus1To1Activation() { 143 activation_ = ::tflite::ActivationFunctionType_RELU_N1_TO_1; 144 return *this; 145 } 146 147 template <class T> 148 void Test(Interpreter* delegate_interpreter, 149 Interpreter* default_interpreter) const; 150 151 void Test(tflite::BuiltinOperator binary_op, TfLiteDelegate* delegate) const; 152 153 private: 154 std::vector<char> CreateTfLiteModel(tflite::BuiltinOperator binary_op) const; 155 Activation()156 inline ::tflite::ActivationFunctionType Activation() const { 157 return activation_; 158 } 159 160 static int32_t ComputeSize(const std::vector<int32_t>& shape); 161 162 std::vector<int32_t> input1_shape_; 163 std::vector<int32_t> input2_shape_; 164 bool input1_static_ = false; 165 bool input2_static_ = false; 166 int32_t input1_zero_point_ = 0; 167 int32_t input2_zero_point_ = 0; 168 int32_t output_zero_point_ = 0; 169 float input1_scale_ = 0.75f; 170 float input2_scale_ = 1.0f; 171 float output_scale_ = 1.75f; 172 bool unsigned_ = false; 173 ::tflite::ActivationFunctionType activation_ = 174 ::tflite::ActivationFunctionType_NONE; 175 }; 176 177 } // namespace xnnpack 178 } // namespace tflite 179 180 #endif // TENSORFLOW_LITE_DELEGATES_XNNPACK_QUANTIZED_BINARY_ELEMENTWISE_TESTER_H_ 181