xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/gl/kernels/conv.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/conv.h"
17 
18 #include <any>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/lite/delegates/gpu/common/convert.h"
27 #include "tensorflow/lite/delegates/gpu/common/operations.h"
28 #include "tensorflow/lite/delegates/gpu/common/shape.h"
29 #include "tensorflow/lite/delegates/gpu/common/status.h"
30 #include "tensorflow/lite/delegates/gpu/common/types.h"
31 #include "tensorflow/lite/delegates/gpu/common/util.h"
32 #include "tensorflow/lite/delegates/gpu/gl/node_shader.h"
33 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
34 #include "tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.h"
35 
36 namespace tflite {
37 namespace gpu {
38 namespace gl {
39 namespace {
40 
41 class Convolution : public NodeShader {
42  public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const43   absl::Status GenerateCode(const GenerationContext& ctx,
44                             GeneratedCode* generated_code) const final {
45     if (ctx.input_shapes.size() != 1) {
46       return absl::UnimplementedError(
47           "Convolution does not support more than 1 runtime tensor");
48     }
49     const auto& attr =
50         std::any_cast<const Convolution2DAttributes&>(ctx.op_attr);
51     if (attr.groups != 1) {
52       return absl::UnimplementedError(
53           "Convolution does not support more than 1 group");
54     }
55     auto weights = attr.weights.shape;
56     const int offsets_count = weights.h * weights.w;
57     const bool offsets_count_too_large = offsets_count > kMaxConstArraySize;
58     std::vector<Variable> parameters;
59     if (offsets_count_too_large) {
60       parameters = {
61           {"input_data_0_h", static_cast<int>(ctx.input_shapes[0][1])},
62           {"input_data_0_w", static_cast<int>(ctx.input_shapes[0][2])},
63           {"padding_w", attr.padding.prepended.w},
64           {"padding_h", attr.padding.prepended.h},
65           {"dilation_w", attr.dilations.w},
66           {"dilation_h", attr.dilations.h},
67           {"kernel_w", weights.w},
68           {"kernel_h", weights.h},
69           {"src_depth", DivideRoundUp(weights.i, 4)},
70           {"stride", int2(attr.strides.w, attr.strides.h)},
71       };
72     } else {
73       std::vector<int2> offsets;
74       for (int h = 0; h < weights.h; ++h) {
75         for (int w = 0; w < weights.w; ++w) {
76           offsets.emplace_back(w * attr.dilations.w - attr.padding.prepended.w,
77                                h * attr.dilations.h - attr.padding.prepended.h);
78         }
79       }
80       parameters = {
81           {"input_data_0_h", static_cast<int>(ctx.input_shapes[0][1])},
82           {"input_data_0_w", static_cast<int>(ctx.input_shapes[0][2])},
83           {"offsets_count", offsets_count},
84           {"offsets", offsets},
85           {"src_depth", DivideRoundUp(weights.i, 4)},
86           {"stride", int2(attr.strides.w, attr.strides.h)},
87       };
88     }
89 
90     // at least one padding is not empty
91     bool non_empty_padding =
92         attr.padding.appended.h != 0 || attr.padding.appended.w != 0 ||
93         attr.padding.prepended.h != 0 || attr.padding.prepended.w != 0;
94 
95     std::vector<std::pair<std::string, Object>> objects = {
96         {"weights", MakeReadonlyObject(Get3DSizeForPHWO4I4(attr.weights.shape),
97                                        ConvertToPHWO4I4(attr.weights))}};
98 
99     std::string source;
100     if (offsets_count_too_large) {
101       source = R"(
102       int i = 0;
103       for (int ky = 0; ky < $kernel_h$; ky++) {
104         for (int kx = 0; kx < $kernel_w$; kx++, i++) {
105           ivec2 coord = gid.xy * $stride$ + ivec2(kx * $dilation_w$ - $padding_w$, ky * $dilation_h$ - $padding_h$);)";
106     } else {
107       source = R"(
108         for (int i = 0; i < $offsets_count$; ++i) {
109           ivec2 coord = gid.xy * $stride$ + $offsets[i]$;)";
110     }
111     if (non_empty_padding) {
112       source += R"(
113         if (coord.x < 0 || coord.y < 0 || coord.x >= $input_data_0_w$ || coord.y >= $input_data_0_h$) {
114           continue;
115         })";
116     }
117     source += R"(
118           for (int l = 0; l < $src_depth$; ++l) {
119             vec4 input_ = $input_data_0[coord.x, coord.y, l]$;
120             value_0.x += dot(input_, $weights[l * 4 + 0, i, gid.z]$);
121             value_0.y += dot(input_, $weights[l * 4 + 1, i, gid.z]$);
122             value_0.z += dot(input_, $weights[l * 4 + 2, i, gid.z]$);
123             value_0.w += dot(input_, $weights[l * 4 + 3, i, gid.z]$);
124           }
125         }
126 )";
127     if (offsets_count_too_large) {
128       source += R"(
129       }
130 )";
131     }
132     if (!attr.bias.data.empty()) {
133       source += "value_0 += $bias[gid.z]$;\n";
134       objects.push_back({"bias", MakeReadonlyObject(attr.bias.data)});
135     }
136 
137     *generated_code = {
138         /*parameters=*/std::move(parameters),
139         /*objects=*/std::move(objects),
140         /*shared_variables=*/{},
141         /*workload=*/uint3(),
142         /*workgroup=*/
143         GetIdealWorkgroupIfPossible(
144             *ctx.gpu_info, OperationType::CONVOLUTION_2D,
145             HW(weights.h, weights.w), attr.strides, uint3(0, 0, 0),
146             OHWI(weights.o, ctx.input_shapes[0][1], ctx.input_shapes[0][2],
147                  ctx.input_shapes[0][3])),
148         /*source_code=*/std::move(source),
149         /*input=*/IOStructure::ONLY_DEFINITIONS,
150         /*output=*/IOStructure::AUTO,
151     };
152     return absl::OkStatus();
153   }
154 };
155 
SelectMultiplier(int32_t input_width,const NodeShader::GenerationContext & ctx)156 int SelectMultiplier(int32_t input_width,
157                      const NodeShader::GenerationContext& ctx) {
158   std::vector<int> multipliers = {4, 2};
159   if (ctx.gpu_info->IsAMD()) {
160     return 1;
161   }
162   if (!ctx.compiler_options.allow_precision_loss && ctx.gpu_info->IsMali()) {
163     multipliers = {2};
164   }
165   for (int i : multipliers) {
166     if (input_width % i == 0) {
167       return i;
168     }
169   }
170   return 1;
171 }
172 
173 class Convolution1x1 : public NodeShader {
174  public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const175   absl::Status GenerateCode(const GenerationContext& ctx,
176                             GeneratedCode* generated_code) const final {
177     if (ctx.input_shapes.size() != 1) {
178       return absl::UnimplementedError(
179           "Convolution does not support more than 1 runtime tensor");
180     }
181     const auto& attr =
182         std::any_cast<const Convolution2DAttributes&>(ctx.op_attr);
183     if (attr.weights.shape.h != 1 || attr.weights.shape.w != 1) {
184       return absl::UnimplementedError("Height and width should be 1.");
185     }
186     if (attr.dilations.h != 1 || attr.dilations.w != 1) {
187       return absl::UnimplementedError("Dilations are not supported.");
188     }
189     if (attr.strides.h != 1 || attr.strides.w != 1) {
190       return absl::UnimplementedError("Strides are not supported.");
191     }
192     if (attr.padding.appended.h != 0 || attr.padding.appended.w != 0 ||
193         attr.padding.prepended.h != 0 || attr.padding.prepended.w != 0) {
194       return absl::UnimplementedError("Padding is not supported.");
195     }
196 
197     int multiplier = SelectMultiplier(ctx.input_shapes[0][2], ctx);
198 
199     std::vector<Variable> parameters = {
200         {"src_depth",
201          DivideRoundUp(static_cast<int>(ctx.input_shapes[0][3]), 4)},
202     };
203 
204     std::vector<std::pair<std::string, Object>> objects = {
205         {"weights",
206          MakeReadonlyObject(uint3(4, DivideRoundUp(attr.weights.shape.i, 4),
207                                   DivideRoundUp(attr.weights.shape.o, 4)),
208                             ConvertToPHWO4I4(attr.weights))}};
209     std::string source;
210     for (int i = 0; i < multiplier; i++) {
211       absl::StrAppend(&source, "highp vec4 result", i, " = vec4(0);\n");
212     }
213     absl::StrAppend(&source, "vec4 f;\n");
214     absl::StrAppend(&source, "for (int l = 0; l < $src_depth$; ++l) {\n");
215     for (int i = 0; i < multiplier; i++) {
216       absl::StrAppend(&source, "  vec4 input", i, " = $input_data_0[gid.x * ",
217                       multiplier, " + ", i, ",gid.y,l]$;\n");
218     }
219     for (int k = 0; k < 4; k++) {
220       absl::StrAppend(&source, "  f = $weights[", k, ", l, gid.z]$;\n");
221       for (int i = 0; i < multiplier; i++) {
222         absl::StrAppend(&source, "  result", i, "[", k, "] += dot(input", i,
223                         ", f);\n");
224       }
225     }
226     absl::StrAppend(&source, "}\n");
227     if (!attr.bias.data.empty()) {
228       objects.push_back({"bias", MakeReadonlyObject(attr.bias.data)});
229       absl::StrAppend(&source, "vec4 b = $bias[gid.z]$;\n");
230       for (int i = 0; i < multiplier; i++) {
231         absl::StrAppend(&source, "result", i, " += b;\n");
232       }
233     }
234     if (multiplier != 1) {
235       for (int i = 0; i < multiplier; i++) {
236         absl::StrAppend(&source, "$inplace_update:result", i, "$\n");
237         absl::StrAppend(&source, "$output_data_0[gid.x * ", multiplier, " + ",
238                         i, ",gid.y,gid.z] = result", i, "$;\n");
239       }
240     } else {
241       absl::StrAppend(&source, "value_0 = result0;\n");
242     }
243 
244     auto dst_depth = DivideRoundUp(ctx.output_shapes[0][3], 4);
245     uint3 workgroup = uint3(16, 16, 1);
246     if (ctx.gpu_info->IsAdreno()) {
247       if (dst_depth >= 2) {
248         workgroup = uint3(8, 8, 2);
249       }
250       if (dst_depth >= 4) {
251         workgroup = uint3(4, 8, 4);
252       }
253       if (dst_depth >= 8) {
254         workgroup = uint3(4, 4, 8);
255       }
256       if (dst_depth >= 32) {
257         workgroup = uint3(4, 4, 16);
258       }
259       if (dst_depth >= 64) {
260         workgroup = uint3(2, 8, 16);
261       }
262     } else {
263       if (dst_depth >= 2) {
264         workgroup = uint3(16, 8, 2);
265       }
266       if (dst_depth >= 4) {
267         workgroup = uint3(16, 4, 4);
268       }
269       if (dst_depth >= 8) {
270         workgroup = uint3(8, 4, 8);
271       }
272       if (dst_depth >= 32) {
273         workgroup = uint3(8, 4, 8);
274       }
275       if (dst_depth >= 64) {
276         workgroup = uint3(8, 4, 8);
277       }
278     }
279     *generated_code = {
280         /*parameters=*/std::move(parameters),
281         /*objects=*/std::move(objects),
282         /*shared_variables=*/{},
283         /*workload=*/
284         uint3(ctx.output_shapes[0][2] / multiplier, ctx.output_shapes[0][1],
285               DivideRoundUp(ctx.output_shapes[0][3], 4)),
286         /*workgroup=*/
287         GetIdealWorkgroupIfPossible(
288             *ctx.gpu_info, OperationType::CONVOLUTION_2D,
289             HW(attr.weights.shape.h, attr.weights.shape.w), attr.strides,
290             workgroup,
291             OHWI(attr.weights.shape.o, ctx.input_shapes[0][1],
292                  ctx.input_shapes[0][2], ctx.input_shapes[0][3])),
293         /*source_code=*/std::move(source),
294         /*input=*/IOStructure::ONLY_DEFINITIONS,
295         /*output=*/multiplier == 1 ? IOStructure::AUTO
296                                    : IOStructure::ONLY_DEFINITIONS,
297     };
298     return absl::OkStatus();
299   }
300 };
301 
302 }  // namespace
303 
NewConvolutionNodeShader()304 std::unique_ptr<NodeShader> NewConvolutionNodeShader() {
305   return std::make_unique<Convolution>();
306 }
307 
NewConvolution1x1NodeShader()308 std::unique_ptr<NodeShader> NewConvolution1x1NodeShader() {
309   return std::make_unique<Convolution1x1>();
310 }
311 
312 }  // namespace gl
313 }  // namespace gpu
314 }  // namespace tflite
315