1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.h"
16 
17 #include <atomic>
18 #include <deque>
19 #include <functional>
20 #include <string>
21 #include <utility>
22 
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
25 #include "tensorflow/core/data/dataset_utils.h"
26 #include "tensorflow/core/data/name_utils.h"
27 #include "tensorflow/core/framework/dataset.h"
28 #include "tensorflow/core/framework/partial_tensor_shape.h"
29 #include "tensorflow/core/framework/stats_aggregator.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/lib/core/threadpool.h"
32 #include "tensorflow/core/lib/gtl/cleanup.h"
33 #include "tensorflow/core/lib/random/random.h"
34 #include "tensorflow/core/platform/blocking_counter.h"
35 #include "tensorflow/core/platform/stringprintf.h"
36 #include "tensorflow/core/profiler/lib/traceme.h"
37 #include "tensorflow/core/profiler/lib/traceme_encode.h"
38 
39 namespace tensorflow {
40 namespace data {
41 namespace experimental {
42 
43 /* static */ constexpr const char* const
44     ParallelInterleaveDatasetOp::kDatasetType;
45 /* static */ constexpr const char* const
46     ParallelInterleaveDatasetOp::kInputDataset;
47 /* static */ constexpr const char* const
48     ParallelInterleaveDatasetOp::kOtherArguments;
49 /* static */ constexpr const char* const
50     ParallelInterleaveDatasetOp::kCycleLength;
51 /* static */ constexpr const char* const
52     ParallelInterleaveDatasetOp::kBlockLength;
53 /* static */ constexpr const char* const
54     ParallelInterleaveDatasetOp::kDeterministic;
55 /* static */ constexpr const char* const ParallelInterleaveDatasetOp::kSloppy;
56 /* static */ constexpr const char* const
57     ParallelInterleaveDatasetOp::kBufferOutputElements;
58 /* static */ constexpr const char* const
59     ParallelInterleaveDatasetOp::kPrefetchInputElements;
60 /* static */ constexpr const char* const ParallelInterleaveDatasetOp::kFunc;
61 /* static */ constexpr const char* const
62     ParallelInterleaveDatasetOp::kTarguments;
63 /* static */ constexpr const char* const
64     ParallelInterleaveDatasetOp::kOutputTypes;
65 /* static */ constexpr const char* const
66     ParallelInterleaveDatasetOp::kOutputShapes;
67 
68 constexpr char kInputExhausted[] = "input_exhausted";
69 constexpr char kNextIndex[] = "next_index";
70 constexpr char kBlockCount[] = "block_count";
71 constexpr char kWorkersSize[] = "workers_size";
72 constexpr char kInterleaveSize[] = "interleave_size";
73 constexpr char kInterleaveIndices[] = "interleave_indices";
74 constexpr char kStagingSize[] = "staging_size";
75 constexpr char kStagingIndices[] = "staging_indices";
76 constexpr char kWorkerThreadsRunning[] = "worker_threads_running";
77 constexpr char kDataParallelInterleaveWorker[] =
78     "data_parallel_interleave_worker";
79 constexpr char kWorker[] = "worker";
80 constexpr char kInputSize[] = "input_size";
81 constexpr char kInput[] = "input";
82 constexpr char kOutputsSize[] = "outputs_size";
83 constexpr char kOutputs[] = "outputs";
84 constexpr char kIsProducing[] = "is_producing";
85 constexpr char kWorkerThread[] = "worker_thread";
86 constexpr char kIteratorExhausted[] = "iterator_exhausted";
87 constexpr char kIteratorCreationStatus[] = "iterator_creation_status";
88 constexpr char kOutput[] = "output";
89 constexpr char kEndOfSequence[] = "end_of_sequence";
90 constexpr char kStatus[] = "status";
91 constexpr char kOutputSize[] = "output_size";
92 constexpr char kCode[] = "code";
93 constexpr char KMessage[] = "msg";
94 
95 class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
96  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,std::unique_ptr<CapturedFunction> captured_func,int64_t cycle_length,int64_t block_length,DeterminismPolicy deterministic,int64_t buffer_output_elements,int64_t prefetch_input_elements,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,int op_version)97   Dataset(OpKernelContext* ctx, const DatasetBase* input,
98           std::unique_ptr<CapturedFunction> captured_func, int64_t cycle_length,
99           int64_t block_length, DeterminismPolicy deterministic,
100           int64_t buffer_output_elements, int64_t prefetch_input_elements,
101           const DataTypeVector& output_types,
102           const std::vector<PartialTensorShape>& output_shapes, int op_version)
103       : DatasetBase(DatasetContext(ctx)),
104         input_(input),
105         captured_func_(std::move(captured_func)),
106         cycle_length_(cycle_length),
107         block_length_(block_length),
108         deterministic_(deterministic),
109         buffer_output_elements_(buffer_output_elements),
110         prefetch_input_elements_(prefetch_input_elements),
111         output_types_(output_types),
112         output_shapes_(output_shapes),
113         traceme_metadata_(
114             {{"block_length",
115               strings::Printf("%lld", static_cast<long long>(block_length))},
116              {"cycle_length",
117               strings::Printf("%lld", static_cast<long long>(cycle_length))},
118              {"deterministic",
119               deterministic.IsDeterministic() || deterministic.IsDefault()
120                   ? "true"
121                   : "false"}}),
122         op_version_(op_version) {
123     input_->Ref();
124   }
125 
~Dataset()126   ~Dataset() override { input_->Unref(); }
127 
MakeIteratorInternal(const string & prefix) const128   std::unique_ptr<IteratorBase> MakeIteratorInternal(
129       const string& prefix) const override {
130     name_utils::IteratorPrefixParams params;
131     params.op_version = op_version_;
132     bool deterministic =
133         deterministic_.IsDeterministic() || deterministic_.IsDefault();
134     return std::make_unique<Iterator>(
135         Iterator::Params{
136             this, name_utils::IteratorPrefix(kDatasetType, prefix, params)},
137         deterministic);
138   }
139 
output_dtypes() const140   const DataTypeVector& output_dtypes() const override { return output_types_; }
141 
output_shapes() const142   const std::vector<PartialTensorShape>& output_shapes() const override {
143     return output_shapes_;
144   }
145 
DebugString() const146   string DebugString() const override {
147     name_utils::DatasetDebugStringParams params;
148     params.op_version = op_version_;
149     return name_utils::DatasetDebugString(kDatasetType, params);
150   }
151 
InputDatasets(std::vector<const DatasetBase * > * inputs) const152   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
153     inputs->push_back(input_);
154     return OkStatus();
155   }
156 
CheckExternalState() const157   Status CheckExternalState() const override {
158     TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
159     return input_->CheckExternalState();
160   }
161 
162  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const163   Status AsGraphDefInternal(SerializationContext* ctx,
164                             DatasetGraphDefBuilder* b,
165                             Node** output) const override {
166     std::vector<std::pair<size_t, Node*>> inputs;
167     std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>> list_inputs;
168     int input_index = 0;
169 
170     Node* input_node;
171     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
172     inputs.emplace_back(input_index++, input_node);
173 
174     std::vector<Node*> other_arguments;
175     DataTypeVector other_arguments_types;
176     TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
177                                                   &other_arguments_types));
178     list_inputs.emplace_back(input_index++, other_arguments);
179 
180     Node* cycle_length_node;
181     TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
182     inputs.emplace_back(input_index++, cycle_length_node);
183 
184     Node* block_length_node;
185     TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
186     inputs.emplace_back(input_index++, block_length_node);
187 
188     if (op_version_ == 1) {
189       Node* sloppy_node;
190       TF_RETURN_IF_ERROR(
191           b->AddScalar(deterministic_.IsNondeterministic(), &sloppy_node));
192       inputs.emplace_back(input_index++, sloppy_node);
193     }
194 
195     Node* buffer_output_elements_node;
196     TF_RETURN_IF_ERROR(
197         b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
198     inputs.emplace_back(input_index++, buffer_output_elements_node);
199 
200     Node* prefetch_input_elements_node;
201     TF_RETURN_IF_ERROR(
202         b->AddScalar(prefetch_input_elements_, &prefetch_input_elements_node));
203     inputs.emplace_back(input_index++, prefetch_input_elements_node);
204 
205     std::vector<std::pair<StringPiece, AttrValue>> attrs;
206 
207     AttrValue f;
208     b->BuildAttrValue(captured_func_->func(), &f);
209     attrs.emplace_back(kFunc, f);
210 
211     if (op_version_ == 2) {
212       AttrValue deterministic_attr;
213       b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
214       attrs.emplace_back(kDeterministic, deterministic_attr);
215     }
216 
217     AttrValue other_arguments_types_attr;
218     b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
219     attrs.emplace_back(kTarguments, other_arguments_types_attr);
220 
221     TF_RETURN_IF_ERROR(b->AddDataset(this, inputs, list_inputs, attrs, output));
222     return OkStatus();
223   }
224 
225  private:
num_threads() const226   int64_t num_threads() const {
227     return cycle_length_ + prefetch_input_elements_;
228   }
229 
230   // Parallel interleave's implementation is designed around a few principles:
231   //  1. Thread creation is relatively expensive. (Not reusing
232   //     threads causes a number of indirect costs such as poorer tcmalloc
233   //     performance due to thread-local caches, etc.) We allocate a fixed
234   //     number of threads at the start and never change. This is why we've
235   //     fused functionality that is theoretically orthogonal (i.e.
236   //     .prefetch()) into the implementation.
237   //  2. Drop-in replacement for standard interleave. The goal will be to
238   //     auto-opt people into an optimized implementation without any work
239   //     on the customer's part. We thus go through great pains to maintain
240   //     identical iteration orders, full determinism (disabled only via a
241   //     flag, etc.)
242   //  3. Performance across a variety of environments and I/O envelopes.
243   //
244   // The actual implementation centers around a collection of worker threads
245   // and their corresponding worker state (tracked in the `workers_` vector).
246   // Worker threads repeatedly receive a vector of Tensors that are used as
247   // input to the flat-map function (`captured_func_`). The output of this
248   // function must be a dataset. The worker thread then repeatedly calls
249   // `GetNext()`, maintaining a buffer of elements to minimize the likelihood
250   // that a caller will block waiting for an element to be produced.
251   //
252   // Pointers to these worker states are kept in 2 disjoint data structures:
253   //  1. `interleave_indices_` is a vector containing indices of WorkerStates
254   //     in `workers_` that we are interleaving. Worker threads backing these
255   //     WorkerStates should be regularly producing values.
256   //  2. `staging_indices_` is a deque containing indices of WorkerStates in
257   //     `workers_` that we will move to `interleave_indices_` when an
258   //     iterator in `interleave_indices_` is exhausted.
259   //
260   // The client calls `GetNext[Internal]()` to retrieve an output element. The
261   // internal implementation updates the state of `interleave_indices_` and
262   // `staging_indices_` as output iterators (run by the worker threads) are
263   // exhausted.
264   //
265   // `input_impl_` is the input iterator that generates arguments for the
266   // flat-map function (`captured_func_`). It is set to an iterator at
267   // Iterator construction, and is fixed until we consume all input elements.
268   // Once it is exhausted, we reset the unique_ptr to eagerly deallocate
269   // memory.
270   //
271   // A few invariants are maintained:
272   //  1. No element in interleave_indices_ should be a -1 unless
273   //     `staging_indices_` is empty and `input_impl_` is empty.
274   //  2. Every `worker_` element is pointed to by at most one element of the
275   //     union of `interleave_indices_` and `staging_indices_`.
276   //  3. Unless `input_impl_` is empty, every `worker_` must be pointed to by
277   //     an element in `interleave_indices_` or `staging_indices_`.
278   class Iterator : public DatasetIterator<Dataset> {
279    public:
Iterator(const Params & params,bool deterministic)280     explicit Iterator(const Params& params, bool deterministic)
281         : DatasetIterator<Dataset>(params),
282           deterministic_(deterministic),
283           workers_(dataset()->num_threads()),
284           worker_thread_states_(dataset()->num_threads()) {}
285 
~Iterator()286     ~Iterator() override {
287       CancelThreads();
288       if (deregister_fn_) deregister_fn_();
289     }
290 
291     // TODO(jsimsa): Register cancellation callback once the implementation is
292     // refactored not to hold mu_ while calling `GetNext` on the input.
Initialize(IteratorContext * ctx)293     Status Initialize(IteratorContext* ctx) override {
294       cancellation_manager_ = std::make_unique<CancellationManager>();
295       IteratorContext::Params params(ctx);
296       params.cancellation_manager = cancellation_manager_.get();
297       TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
298           IteratorContext(params), this, prefix(), &input_impl_));
299       return dataset()->captured_func_->Instantiate(
300           ctx, &instantiated_captured_func_);
301     }
302 
303     // It is implemented so that it matches the deterministic interleave
304     // unless getting the next element would block and we are allowed to be
305     // nondeterministic.
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)306     Status GetNextInternal(IteratorContext* ctx,
307                            std::vector<Tensor>* out_tensors,
308                            bool* end_of_sequence) override {
309       mutex_lock l(mu_);
310       TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
311       while (!cancelled_) {
312         // Wait for an item to become available, blocking if necessary. If we
313         // are allowed to be nondeterministic, we can skip over input datasets
314         // that do not have an item readily available.
315         bool can_produce_elements = false;
316         bool must_wait_for_input = true;
317         for (int64_t i = 0; i < interleave_indices_.size(); ++i) {
318           int64_t index = (next_index_ + i) % interleave_indices_.size();
319           int64_t current_worker_index = interleave_indices_[index];
320           if (current_worker_index < 0) {
321             continue;  // Empty interleave elements.
322           }
323           WorkerState* current_worker = &workers_[current_worker_index];
324           can_produce_elements |= current_worker->MayHaveElements();
325           if (!current_worker->outputs.empty()) {
326             // We have an element!
327             next_index_ = index;
328             const bool element_acquired_sloppily = !deterministic_ && i > 1;
329             if (!element_acquired_sloppily) {
330               // If the element was acquired in the regular (deterministic)
331               // order, then advance the current block and cycle pointers to
332               // the next element in the regular order.
333               block_count_++;
334               if (block_count_ == dataset()->block_length_) {
335                 next_index_ = (index + 1) % interleave_indices_.size();
336                 block_count_ = 0;
337               }
338             } else {
339               block_count_ = 0;
340             }
341             *end_of_sequence = false;
342             Status s = current_worker->outputs.front().status;
343             profiler::TraceMe traceme([&] {
344               return profiler::TraceMeEncode(
345                   "ParallelInterleaveConsume",
346                   {{"element_id", current_worker->outputs.front().id}});
347             });
348             current_worker->outputs.front().output.swap(*out_tensors);
349             current_worker->outputs.pop_front();
350             current_worker->cond_var.notify_one();
351             return s;
352           } else if (current_worker->is_producing && deterministic_) {
353             // current_worker.outputs.empty(), and we must wait for this
354             // iterator.
355             if (next_index_ != index) {
356               // We have advanced to a new iterator; reset block counts.
357               next_index_ = index;
358               block_count_ = 0;
359             }
360             break;
361           } else if (!current_worker->is_producing) {
362             // This iterator has reached end of input.
363             interleave_indices_[index] = -1;
364             if (input_impl_) {
365               // Start prefetching a new iterator.
366               std::vector<Tensor> args;
367               bool end_of_input = false;
368               Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
369               if (end_of_input) {
370                 input_impl_.reset();
371               } else {
372                 current_worker->SetInputs(s, std::move(args));
373                 staging_indices_.emplace_back(current_worker_index);
374               }
375             }
376 
377             if (!staging_indices_.empty()) {
378               // Move a worker from `staging_indices_` to
379               // `interleave_indices_`.
380               interleave_indices_[index] = staging_indices_.front();
381               staging_indices_.pop_front();
382               {
383                 mutex_lock ckpt_l(ckpt_mu_);
384                 if (worker_thread_states_[interleave_indices_[index]]
385                         .iterator != nullptr) {
386                   // TODO(wilsin): Write a unit test where we iterate through a
387                   // dataset, pause, and check the model proto autotune value.
388                   EnableAutotune(
389                       ctx, worker_thread_states_[interleave_indices_[index]]
390                                .iterator.get());
391                 }
392               }
393               next_index_ = (index + 1) % interleave_indices_.size();
394               block_count_ = 0;
395               // Restart the inner [for] loop
396               can_produce_elements = true;
397               must_wait_for_input = false;
398               break;
399             }
400           }
401         }
402 
403         if (!can_produce_elements && !input_impl_) {
404           // No potential for future values.
405           *end_of_sequence = true;
406           return OkStatus();
407         }
408 
409         if (must_wait_for_input) {
410           // Wait for elements to become available.
411           RecordStop(ctx);
412           if (deterministic_) {
413             workers_[interleave_indices_[next_index_]].cond_var.wait(l);
414           } else {
415             any_element_available_cond_var_.wait(l);
416           }
417           RecordStart(ctx);
418         }
419       }
420       return errors::Cancelled(
421           "ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext");
422     }
423 
424    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const425     std::shared_ptr<model::Node> CreateNode(
426         IteratorContext* ctx, model::Node::Args args) const override {
427       return model::MakeAsyncInterleaveManyNode(
428           std::move(args), {model::MakeNonTunableParameter(
429                                 kCycleLength, dataset()->cycle_length_),
430                             model::MakeNonTunableParameter(
431                                 kDeterministic, deterministic_ ? 1.0 : 0.0)});
432     }
433 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)434     Status SaveInternal(SerializationContext* ctx,
435                         IteratorStateWriter* writer) override {
436       TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
437           dataset()->captured_func_->CheckExternalState()));
438       // The order of locking is important here to avoid deadlock.
439       mutex_lock l(mu_);
440       mutex_lock ckpt_l(ckpt_mu_);
441       if (input_impl_) {
442         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
443       } else {
444         TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kInputExhausted, ""));
445       }
446       TF_RETURN_IF_ERROR(
447           writer->WriteScalar(prefix(), kNextIndex, next_index_));
448       TF_RETURN_IF_ERROR(
449           writer->WriteScalar(prefix(), kBlockCount, block_count_));
450       TF_RETURN_IF_ERROR(
451           writer->WriteScalar(prefix(), kWorkersSize, workers_.size()));
452       for (int i = 0; i < workers_.size(); ++i) {
453         TF_RETURN_IF_ERROR(WriteWorkerStateLocked(writer, i));
454       }
455       for (int i = 0; i < worker_thread_states_.size(); ++i) {
456         TF_RETURN_IF_ERROR(WriteWorkerThreadStateLocked(ctx, writer, i));
457       }
458       TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kInterleaveSize,
459                                              interleave_indices_.size()));
460       for (int i = 0; i < interleave_indices_.size(); ++i) {
461         TF_RETURN_IF_ERROR(writer->WriteScalar(
462             prefix(), strings::StrCat(kInterleaveIndices, "_", i),
463             interleave_indices_[i]));
464       }
465       TF_RETURN_IF_ERROR(
466           writer->WriteScalar(prefix(), kStagingSize, staging_indices_.size()));
467       for (int i = 0; i < staging_indices_.size(); ++i) {
468         TF_RETURN_IF_ERROR(writer->WriteScalar(
469             prefix(), strings::StrCat(kStagingIndices, "_", i),
470             staging_indices_[i]));
471       }
472       if (!worker_threads_.empty()) {
473         TF_RETURN_IF_ERROR(
474             writer->WriteScalar(prefix(), kWorkerThreadsRunning, ""));
475       }
476       return OkStatus();
477     }
478 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)479     Status RestoreInternal(IteratorContext* ctx,
480                            IteratorStateReader* reader) override {
481       {
482         // The order of locking is important here to avoid deadlock.
483         mutex_lock l(mu_);
484         mutex_lock ckpt_l(ckpt_mu_);
485         if (!reader->Contains(prefix(), kInputExhausted)) {
486           TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
487         } else {
488           input_impl_.reset();
489         }
490         int64_t temp;
491         TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kNextIndex, &temp));
492         next_index_ = size_t(temp);
493         TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kBlockCount, &temp));
494         block_count_ = size_t(temp);
495 
496         // Restore WorkerStates.
497         TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kWorkersSize, &temp));
498         if (temp != dataset()->num_threads()) {
499           return errors::Internal("Expected ", dataset()->num_threads(),
500                                   " worker states but found ", temp, ".");
501         }
502         for (size_t i = 0; i < dataset()->num_threads(); ++i) {
503           TF_RETURN_IF_ERROR(ReadWorkerStateLocked(ctx, reader, i));
504         }
505       }
506       std::unique_ptr<thread::ThreadPool> threadpool = ctx->CreateThreadPool(
507           "read_worker_thread_state", dataset()->num_threads());
508       Status s = OkStatus();
509       BlockingCounter counter(dataset()->num_threads());
510       for (size_t i = 0; i < dataset()->num_threads(); ++i) {
511         threadpool->Schedule([this, i, ctx, reader, &s, &counter] {
512           WorkerThreadState state;
513           Status result = ReadWorkerThreadStateLocked(ctx, reader, i, &state);
514           mutex_lock l(mu_);
515           mutex_lock ckpt_l(ckpt_mu_);
516           if (!result.ok()) {
517             s.Update(result);
518             counter.DecrementCount();
519             return;
520           }
521           worker_thread_states_[i] = std::move(state);
522           counter.DecrementCount();
523         });
524       }
525       counter.Wait();
526       if (!s.ok()) {
527         return s;
528       }
529 
530       mutex_lock l(mu_);
531       mutex_lock ckpt_l(ckpt_mu_);
532       // Restore `interleave_indices_`.
533       std::set<int64_t> all_indices;
534       {
535         int64_t interleave_size;
536         TF_RETURN_IF_ERROR(
537             reader->ReadScalar(prefix(), kInterleaveSize, &interleave_size));
538         interleave_indices_.reserve(interleave_size);
539         for (int64_t i = 0; i < interleave_size; ++i) {
540           int64_t temp;
541           TF_RETURN_IF_ERROR(reader->ReadScalar(
542               prefix(), strings::StrCat(kInterleaveIndices, "_", i), &temp));
543           if (temp >= 0 && all_indices.find(temp) != all_indices.end()) {
544             return errors::Internal(
545                 "Duplicate entry for ", temp,
546                 " found when reading interleave and staging indices.");
547           }
548           if (temp >= 0) {
549             all_indices.insert(temp);
550           }
551           interleave_indices_.emplace_back(temp);
552         }
553       }
554 
555       // Restore `staging_indices_`.
556       {
557         int64_t staging_size;
558         TF_RETURN_IF_ERROR(
559             reader->ReadScalar(prefix(), kStagingSize, &staging_size));
560         for (int i = 0; i < staging_size; ++i) {
561           int64_t temp;
562           TF_RETURN_IF_ERROR(reader->ReadScalar(
563               prefix(), strings::StrCat(kStagingIndices, "_", i), &temp));
564           if (all_indices.find(temp) != all_indices.end()) {
565             return errors::Internal(
566                 "Duplicate entry for ", temp,
567                 " found when reading interleave and staging indices.");
568           }
569           if (temp >= 0) {
570             all_indices.insert(temp);
571           }
572           staging_indices_.emplace_back(temp);
573         }
574       }
575 
576       // Start Worker threads.
577       if (reader->Contains(prefix(), kWorkerThreadsRunning)) {
578         worker_threads_.reserve(dataset()->num_threads());
579         for (size_t i = 0; i < dataset()->num_threads(); ++i) {
580           std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
581           worker_threads_.emplace_back(ctx->StartThread(
582               strings::StrCat(kDataParallelInterleaveWorker, "_", i),
583               [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
584         }
585       }
586       return OkStatus();
587     }
588 
GetTraceMeMetadata() const589     TraceMeMetadata GetTraceMeMetadata() const override {
590       return dataset()->traceme_metadata_;
591     }
592 
593    private:
594     // OutputElem contains the information from a call to GetNext by an output
595     // iterator.
596     struct OutputElem {
597       // The output iterator sets `status` if getting the output element
598       // fails.
599       Status status;
600       // The buffered data element.
601       std::vector<Tensor> output;
602       int64_t id = -1;
603 
OutputElemtensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::OutputElem604       explicit OutputElem(const Status& s) : status(s) {}
OutputElemtensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::OutputElem605       OutputElem(const Status& s, int64_t id) : status(s), id(id) {}
606     };
607 
608     // Worker threads operate on their relevant WorkerState structs.
609     //
610     // WorkerState's fields are all protected by mu_;
611     struct WorkerState {
612       // The arguments to be used to construct an output iterator.
613       std::vector<Tensor> input;
614       // The buffered output elements.
615       std::deque<OutputElem> outputs;
616       // Set to true iff the worker thread expects to append more elements to
617       // outputs. is_producing can be false despite !outputs.empty().
618       // Concretely, all output elements will have been consumed only when:
619       // is_producing == false && outputs.empty();
620       bool is_producing = false;
621       // Condition variable used to coordinate between threads. The worker
622       // thread waits on this condition variable when it is either (1) waiting
623       // for the main thread to add arguments to `input`, or (2) waiting for
624       // the main thread to consume an element of `outputs`. The main thread
625       // waits on cond_var if it is waiting for the worker thread to produce
626       // an element into `outputs` (this implies deterministic==true).
627       condition_variable cond_var;
628 
MayHaveElementstensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::WorkerState629       inline bool MayHaveElements() const {
630         return is_producing || !outputs.empty();
631       }
632 
633       // Sets inputs for a worker thread and notifies it to start processing.
SetInputstensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::WorkerState634       void SetInputs(const Status& s, std::vector<Tensor> input_arguments) {
635         if (s.ok()) {
636           DCHECK(!MayHaveElements())
637               << "Tried to start inputs, despite already producing!";
638           input = std::move(input_arguments);
639           is_producing = true;
640           cond_var.notify_one();
641         } else {
642           outputs.emplace_back(s);
643         }
644       }
645     };
646 
647     // The internal state of a worker thread that is not already captured
648     // in its `WorkerState`.
649     //
650     // This is needed only for checkpointing purposes. We keep this
651     // separate from `WorkerState` and guard its fields using a separate
652     // lock `ckpt_mu_` so as to not affect the performance of main pipeline.
653     struct WorkerThreadState {
654       // The output element that has been produced from the input iterator
655       // and is waiting to be added to `WorkerState.outputs`.
656       OutputElem output_elem;
657 
658       // Whether the input iterator returned an `end_of_sequence`.
659       bool end_of_sequence = false;
660 
661       // Status returned from `MakeIteratorFromInputElement`.
662       Status iterator_creation_status;
663 
664       // The arguments to be used to construct `iterator`.
665       std::vector<Tensor> input;
666 
667       std::unique_ptr<IteratorBase> iterator;
668 
WorkerThreadStatetensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::WorkerThreadState669       WorkerThreadState() : output_elem(OkStatus()) {}
670     };
671 
CancelThreads()672     void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
673       cancellation_manager_->StartCancel();
674       mutex_lock l(mu_);
675       cancelled_ = true;
676       for (auto& worker : workers_) {
677         worker.cond_var.notify_all();
678       }
679     }
680 
EnsureWorkerThreadsStarted(IteratorContext * ctx)681     Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
682         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
683       if (worker_threads_.empty() && input_impl_) {
684         worker_threads_.reserve(dataset()->num_threads());
685         for (int64_t i = 0; i < dataset()->num_threads(); ++i) {
686           std::vector<Tensor> args;
687           bool end_of_input = false;
688           Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
689           if (end_of_input) {
690             input_impl_.reset();
691             return OkStatus();
692           }
693           if (i < dataset()->cycle_length_) {
694             interleave_indices_.push_back(i);
695           } else {
696             staging_indices_.push_back(i);
697           }
698           workers_[i].SetInputs(s, std::move(args));
699           std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
700           worker_threads_.push_back(ctx->StartThread(
701               strings::StrCat(kDataParallelInterleaveWorker, "_", i),
702               [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
703         }
704         DCHECK(interleave_indices_.size() == dataset()->cycle_length_);
705         DCHECK(staging_indices_.size() == dataset()->prefetch_input_elements_);
706       }
707       return OkStatus();
708     }
709 
710     // Produces elements into the worker's output buffers.
WorkerThread(const std::shared_ptr<IteratorContext> & ctx,const int64_t thread_index)711     void WorkerThread(const std::shared_ptr<IteratorContext>& ctx,
712                       const int64_t thread_index) {
713       // Notes on checkpointing thread local state, i.e., `WorkerThreadState`:
714       //
715       // 1. Any local state that may need to be checkpointed should be kept
716       //    in `worker_thread_states_[thread_index]`.
717       // 2. `WorkerThreadState` should contain state that is needed only for
718       //    checkpointing, i.e., if we were to remove checkpointing support,
719       //    we could keep that state as local variables in this thread.
720       // 3. This thread should only read/write state at `thread_index`
721       //    and should not access other thread states.
722       // 4. When restoring from checkpoint, threads are started only after
723       //    the restore is complete.
724       // 5. Once restored from a checkpoint, the local state is edited only
725       //    by this thread. 3 & 4 allow making assumptions like temporarily
726       //    caching local state in this thread and using it outside a lock
727       //    e.g. `make_new_iterator`.
728       // 6. `ckpt_mu_` should be wisely used to create *consistent*
729       //    checkpoint markers.
730 
731       // std::function arguments are copy-constructable, so we pass raw
732       // pointers, and then immediately wrap them to ensure correct ownership.
733       RecordStart(ctx.get());
734       auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
735         mutex_lock l(mu_);
736         workers_[thread_index].cond_var.notify_all();
737         RecordStop(ctx.get());
738       });
739       bool make_new_iterator;
740       {
741         tf_shared_lock l(ckpt_mu_);
742         // Decide whether a new iterator should be built.
743         // 1. If there is an existing iterator, we use it.
744         // 2. If there was an error in iterator creation that could not be
745         //    notified to the client we attempt to send that to the client
746         //    first.
747         make_new_iterator =
748             worker_thread_states_[thread_index].iterator == nullptr &&
749             worker_thread_states_[thread_index].iterator_creation_status.ok();
750       }
751       // Even though `make_new_iterator` has cached values from
752       // `worker_thread_states_[thread_index]` which is guarded by ckpt_mu_,
753       // it is safe to *read* `make_new_iterator`outside of a lock without
754       // worrying about concurrent changes to values in
755       // `worker_thread_states_[thread_index]`. See comment at the start of
756       // this function for details.
757       while (true) {
758         // Whether creation of the iterator succeeded.
759         Status iterator_creation_status;
760         // 1. Build a new iterator or use the existing one.
761         if (make_new_iterator) {
762           // 1a. Get new input tensors or use the exiting ones.
763           bool read_new_input;
764           {
765             tf_shared_lock l(ckpt_mu_);
766             // worker_thread_states_[thread_index].input will be non-empty
767             // if checkpointing happened at CHECKPOINT_MARKER_A.
768             read_new_input = worker_thread_states_[thread_index].input.empty();
769           }
770 
771           if (read_new_input) {
772             mutex_lock l(mu_);
773             while (!cancelled_ && !workers_[thread_index].is_producing) {
774               RecordStop(ctx.get());
775               workers_[thread_index].cond_var.wait(l);
776               RecordStart(ctx.get());
777             }
778             if (cancelled_) return;
779             // Copy the input tensors so that we do not need to block on `mu_`
780             // when building the iterator.
781             // We keep a copy of the input tensors in
782             // `WorkerThreadState.input` till the iterator is in use. This is
783             // used in `RestoreInternal` to re-build the iterator.
784             // TODO(b/78046638): Explore ways to avoid tracking the input
785             // tensors.
786             tf_shared_lock ckpt_l(ckpt_mu_);
787             worker_thread_states_[thread_index].input.swap(
788                 workers_[thread_index].input);
789             // CHECKPOINT_MARKER_A
790             // We have the input tensors but have not built the iterator yet.
791           }
792           bool thread_in_staging = false;
793           {
794             mutex_lock l(mu_);
795             thread_in_staging = absl::c_find(staging_indices_, thread_index) !=
796                                 staging_indices_.end();
797           }
798           // 1b. Run the user defined function to produce a new iterator.
799           {
800             tf_shared_lock l(ckpt_mu_);
801             worker_thread_states_[thread_index].iterator_creation_status =
802                 MakeIteratorFromInputElement(
803                     ctx.get(), this, worker_thread_states_[thread_index].input,
804                     thread_index, *instantiated_captured_func_, prefix(),
805                     &worker_thread_states_[thread_index].iterator,
806                     model_node());
807             iterator_creation_status =
808                 worker_thread_states_[thread_index].iterator_creation_status;
809             if (!iterator_creation_status.ok()) {
810               worker_thread_states_[thread_index].input.clear();
811             } else if (thread_in_staging) {
812               // TODO(wilsin): Write a unit test where we iterate through a
813               // dataset, pause, and check the model proto autotune value.
814               DisableAutotune(
815                   ctx.get(),
816                   worker_thread_states_[thread_index].iterator.get());
817             }
818             // CHECKPOINT_MARKER_B
819             // Either an iterator has been successfully built and placed in
820             // `worker_thread_states_[thread_index].iterator` or it failed and
821             // a non-OK status has been put in
822             // `worker_thread_states_[thread_index].iterator_creation_status`.
823           }
824         } else {
825           tf_shared_lock l(ckpt_mu_);
826           iterator_creation_status =
827               worker_thread_states_[thread_index].iterator_creation_status;
828           // Mark that we have used up the restored iterator.
829           make_new_iterator = true;
830         }
831         // 2. Start producing elements or send error state to client if
832         //    iterator creation failed.
833         if (!iterator_creation_status.ok()) {
834           mutex_lock l(mu_);
835           // Wait for space in the prefetch queue.
836           while (!cancelled_ && workers_[thread_index].outputs.size() ==
837                                     dataset()->buffer_output_elements_) {
838             RecordStop(ctx.get());
839             workers_[thread_index].cond_var.wait(l);
840             RecordStart(ctx.get());
841           }
842           if (cancelled_) return;
843           tf_shared_lock ckpt_l(ckpt_mu_);
844           workers_[thread_index].outputs.emplace_back(iterator_creation_status);
845           workers_[thread_index].is_producing = false;
846           worker_thread_states_[thread_index].iterator_creation_status =
847               OkStatus();
848           // CHECKPOINT_MARKER_C
849           // Non-OK iterator creation status has been notified to the
850           // client.
851           if (deterministic_) {
852             workers_[thread_index].cond_var.notify_one();
853           } else {
854             any_element_available_cond_var_.notify_one();
855           }
856         } else {
857           bool end_of_sequence = false;
858           while (!end_of_sequence) {
859             // 3.a Produce an element!
860             {
861               tf_shared_lock ckpt_l(ckpt_mu_);
862               if (worker_thread_states_[thread_index].output_elem.status.ok() &&
863                   worker_thread_states_[thread_index]
864                       .output_elem.output.empty() &&
865                   !worker_thread_states_[thread_index].end_of_sequence) {
866                 int64_t& id =
867                     worker_thread_states_[thread_index].output_elem.id;
868                 profiler::TraceMe traceme(
869                     [&] {
870                       id = profiler::TraceMe::NewActivityId();
871                       return profiler::TraceMeEncode(
872                           "ParallelInterleaveProduce", {{"element_id", id}});
873                     },
874                     profiler::kInfo);
875                 worker_thread_states_[thread_index].output_elem.status =
876                     worker_thread_states_[thread_index].iterator->GetNext(
877                         ctx.get(),
878                         &worker_thread_states_[thread_index].output_elem.output,
879                         &worker_thread_states_[thread_index].end_of_sequence);
880                 end_of_sequence =
881                     worker_thread_states_[thread_index].end_of_sequence;
882               } else {
883                 end_of_sequence =
884                     worker_thread_states_[thread_index].end_of_sequence;
885               }
886               // CHECKPOINT_MARKER_D
887               // An element has been read or an error or end_of_sequence has
888               // been received from the input iterator and is waiting to be
889               // sent to client.
890             }
891 
892             // 3.b Make it available to the client.
893             {
894               mutex_lock l(mu_);
895 
896               // Wait for space in the prefetch queue.
897               while (!cancelled_ && workers_[thread_index].outputs.size() ==
898                                         dataset()->buffer_output_elements_) {
899                 RecordStop(ctx.get());
900                 workers_[thread_index].cond_var.wait(l);
901                 RecordStart(ctx.get());
902               }
903               if (cancelled_) return;
904 
905               tf_shared_lock ckpt_l(ckpt_mu_);
906               workers_[thread_index].is_producing = !end_of_sequence;
907 
908               // Output the element.
909 
910               // Move the temporary state in WorkerThreadState to WorkerState
911               // and mark it as used.
912               if (end_of_sequence) {
913                 worker_thread_states_[thread_index].iterator.reset();
914                 worker_thread_states_[thread_index].input.clear();
915                 worker_thread_states_[thread_index].end_of_sequence = false;
916               } else {
917                 workers_[thread_index].outputs.emplace_back(
918                     worker_thread_states_[thread_index].output_elem.status,
919                     worker_thread_states_[thread_index].output_elem.id);
920                 workers_[thread_index].outputs.back().output.swap(
921                     worker_thread_states_[thread_index].output_elem.output);
922               }
923               worker_thread_states_[thread_index].output_elem.status =
924                   OkStatus();
925               if (deterministic_) {
926                 workers_[thread_index].cond_var.notify_one();
927               } else {
928                 any_element_available_cond_var_.notify_one();
929               }
930               // CHECKPOINT_MARKER_E
931               // Output element or iterator status has been sent to the
932               // client.
933             }
934           }
935         }
936       }
937     }
938 
WriteWorkerStateLocked(IteratorStateWriter * writer,int index)939     Status WriteWorkerStateLocked(IteratorStateWriter* writer, int index)
940         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
941       string iterator_name =
942           strings::StrCat(prefix(), "::", kWorker, "_", index);
943       TF_RETURN_IF_ERROR(writer->WriteScalar(iterator_name, kInputSize,
944                                              workers_[index].input.size()));
945       for (int i = 0; i < workers_[index].input.size(); ++i) {
946         TF_RETURN_IF_ERROR(writer->WriteTensor(iterator_name,
947                                                strings::StrCat(kInput, "_", i),
948                                                workers_[index].input[i]));
949       }
950       TF_RETURN_IF_ERROR(writer->WriteScalar(iterator_name, kOutputsSize,
951                                              workers_[index].outputs.size()));
952       for (int i = 0; i < workers_[index].outputs.size(); ++i) {
953         TF_RETURN_IF_ERROR(WriteOutputElemLocked(
954             writer, workers_[index].outputs[i], iterator_name,
955             strings::StrCat(kOutputs, "_", i)));
956       }
957       if (workers_[index].is_producing) {
958         TF_RETURN_IF_ERROR(
959             writer->WriteScalar(iterator_name, kIsProducing, ""));
960       }
961       return OkStatus();
962     }
963 
ReadWorkerStateLocked(IteratorContext * ctx,IteratorStateReader * reader,int index)964     Status ReadWorkerStateLocked(IteratorContext* ctx,
965                                  IteratorStateReader* reader, int index)
966         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
967       string worker_prefix =
968           strings::StrCat(prefix(), "::", kWorker, "_", index);
969       // Restore inputs.
970       int64_t input_size;
971       TF_RETURN_IF_ERROR(
972           reader->ReadScalar(worker_prefix, kInputSize, &input_size));
973       workers_[index].input.reserve(input_size);
974       for (int i = 0; i < input_size; ++i) {
975         workers_[index].input.emplace_back();
976         TF_RETURN_IF_ERROR(reader->ReadTensor(ctx->flr(), worker_prefix,
977                                               strings::StrCat(kInput, "_", i),
978                                               &workers_[index].input.back()));
979       }
980       int64_t outputs_size;
981       TF_RETURN_IF_ERROR(
982           reader->ReadScalar(worker_prefix, kOutputsSize, &outputs_size));
983       for (int i = 0; i < outputs_size; ++i) {
984         workers_[index].outputs.emplace_back(OkStatus());
985         TF_RETURN_IF_ERROR(ReadOutputElemLocked(
986             ctx, reader, &workers_[index].outputs.back(), worker_prefix,
987             strings::StrCat(kOutputs, "_", i)));
988       }
989       if (reader->Contains(worker_prefix, kIsProducing)) {
990         workers_[index].is_producing = true;
991       } else {
992         workers_[index].is_producing = false;
993       }
994       return OkStatus();
995     }
996 
WriteWorkerThreadStateLocked(SerializationContext * ctx,IteratorStateWriter * writer,int index)997     Status WriteWorkerThreadStateLocked(SerializationContext* ctx,
998                                         IteratorStateWriter* writer, int index)
999         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
1000       string iterator_name =
1001           strings::StrCat(prefix(), "::", kWorkerThread, "_", index);
1002       if (worker_thread_states_[index].iterator != nullptr) {
1003         TF_RETURN_IF_ERROR(
1004             SaveInput(ctx, writer, worker_thread_states_[index].iterator));
1005       } else {
1006         TF_RETURN_IF_ERROR(
1007             writer->WriteScalar(iterator_name, kIteratorExhausted, ""));
1008       }
1009       TF_RETURN_IF_ERROR(
1010           writer->WriteScalar(iterator_name, kInputSize,
1011                               worker_thread_states_[index].input.size()));
1012       for (int i = 0; i < worker_thread_states_[index].input.size(); ++i) {
1013         TF_RETURN_IF_ERROR(
1014             writer->WriteTensor(iterator_name, strings::StrCat(kInput, "_", i),
1015                                 worker_thread_states_[index].input[i]));
1016       }
1017       TF_RETURN_IF_ERROR(WriteStatusLocked(
1018           writer, iterator_name, kIteratorCreationStatus,
1019           worker_thread_states_[index].iterator_creation_status));
1020       TF_RETURN_IF_ERROR(WriteOutputElemLocked(
1021           writer, worker_thread_states_[index].output_elem, iterator_name,
1022           kOutput));
1023       if (worker_thread_states_[index].end_of_sequence) {
1024         TF_RETURN_IF_ERROR(
1025             writer->WriteScalar(iterator_name, kEndOfSequence, ""));
1026       }
1027       return OkStatus();
1028     }
1029 
ReadWorkerThreadStateLocked(IteratorContext * ctx,IteratorStateReader * reader,int index,WorkerThreadState * state)1030     Status ReadWorkerThreadStateLocked(IteratorContext* ctx,
1031                                        IteratorStateReader* reader, int index,
1032                                        WorkerThreadState* state) {
1033       string worker_prefix =
1034           strings::StrCat(prefix(), "::", kWorkerThread, "_", index);
1035       // Restore inputs.
1036       int64_t input_size;
1037       TF_RETURN_IF_ERROR(
1038           reader->ReadScalar(worker_prefix, kInputSize, &input_size));
1039       state->input.reserve(input_size);
1040       for (int i = 0; i < input_size; ++i) {
1041         state->input.emplace_back();
1042         TF_RETURN_IF_ERROR(reader->ReadTensor(ctx->flr(), worker_prefix,
1043                                               strings::StrCat(kInput, "_", i),
1044                                               &state->input.back()));
1045       }
1046       // Restore iterator
1047       if (reader->Contains(worker_prefix, kIteratorExhausted)) {
1048         state->iterator.reset();
1049       } else {
1050         std::unique_ptr<IteratorBase> iterator;
1051         // NOTE: We intentionally ignore resource modeling outside GetNext().
1052         TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
1053             ctx, this, state->input, index, *instantiated_captured_func_,
1054             prefix(), &iterator, /*node=*/nullptr));
1055         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator));
1056         state->iterator.swap(iterator);
1057       }
1058       TF_RETURN_IF_ERROR(ReadStatusLocked(reader, worker_prefix,
1059                                           kIteratorCreationStatus,
1060                                           &state->iterator_creation_status));
1061       TF_RETURN_IF_ERROR(ReadOutputElemLocked(ctx, reader, &state->output_elem,
1062                                               worker_prefix, kOutput));
1063       if (reader->Contains(worker_prefix, kEndOfSequence)) {
1064         state->end_of_sequence = true;
1065       } else {
1066         state->end_of_sequence = false;
1067       }
1068       return OkStatus();
1069     }
1070 
WriteOutputElemLocked(IteratorStateWriter * writer,const OutputElem & output_elem,const string & iterator_name,const string & prefix)1071     Status WriteOutputElemLocked(IteratorStateWriter* writer,
1072                                  const OutputElem& output_elem,
1073                                  const string& iterator_name,
1074                                  const string& prefix)
1075         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
1076       TF_RETURN_IF_ERROR(WriteStatusLocked(
1077           writer, iterator_name, strings::StrCat(prefix, "_", kStatus),
1078           output_elem.status));
1079       TF_RETURN_IF_ERROR(writer->WriteScalar(
1080           iterator_name, strings::StrCat(prefix, "_", kOutputSize),
1081           output_elem.output.size()));
1082       for (int i = 0; i < output_elem.output.size(); ++i) {
1083         TF_RETURN_IF_ERROR(writer->WriteTensor(
1084             iterator_name, strings::StrCat(prefix, "_", kOutput, "_", i),
1085             output_elem.output[i]));
1086       }
1087       return OkStatus();
1088     }
1089 
ReadOutputElemLocked(IteratorContext * ctx,IteratorStateReader * reader,OutputElem * output_elem,const string & iterator_name,const string & prefix)1090     Status ReadOutputElemLocked(IteratorContext* ctx,
1091                                 IteratorStateReader* reader,
1092                                 OutputElem* output_elem,
1093                                 const string& iterator_name,
1094                                 const string& prefix) {
1095       TF_RETURN_IF_ERROR(ReadStatusLocked(reader, iterator_name,
1096                                           strings::StrCat(prefix, "_", kStatus),
1097                                           &output_elem->status));
1098       int64_t output_size;
1099       TF_RETURN_IF_ERROR(reader->ReadScalar(
1100           iterator_name, strings::StrCat(prefix, "_", kOutputSize),
1101           &output_size));
1102       output_elem->output.reserve(output_size);
1103       for (int i = 0; i < output_size; ++i) {
1104         output_elem->output.emplace_back();
1105         TF_RETURN_IF_ERROR(
1106             reader->ReadTensor(ctx->flr(), iterator_name,
1107                                strings::StrCat(prefix, "_", kOutput, "_", i),
1108                                &output_elem->output.back()));
1109       }
1110       return OkStatus();
1111     }
1112 
WriteStatusLocked(IteratorStateWriter * writer,const string & iterator_name,const string & prefix,const Status & status)1113     Status WriteStatusLocked(IteratorStateWriter* writer,
1114                              const string& iterator_name, const string& prefix,
1115                              const Status& status)
1116         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
1117       TF_RETURN_IF_ERROR(writer->WriteScalar(
1118           iterator_name, strings::StrCat(prefix, "_", kCode),
1119           static_cast<int64_t>(status.code())));
1120       if (!status.ok()) {
1121         TF_RETURN_IF_ERROR(writer->WriteScalar(
1122             iterator_name, strings::StrCat(prefix, "_", KMessage),
1123             status.error_message()));
1124       }
1125       return OkStatus();
1126     }
1127 
ReadStatusLocked(IteratorStateReader * reader,const string & iterator_name,const string & prefix,Status * status)1128     Status ReadStatusLocked(IteratorStateReader* reader,
1129                             const string& iterator_name, const string& prefix,
1130                             Status* status) {
1131       int64_t code_int;
1132       TF_RETURN_IF_ERROR(reader->ReadScalar(
1133           iterator_name, strings::StrCat(prefix, "_", kCode), &code_int));
1134       error::Code code = static_cast<error::Code>(code_int);
1135 
1136       if (code != error::Code::OK) {
1137         tstring error_message;
1138         TF_RETURN_IF_ERROR(reader->ReadScalar(
1139             iterator_name, strings::StrCat(prefix, "_", KMessage),
1140             &error_message));
1141         *status = Status(code, error_message);
1142       } else {
1143         *status = OkStatus();
1144       }
1145       return OkStatus();
1146     }
1147 
1148     // Mutex & condition variable to guard mutable iterator internals and
1149     // coordinate among worker threads and client thread[s].
1150     mutex mu_ TF_ACQUIRED_BEFORE(ckpt_mu_);
1151     // The main thread waits on this condition variable if running in
1152     // nondeterministic mode and no values are available.
1153     condition_variable any_element_available_cond_var_;
1154     // Whether outputs must be produced in deterministic order.
1155     const bool deterministic_;
1156     // Mutex used to wait for a consistent state while checkpointing.
1157     // Only Save and Restore require an exclusive lock on this mutex. In
1158     // other scenarios we just acquire a shared lock so the pipeline's
1159     // performance should not be affected in the absence of checkpointing.
1160     // A thread must not wait on any condition variable while holding
1161     // `ckpt_mu_` in either shared or exclusive modes.
1162     mutex ckpt_mu_;
1163 
1164     // Controls cancellation of `input_impl_`. Must be ordered before
1165     // `input_impl_` so that `input_impl_` is destroyed first.
1166     std::unique_ptr<CancellationManager> cancellation_manager_;
1167 
1168     // The iterator producing elements which are converted to datasets by
1169     // the dataset()->captured_func_ then interleaved together.
1170     // input_impl_ is reset when we have exhausted its input.
1171     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
1172 
1173     std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
1174 
1175     // The WorkerState structs the worker threads operate on.
1176     // workers_ elements are in at most one of interleave_ and staging_.
1177     std::vector<WorkerState> workers_ TF_GUARDED_BY(mu_);
1178 
1179     // Stores the temporary state of WorkerThreads which is not stored in
1180     // WorkerState. This is used for checkpointing purposes only.
1181     std::vector<WorkerThreadState> worker_thread_states_
1182         TF_GUARDED_BY(ckpt_mu_);
1183 
1184     // Indices in `workers_` of iterators to interleave.
1185     std::vector<int64_t> interleave_indices_ TF_GUARDED_BY(mu_);
1186     // Indices in `workers_` of prefetched iterators.
1187     std::deque<int64_t> staging_indices_ TF_GUARDED_BY(mu_);
1188 
1189     // The index into output_elements_ for next element to produce.
1190     size_t next_index_ TF_GUARDED_BY(mu_) = 0;
1191     // The number of items produced so far within the block
1192     size_t block_count_ TF_GUARDED_BY(mu_) = 0;
1193     // Flag to instruct the worker threads to exit.
1194     bool cancelled_ TF_GUARDED_BY(mu_) = false;
1195     // The worker threads. This must be last to ensure the
1196     // threads have exited before any other members are deallocated.
1197     // TODO(b/65178177): Avoid allocating additional threads.
1198     std::vector<std::unique_ptr<Thread>> worker_threads_ TF_GUARDED_BY(mu_);
1199 
1200     // Method for deregistering the cancellation callback.
1201     std::function<void()> deregister_fn_;
1202   };
1203 
1204   const DatasetBase* const input_;
1205   const std::unique_ptr<CapturedFunction> captured_func_;
1206   const int64_t cycle_length_;
1207   const int64_t block_length_;
1208   const DeterminismPolicy deterministic_;
1209   const int64_t buffer_output_elements_;
1210   const int64_t prefetch_input_elements_;
1211   const DataTypeVector output_types_;
1212   const std::vector<PartialTensorShape> output_shapes_;
1213   const TraceMeMetadata traceme_metadata_;
1214   const int op_version_;
1215 };
1216 
ParallelInterleaveDatasetOp(OpKernelConstruction * ctx)1217 ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
1218     OpKernelConstruction* ctx)
1219     : UnaryDatasetOpKernel(ctx),
1220       op_version_(ctx->HasAttr(kDeterministic) ? 2 : 1) {
1221   OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kFunc, /*params=*/{},
1222                                                &func_metadata_));
1223   if (op_version_ == 2) {
1224     std::string deterministic;
1225     OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
1226     OP_REQUIRES_OK(
1227         ctx, DeterminismPolicy::FromString(deterministic, &deterministic_));
1228   }
1229   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
1230   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
1231 }
1232 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)1233 void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
1234                                               DatasetBase* input,
1235                                               DatasetBase** output) {
1236   int64_t cycle_length = 0;
1237   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kCycleLength, &cycle_length));
1238   OP_REQUIRES(ctx, cycle_length > 0,
1239               errors::InvalidArgument("`cycle_length` must be > 0"));
1240 
1241   int64_t block_length = 0;
1242   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBlockLength, &block_length));
1243   OP_REQUIRES(ctx, block_length > 0,
1244               errors::InvalidArgument("`block_length` must be > 0"));
1245 
1246   if (op_version_ == 1) {
1247     bool sloppy = false;
1248     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kSloppy, &sloppy));
1249     if (sloppy) {
1250       deterministic_ =
1251           DeterminismPolicy(DeterminismPolicy::Type::kNondeterministic);
1252     } else {
1253       deterministic_ =
1254           DeterminismPolicy(DeterminismPolicy::Type::kDeterministic);
1255     }
1256   }
1257 
1258   int64_t buffer_output_elements = 0;
1259   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBufferOutputElements,
1260                                           &buffer_output_elements));
1261   OP_REQUIRES(ctx, buffer_output_elements > 0,
1262               errors::InvalidArgument("`buffer_output_elements` must be > 0"));
1263 
1264   int64_t prefetch_input_elements = 0;
1265   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kPrefetchInputElements,
1266                                           &prefetch_input_elements));
1267   OP_REQUIRES(
1268       ctx, prefetch_input_elements >= 0,
1269       errors::InvalidArgument("`prefetch_input_elements` must be >= 0"));
1270 
1271   std::unique_ptr<CapturedFunction> captured_func;
1272   OP_REQUIRES_OK(ctx,
1273                  CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
1274                                           &captured_func));
1275 
1276   *output = new Dataset(ctx, input, std::move(captured_func), cycle_length,
1277                         block_length, deterministic_, buffer_output_elements,
1278                         prefetch_input_elements, output_types_, output_shapes_,
1279                         op_version_);
1280 }
1281 
1282 namespace {
1283 REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
1284                         ParallelInterleaveDatasetOp);
1285 REGISTER_KERNEL_BUILDER(
1286     Name("ExperimentalParallelInterleaveDataset").Device(DEVICE_CPU),
1287     ParallelInterleaveDatasetOp);
1288 REGISTER_KERNEL_BUILDER(
1289     Name("LegacyParallelInterleaveDatasetV2").Device(DEVICE_CPU),
1290     ParallelInterleaveDatasetOp);
1291 
1292 REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDataset");
1293 REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalParallelInterleaveDataset");
1294 REGISTER_INPUT_COLOCATION_EXEMPTION("LegacyParallelInterleaveDatasetV2");
1295 
1296 }  // namespace
1297 }  // namespace experimental
1298 }  // namespace data
1299 }  // namespace tensorflow
1300