xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/dataloader/stateful.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/irange.h>
4 #include <torch/data/dataloader/base.h>
5 
6 #include <cstddef>
7 #include <thread>
8 #include <utility>
9 
10 namespace torch {
11 namespace data {
12 
13 /// A dataloader for stateful datasets.
14 ///
15 /// A dataloader for stateful datatasets differs from one for stateless
16 /// datasets one in that the dataset is shared among worker threads, and that
17 /// this dataset is itself responsible for producing batches rather than
18 /// depending on a sampler. The statefulness here actually refers to the
19 /// dataset. The StatefulDataLoader simply alters the data loading algorithm to
20 /// accommodate the stateful, shared nature of the dataset. Note that the
21 /// dataset must be thread safe if more than one worker thread is used.
22 ///
23 /// A stateful dataloader is created by calling `make_data_loader` with a
24 /// stateful dataset.
25 template <typename Dataset>
26 class StatefulDataLoader : public DataLoaderBase<
27                                Dataset,
28                                typename Dataset::BatchType::value_type,
29                                typename Dataset::BatchRequestType> {
30  public:
31   using super = DataLoaderBase<
32       Dataset,
33       typename Dataset::BatchType::value_type,
34       typename Dataset::BatchRequestType>;
35   using typename super::BatchRequestType;
36 
37   /// Constructs the `StatefulDataLoader` from a `dataset` and some `options`.
StatefulDataLoader(Dataset dataset,DataLoaderOptions options)38   StatefulDataLoader(Dataset dataset, DataLoaderOptions options)
39       : super(options, std::make_unique<Dataset>(std::move(dataset))) {
40     for ([[maybe_unused]] const auto _ : c10::irange(this->options_.workers)) {
41       // As opposed to the stateless case, here all worker threads access the
42       // same underlying dataset.
43       this->workers_.emplace_back(
44           [this] { this->worker_thread(*this->main_thread_dataset_); });
45     }
46   }
47 
48  private:
49   /// Resets the internal state of the dataloader and the dataset.
reset()50   void reset() override {
51     this->main_thread_dataset_->reset();
52     // Call the base class method last because it calls `prefetch()`
53     super::reset();
54   }
55 
56   /// For stateful datasets, the batch request is always the batch size. The
57   /// dataset is responsible for determining what goes into the batch next.
get_batch_request()58   std::optional<BatchRequestType> get_batch_request() override {
59     return this->options_.batch_size;
60   }
61 };
62 } // namespace data
63 } // namespace torch
64