1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/lite/delegates/gpu/cl/cl_program.h"
23 #include "tensorflow/lite/delegates/gpu/cl/util.h"
24 #include "tensorflow/lite/delegates/gpu/common/status.h"
25
26 namespace tflite {
27 namespace gpu {
28 namespace cl {
29 namespace {
30
GetKernelMaxWorkGroupSize(cl_kernel kernel,cl_device_id device_id,int * result)31 absl::Status GetKernelMaxWorkGroupSize(cl_kernel kernel, cl_device_id device_id,
32 int* result) {
33 size_t max_work_group_size;
34 cl_int error_code =
35 clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_WORK_GROUP_SIZE,
36 sizeof(size_t), &max_work_group_size, nullptr);
37 if (error_code != CL_SUCCESS) {
38 return absl::UnknownError(
39 absl::StrCat("Failed to get info CL_KERNEL_WORK_GROUP_SIZE ",
40 CLErrorCodeToString(error_code)));
41 }
42 *result = static_cast<int>(max_work_group_size);
43 return absl::OkStatus();
44 }
45
GetKernelPrivateMemorySize(cl_kernel kernel,cl_device_id device_id,int * result)46 absl::Status GetKernelPrivateMemorySize(cl_kernel kernel,
47 cl_device_id device_id, int* result) {
48 cl_ulong private_mem_size;
49 cl_int error_code =
50 clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_PRIVATE_MEM_SIZE,
51 sizeof(cl_ulong), &private_mem_size, nullptr);
52 if (error_code != CL_SUCCESS) {
53 return absl::UnknownError(
54 absl::StrCat("Failed to get info CL_KERNEL_PRIVATE_MEM_SIZE ",
55 CLErrorCodeToString(error_code)));
56 }
57 *result = static_cast<int>(private_mem_size);
58 return absl::OkStatus();
59 }
60
61 } // namespace
62
CLKernel(CLKernel && kernel)63 CLKernel::CLKernel(CLKernel&& kernel)
64 : info_(kernel.info_),
65 binding_counter_(kernel.binding_counter_),
66 function_name_(std::move(kernel.function_name_)),
67 program_(kernel.program_),
68 kernel_(kernel.kernel_) {
69 kernel.kernel_ = nullptr;
70 }
71
operator =(CLKernel && kernel)72 CLKernel& CLKernel::operator=(CLKernel&& kernel) {
73 if (this != &kernel) {
74 Release();
75 std::swap(info_, kernel.info_);
76 std::swap(binding_counter_, kernel.binding_counter_);
77 function_name_ = std::move(kernel.function_name_);
78 std::swap(program_, kernel.program_);
79 std::swap(kernel_, kernel.kernel_);
80 }
81 return *this;
82 }
83
~CLKernel()84 CLKernel::~CLKernel() { Release(); }
85
ReInit() const86 absl::Status CLKernel::ReInit() const {
87 clReleaseKernel(kernel_);
88 cl_kernel* kern_ptr = const_cast<cl_kernel*>(&kernel_);
89 int error_code;
90 *kern_ptr = clCreateKernel(program_, function_name_.c_str(), &error_code);
91 if (!kernel_ || error_code != CL_SUCCESS) {
92 *kern_ptr = nullptr;
93 return absl::UnknownError(absl::StrCat("Failed to create ", function_name_,
94 CLErrorCodeToString(error_code)));
95 }
96 return absl::OkStatus();
97 }
98
Release()99 void CLKernel::Release() {
100 if (kernel_) {
101 clReleaseKernel(kernel_);
102 clReleaseProgram(program_);
103 kernel_ = nullptr;
104 }
105 }
106
CreateFromProgram(const CLProgram & program,const std::string & function_name)107 absl::Status CLKernel::CreateFromProgram(const CLProgram& program,
108 const std::string& function_name) {
109 int error_code;
110 function_name_ = function_name;
111 kernel_ =
112 clCreateKernel(program.program(), function_name.c_str(), &error_code);
113 if (!kernel_ || error_code != CL_SUCCESS) {
114 kernel_ = nullptr;
115 return absl::UnknownError(absl::StrCat("Failed to create ", function_name,
116 CLErrorCodeToString(error_code)));
117 }
118
119 program_ = program.program();
120 clRetainProgram(program_);
121
122 RETURN_IF_ERROR(GetKernelPrivateMemorySize(kernel_, program.GetDeviceId(),
123 &info_.private_memory_size));
124 RETURN_IF_ERROR(GetKernelMaxWorkGroupSize(kernel_, program.GetDeviceId(),
125 &info_.max_work_group_size));
126 return absl::OkStatus();
127 }
128
SetMemory(int index,cl_mem memory)129 absl::Status CLKernel::SetMemory(int index, cl_mem memory) {
130 return SetBytes(index, &memory, sizeof(cl_mem));
131 }
132
SetMemoryAuto(cl_mem memory)133 absl::Status CLKernel::SetMemoryAuto(cl_mem memory) {
134 return SetBytesAuto(&memory, sizeof(cl_mem));
135 }
136
SetBytes(int index,const void * ptr,int length) const137 absl::Status CLKernel::SetBytes(int index, const void* ptr, int length) const {
138 const int error_code = clSetKernelArg(kernel_, index, length, ptr);
139 if (error_code != CL_SUCCESS) {
140 return absl::UnknownError(absl::StrCat("Failed to set kernel arguments - ",
141 CLErrorCodeToString(error_code)));
142 }
143 return absl::OkStatus();
144 }
145
SetBytesAuto(const void * ptr,int length)146 absl::Status CLKernel::SetBytesAuto(const void* ptr, int length) {
147 const int error_code = clSetKernelArg(kernel_, binding_counter_, length, ptr);
148 if (error_code != CL_SUCCESS) {
149 return absl::UnknownError(absl::StrCat(
150 "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
151 "(at index - ", binding_counter_, ")"));
152 }
153 binding_counter_++;
154 return absl::OkStatus();
155 }
156
157 } // namespace cl
158 } // namespace gpu
159 } // namespace tflite
160