xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/vk_api/Adapter.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // @lint-ignore-every CLANGTIDY clang-diagnostic-missing-field-initializers
10 
11 #include <executorch/backends/vulkan/runtime/vk_api/Adapter.h>
12 
13 #include <iomanip>
14 
15 namespace vkcompute {
16 namespace vkapi {
17 
18 namespace {
19 
create_logical_device(const PhysicalDevice & physical_device,const uint32_t num_queues_to_create,std::vector<Adapter::Queue> & queues,std::vector<uint32_t> & queue_usage)20 VkDevice create_logical_device(
21     const PhysicalDevice& physical_device,
22     const uint32_t num_queues_to_create,
23     std::vector<Adapter::Queue>& queues,
24     std::vector<uint32_t>& queue_usage) {
25   // Find compute queues up to the requested number of queues
26 
27   std::vector<VkDeviceQueueCreateInfo> queue_create_infos;
28   queue_create_infos.reserve(num_queues_to_create);
29 
30   std::vector<std::pair<uint32_t, uint32_t>> queues_to_get;
31   queues_to_get.reserve(num_queues_to_create);
32 
33   uint32_t remaining_queues = num_queues_to_create;
34   for (uint32_t family_i = 0; family_i < physical_device.queue_families.size();
35        ++family_i) {
36     const VkQueueFamilyProperties& queue_properties =
37         physical_device.queue_families.at(family_i);
38     // Check if this family has compute capability
39     if (queue_properties.queueFlags & VK_QUEUE_COMPUTE_BIT) {
40       const uint32_t queues_to_init =
41           std::min(remaining_queues, queue_properties.queueCount);
42 
43       const std::vector<float> queue_priorities(queues_to_init, 1.0f);
44       queue_create_infos.push_back({
45           VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO, // sType
46           nullptr, // pNext
47           0u, // flags
48           family_i, // queueFamilyIndex
49           queues_to_init, // queueCount
50           queue_priorities.data(), // pQueuePriorities
51       });
52 
53       for (size_t queue_i = 0; queue_i < queues_to_init; ++queue_i) {
54         // Use this to get the queue handle once device is created
55         queues_to_get.emplace_back(family_i, queue_i);
56       }
57       remaining_queues -= queues_to_init;
58     }
59     if (remaining_queues == 0) {
60       break;
61     }
62   }
63 
64   queues.reserve(queues_to_get.size());
65   queue_usage.reserve(queues_to_get.size());
66 
67   // Create the VkDevice
68 
69   std::vector<const char*> requested_device_extensions{
70 #ifdef VK_KHR_portability_subset
71       VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME,
72 #endif /* VK_KHR_portability_subset */
73 #ifdef VK_ANDROID_external_memory_android_hardware_buffer
74       VK_ANDROID_EXTERNAL_MEMORY_ANDROID_HARDWARE_BUFFER_EXTENSION_NAME,
75 #endif /* VK_ANDROID_external_memory_android_hardware_buffer */
76 #ifdef VK_KHR_16bit_storage
77       VK_KHR_16BIT_STORAGE_EXTENSION_NAME,
78 #endif /* VK_KHR_16bit_storage */
79 #ifdef VK_KHR_8bit_storage
80       VK_KHR_8BIT_STORAGE_EXTENSION_NAME,
81 #endif /* VK_KHR_8bit_storage */
82 #ifdef VK_KHR_shader_float16_int8
83       VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME,
84 #endif /* VK_KHR_shader_float16_int8 */
85   };
86 
87   std::vector<const char*> enabled_device_extensions;
88   find_requested_device_extensions(
89       physical_device.handle,
90       enabled_device_extensions,
91       requested_device_extensions);
92 
93   VkDeviceCreateInfo device_create_info{
94       VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO, // sType
95       nullptr, // pNext
96       0u, // flags
97       static_cast<uint32_t>(queue_create_infos.size()), // queueCreateInfoCount
98       queue_create_infos.data(), // pQueueCreateInfos
99       0u, // enabledLayerCount
100       nullptr, // ppEnabledLayerNames
101       static_cast<uint32_t>(
102           enabled_device_extensions.size()), // enabledExtensionCount
103       enabled_device_extensions.data(), // ppEnabledExtensionNames
104       nullptr, // pEnabledFeatures
105   };
106 
107   device_create_info.pNext = physical_device.extension_features;
108 
109   VkDevice handle = nullptr;
110   VK_CHECK(vkCreateDevice(
111       physical_device.handle, &device_create_info, nullptr, &handle));
112 
113 #ifdef USE_VULKAN_VOLK
114   volkLoadDevice(handle);
115 #endif /* USE_VULKAN_VOLK */
116 
117   // Obtain handles for the created queues and initialize queue usage heuristic
118 
119   for (const std::pair<uint32_t, uint32_t>& queue_idx : queues_to_get) {
120     VkQueue queue_handle = VK_NULL_HANDLE;
121     VkQueueFlags flags =
122         physical_device.queue_families.at(queue_idx.first).queueFlags;
123     vkGetDeviceQueue(handle, queue_idx.first, queue_idx.second, &queue_handle);
124     queues.push_back({queue_idx.first, queue_idx.second, flags, queue_handle});
125     // Initial usage value
126     queue_usage.push_back(0);
127   }
128 
129   return handle;
130 }
131 
132 } // namespace
133 
134 //
135 // Adapter
136 //
137 
Adapter(VkInstance instance,PhysicalDevice physical_device,const uint32_t num_queues,const std::string & cache_data_path)138 Adapter::Adapter(
139     VkInstance instance,
140     PhysicalDevice physical_device,
141     const uint32_t num_queues,
142     const std::string& cache_data_path)
143     : queue_usage_mutex_{},
144       physical_device_(std::move(physical_device)),
145       queues_{},
146       queue_usage_{},
147       queue_mutexes_{},
148       instance_(instance),
149       device_(create_logical_device(
150           physical_device_,
151           num_queues,
152           queues_,
153           queue_usage_)),
154       shader_layout_cache_(device_.handle),
155       shader_cache_(device_.handle),
156       pipeline_layout_cache_(device_.handle),
157       compute_pipeline_cache_(device_.handle, cache_data_path),
158       sampler_cache_(device_.handle),
159       vma_(instance_, physical_device_.handle, device_.handle),
160       linear_tiling_3d_enabled_{true} {
161   // Test creating a 3D image with linear tiling to see if it is supported.
162   // According to the Vulkan spec, linear tiling may not be supported for 3D
163   // images.
164   VkExtent3D image_extents{1u, 1u, 1u};
165   const VkImageCreateInfo image_create_info{
166       VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // sType
167       nullptr, // pNext
168       0u, // flags
169       VK_IMAGE_TYPE_3D, // imageType
170       VK_FORMAT_R32G32B32A32_SFLOAT, // format
171       image_extents, // extents
172       1u, // mipLevels
173       1u, // arrayLayers
174       VK_SAMPLE_COUNT_1_BIT, // samples
175       VK_IMAGE_TILING_LINEAR, // tiling
176       VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_STORAGE_BIT, // usage
177       VK_SHARING_MODE_EXCLUSIVE, // sharingMode
178       0u, // queueFamilyIndexCount
179       nullptr, // pQueueFamilyIndices
180       VK_IMAGE_LAYOUT_UNDEFINED, // initialLayout
181   };
182   VkImage image = VK_NULL_HANDLE;
183   VkResult res =
184       vkCreateImage(device_.handle, &image_create_info, nullptr, &image);
185   if (res == VK_ERROR_FEATURE_NOT_PRESENT) {
186     linear_tiling_3d_enabled_ = false;
187   } else if (res == VK_SUCCESS) {
188     vkDestroyImage(device_.handle, image, nullptr);
189   }
190   return;
191 }
192 
request_queue()193 Adapter::Queue Adapter::request_queue() {
194   // Lock the mutex as multiple threads can request a queue at the same time
195   std::lock_guard<std::mutex> lock(queue_usage_mutex_);
196 
197   uint32_t min_usage = UINT32_MAX;
198   uint32_t min_used_i = 0;
199   for (size_t i = 0; i < queues_.size(); ++i) {
200     if (queue_usage_[i] < min_usage) {
201       min_used_i = i;
202       min_usage = queue_usage_[i];
203     }
204   }
205   queue_usage_[min_used_i] += 1;
206 
207   return queues_[min_used_i];
208 }
209 
return_queue(Adapter::Queue & compute_queue)210 void Adapter::return_queue(Adapter::Queue& compute_queue) {
211   for (size_t i = 0; i < queues_.size(); ++i) {
212     if ((queues_[i].family_index == compute_queue.family_index) &&
213         (queues_[i].queue_index == compute_queue.queue_index)) {
214       std::lock_guard<std::mutex> lock(queue_usage_mutex_);
215       queue_usage_[i] -= 1;
216       break;
217     }
218   }
219 }
220 
submit_cmd(const Adapter::Queue & device_queue,VkCommandBuffer cmd,VkFence fence)221 void Adapter::submit_cmd(
222     const Adapter::Queue& device_queue,
223     VkCommandBuffer cmd,
224     VkFence fence) {
225   const VkSubmitInfo submit_info{
226       VK_STRUCTURE_TYPE_SUBMIT_INFO, // sType
227       nullptr, // pNext
228       0u, // waitSemaphoreCount
229       nullptr, // pWaitSemaphores
230       nullptr, // pWaitDstStageMask
231       1u, // commandBufferCount
232       &cmd, // pCommandBuffers
233       0u, // signalSemaphoreCount
234       nullptr, // pSignalSemaphores
235   };
236 
237   std::lock_guard<std::mutex> queue_lock(
238       queue_mutexes_[device_queue.queue_index % NUM_QUEUE_MUTEXES]);
239 
240   VK_CHECK(vkQueueSubmit(device_queue.handle, 1u, &submit_info, fence));
241 }
242 
stringize() const243 std::string Adapter::stringize() const {
244   std::stringstream ss;
245 
246   VkPhysicalDeviceProperties properties = physical_device_.properties;
247   uint32_t v_major = VK_VERSION_MAJOR(properties.apiVersion);
248   uint32_t v_minor = VK_VERSION_MINOR(properties.apiVersion);
249   std::string device_type = get_device_type_str(properties.deviceType);
250   VkPhysicalDeviceLimits limits = properties.limits;
251 
252   ss << "{" << std::endl;
253   ss << "  Physical Device Info {" << std::endl;
254   ss << "    apiVersion:    " << v_major << "." << v_minor << std::endl;
255   ss << "    driverversion: " << properties.driverVersion << std::endl;
256   ss << "    deviceType:    " << device_type << std::endl;
257   ss << "    deviceName:    " << properties.deviceName << std::endl;
258 
259 #define PRINT_PROP(struct, name)                                       \
260   ss << "      " << std::left << std::setw(36) << #name << struct.name \
261      << std::endl;
262 
263 #define PRINT_PROP_VEC3(struct, name)                                     \
264   ss << "      " << std::left << std::setw(36) << #name << struct.name[0] \
265      << "," << struct.name[1] << "," << struct.name[2] << std::endl;
266 
267   ss << "    Physical Device Limits {" << std::endl;
268   PRINT_PROP(limits, maxImageDimension1D);
269   PRINT_PROP(limits, maxImageDimension2D);
270   PRINT_PROP(limits, maxImageDimension3D);
271   PRINT_PROP(limits, maxTexelBufferElements);
272   PRINT_PROP(limits, maxPushConstantsSize);
273   PRINT_PROP(limits, maxMemoryAllocationCount);
274   PRINT_PROP(limits, maxSamplerAllocationCount);
275   PRINT_PROP(limits, maxComputeSharedMemorySize);
276   PRINT_PROP_VEC3(limits, maxComputeWorkGroupCount);
277   PRINT_PROP(limits, maxComputeWorkGroupInvocations);
278   PRINT_PROP_VEC3(limits, maxComputeWorkGroupSize);
279   ss << "    }" << std::endl;
280 
281 #ifdef VK_KHR_16bit_storage
282   ss << "    16bit Storage Features {" << std::endl;
283   PRINT_PROP(physical_device_.shader_16bit_storage, storageBuffer16BitAccess);
284   PRINT_PROP(
285       physical_device_.shader_16bit_storage,
286       uniformAndStorageBuffer16BitAccess);
287   PRINT_PROP(physical_device_.shader_16bit_storage, storagePushConstant16);
288   PRINT_PROP(physical_device_.shader_16bit_storage, storageInputOutput16);
289   ss << "    }" << std::endl;
290 #endif /* VK_KHR_16bit_storage */
291 
292 #ifdef VK_KHR_8bit_storage
293   ss << "    8bit Storage Features {" << std::endl;
294   PRINT_PROP(physical_device_.shader_8bit_storage, storageBuffer8BitAccess);
295   PRINT_PROP(
296       physical_device_.shader_8bit_storage, uniformAndStorageBuffer8BitAccess);
297   PRINT_PROP(physical_device_.shader_8bit_storage, storagePushConstant8);
298   ss << "    }" << std::endl;
299 #endif /* VK_KHR_8bit_storage */
300 
301 #ifdef VK_KHR_shader_float16_int8
302   ss << "    Shader 16bit and 8bit Features {" << std::endl;
303   PRINT_PROP(physical_device_.shader_float16_int8_types, shaderFloat16);
304   PRINT_PROP(physical_device_.shader_float16_int8_types, shaderInt8);
305   ss << "    }" << std::endl;
306 #endif /* VK_KHR_shader_float16_int8 */
307 
308   const VkPhysicalDeviceMemoryProperties& mem_props =
309       physical_device_.memory_properties;
310 
311   ss << "  }" << std::endl;
312   ss << "  Memory Info {" << std::endl;
313   ss << "    Memory Types [" << std::endl;
314   for (size_t i = 0; i < mem_props.memoryTypeCount; ++i) {
315     ss << "      " << " [Heap " << mem_props.memoryTypes[i].heapIndex << "] "
316        << get_memory_properties_str(mem_props.memoryTypes[i].propertyFlags)
317        << std::endl;
318   }
319   ss << "    ]" << std::endl;
320   ss << "    Memory Heaps [" << std::endl;
321   for (size_t i = 0; i < mem_props.memoryHeapCount; ++i) {
322     ss << "      " << mem_props.memoryHeaps[i].size << std::endl;
323   }
324   ss << "    ]" << std::endl;
325   ss << "  }" << std::endl;
326 
327   ss << "  Queue Families {" << std::endl;
328   for (const VkQueueFamilyProperties& queue_family_props :
329        physical_device_.queue_families) {
330     ss << "    (" << queue_family_props.queueCount << " Queues) "
331        << get_queue_family_properties_str(queue_family_props.queueFlags)
332        << std::endl;
333   }
334   ss << "  }" << std::endl;
335   ss << "  VkDevice: " << device_.handle << std::endl;
336   ss << "  Compute Queues [" << std::endl;
337   for (const Adapter::Queue& compute_queue : queues_) {
338     ss << "    Family " << compute_queue.family_index << ", Queue "
339        << compute_queue.queue_index << ": " << compute_queue.handle
340        << std::endl;
341     ;
342   }
343   ss << "  ]" << std::endl;
344   ss << "}";
345 
346 #undef PRINT_PROP
347 #undef PRINT_PROP_VEC3
348 
349   return ss.str();
350 }
351 
operator <<(std::ostream & os,const Adapter & adapter)352 std::ostream& operator<<(std::ostream& os, const Adapter& adapter) {
353   os << adapter.stringize() << std::endl;
354   return os;
355 }
356 
357 } // namespace vkapi
358 } // namespace vkcompute
359