1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/data/samplers/base.h> 5 6 #include <cstddef> 7 #include <vector> 8 9 namespace torch { 10 namespace serialize { 11 class OutputArchive; 12 class InputArchive; 13 } // namespace serialize 14 } // namespace torch 15 16 namespace torch { 17 namespace data { 18 namespace samplers { 19 20 /// A `Sampler` that selects a subset of indices to sample from and defines a 21 /// sampling behavior. In a distributed setting, this selects a subset of the 22 /// indices depending on the provided num_replicas and rank parameters. The 23 /// `Sampler` performs a rounding operation based on the `allow_duplicates` 24 /// parameter to decide the local sample count. 25 template <typename BatchRequest = std::vector<size_t>> 26 class DistributedSampler : public Sampler<BatchRequest> { 27 public: 28 DistributedSampler( 29 size_t size, 30 size_t num_replicas = 1, 31 size_t rank = 0, 32 bool allow_duplicates = true) size_(size)33 : size_(size), 34 num_replicas_(num_replicas), 35 rank_(rank), 36 epoch_(0), 37 allow_duplicates_(allow_duplicates) {} 38 39 /// Set the epoch for the current enumeration. This can be used to alter the 40 /// sample selection and shuffling behavior. set_epoch(size_t epoch)41 void set_epoch(size_t epoch) { 42 epoch_ = epoch; 43 } 44 epoch()45 size_t epoch() const { 46 return epoch_; 47 } 48 49 protected: local_sample_count()50 size_t local_sample_count() { 51 if (allow_duplicates_) { 52 return (size_ + num_replicas_ - 1) / num_replicas_; 53 } else { 54 return size_ / num_replicas_; 55 } 56 } 57 58 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 59 size_t size_; 60 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 61 size_t num_replicas_; 62 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 63 size_t rank_; 64 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 65 size_t epoch_; 66 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 67 bool allow_duplicates_; 68 }; 69 70 /// Select samples randomly. The sampling order is shuffled at each `reset()` 71 /// call. 72 class TORCH_API DistributedRandomSampler : public DistributedSampler<> { 73 public: 74 DistributedRandomSampler( 75 size_t size, 76 size_t num_replicas = 1, 77 size_t rank = 0, 78 bool allow_duplicates = true); 79 80 /// Resets the `DistributedRandomSampler` to a new set of indices. 81 void reset(std::optional<size_t> new_size = std::nullopt) override; 82 83 /// Returns the next batch of indices. 84 std::optional<std::vector<size_t>> next(size_t batch_size) override; 85 86 /// Serializes the `DistributedRandomSampler` to the `archive`. 87 void save(serialize::OutputArchive& archive) const override; 88 89 /// Deserializes the `DistributedRandomSampler` from the `archive`. 90 void load(serialize::InputArchive& archive) override; 91 92 /// Returns the current index of the `DistributedRandomSampler`. 93 size_t index() const noexcept; 94 95 private: 96 void populate_indices(); 97 98 size_t begin_index_; 99 size_t end_index_; 100 size_t sample_index_; 101 std::vector<size_t> all_indices_; 102 }; 103 104 /// Select samples sequentially. 105 class TORCH_API DistributedSequentialSampler : public DistributedSampler<> { 106 public: 107 DistributedSequentialSampler( 108 size_t size, 109 size_t num_replicas = 1, 110 size_t rank = 0, 111 bool allow_duplicates = true); 112 113 /// Resets the `DistributedSequentialSampler` to a new set of indices. 114 void reset(std::optional<size_t> new_size = std::nullopt) override; 115 116 /// Returns the next batch of indices. 117 std::optional<std::vector<size_t>> next(size_t batch_size) override; 118 119 /// Serializes the `DistributedSequentialSampler` to the `archive`. 120 void save(serialize::OutputArchive& archive) const override; 121 122 /// Deserializes the `DistributedSequentialSampler` from the `archive`. 123 void load(serialize::InputArchive& archive) override; 124 125 /// Returns the current index of the `DistributedSequentialSampler`. 126 size_t index() const noexcept; 127 128 private: 129 void populate_indices(); 130 131 size_t begin_index_; 132 size_t end_index_; 133 size_t sample_index_; 134 std::vector<size_t> all_indices_; 135 }; 136 137 } // namespace samplers 138 } // namespace data 139 } // namespace torch 140