1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/data/samplers/base.h> 5 #include <torch/data/samplers/custom_batch_request.h> 6 #include <torch/types.h> 7 8 #include <cstddef> 9 10 namespace torch { 11 namespace serialize { 12 class InputArchive; 13 class OutputArchive; 14 } // namespace serialize 15 } // namespace torch 16 17 namespace torch { 18 namespace data { 19 namespace samplers { 20 21 /// A wrapper around a batch size value, which implements the 22 /// `CustomBatchRequest` interface. 23 struct TORCH_API BatchSize : public CustomBatchRequest { 24 explicit BatchSize(size_t size); 25 size_t size() const noexcept override; 26 operator size_t() const noexcept; 27 size_t size_; 28 }; 29 30 /// A sampler for (potentially infinite) streams of data. 31 /// 32 /// The major feature of the `StreamSampler` is that it does not return 33 /// particular indices, but instead only the number of elements to fetch from 34 /// the dataset. The dataset has to decide how to produce those elements. 35 class TORCH_API StreamSampler : public Sampler<BatchSize> { 36 public: 37 /// Constructs the `StreamSampler` with the number of individual examples that 38 /// should be fetched until the sampler is exhausted. 39 explicit StreamSampler(size_t epoch_size); 40 41 /// Resets the internal state of the sampler. 42 void reset(std::optional<size_t> new_size = std::nullopt) override; 43 44 /// Returns a `BatchSize` object with the number of elements to fetch in the 45 /// next batch. This number is the minimum of the supplied `batch_size` and 46 /// the difference between the `epoch_size` and the current index. If the 47 /// `epoch_size` has been reached, returns an empty optional. 48 std::optional<BatchSize> next(size_t batch_size) override; 49 50 /// Serializes the `StreamSampler` to the `archive`. 51 void save(serialize::OutputArchive& archive) const override; 52 53 /// Deserializes the `StreamSampler` from the `archive`. 54 void load(serialize::InputArchive& archive) override; 55 56 private: 57 size_t examples_retrieved_so_far_ = 0; 58 size_t epoch_size_; 59 }; 60 61 } // namespace samplers 62 } // namespace data 63 } // namespace torch 64