xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/tasks/tile_test_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/tile_test_util.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/lite/delegates/gpu/common/operations.h"
22 #include "tensorflow/lite/delegates/gpu/common/status.h"
23 #include "tensorflow/lite/delegates/gpu/common/task/testing_util.h"
24 #include "tensorflow/lite/delegates/gpu/common/tasks/tile.h"
25 
26 namespace tflite {
27 namespace gpu {
28 
TileChannelsTest(TestExecutionEnvironment * env)29 absl::Status TileChannelsTest(TestExecutionEnvironment* env) {
30   TensorFloat32 src_tensor;
31   src_tensor.shape = BHWC(1, 2, 1, 3);
32   src_tensor.data = {half(1.0f), half(2.0f), half(3.0f),
33                      half(4.0f), half(5.0f), half(6.0f)};
34   for (auto precision : env->GetSupportedPrecisions()) {
35     auto data_type = DeduceDataTypeFromPrecision(precision);
36     for (auto storage : env->GetSupportedStorages(data_type)) {
37       OperationDef op_def;
38       op_def.precision = precision;
39       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
40       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
41       TensorFloat32 dst_tensor;
42       GPUOperation operation = CreateTile(op_def, src_tensor.shape.c);
43       RETURN_IF_ERROR(env->ExecuteGPUOperation(
44           src_tensor, std::make_unique<GPUOperation>(std::move(operation)),
45           BHWC(1, 2, 1, 6), &dst_tensor));
46       RETURN_IF_ERROR(
47           PointWiseNear({half(1.0f), half(2.0f), half(3.0f), half(1.0f),
48                          half(2.0f), half(3.0f), half(4.0f), half(5.0f),
49                          half(6.0f), half(4.0f), half(5.0f), half(6.0f)},
50                         dst_tensor.data, 0.0f));
51     }
52   }
53   return absl::OkStatus();
54 }
55 
TileChannelsX4Test(TestExecutionEnvironment * env)56 absl::Status TileChannelsX4Test(TestExecutionEnvironment* env) {
57   TensorFloat32 src_tensor;
58   src_tensor.shape = BHWC(1, 2, 1, 4);
59   src_tensor.data = {half(1.0f), half(2.0f), half(3.0f), half(7.0f),
60                      half(4.0f), half(5.0f), half(6.0f), half(8.0f)};
61   for (auto precision : env->GetSupportedPrecisions()) {
62     auto data_type = DeduceDataTypeFromPrecision(precision);
63     for (auto storage : env->GetSupportedStorages(data_type)) {
64       OperationDef op_def;
65       op_def.precision = precision;
66       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
67       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
68       TensorFloat32 dst_tensor;
69       GPUOperation operation = CreateTile(op_def, src_tensor.shape.c);
70       RETURN_IF_ERROR(env->ExecuteGPUOperation(
71           src_tensor, std::make_unique<GPUOperation>(std::move(operation)),
72           BHWC(1, 2, 1, 8), &dst_tensor));
73       RETURN_IF_ERROR(
74           PointWiseNear({half(1.0f), half(2.0f), half(3.0f), half(7.0f),
75                          half(1.0f), half(2.0f), half(3.0f), half(7.0f),
76                          half(4.0f), half(5.0f), half(6.0f), half(8.0f),
77                          half(4.0f), half(5.0f), half(6.0f), half(8.0f)},
78                         dst_tensor.data, 0.0f));
79     }
80   }
81   return absl::OkStatus();
82 }
83 
TileWidthTest(TestExecutionEnvironment * env)84 absl::Status TileWidthTest(TestExecutionEnvironment* env) {
85   TensorFloat32 src_tensor;
86   src_tensor.shape = BHWC(1, 1, 2, 3);
87   src_tensor.data = {half(1.0f), half(2.0f), half(3.0f),
88                      half(4.0f), half(5.0f), half(6.0f)};
89   for (auto precision : env->GetSupportedPrecisions()) {
90     auto data_type = DeduceDataTypeFromPrecision(precision);
91     for (auto storage : env->GetSupportedStorages(data_type)) {
92       OperationDef op_def;
93       op_def.precision = precision;
94       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
95       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
96       TensorFloat32 dst_tensor;
97       GPUOperation operation = CreateTile(op_def, src_tensor.shape.c);
98       RETURN_IF_ERROR(env->ExecuteGPUOperation(
99           src_tensor, std::make_unique<GPUOperation>(std::move(operation)),
100           BHWC(1, 1, 4, 3), &dst_tensor));
101       RETURN_IF_ERROR(
102           PointWiseNear({half(1.0f), half(2.0f), half(3.0f), half(4.0f),
103                          half(5.0f), half(6.0f), half(1.0f), half(2.0f),
104                          half(3.0f), half(4.0f), half(5.0f), half(6.0f)},
105                         dst_tensor.data, 0.0f));
106     }
107   }
108   return absl::OkStatus();
109 }
110 
TileHeightTest(TestExecutionEnvironment * env)111 absl::Status TileHeightTest(TestExecutionEnvironment* env) {
112   TensorFloat32 src_tensor;
113   src_tensor.shape = BHWC(1, 2, 1, 3);
114   src_tensor.data = {half(1.0f), half(2.0f), half(3.0f),
115                      half(4.0f), half(5.0f), half(6.0f)};
116   for (auto precision : env->GetSupportedPrecisions()) {
117     auto data_type = DeduceDataTypeFromPrecision(precision);
118     for (auto storage : env->GetSupportedStorages(data_type)) {
119       OperationDef op_def;
120       op_def.precision = precision;
121       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
122       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
123       TensorFloat32 dst_tensor;
124       GPUOperation operation = CreateTile(op_def, src_tensor.shape.c);
125       RETURN_IF_ERROR(env->ExecuteGPUOperation(
126           src_tensor, std::make_unique<GPUOperation>(std::move(operation)),
127           BHWC(1, 4, 1, 3), &dst_tensor));
128       RETURN_IF_ERROR(
129           PointWiseNear({half(1.0f), half(2.0f), half(3.0f), half(4.0f),
130                          half(5.0f), half(6.0f), half(1.0f), half(2.0f),
131                          half(3.0f), half(4.0f), half(5.0f), half(6.0f)},
132                         dst_tensor.data, 0.0f));
133     }
134   }
135   return absl::OkStatus();
136 }
137 
TileHWCTest(TestExecutionEnvironment * env)138 absl::Status TileHWCTest(TestExecutionEnvironment* env) {
139   TensorFloat32 src_tensor;
140   src_tensor.shape = BHWC(1, 2, 2, 3);
141   src_tensor.data = {half(1.0f), half(2.0f),  half(3.0f),  half(4.0f),
142                      half(5.0f), half(6.0f),  half(7.0f),  half(8.0f),
143                      half(9.0f), half(10.0f), half(11.0f), half(12.0f)};
144   for (auto precision : env->GetSupportedPrecisions()) {
145     auto data_type = DeduceDataTypeFromPrecision(precision);
146     for (auto storage : env->GetSupportedStorages(data_type)) {
147       OperationDef op_def;
148       op_def.precision = precision;
149       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
150       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
151       TensorFloat32 dst_tensor;
152       GPUOperation operation = CreateTile(op_def, src_tensor.shape.c);
153       RETURN_IF_ERROR(env->ExecuteGPUOperation(
154           src_tensor, std::make_unique<GPUOperation>(std::move(operation)),
155           BHWC(1, 4, 4, 6), &dst_tensor));
156       RETURN_IF_ERROR(PointWiseNear(
157           {half(1.0f),  half(2.0f),  half(3.0f),  half(1.0f),  half(2.0f),
158            half(3.0f),  half(4.0f),  half(5.0f),  half(6.0f),  half(4.0f),
159            half(5.0f),  half(6.0f),  half(1.0f),  half(2.0f),  half(3.0f),
160            half(1.0f),  half(2.0f),  half(3.0f),  half(4.0f),  half(5.0f),
161            half(6.0f),  half(4.0f),  half(5.0f),  half(6.0f),  half(7.0f),
162            half(8.0f),  half(9.0f),  half(7.0f),  half(8.0f),  half(9.0f),
163            half(10.0f), half(11.0f), half(12.0f), half(10.0f), half(11.0f),
164            half(12.0f), half(7.0f),  half(8.0f),  half(9.0f),  half(7.0f),
165            half(8.0f),  half(9.0f),  half(10.0f), half(11.0f), half(12.0f),
166            half(10.0f), half(11.0f), half(12.0f), half(1.0f),  half(2.0f),
167            half(3.0f),  half(1.0f),  half(2.0f),  half(3.0f),  half(4.0f),
168            half(5.0f),  half(6.0f),  half(4.0f),  half(5.0f),  half(6.0f),
169            half(1.0f),  half(2.0f),  half(3.0f),  half(1.0f),  half(2.0f),
170            half(3.0f),  half(4.0f),  half(5.0f),  half(6.0f),  half(4.0f),
171            half(5.0f),  half(6.0f),  half(7.0f),  half(8.0f),  half(9.0f),
172            half(7.0f),  half(8.0f),  half(9.0f),  half(10.0f), half(11.0f),
173            half(12.0f), half(10.0f), half(11.0f), half(12.0f), half(7.0f),
174            half(8.0f),  half(9.0f),  half(7.0f),  half(8.0f),  half(9.0f),
175            half(10.0f), half(11.0f), half(12.0f), half(10.0f), half(11.0f),
176            half(12.0f)},
177           dst_tensor.data, 0.0f));
178     }
179   }
180   return absl::OkStatus();
181 }
182 
183 }  // namespace gpu
184 }  // namespace tflite
185