1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates. 2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved. 3*4bdc9457SAndroid Build Coastguard Worker // 4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 5*4bdc9457SAndroid Build Coastguard Worker // 6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 8*4bdc9457SAndroid Build Coastguard Worker 9*4bdc9457SAndroid Build Coastguard Worker #pragma once 10*4bdc9457SAndroid Build Coastguard Worker 11*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 12*4bdc9457SAndroid Build Coastguard Worker 13*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 14*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 15*4bdc9457SAndroid Build Coastguard Worker #include <cmath> 16*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 17*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 18*4bdc9457SAndroid Build Coastguard Worker #include <limits> 19*4bdc9457SAndroid Build Coastguard Worker #include <random> 20*4bdc9457SAndroid Build Coastguard Worker #include <vector> 21*4bdc9457SAndroid Build Coastguard Worker 22*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h> 23*4bdc9457SAndroid Build Coastguard Worker 24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 25*4bdc9457SAndroid Build Coastguard Worker 26*4bdc9457SAndroid Build Coastguard Worker 27*4bdc9457SAndroid Build Coastguard Worker class SigmoidOperatorTester { 28*4bdc9457SAndroid Build Coastguard Worker public: channels(size_t channels)29*4bdc9457SAndroid Build Coastguard Worker inline SigmoidOperatorTester& channels(size_t channels) { 30*4bdc9457SAndroid Build Coastguard Worker assert(channels != 0); 31*4bdc9457SAndroid Build Coastguard Worker this->channels_ = channels; 32*4bdc9457SAndroid Build Coastguard Worker return *this; 33*4bdc9457SAndroid Build Coastguard Worker } 34*4bdc9457SAndroid Build Coastguard Worker channels()35*4bdc9457SAndroid Build Coastguard Worker inline size_t channels() const { 36*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 37*4bdc9457SAndroid Build Coastguard Worker } 38*4bdc9457SAndroid Build Coastguard Worker input_stride(size_t input_stride)39*4bdc9457SAndroid Build Coastguard Worker inline SigmoidOperatorTester& input_stride(size_t input_stride) { 40*4bdc9457SAndroid Build Coastguard Worker assert(input_stride != 0); 41*4bdc9457SAndroid Build Coastguard Worker this->input_stride_ = input_stride; 42*4bdc9457SAndroid Build Coastguard Worker return *this; 43*4bdc9457SAndroid Build Coastguard Worker } 44*4bdc9457SAndroid Build Coastguard Worker input_stride()45*4bdc9457SAndroid Build Coastguard Worker inline size_t input_stride() const { 46*4bdc9457SAndroid Build Coastguard Worker if (this->input_stride_ == 0) { 47*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 48*4bdc9457SAndroid Build Coastguard Worker } else { 49*4bdc9457SAndroid Build Coastguard Worker assert(this->input_stride_ >= this->channels_); 50*4bdc9457SAndroid Build Coastguard Worker return this->input_stride_; 51*4bdc9457SAndroid Build Coastguard Worker } 52*4bdc9457SAndroid Build Coastguard Worker } 53*4bdc9457SAndroid Build Coastguard Worker output_stride(size_t output_stride)54*4bdc9457SAndroid Build Coastguard Worker inline SigmoidOperatorTester& output_stride(size_t output_stride) { 55*4bdc9457SAndroid Build Coastguard Worker assert(output_stride != 0); 56*4bdc9457SAndroid Build Coastguard Worker this->output_stride_ = output_stride; 57*4bdc9457SAndroid Build Coastguard Worker return *this; 58*4bdc9457SAndroid Build Coastguard Worker } 59*4bdc9457SAndroid Build Coastguard Worker output_stride()60*4bdc9457SAndroid Build Coastguard Worker inline size_t output_stride() const { 61*4bdc9457SAndroid Build Coastguard Worker if (this->output_stride_ == 0) { 62*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 63*4bdc9457SAndroid Build Coastguard Worker } else { 64*4bdc9457SAndroid Build Coastguard Worker assert(this->output_stride_ >= this->channels_); 65*4bdc9457SAndroid Build Coastguard Worker return this->output_stride_; 66*4bdc9457SAndroid Build Coastguard Worker } 67*4bdc9457SAndroid Build Coastguard Worker } 68*4bdc9457SAndroid Build Coastguard Worker batch_size(size_t batch_size)69*4bdc9457SAndroid Build Coastguard Worker inline SigmoidOperatorTester& batch_size(size_t batch_size) { 70*4bdc9457SAndroid Build Coastguard Worker assert(batch_size != 0); 71*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size; 72*4bdc9457SAndroid Build Coastguard Worker return *this; 73*4bdc9457SAndroid Build Coastguard Worker } 74*4bdc9457SAndroid Build Coastguard Worker batch_size()75*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const { 76*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_; 77*4bdc9457SAndroid Build Coastguard Worker } 78*4bdc9457SAndroid Build Coastguard Worker input_scale(float input_scale)79*4bdc9457SAndroid Build Coastguard Worker inline SigmoidOperatorTester& input_scale(float input_scale) { 80*4bdc9457SAndroid Build Coastguard Worker assert(input_scale > 0.0f); 81*4bdc9457SAndroid Build Coastguard Worker assert(std::isnormal(input_scale)); 82*4bdc9457SAndroid Build Coastguard Worker this->input_scale_ = input_scale; 83*4bdc9457SAndroid Build Coastguard Worker return *this; 84*4bdc9457SAndroid Build Coastguard Worker } 85*4bdc9457SAndroid Build Coastguard Worker input_scale()86*4bdc9457SAndroid Build Coastguard Worker inline float input_scale() const { 87*4bdc9457SAndroid Build Coastguard Worker return this->input_scale_; 88*4bdc9457SAndroid Build Coastguard Worker } 89*4bdc9457SAndroid Build Coastguard Worker input_zero_point(uint8_t input_zero_point)90*4bdc9457SAndroid Build Coastguard Worker inline SigmoidOperatorTester& input_zero_point(uint8_t input_zero_point) { 91*4bdc9457SAndroid Build Coastguard Worker this->input_zero_point_ = input_zero_point; 92*4bdc9457SAndroid Build Coastguard Worker return *this; 93*4bdc9457SAndroid Build Coastguard Worker } 94*4bdc9457SAndroid Build Coastguard Worker input_zero_point()95*4bdc9457SAndroid Build Coastguard Worker inline uint8_t input_zero_point() const { 96*4bdc9457SAndroid Build Coastguard Worker return this->input_zero_point_; 97*4bdc9457SAndroid Build Coastguard Worker } 98*4bdc9457SAndroid Build Coastguard Worker output_scale()99*4bdc9457SAndroid Build Coastguard Worker inline float output_scale() const { 100*4bdc9457SAndroid Build Coastguard Worker return 1.0f / 256.0f; 101*4bdc9457SAndroid Build Coastguard Worker } 102*4bdc9457SAndroid Build Coastguard Worker output_zero_point()103*4bdc9457SAndroid Build Coastguard Worker inline uint8_t output_zero_point() const { 104*4bdc9457SAndroid Build Coastguard Worker return 0; 105*4bdc9457SAndroid Build Coastguard Worker } 106*4bdc9457SAndroid Build Coastguard Worker qmin(uint8_t qmin)107*4bdc9457SAndroid Build Coastguard Worker inline SigmoidOperatorTester& qmin(uint8_t qmin) { 108*4bdc9457SAndroid Build Coastguard Worker this->qmin_ = qmin; 109*4bdc9457SAndroid Build Coastguard Worker return *this; 110*4bdc9457SAndroid Build Coastguard Worker } 111*4bdc9457SAndroid Build Coastguard Worker qmin()112*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmin() const { 113*4bdc9457SAndroid Build Coastguard Worker return this->qmin_; 114*4bdc9457SAndroid Build Coastguard Worker } 115*4bdc9457SAndroid Build Coastguard Worker qmax(uint8_t qmax)116*4bdc9457SAndroid Build Coastguard Worker inline SigmoidOperatorTester& qmax(uint8_t qmax) { 117*4bdc9457SAndroid Build Coastguard Worker this->qmax_ = qmax; 118*4bdc9457SAndroid Build Coastguard Worker return *this; 119*4bdc9457SAndroid Build Coastguard Worker } 120*4bdc9457SAndroid Build Coastguard Worker qmax()121*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmax() const { 122*4bdc9457SAndroid Build Coastguard Worker return this->qmax_; 123*4bdc9457SAndroid Build Coastguard Worker } 124*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)125*4bdc9457SAndroid Build Coastguard Worker inline SigmoidOperatorTester& iterations(size_t iterations) { 126*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 127*4bdc9457SAndroid Build Coastguard Worker return *this; 128*4bdc9457SAndroid Build Coastguard Worker } 129*4bdc9457SAndroid Build Coastguard Worker iterations()130*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 131*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 132*4bdc9457SAndroid Build Coastguard Worker } 133*4bdc9457SAndroid Build Coastguard Worker TestF16()134*4bdc9457SAndroid Build Coastguard Worker void TestF16() const { 135*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 136*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 137*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-25.0f, 25.0f); 138*4bdc9457SAndroid Build Coastguard Worker 139*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 140*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output((batch_size() - 1) * output_stride() + channels()); 141*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * channels()); 142*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 143*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 144*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 145*4bdc9457SAndroid Build Coastguard Worker 146*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 147*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 148*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 149*4bdc9457SAndroid Build Coastguard Worker const float x = fp16_ieee_to_fp32_value(input[i * input_stride() + c]); 150*4bdc9457SAndroid Build Coastguard Worker const float exp_x = std::exp(x); 151*4bdc9457SAndroid Build Coastguard Worker const float sigmoid_x = exp_x / (1.0 + exp_x); 152*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = sigmoid_x; 153*4bdc9457SAndroid Build Coastguard Worker } 154*4bdc9457SAndroid Build Coastguard Worker } 155*4bdc9457SAndroid Build Coastguard Worker 156*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Sigmoid operator. 157*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 158*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t sigmoid_op = nullptr; 159*4bdc9457SAndroid Build Coastguard Worker 160*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_create_sigmoid_nc_f16( 161*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 162*4bdc9457SAndroid Build Coastguard Worker 0, &sigmoid_op); 163*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 164*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 165*4bdc9457SAndroid Build Coastguard Worker } 166*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 167*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, sigmoid_op); 168*4bdc9457SAndroid Build Coastguard Worker 169*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete sigmoid_op. 170*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_sigmoid_op(sigmoid_op, xnn_delete_operator); 171*4bdc9457SAndroid Build Coastguard Worker 172*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 173*4bdc9457SAndroid Build Coastguard Worker xnn_setup_sigmoid_nc_f16( 174*4bdc9457SAndroid Build Coastguard Worker sigmoid_op, 175*4bdc9457SAndroid Build Coastguard Worker batch_size(), 176*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 177*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 178*4bdc9457SAndroid Build Coastguard Worker 179*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 180*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(sigmoid_op, nullptr /* thread pool */)); 181*4bdc9457SAndroid Build Coastguard Worker 182*4bdc9457SAndroid Build Coastguard Worker // Verify results. 183*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 184*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 185*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 186*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(output[i * output_stride() + c]), 187*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c], 188*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-4f, std::abs(output_ref[i * channels() + c]) * 5.0e-3f)); 189*4bdc9457SAndroid Build Coastguard Worker } 190*4bdc9457SAndroid Build Coastguard Worker } 191*4bdc9457SAndroid Build Coastguard Worker } 192*4bdc9457SAndroid Build Coastguard Worker } 193*4bdc9457SAndroid Build Coastguard Worker TestF32()194*4bdc9457SAndroid Build Coastguard Worker void TestF32() const { 195*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 196*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 197*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-25.0f, 25.0f); 198*4bdc9457SAndroid Build Coastguard Worker 199*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float)); 200*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output((batch_size() - 1) * output_stride() + channels()); 201*4bdc9457SAndroid Build Coastguard Worker std::vector<double> output_ref(batch_size() * channels()); 202*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 203*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 204*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf("")); 205*4bdc9457SAndroid Build Coastguard Worker 206*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 207*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 208*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 209*4bdc9457SAndroid Build Coastguard Worker const double x = input[i * input_stride() + c]; 210*4bdc9457SAndroid Build Coastguard Worker const double exp_x = std::exp(x); 211*4bdc9457SAndroid Build Coastguard Worker const double sigmoid_x = exp_x / (1.0 + exp_x); 212*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = sigmoid_x; 213*4bdc9457SAndroid Build Coastguard Worker } 214*4bdc9457SAndroid Build Coastguard Worker } 215*4bdc9457SAndroid Build Coastguard Worker 216*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Sigmoid operator. 217*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 218*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t sigmoid_op = nullptr; 219*4bdc9457SAndroid Build Coastguard Worker 220*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_create_sigmoid_nc_f32( 221*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 222*4bdc9457SAndroid Build Coastguard Worker 0, &sigmoid_op); 223*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 224*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, sigmoid_op); 225*4bdc9457SAndroid Build Coastguard Worker 226*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete sigmoid_op. 227*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_sigmoid_op(sigmoid_op, xnn_delete_operator); 228*4bdc9457SAndroid Build Coastguard Worker 229*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 230*4bdc9457SAndroid Build Coastguard Worker xnn_setup_sigmoid_nc_f32( 231*4bdc9457SAndroid Build Coastguard Worker sigmoid_op, 232*4bdc9457SAndroid Build Coastguard Worker batch_size(), 233*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 234*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 235*4bdc9457SAndroid Build Coastguard Worker 236*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 237*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(sigmoid_op, nullptr /* thread pool */)); 238*4bdc9457SAndroid Build Coastguard Worker 239*4bdc9457SAndroid Build Coastguard Worker // Verify results. 240*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 241*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 242*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 243*4bdc9457SAndroid Build Coastguard Worker output[i * output_stride() + c], 244*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c], 245*4bdc9457SAndroid Build Coastguard Worker 5.0e-6); 246*4bdc9457SAndroid Build Coastguard Worker } 247*4bdc9457SAndroid Build Coastguard Worker } 248*4bdc9457SAndroid Build Coastguard Worker } 249*4bdc9457SAndroid Build Coastguard Worker } 250*4bdc9457SAndroid Build Coastguard Worker TestQS8()251*4bdc9457SAndroid Build Coastguard Worker void TestQS8() const { 252*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 253*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 254*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i8dist( 255*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()); 256*4bdc9457SAndroid Build Coastguard Worker 257*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(int8_t)); 258*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output((batch_size() - 1) * output_stride() + channels()); 259*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * channels()); 260*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 261*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); }); 262*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5)); 263*4bdc9457SAndroid Build Coastguard Worker 264*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 265*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 266*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 267*4bdc9457SAndroid Build Coastguard Worker const float x = input_scale() * 268*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[i * input_stride() + c]) - int32_t(input_zero_point() - 0x80)); 269*4bdc9457SAndroid Build Coastguard Worker const float sigmoid_x = 1.0f / (1.0f + std::exp(-x)); 270*4bdc9457SAndroid Build Coastguard Worker const float scaled_sigmoid_x = sigmoid_x / output_scale(); 271*4bdc9457SAndroid Build Coastguard Worker float y = scaled_sigmoid_x; 272*4bdc9457SAndroid Build Coastguard Worker y = std::min<float>(y, int32_t(qmax() - 0x80) - int32_t(output_zero_point() - 0x80)); 273*4bdc9457SAndroid Build Coastguard Worker y = std::max<float>(y, int32_t(qmin() - 0x80) - int32_t(output_zero_point() - 0x80)); 274*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = y + int32_t(output_zero_point() - 0x80); 275*4bdc9457SAndroid Build Coastguard Worker } 276*4bdc9457SAndroid Build Coastguard Worker } 277*4bdc9457SAndroid Build Coastguard Worker 278*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Sigmoid operator. 279*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 280*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t sigmoid_op = nullptr; 281*4bdc9457SAndroid Build Coastguard Worker 282*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 283*4bdc9457SAndroid Build Coastguard Worker xnn_create_sigmoid_nc_qs8( 284*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 285*4bdc9457SAndroid Build Coastguard Worker int8_t(input_zero_point() - 0x80), input_scale(), 286*4bdc9457SAndroid Build Coastguard Worker int8_t(output_zero_point() - 0x80), output_scale(), 287*4bdc9457SAndroid Build Coastguard Worker int8_t(qmin() - 0x80), int8_t(qmax() - 0x80), 288*4bdc9457SAndroid Build Coastguard Worker 0, &sigmoid_op)); 289*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, sigmoid_op); 290*4bdc9457SAndroid Build Coastguard Worker 291*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete sigmoid_op. 292*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_sigmoid_op(sigmoid_op, xnn_delete_operator); 293*4bdc9457SAndroid Build Coastguard Worker 294*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 295*4bdc9457SAndroid Build Coastguard Worker xnn_setup_sigmoid_nc_qs8( 296*4bdc9457SAndroid Build Coastguard Worker sigmoid_op, 297*4bdc9457SAndroid Build Coastguard Worker batch_size(), 298*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 299*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 300*4bdc9457SAndroid Build Coastguard Worker 301*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 302*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(sigmoid_op, nullptr /* thread pool */)); 303*4bdc9457SAndroid Build Coastguard Worker 304*4bdc9457SAndroid Build Coastguard Worker // Verify results. 305*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 306*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 307*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.6f); 308*4bdc9457SAndroid Build Coastguard Worker } 309*4bdc9457SAndroid Build Coastguard Worker } 310*4bdc9457SAndroid Build Coastguard Worker } 311*4bdc9457SAndroid Build Coastguard Worker } 312*4bdc9457SAndroid Build Coastguard Worker TestQU8()313*4bdc9457SAndroid Build Coastguard Worker void TestQU8() const { 314*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 315*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 316*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist( 317*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 318*4bdc9457SAndroid Build Coastguard Worker 319*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint8_t)); 320*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels()); 321*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * channels()); 322*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 323*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 324*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT8_C(0xA5)); 325*4bdc9457SAndroid Build Coastguard Worker 326*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 327*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 328*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 329*4bdc9457SAndroid Build Coastguard Worker const float x = input_scale() * 330*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[i * input_stride() + c]) - int32_t(input_zero_point())); 331*4bdc9457SAndroid Build Coastguard Worker const float sigmoid_x = 1.0f / (1.0f + std::exp(-x)); 332*4bdc9457SAndroid Build Coastguard Worker const float scaled_sigmoid_x = sigmoid_x / output_scale(); 333*4bdc9457SAndroid Build Coastguard Worker float y = scaled_sigmoid_x; 334*4bdc9457SAndroid Build Coastguard Worker y = std::min<float>(y, int32_t(qmax()) - int32_t(output_zero_point())); 335*4bdc9457SAndroid Build Coastguard Worker y = std::max<float>(y, int32_t(qmin()) - int32_t(output_zero_point())); 336*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = y + int32_t(output_zero_point()); 337*4bdc9457SAndroid Build Coastguard Worker } 338*4bdc9457SAndroid Build Coastguard Worker } 339*4bdc9457SAndroid Build Coastguard Worker 340*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Sigmoid operator. 341*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 342*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t sigmoid_op = nullptr; 343*4bdc9457SAndroid Build Coastguard Worker 344*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 345*4bdc9457SAndroid Build Coastguard Worker xnn_create_sigmoid_nc_qu8( 346*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 347*4bdc9457SAndroid Build Coastguard Worker input_zero_point(), input_scale(), 348*4bdc9457SAndroid Build Coastguard Worker output_zero_point(), output_scale(), 349*4bdc9457SAndroid Build Coastguard Worker qmin(), qmax(), 350*4bdc9457SAndroid Build Coastguard Worker 0, &sigmoid_op)); 351*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, sigmoid_op); 352*4bdc9457SAndroid Build Coastguard Worker 353*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete sigmoid_op. 354*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_sigmoid_op(sigmoid_op, xnn_delete_operator); 355*4bdc9457SAndroid Build Coastguard Worker 356*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 357*4bdc9457SAndroid Build Coastguard Worker xnn_setup_sigmoid_nc_qu8( 358*4bdc9457SAndroid Build Coastguard Worker sigmoid_op, 359*4bdc9457SAndroid Build Coastguard Worker batch_size(), 360*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 361*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 362*4bdc9457SAndroid Build Coastguard Worker 363*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 364*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(sigmoid_op, nullptr /* thread pool */)); 365*4bdc9457SAndroid Build Coastguard Worker 366*4bdc9457SAndroid Build Coastguard Worker // Verify results. 367*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 368*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 369*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.6f); 370*4bdc9457SAndroid Build Coastguard Worker } 371*4bdc9457SAndroid Build Coastguard Worker } 372*4bdc9457SAndroid Build Coastguard Worker } 373*4bdc9457SAndroid Build Coastguard Worker } 374*4bdc9457SAndroid Build Coastguard Worker 375*4bdc9457SAndroid Build Coastguard Worker private: 376*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_{1}; 377*4bdc9457SAndroid Build Coastguard Worker size_t channels_{1}; 378*4bdc9457SAndroid Build Coastguard Worker size_t input_stride_{0}; 379*4bdc9457SAndroid Build Coastguard Worker size_t output_stride_{0}; 380*4bdc9457SAndroid Build Coastguard Worker float input_scale_{0.75f}; 381*4bdc9457SAndroid Build Coastguard Worker uint8_t input_zero_point_{121}; 382*4bdc9457SAndroid Build Coastguard Worker uint8_t qmin_{0}; 383*4bdc9457SAndroid Build Coastguard Worker uint8_t qmax_{255}; 384*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{15}; 385*4bdc9457SAndroid Build Coastguard Worker }; 386