xref: /aosp_15_r20/external/pytorch/test/mobile/nnc/test_registry.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <torch/csrc/jit/mobile/nnc/registry.h>
3 
4 namespace torch {
5 namespace jit {
6 namespace mobile {
7 namespace nnc {
8 
9 extern "C" {
generated_asm_kernel_foo(void **)10 int generated_asm_kernel_foo(void**) {
11   return 1;
12 }
13 
generated_asm_kernel_bar(void **)14 int generated_asm_kernel_bar(void**) {
15   return 2;
16 }
17 } // extern "C"
18 
19 REGISTER_NNC_KERNEL("foo:v1:VERTOKEN", generated_asm_kernel_foo)
20 REGISTER_NNC_KERNEL("bar:v1:VERTOKEN", generated_asm_kernel_bar)
21 
TEST(MobileNNCRegistryTest,FindAndRun)22 TEST(MobileNNCRegistryTest, FindAndRun) {
23   auto foo_kernel = registry::get_nnc_kernel("foo:v1:VERTOKEN");
24   EXPECT_EQ(foo_kernel->execute(nullptr), 1);
25 
26   auto bar_kernel = registry::get_nnc_kernel("bar:v1:VERTOKEN");
27   EXPECT_EQ(bar_kernel->execute(nullptr), 2);
28 }
29 
TEST(MobileNNCRegistryTest,NoKernel)30 TEST(MobileNNCRegistryTest, NoKernel) {
31   EXPECT_EQ(registry::has_nnc_kernel("missing"), false);
32 }
33 
34 } // namespace nnc
35 } // namespace mobile
36 } // namespace jit
37 } // namespace torch
38