xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/task_runner.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/task_runner.h"
16 
17 #include <algorithm>
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/core/data/service/common.h"
23 #include "tensorflow/core/data/service/cross_trainer_cache.h"
24 #include "tensorflow/core/data/service/data_transfer.h"
25 #include "tensorflow/core/data/service/logging_utils.h"
26 #include "tensorflow/core/data/service/thread_safe_buffer.h"
27 #include "tensorflow/core/data/service/worker.pb.h"
28 #include "tensorflow/core/data/standalone.h"
29 #include "tensorflow/core/framework/cancellation.h"
30 #include "tensorflow/core/framework/dataset.h"
31 #include "tensorflow/core/framework/tensor_util.h"
32 #include "tensorflow/core/lib/gtl/cleanup.h"
33 #include "tensorflow/core/platform/env.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/mutex.h"
36 #include "tensorflow/core/platform/status.h"
37 #include "tensorflow/core/platform/statusor.h"
38 #include "tensorflow/core/platform/thread_annotations.h"
39 #include "tensorflow/core/protobuf/service_config.pb.h"
40 
41 namespace tensorflow {
42 namespace data {
43 namespace {
44 // Time to wait before skipping a round if data still isn't available.
45 constexpr int64_t kWaitBeforeSkipUs = 100 * 1000;  // 100ms.
46 constexpr size_t kDefaultCrossTrainerCacheSizeBytes =
47     10 * (size_t{1} << 30);  // 10GB
48 
49 }  // namespace
50 
StandaloneTaskIterator(std::unique_ptr<standalone::Dataset> dataset,std::unique_ptr<standalone::Iterator> iterator)51 StandaloneTaskIterator::StandaloneTaskIterator(
52     std::unique_ptr<standalone::Dataset> dataset,
53     std::unique_ptr<standalone::Iterator> iterator)
54     : dataset_(std::move(dataset)), iterator_(std::move(iterator)) {}
55 
GetNext(std::vector<Tensor> & element,bool & end_of_sequence)56 Status StandaloneTaskIterator::GetNext(std::vector<Tensor>& element,
57                                        bool& end_of_sequence) {
58   return iterator_->GetNext(&element, &end_of_sequence);
59 }
60 
Cardinality() const61 int64_t StandaloneTaskIterator::Cardinality() const {
62   return dataset_->Get()->Cardinality();
63 }
64 
Create(const experimental::WorkerConfig & worker_config,const TaskDef & task_def,std::unique_ptr<TaskIterator> iterator,std::unique_ptr<TaskRunner> & out)65 Status TaskRunner::Create(const experimental::WorkerConfig& worker_config,
66                           const TaskDef& task_def,
67                           std::unique_ptr<TaskIterator> iterator,
68                           std::unique_ptr<TaskRunner>& out) {
69   if (task_def.optional_num_consumers_case() == TaskDef::kNumConsumers) {
70     int64_t cardinality = iterator->Cardinality();
71     if (cardinality != kInfiniteCardinality &&
72         cardinality != kUnknownCardinality) {
73       return errors::FailedPrecondition(
74           "Round robin reads require that the input dataset has infinite "
75           "cardinality, but the dataset has cardinality ",
76           cardinality,
77           ". Consider adding a `.repeat()` transformation to the dataset.");
78     }
79     out = std::make_unique<RoundRobinTaskRunner>(std::move(iterator),
80                                                  task_def.num_consumers(),
81                                                  task_def.worker_address());
82   } else if (task_def.use_cross_trainer_cache()) {
83     const size_t max_cache_size_bytes =
84         worker_config.cross_trainer_cache_size_bytes() > 0
85             ? worker_config.cross_trainer_cache_size_bytes()
86             : kDefaultCrossTrainerCacheSizeBytes;
87     out = std::make_unique<CachingTaskRunner>(std::move(iterator),
88                                               max_cache_size_bytes);
89   } else {
90     out = std::make_unique<FirstComeFirstServedTaskRunner>(std::move(iterator));
91   }
92   return OkStatus();
93 }
94 
FirstComeFirstServedTaskRunner(std::unique_ptr<TaskIterator> iterator)95 FirstComeFirstServedTaskRunner::FirstComeFirstServedTaskRunner(
96     std::unique_ptr<TaskIterator> iterator)
97     : iterator_(std::move(iterator)), buffer_(/*buffer_size=*/1) {
98   RunPrefetchThread();
99 }
100 
~FirstComeFirstServedTaskRunner()101 FirstComeFirstServedTaskRunner::~FirstComeFirstServedTaskRunner() { Cancel(); }
102 
GetNext(const GetElementRequest & req,GetElementResult & result)103 Status FirstComeFirstServedTaskRunner::GetNext(const GetElementRequest& req,
104                                                GetElementResult& result) {
105   return GetNext(result);
106 }
107 
GetNext(GetElementResult & result)108 Status FirstComeFirstServedTaskRunner::GetNext(GetElementResult& result) {
109   TF_ASSIGN_OR_RETURN(result, buffer_.Pop());
110   return OkStatus();
111 }
112 
PrefetchFn()113 Status FirstComeFirstServedTaskRunner::PrefetchFn() {
114   while (true) {
115     TF_RETURN_IF_ERROR(buffer_.Push(GetNextFromInputIterator()));
116   }
117   return OkStatus();
118 }
119 
RunPrefetchThread()120 void FirstComeFirstServedTaskRunner::RunPrefetchThread() {
121   auto prefetch_fn = [this] {
122     Status status = PrefetchFn();
123     if (!status.ok()) {
124       buffer_.Cancel(status);
125     }
126   };
127   prefetch_thread_ = absl::WrapUnique(Env::Default()->StartThread(
128       /*thread_options=*/{}, /*name=*/"tf_data_service_fcfs_prefetch_thread",
129       prefetch_fn));
130 }
131 
132 StatusOr<GetElementResult>
GetNextFromInputIterator()133 FirstComeFirstServedTaskRunner::GetNextFromInputIterator()
134     TF_LOCKS_EXCLUDED(mu_) {
135   GetElementResult result;
136   std::vector<Tensor> element;
137   bool end_of_task = false;
138   result.skip = false;
139   {
140     mutex_lock l(mu_);
141     TF_RETURN_IF_ERROR(iterator_->GetNext(element, end_of_task));
142     result.end_of_sequence = end_of_task;
143     result.element_index = element_index_++;
144   }
145   if (!end_of_task) {
146     result.components = std::move(element);
147   }
148   return result;
149 }
150 
Cancel()151 void FirstComeFirstServedTaskRunner::Cancel() {
152   VLOG(2) << "Cancelling tf.data service FCFS task.";
153   buffer_.Cancel(errors::Cancelled("tf.data service FCFS task is cancelled."));
154 }
155 
CachingTaskRunner(std::unique_ptr<TaskIterator> iterator,size_t max_cache_size_bytes)156 CachingTaskRunner::CachingTaskRunner(std::unique_ptr<TaskIterator> iterator,
157                                      size_t max_cache_size_bytes)
158     : fcfs_task_runner_(std::move(iterator)),
159       cache_(max_cache_size_bytes,
160              std::make_unique<GetElementResultSequence>(fcfs_task_runner_)) {
161   LOG(INFO) << "Initialized tf.data service cross-trainer cache with "
162             << FormatBytes(max_cache_size_bytes) << " of memory.";
163 }
164 
~CachingTaskRunner()165 CachingTaskRunner::~CachingTaskRunner() { Cancel(); }
166 
GetNext(const GetElementRequest & req,GetElementResult & result)167 Status CachingTaskRunner::GetNext(const GetElementRequest& req,
168                                   GetElementResult& result) {
169   TF_ASSIGN_OR_RETURN(std::shared_ptr<const GetElementResult> element,
170                       cache_.Get(req.trainer_id()));
171   result = element->Copy();
172   return OkStatus();
173 }
174 
GetElementResultSequence(FirstComeFirstServedTaskRunner & fcfs_task_runner)175 CachingTaskRunner::GetElementResultSequence::GetElementResultSequence(
176     FirstComeFirstServedTaskRunner& fcfs_task_runner)
177     : fcfs_task_runner_(fcfs_task_runner) {}
178 
179 StatusOr<GetElementResult>
GetNext()180 CachingTaskRunner::GetElementResultSequence::GetNext() {
181   GetElementResult result;
182   TF_RETURN_IF_ERROR(fcfs_task_runner_.GetNext(result));
183   if (result.end_of_sequence) {
184     return errors::InvalidArgument(
185         "Cross-trainer caching requires the input dataset to be infinite. "
186         "However, it reached the end of sequence.");
187   }
188   return result;
189 }
190 
GetElementSizeBytes(const GetElementResult & element) const191 size_t CachingTaskRunner::GetElementResultSequence::GetElementSizeBytes(
192     const GetElementResult& element) const {
193   return element.EstimatedMemoryUsageBytes();
194 }
195 
Cancel()196 void CachingTaskRunner::Cancel() {
197   VLOG(2) << "Cancelling tf.data service cross-trainer cache task.";
198   if (!cache_.IsCancelled()) {
199     cache_.Cancel(errors::Cancelled(
200         "tf.data service cross-trainer cache task is cancelled."));
201   }
202   fcfs_task_runner_.Cancel();
203 }
204 
RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,int64_t num_consumers,string worker_address)205 RoundRobinTaskRunner::RoundRobinTaskRunner(
206     std::unique_ptr<TaskIterator> iterator, int64_t num_consumers,
207     string worker_address)
208     : num_consumers_(num_consumers),
209       worker_address_(worker_address),
210       buffer_(num_consumers_),
211       prefetch_thread_(std::move(iterator), num_consumers_) {
212   VLOG(1) << "Creating task runner for distributing data round-robin to "
213           << num_consumers << " consumers";
214 }
215 
ValidateRequest(const GetElementRequest & req)216 Status RoundRobinTaskRunner::ValidateRequest(const GetElementRequest& req) {
217   if (req.consumer_index() < 0 || req.round_index() < 0) {
218     return errors::FailedPrecondition(
219         "RoundRobinTaskRunner needs to know the consumer index and element "
220         "index of each request.");
221   }
222   if (req.consumer_index() >= num_consumers_) {
223     return errors::FailedPrecondition(
224         "Requesting data for consumer index ", req.consumer_index(),
225         ", but the task is configured for only ", num_consumers_, " consumers");
226   }
227   return OkStatus();
228 }
229 
PrepareFullRound(int64_t wait_us)230 Status RoundRobinTaskRunner::PrepareFullRound(int64_t wait_us)
231     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
232   VLOG(1) << worker_address_ << ": Preparing full round for round "
233           << current_round_;
234   // This was the last request to arrive, time to start a new round.
235   TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(wait_us, buffer_));
236   round_skipped_ = buffer_.empty();
237   new_round_cv_.notify_all();
238   return OkStatus();
239 }
240 
PreparePartialRound()241 Status RoundRobinTaskRunner::PreparePartialRound()
242     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
243   VLOG(1) << worker_address_ << ": Starting partial round " << first_round_
244           << " for " << requests_[first_round_].size() << " consumers";
245   current_round_ = first_round_;
246   new_round_cv_.notify_all();
247   // Indicates that we need a partial round to get consumers back in sync.
248   auto next_round_request = *(requests_[first_round_ + 1].begin()->second);
249   if (next_round_request.skipped_previous_round()) {
250     VLOG(1) << "Skipping partial round";
251     round_skipped_ = true;
252     return OkStatus();
253   }
254   TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(/*wait_us=*/-1, buffer_));
255   round_skipped_ = false;
256   return OkStatus();
257 }
258 
PrepareRound(const GetElementRequest & req)259 Status RoundRobinTaskRunner::PrepareRound(const GetElementRequest& req) {
260   mutex_lock l(mu_);
261   first_round_ = std::min(first_round_, req.round_index());
262   absl::flat_hash_map<int64_t, const GetElementRequest*>& round =
263       requests_[req.round_index()];
264   round[req.consumer_index()] = &req;
265   auto cleanup = gtl::MakeCleanup([&]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
266     requests_[req.round_index()].erase(req.consumer_index());
267   });
268   if (current_round_ < req.round_index() && round.size() == num_consumers_) {
269     current_round_ = req.round_index();
270     int64_t wait_us = kWaitBeforeSkipUs;
271     if (!req.allow_skip()) {
272       wait_us = -1;
273     }
274     TF_RETURN_IF_ERROR(PrepareFullRound(wait_us));
275   }
276   if (current_round_ < 0 &&
277       requests_[first_round_].size() + requests_[first_round_ + 1].size() ==
278           num_consumers_) {
279     TF_RETURN_IF_ERROR(PreparePartialRound());
280   }
281   while (!cancelled_ && current_round_ < req.round_index()) {
282     TF_RETURN_IF_ERROR(prefetch_thread_.GetStatus());
283     new_round_cv_.wait(l);
284   }
285   if (current_round_ < req.round_index() && cancelled_) {
286     return errors::Cancelled("Worker is shutting down.");
287   }
288   if (current_round_ != req.round_index()) {
289     return errors::FailedPrecondition(
290         "Consumer ", req.consumer_index(), " requested data for round ",
291         req.round_index(), ", but the current round has already reached ",
292         current_round_,
293         ". This may indicate that the consumer was restarted with the same "
294         "iteration "
295         "name.`");
296   }
297   return prefetch_thread_.GetStatus();
298 }
299 
GetNext(const GetElementRequest & req,GetElementResult & result)300 Status RoundRobinTaskRunner::GetNext(const GetElementRequest& req,
301                                      GetElementResult& result) {
302   TF_RETURN_IF_ERROR(ValidateRequest(req));
303   result.end_of_sequence = false;
304   VLOG(2) << worker_address_ << ": Received request from consumer index "
305           << req.consumer_index() << " for round " << req.round_index();
306   TF_RETURN_IF_ERROR(PrepareRound(req));
307   tf_shared_lock l(mu_);
308   result.skip = round_skipped_;
309   if (round_skipped_) {
310     VLOG(1) << worker_address_ << ": Buffer not ready, skipping round "
311             << current_round_ << " for consumer " << req.consumer_index();
312     return OkStatus();
313   }
314   auto& buffer_result = buffer_[req.consumer_index()];
315   result.element_index = buffer_result->index;
316   std::vector<Tensor> element;
317   for (auto& component : buffer_result->components) {
318     element.push_back(tensor::DeepCopy(component));
319   }
320   if (VLOG_IS_ON(2)) {
321     int64_t size = 0;
322     for (auto& component : element) {
323       size += component.TotalBytes();
324     }
325     VLOG(2) << worker_address_ << ": Returning element " << result.element_index
326             << " to consumer " << req.consumer_index() << " for round "
327             << req.round_index() << ". element size " << size;
328   }
329   result.components = std::move(element);
330   return OkStatus();
331 }
332 
Cancel()333 void RoundRobinTaskRunner::Cancel() {
334   mutex_lock l(mu_);
335   cancelled_ = true;
336   new_round_cv_.notify_all();
337 }
338 
PrefetchThread(std::unique_ptr<TaskIterator> iterator,int64_t round_size)339 PrefetchThread::PrefetchThread(std::unique_ptr<TaskIterator> iterator,
340                                int64_t round_size)
341     : iterator_(std::move(iterator)), round_size_(round_size) {
342   thread_ = absl::WrapUnique(
343       Env::Default()->StartThread({}, "round-robin-prefetch", [&] { Run(); }));
344 }
345 
~PrefetchThread()346 PrefetchThread::~PrefetchThread() {
347   mutex_lock l(mu_);
348   cancelled_ = true;
349   cv_.notify_all();
350 }
351 
Run()352 void PrefetchThread::Run() {
353   while (true) {
354     {
355       mutex_lock l(mu_);
356       while (!cancelled_ && buffer_.size() >= round_size_) {
357         cv_.wait(l);
358       }
359       if (cancelled_) {
360         return;
361       }
362     }
363     std::vector<Tensor> element;
364     bool end_of_sequence;
365     Status s = iterator_->GetNext(element, end_of_sequence);
366     if (!s.ok()) {
367       mutex_lock l(mu_);
368       status_ = s;
369       cv_.notify_all();
370       return;
371     }
372     if (end_of_sequence) {
373       mutex_lock l(mu_);
374       status_ = errors::FailedPrecondition(
375           "Encountered end of sequence on a round-robin read iterator. "
376           "Please ensure that the dataset used for round-robin reading has "
377           "infinite cardinality, e.g. by adding a .repeat() transformation "
378           "at the end.");
379       cv_.notify_all();
380       return;
381     }
382     mutex_lock l(mu_);
383     buffer_.push_back(std::make_unique<Element>(std::move(element), index_++));
384     cv_.notify_all();
385   }
386 }
387 
FillBuffer(int64_t wait_us,std::vector<std::unique_ptr<Element>> & out)388 Status PrefetchThread::FillBuffer(int64_t wait_us,
389                                   std::vector<std::unique_ptr<Element>>& out) {
390   int64_t start_us = Env::Default()->NowMicros();
391   out.clear();
392   mutex_lock l(mu_);
393   while (buffer_.size() < round_size_ && !cancelled_ && status_.ok()) {
394     int64_t remaining_us = start_us + wait_us - Env::Default()->NowMicros();
395     if (wait_us >= 0 && remaining_us <= 0) {
396       break;
397     }
398     cv_.wait_for(l, std::chrono::microseconds(remaining_us));
399   }
400   TF_RETURN_IF_ERROR(status_);
401   if (cancelled_) {
402     return errors::Cancelled("Prefetch thread cancelled");
403   }
404   if (buffer_.size() < round_size_) {
405     DCHECK_GE(wait_us, 0);
406     return OkStatus();
407   }
408   for (auto& elem : buffer_) {
409     out.push_back(std::move(elem));
410   }
411   buffer_.clear();
412   cv_.notify_all();
413   return OkStatus();
414 }
415 
GetStatus()416 Status PrefetchThread::GetStatus() {
417   mutex_lock l(mu_);
418   return status_;
419 }
420 }  // namespace data
421 }  // namespace tensorflow
422