1 // Copyright 2019 Google LLC 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #pragma once 7 8 #include <gtest/gtest.h> 9 10 #include <fp16.h> 11 12 #include <algorithm> 13 #include <cmath> 14 #include <cstddef> 15 #include <cstdlib> 16 #include <functional> 17 #include <random> 18 #include <vector> 19 20 #include <xnnpack.h> 21 #include <xnnpack/cache.h> 22 23 24 class PReLUOperatorTester { 25 public: 26 enum class WeightsType { 27 Default, 28 FP32, 29 }; 30 batch_size(size_t batch_size)31 inline PReLUOperatorTester& batch_size(size_t batch_size) { 32 assert(batch_size != 0); 33 this->batch_size_ = batch_size; 34 return *this; 35 } 36 batch_size()37 inline size_t batch_size() const { 38 return this->batch_size_; 39 } 40 channels(size_t channels)41 inline PReLUOperatorTester& channels(size_t channels) { 42 assert(channels != 0); 43 this->channels_ = channels; 44 return *this; 45 } 46 channels()47 inline size_t channels() const { 48 return this->channels_; 49 } 50 x_stride(size_t x_stride)51 inline PReLUOperatorTester& x_stride(size_t x_stride) { 52 assert(x_stride != 0); 53 this->x_stride_ = x_stride; 54 return *this; 55 } 56 x_stride()57 inline size_t x_stride() const { 58 if (this->x_stride_ == 0) { 59 return this->channels_; 60 } else { 61 assert(this->x_stride_ >= this->channels_); 62 return this->x_stride_; 63 } 64 } 65 y_stride(size_t y_stride)66 inline PReLUOperatorTester& y_stride(size_t y_stride) { 67 assert(y_stride != 0); 68 this->y_stride_ = y_stride; 69 return *this; 70 } 71 y_stride()72 inline size_t y_stride() const { 73 if (this->y_stride_ == 0) { 74 return this->channels_; 75 } else { 76 assert(this->y_stride_ >= this->channels_); 77 return this->y_stride_; 78 } 79 } 80 weights_type(WeightsType weights_type)81 inline PReLUOperatorTester& weights_type(WeightsType weights_type) { 82 this->weights_type_ = weights_type; 83 return *this; 84 } 85 weights_type()86 inline WeightsType weights_type() const { 87 return this->weights_type_; 88 } 89 iterations(size_t iterations)90 inline PReLUOperatorTester& iterations(size_t iterations) { 91 this->iterations_ = iterations; 92 return *this; 93 } 94 iterations()95 inline size_t iterations() const { 96 return this->iterations_; 97 } 98 use_weights_cache(bool use_weights_cache)99 inline PReLUOperatorTester& use_weights_cache(bool use_weights_cache) { 100 this->use_weights_cache_ = use_weights_cache; 101 return *this; 102 } 103 use_weights_cache()104 inline bool use_weights_cache() const { 105 return this->use_weights_cache_; 106 } 107 TestF16()108 void TestF16() const { 109 switch (weights_type()) { 110 case WeightsType::Default: 111 break; 112 case WeightsType::FP32: 113 break; 114 default: 115 GTEST_FAIL() << "unexpected weights type"; 116 } 117 118 std::random_device random_device; 119 auto rng = std::mt19937(random_device()); 120 auto f32irng = std::uniform_real_distribution<float>(-1.0f, 1.0f); 121 auto f32wrng = std::uniform_real_distribution<float>(0.25f, 0.75f); 122 123 std::vector<uint16_t> x((batch_size() - 1) * x_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 124 std::vector<uint16_t> w(channels()); 125 std::vector<float> w_as_float(channels()); 126 std::vector<uint16_t> y((batch_size() - 1) * y_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 127 std::vector<float> y_ref(batch_size() * channels()); 128 for (size_t iteration = 0; iteration < iterations(); iteration++) { 129 std::generate(x.begin(), x.end(), [&] { return fp16_ieee_from_fp32_value(f32irng(rng)); }); 130 std::generate(w.begin(), w.end(), [&] { return fp16_ieee_from_fp32_value(f32wrng(rng)); }); 131 std::transform(w.cbegin(), w.cend(), w_as_float.begin(), fp16_ieee_to_fp32_value); 132 std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */); 133 134 // Compute reference results, without clamping. 135 for (size_t i = 0; i < batch_size(); i++) { 136 for (size_t c = 0; c < channels(); c++) { 137 const float x_value = fp16_ieee_to_fp32_value(x[i * x_stride() + c]); 138 const float w_value = w_as_float[c]; 139 y_ref[i * channels() + c] = std::signbit(x_value) ? x_value * w_value : x_value; 140 } 141 } 142 143 // Create, setup, run, and destroy PReLU operator. 144 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 145 xnn_operator_t prelu_op = nullptr; 146 147 xnn_caches caches = { 148 .code_cache = NULL, 149 .weights_cache = NULL, 150 }; 151 xnn_weights_cache weights_cache; 152 if (use_weights_cache()) { 153 xnn_init_weights_cache(&weights_cache); 154 caches.weights_cache = &weights_cache; 155 } 156 157 const void* negative_slope_data = w.data(); 158 if (weights_type() == WeightsType::FP32) { 159 negative_slope_data = w_as_float.data(); 160 } 161 uint32_t flags = 0; 162 if (weights_type() == WeightsType::FP32) { 163 flags |= XNN_FLAG_FP32_STATIC_WEIGHTS; 164 } 165 ASSERT_EQ(xnn_status_success, 166 xnn_create_prelu_nc_f16( 167 channels(), x_stride(), y_stride(), 168 negative_slope_data, 169 flags, &caches, &prelu_op)); 170 ASSERT_NE(nullptr, prelu_op); 171 if (use_weights_cache()) { 172 ASSERT_EQ(xnn_status_success, 173 xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft)); 174 } 175 176 // Smart pointer to automatically delete prelu_op. 177 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op, xnn_delete_operator); 178 179 ASSERT_EQ(xnn_status_success, 180 xnn_setup_prelu_nc_f16( 181 prelu_op, 182 batch_size(), 183 x.data(), y.data(), 184 nullptr /* thread pool */)); 185 186 ASSERT_EQ(xnn_status_success, 187 xnn_run_operator(prelu_op, nullptr /* thread pool */)); 188 189 VerifyF16(y, y_ref); 190 191 if (use_weights_cache()) { 192 xnn_operator_t prelu_op2 = nullptr; 193 const size_t old_weights_cache_size = weights_cache.cache.weights.size; 194 195 ASSERT_EQ(xnn_status_success, 196 xnn_create_prelu_nc_f16( 197 channels(), x_stride(), y_stride(), 198 negative_slope_data, 199 flags, &caches, &prelu_op2)); 200 ASSERT_NE(nullptr, prelu_op2); 201 202 // Smart pointer to automatically delete prelu_op2. 203 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op2, xnn_delete_operator); 204 205 std::vector<uint16_t> y2(y.size(), UINT16_C(0x7E00) /* NaN */); 206 ASSERT_EQ(xnn_status_success, 207 xnn_setup_prelu_nc_f16( 208 prelu_op2, 209 batch_size(), 210 x.data(), y2.data(), 211 nullptr /* thread pool */)); 212 213 ASSERT_EQ(xnn_status_success, 214 xnn_run_operator(prelu_op2, nullptr /* thread pool */)); 215 216 VerifyF16(y2, y_ref); 217 VerifyWeightsCache(weights_cache, old_weights_cache_size); 218 xnn_release_weights_cache(&weights_cache); 219 } 220 } 221 } 222 VerifyF16(const std::vector<uint16_t> & y,const std::vector<float> & y_ref)223 void VerifyF16(const std::vector<uint16_t>& y, const std::vector<float>& y_ref) const { 224 for (size_t i = 0; i < batch_size(); i++) { 225 for (size_t c = 0; c < channels(); c++) { 226 ASSERT_NEAR( 227 fp16_ieee_to_fp32_value(y[i * y_stride() + c]), 228 y_ref[i * channels() + c], 229 std::max(1.0e-4f, std::abs(y_ref[i * channels() + c]) * 1.0e-3f)) 230 << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 231 } 232 } 233 } 234 TestF32()235 void TestF32() const { 236 ASSERT_EQ(weights_type(), WeightsType::Default); 237 238 std::random_device random_device; 239 auto rng = std::mt19937(random_device()); 240 auto f32irng = std::uniform_real_distribution<float>(-1.0f, 1.0f); 241 auto f32wrng = std::uniform_real_distribution<float>(0.25f, 0.75f); 242 243 std::vector<float> x((batch_size() - 1) * x_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float)); 244 std::vector<float> w(channels()); 245 std::vector<float> y((batch_size() - 1) * y_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float)); 246 std::vector<float> y_ref(batch_size() * channels()); 247 for (size_t iteration = 0; iteration < iterations(); iteration++) { 248 std::generate(x.begin(), x.end(), [&] { return f32irng(rng);} ); 249 std::generate(w.begin(), w.end(), [&] { return f32wrng(rng);} ); 250 std::fill(y.begin(), y.end(), nanf("")); 251 252 // Compute reference results, without clamping. 253 for (size_t i = 0; i < batch_size(); i++) { 254 for (size_t c = 0; c < channels(); c++) { 255 y_ref[i * channels() + c] = std::signbit(x[i * x_stride() + c]) ? x[i * x_stride() + c] * w[c] : x[i * x_stride() + c]; 256 } 257 } 258 259 // Create, setup, run, and destroy PReLU operator. 260 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 261 xnn_operator_t prelu_op = nullptr; 262 263 xnn_caches caches = { 264 .code_cache = NULL, 265 .weights_cache = NULL, 266 }; 267 xnn_weights_cache weights_cache; 268 if (use_weights_cache()) { 269 xnn_init_weights_cache(&weights_cache); 270 caches.weights_cache = &weights_cache; 271 } 272 273 ASSERT_EQ(xnn_status_success, 274 xnn_create_prelu_nc_f32( 275 channels(), x_stride(), y_stride(), 276 w.data(), 277 0, &caches, &prelu_op)); 278 ASSERT_NE(nullptr, prelu_op); 279 if (use_weights_cache()) { 280 ASSERT_EQ(xnn_status_success, 281 xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft)); 282 } 283 284 // Smart pointer to automatically delete prelu_op. 285 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op, xnn_delete_operator); 286 287 ASSERT_EQ(xnn_status_success, 288 xnn_setup_prelu_nc_f32( 289 prelu_op, 290 batch_size(), 291 x.data(), y.data(), 292 nullptr /* thread pool */)); 293 294 ASSERT_EQ(xnn_status_success, 295 xnn_run_operator(prelu_op, nullptr /* thread pool */)); 296 297 VerifyF32(y, y_ref); 298 299 if (use_weights_cache()) { 300 xnn_operator_t prelu_op2 = nullptr; 301 const size_t old_weights_cache_size = weights_cache.cache.weights.size; 302 303 ASSERT_EQ(xnn_status_success, 304 xnn_create_prelu_nc_f32( 305 channels(), x_stride(), y_stride(), 306 w.data(), 307 0, &caches, &prelu_op2)); 308 ASSERT_NE(nullptr, prelu_op2); 309 310 // Smart pointer to automatically delete prelu_op2. 311 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op2, xnn_delete_operator); 312 std::vector<float> y2(y.size(), nanf("")); 313 314 ASSERT_EQ(xnn_status_success, 315 xnn_setup_prelu_nc_f32( 316 prelu_op2, 317 batch_size(), 318 x.data(), y2.data(), 319 nullptr /* thread pool */)); 320 321 ASSERT_EQ(xnn_status_success, 322 xnn_run_operator(prelu_op2, nullptr /* thread pool */)); 323 324 VerifyF32(y, y_ref); 325 VerifyWeightsCache(weights_cache, old_weights_cache_size); 326 xnn_release_weights_cache(&weights_cache); 327 } 328 } 329 } 330 VerifyF32(const std::vector<float> & y,const std::vector<float> & y_ref)331 void VerifyF32(const std::vector<float>& y, const std::vector<float>& y_ref) const { 332 for (size_t i = 0; i < batch_size(); i++) { 333 for (size_t c = 0; c < channels(); c++) { 334 ASSERT_NEAR( 335 y[i * y_stride() + c], 336 y_ref[i * channels() + c], 337 std::max(1.0e-6f, std::abs(y_ref[i * channels() + c]) * 1.0e-6f)) 338 << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 339 } 340 } 341 } 342 VerifyWeightsCache(const xnn_weights_cache & weights_cache,size_t old_size)343 void VerifyWeightsCache(const xnn_weights_cache& weights_cache, size_t old_size) const { 344 ASSERT_EQ(weights_cache.cache.hits, 1); 345 // Ensure that we did not write more weights to the cache because it was a cache hit. 346 ASSERT_EQ(old_size, weights_cache.cache.weights.size); 347 }; 348 349 private: 350 size_t batch_size_{1}; 351 size_t channels_{1}; 352 size_t x_stride_{0}; 353 size_t y_stride_{0}; 354 WeightsType weights_type_{WeightsType::Default}; 355 bool use_weights_cache_{false}; 356 size_t iterations_{15}; 357 }; 358