xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/train/random.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/mobile/train/random.h>
2 #include <torch/types.h>
3 
4 #include <algorithm>
5 #include <cstddef>
6 #include <vector>
7 
8 namespace torch::jit::mobile {
9 
RandomSampler(int64_t size,Dtype index_dtype)10 RandomSampler::RandomSampler(int64_t size, Dtype index_dtype)
11     : indices_(torch::randperm(size, index_dtype)) {}
12 
13 RandomSampler::~RandomSampler() = default;
14 
reset(std::optional<size_t> new_size)15 void RandomSampler::reset(std::optional<size_t> new_size) {
16   // This allocates a new chunk of memory every time (just FYI). It should be
17   // amortized over the entire epoch hopefully.
18   const auto size = new_size.value_or(static_cast<size_t>(indices_.numel()));
19   indices_ = torch::randperm(static_cast<int64_t>(size), indices_.options());
20   index_ = 0;
21 }
22 
next(size_t batch_size)23 std::optional<std::vector<size_t>> RandomSampler::next(size_t batch_size) {
24   AT_ASSERT(index_ <= indices_.numel());
25   const size_t remaining_indices = indices_.numel() - index_;
26   if (remaining_indices == 0) {
27     return nullopt;
28   }
29   std::vector<size_t> index_batch(std::min(batch_size, remaining_indices));
30   auto slice = indices_.slice(/*dim=*/0, index_, index_ + index_batch.size());
31   // You may want to store your indices with 32-bit or less, but here we need
32   // to upcast to 64-bit. A batch itself won't hold too many indices, so that
33   // should be ok. Note that if this indeed results in a type promotion, there
34   // will be two allocations: one for the upcast slice, and one for the
35   // returned `index_batch` vector.
36   slice = slice.to(torch::kInt64);
37   const auto* data = slice.const_data_ptr<int64_t>();
38   std::copy(data, data + index_batch.size(), index_batch.begin());
39   index_ += static_cast<int64_t>(index_batch.size());
40   return index_batch;
41 }
42 
save(serialize::OutputArchive & archive) const43 void RandomSampler::save(serialize::OutputArchive& archive) const {
44   TORCH_CHECK(false, "Serialization of RandomSampler not supported on mobile.");
45 }
46 
load(serialize::InputArchive & archive)47 void RandomSampler::load(serialize::InputArchive& archive) {
48   TORCH_CHECK(false, "Serialization of RandomSampler not supported on mobile.");
49 }
50 
index() const51 size_t RandomSampler::index() const noexcept {
52   return index_;
53 }
54 
55 } // namespace torch::jit::mobile
56