xref: /aosp_15_r20/external/XNNPACK/test/concatenate4.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
7*4bdc9457SAndroid Build Coastguard Worker #include <array>
8*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
9*4bdc9457SAndroid Build Coastguard Worker #include <cstdint>
10*4bdc9457SAndroid Build Coastguard Worker #include <limits>
11*4bdc9457SAndroid Build Coastguard Worker #include <memory>
12*4bdc9457SAndroid Build Coastguard Worker #include <numeric>
13*4bdc9457SAndroid Build Coastguard Worker #include <random>
14*4bdc9457SAndroid Build Coastguard Worker 
15*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
16*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/node-type.h>
17*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/operator.h>
18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/subgraph.h>
19*4bdc9457SAndroid Build Coastguard Worker 
20*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
21*4bdc9457SAndroid Build Coastguard Worker 
22*4bdc9457SAndroid Build Coastguard Worker template <typename T> class Concatenate4Test : public ::testing::Test {
23*4bdc9457SAndroid Build Coastguard Worker protected:
Concatenate4Test()24*4bdc9457SAndroid Build Coastguard Worker   Concatenate4Test()
25*4bdc9457SAndroid Build Coastguard Worker   {
26*4bdc9457SAndroid Build Coastguard Worker     random_device = std::unique_ptr<std::random_device>(new std::random_device());
27*4bdc9457SAndroid Build Coastguard Worker     rng = std::mt19937((*random_device)());
28*4bdc9457SAndroid Build Coastguard Worker     shape_dist = std::uniform_int_distribution<size_t>(1, XNN_MAX_TENSOR_DIMS);
29*4bdc9457SAndroid Build Coastguard Worker     dim_dist = std::uniform_int_distribution<size_t>(1, 9);
30*4bdc9457SAndroid Build Coastguard Worker     f32dist = std::uniform_real_distribution<float>();
31*4bdc9457SAndroid Build Coastguard Worker     i8dist =
32*4bdc9457SAndroid Build Coastguard Worker       std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
33*4bdc9457SAndroid Build Coastguard Worker     u8dist =
34*4bdc9457SAndroid Build Coastguard Worker       std::uniform_int_distribution<int32_t>(std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
35*4bdc9457SAndroid Build Coastguard Worker     scale_dist = std::uniform_real_distribution<float>(0.1f, 5.0f);
36*4bdc9457SAndroid Build Coastguard Worker 
37*4bdc9457SAndroid Build Coastguard Worker     input1_dims = RandomShape();
38*4bdc9457SAndroid Build Coastguard Worker     axis = RandomAxis(input1_dims);
39*4bdc9457SAndroid Build Coastguard Worker     input2_dims = RandomShape(input1_dims, axis);
40*4bdc9457SAndroid Build Coastguard Worker     input3_dims = RandomShape(input1_dims, axis);
41*4bdc9457SAndroid Build Coastguard Worker     input4_dims = RandomShape(input1_dims, axis);
42*4bdc9457SAndroid Build Coastguard Worker     output_dims = input1_dims;
43*4bdc9457SAndroid Build Coastguard Worker     output_dims[axis] = input1_dims[axis] + input2_dims[axis] + input3_dims[axis] + input4_dims[axis];
44*4bdc9457SAndroid Build Coastguard Worker 
45*4bdc9457SAndroid Build Coastguard Worker     input1 = std::vector<T>(NumElements(input1_dims));
46*4bdc9457SAndroid Build Coastguard Worker     input2 = std::vector<T>(NumElements(input2_dims));
47*4bdc9457SAndroid Build Coastguard Worker     input3 = std::vector<T>(NumElements(input3_dims));
48*4bdc9457SAndroid Build Coastguard Worker     input4 = std::vector<T>(NumElements(input4_dims));
49*4bdc9457SAndroid Build Coastguard Worker     operator_output = std::vector<T>(NumElements(output_dims));
50*4bdc9457SAndroid Build Coastguard Worker     subgraph_output = std::vector<T>(NumElements(output_dims));
51*4bdc9457SAndroid Build Coastguard Worker 
52*4bdc9457SAndroid Build Coastguard Worker     signed_zero_point = i8dist(rng);
53*4bdc9457SAndroid Build Coastguard Worker     unsigned_zero_point = u8dist(rng);
54*4bdc9457SAndroid Build Coastguard Worker     scale = scale_dist(rng);
55*4bdc9457SAndroid Build Coastguard Worker 
56*4bdc9457SAndroid Build Coastguard Worker     batch_size = 1;
57*4bdc9457SAndroid Build Coastguard Worker     channels_1 = 1;
58*4bdc9457SAndroid Build Coastguard Worker     channels_2 = 1;
59*4bdc9457SAndroid Build Coastguard Worker     channels_3 = 1;
60*4bdc9457SAndroid Build Coastguard Worker     channels_4 = 1;
61*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < axis; i++) {
62*4bdc9457SAndroid Build Coastguard Worker       batch_size *= output_dims[i];
63*4bdc9457SAndroid Build Coastguard Worker     }
64*4bdc9457SAndroid Build Coastguard Worker 
65*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = axis; i < input1_dims.size(); i++) {
66*4bdc9457SAndroid Build Coastguard Worker       channels_1 *= input1_dims[i];
67*4bdc9457SAndroid Build Coastguard Worker       channels_2 *= input2_dims[i];
68*4bdc9457SAndroid Build Coastguard Worker       channels_3 *= input3_dims[i];
69*4bdc9457SAndroid Build Coastguard Worker       channels_4 *= input4_dims[i];
70*4bdc9457SAndroid Build Coastguard Worker     }
71*4bdc9457SAndroid Build Coastguard Worker     output_stride = channels_1 + channels_2 + channels_3 + channels_4;
72*4bdc9457SAndroid Build Coastguard Worker   }
73*4bdc9457SAndroid Build Coastguard Worker 
RandomShape()74*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> RandomShape()
75*4bdc9457SAndroid Build Coastguard Worker   {
76*4bdc9457SAndroid Build Coastguard Worker     std::vector<size_t> dims(shape_dist(rng));
77*4bdc9457SAndroid Build Coastguard Worker     std::generate(dims.begin(), dims.end(), [&] { return dim_dist(rng); });
78*4bdc9457SAndroid Build Coastguard Worker     return dims;
79*4bdc9457SAndroid Build Coastguard Worker   }
80*4bdc9457SAndroid Build Coastguard Worker 
RandomShape(const std::vector<size_t> base_dims,size_t axis)81*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> RandomShape(const std::vector<size_t> base_dims, size_t axis)
82*4bdc9457SAndroid Build Coastguard Worker   {
83*4bdc9457SAndroid Build Coastguard Worker     auto dims = base_dims;
84*4bdc9457SAndroid Build Coastguard Worker     dims[axis] = dim_dist(rng);
85*4bdc9457SAndroid Build Coastguard Worker     return dims;
86*4bdc9457SAndroid Build Coastguard Worker   }
87*4bdc9457SAndroid Build Coastguard Worker 
RandomAxis(const std::vector<size_t> & dims)88*4bdc9457SAndroid Build Coastguard Worker   size_t RandomAxis(const std::vector<size_t>& dims)
89*4bdc9457SAndroid Build Coastguard Worker   {
90*4bdc9457SAndroid Build Coastguard Worker     return std::uniform_int_distribution<size_t>(0, dims.size() - 1)(rng);
91*4bdc9457SAndroid Build Coastguard Worker   }
92*4bdc9457SAndroid Build Coastguard Worker 
NumElements(const std::vector<size_t> & dims)93*4bdc9457SAndroid Build Coastguard Worker   size_t NumElements(const std::vector<size_t>& dims)
94*4bdc9457SAndroid Build Coastguard Worker   {
95*4bdc9457SAndroid Build Coastguard Worker     return std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>());
96*4bdc9457SAndroid Build Coastguard Worker   }
97*4bdc9457SAndroid Build Coastguard Worker 
98*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<std::random_device> random_device;
99*4bdc9457SAndroid Build Coastguard Worker   std::mt19937 rng;
100*4bdc9457SAndroid Build Coastguard Worker   std::uniform_int_distribution<size_t> shape_dist;
101*4bdc9457SAndroid Build Coastguard Worker   std::uniform_int_distribution<size_t> dim_dist;
102*4bdc9457SAndroid Build Coastguard Worker   std::uniform_real_distribution<float> f32dist;
103*4bdc9457SAndroid Build Coastguard Worker   std::uniform_int_distribution<int32_t> i8dist;
104*4bdc9457SAndroid Build Coastguard Worker   std::uniform_int_distribution<int32_t> u8dist;
105*4bdc9457SAndroid Build Coastguard Worker   std::uniform_real_distribution<float> scale_dist;
106*4bdc9457SAndroid Build Coastguard Worker 
107*4bdc9457SAndroid Build Coastguard Worker   uint32_t input1_id;
108*4bdc9457SAndroid Build Coastguard Worker   uint32_t input2_id;
109*4bdc9457SAndroid Build Coastguard Worker   uint32_t input3_id;
110*4bdc9457SAndroid Build Coastguard Worker   uint32_t input4_id;
111*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id;
112*4bdc9457SAndroid Build Coastguard Worker 
113*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> input1_dims;
114*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> input2_dims;
115*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> input3_dims;
116*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> input4_dims;
117*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> output_dims;
118*4bdc9457SAndroid Build Coastguard Worker 
119*4bdc9457SAndroid Build Coastguard Worker   size_t axis;
120*4bdc9457SAndroid Build Coastguard Worker   size_t batch_size;
121*4bdc9457SAndroid Build Coastguard Worker   size_t channels_1;
122*4bdc9457SAndroid Build Coastguard Worker   size_t channels_2;
123*4bdc9457SAndroid Build Coastguard Worker   size_t channels_3;
124*4bdc9457SAndroid Build Coastguard Worker   size_t channels_4;
125*4bdc9457SAndroid Build Coastguard Worker   size_t output_stride;
126*4bdc9457SAndroid Build Coastguard Worker 
127*4bdc9457SAndroid Build Coastguard Worker   int32_t signed_zero_point;
128*4bdc9457SAndroid Build Coastguard Worker   int32_t unsigned_zero_point;
129*4bdc9457SAndroid Build Coastguard Worker   float scale;
130*4bdc9457SAndroid Build Coastguard Worker 
131*4bdc9457SAndroid Build Coastguard Worker   std::vector<T> input1;
132*4bdc9457SAndroid Build Coastguard Worker   std::vector<T> input2;
133*4bdc9457SAndroid Build Coastguard Worker   std::vector<T> input3;
134*4bdc9457SAndroid Build Coastguard Worker   std::vector<T> input4;
135*4bdc9457SAndroid Build Coastguard Worker   std::vector<T> operator_output;
136*4bdc9457SAndroid Build Coastguard Worker   std::vector<T> subgraph_output;
137*4bdc9457SAndroid Build Coastguard Worker };
138*4bdc9457SAndroid Build Coastguard Worker 
139*4bdc9457SAndroid Build Coastguard Worker using Concatenate4TestQS8 = Concatenate4Test<int8_t>;
140*4bdc9457SAndroid Build Coastguard Worker using Concatenate4TestQU8 = Concatenate4Test<uint8_t>;
141*4bdc9457SAndroid Build Coastguard Worker using Concatenate4TestF32 = Concatenate4Test<float>;
142*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Concatenate4TestQS8,define)143*4bdc9457SAndroid Build Coastguard Worker TEST_F(Concatenate4TestQS8, define)
144*4bdc9457SAndroid Build Coastguard Worker {
145*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
146*4bdc9457SAndroid Build Coastguard Worker 
147*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
148*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
149*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
150*4bdc9457SAndroid Build Coastguard Worker 
151*4bdc9457SAndroid Build Coastguard Worker   input1_id = XNN_INVALID_NODE_ID;
152*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
153*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
154*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
155*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
156*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
157*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
158*4bdc9457SAndroid Build Coastguard Worker 
159*4bdc9457SAndroid Build Coastguard Worker   input2_id = XNN_INVALID_NODE_ID;
160*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
161*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
162*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
163*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
164*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
165*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
166*4bdc9457SAndroid Build Coastguard Worker 
167*4bdc9457SAndroid Build Coastguard Worker   input3_id = XNN_INVALID_NODE_ID;
168*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
169*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
170*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
171*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
172*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
173*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
174*4bdc9457SAndroid Build Coastguard Worker 
175*4bdc9457SAndroid Build Coastguard Worker   input4_id = XNN_INVALID_NODE_ID;
176*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
177*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
178*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
179*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input4_dims.size(), input4_dims.data(), nullptr, 3,
180*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input4_id));
181*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input4_id, XNN_INVALID_NODE_ID);
182*4bdc9457SAndroid Build Coastguard Worker 
183*4bdc9457SAndroid Build Coastguard Worker   output_id = XNN_INVALID_NODE_ID;
184*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
185*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
186*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
187*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 4,
188*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
189*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
190*4bdc9457SAndroid Build Coastguard Worker 
191*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
192*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
193*4bdc9457SAndroid Build Coastguard Worker     xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0));
194*4bdc9457SAndroid Build Coastguard Worker 
195*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph->num_nodes, 1);
196*4bdc9457SAndroid Build Coastguard Worker   const struct xnn_node* node = &subgraph->nodes[0];
197*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->type, xnn_node_type_concatenate4);
198*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->compute_type, xnn_compute_type_qs8);
199*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->params.concatenate.axis, axis);
200*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_inputs, 4);
201*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[0], input1_id);
202*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[1], input2_id);
203*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[2], input3_id);
204*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[3], input4_id);
205*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_outputs, 1);
206*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->outputs[0], output_id);
207*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->flags, 0);
208*4bdc9457SAndroid Build Coastguard Worker }
209*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Concatenate4TestQU8,define)210*4bdc9457SAndroid Build Coastguard Worker TEST_F(Concatenate4TestQU8, define)
211*4bdc9457SAndroid Build Coastguard Worker {
212*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
213*4bdc9457SAndroid Build Coastguard Worker 
214*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
215*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
216*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
217*4bdc9457SAndroid Build Coastguard Worker 
218*4bdc9457SAndroid Build Coastguard Worker   input1_id = XNN_INVALID_NODE_ID;
219*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
220*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
221*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
222*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
223*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
224*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
225*4bdc9457SAndroid Build Coastguard Worker 
226*4bdc9457SAndroid Build Coastguard Worker   input2_id = XNN_INVALID_NODE_ID;
227*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
228*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
229*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
230*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
231*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
232*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
233*4bdc9457SAndroid Build Coastguard Worker 
234*4bdc9457SAndroid Build Coastguard Worker   input3_id = XNN_INVALID_NODE_ID;
235*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
236*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
237*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
238*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
239*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
240*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
241*4bdc9457SAndroid Build Coastguard Worker 
242*4bdc9457SAndroid Build Coastguard Worker   input4_id = XNN_INVALID_NODE_ID;
243*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
244*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
245*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
246*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input4_dims.size(), input4_dims.data(), nullptr, 3,
247*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input4_id));
248*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input4_id, XNN_INVALID_NODE_ID);
249*4bdc9457SAndroid Build Coastguard Worker 
250*4bdc9457SAndroid Build Coastguard Worker   output_id = XNN_INVALID_NODE_ID;
251*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
252*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
253*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
254*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 4,
255*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
256*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
257*4bdc9457SAndroid Build Coastguard Worker 
258*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
259*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
260*4bdc9457SAndroid Build Coastguard Worker     xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0));
261*4bdc9457SAndroid Build Coastguard Worker 
262*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph->num_nodes, 1);
263*4bdc9457SAndroid Build Coastguard Worker   const struct xnn_node* node = &subgraph->nodes[0];
264*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->type, xnn_node_type_concatenate4);
265*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->compute_type, xnn_compute_type_qu8);
266*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->params.concatenate.axis, axis);
267*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_inputs, 4);
268*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[0], input1_id);
269*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[1], input2_id);
270*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[2], input3_id);
271*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[3], input4_id);
272*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_outputs, 1);
273*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->outputs[0], output_id);
274*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->flags, 0);
275*4bdc9457SAndroid Build Coastguard Worker }
276*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Concatenate4TestF32,define)277*4bdc9457SAndroid Build Coastguard Worker TEST_F(Concatenate4TestF32, define)
278*4bdc9457SAndroid Build Coastguard Worker {
279*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
280*4bdc9457SAndroid Build Coastguard Worker 
281*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
282*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
283*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
284*4bdc9457SAndroid Build Coastguard Worker 
285*4bdc9457SAndroid Build Coastguard Worker   input1_id = XNN_INVALID_NODE_ID;
286*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
287*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
288*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input1_dims.size(), input1_dims.data(), nullptr, 0,
289*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
290*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
291*4bdc9457SAndroid Build Coastguard Worker 
292*4bdc9457SAndroid Build Coastguard Worker   input2_id = XNN_INVALID_NODE_ID;
293*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
294*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
295*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input2_dims.size(), input2_dims.data(), nullptr, 1,
296*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
297*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
298*4bdc9457SAndroid Build Coastguard Worker 
299*4bdc9457SAndroid Build Coastguard Worker   input3_id = XNN_INVALID_NODE_ID;
300*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
301*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
302*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input3_dims.size(), input3_dims.data(), nullptr, 2,
303*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
304*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
305*4bdc9457SAndroid Build Coastguard Worker 
306*4bdc9457SAndroid Build Coastguard Worker   input4_id = XNN_INVALID_NODE_ID;
307*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
308*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
309*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input4_dims.size(), input4_dims.data(), nullptr, 3,
310*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input4_id));
311*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input4_id, XNN_INVALID_NODE_ID);
312*4bdc9457SAndroid Build Coastguard Worker 
313*4bdc9457SAndroid Build Coastguard Worker   output_id = XNN_INVALID_NODE_ID;
314*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
315*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
316*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, 4,
317*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
318*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
319*4bdc9457SAndroid Build Coastguard Worker 
320*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
321*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
322*4bdc9457SAndroid Build Coastguard Worker     xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0));
323*4bdc9457SAndroid Build Coastguard Worker 
324*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph->num_nodes, 1);
325*4bdc9457SAndroid Build Coastguard Worker   const struct xnn_node* node = &subgraph->nodes[0];
326*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->type, xnn_node_type_concatenate4);
327*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
328*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->params.concatenate.axis, axis);
329*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_inputs, 4);
330*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[0], input1_id);
331*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[1], input2_id);
332*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[2], input3_id);
333*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->inputs[3], input4_id);
334*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->num_outputs, 1);
335*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->outputs[0], output_id);
336*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(node->flags, 0);
337*4bdc9457SAndroid Build Coastguard Worker }
338*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Concatenate4TestQS8,matches_operator_api)339*4bdc9457SAndroid Build Coastguard Worker TEST_F(Concatenate4TestQS8, matches_operator_api)
340*4bdc9457SAndroid Build Coastguard Worker {
341*4bdc9457SAndroid Build Coastguard Worker   std::generate(input1.begin(), input1.end(), [&]() { return i8dist(rng); });
342*4bdc9457SAndroid Build Coastguard Worker   std::generate(input2.begin(), input2.end(), [&]() { return i8dist(rng); });
343*4bdc9457SAndroid Build Coastguard Worker   std::generate(input3.begin(), input3.end(), [&]() { return i8dist(rng); });
344*4bdc9457SAndroid Build Coastguard Worker   std::generate(input4.begin(), input4.end(), [&]() { return i8dist(rng); });
345*4bdc9457SAndroid Build Coastguard Worker   std::fill(operator_output.begin(), operator_output.end(), INT8_C(0xA5));
346*4bdc9457SAndroid Build Coastguard Worker   std::fill(subgraph_output.begin(), subgraph_output.end(), INT8_C(0xA5));
347*4bdc9457SAndroid Build Coastguard Worker 
348*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
349*4bdc9457SAndroid Build Coastguard Worker 
350*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op1 = nullptr;
351*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op2 = nullptr;
352*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op3 = nullptr;
353*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op4 = nullptr;
354*4bdc9457SAndroid Build Coastguard Worker 
355*4bdc9457SAndroid Build Coastguard Worker   // Call operator API.
356*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
357*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
358*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
359*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
360*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_3, channels_3, output_stride, /*flags=*/0, &op3));
361*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
362*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_4, channels_4, output_stride, /*flags=*/0, &op4));
363*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op4(op4, xnn_delete_operator);
364*4bdc9457SAndroid Build Coastguard Worker 
365*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
366*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
367*4bdc9457SAndroid Build Coastguard Worker     xnn_setup_copy_nc_x8(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
368*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
369*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
370*4bdc9457SAndroid Build Coastguard Worker     xnn_setup_copy_nc_x8(
371*4bdc9457SAndroid Build Coastguard Worker       op2, batch_size, input2.data(), (uint8_t*) operator_output.data() + op1->channels, nullptr /* thread pool */));
372*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
373*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
374*4bdc9457SAndroid Build Coastguard Worker     xnn_setup_copy_nc_x8(
375*4bdc9457SAndroid Build Coastguard Worker       op3, batch_size, input3.data(), (uint8_t*) operator_output.data() + op1->channels + op2->channels,
376*4bdc9457SAndroid Build Coastguard Worker       nullptr /* thread pool */));
377*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
378*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
379*4bdc9457SAndroid Build Coastguard Worker     xnn_setup_copy_nc_x8(
380*4bdc9457SAndroid Build Coastguard Worker       op4, batch_size, input4.data(), (uint8_t*) operator_output.data() + op1->channels + op2->channels + op3->channels,
381*4bdc9457SAndroid Build Coastguard Worker       nullptr /* thread pool */));
382*4bdc9457SAndroid Build Coastguard Worker 
383*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
384*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
385*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
386*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op4, nullptr /* thread pool */));
387*4bdc9457SAndroid Build Coastguard Worker 
388*4bdc9457SAndroid Build Coastguard Worker   // Call subgraph API.
389*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
390*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
391*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
392*4bdc9457SAndroid Build Coastguard Worker 
393*4bdc9457SAndroid Build Coastguard Worker   input1_id = XNN_INVALID_NODE_ID;
394*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
395*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
396*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
397*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
398*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
399*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
400*4bdc9457SAndroid Build Coastguard Worker 
401*4bdc9457SAndroid Build Coastguard Worker   input2_id = XNN_INVALID_NODE_ID;
402*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
403*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
404*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
405*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
406*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
407*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
408*4bdc9457SAndroid Build Coastguard Worker 
409*4bdc9457SAndroid Build Coastguard Worker   input3_id = XNN_INVALID_NODE_ID;
410*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
411*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
412*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
413*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
414*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
415*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
416*4bdc9457SAndroid Build Coastguard Worker 
417*4bdc9457SAndroid Build Coastguard Worker   input4_id = XNN_INVALID_NODE_ID;
418*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
419*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
420*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
421*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input4_dims.size(), input4_dims.data(), nullptr, 3,
422*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input4_id));
423*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input4_id, XNN_INVALID_NODE_ID);
424*4bdc9457SAndroid Build Coastguard Worker 
425*4bdc9457SAndroid Build Coastguard Worker   output_id = XNN_INVALID_NODE_ID;
426*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
427*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
428*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
429*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 4,
430*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
431*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
432*4bdc9457SAndroid Build Coastguard Worker 
433*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
434*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
435*4bdc9457SAndroid Build Coastguard Worker     xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0));
436*4bdc9457SAndroid Build Coastguard Worker 
437*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime = nullptr;
438*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
439*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(nullptr, runtime);
440*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
441*4bdc9457SAndroid Build Coastguard Worker   std::array<xnn_external_value, 5> external = {
442*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
443*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{input3_id, input3.data()}, xnn_external_value{input4_id, input4.data()},
444*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{output_id, subgraph_output.data()}};
445*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
446*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
447*4bdc9457SAndroid Build Coastguard Worker 
448*4bdc9457SAndroid Build Coastguard Worker   // Check outputs match.
449*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph_output, operator_output);
450*4bdc9457SAndroid Build Coastguard Worker }
451*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Concatenate4TestQU8,matches_operator_api)452*4bdc9457SAndroid Build Coastguard Worker TEST_F(Concatenate4TestQU8, matches_operator_api)
453*4bdc9457SAndroid Build Coastguard Worker {
454*4bdc9457SAndroid Build Coastguard Worker   std::generate(input1.begin(), input1.end(), [&]() { return u8dist(rng); });
455*4bdc9457SAndroid Build Coastguard Worker   std::generate(input2.begin(), input2.end(), [&]() { return u8dist(rng); });
456*4bdc9457SAndroid Build Coastguard Worker   std::generate(input3.begin(), input3.end(), [&]() { return u8dist(rng); });
457*4bdc9457SAndroid Build Coastguard Worker   std::generate(input4.begin(), input4.end(), [&]() { return u8dist(rng); });
458*4bdc9457SAndroid Build Coastguard Worker   std::fill(operator_output.begin(), operator_output.end(), UINT8_C(0xA5));
459*4bdc9457SAndroid Build Coastguard Worker   std::fill(subgraph_output.begin(), subgraph_output.end(), UINT8_C(0xA5));
460*4bdc9457SAndroid Build Coastguard Worker 
461*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
462*4bdc9457SAndroid Build Coastguard Worker 
463*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op1 = nullptr;
464*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op2 = nullptr;
465*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op3 = nullptr;
466*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op4 = nullptr;
467*4bdc9457SAndroid Build Coastguard Worker 
468*4bdc9457SAndroid Build Coastguard Worker   // Call operator API.
469*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
470*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
471*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
472*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
473*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_3, channels_3, output_stride, /*flags=*/0, &op3));
474*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
475*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_4, channels_4, output_stride, /*flags=*/0, &op4));
476*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op4(op4, xnn_delete_operator);
477*4bdc9457SAndroid Build Coastguard Worker 
478*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
479*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
480*4bdc9457SAndroid Build Coastguard Worker     xnn_setup_copy_nc_x8(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
481*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
482*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
483*4bdc9457SAndroid Build Coastguard Worker     xnn_setup_copy_nc_x8(
484*4bdc9457SAndroid Build Coastguard Worker       op2, batch_size, input2.data(), (uint8_t*) operator_output.data() + op1->channels, nullptr /* thread pool */));
485*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
486*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
487*4bdc9457SAndroid Build Coastguard Worker     xnn_setup_copy_nc_x8(
488*4bdc9457SAndroid Build Coastguard Worker       op3, batch_size, input3.data(), (uint8_t*) operator_output.data() + op1->channels + op2->channels,
489*4bdc9457SAndroid Build Coastguard Worker       nullptr /* thread pool */));
490*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
491*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
492*4bdc9457SAndroid Build Coastguard Worker     xnn_setup_copy_nc_x8(
493*4bdc9457SAndroid Build Coastguard Worker       op4, batch_size, input4.data(), (uint8_t*) operator_output.data() + op1->channels + op2->channels + op3->channels,
494*4bdc9457SAndroid Build Coastguard Worker       nullptr /* thread pool */));
495*4bdc9457SAndroid Build Coastguard Worker 
496*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
497*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
498*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
499*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op4, nullptr /* thread pool */));
500*4bdc9457SAndroid Build Coastguard Worker 
501*4bdc9457SAndroid Build Coastguard Worker   // Call subgraph API.
502*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
503*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
504*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
505*4bdc9457SAndroid Build Coastguard Worker 
506*4bdc9457SAndroid Build Coastguard Worker   input1_id = XNN_INVALID_NODE_ID;
507*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
508*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
509*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
510*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
511*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
512*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
513*4bdc9457SAndroid Build Coastguard Worker 
514*4bdc9457SAndroid Build Coastguard Worker   input2_id = XNN_INVALID_NODE_ID;
515*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
516*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
517*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
518*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
519*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
520*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
521*4bdc9457SAndroid Build Coastguard Worker 
522*4bdc9457SAndroid Build Coastguard Worker   input3_id = XNN_INVALID_NODE_ID;
523*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
524*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
525*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
526*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
527*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
528*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
529*4bdc9457SAndroid Build Coastguard Worker 
530*4bdc9457SAndroid Build Coastguard Worker   input4_id = XNN_INVALID_NODE_ID;
531*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
532*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
533*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
534*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input4_dims.size(), input4_dims.data(), nullptr, 3,
535*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input4_id));
536*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input4_id, XNN_INVALID_NODE_ID);
537*4bdc9457SAndroid Build Coastguard Worker 
538*4bdc9457SAndroid Build Coastguard Worker   output_id = XNN_INVALID_NODE_ID;
539*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
540*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
541*4bdc9457SAndroid Build Coastguard Worker     xnn_define_quantized_tensor_value(
542*4bdc9457SAndroid Build Coastguard Worker       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 4,
543*4bdc9457SAndroid Build Coastguard Worker       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
544*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
545*4bdc9457SAndroid Build Coastguard Worker 
546*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
547*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
548*4bdc9457SAndroid Build Coastguard Worker     xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0));
549*4bdc9457SAndroid Build Coastguard Worker 
550*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime = nullptr;
551*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
552*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(nullptr, runtime);
553*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
554*4bdc9457SAndroid Build Coastguard Worker   std::array<xnn_external_value, 5> external = {
555*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
556*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{input3_id, input3.data()}, xnn_external_value{input4_id, input4.data()},
557*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{output_id, subgraph_output.data()}};
558*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
559*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
560*4bdc9457SAndroid Build Coastguard Worker 
561*4bdc9457SAndroid Build Coastguard Worker   // Check outputs match.
562*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph_output, operator_output);
563*4bdc9457SAndroid Build Coastguard Worker }
564*4bdc9457SAndroid Build Coastguard Worker 
TEST_F(Concatenate4TestF32,matches_operator_api)565*4bdc9457SAndroid Build Coastguard Worker TEST_F(Concatenate4TestF32, matches_operator_api)
566*4bdc9457SAndroid Build Coastguard Worker {
567*4bdc9457SAndroid Build Coastguard Worker   std::generate(input1.begin(), input1.end(), [&]() { return f32dist(rng); });
568*4bdc9457SAndroid Build Coastguard Worker   std::generate(input2.begin(), input2.end(), [&]() { return f32dist(rng); });
569*4bdc9457SAndroid Build Coastguard Worker   std::generate(input3.begin(), input3.end(), [&]() { return f32dist(rng); });
570*4bdc9457SAndroid Build Coastguard Worker   std::generate(input4.begin(), input4.end(), [&]() { return f32dist(rng); });
571*4bdc9457SAndroid Build Coastguard Worker   std::fill(operator_output.begin(), operator_output.end(), std::nanf(""));
572*4bdc9457SAndroid Build Coastguard Worker   std::fill(subgraph_output.begin(), subgraph_output.end(), std::nanf(""));
573*4bdc9457SAndroid Build Coastguard Worker 
574*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
575*4bdc9457SAndroid Build Coastguard Worker 
576*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op1 = nullptr;
577*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op2 = nullptr;
578*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op3 = nullptr;
579*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t op4 = nullptr;
580*4bdc9457SAndroid Build Coastguard Worker 
581*4bdc9457SAndroid Build Coastguard Worker   // Call operator API.
582*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
583*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
584*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
585*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
586*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_3, channels_3, output_stride, /*flags=*/0, &op3));
587*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
588*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_4, channels_4, output_stride, /*flags=*/0, &op4));
589*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op4(op4, xnn_delete_operator);
590*4bdc9457SAndroid Build Coastguard Worker 
591*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
592*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
593*4bdc9457SAndroid Build Coastguard Worker     xnn_setup_copy_nc_x32(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
594*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
595*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
596*4bdc9457SAndroid Build Coastguard Worker     xnn_setup_copy_nc_x32(
597*4bdc9457SAndroid Build Coastguard Worker       op2, batch_size, input2.data(), (float*) operator_output.data() + op1->channels, nullptr /* thread pool */));
598*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
599*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_setup_copy_nc_x32(
600*4bdc9457SAndroid Build Coastguard Worker                           op3, batch_size, input3.data(),
601*4bdc9457SAndroid Build Coastguard Worker                           (float*) operator_output.data() + op1->channels + op2->channels, nullptr /* thread pool */));
602*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
603*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
604*4bdc9457SAndroid Build Coastguard Worker     xnn_setup_copy_nc_x32(
605*4bdc9457SAndroid Build Coastguard Worker       op4, batch_size, input4.data(), (float*) operator_output.data() + op1->channels + op2->channels + op3->channels,
606*4bdc9457SAndroid Build Coastguard Worker       nullptr /* thread pool */));
607*4bdc9457SAndroid Build Coastguard Worker 
608*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
609*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
610*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
611*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_run_operator(op4, nullptr /* thread pool */));
612*4bdc9457SAndroid Build Coastguard Worker 
613*4bdc9457SAndroid Build Coastguard Worker   // Call subgraph API.
614*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
615*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
616*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
617*4bdc9457SAndroid Build Coastguard Worker 
618*4bdc9457SAndroid Build Coastguard Worker   input1_id = XNN_INVALID_NODE_ID;
619*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
620*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
621*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input1_dims.size(), input1_dims.data(), nullptr, 0,
622*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
623*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
624*4bdc9457SAndroid Build Coastguard Worker 
625*4bdc9457SAndroid Build Coastguard Worker   input2_id = XNN_INVALID_NODE_ID;
626*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
627*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
628*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input2_dims.size(), input2_dims.data(), nullptr, 1,
629*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
630*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
631*4bdc9457SAndroid Build Coastguard Worker 
632*4bdc9457SAndroid Build Coastguard Worker   input3_id = XNN_INVALID_NODE_ID;
633*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
634*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
635*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input3_dims.size(), input3_dims.data(), nullptr, 2,
636*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
637*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
638*4bdc9457SAndroid Build Coastguard Worker 
639*4bdc9457SAndroid Build Coastguard Worker   input4_id = XNN_INVALID_NODE_ID;
640*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
641*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
642*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, input4_dims.size(), input4_dims.data(), nullptr, 3,
643*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input4_id));
644*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input4_id, XNN_INVALID_NODE_ID);
645*4bdc9457SAndroid Build Coastguard Worker 
646*4bdc9457SAndroid Build Coastguard Worker   output_id = XNN_INVALID_NODE_ID;
647*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
648*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success, xnn_define_tensor_value(
649*4bdc9457SAndroid Build Coastguard Worker                           subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, 4,
650*4bdc9457SAndroid Build Coastguard Worker                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
651*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
652*4bdc9457SAndroid Build Coastguard Worker 
653*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(
654*4bdc9457SAndroid Build Coastguard Worker     xnn_status_success,
655*4bdc9457SAndroid Build Coastguard Worker     xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0));
656*4bdc9457SAndroid Build Coastguard Worker 
657*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime = nullptr;
658*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
659*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(nullptr, runtime);
660*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
661*4bdc9457SAndroid Build Coastguard Worker   std::array<xnn_external_value, 5> external = {
662*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
663*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{input3_id, input3.data()}, xnn_external_value{input4_id, input4.data()},
664*4bdc9457SAndroid Build Coastguard Worker     xnn_external_value{output_id, subgraph_output.data()}};
665*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
666*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
667*4bdc9457SAndroid Build Coastguard Worker 
668*4bdc9457SAndroid Build Coastguard Worker   // Check outputs match.
669*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(subgraph_output, operator_output);
670*4bdc9457SAndroid Build Coastguard Worker }
671