1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2020 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 9*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 10*4bdc9457SAndroid Build Coastguard Worker #include <cmath> 11*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 12*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 13*4bdc9457SAndroid Build Coastguard Worker #include <unordered_map> 14*4bdc9457SAndroid Build Coastguard Worker #include <numeric> 15*4bdc9457SAndroid Build Coastguard Worker #include <random> 16*4bdc9457SAndroid Build Coastguard Worker #include <vector> 17*4bdc9457SAndroid Build Coastguard Worker #include <type_traits> 18*4bdc9457SAndroid Build Coastguard Worker 19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/subgraph.h> 21*4bdc9457SAndroid Build Coastguard Worker 22*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 23*4bdc9457SAndroid Build Coastguard Worker 24*4bdc9457SAndroid Build Coastguard Worker namespace xnnpack { 25*4bdc9457SAndroid Build Coastguard Worker 26*4bdc9457SAndroid Build Coastguard Worker enum class TensorType { 27*4bdc9457SAndroid Build Coastguard Worker kDense, 28*4bdc9457SAndroid Build Coastguard Worker kSparse, 29*4bdc9457SAndroid Build Coastguard Worker }; 30*4bdc9457SAndroid Build Coastguard Worker 31*4bdc9457SAndroid Build Coastguard Worker struct Padding { 32*4bdc9457SAndroid Build Coastguard Worker uint32_t top; 33*4bdc9457SAndroid Build Coastguard Worker uint32_t right; 34*4bdc9457SAndroid Build Coastguard Worker uint32_t bottom; 35*4bdc9457SAndroid Build Coastguard Worker uint32_t left; 36*4bdc9457SAndroid Build Coastguard Worker }; 37*4bdc9457SAndroid Build Coastguard Worker 38*4bdc9457SAndroid Build Coastguard Worker struct HeightWidth { 39*4bdc9457SAndroid Build Coastguard Worker uint32_t height; 40*4bdc9457SAndroid Build Coastguard Worker uint32_t width; 41*4bdc9457SAndroid Build Coastguard Worker }; 42*4bdc9457SAndroid Build Coastguard Worker 43*4bdc9457SAndroid Build Coastguard Worker using Kernel = HeightWidth; 44*4bdc9457SAndroid Build Coastguard Worker using Subsampling = HeightWidth; 45*4bdc9457SAndroid Build Coastguard Worker using Dilation = HeightWidth; 46*4bdc9457SAndroid Build Coastguard Worker using Upsampling = HeightWidth; 47*4bdc9457SAndroid Build Coastguard Worker using Adjustment = HeightWidth; 48*4bdc9457SAndroid Build Coastguard Worker 49*4bdc9457SAndroid Build Coastguard Worker struct ConvolutionParams { 50*4bdc9457SAndroid Build Coastguard Worker Padding padding; 51*4bdc9457SAndroid Build Coastguard Worker Kernel kernel; 52*4bdc9457SAndroid Build Coastguard Worker Subsampling subsampling; 53*4bdc9457SAndroid Build Coastguard Worker Dilation dilation; 54*4bdc9457SAndroid Build Coastguard Worker uint32_t groups; 55*4bdc9457SAndroid Build Coastguard Worker uint32_t group_input_channels; 56*4bdc9457SAndroid Build Coastguard Worker uint32_t group_output_channels; 57*4bdc9457SAndroid Build Coastguard Worker }; 58*4bdc9457SAndroid Build Coastguard Worker 59*4bdc9457SAndroid Build Coastguard Worker struct DeconvolutionParams { 60*4bdc9457SAndroid Build Coastguard Worker Padding padding; 61*4bdc9457SAndroid Build Coastguard Worker Adjustment adjustment; 62*4bdc9457SAndroid Build Coastguard Worker Kernel kernel; 63*4bdc9457SAndroid Build Coastguard Worker Upsampling upsampling; 64*4bdc9457SAndroid Build Coastguard Worker Dilation dilation; 65*4bdc9457SAndroid Build Coastguard Worker uint32_t groups; 66*4bdc9457SAndroid Build Coastguard Worker uint32_t group_input_channels; 67*4bdc9457SAndroid Build Coastguard Worker uint32_t group_output_channels; 68*4bdc9457SAndroid Build Coastguard Worker }; 69*4bdc9457SAndroid Build Coastguard Worker 70*4bdc9457SAndroid Build Coastguard Worker struct DepthwiseConvolutionParams { 71*4bdc9457SAndroid Build Coastguard Worker Padding padding; 72*4bdc9457SAndroid Build Coastguard Worker Kernel kernel; 73*4bdc9457SAndroid Build Coastguard Worker Subsampling subsampling; 74*4bdc9457SAndroid Build Coastguard Worker Dilation dilation; 75*4bdc9457SAndroid Build Coastguard Worker uint32_t depth_multiplier; 76*4bdc9457SAndroid Build Coastguard Worker uint32_t input_channels; 77*4bdc9457SAndroid Build Coastguard Worker }; 78*4bdc9457SAndroid Build Coastguard Worker 79*4bdc9457SAndroid Build Coastguard Worker class SubgraphTester { 80*4bdc9457SAndroid Build Coastguard Worker public: SubgraphTester(uint32_t external_value_ids)81*4bdc9457SAndroid Build Coastguard Worker explicit SubgraphTester(uint32_t external_value_ids) { 82*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_initialize(nullptr); 83*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(status, xnn_status_success); 84*4bdc9457SAndroid Build Coastguard Worker 85*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph_ptr = nullptr; 86*4bdc9457SAndroid Build Coastguard Worker status = xnn_create_subgraph(external_value_ids, 0 /* flags */, &subgraph_ptr); 87*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(status, xnn_status_success); 88*4bdc9457SAndroid Build Coastguard Worker subgraph_.reset(subgraph_ptr); 89*4bdc9457SAndroid Build Coastguard Worker 90*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 91*4bdc9457SAndroid Build Coastguard Worker rng_ = std::mt19937(random_device()); 92*4bdc9457SAndroid Build Coastguard Worker } 93*4bdc9457SAndroid Build Coastguard Worker 94*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& AddDynamicTensorF32(const std::vector<size_t>& dims, 95*4bdc9457SAndroid Build Coastguard Worker uint32_t external_id, 96*4bdc9457SAndroid Build Coastguard Worker uint32_t flags = 0) { 97*4bdc9457SAndroid Build Coastguard Worker uint32_t id_out = 0; 98*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = 99*4bdc9457SAndroid Build Coastguard Worker xnn_define_tensor_value(subgraph_.get(), xnn_datatype_fp32, dims.size(), 100*4bdc9457SAndroid Build Coastguard Worker dims.data(), nullptr, external_id, flags, &id_out); 101*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(status, xnn_status_success); 102*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(id_out, external_id); 103*4bdc9457SAndroid Build Coastguard Worker 104*4bdc9457SAndroid Build Coastguard Worker return *this; 105*4bdc9457SAndroid Build Coastguard Worker } 106*4bdc9457SAndroid Build Coastguard Worker 107*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& AddStaticTensorF32(const std::vector<size_t>& dims, 108*4bdc9457SAndroid Build Coastguard Worker TensorType tensor_type, 109*4bdc9457SAndroid Build Coastguard Worker uint32_t external_id, 110*4bdc9457SAndroid Build Coastguard Worker uint32_t flags = 0) { 111*4bdc9457SAndroid Build Coastguard Worker const size_t num_elements = NumElements(dims); 112*4bdc9457SAndroid Build Coastguard Worker static_data_.emplace_back(num_elements * sizeof(float)); 113*4bdc9457SAndroid Build Coastguard Worker float* data = reinterpret_cast<float*>(static_data_.back().data()); 114*4bdc9457SAndroid Build Coastguard Worker 115*4bdc9457SAndroid Build Coastguard Worker if (tensor_type == TensorType::kDense) { 116*4bdc9457SAndroid Build Coastguard Worker std::generate(data, data + num_elements, [&]() { return f32dist(rng_); }); 117*4bdc9457SAndroid Build Coastguard Worker } else { 118*4bdc9457SAndroid Build Coastguard Worker // Create tensor with 90% sparsity in two steps: 119*4bdc9457SAndroid Build Coastguard Worker // 1. Generate non-zero elements in the beginning of the vector 120*4bdc9457SAndroid Build Coastguard Worker // 2. Randomize positions of non-zero elements 121*4bdc9457SAndroid Build Coastguard Worker const size_t num_nonzero_elements = num_elements / 10; 122*4bdc9457SAndroid Build Coastguard Worker std::generate(data, data + num_nonzero_elements, [&]() { return f32dist(rng_); }); 123*4bdc9457SAndroid Build Coastguard Worker std::shuffle(data, data + num_elements, rng_); 124*4bdc9457SAndroid Build Coastguard Worker } 125*4bdc9457SAndroid Build Coastguard Worker uint32_t id_out; 126*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = 127*4bdc9457SAndroid Build Coastguard Worker xnn_define_tensor_value(subgraph_.get(), xnn_datatype_fp32, dims.size(), 128*4bdc9457SAndroid Build Coastguard Worker dims.data(), data, external_id, flags, &id_out); 129*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(status, xnn_status_success); 130*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(id_out, external_id); 131*4bdc9457SAndroid Build Coastguard Worker return *this; 132*4bdc9457SAndroid Build Coastguard Worker } 133*4bdc9457SAndroid Build Coastguard Worker 134*4bdc9457SAndroid Build Coastguard Worker AddInputTensorF32(const std::vector<size_t> & dims,uint32_t external_id)135*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& AddInputTensorF32(const std::vector<size_t>& dims, uint32_t external_id) { 136*4bdc9457SAndroid Build Coastguard Worker AddDynamicTensorF32(dims, external_id, XNN_VALUE_FLAG_EXTERNAL_INPUT); 137*4bdc9457SAndroid Build Coastguard Worker size_t num_elements = NumElements(dims); 138*4bdc9457SAndroid Build Coastguard Worker auto input = std::vector<char>(num_elements * sizeof(float) + XNN_EXTRA_BYTES * sizeof(char)); 139*4bdc9457SAndroid Build Coastguard Worker float* data = reinterpret_cast<float*>(input.data()); 140*4bdc9457SAndroid Build Coastguard Worker std::generate(data, data + num_elements, [&]() { return f32dist(rng_); }); 141*4bdc9457SAndroid Build Coastguard Worker auto it = external_tensors_.insert({external_id, input}); 142*4bdc9457SAndroid Build Coastguard Worker EXPECT_TRUE(it.second); 143*4bdc9457SAndroid Build Coastguard Worker return *this; 144*4bdc9457SAndroid Build Coastguard Worker } 145*4bdc9457SAndroid Build Coastguard Worker AddOutputTensorF32(const std::vector<size_t> & dims,uint32_t external_id)146*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& AddOutputTensorF32(const std::vector<size_t>& dims, uint32_t external_id) { 147*4bdc9457SAndroid Build Coastguard Worker output_id_ = external_id; 148*4bdc9457SAndroid Build Coastguard Worker AddDynamicTensorF32(dims, external_id, XNN_VALUE_FLAG_EXTERNAL_OUTPUT); 149*4bdc9457SAndroid Build Coastguard Worker size_t num_elements = NumElements(dims); 150*4bdc9457SAndroid Build Coastguard Worker auto output = std::vector<char>(num_elements * sizeof(float)); 151*4bdc9457SAndroid Build Coastguard Worker float* data = reinterpret_cast<float*>(output.data()); 152*4bdc9457SAndroid Build Coastguard Worker std::fill(data, data + num_elements, std::nanf("")); 153*4bdc9457SAndroid Build Coastguard Worker auto it = external_tensors_.insert({external_id, output}); 154*4bdc9457SAndroid Build Coastguard Worker EXPECT_TRUE(it.second); 155*4bdc9457SAndroid Build Coastguard Worker return *this; 156*4bdc9457SAndroid Build Coastguard Worker } 157*4bdc9457SAndroid Build Coastguard Worker AddConstantPad(const size_t * pre_paddings,const size_t * post_paddings,float padding_value,uint32_t input_id,uint32_t output_id)158*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& AddConstantPad( 159*4bdc9457SAndroid Build Coastguard Worker const size_t *pre_paddings, const size_t *post_paddings, 160*4bdc9457SAndroid Build Coastguard Worker float padding_value, uint32_t input_id, uint32_t output_id) { 161*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_define_static_constant_pad( 162*4bdc9457SAndroid Build Coastguard Worker subgraph_.get(), pre_paddings, post_paddings, padding_value, input_id, 163*4bdc9457SAndroid Build Coastguard Worker output_id, 0 /* flags */); 164*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(status, xnn_status_success); 165*4bdc9457SAndroid Build Coastguard Worker return *this; 166*4bdc9457SAndroid Build Coastguard Worker } 167*4bdc9457SAndroid Build Coastguard Worker AddConvolution2D(ConvolutionParams params,uint32_t input_id,uint32_t filter_id,uint32_t bias_id,uint32_t output_id)168*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& AddConvolution2D( 169*4bdc9457SAndroid Build Coastguard Worker ConvolutionParams params, 170*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id, uint32_t filter_id, uint32_t bias_id, 171*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id) { 172*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_define_convolution_2d( 173*4bdc9457SAndroid Build Coastguard Worker subgraph_.get(), params.padding.top, params.padding.right, 174*4bdc9457SAndroid Build Coastguard Worker params.padding.bottom, params.padding.left, params.kernel.height, params.kernel.width, 175*4bdc9457SAndroid Build Coastguard Worker params.subsampling.height, params.subsampling.width, params.dilation.height, params.dilation.width, 176*4bdc9457SAndroid Build Coastguard Worker params.groups, params.group_input_channels, params.group_output_channels, 177*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<float>::infinity(), 178*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<float>::infinity(), input_id, filter_id, bias_id, 179*4bdc9457SAndroid Build Coastguard Worker output_id, 0 /* flags */); 180*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(status, xnn_status_success); 181*4bdc9457SAndroid Build Coastguard Worker 182*4bdc9457SAndroid Build Coastguard Worker return *this; 183*4bdc9457SAndroid Build Coastguard Worker } 184*4bdc9457SAndroid Build Coastguard Worker AddDepthwiseConvolution2D(DepthwiseConvolutionParams params,uint32_t input_id,uint32_t filter_id,uint32_t bias_id,uint32_t output_id)185*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& AddDepthwiseConvolution2D( 186*4bdc9457SAndroid Build Coastguard Worker DepthwiseConvolutionParams params, 187*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id, uint32_t filter_id, uint32_t bias_id, uint32_t output_id) { 188*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_define_depthwise_convolution_2d( 189*4bdc9457SAndroid Build Coastguard Worker subgraph_.get(), params.padding.top, params.padding.right, 190*4bdc9457SAndroid Build Coastguard Worker params.padding.bottom, params.padding.left, params.kernel.height, params.kernel.width, 191*4bdc9457SAndroid Build Coastguard Worker params.subsampling.height, params.subsampling.width, params.dilation.height, params.dilation.width, 192*4bdc9457SAndroid Build Coastguard Worker params.depth_multiplier, params.input_channels, 193*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<float>::infinity(), 194*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<float>::infinity(), input_id, filter_id, bias_id, 195*4bdc9457SAndroid Build Coastguard Worker output_id, 0 /* flags */); 196*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(status, xnn_status_success); 197*4bdc9457SAndroid Build Coastguard Worker 198*4bdc9457SAndroid Build Coastguard Worker return *this; 199*4bdc9457SAndroid Build Coastguard Worker } 200*4bdc9457SAndroid Build Coastguard Worker AddAddition(uint32_t input_id1,uint32_t input_id2,uint32_t output_id)201*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& AddAddition(uint32_t input_id1, uint32_t input_id2, uint32_t output_id) { 202*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = 203*4bdc9457SAndroid Build Coastguard Worker xnn_define_add2(subgraph_.get(), -std::numeric_limits<float>::infinity(), 204*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<float>::infinity(), input_id1, 205*4bdc9457SAndroid Build Coastguard Worker input_id2, output_id, 0 /* flags */); 206*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(status, xnn_status_success); 207*4bdc9457SAndroid Build Coastguard Worker 208*4bdc9457SAndroid Build Coastguard Worker return *this; 209*4bdc9457SAndroid Build Coastguard Worker } 210*4bdc9457SAndroid Build Coastguard Worker AddAveragePooling2D(uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t pooling_height,uint32_t pooling_width,uint32_t stride_height,uint32_t stride_width,uint32_t input_id,uint32_t output_id)211*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& AddAveragePooling2D( 212*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_top, uint32_t input_padding_right, 213*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_bottom, uint32_t input_padding_left, 214*4bdc9457SAndroid Build Coastguard Worker uint32_t pooling_height, uint32_t pooling_width, uint32_t stride_height, 215*4bdc9457SAndroid Build Coastguard Worker uint32_t stride_width, uint32_t input_id, uint32_t output_id) { 216*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_define_average_pooling_2d( 217*4bdc9457SAndroid Build Coastguard Worker subgraph_.get(), input_padding_top, input_padding_right, 218*4bdc9457SAndroid Build Coastguard Worker input_padding_bottom, input_padding_left, pooling_height, pooling_width, 219*4bdc9457SAndroid Build Coastguard Worker stride_height, stride_width, -std::numeric_limits<float>::infinity(), 220*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<float>::infinity(), input_id, output_id, 221*4bdc9457SAndroid Build Coastguard Worker 0 /* flags */); 222*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(status, xnn_status_success); 223*4bdc9457SAndroid Build Coastguard Worker 224*4bdc9457SAndroid Build Coastguard Worker return *this; 225*4bdc9457SAndroid Build Coastguard Worker } 226*4bdc9457SAndroid Build Coastguard Worker AddClamp(float output_min,float output_max,uint32_t input_id,uint32_t output_id)227*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& AddClamp(float output_min, float output_max, uint32_t input_id, uint32_t output_id) { 228*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = 229*4bdc9457SAndroid Build Coastguard Worker xnn_define_clamp(subgraph_.get(), output_min, output_max, input_id, output_id, 0 /* flags */); 230*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(status, xnn_status_success); 231*4bdc9457SAndroid Build Coastguard Worker 232*4bdc9457SAndroid Build Coastguard Worker return *this; 233*4bdc9457SAndroid Build Coastguard Worker } 234*4bdc9457SAndroid Build Coastguard Worker AddDeconvolution2D(uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t adjustment_height,uint32_t adjustment_width,uint32_t kernel_height,uint32_t kernel_width,uint32_t upsampling_height,uint32_t upsampling_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,uint32_t input_id,uint32_t filter_id,uint32_t bias_id,uint32_t output_id)235*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& AddDeconvolution2D( 236*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_top, uint32_t input_padding_right, 237*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_bottom, uint32_t input_padding_left, 238*4bdc9457SAndroid Build Coastguard Worker uint32_t adjustment_height, uint32_t adjustment_width, 239*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_height, uint32_t kernel_width, 240*4bdc9457SAndroid Build Coastguard Worker uint32_t upsampling_height, uint32_t upsampling_width, 241*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_height, uint32_t dilation_width, uint32_t groups, 242*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels, size_t group_output_channels, 243*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id, uint32_t filter_id, uint32_t bias_id, 244*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id) { 245*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_define_deconvolution_2d( 246*4bdc9457SAndroid Build Coastguard Worker subgraph_.get(), input_padding_top, input_padding_right, 247*4bdc9457SAndroid Build Coastguard Worker input_padding_bottom, input_padding_left, adjustment_height, 248*4bdc9457SAndroid Build Coastguard Worker adjustment_width, kernel_height, kernel_width, upsampling_height, 249*4bdc9457SAndroid Build Coastguard Worker upsampling_width, dilation_height, dilation_width, groups, 250*4bdc9457SAndroid Build Coastguard Worker group_input_channels, group_output_channels, 251*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<float>::infinity(), 252*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<float>::infinity(), input_id, filter_id, bias_id, 253*4bdc9457SAndroid Build Coastguard Worker output_id, 0 /* flags */); 254*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(status, xnn_status_success); 255*4bdc9457SAndroid Build Coastguard Worker 256*4bdc9457SAndroid Build Coastguard Worker return *this; 257*4bdc9457SAndroid Build Coastguard Worker } 258*4bdc9457SAndroid Build Coastguard Worker AddDeconvolution2D(DeconvolutionParams params,uint32_t input_id,uint32_t filter_id,uint32_t bias_id,uint32_t output_id)259*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& AddDeconvolution2D( 260*4bdc9457SAndroid Build Coastguard Worker DeconvolutionParams params, 261*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id, uint32_t filter_id, uint32_t bias_id, 262*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id) { 263*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_define_deconvolution_2d( 264*4bdc9457SAndroid Build Coastguard Worker subgraph_.get(), params.padding.top, params.padding.right, 265*4bdc9457SAndroid Build Coastguard Worker params.padding.bottom, params.padding.left, params.adjustment.height, 266*4bdc9457SAndroid Build Coastguard Worker params.adjustment.width, params.kernel.height, params.kernel.width, params.upsampling.height, 267*4bdc9457SAndroid Build Coastguard Worker params.upsampling.width, params.dilation.height, params.dilation.width, params.groups, 268*4bdc9457SAndroid Build Coastguard Worker params.group_input_channels, params.group_output_channels, 269*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<float>::infinity(), 270*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<float>::infinity(), input_id, filter_id, bias_id, 271*4bdc9457SAndroid Build Coastguard Worker output_id, 0 /* flags */); 272*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(status, xnn_status_success); 273*4bdc9457SAndroid Build Coastguard Worker 274*4bdc9457SAndroid Build Coastguard Worker return *this; 275*4bdc9457SAndroid Build Coastguard Worker } 276*4bdc9457SAndroid Build Coastguard Worker AddDivide(uint32_t input_id1,uint32_t input_id2,uint32_t output_id)277*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& AddDivide(uint32_t input_id1, uint32_t input_id2, uint32_t output_id) { 278*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = 279*4bdc9457SAndroid Build Coastguard Worker xnn_define_divide(subgraph_.get(), -std::numeric_limits<float>::infinity(), 280*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<float>::infinity(), input_id1, 281*4bdc9457SAndroid Build Coastguard Worker input_id2, output_id, 0 /* flags */); 282*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(status, xnn_status_success); 283*4bdc9457SAndroid Build Coastguard Worker 284*4bdc9457SAndroid Build Coastguard Worker return *this; 285*4bdc9457SAndroid Build Coastguard Worker } 286*4bdc9457SAndroid Build Coastguard Worker AddFullyConnected(uint32_t input_id,uint32_t filter_id,uint32_t bias_id,uint32_t output_id)287*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& AddFullyConnected( 288*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id, uint32_t filter_id, uint32_t bias_id, uint32_t output_id) { 289*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_define_fully_connected( 290*4bdc9457SAndroid Build Coastguard Worker subgraph_.get(), 291*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<float>::infinity(), 292*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<float>::infinity(), input_id, filter_id, bias_id, 293*4bdc9457SAndroid Build Coastguard Worker output_id, 0 /* flags */); 294*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(status, xnn_status_success); 295*4bdc9457SAndroid Build Coastguard Worker 296*4bdc9457SAndroid Build Coastguard Worker return *this; 297*4bdc9457SAndroid Build Coastguard Worker } 298*4bdc9457SAndroid Build Coastguard Worker 299*4bdc9457SAndroid Build Coastguard Worker AddGlobalAveragePooling(uint32_t input_id,uint32_t output_id)300*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& AddGlobalAveragePooling(uint32_t input_id, uint32_t output_id) { 301*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_define_global_average_pooling_2d( 302*4bdc9457SAndroid Build Coastguard Worker subgraph_.get(), -std::numeric_limits<float>::infinity(), 303*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<float>::infinity(), input_id, output_id, 0 /* flags */); 304*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(status, xnn_status_success); 305*4bdc9457SAndroid Build Coastguard Worker 306*4bdc9457SAndroid Build Coastguard Worker return *this; 307*4bdc9457SAndroid Build Coastguard Worker } 308*4bdc9457SAndroid Build Coastguard Worker AddMaxPooling2D(uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t pooling_height,uint32_t pooling_width,uint32_t stride_height,uint32_t stride_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t input_id,uint32_t output_id)309*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& AddMaxPooling2D( 310*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_top, uint32_t input_padding_right, 311*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_bottom, uint32_t input_padding_left, 312*4bdc9457SAndroid Build Coastguard Worker uint32_t pooling_height, uint32_t pooling_width, uint32_t stride_height, 313*4bdc9457SAndroid Build Coastguard Worker uint32_t stride_width, uint32_t dilation_height, uint32_t dilation_width, uint32_t input_id, uint32_t output_id) { 314*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_define_max_pooling_2d( 315*4bdc9457SAndroid Build Coastguard Worker subgraph_.get(), input_padding_top, input_padding_right, 316*4bdc9457SAndroid Build Coastguard Worker input_padding_bottom, input_padding_left, pooling_height, pooling_width, 317*4bdc9457SAndroid Build Coastguard Worker stride_height, stride_width, dilation_height, dilation_width, -std::numeric_limits<float>::infinity(), 318*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<float>::infinity(), input_id, output_id, 319*4bdc9457SAndroid Build Coastguard Worker 0 /* flags */); 320*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(status, xnn_status_success); 321*4bdc9457SAndroid Build Coastguard Worker 322*4bdc9457SAndroid Build Coastguard Worker return *this; 323*4bdc9457SAndroid Build Coastguard Worker } 324*4bdc9457SAndroid Build Coastguard Worker AddMultiply(uint32_t input_id1,uint32_t input_id2,uint32_t output_id)325*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& AddMultiply(uint32_t input_id1, uint32_t input_id2, uint32_t output_id) { 326*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = 327*4bdc9457SAndroid Build Coastguard Worker xnn_define_multiply2(subgraph_.get(), -std::numeric_limits<float>::infinity(), 328*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<float>::infinity(), input_id1, 329*4bdc9457SAndroid Build Coastguard Worker input_id2, output_id, 0 /* flags */); 330*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(status, xnn_status_success); 331*4bdc9457SAndroid Build Coastguard Worker 332*4bdc9457SAndroid Build Coastguard Worker return *this; 333*4bdc9457SAndroid Build Coastguard Worker } 334*4bdc9457SAndroid Build Coastguard Worker AddSubtract(uint32_t input_id1,uint32_t input_id2,uint32_t output_id)335*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& AddSubtract(uint32_t input_id1, uint32_t input_id2, uint32_t output_id) { 336*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = 337*4bdc9457SAndroid Build Coastguard Worker xnn_define_subtract(subgraph_.get(), -std::numeric_limits<float>::infinity(), 338*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<float>::infinity(), input_id1, 339*4bdc9457SAndroid Build Coastguard Worker input_id2, output_id, 0 /* flags */); 340*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(status, xnn_status_success); 341*4bdc9457SAndroid Build Coastguard Worker 342*4bdc9457SAndroid Build Coastguard Worker return *this; 343*4bdc9457SAndroid Build Coastguard Worker } 344*4bdc9457SAndroid Build Coastguard Worker Optimize()345*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& Optimize() { 346*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_subgraph_optimize(subgraph_.get(), 0 /* flags */); 347*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(status, xnn_status_success); 348*4bdc9457SAndroid Build Coastguard Worker 349*4bdc9457SAndroid Build Coastguard Worker return *this; 350*4bdc9457SAndroid Build Coastguard Worker } 351*4bdc9457SAndroid Build Coastguard Worker RewriteForNchw()352*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& RewriteForNchw() { 353*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_rewrite_for_nchw(subgraph_.get()); 354*4bdc9457SAndroid Build Coastguard Worker 355*4bdc9457SAndroid Build Coastguard Worker return *this; 356*4bdc9457SAndroid Build Coastguard Worker } 357*4bdc9457SAndroid Build Coastguard Worker RewriteForFp16()358*4bdc9457SAndroid Build Coastguard Worker inline SubgraphTester& RewriteForFp16() { 359*4bdc9457SAndroid Build Coastguard Worker EXPECT_TRUE(xnn_subgraph_rewrite_for_fp16(subgraph_.get())); 360*4bdc9457SAndroid Build Coastguard Worker 361*4bdc9457SAndroid Build Coastguard Worker return *this; 362*4bdc9457SAndroid Build Coastguard Worker } 363*4bdc9457SAndroid Build Coastguard Worker GetLayout(uint32_t value_id)364*4bdc9457SAndroid Build Coastguard Worker inline xnn_layout_type GetLayout(uint32_t value_id) const { 365*4bdc9457SAndroid Build Coastguard Worker return subgraph_->values[value_id].layout; 366*4bdc9457SAndroid Build Coastguard Worker } 367*4bdc9457SAndroid Build Coastguard Worker Value(uint32_t value_id)368*4bdc9457SAndroid Build Coastguard Worker inline const xnn_value* const Value(uint32_t value_id) const { 369*4bdc9457SAndroid Build Coastguard Worker return &subgraph_->values[value_id]; 370*4bdc9457SAndroid Build Coastguard Worker } 371*4bdc9457SAndroid Build Coastguard Worker Node(uint32_t node_id)372*4bdc9457SAndroid Build Coastguard Worker inline const xnn_node* const Node(uint32_t node_id) const { 373*4bdc9457SAndroid Build Coastguard Worker return &subgraph_->nodes[node_id]; 374*4bdc9457SAndroid Build Coastguard Worker } 375*4bdc9457SAndroid Build Coastguard Worker NumNodes()376*4bdc9457SAndroid Build Coastguard Worker inline size_t NumNodes() const { 377*4bdc9457SAndroid Build Coastguard Worker return subgraph_->num_nodes; 378*4bdc9457SAndroid Build Coastguard Worker } 379*4bdc9457SAndroid Build Coastguard Worker 380*4bdc9457SAndroid Build Coastguard Worker protected: 381*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> subgraph_{nullptr, xnn_delete_subgraph}; 382*4bdc9457SAndroid Build Coastguard Worker std::unordered_map<uint32_t, std::vector<char>> external_tensors_; 383*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id_; 384*4bdc9457SAndroid Build Coastguard Worker 385*4bdc9457SAndroid Build Coastguard Worker private: NumElements(const std::vector<size_t> & dims)386*4bdc9457SAndroid Build Coastguard Worker static inline size_t NumElements(const std::vector<size_t>& dims) { 387*4bdc9457SAndroid Build Coastguard Worker return std::accumulate(std::begin(dims), std::end(dims), size_t(1), std::multiplies<size_t>()); 388*4bdc9457SAndroid Build Coastguard Worker } 389*4bdc9457SAndroid Build Coastguard Worker 390*4bdc9457SAndroid Build Coastguard Worker std::vector<std::vector<char>> static_data_; 391*4bdc9457SAndroid Build Coastguard Worker std::mt19937 rng_; 392*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist = std::uniform_real_distribution<float>(-1.0f, +1.0f); 393*4bdc9457SAndroid Build Coastguard Worker 394*4bdc9457SAndroid Build Coastguard Worker }; 395*4bdc9457SAndroid Build Coastguard Worker 396*4bdc9457SAndroid Build Coastguard Worker } // namespace xnnpack 397