1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName 12 13 #include <executorch/backends/vulkan/runtime/vk_api/Shader.h> 14 15 #include <string> 16 #include <unordered_map> 17 18 #define VK_KERNEL(shader_name) \ 19 ::vkcompute::api::shader_registry().get_shader_info(#shader_name) 20 21 #define VK_KERNEL_FROM_STR(shader_name_str) \ 22 ::vkcompute::api::shader_registry().get_shader_info(shader_name_str) 23 24 namespace vkcompute { 25 namespace api { 26 27 enum class DispatchKey : int8_t { 28 CATCHALL, 29 ADRENO, 30 MALI, 31 OVERRIDE, 32 }; 33 34 class ShaderRegistry final { 35 using ShaderListing = std::unordered_map<std::string, vkapi::ShaderInfo>; 36 using Dispatcher = std::unordered_map<DispatchKey, std::string>; 37 using Registry = std::unordered_map<std::string, Dispatcher>; 38 39 ShaderListing listings_; 40 Dispatcher dispatcher_; 41 Registry registry_; 42 43 public: 44 /* 45 * Check if the registry has a shader registered under the given name 46 */ 47 bool has_shader(const std::string& shader_name); 48 49 /* 50 * Check if the registry has a dispatch registered under the given name 51 */ 52 bool has_dispatch(const std::string& op_name); 53 54 /* 55 * Register a ShaderInfo to a given shader name 56 */ 57 void register_shader(vkapi::ShaderInfo&& shader_info); 58 59 /* 60 * Register a dispatch entry to the given op name 61 */ 62 void register_op_dispatch( 63 const std::string& op_name, 64 const DispatchKey key, 65 const std::string& shader_name); 66 67 /* 68 * Given a shader name, return the ShaderInfo which contains the SPIRV binary 69 */ 70 const vkapi::ShaderInfo& get_shader_info(const std::string& shader_name); 71 }; 72 73 class ShaderRegisterInit final { 74 using InitFn = void(); 75 76 public: ShaderRegisterInit(InitFn * init_fn)77 ShaderRegisterInit(InitFn* init_fn) { 78 init_fn(); 79 }; 80 }; 81 82 // The global shader registry is retrieved using this function, where it is 83 // declared as a static local variable. 84 ShaderRegistry& shader_registry(); 85 86 } // namespace api 87 } // namespace vkcompute 88