1 #ifdef USE_C10D_UCC
2
3 #include <torch/csrc/distributed/c10d/UCCTracing.hpp>
4 #include <torch/csrc/distributed/c10d/UCCUtils.hpp>
5
6 #include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
7
8 #include <sys/stat.h>
9 #include <cstdlib>
10 #include <ctime>
11 #include <fstream>
12
13 namespace c10d {
14
initCommsTracer()15 void ProcessGroupUCCLogger::initCommsTracer() {
16 trace_generator = std::make_shared<CommTraceLogger>();
17 initialized_CommTraceLogger = true;
18 }
19
flushComms(int rank,int world_size)20 void ProcessGroupUCCLogger::flushComms(int rank, int world_size) {
21 if (!initialized_CommTraceLogger ||
22 trace_generator->getCommsTrace().empty()) {
23 return;
24 }
25
26 std::string dirname = c10::str("ProcessGroupUCC_trace_np", world_size);
27 time_t now_ = time(0);
28 std::tm* ltm = localtime(&now_);
29 if (ltm) {
30 dirname += c10::str(
31 "_", (1 + ltm->tm_mon), "_", ltm->tm_mday, "_", (1900 + ltm->tm_year));
32 }
33
34 std::string fullpath = "/tmp/" + dirname;
35 char* user_path = std::getenv("TORCH_UCC_COMMS_TRACE_OUTPUT_DIR");
36 if (user_path) {
37 fullpath = user_path;
38 }
39 std::string trace_filename = c10::str(fullpath, "/rank", rank, ".json");
40 std::ofstream _outfile;
41 if (!_outfile.is_open()) {
42 if (!mkdir(fullpath.c_str(), 0777)) {
43 LOG(INFO) << getLogPrefix() << "[INFO] failed to mkdir " << fullpath;
44 } else if (errno != EEXIST) {
45 return;
46 }
47 _outfile.open(trace_filename, std::ofstream::out | std::ofstream::trunc);
48 }
49 // flush the traced comms
50 if (_outfile.is_open()) {
51 _outfile << "[" << c10::Join(",", trace_generator->getCommsTrace())
52 << "\n]";
53 _outfile.flush();
54 _outfile.close();
55 }
56 }
57
58 /* unused */
setCurBlock(const std::string & name)59 void CommTraceLogger::setCurBlock(const std::string& name) {
60 curBlocks_.push_back(
61 c10::str("\"", name, "\"")); // add quote marks for JSON format
62 }
63
64 /* unused */
popBlock()65 void CommTraceLogger::popBlock() {
66 // TODO: remove specific name
67 curBlocks_.pop_back();
68 }
69
recordOptionalInfo(int root)70 void CommTraceLogger::recordOptionalInfo(int root) {
71 curRoot_ = root;
72 }
73
recordOptionalInfo(const std::vector<int64_t> & outputSplitSizes,const std::vector<int64_t> & inputSplitSizes)74 void CommTraceLogger::recordOptionalInfo(
75 const std::vector<int64_t>& outputSplitSizes,
76 const std::vector<int64_t>& inputSplitSizes) {
77 curOutSplitSizes_ = outputSplitSizes;
78 curInSplitSizes_ = inputSplitSizes;
79 }
80
recordComms(const std::string & commName,const uintptr_t workReq,const int rank,const int world_size,const std::vector<at::Tensor> & inputTensors,const std::vector<at::Tensor> & outputTensors)81 void CommTraceLogger::recordComms(
82 const std::string& commName,
83 const uintptr_t workReq,
84 const int rank,
85 const int world_size,
86 const std::vector<at::Tensor>& inputTensors,
87 const std::vector<at::Tensor>& outputTensors) {
88 auto inNelems = (!inputTensors.empty()) ? inputTensors[0].numel() : 0;
89 auto outNelems = (!outputTensors.empty()) ? outputTensors[0].numel() : 0;
90 auto dtype =
91 (!outputTensors.empty()) ? outputTensors[0].scalar_type() : at::kByte;
92 auto devType = (!outputTensors.empty()) ? outputTensors[0].device().type()
93 : c10::DeviceType::CPU;
94 auto now = std::chrono::system_clock::now();
95 static auto startTS = now;
96 int64_t time_since_begin =
97 std::chrono::duration_cast<std::chrono::nanoseconds>(now - startTS)
98 .count();
99
100 // TODO: get markers from torch profiler if enabled
101
102 // common fields for all operations
103 std::string cur_trace_ = c10::str(
104 "\n\t\t\"markers\": [",
105 curBlocks_,
106 "]",
107 ",\n\t\t\"startTime_ns\": ",
108 time_since_begin,
109 ",\n\t\t\"comms\": \"",
110 commName,
111 "\"",
112 ",\n\t\t\"req\": ",
113 workReq,
114 ",\n\t\t\"seqnum\": ",
115 seqnum,
116 ",\n\t\t\"world_size\": ",
117 world_size);
118
119 if (inNelems > 0 || outNelems > 0) {
120 // for most collectives - append msg sizes, data type, device type
121 cur_trace_ = c10::str(
122 cur_trace_,
123 ",\n\t\t\"in_msg_size\": ",
124 inNelems,
125 ",\n\t\t\"out_msg_size\": ",
126 outNelems,
127 ",\n\t\t\"dtype\": \"",
128 at::toString(dtype),
129 "\",\n\t\t\"devType\": \"",
130 c10::DeviceTypeName(devType),
131 "\"");
132 }
133 if (curRoot_ != -1) {
134 // append root rank if applicable, e.g., broadcast, gather, scatter
135 cur_trace_ = c10::str(cur_trace_, ",\n\t\t\"root\": ", curRoot_);
136 }
137 if (!curInSplitSizes_.empty() || !curOutSplitSizes_.empty()) {
138 // append input and output splits if applicable, e.g., ALLTOALL_BASE
139 cur_trace_ = c10::str(
140 cur_trace_,
141 ",\n\t\t\"in_split\": [",
142 c10::Join(",", curInSplitSizes_),
143 "]"
144 ",\n\t\t\"out_split\": [",
145 c10::Join(",", curOutSplitSizes_),
146 "]");
147 }
148 comms_trace_.push_back(c10::str("\n\t{", cur_trace_, "\n\t}"));
149
150 // record the trace to kineto trace if applicable
151 RECORD_PARAM_COMMS(
152 static_cast<int64_t>(seqnum), // seq
153 std::make_tuple("0", ""), // pg_name tuple
154 rank,
155 commName.c_str(),
156 inNelems,
157 outNelems,
158 dtype,
159 curInSplitSizes_,
160 curOutSplitSizes_,
161 -1,
162 -1,
163 world_size);
164
165 ++seqnum;
166
167 // reset optional field
168 curRoot_ = -1;
169 curInSplitSizes_ = {};
170 curOutSplitSizes_ = {};
171 }
172
173 } // namespace c10d
174
175 #endif // USE_C10D_UCC
176