xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/train/random.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/data/samplers/base.h>
5 #include <torch/types.h>
6 
7 #include <cstddef>
8 #include <vector>
9 
10 namespace torch::serialize {
11 class OutputArchive;
12 class InputArchive;
13 } // namespace torch::serialize
14 
15 namespace torch::jit::mobile {
16 
17 /// A lighter `Sampler` that returns indices randomly and cannot be
18 /// serialized.
19 class TORCH_API RandomSampler : public torch::data::samplers::Sampler<> {
20  public:
21   /// Constructs a `RandomSampler` with a size and dtype for the stored indices.
22   ///
23   /// The constructor will eagerly allocate all required indices, which is the
24   /// sequence `0 ... size - 1`. `index_dtype` is the data type of the stored
25   /// indices. You can change it to influence memory usage.
26   explicit RandomSampler(int64_t size, Dtype index_dtype = torch::kInt64);
27 
28   ~RandomSampler() override;
29 
30   /// Resets the `RandomSampler` to a new set of indices.
31   void reset(std::optional<size_t> new_size = std::nullopt) override;
32 
33   /// Returns the next batch of indices.
34   std::optional<std::vector<size_t>> next(size_t batch_size) override;
35 
36   /// Serializes the `RandomSampler` to the `archive`.
37   void save(serialize::OutputArchive& archive) const override;
38 
39   /// Deserializes the `RandomSampler` from the `archive`.
40   void load(serialize::InputArchive& archive) override;
41 
42   /// Returns the current index of the `RandomSampler`.
43   size_t index() const noexcept;
44 
45  private:
46   at::Tensor indices_;
47   int64_t index_ = 0;
48 };
49 
50 } // namespace torch::jit::mobile
51