xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/cl/cl_context.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/cl/cl_context.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/lite/delegates/gpu/cl/cl_image_format.h"
20 #include "tensorflow/lite/delegates/gpu/cl/util.h"
21 #include "tensorflow/lite/delegates/gpu/common/status.h"
22 
23 namespace tflite {
24 namespace gpu {
25 namespace cl {
26 namespace {
27 
GetSupportedImage2DFormats(cl_context context,cl_mem_flags flags)28 std::vector<cl_image_format> GetSupportedImage2DFormats(cl_context context,
29                                                         cl_mem_flags flags) {
30   cl_uint num_image_formats;
31   cl_int error = clGetSupportedImageFormats(
32       context, flags, CL_MEM_OBJECT_IMAGE2D, 0, nullptr, &num_image_formats);
33   if (error != CL_SUCCESS) {
34     return {};
35   }
36 
37   std::vector<cl_image_format> result(num_image_formats);
38   error = clGetSupportedImageFormats(context, flags, CL_MEM_OBJECT_IMAGE2D,
39                                      num_image_formats, &result[0], nullptr);
40   if (error != CL_SUCCESS) {
41     return {};
42   }
43   return result;
44 }
45 
IsEqualToImageFormat(cl_image_format image_format,DataType data_type,int num_channels)46 bool IsEqualToImageFormat(cl_image_format image_format, DataType data_type,
47                           int num_channels) {
48   return image_format.image_channel_data_type ==
49              DataTypeToChannelType(data_type) &&
50          image_format.image_channel_order == ToChannelOrder(num_channels);
51 }
52 
AddSupportedImageFormats(cl_context context,GpuInfo * info)53 void AddSupportedImageFormats(cl_context context, GpuInfo* info) {
54   auto supported_formats =
55       GetSupportedImage2DFormats(context, CL_MEM_READ_WRITE);
56   const std::vector<DataType> kPossibleDataTypes = {
57       DataType::FLOAT16, DataType::FLOAT32, DataType::INT8,  DataType::UINT8,
58       DataType::INT16,   DataType::UINT16,  DataType::INT32, DataType::UINT32};
59   for (auto format : supported_formats) {
60     for (auto data_type : kPossibleDataTypes) {
61       if (IsEqualToImageFormat(format, data_type, 1)) {
62         info->opencl_info.supported_images_2d.r_layout.insert(data_type);
63       } else if (IsEqualToImageFormat(format, data_type, 2)) {
64         info->opencl_info.supported_images_2d.rg_layout.insert(data_type);
65       } else if (IsEqualToImageFormat(format, data_type, 3)) {
66         info->opencl_info.supported_images_2d.rgb_layout.insert(data_type);
67       } else if (IsEqualToImageFormat(format, data_type, 4)) {
68         info->opencl_info.supported_images_2d.rgba_layout.insert(data_type);
69       }
70     }
71   }
72 }
73 
CreateCLContext(const CLDevice & device,cl_context_properties * properties,CLContext * result)74 absl::Status CreateCLContext(const CLDevice& device,
75                              cl_context_properties* properties,
76                              CLContext* result) {
77   int error_code;
78   cl_device_id device_id = device.id();
79   cl_context context =
80       clCreateContext(properties, 1, &device_id, nullptr, nullptr, &error_code);
81   if (!context) {
82     return absl::UnknownError(
83         absl::StrCat("Failed to create a compute context - ",
84                      CLErrorCodeToString(error_code)));
85   }
86   AddSupportedImageFormats(context, &device.info_);
87 
88   *result = CLContext(context, true);
89   return absl::OkStatus();
90 }
91 
92 }  // namespace
93 
CLContext(cl_context context,bool has_ownership)94 CLContext::CLContext(cl_context context, bool has_ownership)
95     : context_(context), has_ownership_(has_ownership) {}
96 
CLContext(CLContext && context)97 CLContext::CLContext(CLContext&& context)
98     : context_(context.context_), has_ownership_(context.has_ownership_) {
99   context.context_ = nullptr;
100 }
101 
operator =(CLContext && context)102 CLContext& CLContext::operator=(CLContext&& context) {
103   if (this != &context) {
104     Release();
105     std::swap(context_, context.context_);
106     has_ownership_ = context.has_ownership_;
107   }
108   return *this;
109 }
110 
~CLContext()111 CLContext::~CLContext() { Release(); }
112 
Release()113 void CLContext::Release() {
114   if (has_ownership_ && context_) {
115     clReleaseContext(context_);
116     context_ = nullptr;
117   }
118 }
119 
IsFloatTexture2DSupported(int num_channels,DataType data_type,cl_mem_flags flags) const120 bool CLContext::IsFloatTexture2DSupported(int num_channels, DataType data_type,
121                                           cl_mem_flags flags) const {
122   auto supported_formats = GetSupportedImage2DFormats(context_, flags);
123   for (auto format : supported_formats) {
124     if (format.image_channel_data_type == DataTypeToChannelType(data_type) &&
125         format.image_channel_order == ToChannelOrder(num_channels)) {
126       return true;
127     }
128   }
129 
130   return false;
131 }
132 
CreateCLContext(const CLDevice & device,CLContext * result)133 absl::Status CreateCLContext(const CLDevice& device, CLContext* result) {
134   return CreateCLContext(device, nullptr, result);
135 }
136 
CreateCLGLContext(const CLDevice & device,cl_context_properties egl_context,cl_context_properties egl_display,CLContext * result)137 absl::Status CreateCLGLContext(const CLDevice& device,
138                                cl_context_properties egl_context,
139                                cl_context_properties egl_display,
140                                CLContext* result) {
141   if (!device.GetInfo().SupportsExtension("cl_khr_gl_sharing")) {
142     return absl::UnavailableError("Device doesn't support CL-GL sharing.");
143   }
144   cl_context_properties platform =
145       reinterpret_cast<cl_context_properties>(device.platform());
146   cl_context_properties props[] = {CL_GL_CONTEXT_KHR,
147                                    egl_context,
148                                    CL_EGL_DISPLAY_KHR,
149                                    egl_display,
150                                    CL_CONTEXT_PLATFORM,
151                                    platform,
152                                    0};
153   return CreateCLContext(device, props, result);
154 }
155 
156 }  // namespace cl
157 }  // namespace gpu
158 }  // namespace tflite
159