xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h>
2 #include <torch/csrc/distributed/rpc/utils.h>
3 #include <torch/csrc/jit/serialization/pickle.h>
4 #include <vector>
5 
6 namespace torch {
7 namespace distributed {
8 namespace autograd {
9 
10 constexpr auto kProfilingResponseElementExpectedSize = 3;
11 
12 using rpc::RpcCommandBase;
13 
14 // This constructor is called when creating the RpcWithProfilingReq on the
15 // client.
RpcWithProfilingReq(rpc::MessageType messageType,c10::intrusive_ptr<rpc::Message> wrappedMessage,torch::autograd::profiler::ProfilerConfig && profilerConfig,rpc::ProfilingId profilingKeyId)16 RpcWithProfilingReq::RpcWithProfilingReq(
17     rpc::MessageType messageType,
18     c10::intrusive_ptr<rpc::Message> wrappedMessage,
19     torch::autograd::profiler::ProfilerConfig&& profilerConfig,
20     rpc::ProfilingId profilingKeyId)
21     : messageType_(messageType),
22       wrappedMessage_(std::move(wrappedMessage)),
23       tensors_(wrappedMessage_->tensors()),
24       profilerConfig_(profilerConfig),
25       profilingKeyId_(profilingKeyId) {
26   TORCH_INTERNAL_ASSERT(
27       messageType_ == rpc::MessageType::RUN_WITH_PROFILING_REQ,
28       c10::str(
29           "Incorrect message type, expected message type ",
30           rpc::MessageType::RUN_WITH_PROFILING_REQ));
31   wrappedMessageType_ = wrappedMessage_->type();
32 }
33 
34 // this constructor is only called in fromMessage() which is called in
35 // deserializeRequest(). It is called when reconstructing the
36 // RpcWithProfilingReq on the remote end.
RpcWithProfilingReq(rpc::MessageType messageType,std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,rpc::MessageType wrappedMessageType,std::vector<torch::Tensor> tensors,torch::autograd::profiler::ProfilerConfig && profilerConfig,rpc::ProfilingId profilingKeyId)37 RpcWithProfilingReq::RpcWithProfilingReq(
38     rpc::MessageType messageType,
39     std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,
40     rpc::MessageType wrappedMessageType,
41     std::vector<torch::Tensor> tensors,
42     torch::autograd::profiler::ProfilerConfig&& profilerConfig,
43     rpc::ProfilingId profilingKeyId)
44     : messageType_(messageType),
45       wrappedRpc_(std::move(wrappedRpc)),
46       wrappedMessageType_(wrappedMessageType),
47       tensors_(std::move(tensors)),
48       profilerConfig_(profilerConfig),
49       profilingKeyId_(profilingKeyId) {
50   TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cant be null");
51 }
52 
wrappedMessageType() const53 rpc::MessageType RpcWithProfilingReq::wrappedMessageType() const {
54   return wrappedMessageType_;
55 }
56 
setWrappedRpc(std::unique_ptr<RpcCommandBase> wrappedRpc)57 void RpcWithProfilingReq::setWrappedRpc(
58     std::unique_ptr<RpcCommandBase> wrappedRpc) {
59   wrappedRpc_ = std::move(wrappedRpc);
60 }
61 
toMessageImpl()62 c10::intrusive_ptr<rpc::Message> RpcWithProfilingReq::toMessageImpl() && {
63   // save the original message ID and type before moving it.
64   auto wrappedMsgId = wrappedMessage_->id();
65   auto wrappedMsgType = wrappedMessage_->type();
66   // destructively move the wrappedMessage and get the payload. Now the payload
67   // of wrappedMessage won't be in a valid state.
68   auto wrappedPayload = std::move(*wrappedMessage_).movePayload();
69   // The wrapped payload should not be empty
70   TORCH_INTERNAL_ASSERT(
71       !wrappedPayload.empty(), "Wrapped payload should not be empty.");
72   // Create the ivalues to send over. We need to send the original message type
73   // and id, as well as some profiling metadata.
74   std::vector<at::IValue> ivalues{
75       wrappedMsgType, profilerConfig_.toIValue(), profilingKeyId_.toIValue()};
76   // Pickle it into a char payload to be sent over the wire.
77   std::vector<torch::Tensor> tensorTable;
78   std::vector<char> profilingPayload =
79       jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);
80   // add the profiling payload to the wrapped payload
81   rpc::writeWrappedPayload(wrappedPayload, profilingPayload);
82   // Put the wrapped payload into a message to return.
83   auto returnMsg = c10::make_intrusive<rpc::Message>(
84       std::move(wrappedPayload),
85       std::move(tensors_),
86       messageType_,
87       wrappedMsgId);
88 
89   return returnMsg;
90 }
91 
wrappedRpc()92 RpcCommandBase& RpcWithProfilingReq::wrappedRpc() {
93   TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
94   return *wrappedRpc_;
95 }
96 
97 torch::autograd::profiler::ProfilerConfig RpcWithProfilingReq::
getProfilingConfig() const98     getProfilingConfig() const {
99   return profilerConfig_;
100 }
101 
getProfilingId() const102 const rpc::ProfilingId& RpcWithProfilingReq::getProfilingId() const {
103   return profilingKeyId_;
104 }
105 
fromMessage(const rpc::Message & message)106 std::unique_ptr<RpcWithProfilingReq> RpcWithProfilingReq::fromMessage(
107     const rpc::Message& message) {
108   rpc::MessageType origMsgType = message.type();
109   std::vector<torch::Tensor> tensors = message.tensors();
110   int64_t msgId = message.id();
111   auto payload = message.payload();
112   auto tupleElements = rpc::readWrappedPayload(payload, message);
113   // Ensure that we have the expected number of elements
114   TORCH_INTERNAL_ASSERT(
115       tupleElements.size() == kProfilingResponseElementExpectedSize,
116       c10::str(
117           "Expected payload of size ",
118           kProfilingResponseElementExpectedSize,
119           " but got ",
120           tupleElements.size()));
121   rpc::MessageType wrappedMsgType =
122       static_cast<rpc::MessageType>(tupleElements[0].toInt());
123   // Create a config to be enabled on this node that is a replica of the
124   // state on the requesting node.
125   torch::autograd::profiler::ProfilerConfig cfg =
126       torch::autograd::profiler::ProfilerConfig::fromIValue(tupleElements[1]);
127 
128   rpc::ProfilingId profilerId = rpc::ProfilingId::fromIValue(tupleElements[2]);
129 
130   // Create new message type and build wrapped RPC
131   auto wrappedMessage = c10::make_intrusive<rpc::Message>(
132       std::move(payload), std::move(tensors), wrappedMsgType, msgId);
133   TORCH_INTERNAL_ASSERT(
134       wrappedMessage->isRequest(),
135       "Messages wrapped with profiling requests must be requests.");
136   std::unique_ptr<RpcCommandBase> wrappedRpc =
137       deserializeRequest(*wrappedMessage);
138 
139   return std::make_unique<RpcWithProfilingReq>(
140       origMsgType,
141       std::move(wrappedRpc),
142       wrappedMsgType,
143       std::move(wrappedMessage->tensors()),
144       std::move(cfg),
145       profilerId);
146 }
147 } // namespace autograd
148 } // namespace distributed
149 } // namespace torch
150