1 #pragma once
2
3 #include <torch/types.h>
4
5 #include <algorithm>
6 #include <cstddef>
7 #include <vector>
8
9 namespace torch {
10 namespace data {
11 namespace detail {
12 namespace sequencers {
13 namespace detail {
14 template <typename Result>
buffer_contains_result(const std::vector<std::optional<Result>> & buffer)15 bool buffer_contains_result(const std::vector<std::optional<Result>>& buffer) {
16 return std::any_of(
17 buffer.begin(), buffer.end(), [](const std::optional<Result>& result) {
18 return result.has_value();
19 });
20 }
21 } // namespace detail
22
23 /// A `Sequencer` accepts a function that yields the next result of a
24 /// `DataLoader` and then has the opportunity to influence the order in which
25 /// these results are returned. The `NoSequencer` does not enforce any
26 /// sequencing and returns any result directly. The `OrderedSequencer` instead
27 /// buffers results internally to return them in order of their sequence number.
28 template <typename Result>
29 struct Sequencer {
30 using ResultProducer = std::function<std::optional<Result>()>;
31 virtual ~Sequencer() = default;
32 virtual std::optional<Result> next(ResultProducer next_result) = 0;
33 };
34
35 /// A `Sequencer` that does not enforce any ordering. It is effectively the
36 /// identity function.
37 template <typename Result>
38 struct NoSequencer final : public Sequencer<Result> {
39 using typename Sequencer<Result>::ResultProducer;
nextfinal40 std::optional<Result> next(ResultProducer next_result) override {
41 return next_result();
42 }
43 };
44
45 /// A `Sequencer` that buffers results and returns them in order of their
46 /// sequence number. The `OrderedSequencer` maintains an internal, monotonically
47 /// incrementing counter for the next sequence number it expects. If it receives
48 /// a result with a higher sequence number, it will buffer it for later (when
49 /// the sequence number reaches that of this result). Otherwise, if the sequence
50 /// numbers match, the result is returned.
51 ///
52 /// Implementation note: The `OrderedSequencer` is implemented with a fixed-size
53 /// buffer. Let `m` be the maximum number of jobs in the data loader's queue and
54 /// `s` be the current sequence number. Assume `m` jobs are scheduled in the
55 /// `DataLoader`. Any new result is stored at index `job.sqn mod m` in the
56 /// `OrderedSequencer`. Why are we sure sequence numbers of new jobs will not
57 /// collide with sequence numbers of buffered jobs? The `OrderedSequencer` will
58 /// not return from `next()` until it receives the result with sqn `s`. This
59 /// means no new jobs can be scheduled in the `DataLoader` in the meantime,
60 /// which enforces that as long as sqn `s` has not been received, `s + m` (which
61 /// would cause a collision in the fixed-size buffer) will not yet be scheduled.
62 template <typename Result>
63 struct OrderedSequencer : public Sequencer<Result> {
64 using typename Sequencer<Result>::ResultProducer;
65
66 /// Constructs the `OrderedSequencer` with the maximum number of results it
67 /// will ever hold at one point in time.
OrderedSequencerOrderedSequencer68 explicit OrderedSequencer(size_t max_jobs) : buffer_(max_jobs) {}
69
70 /// Buffers results until the next one in the expected order is received.
nextOrderedSequencer71 std::optional<Result> next(ResultProducer next_result) override {
72 // If we already have the result for the next sqn, return it.
73 if (auto& maybe_result = buffer(next_sequence_number_)) {
74 auto result = std::move(*maybe_result);
75 buffer(next_sequence_number_++).reset();
76 return result;
77 }
78 // Otherwise wait for the next result.
79 while (true) {
80 auto result = next_result();
81 if (!result) {
82 AT_ASSERT(!detail::buffer_contains_result(buffer_));
83 break;
84 }
85 // If it was not nullopt and the sequence numbers match, return it
86 // directly and bump the sequence number.
87 if (result->sequence_number == next_sequence_number_) {
88 ++next_sequence_number_;
89 return result;
90 }
91 // Stash the result for later.
92 AT_ASSERT(!buffer(result->sequence_number).has_value());
93 buffer(result->sequence_number) = std::move(result);
94 }
95 // The result was an empty optional, so we are done with this epoch.
96 return nullopt;
97 }
98
99 /// Accesses the buffer at the `index` modulo the buffer size.
bufferOrderedSequencer100 std::optional<Result>& buffer(size_t index) {
101 return buffer_.at(index % buffer_.size());
102 }
103
104 /// The monotonically increasing sequence number we expect.
105 size_t next_sequence_number_ = 0;
106
107 /// A fixed-size buffer (after construction).
108 std::vector<std::optional<Result>> buffer_;
109 };
110 } // namespace sequencers
111 } // namespace detail
112 } // namespace data
113 } // namespace torch
114