xref: /aosp_15_r20/external/pytorch/c10/test/util/registry_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <iostream>
3 #include <memory>
4 
5 #include <c10/util/Registry.h>
6 
7 // Note: we use a different namespace to test if the macros defined in
8 // Registry.h actually works with a different namespace from c10.
9 namespace c10_test {
10 
11 class Foo {
12  public:
Foo(int x)13   explicit Foo(int x) {
14     // LOG(INFO) << "Foo " << x;
15   }
16   virtual ~Foo() = default;
17 };
18 
19 C10_DECLARE_REGISTRY(FooRegistry, Foo, int);
20 C10_DEFINE_REGISTRY(FooRegistry, Foo, int);
21 #define REGISTER_FOO(clsname) C10_REGISTER_CLASS(FooRegistry, clsname, clsname)
22 
23 class Bar : public Foo {
24  public:
Bar(int x)25   explicit Bar(int x) : Foo(x) {
26     // LOG(INFO) << "Bar " << x;
27   }
28 };
29 REGISTER_FOO(Bar);
30 
31 class AnotherBar : public Foo {
32  public:
AnotherBar(int x)33   explicit AnotherBar(int x) : Foo(x) {
34     // LOG(INFO) << "AnotherBar " << x;
35   }
36 };
37 REGISTER_FOO(AnotherBar);
38 
TEST(RegistryTest,CanRunCreator)39 TEST(RegistryTest, CanRunCreator) {
40   std::unique_ptr<Foo> bar(FooRegistry()->Create("Bar", 1));
41   EXPECT_TRUE(bar != nullptr) << "Cannot create bar.";
42   std::unique_ptr<Foo> another_bar(FooRegistry()->Create("AnotherBar", 1));
43   EXPECT_TRUE(another_bar != nullptr);
44 }
45 
TEST(RegistryTest,ReturnNullOnNonExistingCreator)46 TEST(RegistryTest, ReturnNullOnNonExistingCreator) {
47   EXPECT_EQ(FooRegistry()->Create("Non-existing bar", 1), nullptr);
48 }
49 
50 // C10_REGISTER_CLASS_WITH_PRIORITY defines static variable
RegisterFooDefault()51 void RegisterFooDefault() {
52   C10_REGISTER_CLASS_WITH_PRIORITY(
53       FooRegistry, FooWithPriority, c10::REGISTRY_DEFAULT, Foo);
54 }
55 
RegisterFooDefaultAgain()56 void RegisterFooDefaultAgain() {
57   C10_REGISTER_CLASS_WITH_PRIORITY(
58       FooRegistry, FooWithPriority, c10::REGISTRY_DEFAULT, Foo);
59 }
60 
RegisterFooBarFallback()61 void RegisterFooBarFallback() {
62   C10_REGISTER_CLASS_WITH_PRIORITY(
63       FooRegistry, FooWithPriority, c10::REGISTRY_FALLBACK, Bar);
64 }
65 
RegisterFooBarPreferred()66 void RegisterFooBarPreferred() {
67   C10_REGISTER_CLASS_WITH_PRIORITY(
68       FooRegistry, FooWithPriority, c10::REGISTRY_PREFERRED, Bar);
69 }
70 
TEST(RegistryTest,RegistryPriorities)71 TEST(RegistryTest, RegistryPriorities) {
72   FooRegistry()->SetTerminate(false);
73   RegisterFooDefault();
74 
75   // throws because Foo is already registered with default priority
76   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
77   EXPECT_THROW(RegisterFooDefaultAgain(), std::runtime_error);
78 
79 #ifdef __GXX_RTTI
80   // not going to register Bar because Foo is registered with Default priority
81   RegisterFooBarFallback();
82   std::unique_ptr<Foo> bar1(FooRegistry()->Create("FooWithPriority", 1));
83   EXPECT_EQ(dynamic_cast<Bar*>(bar1.get()), nullptr);
84 
85   // will register Bar because of higher priority
86   RegisterFooBarPreferred();
87   std::unique_ptr<Foo> bar2(FooRegistry()->Create("FooWithPriority", 1));
88   EXPECT_NE(dynamic_cast<Bar*>(bar2.get()), nullptr);
89 #endif
90 }
91 
92 } // namespace c10_test
93