xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/cl/cl_kernel.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/lite/delegates/gpu/cl/cl_program.h"
23 #include "tensorflow/lite/delegates/gpu/cl/util.h"
24 #include "tensorflow/lite/delegates/gpu/common/status.h"
25 
26 namespace tflite {
27 namespace gpu {
28 namespace cl {
29 namespace {
30 
GetKernelMaxWorkGroupSize(cl_kernel kernel,cl_device_id device_id,int * result)31 absl::Status GetKernelMaxWorkGroupSize(cl_kernel kernel, cl_device_id device_id,
32                                        int* result) {
33   size_t max_work_group_size;
34   cl_int error_code =
35       clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_WORK_GROUP_SIZE,
36                                sizeof(size_t), &max_work_group_size, nullptr);
37   if (error_code != CL_SUCCESS) {
38     return absl::UnknownError(
39         absl::StrCat("Failed to get info CL_KERNEL_WORK_GROUP_SIZE ",
40                      CLErrorCodeToString(error_code)));
41   }
42   *result = static_cast<int>(max_work_group_size);
43   return absl::OkStatus();
44 }
45 
GetKernelPrivateMemorySize(cl_kernel kernel,cl_device_id device_id,int * result)46 absl::Status GetKernelPrivateMemorySize(cl_kernel kernel,
47                                         cl_device_id device_id, int* result) {
48   cl_ulong private_mem_size;
49   cl_int error_code =
50       clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_PRIVATE_MEM_SIZE,
51                                sizeof(cl_ulong), &private_mem_size, nullptr);
52   if (error_code != CL_SUCCESS) {
53     return absl::UnknownError(
54         absl::StrCat("Failed to get info CL_KERNEL_PRIVATE_MEM_SIZE ",
55                      CLErrorCodeToString(error_code)));
56   }
57   *result = static_cast<int>(private_mem_size);
58   return absl::OkStatus();
59 }
60 
61 }  // namespace
62 
CLKernel(CLKernel && kernel)63 CLKernel::CLKernel(CLKernel&& kernel)
64     : info_(kernel.info_),
65       binding_counter_(kernel.binding_counter_),
66       function_name_(std::move(kernel.function_name_)),
67       program_(kernel.program_),
68       kernel_(kernel.kernel_) {
69   kernel.kernel_ = nullptr;
70 }
71 
operator =(CLKernel && kernel)72 CLKernel& CLKernel::operator=(CLKernel&& kernel) {
73   if (this != &kernel) {
74     Release();
75     std::swap(info_, kernel.info_);
76     std::swap(binding_counter_, kernel.binding_counter_);
77     function_name_ = std::move(kernel.function_name_);
78     std::swap(program_, kernel.program_);
79     std::swap(kernel_, kernel.kernel_);
80   }
81   return *this;
82 }
83 
~CLKernel()84 CLKernel::~CLKernel() { Release(); }
85 
ReInit() const86 absl::Status CLKernel::ReInit() const {
87   clReleaseKernel(kernel_);
88   cl_kernel* kern_ptr = const_cast<cl_kernel*>(&kernel_);
89   int error_code;
90   *kern_ptr = clCreateKernel(program_, function_name_.c_str(), &error_code);
91   if (!kernel_ || error_code != CL_SUCCESS) {
92     *kern_ptr = nullptr;
93     return absl::UnknownError(absl::StrCat("Failed to create ", function_name_,
94                                            CLErrorCodeToString(error_code)));
95   }
96   return absl::OkStatus();
97 }
98 
Release()99 void CLKernel::Release() {
100   if (kernel_) {
101     clReleaseKernel(kernel_);
102     clReleaseProgram(program_);
103     kernel_ = nullptr;
104   }
105 }
106 
CreateFromProgram(const CLProgram & program,const std::string & function_name)107 absl::Status CLKernel::CreateFromProgram(const CLProgram& program,
108                                          const std::string& function_name) {
109   int error_code;
110   function_name_ = function_name;
111   kernel_ =
112       clCreateKernel(program.program(), function_name.c_str(), &error_code);
113   if (!kernel_ || error_code != CL_SUCCESS) {
114     kernel_ = nullptr;
115     return absl::UnknownError(absl::StrCat("Failed to create ", function_name,
116                                            CLErrorCodeToString(error_code)));
117   }
118 
119   program_ = program.program();
120   clRetainProgram(program_);
121 
122   RETURN_IF_ERROR(GetKernelPrivateMemorySize(kernel_, program.GetDeviceId(),
123                                              &info_.private_memory_size));
124   RETURN_IF_ERROR(GetKernelMaxWorkGroupSize(kernel_, program.GetDeviceId(),
125                                             &info_.max_work_group_size));
126   return absl::OkStatus();
127 }
128 
SetMemory(int index,cl_mem memory)129 absl::Status CLKernel::SetMemory(int index, cl_mem memory) {
130   return SetBytes(index, &memory, sizeof(cl_mem));
131 }
132 
SetMemoryAuto(cl_mem memory)133 absl::Status CLKernel::SetMemoryAuto(cl_mem memory) {
134   return SetBytesAuto(&memory, sizeof(cl_mem));
135 }
136 
SetBytes(int index,const void * ptr,int length) const137 absl::Status CLKernel::SetBytes(int index, const void* ptr, int length) const {
138   const int error_code = clSetKernelArg(kernel_, index, length, ptr);
139   if (error_code != CL_SUCCESS) {
140     return absl::UnknownError(absl::StrCat("Failed to set kernel arguments - ",
141                                            CLErrorCodeToString(error_code)));
142   }
143   return absl::OkStatus();
144 }
145 
SetBytesAuto(const void * ptr,int length)146 absl::Status CLKernel::SetBytesAuto(const void* ptr, int length) {
147   const int error_code = clSetKernelArg(kernel_, binding_counter_, length, ptr);
148   if (error_code != CL_SUCCESS) {
149     return absl::UnknownError(absl::StrCat(
150         "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
151         "(at index - ", binding_counter_, ")"));
152   }
153   binding_counter_++;
154   return absl::OkStatus();
155 }
156 
157 }  // namespace cl
158 }  // namespace gpu
159 }  // namespace tflite
160