xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/py_rref.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/py_rref.h>
2 
3 #include <torch/csrc/autograd/autograd.h>
4 #include <torch/csrc/distributed/autograd/autograd.h>
5 #include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h>
6 #include <torch/csrc/distributed/rpc/python_functions.h>
7 #include <torch/csrc/distributed/rpc/python_rpc_handler.h>
8 #include <torch/csrc/distributed/rpc/rref_context.h>
9 #include <torch/csrc/jit/python/module_python.h>
10 #include <torch/csrc/jit/python/pybind_utils.h>
11 
12 namespace torch::distributed::rpc {
13 
14 /////////////////////  Pickle/Unpickle Helplers ////////////////////////////
15 
16 namespace {
17 
toPyTuple(const RRefForkData & rrefForkData)18 py::tuple toPyTuple(const RRefForkData& rrefForkData) {
19   // add GIL as it is contructing a py::object
20   pybind11::gil_scoped_acquire ag;
21   return py::make_tuple(
22       rrefForkData.ownerId_,
23       rrefForkData.rrefId_.createdOn_,
24       rrefForkData.rrefId_.localId_,
25       rrefForkData.forkId_.createdOn_,
26       rrefForkData.forkId_.localId_,
27       rrefForkData.parent_,
28       rrefForkData.typeStr_);
29 }
30 
fromPyTuple(const py::tuple & pyTuple)31 RRefForkData fromPyTuple(const py::tuple& pyTuple) {
32   // add GIL as it is accessing a py::object
33   pybind11::gil_scoped_acquire ag;
34   TORCH_INTERNAL_ASSERT(
35       pyTuple.size() == RFD_TUPLE_SIZE,
36       "Pickled RRefForkData must contain ",
37       RFD_TUPLE_SIZE,
38       " numbers.");
39   worker_id_t ownerId = pyTuple[OWNER_IDX].cast<worker_id_t>();
40   // const reference will extend the lifetime of the temporary variable
41   const RRefId& rrefId = RRefId(
42       pyTuple[RREFID_ON_IDX].cast<worker_id_t>(),
43       pyTuple[RREFID_ID_IDX].cast<local_id_t>());
44   const RRefId& forkId = RRefId(
45       pyTuple[FORKID_ON_IDX].cast<worker_id_t>(),
46       pyTuple[FORKID_ID_IDX].cast<local_id_t>());
47 
48   worker_id_t parent = pyTuple[PARENT_IDX].cast<worker_id_t>();
49   const std::string& typeStr = pyTuple[TYPE_IDX].cast<std::string>();
50 
51   return RRefForkData(ownerId, rrefId, forkId, parent, typeStr);
52 }
53 
tryInferTypeWithTypeHint(const py::object & value,const py::object & type_hint)54 TypePtr tryInferTypeWithTypeHint(
55     const py::object& value,
56     const py::object& type_hint) {
57   // If the py::object to be contained by the RRef is a ScriptModule, we enforce
58   // users to specify its ModuleInterface type.
59   if (auto module = jit::as_module(value)) {
60     TORCH_CHECK(
61         !type_hint.is_none(),
62         "The RRef being created contains a ScriptModule, "
63         "must provide its ModuleInterface type hint. ");
64     c10::QualifiedName type_qualified_name = c10::QualifiedName(
65         py::cast<std::string>(py::module::import("torch._jit_internal")
66                                   .attr("_qualified_name")(type_hint)));
67     TypePtr type_hint_ptr =
68         jit::get_python_cu()->get_interface(type_qualified_name);
69     std::ostringstream subtype_check_msg;
70     TORCH_CHECK(
71         type_hint_ptr != nullptr &&
72             module.value().type()->isSubtypeOfExt(
73                 *type_hint_ptr, &subtype_check_msg),
74         module.value().type()->repr_str(),
75         " is not a subtype of the type hint: ",
76         type_qualified_name.qualifiedName(),
77         ", did you pass a valid interface type?\n",
78         subtype_check_msg.str());
79     return type_hint_ptr;
80   } else {
81     TORCH_CHECK(
82         type_hint.is_none(),
83         "type_hint should only be specified when the RRef being created contains a ScriptModule.");
84   }
85 
86   // Check if value is an instance of a ScriptClass. If not, skip type inference
87   // because it will try to script the class that value is in instance of, and
88   // this should be avoided.
89   py::bool_ can_compile = py::module::import("torch._jit_internal")
90                               .attr("can_compile_class")(value.get_type());
91 
92   if (py::cast<bool>(can_compile)) {
93     py::object existing_ty = py::module::import("torch.jit._state")
94                                  .attr("_get_script_class")(value.get_type());
95 
96     if (existing_ty.is_none()) {
97       return PyObjectType::get();
98     }
99   }
100 
101   // NB: `jit::tryToInferType(..)` infers types including ScriptClass, but
102   // excluding ScriptModule.
103   jit::InferredType type_inferred = jit::tryToInferType(value);
104   if (type_inferred.success()) {
105     // If we could infer the type from the pyobject, we create
106     // the RRef with the IValue of that type.
107     return type_inferred.type();
108   }
109 
110   // Otherwise it's a pure pyobject, create the RRef
111   // that holds an IValue of an pyobject.
112   return PyObjectType::get();
113 }
114 
115 } // namespace
116 
117 ///////////////////////////  PyRRef  //////////////////////////////////
118 
PyRRef(c10::intrusive_ptr<RRef> rref)119 PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref)
120     : rref_(std::move(rref)), profilingFuture_(std::nullopt) {
121   TORCH_CHECK(rref_, "PyRRef must not wrap nullptr");
122   C10_LOG_API_USAGE_ONCE("torch.distributed.rref");
123 }
124 
PyRRef(const py::object & value,const py::object & type_hint)125 PyRRef::PyRRef(const py::object& value, const py::object& type_hint)
126     : PyRRef([&value, &type_hint]() mutable {
127         TypePtr elem_type = tryInferTypeWithTypeHint(value, type_hint);
128         auto rref = RRefContext::getInstance().createOwnerRRef(elem_type);
129         // jit::toIValue takes a py::handle as the first argument, and it calls
130         // py::handle.cast<py::object>() to incref of provided value. The
131         // returned ivalue will keep the reference alive.
132         // NB: the first argument const py::object& value must be kept alive
133         // until the following jit::toIValue returns (i.e., incref done). That's
134         // why this ctor can only be called while holding GIL.
135         IValue ivalue = jit::toIValue(value, elem_type);
136         rref->setValue(std::move(ivalue));
137         return rref;
138       }()) {}
139 
~PyRRef()140 PyRRef::~PyRRef() {
141   if (type_.has_value()) {
142     pybind11::gil_scoped_acquire ag;
143     (*type_).dec_ref();
144     // explicitly setting PyObject* to nullptr to prevent py::object's dtor to
145     // decref on the PyObject again.
146     // See Note [Destructing py::object] in python_ivalue.h
147     (*type_).ptr() = nullptr;
148   }
149 }
150 
getFuture() const151 c10::intrusive_ptr<JitFuture> PyRRef::getFuture() const {
152   // Marking hasValue to false, as this Future is only used for signaling
153   // profiler to update profiling result and the profiler does not retrieve
154   // any value from it.
155   return toPyJitFuture(rref_->getOwnerCreationFuture(), false /* hasValue */);
156 }
157 
getProfilingFuture() const158 c10::intrusive_ptr<JitFuture> PyRRef::getProfilingFuture() const {
159   TORCH_INTERNAL_ASSERT(profilingFuture_, "Profiling future has not been set!");
160   return *profilingFuture_;
161 }
162 
setProfilingFuture(c10::intrusive_ptr<JitFuture> profilingFuture)163 void PyRRef::setProfilingFuture(c10::intrusive_ptr<JitFuture> profilingFuture) {
164   profilingFuture_ = std::move(profilingFuture);
165 }
166 
isOwner() const167 bool PyRRef::isOwner() const {
168   return rref_->isOwner();
169 }
170 
confirmedByOwner() const171 bool PyRRef::confirmedByOwner() const {
172   return rref_->confirmedByOwner();
173 }
174 
owner() const175 WorkerInfo PyRRef::owner() const {
176   return RRefContext::getInstance().agent()->getWorkerInfo(rref_->owner());
177 }
178 
ownerName() const179 std::string PyRRef::ownerName() const {
180   return rref_->ownerName();
181 }
182 
toHere(const float timeoutSeconds) const183 py::object PyRRef::toHere(const float timeoutSeconds) const {
184   C10_LOG_API_USAGE_ONCE("torch.distributed.rref.to_here");
185   if (rref_->isOwner()) {
186     return localValue();
187   } else {
188     // toHere() calls python_rpc_handler which acquires GIL when UserRRef holds
189     // a python object
190     IValue value = c10::static_intrusive_pointer_cast<UserRRef>(rref_)->toHere(
191         timeoutSeconds);
192 
193     if (rref_->isPyObj()) {
194       // python_rpc_handler deserialization will acquires GIL.
195       auto rfr_values = value.toTupleRef().elements().vec();
196       auto& pythonRpcHandler = PythonRpcHandler::getInstance();
197       auto ret = pythonRpcHandler.deserialize(
198           SerializedPyObj::fromIValues(std::move(rfr_values)));
199       pythonRpcHandler.handleException(ret);
200       return ret;
201     } else {
202       // acquiring GIL as torch::jit::toPyObject creates new py::object
203       // without grabbing the GIL.
204       pybind11::gil_scoped_acquire ag;
205       return torch::jit::toPyObject(std::move(value));
206     }
207   }
208 }
209 
localValue() const210 py::object PyRRef::localValue() const {
211   TORCH_CHECK(
212       rref_->isOwner(),
213       "For ",
214       *rref_,
215       ", can't call localValue() on user ",
216       RRefContext::getInstance().agent()->getWorkerInfo(),
217       ". Call it on owner ",
218       owner());
219 
220   py::object res;
221   auto value =
222       c10::static_intrusive_pointer_cast<const OwnerRRef>(rref_)->getValue();
223   auto& rpcHandler = PythonRpcHandler::getInstance();
224   {
225     // acquiring GIL as torch::jit::toPyObject creates new py::object without
226     // grabbing the GIL.
227     pybind11::gil_scoped_acquire ag;
228     res = torch::jit::toPyObject(std::move(value));
229     rpcHandler.handleExceptionGILHeld(res);
230   }
231   return res;
232 }
233 
str() const234 std::string PyRRef::str() const {
235   if (rref_->isOwner()) {
236     return c10::str("OwnerRRef(", rref_->rrefId(), ")");
237   } else {
238     return c10::str(
239         "UserRRef(RRefId = ",
240         rref_->rrefId(),
241         ", ForkId = ",
242         c10::static_intrusive_pointer_cast<UserRRef>(rref_)->forkId(),
243         ")");
244   }
245 }
246 
createRRefProxy(const RRefProxyType & type,float timeoutSeconds) const247 py::object PyRRef::createRRefProxy(
248     const RRefProxyType& type,
249     float timeoutSeconds) const {
250   auto& pythonRpcHandler = PythonRpcHandler::getInstance();
251   pybind11::gil_scoped_acquire ag;
252   auto& functions = pythonRpcHandler.getRRefProxyFunctions();
253   auto& ctor = functions.rrefProxyCtor_;
254   switch (type) {
255     case RRefProxyType::RPC_SYNC: {
256       return ctor(*this, functions.rpcSync_, timeoutSeconds);
257     }
258     case RRefProxyType::RPC_ASYNC: {
259       return ctor(*this, functions.rpcAsync_, timeoutSeconds);
260     }
261     case RRefProxyType::REMOTE: {
262       return ctor(*this, functions.remote_, timeoutSeconds);
263     }
264     default: {
265       TORCH_INTERNAL_ASSERT(false, "Unrecognized RRefProxy type ", type);
266     }
267   }
268 }
269 
getRRefType(float timeout,bool blocking)270 py::object PyRRef::getRRefType(float timeout, bool blocking) {
271   // GIL is not released when calling this function.
272   if (!type_.has_value()) {
273     pybind11::gil_scoped_release release;
274     auto& pythonRpcHandler = PythonRpcHandler::getInstance();
275     auto& typeFuncs = pythonRpcHandler.getRRefTypeFunctions();
276     pybind11::gil_scoped_acquire acquire;
277     type_ = isOwner() ? typeFuncs.onOwner_(*this, blocking)
278                       : typeFuncs.onUser_(*this, timeout, blocking);
279   }
280   // Returns py::object that can be Python type or future.
281   return *type_;
282 }
283 
pickle() const284 py::tuple PyRRef::pickle() const {
285   auto& ctx = RRefContext::getInstance();
286   auto rrefForkData = ctx.prepareChildFork(rref_);
287   return toPyTuple(rrefForkData);
288 }
289 
unpickle(const py::tuple & pyTuple)290 PyRRef PyRRef::unpickle(const py::tuple& pyTuple) {
291   auto& ctx = RRefContext::getInstance();
292   auto rrefForkData = fromPyTuple(pyTuple);
293   TypePtr rrefType =
294       PythonRpcHandler::getInstance().parseTypeFromStr(rrefForkData.typeStr_);
295   c10::intrusive_ptr<RRef> rref = ctx.getOrCreateRRef(rrefForkData, rrefType);
296   ctx.notifyOwnerAndParentOfFork(
297       rrefForkData.forkId_, rrefForkData.parent_, rref);
298   return PyRRef(std::move(rref));
299 }
300 
toIValue() const301 c10::IValue PyRRef::toIValue() const {
302   // cast to RRefInterface to hold it into IValue
303   auto rrefPtr = c10::static_intrusive_pointer_cast<c10::RRefInterface>(rref_);
304   return IValue(rrefPtr);
305 }
306 
backward(int64_t autogradContextId,bool retainGraph)307 void PyRRef::backward(int64_t autogradContextId, bool retainGraph) {
308   backward(autogradContextId, retainGraph, rref_);
309 }
310 
backwardOwnerRRef(int64_t autogradContextId,bool retainGraph,IValue value)311 void PyRRef::backwardOwnerRRef(
312     int64_t autogradContextId,
313     bool retainGraph,
314     IValue value) {
315   // If we have a PyObj, retrieve the underlying tensor.
316   if (value.isPyObject()) {
317     py::gil_scoped_acquire gil;
318     py::object obj = torch::jit::toPyObject(value);
319     try {
320       value = torch::jit::toIValue(obj, c10::TensorType::get());
321     } catch (py::cast_error& e) {
322       TORCH_CHECK(false, "RRef should contain a tensor for .backward()");
323     }
324   }
325 
326   TORCH_CHECK(value.isTensor(), "RRef should contain a tensor for .backward()");
327   auto root = value.toTensor();
328 
329   if (autogradContextId == -1) {
330     torch::autograd::backward({root});
331   } else {
332     torch::distributed::autograd::backward(
333         autogradContextId, {root}, retainGraph);
334   }
335 }
336 
backward(int64_t autogradContextId,bool retainGraph,const c10::intrusive_ptr<RRef> & rref)337 void PyRRef::backward(
338     int64_t autogradContextId,
339     bool retainGraph,
340     const c10::intrusive_ptr<RRef>& rref) {
341   if (rref->isOwner()) {
342     backwardOwnerRRef(
343         autogradContextId,
344         retainGraph,
345         c10::static_intrusive_pointer_cast<const OwnerRRef>(rref)->getValue());
346   } else {
347     TORCH_CHECK(
348         autogradContextId != -1,
349         "User RRefs require 'dist_autograd_ctx_id' to be specified");
350 
351     autograd::RRefBackwardReq rrefBackwardReq(
352         rref->rrefId(), autogradContextId, retainGraph);
353 
354     // Invoke distributed backward remotely.
355     auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent();
356     rpcAgent
357         ->send(
358             rpcAgent->getWorkerInfo(rref->owner()),
359             std::move(rrefBackwardReq).toMessage())
360         ->waitAndThrow();
361   }
362 }
363 
364 } // namespace torch::distributed::rpc
365