xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/dataloader/base.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/data/dataloader_options.h>
4 #include <torch/data/detail/data_shuttle.h>
5 #include <torch/data/detail/sequencers.h>
6 #include <torch/data/iterator.h>
7 #include <torch/data/samplers/random.h>
8 #include <torch/data/worker_exception.h>
9 #include <torch/types.h>
10 
11 #include <torch/csrc/utils/variadic.h>
12 
13 #include <c10/util/Exception.h>
14 #include <c10/util/irange.h>
15 
16 #include <cstddef>
17 #include <exception>
18 #include <memory>
19 #include <thread>
20 #include <type_traits>
21 #include <utility>
22 #include <vector>
23 
24 namespace torch {
25 namespace data {
26 template <typename Dataset, typename Batch, typename BatchRequest>
27 class DataLoaderBase {
28  public:
29   using BatchType = Batch;
30   using BatchRequestType = BatchRequest;
31 
32   /// Constructs a new DataLoader from a `dataset` to sample from, `options`
33   /// to configure the DataLoader with, and a `sampler` that specifies the
34   /// sampling strategy.
35   DataLoaderBase(
36       DataLoaderOptions options,
37       std::unique_ptr<Dataset> main_thread_dataset = nullptr)
options_(std::move (options))38       : options_(std::move(options)),
39         main_thread_dataset_(std::move(main_thread_dataset)),
40         sequencer_(new_sequencer()) {}
41 
42   // NOLINTNEXTLINE(bugprone-exception-escape)
~DataLoaderBase()43   virtual ~DataLoaderBase() {
44     join();
45   }
46 
47   /// Returns an iterator into the DataLoader. The lifetime of the iterator is
48   /// bound to the DataLoader. In C++ standards language, the category of the
49   /// iterator is `OutputIterator`. See
50   /// https://en.cppreference.com/w/cpp/named_req/OutputIterator for what this
51   /// means. In short: you may increment the iterator and dereference it, but
52   /// cannot go back, or step forward more than one position at a time. When the
53   /// DataLoader is exhausted, it will compare equal with the special
54   /// "sentinel" iterator returned by `DataLoader::end()`. Most of the time, you
55   /// should only use range-for loops to loop over the DataLoader, but
56   /// standard algorithms like `std::copy(dataloader.begin(), dataloader.end(),
57   /// output_iterator)`  are supported too.
begin()58   Iterator<Batch> begin() {
59     TORCH_CHECK(
60         shuttle_.in_flight_jobs() == 0,
61         "Attempted to get a new DataLoader iterator "
62         "while another iterator is not yet exhausted");
63     reset();
64     return Iterator<Batch>(std::make_unique<detail::ValidIterator<Batch>>(
65         [this] { return this->next(); }));
66   }
67 
68   /// Returns a special "sentinel" iterator that compares equal with a
69   /// non-sentinel iterator once the DataLoader is exhausted.
end()70   Iterator<Batch> end() {
71     return Iterator<Batch>(std::make_unique<detail::SentinelIterator<Batch>>());
72   }
73 
74   /// Joins the DataLoader's worker threads and drains internal queues.
75   /// This function may only be invoked from the main thread (in which the
76   /// DataLoader lives).
join()77   void join() {
78     if (joined_) {
79       return;
80     }
81     shuttle_.drain();
82     // Send one 'quit' message per worker. Since a worker dies (exits its
83     // thread) after receiving this message, each `QuitWorker()` message will be
84     // read by exactly one worker.
85     for (const auto w : c10::irange(options_.workers)) {
86       (void)w; // Suppress unused variable warning
87       push_job(QuitWorker());
88     }
89     for (auto& worker : workers_) {
90       worker.join();
91     }
92     joined_ = true;
93   }
94 
95   /// Returns the options with which the DataLoader was configured.
options()96   const FullDataLoaderOptions& options() const noexcept {
97     return options_;
98   }
99 
100  protected:
101   /// Simple mix-in to give something a sequence number.
102   struct Sequenced {
103     Sequenced() = default;
SequencedSequenced104     Sequenced(size_t sqn) : sequence_number(sqn) {}
105     size_t sequence_number;
106   };
107 
108   struct QuitWorker {};
109 
110   /// A `Job` is either a `BatchRequest` (new indices to fetch data at) or a
111   /// `QuitWorker` object, to indicate the worker should shut down.
112   struct Job : Sequenced {
113     Job() = default;
JobJob114     Job(QuitWorker q, size_t sqn) : Sequenced(sqn), quit(q) {}
JobJob115     Job(BatchRequest&& i, size_t sqn)
116         : Sequenced(sqn), batch_request(std::move(i)) {}
117     std::optional<QuitWorker> quit;
118     std::optional<BatchRequest> batch_request;
119   };
120 
121   /// The finished result of a job.
122   struct Result : Sequenced {
123     Result() = default;
ResultResult124     Result(std::optional<Batch>&& b, size_t sqn)
125         : Sequenced(sqn), batch(std::move(b)) {}
ResultResult126     Result(std::exception_ptr exception, size_t sqn)
127         : Sequenced(sqn), exception(std::move(exception)) {}
128     std::optional<Batch> batch;
129     std::exception_ptr exception;
130   };
131 
132   /// Subclass hook for getting the next batch request. The stateless case will
133   /// ask the sampler for a new batch request (e.g. a vector of indices), while
134   /// the stateful one will simply return the batch size.
135   virtual std::optional<BatchRequestType> get_batch_request() = 0;
136 
137   /// Resets the internal state of the DataLoader, optionally pre-fetching
138   /// new jobs.
reset()139   virtual void reset() {
140     shuttle_.drain();
141     sequence_number_ = 0;
142     sequencer_ = new_sequencer();
143     prefetch();
144   }
145 
146   /// Schedules `requested_jobs` many new batches to be fetched. The actual
147   /// number of jobs scheduled may be less if the DataLoader exhausts.
prefetch(size_t requested_jobs)148   void prefetch(size_t requested_jobs) {
149     for (const auto r : c10::irange(requested_jobs)) {
150       (void)r; // Suppress unused variable
151       if (auto batch_request = get_batch_request()) {
152         this->push_job(std::move(*batch_request));
153       } else {
154         break;
155       }
156     }
157   }
158 
159   /// Schedules the maximum number of jobs (based on the `max_jobs` option).
prefetch()160   void prefetch() {
161     prefetch(options_.max_jobs);
162   }
163 
164   /// Returns the next batch of data, or an empty `optional` if the DataLoader
165   /// is exhausted. This operation will block until a batch is available if one
166   /// is still expected.
next()167   std::optional<BatchType> next() {
168     if (options_.workers > 0) {
169       while (std::optional<Result> result = this->pop_result()) {
170         if (result->exception) {
171           throw WorkerException(result->exception);
172         } else if (result->batch) {
173           prefetch(1);
174           return std::move(result->batch);
175         }
176       }
177     } else if (auto batch_request = get_batch_request()) {
178       return this->main_thread_dataset_->get_batch(std::move(*batch_request));
179     }
180     return nullopt;
181   }
182 
183   /// The function that worker threads run.
worker_thread(Dataset & dataset)184   void worker_thread(Dataset& dataset) {
185     while (true) {
186       auto job = shuttle_.pop_job();
187       if (job.quit) {
188         break;
189       }
190       try {
191         auto batch = dataset.get_batch(std::move(*job.batch_request));
192         shuttle_.push_result({std::move(batch), job.sequence_number});
193       } catch (...) {
194         shuttle_.push_result({std::current_exception(), job.sequence_number});
195       }
196     }
197   }
198 
199   /// Convenience method that calls `shuttle_.push_job()` with the next sequence
200   /// number.
201   template <typename T>
push_job(T value)202   void push_job(T value) {
203     shuttle_.push_job({std::move(value), sequence_number_++});
204   }
205 
206   /// Convenience method that gets the next result from the sequencer.
pop_result()207   std::optional<Result> pop_result() {
208     return sequencer_->next(
209         [this] { return this->shuttle_.pop_result(this->options_.timeout); });
210   }
211 
212   /// Convenience method that creates a new sequencer based on the
213   /// `enforce_ordering` option.
new_sequencer()214   std::unique_ptr<detail::sequencers::Sequencer<Result>> new_sequencer() {
215     if (options_.enforce_ordering) {
216       return std::make_unique<detail::sequencers::OrderedSequencer<Result>>(
217           options_.max_jobs);
218     }
219     return std::make_unique<detail::sequencers::NoSequencer<Result>>();
220   }
221 
222   /// The options the DataLoader was configured with.
223   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
224   const FullDataLoaderOptions options_;
225 
226   /// The dataset for the main thread, only has a value if the number of
227   /// worker threads was configured as zero, meaning the main thread has to do
228   /// all the work (synchronously). NOTE: Really want this to be on the heap
229   /// when empty, therefore `unique_ptr` and not `optional`.
230   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
231   std::unique_ptr<Dataset> main_thread_dataset_;
232 
233   /// The sequence number for the *next* batch to be retrieved from the
234   /// dataset.
235   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
236   size_t sequence_number_ = 0;
237 
238   /// The worker threads, running the `worker_thread()` method.
239   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
240   std::vector<std::thread> workers_;
241 
242   /// The `DataShuttle` which takes care of the life cycle of a job.
243   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
244   detail::DataShuttle<Job, Result> shuttle_;
245 
246   /// The `Sequencer`, which handles optional ordering of batches.
247   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
248   std::unique_ptr<detail::sequencers::Sequencer<Result>> sequencer_;
249 
250   /// True if the DataLoader has joined its worker threads.
251   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
252   bool joined_ = false;
253 };
254 } // namespace data
255 } // namespace torch
256