1 #pragma once 2 3 #include <ATen/ATen.h> 4 #include <ATen/core/ivalue.h> 5 #include <torch/csrc/Export.h> 6 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp> 7 #include <utility> 8 9 namespace c10d { 10 11 // Broadcast many tensors to all processes in the process group. 12 TORCH_API void broadcast_coalesced( 13 const c10::intrusive_ptr<c10d::ProcessGroup>& process_group, 14 at::TensorList tensors, 15 size_t buffer_size, 16 int rank = 0); 17 18 // This class passes bucket contents tensor to DDP communication hook. 19 class TORCH_API GradBucket { 20 public: GradBucket(size_t index,size_t bucket_count,at::Tensor tensor,std::vector<size_t> offsets,std::vector<size_t> lengths,std::vector<c10::IntArrayRef> sizes_vec,std::vector<at::Tensor> parameters,std::optional<at::Tensor> sparse_grad_indices)21 explicit GradBucket( 22 size_t index, 23 size_t bucket_count, 24 at::Tensor tensor, 25 std::vector<size_t> offsets, 26 std::vector<size_t> lengths, 27 std::vector<c10::IntArrayRef> sizes_vec, 28 std::vector<at::Tensor> parameters, 29 std::optional<at::Tensor> sparse_grad_indices) 30 : index_(index), 31 bucket_count_(bucket_count), 32 buffer_(std::move(tensor)), 33 offsets_(std::move(offsets)), 34 lengths_(std::move(lengths)), 35 sizes_vec_(std::move(sizes_vec)), 36 parameters_(std::move(parameters)), 37 sparse_grad_indices_(std::move(sparse_grad_indices)) {} 38 39 // Returns the index of the bucket, which is unique across all the buckets. getIndex() const40 size_t getIndex() const { 41 return index_; 42 } 43 getBuffer() const44 const at::Tensor& getBuffer() const { 45 return buffer_; 46 } 47 48 // Returns a mutable buffer compared with the above method. getBufferRef()49 at::Tensor& getBufferRef() { 50 return buffer_; 51 } 52 53 // Overwrites the buffer at a specific index. setBuffer(at::Tensor & buffer)54 void setBuffer(at::Tensor& buffer) { 55 buffer_ = buffer; 56 } 57 58 // Each tensor in the list that getGradients corresponds to a 59 // parameter. 60 std::vector<at::Tensor> getGradients() const; 61 62 // Returns model parameters belonging to this bucket. They are returned in the 63 // same order as gradient tensors via getGradients(). For example, 64 // getParameters[i] will have its gradient stored in 65 // getGradients[i] getParameters() const66 const std::vector<at::Tensor> getParameters() const { 67 return parameters_; 68 } 69 70 // Returns whther this bucket is the last bucket to allreduce in an iteration. isLast() const71 bool isLast() const { 72 return index_ == bucket_count_ - 1; 73 } 74 getSparseGradIndices()75 std::optional<at::Tensor>& getSparseGradIndices() { 76 return sparse_grad_indices_; 77 } 78 79 private: 80 size_t index_; 81 size_t bucket_count_; 82 at::Tensor buffer_; 83 84 // Per-variable info in buffer_. 85 std::vector<size_t> offsets_; 86 std::vector<size_t> lengths_; 87 std::vector<c10::IntArrayRef> sizes_vec_; 88 89 // Model parameters for this bucket. 90 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 91 const std::vector<at::Tensor> parameters_; 92 93 // Predefined sparse indices for this bucket (only used for sparse tensors). 94 // The gradients will be updated to have indices with these tensor values 95 std::optional<at::Tensor> sparse_grad_indices_; 96 }; 97 98 // Base class of both `PythonCommHook` and `CppCommHook`. 99 // Requires implementing 1) `runHook` method that communicates gradients 100 // asynchronously, and 2) `parseHookResult` method that converts the hook 101 // result into a tensor. 102 class TORCH_API CommHookInterface { 103 public: 104 virtual ~CommHookInterface() = default; 105 106 // Passes the input grad bucket to the registered communication hook. 107 // Once the tensor in the bucket are ready, kicks off the hook asynchronously 108 // and returns a future that holds the communication results. 109 virtual c10::intrusive_ptr<c10::ivalue::Future> runHook( 110 GradBucket& bucket) = 0; 111 112 // Returns the resulting tensor once the communication hook result is 113 // ready. The resulting tensor will then be copied to the grads of 114 // individual parameters. 115 virtual at::Tensor parseHookResult(const c10::IValue& result) = 0; 116 }; 117 118 namespace detail { 119 // This helper function is called both by CppCommHookInterface below and inside 120 // reducer. 121 TORCH_API at::Tensor parseCppCommHookResult(const c10::IValue& result); 122 } // namespace detail 123 124 // This CppCommHook interface only requires implementing runHook method that 125 // potentially uses a state. 126 template <typename T> 127 class CppCommHookInterface : public CommHookInterface { 128 public: CppCommHookInterface(T state)129 explicit CppCommHookInterface(T state) : state_(std::move(state)) {} 130 131 ~CppCommHookInterface() override = default; 132 parseHookResult(const c10::IValue & result)133 at::Tensor parseHookResult(const c10::IValue& result) override { 134 return detail::parseCppCommHookResult(result); 135 } 136 137 protected: 138 T state_; 139 }; 140 141 } // namespace c10d 142