xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/api/ShaderRegistry.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
12 
13 #include <executorch/backends/vulkan/runtime/vk_api/Shader.h>
14 
15 #include <string>
16 #include <unordered_map>
17 
18 #define VK_KERNEL(shader_name) \
19   ::vkcompute::api::shader_registry().get_shader_info(#shader_name)
20 
21 #define VK_KERNEL_FROM_STR(shader_name_str) \
22   ::vkcompute::api::shader_registry().get_shader_info(shader_name_str)
23 
24 namespace vkcompute {
25 namespace api {
26 
27 enum class DispatchKey : int8_t {
28   CATCHALL,
29   ADRENO,
30   MALI,
31   OVERRIDE,
32 };
33 
34 class ShaderRegistry final {
35   using ShaderListing = std::unordered_map<std::string, vkapi::ShaderInfo>;
36   using Dispatcher = std::unordered_map<DispatchKey, std::string>;
37   using Registry = std::unordered_map<std::string, Dispatcher>;
38 
39   ShaderListing listings_;
40   Dispatcher dispatcher_;
41   Registry registry_;
42 
43  public:
44   /*
45    * Check if the registry has a shader registered under the given name
46    */
47   bool has_shader(const std::string& shader_name);
48 
49   /*
50    * Check if the registry has a dispatch registered under the given name
51    */
52   bool has_dispatch(const std::string& op_name);
53 
54   /*
55    * Register a ShaderInfo to a given shader name
56    */
57   void register_shader(vkapi::ShaderInfo&& shader_info);
58 
59   /*
60    * Register a dispatch entry to the given op name
61    */
62   void register_op_dispatch(
63       const std::string& op_name,
64       const DispatchKey key,
65       const std::string& shader_name);
66 
67   /*
68    * Given a shader name, return the ShaderInfo which contains the SPIRV binary
69    */
70   const vkapi::ShaderInfo& get_shader_info(const std::string& shader_name);
71 };
72 
73 class ShaderRegisterInit final {
74   using InitFn = void();
75 
76  public:
ShaderRegisterInit(InitFn * init_fn)77   ShaderRegisterInit(InitFn* init_fn) {
78     init_fn();
79   };
80 };
81 
82 // The global shader registry is retrieved using this function, where it is
83 // declared as a static local variable.
84 ShaderRegistry& shader_registry();
85 
86 } // namespace api
87 } // namespace vkcompute
88