1 #include <torch/csrc/distributed/rpc/utils.h>
2
3 #include <fmt/format.h>
4 #include <torch/csrc/autograd/profiler.h>
5 #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
6 #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
7 #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
8 #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
9 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
10 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h>
11 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h>
12 #include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h>
13 #include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_resp.h>
14 #include <torch/csrc/distributed/autograd/utils.h>
15 #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
16 #include <torch/csrc/distributed/rpc/python_call.h>
17 #include <torch/csrc/distributed/rpc/python_remote_call.h>
18 #include <torch/csrc/distributed/rpc/python_resp.h>
19 #include <torch/csrc/distributed/rpc/rref_proto.h>
20 #include <torch/csrc/distributed/rpc/script_call.h>
21 #include <torch/csrc/distributed/rpc/script_remote_call.h>
22 #include <torch/csrc/distributed/rpc/script_resp.h>
23 #include <torch/csrc/jit/serialization/pickler.h>
24 #include <torch/csrc/jit/serialization/unpickler.h>
25
26 #include <c10/util/irange.h>
27
28 using namespace torch::autograd::profiler;
29
30 namespace torch::distributed::rpc {
31 namespace {
processRemoteProfiledEvents(autograd::RpcWithProfilingResp & rpcWithProfilingResp)32 void processRemoteProfiledEvents(
33 autograd::RpcWithProfilingResp& rpcWithProfilingResp) {
34 // Check if the profiler is enabled
35 auto enabled = profilerEnabled();
36 TORCH_CHECK(
37 enabled,
38 "Profiler was expected to be enabled. This can happen in callback "
39 " continuations that run in different threads, and the TLS of the "
40 " profiler was not propagated.");
41 std::vector<LegacyEvent> events = rpcWithProfilingResp.getProfiledEvents();
42 const auto& profilingId = rpcWithProfilingResp.getProfilingId();
43 auto& remoteProfilerManager = RemoteProfilerManager::getInstance();
44 auto key = remoteProfilerManager.retrieveRPCProfilingKey(profilingId);
45 remoteProfilerManager.eraseKey(profilingId);
46 auto keyPrefixStr = key + rpc::REMOTE_PROFILING_KEY_PREFIX;
47 std::for_each(
48 events.begin(), events.end(), [&keyPrefixStr](LegacyEvent& event) {
49 std::string name = keyPrefixStr + std::string(event.name());
50 event.setName(at::StringView(name));
51 });
52 // Add event list to the thread local profiler.
53 addEventList(std::move(events));
54 }
55
56 } // namespace
57
58 const std::string kRPCErrorPrefix = std::string("RPCErr");
59
getRPCErrorType(const JitFuture & jitFuture)60 RPCErrorType getRPCErrorType(const JitFuture& jitFuture) {
61 TORCH_INTERNAL_ASSERT(
62 jitFuture.hasError(),
63 "JitFuture of Message passed to getRPCErrorType does not have an error.");
64
65 // Attempt to parse for error string given by makeRPCError, otherwise return
66 // unknown error.
67 // Note that this function expects errors formatted with makeRPCError().
68 auto err = jitFuture.tryRetrieveErrorMessage();
69 size_t pos = err.find(kRPCErrorPrefix);
70 if (pos != std::string::npos) {
71 // Parse the RPCErrorType.
72 auto errStartIdx =
73 pos + torch::distributed::rpc::kRPCErrorPrefix.size() + 1;
74 auto errEndIdx = err.find(':', errStartIdx);
75 if (errEndIdx == std::string::npos) {
76 // Indicates error was not formatted correctly.
77 return RPCErrorType::UNKNOWN_ERROR;
78 }
79 auto errStr = err.substr(errStartIdx, errEndIdx - errStartIdx);
80 auto errType = static_cast<RPCErrorType>(std::stoi(errStr));
81 return errType;
82 } else {
83 return RPCErrorType::UNKNOWN_ERROR;
84 }
85 }
86
makeRPCError(const std::string & rpcErrorStr,RPCErrorType errorType)87 std::string makeRPCError(
88 const std::string& rpcErrorStr,
89 RPCErrorType errorType) {
90 return fmt::format(
91 "{}:{}:{}",
92 torch::distributed::rpc::kRPCErrorPrefix,
93 static_cast<int>(errorType),
94 rpcErrorStr);
95 }
96
deserializeRequest(const Message & request)97 std::unique_ptr<RpcCommandBase> deserializeRequest(const Message& request) {
98 switch (request.type()) {
99 case MessageType::SCRIPT_CALL: {
100 return ScriptCall::fromMessage(request);
101 }
102 case MessageType::PYTHON_CALL: {
103 return PythonCall::fromMessage(request);
104 }
105 case MessageType::SCRIPT_REMOTE_CALL: {
106 return ScriptRemoteCall::fromMessage(request);
107 }
108 case MessageType::PYTHON_REMOTE_CALL: {
109 return PythonRemoteCall::fromMessage(request);
110 }
111 case MessageType::SCRIPT_RREF_FETCH_CALL: {
112 return ScriptRRefFetchCall::fromMessage(request);
113 }
114 case MessageType::PYTHON_RREF_FETCH_CALL: {
115 return PythonRRefFetchCall::fromMessage(request);
116 }
117 case MessageType::RREF_USER_DELETE: {
118 return RRefUserDelete::fromMessage(request);
119 }
120 case MessageType::RREF_CHILD_ACCEPT: {
121 return RRefChildAccept::fromMessage(request);
122 }
123 case MessageType::RREF_FORK_REQUEST: {
124 return RRefForkRequest::fromMessage(request);
125 }
126 case MessageType::FORWARD_AUTOGRAD_REQ: {
127 return autograd::RpcWithAutograd::fromMessage(request);
128 }
129 case MessageType::BACKWARD_AUTOGRAD_REQ: {
130 return autograd::PropagateGradientsReq::fromMessage(request);
131 }
132 case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: {
133 return autograd::CleanupAutogradContextReq::fromMessage(request);
134 }
135 case MessageType::RUN_WITH_PROFILING_REQ: {
136 return autograd::RpcWithProfilingReq::fromMessage(request);
137 }
138 case MessageType::RREF_BACKWARD_REQ: {
139 return autograd::RRefBackwardReq::fromMessage(request);
140 }
141 default: {
142 TORCH_INTERNAL_ASSERT(
143 false, "Request type ", request.type(), " not supported.");
144 }
145 }
146 }
147
deserializeResponse(const Message & response,MessageType & wrappedMsgType)148 std::unique_ptr<RpcCommandBase> deserializeResponse(
149 const Message& response,
150 MessageType& wrappedMsgType) {
151 switch (response.type()) {
152 case MessageType::SCRIPT_RET: {
153 return ScriptResp::fromMessage(response);
154 }
155 case MessageType::PYTHON_RET: {
156 return PythonResp::fromMessage(response);
157 }
158 case MessageType::REMOTE_RET: {
159 return RemoteRet::fromMessage(response);
160 }
161 case MessageType::SCRIPT_RREF_FETCH_RET: {
162 return ScriptRRefFetchRet::fromMessage(response);
163 }
164 case MessageType::PYTHON_RREF_FETCH_RET: {
165 return PythonRRefFetchRet::fromMessage(response);
166 }
167 case MessageType::RREF_ACK: {
168 return RRefAck::fromMessage(response);
169 }
170 case MessageType::FORWARD_AUTOGRAD_RESP: {
171 std::unique_ptr<RpcCommandBase> rpcPtr =
172 autograd::RpcWithAutograd::fromMessage(response);
173 RpcCommandBase& rpc = *rpcPtr;
174 auto& rpcWithAutograd = static_cast<autograd::RpcWithAutograd&>(rpc);
175
176 // Need to reverse the device map for the backward pass of distributed
177 // autograd.
178 DeviceMap reverseDeviceMap;
179 for (const auto& mapEntry : rpcWithAutograd.deviceMap()) {
180 reverseDeviceMap.insert({mapEntry.second, mapEntry.first});
181 }
182
183 // Attach 'recv' autograd function.
184 addRecvRpcBackward(
185 rpcWithAutograd.autogradMetadata(),
186 rpcWithAutograd.tensors(),
187 rpcWithAutograd.fromWorkerId(),
188 reverseDeviceMap);
189
190 wrappedMsgType = rpcWithAutograd.wrappedMessageType();
191
192 return std::move(rpcWithAutograd).moveWrappedRpc();
193 }
194 case MessageType::BACKWARD_AUTOGRAD_RESP: {
195 return autograd::PropagateGradientsResp::fromMessage(response);
196 }
197 case MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP: {
198 return autograd::CleanupAutogradContextResp::fromMessage(response);
199 }
200 case MessageType::RUN_WITH_PROFILING_RESP: {
201 std::unique_ptr<RpcCommandBase> rpcPtr =
202 autograd::RpcWithProfilingResp::fromMessage(response);
203 RpcCommandBase& rpc = *rpcPtr;
204 auto& rpcWithProfilingResp =
205 static_cast<autograd::RpcWithProfilingResp&>(rpc);
206 // Process remotely profiled events.
207 processRemoteProfiledEvents(rpcWithProfilingResp);
208
209 wrappedMsgType = rpcWithProfilingResp.wrappedMessageType();
210 auto wrappedRPC = std::move(rpcWithProfilingResp).moveWrappedRpc();
211 return wrappedRPC;
212 }
213 case MessageType::RREF_BACKWARD_RESP: {
214 return autograd::RRefBackwardResp::fromMessage(response);
215 }
216 default: {
217 TORCH_INTERNAL_ASSERT(
218 false, "Response type ", response.type(), " not supported.");
219 }
220 }
221 }
222
deserializeResptoIValueInternal(RpcCommandBase & rpc,MessageType messageType)223 IValue deserializeResptoIValueInternal(
224 RpcCommandBase& rpc,
225 MessageType messageType) {
226 switch (messageType) {
227 case MessageType::SCRIPT_RET: {
228 auto& ret = static_cast<ScriptResp&>(rpc);
229 return ret.value();
230 }
231 default: {
232 TORCH_INTERNAL_ASSERT(
233 false,
234 "Response type ",
235 messageType,
236 " is not supported to be deserialized to IValue.");
237 }
238 }
239 }
240
deserializeRespToIValue(const Message & message)241 IValue deserializeRespToIValue(const Message& message) {
242 MessageType msgType = message.type();
243 auto response = deserializeResponse(message, msgType);
244 return deserializeResptoIValueInternal(*response, msgType);
245 }
246
247 namespace {
248
249 // Helper for wireDeserialize() below.
250 //
251 // The format we use below looks like:
252 // section_name_1 size_1\n
253 // section_name_2 size_2\n
254 // ..
255 // \n
256 // [sections in order]
257 //
258 // Sections themselves include:
259 // - "payload" - the payload bits
260 // - "meta" - metadata for the unpickler
261 // - "0" ... - tensor sections for the unpickler
262 //
263 // Note that per the header comments, the format is subject to change,
264 // and is best used for rpcs, rather than persistent disk storage.
265 std::unordered_map<std::string, std::pair<const char*, size_t>>
parseWireSections(const void * data,size_t data_size)266 parseWireSections(const void* data, size_t data_size) {
267 const char* ptr = static_cast<const char*>(data);
268 const char* endp = ptr + data_size;
269
270 std::vector<std::pair<std::string, size_t>> headerEnts;
271 bool ok = false;
272 while (ptr != endp) {
273 if (*ptr == '\n') {
274 ok = true; // The only "correct" exit point.
275 ++ptr;
276 break;
277 }
278 // Parse name
279 const char* namePtr = ptr;
280 while (ptr != endp && *ptr != ' ') {
281 ptr++;
282 }
283 if (ptr == endp) {
284 break;
285 }
286 std::string name(namePtr, ptr - namePtr);
287 if (++ptr == endp) {
288 break; // past the ' '
289 }
290 // Parse size
291 const char* sizePtr = ptr;
292 while (ptr != endp && *ptr != '\n') {
293 ptr++;
294 }
295 if (ptr == endp) {
296 break;
297 }
298 size_t sz = std::stoll(std::string(sizePtr, ptr - sizePtr));
299 headerEnts.emplace_back(name, sz);
300 ++ptr; // past the '\n'
301 }
302 if (!ok) {
303 TORCH_CHECK(false, "failed parse");
304 }
305
306 std::unordered_map<std::string, std::pair<const char*, size_t>> out;
307 for (const auto& headerEnt : headerEnts) {
308 out[headerEnt.first] = {ptr, headerEnt.second};
309 ptr += headerEnt.second;
310 }
311 if (ptr != endp) {
312 TORCH_CHECK(false, "failed bounds");
313 }
314 return out;
315 }
316
317 static const char* kMeta = "meta";
318 static const char* kPayload = "payload";
319 }; // namespace
320
cloneSparseTensors(const std::vector<at::Tensor> & tensors)321 c10::List<at::Tensor> cloneSparseTensors(
322 const std::vector<at::Tensor>& tensors) {
323 // Sanity-check: If the majority of bits don't need to go over the wire,
324 // force a clone(). Some Tensors are effectively small views, only using
325 // ~1% of the underlying Storage.
326 auto worthRecopying = [](const at::Tensor& t) -> bool {
327 if (!t.has_storage()) {
328 return false; // avoid throwing below.
329 }
330 auto storageSize = t.storage().nbytes();
331 auto usefulSize = t.element_size() * t.numel();
332 constexpr size_t kMinMultiple = 2;
333 constexpr size_t kMinRecopyBytes = 8 * 1024;
334 return storageSize >= kMinRecopyBytes &&
335 storageSize >= usefulSize * kMinMultiple;
336 };
337 c10::List<at::Tensor> pTensors;
338 pTensors.reserve(tensors.size());
339 for (const auto& t : tensors) {
340 pTensors.push_back(worthRecopying(t) ? t.clone() : t);
341 }
342 return pTensors;
343 }
344
wireSerialize(const std::vector<char> & payload,const std::vector<at::Tensor> & tensors)345 std::string wireSerialize(
346 const std::vector<char>& payload,
347 const std::vector<at::Tensor>& tensors) {
348 for (const auto& tensor : tensors) {
349 TORCH_CHECK(
350 tensor.device().is_cpu(),
351 "ProcessGroup RPC backend only supports",
352 " CPU tensors, please move your tensors to CPU before sending ",
353 "them over RPC. Found tensor on device: ",
354 tensor.device());
355 }
356
357 struct Ent {
358 std::string name;
359 const char* data;
360 size_t size;
361 };
362 std::vector<Ent> entries;
363 std::string metaEntry;
364 std::vector<at::Tensor> tensorData;
365
366 if (!payload.empty()) {
367 entries.push_back({kPayload, payload.data(), payload.size()});
368 }
369
370 if (!tensors.empty()) {
371 torch::jit::Pickler pickler([&](const void* buf, size_t sz) -> size_t {
372 metaEntry.append(static_cast<const char*>(buf), sz);
373 return sz;
374 });
375 pickler.protocol();
376 pickler.pushIValue(cloneSparseTensors(tensors));
377 pickler.stop();
378 tensorData = pickler.tensorData();
379 entries.push_back({kMeta, metaEntry.data(), metaEntry.size()});
380 for (const auto i : c10::irange(tensorData.size())) {
381 // Construct WritableTensorData for each tensor in the pickler tensorData
382 // Since tensorData is in function scope, and getWritableTensorData just
383 // record the tensors, the data() pointers stay valid for CPU tensors
384 // Note that RPC serde doesn't support CUDA tensors yet, if we should
385 // support CUDA tensor, we need to be careful since getWritableTensorData
386 // converts CUDA tensor to cpu and data() might get destructed as we go
387 // out of scope of this loop.
388 auto writeableTensorData = jit::getWriteableTensorData(tensorData[i]);
389 entries.push_back(
390 {std::to_string(i),
391 writeableTensorData.data(),
392 writeableTensorData.sizeInBytes()});
393 }
394 }
395
396 std::string header;
397 size_t tot = 0;
398 for (const auto& e : entries) {
399 tot += e.size;
400 header.append(e.name)
401 .append(" ")
402 .append(std::to_string(e.size))
403 .append("\n");
404 }
405 header.push_back('\n');
406
407 std::string out;
408 out.reserve(header.size() + tot);
409 out.append(header);
410 for (const auto& e : entries) {
411 out.append(e.data, e.size);
412 }
413 return out;
414 }
415
wireDeserialize(const void * data,size_t data_size)416 std::pair<std::vector<char>, std::vector<at::Tensor>> wireDeserialize(
417 const void* data,
418 size_t data_size) {
419 auto sections = parseWireSections(data, data_size);
420
421 std::vector<char> payload;
422 auto payloadIt = sections.find(kPayload);
423 if (payloadIt != sections.end() && payloadIt->second.second != 0) {
424 payload.assign(
425 payloadIt->second.first,
426 payloadIt->second.first + payloadIt->second.second);
427 }
428
429 std::vector<at::Tensor> tensors;
430 auto metaIt = sections.find(kMeta);
431 if (metaIt != sections.end()) {
432 const auto& metaData = metaIt->second;
433 size_t metaDataPos = 0;
434 auto metaDataReadFunc = [&](char* buf, size_t n) -> size_t {
435 if (metaDataPos >= metaData.second || n == 0) {
436 return 0;
437 }
438 size_t toCopy = std::min(metaDataPos + n, metaData.second) - metaDataPos;
439 memcpy(buf, metaData.first + metaDataPos, toCopy);
440 metaDataPos += toCopy;
441 return toCopy;
442 };
443 auto sectionReadFunc = [&](const std::string& ename) -> at::DataPtr {
444 auto it = sections.find(ename);
445 if (it == sections.end()) {
446 TORCH_CHECK(false, "Couldn't find entity " + ename);
447 }
448 const auto& idat = it->second;
449 auto dptr = at::getCPUAllocator()->allocate(idat.second);
450 if (idat.second != 0) {
451 memcpy(dptr.get(), idat.first, idat.second);
452 }
453 return dptr;
454 };
455
456 // No need to pass typeResolver here, as it always processes string and
457 // tensors only
458 torch::jit::Unpickler unpickler(
459 metaDataReadFunc, nullptr, nullptr, sectionReadFunc, {});
460 auto ival = unpickler.parse_ivalue();
461 for (auto&& t : ival.toTensorList()) {
462 tensors.emplace_back(std::move(t));
463 }
464 }
465 return {std::move(payload), std::move(tensors)};
466 }
467
writeWrappedPayload(std::vector<char> & originalPayload,std::vector<char> & additionalPayload)468 void writeWrappedPayload(
469 std::vector<char>& originalPayload,
470 std::vector<char>& additionalPayload) {
471 originalPayload.insert(
472 originalPayload.end(),
473 additionalPayload.begin(),
474 additionalPayload.end());
475
476 // Add size of the additional payload
477 int64_t indexToWrite = originalPayload.size();
478 originalPayload.resize(originalPayload.size() + sizeof(int64_t));
479 const int64_t additionalPayloadSize = additionalPayload.size();
480 torch::utils::THP_encodeInt64Buffer(
481 reinterpret_cast<uint8_t*>(originalPayload.data()) + indexToWrite,
482 &additionalPayloadSize,
483 torch::utils::THPByteOrder::THP_BIG_ENDIAN,
484 1);
485 }
486
readWrappedPayload(std::vector<char> & payload,const rpc::Message & message)487 std::vector<at::IValue> readWrappedPayload(
488 std::vector<char>& payload,
489 const rpc::Message& message) {
490 // Read the additional payload remove it from the payload.
491 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
492 int64_t additionalPayloadSize;
493 TORCH_INTERNAL_ASSERT(payload.size() >= sizeof(int64_t));
494 size_t indexToRead = payload.size() - sizeof(int64_t);
495 torch::utils::THP_decodeInt64Buffer(
496 &additionalPayloadSize,
497 reinterpret_cast<uint8_t*>(payload.data()) + indexToRead,
498 torch::utils::THPByteOrder::THP_BIG_ENDIAN,
499 1);
500 payload.resize(indexToRead);
501
502 TORCH_INTERNAL_ASSERT(
503 additionalPayloadSize > 0 &&
504 static_cast<int64_t>(payload.size()) > additionalPayloadSize,
505 "Wrong payload sizes: payload.size() is ",
506 payload.size(),
507 " but additional payload size is ",
508 additionalPayloadSize);
509 auto wrappedPayloadBegin =
510 static_cast<const char*>(message.payload().data()) + payload.size() -
511 additionalPayloadSize;
512 std::vector<torch::Tensor> tensorTable;
513 IValue tuple = jit::unpickle(
514 wrappedPayloadBegin,
515 additionalPayloadSize,
516 *rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
517 tensorTable);
518 std::vector<at::IValue> tupleElements = tuple.toTupleRef().elements().vec();
519 payload.resize(payload.size() - additionalPayloadSize);
520 return tupleElements;
521 }
522
populateRemoteProfiledEvents(std::vector<LegacyEvent> & profiledEvents,const ProfilerConfig & profilingConfig,const std::vector<std::vector<LegacyEvent>> & eventLists)523 void populateRemoteProfiledEvents(
524 std::vector<LegacyEvent>& profiledEvents,
525 const ProfilerConfig& profilingConfig,
526 const std::vector<std::vector<LegacyEvent>>& eventLists) {
527 // Gather all events into a vector
528 for (auto& l : eventLists) {
529 for (auto& e : l) {
530 profiledEvents.push_back(e);
531 }
532 }
533 // find __start_profile event
534 bool cudaProfilingEnabled = profilingConfig.state == ProfilerState::CUDA;
535 const LegacyEvent* profilerStart = nullptr;
536
537 for (auto& e : profiledEvents) {
538 if (std::string(e.name()) == "__start_profile") {
539 profilerStart = &e;
540 break;
541 }
542 }
543 // We should always find __start_profile.
544 TORCH_CHECK(
545 profilerStart != nullptr, "Expected to find __start_profile event.");
546
547 if (cudaProfilingEnabled) {
548 // Deserialized events don't have the corresponding CUDA events, making it
549 // impossible to use cudaEventElapsedTime the receiving end. To avoid this,
550 // find all push/pop pairs of CUDA events and set the corresponding CUDA
551 // time to zero for the push event and to the elapsed time for the pop
552 // event, to be used later for the elapsed CUDA time computation.
553 std::unordered_map<at::RecordFunctionHandle, const LegacyEvent*>
554 startEvents;
555 for (auto& e : profiledEvents) {
556 if (e.hasCuda()) {
557 if (e.kind() == EventKind::PushRange) {
558 startEvents[e.handle()] = &e;
559 }
560 }
561 }
562 for (auto& e : profiledEvents) {
563 if (e.hasCuda()) {
564 if (e.kind() == EventKind::PopRange) {
565 auto it = startEvents.find(e.handle());
566 if (it != startEvents.end()) {
567 e.setCudaUs(it->second->cudaElapsedUs(e));
568 } else {
569 TORCH_WARN("Found a pop event without a corresponding push event");
570 e.setCudaUs(0);
571 }
572 } else {
573 e.setCudaUs(0);
574 }
575 }
576 }
577 }
578 }
579
580 } // namespace torch::distributed::rpc
581