xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/tasks/concat_test_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/concat_test_util.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/lite/delegates/gpu/common/operations.h"
22 #include "tensorflow/lite/delegates/gpu/common/status.h"
23 #include "tensorflow/lite/delegates/gpu/common/task/testing_util.h"
24 #include "tensorflow/lite/delegates/gpu/common/tasks/concat_xy.h"
25 #include "tensorflow/lite/delegates/gpu/common/tasks/concat_z.h"
26 
27 namespace tflite {
28 namespace gpu {
29 
ConcatWidthTest(TestExecutionEnvironment * env)30 absl::Status ConcatWidthTest(TestExecutionEnvironment* env) {
31   TensorFloat32 src0, src1;
32   src0.shape = BHWC(1, 2, 1, 2);
33   src0.data = {half(0.0f), half(-1.0f), half(-0.05f), half(0.045f)};
34   src1.shape = BHWC(1, 2, 2, 2);
35   src1.data = {half(1.0f), half(-1.2f), half(-0.45f), half(1.045f),
36                half(1.1f), half(-1.3f), half(-0.55f), half(2.045f)};
37 
38   ConcatAttributes attr;
39   attr.axis = Axis::WIDTH;
40 
41   for (auto precision : env->GetSupportedPrecisions()) {
42     auto data_type = DeduceDataTypeFromPrecision(precision);
43     for (auto storage : env->GetSupportedStorages(data_type)) {
44       OperationDef op_def;
45       op_def.precision = precision;
46       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
47       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
48       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
49       TensorFloat32 dst_tensor;
50       GPUOperation operation = CreateConcatXY(op_def, attr);
51       RETURN_IF_ERROR(env->ExecuteGPUOperation(
52           {src0, src1}, std::make_unique<GPUOperation>(std::move(operation)),
53           BHWC(1, 2, 3, 2), &dst_tensor));
54       RETURN_IF_ERROR(
55           PointWiseNear({half(0.0f), half(-1.0f), half(1.0f), half(-1.2f),
56                          half(-0.45f), half(1.045f), half(-0.05f), half(0.045f),
57                          half(1.1f), half(-1.3f), half(-0.55f), half(2.045f)},
58                         dst_tensor.data, 0.0f));
59     }
60   }
61   return absl::OkStatus();
62 }
63 
ConcatHeightTest(TestExecutionEnvironment * env)64 absl::Status ConcatHeightTest(TestExecutionEnvironment* env) {
65   TensorFloat32 src0, src1;
66   src0.shape = BHWC(1, 2, 1, 2);
67   src0.data = {half(0.0f), half(-1.0f), half(-0.05f), half(0.045f)};
68   src1.shape = BHWC(1, 1, 1, 2);
69   src1.data = {half(1.0f), half(-1.2f)};
70 
71   ConcatAttributes attr;
72   attr.axis = Axis::HEIGHT;
73 
74   for (auto precision : env->GetSupportedPrecisions()) {
75     auto data_type = DeduceDataTypeFromPrecision(precision);
76     for (auto storage : env->GetSupportedStorages(data_type)) {
77       OperationDef op_def;
78       op_def.precision = precision;
79       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
80       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
81       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
82       TensorFloat32 dst_tensor;
83       GPUOperation operation = CreateConcatXY(op_def, attr);
84       RETURN_IF_ERROR(env->ExecuteGPUOperation(
85           {src0, src1}, std::make_unique<GPUOperation>(std::move(operation)),
86           BHWC(1, 3, 1, 2), &dst_tensor));
87       RETURN_IF_ERROR(PointWiseNear({half(0.0f), half(-1.0f), half(-0.05f),
88                                      half(0.045f), half(1.0f), half(-1.2f)},
89                                     dst_tensor.data, 0.0f));
90     }
91   }
92   return absl::OkStatus();
93 }
94 
ConcatChannelsTest(TestExecutionEnvironment * env)95 absl::Status ConcatChannelsTest(TestExecutionEnvironment* env) {
96   TensorFloat32 src0, src1, src2;
97   src0.shape = BHWC(1, 2, 1, 1);
98   src0.data = {half(0.0f), half(-1.0f)};
99   src1.shape = BHWC(1, 2, 1, 2);
100   src1.data = {half(1.0f), half(2.0f), half(3.0f), half(4.0f)};
101   src2.shape = BHWC(1, 2, 1, 3);
102   src2.data = {half(5.0f), half(6.0f), half(7.0f),
103                half(8.0f), half(9.0),  half(10.0f)};
104 
105   ConcatAttributes attr;
106   attr.axis = Axis::CHANNELS;
107 
108   for (auto precision : env->GetSupportedPrecisions()) {
109     auto data_type = DeduceDataTypeFromPrecision(precision);
110     for (auto storage : env->GetSupportedStorages(data_type)) {
111       OperationDef op_def;
112       op_def.precision = precision;
113       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
114       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
115       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
116       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
117       TensorFloat32 dst_tensor;
118       GPUOperation operation =
119           CreateConcatZ(op_def, {1, 2, 3}, env->GetGpuInfo());
120       RETURN_IF_ERROR(env->ExecuteGPUOperation(
121           {src0, src1, src2},
122           std::make_unique<GPUOperation>(std::move(operation)),
123           BHWC(1, 2, 1, 6), &dst_tensor));
124       RETURN_IF_ERROR(
125           PointWiseNear({half(0.0f), half(1.0f), half(2.0f), half(5.0f),
126                          half(6.0f), half(7.0f), half(-1.0f), half(3.0f),
127                          half(4.0f), half(8.0f), half(9.0), half(10.0f)},
128                         dst_tensor.data, 0.0f));
129     }
130   }
131   return absl::OkStatus();
132 }
133 
ConcatChannelsAlignedx4Test(TestExecutionEnvironment * env)134 absl::Status ConcatChannelsAlignedx4Test(TestExecutionEnvironment* env) {
135   TensorFloat32 src0, src1;
136   src0.shape = BHWC(1, 2, 1, 4);
137   src0.data = {half(-1.0f), half(-2.0f), half(-3.0f), half(-4.0f),
138                half(1.0f),  half(2.0f),  half(3.0f),  half(4.0f)};
139   src1.shape = BHWC(1, 2, 1, 4);
140   src1.data = {half(5.0f),  half(6.0f),  half(7.0f),  half(8.0f),
141                half(-5.0f), half(-6.0f), half(-7.0f), half(-8.0f)};
142 
143   ConcatAttributes attr;
144   attr.axis = Axis::CHANNELS;
145 
146   for (auto precision : env->GetSupportedPrecisions()) {
147     auto data_type = DeduceDataTypeFromPrecision(precision);
148     for (auto storage : env->GetSupportedStorages(data_type)) {
149       OperationDef op_def;
150       op_def.precision = precision;
151       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
152       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
153       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
154       TensorFloat32 dst_tensor;
155       GPUOperation operation = CreateConcatZ(op_def, {4, 4}, env->GetGpuInfo());
156       RETURN_IF_ERROR(env->ExecuteGPUOperation(
157           {src0, src1}, std::make_unique<GPUOperation>(std::move(operation)),
158           BHWC(1, 2, 1, 8), &dst_tensor));
159       RETURN_IF_ERROR(
160           PointWiseNear({half(-1.0f), half(-2.0f), half(-3.0f), half(-4.0f),
161                          half(5.0f), half(6.0f), half(7.0f), half(8.0f),
162                          half(1.0f), half(2.0f), half(3.0f), half(4.0f),
163                          half(-5.0f), half(-6.0f), half(-7.0f), half(-8.0f)},
164                         dst_tensor.data, 0.0f));
165     }
166   }
167   return absl::OkStatus();
168 }
169 
170 }  // namespace gpu
171 }  // namespace tflite
172