1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates. 2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved. 3*4bdc9457SAndroid Build Coastguard Worker // 4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 5*4bdc9457SAndroid Build Coastguard Worker // 6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 8*4bdc9457SAndroid Build Coastguard Worker 9*4bdc9457SAndroid Build Coastguard Worker #pragma once 10*4bdc9457SAndroid Build Coastguard Worker 11*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 12*4bdc9457SAndroid Build Coastguard Worker 13*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 14*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 15*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 16*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 17*4bdc9457SAndroid Build Coastguard Worker #include <cmath> 18*4bdc9457SAndroid Build Coastguard Worker #include <limits> 19*4bdc9457SAndroid Build Coastguard Worker #include <random> 20*4bdc9457SAndroid Build Coastguard Worker #include <vector> 21*4bdc9457SAndroid Build Coastguard Worker 22*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h> 23*4bdc9457SAndroid Build Coastguard Worker 24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 25*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/cache.h> 26*4bdc9457SAndroid Build Coastguard Worker 27*4bdc9457SAndroid Build Coastguard Worker 28*4bdc9457SAndroid Build Coastguard Worker class FullyConnectedOperatorTester { 29*4bdc9457SAndroid Build Coastguard Worker public: 30*4bdc9457SAndroid Build Coastguard Worker enum class WeightsType { 31*4bdc9457SAndroid Build Coastguard Worker Default, 32*4bdc9457SAndroid Build Coastguard Worker FP32, 33*4bdc9457SAndroid Build Coastguard Worker }; 34*4bdc9457SAndroid Build Coastguard Worker input_channels(size_t input_channels)35*4bdc9457SAndroid Build Coastguard Worker inline FullyConnectedOperatorTester& input_channels(size_t input_channels) { 36*4bdc9457SAndroid Build Coastguard Worker assert(input_channels >= 1); 37*4bdc9457SAndroid Build Coastguard Worker this->input_channels_ = input_channels; 38*4bdc9457SAndroid Build Coastguard Worker return *this; 39*4bdc9457SAndroid Build Coastguard Worker } 40*4bdc9457SAndroid Build Coastguard Worker input_channels()41*4bdc9457SAndroid Build Coastguard Worker inline size_t input_channels() const { 42*4bdc9457SAndroid Build Coastguard Worker return this->input_channels_; 43*4bdc9457SAndroid Build Coastguard Worker } 44*4bdc9457SAndroid Build Coastguard Worker output_channels(size_t output_channels)45*4bdc9457SAndroid Build Coastguard Worker inline FullyConnectedOperatorTester& output_channels(size_t output_channels) { 46*4bdc9457SAndroid Build Coastguard Worker assert(output_channels >= 1); 47*4bdc9457SAndroid Build Coastguard Worker this->output_channels_ = output_channels; 48*4bdc9457SAndroid Build Coastguard Worker return *this; 49*4bdc9457SAndroid Build Coastguard Worker } 50*4bdc9457SAndroid Build Coastguard Worker output_channels()51*4bdc9457SAndroid Build Coastguard Worker inline size_t output_channels() const { 52*4bdc9457SAndroid Build Coastguard Worker return this->output_channels_; 53*4bdc9457SAndroid Build Coastguard Worker } 54*4bdc9457SAndroid Build Coastguard Worker batch_size(size_t batch_size)55*4bdc9457SAndroid Build Coastguard Worker inline FullyConnectedOperatorTester& batch_size(size_t batch_size) { 56*4bdc9457SAndroid Build Coastguard Worker assert(batch_size >= 1); 57*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size; 58*4bdc9457SAndroid Build Coastguard Worker return *this; 59*4bdc9457SAndroid Build Coastguard Worker } 60*4bdc9457SAndroid Build Coastguard Worker batch_size()61*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const { 62*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_; 63*4bdc9457SAndroid Build Coastguard Worker } 64*4bdc9457SAndroid Build Coastguard Worker input_stride(size_t input_stride)65*4bdc9457SAndroid Build Coastguard Worker inline FullyConnectedOperatorTester& input_stride(size_t input_stride) { 66*4bdc9457SAndroid Build Coastguard Worker assert(input_stride >= 1); 67*4bdc9457SAndroid Build Coastguard Worker this->input_stride_ = input_stride; 68*4bdc9457SAndroid Build Coastguard Worker return *this; 69*4bdc9457SAndroid Build Coastguard Worker } 70*4bdc9457SAndroid Build Coastguard Worker input_stride()71*4bdc9457SAndroid Build Coastguard Worker inline size_t input_stride() const { 72*4bdc9457SAndroid Build Coastguard Worker if (this->input_stride_ == 0) { 73*4bdc9457SAndroid Build Coastguard Worker return input_channels(); 74*4bdc9457SAndroid Build Coastguard Worker } else { 75*4bdc9457SAndroid Build Coastguard Worker assert(this->input_stride_ >= input_channels()); 76*4bdc9457SAndroid Build Coastguard Worker return this->input_stride_; 77*4bdc9457SAndroid Build Coastguard Worker } 78*4bdc9457SAndroid Build Coastguard Worker } 79*4bdc9457SAndroid Build Coastguard Worker output_stride(size_t output_stride)80*4bdc9457SAndroid Build Coastguard Worker inline FullyConnectedOperatorTester& output_stride(size_t output_stride) { 81*4bdc9457SAndroid Build Coastguard Worker assert(output_stride >= 1); 82*4bdc9457SAndroid Build Coastguard Worker this->output_stride_ = output_stride; 83*4bdc9457SAndroid Build Coastguard Worker return *this; 84*4bdc9457SAndroid Build Coastguard Worker } 85*4bdc9457SAndroid Build Coastguard Worker output_stride()86*4bdc9457SAndroid Build Coastguard Worker inline size_t output_stride() const { 87*4bdc9457SAndroid Build Coastguard Worker if (this->output_stride_ == 0) { 88*4bdc9457SAndroid Build Coastguard Worker return output_channels(); 89*4bdc9457SAndroid Build Coastguard Worker } else { 90*4bdc9457SAndroid Build Coastguard Worker assert(this->output_stride_ >= output_channels()); 91*4bdc9457SAndroid Build Coastguard Worker return this->output_stride_; 92*4bdc9457SAndroid Build Coastguard Worker } 93*4bdc9457SAndroid Build Coastguard Worker } 94*4bdc9457SAndroid Build Coastguard Worker qmin(uint8_t qmin)95*4bdc9457SAndroid Build Coastguard Worker inline FullyConnectedOperatorTester& qmin(uint8_t qmin) { 96*4bdc9457SAndroid Build Coastguard Worker this->qmin_ = qmin; 97*4bdc9457SAndroid Build Coastguard Worker return *this; 98*4bdc9457SAndroid Build Coastguard Worker } 99*4bdc9457SAndroid Build Coastguard Worker qmin()100*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmin() const { 101*4bdc9457SAndroid Build Coastguard Worker return this->qmin_; 102*4bdc9457SAndroid Build Coastguard Worker } 103*4bdc9457SAndroid Build Coastguard Worker qmax(uint8_t qmax)104*4bdc9457SAndroid Build Coastguard Worker inline FullyConnectedOperatorTester& qmax(uint8_t qmax) { 105*4bdc9457SAndroid Build Coastguard Worker this->qmax_ = qmax; 106*4bdc9457SAndroid Build Coastguard Worker return *this; 107*4bdc9457SAndroid Build Coastguard Worker } 108*4bdc9457SAndroid Build Coastguard Worker qmax()109*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmax() const { 110*4bdc9457SAndroid Build Coastguard Worker return this->qmax_; 111*4bdc9457SAndroid Build Coastguard Worker } 112*4bdc9457SAndroid Build Coastguard Worker transpose_weights(bool transpose_weights)113*4bdc9457SAndroid Build Coastguard Worker inline FullyConnectedOperatorTester& transpose_weights(bool transpose_weights) { 114*4bdc9457SAndroid Build Coastguard Worker this->transpose_weights_ = transpose_weights; 115*4bdc9457SAndroid Build Coastguard Worker return *this; 116*4bdc9457SAndroid Build Coastguard Worker } 117*4bdc9457SAndroid Build Coastguard Worker transpose_weights()118*4bdc9457SAndroid Build Coastguard Worker inline bool transpose_weights() const { 119*4bdc9457SAndroid Build Coastguard Worker return this->transpose_weights_; 120*4bdc9457SAndroid Build Coastguard Worker } 121*4bdc9457SAndroid Build Coastguard Worker has_bias(bool has_bias)122*4bdc9457SAndroid Build Coastguard Worker inline FullyConnectedOperatorTester& has_bias(bool has_bias) { 123*4bdc9457SAndroid Build Coastguard Worker this->has_bias_ = has_bias; 124*4bdc9457SAndroid Build Coastguard Worker return *this; 125*4bdc9457SAndroid Build Coastguard Worker } 126*4bdc9457SAndroid Build Coastguard Worker has_bias()127*4bdc9457SAndroid Build Coastguard Worker inline bool has_bias() const { 128*4bdc9457SAndroid Build Coastguard Worker return this->has_bias_; 129*4bdc9457SAndroid Build Coastguard Worker } 130*4bdc9457SAndroid Build Coastguard Worker weights_type(WeightsType weights_type)131*4bdc9457SAndroid Build Coastguard Worker inline FullyConnectedOperatorTester& weights_type(WeightsType weights_type) { 132*4bdc9457SAndroid Build Coastguard Worker this->weights_type_ = weights_type; 133*4bdc9457SAndroid Build Coastguard Worker return *this; 134*4bdc9457SAndroid Build Coastguard Worker } 135*4bdc9457SAndroid Build Coastguard Worker weights_type()136*4bdc9457SAndroid Build Coastguard Worker inline WeightsType weights_type() const { 137*4bdc9457SAndroid Build Coastguard Worker return this->weights_type_; 138*4bdc9457SAndroid Build Coastguard Worker } 139*4bdc9457SAndroid Build Coastguard Worker use_weights_cache(bool use_weights_cache)140*4bdc9457SAndroid Build Coastguard Worker inline FullyConnectedOperatorTester& use_weights_cache(bool use_weights_cache) { 141*4bdc9457SAndroid Build Coastguard Worker this->use_weights_cache_ = use_weights_cache; 142*4bdc9457SAndroid Build Coastguard Worker return *this; 143*4bdc9457SAndroid Build Coastguard Worker } 144*4bdc9457SAndroid Build Coastguard Worker use_weights_cache()145*4bdc9457SAndroid Build Coastguard Worker inline bool use_weights_cache() const { 146*4bdc9457SAndroid Build Coastguard Worker return this->use_weights_cache_; 147*4bdc9457SAndroid Build Coastguard Worker } 148*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)149*4bdc9457SAndroid Build Coastguard Worker inline FullyConnectedOperatorTester& iterations(size_t iterations) { 150*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 151*4bdc9457SAndroid Build Coastguard Worker return *this; 152*4bdc9457SAndroid Build Coastguard Worker } 153*4bdc9457SAndroid Build Coastguard Worker iterations()154*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 155*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 156*4bdc9457SAndroid Build Coastguard Worker } 157*4bdc9457SAndroid Build Coastguard Worker TestQS8()158*4bdc9457SAndroid Build Coastguard Worker void TestQS8() const { 159*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default); 160*4bdc9457SAndroid Build Coastguard Worker 161*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 162*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 163*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i32dist(-10000, 10000); 164*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i8dist( 165*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()); 166*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> w8dist( 167*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()); 168*4bdc9457SAndroid Build Coastguard Worker 169*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) + 170*4bdc9457SAndroid Build Coastguard Worker (batch_size() - 1) * input_stride() + input_channels()); 171*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> kernel(output_channels() * input_channels()); 172*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(output_channels()); 173*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output((batch_size() - 1) * output_stride() + output_channels()); 174*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> accumulators(batch_size() * output_channels()); 175*4bdc9457SAndroid Build Coastguard Worker std::vector<double> output_ref(batch_size() * output_channels()); 176*4bdc9457SAndroid Build Coastguard Worker 177*4bdc9457SAndroid Build Coastguard Worker const int8_t input_zero_point = 127; 178*4bdc9457SAndroid Build Coastguard Worker 179*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 180*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); }); 181*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); 182*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); }); 183*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5)); 184*4bdc9457SAndroid Build Coastguard Worker 185*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without renormalization. 186*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) { 187*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 188*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < output_channels(); oc++) { 189*4bdc9457SAndroid Build Coastguard Worker accumulators[i * output_channels() + oc] = bias[oc]; 190*4bdc9457SAndroid Build Coastguard Worker } 191*4bdc9457SAndroid Build Coastguard Worker } 192*4bdc9457SAndroid Build Coastguard Worker } else { 193*4bdc9457SAndroid Build Coastguard Worker std::fill(accumulators.begin(), accumulators.end(), 0); 194*4bdc9457SAndroid Build Coastguard Worker } 195*4bdc9457SAndroid Build Coastguard Worker if (transpose_weights()) { 196*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 197*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < output_channels(); oc++) { 198*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < input_channels(); ic++) { 199*4bdc9457SAndroid Build Coastguard Worker accumulators[i * output_channels() + oc] += 200*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) * 201*4bdc9457SAndroid Build Coastguard Worker int32_t(kernel[ic * output_channels() + oc]); 202*4bdc9457SAndroid Build Coastguard Worker } 203*4bdc9457SAndroid Build Coastguard Worker } 204*4bdc9457SAndroid Build Coastguard Worker } 205*4bdc9457SAndroid Build Coastguard Worker } else { 206*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 207*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < output_channels(); oc++) { 208*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < input_channels(); ic++) { 209*4bdc9457SAndroid Build Coastguard Worker accumulators[i * output_channels() + oc] += 210*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) * 211*4bdc9457SAndroid Build Coastguard Worker int32_t(kernel[oc * input_channels() + ic]); 212*4bdc9457SAndroid Build Coastguard Worker } 213*4bdc9457SAndroid Build Coastguard Worker } 214*4bdc9457SAndroid Build Coastguard Worker } 215*4bdc9457SAndroid Build Coastguard Worker } 216*4bdc9457SAndroid Build Coastguard Worker 217*4bdc9457SAndroid Build Coastguard Worker // Compute renormalization parameters. 218*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend()); 219*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend()); 220*4bdc9457SAndroid Build Coastguard Worker 221*4bdc9457SAndroid Build Coastguard Worker const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0; 222*4bdc9457SAndroid Build Coastguard Worker const int8_t output_zero_point = int8_t(std::max(std::min( 223*4bdc9457SAndroid Build Coastguard Worker lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale), 224*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min()))); 225*4bdc9457SAndroid Build Coastguard Worker 226*4bdc9457SAndroid Build Coastguard Worker // Renormalize reference results. 227*4bdc9457SAndroid Build Coastguard Worker std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(), 228*4bdc9457SAndroid Build Coastguard Worker [this, output_scale, output_zero_point](int32_t x) -> double { 229*4bdc9457SAndroid Build Coastguard Worker return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax() - 0x80) - output_zero_point), double(qmin() - 0x80) - output_zero_point); 230*4bdc9457SAndroid Build Coastguard Worker }); 231*4bdc9457SAndroid Build Coastguard Worker 232*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Fully Connected operator. 233*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 234*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t fully_connected_op = nullptr; 235*4bdc9457SAndroid Build Coastguard Worker 236*4bdc9457SAndroid Build Coastguard Worker xnn_caches caches = { 237*4bdc9457SAndroid Build Coastguard Worker .code_cache = NULL, 238*4bdc9457SAndroid Build Coastguard Worker .weights_cache = NULL, 239*4bdc9457SAndroid Build Coastguard Worker }; 240*4bdc9457SAndroid Build Coastguard Worker xnn_weights_cache weights_cache; 241*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 242*4bdc9457SAndroid Build Coastguard Worker xnn_init_weights_cache(&weights_cache); 243*4bdc9457SAndroid Build Coastguard Worker caches.weights_cache = &weights_cache; 244*4bdc9457SAndroid Build Coastguard Worker } 245*4bdc9457SAndroid Build Coastguard Worker 246*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_create_fully_connected_nc_qs8( 247*4bdc9457SAndroid Build Coastguard Worker input_channels(), output_channels(), 248*4bdc9457SAndroid Build Coastguard Worker input_stride(), output_stride(), 249*4bdc9457SAndroid Build Coastguard Worker input_zero_point, 1.0f /* input scale */, 250*4bdc9457SAndroid Build Coastguard Worker 1.0f /* kernel scale */, 251*4bdc9457SAndroid Build Coastguard Worker kernel.data(), has_bias() ? bias.data() : nullptr, 252*4bdc9457SAndroid Build Coastguard Worker output_zero_point, output_scale, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80), 253*4bdc9457SAndroid Build Coastguard Worker transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0, 254*4bdc9457SAndroid Build Coastguard Worker &caches, 255*4bdc9457SAndroid Build Coastguard Worker &fully_connected_op); 256*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 257*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 258*4bdc9457SAndroid Build Coastguard Worker } 259*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 260*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, fully_connected_op); 261*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 262*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 263*4bdc9457SAndroid Build Coastguard Worker xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft)); 264*4bdc9457SAndroid Build Coastguard Worker } 265*4bdc9457SAndroid Build Coastguard Worker 266*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete fully_connected_op. 267*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator); 268*4bdc9457SAndroid Build Coastguard Worker 269*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 270*4bdc9457SAndroid Build Coastguard Worker xnn_setup_fully_connected_nc_qs8( 271*4bdc9457SAndroid Build Coastguard Worker fully_connected_op, 272*4bdc9457SAndroid Build Coastguard Worker batch_size(), 273*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 274*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 275*4bdc9457SAndroid Build Coastguard Worker 276*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 277*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(fully_connected_op, nullptr /* thread pool */)); 278*4bdc9457SAndroid Build Coastguard Worker 279*4bdc9457SAndroid Build Coastguard Worker // Verify results. 280*4bdc9457SAndroid Build Coastguard Worker VerifyQS8(output, output_ref, double(output_zero_point)); 281*4bdc9457SAndroid Build Coastguard Worker 282*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 283*4bdc9457SAndroid Build Coastguard Worker // Create another operator with the same weights cache. 284*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t fully_connected_op2 = nullptr; 285*4bdc9457SAndroid Build Coastguard Worker size_t old_weights_cache_size = weights_cache.cache.weights.size; 286*4bdc9457SAndroid Build Coastguard Worker 287*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 288*4bdc9457SAndroid Build Coastguard Worker xnn_create_fully_connected_nc_qs8( 289*4bdc9457SAndroid Build Coastguard Worker input_channels(), output_channels(), input_stride(), 290*4bdc9457SAndroid Build Coastguard Worker output_stride(), input_zero_point, 1.0f /* input scale */, 291*4bdc9457SAndroid Build Coastguard Worker 1.0f /* kernel scale */, kernel.data(), 292*4bdc9457SAndroid Build Coastguard Worker has_bias() ? bias.data() : nullptr, output_zero_point, 293*4bdc9457SAndroid Build Coastguard Worker output_scale, int8_t(qmin() - 0x80), 294*4bdc9457SAndroid Build Coastguard Worker int8_t(qmax() - 0x80), 295*4bdc9457SAndroid Build Coastguard Worker transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0, 296*4bdc9457SAndroid Build Coastguard Worker &caches, &fully_connected_op2)); 297*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, fully_connected_op2); 298*4bdc9457SAndroid Build Coastguard Worker 299*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete fully_connected_op. 300*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> 301*4bdc9457SAndroid Build Coastguard Worker auto_fully_connected_op(fully_connected_op2, xnn_delete_operator); 302*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output2(output.size(), INT8_C(0xA5)); 303*4bdc9457SAndroid Build Coastguard Worker 304*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 305*4bdc9457SAndroid Build Coastguard Worker xnn_setup_fully_connected_nc_qs8( 306*4bdc9457SAndroid Build Coastguard Worker fully_connected_op2, 307*4bdc9457SAndroid Build Coastguard Worker batch_size(), 308*4bdc9457SAndroid Build Coastguard Worker input.data(), output2.data(), 309*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 310*4bdc9457SAndroid Build Coastguard Worker 311*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ( 312*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, 313*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(fully_connected_op2, nullptr /* thread pool */)); 314*4bdc9457SAndroid Build Coastguard Worker 315*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCache(weights_cache, old_weights_cache_size); 316*4bdc9457SAndroid Build Coastguard Worker xnn_release_weights_cache(&weights_cache); 317*4bdc9457SAndroid Build Coastguard Worker 318*4bdc9457SAndroid Build Coastguard Worker VerifyQS8(output, output_ref, double(output_zero_point)); 319*4bdc9457SAndroid Build Coastguard Worker } 320*4bdc9457SAndroid Build Coastguard Worker } 321*4bdc9457SAndroid Build Coastguard Worker } 322*4bdc9457SAndroid Build Coastguard Worker VerifyQS8(const std::vector<int8_t> & output,const std::vector<double> & output_ref,double output_zero_point)323*4bdc9457SAndroid Build Coastguard Worker void VerifyQS8(const std::vector<int8_t>& output, 324*4bdc9457SAndroid Build Coastguard Worker const std::vector<double>& output_ref, 325*4bdc9457SAndroid Build Coastguard Worker double output_zero_point) const { 326*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 327*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < output_channels(); c++) { 328*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax() - 0x80)) 329*4bdc9457SAndroid Build Coastguard Worker << "batch index = " << i << ", channel = " << c; 330*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin() - 0x80)) 331*4bdc9457SAndroid Build Coastguard Worker << "batch index = " << i << ", channel = " << c; 332*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(output_ref[i * output_channels() + c], 333*4bdc9457SAndroid Build Coastguard Worker double(output[i * output_stride() + c]) - output_zero_point, 334*4bdc9457SAndroid Build Coastguard Worker 0.9) 335*4bdc9457SAndroid Build Coastguard Worker << "batch index = " << i << ", channel = " << c; 336*4bdc9457SAndroid Build Coastguard Worker } 337*4bdc9457SAndroid Build Coastguard Worker } 338*4bdc9457SAndroid Build Coastguard Worker } 339*4bdc9457SAndroid Build Coastguard Worker TestQU8()340*4bdc9457SAndroid Build Coastguard Worker void TestQU8() const { 341*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default); 342*4bdc9457SAndroid Build Coastguard Worker 343*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 344*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 345*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i32dist(-10000, 10000); 346*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist( 347*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 348*4bdc9457SAndroid Build Coastguard Worker 349*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + 350*4bdc9457SAndroid Build Coastguard Worker (batch_size() - 1) * input_stride() + input_channels()); 351*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> kernel(output_channels() * input_channels()); 352*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(output_channels()); 353*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output((batch_size() - 1) * output_stride() + output_channels()); 354*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> accumulators(batch_size() * output_channels()); 355*4bdc9457SAndroid Build Coastguard Worker std::vector<double> output_ref(batch_size() * output_channels()); 356*4bdc9457SAndroid Build Coastguard Worker 357*4bdc9457SAndroid Build Coastguard Worker const uint8_t input_zero_point = 127; 358*4bdc9457SAndroid Build Coastguard Worker const uint8_t kernel_zero_point = 127; 359*4bdc9457SAndroid Build Coastguard Worker 360*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 361*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 362*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return u8dist(rng); }); 363*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); }); 364*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT8_C(0xA5)); 365*4bdc9457SAndroid Build Coastguard Worker 366*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without renormalization. 367*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) { 368*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 369*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < output_channels(); oc++) { 370*4bdc9457SAndroid Build Coastguard Worker accumulators[i * output_channels() + oc] = bias[oc]; 371*4bdc9457SAndroid Build Coastguard Worker } 372*4bdc9457SAndroid Build Coastguard Worker } 373*4bdc9457SAndroid Build Coastguard Worker } else { 374*4bdc9457SAndroid Build Coastguard Worker std::fill(accumulators.begin(), accumulators.end(), 0); 375*4bdc9457SAndroid Build Coastguard Worker } 376*4bdc9457SAndroid Build Coastguard Worker if (transpose_weights()) { 377*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 378*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < output_channels(); oc++) { 379*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < input_channels(); ic++) { 380*4bdc9457SAndroid Build Coastguard Worker accumulators[i * output_channels() + oc] += 381*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) * 382*4bdc9457SAndroid Build Coastguard Worker (int32_t(kernel[ic * output_channels() + oc]) - int32_t(kernel_zero_point)); 383*4bdc9457SAndroid Build Coastguard Worker } 384*4bdc9457SAndroid Build Coastguard Worker } 385*4bdc9457SAndroid Build Coastguard Worker } 386*4bdc9457SAndroid Build Coastguard Worker } else { 387*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 388*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < output_channels(); oc++) { 389*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < input_channels(); ic++) { 390*4bdc9457SAndroid Build Coastguard Worker accumulators[i * output_channels() + oc] += 391*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) * 392*4bdc9457SAndroid Build Coastguard Worker (int32_t(kernel[oc * input_channels() + ic]) - int32_t(kernel_zero_point)); 393*4bdc9457SAndroid Build Coastguard Worker } 394*4bdc9457SAndroid Build Coastguard Worker } 395*4bdc9457SAndroid Build Coastguard Worker } 396*4bdc9457SAndroid Build Coastguard Worker } 397*4bdc9457SAndroid Build Coastguard Worker 398*4bdc9457SAndroid Build Coastguard Worker // Compute renormalization parameters. 399*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend()); 400*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend()); 401*4bdc9457SAndroid Build Coastguard Worker 402*4bdc9457SAndroid Build Coastguard Worker const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0; 403*4bdc9457SAndroid Build Coastguard Worker const uint8_t output_zero_point = uint8_t(std::max(std::min( 404*4bdc9457SAndroid Build Coastguard Worker lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale), 405*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min()))); 406*4bdc9457SAndroid Build Coastguard Worker 407*4bdc9457SAndroid Build Coastguard Worker // Renormalize reference results. 408*4bdc9457SAndroid Build Coastguard Worker std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(), 409*4bdc9457SAndroid Build Coastguard Worker [this, output_scale, output_zero_point](int32_t x) -> double { 410*4bdc9457SAndroid Build Coastguard Worker return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax()) - output_zero_point), double(qmin()) - output_zero_point); 411*4bdc9457SAndroid Build Coastguard Worker }); 412*4bdc9457SAndroid Build Coastguard Worker 413*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Fully Connected operator. 414*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 415*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t fully_connected_op = nullptr; 416*4bdc9457SAndroid Build Coastguard Worker 417*4bdc9457SAndroid Build Coastguard Worker xnn_caches caches = { 418*4bdc9457SAndroid Build Coastguard Worker .code_cache = NULL, 419*4bdc9457SAndroid Build Coastguard Worker .weights_cache = NULL, 420*4bdc9457SAndroid Build Coastguard Worker }; 421*4bdc9457SAndroid Build Coastguard Worker xnn_weights_cache weights_cache; 422*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 423*4bdc9457SAndroid Build Coastguard Worker xnn_init_weights_cache(&weights_cache); 424*4bdc9457SAndroid Build Coastguard Worker caches.weights_cache = &weights_cache; 425*4bdc9457SAndroid Build Coastguard Worker } 426*4bdc9457SAndroid Build Coastguard Worker 427*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_create_fully_connected_nc_qu8( 428*4bdc9457SAndroid Build Coastguard Worker input_channels(), output_channels(), 429*4bdc9457SAndroid Build Coastguard Worker input_stride(), output_stride(), 430*4bdc9457SAndroid Build Coastguard Worker input_zero_point, 1.0f /* input scale */, 431*4bdc9457SAndroid Build Coastguard Worker kernel_zero_point, 1.0f /* kernel scale */, 432*4bdc9457SAndroid Build Coastguard Worker kernel.data(), has_bias() ? bias.data() : nullptr, 433*4bdc9457SAndroid Build Coastguard Worker output_zero_point, output_scale, qmin(), qmax(), 434*4bdc9457SAndroid Build Coastguard Worker transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0, 435*4bdc9457SAndroid Build Coastguard Worker &caches, 436*4bdc9457SAndroid Build Coastguard Worker &fully_connected_op); 437*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 438*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 439*4bdc9457SAndroid Build Coastguard Worker } 440*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 441*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, fully_connected_op); 442*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 443*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 444*4bdc9457SAndroid Build Coastguard Worker xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft)); 445*4bdc9457SAndroid Build Coastguard Worker } 446*4bdc9457SAndroid Build Coastguard Worker 447*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete fully_connected_op. 448*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator); 449*4bdc9457SAndroid Build Coastguard Worker 450*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 451*4bdc9457SAndroid Build Coastguard Worker xnn_setup_fully_connected_nc_qu8( 452*4bdc9457SAndroid Build Coastguard Worker fully_connected_op, 453*4bdc9457SAndroid Build Coastguard Worker batch_size(), 454*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 455*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 456*4bdc9457SAndroid Build Coastguard Worker 457*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 458*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(fully_connected_op, nullptr /* thread pool */)); 459*4bdc9457SAndroid Build Coastguard Worker 460*4bdc9457SAndroid Build Coastguard Worker VerifyQU8(output, output_ref, double(output_zero_point)); 461*4bdc9457SAndroid Build Coastguard Worker 462*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 463*4bdc9457SAndroid Build Coastguard Worker // Create another operator with the same weights cache. 464*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t fully_connected_op2 = nullptr; 465*4bdc9457SAndroid Build Coastguard Worker size_t old_weights_cache_size = weights_cache.cache.weights.size; 466*4bdc9457SAndroid Build Coastguard Worker 467*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 468*4bdc9457SAndroid Build Coastguard Worker xnn_create_fully_connected_nc_qu8( 469*4bdc9457SAndroid Build Coastguard Worker input_channels(), output_channels(), input_stride(), 470*4bdc9457SAndroid Build Coastguard Worker output_stride(), input_zero_point, 1.0f /* input scale */, 471*4bdc9457SAndroid Build Coastguard Worker kernel_zero_point, 1.0f /* kernel scale */, kernel.data(), 472*4bdc9457SAndroid Build Coastguard Worker has_bias() ? bias.data() : nullptr, output_zero_point, 473*4bdc9457SAndroid Build Coastguard Worker output_scale, qmin(), qmax(), 474*4bdc9457SAndroid Build Coastguard Worker transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0, 475*4bdc9457SAndroid Build Coastguard Worker &caches, &fully_connected_op2)); 476*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, fully_connected_op2); 477*4bdc9457SAndroid Build Coastguard Worker 478*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete fully_connected_op. 479*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> 480*4bdc9457SAndroid Build Coastguard Worker auto_fully_connected_op(fully_connected_op2, xnn_delete_operator); 481*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output2(output.size(), UINT8_C(0xA5)); 482*4bdc9457SAndroid Build Coastguard Worker 483*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 484*4bdc9457SAndroid Build Coastguard Worker xnn_setup_fully_connected_nc_qu8( 485*4bdc9457SAndroid Build Coastguard Worker fully_connected_op2, batch_size(), input.data(), 486*4bdc9457SAndroid Build Coastguard Worker output2.data(), nullptr /* thread pool */)); 487*4bdc9457SAndroid Build Coastguard Worker 488*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ( 489*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, 490*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(fully_connected_op2, nullptr /* thread pool */)); 491*4bdc9457SAndroid Build Coastguard Worker 492*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCache(weights_cache, old_weights_cache_size); 493*4bdc9457SAndroid Build Coastguard Worker xnn_release_weights_cache(&weights_cache); 494*4bdc9457SAndroid Build Coastguard Worker 495*4bdc9457SAndroid Build Coastguard Worker VerifyQU8(output2, output_ref, double(output_zero_point)); 496*4bdc9457SAndroid Build Coastguard Worker } 497*4bdc9457SAndroid Build Coastguard Worker 498*4bdc9457SAndroid Build Coastguard Worker } 499*4bdc9457SAndroid Build Coastguard Worker } 500*4bdc9457SAndroid Build Coastguard Worker VerifyQU8(const std::vector<uint8_t> & output,const std::vector<double> & output_ref,double output_zero_point)501*4bdc9457SAndroid Build Coastguard Worker void VerifyQU8(const std::vector<uint8_t>& output, 502*4bdc9457SAndroid Build Coastguard Worker const std::vector<double>& output_ref, 503*4bdc9457SAndroid Build Coastguard Worker double output_zero_point) const { 504*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 505*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < output_channels(); c++) { 506*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax())) 507*4bdc9457SAndroid Build Coastguard Worker << "batch index = " << i << ", channel = " << c; 508*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin())) 509*4bdc9457SAndroid Build Coastguard Worker << "batch index = " << i << ", channel = " << c; 510*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(output_ref[i * output_channels() + c], 511*4bdc9457SAndroid Build Coastguard Worker double(output[i * output_stride() + c]) - output_zero_point, 512*4bdc9457SAndroid Build Coastguard Worker 0.9) 513*4bdc9457SAndroid Build Coastguard Worker << "batch index = " << i << ", channel = " << c; 514*4bdc9457SAndroid Build Coastguard Worker } 515*4bdc9457SAndroid Build Coastguard Worker } 516*4bdc9457SAndroid Build Coastguard Worker } 517*4bdc9457SAndroid Build Coastguard Worker TestF32()518*4bdc9457SAndroid Build Coastguard Worker void TestF32() const { 519*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default); 520*4bdc9457SAndroid Build Coastguard Worker 521*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 522*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 523*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.1f, 1.0f); 524*4bdc9457SAndroid Build Coastguard Worker 525*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + 526*4bdc9457SAndroid Build Coastguard Worker (batch_size() - 1) * input_stride() + input_channels()); 527*4bdc9457SAndroid Build Coastguard Worker std::vector<float> kernel(output_channels() * input_channels()); 528*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias(output_channels()); 529*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output((batch_size() - 1) * output_stride() + output_channels()); 530*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_channels()); 531*4bdc9457SAndroid Build Coastguard Worker 532*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 533*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 534*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); }); 535*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); 536*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), nanf("")); 537*4bdc9457SAndroid Build Coastguard Worker 538*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without renormalization. 539*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) { 540*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 541*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < output_channels(); oc++) { 542*4bdc9457SAndroid Build Coastguard Worker output_ref[i * output_channels() + oc] = bias[oc]; 543*4bdc9457SAndroid Build Coastguard Worker } 544*4bdc9457SAndroid Build Coastguard Worker } 545*4bdc9457SAndroid Build Coastguard Worker } else { 546*4bdc9457SAndroid Build Coastguard Worker std::fill(output_ref.begin(), output_ref.end(), 0.0f); 547*4bdc9457SAndroid Build Coastguard Worker } 548*4bdc9457SAndroid Build Coastguard Worker if (transpose_weights()) { 549*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 550*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < output_channels(); oc++) { 551*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < input_channels(); ic++) { 552*4bdc9457SAndroid Build Coastguard Worker output_ref[i * output_channels() + oc] += 553*4bdc9457SAndroid Build Coastguard Worker input[i * input_stride() + ic] * kernel[ic * output_channels() + oc]; 554*4bdc9457SAndroid Build Coastguard Worker } 555*4bdc9457SAndroid Build Coastguard Worker } 556*4bdc9457SAndroid Build Coastguard Worker } 557*4bdc9457SAndroid Build Coastguard Worker } else { 558*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 559*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < output_channels(); oc++) { 560*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < input_channels(); ic++) { 561*4bdc9457SAndroid Build Coastguard Worker output_ref[i * output_channels() + oc] += 562*4bdc9457SAndroid Build Coastguard Worker input[i * input_stride() + ic] * kernel[oc * input_channels() + ic]; 563*4bdc9457SAndroid Build Coastguard Worker } 564*4bdc9457SAndroid Build Coastguard Worker } 565*4bdc9457SAndroid Build Coastguard Worker } 566*4bdc9457SAndroid Build Coastguard Worker } 567*4bdc9457SAndroid Build Coastguard Worker 568*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 569*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 570*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 571*4bdc9457SAndroid Build Coastguard Worker 572*4bdc9457SAndroid Build Coastguard Worker const float output_min = qmin() == 0 ? -std::numeric_limits<float>::infinity() : 573*4bdc9457SAndroid Build Coastguard Worker accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin()); 574*4bdc9457SAndroid Build Coastguard Worker const float output_max = qmax() == 255 ? std::numeric_limits<float>::infinity() : 575*4bdc9457SAndroid Build Coastguard Worker accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax()); 576*4bdc9457SAndroid Build Coastguard Worker 577*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 578*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) { 579*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 580*4bdc9457SAndroid Build Coastguard Worker } 581*4bdc9457SAndroid Build Coastguard Worker 582*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Fully Connected operator. 583*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 584*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t fully_connected_op = nullptr; 585*4bdc9457SAndroid Build Coastguard Worker 586*4bdc9457SAndroid Build Coastguard Worker xnn_caches caches = { 587*4bdc9457SAndroid Build Coastguard Worker .code_cache = NULL, 588*4bdc9457SAndroid Build Coastguard Worker .weights_cache = NULL, 589*4bdc9457SAndroid Build Coastguard Worker }; 590*4bdc9457SAndroid Build Coastguard Worker xnn_weights_cache weights_cache; 591*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 592*4bdc9457SAndroid Build Coastguard Worker xnn_init_weights_cache(&weights_cache); 593*4bdc9457SAndroid Build Coastguard Worker caches.weights_cache = &weights_cache; 594*4bdc9457SAndroid Build Coastguard Worker } 595*4bdc9457SAndroid Build Coastguard Worker 596*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_create_fully_connected_nc_f32( 597*4bdc9457SAndroid Build Coastguard Worker input_channels(), output_channels(), 598*4bdc9457SAndroid Build Coastguard Worker input_stride(), output_stride(), 599*4bdc9457SAndroid Build Coastguard Worker kernel.data(), has_bias() ? bias.data() : nullptr, 600*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 601*4bdc9457SAndroid Build Coastguard Worker transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0, 602*4bdc9457SAndroid Build Coastguard Worker &caches, 603*4bdc9457SAndroid Build Coastguard Worker &fully_connected_op); 604*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 605*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 606*4bdc9457SAndroid Build Coastguard Worker } 607*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 608*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, fully_connected_op); 609*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 610*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 611*4bdc9457SAndroid Build Coastguard Worker xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft)); 612*4bdc9457SAndroid Build Coastguard Worker } 613*4bdc9457SAndroid Build Coastguard Worker 614*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete fully_connected_op. 615*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator); 616*4bdc9457SAndroid Build Coastguard Worker 617*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 618*4bdc9457SAndroid Build Coastguard Worker xnn_setup_fully_connected_nc_f32( 619*4bdc9457SAndroid Build Coastguard Worker fully_connected_op, 620*4bdc9457SAndroid Build Coastguard Worker batch_size(), 621*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 622*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 623*4bdc9457SAndroid Build Coastguard Worker 624*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 625*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(fully_connected_op, nullptr /* thread pool */)); 626*4bdc9457SAndroid Build Coastguard Worker 627*4bdc9457SAndroid Build Coastguard Worker VerifyF32(output, output_ref, output_max, output_min); 628*4bdc9457SAndroid Build Coastguard Worker 629*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 630*4bdc9457SAndroid Build Coastguard Worker // Create another operator with the same weights cache. 631*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t fully_connected_op2 = nullptr; 632*4bdc9457SAndroid Build Coastguard Worker size_t old_weights_cache_size = weights_cache.cache.weights.size; 633*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 634*4bdc9457SAndroid Build Coastguard Worker xnn_create_fully_connected_nc_f32( 635*4bdc9457SAndroid Build Coastguard Worker input_channels(), output_channels(), input_stride(), 636*4bdc9457SAndroid Build Coastguard Worker output_stride(), kernel.data(), 637*4bdc9457SAndroid Build Coastguard Worker has_bias() ? bias.data() : nullptr, output_min, 638*4bdc9457SAndroid Build Coastguard Worker output_max, 639*4bdc9457SAndroid Build Coastguard Worker transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0, 640*4bdc9457SAndroid Build Coastguard Worker &caches, &fully_connected_op2)); 641*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, fully_connected_op2); 642*4bdc9457SAndroid Build Coastguard Worker 643*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op2, xnn_delete_operator); 644*4bdc9457SAndroid Build Coastguard Worker 645*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output2(output.size(), nanf("")); 646*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 647*4bdc9457SAndroid Build Coastguard Worker xnn_setup_fully_connected_nc_f32( 648*4bdc9457SAndroid Build Coastguard Worker fully_connected_op2, 649*4bdc9457SAndroid Build Coastguard Worker batch_size(), 650*4bdc9457SAndroid Build Coastguard Worker input.data(), output2.data(), 651*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 652*4bdc9457SAndroid Build Coastguard Worker 653*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 654*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(fully_connected_op2, nullptr /* thread pool */)); 655*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCache(weights_cache, old_weights_cache_size); 656*4bdc9457SAndroid Build Coastguard Worker xnn_release_weights_cache(&weights_cache); 657*4bdc9457SAndroid Build Coastguard Worker 658*4bdc9457SAndroid Build Coastguard Worker VerifyF32(output, output_ref, output_max, output_min); 659*4bdc9457SAndroid Build Coastguard Worker } 660*4bdc9457SAndroid Build Coastguard Worker } 661*4bdc9457SAndroid Build Coastguard Worker } 662*4bdc9457SAndroid Build Coastguard Worker VerifyF32(const std::vector<float> & output,const std::vector<float> & output_ref,float output_max,float output_min)663*4bdc9457SAndroid Build Coastguard Worker void VerifyF32(const std::vector<float>& output, 664*4bdc9457SAndroid Build Coastguard Worker const std::vector<float>& output_ref, 665*4bdc9457SAndroid Build Coastguard Worker float output_max, 666*4bdc9457SAndroid Build Coastguard Worker float output_min) const { 667*4bdc9457SAndroid Build Coastguard Worker // Verify results. 668*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 669*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < output_channels(); c++) { 670*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[i * output_stride() + c], output_max) 671*4bdc9457SAndroid Build Coastguard Worker << "batch index = " << i << ", channel = " << c; 672*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[i * output_stride() + c], output_min) 673*4bdc9457SAndroid Build Coastguard Worker << "batch index = " << i << ", channel = " << c; 674*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(output_ref[i * output_channels() + c], 675*4bdc9457SAndroid Build Coastguard Worker output[i * output_stride() + c], 676*4bdc9457SAndroid Build Coastguard Worker 1.0e-4 * std::abs(output_ref[i * output_channels() + c])) 677*4bdc9457SAndroid Build Coastguard Worker << "batch index = " << i << ", channel = " << c; 678*4bdc9457SAndroid Build Coastguard Worker } 679*4bdc9457SAndroid Build Coastguard Worker } 680*4bdc9457SAndroid Build Coastguard Worker } 681*4bdc9457SAndroid Build Coastguard Worker TestF16()682*4bdc9457SAndroid Build Coastguard Worker void TestF16() const { 683*4bdc9457SAndroid Build Coastguard Worker switch (weights_type()) { 684*4bdc9457SAndroid Build Coastguard Worker case WeightsType::Default: 685*4bdc9457SAndroid Build Coastguard Worker break; 686*4bdc9457SAndroid Build Coastguard Worker case WeightsType::FP32: 687*4bdc9457SAndroid Build Coastguard Worker break; 688*4bdc9457SAndroid Build Coastguard Worker default: 689*4bdc9457SAndroid Build Coastguard Worker GTEST_FAIL() << "unexpected weights type"; 690*4bdc9457SAndroid Build Coastguard Worker } 691*4bdc9457SAndroid Build Coastguard Worker 692*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 693*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 694*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.1f, 1.0f); 695*4bdc9457SAndroid Build Coastguard Worker 696*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + 697*4bdc9457SAndroid Build Coastguard Worker (batch_size() - 1) * input_stride() + input_channels()); 698*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> kernel(output_channels() * input_channels()); 699*4bdc9457SAndroid Build Coastguard Worker std::vector<float> kernel_as_float(kernel.size()); 700*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> bias(output_channels()); 701*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias_as_float(bias.size()); 702*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output((batch_size() - 1) * output_stride() + output_channels()); 703*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_channels()); 704*4bdc9457SAndroid Build Coastguard Worker 705*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 706*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 707*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 708*4bdc9457SAndroid Build Coastguard Worker std::transform(kernel.cbegin(), kernel.cend(), kernel_as_float.begin(), fp16_ieee_to_fp32_value); 709*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 710*4bdc9457SAndroid Build Coastguard Worker std::transform(bias.cbegin(), bias.cend(), bias_as_float.begin(), fp16_ieee_to_fp32_value); 711*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 712*4bdc9457SAndroid Build Coastguard Worker 713*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without renormalization. 714*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) { 715*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 716*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < output_channels(); oc++) { 717*4bdc9457SAndroid Build Coastguard Worker output_ref[i * output_channels() + oc] = fp16_ieee_to_fp32_value(bias[oc]); 718*4bdc9457SAndroid Build Coastguard Worker } 719*4bdc9457SAndroid Build Coastguard Worker } 720*4bdc9457SAndroid Build Coastguard Worker } else { 721*4bdc9457SAndroid Build Coastguard Worker std::fill(output_ref.begin(), output_ref.end(), 0.0f); 722*4bdc9457SAndroid Build Coastguard Worker } 723*4bdc9457SAndroid Build Coastguard Worker if (transpose_weights()) { 724*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 725*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < output_channels(); oc++) { 726*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < input_channels(); ic++) { 727*4bdc9457SAndroid Build Coastguard Worker output_ref[i * output_channels() + oc] += 728*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(input[i * input_stride() + ic]) * fp16_ieee_to_fp32_value(kernel[ic * output_channels() + oc]); 729*4bdc9457SAndroid Build Coastguard Worker } 730*4bdc9457SAndroid Build Coastguard Worker } 731*4bdc9457SAndroid Build Coastguard Worker } 732*4bdc9457SAndroid Build Coastguard Worker } else { 733*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 734*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < output_channels(); oc++) { 735*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < input_channels(); ic++) { 736*4bdc9457SAndroid Build Coastguard Worker output_ref[i * output_channels() + oc] += 737*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(input[i * input_stride() + ic]) * fp16_ieee_to_fp32_value(kernel[oc * input_channels() + ic]); 738*4bdc9457SAndroid Build Coastguard Worker } 739*4bdc9457SAndroid Build Coastguard Worker } 740*4bdc9457SAndroid Build Coastguard Worker } 741*4bdc9457SAndroid Build Coastguard Worker } 742*4bdc9457SAndroid Build Coastguard Worker 743*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 744*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 745*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 746*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 747*4bdc9457SAndroid Build Coastguard Worker const float scaled_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + accumulated_range / 255.0f * float(qmin()))); 748*4bdc9457SAndroid Build Coastguard Worker const float scaled_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - accumulated_range / 255.0f * float(255 - qmax()))); 749*4bdc9457SAndroid Build Coastguard Worker const float output_min = scaled_min == scaled_max ? -std::numeric_limits<float>::infinity() : scaled_min; 750*4bdc9457SAndroid Build Coastguard Worker const float output_max = scaled_min == scaled_max ? +std::numeric_limits<float>::infinity() : scaled_max; 751*4bdc9457SAndroid Build Coastguard Worker 752*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 753*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) { 754*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 755*4bdc9457SAndroid Build Coastguard Worker } 756*4bdc9457SAndroid Build Coastguard Worker 757*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Fully Connected operator. 758*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 759*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t fully_connected_op = nullptr; 760*4bdc9457SAndroid Build Coastguard Worker 761*4bdc9457SAndroid Build Coastguard Worker xnn_caches caches = { 762*4bdc9457SAndroid Build Coastguard Worker .code_cache = NULL, 763*4bdc9457SAndroid Build Coastguard Worker .weights_cache = NULL, 764*4bdc9457SAndroid Build Coastguard Worker }; 765*4bdc9457SAndroid Build Coastguard Worker xnn_weights_cache weights_cache; 766*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 767*4bdc9457SAndroid Build Coastguard Worker xnn_init_weights_cache(&weights_cache); 768*4bdc9457SAndroid Build Coastguard Worker caches.weights_cache = &weights_cache; 769*4bdc9457SAndroid Build Coastguard Worker } 770*4bdc9457SAndroid Build Coastguard Worker 771*4bdc9457SAndroid Build Coastguard Worker const void* kernel_data = kernel.data(); 772*4bdc9457SAndroid Build Coastguard Worker const void* bias_data = bias.data(); 773*4bdc9457SAndroid Build Coastguard Worker if (weights_type() == WeightsType::FP32) { 774*4bdc9457SAndroid Build Coastguard Worker kernel_data = kernel_as_float.data(); 775*4bdc9457SAndroid Build Coastguard Worker bias_data = bias_as_float.data(); 776*4bdc9457SAndroid Build Coastguard Worker } 777*4bdc9457SAndroid Build Coastguard Worker uint32_t flags = 0; 778*4bdc9457SAndroid Build Coastguard Worker if (transpose_weights()) { 779*4bdc9457SAndroid Build Coastguard Worker flags |= XNN_FLAG_TRANSPOSE_WEIGHTS; 780*4bdc9457SAndroid Build Coastguard Worker } 781*4bdc9457SAndroid Build Coastguard Worker if (weights_type() == WeightsType::FP32) { 782*4bdc9457SAndroid Build Coastguard Worker flags |= XNN_FLAG_FP32_STATIC_WEIGHTS; 783*4bdc9457SAndroid Build Coastguard Worker } 784*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_create_fully_connected_nc_f16( 785*4bdc9457SAndroid Build Coastguard Worker input_channels(), output_channels(), 786*4bdc9457SAndroid Build Coastguard Worker input_stride(), output_stride(), 787*4bdc9457SAndroid Build Coastguard Worker kernel_data, has_bias() ? bias_data : nullptr, 788*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 789*4bdc9457SAndroid Build Coastguard Worker flags, 790*4bdc9457SAndroid Build Coastguard Worker &caches, 791*4bdc9457SAndroid Build Coastguard Worker &fully_connected_op); 792*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 793*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 794*4bdc9457SAndroid Build Coastguard Worker } 795*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 796*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, fully_connected_op); 797*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 798*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 799*4bdc9457SAndroid Build Coastguard Worker xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft)); 800*4bdc9457SAndroid Build Coastguard Worker } 801*4bdc9457SAndroid Build Coastguard Worker 802*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete fully_connected_op. 803*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator); 804*4bdc9457SAndroid Build Coastguard Worker 805*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 806*4bdc9457SAndroid Build Coastguard Worker xnn_setup_fully_connected_nc_f16( 807*4bdc9457SAndroid Build Coastguard Worker fully_connected_op, 808*4bdc9457SAndroid Build Coastguard Worker batch_size(), 809*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 810*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 811*4bdc9457SAndroid Build Coastguard Worker 812*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 813*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(fully_connected_op, nullptr /* thread pool */)); 814*4bdc9457SAndroid Build Coastguard Worker 815*4bdc9457SAndroid Build Coastguard Worker // Verify results. 816*4bdc9457SAndroid Build Coastguard Worker VerifyF16(output, output_ref, output_max, output_min); 817*4bdc9457SAndroid Build Coastguard Worker 818*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 819*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t fully_connected_op2 = nullptr; 820*4bdc9457SAndroid Build Coastguard Worker size_t old_weights_cache_size = weights_cache.cache.weights.size; 821*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 822*4bdc9457SAndroid Build Coastguard Worker xnn_create_fully_connected_nc_f16( 823*4bdc9457SAndroid Build Coastguard Worker input_channels(), output_channels(), input_stride(), 824*4bdc9457SAndroid Build Coastguard Worker output_stride(), kernel_data, 825*4bdc9457SAndroid Build Coastguard Worker has_bias() ? bias_data : nullptr, output_min, output_max, 826*4bdc9457SAndroid Build Coastguard Worker flags, &caches, &fully_connected_op2)); 827*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 828*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 829*4bdc9457SAndroid Build Coastguard Worker } 830*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 831*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, fully_connected_op2); 832*4bdc9457SAndroid Build Coastguard Worker 833*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete fully_connected_op2. 834*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op2, xnn_delete_operator); 835*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output2(output.size(), UINT16_C(0x7E00) /* NaN */); 836*4bdc9457SAndroid Build Coastguard Worker 837*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 838*4bdc9457SAndroid Build Coastguard Worker xnn_setup_fully_connected_nc_f16( 839*4bdc9457SAndroid Build Coastguard Worker fully_connected_op2, 840*4bdc9457SAndroid Build Coastguard Worker batch_size(), 841*4bdc9457SAndroid Build Coastguard Worker input.data(), output2.data(), 842*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 843*4bdc9457SAndroid Build Coastguard Worker 844*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 845*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(fully_connected_op2, nullptr /* thread pool */)); 846*4bdc9457SAndroid Build Coastguard Worker 847*4bdc9457SAndroid Build Coastguard Worker // Verify results. 848*4bdc9457SAndroid Build Coastguard Worker VerifyF16(output2, output_ref, output_max, output_min); 849*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCache(weights_cache, old_weights_cache_size); 850*4bdc9457SAndroid Build Coastguard Worker xnn_release_weights_cache(&weights_cache); 851*4bdc9457SAndroid Build Coastguard Worker } 852*4bdc9457SAndroid Build Coastguard Worker } 853*4bdc9457SAndroid Build Coastguard Worker } 854*4bdc9457SAndroid Build Coastguard Worker VerifyF16(const std::vector<uint16_t> & output,const std::vector<float> & output_ref,const float output_max,const float output_min)855*4bdc9457SAndroid Build Coastguard Worker void VerifyF16(const std::vector<uint16_t>& output, 856*4bdc9457SAndroid Build Coastguard Worker const std::vector<float>& output_ref, 857*4bdc9457SAndroid Build Coastguard Worker const float output_max, 858*4bdc9457SAndroid Build Coastguard Worker const float output_min) const { 859*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 860*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < output_channels(); c++) { 861*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_max) 862*4bdc9457SAndroid Build Coastguard Worker << "batch index = " << i << ", channel = " << c; 863*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_min) 864*4bdc9457SAndroid Build Coastguard Worker << "batch index = " << i << ", channel = " << c; 865*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 866*4bdc9457SAndroid Build Coastguard Worker output_ref[i * output_channels() + c], 867*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(output[i * output_stride() + c]), 868*4bdc9457SAndroid Build Coastguard Worker 1.0e-2f * std::abs(output_ref[i * output_channels() + c])) 869*4bdc9457SAndroid Build Coastguard Worker << "batch index = " << i << ", channel = " << c; 870*4bdc9457SAndroid Build Coastguard Worker } 871*4bdc9457SAndroid Build Coastguard Worker } 872*4bdc9457SAndroid Build Coastguard Worker } 873*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCache(const xnn_weights_cache & weights_cache,size_t old_size)874*4bdc9457SAndroid Build Coastguard Worker void VerifyWeightsCache(const xnn_weights_cache& weights_cache, size_t old_size) const { 875*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_cache.cache.hits, 1); 876*4bdc9457SAndroid Build Coastguard Worker // Ensure that we did not write more weights to the cache because it was a cache hit. 877*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(old_size, weights_cache.cache.weights.size); 878*4bdc9457SAndroid Build Coastguard Worker }; 879*4bdc9457SAndroid Build Coastguard Worker 880*4bdc9457SAndroid Build Coastguard Worker private: 881*4bdc9457SAndroid Build Coastguard Worker size_t input_channels_{1}; 882*4bdc9457SAndroid Build Coastguard Worker size_t input_stride_{0}; 883*4bdc9457SAndroid Build Coastguard Worker size_t output_channels_{1}; 884*4bdc9457SAndroid Build Coastguard Worker size_t output_stride_{0}; 885*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_{1}; 886*4bdc9457SAndroid Build Coastguard Worker uint8_t qmin_{0}; 887*4bdc9457SAndroid Build Coastguard Worker uint8_t qmax_{255}; 888*4bdc9457SAndroid Build Coastguard Worker bool transpose_weights_{false}; 889*4bdc9457SAndroid Build Coastguard Worker bool has_bias_{true}; 890*4bdc9457SAndroid Build Coastguard Worker WeightsType weights_type_{WeightsType::Default}; 891*4bdc9457SAndroid Build Coastguard Worker bool use_weights_cache_{false}; 892*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{1}; 893*4bdc9457SAndroid Build Coastguard Worker }; 894