1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #pragma once
7
8 #include <algorithm>
9 #include <cassert>
10 #include <cstddef>
11 #include <cstdlib>
12 #include <functional>
13 #include <numeric>
14 #include <vector>
15
16 #include <xnnpack.h>
17
18 #include <gtest/gtest.h>
19
reference_index(const size_t * input_stride,const size_t * output_stride,const size_t * perm,const size_t num_dims,size_t pos)20 inline size_t reference_index(
21 const size_t* input_stride,
22 const size_t* output_stride,
23 const size_t* perm,
24 const size_t num_dims,
25 size_t pos)
26 {
27 size_t in_pos = 0;
28 for (size_t j = 0; j < num_dims; ++j) {
29 const size_t idx = pos / output_stride[j];
30 pos = pos % output_stride[j];
31 in_pos += idx * input_stride[perm[j]];
32 }
33 return in_pos;
34 }
35
36 class TransposeOperatorTester {
37 public:
num_dims(size_t num_dims)38 inline TransposeOperatorTester& num_dims(size_t num_dims) {
39 assert(num_dims != 0);
40 this->num_dims_ = num_dims;
41 return *this;
42 }
43
num_dims()44 inline size_t num_dims() const { return this->num_dims_; }
45
shape(std::vector<size_t> shape)46 inline TransposeOperatorTester& shape(std::vector<size_t> shape) {
47 assert(shape.size() <= XNN_MAX_TENSOR_DIMS);
48 this->shape_ = shape;
49 return *this;
50 }
51
dims()52 inline const std::vector<size_t>& dims() const { return this->shape_; }
53
perm(std::vector<size_t> perm)54 inline TransposeOperatorTester& perm(std::vector<size_t> perm) {
55 assert(perm.size() <= XNN_MAX_TENSOR_DIMS);
56 this->perm_ = perm;
57 return *this;
58 }
59
perm()60 inline const std::vector<size_t>& perm() const { return this->perm_; }
61
TestX8()62 void TestX8() const {
63 size_t count = std::accumulate(dims().cbegin(), dims().cend(), 1, std::multiplies<size_t>());
64 std::vector<uint8_t> input(count + XNN_EXTRA_BYTES / sizeof(uint8_t));
65 std::vector<uint8_t> output(count);
66 std::vector<size_t> input_stride(input.size(), 1);
67 std::vector<size_t> output_stride(input.size(), 1);
68 for (size_t i = num_dims() - 1; i > 0; --i) {
69 input_stride[i - 1] = input_stride[i] * shape_[i];
70 output_stride[i - 1] = output_stride[i] * shape_[perm()[i]];
71 }
72 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
73 xnn_operator_t transpose_op = nullptr;
74 std::iota(input.begin(), input.end(), 0);
75 std::fill(output.begin(), output.end(), UINT8_C(0xA5));
76
77 ASSERT_EQ(xnn_status_success,
78 xnn_create_transpose_nd_x8(0, &transpose_op));
79 ASSERT_NE(nullptr, transpose_op);
80
81 // Smart pointer to automatically delete convert op.
82 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_transpose_op(transpose_op, xnn_delete_operator);
83
84 ASSERT_EQ(xnn_status_success,
85 xnn_setup_transpose_nd_x8(
86 transpose_op,
87 input.data(), output.data(),
88 num_dims(), shape_.data(), perm_.data(),
89 nullptr /* thread pool */));
90
91 // Run operator.
92 ASSERT_EQ(xnn_status_success,
93 xnn_run_operator(transpose_op, nullptr /* thread pool */));
94
95 // Verify results.
96 for (size_t i = 0; i < count; ++i) {
97 const size_t in_idx = reference_index(input_stride.data(), output_stride.data(), perm_.data(), num_dims(), i);
98 ASSERT_EQ(input[in_idx], output[i]);
99 }
100 }
101
TestRunX8()102 void TestRunX8() const {
103 const size_t count = std::accumulate(dims().cbegin(), dims().cend(), 1, std::multiplies<size_t>());
104 std::vector<uint8_t> input(count + XNN_EXTRA_BYTES / sizeof(uint8_t));
105 std::vector<uint8_t> output(count);
106 std::vector<size_t> input_stride(input.size(), 1);
107 std::vector<size_t> output_stride(input.size(), 1);
108 for (size_t i = num_dims() - 1; i > 0; --i) {
109 input_stride[i - 1] = input_stride[i] * shape_[i];
110 output_stride[i - 1] = output_stride[i] * shape_[perm()[i]];
111 }
112 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
113 std::iota(input.begin(), input.end(), 0);
114 std::fill(output.begin(), output.end(), UINT8_C(0xA5));
115
116 // Call transpose eager API
117 ASSERT_EQ(xnn_status_success,
118 xnn_run_transpose_nd_x8(
119 0 /* flags */,
120 input.data(), output.data(),
121 num_dims(), shape_.data(), perm_.data(),
122 nullptr /* thread pool */));
123
124 // Verify results.
125 for (size_t i = 0; i < count; ++i) {
126 const size_t in_idx = reference_index(input_stride.data(), output_stride.data(), perm_.data(), num_dims(), i);
127 ASSERT_EQ(input[in_idx], output[i]);
128 }
129 }
130
TestX16()131 void TestX16() const {
132 size_t count = std::accumulate(dims().cbegin(), dims().cend(), 1, std::multiplies<size_t>());
133 std::vector<uint16_t> input(count + XNN_EXTRA_BYTES / sizeof(uint16_t));
134 std::vector<uint16_t> output(count);
135 std::vector<size_t> input_stride(input.size(), 1);
136 std::vector<size_t> output_stride(input.size(), 1);
137 for (size_t i = num_dims() - 1; i > 0; --i) {
138 input_stride[i - 1] = input_stride[i] * shape_[i];
139 output_stride[i - 1] = output_stride[i] * shape_[perm()[i]];
140 }
141 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
142 xnn_operator_t transpose_op = nullptr;
143 std::iota(input.begin(), input.end(), 0);
144 std::fill(output.begin(), output.end(), UINT16_C(0xDEAD));
145
146 ASSERT_EQ(xnn_status_success,
147 xnn_create_transpose_nd_x16(0, &transpose_op));
148 ASSERT_NE(nullptr, transpose_op);
149
150 // Smart pointer to automatically delete convert op.
151 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_transpose_op(transpose_op, xnn_delete_operator);
152
153 ASSERT_EQ(xnn_status_success,
154 xnn_setup_transpose_nd_x16(
155 transpose_op,
156 input.data(), output.data(),
157 num_dims(), shape_.data(), perm_.data(),
158 nullptr /* thread pool */));
159
160 // Run operator.
161 ASSERT_EQ(xnn_status_success,
162 xnn_run_operator(transpose_op, nullptr /* thread pool */));
163
164 // Verify results.
165 for (size_t i = 0; i < count; ++i) {
166 const size_t in_idx = reference_index(input_stride.data(), output_stride.data(), perm_.data(), num_dims(), i);
167 ASSERT_EQ(input[in_idx], output[i]);
168 }
169 }
170
TestRunX16()171 void TestRunX16() const {
172 const size_t count = std::accumulate(dims().cbegin(), dims().cend(), 1, std::multiplies<size_t>());
173 std::vector<uint16_t> input(count + XNN_EXTRA_BYTES / sizeof(uint16_t));
174 std::vector<uint16_t> output(count);
175 std::vector<size_t> input_stride(input.size(), 1);
176 std::vector<size_t> output_stride(input.size(), 1);
177 for (size_t i = num_dims() - 1; i > 0; --i) {
178 input_stride[i - 1] = input_stride[i] * shape_[i];
179 output_stride[i - 1] = output_stride[i] * shape_[perm()[i]];
180 }
181 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
182 std::iota(input.begin(), input.end(), 0);
183 std::fill(output.begin(), output.end(), UINT16_C(0xDEADBEEF));
184
185 // Call transpose eager API
186 ASSERT_EQ(xnn_status_success,
187 xnn_run_transpose_nd_x16(
188 0 /* flags */,
189 input.data(), output.data(),
190 num_dims(), shape_.data(), perm_.data(),
191 nullptr /* thread pool */));
192
193 // Verify results.
194 for (size_t i = 0; i < count; ++i) {
195 const size_t in_idx = reference_index(input_stride.data(), output_stride.data(), perm_.data(), num_dims(), i);
196 ASSERT_EQ(input[in_idx], output[i]);
197 }
198 }
199
TestX32()200 void TestX32() const {
201 size_t count = std::accumulate(dims().cbegin(), dims().cend(), 1, std::multiplies<size_t>());
202 std::vector<uint32_t> input(count + XNN_EXTRA_BYTES / sizeof(uint32_t));
203 std::vector<uint32_t> output(count);
204 std::vector<size_t> input_stride(input.size(), 1);
205 std::vector<size_t> output_stride(input.size(), 1);
206 for (size_t i = num_dims() - 1; i > 0; --i) {
207 input_stride[i - 1] = input_stride[i] * shape_[i];
208 output_stride[i - 1] = output_stride[i] * shape_[perm()[i]];
209 }
210 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
211 xnn_operator_t transpose_op = nullptr;
212 std::iota(input.begin(), input.end(), 0);
213 std::fill(output.begin(), output.end(), UINT32_C(0xDEADBEEF));
214
215 ASSERT_EQ(xnn_status_success,
216 xnn_create_transpose_nd_x32(0, &transpose_op));
217 ASSERT_NE(nullptr, transpose_op);
218
219 // Smart pointer to automatically delete convert op.
220 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_transpose_op(transpose_op, xnn_delete_operator);
221
222 ASSERT_EQ(xnn_status_success,
223 xnn_setup_transpose_nd_x32(
224 transpose_op,
225 input.data(), output.data(),
226 num_dims(), shape_.data(), perm_.data(),
227 nullptr /* thread pool */));
228
229 // Run operator.
230 ASSERT_EQ(xnn_status_success,
231 xnn_run_operator(transpose_op, nullptr /* thread pool */));
232
233 // Verify results.
234 for (size_t i = 0; i < count; ++i) {
235 const size_t in_idx = reference_index(input_stride.data(), output_stride.data(), perm_.data(), num_dims(), i);
236 ASSERT_EQ(input[in_idx], output[i]);
237 }
238 }
239
TestRunX32()240 void TestRunX32() const {
241 const size_t count = std::accumulate(dims().cbegin(), dims().cend(), 1, std::multiplies<size_t>());
242 std::vector<uint32_t> input(count + XNN_EXTRA_BYTES / sizeof(uint32_t));
243 std::vector<uint32_t> output(count);
244 std::vector<size_t> input_stride(input.size(), 1);
245 std::vector<size_t> output_stride(input.size(), 1);
246 for (size_t i = num_dims() - 1; i > 0; --i) {
247 input_stride[i - 1] = input_stride[i] * shape_[i];
248 output_stride[i - 1] = output_stride[i] * shape_[perm()[i]];
249 }
250 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
251 std::iota(input.begin(), input.end(), 0);
252 std::fill(output.begin(), output.end(), UINT32_C(0xDEADBEEF));
253
254 // Call transpose eager API
255 ASSERT_EQ(xnn_status_success,
256 xnn_run_transpose_nd_x32(
257 0,
258 input.data(), output.data(),
259 num_dims(), shape_.data(), perm_.data(),
260 nullptr /* thread pool */));
261
262 // Verify results.
263 for (size_t i = 0; i < count; ++i) {
264 const size_t in_idx = reference_index(input_stride.data(), output_stride.data(), perm_.data(), num_dims(), i);
265 ASSERT_EQ(input[in_idx], output[i]);
266 }
267 }
268
269 private:
270 size_t num_dims_ = 1;
271 std::vector<size_t> shape_;
272 std::vector<size_t> perm_;
273 };
274