1 #include <torch/csrc/distributed/rpc/request_callback_impl.h>
2
3 #include <torch/csrc/autograd/profiler.h>
4 #include <torch/csrc/distributed/autograd/context/container.h>
5 #include <torch/csrc/distributed/autograd/context/context.h>
6 #include <torch/csrc/distributed/autograd/engine/dist_engine.h>
7 #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
8 #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
9 #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
10 #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
11 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
12 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h>
13 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h>
14 #include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h>
15 #include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_resp.h>
16 #include <torch/csrc/distributed/autograd/utils.h>
17 #include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h>
18 #include <torch/csrc/distributed/rpc/py_rref.h>
19 #include <torch/csrc/distributed/rpc/python_call.h>
20 #include <torch/csrc/distributed/rpc/python_remote_call.h>
21 #include <torch/csrc/distributed/rpc/python_resp.h>
22 #include <torch/csrc/distributed/rpc/python_rpc_handler.h>
23 #include <torch/csrc/distributed/rpc/rref_context.h>
24 #include <torch/csrc/distributed/rpc/rref_impl.h>
25 #include <torch/csrc/distributed/rpc/rref_proto.h>
26 #include <torch/csrc/distributed/rpc/script_call.h>
27 #include <torch/csrc/distributed/rpc/script_remote_call.h>
28 #include <torch/csrc/distributed/rpc/script_resp.h>
29 #include <torch/csrc/distributed/rpc/unpickled_python_call.h>
30 #include <torch/csrc/distributed/rpc/unpickled_python_remote_call.h>
31 #include <torch/csrc/distributed/rpc/utils.h>
32 #include <torch/csrc/jit/python/python_ivalue.h>
33
34 #include <utility>
35
36 namespace torch::distributed::rpc {
37
38 using namespace torch::distributed::autograd;
39
40 namespace {
41
deserializePythonRpcCommandReference(RpcCommandBase & rpc,const MessageType & messageType)42 std::unique_ptr<RpcCommandBase> deserializePythonRpcCommandReference(
43 RpcCommandBase& rpc,
44 const MessageType& messageType) {
45 switch (messageType) {
46 case MessageType::PYTHON_CALL: {
47 auto& pc = static_cast<PythonCall&>(rpc);
48 return std::make_unique<UnpickledPythonCall>(
49 pc.serializedPyObj(), pc.isAsyncExecution());
50 }
51 case MessageType::PYTHON_REMOTE_CALL: {
52 auto& prc = static_cast<PythonRemoteCall&>(rpc);
53 return std::make_unique<UnpickledPythonRemoteCall>(
54 prc.serializedPyObj(),
55 prc.retRRefId(),
56 prc.retForkId(),
57 prc.isAsyncExecution());
58 }
59 case MessageType::FORWARD_AUTOGRAD_REQ: {
60 // Deserialize the wrapped RPC if it contains Python UDF
61 auto& rwa = static_cast<RpcWithAutograd&>(rpc);
62 auto& wrappedRpc = rwa.wrappedRpc();
63 auto pythonRpc = deserializePythonRpcCommandReference(
64 wrappedRpc, rwa.wrappedMessageType());
65 if (pythonRpc) {
66 rwa.setWrappedRpc(std::move(pythonRpc));
67 }
68 return nullptr;
69 }
70 case MessageType::RUN_WITH_PROFILING_REQ: {
71 // Deserialize wrapped RPC if it contains python call
72 auto& rpcWithProfilingReq = static_cast<RpcWithProfilingReq&>(rpc);
73 auto& wrappedRpc = rpcWithProfilingReq.wrappedRpc();
74 auto pythonRpc = deserializePythonRpcCommandReference(
75 wrappedRpc, rpcWithProfilingReq.wrappedMessageType());
76 if (pythonRpc) {
77 rpcWithProfilingReq.setWrappedRpc(std::move(pythonRpc));
78 }
79 return nullptr;
80 }
81 default: {
82 return nullptr;
83 }
84 }
85 }
86
serializePyObject(IValue value)87 SerializedPyObj serializePyObject(IValue value) {
88 auto& pythonRpcHandler = PythonRpcHandler::getInstance();
89 // Need this GIL to guard jit::toPyObj and destruct its returned
90 // py::object
91 py::gil_scoped_acquire acquire;
92 try {
93 return pythonRpcHandler.serialize(jit::toPyObject(std::move(value)));
94 } catch (py::error_already_set& e) {
95 // py::error_already_set requires GIL to destruct, take special care.
96 std::string err_msg = e.what();
97 e.restore();
98 PyErr_Clear();
99 throw std::runtime_error(err_msg);
100 }
101 }
102
103 } // anonymous namespace
104
runPythonFunction(const py::object & function,const std::vector<c10::Stream> & streams,bool isAsyncExecution) const105 c10::intrusive_ptr<JitFuture> RequestCallbackImpl::runPythonFunction(
106 const py::object& function,
107 const std::vector<c10::Stream>& streams,
108 bool isAsyncExecution) const {
109 c10::MultiStreamGuard guard(streams);
110 auto& pythonRpcHandler = PythonRpcHandler::getInstance();
111 py::gil_scoped_acquire acquire;
112
113 py::object result;
114 try {
115 result = pythonRpcHandler.runPythonUdf(function);
116 } catch (py::error_already_set& e) {
117 // py::error_already_set requires GIL to destruct, take special care.
118 auto future =
119 asFuture(std::make_exception_ptr(std::runtime_error(e.what())));
120 e.restore();
121 PyErr_Clear();
122 return future;
123 } catch (std::exception& e) {
124 return asFuture(std::current_exception());
125 }
126
127 // After sync execution or failed async execution return the value as-is.
128 if (pythonRpcHandler.isRemoteException(result) || !isAsyncExecution) {
129 return asFuture(
130 c10::ivalue::ConcretePyObjectHolder::create(result),
131 at::PyObjectType::get());
132 }
133
134 try {
135 return result.cast<jit::PythonFutureWrapper&>().fut;
136 } catch (const py::cast_error& e) {
137 auto type = result.get_type();
138 auto errMsg = c10::str(
139 e.what(),
140 ". Functions decorated with @rpc.async_function must return a "
141 "torch.futures.Future object, but got ",
142 type.attr("__module__").cast<std::string>(),
143 ".",
144 type.attr("__qualname__").cast<std::string>());
145 return asFuture(std::make_exception_ptr(std::runtime_error(errMsg)));
146 }
147 }
148
149 std::unique_ptr<RpcCommandBase> RequestCallbackImpl::
deserializePythonRpcCommand(std::unique_ptr<RpcCommandBase> rpc,const MessageType & messageType) const150 deserializePythonRpcCommand(
151 std::unique_ptr<RpcCommandBase> rpc,
152 const MessageType& messageType) const {
153 auto pythonRpc = deserializePythonRpcCommandReference(*rpc, messageType);
154 return pythonRpc ? std::move(pythonRpc) : std::move(rpc);
155 }
156
processScriptCall(RpcCommandBase & rpc,const std::vector<c10::Stream> & streams) const157 c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processScriptCall(
158 RpcCommandBase& rpc,
159 const std::vector<c10::Stream>& streams) const {
160 auto& scriptCall = static_cast<ScriptCall&>(rpc);
161
162 c10::intrusive_ptr<JitFuture> future;
163 if (scriptCall.hasOp()) {
164 future = runJitOperator(*scriptCall.op(), scriptCall.stackRef(), streams);
165 } else {
166 future = runJitFunction(
167 scriptCall.qualifiedName(),
168 scriptCall.stackRef(),
169 streams,
170 scriptCall.isAsyncExecution());
171 }
172
173 return future->then(
174 [](JitFuture& jitFuture) {
175 return withStorages(ScriptResp(jitFuture.value()).toMessage());
176 },
177 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
178 }
179
processPythonCall(RpcCommandBase & rpc,const std::vector<c10::Stream> & streams) const180 c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processPythonCall(
181 RpcCommandBase& rpc,
182 const std::vector<c10::Stream>& streams) const {
183 auto& upc = static_cast<UnpickledPythonCall&>(rpc);
184 auto future =
185 runPythonFunction(upc.pythonUdf(), streams, upc.isAsyncExecution());
186
187 return future->then(
188 [](JitFuture& future) {
189 return withStorages(
190 PythonResp(serializePyObject(future.value())).toMessage());
191 },
192 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
193 }
194
processScriptRemoteCall(RpcCommandBase & rpc,const std::vector<c10::Stream> & streams) const195 c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processScriptRemoteCall(
196 RpcCommandBase& rpc,
197 const std::vector<c10::Stream>& streams) const {
198 auto& scriptRemoteCall = static_cast<ScriptRemoteCall&>(rpc);
199
200 c10::intrusive_ptr<JitFuture> future;
201 if (scriptRemoteCall.hasOp()) {
202 future = runJitOperator(
203 *scriptRemoteCall.op(), scriptRemoteCall.stackRef(), streams);
204 } else {
205 future = runJitFunction(
206 scriptRemoteCall.qualifiedName(),
207 scriptRemoteCall.stackRef(),
208 streams,
209 scriptRemoteCall.isAsyncExecution());
210 }
211
212 return assignOwnerRRef(
213 scriptRemoteCall.retRRefId(), scriptRemoteCall.retForkId(), future);
214 }
215
processPythonRemoteCall(RpcCommandBase & rpc,const std::vector<c10::Stream> & streams) const216 c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processPythonRemoteCall(
217 RpcCommandBase& rpc,
218 const std::vector<c10::Stream>& streams) const {
219 auto& uprc = static_cast<UnpickledPythonRemoteCall&>(rpc);
220 auto future =
221 runPythonFunction(uprc.pythonUdf(), streams, uprc.isAsyncExecution());
222
223 return assignOwnerRRef(uprc.rrefId(), uprc.forkId(), future);
224 }
225
processPythonRRefFetchCall(RpcCommandBase & rpc) const226 c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processPythonRRefFetchCall(
227 RpcCommandBase& rpc) const {
228 auto& prf = static_cast<PythonRRefFetchCall&>(rpc);
229
230 auto future = retrieveOwnerRRef(prf.rrefId());
231
232 return future->then(
233 [](JitFuture& future) {
234 SerializedPyObj result = serializePyObject(future.value());
235 return withStorages(
236 PythonRRefFetchRet(std::move(result).toIValues()).toMessage());
237 },
238 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
239 }
240
handleRRefDelete(c10::intrusive_ptr<RRef> & rref) const241 void RequestCallbackImpl::handleRRefDelete(
242 c10::intrusive_ptr<RRef>& rref) const {
243 if (rref && rref->isPyObj()) {
244 py::gil_scoped_acquire acquire;
245 rref.reset();
246 }
247 }
248
processRpcWithErrors(RpcCommandBase & rpc,const MessageType & messageType,const std::vector<c10::Stream> & streams) const249 c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processRpcWithErrors(
250 RpcCommandBase& rpc,
251 const MessageType& messageType,
252 const std::vector<c10::Stream>& streams) const {
253 try {
254 return processRpc(rpc, messageType, streams);
255 } catch (py::error_already_set& e) {
256 // Pass a dummy message ID since it will be overwritten anyways.
257 auto future = asFuture(handleError(e, messageType, -1));
258 // There are request callback impls in Python, where Python
259 // exceptions could be thrown. For releasing Python exception
260 // py::objects, GIL must be held.
261 py::gil_scoped_acquire acquire;
262 e.restore(); // Release ownership on py::objects and also restore
263 // Python Error Indicator.
264 PyErr_Clear(); // Clear the Python Error Indicator as we has
265 // recorded the exception in the response message.
266 return future;
267 } catch (std::exception& e) {
268 // Pass a dummy message ID since it will be overwritten anyways.
269 return asFuture(handleError(e, messageType, -1));
270 }
271 }
272
cudaAvailable() const273 bool RequestCallbackImpl::cudaAvailable() const {
274 #ifdef USE_CUDA
275 return true;
276 #else
277 return false;
278 #endif
279 }
280
processRRefBackward(RpcCommandBase & rpc) const281 c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processRRefBackward(
282 RpcCommandBase& rpc) const {
283 auto& rrefBackwardReq = static_cast<RRefBackwardReq&>(rpc);
284
285 auto future = retrieveOwnerRRef(rrefBackwardReq.getRRefId());
286
287 return future->then(
288 [autogradContextId = rrefBackwardReq.getAutogradContextId(),
289 retainGraph = rrefBackwardReq.retainGraph()](JitFuture& future) {
290 // Run backward (TODO: make this async?).
291 PyRRef::backwardOwnerRRef(
292 autogradContextId, retainGraph, future.value());
293
294 return withStorages(RRefBackwardResp().toMessage());
295 },
296 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
297 }
298
runJitFunction(const c10::QualifiedName & name,std::vector<at::IValue> & stack,const std::vector<c10::Stream> & streams,bool isAsyncExecution) const299 c10::intrusive_ptr<JitFuture> RequestCallbackImpl::runJitFunction(
300 const c10::QualifiedName& name,
301 std::vector<at::IValue>& stack,
302 const std::vector<c10::Stream>& streams,
303 bool isAsyncExecution) const {
304 c10::MultiStreamGuard guard(streams);
305 c10::intrusive_ptr<JitFuture> future;
306 try {
307 // runAsync() starts in the calling thread, but may return an uncompleted
308 // future (though for non-async code, it will typically be completed).
309 // If it was async, our callback will typically be invoked by the
310 // continuation on an at::launch() thread.
311 future = PythonRpcHandler::getInstance()
312 .jitCompilationUnit()
313 ->get_function(name)
314 .runAsync(stack);
315 } catch (const std::exception&) {
316 return asFuture(std::current_exception());
317 }
318
319 if (isAsyncExecution) {
320 at::TypePtr type = future->elementType();
321 if (type->kind() != at::FutureType::Kind) {
322 return asFuture(std::make_exception_ptr(std::runtime_error(c10::str(
323 "Async functions must return an IValue of Future type, but got ",
324 type->str()))));
325 }
326 future = future->thenAsync(
327 [](JitFuture& future) { return future.value().toFuture(); },
328 type->cast<at::FutureType>()->getElementType());
329 }
330
331 return future;
332 }
333
334 } // namespace torch::distributed::rpc
335