1 #ifndef C10_UTIL_REGISTRY_H_
2 #define C10_UTIL_REGISTRY_H_
3
4 /**
5 * Simple registry implementation that uses static variables to
6 * register object creators during program initialization time.
7 */
8
9 // NB: This Registry works poorly when you have other namespaces.
10 // Make all macro invocations from inside the at namespace.
11
12 #include <cstdio>
13 #include <cstdlib>
14 #include <functional>
15 #include <memory>
16 #include <mutex>
17 #include <stdexcept>
18 #include <string>
19 #include <unordered_map>
20 #include <vector>
21
22 #include <c10/macros/Export.h>
23 #include <c10/macros/Macros.h>
24 #include <c10/util/Type.h>
25
26 namespace c10 {
27
28 template <typename KeyType>
KeyStrRepr(const KeyType &)29 inline std::string KeyStrRepr(const KeyType& /*key*/) {
30 return "[key type printing not supported]";
31 }
32
33 template <>
KeyStrRepr(const std::string & key)34 inline std::string KeyStrRepr(const std::string& key) {
35 return key;
36 }
37
38 enum RegistryPriority {
39 REGISTRY_FALLBACK = 1,
40 REGISTRY_DEFAULT = 2,
41 REGISTRY_PREFERRED = 3,
42 };
43
44 /**
45 * @brief A template class that allows one to register classes by keys.
46 *
47 * The keys are usually a std::string specifying the name, but can be anything
48 * that can be used in a std::map.
49 *
50 * You should most likely not use the Registry class explicitly, but use the
51 * helper macros below to declare specific registries as well as registering
52 * objects.
53 */
54 template <class SrcType, class ObjectPtrType, class... Args>
55 class Registry {
56 public:
57 typedef std::function<ObjectPtrType(Args...)> Creator;
58
registry_()59 Registry(bool warning = true) : registry_(), priority_(), warning_(warning) {}
60
61 void Register(
62 const SrcType& key,
63 Creator creator,
64 const RegistryPriority priority = REGISTRY_DEFAULT) {
65 std::lock_guard<std::mutex> lock(register_mutex_);
66 // The if statement below is essentially the same as the following line:
67 // TORCH_CHECK_EQ(registry_.count(key), 0) << "Key " << key
68 // << " registered twice.";
69 // However, TORCH_CHECK_EQ depends on google logging, and since registration
70 // is carried out at static initialization time, we do not want to have an
71 // explicit dependency on glog's initialization function.
72 if (registry_.count(key) != 0) {
73 auto cur_priority = priority_[key];
74 if (priority > cur_priority) {
75 #ifdef DEBUG
76 std::string warn_msg =
77 "Overwriting already registered item for key " + KeyStrRepr(key);
78 fprintf(stderr, "%s\n", warn_msg.c_str());
79 #endif
80 registry_[key] = creator;
81 priority_[key] = priority;
82 } else if (priority == cur_priority) {
83 std::string err_msg =
84 "Key already registered with the same priority: " + KeyStrRepr(key);
85 fprintf(stderr, "%s\n", err_msg.c_str());
86 if (terminate_) {
87 std::exit(1);
88 } else {
89 throw std::runtime_error(err_msg);
90 }
91 } else if (warning_) {
92 std::string warn_msg =
93 "Higher priority item already registered, skipping registration of " +
94 KeyStrRepr(key);
95 fprintf(stderr, "%s\n", warn_msg.c_str());
96 }
97 } else {
98 registry_[key] = creator;
99 priority_[key] = priority;
100 }
101 }
102
103 void Register(
104 const SrcType& key,
105 Creator creator,
106 const std::string& help_msg,
107 const RegistryPriority priority = REGISTRY_DEFAULT) {
108 Register(key, creator, priority);
109 help_message_[key] = help_msg;
110 }
111
Has(const SrcType & key)112 inline bool Has(const SrcType& key) {
113 return (registry_.count(key) != 0);
114 }
115
Create(const SrcType & key,Args...args)116 ObjectPtrType Create(const SrcType& key, Args... args) {
117 auto it = registry_.find(key);
118 if (it == registry_.end()) {
119 // Returns nullptr if the key is not registered.
120 return nullptr;
121 }
122 return it->second(args...);
123 }
124
125 /**
126 * Returns the keys currently registered as a std::vector.
127 */
Keys()128 std::vector<SrcType> Keys() const {
129 std::vector<SrcType> keys;
130 keys.reserve(registry_.size());
131 for (const auto& it : registry_) {
132 keys.push_back(it.first);
133 }
134 return keys;
135 }
136
HelpMessage()137 inline const std::unordered_map<SrcType, std::string>& HelpMessage() const {
138 return help_message_;
139 }
140
HelpMessage(const SrcType & key)141 const char* HelpMessage(const SrcType& key) const {
142 auto it = help_message_.find(key);
143 if (it == help_message_.end()) {
144 return nullptr;
145 }
146 return it->second.c_str();
147 }
148
149 // Used for testing, if terminate is unset, Registry throws instead of
150 // calling std::exit
SetTerminate(bool terminate)151 void SetTerminate(bool terminate) {
152 terminate_ = terminate;
153 }
154
155 private:
156 std::unordered_map<SrcType, Creator> registry_;
157 std::unordered_map<SrcType, RegistryPriority> priority_;
158 bool terminate_{true};
159 const bool warning_;
160 std::unordered_map<SrcType, std::string> help_message_;
161 std::mutex register_mutex_;
162
163 C10_DISABLE_COPY_AND_ASSIGN(Registry);
164 };
165
166 template <class SrcType, class ObjectPtrType, class... Args>
167 class Registerer {
168 public:
169 explicit Registerer(
170 const SrcType& key,
171 Registry<SrcType, ObjectPtrType, Args...>* registry,
172 typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator,
173 const std::string& help_msg = "") {
174 registry->Register(key, creator, help_msg);
175 }
176
177 explicit Registerer(
178 const SrcType& key,
179 const RegistryPriority priority,
180 Registry<SrcType, ObjectPtrType, Args...>* registry,
181 typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator,
182 const std::string& help_msg = "") {
183 registry->Register(key, creator, help_msg, priority);
184 }
185
186 template <class DerivedType>
DefaultCreator(Args...args)187 static ObjectPtrType DefaultCreator(Args... args) {
188 return ObjectPtrType(new DerivedType(args...));
189 }
190 };
191
192 /**
193 * C10_DECLARE_TYPED_REGISTRY is a macro that expands to a function
194 * declaration, as well as creating a convenient typename for its corresponding
195 * registerer.
196 */
197 // Note on C10_IMPORT and C10_EXPORT below: we need to explicitly mark DECLARE
198 // as import and DEFINE as export, because these registry macros will be used
199 // in downstream shared libraries as well, and one cannot use *_API - the API
200 // macro will be defined on a per-shared-library basis. Semantically, when one
201 // declares a typed registry it is always going to be IMPORT, and when one
202 // defines a registry (which should happen ONLY ONCE and ONLY IN SOURCE FILE),
203 // the instantiation unit is always going to be exported.
204 //
205 // The only unique condition is when in the same file one does DECLARE and
206 // DEFINE - in Windows compilers, this generates a warning that dllimport and
207 // dllexport are mixed, but the warning is fine and linker will be properly
208 // exporting the symbol. Same thing happens in the gflags flag declaration and
209 // definition caes.
210 #define C10_DECLARE_TYPED_REGISTRY( \
211 RegistryName, SrcType, ObjectType, PtrType, ...) \
212 C10_API ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
213 RegistryName(); \
214 typedef ::c10::Registerer<SrcType, PtrType<ObjectType>, ##__VA_ARGS__> \
215 Registerer##RegistryName
216
217 #define TORCH_DECLARE_TYPED_REGISTRY( \
218 RegistryName, SrcType, ObjectType, PtrType, ...) \
219 TORCH_API ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
220 RegistryName(); \
221 typedef ::c10::Registerer<SrcType, PtrType<ObjectType>, ##__VA_ARGS__> \
222 Registerer##RegistryName
223
224 #define C10_DEFINE_TYPED_REGISTRY( \
225 RegistryName, SrcType, ObjectType, PtrType, ...) \
226 C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
227 RegistryName() { \
228 static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
229 registry = new ::c10:: \
230 Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>(); \
231 return registry; \
232 }
233
234 #define C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
235 RegistryName, SrcType, ObjectType, PtrType, ...) \
236 C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
237 RegistryName() { \
238 static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
239 registry = \
240 new ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>( \
241 false); \
242 return registry; \
243 }
244
245 // Note(Yangqing): The __VA_ARGS__ below allows one to specify a templated
246 // creator with comma in its templated arguments.
247 #define C10_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \
248 static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
249 key, RegistryName(), ##__VA_ARGS__);
250
251 #define C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \
252 RegistryName, key, priority, ...) \
253 static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
254 key, priority, RegistryName(), ##__VA_ARGS__);
255
256 #define C10_REGISTER_TYPED_CLASS(RegistryName, key, ...) \
257 static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
258 key, \
259 RegistryName(), \
260 Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \
261 ::c10::demangle_type<__VA_ARGS__>());
262
263 #define C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \
264 RegistryName, key, priority, ...) \
265 static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
266 key, \
267 priority, \
268 RegistryName(), \
269 Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \
270 ::c10::demangle_type<__VA_ARGS__>());
271
272 // C10_DECLARE_REGISTRY and C10_DEFINE_REGISTRY are hard-wired to use
273 // std::string as the key type, because that is the most commonly used cases.
274 #define C10_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
275 C10_DECLARE_TYPED_REGISTRY( \
276 RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
277
278 #define TORCH_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
279 TORCH_DECLARE_TYPED_REGISTRY( \
280 RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
281
282 #define C10_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \
283 C10_DEFINE_TYPED_REGISTRY( \
284 RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
285
286 #define C10_DEFINE_REGISTRY_WITHOUT_WARNING(RegistryName, ObjectType, ...) \
287 C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
288 RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
289
290 #define C10_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
291 C10_DECLARE_TYPED_REGISTRY( \
292 RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
293
294 #define TORCH_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
295 TORCH_DECLARE_TYPED_REGISTRY( \
296 RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
297
298 #define C10_DEFINE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
299 C10_DEFINE_TYPED_REGISTRY( \
300 RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
301
302 #define C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING( \
303 RegistryName, ObjectType, ...) \
304 C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
305 RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
306
307 // C10_REGISTER_CREATOR and C10_REGISTER_CLASS are hard-wired to use std::string
308 // as the key
309 // type, because that is the most commonly used cases.
310 #define C10_REGISTER_CREATOR(RegistryName, key, ...) \
311 C10_REGISTER_TYPED_CREATOR(RegistryName, #key, __VA_ARGS__)
312
313 #define C10_REGISTER_CREATOR_WITH_PRIORITY(RegistryName, key, priority, ...) \
314 C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \
315 RegistryName, #key, priority, __VA_ARGS__)
316
317 #define C10_REGISTER_CLASS(RegistryName, key, ...) \
318 C10_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__)
319
320 #define C10_REGISTER_CLASS_WITH_PRIORITY(RegistryName, key, priority, ...) \
321 C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \
322 RegistryName, #key, priority, __VA_ARGS__)
323
324 } // namespace c10
325
326 #endif // C10_UTIL_REGISTRY_H_
327