1 #include <torch/csrc/distributed/c10d/python_comm_hook.h> 2 3 #include <ATen/core/functional.h> 4 #include <torch/csrc/distributed/c10d/reducer.hpp> 5 #include <torch/csrc/jit/python/pybind_utils.h> 6 #include <torch/csrc/utils/tensor_flatten.h> 7 8 namespace c10d { 9 ~PythonCommHook()10PythonCommHook::~PythonCommHook() { 11 py::gil_scoped_acquire ag; 12 state_.dec_ref(); 13 hook_.dec_ref(); 14 // Explicitly set state_ and hook_ to nullptr to prevent py::object's dtor 15 // to decref on the PyObject again. 16 // See Note [Destructing py::object] in python_ivalue.h 17 state_.ptr() = nullptr; 18 hook_.ptr() = nullptr; 19 } 20 runHook(GradBucket & bucket)21c10::intrusive_ptr<c10::ivalue::Future> PythonCommHook::runHook( 22 GradBucket& bucket) { 23 py::gil_scoped_acquire acquire; 24 25 py::object py_fut = hook_(state_, bucket); 26 27 try { 28 return py_fut.cast<std::shared_ptr<torch::jit::PythonFutureWrapper>>()->fut; 29 } catch (const py::cast_error& e) { 30 auto type = py_fut.get_type(); 31 auto errMsg = c10::str( 32 e.what(), 33 ". DDP communication hook's callback must return a " 34 "torch.futures.Future object, but got ", 35 type.attr("__module__").cast<std::string>(), 36 ".", 37 type.attr("__qualname__").cast<std::string>()); 38 TORCH_CHECK(false, errMsg); 39 } 40 } 41 parseHookResult(const c10::IValue & result)42at::Tensor PythonCommHook::parseHookResult(const c10::IValue& result) { 43 TORCH_INTERNAL_ASSERT( 44 result.isPyObject(), "expected the hook result is a PyObject"); 45 46 py::gil_scoped_acquire ag; 47 py::object obj = torch::jit::toPyObject(result); 48 auto value = torch::jit::toIValue(obj, c10::TensorType::get()); 49 return value.toTensor(); 50 } 51 52 } // namespace c10d 53