xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/dataloader/stateless.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/data/dataloader/base.h>
4 #include <torch/data/worker_exception.h>
5 
6 #include <c10/util/Exception.h>
7 #include <c10/util/irange.h>
8 
9 #include <cstddef>
10 #include <thread>
11 #include <utility>
12 
13 namespace torch {
14 namespace data {
15 
16 /// A dataloader for stateless datasets.
17 ///
18 /// This dataloader follows the traditional PyTorch dataloader design, whereby a
19 /// (posssibly) stateful sampler produces *batch requests* for a stateless
20 /// dataset, which acts as a simple batch request to batch mapping. The batch
21 /// request will often be an array of indices, and if the dataset is a simple
22 /// image dataset, the dataset would produce the images at those indices.
23 template <typename Dataset, typename Sampler>
24 class StatelessDataLoader : public DataLoaderBase<
25                                 Dataset,
26                                 typename Dataset::BatchType,
27                                 typename Sampler::BatchRequestType> {
28  public:
29   using super = DataLoaderBase<
30       Dataset,
31       typename Dataset::BatchType,
32       typename Sampler::BatchRequestType>;
33   using typename super::BatchRequestType;
34 
35   /// Constructs the `StatelessDataLoader` from a `dataset`, a `sampler` and
36   /// some `options`.
StatelessDataLoader(Dataset dataset,Sampler sampler,DataLoaderOptions options)37   StatelessDataLoader(
38       Dataset dataset,
39       Sampler sampler,
40       DataLoaderOptions options)
41       : super(std::move(options)), sampler_(std::move(sampler)) {
42     for (const auto w : c10::irange(this->options_.workers)) {
43       // Here we copy the dataset into the worker thread closure. Each worker
44       // has its own copy of the dataset. This means the dataset must be
45       // trivially copiable, or else we don't expect more than one worker to
46       // be in use.
47       (void)w; // Suppress unused variable warning
48       this->workers_.emplace_back(
49           [this, dataset]() mutable { this->worker_thread(dataset); });
50     }
51     if (this->options_.workers == 0) {
52       this->main_thread_dataset_ =
53           std::make_unique<Dataset>(std::move(dataset));
54     }
55   }
56 
57  private:
58   /// Resets the internal state of the dataloader and the sampler.
reset()59   void reset() override {
60     sampler_.reset();
61     // Call the base class method last because it calls `prefetch()`
62     super::reset();
63   }
64 
65   /// Queries the sampler for the next batch request (possibly progressing its
66   /// internal state).
get_batch_request()67   std::optional<BatchRequestType> get_batch_request() override {
68     auto indices = sampler_.next(this->options_.batch_size);
69     if (!indices ||
70         (indices->size() < this->options_.batch_size &&
71          this->options_.drop_last)) {
72       return nullopt;
73     }
74     AT_ASSERT(indices->size() > 0);
75     return indices;
76   }
77 
78   /// The `Sampler` used to produce batch requests.
79   Sampler sampler_;
80 };
81 } // namespace data
82 } // namespace torch
83