xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/detail/queue.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/types.h>
4 
5 #include <c10/util/Exception.h>
6 
7 #include <chrono>
8 #include <condition_variable>
9 #include <cstddef>
10 #include <mutex>
11 #include <queue>
12 
13 namespace torch {
14 namespace data {
15 namespace detail {
16 
17 /// A basic locked, blocking MPMC queue.
18 ///
19 /// Every `push` and `pop` is guarded by a mutex. A condition variable is used
20 /// to communicate insertion of new elements, such that waiting threads will be
21 /// woken up if they are currently waiting inside a call to `pop()`.
22 ///
23 /// Note that this data structure is written specifically for use with the
24 /// `DataLoader`. Its behavior is tailored to this use case and may not be
25 /// applicable to more general uses.
26 template <typename T>
27 class Queue {
28  public:
29   /// Pushes a new value to the back of the `Queue` and notifies one thread on
30   /// the waiting side about this event.
push(T value)31   void push(T value) {
32     {
33       std::lock_guard<std::mutex> lock(mutex_);
34       queue_.push(std::move(value));
35     }
36     cv_.notify_one();
37   }
38 
39   /// Blocks until at least one element is ready to be popped from the front of
40   /// the queue. An optional `timeout` in seconds can be used to limit the time
41   /// spent waiting for an element. If the wait times out, an exception is
42   /// raised.
43   T pop(std::optional<std::chrono::milliseconds> timeout = std::nullopt) {
44     std::unique_lock<std::mutex> lock(mutex_);
45     if (timeout) {
46       if (!cv_.wait_for(
47               lock, *timeout, [this] { return !this->queue_.empty(); })) {
48         // clang-format off
49         AT_ERROR(
50             "Timeout in DataLoader queue while waiting for next batch"
51             " (timeout was ", timeout->count(), " ms)");
52         // clang-format on
53       }
54     } else {
55       cv_.wait(lock, [this] { return !this->queue_.empty(); });
56     }
57     AT_ASSERT(!queue_.empty());
58     T value = queue_.front();
59     queue_.pop();
60     lock.unlock();
61     return value;
62   }
63 
64   /// Empties the queue and returns the number of elements that were present at
65   /// the start of the function. No threads are notified about this event as it
66   /// is assumed to be used to drain the queue during shutdown of a
67   /// `DataLoader`.
clear()68   size_t clear() {
69     std::lock_guard<std::mutex> lock(this->mutex_);
70     const auto size = queue_.size();
71     while (!queue_.empty()) {
72       queue_.pop();
73     }
74     return size;
75   }
76 
77  private:
78   std::queue<T> queue_;
79   std::mutex mutex_;
80   std::condition_variable cv_;
81 };
82 } // namespace detail
83 } // namespace data
84 } // namespace torch
85