1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h> 12 #include <executorch/backends/vulkan/runtime/graph/ops/DispatchNode.h> 13 14 #include <functional> 15 #include <unordered_map> 16 17 #define VK_HAS_OP(name) ::vkcompute::operator_registry().has_op(name) 18 19 #define VK_GET_OP_FN(name) ::vkcompute::operator_registry().get_op_fn(name) 20 21 #define VK_REGISTER_OP(name, function) \ 22 ::vkcompute::operator_registry().register_op( \ 23 #name, \ 24 std::bind(&function, std::placeholders::_1, std::placeholders::_2)) 25 26 #define REGISTER_OPERATORS \ 27 static void register_ops(); \ 28 static const OperatorRegisterInit reg(®ister_ops); \ 29 static void register_ops() 30 31 namespace vkcompute { 32 33 /* 34 * The Vulkan operator registry maps ATen operator names 35 * to their Vulkan delegate function implementation. It is 36 * a simplified version of 37 * executorch/runtime/kernel/operator_registry.h that uses 38 * the C++ Standard Library. 39 */ 40 class OperatorRegistry final { 41 using OpFunction = 42 const std::function<void(ComputeGraph&, const std::vector<ValueRef>&)>; 43 using OpTable = std::unordered_map<std::string, OpFunction>; 44 45 OpTable table_; 46 47 public: 48 /* 49 * Check if the registry has an operator registered under the given name 50 */ 51 bool has_op(const std::string& name); 52 53 /* 54 * Given an operator name, return the Vulkan delegate function 55 */ 56 OpFunction& get_op_fn(const std::string& name); 57 58 /* 59 * Register a function to a given operator name 60 */ 61 void register_op(const std::string& name, OpFunction& fn); 62 }; 63 64 class OperatorRegisterInit final { 65 using InitFn = void(); 66 67 public: OperatorRegisterInit(InitFn * init_fn)68 explicit OperatorRegisterInit(InitFn* init_fn) { 69 init_fn(); 70 } 71 }; 72 73 // The Vulkan operator registry is global. It is retrieved using this function, 74 // where it is declared as a static local variable. 75 OperatorRegistry& operator_registry(); 76 77 } // namespace vkcompute 78