1 #pragma once 2 3 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp> 4 #include <torch/csrc/jit/python/pybind_utils.h> 5 #include <torch/csrc/utils/pybind.h> 6 7 namespace c10d { 8 9 // PyProcessGroup is a pybind11 trampoline class to allow a Python 10 // class to inherit from torch.distributed.ProcessGroup 11 class PyProcessGroup : public ProcessGroup { 12 public: 13 // PyWork is a pybind11 trampoline class to allow a Python 14 // class to inherit from torch.distributed.Work 15 class TORCH_PYTHON_API PyWork : public Work { 16 public: 17 PyWork() = default; 18 wait(std::chrono::milliseconds timeout=kNoTimeout)19 bool wait(std::chrono::milliseconds timeout = kNoTimeout) override { 20 PYBIND11_OVERRIDE( 21 bool, /* Return type */ 22 Work, /* Parent class */ 23 wait, /* Name of function in C++ */ 24 timeout); 25 } 26 getFuture()27 c10::intrusive_ptr<c10::ivalue::Future> getFuture() override { 28 // We cannot use PYBIND11_OVERRIDE because: 29 // 1. We have to >MANUALLY< unwrap the PyFutureWrapper and 30 // 2. The python name is get_future 31 pybind11::gil_scoped_acquire gil; 32 auto override = 33 pybind11::get_override(static_cast<const Work*>(this), "get_future"); 34 35 if (override) { 36 py::object o = override(); 37 auto futWrapper = 38 o.cast<std::shared_ptr<torch::jit::PythonFutureWrapper>>(); 39 return futWrapper->fut; 40 } 41 42 return Work::getFuture(); 43 } 44 45 // Take a reference of the corresponding py::object. 46 // With functional collectives, ownership of work objects is generally 47 // transferred to C++. For pure C++ work objects, it is sufficient to 48 // transfer the ownership of work object. For user-defined work objects in 49 // Python, it is necessary to keep the corresponding py::object alive in 50 // addition to ensure that the user-defined methods can be executed. ref_py_object()51 void ref_py_object() { 52 py_obj_ = py::cast(this); 53 } 54 55 private: 56 py::object py_obj_; 57 }; 58 59 using ProcessGroup::ProcessGroup; 60 getBackendName() const61 const std::string getBackendName() const override { 62 PYBIND11_OVERRIDE_PURE( 63 std::string, /* Return type */ 64 ProcessGroup, /* Parent class */ 65 getBackendName, /* Name of function in C++ */ 66 ); 67 } 68 allgather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllgatherOptions & opts=AllgatherOptions ())69 c10::intrusive_ptr<Work> allgather( 70 std::vector<std::vector<at::Tensor>>& outputTensors, 71 std::vector<at::Tensor>& inputTensors, 72 const AllgatherOptions& opts = AllgatherOptions()) override { 73 PYBIND11_OVERRIDE( 74 c10::intrusive_ptr<Work>, /* Return type */ 75 ProcessGroup, /* Parent class */ 76 allgather, /* Name of function in C++ */ 77 outputTensors, 78 inputTensors, 79 opts); 80 } 81 allgather_into_tensor_coalesced(std::vector<at::Tensor> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllgatherOptions & opts=AllgatherOptions ())82 c10::intrusive_ptr<Work> allgather_into_tensor_coalesced( 83 std::vector<at::Tensor>& outputTensors, 84 std::vector<at::Tensor>& inputTensors, 85 const AllgatherOptions& opts = AllgatherOptions()) override { 86 PYBIND11_OVERRIDE( 87 c10::intrusive_ptr<Work>, /* Return type */ 88 ProcessGroup, /* Parent class */ 89 allgather_into_tensor_coalesced, /* Name of function in C++ */ 90 outputTensors, 91 inputTensors, 92 opts); 93 } 94 allreduce(std::vector<at::Tensor> & tensors,const AllreduceOptions & opts=AllreduceOptions ())95 c10::intrusive_ptr<Work> allreduce( 96 std::vector<at::Tensor>& tensors, 97 const AllreduceOptions& opts = AllreduceOptions()) override { 98 PYBIND11_OVERRIDE( 99 c10::intrusive_ptr<Work>, /* Return type */ 100 ProcessGroup, /* Parent class */ 101 allreduce, /* Name of function in C++ */ 102 tensors, 103 opts); 104 } 105 allreduce_coalesced(std::vector<at::Tensor> & tensors,const AllreduceCoalescedOptions & opts=AllreduceCoalescedOptions ())106 c10::intrusive_ptr<Work> allreduce_coalesced( 107 std::vector<at::Tensor>& tensors, 108 const AllreduceCoalescedOptions& opts = 109 AllreduceCoalescedOptions()) override { 110 PYBIND11_OVERRIDE( 111 c10::intrusive_ptr<Work>, /* Return type */ 112 ProcessGroup, /* Parent class */ 113 allreduce_coalesced, /* Name of function in C++ */ 114 tensors, 115 opts); 116 } 117 alltoall_base(at::Tensor & outputBuffer,at::Tensor & inputBuffer,std::vector<int64_t> & outputSplitSizes,std::vector<int64_t> & inputSplitSizes,const AllToAllOptions & opts=AllToAllOptions ())118 c10::intrusive_ptr<Work> alltoall_base( 119 at::Tensor& outputBuffer, 120 at::Tensor& inputBuffer, 121 std::vector<int64_t>& outputSplitSizes, 122 std::vector<int64_t>& inputSplitSizes, 123 const AllToAllOptions& opts = AllToAllOptions()) override { 124 PYBIND11_OVERRIDE( 125 c10::intrusive_ptr<Work>, /* Return type */ 126 ProcessGroup, /* Parent class */ 127 alltoall_base, /* Name of function in C++ */ 128 outputBuffer, 129 inputBuffer, 130 outputSplitSizes, 131 inputSplitSizes, 132 opts); 133 } 134 barrier(const BarrierOptions & opts=BarrierOptions ())135 c10::intrusive_ptr<Work> barrier( 136 const BarrierOptions& opts = BarrierOptions()) override { 137 PYBIND11_OVERRIDE( 138 c10::intrusive_ptr<Work>, /* Return type */ 139 ProcessGroup, /* Parent class */ 140 barrier, /* Name of function in C++ */ 141 opts); 142 } 143 broadcast(std::vector<at::Tensor> & tensors,const BroadcastOptions & opts=BroadcastOptions ())144 c10::intrusive_ptr<Work> broadcast( 145 std::vector<at::Tensor>& tensors, 146 const BroadcastOptions& opts = BroadcastOptions()) override { 147 PYBIND11_OVERRIDE( 148 c10::intrusive_ptr<Work>, /* Return type */ 149 ProcessGroup, /* Parent class */ 150 broadcast, /* Name of function in C++ */ 151 tensors, 152 opts); 153 } 154 reduce_scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ReduceScatterOptions & opts=ReduceScatterOptions ())155 c10::intrusive_ptr<Work> reduce_scatter( 156 std::vector<at::Tensor>& outputTensors, 157 std::vector<std::vector<at::Tensor>>& inputTensors, 158 const ReduceScatterOptions& opts = ReduceScatterOptions()) override { 159 PYBIND11_OVERRIDE( 160 c10::intrusive_ptr<Work>, /* Return type */ 161 ProcessGroup, /* Parent class */ 162 reduce_scatter, /* Name of function in C++ */ 163 outputTensors, 164 inputTensors, 165 opts); 166 } 167 reduce_scatter_tensor_coalesced(std::vector<at::Tensor> & outputTensors,std::vector<at::Tensor> & inputTensors,const ReduceScatterOptions & opts=ReduceScatterOptions ())168 c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced( 169 std::vector<at::Tensor>& outputTensors, 170 std::vector<at::Tensor>& inputTensors, 171 const ReduceScatterOptions& opts = ReduceScatterOptions()) override { 172 PYBIND11_OVERRIDE( 173 c10::intrusive_ptr<Work>, /* Return type */ 174 ProcessGroup, /* Parent class */ 175 reduce_scatter_tensor_coalesced, /* Name of function in C++ */ 176 outputTensors, 177 inputTensors, 178 opts); 179 } 180 send(std::vector<at::Tensor> & tensors,int dstRank,int tag)181 c10::intrusive_ptr<Work> send( 182 std::vector<at::Tensor>& tensors, 183 int dstRank, 184 int tag) override { 185 PYBIND11_OVERRIDE( 186 c10::intrusive_ptr<Work>, /* Return type */ 187 ProcessGroup, /* Parent class */ 188 send, /* Name of function in C++ */ 189 tensors, 190 dstRank, 191 tag); 192 } 193 recv(std::vector<at::Tensor> & tensors,int srcRank,int tag)194 c10::intrusive_ptr<Work> recv( 195 std::vector<at::Tensor>& tensors, 196 int srcRank, 197 int tag) override { 198 PYBIND11_OVERRIDE( 199 c10::intrusive_ptr<Work>, /* Return type */ 200 ProcessGroup, /* Parent class */ 201 recv, /* Name of function in C++ */ 202 tensors, 203 srcRank, 204 tag); 205 } 206 }; 207 208 class TORCH_PYTHON_API PythonOnCompletionHook { 209 public: 210 // Wraps a py::object hook and acquires Python GIL in dtor before 211 // destructing the hook object. PythonOnCompletionHook(py::object hook)212 PythonOnCompletionHook(py::object hook) : hook_(std::move(hook)) {} 213 ~PythonOnCompletionHook()214 ~PythonOnCompletionHook() { 215 py::gil_scoped_acquire ag; 216 hook_.dec_ref(); 217 // Explicitly set hook_ to nullptr to prevent py::object's dtor 218 // to decref on the PyObject again. 219 // See Note [Destructing py::object] in python_ivalue.h 220 hook_.ptr() = nullptr; 221 } 222 operator ()(const std::shared_ptr<WorkInfo> & workInfo) const223 void operator()(const std::shared_ptr<WorkInfo>& workInfo) const { 224 std::exception_ptr eptr; 225 { 226 py::gil_scoped_acquire acquire; 227 try { 228 hook_(workInfo); 229 } catch (py::error_already_set& e) { 230 // py::error_already_set requires GIL to destruct, take 231 // special care. 232 eptr = std::make_exception_ptr(std::runtime_error(e.what())); 233 e.restore(); 234 PyErr_Clear(); 235 } catch (std::exception& e) { 236 eptr = std::current_exception(); 237 } 238 } 239 // No more Python-related stuff at this point, i.e., this 240 // exception can be captured and handled by PG backend. 241 if (eptr) 242 std::rethrow_exception(eptr); 243 } 244 245 private: 246 py::object hook_; 247 }; 248 249 } // namespace c10d 250