/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include namespace executorch { namespace runtime { namespace { // Maximum number of operators and their associated kernels that can be // registered. #ifdef MAX_KERNEL_NUM constexpr uint32_t kMaxRegisteredKernels = MAX_KERNEL_NUM; #else constexpr uint32_t kMaxOperators = 250; constexpr uint32_t kMaxKernelsPerOp = 8; constexpr uint32_t kMaxRegisteredKernels = kMaxOperators * kMaxKernelsPerOp; #endif // Data that backs the kernel table. Since Kernel has a custom default // constructor (implicitly, because it contains KernelKey, which has a custom // ctor), some toolchains don't like having a global array of them: it would // require constructing them at init time. Since we don't care about the values // until we add each entry to the table, allocate static zeroed memory instead // and point the table at it. // @lint-ignore CLANGTIDY facebook-hte-CArray alignas(sizeof(Kernel)) uint8_t registered_kernels_data[kMaxRegisteredKernels * sizeof(Kernel)]; /// Global table of registered kernels. Kernel* registered_kernels = reinterpret_cast(registered_kernels_data); /// The number of kernels registered in the table. size_t num_registered_kernels = 0; // Registers the kernels, but may return an error. Error register_kernels_internal(const Span kernels) { // Operator registration happens in static initialization time before or after // PAL init, so call it here. It is safe to call multiple times. ::et_pal_init(); if (kernels.size() + num_registered_kernels > kMaxRegisteredKernels) { ET_LOG( Error, "The total number of kernels to be registered is larger than the limit " "%" PRIu32 ". %" PRIu32 " kernels are already registered and we're trying to register another " "%" PRIu32 " kernels.", kMaxRegisteredKernels, (uint32_t)num_registered_kernels, (uint32_t)kernels.size()); ET_LOG(Error, "======== Kernels already in the registry: ========"); for (size_t i = 0; i < num_registered_kernels; i++) { ET_LOG(Error, "%s", registered_kernels[i].name_); ET_LOG_KERNEL_KEY(registered_kernels[i].kernel_key_); } ET_LOG(Error, "======== Kernels being registered: ========"); for (size_t i = 0; i < kernels.size(); i++) { ET_LOG(Error, "%s", kernels[i].name_); ET_LOG_KERNEL_KEY(kernels[i].kernel_key_); } return Error::Internal; } // for debugging purpose const char* lib_name = et_pal_get_shared_library_name(kernels.data()); for (const auto& kernel : kernels) { // Linear search. This is fine if the number of kernels is small. for (int32_t i = 0; i < num_registered_kernels; i++) { Kernel k = registered_kernels[i]; if (strcmp(kernel.name_, k.name_) == 0 && kernel.kernel_key_ == k.kernel_key_) { ET_LOG(Error, "Re-registering %s, from %s", k.name_, lib_name); ET_LOG_KERNEL_KEY(k.kernel_key_); return Error::InvalidArgument; } } registered_kernels[num_registered_kernels++] = kernel; } ET_LOG( Debug, "Successfully registered all kernels from shared library: %s", lib_name); return Error::Ok; } } // namespace // Registers the kernels, but panics if an error occurs. Always returns Ok. Error register_kernels(const Span kernels) { Error success = register_kernels_internal(kernels); if (success == Error::InvalidArgument || success == Error::Internal) { ET_CHECK_MSG( false, "Kernel registration failed with error %" PRIu32 ", see error log for details.", static_cast(success)); } return success; } namespace { int copy_char_as_number_to_buf(char num, char* buf) { if ((char)num < 10) { *buf = '0' + (char)num; buf += 1; return 1; } else { *buf = '0' + ((char)num) / 10; buf += 1; *buf = '0' + ((char)num) % 10; buf += 1; return 2; } } } // namespace namespace internal { void make_kernel_key_string(Span key, char* buf) { if (key.empty()) { // If no tensor is present in an op, kernel key does not apply return; } strncpy(buf, "v1/", 3); buf += 3; for (size_t i = 0; i < key.size(); i++) { auto& meta = key[i]; buf += copy_char_as_number_to_buf((char)meta.dtype_, buf); *buf = ';'; buf += 1; for (int j = 0; j < meta.dim_order_.size(); j++) { buf += copy_char_as_number_to_buf((char)meta.dim_order_[j], buf); if (j != meta.dim_order_.size() - 1) { *buf = ','; buf += 1; } } *buf = (i < (key.size() - 1)) ? '|' : 0x00; buf += 1; } } } // namespace internal bool registry_has_op_function( const char* name, Span meta_list) { return get_op_function_from_registry(name, meta_list).ok(); } Result get_op_function_from_registry( const char* name, Span meta_list) { // @lint-ignore CLANGTIDY facebook-hte-CArray char buf[KernelKey::MAX_SIZE] = {0}; internal::make_kernel_key_string(meta_list, buf); KernelKey kernel_key = KernelKey(buf); int32_t fallback_idx = -1; for (size_t idx = 0; idx < num_registered_kernels; idx++) { if (strcmp(registered_kernels[idx].name_, name) == 0) { if (registered_kernels[idx].kernel_key_ == kernel_key) { return registered_kernels[idx].op_; } if (registered_kernels[idx].kernel_key_.is_fallback()) { fallback_idx = idx; } } } if (fallback_idx != -1) { return registered_kernels[fallback_idx].op_; } ET_LOG(Error, "kernel '%s' not found.", name); ET_LOG_TENSOR_META(meta_list); return Error::OperatorMissing; } Span get_registered_kernels() { return {registered_kernels, num_registered_kernels}; } } // namespace runtime } // namespace executorch