xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/data/samplers/stream.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/data/samplers/stream.h>
2 #include <torch/serialize/archive.h>
3 #include <torch/types.h>
4 
5 #include <c10/util/Exception.h>
6 
7 #include <cstddef>
8 
9 namespace torch {
10 namespace data {
11 namespace samplers {
12 
BatchSize(size_t size)13 BatchSize::BatchSize(size_t size) : size_(size) {}
size() const14 size_t BatchSize::size() const noexcept {
15   return size_;
16 }
operator size_t() const17 BatchSize::operator size_t() const noexcept {
18   return size_;
19 }
20 
StreamSampler(size_t epoch_size)21 StreamSampler::StreamSampler(size_t epoch_size) : epoch_size_(epoch_size) {}
22 
reset(std::optional<size_t> new_size)23 void StreamSampler::reset(std::optional<size_t> new_size) {
24   if (new_size.has_value()) {
25     epoch_size_ = *new_size;
26   }
27   examples_retrieved_so_far_ = 0;
28 }
29 
next(size_t batch_size)30 std::optional<BatchSize> StreamSampler::next(size_t batch_size) {
31   AT_ASSERT(examples_retrieved_so_far_ <= epoch_size_);
32   if (examples_retrieved_so_far_ == epoch_size_) {
33     return nullopt;
34   }
35   if (examples_retrieved_so_far_ + batch_size > epoch_size_) {
36     batch_size = epoch_size_ - examples_retrieved_so_far_;
37   }
38   examples_retrieved_so_far_ += batch_size;
39   return BatchSize(batch_size);
40 }
41 
save(serialize::OutputArchive & archive) const42 void StreamSampler::save(serialize::OutputArchive& archive) const {
43   archive.write(
44       "examples_retrieved_so_far",
45       torch::tensor(
46           static_cast<int64_t>(examples_retrieved_so_far_), torch::kInt64),
47       /*is_buffer=*/true);
48 }
49 
load(serialize::InputArchive & archive)50 void StreamSampler::load(serialize::InputArchive& archive) {
51   auto tensor = torch::empty(1, torch::kInt64);
52   archive.read(
53       "examples_retrieved_so_far",
54       tensor,
55       /*is_buffer=*/true);
56   examples_retrieved_so_far_ = tensor.item<int64_t>();
57 }
58 
59 } // namespace samplers
60 } // namespace data
61 } // namespace torch
62