xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/GroupRegistry.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
2 
3 #include <torch/csrc/distributed/c10d/RankLocal.hpp>
4 
5 namespace {
6 
7 // Each rank operates on a different `c10d::ProcessGroup` instance for the same
8 // logical process group. Use `RankLocal<GroupRegistry>::get()` to ensure each
9 // rank gets a unique registry.
10 class GroupRegistry {
11  public:
register_group(const std::string & group_name,c10::intrusive_ptr<c10d::ProcessGroup> group)12   void register_group(
13       const std::string& group_name,
14       c10::intrusive_ptr<c10d::ProcessGroup> group) {
15     std::unique_lock write_lock(lock_);
16     auto [_, inserted] = registry_.try_emplace(group_name, std::move(group));
17     TORCH_CHECK(
18         inserted,
19         "A process group is already registered under the name",
20         group_name);
21   }
22 
resolve_group(const std::string & group_name)23   c10::intrusive_ptr<c10d::ProcessGroup> resolve_group(
24       const std::string& group_name) {
25     std::shared_lock read_lock(lock_);
26     auto it = registry_.find(group_name);
27     TORCH_CHECK(
28         it != registry_.end(),
29         "Could not resolve the process group registered under the name ",
30         group_name);
31 
32     auto group = it->second.lock();
33     TORCH_CHECK(
34         group != nullptr,
35         "Process group registered under the name ",
36         group_name,
37         " has already been destroyed.");
38     return group;
39   }
40 
unregister_group(const std::string & group_name)41   void unregister_group(const std::string& group_name) {
42     std::unique_lock write_lock(lock_);
43     registry_.erase(group_name);
44   }
45 
unregister_all_groups()46   void unregister_all_groups() {
47     std::unique_lock write_lock(lock_);
48     registry_.clear();
49   }
50 
51  private:
52   std::map<std::string, c10::weak_intrusive_ptr<c10d::ProcessGroup>> registry_;
53   std::shared_mutex lock_;
54 };
55 
56 } // namespace
57 
58 namespace c10d {
59 
60 static bool thread_isolation_mode = false;
61 static GroupRegistry process_registry;
62 
set_thread_isolation_mode(bool enable)63 void set_thread_isolation_mode(bool enable) {
64   thread_isolation_mode = enable;
65 }
66 
get_thread_isolation_mode()67 bool get_thread_isolation_mode() {
68   return thread_isolation_mode;
69 }
70 
register_process_group(const std::string & group_name,c10::intrusive_ptr<c10d::ProcessGroup> group)71 void register_process_group(
72     const std::string& group_name,
73     c10::intrusive_ptr<c10d::ProcessGroup> group) {
74   if (thread_isolation_mode) {
75     RankLocal<::GroupRegistry>::get().register_group(
76         group_name, std::move(group));
77   } else {
78     process_registry.register_group(group_name, std::move(group));
79   }
80 }
81 
resolve_process_group(const std::string & group_name)82 c10::intrusive_ptr<c10d::ProcessGroup> resolve_process_group(
83     const std::string& group_name) {
84   if (thread_isolation_mode) {
85     return RankLocal<::GroupRegistry>::get().resolve_group(group_name);
86   } else {
87     return process_registry.resolve_group(group_name);
88   }
89 }
90 
unregister_process_group(const std::string & group_name)91 void unregister_process_group(const std::string& group_name) {
92   if (thread_isolation_mode) {
93     RankLocal<::GroupRegistry>::get().unregister_group(group_name);
94   } else {
95     process_registry.unregister_group(group_name);
96   }
97 }
98 
unregister_all_process_groups()99 void unregister_all_process_groups() {
100   if (thread_isolation_mode) {
101     RankLocal<::GroupRegistry>::get().unregister_all_groups();
102   } else {
103     process_registry.unregister_all_groups();
104   }
105 }
106 
107 } // namespace c10d
108