xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/gl/kernels/pooling.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/pooling.h"
17 
18 #include <algorithm>
19 #include <any>
20 #include <cstdint>
21 #include <cstring>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/memory/memory.h"
28 #include "tensorflow/lite/delegates/gpu/common/status.h"
29 #include "tensorflow/lite/delegates/gpu/common/types.h"
30 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
31 
32 namespace tflite {
33 namespace gpu {
34 namespace gl {
35 namespace {
36 
GenerateMaxPoolingCode(const Pooling2DAttributes & attr,const NodeShader::GenerationContext & ctx,GeneratedCode * generated_code)37 absl::Status GenerateMaxPoolingCode(const Pooling2DAttributes& attr,
38                                     const NodeShader::GenerationContext& ctx,
39                                     GeneratedCode* generated_code) {
40   if (attr.padding.prepended.h > attr.kernel.h ||
41       attr.padding.prepended.w > attr.kernel.w) {
42     return absl::InvalidArgumentError("Padding is bigger than kernel.");
43   }
44 
45   std::vector<Variable> parameters = {
46       {"input_data_0_h", static_cast<int>(ctx.input_shapes[0][1])},
47       {"input_data_0_w", static_cast<int>(ctx.input_shapes[0][2])},
48       {"stride", int2(attr.strides.w, attr.strides.h)},
49       {"offset", int2(attr.padding.prepended.w, attr.padding.prepended.h)},
50       {"window_h", attr.kernel.h},
51       {"window_w", attr.kernel.w},
52   };
53 
54   // Per GLSL_ES 3.1 spec in Issue 13.4
55   // "Floating Point Representation and Functionality" highp floats are
56   // expected to behave as defined in IEEE 754. In particular, signed
57   // infinities are mandated and defined as a number divided by 0.
58   std::string source = R"(
59   const highp float inf = -(1.0f / 0.0f);
60   value_0 = vec4(inf);)";
61   if (attr.output_indices) {
62     source += R"(
63   ivec4 value_1;
64 )";
65   }
66   source += R"(
67   ivec2 base_coord = gid.xy * $stride$ - $offset$;
68   for (int a = 0; a < $window_h$; ++a) {
69     for (int b = 0; b < $window_w$; ++b) {
70       ivec2 coord = base_coord + ivec2(b, a);
71       if (coord.x < 0 || coord.y < 0 || coord.x >= $input_data_0_w$ || coord.y >= $input_data_0_h$) {
72         continue;
73       }
74       vec4 input_ = $input_data_0[coord.x, coord.y, gid.z]$;)";
75   if (attr.output_indices) {
76     source += R"(
77       int window_index = a * $window_w$ + b;
78       if (input_.x > value_0.x) value_1.x = window_index;
79       if (input_.y > value_0.y) value_1.y = window_index;
80       if (input_.z > value_0.z) value_1.z = window_index;
81       if (input_.w > value_0.w) value_1.w = window_index;)";
82   }
83   source += R"(
84       value_0 = max(value_0, input_);
85     }
86   }
87 )";
88   *generated_code = {
89       /*parameters=*/std::move(parameters),
90       /*objects=*/{},
91       /*shared_variables=*/{},
92       /*workload=*/uint3(),
93       /*workgroup=*/uint3(),
94       /*source_code=*/std::move(source),
95       /*input=*/IOStructure::ONLY_DEFINITIONS,
96       /*output=*/IOStructure::AUTO,
97   };
98   return absl::OkStatus();
99 }
100 
GenerateAveragePoolingCode(const Pooling2DAttributes & attr,const NodeShader::GenerationContext & ctx,GeneratedCode * generated_code)101 absl::Status GenerateAveragePoolingCode(
102     const Pooling2DAttributes& attr, const NodeShader::GenerationContext& ctx,
103     GeneratedCode* generated_code) {
104   std::vector<Variable> parameters = {
105       {"input_data_0_h", static_cast<int>(ctx.input_shapes[0][1])},
106       {"input_data_0_w", static_cast<int>(ctx.input_shapes[0][2])},
107       {"stride", int2(attr.strides.w, attr.strides.h)},
108       {"offset", int2(attr.padding.prepended.w, attr.padding.prepended.h)},
109       {"window_h", attr.kernel.h},
110       {"window_w", attr.kernel.w},
111   };
112 
113   // Bounds checking helper functions.
114   auto x_in_bounds = [input_width = ctx.input_shapes[0][2],
115                       kernel_width = attr.kernel.w](int64_t x) -> bool {
116     return 0 <= x && x + kernel_width <= input_width;
117   };
118   auto y_in_bounds = [input_height = ctx.input_shapes[0][1],
119                       kernel_height = attr.kernel.h](int64_t y) -> bool {
120     return 0 <= y && y + kernel_height <= input_height;
121   };
122 
123   // Only include a bounds check in the shader if it will actually be necessary
124   // at run time.
125   const int64_t output_shape_max_y = ctx.output_shapes[0][1] - 1;
126   const int64_t output_shape_max_x = ctx.output_shapes[0][2] - 1;
127   const int64_t base_x = -attr.padding.prepended.w;
128   const int64_t base_y = -attr.padding.prepended.h;
129   const bool bounds_check_necessary =
130       !(x_in_bounds(base_x) &&
131         x_in_bounds(base_x + output_shape_max_x * attr.strides.w) &&
132         y_in_bounds(base_y) &&
133         y_in_bounds(base_y + output_shape_max_y * attr.strides.h));
134 
135   std::string source = bounds_check_necessary ?
136                                               R"(
137   int window_size = 0;
138   for (int a = 0; a < $window_h$; ++a) {
139     for (int b = 0; b < $window_w$; ++b) {
140       ivec2 coord = gid.xy * $stride$ - $offset$ + ivec2(b, a);
141       if (coord.x >= 0 && coord.y >= 0 && coord.x < $input_data_0_w$ && coord.y < $input_data_0_h$) {
142         value_0 += $input_data_0[coord.x, coord.y, gid.z]$;
143         window_size++;
144       }
145     }
146   }
147   // If window_size==0, window covered nothing. This situation is a sign of
148   // incorrectly constructed operation. NaNs are expected as output.
149   value_0 /= float(window_size);
150 )"
151                                               :
152                                               R"(
153   for (int a = 0; a < $window_h$; ++a) {
154     for (int b = 0; b < $window_w$; ++b) {
155       ivec2 coord = gid.xy * $stride$ - $offset$ + ivec2(b, a);
156       value_0 += $input_data_0[coord.x, coord.y, gid.z]$;
157     }
158   }
159   // If the denominator is 0, that is a sign of an incorrectly constructed
160   // operation. NaNs are expected as output.
161   value_0 /= float($window_h$ * $window_w$);
162 )";
163 
164   *generated_code = {
165       /*parameters=*/std::move(parameters),
166       /*objects=*/{},
167       /*shared_variables=*/{},
168       /*workload=*/uint3(),
169       /*workgroup=*/uint3(),
170       /*source_code=*/std::move(source),
171       /*input=*/IOStructure::ONLY_DEFINITIONS,
172       /*output=*/IOStructure::AUTO,
173   };
174   return absl::OkStatus();
175 }
176 
177 class Pooling : public NodeShader {
178  public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const179   absl::Status GenerateCode(const GenerationContext& ctx,
180                             GeneratedCode* generated_code) const final {
181     const auto& attr = std::any_cast<const Pooling2DAttributes&>(ctx.op_attr);
182     switch (attr.type) {
183       case PoolingType::AVERAGE:
184         return GenerateAveragePoolingCode(attr, ctx, generated_code);
185       case PoolingType::MAX:
186         return GenerateMaxPoolingCode(attr, ctx, generated_code);
187       default:
188         return absl::InvalidArgumentError("Incorrect attributes' type.");
189     }
190   }
191 };
192 
193 }  // namespace
194 
NewPoolingNodeShader()195 std::unique_ptr<NodeShader> NewPoolingNodeShader() {
196   return std::make_unique<Pooling>();
197 }
198 
199 }  // namespace gl
200 }  // namespace gpu
201 }  // namespace tflite
202