xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/python_comm_hook.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/c10d/python_comm_hook.h>
2 
3 #include <ATen/core/functional.h>
4 #include <torch/csrc/distributed/c10d/reducer.hpp>
5 #include <torch/csrc/jit/python/pybind_utils.h>
6 #include <torch/csrc/utils/tensor_flatten.h>
7 
8 namespace c10d {
9 
~PythonCommHook()10 PythonCommHook::~PythonCommHook() {
11   py::gil_scoped_acquire ag;
12   state_.dec_ref();
13   hook_.dec_ref();
14   // Explicitly set state_ and hook_ to nullptr to prevent py::object's dtor
15   // to decref on the PyObject again.
16   // See Note [Destructing py::object] in python_ivalue.h
17   state_.ptr() = nullptr;
18   hook_.ptr() = nullptr;
19 }
20 
runHook(GradBucket & bucket)21 c10::intrusive_ptr<c10::ivalue::Future> PythonCommHook::runHook(
22     GradBucket& bucket) {
23   py::gil_scoped_acquire acquire;
24 
25   py::object py_fut = hook_(state_, bucket);
26 
27   try {
28     return py_fut.cast<std::shared_ptr<torch::jit::PythonFutureWrapper>>()->fut;
29   } catch (const py::cast_error& e) {
30     auto type = py_fut.get_type();
31     auto errMsg = c10::str(
32         e.what(),
33         ". DDP communication hook's callback must return a "
34         "torch.futures.Future object, but got ",
35         type.attr("__module__").cast<std::string>(),
36         ".",
37         type.attr("__qualname__").cast<std::string>());
38     TORCH_CHECK(false, errMsg);
39   }
40 }
41 
parseHookResult(const c10::IValue & result)42 at::Tensor PythonCommHook::parseHookResult(const c10::IValue& result) {
43   TORCH_INTERNAL_ASSERT(
44       result.isPyObject(), "expected the hook result is a PyObject");
45 
46   py::gil_scoped_acquire ag;
47   py::object obj = torch::jit::toPyObject(result);
48   auto value = torch::jit::toIValue(obj, c10::TensorType::get());
49   return value.toTensor();
50 }
51 
52 } // namespace c10d
53