1 #pragma once
2
3 #ifdef USE_C10D_UCC
4
5 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
6 #include <torch/csrc/distributed/c10d/Store.hpp>
7 #include <ucc/api/ucc.h>
8
9 namespace c10d {
10
11 // Macro to generate the error message on a non-successful UCC return value.
12 #define TORCH_UCC_GET_ERROR_MSG(_err, _error_msg, _result) \
13 do { \
14 _err = c10::str( \
15 "[", \
16 std::string(__FILE__), \
17 ":", \
18 std::to_string(__LINE__), \
19 "] ", \
20 logger->getLogPrefix(), \
21 _error_msg, \
22 ", error code ", \
23 _result, \
24 ": ", \
25 ucc_status_string(_result), \
26 ", system error code ", \
27 errno); \
28 } while (0)
29
30 // Macro to throw on a non-successful UCC return value.
31 #define TORCH_UCC_CHECK(_cmd, _error_msg) \
32 do { \
33 ucc_status_t result = _cmd; \
34 if (result != UCC_OK) { \
35 std::string err; \
36 TORCH_UCC_GET_ERROR_MSG(err, _error_msg, result); \
37 TORCH_CHECK(false, err); \
38 } \
39 } while (0)
40
41 // Macro and throw on a non-successful UCC return value and free its request.
42 #define TORCH_UCC_CHECK_REQUEST(_request, _cmd, _error_msg) \
43 do { \
44 ucc_status_t result = _cmd; \
45 if (result != UCC_OK) { \
46 std::string err; \
47 TORCH_UCC_GET_ERROR_MSG(err, _error_msg, result); \
48 if (_request != nullptr) { \
49 ucc_collective_finalize(_request); \
50 } \
51 TORCH_CHECK(false, err); \
52 } \
53 } while (0)
54
55 // Macros to print logs with unified format
56 #define TORCH_UCC_LOG_ERROR(_phase, _msg) \
57 LOG(ERROR) << logger->getLogPrefix(_phase) << "[ERROR] " << _msg;
58 #define TORCH_UCC_LOG_INFO(_phase, _msg) \
59 LOG(INFO) << logger->getLogPrefix(_phase) << "[INFO] " << _msg;
60 #define TORCH_UCC_LOG_DEBUG(_phase, _msg) \
61 VLOG(1) << logger->getLogPrefix(_phase) << "[DEBUG] " << _msg;
62
63 enum torch_ucc_phase_t {
64 TORCH_UCC_UNKNOWN = -1,
65 TORCH_UCC_INIT,
66 TORCH_UCC_HEALTH_CHECK,
67 TORCH_UCC_READY,
68 TORCH_UCC_COLL_POST,
69 TORCH_UCC_COLL_PROGRESS,
70 TORCH_UCC_FINALIZE,
71 };
72
73 const std::map<torch_ucc_phase_t, std::string> ucc_phase_map = {
74 {TORCH_UCC_UNKNOWN, "UNKNOWN"},
75 {TORCH_UCC_INIT, "INIT"},
76 {TORCH_UCC_HEALTH_CHECK, "HEALTH_CHECK"},
77 {TORCH_UCC_READY, "READY"},
78 {TORCH_UCC_COLL_POST, "COLL_POST"},
79 {TORCH_UCC_COLL_PROGRESS, "COLL_PROGRESS"},
80 {TORCH_UCC_FINALIZE, "FINALIZE"},
81 };
82
83 class CommTraceLogger;
84
85 class TORCH_API ProcessGroupUCCLogger : public torch::CustomClassHolder {
86 public:
87 ProcessGroupUCCLogger();
88 ProcessGroupUCCLogger(std::string log_prefix, torch_ucc_phase_t phase);
89
90 std::string getLogPrefix(torch_ucc_phase_t phase = TORCH_UCC_UNKNOWN);
91 void setLogPrefix(std::string log_prefix);
setPhase(torch_ucc_phase_t phase)92 inline void setPhase(torch_ucc_phase_t phase) {
93 local_phase = phase;
94 }
95
96 void initCommsTracer();
97 void flushComms(int rank, int world_size);
98 std::shared_ptr<CommTraceLogger> trace_generator = nullptr;
99
100 protected:
101 std::string log_prefix;
102 torch_ucc_phase_t local_phase = TORCH_UCC_UNKNOWN;
103 bool initialized_CommTraceLogger = false;
104 };
105
106 struct torch_ucc_oob_coll_info_t {
107 c10::intrusive_ptr<Store> store;
108 uint32_t comm_id;
109 int rank;
110 int size;
111 void* rbuf;
112 size_t msglen;
getKeyc10d::torch_ucc_oob_coll_info_t113 std::string getKey(std::string key) {
114 return std::to_string(comm_id) + key;
115 }
116 };
117
118 class CommBase {
119 public:
CommBase(const c10::intrusive_ptr<ProcessGroupUCCLogger> & logger_)120 CommBase(const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger_)
121 : logger(logger_) {}
122 virtual void progress() = 0;
123 virtual void free_request(ucc_coll_req_h request) = 0;
~CommBase()124 virtual ~CommBase() {}
125 c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
126 };
127 class CommUCC : public CommBase {
128 public:
129 ucc_lib_h lib{nullptr};
130 ucc_context_h context{nullptr};
131
132 public:
133 void progress() override;
134 CommUCC(
135 std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
136 const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger);
137 void free_request(ucc_coll_req_h request) override;
138 ~CommUCC();
139 };
140
141 ucc_status_t oob_allgather(
142 void* sbuf,
143 void* rbuf,
144 size_t msglen,
145 void* coll_info,
146 void** req);
147
148 ucc_status_t oob_allgather_test(void* req);
149
150 ucc_status_t oob_allgather_free(void* req);
151
152 // trim: remove spaces before and after the string view
153 // implementation borrowed from https://stackoverflow.com/a/17976541
trim(c10::string_view s)154 inline c10::string_view trim(c10::string_view s) {
155 auto wsfront = std::find_if_not(
156 s.begin(), s.end(), [](int c) { return std::isspace(c); });
157 auto wsback = std::find_if_not(s.rbegin(), s.rend(), [](int c) {
158 return std::isspace(c);
159 }).base();
160 return (
161 wsback <= wsfront ? "" : s.substr(wsfront - s.begin(), wsback - wsfront));
162 }
163
tolower(c10::string_view s)164 inline std::string tolower(c10::string_view s) {
165 std::string result;
166 result.reserve(s.size());
167 for (auto c : s) {
168 result.push_back(std::tolower(c));
169 }
170 return result;
171 }
172
parse_list(std::string list)173 inline std::vector<std::string> parse_list(std::string list) {
174 std::vector<std::string> result;
175 list = tolower(trim(list));
176 while (!list.empty()) {
177 const auto end_pos = list.find_first_of(',');
178 const auto token = trim(list.substr(0, end_pos));
179 result.push_back(std::string(token));
180 list = (end_pos != c10::string_view::npos) ? list.substr(end_pos + 1) : "";
181 }
182 return result;
183 }
184
185 } // namespace c10d
186
187 #endif // USE_C10D_UCC
188