1 #pragma once 2 3 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName 4 5 #ifdef USE_VULKAN_API 6 7 #include <ATen/native/vulkan/api/vk_api.h> 8 9 #include <ATen/native/vulkan/api/Resource.h> 10 #include <ATen/native/vulkan/api/Shader.h> 11 12 #include <unordered_map> 13 14 namespace at { 15 namespace native { 16 namespace vulkan { 17 namespace api { 18 19 class DescriptorSet final { 20 public: 21 explicit DescriptorSet(VkDevice, VkDescriptorSet, ShaderLayout::Signature); 22 23 DescriptorSet(const DescriptorSet&) = delete; 24 DescriptorSet& operator=(const DescriptorSet&) = delete; 25 26 DescriptorSet(DescriptorSet&&) noexcept; 27 DescriptorSet& operator=(DescriptorSet&&) noexcept; 28 29 ~DescriptorSet() = default; 30 31 struct ResourceBinding final { 32 uint32_t binding_idx; 33 VkDescriptorType descriptor_type; 34 bool is_image; 35 36 union { 37 VkDescriptorBufferInfo buffer_info; 38 VkDescriptorImageInfo image_info; 39 } resource_info; 40 }; 41 42 private: 43 VkDevice device_; 44 VkDescriptorSet handle_; 45 ShaderLayout::Signature shader_layout_signature_; 46 std::vector<ResourceBinding> bindings_; 47 48 public: 49 DescriptorSet& bind(const uint32_t, const VulkanBuffer&); 50 DescriptorSet& bind(const uint32_t, const VulkanImage&); 51 52 VkDescriptorSet get_bind_handle() const; 53 54 private: 55 void add_binding(const ResourceBinding& resource); 56 }; 57 58 class DescriptorSetPile final { 59 public: 60 DescriptorSetPile( 61 const uint32_t, 62 VkDescriptorSetLayout, 63 VkDevice, 64 VkDescriptorPool); 65 66 DescriptorSetPile(const DescriptorSetPile&) = delete; 67 DescriptorSetPile& operator=(const DescriptorSetPile&) = delete; 68 69 DescriptorSetPile(DescriptorSetPile&&) = default; 70 DescriptorSetPile& operator=(DescriptorSetPile&&) = default; 71 72 ~DescriptorSetPile() = default; 73 74 private: 75 uint32_t pile_size_; 76 VkDescriptorSetLayout set_layout_; 77 VkDevice device_; 78 VkDescriptorPool pool_; 79 std::vector<VkDescriptorSet> descriptors_; 80 size_t in_use_; 81 82 public: 83 VkDescriptorSet get_descriptor_set(); 84 85 private: 86 void allocate_new_batch(); 87 }; 88 89 struct DescriptorPoolConfig final { 90 // Overall Pool capacity 91 uint32_t descriptorPoolMaxSets; 92 // DescriptorCounts by type 93 uint32_t descriptorUniformBufferCount; 94 uint32_t descriptorStorageBufferCount; 95 uint32_t descriptorCombinedSamplerCount; 96 uint32_t descriptorStorageImageCount; 97 // Pile size for pre-allocating descriptor sets 98 uint32_t descriptorPileSizes; 99 }; 100 101 class DescriptorPool final { 102 public: 103 explicit DescriptorPool(VkDevice, const DescriptorPoolConfig&); 104 105 DescriptorPool(const DescriptorPool&) = delete; 106 DescriptorPool& operator=(const DescriptorPool&) = delete; 107 108 DescriptorPool(DescriptorPool&&) = delete; 109 DescriptorPool& operator=(DescriptorPool&&) = delete; 110 111 ~DescriptorPool(); 112 113 private: 114 VkDevice device_; 115 VkDescriptorPool pool_; 116 DescriptorPoolConfig config_; 117 // New Descriptors 118 std::mutex mutex_; 119 std::unordered_map<VkDescriptorSetLayout, DescriptorSetPile> piles_; 120 121 public: 122 operator bool() const { 123 return (pool_ != VK_NULL_HANDLE); 124 } 125 126 void init(const DescriptorPoolConfig& config); 127 128 DescriptorSet get_descriptor_set( 129 VkDescriptorSetLayout handle, 130 const ShaderLayout::Signature& signature); 131 132 void flush(); 133 }; 134 135 } // namespace api 136 } // namespace vulkan 137 } // namespace native 138 } // namespace at 139 140 #endif /* USE_VULKAN_API */ 141