xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/request_callback_impl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/request_callback_impl.h>
2 
3 #include <torch/csrc/autograd/profiler.h>
4 #include <torch/csrc/distributed/autograd/context/container.h>
5 #include <torch/csrc/distributed/autograd/context/context.h>
6 #include <torch/csrc/distributed/autograd/engine/dist_engine.h>
7 #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
8 #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
9 #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
10 #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
11 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
12 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h>
13 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h>
14 #include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h>
15 #include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_resp.h>
16 #include <torch/csrc/distributed/autograd/utils.h>
17 #include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h>
18 #include <torch/csrc/distributed/rpc/py_rref.h>
19 #include <torch/csrc/distributed/rpc/python_call.h>
20 #include <torch/csrc/distributed/rpc/python_remote_call.h>
21 #include <torch/csrc/distributed/rpc/python_resp.h>
22 #include <torch/csrc/distributed/rpc/python_rpc_handler.h>
23 #include <torch/csrc/distributed/rpc/rref_context.h>
24 #include <torch/csrc/distributed/rpc/rref_impl.h>
25 #include <torch/csrc/distributed/rpc/rref_proto.h>
26 #include <torch/csrc/distributed/rpc/script_call.h>
27 #include <torch/csrc/distributed/rpc/script_remote_call.h>
28 #include <torch/csrc/distributed/rpc/script_resp.h>
29 #include <torch/csrc/distributed/rpc/unpickled_python_call.h>
30 #include <torch/csrc/distributed/rpc/unpickled_python_remote_call.h>
31 #include <torch/csrc/distributed/rpc/utils.h>
32 #include <torch/csrc/jit/python/python_ivalue.h>
33 
34 #include <utility>
35 
36 namespace torch::distributed::rpc {
37 
38 using namespace torch::distributed::autograd;
39 
40 namespace {
41 
deserializePythonRpcCommandReference(RpcCommandBase & rpc,const MessageType & messageType)42 std::unique_ptr<RpcCommandBase> deserializePythonRpcCommandReference(
43     RpcCommandBase& rpc,
44     const MessageType& messageType) {
45   switch (messageType) {
46     case MessageType::PYTHON_CALL: {
47       auto& pc = static_cast<PythonCall&>(rpc);
48       return std::make_unique<UnpickledPythonCall>(
49           pc.serializedPyObj(), pc.isAsyncExecution());
50     }
51     case MessageType::PYTHON_REMOTE_CALL: {
52       auto& prc = static_cast<PythonRemoteCall&>(rpc);
53       return std::make_unique<UnpickledPythonRemoteCall>(
54           prc.serializedPyObj(),
55           prc.retRRefId(),
56           prc.retForkId(),
57           prc.isAsyncExecution());
58     }
59     case MessageType::FORWARD_AUTOGRAD_REQ: {
60       // Deserialize the wrapped RPC if it contains Python UDF
61       auto& rwa = static_cast<RpcWithAutograd&>(rpc);
62       auto& wrappedRpc = rwa.wrappedRpc();
63       auto pythonRpc = deserializePythonRpcCommandReference(
64           wrappedRpc, rwa.wrappedMessageType());
65       if (pythonRpc) {
66         rwa.setWrappedRpc(std::move(pythonRpc));
67       }
68       return nullptr;
69     }
70     case MessageType::RUN_WITH_PROFILING_REQ: {
71       // Deserialize wrapped RPC if it contains python call
72       auto& rpcWithProfilingReq = static_cast<RpcWithProfilingReq&>(rpc);
73       auto& wrappedRpc = rpcWithProfilingReq.wrappedRpc();
74       auto pythonRpc = deserializePythonRpcCommandReference(
75           wrappedRpc, rpcWithProfilingReq.wrappedMessageType());
76       if (pythonRpc) {
77         rpcWithProfilingReq.setWrappedRpc(std::move(pythonRpc));
78       }
79       return nullptr;
80     }
81     default: {
82       return nullptr;
83     }
84   }
85 }
86 
serializePyObject(IValue value)87 SerializedPyObj serializePyObject(IValue value) {
88   auto& pythonRpcHandler = PythonRpcHandler::getInstance();
89   // Need this GIL to guard jit::toPyObj and destruct its returned
90   // py::object
91   py::gil_scoped_acquire acquire;
92   try {
93     return pythonRpcHandler.serialize(jit::toPyObject(std::move(value)));
94   } catch (py::error_already_set& e) {
95     // py::error_already_set requires GIL to destruct, take special care.
96     std::string err_msg = e.what();
97     e.restore();
98     PyErr_Clear();
99     throw std::runtime_error(err_msg);
100   }
101 }
102 
103 } // anonymous namespace
104 
runPythonFunction(const py::object & function,const std::vector<c10::Stream> & streams,bool isAsyncExecution) const105 c10::intrusive_ptr<JitFuture> RequestCallbackImpl::runPythonFunction(
106     const py::object& function,
107     const std::vector<c10::Stream>& streams,
108     bool isAsyncExecution) const {
109   c10::MultiStreamGuard guard(streams);
110   auto& pythonRpcHandler = PythonRpcHandler::getInstance();
111   py::gil_scoped_acquire acquire;
112 
113   py::object result;
114   try {
115     result = pythonRpcHandler.runPythonUdf(function);
116   } catch (py::error_already_set& e) {
117     // py::error_already_set requires GIL to destruct, take special care.
118     auto future =
119         asFuture(std::make_exception_ptr(std::runtime_error(e.what())));
120     e.restore();
121     PyErr_Clear();
122     return future;
123   } catch (std::exception& e) {
124     return asFuture(std::current_exception());
125   }
126 
127   // After sync execution or failed async execution return the value as-is.
128   if (pythonRpcHandler.isRemoteException(result) || !isAsyncExecution) {
129     return asFuture(
130         c10::ivalue::ConcretePyObjectHolder::create(result),
131         at::PyObjectType::get());
132   }
133 
134   try {
135     return result.cast<jit::PythonFutureWrapper&>().fut;
136   } catch (const py::cast_error& e) {
137     auto type = result.get_type();
138     auto errMsg = c10::str(
139         e.what(),
140         ". Functions decorated with @rpc.async_function must return a "
141         "torch.futures.Future object, but got ",
142         type.attr("__module__").cast<std::string>(),
143         ".",
144         type.attr("__qualname__").cast<std::string>());
145     return asFuture(std::make_exception_ptr(std::runtime_error(errMsg)));
146   }
147 }
148 
149 std::unique_ptr<RpcCommandBase> RequestCallbackImpl::
deserializePythonRpcCommand(std::unique_ptr<RpcCommandBase> rpc,const MessageType & messageType) const150     deserializePythonRpcCommand(
151         std::unique_ptr<RpcCommandBase> rpc,
152         const MessageType& messageType) const {
153   auto pythonRpc = deserializePythonRpcCommandReference(*rpc, messageType);
154   return pythonRpc ? std::move(pythonRpc) : std::move(rpc);
155 }
156 
processScriptCall(RpcCommandBase & rpc,const std::vector<c10::Stream> & streams) const157 c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processScriptCall(
158     RpcCommandBase& rpc,
159     const std::vector<c10::Stream>& streams) const {
160   auto& scriptCall = static_cast<ScriptCall&>(rpc);
161 
162   c10::intrusive_ptr<JitFuture> future;
163   if (scriptCall.hasOp()) {
164     future = runJitOperator(*scriptCall.op(), scriptCall.stackRef(), streams);
165   } else {
166     future = runJitFunction(
167         scriptCall.qualifiedName(),
168         scriptCall.stackRef(),
169         streams,
170         scriptCall.isAsyncExecution());
171   }
172 
173   return future->then(
174       [](JitFuture& jitFuture) {
175         return withStorages(ScriptResp(jitFuture.value()).toMessage());
176       },
177       c10::getCustomClassType<c10::intrusive_ptr<Message>>());
178 }
179 
processPythonCall(RpcCommandBase & rpc,const std::vector<c10::Stream> & streams) const180 c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processPythonCall(
181     RpcCommandBase& rpc,
182     const std::vector<c10::Stream>& streams) const {
183   auto& upc = static_cast<UnpickledPythonCall&>(rpc);
184   auto future =
185       runPythonFunction(upc.pythonUdf(), streams, upc.isAsyncExecution());
186 
187   return future->then(
188       [](JitFuture& future) {
189         return withStorages(
190             PythonResp(serializePyObject(future.value())).toMessage());
191       },
192       c10::getCustomClassType<c10::intrusive_ptr<Message>>());
193 }
194 
processScriptRemoteCall(RpcCommandBase & rpc,const std::vector<c10::Stream> & streams) const195 c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processScriptRemoteCall(
196     RpcCommandBase& rpc,
197     const std::vector<c10::Stream>& streams) const {
198   auto& scriptRemoteCall = static_cast<ScriptRemoteCall&>(rpc);
199 
200   c10::intrusive_ptr<JitFuture> future;
201   if (scriptRemoteCall.hasOp()) {
202     future = runJitOperator(
203         *scriptRemoteCall.op(), scriptRemoteCall.stackRef(), streams);
204   } else {
205     future = runJitFunction(
206         scriptRemoteCall.qualifiedName(),
207         scriptRemoteCall.stackRef(),
208         streams,
209         scriptRemoteCall.isAsyncExecution());
210   }
211 
212   return assignOwnerRRef(
213       scriptRemoteCall.retRRefId(), scriptRemoteCall.retForkId(), future);
214 }
215 
processPythonRemoteCall(RpcCommandBase & rpc,const std::vector<c10::Stream> & streams) const216 c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processPythonRemoteCall(
217     RpcCommandBase& rpc,
218     const std::vector<c10::Stream>& streams) const {
219   auto& uprc = static_cast<UnpickledPythonRemoteCall&>(rpc);
220   auto future =
221       runPythonFunction(uprc.pythonUdf(), streams, uprc.isAsyncExecution());
222 
223   return assignOwnerRRef(uprc.rrefId(), uprc.forkId(), future);
224 }
225 
processPythonRRefFetchCall(RpcCommandBase & rpc) const226 c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processPythonRRefFetchCall(
227     RpcCommandBase& rpc) const {
228   auto& prf = static_cast<PythonRRefFetchCall&>(rpc);
229 
230   auto future = retrieveOwnerRRef(prf.rrefId());
231 
232   return future->then(
233       [](JitFuture& future) {
234         SerializedPyObj result = serializePyObject(future.value());
235         return withStorages(
236             PythonRRefFetchRet(std::move(result).toIValues()).toMessage());
237       },
238       c10::getCustomClassType<c10::intrusive_ptr<Message>>());
239 }
240 
handleRRefDelete(c10::intrusive_ptr<RRef> & rref) const241 void RequestCallbackImpl::handleRRefDelete(
242     c10::intrusive_ptr<RRef>& rref) const {
243   if (rref && rref->isPyObj()) {
244     py::gil_scoped_acquire acquire;
245     rref.reset();
246   }
247 }
248 
processRpcWithErrors(RpcCommandBase & rpc,const MessageType & messageType,const std::vector<c10::Stream> & streams) const249 c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processRpcWithErrors(
250     RpcCommandBase& rpc,
251     const MessageType& messageType,
252     const std::vector<c10::Stream>& streams) const {
253   try {
254     return processRpc(rpc, messageType, streams);
255   } catch (py::error_already_set& e) {
256     // Pass a dummy message ID since it will be overwritten anyways.
257     auto future = asFuture(handleError(e, messageType, -1));
258     // There are request callback impls in Python, where Python
259     // exceptions could be thrown. For releasing Python exception
260     // py::objects, GIL must be held.
261     py::gil_scoped_acquire acquire;
262     e.restore(); // Release ownership on py::objects and also restore
263                  // Python Error Indicator.
264     PyErr_Clear(); // Clear the Python Error Indicator as we has
265                    // recorded the exception in the response message.
266     return future;
267   } catch (std::exception& e) {
268     // Pass a dummy message ID since it will be overwritten anyways.
269     return asFuture(handleError(e, messageType, -1));
270   }
271 }
272 
cudaAvailable() const273 bool RequestCallbackImpl::cudaAvailable() const {
274 #ifdef USE_CUDA
275   return true;
276 #else
277   return false;
278 #endif
279 }
280 
processRRefBackward(RpcCommandBase & rpc) const281 c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processRRefBackward(
282     RpcCommandBase& rpc) const {
283   auto& rrefBackwardReq = static_cast<RRefBackwardReq&>(rpc);
284 
285   auto future = retrieveOwnerRRef(rrefBackwardReq.getRRefId());
286 
287   return future->then(
288       [autogradContextId = rrefBackwardReq.getAutogradContextId(),
289        retainGraph = rrefBackwardReq.retainGraph()](JitFuture& future) {
290         // Run backward (TODO: make this async?).
291         PyRRef::backwardOwnerRRef(
292             autogradContextId, retainGraph, future.value());
293 
294         return withStorages(RRefBackwardResp().toMessage());
295       },
296       c10::getCustomClassType<c10::intrusive_ptr<Message>>());
297 }
298 
runJitFunction(const c10::QualifiedName & name,std::vector<at::IValue> & stack,const std::vector<c10::Stream> & streams,bool isAsyncExecution) const299 c10::intrusive_ptr<JitFuture> RequestCallbackImpl::runJitFunction(
300     const c10::QualifiedName& name,
301     std::vector<at::IValue>& stack,
302     const std::vector<c10::Stream>& streams,
303     bool isAsyncExecution) const {
304   c10::MultiStreamGuard guard(streams);
305   c10::intrusive_ptr<JitFuture> future;
306   try {
307     // runAsync() starts in the calling thread, but may return an uncompleted
308     // future (though for non-async code, it will typically be completed).
309     // If it was async, our callback will typically be invoked by the
310     // continuation on an at::launch() thread.
311     future = PythonRpcHandler::getInstance()
312                  .jitCompilationUnit()
313                  ->get_function(name)
314                  .runAsync(stack);
315   } catch (const std::exception&) {
316     return asFuture(std::current_exception());
317   }
318 
319   if (isAsyncExecution) {
320     at::TypePtr type = future->elementType();
321     if (type->kind() != at::FutureType::Kind) {
322       return asFuture(std::make_exception_ptr(std::runtime_error(c10::str(
323           "Async functions must return an IValue of Future type, but got ",
324           type->str()))));
325     }
326     future = future->thenAsync(
327         [](JitFuture& future) { return future.value().toFuture(); },
328         type->cast<at::FutureType>()->getElementType());
329   }
330 
331   return future;
332 }
333 
334 } // namespace torch::distributed::rpc
335