xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/python_rpc_handler.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/rpc/message.h>
4 #include <torch/csrc/distributed/rpc/types.h>
5 #include <torch/csrc/jit/frontend/script_type_parser.h>
6 #include <torch/csrc/utils/pybind.h>
7 
8 namespace torch::distributed::rpc {
9 
10 // Singleton class provides interface to execute python UDF remote call
11 // and deserialize the returned results by running python function
12 // in internal_rpc_utilities.
13 // The singleton object is constructed at first when RPC agent is
14 // constructed, where the python function in
15 // torch/distributed/internal_rpc_utils.py are imported only once.
16 class PYBIND11_EXPORT PythonRpcHandler {
17  public:
18   struct RRefProxyFunctions {
19     py::object rrefProxyCtor_;
20     py::object rpcSync_;
21     py::object rpcAsync_;
22     py::object remote_;
23   };
24 
25   struct RRefTypeFunctions {
26     py::object onOwner_;
27     py::object onUser_;
28   };
29 
30   static PythonRpcHandler& getInstance();
31 
32   // Run a pickled Python UDF and return the result py::object
33   py::object runPythonUdf(const py::object& pythonUdf);
34 
35   // Serialized a py::object into a string
36   SerializedPyObj serialize(const py::object& obj);
37 
38   // Deserialize a string into a py::object
39   py::object deserialize(const SerializedPyObj& serializedObj);
40 
41   // Check if obj is RemoteException, then throw it
42   void handleException(const py::object& obj);
43   // Alternative if the caller is already holding the GIL.
44   void handleExceptionGILHeld(const py::object& obj);
45   // Check if obj is an RemoteException instance.
46   bool isRemoteException(const py::object& obj);
47 
48   // Explicitly clean up py::objects to avoid segment faults when
49   // py::objects with CPython are cleaned up later at program exit
50   // See similar issues reported https://github.com/pybind/pybind11/issues/1598
51   // and https://github.com/pybind/pybind11/issues/1493
52   // Our local tests also caught this segment faults if py::objects are cleaned
53   // up at program exit. The explanation is: CPython cleans up most critical
54   // utilities before cleaning up PythonRpcHandler singleton, so when
55   // PythonRpcHandler singleton cleans up py::objects and call dec_ref(), it
56   // will crash.
57   // The solution is to clean up py::objects earlier when Rpc agent join().
58   // Be note that py::objects can not be cleaned up when Rpc agent is destroyed
59   // as well, as Rpc agent is global variable and it will have same issue as
60   // PythonRpcHandler.
61   void cleanup();
62 
63   std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit();
64 
65   // Parse the string to recover the jit_type, this is used for RRef python
66   // pickling/unpickling type recovery. The type string inference rule is as
67   // follows:
68   // 1. first try to parse if this is primitive types.
69   //    i.e. TensorType, IntType, PyObjectType, etc.
70   // 2. if not primitive type, we query the python_cu to see if it is a
71   //    class type or interface type registered in python
72   // We use a ScriptTypeParser instance with custom PythonTypeResolver
73   // to resolve types according to the above rules.
74   TypePtr parseTypeFromStr(const std::string& typeStr);
75 
76   // Return a set of Python functions for RRef helpers.
77   const RRefProxyFunctions& getRRefProxyFunctions() const;
78 
79   // Return a set of Python functions to retrieve the type of the object
80   // referenced by a given RRef.
81   const RRefTypeFunctions& getRRefTypeFunctions() const;
82 
83   PythonRpcHandler(const PythonRpcHandler&) = delete;
84   PythonRpcHandler& operator=(const PythonRpcHandler&) = delete;
85   PythonRpcHandler(PythonRpcHandler&&) = delete;
86   PythonRpcHandler& operator=(PythonRpcHandler&&) = delete;
87 
88  private:
89   void init();
90   PythonRpcHandler();
91   ~PythonRpcHandler() = default;
92 
93   // Ref to `torch.distributed.rpc.internal._run_function`.
94   py::object pyRunFunction_;
95 
96   // Ref to `torch.distributed.rpc.internal.serialize`.
97   py::object pySerialize_;
98 
99   // Ref to `torch.distributed.rpc.internal.deserialize`.
100   py::object pyDeserialize_;
101 
102   // Ref to 'torch.distributed.rpc.internal._handle_exception'
103   py::object pyHandleException_;
104 
105   // Python functions for RRef proxy
106   RRefProxyFunctions rrefProxyFunctions_;
107 
108   // Ref to 'torch.distributed.rpc.api._rref_typeof_on_'
109   RRefTypeFunctions rrefTypeFunctions_;
110 
111   // Shared ptr to python compilation unit in jit, it is constructed in python
112   // side (see _python_cu = torch._C.CompilationUnit() in jit/__init__.py)
113   // and imported in C++ (see get_python_cu() in
114   // csrc/jit/python/pybind_utils.h). We import the compilation unit here only
115   // once for less cost and thread safety.
116   std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit_;
117 
118   // jit type parser to parse type_str back to TypePtr for RRef type
119   // recovery when pickling and unpickling RRef
120   std::shared_ptr<jit::ScriptTypeParser> typeParser_;
121 
122   // Indicates whether or not we have properly initialized the handler.
123   bool initialized_;
124 
125   // Lock to protect initialization.
126   std::mutex init_lock_;
127 };
128 
129 } // namespace torch::distributed::rpc
130