1 #pragma once
2
3 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
4
5 #ifdef USE_VULKAN_API
6
7 #include <ATen/native/vulkan/api/vk_api.h>
8
9 #include <ATen/native/vulkan/api/Adapter.h>
10 #include <ATen/native/vulkan/api/Command.h>
11 #include <ATen/native/vulkan/api/Descriptor.h>
12 #include <ATen/native/vulkan/api/Pipeline.h>
13 #include <ATen/native/vulkan/api/QueryPool.h>
14 #include <ATen/native/vulkan/api/Resource.h>
15 #include <ATen/native/vulkan/api/Runtime.h>
16 #include <ATen/native/vulkan/api/Shader.h>
17 #include <ATen/native/vulkan/api/Utils.h>
18
19 namespace at {
20 namespace native {
21 namespace vulkan {
22 namespace api {
23
24 struct ContextConfig final {
25 uint32_t cmdSubmitFrequency;
26 CommandPoolConfig cmdPoolConfig;
27 DescriptorPoolConfig descriptorPoolConfig;
28 QueryPoolConfig queryPoolConfig;
29 };
30
31 //
32 // Vulkan Context holds onto all relevant Vulkan state as it pertains to our
33 // use of Vulkan in PyTorch. A Context is associated with one, and only one,
34 // Adapter as a precursor to multi-GPU support. All Vulkan tensors in PyTorch
35 // are associated with a Context to make tensor <-> device affinity explicit.
36 // The context is currently a global object, but technically it does not need
37 // to be if we were to make it explicit to the user.
38 //
39
40 class Context final {
41 public:
42 explicit Context(size_t adapter_i, const ContextConfig&);
43
44 Context(const Context&) = delete;
45 Context& operator=(const Context&) = delete;
46
47 Context(Context&&) = delete;
48 Context& operator=(Context&&) = delete;
49
50 ~Context();
51
52 private:
53 // Config
54 ContextConfig config_;
55 // Important handles
56 Adapter* adapter_p_;
57 VkDevice device_;
58 Adapter::Queue queue_;
59 // Resource Pools
60 CommandPool command_pool_;
61 DescriptorPool descriptor_pool_;
62 FencePool fences_;
63 // Diagnostics
64 // TODO: remove USE_VULKAN_GPU_DIAGNOSTICS
65 bool enable_op_profiling_{false};
66 #ifdef USE_VULKAN_GPU_DIAGNOSTICS
67 QueryPool querypool_;
68 #endif /* USE_VULKAN_GPU_DIAGNOSTICS */
69 // Command buffers submission
70 std::mutex cmd_mutex_;
71 CommandBuffer cmd_;
72 uint32_t submit_count_;
73 // Memory Management
74 std::mutex buffer_clearlist_mutex_;
75 std::vector<VulkanBuffer> buffers_to_clear_;
76 std::mutex image_clearlist_mutex_;
77 std::vector<VulkanImage> images_to_clear_;
78
79 public:
80 // Adapter access
81
adapter_ptr()82 inline Adapter* adapter_ptr() {
83 return adapter_p_;
84 }
85
enable_op_profiling()86 inline void enable_op_profiling() {
87 enable_op_profiling_ = true;
88 }
89
disable_op_profiling()90 inline void disable_op_profiling() {
91 enable_op_profiling_ = false;
92 }
93
op_profiling_enabled()94 inline bool op_profiling_enabled() {
95 return enable_op_profiling_;
96 }
97
device()98 inline VkDevice device() {
99 return device_;
100 }
101
queue()102 inline VkQueue queue() {
103 return queue_.handle;
104 }
105
106 // Device Caches
107
shader_layout_cache()108 inline ShaderLayoutCache& shader_layout_cache() {
109 return adapter_ptr()->shader_layout_cache();
110 }
111
shader_cache()112 inline ShaderCache& shader_cache() {
113 return adapter_ptr()->shader_cache();
114 }
115
pipeline_layout_cache()116 inline PipelineLayoutCache& pipeline_layout_cache() {
117 return adapter_ptr()->pipeline_layout_cache();
118 }
119
pipeline_cache()120 inline ComputePipelineCache& pipeline_cache() {
121 return adapter_ptr()->compute_pipeline_cache();
122 }
123
124 // Resource Pools
125
descriptor_pool()126 inline DescriptorPool& descriptor_pool() {
127 return descriptor_pool_;
128 }
129
fences()130 inline FencePool& fences() {
131 return fences_;
132 }
133
134 // Diagnostics
135
136 #ifdef USE_VULKAN_GPU_DIAGNOSTICS
querypool()137 inline QueryPool& querypool() {
138 return querypool_;
139 }
140
reset_querypool()141 inline void reset_querypool() {
142 set_cmd();
143 querypool_.reset(cmd_);
144 }
145 #endif /* USE_VULKAN_GPU_DIAGNOSTICS */
146
147 // Memory Management
register_buffer_cleanup(VulkanBuffer & buffer)148 void register_buffer_cleanup(VulkanBuffer& buffer) {
149 std::lock_guard<std::mutex> bufferlist_lock(buffer_clearlist_mutex_);
150 buffers_to_clear_.emplace_back(std::move(buffer));
151 }
152
register_image_cleanup(VulkanImage & image)153 void register_image_cleanup(VulkanImage& image) {
154 std::lock_guard<std::mutex> imagelist_lock(image_clearlist_mutex_);
155 images_to_clear_.emplace_back(std::move(image));
156 }
157
158 // GPU RPC
159
dispatch_lock()160 inline std::unique_lock<std::mutex> dispatch_lock() {
161 return std::unique_lock<std::mutex>(cmd_mutex_);
162 }
163
164 inline void set_cmd(bool reusable = false) {
165 if (!cmd_) {
166 cmd_ = command_pool_.get_new_cmd(reusable);
167 cmd_.begin();
168 }
169 }
170
171 DescriptorSet get_descriptor_set(const ShaderInfo&, const utils::uvec3&);
172
173 void register_shader_dispatch(
174 const DescriptorSet&,
175 PipelineBarrier&,
176 const ShaderInfo&,
177 const utils::uvec3&);
178
179 template <class S, class D>
180 bool submit_copy(
181 PipelineBarrier&,
182 const S&,
183 const D&,
184 const api::utils::uvec3&,
185 const api::utils::uvec3&,
186 const api::utils::uvec3&,
187 VkFence fence_handle);
188
189 template <typename... Arguments>
190 bool submit_compute_job(
191 const ShaderInfo&,
192 PipelineBarrier&,
193 const utils::uvec3&,
194 const utils::uvec3&,
195 VkFence fence_handle,
196 Arguments&&...);
197
198 void submit_cmd_to_gpu(
199 VkFence fence_handle = VK_NULL_HANDLE,
200 const bool final_use = false);
201
202 void flush();
203 };
204
205 class UniformParamsBuffer final {
206 private:
207 Context* context_p_;
208 size_t nbytes_;
209 VulkanBuffer vulkan_buffer_;
210
211 public:
UniformParamsBuffer()212 UniformParamsBuffer() : context_p_{nullptr}, vulkan_buffer_{} {}
213
214 template <typename Block>
UniformParamsBuffer(Context * context_p,const Block & block)215 UniformParamsBuffer(Context* context_p, const Block& block)
216 : context_p_(context_p),
217 nbytes_(sizeof(block)),
218 vulkan_buffer_(
219 context_p_->adapter_ptr()->vma().create_params_buffer(block)) {}
220
221 UniformParamsBuffer(const UniformParamsBuffer&);
222 UniformParamsBuffer& operator=(const UniformParamsBuffer&);
223
224 UniformParamsBuffer(UniformParamsBuffer&&) = default;
225 UniformParamsBuffer& operator=(UniformParamsBuffer&&) = default;
226
~UniformParamsBuffer()227 ~UniformParamsBuffer() {
228 if (vulkan_buffer_) {
229 context_p_->register_buffer_cleanup(vulkan_buffer_);
230 }
231 }
232
buffer()233 VulkanBuffer& buffer() {
234 return vulkan_buffer_;
235 }
236
237 template <typename Block>
update(const Block & block)238 void update(const Block& block) {
239 if (sizeof(block) != nbytes_) {
240 VK_THROW(
241 "Attempted to update UniformParamsBuffer with data of different size");
242 }
243 // Fill the uniform buffer with data in block
244 {
245 MemoryMap mapping(vulkan_buffer_, MemoryAccessType::WRITE);
246 Block* data_ptr = mapping.template data<Block>();
247
248 *data_ptr = block;
249 }
250 }
251 };
252
253 class StorageBuffer final {
254 private:
255 Context* context_p_;
256 ScalarType dtype_;
257 size_t numel_;
258 size_t nbytes_;
259 VulkanBuffer vulkan_buffer_;
260
261 public:
262 StorageBuffer(
263 Context* context_p,
264 const ScalarType dtype,
265 const size_t numel,
266 const bool gpuonly = false)
context_p_(context_p)267 : context_p_(context_p),
268 dtype_(dtype),
269 numel_(numel),
270 nbytes_(element_size(dtype_) * numel_),
271 vulkan_buffer_(context_p_->adapter_ptr()->vma().create_storage_buffer(
272 nbytes_,
273 gpuonly)) {}
274
275 StorageBuffer(const StorageBuffer&) = delete;
276 StorageBuffer& operator=(const StorageBuffer&) = delete;
277
278 StorageBuffer(StorageBuffer&&) = default;
279 StorageBuffer& operator=(StorageBuffer&&) = default;
280
~StorageBuffer()281 ~StorageBuffer() {
282 context_p_->register_buffer_cleanup(vulkan_buffer_);
283 }
284
dtype()285 inline ScalarType dtype() {
286 return dtype_;
287 }
288
buffer()289 inline VulkanBuffer& buffer() {
290 return vulkan_buffer_;
291 }
292
numel()293 inline size_t numel() {
294 return numel_;
295 }
296
nbytes()297 inline size_t nbytes() {
298 return nbytes_;
299 }
300 };
301
302 bool available();
303
304 // The global runtime is retrieved using this function, where it is declared as
305 // a static local variable.
306 Context* context();
307
308 namespace detail {
309
arg_is_empty(bool & any_is_empty,const VulkanBuffer & buffer)310 inline void arg_is_empty(bool& any_is_empty, const VulkanBuffer& buffer) {
311 // bool(buffer) will evaluate to false if no memory has been allocated
312 any_is_empty = any_is_empty || !buffer;
313 }
314
arg_is_empty(bool & any_is_empty,const VulkanImage & image)315 inline void arg_is_empty(bool& any_is_empty, const VulkanImage& image) {
316 // bool(image) will evaluate to false if no memory has been allocated
317 any_is_empty = any_is_empty || !image;
318 }
319
320 /*
321 Reports if any VulkanBuffer or VulkanImage argument in a variadic argument
322 list does not have any memory associated with it.
323 */
324 template <typename... Arguments>
any_arg_is_empty(Arguments &&...arguments)325 inline bool any_arg_is_empty(Arguments&&... arguments) {
326 bool any_is_empty = false;
327 VK_UNUSED const int _[]{
328 0,
329 (arg_is_empty(any_is_empty, std::forward<Arguments>(arguments)), 0)...,
330 };
331
332 return any_is_empty;
333 }
334
335 template <size_t... Indices, typename... Arguments>
bind(DescriptorSet & descriptor_set,const std::index_sequence<Indices...> &,Arguments &&...arguments)336 inline void bind(
337 DescriptorSet& descriptor_set,
338 const std::index_sequence<Indices...>&,
339 Arguments&&... arguments) {
340 VK_UNUSED const int _[]{
341 0,
342 (descriptor_set.bind(Indices, std::forward<Arguments>(arguments)), 0)...,
343 };
344 }
345
346 } // namespace detail
347
348 template <class S, class D>
349 inline void record_copy(
350 CommandBuffer& cmd,
351 const S& source,
352 const D& destination,
353 const api::utils::uvec3& copy_range,
354 const api::utils::uvec3& src_offset,
355 const api::utils::uvec3& dst_offset) = delete;
356
357 template <>
358 inline void record_copy<VulkanBuffer, VulkanBuffer>(
359 CommandBuffer& cmd,
360 const VulkanBuffer& source,
361 const VulkanBuffer& destination,
362 const api::utils::uvec3& copy_range,
363 const api::utils::uvec3& src_offset,
364 const api::utils::uvec3& dst_offset) {
365 cmd.copy_buffer_to_buffer(
366 source, destination, copy_range, src_offset, dst_offset);
367 }
368
369 template <>
370 inline void record_copy<VulkanImage, VulkanImage>(
371 CommandBuffer& cmd,
372 const VulkanImage& source,
373 const VulkanImage& destination,
374 const api::utils::uvec3& copy_range,
375 const api::utils::uvec3& src_offset,
376 const api::utils::uvec3& dst_offset) {
377 cmd.copy_texture_to_texture(
378 source, destination, copy_range, src_offset, dst_offset);
379 }
380
381 template <>
382 inline void record_copy<VulkanImage, VulkanBuffer>(
383 CommandBuffer& cmd,
384 const VulkanImage& source,
385 const VulkanBuffer& destination,
386 const api::utils::uvec3& copy_range,
387 const api::utils::uvec3& src_offset,
388 const api::utils::uvec3& dst_offset) {
389 cmd.copy_texture_to_buffer(
390 source, destination, copy_range, src_offset, dst_offset);
391 }
392
393 template <>
394 inline void record_copy<VulkanBuffer, VulkanImage>(
395 CommandBuffer& cmd,
396 const VulkanBuffer& source,
397 const VulkanImage& destination,
398 const api::utils::uvec3& copy_range,
399 const api::utils::uvec3& src_offset,
400 const api::utils::uvec3& dst_offset) {
401 cmd.copy_buffer_to_texture(
402 source, destination, copy_range, src_offset, dst_offset);
403 }
404
405 /*
406 Records a GPU data copy into the current command buffer. If the number of
407 submit_*_job calls exceeds the configured frequency, or if a fence is
408 provided, then the command buffer is submitted to the GPU for execution.
409 Returns a bool indicating whether or not the function call resulted in a GPU
410 queue submission.
411 */
412 template <class S, class D>
submit_copy(PipelineBarrier & pipeline_barrier,const S & source,const D & destination,const api::utils::uvec3 & copy_range,const api::utils::uvec3 & src_offset,const api::utils::uvec3 & dst_offset,VkFence fence_handle)413 inline bool Context::submit_copy(
414 PipelineBarrier& pipeline_barrier,
415 const S& source,
416 const D& destination,
417 const api::utils::uvec3& copy_range,
418 const api::utils::uvec3& src_offset,
419 const api::utils::uvec3& dst_offset,
420 VkFence fence_handle) {
421 // If any of the provided arguments does not have memory associated with it,
422 // then exit early as there is no work to be done. However, if a fence has
423 // been passed the command buffer is not empty, then the current command
424 // buffer must still be submitted so that the fence can be signaled.
425 if (!source || !destination) {
426 if (fence_handle != VK_NULL_HANDLE && submit_count_ > 0) {
427 submit_cmd_to_gpu(fence_handle);
428 return true;
429 }
430 return false;
431 }
432
433 // Serialize recording to the shared command buffer. Do not initialize with a
434 // mutex just yet, since in some cases it will be externally managed.
435 std::unique_lock<std::mutex> cmd_lock;
436 // Refer to comments in submit_compute_job for explanation.
437 if (fence_handle == VK_NULL_HANDLE) {
438 cmd_lock = std::unique_lock<std::mutex>(cmd_mutex_);
439 }
440
441 set_cmd();
442
443 #ifdef USE_VULKAN_GPU_DIAGNOSTICS
444 uint32_t log_idx = UINT32_MAX;
445 if (enable_op_profiling_) {
446 std::string label = "cmd_copy";
447 log_idx = querypool_.shader_profile_begin(
448 cmd_, label, create_extent3d({0, 0, 0}), create_extent3d({0, 0, 0}));
449 }
450 #endif /* USE_VULKAN_GPU_DIAGNOSTICS */
451
452 cmd_.insert_barrier(pipeline_barrier);
453
454 record_copy(cmd_, source, destination, copy_range, src_offset, dst_offset);
455
456 #ifdef USE_VULKAN_GPU_DIAGNOSTICS
457 if (enable_op_profiling_) {
458 querypool_.shader_profile_end(cmd_, log_idx);
459 }
460 #endif /* USE_VULKAN_GPU_DIAGNOSTICS */
461
462 submit_count_++;
463 if (fence_handle != VK_NULL_HANDLE ||
464 submit_count_ >= config_.cmdSubmitFrequency) {
465 submit_cmd_to_gpu(fence_handle);
466 return true;
467 }
468 return false;
469 }
470
471 /*
472 Records a compute shader dispatch into the current command buffer. If the
473 number of submit_*_job calls exceeds the configured frequency, or if a fence
474 is provided, then the command buffer is submitted to the GPU for execution.
475 Returns a bool indicating whether or not the function call resulted in a GPU
476 queue submission.
477 */
478 template <typename... Arguments>
submit_compute_job(const ShaderInfo & shader,PipelineBarrier & pipeline_barrier,const utils::uvec3 & global_work_group,const utils::uvec3 & local_work_group_size,VkFence fence_handle,Arguments &&...arguments)479 inline bool Context::submit_compute_job(
480 const ShaderInfo& shader,
481 PipelineBarrier& pipeline_barrier,
482 const utils::uvec3& global_work_group,
483 const utils::uvec3& local_work_group_size,
484 VkFence fence_handle,
485 Arguments&&... arguments) {
486 // If any of the provided arguments does not have memory associated with it,
487 // then exit early as there is no work to be done. However, if a fence has
488 // been passed the command buffer is not empty, then the current command
489 // buffer must still be submitted so that the fence can be signaled.
490 if (detail::any_arg_is_empty(arguments...)) {
491 if (fence_handle != VK_NULL_HANDLE && submit_count_ > 0) {
492 submit_cmd_to_gpu(fence_handle);
493 return true;
494 }
495 return false;
496 }
497
498 // Serialize recording to the shared command buffer. Do not initialize with a
499 // mutex just yet, since in some cases it will be externally managed.
500 std::unique_lock<std::mutex> cmd_lock;
501 // If a fence was passed, then assume that the host intends to sync with
502 // the GPU, implying there will be imminent calls to fence.wait() and flush().
503 // We therefore assume the mutex is externally managed in this case, and the
504 // calling thread has already locked the mutex prior to calling the function,
505 // and will release the mutex manually after calling flush(). This will
506 // prevent more dispatches from being recorded until we have flushed the
507 // Context.
508 if (fence_handle == VK_NULL_HANDLE) {
509 cmd_lock = std::unique_lock<std::mutex>(cmd_mutex_);
510 }
511
512 set_cmd();
513
514 #ifdef USE_VULKAN_GPU_DIAGNOSTICS
515 uint32_t log_idx = UINT32_MAX;
516 if (enable_op_profiling_) {
517 log_idx = querypool_.shader_profile_begin(
518 cmd_,
519 shader.kernel_name,
520 create_extent3d(global_work_group),
521 create_extent3d(local_work_group_size));
522 }
523 #endif /* USE_VULKAN_GPU_DIAGNOSTICS */
524
525 // Factor out template parameter independent code to minimize code bloat.
526 DescriptorSet descriptor_set =
527 get_descriptor_set(shader, local_work_group_size);
528
529 detail::bind(
530 descriptor_set,
531 std::index_sequence_for<Arguments...>{},
532 std::forward<Arguments>(arguments)...);
533
534 // Factor out template parameter independent code to minimize code bloat.
535 register_shader_dispatch(
536 descriptor_set, pipeline_barrier, shader, global_work_group);
537
538 #ifdef USE_VULKAN_GPU_DIAGNOSTICS
539 if (enable_op_profiling_) {
540 querypool_.shader_profile_end(cmd_, log_idx);
541 }
542 #endif /* USE_VULKAN_GPU_DIAGNOSTICS */
543
544 submit_count_++;
545 if (fence_handle != VK_NULL_HANDLE ||
546 submit_count_ >= config_.cmdSubmitFrequency) {
547 submit_cmd_to_gpu(fence_handle);
548 return true;
549 }
550
551 return false;
552 }
553
554 } // namespace api
555 } // namespace vulkan
556 } // namespace native
557 } // namespace at
558
559 #endif /* USE_VULKAN_API */
560