1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2021 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 9*4bdc9457SAndroid Build Coastguard Worker 10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 11*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 12*4bdc9457SAndroid Build Coastguard Worker #include <cmath> 13*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 14*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 15*4bdc9457SAndroid Build Coastguard Worker #include <random> 16*4bdc9457SAndroid Build Coastguard Worker #include <vector> 17*4bdc9457SAndroid Build Coastguard Worker 18*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h> 19*4bdc9457SAndroid Build Coastguard Worker 20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 21*4bdc9457SAndroid Build Coastguard Worker 22*4bdc9457SAndroid Build Coastguard Worker 23*4bdc9457SAndroid Build Coastguard Worker class ConvertOperatorTester { 24*4bdc9457SAndroid Build Coastguard Worker public: channels(size_t channels)25*4bdc9457SAndroid Build Coastguard Worker inline ConvertOperatorTester& channels(size_t channels) { 26*4bdc9457SAndroid Build Coastguard Worker assert(channels != 0); 27*4bdc9457SAndroid Build Coastguard Worker this->channels_ = channels; 28*4bdc9457SAndroid Build Coastguard Worker return *this; 29*4bdc9457SAndroid Build Coastguard Worker } 30*4bdc9457SAndroid Build Coastguard Worker channels()31*4bdc9457SAndroid Build Coastguard Worker inline size_t channels() const { 32*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 33*4bdc9457SAndroid Build Coastguard Worker } 34*4bdc9457SAndroid Build Coastguard Worker input_stride(size_t input_stride)35*4bdc9457SAndroid Build Coastguard Worker inline ConvertOperatorTester& input_stride(size_t input_stride) { 36*4bdc9457SAndroid Build Coastguard Worker assert(input_stride != 0); 37*4bdc9457SAndroid Build Coastguard Worker this->input_stride_ = input_stride; 38*4bdc9457SAndroid Build Coastguard Worker return *this; 39*4bdc9457SAndroid Build Coastguard Worker } 40*4bdc9457SAndroid Build Coastguard Worker input_stride()41*4bdc9457SAndroid Build Coastguard Worker inline size_t input_stride() const { 42*4bdc9457SAndroid Build Coastguard Worker if (this->input_stride_ == 0) { 43*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 44*4bdc9457SAndroid Build Coastguard Worker } else { 45*4bdc9457SAndroid Build Coastguard Worker assert(this->input_stride_ >= this->channels_); 46*4bdc9457SAndroid Build Coastguard Worker return this->input_stride_; 47*4bdc9457SAndroid Build Coastguard Worker } 48*4bdc9457SAndroid Build Coastguard Worker } 49*4bdc9457SAndroid Build Coastguard Worker output_stride(size_t output_stride)50*4bdc9457SAndroid Build Coastguard Worker inline ConvertOperatorTester& output_stride(size_t output_stride) { 51*4bdc9457SAndroid Build Coastguard Worker assert(output_stride != 0); 52*4bdc9457SAndroid Build Coastguard Worker this->output_stride_ = output_stride; 53*4bdc9457SAndroid Build Coastguard Worker return *this; 54*4bdc9457SAndroid Build Coastguard Worker } 55*4bdc9457SAndroid Build Coastguard Worker output_stride()56*4bdc9457SAndroid Build Coastguard Worker inline size_t output_stride() const { 57*4bdc9457SAndroid Build Coastguard Worker if (this->output_stride_ == 0) { 58*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 59*4bdc9457SAndroid Build Coastguard Worker } else { 60*4bdc9457SAndroid Build Coastguard Worker assert(this->output_stride_ >= this->channels_); 61*4bdc9457SAndroid Build Coastguard Worker return this->output_stride_; 62*4bdc9457SAndroid Build Coastguard Worker } 63*4bdc9457SAndroid Build Coastguard Worker } 64*4bdc9457SAndroid Build Coastguard Worker batch_size(size_t batch_size)65*4bdc9457SAndroid Build Coastguard Worker inline ConvertOperatorTester& batch_size(size_t batch_size) { 66*4bdc9457SAndroid Build Coastguard Worker assert(batch_size != 0); 67*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size; 68*4bdc9457SAndroid Build Coastguard Worker return *this; 69*4bdc9457SAndroid Build Coastguard Worker } 70*4bdc9457SAndroid Build Coastguard Worker batch_size()71*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const { 72*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_; 73*4bdc9457SAndroid Build Coastguard Worker } 74*4bdc9457SAndroid Build Coastguard Worker scale(float scale)75*4bdc9457SAndroid Build Coastguard Worker inline ConvertOperatorTester& scale(float scale) { 76*4bdc9457SAndroid Build Coastguard Worker assert(scale >= 0.0f); 77*4bdc9457SAndroid Build Coastguard Worker assert(std::isnormal(scale)); 78*4bdc9457SAndroid Build Coastguard Worker this->scale_ = scale; 79*4bdc9457SAndroid Build Coastguard Worker return *this; 80*4bdc9457SAndroid Build Coastguard Worker } 81*4bdc9457SAndroid Build Coastguard Worker scale()82*4bdc9457SAndroid Build Coastguard Worker inline float scale() const { 83*4bdc9457SAndroid Build Coastguard Worker return this->scale_; 84*4bdc9457SAndroid Build Coastguard Worker } 85*4bdc9457SAndroid Build Coastguard Worker zero_point(int16_t zero_point)86*4bdc9457SAndroid Build Coastguard Worker inline ConvertOperatorTester& zero_point(int16_t zero_point) { 87*4bdc9457SAndroid Build Coastguard Worker this->zero_point_ = zero_point; 88*4bdc9457SAndroid Build Coastguard Worker return *this; 89*4bdc9457SAndroid Build Coastguard Worker } 90*4bdc9457SAndroid Build Coastguard Worker zero_point()91*4bdc9457SAndroid Build Coastguard Worker inline int16_t zero_point() const { 92*4bdc9457SAndroid Build Coastguard Worker return this->zero_point_; 93*4bdc9457SAndroid Build Coastguard Worker } 94*4bdc9457SAndroid Build Coastguard Worker qmin(int16_t qmin)95*4bdc9457SAndroid Build Coastguard Worker inline ConvertOperatorTester& qmin(int16_t qmin) { 96*4bdc9457SAndroid Build Coastguard Worker this->qmin_ = qmin; 97*4bdc9457SAndroid Build Coastguard Worker return *this; 98*4bdc9457SAndroid Build Coastguard Worker } 99*4bdc9457SAndroid Build Coastguard Worker qmin()100*4bdc9457SAndroid Build Coastguard Worker inline int16_t qmin() const { 101*4bdc9457SAndroid Build Coastguard Worker return this->qmin_; 102*4bdc9457SAndroid Build Coastguard Worker } 103*4bdc9457SAndroid Build Coastguard Worker qmax(int16_t qmax)104*4bdc9457SAndroid Build Coastguard Worker inline ConvertOperatorTester& qmax(int16_t qmax) { 105*4bdc9457SAndroid Build Coastguard Worker this->qmax_ = qmax; 106*4bdc9457SAndroid Build Coastguard Worker return *this; 107*4bdc9457SAndroid Build Coastguard Worker } 108*4bdc9457SAndroid Build Coastguard Worker qmax()109*4bdc9457SAndroid Build Coastguard Worker inline int16_t qmax() const { 110*4bdc9457SAndroid Build Coastguard Worker return this->qmax_; 111*4bdc9457SAndroid Build Coastguard Worker } 112*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)113*4bdc9457SAndroid Build Coastguard Worker inline ConvertOperatorTester& iterations(size_t iterations) { 114*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 115*4bdc9457SAndroid Build Coastguard Worker return *this; 116*4bdc9457SAndroid Build Coastguard Worker } 117*4bdc9457SAndroid Build Coastguard Worker iterations()118*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 119*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 120*4bdc9457SAndroid Build Coastguard Worker } 121*4bdc9457SAndroid Build Coastguard Worker TestF16toF32()122*4bdc9457SAndroid Build Coastguard Worker void TestF16toF32() const { 123*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 124*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 125*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f); 126*4bdc9457SAndroid Build Coastguard Worker 127*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + 128*4bdc9457SAndroid Build Coastguard Worker (batch_size() - 1) * input_stride() + channels()); 129*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output((batch_size() - 1) * output_stride() + channels()); 130*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * channels()); 131*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 132*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 133*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf("")); 134*4bdc9457SAndroid Build Coastguard Worker 135*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 136*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 137*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 138*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = fp16_ieee_to_fp32_value(input[i * input_stride() + c]); 139*4bdc9457SAndroid Build Coastguard Worker } 140*4bdc9457SAndroid Build Coastguard Worker } 141*4bdc9457SAndroid Build Coastguard Worker 142*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Convert operator. 143*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 144*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convert_op = nullptr; 145*4bdc9457SAndroid Build Coastguard Worker 146*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 147*4bdc9457SAndroid Build Coastguard Worker xnn_create_convert_nc_f16_f32( 148*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 149*4bdc9457SAndroid Build Coastguard Worker 0, &convert_op)); 150*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convert_op); 151*4bdc9457SAndroid Build Coastguard Worker 152*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convert op. 153*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convert_op(convert_op, xnn_delete_operator); 154*4bdc9457SAndroid Build Coastguard Worker 155*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 156*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convert_nc_f16_f32( 157*4bdc9457SAndroid Build Coastguard Worker convert_op, 158*4bdc9457SAndroid Build Coastguard Worker batch_size(), 159*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 160*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 161*4bdc9457SAndroid Build Coastguard Worker 162*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 163*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convert_op, nullptr /* thread pool */)); 164*4bdc9457SAndroid Build Coastguard Worker 165*4bdc9457SAndroid Build Coastguard Worker // Verify results. 166*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 167*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 168*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c]) 169*4bdc9457SAndroid Build Coastguard Worker << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 170*4bdc9457SAndroid Build Coastguard Worker } 171*4bdc9457SAndroid Build Coastguard Worker } 172*4bdc9457SAndroid Build Coastguard Worker } 173*4bdc9457SAndroid Build Coastguard Worker } 174*4bdc9457SAndroid Build Coastguard Worker TestF32toF16()175*4bdc9457SAndroid Build Coastguard Worker void TestF32toF16() const { 176*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 177*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 178*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f); 179*4bdc9457SAndroid Build Coastguard Worker 180*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + 181*4bdc9457SAndroid Build Coastguard Worker (batch_size() - 1) * input_stride() + channels()); 182*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output((batch_size() - 1) * output_stride() + channels()); 183*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output_ref(batch_size() * channels()); 184*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 185*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 186*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 187*4bdc9457SAndroid Build Coastguard Worker 188*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 189*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 190*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 191*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = fp16_ieee_from_fp32_value(input[i * input_stride() + c]); 192*4bdc9457SAndroid Build Coastguard Worker } 193*4bdc9457SAndroid Build Coastguard Worker } 194*4bdc9457SAndroid Build Coastguard Worker 195*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Convert operator. 196*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 197*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convert_op = nullptr; 198*4bdc9457SAndroid Build Coastguard Worker 199*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 200*4bdc9457SAndroid Build Coastguard Worker xnn_create_convert_nc_f32_f16( 201*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 202*4bdc9457SAndroid Build Coastguard Worker 0, &convert_op)); 203*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convert_op); 204*4bdc9457SAndroid Build Coastguard Worker 205*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convert op. 206*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convert_op(convert_op, xnn_delete_operator); 207*4bdc9457SAndroid Build Coastguard Worker 208*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 209*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convert_nc_f32_f16( 210*4bdc9457SAndroid Build Coastguard Worker convert_op, 211*4bdc9457SAndroid Build Coastguard Worker batch_size(), 212*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 213*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 214*4bdc9457SAndroid Build Coastguard Worker 215*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 216*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convert_op, nullptr /* thread pool */)); 217*4bdc9457SAndroid Build Coastguard Worker 218*4bdc9457SAndroid Build Coastguard Worker // Verify results. 219*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 220*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 221*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c]) 222*4bdc9457SAndroid Build Coastguard Worker << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 223*4bdc9457SAndroid Build Coastguard Worker } 224*4bdc9457SAndroid Build Coastguard Worker } 225*4bdc9457SAndroid Build Coastguard Worker } 226*4bdc9457SAndroid Build Coastguard Worker } 227*4bdc9457SAndroid Build Coastguard Worker TestF32toQS8()228*4bdc9457SAndroid Build Coastguard Worker void TestF32toQS8() const { 229*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(qmin(), std::numeric_limits<int8_t>::min()); 230*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(qmax(), std::numeric_limits<int8_t>::max()); 231*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(qmin(), qmax()); 232*4bdc9457SAndroid Build Coastguard Worker 233*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(zero_point(), std::numeric_limits<int8_t>::min()); 234*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(zero_point(), std::numeric_limits<int8_t>::max()); 235*4bdc9457SAndroid Build Coastguard Worker 236*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 237*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 238*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f); 239*4bdc9457SAndroid Build Coastguard Worker 240*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + 241*4bdc9457SAndroid Build Coastguard Worker (batch_size() - 1) * input_stride() + channels()); 242*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output((batch_size() - 1) * output_stride() + channels()); 243*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output_ref(batch_size() * channels()); 244*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 245*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 246*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5)); 247*4bdc9457SAndroid Build Coastguard Worker 248*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 249*4bdc9457SAndroid Build Coastguard Worker const float inv_scale = 1.0f / scale(); 250*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 251*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 252*4bdc9457SAndroid Build Coastguard Worker float scaled_input = input[i * input_stride() + c] * inv_scale; 253*4bdc9457SAndroid Build Coastguard Worker scaled_input = std::min<float>(scaled_input, float(qmax() - zero_point())); 254*4bdc9457SAndroid Build Coastguard Worker scaled_input = std::max<float>(scaled_input, float(qmin() - zero_point())); 255*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = int8_t(std::lrintf(scaled_input) + long(zero_point())); 256*4bdc9457SAndroid Build Coastguard Worker } 257*4bdc9457SAndroid Build Coastguard Worker } 258*4bdc9457SAndroid Build Coastguard Worker 259*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Convert operator. 260*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 261*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convert_op = nullptr; 262*4bdc9457SAndroid Build Coastguard Worker 263*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 264*4bdc9457SAndroid Build Coastguard Worker xnn_create_convert_nc_f32_qs8( 265*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 266*4bdc9457SAndroid Build Coastguard Worker scale(), int8_t(zero_point()), int8_t(qmin()), int8_t(qmax()), 267*4bdc9457SAndroid Build Coastguard Worker 0, &convert_op)); 268*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convert_op); 269*4bdc9457SAndroid Build Coastguard Worker 270*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convert op. 271*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convert_op(convert_op, xnn_delete_operator); 272*4bdc9457SAndroid Build Coastguard Worker 273*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 274*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convert_nc_f32_qs8( 275*4bdc9457SAndroid Build Coastguard Worker convert_op, 276*4bdc9457SAndroid Build Coastguard Worker batch_size(), 277*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 278*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 279*4bdc9457SAndroid Build Coastguard Worker 280*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 281*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convert_op, nullptr /* thread pool */)); 282*4bdc9457SAndroid Build Coastguard Worker 283*4bdc9457SAndroid Build Coastguard Worker // Verify results. 284*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 285*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 286*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(int32_t(output_ref[i * channels() + c]), int32_t(output[i * output_stride() + c])) 287*4bdc9457SAndroid Build Coastguard Worker << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 288*4bdc9457SAndroid Build Coastguard Worker } 289*4bdc9457SAndroid Build Coastguard Worker } 290*4bdc9457SAndroid Build Coastguard Worker } 291*4bdc9457SAndroid Build Coastguard Worker } 292*4bdc9457SAndroid Build Coastguard Worker TestF32toQU8()293*4bdc9457SAndroid Build Coastguard Worker void TestF32toQU8() const { 294*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(qmin(), std::numeric_limits<uint8_t>::min()); 295*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(qmax(), std::numeric_limits<uint8_t>::max()); 296*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(qmin(), qmax()); 297*4bdc9457SAndroid Build Coastguard Worker 298*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(zero_point(), std::numeric_limits<uint8_t>::min()); 299*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(zero_point(), std::numeric_limits<uint8_t>::max()); 300*4bdc9457SAndroid Build Coastguard Worker 301*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 302*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 303*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f); 304*4bdc9457SAndroid Build Coastguard Worker 305*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + 306*4bdc9457SAndroid Build Coastguard Worker (batch_size() - 1) * input_stride() + channels()); 307*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels()); 308*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output_ref(batch_size() * channels()); 309*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 310*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 311*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT8_C(0xA5)); 312*4bdc9457SAndroid Build Coastguard Worker 313*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 314*4bdc9457SAndroid Build Coastguard Worker const float inv_scale = 1.0f / scale(); 315*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 316*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 317*4bdc9457SAndroid Build Coastguard Worker float scaled_input = input[i * input_stride() + c] * inv_scale; 318*4bdc9457SAndroid Build Coastguard Worker scaled_input = std::min<float>(scaled_input, float(qmax() - zero_point())); 319*4bdc9457SAndroid Build Coastguard Worker scaled_input = std::max<float>(scaled_input, float(qmin() - zero_point())); 320*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = uint8_t(std::lrintf(scaled_input) + long(zero_point())); 321*4bdc9457SAndroid Build Coastguard Worker } 322*4bdc9457SAndroid Build Coastguard Worker } 323*4bdc9457SAndroid Build Coastguard Worker 324*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Convert operator. 325*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 326*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convert_op = nullptr; 327*4bdc9457SAndroid Build Coastguard Worker 328*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 329*4bdc9457SAndroid Build Coastguard Worker xnn_create_convert_nc_f32_qu8( 330*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 331*4bdc9457SAndroid Build Coastguard Worker scale(), uint8_t(zero_point()), uint8_t(qmin()), uint8_t(qmax()), 332*4bdc9457SAndroid Build Coastguard Worker 0, &convert_op)); 333*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convert_op); 334*4bdc9457SAndroid Build Coastguard Worker 335*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convert op. 336*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convert_op(convert_op, xnn_delete_operator); 337*4bdc9457SAndroid Build Coastguard Worker 338*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 339*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convert_nc_f32_qu8( 340*4bdc9457SAndroid Build Coastguard Worker convert_op, 341*4bdc9457SAndroid Build Coastguard Worker batch_size(), 342*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 343*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 344*4bdc9457SAndroid Build Coastguard Worker 345*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 346*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convert_op, nullptr /* thread pool */)); 347*4bdc9457SAndroid Build Coastguard Worker 348*4bdc9457SAndroid Build Coastguard Worker // Verify results. 349*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 350*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 351*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(uint32_t(output_ref[i * channels() + c]), uint32_t(output[i * output_stride() + c])) 352*4bdc9457SAndroid Build Coastguard Worker << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 353*4bdc9457SAndroid Build Coastguard Worker } 354*4bdc9457SAndroid Build Coastguard Worker } 355*4bdc9457SAndroid Build Coastguard Worker } 356*4bdc9457SAndroid Build Coastguard Worker } 357*4bdc9457SAndroid Build Coastguard Worker TestQS8toF32()358*4bdc9457SAndroid Build Coastguard Worker void TestQS8toF32() const { 359*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(zero_point(), std::numeric_limits<int8_t>::min()); 360*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(zero_point(), std::numeric_limits<int8_t>::max()); 361*4bdc9457SAndroid Build Coastguard Worker 362*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 363*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 364*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i8dist( 365*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()); 366*4bdc9457SAndroid Build Coastguard Worker 367*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) + 368*4bdc9457SAndroid Build Coastguard Worker (batch_size() - 1) * input_stride() + channels()); 369*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output((batch_size() - 1) * output_stride() + channels()); 370*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * channels()); 371*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 372*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); }); 373*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf("")); 374*4bdc9457SAndroid Build Coastguard Worker 375*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 376*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 377*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 378*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = float(input[i * input_stride() + c] - zero_point()) * scale(); 379*4bdc9457SAndroid Build Coastguard Worker } 380*4bdc9457SAndroid Build Coastguard Worker } 381*4bdc9457SAndroid Build Coastguard Worker 382*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Convert operator. 383*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 384*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convert_op = nullptr; 385*4bdc9457SAndroid Build Coastguard Worker 386*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 387*4bdc9457SAndroid Build Coastguard Worker xnn_create_convert_nc_qs8_f32( 388*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 389*4bdc9457SAndroid Build Coastguard Worker scale(), int8_t(zero_point()), 390*4bdc9457SAndroid Build Coastguard Worker 0, &convert_op)); 391*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convert_op); 392*4bdc9457SAndroid Build Coastguard Worker 393*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convert op. 394*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convert_op(convert_op, xnn_delete_operator); 395*4bdc9457SAndroid Build Coastguard Worker 396*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 397*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convert_nc_qs8_f32( 398*4bdc9457SAndroid Build Coastguard Worker convert_op, 399*4bdc9457SAndroid Build Coastguard Worker batch_size(), 400*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 401*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 402*4bdc9457SAndroid Build Coastguard Worker 403*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 404*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convert_op, nullptr /* thread pool */)); 405*4bdc9457SAndroid Build Coastguard Worker 406*4bdc9457SAndroid Build Coastguard Worker // Verify results. 407*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 408*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 409*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c]) 410*4bdc9457SAndroid Build Coastguard Worker << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 411*4bdc9457SAndroid Build Coastguard Worker } 412*4bdc9457SAndroid Build Coastguard Worker } 413*4bdc9457SAndroid Build Coastguard Worker } 414*4bdc9457SAndroid Build Coastguard Worker } 415*4bdc9457SAndroid Build Coastguard Worker TestQU8toF32()416*4bdc9457SAndroid Build Coastguard Worker void TestQU8toF32() const { 417*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(zero_point(), std::numeric_limits<uint8_t>::min()); 418*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(zero_point(), std::numeric_limits<uint8_t>::max()); 419*4bdc9457SAndroid Build Coastguard Worker 420*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 421*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 422*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist( 423*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 424*4bdc9457SAndroid Build Coastguard Worker 425*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + 426*4bdc9457SAndroid Build Coastguard Worker (batch_size() - 1) * input_stride() + channels()); 427*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output((batch_size() - 1) * output_stride() + channels()); 428*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * channels()); 429*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 430*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 431*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf("")); 432*4bdc9457SAndroid Build Coastguard Worker 433*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 434*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 435*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 436*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = float(input[i * input_stride() + c] - zero_point()) * scale(); 437*4bdc9457SAndroid Build Coastguard Worker } 438*4bdc9457SAndroid Build Coastguard Worker } 439*4bdc9457SAndroid Build Coastguard Worker 440*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Convert operator. 441*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 442*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convert_op = nullptr; 443*4bdc9457SAndroid Build Coastguard Worker 444*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 445*4bdc9457SAndroid Build Coastguard Worker xnn_create_convert_nc_qu8_f32( 446*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 447*4bdc9457SAndroid Build Coastguard Worker scale(), uint8_t(zero_point()), 448*4bdc9457SAndroid Build Coastguard Worker 0, &convert_op)); 449*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convert_op); 450*4bdc9457SAndroid Build Coastguard Worker 451*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convert op. 452*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convert_op(convert_op, xnn_delete_operator); 453*4bdc9457SAndroid Build Coastguard Worker 454*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 455*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convert_nc_qu8_f32( 456*4bdc9457SAndroid Build Coastguard Worker convert_op, 457*4bdc9457SAndroid Build Coastguard Worker batch_size(), 458*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 459*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 460*4bdc9457SAndroid Build Coastguard Worker 461*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 462*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convert_op, nullptr /* thread pool */)); 463*4bdc9457SAndroid Build Coastguard Worker 464*4bdc9457SAndroid Build Coastguard Worker // Verify results. 465*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 466*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 467*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c]) 468*4bdc9457SAndroid Build Coastguard Worker << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 469*4bdc9457SAndroid Build Coastguard Worker } 470*4bdc9457SAndroid Build Coastguard Worker } 471*4bdc9457SAndroid Build Coastguard Worker } 472*4bdc9457SAndroid Build Coastguard Worker } 473*4bdc9457SAndroid Build Coastguard Worker 474*4bdc9457SAndroid Build Coastguard Worker private: 475*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_{1}; 476*4bdc9457SAndroid Build Coastguard Worker size_t channels_{1}; 477*4bdc9457SAndroid Build Coastguard Worker size_t input_stride_{0}; 478*4bdc9457SAndroid Build Coastguard Worker size_t output_stride_{0}; 479*4bdc9457SAndroid Build Coastguard Worker float scale_{150.0f}; 480*4bdc9457SAndroid Build Coastguard Worker int16_t zero_point_{1}; 481*4bdc9457SAndroid Build Coastguard Worker int16_t qmin_{std::numeric_limits<int16_t>::min()}; 482*4bdc9457SAndroid Build Coastguard Worker int16_t qmax_{std::numeric_limits<int16_t>::max()}; 483*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{15}; 484*4bdc9457SAndroid Build Coastguard Worker }; 485