1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
16
17 #include <algorithm>
18 #include <deque>
19
20 #include "tensorflow/core/data/dataset_utils.h"
21 #include "tensorflow/core/data/name_utils.h"
22 #include "tensorflow/core/data/stats_utils.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/metrics.h"
25 #include "tensorflow/core/framework/model.h"
26 #include "tensorflow/core/framework/partial_tensor_shape.h"
27 #include "tensorflow/core/framework/stats_aggregator.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/lib/gtl/cleanup.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/lib/strings/stringprintf.h"
32 #include "tensorflow/core/platform/stringprintf.h"
33 #include "tensorflow/core/profiler/lib/traceme.h"
34 #include "tensorflow/core/profiler/lib/traceme_encode.h"
35 #include "tensorflow/core/protobuf/error_codes.pb.h"
36
37 namespace tensorflow {
38 namespace data {
39
40 // See documentation in ../../ops/dataset_ops.cc for a high-level
41 // description of the following op.
42
43 /* static */ constexpr const char* const PrefetchDatasetOp::kDatasetType;
44 /* static */ constexpr const char* const PrefetchDatasetOp::kInputDataset;
45 /* static */ constexpr const char* const PrefetchDatasetOp::kBufferSize;
46 /* static */ constexpr const char* const PrefetchDatasetOp::kOutputTypes;
47 /* static */ constexpr const char* const PrefetchDatasetOp::kOutputShapes;
48 /* static */ constexpr const char* const PrefetchDatasetOp::kSlackPeriod;
49 /* static */ constexpr const char* const PrefetchDatasetOp::kLegacyAutotune;
50 /* static */ constexpr const char* const PrefetchDatasetOp::kBufferSizeMin;
51
52 namespace {
53
54 // Determines the fraction of slack time by which to delay prefetching of data.
55 constexpr double kSleepFactor = 0.2;
56 constexpr char kBuffer[] = "buffer";
57 constexpr char kStatus[] = "status";
58 constexpr char kSizeSuffix[] = ".size";
59 constexpr char kCodeSuffix[] = ".code";
60 constexpr char kErrorMessageSuffix[] = ".error_message";
61
62 } // namespace
63
64 class PrefetchDatasetOp::Dataset : public DatasetBase {
65 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64_t buffer_size,int64_t slack_period,bool legacy_autotune,int64_t buffer_size_min)66 Dataset(OpKernelContext* ctx, const DatasetBase* input, int64_t buffer_size,
67 int64_t slack_period, bool legacy_autotune, int64_t buffer_size_min)
68 : DatasetBase(DatasetContext(ctx)),
69 input_(input),
70 buffer_size_(buffer_size),
71 slack_period_(slack_period),
72 legacy_autotune_(legacy_autotune),
73 buffer_size_min_(buffer_size_min) {
74 input_->Ref();
75 }
76
~Dataset()77 ~Dataset() override { input_->Unref(); }
78
MakeIteratorInternal(const string & prefix) const79 std::unique_ptr<IteratorBase> MakeIteratorInternal(
80 const string& prefix) const override {
81 return std::make_unique<Iterator>(Iterator::Params{
82 this, name_utils::IteratorPrefix(kDatasetType, prefix)});
83 }
84
output_dtypes() const85 const DataTypeVector& output_dtypes() const override {
86 return input_->output_dtypes();
87 }
88
output_shapes() const89 const std::vector<PartialTensorShape>& output_shapes() const override {
90 return input_->output_shapes();
91 }
92
DebugString() const93 string DebugString() const override {
94 return name_utils::DatasetDebugString(kDatasetType);
95 }
96
CardinalityInternal() const97 int64_t CardinalityInternal() const override { return input_->Cardinality(); }
98
CardinalityInternal(CardinalityOptions options) const99 int64_t CardinalityInternal(CardinalityOptions options) const override {
100 return input_->Cardinality(options);
101 }
102
InputDatasets(std::vector<const DatasetBase * > * inputs) const103 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
104 inputs->push_back(input_);
105 return OkStatus();
106 }
107
CheckExternalState() const108 Status CheckExternalState() const override {
109 return input_->CheckExternalState();
110 }
111
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const112 Status Get(OpKernelContext* ctx, int64 index,
113 std::vector<Tensor>* out_tensors) const override {
114 return input_->Get(ctx, index, out_tensors);
115 }
116
117 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const118 Status AsGraphDefInternal(SerializationContext* ctx,
119 DatasetGraphDefBuilder* b,
120 Node** output) const override {
121 Node* input_graph_node = nullptr;
122 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
123 Node* buffer_size = nullptr;
124 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
125 AttrValue slack_period_attr;
126 b->BuildAttrValue(slack_period_, &slack_period_attr);
127 AttrValue legacy_autotune_attr;
128 b->BuildAttrValue(legacy_autotune_, &legacy_autotune_attr);
129 AttrValue buffer_size_min_attr;
130 b->BuildAttrValue(buffer_size_min_, &buffer_size_min_attr);
131
132 TF_RETURN_IF_ERROR(
133 b->AddDataset(this, {input_graph_node, buffer_size},
134 {std::make_pair(kSlackPeriod, slack_period_attr),
135 std::make_pair(kLegacyAutotune, legacy_autotune_attr),
136 std::make_pair(kBufferSizeMin, buffer_size_min_attr)},
137 output));
138 return OkStatus();
139 }
140
141 private:
142 class Iterator : public DatasetIterator<Dataset> {
143 public:
Iterator(const Params & params)144 explicit Iterator(const Params& params)
145 : DatasetIterator<Dataset>(params),
146 mu_(std::make_shared<mutex>()),
147 cond_var_(std::make_shared<condition_variable>()),
148 buffer_size_min_(params.dataset->buffer_size_min_),
149 auto_tuner_(params.dataset->buffer_size_, buffer_size_min_),
150 legacy_autotune_(params.dataset->legacy_autotune_),
151 // If `legacy_autotune_`, initialize the `buffer_size_` value to be 0
152 // to avoid the created node to be collected as tunable nodes in the
153 // autotuning optimization.
154 buffer_size_(std::make_shared<model::SharedState>(
155 legacy_autotune_ ? 0 : params.dataset->buffer_size_, mu_,
156 cond_var_)) {
157 slack_us_ = 0;
158 }
159
~Iterator()160 ~Iterator() override {
161 CancelThreads();
162 if (deregister_fn_) deregister_fn_();
163 }
164
Initialize(IteratorContext * ctx)165 Status Initialize(IteratorContext* ctx) override {
166 mutex_lock l(*mu_);
167 interleave_depth_ = ctx->interleave_depth();
168
169 if (buffer_size_->value == model::kAutotune) {
170 buffer_size_->value = buffer_size_min_;
171 }
172 cancellation_manager_ = std::make_unique<CancellationManager>();
173 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
174 ctx->cancellation_manager(), [this]() { CancelThreads(); },
175 &deregister_fn_));
176 IteratorContext::Params params(ctx);
177 params.cancellation_manager = cancellation_manager_.get();
178 return dataset()->input_->MakeIterator(IteratorContext(params), this,
179 prefix(), &input_impl_);
180 }
181
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)182 Status GetNextInternal(IteratorContext* ctx,
183 std::vector<Tensor>* out_tensors,
184 bool* end_of_sequence) override {
185 const auto& stats_aggregator = ctx->stats_aggregator();
186 {
187 mutex_lock l(*mu_);
188 TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
189 // Wait until the next element in the buffer has been
190 // produced, or we are shutting down.
191 while (buffer_.empty() && !prefetch_thread_finished_ &&
192 buffer_limit() != 0) {
193 if (legacy_autotune_) {
194 auto_tuner_.RecordEmpty();
195 buffer_size_->value = auto_tuner_.buffer_limit();
196 }
197 RecordStop(ctx);
198 cond_var_->wait(l);
199 RecordStart(ctx);
200 }
201
202 if (!buffer_.empty()) {
203 return Consume(ctx, out_tensors, end_of_sequence);
204 }
205
206 if (prefetch_thread_finished_) {
207 *end_of_sequence = true;
208 return OkStatus();
209 }
210
211 DCHECK_EQ(buffer_limit(), 0);
212 }
213
214 mutex_lock input_l(input_mu_);
215 {
216 mutex_lock l(*mu_);
217 if (stats_aggregator) {
218 stats_aggregator->AddScalar(
219 stats_utils::BufferSizeScalarName(dataset()->node_name()),
220 static_cast<float>(buffer_.size()), num_elements());
221 stats_aggregator->AddScalar(
222 stats_utils::BufferCapacityScalarName(dataset()->node_name()),
223 static_cast<float>(buffer_limit()), num_elements());
224 }
225 // Release mu_
226 }
227 return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
228 }
229
230 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const231 std::shared_ptr<model::Node> CreateNode(
232 IteratorContext* ctx, model::Node::Args args) const override {
233 return model::MakeAsyncKnownRatioNode(
234 std::move(args),
235 /*ratio=*/1,
236 {model::MakeParameter(kBufferSize, buffer_size_,
237 /*min=*/buffer_size_min_,
238 /*max=*/std::numeric_limits<int64_t>::max())});
239 }
240
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)241 Status SaveInternal(SerializationContext* ctx,
242 IteratorStateWriter* writer) override {
243 // Acquire both locks to ensure that the prefetch thread and
244 // all GetNext threads are blocked.
245 mutex_lock input_l(input_mu_);
246 mutex_lock l(*mu_);
247 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
248 TF_RETURN_IF_ERROR(
249 writer->WriteScalar(prefix(), kBufferSize, buffer_.size()));
250 for (size_t i = 0; i < buffer_.size(); i++) {
251 auto& buffer_element = buffer_[i];
252 TF_RETURN_IF_ERROR(WriteStatus(writer, i, buffer_element.status));
253 if (buffer_element.status.ok()) {
254 TF_RETURN_IF_ERROR(writer->WriteScalar(
255 absl::StrCat(prefix(), "::", i),
256 absl::StrCat(kBuffer, kSizeSuffix), buffer_element.value.size()));
257 for (size_t j = 0; j < buffer_element.value.size(); j++) {
258 TF_RETURN_IF_ERROR(writer->WriteTensor(
259 absl::StrCat(prefix(), "::", i),
260 absl::StrCat(kBuffer, "[", j, "]"), buffer_element.value[j]));
261 }
262 }
263 }
264 return OkStatus();
265 }
266
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)267 Status RestoreInternal(IteratorContext* ctx,
268 IteratorStateReader* reader) override {
269 mutex_lock input_l(input_mu_);
270 mutex_lock l(*mu_);
271 DCHECK(buffer_.empty());
272 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
273 size_t buffer_size;
274 {
275 int64_t temp;
276 TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kBufferSize, &temp));
277 buffer_size = static_cast<size_t>(temp);
278 }
279 for (size_t i = 0; i < buffer_size; i++) {
280 buffer_.emplace_back();
281 auto& buffer_element = buffer_.back();
282 TF_RETURN_IF_ERROR(ReadStatus(reader, i, &buffer_element.status));
283 if (buffer_element.status.ok()) {
284 size_t value_size;
285 {
286 int64_t temp;
287 TF_RETURN_IF_ERROR(
288 reader->ReadScalar(absl::StrCat(prefix(), "::", i),
289 absl::StrCat(kBuffer, kSizeSuffix), &temp));
290 value_size = static_cast<size_t>(temp);
291 }
292 buffer_element.value.reserve(value_size);
293 for (size_t j = 0; j < value_size; j++) {
294 buffer_element.value.emplace_back();
295 TF_RETURN_IF_ERROR(
296 reader->ReadTensor(ctx->flr(), absl::StrCat(prefix(), "::", i),
297 absl::StrCat(kBuffer, "[", j, "]"),
298 &buffer_element.value.back()));
299 }
300 }
301 RecordBufferEnqueue(ctx, buffer_element.value);
302 }
303 return OkStatus();
304 }
305
GetTraceMeMetadata() const306 data::TraceMeMetadata GetTraceMeMetadata() const override {
307 int64_t limit = -1, size = -1;
308 data::TraceMeMetadata result;
309 // NOTE: We only set the parallelism value if the lock can be acquired
310 // right away to avoid introducing tracing overhead.
311 if (mu_->try_lock()) {
312 limit = buffer_limit();
313 size = buffer_.size();
314 if (!buffer_.empty()) {
315 std::vector<std::string> shapes(buffer_.front().value.size());
316 for (const auto& component : buffer_.front().value) {
317 shapes.push_back(component.shape().DebugString());
318 }
319 result.push_back(std::make_pair("next_element_shapes",
320 absl::StrJoin(shapes, ",")));
321 }
322 mu_->unlock();
323 }
324 result.push_back(std::make_pair(
325 "buffer_limit",
326 limit == -1
327 ? kTraceInfoUnavailable
328 : strings::Printf("%lld", static_cast<long long>(limit))));
329 result.push_back(std::make_pair(
330 "buffer_size",
331 size == -1 ? kTraceInfoUnavailable
332 : strings::Printf("%lld", static_cast<long long>(size))));
333 result.push_back(std::make_pair(
334 "autotune",
335 dataset()->buffer_size_ == model::kAutotune ? "true" : "false"));
336 result.push_back(std::make_pair(
337 "autotune_mode", legacy_autotune_ ? "legacy" : "performance"));
338 if (dataset()->slack_period_ > 0) {
339 result.push_back(std::make_pair(
340 "slack",
341 strings::Printf("%lld", static_cast<long long>(slack_us_.load()))));
342 }
343 result.push_back(std::make_pair(
344 "interleave_depth",
345 strings::Printf("%lld", static_cast<long long>(interleave_depth_))));
346 return result;
347 }
348
349 private:
350 // A buffer element comprises a status and (if that status is
351 // OK) a vector of tensors, representing an element of the input dataset.
352 struct BufferElement {
BufferElementtensorflow::data::PrefetchDatasetOp::Dataset::Iterator::BufferElement353 BufferElement() : uid(tensorflow::EnvTime::NowNanos()) {}
354
355 // The producer sets `status` if getting the input element fails.
356 Status status;
357 // The buffered data element.
358 std::vector<Tensor> value;
359 int64_t created_us;
360 const uint64 uid;
361 };
362
buffer_limit() const363 int64_t buffer_limit() const TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
364 if (legacy_autotune_) {
365 return auto_tuner_.buffer_limit();
366 }
367 return buffer_size_->value;
368 }
369
CancelThreads()370 void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
371 cancellation_manager_->StartCancel();
372 mutex_lock l(*mu_);
373 cancelled_ = true;
374 cond_var_->notify_all();
375 }
376
Consume(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)377 Status Consume(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
378 bool* end_of_sequence) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
379 const auto& stats_aggregator = ctx->stats_aggregator();
380 if (stats_aggregator) {
381 double buffer_limit_ = buffer_limit();
382 stats_aggregator->AddToHistogram(
383 stats_utils::BufferUtilizationHistogramName(dataset()->node_name()),
384 {static_cast<float>(buffer_.size()) /
385 static_cast<float>(buffer_limit_)},
386 num_elements());
387 stats_aggregator->AddScalar(
388 stats_utils::BufferSizeScalarName(dataset()->node_name()),
389 static_cast<float>(buffer_.size()), num_elements());
390 stats_aggregator->AddScalar(
391 stats_utils::BufferCapacityScalarName(dataset()->node_name()),
392 static_cast<float>(buffer_limit_), num_elements());
393 }
394 // A new element is available. Forward the status from computing it, and
395 // (if we successfully got an element) the output values.
396 Status s = buffer_.front().status;
397 if (s.ok()) {
398 int64_t buffer_element_id = buffer_.front().uid;
399 profiler::TraceMe traceme(
400 [&] {
401 return profiler::TraceMeEncode(
402 "PrefetchConsume", {{"element_id", buffer_element_id}});
403 },
404 profiler::kInfo);
405 if (dataset()->slack_period_ > 0 &&
406 (num_elements() + 1) % dataset()->slack_period_ == 0) {
407 // TODO(rachelim): Consider doing something more sophisticated
408 // to decide how long to sleep for; e.g. using a kalman filter.
409 int64_t slack_us = EnvTime::NowMicros() - buffer_.front().created_us;
410 // Every slack_period_-th element, update the most recent slack time,
411 // measured by the duration between when the element is prefetched
412 // and when it is consumed. We add kSleepFactor * slack_us_ to the
413 // measurement because we slept for that duration before prefetching
414 // the element.
415 slack_us_ = kSleepFactor * slack_us_ + slack_us;
416 VLOG(2) << "Setting slack_us_: " << slack_us_;
417 }
418 *out_tensors = std::move(buffer_.front().value);
419 RecordBufferDequeue(ctx, *out_tensors);
420 } else {
421 // If status not ok, we still record the dequeue event to make sure each
422 // enqueue event is paired with a dequeue event even in the presence of
423 // errors.
424 RecordBufferDequeue(ctx, buffer_.front().value);
425 }
426 if (legacy_autotune_) {
427 auto_tuner_.RecordConsumption(buffer_.size());
428 buffer_size_->value = auto_tuner_.buffer_limit();
429 }
430 buffer_.pop_front();
431 *end_of_sequence = false;
432
433 // Wake the prefetch thread, in case it has been waiting for space
434 // in the buffer. Also wake up threads from other calls to GetNext.
435 //
436 // TODO(mrry): Consider using different condition variables for
437 // GetNext and Prefetch.
438 cond_var_->notify_all();
439 return s;
440 }
441
EnsurePrefetchThreadStarted(IteratorContext * ctx)442 Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
443 TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
444 if (!prefetch_thread_) {
445 std::shared_ptr<IteratorContext> new_ctx =
446 std::make_shared<IteratorContext>(*ctx);
447 prefetch_thread_ = ctx->StartThread(
448 "tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); });
449 }
450 return OkStatus();
451 }
452
453 // Prefetches elements of the input, storing results in an internal buffer.
454 //
455 // It owns the iterator context passed to it.
PrefetchThread(const std::shared_ptr<IteratorContext> & ctx)456 void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) {
457 RecordStart(ctx.get());
458 auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
459 // Keep track of where we are in an iteration "burst"
460 int num_produced = 0;
461 while (true) {
462 // 1. Wait for a slot in the buffer.
463 {
464 mutex_lock l(*mu_);
465 while (!cancelled_ && buffer_.size() >= buffer_limit()) {
466 RecordStop(ctx.get());
467 cond_var_->wait(l);
468 RecordStart(ctx.get());
469 }
470
471 if (cancelled_) {
472 prefetch_thread_finished_ = true;
473 cond_var_->notify_all();
474 return;
475 }
476 }
477
478 if (dataset()->slack_period_ > 0 &&
479 num_produced % dataset()->slack_period_ == 0) {
480 // For the first element in the "burst", sleep for a bit if there is
481 // slack.
482 VLOG(2) << "Sleeping for: " << slack_us_ * kSleepFactor;
483 ctx->env()->SleepForMicroseconds(slack_us_ * kSleepFactor);
484 }
485
486 // 2. Read the next element.
487 // Acquire the input mutex since we will be reading an element from the
488 // input iterator. Note that we do not wish to release this mutex till
489 // we have added the fetched element to the `buffer_` else there will be
490 // local state that may be missed by SaveInternal.
491 mutex_lock input_l(input_mu_);
492 bool end_of_sequence;
493 BufferElement buffer_element;
494 {
495 profiler::TraceMe traceme(
496 [&] {
497 return profiler::TraceMeEncode(
498 "PrefetchProduce", {{"element_id", buffer_element.uid}});
499 },
500 profiler::kInfo);
501 buffer_element.status = input_impl_->GetNext(
502 ctx.get(), &buffer_element.value, &end_of_sequence);
503 }
504 if (buffer_element.status.ok() && end_of_sequence) {
505 mutex_lock l(*mu_);
506 prefetch_thread_finished_ = true;
507 cond_var_->notify_all();
508 return;
509 }
510
511 // 3. Signal that the element has been produced.
512 {
513 mutex_lock l(*mu_);
514 RecordBufferEnqueue(ctx.get(), buffer_element.value);
515 buffer_element.created_us = EnvTime::NowMicros();
516 buffer_.push_back(std::move(buffer_element));
517 cond_var_->notify_all();
518 }
519 ++num_produced;
520 }
521 }
522
WriteStatus(IteratorStateWriter * writer,size_t index,const Status & status)523 Status WriteStatus(IteratorStateWriter* writer, size_t index,
524 const Status& status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
525 TF_RETURN_IF_ERROR(
526 writer->WriteScalar(absl::StrCat(prefix(), "::", index), CodeKey(),
527 static_cast<int64_t>(status.code())));
528 if (!status.ok()) {
529 TF_RETURN_IF_ERROR(
530 writer->WriteScalar(absl::StrCat(prefix(), "::", index),
531 ErrorMessageKey(), status.error_message()));
532 }
533 return OkStatus();
534 }
535
ReadStatus(IteratorStateReader * reader,size_t index,Status * status)536 Status ReadStatus(IteratorStateReader* reader, size_t index, Status* status)
537 TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
538 int64_t code_int;
539 TF_RETURN_IF_ERROR(reader->ReadScalar(absl::StrCat(prefix(), "::", index),
540 CodeKey(), &code_int));
541 error::Code code = static_cast<error::Code>(code_int);
542
543 if (code != error::Code::OK) {
544 tstring error_message;
545 TF_RETURN_IF_ERROR(
546 reader->ReadScalar(absl::StrCat(prefix(), "::", index),
547 ErrorMessageKey(), &error_message));
548 *status = Status(code, error_message);
549 } else {
550 *status = OkStatus();
551 }
552 return OkStatus();
553 }
554
CodeKey()555 string CodeKey() { return absl::StrCat(kStatus, kCodeSuffix); }
556
ErrorMessageKey()557 string ErrorMessageKey() {
558 return absl::StrCat(kStatus, kErrorMessageSuffix);
559 }
560
561 // This mutex is used to ensure exclusivity between multiple threads
562 // reading/writing this iterator's local state.
563 //
564 // NOTE: We should never call GetNext on the input while holding this mutex.
565 const std::shared_ptr<mutex> mu_;
566 // This mutex is used to ensure exclusivity between multiple threads
567 // accessing the input iterator. We keep this separate from `mu_` to allow
568 // prefetching to run in parallel with GetNext calls.
569 mutex input_mu_ TF_ACQUIRED_BEFORE(*mu_);
570 // Controls cancellation of `input_impl_`. Must be ordered before
571 // `input_impl_` so that `input_impl_` is destroyed first.
572 std::unique_ptr<CancellationManager> cancellation_manager_;
573 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(input_mu_);
574 const std::shared_ptr<condition_variable> cond_var_;
575 const int64_t buffer_size_min_;
576 PrefetchAutotuner auto_tuner_ TF_GUARDED_BY(*mu_);
577 std::deque<BufferElement> buffer_ TF_GUARDED_BY(*mu_);
578 bool cancelled_ TF_GUARDED_BY(*mu_) = false;
579 bool prefetch_thread_finished_ TF_GUARDED_BY(*mu_) = false;
580 const bool legacy_autotune_;
581
582 std::atomic<int64_t> slack_us_;
583
584 // If legacy_autotune_ is false, identifies the maximum size of the buffer.
585 const std::shared_ptr<model::SharedState> buffer_size_;
586
587 // Method for deregistering the cancellation callback.
588 std::function<void()> deregister_fn_;
589
590 // Records the number of ParallelInterleave operations in the path from the
591 // root node to this node (not including this node) in the input pipeline
592 // tree. We record the interleave depth so that it can be included in the
593 // trace metadata.
594 int64 interleave_depth_ = -1;
595 std::unique_ptr<Thread> prefetch_thread_ TF_GUARDED_BY(*mu_);
596 };
597
598 const DatasetBase* const input_;
599 const int64_t buffer_size_;
600
601 // If non-zero, determines the period between injecting "slack" into the
602 // execution.
603 const int64_t slack_period_;
604
605 // Determines whether legacy autotuning should be used.
606 const bool legacy_autotune_ = true;
607
608 // If autotune is enabled, determines the minimal value of `buffer_size`
609 // parameter.
610 const int64_t buffer_size_min_ = 0;
611
612 TraceMeMetadata traceme_metadata_;
613 };
614
PrefetchDatasetOp(OpKernelConstruction * ctx)615 PrefetchDatasetOp::PrefetchDatasetOp(OpKernelConstruction* ctx)
616 : UnaryDatasetOpKernel(ctx) {
617 if (ctx->HasAttr(kSlackPeriod)) {
618 OP_REQUIRES_OK(ctx, ctx->GetAttr(kSlackPeriod, &slack_period_));
619 }
620 if (ctx->HasAttr(kLegacyAutotune)) {
621 OP_REQUIRES_OK(ctx, ctx->GetAttr(kLegacyAutotune, &legacy_autotune_));
622 }
623 if (ctx->HasAttr(kBufferSizeMin)) {
624 OP_REQUIRES_OK(ctx, ctx->GetAttr(kBufferSizeMin, &buffer_size_min_));
625 }
626 if (GetExperiments().contains("autotune_buffer_optimization")) {
627 legacy_autotune_ = false;
628 buffer_size_min_ = std::max(static_cast<int64_t>(1), buffer_size_min_);
629 }
630 }
631
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)632 void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
633 DatasetBase** output) {
634 int64_t buffer_size = 0;
635 OP_REQUIRES_OK(ctx,
636 ParseScalarArgument<int64_t>(ctx, kBufferSize, &buffer_size));
637 OP_REQUIRES(ctx, buffer_size >= 0 || buffer_size == model::kAutotune,
638 errors::InvalidArgument("buffer_size must be >= 0 or set "
639 "buffer_size to be ",
640 model::kAutotune, " for auto-tuning"));
641
642 if (buffer_size == model::kAutotune) {
643 metrics::RecordTFDataAutotune(kDatasetType);
644 }
645
646 *output = new Dataset(ctx, input, buffer_size, slack_period_,
647 legacy_autotune_, buffer_size_min_);
648 }
649
650 namespace {
651 REGISTER_KERNEL_BUILDER(Name("PrefetchDataset").Device(DEVICE_CPU).Priority(2),
652 PrefetchDatasetOp);
653 REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")
654 .Device(DEVICE_GPU)
655 .HostMemory("buffer_size")
656 .HostMemory("input_dataset")
657 .HostMemory("handle")
658 .Priority(1),
659 PrefetchDatasetOp);
660 } // namespace
661
662 } // namespace data
663 } // namespace tensorflow
664