xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/api/Pipeline.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
4 
5 #ifdef USE_VULKAN_API
6 
7 #include <ATen/native/vulkan/api/vk_api.h>
8 
9 #include <ATen/native/vulkan/api/Resource.h>
10 #include <ATen/native/vulkan/api/Shader.h>
11 
12 #include <mutex>
13 #include <unordered_map>
14 
15 namespace at {
16 namespace native {
17 namespace vulkan {
18 namespace api {
19 
20 struct PipelineBarrier final {
21   struct Stages final {
22     VkPipelineStageFlags src;
23     VkPipelineStageFlags dst;
24   } stage;
25 
26   std::vector<BufferMemoryBarrier> buffers;
27   std::vector<ImageMemoryBarrier> images;
28   std::vector<VkBufferMemoryBarrier> buffer_barrier_handles;
29   std::vector<VkImageMemoryBarrier> image_barrier_handles;
30 
31   inline operator bool() const {
32     return (0u != stage.src) || (0u != stage.dst) || !buffers.empty() ||
33         !images.empty();
34   }
35 };
36 
37 using PipelineStageFlags = uint8_t;
38 
39 enum PipelineStage : PipelineStageFlags {
40   NO_STAGE = 0u << 0u,
41   COMPUTE = 1u << 0u,
42   HOST = 1u << 1u,
43   TRANSFER = 1u << 2u,
44 };
45 
46 VkAccessFlags vk_access(const PipelineStageFlags, const MemoryAccessFlags);
47 VkPipelineStageFlags vk_stage(const PipelineStageFlags);
48 VkImageLayout vk_layout(const PipelineStageFlags, const MemoryAccessFlags);
49 
50 class PipelineLayout final {
51  public:
52   explicit PipelineLayout(VkDevice, VkDescriptorSetLayout);
53 
54   PipelineLayout(const PipelineLayout&) = delete;
55   PipelineLayout& operator=(const PipelineLayout&) = delete;
56 
57   PipelineLayout(PipelineLayout&&) noexcept;
58   PipelineLayout& operator=(PipelineLayout&&) = delete;
59 
60   ~PipelineLayout();
61 
62  private:
63   VkDevice device_;
64   VkPipelineLayout handle_;
65 
66  public:
handle()67   VkPipelineLayout handle() const {
68     return handle_;
69   }
70 
71   // We need to define a custom swap function since this class
72   // does not allow for move assignment. The swap function will
73   // be used in the hash map.
74   friend void swap(PipelineLayout& lhs, PipelineLayout& rhs) noexcept;
75 };
76 
77 class ComputePipeline final {
78  public:
79   struct Descriptor final {
80     VkPipelineLayout pipeline_layout;
81     VkShaderModule shader_module;
82     utils::uvec3 local_work_group;
83   };
84 
85   explicit ComputePipeline(
86       VkDevice device,
87       const Descriptor& descriptor,
88       VkPipelineCache pipeline_cache);
89 
90   ComputePipeline(const ComputePipeline&) = delete;
91   ComputePipeline& operator=(const ComputePipeline&) = delete;
92 
93   ComputePipeline(ComputePipeline&&) noexcept;
94   ComputePipeline& operator=(ComputePipeline&&) = delete;
95 
96   ~ComputePipeline();
97 
98  private:
99   VkDevice device_;
100   VkPipeline handle_;
101 
102  public:
handle()103   inline VkPipeline handle() const {
104     return handle_;
105   }
106 
107   // We need to define a custom swap function since this class
108   // does not allow for move assignment. The swap function will
109   // be used in the hash map.
110   friend void swap(ComputePipeline& lhs, ComputePipeline& rhs) noexcept;
111 };
112 
113 class PipelineLayoutCache final {
114  public:
115   explicit PipelineLayoutCache(VkDevice device);
116 
117   PipelineLayoutCache(const PipelineLayoutCache&) = delete;
118   PipelineLayoutCache& operator=(const PipelineLayoutCache&) = delete;
119 
120   PipelineLayoutCache(PipelineLayoutCache&&) noexcept;
121   PipelineLayoutCache& operator=(PipelineLayoutCache&&) = delete;
122 
123   ~PipelineLayoutCache();
124 
125   using Key = VkDescriptorSetLayout;
126   using Value = PipelineLayout;
127 
128   struct Hasher {
operatorHasher129     inline size_t operator()(VkDescriptorSetLayout descriptor_layout) const {
130       return std::hash<VkDescriptorSetLayout>()(descriptor_layout);
131     }
132   };
133 
134  private:
135   // Multiple threads could potentially be adding entries into the cache, so use
136   // a mutex to manage access
137   std::mutex cache_mutex_;
138 
139   VkDevice device_;
140   std::unordered_map<Key, Value, Hasher> cache_;
141 
142  public:
143   VkPipelineLayout retrieve(const Key&);
144   void purge();
145 };
146 
147 class ComputePipelineCache final {
148  public:
149   explicit ComputePipelineCache(VkDevice device);
150 
151   ComputePipelineCache(const ComputePipelineCache&) = delete;
152   ComputePipelineCache& operator=(const ComputePipelineCache&) = delete;
153 
154   ComputePipelineCache(ComputePipelineCache&&) noexcept;
155   ComputePipelineCache& operator=(ComputePipelineCache&&) = delete;
156 
157   ~ComputePipelineCache();
158 
159   using Key = ComputePipeline::Descriptor;
160   using Value = ComputePipeline;
161 
162   struct Hasher {
operatorHasher163     inline size_t operator()(
164         const ComputePipeline::Descriptor& descriptor) const {
165       size_t seed = 0;
166       seed = utils::hash_combine(
167           seed, std::hash<VkPipelineLayout>()(descriptor.pipeline_layout));
168       seed = utils::hash_combine(
169           seed, std::hash<VkShaderModule>()(descriptor.shader_module));
170       seed = utils::hash_combine(
171           seed, std::hash<uint32_t>()(descriptor.local_work_group.data[0u]));
172       seed = utils::hash_combine(
173           seed, std::hash<uint32_t>()(descriptor.local_work_group.data[1u]));
174       seed = utils::hash_combine(
175           seed, std::hash<uint32_t>()(descriptor.local_work_group.data[2u]));
176 
177       return seed;
178     }
179   };
180 
181  private:
182   // Multiple threads could potentially be adding entries into the cache, so use
183   // a mutex to manage access
184   std::mutex cache_mutex_;
185 
186   VkDevice device_;
187   VkPipelineCache pipeline_cache_;
188   std::unordered_map<Key, Value, Hasher> cache_;
189 
190  public:
191   VkPipeline retrieve(const Key&);
192   void purge();
193 };
194 
195 //
196 // Impl
197 //
198 
199 } // namespace api
200 } // namespace vulkan
201 } // namespace native
202 } // namespace at
203 
204 #endif /* USE_VULKAN_API */
205