xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/api/Adapter.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
4 
5 #ifdef USE_VULKAN_API
6 
7 #include <ATen/native/vulkan/api/vk_api.h>
8 
9 #include <ATen/native/vulkan/api/Pipeline.h>
10 #include <ATen/native/vulkan/api/Shader.h>
11 #include <ATen/native/vulkan/api/Utils.h>
12 
13 #include <array>
14 #include <mutex>
15 #include <ostream>
16 
17 namespace at {
18 namespace native {
19 namespace vulkan {
20 namespace api {
21 
22 struct PhysicalDevice final {
23   // Handle
24   VkPhysicalDevice handle;
25 
26   // Properties obtained from Vulkan
27   VkPhysicalDeviceProperties properties;
28   VkPhysicalDeviceMemoryProperties memory_properties;
29   std::vector<VkQueueFamilyProperties> queue_families;
30 
31   // Metadata
32   uint32_t num_compute_queues;
33   bool has_unified_memory;
34   bool has_timestamps;
35   float timestamp_period;
36 
37   explicit PhysicalDevice(VkPhysicalDevice);
38 };
39 
40 class DeviceHandle final {
41  public:
42   explicit DeviceHandle(VkDevice device);
43 
44   DeviceHandle(const DeviceHandle&) = delete;
45   DeviceHandle& operator=(const DeviceHandle&) = delete;
46 
47   DeviceHandle(DeviceHandle&&) noexcept;
48   DeviceHandle& operator=(DeviceHandle&&) = delete;
49 
50   ~DeviceHandle();
51 
52  private:
53   VkDevice handle_;
54 
55   friend class Adapter;
56 };
57 
58 //
59 // A Vulkan Adapter represents a logical device and all its properties. It
60 // manages all relevant properties of the underlying physical device, a
61 // handle to the logical device, and a number of compute queues available to
62 // the device. It is primarily responsible for managing the VkDevice handle
63 // which points to the logical device object on the GPU.
64 //
65 // This class is primarily used by the Runtime class, which holds one Adapter
66 // instance for each physical device visible to the VkInstance. Upon
67 // construction, this class will populate the physical device properties, but
68 // will not create the logical device until specifically requested via the
69 // init_device() function.
70 //
71 // init_device() will create the logical device and obtain the VkDevice handle
72 // for it. It will also create a number of compute queues up to the amount
73 // requested when the Adapter instance was constructed.
74 //
75 // Contexts (which represent one thread of execution) will request a compute
76 // queue from an Adapter. The Adapter will then select a compute queue to
77 // assign to the Context, attempting to balance load between all available
78 // queues. This will allow different Contexts (which typically execute on
79 // separate threads) to run concurrently.
80 //
81 
82 #define NUM_QUEUE_MUTEXES 4
83 
84 class Adapter final {
85  public:
86   explicit Adapter(
87       VkInstance instance,
88       PhysicalDevice physical_device,
89       const uint32_t num_queues);
90 
91   Adapter(const Adapter&) = delete;
92   Adapter& operator=(const Adapter&) = delete;
93 
94   Adapter(Adapter&&) = delete;
95   Adapter& operator=(Adapter&&) = delete;
96 
97   ~Adapter() = default;
98 
99   struct Queue {
100     uint32_t family_index;
101     uint32_t queue_index;
102     VkQueueFlags capabilities;
103     VkQueue handle;
104   };
105 
106  private:
107   // Use a mutex to manage queue usage info since
108   // it can be accessed from multiple threads
109   std::mutex queue_usage_mutex_;
110   // Physical Device Info
111   PhysicalDevice physical_device_;
112   // Queue Management
113   std::vector<Queue> queues_;
114   std::vector<uint32_t> queue_usage_;
115   std::array<std::mutex, NUM_QUEUE_MUTEXES> queue_mutexes_;
116   // Handles
117   VkInstance instance_;
118   DeviceHandle device_;
119   // Device-level resource caches
120   ShaderLayoutCache shader_layout_cache_;
121   ShaderCache shader_cache_;
122   PipelineLayoutCache pipeline_layout_cache_;
123   ComputePipelineCache compute_pipeline_cache_;
124   // Memory Management
125   SamplerCache sampler_cache_;
126   MemoryAllocator vma_;
127 
128  public:
129   // Physical Device metadata
130 
physical_handle()131   inline VkPhysicalDevice physical_handle() const {
132     return physical_device_.handle;
133   }
134 
device_handle()135   inline VkDevice device_handle() const {
136     return device_.handle_;
137   }
138 
has_unified_memory()139   inline bool has_unified_memory() const {
140     return physical_device_.has_unified_memory;
141   }
142 
num_compute_queues()143   inline uint32_t num_compute_queues() const {
144     return physical_device_.num_compute_queues;
145   }
146 
timestamp_compute_and_graphics()147   inline bool timestamp_compute_and_graphics() const {
148     return physical_device_.has_timestamps;
149   }
150 
timestamp_period()151   inline float timestamp_period() const {
152     return physical_device_.timestamp_period;
153   }
154 
155   // Queue Management
156 
157   Queue request_queue();
158   void return_queue(Queue&);
159 
160   // Caches
161 
shader_layout_cache()162   inline ShaderLayoutCache& shader_layout_cache() {
163     return shader_layout_cache_;
164   }
165 
shader_cache()166   inline ShaderCache& shader_cache() {
167     return shader_cache_;
168   }
169 
pipeline_layout_cache()170   inline PipelineLayoutCache& pipeline_layout_cache() {
171     return pipeline_layout_cache_;
172   }
173 
compute_pipeline_cache()174   inline ComputePipelineCache& compute_pipeline_cache() {
175     return compute_pipeline_cache_;
176   }
177 
178   // Memory Allocation
179 
sampler_cache()180   inline SamplerCache& sampler_cache() {
181     return sampler_cache_;
182   }
183 
vma()184   inline MemoryAllocator& vma() {
185     return vma_;
186   }
187 
188   // Command Buffer Submission
189 
190   void submit_cmd(
191       const Queue&,
192       VkCommandBuffer,
193       VkFence fence = VK_NULL_HANDLE);
194 
195   void submit_cmds(
196       const Adapter::Queue&,
197       const std::vector<VkCommandBuffer>&,
198       VkFence fence = VK_NULL_HANDLE);
199 
200   // Miscellaneous
201 
local_work_group_size()202   inline utils::uvec3 local_work_group_size() const {
203     return {
204         4u,
205         4u,
206         4u,
207     };
208   }
209 
210   std::string stringize() const;
211   friend std::ostream& operator<<(std::ostream&, const Adapter&);
212 };
213 
214 } // namespace api
215 } // namespace vulkan
216 } // namespace native
217 } // namespace at
218 
219 #endif /* USE_VULKAN_API */
220