xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/UCCUtils.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_C10D_UCC
4 
5 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
6 #include <torch/csrc/distributed/c10d/Store.hpp>
7 #include <ucc/api/ucc.h>
8 
9 namespace c10d {
10 
11 // Macro to generate the error message on a non-successful UCC return value.
12 #define TORCH_UCC_GET_ERROR_MSG(_err, _error_msg, _result) \
13   do {                                                     \
14     _err = c10::str(                                       \
15         "[",                                               \
16         std::string(__FILE__),                             \
17         ":",                                               \
18         std::to_string(__LINE__),                          \
19         "] ",                                              \
20         logger->getLogPrefix(),                            \
21         _error_msg,                                        \
22         ", error code ",                                   \
23         _result,                                           \
24         ": ",                                              \
25         ucc_status_string(_result),                        \
26         ", system error code ",                            \
27         errno);                                            \
28   } while (0)
29 
30 // Macro to throw on a non-successful UCC return value.
31 #define TORCH_UCC_CHECK(_cmd, _error_msg)               \
32   do {                                                  \
33     ucc_status_t result = _cmd;                         \
34     if (result != UCC_OK) {                             \
35       std::string err;                                  \
36       TORCH_UCC_GET_ERROR_MSG(err, _error_msg, result); \
37       TORCH_CHECK(false, err);                          \
38     }                                                   \
39   } while (0)
40 
41 // Macro and throw on a non-successful UCC return value and free its request.
42 #define TORCH_UCC_CHECK_REQUEST(_request, _cmd, _error_msg) \
43   do {                                                      \
44     ucc_status_t result = _cmd;                             \
45     if (result != UCC_OK) {                                 \
46       std::string err;                                      \
47       TORCH_UCC_GET_ERROR_MSG(err, _error_msg, result);     \
48       if (_request != nullptr) {                            \
49         ucc_collective_finalize(_request);                  \
50       }                                                     \
51       TORCH_CHECK(false, err);                              \
52     }                                                       \
53   } while (0)
54 
55 // Macros to print logs with unified format
56 #define TORCH_UCC_LOG_ERROR(_phase, _msg) \
57   LOG(ERROR) << logger->getLogPrefix(_phase) << "[ERROR] " << _msg;
58 #define TORCH_UCC_LOG_INFO(_phase, _msg) \
59   LOG(INFO) << logger->getLogPrefix(_phase) << "[INFO] " << _msg;
60 #define TORCH_UCC_LOG_DEBUG(_phase, _msg) \
61   VLOG(1) << logger->getLogPrefix(_phase) << "[DEBUG] " << _msg;
62 
63 enum torch_ucc_phase_t {
64   TORCH_UCC_UNKNOWN = -1,
65   TORCH_UCC_INIT,
66   TORCH_UCC_HEALTH_CHECK,
67   TORCH_UCC_READY,
68   TORCH_UCC_COLL_POST,
69   TORCH_UCC_COLL_PROGRESS,
70   TORCH_UCC_FINALIZE,
71 };
72 
73 const std::map<torch_ucc_phase_t, std::string> ucc_phase_map = {
74     {TORCH_UCC_UNKNOWN, "UNKNOWN"},
75     {TORCH_UCC_INIT, "INIT"},
76     {TORCH_UCC_HEALTH_CHECK, "HEALTH_CHECK"},
77     {TORCH_UCC_READY, "READY"},
78     {TORCH_UCC_COLL_POST, "COLL_POST"},
79     {TORCH_UCC_COLL_PROGRESS, "COLL_PROGRESS"},
80     {TORCH_UCC_FINALIZE, "FINALIZE"},
81 };
82 
83 class CommTraceLogger;
84 
85 class TORCH_API ProcessGroupUCCLogger : public torch::CustomClassHolder {
86  public:
87   ProcessGroupUCCLogger();
88   ProcessGroupUCCLogger(std::string log_prefix, torch_ucc_phase_t phase);
89 
90   std::string getLogPrefix(torch_ucc_phase_t phase = TORCH_UCC_UNKNOWN);
91   void setLogPrefix(std::string log_prefix);
setPhase(torch_ucc_phase_t phase)92   inline void setPhase(torch_ucc_phase_t phase) {
93     local_phase = phase;
94   }
95 
96   void initCommsTracer();
97   void flushComms(int rank, int world_size);
98   std::shared_ptr<CommTraceLogger> trace_generator = nullptr;
99 
100  protected:
101   std::string log_prefix;
102   torch_ucc_phase_t local_phase = TORCH_UCC_UNKNOWN;
103   bool initialized_CommTraceLogger = false;
104 };
105 
106 struct torch_ucc_oob_coll_info_t {
107   c10::intrusive_ptr<Store> store;
108   uint32_t comm_id;
109   int rank;
110   int size;
111   void* rbuf;
112   size_t msglen;
getKeyc10d::torch_ucc_oob_coll_info_t113   std::string getKey(std::string key) {
114     return std::to_string(comm_id) + key;
115   }
116 };
117 
118 class CommBase {
119  public:
CommBase(const c10::intrusive_ptr<ProcessGroupUCCLogger> & logger_)120   CommBase(const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger_)
121       : logger(logger_) {}
122   virtual void progress() = 0;
123   virtual void free_request(ucc_coll_req_h request) = 0;
~CommBase()124   virtual ~CommBase() {}
125   c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
126 };
127 class CommUCC : public CommBase {
128  public:
129   ucc_lib_h lib{nullptr};
130   ucc_context_h context{nullptr};
131 
132  public:
133   void progress() override;
134   CommUCC(
135       std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
136       const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger);
137   void free_request(ucc_coll_req_h request) override;
138   ~CommUCC();
139 };
140 
141 ucc_status_t oob_allgather(
142     void* sbuf,
143     void* rbuf,
144     size_t msglen,
145     void* coll_info,
146     void** req);
147 
148 ucc_status_t oob_allgather_test(void* req);
149 
150 ucc_status_t oob_allgather_free(void* req);
151 
152 // trim: remove spaces before and after the string view
153 // implementation borrowed from https://stackoverflow.com/a/17976541
trim(c10::string_view s)154 inline c10::string_view trim(c10::string_view s) {
155   auto wsfront = std::find_if_not(
156       s.begin(), s.end(), [](int c) { return std::isspace(c); });
157   auto wsback = std::find_if_not(s.rbegin(), s.rend(), [](int c) {
158                   return std::isspace(c);
159                 }).base();
160   return (
161       wsback <= wsfront ? "" : s.substr(wsfront - s.begin(), wsback - wsfront));
162 }
163 
tolower(c10::string_view s)164 inline std::string tolower(c10::string_view s) {
165   std::string result;
166   result.reserve(s.size());
167   for (auto c : s) {
168     result.push_back(std::tolower(c));
169   }
170   return result;
171 }
172 
parse_list(std::string list)173 inline std::vector<std::string> parse_list(std::string list) {
174   std::vector<std::string> result;
175   list = tolower(trim(list));
176   while (!list.empty()) {
177     const auto end_pos = list.find_first_of(',');
178     const auto token = trim(list.substr(0, end_pos));
179     result.push_back(std::string(token));
180     list = (end_pos != c10::string_view::npos) ? list.substr(end_pos + 1) : "";
181   }
182   return result;
183 }
184 
185 } // namespace c10d
186 
187 #endif // USE_C10D_UCC
188