1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h"
17
18 #include <string>
19
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/substitute.h"
22 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
23 #include "tensorflow/lite/delegates/gpu/common/status.h"
24 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
25
26 namespace tflite {
27 namespace gpu {
28
Release()29 void BufferDescriptor::Release() { data.clear(); }
30
GetGPUResources(const GpuInfo & gpu_info) const31 GPUResources BufferDescriptor::GetGPUResources(const GpuInfo& gpu_info) const {
32 GPUResources resources;
33 GPUBufferDescriptor desc;
34 desc.data_type = element_type;
35 desc.access_type = access_type_;
36 desc.element_size = element_size;
37 desc.memory_type = memory_type;
38 desc.attributes = attributes;
39 if (gpu_info.IsGlsl() && memory_type == tflite::gpu::MemoryType::CONSTANT) {
40 desc.attributes.push_back(
41 std::to_string(size / (element_size * SizeOf(element_type))));
42 }
43 resources.buffers.push_back({"buffer", desc});
44 return resources;
45 }
46
PerformSelector(const GpuInfo & gpu_info,const std::string & selector,const std::vector<std::string> & args,const std::vector<std::string> & template_args,std::string * result) const47 absl::Status BufferDescriptor::PerformSelector(
48 const GpuInfo& gpu_info, const std::string& selector,
49 const std::vector<std::string>& args,
50 const std::vector<std::string>& template_args, std::string* result) const {
51 if (selector == "Read") {
52 return PerformReadSelector(gpu_info, args, result);
53 } else if (selector == "GetPtr") {
54 return PerformGetPtrSelector(args, template_args, result);
55 } else {
56 return absl::NotFoundError(absl::StrCat(
57 "BufferDescriptor don't have selector with name - ", selector));
58 }
59 }
60
PerformReadSelector(const GpuInfo & gpu_info,const std::vector<std::string> & args,std::string * result) const61 absl::Status BufferDescriptor::PerformReadSelector(
62 const GpuInfo& gpu_info, const std::vector<std::string>& args,
63 std::string* result) const {
64 if (args.size() != 1) {
65 return absl::NotFoundError(
66 absl::StrCat("BufferDescriptor Read require one argument, but ",
67 args.size(), " was passed"));
68 }
69 if (gpu_info.IsGlsl()) {
70 if (element_type == DataType::FLOAT16 &&
71 !gpu_info.IsGlslSupportsExplicitFp16()) {
72 if (memory_type == MemoryType::CONSTANT) {
73 bool is_kernel_global_space = false;
74 for (const auto& attribute : attributes) {
75 if (attribute == "kernel_global_space") {
76 is_kernel_global_space = true;
77 break;
78 }
79 }
80 if (is_kernel_global_space) {
81 *result = absl::StrCat("buffer[", args[0], "]");
82 return absl::OkStatus();
83 }
84 const std::string arg0 = "(" + args[0] + ")";
85 *result =
86 absl::StrCat("vec4(unpackHalf2x16(buffer[", arg0, " / 2][", arg0,
87 " % 2 == 0 ? 0 : 2]), unpackHalf2x16(buffer[", arg0,
88 " / 2][", arg0, " % 2 == 0 ? 1 : 3]))");
89 } else {
90 if (element_size == 4) {
91 *result =
92 absl::StrCat("vec4(unpackHalf2x16(buffer[", args[0],
93 "].x), unpackHalf2x16(buffer[", args[0], "].y))");
94 } else if (element_size == 16) {
95 const std::string vec0 = absl::Substitute(
96 "vec4(unpackHalf2x16(buffer[$0].a.x), "
97 "unpackHalf2x16(buffer[$0].a.y))",
98 args[0]);
99 const std::string vec1 = absl::Substitute(
100 "vec4(unpackHalf2x16(buffer[$0].a.z), "
101 "unpackHalf2x16(buffer[$0].a.w))",
102 args[0]);
103 const std::string vec2 = absl::Substitute(
104 "vec4(unpackHalf2x16(buffer[$0].b.x), "
105 "unpackHalf2x16(buffer[$0].b.y))",
106 args[0]);
107 const std::string vec3 = absl::Substitute(
108 "vec4(unpackHalf2x16(buffer[$0].b.z), "
109 "unpackHalf2x16(buffer[$0].b.w))",
110 args[0]);
111 *result = absl::Substitute("mat4x4($0, $1, $2, $3)", vec0, vec1, vec2,
112 vec3);
113 }
114 }
115 } else {
116 *result = absl::StrCat("buffer[", args[0], "]");
117 }
118 return absl::OkStatus();
119 } else {
120 *result = absl::StrCat("buffer[", args[0], "]");
121 return absl::OkStatus();
122 }
123 }
124
PerformGetPtrSelector(const std::vector<std::string> & args,const std::vector<std::string> & template_args,std::string * result) const125 absl::Status BufferDescriptor::PerformGetPtrSelector(
126 const std::vector<std::string>& args,
127 const std::vector<std::string>& template_args, std::string* result) const {
128 if (args.size() > 1) {
129 return absl::NotFoundError(absl::StrCat(
130 "BufferDescriptor GetPtr require one or zero arguments, but ",
131 args.size(), " was passed"));
132 }
133 if (template_args.size() > 1) {
134 return absl::NotFoundError(
135 absl::StrCat("BufferDescriptor GetPtr require one or zero teemplate "
136 "arguments, but ",
137 template_args.size(), " was passed"));
138 }
139 std::string conversion;
140 if (template_args.size() == 1) {
141 const std::string type_name = ToCLDataType(element_type, element_size);
142 if (type_name != template_args[0]) {
143 conversion = absl::StrCat("(", MemoryTypeToCLType(memory_type), " ",
144 template_args[0], "*)&");
145 }
146 }
147 if (args.empty()) {
148 *result = absl::StrCat(conversion, "buffer");
149 } else if (conversion.empty()) {
150 *result = absl::StrCat("(buffer + ", args[0], ")");
151 } else {
152 *result = absl::StrCat(conversion, "buffer[", args[0], "]");
153 }
154 return absl::OkStatus();
155 }
156
157 } // namespace gpu
158 } // namespace tflite
159