xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/script_resp.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/script_resp.h>
2 
3 #include <torch/csrc/distributed/rpc/rpc_agent.h>
4 #include <torch/csrc/jit/serialization/pickle.h>
5 #include <torch/csrc/jit/serialization/unpickler.h>
6 
7 namespace torch::distributed::rpc {
8 
ScriptResp(at::IValue && value)9 ScriptResp::ScriptResp(at::IValue&& value) : value_(value) {}
10 
value()11 const at::IValue& ScriptResp::value() {
12   return value_;
13 }
14 
toMessageImpl()15 c10::intrusive_ptr<Message> ScriptResp::toMessageImpl() && {
16   std::vector<torch::Tensor> tensor_table;
17   auto payload = jit::pickle(value_, &tensor_table);
18   return c10::make_intrusive<Message>(
19       std::move(payload), std::move(tensor_table), MessageType::SCRIPT_RET);
20 }
21 
fromMessage(const Message & message)22 std::unique_ptr<ScriptResp> ScriptResp::fromMessage(const Message& message) {
23   auto payload = static_cast<const char*>(message.payload().data());
24   auto payload_size = message.payload().size();
25   auto value = jit::unpickle(
26       payload,
27       payload_size,
28       *RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
29       message.tensors());
30   return std::make_unique<ScriptResp>(std::move(value));
31 }
32 
33 } // namespace torch::distributed::rpc
34