xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ThreadLocalState.h>
2 #include <c10/util/ThreadLocalDebugInfo.h>
3 #include <torch/csrc/autograd/functions/utils.h>
4 #include <torch/csrc/autograd/profiler.h>
5 #include <torch/csrc/distributed/autograd/context/container.h>
6 #include <torch/csrc/distributed/autograd/functions/recvrpc_backward.h>
7 #include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h>
8 #include <torch/csrc/distributed/autograd/utils.h>
9 #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
10 #include <torch/csrc/distributed/rpc/rpc_agent.h>
11 #include <torch/csrc/distributed/rpc/types.h>
12 
13 namespace torch {
14 namespace distributed {
15 namespace autograd {
16 
17 using torch::distributed::autograd::AutogradMetadata;
18 using torch::distributed::autograd::RpcWithAutograd;
19 using torch::distributed::rpc::JitFuture;
20 using torch::distributed::rpc::Message;
21 using torch::distributed::rpc::MessageType;
22 using torch::distributed::rpc::RpcAgent;
23 using torch::distributed::rpc::WorkerInfo;
24 
addSendRpcBackward(const ContextPtr & autogradContext,const AutogradMetadata & autogradMetadata,std::vector<torch::Tensor> & tensors)25 void addSendRpcBackward(
26     const ContextPtr& autogradContext,
27     const AutogradMetadata& autogradMetadata,
28     std::vector<torch::Tensor>& tensors) {
29   // Attach autograd information only for tensors requiring grad.
30   std::vector<torch::Tensor> tensors_with_grad;
31   std::copy_if(
32       tensors.begin(),
33       tensors.end(),
34       std::back_inserter(tensors_with_grad),
35       [](const torch::Tensor& t) { return t.requires_grad(); });
36 
37   // Attach the appropriate autograd edges.
38   auto grad_fn = std::make_shared<SendRpcBackward>();
39   grad_fn->set_next_edges(
40       torch::autograd::collect_next_edges(tensors_with_grad));
41 
42   // Add the appropriate input metadata for the grad_fn.
43   for (const auto& tensor : tensors_with_grad) {
44     grad_fn->add_input_metadata(tensor);
45   }
46 
47   // Record the send autograd function in our current context.
48   autogradContext->addSendFunction(grad_fn, autogradMetadata.autogradMessageId);
49 }
50 
addRecvRpcBackward(const AutogradMetadata & autogradMetadata,std::vector<torch::Tensor> & tensors,rpc::worker_id_t fromWorkerId,const rpc::DeviceMap & deviceMap)51 ContextPtr addRecvRpcBackward(
52     const AutogradMetadata& autogradMetadata,
53     std::vector<torch::Tensor>& tensors,
54     rpc::worker_id_t fromWorkerId,
55     const rpc::DeviceMap& deviceMap) {
56   // Initialize autograd context if necessary.
57   auto& autogradContainer = DistAutogradContainer::getInstance();
58   auto autogradContext =
59       autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId);
60 
61   if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) {
62     // Attach the tensors as inputs to the autograd function.
63     auto grad_fn = std::make_shared<RecvRpcBackward>(
64         autogradMetadata, autogradContext, fromWorkerId, deviceMap);
65     for (auto& tensor : tensors) {
66       if (tensor.requires_grad()) {
67         torch::autograd::set_history(tensor, grad_fn);
68       }
69     }
70 
71     // Now update the autograd context with the necessary information.
72     autogradContext->addRecvFunction(
73         grad_fn, autogradMetadata.autogradMessageId);
74   }
75 
76   return autogradContext;
77 }
78 
getMessageWithProfiling(c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMessage,MessageType msgType,torch::autograd::profiler::ProfilerConfig && profilerConfig)79 static c10::intrusive_ptr<Message> getMessageWithProfiling(
80     c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMessage,
81     MessageType msgType,
82     torch::autograd::profiler::ProfilerConfig&& profilerConfig) {
83   auto& remoteProfilerManager =
84       torch::distributed::rpc::RemoteProfilerManager::getInstance();
85 
86   auto key = remoteProfilerManager.getCurrentProfilingKey();
87   // generate a globally unique Id
88   auto globallyUniqueProfilingId = remoteProfilerManager.getNextProfilerId();
89   // Save a mapping of ID -> RPC profiling key and unset the current TLS key.
90   remoteProfilerManager.saveRPCKey(globallyUniqueProfilingId, key);
91   remoteProfilerManager.unsetCurrentKey();
92   auto wrappedProfilingMsg = RpcWithProfilingReq(
93       msgType,
94       std::move(wrappedRpcMessage),
95       std::move(profilerConfig),
96       globallyUniqueProfilingId);
97 
98   return std::move(wrappedProfilingMsg).toMessage();
99 }
100 
getMessageWithAutograd(const rpc::worker_id_t dstId,c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMsg,MessageType msgType,bool forceGradRecording,const rpc::DeviceMap & deviceMap)101 c10::intrusive_ptr<Message> getMessageWithAutograd(
102     const rpc::worker_id_t dstId,
103     c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMsg,
104     MessageType msgType,
105     bool forceGradRecording,
106     const rpc::DeviceMap& deviceMap) {
107   auto& autogradContainer = DistAutogradContainer::getInstance();
108 
109   // If there is no valid context and no tensor requires grads, send original
110   // rpc message. otherwise, attach grad info and grad functions and send
111   // rpcWithAutograd message.
112   auto tensorsRequireGrad =
113       torch::autograd::compute_requires_grad(wrappedRpcMsg->tensors());
114   if (!autogradContainer.hasValidContext() ||
115       (!forceGradRecording && !tensorsRequireGrad)) {
116     return wrappedRpcMsg;
117   }
118 
119   // Retrieve the appropriate context to modify.
120   auto autogradContext = autogradContainer.currentContext();
121 
122   // Wrap the original rpc with autograd information.
123   AutogradMetadata autogradMetadata(
124       autogradContext->contextId(), autogradContainer.newAutogradMessageId());
125   auto rpcWithAutograd = std::make_unique<RpcWithAutograd>(
126       RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_,
127       msgType,
128       autogradMetadata,
129       std::move(wrappedRpcMsg),
130       deviceMap);
131 
132   if (tensorsRequireGrad) {
133     // Record autograd information for 'send'.
134     addSendRpcBackward(
135         autogradContext, autogradMetadata, rpcWithAutograd->tensors());
136   }
137   // Record the workerID
138   autogradContext->addKnownWorkerId(dstId);
139 
140   return std::move(*rpcWithAutograd).toMessage();
141 }
142 
sendMessageWithAutograd(RpcAgent & agent,const WorkerInfo & dst,c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMsg,bool forceGradRecording,const float rpcTimeoutSeconds,bool forceDisableProfiling)143 c10::intrusive_ptr<JitFuture> sendMessageWithAutograd(
144     RpcAgent& agent,
145     const WorkerInfo& dst,
146     c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMsg,
147     bool forceGradRecording,
148     const float rpcTimeoutSeconds,
149     bool forceDisableProfiling) {
150   auto msg = getMessageWithAutograd(
151       dst.id_,
152       std::move(wrappedRpcMsg),
153       MessageType::FORWARD_AUTOGRAD_REQ,
154       forceGradRecording,
155       agent.getDeviceMap(dst));
156 
157   // If profiler is enabled, wrap this message with profiling metadata that will
158   // tell the remote end to process this request with the profiler enabled.
159   if (!forceDisableProfiling) {
160     switch (torch::profiler::impl::profilerType()) {
161       case torch::profiler::impl::ActiveProfilerType::LEGACY: {
162         auto profilerConfig = torch::autograd::profiler::getProfilerConfig();
163         auto msgWithProfiling = getMessageWithProfiling(
164             std::move(msg),
165             rpc::MessageType::RUN_WITH_PROFILING_REQ,
166             std::move(profilerConfig));
167         return agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds);
168       }
169       case torch::profiler::impl::ActiveProfilerType::KINETO:
170         TORCH_WARN_ONCE(
171             "Profiling a distributed call with the Kineto profiler will profile "
172             "the caller, but not the worker.");
173         break;
174       default:
175         break;
176     }
177   }
178 
179   return agent.send(dst, std::move(msg), rpcTimeoutSeconds);
180   ;
181 }
182 
183 } // namespace autograd
184 } // namespace distributed
185 } // namespace torch
186