1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #include <array> 7*4bdc9457SAndroid Build Coastguard Worker #include <cstdint> 8*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 9*4bdc9457SAndroid Build Coastguard Worker #include <limits> 10*4bdc9457SAndroid Build Coastguard Worker #include <memory> 11*4bdc9457SAndroid Build Coastguard Worker #include <numeric> 12*4bdc9457SAndroid Build Coastguard Worker #include <random> 13*4bdc9457SAndroid Build Coastguard Worker 14*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 15*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/node-type.h> 16*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/operator.h> 17*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/requantization.h> 18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/subgraph.h> 19*4bdc9457SAndroid Build Coastguard Worker 20*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 21*4bdc9457SAndroid Build Coastguard Worker 22*4bdc9457SAndroid Build Coastguard Worker template <typename InputType, typename OutputType = InputType, size_t min_dim = 0> class UnaryTest : public ::testing::Test { 23*4bdc9457SAndroid Build Coastguard Worker protected: UnaryTest()24*4bdc9457SAndroid Build Coastguard Worker UnaryTest() 25*4bdc9457SAndroid Build Coastguard Worker { 26*4bdc9457SAndroid Build Coastguard Worker random_device = std::unique_ptr<std::random_device>(new std::random_device()); 27*4bdc9457SAndroid Build Coastguard Worker rng = std::mt19937((*random_device)()); 28*4bdc9457SAndroid Build Coastguard Worker shape_dist = std::uniform_int_distribution<size_t>(min_dim, XNN_MAX_TENSOR_DIMS); 29*4bdc9457SAndroid Build Coastguard Worker dim_dist = std::uniform_int_distribution<size_t>(1, 9); 30*4bdc9457SAndroid Build Coastguard Worker i8dist = 31*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()); 32*4bdc9457SAndroid Build Coastguard Worker u8dist = 33*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 34*4bdc9457SAndroid Build Coastguard Worker u32dist = std::uniform_int_distribution<uint32_t>(); 35*4bdc9457SAndroid Build Coastguard Worker scale_dist = std::uniform_real_distribution<float>(0.1f, 10.0f); 36*4bdc9457SAndroid Build Coastguard Worker f32dist = std::uniform_real_distribution<float>(0.01f, 1.0f); 37*4bdc9457SAndroid Build Coastguard Worker dims = RandomShape(); 38*4bdc9457SAndroid Build Coastguard Worker channels = dims.empty() ? 1 : dims.back(); 39*4bdc9457SAndroid Build Coastguard Worker xnn_shape shape = { 40*4bdc9457SAndroid Build Coastguard Worker .num_dims = dims.size(), 41*4bdc9457SAndroid Build Coastguard Worker }; 42*4bdc9457SAndroid Build Coastguard Worker memcpy(shape.dim, dims.data(), dims.size() * sizeof(size_t)); 43*4bdc9457SAndroid Build Coastguard Worker batch_size = xnn_shape_multiply_non_channel_dims(&shape); 44*4bdc9457SAndroid Build Coastguard Worker num_output_elements = batch_size * channels; 45*4bdc9457SAndroid Build Coastguard Worker scale = scale_dist(rng); 46*4bdc9457SAndroid Build Coastguard Worker signed_zero_point = i8dist(rng); 47*4bdc9457SAndroid Build Coastguard Worker unsigned_zero_point = u8dist(rng); 48*4bdc9457SAndroid Build Coastguard Worker 49*4bdc9457SAndroid Build Coastguard Worker input = std::vector<InputType>(num_output_elements + XNN_EXTRA_BYTES / sizeof(InputType)); 50*4bdc9457SAndroid Build Coastguard Worker operator_output = std::vector<OutputType>(num_output_elements); 51*4bdc9457SAndroid Build Coastguard Worker subgraph_output = std::vector<OutputType>(num_output_elements); 52*4bdc9457SAndroid Build Coastguard Worker } 53*4bdc9457SAndroid Build Coastguard Worker RandomShape()54*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> RandomShape() { 55*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> dims(shape_dist(rng)); 56*4bdc9457SAndroid Build Coastguard Worker std::generate(dims.begin(), dims.end(), [&] { return dim_dist(rng); }); 57*4bdc9457SAndroid Build Coastguard Worker return dims; 58*4bdc9457SAndroid Build Coastguard Worker } 59*4bdc9457SAndroid Build Coastguard Worker NumElements(std::vector<size_t> & dims)60*4bdc9457SAndroid Build Coastguard Worker size_t NumElements(std::vector<size_t>& dims) 61*4bdc9457SAndroid Build Coastguard Worker { 62*4bdc9457SAndroid Build Coastguard Worker return std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>()); 63*4bdc9457SAndroid Build Coastguard Worker } 64*4bdc9457SAndroid Build Coastguard Worker 65*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<std::random_device> random_device; 66*4bdc9457SAndroid Build Coastguard Worker std::mt19937 rng; 67*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<size_t> shape_dist; 68*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<size_t> dim_dist; 69*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> scale_dist; 70*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i8dist; 71*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist; 72*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<uint32_t> u32dist; 73*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 74*4bdc9457SAndroid Build Coastguard Worker 75*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> dims; 76*4bdc9457SAndroid Build Coastguard Worker 77*4bdc9457SAndroid Build Coastguard Worker uint32_t input_id; 78*4bdc9457SAndroid Build Coastguard Worker uint32_t output_id; 79*4bdc9457SAndroid Build Coastguard Worker 80*4bdc9457SAndroid Build Coastguard Worker size_t channels; 81*4bdc9457SAndroid Build Coastguard Worker size_t batch_size; 82*4bdc9457SAndroid Build Coastguard Worker size_t num_output_elements; 83*4bdc9457SAndroid Build Coastguard Worker float scale; 84*4bdc9457SAndroid Build Coastguard Worker int32_t signed_zero_point; 85*4bdc9457SAndroid Build Coastguard Worker int32_t unsigned_zero_point; 86*4bdc9457SAndroid Build Coastguard Worker 87*4bdc9457SAndroid Build Coastguard Worker std::vector<InputType> input; 88*4bdc9457SAndroid Build Coastguard Worker std::vector<OutputType> operator_output; 89*4bdc9457SAndroid Build Coastguard Worker std::vector<OutputType> subgraph_output; 90*4bdc9457SAndroid Build Coastguard Worker }; 91