1 #pragma once 2 3 #include <torch/csrc/distributed/autograd/context/context.h> 4 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h> 5 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h> 6 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h> 7 8 namespace torch { 9 namespace distributed { 10 namespace autograd { 11 12 // This method is used to attach the 'send' autograd function to the autograd 13 // graph when we use RPC. This method creates a new 'send' autograd function 14 // and attaches the provided tensors as next_edges to the 'send' function. In 15 // addition to this, it also registers the send function in the provided 16 // autograd context. Finally, the RPC message is updated with appropriate 17 // autograd information for the recipient. 18 TORCH_API void addSendRpcBackward( 19 const ContextPtr& autogradContext, 20 const AutogradMetadata& autogradMetadata, 21 std::vector<torch::Tensor>& tensors); 22 23 // This method is used to attach the 'recv' autograd function to the autograd 24 // graph when we use RPC. This method creates a new 'recv' autograd function 25 // and attaches the provided tensors as inputs to the 'recv' function. It 26 // creates a new autograd context if needed and registers the 'recv' function 27 // with this context. 28 // 29 // Returns a pointer to the autograd context created. 30 TORCH_API ContextPtr addRecvRpcBackward( 31 const AutogradMetadata& autogradMetadata, 32 std::vector<torch::Tensor>& tensors, 33 rpc::worker_id_t fromWorkerId, 34 const rpc::DeviceMap& deviceMap); 35 36 // This method is a wrapper utility used internally to wrap autograd info 37 // and attach autograd function for each type of rpc call if it has valid 38 // context and tensors require grads or forceGradRecording is true, in this 39 // case, return RpcWithAutograd message; otherwise return original rpc message. 40 // NB: forceGradRecording is useful when the request does not contain any tensor 41 // but the corresponding response does. 42 TORCH_API c10::intrusive_ptr<rpc::Message> getMessageWithAutograd( 43 const rpc::worker_id_t dstId, 44 c10::intrusive_ptr<rpc::Message> wrappedRpcMsg, 45 rpc::MessageType msgType, 46 bool forceGradRecording = false, 47 const rpc::DeviceMap& deviceMap = {}); 48 49 // Send message after autograd checking 50 TORCH_API c10::intrusive_ptr<c10::ivalue::Future> sendMessageWithAutograd( 51 rpc::RpcAgent& agent, 52 const rpc::WorkerInfo& dst, 53 c10::intrusive_ptr<rpc::Message> wrappedRpcMsg, 54 bool forceGradRecording = false, 55 const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout, 56 bool forceDisableProfiling = false); 57 58 } // namespace autograd 59 } // namespace distributed 60 } // namespace torch 61