1 #pragma once 2 3 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName 4 5 #ifdef USE_VULKAN_API 6 7 #include <ATen/native/vulkan/api/vk_api.h> 8 9 #include <ATen/native/vulkan/api/Types.h> 10 #include <ATen/native/vulkan/api/Utils.h> 11 12 #include <mutex> 13 #include <unordered_map> 14 15 namespace at { 16 namespace native { 17 namespace vulkan { 18 namespace api { 19 20 class ShaderLayout final { 21 public: 22 using Signature = std::vector<VkDescriptorType>; 23 24 explicit ShaderLayout(VkDevice, const Signature&); 25 26 ShaderLayout(const ShaderLayout&) = delete; 27 ShaderLayout& operator=(const ShaderLayout&) = delete; 28 29 ShaderLayout(ShaderLayout&&) noexcept; 30 ShaderLayout& operator=(ShaderLayout&&) = delete; 31 32 ~ShaderLayout(); 33 34 private: 35 VkDevice device_; 36 VkDescriptorSetLayout handle_; 37 38 public: handle()39 VkDescriptorSetLayout handle() const { 40 return handle_; 41 } 42 43 // We need to define a custom swap function since this class 44 // does not allow for move assignment. The swap function will 45 // be used in the hash map. 46 friend void swap(ShaderLayout& lhs, ShaderLayout& rhs) noexcept; 47 }; 48 49 struct ShaderInfo final { 50 struct { 51 const uint32_t* bin; 52 uint32_t size; 53 } src_code; 54 55 std::string kernel_name{""}; 56 ShaderLayout::Signature kernel_layout{}; 57 58 // Shader Metadata 59 utils::uvec3 out_tile_size{1u, 1u, 1u}; 60 61 std::vector<uint32_t> tile_size; 62 StorageType bias_storage_type{StorageType::UNKNOWN}; 63 StorageType weight_storage_type{StorageType::UNKNOWN}; 64 65 explicit ShaderInfo(); 66 explicit ShaderInfo(std::string, const char*); 67 explicit ShaderInfo( 68 std::string, 69 const uint32_t*, 70 const uint32_t, 71 std::vector<VkDescriptorType>); 72 explicit ShaderInfo( 73 std::string, 74 const uint32_t*, 75 const uint32_t, 76 std::vector<VkDescriptorType>, 77 const std::vector<uint32_t>& tile_size, 78 const StorageType bias_storage_type, 79 const StorageType weight_storage_type); 80 }; 81 82 bool operator==(const ShaderInfo& _1, const ShaderInfo& _2); 83 84 class ShaderModule final { 85 public: 86 explicit ShaderModule(VkDevice device, const ShaderInfo& source); 87 88 ShaderModule(const ShaderModule&) = delete; 89 ShaderModule& operator=(const ShaderModule&) = delete; 90 91 ShaderModule(ShaderModule&&) noexcept; 92 ShaderModule& operator=(ShaderModule&&) = delete; 93 94 ~ShaderModule(); 95 96 private: 97 VkDevice device_; 98 VkShaderModule handle_; 99 100 public: handle()101 inline VkShaderModule handle() const { 102 return handle_; 103 } 104 105 // We need to define a custom swap function since this class 106 // does not allow for move assignment. The swap function will 107 // be used in the hash map. 108 friend void swap(ShaderModule& lhs, ShaderModule& rhs) noexcept; 109 }; 110 111 class ShaderLayoutCache final { 112 public: 113 explicit ShaderLayoutCache(VkDevice device); 114 115 ShaderLayoutCache(const ShaderLayoutCache&) = delete; 116 ShaderLayoutCache& operator=(const ShaderLayoutCache&) = delete; 117 118 ShaderLayoutCache(ShaderLayoutCache&&) noexcept; 119 ShaderLayoutCache& operator=(ShaderLayoutCache&&) = delete; 120 121 ~ShaderLayoutCache(); 122 123 using Key = ShaderLayout::Signature; 124 using Value = ShaderLayout; 125 126 struct Hasher { operatorHasher127 inline size_t operator()(const ShaderLayout::Signature& signature) const { 128 size_t hashed = 0u; 129 130 for (const VkDescriptorType type : signature) { 131 hashed = 132 utils::hash_combine(hashed, std::hash<VkDescriptorType>()(type)); 133 } 134 135 return hashed; 136 } 137 }; 138 139 private: 140 // Multiple threads could potentially be adding entries into the cache, so use 141 // a mutex to manage access 142 std::mutex cache_mutex_; 143 144 VkDevice device_; 145 std::unordered_map<Key, Value, Hasher> cache_; 146 147 public: 148 VkDescriptorSetLayout retrieve(const Key&); 149 void purge(); 150 }; 151 152 class ShaderCache final { 153 public: 154 explicit ShaderCache(VkDevice device); 155 156 ShaderCache(const ShaderCache&) = delete; 157 ShaderCache& operator=(const ShaderCache&) = delete; 158 159 ShaderCache(ShaderCache&&) noexcept; 160 ShaderCache& operator=(ShaderCache&&) = delete; 161 162 ~ShaderCache(); 163 164 using Key = ShaderInfo; 165 using Value = ShaderModule; 166 167 struct Hasher { operatorHasher168 inline size_t operator()(const ShaderInfo& source) const { 169 size_t seed = 0; 170 seed = utils::hash_combine( 171 seed, std::hash<const uint32_t*>()(source.src_code.bin)); 172 seed = utils::hash_combine( 173 seed, std::hash<uint32_t>()(source.src_code.size)); 174 175 return seed; 176 } 177 }; 178 179 private: 180 // Multiple threads could potentially be adding entries into the cache, so use 181 // a mutex to manage access 182 std::mutex cache_mutex_; 183 184 VkDevice device_; 185 std::unordered_map<Key, Value, Hasher> cache_; 186 187 public: 188 VkShaderModule retrieve(const Key&); 189 void purge(); 190 }; 191 192 } // namespace api 193 } // namespace vulkan 194 } // namespace native 195 } // namespace at 196 197 inline bool operator==( 198 const VkDescriptorSetLayoutBinding& _1, 199 const VkDescriptorSetLayoutBinding& _2) { 200 return ( 201 _1.binding == _2.binding && _1.descriptorType == _2.descriptorType && 202 _1.descriptorCount == _2.descriptorCount && 203 _1.stageFlags == _2.stageFlags && 204 _1.pImmutableSamplers == _2.pImmutableSamplers); 205 } 206 207 #endif /* USE_VULKAN_API */ 208