xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/samplers/random.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/data/samplers/base.h>
5 #include <torch/types.h>
6 
7 #include <cstddef>
8 #include <vector>
9 
10 namespace torch {
11 namespace serialize {
12 class OutputArchive;
13 class InputArchive;
14 } // namespace serialize
15 } // namespace torch
16 
17 namespace torch {
18 namespace data {
19 namespace samplers {
20 
21 /// A `Sampler` that returns random indices.
22 class TORCH_API RandomSampler : public Sampler<> {
23  public:
24   /// Constructs a `RandomSampler` with a size and dtype for the stored indices.
25   ///
26   /// The constructor will eagerly allocate all required indices, which is the
27   /// sequence `0 ... size - 1`. `index_dtype` is the data type of the stored
28   /// indices. You can change it to influence memory usage.
29   explicit RandomSampler(int64_t size, Dtype index_dtype = torch::kInt64);
30 
31   ~RandomSampler() override;
32 
33   /// Resets the `RandomSampler` to a new set of indices.
34   void reset(std::optional<size_t> new_size = std::nullopt) override;
35 
36   /// Returns the next batch of indices.
37   std::optional<std::vector<size_t>> next(size_t batch_size) override;
38 
39   /// Serializes the `RandomSampler` to the `archive`.
40   void save(serialize::OutputArchive& archive) const override;
41 
42   /// Deserializes the `RandomSampler` from the `archive`.
43   void load(serialize::InputArchive& archive) override;
44 
45   /// Returns the current index of the `RandomSampler`.
46   size_t index() const noexcept;
47 
48  private:
49   at::Tensor indices_;
50   int64_t index_ = 0;
51 };
52 } // namespace samplers
53 } // namespace data
54 } // namespace torch
55