1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/elementwise.h"
17
18 #include <any>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <variant>
23
24 #include "absl/memory/memory.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/lite/delegates/gpu/common/status.h"
27 #include "tensorflow/lite/delegates/gpu/common/types.h"
28
29 namespace tflite {
30 namespace gpu {
31 namespace gl {
32 namespace {
33
34 class ElementwiseOneArgument : public NodeShader {
35 public:
ElementwiseOneArgument(OperationType operation_type)36 explicit ElementwiseOneArgument(OperationType operation_type)
37 : operation_type_(operation_type) {}
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const38 absl::Status GenerateCode(const GenerationContext& ctx,
39 GeneratedCode* generated_code) const final {
40 std::string source;
41 switch (operation_type_) {
42 case OperationType::ABS:
43 source = "value_0 = abs(value_0);";
44 break;
45 case OperationType::COS:
46 source = "value_0 = cos(value_0);";
47 break;
48 case OperationType::COPY:
49 source = "value_0 = value_0;";
50 break;
51 case OperationType::ELU:
52 source = R"(
53 value_0.x = value_0.x < 0.0 ? exp(value_0.x) - 1.0 : value_0.x;
54 value_0.y = value_0.y < 0.0 ? exp(value_0.y) - 1.0 : value_0.y;
55 value_0.z = value_0.z < 0.0 ? exp(value_0.z) - 1.0 : value_0.z;
56 value_0.w = value_0.w < 0.0 ? exp(value_0.w) - 1.0 : value_0.w;
57 )";
58 break;
59 case OperationType::EXP:
60 source = "value_0 = exp(value_0);";
61 break;
62 case tflite::gpu::OperationType::FLOOR:
63 source = "value_0 = floor(value_0);";
64 break;
65 case OperationType::HARD_SWISH:
66 source =
67 "value_0 *= clamp(value_0 / 6.0 + vec4(0.5), vec4(0.0), "
68 "vec4(1.0));";
69 break;
70 case OperationType::LOG:
71 source = R"(
72 const float nan = normalize(vec4(0, 0, 0, 0)).x;
73 value_0.x = value_0.x > 0.0 ? log(value_0.x) : nan;
74 value_0.y = value_0.y > 0.0 ? log(value_0.y) : nan;
75 value_0.z = value_0.z > 0.0 ? log(value_0.z) : nan;
76 value_0.w = value_0.w > 0.0 ? log(value_0.w) : nan;
77 )";
78 break;
79 case OperationType::NEG:
80 source = "value_0 = -(value_0);";
81 break;
82 case OperationType::RSQRT:
83 source = R"(
84 const float nan = normalize(vec4(0, 0, 0, 0)).x;
85 value_0.x = value_0.x > 0.0 ? 1.0 / sqrt(value_0.x) : nan;
86 value_0.y = value_0.y > 0.0 ? 1.0 / sqrt(value_0.y) : nan;
87 value_0.z = value_0.z > 0.0 ? 1.0 / sqrt(value_0.z) : nan;
88 value_0.w = value_0.w > 0.0 ? 1.0 / sqrt(value_0.w) : nan;
89 )";
90 break;
91 case OperationType::SIGMOID:
92 source = "value_0 = 1.0 / (1.0 + exp(-1.0 * value_0));";
93 break;
94 case OperationType::SIN:
95 source = "value_0 = sin(value_0);";
96 break;
97 case OperationType::SQRT:
98 source = R"(
99 const float nan = normalize(vec4(0, 0, 0, 0)).x;
100 value_0.x = value_0.x >= 0.0 ? sqrt(value_0.x) : nan;
101 value_0.y = value_0.y >= 0.0 ? sqrt(value_0.y) : nan;
102 value_0.z = value_0.z >= 0.0 ? sqrt(value_0.z) : nan;
103 value_0.w = value_0.w >= 0.0 ? sqrt(value_0.w) : nan;
104 )";
105 break;
106 case OperationType::SQUARE:
107 source = "value_0 = value_0 * value_0;";
108 break;
109 case OperationType::TANH:
110 source = "value_0 = tanh(value_0);";
111 break;
112 default:
113 return absl::InvalidArgumentError(
114 "Incorrect elementwise operation type.");
115 }
116 *generated_code = {
117 /*parameters=*/{},
118 /*objects=*/{},
119 /*shared_variables=*/{},
120 /*workload=*/uint3(),
121 /*workgroup=*/uint3(),
122 source,
123 /*input=*/IOStructure::AUTO,
124 /*output=*/IOStructure::AUTO,
125 };
126 return absl::OkStatus();
127 }
128
129 private:
130 OperationType operation_type_;
131 };
132
133 class ElementwiseTwoArguments : public NodeShader {
134 public:
ElementwiseTwoArguments(OperationType operation_type)135 explicit ElementwiseTwoArguments(OperationType operation_type)
136 : operation_type_(operation_type) {}
137
IsElementwiseSupported(const GenerationContext & ctx) const138 inline bool IsElementwiseSupported(const GenerationContext& ctx) const {
139 return ctx.input_shapes.size() == 2 &&
140 ctx.input_shapes[0] == ctx.input_shapes[1];
141 }
142
IsBroadcastSupported(const GenerationContext & ctx) const143 inline bool IsBroadcastSupported(const GenerationContext& ctx) const {
144 return ctx.input_shapes.size() == 2 && ctx.input_shapes[1][1] == 1 &&
145 ctx.input_shapes[1][2] == 1 &&
146 ctx.input_shapes[0][3] == ctx.input_shapes[1][3];
147 }
148
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const149 absl::Status GenerateCode(const GenerationContext& ctx,
150 GeneratedCode* generated_code) const final {
151 std::vector<Variable> parameters;
152 std::vector<std::pair<std::string, Object>> objects;
153 std::string argument0, argument1;
154 if (IsElementwiseSupported(ctx)) {
155 argument0 = "value_0";
156 argument1 = "value_1";
157 } else if (IsBroadcastSupported(ctx)) {
158 argument0 = "$input_data_0[gid.x, gid.y, gid.z]$";
159 argument1 = "$input_data_1[0, 0, gid.z]$";
160 } else { // Scalar of const vector case
161 const auto& attr =
162 std::any_cast<const ElementwiseAttributes&>(ctx.op_attr);
163 const auto* tensor =
164 std::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
165 const auto* scalar = std::get_if<float>(&attr.param);
166 if (!tensor && !scalar) {
167 return absl::InvalidArgumentError(
168 "Couldn't read scalar of const vector data from the attributes.");
169 }
170
171 argument0 = "value_0";
172 if (tensor) {
173 argument1 = "$const_data[gid.z]$";
174 objects.push_back({"const_data", MakeReadonlyObject(tensor->data)});
175 } else {
176 argument1 = "vec4($const_data$)";
177 parameters.push_back({"const_data", *scalar});
178 }
179 }
180
181 std::string source;
182 switch (operation_type_) {
183 case OperationType::DIV: {
184 source = "value_0 = $0/$1;";
185 break;
186 }
187 case tflite::gpu::OperationType::FLOOR_DIV:
188 source = "value_0 = floor($0 / $1);";
189 break;
190 case tflite::gpu::OperationType::FLOOR_MOD:
191 source = "value_0 = $0 - floor($0 / $1) * $1;";
192 break;
193 case OperationType::MAXIMUM: {
194 source = "value_0 = max($0, $1);";
195 break;
196 }
197 case OperationType::MINIMUM: {
198 source = "value_0 = min($0, $1);";
199 break;
200 }
201 case OperationType::SQUARED_DIFF: {
202 source = "value_0 = ($0 - $1) * ($0 - $1);";
203 break;
204 }
205 case OperationType::SUB: {
206 source = "value_0 = $0 - $1;";
207 break;
208 }
209 case OperationType::POW: {
210 source = "value_0 = pow($0, $1);";
211 break;
212 }
213 default:
214 return absl::InvalidArgumentError(
215 "Incorrect elementwise with scalar operation type.");
216 }
217 source = absl::Substitute(source, argument0, argument1);
218 *generated_code = {
219 /*parameters=*/std::move(parameters),
220 /*objects=*/std::move(objects),
221 /*shared_variables=*/{},
222 /*workload=*/uint3(),
223 /*workgroup=*/uint3(),
224 /*source_code=*/source,
225 /*input=*/IOStructure::AUTO,
226 /*output=*/IOStructure::AUTO,
227 };
228 return absl::OkStatus();
229 }
230
231 private:
232 OperationType operation_type_;
233 };
234
235 } // namespace
236
NewElementwiseNodeShader(OperationType operation_type)237 std::unique_ptr<NodeShader> NewElementwiseNodeShader(
238 OperationType operation_type) {
239 switch (operation_type) {
240 case OperationType::ABS:
241 case OperationType::COS:
242 case OperationType::COPY:
243 case OperationType::ELU:
244 case OperationType::EXP:
245 case OperationType::FLOOR:
246 case OperationType::HARD_SWISH:
247 case OperationType::LOG:
248 case OperationType::NEG:
249 case OperationType::RSQRT:
250 case OperationType::SIGMOID:
251 case OperationType::SIN:
252 case OperationType::SQRT:
253 case OperationType::SQUARE:
254 case OperationType::TANH:
255 return std::make_unique<ElementwiseOneArgument>(operation_type);
256 case OperationType::DIV:
257 case OperationType::FLOOR_DIV:
258 case OperationType::FLOOR_MOD:
259 case OperationType::MAXIMUM:
260 case OperationType::MINIMUM:
261 case OperationType::POW:
262 case OperationType::SQUARED_DIFF:
263 case OperationType::SUB:
264 return std::make_unique<ElementwiseTwoArguments>(operation_type);
265 default:
266 return nullptr;
267 }
268 }
269
270 } // namespace gl
271 } // namespace gpu
272 } // namespace tflite
273