xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/rref_proto.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/rpc/message.h>
4 #include <torch/csrc/distributed/rpc/rpc_command_base.h>
5 #include <torch/csrc/distributed/rpc/types.h>
6 #include <torch/csrc/jit/runtime/operator.h>
7 #include <torch/csrc/jit/serialization/pickler.h>
8 #include <vector>
9 
10 namespace torch::distributed::rpc {
11 
12 // Temporary solution of RRef operations.
13 // TODO: Remove all these messages and use rpc + registered functions instead.
14 class TORCH_API RRefMessageBase : public RpcCommandBase {
15  public:
RRefMessageBase(const RRefId & rrefId,MessageType type)16   RRefMessageBase(const RRefId& rrefId, MessageType type)
17       : rrefId_(rrefId), type_(type) {}
18 
19   const RRefId& rrefId();
20 
21  protected:
22   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
23   const RRefId rrefId_;
24   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
25   const MessageType type_;
26 };
27 
28 class TORCH_API ForkMessageBase : public RRefMessageBase {
29  public:
ForkMessageBase(const RRefId & rrefId,const ForkId & forkId,MessageType type)30   ForkMessageBase(const RRefId& rrefId, const ForkId& forkId, MessageType type)
31       : RRefMessageBase(rrefId, type), forkId_(forkId) {}
32 
33   const ForkId& forkId();
34 
35   c10::intrusive_ptr<Message> toMessageImpl() && override;
36   static std::pair<RRefId, ForkId> fromMessage(
37       const Message& message,
38       MessageType type);
39 
40  protected:
41   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
42   const ForkId forkId_;
43 };
44 
45 // UserRRef uses this message to fetch the remote RRef value from the owner.
46 class TORCH_API ScriptRRefFetchCall final : public RRefMessageBase {
47  public:
ScriptRRefFetchCall(worker_id_t fromWorkerId,const RRefId & rrefId)48   ScriptRRefFetchCall(worker_id_t fromWorkerId, const RRefId& rrefId)
49       : RRefMessageBase(rrefId, MessageType::SCRIPT_RREF_FETCH_CALL),
50         fromWorkerId_(fromWorkerId) {}
51 
fromWorkerId()52   inline worker_id_t fromWorkerId() const {
53     return fromWorkerId_;
54   }
55 
56   c10::intrusive_ptr<Message> toMessageImpl() && override;
57   static std::unique_ptr<ScriptRRefFetchCall> fromMessage(
58       const Message& message);
59 
60  private:
61   const worker_id_t fromWorkerId_;
62 };
63 
64 class TORCH_API PythonRRefFetchCall final : public RRefMessageBase {
65  public:
PythonRRefFetchCall(worker_id_t fromWorkerId,const RRefId & rrefId)66   PythonRRefFetchCall(worker_id_t fromWorkerId, const RRefId& rrefId)
67       : RRefMessageBase(rrefId, MessageType::PYTHON_RREF_FETCH_CALL),
68         fromWorkerId_(fromWorkerId) {}
69 
70   c10::intrusive_ptr<Message> toMessageImpl() && override;
71   static std::unique_ptr<PythonRRefFetchCall> fromMessage(
72       const Message& message);
73 
74  private:
75   const worker_id_t fromWorkerId_;
76 };
77 
78 // OwnerRRef uses this message to send the RRef value to a remote UserRRef
79 class TORCH_API RRefFetchRet : public RpcCommandBase {
80  public:
RRefFetchRet(std::vector<at::IValue> values,MessageType type)81   RRefFetchRet(std::vector<at::IValue> values, MessageType type)
82       : values_(std::move(values)), type_(type) {}
83 
84   const std::vector<at::IValue>& values();
85   c10::intrusive_ptr<Message> toMessageImpl() && override;
86 
87  private:
88   std::vector<at::IValue> values_;
89   const MessageType type_;
90 };
91 
92 class TORCH_API ScriptRRefFetchRet final : public RRefFetchRet {
93  public:
ScriptRRefFetchRet(std::vector<at::IValue> values)94   explicit ScriptRRefFetchRet(std::vector<at::IValue> values)
95       : RRefFetchRet(std::move(values), MessageType::SCRIPT_RREF_FETCH_RET) {}
96 
97   static std::unique_ptr<ScriptRRefFetchRet> fromMessage(
98       const Message& message);
99 };
100 
101 class TORCH_API PythonRRefFetchRet final : public RRefFetchRet {
102  public:
PythonRRefFetchRet(std::vector<at::IValue> values)103   explicit PythonRRefFetchRet(std::vector<at::IValue> values)
104       : RRefFetchRet(std::move(values), MessageType::PYTHON_RREF_FETCH_RET) {}
105 
106   static std::unique_ptr<PythonRRefFetchRet> fromMessage(
107       const Message& message);
108 };
109 
110 // UserRRef (regardless it's the creator or not) uses this message to notify
111 // OwnerRRef on delete.
112 class TORCH_API RRefUserDelete final : public ForkMessageBase {
113  public:
RRefUserDelete(const RRefId & rrefId,const ForkId & forkId)114   RRefUserDelete(const RRefId& rrefId, const ForkId& forkId)
115       : ForkMessageBase(rrefId, forkId, MessageType::RREF_USER_DELETE) {}
116 
117   static std::unique_ptr<RRefUserDelete> fromMessage(const Message& message);
118 };
119 
120 class TORCH_API RemoteRet final : public ForkMessageBase {
121  public:
RemoteRet(const RRefId & rrefId,const ForkId & forkId)122   RemoteRet(const RRefId& rrefId, const ForkId& forkId)
123       : ForkMessageBase(rrefId, forkId, MessageType::REMOTE_RET) {}
124 
125   static std::unique_ptr<RemoteRet> fromMessage(const Message& message);
126 };
127 
128 // A child RRef uses this message to notify its parent that the child has been
129 // confirmed by the owner.
130 class TORCH_API RRefChildAccept final : public RpcCommandBase {
131  public:
RRefChildAccept(const ForkId & forkId)132   explicit RRefChildAccept(const ForkId& forkId) : forkId_(forkId) {}
133 
134   const ForkId& forkId() const;
135 
136   c10::intrusive_ptr<Message> toMessageImpl() && override;
137   static std::unique_ptr<RRefChildAccept> fromMessage(const Message& message);
138 
139  private:
140   const ForkId forkId_;
141 };
142 
143 // A child RRef uses this message to send a fork request to the owner.
144 class TORCH_API RRefForkRequest final : public ForkMessageBase {
145  public:
RRefForkRequest(const RRefId & rrefId,const ForkId & forkId)146   RRefForkRequest(const RRefId& rrefId, const ForkId& forkId)
147       : ForkMessageBase(rrefId, forkId, MessageType::RREF_FORK_REQUEST) {}
148 
149   static std::unique_ptr<RRefForkRequest> fromMessage(const Message& message);
150 };
151 
152 class TORCH_API RRefAck final : public RpcCommandBase {
153  public:
154   RRefAck() = default;
155 
156   c10::intrusive_ptr<Message> toMessageImpl() && override;
157   static std::unique_ptr<RRefAck> fromMessage(const Message& message);
158 };
159 
160 } // namespace torch::distributed::rpc
161