1 #include <torch/csrc/distributed/rpc/python_rpc_handler.h>
2 #include <torch/csrc/distributed/rpc/rpc_agent.h>
3 #include <torch/csrc/jit/python/pybind_utils.h>
4 #include <torch/csrc/utils/python_compat.h>
5
6 namespace torch::distributed::rpc {
7
8 namespace {
9
10 constexpr auto kInternalModule = "torch.distributed.rpc.internal";
11
12 // A macro that grabs the GIL, profiling the acquisition time. The average GIL
13 // acquisition time will be recorded in RpcAgent's getMetrics().
14 #define PROFILE_GIL_SCOPED_ACQUIRE \
15 std::chrono::time_point<std::chrono::high_resolution_clock> startTime; \
16 auto shouldProfileGIL = \
17 RpcAgent::getCurrentRpcAgent() -> isGILProfilingEnabled(); \
18 if (shouldProfileGIL) { \
19 startTime = std::chrono::high_resolution_clock::now(); \
20 } \
21 pybind11::gil_scoped_acquire ag; \
22 if (shouldProfileGIL) { \
23 auto dur = std::chrono::duration_cast<std::chrono::microseconds>( \
24 std::chrono::high_resolution_clock::now() - startTime); \
25 RpcAgent::getCurrentRpcAgent()->addGilWaitTime(dur); \
26 } // NOLINT
27
28 // PythonTypeResolver that inherits from Script::Resolver to
29 // support resolving types together with ScriptTypeParser.
30 struct PythonTypeResolver : public jit::Resolver {
resolveValuetorch::distributed::rpc::__anon7a1394420111::PythonTypeResolver31 std::shared_ptr<jit::SugaredValue> resolveValue(
32 const std::string& /* unused */,
33 torch::jit::GraphFunction& /* unused */,
34 const jit::SourceRange& /* unused */) override {
35 TORCH_INTERNAL_ASSERT(
36 false, "RPC Type resolver does not need to resolve value");
37 }
38
resolveTypetorch::distributed::rpc::__anon7a1394420111::PythonTypeResolver39 TypePtr resolveType(
40 const std::string& name,
41 const jit::SourceRange& /* unused */) override {
42 if (name == "PyObject") {
43 return PyObjectType::get();
44 }
45 return PythonRpcHandler::getInstance().jitCompilationUnit()->get_type(name);
46 }
47 };
48
getFunction(const py::object & module,const char * name)49 py::object getFunction(const py::object& module, const char* name) {
50 py::object fn = module.attr(name);
51 TORCH_CHECK(
52 py::isinstance<py::function>(fn),
53 "attribute ",
54 name,
55 " is not a function");
56 return fn;
57 }
58
cleanupPyObj(py::object & obj)59 void cleanupPyObj(py::object& obj) {
60 obj.dec_ref();
61 // explicitly setting PyObject* to nullptr to prevent py::object's dtor to
62 // decref on the PyObject again.
63 // See Note [Destructing py::object] in python_ivalue.h
64 obj.ptr() = nullptr;
65 }
66
67 } // namespace
68
init()69 void PythonRpcHandler::init() {
70 std::lock_guard<std::mutex> guard(init_lock_);
71 if (!initialized_) {
72 PROFILE_GIL_SCOPED_ACQUIRE;
73 py::object rpcInternal = py::module::import(kInternalModule);
74 py::object rpcApi = py::module::import("torch.distributed.rpc.api");
75 py::object rrefProxy =
76 py::module::import("torch.distributed.rpc.rref_proxy");
77
78 pyRunFunction_ = getFunction(rpcInternal, "_run_function");
79 pySerialize_ = getFunction(rpcInternal, "serialize");
80 pyDeserialize_ = getFunction(rpcInternal, "deserialize");
81 pyHandleException_ = getFunction(rpcInternal, "_handle_exception");
82
83 rrefTypeFunctions_.onOwner_ = getFunction(rpcApi, "_rref_typeof_on_owner");
84 rrefTypeFunctions_.onUser_ = getFunction(rpcApi, "_rref_typeof_on_user");
85
86 rrefProxyFunctions_.rpcSync_ = getFunction(rpcApi, "rpc_sync");
87 rrefProxyFunctions_.rpcAsync_ = getFunction(rpcApi, "rpc_async");
88 rrefProxyFunctions_.remote_ = getFunction(rpcApi, "remote");
89 rrefProxyFunctions_.rrefProxyCtor_ = getFunction(rrefProxy, "RRefProxy");
90
91 jitCompilationUnit_ = torch::jit::get_python_cu();
92 typeParser_ = std::make_shared<jit::ScriptTypeParser>(
93 std::make_shared<PythonTypeResolver>());
94 initialized_ = true;
95 }
96 }
97
PythonRpcHandler()98 PythonRpcHandler::PythonRpcHandler() : initialized_(false) {}
99
cleanup()100 void PythonRpcHandler::cleanup() {
101 std::lock_guard<std::mutex> guard(init_lock_);
102 PROFILE_GIL_SCOPED_ACQUIRE;
103 cleanupPyObj(pyRunFunction_);
104 cleanupPyObj(pySerialize_);
105 cleanupPyObj(pyDeserialize_);
106 cleanupPyObj(pyHandleException_);
107
108 cleanupPyObj(rrefProxyFunctions_.rpcSync_);
109 cleanupPyObj(rrefProxyFunctions_.rpcAsync_);
110 cleanupPyObj(rrefProxyFunctions_.remote_);
111 cleanupPyObj(rrefProxyFunctions_.rrefProxyCtor_);
112
113 jitCompilationUnit_ = nullptr;
114 typeParser_ = nullptr;
115 initialized_ = false;
116 }
117
getInstance()118 PythonRpcHandler& PythonRpcHandler::getInstance() {
119 // A thread could hold GIL when calling PythonRpcHandler::getInstance(),
120 // meantime another thread could have been doing static data
121 // initialization by calling `new PythonRpcHandler()`, inside of which GIL is
122 // also required. Static data initialization is thread-safe, so the thread
123 // holding the GIL will wait for the other thread to finish static data
124 // initializating before going forward. Because the initialization can't
125 // proceed without GIL, there is a deadlock. We ask the calling thread to
126 // release GIL to avoid this situation.
127 TORCH_INTERNAL_ASSERT(!PyGILState_Check());
128 // Leaky singleton to avoid module destructor race.
129 static PythonRpcHandler* handler = new PythonRpcHandler();
130 handler->init();
131 return *handler;
132 }
133
134 std::shared_ptr<torch::jit::CompilationUnit> PythonRpcHandler::
jitCompilationUnit()135 jitCompilationUnit() {
136 return jitCompilationUnit_;
137 }
138
runPythonUdf(const py::object & pythonUdf)139 py::object PythonRpcHandler::runPythonUdf(const py::object& pythonUdf) {
140 PROFILE_GIL_SCOPED_ACQUIRE;
141 // Throw a descriptive error message if pyRunFunction_ is already cleaned up.
142 TORCH_INTERNAL_ASSERT(
143 !pyRunFunction_.is_none(),
144 "Cannot run python UDF since pyRunFunction_ is None. Check if python RPC "
145 "handler is already cleaned up.");
146 return pyRunFunction_(pythonUdf);
147 }
148
serialize(const py::object & obj)149 SerializedPyObj PythonRpcHandler::serialize(const py::object& obj) {
150 PROFILE_GIL_SCOPED_ACQUIRE;
151 py::tuple t = pySerialize_(obj);
152 return SerializedPyObj(
153 t[0].cast<std::string>(), t[1].cast<std::vector<torch::Tensor>>());
154 }
155
deserialize(const SerializedPyObj & serializedObj)156 py::object PythonRpcHandler::deserialize(const SerializedPyObj& serializedObj) {
157 PROFILE_GIL_SCOPED_ACQUIRE;
158 // NB: pyDeserialize_ can return an AttributeError if the deserialize() Python
159 // function fails. Functions consuming the result needs to handle such error
160 // properly.
161 return pyDeserialize_(
162 py::bytes(serializedObj.payload_), serializedObj.tensors_);
163 }
164
handleException(const py::object & obj)165 void PythonRpcHandler::handleException(const py::object& obj) {
166 PROFILE_GIL_SCOPED_ACQUIRE;
167 pyHandleException_(obj);
168 }
169
handleExceptionGILHeld(const py::object & obj)170 void PythonRpcHandler::handleExceptionGILHeld(const py::object& obj) {
171 TORCH_CHECK(PyGILState_Check(), "GIL should be held");
172 pyHandleException_(obj);
173 }
174
isRemoteException(const py::object & obj)175 bool PythonRpcHandler::isRemoteException(const py::object& obj) {
176 PROFILE_GIL_SCOPED_ACQUIRE;
177 auto type = obj.get_type();
178 auto moduleName = type.attr("__module__").cast<std::string>();
179 auto qualName = type.attr("__qualname__").cast<std::string>();
180 return moduleName == kInternalModule && qualName == "RemoteException";
181 }
182
parseTypeFromStr(const std::string & type_str)183 TypePtr PythonRpcHandler::parseTypeFromStr(const std::string& type_str) {
184 return typeParser_->parseType(type_str);
185 }
186
187 const PythonRpcHandler::RRefProxyFunctions& PythonRpcHandler::
getRRefProxyFunctions() const188 getRRefProxyFunctions() const {
189 return rrefProxyFunctions_;
190 }
191
192 const PythonRpcHandler::RRefTypeFunctions& PythonRpcHandler::
getRRefTypeFunctions() const193 getRRefTypeFunctions() const {
194 return rrefTypeFunctions_;
195 }
196
197 } // namespace torch::distributed::rpc
198