#include #include namespace torch { namespace executor { KernelRegistry& getKernelRegistry() { static KernelRegistry kernel_registry; return kernel_registry; } bool register_kernels(const ArrayRef& kernels) { return getKernelRegistry().register_kernels(kernels); } bool KernelRegistry::register_kernels( const ArrayRef& kernels) { for (const auto& kernel : kernels) { this->kernels_map_[kernel.name_] = kernel.kernel_; } return true; } bool hasKernelFn(const char* name) { return getKernelRegistry().hasKernelFn(name); } bool KernelRegistry::hasKernelFn(const char* name) { auto kernel = this->kernels_map_.find(name); return kernel != this->kernels_map_.end(); } KernelFunction& getKernelFn(const char* name) { return getKernelRegistry().getKernelFn(name); } KernelFunction& KernelRegistry::getKernelFn(const char* name) { auto kernel = this->kernels_map_.find(name); TORCH_CHECK_MSG(kernel != this->kernels_map_.end(), "Kernel not found!"); return kernel->second; } } // namespace executor } // namespace torch