xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/api/Context.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
4 
5 #ifdef USE_VULKAN_API
6 
7 #include <ATen/native/vulkan/api/vk_api.h>
8 
9 #include <ATen/native/vulkan/api/Adapter.h>
10 #include <ATen/native/vulkan/api/Command.h>
11 #include <ATen/native/vulkan/api/Descriptor.h>
12 #include <ATen/native/vulkan/api/Pipeline.h>
13 #include <ATen/native/vulkan/api/QueryPool.h>
14 #include <ATen/native/vulkan/api/Resource.h>
15 #include <ATen/native/vulkan/api/Runtime.h>
16 #include <ATen/native/vulkan/api/Shader.h>
17 #include <ATen/native/vulkan/api/Utils.h>
18 
19 namespace at {
20 namespace native {
21 namespace vulkan {
22 namespace api {
23 
24 struct ContextConfig final {
25   uint32_t cmdSubmitFrequency;
26   CommandPoolConfig cmdPoolConfig;
27   DescriptorPoolConfig descriptorPoolConfig;
28   QueryPoolConfig queryPoolConfig;
29 };
30 
31 //
32 // Vulkan Context holds onto all relevant Vulkan state as it pertains to our
33 // use of Vulkan in PyTorch.  A Context is associated with one, and only one,
34 // Adapter as a precursor to multi-GPU support.  All Vulkan tensors in PyTorch
35 // are associated with a Context to make tensor <-> device affinity explicit.
36 // The context is currently a global object, but technically it does not need
37 // to be if we were to make it explicit to the user.
38 //
39 
40 class Context final {
41  public:
42   explicit Context(size_t adapter_i, const ContextConfig&);
43 
44   Context(const Context&) = delete;
45   Context& operator=(const Context&) = delete;
46 
47   Context(Context&&) = delete;
48   Context& operator=(Context&&) = delete;
49 
50   ~Context();
51 
52  private:
53   // Config
54   ContextConfig config_;
55   // Important handles
56   Adapter* adapter_p_;
57   VkDevice device_;
58   Adapter::Queue queue_;
59   // Resource Pools
60   CommandPool command_pool_;
61   DescriptorPool descriptor_pool_;
62   FencePool fences_;
63   // Diagnostics
64   // TODO: remove USE_VULKAN_GPU_DIAGNOSTICS
65   bool enable_op_profiling_{false};
66 #ifdef USE_VULKAN_GPU_DIAGNOSTICS
67   QueryPool querypool_;
68 #endif /* USE_VULKAN_GPU_DIAGNOSTICS */
69   // Command buffers submission
70   std::mutex cmd_mutex_;
71   CommandBuffer cmd_;
72   uint32_t submit_count_;
73   // Memory Management
74   std::mutex buffer_clearlist_mutex_;
75   std::vector<VulkanBuffer> buffers_to_clear_;
76   std::mutex image_clearlist_mutex_;
77   std::vector<VulkanImage> images_to_clear_;
78 
79  public:
80   // Adapter access
81 
adapter_ptr()82   inline Adapter* adapter_ptr() {
83     return adapter_p_;
84   }
85 
enable_op_profiling()86   inline void enable_op_profiling() {
87     enable_op_profiling_ = true;
88   }
89 
disable_op_profiling()90   inline void disable_op_profiling() {
91     enable_op_profiling_ = false;
92   }
93 
op_profiling_enabled()94   inline bool op_profiling_enabled() {
95     return enable_op_profiling_;
96   }
97 
device()98   inline VkDevice device() {
99     return device_;
100   }
101 
queue()102   inline VkQueue queue() {
103     return queue_.handle;
104   }
105 
106   // Device Caches
107 
shader_layout_cache()108   inline ShaderLayoutCache& shader_layout_cache() {
109     return adapter_ptr()->shader_layout_cache();
110   }
111 
shader_cache()112   inline ShaderCache& shader_cache() {
113     return adapter_ptr()->shader_cache();
114   }
115 
pipeline_layout_cache()116   inline PipelineLayoutCache& pipeline_layout_cache() {
117     return adapter_ptr()->pipeline_layout_cache();
118   }
119 
pipeline_cache()120   inline ComputePipelineCache& pipeline_cache() {
121     return adapter_ptr()->compute_pipeline_cache();
122   }
123 
124   // Resource Pools
125 
descriptor_pool()126   inline DescriptorPool& descriptor_pool() {
127     return descriptor_pool_;
128   }
129 
fences()130   inline FencePool& fences() {
131     return fences_;
132   }
133 
134   // Diagnostics
135 
136 #ifdef USE_VULKAN_GPU_DIAGNOSTICS
querypool()137   inline QueryPool& querypool() {
138     return querypool_;
139   }
140 
reset_querypool()141   inline void reset_querypool() {
142     set_cmd();
143     querypool_.reset(cmd_);
144   }
145 #endif /* USE_VULKAN_GPU_DIAGNOSTICS */
146 
147   // Memory Management
register_buffer_cleanup(VulkanBuffer & buffer)148   void register_buffer_cleanup(VulkanBuffer& buffer) {
149     std::lock_guard<std::mutex> bufferlist_lock(buffer_clearlist_mutex_);
150     buffers_to_clear_.emplace_back(std::move(buffer));
151   }
152 
register_image_cleanup(VulkanImage & image)153   void register_image_cleanup(VulkanImage& image) {
154     std::lock_guard<std::mutex> imagelist_lock(image_clearlist_mutex_);
155     images_to_clear_.emplace_back(std::move(image));
156   }
157 
158   // GPU RPC
159 
dispatch_lock()160   inline std::unique_lock<std::mutex> dispatch_lock() {
161     return std::unique_lock<std::mutex>(cmd_mutex_);
162   }
163 
164   inline void set_cmd(bool reusable = false) {
165     if (!cmd_) {
166       cmd_ = command_pool_.get_new_cmd(reusable);
167       cmd_.begin();
168     }
169   }
170 
171   DescriptorSet get_descriptor_set(const ShaderInfo&, const utils::uvec3&);
172 
173   void register_shader_dispatch(
174       const DescriptorSet&,
175       PipelineBarrier&,
176       const ShaderInfo&,
177       const utils::uvec3&);
178 
179   template <class S, class D>
180   bool submit_copy(
181       PipelineBarrier&,
182       const S&,
183       const D&,
184       const api::utils::uvec3&,
185       const api::utils::uvec3&,
186       const api::utils::uvec3&,
187       VkFence fence_handle);
188 
189   template <typename... Arguments>
190   bool submit_compute_job(
191       const ShaderInfo&,
192       PipelineBarrier&,
193       const utils::uvec3&,
194       const utils::uvec3&,
195       VkFence fence_handle,
196       Arguments&&...);
197 
198   void submit_cmd_to_gpu(
199       VkFence fence_handle = VK_NULL_HANDLE,
200       const bool final_use = false);
201 
202   void flush();
203 };
204 
205 class UniformParamsBuffer final {
206  private:
207   Context* context_p_;
208   size_t nbytes_;
209   VulkanBuffer vulkan_buffer_;
210 
211  public:
UniformParamsBuffer()212   UniformParamsBuffer() : context_p_{nullptr}, vulkan_buffer_{} {}
213 
214   template <typename Block>
UniformParamsBuffer(Context * context_p,const Block & block)215   UniformParamsBuffer(Context* context_p, const Block& block)
216       : context_p_(context_p),
217         nbytes_(sizeof(block)),
218         vulkan_buffer_(
219             context_p_->adapter_ptr()->vma().create_params_buffer(block)) {}
220 
221   UniformParamsBuffer(const UniformParamsBuffer&);
222   UniformParamsBuffer& operator=(const UniformParamsBuffer&);
223 
224   UniformParamsBuffer(UniformParamsBuffer&&) = default;
225   UniformParamsBuffer& operator=(UniformParamsBuffer&&) = default;
226 
~UniformParamsBuffer()227   ~UniformParamsBuffer() {
228     if (vulkan_buffer_) {
229       context_p_->register_buffer_cleanup(vulkan_buffer_);
230     }
231   }
232 
buffer()233   VulkanBuffer& buffer() {
234     return vulkan_buffer_;
235   }
236 
237   template <typename Block>
update(const Block & block)238   void update(const Block& block) {
239     if (sizeof(block) != nbytes_) {
240       VK_THROW(
241           "Attempted to update UniformParamsBuffer with data of different size");
242     }
243     // Fill the uniform buffer with data in block
244     {
245       MemoryMap mapping(vulkan_buffer_, MemoryAccessType::WRITE);
246       Block* data_ptr = mapping.template data<Block>();
247 
248       *data_ptr = block;
249     }
250   }
251 };
252 
253 class StorageBuffer final {
254  private:
255   Context* context_p_;
256   ScalarType dtype_;
257   size_t numel_;
258   size_t nbytes_;
259   VulkanBuffer vulkan_buffer_;
260 
261  public:
262   StorageBuffer(
263       Context* context_p,
264       const ScalarType dtype,
265       const size_t numel,
266       const bool gpuonly = false)
context_p_(context_p)267       : context_p_(context_p),
268         dtype_(dtype),
269         numel_(numel),
270         nbytes_(element_size(dtype_) * numel_),
271         vulkan_buffer_(context_p_->adapter_ptr()->vma().create_storage_buffer(
272             nbytes_,
273             gpuonly)) {}
274 
275   StorageBuffer(const StorageBuffer&) = delete;
276   StorageBuffer& operator=(const StorageBuffer&) = delete;
277 
278   StorageBuffer(StorageBuffer&&) = default;
279   StorageBuffer& operator=(StorageBuffer&&) = default;
280 
~StorageBuffer()281   ~StorageBuffer() {
282     context_p_->register_buffer_cleanup(vulkan_buffer_);
283   }
284 
dtype()285   inline ScalarType dtype() {
286     return dtype_;
287   }
288 
buffer()289   inline VulkanBuffer& buffer() {
290     return vulkan_buffer_;
291   }
292 
numel()293   inline size_t numel() {
294     return numel_;
295   }
296 
nbytes()297   inline size_t nbytes() {
298     return nbytes_;
299   }
300 };
301 
302 bool available();
303 
304 // The global runtime is retrieved using this function, where it is declared as
305 // a static local variable.
306 Context* context();
307 
308 namespace detail {
309 
arg_is_empty(bool & any_is_empty,const VulkanBuffer & buffer)310 inline void arg_is_empty(bool& any_is_empty, const VulkanBuffer& buffer) {
311   // bool(buffer) will evaluate to false if no memory has been allocated
312   any_is_empty = any_is_empty || !buffer;
313 }
314 
arg_is_empty(bool & any_is_empty,const VulkanImage & image)315 inline void arg_is_empty(bool& any_is_empty, const VulkanImage& image) {
316   // bool(image) will evaluate to false if no memory has been allocated
317   any_is_empty = any_is_empty || !image;
318 }
319 
320 /*
321   Reports if any VulkanBuffer or VulkanImage argument in a variadic argument
322   list does not have any memory associated with it.
323  */
324 template <typename... Arguments>
any_arg_is_empty(Arguments &&...arguments)325 inline bool any_arg_is_empty(Arguments&&... arguments) {
326   bool any_is_empty = false;
327   VK_UNUSED const int _[]{
328       0,
329       (arg_is_empty(any_is_empty, std::forward<Arguments>(arguments)), 0)...,
330   };
331 
332   return any_is_empty;
333 }
334 
335 template <size_t... Indices, typename... Arguments>
bind(DescriptorSet & descriptor_set,const std::index_sequence<Indices...> &,Arguments &&...arguments)336 inline void bind(
337     DescriptorSet& descriptor_set,
338     const std::index_sequence<Indices...>&,
339     Arguments&&... arguments) {
340   VK_UNUSED const int _[]{
341       0,
342       (descriptor_set.bind(Indices, std::forward<Arguments>(arguments)), 0)...,
343   };
344 }
345 
346 } // namespace detail
347 
348 template <class S, class D>
349 inline void record_copy(
350     CommandBuffer& cmd,
351     const S& source,
352     const D& destination,
353     const api::utils::uvec3& copy_range,
354     const api::utils::uvec3& src_offset,
355     const api::utils::uvec3& dst_offset) = delete;
356 
357 template <>
358 inline void record_copy<VulkanBuffer, VulkanBuffer>(
359     CommandBuffer& cmd,
360     const VulkanBuffer& source,
361     const VulkanBuffer& destination,
362     const api::utils::uvec3& copy_range,
363     const api::utils::uvec3& src_offset,
364     const api::utils::uvec3& dst_offset) {
365   cmd.copy_buffer_to_buffer(
366       source, destination, copy_range, src_offset, dst_offset);
367 }
368 
369 template <>
370 inline void record_copy<VulkanImage, VulkanImage>(
371     CommandBuffer& cmd,
372     const VulkanImage& source,
373     const VulkanImage& destination,
374     const api::utils::uvec3& copy_range,
375     const api::utils::uvec3& src_offset,
376     const api::utils::uvec3& dst_offset) {
377   cmd.copy_texture_to_texture(
378       source, destination, copy_range, src_offset, dst_offset);
379 }
380 
381 template <>
382 inline void record_copy<VulkanImage, VulkanBuffer>(
383     CommandBuffer& cmd,
384     const VulkanImage& source,
385     const VulkanBuffer& destination,
386     const api::utils::uvec3& copy_range,
387     const api::utils::uvec3& src_offset,
388     const api::utils::uvec3& dst_offset) {
389   cmd.copy_texture_to_buffer(
390       source, destination, copy_range, src_offset, dst_offset);
391 }
392 
393 template <>
394 inline void record_copy<VulkanBuffer, VulkanImage>(
395     CommandBuffer& cmd,
396     const VulkanBuffer& source,
397     const VulkanImage& destination,
398     const api::utils::uvec3& copy_range,
399     const api::utils::uvec3& src_offset,
400     const api::utils::uvec3& dst_offset) {
401   cmd.copy_buffer_to_texture(
402       source, destination, copy_range, src_offset, dst_offset);
403 }
404 
405 /*
406   Records a GPU data copy into the current command buffer. If the number of
407   submit_*_job calls exceeds the configured frequency, or if a fence is
408   provided, then the command buffer is submitted to the GPU for execution.
409   Returns a bool indicating whether or not the function call resulted in a GPU
410   queue submission.
411  */
412 template <class S, class D>
submit_copy(PipelineBarrier & pipeline_barrier,const S & source,const D & destination,const api::utils::uvec3 & copy_range,const api::utils::uvec3 & src_offset,const api::utils::uvec3 & dst_offset,VkFence fence_handle)413 inline bool Context::submit_copy(
414     PipelineBarrier& pipeline_barrier,
415     const S& source,
416     const D& destination,
417     const api::utils::uvec3& copy_range,
418     const api::utils::uvec3& src_offset,
419     const api::utils::uvec3& dst_offset,
420     VkFence fence_handle) {
421   // If any of the provided arguments does not have memory associated with it,
422   // then exit early as there is no work to be done. However, if a fence has
423   // been passed the command buffer is not empty, then the current command
424   // buffer must still be submitted so that the fence can be signaled.
425   if (!source || !destination) {
426     if (fence_handle != VK_NULL_HANDLE && submit_count_ > 0) {
427       submit_cmd_to_gpu(fence_handle);
428       return true;
429     }
430     return false;
431   }
432 
433   // Serialize recording to the shared command buffer. Do not initialize with a
434   // mutex just yet, since in some cases it will be externally managed.
435   std::unique_lock<std::mutex> cmd_lock;
436   // Refer to comments in submit_compute_job for explanation.
437   if (fence_handle == VK_NULL_HANDLE) {
438     cmd_lock = std::unique_lock<std::mutex>(cmd_mutex_);
439   }
440 
441   set_cmd();
442 
443 #ifdef USE_VULKAN_GPU_DIAGNOSTICS
444   uint32_t log_idx = UINT32_MAX;
445   if (enable_op_profiling_) {
446     std::string label = "cmd_copy";
447     log_idx = querypool_.shader_profile_begin(
448         cmd_, label, create_extent3d({0, 0, 0}), create_extent3d({0, 0, 0}));
449   }
450 #endif /* USE_VULKAN_GPU_DIAGNOSTICS */
451 
452   cmd_.insert_barrier(pipeline_barrier);
453 
454   record_copy(cmd_, source, destination, copy_range, src_offset, dst_offset);
455 
456 #ifdef USE_VULKAN_GPU_DIAGNOSTICS
457   if (enable_op_profiling_) {
458     querypool_.shader_profile_end(cmd_, log_idx);
459   }
460 #endif /* USE_VULKAN_GPU_DIAGNOSTICS */
461 
462   submit_count_++;
463   if (fence_handle != VK_NULL_HANDLE ||
464       submit_count_ >= config_.cmdSubmitFrequency) {
465     submit_cmd_to_gpu(fence_handle);
466     return true;
467   }
468   return false;
469 }
470 
471 /*
472   Records a compute shader dispatch into the current command buffer. If the
473   number of submit_*_job calls exceeds the configured frequency, or if a fence
474   is provided, then the command buffer is submitted to the GPU for execution.
475   Returns a bool indicating whether or not the function call resulted in a GPU
476   queue submission.
477  */
478 template <typename... Arguments>
submit_compute_job(const ShaderInfo & shader,PipelineBarrier & pipeline_barrier,const utils::uvec3 & global_work_group,const utils::uvec3 & local_work_group_size,VkFence fence_handle,Arguments &&...arguments)479 inline bool Context::submit_compute_job(
480     const ShaderInfo& shader,
481     PipelineBarrier& pipeline_barrier,
482     const utils::uvec3& global_work_group,
483     const utils::uvec3& local_work_group_size,
484     VkFence fence_handle,
485     Arguments&&... arguments) {
486   // If any of the provided arguments does not have memory associated with it,
487   // then exit early as there is no work to be done. However, if a fence has
488   // been passed the command buffer is not empty, then the current command
489   // buffer must still be submitted so that the fence can be signaled.
490   if (detail::any_arg_is_empty(arguments...)) {
491     if (fence_handle != VK_NULL_HANDLE && submit_count_ > 0) {
492       submit_cmd_to_gpu(fence_handle);
493       return true;
494     }
495     return false;
496   }
497 
498   // Serialize recording to the shared command buffer. Do not initialize with a
499   // mutex just yet, since in some cases it will be externally managed.
500   std::unique_lock<std::mutex> cmd_lock;
501   // If a fence was passed, then assume that the host intends to sync with
502   // the GPU, implying there will be imminent calls to fence.wait() and flush().
503   // We therefore assume the mutex is externally managed in this case, and the
504   // calling thread has already locked the mutex prior to calling the function,
505   // and will release the mutex manually after calling flush(). This will
506   // prevent more dispatches from being recorded until we have flushed the
507   // Context.
508   if (fence_handle == VK_NULL_HANDLE) {
509     cmd_lock = std::unique_lock<std::mutex>(cmd_mutex_);
510   }
511 
512   set_cmd();
513 
514 #ifdef USE_VULKAN_GPU_DIAGNOSTICS
515   uint32_t log_idx = UINT32_MAX;
516   if (enable_op_profiling_) {
517     log_idx = querypool_.shader_profile_begin(
518         cmd_,
519         shader.kernel_name,
520         create_extent3d(global_work_group),
521         create_extent3d(local_work_group_size));
522   }
523 #endif /* USE_VULKAN_GPU_DIAGNOSTICS */
524 
525   // Factor out template parameter independent code to minimize code bloat.
526   DescriptorSet descriptor_set =
527       get_descriptor_set(shader, local_work_group_size);
528 
529   detail::bind(
530       descriptor_set,
531       std::index_sequence_for<Arguments...>{},
532       std::forward<Arguments>(arguments)...);
533 
534   // Factor out template parameter independent code to minimize code bloat.
535   register_shader_dispatch(
536       descriptor_set, pipeline_barrier, shader, global_work_group);
537 
538 #ifdef USE_VULKAN_GPU_DIAGNOSTICS
539   if (enable_op_profiling_) {
540     querypool_.shader_profile_end(cmd_, log_idx);
541   }
542 #endif /* USE_VULKAN_GPU_DIAGNOSTICS */
543 
544   submit_count_++;
545   if (fence_handle != VK_NULL_HANDLE ||
546       submit_count_ >= config_.cmdSubmitFrequency) {
547     submit_cmd_to_gpu(fence_handle);
548     return true;
549   }
550 
551   return false;
552 }
553 
554 } // namespace api
555 } // namespace vulkan
556 } // namespace native
557 } // namespace at
558 
559 #endif /* USE_VULKAN_API */
560