xref: /aosp_15_r20/external/angle/src/libANGLE/renderer/cl/CLKernelCL.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // CLKernelCL.cpp: Implements the class methods for CLKernelCL.
7 
8 #include "libANGLE/renderer/cl/CLKernelCL.h"
9 
10 #include "libANGLE/renderer/cl/CLCommandQueueCL.h"
11 #include "libANGLE/renderer/cl/CLContextCL.h"
12 #include "libANGLE/renderer/cl/CLDeviceCL.h"
13 #include "libANGLE/renderer/cl/CLMemoryCL.h"
14 #include "libANGLE/renderer/cl/CLSamplerCL.h"
15 
16 #include "libANGLE/CLCommandQueue.h"
17 #include "libANGLE/CLContext.h"
18 #include "libANGLE/CLKernel.h"
19 #include "libANGLE/CLMemory.h"
20 #include "libANGLE/CLPlatform.h"
21 #include "libANGLE/CLProgram.h"
22 #include "libANGLE/CLSampler.h"
23 #include "libANGLE/cl_utils.h"
24 
25 namespace rx
26 {
27 
28 namespace
29 {
30 
31 template <typename T>
GetWorkGroupInfo(cl_kernel kernel,cl_device_id device,cl::KernelWorkGroupInfo name,T & value,cl_int & errorCode)32 bool GetWorkGroupInfo(cl_kernel kernel,
33                       cl_device_id device,
34                       cl::KernelWorkGroupInfo name,
35                       T &value,
36                       cl_int &errorCode)
37 {
38     errorCode = kernel->getDispatch().clGetKernelWorkGroupInfo(kernel, device, cl::ToCLenum(name),
39                                                                sizeof(T), &value, nullptr);
40     return errorCode == CL_SUCCESS;
41 }
42 
43 template <typename T>
GetArgInfo(cl_kernel kernel,cl_uint index,cl::KernelArgInfo name,T & value,cl_int & errorCode)44 bool GetArgInfo(cl_kernel kernel,
45                 cl_uint index,
46                 cl::KernelArgInfo name,
47                 T &value,
48                 cl_int &errorCode)
49 {
50     errorCode = kernel->getDispatch().clGetKernelArgInfo(kernel, index, cl::ToCLenum(name),
51                                                          sizeof(T), &value, nullptr);
52     if (errorCode == CL_KERNEL_ARG_INFO_NOT_AVAILABLE)
53     {
54         errorCode = CL_SUCCESS;
55     }
56     return errorCode == CL_SUCCESS;
57 }
58 
59 template <typename T>
GetKernelInfo(cl_kernel kernel,cl::KernelInfo name,T & value,cl_int & errorCode)60 bool GetKernelInfo(cl_kernel kernel, cl::KernelInfo name, T &value, cl_int &errorCode)
61 {
62     errorCode = kernel->getDispatch().clGetKernelInfo(kernel, cl::ToCLenum(name), sizeof(T), &value,
63                                                       nullptr);
64     return errorCode == CL_SUCCESS;
65 }
66 
GetArgString(cl_kernel kernel,cl_uint index,cl::KernelArgInfo name,std::string & string,cl_int & errorCode)67 bool GetArgString(cl_kernel kernel,
68                   cl_uint index,
69                   cl::KernelArgInfo name,
70                   std::string &string,
71                   cl_int &errorCode)
72 {
73     size_t size = 0u;
74     errorCode   = kernel->getDispatch().clGetKernelArgInfo(kernel, index, cl::ToCLenum(name), 0u,
75                                                            nullptr, &size);
76     if (errorCode == CL_KERNEL_ARG_INFO_NOT_AVAILABLE)
77     {
78         errorCode = CL_SUCCESS;
79         return true;
80     }
81     else if (errorCode != CL_SUCCESS)
82     {
83         return false;
84     }
85     std::vector<char> valString(size, '\0');
86     errorCode = kernel->getDispatch().clGetKernelArgInfo(kernel, index, cl::ToCLenum(name), size,
87                                                          valString.data(), nullptr);
88     if (errorCode != CL_SUCCESS)
89     {
90         return false;
91     }
92     string.assign(valString.data(), valString.size() - 1u);
93     return true;
94 }
95 
GetKernelString(cl_kernel kernel,cl::KernelInfo name,std::string & string,cl_int & errorCode)96 bool GetKernelString(cl_kernel kernel, cl::KernelInfo name, std::string &string, cl_int &errorCode)
97 {
98     size_t size = 0u;
99     errorCode =
100         kernel->getDispatch().clGetKernelInfo(kernel, cl::ToCLenum(name), 0u, nullptr, &size);
101     if (errorCode != CL_SUCCESS)
102     {
103         return false;
104     }
105     std::vector<char> valString(size, '\0');
106     errorCode = kernel->getDispatch().clGetKernelInfo(kernel, cl::ToCLenum(name), size,
107                                                       valString.data(), nullptr);
108     if (errorCode != CL_SUCCESS)
109     {
110         return false;
111     }
112     string.assign(valString.data(), valString.size() - 1u);
113     return true;
114 }
115 
116 }  // namespace
117 
CLKernelCL(const cl::Kernel & kernel,cl_kernel native)118 CLKernelCL::CLKernelCL(const cl::Kernel &kernel, cl_kernel native)
119     : CLKernelImpl(kernel), mNative(native)
120 {}
121 
~CLKernelCL()122 CLKernelCL::~CLKernelCL()
123 {
124     if (mNative->getDispatch().clReleaseKernel(mNative) != CL_SUCCESS)
125     {
126         ERR() << "Error while releasing CL kernel";
127     }
128 }
129 
setArg(cl_uint argIndex,size_t argSize,const void * argValue)130 angle::Result CLKernelCL::setArg(cl_uint argIndex, size_t argSize, const void *argValue)
131 {
132     void *value = nullptr;
133     if (argValue != nullptr)
134     {
135         // If argument is a CL object, fetch the mapped value
136         const CLContextCL &ctx = mKernel.getProgram().getContext().getImpl<CLContextCL>();
137         if (argSize == sizeof(cl_mem))
138         {
139             cl_mem memory = *static_cast<const cl_mem *>(argValue);
140             if (ctx.hasMemory(memory))
141             {
142                 value = memory->cast<cl::Memory>().getImpl<CLMemoryCL>().getNative();
143             }
144         }
145         if (value == nullptr && argSize == sizeof(cl_sampler))
146         {
147             cl_sampler sampler = *static_cast<const cl_sampler *>(argValue);
148             if (ctx.hasSampler(sampler))
149             {
150                 value = sampler->cast<cl::Sampler>().getImpl<CLSamplerCL>().getNative();
151             }
152         }
153         if (value == nullptr && argSize == sizeof(cl_command_queue))
154         {
155             cl_command_queue queue = *static_cast<const cl_command_queue *>(argValue);
156             if (ctx.hasDeviceQueue(queue))
157             {
158                 value = queue->cast<cl::CommandQueue>().getImpl<CLCommandQueueCL>().getNative();
159             }
160         }
161     }
162 
163     // If mapped value was found, use it instead of original value
164     if (value != nullptr)
165     {
166         argValue = &value;
167     }
168     ANGLE_CL_TRY(mNative->getDispatch().clSetKernelArg(mNative, argIndex, argSize, argValue));
169     return angle::Result::Continue;
170 }
171 
createInfo(CLKernelImpl::Info * infoOut) const172 angle::Result CLKernelCL::createInfo(CLKernelImpl::Info *infoOut) const
173 {
174     cl_int errorCode       = CL_SUCCESS;
175     const cl::Context &ctx = mKernel.getProgram().getContext();
176 
177     if (!GetKernelString(mNative, cl::KernelInfo::FunctionName, infoOut->functionName, errorCode) ||
178         !GetKernelInfo(mNative, cl::KernelInfo::NumArgs, infoOut->numArgs, errorCode) ||
179         (ctx.getPlatform().isVersionOrNewer(1u, 2u) &&
180          !GetKernelString(mNative, cl::KernelInfo::Attributes, infoOut->attributes, errorCode)))
181     {
182         ANGLE_CL_RETURN_ERROR(errorCode);
183     }
184 
185     infoOut->workGroups.resize(ctx.getDevices().size());
186     for (size_t index = 0u; index < ctx.getDevices().size(); ++index)
187     {
188         const cl_device_id device = ctx.getDevices()[index]->getImpl<CLDeviceCL>().getNative();
189         WorkGroupInfo &workGroup  = infoOut->workGroups[index];
190 
191         if ((ctx.getPlatform().isVersionOrNewer(1u, 2u) &&
192              ctx.getDevices()[index]->supportsBuiltInKernel(infoOut->functionName) &&
193              !GetWorkGroupInfo(mNative, device, cl::KernelWorkGroupInfo::GlobalWorkSize,
194                                workGroup.globalWorkSize, errorCode)) ||
195             !GetWorkGroupInfo(mNative, device, cl::KernelWorkGroupInfo::WorkGroupSize,
196                               workGroup.workGroupSize, errorCode) ||
197             !GetWorkGroupInfo(mNative, device, cl::KernelWorkGroupInfo::CompileWorkGroupSize,
198                               workGroup.compileWorkGroupSize, errorCode) ||
199             !GetWorkGroupInfo(mNative, device, cl::KernelWorkGroupInfo::LocalMemSize,
200                               workGroup.localMemSize, errorCode) ||
201             !GetWorkGroupInfo(mNative, device,
202                               cl::KernelWorkGroupInfo::PreferredWorkGroupSizeMultiple,
203                               workGroup.prefWorkGroupSizeMultiple, errorCode) ||
204             !GetWorkGroupInfo(mNative, device, cl::KernelWorkGroupInfo::PrivateMemSize,
205                               workGroup.privateMemSize, errorCode))
206         {
207             ANGLE_CL_RETURN_ERROR(errorCode);
208         }
209     }
210 
211     infoOut->args.resize(infoOut->numArgs);
212     if (ctx.getPlatform().isVersionOrNewer(1u, 2u))
213     {
214         for (cl_uint index = 0u; index < infoOut->numArgs; ++index)
215         {
216             ArgInfo &arg = infoOut->args[index];
217             if (!GetArgInfo(mNative, index, cl::KernelArgInfo::AddressQualifier,
218                             arg.addressQualifier, errorCode) ||
219                 !GetArgInfo(mNative, index, cl::KernelArgInfo::AccessQualifier, arg.accessQualifier,
220                             errorCode) ||
221                 !GetArgString(mNative, index, cl::KernelArgInfo::TypeName, arg.typeName,
222                               errorCode) ||
223                 !GetArgInfo(mNative, index, cl::KernelArgInfo::TypeQualifier, arg.typeQualifier,
224                             errorCode) ||
225                 !GetArgString(mNative, index, cl::KernelArgInfo::Name, arg.name, errorCode))
226             {
227                 ANGLE_CL_RETURN_ERROR(errorCode);
228             }
229         }
230     }
231 
232     return angle::Result::Continue;
233 }
234 
235 }  // namespace rx
236