1 #pragma once 2 3 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName 4 5 #ifdef USE_VULKAN_API 6 7 #include <ATen/native/vulkan/api/vk_api.h> 8 9 #include <ATen/native/vulkan/api/Pipeline.h> 10 #include <ATen/native/vulkan/api/Shader.h> 11 #include <ATen/native/vulkan/api/Utils.h> 12 13 #include <array> 14 #include <mutex> 15 #include <ostream> 16 17 namespace at { 18 namespace native { 19 namespace vulkan { 20 namespace api { 21 22 struct PhysicalDevice final { 23 // Handle 24 VkPhysicalDevice handle; 25 26 // Properties obtained from Vulkan 27 VkPhysicalDeviceProperties properties; 28 VkPhysicalDeviceMemoryProperties memory_properties; 29 std::vector<VkQueueFamilyProperties> queue_families; 30 31 // Metadata 32 uint32_t num_compute_queues; 33 bool has_unified_memory; 34 bool has_timestamps; 35 float timestamp_period; 36 37 explicit PhysicalDevice(VkPhysicalDevice); 38 }; 39 40 class DeviceHandle final { 41 public: 42 explicit DeviceHandle(VkDevice device); 43 44 DeviceHandle(const DeviceHandle&) = delete; 45 DeviceHandle& operator=(const DeviceHandle&) = delete; 46 47 DeviceHandle(DeviceHandle&&) noexcept; 48 DeviceHandle& operator=(DeviceHandle&&) = delete; 49 50 ~DeviceHandle(); 51 52 private: 53 VkDevice handle_; 54 55 friend class Adapter; 56 }; 57 58 // 59 // A Vulkan Adapter represents a logical device and all its properties. It 60 // manages all relevant properties of the underlying physical device, a 61 // handle to the logical device, and a number of compute queues available to 62 // the device. It is primarily responsible for managing the VkDevice handle 63 // which points to the logical device object on the GPU. 64 // 65 // This class is primarily used by the Runtime class, which holds one Adapter 66 // instance for each physical device visible to the VkInstance. Upon 67 // construction, this class will populate the physical device properties, but 68 // will not create the logical device until specifically requested via the 69 // init_device() function. 70 // 71 // init_device() will create the logical device and obtain the VkDevice handle 72 // for it. It will also create a number of compute queues up to the amount 73 // requested when the Adapter instance was constructed. 74 // 75 // Contexts (which represent one thread of execution) will request a compute 76 // queue from an Adapter. The Adapter will then select a compute queue to 77 // assign to the Context, attempting to balance load between all available 78 // queues. This will allow different Contexts (which typically execute on 79 // separate threads) to run concurrently. 80 // 81 82 #define NUM_QUEUE_MUTEXES 4 83 84 class Adapter final { 85 public: 86 explicit Adapter( 87 VkInstance instance, 88 PhysicalDevice physical_device, 89 const uint32_t num_queues); 90 91 Adapter(const Adapter&) = delete; 92 Adapter& operator=(const Adapter&) = delete; 93 94 Adapter(Adapter&&) = delete; 95 Adapter& operator=(Adapter&&) = delete; 96 97 ~Adapter() = default; 98 99 struct Queue { 100 uint32_t family_index; 101 uint32_t queue_index; 102 VkQueueFlags capabilities; 103 VkQueue handle; 104 }; 105 106 private: 107 // Use a mutex to manage queue usage info since 108 // it can be accessed from multiple threads 109 std::mutex queue_usage_mutex_; 110 // Physical Device Info 111 PhysicalDevice physical_device_; 112 // Queue Management 113 std::vector<Queue> queues_; 114 std::vector<uint32_t> queue_usage_; 115 std::array<std::mutex, NUM_QUEUE_MUTEXES> queue_mutexes_; 116 // Handles 117 VkInstance instance_; 118 DeviceHandle device_; 119 // Device-level resource caches 120 ShaderLayoutCache shader_layout_cache_; 121 ShaderCache shader_cache_; 122 PipelineLayoutCache pipeline_layout_cache_; 123 ComputePipelineCache compute_pipeline_cache_; 124 // Memory Management 125 SamplerCache sampler_cache_; 126 MemoryAllocator vma_; 127 128 public: 129 // Physical Device metadata 130 physical_handle()131 inline VkPhysicalDevice physical_handle() const { 132 return physical_device_.handle; 133 } 134 device_handle()135 inline VkDevice device_handle() const { 136 return device_.handle_; 137 } 138 has_unified_memory()139 inline bool has_unified_memory() const { 140 return physical_device_.has_unified_memory; 141 } 142 num_compute_queues()143 inline uint32_t num_compute_queues() const { 144 return physical_device_.num_compute_queues; 145 } 146 timestamp_compute_and_graphics()147 inline bool timestamp_compute_and_graphics() const { 148 return physical_device_.has_timestamps; 149 } 150 timestamp_period()151 inline float timestamp_period() const { 152 return physical_device_.timestamp_period; 153 } 154 155 // Queue Management 156 157 Queue request_queue(); 158 void return_queue(Queue&); 159 160 // Caches 161 shader_layout_cache()162 inline ShaderLayoutCache& shader_layout_cache() { 163 return shader_layout_cache_; 164 } 165 shader_cache()166 inline ShaderCache& shader_cache() { 167 return shader_cache_; 168 } 169 pipeline_layout_cache()170 inline PipelineLayoutCache& pipeline_layout_cache() { 171 return pipeline_layout_cache_; 172 } 173 compute_pipeline_cache()174 inline ComputePipelineCache& compute_pipeline_cache() { 175 return compute_pipeline_cache_; 176 } 177 178 // Memory Allocation 179 sampler_cache()180 inline SamplerCache& sampler_cache() { 181 return sampler_cache_; 182 } 183 vma()184 inline MemoryAllocator& vma() { 185 return vma_; 186 } 187 188 // Command Buffer Submission 189 190 void submit_cmd( 191 const Queue&, 192 VkCommandBuffer, 193 VkFence fence = VK_NULL_HANDLE); 194 195 void submit_cmds( 196 const Adapter::Queue&, 197 const std::vector<VkCommandBuffer>&, 198 VkFence fence = VK_NULL_HANDLE); 199 200 // Miscellaneous 201 local_work_group_size()202 inline utils::uvec3 local_work_group_size() const { 203 return { 204 4u, 205 4u, 206 4u, 207 }; 208 } 209 210 std::string stringize() const; 211 friend std::ostream& operator<<(std::ostream&, const Adapter&); 212 }; 213 214 } // namespace api 215 } // namespace vulkan 216 } // namespace native 217 } // namespace at 218 219 #endif /* USE_VULKAN_API */ 220