xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/data/samplers/random.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/data/samplers/random.h>
2 #include <torch/serialize/archive.h>
3 #include <torch/types.h>
4 
5 #include <algorithm>
6 #include <cstddef>
7 #include <vector>
8 
9 namespace torch {
10 namespace data {
11 namespace samplers {
RandomSampler(int64_t size,Dtype index_dtype)12 RandomSampler::RandomSampler(int64_t size, Dtype index_dtype)
13     : indices_(torch::randperm(size, index_dtype)) {}
14 
15 RandomSampler::~RandomSampler() = default;
16 
reset(std::optional<size_t> new_size)17 void RandomSampler::reset(std::optional<size_t> new_size) {
18   // This allocates a new chunk of memory every time (just FYI). It should be
19   // amortized over the entire epoch hopefully.
20   const auto size = new_size.value_or(static_cast<size_t>(indices_.numel()));
21   indices_ = torch::randperm(size, indices_.options());
22   index_ = 0;
23 }
24 
next(size_t batch_size)25 optional<std::vector<size_t>> RandomSampler::next(size_t batch_size) {
26   AT_ASSERT(index_ <= indices_.numel());
27   const size_t remaining_indices = indices_.numel() - index_;
28   if (remaining_indices == 0) {
29     return nullopt;
30   }
31   std::vector<size_t> index_batch(std::min(batch_size, remaining_indices));
32   auto slice = indices_.slice(/*dim=*/0, index_, index_ + index_batch.size());
33   // You may want to store your indices with 32-bit or less, but here we need
34   // to upcast to 64-bit. A batch itself won't hold too many indices, so that
35   // should be ok. Note that if this indeed results in a type promotion, there
36   // will be two allocations: one for the upcast slice, and one for the
37   // returned `index_batch` vector.
38   slice = slice.to(torch::kInt64);
39   const auto* data = slice.const_data_ptr<int64_t>();
40   std::copy(data, data + index_batch.size(), index_batch.begin());
41   index_ += index_batch.size();
42   return index_batch;
43 }
44 
save(serialize::OutputArchive & archive) const45 void RandomSampler::save(serialize::OutputArchive& archive) const {
46   archive.write(
47       "index",
48       torch::tensor(static_cast<int64_t>(index_), torch::kInt64),
49       /*is_buffer=*/true);
50   archive.write(
51       "indices",
52       indices_,
53       /*is_buffer=*/true);
54 }
55 
load(serialize::InputArchive & archive)56 void RandomSampler::load(serialize::InputArchive& archive) {
57   auto tensor = torch::empty(1, torch::kInt64);
58   archive.read(
59       "index",
60       tensor,
61       /*is_buffer=*/true);
62   index_ = tensor.item<int64_t>();
63   archive.read(
64       "indices",
65       indices_,
66       /*is_buffer=*/true);
67 }
68 
index() const69 size_t RandomSampler::index() const noexcept {
70   return index_;
71 }
72 
73 } // namespace samplers
74 } // namespace data
75 } // namespace torch
76