xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
12 #include <executorch/backends/vulkan/runtime/graph/ops/DispatchNode.h>
13 
14 #include <functional>
15 #include <unordered_map>
16 
17 #define VK_HAS_OP(name) ::vkcompute::operator_registry().has_op(name)
18 
19 #define VK_GET_OP_FN(name) ::vkcompute::operator_registry().get_op_fn(name)
20 
21 #define VK_REGISTER_OP(name, function)          \
22   ::vkcompute::operator_registry().register_op( \
23       #name,                                    \
24       std::bind(&function, std::placeholders::_1, std::placeholders::_2))
25 
26 #define REGISTER_OPERATORS                              \
27   static void register_ops();                           \
28   static const OperatorRegisterInit reg(&register_ops); \
29   static void register_ops()
30 
31 namespace vkcompute {
32 
33 /*
34  * The Vulkan operator registry maps ATen operator names
35  * to their Vulkan delegate function implementation. It is
36  * a simplified version of
37  * executorch/runtime/kernel/operator_registry.h that uses
38  * the C++ Standard Library.
39  */
40 class OperatorRegistry final {
41   using OpFunction =
42       const std::function<void(ComputeGraph&, const std::vector<ValueRef>&)>;
43   using OpTable = std::unordered_map<std::string, OpFunction>;
44 
45   OpTable table_;
46 
47  public:
48   /*
49    * Check if the registry has an operator registered under the given name
50    */
51   bool has_op(const std::string& name);
52 
53   /*
54    * Given an operator name, return the Vulkan delegate function
55    */
56   OpFunction& get_op_fn(const std::string& name);
57 
58   /*
59    * Register a function to a given operator name
60    */
61   void register_op(const std::string& name, OpFunction& fn);
62 };
63 
64 class OperatorRegisterInit final {
65   using InitFn = void();
66 
67  public:
OperatorRegisterInit(InitFn * init_fn)68   explicit OperatorRegisterInit(InitFn* init_fn) {
69     init_fn();
70   }
71 };
72 
73 // The Vulkan operator registry is global. It is retrieved using this function,
74 // where it is declared as a static local variable.
75 OperatorRegistry& operator_registry();
76 
77 } // namespace vkcompute
78