1 #pragma once 2 3 #include <torch/csrc/distributed/rpc/message.h> 4 #include <torch/csrc/distributed/rpc/rpc_command_base.h> 5 #include <torch/csrc/distributed/rpc/types.h> 6 #include <torch/csrc/jit/runtime/operator.h> 7 #include <torch/csrc/jit/serialization/pickler.h> 8 #include <vector> 9 10 namespace torch::distributed::rpc { 11 12 // Temporary solution of RRef operations. 13 // TODO: Remove all these messages and use rpc + registered functions instead. 14 class TORCH_API RRefMessageBase : public RpcCommandBase { 15 public: RRefMessageBase(const RRefId & rrefId,MessageType type)16 RRefMessageBase(const RRefId& rrefId, MessageType type) 17 : rrefId_(rrefId), type_(type) {} 18 19 const RRefId& rrefId(); 20 21 protected: 22 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 23 const RRefId rrefId_; 24 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 25 const MessageType type_; 26 }; 27 28 class TORCH_API ForkMessageBase : public RRefMessageBase { 29 public: ForkMessageBase(const RRefId & rrefId,const ForkId & forkId,MessageType type)30 ForkMessageBase(const RRefId& rrefId, const ForkId& forkId, MessageType type) 31 : RRefMessageBase(rrefId, type), forkId_(forkId) {} 32 33 const ForkId& forkId(); 34 35 c10::intrusive_ptr<Message> toMessageImpl() && override; 36 static std::pair<RRefId, ForkId> fromMessage( 37 const Message& message, 38 MessageType type); 39 40 protected: 41 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 42 const ForkId forkId_; 43 }; 44 45 // UserRRef uses this message to fetch the remote RRef value from the owner. 46 class TORCH_API ScriptRRefFetchCall final : public RRefMessageBase { 47 public: ScriptRRefFetchCall(worker_id_t fromWorkerId,const RRefId & rrefId)48 ScriptRRefFetchCall(worker_id_t fromWorkerId, const RRefId& rrefId) 49 : RRefMessageBase(rrefId, MessageType::SCRIPT_RREF_FETCH_CALL), 50 fromWorkerId_(fromWorkerId) {} 51 fromWorkerId()52 inline worker_id_t fromWorkerId() const { 53 return fromWorkerId_; 54 } 55 56 c10::intrusive_ptr<Message> toMessageImpl() && override; 57 static std::unique_ptr<ScriptRRefFetchCall> fromMessage( 58 const Message& message); 59 60 private: 61 const worker_id_t fromWorkerId_; 62 }; 63 64 class TORCH_API PythonRRefFetchCall final : public RRefMessageBase { 65 public: PythonRRefFetchCall(worker_id_t fromWorkerId,const RRefId & rrefId)66 PythonRRefFetchCall(worker_id_t fromWorkerId, const RRefId& rrefId) 67 : RRefMessageBase(rrefId, MessageType::PYTHON_RREF_FETCH_CALL), 68 fromWorkerId_(fromWorkerId) {} 69 70 c10::intrusive_ptr<Message> toMessageImpl() && override; 71 static std::unique_ptr<PythonRRefFetchCall> fromMessage( 72 const Message& message); 73 74 private: 75 const worker_id_t fromWorkerId_; 76 }; 77 78 // OwnerRRef uses this message to send the RRef value to a remote UserRRef 79 class TORCH_API RRefFetchRet : public RpcCommandBase { 80 public: RRefFetchRet(std::vector<at::IValue> values,MessageType type)81 RRefFetchRet(std::vector<at::IValue> values, MessageType type) 82 : values_(std::move(values)), type_(type) {} 83 84 const std::vector<at::IValue>& values(); 85 c10::intrusive_ptr<Message> toMessageImpl() && override; 86 87 private: 88 std::vector<at::IValue> values_; 89 const MessageType type_; 90 }; 91 92 class TORCH_API ScriptRRefFetchRet final : public RRefFetchRet { 93 public: ScriptRRefFetchRet(std::vector<at::IValue> values)94 explicit ScriptRRefFetchRet(std::vector<at::IValue> values) 95 : RRefFetchRet(std::move(values), MessageType::SCRIPT_RREF_FETCH_RET) {} 96 97 static std::unique_ptr<ScriptRRefFetchRet> fromMessage( 98 const Message& message); 99 }; 100 101 class TORCH_API PythonRRefFetchRet final : public RRefFetchRet { 102 public: PythonRRefFetchRet(std::vector<at::IValue> values)103 explicit PythonRRefFetchRet(std::vector<at::IValue> values) 104 : RRefFetchRet(std::move(values), MessageType::PYTHON_RREF_FETCH_RET) {} 105 106 static std::unique_ptr<PythonRRefFetchRet> fromMessage( 107 const Message& message); 108 }; 109 110 // UserRRef (regardless it's the creator or not) uses this message to notify 111 // OwnerRRef on delete. 112 class TORCH_API RRefUserDelete final : public ForkMessageBase { 113 public: RRefUserDelete(const RRefId & rrefId,const ForkId & forkId)114 RRefUserDelete(const RRefId& rrefId, const ForkId& forkId) 115 : ForkMessageBase(rrefId, forkId, MessageType::RREF_USER_DELETE) {} 116 117 static std::unique_ptr<RRefUserDelete> fromMessage(const Message& message); 118 }; 119 120 class TORCH_API RemoteRet final : public ForkMessageBase { 121 public: RemoteRet(const RRefId & rrefId,const ForkId & forkId)122 RemoteRet(const RRefId& rrefId, const ForkId& forkId) 123 : ForkMessageBase(rrefId, forkId, MessageType::REMOTE_RET) {} 124 125 static std::unique_ptr<RemoteRet> fromMessage(const Message& message); 126 }; 127 128 // A child RRef uses this message to notify its parent that the child has been 129 // confirmed by the owner. 130 class TORCH_API RRefChildAccept final : public RpcCommandBase { 131 public: RRefChildAccept(const ForkId & forkId)132 explicit RRefChildAccept(const ForkId& forkId) : forkId_(forkId) {} 133 134 const ForkId& forkId() const; 135 136 c10::intrusive_ptr<Message> toMessageImpl() && override; 137 static std::unique_ptr<RRefChildAccept> fromMessage(const Message& message); 138 139 private: 140 const ForkId forkId_; 141 }; 142 143 // A child RRef uses this message to send a fork request to the owner. 144 class TORCH_API RRefForkRequest final : public ForkMessageBase { 145 public: RRefForkRequest(const RRefId & rrefId,const ForkId & forkId)146 RRefForkRequest(const RRefId& rrefId, const ForkId& forkId) 147 : ForkMessageBase(rrefId, forkId, MessageType::RREF_FORK_REQUEST) {} 148 149 static std::unique_ptr<RRefForkRequest> fromMessage(const Message& message); 150 }; 151 152 class TORCH_API RRefAck final : public RpcCommandBase { 153 public: 154 RRefAck() = default; 155 156 c10::intrusive_ptr<Message> toMessageImpl() && override; 157 static std::unique_ptr<RRefAck> fromMessage(const Message& message); 158 }; 159 160 } // namespace torch::distributed::rpc 161