1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/parallel_map_dataset_op.h"
16
17 #include <deque>
18
19 #include "tensorflow/core/common_runtime/function.h"
20 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
21 #include "tensorflow/core/data/dataset_utils.h"
22 #include "tensorflow/core/data/name_utils.h"
23 #include "tensorflow/core/data/stats_utils.h"
24 #include "tensorflow/core/framework/metrics.h"
25 #include "tensorflow/core/framework/model.h"
26 #include "tensorflow/core/framework/partial_tensor_shape.h"
27 #include "tensorflow/core/framework/stats_aggregator.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/random/random.h"
31 #include "tensorflow/core/platform/stringprintf.h"
32 #include "tensorflow/core/profiler/lib/traceme.h"
33 #include "tensorflow/core/profiler/lib/traceme_encode.h"
34 #include "tensorflow/core/protobuf/error_codes.pb.h"
35
36 namespace tensorflow {
37 namespace data {
38
39 // See documentation in ../../ops/dataset_ops.cc for a high-level
40 // description of the following op.
41
42 /* static */ constexpr const char* const ParallelMapDatasetOp::kDatasetType;
43 /* static */ constexpr const char* const ParallelMapDatasetOp::kInputDataset;
44 /* static */ constexpr const char* const ParallelMapDatasetOp::kOtherArguments;
45 /* static */ constexpr const char* const
46 ParallelMapDatasetOp::kNumParallelCalls;
47 /* static */ constexpr const char* const ParallelMapDatasetOp::kFunc;
48 /* static */ constexpr const char* const ParallelMapDatasetOp::kTarguments;
49 /* static */ constexpr const char* const ParallelMapDatasetOp::kOutputTypes;
50 /* static */ constexpr const char* const ParallelMapDatasetOp::kOutputShapes;
51 /* static */ constexpr const char* const
52 ParallelMapDatasetOp::kUseInterOpParallelism;
53 /* static */ constexpr const char* const ParallelMapDatasetOp::kDeterministic;
54 /* static */ constexpr const char* const ParallelMapDatasetOp::kSloppy;
55 /* static */ constexpr const char* const
56 ParallelMapDatasetOp::kPreserveCardinality;
57
58 namespace {
59
60 constexpr char kParallelMapDatasetV1[] = "ParallelMapDataset";
61 constexpr char kParallelMapDatasetV2[] = "ParallelMapDatasetV2";
62
63 constexpr char kComponent[] = "component";
64 constexpr char kInvocationResults[] = "invocation_results";
65 constexpr char kSize[] = "size";
66 constexpr char kEndOfInput[] = "end_of_input";
67 constexpr char kErrorCode[] = "code";
68 constexpr char kErrorMessage[] = "error_message";
69
70 // Period between reporting dataset statistics.
71 constexpr int kStatsReportingPeriodMillis = 1000;
72
73 } // namespace
74
75 class ParallelMapDatasetOp::Dataset : public DatasetBase {
76 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64_t num_parallel_calls,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,DeterminismPolicy deterministic,std::unique_ptr<CapturedFunction> captured_func,bool preserve_cardinality,int op_version)77 Dataset(OpKernelContext* ctx, const DatasetBase* input,
78 int64_t num_parallel_calls, const DataTypeVector& output_types,
79 const std::vector<PartialTensorShape>& output_shapes,
80 DeterminismPolicy deterministic,
81 std::unique_ptr<CapturedFunction> captured_func,
82 bool preserve_cardinality, int op_version)
83 : Dataset(DatasetContext(ctx), input, num_parallel_calls, output_types,
84 output_shapes, deterministic, std::move(captured_func),
85 preserve_cardinality, op_version) {}
86
Dataset(DatasetContext dataset_context,const DatasetBase * input,int64_t num_parallel_calls,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,DeterminismPolicy deterministic,std::unique_ptr<CapturedFunction> captured_func,bool preserve_cardinality,int op_version)87 Dataset(DatasetContext dataset_context, const DatasetBase* input,
88 int64_t num_parallel_calls, const DataTypeVector& output_types,
89 const std::vector<PartialTensorShape>& output_shapes,
90 DeterminismPolicy deterministic,
91 std::unique_ptr<CapturedFunction> captured_func,
92 bool preserve_cardinality, int op_version)
93 : DatasetBase(std::move(dataset_context)),
94 input_(input),
95 num_parallel_calls_(num_parallel_calls),
96 output_types_(output_types),
97 output_shapes_(output_shapes),
98 deterministic_(deterministic),
99 preserve_cardinality_(preserve_cardinality),
100 captured_func_(std::move(captured_func)),
101 op_version_(op_version) {
102 input_->Ref();
103 }
104
~Dataset()105 ~Dataset() override { input_->Unref(); }
106
MakeIteratorInternal(const string & prefix) const107 std::unique_ptr<IteratorBase> MakeIteratorInternal(
108 const string& prefix) const override {
109 name_utils::IteratorPrefixParams params;
110 params.op_version = op_version_;
111 return std::make_unique<Iterator>(Iterator::Params{
112 this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
113 }
114
output_dtypes() const115 const DataTypeVector& output_dtypes() const override { return output_types_; }
116
output_shapes() const117 const std::vector<PartialTensorShape>& output_shapes() const override {
118 return output_shapes_;
119 }
120
DebugString() const121 string DebugString() const override {
122 name_utils::DatasetDebugStringParams params;
123 params.op_version = op_version_;
124 return name_utils::DatasetDebugString(ParallelMapDatasetOp::kDatasetType,
125 params);
126 }
127
CardinalityInternal() const128 int64_t CardinalityInternal() const override {
129 if (preserve_cardinality_) {
130 return input_->Cardinality();
131 } else {
132 return kUnknownCardinality;
133 }
134 }
135
CardinalityInternal(CardinalityOptions options) const136 int64_t CardinalityInternal(CardinalityOptions options) const override {
137 if (preserve_cardinality_) {
138 return input_->Cardinality(options);
139 } else {
140 return kUnknownCardinality;
141 }
142 }
143
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const144 Status Get(OpKernelContext* ctx, int64 index,
145 std::vector<Tensor>* out_tensors) const override {
146 TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
147 std::vector<Tensor> args;
148 TF_RETURN_IF_ERROR(input_->Get(ctx, index, &args));
149 if (!instantiated_captured_func_) {
150 TF_RETURN_IF_ERROR(
151 captured_func_->Instantiate(InstantiateCapturedFunctionParams(ctx),
152 &instantiated_captured_func_));
153 }
154 return instantiated_captured_func_->RunInstantiated(args, out_tensors);
155 }
156
InputDatasets(std::vector<const DatasetBase * > * inputs) const157 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
158 inputs->push_back(input_);
159 return OkStatus();
160 }
161
CheckExternalState() const162 Status CheckExternalState() const override {
163 TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
164 return input_->CheckExternalState();
165 }
166
167 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const168 Status AsGraphDefInternal(SerializationContext* ctx,
169 DatasetGraphDefBuilder* b,
170 Node** output) const override {
171 // Input: input_dataset
172 Node* input_graph_node = nullptr;
173 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
174
175 // Input: other_arguments
176 std::vector<Node*> other_arguments;
177 DataTypeVector other_arguments_types;
178 TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
179 &other_arguments_types));
180
181 // Input: num_parallel_calls
182 Node* num_parallel_calls = nullptr;
183 if (op_version_ == 1) {
184 TF_RETURN_IF_ERROR(b->AddScalar(static_cast<int32>(num_parallel_calls_),
185 &num_parallel_calls));
186 } else {
187 TF_RETURN_IF_ERROR(
188 b->AddScalar(num_parallel_calls_, &num_parallel_calls));
189 }
190 std::vector<std::pair<StringPiece, AttrValue>> attrs;
191
192 // Attr: f
193 AttrValue f_attr;
194 b->BuildAttrValue(captured_func_->func(), &f_attr);
195 attrs.emplace_back(kFunc, f_attr);
196
197 // Attr: Targuments
198 AttrValue other_arguments_types_attr;
199 b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
200 attrs.emplace_back(kTarguments, other_arguments_types_attr);
201
202 // Attr: use_inter_op_parallelism
203 AttrValue use_inter_op_parallelism_attr;
204 b->BuildAttrValue(captured_func_->use_inter_op_parallelism(),
205 &use_inter_op_parallelism_attr);
206 attrs.emplace_back(kUseInterOpParallelism, use_inter_op_parallelism_attr);
207
208 if (op_version_ == 1) {
209 // Attr: sloppy
210 AttrValue sloppy_attr;
211 b->BuildAttrValue(deterministic_.IsNondeterministic(), &sloppy_attr);
212 attrs.emplace_back(kSloppy, sloppy_attr);
213 }
214 if (op_version_ == 2) {
215 AttrValue deterministic_attr;
216 b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
217 attrs.emplace_back(kDeterministic, deterministic_attr);
218 }
219
220 // Attr: preserve_cardinality
221 AttrValue preserve_cardinality_attr;
222 b->BuildAttrValue(preserve_cardinality_, &preserve_cardinality_attr);
223 attrs.emplace_back(kPreserveCardinality, preserve_cardinality_attr);
224
225 TF_RETURN_IF_ERROR(b->AddDataset(
226 this,
227 {std::make_pair(0, input_graph_node),
228 std::make_pair(2, num_parallel_calls)}, // Single tensor inputs.
229 {std::make_pair(1, other_arguments)}, // Tensor list inputs.
230 attrs, output));
231 return OkStatus();
232 }
233
234 private:
235 class Iterator : public DatasetIterator<Dataset> {
236 public:
Iterator(const Params & params)237 explicit Iterator(const Params& params)
238 : DatasetIterator<Dataset>(params),
239 mu_(std::make_shared<mutex>()),
240 cond_var_(std::make_shared<condition_variable>()),
241 num_parallel_calls_(std::make_shared<model::SharedState>(
242 params.dataset->num_parallel_calls_, mu_, cond_var_)),
243 deterministic_(params.dataset->deterministic_.IsDeterministic() ||
244 params.dataset->deterministic_.IsDefault()),
245 preserve_cardinality_(params.dataset->preserve_cardinality_),
246 autotune_(params.dataset->num_parallel_calls_ == model::kAutotune) {}
247
~Iterator()248 ~Iterator() override {
249 CancelThreads(/*wait=*/true);
250 input_impl_.reset();
251 if (deregister_fn_) deregister_fn_();
252 }
253
Initialize(IteratorContext * ctx)254 Status Initialize(IteratorContext* ctx) override {
255 mutex_lock l(*mu_);
256 interleave_depth_ = ctx->interleave_depth();
257
258 if (num_parallel_calls_->value == model::kAutotune) {
259 num_parallel_calls_->value = GetAutotuneDefaultParallelism(ctx);
260 }
261 cancellation_manager_ = std::make_unique<CancellationManager>();
262 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
263 ctx->cancellation_manager(),
264 [this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_));
265 IteratorContext::Params params(ctx);
266 params.cancellation_manager = cancellation_manager_.get();
267 TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
268 IteratorContext(params), this, prefix(), &input_impl_));
269 return dataset()->captured_func_->Instantiate(
270 ctx, &instantiated_captured_func_);
271 }
272
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)273 Status GetNextInternal(IteratorContext* ctx,
274 std::vector<Tensor>* out_tensors,
275 bool* end_of_sequence) override {
276 std::shared_ptr<InvocationResult> result;
277 {
278 mutex_lock l(*mu_);
279 EnsureThreadsStarted(ctx);
280 while (ShouldWait(&result)) {
281 RecordStop(ctx);
282 cond_var_->wait(l);
283 RecordStart(ctx);
284 }
285 if (cancelled_) {
286 return errors::Cancelled("Iterator was cancelled");
287 }
288 }
289 RecordStop(ctx);
290 result->notification.WaitForNotification();
291 RecordStart(ctx);
292 profiler::TraceMe traceme([&] {
293 return profiler::TraceMeEncode("ParallelMapConsume",
294 {{"element_id", result->uid}});
295 });
296 return ProcessResult(ctx, result, out_tensors, end_of_sequence);
297 }
298
299 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const300 std::shared_ptr<model::Node> CreateNode(
301 IteratorContext* ctx, model::Node::Args args) const override {
302 return model::MakeAsyncKnownRatioNode(
303 std::move(args),
304 /*ratio=*/1,
305 {model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
306 /*max=*/ctx->runner_threadpool_size())});
307 }
308
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)309 Status SaveInternal(SerializationContext* ctx,
310 IteratorStateWriter* writer) override {
311 TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
312 dataset()->captured_func_->CheckExternalState()));
313 mutex_lock l(*mu_);
314 // Wait for all in-flight calls to complete.
315 while (num_calls_ > 0) {
316 cond_var_->wait(l);
317 }
318 if (num_calls_ != 0) {
319 return errors::FailedPrecondition(
320 "Unexpected outstanding calls encountered.");
321 }
322 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
323 TF_RETURN_IF_ERROR(
324 writer->WriteScalar(absl::StrCat(prefix(), "::", kInvocationResults),
325 kSize, invocation_results_.size()));
326 for (size_t i = 0; i < invocation_results_.size(); i++) {
327 const auto& result = *(invocation_results_[i]);
328 std::string element_prefix =
329 absl::StrCat(prefix(), "::", kInvocationResults, "::", i);
330 TF_RETURN_IF_ERROR(
331 WriteStatusLocked(writer, element_prefix, result.status));
332 TF_RETURN_IF_ERROR(writer->WriteScalar(element_prefix, kSize,
333 result.return_values.size()));
334 for (size_t j = 0; j < result.return_values.size(); j++) {
335 TF_RETURN_IF_ERROR(writer->WriteTensor(
336 element_prefix, absl::StrCat(kComponent, "[", j, "]"),
337 result.return_values[j]));
338 }
339 if (result.end_of_input) {
340 TF_RETURN_IF_ERROR(
341 writer->WriteScalar(element_prefix, kEndOfInput, ""));
342 }
343 }
344 return OkStatus();
345 }
346
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)347 Status RestoreInternal(IteratorContext* ctx,
348 IteratorStateReader* reader) override {
349 mutex_lock l(*mu_);
350 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
351 int64_t invocation_results_size;
352 TF_RETURN_IF_ERROR(
353 reader->ReadScalar(absl::StrCat(prefix(), "::", kInvocationResults),
354 kSize, &invocation_results_size));
355 DCHECK(invocation_results_.empty());
356 for (size_t i = 0; i < invocation_results_size; i++) {
357 invocation_results_.push_back(std::make_shared<InvocationResult>());
358 auto& result = *invocation_results_.back();
359 std::string element_prefix =
360 absl::StrCat(prefix(), "::", kInvocationResults, "::", i);
361 TF_RETURN_IF_ERROR(
362 ReadStatusLocked(reader, element_prefix, &result.status));
363 size_t num_return_values;
364 {
365 int64_t size;
366 TF_RETURN_IF_ERROR(reader->ReadScalar(element_prefix, kSize, &size));
367 num_return_values = static_cast<size_t>(size);
368 if (num_return_values != size) {
369 return errors::InvalidArgument(
370 element_prefix, ",", kSize, ": ", size,
371 " is not a valid value of type size_t.");
372 }
373 }
374 result.return_values.reserve(num_return_values);
375 for (size_t j = 0; j < num_return_values; j++) {
376 result.return_values.emplace_back();
377 TF_RETURN_IF_ERROR(reader->ReadTensor(
378 ctx->flr(), element_prefix, absl::StrCat(kComponent, "[", j, "]"),
379 &result.return_values.back()));
380 }
381 result.end_of_input = reader->Contains(element_prefix, kEndOfInput);
382 RecordBufferEnqueue(ctx, result.return_values);
383 result.notification.Notify();
384 }
385 return OkStatus();
386 }
387
GetTraceMeMetadata() const388 TraceMeMetadata GetTraceMeMetadata() const override {
389 int64_t parallelism = -1;
390 // NOTE: We only set the parallelism value if the lock can be acquired
391 // right away to avoid introducing tracing overhead.
392 if (mu_->try_lock()) {
393 parallelism = num_parallel_calls_->value;
394 mu_->unlock();
395 }
396 data::TraceMeMetadata result;
397 result.push_back(
398 std::make_pair("autotune", autotune_ ? "true" : "false"));
399 result.push_back(
400 std::make_pair("deterministic", deterministic_ ? "true" : "false"));
401 result.push_back(std::make_pair(
402 "parallelism",
403 parallelism == -1
404 ? kTraceInfoUnavailable
405 : strings::Printf("%lld", static_cast<long long>(parallelism))));
406 result.push_back(std::make_pair(
407 "interleave_depth",
408 strings::Printf("%lld", static_cast<long long>(interleave_depth_))));
409 return result;
410 }
411
412 private:
413 struct InvocationResult {
InvocationResulttensorflow::data::ParallelMapDatasetOp::Dataset::Iterator::InvocationResult414 InvocationResult() : uid(tensorflow::EnvTime::NowNanos()) {}
415
416 Notification notification;
417 Status status;
418 std::vector<Tensor> return_values;
419 bool end_of_input = false;
420 const int64_t uid;
421 };
422
CancelThreads(bool wait)423 void CancelThreads(bool wait) TF_LOCKS_EXCLUDED(mu_) {
424 cancellation_manager_->StartCancel();
425 mutex_lock l(*mu_);
426 cancelled_ = true;
427 cond_var_->notify_all();
428 // Wait for all in-flight calls to complete.
429 while (wait && num_calls_ > 0) {
430 cond_var_->wait(l);
431 }
432 }
433
EnsureThreadsStarted(IteratorContext * ctx)434 void EnsureThreadsStarted(IteratorContext* ctx)
435 TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
436 if (!runner_thread_) {
437 auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
438 runner_thread_ = ctx->StartThread(
439 "tf_data_parallel_map",
440 std::bind(&Iterator::RunnerThread, this, ctx_copy));
441 if (ctx->stats_aggregator()) {
442 stats_thread_ = ctx->StartThread(
443 "tf_data_parallel_map_stats",
444 std::bind(&Iterator::StatsThread, this, ctx_copy));
445 }
446 }
447 }
448
CallCompleted(const std::shared_ptr<IteratorContext> & ctx,const std::shared_ptr<InvocationResult> & result)449 void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
450 const std::shared_ptr<InvocationResult>& result)
451 TF_LOCKS_EXCLUDED(*mu_) {
452 mutex_lock l(*mu_);
453 num_calls_--;
454 result->notification.Notify();
455 cond_var_->notify_all();
456 }
457
CallFunction(const std::shared_ptr<IteratorContext> & ctx,const std::shared_ptr<InvocationResult> & result)458 void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
459 const std::shared_ptr<InvocationResult>& result)
460 TF_LOCKS_EXCLUDED(*mu_) {
461 profiler::TraceMe traceme([&] {
462 return profiler::TraceMeEncode("ParallelMapProduce",
463 {{"element_id", result->uid}});
464 });
465 // Get the next input element.
466 std::vector<Tensor> input_element;
467 result->status = input_impl_->GetNext(ctx.get(), &input_element,
468 &result->end_of_input);
469 if (result->end_of_input || !result->status.ok()) {
470 CallCompleted(ctx, result);
471 return;
472 }
473
474 auto done = [this, ctx, result](Status status) {
475 result->status.Update(status);
476 RecordBufferEnqueue(ctx.get(), result->return_values);
477 CallCompleted(ctx, result);
478 };
479
480 // Apply the map function on `input_element`, storing the result in
481 // `result->return_values`, and invoking `done` when finished.
482 if (dataset()->captured_func_->use_inter_op_parallelism()) {
483 instantiated_captured_func_->RunAsync(
484 ctx.get(), std::move(input_element), &result->return_values,
485 std::move(done), model_node());
486 } else {
487 // In this case, the function will be executed using single-threaded
488 // executor. We schedule it using `ctx->runner()` to enable concurrent
489 // application of the function over different input elements.
490 auto fn = std::bind(
491 [this, ctx, result](std::vector<Tensor> input_element) {
492 return instantiated_captured_func_->Run(
493 ctx.get(), std::move(input_element), &result->return_values,
494 model_node());
495 },
496 std::move(input_element));
497 (*ctx->runner())(
498 [this, ctx, fn = std::move(fn), done = std::move(done)]() {
499 Status s;
500 // Check whether we are already recording to prevent invalid
501 // nesting of `RecordStart` calls.
502 if (IsRecording(ctx.get())) {
503 s = fn();
504 } else {
505 RecordStart(ctx.get());
506 s = fn();
507 RecordStop(ctx.get());
508 }
509 done(s);
510 });
511 }
512 }
513
ProcessResult(IteratorContext * ctx,const std::shared_ptr<InvocationResult> & result,std::vector<Tensor> * out_tensors,bool * end_of_sequence)514 Status ProcessResult(IteratorContext* ctx,
515 const std::shared_ptr<InvocationResult>& result,
516 std::vector<Tensor>* out_tensors,
517 bool* end_of_sequence) TF_LOCKS_EXCLUDED(*mu_) {
518 if (!result->end_of_input && result->status.ok()) {
519 *out_tensors = std::move(result->return_values);
520 RecordBufferDequeue(ctx, *out_tensors);
521 *end_of_sequence = false;
522 return OkStatus();
523 }
524 if (errors::IsOutOfRange(result->status)) {
525 if (preserve_cardinality_) {
526 // To guarantee that the transformation preserves the cardinality of
527 // the dataset, we convert `OutOfRange` to `InvalidArgument` as the
528 // former may be interpreted by a caller as the end of sequence.
529 return errors::InvalidArgument(
530 "Function invocation produced OutOfRangeError: ",
531 result->status.error_message());
532 } else {
533 // `f` may deliberately raise `errors::OutOfRange` to indicate
534 // that we should terminate the iteration early.
535 *end_of_sequence = true;
536 return OkStatus();
537 }
538 }
539 *end_of_sequence = result->end_of_input;
540 return result->status;
541 }
542
RunnerThread(const std::shared_ptr<IteratorContext> & ctx)543 void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
544 TF_LOCKS_EXCLUDED(*mu_) {
545 RecordStart(ctx.get());
546 auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
547 std::vector<std::shared_ptr<InvocationResult>> new_calls;
548 {
549 tf_shared_lock l(*mu_); // mu_ == num_parallel_calls_->mu
550 new_calls.reserve(num_parallel_calls_->value);
551 }
552 auto busy = [this]() TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
553 int64_t num_parallel_calls = num_parallel_calls_->value;
554 return num_calls_ >= num_parallel_calls ||
555 invocation_results_.size() >= num_parallel_calls;
556 };
557 while (true) {
558 {
559 mutex_lock l(*mu_);
560 while (!cancelled_ && busy()) {
561 RecordStop(ctx.get());
562 cond_var_->wait(l);
563 RecordStart(ctx.get());
564 }
565 if (cancelled_) {
566 return;
567 }
568 while (!busy()) {
569 invocation_results_.push_back(std::make_shared<InvocationResult>());
570 new_calls.push_back(invocation_results_.back());
571 num_calls_++;
572 }
573 cond_var_->notify_all();
574 }
575 for (const auto& call : new_calls) {
576 CallFunction(ctx, call);
577 }
578 new_calls.clear();
579 }
580 }
581
582 // Determines whether the caller needs to wait for a result. Upon returning
583 // false, `result` will point to the result.
ShouldWait(std::shared_ptr<InvocationResult> * result)584 bool ShouldWait(std::shared_ptr<InvocationResult>* result)
585 TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
586 if (cancelled_) {
587 return false;
588 }
589 if (!deterministic_) {
590 // Iterate through in-flight results and return the first one that is
591 // found to be available and not end-of-input. If the first result (in
592 // order) is end-of-input, we know that all earlier iterations have
593 // already been completed, so it is safe to return that result for the
594 // caller to process end of iteration.
595 for (auto it = invocation_results_.begin();
596 it != invocation_results_.end(); ++it) {
597 if ((*it)->notification.HasBeenNotified() &&
598 (it == invocation_results_.begin() || !(*it)->end_of_input)) {
599 std::swap(*result, *it);
600 invocation_results_.erase(it);
601 cond_var_->notify_all();
602 return false;
603 }
604 }
605 } else if (!invocation_results_.empty()) {
606 std::swap(*result, invocation_results_.front());
607 invocation_results_.pop_front();
608 cond_var_->notify_all();
609 return false;
610 }
611 return true;
612 }
613
StatsThread(const std::shared_ptr<IteratorContext> & ctx)614 void StatsThread(const std::shared_ptr<IteratorContext>& ctx) {
615 for (int64_t step = 0;; ++step) {
616 int num_calls;
617 int num_parallel_calls;
618 {
619 mutex_lock l(*mu_);
620 if (step != 0 && !cancelled_) {
621 cond_var_->wait_for(
622 l, std::chrono::milliseconds(kStatsReportingPeriodMillis));
623 }
624 if (cancelled_) {
625 return;
626 }
627 num_calls = num_calls_;
628 num_parallel_calls = num_parallel_calls_->value;
629 }
630 if (num_parallel_calls == 0) {
631 // Avoid division by zero.
632 num_parallel_calls = 1;
633 }
634 ctx->stats_aggregator()->AddScalar(
635 stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
636 static_cast<float>(num_calls) /
637 static_cast<float>(num_parallel_calls),
638 step);
639 }
640 }
641
WriteStatusLocked(IteratorStateWriter * writer,const std::string & key,const Status & status)642 Status WriteStatusLocked(IteratorStateWriter* writer,
643 const std::string& key, const Status& status)
644 TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
645 TF_RETURN_IF_ERROR(writer->WriteScalar(
646 key, kErrorCode, static_cast<int64_t>(status.code())));
647 if (!status.ok()) {
648 TF_RETURN_IF_ERROR(
649 writer->WriteScalar(key, kErrorMessage, status.error_message()));
650 }
651 return OkStatus();
652 }
653
ReadStatusLocked(IteratorStateReader * reader,const std::string & key,Status * status)654 Status ReadStatusLocked(IteratorStateReader* reader, const std::string& key,
655 Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
656 int64_t code_int;
657 TF_RETURN_IF_ERROR(reader->ReadScalar(key, kErrorCode, &code_int));
658 error::Code code = static_cast<error::Code>(code_int);
659
660 if (code != error::Code::OK) {
661 tstring error_message;
662 TF_RETURN_IF_ERROR(
663 reader->ReadScalar(key, kErrorMessage, &error_message));
664 *status = Status(code, error_message);
665 } else {
666 *status = OkStatus();
667 }
668 return OkStatus();
669 }
670
671 // Used for coordination between the main thread and the runner thread.
672 const std::shared_ptr<mutex> mu_;
673 // Used for coordination between the main thread and the runner thread. In
674 // particular, the runner thread should only schedule new calls when the
675 // number of in-flight calls is less than the user specified level of
676 // parallelism and there are slots available in the `invocation_results_`
677 // buffer.
678 const std::shared_ptr<condition_variable> cond_var_;
679 // Identifies the maximum number of parallel calls.
680 const std::shared_ptr<model::SharedState> num_parallel_calls_;
681 const bool deterministic_;
682 const bool preserve_cardinality_;
683 const bool autotune_;
684 // Counts the number of outstanding calls.
685 int64_t num_calls_ TF_GUARDED_BY(*mu_) = 0;
686 // Controls cancellation of `input_impl_`. Must be ordered before
687 // `input_impl_` so that `input_impl_` is destroyed first.
688 std::unique_ptr<CancellationManager> cancellation_manager_;
689 std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
690 // Must be ordered after `cancellation_manager_` so that `input_impl_` is
691 // destroyed first.
692 std::unique_ptr<IteratorBase> input_impl_;
693 // Buffer for storing the invocation results.
694 std::deque<std::shared_ptr<InvocationResult>> invocation_results_
695 TF_GUARDED_BY(*mu_);
696 bool cancelled_ TF_GUARDED_BY(*mu_) = false;
697 std::unique_ptr<Thread> runner_thread_ TF_GUARDED_BY(*mu_);
698 std::unique_ptr<Thread> stats_thread_ TF_GUARDED_BY(*mu_);
699
700 // Method for deregistering the cancellation callback.
701 std::function<void()> deregister_fn_;
702
703 // Records the number of ParallelInterleave operations in the path from the
704 // root node to this node (not including this node) in the input pipeline
705 // tree. We record the interleave depth so that it can be included in the
706 // trace metadata.
707 int64 interleave_depth_ = -1;
708 };
709
710 const DatasetBase* const input_;
711 const int64_t num_parallel_calls_;
712 const DataTypeVector output_types_;
713 const std::vector<PartialTensorShape> output_shapes_;
714 const DeterminismPolicy deterministic_;
715 const bool preserve_cardinality_;
716 const std::unique_ptr<CapturedFunction> captured_func_;
717 const int op_version_;
718 // This is used for random access provided by Get().
719 mutable std::unique_ptr<InstantiatedCapturedFunction>
720 instantiated_captured_func_;
721 };
722
ParallelMapDatasetOp(OpKernelConstruction * ctx)723 ParallelMapDatasetOp::ParallelMapDatasetOp(OpKernelConstruction* ctx)
724 : UnaryDatasetOpKernel(ctx), op_version_(ctx->HasAttr(kSloppy) ? 1 : 2) {
725 FunctionMetadata::Params params;
726 OP_REQUIRES_OK(ctx, ctx->GetAttr(kUseInterOpParallelism,
727 ¶ms.use_inter_op_parallelism));
728 OP_REQUIRES_OK(ctx,
729 FunctionMetadata::Create(ctx, kFunc, params, &func_metadata_));
730 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
731 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
732 if (op_version_ == 1) {
733 bool sloppy;
734 OP_REQUIRES_OK(ctx, ctx->GetAttr(kSloppy, &sloppy));
735 if (sloppy) {
736 deterministic_ =
737 DeterminismPolicy(DeterminismPolicy::Type::kNondeterministic);
738 } else {
739 deterministic_ = DeterminismPolicy(DeterminismPolicy::Type::kDefault);
740 }
741 }
742 if (op_version_ == 2) {
743 std::string deterministic;
744 OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
745 OP_REQUIRES_OK(
746 ctx, DeterminismPolicy::FromString(deterministic, &deterministic_));
747 }
748 OP_REQUIRES_OK(ctx,
749 ctx->GetAttr(kPreserveCardinality, &preserve_cardinality_));
750 }
751
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)752 void ParallelMapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
753 DatasetBase** output) {
754 int64_t num_parallel_calls;
755 if (op_version_ == 1) {
756 int32_t parallel_calls;
757 OP_REQUIRES_OK(
758 ctx, ParseScalarArgument(ctx, kNumParallelCalls, ¶llel_calls));
759 num_parallel_calls = parallel_calls;
760 }
761 if (op_version_ == 2) {
762 OP_REQUIRES_OK(
763 ctx, ParseScalarArgument(ctx, kNumParallelCalls, &num_parallel_calls));
764 }
765 OP_REQUIRES(
766 ctx, num_parallel_calls > 0 || num_parallel_calls == model::kAutotune,
767 errors::InvalidArgument("num_parallel_calls must be greater than zero."));
768
769 std::unique_ptr<CapturedFunction> captured_func;
770 OP_REQUIRES_OK(ctx,
771 CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
772 &captured_func));
773
774 if (num_parallel_calls == model::kAutotune) {
775 metrics::RecordTFDataAutotune(kDatasetType);
776 }
777
778 *output =
779 new Dataset(ctx, input, num_parallel_calls, output_types_, output_shapes_,
780 deterministic_, std::move(captured_func),
781 preserve_cardinality_, op_version_);
782 }
783
MakeDataServiceUncompressDataset(DatasetBase * input,std::unique_ptr<CapturedFunction> captured_function,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)784 std::unique_ptr<DatasetBase> MakeDataServiceUncompressDataset(
785 DatasetBase* input, std::unique_ptr<CapturedFunction> captured_function,
786 const DataTypeVector& output_types,
787 const std::vector<PartialTensorShape>& output_shapes) {
788 DatasetContext::Params param;
789 param.type_string = kParallelMapDatasetV2;
790 param.node_name = kParallelMapDatasetV2;
791 return std::make_unique<ParallelMapDatasetOp::Dataset>(
792 DatasetContext(std::move(param)), input,
793 /*num_parallel_calls=*/model::kAutotune, output_types, output_shapes,
794 DeterminismPolicy(DeterminismPolicy::Type::kDefault),
795 std::move(captured_function),
796 /*preserve_cardinality=*/true, /*op_version=*/2);
797 }
798
799 namespace {
800 REGISTER_KERNEL_BUILDER(Name(kParallelMapDatasetV1).Device(DEVICE_CPU),
801 ParallelMapDatasetOp);
802 REGISTER_KERNEL_BUILDER(Name(kParallelMapDatasetV2).Device(DEVICE_CPU),
803 ParallelMapDatasetOp);
804 REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelMapDatasetV1);
805 REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelMapDatasetV2);
806 } // namespace
807 } // namespace data
808 } // namespace tensorflow
809