xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/utils.h>
2 
3 #include <fmt/format.h>
4 #include <torch/csrc/autograd/profiler.h>
5 #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
6 #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
7 #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
8 #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
9 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
10 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h>
11 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h>
12 #include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h>
13 #include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_resp.h>
14 #include <torch/csrc/distributed/autograd/utils.h>
15 #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
16 #include <torch/csrc/distributed/rpc/python_call.h>
17 #include <torch/csrc/distributed/rpc/python_remote_call.h>
18 #include <torch/csrc/distributed/rpc/python_resp.h>
19 #include <torch/csrc/distributed/rpc/rref_proto.h>
20 #include <torch/csrc/distributed/rpc/script_call.h>
21 #include <torch/csrc/distributed/rpc/script_remote_call.h>
22 #include <torch/csrc/distributed/rpc/script_resp.h>
23 #include <torch/csrc/jit/serialization/pickler.h>
24 #include <torch/csrc/jit/serialization/unpickler.h>
25 
26 #include <c10/util/irange.h>
27 
28 using namespace torch::autograd::profiler;
29 
30 namespace torch::distributed::rpc {
31 namespace {
processRemoteProfiledEvents(autograd::RpcWithProfilingResp & rpcWithProfilingResp)32 void processRemoteProfiledEvents(
33     autograd::RpcWithProfilingResp& rpcWithProfilingResp) {
34   // Check if the profiler is enabled
35   auto enabled = profilerEnabled();
36   TORCH_CHECK(
37       enabled,
38       "Profiler was expected to be enabled. This can happen in callback "
39       " continuations that run in different threads, and the TLS of the "
40       " profiler was not propagated.");
41   std::vector<LegacyEvent> events = rpcWithProfilingResp.getProfiledEvents();
42   const auto& profilingId = rpcWithProfilingResp.getProfilingId();
43   auto& remoteProfilerManager = RemoteProfilerManager::getInstance();
44   auto key = remoteProfilerManager.retrieveRPCProfilingKey(profilingId);
45   remoteProfilerManager.eraseKey(profilingId);
46   auto keyPrefixStr = key + rpc::REMOTE_PROFILING_KEY_PREFIX;
47   std::for_each(
48       events.begin(), events.end(), [&keyPrefixStr](LegacyEvent& event) {
49         std::string name = keyPrefixStr + std::string(event.name());
50         event.setName(at::StringView(name));
51       });
52   // Add event list to the thread local profiler.
53   addEventList(std::move(events));
54 }
55 
56 } // namespace
57 
58 const std::string kRPCErrorPrefix = std::string("RPCErr");
59 
getRPCErrorType(const JitFuture & jitFuture)60 RPCErrorType getRPCErrorType(const JitFuture& jitFuture) {
61   TORCH_INTERNAL_ASSERT(
62       jitFuture.hasError(),
63       "JitFuture of Message passed to getRPCErrorType does not have an error.");
64 
65   // Attempt to parse for error string given by makeRPCError, otherwise return
66   // unknown error.
67   // Note that this function expects errors formatted with makeRPCError().
68   auto err = jitFuture.tryRetrieveErrorMessage();
69   size_t pos = err.find(kRPCErrorPrefix);
70   if (pos != std::string::npos) {
71     // Parse the RPCErrorType.
72     auto errStartIdx =
73         pos + torch::distributed::rpc::kRPCErrorPrefix.size() + 1;
74     auto errEndIdx = err.find(':', errStartIdx);
75     if (errEndIdx == std::string::npos) {
76       // Indicates error was not formatted correctly.
77       return RPCErrorType::UNKNOWN_ERROR;
78     }
79     auto errStr = err.substr(errStartIdx, errEndIdx - errStartIdx);
80     auto errType = static_cast<RPCErrorType>(std::stoi(errStr));
81     return errType;
82   } else {
83     return RPCErrorType::UNKNOWN_ERROR;
84   }
85 }
86 
makeRPCError(const std::string & rpcErrorStr,RPCErrorType errorType)87 std::string makeRPCError(
88     const std::string& rpcErrorStr,
89     RPCErrorType errorType) {
90   return fmt::format(
91       "{}:{}:{}",
92       torch::distributed::rpc::kRPCErrorPrefix,
93       static_cast<int>(errorType),
94       rpcErrorStr);
95 }
96 
deserializeRequest(const Message & request)97 std::unique_ptr<RpcCommandBase> deserializeRequest(const Message& request) {
98   switch (request.type()) {
99     case MessageType::SCRIPT_CALL: {
100       return ScriptCall::fromMessage(request);
101     }
102     case MessageType::PYTHON_CALL: {
103       return PythonCall::fromMessage(request);
104     }
105     case MessageType::SCRIPT_REMOTE_CALL: {
106       return ScriptRemoteCall::fromMessage(request);
107     }
108     case MessageType::PYTHON_REMOTE_CALL: {
109       return PythonRemoteCall::fromMessage(request);
110     }
111     case MessageType::SCRIPT_RREF_FETCH_CALL: {
112       return ScriptRRefFetchCall::fromMessage(request);
113     }
114     case MessageType::PYTHON_RREF_FETCH_CALL: {
115       return PythonRRefFetchCall::fromMessage(request);
116     }
117     case MessageType::RREF_USER_DELETE: {
118       return RRefUserDelete::fromMessage(request);
119     }
120     case MessageType::RREF_CHILD_ACCEPT: {
121       return RRefChildAccept::fromMessage(request);
122     }
123     case MessageType::RREF_FORK_REQUEST: {
124       return RRefForkRequest::fromMessage(request);
125     }
126     case MessageType::FORWARD_AUTOGRAD_REQ: {
127       return autograd::RpcWithAutograd::fromMessage(request);
128     }
129     case MessageType::BACKWARD_AUTOGRAD_REQ: {
130       return autograd::PropagateGradientsReq::fromMessage(request);
131     }
132     case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: {
133       return autograd::CleanupAutogradContextReq::fromMessage(request);
134     }
135     case MessageType::RUN_WITH_PROFILING_REQ: {
136       return autograd::RpcWithProfilingReq::fromMessage(request);
137     }
138     case MessageType::RREF_BACKWARD_REQ: {
139       return autograd::RRefBackwardReq::fromMessage(request);
140     }
141     default: {
142       TORCH_INTERNAL_ASSERT(
143           false, "Request type ", request.type(), " not supported.");
144     }
145   }
146 }
147 
deserializeResponse(const Message & response,MessageType & wrappedMsgType)148 std::unique_ptr<RpcCommandBase> deserializeResponse(
149     const Message& response,
150     MessageType& wrappedMsgType) {
151   switch (response.type()) {
152     case MessageType::SCRIPT_RET: {
153       return ScriptResp::fromMessage(response);
154     }
155     case MessageType::PYTHON_RET: {
156       return PythonResp::fromMessage(response);
157     }
158     case MessageType::REMOTE_RET: {
159       return RemoteRet::fromMessage(response);
160     }
161     case MessageType::SCRIPT_RREF_FETCH_RET: {
162       return ScriptRRefFetchRet::fromMessage(response);
163     }
164     case MessageType::PYTHON_RREF_FETCH_RET: {
165       return PythonRRefFetchRet::fromMessage(response);
166     }
167     case MessageType::RREF_ACK: {
168       return RRefAck::fromMessage(response);
169     }
170     case MessageType::FORWARD_AUTOGRAD_RESP: {
171       std::unique_ptr<RpcCommandBase> rpcPtr =
172           autograd::RpcWithAutograd::fromMessage(response);
173       RpcCommandBase& rpc = *rpcPtr;
174       auto& rpcWithAutograd = static_cast<autograd::RpcWithAutograd&>(rpc);
175 
176       // Need to reverse the device map for the backward pass of distributed
177       // autograd.
178       DeviceMap reverseDeviceMap;
179       for (const auto& mapEntry : rpcWithAutograd.deviceMap()) {
180         reverseDeviceMap.insert({mapEntry.second, mapEntry.first});
181       }
182 
183       // Attach 'recv' autograd function.
184       addRecvRpcBackward(
185           rpcWithAutograd.autogradMetadata(),
186           rpcWithAutograd.tensors(),
187           rpcWithAutograd.fromWorkerId(),
188           reverseDeviceMap);
189 
190       wrappedMsgType = rpcWithAutograd.wrappedMessageType();
191 
192       return std::move(rpcWithAutograd).moveWrappedRpc();
193     }
194     case MessageType::BACKWARD_AUTOGRAD_RESP: {
195       return autograd::PropagateGradientsResp::fromMessage(response);
196     }
197     case MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP: {
198       return autograd::CleanupAutogradContextResp::fromMessage(response);
199     }
200     case MessageType::RUN_WITH_PROFILING_RESP: {
201       std::unique_ptr<RpcCommandBase> rpcPtr =
202           autograd::RpcWithProfilingResp::fromMessage(response);
203       RpcCommandBase& rpc = *rpcPtr;
204       auto& rpcWithProfilingResp =
205           static_cast<autograd::RpcWithProfilingResp&>(rpc);
206       // Process remotely profiled events.
207       processRemoteProfiledEvents(rpcWithProfilingResp);
208 
209       wrappedMsgType = rpcWithProfilingResp.wrappedMessageType();
210       auto wrappedRPC = std::move(rpcWithProfilingResp).moveWrappedRpc();
211       return wrappedRPC;
212     }
213     case MessageType::RREF_BACKWARD_RESP: {
214       return autograd::RRefBackwardResp::fromMessage(response);
215     }
216     default: {
217       TORCH_INTERNAL_ASSERT(
218           false, "Response type ", response.type(), " not supported.");
219     }
220   }
221 }
222 
deserializeResptoIValueInternal(RpcCommandBase & rpc,MessageType messageType)223 IValue deserializeResptoIValueInternal(
224     RpcCommandBase& rpc,
225     MessageType messageType) {
226   switch (messageType) {
227     case MessageType::SCRIPT_RET: {
228       auto& ret = static_cast<ScriptResp&>(rpc);
229       return ret.value();
230     }
231     default: {
232       TORCH_INTERNAL_ASSERT(
233           false,
234           "Response type ",
235           messageType,
236           " is not supported to be deserialized to IValue.");
237     }
238   }
239 }
240 
deserializeRespToIValue(const Message & message)241 IValue deserializeRespToIValue(const Message& message) {
242   MessageType msgType = message.type();
243   auto response = deserializeResponse(message, msgType);
244   return deserializeResptoIValueInternal(*response, msgType);
245 }
246 
247 namespace {
248 
249 // Helper for wireDeserialize() below.
250 //
251 // The format we use below looks like:
252 //    section_name_1 size_1\n
253 //    section_name_2 size_2\n
254 //    ..
255 //    \n
256 //    [sections in order]
257 //
258 // Sections themselves include:
259 //    - "payload" - the payload bits
260 //    - "meta"    - metadata for the unpickler
261 //    - "0" ...   - tensor sections for the unpickler
262 //
263 // Note that per the header comments, the format is subject to change,
264 // and is best used for rpcs, rather than persistent disk storage.
265 std::unordered_map<std::string, std::pair<const char*, size_t>>
parseWireSections(const void * data,size_t data_size)266 parseWireSections(const void* data, size_t data_size) {
267   const char* ptr = static_cast<const char*>(data);
268   const char* endp = ptr + data_size;
269 
270   std::vector<std::pair<std::string, size_t>> headerEnts;
271   bool ok = false;
272   while (ptr != endp) {
273     if (*ptr == '\n') {
274       ok = true; // The only "correct" exit point.
275       ++ptr;
276       break;
277     }
278     // Parse name
279     const char* namePtr = ptr;
280     while (ptr != endp && *ptr != ' ') {
281       ptr++;
282     }
283     if (ptr == endp) {
284       break;
285     }
286     std::string name(namePtr, ptr - namePtr);
287     if (++ptr == endp) {
288       break; // past the ' '
289     }
290     // Parse size
291     const char* sizePtr = ptr;
292     while (ptr != endp && *ptr != '\n') {
293       ptr++;
294     }
295     if (ptr == endp) {
296       break;
297     }
298     size_t sz = std::stoll(std::string(sizePtr, ptr - sizePtr));
299     headerEnts.emplace_back(name, sz);
300     ++ptr; // past the '\n'
301   }
302   if (!ok) {
303     TORCH_CHECK(false, "failed parse");
304   }
305 
306   std::unordered_map<std::string, std::pair<const char*, size_t>> out;
307   for (const auto& headerEnt : headerEnts) {
308     out[headerEnt.first] = {ptr, headerEnt.second};
309     ptr += headerEnt.second;
310   }
311   if (ptr != endp) {
312     TORCH_CHECK(false, "failed bounds");
313   }
314   return out;
315 }
316 
317 static const char* kMeta = "meta";
318 static const char* kPayload = "payload";
319 }; // namespace
320 
cloneSparseTensors(const std::vector<at::Tensor> & tensors)321 c10::List<at::Tensor> cloneSparseTensors(
322     const std::vector<at::Tensor>& tensors) {
323   // Sanity-check: If the majority of bits don't need to go over the wire,
324   // force a clone(). Some Tensors are effectively small views, only using
325   // ~1% of the underlying Storage.
326   auto worthRecopying = [](const at::Tensor& t) -> bool {
327     if (!t.has_storage()) {
328       return false; // avoid throwing below.
329     }
330     auto storageSize = t.storage().nbytes();
331     auto usefulSize = t.element_size() * t.numel();
332     constexpr size_t kMinMultiple = 2;
333     constexpr size_t kMinRecopyBytes = 8 * 1024;
334     return storageSize >= kMinRecopyBytes &&
335         storageSize >= usefulSize * kMinMultiple;
336   };
337   c10::List<at::Tensor> pTensors;
338   pTensors.reserve(tensors.size());
339   for (const auto& t : tensors) {
340     pTensors.push_back(worthRecopying(t) ? t.clone() : t);
341   }
342   return pTensors;
343 }
344 
wireSerialize(const std::vector<char> & payload,const std::vector<at::Tensor> & tensors)345 std::string wireSerialize(
346     const std::vector<char>& payload,
347     const std::vector<at::Tensor>& tensors) {
348   for (const auto& tensor : tensors) {
349     TORCH_CHECK(
350         tensor.device().is_cpu(),
351         "ProcessGroup RPC backend only supports",
352         " CPU tensors, please move your tensors to CPU before sending ",
353         "them over RPC. Found tensor on device: ",
354         tensor.device());
355   }
356 
357   struct Ent {
358     std::string name;
359     const char* data;
360     size_t size;
361   };
362   std::vector<Ent> entries;
363   std::string metaEntry;
364   std::vector<at::Tensor> tensorData;
365 
366   if (!payload.empty()) {
367     entries.push_back({kPayload, payload.data(), payload.size()});
368   }
369 
370   if (!tensors.empty()) {
371     torch::jit::Pickler pickler([&](const void* buf, size_t sz) -> size_t {
372       metaEntry.append(static_cast<const char*>(buf), sz);
373       return sz;
374     });
375     pickler.protocol();
376     pickler.pushIValue(cloneSparseTensors(tensors));
377     pickler.stop();
378     tensorData = pickler.tensorData();
379     entries.push_back({kMeta, metaEntry.data(), metaEntry.size()});
380     for (const auto i : c10::irange(tensorData.size())) {
381       // Construct WritableTensorData for each tensor in the pickler tensorData
382       // Since tensorData is in function scope, and getWritableTensorData just
383       // record the tensors, the data() pointers stay valid for CPU tensors
384       // Note that RPC serde doesn't support CUDA tensors yet, if we should
385       // support CUDA tensor, we need to be careful since getWritableTensorData
386       // converts CUDA tensor to cpu and data() might get destructed as we go
387       // out of scope of this loop.
388       auto writeableTensorData = jit::getWriteableTensorData(tensorData[i]);
389       entries.push_back(
390           {std::to_string(i),
391            writeableTensorData.data(),
392            writeableTensorData.sizeInBytes()});
393     }
394   }
395 
396   std::string header;
397   size_t tot = 0;
398   for (const auto& e : entries) {
399     tot += e.size;
400     header.append(e.name)
401         .append(" ")
402         .append(std::to_string(e.size))
403         .append("\n");
404   }
405   header.push_back('\n');
406 
407   std::string out;
408   out.reserve(header.size() + tot);
409   out.append(header);
410   for (const auto& e : entries) {
411     out.append(e.data, e.size);
412   }
413   return out;
414 }
415 
wireDeserialize(const void * data,size_t data_size)416 std::pair<std::vector<char>, std::vector<at::Tensor>> wireDeserialize(
417     const void* data,
418     size_t data_size) {
419   auto sections = parseWireSections(data, data_size);
420 
421   std::vector<char> payload;
422   auto payloadIt = sections.find(kPayload);
423   if (payloadIt != sections.end() && payloadIt->second.second != 0) {
424     payload.assign(
425         payloadIt->second.first,
426         payloadIt->second.first + payloadIt->second.second);
427   }
428 
429   std::vector<at::Tensor> tensors;
430   auto metaIt = sections.find(kMeta);
431   if (metaIt != sections.end()) {
432     const auto& metaData = metaIt->second;
433     size_t metaDataPos = 0;
434     auto metaDataReadFunc = [&](char* buf, size_t n) -> size_t {
435       if (metaDataPos >= metaData.second || n == 0) {
436         return 0;
437       }
438       size_t toCopy = std::min(metaDataPos + n, metaData.second) - metaDataPos;
439       memcpy(buf, metaData.first + metaDataPos, toCopy);
440       metaDataPos += toCopy;
441       return toCopy;
442     };
443     auto sectionReadFunc = [&](const std::string& ename) -> at::DataPtr {
444       auto it = sections.find(ename);
445       if (it == sections.end()) {
446         TORCH_CHECK(false, "Couldn't find entity " + ename);
447       }
448       const auto& idat = it->second;
449       auto dptr = at::getCPUAllocator()->allocate(idat.second);
450       if (idat.second != 0) {
451         memcpy(dptr.get(), idat.first, idat.second);
452       }
453       return dptr;
454     };
455 
456     // No need to pass typeResolver here, as it always processes string and
457     // tensors only
458     torch::jit::Unpickler unpickler(
459         metaDataReadFunc, nullptr, nullptr, sectionReadFunc, {});
460     auto ival = unpickler.parse_ivalue();
461     for (auto&& t : ival.toTensorList()) {
462       tensors.emplace_back(std::move(t));
463     }
464   }
465   return {std::move(payload), std::move(tensors)};
466 }
467 
writeWrappedPayload(std::vector<char> & originalPayload,std::vector<char> & additionalPayload)468 void writeWrappedPayload(
469     std::vector<char>& originalPayload,
470     std::vector<char>& additionalPayload) {
471   originalPayload.insert(
472       originalPayload.end(),
473       additionalPayload.begin(),
474       additionalPayload.end());
475 
476   // Add size of the additional payload
477   int64_t indexToWrite = originalPayload.size();
478   originalPayload.resize(originalPayload.size() + sizeof(int64_t));
479   const int64_t additionalPayloadSize = additionalPayload.size();
480   torch::utils::THP_encodeInt64Buffer(
481       reinterpret_cast<uint8_t*>(originalPayload.data()) + indexToWrite,
482       &additionalPayloadSize,
483       torch::utils::THPByteOrder::THP_BIG_ENDIAN,
484       1);
485 }
486 
readWrappedPayload(std::vector<char> & payload,const rpc::Message & message)487 std::vector<at::IValue> readWrappedPayload(
488     std::vector<char>& payload,
489     const rpc::Message& message) {
490   // Read the additional payload remove it from the payload.
491   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
492   int64_t additionalPayloadSize;
493   TORCH_INTERNAL_ASSERT(payload.size() >= sizeof(int64_t));
494   size_t indexToRead = payload.size() - sizeof(int64_t);
495   torch::utils::THP_decodeInt64Buffer(
496       &additionalPayloadSize,
497       reinterpret_cast<uint8_t*>(payload.data()) + indexToRead,
498       torch::utils::THPByteOrder::THP_BIG_ENDIAN,
499       1);
500   payload.resize(indexToRead);
501 
502   TORCH_INTERNAL_ASSERT(
503       additionalPayloadSize > 0 &&
504           static_cast<int64_t>(payload.size()) > additionalPayloadSize,
505       "Wrong payload sizes: payload.size() is ",
506       payload.size(),
507       " but additional payload size is ",
508       additionalPayloadSize);
509   auto wrappedPayloadBegin =
510       static_cast<const char*>(message.payload().data()) + payload.size() -
511       additionalPayloadSize;
512   std::vector<torch::Tensor> tensorTable;
513   IValue tuple = jit::unpickle(
514       wrappedPayloadBegin,
515       additionalPayloadSize,
516       *rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
517       tensorTable);
518   std::vector<at::IValue> tupleElements = tuple.toTupleRef().elements().vec();
519   payload.resize(payload.size() - additionalPayloadSize);
520   return tupleElements;
521 }
522 
populateRemoteProfiledEvents(std::vector<LegacyEvent> & profiledEvents,const ProfilerConfig & profilingConfig,const std::vector<std::vector<LegacyEvent>> & eventLists)523 void populateRemoteProfiledEvents(
524     std::vector<LegacyEvent>& profiledEvents,
525     const ProfilerConfig& profilingConfig,
526     const std::vector<std::vector<LegacyEvent>>& eventLists) {
527   // Gather all events into a vector
528   for (auto& l : eventLists) {
529     for (auto& e : l) {
530       profiledEvents.push_back(e);
531     }
532   }
533   // find __start_profile event
534   bool cudaProfilingEnabled = profilingConfig.state == ProfilerState::CUDA;
535   const LegacyEvent* profilerStart = nullptr;
536 
537   for (auto& e : profiledEvents) {
538     if (std::string(e.name()) == "__start_profile") {
539       profilerStart = &e;
540       break;
541     }
542   }
543   // We should always find __start_profile.
544   TORCH_CHECK(
545       profilerStart != nullptr, "Expected to find __start_profile event.");
546 
547   if (cudaProfilingEnabled) {
548     // Deserialized events don't have the corresponding CUDA events, making it
549     // impossible to use cudaEventElapsedTime the receiving end. To avoid this,
550     // find all push/pop pairs of CUDA events and set the corresponding CUDA
551     // time to zero for the push event and to the elapsed time for the pop
552     // event, to be used later for the elapsed CUDA time computation.
553     std::unordered_map<at::RecordFunctionHandle, const LegacyEvent*>
554         startEvents;
555     for (auto& e : profiledEvents) {
556       if (e.hasCuda()) {
557         if (e.kind() == EventKind::PushRange) {
558           startEvents[e.handle()] = &e;
559         }
560       }
561     }
562     for (auto& e : profiledEvents) {
563       if (e.hasCuda()) {
564         if (e.kind() == EventKind::PopRange) {
565           auto it = startEvents.find(e.handle());
566           if (it != startEvents.end()) {
567             e.setCudaUs(it->second->cudaElapsedUs(e));
568           } else {
569             TORCH_WARN("Found a pop event without a corresponding push event");
570             e.setCudaUs(0);
571           }
572         } else {
573           e.setCudaUs(0);
574         }
575       }
576     }
577   }
578 }
579 
580 } // namespace torch::distributed::rpc
581