xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/rpc/message.h>
4 #include <torch/csrc/distributed/rpc/rpc_command_base.h>
5 #include <torch/csrc/distributed/rpc/types.h>
6 
7 namespace torch {
8 namespace distributed {
9 namespace autograd {
10 
11 // Internal system RPC to invoke distributed backward pass on remote nodes when
12 // 'rref.backward()' is invoked.
13 class TORCH_API RRefBackwardReq : public rpc::RpcCommandBase {
14  public:
15   RRefBackwardReq(
16       const rpc::RRefId& rrefId,
17       int64_t autogradContextId,
18       bool retainGraph = false);
19 
20   const rpc::RRefId& getRRefId() const;
21 
22   int64_t getAutogradContextId() const;
23 
24   bool retainGraph() const;
25 
26   // Serialization and deserialization methods.
27   c10::intrusive_ptr<rpc::Message> toMessageImpl() && override;
28   static std::unique_ptr<RRefBackwardReq> fromMessage(
29       const rpc::Message& message);
30 
31  private:
32   const rpc::RRefId rrefId_;
33   const int64_t autogradContextId_;
34   const bool retainGraph_;
35 };
36 
37 } // namespace autograd
38 } // namespace distributed
39 } // namespace torch
40