xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/elementwise.h"
17 
18 #include <any>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <variant>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/lite/delegates/gpu/common/status.h"
27 #include "tensorflow/lite/delegates/gpu/common/types.h"
28 
29 namespace tflite {
30 namespace gpu {
31 namespace gl {
32 namespace {
33 
34 class ElementwiseOneArgument : public NodeShader {
35  public:
ElementwiseOneArgument(OperationType operation_type)36   explicit ElementwiseOneArgument(OperationType operation_type)
37       : operation_type_(operation_type) {}
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const38   absl::Status GenerateCode(const GenerationContext& ctx,
39                             GeneratedCode* generated_code) const final {
40     std::string source;
41     switch (operation_type_) {
42       case OperationType::ABS:
43         source = "value_0 = abs(value_0);";
44         break;
45       case OperationType::COS:
46         source = "value_0 = cos(value_0);";
47         break;
48       case OperationType::COPY:
49         source = "value_0 = value_0;";
50         break;
51       case OperationType::ELU:
52         source = R"(
53             value_0.x = value_0.x < 0.0 ? exp(value_0.x) - 1.0 : value_0.x;
54             value_0.y = value_0.y < 0.0 ? exp(value_0.y) - 1.0 : value_0.y;
55             value_0.z = value_0.z < 0.0 ? exp(value_0.z) - 1.0 : value_0.z;
56             value_0.w = value_0.w < 0.0 ? exp(value_0.w) - 1.0 : value_0.w;
57         )";
58         break;
59       case OperationType::EXP:
60         source = "value_0 = exp(value_0);";
61         break;
62       case tflite::gpu::OperationType::FLOOR:
63         source = "value_0 = floor(value_0);";
64         break;
65       case OperationType::HARD_SWISH:
66         source =
67             "value_0 *= clamp(value_0 / 6.0 + vec4(0.5), vec4(0.0), "
68             "vec4(1.0));";
69         break;
70       case OperationType::LOG:
71         source = R"(
72             const float nan = normalize(vec4(0, 0, 0, 0)).x;
73             value_0.x = value_0.x > 0.0 ? log(value_0.x) : nan;
74             value_0.y = value_0.y > 0.0 ? log(value_0.y) : nan;
75             value_0.z = value_0.z > 0.0 ? log(value_0.z) : nan;
76             value_0.w = value_0.w > 0.0 ? log(value_0.w) : nan;
77         )";
78         break;
79       case OperationType::NEG:
80         source = "value_0 = -(value_0);";
81         break;
82       case OperationType::RSQRT:
83         source = R"(
84             const float nan = normalize(vec4(0, 0, 0, 0)).x;
85             value_0.x = value_0.x > 0.0 ? 1.0 / sqrt(value_0.x) : nan;
86             value_0.y = value_0.y > 0.0 ? 1.0 / sqrt(value_0.y) : nan;
87             value_0.z = value_0.z > 0.0 ? 1.0 / sqrt(value_0.z) : nan;
88             value_0.w = value_0.w > 0.0 ? 1.0 / sqrt(value_0.w) : nan;
89         )";
90         break;
91       case OperationType::SIGMOID:
92         source = "value_0 = 1.0 / (1.0 + exp(-1.0 * value_0));";
93         break;
94       case OperationType::SIN:
95         source = "value_0 = sin(value_0);";
96         break;
97       case OperationType::SQRT:
98         source = R"(
99             const float nan = normalize(vec4(0, 0, 0, 0)).x;
100             value_0.x = value_0.x >= 0.0 ? sqrt(value_0.x) : nan;
101             value_0.y = value_0.y >= 0.0 ? sqrt(value_0.y) : nan;
102             value_0.z = value_0.z >= 0.0 ? sqrt(value_0.z) : nan;
103             value_0.w = value_0.w >= 0.0 ? sqrt(value_0.w) : nan;
104         )";
105         break;
106       case OperationType::SQUARE:
107         source = "value_0 = value_0 * value_0;";
108         break;
109       case OperationType::TANH:
110         source = "value_0 = tanh(value_0);";
111         break;
112       default:
113         return absl::InvalidArgumentError(
114             "Incorrect elementwise operation type.");
115     }
116     *generated_code = {
117         /*parameters=*/{},
118         /*objects=*/{},
119         /*shared_variables=*/{},
120         /*workload=*/uint3(),
121         /*workgroup=*/uint3(),
122         source,
123         /*input=*/IOStructure::AUTO,
124         /*output=*/IOStructure::AUTO,
125     };
126     return absl::OkStatus();
127   }
128 
129  private:
130   OperationType operation_type_;
131 };
132 
133 class ElementwiseTwoArguments : public NodeShader {
134  public:
ElementwiseTwoArguments(OperationType operation_type)135   explicit ElementwiseTwoArguments(OperationType operation_type)
136       : operation_type_(operation_type) {}
137 
IsElementwiseSupported(const GenerationContext & ctx) const138   inline bool IsElementwiseSupported(const GenerationContext& ctx) const {
139     return ctx.input_shapes.size() == 2 &&
140            ctx.input_shapes[0] == ctx.input_shapes[1];
141   }
142 
IsBroadcastSupported(const GenerationContext & ctx) const143   inline bool IsBroadcastSupported(const GenerationContext& ctx) const {
144     return ctx.input_shapes.size() == 2 && ctx.input_shapes[1][1] == 1 &&
145            ctx.input_shapes[1][2] == 1 &&
146            ctx.input_shapes[0][3] == ctx.input_shapes[1][3];
147   }
148 
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const149   absl::Status GenerateCode(const GenerationContext& ctx,
150                             GeneratedCode* generated_code) const final {
151     std::vector<Variable> parameters;
152     std::vector<std::pair<std::string, Object>> objects;
153     std::string argument0, argument1;
154     if (IsElementwiseSupported(ctx)) {
155       argument0 = "value_0";
156       argument1 = "value_1";
157     } else if (IsBroadcastSupported(ctx)) {
158       argument0 = "$input_data_0[gid.x, gid.y, gid.z]$";
159       argument1 = "$input_data_1[0, 0, gid.z]$";
160     } else {  // Scalar of const vector case
161       const auto& attr =
162           std::any_cast<const ElementwiseAttributes&>(ctx.op_attr);
163       const auto* tensor =
164           std::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
165       const auto* scalar = std::get_if<float>(&attr.param);
166       if (!tensor && !scalar) {
167         return absl::InvalidArgumentError(
168             "Couldn't read scalar of const vector data from the attributes.");
169       }
170 
171       argument0 = "value_0";
172       if (tensor) {
173         argument1 = "$const_data[gid.z]$";
174         objects.push_back({"const_data", MakeReadonlyObject(tensor->data)});
175       } else {
176         argument1 = "vec4($const_data$)";
177         parameters.push_back({"const_data", *scalar});
178       }
179     }
180 
181     std::string source;
182     switch (operation_type_) {
183       case OperationType::DIV: {
184         source = "value_0 = $0/$1;";
185         break;
186       }
187       case tflite::gpu::OperationType::FLOOR_DIV:
188         source = "value_0 = floor($0 / $1);";
189         break;
190       case tflite::gpu::OperationType::FLOOR_MOD:
191         source = "value_0 = $0 - floor($0 / $1) * $1;";
192         break;
193       case OperationType::MAXIMUM: {
194         source = "value_0 = max($0, $1);";
195         break;
196       }
197       case OperationType::MINIMUM: {
198         source = "value_0 = min($0, $1);";
199         break;
200       }
201       case OperationType::SQUARED_DIFF: {
202         source = "value_0 = ($0 - $1) * ($0 - $1);";
203         break;
204       }
205       case OperationType::SUB: {
206         source = "value_0 = $0 - $1;";
207         break;
208       }
209       case OperationType::POW: {
210         source = "value_0 = pow($0, $1);";
211         break;
212       }
213       default:
214         return absl::InvalidArgumentError(
215             "Incorrect elementwise with scalar operation type.");
216     }
217     source = absl::Substitute(source, argument0, argument1);
218     *generated_code = {
219         /*parameters=*/std::move(parameters),
220         /*objects=*/std::move(objects),
221         /*shared_variables=*/{},
222         /*workload=*/uint3(),
223         /*workgroup=*/uint3(),
224         /*source_code=*/source,
225         /*input=*/IOStructure::AUTO,
226         /*output=*/IOStructure::AUTO,
227     };
228     return absl::OkStatus();
229   }
230 
231  private:
232   OperationType operation_type_;
233 };
234 
235 }  // namespace
236 
NewElementwiseNodeShader(OperationType operation_type)237 std::unique_ptr<NodeShader> NewElementwiseNodeShader(
238     OperationType operation_type) {
239   switch (operation_type) {
240     case OperationType::ABS:
241     case OperationType::COS:
242     case OperationType::COPY:
243     case OperationType::ELU:
244     case OperationType::EXP:
245     case OperationType::FLOOR:
246     case OperationType::HARD_SWISH:
247     case OperationType::LOG:
248     case OperationType::NEG:
249     case OperationType::RSQRT:
250     case OperationType::SIGMOID:
251     case OperationType::SIN:
252     case OperationType::SQRT:
253     case OperationType::SQUARE:
254     case OperationType::TANH:
255       return std::make_unique<ElementwiseOneArgument>(operation_type);
256     case OperationType::DIV:
257     case OperationType::FLOOR_DIV:
258     case OperationType::FLOOR_MOD:
259     case OperationType::MAXIMUM:
260     case OperationType::MINIMUM:
261     case OperationType::POW:
262     case OperationType::SQUARED_DIFF:
263     case OperationType::SUB:
264       return std::make_unique<ElementwiseTwoArguments>(operation_type);
265     default:
266       return nullptr;
267   }
268 }
269 
270 }  // namespace gl
271 }  // namespace gpu
272 }  // namespace tflite
273