xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.h"
17 
18 #include <algorithm>
19 #include <any>
20 #include <cstdint>
21 #include <cstring>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/memory/memory.h"
28 #include "tensorflow/lite/delegates/gpu/common/convert.h"
29 #include "tensorflow/lite/delegates/gpu/common/status.h"
30 #include "tensorflow/lite/delegates/gpu/common/types.h"
31 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
32 
33 namespace tflite {
34 namespace gpu {
35 namespace gl {
36 namespace {
37 
38 class FullyConnectedBuffers : public NodeShader {
39  public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const40   absl::Status GenerateCode(const GenerationContext& ctx,
41                             GeneratedCode* generated_code) const final {
42     const auto& attr =
43         std::any_cast<const FullyConnectedAttributes&>(ctx.op_attr);
44 
45     const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
46     const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
47 
48     // This shader can work with any workgroup size, the values below work well
49     // for OpenGL.
50     constexpr int kWorkgroupHintX = 4;
51     constexpr int kWorkgroupHintY = 4;
52 
53     // TODO(akulik): check that input has h,w == 1,1
54     std::vector<Variable> parameters = {
55         {"src_depth", src_depth},
56         {"dst_depth", dst_depth},
57     };
58 
59     // TODO(akulik): refactor indexed access to weights.
60     std::vector<std::pair<std::string, Object>> objects = {
61         {"weights", MakeReadonlyObject(ConvertToPHWO4I4(attr.weights))}};
62 
63     std::string source = R"(
64   const int threads = int(gl_WorkGroupSize.y);
65   const int workers = int(gl_WorkGroupSize.x);
66   ivec3 tid = ivec3(gl_LocalInvocationID);
67 
68   if (gid.x < $dst_depth$) {
69     int offset = 4 * gid.x * $src_depth$ + 4 * tid.y;
70     for (int d = tid.y; d < $src_depth$; d += threads, offset += 4 * threads) {
71       vec4 src = $input_data_0[0, 0, d]$;
72       value_0.x += dot(src, $weights[offset + 0]$);
73       value_0.y += dot(src, $weights[offset + 1]$);
74       value_0.z += dot(src, $weights[offset + 2]$);
75       value_0.w += dot(src, $weights[offset + 3]$);
76     }
77     sh_mem[workers * tid.y + tid.x] = value_0;
78   }
79   memoryBarrierShared();
80   barrier();
81 
82   if (tid.y > 0 || gid.x >= $dst_depth$) {
83     return;
84   }
85 
86   for (int t = 1; t < threads; t++) {
87     value_0 += sh_mem[workers * t + tid.x];
88   }
89 )";
90     if (!attr.bias.data.empty()) {
91       source += "  value_0 += $bias[gid.x]$;\n";
92       objects.push_back({"bias", MakeReadonlyObject(attr.bias.data)});
93     }
94     source += "  $output_data_0[0, 0, gid.x] = value_0$;";
95 
96     std::vector<Variable> shared_variables = {
97 #ifdef __APPLE__
98         // MoltenVK has problems with shared memory sized using the workgroup
99         // size. Fortunately with Metal a fixed workgroup size of 32 seems to
100         // give optimal results.
101         {"sh_mem", std::vector<float4>(32)},
102 #else
103         // The actual size of sh_mem depends on the WorkgroupSize
104         {"sh_mem", std::vector<float4>(0)},
105 #endif
106     };
107 
108     *generated_code = {
109         /*parameters=*/std::move(parameters),
110         /*objects=*/std::move(objects),
111         /*shared_variables=*/std::move(shared_variables),
112         /*workload=*/uint3(dst_depth, kWorkgroupHintY, 1),
113         /*workgroup=*/uint3(kWorkgroupHintX, kWorkgroupHintY, 1),
114         /*source_code=*/std::move(source),
115         /*input=*/IOStructure::ONLY_DEFINITIONS,
116         /*output=*/IOStructure::ONLY_DEFINITIONS,
117     };
118     return absl::OkStatus();
119   }
120 };
121 
122 }  // namespace
123 
NewFullyConnectedNodeShader()124 std::unique_ptr<NodeShader> NewFullyConnectedNodeShader() {
125   return std::make_unique<FullyConnectedBuffers>();
126 }
127 
128 }  // namespace gl
129 }  // namespace gpu
130 }  // namespace tflite
131