xref: /aosp_15_r20/external/XNNPACK/test/tanh-operator-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2021 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 #include <gtest/gtest.h>
9 
10 #include <algorithm>
11 #include <cassert>
12 #include <cmath>
13 #include <cstddef>
14 #include <cstdlib>
15 #include <functional>
16 #include <limits>
17 #include <random>
18 #include <vector>
19 
20 #include <xnnpack.h>
21 
22 
23 class TanhOperatorTester {
24  public:
channels(size_t channels)25   inline TanhOperatorTester& channels(size_t channels) {
26     assert(channels != 0);
27     this->channels_ = channels;
28     return *this;
29   }
30 
channels()31   inline size_t channels() const {
32     return this->channels_;
33   }
34 
input_stride(size_t input_stride)35   inline TanhOperatorTester& input_stride(size_t input_stride) {
36     assert(input_stride != 0);
37     this->input_stride_ = input_stride;
38     return *this;
39   }
40 
input_stride()41   inline size_t input_stride() const {
42     if (this->input_stride_ == 0) {
43       return this->channels_;
44     } else {
45       assert(this->input_stride_ >= this->channels_);
46       return this->input_stride_;
47     }
48   }
49 
output_stride(size_t output_stride)50   inline TanhOperatorTester& output_stride(size_t output_stride) {
51     assert(output_stride != 0);
52     this->output_stride_ = output_stride;
53     return *this;
54   }
55 
output_stride()56   inline size_t output_stride() const {
57     if (this->output_stride_ == 0) {
58       return this->channels_;
59     } else {
60       assert(this->output_stride_ >= this->channels_);
61       return this->output_stride_;
62     }
63   }
64 
batch_size(size_t batch_size)65   inline TanhOperatorTester& batch_size(size_t batch_size) {
66     assert(batch_size != 0);
67     this->batch_size_ = batch_size;
68     return *this;
69   }
70 
batch_size()71   inline size_t batch_size() const {
72     return this->batch_size_;
73   }
74 
input_scale(float input_scale)75   inline TanhOperatorTester& input_scale(float input_scale) {
76     assert(input_scale > 0.0f);
77     assert(std::isnormal(input_scale));
78     this->input_scale_ = input_scale;
79     return *this;
80   }
81 
input_scale()82   inline float input_scale() const {
83     return this->input_scale_;
84   }
85 
input_zero_point(uint8_t input_zero_point)86   inline TanhOperatorTester& input_zero_point(uint8_t input_zero_point) {
87     this->input_zero_point_ = input_zero_point;
88     return *this;
89   }
90 
input_zero_point()91   inline uint8_t input_zero_point() const {
92     return this->input_zero_point_;
93   }
94 
output_scale()95   inline float output_scale() const {
96     return 1.0f / 128.0f;
97   }
98 
output_zero_point()99   inline uint8_t output_zero_point() const {
100     return 128;
101   }
102 
qmin(uint8_t qmin)103   inline TanhOperatorTester& qmin(uint8_t qmin) {
104     this->qmin_ = qmin;
105     return *this;
106   }
107 
qmin()108   inline uint8_t qmin() const {
109     return this->qmin_;
110   }
111 
qmax(uint8_t qmax)112   inline TanhOperatorTester& qmax(uint8_t qmax) {
113     this->qmax_ = qmax;
114     return *this;
115   }
116 
qmax()117   inline uint8_t qmax() const {
118     return this->qmax_;
119   }
120 
iterations(size_t iterations)121   inline TanhOperatorTester& iterations(size_t iterations) {
122     this->iterations_ = iterations;
123     return *this;
124   }
125 
iterations()126   inline size_t iterations() const {
127     return this->iterations_;
128   }
129 
TestQS8()130   void TestQS8() const {
131     std::random_device random_device;
132     auto rng = std::mt19937(random_device());
133     auto i8rng = std::bind(
134       std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
135       std::ref(rng));
136 
137     std::vector<int8_t> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(int8_t));
138     std::vector<int8_t> output((batch_size() - 1) * output_stride() + channels());
139     std::vector<float> output_ref(batch_size() * channels());
140     for (size_t iteration = 0; iteration < iterations(); iteration++) {
141       std::generate(input.begin(), input.end(), std::ref(i8rng));
142       std::fill(output.begin(), output.end(), 0xA5);
143 
144       // Compute reference results.
145       for (size_t i = 0; i < batch_size(); i++) {
146         for (size_t c = 0; c < channels(); c++) {
147           const float x = input_scale() *
148             (int32_t(input[i * input_stride() + c]) - int32_t(input_zero_point() - 0x80));
149           const float tanh_x = std::tanh(x);
150           const float scaled_tanh_x = tanh_x / output_scale();
151           float y = scaled_tanh_x;
152           y = std::min<float>(y, int32_t(qmax() - 0x80) - int32_t(output_zero_point() - 0x80));
153           y = std::max<float>(y, int32_t(qmin() - 0x80) - int32_t(output_zero_point() - 0x80));
154           output_ref[i * channels() + c] = y + int32_t(output_zero_point() - 0x80);
155         }
156       }
157 
158       // Create, setup, run, and destroy Sigmoid operator.
159       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
160       xnn_operator_t tanh_op = nullptr;
161 
162       ASSERT_EQ(xnn_status_success,
163         xnn_create_tanh_nc_qs8(
164           channels(), input_stride(), output_stride(),
165           int8_t(input_zero_point() - 0x80), input_scale(),
166           int8_t(output_zero_point() - 0x80), output_scale(),
167           int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
168           0, &tanh_op));
169       ASSERT_NE(nullptr, tanh_op);
170 
171       // Smart pointer to automatically delete tanh_op.
172       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_tanh_op(tanh_op, xnn_delete_operator);
173 
174       ASSERT_EQ(xnn_status_success,
175         xnn_setup_tanh_nc_qs8(
176           tanh_op,
177           batch_size(),
178           input.data(), output.data(),
179           nullptr /* thread pool */));
180 
181       ASSERT_EQ(xnn_status_success,
182         xnn_run_operator(tanh_op, nullptr /* thread pool */));
183 
184       // Verify results.
185       for (size_t i = 0; i < batch_size(); i++) {
186         for (size_t c = 0; c < channels(); c++) {
187           ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.6f);
188         }
189       }
190     }
191   }
192 
TestQU8()193   void TestQU8() const {
194     std::random_device random_device;
195     auto rng = std::mt19937(random_device());
196     auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng);
197 
198     std::vector<uint8_t> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint8_t));
199     std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels());
200     std::vector<float> output_ref(batch_size() * channels());
201     for (size_t iteration = 0; iteration < iterations(); iteration++) {
202       std::generate(input.begin(), input.end(), std::ref(u8rng));
203       std::fill(output.begin(), output.end(), 0xA5);
204 
205       // Compute reference results.
206       for (size_t i = 0; i < batch_size(); i++) {
207         for (size_t c = 0; c < channels(); c++) {
208           const float x = input_scale() *
209             (int32_t(input[i * input_stride() + c]) - int32_t(input_zero_point()));
210           const float tanh_x = std::tanh(x);
211           const float scaled_tanh_x = tanh_x / output_scale();
212           float y = scaled_tanh_x;
213           y = std::min<float>(y, int32_t(qmax()) - int32_t(output_zero_point()));
214           y = std::max<float>(y, int32_t(qmin()) - int32_t(output_zero_point()));
215           output_ref[i * channels() + c] = y + int32_t(output_zero_point());
216         }
217       }
218 
219       // Create, setup, run, and destroy Sigmoid operator.
220       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
221       xnn_operator_t tanh_op = nullptr;
222 
223       ASSERT_EQ(xnn_status_success,
224         xnn_create_tanh_nc_qu8(
225           channels(), input_stride(), output_stride(),
226           input_zero_point(), input_scale(),
227           output_zero_point(), output_scale(),
228           qmin(), qmax(),
229           0, &tanh_op));
230       ASSERT_NE(nullptr, tanh_op);
231 
232       // Smart pointer to automatically delete tanh_op.
233       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_tanh_op(tanh_op, xnn_delete_operator);
234 
235       ASSERT_EQ(xnn_status_success,
236         xnn_setup_tanh_nc_qu8(
237           tanh_op,
238           batch_size(),
239           input.data(), output.data(),
240           nullptr /* thread pool */));
241 
242       ASSERT_EQ(xnn_status_success,
243         xnn_run_operator(tanh_op, nullptr /* thread pool */));
244 
245       // Verify results.
246       for (size_t i = 0; i < batch_size(); i++) {
247         for (size_t c = 0; c < channels(); c++) {
248           ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.6f);
249         }
250       }
251     }
252   }
253 
254  private:
255   size_t batch_size_{1};
256   size_t channels_{1};
257   size_t input_stride_{0};
258   size_t output_stride_{0};
259   float input_scale_{0.75f};
260   uint8_t input_zero_point_{121};
261   uint8_t qmin_{0};
262   uint8_t qmax_{255};
263   size_t iterations_{15};
264 };
265