xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/UCCTracing.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef USE_C10D_UCC
2 
3 #include <torch/csrc/distributed/c10d/UCCTracing.hpp>
4 #include <torch/csrc/distributed/c10d/UCCUtils.hpp>
5 
6 #include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
7 
8 #include <sys/stat.h>
9 #include <cstdlib>
10 #include <ctime>
11 #include <fstream>
12 
13 namespace c10d {
14 
initCommsTracer()15 void ProcessGroupUCCLogger::initCommsTracer() {
16   trace_generator = std::make_shared<CommTraceLogger>();
17   initialized_CommTraceLogger = true;
18 }
19 
flushComms(int rank,int world_size)20 void ProcessGroupUCCLogger::flushComms(int rank, int world_size) {
21   if (!initialized_CommTraceLogger ||
22       trace_generator->getCommsTrace().empty()) {
23     return;
24   }
25 
26   std::string dirname = c10::str("ProcessGroupUCC_trace_np", world_size);
27   time_t now_ = time(0);
28   std::tm* ltm = localtime(&now_);
29   if (ltm) {
30     dirname += c10::str(
31         "_", (1 + ltm->tm_mon), "_", ltm->tm_mday, "_", (1900 + ltm->tm_year));
32   }
33 
34   std::string fullpath = "/tmp/" + dirname;
35   char* user_path = std::getenv("TORCH_UCC_COMMS_TRACE_OUTPUT_DIR");
36   if (user_path) {
37     fullpath = user_path;
38   }
39   std::string trace_filename = c10::str(fullpath, "/rank", rank, ".json");
40   std::ofstream _outfile;
41   if (!_outfile.is_open()) {
42     if (!mkdir(fullpath.c_str(), 0777)) {
43       LOG(INFO) << getLogPrefix() << "[INFO] failed to mkdir " << fullpath;
44     } else if (errno != EEXIST) {
45       return;
46     }
47     _outfile.open(trace_filename, std::ofstream::out | std::ofstream::trunc);
48   }
49   // flush the traced comms
50   if (_outfile.is_open()) {
51     _outfile << "[" << c10::Join(",", trace_generator->getCommsTrace())
52              << "\n]";
53     _outfile.flush();
54     _outfile.close();
55   }
56 }
57 
58 /* unused */
setCurBlock(const std::string & name)59 void CommTraceLogger::setCurBlock(const std::string& name) {
60   curBlocks_.push_back(
61       c10::str("\"", name, "\"")); // add quote marks for JSON format
62 }
63 
64 /* unused */
popBlock()65 void CommTraceLogger::popBlock() {
66   // TODO: remove specific name
67   curBlocks_.pop_back();
68 }
69 
recordOptionalInfo(int root)70 void CommTraceLogger::recordOptionalInfo(int root) {
71   curRoot_ = root;
72 }
73 
recordOptionalInfo(const std::vector<int64_t> & outputSplitSizes,const std::vector<int64_t> & inputSplitSizes)74 void CommTraceLogger::recordOptionalInfo(
75     const std::vector<int64_t>& outputSplitSizes,
76     const std::vector<int64_t>& inputSplitSizes) {
77   curOutSplitSizes_ = outputSplitSizes;
78   curInSplitSizes_ = inputSplitSizes;
79 }
80 
recordComms(const std::string & commName,const uintptr_t workReq,const int rank,const int world_size,const std::vector<at::Tensor> & inputTensors,const std::vector<at::Tensor> & outputTensors)81 void CommTraceLogger::recordComms(
82     const std::string& commName,
83     const uintptr_t workReq,
84     const int rank,
85     const int world_size,
86     const std::vector<at::Tensor>& inputTensors,
87     const std::vector<at::Tensor>& outputTensors) {
88   auto inNelems = (!inputTensors.empty()) ? inputTensors[0].numel() : 0;
89   auto outNelems = (!outputTensors.empty()) ? outputTensors[0].numel() : 0;
90   auto dtype =
91       (!outputTensors.empty()) ? outputTensors[0].scalar_type() : at::kByte;
92   auto devType = (!outputTensors.empty()) ? outputTensors[0].device().type()
93                                           : c10::DeviceType::CPU;
94   auto now = std::chrono::system_clock::now();
95   static auto startTS = now;
96   int64_t time_since_begin =
97       std::chrono::duration_cast<std::chrono::nanoseconds>(now - startTS)
98           .count();
99 
100   // TODO: get markers from torch profiler if enabled
101 
102   // common fields for all operations
103   std::string cur_trace_ = c10::str(
104       "\n\t\t\"markers\": [",
105       curBlocks_,
106       "]",
107       ",\n\t\t\"startTime_ns\": ",
108       time_since_begin,
109       ",\n\t\t\"comms\": \"",
110       commName,
111       "\"",
112       ",\n\t\t\"req\": ",
113       workReq,
114       ",\n\t\t\"seqnum\": ",
115       seqnum,
116       ",\n\t\t\"world_size\": ",
117       world_size);
118 
119   if (inNelems > 0 || outNelems > 0) {
120     // for most collectives - append msg sizes, data type, device type
121     cur_trace_ = c10::str(
122         cur_trace_,
123         ",\n\t\t\"in_msg_size\": ",
124         inNelems,
125         ",\n\t\t\"out_msg_size\": ",
126         outNelems,
127         ",\n\t\t\"dtype\": \"",
128         at::toString(dtype),
129         "\",\n\t\t\"devType\": \"",
130         c10::DeviceTypeName(devType),
131         "\"");
132   }
133   if (curRoot_ != -1) {
134     // append root rank if applicable, e.g., broadcast, gather, scatter
135     cur_trace_ = c10::str(cur_trace_, ",\n\t\t\"root\": ", curRoot_);
136   }
137   if (!curInSplitSizes_.empty() || !curOutSplitSizes_.empty()) {
138     // append input and output splits if applicable, e.g., ALLTOALL_BASE
139     cur_trace_ = c10::str(
140         cur_trace_,
141         ",\n\t\t\"in_split\": [",
142         c10::Join(",", curInSplitSizes_),
143         "]"
144         ",\n\t\t\"out_split\": [",
145         c10::Join(",", curOutSplitSizes_),
146         "]");
147   }
148   comms_trace_.push_back(c10::str("\n\t{", cur_trace_, "\n\t}"));
149 
150   // record the trace to kineto trace if applicable
151   RECORD_PARAM_COMMS(
152       static_cast<int64_t>(seqnum), // seq
153       std::make_tuple("0", ""), // pg_name tuple
154       rank,
155       commName.c_str(),
156       inNelems,
157       outNelems,
158       dtype,
159       curInSplitSizes_,
160       curOutSplitSizes_,
161       -1,
162       -1,
163       world_size);
164 
165   ++seqnum;
166 
167   // reset optional field
168   curRoot_ = -1;
169   curInSplitSizes_ = {};
170   curOutSplitSizes_ = {};
171 }
172 
173 } // namespace c10d
174 
175 #endif // USE_C10D_UCC
176