1 #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h> 2 #include <torch/csrc/distributed/rpc/rpc_agent.h> 3 #include <torch/csrc/jit/serialization/pickle.h> 4 5 namespace torch { 6 namespace distributed { 7 namespace autograd { 8 CleanupAutogradContextReq(int64_t context_id)9CleanupAutogradContextReq::CleanupAutogradContextReq(int64_t context_id) 10 : context_id_(context_id){}; 11 getContextId()12int64_t CleanupAutogradContextReq::getContextId() { 13 return context_id_; 14 } 15 toMessageImpl()16c10::intrusive_ptr<rpc::Message> CleanupAutogradContextReq::toMessageImpl() && { 17 // pickle context_id using JIT pickler. 18 std::vector<torch::Tensor> tensorTable; 19 std::vector<char> payload = 20 jit::pickle(at::IValue(context_id_), &tensorTable); 21 return c10::make_intrusive<rpc::Message>( 22 std::move(payload), 23 std::move(tensorTable), 24 rpc::MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ); 25 } 26 27 std::unique_ptr<CleanupAutogradContextReq> CleanupAutogradContextReq:: fromMessage(const rpc::Message & message)28 fromMessage(const rpc::Message& message) { 29 // unpickle and get the context_id we need to clean up 30 auto payload = static_cast<const char*>(message.payload().data()); 31 auto payload_size = message.payload().size(); 32 IValue ivalue_context_id = jit::unpickle( 33 payload, 34 payload_size, 35 *rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(), 36 message.tensors()); 37 38 // convert ivalue to int and construct request 39 int64_t context_id = ivalue_context_id.toInt(); 40 return std::make_unique<CleanupAutogradContextReq>(context_id); 41 } 42 43 } // namespace autograd 44 } // namespace distributed 45 } // namespace torch 46