1 #pragma once 2 3 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName 4 5 #ifdef USE_VULKAN_API 6 7 #include <ATen/native/vulkan/api/vk_api.h> 8 9 #include <ATen/native/vulkan/api/Resource.h> 10 #include <ATen/native/vulkan/api/Shader.h> 11 12 #include <mutex> 13 #include <unordered_map> 14 15 namespace at { 16 namespace native { 17 namespace vulkan { 18 namespace api { 19 20 struct PipelineBarrier final { 21 struct Stages final { 22 VkPipelineStageFlags src; 23 VkPipelineStageFlags dst; 24 } stage; 25 26 std::vector<BufferMemoryBarrier> buffers; 27 std::vector<ImageMemoryBarrier> images; 28 std::vector<VkBufferMemoryBarrier> buffer_barrier_handles; 29 std::vector<VkImageMemoryBarrier> image_barrier_handles; 30 31 inline operator bool() const { 32 return (0u != stage.src) || (0u != stage.dst) || !buffers.empty() || 33 !images.empty(); 34 } 35 }; 36 37 using PipelineStageFlags = uint8_t; 38 39 enum PipelineStage : PipelineStageFlags { 40 NO_STAGE = 0u << 0u, 41 COMPUTE = 1u << 0u, 42 HOST = 1u << 1u, 43 TRANSFER = 1u << 2u, 44 }; 45 46 VkAccessFlags vk_access(const PipelineStageFlags, const MemoryAccessFlags); 47 VkPipelineStageFlags vk_stage(const PipelineStageFlags); 48 VkImageLayout vk_layout(const PipelineStageFlags, const MemoryAccessFlags); 49 50 class PipelineLayout final { 51 public: 52 explicit PipelineLayout(VkDevice, VkDescriptorSetLayout); 53 54 PipelineLayout(const PipelineLayout&) = delete; 55 PipelineLayout& operator=(const PipelineLayout&) = delete; 56 57 PipelineLayout(PipelineLayout&&) noexcept; 58 PipelineLayout& operator=(PipelineLayout&&) = delete; 59 60 ~PipelineLayout(); 61 62 private: 63 VkDevice device_; 64 VkPipelineLayout handle_; 65 66 public: handle()67 VkPipelineLayout handle() const { 68 return handle_; 69 } 70 71 // We need to define a custom swap function since this class 72 // does not allow for move assignment. The swap function will 73 // be used in the hash map. 74 friend void swap(PipelineLayout& lhs, PipelineLayout& rhs) noexcept; 75 }; 76 77 class ComputePipeline final { 78 public: 79 struct Descriptor final { 80 VkPipelineLayout pipeline_layout; 81 VkShaderModule shader_module; 82 utils::uvec3 local_work_group; 83 }; 84 85 explicit ComputePipeline( 86 VkDevice device, 87 const Descriptor& descriptor, 88 VkPipelineCache pipeline_cache); 89 90 ComputePipeline(const ComputePipeline&) = delete; 91 ComputePipeline& operator=(const ComputePipeline&) = delete; 92 93 ComputePipeline(ComputePipeline&&) noexcept; 94 ComputePipeline& operator=(ComputePipeline&&) = delete; 95 96 ~ComputePipeline(); 97 98 private: 99 VkDevice device_; 100 VkPipeline handle_; 101 102 public: handle()103 inline VkPipeline handle() const { 104 return handle_; 105 } 106 107 // We need to define a custom swap function since this class 108 // does not allow for move assignment. The swap function will 109 // be used in the hash map. 110 friend void swap(ComputePipeline& lhs, ComputePipeline& rhs) noexcept; 111 }; 112 113 class PipelineLayoutCache final { 114 public: 115 explicit PipelineLayoutCache(VkDevice device); 116 117 PipelineLayoutCache(const PipelineLayoutCache&) = delete; 118 PipelineLayoutCache& operator=(const PipelineLayoutCache&) = delete; 119 120 PipelineLayoutCache(PipelineLayoutCache&&) noexcept; 121 PipelineLayoutCache& operator=(PipelineLayoutCache&&) = delete; 122 123 ~PipelineLayoutCache(); 124 125 using Key = VkDescriptorSetLayout; 126 using Value = PipelineLayout; 127 128 struct Hasher { operatorHasher129 inline size_t operator()(VkDescriptorSetLayout descriptor_layout) const { 130 return std::hash<VkDescriptorSetLayout>()(descriptor_layout); 131 } 132 }; 133 134 private: 135 // Multiple threads could potentially be adding entries into the cache, so use 136 // a mutex to manage access 137 std::mutex cache_mutex_; 138 139 VkDevice device_; 140 std::unordered_map<Key, Value, Hasher> cache_; 141 142 public: 143 VkPipelineLayout retrieve(const Key&); 144 void purge(); 145 }; 146 147 class ComputePipelineCache final { 148 public: 149 explicit ComputePipelineCache(VkDevice device); 150 151 ComputePipelineCache(const ComputePipelineCache&) = delete; 152 ComputePipelineCache& operator=(const ComputePipelineCache&) = delete; 153 154 ComputePipelineCache(ComputePipelineCache&&) noexcept; 155 ComputePipelineCache& operator=(ComputePipelineCache&&) = delete; 156 157 ~ComputePipelineCache(); 158 159 using Key = ComputePipeline::Descriptor; 160 using Value = ComputePipeline; 161 162 struct Hasher { operatorHasher163 inline size_t operator()( 164 const ComputePipeline::Descriptor& descriptor) const { 165 size_t seed = 0; 166 seed = utils::hash_combine( 167 seed, std::hash<VkPipelineLayout>()(descriptor.pipeline_layout)); 168 seed = utils::hash_combine( 169 seed, std::hash<VkShaderModule>()(descriptor.shader_module)); 170 seed = utils::hash_combine( 171 seed, std::hash<uint32_t>()(descriptor.local_work_group.data[0u])); 172 seed = utils::hash_combine( 173 seed, std::hash<uint32_t>()(descriptor.local_work_group.data[1u])); 174 seed = utils::hash_combine( 175 seed, std::hash<uint32_t>()(descriptor.local_work_group.data[2u])); 176 177 return seed; 178 } 179 }; 180 181 private: 182 // Multiple threads could potentially be adding entries into the cache, so use 183 // a mutex to manage access 184 std::mutex cache_mutex_; 185 186 VkDevice device_; 187 VkPipelineCache pipeline_cache_; 188 std::unordered_map<Key, Value, Hasher> cache_; 189 190 public: 191 VkPipeline retrieve(const Key&); 192 void purge(); 193 }; 194 195 // 196 // Impl 197 // 198 199 } // namespace api 200 } // namespace vulkan 201 } // namespace native 202 } // namespace at 203 204 #endif /* USE_VULKAN_API */ 205