1 // Copyright 2021 Google LLC 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #pragma once 7 8 #include <gtest/gtest.h> 9 10 #include <algorithm> 11 #include <cassert> 12 #include <cmath> 13 #include <cstddef> 14 #include <cstdlib> 15 #include <functional> 16 #include <limits> 17 #include <random> 18 #include <vector> 19 20 #include <xnnpack.h> 21 22 23 class TanhOperatorTester { 24 public: channels(size_t channels)25 inline TanhOperatorTester& channels(size_t channels) { 26 assert(channels != 0); 27 this->channels_ = channels; 28 return *this; 29 } 30 channels()31 inline size_t channels() const { 32 return this->channels_; 33 } 34 input_stride(size_t input_stride)35 inline TanhOperatorTester& input_stride(size_t input_stride) { 36 assert(input_stride != 0); 37 this->input_stride_ = input_stride; 38 return *this; 39 } 40 input_stride()41 inline size_t input_stride() const { 42 if (this->input_stride_ == 0) { 43 return this->channels_; 44 } else { 45 assert(this->input_stride_ >= this->channels_); 46 return this->input_stride_; 47 } 48 } 49 output_stride(size_t output_stride)50 inline TanhOperatorTester& output_stride(size_t output_stride) { 51 assert(output_stride != 0); 52 this->output_stride_ = output_stride; 53 return *this; 54 } 55 output_stride()56 inline size_t output_stride() const { 57 if (this->output_stride_ == 0) { 58 return this->channels_; 59 } else { 60 assert(this->output_stride_ >= this->channels_); 61 return this->output_stride_; 62 } 63 } 64 batch_size(size_t batch_size)65 inline TanhOperatorTester& batch_size(size_t batch_size) { 66 assert(batch_size != 0); 67 this->batch_size_ = batch_size; 68 return *this; 69 } 70 batch_size()71 inline size_t batch_size() const { 72 return this->batch_size_; 73 } 74 input_scale(float input_scale)75 inline TanhOperatorTester& input_scale(float input_scale) { 76 assert(input_scale > 0.0f); 77 assert(std::isnormal(input_scale)); 78 this->input_scale_ = input_scale; 79 return *this; 80 } 81 input_scale()82 inline float input_scale() const { 83 return this->input_scale_; 84 } 85 input_zero_point(uint8_t input_zero_point)86 inline TanhOperatorTester& input_zero_point(uint8_t input_zero_point) { 87 this->input_zero_point_ = input_zero_point; 88 return *this; 89 } 90 input_zero_point()91 inline uint8_t input_zero_point() const { 92 return this->input_zero_point_; 93 } 94 output_scale()95 inline float output_scale() const { 96 return 1.0f / 128.0f; 97 } 98 output_zero_point()99 inline uint8_t output_zero_point() const { 100 return 128; 101 } 102 qmin(uint8_t qmin)103 inline TanhOperatorTester& qmin(uint8_t qmin) { 104 this->qmin_ = qmin; 105 return *this; 106 } 107 qmin()108 inline uint8_t qmin() const { 109 return this->qmin_; 110 } 111 qmax(uint8_t qmax)112 inline TanhOperatorTester& qmax(uint8_t qmax) { 113 this->qmax_ = qmax; 114 return *this; 115 } 116 qmax()117 inline uint8_t qmax() const { 118 return this->qmax_; 119 } 120 iterations(size_t iterations)121 inline TanhOperatorTester& iterations(size_t iterations) { 122 this->iterations_ = iterations; 123 return *this; 124 } 125 iterations()126 inline size_t iterations() const { 127 return this->iterations_; 128 } 129 TestQS8()130 void TestQS8() const { 131 std::random_device random_device; 132 auto rng = std::mt19937(random_device()); 133 auto i8rng = std::bind( 134 std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()), 135 std::ref(rng)); 136 137 std::vector<int8_t> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(int8_t)); 138 std::vector<int8_t> output((batch_size() - 1) * output_stride() + channels()); 139 std::vector<float> output_ref(batch_size() * channels()); 140 for (size_t iteration = 0; iteration < iterations(); iteration++) { 141 std::generate(input.begin(), input.end(), std::ref(i8rng)); 142 std::fill(output.begin(), output.end(), 0xA5); 143 144 // Compute reference results. 145 for (size_t i = 0; i < batch_size(); i++) { 146 for (size_t c = 0; c < channels(); c++) { 147 const float x = input_scale() * 148 (int32_t(input[i * input_stride() + c]) - int32_t(input_zero_point() - 0x80)); 149 const float tanh_x = std::tanh(x); 150 const float scaled_tanh_x = tanh_x / output_scale(); 151 float y = scaled_tanh_x; 152 y = std::min<float>(y, int32_t(qmax() - 0x80) - int32_t(output_zero_point() - 0x80)); 153 y = std::max<float>(y, int32_t(qmin() - 0x80) - int32_t(output_zero_point() - 0x80)); 154 output_ref[i * channels() + c] = y + int32_t(output_zero_point() - 0x80); 155 } 156 } 157 158 // Create, setup, run, and destroy Sigmoid operator. 159 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 160 xnn_operator_t tanh_op = nullptr; 161 162 ASSERT_EQ(xnn_status_success, 163 xnn_create_tanh_nc_qs8( 164 channels(), input_stride(), output_stride(), 165 int8_t(input_zero_point() - 0x80), input_scale(), 166 int8_t(output_zero_point() - 0x80), output_scale(), 167 int8_t(qmin() - 0x80), int8_t(qmax() - 0x80), 168 0, &tanh_op)); 169 ASSERT_NE(nullptr, tanh_op); 170 171 // Smart pointer to automatically delete tanh_op. 172 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_tanh_op(tanh_op, xnn_delete_operator); 173 174 ASSERT_EQ(xnn_status_success, 175 xnn_setup_tanh_nc_qs8( 176 tanh_op, 177 batch_size(), 178 input.data(), output.data(), 179 nullptr /* thread pool */)); 180 181 ASSERT_EQ(xnn_status_success, 182 xnn_run_operator(tanh_op, nullptr /* thread pool */)); 183 184 // Verify results. 185 for (size_t i = 0; i < batch_size(); i++) { 186 for (size_t c = 0; c < channels(); c++) { 187 ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.6f); 188 } 189 } 190 } 191 } 192 TestQU8()193 void TestQU8() const { 194 std::random_device random_device; 195 auto rng = std::mt19937(random_device()); 196 auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng); 197 198 std::vector<uint8_t> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint8_t)); 199 std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels()); 200 std::vector<float> output_ref(batch_size() * channels()); 201 for (size_t iteration = 0; iteration < iterations(); iteration++) { 202 std::generate(input.begin(), input.end(), std::ref(u8rng)); 203 std::fill(output.begin(), output.end(), 0xA5); 204 205 // Compute reference results. 206 for (size_t i = 0; i < batch_size(); i++) { 207 for (size_t c = 0; c < channels(); c++) { 208 const float x = input_scale() * 209 (int32_t(input[i * input_stride() + c]) - int32_t(input_zero_point())); 210 const float tanh_x = std::tanh(x); 211 const float scaled_tanh_x = tanh_x / output_scale(); 212 float y = scaled_tanh_x; 213 y = std::min<float>(y, int32_t(qmax()) - int32_t(output_zero_point())); 214 y = std::max<float>(y, int32_t(qmin()) - int32_t(output_zero_point())); 215 output_ref[i * channels() + c] = y + int32_t(output_zero_point()); 216 } 217 } 218 219 // Create, setup, run, and destroy Sigmoid operator. 220 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 221 xnn_operator_t tanh_op = nullptr; 222 223 ASSERT_EQ(xnn_status_success, 224 xnn_create_tanh_nc_qu8( 225 channels(), input_stride(), output_stride(), 226 input_zero_point(), input_scale(), 227 output_zero_point(), output_scale(), 228 qmin(), qmax(), 229 0, &tanh_op)); 230 ASSERT_NE(nullptr, tanh_op); 231 232 // Smart pointer to automatically delete tanh_op. 233 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_tanh_op(tanh_op, xnn_delete_operator); 234 235 ASSERT_EQ(xnn_status_success, 236 xnn_setup_tanh_nc_qu8( 237 tanh_op, 238 batch_size(), 239 input.data(), output.data(), 240 nullptr /* thread pool */)); 241 242 ASSERT_EQ(xnn_status_success, 243 xnn_run_operator(tanh_op, nullptr /* thread pool */)); 244 245 // Verify results. 246 for (size_t i = 0; i < batch_size(); i++) { 247 for (size_t c = 0; c < channels(); c++) { 248 ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.6f); 249 } 250 } 251 } 252 } 253 254 private: 255 size_t batch_size_{1}; 256 size_t channels_{1}; 257 size_t input_stride_{0}; 258 size_t output_stride_{0}; 259 float input_scale_{0.75f}; 260 uint8_t input_zero_point_{121}; 261 uint8_t qmin_{0}; 262 uint8_t qmax_{255}; 263 size_t iterations_{15}; 264 }; 265