1 #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
2 #include <torch/csrc/distributed/rpc/rpc_agent.h>
3 #include <torch/csrc/jit/serialization/pickle.h>
4
5 #include <c10/util/irange.h>
6
7 namespace torch {
8 namespace distributed {
9 namespace autograd {
10
11 using rpc::Message;
12 using rpc::MessageType;
13 using torch::autograd::Variable;
14
PropagateGradientsReq(const AutogradMetadata & autogradMetadata,std::vector<Variable> grads,bool retainGraph)15 PropagateGradientsReq::PropagateGradientsReq(
16 const AutogradMetadata& autogradMetadata,
17 std::vector<Variable> grads,
18 bool retainGraph)
19 : autogradMetadata_(autogradMetadata),
20 grads_(std::move(grads)),
21 retainGraph_(retainGraph) {}
22
toMessageImpl()23 c10::intrusive_ptr<Message> PropagateGradientsReq::toMessageImpl() && {
24 std::vector<at::IValue> ivalues;
25 // Add all the grad tensors.
26 ivalues.reserve(grads_.size() + 3);
27 for (const auto& grad : grads_) {
28 ivalues.emplace_back(grad);
29 }
30
31 // Now add autograd metadata.
32 ivalues.emplace_back(autogradMetadata_.autogradContextId);
33 ivalues.emplace_back(autogradMetadata_.autogradMessageId);
34
35 // Add retain graph.
36 ivalues.emplace_back(retainGraph_);
37
38 // Now pickle using JIT pickler.
39 std::vector<torch::Tensor> tensorTable;
40 std::vector<char> payload =
41 jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);
42
43 return c10::make_intrusive<Message>(
44 std::move(payload),
45 std::move(tensorTable),
46 MessageType::BACKWARD_AUTOGRAD_REQ);
47 }
48
fromMessage(const Message & message)49 std::unique_ptr<PropagateGradientsReq> PropagateGradientsReq::fromMessage(
50 const Message& message) {
51 // Unpickle the message and retrieve tupleElements.
52 auto payload = static_cast<const char*>(message.payload().data());
53 auto payload_size = message.payload().size();
54 IValue tuple = jit::unpickle(
55 payload,
56 payload_size,
57 *rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
58 message.tensors());
59 const auto& tupleElements = tuple.toTupleRef().elements();
60
61 // Build PropagateGradientsReq.
62 TORCH_INTERNAL_ASSERT(tupleElements.size() >= 3);
63
64 // Retrieve retainGraph.
65 bool retainGraph = tupleElements.back().toBool();
66
67 // Build AutogradMetadata.
68 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
69 int64_t autogradContextId, autogradMessageId;
70 autogradMessageId = tupleElements[tupleElements.size() - 2].toInt();
71 autogradContextId = tupleElements[tupleElements.size() - 3].toInt();
72
73 AutogradMetadata autogradMetadata(autogradContextId, autogradMessageId);
74
75 // Retrieve the gradient tensors.
76 std::vector<Variable> grads(tupleElements.size() - 3);
77 for (const auto i : c10::irange(tupleElements.size() - 3)) {
78 grads[i] = tupleElements[i].toTensor();
79 }
80
81 return std::make_unique<PropagateGradientsReq>(
82 autogradMetadata, grads, retainGraph);
83 }
84
getAutogradMetadata()85 const AutogradMetadata& PropagateGradientsReq::getAutogradMetadata() {
86 return autogradMetadata_;
87 }
88
89 const std::vector<torch::autograd::Variable>& PropagateGradientsReq::
getGrads()90 getGrads() {
91 return grads_;
92 }
93
retainGraph()94 bool PropagateGradientsReq::retainGraph() {
95 return retainGraph_;
96 }
97
98 } // namespace autograd
99 } // namespace distributed
100 } // namespace torch
101