1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h"
17
18 #include <string>
19 #include <utility>
20 #include <variant>
21
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/types/variant.h"
26 #include "tensorflow/lite/delegates/gpu/common/types.h"
27
28 namespace tflite {
29 namespace gpu {
30 namespace gl {
31 namespace variable_accessor_internal {
32
33 // Parse the following regex manually
34 // name(\[index\])?(\.field)?
Parse(absl::string_view input)35 VariableReference Parse(absl::string_view input) {
36 VariableReference ref;
37 auto start_index = input.find('[');
38 if (start_index != std::string::npos) {
39 auto end_index = input.rfind(']');
40 if (end_index == std::string::npos) {
41 return ref;
42 }
43 ref.index = input.substr(start_index + 1, end_index - start_index - 1);
44 ref.name = input.substr(0, start_index);
45 ref.field = input.substr(end_index + 1);
46 } else {
47 auto dot = input.find('.');
48 if (dot != std::string::npos) {
49 ref.name = input.substr(0, dot);
50 ref.field = input.substr(dot);
51 } else {
52 ref.name = input;
53 }
54 }
55 return ref;
56 }
57
58 } // namespace variable_accessor_internal
59
60 namespace {
61
62 struct VariableTypeGetter {
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter63 std::string operator()(int) const { return "int"; }
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter64 std::string operator()(const int2&) const { return "ivec2"; }
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter65 std::string operator()(const std::vector<int2>&) const { return "ivec2"; }
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter66 std::string operator()(const int4&) const { return "ivec4"; }
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter67 std::string operator()(unsigned int) const { return "uint"; }
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter68 std::string operator()(const uint4&) const { return "uvec4"; }
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter69 std::string operator()(float) const { return "float"; }
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter70 std::string operator()(const float2&) const { return "vec2"; }
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter71 std::string operator()(const float4&) const { return "vec4"; }
operator ()tflite::gpu::gl::__anona39f54780111::VariableTypeGetter72 std::string operator()(const std::vector<float4>&) const { return "vec4"; }
73 };
74
75 // Returns GLSL uniform type of the given variable.
GetVariableType(const Variable::ValueType & value)76 std::string GetVariableType(const Variable::ValueType& value) {
77 return std::visit(VariableTypeGetter(), value);
78 }
79
80 struct LengthGetter {
81 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::LengthGetter82 int operator()(const T& param) const {
83 return 1;
84 }
85 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::LengthGetter86 int operator()(const std::vector<T>& param) const {
87 return param.size();
88 }
89 };
90
GetLength(const Variable::ValueType & value)91 int GetLength(const Variable::ValueType& value) {
92 return std::visit(LengthGetter(), value);
93 }
94
95 template <typename T>
FormatValue(std::string * result,T t)96 void FormatValue(std::string* result, T t) {
97 absl::StrAppend(result, t);
98 }
99
100 template <>
FormatValue(std::string * result,float t)101 void FormatValue(std::string* result, float t) {
102 absl::StrAppend(result, absl::StrFormat("%.9ff", t));
103 }
104
105 // Unfortunately absl::StrJoin with custom formatter requires formatter to use
106 // string, not std::string. Therefore, due to this compatibility issue data
107 // needs to be converted to string representation first and then joined.
108 template <typename T, int N>
ToString(const std::array<T,N> & data)109 std::vector<std::string> ToString(const std::array<T, N>& data) {
110 std::vector<std::string> result(N);
111 for (int i = 0; i < N; ++i) {
112 FormatValue(&result[i], data[i]);
113 }
114 return result;
115 }
116
117 struct ConstGenerator {
118 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::ConstGenerator119 void operator()(T t) const {
120 FormatValue(result, t);
121 }
122
123 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::ConstGenerator124 void operator()(const Vec2<T>& v) const {
125 absl::StrAppend(result, VariableTypeGetter()(v), "(",
126 absl::StrJoin(ToString<T, 2>(v.data_), ","), ")");
127 }
128
129 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::ConstGenerator130 void operator()(const Vec3<T>& v) const {
131 absl::StrAppend(result, VariableTypeGetter()(v), "(",
132 absl::StrJoin(ToString<T, 3>(v.data_), ","), ")");
133 }
134
135 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::ConstGenerator136 void operator()(const Vec4<T>& v) const {
137 absl::StrAppend(result, VariableTypeGetter()(v), "(",
138 absl::StrJoin(ToString<T, 4>(v.data_), ","), ")");
139 }
140
141 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::ConstGenerator142 void operator()(const std::vector<T>& v) const {
143 std::string type = VariableTypeGetter()(v);
144 absl::StrAppend(result, type, "[", v.size(), "](");
145 bool first = true;
146 for (const auto& i : v) {
147 if (first) {
148 first = false;
149 } else {
150 absl::StrAppend(result, ",");
151 }
152 (*this)(i);
153 }
154 absl::StrAppend(result, ")");
155 }
156
157 std::string* result;
158 };
159
160 // Appends string representation of a variable value.
GetValue(const Variable::ValueType & value,std::string * result)161 void GetValue(const Variable::ValueType& value, std::string* result) {
162 std::visit(ConstGenerator{result}, value);
163 }
164
165 struct SharedVariableDeclarationGenerator {
166 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::SharedVariableDeclarationGenerator167 void operator()(const T&) const {
168 absl::StrAppend(result, "shared highp ", GetVariableType(variable.value),
169 " ", variable.name, ";\n");
170 }
171
172 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::SharedVariableDeclarationGenerator173 void operator()(const std::vector<T>& v) const {
174 absl::StrAppend(result, "shared highp ", GetVariableType(variable.value),
175 " ", variable.name);
176 if (v.empty()) {
177 // Normalize the size of the shared array to that of the WorkGroupSize
178 absl::StrAppend(
179 result,
180 "[gl_WorkGroupSize.z * gl_WorkGroupSize.y * gl_WorkGroupSize.x];\n");
181 } else {
182 // Use the specified size
183 absl::StrAppend(result, "[", v.size(), "];\n");
184 }
185 }
186
187 const Variable& variable;
188 std::string* result;
189 };
190
GenerateSharedVariableDeclaration(const Variable & variable,std::string * result)191 void GenerateSharedVariableDeclaration(const Variable& variable,
192 std::string* result) {
193 std::visit(SharedVariableDeclarationGenerator{variable, result},
194 variable.value);
195 }
196
197 struct UniformParameterDeclarationGenerator {
198 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::UniformParameterDeclarationGenerator199 void operator()(const T&) const {
200 absl::StrAppend(result, "uniform ", GetVariableType(variable.value), " ",
201 variable.name, ";\n");
202 }
203
204 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::UniformParameterDeclarationGenerator205 void operator()(const std::vector<T>& v) const {
206 absl::StrAppend(result, "uniform ", GetVariableType(variable.value), " ",
207 variable.name, "[", v.size(), "];\n");
208 }
209
210 const Variable& variable;
211 std::string* result;
212 };
213
GenerateUniformParameterDeclaration(const Variable & variable,std::string * result)214 void GenerateUniformParameterDeclaration(const Variable& variable,
215 std::string* result) {
216 std::visit(UniformParameterDeclarationGenerator{variable, result},
217 variable.value);
218 }
219
220 struct VulkanPushConstantGenerator {
221 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::VulkanPushConstantGenerator222 void operator()(const T&) const {
223 absl::StrAppend(result, " ", GetVariableType(variable.value), " ",
224 variable.name, ";\n");
225 }
226
227 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::VulkanPushConstantGenerator228 void operator()(const std::vector<T>& v) const {
229 absl::StrAppend(result, " ", GetVariableType(variable.value), " ",
230 variable.name, "[", v.size(), "];\n");
231 }
232
233 const Variable& variable;
234 std::string* result;
235 };
236
GenerateVulkanPushConstant(const Variable & variable,std::string * result)237 void GenerateVulkanPushConstant(const Variable& variable, std::string* result) {
238 std::visit(VulkanPushConstantGenerator{variable, result}, variable.value);
239 }
240
241 struct VariableLengthGetter {
242 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::VariableLengthGetter243 bool operator()(const T&) const {
244 return false;
245 }
246 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::VariableLengthGetter247 bool operator()(const std::vector<T>&) const {
248 return true;
249 }
250 };
251
252 struct VulkanConstantGenerator {
253 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::VulkanConstantGenerator254 void operator()(const T&) const {
255 const std::string variable_type = GetVariableType(variable.value);
256
257 // Vulkan specialization constants are used for scalar types, all other
258 // types go in push (uniform) constants.
259 if (variable_type == "int" || variable_type == "uint" ||
260 variable_type == "float") {
261 absl::StrAppend(result, "layout(constant_id = ", *constant_id, ") const ",
262 variable_type, " ", variable.name, " = ");
263 // Always set the default values to zero to generate generic cacheable
264 // shaders.
265 absl::StrAppend(result, (variable_type == "float" ? "0.0" : "0"), ";\n");
266 (*constant_id)++;
267 } else {
268 non_scalar_variables->push_back(variable);
269 }
270 }
271
272 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::VulkanConstantGenerator273 void operator()(const std::vector<T>& v) const {
274 non_scalar_variables->push_back(variable);
275 }
276
277 const Variable& variable;
278 int* const constant_id;
279 std::vector<Variable>* non_scalar_variables;
280 std::string* result;
281 };
282
GenerateVulkanConstant(const Variable & variable,int * constant_id,std::vector<Variable> * non_scalar_variables,std::string * result)283 void GenerateVulkanConstant(const Variable& variable, int* constant_id,
284 std::vector<Variable>* non_scalar_variables,
285 std::string* result) {
286 std::visit(VulkanConstantGenerator{variable, constant_id,
287 non_scalar_variables, result},
288 variable.value);
289 }
290
291 class VulkanConstantsProcessor {
292 public:
ProcessVulkanConstant(const Variable & variable,std::string * result)293 void ProcessVulkanConstant(const Variable& variable, std::string* result) {
294 GenerateVulkanConstant(variable, &constant_id_, &non_scalar_variables_,
295 result);
296 }
297
GeneratePushConstantsDeclarations(std::string * result)298 void GeneratePushConstantsDeclarations(std::string* result) {
299 if (!non_scalar_variables_.empty()) {
300 *result += "\nlayout(push_constant) uniform pushConstants {\n";
301 for (const auto& variable : non_scalar_variables_) {
302 GenerateVulkanPushConstant(variable, result);
303 }
304 *result += "};\n";
305 }
306 }
307
308 protected:
309 // Reserve the first three specialization constants slots for the
310 // workgroup size.
311 int constant_id_ = 3;
312 std::vector<Variable> non_scalar_variables_;
313 };
314
315 // Returns true if value is a vector
IsVariableLength(const Variable::ValueType & value)316 bool IsVariableLength(const Variable::ValueType& value) {
317 return std::visit(VariableLengthGetter(), value);
318 }
319
320 enum Field : uint8_t { UNKNOWN = 4, X = 0, Y = 1, Z = 2, W = 3 };
321
ToField(absl::string_view field_name)322 Field ToField(absl::string_view field_name) {
323 if (field_name.size() == 2 && field_name[0] == '.') {
324 switch (field_name[1]) {
325 case 'x':
326 return Field::X;
327 case 'y':
328 return Field::Y;
329 case 'z':
330 return Field::Z;
331 case 'w':
332 return Field::W;
333 }
334 }
335 return Field::UNKNOWN;
336 }
337
338 struct FieldAccessor {
339 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::FieldAccessor340 void operator()(const T&) const {}
341
342 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::FieldAccessor343 void operator()(const Vec2<T>& v) const {
344 FormatValue(result, v[field]);
345 }
346
347 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::FieldAccessor348 void operator()(const Vec3<T>& v) const {
349 FormatValue(result, v[field]);
350 }
351
352 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::FieldAccessor353 void operator()(const Vec4<T>& v) const {
354 FormatValue(result, v[field]);
355 }
356
357 Field field;
358 std::string* result;
359 };
360
361 // Appends formatted value of the given field.
GetValue(const Variable::ValueType & value,Field field,std::string * result)362 void GetValue(const Variable::ValueType& value, Field field,
363 std::string* result) {
364 std::visit(FieldAccessor{field, result}, value);
365 }
366
367 struct FieldChecker {
368 // For trivial as well as variable-length types indexed access is not allowed.
369 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::FieldChecker370 bool operator()(const T&) const {
371 return false;
372 }
373
374 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::FieldChecker375 bool operator()(const Vec2<T>& v) const {
376 return field < v.size();
377 }
378
379 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::FieldChecker380 bool operator()(const Vec3<T>& v) const {
381 return field < v.size();
382 }
383
384 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::FieldChecker385 bool operator()(const Vec4<T>& v) const {
386 return field < v.size();
387 }
388
389 template <typename T>
operator ()tflite::gpu::gl::__anona39f54780111::FieldChecker390 bool operator()(const std::vector<T>&) const {
391 // technically accessing [0] element of an empty vector is UB, but we need
392 // only type information for this check. Therefore, construct default T and
393 // use it instead.
394 T t;
395 return (*this)(t);
396 }
397
398 Field field;
399 };
400
401 // Returns true if field has field access and field is not out of bounds.
HasField(const Variable::ValueType & value,Field field)402 bool HasField(const Variable::ValueType& value, Field field) {
403 return std::visit(FieldChecker{field}, value);
404 }
405
AssembleAccessor(absl::string_view name,absl::string_view index,absl::string_view field,std::string * result)406 void AssembleAccessor(absl::string_view name, absl::string_view index,
407 absl::string_view field, std::string* result) {
408 if (index.empty()) {
409 absl::StrAppend(result, name, field);
410 } else {
411 absl::StrAppend(result, name, "[", index, "]", field);
412 }
413 }
414
415 } // namespace
416
Rewrite(absl::string_view input,std::string * output)417 RewriteStatus VariableAccessor::Rewrite(absl::string_view input,
418 std::string* output) {
419 auto ref = variable_accessor_internal::Parse(input);
420 if (ref.name.empty()) {
421 absl::StrAppend(output, "INVALID_SYNTAX");
422 return RewriteStatus::ERROR;
423 }
424
425 auto it =
426 name_to_variable_.find(std::string(ref.name.data(), ref.name.size()));
427 if (it == name_to_variable_.end()) {
428 // Uniform with this name is not registered.
429 return RewriteStatus::NOT_RECOGNIZED;
430 }
431 const auto& value = it->second.value;
432
433 if (!ref.index.empty() && !IsVariableLength(value)) {
434 // Trying to access variable by index, but it is not variable-length.
435 absl::StrAppend(output, "INVALID_ACCESS_BY_INDEX");
436 return RewriteStatus::ERROR;
437 }
438
439 Field f = ToField(ref.field);
440 if (!ref.field.empty() && !HasField(value, f)) {
441 // Trying to access a variable by field, but it does not have it.
442 absl::StrAppend(output, "INVALID_ACCESS_BY_FIELD");
443 return RewriteStatus::ERROR;
444 }
445
446 // Error checks are complete now.
447
448 // All variable-length variables are encoded as-is without inlining.
449 if (!inline_values_ || IsVariableLength(value)) {
450 AssembleAccessor(it->second.name, ref.index, ref.field, output);
451 } else {
452 // Parameter + field is replaced with field value.
453 if (f != Field::UNKNOWN) {
454 GetValue(value, f, output);
455 } else {
456 // Parameter is accessed directly.
457 GetValue(value, output);
458 }
459 }
460 return RewriteStatus::SUCCESS;
461 }
462
AddSharedVariable(Variable && variable)463 bool VariableAccessor::AddSharedVariable(Variable&& variable) {
464 const std::string name = variable.name;
465 if (!name_to_variable_.insert({name, std::move(variable)}).second) {
466 return false;
467 }
468 shared_variables_.insert(name);
469 return true;
470 }
471
AddUniformParameter(Variable && variable)472 bool VariableAccessor::AddUniformParameter(Variable&& variable) {
473 const std::string name = variable.name;
474 if (!name_to_variable_.insert({name, std::move(variable)}).second) {
475 return false;
476 }
477 uniform_parameters_.insert(name);
478 return true;
479 }
480
IsEmptyVariableLength(const Variable & variable) const481 bool VariableAccessor::IsEmptyVariableLength(const Variable& variable) const {
482 const auto& value = variable.value;
483 return IsVariableLength(value) && GetLength(value) == 0;
484 }
485
GetConstDeclarations() const486 std::string VariableAccessor::GetConstDeclarations() const {
487 // Variable length variables are declared as const and accessed via variable
488 // with index.
489 std::string declarations;
490 for (const auto& variable : name_to_variable_) {
491 // Skip shared variables.
492 const std::string& variable_name = variable.second.name;
493 if (shared_variables_.find(variable_name) != shared_variables_.end()) {
494 continue;
495 }
496
497 const auto& value = variable.second.value;
498 if (IsVariableLength(value)) {
499 absl::StrAppend(&declarations, "const ", GetVariableType(value), " ",
500 variable_name, "[] = ");
501 GetValue(value, &declarations);
502 absl::StrAppend(&declarations, ";\n");
503 }
504 }
505 return declarations;
506 }
507
GetSharedVariableDeclarations() const508 std::string VariableAccessor::GetSharedVariableDeclarations() const {
509 std::string declarations;
510 for (const auto& name : shared_variables_) {
511 const auto& variable = name_to_variable_.at(name);
512 GenerateSharedVariableDeclaration(variable, &declarations);
513 }
514 return declarations;
515 }
516
GetUniformParameterDeclarations() const517 std::string VariableAccessor::GetUniformParameterDeclarations() const {
518 std::string declarations;
519 if (!inline_values_) {
520 if (vulkan_support_) {
521 VulkanConstantsProcessor processor;
522 for (const auto& name : uniform_parameters_) {
523 const auto& variable = name_to_variable_.at(name);
524 processor.ProcessVulkanConstant(variable, &declarations);
525 }
526 processor.GeneratePushConstantsDeclarations(&declarations);
527 } else {
528 for (const auto& name : uniform_parameters_) {
529 const auto& variable = name_to_variable_.at(name);
530 GenerateUniformParameterDeclaration(variable, &declarations);
531 }
532 }
533 }
534 return declarations;
535 }
536
GetUniformParameters() const537 std::vector<Variable> VariableAccessor::GetUniformParameters() const {
538 std::vector<Variable> variables;
539 if (!inline_values_) {
540 variables.reserve(name_to_variable_.size());
541 // Keep the order of the variables consistent with that of the declarations
542 for (const auto& name : uniform_parameters_) {
543 const auto& variable = name_to_variable_.at(name);
544 variables.push_back(variable);
545 }
546 }
547 return variables;
548 }
549
550 } // namespace gl
551 } // namespace gpu
552 } // namespace tflite
553