1 #pragma once 2 3 #include <torch/data/detail/queue.h> 4 #include <torch/types.h> 5 6 #include <c10/util/Exception.h> 7 #include <optional> 8 9 #include <chrono> 10 #include <utility> 11 12 namespace torch { 13 namespace data { 14 namespace detail { 15 16 /// Encapsulates the full life cycle of DataLoader jobs. 17 /// 18 /// When a new job is enqueued to the `DataShuttle`, a counter for in-flight 19 /// jobs is bumped. This job is said to be "in-flight" until its result is 20 /// popped. Worker threads dequeue jobs as soon as they are available. When a 21 /// worker finishes a job, it enqueues the result. Only when the main thread 22 /// dequeues a result is the count of in-flight jobs decremented. When the main 23 /// thread attempts to dequeue a job but no jobs are in-flight, that means the 24 /// epoch is complete and `pop_result` returns an empty optional. 25 template <typename Job, typename Result> 26 class DataShuttle { 27 public: 28 /// Pushes a new job. Called by the main thread. push_job(Job job)29 void push_job(Job job) { 30 new_jobs_.push(std::move(job)); 31 ++in_flight_jobs_; 32 } 33 34 /// Pushes the result of a job. Called by worker threads. push_result(Result result)35 void push_result(Result result) { 36 results_.push(std::move(result)); 37 } 38 39 /// Returns the next job, blocking until there is one available. Called by 40 /// worker threads. pop_job()41 Job pop_job() { 42 return new_jobs_.pop(); 43 } 44 45 /// Returns the result of a job, or nullopt if all jobs were exhausted. Called 46 /// by the main thread. 47 std::optional<Result> pop_result( 48 std::optional<std::chrono::milliseconds> timeout = std::nullopt) { 49 if (in_flight_jobs_ > 0) { 50 auto result = results_.pop(timeout); 51 --in_flight_jobs_; 52 return result; 53 } 54 return nullopt; 55 } 56 57 /// Discards any jobs that are not yet in flight, and waits for all in-flight 58 /// jobs to finish, discarding their result. drain()59 void drain() { 60 // Clear all inputs so that no further jobs are scheduled. 61 auto number_cleared = new_jobs_.clear(); 62 in_flight_jobs_ -= number_cleared; 63 // Remove any outstanding results. 64 while (in_flight_jobs_ > 0) { 65 pop_result(); 66 } 67 } 68 69 /// Returns the number of jobs that are still in progress. 70 /// When this number is zero, an epoch is finished. in_flight_jobs()71 size_t in_flight_jobs() const noexcept { 72 return in_flight_jobs_; 73 } 74 75 private: 76 /// The queue for jobs that are not yet in flight. 77 Queue<Job> new_jobs_; 78 /// The number of in-flight jobs. 79 /// NOTE: Not atomic because only manipulated by the main thread. 80 size_t in_flight_jobs_ = 0; 81 /// The queue for results of finished jobs. 82 Queue<Result> results_; 83 }; 84 85 } // namespace detail 86 } // namespace data 87 } // namespace torch 88