xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/api/Descriptor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
4 
5 #ifdef USE_VULKAN_API
6 
7 #include <ATen/native/vulkan/api/vk_api.h>
8 
9 #include <ATen/native/vulkan/api/Resource.h>
10 #include <ATen/native/vulkan/api/Shader.h>
11 
12 #include <unordered_map>
13 
14 namespace at {
15 namespace native {
16 namespace vulkan {
17 namespace api {
18 
19 class DescriptorSet final {
20  public:
21   explicit DescriptorSet(VkDevice, VkDescriptorSet, ShaderLayout::Signature);
22 
23   DescriptorSet(const DescriptorSet&) = delete;
24   DescriptorSet& operator=(const DescriptorSet&) = delete;
25 
26   DescriptorSet(DescriptorSet&&) noexcept;
27   DescriptorSet& operator=(DescriptorSet&&) noexcept;
28 
29   ~DescriptorSet() = default;
30 
31   struct ResourceBinding final {
32     uint32_t binding_idx;
33     VkDescriptorType descriptor_type;
34     bool is_image;
35 
36     union {
37       VkDescriptorBufferInfo buffer_info;
38       VkDescriptorImageInfo image_info;
39     } resource_info;
40   };
41 
42  private:
43   VkDevice device_;
44   VkDescriptorSet handle_;
45   ShaderLayout::Signature shader_layout_signature_;
46   std::vector<ResourceBinding> bindings_;
47 
48  public:
49   DescriptorSet& bind(const uint32_t, const VulkanBuffer&);
50   DescriptorSet& bind(const uint32_t, const VulkanImage&);
51 
52   VkDescriptorSet get_bind_handle() const;
53 
54  private:
55   void add_binding(const ResourceBinding& resource);
56 };
57 
58 class DescriptorSetPile final {
59  public:
60   DescriptorSetPile(
61       const uint32_t,
62       VkDescriptorSetLayout,
63       VkDevice,
64       VkDescriptorPool);
65 
66   DescriptorSetPile(const DescriptorSetPile&) = delete;
67   DescriptorSetPile& operator=(const DescriptorSetPile&) = delete;
68 
69   DescriptorSetPile(DescriptorSetPile&&) = default;
70   DescriptorSetPile& operator=(DescriptorSetPile&&) = default;
71 
72   ~DescriptorSetPile() = default;
73 
74  private:
75   uint32_t pile_size_;
76   VkDescriptorSetLayout set_layout_;
77   VkDevice device_;
78   VkDescriptorPool pool_;
79   std::vector<VkDescriptorSet> descriptors_;
80   size_t in_use_;
81 
82  public:
83   VkDescriptorSet get_descriptor_set();
84 
85  private:
86   void allocate_new_batch();
87 };
88 
89 struct DescriptorPoolConfig final {
90   // Overall Pool capacity
91   uint32_t descriptorPoolMaxSets;
92   // DescriptorCounts by type
93   uint32_t descriptorUniformBufferCount;
94   uint32_t descriptorStorageBufferCount;
95   uint32_t descriptorCombinedSamplerCount;
96   uint32_t descriptorStorageImageCount;
97   // Pile size for pre-allocating descriptor sets
98   uint32_t descriptorPileSizes;
99 };
100 
101 class DescriptorPool final {
102  public:
103   explicit DescriptorPool(VkDevice, const DescriptorPoolConfig&);
104 
105   DescriptorPool(const DescriptorPool&) = delete;
106   DescriptorPool& operator=(const DescriptorPool&) = delete;
107 
108   DescriptorPool(DescriptorPool&&) = delete;
109   DescriptorPool& operator=(DescriptorPool&&) = delete;
110 
111   ~DescriptorPool();
112 
113  private:
114   VkDevice device_;
115   VkDescriptorPool pool_;
116   DescriptorPoolConfig config_;
117   // New Descriptors
118   std::mutex mutex_;
119   std::unordered_map<VkDescriptorSetLayout, DescriptorSetPile> piles_;
120 
121  public:
122   operator bool() const {
123     return (pool_ != VK_NULL_HANDLE);
124   }
125 
126   void init(const DescriptorPoolConfig& config);
127 
128   DescriptorSet get_descriptor_set(
129       VkDescriptorSetLayout handle,
130       const ShaderLayout::Signature& signature);
131 
132   void flush();
133 };
134 
135 } // namespace api
136 } // namespace vulkan
137 } // namespace native
138 } // namespace at
139 
140 #endif /* USE_VULKAN_API */
141