xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/script_remote_call.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/rpc_agent.h>
2 #include <torch/csrc/distributed/rpc/script_remote_call.h>
3 
4 #include <torch/csrc/jit/serialization/pickle.h>
5 
6 namespace torch::distributed::rpc {
7 
ScriptRemoteCall(std::shared_ptr<Operator> op,std::vector<at::IValue> && stack,const RRefId & retRRefId,const ForkId & retForkId)8 ScriptRemoteCall::ScriptRemoteCall(
9     std::shared_ptr<Operator> op,
10     std::vector<at::IValue>&& stack,
11     const RRefId& retRRefId,
12     const ForkId& retForkId)
13     : ScriptCall(std::move(op), std::move(stack)),
14       retRRefId_(retRRefId),
15       retForkId_(retForkId) {}
16 
ScriptRemoteCall(const c10::QualifiedName & qualifiedName,std::vector<at::IValue> && stack,const RRefId & retRRefId,const ForkId & retForkId,const bool isAsyncExecution)17 ScriptRemoteCall::ScriptRemoteCall(
18     const c10::QualifiedName& qualifiedName,
19     std::vector<at::IValue>&& stack,
20     const RRefId& retRRefId,
21     const ForkId& retForkId,
22     const bool isAsyncExecution)
23     : ScriptCall(qualifiedName, std::move(stack), isAsyncExecution),
24       retRRefId_(retRRefId),
25       retForkId_(retForkId) {}
26 
fromIValues(std::vector<at::IValue> & ivalues)27 std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromIValues(
28     std::vector<at::IValue>& ivalues) {
29   // remove the last element from values and convert it back to an RRef
30   auto retForkId = RRefId::fromIValue(ivalues.back());
31   ivalues.pop_back();
32   auto retRRefId = ForkId::fromIValue(ivalues.back());
33   ivalues.pop_back();
34 
35   auto scriptCallPtr = ScriptCall::fromIValues(ivalues);
36 
37   if (scriptCallPtr->hasOp()) {
38     return std::make_unique<ScriptRemoteCall>(
39         scriptCallPtr->op(), std::move(ivalues), retRRefId, retForkId);
40   } else {
41     return std::make_unique<ScriptRemoteCall>(
42         scriptCallPtr->qualifiedName(),
43         std::move(ivalues),
44         retRRefId,
45         retForkId,
46         scriptCallPtr->isAsyncExecution());
47   }
48 }
49 
toMessageImpl()50 c10::intrusive_ptr<Message> ScriptRemoteCall::toMessageImpl() && {
51   std::vector<IValue> ivalues;
52   ScriptCall::toIValues(ivalues);
53   ivalues.emplace_back(retRRefId_.toIValue());
54   ivalues.emplace_back(retForkId_.toIValue());
55 
56   std::vector<torch::Tensor> tensor_table;
57   auto payload = jit::pickle(
58       c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table);
59 
60   return c10::make_intrusive<Message>(
61       std::move(payload),
62       std::move(tensor_table),
63       MessageType::SCRIPT_REMOTE_CALL);
64 }
65 
fromMessage(const Message & message)66 std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromMessage(
67     const Message& message) {
68   auto payload = static_cast<const char*>(message.payload().data());
69   auto payload_size = message.payload().size();
70 
71   auto value = jit::unpickle(
72       payload,
73       payload_size,
74       *RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
75       message.tensors());
76   auto values = value.toTupleRef().elements().vec();
77   TORCH_CHECK(!values.empty(), "Malformed message: empty values unpickled");
78   return fromIValues(values);
79 }
80 
81 } // namespace torch::distributed::rpc
82