xref: /aosp_15_r20/external/XNNPACK/test/concatenate4.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <algorithm>
7 #include <array>
8 #include <cstddef>
9 #include <cstdint>
10 #include <limits>
11 #include <memory>
12 #include <numeric>
13 #include <random>
14 
15 #include <xnnpack.h>
16 #include <xnnpack/node-type.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/subgraph.h>
19 
20 #include <gtest/gtest.h>
21 
22 template <typename T> class Concatenate4Test : public ::testing::Test {
23 protected:
Concatenate4Test()24   Concatenate4Test()
25   {
26     random_device = std::unique_ptr<std::random_device>(new std::random_device());
27     rng = std::mt19937((*random_device)());
28     shape_dist = std::uniform_int_distribution<size_t>(1, XNN_MAX_TENSOR_DIMS);
29     dim_dist = std::uniform_int_distribution<size_t>(1, 9);
30     f32dist = std::uniform_real_distribution<float>();
31     i8dist =
32       std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
33     u8dist =
34       std::uniform_int_distribution<int32_t>(std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
35     scale_dist = std::uniform_real_distribution<float>(0.1f, 5.0f);
36 
37     input1_dims = RandomShape();
38     axis = RandomAxis(input1_dims);
39     input2_dims = RandomShape(input1_dims, axis);
40     input3_dims = RandomShape(input1_dims, axis);
41     input4_dims = RandomShape(input1_dims, axis);
42     output_dims = input1_dims;
43     output_dims[axis] = input1_dims[axis] + input2_dims[axis] + input3_dims[axis] + input4_dims[axis];
44 
45     input1 = std::vector<T>(NumElements(input1_dims));
46     input2 = std::vector<T>(NumElements(input2_dims));
47     input3 = std::vector<T>(NumElements(input3_dims));
48     input4 = std::vector<T>(NumElements(input4_dims));
49     operator_output = std::vector<T>(NumElements(output_dims));
50     subgraph_output = std::vector<T>(NumElements(output_dims));
51 
52     signed_zero_point = i8dist(rng);
53     unsigned_zero_point = u8dist(rng);
54     scale = scale_dist(rng);
55 
56     batch_size = 1;
57     channels_1 = 1;
58     channels_2 = 1;
59     channels_3 = 1;
60     channels_4 = 1;
61     for (size_t i = 0; i < axis; i++) {
62       batch_size *= output_dims[i];
63     }
64 
65     for (size_t i = axis; i < input1_dims.size(); i++) {
66       channels_1 *= input1_dims[i];
67       channels_2 *= input2_dims[i];
68       channels_3 *= input3_dims[i];
69       channels_4 *= input4_dims[i];
70     }
71     output_stride = channels_1 + channels_2 + channels_3 + channels_4;
72   }
73 
RandomShape()74   std::vector<size_t> RandomShape()
75   {
76     std::vector<size_t> dims(shape_dist(rng));
77     std::generate(dims.begin(), dims.end(), [&] { return dim_dist(rng); });
78     return dims;
79   }
80 
RandomShape(const std::vector<size_t> base_dims,size_t axis)81   std::vector<size_t> RandomShape(const std::vector<size_t> base_dims, size_t axis)
82   {
83     auto dims = base_dims;
84     dims[axis] = dim_dist(rng);
85     return dims;
86   }
87 
RandomAxis(const std::vector<size_t> & dims)88   size_t RandomAxis(const std::vector<size_t>& dims)
89   {
90     return std::uniform_int_distribution<size_t>(0, dims.size() - 1)(rng);
91   }
92 
NumElements(const std::vector<size_t> & dims)93   size_t NumElements(const std::vector<size_t>& dims)
94   {
95     return std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>());
96   }
97 
98   std::unique_ptr<std::random_device> random_device;
99   std::mt19937 rng;
100   std::uniform_int_distribution<size_t> shape_dist;
101   std::uniform_int_distribution<size_t> dim_dist;
102   std::uniform_real_distribution<float> f32dist;
103   std::uniform_int_distribution<int32_t> i8dist;
104   std::uniform_int_distribution<int32_t> u8dist;
105   std::uniform_real_distribution<float> scale_dist;
106 
107   uint32_t input1_id;
108   uint32_t input2_id;
109   uint32_t input3_id;
110   uint32_t input4_id;
111   uint32_t output_id;
112 
113   std::vector<size_t> input1_dims;
114   std::vector<size_t> input2_dims;
115   std::vector<size_t> input3_dims;
116   std::vector<size_t> input4_dims;
117   std::vector<size_t> output_dims;
118 
119   size_t axis;
120   size_t batch_size;
121   size_t channels_1;
122   size_t channels_2;
123   size_t channels_3;
124   size_t channels_4;
125   size_t output_stride;
126 
127   int32_t signed_zero_point;
128   int32_t unsigned_zero_point;
129   float scale;
130 
131   std::vector<T> input1;
132   std::vector<T> input2;
133   std::vector<T> input3;
134   std::vector<T> input4;
135   std::vector<T> operator_output;
136   std::vector<T> subgraph_output;
137 };
138 
139 using Concatenate4TestQS8 = Concatenate4Test<int8_t>;
140 using Concatenate4TestQU8 = Concatenate4Test<uint8_t>;
141 using Concatenate4TestF32 = Concatenate4Test<float>;
142 
TEST_F(Concatenate4TestQS8,define)143 TEST_F(Concatenate4TestQS8, define)
144 {
145   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
146 
147   xnn_subgraph_t subgraph = nullptr;
148   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
149   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
150 
151   input1_id = XNN_INVALID_NODE_ID;
152   ASSERT_EQ(
153     xnn_status_success,
154     xnn_define_quantized_tensor_value(
155       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
156       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
157   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
158 
159   input2_id = XNN_INVALID_NODE_ID;
160   ASSERT_EQ(
161     xnn_status_success,
162     xnn_define_quantized_tensor_value(
163       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
164       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
165   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
166 
167   input3_id = XNN_INVALID_NODE_ID;
168   ASSERT_EQ(
169     xnn_status_success,
170     xnn_define_quantized_tensor_value(
171       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
172       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
173   ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
174 
175   input4_id = XNN_INVALID_NODE_ID;
176   ASSERT_EQ(
177     xnn_status_success,
178     xnn_define_quantized_tensor_value(
179       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input4_dims.size(), input4_dims.data(), nullptr, 3,
180       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input4_id));
181   ASSERT_NE(input4_id, XNN_INVALID_NODE_ID);
182 
183   output_id = XNN_INVALID_NODE_ID;
184   ASSERT_EQ(
185     xnn_status_success,
186     xnn_define_quantized_tensor_value(
187       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 4,
188       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
189   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
190 
191   ASSERT_EQ(
192     xnn_status_success,
193     xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0));
194 
195   ASSERT_EQ(subgraph->num_nodes, 1);
196   const struct xnn_node* node = &subgraph->nodes[0];
197   ASSERT_EQ(node->type, xnn_node_type_concatenate4);
198   ASSERT_EQ(node->compute_type, xnn_compute_type_qs8);
199   ASSERT_EQ(node->params.concatenate.axis, axis);
200   ASSERT_EQ(node->num_inputs, 4);
201   ASSERT_EQ(node->inputs[0], input1_id);
202   ASSERT_EQ(node->inputs[1], input2_id);
203   ASSERT_EQ(node->inputs[2], input3_id);
204   ASSERT_EQ(node->inputs[3], input4_id);
205   ASSERT_EQ(node->num_outputs, 1);
206   ASSERT_EQ(node->outputs[0], output_id);
207   ASSERT_EQ(node->flags, 0);
208 }
209 
TEST_F(Concatenate4TestQU8,define)210 TEST_F(Concatenate4TestQU8, define)
211 {
212   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
213 
214   xnn_subgraph_t subgraph = nullptr;
215   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
216   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
217 
218   input1_id = XNN_INVALID_NODE_ID;
219   ASSERT_EQ(
220     xnn_status_success,
221     xnn_define_quantized_tensor_value(
222       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
223       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
224   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
225 
226   input2_id = XNN_INVALID_NODE_ID;
227   ASSERT_EQ(
228     xnn_status_success,
229     xnn_define_quantized_tensor_value(
230       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
231       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
232   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
233 
234   input3_id = XNN_INVALID_NODE_ID;
235   ASSERT_EQ(
236     xnn_status_success,
237     xnn_define_quantized_tensor_value(
238       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
239       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
240   ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
241 
242   input4_id = XNN_INVALID_NODE_ID;
243   ASSERT_EQ(
244     xnn_status_success,
245     xnn_define_quantized_tensor_value(
246       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input4_dims.size(), input4_dims.data(), nullptr, 3,
247       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input4_id));
248   ASSERT_NE(input4_id, XNN_INVALID_NODE_ID);
249 
250   output_id = XNN_INVALID_NODE_ID;
251   ASSERT_EQ(
252     xnn_status_success,
253     xnn_define_quantized_tensor_value(
254       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 4,
255       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
256   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
257 
258   ASSERT_EQ(
259     xnn_status_success,
260     xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0));
261 
262   ASSERT_EQ(subgraph->num_nodes, 1);
263   const struct xnn_node* node = &subgraph->nodes[0];
264   ASSERT_EQ(node->type, xnn_node_type_concatenate4);
265   ASSERT_EQ(node->compute_type, xnn_compute_type_qu8);
266   ASSERT_EQ(node->params.concatenate.axis, axis);
267   ASSERT_EQ(node->num_inputs, 4);
268   ASSERT_EQ(node->inputs[0], input1_id);
269   ASSERT_EQ(node->inputs[1], input2_id);
270   ASSERT_EQ(node->inputs[2], input3_id);
271   ASSERT_EQ(node->inputs[3], input4_id);
272   ASSERT_EQ(node->num_outputs, 1);
273   ASSERT_EQ(node->outputs[0], output_id);
274   ASSERT_EQ(node->flags, 0);
275 }
276 
TEST_F(Concatenate4TestF32,define)277 TEST_F(Concatenate4TestF32, define)
278 {
279   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
280 
281   xnn_subgraph_t subgraph = nullptr;
282   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
283   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
284 
285   input1_id = XNN_INVALID_NODE_ID;
286   ASSERT_EQ(
287     xnn_status_success, xnn_define_tensor_value(
288                           subgraph, xnn_datatype_fp32, input1_dims.size(), input1_dims.data(), nullptr, 0,
289                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
290   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
291 
292   input2_id = XNN_INVALID_NODE_ID;
293   ASSERT_EQ(
294     xnn_status_success, xnn_define_tensor_value(
295                           subgraph, xnn_datatype_fp32, input2_dims.size(), input2_dims.data(), nullptr, 1,
296                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
297   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
298 
299   input3_id = XNN_INVALID_NODE_ID;
300   ASSERT_EQ(
301     xnn_status_success, xnn_define_tensor_value(
302                           subgraph, xnn_datatype_fp32, input3_dims.size(), input3_dims.data(), nullptr, 2,
303                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
304   ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
305 
306   input4_id = XNN_INVALID_NODE_ID;
307   ASSERT_EQ(
308     xnn_status_success, xnn_define_tensor_value(
309                           subgraph, xnn_datatype_fp32, input4_dims.size(), input4_dims.data(), nullptr, 3,
310                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input4_id));
311   ASSERT_NE(input4_id, XNN_INVALID_NODE_ID);
312 
313   output_id = XNN_INVALID_NODE_ID;
314   ASSERT_EQ(
315     xnn_status_success, xnn_define_tensor_value(
316                           subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, 4,
317                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
318   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
319 
320   ASSERT_EQ(
321     xnn_status_success,
322     xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0));
323 
324   ASSERT_EQ(subgraph->num_nodes, 1);
325   const struct xnn_node* node = &subgraph->nodes[0];
326   ASSERT_EQ(node->type, xnn_node_type_concatenate4);
327   ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
328   ASSERT_EQ(node->params.concatenate.axis, axis);
329   ASSERT_EQ(node->num_inputs, 4);
330   ASSERT_EQ(node->inputs[0], input1_id);
331   ASSERT_EQ(node->inputs[1], input2_id);
332   ASSERT_EQ(node->inputs[2], input3_id);
333   ASSERT_EQ(node->inputs[3], input4_id);
334   ASSERT_EQ(node->num_outputs, 1);
335   ASSERT_EQ(node->outputs[0], output_id);
336   ASSERT_EQ(node->flags, 0);
337 }
338 
TEST_F(Concatenate4TestQS8,matches_operator_api)339 TEST_F(Concatenate4TestQS8, matches_operator_api)
340 {
341   std::generate(input1.begin(), input1.end(), [&]() { return i8dist(rng); });
342   std::generate(input2.begin(), input2.end(), [&]() { return i8dist(rng); });
343   std::generate(input3.begin(), input3.end(), [&]() { return i8dist(rng); });
344   std::generate(input4.begin(), input4.end(), [&]() { return i8dist(rng); });
345   std::fill(operator_output.begin(), operator_output.end(), INT8_C(0xA5));
346   std::fill(subgraph_output.begin(), subgraph_output.end(), INT8_C(0xA5));
347 
348   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
349 
350   xnn_operator_t op1 = nullptr;
351   xnn_operator_t op2 = nullptr;
352   xnn_operator_t op3 = nullptr;
353   xnn_operator_t op4 = nullptr;
354 
355   // Call operator API.
356   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
357   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
358   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
359   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
360   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_3, channels_3, output_stride, /*flags=*/0, &op3));
361   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
362   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_4, channels_4, output_stride, /*flags=*/0, &op4));
363   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op4(op4, xnn_delete_operator);
364 
365   ASSERT_EQ(
366     xnn_status_success,
367     xnn_setup_copy_nc_x8(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
368   ASSERT_EQ(
369     xnn_status_success,
370     xnn_setup_copy_nc_x8(
371       op2, batch_size, input2.data(), (uint8_t*) operator_output.data() + op1->channels, nullptr /* thread pool */));
372   ASSERT_EQ(
373     xnn_status_success,
374     xnn_setup_copy_nc_x8(
375       op3, batch_size, input3.data(), (uint8_t*) operator_output.data() + op1->channels + op2->channels,
376       nullptr /* thread pool */));
377   ASSERT_EQ(
378     xnn_status_success,
379     xnn_setup_copy_nc_x8(
380       op4, batch_size, input4.data(), (uint8_t*) operator_output.data() + op1->channels + op2->channels + op3->channels,
381       nullptr /* thread pool */));
382 
383   ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
384   ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
385   ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
386   ASSERT_EQ(xnn_status_success, xnn_run_operator(op4, nullptr /* thread pool */));
387 
388   // Call subgraph API.
389   xnn_subgraph_t subgraph = nullptr;
390   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
391   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
392 
393   input1_id = XNN_INVALID_NODE_ID;
394   ASSERT_EQ(
395     xnn_status_success,
396     xnn_define_quantized_tensor_value(
397       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
398       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
399   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
400 
401   input2_id = XNN_INVALID_NODE_ID;
402   ASSERT_EQ(
403     xnn_status_success,
404     xnn_define_quantized_tensor_value(
405       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
406       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
407   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
408 
409   input3_id = XNN_INVALID_NODE_ID;
410   ASSERT_EQ(
411     xnn_status_success,
412     xnn_define_quantized_tensor_value(
413       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
414       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
415   ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
416 
417   input4_id = XNN_INVALID_NODE_ID;
418   ASSERT_EQ(
419     xnn_status_success,
420     xnn_define_quantized_tensor_value(
421       subgraph, xnn_datatype_qint8, signed_zero_point, scale, input4_dims.size(), input4_dims.data(), nullptr, 3,
422       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input4_id));
423   ASSERT_NE(input4_id, XNN_INVALID_NODE_ID);
424 
425   output_id = XNN_INVALID_NODE_ID;
426   ASSERT_EQ(
427     xnn_status_success,
428     xnn_define_quantized_tensor_value(
429       subgraph, xnn_datatype_qint8, signed_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 4,
430       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
431   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
432 
433   ASSERT_EQ(
434     xnn_status_success,
435     xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0));
436 
437   xnn_runtime_t runtime = nullptr;
438   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
439   ASSERT_NE(nullptr, runtime);
440   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
441   std::array<xnn_external_value, 5> external = {
442     xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
443     xnn_external_value{input3_id, input3.data()}, xnn_external_value{input4_id, input4.data()},
444     xnn_external_value{output_id, subgraph_output.data()}};
445   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
446   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
447 
448   // Check outputs match.
449   ASSERT_EQ(subgraph_output, operator_output);
450 }
451 
TEST_F(Concatenate4TestQU8,matches_operator_api)452 TEST_F(Concatenate4TestQU8, matches_operator_api)
453 {
454   std::generate(input1.begin(), input1.end(), [&]() { return u8dist(rng); });
455   std::generate(input2.begin(), input2.end(), [&]() { return u8dist(rng); });
456   std::generate(input3.begin(), input3.end(), [&]() { return u8dist(rng); });
457   std::generate(input4.begin(), input4.end(), [&]() { return u8dist(rng); });
458   std::fill(operator_output.begin(), operator_output.end(), UINT8_C(0xA5));
459   std::fill(subgraph_output.begin(), subgraph_output.end(), UINT8_C(0xA5));
460 
461   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
462 
463   xnn_operator_t op1 = nullptr;
464   xnn_operator_t op2 = nullptr;
465   xnn_operator_t op3 = nullptr;
466   xnn_operator_t op4 = nullptr;
467 
468   // Call operator API.
469   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
470   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
471   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
472   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
473   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_3, channels_3, output_stride, /*flags=*/0, &op3));
474   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
475   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_4, channels_4, output_stride, /*flags=*/0, &op4));
476   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op4(op4, xnn_delete_operator);
477 
478   ASSERT_EQ(
479     xnn_status_success,
480     xnn_setup_copy_nc_x8(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
481   ASSERT_EQ(
482     xnn_status_success,
483     xnn_setup_copy_nc_x8(
484       op2, batch_size, input2.data(), (uint8_t*) operator_output.data() + op1->channels, nullptr /* thread pool */));
485   ASSERT_EQ(
486     xnn_status_success,
487     xnn_setup_copy_nc_x8(
488       op3, batch_size, input3.data(), (uint8_t*) operator_output.data() + op1->channels + op2->channels,
489       nullptr /* thread pool */));
490   ASSERT_EQ(
491     xnn_status_success,
492     xnn_setup_copy_nc_x8(
493       op4, batch_size, input4.data(), (uint8_t*) operator_output.data() + op1->channels + op2->channels + op3->channels,
494       nullptr /* thread pool */));
495 
496   ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
497   ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
498   ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
499   ASSERT_EQ(xnn_status_success, xnn_run_operator(op4, nullptr /* thread pool */));
500 
501   // Call subgraph API.
502   xnn_subgraph_t subgraph = nullptr;
503   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
504   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
505 
506   input1_id = XNN_INVALID_NODE_ID;
507   ASSERT_EQ(
508     xnn_status_success,
509     xnn_define_quantized_tensor_value(
510       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
511       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
512   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
513 
514   input2_id = XNN_INVALID_NODE_ID;
515   ASSERT_EQ(
516     xnn_status_success,
517     xnn_define_quantized_tensor_value(
518       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
519       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
520   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
521 
522   input3_id = XNN_INVALID_NODE_ID;
523   ASSERT_EQ(
524     xnn_status_success,
525     xnn_define_quantized_tensor_value(
526       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
527       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
528   ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
529 
530   input4_id = XNN_INVALID_NODE_ID;
531   ASSERT_EQ(
532     xnn_status_success,
533     xnn_define_quantized_tensor_value(
534       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input4_dims.size(), input4_dims.data(), nullptr, 3,
535       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input4_id));
536   ASSERT_NE(input4_id, XNN_INVALID_NODE_ID);
537 
538   output_id = XNN_INVALID_NODE_ID;
539   ASSERT_EQ(
540     xnn_status_success,
541     xnn_define_quantized_tensor_value(
542       subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 4,
543       /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
544   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
545 
546   ASSERT_EQ(
547     xnn_status_success,
548     xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0));
549 
550   xnn_runtime_t runtime = nullptr;
551   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
552   ASSERT_NE(nullptr, runtime);
553   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
554   std::array<xnn_external_value, 5> external = {
555     xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
556     xnn_external_value{input3_id, input3.data()}, xnn_external_value{input4_id, input4.data()},
557     xnn_external_value{output_id, subgraph_output.data()}};
558   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
559   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
560 
561   // Check outputs match.
562   ASSERT_EQ(subgraph_output, operator_output);
563 }
564 
TEST_F(Concatenate4TestF32,matches_operator_api)565 TEST_F(Concatenate4TestF32, matches_operator_api)
566 {
567   std::generate(input1.begin(), input1.end(), [&]() { return f32dist(rng); });
568   std::generate(input2.begin(), input2.end(), [&]() { return f32dist(rng); });
569   std::generate(input3.begin(), input3.end(), [&]() { return f32dist(rng); });
570   std::generate(input4.begin(), input4.end(), [&]() { return f32dist(rng); });
571   std::fill(operator_output.begin(), operator_output.end(), std::nanf(""));
572   std::fill(subgraph_output.begin(), subgraph_output.end(), std::nanf(""));
573 
574   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
575 
576   xnn_operator_t op1 = nullptr;
577   xnn_operator_t op2 = nullptr;
578   xnn_operator_t op3 = nullptr;
579   xnn_operator_t op4 = nullptr;
580 
581   // Call operator API.
582   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
583   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
584   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
585   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
586   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_3, channels_3, output_stride, /*flags=*/0, &op3));
587   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
588   ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_4, channels_4, output_stride, /*flags=*/0, &op4));
589   std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op4(op4, xnn_delete_operator);
590 
591   ASSERT_EQ(
592     xnn_status_success,
593     xnn_setup_copy_nc_x32(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
594   ASSERT_EQ(
595     xnn_status_success,
596     xnn_setup_copy_nc_x32(
597       op2, batch_size, input2.data(), (float*) operator_output.data() + op1->channels, nullptr /* thread pool */));
598   ASSERT_EQ(
599     xnn_status_success, xnn_setup_copy_nc_x32(
600                           op3, batch_size, input3.data(),
601                           (float*) operator_output.data() + op1->channels + op2->channels, nullptr /* thread pool */));
602   ASSERT_EQ(
603     xnn_status_success,
604     xnn_setup_copy_nc_x32(
605       op4, batch_size, input4.data(), (float*) operator_output.data() + op1->channels + op2->channels + op3->channels,
606       nullptr /* thread pool */));
607 
608   ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
609   ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
610   ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
611   ASSERT_EQ(xnn_status_success, xnn_run_operator(op4, nullptr /* thread pool */));
612 
613   // Call subgraph API.
614   xnn_subgraph_t subgraph = nullptr;
615   ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
616   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
617 
618   input1_id = XNN_INVALID_NODE_ID;
619   ASSERT_EQ(
620     xnn_status_success, xnn_define_tensor_value(
621                           subgraph, xnn_datatype_fp32, input1_dims.size(), input1_dims.data(), nullptr, 0,
622                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
623   ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
624 
625   input2_id = XNN_INVALID_NODE_ID;
626   ASSERT_EQ(
627     xnn_status_success, xnn_define_tensor_value(
628                           subgraph, xnn_datatype_fp32, input2_dims.size(), input2_dims.data(), nullptr, 1,
629                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
630   ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
631 
632   input3_id = XNN_INVALID_NODE_ID;
633   ASSERT_EQ(
634     xnn_status_success, xnn_define_tensor_value(
635                           subgraph, xnn_datatype_fp32, input3_dims.size(), input3_dims.data(), nullptr, 2,
636                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
637   ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
638 
639   input4_id = XNN_INVALID_NODE_ID;
640   ASSERT_EQ(
641     xnn_status_success, xnn_define_tensor_value(
642                           subgraph, xnn_datatype_fp32, input4_dims.size(), input4_dims.data(), nullptr, 3,
643                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input4_id));
644   ASSERT_NE(input4_id, XNN_INVALID_NODE_ID);
645 
646   output_id = XNN_INVALID_NODE_ID;
647   ASSERT_EQ(
648     xnn_status_success, xnn_define_tensor_value(
649                           subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, 4,
650                           /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
651   ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
652 
653   ASSERT_EQ(
654     xnn_status_success,
655     xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0));
656 
657   xnn_runtime_t runtime = nullptr;
658   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
659   ASSERT_NE(nullptr, runtime);
660   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
661   std::array<xnn_external_value, 5> external = {
662     xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
663     xnn_external_value{input3_id, input3.data()}, xnn_external_value{input4_id, input4.data()},
664     xnn_external_value{output_id, subgraph_output.data()}};
665   ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
666   ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
667 
668   // Check outputs match.
669   ASSERT_EQ(subgraph_output, operator_output);
670 }
671