1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/data/samplers/base.h> 5 #include <torch/types.h> 6 7 #include <cstddef> 8 #include <vector> 9 10 namespace torch { 11 namespace serialize { 12 class OutputArchive; 13 class InputArchive; 14 } // namespace serialize 15 } // namespace torch 16 17 namespace torch { 18 namespace data { 19 namespace samplers { 20 21 /// A `Sampler` that returns random indices. 22 class TORCH_API RandomSampler : public Sampler<> { 23 public: 24 /// Constructs a `RandomSampler` with a size and dtype for the stored indices. 25 /// 26 /// The constructor will eagerly allocate all required indices, which is the 27 /// sequence `0 ... size - 1`. `index_dtype` is the data type of the stored 28 /// indices. You can change it to influence memory usage. 29 explicit RandomSampler(int64_t size, Dtype index_dtype = torch::kInt64); 30 31 ~RandomSampler() override; 32 33 /// Resets the `RandomSampler` to a new set of indices. 34 void reset(std::optional<size_t> new_size = std::nullopt) override; 35 36 /// Returns the next batch of indices. 37 std::optional<std::vector<size_t>> next(size_t batch_size) override; 38 39 /// Serializes the `RandomSampler` to the `archive`. 40 void save(serialize::OutputArchive& archive) const override; 41 42 /// Deserializes the `RandomSampler` from the `archive`. 43 void load(serialize::InputArchive& archive) override; 44 45 /// Returns the current index of the `RandomSampler`. 46 size_t index() const noexcept; 47 48 private: 49 at::Tensor indices_; 50 int64_t index_ = 0; 51 }; 52 } // namespace samplers 53 } // namespace data 54 } // namespace torch 55