1 #pragma once 2 3 #include <c10/util/irange.h> 4 #include <torch/arg.h> 5 #include <torch/data/datasets/stateful.h> 6 #include <torch/data/samplers.h> 7 #include <queue> 8 #include <thread> 9 10 #include <torch/serialize.h> 11 12 namespace torch { 13 namespace data { 14 namespace datasets { 15 16 /// Interface for chunk reader, which performs data chunking and reading of 17 /// entire chunks. 18 /// 19 /// A chunk could be an entire file, such as an audio data file or an image, 20 /// or part of a file in the case of a large text-file split based on seek 21 /// positions. 22 template < 23 typename ExampleType_, 24 typename ChunkType_ = std::vector<ExampleType_>> 25 class ChunkDataReader { 26 public: 27 virtual ~ChunkDataReader() = default; 28 29 using ChunkType = ChunkType_; 30 using ExampleType = ExampleType_; 31 32 /// Read an entire chunk. 33 virtual ChunkType read_chunk(size_t chunk_index) = 0; 34 35 /// Returns the number of chunks available in this reader. 36 virtual size_t chunk_count() = 0; 37 38 /// This will clear any internal state associate with this reader. 39 virtual void reset() = 0; 40 }; 41 42 namespace detail { 43 /// BatchDataBuffer manages a queue of UnwrappedBatchData. After a new chunk is 44 /// loaded, BatchDataBuffer splits it into small batches and push them into the 45 /// queue. When get_batch is called from data loader, it pops cached batches and 46 /// return. If the cache is empty, it either waits to load more chunks or return 47 /// null if all chunks are loaded. 48 template < 49 typename UnwrappedBatch, 50 typename ExampleSampler = samplers::RandomSampler> 51 class BatchDataBuffer { 52 public: 53 using UnwrappedBatchType = UnwrappedBatch; 54 using BatchType = torch::optional<UnwrappedBatchType>; 55 using BatchRequestType = typename ExampleSampler::BatchRequestType; 56 BatchDataBuffer(size_t batch_size,ExampleSampler & example_sampler,size_t queue_capacity)57 BatchDataBuffer( 58 size_t batch_size, 59 ExampleSampler& example_sampler, 60 size_t queue_capacity) 61 : batch_size_(batch_size), 62 example_sampler_(example_sampler), 63 queue_capacity_(queue_capacity) {} 64 65 /// Return batch data from the queue. Called from the ChunkDataset main 66 /// thread. get_batch()67 BatchType get_batch() { 68 std::unique_lock<std::mutex> lock(queue_mutex_); 69 cv_read_.wait(lock, [this] { 70 // wait till there is available data in the queue or if all chunks are 71 // loaded (i.e. the dataset is exhausted for this epoch) 72 return ( 73 this->total_example_count_in_queue_ >= batch_size_ || this->stop_); 74 }); 75 if (batch_queue_.empty()) { 76 AT_ASSERT(stop_); 77 // All batches have been retrieved. Return an empty batch. 78 return nullopt; 79 } 80 81 UnwrappedBatchData batch = std::move(batch_queue_.front()); 82 batch_queue_.pop(); 83 if (batch.exception) { 84 throw WorkerException(batch.exception); 85 } 86 87 total_example_count_in_queue_ -= batch.batch_data.size(); 88 lock.unlock(); 89 cv_write_.notify_all(); 90 91 return batch.batch_data; 92 } 93 94 /// Push preloaded chunks to batch queue. Called from the ChunkDataset worker 95 /// threads. add_chunk_data(UnwrappedBatchType data)96 void add_chunk_data(UnwrappedBatchType data) { 97 std::unique_lock<std::mutex> lock(queue_mutex_); 98 cv_write_.wait(lock, [this] { 99 // stop loading if we have preloaded enough data. 100 return this->total_example_count_in_queue_ < this->queue_capacity_ || 101 this->stop_; 102 }); 103 if (stop_) { 104 // When stop_ is true, it means no further chunk loading is necessary. 105 // Return without any further processing. 106 return; 107 } 108 109 auto data_size = data.size(); 110 auto remaining_size = data_size; 111 example_sampler_.reset(data_size); 112 113 auto fill_batch = [&](size_t example_count, UnwrappedBatchType& batch) { 114 auto batch_example_indices = this->example_sampler_.next(example_count); 115 AT_ASSERT( 116 batch_example_indices && 117 batch_example_indices.value().size() == example_count); 118 BatchRequestType& indices = batch_example_indices.value(); 119 for (size_t i : indices) { 120 TORCH_CHECK(i < data_size, "Index out of range"); 121 batch.emplace_back(std::move(data[i])); 122 } 123 remaining_size -= example_count; 124 }; 125 126 if (!batch_queue_.empty()) { 127 // if the queue has existing data, and the last batch doesn't have enough 128 // examples to fill a batch_size batch, add more example to this batch 129 // first. 130 auto& batch = batch_queue_.back(); 131 size_t current_count = batch.batch_data.size(); 132 if (current_count < batch_size_) { 133 auto example_count = 134 std::min(remaining_size, batch_size_ - current_count); 135 fill_batch(example_count, batch.batch_data); 136 } 137 } 138 139 // If we still have data remaining after filling the last pushed batch, add 140 // them to the queue too. 141 // NOLINTNEXTLINE(bugprone-infinite-loop) 142 while (remaining_size > 0) { 143 UnwrappedBatchType current_batch; 144 145 // Allocate the batch memory ahead of time. 146 current_batch.reserve(batch_size_); 147 148 auto example_count = std::min(remaining_size, batch_size_); 149 fill_batch(example_count, current_batch); 150 batch_queue_.emplace(std::move(current_batch)); 151 } 152 total_example_count_in_queue_ += data_size; 153 lock.unlock(); 154 cv_read_.notify_all(); 155 } 156 157 /// Push exceptions thrown during preloading into batch queue. Called from 158 /// the ChunkDataset worker threads. add_chunk_data(std::exception_ptr e_ptr)159 void add_chunk_data(std::exception_ptr e_ptr) { 160 std::unique_lock<std::mutex> lock(queue_mutex_); 161 cv_write_.wait(lock, [this] { 162 // stop loading if we have preloaded enough data. 163 return ( 164 this->total_example_count_in_queue_ < this->queue_capacity_ || 165 this->stop_); 166 }); 167 if (stop_) { 168 // When stop_ is true, it means this current thread needs to be tore down, 169 // the batch buffer will be discarded, so no need to enqueue any new 170 // exceptions. 171 return; 172 } 173 174 batch_queue_.emplace(e_ptr); 175 lock.unlock(); 176 cv_read_.notify_all(); 177 } 178 stop()179 void stop() { 180 { 181 // Hold the lock before changing stop_ to prevent a race condition which 182 // can cause a deadlock. To be more specific, conditional variable 183 // cv_write_ waits on predicate stop_ in add_chunk_data(). The wait 184 // happens in two steps: 1) while still holding the lock, check if 185 // predicate is true; 2) if it is true, proceeds, otherwise, release the 186 // lock and wait until notified. Without holding a lock, cv_write_'s 187 // notification can happen in between step 1) and 2). In that case, as 188 // cv_write_ is not in waiting status yet, so the notification is lost and 189 // cv_write_ will sleep forever. By taking a lock before changing 190 // predicate stop_, it is ensured updating and evaluating stop_ always 191 // happen in a synchronized way 192 std::lock_guard<std::mutex> lock(queue_mutex_); 193 stop_ = true; 194 } 195 196 // notify all writers, wake them from wait to exit current method. 197 cv_write_.notify_all(); 198 // notify all readers too. 199 cv_read_.notify_all(); 200 } 201 /// The batch size is needed to create batches from the chunk data. Similar to 202 /// regular dataloader where the batches are created with prefetches, 203 /// BatchDataBuffer perform the batch creation using the provided batch size. 204 size_t batch_size_ = 0; 205 206 /// count of total example stored in the queue 207 size_t total_example_count_in_queue_ = 0; 208 209 /// struct that contains a raw unwrapped batch unit. An unwrapped batch unit 210 /// is the raw data without 'optional' wrapper. It can be a collection of 211 /// images, utterances, e.t.c. 212 struct UnwrappedBatchData { UnwrappedBatchDataUnwrappedBatchData213 explicit UnwrappedBatchData(UnwrappedBatchType data) 214 : batch_data(std::move(data)) {} 215 216 // NOLINTNEXTLINE(modernize-pass-by-value) UnwrappedBatchDataUnwrappedBatchData217 explicit UnwrappedBatchData(std::exception_ptr e) : exception(e) {} 218 219 /// batch data to return 220 UnwrappedBatchType batch_data; 221 222 /// exception pointer which captures any abnormal exceptions while creating 223 /// the batch. 224 std::exception_ptr exception; 225 }; 226 227 /// local cache to store example batches from loaded chunk 228 std::queue<UnwrappedBatchData> batch_queue_; 229 230 // sync batch_queue_ update. 231 std::mutex queue_mutex_; 232 233 std::condition_variable cv_read_; 234 std::condition_variable cv_write_; 235 236 ExampleSampler& example_sampler_; 237 238 // configurable maximun number of elements the queue can hold at one time. 239 size_t queue_capacity_; 240 241 // When set to true, it wakes the writer threads from the wait and exit 242 // current function call. This is needed when ChunkDataSet.Reset is called 243 // while the previous epoch is not exhausted yet. When ChunkDataset is waiting 244 // its preloader to finish previous work before tearing down the thread, the 245 // preloader could be still waiting for the conditional variable, thus cause 246 // the program to hang. This boolean is used to break this waiting condition. 247 bool stop_ = false; 248 }; 249 } // namespace detail 250 251 /// Options to configure a `ChunkDataset`. 252 struct ChunkDatasetOptions { 253 ChunkDatasetOptions() = delete; 254 ChunkDatasetOptions( 255 size_t preloader_count, 256 size_t batch_size, 257 size_t cache_size = 2048, 258 size_t cross_chunk_shuffle_count = 1) preloader_count_ChunkDatasetOptions259 : preloader_count_(preloader_count), 260 batch_size_(batch_size), 261 cache_size_(cache_size), 262 cross_chunk_shuffle_count_(cross_chunk_shuffle_count) { 263 TORCH_CHECK( 264 preloader_count_ > 0, 265 "Preloader count is 0. At least one preloader needs to be specified."); 266 TORCH_CHECK( 267 batch_size_ > 0, 268 "Batch size is 0. A positive batch size needs to be specified."); 269 TORCH_CHECK( 270 cache_size_ > 0, 271 "Cache size is 0. A positive cache size needs to be specified."); 272 TORCH_CHECK( 273 cache_size_ >= batch_size_, 274 "Cache size is less than batch size. Cache needs to be large enough to " 275 "hold at least one batch."); 276 TORCH_CHECK( 277 cross_chunk_shuffle_count_ > 0, 278 "cross_chunk_shuffle_count needs to be greater than 0."); 279 } 280 281 /// The number of worker thread to preload chunk data. 282 TORCH_ARG(size_t, preloader_count); 283 284 /// The size of each batch. 285 TORCH_ARG(size_t, batch_size); 286 287 /// The capacity of the queue for batch caching. 288 TORCH_ARG(size_t, cache_size) = 2048; 289 290 // The number of chunks to perfrom cross-chunk shuffling. Default to 1 meaning 291 // no cross-chunk shuffling. When it is equal to n (n > 1), n random 292 // chunks will be loaded at once and example shuffling will be performed 293 // across all those n chunks. 294 // Note: Usually the default config (1 chunk shuffle + example shuffle) is 295 // good enough to generate random distributed data. Use this parameter only if 296 // you know cross-shuffle is needed in your case. Also there is a performance 297 // penalty when this value is greater than 1, as we need to do extra merge 298 // between multiple chunks before performing example sampling. 299 TORCH_ARG(size_t, cross_chunk_shuffle_count) = 1; 300 }; 301 302 /// A stateful dataset that support hierarchical sampling and prefetching of 303 /// entre chunks. 304 /// 305 /// Unlike regular dataset, chunk dataset require two samplers to operate and 306 /// keeps an internal state. `ChunkSampler` selects, which chunk to load next, 307 /// while the `ExampleSampler` determins the order of Examples that are returned 308 /// in each `get_batch` call. The hierarchical sampling approach used here is 309 /// inspired by this paper http://martin.zinkevich.org/publications/nips2010.pdf 310 template < 311 typename ChunkReader, 312 typename ChunkSampler = samplers::RandomSampler, 313 typename ExampleSampler = samplers::RandomSampler> 314 class ChunkDataset final 315 : public StatefulDataset< 316 ChunkDataset<ChunkReader, ChunkSampler, ExampleSampler>, 317 typename ChunkReader::BatchType, 318 size_t> { 319 public: 320 using BatchType = torch::optional<typename ChunkReader::BatchType>; 321 using UnwrappedBatchType = typename ChunkReader::BatchType; 322 using BatchRequestType = size_t; 323 using ChunkSamplerType = ChunkSampler; 324 using ExampleSamplerType = ExampleSampler; 325 326 ChunkDataset( 327 ChunkReader chunk_reader, 328 ChunkSampler chunk_sampler, 329 ExampleSampler example_sampler, 330 ChunkDatasetOptions options, 331 std::function<void(UnwrappedBatchType&)> preprocessing_policy = 332 std::function<void(UnwrappedBatchType&)>()) chunk_reader_(std::move (chunk_reader))333 : chunk_reader_(std::move(chunk_reader)), 334 chunk_sampler_(std::move(chunk_sampler)), 335 example_sampler_(std::move(example_sampler)), 336 options_(std::move(options)), 337 preprocessing_policy_(std::move(preprocessing_policy)), 338 quit_worker_(false), 339 running_preloaders_(0), 340 load_checkpoint_(false) {} 341 ~ChunkDataset()342 ~ChunkDataset() override { 343 // stop batch buffer first. 344 if (batch_buffer_) { 345 batch_buffer_->stop(); 346 } 347 free_workers(); 348 } 349 350 /// Default get_batch method of BatchDataset. This method returns 351 /// Example batches created from the preloaded chunks. The implemenation 352 /// is dataset agnostic and does not need overriding in different chunk 353 /// datasets. get_batch(size_t batch_size)354 BatchType get_batch(size_t batch_size) override { 355 TORCH_CHECK( 356 batch_buffer_ != nullptr, 357 "Dataset needs to call reset() before calling get_batch()."); 358 359 TORCH_CHECK( 360 batch_size == options_.batch_size(), 361 "The requested batch size does not match with the initialized batch size.\n" 362 " The requested batch size is ", 363 batch_size, 364 ", while the dataset is created with batch size equal to ", 365 options_.batch_size()); 366 return batch_buffer_->get_batch(); 367 } 368 369 /// Helper method around get_batch as `batch_size` is not strictly necessary get_batch()370 BatchType get_batch() { 371 return get_batch(options_.batch_size()); 372 } 373 374 /// This will clear any internal state and starts the internal prefetching 375 /// mechanism for the chunk dataset. reset()376 void reset() override { 377 // We need this to support partial data reads via dataloader iterator. 378 if (batch_buffer_) { 379 batch_buffer_->stop(); 380 } 381 // free workers from previous reset if there is any. 382 free_workers(); 383 preload_threads_.clear(); 384 385 if (!load_checkpoint_) { 386 chunk_reader_.reset(); 387 chunk_sampler_.reset(chunk_reader_.chunk_count()); 388 load_checkpoint_ = false; 389 } 390 391 // Throw out any existing cached batch in the buffer and re-creates a new 392 // chunk buffer. 393 batch_buffer_ = std::make_unique< 394 detail::BatchDataBuffer<UnwrappedBatchType, ExampleSamplerType>>( 395 options_.batch_size(), example_sampler_, options_.cache_size()); 396 397 // create new workers for this new epoch. 398 quit_worker_ = false; 399 400 AT_ASSERT(running_preloaders_ == 0); 401 running_preloaders_ = options_.preloader_count(); 402 for (const auto i : c10::irange(options_.preloader_count())) { 403 preload_threads_.emplace_back([this, i]() { this->preloader(i); }); 404 } 405 } 406 407 /// size is not used for chunk dataset. size()408 std::optional<size_t> size() const override { 409 return torch::nullopt; 410 } 411 412 // provide a references to chunk sampler. Used mainly in distributed data 413 // loading to set the epoch number for the sampler. chunk_sampler()414 ChunkSamplerType& chunk_sampler() { 415 return chunk_sampler_; 416 } 417 save(serialize::OutputArchive & archive)418 void save(serialize::OutputArchive& archive) const override { 419 std::lock_guard<std::mutex> lock(chunk_index_guard_); 420 chunk_sampler_.save(archive); 421 } 422 load(serialize::InputArchive & archive)423 void load(serialize::InputArchive& archive) override { 424 std::lock_guard<std::mutex> lock(chunk_index_guard_); 425 chunk_sampler_.load(archive); 426 load_checkpoint_ = true; 427 } 428 429 private: 430 /// running on worker thread to preload chunk data. preloader(size_t id)431 void preloader(size_t id) { 432 while (!quit_worker_.load()) { 433 try { 434 std::vector<size_t> chunk_idx; 435 { 436 std::lock_guard<std::mutex> lock(chunk_index_guard_); 437 if (auto chunk_sampler_result = chunk_sampler_.next( 438 this->options_.cross_chunk_shuffle_count())) { 439 chunk_idx = chunk_sampler_result.value(); 440 } else { 441 break; 442 } 443 } 444 UnwrappedBatchType data = chunk_reader_.read_chunk(chunk_idx[0]); 445 for (const auto i : c10::irange(1, chunk_idx.size())) { 446 auto chunk_data = chunk_reader_.read_chunk(chunk_idx[i]); 447 std::move( 448 chunk_data.begin(), chunk_data.end(), std::back_inserter(data)); 449 } 450 if (preprocessing_policy_) { 451 preprocessing_policy_(data); 452 } 453 if (!data.empty()) { // skip empty chunks. 454 batch_buffer_->add_chunk_data(std::move(data)); 455 } 456 } catch (...) { 457 batch_buffer_->add_chunk_data(std::current_exception()); 458 } 459 } 460 AT_ASSERT(running_preloaders_.load() > 0); 461 --running_preloaders_; 462 if (running_preloaders_.load() == 0) { 463 // all preloaders are completed, so we can notify the batch_buffer. 464 batch_buffer_->stop(); 465 } 466 } 467 468 /// Block the current thread until the workers finish execution and exit. free_workers()469 void free_workers() { 470 if (!quit_worker_.load()) { 471 quit_worker_ = true; 472 for (auto& worker_thread : preload_threads_) { 473 worker_thread.join(); 474 } 475 } 476 } 477 478 private: 479 // Templated class that defines what is a chunk and how to read chunk data. 480 // When a chunk is returned by chunk_reader_, ChunkDataset split it into 481 // batches and caches them in batch_buffer_. 482 ChunkReader chunk_reader_; 483 484 // chunk sampler to shuffle different chunks 485 ChunkSamplerType chunk_sampler_; 486 487 // example sampler to shuffle examples in a specific chunk 488 ExampleSamplerType example_sampler_; 489 490 // batch data buffer which holds chunk data from preloading thread. 491 std::shared_ptr< 492 detail::BatchDataBuffer<UnwrappedBatchType, ExampleSamplerType>> 493 batch_buffer_; 494 495 // worker thread pool 496 std::vector<std::thread> preload_threads_; 497 498 /// The options the Dataset was configured with. 499 const ChunkDatasetOptions options_; 500 501 // function pointer wrapper to apply custom processing over chunk data. This 502 // is considered an advanced parameter for developers who want to apply a 503 // pre-process to the chunk data before sampling into minibatch. 504 // Different than the collate function, this policy is applied on the chunk 505 // level, instead of minibatch level. When a chunk of data is loaded (multiple 506 // chunks if cross_chunk_shuffle_count_ is greater than 1), this policy is 507 // applied to the full loaded data. It is useful if developers want to 508 // perform pre-processing (like bucketing) to the chunk data before 509 // example sampler samples the data. By default it's an empty pointer and no 510 // action will be taken. 511 std::function<void(UnwrappedBatchType&)> preprocessing_policy_; 512 513 // indicate whether the worker thread can be teared down 514 std::atomic<bool> quit_worker_; 515 516 // keep track of running preloaders to notify batch buffer. A value 0 517 // indicates that the chunk loading is completed. 518 std::atomic<size_t> running_preloaders_; 519 520 // mutex to synchronize chunk sampler next() call. 521 mutable std::mutex chunk_index_guard_; 522 523 // boolean value to indicate whether we need to load the checkpoint for 524 // chunk_sampler_. 525 bool load_checkpoint_; 526 }; 527 } // namespace datasets 528 } // namespace data 529 } // namespace torch 530