xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.h>
4 #include <torch/csrc/distributed/rpc/rpc_agent.h>
5 #include <torch/csrc/distributed/rpc/rpc_command_base.h>
6 
7 namespace torch {
8 namespace distributed {
9 namespace autograd {
10 
11 // Represents an RPC that includes autograd information. This class basically
12 // wraps another `RpcCommandBase` object which represents the actual RPC and has
13 // additional autograd information associated with that RPC.
14 class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
15  public:
16   // Used when we are sending an RPC over the wire.
17   RpcWithAutograd(
18       rpc::worker_id_t fromWorkerId,
19       rpc::MessageType messageType,
20       const AutogradMetadata& autogradMetadata,
21       c10::intrusive_ptr<rpc::Message> wrappedMessage,
22       rpc::DeviceMap deviceMap = {});
23 
24   // Used when receiving an RPC over the wire.
25   RpcWithAutograd(
26       rpc::worker_id_t fromWorkerId,
27       rpc::MessageType messageType,
28       const AutogradMetadata& autogradMetadata,
29       std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,
30       rpc::MessageType wrappedMessageType,
31       std::vector<torch::Tensor> tensors,
32       rpc::DeviceMap deviceMap = {});
33 
34   c10::intrusive_ptr<rpc::Message> toMessageImpl() && override;
35 
36   static std::unique_ptr<RpcWithAutograd> fromMessage(
37       const rpc::Message& message);
38 
39   // Retrieves tensors as part of this RPC, which need to be considered for
40   // autograd computations.
41   std::vector<torch::Tensor>& tensors();
42 
43   const AutogradMetadata& autogradMetadata() const;
44 
45   RpcCommandBase& wrappedRpc();
46 
47   void setWrappedRpc(std::unique_ptr<RpcCommandBase> wrappedRpc);
48 
49   std::unique_ptr<RpcCommandBase> moveWrappedRpc() &&;
50 
51   // Message type of the wrapped RPC.
52   rpc::MessageType wrappedMessageType() const;
53 
54   // Retrieve the worker id from which the RPC originated.
55   rpc::worker_id_t fromWorkerId() const;
56 
57   // Retrieve the device map.
58   const rpc::DeviceMap& deviceMap();
59 
60  private:
61   // WorkerId from which this RPC originated. This is necessary for knowing
62   // which worker we need to contact during the backward pass.
63   rpc::worker_id_t fromWorkerId_;
64 
65   // Message type for this call.
66   rpc::MessageType messageType_;
67 
68   AutogradMetadata autogradMetadata_;
69 
70   // Since wrappedMessage_ is destructively constructed from wrappedRpc_,
71   // they are valid exclusively. They are used for different purpose.
72   // wrappedRpc_ is used while constructing receive rpcWithAutograd;
73   // wrappedMessage_ is used while constructing send rpcWithAutograd;
74 
75   // When receive rpcWithAutograd is constructed fromMessage, it is valid;
76   // When send rpcWithAutograd is constructed before toMessage, it is nullptr;
77   std::unique_ptr<RpcCommandBase> wrappedRpc_;
78 
79   // Serialized message representing wrappedRpc_. Used mostly as a cache to
80   // avoid serializing the request twice.
81   // When receive rpcWithAutograd is constructed fromMessage, it is nullptr;
82   // When send rpcWithAutograd is constructed before toMessage, it is valid;
83   c10::intrusive_ptr<rpc::Message> wrappedMessage_;
84 
85   // message type of the wrappedMessage, this is stored separately since
86   // wrappedMessage_ is not always guaranteed to be populated.
87   rpc::MessageType wrappedMessageType_;
88 
89   // Tensors part of the wrappedRpc that need to be considered for autograd.
90   std::vector<torch::Tensor> tensors_;
91 
92   // Device mapping for tensors that are sent across an RPC to another node.
93   rpc::DeviceMap deviceMap_;
94 };
95 
96 } // namespace autograd
97 } // namespace distributed
98 } // namespace torch
99