xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/datasets/chunk.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/irange.h>
4 #include <torch/arg.h>
5 #include <torch/data/datasets/stateful.h>
6 #include <torch/data/samplers.h>
7 #include <queue>
8 #include <thread>
9 
10 #include <torch/serialize.h>
11 
12 namespace torch {
13 namespace data {
14 namespace datasets {
15 
16 /// Interface for chunk reader, which performs data chunking and reading of
17 /// entire chunks.
18 ///
19 /// A chunk could be an entire file, such as an audio data file or an image,
20 /// or part of a file in the case of a large text-file split based on seek
21 /// positions.
22 template <
23     typename ExampleType_,
24     typename ChunkType_ = std::vector<ExampleType_>>
25 class ChunkDataReader {
26  public:
27   virtual ~ChunkDataReader() = default;
28 
29   using ChunkType = ChunkType_;
30   using ExampleType = ExampleType_;
31 
32   /// Read an entire chunk.
33   virtual ChunkType read_chunk(size_t chunk_index) = 0;
34 
35   /// Returns the number of chunks available in this reader.
36   virtual size_t chunk_count() = 0;
37 
38   /// This will clear any internal state associate with this reader.
39   virtual void reset() = 0;
40 };
41 
42 namespace detail {
43 /// BatchDataBuffer manages a queue of UnwrappedBatchData. After a new chunk is
44 /// loaded, BatchDataBuffer splits it into small batches and push them into the
45 /// queue. When get_batch is called from data loader, it pops cached batches and
46 /// return. If the cache is empty, it either waits to load more chunks or return
47 /// null if all chunks are loaded.
48 template <
49     typename UnwrappedBatch,
50     typename ExampleSampler = samplers::RandomSampler>
51 class BatchDataBuffer {
52  public:
53   using UnwrappedBatchType = UnwrappedBatch;
54   using BatchType = torch::optional<UnwrappedBatchType>;
55   using BatchRequestType = typename ExampleSampler::BatchRequestType;
56 
BatchDataBuffer(size_t batch_size,ExampleSampler & example_sampler,size_t queue_capacity)57   BatchDataBuffer(
58       size_t batch_size,
59       ExampleSampler& example_sampler,
60       size_t queue_capacity)
61       : batch_size_(batch_size),
62         example_sampler_(example_sampler),
63         queue_capacity_(queue_capacity) {}
64 
65   /// Return batch data from the queue. Called from the ChunkDataset main
66   /// thread.
get_batch()67   BatchType get_batch() {
68     std::unique_lock<std::mutex> lock(queue_mutex_);
69     cv_read_.wait(lock, [this] {
70       // wait till there is available data in the queue or if all chunks are
71       // loaded (i.e. the dataset is exhausted for this epoch)
72       return (
73           this->total_example_count_in_queue_ >= batch_size_ || this->stop_);
74     });
75     if (batch_queue_.empty()) {
76       AT_ASSERT(stop_);
77       // All batches have been retrieved. Return an empty batch.
78       return nullopt;
79     }
80 
81     UnwrappedBatchData batch = std::move(batch_queue_.front());
82     batch_queue_.pop();
83     if (batch.exception) {
84       throw WorkerException(batch.exception);
85     }
86 
87     total_example_count_in_queue_ -= batch.batch_data.size();
88     lock.unlock();
89     cv_write_.notify_all();
90 
91     return batch.batch_data;
92   }
93 
94   /// Push preloaded chunks to batch queue. Called from the ChunkDataset worker
95   /// threads.
add_chunk_data(UnwrappedBatchType data)96   void add_chunk_data(UnwrappedBatchType data) {
97     std::unique_lock<std::mutex> lock(queue_mutex_);
98     cv_write_.wait(lock, [this] {
99       // stop loading if we have preloaded enough data.
100       return this->total_example_count_in_queue_ < this->queue_capacity_ ||
101           this->stop_;
102     });
103     if (stop_) {
104       // When stop_ is true, it means no further chunk loading is necessary.
105       // Return without any further processing.
106       return;
107     }
108 
109     auto data_size = data.size();
110     auto remaining_size = data_size;
111     example_sampler_.reset(data_size);
112 
113     auto fill_batch = [&](size_t example_count, UnwrappedBatchType& batch) {
114       auto batch_example_indices = this->example_sampler_.next(example_count);
115       AT_ASSERT(
116           batch_example_indices &&
117           batch_example_indices.value().size() == example_count);
118       BatchRequestType& indices = batch_example_indices.value();
119       for (size_t i : indices) {
120         TORCH_CHECK(i < data_size, "Index out of range");
121         batch.emplace_back(std::move(data[i]));
122       }
123       remaining_size -= example_count;
124     };
125 
126     if (!batch_queue_.empty()) {
127       // if the queue has existing data, and the last batch doesn't have enough
128       // examples to fill a batch_size batch, add more example to this batch
129       // first.
130       auto& batch = batch_queue_.back();
131       size_t current_count = batch.batch_data.size();
132       if (current_count < batch_size_) {
133         auto example_count =
134             std::min(remaining_size, batch_size_ - current_count);
135         fill_batch(example_count, batch.batch_data);
136       }
137     }
138 
139     // If we still have data remaining after filling the last pushed batch, add
140     // them to the queue too.
141     // NOLINTNEXTLINE(bugprone-infinite-loop)
142     while (remaining_size > 0) {
143       UnwrappedBatchType current_batch;
144 
145       // Allocate the batch memory ahead of time.
146       current_batch.reserve(batch_size_);
147 
148       auto example_count = std::min(remaining_size, batch_size_);
149       fill_batch(example_count, current_batch);
150       batch_queue_.emplace(std::move(current_batch));
151     }
152     total_example_count_in_queue_ += data_size;
153     lock.unlock();
154     cv_read_.notify_all();
155   }
156 
157   /// Push exceptions thrown during preloading into batch queue. Called from
158   /// the ChunkDataset worker threads.
add_chunk_data(std::exception_ptr e_ptr)159   void add_chunk_data(std::exception_ptr e_ptr) {
160     std::unique_lock<std::mutex> lock(queue_mutex_);
161     cv_write_.wait(lock, [this] {
162       // stop loading if we have preloaded enough data.
163       return (
164           this->total_example_count_in_queue_ < this->queue_capacity_ ||
165           this->stop_);
166     });
167     if (stop_) {
168       // When stop_ is true, it means this current thread needs to be tore down,
169       // the batch buffer will be discarded, so no need to enqueue any new
170       // exceptions.
171       return;
172     }
173 
174     batch_queue_.emplace(e_ptr);
175     lock.unlock();
176     cv_read_.notify_all();
177   }
178 
stop()179   void stop() {
180     {
181       // Hold the lock before changing stop_ to prevent a race condition which
182       // can cause a deadlock. To be more specific, conditional variable
183       // cv_write_ waits on predicate stop_ in add_chunk_data(). The wait
184       // happens in two steps: 1) while still holding the lock, check if
185       // predicate is true; 2) if it is true, proceeds, otherwise, release the
186       // lock and wait until notified. Without holding a lock, cv_write_'s
187       // notification can happen in between step 1) and 2). In that case, as
188       // cv_write_ is not in waiting status yet, so the notification is lost and
189       // cv_write_ will sleep forever. By taking a lock before changing
190       // predicate stop_, it is ensured updating and evaluating stop_ always
191       // happen in a synchronized way
192       std::lock_guard<std::mutex> lock(queue_mutex_);
193       stop_ = true;
194     }
195 
196     // notify all writers, wake them from wait to exit current method.
197     cv_write_.notify_all();
198     // notify all readers too.
199     cv_read_.notify_all();
200   }
201   /// The batch size is needed to create batches from the chunk data. Similar to
202   /// regular dataloader where the batches are created with prefetches,
203   /// BatchDataBuffer perform the batch creation using the provided batch size.
204   size_t batch_size_ = 0;
205 
206   /// count of total example stored in the queue
207   size_t total_example_count_in_queue_ = 0;
208 
209   /// struct that contains a raw unwrapped batch unit. An unwrapped batch unit
210   /// is the raw data without 'optional' wrapper. It can be a collection of
211   /// images, utterances, e.t.c.
212   struct UnwrappedBatchData {
UnwrappedBatchDataUnwrappedBatchData213     explicit UnwrappedBatchData(UnwrappedBatchType data)
214         : batch_data(std::move(data)) {}
215 
216     // NOLINTNEXTLINE(modernize-pass-by-value)
UnwrappedBatchDataUnwrappedBatchData217     explicit UnwrappedBatchData(std::exception_ptr e) : exception(e) {}
218 
219     /// batch data to return
220     UnwrappedBatchType batch_data;
221 
222     /// exception pointer which captures any abnormal exceptions while creating
223     /// the batch.
224     std::exception_ptr exception;
225   };
226 
227   /// local cache to store example batches from loaded chunk
228   std::queue<UnwrappedBatchData> batch_queue_;
229 
230   // sync batch_queue_ update.
231   std::mutex queue_mutex_;
232 
233   std::condition_variable cv_read_;
234   std::condition_variable cv_write_;
235 
236   ExampleSampler& example_sampler_;
237 
238   // configurable maximun number of elements the queue can hold at one time.
239   size_t queue_capacity_;
240 
241   // When set to true, it wakes the writer threads from the wait and exit
242   // current function call. This is needed when ChunkDataSet.Reset is called
243   // while the previous epoch is not exhausted yet. When ChunkDataset is waiting
244   // its preloader to finish previous work before tearing down the thread, the
245   // preloader could be still waiting for the conditional variable, thus cause
246   // the program to hang. This boolean is used to break this waiting condition.
247   bool stop_ = false;
248 };
249 } // namespace detail
250 
251 /// Options to configure a `ChunkDataset`.
252 struct ChunkDatasetOptions {
253   ChunkDatasetOptions() = delete;
254   ChunkDatasetOptions(
255       size_t preloader_count,
256       size_t batch_size,
257       size_t cache_size = 2048,
258       size_t cross_chunk_shuffle_count = 1)
preloader_count_ChunkDatasetOptions259       : preloader_count_(preloader_count),
260         batch_size_(batch_size),
261         cache_size_(cache_size),
262         cross_chunk_shuffle_count_(cross_chunk_shuffle_count) {
263     TORCH_CHECK(
264         preloader_count_ > 0,
265         "Preloader count is 0. At least one preloader needs to be specified.");
266     TORCH_CHECK(
267         batch_size_ > 0,
268         "Batch size is 0. A positive batch size needs to be specified.");
269     TORCH_CHECK(
270         cache_size_ > 0,
271         "Cache size is 0. A positive cache size needs to be specified.");
272     TORCH_CHECK(
273         cache_size_ >= batch_size_,
274         "Cache size is less than batch size. Cache needs to be large enough to "
275         "hold at least one batch.");
276     TORCH_CHECK(
277         cross_chunk_shuffle_count_ > 0,
278         "cross_chunk_shuffle_count needs to be greater than 0.");
279   }
280 
281   /// The number of worker thread to preload chunk data.
282   TORCH_ARG(size_t, preloader_count);
283 
284   /// The size of each batch.
285   TORCH_ARG(size_t, batch_size);
286 
287   /// The capacity of the queue for batch caching.
288   TORCH_ARG(size_t, cache_size) = 2048;
289 
290   // The number of chunks to perfrom cross-chunk shuffling. Default to 1 meaning
291   // no cross-chunk shuffling. When it is equal to n (n > 1), n random
292   // chunks will be loaded at once and example shuffling will be performed
293   // across all those n chunks.
294   // Note: Usually the default config (1 chunk shuffle + example shuffle) is
295   // good enough to generate random distributed data. Use this parameter only if
296   // you know cross-shuffle is needed in your case. Also there is a performance
297   // penalty when this value is greater than 1, as we need to do extra merge
298   // between multiple chunks before performing example sampling.
299   TORCH_ARG(size_t, cross_chunk_shuffle_count) = 1;
300 };
301 
302 /// A stateful dataset that support hierarchical sampling and prefetching of
303 /// entre chunks.
304 ///
305 /// Unlike regular dataset, chunk dataset require two samplers to operate and
306 /// keeps an internal state. `ChunkSampler` selects, which chunk to load next,
307 /// while the `ExampleSampler` determins the order of Examples that are returned
308 /// in each `get_batch` call. The hierarchical sampling approach used here is
309 /// inspired by this paper http://martin.zinkevich.org/publications/nips2010.pdf
310 template <
311     typename ChunkReader,
312     typename ChunkSampler = samplers::RandomSampler,
313     typename ExampleSampler = samplers::RandomSampler>
314 class ChunkDataset final
315     : public StatefulDataset<
316           ChunkDataset<ChunkReader, ChunkSampler, ExampleSampler>,
317           typename ChunkReader::BatchType,
318           size_t> {
319  public:
320   using BatchType = torch::optional<typename ChunkReader::BatchType>;
321   using UnwrappedBatchType = typename ChunkReader::BatchType;
322   using BatchRequestType = size_t;
323   using ChunkSamplerType = ChunkSampler;
324   using ExampleSamplerType = ExampleSampler;
325 
326   ChunkDataset(
327       ChunkReader chunk_reader,
328       ChunkSampler chunk_sampler,
329       ExampleSampler example_sampler,
330       ChunkDatasetOptions options,
331       std::function<void(UnwrappedBatchType&)> preprocessing_policy =
332           std::function<void(UnwrappedBatchType&)>())
chunk_reader_(std::move (chunk_reader))333       : chunk_reader_(std::move(chunk_reader)),
334         chunk_sampler_(std::move(chunk_sampler)),
335         example_sampler_(std::move(example_sampler)),
336         options_(std::move(options)),
337         preprocessing_policy_(std::move(preprocessing_policy)),
338         quit_worker_(false),
339         running_preloaders_(0),
340         load_checkpoint_(false) {}
341 
~ChunkDataset()342   ~ChunkDataset() override {
343     // stop batch buffer first.
344     if (batch_buffer_) {
345       batch_buffer_->stop();
346     }
347     free_workers();
348   }
349 
350   /// Default get_batch method of BatchDataset. This method returns
351   /// Example batches created from the preloaded chunks. The implemenation
352   /// is dataset agnostic and does not need overriding in different chunk
353   /// datasets.
get_batch(size_t batch_size)354   BatchType get_batch(size_t batch_size) override {
355     TORCH_CHECK(
356         batch_buffer_ != nullptr,
357         "Dataset needs to call reset() before calling get_batch().");
358 
359     TORCH_CHECK(
360         batch_size == options_.batch_size(),
361         "The requested batch size does not match with the initialized batch size.\n"
362         " The requested batch size is ",
363         batch_size,
364         ", while the dataset is created with batch size equal to ",
365         options_.batch_size());
366     return batch_buffer_->get_batch();
367   }
368 
369   /// Helper method around get_batch as `batch_size` is not strictly necessary
get_batch()370   BatchType get_batch() {
371     return get_batch(options_.batch_size());
372   }
373 
374   /// This will clear any internal state and starts the internal prefetching
375   /// mechanism for the chunk dataset.
reset()376   void reset() override {
377     // We need this to support partial data reads via dataloader iterator.
378     if (batch_buffer_) {
379       batch_buffer_->stop();
380     }
381     // free workers from previous reset if there is any.
382     free_workers();
383     preload_threads_.clear();
384 
385     if (!load_checkpoint_) {
386       chunk_reader_.reset();
387       chunk_sampler_.reset(chunk_reader_.chunk_count());
388       load_checkpoint_ = false;
389     }
390 
391     // Throw out any existing cached batch in the buffer and re-creates a new
392     // chunk buffer.
393     batch_buffer_ = std::make_unique<
394         detail::BatchDataBuffer<UnwrappedBatchType, ExampleSamplerType>>(
395         options_.batch_size(), example_sampler_, options_.cache_size());
396 
397     // create new workers for this new epoch.
398     quit_worker_ = false;
399 
400     AT_ASSERT(running_preloaders_ == 0);
401     running_preloaders_ = options_.preloader_count();
402     for (const auto i : c10::irange(options_.preloader_count())) {
403       preload_threads_.emplace_back([this, i]() { this->preloader(i); });
404     }
405   }
406 
407   /// size is not used for chunk dataset.
size()408   std::optional<size_t> size() const override {
409     return torch::nullopt;
410   }
411 
412   // provide a references to chunk sampler. Used mainly in distributed data
413   // loading to set the epoch number for the sampler.
chunk_sampler()414   ChunkSamplerType& chunk_sampler() {
415     return chunk_sampler_;
416   }
417 
save(serialize::OutputArchive & archive)418   void save(serialize::OutputArchive& archive) const override {
419     std::lock_guard<std::mutex> lock(chunk_index_guard_);
420     chunk_sampler_.save(archive);
421   }
422 
load(serialize::InputArchive & archive)423   void load(serialize::InputArchive& archive) override {
424     std::lock_guard<std::mutex> lock(chunk_index_guard_);
425     chunk_sampler_.load(archive);
426     load_checkpoint_ = true;
427   }
428 
429  private:
430   /// running on worker thread to preload chunk data.
preloader(size_t id)431   void preloader(size_t id) {
432     while (!quit_worker_.load()) {
433       try {
434         std::vector<size_t> chunk_idx;
435         {
436           std::lock_guard<std::mutex> lock(chunk_index_guard_);
437           if (auto chunk_sampler_result = chunk_sampler_.next(
438                   this->options_.cross_chunk_shuffle_count())) {
439             chunk_idx = chunk_sampler_result.value();
440           } else {
441             break;
442           }
443         }
444         UnwrappedBatchType data = chunk_reader_.read_chunk(chunk_idx[0]);
445         for (const auto i : c10::irange(1, chunk_idx.size())) {
446           auto chunk_data = chunk_reader_.read_chunk(chunk_idx[i]);
447           std::move(
448               chunk_data.begin(), chunk_data.end(), std::back_inserter(data));
449         }
450         if (preprocessing_policy_) {
451           preprocessing_policy_(data);
452         }
453         if (!data.empty()) { // skip empty chunks.
454           batch_buffer_->add_chunk_data(std::move(data));
455         }
456       } catch (...) {
457         batch_buffer_->add_chunk_data(std::current_exception());
458       }
459     }
460     AT_ASSERT(running_preloaders_.load() > 0);
461     --running_preloaders_;
462     if (running_preloaders_.load() == 0) {
463       // all preloaders are completed, so we can notify the batch_buffer.
464       batch_buffer_->stop();
465     }
466   }
467 
468   /// Block the current thread until the workers finish execution and exit.
free_workers()469   void free_workers() {
470     if (!quit_worker_.load()) {
471       quit_worker_ = true;
472       for (auto& worker_thread : preload_threads_) {
473         worker_thread.join();
474       }
475     }
476   }
477 
478  private:
479   // Templated class that defines what is a chunk and how to read chunk data.
480   // When a chunk is returned by chunk_reader_, ChunkDataset split it into
481   // batches and caches them in batch_buffer_.
482   ChunkReader chunk_reader_;
483 
484   // chunk sampler to shuffle different chunks
485   ChunkSamplerType chunk_sampler_;
486 
487   // example sampler to shuffle examples in a specific chunk
488   ExampleSamplerType example_sampler_;
489 
490   // batch data buffer which holds chunk data from preloading thread.
491   std::shared_ptr<
492       detail::BatchDataBuffer<UnwrappedBatchType, ExampleSamplerType>>
493       batch_buffer_;
494 
495   // worker thread pool
496   std::vector<std::thread> preload_threads_;
497 
498   /// The options the Dataset was configured with.
499   const ChunkDatasetOptions options_;
500 
501   // function pointer wrapper to apply custom processing over chunk data. This
502   // is considered an advanced parameter for developers who want to apply a
503   // pre-process to the chunk data before sampling into minibatch.
504   // Different than the collate function, this policy is applied on the chunk
505   // level, instead of minibatch level. When a chunk of data is loaded (multiple
506   // chunks if cross_chunk_shuffle_count_ is greater than 1), this policy is
507   // applied to the full loaded data. It is useful if developers want to
508   // perform pre-processing (like bucketing) to the chunk data before
509   // example sampler samples the data. By default it's an empty pointer and no
510   // action will be taken.
511   std::function<void(UnwrappedBatchType&)> preprocessing_policy_;
512 
513   // indicate whether the worker thread can be teared down
514   std::atomic<bool> quit_worker_;
515 
516   // keep track of running preloaders to notify batch buffer. A value 0
517   // indicates that the chunk loading is completed.
518   std::atomic<size_t> running_preloaders_;
519 
520   // mutex to synchronize chunk sampler next() call.
521   mutable std::mutex chunk_index_guard_;
522 
523   // boolean value to indicate whether we need to load the checkpoint for
524   // chunk_sampler_.
525   bool load_checkpoint_;
526 };
527 } // namespace datasets
528 } // namespace data
529 } // namespace torch
530