1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/data/samplers/base.h> 5 #include <torch/types.h> 6 7 #include <cstddef> 8 #include <vector> 9 10 namespace torch::serialize { 11 class OutputArchive; 12 class InputArchive; 13 } // namespace torch::serialize 14 15 namespace torch::jit::mobile { 16 17 /// A lighter `Sampler` that returns indices randomly and cannot be 18 /// serialized. 19 class TORCH_API RandomSampler : public torch::data::samplers::Sampler<> { 20 public: 21 /// Constructs a `RandomSampler` with a size and dtype for the stored indices. 22 /// 23 /// The constructor will eagerly allocate all required indices, which is the 24 /// sequence `0 ... size - 1`. `index_dtype` is the data type of the stored 25 /// indices. You can change it to influence memory usage. 26 explicit RandomSampler(int64_t size, Dtype index_dtype = torch::kInt64); 27 28 ~RandomSampler() override; 29 30 /// Resets the `RandomSampler` to a new set of indices. 31 void reset(std::optional<size_t> new_size = std::nullopt) override; 32 33 /// Returns the next batch of indices. 34 std::optional<std::vector<size_t>> next(size_t batch_size) override; 35 36 /// Serializes the `RandomSampler` to the `archive`. 37 void save(serialize::OutputArchive& archive) const override; 38 39 /// Deserializes the `RandomSampler` from the `archive`. 40 void load(serialize::InputArchive& archive) override; 41 42 /// Returns the current index of the `RandomSampler`. 43 size_t index() const noexcept; 44 45 private: 46 at::Tensor indices_; 47 int64_t index_ = 0; 48 }; 49 50 } // namespace torch::jit::mobile 51