1 #include <torch/data/samplers/stream.h> 2 #include <torch/serialize/archive.h> 3 #include <torch/types.h> 4 5 #include <c10/util/Exception.h> 6 7 #include <cstddef> 8 9 namespace torch { 10 namespace data { 11 namespace samplers { 12 BatchSize(size_t size)13BatchSize::BatchSize(size_t size) : size_(size) {} size() const14size_t BatchSize::size() const noexcept { 15 return size_; 16 } operator size_t() const17BatchSize::operator size_t() const noexcept { 18 return size_; 19 } 20 StreamSampler(size_t epoch_size)21StreamSampler::StreamSampler(size_t epoch_size) : epoch_size_(epoch_size) {} 22 reset(std::optional<size_t> new_size)23void StreamSampler::reset(std::optional<size_t> new_size) { 24 if (new_size.has_value()) { 25 epoch_size_ = *new_size; 26 } 27 examples_retrieved_so_far_ = 0; 28 } 29 next(size_t batch_size)30std::optional<BatchSize> StreamSampler::next(size_t batch_size) { 31 AT_ASSERT(examples_retrieved_so_far_ <= epoch_size_); 32 if (examples_retrieved_so_far_ == epoch_size_) { 33 return nullopt; 34 } 35 if (examples_retrieved_so_far_ + batch_size > epoch_size_) { 36 batch_size = epoch_size_ - examples_retrieved_so_far_; 37 } 38 examples_retrieved_so_far_ += batch_size; 39 return BatchSize(batch_size); 40 } 41 save(serialize::OutputArchive & archive) const42void StreamSampler::save(serialize::OutputArchive& archive) const { 43 archive.write( 44 "examples_retrieved_so_far", 45 torch::tensor( 46 static_cast<int64_t>(examples_retrieved_so_far_), torch::kInt64), 47 /*is_buffer=*/true); 48 } 49 load(serialize::InputArchive & archive)50void StreamSampler::load(serialize::InputArchive& archive) { 51 auto tensor = torch::empty(1, torch::kInt64); 52 archive.read( 53 "examples_retrieved_so_far", 54 tensor, 55 /*is_buffer=*/true); 56 examples_retrieved_so_far_ = tensor.item<int64_t>(); 57 } 58 59 } // namespace samplers 60 } // namespace data 61 } // namespace torch 62