1 #pragma once 2 3 #include <torch/data/dataloader_options.h> 4 #include <torch/data/detail/data_shuttle.h> 5 #include <torch/data/detail/sequencers.h> 6 #include <torch/data/iterator.h> 7 #include <torch/data/samplers/random.h> 8 #include <torch/data/worker_exception.h> 9 #include <torch/types.h> 10 11 #include <torch/csrc/utils/variadic.h> 12 13 #include <c10/util/Exception.h> 14 #include <c10/util/irange.h> 15 16 #include <cstddef> 17 #include <exception> 18 #include <memory> 19 #include <thread> 20 #include <type_traits> 21 #include <utility> 22 #include <vector> 23 24 namespace torch { 25 namespace data { 26 template <typename Dataset, typename Batch, typename BatchRequest> 27 class DataLoaderBase { 28 public: 29 using BatchType = Batch; 30 using BatchRequestType = BatchRequest; 31 32 /// Constructs a new DataLoader from a `dataset` to sample from, `options` 33 /// to configure the DataLoader with, and a `sampler` that specifies the 34 /// sampling strategy. 35 DataLoaderBase( 36 DataLoaderOptions options, 37 std::unique_ptr<Dataset> main_thread_dataset = nullptr) options_(std::move (options))38 : options_(std::move(options)), 39 main_thread_dataset_(std::move(main_thread_dataset)), 40 sequencer_(new_sequencer()) {} 41 42 // NOLINTNEXTLINE(bugprone-exception-escape) ~DataLoaderBase()43 virtual ~DataLoaderBase() { 44 join(); 45 } 46 47 /// Returns an iterator into the DataLoader. The lifetime of the iterator is 48 /// bound to the DataLoader. In C++ standards language, the category of the 49 /// iterator is `OutputIterator`. See 50 /// https://en.cppreference.com/w/cpp/named_req/OutputIterator for what this 51 /// means. In short: you may increment the iterator and dereference it, but 52 /// cannot go back, or step forward more than one position at a time. When the 53 /// DataLoader is exhausted, it will compare equal with the special 54 /// "sentinel" iterator returned by `DataLoader::end()`. Most of the time, you 55 /// should only use range-for loops to loop over the DataLoader, but 56 /// standard algorithms like `std::copy(dataloader.begin(), dataloader.end(), 57 /// output_iterator)` are supported too. begin()58 Iterator<Batch> begin() { 59 TORCH_CHECK( 60 shuttle_.in_flight_jobs() == 0, 61 "Attempted to get a new DataLoader iterator " 62 "while another iterator is not yet exhausted"); 63 reset(); 64 return Iterator<Batch>(std::make_unique<detail::ValidIterator<Batch>>( 65 [this] { return this->next(); })); 66 } 67 68 /// Returns a special "sentinel" iterator that compares equal with a 69 /// non-sentinel iterator once the DataLoader is exhausted. end()70 Iterator<Batch> end() { 71 return Iterator<Batch>(std::make_unique<detail::SentinelIterator<Batch>>()); 72 } 73 74 /// Joins the DataLoader's worker threads and drains internal queues. 75 /// This function may only be invoked from the main thread (in which the 76 /// DataLoader lives). join()77 void join() { 78 if (joined_) { 79 return; 80 } 81 shuttle_.drain(); 82 // Send one 'quit' message per worker. Since a worker dies (exits its 83 // thread) after receiving this message, each `QuitWorker()` message will be 84 // read by exactly one worker. 85 for (const auto w : c10::irange(options_.workers)) { 86 (void)w; // Suppress unused variable warning 87 push_job(QuitWorker()); 88 } 89 for (auto& worker : workers_) { 90 worker.join(); 91 } 92 joined_ = true; 93 } 94 95 /// Returns the options with which the DataLoader was configured. options()96 const FullDataLoaderOptions& options() const noexcept { 97 return options_; 98 } 99 100 protected: 101 /// Simple mix-in to give something a sequence number. 102 struct Sequenced { 103 Sequenced() = default; SequencedSequenced104 Sequenced(size_t sqn) : sequence_number(sqn) {} 105 size_t sequence_number; 106 }; 107 108 struct QuitWorker {}; 109 110 /// A `Job` is either a `BatchRequest` (new indices to fetch data at) or a 111 /// `QuitWorker` object, to indicate the worker should shut down. 112 struct Job : Sequenced { 113 Job() = default; JobJob114 Job(QuitWorker q, size_t sqn) : Sequenced(sqn), quit(q) {} JobJob115 Job(BatchRequest&& i, size_t sqn) 116 : Sequenced(sqn), batch_request(std::move(i)) {} 117 std::optional<QuitWorker> quit; 118 std::optional<BatchRequest> batch_request; 119 }; 120 121 /// The finished result of a job. 122 struct Result : Sequenced { 123 Result() = default; ResultResult124 Result(std::optional<Batch>&& b, size_t sqn) 125 : Sequenced(sqn), batch(std::move(b)) {} ResultResult126 Result(std::exception_ptr exception, size_t sqn) 127 : Sequenced(sqn), exception(std::move(exception)) {} 128 std::optional<Batch> batch; 129 std::exception_ptr exception; 130 }; 131 132 /// Subclass hook for getting the next batch request. The stateless case will 133 /// ask the sampler for a new batch request (e.g. a vector of indices), while 134 /// the stateful one will simply return the batch size. 135 virtual std::optional<BatchRequestType> get_batch_request() = 0; 136 137 /// Resets the internal state of the DataLoader, optionally pre-fetching 138 /// new jobs. reset()139 virtual void reset() { 140 shuttle_.drain(); 141 sequence_number_ = 0; 142 sequencer_ = new_sequencer(); 143 prefetch(); 144 } 145 146 /// Schedules `requested_jobs` many new batches to be fetched. The actual 147 /// number of jobs scheduled may be less if the DataLoader exhausts. prefetch(size_t requested_jobs)148 void prefetch(size_t requested_jobs) { 149 for (const auto r : c10::irange(requested_jobs)) { 150 (void)r; // Suppress unused variable 151 if (auto batch_request = get_batch_request()) { 152 this->push_job(std::move(*batch_request)); 153 } else { 154 break; 155 } 156 } 157 } 158 159 /// Schedules the maximum number of jobs (based on the `max_jobs` option). prefetch()160 void prefetch() { 161 prefetch(options_.max_jobs); 162 } 163 164 /// Returns the next batch of data, or an empty `optional` if the DataLoader 165 /// is exhausted. This operation will block until a batch is available if one 166 /// is still expected. next()167 std::optional<BatchType> next() { 168 if (options_.workers > 0) { 169 while (std::optional<Result> result = this->pop_result()) { 170 if (result->exception) { 171 throw WorkerException(result->exception); 172 } else if (result->batch) { 173 prefetch(1); 174 return std::move(result->batch); 175 } 176 } 177 } else if (auto batch_request = get_batch_request()) { 178 return this->main_thread_dataset_->get_batch(std::move(*batch_request)); 179 } 180 return nullopt; 181 } 182 183 /// The function that worker threads run. worker_thread(Dataset & dataset)184 void worker_thread(Dataset& dataset) { 185 while (true) { 186 auto job = shuttle_.pop_job(); 187 if (job.quit) { 188 break; 189 } 190 try { 191 auto batch = dataset.get_batch(std::move(*job.batch_request)); 192 shuttle_.push_result({std::move(batch), job.sequence_number}); 193 } catch (...) { 194 shuttle_.push_result({std::current_exception(), job.sequence_number}); 195 } 196 } 197 } 198 199 /// Convenience method that calls `shuttle_.push_job()` with the next sequence 200 /// number. 201 template <typename T> push_job(T value)202 void push_job(T value) { 203 shuttle_.push_job({std::move(value), sequence_number_++}); 204 } 205 206 /// Convenience method that gets the next result from the sequencer. pop_result()207 std::optional<Result> pop_result() { 208 return sequencer_->next( 209 [this] { return this->shuttle_.pop_result(this->options_.timeout); }); 210 } 211 212 /// Convenience method that creates a new sequencer based on the 213 /// `enforce_ordering` option. new_sequencer()214 std::unique_ptr<detail::sequencers::Sequencer<Result>> new_sequencer() { 215 if (options_.enforce_ordering) { 216 return std::make_unique<detail::sequencers::OrderedSequencer<Result>>( 217 options_.max_jobs); 218 } 219 return std::make_unique<detail::sequencers::NoSequencer<Result>>(); 220 } 221 222 /// The options the DataLoader was configured with. 223 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 224 const FullDataLoaderOptions options_; 225 226 /// The dataset for the main thread, only has a value if the number of 227 /// worker threads was configured as zero, meaning the main thread has to do 228 /// all the work (synchronously). NOTE: Really want this to be on the heap 229 /// when empty, therefore `unique_ptr` and not `optional`. 230 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 231 std::unique_ptr<Dataset> main_thread_dataset_; 232 233 /// The sequence number for the *next* batch to be retrieved from the 234 /// dataset. 235 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 236 size_t sequence_number_ = 0; 237 238 /// The worker threads, running the `worker_thread()` method. 239 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 240 std::vector<std::thread> workers_; 241 242 /// The `DataShuttle` which takes care of the life cycle of a job. 243 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 244 detail::DataShuttle<Job, Result> shuttle_; 245 246 /// The `Sequencer`, which handles optional ordering of batches. 247 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 248 std::unique_ptr<detail::sequencers::Sequencer<Result>> sequencer_; 249 250 /// True if the DataLoader has joined its worker threads. 251 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 252 bool joined_ = false; 253 }; 254 } // namespace data 255 } // namespace torch 256