xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/samplers/distributed.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/data/samplers/base.h>
5 
6 #include <cstddef>
7 #include <vector>
8 
9 namespace torch {
10 namespace serialize {
11 class OutputArchive;
12 class InputArchive;
13 } // namespace serialize
14 } // namespace torch
15 
16 namespace torch {
17 namespace data {
18 namespace samplers {
19 
20 /// A `Sampler` that selects a subset of indices to sample from and defines a
21 /// sampling behavior. In a distributed setting, this selects a subset of the
22 /// indices depending on the provided num_replicas and rank parameters. The
23 /// `Sampler` performs a rounding operation based on the `allow_duplicates`
24 /// parameter to decide the local sample count.
25 template <typename BatchRequest = std::vector<size_t>>
26 class DistributedSampler : public Sampler<BatchRequest> {
27  public:
28   DistributedSampler(
29       size_t size,
30       size_t num_replicas = 1,
31       size_t rank = 0,
32       bool allow_duplicates = true)
size_(size)33       : size_(size),
34         num_replicas_(num_replicas),
35         rank_(rank),
36         epoch_(0),
37         allow_duplicates_(allow_duplicates) {}
38 
39   /// Set the epoch for the current enumeration. This can be used to alter the
40   /// sample selection and shuffling behavior.
set_epoch(size_t epoch)41   void set_epoch(size_t epoch) {
42     epoch_ = epoch;
43   }
44 
epoch()45   size_t epoch() const {
46     return epoch_;
47   }
48 
49  protected:
local_sample_count()50   size_t local_sample_count() {
51     if (allow_duplicates_) {
52       return (size_ + num_replicas_ - 1) / num_replicas_;
53     } else {
54       return size_ / num_replicas_;
55     }
56   }
57 
58   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
59   size_t size_;
60   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
61   size_t num_replicas_;
62   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
63   size_t rank_;
64   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
65   size_t epoch_;
66   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
67   bool allow_duplicates_;
68 };
69 
70 /// Select samples randomly. The sampling order is shuffled at each `reset()`
71 /// call.
72 class TORCH_API DistributedRandomSampler : public DistributedSampler<> {
73  public:
74   DistributedRandomSampler(
75       size_t size,
76       size_t num_replicas = 1,
77       size_t rank = 0,
78       bool allow_duplicates = true);
79 
80   /// Resets the `DistributedRandomSampler` to a new set of indices.
81   void reset(std::optional<size_t> new_size = std::nullopt) override;
82 
83   /// Returns the next batch of indices.
84   std::optional<std::vector<size_t>> next(size_t batch_size) override;
85 
86   /// Serializes the `DistributedRandomSampler` to the `archive`.
87   void save(serialize::OutputArchive& archive) const override;
88 
89   /// Deserializes the `DistributedRandomSampler` from the `archive`.
90   void load(serialize::InputArchive& archive) override;
91 
92   /// Returns the current index of the `DistributedRandomSampler`.
93   size_t index() const noexcept;
94 
95  private:
96   void populate_indices();
97 
98   size_t begin_index_;
99   size_t end_index_;
100   size_t sample_index_;
101   std::vector<size_t> all_indices_;
102 };
103 
104 /// Select samples sequentially.
105 class TORCH_API DistributedSequentialSampler : public DistributedSampler<> {
106  public:
107   DistributedSequentialSampler(
108       size_t size,
109       size_t num_replicas = 1,
110       size_t rank = 0,
111       bool allow_duplicates = true);
112 
113   /// Resets the `DistributedSequentialSampler` to a new set of indices.
114   void reset(std::optional<size_t> new_size = std::nullopt) override;
115 
116   /// Returns the next batch of indices.
117   std::optional<std::vector<size_t>> next(size_t batch_size) override;
118 
119   /// Serializes the `DistributedSequentialSampler` to the `archive`.
120   void save(serialize::OutputArchive& archive) const override;
121 
122   /// Deserializes the `DistributedSequentialSampler` from the `archive`.
123   void load(serialize::InputArchive& archive) override;
124 
125   /// Returns the current index of the `DistributedSequentialSampler`.
126   size_t index() const noexcept;
127 
128  private:
129   void populate_indices();
130 
131   size_t begin_index_;
132   size_t end_index_;
133   size_t sample_index_;
134   std::vector<size_t> all_indices_;
135 };
136 
137 } // namespace samplers
138 } // namespace data
139 } // namespace torch
140