xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/detail/sequencers.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/types.h>
4 
5 #include <algorithm>
6 #include <cstddef>
7 #include <vector>
8 
9 namespace torch {
10 namespace data {
11 namespace detail {
12 namespace sequencers {
13 namespace detail {
14 template <typename Result>
buffer_contains_result(const std::vector<std::optional<Result>> & buffer)15 bool buffer_contains_result(const std::vector<std::optional<Result>>& buffer) {
16   return std::any_of(
17       buffer.begin(), buffer.end(), [](const std::optional<Result>& result) {
18         return result.has_value();
19       });
20 }
21 } // namespace detail
22 
23 /// A `Sequencer` accepts a function that yields the next result of a
24 /// `DataLoader` and then has the opportunity to influence the order in which
25 /// these results are returned. The `NoSequencer` does not enforce any
26 /// sequencing and returns any result directly. The `OrderedSequencer` instead
27 /// buffers results internally to return them in order of their sequence number.
28 template <typename Result>
29 struct Sequencer {
30   using ResultProducer = std::function<std::optional<Result>()>;
31   virtual ~Sequencer() = default;
32   virtual std::optional<Result> next(ResultProducer next_result) = 0;
33 };
34 
35 /// A `Sequencer` that does not enforce any ordering. It is effectively the
36 /// identity function.
37 template <typename Result>
38 struct NoSequencer final : public Sequencer<Result> {
39   using typename Sequencer<Result>::ResultProducer;
nextfinal40   std::optional<Result> next(ResultProducer next_result) override {
41     return next_result();
42   }
43 };
44 
45 /// A `Sequencer` that buffers results and returns them in order of their
46 /// sequence number. The `OrderedSequencer` maintains an internal, monotonically
47 /// incrementing counter for the next sequence number it expects. If it receives
48 /// a result with a higher sequence number, it will buffer it for later (when
49 /// the sequence number reaches that of this result). Otherwise, if the sequence
50 /// numbers match, the result is returned.
51 ///
52 /// Implementation note: The `OrderedSequencer` is implemented with a fixed-size
53 /// buffer. Let `m` be the maximum number of jobs in the data loader's queue and
54 /// `s` be the current sequence number. Assume `m` jobs are scheduled in the
55 /// `DataLoader`. Any new result is stored at index `job.sqn mod m` in the
56 /// `OrderedSequencer`. Why are we sure sequence numbers of new jobs will not
57 /// collide with sequence numbers of buffered jobs? The `OrderedSequencer` will
58 /// not return from `next()` until it receives the result with sqn `s`. This
59 /// means no new jobs can be scheduled in the `DataLoader` in the meantime,
60 /// which enforces that as long as sqn `s` has not been received, `s + m` (which
61 /// would cause a collision in the fixed-size buffer) will not yet be scheduled.
62 template <typename Result>
63 struct OrderedSequencer : public Sequencer<Result> {
64   using typename Sequencer<Result>::ResultProducer;
65 
66   /// Constructs the `OrderedSequencer` with the maximum number of results it
67   /// will ever hold at one point in time.
OrderedSequencerOrderedSequencer68   explicit OrderedSequencer(size_t max_jobs) : buffer_(max_jobs) {}
69 
70   /// Buffers results until the next one in the expected order is received.
nextOrderedSequencer71   std::optional<Result> next(ResultProducer next_result) override {
72     // If we already have the result for the next sqn, return it.
73     if (auto& maybe_result = buffer(next_sequence_number_)) {
74       auto result = std::move(*maybe_result);
75       buffer(next_sequence_number_++).reset();
76       return result;
77     }
78     // Otherwise wait for the next result.
79     while (true) {
80       auto result = next_result();
81       if (!result) {
82         AT_ASSERT(!detail::buffer_contains_result(buffer_));
83         break;
84       }
85       // If it was not nullopt and the sequence numbers match, return it
86       // directly and bump the sequence number.
87       if (result->sequence_number == next_sequence_number_) {
88         ++next_sequence_number_;
89         return result;
90       }
91       // Stash the result for later.
92       AT_ASSERT(!buffer(result->sequence_number).has_value());
93       buffer(result->sequence_number) = std::move(result);
94     }
95     // The result was an empty optional, so we are done with this epoch.
96     return nullopt;
97   }
98 
99   /// Accesses the buffer at the `index` modulo the buffer size.
bufferOrderedSequencer100   std::optional<Result>& buffer(size_t index) {
101     return buffer_.at(index % buffer_.size());
102   }
103 
104   /// The monotonically increasing sequence number we expect.
105   size_t next_sequence_number_ = 0;
106 
107   /// A fixed-size buffer (after construction).
108   std::vector<std::optional<Result>> buffer_;
109 };
110 } // namespace sequencers
111 } // namespace detail
112 } // namespace data
113 } // namespace torch
114