xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/tasks/elementwise.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/elementwise.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_replace.h"
24 #include "absl/strings/substitute.h"
25 
26 namespace tflite {
27 namespace gpu {
28 
29 namespace {
GetOneInputCode(const GpuInfo & gpu_info,const OperationType & op_type,CalculationsPrecision precision,const std::string & input_value,const std::string & output_value)30 std::string GetOneInputCode(const GpuInfo& gpu_info,
31                             const OperationType& op_type,
32                             CalculationsPrecision precision,
33                             const std::string& input_value,
34                             const std::string& output_value) {
35   const bool use_native_opencl_functions =
36       gpu_info.IsApiOpenCl() && precision != CalculationsPrecision::F32 &&
37       gpu_info.IsAdreno();
38   std::string result;
39   switch (op_type) {
40     case OperationType::ABS:
41       result = "$0 = fabs($1);";
42       break;
43     case OperationType::COS:
44       if (use_native_opencl_functions) {
45         result = "$0 = convert_half4(native_cos(convert_float4($1)));";
46       } else {
47         result = "$0 = cos($1);";
48       }
49       break;
50     case OperationType::COPY:
51       result = "$0 = $1;";
52       break;
53     case OperationType::ELU:
54       if (gpu_info.IsApiOpenCl()) {
55         result = R"(
56 $0.x = $1.x < INIT_FLT(0.0f) ? expm1($1.x) : $1.x;
57 $0.y = $1.y < INIT_FLT(0.0f) ? expm1($1.y) : $1.y;
58 $0.z = $1.z < INIT_FLT(0.0f) ? expm1($1.z) : $1.z;
59 $0.w = $1.w < INIT_FLT(0.0f) ? expm1($1.w) : $1.w;)";
60       } else {
61         result = R"(
62 $0.x = $1.x < INIT_FLT(0.0f) ? exp($1.x) - INIT_FLT(1.0f) : $1.x;
63 $0.y = $1.y < INIT_FLT(0.0f) ? exp($1.y) - INIT_FLT(1.0f) : $1.y;
64 $0.z = $1.z < INIT_FLT(0.0f) ? exp($1.z) - INIT_FLT(1.0f) : $1.z;
65 $0.w = $1.w < INIT_FLT(0.0f) ? exp($1.w) - INIT_FLT(1.0f) : $1.w;)";
66       }
67       break;
68     case OperationType::EXP:
69       if (use_native_opencl_functions) {
70         result = "$0 = convert_half4(native_exp(convert_float4($1)));";
71       } else {
72         result = "$0 = exp($1);";
73       }
74       break;
75     case OperationType::FLOOR:
76       result = "$0 = floor($1);";
77       break;
78     case OperationType::HARD_SWISH:
79       result =
80           "$0 = $1 * clamp($1 * INIT_FLT(0.16666667f) + INIT_FLT(0.5f), "
81           "INIT_FLT4(0.0f), "
82           "INIT_FLT4(1.0f));";
83       break;
84     case OperationType::LOG:
85       if (use_native_opencl_functions) {
86         result = "$0 = convert_half4(native_log(convert_float4($1)));";
87       } else {
88         result = "$0 = log($1);";
89       }
90       break;
91     case OperationType::NEG:
92       result = "$0 = -($1);";
93       break;
94     case OperationType::RSQRT:
95       if (use_native_opencl_functions) {
96         result = "$0 = convert_half4(native_rsqrt(convert_float4($1)));";
97       } else {
98         result = "$0 = rsqrt($1);";
99       }
100       break;
101     case OperationType::SIGMOID:
102       if (use_native_opencl_functions) {
103         result =
104             "$0 = convert_half4(native_recip(1.0f + "
105             "native_exp(convert_float4(-$1))));";
106       } else {
107         result = "$0 = INIT_FLT4(1.0f) / (INIT_FLT4(1.0f) + exp(-($1)));";
108       }
109       break;
110     case OperationType::SIN:
111       if (use_native_opencl_functions) {
112         result = "$0 = convert_half4(native_sin(convert_float4($1)));";
113       } else {
114         result = "$0 = sin($1);";
115       }
116       break;
117     case OperationType::SQRT:
118       if (use_native_opencl_functions) {
119         result = "$0 = convert_half4(native_sqrt(convert_float4($1)));";
120       } else {
121         result = "$0 = sqrt($1);";
122       }
123       break;
124     case OperationType::SQUARE:
125       result = "$0 = $1 * $1;";
126       break;
127     case OperationType::TANH:
128       if (use_native_opencl_functions) {
129         result =
130             "FLT4 exp_val = convert_half4(native_exp(2.0f * "
131             "convert_float4($1)));\n";
132         result +=
133             "$0 = ((exp_val - INIT_FLT4(1.0f)) / (exp_val + "
134             "INIT_FLT4(1.0f)));";
135       } else {
136         result = "$0 = tanh($1);";
137       }
138       break;
139     default:
140       return "Unknown operation type;";
141   }
142   return absl::Substitute(result, output_value, input_value);
143 }
144 
GetTwoInputCode(const OperationType & op_type,const std::string & result_var,const std::string & input0,const std::string & input1,bool swap_inputs=false)145 std::string GetTwoInputCode(const OperationType& op_type,
146                             const std::string& result_var,
147                             const std::string& input0,
148                             const std::string& input1,
149                             bool swap_inputs = false) {
150   std::string result;
151   switch (op_type) {
152     case OperationType::ADD:
153       result += "$0 = $1 + $2;";
154       break;
155     case OperationType::DIV:
156       result += "$0 = $1 / $2;";
157       break;
158     case OperationType::FLOOR_DIV:
159       result = "$0 = floor($1 / $2);";
160       break;
161     case OperationType::FLOOR_MOD:
162       result = "$0 = $1 - floor($1 / $2) * $2;";
163       break;
164     case OperationType::MAXIMUM:
165       result += "$0 = max($1, $2);";
166       break;
167     case OperationType::MINIMUM:
168       result += "$0 = min($1, $2);";
169       break;
170     case OperationType::MUL:
171       result += "$0 = $1 * $2;";
172       break;
173     case OperationType::POW:
174       result += "$0 = pow($1, $2);";
175       break;
176     case OperationType::SQUARED_DIFF:
177       result += "$0 = ($1 - $2) * ($1 - $2);";
178       break;
179     case OperationType::SUB:
180       result += "$0 = $1 - $2;";
181       break;
182     // Comparison operators
183     case OperationType::LESS:
184       result = "$0.x = $1.x < $2.x;\n";
185       result += "$0.y = $1.y < $2.y;\n";
186       result += "$0.z = $1.z < $2.z;\n";
187       result += "$0.w = $1.w < $2.w;";
188       break;
189     case OperationType::LESS_EQUAL:
190       result = "$0.x = $1.x <= $2.x;\n";
191       result += "$0.y = $1.y <= $2.y;\n";
192       result += "$0.z = $1.z <= $2.z;\n";
193       result += "$0.w = $1.w <= $2.w;";
194       break;
195     case OperationType::GREATER:
196       result = "$0.x = $1.x > $2.x;\n";
197       result += "$0.y = $1.y > $2.y;\n";
198       result += "$0.z = $1.z > $2.z;\n";
199       result += "$0.w = $1.w > $2.w;";
200       break;
201     case OperationType::GREATER_EQUAL:
202       result = "$0.x = $1.x >= $2.x;\n";
203       result += "$0.y = $1.y >= $2.y;\n";
204       result += "$0.z = $1.z >= $2.z;\n";
205       result += "$0.w = $1.w >= $2.w;";
206       break;
207     case OperationType::EQUAL:
208       result = "$0.x = $1.x == $2.x;\n";
209       result += "$0.y = $1.y == $2.y;\n";
210       result += "$0.z = $1.z == $2.z;\n";
211       result += "$0.w = $1.w == $2.w;";
212       break;
213     case OperationType::NOT_EQUAL:
214       result = "$0.x = $1.x != $2.x;\n";
215       result += "$0.y = $1.y != $2.y;\n";
216       result += "$0.z = $1.z != $2.z;\n";
217       result += "$0.w = $1.w != $2.w;";
218       break;
219     default:
220       return "Unknown operation type;";
221   }
222   if (swap_inputs) {
223     return absl::Substitute(result, result_var, input1, input0);
224   } else {
225     return absl::Substitute(result, result_var, input0, input1);
226   }
227 }
228 
229 // Creates simple two input (first input is runtime tensor and second input is
230 // scalar argument) operation, for example sub, div, pow, etc.
CreateElementwiseOneRuntimeOneScalar(const OperationDef & definition,const OperationType & op_type,float scalar_parameter,bool swap_inputs)231 ElementwiseDescriptor CreateElementwiseOneRuntimeOneScalar(
232     const OperationDef& definition, const OperationType& op_type,
233     float scalar_parameter, bool swap_inputs) {
234   ElementwiseDescriptor op_desc;
235   if (definition.precision == CalculationsPrecision::F32) {
236     op_desc.args.AddFloat("scalar", scalar_parameter);
237   } else {
238     op_desc.args.AddHalf("scalar", half(scalar_parameter));
239   }
240   op_desc.code = "FLT4 second_val = INIT_FLT4(args.scalar);\n";
241   op_desc.code += GetTwoInputCode(op_type, "out_value", "in_value",
242                                   "second_val", swap_inputs);
243   return op_desc;
244 }
245 
246 // Creates simple two input(first input is runtime tensor and second input is
247 // constant linear tensor) operation, for example sub, div and etc.
CreateElementwiseTwoInput(const GpuInfo & gpu_info,const OperationDef & definition,const OperationType & op_type,const tflite::gpu::Tensor<Linear,DataType::FLOAT32> & constant_tensor,bool swap_inputs)248 ElementwiseDescriptor CreateElementwiseTwoInput(
249     const GpuInfo& gpu_info, const OperationDef& definition,
250     const OperationType& op_type,
251     const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& constant_tensor,
252     bool swap_inputs) {
253   TensorDescriptor const_tensor_desc = CreateConstantLinearTensorDescriptor(
254       gpu_info, definition.src_tensors[0].GetDataType(), constant_tensor);
255   ElementwiseDescriptor op_desc;
256   op_desc.args.AddObject("second_tensor", std::make_unique<TensorDescriptor>(
257                                               std::move(const_tensor_desc)));
258   const std::string s_coord = constant_tensor.shape.v == 1 ? "0" : "S_COORD";
259   op_desc.code = absl::StrCat(
260       "args.second_tensor::type second_val = args.second_tensor.Read(", s_coord,
261       ");\n");
262   if (constant_tensor.shape.v == 1) {
263     op_desc.code += "  second_val.y = second_val.x;\n";
264     op_desc.code += "  second_val.z = second_val.x;\n";
265     op_desc.code += "  second_val.w = second_val.x;\n";
266   }
267   op_desc.code += GetTwoInputCode(op_type, "out_value", "in_value",
268                                   "second_val", swap_inputs);
269   return op_desc;
270 }
271 
272 // Creates simple two input(first input is runtime tensor and second input is
273 // constant HWC tensor) operation, for example sub, div and etc.
CreateElementwiseTwoInput(const GpuInfo & gpu_info,const OperationDef & definition,const OperationType & op_type,const tflite::gpu::Tensor<HWC,DataType::FLOAT32> & constant_tensor,bool swap_inputs)274 ElementwiseDescriptor CreateElementwiseTwoInput(
275     const GpuInfo& gpu_info, const OperationDef& definition,
276     const OperationType& op_type,
277     const tflite::gpu::Tensor<HWC, DataType::FLOAT32>& constant_tensor,
278     bool swap_inputs) {
279   const BHWC shape = BHWC(1, constant_tensor.shape.h, constant_tensor.shape.w,
280                           constant_tensor.shape.c);
281   TensorDescriptor const_tensor_desc = definition.src_tensors[0];
282   auto status = const_tensor_desc.UpdateToSupportedStorageType(gpu_info, shape);
283   const_tensor_desc.UploadData(constant_tensor);
284 
285   ElementwiseDescriptor op_desc;
286   op_desc.args.AddObject("second_tensor", std::make_unique<TensorDescriptor>(
287                                               std::move(const_tensor_desc)));
288   const std::string x_coord = shape.w == 1 ? "0" : "X_COORD";
289   const std::string y_coord = shape.h == 1 ? "0" : "Y_COORD";
290   const std::string s_coord = shape.c == 1 ? "0" : "S_COORD";
291   op_desc.code = absl::StrCat(
292       "args.second_tensor::type second_val = args.second_tensor.Read(", x_coord,
293       ", ", y_coord, ", ", s_coord, ");\n");
294   if (shape.c == 1) {
295     op_desc.code += "  second_val.y = second_val.x;\n";
296     op_desc.code += "  second_val.z = second_val.x;\n";
297     op_desc.code += "  second_val.w = second_val.x;\n";
298   }
299   op_desc.code += GetTwoInputCode(op_type, "out_value", "in_value",
300                                   "second_val", swap_inputs);
301 
302   return op_desc;
303 }
304 
CreateElementwiseDesc(const GpuInfo & gpu_info,const OperationDef & definition,const OperationType & op_type,const ElementwiseAttributes & attr)305 ElementwiseDescriptor CreateElementwiseDesc(const GpuInfo& gpu_info,
306                                             const OperationDef& definition,
307                                             const OperationType& op_type,
308                                             const ElementwiseAttributes& attr) {
309   const float* scalar = absl::get_if<float>(&attr.param);
310   const auto* linear_tensor =
311       absl::get_if<tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(&attr.param);
312   const auto* hwc_tensor =
313       absl::get_if<tflite::gpu::Tensor<HWC, DataType::FLOAT32>>(&attr.param);
314 
315   if (scalar) {
316     return CreateElementwiseOneRuntimeOneScalar(definition, op_type, *scalar,
317                                                 attr.runtime_tensor_is_second);
318   } else if (linear_tensor) {
319     return CreateElementwiseTwoInput(gpu_info, definition, op_type,
320                                      *linear_tensor,
321                                      attr.runtime_tensor_is_second);
322   } else if (hwc_tensor) {
323     return CreateElementwiseTwoInput(gpu_info, definition, op_type, *hwc_tensor,
324                                      attr.runtime_tensor_is_second);
325   } else {
326     return ElementwiseDescriptor();
327   }
328 }
329 
330 }  // namespace
331 
CreateElementwiseOneInput(const GpuInfo & gpu_info,const OperationDef & definition,const OperationType & op_type)332 GPUOperation CreateElementwiseOneInput(const GpuInfo& gpu_info,
333                                        const OperationDef& definition,
334                                        const OperationType& op_type) {
335   ElementwiseDescriptor op_desc;
336   op_desc.code = GetOneInputCode(gpu_info, op_type, definition.precision,
337                                  "in_value", "out_value");
338   return CreateGpuOperation(definition, std::move(op_desc));
339 }
340 
CreateElementwise(const GpuInfo & gpu_info,const OperationDef & definition,const OperationType & op_type,const ElementwiseAttributes & attr)341 GPUOperation CreateElementwise(const GpuInfo& gpu_info,
342                                const OperationDef& definition,
343                                const OperationType& op_type,
344                                const ElementwiseAttributes& attr) {
345   return CreateGpuOperation(
346       definition, CreateElementwiseDesc(gpu_info, definition, op_type, attr));
347 }
348 
CreateElementwiseTwoInput(const OperationDef & definition,const OperationType & op_type,const BHWC & shape)349 GPUOperation CreateElementwiseTwoInput(const OperationDef& definition,
350                                        const OperationType& op_type,
351                                        const BHWC& shape) {
352   ElementwiseDescriptor op_desc;
353   op_desc.code =
354       GetTwoInputCode(op_type, "out_value", "in_value", "in2_value", false);
355   return CreateGpuOperation(definition, std::move(op_desc), shape);
356 }
357 
358 namespace {
GetKernelBodyCode(const TensorDescriptor & dst_desc)359 std::string GetKernelBodyCode(const TensorDescriptor& dst_desc) {
360   std::string c;
361   c += "MAIN_FUNCTION($$0) {\n";
362   if (dst_desc.HasAxis(Axis::BATCH)) {
363     c += "  int linear_id = GLOBAL_ID_0;\n";
364     c += "  int X = linear_id / args.dst_tensor.Batch();\n";
365     c += "  int B = linear_id % args.dst_tensor.Batch();\n";
366     c += "  args.dst_tensor.SetBatchRef(B);\n";
367   } else {
368     c += "  int X = GLOBAL_ID_0;\n";
369   }
370   c += "  int Y = GLOBAL_ID_1;\n";
371   c += "  int S = GLOBAL_ID_2;\n";
372   c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
373        "S >= args.dst_tensor.Slices()) return; \n";
374   c += "  args.dst_tensor::type result;\n";
375   c += "  $0\n";
376   c += "  args.dst_tensor.Write(result, X, Y, S);\n";
377   c += "} \n";
378   return c;
379 }
GetReadBroadcastedValueCode(const BHWC & src_shape,const TensorDescriptor & src_desc,const BHWC & dst_shape)380 std::string GetReadBroadcastedValueCode(const BHWC& src_shape,
381                                         const TensorDescriptor& src_desc,
382                                         const BHWC& dst_shape) {
383   const std::string x_coord = src_shape.w != dst_shape.w ? "0" : "X";
384   const std::string y_coord = src_shape.h != dst_shape.h ? "0" : "Y";
385   const std::string s_coord = src_shape.c != dst_shape.c ? "0" : "S";
386   std::string coords = absl::StrCat(x_coord, ", ", y_coord, ", ", s_coord);
387   if (src_desc.HasAxis(Axis::BATCH)) {
388     const std::string b_coord = src_shape.b != dst_shape.b ? "0" : "B";
389     coords += ", " + b_coord;
390   }
391   std::string read_value_code =
392       absl::StrCat("args.$0::type $1 = args.$0.Read(", coords, ");\n");
393   if (src_shape.c != dst_shape.c) {
394     read_value_code += "  $1.y = $1.x;\n";
395     read_value_code += "  $1.z = $1.x;\n";
396     read_value_code += "  $1.w = $1.x;\n";
397   }
398   return read_value_code;
399 }
400 }  // namespace
401 
CreateElementwiseOneInputWithBroadcast(const GpuInfo & gpu_info,const OperationDef & definition,const OperationType & op_type,const BHWC & input_shape,const BHWC & output_shape)402 GPUOperation CreateElementwiseOneInputWithBroadcast(
403     const GpuInfo& gpu_info, const OperationDef& definition,
404     const OperationType& op_type, const BHWC& input_shape,
405     const BHWC& output_shape) {
406   GPUOperation op(definition);
407   op.AddSrcTensor("src_tensor", definition.src_tensors[0]);
408   op.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
409   op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
410   std::string c;
411   c += "  " + absl::Substitute(
412                   GetReadBroadcastedValueCode(
413                       input_shape, definition.src_tensors[0], output_shape),
414                   "src_tensor", "first_value");
415   c += "  " + GetOneInputCode(gpu_info, op_type, definition.precision,
416                               "first_value", "result");
417   op.code_ = absl::Substitute(GetKernelBodyCode(definition.dst_tensors[0]), c);
418   return op;
419 }
420 
CreateElementwiseWithBroadcast(const GpuInfo & gpu_info,const OperationDef & definition,const OperationType & op_type,const ElementwiseAttributes & attr,const BHWC & input_shape,const BHWC & output_shape)421 GPUOperation CreateElementwiseWithBroadcast(const GpuInfo& gpu_info,
422                                             const OperationDef& definition,
423                                             const OperationType& op_type,
424                                             const ElementwiseAttributes& attr,
425                                             const BHWC& input_shape,
426                                             const BHWC& output_shape) {
427   ElementwiseDescriptor op_desc =
428       CreateElementwiseDesc(gpu_info, definition, op_type, attr);
429 
430   GPUOperation op(definition);
431   op.args_ = std::move(op_desc.args);
432   op.AddSrcTensor("src_tensor", definition.src_tensors[0]);
433   op.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
434   op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
435   std::string c;
436   c += "  " + absl::Substitute(
437                   GetReadBroadcastedValueCode(
438                       input_shape, definition.src_tensors[0], output_shape),
439                   "src_tensor", "first_value");
440   c += "  " + absl::StrReplaceAll(op_desc.code, {{"in_value", "first_value"},
441                                                  {"out_value", "result"},
442                                                  {"X_COORD", "X"},
443                                                  {"Y_COORD", "Y"},
444                                                  {"S_COORD", "S"},
445                                                  {"B_COORD", "B"}});
446   op.code_ = absl::Substitute(GetKernelBodyCode(definition.dst_tensors[0]), c);
447   return op;
448 }
449 
CreateElementwiseTwoInputWithBroadcast(const OperationDef & definition,const OperationType & op_type,const BHWC & first_input_shape,const BHWC & second_input_shape,const BHWC & output_shape)450 GPUOperation CreateElementwiseTwoInputWithBroadcast(
451     const OperationDef& definition, const OperationType& op_type,
452     const BHWC& first_input_shape, const BHWC& second_input_shape,
453     const BHWC& output_shape) {
454   GPUOperation op(definition);
455   op.AddSrcTensor("src0_tensor", definition.src_tensors[0]);
456   op.AddSrcTensor("src1_tensor", definition.src_tensors[1]);
457   op.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
458   op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
459   std::string c;
460   c += "  " + absl::Substitute(GetReadBroadcastedValueCode(
461                                    first_input_shape, definition.src_tensors[0],
462                                    output_shape),
463                                "src0_tensor", "first_value");
464   c += "  " + absl::Substitute(GetReadBroadcastedValueCode(
465                                    second_input_shape,
466                                    definition.src_tensors[1], output_shape),
467                                "src1_tensor", "second_value");
468   c += "  " +
469        GetTwoInputCode(op_type, "result", "first_value", "second_value", false);
470   op.code_ = absl::Substitute(GetKernelBodyCode(definition.dst_tensors[0]), c);
471   return op;
472 }
473 
474 }  // namespace gpu
475 }  // namespace tflite
476