1 #pragma once 2 3 #include <memory> 4 #include <utility> 5 #include <vector> 6 7 #include <ATen/ATen.h> 8 #include <c10/macros/Macros.h> 9 10 #include <torch/csrc/distributed/c10d/Types.hpp> 11 #include <torch/csrc/distributed/c10d/Utils.hpp> 12 #include <torch/csrc/distributed/c10d/Work.hpp> 13 #include <torch/csrc/distributed/c10d/debug.h> 14 15 constexpr auto kBackendDefaultTimeout = 16 std::chrono::milliseconds(30 * 60 * 1000); 17 18 namespace c10d { 19 20 class TORCH_API Backend : public torch::CustomClassHolder { 21 public: 22 // Backend Options is a base struct that defines the basic options 23 // when constructing a Backend. Each Backend subclass should 24 // extend this struct and define its options if it wants to provide more 25 // config options (beyond basic ones defined here) to end user. 26 struct TORCH_API Options : torch::CustomClassHolder { Optionsc10d::Backend::Options27 explicit Options( 28 std::string backend, 29 std::chrono::milliseconds timeout = kBackendDefaultTimeout) 30 : timeout(timeout), backend(std::move(backend)) {} 31 ~Options() override = default; 32 33 std::chrono::milliseconds timeout; 34 35 // backend name 36 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 37 const std::string backend; 38 }; 39 40 explicit Backend(int rank, int size); 41 ~Backend() override = 0; 42 getRank() const43 int getRank() const { 44 return rank_; 45 } 46 getSize() const47 int getSize() const { 48 return size_; 49 } 50 51 // Returns an unique opaque ID of this backend that can be used to correlate 52 // with its collectives. getID() const53 int64_t getID() const { 54 return reinterpret_cast<std::intptr_t>(this); 55 } 56 supportsSplitting() const57 virtual bool supportsSplitting() const { 58 return false; 59 } 60 startCoalescing()61 virtual void startCoalescing() { 62 TORCH_CHECK( 63 false, 64 c10::str( 65 "Backend ", 66 getBackendName(), 67 " does not implement startCoalescing")); 68 } 69 endCoalescing()70 virtual c10::intrusive_ptr<Work> endCoalescing() { 71 TORCH_CHECK( 72 false, 73 c10::str( 74 "Backend ", getBackendName(), " does not implement endCoalescing")); 75 } 76 77 // Subclasses must override this method to return the backend name getBackendName() const78 virtual const std::string getBackendName() const { 79 TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented."); 80 }; 81 broadcast(std::vector<at::Tensor> &,const BroadcastOptions &=BroadcastOptions ())82 virtual c10::intrusive_ptr<Work> broadcast( 83 std::vector<at::Tensor>& /* tensors */, 84 const BroadcastOptions& /* opts */ = BroadcastOptions()) { 85 TORCH_CHECK( 86 false, 87 c10::str("Backend ", getBackendName(), " does not support broadcast")); 88 } 89 allreduce(std::vector<at::Tensor> &,const AllreduceOptions &=AllreduceOptions ())90 virtual c10::intrusive_ptr<Work> allreduce( 91 std::vector<at::Tensor>& /* tensors */, 92 const AllreduceOptions& /* opts */ = AllreduceOptions()) { 93 TORCH_CHECK( 94 false, 95 c10::str("Backend ", getBackendName(), " does not support allreduce")); 96 } 97 allreduce_sparse(std::vector<at::Tensor> &,const AllreduceOptions &=AllreduceOptions ())98 virtual c10::intrusive_ptr<Work> allreduce_sparse( 99 std::vector<at::Tensor>& /* tensors */, 100 const AllreduceOptions& /* opts */ = AllreduceOptions()) { 101 TORCH_CHECK( 102 false, 103 c10::str( 104 "Backend ", 105 getBackendName(), 106 " does not support allreduce sparse")); 107 } 108 allreduce_coalesced(std::vector<at::Tensor> &,const AllreduceCoalescedOptions &=AllreduceCoalescedOptions ())109 virtual c10::intrusive_ptr<Work> allreduce_coalesced( 110 std::vector<at::Tensor>& /* tensors */, 111 const AllreduceCoalescedOptions& /* opts */ = 112 AllreduceCoalescedOptions()) { 113 TORCH_CHECK( 114 false, 115 c10::str( 116 "Backend ", 117 getBackendName(), 118 " does not support allreduce_coalesced")); 119 } 120 reduce(std::vector<at::Tensor> &,const ReduceOptions &=ReduceOptions ())121 virtual c10::intrusive_ptr<Work> reduce( 122 std::vector<at::Tensor>& /* tensors */, 123 const ReduceOptions& /* opts */ = ReduceOptions()) { 124 TORCH_CHECK( 125 false, 126 c10::str("Backend ", getBackendName(), " does not support reduce")); 127 } 128 allgather(std::vector<std::vector<at::Tensor>> &,std::vector<at::Tensor> &,const AllgatherOptions &=AllgatherOptions ())129 virtual c10::intrusive_ptr<Work> allgather( 130 std::vector<std::vector<at::Tensor>>& /* outputTensors */, 131 std::vector<at::Tensor>& /* inputTensors */, 132 const AllgatherOptions& /* opts */ = AllgatherOptions()) { 133 TORCH_CHECK( 134 false, 135 c10::str("Backend ", getBackendName(), " does not support allgather")); 136 } 137 138 // Gathers a single tensor inputBuffer into a single buffer outputBuffer that 139 // is interpreted as a contiguous collection of size inputBuffer * WORLD_SIZE. 140 // For implementers of ProcessGroup API and advanced users only. 141 // Note: this function will be deprecated in near future. _allgather_base(at::Tensor &,at::Tensor &,const AllgatherOptions &=AllgatherOptions ())142 virtual c10::intrusive_ptr<Work> _allgather_base( 143 at::Tensor& /* outputBuffer */, 144 at::Tensor& /* inputBuffer */, 145 const AllgatherOptions& /* opts */ = AllgatherOptions()) { 146 TORCH_CHECK( 147 false, 148 c10::str( 149 "Backend ", getBackendName(), " does not support _allgather_base")); 150 } 151 152 // This function is deprecated and will be moved out of Backend to comms: 153 // * do not add dependencies on this function, 154 // * do not implement it in your Backend, implement _allgather_base 155 // instead. allgather_coalesced(std::vector<std::vector<at::Tensor>> &,std::vector<at::Tensor> &,const AllgatherOptions &=AllgatherOptions ())156 virtual c10::intrusive_ptr<Work> allgather_coalesced( 157 std::vector<std::vector<at::Tensor>>& /* outputTensorLists */, 158 std::vector<at::Tensor>& /* inputTensors */, 159 const AllgatherOptions& /* opts */ = AllgatherOptions()) { 160 TORCH_CHECK( 161 false, 162 c10::str( 163 "Backend ", 164 getBackendName(), 165 " does not support allgather_coalesced")); 166 } 167 168 // This function is a coalesced version of `allgather_into_tensor` (currently 169 // still named as `_allgather_base`). Each tensor in the vector corresponds to 170 // an input/output of one `allgather_into_tensor` operation. allgather_into_tensor_coalesced(std::vector<at::Tensor> &,std::vector<at::Tensor> &,const AllgatherOptions &=AllgatherOptions ())171 virtual c10::intrusive_ptr<Work> allgather_into_tensor_coalesced( 172 std::vector<at::Tensor>& /* outputs */, 173 std::vector<at::Tensor>& /* inputs */, 174 const AllgatherOptions& /* opts */ = AllgatherOptions()) { 175 TORCH_CHECK( 176 false, 177 c10::str( 178 "Backend ", 179 getBackendName(), 180 " does not support allgather_into_tensor_coalesced")); 181 } 182 gather(std::vector<std::vector<at::Tensor>> &,std::vector<at::Tensor> &,const GatherOptions &=GatherOptions ())183 virtual c10::intrusive_ptr<Work> gather( 184 std::vector<std::vector<at::Tensor>>& /* outputTensors */, 185 std::vector<at::Tensor>& /* inputTensors */, 186 const GatherOptions& /* opts */ = GatherOptions()) { 187 TORCH_CHECK( 188 false, 189 c10::str("Backend ", getBackendName(), " does not support gather")); 190 } 191 scatter(std::vector<at::Tensor> &,std::vector<std::vector<at::Tensor>> &,const ScatterOptions &=ScatterOptions ())192 virtual c10::intrusive_ptr<Work> scatter( 193 std::vector<at::Tensor>& /* outputTensors */, 194 std::vector<std::vector<at::Tensor>>& /* inputTensors */, 195 const ScatterOptions& /* opts */ = ScatterOptions()) { 196 TORCH_CHECK( 197 false, 198 c10::str("Backend ", getBackendName(), " does not support scatter")); 199 } 200 reduce_scatter(std::vector<at::Tensor> &,std::vector<std::vector<at::Tensor>> &,const ReduceScatterOptions &=ReduceScatterOptions ())201 virtual c10::intrusive_ptr<Work> reduce_scatter( 202 std::vector<at::Tensor>& /* outputTensors */, 203 std::vector<std::vector<at::Tensor>>& /* inputTensors */, 204 const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) { 205 TORCH_CHECK( 206 false, 207 c10::str( 208 "Backend ", getBackendName(), " does not support reduce_scatter")); 209 } 210 _reduce_scatter_base(at::Tensor &,at::Tensor &,const ReduceScatterOptions &=ReduceScatterOptions ())211 virtual c10::intrusive_ptr<Work> _reduce_scatter_base( 212 at::Tensor& /* outputBuffer */, 213 at::Tensor& /* inputBuffer */, 214 const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) { 215 TORCH_CHECK( 216 false, 217 c10::str( 218 "Backend ", 219 getBackendName(), 220 " does not support _reduce_scatter_base")); 221 } 222 223 // This function is a coalesced version of `reduce_scatter_tensor` (currently 224 // still named as `_reduce_scatter_base`). Each tensor in the vector 225 // corresponds to an input/output of one `reduce_scatter_tensor` operation. reduce_scatter_tensor_coalesced(std::vector<at::Tensor> &,std::vector<at::Tensor> &,const ReduceScatterOptions &=ReduceScatterOptions ())226 virtual c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced( 227 std::vector<at::Tensor>& /* outputs */, 228 std::vector<at::Tensor>& /* inputs */, 229 const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) { 230 TORCH_CHECK( 231 false, 232 c10::str( 233 "Backend ", 234 getBackendName(), 235 " does not support reduce_scatter_tensor_coalesced")); 236 } 237 alltoall_base(at::Tensor &,at::Tensor &,std::vector<int64_t> &,std::vector<int64_t> &,const AllToAllOptions &=AllToAllOptions ())238 virtual c10::intrusive_ptr<Work> alltoall_base( 239 at::Tensor& /* outputBuffer */, 240 at::Tensor& /* inputBuffer */, 241 std::vector<int64_t>& /* outputSplitSizes */, 242 std::vector<int64_t>& /* inputSplitSizes */, 243 const AllToAllOptions& /* opts */ = AllToAllOptions()) { 244 TORCH_CHECK( 245 false, 246 c10::str( 247 "Backend ", getBackendName(), " does not support alltoall_base")); 248 } 249 alltoall(std::vector<at::Tensor> &,std::vector<at::Tensor> &,const AllToAllOptions & opts=AllToAllOptions ())250 virtual c10::intrusive_ptr<Work> alltoall( 251 std::vector<at::Tensor>& /* outputTensors */, 252 std::vector<at::Tensor>& /* inputTensors */, 253 const AllToAllOptions& opts = AllToAllOptions()) { 254 TORCH_CHECK( 255 false, 256 c10::str("Backend ", getBackendName(), " does not support alltoall")); 257 } 258 monitoredBarrier(const BarrierOptions &,bool=false)259 virtual void monitoredBarrier( 260 const BarrierOptions& /* unused */, 261 bool /* unused */ = false) { 262 auto backendName = getBackendName(); 263 TORCH_CHECK( 264 false, 265 c10::str( 266 "Backend ", 267 backendName, 268 " does not support monitoredBarrier, only GLOO supports monitored barrier.")); 269 } 270 271 // Agrees on an initial sequence number for the whole group by having rank 0 272 // create it and broadcast it to other ranks using the store. Only implemented 273 // for GLOO and NCCL backends currently. setSequenceNumberForGroup()274 virtual void setSequenceNumberForGroup() { 275 auto backendName = getBackendName(); 276 TORCH_CHECK( 277 false, 278 c10::str( 279 "Backend ", 280 backendName, 281 " does not yet support sequence numbers.")); 282 } 283 284 // Retrieves the current sequence number for the whole group, which should be 285 // in sync. If the returned number is not consistent across the group, it 286 // may indicate that there is some sort of collective desynchronization. getSequenceNumberForGroup()287 virtual uint64_t getSequenceNumberForGroup() { 288 auto backendName = getBackendName(); 289 TORCH_CHECK( 290 false, 291 c10::str( 292 "Backend ", 293 backendName, 294 " does not yet support sequence numbers.")); 295 } 296 send(std::vector<at::Tensor> &,int,int)297 virtual c10::intrusive_ptr<Work> send( 298 std::vector<at::Tensor>& /* tensors */, 299 int /* dstRank */, 300 int /* tag */) { 301 TORCH_CHECK( 302 false, 303 c10::str("Backend ", getBackendName(), " does not support send")); 304 } 305 recv(std::vector<at::Tensor> &,int,int)306 virtual c10::intrusive_ptr<Work> recv( 307 std::vector<at::Tensor>& /* tensors */, 308 int /* srcRank */, 309 int /* tag */) { 310 TORCH_CHECK( 311 false, 312 c10::str("Backend ", getBackendName(), " does not support recv")); 313 } 314 recvAnysource(std::vector<at::Tensor> &,int)315 virtual c10::intrusive_ptr<Work> recvAnysource( 316 std::vector<at::Tensor>& /* tensors */, 317 int /* tag */) { 318 TORCH_CHECK( 319 false, 320 c10::str( 321 "Backend ", getBackendName(), " does not support recvAnysource")); 322 } 323 barrier(const BarrierOptions &=BarrierOptions ())324 virtual c10::intrusive_ptr<Work> barrier( 325 const BarrierOptions& /* opts */ = BarrierOptions()) { 326 TORCH_CHECK( 327 false, 328 c10::str("Backend ", getBackendName(), " does not support barrier")); 329 } 330 registerOnCompletionHook(std::function<void (std::shared_ptr<WorkInfo>)> && hook)331 virtual void registerOnCompletionHook( 332 std::function<void(std::shared_ptr<WorkInfo>)>&& hook) { 333 TORCH_CHECK( 334 false, 335 "Only ProcessGrouppNCCL supports onCompletion hook, but got ", 336 getBackendName(), 337 " backend."); 338 } 339 waitForPendingWorks()340 virtual void waitForPendingWorks() { 341 TORCH_CHECK( 342 false, 343 "Only ProcessGrouppNCCL supports waitForPendingWorks, but got ", 344 getBackendName(), 345 " backend."); 346 } 347 enableCollectivesTiming()348 virtual void enableCollectivesTiming() { 349 TORCH_CHECK( 350 false, 351 "Backend ", 352 getBackendName(), 353 " is missing implementation of enableCollectivesTiming."); 354 } 355 hasHooks() const356 bool hasHooks() const { 357 return onCompletionHook_ != nullptr; 358 } 359 360 // Do not call this directly, use ProcessGroup::setGroupName instead. setGroupUid(const std::string & pg_uid)361 void setGroupUid(const std::string& pg_uid) { 362 pg_uid_ = pg_uid; 363 } 364 getGroupUid() const365 const std::string& getGroupUid() const { 366 return pg_uid_; 367 } 368 setGroupDesc(const std::string & desc)369 void setGroupDesc(const std::string& desc) { 370 pg_desc_ = desc; 371 } 372 getGroupDesc() const373 const std::string& getGroupDesc() const { 374 return pg_desc_; 375 } 376 377 // See similar functions in ProcessGroup.hpp for context. getBoundDeviceId() const378 std::optional<at::Device> getBoundDeviceId() const { 379 return bound_device_id_; 380 } 381 382 // Perform an eager connect to the specified device if the backend supports 383 // it. eagerConnectSingleDevice(at::Device device)384 virtual void eagerConnectSingleDevice(at::Device device) { 385 // no-op in the default case; this is an optimization some 386 // backends may perform 387 } 388 setBoundDeviceId(std::optional<at::Device> device)389 void setBoundDeviceId(std::optional<at::Device> device) { 390 if (device) { 391 TORCH_CHECK(device->has_index(), "setBoundDeviceId must have an index"); 392 } 393 bound_device_id_ = device; 394 } 395 396 protected: 397 // Implementations of this interface need to call this to setup 398 // appropriate logging etc. 399 void init(); 400 401 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 402 const int rank_; 403 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 404 const int size_; 405 // Debug level setting. It is parsed once when ProcessGroup is constructed and 406 // remains the same across use of this process group. 407 DebugLevel dist_debug_level_; 408 std::string pg_uid_; 409 std::string pg_desc_; 410 411 std::function<void(std::shared_ptr<WorkInfo>)> onCompletionHook_; 412 413 std::optional<at::Device> bound_device_id_; 414 }; 415 416 } // namespace c10d 417