1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/cl/cl_context.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/lite/delegates/gpu/cl/cl_image_format.h"
20 #include "tensorflow/lite/delegates/gpu/cl/util.h"
21 #include "tensorflow/lite/delegates/gpu/common/status.h"
22
23 namespace tflite {
24 namespace gpu {
25 namespace cl {
26 namespace {
27
GetSupportedImage2DFormats(cl_context context,cl_mem_flags flags)28 std::vector<cl_image_format> GetSupportedImage2DFormats(cl_context context,
29 cl_mem_flags flags) {
30 cl_uint num_image_formats;
31 cl_int error = clGetSupportedImageFormats(
32 context, flags, CL_MEM_OBJECT_IMAGE2D, 0, nullptr, &num_image_formats);
33 if (error != CL_SUCCESS) {
34 return {};
35 }
36
37 std::vector<cl_image_format> result(num_image_formats);
38 error = clGetSupportedImageFormats(context, flags, CL_MEM_OBJECT_IMAGE2D,
39 num_image_formats, &result[0], nullptr);
40 if (error != CL_SUCCESS) {
41 return {};
42 }
43 return result;
44 }
45
IsEqualToImageFormat(cl_image_format image_format,DataType data_type,int num_channels)46 bool IsEqualToImageFormat(cl_image_format image_format, DataType data_type,
47 int num_channels) {
48 return image_format.image_channel_data_type ==
49 DataTypeToChannelType(data_type) &&
50 image_format.image_channel_order == ToChannelOrder(num_channels);
51 }
52
AddSupportedImageFormats(cl_context context,GpuInfo * info)53 void AddSupportedImageFormats(cl_context context, GpuInfo* info) {
54 auto supported_formats =
55 GetSupportedImage2DFormats(context, CL_MEM_READ_WRITE);
56 const std::vector<DataType> kPossibleDataTypes = {
57 DataType::FLOAT16, DataType::FLOAT32, DataType::INT8, DataType::UINT8,
58 DataType::INT16, DataType::UINT16, DataType::INT32, DataType::UINT32};
59 for (auto format : supported_formats) {
60 for (auto data_type : kPossibleDataTypes) {
61 if (IsEqualToImageFormat(format, data_type, 1)) {
62 info->opencl_info.supported_images_2d.r_layout.insert(data_type);
63 } else if (IsEqualToImageFormat(format, data_type, 2)) {
64 info->opencl_info.supported_images_2d.rg_layout.insert(data_type);
65 } else if (IsEqualToImageFormat(format, data_type, 3)) {
66 info->opencl_info.supported_images_2d.rgb_layout.insert(data_type);
67 } else if (IsEqualToImageFormat(format, data_type, 4)) {
68 info->opencl_info.supported_images_2d.rgba_layout.insert(data_type);
69 }
70 }
71 }
72 }
73
CreateCLContext(const CLDevice & device,cl_context_properties * properties,CLContext * result)74 absl::Status CreateCLContext(const CLDevice& device,
75 cl_context_properties* properties,
76 CLContext* result) {
77 int error_code;
78 cl_device_id device_id = device.id();
79 cl_context context =
80 clCreateContext(properties, 1, &device_id, nullptr, nullptr, &error_code);
81 if (!context) {
82 return absl::UnknownError(
83 absl::StrCat("Failed to create a compute context - ",
84 CLErrorCodeToString(error_code)));
85 }
86 AddSupportedImageFormats(context, &device.info_);
87
88 *result = CLContext(context, true);
89 return absl::OkStatus();
90 }
91
92 } // namespace
93
CLContext(cl_context context,bool has_ownership)94 CLContext::CLContext(cl_context context, bool has_ownership)
95 : context_(context), has_ownership_(has_ownership) {}
96
CLContext(CLContext && context)97 CLContext::CLContext(CLContext&& context)
98 : context_(context.context_), has_ownership_(context.has_ownership_) {
99 context.context_ = nullptr;
100 }
101
operator =(CLContext && context)102 CLContext& CLContext::operator=(CLContext&& context) {
103 if (this != &context) {
104 Release();
105 std::swap(context_, context.context_);
106 has_ownership_ = context.has_ownership_;
107 }
108 return *this;
109 }
110
~CLContext()111 CLContext::~CLContext() { Release(); }
112
Release()113 void CLContext::Release() {
114 if (has_ownership_ && context_) {
115 clReleaseContext(context_);
116 context_ = nullptr;
117 }
118 }
119
IsFloatTexture2DSupported(int num_channels,DataType data_type,cl_mem_flags flags) const120 bool CLContext::IsFloatTexture2DSupported(int num_channels, DataType data_type,
121 cl_mem_flags flags) const {
122 auto supported_formats = GetSupportedImage2DFormats(context_, flags);
123 for (auto format : supported_formats) {
124 if (format.image_channel_data_type == DataTypeToChannelType(data_type) &&
125 format.image_channel_order == ToChannelOrder(num_channels)) {
126 return true;
127 }
128 }
129
130 return false;
131 }
132
CreateCLContext(const CLDevice & device,CLContext * result)133 absl::Status CreateCLContext(const CLDevice& device, CLContext* result) {
134 return CreateCLContext(device, nullptr, result);
135 }
136
CreateCLGLContext(const CLDevice & device,cl_context_properties egl_context,cl_context_properties egl_display,CLContext * result)137 absl::Status CreateCLGLContext(const CLDevice& device,
138 cl_context_properties egl_context,
139 cl_context_properties egl_display,
140 CLContext* result) {
141 if (!device.GetInfo().SupportsExtension("cl_khr_gl_sharing")) {
142 return absl::UnavailableError("Device doesn't support CL-GL sharing.");
143 }
144 cl_context_properties platform =
145 reinterpret_cast<cl_context_properties>(device.platform());
146 cl_context_properties props[] = {CL_GL_CONTEXT_KHR,
147 egl_context,
148 CL_EGL_DISPLAY_KHR,
149 egl_display,
150 CL_CONTEXT_PLATFORM,
151 platform,
152 0};
153 return CreateCLContext(device, props, result);
154 }
155
156 } // namespace cl
157 } // namespace gpu
158 } // namespace tflite
159