xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h"
17 
18 #include <algorithm>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/status/status.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
25 #include "tensorflow/lite/delegates/gpu/common/status.h"
26 #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h"
27 #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h"
28 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
29 
30 namespace tflite {
31 namespace gpu {
32 namespace gl {
33 
ShaderCodegen(const CompilationOptions & options,const GpuInfo & gpu_info)34 ShaderCodegen::ShaderCodegen(const CompilationOptions& options,
35                              const GpuInfo& gpu_info)
36     : options_(options), gpu_type_(gpu_info.vendor) {}
37 
Build(CompiledNodeAttributes attr,ShaderCode * shader_code) const38 absl::Status ShaderCodegen::Build(CompiledNodeAttributes attr,
39                                   ShaderCode* shader_code) const {
40   VariableAccessor variable_accessor(options_.inline_parameters,
41                                      options_.vulkan_support);
42   ObjectAccessor object_accessor(gpu_type_ == GpuVendor::kMali,
43                                  options_.sampler_textures, &variable_accessor);
44 
45   const auto add_object = [&](const std::string& name, Object&& object) {
46     if (!object_accessor.AddObject(name, std::forward<Object>(object))) {
47       return absl::AlreadyExistsError(absl::StrCat("Object \"", name, "\""));
48     }
49     return absl::OkStatus();
50   };
51 
52   const auto add_uniform_parameter = [&](Variable&& variable) {
53     const std::string name = variable.name;
54     const Variable& const_ref = variable;
55     if (variable_accessor.IsEmptyVariableLength(const_ref)) {
56       return absl::InvalidArgumentError(
57           absl::StrCat("Empty uniform vector value \"", name, "\""));
58     }
59     if (!variable_accessor.AddUniformParameter(std::move(variable))) {
60       return absl::AlreadyExistsError(
61           absl::StrCat("Uniform parameter \"", name, "\""));
62     }
63     return absl::OkStatus();
64   };
65 
66   for (auto&& object : attr.code.objects) {
67     RETURN_IF_ERROR(add_object(object.first, std::move(object.second)));
68   }
69 
70   for (auto&& variable : attr.code.shared_variables) {
71     const std::string name = variable.name;
72     if (!variable_accessor.AddSharedVariable(std::move(variable))) {
73       return absl::AlreadyExistsError(
74           absl::StrCat("Shared variable \"", name, "\""));
75     }
76   }
77 
78   for (auto&& variable : attr.code.parameters) {
79     RETURN_IF_ERROR(add_uniform_parameter(std::move(variable)));
80   }
81 
82   int index = 0;
83   for (auto&& input : attr.inputs) {
84     RETURN_IF_ERROR(
85         add_object(absl::StrCat("input_data_", index++), std::move(input)));
86   }
87   index = 0;
88   for (auto&& output : attr.outputs) {
89     RETURN_IF_ERROR(
90         add_object(absl::StrCat("output_data_", index++), std::move(output)));
91   }
92 
93   // TODO(akulik): workload params need to go away and be replaced with
94   // output_data_0_w
95   RETURN_IF_ERROR(add_uniform_parameter(
96       {"workload_x", static_cast<int32_t>(attr.code.workload.x)}));
97   RETURN_IF_ERROR(add_uniform_parameter(
98       {"workload_y", static_cast<int32_t>(attr.code.workload.y)}));
99   RETURN_IF_ERROR(add_uniform_parameter(
100       {"workload_z", static_cast<int32_t>(attr.code.workload.z)}));
101 
102   // NOTE: If the shader has shared variables it will have to use barriers,
103   //       which will conflict with a return at this stage.
104   // Let the user deal with the geometry constraints.
105   const bool has_shared_variables = !attr.code.shared_variables.empty();
106   std::string main_source_code = has_shared_variables ? R"(
107   ivec3 gid = ivec3(gl_GlobalInvocationID.xyz);
108 )"
109                                                       : R"(
110   ivec3 gid = ivec3(gl_GlobalInvocationID.xyz);
111   if (gid.x >= $workload_x$ || gid.y >= $workload_y$ || gid.z >= $workload_z$) {
112     return;
113   }
114 )";
115 
116   switch (attr.code.input) {
117     case IOStructure::ONLY_DEFINITIONS:
118       for (int i = 0; i < attr.inputs.size(); ++i) {
119         absl::StrAppend(&main_source_code, "  highp vec4 value_", i,
120                         " = vec4(0);\n");
121       }
122       break;
123     case IOStructure::AUTO: {
124       for (int i = 0; i < attr.inputs.size(); ++i) {
125         absl::StrAppend(&main_source_code, "  highp vec4 value_", i,
126                         " = $input_data_", i, "[gid.x, gid.y, gid.z]$;\n");
127       }
128       break;
129     }
130   }
131 
132   main_source_code.append(attr.code.source_code);
133 
134   if (attr.code.output == IOStructure::AUTO) {
135     for (int i = 0; i < attr.outputs.size(); ++i) {
136       absl::StrAppend(&main_source_code, "  $output_data_", i,
137                       "[gid.x, gid.y, gid.z] = value_", i, "$;\n");
138     }
139   }
140 
141   // At this point main function is already generated. Now we need to process
142   // object and variable accessors.
143 
144   // process objects first. Object accessor may introduce new uniform
145   // parameters that need to be rewritten in the subsequent pass.
146   {
147     TextPreprocessor preprocessor('$', /*keep_unknown_rewrites=*/true);
148     preprocessor.AddRewrite(&object_accessor);
149     RETURN_IF_ERROR(preprocessor.Rewrite(main_source_code, &main_source_code));
150   }
151 
152   {
153     TextPreprocessor preprocessor('$', /*keep_unknown_rewrites=*/false);
154     preprocessor.AddRewrite(&variable_accessor);
155     RETURN_IF_ERROR(preprocessor.Rewrite(main_source_code, &main_source_code));
156   }
157 
158   if (options_.inline_parameters) {
159     main_source_code = absl::StrCat(variable_accessor.GetConstDeclarations(),
160                                     main_source_code);
161   }
162 
163   // partial_source_code is only missing the following which is added later:
164   // #version 310 es
165   // layout(local_size_x = ..., local_size_y = ..., local_size_z = ...) in;
166   const char* precision = options_.allow_precision_loss ? "mediump" : "highp";
167   const std::string partial_source_code = absl::StrCat(
168       "layout(std430) buffer;\n",                                 //
169       "precision ", precision, " float;\n",                       //
170       object_accessor.GetFunctionsDeclarations(), "\n",           //
171       object_accessor.GetObjectDeclarations(), "\n",              //
172       variable_accessor.GetUniformParameterDeclarations(), "\n",  //
173       variable_accessor.GetSharedVariableDeclarations(), "\n",    //
174       "void main() {\n",                                          //
175       main_source_code,                                           //
176       "}");
177   *shader_code =
178       ShaderCode(variable_accessor.GetUniformParameters(),
179                  object_accessor.GetObjects(), attr.code.workload,
180                  attr.code.workgroup, partial_source_code, attr.node_indices);
181   return absl::OkStatus();
182 }
183 
184 }  // namespace gl
185 }  // namespace gpu
186 }  // namespace tflite
187