xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/python_rpc_handler.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/python_rpc_handler.h>
2 #include <torch/csrc/distributed/rpc/rpc_agent.h>
3 #include <torch/csrc/jit/python/pybind_utils.h>
4 #include <torch/csrc/utils/python_compat.h>
5 
6 namespace torch::distributed::rpc {
7 
8 namespace {
9 
10 constexpr auto kInternalModule = "torch.distributed.rpc.internal";
11 
12 // A macro that grabs the GIL, profiling the acquisition time. The average GIL
13 // acquisition time will be recorded in RpcAgent's getMetrics().
14 #define PROFILE_GIL_SCOPED_ACQUIRE                                       \
15   std::chrono::time_point<std::chrono::high_resolution_clock> startTime; \
16   auto shouldProfileGIL =                                                \
17       RpcAgent::getCurrentRpcAgent() -> isGILProfilingEnabled();         \
18   if (shouldProfileGIL) {                                                \
19     startTime = std::chrono::high_resolution_clock::now();               \
20   }                                                                      \
21   pybind11::gil_scoped_acquire ag;                                       \
22   if (shouldProfileGIL) {                                                \
23     auto dur = std::chrono::duration_cast<std::chrono::microseconds>(    \
24         std::chrono::high_resolution_clock::now() - startTime);          \
25     RpcAgent::getCurrentRpcAgent()->addGilWaitTime(dur);                 \
26   } // NOLINT
27 
28 // PythonTypeResolver that inherits from Script::Resolver to
29 // support resolving types together with ScriptTypeParser.
30 struct PythonTypeResolver : public jit::Resolver {
resolveValuetorch::distributed::rpc::__anon7a1394420111::PythonTypeResolver31   std::shared_ptr<jit::SugaredValue> resolveValue(
32       const std::string& /* unused */,
33       torch::jit::GraphFunction& /* unused */,
34       const jit::SourceRange& /* unused */) override {
35     TORCH_INTERNAL_ASSERT(
36         false, "RPC Type resolver does not need to resolve value");
37   }
38 
resolveTypetorch::distributed::rpc::__anon7a1394420111::PythonTypeResolver39   TypePtr resolveType(
40       const std::string& name,
41       const jit::SourceRange& /* unused */) override {
42     if (name == "PyObject") {
43       return PyObjectType::get();
44     }
45     return PythonRpcHandler::getInstance().jitCompilationUnit()->get_type(name);
46   }
47 };
48 
getFunction(const py::object & module,const char * name)49 py::object getFunction(const py::object& module, const char* name) {
50   py::object fn = module.attr(name);
51   TORCH_CHECK(
52       py::isinstance<py::function>(fn),
53       "attribute ",
54       name,
55       " is not a function");
56   return fn;
57 }
58 
cleanupPyObj(py::object & obj)59 void cleanupPyObj(py::object& obj) {
60   obj.dec_ref();
61   // explicitly setting PyObject* to nullptr to prevent py::object's dtor to
62   // decref on the PyObject again.
63   // See Note [Destructing py::object] in python_ivalue.h
64   obj.ptr() = nullptr;
65 }
66 
67 } // namespace
68 
init()69 void PythonRpcHandler::init() {
70   std::lock_guard<std::mutex> guard(init_lock_);
71   if (!initialized_) {
72     PROFILE_GIL_SCOPED_ACQUIRE;
73     py::object rpcInternal = py::module::import(kInternalModule);
74     py::object rpcApi = py::module::import("torch.distributed.rpc.api");
75     py::object rrefProxy =
76         py::module::import("torch.distributed.rpc.rref_proxy");
77 
78     pyRunFunction_ = getFunction(rpcInternal, "_run_function");
79     pySerialize_ = getFunction(rpcInternal, "serialize");
80     pyDeserialize_ = getFunction(rpcInternal, "deserialize");
81     pyHandleException_ = getFunction(rpcInternal, "_handle_exception");
82 
83     rrefTypeFunctions_.onOwner_ = getFunction(rpcApi, "_rref_typeof_on_owner");
84     rrefTypeFunctions_.onUser_ = getFunction(rpcApi, "_rref_typeof_on_user");
85 
86     rrefProxyFunctions_.rpcSync_ = getFunction(rpcApi, "rpc_sync");
87     rrefProxyFunctions_.rpcAsync_ = getFunction(rpcApi, "rpc_async");
88     rrefProxyFunctions_.remote_ = getFunction(rpcApi, "remote");
89     rrefProxyFunctions_.rrefProxyCtor_ = getFunction(rrefProxy, "RRefProxy");
90 
91     jitCompilationUnit_ = torch::jit::get_python_cu();
92     typeParser_ = std::make_shared<jit::ScriptTypeParser>(
93         std::make_shared<PythonTypeResolver>());
94     initialized_ = true;
95   }
96 }
97 
PythonRpcHandler()98 PythonRpcHandler::PythonRpcHandler() : initialized_(false) {}
99 
cleanup()100 void PythonRpcHandler::cleanup() {
101   std::lock_guard<std::mutex> guard(init_lock_);
102   PROFILE_GIL_SCOPED_ACQUIRE;
103   cleanupPyObj(pyRunFunction_);
104   cleanupPyObj(pySerialize_);
105   cleanupPyObj(pyDeserialize_);
106   cleanupPyObj(pyHandleException_);
107 
108   cleanupPyObj(rrefProxyFunctions_.rpcSync_);
109   cleanupPyObj(rrefProxyFunctions_.rpcAsync_);
110   cleanupPyObj(rrefProxyFunctions_.remote_);
111   cleanupPyObj(rrefProxyFunctions_.rrefProxyCtor_);
112 
113   jitCompilationUnit_ = nullptr;
114   typeParser_ = nullptr;
115   initialized_ = false;
116 }
117 
getInstance()118 PythonRpcHandler& PythonRpcHandler::getInstance() {
119   // A thread could hold GIL when calling PythonRpcHandler::getInstance(),
120   // meantime another thread could have been doing static data
121   // initialization by calling `new PythonRpcHandler()`, inside of which GIL is
122   // also required. Static data initialization is thread-safe, so the thread
123   // holding the GIL will wait for the other thread to finish static data
124   // initializating before going forward. Because the initialization can't
125   // proceed without GIL, there is a deadlock. We ask the calling thread to
126   // release GIL to avoid this situation.
127   TORCH_INTERNAL_ASSERT(!PyGILState_Check());
128   // Leaky singleton to avoid module destructor race.
129   static PythonRpcHandler* handler = new PythonRpcHandler();
130   handler->init();
131   return *handler;
132 }
133 
134 std::shared_ptr<torch::jit::CompilationUnit> PythonRpcHandler::
jitCompilationUnit()135     jitCompilationUnit() {
136   return jitCompilationUnit_;
137 }
138 
runPythonUdf(const py::object & pythonUdf)139 py::object PythonRpcHandler::runPythonUdf(const py::object& pythonUdf) {
140   PROFILE_GIL_SCOPED_ACQUIRE;
141   // Throw a descriptive error message if pyRunFunction_ is already cleaned up.
142   TORCH_INTERNAL_ASSERT(
143       !pyRunFunction_.is_none(),
144       "Cannot run python UDF since pyRunFunction_ is None. Check if python RPC "
145       "handler is already cleaned up.");
146   return pyRunFunction_(pythonUdf);
147 }
148 
serialize(const py::object & obj)149 SerializedPyObj PythonRpcHandler::serialize(const py::object& obj) {
150   PROFILE_GIL_SCOPED_ACQUIRE;
151   py::tuple t = pySerialize_(obj);
152   return SerializedPyObj(
153       t[0].cast<std::string>(), t[1].cast<std::vector<torch::Tensor>>());
154 }
155 
deserialize(const SerializedPyObj & serializedObj)156 py::object PythonRpcHandler::deserialize(const SerializedPyObj& serializedObj) {
157   PROFILE_GIL_SCOPED_ACQUIRE;
158   // NB: pyDeserialize_ can return an AttributeError if the deserialize() Python
159   // function fails. Functions consuming the result needs to handle such error
160   // properly.
161   return pyDeserialize_(
162       py::bytes(serializedObj.payload_), serializedObj.tensors_);
163 }
164 
handleException(const py::object & obj)165 void PythonRpcHandler::handleException(const py::object& obj) {
166   PROFILE_GIL_SCOPED_ACQUIRE;
167   pyHandleException_(obj);
168 }
169 
handleExceptionGILHeld(const py::object & obj)170 void PythonRpcHandler::handleExceptionGILHeld(const py::object& obj) {
171   TORCH_CHECK(PyGILState_Check(), "GIL should be held");
172   pyHandleException_(obj);
173 }
174 
isRemoteException(const py::object & obj)175 bool PythonRpcHandler::isRemoteException(const py::object& obj) {
176   PROFILE_GIL_SCOPED_ACQUIRE;
177   auto type = obj.get_type();
178   auto moduleName = type.attr("__module__").cast<std::string>();
179   auto qualName = type.attr("__qualname__").cast<std::string>();
180   return moduleName == kInternalModule && qualName == "RemoteException";
181 }
182 
parseTypeFromStr(const std::string & type_str)183 TypePtr PythonRpcHandler::parseTypeFromStr(const std::string& type_str) {
184   return typeParser_->parseType(type_str);
185 }
186 
187 const PythonRpcHandler::RRefProxyFunctions& PythonRpcHandler::
getRRefProxyFunctions() const188     getRRefProxyFunctions() const {
189   return rrefProxyFunctions_;
190 }
191 
192 const PythonRpcHandler::RRefTypeFunctions& PythonRpcHandler::
getRRefTypeFunctions() const193     getRRefTypeFunctions() const {
194   return rrefTypeFunctions_;
195 }
196 
197 } // namespace torch::distributed::rpc
198