xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/task_runner.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_
16 #define TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/core/data/service/common.pb.h"
22 #include "tensorflow/core/data/service/cross_trainer_cache.h"
23 #include "tensorflow/core/data/service/data_transfer.h"
24 #include "tensorflow/core/data/service/thread_safe_buffer.h"
25 #include "tensorflow/core/data/service/worker.pb.h"
26 #include "tensorflow/core/data/standalone.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/status.h"
30 #include "tensorflow/core/platform/statusor.h"
31 #include "tensorflow/core/platform/thread_annotations.h"
32 #include "tensorflow/core/protobuf/service_config.pb.h"
33 
34 namespace tensorflow {
35 namespace data {
36 
37 // Iterator over a task's elements.
38 class TaskIterator {
39  public:
40   virtual ~TaskIterator() = default;
41   // If the iterator is not yet exhausted, `GetNext` stores the next element in
42   // `element` and sets `end_of_sequence` to `false`. Otherwise, sets
43   // `end_of_sequence to `true`.
44   virtual Status GetNext(std::vector<Tensor>& element,
45                          bool& end_of_sequence) = 0;
46   // Reports the cardinality of the dataset that created this iterator.
47   virtual int64_t Cardinality() const = 0;
48 };
49 
50 // Implementation of TaskIterator wrapping a standalone iterator.
51 class StandaloneTaskIterator : public TaskIterator {
52  public:
53   // `dataset` should be the dataset that created `iterator`.
54   // StandaloneTaskIterator takes ownership of the dataset to ensures it
55   // lives as long as `iterator`.
56   StandaloneTaskIterator(std::unique_ptr<standalone::Dataset> dataset,
57                          std::unique_ptr<standalone::Iterator> iterator);
58   Status GetNext(std::vector<Tensor>& element, bool& end_of_sequence) override;
59   int64_t Cardinality() const override;
60 
61  private:
62   std::unique_ptr<standalone::Dataset> dataset_;
63   std::unique_ptr<standalone::Iterator> iterator_;
64 };
65 
66 // Interface for providing elements to task consumers.
67 class TaskRunner {
68  public:
69   // Creates a `TaskRunner` and stores it in `out`.
70   static Status Create(const experimental::WorkerConfig& worker_config,
71                        const TaskDef& task_def,
72                        std::unique_ptr<TaskIterator> iterator,
73                        std::unique_ptr<TaskRunner>& out);
74   virtual ~TaskRunner() = default;
75   // Gets the next element for the given request.
76   virtual Status GetNext(const GetElementRequest& req,
77                          GetElementResult& result) = 0;
78   // Cancels in-progress `GetNext` requests.
79   virtual void Cancel() = 0;
80 };
81 
82 // A task runner which provides elements on a first-come first-served basis.
83 // It does not consider which consumer is making the request.
84 class FirstComeFirstServedTaskRunner : public TaskRunner {
85  public:
86   explicit FirstComeFirstServedTaskRunner(
87       std::unique_ptr<TaskIterator> iterator);
88   ~FirstComeFirstServedTaskRunner() override;
89 
90   // Gets the next element. It may block if the element is not ready yet.
91   Status GetNext(const GetElementRequest& req,
92                  GetElementResult& result) override;
93   Status GetNext(GetElementResult& result);
94 
95   void Cancel() override;
96 
97  private:
98   // Function to continually prefetch the next element. Returns an error if the
99   // task has been cancelled.
100   Status PrefetchFn();
101 
102   // Runs `PrefetchFn` on a dedicated thread.
103   void RunPrefetchThread();
104 
105   // Gets the next element from the input iterator.
106   StatusOr<GetElementResult> GetNextFromInputIterator() TF_LOCKS_EXCLUDED(mu_);
107 
108   mutex mu_;
109   std::unique_ptr<TaskIterator> iterator_ TF_GUARDED_BY(mu_);
110   int64_t element_index_ TF_GUARDED_BY(mu_) = 0;
111 
112   ThreadSafeBuffer<GetElementResult> buffer_;
113   std::unique_ptr<Thread> prefetch_thread_;
114 
115   TF_DISALLOW_COPY_AND_ASSIGN(FirstComeFirstServedTaskRunner);
116 };
117 
118 // A task runner which prefetches elements on a first-come first-served basis
119 // and caches elements in a sliding-window `CrossTrainerCache`. The cache has a
120 // bounded size and progresses when a trainer that has consumed all elements in
121 // the cache. Trainers read from a sliding window of the dataset and may not
122 // read the full dataset.
123 class CachingTaskRunner : public TaskRunner {
124  public:
125   explicit CachingTaskRunner(std::unique_ptr<TaskIterator> iterator,
126                              size_t max_cache_size_bytes);
127   ~CachingTaskRunner() override;
128 
129   // Gets the next element from the cross-trainer cache, blocking if the data is
130   // not ready.
131   // REQUIRES: !req.trainer_id().empty()
132   Status GetNext(const GetElementRequest& req,
133                  GetElementResult& result) override;
134 
135   // Cancel the task runner. After cancelling, all the `GetNext` calls will
136   // return a Cancelled status.
137   void Cancel() override;
138 
139  private:
140   // The `GetElementResultSequence` generates a sequence of elements from the
141   // `FirstComeFirstServedTaskRunner`. It is used for the `CrossTrainerCache` to
142   // generate cached elements.
143   class GetElementResultSequence : public CachableSequence<GetElementResult> {
144    public:
145     explicit GetElementResultSequence(
146         FirstComeFirstServedTaskRunner& fcfs_task_runner);
147     StatusOr<GetElementResult> GetNext() override;
148     size_t GetElementSizeBytes(const GetElementResult& element) const override;
149 
150    private:
151     FirstComeFirstServedTaskRunner& fcfs_task_runner_;
152   };
153 
154   FirstComeFirstServedTaskRunner fcfs_task_runner_;
155   CrossTrainerCache<GetElementResult> cache_;
156 
157   TF_DISALLOW_COPY_AND_ASSIGN(CachingTaskRunner);
158 };
159 
160 // An element produced by a task.
161 struct Element {
ElementElement162   explicit Element(std::vector<Tensor>&& components, int64_t index)
163       : components(components), index(index) {}
164   // The components of the element.
165   std::vector<Tensor> components;
166   // The element's index within the task, e.g. 0 for the first element produced
167   // by the task, 1 for the second element, etc.
168   int64_t index;
169 };
170 
171 // Thread for prefetching a round worth of elements.
172 class PrefetchThread {
173  public:
174   explicit PrefetchThread(std::unique_ptr<TaskIterator> iterator,
175                           int64_t round_size);
176   ~PrefetchThread();
177   // Runs the prefetch thread. It runs until an error is encountered or the
178   // destructor is called.
179   void Run();
180   // Fills `out` with a round of data. Waits for up to `wait_us` microseconds
181   // before giving up and returning with `out` empty. A negative `wait_us`
182   // signals to wait indefinitely.
183   Status FillBuffer(int64_t wait_us,
184                     std::vector<std::unique_ptr<Element>>& out);
185   // Returns the status for any failures encountered by the prefetch thread.
186   Status GetStatus();
187 
188  private:
189   const std::unique_ptr<TaskIterator> iterator_;
190   const int64_t round_size_;
191   mutex mu_;
192   int64_t index_ TF_GUARDED_BY(mu_) = 0;
193   // Buffered results for the next round.
194   std::vector<std::unique_ptr<Element>> buffer_ TF_GUARDED_BY(mu_);
195   // The status if the prefetch thread fails.
196   Status status_ TF_GUARDED_BY(mu_) = OkStatus();
197   // Condition variable notified when elements are added to or removed from
198   // `buffer_`, or when `status_` is changed.
199   condition_variable cv_;
200   bool cancelled_ TF_GUARDED_BY(mu_) = false;
201   // Thread which constantly tries to fill `buffer_` up with
202   // `num_consumers` elements.
203   std::unique_ptr<Thread> thread_;
204 };
205 
206 // A task runner which enforces round-robin order for consuming a task's
207 // elements. `RoundRobinTaskRunner` provides elements in a series of "rounds".
208 // In each successive round, the runner waits to receive requests from all
209 // consumers. These requests are blocked until all requests arrive. Once all
210 // requests arrive, the runner hands out elements to consumers in order of their
211 // consumer indices.
212 //
213 // Consumers are expected to successively request consecutive element indices,
214 // starting at 0. The same element can be requested multiple times by the same
215 // consumer, as long as the consumer hasn't yet requested the next element (at
216 // the start of each round we discard elements from the previous round).
217 //
218 // If the worker restarts mid-round, a situation arises where some consumers
219 // are requesting element index `n` while others are requesting element index
220 // `n + 1`. To remedy this, the first round after restart may be a partial
221 // round, where we only serve elements to consumers requesting data for element
222 // index `n`, blocking other consumers until the second round.
223 class RoundRobinTaskRunner : public TaskRunner {
224  public:
225   RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,
226                        int64_t num_consumers, string worker_address);
227 
228   Status GetNext(const GetElementRequest& req,
229                  GetElementResult& result) override;
230   void Cancel() override;
231 
232  private:
233   // Prepares a full round of data. `wait_us` indicates how long to wait before
234   // skipping if a full round of data is not yet ready.
235   Status PrepareFullRound(int64_t wait_us) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
236   // Prepares a partial round to get consumers back in sync.
237   Status PreparePartialRound() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
238   Status ValidateRequest(const GetElementRequest& req);
239   // Prepares data for the next round, blocking until the round is ready to
240   // start.
241   Status PrepareRound(const GetElementRequest& req);
242   const int64_t num_consumers_;
243   const string worker_address_;
244   mutex mu_;
245   bool cancelled_ TF_GUARDED_BY(mu_) = false;
246   // Condition variable notified whenever we start a new round of round-robin.
247   condition_variable new_round_cv_;
248   // Outstanding requests, indexed by round number and then consumer index.
249   absl::flat_hash_map<int64_t,
250                       absl::flat_hash_map<int64_t, const GetElementRequest*>>
251       requests_ TF_GUARDED_BY(mu_);
252   // Index of the first round we plan to serve. At startup, this is the minimum
253   // of all requested element indices.
254   int64_t first_round_ TF_GUARDED_BY(mu_) = kint64max;
255   int64_t current_round_ TF_GUARDED_BY(mu_) = -1;
256   bool round_skipped_ TF_GUARDED_BY(mu_) = false;
257   // Buffered results for the current round.
258   std::vector<std::unique_ptr<Element>> buffer_ TF_GUARDED_BY(mu_);
259   // Thread which constantly tries to prepare `num_consumers` elements for the
260   // next round.
261   PrefetchThread prefetch_thread_;
262 };
263 
264 }  // namespace data
265 }  // namespace tensorflow
266 
267 #endif  // TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_
268