1 #pragma once 2 3 #include <torch/extension.h> 4 5 #include <deque> 6 #include <exception> 7 #include <memory> 8 #include <mutex> 9 #include <thread> 10 #include <vector> 11 #include <chrono> 12 13 #include <pybind11/chrono.h> 14 15 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp> 16 #include <torch/csrc/distributed/c10d/Work.hpp> 17 #include <torch/csrc/distributed/c10d/Store.hpp> 18 #include <torch/csrc/distributed/c10d/Types.hpp> 19 #include <torch/csrc/distributed/c10d/Utils.hpp> 20 21 namespace c10d { 22 23 // 24 // ProcessGroupTest implements dummy bindings for c10d. 25 // 26 27 class ProcessGroupTest : public ProcessGroup { 28 public: 29 class WorkTest : public Work { 30 public: WorkTest()31 WorkTest() {} 32 33 virtual ~WorkTest(); 34 bool isCompleted() override; 35 bool isSuccess() const override; 36 bool wait(std::chrono::milliseconds timeout) override; 37 38 protected: 39 friend class ProcessGroupTest; 40 }; 41 42 explicit ProcessGroupTest(int rank = -1, int size = -1); 43 virtual ~ProcessGroupTest(); 44 45 c10::intrusive_ptr<Work> broadcast( 46 std::vector<at::Tensor>& data, 47 const BroadcastOptions& opts = BroadcastOptions()) override; 48 49 c10::intrusive_ptr<Work> allreduce( 50 std::vector<at::Tensor>& tensors, 51 const AllreduceOptions& opts = AllreduceOptions()) override; 52 53 c10::intrusive_ptr<Work> allreduce_coalesced( 54 std::vector<at::Tensor>& tensors, 55 const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; 56 57 c10::intrusive_ptr<Work> reduce( 58 std::vector<at::Tensor>& tensors, 59 const ReduceOptions& opts = ReduceOptions()) override; 60 61 c10::intrusive_ptr<Work> allgather( 62 std::vector<std::vector<at::Tensor>>& outputTensors, 63 std::vector<at::Tensor>& inputTensors, 64 const AllgatherOptions& opts = AllgatherOptions()) override; 65 66 c10::intrusive_ptr<Work> _allgather_base( 67 at::Tensor& outputBuffer, 68 at::Tensor& inputBuffer, 69 const AllgatherOptions& opts = AllgatherOptions()) override; 70 71 c10::intrusive_ptr<Work> barrier( 72 const BarrierOptions& opts = BarrierOptions()) override; 73 74 c10::intrusive_ptr<Work> gather( 75 std::vector<std::vector<at::Tensor>>& outputTensors, 76 std::vector<at::Tensor>& inputTensors, 77 const GatherOptions& opts = GatherOptions()) override; 78 79 c10::intrusive_ptr<Work> scatter( 80 std::vector<at::Tensor>& outputTensors, 81 std::vector<std::vector<at::Tensor>>& inputTensors, 82 const ScatterOptions& opts = ScatterOptions()) override; 83 84 c10::intrusive_ptr<Work> reduce_scatter( 85 std::vector<at::Tensor>& outputTensors, 86 std::vector<std::vector<at::Tensor>>& inputTensors, 87 const ReduceScatterOptions& opts = ReduceScatterOptions()) override; 88 89 c10::intrusive_ptr<Work> send( 90 std::vector<at::Tensor>& tensors, 91 int dstRank, 92 int tag) override; 93 94 c10::intrusive_ptr<Work> recv( 95 std::vector<at::Tensor>& tensors, 96 int srcRank, 97 int tag) override; 98 99 c10::intrusive_ptr<Work> recvAnysource( 100 std::vector<at::Tensor>& tensor, 101 int tag) override; 102 103 // Create a new ProcessGroupTest instance 104 static c10::intrusive_ptr<ProcessGroup> createProcessGroupTest( 105 const c10::intrusive_ptr<::c10d::Store>& store, 106 int rank, 107 int size, 108 const std::chrono::duration<float>& timeout); 109 ProcessGroupTestConstructor()110 static void ProcessGroupTestConstructor() __attribute__((constructor)) { 111 py::object module = py::module::import("torch.distributed"); 112 py::object register_backend = module.attr("Backend").attr("register_backend"); 113 // The first parameter is the backend name used by user in invoking 114 // torch.distributed.init_process_group(). 115 // Note it could be different with module name. For example, the module 116 // name is "torch_test" but the backend name is "test". 117 // The second parameter is the instantiation function. 118 register_backend("test", py::cpp_function(createProcessGroupTest)); 119 } 120 121 }; 122 123 } // namespace c10d 124