xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/strings/str_replace.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/lite/delegates/gpu/common/access_type.h"
27 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
28 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
29 
30 namespace tflite {
31 namespace gpu {
32 namespace {
GetWorkGroupsCountInternal(int grid_dimension,const int3 & grid_size,const int3 & work_group_size,const int3 & work_group_launch_order)33 int3 GetWorkGroupsCountInternal(int grid_dimension, const int3& grid_size,
34                                 const int3& work_group_size,
35                                 const int3& work_group_launch_order) {
36   int3 work_groups_count;
37   if (grid_dimension == 1) {
38     work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x);
39     work_groups_count.y = 1;
40     work_groups_count.z = 1;
41   } else if (grid_dimension == 2) {
42     int3 wgs;
43     wgs.x = DivideRoundUp(grid_size.x, work_group_size.x);
44     wgs.y = DivideRoundUp(grid_size.y, work_group_size.y);
45     work_groups_count.x = wgs[work_group_launch_order[0]];
46     work_groups_count.y = wgs[work_group_launch_order[1]];
47     work_groups_count.z = 1;
48   } else {  // grid_dimension == 3
49     int3 wgs;
50     wgs.x = DivideRoundUp(grid_size.x, work_group_size.x);
51     wgs.y = DivideRoundUp(grid_size.y, work_group_size.y);
52     wgs.z = DivideRoundUp(grid_size.z, work_group_size.z);
53     work_groups_count.x = wgs[work_group_launch_order[0]];
54     work_groups_count.y = wgs[work_group_launch_order[1]];
55     work_groups_count.z = wgs[work_group_launch_order[2]];
56   }
57   return work_groups_count;
58 }
59 
GetElementWiseCode(const OperationDef & op_def)60 std::string GetElementWiseCode(const OperationDef& op_def) {
61   std::string c;
62   c += "MAIN_FUNCTION($0) {\n";
63   if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
64     c += "  int linear_id = GLOBAL_ID_0;\n";
65     c += "  int X = linear_id / args.dst_tensor.Batch();\n";
66     c += "  int B = linear_id % args.dst_tensor.Batch();\n";
67     c += "  args.dst_tensor.SetBatchRef(B);\n";
68     c += "  args.src_tensor.SetBatchRef(B);\n";
69   } else {
70     c += "  int X = GLOBAL_ID_0;\n";
71   }
72   c += "  int Y = GLOBAL_ID_1;\n";
73   c += "  int Z = GLOBAL_ID_2;\n";
74   c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
75        "Z >= args.dst_tensor.Slices()) return; \n";
76   c += "  args.src_tensor::type src = args.src_tensor.Read(X, Y, Z);\n";
77   c += "  args.dst_tensor.Write(src, X, Y, Z);\n";
78   c += "} \n";
79   return c;
80 }
81 
NeedsBroadcast(const TensorDescriptor & desc,const BHWC & shape)82 bool NeedsBroadcast(const TensorDescriptor& desc, const BHWC& shape) {
83   bool needs_broadcast = shape.w == 1 || shape.h == 1 || shape.c == 1;
84   if (desc.HasAxis(Axis::BATCH)) {
85     needs_broadcast = needs_broadcast || shape.b == 1;
86   }
87   return needs_broadcast;
88 }
89 
90 }  // namespace
91 
GetDataType() const92 DataType OperationDef::GetDataType() const {
93   return DeduceDataTypeFromPrecision(precision);
94 }
95 
GetPrimaryDataType() const96 DataType OperationDef::GetPrimaryDataType() const {
97   return src_tensors[0].GetDataType();
98 }
GetPrimaryStorageType() const99 TensorStorageType OperationDef::GetPrimaryStorageType() const {
100   return src_tensors[0].GetStorageType();
101 }
102 
IsBatchSupported() const103 bool OperationDef::IsBatchSupported() const {
104   for (const auto& src : src_tensors) {
105     if (src.HasAxis(Axis::BATCH)) {
106       return true;
107     }
108   }
109   for (const auto& dst : dst_tensors) {
110     if (dst.HasAxis(Axis::BATCH)) {
111       return true;
112     }
113   }
114   return false;
115 }
116 
GPUOperation(const OperationDef & definition)117 GPUOperation::GPUOperation(const OperationDef& definition)
118     : definition_(definition) {}
119 
SetSrc(GpuSpatialTensor * ptr,int index)120 void GPUOperation::SetSrc(GpuSpatialTensor* ptr, int index) {
121   if (index >= src_.size()) {
122     src_.resize(index + 1, nullptr);
123   }
124   src_[index] = ptr;
125 }
126 
SetDst(GpuSpatialTensor * ptr,int index)127 void GPUOperation::SetDst(GpuSpatialTensor* ptr, int index) {
128   if (index >= dst_.size()) {
129     dst_.resize(index + 1, nullptr);
130   }
131   dst_[index] = ptr;
132 }
133 
GPUOperation(GPUOperation && operation)134 GPUOperation::GPUOperation(GPUOperation&& operation)
135     : args_(std::move(operation.args_)),
136       code_(std::move(operation.code_)),
137       work_group_size_(operation.work_group_size_),
138       compiler_options_(std::move(operation.compiler_options_)),
139       tensor_to_grid_(operation.tensor_to_grid_),
140       flops_(operation.flops_),
141       const_args_size_(operation.const_args_size_),
142       definition_(std::move(operation.definition_)),
143       src_(std::move(operation.src_)),
144       dst_(std::move(operation.dst_)),
145       grid_dimension_(operation.grid_dimension_),
146       work_group_launch_order_(operation.work_group_launch_order_),
147       grid_size_(operation.grid_size_),
148       src_tensors_names_(std::move(operation.src_tensors_names_)),
149       dst_tensors_names_(std::move(operation.dst_tensors_names_)),
150       work_groups_count_(operation.work_groups_count_),
151       elementwise_(operation.elementwise_),
152       elementwise_inputs_(operation.elementwise_inputs_),
153       second_elementwise_tensor_name_(
154           operation.second_elementwise_tensor_name_),
155       linkable_count_(operation.linkable_count_),
156       elementwise_code_(std::move(operation.elementwise_code_)) {}
157 
operator =(GPUOperation && operation)158 GPUOperation& GPUOperation::operator=(GPUOperation&& operation) {
159   if (this != &operation) {
160     args_ = std::move(operation.args_);
161     code_ = std::move(operation.code_);
162     std::swap(work_group_size_, operation.work_group_size_);
163     compiler_options_ = std::move(operation.compiler_options_);
164     tensor_to_grid_ = operation.tensor_to_grid_;
165     flops_ = operation.flops_;
166     const_args_size_ = operation.const_args_size_;
167     definition_ = std::move(operation.definition_);
168     src_ = std::move(operation.src_);
169     dst_ = std::move(operation.dst_);
170     std::swap(grid_dimension_, operation.grid_dimension_);
171     std::swap(work_group_launch_order_, operation.work_group_launch_order_);
172     std::swap(grid_size_, operation.grid_size_);
173     src_tensors_names_ = std::move(operation.src_tensors_names_);
174     dst_tensors_names_ = std::move(operation.dst_tensors_names_);
175     std::swap(work_groups_count_, operation.work_groups_count_);
176     elementwise_ = operation.elementwise_;
177     std::swap(elementwise_inputs_, operation.elementwise_inputs_);
178     std::swap(second_elementwise_tensor_name_,
179               operation.second_elementwise_tensor_name_);
180     std::swap(linkable_count_, operation.linkable_count_);
181     elementwise_code_ = std::move(operation.elementwise_code_);
182   }
183   return *this;
184 }
185 
186 //    input       input
187 //      |           |
188 //    elem0         |
189 //      |    -->  elem
190 //    elem1         |
191 //      |           |
192 //    output      output
193 // GPUOperation* operation is elem1
194 // *this is elem0
FuseSimpleElemWithSimpleElem(const GpuInfo & gpu_info,GPUOperation * operation)195 absl::Status GPUOperation::FuseSimpleElemWithSimpleElem(
196     const GpuInfo& gpu_info, GPUOperation* operation) {
197   GPUOperation& elem0 = *this;
198   GPUOperation& elem1 = *operation;
199   elem0.definition_.dst_tensors[0] = elem1.definition_.dst_tensors[0];
200   const auto link_value_type = elem1.definition_.src_tensors[0].GetDataType();
201   elem0.linkable_count_ += (elem1.linkable_count_ + 1);
202   std::string unique_postfix = absl::StrCat("_link", elem0.linkable_count_);
203   elem1.args_.RenameArgs(unique_postfix, &elem1.elementwise_code_);
204   const std::string link_value_name = "interm_value" + unique_postfix;
205   const std::string value_declaration =
206       "\n" + GetTypeDeclaration(gpu_info, link_value_type, 4) + " " +
207       link_value_name + ";\n";
208   elem1.elementwise_code_ = absl::StrReplaceAll(
209       elem1.elementwise_code_, {{"in_value", link_value_name}});
210   elem0.elementwise_code_ = absl::StrReplaceAll(
211       elem0.elementwise_code_, {{"out_value", link_value_name}});
212   elem0.elementwise_code_ =
213       absl::Substitute(elem0.elementwise_code_, value_declaration);
214   elem0.elementwise_code_ += "\n" + elem1.elementwise_code_;
215   return args_.Merge(std::move(elem1.args_), unique_postfix);
216 }
217 
218 //      input           input
219 //     /    \             |
220 //  elem0    |            |
221 //     \    /      -->  elem
222 //     elem1              |
223 //       |                |
224 //     output           output
225 // GPUOperation* operation is elem1
226 // *this is elem0
Fuse2InputElemWithSimpleElemAsFirstInput(const GpuInfo & gpu_info,GPUOperation * operation)227 absl::Status GPUOperation::Fuse2InputElemWithSimpleElemAsFirstInput(
228     const GpuInfo& gpu_info, GPUOperation* operation) {
229   GPUOperation& elem0 = *this;
230   GPUOperation& elem1 = *operation;
231   const auto link_value_type = elem0.definition_.dst_tensors[0].GetDataType();
232   elem0.definition_.dst_tensors[0] = elem1.definition_.dst_tensors[0];
233   elem0.linkable_count_ += (elem1.linkable_count_ + 1);
234   std::string unique_postfix = absl::StrCat("_link", elem0.linkable_count_);
235   elem1.args_.RenameArgs(unique_postfix, &elem1.elementwise_code_);
236   const std::string link_value_name = "interm_value" + unique_postfix;
237   const std::string value_declaration =
238       "\n" + GetTypeDeclaration(gpu_info, link_value_type, 4) + " " +
239       link_value_name + ";\n";
240   elem0.elementwise_code_ = absl::StrReplaceAll(
241       elem0.elementwise_code_, {{"out_value", link_value_name}});
242   elem0.elementwise_code_ =
243       absl::Substitute(elem0.elementwise_code_, value_declaration);
244   elem1.elementwise_code_ = absl::StrReplaceAll(elem1.elementwise_code_,
245                                                 {{"in_value", link_value_name},
246                                                  {"READ_SECOND_VALUE", ""},
247                                                  {"in2_value", "in_value"}});
248   elem0.elementwise_code_ += "\n" + elem1.elementwise_code_;
249   return elem0.args_.Merge(std::move(elem1.args_), unique_postfix,
250                            {elem1.second_elementwise_tensor_name_});
251 }
252 
253 //      input           input
254 //     /    \             |
255 //    |    elem0          |
256 //     \    /      -->  elem
257 //     elem1              |
258 //       |                |
259 //     output           output
260 // GPUOperation* operation is elem1
261 // *this is elem0
Fuse2InputElemWithSimpleElemAsSecondInput(const GpuInfo & gpu_info,GPUOperation * operation43)262 absl::Status GPUOperation::Fuse2InputElemWithSimpleElemAsSecondInput(
263     const GpuInfo& gpu_info, GPUOperation* operation43) {
264   GPUOperation& elem0 = *this;
265   GPUOperation& elem1 = *operation43;
266   const auto link_value_type = elem0.definition_.dst_tensors[0].GetDataType();
267   elem0.definition_.dst_tensors[0] = elem1.definition_.dst_tensors[0];
268   elem0.linkable_count_ += (elem1.linkable_count_ + 1);
269   std::string unique_postfix = absl::StrCat("_link", elem0.linkable_count_);
270   elem1.args_.RenameArgs(unique_postfix, &elem1.elementwise_code_);
271   const std::string link_value_name = "interm_value" + unique_postfix;
272   const std::string value_declaration =
273       "\n" + GetTypeDeclaration(gpu_info, link_value_type, 4) + " " +
274       link_value_name + ";\n";
275   elem0.elementwise_code_ = absl::StrReplaceAll(
276       elem0.elementwise_code_, {{"out_value", link_value_name}});
277   elem0.elementwise_code_ =
278       absl::Substitute(elem0.elementwise_code_, value_declaration);
279   elem1.elementwise_code_ = absl::StrReplaceAll(
280       elem1.elementwise_code_,
281       {{"in2_value", link_value_name}, {"READ_SECOND_VALUE", ""}});
282   elem0.elementwise_code_ += "\n" + elem1.elementwise_code_;
283   return elem0.args_.Merge(std::move(elem1.args_), unique_postfix,
284                            {elem1.second_elementwise_tensor_name_});
285 }
286 
AddOperation(const GpuInfo & gpu_info,GPUOperation * operation)287 absl::Status GPUOperation::AddOperation(const GpuInfo& gpu_info,
288                                         GPUOperation* operation) {
289   const auto prev_type = definition_.dst_tensors[0].GetDataType();
290   definition_.dst_tensors[0] = operation->definition_.dst_tensors[0];
291   linkable_count_ += (operation->linkable_count_ + 1);
292   std::string code = operation->elementwise_code_;
293   std::string unique_postfix = absl::StrCat("_link", linkable_count_);
294   operation->args_.RenameArgs(unique_postfix, &code);
295   operation->second_elementwise_tensor_name_ += unique_postfix;
296   if (elementwise_code_.empty()) {
297     elementwise_code_ = code;
298     elementwise_inputs_ = operation->elementwise_inputs_;
299     second_elementwise_tensor_name_ =
300         operation->second_elementwise_tensor_name_;
301   } else {
302     if (operation->elementwise_inputs_ == 2) {
303       if (elementwise_inputs_ == 2) {
304         // if we have fusion of 2 2-input elementwise ops, we will get 3-input
305         // elementwise, but currently we support only max 2-input elementwise.
306         // So we will resolve one input here.
307         RETURN_IF_ERROR(ResolveSecondElementwiseInput());
308       }
309       second_elementwise_tensor_name_ =
310           operation->second_elementwise_tensor_name_;
311       elementwise_inputs_ = 2;
312     }
313     const std::string new_value_name = "interm_value" + unique_postfix;
314     code = absl::StrReplaceAll(code, {{"in_value", new_value_name}});
315     elementwise_code_ =
316         absl::StrReplaceAll(elementwise_code_, {{"out_value", new_value_name}});
317     const std::string out_var_declaration =
318         "\n" + GetTypeDeclaration(gpu_info, prev_type, 4) + " " +
319         new_value_name + ";\n";
320     elementwise_code_ =
321         absl::Substitute(elementwise_code_, out_var_declaration);
322     elementwise_code_ = elementwise_code_ + "\n" + code;
323   }
324   RETURN_IF_ERROR(args_.Merge(std::move(operation->args_), unique_postfix));
325   for (int i = 0; i < operation->src_tensors_names_.size(); ++i) {
326     definition_.src_tensors.push_back(
327         operation->definition_.src_tensors[i + 1]);
328     src_tensors_names_.push_back(operation->src_tensors_names_[i] +
329                                  unique_postfix);
330   }
331   for (int i = 0; i < operation->dst_tensors_names_.size(); ++i) {
332     dst_tensors_names_.push_back(operation->dst_tensors_names_[i] +
333                                  unique_postfix);
334   }
335   return absl::OkStatus();
336 }
337 
ResolveSecondElementwiseInput()338 absl::Status GPUOperation::ResolveSecondElementwiseInput() {
339   if (elementwise_inputs_ != 2) {
340     return absl::FailedPreconditionError(
341         "Can not apply ResolveSecondElementwiseInput for non 2 input "
342         "elementwise");
343   }
344   TensorDescriptor* tensor_desc;
345   RETURN_IF_ERROR(
346       GetTensorDescriptor(second_elementwise_tensor_name_, &tensor_desc));
347   std::string coords = "X_COORD, Y_COORD, S_COORD";
348   if (tensor_desc->HasAxis(Axis::BATCH)) {
349     coords += ", B_COORD";
350   }
351   const std::string read_code = "args." + second_elementwise_tensor_name_ +
352                                 "::type second_value = args." +
353                                 second_elementwise_tensor_name_ + ".Read(" +
354                                 coords + ");\n";
355   elementwise_code_ = absl::StrReplaceAll(
356       elementwise_code_,
357       {{"in2_value", "second_value"}, {"READ_SECOND_VALUE", read_code}});
358   elementwise_inputs_ = 1;
359   return absl::OkStatus();
360 }
361 
GetTensorDescriptor(const std::string & tensor_name,TensorDescriptor ** resutl)362 absl::Status GPUOperation::GetTensorDescriptor(const std::string& tensor_name,
363                                                TensorDescriptor** resutl) {
364   for (int i = 0; i < src_tensors_names_.size(); ++i) {
365     if (src_tensors_names_[i] == tensor_name) {
366       int index = elementwise_ ? i + 1 : i;
367       *resutl = &definition_.src_tensors[index];
368       return absl::OkStatus();
369     }
370   }
371   for (int i = 0; i < dst_tensors_names_.size(); ++i) {
372     if (dst_tensors_names_[i] == tensor_name) {
373       int index = elementwise_ ? i + 1 : i;
374       *resutl = &definition_.dst_tensors[index];
375       return absl::OkStatus();
376     }
377   }
378   return absl::NotFoundError("Can not find tensor with this name");
379 }
380 
AddSrcTensor(const std::string & tensor_name,const TensorDescriptor & desc)381 void GPUOperation::AddSrcTensor(const std::string& tensor_name,
382                                 const TensorDescriptor& desc) {
383   src_tensors_names_.push_back(tensor_name);
384   auto desc_new = std::make_unique<TensorDescriptor>(desc);
385   args_.AddObjectRef(tensor_name, AccessType::READ, std::move(desc_new));
386 }
387 
AddSrcBuffer(const std::string & buffer_name,const BufferDescriptor & desc)388 void GPUOperation::AddSrcBuffer(const std::string& buffer_name,
389                                 const BufferDescriptor& desc) {
390   src_tensors_names_.push_back(buffer_name);
391   auto desc_new = std::make_unique<BufferDescriptor>(desc);
392   args_.AddObjectRef(buffer_name, AccessType::READ, std::move(desc_new));
393 }
394 
AddDstTensor(const std::string & tensor_name,const TensorDescriptor & desc)395 void GPUOperation::AddDstTensor(const std::string& tensor_name,
396                                 const TensorDescriptor& desc) {
397   dst_tensors_names_.push_back(tensor_name);
398   auto desc_new = std::make_unique<TensorDescriptor>(desc);
399   args_.AddObjectRef(tensor_name, AccessType::WRITE, std::move(desc_new));
400 }
401 
AssembleCode(const GpuInfo & gpu_info)402 absl::Status GPUOperation::AssembleCode(const GpuInfo& gpu_info) {
403   if (elementwise_inputs_ == 2) {
404     RETURN_IF_ERROR(ResolveSecondElementwiseInput());
405   }
406   if (elementwise_) {
407     src_tensors_names_.insert(src_tensors_names_.begin(), "src_tensor");
408     args_.AddObjectRef(
409         "src_tensor", AccessType::READ,
410         std::make_unique<TensorDescriptor>(definition_.src_tensors[0]));
411 
412     dst_tensors_names_.insert(dst_tensors_names_.begin(), "dst_tensor");
413     args_.AddObjectRef(
414         "dst_tensor", AccessType::WRITE,
415         std::make_unique<TensorDescriptor>(definition_.dst_tensors[0]));
416 
417     code_ = GetElementWiseCode(definition_);
418   }
419   RETURN_IF_ERROR(args_.Compile(
420       gpu_info, {{dst_tensors_names_[0], elementwise_code_}}, &code_));
421   CalculateConstArgsSize();
422   return absl::OkStatus();
423 }
424 
RecalculateWorkGroupsCount()425 void GPUOperation::RecalculateWorkGroupsCount() {
426   work_groups_count_ = GetWorkGroupsCountInternal(
427       grid_dimension_, grid_size_, work_group_size_, work_group_launch_order_);
428 }
429 
CalculateConstArgsSize()430 void GPUOperation::CalculateConstArgsSize() {
431   const_args_size_ = 0;
432   for (const auto& obj : args_.GetObjects()) {
433     const_args_size_ += obj.second->GetSizeInBytes();
434   }
435 }
436 
GetPossibleDispatches(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<DispatchInfo> * dispatches) const437 void GPUOperation::GetPossibleDispatches(
438     TuningType tuning_type, const GpuInfo& gpu_info,
439     const KernelInfo& kernel_info,
440     std::vector<DispatchInfo>* dispatches) const {
441   std::vector<int3> work_group_sizes;
442   GetPossibleKernelWorkGroups(tuning_type, gpu_info, kernel_info,
443                               &work_group_sizes);
444   dispatches->resize(work_group_sizes.size());
445   for (int i = 0; i < work_group_sizes.size(); ++i) {
446     auto& dispatch_info = (*dispatches)[i];
447     dispatch_info.work_group_size = work_group_sizes[i];
448     dispatch_info.work_groups_count = GetWorkGroupsCountInternal(
449         grid_dimension_, grid_size_, work_group_sizes[i],
450         work_group_launch_order_);
451   }
452 }
453 
GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups) const454 void GPUOperation::GetPossibleKernelWorkGroups(
455     TuningType tuning_type, const GpuInfo& gpu_info,
456     const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
457   GetPossibleWorkGroups(tuning_type, gpu_info, kernel_info, grid_size_,
458                         work_groups);
459 }
460 
GetGridSize() const461 int3 GPUOperation::GetGridSize() const {
462   if (tensor_to_grid_ == TensorToGrid::kWBToX_HDToY_SToZ) {
463     const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
464     const int grid_y = dst_[0]->Height() * dst_[0]->Depth();
465     const int grid_z = dst_[0]->Slices();
466     return int3(grid_x, grid_y, grid_z);
467   }
468   if (tensor_to_grid_ == TensorToGrid::kWBToX_HDToY_ZIs1) {
469     const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
470     const int grid_y = dst_[0]->Height() * dst_[0]->Depth();
471     const int grid_z = 1;
472     return int3(grid_x, grid_y, grid_z);
473   }
474   if (tensor_to_grid_ == TensorToGrid::kWBToX_HToY_DToZ) {
475     const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
476     const int grid_y = dst_[0]->Height();
477     const int grid_z = dst_[0]->Depth();
478     return int3(grid_x, grid_y, grid_z);
479   }
480   if (tensor_to_grid_ == TensorToGrid::kBToX_YIs1_ZIs1) {
481     const int grid_x = dst_[0]->Batch();
482     const int grid_y = 1;
483     const int grid_z = 1;
484     return int3(grid_x, grid_y, grid_z);
485   }
486   return grid_size_;
487 }
488 
CreateGpuOperation(const OperationDef & definition,ElementwiseDescriptor && descriptor)489 GPUOperation CreateGpuOperation(const OperationDef& definition,
490                                 ElementwiseDescriptor&& descriptor) {
491   const BHWC second_shape(2, 2, 2, 2);  // dummy non-broadcasted shape
492   return CreateGpuOperation(definition, std::move(descriptor), second_shape);
493 }
494 
CreateGpuOperation(const OperationDef & definition,ElementwiseDescriptor && descriptor,const BHWC & second_shape)495 GPUOperation CreateGpuOperation(const OperationDef& definition,
496                                 ElementwiseDescriptor&& descriptor,
497                                 const BHWC& second_shape) {
498   GPUOperation op(definition);
499   op.elementwise_code_ = std::move(descriptor.code);
500   op.elementwise_ = true;
501   if (definition.src_tensors.size() > 1 &&
502       op.elementwise_code_.find("in2_value")) {
503     const auto second_tensor_def = definition.src_tensors[1];
504     if (NeedsBroadcast(second_tensor_def, second_shape)) {
505       const std::string x_coord = second_shape.w == 1 ? "0" : "X_COORD";
506       const std::string y_coord = second_shape.h == 1 ? "0" : "Y_COORD";
507       const std::string s_coord = second_shape.c == 1 ? "0" : "S_COORD";
508       std::string coords = absl::StrCat(x_coord, ", ", y_coord, ", ", s_coord);
509       if (second_tensor_def.HasAxis(Axis::BATCH)) {
510         const std::string b_coord = second_shape.b == 1 ? "0" : "B_COORD";
511         coords += ", " + b_coord;
512       }
513       std::string read_value_code = absl::StrCat(
514           "args.src_tensor_1::type in2_value = args.src_tensor_1.Read(", coords,
515           ");\n");
516       if (second_shape.c == 1) {
517         read_value_code += "  in2_value.y = in2_value.x;\n";
518         read_value_code += "  in2_value.z = in2_value.x;\n";
519         read_value_code += "  in2_value.w = in2_value.x;\n";
520       }
521       op.elementwise_code_ =
522           "$0{" + read_value_code + op.elementwise_code_ + "}";
523       op.elementwise_code_ = absl::StrReplaceAll(
524           op.elementwise_code_, {{"in2_value", "second_value"}});
525       op.elementwise_inputs_ = 1;
526     } else {
527       op.elementwise_code_ =
528           "$0{READ_SECOND_VALUE" + op.elementwise_code_ + "}";
529       op.elementwise_inputs_ = 2;
530       op.second_elementwise_tensor_name_ = "src_tensor_1";
531     }
532   } else {
533     op.elementwise_code_ = "$0{" + op.elementwise_code_ + "}";
534     op.elementwise_inputs_ = 1;
535   }
536   op.args_ = std::move(descriptor.args);
537   for (int i = 1; i < definition.src_tensors.size(); ++i) {
538     const std::string tensor_name = "src_tensor_" + std::to_string(i);
539     op.AddSrcTensor(tensor_name, definition.src_tensors[i]);
540   }
541   op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
542   return op;
543 }
544 
Fuse2InputElemWith2SimpleElem(const GpuInfo & gpu_info,GPUOperation && elem0,GPUOperation && elem1,GPUOperation && elem_root,GPUOperation * result)545 absl::Status Fuse2InputElemWith2SimpleElem(const GpuInfo& gpu_info,
546                                            GPUOperation&& elem0,
547                                            GPUOperation&& elem1,
548                                            GPUOperation&& elem_root,
549                                            GPUOperation* result) {
550   int linkable_count = std::max(elem0.linkable_count_, elem1.linkable_count_);
551   linkable_count = std::max(linkable_count, elem_root.linkable_count_);
552   linkable_count += 1;
553 
554   std::string unique_postfix = absl::StrCat("_link", linkable_count);
555   elem0.args_.RenameArgs(unique_postfix + "l", &elem0.elementwise_code_);
556   elem1.args_.RenameArgs(unique_postfix + "r", &elem1.elementwise_code_);
557   elem_root.args_.RenameArgs(unique_postfix, &elem_root.elementwise_code_);
558   const std::string link_left_value_name = "interm_value_left" + unique_postfix;
559   const std::string link_right_value_name =
560       "interm_value_right" + unique_postfix;
561   const auto link_left_value_type =
562       elem0.definition_.dst_tensors[0].GetDataType();
563   const std::string left_value_declaration =
564       "\n" + GetTypeDeclaration(gpu_info, link_left_value_type, 4) + " " +
565       link_left_value_name + ";\n";
566   const auto link_right_value_type =
567       elem1.definition_.dst_tensors[0].GetDataType();
568   const std::string right_value_declaration =
569       "\n" + GetTypeDeclaration(gpu_info, link_right_value_type, 4) + " " +
570       link_right_value_name + ";\n";
571   elem0.elementwise_code_ = absl::StrReplaceAll(
572       elem0.elementwise_code_, {{"out_value", link_left_value_name}});
573   elem1.elementwise_code_ = absl::StrReplaceAll(
574       elem1.elementwise_code_, {{"out_value", link_right_value_name}});
575   elem0.elementwise_code_ =
576       absl::Substitute(elem0.elementwise_code_, left_value_declaration);
577   elem1.elementwise_code_ =
578       absl::Substitute(elem1.elementwise_code_, right_value_declaration);
579   elem_root.elementwise_code_ = absl::StrReplaceAll(
580       elem_root.elementwise_code_, {{"in_value", link_left_value_name},
581                                     {"READ_SECOND_VALUE", ""},
582                                     {"in2_value", link_right_value_name}});
583 
584   OperationDef new_definition = elem0.definition_;
585   new_definition.dst_tensors[0] = elem_root.definition_.dst_tensors[0];
586 
587   *result = GPUOperation(new_definition);
588   result->elementwise_ = true;
589   result->elementwise_inputs_ = 1;
590   result->tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
591   result->elementwise_code_ = elem0.elementwise_code_ + "\n" +
592                               elem1.elementwise_code_ + "\n" +
593                               elem_root.elementwise_code_;
594   result->linkable_count_ = linkable_count;
595   RETURN_IF_ERROR(
596       result->args_.Merge(std::move(elem0.args_), unique_postfix + "l"));
597   RETURN_IF_ERROR(
598       result->args_.Merge(std::move(elem1.args_), unique_postfix + "r"));
599   RETURN_IF_ERROR(
600       result->args_.Merge(std::move(elem_root.args_), unique_postfix,
601                           {elem_root.second_elementwise_tensor_name_}));
602   return absl::OkStatus();
603 }
604 
605 }  // namespace gpu
606 }  // namespace tflite
607