1 #include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
2
3 #include <torch/csrc/distributed/c10d/RankLocal.hpp>
4
5 namespace {
6
7 // Each rank operates on a different `c10d::ProcessGroup` instance for the same
8 // logical process group. Use `RankLocal<GroupRegistry>::get()` to ensure each
9 // rank gets a unique registry.
10 class GroupRegistry {
11 public:
register_group(const std::string & group_name,c10::intrusive_ptr<c10d::ProcessGroup> group)12 void register_group(
13 const std::string& group_name,
14 c10::intrusive_ptr<c10d::ProcessGroup> group) {
15 std::unique_lock write_lock(lock_);
16 auto [_, inserted] = registry_.try_emplace(group_name, std::move(group));
17 TORCH_CHECK(
18 inserted,
19 "A process group is already registered under the name",
20 group_name);
21 }
22
resolve_group(const std::string & group_name)23 c10::intrusive_ptr<c10d::ProcessGroup> resolve_group(
24 const std::string& group_name) {
25 std::shared_lock read_lock(lock_);
26 auto it = registry_.find(group_name);
27 TORCH_CHECK(
28 it != registry_.end(),
29 "Could not resolve the process group registered under the name ",
30 group_name);
31
32 auto group = it->second.lock();
33 TORCH_CHECK(
34 group != nullptr,
35 "Process group registered under the name ",
36 group_name,
37 " has already been destroyed.");
38 return group;
39 }
40
unregister_group(const std::string & group_name)41 void unregister_group(const std::string& group_name) {
42 std::unique_lock write_lock(lock_);
43 registry_.erase(group_name);
44 }
45
unregister_all_groups()46 void unregister_all_groups() {
47 std::unique_lock write_lock(lock_);
48 registry_.clear();
49 }
50
51 private:
52 std::map<std::string, c10::weak_intrusive_ptr<c10d::ProcessGroup>> registry_;
53 std::shared_mutex lock_;
54 };
55
56 } // namespace
57
58 namespace c10d {
59
60 static bool thread_isolation_mode = false;
61 static GroupRegistry process_registry;
62
set_thread_isolation_mode(bool enable)63 void set_thread_isolation_mode(bool enable) {
64 thread_isolation_mode = enable;
65 }
66
get_thread_isolation_mode()67 bool get_thread_isolation_mode() {
68 return thread_isolation_mode;
69 }
70
register_process_group(const std::string & group_name,c10::intrusive_ptr<c10d::ProcessGroup> group)71 void register_process_group(
72 const std::string& group_name,
73 c10::intrusive_ptr<c10d::ProcessGroup> group) {
74 if (thread_isolation_mode) {
75 RankLocal<::GroupRegistry>::get().register_group(
76 group_name, std::move(group));
77 } else {
78 process_registry.register_group(group_name, std::move(group));
79 }
80 }
81
resolve_process_group(const std::string & group_name)82 c10::intrusive_ptr<c10d::ProcessGroup> resolve_process_group(
83 const std::string& group_name) {
84 if (thread_isolation_mode) {
85 return RankLocal<::GroupRegistry>::get().resolve_group(group_name);
86 } else {
87 return process_registry.resolve_group(group_name);
88 }
89 }
90
unregister_process_group(const std::string & group_name)91 void unregister_process_group(const std::string& group_name) {
92 if (thread_isolation_mode) {
93 RankLocal<::GroupRegistry>::get().unregister_group(group_name);
94 } else {
95 process_registry.unregister_group(group_name);
96 }
97 }
98
unregister_all_process_groups()99 void unregister_all_process_groups() {
100 if (thread_isolation_mode) {
101 RankLocal<::GroupRegistry>::get().unregister_all_groups();
102 } else {
103 process_registry.unregister_all_groups();
104 }
105 }
106
107 } // namespace c10d
108