xref: /aosp_15_r20/external/pytorch/test/cpp_extensions/cpp_c10d_extension.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/extension.h>
4 
5 #include <deque>
6 #include <exception>
7 #include <memory>
8 #include <mutex>
9 #include <thread>
10 #include <vector>
11 #include <chrono>
12 
13 #include <pybind11/chrono.h>
14 
15 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
16 #include <torch/csrc/distributed/c10d/Work.hpp>
17 #include <torch/csrc/distributed/c10d/Store.hpp>
18 #include <torch/csrc/distributed/c10d/Types.hpp>
19 #include <torch/csrc/distributed/c10d/Utils.hpp>
20 
21 namespace c10d {
22 
23 //
24 // ProcessGroupTest implements dummy bindings for c10d.
25 //
26 
27 class ProcessGroupTest : public ProcessGroup {
28  public:
29   class WorkTest : public Work {
30    public:
WorkTest()31     WorkTest() {}
32 
33     virtual ~WorkTest();
34     bool isCompleted() override;
35     bool isSuccess() const override;
36     bool wait(std::chrono::milliseconds timeout) override;
37 
38    protected:
39     friend class ProcessGroupTest;
40   };
41 
42   explicit ProcessGroupTest(int rank = -1, int size = -1);
43   virtual ~ProcessGroupTest();
44 
45   c10::intrusive_ptr<Work> broadcast(
46       std::vector<at::Tensor>& data,
47       const BroadcastOptions& opts = BroadcastOptions()) override;
48 
49   c10::intrusive_ptr<Work> allreduce(
50       std::vector<at::Tensor>& tensors,
51       const AllreduceOptions& opts = AllreduceOptions()) override;
52 
53   c10::intrusive_ptr<Work> allreduce_coalesced(
54       std::vector<at::Tensor>& tensors,
55       const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override;
56 
57   c10::intrusive_ptr<Work> reduce(
58       std::vector<at::Tensor>& tensors,
59       const ReduceOptions& opts = ReduceOptions()) override;
60 
61   c10::intrusive_ptr<Work> allgather(
62       std::vector<std::vector<at::Tensor>>& outputTensors,
63       std::vector<at::Tensor>& inputTensors,
64       const AllgatherOptions& opts = AllgatherOptions()) override;
65 
66   c10::intrusive_ptr<Work> _allgather_base(
67       at::Tensor& outputBuffer,
68       at::Tensor& inputBuffer,
69       const AllgatherOptions& opts = AllgatherOptions()) override;
70 
71   c10::intrusive_ptr<Work> barrier(
72       const BarrierOptions& opts = BarrierOptions()) override;
73 
74   c10::intrusive_ptr<Work> gather(
75       std::vector<std::vector<at::Tensor>>& outputTensors,
76       std::vector<at::Tensor>& inputTensors,
77       const GatherOptions& opts = GatherOptions()) override;
78 
79   c10::intrusive_ptr<Work> scatter(
80       std::vector<at::Tensor>& outputTensors,
81       std::vector<std::vector<at::Tensor>>& inputTensors,
82       const ScatterOptions& opts = ScatterOptions()) override;
83 
84   c10::intrusive_ptr<Work> reduce_scatter(
85       std::vector<at::Tensor>& outputTensors,
86       std::vector<std::vector<at::Tensor>>& inputTensors,
87       const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
88 
89   c10::intrusive_ptr<Work> send(
90       std::vector<at::Tensor>& tensors,
91       int dstRank,
92       int tag) override;
93 
94   c10::intrusive_ptr<Work> recv(
95       std::vector<at::Tensor>& tensors,
96       int srcRank,
97       int tag) override;
98 
99   c10::intrusive_ptr<Work> recvAnysource(
100       std::vector<at::Tensor>& tensor,
101       int tag) override;
102 
103   // Create a new ProcessGroupTest instance
104   static c10::intrusive_ptr<ProcessGroup> createProcessGroupTest(
105       const c10::intrusive_ptr<::c10d::Store>& store,
106       int rank,
107       int size,
108       const std::chrono::duration<float>& timeout);
109 
ProcessGroupTestConstructor()110   static void ProcessGroupTestConstructor() __attribute__((constructor)) {
111       py::object module = py::module::import("torch.distributed");
112       py::object register_backend = module.attr("Backend").attr("register_backend");
113       // The first parameter is the backend name used by user in invoking
114       // torch.distributed.init_process_group().
115       // Note it could be different with module name. For example, the module
116       // name is "torch_test" but the backend name is "test".
117       // The second parameter is the instantiation function.
118       register_backend("test", py::cpp_function(createProcessGroupTest));
119   }
120 
121 };
122 
123 } // namespace c10d
124