xref: /aosp_15_r20/external/pytorch/c10/util/Registry.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifndef C10_UTIL_REGISTRY_H_
2 #define C10_UTIL_REGISTRY_H_
3 
4 /**
5  * Simple registry implementation that uses static variables to
6  * register object creators during program initialization time.
7  */
8 
9 // NB: This Registry works poorly when you have other namespaces.
10 // Make all macro invocations from inside the at namespace.
11 
12 #include <cstdio>
13 #include <cstdlib>
14 #include <functional>
15 #include <memory>
16 #include <mutex>
17 #include <stdexcept>
18 #include <string>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include <c10/macros/Export.h>
23 #include <c10/macros/Macros.h>
24 #include <c10/util/Type.h>
25 
26 namespace c10 {
27 
28 template <typename KeyType>
KeyStrRepr(const KeyType &)29 inline std::string KeyStrRepr(const KeyType& /*key*/) {
30   return "[key type printing not supported]";
31 }
32 
33 template <>
KeyStrRepr(const std::string & key)34 inline std::string KeyStrRepr(const std::string& key) {
35   return key;
36 }
37 
38 enum RegistryPriority {
39   REGISTRY_FALLBACK = 1,
40   REGISTRY_DEFAULT = 2,
41   REGISTRY_PREFERRED = 3,
42 };
43 
44 /**
45  * @brief A template class that allows one to register classes by keys.
46  *
47  * The keys are usually a std::string specifying the name, but can be anything
48  * that can be used in a std::map.
49  *
50  * You should most likely not use the Registry class explicitly, but use the
51  * helper macros below to declare specific registries as well as registering
52  * objects.
53  */
54 template <class SrcType, class ObjectPtrType, class... Args>
55 class Registry {
56  public:
57   typedef std::function<ObjectPtrType(Args...)> Creator;
58 
registry_()59   Registry(bool warning = true) : registry_(), priority_(), warning_(warning) {}
60 
61   void Register(
62       const SrcType& key,
63       Creator creator,
64       const RegistryPriority priority = REGISTRY_DEFAULT) {
65     std::lock_guard<std::mutex> lock(register_mutex_);
66     // The if statement below is essentially the same as the following line:
67     // TORCH_CHECK_EQ(registry_.count(key), 0) << "Key " << key
68     //                                   << " registered twice.";
69     // However, TORCH_CHECK_EQ depends on google logging, and since registration
70     // is carried out at static initialization time, we do not want to have an
71     // explicit dependency on glog's initialization function.
72     if (registry_.count(key) != 0) {
73       auto cur_priority = priority_[key];
74       if (priority > cur_priority) {
75 #ifdef DEBUG
76         std::string warn_msg =
77             "Overwriting already registered item for key " + KeyStrRepr(key);
78         fprintf(stderr, "%s\n", warn_msg.c_str());
79 #endif
80         registry_[key] = creator;
81         priority_[key] = priority;
82       } else if (priority == cur_priority) {
83         std::string err_msg =
84             "Key already registered with the same priority: " + KeyStrRepr(key);
85         fprintf(stderr, "%s\n", err_msg.c_str());
86         if (terminate_) {
87           std::exit(1);
88         } else {
89           throw std::runtime_error(err_msg);
90         }
91       } else if (warning_) {
92         std::string warn_msg =
93             "Higher priority item already registered, skipping registration of " +
94             KeyStrRepr(key);
95         fprintf(stderr, "%s\n", warn_msg.c_str());
96       }
97     } else {
98       registry_[key] = creator;
99       priority_[key] = priority;
100     }
101   }
102 
103   void Register(
104       const SrcType& key,
105       Creator creator,
106       const std::string& help_msg,
107       const RegistryPriority priority = REGISTRY_DEFAULT) {
108     Register(key, creator, priority);
109     help_message_[key] = help_msg;
110   }
111 
Has(const SrcType & key)112   inline bool Has(const SrcType& key) {
113     return (registry_.count(key) != 0);
114   }
115 
Create(const SrcType & key,Args...args)116   ObjectPtrType Create(const SrcType& key, Args... args) {
117     auto it = registry_.find(key);
118     if (it == registry_.end()) {
119       // Returns nullptr if the key is not registered.
120       return nullptr;
121     }
122     return it->second(args...);
123   }
124 
125   /**
126    * Returns the keys currently registered as a std::vector.
127    */
Keys()128   std::vector<SrcType> Keys() const {
129     std::vector<SrcType> keys;
130     keys.reserve(registry_.size());
131     for (const auto& it : registry_) {
132       keys.push_back(it.first);
133     }
134     return keys;
135   }
136 
HelpMessage()137   inline const std::unordered_map<SrcType, std::string>& HelpMessage() const {
138     return help_message_;
139   }
140 
HelpMessage(const SrcType & key)141   const char* HelpMessage(const SrcType& key) const {
142     auto it = help_message_.find(key);
143     if (it == help_message_.end()) {
144       return nullptr;
145     }
146     return it->second.c_str();
147   }
148 
149   // Used for testing, if terminate is unset, Registry throws instead of
150   // calling std::exit
SetTerminate(bool terminate)151   void SetTerminate(bool terminate) {
152     terminate_ = terminate;
153   }
154 
155  private:
156   std::unordered_map<SrcType, Creator> registry_;
157   std::unordered_map<SrcType, RegistryPriority> priority_;
158   bool terminate_{true};
159   const bool warning_;
160   std::unordered_map<SrcType, std::string> help_message_;
161   std::mutex register_mutex_;
162 
163   C10_DISABLE_COPY_AND_ASSIGN(Registry);
164 };
165 
166 template <class SrcType, class ObjectPtrType, class... Args>
167 class Registerer {
168  public:
169   explicit Registerer(
170       const SrcType& key,
171       Registry<SrcType, ObjectPtrType, Args...>* registry,
172       typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator,
173       const std::string& help_msg = "") {
174     registry->Register(key, creator, help_msg);
175   }
176 
177   explicit Registerer(
178       const SrcType& key,
179       const RegistryPriority priority,
180       Registry<SrcType, ObjectPtrType, Args...>* registry,
181       typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator,
182       const std::string& help_msg = "") {
183     registry->Register(key, creator, help_msg, priority);
184   }
185 
186   template <class DerivedType>
DefaultCreator(Args...args)187   static ObjectPtrType DefaultCreator(Args... args) {
188     return ObjectPtrType(new DerivedType(args...));
189   }
190 };
191 
192 /**
193  * C10_DECLARE_TYPED_REGISTRY is a macro that expands to a function
194  * declaration, as well as creating a convenient typename for its corresponding
195  * registerer.
196  */
197 // Note on C10_IMPORT and C10_EXPORT below: we need to explicitly mark DECLARE
198 // as import and DEFINE as export, because these registry macros will be used
199 // in downstream shared libraries as well, and one cannot use *_API - the API
200 // macro will be defined on a per-shared-library basis. Semantically, when one
201 // declares a typed registry it is always going to be IMPORT, and when one
202 // defines a registry (which should happen ONLY ONCE and ONLY IN SOURCE FILE),
203 // the instantiation unit is always going to be exported.
204 //
205 // The only unique condition is when in the same file one does DECLARE and
206 // DEFINE - in Windows compilers, this generates a warning that dllimport and
207 // dllexport are mixed, but the warning is fine and linker will be properly
208 // exporting the symbol. Same thing happens in the gflags flag declaration and
209 // definition caes.
210 #define C10_DECLARE_TYPED_REGISTRY(                                      \
211     RegistryName, SrcType, ObjectType, PtrType, ...)                     \
212   C10_API ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>*  \
213   RegistryName();                                                        \
214   typedef ::c10::Registerer<SrcType, PtrType<ObjectType>, ##__VA_ARGS__> \
215       Registerer##RegistryName
216 
217 #define TORCH_DECLARE_TYPED_REGISTRY(                                     \
218     RegistryName, SrcType, ObjectType, PtrType, ...)                      \
219   TORCH_API ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
220   RegistryName();                                                         \
221   typedef ::c10::Registerer<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>  \
222       Registerer##RegistryName
223 
224 #define C10_DEFINE_TYPED_REGISTRY(                                         \
225     RegistryName, SrcType, ObjectType, PtrType, ...)                       \
226   C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
227   RegistryName() {                                                         \
228     static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>*   \
229         registry = new ::c10::                                             \
230             Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>();       \
231     return registry;                                                       \
232   }
233 
234 #define C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING(                            \
235     RegistryName, SrcType, ObjectType, PtrType, ...)                          \
236   C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>*    \
237   RegistryName() {                                                            \
238     static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>*      \
239         registry =                                                            \
240             new ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>( \
241                 false);                                                       \
242     return registry;                                                          \
243   }
244 
245 // Note(Yangqing): The __VA_ARGS__ below allows one to specify a templated
246 // creator with comma in its templated arguments.
247 #define C10_REGISTER_TYPED_CREATOR(RegistryName, key, ...)                  \
248   static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
249       key, RegistryName(), ##__VA_ARGS__);
250 
251 #define C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY(                           \
252     RegistryName, key, priority, ...)                                       \
253   static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
254       key, priority, RegistryName(), ##__VA_ARGS__);
255 
256 #define C10_REGISTER_TYPED_CLASS(RegistryName, key, ...)                    \
257   static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
258       key,                                                                  \
259       RegistryName(),                                                       \
260       Registerer##RegistryName::DefaultCreator<__VA_ARGS__>,                \
261       ::c10::demangle_type<__VA_ARGS__>());
262 
263 #define C10_REGISTER_TYPED_CLASS_WITH_PRIORITY(                             \
264     RegistryName, key, priority, ...)                                       \
265   static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
266       key,                                                                  \
267       priority,                                                             \
268       RegistryName(),                                                       \
269       Registerer##RegistryName::DefaultCreator<__VA_ARGS__>,                \
270       ::c10::demangle_type<__VA_ARGS__>());
271 
272 // C10_DECLARE_REGISTRY and C10_DEFINE_REGISTRY are hard-wired to use
273 // std::string as the key type, because that is the most commonly used cases.
274 #define C10_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
275   C10_DECLARE_TYPED_REGISTRY(                               \
276       RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
277 
278 #define TORCH_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
279   TORCH_DECLARE_TYPED_REGISTRY(                               \
280       RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
281 
282 #define C10_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \
283   C10_DEFINE_TYPED_REGISTRY(                               \
284       RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
285 
286 #define C10_DEFINE_REGISTRY_WITHOUT_WARNING(RegistryName, ObjectType, ...) \
287   C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING(                               \
288       RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
289 
290 #define C10_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
291   C10_DECLARE_TYPED_REGISTRY(                                      \
292       RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
293 
294 #define TORCH_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
295   TORCH_DECLARE_TYPED_REGISTRY(                                      \
296       RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
297 
298 #define C10_DEFINE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
299   C10_DEFINE_TYPED_REGISTRY(                                      \
300       RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
301 
302 #define C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING( \
303     RegistryName, ObjectType, ...)                  \
304   C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING(        \
305       RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
306 
307 // C10_REGISTER_CREATOR and C10_REGISTER_CLASS are hard-wired to use std::string
308 // as the key
309 // type, because that is the most commonly used cases.
310 #define C10_REGISTER_CREATOR(RegistryName, key, ...) \
311   C10_REGISTER_TYPED_CREATOR(RegistryName, #key, __VA_ARGS__)
312 
313 #define C10_REGISTER_CREATOR_WITH_PRIORITY(RegistryName, key, priority, ...) \
314   C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY(                                  \
315       RegistryName, #key, priority, __VA_ARGS__)
316 
317 #define C10_REGISTER_CLASS(RegistryName, key, ...) \
318   C10_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__)
319 
320 #define C10_REGISTER_CLASS_WITH_PRIORITY(RegistryName, key, priority, ...) \
321   C10_REGISTER_TYPED_CLASS_WITH_PRIORITY(                                  \
322       RegistryName, #key, priority, __VA_ARGS__)
323 
324 } // namespace c10
325 
326 #endif // C10_UTIL_REGISTRY_H_
327