xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
2 #include <torch/csrc/distributed/rpc/rpc_agent.h>
3 #include <torch/csrc/jit/serialization/pickle.h>
4 
5 #include <c10/util/irange.h>
6 
7 namespace torch {
8 namespace distributed {
9 namespace autograd {
10 
11 using rpc::Message;
12 using rpc::MessageType;
13 using torch::autograd::Variable;
14 
PropagateGradientsReq(const AutogradMetadata & autogradMetadata,std::vector<Variable> grads,bool retainGraph)15 PropagateGradientsReq::PropagateGradientsReq(
16     const AutogradMetadata& autogradMetadata,
17     std::vector<Variable> grads,
18     bool retainGraph)
19     : autogradMetadata_(autogradMetadata),
20       grads_(std::move(grads)),
21       retainGraph_(retainGraph) {}
22 
toMessageImpl()23 c10::intrusive_ptr<Message> PropagateGradientsReq::toMessageImpl() && {
24   std::vector<at::IValue> ivalues;
25   // Add all the grad tensors.
26   ivalues.reserve(grads_.size() + 3);
27   for (const auto& grad : grads_) {
28     ivalues.emplace_back(grad);
29   }
30 
31   // Now add autograd metadata.
32   ivalues.emplace_back(autogradMetadata_.autogradContextId);
33   ivalues.emplace_back(autogradMetadata_.autogradMessageId);
34 
35   // Add retain graph.
36   ivalues.emplace_back(retainGraph_);
37 
38   // Now pickle using JIT pickler.
39   std::vector<torch::Tensor> tensorTable;
40   std::vector<char> payload =
41       jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);
42 
43   return c10::make_intrusive<Message>(
44       std::move(payload),
45       std::move(tensorTable),
46       MessageType::BACKWARD_AUTOGRAD_REQ);
47 }
48 
fromMessage(const Message & message)49 std::unique_ptr<PropagateGradientsReq> PropagateGradientsReq::fromMessage(
50     const Message& message) {
51   // Unpickle the message and retrieve tupleElements.
52   auto payload = static_cast<const char*>(message.payload().data());
53   auto payload_size = message.payload().size();
54   IValue tuple = jit::unpickle(
55       payload,
56       payload_size,
57       *rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
58       message.tensors());
59   const auto& tupleElements = tuple.toTupleRef().elements();
60 
61   // Build PropagateGradientsReq.
62   TORCH_INTERNAL_ASSERT(tupleElements.size() >= 3);
63 
64   // Retrieve retainGraph.
65   bool retainGraph = tupleElements.back().toBool();
66 
67   // Build AutogradMetadata.
68   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
69   int64_t autogradContextId, autogradMessageId;
70   autogradMessageId = tupleElements[tupleElements.size() - 2].toInt();
71   autogradContextId = tupleElements[tupleElements.size() - 3].toInt();
72 
73   AutogradMetadata autogradMetadata(autogradContextId, autogradMessageId);
74 
75   // Retrieve the gradient tensors.
76   std::vector<Variable> grads(tupleElements.size() - 3);
77   for (const auto i : c10::irange(tupleElements.size() - 3)) {
78     grads[i] = tupleElements[i].toTensor();
79   }
80 
81   return std::make_unique<PropagateGradientsReq>(
82       autogradMetadata, grads, retainGraph);
83 }
84 
getAutogradMetadata()85 const AutogradMetadata& PropagateGradientsReq::getAutogradMetadata() {
86   return autogradMetadata_;
87 }
88 
89 const std::vector<torch::autograd::Variable>& PropagateGradientsReq::
getGrads()90     getGrads() {
91   return grads_;
92 }
93 
retainGraph()94 bool PropagateGradientsReq::retainGraph() {
95   return retainGraph_;
96 }
97 
98 } // namespace autograd
99 } // namespace distributed
100 } // namespace torch
101