xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/task/buffer_desc.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h"
17 
18 #include <string>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/substitute.h"
22 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
23 #include "tensorflow/lite/delegates/gpu/common/status.h"
24 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
25 
26 namespace tflite {
27 namespace gpu {
28 
Release()29 void BufferDescriptor::Release() { data.clear(); }
30 
GetGPUResources(const GpuInfo & gpu_info) const31 GPUResources BufferDescriptor::GetGPUResources(const GpuInfo& gpu_info) const {
32   GPUResources resources;
33   GPUBufferDescriptor desc;
34   desc.data_type = element_type;
35   desc.access_type = access_type_;
36   desc.element_size = element_size;
37   desc.memory_type = memory_type;
38   desc.attributes = attributes;
39   if (gpu_info.IsGlsl() && memory_type == tflite::gpu::MemoryType::CONSTANT) {
40     desc.attributes.push_back(
41         std::to_string(size / (element_size * SizeOf(element_type))));
42   }
43   resources.buffers.push_back({"buffer", desc});
44   return resources;
45 }
46 
PerformSelector(const GpuInfo & gpu_info,const std::string & selector,const std::vector<std::string> & args,const std::vector<std::string> & template_args,std::string * result) const47 absl::Status BufferDescriptor::PerformSelector(
48     const GpuInfo& gpu_info, const std::string& selector,
49     const std::vector<std::string>& args,
50     const std::vector<std::string>& template_args, std::string* result) const {
51   if (selector == "Read") {
52     return PerformReadSelector(gpu_info, args, result);
53   } else if (selector == "GetPtr") {
54     return PerformGetPtrSelector(args, template_args, result);
55   } else {
56     return absl::NotFoundError(absl::StrCat(
57         "BufferDescriptor don't have selector with name - ", selector));
58   }
59 }
60 
PerformReadSelector(const GpuInfo & gpu_info,const std::vector<std::string> & args,std::string * result) const61 absl::Status BufferDescriptor::PerformReadSelector(
62     const GpuInfo& gpu_info, const std::vector<std::string>& args,
63     std::string* result) const {
64   if (args.size() != 1) {
65     return absl::NotFoundError(
66         absl::StrCat("BufferDescriptor Read require one argument, but ",
67                      args.size(), " was passed"));
68   }
69   if (gpu_info.IsGlsl()) {
70     if (element_type == DataType::FLOAT16 &&
71         !gpu_info.IsGlslSupportsExplicitFp16()) {
72       if (memory_type == MemoryType::CONSTANT) {
73         bool is_kernel_global_space = false;
74         for (const auto& attribute : attributes) {
75           if (attribute == "kernel_global_space") {
76             is_kernel_global_space = true;
77             break;
78           }
79         }
80         if (is_kernel_global_space) {
81           *result = absl::StrCat("buffer[", args[0], "]");
82           return absl::OkStatus();
83         }
84         const std::string arg0 = "(" + args[0] + ")";
85         *result =
86             absl::StrCat("vec4(unpackHalf2x16(buffer[", arg0, " / 2][", arg0,
87                          " % 2 == 0 ? 0 : 2]), unpackHalf2x16(buffer[", arg0,
88                          " / 2][", arg0, " % 2 == 0 ? 1 : 3]))");
89       } else {
90         if (element_size == 4) {
91           *result =
92               absl::StrCat("vec4(unpackHalf2x16(buffer[", args[0],
93                            "].x), unpackHalf2x16(buffer[", args[0], "].y))");
94         } else if (element_size == 16) {
95           const std::string vec0 = absl::Substitute(
96               "vec4(unpackHalf2x16(buffer[$0].a.x), "
97               "unpackHalf2x16(buffer[$0].a.y))",
98               args[0]);
99           const std::string vec1 = absl::Substitute(
100               "vec4(unpackHalf2x16(buffer[$0].a.z), "
101               "unpackHalf2x16(buffer[$0].a.w))",
102               args[0]);
103           const std::string vec2 = absl::Substitute(
104               "vec4(unpackHalf2x16(buffer[$0].b.x), "
105               "unpackHalf2x16(buffer[$0].b.y))",
106               args[0]);
107           const std::string vec3 = absl::Substitute(
108               "vec4(unpackHalf2x16(buffer[$0].b.z), "
109               "unpackHalf2x16(buffer[$0].b.w))",
110               args[0]);
111           *result = absl::Substitute("mat4x4($0, $1, $2, $3)", vec0, vec1, vec2,
112                                      vec3);
113         }
114       }
115     } else {
116       *result = absl::StrCat("buffer[", args[0], "]");
117     }
118     return absl::OkStatus();
119   } else {
120     *result = absl::StrCat("buffer[", args[0], "]");
121     return absl::OkStatus();
122   }
123 }
124 
PerformGetPtrSelector(const std::vector<std::string> & args,const std::vector<std::string> & template_args,std::string * result) const125 absl::Status BufferDescriptor::PerformGetPtrSelector(
126     const std::vector<std::string>& args,
127     const std::vector<std::string>& template_args, std::string* result) const {
128   if (args.size() > 1) {
129     return absl::NotFoundError(absl::StrCat(
130         "BufferDescriptor GetPtr require one or zero arguments, but ",
131         args.size(), " was passed"));
132   }
133   if (template_args.size() > 1) {
134     return absl::NotFoundError(
135         absl::StrCat("BufferDescriptor GetPtr require one or zero teemplate "
136                      "arguments, but ",
137                      template_args.size(), " was passed"));
138   }
139   std::string conversion;
140   if (template_args.size() == 1) {
141     const std::string type_name = ToCLDataType(element_type, element_size);
142     if (type_name != template_args[0]) {
143       conversion = absl::StrCat("(", MemoryTypeToCLType(memory_type), " ",
144                                 template_args[0], "*)&");
145     }
146   }
147   if (args.empty()) {
148     *result = absl::StrCat(conversion, "buffer");
149   } else if (conversion.empty()) {
150     *result = absl::StrCat("(buffer + ", args[0], ")");
151   } else {
152     *result = absl::StrCat(conversion, "buffer[", args[0], "]");
153   }
154   return absl::OkStatus();
155 }
156 
157 }  // namespace gpu
158 }  // namespace tflite
159