xref: /aosp_15_r20/external/XNNPACK/test/prelu-operator-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 #include <gtest/gtest.h>
9 
10 #include <fp16.h>
11 
12 #include <algorithm>
13 #include <cmath>
14 #include <cstddef>
15 #include <cstdlib>
16 #include <functional>
17 #include <random>
18 #include <vector>
19 
20 #include <xnnpack.h>
21 #include <xnnpack/cache.h>
22 
23 
24 class PReLUOperatorTester {
25  public:
26   enum class WeightsType {
27     Default,
28     FP32,
29   };
30 
batch_size(size_t batch_size)31   inline PReLUOperatorTester& batch_size(size_t batch_size) {
32     assert(batch_size != 0);
33     this->batch_size_ = batch_size;
34     return *this;
35   }
36 
batch_size()37   inline size_t batch_size() const {
38     return this->batch_size_;
39   }
40 
channels(size_t channels)41   inline PReLUOperatorTester& channels(size_t channels) {
42     assert(channels != 0);
43     this->channels_ = channels;
44     return *this;
45   }
46 
channels()47   inline size_t channels() const {
48     return this->channels_;
49   }
50 
x_stride(size_t x_stride)51   inline PReLUOperatorTester& x_stride(size_t x_stride) {
52     assert(x_stride != 0);
53     this->x_stride_ = x_stride;
54     return *this;
55   }
56 
x_stride()57   inline size_t x_stride() const {
58     if (this->x_stride_ == 0) {
59       return this->channels_;
60     } else {
61       assert(this->x_stride_ >= this->channels_);
62       return this->x_stride_;
63     }
64   }
65 
y_stride(size_t y_stride)66   inline PReLUOperatorTester& y_stride(size_t y_stride) {
67     assert(y_stride != 0);
68     this->y_stride_ = y_stride;
69     return *this;
70   }
71 
y_stride()72   inline size_t y_stride() const {
73     if (this->y_stride_ == 0) {
74       return this->channels_;
75     } else {
76       assert(this->y_stride_ >= this->channels_);
77       return this->y_stride_;
78     }
79   }
80 
weights_type(WeightsType weights_type)81   inline PReLUOperatorTester& weights_type(WeightsType weights_type) {
82     this->weights_type_ = weights_type;
83     return *this;
84   }
85 
weights_type()86   inline WeightsType weights_type() const {
87     return this->weights_type_;
88   }
89 
iterations(size_t iterations)90   inline PReLUOperatorTester& iterations(size_t iterations) {
91     this->iterations_ = iterations;
92     return *this;
93   }
94 
iterations()95   inline size_t iterations() const {
96     return this->iterations_;
97   }
98 
use_weights_cache(bool use_weights_cache)99   inline PReLUOperatorTester& use_weights_cache(bool use_weights_cache) {
100     this->use_weights_cache_ = use_weights_cache;
101     return *this;
102   }
103 
use_weights_cache()104   inline bool use_weights_cache() const {
105     return this->use_weights_cache_;
106   }
107 
TestF16()108   void TestF16() const {
109     switch (weights_type()) {
110       case WeightsType::Default:
111         break;
112       case WeightsType::FP32:
113         break;
114       default:
115         GTEST_FAIL() << "unexpected weights type";
116     }
117 
118     std::random_device random_device;
119     auto rng = std::mt19937(random_device());
120     auto f32irng = std::uniform_real_distribution<float>(-1.0f, 1.0f);
121     auto f32wrng = std::uniform_real_distribution<float>(0.25f, 0.75f);
122 
123     std::vector<uint16_t> x((batch_size() - 1) * x_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
124     std::vector<uint16_t> w(channels());
125     std::vector<float> w_as_float(channels());
126     std::vector<uint16_t> y((batch_size() - 1) * y_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
127     std::vector<float> y_ref(batch_size() * channels());
128     for (size_t iteration = 0; iteration < iterations(); iteration++) {
129       std::generate(x.begin(), x.end(), [&] { return fp16_ieee_from_fp32_value(f32irng(rng)); });
130       std::generate(w.begin(), w.end(), [&] { return fp16_ieee_from_fp32_value(f32wrng(rng)); });
131       std::transform(w.cbegin(), w.cend(), w_as_float.begin(), fp16_ieee_to_fp32_value);
132       std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
133 
134       // Compute reference results, without clamping.
135       for (size_t i = 0; i < batch_size(); i++) {
136         for (size_t c = 0; c < channels(); c++) {
137           const float x_value = fp16_ieee_to_fp32_value(x[i * x_stride() + c]);
138           const float w_value = w_as_float[c];
139           y_ref[i * channels() + c] = std::signbit(x_value) ? x_value * w_value : x_value;
140         }
141       }
142 
143       // Create, setup, run, and destroy PReLU operator.
144       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
145       xnn_operator_t prelu_op = nullptr;
146 
147       xnn_caches caches = {
148         .code_cache = NULL,
149         .weights_cache = NULL,
150       };
151       xnn_weights_cache weights_cache;
152       if (use_weights_cache()) {
153         xnn_init_weights_cache(&weights_cache);
154         caches.weights_cache = &weights_cache;
155       }
156 
157       const void* negative_slope_data = w.data();
158       if (weights_type() == WeightsType::FP32) {
159         negative_slope_data = w_as_float.data();
160       }
161       uint32_t flags = 0;
162       if (weights_type() == WeightsType::FP32) {
163         flags |= XNN_FLAG_FP32_STATIC_WEIGHTS;
164       }
165       ASSERT_EQ(xnn_status_success,
166         xnn_create_prelu_nc_f16(
167           channels(), x_stride(), y_stride(),
168           negative_slope_data,
169           flags, &caches, &prelu_op));
170       ASSERT_NE(nullptr, prelu_op);
171       if (use_weights_cache()) {
172         ASSERT_EQ(xnn_status_success,
173                   xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
174       }
175 
176       // Smart pointer to automatically delete prelu_op.
177       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op, xnn_delete_operator);
178 
179       ASSERT_EQ(xnn_status_success,
180         xnn_setup_prelu_nc_f16(
181           prelu_op,
182           batch_size(),
183           x.data(), y.data(),
184           nullptr /* thread pool */));
185 
186       ASSERT_EQ(xnn_status_success,
187         xnn_run_operator(prelu_op, nullptr /* thread pool */));
188 
189       VerifyF16(y, y_ref);
190 
191       if (use_weights_cache()) {
192         xnn_operator_t prelu_op2 = nullptr;
193         const size_t old_weights_cache_size = weights_cache.cache.weights.size;
194 
195         ASSERT_EQ(xnn_status_success,
196                   xnn_create_prelu_nc_f16(
197                       channels(), x_stride(), y_stride(),
198                       negative_slope_data,
199                       flags, &caches, &prelu_op2));
200         ASSERT_NE(nullptr, prelu_op2);
201 
202         // Smart pointer to automatically delete prelu_op2.
203         std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op2, xnn_delete_operator);
204 
205         std::vector<uint16_t> y2(y.size(), UINT16_C(0x7E00) /* NaN */);
206         ASSERT_EQ(xnn_status_success,
207                   xnn_setup_prelu_nc_f16(
208                       prelu_op2,
209                       batch_size(),
210                       x.data(), y2.data(),
211                       nullptr /* thread pool */));
212 
213         ASSERT_EQ(xnn_status_success,
214                   xnn_run_operator(prelu_op2, nullptr /* thread pool */));
215 
216         VerifyF16(y2, y_ref);
217         VerifyWeightsCache(weights_cache, old_weights_cache_size);
218         xnn_release_weights_cache(&weights_cache);
219       }
220     }
221   }
222 
VerifyF16(const std::vector<uint16_t> & y,const std::vector<float> & y_ref)223   void VerifyF16(const std::vector<uint16_t>& y, const std::vector<float>& y_ref) const {
224     for (size_t i = 0; i < batch_size(); i++) {
225       for (size_t c = 0; c < channels(); c++) {
226         ASSERT_NEAR(
227             fp16_ieee_to_fp32_value(y[i * y_stride() + c]),
228             y_ref[i * channels() + c],
229             std::max(1.0e-4f, std::abs(y_ref[i * channels() + c]) * 1.0e-3f))
230             << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels();
231       }
232     }
233   }
234 
TestF32()235   void TestF32() const {
236     ASSERT_EQ(weights_type(), WeightsType::Default);
237 
238     std::random_device random_device;
239     auto rng = std::mt19937(random_device());
240     auto f32irng = std::uniform_real_distribution<float>(-1.0f, 1.0f);
241     auto f32wrng = std::uniform_real_distribution<float>(0.25f, 0.75f);
242 
243     std::vector<float> x((batch_size() - 1) * x_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
244     std::vector<float> w(channels());
245     std::vector<float> y((batch_size() - 1) * y_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
246     std::vector<float> y_ref(batch_size() * channels());
247     for (size_t iteration = 0; iteration < iterations(); iteration++) {
248       std::generate(x.begin(), x.end(), [&] { return f32irng(rng);} );
249       std::generate(w.begin(), w.end(), [&] { return f32wrng(rng);} );
250       std::fill(y.begin(), y.end(), nanf(""));
251 
252       // Compute reference results, without clamping.
253       for (size_t i = 0; i < batch_size(); i++) {
254         for (size_t c = 0; c < channels(); c++) {
255           y_ref[i * channels() + c] = std::signbit(x[i * x_stride() + c]) ? x[i * x_stride() + c] * w[c] : x[i * x_stride() + c];
256         }
257       }
258 
259       // Create, setup, run, and destroy PReLU operator.
260       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
261       xnn_operator_t prelu_op = nullptr;
262 
263       xnn_caches caches = {
264         .code_cache = NULL,
265         .weights_cache = NULL,
266       };
267       xnn_weights_cache weights_cache;
268       if (use_weights_cache()) {
269         xnn_init_weights_cache(&weights_cache);
270         caches.weights_cache = &weights_cache;
271       }
272 
273       ASSERT_EQ(xnn_status_success,
274         xnn_create_prelu_nc_f32(
275           channels(), x_stride(), y_stride(),
276           w.data(),
277           0, &caches, &prelu_op));
278       ASSERT_NE(nullptr, prelu_op);
279       if (use_weights_cache()) {
280         ASSERT_EQ(xnn_status_success,
281                   xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
282       }
283 
284       // Smart pointer to automatically delete prelu_op.
285       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op, xnn_delete_operator);
286 
287       ASSERT_EQ(xnn_status_success,
288         xnn_setup_prelu_nc_f32(
289           prelu_op,
290           batch_size(),
291           x.data(), y.data(),
292           nullptr /* thread pool */));
293 
294       ASSERT_EQ(xnn_status_success,
295         xnn_run_operator(prelu_op, nullptr /* thread pool */));
296 
297       VerifyF32(y, y_ref);
298 
299       if (use_weights_cache()) {
300         xnn_operator_t prelu_op2 = nullptr;
301         const size_t old_weights_cache_size = weights_cache.cache.weights.size;
302 
303         ASSERT_EQ(xnn_status_success,
304                   xnn_create_prelu_nc_f32(
305                       channels(), x_stride(), y_stride(),
306                       w.data(),
307                       0, &caches, &prelu_op2));
308         ASSERT_NE(nullptr, prelu_op2);
309 
310         // Smart pointer to automatically delete prelu_op2.
311         std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op2, xnn_delete_operator);
312         std::vector<float> y2(y.size(), nanf(""));
313 
314         ASSERT_EQ(xnn_status_success,
315                   xnn_setup_prelu_nc_f32(
316                       prelu_op2,
317                       batch_size(),
318                       x.data(), y2.data(),
319                       nullptr /* thread pool */));
320 
321         ASSERT_EQ(xnn_status_success,
322                   xnn_run_operator(prelu_op2, nullptr /* thread pool */));
323 
324         VerifyF32(y, y_ref);
325         VerifyWeightsCache(weights_cache, old_weights_cache_size);
326         xnn_release_weights_cache(&weights_cache);
327       }
328     }
329   }
330 
VerifyF32(const std::vector<float> & y,const std::vector<float> & y_ref)331   void VerifyF32(const std::vector<float>& y, const std::vector<float>& y_ref) const {
332     for (size_t i = 0; i < batch_size(); i++) {
333       for (size_t c = 0; c < channels(); c++) {
334         ASSERT_NEAR(
335             y[i * y_stride() + c],
336             y_ref[i * channels() + c],
337             std::max(1.0e-6f, std::abs(y_ref[i * channels() + c]) * 1.0e-6f))
338           << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels();
339       }
340     }
341   }
342 
VerifyWeightsCache(const xnn_weights_cache & weights_cache,size_t old_size)343   void VerifyWeightsCache(const xnn_weights_cache& weights_cache, size_t old_size) const {
344     ASSERT_EQ(weights_cache.cache.hits, 1);
345     // Ensure that we did not write more weights to the cache because it was a cache hit.
346     ASSERT_EQ(old_size, weights_cache.cache.weights.size);
347   };
348 
349  private:
350   size_t batch_size_{1};
351   size_t channels_{1};
352   size_t x_stride_{0};
353   size_t y_stride_{0};
354   WeightsType weights_type_{WeightsType::Default};
355   bool use_weights_cache_{false};
356   size_t iterations_{15};
357 };
358