1 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h>
2 #include <torch/csrc/distributed/rpc/utils.h>
3 #include <torch/csrc/jit/serialization/pickle.h>
4 #include <vector>
5
6 namespace torch {
7 namespace distributed {
8 namespace autograd {
9
10 constexpr auto kProfilingResponseElementExpectedSize = 3;
11
12 using rpc::RpcCommandBase;
13
14 // This constructor is called when creating the RpcWithProfilingReq on the
15 // client.
RpcWithProfilingReq(rpc::MessageType messageType,c10::intrusive_ptr<rpc::Message> wrappedMessage,torch::autograd::profiler::ProfilerConfig && profilerConfig,rpc::ProfilingId profilingKeyId)16 RpcWithProfilingReq::RpcWithProfilingReq(
17 rpc::MessageType messageType,
18 c10::intrusive_ptr<rpc::Message> wrappedMessage,
19 torch::autograd::profiler::ProfilerConfig&& profilerConfig,
20 rpc::ProfilingId profilingKeyId)
21 : messageType_(messageType),
22 wrappedMessage_(std::move(wrappedMessage)),
23 tensors_(wrappedMessage_->tensors()),
24 profilerConfig_(profilerConfig),
25 profilingKeyId_(profilingKeyId) {
26 TORCH_INTERNAL_ASSERT(
27 messageType_ == rpc::MessageType::RUN_WITH_PROFILING_REQ,
28 c10::str(
29 "Incorrect message type, expected message type ",
30 rpc::MessageType::RUN_WITH_PROFILING_REQ));
31 wrappedMessageType_ = wrappedMessage_->type();
32 }
33
34 // this constructor is only called in fromMessage() which is called in
35 // deserializeRequest(). It is called when reconstructing the
36 // RpcWithProfilingReq on the remote end.
RpcWithProfilingReq(rpc::MessageType messageType,std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,rpc::MessageType wrappedMessageType,std::vector<torch::Tensor> tensors,torch::autograd::profiler::ProfilerConfig && profilerConfig,rpc::ProfilingId profilingKeyId)37 RpcWithProfilingReq::RpcWithProfilingReq(
38 rpc::MessageType messageType,
39 std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,
40 rpc::MessageType wrappedMessageType,
41 std::vector<torch::Tensor> tensors,
42 torch::autograd::profiler::ProfilerConfig&& profilerConfig,
43 rpc::ProfilingId profilingKeyId)
44 : messageType_(messageType),
45 wrappedRpc_(std::move(wrappedRpc)),
46 wrappedMessageType_(wrappedMessageType),
47 tensors_(std::move(tensors)),
48 profilerConfig_(profilerConfig),
49 profilingKeyId_(profilingKeyId) {
50 TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cant be null");
51 }
52
wrappedMessageType() const53 rpc::MessageType RpcWithProfilingReq::wrappedMessageType() const {
54 return wrappedMessageType_;
55 }
56
setWrappedRpc(std::unique_ptr<RpcCommandBase> wrappedRpc)57 void RpcWithProfilingReq::setWrappedRpc(
58 std::unique_ptr<RpcCommandBase> wrappedRpc) {
59 wrappedRpc_ = std::move(wrappedRpc);
60 }
61
toMessageImpl()62 c10::intrusive_ptr<rpc::Message> RpcWithProfilingReq::toMessageImpl() && {
63 // save the original message ID and type before moving it.
64 auto wrappedMsgId = wrappedMessage_->id();
65 auto wrappedMsgType = wrappedMessage_->type();
66 // destructively move the wrappedMessage and get the payload. Now the payload
67 // of wrappedMessage won't be in a valid state.
68 auto wrappedPayload = std::move(*wrappedMessage_).movePayload();
69 // The wrapped payload should not be empty
70 TORCH_INTERNAL_ASSERT(
71 !wrappedPayload.empty(), "Wrapped payload should not be empty.");
72 // Create the ivalues to send over. We need to send the original message type
73 // and id, as well as some profiling metadata.
74 std::vector<at::IValue> ivalues{
75 wrappedMsgType, profilerConfig_.toIValue(), profilingKeyId_.toIValue()};
76 // Pickle it into a char payload to be sent over the wire.
77 std::vector<torch::Tensor> tensorTable;
78 std::vector<char> profilingPayload =
79 jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);
80 // add the profiling payload to the wrapped payload
81 rpc::writeWrappedPayload(wrappedPayload, profilingPayload);
82 // Put the wrapped payload into a message to return.
83 auto returnMsg = c10::make_intrusive<rpc::Message>(
84 std::move(wrappedPayload),
85 std::move(tensors_),
86 messageType_,
87 wrappedMsgId);
88
89 return returnMsg;
90 }
91
wrappedRpc()92 RpcCommandBase& RpcWithProfilingReq::wrappedRpc() {
93 TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
94 return *wrappedRpc_;
95 }
96
97 torch::autograd::profiler::ProfilerConfig RpcWithProfilingReq::
getProfilingConfig() const98 getProfilingConfig() const {
99 return profilerConfig_;
100 }
101
getProfilingId() const102 const rpc::ProfilingId& RpcWithProfilingReq::getProfilingId() const {
103 return profilingKeyId_;
104 }
105
fromMessage(const rpc::Message & message)106 std::unique_ptr<RpcWithProfilingReq> RpcWithProfilingReq::fromMessage(
107 const rpc::Message& message) {
108 rpc::MessageType origMsgType = message.type();
109 std::vector<torch::Tensor> tensors = message.tensors();
110 int64_t msgId = message.id();
111 auto payload = message.payload();
112 auto tupleElements = rpc::readWrappedPayload(payload, message);
113 // Ensure that we have the expected number of elements
114 TORCH_INTERNAL_ASSERT(
115 tupleElements.size() == kProfilingResponseElementExpectedSize,
116 c10::str(
117 "Expected payload of size ",
118 kProfilingResponseElementExpectedSize,
119 " but got ",
120 tupleElements.size()));
121 rpc::MessageType wrappedMsgType =
122 static_cast<rpc::MessageType>(tupleElements[0].toInt());
123 // Create a config to be enabled on this node that is a replica of the
124 // state on the requesting node.
125 torch::autograd::profiler::ProfilerConfig cfg =
126 torch::autograd::profiler::ProfilerConfig::fromIValue(tupleElements[1]);
127
128 rpc::ProfilingId profilerId = rpc::ProfilingId::fromIValue(tupleElements[2]);
129
130 // Create new message type and build wrapped RPC
131 auto wrappedMessage = c10::make_intrusive<rpc::Message>(
132 std::move(payload), std::move(tensors), wrappedMsgType, msgId);
133 TORCH_INTERNAL_ASSERT(
134 wrappedMessage->isRequest(),
135 "Messages wrapped with profiling requests must be requests.");
136 std::unique_ptr<RpcCommandBase> wrappedRpc =
137 deserializeRequest(*wrappedMessage);
138
139 return std::make_unique<RpcWithProfilingReq>(
140 origMsgType,
141 std::move(wrappedRpc),
142 wrappedMsgType,
143 std::move(wrappedMessage->tensors()),
144 std::move(cfg),
145 profilerId);
146 }
147 } // namespace autograd
148 } // namespace distributed
149 } // namespace torch
150