1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <algorithm>
7 #include <array>
8 #include <cstddef>
9 #include <cstdint>
10 #include <limits>
11 #include <memory>
12 #include <numeric>
13 #include <random>
14
15 #include <xnnpack.h>
16 #include <xnnpack/node-type.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/subgraph.h>
19
20 #include <gtest/gtest.h>
21
22 template <typename T> class Concatenate4Test : public ::testing::Test {
23 protected:
Concatenate4Test()24 Concatenate4Test()
25 {
26 random_device = std::unique_ptr<std::random_device>(new std::random_device());
27 rng = std::mt19937((*random_device)());
28 shape_dist = std::uniform_int_distribution<size_t>(1, XNN_MAX_TENSOR_DIMS);
29 dim_dist = std::uniform_int_distribution<size_t>(1, 9);
30 f32dist = std::uniform_real_distribution<float>();
31 i8dist =
32 std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
33 u8dist =
34 std::uniform_int_distribution<int32_t>(std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
35 scale_dist = std::uniform_real_distribution<float>(0.1f, 5.0f);
36
37 input1_dims = RandomShape();
38 axis = RandomAxis(input1_dims);
39 input2_dims = RandomShape(input1_dims, axis);
40 input3_dims = RandomShape(input1_dims, axis);
41 input4_dims = RandomShape(input1_dims, axis);
42 output_dims = input1_dims;
43 output_dims[axis] = input1_dims[axis] + input2_dims[axis] + input3_dims[axis] + input4_dims[axis];
44
45 input1 = std::vector<T>(NumElements(input1_dims));
46 input2 = std::vector<T>(NumElements(input2_dims));
47 input3 = std::vector<T>(NumElements(input3_dims));
48 input4 = std::vector<T>(NumElements(input4_dims));
49 operator_output = std::vector<T>(NumElements(output_dims));
50 subgraph_output = std::vector<T>(NumElements(output_dims));
51
52 signed_zero_point = i8dist(rng);
53 unsigned_zero_point = u8dist(rng);
54 scale = scale_dist(rng);
55
56 batch_size = 1;
57 channels_1 = 1;
58 channels_2 = 1;
59 channels_3 = 1;
60 channels_4 = 1;
61 for (size_t i = 0; i < axis; i++) {
62 batch_size *= output_dims[i];
63 }
64
65 for (size_t i = axis; i < input1_dims.size(); i++) {
66 channels_1 *= input1_dims[i];
67 channels_2 *= input2_dims[i];
68 channels_3 *= input3_dims[i];
69 channels_4 *= input4_dims[i];
70 }
71 output_stride = channels_1 + channels_2 + channels_3 + channels_4;
72 }
73
RandomShape()74 std::vector<size_t> RandomShape()
75 {
76 std::vector<size_t> dims(shape_dist(rng));
77 std::generate(dims.begin(), dims.end(), [&] { return dim_dist(rng); });
78 return dims;
79 }
80
RandomShape(const std::vector<size_t> base_dims,size_t axis)81 std::vector<size_t> RandomShape(const std::vector<size_t> base_dims, size_t axis)
82 {
83 auto dims = base_dims;
84 dims[axis] = dim_dist(rng);
85 return dims;
86 }
87
RandomAxis(const std::vector<size_t> & dims)88 size_t RandomAxis(const std::vector<size_t>& dims)
89 {
90 return std::uniform_int_distribution<size_t>(0, dims.size() - 1)(rng);
91 }
92
NumElements(const std::vector<size_t> & dims)93 size_t NumElements(const std::vector<size_t>& dims)
94 {
95 return std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>());
96 }
97
98 std::unique_ptr<std::random_device> random_device;
99 std::mt19937 rng;
100 std::uniform_int_distribution<size_t> shape_dist;
101 std::uniform_int_distribution<size_t> dim_dist;
102 std::uniform_real_distribution<float> f32dist;
103 std::uniform_int_distribution<int32_t> i8dist;
104 std::uniform_int_distribution<int32_t> u8dist;
105 std::uniform_real_distribution<float> scale_dist;
106
107 uint32_t input1_id;
108 uint32_t input2_id;
109 uint32_t input3_id;
110 uint32_t input4_id;
111 uint32_t output_id;
112
113 std::vector<size_t> input1_dims;
114 std::vector<size_t> input2_dims;
115 std::vector<size_t> input3_dims;
116 std::vector<size_t> input4_dims;
117 std::vector<size_t> output_dims;
118
119 size_t axis;
120 size_t batch_size;
121 size_t channels_1;
122 size_t channels_2;
123 size_t channels_3;
124 size_t channels_4;
125 size_t output_stride;
126
127 int32_t signed_zero_point;
128 int32_t unsigned_zero_point;
129 float scale;
130
131 std::vector<T> input1;
132 std::vector<T> input2;
133 std::vector<T> input3;
134 std::vector<T> input4;
135 std::vector<T> operator_output;
136 std::vector<T> subgraph_output;
137 };
138
139 using Concatenate4TestQS8 = Concatenate4Test<int8_t>;
140 using Concatenate4TestQU8 = Concatenate4Test<uint8_t>;
141 using Concatenate4TestF32 = Concatenate4Test<float>;
142
TEST_F(Concatenate4TestQS8,define)143 TEST_F(Concatenate4TestQS8, define)
144 {
145 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
146
147 xnn_subgraph_t subgraph = nullptr;
148 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
149 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
150
151 input1_id = XNN_INVALID_NODE_ID;
152 ASSERT_EQ(
153 xnn_status_success,
154 xnn_define_quantized_tensor_value(
155 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
156 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
157 ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
158
159 input2_id = XNN_INVALID_NODE_ID;
160 ASSERT_EQ(
161 xnn_status_success,
162 xnn_define_quantized_tensor_value(
163 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
164 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
165 ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
166
167 input3_id = XNN_INVALID_NODE_ID;
168 ASSERT_EQ(
169 xnn_status_success,
170 xnn_define_quantized_tensor_value(
171 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
172 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
173 ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
174
175 input4_id = XNN_INVALID_NODE_ID;
176 ASSERT_EQ(
177 xnn_status_success,
178 xnn_define_quantized_tensor_value(
179 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input4_dims.size(), input4_dims.data(), nullptr, 3,
180 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input4_id));
181 ASSERT_NE(input4_id, XNN_INVALID_NODE_ID);
182
183 output_id = XNN_INVALID_NODE_ID;
184 ASSERT_EQ(
185 xnn_status_success,
186 xnn_define_quantized_tensor_value(
187 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 4,
188 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
189 ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
190
191 ASSERT_EQ(
192 xnn_status_success,
193 xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0));
194
195 ASSERT_EQ(subgraph->num_nodes, 1);
196 const struct xnn_node* node = &subgraph->nodes[0];
197 ASSERT_EQ(node->type, xnn_node_type_concatenate4);
198 ASSERT_EQ(node->compute_type, xnn_compute_type_qs8);
199 ASSERT_EQ(node->params.concatenate.axis, axis);
200 ASSERT_EQ(node->num_inputs, 4);
201 ASSERT_EQ(node->inputs[0], input1_id);
202 ASSERT_EQ(node->inputs[1], input2_id);
203 ASSERT_EQ(node->inputs[2], input3_id);
204 ASSERT_EQ(node->inputs[3], input4_id);
205 ASSERT_EQ(node->num_outputs, 1);
206 ASSERT_EQ(node->outputs[0], output_id);
207 ASSERT_EQ(node->flags, 0);
208 }
209
TEST_F(Concatenate4TestQU8,define)210 TEST_F(Concatenate4TestQU8, define)
211 {
212 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
213
214 xnn_subgraph_t subgraph = nullptr;
215 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
216 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
217
218 input1_id = XNN_INVALID_NODE_ID;
219 ASSERT_EQ(
220 xnn_status_success,
221 xnn_define_quantized_tensor_value(
222 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
223 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
224 ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
225
226 input2_id = XNN_INVALID_NODE_ID;
227 ASSERT_EQ(
228 xnn_status_success,
229 xnn_define_quantized_tensor_value(
230 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
231 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
232 ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
233
234 input3_id = XNN_INVALID_NODE_ID;
235 ASSERT_EQ(
236 xnn_status_success,
237 xnn_define_quantized_tensor_value(
238 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
239 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
240 ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
241
242 input4_id = XNN_INVALID_NODE_ID;
243 ASSERT_EQ(
244 xnn_status_success,
245 xnn_define_quantized_tensor_value(
246 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input4_dims.size(), input4_dims.data(), nullptr, 3,
247 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input4_id));
248 ASSERT_NE(input4_id, XNN_INVALID_NODE_ID);
249
250 output_id = XNN_INVALID_NODE_ID;
251 ASSERT_EQ(
252 xnn_status_success,
253 xnn_define_quantized_tensor_value(
254 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 4,
255 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
256 ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
257
258 ASSERT_EQ(
259 xnn_status_success,
260 xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0));
261
262 ASSERT_EQ(subgraph->num_nodes, 1);
263 const struct xnn_node* node = &subgraph->nodes[0];
264 ASSERT_EQ(node->type, xnn_node_type_concatenate4);
265 ASSERT_EQ(node->compute_type, xnn_compute_type_qu8);
266 ASSERT_EQ(node->params.concatenate.axis, axis);
267 ASSERT_EQ(node->num_inputs, 4);
268 ASSERT_EQ(node->inputs[0], input1_id);
269 ASSERT_EQ(node->inputs[1], input2_id);
270 ASSERT_EQ(node->inputs[2], input3_id);
271 ASSERT_EQ(node->inputs[3], input4_id);
272 ASSERT_EQ(node->num_outputs, 1);
273 ASSERT_EQ(node->outputs[0], output_id);
274 ASSERT_EQ(node->flags, 0);
275 }
276
TEST_F(Concatenate4TestF32,define)277 TEST_F(Concatenate4TestF32, define)
278 {
279 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
280
281 xnn_subgraph_t subgraph = nullptr;
282 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
283 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
284
285 input1_id = XNN_INVALID_NODE_ID;
286 ASSERT_EQ(
287 xnn_status_success, xnn_define_tensor_value(
288 subgraph, xnn_datatype_fp32, input1_dims.size(), input1_dims.data(), nullptr, 0,
289 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
290 ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
291
292 input2_id = XNN_INVALID_NODE_ID;
293 ASSERT_EQ(
294 xnn_status_success, xnn_define_tensor_value(
295 subgraph, xnn_datatype_fp32, input2_dims.size(), input2_dims.data(), nullptr, 1,
296 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
297 ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
298
299 input3_id = XNN_INVALID_NODE_ID;
300 ASSERT_EQ(
301 xnn_status_success, xnn_define_tensor_value(
302 subgraph, xnn_datatype_fp32, input3_dims.size(), input3_dims.data(), nullptr, 2,
303 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
304 ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
305
306 input4_id = XNN_INVALID_NODE_ID;
307 ASSERT_EQ(
308 xnn_status_success, xnn_define_tensor_value(
309 subgraph, xnn_datatype_fp32, input4_dims.size(), input4_dims.data(), nullptr, 3,
310 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input4_id));
311 ASSERT_NE(input4_id, XNN_INVALID_NODE_ID);
312
313 output_id = XNN_INVALID_NODE_ID;
314 ASSERT_EQ(
315 xnn_status_success, xnn_define_tensor_value(
316 subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, 4,
317 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
318 ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
319
320 ASSERT_EQ(
321 xnn_status_success,
322 xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0));
323
324 ASSERT_EQ(subgraph->num_nodes, 1);
325 const struct xnn_node* node = &subgraph->nodes[0];
326 ASSERT_EQ(node->type, xnn_node_type_concatenate4);
327 ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
328 ASSERT_EQ(node->params.concatenate.axis, axis);
329 ASSERT_EQ(node->num_inputs, 4);
330 ASSERT_EQ(node->inputs[0], input1_id);
331 ASSERT_EQ(node->inputs[1], input2_id);
332 ASSERT_EQ(node->inputs[2], input3_id);
333 ASSERT_EQ(node->inputs[3], input4_id);
334 ASSERT_EQ(node->num_outputs, 1);
335 ASSERT_EQ(node->outputs[0], output_id);
336 ASSERT_EQ(node->flags, 0);
337 }
338
TEST_F(Concatenate4TestQS8,matches_operator_api)339 TEST_F(Concatenate4TestQS8, matches_operator_api)
340 {
341 std::generate(input1.begin(), input1.end(), [&]() { return i8dist(rng); });
342 std::generate(input2.begin(), input2.end(), [&]() { return i8dist(rng); });
343 std::generate(input3.begin(), input3.end(), [&]() { return i8dist(rng); });
344 std::generate(input4.begin(), input4.end(), [&]() { return i8dist(rng); });
345 std::fill(operator_output.begin(), operator_output.end(), INT8_C(0xA5));
346 std::fill(subgraph_output.begin(), subgraph_output.end(), INT8_C(0xA5));
347
348 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
349
350 xnn_operator_t op1 = nullptr;
351 xnn_operator_t op2 = nullptr;
352 xnn_operator_t op3 = nullptr;
353 xnn_operator_t op4 = nullptr;
354
355 // Call operator API.
356 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
357 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
358 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
359 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
360 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_3, channels_3, output_stride, /*flags=*/0, &op3));
361 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
362 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_4, channels_4, output_stride, /*flags=*/0, &op4));
363 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op4(op4, xnn_delete_operator);
364
365 ASSERT_EQ(
366 xnn_status_success,
367 xnn_setup_copy_nc_x8(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
368 ASSERT_EQ(
369 xnn_status_success,
370 xnn_setup_copy_nc_x8(
371 op2, batch_size, input2.data(), (uint8_t*) operator_output.data() + op1->channels, nullptr /* thread pool */));
372 ASSERT_EQ(
373 xnn_status_success,
374 xnn_setup_copy_nc_x8(
375 op3, batch_size, input3.data(), (uint8_t*) operator_output.data() + op1->channels + op2->channels,
376 nullptr /* thread pool */));
377 ASSERT_EQ(
378 xnn_status_success,
379 xnn_setup_copy_nc_x8(
380 op4, batch_size, input4.data(), (uint8_t*) operator_output.data() + op1->channels + op2->channels + op3->channels,
381 nullptr /* thread pool */));
382
383 ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
384 ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
385 ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
386 ASSERT_EQ(xnn_status_success, xnn_run_operator(op4, nullptr /* thread pool */));
387
388 // Call subgraph API.
389 xnn_subgraph_t subgraph = nullptr;
390 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
391 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
392
393 input1_id = XNN_INVALID_NODE_ID;
394 ASSERT_EQ(
395 xnn_status_success,
396 xnn_define_quantized_tensor_value(
397 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
398 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
399 ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
400
401 input2_id = XNN_INVALID_NODE_ID;
402 ASSERT_EQ(
403 xnn_status_success,
404 xnn_define_quantized_tensor_value(
405 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
406 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
407 ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
408
409 input3_id = XNN_INVALID_NODE_ID;
410 ASSERT_EQ(
411 xnn_status_success,
412 xnn_define_quantized_tensor_value(
413 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
414 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
415 ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
416
417 input4_id = XNN_INVALID_NODE_ID;
418 ASSERT_EQ(
419 xnn_status_success,
420 xnn_define_quantized_tensor_value(
421 subgraph, xnn_datatype_qint8, signed_zero_point, scale, input4_dims.size(), input4_dims.data(), nullptr, 3,
422 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input4_id));
423 ASSERT_NE(input4_id, XNN_INVALID_NODE_ID);
424
425 output_id = XNN_INVALID_NODE_ID;
426 ASSERT_EQ(
427 xnn_status_success,
428 xnn_define_quantized_tensor_value(
429 subgraph, xnn_datatype_qint8, signed_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 4,
430 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
431 ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
432
433 ASSERT_EQ(
434 xnn_status_success,
435 xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0));
436
437 xnn_runtime_t runtime = nullptr;
438 ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
439 ASSERT_NE(nullptr, runtime);
440 std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
441 std::array<xnn_external_value, 5> external = {
442 xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
443 xnn_external_value{input3_id, input3.data()}, xnn_external_value{input4_id, input4.data()},
444 xnn_external_value{output_id, subgraph_output.data()}};
445 ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
446 ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
447
448 // Check outputs match.
449 ASSERT_EQ(subgraph_output, operator_output);
450 }
451
TEST_F(Concatenate4TestQU8,matches_operator_api)452 TEST_F(Concatenate4TestQU8, matches_operator_api)
453 {
454 std::generate(input1.begin(), input1.end(), [&]() { return u8dist(rng); });
455 std::generate(input2.begin(), input2.end(), [&]() { return u8dist(rng); });
456 std::generate(input3.begin(), input3.end(), [&]() { return u8dist(rng); });
457 std::generate(input4.begin(), input4.end(), [&]() { return u8dist(rng); });
458 std::fill(operator_output.begin(), operator_output.end(), UINT8_C(0xA5));
459 std::fill(subgraph_output.begin(), subgraph_output.end(), UINT8_C(0xA5));
460
461 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
462
463 xnn_operator_t op1 = nullptr;
464 xnn_operator_t op2 = nullptr;
465 xnn_operator_t op3 = nullptr;
466 xnn_operator_t op4 = nullptr;
467
468 // Call operator API.
469 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
470 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
471 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
472 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
473 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_3, channels_3, output_stride, /*flags=*/0, &op3));
474 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
475 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x8(channels_4, channels_4, output_stride, /*flags=*/0, &op4));
476 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op4(op4, xnn_delete_operator);
477
478 ASSERT_EQ(
479 xnn_status_success,
480 xnn_setup_copy_nc_x8(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
481 ASSERT_EQ(
482 xnn_status_success,
483 xnn_setup_copy_nc_x8(
484 op2, batch_size, input2.data(), (uint8_t*) operator_output.data() + op1->channels, nullptr /* thread pool */));
485 ASSERT_EQ(
486 xnn_status_success,
487 xnn_setup_copy_nc_x8(
488 op3, batch_size, input3.data(), (uint8_t*) operator_output.data() + op1->channels + op2->channels,
489 nullptr /* thread pool */));
490 ASSERT_EQ(
491 xnn_status_success,
492 xnn_setup_copy_nc_x8(
493 op4, batch_size, input4.data(), (uint8_t*) operator_output.data() + op1->channels + op2->channels + op3->channels,
494 nullptr /* thread pool */));
495
496 ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
497 ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
498 ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
499 ASSERT_EQ(xnn_status_success, xnn_run_operator(op4, nullptr /* thread pool */));
500
501 // Call subgraph API.
502 xnn_subgraph_t subgraph = nullptr;
503 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
504 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
505
506 input1_id = XNN_INVALID_NODE_ID;
507 ASSERT_EQ(
508 xnn_status_success,
509 xnn_define_quantized_tensor_value(
510 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input1_dims.size(), input1_dims.data(), nullptr, 0,
511 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
512 ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
513
514 input2_id = XNN_INVALID_NODE_ID;
515 ASSERT_EQ(
516 xnn_status_success,
517 xnn_define_quantized_tensor_value(
518 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input2_dims.size(), input2_dims.data(), nullptr, 1,
519 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
520 ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
521
522 input3_id = XNN_INVALID_NODE_ID;
523 ASSERT_EQ(
524 xnn_status_success,
525 xnn_define_quantized_tensor_value(
526 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input3_dims.size(), input3_dims.data(), nullptr, 2,
527 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
528 ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
529
530 input4_id = XNN_INVALID_NODE_ID;
531 ASSERT_EQ(
532 xnn_status_success,
533 xnn_define_quantized_tensor_value(
534 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, input4_dims.size(), input4_dims.data(), nullptr, 3,
535 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input4_id));
536 ASSERT_NE(input4_id, XNN_INVALID_NODE_ID);
537
538 output_id = XNN_INVALID_NODE_ID;
539 ASSERT_EQ(
540 xnn_status_success,
541 xnn_define_quantized_tensor_value(
542 subgraph, xnn_datatype_quint8, unsigned_zero_point, scale, output_dims.size(), output_dims.data(), nullptr, 4,
543 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
544 ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
545
546 ASSERT_EQ(
547 xnn_status_success,
548 xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0));
549
550 xnn_runtime_t runtime = nullptr;
551 ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
552 ASSERT_NE(nullptr, runtime);
553 std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
554 std::array<xnn_external_value, 5> external = {
555 xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
556 xnn_external_value{input3_id, input3.data()}, xnn_external_value{input4_id, input4.data()},
557 xnn_external_value{output_id, subgraph_output.data()}};
558 ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
559 ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
560
561 // Check outputs match.
562 ASSERT_EQ(subgraph_output, operator_output);
563 }
564
TEST_F(Concatenate4TestF32,matches_operator_api)565 TEST_F(Concatenate4TestF32, matches_operator_api)
566 {
567 std::generate(input1.begin(), input1.end(), [&]() { return f32dist(rng); });
568 std::generate(input2.begin(), input2.end(), [&]() { return f32dist(rng); });
569 std::generate(input3.begin(), input3.end(), [&]() { return f32dist(rng); });
570 std::generate(input4.begin(), input4.end(), [&]() { return f32dist(rng); });
571 std::fill(operator_output.begin(), operator_output.end(), std::nanf(""));
572 std::fill(subgraph_output.begin(), subgraph_output.end(), std::nanf(""));
573
574 ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
575
576 xnn_operator_t op1 = nullptr;
577 xnn_operator_t op2 = nullptr;
578 xnn_operator_t op3 = nullptr;
579 xnn_operator_t op4 = nullptr;
580
581 // Call operator API.
582 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_1, channels_1, output_stride, /*flags=*/0, &op1));
583 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op1(op1, xnn_delete_operator);
584 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_2, channels_2, output_stride, /*flags=*/0, &op2));
585 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op2(op2, xnn_delete_operator);
586 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_3, channels_3, output_stride, /*flags=*/0, &op3));
587 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op3(op3, xnn_delete_operator);
588 ASSERT_EQ(xnn_status_success, xnn_create_copy_nc_x32(channels_4, channels_4, output_stride, /*flags=*/0, &op4));
589 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_op4(op4, xnn_delete_operator);
590
591 ASSERT_EQ(
592 xnn_status_success,
593 xnn_setup_copy_nc_x32(op1, batch_size, input1.data(), operator_output.data(), nullptr /* thread pool */));
594 ASSERT_EQ(
595 xnn_status_success,
596 xnn_setup_copy_nc_x32(
597 op2, batch_size, input2.data(), (float*) operator_output.data() + op1->channels, nullptr /* thread pool */));
598 ASSERT_EQ(
599 xnn_status_success, xnn_setup_copy_nc_x32(
600 op3, batch_size, input3.data(),
601 (float*) operator_output.data() + op1->channels + op2->channels, nullptr /* thread pool */));
602 ASSERT_EQ(
603 xnn_status_success,
604 xnn_setup_copy_nc_x32(
605 op4, batch_size, input4.data(), (float*) operator_output.data() + op1->channels + op2->channels + op3->channels,
606 nullptr /* thread pool */));
607
608 ASSERT_EQ(xnn_status_success, xnn_run_operator(op1, nullptr /* thread pool */));
609 ASSERT_EQ(xnn_status_success, xnn_run_operator(op2, nullptr /* thread pool */));
610 ASSERT_EQ(xnn_status_success, xnn_run_operator(op3, nullptr /* thread pool */));
611 ASSERT_EQ(xnn_status_success, xnn_run_operator(op4, nullptr /* thread pool */));
612
613 // Call subgraph API.
614 xnn_subgraph_t subgraph = nullptr;
615 ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/5, /*flags=*/0, &subgraph));
616 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
617
618 input1_id = XNN_INVALID_NODE_ID;
619 ASSERT_EQ(
620 xnn_status_success, xnn_define_tensor_value(
621 subgraph, xnn_datatype_fp32, input1_dims.size(), input1_dims.data(), nullptr, 0,
622 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input1_id));
623 ASSERT_NE(input1_id, XNN_INVALID_NODE_ID);
624
625 input2_id = XNN_INVALID_NODE_ID;
626 ASSERT_EQ(
627 xnn_status_success, xnn_define_tensor_value(
628 subgraph, xnn_datatype_fp32, input2_dims.size(), input2_dims.data(), nullptr, 1,
629 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input2_id));
630 ASSERT_NE(input2_id, XNN_INVALID_NODE_ID);
631
632 input3_id = XNN_INVALID_NODE_ID;
633 ASSERT_EQ(
634 xnn_status_success, xnn_define_tensor_value(
635 subgraph, xnn_datatype_fp32, input3_dims.size(), input3_dims.data(), nullptr, 2,
636 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input3_id));
637 ASSERT_NE(input3_id, XNN_INVALID_NODE_ID);
638
639 input4_id = XNN_INVALID_NODE_ID;
640 ASSERT_EQ(
641 xnn_status_success, xnn_define_tensor_value(
642 subgraph, xnn_datatype_fp32, input4_dims.size(), input4_dims.data(), nullptr, 3,
643 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input4_id));
644 ASSERT_NE(input4_id, XNN_INVALID_NODE_ID);
645
646 output_id = XNN_INVALID_NODE_ID;
647 ASSERT_EQ(
648 xnn_status_success, xnn_define_tensor_value(
649 subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, 4,
650 /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
651 ASSERT_NE(output_id, XNN_INVALID_NODE_ID);
652
653 ASSERT_EQ(
654 xnn_status_success,
655 xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0));
656
657 xnn_runtime_t runtime = nullptr;
658 ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
659 ASSERT_NE(nullptr, runtime);
660 std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
661 std::array<xnn_external_value, 5> external = {
662 xnn_external_value{input1_id, input1.data()}, xnn_external_value{input2_id, input2.data()},
663 xnn_external_value{input3_id, input3.data()}, xnn_external_value{input4_id, input4.data()},
664 xnn_external_value{output_id, subgraph_output.data()}};
665 ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
666 ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
667
668 // Check outputs match.
669 ASSERT_EQ(subgraph_output, operator_output);
670 }
671