xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/PyProcessGroup.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
4 #include <torch/csrc/jit/python/pybind_utils.h>
5 #include <torch/csrc/utils/pybind.h>
6 
7 namespace c10d {
8 
9 // PyProcessGroup is a pybind11 trampoline class to allow a Python
10 // class to inherit from torch.distributed.ProcessGroup
11 class PyProcessGroup : public ProcessGroup {
12  public:
13   // PyWork is a pybind11 trampoline class to allow a Python
14   // class to inherit from torch.distributed.Work
15   class TORCH_PYTHON_API PyWork : public Work {
16    public:
17     PyWork() = default;
18 
wait(std::chrono::milliseconds timeout=kNoTimeout)19     bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
20       PYBIND11_OVERRIDE(
21           bool, /* Return type */
22           Work, /* Parent class */
23           wait, /* Name of function in C++ */
24           timeout);
25     }
26 
getFuture()27     c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
28       // We cannot use PYBIND11_OVERRIDE because:
29       // 1. We have to >MANUALLY< unwrap the PyFutureWrapper and
30       // 2. The python name is get_future
31       pybind11::gil_scoped_acquire gil;
32       auto override =
33           pybind11::get_override(static_cast<const Work*>(this), "get_future");
34 
35       if (override) {
36         py::object o = override();
37         auto futWrapper =
38             o.cast<std::shared_ptr<torch::jit::PythonFutureWrapper>>();
39         return futWrapper->fut;
40       }
41 
42       return Work::getFuture();
43     }
44 
45     // Take a reference of the corresponding py::object.
46     // With functional collectives, ownership of work objects is generally
47     // transferred to C++. For pure C++ work objects, it is sufficient to
48     // transfer the ownership of work object. For user-defined work objects in
49     // Python, it is necessary to keep the corresponding py::object alive in
50     // addition to ensure that the user-defined methods can be executed.
ref_py_object()51     void ref_py_object() {
52       py_obj_ = py::cast(this);
53     }
54 
55    private:
56     py::object py_obj_;
57   };
58 
59   using ProcessGroup::ProcessGroup;
60 
getBackendName() const61   const std::string getBackendName() const override {
62     PYBIND11_OVERRIDE_PURE(
63         std::string, /* Return type */
64         ProcessGroup, /* Parent class */
65         getBackendName, /* Name of function in C++ */
66     );
67   }
68 
allgather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllgatherOptions & opts=AllgatherOptions ())69   c10::intrusive_ptr<Work> allgather(
70       std::vector<std::vector<at::Tensor>>& outputTensors,
71       std::vector<at::Tensor>& inputTensors,
72       const AllgatherOptions& opts = AllgatherOptions()) override {
73     PYBIND11_OVERRIDE(
74         c10::intrusive_ptr<Work>, /* Return type */
75         ProcessGroup, /* Parent class */
76         allgather, /* Name of function in C++ */
77         outputTensors,
78         inputTensors,
79         opts);
80   }
81 
allgather_into_tensor_coalesced(std::vector<at::Tensor> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllgatherOptions & opts=AllgatherOptions ())82   c10::intrusive_ptr<Work> allgather_into_tensor_coalesced(
83       std::vector<at::Tensor>& outputTensors,
84       std::vector<at::Tensor>& inputTensors,
85       const AllgatherOptions& opts = AllgatherOptions()) override {
86     PYBIND11_OVERRIDE(
87         c10::intrusive_ptr<Work>, /* Return type */
88         ProcessGroup, /* Parent class */
89         allgather_into_tensor_coalesced, /* Name of function in C++ */
90         outputTensors,
91         inputTensors,
92         opts);
93   }
94 
allreduce(std::vector<at::Tensor> & tensors,const AllreduceOptions & opts=AllreduceOptions ())95   c10::intrusive_ptr<Work> allreduce(
96       std::vector<at::Tensor>& tensors,
97       const AllreduceOptions& opts = AllreduceOptions()) override {
98     PYBIND11_OVERRIDE(
99         c10::intrusive_ptr<Work>, /* Return type */
100         ProcessGroup, /* Parent class */
101         allreduce, /* Name of function in C++ */
102         tensors,
103         opts);
104   }
105 
allreduce_coalesced(std::vector<at::Tensor> & tensors,const AllreduceCoalescedOptions & opts=AllreduceCoalescedOptions ())106   c10::intrusive_ptr<Work> allreduce_coalesced(
107       std::vector<at::Tensor>& tensors,
108       const AllreduceCoalescedOptions& opts =
109           AllreduceCoalescedOptions()) override {
110     PYBIND11_OVERRIDE(
111         c10::intrusive_ptr<Work>, /* Return type */
112         ProcessGroup, /* Parent class */
113         allreduce_coalesced, /* Name of function in C++ */
114         tensors,
115         opts);
116   }
117 
alltoall_base(at::Tensor & outputBuffer,at::Tensor & inputBuffer,std::vector<int64_t> & outputSplitSizes,std::vector<int64_t> & inputSplitSizes,const AllToAllOptions & opts=AllToAllOptions ())118   c10::intrusive_ptr<Work> alltoall_base(
119       at::Tensor& outputBuffer,
120       at::Tensor& inputBuffer,
121       std::vector<int64_t>& outputSplitSizes,
122       std::vector<int64_t>& inputSplitSizes,
123       const AllToAllOptions& opts = AllToAllOptions()) override {
124     PYBIND11_OVERRIDE(
125         c10::intrusive_ptr<Work>, /* Return type */
126         ProcessGroup, /* Parent class */
127         alltoall_base, /* Name of function in C++ */
128         outputBuffer,
129         inputBuffer,
130         outputSplitSizes,
131         inputSplitSizes,
132         opts);
133   }
134 
barrier(const BarrierOptions & opts=BarrierOptions ())135   c10::intrusive_ptr<Work> barrier(
136       const BarrierOptions& opts = BarrierOptions()) override {
137     PYBIND11_OVERRIDE(
138         c10::intrusive_ptr<Work>, /* Return type */
139         ProcessGroup, /* Parent class */
140         barrier, /* Name of function in C++ */
141         opts);
142   }
143 
broadcast(std::vector<at::Tensor> & tensors,const BroadcastOptions & opts=BroadcastOptions ())144   c10::intrusive_ptr<Work> broadcast(
145       std::vector<at::Tensor>& tensors,
146       const BroadcastOptions& opts = BroadcastOptions()) override {
147     PYBIND11_OVERRIDE(
148         c10::intrusive_ptr<Work>, /* Return type */
149         ProcessGroup, /* Parent class */
150         broadcast, /* Name of function in C++ */
151         tensors,
152         opts);
153   }
154 
reduce_scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ReduceScatterOptions & opts=ReduceScatterOptions ())155   c10::intrusive_ptr<Work> reduce_scatter(
156       std::vector<at::Tensor>& outputTensors,
157       std::vector<std::vector<at::Tensor>>& inputTensors,
158       const ReduceScatterOptions& opts = ReduceScatterOptions()) override {
159     PYBIND11_OVERRIDE(
160         c10::intrusive_ptr<Work>, /* Return type */
161         ProcessGroup, /* Parent class */
162         reduce_scatter, /* Name of function in C++ */
163         outputTensors,
164         inputTensors,
165         opts);
166   }
167 
reduce_scatter_tensor_coalesced(std::vector<at::Tensor> & outputTensors,std::vector<at::Tensor> & inputTensors,const ReduceScatterOptions & opts=ReduceScatterOptions ())168   c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced(
169       std::vector<at::Tensor>& outputTensors,
170       std::vector<at::Tensor>& inputTensors,
171       const ReduceScatterOptions& opts = ReduceScatterOptions()) override {
172     PYBIND11_OVERRIDE(
173         c10::intrusive_ptr<Work>, /* Return type */
174         ProcessGroup, /* Parent class */
175         reduce_scatter_tensor_coalesced, /* Name of function in C++ */
176         outputTensors,
177         inputTensors,
178         opts);
179   }
180 
send(std::vector<at::Tensor> & tensors,int dstRank,int tag)181   c10::intrusive_ptr<Work> send(
182       std::vector<at::Tensor>& tensors,
183       int dstRank,
184       int tag) override {
185     PYBIND11_OVERRIDE(
186         c10::intrusive_ptr<Work>, /* Return type */
187         ProcessGroup, /* Parent class */
188         send, /* Name of function in C++ */
189         tensors,
190         dstRank,
191         tag);
192   }
193 
recv(std::vector<at::Tensor> & tensors,int srcRank,int tag)194   c10::intrusive_ptr<Work> recv(
195       std::vector<at::Tensor>& tensors,
196       int srcRank,
197       int tag) override {
198     PYBIND11_OVERRIDE(
199         c10::intrusive_ptr<Work>, /* Return type */
200         ProcessGroup, /* Parent class */
201         recv, /* Name of function in C++ */
202         tensors,
203         srcRank,
204         tag);
205   }
206 };
207 
208 class TORCH_PYTHON_API PythonOnCompletionHook {
209  public:
210   // Wraps a py::object hook and acquires Python GIL in dtor before
211   // destructing the hook object.
PythonOnCompletionHook(py::object hook)212   PythonOnCompletionHook(py::object hook) : hook_(std::move(hook)) {}
213 
~PythonOnCompletionHook()214   ~PythonOnCompletionHook() {
215     py::gil_scoped_acquire ag;
216     hook_.dec_ref();
217     // Explicitly set hook_ to nullptr to prevent py::object's dtor
218     // to decref on the PyObject again.
219     // See Note [Destructing py::object] in python_ivalue.h
220     hook_.ptr() = nullptr;
221   }
222 
operator ()(const std::shared_ptr<WorkInfo> & workInfo) const223   void operator()(const std::shared_ptr<WorkInfo>& workInfo) const {
224     std::exception_ptr eptr;
225     {
226       py::gil_scoped_acquire acquire;
227       try {
228         hook_(workInfo);
229       } catch (py::error_already_set& e) {
230         // py::error_already_set requires GIL to destruct, take
231         // special care.
232         eptr = std::make_exception_ptr(std::runtime_error(e.what()));
233         e.restore();
234         PyErr_Clear();
235       } catch (std::exception& e) {
236         eptr = std::current_exception();
237       }
238     }
239     // No more Python-related stuff at this point, i.e., this
240     // exception can be captured and handled by PG backend.
241     if (eptr)
242       std::rethrow_exception(eptr);
243   }
244 
245  private:
246   py::object hook_;
247 };
248 
249 } // namespace c10d
250