xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/autograd/context/context.h>
4 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
5 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h>
6 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h>
7 
8 namespace torch {
9 namespace distributed {
10 namespace autograd {
11 
12 // This method is used to attach the 'send' autograd function to the autograd
13 // graph when we use RPC. This method creates a new 'send' autograd function
14 // and attaches the provided tensors as next_edges to the 'send' function. In
15 // addition to this, it also registers the send function in the provided
16 // autograd context. Finally, the RPC message is updated with appropriate
17 // autograd information for the recipient.
18 TORCH_API void addSendRpcBackward(
19     const ContextPtr& autogradContext,
20     const AutogradMetadata& autogradMetadata,
21     std::vector<torch::Tensor>& tensors);
22 
23 // This method is used to attach the 'recv' autograd function to the autograd
24 // graph when we use RPC. This method creates a new 'recv' autograd function
25 // and attaches the provided tensors as inputs to the 'recv' function. It
26 // creates a new autograd context if needed and registers the 'recv' function
27 // with this context.
28 //
29 // Returns a pointer to the autograd context created.
30 TORCH_API ContextPtr addRecvRpcBackward(
31     const AutogradMetadata& autogradMetadata,
32     std::vector<torch::Tensor>& tensors,
33     rpc::worker_id_t fromWorkerId,
34     const rpc::DeviceMap& deviceMap);
35 
36 // This method is a wrapper utility used internally to wrap autograd info
37 // and attach autograd function for each type of rpc call if it has valid
38 // context and tensors require grads or forceGradRecording is true, in this
39 // case, return RpcWithAutograd message; otherwise return original rpc message.
40 // NB: forceGradRecording is useful when the request does not contain any tensor
41 // but the corresponding response does.
42 TORCH_API c10::intrusive_ptr<rpc::Message> getMessageWithAutograd(
43     const rpc::worker_id_t dstId,
44     c10::intrusive_ptr<rpc::Message> wrappedRpcMsg,
45     rpc::MessageType msgType,
46     bool forceGradRecording = false,
47     const rpc::DeviceMap& deviceMap = {});
48 
49 // Send message after autograd checking
50 TORCH_API c10::intrusive_ptr<c10::ivalue::Future> sendMessageWithAutograd(
51     rpc::RpcAgent& agent,
52     const rpc::WorkerInfo& dst,
53     c10::intrusive_ptr<rpc::Message> wrappedRpcMsg,
54     bool forceGradRecording = false,
55     const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout,
56     bool forceDisableProfiling = false);
57 
58 } // namespace autograd
59 } // namespace distributed
60 } // namespace torch
61