xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/prefetch_dataset_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
16 
17 #include <algorithm>
18 #include <deque>
19 
20 #include "tensorflow/core/data/dataset_utils.h"
21 #include "tensorflow/core/data/name_utils.h"
22 #include "tensorflow/core/data/stats_utils.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/metrics.h"
25 #include "tensorflow/core/framework/model.h"
26 #include "tensorflow/core/framework/partial_tensor_shape.h"
27 #include "tensorflow/core/framework/stats_aggregator.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/lib/gtl/cleanup.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/lib/strings/stringprintf.h"
32 #include "tensorflow/core/platform/stringprintf.h"
33 #include "tensorflow/core/profiler/lib/traceme.h"
34 #include "tensorflow/core/profiler/lib/traceme_encode.h"
35 #include "tensorflow/core/protobuf/error_codes.pb.h"
36 
37 namespace tensorflow {
38 namespace data {
39 
40 // See documentation in ../../ops/dataset_ops.cc for a high-level
41 // description of the following op.
42 
43 /* static */ constexpr const char* const PrefetchDatasetOp::kDatasetType;
44 /* static */ constexpr const char* const PrefetchDatasetOp::kInputDataset;
45 /* static */ constexpr const char* const PrefetchDatasetOp::kBufferSize;
46 /* static */ constexpr const char* const PrefetchDatasetOp::kOutputTypes;
47 /* static */ constexpr const char* const PrefetchDatasetOp::kOutputShapes;
48 /* static */ constexpr const char* const PrefetchDatasetOp::kSlackPeriod;
49 /* static */ constexpr const char* const PrefetchDatasetOp::kLegacyAutotune;
50 /* static */ constexpr const char* const PrefetchDatasetOp::kBufferSizeMin;
51 
52 namespace {
53 
54 // Determines the fraction of slack time by which to delay prefetching of data.
55 constexpr double kSleepFactor = 0.2;
56 constexpr char kBuffer[] = "buffer";
57 constexpr char kStatus[] = "status";
58 constexpr char kSizeSuffix[] = ".size";
59 constexpr char kCodeSuffix[] = ".code";
60 constexpr char kErrorMessageSuffix[] = ".error_message";
61 
62 }  // namespace
63 
64 class PrefetchDatasetOp::Dataset : public DatasetBase {
65  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64_t buffer_size,int64_t slack_period,bool legacy_autotune,int64_t buffer_size_min)66   Dataset(OpKernelContext* ctx, const DatasetBase* input, int64_t buffer_size,
67           int64_t slack_period, bool legacy_autotune, int64_t buffer_size_min)
68       : DatasetBase(DatasetContext(ctx)),
69         input_(input),
70         buffer_size_(buffer_size),
71         slack_period_(slack_period),
72         legacy_autotune_(legacy_autotune),
73         buffer_size_min_(buffer_size_min) {
74     input_->Ref();
75   }
76 
~Dataset()77   ~Dataset() override { input_->Unref(); }
78 
MakeIteratorInternal(const string & prefix) const79   std::unique_ptr<IteratorBase> MakeIteratorInternal(
80       const string& prefix) const override {
81     return std::make_unique<Iterator>(Iterator::Params{
82         this, name_utils::IteratorPrefix(kDatasetType, prefix)});
83   }
84 
output_dtypes() const85   const DataTypeVector& output_dtypes() const override {
86     return input_->output_dtypes();
87   }
88 
output_shapes() const89   const std::vector<PartialTensorShape>& output_shapes() const override {
90     return input_->output_shapes();
91   }
92 
DebugString() const93   string DebugString() const override {
94     return name_utils::DatasetDebugString(kDatasetType);
95   }
96 
CardinalityInternal() const97   int64_t CardinalityInternal() const override { return input_->Cardinality(); }
98 
CardinalityInternal(CardinalityOptions options) const99   int64_t CardinalityInternal(CardinalityOptions options) const override {
100     return input_->Cardinality(options);
101   }
102 
InputDatasets(std::vector<const DatasetBase * > * inputs) const103   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
104     inputs->push_back(input_);
105     return OkStatus();
106   }
107 
CheckExternalState() const108   Status CheckExternalState() const override {
109     return input_->CheckExternalState();
110   }
111 
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const112   Status Get(OpKernelContext* ctx, int64 index,
113              std::vector<Tensor>* out_tensors) const override {
114     return input_->Get(ctx, index, out_tensors);
115   }
116 
117  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const118   Status AsGraphDefInternal(SerializationContext* ctx,
119                             DatasetGraphDefBuilder* b,
120                             Node** output) const override {
121     Node* input_graph_node = nullptr;
122     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
123     Node* buffer_size = nullptr;
124     TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
125     AttrValue slack_period_attr;
126     b->BuildAttrValue(slack_period_, &slack_period_attr);
127     AttrValue legacy_autotune_attr;
128     b->BuildAttrValue(legacy_autotune_, &legacy_autotune_attr);
129     AttrValue buffer_size_min_attr;
130     b->BuildAttrValue(buffer_size_min_, &buffer_size_min_attr);
131 
132     TF_RETURN_IF_ERROR(
133         b->AddDataset(this, {input_graph_node, buffer_size},
134                       {std::make_pair(kSlackPeriod, slack_period_attr),
135                        std::make_pair(kLegacyAutotune, legacy_autotune_attr),
136                        std::make_pair(kBufferSizeMin, buffer_size_min_attr)},
137                       output));
138     return OkStatus();
139   }
140 
141  private:
142   class Iterator : public DatasetIterator<Dataset> {
143    public:
Iterator(const Params & params)144     explicit Iterator(const Params& params)
145         : DatasetIterator<Dataset>(params),
146           mu_(std::make_shared<mutex>()),
147           cond_var_(std::make_shared<condition_variable>()),
148           buffer_size_min_(params.dataset->buffer_size_min_),
149           auto_tuner_(params.dataset->buffer_size_, buffer_size_min_),
150           legacy_autotune_(params.dataset->legacy_autotune_),
151           // If `legacy_autotune_`, initialize the `buffer_size_` value to be 0
152           // to avoid the created node to be collected as tunable nodes in the
153           // autotuning optimization.
154           buffer_size_(std::make_shared<model::SharedState>(
155               legacy_autotune_ ? 0 : params.dataset->buffer_size_, mu_,
156               cond_var_)) {
157       slack_us_ = 0;
158     }
159 
~Iterator()160     ~Iterator() override {
161       CancelThreads();
162       if (deregister_fn_) deregister_fn_();
163     }
164 
Initialize(IteratorContext * ctx)165     Status Initialize(IteratorContext* ctx) override {
166       mutex_lock l(*mu_);
167       interleave_depth_ = ctx->interleave_depth();
168 
169       if (buffer_size_->value == model::kAutotune) {
170         buffer_size_->value = buffer_size_min_;
171       }
172       cancellation_manager_ = std::make_unique<CancellationManager>();
173       TF_RETURN_IF_ERROR(RegisterCancellationCallback(
174           ctx->cancellation_manager(), [this]() { CancelThreads(); },
175           &deregister_fn_));
176       IteratorContext::Params params(ctx);
177       params.cancellation_manager = cancellation_manager_.get();
178       return dataset()->input_->MakeIterator(IteratorContext(params), this,
179                                              prefix(), &input_impl_);
180     }
181 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)182     Status GetNextInternal(IteratorContext* ctx,
183                            std::vector<Tensor>* out_tensors,
184                            bool* end_of_sequence) override {
185       const auto& stats_aggregator = ctx->stats_aggregator();
186       {
187         mutex_lock l(*mu_);
188         TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
189         // Wait until the next element in the buffer has been
190         // produced, or we are shutting down.
191         while (buffer_.empty() && !prefetch_thread_finished_ &&
192                buffer_limit() != 0) {
193           if (legacy_autotune_) {
194             auto_tuner_.RecordEmpty();
195             buffer_size_->value = auto_tuner_.buffer_limit();
196           }
197           RecordStop(ctx);
198           cond_var_->wait(l);
199           RecordStart(ctx);
200         }
201 
202         if (!buffer_.empty()) {
203           return Consume(ctx, out_tensors, end_of_sequence);
204         }
205 
206         if (prefetch_thread_finished_) {
207           *end_of_sequence = true;
208           return OkStatus();
209         }
210 
211         DCHECK_EQ(buffer_limit(), 0);
212       }
213 
214       mutex_lock input_l(input_mu_);
215       {
216         mutex_lock l(*mu_);
217         if (stats_aggregator) {
218           stats_aggregator->AddScalar(
219               stats_utils::BufferSizeScalarName(dataset()->node_name()),
220               static_cast<float>(buffer_.size()), num_elements());
221           stats_aggregator->AddScalar(
222               stats_utils::BufferCapacityScalarName(dataset()->node_name()),
223               static_cast<float>(buffer_limit()), num_elements());
224         }
225         // Release mu_
226       }
227       return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
228     }
229 
230    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const231     std::shared_ptr<model::Node> CreateNode(
232         IteratorContext* ctx, model::Node::Args args) const override {
233       return model::MakeAsyncKnownRatioNode(
234           std::move(args),
235           /*ratio=*/1,
236           {model::MakeParameter(kBufferSize, buffer_size_,
237                                 /*min=*/buffer_size_min_,
238                                 /*max=*/std::numeric_limits<int64_t>::max())});
239     }
240 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)241     Status SaveInternal(SerializationContext* ctx,
242                         IteratorStateWriter* writer) override {
243       // Acquire both locks to ensure that the prefetch thread and
244       // all GetNext threads are blocked.
245       mutex_lock input_l(input_mu_);
246       mutex_lock l(*mu_);
247       TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
248       TF_RETURN_IF_ERROR(
249           writer->WriteScalar(prefix(), kBufferSize, buffer_.size()));
250       for (size_t i = 0; i < buffer_.size(); i++) {
251         auto& buffer_element = buffer_[i];
252         TF_RETURN_IF_ERROR(WriteStatus(writer, i, buffer_element.status));
253         if (buffer_element.status.ok()) {
254           TF_RETURN_IF_ERROR(writer->WriteScalar(
255               absl::StrCat(prefix(), "::", i),
256               absl::StrCat(kBuffer, kSizeSuffix), buffer_element.value.size()));
257           for (size_t j = 0; j < buffer_element.value.size(); j++) {
258             TF_RETURN_IF_ERROR(writer->WriteTensor(
259                 absl::StrCat(prefix(), "::", i),
260                 absl::StrCat(kBuffer, "[", j, "]"), buffer_element.value[j]));
261           }
262         }
263       }
264       return OkStatus();
265     }
266 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)267     Status RestoreInternal(IteratorContext* ctx,
268                            IteratorStateReader* reader) override {
269       mutex_lock input_l(input_mu_);
270       mutex_lock l(*mu_);
271       DCHECK(buffer_.empty());
272       TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
273       size_t buffer_size;
274       {
275         int64_t temp;
276         TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kBufferSize, &temp));
277         buffer_size = static_cast<size_t>(temp);
278       }
279       for (size_t i = 0; i < buffer_size; i++) {
280         buffer_.emplace_back();
281         auto& buffer_element = buffer_.back();
282         TF_RETURN_IF_ERROR(ReadStatus(reader, i, &buffer_element.status));
283         if (buffer_element.status.ok()) {
284           size_t value_size;
285           {
286             int64_t temp;
287             TF_RETURN_IF_ERROR(
288                 reader->ReadScalar(absl::StrCat(prefix(), "::", i),
289                                    absl::StrCat(kBuffer, kSizeSuffix), &temp));
290             value_size = static_cast<size_t>(temp);
291           }
292           buffer_element.value.reserve(value_size);
293           for (size_t j = 0; j < value_size; j++) {
294             buffer_element.value.emplace_back();
295             TF_RETURN_IF_ERROR(
296                 reader->ReadTensor(ctx->flr(), absl::StrCat(prefix(), "::", i),
297                                    absl::StrCat(kBuffer, "[", j, "]"),
298                                    &buffer_element.value.back()));
299           }
300         }
301         RecordBufferEnqueue(ctx, buffer_element.value);
302       }
303       return OkStatus();
304     }
305 
GetTraceMeMetadata() const306     data::TraceMeMetadata GetTraceMeMetadata() const override {
307       int64_t limit = -1, size = -1;
308       data::TraceMeMetadata result;
309       // NOTE: We only set the parallelism value if the lock can be acquired
310       // right away to avoid introducing tracing overhead.
311       if (mu_->try_lock()) {
312         limit = buffer_limit();
313         size = buffer_.size();
314         if (!buffer_.empty()) {
315           std::vector<std::string> shapes(buffer_.front().value.size());
316           for (const auto& component : buffer_.front().value) {
317             shapes.push_back(component.shape().DebugString());
318           }
319           result.push_back(std::make_pair("next_element_shapes",
320                                           absl::StrJoin(shapes, ",")));
321         }
322         mu_->unlock();
323       }
324       result.push_back(std::make_pair(
325           "buffer_limit",
326           limit == -1
327               ? kTraceInfoUnavailable
328               : strings::Printf("%lld", static_cast<long long>(limit))));
329       result.push_back(std::make_pair(
330           "buffer_size",
331           size == -1 ? kTraceInfoUnavailable
332                      : strings::Printf("%lld", static_cast<long long>(size))));
333       result.push_back(std::make_pair(
334           "autotune",
335           dataset()->buffer_size_ == model::kAutotune ? "true" : "false"));
336       result.push_back(std::make_pair(
337           "autotune_mode", legacy_autotune_ ? "legacy" : "performance"));
338       if (dataset()->slack_period_ > 0) {
339         result.push_back(std::make_pair(
340             "slack",
341             strings::Printf("%lld", static_cast<long long>(slack_us_.load()))));
342       }
343       result.push_back(std::make_pair(
344           "interleave_depth",
345           strings::Printf("%lld", static_cast<long long>(interleave_depth_))));
346       return result;
347     }
348 
349    private:
350     // A buffer element comprises a status and (if that status is
351     // OK) a vector of tensors, representing an element of the input dataset.
352     struct BufferElement {
BufferElementtensorflow::data::PrefetchDatasetOp::Dataset::Iterator::BufferElement353       BufferElement() : uid(tensorflow::EnvTime::NowNanos()) {}
354 
355       // The producer sets `status` if getting the input element fails.
356       Status status;
357       // The buffered data element.
358       std::vector<Tensor> value;
359       int64_t created_us;
360       const uint64 uid;
361     };
362 
buffer_limit() const363     int64_t buffer_limit() const TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
364       if (legacy_autotune_) {
365         return auto_tuner_.buffer_limit();
366       }
367       return buffer_size_->value;
368     }
369 
CancelThreads()370     void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
371       cancellation_manager_->StartCancel();
372       mutex_lock l(*mu_);
373       cancelled_ = true;
374       cond_var_->notify_all();
375     }
376 
Consume(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)377     Status Consume(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
378                    bool* end_of_sequence) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
379       const auto& stats_aggregator = ctx->stats_aggregator();
380       if (stats_aggregator) {
381         double buffer_limit_ = buffer_limit();
382         stats_aggregator->AddToHistogram(
383             stats_utils::BufferUtilizationHistogramName(dataset()->node_name()),
384             {static_cast<float>(buffer_.size()) /
385              static_cast<float>(buffer_limit_)},
386             num_elements());
387         stats_aggregator->AddScalar(
388             stats_utils::BufferSizeScalarName(dataset()->node_name()),
389             static_cast<float>(buffer_.size()), num_elements());
390         stats_aggregator->AddScalar(
391             stats_utils::BufferCapacityScalarName(dataset()->node_name()),
392             static_cast<float>(buffer_limit_), num_elements());
393       }
394       // A new element is available. Forward the status from computing it, and
395       // (if we successfully got an element) the output values.
396       Status s = buffer_.front().status;
397       if (s.ok()) {
398         int64_t buffer_element_id = buffer_.front().uid;
399         profiler::TraceMe traceme(
400             [&] {
401               return profiler::TraceMeEncode(
402                   "PrefetchConsume", {{"element_id", buffer_element_id}});
403             },
404             profiler::kInfo);
405         if (dataset()->slack_period_ > 0 &&
406             (num_elements() + 1) % dataset()->slack_period_ == 0) {
407           // TODO(rachelim): Consider doing something more sophisticated
408           // to decide how long to sleep for; e.g. using a kalman filter.
409           int64_t slack_us = EnvTime::NowMicros() - buffer_.front().created_us;
410           // Every slack_period_-th element, update the most recent slack time,
411           // measured by the duration between when the element is prefetched
412           // and when it is consumed. We add kSleepFactor * slack_us_ to the
413           // measurement because we slept for that duration before prefetching
414           // the element.
415           slack_us_ = kSleepFactor * slack_us_ + slack_us;
416           VLOG(2) << "Setting slack_us_: " << slack_us_;
417         }
418         *out_tensors = std::move(buffer_.front().value);
419         RecordBufferDequeue(ctx, *out_tensors);
420       } else {
421         // If status not ok, we still record the dequeue event to make sure each
422         // enqueue event is paired with a dequeue event even in the presence of
423         // errors.
424         RecordBufferDequeue(ctx, buffer_.front().value);
425       }
426       if (legacy_autotune_) {
427         auto_tuner_.RecordConsumption(buffer_.size());
428         buffer_size_->value = auto_tuner_.buffer_limit();
429       }
430       buffer_.pop_front();
431       *end_of_sequence = false;
432 
433       // Wake the prefetch thread, in case it has been waiting for space
434       // in the buffer. Also wake up threads from other calls to GetNext.
435       //
436       // TODO(mrry): Consider using different condition variables for
437       // GetNext and Prefetch.
438       cond_var_->notify_all();
439       return s;
440     }
441 
EnsurePrefetchThreadStarted(IteratorContext * ctx)442     Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
443         TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
444       if (!prefetch_thread_) {
445         std::shared_ptr<IteratorContext> new_ctx =
446             std::make_shared<IteratorContext>(*ctx);
447         prefetch_thread_ = ctx->StartThread(
448             "tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); });
449       }
450       return OkStatus();
451     }
452 
453     // Prefetches elements of the input, storing results in an internal buffer.
454     //
455     // It owns the iterator context passed to it.
PrefetchThread(const std::shared_ptr<IteratorContext> & ctx)456     void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) {
457       RecordStart(ctx.get());
458       auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
459       // Keep track of where we are in an iteration "burst"
460       int num_produced = 0;
461       while (true) {
462         // 1. Wait for a slot in the buffer.
463         {
464           mutex_lock l(*mu_);
465           while (!cancelled_ && buffer_.size() >= buffer_limit()) {
466             RecordStop(ctx.get());
467             cond_var_->wait(l);
468             RecordStart(ctx.get());
469           }
470 
471           if (cancelled_) {
472             prefetch_thread_finished_ = true;
473             cond_var_->notify_all();
474             return;
475           }
476         }
477 
478         if (dataset()->slack_period_ > 0 &&
479             num_produced % dataset()->slack_period_ == 0) {
480           // For the first element in the "burst", sleep for a bit if there is
481           // slack.
482           VLOG(2) << "Sleeping for: " << slack_us_ * kSleepFactor;
483           ctx->env()->SleepForMicroseconds(slack_us_ * kSleepFactor);
484         }
485 
486         // 2. Read the next element.
487         // Acquire the input mutex since we will be reading an element from the
488         // input iterator. Note that we do not wish to release this mutex till
489         // we have added the fetched element to the `buffer_` else there will be
490         // local state that may be missed by SaveInternal.
491         mutex_lock input_l(input_mu_);
492         bool end_of_sequence;
493         BufferElement buffer_element;
494         {
495           profiler::TraceMe traceme(
496               [&] {
497                 return profiler::TraceMeEncode(
498                     "PrefetchProduce", {{"element_id", buffer_element.uid}});
499               },
500               profiler::kInfo);
501           buffer_element.status = input_impl_->GetNext(
502               ctx.get(), &buffer_element.value, &end_of_sequence);
503         }
504         if (buffer_element.status.ok() && end_of_sequence) {
505           mutex_lock l(*mu_);
506           prefetch_thread_finished_ = true;
507           cond_var_->notify_all();
508           return;
509         }
510 
511         // 3. Signal that the element has been produced.
512         {
513           mutex_lock l(*mu_);
514           RecordBufferEnqueue(ctx.get(), buffer_element.value);
515           buffer_element.created_us = EnvTime::NowMicros();
516           buffer_.push_back(std::move(buffer_element));
517           cond_var_->notify_all();
518         }
519         ++num_produced;
520       }
521     }
522 
WriteStatus(IteratorStateWriter * writer,size_t index,const Status & status)523     Status WriteStatus(IteratorStateWriter* writer, size_t index,
524                        const Status& status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
525       TF_RETURN_IF_ERROR(
526           writer->WriteScalar(absl::StrCat(prefix(), "::", index), CodeKey(),
527                               static_cast<int64_t>(status.code())));
528       if (!status.ok()) {
529         TF_RETURN_IF_ERROR(
530             writer->WriteScalar(absl::StrCat(prefix(), "::", index),
531                                 ErrorMessageKey(), status.error_message()));
532       }
533       return OkStatus();
534     }
535 
ReadStatus(IteratorStateReader * reader,size_t index,Status * status)536     Status ReadStatus(IteratorStateReader* reader, size_t index, Status* status)
537         TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
538       int64_t code_int;
539       TF_RETURN_IF_ERROR(reader->ReadScalar(absl::StrCat(prefix(), "::", index),
540                                             CodeKey(), &code_int));
541       error::Code code = static_cast<error::Code>(code_int);
542 
543       if (code != error::Code::OK) {
544         tstring error_message;
545         TF_RETURN_IF_ERROR(
546             reader->ReadScalar(absl::StrCat(prefix(), "::", index),
547                                ErrorMessageKey(), &error_message));
548         *status = Status(code, error_message);
549       } else {
550         *status = OkStatus();
551       }
552       return OkStatus();
553     }
554 
CodeKey()555     string CodeKey() { return absl::StrCat(kStatus, kCodeSuffix); }
556 
ErrorMessageKey()557     string ErrorMessageKey() {
558       return absl::StrCat(kStatus, kErrorMessageSuffix);
559     }
560 
561     // This mutex is used to ensure exclusivity between multiple threads
562     // reading/writing this iterator's local state.
563     //
564     // NOTE: We should never call GetNext on the input while holding this mutex.
565     const std::shared_ptr<mutex> mu_;
566     // This mutex is used to ensure exclusivity between multiple threads
567     // accessing the input iterator. We keep this separate from `mu_` to allow
568     // prefetching to run in parallel with GetNext calls.
569     mutex input_mu_ TF_ACQUIRED_BEFORE(*mu_);
570     // Controls cancellation of `input_impl_`. Must be ordered before
571     // `input_impl_` so that `input_impl_` is destroyed first.
572     std::unique_ptr<CancellationManager> cancellation_manager_;
573     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(input_mu_);
574     const std::shared_ptr<condition_variable> cond_var_;
575     const int64_t buffer_size_min_;
576     PrefetchAutotuner auto_tuner_ TF_GUARDED_BY(*mu_);
577     std::deque<BufferElement> buffer_ TF_GUARDED_BY(*mu_);
578     bool cancelled_ TF_GUARDED_BY(*mu_) = false;
579     bool prefetch_thread_finished_ TF_GUARDED_BY(*mu_) = false;
580     const bool legacy_autotune_;
581 
582     std::atomic<int64_t> slack_us_;
583 
584     // If legacy_autotune_ is false, identifies the maximum size of the buffer.
585     const std::shared_ptr<model::SharedState> buffer_size_;
586 
587     // Method for deregistering the cancellation callback.
588     std::function<void()> deregister_fn_;
589 
590     // Records the number of ParallelInterleave operations in the path from the
591     // root node to this node (not including this node) in the input pipeline
592     // tree. We record the interleave depth so that it can be included in the
593     // trace metadata.
594     int64 interleave_depth_ = -1;
595     std::unique_ptr<Thread> prefetch_thread_ TF_GUARDED_BY(*mu_);
596   };
597 
598   const DatasetBase* const input_;
599   const int64_t buffer_size_;
600 
601   // If non-zero, determines the period between injecting "slack" into the
602   // execution.
603   const int64_t slack_period_;
604 
605   // Determines whether legacy autotuning should be used.
606   const bool legacy_autotune_ = true;
607 
608   // If autotune is enabled, determines the minimal value of `buffer_size`
609   // parameter.
610   const int64_t buffer_size_min_ = 0;
611 
612   TraceMeMetadata traceme_metadata_;
613 };
614 
PrefetchDatasetOp(OpKernelConstruction * ctx)615 PrefetchDatasetOp::PrefetchDatasetOp(OpKernelConstruction* ctx)
616     : UnaryDatasetOpKernel(ctx) {
617   if (ctx->HasAttr(kSlackPeriod)) {
618     OP_REQUIRES_OK(ctx, ctx->GetAttr(kSlackPeriod, &slack_period_));
619   }
620   if (ctx->HasAttr(kLegacyAutotune)) {
621     OP_REQUIRES_OK(ctx, ctx->GetAttr(kLegacyAutotune, &legacy_autotune_));
622   }
623   if (ctx->HasAttr(kBufferSizeMin)) {
624     OP_REQUIRES_OK(ctx, ctx->GetAttr(kBufferSizeMin, &buffer_size_min_));
625   }
626   if (GetExperiments().contains("autotune_buffer_optimization")) {
627     legacy_autotune_ = false;
628     buffer_size_min_ = std::max(static_cast<int64_t>(1), buffer_size_min_);
629   }
630 }
631 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)632 void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
633                                     DatasetBase** output) {
634   int64_t buffer_size = 0;
635   OP_REQUIRES_OK(ctx,
636                  ParseScalarArgument<int64_t>(ctx, kBufferSize, &buffer_size));
637   OP_REQUIRES(ctx, buffer_size >= 0 || buffer_size == model::kAutotune,
638               errors::InvalidArgument("buffer_size must be >= 0 or set "
639                                       "buffer_size to be ",
640                                       model::kAutotune, " for auto-tuning"));
641 
642   if (buffer_size == model::kAutotune) {
643     metrics::RecordTFDataAutotune(kDatasetType);
644   }
645 
646   *output = new Dataset(ctx, input, buffer_size, slack_period_,
647                         legacy_autotune_, buffer_size_min_);
648 }
649 
650 namespace {
651 REGISTER_KERNEL_BUILDER(Name("PrefetchDataset").Device(DEVICE_CPU).Priority(2),
652                         PrefetchDatasetOp);
653 REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")
654                             .Device(DEVICE_GPU)
655                             .HostMemory("buffer_size")
656                             .HostMemory("input_dataset")
657                             .HostMemory("handle")
658                             .Priority(1),
659                         PrefetchDatasetOp);
660 }  // namespace
661 
662 }  // namespace data
663 }  // namespace tensorflow
664