xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/task/arguments.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/task/arguments.h"
17 
18 #include <algorithm>
19 #include <map>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/strings/ascii.h"
25 #include "absl/strings/match.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "absl/strings/str_replace.h"
29 #include "absl/strings/substitute.h"
30 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
31 #include "tensorflow/lite/delegates/gpu/common/status.h"
32 #include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h"
33 #include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
34 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
35 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
36 
37 namespace tflite {
38 namespace gpu {
39 namespace {
IsWordSymbol(char symbol)40 bool IsWordSymbol(char symbol) {
41   return absl::ascii_isalnum(symbol) || symbol == '_';
42 }
43 
ReplaceAllWords(const std::string & old_word,const std::string & new_word,std::string * str)44 void ReplaceAllWords(const std::string& old_word, const std::string& new_word,
45                      std::string* str) {
46   size_t position = str->find(old_word);
47   while (position != std::string::npos) {
48     char prev = position == 0 ? '.' : (*str)[position - 1];
49     char next = position + old_word.size() < str->size()
50                     ? (*str)[position + old_word.size()]
51                     : '.';
52     if (IsWordSymbol(prev) || IsWordSymbol(next)) {
53       position = str->find(old_word, position + 1);
54       continue;
55     }
56     str->replace(position, old_word.size(), new_word);
57     position = str->find(old_word, position + new_word.size());
58   }
59 }
60 
GetNextWord(const std::string & code,size_t first_position)61 std::string GetNextWord(const std::string& code, size_t first_position) {
62   size_t pos = first_position;
63   char t = code[pos];
64   while (IsWordSymbol(t)) {
65     pos++;
66     t = code[pos];
67   }
68   return code.substr(first_position, pos - first_position);
69 }
70 
HasWord(const std::string & word,const std::string & text)71 bool HasWord(const std::string& word, const std::string& text) {
72   size_t pos = text.find(word);
73   while (pos != std::string::npos) {
74     char prev = pos == 0 ? '.' : text[pos - 1];
75     char next = pos + word.size() < text.size() ? text[pos + word.size()] : '.';
76     if (!IsWordSymbol(prev) && !IsWordSymbol(next)) {
77       return true;
78     }
79     pos = text.find(word, pos + 1);
80   }
81   return false;
82 }
83 
RenameArg(const std::vector<std::string> & object_names,const std::string & postfix,const std::string & arg_name)84 std::string RenameArg(const std::vector<std::string>& object_names,
85                       const std::string& postfix, const std::string& arg_name) {
86   for (const auto& object_name : object_names) {
87     if (absl::StartsWith(arg_name, object_name) &&
88         arg_name.size() > object_name.size() &&
89         arg_name[object_name.size()] == '_') {
90       return object_name + postfix +
91              arg_name.substr(object_name.size(),
92                              arg_name.size() - object_name.size());
93     }
94   }
95   return arg_name + postfix;
96 }
97 
FindEnclosingBracket(const std::string & text,size_t first_pos,char bracket)98 size_t FindEnclosingBracket(const std::string& text, size_t first_pos,
99                             char bracket) {
100   const std::map<char, char> brackets = {
101       {'(', ')'},
102       {'{', '}'},
103       {'[', ']'},
104       {'<', '>'},
105   };
106   char b_open = bracket;
107   auto it = brackets.find(b_open);
108   if (it == brackets.end()) {
109     return -1;
110   }
111   char b_close = it->second;
112   size_t pos = first_pos;
113   int opened = 1;
114   int closed = 0;
115   while (opened != closed && pos < text.size()) {
116     if (text[pos] == b_open) {
117       opened++;
118     } else if (text[pos] == b_close) {
119       closed++;
120     }
121     pos++;
122   }
123   if (opened == closed) {
124     return pos;
125   } else {
126     return -1;
127   }
128 }
129 
ParseArgsInsideBrackets(const std::string & text,size_t open_bracket_pos,size_t * close_bracket_pos,std::vector<std::string> * args)130 absl::Status ParseArgsInsideBrackets(const std::string& text,
131                                      size_t open_bracket_pos,
132                                      size_t* close_bracket_pos,
133                                      std::vector<std::string>* args) {
134   *close_bracket_pos =
135       FindEnclosingBracket(text, open_bracket_pos + 1, text[open_bracket_pos]);
136   if (*close_bracket_pos == -1) {
137     return absl::NotFoundError("Not found enclosing bracket");
138   }
139   std::string str_args = text.substr(open_bracket_pos + 1,
140                                      *close_bracket_pos - open_bracket_pos - 2);
141   std::vector<absl::string_view> words = absl::StrSplit(str_args, ',');
142   args->reserve(words.size());
143   for (const auto& word : words) {
144     absl::string_view arg = absl::StripAsciiWhitespace(word);
145     if (!arg.empty()) {
146       args->push_back(std::string(arg));
147     }
148   }
149   return absl::OkStatus();
150 }
151 
DataTypeToGlType(DataType data_type,int vec_size,bool explicit_f16)152 std::string DataTypeToGlType(DataType data_type, int vec_size,
153                              bool explicit_f16) {
154   if (data_type == DataType::FLOAT32) {
155     if (vec_size == 1) {
156       return "float";
157     } else {
158       return "vec" + std::to_string(vec_size);
159     }
160   } else if (data_type == DataType::FLOAT16) {
161     if (vec_size == 1) {
162       return explicit_f16 ? "float16_t" : "float";
163     } else {
164       if (explicit_f16) {
165         return "f16vec" + std::to_string(vec_size);
166       } else {
167         return "vec" + std::to_string(vec_size);
168       }
169     }
170   } else if (data_type == DataType::INT32) {
171     if (vec_size == 1) {
172       return "int";
173     } else {
174       return "ivec" + std::to_string(vec_size);
175     }
176   } else if (data_type == DataType::UINT32) {
177     if (vec_size == 1) {
178       return "uint";
179     } else {
180       return "uvec" + std::to_string(vec_size);
181     }
182   }
183   return "unsupported_type";
184 }
185 
BufferToKernelLanguage(const GpuInfo & gpu_info,const std::string & buffer_name,const BufferDescriptor * buffer_desc,std::string * result)186 absl::Status BufferToKernelLanguage(const GpuInfo& gpu_info,
187                                     const std::string& buffer_name,
188                                     const BufferDescriptor* buffer_desc,
189                                     std::string* result) {
190   if (buffer_desc->element_size != 1) {
191     return absl::UnimplementedError("No support of vector types.");
192   }
193   const int elements_count =
194       buffer_desc->size /
195       (buffer_desc->element_size * SizeOf(buffer_desc->element_type));
196   if (gpu_info.IsGlsl()) {
197     const std::string gl_type =
198         DataTypeToGlType(buffer_desc->element_type, buffer_desc->element_size,
199                          gpu_info.IsGlslSupportsExplicitFp16());
200     *result = "const ";
201     if (buffer_desc->element_type == DataType::FLOAT16 &&
202         !gpu_info.IsGlslSupportsExplicitFp16()) {
203       *result += "mediump ";
204     }
205     *result += gl_type + " " + buffer_name + "_buffer[] = " + gl_type + "[](\n";
206   } else if (gpu_info.IsApiMetal()) {
207     const std::string metal_type =
208         ToMetalDataType(buffer_desc->element_type, buffer_desc->element_size);
209     *result = "constant " + metal_type + " " + buffer_name + "_buffer[" +
210               std::to_string(elements_count) + "] = {\n";
211   } else if (gpu_info.IsApiOpenCl()) {
212     const std::string cl_type =
213         ToCLDataType(buffer_desc->element_type, buffer_desc->element_size);
214     *result = "__constant " + cl_type + " " + buffer_name + "_buffer[" +
215               std::to_string(elements_count) + "] = {\n";
216   } else {
217     return absl::UnimplementedError("Not supported API.");
218   }
219   if (buffer_desc->element_type == DataType::FLOAT16) {
220     std::string postfix = "f";
221     if (gpu_info.IsGlsl() && gpu_info.IsGlslSupportsExplicitFp16()) {
222       postfix = "hf";
223     }
224     const half* data_ptr =
225         reinterpret_cast<const half*>(buffer_desc->data.data());
226     for (int i = 0; i < elements_count; ++i) {
227       *result += "  " +
228                  absl::StrFormat("%.10f", static_cast<float>(data_ptr[i])) +
229                  postfix;
230       if (i != elements_count - 1) {
231         *result += ",\n";
232       }
233     }
234   } else if (buffer_desc->element_type == DataType::FLOAT32) {
235     const float* data_ptr =
236         reinterpret_cast<const float*>(buffer_desc->data.data());
237     for (int i = 0; i < elements_count; ++i) {
238       *result += "  " + absl::StrFormat("%.10f", data_ptr[i]) + "f";
239       if (i != elements_count - 1) {
240         *result += ",\n";
241       }
242     }
243   } else {
244     return absl::UnimplementedError("Not supported type.");
245   }
246   if (gpu_info.IsGlsl()) {
247     *result += ");\n";
248   } else {
249     *result += "};\n";
250   }
251 
252   return absl::OkStatus();
253 }
254 
255 }  // namespace
256 
257 // Static
258 constexpr char Arguments::kArgsPrefix[];
259 
AddFloat(const std::string & name,float value)260 void Arguments::AddFloat(const std::string& name, float value) {
261   float_values_[name].value = value;
262 }
AddHalf(const std::string & name,half value)263 void Arguments::AddHalf(const std::string& name, half value) {
264   half_values_[name].value = value;
265 }
AddInt(const std::string & name,int value)266 void Arguments::AddInt(const std::string& name, int value) {
267   int_values_[name].value = value;
268 }
269 
SetInt(const std::string & name,int value)270 absl::Status Arguments::SetInt(const std::string& name, int value) {
271   auto it = int_values_.find(name);
272   if (it == int_values_.end()) {
273     return absl::NotFoundError(
274         absl::StrCat("No int argument with name - ", name));
275   }
276   it->second.value = value;
277   return absl::OkStatus();
278 }
SetFloat(const std::string & name,float value)279 absl::Status Arguments::SetFloat(const std::string& name, float value) {
280   auto it = float_values_.find(name);
281   if (it == float_values_.end()) {
282     return absl::NotFoundError(
283         absl::StrCat("No float argument with name - ", name));
284   }
285   it->second.value = value;
286   return absl::OkStatus();
287 }
288 
SetHalf(const std::string & name,half value)289 absl::Status Arguments::SetHalf(const std::string& name, half value) {
290   auto it = half_values_.find(name);
291   if (it == half_values_.end()) {
292     return absl::NotFoundError(
293         absl::StrCat("No half argument with name - ", name));
294   }
295   it->second.value = value;
296   return absl::OkStatus();
297 }
298 
AddObjectRef(const std::string & name,AccessType access_type,GPUObjectDescriptorPtr && descriptor_ptr)299 void Arguments::AddObjectRef(const std::string& name, AccessType access_type,
300                              GPUObjectDescriptorPtr&& descriptor_ptr) {
301   descriptor_ptr->SetAccess(access_type);
302   object_refs_[name] = {std::move(descriptor_ptr)};
303 }
304 
AddObject(const std::string & name,GPUObjectDescriptorPtr && descriptor_ptr)305 void Arguments::AddObject(const std::string& name,
306                           GPUObjectDescriptorPtr&& descriptor_ptr) {
307   descriptor_ptr->SetAccess(AccessType::READ);
308   objects_[name] = {std::move(descriptor_ptr)};
309 }
310 
RenameArgs(const std::string & postfix,std::string * code) const311 void Arguments::RenameArgs(const std::string& postfix,
312                            std::string* code) const {
313   size_t next_position = code->find(kArgsPrefix);
314   while (next_position != std::string::npos) {
315     size_t arg_pos = next_position + strlen(kArgsPrefix);
316     std::string arg_name = GetNextWord(*code, arg_pos);
317     code->replace(arg_pos, arg_name.size(), arg_name + postfix);
318     next_position = code->find(kArgsPrefix, arg_pos + arg_name.size());
319   }
320 }
321 
Merge(Arguments && args,const std::string & postfix,const std::vector<std::string> & exception_names)322 absl::Status Arguments::Merge(Arguments&& args, const std::string& postfix,
323                               const std::vector<std::string>& exception_names) {
324   std::vector<std::string> object_names;
325   object_names.reserve(args.object_refs_.size() + args.objects_.size());
326   for (auto& v : args.object_refs_) {
327     if (std::find(exception_names.begin(), exception_names.end(), v.first) !=
328         exception_names.end()) {
329       continue;
330     }
331     object_names.push_back(v.first);
332     const std::string name = v.first + postfix;
333     if (object_refs_.find(name) != object_refs_.end()) {
334       return absl::InvalidArgumentError(
335           absl::StrCat("Object reference name collision. Name - ", name));
336     }
337     object_refs_[name] = {std::move(v.second)};
338   }
339   for (auto& v : args.objects_) {
340     if (std::find(exception_names.begin(), exception_names.end(), v.first) !=
341         exception_names.end()) {
342       continue;
343     }
344     object_names.push_back(v.first);
345     const std::string name = v.first + postfix;
346     if (objects_.find(name) != objects_.end()) {
347       return absl::InvalidArgumentError(
348           absl::StrCat("Object name collision. Name - ", name));
349     }
350     objects_[name] = {std::move(v.second)};
351   }
352   for (const auto& v : args.int_values_) {
353     AddInt(RenameArg(object_names, postfix, v.first), v.second.value);
354   }
355   for (const auto& v : args.float_values_) {
356     AddFloat(RenameArg(object_names, postfix, v.first), v.second.value);
357   }
358   for (const auto& v : args.half_values_) {
359     AddHalf(RenameArg(object_names, postfix, v.first), v.second.value);
360   }
361   return absl::OkStatus();
362 }
363 
GetDescriptor(const std::string & name,GPUObjectDescriptor ** descriptor) const364 absl::Status Arguments::GetDescriptor(const std::string& name,
365                                       GPUObjectDescriptor** descriptor) const {
366   auto it_ref = object_refs_.find(name);
367   if (it_ref != object_refs_.end()) {
368     *descriptor = it_ref->second.get();
369     return absl::OkStatus();
370   }
371   auto it = objects_.find(name);
372   if (it != objects_.end()) {
373     *descriptor = it->second.get();
374     return absl::OkStatus();
375   }
376   return absl::NotFoundError(absl::StrCat("No GPU object with name - ", name));
377 }
378 
ReleaseCPURepresentation()379 void Arguments::ReleaseCPURepresentation() {
380   for (auto& t : objects_) {
381     t.second->Release();
382   }
383 }
384 
GetActiveArguments(const std::string & code)385 void Arguments::GetActiveArguments(const std::string& code) {
386   for (auto& float_val : float_values_) {
387     float_val.second.active = HasWord(kArgsPrefix + float_val.first, code);
388   }
389   for (auto& int_val : int_values_) {
390     int_val.second.active = HasWord(kArgsPrefix + int_val.first, code);
391   }
392   for (auto& half_val : half_values_) {
393     half_val.second.active = HasWord(kArgsPrefix + half_val.first, code);
394   }
395 }
396 
GetReadTexturesCount(const GpuInfo & gpu_info) const397 int Arguments::GetReadTexturesCount(const GpuInfo& gpu_info) const {
398   int counter = 0;
399   for (auto& t : objects_) {
400     counter += t.second->GetGPUResources(gpu_info).GetReadImagesCount();
401   }
402   for (auto& t : object_refs_) {
403     counter += t.second->GetGPUResources(gpu_info).GetReadImagesCount();
404   }
405   return counter;
406 }
407 
GetWriteTexturesCount(const GpuInfo & gpu_info) const408 int Arguments::GetWriteTexturesCount(const GpuInfo& gpu_info) const {
409   int counter = 0;
410   for (auto& t : objects_) {
411     counter += t.second->GetGPUResources(gpu_info).GetWriteImagesCount();
412   }
413   for (auto& t : object_refs_) {
414     counter += t.second->GetGPUResources(gpu_info).GetWriteImagesCount();
415   }
416   return counter;
417 }
418 
SetStateValueForAllObjects(const std::string & key,const std::string & value)419 void Arguments::SetStateValueForAllObjects(const std::string& key,
420                                            const std::string& value) {
421   for (auto& obj : object_refs_) {
422     obj.second->SetStateVar(key, value);
423   }
424   for (auto& obj : objects_) {
425     obj.second->SetStateVar(key, value);
426   }
427 }
428 
Compile(const GpuInfo & gpu_info,const std::map<std::string,std::string> & linkables,std::string * code)429 absl::Status Arguments::Compile(
430     const GpuInfo& gpu_info,
431     const std::map<std::string, std::string>& linkables, std::string* code) {
432   RETURN_IF_ERROR(AddObjectsScalarArgs(gpu_info));
433   RETURN_IF_ERROR(ResolveConstExprPass(gpu_info, code));
434   RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, linkables, code));
435   GetActiveArguments(*code);
436   RETURN_IF_ERROR(ResolveKernelGlobalSpaceBuffers(gpu_info, code));
437   return absl::OkStatus();
438 }
439 
ResolveConstExprPass(const GpuInfo & gpu_info,std::string * code) const440 absl::Status Arguments::ResolveConstExprPass(const GpuInfo& gpu_info,
441                                              std::string* code) const {
442   std::string result;
443   size_t position = 0;
444   size_t next_position = code->find(kArgsPrefix);
445   while (next_position != std::string::npos) {
446     size_t arg_pos = next_position;
447     next_position += strlen(kArgsPrefix);
448     std::string object_name = GetNextWord(*code, next_position);
449     if (next_position + object_name.size() > code->size() - 2) {
450       next_position = code->find(kArgsPrefix, next_position);
451       continue;
452     }
453     char next0 = (*code)[next_position + object_name.size()];
454     char next1 = (*code)[next_position + object_name.size() + 1];
455     if (next0 == ':' && next1 == ':') {
456       next_position += object_name.size() + 2;
457       std::string const_expr_name = GetNextWord(*code, next_position);
458       next_position += const_expr_name.size();
459       std::string patch;
460       RETURN_IF_ERROR(
461           ResolveConstExpr(gpu_info, object_name, const_expr_name, &patch));
462       code->replace(arg_pos, next_position - arg_pos, patch);
463       position = arg_pos + patch.size();
464     } else {
465       position = arg_pos + strlen(kArgsPrefix);
466     }
467     next_position = code->find(kArgsPrefix, position);
468   }
469   return absl::OkStatus();
470 }
471 
ResolveConstExpr(const GpuInfo & gpu_info,const std::string & object_name,const std::string & const_expr,std::string * result) const472 absl::Status Arguments::ResolveConstExpr(const GpuInfo& gpu_info,
473                                          const std::string& object_name,
474                                          const std::string& const_expr,
475                                          std::string* result) const {
476   tflite::gpu::GPUObjectDescriptor* desc_ptr;
477   RETURN_IF_ERROR(GetDescriptor(object_name, &desc_ptr));
478   RETURN_IF_ERROR(desc_ptr->PerformConstExpr(gpu_info, const_expr, result));
479   return absl::OkStatus();
480 }
481 
ResolveSelectorsPass(const GpuInfo & gpu_info,const std::map<std::string,std::string> & linkables,std::string * code) const482 absl::Status Arguments::ResolveSelectorsPass(
483     const GpuInfo& gpu_info,
484     const std::map<std::string, std::string>& linkables,
485     std::string* code) const {
486   std::string result;
487   size_t position = 0;
488   size_t next_position = code->find(kArgsPrefix);
489   while (next_position != std::string::npos) {
490     size_t arg_pos = next_position;
491     next_position += strlen(kArgsPrefix);
492     std::string object_name = GetNextWord(*code, next_position);
493     char next = (*code)[next_position + object_name.size()];
494     if (next == '.') {
495       next_position += object_name.size() + 1;
496       std::string selector_name = GetNextWord(*code, next_position);
497       next_position += selector_name.size();
498       next = (*code)[next_position];
499       std::vector<std::string> template_args;
500       if (next == '<') {
501         size_t close_bracket_pos;
502         RETURN_IF_ERROR(ParseArgsInsideBrackets(
503             *code, next_position, &close_bracket_pos, &template_args));
504         next_position = close_bracket_pos;
505         next = (*code)[next_position];
506       }
507       if (next != '(') {
508         return absl::NotFoundError(absl::StrCat(
509             "Expected ( after ", object_name, ".", selector_name, " call"));
510       }
511       std::vector<std::string> function_args;
512       size_t close_bracket_pos;
513       RETURN_IF_ERROR(ParseArgsInsideBrackets(
514           *code, next_position, &close_bracket_pos, &function_args));
515       for (auto& arg : function_args) {
516         RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, {}, &arg));
517       }
518       std::string patch;
519       RETURN_IF_ERROR(ResolveSelector(gpu_info, linkables, object_name,
520                                       selector_name, function_args,
521                                       template_args, &patch));
522       code->replace(arg_pos, close_bracket_pos - arg_pos, patch);
523       position = arg_pos + patch.size();
524     } else {
525       position = arg_pos + strlen(kArgsPrefix);
526     }
527     next_position = code->find(kArgsPrefix, position);
528   }
529   return absl::OkStatus();
530 }
531 
ResolveSelector(const GpuInfo & gpu_info,const std::map<std::string,std::string> & linkables,const std::string & object_name,const std::string & selector,const std::vector<std::string> & function_args,const std::vector<std::string> & template_args,std::string * result) const532 absl::Status Arguments::ResolveSelector(
533     const GpuInfo& gpu_info,
534     const std::map<std::string, std::string>& linkables,
535     const std::string& object_name, const std::string& selector,
536     const std::vector<std::string>& function_args,
537     const std::vector<std::string>& template_args, std::string* result) const {
538   GPUObjectDescriptor* desc_ptr;
539   RETURN_IF_ERROR(GetDescriptor(object_name, &desc_ptr));
540   auto names = desc_ptr->GetGPUResources(gpu_info).GetNames();
541   const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc_ptr);
542   std::vector<std::string> function_args_new = function_args;
543   if (tensor_desc && !linkables.empty() && selector == "Write") {
544     auto it = linkables.find(object_name);
545     if (it != linkables.end() && !it->second.empty()) {
546       if (desc_ptr->GetAccess() != AccessType::WRITE &&
547           desc_ptr->GetAccess() != AccessType::READ_WRITE) {
548         return absl::FailedPreconditionError(absl::StrCat(
549             "Object with name - ", object_name, " should have Write access."));
550       }
551       std::string value_name, x_coord, y_coord, z_coord, s_coord, b_coord;
552       RETURN_IF_ERROR(tensor_desc->GetLinkingContextFromWriteSelector(
553           function_args_new, &value_name, &x_coord, &y_coord, &z_coord,
554           &s_coord, &b_coord));
555       const std::string new_value_name = value_name + "_final";
556       const std::string out_var_declaration =
557           "\n" + GetTypeDeclaration(gpu_info, tensor_desc->GetDataType(), 4) +
558           " " + new_value_name + ";\n";
559       *result = "{  // elementwise code with input:" + value_name +
560                 absl::Substitute(it->second, out_var_declaration) + "\n";
561       *result = absl::StrReplaceAll(*result, {{"\n", "\n  "}});
562       ReplaceAllWords("in_value", value_name, result);
563       ReplaceAllWords("out_value", new_value_name, result);
564       ReplaceAllWords("X_COORD", x_coord, result);
565       ReplaceAllWords("Y_COORD", y_coord, result);
566       ReplaceAllWords("Z_COORD", z_coord, result);
567       ReplaceAllWords("S_COORD", s_coord, result);
568       ReplaceAllWords("B_COORD", b_coord, result);
569       function_args_new[0] = new_value_name;
570       RETURN_IF_ERROR(ResolveConstExprPass(gpu_info, result));
571       RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, {}, result));
572     }
573   }
574   std::string patch;
575   RETURN_IF_ERROR(desc_ptr->PerformSelector(
576       gpu_info, selector, function_args_new, template_args, &patch));
577   ResolveObjectNames(object_name, names, &patch);
578   if (result->empty()) {
579     *result += patch;
580   } else {
581     // result has elementwise code
582     *result += "// write result to tensor\n  " + patch + ";\n}";
583   }
584   return absl::OkStatus();
585 }
586 
ResolveObjectNames(const std::string & object_name,const std::vector<std::string> & member_names,std::string * code) const587 void Arguments::ResolveObjectNames(const std::string& object_name,
588                                    const std::vector<std::string>& member_names,
589                                    std::string* code) const {
590   for (const auto& member_name : member_names) {
591     const std::string new_name = kArgsPrefix + object_name + "_" + member_name;
592     ReplaceAllWords(member_name, new_name, code);
593   }
594 }
595 
AddObjectsScalarArgs(const GpuInfo & gpu_info)596 absl::Status Arguments::AddObjectsScalarArgs(const GpuInfo& gpu_info) {
597   for (auto& t : objects_) {
598     const auto resources = t.second->GetGPUResources(gpu_info);
599     for (const auto& r : resources.ints) {
600       AddInt(absl::StrCat(t.first, "_", r));
601     }
602     for (const auto& r : resources.floats) {
603       AddFloat(absl::StrCat(t.first, "_", r));
604     }
605   }
606   for (auto& t : object_refs_) {
607     const auto resources = t.second->GetGPUResources(gpu_info);
608     for (const auto& r : resources.ints) {
609       AddInt(absl::StrCat(t.first, "_", r));
610     }
611     for (const auto& r : resources.floats) {
612       AddFloat(absl::StrCat(t.first, "_", r));
613     }
614   }
615   return absl::OkStatus();
616 }
617 
ResolveArgsPass(std::string * code) const618 void Arguments::ResolveArgsPass(std::string* code) const {
619   size_t position = 0;
620   size_t next_position = code->find(kArgsPrefix);
621   while (next_position != std::string::npos) {
622     size_t arg_pos = next_position;
623     next_position += strlen(kArgsPrefix);
624     std::string object_name = GetNextWord(*code, next_position);
625     std::string new_name = object_name;
626     code->replace(arg_pos, object_name.size() + strlen(kArgsPrefix), new_name);
627     position = arg_pos + new_name.size();
628     next_position = code->find(kArgsPrefix, position);
629   }
630 }
631 
ResolveKernelGlobalSpaceBuffers(const GpuInfo & gpu_info,std::string * code)632 absl::Status Arguments::ResolveKernelGlobalSpaceBuffers(const GpuInfo& gpu_info,
633                                                         std::string* code) {
634   for (auto it = objects_.begin(); it != objects_.end();) {
635     const auto* buffer_desc =
636         dynamic_cast<const BufferDescriptor*>(it->second.get());
637     if (!buffer_desc || buffer_desc->memory_type != MemoryType::CONSTANT) {
638       ++it;
639       continue;
640     }
641     bool is_kernel_global_space = false;
642     for (const auto& attribute : buffer_desc->attributes) {
643       if (attribute == "kernel_global_space") {
644         is_kernel_global_space = true;
645         break;
646       }
647     }
648     if (!is_kernel_global_space) {
649       ++it;
650       continue;
651     }
652     std::string declaration;
653     if (!BufferToKernelLanguage(gpu_info, it->first, buffer_desc, &declaration)
654              .ok()) {
655       ++it;
656       continue;
657     }
658     *code = declaration + *code;
659     objects_.erase(it++);
660   }
661   return absl::OkStatus();
662 }
663 
664 }  // namespace gpu
665 }  // namespace tflite
666