xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/api/Shader.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
4 
5 #ifdef USE_VULKAN_API
6 
7 #include <ATen/native/vulkan/api/vk_api.h>
8 
9 #include <ATen/native/vulkan/api/Types.h>
10 #include <ATen/native/vulkan/api/Utils.h>
11 
12 #include <mutex>
13 #include <unordered_map>
14 
15 namespace at {
16 namespace native {
17 namespace vulkan {
18 namespace api {
19 
20 class ShaderLayout final {
21  public:
22   using Signature = std::vector<VkDescriptorType>;
23 
24   explicit ShaderLayout(VkDevice, const Signature&);
25 
26   ShaderLayout(const ShaderLayout&) = delete;
27   ShaderLayout& operator=(const ShaderLayout&) = delete;
28 
29   ShaderLayout(ShaderLayout&&) noexcept;
30   ShaderLayout& operator=(ShaderLayout&&) = delete;
31 
32   ~ShaderLayout();
33 
34  private:
35   VkDevice device_;
36   VkDescriptorSetLayout handle_;
37 
38  public:
handle()39   VkDescriptorSetLayout handle() const {
40     return handle_;
41   }
42 
43   // We need to define a custom swap function since this class
44   // does not allow for move assignment. The swap function will
45   // be used in the hash map.
46   friend void swap(ShaderLayout& lhs, ShaderLayout& rhs) noexcept;
47 };
48 
49 struct ShaderInfo final {
50   struct {
51     const uint32_t* bin;
52     uint32_t size;
53   } src_code;
54 
55   std::string kernel_name{""};
56   ShaderLayout::Signature kernel_layout{};
57 
58   // Shader Metadata
59   utils::uvec3 out_tile_size{1u, 1u, 1u};
60 
61   std::vector<uint32_t> tile_size;
62   StorageType bias_storage_type{StorageType::UNKNOWN};
63   StorageType weight_storage_type{StorageType::UNKNOWN};
64 
65   explicit ShaderInfo();
66   explicit ShaderInfo(std::string, const char*);
67   explicit ShaderInfo(
68       std::string,
69       const uint32_t*,
70       const uint32_t,
71       std::vector<VkDescriptorType>);
72   explicit ShaderInfo(
73       std::string,
74       const uint32_t*,
75       const uint32_t,
76       std::vector<VkDescriptorType>,
77       const std::vector<uint32_t>& tile_size,
78       const StorageType bias_storage_type,
79       const StorageType weight_storage_type);
80 };
81 
82 bool operator==(const ShaderInfo& _1, const ShaderInfo& _2);
83 
84 class ShaderModule final {
85  public:
86   explicit ShaderModule(VkDevice device, const ShaderInfo& source);
87 
88   ShaderModule(const ShaderModule&) = delete;
89   ShaderModule& operator=(const ShaderModule&) = delete;
90 
91   ShaderModule(ShaderModule&&) noexcept;
92   ShaderModule& operator=(ShaderModule&&) = delete;
93 
94   ~ShaderModule();
95 
96  private:
97   VkDevice device_;
98   VkShaderModule handle_;
99 
100  public:
handle()101   inline VkShaderModule handle() const {
102     return handle_;
103   }
104 
105   // We need to define a custom swap function since this class
106   // does not allow for move assignment. The swap function will
107   // be used in the hash map.
108   friend void swap(ShaderModule& lhs, ShaderModule& rhs) noexcept;
109 };
110 
111 class ShaderLayoutCache final {
112  public:
113   explicit ShaderLayoutCache(VkDevice device);
114 
115   ShaderLayoutCache(const ShaderLayoutCache&) = delete;
116   ShaderLayoutCache& operator=(const ShaderLayoutCache&) = delete;
117 
118   ShaderLayoutCache(ShaderLayoutCache&&) noexcept;
119   ShaderLayoutCache& operator=(ShaderLayoutCache&&) = delete;
120 
121   ~ShaderLayoutCache();
122 
123   using Key = ShaderLayout::Signature;
124   using Value = ShaderLayout;
125 
126   struct Hasher {
operatorHasher127     inline size_t operator()(const ShaderLayout::Signature& signature) const {
128       size_t hashed = 0u;
129 
130       for (const VkDescriptorType type : signature) {
131         hashed =
132             utils::hash_combine(hashed, std::hash<VkDescriptorType>()(type));
133       }
134 
135       return hashed;
136     }
137   };
138 
139  private:
140   // Multiple threads could potentially be adding entries into the cache, so use
141   // a mutex to manage access
142   std::mutex cache_mutex_;
143 
144   VkDevice device_;
145   std::unordered_map<Key, Value, Hasher> cache_;
146 
147  public:
148   VkDescriptorSetLayout retrieve(const Key&);
149   void purge();
150 };
151 
152 class ShaderCache final {
153  public:
154   explicit ShaderCache(VkDevice device);
155 
156   ShaderCache(const ShaderCache&) = delete;
157   ShaderCache& operator=(const ShaderCache&) = delete;
158 
159   ShaderCache(ShaderCache&&) noexcept;
160   ShaderCache& operator=(ShaderCache&&) = delete;
161 
162   ~ShaderCache();
163 
164   using Key = ShaderInfo;
165   using Value = ShaderModule;
166 
167   struct Hasher {
operatorHasher168     inline size_t operator()(const ShaderInfo& source) const {
169       size_t seed = 0;
170       seed = utils::hash_combine(
171           seed, std::hash<const uint32_t*>()(source.src_code.bin));
172       seed = utils::hash_combine(
173           seed, std::hash<uint32_t>()(source.src_code.size));
174 
175       return seed;
176     }
177   };
178 
179  private:
180   // Multiple threads could potentially be adding entries into the cache, so use
181   // a mutex to manage access
182   std::mutex cache_mutex_;
183 
184   VkDevice device_;
185   std::unordered_map<Key, Value, Hasher> cache_;
186 
187  public:
188   VkShaderModule retrieve(const Key&);
189   void purge();
190 };
191 
192 } // namespace api
193 } // namespace vulkan
194 } // namespace native
195 } // namespace at
196 
197 inline bool operator==(
198     const VkDescriptorSetLayoutBinding& _1,
199     const VkDescriptorSetLayoutBinding& _2) {
200   return (
201       _1.binding == _2.binding && _1.descriptorType == _2.descriptorType &&
202       _1.descriptorCount == _2.descriptorCount &&
203       _1.stageFlags == _2.stageFlags &&
204       _1.pImmutableSamplers == _2.pImmutableSamplers);
205 }
206 
207 #endif /* USE_VULKAN_API */
208