xref: /aosp_15_r20/external/XNNPACK/test/subgraph-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2020 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #pragma once
7*4bdc9457SAndroid Build Coastguard Worker 
8*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
9*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
10*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
11*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
12*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
13*4bdc9457SAndroid Build Coastguard Worker #include <unordered_map>
14*4bdc9457SAndroid Build Coastguard Worker #include <numeric>
15*4bdc9457SAndroid Build Coastguard Worker #include <random>
16*4bdc9457SAndroid Build Coastguard Worker #include <vector>
17*4bdc9457SAndroid Build Coastguard Worker #include <type_traits>
18*4bdc9457SAndroid Build Coastguard Worker 
19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/subgraph.h>
21*4bdc9457SAndroid Build Coastguard Worker 
22*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
23*4bdc9457SAndroid Build Coastguard Worker 
24*4bdc9457SAndroid Build Coastguard Worker namespace xnnpack {
25*4bdc9457SAndroid Build Coastguard Worker 
26*4bdc9457SAndroid Build Coastguard Worker enum class TensorType {
27*4bdc9457SAndroid Build Coastguard Worker   kDense,
28*4bdc9457SAndroid Build Coastguard Worker   kSparse,
29*4bdc9457SAndroid Build Coastguard Worker };
30*4bdc9457SAndroid Build Coastguard Worker 
31*4bdc9457SAndroid Build Coastguard Worker struct Padding {
32*4bdc9457SAndroid Build Coastguard Worker   uint32_t top;
33*4bdc9457SAndroid Build Coastguard Worker   uint32_t right;
34*4bdc9457SAndroid Build Coastguard Worker   uint32_t bottom;
35*4bdc9457SAndroid Build Coastguard Worker   uint32_t left;
36*4bdc9457SAndroid Build Coastguard Worker };
37*4bdc9457SAndroid Build Coastguard Worker 
38*4bdc9457SAndroid Build Coastguard Worker struct HeightWidth {
39*4bdc9457SAndroid Build Coastguard Worker   uint32_t height;
40*4bdc9457SAndroid Build Coastguard Worker   uint32_t width;
41*4bdc9457SAndroid Build Coastguard Worker };
42*4bdc9457SAndroid Build Coastguard Worker 
43*4bdc9457SAndroid Build Coastguard Worker using Kernel = HeightWidth;
44*4bdc9457SAndroid Build Coastguard Worker using Subsampling = HeightWidth;
45*4bdc9457SAndroid Build Coastguard Worker using Dilation = HeightWidth;
46*4bdc9457SAndroid Build Coastguard Worker using Upsampling = HeightWidth;
47*4bdc9457SAndroid Build Coastguard Worker using Adjustment = HeightWidth;
48*4bdc9457SAndroid Build Coastguard Worker 
49*4bdc9457SAndroid Build Coastguard Worker struct ConvolutionParams {
50*4bdc9457SAndroid Build Coastguard Worker   Padding padding;
51*4bdc9457SAndroid Build Coastguard Worker   Kernel kernel;
52*4bdc9457SAndroid Build Coastguard Worker   Subsampling subsampling;
53*4bdc9457SAndroid Build Coastguard Worker   Dilation dilation;
54*4bdc9457SAndroid Build Coastguard Worker   uint32_t groups;
55*4bdc9457SAndroid Build Coastguard Worker   uint32_t group_input_channels;
56*4bdc9457SAndroid Build Coastguard Worker   uint32_t group_output_channels;
57*4bdc9457SAndroid Build Coastguard Worker };
58*4bdc9457SAndroid Build Coastguard Worker 
59*4bdc9457SAndroid Build Coastguard Worker struct DeconvolutionParams {
60*4bdc9457SAndroid Build Coastguard Worker   Padding padding;
61*4bdc9457SAndroid Build Coastguard Worker   Adjustment adjustment;
62*4bdc9457SAndroid Build Coastguard Worker   Kernel kernel;
63*4bdc9457SAndroid Build Coastguard Worker   Upsampling upsampling;
64*4bdc9457SAndroid Build Coastguard Worker   Dilation dilation;
65*4bdc9457SAndroid Build Coastguard Worker   uint32_t groups;
66*4bdc9457SAndroid Build Coastguard Worker   uint32_t group_input_channels;
67*4bdc9457SAndroid Build Coastguard Worker   uint32_t group_output_channels;
68*4bdc9457SAndroid Build Coastguard Worker };
69*4bdc9457SAndroid Build Coastguard Worker 
70*4bdc9457SAndroid Build Coastguard Worker struct DepthwiseConvolutionParams {
71*4bdc9457SAndroid Build Coastguard Worker   Padding padding;
72*4bdc9457SAndroid Build Coastguard Worker   Kernel kernel;
73*4bdc9457SAndroid Build Coastguard Worker   Subsampling subsampling;
74*4bdc9457SAndroid Build Coastguard Worker   Dilation dilation;
75*4bdc9457SAndroid Build Coastguard Worker   uint32_t depth_multiplier;
76*4bdc9457SAndroid Build Coastguard Worker   uint32_t input_channels;
77*4bdc9457SAndroid Build Coastguard Worker };
78*4bdc9457SAndroid Build Coastguard Worker 
79*4bdc9457SAndroid Build Coastguard Worker class SubgraphTester {
80*4bdc9457SAndroid Build Coastguard Worker  public:
SubgraphTester(uint32_t external_value_ids)81*4bdc9457SAndroid Build Coastguard Worker   explicit SubgraphTester(uint32_t external_value_ids) {
82*4bdc9457SAndroid Build Coastguard Worker     xnn_status status = xnn_initialize(nullptr);
83*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(status, xnn_status_success);
84*4bdc9457SAndroid Build Coastguard Worker 
85*4bdc9457SAndroid Build Coastguard Worker     xnn_subgraph_t subgraph_ptr = nullptr;
86*4bdc9457SAndroid Build Coastguard Worker     status = xnn_create_subgraph(external_value_ids, 0 /* flags */, &subgraph_ptr);
87*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(status, xnn_status_success);
88*4bdc9457SAndroid Build Coastguard Worker     subgraph_.reset(subgraph_ptr);
89*4bdc9457SAndroid Build Coastguard Worker 
90*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
91*4bdc9457SAndroid Build Coastguard Worker     rng_ = std::mt19937(random_device());
92*4bdc9457SAndroid Build Coastguard Worker   }
93*4bdc9457SAndroid Build Coastguard Worker 
94*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& AddDynamicTensorF32(const std::vector<size_t>& dims,
95*4bdc9457SAndroid Build Coastguard Worker                                    uint32_t external_id,
96*4bdc9457SAndroid Build Coastguard Worker                                    uint32_t flags = 0) {
97*4bdc9457SAndroid Build Coastguard Worker     uint32_t id_out = 0;
98*4bdc9457SAndroid Build Coastguard Worker     const xnn_status status =
99*4bdc9457SAndroid Build Coastguard Worker         xnn_define_tensor_value(subgraph_.get(), xnn_datatype_fp32, dims.size(),
100*4bdc9457SAndroid Build Coastguard Worker                                 dims.data(), nullptr, external_id, flags, &id_out);
101*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(status, xnn_status_success);
102*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(id_out, external_id);
103*4bdc9457SAndroid Build Coastguard Worker 
104*4bdc9457SAndroid Build Coastguard Worker     return *this;
105*4bdc9457SAndroid Build Coastguard Worker   }
106*4bdc9457SAndroid Build Coastguard Worker 
107*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& AddStaticTensorF32(const std::vector<size_t>& dims,
108*4bdc9457SAndroid Build Coastguard Worker                                             TensorType tensor_type,
109*4bdc9457SAndroid Build Coastguard Worker                                             uint32_t external_id,
110*4bdc9457SAndroid Build Coastguard Worker                                             uint32_t flags = 0) {
111*4bdc9457SAndroid Build Coastguard Worker     const size_t num_elements = NumElements(dims);
112*4bdc9457SAndroid Build Coastguard Worker     static_data_.emplace_back(num_elements * sizeof(float));
113*4bdc9457SAndroid Build Coastguard Worker     float* data = reinterpret_cast<float*>(static_data_.back().data());
114*4bdc9457SAndroid Build Coastguard Worker 
115*4bdc9457SAndroid Build Coastguard Worker     if (tensor_type == TensorType::kDense) {
116*4bdc9457SAndroid Build Coastguard Worker       std::generate(data, data + num_elements, [&]() { return f32dist(rng_); });
117*4bdc9457SAndroid Build Coastguard Worker     } else {
118*4bdc9457SAndroid Build Coastguard Worker       // Create tensor with 90% sparsity in two steps:
119*4bdc9457SAndroid Build Coastguard Worker       // 1. Generate non-zero elements in the beginning of the vector
120*4bdc9457SAndroid Build Coastguard Worker       // 2. Randomize positions of non-zero elements
121*4bdc9457SAndroid Build Coastguard Worker       const size_t num_nonzero_elements = num_elements / 10;
122*4bdc9457SAndroid Build Coastguard Worker       std::generate(data, data + num_nonzero_elements, [&]() { return f32dist(rng_); });
123*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(data, data + num_elements, rng_);
124*4bdc9457SAndroid Build Coastguard Worker     }
125*4bdc9457SAndroid Build Coastguard Worker     uint32_t id_out;
126*4bdc9457SAndroid Build Coastguard Worker     const xnn_status status =
127*4bdc9457SAndroid Build Coastguard Worker         xnn_define_tensor_value(subgraph_.get(), xnn_datatype_fp32, dims.size(),
128*4bdc9457SAndroid Build Coastguard Worker                                 dims.data(), data, external_id, flags, &id_out);
129*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(status, xnn_status_success);
130*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(id_out, external_id);
131*4bdc9457SAndroid Build Coastguard Worker     return *this;
132*4bdc9457SAndroid Build Coastguard Worker   }
133*4bdc9457SAndroid Build Coastguard Worker 
134*4bdc9457SAndroid Build Coastguard Worker 
AddInputTensorF32(const std::vector<size_t> & dims,uint32_t external_id)135*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& AddInputTensorF32(const std::vector<size_t>& dims, uint32_t external_id) {
136*4bdc9457SAndroid Build Coastguard Worker     AddDynamicTensorF32(dims, external_id, XNN_VALUE_FLAG_EXTERNAL_INPUT);
137*4bdc9457SAndroid Build Coastguard Worker     size_t num_elements = NumElements(dims);
138*4bdc9457SAndroid Build Coastguard Worker     auto input = std::vector<char>(num_elements * sizeof(float) + XNN_EXTRA_BYTES * sizeof(char));
139*4bdc9457SAndroid Build Coastguard Worker     float* data = reinterpret_cast<float*>(input.data());
140*4bdc9457SAndroid Build Coastguard Worker     std::generate(data, data + num_elements, [&]() { return f32dist(rng_); });
141*4bdc9457SAndroid Build Coastguard Worker     auto it = external_tensors_.insert({external_id, input});
142*4bdc9457SAndroid Build Coastguard Worker     EXPECT_TRUE(it.second);
143*4bdc9457SAndroid Build Coastguard Worker     return *this;
144*4bdc9457SAndroid Build Coastguard Worker   }
145*4bdc9457SAndroid Build Coastguard Worker 
AddOutputTensorF32(const std::vector<size_t> & dims,uint32_t external_id)146*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& AddOutputTensorF32(const std::vector<size_t>& dims, uint32_t external_id) {
147*4bdc9457SAndroid Build Coastguard Worker     output_id_ = external_id;
148*4bdc9457SAndroid Build Coastguard Worker     AddDynamicTensorF32(dims, external_id, XNN_VALUE_FLAG_EXTERNAL_OUTPUT);
149*4bdc9457SAndroid Build Coastguard Worker     size_t num_elements = NumElements(dims);
150*4bdc9457SAndroid Build Coastguard Worker     auto output = std::vector<char>(num_elements * sizeof(float));
151*4bdc9457SAndroid Build Coastguard Worker     float* data = reinterpret_cast<float*>(output.data());
152*4bdc9457SAndroid Build Coastguard Worker     std::fill(data, data + num_elements, std::nanf(""));
153*4bdc9457SAndroid Build Coastguard Worker     auto it = external_tensors_.insert({external_id, output});
154*4bdc9457SAndroid Build Coastguard Worker     EXPECT_TRUE(it.second);
155*4bdc9457SAndroid Build Coastguard Worker     return *this;
156*4bdc9457SAndroid Build Coastguard Worker   }
157*4bdc9457SAndroid Build Coastguard Worker 
AddConstantPad(const size_t * pre_paddings,const size_t * post_paddings,float padding_value,uint32_t input_id,uint32_t output_id)158*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& AddConstantPad(
159*4bdc9457SAndroid Build Coastguard Worker       const size_t *pre_paddings, const size_t *post_paddings,
160*4bdc9457SAndroid Build Coastguard Worker       float padding_value, uint32_t input_id, uint32_t output_id) {
161*4bdc9457SAndroid Build Coastguard Worker     const xnn_status status = xnn_define_static_constant_pad(
162*4bdc9457SAndroid Build Coastguard Worker         subgraph_.get(), pre_paddings, post_paddings, padding_value, input_id,
163*4bdc9457SAndroid Build Coastguard Worker         output_id, 0 /* flags */);
164*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(status, xnn_status_success);
165*4bdc9457SAndroid Build Coastguard Worker     return *this;
166*4bdc9457SAndroid Build Coastguard Worker   }
167*4bdc9457SAndroid Build Coastguard Worker 
AddConvolution2D(ConvolutionParams params,uint32_t input_id,uint32_t filter_id,uint32_t bias_id,uint32_t output_id)168*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& AddConvolution2D(
169*4bdc9457SAndroid Build Coastguard Worker       ConvolutionParams params,
170*4bdc9457SAndroid Build Coastguard Worker       uint32_t input_id, uint32_t filter_id, uint32_t bias_id,
171*4bdc9457SAndroid Build Coastguard Worker       uint32_t output_id) {
172*4bdc9457SAndroid Build Coastguard Worker     const xnn_status status = xnn_define_convolution_2d(
173*4bdc9457SAndroid Build Coastguard Worker         subgraph_.get(), params.padding.top, params.padding.right,
174*4bdc9457SAndroid Build Coastguard Worker         params.padding.bottom, params.padding.left, params.kernel.height, params.kernel.width,
175*4bdc9457SAndroid Build Coastguard Worker         params.subsampling.height, params.subsampling.width, params.dilation.height, params.dilation.width,
176*4bdc9457SAndroid Build Coastguard Worker         params.groups, params.group_input_channels, params.group_output_channels,
177*4bdc9457SAndroid Build Coastguard Worker         -std::numeric_limits<float>::infinity(),
178*4bdc9457SAndroid Build Coastguard Worker         std::numeric_limits<float>::infinity(), input_id, filter_id, bias_id,
179*4bdc9457SAndroid Build Coastguard Worker         output_id, 0 /* flags */);
180*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(status, xnn_status_success);
181*4bdc9457SAndroid Build Coastguard Worker 
182*4bdc9457SAndroid Build Coastguard Worker     return *this;
183*4bdc9457SAndroid Build Coastguard Worker   }
184*4bdc9457SAndroid Build Coastguard Worker 
AddDepthwiseConvolution2D(DepthwiseConvolutionParams params,uint32_t input_id,uint32_t filter_id,uint32_t bias_id,uint32_t output_id)185*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& AddDepthwiseConvolution2D(
186*4bdc9457SAndroid Build Coastguard Worker       DepthwiseConvolutionParams params,
187*4bdc9457SAndroid Build Coastguard Worker       uint32_t input_id, uint32_t filter_id, uint32_t bias_id, uint32_t output_id) {
188*4bdc9457SAndroid Build Coastguard Worker     const xnn_status status = xnn_define_depthwise_convolution_2d(
189*4bdc9457SAndroid Build Coastguard Worker         subgraph_.get(), params.padding.top, params.padding.right,
190*4bdc9457SAndroid Build Coastguard Worker         params.padding.bottom, params.padding.left, params.kernel.height, params.kernel.width,
191*4bdc9457SAndroid Build Coastguard Worker         params.subsampling.height, params.subsampling.width, params.dilation.height, params.dilation.width,
192*4bdc9457SAndroid Build Coastguard Worker         params.depth_multiplier, params.input_channels,
193*4bdc9457SAndroid Build Coastguard Worker         -std::numeric_limits<float>::infinity(),
194*4bdc9457SAndroid Build Coastguard Worker         std::numeric_limits<float>::infinity(), input_id, filter_id, bias_id,
195*4bdc9457SAndroid Build Coastguard Worker         output_id, 0 /* flags */);
196*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(status, xnn_status_success);
197*4bdc9457SAndroid Build Coastguard Worker 
198*4bdc9457SAndroid Build Coastguard Worker     return *this;
199*4bdc9457SAndroid Build Coastguard Worker   }
200*4bdc9457SAndroid Build Coastguard Worker 
AddAddition(uint32_t input_id1,uint32_t input_id2,uint32_t output_id)201*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& AddAddition(uint32_t input_id1, uint32_t input_id2, uint32_t output_id) {
202*4bdc9457SAndroid Build Coastguard Worker     const xnn_status status =
203*4bdc9457SAndroid Build Coastguard Worker         xnn_define_add2(subgraph_.get(), -std::numeric_limits<float>::infinity(),
204*4bdc9457SAndroid Build Coastguard Worker                         std::numeric_limits<float>::infinity(), input_id1,
205*4bdc9457SAndroid Build Coastguard Worker                         input_id2, output_id, 0 /* flags */);
206*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(status, xnn_status_success);
207*4bdc9457SAndroid Build Coastguard Worker 
208*4bdc9457SAndroid Build Coastguard Worker     return *this;
209*4bdc9457SAndroid Build Coastguard Worker   }
210*4bdc9457SAndroid Build Coastguard Worker 
AddAveragePooling2D(uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t pooling_height,uint32_t pooling_width,uint32_t stride_height,uint32_t stride_width,uint32_t input_id,uint32_t output_id)211*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& AddAveragePooling2D(
212*4bdc9457SAndroid Build Coastguard Worker       uint32_t input_padding_top, uint32_t input_padding_right,
213*4bdc9457SAndroid Build Coastguard Worker       uint32_t input_padding_bottom, uint32_t input_padding_left,
214*4bdc9457SAndroid Build Coastguard Worker       uint32_t pooling_height, uint32_t pooling_width, uint32_t stride_height,
215*4bdc9457SAndroid Build Coastguard Worker       uint32_t stride_width, uint32_t input_id, uint32_t output_id) {
216*4bdc9457SAndroid Build Coastguard Worker     const xnn_status status = xnn_define_average_pooling_2d(
217*4bdc9457SAndroid Build Coastguard Worker         subgraph_.get(), input_padding_top, input_padding_right,
218*4bdc9457SAndroid Build Coastguard Worker         input_padding_bottom, input_padding_left, pooling_height, pooling_width,
219*4bdc9457SAndroid Build Coastguard Worker         stride_height, stride_width, -std::numeric_limits<float>::infinity(),
220*4bdc9457SAndroid Build Coastguard Worker         std::numeric_limits<float>::infinity(), input_id, output_id,
221*4bdc9457SAndroid Build Coastguard Worker         0 /* flags */);
222*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(status, xnn_status_success);
223*4bdc9457SAndroid Build Coastguard Worker 
224*4bdc9457SAndroid Build Coastguard Worker     return *this;
225*4bdc9457SAndroid Build Coastguard Worker   }
226*4bdc9457SAndroid Build Coastguard Worker 
AddClamp(float output_min,float output_max,uint32_t input_id,uint32_t output_id)227*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& AddClamp(float output_min, float output_max, uint32_t input_id, uint32_t output_id) {
228*4bdc9457SAndroid Build Coastguard Worker     const xnn_status status =
229*4bdc9457SAndroid Build Coastguard Worker         xnn_define_clamp(subgraph_.get(), output_min, output_max, input_id, output_id, 0 /* flags */);
230*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(status, xnn_status_success);
231*4bdc9457SAndroid Build Coastguard Worker 
232*4bdc9457SAndroid Build Coastguard Worker     return *this;
233*4bdc9457SAndroid Build Coastguard Worker   }
234*4bdc9457SAndroid Build Coastguard Worker 
AddDeconvolution2D(uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t adjustment_height,uint32_t adjustment_width,uint32_t kernel_height,uint32_t kernel_width,uint32_t upsampling_height,uint32_t upsampling_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,uint32_t input_id,uint32_t filter_id,uint32_t bias_id,uint32_t output_id)235*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& AddDeconvolution2D(
236*4bdc9457SAndroid Build Coastguard Worker       uint32_t input_padding_top, uint32_t input_padding_right,
237*4bdc9457SAndroid Build Coastguard Worker       uint32_t input_padding_bottom, uint32_t input_padding_left,
238*4bdc9457SAndroid Build Coastguard Worker       uint32_t adjustment_height, uint32_t adjustment_width,
239*4bdc9457SAndroid Build Coastguard Worker       uint32_t kernel_height, uint32_t kernel_width,
240*4bdc9457SAndroid Build Coastguard Worker       uint32_t upsampling_height, uint32_t upsampling_width,
241*4bdc9457SAndroid Build Coastguard Worker       uint32_t dilation_height, uint32_t dilation_width, uint32_t groups,
242*4bdc9457SAndroid Build Coastguard Worker       size_t group_input_channels, size_t group_output_channels,
243*4bdc9457SAndroid Build Coastguard Worker       uint32_t input_id, uint32_t filter_id, uint32_t bias_id,
244*4bdc9457SAndroid Build Coastguard Worker       uint32_t output_id) {
245*4bdc9457SAndroid Build Coastguard Worker     const xnn_status status = xnn_define_deconvolution_2d(
246*4bdc9457SAndroid Build Coastguard Worker         subgraph_.get(), input_padding_top, input_padding_right,
247*4bdc9457SAndroid Build Coastguard Worker         input_padding_bottom, input_padding_left, adjustment_height,
248*4bdc9457SAndroid Build Coastguard Worker         adjustment_width, kernel_height, kernel_width, upsampling_height,
249*4bdc9457SAndroid Build Coastguard Worker         upsampling_width, dilation_height, dilation_width, groups,
250*4bdc9457SAndroid Build Coastguard Worker         group_input_channels, group_output_channels,
251*4bdc9457SAndroid Build Coastguard Worker         -std::numeric_limits<float>::infinity(),
252*4bdc9457SAndroid Build Coastguard Worker         std::numeric_limits<float>::infinity(), input_id, filter_id, bias_id,
253*4bdc9457SAndroid Build Coastguard Worker         output_id, 0 /* flags */);
254*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(status, xnn_status_success);
255*4bdc9457SAndroid Build Coastguard Worker 
256*4bdc9457SAndroid Build Coastguard Worker     return *this;
257*4bdc9457SAndroid Build Coastguard Worker   }
258*4bdc9457SAndroid Build Coastguard Worker 
AddDeconvolution2D(DeconvolutionParams params,uint32_t input_id,uint32_t filter_id,uint32_t bias_id,uint32_t output_id)259*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& AddDeconvolution2D(
260*4bdc9457SAndroid Build Coastguard Worker       DeconvolutionParams params,
261*4bdc9457SAndroid Build Coastguard Worker       uint32_t input_id, uint32_t filter_id, uint32_t bias_id,
262*4bdc9457SAndroid Build Coastguard Worker       uint32_t output_id) {
263*4bdc9457SAndroid Build Coastguard Worker     const xnn_status status = xnn_define_deconvolution_2d(
264*4bdc9457SAndroid Build Coastguard Worker         subgraph_.get(), params.padding.top, params.padding.right,
265*4bdc9457SAndroid Build Coastguard Worker         params.padding.bottom, params.padding.left, params.adjustment.height,
266*4bdc9457SAndroid Build Coastguard Worker         params.adjustment.width, params.kernel.height, params.kernel.width, params.upsampling.height,
267*4bdc9457SAndroid Build Coastguard Worker         params.upsampling.width, params.dilation.height, params.dilation.width, params.groups,
268*4bdc9457SAndroid Build Coastguard Worker         params.group_input_channels, params.group_output_channels,
269*4bdc9457SAndroid Build Coastguard Worker         -std::numeric_limits<float>::infinity(),
270*4bdc9457SAndroid Build Coastguard Worker         std::numeric_limits<float>::infinity(), input_id, filter_id, bias_id,
271*4bdc9457SAndroid Build Coastguard Worker         output_id, 0 /* flags */);
272*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(status, xnn_status_success);
273*4bdc9457SAndroid Build Coastguard Worker 
274*4bdc9457SAndroid Build Coastguard Worker     return *this;
275*4bdc9457SAndroid Build Coastguard Worker   }
276*4bdc9457SAndroid Build Coastguard Worker 
AddDivide(uint32_t input_id1,uint32_t input_id2,uint32_t output_id)277*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& AddDivide(uint32_t input_id1, uint32_t input_id2, uint32_t output_id) {
278*4bdc9457SAndroid Build Coastguard Worker     const xnn_status status =
279*4bdc9457SAndroid Build Coastguard Worker         xnn_define_divide(subgraph_.get(), -std::numeric_limits<float>::infinity(),
280*4bdc9457SAndroid Build Coastguard Worker                         std::numeric_limits<float>::infinity(), input_id1,
281*4bdc9457SAndroid Build Coastguard Worker                         input_id2, output_id, 0 /* flags */);
282*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(status, xnn_status_success);
283*4bdc9457SAndroid Build Coastguard Worker 
284*4bdc9457SAndroid Build Coastguard Worker     return *this;
285*4bdc9457SAndroid Build Coastguard Worker   }
286*4bdc9457SAndroid Build Coastguard Worker 
AddFullyConnected(uint32_t input_id,uint32_t filter_id,uint32_t bias_id,uint32_t output_id)287*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& AddFullyConnected(
288*4bdc9457SAndroid Build Coastguard Worker       uint32_t input_id, uint32_t filter_id, uint32_t bias_id, uint32_t output_id) {
289*4bdc9457SAndroid Build Coastguard Worker     const xnn_status status = xnn_define_fully_connected(
290*4bdc9457SAndroid Build Coastguard Worker         subgraph_.get(),
291*4bdc9457SAndroid Build Coastguard Worker         -std::numeric_limits<float>::infinity(),
292*4bdc9457SAndroid Build Coastguard Worker         std::numeric_limits<float>::infinity(), input_id, filter_id, bias_id,
293*4bdc9457SAndroid Build Coastguard Worker         output_id, 0 /* flags */);
294*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(status, xnn_status_success);
295*4bdc9457SAndroid Build Coastguard Worker 
296*4bdc9457SAndroid Build Coastguard Worker     return *this;
297*4bdc9457SAndroid Build Coastguard Worker   }
298*4bdc9457SAndroid Build Coastguard Worker 
299*4bdc9457SAndroid Build Coastguard Worker 
AddGlobalAveragePooling(uint32_t input_id,uint32_t output_id)300*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& AddGlobalAveragePooling(uint32_t input_id, uint32_t output_id) {
301*4bdc9457SAndroid Build Coastguard Worker     const xnn_status status = xnn_define_global_average_pooling_2d(
302*4bdc9457SAndroid Build Coastguard Worker         subgraph_.get(), -std::numeric_limits<float>::infinity(),
303*4bdc9457SAndroid Build Coastguard Worker         std::numeric_limits<float>::infinity(), input_id, output_id, 0 /* flags */);
304*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(status, xnn_status_success);
305*4bdc9457SAndroid Build Coastguard Worker 
306*4bdc9457SAndroid Build Coastguard Worker     return *this;
307*4bdc9457SAndroid Build Coastguard Worker   }
308*4bdc9457SAndroid Build Coastguard Worker 
AddMaxPooling2D(uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t pooling_height,uint32_t pooling_width,uint32_t stride_height,uint32_t stride_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t input_id,uint32_t output_id)309*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& AddMaxPooling2D(
310*4bdc9457SAndroid Build Coastguard Worker       uint32_t input_padding_top, uint32_t input_padding_right,
311*4bdc9457SAndroid Build Coastguard Worker       uint32_t input_padding_bottom, uint32_t input_padding_left,
312*4bdc9457SAndroid Build Coastguard Worker       uint32_t pooling_height, uint32_t pooling_width, uint32_t stride_height,
313*4bdc9457SAndroid Build Coastguard Worker       uint32_t stride_width, uint32_t dilation_height, uint32_t dilation_width, uint32_t input_id, uint32_t output_id) {
314*4bdc9457SAndroid Build Coastguard Worker     const xnn_status status = xnn_define_max_pooling_2d(
315*4bdc9457SAndroid Build Coastguard Worker         subgraph_.get(), input_padding_top, input_padding_right,
316*4bdc9457SAndroid Build Coastguard Worker         input_padding_bottom, input_padding_left, pooling_height, pooling_width,
317*4bdc9457SAndroid Build Coastguard Worker         stride_height, stride_width, dilation_height, dilation_width, -std::numeric_limits<float>::infinity(),
318*4bdc9457SAndroid Build Coastguard Worker         std::numeric_limits<float>::infinity(), input_id, output_id,
319*4bdc9457SAndroid Build Coastguard Worker         0 /* flags */);
320*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(status, xnn_status_success);
321*4bdc9457SAndroid Build Coastguard Worker 
322*4bdc9457SAndroid Build Coastguard Worker     return *this;
323*4bdc9457SAndroid Build Coastguard Worker   }
324*4bdc9457SAndroid Build Coastguard Worker 
AddMultiply(uint32_t input_id1,uint32_t input_id2,uint32_t output_id)325*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& AddMultiply(uint32_t input_id1, uint32_t input_id2, uint32_t output_id) {
326*4bdc9457SAndroid Build Coastguard Worker     const xnn_status status =
327*4bdc9457SAndroid Build Coastguard Worker         xnn_define_multiply2(subgraph_.get(), -std::numeric_limits<float>::infinity(),
328*4bdc9457SAndroid Build Coastguard Worker                         std::numeric_limits<float>::infinity(), input_id1,
329*4bdc9457SAndroid Build Coastguard Worker                         input_id2, output_id, 0 /* flags */);
330*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(status, xnn_status_success);
331*4bdc9457SAndroid Build Coastguard Worker 
332*4bdc9457SAndroid Build Coastguard Worker     return *this;
333*4bdc9457SAndroid Build Coastguard Worker   }
334*4bdc9457SAndroid Build Coastguard Worker 
AddSubtract(uint32_t input_id1,uint32_t input_id2,uint32_t output_id)335*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& AddSubtract(uint32_t input_id1, uint32_t input_id2, uint32_t output_id) {
336*4bdc9457SAndroid Build Coastguard Worker     const xnn_status status =
337*4bdc9457SAndroid Build Coastguard Worker         xnn_define_subtract(subgraph_.get(), -std::numeric_limits<float>::infinity(),
338*4bdc9457SAndroid Build Coastguard Worker                         std::numeric_limits<float>::infinity(), input_id1,
339*4bdc9457SAndroid Build Coastguard Worker                         input_id2, output_id, 0 /* flags */);
340*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(status, xnn_status_success);
341*4bdc9457SAndroid Build Coastguard Worker 
342*4bdc9457SAndroid Build Coastguard Worker     return *this;
343*4bdc9457SAndroid Build Coastguard Worker   }
344*4bdc9457SAndroid Build Coastguard Worker 
Optimize()345*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& Optimize() {
346*4bdc9457SAndroid Build Coastguard Worker     const xnn_status status = xnn_subgraph_optimize(subgraph_.get(), 0 /* flags */);
347*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(status, xnn_status_success);
348*4bdc9457SAndroid Build Coastguard Worker 
349*4bdc9457SAndroid Build Coastguard Worker     return *this;
350*4bdc9457SAndroid Build Coastguard Worker   }
351*4bdc9457SAndroid Build Coastguard Worker 
RewriteForNchw()352*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& RewriteForNchw() {
353*4bdc9457SAndroid Build Coastguard Worker     xnn_subgraph_rewrite_for_nchw(subgraph_.get());
354*4bdc9457SAndroid Build Coastguard Worker 
355*4bdc9457SAndroid Build Coastguard Worker     return *this;
356*4bdc9457SAndroid Build Coastguard Worker   }
357*4bdc9457SAndroid Build Coastguard Worker 
RewriteForFp16()358*4bdc9457SAndroid Build Coastguard Worker   inline SubgraphTester& RewriteForFp16() {
359*4bdc9457SAndroid Build Coastguard Worker     EXPECT_TRUE(xnn_subgraph_rewrite_for_fp16(subgraph_.get()));
360*4bdc9457SAndroid Build Coastguard Worker 
361*4bdc9457SAndroid Build Coastguard Worker     return *this;
362*4bdc9457SAndroid Build Coastguard Worker   }
363*4bdc9457SAndroid Build Coastguard Worker 
GetLayout(uint32_t value_id)364*4bdc9457SAndroid Build Coastguard Worker   inline xnn_layout_type GetLayout(uint32_t value_id) const {
365*4bdc9457SAndroid Build Coastguard Worker     return subgraph_->values[value_id].layout;
366*4bdc9457SAndroid Build Coastguard Worker   }
367*4bdc9457SAndroid Build Coastguard Worker 
Value(uint32_t value_id)368*4bdc9457SAndroid Build Coastguard Worker   inline const xnn_value* const Value(uint32_t value_id) const {
369*4bdc9457SAndroid Build Coastguard Worker     return &subgraph_->values[value_id];
370*4bdc9457SAndroid Build Coastguard Worker   }
371*4bdc9457SAndroid Build Coastguard Worker 
Node(uint32_t node_id)372*4bdc9457SAndroid Build Coastguard Worker   inline const xnn_node* const Node(uint32_t node_id) const {
373*4bdc9457SAndroid Build Coastguard Worker     return &subgraph_->nodes[node_id];
374*4bdc9457SAndroid Build Coastguard Worker   }
375*4bdc9457SAndroid Build Coastguard Worker 
NumNodes()376*4bdc9457SAndroid Build Coastguard Worker   inline size_t NumNodes() const {
377*4bdc9457SAndroid Build Coastguard Worker     return subgraph_->num_nodes;
378*4bdc9457SAndroid Build Coastguard Worker   }
379*4bdc9457SAndroid Build Coastguard Worker 
380*4bdc9457SAndroid Build Coastguard Worker  protected:
381*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> subgraph_{nullptr, xnn_delete_subgraph};
382*4bdc9457SAndroid Build Coastguard Worker   std::unordered_map<uint32_t, std::vector<char>> external_tensors_;
383*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id_;
384*4bdc9457SAndroid Build Coastguard Worker 
385*4bdc9457SAndroid Build Coastguard Worker  private:
NumElements(const std::vector<size_t> & dims)386*4bdc9457SAndroid Build Coastguard Worker   static inline size_t NumElements(const std::vector<size_t>& dims) {
387*4bdc9457SAndroid Build Coastguard Worker     return std::accumulate(std::begin(dims), std::end(dims), size_t(1), std::multiplies<size_t>());
388*4bdc9457SAndroid Build Coastguard Worker   }
389*4bdc9457SAndroid Build Coastguard Worker 
390*4bdc9457SAndroid Build Coastguard Worker   std::vector<std::vector<char>> static_data_;
391*4bdc9457SAndroid Build Coastguard Worker   std::mt19937 rng_;
392*4bdc9457SAndroid Build Coastguard Worker   std::uniform_real_distribution<float> f32dist = std::uniform_real_distribution<float>(-1.0f, +1.0f);
393*4bdc9457SAndroid Build Coastguard Worker 
394*4bdc9457SAndroid Build Coastguard Worker };
395*4bdc9457SAndroid Build Coastguard Worker 
396*4bdc9457SAndroid Build Coastguard Worker }  // namespace xnnpack
397