xref: /aosp_15_r20/external/ComputeLibrary/tests/validation/CL/Permute.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/Types.h"
25 #include "arm_compute/runtime/CL/CLTensor.h"
26 #include "arm_compute/runtime/CL/CLTensorAllocator.h"
27 #include "arm_compute/runtime/CL/functions/CLPermute.h"
28 #include "tests/CL/CLAccessor.h"
29 #include "tests/PaddingCalculator.h"
30 #include "tests/datasets/ShapeDatasets.h"
31 #include "tests/framework/Asserts.h"
32 #include "tests/framework/Macros.h"
33 #include "tests/framework/datasets/Datasets.h"
34 #include "tests/validation/Validation.h"
35 #include "tests/validation/fixtures/PermuteFixture.h"
36 
37 namespace arm_compute
38 {
39 namespace test
40 {
41 namespace validation
42 {
43 namespace
44 {
45 const auto PermuteVectors3 = framework::dataset::make("PermutationVector",
46 {
47     PermutationVector(2U, 0U, 1U),
48     PermutationVector(1U, 2U, 0U),
49     PermutationVector(0U, 1U, 2U),
50     PermutationVector(0U, 2U, 1U),
51     PermutationVector(1U, 0U, 2U),
52     PermutationVector(2U, 1U, 0U),
53 });
54 const auto PermuteVectors4 = framework::dataset::make("PermutationVector",
55 {
56     PermutationVector(3U, 2U, 0U, 1U),
57     PermutationVector(3U, 2U, 1U, 0U),
58     PermutationVector(2U, 3U, 1U, 0U),
59     PermutationVector(1U, 3U, 2U, 0U),
60     PermutationVector(3U, 1U, 2U, 0U),
61     PermutationVector(3U, 0U, 2U, 1U),
62     PermutationVector(0U, 3U, 2U, 1U)
63 });
64 const auto PermuteVectors         = concat(PermuteVectors3, PermuteVectors4);
65 const auto PermuteParametersSmall = concat(concat(datasets::Small2DShapes(), datasets::Small3DShapes()), datasets::Small4DShapes()) * PermuteVectors;
66 const auto PermuteParametersLarge = datasets::Large4DShapes() * PermuteVectors;
67 } // namespace
68 TEST_SUITE(CL)
TEST_SUITE(Permute)69 TEST_SUITE(Permute)
70 
71 // *INDENT-OFF*
72 // clang-format off
73 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
74         framework::dataset::make("InputInfo",{
75                 TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16),     // valid
76                 TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16),     // permutation not supported
77                 TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16),     // permutation not supported
78                 TensorInfo(TensorShape(1U, 7U), 1, DataType::U8),              // invalid input size
79                 TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16),     // valid
80                 TensorInfo(TensorShape(27U, 13U, 37U, 2U), 1, DataType::F32),  // valid
81                 TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16),     // permutation not supported
82                 TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::S16),     // valid
83                 TensorInfo(TensorShape(27U, 13U, 37U, 2U), 1, DataType::F32),  // permutation not supported
84                 TensorInfo(TensorShape(27U, 13U, 37U, 2U), 1, DataType::F32),  // valid
85                 TensorInfo(TensorShape(27U, 13U, 37U, 2U), 1, DataType::F32)   // permutation not supported
86 
87         }),
88         framework::dataset::make("OutputInfo", {
89                 TensorInfo(TensorShape(5U, 7U, 7U, 3U), 1, DataType::U16),
90                 TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16),
91                 TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16),
92                 TensorInfo(TensorShape(5U, 7U), 1, DataType::U8),
93                 TensorInfo(TensorShape(5U, 7U, 7U, 3U), 1, DataType::U16),
94                 TensorInfo(TensorShape(13U, 37U, 27U, 2U), 1, DataType::F32),
95                 TensorInfo(TensorShape(5U, 7U, 7U, 3U), 1, DataType::U16),
96                 TensorInfo(TensorShape(3U, 5U, 7U, 7U), 1, DataType::S16),
97                 TensorInfo(TensorShape(13U, 37U, 27U, 2U), 1, DataType::F32),
98                 TensorInfo(TensorShape(37U, 2U, 13U, 27U), 1, DataType::F32),
99                 TensorInfo(TensorShape(37U, 2U, 13U, 27U), 1, DataType::F32)
100 
101         })),
102         framework::dataset::make("PermutationVector", {
103                 PermutationVector(2U, 1U, 0U),
104                 PermutationVector(2U, 2U, 1U),
105                 PermutationVector(1U, 1U, 1U),
106                 PermutationVector(2U, 0U, 1U),
107                 PermutationVector(2U, 0U, 1U),
108                 PermutationVector(1U, 2U, 0U),
109                 PermutationVector(3U, 2U, 0U, 1U),
110                 PermutationVector(3U, 2U, 0U, 1U),
111                 PermutationVector(2U, 3U, 1U, 0U),
112                 PermutationVector(2U, 3U, 1U, 0U),
113                 PermutationVector(0U, 0U, 0U, 1000U)
114         })),
115         framework::dataset::make("Expected", { true, false, false, false, true, true, false, true, false, true, false })),
116         input_info, output_info, perm_vect, expected)
117 {
118     ARM_COMPUTE_EXPECT(bool(CLPermute::validate(&input_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), perm_vect)) == expected, framework::LogLevel::ERRORS);
119 }
120 // clang-format on
121 // *INDENT-ON*
122 
123 #ifndef DOXYGEN_SKIP_THIS
124 
125 template <typename T>
126 using CLPermuteFixture = PermuteValidationFixture<CLTensor, CLAccessor, CLPermute, T>;
127 
128 TEST_SUITE(U8)
129 FIXTURE_DATA_TEST_CASE(RunSmall, CLPermuteFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
130                        PermuteParametersSmall * framework::dataset::make("DataType", DataType::U8))
131 {
132     // Validate output
133     validate(CLAccessor(_target), _reference);
134 }
135 
136 FIXTURE_DATA_TEST_CASE(RunLarge, CLPermuteFixture<uint8_t>, framework::DatasetMode::NIGHTLY,
137                        PermuteParametersLarge * framework::dataset::make("DataType", DataType::U8))
138 {
139     // Validate output
140     validate(CLAccessor(_target), _reference);
141 }
142 TEST_SUITE_END() // U8
143 
TEST_SUITE(U16)144 TEST_SUITE(U16)
145 FIXTURE_DATA_TEST_CASE(RunSmall, CLPermuteFixture<uint16_t>, framework::DatasetMode::PRECOMMIT,
146                        PermuteParametersSmall * framework::dataset::make("DataType", DataType::U16))
147 {
148     // Validate output
149     validate(CLAccessor(_target), _reference);
150 }
151 FIXTURE_DATA_TEST_CASE(RunLarge, CLPermuteFixture<uint16_t>, framework::DatasetMode::NIGHTLY,
152                        PermuteParametersLarge * framework::dataset::make("DataType", DataType::U16))
153 {
154     // Validate output
155     validate(CLAccessor(_target), _reference);
156 }
157 TEST_SUITE_END() // U16
158 
TEST_SUITE(U32)159 TEST_SUITE(U32)
160 FIXTURE_DATA_TEST_CASE(RunSmall, CLPermuteFixture<uint32_t>, framework::DatasetMode::PRECOMMIT,
161                        PermuteParametersSmall * framework::dataset::make("DataType", DataType::U32))
162 {
163     // Validate output
164     validate(CLAccessor(_target), _reference);
165 }
166 FIXTURE_DATA_TEST_CASE(RunLarge, CLPermuteFixture<uint32_t>, framework::DatasetMode::NIGHTLY,
167                        PermuteParametersLarge * framework::dataset::make("DataType", DataType::U32))
168 {
169     // Validate output
170     validate(CLAccessor(_target), _reference);
171 }
172 TEST_SUITE_END() // U32
173 
174 #endif /* DOXYGEN_SKIP_THIS */
175 
176 TEST_SUITE_END() // Permute
177 TEST_SUITE_END() // CL
178 } // namespace validation
179 } // namespace test
180 } // namespace arm_compute
181