1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/strings/str_replace.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/lite/delegates/gpu/common/access_type.h"
27 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
28 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
29
30 namespace tflite {
31 namespace gpu {
32 namespace {
GetWorkGroupsCountInternal(int grid_dimension,const int3 & grid_size,const int3 & work_group_size,const int3 & work_group_launch_order)33 int3 GetWorkGroupsCountInternal(int grid_dimension, const int3& grid_size,
34 const int3& work_group_size,
35 const int3& work_group_launch_order) {
36 int3 work_groups_count;
37 if (grid_dimension == 1) {
38 work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x);
39 work_groups_count.y = 1;
40 work_groups_count.z = 1;
41 } else if (grid_dimension == 2) {
42 int3 wgs;
43 wgs.x = DivideRoundUp(grid_size.x, work_group_size.x);
44 wgs.y = DivideRoundUp(grid_size.y, work_group_size.y);
45 work_groups_count.x = wgs[work_group_launch_order[0]];
46 work_groups_count.y = wgs[work_group_launch_order[1]];
47 work_groups_count.z = 1;
48 } else { // grid_dimension == 3
49 int3 wgs;
50 wgs.x = DivideRoundUp(grid_size.x, work_group_size.x);
51 wgs.y = DivideRoundUp(grid_size.y, work_group_size.y);
52 wgs.z = DivideRoundUp(grid_size.z, work_group_size.z);
53 work_groups_count.x = wgs[work_group_launch_order[0]];
54 work_groups_count.y = wgs[work_group_launch_order[1]];
55 work_groups_count.z = wgs[work_group_launch_order[2]];
56 }
57 return work_groups_count;
58 }
59
GetElementWiseCode(const OperationDef & op_def)60 std::string GetElementWiseCode(const OperationDef& op_def) {
61 std::string c;
62 c += "MAIN_FUNCTION($0) {\n";
63 if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
64 c += " int linear_id = GLOBAL_ID_0;\n";
65 c += " int X = linear_id / args.dst_tensor.Batch();\n";
66 c += " int B = linear_id % args.dst_tensor.Batch();\n";
67 c += " args.dst_tensor.SetBatchRef(B);\n";
68 c += " args.src_tensor.SetBatchRef(B);\n";
69 } else {
70 c += " int X = GLOBAL_ID_0;\n";
71 }
72 c += " int Y = GLOBAL_ID_1;\n";
73 c += " int Z = GLOBAL_ID_2;\n";
74 c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
75 "Z >= args.dst_tensor.Slices()) return; \n";
76 c += " args.src_tensor::type src = args.src_tensor.Read(X, Y, Z);\n";
77 c += " args.dst_tensor.Write(src, X, Y, Z);\n";
78 c += "} \n";
79 return c;
80 }
81
NeedsBroadcast(const TensorDescriptor & desc,const BHWC & shape)82 bool NeedsBroadcast(const TensorDescriptor& desc, const BHWC& shape) {
83 bool needs_broadcast = shape.w == 1 || shape.h == 1 || shape.c == 1;
84 if (desc.HasAxis(Axis::BATCH)) {
85 needs_broadcast = needs_broadcast || shape.b == 1;
86 }
87 return needs_broadcast;
88 }
89
90 } // namespace
91
GetDataType() const92 DataType OperationDef::GetDataType() const {
93 return DeduceDataTypeFromPrecision(precision);
94 }
95
GetPrimaryDataType() const96 DataType OperationDef::GetPrimaryDataType() const {
97 return src_tensors[0].GetDataType();
98 }
GetPrimaryStorageType() const99 TensorStorageType OperationDef::GetPrimaryStorageType() const {
100 return src_tensors[0].GetStorageType();
101 }
102
IsBatchSupported() const103 bool OperationDef::IsBatchSupported() const {
104 for (const auto& src : src_tensors) {
105 if (src.HasAxis(Axis::BATCH)) {
106 return true;
107 }
108 }
109 for (const auto& dst : dst_tensors) {
110 if (dst.HasAxis(Axis::BATCH)) {
111 return true;
112 }
113 }
114 return false;
115 }
116
GPUOperation(const OperationDef & definition)117 GPUOperation::GPUOperation(const OperationDef& definition)
118 : definition_(definition) {}
119
SetSrc(GpuSpatialTensor * ptr,int index)120 void GPUOperation::SetSrc(GpuSpatialTensor* ptr, int index) {
121 if (index >= src_.size()) {
122 src_.resize(index + 1, nullptr);
123 }
124 src_[index] = ptr;
125 }
126
SetDst(GpuSpatialTensor * ptr,int index)127 void GPUOperation::SetDst(GpuSpatialTensor* ptr, int index) {
128 if (index >= dst_.size()) {
129 dst_.resize(index + 1, nullptr);
130 }
131 dst_[index] = ptr;
132 }
133
GPUOperation(GPUOperation && operation)134 GPUOperation::GPUOperation(GPUOperation&& operation)
135 : args_(std::move(operation.args_)),
136 code_(std::move(operation.code_)),
137 work_group_size_(operation.work_group_size_),
138 compiler_options_(std::move(operation.compiler_options_)),
139 tensor_to_grid_(operation.tensor_to_grid_),
140 flops_(operation.flops_),
141 const_args_size_(operation.const_args_size_),
142 definition_(std::move(operation.definition_)),
143 src_(std::move(operation.src_)),
144 dst_(std::move(operation.dst_)),
145 grid_dimension_(operation.grid_dimension_),
146 work_group_launch_order_(operation.work_group_launch_order_),
147 grid_size_(operation.grid_size_),
148 src_tensors_names_(std::move(operation.src_tensors_names_)),
149 dst_tensors_names_(std::move(operation.dst_tensors_names_)),
150 work_groups_count_(operation.work_groups_count_),
151 elementwise_(operation.elementwise_),
152 elementwise_inputs_(operation.elementwise_inputs_),
153 second_elementwise_tensor_name_(
154 operation.second_elementwise_tensor_name_),
155 linkable_count_(operation.linkable_count_),
156 elementwise_code_(std::move(operation.elementwise_code_)) {}
157
operator =(GPUOperation && operation)158 GPUOperation& GPUOperation::operator=(GPUOperation&& operation) {
159 if (this != &operation) {
160 args_ = std::move(operation.args_);
161 code_ = std::move(operation.code_);
162 std::swap(work_group_size_, operation.work_group_size_);
163 compiler_options_ = std::move(operation.compiler_options_);
164 tensor_to_grid_ = operation.tensor_to_grid_;
165 flops_ = operation.flops_;
166 const_args_size_ = operation.const_args_size_;
167 definition_ = std::move(operation.definition_);
168 src_ = std::move(operation.src_);
169 dst_ = std::move(operation.dst_);
170 std::swap(grid_dimension_, operation.grid_dimension_);
171 std::swap(work_group_launch_order_, operation.work_group_launch_order_);
172 std::swap(grid_size_, operation.grid_size_);
173 src_tensors_names_ = std::move(operation.src_tensors_names_);
174 dst_tensors_names_ = std::move(operation.dst_tensors_names_);
175 std::swap(work_groups_count_, operation.work_groups_count_);
176 elementwise_ = operation.elementwise_;
177 std::swap(elementwise_inputs_, operation.elementwise_inputs_);
178 std::swap(second_elementwise_tensor_name_,
179 operation.second_elementwise_tensor_name_);
180 std::swap(linkable_count_, operation.linkable_count_);
181 elementwise_code_ = std::move(operation.elementwise_code_);
182 }
183 return *this;
184 }
185
186 // input input
187 // | |
188 // elem0 |
189 // | --> elem
190 // elem1 |
191 // | |
192 // output output
193 // GPUOperation* operation is elem1
194 // *this is elem0
FuseSimpleElemWithSimpleElem(const GpuInfo & gpu_info,GPUOperation * operation)195 absl::Status GPUOperation::FuseSimpleElemWithSimpleElem(
196 const GpuInfo& gpu_info, GPUOperation* operation) {
197 GPUOperation& elem0 = *this;
198 GPUOperation& elem1 = *operation;
199 elem0.definition_.dst_tensors[0] = elem1.definition_.dst_tensors[0];
200 const auto link_value_type = elem1.definition_.src_tensors[0].GetDataType();
201 elem0.linkable_count_ += (elem1.linkable_count_ + 1);
202 std::string unique_postfix = absl::StrCat("_link", elem0.linkable_count_);
203 elem1.args_.RenameArgs(unique_postfix, &elem1.elementwise_code_);
204 const std::string link_value_name = "interm_value" + unique_postfix;
205 const std::string value_declaration =
206 "\n" + GetTypeDeclaration(gpu_info, link_value_type, 4) + " " +
207 link_value_name + ";\n";
208 elem1.elementwise_code_ = absl::StrReplaceAll(
209 elem1.elementwise_code_, {{"in_value", link_value_name}});
210 elem0.elementwise_code_ = absl::StrReplaceAll(
211 elem0.elementwise_code_, {{"out_value", link_value_name}});
212 elem0.elementwise_code_ =
213 absl::Substitute(elem0.elementwise_code_, value_declaration);
214 elem0.elementwise_code_ += "\n" + elem1.elementwise_code_;
215 return args_.Merge(std::move(elem1.args_), unique_postfix);
216 }
217
218 // input input
219 // / \ |
220 // elem0 | |
221 // \ / --> elem
222 // elem1 |
223 // | |
224 // output output
225 // GPUOperation* operation is elem1
226 // *this is elem0
Fuse2InputElemWithSimpleElemAsFirstInput(const GpuInfo & gpu_info,GPUOperation * operation)227 absl::Status GPUOperation::Fuse2InputElemWithSimpleElemAsFirstInput(
228 const GpuInfo& gpu_info, GPUOperation* operation) {
229 GPUOperation& elem0 = *this;
230 GPUOperation& elem1 = *operation;
231 const auto link_value_type = elem0.definition_.dst_tensors[0].GetDataType();
232 elem0.definition_.dst_tensors[0] = elem1.definition_.dst_tensors[0];
233 elem0.linkable_count_ += (elem1.linkable_count_ + 1);
234 std::string unique_postfix = absl::StrCat("_link", elem0.linkable_count_);
235 elem1.args_.RenameArgs(unique_postfix, &elem1.elementwise_code_);
236 const std::string link_value_name = "interm_value" + unique_postfix;
237 const std::string value_declaration =
238 "\n" + GetTypeDeclaration(gpu_info, link_value_type, 4) + " " +
239 link_value_name + ";\n";
240 elem0.elementwise_code_ = absl::StrReplaceAll(
241 elem0.elementwise_code_, {{"out_value", link_value_name}});
242 elem0.elementwise_code_ =
243 absl::Substitute(elem0.elementwise_code_, value_declaration);
244 elem1.elementwise_code_ = absl::StrReplaceAll(elem1.elementwise_code_,
245 {{"in_value", link_value_name},
246 {"READ_SECOND_VALUE", ""},
247 {"in2_value", "in_value"}});
248 elem0.elementwise_code_ += "\n" + elem1.elementwise_code_;
249 return elem0.args_.Merge(std::move(elem1.args_), unique_postfix,
250 {elem1.second_elementwise_tensor_name_});
251 }
252
253 // input input
254 // / \ |
255 // | elem0 |
256 // \ / --> elem
257 // elem1 |
258 // | |
259 // output output
260 // GPUOperation* operation is elem1
261 // *this is elem0
Fuse2InputElemWithSimpleElemAsSecondInput(const GpuInfo & gpu_info,GPUOperation * operation43)262 absl::Status GPUOperation::Fuse2InputElemWithSimpleElemAsSecondInput(
263 const GpuInfo& gpu_info, GPUOperation* operation43) {
264 GPUOperation& elem0 = *this;
265 GPUOperation& elem1 = *operation43;
266 const auto link_value_type = elem0.definition_.dst_tensors[0].GetDataType();
267 elem0.definition_.dst_tensors[0] = elem1.definition_.dst_tensors[0];
268 elem0.linkable_count_ += (elem1.linkable_count_ + 1);
269 std::string unique_postfix = absl::StrCat("_link", elem0.linkable_count_);
270 elem1.args_.RenameArgs(unique_postfix, &elem1.elementwise_code_);
271 const std::string link_value_name = "interm_value" + unique_postfix;
272 const std::string value_declaration =
273 "\n" + GetTypeDeclaration(gpu_info, link_value_type, 4) + " " +
274 link_value_name + ";\n";
275 elem0.elementwise_code_ = absl::StrReplaceAll(
276 elem0.elementwise_code_, {{"out_value", link_value_name}});
277 elem0.elementwise_code_ =
278 absl::Substitute(elem0.elementwise_code_, value_declaration);
279 elem1.elementwise_code_ = absl::StrReplaceAll(
280 elem1.elementwise_code_,
281 {{"in2_value", link_value_name}, {"READ_SECOND_VALUE", ""}});
282 elem0.elementwise_code_ += "\n" + elem1.elementwise_code_;
283 return elem0.args_.Merge(std::move(elem1.args_), unique_postfix,
284 {elem1.second_elementwise_tensor_name_});
285 }
286
AddOperation(const GpuInfo & gpu_info,GPUOperation * operation)287 absl::Status GPUOperation::AddOperation(const GpuInfo& gpu_info,
288 GPUOperation* operation) {
289 const auto prev_type = definition_.dst_tensors[0].GetDataType();
290 definition_.dst_tensors[0] = operation->definition_.dst_tensors[0];
291 linkable_count_ += (operation->linkable_count_ + 1);
292 std::string code = operation->elementwise_code_;
293 std::string unique_postfix = absl::StrCat("_link", linkable_count_);
294 operation->args_.RenameArgs(unique_postfix, &code);
295 operation->second_elementwise_tensor_name_ += unique_postfix;
296 if (elementwise_code_.empty()) {
297 elementwise_code_ = code;
298 elementwise_inputs_ = operation->elementwise_inputs_;
299 second_elementwise_tensor_name_ =
300 operation->second_elementwise_tensor_name_;
301 } else {
302 if (operation->elementwise_inputs_ == 2) {
303 if (elementwise_inputs_ == 2) {
304 // if we have fusion of 2 2-input elementwise ops, we will get 3-input
305 // elementwise, but currently we support only max 2-input elementwise.
306 // So we will resolve one input here.
307 RETURN_IF_ERROR(ResolveSecondElementwiseInput());
308 }
309 second_elementwise_tensor_name_ =
310 operation->second_elementwise_tensor_name_;
311 elementwise_inputs_ = 2;
312 }
313 const std::string new_value_name = "interm_value" + unique_postfix;
314 code = absl::StrReplaceAll(code, {{"in_value", new_value_name}});
315 elementwise_code_ =
316 absl::StrReplaceAll(elementwise_code_, {{"out_value", new_value_name}});
317 const std::string out_var_declaration =
318 "\n" + GetTypeDeclaration(gpu_info, prev_type, 4) + " " +
319 new_value_name + ";\n";
320 elementwise_code_ =
321 absl::Substitute(elementwise_code_, out_var_declaration);
322 elementwise_code_ = elementwise_code_ + "\n" + code;
323 }
324 RETURN_IF_ERROR(args_.Merge(std::move(operation->args_), unique_postfix));
325 for (int i = 0; i < operation->src_tensors_names_.size(); ++i) {
326 definition_.src_tensors.push_back(
327 operation->definition_.src_tensors[i + 1]);
328 src_tensors_names_.push_back(operation->src_tensors_names_[i] +
329 unique_postfix);
330 }
331 for (int i = 0; i < operation->dst_tensors_names_.size(); ++i) {
332 dst_tensors_names_.push_back(operation->dst_tensors_names_[i] +
333 unique_postfix);
334 }
335 return absl::OkStatus();
336 }
337
ResolveSecondElementwiseInput()338 absl::Status GPUOperation::ResolveSecondElementwiseInput() {
339 if (elementwise_inputs_ != 2) {
340 return absl::FailedPreconditionError(
341 "Can not apply ResolveSecondElementwiseInput for non 2 input "
342 "elementwise");
343 }
344 TensorDescriptor* tensor_desc;
345 RETURN_IF_ERROR(
346 GetTensorDescriptor(second_elementwise_tensor_name_, &tensor_desc));
347 std::string coords = "X_COORD, Y_COORD, S_COORD";
348 if (tensor_desc->HasAxis(Axis::BATCH)) {
349 coords += ", B_COORD";
350 }
351 const std::string read_code = "args." + second_elementwise_tensor_name_ +
352 "::type second_value = args." +
353 second_elementwise_tensor_name_ + ".Read(" +
354 coords + ");\n";
355 elementwise_code_ = absl::StrReplaceAll(
356 elementwise_code_,
357 {{"in2_value", "second_value"}, {"READ_SECOND_VALUE", read_code}});
358 elementwise_inputs_ = 1;
359 return absl::OkStatus();
360 }
361
GetTensorDescriptor(const std::string & tensor_name,TensorDescriptor ** resutl)362 absl::Status GPUOperation::GetTensorDescriptor(const std::string& tensor_name,
363 TensorDescriptor** resutl) {
364 for (int i = 0; i < src_tensors_names_.size(); ++i) {
365 if (src_tensors_names_[i] == tensor_name) {
366 int index = elementwise_ ? i + 1 : i;
367 *resutl = &definition_.src_tensors[index];
368 return absl::OkStatus();
369 }
370 }
371 for (int i = 0; i < dst_tensors_names_.size(); ++i) {
372 if (dst_tensors_names_[i] == tensor_name) {
373 int index = elementwise_ ? i + 1 : i;
374 *resutl = &definition_.dst_tensors[index];
375 return absl::OkStatus();
376 }
377 }
378 return absl::NotFoundError("Can not find tensor with this name");
379 }
380
AddSrcTensor(const std::string & tensor_name,const TensorDescriptor & desc)381 void GPUOperation::AddSrcTensor(const std::string& tensor_name,
382 const TensorDescriptor& desc) {
383 src_tensors_names_.push_back(tensor_name);
384 auto desc_new = std::make_unique<TensorDescriptor>(desc);
385 args_.AddObjectRef(tensor_name, AccessType::READ, std::move(desc_new));
386 }
387
AddSrcBuffer(const std::string & buffer_name,const BufferDescriptor & desc)388 void GPUOperation::AddSrcBuffer(const std::string& buffer_name,
389 const BufferDescriptor& desc) {
390 src_tensors_names_.push_back(buffer_name);
391 auto desc_new = std::make_unique<BufferDescriptor>(desc);
392 args_.AddObjectRef(buffer_name, AccessType::READ, std::move(desc_new));
393 }
394
AddDstTensor(const std::string & tensor_name,const TensorDescriptor & desc)395 void GPUOperation::AddDstTensor(const std::string& tensor_name,
396 const TensorDescriptor& desc) {
397 dst_tensors_names_.push_back(tensor_name);
398 auto desc_new = std::make_unique<TensorDescriptor>(desc);
399 args_.AddObjectRef(tensor_name, AccessType::WRITE, std::move(desc_new));
400 }
401
AssembleCode(const GpuInfo & gpu_info)402 absl::Status GPUOperation::AssembleCode(const GpuInfo& gpu_info) {
403 if (elementwise_inputs_ == 2) {
404 RETURN_IF_ERROR(ResolveSecondElementwiseInput());
405 }
406 if (elementwise_) {
407 src_tensors_names_.insert(src_tensors_names_.begin(), "src_tensor");
408 args_.AddObjectRef(
409 "src_tensor", AccessType::READ,
410 std::make_unique<TensorDescriptor>(definition_.src_tensors[0]));
411
412 dst_tensors_names_.insert(dst_tensors_names_.begin(), "dst_tensor");
413 args_.AddObjectRef(
414 "dst_tensor", AccessType::WRITE,
415 std::make_unique<TensorDescriptor>(definition_.dst_tensors[0]));
416
417 code_ = GetElementWiseCode(definition_);
418 }
419 RETURN_IF_ERROR(args_.Compile(
420 gpu_info, {{dst_tensors_names_[0], elementwise_code_}}, &code_));
421 CalculateConstArgsSize();
422 return absl::OkStatus();
423 }
424
RecalculateWorkGroupsCount()425 void GPUOperation::RecalculateWorkGroupsCount() {
426 work_groups_count_ = GetWorkGroupsCountInternal(
427 grid_dimension_, grid_size_, work_group_size_, work_group_launch_order_);
428 }
429
CalculateConstArgsSize()430 void GPUOperation::CalculateConstArgsSize() {
431 const_args_size_ = 0;
432 for (const auto& obj : args_.GetObjects()) {
433 const_args_size_ += obj.second->GetSizeInBytes();
434 }
435 }
436
GetPossibleDispatches(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<DispatchInfo> * dispatches) const437 void GPUOperation::GetPossibleDispatches(
438 TuningType tuning_type, const GpuInfo& gpu_info,
439 const KernelInfo& kernel_info,
440 std::vector<DispatchInfo>* dispatches) const {
441 std::vector<int3> work_group_sizes;
442 GetPossibleKernelWorkGroups(tuning_type, gpu_info, kernel_info,
443 &work_group_sizes);
444 dispatches->resize(work_group_sizes.size());
445 for (int i = 0; i < work_group_sizes.size(); ++i) {
446 auto& dispatch_info = (*dispatches)[i];
447 dispatch_info.work_group_size = work_group_sizes[i];
448 dispatch_info.work_groups_count = GetWorkGroupsCountInternal(
449 grid_dimension_, grid_size_, work_group_sizes[i],
450 work_group_launch_order_);
451 }
452 }
453
GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups) const454 void GPUOperation::GetPossibleKernelWorkGroups(
455 TuningType tuning_type, const GpuInfo& gpu_info,
456 const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
457 GetPossibleWorkGroups(tuning_type, gpu_info, kernel_info, grid_size_,
458 work_groups);
459 }
460
GetGridSize() const461 int3 GPUOperation::GetGridSize() const {
462 if (tensor_to_grid_ == TensorToGrid::kWBToX_HDToY_SToZ) {
463 const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
464 const int grid_y = dst_[0]->Height() * dst_[0]->Depth();
465 const int grid_z = dst_[0]->Slices();
466 return int3(grid_x, grid_y, grid_z);
467 }
468 if (tensor_to_grid_ == TensorToGrid::kWBToX_HDToY_ZIs1) {
469 const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
470 const int grid_y = dst_[0]->Height() * dst_[0]->Depth();
471 const int grid_z = 1;
472 return int3(grid_x, grid_y, grid_z);
473 }
474 if (tensor_to_grid_ == TensorToGrid::kWBToX_HToY_DToZ) {
475 const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
476 const int grid_y = dst_[0]->Height();
477 const int grid_z = dst_[0]->Depth();
478 return int3(grid_x, grid_y, grid_z);
479 }
480 if (tensor_to_grid_ == TensorToGrid::kBToX_YIs1_ZIs1) {
481 const int grid_x = dst_[0]->Batch();
482 const int grid_y = 1;
483 const int grid_z = 1;
484 return int3(grid_x, grid_y, grid_z);
485 }
486 return grid_size_;
487 }
488
CreateGpuOperation(const OperationDef & definition,ElementwiseDescriptor && descriptor)489 GPUOperation CreateGpuOperation(const OperationDef& definition,
490 ElementwiseDescriptor&& descriptor) {
491 const BHWC second_shape(2, 2, 2, 2); // dummy non-broadcasted shape
492 return CreateGpuOperation(definition, std::move(descriptor), second_shape);
493 }
494
CreateGpuOperation(const OperationDef & definition,ElementwiseDescriptor && descriptor,const BHWC & second_shape)495 GPUOperation CreateGpuOperation(const OperationDef& definition,
496 ElementwiseDescriptor&& descriptor,
497 const BHWC& second_shape) {
498 GPUOperation op(definition);
499 op.elementwise_code_ = std::move(descriptor.code);
500 op.elementwise_ = true;
501 if (definition.src_tensors.size() > 1 &&
502 op.elementwise_code_.find("in2_value")) {
503 const auto second_tensor_def = definition.src_tensors[1];
504 if (NeedsBroadcast(second_tensor_def, second_shape)) {
505 const std::string x_coord = second_shape.w == 1 ? "0" : "X_COORD";
506 const std::string y_coord = second_shape.h == 1 ? "0" : "Y_COORD";
507 const std::string s_coord = second_shape.c == 1 ? "0" : "S_COORD";
508 std::string coords = absl::StrCat(x_coord, ", ", y_coord, ", ", s_coord);
509 if (second_tensor_def.HasAxis(Axis::BATCH)) {
510 const std::string b_coord = second_shape.b == 1 ? "0" : "B_COORD";
511 coords += ", " + b_coord;
512 }
513 std::string read_value_code = absl::StrCat(
514 "args.src_tensor_1::type in2_value = args.src_tensor_1.Read(", coords,
515 ");\n");
516 if (second_shape.c == 1) {
517 read_value_code += " in2_value.y = in2_value.x;\n";
518 read_value_code += " in2_value.z = in2_value.x;\n";
519 read_value_code += " in2_value.w = in2_value.x;\n";
520 }
521 op.elementwise_code_ =
522 "$0{" + read_value_code + op.elementwise_code_ + "}";
523 op.elementwise_code_ = absl::StrReplaceAll(
524 op.elementwise_code_, {{"in2_value", "second_value"}});
525 op.elementwise_inputs_ = 1;
526 } else {
527 op.elementwise_code_ =
528 "$0{READ_SECOND_VALUE" + op.elementwise_code_ + "}";
529 op.elementwise_inputs_ = 2;
530 op.second_elementwise_tensor_name_ = "src_tensor_1";
531 }
532 } else {
533 op.elementwise_code_ = "$0{" + op.elementwise_code_ + "}";
534 op.elementwise_inputs_ = 1;
535 }
536 op.args_ = std::move(descriptor.args);
537 for (int i = 1; i < definition.src_tensors.size(); ++i) {
538 const std::string tensor_name = "src_tensor_" + std::to_string(i);
539 op.AddSrcTensor(tensor_name, definition.src_tensors[i]);
540 }
541 op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
542 return op;
543 }
544
Fuse2InputElemWith2SimpleElem(const GpuInfo & gpu_info,GPUOperation && elem0,GPUOperation && elem1,GPUOperation && elem_root,GPUOperation * result)545 absl::Status Fuse2InputElemWith2SimpleElem(const GpuInfo& gpu_info,
546 GPUOperation&& elem0,
547 GPUOperation&& elem1,
548 GPUOperation&& elem_root,
549 GPUOperation* result) {
550 int linkable_count = std::max(elem0.linkable_count_, elem1.linkable_count_);
551 linkable_count = std::max(linkable_count, elem_root.linkable_count_);
552 linkable_count += 1;
553
554 std::string unique_postfix = absl::StrCat("_link", linkable_count);
555 elem0.args_.RenameArgs(unique_postfix + "l", &elem0.elementwise_code_);
556 elem1.args_.RenameArgs(unique_postfix + "r", &elem1.elementwise_code_);
557 elem_root.args_.RenameArgs(unique_postfix, &elem_root.elementwise_code_);
558 const std::string link_left_value_name = "interm_value_left" + unique_postfix;
559 const std::string link_right_value_name =
560 "interm_value_right" + unique_postfix;
561 const auto link_left_value_type =
562 elem0.definition_.dst_tensors[0].GetDataType();
563 const std::string left_value_declaration =
564 "\n" + GetTypeDeclaration(gpu_info, link_left_value_type, 4) + " " +
565 link_left_value_name + ";\n";
566 const auto link_right_value_type =
567 elem1.definition_.dst_tensors[0].GetDataType();
568 const std::string right_value_declaration =
569 "\n" + GetTypeDeclaration(gpu_info, link_right_value_type, 4) + " " +
570 link_right_value_name + ";\n";
571 elem0.elementwise_code_ = absl::StrReplaceAll(
572 elem0.elementwise_code_, {{"out_value", link_left_value_name}});
573 elem1.elementwise_code_ = absl::StrReplaceAll(
574 elem1.elementwise_code_, {{"out_value", link_right_value_name}});
575 elem0.elementwise_code_ =
576 absl::Substitute(elem0.elementwise_code_, left_value_declaration);
577 elem1.elementwise_code_ =
578 absl::Substitute(elem1.elementwise_code_, right_value_declaration);
579 elem_root.elementwise_code_ = absl::StrReplaceAll(
580 elem_root.elementwise_code_, {{"in_value", link_left_value_name},
581 {"READ_SECOND_VALUE", ""},
582 {"in2_value", link_right_value_name}});
583
584 OperationDef new_definition = elem0.definition_;
585 new_definition.dst_tensors[0] = elem_root.definition_.dst_tensors[0];
586
587 *result = GPUOperation(new_definition);
588 result->elementwise_ = true;
589 result->elementwise_inputs_ = 1;
590 result->tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
591 result->elementwise_code_ = elem0.elementwise_code_ + "\n" +
592 elem1.elementwise_code_ + "\n" +
593 elem_root.elementwise_code_;
594 result->linkable_count_ = linkable_count;
595 RETURN_IF_ERROR(
596 result->args_.Merge(std::move(elem0.args_), unique_postfix + "l"));
597 RETURN_IF_ERROR(
598 result->args_.Merge(std::move(elem1.args_), unique_postfix + "r"));
599 RETURN_IF_ERROR(
600 result->args_.Merge(std::move(elem_root.args_), unique_postfix,
601 {elem_root.second_elementwise_tensor_name_}));
602 return absl::OkStatus();
603 }
604
605 } // namespace gpu
606 } // namespace tflite
607