xref: /aosp_15_r20/external/XNNPACK/test/subgraph-unary-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #include <array>
7*4bdc9457SAndroid Build Coastguard Worker #include <cstdint>
8*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
9*4bdc9457SAndroid Build Coastguard Worker #include <limits>
10*4bdc9457SAndroid Build Coastguard Worker #include <memory>
11*4bdc9457SAndroid Build Coastguard Worker #include <numeric>
12*4bdc9457SAndroid Build Coastguard Worker #include <random>
13*4bdc9457SAndroid Build Coastguard Worker 
14*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
15*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/node-type.h>
16*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/operator.h>
17*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/requantization.h>
18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/subgraph.h>
19*4bdc9457SAndroid Build Coastguard Worker 
20*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
21*4bdc9457SAndroid Build Coastguard Worker 
22*4bdc9457SAndroid Build Coastguard Worker template <typename InputType, typename OutputType = InputType, size_t min_dim = 0> class UnaryTest : public ::testing::Test {
23*4bdc9457SAndroid Build Coastguard Worker protected:
UnaryTest()24*4bdc9457SAndroid Build Coastguard Worker   UnaryTest()
25*4bdc9457SAndroid Build Coastguard Worker   {
26*4bdc9457SAndroid Build Coastguard Worker     random_device = std::unique_ptr<std::random_device>(new std::random_device());
27*4bdc9457SAndroid Build Coastguard Worker     rng = std::mt19937((*random_device)());
28*4bdc9457SAndroid Build Coastguard Worker     shape_dist = std::uniform_int_distribution<size_t>(min_dim, XNN_MAX_TENSOR_DIMS);
29*4bdc9457SAndroid Build Coastguard Worker     dim_dist = std::uniform_int_distribution<size_t>(1, 9);
30*4bdc9457SAndroid Build Coastguard Worker     i8dist =
31*4bdc9457SAndroid Build Coastguard Worker       std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
32*4bdc9457SAndroid Build Coastguard Worker     u8dist =
33*4bdc9457SAndroid Build Coastguard Worker       std::uniform_int_distribution<int32_t>(std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
34*4bdc9457SAndroid Build Coastguard Worker     u32dist = std::uniform_int_distribution<uint32_t>();
35*4bdc9457SAndroid Build Coastguard Worker     scale_dist = std::uniform_real_distribution<float>(0.1f, 10.0f);
36*4bdc9457SAndroid Build Coastguard Worker     f32dist = std::uniform_real_distribution<float>(0.01f, 1.0f);
37*4bdc9457SAndroid Build Coastguard Worker     dims = RandomShape();
38*4bdc9457SAndroid Build Coastguard Worker     channels = dims.empty() ? 1 : dims.back();
39*4bdc9457SAndroid Build Coastguard Worker     xnn_shape shape = {
40*4bdc9457SAndroid Build Coastguard Worker       .num_dims = dims.size(),
41*4bdc9457SAndroid Build Coastguard Worker     };
42*4bdc9457SAndroid Build Coastguard Worker     memcpy(shape.dim, dims.data(), dims.size() * sizeof(size_t));
43*4bdc9457SAndroid Build Coastguard Worker     batch_size = xnn_shape_multiply_non_channel_dims(&shape);
44*4bdc9457SAndroid Build Coastguard Worker     num_output_elements = batch_size * channels;
45*4bdc9457SAndroid Build Coastguard Worker     scale = scale_dist(rng);
46*4bdc9457SAndroid Build Coastguard Worker     signed_zero_point = i8dist(rng);
47*4bdc9457SAndroid Build Coastguard Worker     unsigned_zero_point = u8dist(rng);
48*4bdc9457SAndroid Build Coastguard Worker 
49*4bdc9457SAndroid Build Coastguard Worker     input = std::vector<InputType>(num_output_elements + XNN_EXTRA_BYTES / sizeof(InputType));
50*4bdc9457SAndroid Build Coastguard Worker     operator_output = std::vector<OutputType>(num_output_elements);
51*4bdc9457SAndroid Build Coastguard Worker     subgraph_output = std::vector<OutputType>(num_output_elements);
52*4bdc9457SAndroid Build Coastguard Worker   }
53*4bdc9457SAndroid Build Coastguard Worker 
RandomShape()54*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> RandomShape() {
55*4bdc9457SAndroid Build Coastguard Worker     std::vector<size_t> dims(shape_dist(rng));
56*4bdc9457SAndroid Build Coastguard Worker     std::generate(dims.begin(), dims.end(), [&] { return dim_dist(rng); });
57*4bdc9457SAndroid Build Coastguard Worker     return dims;
58*4bdc9457SAndroid Build Coastguard Worker   }
59*4bdc9457SAndroid Build Coastguard Worker 
NumElements(std::vector<size_t> & dims)60*4bdc9457SAndroid Build Coastguard Worker   size_t NumElements(std::vector<size_t>& dims)
61*4bdc9457SAndroid Build Coastguard Worker   {
62*4bdc9457SAndroid Build Coastguard Worker     return std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>());
63*4bdc9457SAndroid Build Coastguard Worker   }
64*4bdc9457SAndroid Build Coastguard Worker 
65*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<std::random_device> random_device;
66*4bdc9457SAndroid Build Coastguard Worker   std::mt19937 rng;
67*4bdc9457SAndroid Build Coastguard Worker   std::uniform_int_distribution<size_t> shape_dist;
68*4bdc9457SAndroid Build Coastguard Worker   std::uniform_int_distribution<size_t> dim_dist;
69*4bdc9457SAndroid Build Coastguard Worker   std::uniform_real_distribution<float> scale_dist;
70*4bdc9457SAndroid Build Coastguard Worker   std::uniform_int_distribution<int32_t> i8dist;
71*4bdc9457SAndroid Build Coastguard Worker   std::uniform_int_distribution<int32_t> u8dist;
72*4bdc9457SAndroid Build Coastguard Worker   std::uniform_int_distribution<uint32_t> u32dist;
73*4bdc9457SAndroid Build Coastguard Worker   std::uniform_real_distribution<float> f32dist;
74*4bdc9457SAndroid Build Coastguard Worker 
75*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> dims;
76*4bdc9457SAndroid Build Coastguard Worker 
77*4bdc9457SAndroid Build Coastguard Worker   uint32_t input_id;
78*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id;
79*4bdc9457SAndroid Build Coastguard Worker 
80*4bdc9457SAndroid Build Coastguard Worker   size_t channels;
81*4bdc9457SAndroid Build Coastguard Worker   size_t batch_size;
82*4bdc9457SAndroid Build Coastguard Worker   size_t num_output_elements;
83*4bdc9457SAndroid Build Coastguard Worker   float scale;
84*4bdc9457SAndroid Build Coastguard Worker   int32_t signed_zero_point;
85*4bdc9457SAndroid Build Coastguard Worker   int32_t unsigned_zero_point;
86*4bdc9457SAndroid Build Coastguard Worker 
87*4bdc9457SAndroid Build Coastguard Worker   std::vector<InputType> input;
88*4bdc9457SAndroid Build Coastguard Worker   std::vector<OutputType> operator_output;
89*4bdc9457SAndroid Build Coastguard Worker   std::vector<OutputType> subgraph_output;
90*4bdc9457SAndroid Build Coastguard Worker };
91