1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.h"
16
17 #include <atomic>
18 #include <deque>
19 #include <functional>
20 #include <string>
21 #include <utility>
22
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
25 #include "tensorflow/core/data/dataset_utils.h"
26 #include "tensorflow/core/data/name_utils.h"
27 #include "tensorflow/core/framework/dataset.h"
28 #include "tensorflow/core/framework/partial_tensor_shape.h"
29 #include "tensorflow/core/framework/stats_aggregator.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/lib/core/threadpool.h"
32 #include "tensorflow/core/lib/gtl/cleanup.h"
33 #include "tensorflow/core/lib/random/random.h"
34 #include "tensorflow/core/platform/blocking_counter.h"
35 #include "tensorflow/core/platform/stringprintf.h"
36 #include "tensorflow/core/profiler/lib/traceme.h"
37 #include "tensorflow/core/profiler/lib/traceme_encode.h"
38
39 namespace tensorflow {
40 namespace data {
41 namespace experimental {
42
43 /* static */ constexpr const char* const
44 ParallelInterleaveDatasetOp::kDatasetType;
45 /* static */ constexpr const char* const
46 ParallelInterleaveDatasetOp::kInputDataset;
47 /* static */ constexpr const char* const
48 ParallelInterleaveDatasetOp::kOtherArguments;
49 /* static */ constexpr const char* const
50 ParallelInterleaveDatasetOp::kCycleLength;
51 /* static */ constexpr const char* const
52 ParallelInterleaveDatasetOp::kBlockLength;
53 /* static */ constexpr const char* const
54 ParallelInterleaveDatasetOp::kDeterministic;
55 /* static */ constexpr const char* const ParallelInterleaveDatasetOp::kSloppy;
56 /* static */ constexpr const char* const
57 ParallelInterleaveDatasetOp::kBufferOutputElements;
58 /* static */ constexpr const char* const
59 ParallelInterleaveDatasetOp::kPrefetchInputElements;
60 /* static */ constexpr const char* const ParallelInterleaveDatasetOp::kFunc;
61 /* static */ constexpr const char* const
62 ParallelInterleaveDatasetOp::kTarguments;
63 /* static */ constexpr const char* const
64 ParallelInterleaveDatasetOp::kOutputTypes;
65 /* static */ constexpr const char* const
66 ParallelInterleaveDatasetOp::kOutputShapes;
67
68 constexpr char kInputExhausted[] = "input_exhausted";
69 constexpr char kNextIndex[] = "next_index";
70 constexpr char kBlockCount[] = "block_count";
71 constexpr char kWorkersSize[] = "workers_size";
72 constexpr char kInterleaveSize[] = "interleave_size";
73 constexpr char kInterleaveIndices[] = "interleave_indices";
74 constexpr char kStagingSize[] = "staging_size";
75 constexpr char kStagingIndices[] = "staging_indices";
76 constexpr char kWorkerThreadsRunning[] = "worker_threads_running";
77 constexpr char kDataParallelInterleaveWorker[] =
78 "data_parallel_interleave_worker";
79 constexpr char kWorker[] = "worker";
80 constexpr char kInputSize[] = "input_size";
81 constexpr char kInput[] = "input";
82 constexpr char kOutputsSize[] = "outputs_size";
83 constexpr char kOutputs[] = "outputs";
84 constexpr char kIsProducing[] = "is_producing";
85 constexpr char kWorkerThread[] = "worker_thread";
86 constexpr char kIteratorExhausted[] = "iterator_exhausted";
87 constexpr char kIteratorCreationStatus[] = "iterator_creation_status";
88 constexpr char kOutput[] = "output";
89 constexpr char kEndOfSequence[] = "end_of_sequence";
90 constexpr char kStatus[] = "status";
91 constexpr char kOutputSize[] = "output_size";
92 constexpr char kCode[] = "code";
93 constexpr char KMessage[] = "msg";
94
95 class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
96 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,std::unique_ptr<CapturedFunction> captured_func,int64_t cycle_length,int64_t block_length,DeterminismPolicy deterministic,int64_t buffer_output_elements,int64_t prefetch_input_elements,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,int op_version)97 Dataset(OpKernelContext* ctx, const DatasetBase* input,
98 std::unique_ptr<CapturedFunction> captured_func, int64_t cycle_length,
99 int64_t block_length, DeterminismPolicy deterministic,
100 int64_t buffer_output_elements, int64_t prefetch_input_elements,
101 const DataTypeVector& output_types,
102 const std::vector<PartialTensorShape>& output_shapes, int op_version)
103 : DatasetBase(DatasetContext(ctx)),
104 input_(input),
105 captured_func_(std::move(captured_func)),
106 cycle_length_(cycle_length),
107 block_length_(block_length),
108 deterministic_(deterministic),
109 buffer_output_elements_(buffer_output_elements),
110 prefetch_input_elements_(prefetch_input_elements),
111 output_types_(output_types),
112 output_shapes_(output_shapes),
113 traceme_metadata_(
114 {{"block_length",
115 strings::Printf("%lld", static_cast<long long>(block_length))},
116 {"cycle_length",
117 strings::Printf("%lld", static_cast<long long>(cycle_length))},
118 {"deterministic",
119 deterministic.IsDeterministic() || deterministic.IsDefault()
120 ? "true"
121 : "false"}}),
122 op_version_(op_version) {
123 input_->Ref();
124 }
125
~Dataset()126 ~Dataset() override { input_->Unref(); }
127
MakeIteratorInternal(const string & prefix) const128 std::unique_ptr<IteratorBase> MakeIteratorInternal(
129 const string& prefix) const override {
130 name_utils::IteratorPrefixParams params;
131 params.op_version = op_version_;
132 bool deterministic =
133 deterministic_.IsDeterministic() || deterministic_.IsDefault();
134 return std::make_unique<Iterator>(
135 Iterator::Params{
136 this, name_utils::IteratorPrefix(kDatasetType, prefix, params)},
137 deterministic);
138 }
139
output_dtypes() const140 const DataTypeVector& output_dtypes() const override { return output_types_; }
141
output_shapes() const142 const std::vector<PartialTensorShape>& output_shapes() const override {
143 return output_shapes_;
144 }
145
DebugString() const146 string DebugString() const override {
147 name_utils::DatasetDebugStringParams params;
148 params.op_version = op_version_;
149 return name_utils::DatasetDebugString(kDatasetType, params);
150 }
151
InputDatasets(std::vector<const DatasetBase * > * inputs) const152 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
153 inputs->push_back(input_);
154 return OkStatus();
155 }
156
CheckExternalState() const157 Status CheckExternalState() const override {
158 TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
159 return input_->CheckExternalState();
160 }
161
162 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const163 Status AsGraphDefInternal(SerializationContext* ctx,
164 DatasetGraphDefBuilder* b,
165 Node** output) const override {
166 std::vector<std::pair<size_t, Node*>> inputs;
167 std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>> list_inputs;
168 int input_index = 0;
169
170 Node* input_node;
171 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
172 inputs.emplace_back(input_index++, input_node);
173
174 std::vector<Node*> other_arguments;
175 DataTypeVector other_arguments_types;
176 TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
177 &other_arguments_types));
178 list_inputs.emplace_back(input_index++, other_arguments);
179
180 Node* cycle_length_node;
181 TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
182 inputs.emplace_back(input_index++, cycle_length_node);
183
184 Node* block_length_node;
185 TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
186 inputs.emplace_back(input_index++, block_length_node);
187
188 if (op_version_ == 1) {
189 Node* sloppy_node;
190 TF_RETURN_IF_ERROR(
191 b->AddScalar(deterministic_.IsNondeterministic(), &sloppy_node));
192 inputs.emplace_back(input_index++, sloppy_node);
193 }
194
195 Node* buffer_output_elements_node;
196 TF_RETURN_IF_ERROR(
197 b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
198 inputs.emplace_back(input_index++, buffer_output_elements_node);
199
200 Node* prefetch_input_elements_node;
201 TF_RETURN_IF_ERROR(
202 b->AddScalar(prefetch_input_elements_, &prefetch_input_elements_node));
203 inputs.emplace_back(input_index++, prefetch_input_elements_node);
204
205 std::vector<std::pair<StringPiece, AttrValue>> attrs;
206
207 AttrValue f;
208 b->BuildAttrValue(captured_func_->func(), &f);
209 attrs.emplace_back(kFunc, f);
210
211 if (op_version_ == 2) {
212 AttrValue deterministic_attr;
213 b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
214 attrs.emplace_back(kDeterministic, deterministic_attr);
215 }
216
217 AttrValue other_arguments_types_attr;
218 b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
219 attrs.emplace_back(kTarguments, other_arguments_types_attr);
220
221 TF_RETURN_IF_ERROR(b->AddDataset(this, inputs, list_inputs, attrs, output));
222 return OkStatus();
223 }
224
225 private:
num_threads() const226 int64_t num_threads() const {
227 return cycle_length_ + prefetch_input_elements_;
228 }
229
230 // Parallel interleave's implementation is designed around a few principles:
231 // 1. Thread creation is relatively expensive. (Not reusing
232 // threads causes a number of indirect costs such as poorer tcmalloc
233 // performance due to thread-local caches, etc.) We allocate a fixed
234 // number of threads at the start and never change. This is why we've
235 // fused functionality that is theoretically orthogonal (i.e.
236 // .prefetch()) into the implementation.
237 // 2. Drop-in replacement for standard interleave. The goal will be to
238 // auto-opt people into an optimized implementation without any work
239 // on the customer's part. We thus go through great pains to maintain
240 // identical iteration orders, full determinism (disabled only via a
241 // flag, etc.)
242 // 3. Performance across a variety of environments and I/O envelopes.
243 //
244 // The actual implementation centers around a collection of worker threads
245 // and their corresponding worker state (tracked in the `workers_` vector).
246 // Worker threads repeatedly receive a vector of Tensors that are used as
247 // input to the flat-map function (`captured_func_`). The output of this
248 // function must be a dataset. The worker thread then repeatedly calls
249 // `GetNext()`, maintaining a buffer of elements to minimize the likelihood
250 // that a caller will block waiting for an element to be produced.
251 //
252 // Pointers to these worker states are kept in 2 disjoint data structures:
253 // 1. `interleave_indices_` is a vector containing indices of WorkerStates
254 // in `workers_` that we are interleaving. Worker threads backing these
255 // WorkerStates should be regularly producing values.
256 // 2. `staging_indices_` is a deque containing indices of WorkerStates in
257 // `workers_` that we will move to `interleave_indices_` when an
258 // iterator in `interleave_indices_` is exhausted.
259 //
260 // The client calls `GetNext[Internal]()` to retrieve an output element. The
261 // internal implementation updates the state of `interleave_indices_` and
262 // `staging_indices_` as output iterators (run by the worker threads) are
263 // exhausted.
264 //
265 // `input_impl_` is the input iterator that generates arguments for the
266 // flat-map function (`captured_func_`). It is set to an iterator at
267 // Iterator construction, and is fixed until we consume all input elements.
268 // Once it is exhausted, we reset the unique_ptr to eagerly deallocate
269 // memory.
270 //
271 // A few invariants are maintained:
272 // 1. No element in interleave_indices_ should be a -1 unless
273 // `staging_indices_` is empty and `input_impl_` is empty.
274 // 2. Every `worker_` element is pointed to by at most one element of the
275 // union of `interleave_indices_` and `staging_indices_`.
276 // 3. Unless `input_impl_` is empty, every `worker_` must be pointed to by
277 // an element in `interleave_indices_` or `staging_indices_`.
278 class Iterator : public DatasetIterator<Dataset> {
279 public:
Iterator(const Params & params,bool deterministic)280 explicit Iterator(const Params& params, bool deterministic)
281 : DatasetIterator<Dataset>(params),
282 deterministic_(deterministic),
283 workers_(dataset()->num_threads()),
284 worker_thread_states_(dataset()->num_threads()) {}
285
~Iterator()286 ~Iterator() override {
287 CancelThreads();
288 if (deregister_fn_) deregister_fn_();
289 }
290
291 // TODO(jsimsa): Register cancellation callback once the implementation is
292 // refactored not to hold mu_ while calling `GetNext` on the input.
Initialize(IteratorContext * ctx)293 Status Initialize(IteratorContext* ctx) override {
294 cancellation_manager_ = std::make_unique<CancellationManager>();
295 IteratorContext::Params params(ctx);
296 params.cancellation_manager = cancellation_manager_.get();
297 TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
298 IteratorContext(params), this, prefix(), &input_impl_));
299 return dataset()->captured_func_->Instantiate(
300 ctx, &instantiated_captured_func_);
301 }
302
303 // It is implemented so that it matches the deterministic interleave
304 // unless getting the next element would block and we are allowed to be
305 // nondeterministic.
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)306 Status GetNextInternal(IteratorContext* ctx,
307 std::vector<Tensor>* out_tensors,
308 bool* end_of_sequence) override {
309 mutex_lock l(mu_);
310 TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
311 while (!cancelled_) {
312 // Wait for an item to become available, blocking if necessary. If we
313 // are allowed to be nondeterministic, we can skip over input datasets
314 // that do not have an item readily available.
315 bool can_produce_elements = false;
316 bool must_wait_for_input = true;
317 for (int64_t i = 0; i < interleave_indices_.size(); ++i) {
318 int64_t index = (next_index_ + i) % interleave_indices_.size();
319 int64_t current_worker_index = interleave_indices_[index];
320 if (current_worker_index < 0) {
321 continue; // Empty interleave elements.
322 }
323 WorkerState* current_worker = &workers_[current_worker_index];
324 can_produce_elements |= current_worker->MayHaveElements();
325 if (!current_worker->outputs.empty()) {
326 // We have an element!
327 next_index_ = index;
328 const bool element_acquired_sloppily = !deterministic_ && i > 1;
329 if (!element_acquired_sloppily) {
330 // If the element was acquired in the regular (deterministic)
331 // order, then advance the current block and cycle pointers to
332 // the next element in the regular order.
333 block_count_++;
334 if (block_count_ == dataset()->block_length_) {
335 next_index_ = (index + 1) % interleave_indices_.size();
336 block_count_ = 0;
337 }
338 } else {
339 block_count_ = 0;
340 }
341 *end_of_sequence = false;
342 Status s = current_worker->outputs.front().status;
343 profiler::TraceMe traceme([&] {
344 return profiler::TraceMeEncode(
345 "ParallelInterleaveConsume",
346 {{"element_id", current_worker->outputs.front().id}});
347 });
348 current_worker->outputs.front().output.swap(*out_tensors);
349 current_worker->outputs.pop_front();
350 current_worker->cond_var.notify_one();
351 return s;
352 } else if (current_worker->is_producing && deterministic_) {
353 // current_worker.outputs.empty(), and we must wait for this
354 // iterator.
355 if (next_index_ != index) {
356 // We have advanced to a new iterator; reset block counts.
357 next_index_ = index;
358 block_count_ = 0;
359 }
360 break;
361 } else if (!current_worker->is_producing) {
362 // This iterator has reached end of input.
363 interleave_indices_[index] = -1;
364 if (input_impl_) {
365 // Start prefetching a new iterator.
366 std::vector<Tensor> args;
367 bool end_of_input = false;
368 Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
369 if (end_of_input) {
370 input_impl_.reset();
371 } else {
372 current_worker->SetInputs(s, std::move(args));
373 staging_indices_.emplace_back(current_worker_index);
374 }
375 }
376
377 if (!staging_indices_.empty()) {
378 // Move a worker from `staging_indices_` to
379 // `interleave_indices_`.
380 interleave_indices_[index] = staging_indices_.front();
381 staging_indices_.pop_front();
382 {
383 mutex_lock ckpt_l(ckpt_mu_);
384 if (worker_thread_states_[interleave_indices_[index]]
385 .iterator != nullptr) {
386 // TODO(wilsin): Write a unit test where we iterate through a
387 // dataset, pause, and check the model proto autotune value.
388 EnableAutotune(
389 ctx, worker_thread_states_[interleave_indices_[index]]
390 .iterator.get());
391 }
392 }
393 next_index_ = (index + 1) % interleave_indices_.size();
394 block_count_ = 0;
395 // Restart the inner [for] loop
396 can_produce_elements = true;
397 must_wait_for_input = false;
398 break;
399 }
400 }
401 }
402
403 if (!can_produce_elements && !input_impl_) {
404 // No potential for future values.
405 *end_of_sequence = true;
406 return OkStatus();
407 }
408
409 if (must_wait_for_input) {
410 // Wait for elements to become available.
411 RecordStop(ctx);
412 if (deterministic_) {
413 workers_[interleave_indices_[next_index_]].cond_var.wait(l);
414 } else {
415 any_element_available_cond_var_.wait(l);
416 }
417 RecordStart(ctx);
418 }
419 }
420 return errors::Cancelled(
421 "ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext");
422 }
423
424 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const425 std::shared_ptr<model::Node> CreateNode(
426 IteratorContext* ctx, model::Node::Args args) const override {
427 return model::MakeAsyncInterleaveManyNode(
428 std::move(args), {model::MakeNonTunableParameter(
429 kCycleLength, dataset()->cycle_length_),
430 model::MakeNonTunableParameter(
431 kDeterministic, deterministic_ ? 1.0 : 0.0)});
432 }
433
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)434 Status SaveInternal(SerializationContext* ctx,
435 IteratorStateWriter* writer) override {
436 TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
437 dataset()->captured_func_->CheckExternalState()));
438 // The order of locking is important here to avoid deadlock.
439 mutex_lock l(mu_);
440 mutex_lock ckpt_l(ckpt_mu_);
441 if (input_impl_) {
442 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
443 } else {
444 TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kInputExhausted, ""));
445 }
446 TF_RETURN_IF_ERROR(
447 writer->WriteScalar(prefix(), kNextIndex, next_index_));
448 TF_RETURN_IF_ERROR(
449 writer->WriteScalar(prefix(), kBlockCount, block_count_));
450 TF_RETURN_IF_ERROR(
451 writer->WriteScalar(prefix(), kWorkersSize, workers_.size()));
452 for (int i = 0; i < workers_.size(); ++i) {
453 TF_RETURN_IF_ERROR(WriteWorkerStateLocked(writer, i));
454 }
455 for (int i = 0; i < worker_thread_states_.size(); ++i) {
456 TF_RETURN_IF_ERROR(WriteWorkerThreadStateLocked(ctx, writer, i));
457 }
458 TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kInterleaveSize,
459 interleave_indices_.size()));
460 for (int i = 0; i < interleave_indices_.size(); ++i) {
461 TF_RETURN_IF_ERROR(writer->WriteScalar(
462 prefix(), strings::StrCat(kInterleaveIndices, "_", i),
463 interleave_indices_[i]));
464 }
465 TF_RETURN_IF_ERROR(
466 writer->WriteScalar(prefix(), kStagingSize, staging_indices_.size()));
467 for (int i = 0; i < staging_indices_.size(); ++i) {
468 TF_RETURN_IF_ERROR(writer->WriteScalar(
469 prefix(), strings::StrCat(kStagingIndices, "_", i),
470 staging_indices_[i]));
471 }
472 if (!worker_threads_.empty()) {
473 TF_RETURN_IF_ERROR(
474 writer->WriteScalar(prefix(), kWorkerThreadsRunning, ""));
475 }
476 return OkStatus();
477 }
478
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)479 Status RestoreInternal(IteratorContext* ctx,
480 IteratorStateReader* reader) override {
481 {
482 // The order of locking is important here to avoid deadlock.
483 mutex_lock l(mu_);
484 mutex_lock ckpt_l(ckpt_mu_);
485 if (!reader->Contains(prefix(), kInputExhausted)) {
486 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
487 } else {
488 input_impl_.reset();
489 }
490 int64_t temp;
491 TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kNextIndex, &temp));
492 next_index_ = size_t(temp);
493 TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kBlockCount, &temp));
494 block_count_ = size_t(temp);
495
496 // Restore WorkerStates.
497 TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kWorkersSize, &temp));
498 if (temp != dataset()->num_threads()) {
499 return errors::Internal("Expected ", dataset()->num_threads(),
500 " worker states but found ", temp, ".");
501 }
502 for (size_t i = 0; i < dataset()->num_threads(); ++i) {
503 TF_RETURN_IF_ERROR(ReadWorkerStateLocked(ctx, reader, i));
504 }
505 }
506 std::unique_ptr<thread::ThreadPool> threadpool = ctx->CreateThreadPool(
507 "read_worker_thread_state", dataset()->num_threads());
508 Status s = OkStatus();
509 BlockingCounter counter(dataset()->num_threads());
510 for (size_t i = 0; i < dataset()->num_threads(); ++i) {
511 threadpool->Schedule([this, i, ctx, reader, &s, &counter] {
512 WorkerThreadState state;
513 Status result = ReadWorkerThreadStateLocked(ctx, reader, i, &state);
514 mutex_lock l(mu_);
515 mutex_lock ckpt_l(ckpt_mu_);
516 if (!result.ok()) {
517 s.Update(result);
518 counter.DecrementCount();
519 return;
520 }
521 worker_thread_states_[i] = std::move(state);
522 counter.DecrementCount();
523 });
524 }
525 counter.Wait();
526 if (!s.ok()) {
527 return s;
528 }
529
530 mutex_lock l(mu_);
531 mutex_lock ckpt_l(ckpt_mu_);
532 // Restore `interleave_indices_`.
533 std::set<int64_t> all_indices;
534 {
535 int64_t interleave_size;
536 TF_RETURN_IF_ERROR(
537 reader->ReadScalar(prefix(), kInterleaveSize, &interleave_size));
538 interleave_indices_.reserve(interleave_size);
539 for (int64_t i = 0; i < interleave_size; ++i) {
540 int64_t temp;
541 TF_RETURN_IF_ERROR(reader->ReadScalar(
542 prefix(), strings::StrCat(kInterleaveIndices, "_", i), &temp));
543 if (temp >= 0 && all_indices.find(temp) != all_indices.end()) {
544 return errors::Internal(
545 "Duplicate entry for ", temp,
546 " found when reading interleave and staging indices.");
547 }
548 if (temp >= 0) {
549 all_indices.insert(temp);
550 }
551 interleave_indices_.emplace_back(temp);
552 }
553 }
554
555 // Restore `staging_indices_`.
556 {
557 int64_t staging_size;
558 TF_RETURN_IF_ERROR(
559 reader->ReadScalar(prefix(), kStagingSize, &staging_size));
560 for (int i = 0; i < staging_size; ++i) {
561 int64_t temp;
562 TF_RETURN_IF_ERROR(reader->ReadScalar(
563 prefix(), strings::StrCat(kStagingIndices, "_", i), &temp));
564 if (all_indices.find(temp) != all_indices.end()) {
565 return errors::Internal(
566 "Duplicate entry for ", temp,
567 " found when reading interleave and staging indices.");
568 }
569 if (temp >= 0) {
570 all_indices.insert(temp);
571 }
572 staging_indices_.emplace_back(temp);
573 }
574 }
575
576 // Start Worker threads.
577 if (reader->Contains(prefix(), kWorkerThreadsRunning)) {
578 worker_threads_.reserve(dataset()->num_threads());
579 for (size_t i = 0; i < dataset()->num_threads(); ++i) {
580 std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
581 worker_threads_.emplace_back(ctx->StartThread(
582 strings::StrCat(kDataParallelInterleaveWorker, "_", i),
583 [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
584 }
585 }
586 return OkStatus();
587 }
588
GetTraceMeMetadata() const589 TraceMeMetadata GetTraceMeMetadata() const override {
590 return dataset()->traceme_metadata_;
591 }
592
593 private:
594 // OutputElem contains the information from a call to GetNext by an output
595 // iterator.
596 struct OutputElem {
597 // The output iterator sets `status` if getting the output element
598 // fails.
599 Status status;
600 // The buffered data element.
601 std::vector<Tensor> output;
602 int64_t id = -1;
603
OutputElemtensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::OutputElem604 explicit OutputElem(const Status& s) : status(s) {}
OutputElemtensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::OutputElem605 OutputElem(const Status& s, int64_t id) : status(s), id(id) {}
606 };
607
608 // Worker threads operate on their relevant WorkerState structs.
609 //
610 // WorkerState's fields are all protected by mu_;
611 struct WorkerState {
612 // The arguments to be used to construct an output iterator.
613 std::vector<Tensor> input;
614 // The buffered output elements.
615 std::deque<OutputElem> outputs;
616 // Set to true iff the worker thread expects to append more elements to
617 // outputs. is_producing can be false despite !outputs.empty().
618 // Concretely, all output elements will have been consumed only when:
619 // is_producing == false && outputs.empty();
620 bool is_producing = false;
621 // Condition variable used to coordinate between threads. The worker
622 // thread waits on this condition variable when it is either (1) waiting
623 // for the main thread to add arguments to `input`, or (2) waiting for
624 // the main thread to consume an element of `outputs`. The main thread
625 // waits on cond_var if it is waiting for the worker thread to produce
626 // an element into `outputs` (this implies deterministic==true).
627 condition_variable cond_var;
628
MayHaveElementstensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::WorkerState629 inline bool MayHaveElements() const {
630 return is_producing || !outputs.empty();
631 }
632
633 // Sets inputs for a worker thread and notifies it to start processing.
SetInputstensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::WorkerState634 void SetInputs(const Status& s, std::vector<Tensor> input_arguments) {
635 if (s.ok()) {
636 DCHECK(!MayHaveElements())
637 << "Tried to start inputs, despite already producing!";
638 input = std::move(input_arguments);
639 is_producing = true;
640 cond_var.notify_one();
641 } else {
642 outputs.emplace_back(s);
643 }
644 }
645 };
646
647 // The internal state of a worker thread that is not already captured
648 // in its `WorkerState`.
649 //
650 // This is needed only for checkpointing purposes. We keep this
651 // separate from `WorkerState` and guard its fields using a separate
652 // lock `ckpt_mu_` so as to not affect the performance of main pipeline.
653 struct WorkerThreadState {
654 // The output element that has been produced from the input iterator
655 // and is waiting to be added to `WorkerState.outputs`.
656 OutputElem output_elem;
657
658 // Whether the input iterator returned an `end_of_sequence`.
659 bool end_of_sequence = false;
660
661 // Status returned from `MakeIteratorFromInputElement`.
662 Status iterator_creation_status;
663
664 // The arguments to be used to construct `iterator`.
665 std::vector<Tensor> input;
666
667 std::unique_ptr<IteratorBase> iterator;
668
WorkerThreadStatetensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::WorkerThreadState669 WorkerThreadState() : output_elem(OkStatus()) {}
670 };
671
CancelThreads()672 void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
673 cancellation_manager_->StartCancel();
674 mutex_lock l(mu_);
675 cancelled_ = true;
676 for (auto& worker : workers_) {
677 worker.cond_var.notify_all();
678 }
679 }
680
EnsureWorkerThreadsStarted(IteratorContext * ctx)681 Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
682 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
683 if (worker_threads_.empty() && input_impl_) {
684 worker_threads_.reserve(dataset()->num_threads());
685 for (int64_t i = 0; i < dataset()->num_threads(); ++i) {
686 std::vector<Tensor> args;
687 bool end_of_input = false;
688 Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
689 if (end_of_input) {
690 input_impl_.reset();
691 return OkStatus();
692 }
693 if (i < dataset()->cycle_length_) {
694 interleave_indices_.push_back(i);
695 } else {
696 staging_indices_.push_back(i);
697 }
698 workers_[i].SetInputs(s, std::move(args));
699 std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
700 worker_threads_.push_back(ctx->StartThread(
701 strings::StrCat(kDataParallelInterleaveWorker, "_", i),
702 [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
703 }
704 DCHECK(interleave_indices_.size() == dataset()->cycle_length_);
705 DCHECK(staging_indices_.size() == dataset()->prefetch_input_elements_);
706 }
707 return OkStatus();
708 }
709
710 // Produces elements into the worker's output buffers.
WorkerThread(const std::shared_ptr<IteratorContext> & ctx,const int64_t thread_index)711 void WorkerThread(const std::shared_ptr<IteratorContext>& ctx,
712 const int64_t thread_index) {
713 // Notes on checkpointing thread local state, i.e., `WorkerThreadState`:
714 //
715 // 1. Any local state that may need to be checkpointed should be kept
716 // in `worker_thread_states_[thread_index]`.
717 // 2. `WorkerThreadState` should contain state that is needed only for
718 // checkpointing, i.e., if we were to remove checkpointing support,
719 // we could keep that state as local variables in this thread.
720 // 3. This thread should only read/write state at `thread_index`
721 // and should not access other thread states.
722 // 4. When restoring from checkpoint, threads are started only after
723 // the restore is complete.
724 // 5. Once restored from a checkpoint, the local state is edited only
725 // by this thread. 3 & 4 allow making assumptions like temporarily
726 // caching local state in this thread and using it outside a lock
727 // e.g. `make_new_iterator`.
728 // 6. `ckpt_mu_` should be wisely used to create *consistent*
729 // checkpoint markers.
730
731 // std::function arguments are copy-constructable, so we pass raw
732 // pointers, and then immediately wrap them to ensure correct ownership.
733 RecordStart(ctx.get());
734 auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
735 mutex_lock l(mu_);
736 workers_[thread_index].cond_var.notify_all();
737 RecordStop(ctx.get());
738 });
739 bool make_new_iterator;
740 {
741 tf_shared_lock l(ckpt_mu_);
742 // Decide whether a new iterator should be built.
743 // 1. If there is an existing iterator, we use it.
744 // 2. If there was an error in iterator creation that could not be
745 // notified to the client we attempt to send that to the client
746 // first.
747 make_new_iterator =
748 worker_thread_states_[thread_index].iterator == nullptr &&
749 worker_thread_states_[thread_index].iterator_creation_status.ok();
750 }
751 // Even though `make_new_iterator` has cached values from
752 // `worker_thread_states_[thread_index]` which is guarded by ckpt_mu_,
753 // it is safe to *read* `make_new_iterator`outside of a lock without
754 // worrying about concurrent changes to values in
755 // `worker_thread_states_[thread_index]`. See comment at the start of
756 // this function for details.
757 while (true) {
758 // Whether creation of the iterator succeeded.
759 Status iterator_creation_status;
760 // 1. Build a new iterator or use the existing one.
761 if (make_new_iterator) {
762 // 1a. Get new input tensors or use the exiting ones.
763 bool read_new_input;
764 {
765 tf_shared_lock l(ckpt_mu_);
766 // worker_thread_states_[thread_index].input will be non-empty
767 // if checkpointing happened at CHECKPOINT_MARKER_A.
768 read_new_input = worker_thread_states_[thread_index].input.empty();
769 }
770
771 if (read_new_input) {
772 mutex_lock l(mu_);
773 while (!cancelled_ && !workers_[thread_index].is_producing) {
774 RecordStop(ctx.get());
775 workers_[thread_index].cond_var.wait(l);
776 RecordStart(ctx.get());
777 }
778 if (cancelled_) return;
779 // Copy the input tensors so that we do not need to block on `mu_`
780 // when building the iterator.
781 // We keep a copy of the input tensors in
782 // `WorkerThreadState.input` till the iterator is in use. This is
783 // used in `RestoreInternal` to re-build the iterator.
784 // TODO(b/78046638): Explore ways to avoid tracking the input
785 // tensors.
786 tf_shared_lock ckpt_l(ckpt_mu_);
787 worker_thread_states_[thread_index].input.swap(
788 workers_[thread_index].input);
789 // CHECKPOINT_MARKER_A
790 // We have the input tensors but have not built the iterator yet.
791 }
792 bool thread_in_staging = false;
793 {
794 mutex_lock l(mu_);
795 thread_in_staging = absl::c_find(staging_indices_, thread_index) !=
796 staging_indices_.end();
797 }
798 // 1b. Run the user defined function to produce a new iterator.
799 {
800 tf_shared_lock l(ckpt_mu_);
801 worker_thread_states_[thread_index].iterator_creation_status =
802 MakeIteratorFromInputElement(
803 ctx.get(), this, worker_thread_states_[thread_index].input,
804 thread_index, *instantiated_captured_func_, prefix(),
805 &worker_thread_states_[thread_index].iterator,
806 model_node());
807 iterator_creation_status =
808 worker_thread_states_[thread_index].iterator_creation_status;
809 if (!iterator_creation_status.ok()) {
810 worker_thread_states_[thread_index].input.clear();
811 } else if (thread_in_staging) {
812 // TODO(wilsin): Write a unit test where we iterate through a
813 // dataset, pause, and check the model proto autotune value.
814 DisableAutotune(
815 ctx.get(),
816 worker_thread_states_[thread_index].iterator.get());
817 }
818 // CHECKPOINT_MARKER_B
819 // Either an iterator has been successfully built and placed in
820 // `worker_thread_states_[thread_index].iterator` or it failed and
821 // a non-OK status has been put in
822 // `worker_thread_states_[thread_index].iterator_creation_status`.
823 }
824 } else {
825 tf_shared_lock l(ckpt_mu_);
826 iterator_creation_status =
827 worker_thread_states_[thread_index].iterator_creation_status;
828 // Mark that we have used up the restored iterator.
829 make_new_iterator = true;
830 }
831 // 2. Start producing elements or send error state to client if
832 // iterator creation failed.
833 if (!iterator_creation_status.ok()) {
834 mutex_lock l(mu_);
835 // Wait for space in the prefetch queue.
836 while (!cancelled_ && workers_[thread_index].outputs.size() ==
837 dataset()->buffer_output_elements_) {
838 RecordStop(ctx.get());
839 workers_[thread_index].cond_var.wait(l);
840 RecordStart(ctx.get());
841 }
842 if (cancelled_) return;
843 tf_shared_lock ckpt_l(ckpt_mu_);
844 workers_[thread_index].outputs.emplace_back(iterator_creation_status);
845 workers_[thread_index].is_producing = false;
846 worker_thread_states_[thread_index].iterator_creation_status =
847 OkStatus();
848 // CHECKPOINT_MARKER_C
849 // Non-OK iterator creation status has been notified to the
850 // client.
851 if (deterministic_) {
852 workers_[thread_index].cond_var.notify_one();
853 } else {
854 any_element_available_cond_var_.notify_one();
855 }
856 } else {
857 bool end_of_sequence = false;
858 while (!end_of_sequence) {
859 // 3.a Produce an element!
860 {
861 tf_shared_lock ckpt_l(ckpt_mu_);
862 if (worker_thread_states_[thread_index].output_elem.status.ok() &&
863 worker_thread_states_[thread_index]
864 .output_elem.output.empty() &&
865 !worker_thread_states_[thread_index].end_of_sequence) {
866 int64_t& id =
867 worker_thread_states_[thread_index].output_elem.id;
868 profiler::TraceMe traceme(
869 [&] {
870 id = profiler::TraceMe::NewActivityId();
871 return profiler::TraceMeEncode(
872 "ParallelInterleaveProduce", {{"element_id", id}});
873 },
874 profiler::kInfo);
875 worker_thread_states_[thread_index].output_elem.status =
876 worker_thread_states_[thread_index].iterator->GetNext(
877 ctx.get(),
878 &worker_thread_states_[thread_index].output_elem.output,
879 &worker_thread_states_[thread_index].end_of_sequence);
880 end_of_sequence =
881 worker_thread_states_[thread_index].end_of_sequence;
882 } else {
883 end_of_sequence =
884 worker_thread_states_[thread_index].end_of_sequence;
885 }
886 // CHECKPOINT_MARKER_D
887 // An element has been read or an error or end_of_sequence has
888 // been received from the input iterator and is waiting to be
889 // sent to client.
890 }
891
892 // 3.b Make it available to the client.
893 {
894 mutex_lock l(mu_);
895
896 // Wait for space in the prefetch queue.
897 while (!cancelled_ && workers_[thread_index].outputs.size() ==
898 dataset()->buffer_output_elements_) {
899 RecordStop(ctx.get());
900 workers_[thread_index].cond_var.wait(l);
901 RecordStart(ctx.get());
902 }
903 if (cancelled_) return;
904
905 tf_shared_lock ckpt_l(ckpt_mu_);
906 workers_[thread_index].is_producing = !end_of_sequence;
907
908 // Output the element.
909
910 // Move the temporary state in WorkerThreadState to WorkerState
911 // and mark it as used.
912 if (end_of_sequence) {
913 worker_thread_states_[thread_index].iterator.reset();
914 worker_thread_states_[thread_index].input.clear();
915 worker_thread_states_[thread_index].end_of_sequence = false;
916 } else {
917 workers_[thread_index].outputs.emplace_back(
918 worker_thread_states_[thread_index].output_elem.status,
919 worker_thread_states_[thread_index].output_elem.id);
920 workers_[thread_index].outputs.back().output.swap(
921 worker_thread_states_[thread_index].output_elem.output);
922 }
923 worker_thread_states_[thread_index].output_elem.status =
924 OkStatus();
925 if (deterministic_) {
926 workers_[thread_index].cond_var.notify_one();
927 } else {
928 any_element_available_cond_var_.notify_one();
929 }
930 // CHECKPOINT_MARKER_E
931 // Output element or iterator status has been sent to the
932 // client.
933 }
934 }
935 }
936 }
937 }
938
WriteWorkerStateLocked(IteratorStateWriter * writer,int index)939 Status WriteWorkerStateLocked(IteratorStateWriter* writer, int index)
940 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
941 string iterator_name =
942 strings::StrCat(prefix(), "::", kWorker, "_", index);
943 TF_RETURN_IF_ERROR(writer->WriteScalar(iterator_name, kInputSize,
944 workers_[index].input.size()));
945 for (int i = 0; i < workers_[index].input.size(); ++i) {
946 TF_RETURN_IF_ERROR(writer->WriteTensor(iterator_name,
947 strings::StrCat(kInput, "_", i),
948 workers_[index].input[i]));
949 }
950 TF_RETURN_IF_ERROR(writer->WriteScalar(iterator_name, kOutputsSize,
951 workers_[index].outputs.size()));
952 for (int i = 0; i < workers_[index].outputs.size(); ++i) {
953 TF_RETURN_IF_ERROR(WriteOutputElemLocked(
954 writer, workers_[index].outputs[i], iterator_name,
955 strings::StrCat(kOutputs, "_", i)));
956 }
957 if (workers_[index].is_producing) {
958 TF_RETURN_IF_ERROR(
959 writer->WriteScalar(iterator_name, kIsProducing, ""));
960 }
961 return OkStatus();
962 }
963
ReadWorkerStateLocked(IteratorContext * ctx,IteratorStateReader * reader,int index)964 Status ReadWorkerStateLocked(IteratorContext* ctx,
965 IteratorStateReader* reader, int index)
966 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
967 string worker_prefix =
968 strings::StrCat(prefix(), "::", kWorker, "_", index);
969 // Restore inputs.
970 int64_t input_size;
971 TF_RETURN_IF_ERROR(
972 reader->ReadScalar(worker_prefix, kInputSize, &input_size));
973 workers_[index].input.reserve(input_size);
974 for (int i = 0; i < input_size; ++i) {
975 workers_[index].input.emplace_back();
976 TF_RETURN_IF_ERROR(reader->ReadTensor(ctx->flr(), worker_prefix,
977 strings::StrCat(kInput, "_", i),
978 &workers_[index].input.back()));
979 }
980 int64_t outputs_size;
981 TF_RETURN_IF_ERROR(
982 reader->ReadScalar(worker_prefix, kOutputsSize, &outputs_size));
983 for (int i = 0; i < outputs_size; ++i) {
984 workers_[index].outputs.emplace_back(OkStatus());
985 TF_RETURN_IF_ERROR(ReadOutputElemLocked(
986 ctx, reader, &workers_[index].outputs.back(), worker_prefix,
987 strings::StrCat(kOutputs, "_", i)));
988 }
989 if (reader->Contains(worker_prefix, kIsProducing)) {
990 workers_[index].is_producing = true;
991 } else {
992 workers_[index].is_producing = false;
993 }
994 return OkStatus();
995 }
996
WriteWorkerThreadStateLocked(SerializationContext * ctx,IteratorStateWriter * writer,int index)997 Status WriteWorkerThreadStateLocked(SerializationContext* ctx,
998 IteratorStateWriter* writer, int index)
999 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
1000 string iterator_name =
1001 strings::StrCat(prefix(), "::", kWorkerThread, "_", index);
1002 if (worker_thread_states_[index].iterator != nullptr) {
1003 TF_RETURN_IF_ERROR(
1004 SaveInput(ctx, writer, worker_thread_states_[index].iterator));
1005 } else {
1006 TF_RETURN_IF_ERROR(
1007 writer->WriteScalar(iterator_name, kIteratorExhausted, ""));
1008 }
1009 TF_RETURN_IF_ERROR(
1010 writer->WriteScalar(iterator_name, kInputSize,
1011 worker_thread_states_[index].input.size()));
1012 for (int i = 0; i < worker_thread_states_[index].input.size(); ++i) {
1013 TF_RETURN_IF_ERROR(
1014 writer->WriteTensor(iterator_name, strings::StrCat(kInput, "_", i),
1015 worker_thread_states_[index].input[i]));
1016 }
1017 TF_RETURN_IF_ERROR(WriteStatusLocked(
1018 writer, iterator_name, kIteratorCreationStatus,
1019 worker_thread_states_[index].iterator_creation_status));
1020 TF_RETURN_IF_ERROR(WriteOutputElemLocked(
1021 writer, worker_thread_states_[index].output_elem, iterator_name,
1022 kOutput));
1023 if (worker_thread_states_[index].end_of_sequence) {
1024 TF_RETURN_IF_ERROR(
1025 writer->WriteScalar(iterator_name, kEndOfSequence, ""));
1026 }
1027 return OkStatus();
1028 }
1029
ReadWorkerThreadStateLocked(IteratorContext * ctx,IteratorStateReader * reader,int index,WorkerThreadState * state)1030 Status ReadWorkerThreadStateLocked(IteratorContext* ctx,
1031 IteratorStateReader* reader, int index,
1032 WorkerThreadState* state) {
1033 string worker_prefix =
1034 strings::StrCat(prefix(), "::", kWorkerThread, "_", index);
1035 // Restore inputs.
1036 int64_t input_size;
1037 TF_RETURN_IF_ERROR(
1038 reader->ReadScalar(worker_prefix, kInputSize, &input_size));
1039 state->input.reserve(input_size);
1040 for (int i = 0; i < input_size; ++i) {
1041 state->input.emplace_back();
1042 TF_RETURN_IF_ERROR(reader->ReadTensor(ctx->flr(), worker_prefix,
1043 strings::StrCat(kInput, "_", i),
1044 &state->input.back()));
1045 }
1046 // Restore iterator
1047 if (reader->Contains(worker_prefix, kIteratorExhausted)) {
1048 state->iterator.reset();
1049 } else {
1050 std::unique_ptr<IteratorBase> iterator;
1051 // NOTE: We intentionally ignore resource modeling outside GetNext().
1052 TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
1053 ctx, this, state->input, index, *instantiated_captured_func_,
1054 prefix(), &iterator, /*node=*/nullptr));
1055 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator));
1056 state->iterator.swap(iterator);
1057 }
1058 TF_RETURN_IF_ERROR(ReadStatusLocked(reader, worker_prefix,
1059 kIteratorCreationStatus,
1060 &state->iterator_creation_status));
1061 TF_RETURN_IF_ERROR(ReadOutputElemLocked(ctx, reader, &state->output_elem,
1062 worker_prefix, kOutput));
1063 if (reader->Contains(worker_prefix, kEndOfSequence)) {
1064 state->end_of_sequence = true;
1065 } else {
1066 state->end_of_sequence = false;
1067 }
1068 return OkStatus();
1069 }
1070
WriteOutputElemLocked(IteratorStateWriter * writer,const OutputElem & output_elem,const string & iterator_name,const string & prefix)1071 Status WriteOutputElemLocked(IteratorStateWriter* writer,
1072 const OutputElem& output_elem,
1073 const string& iterator_name,
1074 const string& prefix)
1075 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
1076 TF_RETURN_IF_ERROR(WriteStatusLocked(
1077 writer, iterator_name, strings::StrCat(prefix, "_", kStatus),
1078 output_elem.status));
1079 TF_RETURN_IF_ERROR(writer->WriteScalar(
1080 iterator_name, strings::StrCat(prefix, "_", kOutputSize),
1081 output_elem.output.size()));
1082 for (int i = 0; i < output_elem.output.size(); ++i) {
1083 TF_RETURN_IF_ERROR(writer->WriteTensor(
1084 iterator_name, strings::StrCat(prefix, "_", kOutput, "_", i),
1085 output_elem.output[i]));
1086 }
1087 return OkStatus();
1088 }
1089
ReadOutputElemLocked(IteratorContext * ctx,IteratorStateReader * reader,OutputElem * output_elem,const string & iterator_name,const string & prefix)1090 Status ReadOutputElemLocked(IteratorContext* ctx,
1091 IteratorStateReader* reader,
1092 OutputElem* output_elem,
1093 const string& iterator_name,
1094 const string& prefix) {
1095 TF_RETURN_IF_ERROR(ReadStatusLocked(reader, iterator_name,
1096 strings::StrCat(prefix, "_", kStatus),
1097 &output_elem->status));
1098 int64_t output_size;
1099 TF_RETURN_IF_ERROR(reader->ReadScalar(
1100 iterator_name, strings::StrCat(prefix, "_", kOutputSize),
1101 &output_size));
1102 output_elem->output.reserve(output_size);
1103 for (int i = 0; i < output_size; ++i) {
1104 output_elem->output.emplace_back();
1105 TF_RETURN_IF_ERROR(
1106 reader->ReadTensor(ctx->flr(), iterator_name,
1107 strings::StrCat(prefix, "_", kOutput, "_", i),
1108 &output_elem->output.back()));
1109 }
1110 return OkStatus();
1111 }
1112
WriteStatusLocked(IteratorStateWriter * writer,const string & iterator_name,const string & prefix,const Status & status)1113 Status WriteStatusLocked(IteratorStateWriter* writer,
1114 const string& iterator_name, const string& prefix,
1115 const Status& status)
1116 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
1117 TF_RETURN_IF_ERROR(writer->WriteScalar(
1118 iterator_name, strings::StrCat(prefix, "_", kCode),
1119 static_cast<int64_t>(status.code())));
1120 if (!status.ok()) {
1121 TF_RETURN_IF_ERROR(writer->WriteScalar(
1122 iterator_name, strings::StrCat(prefix, "_", KMessage),
1123 status.error_message()));
1124 }
1125 return OkStatus();
1126 }
1127
ReadStatusLocked(IteratorStateReader * reader,const string & iterator_name,const string & prefix,Status * status)1128 Status ReadStatusLocked(IteratorStateReader* reader,
1129 const string& iterator_name, const string& prefix,
1130 Status* status) {
1131 int64_t code_int;
1132 TF_RETURN_IF_ERROR(reader->ReadScalar(
1133 iterator_name, strings::StrCat(prefix, "_", kCode), &code_int));
1134 error::Code code = static_cast<error::Code>(code_int);
1135
1136 if (code != error::Code::OK) {
1137 tstring error_message;
1138 TF_RETURN_IF_ERROR(reader->ReadScalar(
1139 iterator_name, strings::StrCat(prefix, "_", KMessage),
1140 &error_message));
1141 *status = Status(code, error_message);
1142 } else {
1143 *status = OkStatus();
1144 }
1145 return OkStatus();
1146 }
1147
1148 // Mutex & condition variable to guard mutable iterator internals and
1149 // coordinate among worker threads and client thread[s].
1150 mutex mu_ TF_ACQUIRED_BEFORE(ckpt_mu_);
1151 // The main thread waits on this condition variable if running in
1152 // nondeterministic mode and no values are available.
1153 condition_variable any_element_available_cond_var_;
1154 // Whether outputs must be produced in deterministic order.
1155 const bool deterministic_;
1156 // Mutex used to wait for a consistent state while checkpointing.
1157 // Only Save and Restore require an exclusive lock on this mutex. In
1158 // other scenarios we just acquire a shared lock so the pipeline's
1159 // performance should not be affected in the absence of checkpointing.
1160 // A thread must not wait on any condition variable while holding
1161 // `ckpt_mu_` in either shared or exclusive modes.
1162 mutex ckpt_mu_;
1163
1164 // Controls cancellation of `input_impl_`. Must be ordered before
1165 // `input_impl_` so that `input_impl_` is destroyed first.
1166 std::unique_ptr<CancellationManager> cancellation_manager_;
1167
1168 // The iterator producing elements which are converted to datasets by
1169 // the dataset()->captured_func_ then interleaved together.
1170 // input_impl_ is reset when we have exhausted its input.
1171 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
1172
1173 std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
1174
1175 // The WorkerState structs the worker threads operate on.
1176 // workers_ elements are in at most one of interleave_ and staging_.
1177 std::vector<WorkerState> workers_ TF_GUARDED_BY(mu_);
1178
1179 // Stores the temporary state of WorkerThreads which is not stored in
1180 // WorkerState. This is used for checkpointing purposes only.
1181 std::vector<WorkerThreadState> worker_thread_states_
1182 TF_GUARDED_BY(ckpt_mu_);
1183
1184 // Indices in `workers_` of iterators to interleave.
1185 std::vector<int64_t> interleave_indices_ TF_GUARDED_BY(mu_);
1186 // Indices in `workers_` of prefetched iterators.
1187 std::deque<int64_t> staging_indices_ TF_GUARDED_BY(mu_);
1188
1189 // The index into output_elements_ for next element to produce.
1190 size_t next_index_ TF_GUARDED_BY(mu_) = 0;
1191 // The number of items produced so far within the block
1192 size_t block_count_ TF_GUARDED_BY(mu_) = 0;
1193 // Flag to instruct the worker threads to exit.
1194 bool cancelled_ TF_GUARDED_BY(mu_) = false;
1195 // The worker threads. This must be last to ensure the
1196 // threads have exited before any other members are deallocated.
1197 // TODO(b/65178177): Avoid allocating additional threads.
1198 std::vector<std::unique_ptr<Thread>> worker_threads_ TF_GUARDED_BY(mu_);
1199
1200 // Method for deregistering the cancellation callback.
1201 std::function<void()> deregister_fn_;
1202 };
1203
1204 const DatasetBase* const input_;
1205 const std::unique_ptr<CapturedFunction> captured_func_;
1206 const int64_t cycle_length_;
1207 const int64_t block_length_;
1208 const DeterminismPolicy deterministic_;
1209 const int64_t buffer_output_elements_;
1210 const int64_t prefetch_input_elements_;
1211 const DataTypeVector output_types_;
1212 const std::vector<PartialTensorShape> output_shapes_;
1213 const TraceMeMetadata traceme_metadata_;
1214 const int op_version_;
1215 };
1216
ParallelInterleaveDatasetOp(OpKernelConstruction * ctx)1217 ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
1218 OpKernelConstruction* ctx)
1219 : UnaryDatasetOpKernel(ctx),
1220 op_version_(ctx->HasAttr(kDeterministic) ? 2 : 1) {
1221 OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kFunc, /*params=*/{},
1222 &func_metadata_));
1223 if (op_version_ == 2) {
1224 std::string deterministic;
1225 OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
1226 OP_REQUIRES_OK(
1227 ctx, DeterminismPolicy::FromString(deterministic, &deterministic_));
1228 }
1229 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
1230 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
1231 }
1232
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)1233 void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
1234 DatasetBase* input,
1235 DatasetBase** output) {
1236 int64_t cycle_length = 0;
1237 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kCycleLength, &cycle_length));
1238 OP_REQUIRES(ctx, cycle_length > 0,
1239 errors::InvalidArgument("`cycle_length` must be > 0"));
1240
1241 int64_t block_length = 0;
1242 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBlockLength, &block_length));
1243 OP_REQUIRES(ctx, block_length > 0,
1244 errors::InvalidArgument("`block_length` must be > 0"));
1245
1246 if (op_version_ == 1) {
1247 bool sloppy = false;
1248 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kSloppy, &sloppy));
1249 if (sloppy) {
1250 deterministic_ =
1251 DeterminismPolicy(DeterminismPolicy::Type::kNondeterministic);
1252 } else {
1253 deterministic_ =
1254 DeterminismPolicy(DeterminismPolicy::Type::kDeterministic);
1255 }
1256 }
1257
1258 int64_t buffer_output_elements = 0;
1259 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBufferOutputElements,
1260 &buffer_output_elements));
1261 OP_REQUIRES(ctx, buffer_output_elements > 0,
1262 errors::InvalidArgument("`buffer_output_elements` must be > 0"));
1263
1264 int64_t prefetch_input_elements = 0;
1265 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kPrefetchInputElements,
1266 &prefetch_input_elements));
1267 OP_REQUIRES(
1268 ctx, prefetch_input_elements >= 0,
1269 errors::InvalidArgument("`prefetch_input_elements` must be >= 0"));
1270
1271 std::unique_ptr<CapturedFunction> captured_func;
1272 OP_REQUIRES_OK(ctx,
1273 CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
1274 &captured_func));
1275
1276 *output = new Dataset(ctx, input, std::move(captured_func), cycle_length,
1277 block_length, deterministic_, buffer_output_elements,
1278 prefetch_input_elements, output_types_, output_shapes_,
1279 op_version_);
1280 }
1281
1282 namespace {
1283 REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
1284 ParallelInterleaveDatasetOp);
1285 REGISTER_KERNEL_BUILDER(
1286 Name("ExperimentalParallelInterleaveDataset").Device(DEVICE_CPU),
1287 ParallelInterleaveDatasetOp);
1288 REGISTER_KERNEL_BUILDER(
1289 Name("LegacyParallelInterleaveDatasetV2").Device(DEVICE_CPU),
1290 ParallelInterleaveDatasetOp);
1291
1292 REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDataset");
1293 REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalParallelInterleaveDataset");
1294 REGISTER_INPUT_COLOCATION_EXEMPTION("LegacyParallelInterleaveDatasetV2");
1295
1296 } // namespace
1297 } // namespace experimental
1298 } // namespace data
1299 } // namespace tensorflow
1300