xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/vk_api/Shader.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/vk_api/Shader.h>
10 
11 #include <utility>
12 
13 namespace vkcompute {
14 namespace vkapi {
15 
16 //
17 // ShaderInfo
18 //
19 
ShaderInfo()20 ShaderInfo::ShaderInfo()
21     : src_code{
22           nullptr,
23           0u,
24       } {}
25 
ShaderInfo(std::string name,const uint32_t * const spirv_bin,const uint32_t size,std::vector<VkDescriptorType> layout,const utils::uvec3 tile_size)26 ShaderInfo::ShaderInfo(
27     std::string name,
28     const uint32_t* const spirv_bin,
29     const uint32_t size,
30     std::vector<VkDescriptorType>  layout,
31     const utils::uvec3 tile_size)
32     : src_code{
33           spirv_bin,
34           size,
35       },
36       kernel_name{std::move(name)},
37       kernel_layout{std::move(layout)},
38       out_tile_size(tile_size) {
39 }
40 
operator ==(const ShaderInfo & _1,const ShaderInfo & _2)41 bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) {
42   return (
43       _1.src_code.bin == _2.src_code.bin &&
44       _1.src_code.size == _2.src_code.size);
45 }
46 
47 //
48 // ShaderLayout
49 //
50 
ShaderLayout(VkDevice device,const ShaderLayout::Signature & signature)51 ShaderLayout::ShaderLayout(
52     VkDevice device,
53     const ShaderLayout::Signature& signature)
54     : device_(device), handle_{VK_NULL_HANDLE} {
55   std::vector<VkDescriptorSetLayoutBinding> bindings;
56 
57   uint32_t binding_num = 0u;
58   for (const VkDescriptorType type : signature) {
59     bindings.push_back({
60         binding_num++, // binding
61         type, // descriptorType
62         1u, // descriptorCount
63         VK_SHADER_STAGE_COMPUTE_BIT, // stageFlags
64         nullptr, // pImmutableSamplers
65     });
66   }
67 
68   const VkDescriptorSetLayoutCreateInfo descriptor_set_layout_create_info{
69       VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO, // sType
70       nullptr, // pNext
71       0u, // flags
72       static_cast<uint32_t>(bindings.size()), // bindingCount
73       bindings.data(), // pBindings
74   };
75 
76   VK_CHECK(vkCreateDescriptorSetLayout(
77       device_, &descriptor_set_layout_create_info, nullptr, &handle_));
78 }
79 
ShaderLayout(ShaderLayout && other)80 ShaderLayout::ShaderLayout(ShaderLayout&& other) noexcept
81     : device_(other.device_), handle_(other.handle_) {
82   other.handle_ = VK_NULL_HANDLE;
83 }
84 
~ShaderLayout()85 ShaderLayout::~ShaderLayout() {
86   if (handle_ == VK_NULL_HANDLE) {
87     return;
88   }
89   vkDestroyDescriptorSetLayout(device_, handle_, nullptr);
90   handle_ = VK_NULL_HANDLE;
91 }
92 
swap(ShaderLayout & lhs,ShaderLayout & rhs)93 void swap(ShaderLayout& lhs, ShaderLayout& rhs) noexcept {
94   VkDevice tmp_device = lhs.device_;
95   VkDescriptorSetLayout tmp_handle = lhs.handle_;
96 
97   lhs.device_ = rhs.device_;
98   lhs.handle_ = rhs.handle_;
99 
100   rhs.device_ = tmp_device;
101   rhs.handle_ = tmp_handle;
102 }
103 
104 //
105 // ShaderModule
106 //
107 
ShaderModule(VkDevice device,const ShaderInfo & source)108 ShaderModule::ShaderModule(VkDevice device, const ShaderInfo& source)
109     : device_(device), handle_{VK_NULL_HANDLE} {
110   const uint32_t* code = source.src_code.bin;
111   uint32_t size = source.src_code.size;
112 
113   const VkShaderModuleCreateInfo shader_module_create_info{
114       VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, // sType
115       nullptr, // pNext
116       0u, // flags
117       size, // codeSize
118       code, // pCode
119   };
120 
121   VK_CHECK(vkCreateShaderModule(
122       device_, &shader_module_create_info, nullptr, &handle_));
123 }
124 
ShaderModule(ShaderModule && other)125 ShaderModule::ShaderModule(ShaderModule&& other) noexcept
126     : device_(other.device_), handle_(other.handle_) {
127   other.handle_ = VK_NULL_HANDLE;
128 }
129 
~ShaderModule()130 ShaderModule::~ShaderModule() {
131   if (handle_ == VK_NULL_HANDLE) {
132     return;
133   }
134   vkDestroyShaderModule(device_, handle_, nullptr);
135   handle_ = VK_NULL_HANDLE;
136 }
137 
swap(ShaderModule & lhs,ShaderModule & rhs)138 void swap(ShaderModule& lhs, ShaderModule& rhs) noexcept {
139   VkDevice tmp_device = lhs.device_;
140   VkShaderModule tmp_handle = lhs.handle_;
141 
142   lhs.device_ = rhs.device_;
143   lhs.handle_ = rhs.handle_;
144 
145   rhs.device_ = tmp_device;
146   rhs.handle_ = tmp_handle;
147 }
148 
149 //
150 // ShaderLayoutCache
151 //
152 
ShaderLayoutCache(VkDevice device)153 ShaderLayoutCache::ShaderLayoutCache(VkDevice device)
154     : cache_mutex_{}, device_(device), cache_{} {}
155 
ShaderLayoutCache(ShaderLayoutCache && other)156 ShaderLayoutCache::ShaderLayoutCache(ShaderLayoutCache&& other) noexcept
157     : cache_mutex_{}, device_(other.device_), cache_(std::move(other.cache_)) {
158   std::lock_guard<std::mutex> lock(other.cache_mutex_);
159 }
160 
~ShaderLayoutCache()161 ShaderLayoutCache::~ShaderLayoutCache() {
162   purge();
163 }
164 
retrieve(const ShaderLayoutCache::Key & key)165 VkDescriptorSetLayout ShaderLayoutCache::retrieve(
166     const ShaderLayoutCache::Key& key) {
167   std::lock_guard<std::mutex> lock(cache_mutex_);
168 
169   auto it = cache_.find(key);
170   if (cache_.cend() == it) {
171     it = cache_.insert({key, ShaderLayoutCache::Value(device_, key)}).first;
172   }
173 
174   return it->second.handle();
175 }
176 
purge()177 void ShaderLayoutCache::purge() {
178   std::lock_guard<std::mutex> lock(cache_mutex_);
179   cache_.clear();
180 }
181 
182 //
183 // ShaderCache
184 //
185 
ShaderCache(VkDevice device)186 ShaderCache::ShaderCache(VkDevice device)
187     : cache_mutex_{}, device_(device), cache_{} {}
188 
ShaderCache(ShaderCache && other)189 ShaderCache::ShaderCache(ShaderCache&& other) noexcept
190     : cache_mutex_{}, device_(other.device_), cache_(std::move(other.cache_)) {
191   std::lock_guard<std::mutex> lock(other.cache_mutex_);
192 }
193 
~ShaderCache()194 ShaderCache::~ShaderCache() {
195   purge();
196 }
197 
retrieve(const ShaderCache::Key & key)198 VkShaderModule ShaderCache::retrieve(const ShaderCache::Key& key) {
199   std::lock_guard<std::mutex> lock(cache_mutex_);
200 
201   auto it = cache_.find(key);
202   if (cache_.cend() == it) {
203     it = cache_.insert({key, ShaderCache::Value(device_, key)}).first;
204   }
205 
206   return it->second.handle();
207 }
208 
purge()209 void ShaderCache::purge() {
210   cache_.clear();
211 }
212 
213 } // namespace vkapi
214 } // namespace vkcompute
215