1 #pragma once 2 3 #include <torch/csrc/distributed/c10d/Backend.hpp> 4 5 namespace c10d { 6 7 class FakeWork : public Work { 8 public: wait(std::chrono::milliseconds timeout)9 bool wait(std::chrono::milliseconds timeout) override { 10 return true; 11 } 12 getFuture()13 c10::intrusive_ptr<c10::ivalue::Future> getFuture() override { 14 auto fut = c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get()); 15 fut->markCompleted(); 16 return fut; 17 } 18 }; 19 20 class FakeProcessGroup : public Backend { 21 public: FakeProcessGroup(int rank,int size)22 FakeProcessGroup(int rank, int size) : Backend(rank, size) {} 23 broadcast(std::vector<at::Tensor> &,const BroadcastOptions &=BroadcastOptions ())24 c10::intrusive_ptr<Work> broadcast( 25 std::vector<at::Tensor>& /* tensors */, 26 const BroadcastOptions& /* opts */ = BroadcastOptions()) override { 27 return c10::make_intrusive<FakeWork>(); 28 } 29 allreduce(std::vector<at::Tensor> &,const AllreduceOptions &=AllreduceOptions ())30 c10::intrusive_ptr<Work> allreduce( 31 std::vector<at::Tensor>& /* tensors */, 32 const AllreduceOptions& /* opts */ = AllreduceOptions()) override { 33 return c10::make_intrusive<FakeWork>(); 34 } 35 allreduce_sparse(std::vector<at::Tensor> &,const AllreduceOptions &=AllreduceOptions ())36 c10::intrusive_ptr<Work> allreduce_sparse( 37 std::vector<at::Tensor>& /* tensors */, 38 const AllreduceOptions& /* opts */ = AllreduceOptions()) override { 39 return c10::make_intrusive<FakeWork>(); 40 } 41 allreduce_coalesced(std::vector<at::Tensor> &,const AllreduceCoalescedOptions &=AllreduceCoalescedOptions ())42 c10::intrusive_ptr<Work> allreduce_coalesced( 43 std::vector<at::Tensor>& /* tensors */, 44 const AllreduceCoalescedOptions& /* opts */ = 45 AllreduceCoalescedOptions()) override { 46 return c10::make_intrusive<FakeWork>(); 47 } 48 reduce(std::vector<at::Tensor> &,const ReduceOptions &=ReduceOptions ())49 c10::intrusive_ptr<Work> reduce( 50 std::vector<at::Tensor>& /* tensors */, 51 const ReduceOptions& /* opts */ = ReduceOptions()) override { 52 return c10::make_intrusive<FakeWork>(); 53 } 54 55 // NOTE [allgather on FakeProcessGroup] 56 // Assume each rank have the same input tensor so we just copy to the results 57 // since it's not a real allgather, we simply make this copying logic to let 58 // some simple validation works (i.e. calling allgather to see if each rank 59 // have the same tensor or not). 60 // 61 // NOTE: in general it's not good form to try to make FakeProcessGroup work 62 // with real data, but the reasoning here is that we want FakeProcessGroup to 63 // work with DeviceMesh's init code that have the data validation, which 64 // makes it worth the tradeoff. allgather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllgatherOptions &=AllgatherOptions ())65 c10::intrusive_ptr<Work> allgather( 66 std::vector<std::vector<at::Tensor>>& outputTensors, 67 std::vector<at::Tensor>& inputTensors, 68 const AllgatherOptions& /* opts */ = AllgatherOptions()) override { 69 for (auto& tensor : outputTensors[0]) { 70 tensor.copy_(inputTensors[0]); 71 } 72 return c10::make_intrusive<FakeWork>(); 73 } 74 _allgather_base(at::Tensor & outputBuffer,at::Tensor & inputBuffer,const AllgatherOptions &=AllgatherOptions ())75 c10::intrusive_ptr<Work> _allgather_base( 76 at::Tensor& outputBuffer, 77 at::Tensor& inputBuffer, 78 const AllgatherOptions& /* opts */ = AllgatherOptions()) override { 79 auto chunks = outputBuffer.chunk(size_); 80 for (auto& tensor : chunks) { 81 tensor.copy_(inputBuffer); 82 } 83 return c10::make_intrusive<FakeWork>(); 84 } 85 allgather_coalesced(std::vector<std::vector<at::Tensor>> &,std::vector<at::Tensor> &,const AllgatherOptions &=AllgatherOptions ())86 c10::intrusive_ptr<Work> allgather_coalesced( 87 std::vector<std::vector<at::Tensor>>& /* outputTensorLists */, 88 std::vector<at::Tensor>& /* inputTensors */, 89 const AllgatherOptions& /* opts */ = AllgatherOptions()) override { 90 return c10::make_intrusive<FakeWork>(); 91 } 92 allgather_into_tensor_coalesced(std::vector<at::Tensor> & outputs,std::vector<at::Tensor> & inputs,const AllgatherOptions &=AllgatherOptions ())93 c10::intrusive_ptr<Work> allgather_into_tensor_coalesced( 94 std::vector<at::Tensor>& outputs, 95 std::vector<at::Tensor>& inputs, 96 const AllgatherOptions& /* opts */ = AllgatherOptions()) override { 97 for (size_t i = 0; i < outputs.size(); ++i) { 98 auto chunks = outputs[i].chunk(size_); 99 for (auto& chunk : chunks) { 100 chunk.copy_(inputs[i]); 101 } 102 } 103 return c10::make_intrusive<FakeWork>(); 104 } 105 gather(std::vector<std::vector<at::Tensor>> &,std::vector<at::Tensor> &,const GatherOptions &=GatherOptions ())106 c10::intrusive_ptr<Work> gather( 107 std::vector<std::vector<at::Tensor>>& /* outputTensors */, 108 std::vector<at::Tensor>& /* inputTensors */, 109 const GatherOptions& /* opts */ = GatherOptions()) override { 110 return c10::make_intrusive<FakeWork>(); 111 } 112 scatter(std::vector<at::Tensor> &,std::vector<std::vector<at::Tensor>> &,const ScatterOptions &=ScatterOptions ())113 c10::intrusive_ptr<Work> scatter( 114 std::vector<at::Tensor>& /* outputTensors */, 115 std::vector<std::vector<at::Tensor>>& /* inputTensors */, 116 const ScatterOptions& /* opts */ = ScatterOptions()) override { 117 return c10::make_intrusive<FakeWork>(); 118 } 119 reduce_scatter(std::vector<at::Tensor> &,std::vector<std::vector<at::Tensor>> &,const ReduceScatterOptions &=ReduceScatterOptions ())120 c10::intrusive_ptr<Work> reduce_scatter( 121 std::vector<at::Tensor>& /* outputTensors */, 122 std::vector<std::vector<at::Tensor>>& /* inputTensors */, 123 const ReduceScatterOptions& /* opts */ = 124 ReduceScatterOptions()) override { 125 return c10::make_intrusive<FakeWork>(); 126 } 127 _reduce_scatter_base(at::Tensor &,at::Tensor &,const ReduceScatterOptions &=ReduceScatterOptions ())128 c10::intrusive_ptr<Work> _reduce_scatter_base( 129 at::Tensor& /* outputBuffer */, 130 at::Tensor& /* inputBuffer */, 131 const ReduceScatterOptions& /* opts */ = 132 ReduceScatterOptions()) override { 133 return c10::make_intrusive<FakeWork>(); 134 } 135 reduce_scatter_tensor_coalesced(std::vector<at::Tensor> &,std::vector<at::Tensor> &,const ReduceScatterOptions &=ReduceScatterOptions ())136 c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced( 137 std::vector<at::Tensor>& /* outputs */, 138 std::vector<at::Tensor>& /* inputs */, 139 const ReduceScatterOptions& /* opts */ = 140 ReduceScatterOptions()) override { 141 return c10::make_intrusive<FakeWork>(); 142 } 143 alltoall_base(at::Tensor &,at::Tensor &,std::vector<int64_t> &,std::vector<int64_t> &,const AllToAllOptions &=AllToAllOptions ())144 c10::intrusive_ptr<Work> alltoall_base( 145 at::Tensor& /* outputBuffer */, 146 at::Tensor& /* inputBuffer */, 147 std::vector<int64_t>& /* outputSplitSizes */, 148 std::vector<int64_t>& /* inputSplitSizes */, 149 const AllToAllOptions& /* opts */ = AllToAllOptions()) override { 150 return c10::make_intrusive<FakeWork>(); 151 } 152 alltoall(std::vector<at::Tensor> &,std::vector<at::Tensor> &,const AllToAllOptions & opts=AllToAllOptions ())153 c10::intrusive_ptr<Work> alltoall( 154 std::vector<at::Tensor>& /* outputTensors */, 155 std::vector<at::Tensor>& /* inputTensors */, 156 const AllToAllOptions& opts = AllToAllOptions()) override { 157 return c10::make_intrusive<FakeWork>(); 158 } 159 send(std::vector<at::Tensor> &,int,int)160 c10::intrusive_ptr<Work> send( 161 std::vector<at::Tensor>& /* tensors */, 162 int /* dstRank */, 163 int /* tag */) override { 164 return c10::make_intrusive<FakeWork>(); 165 } 166 recv(std::vector<at::Tensor> &,int,int)167 c10::intrusive_ptr<Work> recv( 168 std::vector<at::Tensor>& /* tensors */, 169 int /* srcRank */, 170 int /* tag */) override { 171 return c10::make_intrusive<FakeWork>(); 172 } 173 recvAnysource(std::vector<at::Tensor> &,int)174 c10::intrusive_ptr<Work> recvAnysource( 175 std::vector<at::Tensor>& /* tensors */, 176 int /* tag */) override { 177 return c10::make_intrusive<FakeWork>(); 178 } 179 barrier(const BarrierOptions &=BarrierOptions ())180 c10::intrusive_ptr<Work> barrier( 181 const BarrierOptions& /* opts */ = BarrierOptions()) override { 182 return c10::make_intrusive<FakeWork>(); 183 } 184 }; 185 186 } // namespace c10d 187