1 #pragma once 2 3 #include <torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.h> 4 #include <torch/csrc/distributed/rpc/rpc_agent.h> 5 #include <torch/csrc/distributed/rpc/rpc_command_base.h> 6 7 namespace torch { 8 namespace distributed { 9 namespace autograd { 10 11 // Represents an RPC that includes autograd information. This class basically 12 // wraps another `RpcCommandBase` object which represents the actual RPC and has 13 // additional autograd information associated with that RPC. 14 class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase { 15 public: 16 // Used when we are sending an RPC over the wire. 17 RpcWithAutograd( 18 rpc::worker_id_t fromWorkerId, 19 rpc::MessageType messageType, 20 const AutogradMetadata& autogradMetadata, 21 c10::intrusive_ptr<rpc::Message> wrappedMessage, 22 rpc::DeviceMap deviceMap = {}); 23 24 // Used when receiving an RPC over the wire. 25 RpcWithAutograd( 26 rpc::worker_id_t fromWorkerId, 27 rpc::MessageType messageType, 28 const AutogradMetadata& autogradMetadata, 29 std::unique_ptr<rpc::RpcCommandBase> wrappedRpc, 30 rpc::MessageType wrappedMessageType, 31 std::vector<torch::Tensor> tensors, 32 rpc::DeviceMap deviceMap = {}); 33 34 c10::intrusive_ptr<rpc::Message> toMessageImpl() && override; 35 36 static std::unique_ptr<RpcWithAutograd> fromMessage( 37 const rpc::Message& message); 38 39 // Retrieves tensors as part of this RPC, which need to be considered for 40 // autograd computations. 41 std::vector<torch::Tensor>& tensors(); 42 43 const AutogradMetadata& autogradMetadata() const; 44 45 RpcCommandBase& wrappedRpc(); 46 47 void setWrappedRpc(std::unique_ptr<RpcCommandBase> wrappedRpc); 48 49 std::unique_ptr<RpcCommandBase> moveWrappedRpc() &&; 50 51 // Message type of the wrapped RPC. 52 rpc::MessageType wrappedMessageType() const; 53 54 // Retrieve the worker id from which the RPC originated. 55 rpc::worker_id_t fromWorkerId() const; 56 57 // Retrieve the device map. 58 const rpc::DeviceMap& deviceMap(); 59 60 private: 61 // WorkerId from which this RPC originated. This is necessary for knowing 62 // which worker we need to contact during the backward pass. 63 rpc::worker_id_t fromWorkerId_; 64 65 // Message type for this call. 66 rpc::MessageType messageType_; 67 68 AutogradMetadata autogradMetadata_; 69 70 // Since wrappedMessage_ is destructively constructed from wrappedRpc_, 71 // they are valid exclusively. They are used for different purpose. 72 // wrappedRpc_ is used while constructing receive rpcWithAutograd; 73 // wrappedMessage_ is used while constructing send rpcWithAutograd; 74 75 // When receive rpcWithAutograd is constructed fromMessage, it is valid; 76 // When send rpcWithAutograd is constructed before toMessage, it is nullptr; 77 std::unique_ptr<RpcCommandBase> wrappedRpc_; 78 79 // Serialized message representing wrappedRpc_. Used mostly as a cache to 80 // avoid serializing the request twice. 81 // When receive rpcWithAutograd is constructed fromMessage, it is nullptr; 82 // When send rpcWithAutograd is constructed before toMessage, it is valid; 83 c10::intrusive_ptr<rpc::Message> wrappedMessage_; 84 85 // message type of the wrappedMessage, this is stored separately since 86 // wrappedMessage_ is not always guaranteed to be populated. 87 rpc::MessageType wrappedMessageType_; 88 89 // Tensors part of the wrappedRpc that need to be considered for autograd. 90 std::vector<torch::Tensor> tensors_; 91 92 // Device mapping for tensors that are sent across an RPC to another node. 93 rpc::DeviceMap deviceMap_; 94 }; 95 96 } // namespace autograd 97 } // namespace distributed 98 } // namespace torch 99