1 #pragma once 2 3 #include <torch/types.h> 4 5 #include <c10/util/Exception.h> 6 7 #include <chrono> 8 #include <condition_variable> 9 #include <cstddef> 10 #include <mutex> 11 #include <queue> 12 13 namespace torch { 14 namespace data { 15 namespace detail { 16 17 /// A basic locked, blocking MPMC queue. 18 /// 19 /// Every `push` and `pop` is guarded by a mutex. A condition variable is used 20 /// to communicate insertion of new elements, such that waiting threads will be 21 /// woken up if they are currently waiting inside a call to `pop()`. 22 /// 23 /// Note that this data structure is written specifically for use with the 24 /// `DataLoader`. Its behavior is tailored to this use case and may not be 25 /// applicable to more general uses. 26 template <typename T> 27 class Queue { 28 public: 29 /// Pushes a new value to the back of the `Queue` and notifies one thread on 30 /// the waiting side about this event. push(T value)31 void push(T value) { 32 { 33 std::lock_guard<std::mutex> lock(mutex_); 34 queue_.push(std::move(value)); 35 } 36 cv_.notify_one(); 37 } 38 39 /// Blocks until at least one element is ready to be popped from the front of 40 /// the queue. An optional `timeout` in seconds can be used to limit the time 41 /// spent waiting for an element. If the wait times out, an exception is 42 /// raised. 43 T pop(std::optional<std::chrono::milliseconds> timeout = std::nullopt) { 44 std::unique_lock<std::mutex> lock(mutex_); 45 if (timeout) { 46 if (!cv_.wait_for( 47 lock, *timeout, [this] { return !this->queue_.empty(); })) { 48 // clang-format off 49 AT_ERROR( 50 "Timeout in DataLoader queue while waiting for next batch" 51 " (timeout was ", timeout->count(), " ms)"); 52 // clang-format on 53 } 54 } else { 55 cv_.wait(lock, [this] { return !this->queue_.empty(); }); 56 } 57 AT_ASSERT(!queue_.empty()); 58 T value = queue_.front(); 59 queue_.pop(); 60 lock.unlock(); 61 return value; 62 } 63 64 /// Empties the queue and returns the number of elements that were present at 65 /// the start of the function. No threads are notified about this event as it 66 /// is assumed to be used to drain the queue during shutdown of a 67 /// `DataLoader`. clear()68 size_t clear() { 69 std::lock_guard<std::mutex> lock(this->mutex_); 70 const auto size = queue_.size(); 71 while (!queue_.empty()) { 72 queue_.pop(); 73 } 74 return size; 75 } 76 77 private: 78 std::queue<T> queue_; 79 std::mutex mutex_; 80 std::condition_variable cv_; 81 }; 82 } // namespace detail 83 } // namespace data 84 } // namespace torch 85