xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h"
17 
18 #include <string>
19 #include <utility>
20 #include <variant>
21 
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/types/variant.h"
26 #include "tensorflow/lite/delegates/gpu/common/types.h"
27 
28 namespace tflite {
29 namespace gpu {
30 namespace gl {
31 namespace variable_accessor_internal {
32 
33 // Parse the following regex manually
34 // name(\[index\])?(\.field)?
Parse(absl::string_view input)35 VariableReference Parse(absl::string_view input) {
36   VariableReference ref;
37   auto start_index = input.find('[');
38   if (start_index != std::string::npos) {
39     auto end_index = input.rfind(']');
40     if (end_index == std::string::npos) {
41       return ref;
42     }
43     ref.index = input.substr(start_index + 1, end_index - start_index - 1);
44     ref.name = input.substr(0, start_index);
45     ref.field = input.substr(end_index + 1);
46   } else {
47     auto dot = input.find('.');
48     if (dot != std::string::npos) {
49       ref.name = input.substr(0, dot);
50       ref.field = input.substr(dot);
51     } else {
52       ref.name = input;
53     }
54   }
55   return ref;
56 }
57 
58 }  // namespace variable_accessor_internal
59 
60 namespace {
61 
62 struct VariableTypeGetter {
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter63   std::string operator()(int) const { return "int"; }
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter64   std::string operator()(const int2&) const { return "ivec2"; }
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter65   std::string operator()(const std::vector<int2>&) const { return "ivec2"; }
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter66   std::string operator()(const int4&) const { return "ivec4"; }
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter67   std::string operator()(unsigned int) const { return "uint"; }
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter68   std::string operator()(const uint4&) const { return "uvec4"; }
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter69   std::string operator()(float) const { return "float"; }
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter70   std::string operator()(const float2&) const { return "vec2"; }
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter71   std::string operator()(const float4&) const { return "vec4"; }
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter72   std::string operator()(const std::vector<float4>&) const { return "vec4"; }
73 };
74 
75 // Returns GLSL uniform type of the given variable.
GetVariableType(const Variable::ValueType & value)76 std::string GetVariableType(const Variable::ValueType& value) {
77   return std::visit(VariableTypeGetter(), value);
78 }
79 
80 struct LengthGetter {
81   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::LengthGetter82   int operator()(const T& param) const {
83     return 1;
84   }
85   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::LengthGetter86   int operator()(const std::vector<T>& param) const {
87     return param.size();
88   }
89 };
90 
GetLength(const Variable::ValueType & value)91 int GetLength(const Variable::ValueType& value) {
92   return std::visit(LengthGetter(), value);
93 }
94 
95 template <typename T>
FormatValue(std::string * result,T t)96 void FormatValue(std::string* result, T t) {
97   absl::StrAppend(result, t);
98 }
99 
100 template <>
FormatValue(std::string * result,float t)101 void FormatValue(std::string* result, float t) {
102   absl::StrAppend(result, absl::StrFormat("%.9ff", t));
103 }
104 
105 // Unfortunately absl::StrJoin with custom formatter requires formatter to use
106 // string, not std::string. Therefore, due to this compatibility issue data
107 // needs to be converted to string representation first and then joined.
108 template <typename T, int N>
ToString(const std::array<T,N> & data)109 std::vector<std::string> ToString(const std::array<T, N>& data) {
110   std::vector<std::string> result(N);
111   for (int i = 0; i < N; ++i) {
112     FormatValue(&result[i], data[i]);
113   }
114   return result;
115 }
116 
117 struct ConstGenerator {
118   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::ConstGenerator119   void operator()(T t) const {
120     FormatValue(result, t);
121   }
122 
123   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::ConstGenerator124   void operator()(const Vec2<T>& v) const {
125     absl::StrAppend(result, VariableTypeGetter()(v), "(",
126                     absl::StrJoin(ToString<T, 2>(v.data_), ","), ")");
127   }
128 
129   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::ConstGenerator130   void operator()(const Vec3<T>& v) const {
131     absl::StrAppend(result, VariableTypeGetter()(v), "(",
132                     absl::StrJoin(ToString<T, 3>(v.data_), ","), ")");
133   }
134 
135   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::ConstGenerator136   void operator()(const Vec4<T>& v) const {
137     absl::StrAppend(result, VariableTypeGetter()(v), "(",
138                     absl::StrJoin(ToString<T, 4>(v.data_), ","), ")");
139   }
140 
141   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::ConstGenerator142   void operator()(const std::vector<T>& v) const {
143     std::string type = VariableTypeGetter()(v);
144     absl::StrAppend(result, type, "[", v.size(), "](");
145     bool first = true;
146     for (const auto& i : v) {
147       if (first) {
148         first = false;
149       } else {
150         absl::StrAppend(result, ",");
151       }
152       (*this)(i);
153     }
154     absl::StrAppend(result, ")");
155   }
156 
157   std::string* result;
158 };
159 
160 // Appends string representation of a variable value.
GetValue(const Variable::ValueType & value,std::string * result)161 void GetValue(const Variable::ValueType& value, std::string* result) {
162   std::visit(ConstGenerator{result}, value);
163 }
164 
165 struct SharedVariableDeclarationGenerator {
166   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::SharedVariableDeclarationGenerator167   void operator()(const T&) const {
168     absl::StrAppend(result, "shared highp ", GetVariableType(variable.value),
169                     " ", variable.name, ";\n");
170   }
171 
172   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::SharedVariableDeclarationGenerator173   void operator()(const std::vector<T>& v) const {
174     absl::StrAppend(result, "shared highp ", GetVariableType(variable.value),
175                     " ", variable.name);
176     if (v.empty()) {
177       // Normalize the size of the shared array to that of the WorkGroupSize
178       absl::StrAppend(
179           result,
180           "[gl_WorkGroupSize.z * gl_WorkGroupSize.y * gl_WorkGroupSize.x];\n");
181     } else {
182       // Use the specified size
183       absl::StrAppend(result, "[", v.size(), "];\n");
184     }
185   }
186 
187   const Variable& variable;
188   std::string* result;
189 };
190 
GenerateSharedVariableDeclaration(const Variable & variable,std::string * result)191 void GenerateSharedVariableDeclaration(const Variable& variable,
192                                        std::string* result) {
193   std::visit(SharedVariableDeclarationGenerator{variable, result},
194              variable.value);
195 }
196 
197 struct UniformParameterDeclarationGenerator {
198   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::UniformParameterDeclarationGenerator199   void operator()(const T&) const {
200     absl::StrAppend(result, "uniform ", GetVariableType(variable.value), " ",
201                     variable.name, ";\n");
202   }
203 
204   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::UniformParameterDeclarationGenerator205   void operator()(const std::vector<T>& v) const {
206     absl::StrAppend(result, "uniform ", GetVariableType(variable.value), " ",
207                     variable.name, "[", v.size(), "];\n");
208   }
209 
210   const Variable& variable;
211   std::string* result;
212 };
213 
GenerateUniformParameterDeclaration(const Variable & variable,std::string * result)214 void GenerateUniformParameterDeclaration(const Variable& variable,
215                                          std::string* result) {
216   std::visit(UniformParameterDeclarationGenerator{variable, result},
217              variable.value);
218 }
219 
220 struct VulkanPushConstantGenerator {
221   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::VulkanPushConstantGenerator222   void operator()(const T&) const {
223     absl::StrAppend(result, "  ", GetVariableType(variable.value), " ",
224                     variable.name, ";\n");
225   }
226 
227   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::VulkanPushConstantGenerator228   void operator()(const std::vector<T>& v) const {
229     absl::StrAppend(result, "  ", GetVariableType(variable.value), " ",
230                     variable.name, "[", v.size(), "];\n");
231   }
232 
233   const Variable& variable;
234   std::string* result;
235 };
236 
GenerateVulkanPushConstant(const Variable & variable,std::string * result)237 void GenerateVulkanPushConstant(const Variable& variable, std::string* result) {
238   std::visit(VulkanPushConstantGenerator{variable, result}, variable.value);
239 }
240 
241 struct VariableLengthGetter {
242   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::VariableLengthGetter243   bool operator()(const T&) const {
244     return false;
245   }
246   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::VariableLengthGetter247   bool operator()(const std::vector<T>&) const {
248     return true;
249   }
250 };
251 
252 struct VulkanConstantGenerator {
253   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::VulkanConstantGenerator254   void operator()(const T&) const {
255     const std::string variable_type = GetVariableType(variable.value);
256 
257     // Vulkan specialization constants are used for scalar types, all other
258     // types go in push (uniform) constants.
259     if (variable_type == "int" || variable_type == "uint" ||
260         variable_type == "float") {
261       absl::StrAppend(result, "layout(constant_id = ", *constant_id, ") const ",
262                       variable_type, " ", variable.name, " = ");
263       // Always set the default values to zero to generate generic cacheable
264       // shaders.
265       absl::StrAppend(result, (variable_type == "float" ? "0.0" : "0"), ";\n");
266       (*constant_id)++;
267     } else {
268       non_scalar_variables->push_back(variable);
269     }
270   }
271 
272   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::VulkanConstantGenerator273   void operator()(const std::vector<T>& v) const {
274     non_scalar_variables->push_back(variable);
275   }
276 
277   const Variable& variable;
278   int* const constant_id;
279   std::vector<Variable>* non_scalar_variables;
280   std::string* result;
281 };
282 
GenerateVulkanConstant(const Variable & variable,int * constant_id,std::vector<Variable> * non_scalar_variables,std::string * result)283 void GenerateVulkanConstant(const Variable& variable, int* constant_id,
284                             std::vector<Variable>* non_scalar_variables,
285                             std::string* result) {
286   std::visit(VulkanConstantGenerator{variable, constant_id,
287                                      non_scalar_variables, result},
288              variable.value);
289 }
290 
291 class VulkanConstantsProcessor {
292  public:
ProcessVulkanConstant(const Variable & variable,std::string * result)293   void ProcessVulkanConstant(const Variable& variable, std::string* result) {
294     GenerateVulkanConstant(variable, &constant_id_, &non_scalar_variables_,
295                            result);
296   }
297 
GeneratePushConstantsDeclarations(std::string * result)298   void GeneratePushConstantsDeclarations(std::string* result) {
299     if (!non_scalar_variables_.empty()) {
300       *result += "\nlayout(push_constant) uniform pushConstants {\n";
301       for (const auto& variable : non_scalar_variables_) {
302         GenerateVulkanPushConstant(variable, result);
303       }
304       *result += "};\n";
305     }
306   }
307 
308  protected:
309   // Reserve the first three specialization constants slots for the
310   // workgroup size.
311   int constant_id_ = 3;
312   std::vector<Variable> non_scalar_variables_;
313 };
314 
315 // Returns true if value is a vector
IsVariableLength(const Variable::ValueType & value)316 bool IsVariableLength(const Variable::ValueType& value) {
317   return std::visit(VariableLengthGetter(), value);
318 }
319 
320 enum Field : uint8_t { UNKNOWN = 4, X = 0, Y = 1, Z = 2, W = 3 };
321 
ToField(absl::string_view field_name)322 Field ToField(absl::string_view field_name) {
323   if (field_name.size() == 2 && field_name[0] == '.') {
324     switch (field_name[1]) {
325       case 'x':
326         return Field::X;
327       case 'y':
328         return Field::Y;
329       case 'z':
330         return Field::Z;
331       case 'w':
332         return Field::W;
333     }
334   }
335   return Field::UNKNOWN;
336 }
337 
338 struct FieldAccessor {
339   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::FieldAccessor340   void operator()(const T&) const {}
341 
342   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::FieldAccessor343   void operator()(const Vec2<T>& v) const {
344     FormatValue(result, v[field]);
345   }
346 
347   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::FieldAccessor348   void operator()(const Vec3<T>& v) const {
349     FormatValue(result, v[field]);
350   }
351 
352   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::FieldAccessor353   void operator()(const Vec4<T>& v) const {
354     FormatValue(result, v[field]);
355   }
356 
357   Field field;
358   std::string* result;
359 };
360 
361 // Appends formatted value of the given field.
GetValue(const Variable::ValueType & value,Field field,std::string * result)362 void GetValue(const Variable::ValueType& value, Field field,
363               std::string* result) {
364   std::visit(FieldAccessor{field, result}, value);
365 }
366 
367 struct FieldChecker {
368   // For trivial as well as variable-length types indexed access is not allowed.
369   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::FieldChecker370   bool operator()(const T&) const {
371     return false;
372   }
373 
374   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::FieldChecker375   bool operator()(const Vec2<T>& v) const {
376     return field < v.size();
377   }
378 
379   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::FieldChecker380   bool operator()(const Vec3<T>& v) const {
381     return field < v.size();
382   }
383 
384   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::FieldChecker385   bool operator()(const Vec4<T>& v) const {
386     return field < v.size();
387   }
388 
389   template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::FieldChecker390   bool operator()(const std::vector<T>&) const {
391     // technically accessing [0] element of an empty vector is UB, but we need
392     // only type information for this check. Therefore, construct default T and
393     // use it instead.
394     T t;
395     return (*this)(t);
396   }
397 
398   Field field;
399 };
400 
401 // Returns true if field has field access and field is not out of bounds.
HasField(const Variable::ValueType & value,Field field)402 bool HasField(const Variable::ValueType& value, Field field) {
403   return std::visit(FieldChecker{field}, value);
404 }
405 
AssembleAccessor(absl::string_view name,absl::string_view index,absl::string_view field,std::string * result)406 void AssembleAccessor(absl::string_view name, absl::string_view index,
407                       absl::string_view field, std::string* result) {
408   if (index.empty()) {
409     absl::StrAppend(result, name, field);
410   } else {
411     absl::StrAppend(result, name, "[", index, "]", field);
412   }
413 }
414 
415 }  // namespace
416 
Rewrite(absl::string_view input,std::string * output)417 RewriteStatus VariableAccessor::Rewrite(absl::string_view input,
418                                         std::string* output) {
419   auto ref = variable_accessor_internal::Parse(input);
420   if (ref.name.empty()) {
421     absl::StrAppend(output, "INVALID_SYNTAX");
422     return RewriteStatus::ERROR;
423   }
424 
425   auto it =
426       name_to_variable_.find(std::string(ref.name.data(), ref.name.size()));
427   if (it == name_to_variable_.end()) {
428     // Uniform with this name is not registered.
429     return RewriteStatus::NOT_RECOGNIZED;
430   }
431   const auto& value = it->second.value;
432 
433   if (!ref.index.empty() && !IsVariableLength(value)) {
434     // Trying to access variable by index, but it is not variable-length.
435     absl::StrAppend(output, "INVALID_ACCESS_BY_INDEX");
436     return RewriteStatus::ERROR;
437   }
438 
439   Field f = ToField(ref.field);
440   if (!ref.field.empty() && !HasField(value, f)) {
441     // Trying to access a variable by field, but it does not have it.
442     absl::StrAppend(output, "INVALID_ACCESS_BY_FIELD");
443     return RewriteStatus::ERROR;
444   }
445 
446   // Error checks are complete now.
447 
448   // All variable-length variables are encoded as-is without inlining.
449   if (!inline_values_ || IsVariableLength(value)) {
450     AssembleAccessor(it->second.name, ref.index, ref.field, output);
451   } else {
452     // Parameter + field is replaced with field value.
453     if (f != Field::UNKNOWN) {
454       GetValue(value, f, output);
455     } else {
456       // Parameter is accessed directly.
457       GetValue(value, output);
458     }
459   }
460   return RewriteStatus::SUCCESS;
461 }
462 
AddSharedVariable(Variable && variable)463 bool VariableAccessor::AddSharedVariable(Variable&& variable) {
464   const std::string name = variable.name;
465   if (!name_to_variable_.insert({name, std::move(variable)}).second) {
466     return false;
467   }
468   shared_variables_.insert(name);
469   return true;
470 }
471 
AddUniformParameter(Variable && variable)472 bool VariableAccessor::AddUniformParameter(Variable&& variable) {
473   const std::string name = variable.name;
474   if (!name_to_variable_.insert({name, std::move(variable)}).second) {
475     return false;
476   }
477   uniform_parameters_.insert(name);
478   return true;
479 }
480 
IsEmptyVariableLength(const Variable & variable) const481 bool VariableAccessor::IsEmptyVariableLength(const Variable& variable) const {
482   const auto& value = variable.value;
483   return IsVariableLength(value) && GetLength(value) == 0;
484 }
485 
GetConstDeclarations() const486 std::string VariableAccessor::GetConstDeclarations() const {
487   // Variable length variables are declared as const and accessed via variable
488   // with index.
489   std::string declarations;
490   for (const auto& variable : name_to_variable_) {
491     // Skip shared variables.
492     const std::string& variable_name = variable.second.name;
493     if (shared_variables_.find(variable_name) != shared_variables_.end()) {
494       continue;
495     }
496 
497     const auto& value = variable.second.value;
498     if (IsVariableLength(value)) {
499       absl::StrAppend(&declarations, "const ", GetVariableType(value), " ",
500                       variable_name, "[] = ");
501       GetValue(value, &declarations);
502       absl::StrAppend(&declarations, ";\n");
503     }
504   }
505   return declarations;
506 }
507 
GetSharedVariableDeclarations() const508 std::string VariableAccessor::GetSharedVariableDeclarations() const {
509   std::string declarations;
510   for (const auto& name : shared_variables_) {
511     const auto& variable = name_to_variable_.at(name);
512     GenerateSharedVariableDeclaration(variable, &declarations);
513   }
514   return declarations;
515 }
516 
GetUniformParameterDeclarations() const517 std::string VariableAccessor::GetUniformParameterDeclarations() const {
518   std::string declarations;
519   if (!inline_values_) {
520     if (vulkan_support_) {
521       VulkanConstantsProcessor processor;
522       for (const auto& name : uniform_parameters_) {
523         const auto& variable = name_to_variable_.at(name);
524         processor.ProcessVulkanConstant(variable, &declarations);
525       }
526       processor.GeneratePushConstantsDeclarations(&declarations);
527     } else {
528       for (const auto& name : uniform_parameters_) {
529         const auto& variable = name_to_variable_.at(name);
530         GenerateUniformParameterDeclaration(variable, &declarations);
531       }
532     }
533   }
534   return declarations;
535 }
536 
GetUniformParameters() const537 std::vector<Variable> VariableAccessor::GetUniformParameters() const {
538   std::vector<Variable> variables;
539   if (!inline_values_) {
540     variables.reserve(name_to_variable_.size());
541     // Keep the order of the variables consistent with that of the declarations
542     for (const auto& name : uniform_parameters_) {
543       const auto& variable = name_to_variable_.at(name);
544       variables.push_back(variable);
545     }
546   }
547   return variables;
548 }
549 
550 }  // namespace gl
551 }  // namespace gpu
552 }  // namespace tflite
553