xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/cl/kernels/fully_connected_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/fully_connected.h"
17 
18 #include <vector>
19 
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "tensorflow/lite/delegates/gpu/cl/environment.h"
23 #include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h"
24 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
25 #include "tensorflow/lite/delegates/gpu/common/operations.h"
26 #include "tensorflow/lite/delegates/gpu/common/shape.h"
27 #include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
28 #include "tensorflow/lite/delegates/gpu/common/tasks/fully_connected_test_util.h"
29 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
30 
31 using ::testing::ElementsAreArray;
32 
33 namespace tflite {
34 namespace gpu {
35 namespace cl {
36 namespace {
37 
TEST_F(OpenCLOperationTest,FullyConnected)38 TEST_F(OpenCLOperationTest, FullyConnected) {
39   auto status = FullyConnectedTest(&exec_env_);
40   ASSERT_TRUE(status.ok()) << status.error_message();
41 }
42 
TEST_F(OpenCLOperationTest,FullyConnectedLarge)43 TEST_F(OpenCLOperationTest, FullyConnectedLarge) {
44   auto status = FullyConnectedLargeTest(&exec_env_);
45   ASSERT_TRUE(status.ok()) << status.error_message();
46 }
47 
TEST_F(OpenCLOperationTest,FullyConnectedExtraLarge)48 TEST_F(OpenCLOperationTest, FullyConnectedExtraLarge) {
49   auto status = FullyConnectedExtraLargeTest(&exec_env_);
50   ASSERT_TRUE(status.ok()) << status.error_message();
51 }
52 
TEST_F(OpenCLOperationTest,FullyConnectedInt8)53 TEST_F(OpenCLOperationTest, FullyConnectedInt8) {
54   auto status = FullyConnectedInt8Test(&exec_env_);
55   ASSERT_TRUE(status.ok()) << status.error_message();
56 }
57 
TEST_F(OpenCLOperationTest,RearrageWeights)58 TEST_F(OpenCLOperationTest, RearrageWeights) {
59   tflite::gpu::Tensor<OHWI, DataType::FLOAT32> weights;
60   weights.shape = OHWI(8, 1, 1, 8);
61   weights.data = {
62       0.0f,  1.0f,  2.0f,  3.0f,  4.0f,  5.0f,  6.0f,  7.0f,   //
63       10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f,  //
64       20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f,  //
65       30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f,  //
66       40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f,  //
67       50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f,  //
68       60.0f, 61.0f, 62.0f, 63.0f, 64.0f, 65.0f, 66.0f, 67.0f,  //
69       70.0f, 71.0f, 72.0f, 73.0f, 74.0f, 75.0f, 76.0f, 77.0f   //
70   };
71 
72   std::vector<float> expected_rearranged_data = {
73       // Top-left block
74       0.0f, 10.0f, 20.0f, 30.0f, 1.0f, 11.0f, 21.0f, 31.0f, 2.0f, 12.0f, 22.0f,
75       32.0f, 3.0f, 13.0f, 23.0f, 33.0f,
76       // Bottom-left block
77       40.0f, 50.0f, 60.0f, 70.0f, 41.0f, 51.0f, 61.0f, 71.0f, 42.0f, 52.0f,
78       62.0f, 72.0f, 43.0f, 53.0f, 63.0f, 73.0f,
79       // Top-right block
80       4.0f, 14.0f, 24.0f, 34.0f, 5.0f, 15.0f, 25.0f, 35.0f, 6.0f, 16.0f, 26.0f,
81       36.0f, 7.0f, 17.0f, 27.0f, 37.0f,
82       // Bottom-right block
83       44.0f, 54.0f, 64.0f, 74.0f, 45.0f, 55.0f, 65.0f, 75.0f, 46.0f, 56.0f,
84       66.0f, 76.0f, 47.0f, 57.0f, 67.0f, 77.0f};
85 
86   std::vector<float> data(8 * 8);
87   RearrangeFCWeightsToIOO4I4(weights, data.data());
88 
89   EXPECT_THAT(data, ElementsAreArray(expected_rearranged_data));
90 }
91 
TEST_F(OpenCLOperationTest,RearrageWeightsWhenPaddingIsRequired)92 TEST_F(OpenCLOperationTest, RearrageWeightsWhenPaddingIsRequired) {
93   tflite::gpu::Tensor<OHWI, DataType::FLOAT32> weights;
94   weights.shape = OHWI(9, 1, 1, 7);
95   weights.data = {
96       0.0f,  1.0f,  2.0f,  3.0f,  4.0f,  5.0f,  6.0f,   //
97       10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,  //
98       20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f,  //
99       30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f,  //
100       40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f,  //
101       50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f,  //
102       60.0f, 61.0f, 62.0f, 63.0f, 64.0f, 65.0f, 66.0f,  //
103       70.0f, 71.0f, 72.0f, 73.0f, 74.0f, 75.0f, 76.0f,  //
104       80.0f, 81.0f, 82.0f, 83.0f, 84.0f, 85.0f, 86.0f,  //
105   };
106 
107   std::vector<float> expected_rearranged_data = {
108       // Top-left block
109       0.0f, 10.0f, 20.0f, 30.0f, 1.0f, 11.0f, 21.0f, 31.0f, 2.0f, 12.0f, 22.0f,
110       32.0f, 3.0f, 13.0f, 23.0f, 33.0f,
111       // Mid-left block
112       40.0f, 50.0f, 60.0f, 70.0f, 41.0f, 51.0f, 61.0f, 71.0f, 42.0f, 52.0f,
113       62.0f, 72.0f, 43.0f, 53.0f, 63.0f, 73.0f,
114       // Bottom-left block
115       80.0f, 0.0f, 0.0f, 0.0f, 81.0f, 0.0f, 0.0f, 0.0f, 82.0f, 0.0f, 0.0f, 0.0f,
116       83.0f, 0.0f, 0.0f, 0.0f,
117       // Top-right block
118       4.0f, 14.0f, 24.0f, 34.0f, 5.0f, 15.0f, 25.0f, 35.0f, 6.0f, 16.0f, 26.0f,
119       36.0f, 0.0f, 0.0f, 0.0f, 0.0f,
120       // Mid-left block
121       44.0f, 54.0f, 64.0f, 74.0f, 45.0f, 55.0f, 65.0f, 75.0f, 46.0f, 56.0f,
122       66.0f, 76.0f, 0.0f, 0.0f, 0.0f, 0.0f,
123       // Bottom-right block
124       84.0f, 0.0f, 0.0f, 0.0f, 85.0f, 0.0f, 0.0f, 0.0f, 86.0f, 0.0f, 0.0f, 0.0f,
125       0.0f, 0.0f, 0.0f, 0.0f};
126 
127   std::vector<float> data(12 * 8);
128   RearrangeFCWeightsToIOO4I4(weights, data.data());
129 
130   EXPECT_THAT(data, ElementsAreArray(expected_rearranged_data));
131 }
132 
133 }  // namespace
134 }  // namespace cl
135 }  // namespace gpu
136 }  // namespace tflite
137