1 #pragma once 2 3 #include <torch/csrc/distributed/rpc/message.h> 4 #include <torch/csrc/distributed/rpc/rpc_command_base.h> 5 #include <torch/csrc/distributed/rpc/types.h> 6 7 namespace torch { 8 namespace distributed { 9 namespace autograd { 10 11 // Internal system RPC to invoke distributed backward pass on remote nodes when 12 // 'rref.backward()' is invoked. 13 class TORCH_API RRefBackwardReq : public rpc::RpcCommandBase { 14 public: 15 RRefBackwardReq( 16 const rpc::RRefId& rrefId, 17 int64_t autogradContextId, 18 bool retainGraph = false); 19 20 const rpc::RRefId& getRRefId() const; 21 22 int64_t getAutogradContextId() const; 23 24 bool retainGraph() const; 25 26 // Serialization and deserialization methods. 27 c10::intrusive_ptr<rpc::Message> toMessageImpl() && override; 28 static std::unique_ptr<RRefBackwardReq> fromMessage( 29 const rpc::Message& message); 30 31 private: 32 const rpc::RRefId rrefId_; 33 const int64_t autogradContextId_; 34 const bool retainGraph_; 35 }; 36 37 } // namespace autograd 38 } // namespace distributed 39 } // namespace torch 40