1 #pragma once 2 3 #include <torch/csrc/distributed/rpc/message.h> 4 #include <torch/csrc/distributed/rpc/types.h> 5 #include <torch/csrc/jit/frontend/script_type_parser.h> 6 #include <torch/csrc/utils/pybind.h> 7 8 namespace torch::distributed::rpc { 9 10 // Singleton class provides interface to execute python UDF remote call 11 // and deserialize the returned results by running python function 12 // in internal_rpc_utilities. 13 // The singleton object is constructed at first when RPC agent is 14 // constructed, where the python function in 15 // torch/distributed/internal_rpc_utils.py are imported only once. 16 class PYBIND11_EXPORT PythonRpcHandler { 17 public: 18 struct RRefProxyFunctions { 19 py::object rrefProxyCtor_; 20 py::object rpcSync_; 21 py::object rpcAsync_; 22 py::object remote_; 23 }; 24 25 struct RRefTypeFunctions { 26 py::object onOwner_; 27 py::object onUser_; 28 }; 29 30 static PythonRpcHandler& getInstance(); 31 32 // Run a pickled Python UDF and return the result py::object 33 py::object runPythonUdf(const py::object& pythonUdf); 34 35 // Serialized a py::object into a string 36 SerializedPyObj serialize(const py::object& obj); 37 38 // Deserialize a string into a py::object 39 py::object deserialize(const SerializedPyObj& serializedObj); 40 41 // Check if obj is RemoteException, then throw it 42 void handleException(const py::object& obj); 43 // Alternative if the caller is already holding the GIL. 44 void handleExceptionGILHeld(const py::object& obj); 45 // Check if obj is an RemoteException instance. 46 bool isRemoteException(const py::object& obj); 47 48 // Explicitly clean up py::objects to avoid segment faults when 49 // py::objects with CPython are cleaned up later at program exit 50 // See similar issues reported https://github.com/pybind/pybind11/issues/1598 51 // and https://github.com/pybind/pybind11/issues/1493 52 // Our local tests also caught this segment faults if py::objects are cleaned 53 // up at program exit. The explanation is: CPython cleans up most critical 54 // utilities before cleaning up PythonRpcHandler singleton, so when 55 // PythonRpcHandler singleton cleans up py::objects and call dec_ref(), it 56 // will crash. 57 // The solution is to clean up py::objects earlier when Rpc agent join(). 58 // Be note that py::objects can not be cleaned up when Rpc agent is destroyed 59 // as well, as Rpc agent is global variable and it will have same issue as 60 // PythonRpcHandler. 61 void cleanup(); 62 63 std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit(); 64 65 // Parse the string to recover the jit_type, this is used for RRef python 66 // pickling/unpickling type recovery. The type string inference rule is as 67 // follows: 68 // 1. first try to parse if this is primitive types. 69 // i.e. TensorType, IntType, PyObjectType, etc. 70 // 2. if not primitive type, we query the python_cu to see if it is a 71 // class type or interface type registered in python 72 // We use a ScriptTypeParser instance with custom PythonTypeResolver 73 // to resolve types according to the above rules. 74 TypePtr parseTypeFromStr(const std::string& typeStr); 75 76 // Return a set of Python functions for RRef helpers. 77 const RRefProxyFunctions& getRRefProxyFunctions() const; 78 79 // Return a set of Python functions to retrieve the type of the object 80 // referenced by a given RRef. 81 const RRefTypeFunctions& getRRefTypeFunctions() const; 82 83 PythonRpcHandler(const PythonRpcHandler&) = delete; 84 PythonRpcHandler& operator=(const PythonRpcHandler&) = delete; 85 PythonRpcHandler(PythonRpcHandler&&) = delete; 86 PythonRpcHandler& operator=(PythonRpcHandler&&) = delete; 87 88 private: 89 void init(); 90 PythonRpcHandler(); 91 ~PythonRpcHandler() = default; 92 93 // Ref to `torch.distributed.rpc.internal._run_function`. 94 py::object pyRunFunction_; 95 96 // Ref to `torch.distributed.rpc.internal.serialize`. 97 py::object pySerialize_; 98 99 // Ref to `torch.distributed.rpc.internal.deserialize`. 100 py::object pyDeserialize_; 101 102 // Ref to 'torch.distributed.rpc.internal._handle_exception' 103 py::object pyHandleException_; 104 105 // Python functions for RRef proxy 106 RRefProxyFunctions rrefProxyFunctions_; 107 108 // Ref to 'torch.distributed.rpc.api._rref_typeof_on_' 109 RRefTypeFunctions rrefTypeFunctions_; 110 111 // Shared ptr to python compilation unit in jit, it is constructed in python 112 // side (see _python_cu = torch._C.CompilationUnit() in jit/__init__.py) 113 // and imported in C++ (see get_python_cu() in 114 // csrc/jit/python/pybind_utils.h). We import the compilation unit here only 115 // once for less cost and thread safety. 116 std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit_; 117 118 // jit type parser to parse type_str back to TypePtr for RRef type 119 // recovery when pickling and unpickling RRef 120 std::shared_ptr<jit::ScriptTypeParser> typeParser_; 121 122 // Indicates whether or not we have properly initialized the handler. 123 bool initialized_; 124 125 // Lock to protect initialization. 126 std::mutex init_lock_; 127 }; 128 129 } // namespace torch::distributed::rpc 130