1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/tile_test_util.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "tensorflow/lite/delegates/gpu/common/operations.h"
22 #include "tensorflow/lite/delegates/gpu/common/status.h"
23 #include "tensorflow/lite/delegates/gpu/common/task/testing_util.h"
24 #include "tensorflow/lite/delegates/gpu/common/tasks/tile.h"
25
26 namespace tflite {
27 namespace gpu {
28
TileChannelsTest(TestExecutionEnvironment * env)29 absl::Status TileChannelsTest(TestExecutionEnvironment* env) {
30 TensorFloat32 src_tensor;
31 src_tensor.shape = BHWC(1, 2, 1, 3);
32 src_tensor.data = {half(1.0f), half(2.0f), half(3.0f),
33 half(4.0f), half(5.0f), half(6.0f)};
34 for (auto precision : env->GetSupportedPrecisions()) {
35 auto data_type = DeduceDataTypeFromPrecision(precision);
36 for (auto storage : env->GetSupportedStorages(data_type)) {
37 OperationDef op_def;
38 op_def.precision = precision;
39 op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
40 op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
41 TensorFloat32 dst_tensor;
42 GPUOperation operation = CreateTile(op_def, src_tensor.shape.c);
43 RETURN_IF_ERROR(env->ExecuteGPUOperation(
44 src_tensor, std::make_unique<GPUOperation>(std::move(operation)),
45 BHWC(1, 2, 1, 6), &dst_tensor));
46 RETURN_IF_ERROR(
47 PointWiseNear({half(1.0f), half(2.0f), half(3.0f), half(1.0f),
48 half(2.0f), half(3.0f), half(4.0f), half(5.0f),
49 half(6.0f), half(4.0f), half(5.0f), half(6.0f)},
50 dst_tensor.data, 0.0f));
51 }
52 }
53 return absl::OkStatus();
54 }
55
TileChannelsX4Test(TestExecutionEnvironment * env)56 absl::Status TileChannelsX4Test(TestExecutionEnvironment* env) {
57 TensorFloat32 src_tensor;
58 src_tensor.shape = BHWC(1, 2, 1, 4);
59 src_tensor.data = {half(1.0f), half(2.0f), half(3.0f), half(7.0f),
60 half(4.0f), half(5.0f), half(6.0f), half(8.0f)};
61 for (auto precision : env->GetSupportedPrecisions()) {
62 auto data_type = DeduceDataTypeFromPrecision(precision);
63 for (auto storage : env->GetSupportedStorages(data_type)) {
64 OperationDef op_def;
65 op_def.precision = precision;
66 op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
67 op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
68 TensorFloat32 dst_tensor;
69 GPUOperation operation = CreateTile(op_def, src_tensor.shape.c);
70 RETURN_IF_ERROR(env->ExecuteGPUOperation(
71 src_tensor, std::make_unique<GPUOperation>(std::move(operation)),
72 BHWC(1, 2, 1, 8), &dst_tensor));
73 RETURN_IF_ERROR(
74 PointWiseNear({half(1.0f), half(2.0f), half(3.0f), half(7.0f),
75 half(1.0f), half(2.0f), half(3.0f), half(7.0f),
76 half(4.0f), half(5.0f), half(6.0f), half(8.0f),
77 half(4.0f), half(5.0f), half(6.0f), half(8.0f)},
78 dst_tensor.data, 0.0f));
79 }
80 }
81 return absl::OkStatus();
82 }
83
TileWidthTest(TestExecutionEnvironment * env)84 absl::Status TileWidthTest(TestExecutionEnvironment* env) {
85 TensorFloat32 src_tensor;
86 src_tensor.shape = BHWC(1, 1, 2, 3);
87 src_tensor.data = {half(1.0f), half(2.0f), half(3.0f),
88 half(4.0f), half(5.0f), half(6.0f)};
89 for (auto precision : env->GetSupportedPrecisions()) {
90 auto data_type = DeduceDataTypeFromPrecision(precision);
91 for (auto storage : env->GetSupportedStorages(data_type)) {
92 OperationDef op_def;
93 op_def.precision = precision;
94 op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
95 op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
96 TensorFloat32 dst_tensor;
97 GPUOperation operation = CreateTile(op_def, src_tensor.shape.c);
98 RETURN_IF_ERROR(env->ExecuteGPUOperation(
99 src_tensor, std::make_unique<GPUOperation>(std::move(operation)),
100 BHWC(1, 1, 4, 3), &dst_tensor));
101 RETURN_IF_ERROR(
102 PointWiseNear({half(1.0f), half(2.0f), half(3.0f), half(4.0f),
103 half(5.0f), half(6.0f), half(1.0f), half(2.0f),
104 half(3.0f), half(4.0f), half(5.0f), half(6.0f)},
105 dst_tensor.data, 0.0f));
106 }
107 }
108 return absl::OkStatus();
109 }
110
TileHeightTest(TestExecutionEnvironment * env)111 absl::Status TileHeightTest(TestExecutionEnvironment* env) {
112 TensorFloat32 src_tensor;
113 src_tensor.shape = BHWC(1, 2, 1, 3);
114 src_tensor.data = {half(1.0f), half(2.0f), half(3.0f),
115 half(4.0f), half(5.0f), half(6.0f)};
116 for (auto precision : env->GetSupportedPrecisions()) {
117 auto data_type = DeduceDataTypeFromPrecision(precision);
118 for (auto storage : env->GetSupportedStorages(data_type)) {
119 OperationDef op_def;
120 op_def.precision = precision;
121 op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
122 op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
123 TensorFloat32 dst_tensor;
124 GPUOperation operation = CreateTile(op_def, src_tensor.shape.c);
125 RETURN_IF_ERROR(env->ExecuteGPUOperation(
126 src_tensor, std::make_unique<GPUOperation>(std::move(operation)),
127 BHWC(1, 4, 1, 3), &dst_tensor));
128 RETURN_IF_ERROR(
129 PointWiseNear({half(1.0f), half(2.0f), half(3.0f), half(4.0f),
130 half(5.0f), half(6.0f), half(1.0f), half(2.0f),
131 half(3.0f), half(4.0f), half(5.0f), half(6.0f)},
132 dst_tensor.data, 0.0f));
133 }
134 }
135 return absl::OkStatus();
136 }
137
TileHWCTest(TestExecutionEnvironment * env)138 absl::Status TileHWCTest(TestExecutionEnvironment* env) {
139 TensorFloat32 src_tensor;
140 src_tensor.shape = BHWC(1, 2, 2, 3);
141 src_tensor.data = {half(1.0f), half(2.0f), half(3.0f), half(4.0f),
142 half(5.0f), half(6.0f), half(7.0f), half(8.0f),
143 half(9.0f), half(10.0f), half(11.0f), half(12.0f)};
144 for (auto precision : env->GetSupportedPrecisions()) {
145 auto data_type = DeduceDataTypeFromPrecision(precision);
146 for (auto storage : env->GetSupportedStorages(data_type)) {
147 OperationDef op_def;
148 op_def.precision = precision;
149 op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
150 op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
151 TensorFloat32 dst_tensor;
152 GPUOperation operation = CreateTile(op_def, src_tensor.shape.c);
153 RETURN_IF_ERROR(env->ExecuteGPUOperation(
154 src_tensor, std::make_unique<GPUOperation>(std::move(operation)),
155 BHWC(1, 4, 4, 6), &dst_tensor));
156 RETURN_IF_ERROR(PointWiseNear(
157 {half(1.0f), half(2.0f), half(3.0f), half(1.0f), half(2.0f),
158 half(3.0f), half(4.0f), half(5.0f), half(6.0f), half(4.0f),
159 half(5.0f), half(6.0f), half(1.0f), half(2.0f), half(3.0f),
160 half(1.0f), half(2.0f), half(3.0f), half(4.0f), half(5.0f),
161 half(6.0f), half(4.0f), half(5.0f), half(6.0f), half(7.0f),
162 half(8.0f), half(9.0f), half(7.0f), half(8.0f), half(9.0f),
163 half(10.0f), half(11.0f), half(12.0f), half(10.0f), half(11.0f),
164 half(12.0f), half(7.0f), half(8.0f), half(9.0f), half(7.0f),
165 half(8.0f), half(9.0f), half(10.0f), half(11.0f), half(12.0f),
166 half(10.0f), half(11.0f), half(12.0f), half(1.0f), half(2.0f),
167 half(3.0f), half(1.0f), half(2.0f), half(3.0f), half(4.0f),
168 half(5.0f), half(6.0f), half(4.0f), half(5.0f), half(6.0f),
169 half(1.0f), half(2.0f), half(3.0f), half(1.0f), half(2.0f),
170 half(3.0f), half(4.0f), half(5.0f), half(6.0f), half(4.0f),
171 half(5.0f), half(6.0f), half(7.0f), half(8.0f), half(9.0f),
172 half(7.0f), half(8.0f), half(9.0f), half(10.0f), half(11.0f),
173 half(12.0f), half(10.0f), half(11.0f), half(12.0f), half(7.0f),
174 half(8.0f), half(9.0f), half(7.0f), half(8.0f), half(9.0f),
175 half(10.0f), half(11.0f), half(12.0f), half(10.0f), half(11.0f),
176 half(12.0f)},
177 dst_tensor.data, 0.0f));
178 }
179 }
180 return absl::OkStatus();
181 }
182
183 } // namespace gpu
184 } // namespace tflite
185