xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/UCCUtils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef USE_C10D_UCC
2 
3 #include <torch/csrc/distributed/c10d/UCCTracing.hpp>
4 #include <torch/csrc/distributed/c10d/UCCUtils.hpp>
5 
6 #include <cctype>
7 #include <string>
8 #include <unordered_map>
9 #include <unordered_set>
10 
11 namespace c10d {
12 
13 namespace {
14 // Constants for store keys.
15 constexpr char kTeamRank[] = "teamr";
16 constexpr char kAllGatherDone[] = "ag_done";
17 constexpr char kAllGatherFree[] = "ag_free";
18 } // namespace
19 
oob_allgather(void * sbuf,void * rbuf,size_t msglen,void * coll_info,void ** req)20 ucc_status_t oob_allgather(
21     void* sbuf,
22     void* rbuf,
23     size_t msglen,
24     void* coll_info,
25     void** req) {
26   auto* info = reinterpret_cast<torch_ucc_oob_coll_info_t*>(coll_info);
27   TORCH_CHECK(info != nullptr);
28   std::vector<uint8_t> val = std::vector<uint8_t>(
29       reinterpret_cast<uint8_t*>(sbuf),
30       reinterpret_cast<uint8_t*>(sbuf) + msglen);
31   try {
32     info->store->set(info->getKey(kTeamRank + std::to_string(info->rank)), val);
33     info->rbuf = rbuf;
34     info->msglen = msglen;
35     *req = coll_info;
36   } catch (std::exception& ex) {
37     LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. "
38                << "[" << ex.what() << "]";
39     return UCC_ERR_NO_MESSAGE;
40   }
41   return UCC_OK;
42 }
43 
oob_allgather_test(void * req)44 ucc_status_t oob_allgather_test(void* req) {
45   auto* info = reinterpret_cast<torch_ucc_oob_coll_info_t*>(req);
46   TORCH_CHECK(info != nullptr);
47 
48   try {
49     for (int r = 0; r < info->size; r++) {
50       if (!info->store->check({info->getKey(kTeamRank + std::to_string(r))})) {
51         return UCC_INPROGRESS;
52       }
53     }
54     for (int r = 0; r < info->size; r++) {
55       std::vector<uint8_t> data =
56           info->store->get(info->getKey(kTeamRank + std::to_string(r)));
57       memcpy(
58           (void*)((ptrdiff_t)info->rbuf + info->msglen * r),
59           data.data(),
60           info->msglen);
61     }
62   } catch (std::exception& ex) {
63     LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. "
64                << "[" << ex.what() << "]";
65     return UCC_ERR_NO_MESSAGE;
66   }
67   return UCC_OK;
68 }
69 
oob_allgather_free(void * req)70 ucc_status_t oob_allgather_free(void* req) {
71   auto* info = reinterpret_cast<torch_ucc_oob_coll_info_t*>(req);
72   TORCH_CHECK(info != nullptr);
73   try {
74     int num_done = info->store->add({info->getKey(kAllGatherDone)}, 1);
75     if (num_done == info->size) {
76       info->store->deleteKey(info->getKey(kAllGatherDone));
77       // Note: to avoid race condition, it's important to remove all keys in
78       // oob_allgather_free first and only after that signal completion to
79       // other ranks
80       for (const auto r : c10::irange(info->size)) {
81         info->store->deleteKey(info->getKey(kTeamRank + std::to_string(r)));
82       }
83       for (const auto r : c10::irange(info->size)) {
84         info->store->add({info->getKey(kAllGatherFree + std::to_string(r))}, 1);
85       }
86     } else {
87       info->store->wait(
88           {info->getKey(kAllGatherFree + std::to_string(info->rank))});
89     }
90     info->store->deleteKey(
91         info->getKey(kAllGatherFree + std::to_string(info->rank)));
92   } catch (std::exception& ex) {
93     LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. "
94                << "[" << ex.what() << "]";
95     return UCC_ERR_NO_MESSAGE;
96   }
97   return UCC_OK;
98 }
99 
CommUCC(std::shared_ptr<torch_ucc_oob_coll_info_t> oob,const c10::intrusive_ptr<ProcessGroupUCCLogger> & logger)100 CommUCC::CommUCC(
101     std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
102     const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger)
103     : CommBase(logger) {
104   ucc_lib_config_h lib_config;
105   ucc_context_config_h context_config;
106   ucc_lib_params_t lib_params;
107   ucc_context_params_t context_params;
108   ucc_status_t st;
109 
110   TORCH_UCC_CHECK(
111       ucc_lib_config_read("TORCH", nullptr, &lib_config),
112       "failed to read UCC lib config");
113   memset(&lib_params, 0, sizeof(ucc_lib_params_t));
114   lib_params.mask = UCC_LIB_PARAM_FIELD_THREAD_MODE;
115   lib_params.thread_mode = UCC_THREAD_MULTIPLE;
116   TORCH_UCC_CHECK(
117       ucc_init(&lib_params, lib_config, &lib), "failed to init UCC lib");
118   ucc_lib_config_release(lib_config);
119   ucc_lib_attr_t lib_attr;
120   lib_attr.mask = UCC_LIB_ATTR_FIELD_THREAD_MODE;
121   TORCH_UCC_CHECK(
122       ucc_lib_get_attr(lib, &lib_attr), "failed to query for lib attr");
123   TORCH_CHECK(
124       lib_attr.thread_mode == UCC_THREAD_MULTIPLE,
125       "ucc library wasn't initialized with multithreading support, "
126       "please check ucc build options");
127   st = ucc_context_config_read(lib, NULL, &context_config);
128   if (st != UCC_OK) {
129     // FIXME: would this cause deadlock if only one rank fails?
130     TORCH_UCC_CHECK(
131         ucc_finalize(lib),
132         "failed to finalize UCC library when failing to read UCC context config");
133     TORCH_UCC_LOG_ERROR(
134         TORCH_UCC_INIT,
135         c10::str("failed to read UCC context config: ", ucc_status_string(st)));
136     throw std::runtime_error(ucc_status_string(st));
137   }
138   st = ucc_context_config_modify(
139       context_config,
140       NULL,
141       "ESTIMATED_NUM_EPS",
142       std::to_string(oob->size).c_str());
143   if (st != UCC_OK) {
144     ucc_context_config_release(context_config);
145     ucc_finalize(lib);
146     TORCH_UCC_LOG_ERROR(
147         TORCH_UCC_INIT,
148         c10::str(
149             "UCC failed to modify UCC context config: ",
150             ucc_status_string(st)));
151     throw std::runtime_error(ucc_status_string(st));
152   }
153   memset(&context_params, 0, sizeof(ucc_context_params_t));
154   context_params.mask =
155       UCC_CONTEXT_PARAM_FIELD_TYPE | UCC_CONTEXT_PARAM_FIELD_OOB;
156   context_params.type = UCC_CONTEXT_SHARED;
157   context_params.oob.n_oob_eps = oob->size;
158   context_params.oob.oob_ep = oob->rank;
159   context_params.oob.allgather = oob_allgather;
160   context_params.oob.req_test = oob_allgather_test;
161   context_params.oob.req_free = oob_allgather_free;
162   context_params.oob.coll_info = oob.get();
163   st = ucc_context_create(lib, &context_params, context_config, &context);
164   ucc_context_config_release(context_config);
165   if (st != UCC_OK) {
166     TORCH_UCC_CHECK(
167         ucc_finalize(lib),
168         "failed to finalize UCC library when failing to creat UCC context");
169     TORCH_UCC_LOG_ERROR(
170         TORCH_UCC_INIT,
171         c10::str("UCC failed to create UCC context: ", ucc_status_string(st)));
172     throw std::runtime_error(ucc_status_string(st));
173   }
174 }
175 
progress()176 void CommUCC::progress() {
177   TORCH_UCC_CHECK(
178       ucc_context_progress(context), "failed to progress UCC collective");
179 }
180 
free_request(ucc_coll_req_h request)181 void CommUCC::free_request(ucc_coll_req_h request) {
182   TORCH_UCC_CHECK(
183       ucc_collective_finalize(request), "failed to release UCC request");
184 }
185 
~CommUCC()186 CommUCC::~CommUCC() {
187   if (context != nullptr) {
188     TORCH_UCC_CHECK(
189         ucc_context_destroy(context), "failed to destroy UCC context");
190   }
191   if (lib != nullptr) {
192     TORCH_UCC_CHECK(ucc_finalize(lib), "failed to finalize UCC library");
193   }
194   context = nullptr;
195   lib = nullptr;
196 }
197 
getLogPrefix(torch_ucc_phase_t phase)198 std::string ProcessGroupUCCLogger::getLogPrefix(torch_ucc_phase_t phase) {
199   // caller can override the phase stored locally
200   torch_ucc_phase_t phase_ =
201       (local_phase != phase && phase != TORCH_UCC_UNKNOWN) ? phase
202                                                            : local_phase;
203   return c10::str(log_prefix, "[", ucc_phase_map.at(phase_), "]");
204 }
setLogPrefix(std::string log_prefix_)205 void ProcessGroupUCCLogger::setLogPrefix(std::string log_prefix_) {
206   log_prefix = log_prefix_;
207 }
208 
ProcessGroupUCCLogger()209 ProcessGroupUCCLogger::ProcessGroupUCCLogger() {
210   setLogPrefix("[ProcessGroupUCC]");
211 }
ProcessGroupUCCLogger(std::string log_prefix,torch_ucc_phase_t phase)212 ProcessGroupUCCLogger::ProcessGroupUCCLogger(
213     std::string log_prefix,
214     torch_ucc_phase_t phase)
215     : local_phase(phase) {
216   setLogPrefix(log_prefix);
217 }
218 
219 } // namespace c10d
220 
221 #endif // USE_C10D_UCC
222