1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/vulkan/runtime/vk_api/Shader.h>
10
11 #include <utility>
12
13 namespace vkcompute {
14 namespace vkapi {
15
16 //
17 // ShaderInfo
18 //
19
ShaderInfo()20 ShaderInfo::ShaderInfo()
21 : src_code{
22 nullptr,
23 0u,
24 } {}
25
ShaderInfo(std::string name,const uint32_t * const spirv_bin,const uint32_t size,std::vector<VkDescriptorType> layout,const utils::uvec3 tile_size)26 ShaderInfo::ShaderInfo(
27 std::string name,
28 const uint32_t* const spirv_bin,
29 const uint32_t size,
30 std::vector<VkDescriptorType> layout,
31 const utils::uvec3 tile_size)
32 : src_code{
33 spirv_bin,
34 size,
35 },
36 kernel_name{std::move(name)},
37 kernel_layout{std::move(layout)},
38 out_tile_size(tile_size) {
39 }
40
operator ==(const ShaderInfo & _1,const ShaderInfo & _2)41 bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) {
42 return (
43 _1.src_code.bin == _2.src_code.bin &&
44 _1.src_code.size == _2.src_code.size);
45 }
46
47 //
48 // ShaderLayout
49 //
50
ShaderLayout(VkDevice device,const ShaderLayout::Signature & signature)51 ShaderLayout::ShaderLayout(
52 VkDevice device,
53 const ShaderLayout::Signature& signature)
54 : device_(device), handle_{VK_NULL_HANDLE} {
55 std::vector<VkDescriptorSetLayoutBinding> bindings;
56
57 uint32_t binding_num = 0u;
58 for (const VkDescriptorType type : signature) {
59 bindings.push_back({
60 binding_num++, // binding
61 type, // descriptorType
62 1u, // descriptorCount
63 VK_SHADER_STAGE_COMPUTE_BIT, // stageFlags
64 nullptr, // pImmutableSamplers
65 });
66 }
67
68 const VkDescriptorSetLayoutCreateInfo descriptor_set_layout_create_info{
69 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO, // sType
70 nullptr, // pNext
71 0u, // flags
72 static_cast<uint32_t>(bindings.size()), // bindingCount
73 bindings.data(), // pBindings
74 };
75
76 VK_CHECK(vkCreateDescriptorSetLayout(
77 device_, &descriptor_set_layout_create_info, nullptr, &handle_));
78 }
79
ShaderLayout(ShaderLayout && other)80 ShaderLayout::ShaderLayout(ShaderLayout&& other) noexcept
81 : device_(other.device_), handle_(other.handle_) {
82 other.handle_ = VK_NULL_HANDLE;
83 }
84
~ShaderLayout()85 ShaderLayout::~ShaderLayout() {
86 if (handle_ == VK_NULL_HANDLE) {
87 return;
88 }
89 vkDestroyDescriptorSetLayout(device_, handle_, nullptr);
90 handle_ = VK_NULL_HANDLE;
91 }
92
swap(ShaderLayout & lhs,ShaderLayout & rhs)93 void swap(ShaderLayout& lhs, ShaderLayout& rhs) noexcept {
94 VkDevice tmp_device = lhs.device_;
95 VkDescriptorSetLayout tmp_handle = lhs.handle_;
96
97 lhs.device_ = rhs.device_;
98 lhs.handle_ = rhs.handle_;
99
100 rhs.device_ = tmp_device;
101 rhs.handle_ = tmp_handle;
102 }
103
104 //
105 // ShaderModule
106 //
107
ShaderModule(VkDevice device,const ShaderInfo & source)108 ShaderModule::ShaderModule(VkDevice device, const ShaderInfo& source)
109 : device_(device), handle_{VK_NULL_HANDLE} {
110 const uint32_t* code = source.src_code.bin;
111 uint32_t size = source.src_code.size;
112
113 const VkShaderModuleCreateInfo shader_module_create_info{
114 VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, // sType
115 nullptr, // pNext
116 0u, // flags
117 size, // codeSize
118 code, // pCode
119 };
120
121 VK_CHECK(vkCreateShaderModule(
122 device_, &shader_module_create_info, nullptr, &handle_));
123 }
124
ShaderModule(ShaderModule && other)125 ShaderModule::ShaderModule(ShaderModule&& other) noexcept
126 : device_(other.device_), handle_(other.handle_) {
127 other.handle_ = VK_NULL_HANDLE;
128 }
129
~ShaderModule()130 ShaderModule::~ShaderModule() {
131 if (handle_ == VK_NULL_HANDLE) {
132 return;
133 }
134 vkDestroyShaderModule(device_, handle_, nullptr);
135 handle_ = VK_NULL_HANDLE;
136 }
137
swap(ShaderModule & lhs,ShaderModule & rhs)138 void swap(ShaderModule& lhs, ShaderModule& rhs) noexcept {
139 VkDevice tmp_device = lhs.device_;
140 VkShaderModule tmp_handle = lhs.handle_;
141
142 lhs.device_ = rhs.device_;
143 lhs.handle_ = rhs.handle_;
144
145 rhs.device_ = tmp_device;
146 rhs.handle_ = tmp_handle;
147 }
148
149 //
150 // ShaderLayoutCache
151 //
152
ShaderLayoutCache(VkDevice device)153 ShaderLayoutCache::ShaderLayoutCache(VkDevice device)
154 : cache_mutex_{}, device_(device), cache_{} {}
155
ShaderLayoutCache(ShaderLayoutCache && other)156 ShaderLayoutCache::ShaderLayoutCache(ShaderLayoutCache&& other) noexcept
157 : cache_mutex_{}, device_(other.device_), cache_(std::move(other.cache_)) {
158 std::lock_guard<std::mutex> lock(other.cache_mutex_);
159 }
160
~ShaderLayoutCache()161 ShaderLayoutCache::~ShaderLayoutCache() {
162 purge();
163 }
164
retrieve(const ShaderLayoutCache::Key & key)165 VkDescriptorSetLayout ShaderLayoutCache::retrieve(
166 const ShaderLayoutCache::Key& key) {
167 std::lock_guard<std::mutex> lock(cache_mutex_);
168
169 auto it = cache_.find(key);
170 if (cache_.cend() == it) {
171 it = cache_.insert({key, ShaderLayoutCache::Value(device_, key)}).first;
172 }
173
174 return it->second.handle();
175 }
176
purge()177 void ShaderLayoutCache::purge() {
178 std::lock_guard<std::mutex> lock(cache_mutex_);
179 cache_.clear();
180 }
181
182 //
183 // ShaderCache
184 //
185
ShaderCache(VkDevice device)186 ShaderCache::ShaderCache(VkDevice device)
187 : cache_mutex_{}, device_(device), cache_{} {}
188
ShaderCache(ShaderCache && other)189 ShaderCache::ShaderCache(ShaderCache&& other) noexcept
190 : cache_mutex_{}, device_(other.device_), cache_(std::move(other.cache_)) {
191 std::lock_guard<std::mutex> lock(other.cache_mutex_);
192 }
193
~ShaderCache()194 ShaderCache::~ShaderCache() {
195 purge();
196 }
197
retrieve(const ShaderCache::Key & key)198 VkShaderModule ShaderCache::retrieve(const ShaderCache::Key& key) {
199 std::lock_guard<std::mutex> lock(cache_mutex_);
200
201 auto it = cache_.find(key);
202 if (cache_.cend() == it) {
203 it = cache_.insert({key, ShaderCache::Value(device_, key)}).first;
204 }
205
206 return it->second.handle();
207 }
208
purge()209 void ShaderCache::purge() {
210 cache_.clear();
211 }
212
213 } // namespace vkapi
214 } // namespace vkcompute
215