1 #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
2 #include <torch/csrc/distributed/rpc/rpc_agent.h>
3 #include <torch/csrc/jit/serialization/pickle.h>
4 
5 namespace torch {
6 namespace distributed {
7 namespace autograd {
8 
CleanupAutogradContextReq(int64_t context_id)9 CleanupAutogradContextReq::CleanupAutogradContextReq(int64_t context_id)
10     : context_id_(context_id){};
11 
getContextId()12 int64_t CleanupAutogradContextReq::getContextId() {
13   return context_id_;
14 }
15 
toMessageImpl()16 c10::intrusive_ptr<rpc::Message> CleanupAutogradContextReq::toMessageImpl() && {
17   // pickle context_id using JIT pickler.
18   std::vector<torch::Tensor> tensorTable;
19   std::vector<char> payload =
20       jit::pickle(at::IValue(context_id_), &tensorTable);
21   return c10::make_intrusive<rpc::Message>(
22       std::move(payload),
23       std::move(tensorTable),
24       rpc::MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ);
25 }
26 
27 std::unique_ptr<CleanupAutogradContextReq> CleanupAutogradContextReq::
fromMessage(const rpc::Message & message)28     fromMessage(const rpc::Message& message) {
29   // unpickle and get the context_id we need to clean up
30   auto payload = static_cast<const char*>(message.payload().data());
31   auto payload_size = message.payload().size();
32   IValue ivalue_context_id = jit::unpickle(
33       payload,
34       payload_size,
35       *rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
36       message.tensors());
37 
38   // convert ivalue to int and construct request
39   int64_t context_id = ivalue_context_id.toInt();
40   return std::make_unique<CleanupAutogradContextReq>(context_id);
41 }
42 
43 } // namespace autograd
44 } // namespace distributed
45 } // namespace torch
46