1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 9*4bdc9457SAndroid Build Coastguard Worker 10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 11*4bdc9457SAndroid Build Coastguard Worker #include <array> 12*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 13*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 14*4bdc9457SAndroid Build Coastguard Worker #include <initializer_list> 15*4bdc9457SAndroid Build Coastguard Worker #include <numeric> 16*4bdc9457SAndroid Build Coastguard Worker #include <random> 17*4bdc9457SAndroid Build Coastguard Worker #include <vector> 18*4bdc9457SAndroid Build Coastguard Worker 19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 20*4bdc9457SAndroid Build Coastguard Worker 21*4bdc9457SAndroid Build Coastguard Worker 22*4bdc9457SAndroid Build Coastguard Worker class ConstantPadOperatorTester { 23*4bdc9457SAndroid Build Coastguard Worker public: input_shape(std::initializer_list<size_t> input_shape)24*4bdc9457SAndroid Build Coastguard Worker inline ConstantPadOperatorTester& input_shape(std::initializer_list<size_t> input_shape) { 25*4bdc9457SAndroid Build Coastguard Worker assert(input_shape.size() <= XNN_MAX_TENSOR_DIMS); 26*4bdc9457SAndroid Build Coastguard Worker input_shape_ = std::vector<size_t>(input_shape); 27*4bdc9457SAndroid Build Coastguard Worker return *this; 28*4bdc9457SAndroid Build Coastguard Worker } 29*4bdc9457SAndroid Build Coastguard Worker input_shape()30*4bdc9457SAndroid Build Coastguard Worker inline const std::vector<size_t>& input_shape() const { 31*4bdc9457SAndroid Build Coastguard Worker return input_shape_; 32*4bdc9457SAndroid Build Coastguard Worker } 33*4bdc9457SAndroid Build Coastguard Worker input_dim(size_t i)34*4bdc9457SAndroid Build Coastguard Worker inline size_t input_dim(size_t i) const { 35*4bdc9457SAndroid Build Coastguard Worker return i < input_shape_.size() ? input_shape_[i] : 1; 36*4bdc9457SAndroid Build Coastguard Worker } 37*4bdc9457SAndroid Build Coastguard Worker num_dims()38*4bdc9457SAndroid Build Coastguard Worker inline size_t num_dims() const { 39*4bdc9457SAndroid Build Coastguard Worker return input_shape_.size(); 40*4bdc9457SAndroid Build Coastguard Worker } 41*4bdc9457SAndroid Build Coastguard Worker num_input_elements()42*4bdc9457SAndroid Build Coastguard Worker inline size_t num_input_elements() const { 43*4bdc9457SAndroid Build Coastguard Worker return std::accumulate( 44*4bdc9457SAndroid Build Coastguard Worker input_shape_.cbegin(), input_shape_.cend(), size_t(1), std::multiplies<size_t>()); 45*4bdc9457SAndroid Build Coastguard Worker } 46*4bdc9457SAndroid Build Coastguard Worker pre_paddings(std::initializer_list<size_t> pre_paddings)47*4bdc9457SAndroid Build Coastguard Worker inline ConstantPadOperatorTester& pre_paddings(std::initializer_list<size_t> pre_paddings) { 48*4bdc9457SAndroid Build Coastguard Worker assert(pre_paddings.size() <= XNN_MAX_TENSOR_DIMS); 49*4bdc9457SAndroid Build Coastguard Worker pre_paddings_ = std::vector<size_t>(pre_paddings); 50*4bdc9457SAndroid Build Coastguard Worker return *this; 51*4bdc9457SAndroid Build Coastguard Worker } 52*4bdc9457SAndroid Build Coastguard Worker pre_paddings()53*4bdc9457SAndroid Build Coastguard Worker inline const std::vector<size_t>& pre_paddings() const { 54*4bdc9457SAndroid Build Coastguard Worker return pre_paddings_; 55*4bdc9457SAndroid Build Coastguard Worker } 56*4bdc9457SAndroid Build Coastguard Worker pre_padding(size_t i)57*4bdc9457SAndroid Build Coastguard Worker inline size_t pre_padding(size_t i) const { 58*4bdc9457SAndroid Build Coastguard Worker return i < pre_paddings_.size() ? pre_paddings_[i] : 0; 59*4bdc9457SAndroid Build Coastguard Worker } 60*4bdc9457SAndroid Build Coastguard Worker num_pre_paddings()61*4bdc9457SAndroid Build Coastguard Worker inline size_t num_pre_paddings() const { 62*4bdc9457SAndroid Build Coastguard Worker return pre_paddings_.size(); 63*4bdc9457SAndroid Build Coastguard Worker } 64*4bdc9457SAndroid Build Coastguard Worker post_paddings(std::initializer_list<size_t> post_paddings)65*4bdc9457SAndroid Build Coastguard Worker inline ConstantPadOperatorTester& post_paddings(std::initializer_list<size_t> post_paddings) { 66*4bdc9457SAndroid Build Coastguard Worker assert(post_paddings.size() <= XNN_MAX_TENSOR_DIMS); 67*4bdc9457SAndroid Build Coastguard Worker post_paddings_ = std::vector<size_t>(post_paddings); 68*4bdc9457SAndroid Build Coastguard Worker return *this; 69*4bdc9457SAndroid Build Coastguard Worker } 70*4bdc9457SAndroid Build Coastguard Worker post_paddings()71*4bdc9457SAndroid Build Coastguard Worker inline const std::vector<size_t>& post_paddings() const { 72*4bdc9457SAndroid Build Coastguard Worker return post_paddings_; 73*4bdc9457SAndroid Build Coastguard Worker } 74*4bdc9457SAndroid Build Coastguard Worker post_padding(size_t i)75*4bdc9457SAndroid Build Coastguard Worker inline size_t post_padding(size_t i) const { 76*4bdc9457SAndroid Build Coastguard Worker return i < post_paddings_.size() ? post_paddings_[i] : 0; 77*4bdc9457SAndroid Build Coastguard Worker } 78*4bdc9457SAndroid Build Coastguard Worker num_post_paddings()79*4bdc9457SAndroid Build Coastguard Worker inline size_t num_post_paddings() const { 80*4bdc9457SAndroid Build Coastguard Worker return post_paddings_.size(); 81*4bdc9457SAndroid Build Coastguard Worker } 82*4bdc9457SAndroid Build Coastguard Worker output_dim(size_t i)83*4bdc9457SAndroid Build Coastguard Worker inline size_t output_dim(size_t i) const { 84*4bdc9457SAndroid Build Coastguard Worker return pre_padding(i) + input_dim(i) + post_padding(i); 85*4bdc9457SAndroid Build Coastguard Worker } 86*4bdc9457SAndroid Build Coastguard Worker num_output_elements()87*4bdc9457SAndroid Build Coastguard Worker inline size_t num_output_elements() const { 88*4bdc9457SAndroid Build Coastguard Worker size_t elements = 1; 89*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < num_dims(); i++) { 90*4bdc9457SAndroid Build Coastguard Worker elements *= output_dim(i); 91*4bdc9457SAndroid Build Coastguard Worker } 92*4bdc9457SAndroid Build Coastguard Worker return elements; 93*4bdc9457SAndroid Build Coastguard Worker } 94*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)95*4bdc9457SAndroid Build Coastguard Worker inline ConstantPadOperatorTester& iterations(size_t iterations) { 96*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 97*4bdc9457SAndroid Build Coastguard Worker return *this; 98*4bdc9457SAndroid Build Coastguard Worker } 99*4bdc9457SAndroid Build Coastguard Worker iterations()100*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 101*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 102*4bdc9457SAndroid Build Coastguard Worker } 103*4bdc9457SAndroid Build Coastguard Worker TestX8()104*4bdc9457SAndroid Build Coastguard Worker void TestX8() const { 105*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(num_dims(), num_pre_paddings()); 106*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(num_dims(), num_post_paddings()); 107*4bdc9457SAndroid Build Coastguard Worker 108*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 109*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 110*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist( 111*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 112*4bdc9457SAndroid Build Coastguard Worker 113*4bdc9457SAndroid Build Coastguard Worker // Compute generalized shapes. 114*4bdc9457SAndroid Build Coastguard Worker std::array<size_t, XNN_MAX_TENSOR_DIMS> input_dims; 115*4bdc9457SAndroid Build Coastguard Worker std::array<size_t, XNN_MAX_TENSOR_DIMS> input_pre_paddings; 116*4bdc9457SAndroid Build Coastguard Worker std::array<size_t, XNN_MAX_TENSOR_DIMS> input_post_paddings; 117*4bdc9457SAndroid Build Coastguard Worker std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims; 118*4bdc9457SAndroid Build Coastguard Worker std::fill(input_dims.begin(), input_dims.end(), 1); 119*4bdc9457SAndroid Build Coastguard Worker std::fill(input_pre_paddings.begin(), input_pre_paddings.end(), 0); 120*4bdc9457SAndroid Build Coastguard Worker std::fill(input_post_paddings.begin(), input_post_paddings.end(), 0); 121*4bdc9457SAndroid Build Coastguard Worker std::fill(output_dims.begin(), output_dims.end(), 1); 122*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < num_dims(); i++) { 123*4bdc9457SAndroid Build Coastguard Worker input_dims[XNN_MAX_TENSOR_DIMS - num_dims() + i] = input_dim(i); 124*4bdc9457SAndroid Build Coastguard Worker input_pre_paddings[XNN_MAX_TENSOR_DIMS - num_dims() + i] = pre_padding(i); 125*4bdc9457SAndroid Build Coastguard Worker input_post_paddings[XNN_MAX_TENSOR_DIMS - num_dims() + i] = post_padding(i); 126*4bdc9457SAndroid Build Coastguard Worker output_dims[XNN_MAX_TENSOR_DIMS - num_dims() + i] = output_dim(i); 127*4bdc9457SAndroid Build Coastguard Worker } 128*4bdc9457SAndroid Build Coastguard Worker 129*4bdc9457SAndroid Build Coastguard Worker // Compute generalized strides. 130*4bdc9457SAndroid Build Coastguard Worker std::array<size_t, XNN_MAX_TENSOR_DIMS> input_strides; 131*4bdc9457SAndroid Build Coastguard Worker std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides; 132*4bdc9457SAndroid Build Coastguard Worker size_t input_stride = 1, output_stride = 1; 133*4bdc9457SAndroid Build Coastguard Worker for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) { 134*4bdc9457SAndroid Build Coastguard Worker input_strides[i - 1] = input_stride; 135*4bdc9457SAndroid Build Coastguard Worker output_strides[i - 1] = output_stride; 136*4bdc9457SAndroid Build Coastguard Worker input_stride *= input_dims[i - 1]; 137*4bdc9457SAndroid Build Coastguard Worker output_stride *= output_dims[i - 1]; 138*4bdc9457SAndroid Build Coastguard Worker } 139*4bdc9457SAndroid Build Coastguard Worker 140*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + num_input_elements()); 141*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output(num_output_elements()); 142*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output_ref(num_output_elements()); 143*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 144*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 145*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT32_C(0xAA)); 146*4bdc9457SAndroid Build Coastguard Worker const uint8_t padding_value = u8dist(rng); 147*4bdc9457SAndroid Build Coastguard Worker 148*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 149*4bdc9457SAndroid Build Coastguard Worker std::fill(output_ref.begin(), output_ref.end(), padding_value); 150*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < input_dims[0]; i++) { 151*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < input_dims[1]; j++) { 152*4bdc9457SAndroid Build Coastguard Worker for (size_t k = 0; k < input_dims[2]; k++) { 153*4bdc9457SAndroid Build Coastguard Worker for (size_t l = 0; l < input_dims[3]; l++) { 154*4bdc9457SAndroid Build Coastguard Worker for (size_t m = 0; m < input_dims[4]; m++) { 155*4bdc9457SAndroid Build Coastguard Worker for (size_t n = 0; n < input_dims[5]; n++) { 156*4bdc9457SAndroid Build Coastguard Worker const size_t output_index = 157*4bdc9457SAndroid Build Coastguard Worker (i + input_pre_paddings[0]) * output_strides[0] + 158*4bdc9457SAndroid Build Coastguard Worker (j + input_pre_paddings[1]) * output_strides[1] + 159*4bdc9457SAndroid Build Coastguard Worker (k + input_pre_paddings[2]) * output_strides[2] + 160*4bdc9457SAndroid Build Coastguard Worker (l + input_pre_paddings[3]) * output_strides[3] + 161*4bdc9457SAndroid Build Coastguard Worker (m + input_pre_paddings[4]) * output_strides[4] + 162*4bdc9457SAndroid Build Coastguard Worker (n + input_pre_paddings[5]) * output_strides[5]; 163*4bdc9457SAndroid Build Coastguard Worker const size_t input_index = 164*4bdc9457SAndroid Build Coastguard Worker i * input_strides[0] + j * input_strides[1] + k * input_strides[2] + 165*4bdc9457SAndroid Build Coastguard Worker l * input_strides[3] + m * input_strides[4] + n * input_strides[5]; 166*4bdc9457SAndroid Build Coastguard Worker output_ref[output_index] = input[input_index]; 167*4bdc9457SAndroid Build Coastguard Worker } 168*4bdc9457SAndroid Build Coastguard Worker } 169*4bdc9457SAndroid Build Coastguard Worker } 170*4bdc9457SAndroid Build Coastguard Worker } 171*4bdc9457SAndroid Build Coastguard Worker } 172*4bdc9457SAndroid Build Coastguard Worker } 173*4bdc9457SAndroid Build Coastguard Worker 174*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy a binary elementwise operator. 175*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 176*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t pad_op = nullptr; 177*4bdc9457SAndroid Build Coastguard Worker 178*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 179*4bdc9457SAndroid Build Coastguard Worker xnn_create_constant_pad_nd_x8( 180*4bdc9457SAndroid Build Coastguard Worker &padding_value, 0, &pad_op)); 181*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, pad_op); 182*4bdc9457SAndroid Build Coastguard Worker 183*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete pad_op. 184*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_pad_op(pad_op, xnn_delete_operator); 185*4bdc9457SAndroid Build Coastguard Worker 186*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 187*4bdc9457SAndroid Build Coastguard Worker xnn_setup_constant_pad_nd_x8( 188*4bdc9457SAndroid Build Coastguard Worker pad_op, 189*4bdc9457SAndroid Build Coastguard Worker num_dims(), 190*4bdc9457SAndroid Build Coastguard Worker input_shape().data(), pre_paddings().data(), post_paddings().data(), 191*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 192*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 193*4bdc9457SAndroid Build Coastguard Worker 194*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 195*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(pad_op, nullptr /* thread pool */)); 196*4bdc9457SAndroid Build Coastguard Worker 197*4bdc9457SAndroid Build Coastguard Worker // Verify results. 198*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < output_dims[0]; i++) { 199*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < output_dims[1]; j++) { 200*4bdc9457SAndroid Build Coastguard Worker for (size_t k = 0; k < output_dims[2]; k++) { 201*4bdc9457SAndroid Build Coastguard Worker for (size_t l = 0; l < output_dims[3]; l++) { 202*4bdc9457SAndroid Build Coastguard Worker for (size_t m = 0; m < output_dims[4]; m++) { 203*4bdc9457SAndroid Build Coastguard Worker for (size_t n = 0; n < output_dims[5]; n++) { 204*4bdc9457SAndroid Build Coastguard Worker const size_t index = 205*4bdc9457SAndroid Build Coastguard Worker i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + 206*4bdc9457SAndroid Build Coastguard Worker l * output_strides[3] + m * output_strides[4] + n * output_strides[5]; 207*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(output[index], output_ref[index]) 208*4bdc9457SAndroid Build Coastguard Worker << "(i, j, k, l, m, n) = (" 209*4bdc9457SAndroid Build Coastguard Worker << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")" 210*4bdc9457SAndroid Build Coastguard Worker << ", padding value = " << padding_value; 211*4bdc9457SAndroid Build Coastguard Worker } 212*4bdc9457SAndroid Build Coastguard Worker } 213*4bdc9457SAndroid Build Coastguard Worker } 214*4bdc9457SAndroid Build Coastguard Worker } 215*4bdc9457SAndroid Build Coastguard Worker } 216*4bdc9457SAndroid Build Coastguard Worker } 217*4bdc9457SAndroid Build Coastguard Worker } 218*4bdc9457SAndroid Build Coastguard Worker } 219*4bdc9457SAndroid Build Coastguard Worker TestX16()220*4bdc9457SAndroid Build Coastguard Worker void TestX16() const { 221*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(num_dims(), num_pre_paddings()); 222*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(num_dims(), num_post_paddings()); 223*4bdc9457SAndroid Build Coastguard Worker 224*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 225*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 226*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<uint16_t> u16dist; 227*4bdc9457SAndroid Build Coastguard Worker 228*4bdc9457SAndroid Build Coastguard Worker // Compute generalized shapes. 229*4bdc9457SAndroid Build Coastguard Worker std::array<size_t, XNN_MAX_TENSOR_DIMS> input_dims; 230*4bdc9457SAndroid Build Coastguard Worker std::array<size_t, XNN_MAX_TENSOR_DIMS> input_pre_paddings; 231*4bdc9457SAndroid Build Coastguard Worker std::array<size_t, XNN_MAX_TENSOR_DIMS> input_post_paddings; 232*4bdc9457SAndroid Build Coastguard Worker std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims; 233*4bdc9457SAndroid Build Coastguard Worker std::fill(input_dims.begin(), input_dims.end(), 1); 234*4bdc9457SAndroid Build Coastguard Worker std::fill(input_pre_paddings.begin(), input_pre_paddings.end(), 0); 235*4bdc9457SAndroid Build Coastguard Worker std::fill(input_post_paddings.begin(), input_post_paddings.end(), 0); 236*4bdc9457SAndroid Build Coastguard Worker std::fill(output_dims.begin(), output_dims.end(), 1); 237*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < num_dims(); i++) { 238*4bdc9457SAndroid Build Coastguard Worker input_dims[XNN_MAX_TENSOR_DIMS - num_dims() + i] = input_dim(i); 239*4bdc9457SAndroid Build Coastguard Worker input_pre_paddings[XNN_MAX_TENSOR_DIMS - num_dims() + i] = pre_padding(i); 240*4bdc9457SAndroid Build Coastguard Worker input_post_paddings[XNN_MAX_TENSOR_DIMS - num_dims() + i] = post_padding(i); 241*4bdc9457SAndroid Build Coastguard Worker output_dims[XNN_MAX_TENSOR_DIMS - num_dims() + i] = output_dim(i); 242*4bdc9457SAndroid Build Coastguard Worker } 243*4bdc9457SAndroid Build Coastguard Worker 244*4bdc9457SAndroid Build Coastguard Worker // Compute generalized strides. 245*4bdc9457SAndroid Build Coastguard Worker std::array<size_t, XNN_MAX_TENSOR_DIMS> input_strides; 246*4bdc9457SAndroid Build Coastguard Worker std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides; 247*4bdc9457SAndroid Build Coastguard Worker size_t input_stride = 1, output_stride = 1; 248*4bdc9457SAndroid Build Coastguard Worker for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) { 249*4bdc9457SAndroid Build Coastguard Worker input_strides[i - 1] = input_stride; 250*4bdc9457SAndroid Build Coastguard Worker output_strides[i - 1] = output_stride; 251*4bdc9457SAndroid Build Coastguard Worker input_stride *= input_dims[i - 1]; 252*4bdc9457SAndroid Build Coastguard Worker output_stride *= output_dims[i - 1]; 253*4bdc9457SAndroid Build Coastguard Worker } 254*4bdc9457SAndroid Build Coastguard Worker 255*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input_elements()); 256*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output(num_output_elements()); 257*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output_ref(num_output_elements()); 258*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 259*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u16dist(rng); }); 260*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0xDEAD)); 261*4bdc9457SAndroid Build Coastguard Worker const uint16_t padding_value = u16dist(rng); 262*4bdc9457SAndroid Build Coastguard Worker 263*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 264*4bdc9457SAndroid Build Coastguard Worker std::fill(output_ref.begin(), output_ref.end(), padding_value); 265*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < input_dims[0]; i++) { 266*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < input_dims[1]; j++) { 267*4bdc9457SAndroid Build Coastguard Worker for (size_t k = 0; k < input_dims[2]; k++) { 268*4bdc9457SAndroid Build Coastguard Worker for (size_t l = 0; l < input_dims[3]; l++) { 269*4bdc9457SAndroid Build Coastguard Worker for (size_t m = 0; m < input_dims[4]; m++) { 270*4bdc9457SAndroid Build Coastguard Worker for (size_t n = 0; n < input_dims[5]; n++) { 271*4bdc9457SAndroid Build Coastguard Worker const size_t output_index = 272*4bdc9457SAndroid Build Coastguard Worker (i + input_pre_paddings[0]) * output_strides[0] + 273*4bdc9457SAndroid Build Coastguard Worker (j + input_pre_paddings[1]) * output_strides[1] + 274*4bdc9457SAndroid Build Coastguard Worker (k + input_pre_paddings[2]) * output_strides[2] + 275*4bdc9457SAndroid Build Coastguard Worker (l + input_pre_paddings[3]) * output_strides[3] + 276*4bdc9457SAndroid Build Coastguard Worker (m + input_pre_paddings[4]) * output_strides[4] + 277*4bdc9457SAndroid Build Coastguard Worker (n + input_pre_paddings[5]) * output_strides[5]; 278*4bdc9457SAndroid Build Coastguard Worker const size_t input_index = 279*4bdc9457SAndroid Build Coastguard Worker i * input_strides[0] + j * input_strides[1] + k * input_strides[2] + 280*4bdc9457SAndroid Build Coastguard Worker l * input_strides[3] + m * input_strides[4] + n * input_strides[5]; 281*4bdc9457SAndroid Build Coastguard Worker output_ref[output_index] = input[input_index]; 282*4bdc9457SAndroid Build Coastguard Worker } 283*4bdc9457SAndroid Build Coastguard Worker } 284*4bdc9457SAndroid Build Coastguard Worker } 285*4bdc9457SAndroid Build Coastguard Worker } 286*4bdc9457SAndroid Build Coastguard Worker } 287*4bdc9457SAndroid Build Coastguard Worker } 288*4bdc9457SAndroid Build Coastguard Worker 289*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy a binary elementwise operator. 290*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 291*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t pad_op = nullptr; 292*4bdc9457SAndroid Build Coastguard Worker 293*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 294*4bdc9457SAndroid Build Coastguard Worker xnn_create_constant_pad_nd_x16( 295*4bdc9457SAndroid Build Coastguard Worker &padding_value, 0, &pad_op)); 296*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, pad_op); 297*4bdc9457SAndroid Build Coastguard Worker 298*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete pad_op. 299*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_pad_op(pad_op, xnn_delete_operator); 300*4bdc9457SAndroid Build Coastguard Worker 301*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 302*4bdc9457SAndroid Build Coastguard Worker xnn_setup_constant_pad_nd_x16( 303*4bdc9457SAndroid Build Coastguard Worker pad_op, 304*4bdc9457SAndroid Build Coastguard Worker num_dims(), 305*4bdc9457SAndroid Build Coastguard Worker input_shape().data(), pre_paddings().data(), post_paddings().data(), 306*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 307*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 308*4bdc9457SAndroid Build Coastguard Worker 309*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 310*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(pad_op, nullptr /* thread pool */)); 311*4bdc9457SAndroid Build Coastguard Worker 312*4bdc9457SAndroid Build Coastguard Worker // Verify results. 313*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < output_dims[0]; i++) { 314*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < output_dims[1]; j++) { 315*4bdc9457SAndroid Build Coastguard Worker for (size_t k = 0; k < output_dims[2]; k++) { 316*4bdc9457SAndroid Build Coastguard Worker for (size_t l = 0; l < output_dims[3]; l++) { 317*4bdc9457SAndroid Build Coastguard Worker for (size_t m = 0; m < output_dims[4]; m++) { 318*4bdc9457SAndroid Build Coastguard Worker for (size_t n = 0; n < output_dims[5]; n++) { 319*4bdc9457SAndroid Build Coastguard Worker const size_t index = 320*4bdc9457SAndroid Build Coastguard Worker i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + 321*4bdc9457SAndroid Build Coastguard Worker l * output_strides[3] + m * output_strides[4] + n * output_strides[5]; 322*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(output[index], output_ref[index]) 323*4bdc9457SAndroid Build Coastguard Worker << "(i, j, k, l, m, n) = (" 324*4bdc9457SAndroid Build Coastguard Worker << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")" 325*4bdc9457SAndroid Build Coastguard Worker << ", padding value = " << padding_value; 326*4bdc9457SAndroid Build Coastguard Worker } 327*4bdc9457SAndroid Build Coastguard Worker } 328*4bdc9457SAndroid Build Coastguard Worker } 329*4bdc9457SAndroid Build Coastguard Worker } 330*4bdc9457SAndroid Build Coastguard Worker } 331*4bdc9457SAndroid Build Coastguard Worker } 332*4bdc9457SAndroid Build Coastguard Worker } 333*4bdc9457SAndroid Build Coastguard Worker } 334*4bdc9457SAndroid Build Coastguard Worker TestX32()335*4bdc9457SAndroid Build Coastguard Worker void TestX32() const { 336*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(num_dims(), num_pre_paddings()); 337*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(num_dims(), num_post_paddings()); 338*4bdc9457SAndroid Build Coastguard Worker 339*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 340*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 341*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<uint32_t> u32dist; 342*4bdc9457SAndroid Build Coastguard Worker 343*4bdc9457SAndroid Build Coastguard Worker // Compute generalized shapes. 344*4bdc9457SAndroid Build Coastguard Worker std::array<size_t, XNN_MAX_TENSOR_DIMS> input_dims; 345*4bdc9457SAndroid Build Coastguard Worker std::array<size_t, XNN_MAX_TENSOR_DIMS> input_pre_paddings; 346*4bdc9457SAndroid Build Coastguard Worker std::array<size_t, XNN_MAX_TENSOR_DIMS> input_post_paddings; 347*4bdc9457SAndroid Build Coastguard Worker std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims; 348*4bdc9457SAndroid Build Coastguard Worker std::fill(input_dims.begin(), input_dims.end(), 1); 349*4bdc9457SAndroid Build Coastguard Worker std::fill(input_pre_paddings.begin(), input_pre_paddings.end(), 0); 350*4bdc9457SAndroid Build Coastguard Worker std::fill(input_post_paddings.begin(), input_post_paddings.end(), 0); 351*4bdc9457SAndroid Build Coastguard Worker std::fill(output_dims.begin(), output_dims.end(), 1); 352*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < num_dims(); i++) { 353*4bdc9457SAndroid Build Coastguard Worker input_dims[XNN_MAX_TENSOR_DIMS - num_dims() + i] = input_dim(i); 354*4bdc9457SAndroid Build Coastguard Worker input_pre_paddings[XNN_MAX_TENSOR_DIMS - num_dims() + i] = pre_padding(i); 355*4bdc9457SAndroid Build Coastguard Worker input_post_paddings[XNN_MAX_TENSOR_DIMS - num_dims() + i] = post_padding(i); 356*4bdc9457SAndroid Build Coastguard Worker output_dims[XNN_MAX_TENSOR_DIMS - num_dims() + i] = output_dim(i); 357*4bdc9457SAndroid Build Coastguard Worker } 358*4bdc9457SAndroid Build Coastguard Worker 359*4bdc9457SAndroid Build Coastguard Worker // Compute generalized strides. 360*4bdc9457SAndroid Build Coastguard Worker std::array<size_t, XNN_MAX_TENSOR_DIMS> input_strides; 361*4bdc9457SAndroid Build Coastguard Worker std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides; 362*4bdc9457SAndroid Build Coastguard Worker size_t input_stride = 1, output_stride = 1; 363*4bdc9457SAndroid Build Coastguard Worker for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) { 364*4bdc9457SAndroid Build Coastguard Worker input_strides[i - 1] = input_stride; 365*4bdc9457SAndroid Build Coastguard Worker output_strides[i - 1] = output_stride; 366*4bdc9457SAndroid Build Coastguard Worker input_stride *= input_dims[i - 1]; 367*4bdc9457SAndroid Build Coastguard Worker output_stride *= output_dims[i - 1]; 368*4bdc9457SAndroid Build Coastguard Worker } 369*4bdc9457SAndroid Build Coastguard Worker 370*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> input(XNN_EXTRA_BYTES / sizeof(uint32_t) + num_input_elements()); 371*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> output(num_output_elements()); 372*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> output_ref(num_output_elements()); 373*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 374*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u32dist(rng); }); 375*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT32_C(0xDEADBEEF)); 376*4bdc9457SAndroid Build Coastguard Worker const uint32_t padding_value = u32dist(rng); 377*4bdc9457SAndroid Build Coastguard Worker 378*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 379*4bdc9457SAndroid Build Coastguard Worker std::fill(output_ref.begin(), output_ref.end(), padding_value); 380*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < input_dims[0]; i++) { 381*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < input_dims[1]; j++) { 382*4bdc9457SAndroid Build Coastguard Worker for (size_t k = 0; k < input_dims[2]; k++) { 383*4bdc9457SAndroid Build Coastguard Worker for (size_t l = 0; l < input_dims[3]; l++) { 384*4bdc9457SAndroid Build Coastguard Worker for (size_t m = 0; m < input_dims[4]; m++) { 385*4bdc9457SAndroid Build Coastguard Worker for (size_t n = 0; n < input_dims[5]; n++) { 386*4bdc9457SAndroid Build Coastguard Worker const size_t output_index = 387*4bdc9457SAndroid Build Coastguard Worker (i + input_pre_paddings[0]) * output_strides[0] + 388*4bdc9457SAndroid Build Coastguard Worker (j + input_pre_paddings[1]) * output_strides[1] + 389*4bdc9457SAndroid Build Coastguard Worker (k + input_pre_paddings[2]) * output_strides[2] + 390*4bdc9457SAndroid Build Coastguard Worker (l + input_pre_paddings[3]) * output_strides[3] + 391*4bdc9457SAndroid Build Coastguard Worker (m + input_pre_paddings[4]) * output_strides[4] + 392*4bdc9457SAndroid Build Coastguard Worker (n + input_pre_paddings[5]) * output_strides[5]; 393*4bdc9457SAndroid Build Coastguard Worker const size_t input_index = 394*4bdc9457SAndroid Build Coastguard Worker i * input_strides[0] + j * input_strides[1] + k * input_strides[2] + 395*4bdc9457SAndroid Build Coastguard Worker l * input_strides[3] + m * input_strides[4] + n * input_strides[5]; 396*4bdc9457SAndroid Build Coastguard Worker output_ref[output_index] = input[input_index]; 397*4bdc9457SAndroid Build Coastguard Worker } 398*4bdc9457SAndroid Build Coastguard Worker } 399*4bdc9457SAndroid Build Coastguard Worker } 400*4bdc9457SAndroid Build Coastguard Worker } 401*4bdc9457SAndroid Build Coastguard Worker } 402*4bdc9457SAndroid Build Coastguard Worker } 403*4bdc9457SAndroid Build Coastguard Worker 404*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy a binary elementwise operator. 405*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 406*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t pad_op = nullptr; 407*4bdc9457SAndroid Build Coastguard Worker 408*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 409*4bdc9457SAndroid Build Coastguard Worker xnn_create_constant_pad_nd_x32( 410*4bdc9457SAndroid Build Coastguard Worker &padding_value, 0, &pad_op)); 411*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, pad_op); 412*4bdc9457SAndroid Build Coastguard Worker 413*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete pad_op. 414*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_pad_op(pad_op, xnn_delete_operator); 415*4bdc9457SAndroid Build Coastguard Worker 416*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 417*4bdc9457SAndroid Build Coastguard Worker xnn_setup_constant_pad_nd_x32( 418*4bdc9457SAndroid Build Coastguard Worker pad_op, 419*4bdc9457SAndroid Build Coastguard Worker num_dims(), 420*4bdc9457SAndroid Build Coastguard Worker input_shape().data(), pre_paddings().data(), post_paddings().data(), 421*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 422*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 423*4bdc9457SAndroid Build Coastguard Worker 424*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 425*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(pad_op, nullptr /* thread pool */)); 426*4bdc9457SAndroid Build Coastguard Worker 427*4bdc9457SAndroid Build Coastguard Worker // Verify results. 428*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < output_dims[0]; i++) { 429*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < output_dims[1]; j++) { 430*4bdc9457SAndroid Build Coastguard Worker for (size_t k = 0; k < output_dims[2]; k++) { 431*4bdc9457SAndroid Build Coastguard Worker for (size_t l = 0; l < output_dims[3]; l++) { 432*4bdc9457SAndroid Build Coastguard Worker for (size_t m = 0; m < output_dims[4]; m++) { 433*4bdc9457SAndroid Build Coastguard Worker for (size_t n = 0; n < output_dims[5]; n++) { 434*4bdc9457SAndroid Build Coastguard Worker const size_t index = 435*4bdc9457SAndroid Build Coastguard Worker i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + 436*4bdc9457SAndroid Build Coastguard Worker l * output_strides[3] + m * output_strides[4] + n * output_strides[5]; 437*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(output[index], output_ref[index]) 438*4bdc9457SAndroid Build Coastguard Worker << "(i, j, k, l, m, n) = (" 439*4bdc9457SAndroid Build Coastguard Worker << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")" 440*4bdc9457SAndroid Build Coastguard Worker << ", padding value = " << padding_value; 441*4bdc9457SAndroid Build Coastguard Worker } 442*4bdc9457SAndroid Build Coastguard Worker } 443*4bdc9457SAndroid Build Coastguard Worker } 444*4bdc9457SAndroid Build Coastguard Worker } 445*4bdc9457SAndroid Build Coastguard Worker } 446*4bdc9457SAndroid Build Coastguard Worker } 447*4bdc9457SAndroid Build Coastguard Worker } 448*4bdc9457SAndroid Build Coastguard Worker } 449*4bdc9457SAndroid Build Coastguard Worker 450*4bdc9457SAndroid Build Coastguard Worker private: 451*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> input_shape_; 452*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> pre_paddings_; 453*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> post_paddings_; 454*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{3}; 455*4bdc9457SAndroid Build Coastguard Worker }; 456