1 #include <torch/csrc/distributed/rpc/rpc_agent.h>
2 #include <torch/csrc/distributed/rpc/script_remote_call.h>
3
4 #include <torch/csrc/jit/serialization/pickle.h>
5
6 namespace torch::distributed::rpc {
7
ScriptRemoteCall(std::shared_ptr<Operator> op,std::vector<at::IValue> && stack,const RRefId & retRRefId,const ForkId & retForkId)8 ScriptRemoteCall::ScriptRemoteCall(
9 std::shared_ptr<Operator> op,
10 std::vector<at::IValue>&& stack,
11 const RRefId& retRRefId,
12 const ForkId& retForkId)
13 : ScriptCall(std::move(op), std::move(stack)),
14 retRRefId_(retRRefId),
15 retForkId_(retForkId) {}
16
ScriptRemoteCall(const c10::QualifiedName & qualifiedName,std::vector<at::IValue> && stack,const RRefId & retRRefId,const ForkId & retForkId,const bool isAsyncExecution)17 ScriptRemoteCall::ScriptRemoteCall(
18 const c10::QualifiedName& qualifiedName,
19 std::vector<at::IValue>&& stack,
20 const RRefId& retRRefId,
21 const ForkId& retForkId,
22 const bool isAsyncExecution)
23 : ScriptCall(qualifiedName, std::move(stack), isAsyncExecution),
24 retRRefId_(retRRefId),
25 retForkId_(retForkId) {}
26
fromIValues(std::vector<at::IValue> & ivalues)27 std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromIValues(
28 std::vector<at::IValue>& ivalues) {
29 // remove the last element from values and convert it back to an RRef
30 auto retForkId = RRefId::fromIValue(ivalues.back());
31 ivalues.pop_back();
32 auto retRRefId = ForkId::fromIValue(ivalues.back());
33 ivalues.pop_back();
34
35 auto scriptCallPtr = ScriptCall::fromIValues(ivalues);
36
37 if (scriptCallPtr->hasOp()) {
38 return std::make_unique<ScriptRemoteCall>(
39 scriptCallPtr->op(), std::move(ivalues), retRRefId, retForkId);
40 } else {
41 return std::make_unique<ScriptRemoteCall>(
42 scriptCallPtr->qualifiedName(),
43 std::move(ivalues),
44 retRRefId,
45 retForkId,
46 scriptCallPtr->isAsyncExecution());
47 }
48 }
49
toMessageImpl()50 c10::intrusive_ptr<Message> ScriptRemoteCall::toMessageImpl() && {
51 std::vector<IValue> ivalues;
52 ScriptCall::toIValues(ivalues);
53 ivalues.emplace_back(retRRefId_.toIValue());
54 ivalues.emplace_back(retForkId_.toIValue());
55
56 std::vector<torch::Tensor> tensor_table;
57 auto payload = jit::pickle(
58 c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table);
59
60 return c10::make_intrusive<Message>(
61 std::move(payload),
62 std::move(tensor_table),
63 MessageType::SCRIPT_REMOTE_CALL);
64 }
65
fromMessage(const Message & message)66 std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromMessage(
67 const Message& message) {
68 auto payload = static_cast<const char*>(message.payload().data());
69 auto payload_size = message.payload().size();
70
71 auto value = jit::unpickle(
72 payload,
73 payload_size,
74 *RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
75 message.tensors());
76 auto values = value.toTupleRef().elements().vec();
77 TORCH_CHECK(!values.empty(), "Malformed message: empty values unpickled");
78 return fromIValues(values);
79 }
80
81 } // namespace torch::distributed::rpc
82