xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/samplers/stream.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/data/samplers/base.h>
5 #include <torch/data/samplers/custom_batch_request.h>
6 #include <torch/types.h>
7 
8 #include <cstddef>
9 
10 namespace torch {
11 namespace serialize {
12 class InputArchive;
13 class OutputArchive;
14 } // namespace serialize
15 } // namespace torch
16 
17 namespace torch {
18 namespace data {
19 namespace samplers {
20 
21 /// A wrapper around a batch size value, which implements the
22 /// `CustomBatchRequest` interface.
23 struct TORCH_API BatchSize : public CustomBatchRequest {
24   explicit BatchSize(size_t size);
25   size_t size() const noexcept override;
26   operator size_t() const noexcept;
27   size_t size_;
28 };
29 
30 /// A sampler for (potentially infinite) streams of data.
31 ///
32 /// The major feature of the `StreamSampler` is that it does not return
33 /// particular indices, but instead only the number of elements to fetch from
34 /// the dataset. The dataset has to decide how to produce those elements.
35 class TORCH_API StreamSampler : public Sampler<BatchSize> {
36  public:
37   /// Constructs the `StreamSampler` with the number of individual examples that
38   /// should be fetched until the sampler is exhausted.
39   explicit StreamSampler(size_t epoch_size);
40 
41   /// Resets the internal state of the sampler.
42   void reset(std::optional<size_t> new_size = std::nullopt) override;
43 
44   /// Returns a `BatchSize` object with the number of elements to fetch in the
45   /// next batch. This number is the minimum of the supplied `batch_size` and
46   /// the difference between the `epoch_size` and the current index. If the
47   /// `epoch_size` has been reached, returns an empty optional.
48   std::optional<BatchSize> next(size_t batch_size) override;
49 
50   /// Serializes the `StreamSampler` to the `archive`.
51   void save(serialize::OutputArchive& archive) const override;
52 
53   /// Deserializes the `StreamSampler` from the `archive`.
54   void load(serialize::InputArchive& archive) override;
55 
56  private:
57   size_t examples_retrieved_so_far_ = 0;
58   size_t epoch_size_;
59 };
60 
61 } // namespace samplers
62 } // namespace data
63 } // namespace torch
64