1 #include <torch/csrc/jit/mobile/train/random.h>
2 #include <torch/types.h>
3
4 #include <algorithm>
5 #include <cstddef>
6 #include <vector>
7
8 namespace torch::jit::mobile {
9
RandomSampler(int64_t size,Dtype index_dtype)10 RandomSampler::RandomSampler(int64_t size, Dtype index_dtype)
11 : indices_(torch::randperm(size, index_dtype)) {}
12
13 RandomSampler::~RandomSampler() = default;
14
reset(std::optional<size_t> new_size)15 void RandomSampler::reset(std::optional<size_t> new_size) {
16 // This allocates a new chunk of memory every time (just FYI). It should be
17 // amortized over the entire epoch hopefully.
18 const auto size = new_size.value_or(static_cast<size_t>(indices_.numel()));
19 indices_ = torch::randperm(static_cast<int64_t>(size), indices_.options());
20 index_ = 0;
21 }
22
next(size_t batch_size)23 std::optional<std::vector<size_t>> RandomSampler::next(size_t batch_size) {
24 AT_ASSERT(index_ <= indices_.numel());
25 const size_t remaining_indices = indices_.numel() - index_;
26 if (remaining_indices == 0) {
27 return nullopt;
28 }
29 std::vector<size_t> index_batch(std::min(batch_size, remaining_indices));
30 auto slice = indices_.slice(/*dim=*/0, index_, index_ + index_batch.size());
31 // You may want to store your indices with 32-bit or less, but here we need
32 // to upcast to 64-bit. A batch itself won't hold too many indices, so that
33 // should be ok. Note that if this indeed results in a type promotion, there
34 // will be two allocations: one for the upcast slice, and one for the
35 // returned `index_batch` vector.
36 slice = slice.to(torch::kInt64);
37 const auto* data = slice.const_data_ptr<int64_t>();
38 std::copy(data, data + index_batch.size(), index_batch.begin());
39 index_ += static_cast<int64_t>(index_batch.size());
40 return index_batch;
41 }
42
save(serialize::OutputArchive & archive) const43 void RandomSampler::save(serialize::OutputArchive& archive) const {
44 TORCH_CHECK(false, "Serialization of RandomSampler not supported on mobile.");
45 }
46
load(serialize::InputArchive & archive)47 void RandomSampler::load(serialize::InputArchive& archive) {
48 TORCH_CHECK(false, "Serialization of RandomSampler not supported on mobile.");
49 }
50
index() const51 size_t RandomSampler::index() const noexcept {
52 return index_;
53 }
54
55 } // namespace torch::jit::mobile
56