1 #pragma once 2 3 #include <c10/util/irange.h> 4 #include <torch/data/dataloader/base.h> 5 6 #include <cstddef> 7 #include <thread> 8 #include <utility> 9 10 namespace torch { 11 namespace data { 12 13 /// A dataloader for stateful datasets. 14 /// 15 /// A dataloader for stateful datatasets differs from one for stateless 16 /// datasets one in that the dataset is shared among worker threads, and that 17 /// this dataset is itself responsible for producing batches rather than 18 /// depending on a sampler. The statefulness here actually refers to the 19 /// dataset. The StatefulDataLoader simply alters the data loading algorithm to 20 /// accommodate the stateful, shared nature of the dataset. Note that the 21 /// dataset must be thread safe if more than one worker thread is used. 22 /// 23 /// A stateful dataloader is created by calling `make_data_loader` with a 24 /// stateful dataset. 25 template <typename Dataset> 26 class StatefulDataLoader : public DataLoaderBase< 27 Dataset, 28 typename Dataset::BatchType::value_type, 29 typename Dataset::BatchRequestType> { 30 public: 31 using super = DataLoaderBase< 32 Dataset, 33 typename Dataset::BatchType::value_type, 34 typename Dataset::BatchRequestType>; 35 using typename super::BatchRequestType; 36 37 /// Constructs the `StatefulDataLoader` from a `dataset` and some `options`. StatefulDataLoader(Dataset dataset,DataLoaderOptions options)38 StatefulDataLoader(Dataset dataset, DataLoaderOptions options) 39 : super(options, std::make_unique<Dataset>(std::move(dataset))) { 40 for ([[maybe_unused]] const auto _ : c10::irange(this->options_.workers)) { 41 // As opposed to the stateless case, here all worker threads access the 42 // same underlying dataset. 43 this->workers_.emplace_back( 44 [this] { this->worker_thread(*this->main_thread_dataset_); }); 45 } 46 } 47 48 private: 49 /// Resets the internal state of the dataloader and the dataset. reset()50 void reset() override { 51 this->main_thread_dataset_->reset(); 52 // Call the base class method last because it calls `prefetch()` 53 super::reset(); 54 } 55 56 /// For stateful datasets, the batch request is always the batch size. The 57 /// dataset is responsible for determining what goes into the batch next. get_batch_request()58 std::optional<BatchRequestType> get_batch_request() override { 59 return this->options_.batch_size; 60 } 61 }; 62 } // namespace data 63 } // namespace torch 64