1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/runtime/kernel/operator_registry.h>
10
11 #include <cinttypes>
12
13 #include <executorch/runtime/platform/assert.h>
14 #include <executorch/runtime/platform/platform.h>
15 #include <executorch/runtime/platform/system.h>
16
17 namespace executorch {
18 namespace runtime {
19
20 namespace {
21
22 // Maximum number of operators and their associated kernels that can be
23 // registered.
24 #ifdef MAX_KERNEL_NUM
25 constexpr uint32_t kMaxRegisteredKernels = MAX_KERNEL_NUM;
26 #else
27 constexpr uint32_t kMaxOperators = 250;
28 constexpr uint32_t kMaxKernelsPerOp = 8;
29 constexpr uint32_t kMaxRegisteredKernels = kMaxOperators * kMaxKernelsPerOp;
30 #endif
31
32 // Data that backs the kernel table. Since Kernel has a custom default
33 // constructor (implicitly, because it contains KernelKey, which has a custom
34 // ctor), some toolchains don't like having a global array of them: it would
35 // require constructing them at init time. Since we don't care about the values
36 // until we add each entry to the table, allocate static zeroed memory instead
37 // and point the table at it.
38 // @lint-ignore CLANGTIDY facebook-hte-CArray
39 alignas(sizeof(Kernel)) uint8_t
40 registered_kernels_data[kMaxRegisteredKernels * sizeof(Kernel)];
41
42 /// Global table of registered kernels.
43 Kernel* registered_kernels = reinterpret_cast<Kernel*>(registered_kernels_data);
44
45 /// The number of kernels registered in the table.
46 size_t num_registered_kernels = 0;
47
48 // Registers the kernels, but may return an error.
register_kernels_internal(const Span<const Kernel> kernels)49 Error register_kernels_internal(const Span<const Kernel> kernels) {
50 // Operator registration happens in static initialization time before or after
51 // PAL init, so call it here. It is safe to call multiple times.
52 ::et_pal_init();
53
54 if (kernels.size() + num_registered_kernels > kMaxRegisteredKernels) {
55 ET_LOG(
56 Error,
57 "The total number of kernels to be registered is larger than the limit "
58 "%" PRIu32 ". %" PRIu32
59 " kernels are already registered and we're trying to register another "
60 "%" PRIu32 " kernels.",
61 kMaxRegisteredKernels,
62 (uint32_t)num_registered_kernels,
63 (uint32_t)kernels.size());
64 ET_LOG(Error, "======== Kernels already in the registry: ========");
65 for (size_t i = 0; i < num_registered_kernels; i++) {
66 ET_LOG(Error, "%s", registered_kernels[i].name_);
67 ET_LOG_KERNEL_KEY(registered_kernels[i].kernel_key_);
68 }
69 ET_LOG(Error, "======== Kernels being registered: ========");
70 for (size_t i = 0; i < kernels.size(); i++) {
71 ET_LOG(Error, "%s", kernels[i].name_);
72 ET_LOG_KERNEL_KEY(kernels[i].kernel_key_);
73 }
74 return Error::Internal;
75 }
76 // for debugging purpose
77 const char* lib_name = et_pal_get_shared_library_name(kernels.data());
78
79 for (const auto& kernel : kernels) {
80 // Linear search. This is fine if the number of kernels is small.
81 for (int32_t i = 0; i < num_registered_kernels; i++) {
82 Kernel k = registered_kernels[i];
83 if (strcmp(kernel.name_, k.name_) == 0 &&
84 kernel.kernel_key_ == k.kernel_key_) {
85 ET_LOG(Error, "Re-registering %s, from %s", k.name_, lib_name);
86 ET_LOG_KERNEL_KEY(k.kernel_key_);
87 return Error::InvalidArgument;
88 }
89 }
90 registered_kernels[num_registered_kernels++] = kernel;
91 }
92 ET_LOG(
93 Debug,
94 "Successfully registered all kernels from shared library: %s",
95 lib_name);
96
97 return Error::Ok;
98 }
99
100 } // namespace
101
102 // Registers the kernels, but panics if an error occurs. Always returns Ok.
register_kernels(const Span<const Kernel> kernels)103 Error register_kernels(const Span<const Kernel> kernels) {
104 Error success = register_kernels_internal(kernels);
105 if (success == Error::InvalidArgument || success == Error::Internal) {
106 ET_CHECK_MSG(
107 false,
108 "Kernel registration failed with error %" PRIu32
109 ", see error log for details.",
110 static_cast<uint32_t>(success));
111 }
112 return success;
113 }
114
115 namespace {
copy_char_as_number_to_buf(char num,char * buf)116 int copy_char_as_number_to_buf(char num, char* buf) {
117 if ((char)num < 10) {
118 *buf = '0' + (char)num;
119 buf += 1;
120 return 1;
121 } else {
122 *buf = '0' + ((char)num) / 10;
123 buf += 1;
124 *buf = '0' + ((char)num) % 10;
125 buf += 1;
126 return 2;
127 }
128 }
129 } // namespace
130
131 namespace internal {
make_kernel_key_string(Span<const TensorMeta> key,char * buf)132 void make_kernel_key_string(Span<const TensorMeta> key, char* buf) {
133 if (key.empty()) {
134 // If no tensor is present in an op, kernel key does not apply
135 return;
136 }
137 strncpy(buf, "v1/", 3);
138 buf += 3;
139 for (size_t i = 0; i < key.size(); i++) {
140 auto& meta = key[i];
141 buf += copy_char_as_number_to_buf((char)meta.dtype_, buf);
142 *buf = ';';
143 buf += 1;
144 for (int j = 0; j < meta.dim_order_.size(); j++) {
145 buf += copy_char_as_number_to_buf((char)meta.dim_order_[j], buf);
146 if (j != meta.dim_order_.size() - 1) {
147 *buf = ',';
148 buf += 1;
149 }
150 }
151 *buf = (i < (key.size() - 1)) ? '|' : 0x00;
152 buf += 1;
153 }
154 }
155 } // namespace internal
156
registry_has_op_function(const char * name,Span<const TensorMeta> meta_list)157 bool registry_has_op_function(
158 const char* name,
159 Span<const TensorMeta> meta_list) {
160 return get_op_function_from_registry(name, meta_list).ok();
161 }
162
get_op_function_from_registry(const char * name,Span<const TensorMeta> meta_list)163 Result<OpFunction> get_op_function_from_registry(
164 const char* name,
165 Span<const TensorMeta> meta_list) {
166 // @lint-ignore CLANGTIDY facebook-hte-CArray
167 char buf[KernelKey::MAX_SIZE] = {0};
168 internal::make_kernel_key_string(meta_list, buf);
169 KernelKey kernel_key = KernelKey(buf);
170
171 int32_t fallback_idx = -1;
172 for (size_t idx = 0; idx < num_registered_kernels; idx++) {
173 if (strcmp(registered_kernels[idx].name_, name) == 0) {
174 if (registered_kernels[idx].kernel_key_ == kernel_key) {
175 return registered_kernels[idx].op_;
176 }
177 if (registered_kernels[idx].kernel_key_.is_fallback()) {
178 fallback_idx = idx;
179 }
180 }
181 }
182 if (fallback_idx != -1) {
183 return registered_kernels[fallback_idx].op_;
184 }
185 ET_LOG(Error, "kernel '%s' not found.", name);
186 ET_LOG_TENSOR_META(meta_list);
187 return Error::OperatorMissing;
188 }
189
get_registered_kernels()190 Span<const Kernel> get_registered_kernels() {
191 return {registered_kernels, num_registered_kernels};
192 }
193
194 } // namespace runtime
195 } // namespace executorch
196