xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/FakeProcessGroup.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/c10d/Backend.hpp>
4 
5 namespace c10d {
6 
7 class FakeWork : public Work {
8  public:
wait(std::chrono::milliseconds timeout)9   bool wait(std::chrono::milliseconds timeout) override {
10     return true;
11   }
12 
getFuture()13   c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
14     auto fut = c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
15     fut->markCompleted();
16     return fut;
17   }
18 };
19 
20 class FakeProcessGroup : public Backend {
21  public:
FakeProcessGroup(int rank,int size)22   FakeProcessGroup(int rank, int size) : Backend(rank, size) {}
23 
broadcast(std::vector<at::Tensor> &,const BroadcastOptions &=BroadcastOptions ())24   c10::intrusive_ptr<Work> broadcast(
25       std::vector<at::Tensor>& /* tensors */,
26       const BroadcastOptions& /* opts */ = BroadcastOptions()) override {
27     return c10::make_intrusive<FakeWork>();
28   }
29 
allreduce(std::vector<at::Tensor> &,const AllreduceOptions &=AllreduceOptions ())30   c10::intrusive_ptr<Work> allreduce(
31       std::vector<at::Tensor>& /* tensors */,
32       const AllreduceOptions& /* opts */ = AllreduceOptions()) override {
33     return c10::make_intrusive<FakeWork>();
34   }
35 
allreduce_sparse(std::vector<at::Tensor> &,const AllreduceOptions &=AllreduceOptions ())36   c10::intrusive_ptr<Work> allreduce_sparse(
37       std::vector<at::Tensor>& /* tensors */,
38       const AllreduceOptions& /* opts */ = AllreduceOptions()) override {
39     return c10::make_intrusive<FakeWork>();
40   }
41 
allreduce_coalesced(std::vector<at::Tensor> &,const AllreduceCoalescedOptions &=AllreduceCoalescedOptions ())42   c10::intrusive_ptr<Work> allreduce_coalesced(
43       std::vector<at::Tensor>& /* tensors */,
44       const AllreduceCoalescedOptions& /* opts */ =
45           AllreduceCoalescedOptions()) override {
46     return c10::make_intrusive<FakeWork>();
47   }
48 
reduce(std::vector<at::Tensor> &,const ReduceOptions &=ReduceOptions ())49   c10::intrusive_ptr<Work> reduce(
50       std::vector<at::Tensor>& /* tensors */,
51       const ReduceOptions& /* opts */ = ReduceOptions()) override {
52     return c10::make_intrusive<FakeWork>();
53   }
54 
55   // NOTE [allgather on FakeProcessGroup]
56   // Assume each rank have the same input tensor so we just copy to the results
57   // since it's not a real allgather, we simply make this copying logic to let
58   // some simple validation works (i.e. calling allgather to see if each rank
59   // have the same tensor or not).
60   //
61   // NOTE: in general it's not good form to try to make FakeProcessGroup work
62   // with real data, but the reasoning here is that we want FakeProcessGroup to
63   // work with DeviceMesh's init code that have the data validation, which
64   // makes it worth the tradeoff.
allgather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllgatherOptions &=AllgatherOptions ())65   c10::intrusive_ptr<Work> allgather(
66       std::vector<std::vector<at::Tensor>>& outputTensors,
67       std::vector<at::Tensor>& inputTensors,
68       const AllgatherOptions& /* opts */ = AllgatherOptions()) override {
69     for (auto& tensor : outputTensors[0]) {
70       tensor.copy_(inputTensors[0]);
71     }
72     return c10::make_intrusive<FakeWork>();
73   }
74 
_allgather_base(at::Tensor & outputBuffer,at::Tensor & inputBuffer,const AllgatherOptions &=AllgatherOptions ())75   c10::intrusive_ptr<Work> _allgather_base(
76       at::Tensor& outputBuffer,
77       at::Tensor& inputBuffer,
78       const AllgatherOptions& /* opts */ = AllgatherOptions()) override {
79     auto chunks = outputBuffer.chunk(size_);
80     for (auto& tensor : chunks) {
81       tensor.copy_(inputBuffer);
82     }
83     return c10::make_intrusive<FakeWork>();
84   }
85 
allgather_coalesced(std::vector<std::vector<at::Tensor>> &,std::vector<at::Tensor> &,const AllgatherOptions &=AllgatherOptions ())86   c10::intrusive_ptr<Work> allgather_coalesced(
87       std::vector<std::vector<at::Tensor>>& /* outputTensorLists */,
88       std::vector<at::Tensor>& /* inputTensors */,
89       const AllgatherOptions& /* opts */ = AllgatherOptions()) override {
90     return c10::make_intrusive<FakeWork>();
91   }
92 
allgather_into_tensor_coalesced(std::vector<at::Tensor> & outputs,std::vector<at::Tensor> & inputs,const AllgatherOptions &=AllgatherOptions ())93   c10::intrusive_ptr<Work> allgather_into_tensor_coalesced(
94       std::vector<at::Tensor>& outputs,
95       std::vector<at::Tensor>& inputs,
96       const AllgatherOptions& /* opts */ = AllgatherOptions()) override {
97     for (size_t i = 0; i < outputs.size(); ++i) {
98       auto chunks = outputs[i].chunk(size_);
99       for (auto& chunk : chunks) {
100         chunk.copy_(inputs[i]);
101       }
102     }
103     return c10::make_intrusive<FakeWork>();
104   }
105 
gather(std::vector<std::vector<at::Tensor>> &,std::vector<at::Tensor> &,const GatherOptions &=GatherOptions ())106   c10::intrusive_ptr<Work> gather(
107       std::vector<std::vector<at::Tensor>>& /* outputTensors */,
108       std::vector<at::Tensor>& /* inputTensors */,
109       const GatherOptions& /* opts */ = GatherOptions()) override {
110     return c10::make_intrusive<FakeWork>();
111   }
112 
scatter(std::vector<at::Tensor> &,std::vector<std::vector<at::Tensor>> &,const ScatterOptions &=ScatterOptions ())113   c10::intrusive_ptr<Work> scatter(
114       std::vector<at::Tensor>& /* outputTensors */,
115       std::vector<std::vector<at::Tensor>>& /* inputTensors */,
116       const ScatterOptions& /* opts */ = ScatterOptions()) override {
117     return c10::make_intrusive<FakeWork>();
118   }
119 
reduce_scatter(std::vector<at::Tensor> &,std::vector<std::vector<at::Tensor>> &,const ReduceScatterOptions &=ReduceScatterOptions ())120   c10::intrusive_ptr<Work> reduce_scatter(
121       std::vector<at::Tensor>& /* outputTensors */,
122       std::vector<std::vector<at::Tensor>>& /* inputTensors */,
123       const ReduceScatterOptions& /* opts */ =
124           ReduceScatterOptions()) override {
125     return c10::make_intrusive<FakeWork>();
126   }
127 
_reduce_scatter_base(at::Tensor &,at::Tensor &,const ReduceScatterOptions &=ReduceScatterOptions ())128   c10::intrusive_ptr<Work> _reduce_scatter_base(
129       at::Tensor& /* outputBuffer */,
130       at::Tensor& /* inputBuffer */,
131       const ReduceScatterOptions& /* opts */ =
132           ReduceScatterOptions()) override {
133     return c10::make_intrusive<FakeWork>();
134   }
135 
reduce_scatter_tensor_coalesced(std::vector<at::Tensor> &,std::vector<at::Tensor> &,const ReduceScatterOptions &=ReduceScatterOptions ())136   c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced(
137       std::vector<at::Tensor>& /* outputs */,
138       std::vector<at::Tensor>& /* inputs */,
139       const ReduceScatterOptions& /* opts */ =
140           ReduceScatterOptions()) override {
141     return c10::make_intrusive<FakeWork>();
142   }
143 
alltoall_base(at::Tensor &,at::Tensor &,std::vector<int64_t> &,std::vector<int64_t> &,const AllToAllOptions &=AllToAllOptions ())144   c10::intrusive_ptr<Work> alltoall_base(
145       at::Tensor& /* outputBuffer */,
146       at::Tensor& /* inputBuffer */,
147       std::vector<int64_t>& /* outputSplitSizes */,
148       std::vector<int64_t>& /* inputSplitSizes */,
149       const AllToAllOptions& /* opts */ = AllToAllOptions()) override {
150     return c10::make_intrusive<FakeWork>();
151   }
152 
alltoall(std::vector<at::Tensor> &,std::vector<at::Tensor> &,const AllToAllOptions & opts=AllToAllOptions ())153   c10::intrusive_ptr<Work> alltoall(
154       std::vector<at::Tensor>& /* outputTensors */,
155       std::vector<at::Tensor>& /* inputTensors */,
156       const AllToAllOptions& opts = AllToAllOptions()) override {
157     return c10::make_intrusive<FakeWork>();
158   }
159 
send(std::vector<at::Tensor> &,int,int)160   c10::intrusive_ptr<Work> send(
161       std::vector<at::Tensor>& /* tensors */,
162       int /* dstRank */,
163       int /* tag */) override {
164     return c10::make_intrusive<FakeWork>();
165   }
166 
recv(std::vector<at::Tensor> &,int,int)167   c10::intrusive_ptr<Work> recv(
168       std::vector<at::Tensor>& /* tensors */,
169       int /* srcRank */,
170       int /* tag */) override {
171     return c10::make_intrusive<FakeWork>();
172   }
173 
recvAnysource(std::vector<at::Tensor> &,int)174   c10::intrusive_ptr<Work> recvAnysource(
175       std::vector<at::Tensor>& /* tensors */,
176       int /* tag */) override {
177     return c10::make_intrusive<FakeWork>();
178   }
179 
barrier(const BarrierOptions &=BarrierOptions ())180   c10::intrusive_ptr<Work> barrier(
181       const BarrierOptions& /* opts */ = BarrierOptions()) override {
182     return c10::make_intrusive<FakeWork>();
183   }
184 };
185 
186 } // namespace c10d
187