1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/task_runner.h"
16
17 #include <algorithm>
18 #include <memory>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/core/data/service/common.h"
23 #include "tensorflow/core/data/service/cross_trainer_cache.h"
24 #include "tensorflow/core/data/service/data_transfer.h"
25 #include "tensorflow/core/data/service/logging_utils.h"
26 #include "tensorflow/core/data/service/thread_safe_buffer.h"
27 #include "tensorflow/core/data/service/worker.pb.h"
28 #include "tensorflow/core/data/standalone.h"
29 #include "tensorflow/core/framework/cancellation.h"
30 #include "tensorflow/core/framework/dataset.h"
31 #include "tensorflow/core/framework/tensor_util.h"
32 #include "tensorflow/core/lib/gtl/cleanup.h"
33 #include "tensorflow/core/platform/env.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/mutex.h"
36 #include "tensorflow/core/platform/status.h"
37 #include "tensorflow/core/platform/statusor.h"
38 #include "tensorflow/core/platform/thread_annotations.h"
39 #include "tensorflow/core/protobuf/service_config.pb.h"
40
41 namespace tensorflow {
42 namespace data {
43 namespace {
44 // Time to wait before skipping a round if data still isn't available.
45 constexpr int64_t kWaitBeforeSkipUs = 100 * 1000; // 100ms.
46 constexpr size_t kDefaultCrossTrainerCacheSizeBytes =
47 10 * (size_t{1} << 30); // 10GB
48
49 } // namespace
50
StandaloneTaskIterator(std::unique_ptr<standalone::Dataset> dataset,std::unique_ptr<standalone::Iterator> iterator)51 StandaloneTaskIterator::StandaloneTaskIterator(
52 std::unique_ptr<standalone::Dataset> dataset,
53 std::unique_ptr<standalone::Iterator> iterator)
54 : dataset_(std::move(dataset)), iterator_(std::move(iterator)) {}
55
GetNext(std::vector<Tensor> & element,bool & end_of_sequence)56 Status StandaloneTaskIterator::GetNext(std::vector<Tensor>& element,
57 bool& end_of_sequence) {
58 return iterator_->GetNext(&element, &end_of_sequence);
59 }
60
Cardinality() const61 int64_t StandaloneTaskIterator::Cardinality() const {
62 return dataset_->Get()->Cardinality();
63 }
64
Create(const experimental::WorkerConfig & worker_config,const TaskDef & task_def,std::unique_ptr<TaskIterator> iterator,std::unique_ptr<TaskRunner> & out)65 Status TaskRunner::Create(const experimental::WorkerConfig& worker_config,
66 const TaskDef& task_def,
67 std::unique_ptr<TaskIterator> iterator,
68 std::unique_ptr<TaskRunner>& out) {
69 if (task_def.optional_num_consumers_case() == TaskDef::kNumConsumers) {
70 int64_t cardinality = iterator->Cardinality();
71 if (cardinality != kInfiniteCardinality &&
72 cardinality != kUnknownCardinality) {
73 return errors::FailedPrecondition(
74 "Round robin reads require that the input dataset has infinite "
75 "cardinality, but the dataset has cardinality ",
76 cardinality,
77 ". Consider adding a `.repeat()` transformation to the dataset.");
78 }
79 out = std::make_unique<RoundRobinTaskRunner>(std::move(iterator),
80 task_def.num_consumers(),
81 task_def.worker_address());
82 } else if (task_def.use_cross_trainer_cache()) {
83 const size_t max_cache_size_bytes =
84 worker_config.cross_trainer_cache_size_bytes() > 0
85 ? worker_config.cross_trainer_cache_size_bytes()
86 : kDefaultCrossTrainerCacheSizeBytes;
87 out = std::make_unique<CachingTaskRunner>(std::move(iterator),
88 max_cache_size_bytes);
89 } else {
90 out = std::make_unique<FirstComeFirstServedTaskRunner>(std::move(iterator));
91 }
92 return OkStatus();
93 }
94
FirstComeFirstServedTaskRunner(std::unique_ptr<TaskIterator> iterator)95 FirstComeFirstServedTaskRunner::FirstComeFirstServedTaskRunner(
96 std::unique_ptr<TaskIterator> iterator)
97 : iterator_(std::move(iterator)), buffer_(/*buffer_size=*/1) {
98 RunPrefetchThread();
99 }
100
~FirstComeFirstServedTaskRunner()101 FirstComeFirstServedTaskRunner::~FirstComeFirstServedTaskRunner() { Cancel(); }
102
GetNext(const GetElementRequest & req,GetElementResult & result)103 Status FirstComeFirstServedTaskRunner::GetNext(const GetElementRequest& req,
104 GetElementResult& result) {
105 return GetNext(result);
106 }
107
GetNext(GetElementResult & result)108 Status FirstComeFirstServedTaskRunner::GetNext(GetElementResult& result) {
109 TF_ASSIGN_OR_RETURN(result, buffer_.Pop());
110 return OkStatus();
111 }
112
PrefetchFn()113 Status FirstComeFirstServedTaskRunner::PrefetchFn() {
114 while (true) {
115 TF_RETURN_IF_ERROR(buffer_.Push(GetNextFromInputIterator()));
116 }
117 return OkStatus();
118 }
119
RunPrefetchThread()120 void FirstComeFirstServedTaskRunner::RunPrefetchThread() {
121 auto prefetch_fn = [this] {
122 Status status = PrefetchFn();
123 if (!status.ok()) {
124 buffer_.Cancel(status);
125 }
126 };
127 prefetch_thread_ = absl::WrapUnique(Env::Default()->StartThread(
128 /*thread_options=*/{}, /*name=*/"tf_data_service_fcfs_prefetch_thread",
129 prefetch_fn));
130 }
131
132 StatusOr<GetElementResult>
GetNextFromInputIterator()133 FirstComeFirstServedTaskRunner::GetNextFromInputIterator()
134 TF_LOCKS_EXCLUDED(mu_) {
135 GetElementResult result;
136 std::vector<Tensor> element;
137 bool end_of_task = false;
138 result.skip = false;
139 {
140 mutex_lock l(mu_);
141 TF_RETURN_IF_ERROR(iterator_->GetNext(element, end_of_task));
142 result.end_of_sequence = end_of_task;
143 result.element_index = element_index_++;
144 }
145 if (!end_of_task) {
146 result.components = std::move(element);
147 }
148 return result;
149 }
150
Cancel()151 void FirstComeFirstServedTaskRunner::Cancel() {
152 VLOG(2) << "Cancelling tf.data service FCFS task.";
153 buffer_.Cancel(errors::Cancelled("tf.data service FCFS task is cancelled."));
154 }
155
CachingTaskRunner(std::unique_ptr<TaskIterator> iterator,size_t max_cache_size_bytes)156 CachingTaskRunner::CachingTaskRunner(std::unique_ptr<TaskIterator> iterator,
157 size_t max_cache_size_bytes)
158 : fcfs_task_runner_(std::move(iterator)),
159 cache_(max_cache_size_bytes,
160 std::make_unique<GetElementResultSequence>(fcfs_task_runner_)) {
161 LOG(INFO) << "Initialized tf.data service cross-trainer cache with "
162 << FormatBytes(max_cache_size_bytes) << " of memory.";
163 }
164
~CachingTaskRunner()165 CachingTaskRunner::~CachingTaskRunner() { Cancel(); }
166
GetNext(const GetElementRequest & req,GetElementResult & result)167 Status CachingTaskRunner::GetNext(const GetElementRequest& req,
168 GetElementResult& result) {
169 TF_ASSIGN_OR_RETURN(std::shared_ptr<const GetElementResult> element,
170 cache_.Get(req.trainer_id()));
171 result = element->Copy();
172 return OkStatus();
173 }
174
GetElementResultSequence(FirstComeFirstServedTaskRunner & fcfs_task_runner)175 CachingTaskRunner::GetElementResultSequence::GetElementResultSequence(
176 FirstComeFirstServedTaskRunner& fcfs_task_runner)
177 : fcfs_task_runner_(fcfs_task_runner) {}
178
179 StatusOr<GetElementResult>
GetNext()180 CachingTaskRunner::GetElementResultSequence::GetNext() {
181 GetElementResult result;
182 TF_RETURN_IF_ERROR(fcfs_task_runner_.GetNext(result));
183 if (result.end_of_sequence) {
184 return errors::InvalidArgument(
185 "Cross-trainer caching requires the input dataset to be infinite. "
186 "However, it reached the end of sequence.");
187 }
188 return result;
189 }
190
GetElementSizeBytes(const GetElementResult & element) const191 size_t CachingTaskRunner::GetElementResultSequence::GetElementSizeBytes(
192 const GetElementResult& element) const {
193 return element.EstimatedMemoryUsageBytes();
194 }
195
Cancel()196 void CachingTaskRunner::Cancel() {
197 VLOG(2) << "Cancelling tf.data service cross-trainer cache task.";
198 if (!cache_.IsCancelled()) {
199 cache_.Cancel(errors::Cancelled(
200 "tf.data service cross-trainer cache task is cancelled."));
201 }
202 fcfs_task_runner_.Cancel();
203 }
204
RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,int64_t num_consumers,string worker_address)205 RoundRobinTaskRunner::RoundRobinTaskRunner(
206 std::unique_ptr<TaskIterator> iterator, int64_t num_consumers,
207 string worker_address)
208 : num_consumers_(num_consumers),
209 worker_address_(worker_address),
210 buffer_(num_consumers_),
211 prefetch_thread_(std::move(iterator), num_consumers_) {
212 VLOG(1) << "Creating task runner for distributing data round-robin to "
213 << num_consumers << " consumers";
214 }
215
ValidateRequest(const GetElementRequest & req)216 Status RoundRobinTaskRunner::ValidateRequest(const GetElementRequest& req) {
217 if (req.consumer_index() < 0 || req.round_index() < 0) {
218 return errors::FailedPrecondition(
219 "RoundRobinTaskRunner needs to know the consumer index and element "
220 "index of each request.");
221 }
222 if (req.consumer_index() >= num_consumers_) {
223 return errors::FailedPrecondition(
224 "Requesting data for consumer index ", req.consumer_index(),
225 ", but the task is configured for only ", num_consumers_, " consumers");
226 }
227 return OkStatus();
228 }
229
PrepareFullRound(int64_t wait_us)230 Status RoundRobinTaskRunner::PrepareFullRound(int64_t wait_us)
231 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
232 VLOG(1) << worker_address_ << ": Preparing full round for round "
233 << current_round_;
234 // This was the last request to arrive, time to start a new round.
235 TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(wait_us, buffer_));
236 round_skipped_ = buffer_.empty();
237 new_round_cv_.notify_all();
238 return OkStatus();
239 }
240
PreparePartialRound()241 Status RoundRobinTaskRunner::PreparePartialRound()
242 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
243 VLOG(1) << worker_address_ << ": Starting partial round " << first_round_
244 << " for " << requests_[first_round_].size() << " consumers";
245 current_round_ = first_round_;
246 new_round_cv_.notify_all();
247 // Indicates that we need a partial round to get consumers back in sync.
248 auto next_round_request = *(requests_[first_round_ + 1].begin()->second);
249 if (next_round_request.skipped_previous_round()) {
250 VLOG(1) << "Skipping partial round";
251 round_skipped_ = true;
252 return OkStatus();
253 }
254 TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(/*wait_us=*/-1, buffer_));
255 round_skipped_ = false;
256 return OkStatus();
257 }
258
PrepareRound(const GetElementRequest & req)259 Status RoundRobinTaskRunner::PrepareRound(const GetElementRequest& req) {
260 mutex_lock l(mu_);
261 first_round_ = std::min(first_round_, req.round_index());
262 absl::flat_hash_map<int64_t, const GetElementRequest*>& round =
263 requests_[req.round_index()];
264 round[req.consumer_index()] = &req;
265 auto cleanup = gtl::MakeCleanup([&]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
266 requests_[req.round_index()].erase(req.consumer_index());
267 });
268 if (current_round_ < req.round_index() && round.size() == num_consumers_) {
269 current_round_ = req.round_index();
270 int64_t wait_us = kWaitBeforeSkipUs;
271 if (!req.allow_skip()) {
272 wait_us = -1;
273 }
274 TF_RETURN_IF_ERROR(PrepareFullRound(wait_us));
275 }
276 if (current_round_ < 0 &&
277 requests_[first_round_].size() + requests_[first_round_ + 1].size() ==
278 num_consumers_) {
279 TF_RETURN_IF_ERROR(PreparePartialRound());
280 }
281 while (!cancelled_ && current_round_ < req.round_index()) {
282 TF_RETURN_IF_ERROR(prefetch_thread_.GetStatus());
283 new_round_cv_.wait(l);
284 }
285 if (current_round_ < req.round_index() && cancelled_) {
286 return errors::Cancelled("Worker is shutting down.");
287 }
288 if (current_round_ != req.round_index()) {
289 return errors::FailedPrecondition(
290 "Consumer ", req.consumer_index(), " requested data for round ",
291 req.round_index(), ", but the current round has already reached ",
292 current_round_,
293 ". This may indicate that the consumer was restarted with the same "
294 "iteration "
295 "name.`");
296 }
297 return prefetch_thread_.GetStatus();
298 }
299
GetNext(const GetElementRequest & req,GetElementResult & result)300 Status RoundRobinTaskRunner::GetNext(const GetElementRequest& req,
301 GetElementResult& result) {
302 TF_RETURN_IF_ERROR(ValidateRequest(req));
303 result.end_of_sequence = false;
304 VLOG(2) << worker_address_ << ": Received request from consumer index "
305 << req.consumer_index() << " for round " << req.round_index();
306 TF_RETURN_IF_ERROR(PrepareRound(req));
307 tf_shared_lock l(mu_);
308 result.skip = round_skipped_;
309 if (round_skipped_) {
310 VLOG(1) << worker_address_ << ": Buffer not ready, skipping round "
311 << current_round_ << " for consumer " << req.consumer_index();
312 return OkStatus();
313 }
314 auto& buffer_result = buffer_[req.consumer_index()];
315 result.element_index = buffer_result->index;
316 std::vector<Tensor> element;
317 for (auto& component : buffer_result->components) {
318 element.push_back(tensor::DeepCopy(component));
319 }
320 if (VLOG_IS_ON(2)) {
321 int64_t size = 0;
322 for (auto& component : element) {
323 size += component.TotalBytes();
324 }
325 VLOG(2) << worker_address_ << ": Returning element " << result.element_index
326 << " to consumer " << req.consumer_index() << " for round "
327 << req.round_index() << ". element size " << size;
328 }
329 result.components = std::move(element);
330 return OkStatus();
331 }
332
Cancel()333 void RoundRobinTaskRunner::Cancel() {
334 mutex_lock l(mu_);
335 cancelled_ = true;
336 new_round_cv_.notify_all();
337 }
338
PrefetchThread(std::unique_ptr<TaskIterator> iterator,int64_t round_size)339 PrefetchThread::PrefetchThread(std::unique_ptr<TaskIterator> iterator,
340 int64_t round_size)
341 : iterator_(std::move(iterator)), round_size_(round_size) {
342 thread_ = absl::WrapUnique(
343 Env::Default()->StartThread({}, "round-robin-prefetch", [&] { Run(); }));
344 }
345
~PrefetchThread()346 PrefetchThread::~PrefetchThread() {
347 mutex_lock l(mu_);
348 cancelled_ = true;
349 cv_.notify_all();
350 }
351
Run()352 void PrefetchThread::Run() {
353 while (true) {
354 {
355 mutex_lock l(mu_);
356 while (!cancelled_ && buffer_.size() >= round_size_) {
357 cv_.wait(l);
358 }
359 if (cancelled_) {
360 return;
361 }
362 }
363 std::vector<Tensor> element;
364 bool end_of_sequence;
365 Status s = iterator_->GetNext(element, end_of_sequence);
366 if (!s.ok()) {
367 mutex_lock l(mu_);
368 status_ = s;
369 cv_.notify_all();
370 return;
371 }
372 if (end_of_sequence) {
373 mutex_lock l(mu_);
374 status_ = errors::FailedPrecondition(
375 "Encountered end of sequence on a round-robin read iterator. "
376 "Please ensure that the dataset used for round-robin reading has "
377 "infinite cardinality, e.g. by adding a .repeat() transformation "
378 "at the end.");
379 cv_.notify_all();
380 return;
381 }
382 mutex_lock l(mu_);
383 buffer_.push_back(std::make_unique<Element>(std::move(element), index_++));
384 cv_.notify_all();
385 }
386 }
387
FillBuffer(int64_t wait_us,std::vector<std::unique_ptr<Element>> & out)388 Status PrefetchThread::FillBuffer(int64_t wait_us,
389 std::vector<std::unique_ptr<Element>>& out) {
390 int64_t start_us = Env::Default()->NowMicros();
391 out.clear();
392 mutex_lock l(mu_);
393 while (buffer_.size() < round_size_ && !cancelled_ && status_.ok()) {
394 int64_t remaining_us = start_us + wait_us - Env::Default()->NowMicros();
395 if (wait_us >= 0 && remaining_us <= 0) {
396 break;
397 }
398 cv_.wait_for(l, std::chrono::microseconds(remaining_us));
399 }
400 TF_RETURN_IF_ERROR(status_);
401 if (cancelled_) {
402 return errors::Cancelled("Prefetch thread cancelled");
403 }
404 if (buffer_.size() < round_size_) {
405 DCHECK_GE(wait_us, 0);
406 return OkStatus();
407 }
408 for (auto& elem : buffer_) {
409 out.push_back(std::move(elem));
410 }
411 buffer_.clear();
412 cv_.notify_all();
413 return OkStatus();
414 }
415
GetStatus()416 Status PrefetchThread::GetStatus() {
417 mutex_lock l(mu_);
418 return status_;
419 }
420 } // namespace data
421 } // namespace tensorflow
422