1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/task/arguments.h"
17
18 #include <algorithm>
19 #include <map>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/strings/ascii.h"
25 #include "absl/strings/match.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "absl/strings/str_replace.h"
29 #include "absl/strings/substitute.h"
30 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
31 #include "tensorflow/lite/delegates/gpu/common/status.h"
32 #include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h"
33 #include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
34 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
35 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
36
37 namespace tflite {
38 namespace gpu {
39 namespace {
IsWordSymbol(char symbol)40 bool IsWordSymbol(char symbol) {
41 return absl::ascii_isalnum(symbol) || symbol == '_';
42 }
43
ReplaceAllWords(const std::string & old_word,const std::string & new_word,std::string * str)44 void ReplaceAllWords(const std::string& old_word, const std::string& new_word,
45 std::string* str) {
46 size_t position = str->find(old_word);
47 while (position != std::string::npos) {
48 char prev = position == 0 ? '.' : (*str)[position - 1];
49 char next = position + old_word.size() < str->size()
50 ? (*str)[position + old_word.size()]
51 : '.';
52 if (IsWordSymbol(prev) || IsWordSymbol(next)) {
53 position = str->find(old_word, position + 1);
54 continue;
55 }
56 str->replace(position, old_word.size(), new_word);
57 position = str->find(old_word, position + new_word.size());
58 }
59 }
60
GetNextWord(const std::string & code,size_t first_position)61 std::string GetNextWord(const std::string& code, size_t first_position) {
62 size_t pos = first_position;
63 char t = code[pos];
64 while (IsWordSymbol(t)) {
65 pos++;
66 t = code[pos];
67 }
68 return code.substr(first_position, pos - first_position);
69 }
70
HasWord(const std::string & word,const std::string & text)71 bool HasWord(const std::string& word, const std::string& text) {
72 size_t pos = text.find(word);
73 while (pos != std::string::npos) {
74 char prev = pos == 0 ? '.' : text[pos - 1];
75 char next = pos + word.size() < text.size() ? text[pos + word.size()] : '.';
76 if (!IsWordSymbol(prev) && !IsWordSymbol(next)) {
77 return true;
78 }
79 pos = text.find(word, pos + 1);
80 }
81 return false;
82 }
83
RenameArg(const std::vector<std::string> & object_names,const std::string & postfix,const std::string & arg_name)84 std::string RenameArg(const std::vector<std::string>& object_names,
85 const std::string& postfix, const std::string& arg_name) {
86 for (const auto& object_name : object_names) {
87 if (absl::StartsWith(arg_name, object_name) &&
88 arg_name.size() > object_name.size() &&
89 arg_name[object_name.size()] == '_') {
90 return object_name + postfix +
91 arg_name.substr(object_name.size(),
92 arg_name.size() - object_name.size());
93 }
94 }
95 return arg_name + postfix;
96 }
97
FindEnclosingBracket(const std::string & text,size_t first_pos,char bracket)98 size_t FindEnclosingBracket(const std::string& text, size_t first_pos,
99 char bracket) {
100 const std::map<char, char> brackets = {
101 {'(', ')'},
102 {'{', '}'},
103 {'[', ']'},
104 {'<', '>'},
105 };
106 char b_open = bracket;
107 auto it = brackets.find(b_open);
108 if (it == brackets.end()) {
109 return -1;
110 }
111 char b_close = it->second;
112 size_t pos = first_pos;
113 int opened = 1;
114 int closed = 0;
115 while (opened != closed && pos < text.size()) {
116 if (text[pos] == b_open) {
117 opened++;
118 } else if (text[pos] == b_close) {
119 closed++;
120 }
121 pos++;
122 }
123 if (opened == closed) {
124 return pos;
125 } else {
126 return -1;
127 }
128 }
129
ParseArgsInsideBrackets(const std::string & text,size_t open_bracket_pos,size_t * close_bracket_pos,std::vector<std::string> * args)130 absl::Status ParseArgsInsideBrackets(const std::string& text,
131 size_t open_bracket_pos,
132 size_t* close_bracket_pos,
133 std::vector<std::string>* args) {
134 *close_bracket_pos =
135 FindEnclosingBracket(text, open_bracket_pos + 1, text[open_bracket_pos]);
136 if (*close_bracket_pos == -1) {
137 return absl::NotFoundError("Not found enclosing bracket");
138 }
139 std::string str_args = text.substr(open_bracket_pos + 1,
140 *close_bracket_pos - open_bracket_pos - 2);
141 std::vector<absl::string_view> words = absl::StrSplit(str_args, ',');
142 args->reserve(words.size());
143 for (const auto& word : words) {
144 absl::string_view arg = absl::StripAsciiWhitespace(word);
145 if (!arg.empty()) {
146 args->push_back(std::string(arg));
147 }
148 }
149 return absl::OkStatus();
150 }
151
DataTypeToGlType(DataType data_type,int vec_size,bool explicit_f16)152 std::string DataTypeToGlType(DataType data_type, int vec_size,
153 bool explicit_f16) {
154 if (data_type == DataType::FLOAT32) {
155 if (vec_size == 1) {
156 return "float";
157 } else {
158 return "vec" + std::to_string(vec_size);
159 }
160 } else if (data_type == DataType::FLOAT16) {
161 if (vec_size == 1) {
162 return explicit_f16 ? "float16_t" : "float";
163 } else {
164 if (explicit_f16) {
165 return "f16vec" + std::to_string(vec_size);
166 } else {
167 return "vec" + std::to_string(vec_size);
168 }
169 }
170 } else if (data_type == DataType::INT32) {
171 if (vec_size == 1) {
172 return "int";
173 } else {
174 return "ivec" + std::to_string(vec_size);
175 }
176 } else if (data_type == DataType::UINT32) {
177 if (vec_size == 1) {
178 return "uint";
179 } else {
180 return "uvec" + std::to_string(vec_size);
181 }
182 }
183 return "unsupported_type";
184 }
185
BufferToKernelLanguage(const GpuInfo & gpu_info,const std::string & buffer_name,const BufferDescriptor * buffer_desc,std::string * result)186 absl::Status BufferToKernelLanguage(const GpuInfo& gpu_info,
187 const std::string& buffer_name,
188 const BufferDescriptor* buffer_desc,
189 std::string* result) {
190 if (buffer_desc->element_size != 1) {
191 return absl::UnimplementedError("No support of vector types.");
192 }
193 const int elements_count =
194 buffer_desc->size /
195 (buffer_desc->element_size * SizeOf(buffer_desc->element_type));
196 if (gpu_info.IsGlsl()) {
197 const std::string gl_type =
198 DataTypeToGlType(buffer_desc->element_type, buffer_desc->element_size,
199 gpu_info.IsGlslSupportsExplicitFp16());
200 *result = "const ";
201 if (buffer_desc->element_type == DataType::FLOAT16 &&
202 !gpu_info.IsGlslSupportsExplicitFp16()) {
203 *result += "mediump ";
204 }
205 *result += gl_type + " " + buffer_name + "_buffer[] = " + gl_type + "[](\n";
206 } else if (gpu_info.IsApiMetal()) {
207 const std::string metal_type =
208 ToMetalDataType(buffer_desc->element_type, buffer_desc->element_size);
209 *result = "constant " + metal_type + " " + buffer_name + "_buffer[" +
210 std::to_string(elements_count) + "] = {\n";
211 } else if (gpu_info.IsApiOpenCl()) {
212 const std::string cl_type =
213 ToCLDataType(buffer_desc->element_type, buffer_desc->element_size);
214 *result = "__constant " + cl_type + " " + buffer_name + "_buffer[" +
215 std::to_string(elements_count) + "] = {\n";
216 } else {
217 return absl::UnimplementedError("Not supported API.");
218 }
219 if (buffer_desc->element_type == DataType::FLOAT16) {
220 std::string postfix = "f";
221 if (gpu_info.IsGlsl() && gpu_info.IsGlslSupportsExplicitFp16()) {
222 postfix = "hf";
223 }
224 const half* data_ptr =
225 reinterpret_cast<const half*>(buffer_desc->data.data());
226 for (int i = 0; i < elements_count; ++i) {
227 *result += " " +
228 absl::StrFormat("%.10f", static_cast<float>(data_ptr[i])) +
229 postfix;
230 if (i != elements_count - 1) {
231 *result += ",\n";
232 }
233 }
234 } else if (buffer_desc->element_type == DataType::FLOAT32) {
235 const float* data_ptr =
236 reinterpret_cast<const float*>(buffer_desc->data.data());
237 for (int i = 0; i < elements_count; ++i) {
238 *result += " " + absl::StrFormat("%.10f", data_ptr[i]) + "f";
239 if (i != elements_count - 1) {
240 *result += ",\n";
241 }
242 }
243 } else {
244 return absl::UnimplementedError("Not supported type.");
245 }
246 if (gpu_info.IsGlsl()) {
247 *result += ");\n";
248 } else {
249 *result += "};\n";
250 }
251
252 return absl::OkStatus();
253 }
254
255 } // namespace
256
257 // Static
258 constexpr char Arguments::kArgsPrefix[];
259
AddFloat(const std::string & name,float value)260 void Arguments::AddFloat(const std::string& name, float value) {
261 float_values_[name].value = value;
262 }
AddHalf(const std::string & name,half value)263 void Arguments::AddHalf(const std::string& name, half value) {
264 half_values_[name].value = value;
265 }
AddInt(const std::string & name,int value)266 void Arguments::AddInt(const std::string& name, int value) {
267 int_values_[name].value = value;
268 }
269
SetInt(const std::string & name,int value)270 absl::Status Arguments::SetInt(const std::string& name, int value) {
271 auto it = int_values_.find(name);
272 if (it == int_values_.end()) {
273 return absl::NotFoundError(
274 absl::StrCat("No int argument with name - ", name));
275 }
276 it->second.value = value;
277 return absl::OkStatus();
278 }
SetFloat(const std::string & name,float value)279 absl::Status Arguments::SetFloat(const std::string& name, float value) {
280 auto it = float_values_.find(name);
281 if (it == float_values_.end()) {
282 return absl::NotFoundError(
283 absl::StrCat("No float argument with name - ", name));
284 }
285 it->second.value = value;
286 return absl::OkStatus();
287 }
288
SetHalf(const std::string & name,half value)289 absl::Status Arguments::SetHalf(const std::string& name, half value) {
290 auto it = half_values_.find(name);
291 if (it == half_values_.end()) {
292 return absl::NotFoundError(
293 absl::StrCat("No half argument with name - ", name));
294 }
295 it->second.value = value;
296 return absl::OkStatus();
297 }
298
AddObjectRef(const std::string & name,AccessType access_type,GPUObjectDescriptorPtr && descriptor_ptr)299 void Arguments::AddObjectRef(const std::string& name, AccessType access_type,
300 GPUObjectDescriptorPtr&& descriptor_ptr) {
301 descriptor_ptr->SetAccess(access_type);
302 object_refs_[name] = {std::move(descriptor_ptr)};
303 }
304
AddObject(const std::string & name,GPUObjectDescriptorPtr && descriptor_ptr)305 void Arguments::AddObject(const std::string& name,
306 GPUObjectDescriptorPtr&& descriptor_ptr) {
307 descriptor_ptr->SetAccess(AccessType::READ);
308 objects_[name] = {std::move(descriptor_ptr)};
309 }
310
RenameArgs(const std::string & postfix,std::string * code) const311 void Arguments::RenameArgs(const std::string& postfix,
312 std::string* code) const {
313 size_t next_position = code->find(kArgsPrefix);
314 while (next_position != std::string::npos) {
315 size_t arg_pos = next_position + strlen(kArgsPrefix);
316 std::string arg_name = GetNextWord(*code, arg_pos);
317 code->replace(arg_pos, arg_name.size(), arg_name + postfix);
318 next_position = code->find(kArgsPrefix, arg_pos + arg_name.size());
319 }
320 }
321
Merge(Arguments && args,const std::string & postfix,const std::vector<std::string> & exception_names)322 absl::Status Arguments::Merge(Arguments&& args, const std::string& postfix,
323 const std::vector<std::string>& exception_names) {
324 std::vector<std::string> object_names;
325 object_names.reserve(args.object_refs_.size() + args.objects_.size());
326 for (auto& v : args.object_refs_) {
327 if (std::find(exception_names.begin(), exception_names.end(), v.first) !=
328 exception_names.end()) {
329 continue;
330 }
331 object_names.push_back(v.first);
332 const std::string name = v.first + postfix;
333 if (object_refs_.find(name) != object_refs_.end()) {
334 return absl::InvalidArgumentError(
335 absl::StrCat("Object reference name collision. Name - ", name));
336 }
337 object_refs_[name] = {std::move(v.second)};
338 }
339 for (auto& v : args.objects_) {
340 if (std::find(exception_names.begin(), exception_names.end(), v.first) !=
341 exception_names.end()) {
342 continue;
343 }
344 object_names.push_back(v.first);
345 const std::string name = v.first + postfix;
346 if (objects_.find(name) != objects_.end()) {
347 return absl::InvalidArgumentError(
348 absl::StrCat("Object name collision. Name - ", name));
349 }
350 objects_[name] = {std::move(v.second)};
351 }
352 for (const auto& v : args.int_values_) {
353 AddInt(RenameArg(object_names, postfix, v.first), v.second.value);
354 }
355 for (const auto& v : args.float_values_) {
356 AddFloat(RenameArg(object_names, postfix, v.first), v.second.value);
357 }
358 for (const auto& v : args.half_values_) {
359 AddHalf(RenameArg(object_names, postfix, v.first), v.second.value);
360 }
361 return absl::OkStatus();
362 }
363
GetDescriptor(const std::string & name,GPUObjectDescriptor ** descriptor) const364 absl::Status Arguments::GetDescriptor(const std::string& name,
365 GPUObjectDescriptor** descriptor) const {
366 auto it_ref = object_refs_.find(name);
367 if (it_ref != object_refs_.end()) {
368 *descriptor = it_ref->second.get();
369 return absl::OkStatus();
370 }
371 auto it = objects_.find(name);
372 if (it != objects_.end()) {
373 *descriptor = it->second.get();
374 return absl::OkStatus();
375 }
376 return absl::NotFoundError(absl::StrCat("No GPU object with name - ", name));
377 }
378
ReleaseCPURepresentation()379 void Arguments::ReleaseCPURepresentation() {
380 for (auto& t : objects_) {
381 t.second->Release();
382 }
383 }
384
GetActiveArguments(const std::string & code)385 void Arguments::GetActiveArguments(const std::string& code) {
386 for (auto& float_val : float_values_) {
387 float_val.second.active = HasWord(kArgsPrefix + float_val.first, code);
388 }
389 for (auto& int_val : int_values_) {
390 int_val.second.active = HasWord(kArgsPrefix + int_val.first, code);
391 }
392 for (auto& half_val : half_values_) {
393 half_val.second.active = HasWord(kArgsPrefix + half_val.first, code);
394 }
395 }
396
GetReadTexturesCount(const GpuInfo & gpu_info) const397 int Arguments::GetReadTexturesCount(const GpuInfo& gpu_info) const {
398 int counter = 0;
399 for (auto& t : objects_) {
400 counter += t.second->GetGPUResources(gpu_info).GetReadImagesCount();
401 }
402 for (auto& t : object_refs_) {
403 counter += t.second->GetGPUResources(gpu_info).GetReadImagesCount();
404 }
405 return counter;
406 }
407
GetWriteTexturesCount(const GpuInfo & gpu_info) const408 int Arguments::GetWriteTexturesCount(const GpuInfo& gpu_info) const {
409 int counter = 0;
410 for (auto& t : objects_) {
411 counter += t.second->GetGPUResources(gpu_info).GetWriteImagesCount();
412 }
413 for (auto& t : object_refs_) {
414 counter += t.second->GetGPUResources(gpu_info).GetWriteImagesCount();
415 }
416 return counter;
417 }
418
SetStateValueForAllObjects(const std::string & key,const std::string & value)419 void Arguments::SetStateValueForAllObjects(const std::string& key,
420 const std::string& value) {
421 for (auto& obj : object_refs_) {
422 obj.second->SetStateVar(key, value);
423 }
424 for (auto& obj : objects_) {
425 obj.second->SetStateVar(key, value);
426 }
427 }
428
Compile(const GpuInfo & gpu_info,const std::map<std::string,std::string> & linkables,std::string * code)429 absl::Status Arguments::Compile(
430 const GpuInfo& gpu_info,
431 const std::map<std::string, std::string>& linkables, std::string* code) {
432 RETURN_IF_ERROR(AddObjectsScalarArgs(gpu_info));
433 RETURN_IF_ERROR(ResolveConstExprPass(gpu_info, code));
434 RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, linkables, code));
435 GetActiveArguments(*code);
436 RETURN_IF_ERROR(ResolveKernelGlobalSpaceBuffers(gpu_info, code));
437 return absl::OkStatus();
438 }
439
ResolveConstExprPass(const GpuInfo & gpu_info,std::string * code) const440 absl::Status Arguments::ResolveConstExprPass(const GpuInfo& gpu_info,
441 std::string* code) const {
442 std::string result;
443 size_t position = 0;
444 size_t next_position = code->find(kArgsPrefix);
445 while (next_position != std::string::npos) {
446 size_t arg_pos = next_position;
447 next_position += strlen(kArgsPrefix);
448 std::string object_name = GetNextWord(*code, next_position);
449 if (next_position + object_name.size() > code->size() - 2) {
450 next_position = code->find(kArgsPrefix, next_position);
451 continue;
452 }
453 char next0 = (*code)[next_position + object_name.size()];
454 char next1 = (*code)[next_position + object_name.size() + 1];
455 if (next0 == ':' && next1 == ':') {
456 next_position += object_name.size() + 2;
457 std::string const_expr_name = GetNextWord(*code, next_position);
458 next_position += const_expr_name.size();
459 std::string patch;
460 RETURN_IF_ERROR(
461 ResolveConstExpr(gpu_info, object_name, const_expr_name, &patch));
462 code->replace(arg_pos, next_position - arg_pos, patch);
463 position = arg_pos + patch.size();
464 } else {
465 position = arg_pos + strlen(kArgsPrefix);
466 }
467 next_position = code->find(kArgsPrefix, position);
468 }
469 return absl::OkStatus();
470 }
471
ResolveConstExpr(const GpuInfo & gpu_info,const std::string & object_name,const std::string & const_expr,std::string * result) const472 absl::Status Arguments::ResolveConstExpr(const GpuInfo& gpu_info,
473 const std::string& object_name,
474 const std::string& const_expr,
475 std::string* result) const {
476 tflite::gpu::GPUObjectDescriptor* desc_ptr;
477 RETURN_IF_ERROR(GetDescriptor(object_name, &desc_ptr));
478 RETURN_IF_ERROR(desc_ptr->PerformConstExpr(gpu_info, const_expr, result));
479 return absl::OkStatus();
480 }
481
ResolveSelectorsPass(const GpuInfo & gpu_info,const std::map<std::string,std::string> & linkables,std::string * code) const482 absl::Status Arguments::ResolveSelectorsPass(
483 const GpuInfo& gpu_info,
484 const std::map<std::string, std::string>& linkables,
485 std::string* code) const {
486 std::string result;
487 size_t position = 0;
488 size_t next_position = code->find(kArgsPrefix);
489 while (next_position != std::string::npos) {
490 size_t arg_pos = next_position;
491 next_position += strlen(kArgsPrefix);
492 std::string object_name = GetNextWord(*code, next_position);
493 char next = (*code)[next_position + object_name.size()];
494 if (next == '.') {
495 next_position += object_name.size() + 1;
496 std::string selector_name = GetNextWord(*code, next_position);
497 next_position += selector_name.size();
498 next = (*code)[next_position];
499 std::vector<std::string> template_args;
500 if (next == '<') {
501 size_t close_bracket_pos;
502 RETURN_IF_ERROR(ParseArgsInsideBrackets(
503 *code, next_position, &close_bracket_pos, &template_args));
504 next_position = close_bracket_pos;
505 next = (*code)[next_position];
506 }
507 if (next != '(') {
508 return absl::NotFoundError(absl::StrCat(
509 "Expected ( after ", object_name, ".", selector_name, " call"));
510 }
511 std::vector<std::string> function_args;
512 size_t close_bracket_pos;
513 RETURN_IF_ERROR(ParseArgsInsideBrackets(
514 *code, next_position, &close_bracket_pos, &function_args));
515 for (auto& arg : function_args) {
516 RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, {}, &arg));
517 }
518 std::string patch;
519 RETURN_IF_ERROR(ResolveSelector(gpu_info, linkables, object_name,
520 selector_name, function_args,
521 template_args, &patch));
522 code->replace(arg_pos, close_bracket_pos - arg_pos, patch);
523 position = arg_pos + patch.size();
524 } else {
525 position = arg_pos + strlen(kArgsPrefix);
526 }
527 next_position = code->find(kArgsPrefix, position);
528 }
529 return absl::OkStatus();
530 }
531
ResolveSelector(const GpuInfo & gpu_info,const std::map<std::string,std::string> & linkables,const std::string & object_name,const std::string & selector,const std::vector<std::string> & function_args,const std::vector<std::string> & template_args,std::string * result) const532 absl::Status Arguments::ResolveSelector(
533 const GpuInfo& gpu_info,
534 const std::map<std::string, std::string>& linkables,
535 const std::string& object_name, const std::string& selector,
536 const std::vector<std::string>& function_args,
537 const std::vector<std::string>& template_args, std::string* result) const {
538 GPUObjectDescriptor* desc_ptr;
539 RETURN_IF_ERROR(GetDescriptor(object_name, &desc_ptr));
540 auto names = desc_ptr->GetGPUResources(gpu_info).GetNames();
541 const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc_ptr);
542 std::vector<std::string> function_args_new = function_args;
543 if (tensor_desc && !linkables.empty() && selector == "Write") {
544 auto it = linkables.find(object_name);
545 if (it != linkables.end() && !it->second.empty()) {
546 if (desc_ptr->GetAccess() != AccessType::WRITE &&
547 desc_ptr->GetAccess() != AccessType::READ_WRITE) {
548 return absl::FailedPreconditionError(absl::StrCat(
549 "Object with name - ", object_name, " should have Write access."));
550 }
551 std::string value_name, x_coord, y_coord, z_coord, s_coord, b_coord;
552 RETURN_IF_ERROR(tensor_desc->GetLinkingContextFromWriteSelector(
553 function_args_new, &value_name, &x_coord, &y_coord, &z_coord,
554 &s_coord, &b_coord));
555 const std::string new_value_name = value_name + "_final";
556 const std::string out_var_declaration =
557 "\n" + GetTypeDeclaration(gpu_info, tensor_desc->GetDataType(), 4) +
558 " " + new_value_name + ";\n";
559 *result = "{ // elementwise code with input:" + value_name +
560 absl::Substitute(it->second, out_var_declaration) + "\n";
561 *result = absl::StrReplaceAll(*result, {{"\n", "\n "}});
562 ReplaceAllWords("in_value", value_name, result);
563 ReplaceAllWords("out_value", new_value_name, result);
564 ReplaceAllWords("X_COORD", x_coord, result);
565 ReplaceAllWords("Y_COORD", y_coord, result);
566 ReplaceAllWords("Z_COORD", z_coord, result);
567 ReplaceAllWords("S_COORD", s_coord, result);
568 ReplaceAllWords("B_COORD", b_coord, result);
569 function_args_new[0] = new_value_name;
570 RETURN_IF_ERROR(ResolveConstExprPass(gpu_info, result));
571 RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, {}, result));
572 }
573 }
574 std::string patch;
575 RETURN_IF_ERROR(desc_ptr->PerformSelector(
576 gpu_info, selector, function_args_new, template_args, &patch));
577 ResolveObjectNames(object_name, names, &patch);
578 if (result->empty()) {
579 *result += patch;
580 } else {
581 // result has elementwise code
582 *result += "// write result to tensor\n " + patch + ";\n}";
583 }
584 return absl::OkStatus();
585 }
586
ResolveObjectNames(const std::string & object_name,const std::vector<std::string> & member_names,std::string * code) const587 void Arguments::ResolveObjectNames(const std::string& object_name,
588 const std::vector<std::string>& member_names,
589 std::string* code) const {
590 for (const auto& member_name : member_names) {
591 const std::string new_name = kArgsPrefix + object_name + "_" + member_name;
592 ReplaceAllWords(member_name, new_name, code);
593 }
594 }
595
AddObjectsScalarArgs(const GpuInfo & gpu_info)596 absl::Status Arguments::AddObjectsScalarArgs(const GpuInfo& gpu_info) {
597 for (auto& t : objects_) {
598 const auto resources = t.second->GetGPUResources(gpu_info);
599 for (const auto& r : resources.ints) {
600 AddInt(absl::StrCat(t.first, "_", r));
601 }
602 for (const auto& r : resources.floats) {
603 AddFloat(absl::StrCat(t.first, "_", r));
604 }
605 }
606 for (auto& t : object_refs_) {
607 const auto resources = t.second->GetGPUResources(gpu_info);
608 for (const auto& r : resources.ints) {
609 AddInt(absl::StrCat(t.first, "_", r));
610 }
611 for (const auto& r : resources.floats) {
612 AddFloat(absl::StrCat(t.first, "_", r));
613 }
614 }
615 return absl::OkStatus();
616 }
617
ResolveArgsPass(std::string * code) const618 void Arguments::ResolveArgsPass(std::string* code) const {
619 size_t position = 0;
620 size_t next_position = code->find(kArgsPrefix);
621 while (next_position != std::string::npos) {
622 size_t arg_pos = next_position;
623 next_position += strlen(kArgsPrefix);
624 std::string object_name = GetNextWord(*code, next_position);
625 std::string new_name = object_name;
626 code->replace(arg_pos, object_name.size() + strlen(kArgsPrefix), new_name);
627 position = arg_pos + new_name.size();
628 next_position = code->find(kArgsPrefix, position);
629 }
630 }
631
ResolveKernelGlobalSpaceBuffers(const GpuInfo & gpu_info,std::string * code)632 absl::Status Arguments::ResolveKernelGlobalSpaceBuffers(const GpuInfo& gpu_info,
633 std::string* code) {
634 for (auto it = objects_.begin(); it != objects_.end();) {
635 const auto* buffer_desc =
636 dynamic_cast<const BufferDescriptor*>(it->second.get());
637 if (!buffer_desc || buffer_desc->memory_type != MemoryType::CONSTANT) {
638 ++it;
639 continue;
640 }
641 bool is_kernel_global_space = false;
642 for (const auto& attribute : buffer_desc->attributes) {
643 if (attribute == "kernel_global_space") {
644 is_kernel_global_space = true;
645 break;
646 }
647 }
648 if (!is_kernel_global_space) {
649 ++it;
650 continue;
651 }
652 std::string declaration;
653 if (!BufferToKernelLanguage(gpu_info, it->first, buffer_desc, &declaration)
654 .ok()) {
655 ++it;
656 continue;
657 }
658 *code = declaration + *code;
659 objects_.erase(it++);
660 }
661 return absl::OkStatus();
662 }
663
664 } // namespace gpu
665 } // namespace tflite
666