xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/detail/data_shuttle.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/data/detail/queue.h>
4 #include <torch/types.h>
5 
6 #include <c10/util/Exception.h>
7 #include <optional>
8 
9 #include <chrono>
10 #include <utility>
11 
12 namespace torch {
13 namespace data {
14 namespace detail {
15 
16 /// Encapsulates the full life cycle of DataLoader jobs.
17 ///
18 /// When a new job is enqueued to the `DataShuttle`, a counter for in-flight
19 /// jobs is bumped. This job is said to be "in-flight" until its result is
20 /// popped. Worker threads dequeue jobs as soon as they are available. When a
21 /// worker finishes a job, it enqueues the result. Only when the main thread
22 /// dequeues a result is the count of in-flight jobs decremented. When the main
23 /// thread attempts to dequeue a job but no jobs are in-flight, that means the
24 /// epoch is complete and `pop_result` returns an empty optional.
25 template <typename Job, typename Result>
26 class DataShuttle {
27  public:
28   /// Pushes a new job. Called by the main thread.
push_job(Job job)29   void push_job(Job job) {
30     new_jobs_.push(std::move(job));
31     ++in_flight_jobs_;
32   }
33 
34   /// Pushes the result of a job. Called by worker threads.
push_result(Result result)35   void push_result(Result result) {
36     results_.push(std::move(result));
37   }
38 
39   /// Returns the next job, blocking until there is one available. Called by
40   /// worker threads.
pop_job()41   Job pop_job() {
42     return new_jobs_.pop();
43   }
44 
45   /// Returns the result of a job, or nullopt if all jobs were exhausted. Called
46   /// by the main thread.
47   std::optional<Result> pop_result(
48       std::optional<std::chrono::milliseconds> timeout = std::nullopt) {
49     if (in_flight_jobs_ > 0) {
50       auto result = results_.pop(timeout);
51       --in_flight_jobs_;
52       return result;
53     }
54     return nullopt;
55   }
56 
57   /// Discards any jobs that are not yet in flight, and waits for all in-flight
58   /// jobs to finish, discarding their result.
drain()59   void drain() {
60     // Clear all inputs so that no further jobs are scheduled.
61     auto number_cleared = new_jobs_.clear();
62     in_flight_jobs_ -= number_cleared;
63     // Remove any outstanding results.
64     while (in_flight_jobs_ > 0) {
65       pop_result();
66     }
67   }
68 
69   /// Returns the number of jobs that are still in progress.
70   /// When this number is zero, an epoch is finished.
in_flight_jobs()71   size_t in_flight_jobs() const noexcept {
72     return in_flight_jobs_;
73   }
74 
75  private:
76   /// The queue for jobs that are not yet in flight.
77   Queue<Job> new_jobs_;
78   /// The number of in-flight jobs.
79   /// NOTE: Not atomic because only manipulated by the main thread.
80   size_t in_flight_jobs_ = 0;
81   /// The queue for results of finished jobs.
82   Queue<Result> results_;
83 };
84 
85 } // namespace detail
86 } // namespace data
87 } // namespace torch
88