1 #pragma once 2 3 #include <torch/data/dataloader/base.h> 4 #include <torch/data/worker_exception.h> 5 6 #include <c10/util/Exception.h> 7 #include <c10/util/irange.h> 8 9 #include <cstddef> 10 #include <thread> 11 #include <utility> 12 13 namespace torch { 14 namespace data { 15 16 /// A dataloader for stateless datasets. 17 /// 18 /// This dataloader follows the traditional PyTorch dataloader design, whereby a 19 /// (posssibly) stateful sampler produces *batch requests* for a stateless 20 /// dataset, which acts as a simple batch request to batch mapping. The batch 21 /// request will often be an array of indices, and if the dataset is a simple 22 /// image dataset, the dataset would produce the images at those indices. 23 template <typename Dataset, typename Sampler> 24 class StatelessDataLoader : public DataLoaderBase< 25 Dataset, 26 typename Dataset::BatchType, 27 typename Sampler::BatchRequestType> { 28 public: 29 using super = DataLoaderBase< 30 Dataset, 31 typename Dataset::BatchType, 32 typename Sampler::BatchRequestType>; 33 using typename super::BatchRequestType; 34 35 /// Constructs the `StatelessDataLoader` from a `dataset`, a `sampler` and 36 /// some `options`. StatelessDataLoader(Dataset dataset,Sampler sampler,DataLoaderOptions options)37 StatelessDataLoader( 38 Dataset dataset, 39 Sampler sampler, 40 DataLoaderOptions options) 41 : super(std::move(options)), sampler_(std::move(sampler)) { 42 for (const auto w : c10::irange(this->options_.workers)) { 43 // Here we copy the dataset into the worker thread closure. Each worker 44 // has its own copy of the dataset. This means the dataset must be 45 // trivially copiable, or else we don't expect more than one worker to 46 // be in use. 47 (void)w; // Suppress unused variable warning 48 this->workers_.emplace_back( 49 [this, dataset]() mutable { this->worker_thread(dataset); }); 50 } 51 if (this->options_.workers == 0) { 52 this->main_thread_dataset_ = 53 std::make_unique<Dataset>(std::move(dataset)); 54 } 55 } 56 57 private: 58 /// Resets the internal state of the dataloader and the sampler. reset()59 void reset() override { 60 sampler_.reset(); 61 // Call the base class method last because it calls `prefetch()` 62 super::reset(); 63 } 64 65 /// Queries the sampler for the next batch request (possibly progressing its 66 /// internal state). get_batch_request()67 std::optional<BatchRequestType> get_batch_request() override { 68 auto indices = sampler_.next(this->options_.batch_size); 69 if (!indices || 70 (indices->size() < this->options_.batch_size && 71 this->options_.drop_last)) { 72 return nullopt; 73 } 74 AT_ASSERT(indices->size() > 0); 75 return indices; 76 } 77 78 /// The `Sampler` used to produce batch requests. 79 Sampler sampler_; 80 }; 81 } // namespace data 82 } // namespace torch 83