1 #include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
2
3 #include <fmt/format.h>
4 #include <mutex>
5 #include <shared_mutex>
6 #include <stdexcept>
7 #include <utility>
8
9 namespace c10d::control_plane {
10
11 namespace {
12
13 class HandlerRegistry {
14 public:
registerHandler(const std::string & name,HandlerFunc f)15 void registerHandler(const std::string& name, HandlerFunc f) {
16 std::unique_lock<std::shared_mutex> lock(handlersMutex_);
17
18 if (handlers_.find(name) != handlers_.end()) {
19 throw std::invalid_argument(
20 fmt::format("Handler {} already registered", name));
21 }
22
23 handlers_[name] = std::move(f);
24 }
25
getHandler(const std::string & name)26 HandlerFunc getHandler(const std::string& name) {
27 std::shared_lock<std::shared_mutex> lock(handlersMutex_);
28
29 auto it = handlers_.find(name);
30 if (it == handlers_.end()) {
31 throw std::invalid_argument(
32 fmt::format("Failed to find handler {}", name));
33 }
34 return handlers_[name];
35 }
36
getHandlerNames()37 std::vector<std::string> getHandlerNames() {
38 std::shared_lock<std::shared_mutex> lock(handlersMutex_);
39
40 std::vector<std::string> names;
41 names.reserve(handlers_.size());
42 for (const auto& [name, _] : handlers_) {
43 names.push_back(name);
44 }
45 return names;
46 }
47
48 private:
49 std::shared_mutex handlersMutex_{};
50 std::unordered_map<std::string, HandlerFunc> handlers_{};
51 };
52
getHandlerRegistry()53 HandlerRegistry& getHandlerRegistry() {
54 static HandlerRegistry registry;
55 return registry;
56 }
57
__anon798e2ef20202() 58 RegisterHandler pingHandler{"ping", [](const Request&, Response& res) {
59 res.setContent("pong", "text/plain");
60 res.setStatus(200);
61 }};
62
63 } // namespace
64
registerHandler(const std::string & name,HandlerFunc f)65 void registerHandler(const std::string& name, HandlerFunc f) {
66 return getHandlerRegistry().registerHandler(name, std::move(f));
67 }
68
getHandler(const std::string & name)69 HandlerFunc getHandler(const std::string& name) {
70 return getHandlerRegistry().getHandler(name);
71 }
72
getHandlerNames()73 std::vector<std::string> getHandlerNames() {
74 return getHandlerRegistry().getHandlerNames();
75 }
76
77 } // namespace c10d::control_plane
78