xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/parallel_map_dataset_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/parallel_map_dataset_op.h"
16 
17 #include <deque>
18 
19 #include "tensorflow/core/common_runtime/function.h"
20 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
21 #include "tensorflow/core/data/dataset_utils.h"
22 #include "tensorflow/core/data/name_utils.h"
23 #include "tensorflow/core/data/stats_utils.h"
24 #include "tensorflow/core/framework/metrics.h"
25 #include "tensorflow/core/framework/model.h"
26 #include "tensorflow/core/framework/partial_tensor_shape.h"
27 #include "tensorflow/core/framework/stats_aggregator.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/random/random.h"
31 #include "tensorflow/core/platform/stringprintf.h"
32 #include "tensorflow/core/profiler/lib/traceme.h"
33 #include "tensorflow/core/profiler/lib/traceme_encode.h"
34 #include "tensorflow/core/protobuf/error_codes.pb.h"
35 
36 namespace tensorflow {
37 namespace data {
38 
39 // See documentation in ../../ops/dataset_ops.cc for a high-level
40 // description of the following op.
41 
42 /* static */ constexpr const char* const ParallelMapDatasetOp::kDatasetType;
43 /* static */ constexpr const char* const ParallelMapDatasetOp::kInputDataset;
44 /* static */ constexpr const char* const ParallelMapDatasetOp::kOtherArguments;
45 /* static */ constexpr const char* const
46     ParallelMapDatasetOp::kNumParallelCalls;
47 /* static */ constexpr const char* const ParallelMapDatasetOp::kFunc;
48 /* static */ constexpr const char* const ParallelMapDatasetOp::kTarguments;
49 /* static */ constexpr const char* const ParallelMapDatasetOp::kOutputTypes;
50 /* static */ constexpr const char* const ParallelMapDatasetOp::kOutputShapes;
51 /* static */ constexpr const char* const
52     ParallelMapDatasetOp::kUseInterOpParallelism;
53 /* static */ constexpr const char* const ParallelMapDatasetOp::kDeterministic;
54 /* static */ constexpr const char* const ParallelMapDatasetOp::kSloppy;
55 /* static */ constexpr const char* const
56     ParallelMapDatasetOp::kPreserveCardinality;
57 
58 namespace {
59 
60 constexpr char kParallelMapDatasetV1[] = "ParallelMapDataset";
61 constexpr char kParallelMapDatasetV2[] = "ParallelMapDatasetV2";
62 
63 constexpr char kComponent[] = "component";
64 constexpr char kInvocationResults[] = "invocation_results";
65 constexpr char kSize[] = "size";
66 constexpr char kEndOfInput[] = "end_of_input";
67 constexpr char kErrorCode[] = "code";
68 constexpr char kErrorMessage[] = "error_message";
69 
70 // Period between reporting dataset statistics.
71 constexpr int kStatsReportingPeriodMillis = 1000;
72 
73 }  // namespace
74 
75 class ParallelMapDatasetOp::Dataset : public DatasetBase {
76  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64_t num_parallel_calls,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,DeterminismPolicy deterministic,std::unique_ptr<CapturedFunction> captured_func,bool preserve_cardinality,int op_version)77   Dataset(OpKernelContext* ctx, const DatasetBase* input,
78           int64_t num_parallel_calls, const DataTypeVector& output_types,
79           const std::vector<PartialTensorShape>& output_shapes,
80           DeterminismPolicy deterministic,
81           std::unique_ptr<CapturedFunction> captured_func,
82           bool preserve_cardinality, int op_version)
83       : Dataset(DatasetContext(ctx), input, num_parallel_calls, output_types,
84                 output_shapes, deterministic, std::move(captured_func),
85                 preserve_cardinality, op_version) {}
86 
Dataset(DatasetContext dataset_context,const DatasetBase * input,int64_t num_parallel_calls,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,DeterminismPolicy deterministic,std::unique_ptr<CapturedFunction> captured_func,bool preserve_cardinality,int op_version)87   Dataset(DatasetContext dataset_context, const DatasetBase* input,
88           int64_t num_parallel_calls, const DataTypeVector& output_types,
89           const std::vector<PartialTensorShape>& output_shapes,
90           DeterminismPolicy deterministic,
91           std::unique_ptr<CapturedFunction> captured_func,
92           bool preserve_cardinality, int op_version)
93       : DatasetBase(std::move(dataset_context)),
94         input_(input),
95         num_parallel_calls_(num_parallel_calls),
96         output_types_(output_types),
97         output_shapes_(output_shapes),
98         deterministic_(deterministic),
99         preserve_cardinality_(preserve_cardinality),
100         captured_func_(std::move(captured_func)),
101         op_version_(op_version) {
102     input_->Ref();
103   }
104 
~Dataset()105   ~Dataset() override { input_->Unref(); }
106 
MakeIteratorInternal(const string & prefix) const107   std::unique_ptr<IteratorBase> MakeIteratorInternal(
108       const string& prefix) const override {
109     name_utils::IteratorPrefixParams params;
110     params.op_version = op_version_;
111     return std::make_unique<Iterator>(Iterator::Params{
112         this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
113   }
114 
output_dtypes() const115   const DataTypeVector& output_dtypes() const override { return output_types_; }
116 
output_shapes() const117   const std::vector<PartialTensorShape>& output_shapes() const override {
118     return output_shapes_;
119   }
120 
DebugString() const121   string DebugString() const override {
122     name_utils::DatasetDebugStringParams params;
123     params.op_version = op_version_;
124     return name_utils::DatasetDebugString(ParallelMapDatasetOp::kDatasetType,
125                                           params);
126   }
127 
CardinalityInternal() const128   int64_t CardinalityInternal() const override {
129     if (preserve_cardinality_) {
130       return input_->Cardinality();
131     } else {
132       return kUnknownCardinality;
133     }
134   }
135 
CardinalityInternal(CardinalityOptions options) const136   int64_t CardinalityInternal(CardinalityOptions options) const override {
137     if (preserve_cardinality_) {
138       return input_->Cardinality(options);
139     } else {
140       return kUnknownCardinality;
141     }
142   }
143 
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const144   Status Get(OpKernelContext* ctx, int64 index,
145              std::vector<Tensor>* out_tensors) const override {
146     TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
147     std::vector<Tensor> args;
148     TF_RETURN_IF_ERROR(input_->Get(ctx, index, &args));
149     if (!instantiated_captured_func_) {
150       TF_RETURN_IF_ERROR(
151           captured_func_->Instantiate(InstantiateCapturedFunctionParams(ctx),
152                                       &instantiated_captured_func_));
153     }
154     return instantiated_captured_func_->RunInstantiated(args, out_tensors);
155   }
156 
InputDatasets(std::vector<const DatasetBase * > * inputs) const157   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
158     inputs->push_back(input_);
159     return OkStatus();
160   }
161 
CheckExternalState() const162   Status CheckExternalState() const override {
163     TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
164     return input_->CheckExternalState();
165   }
166 
167  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const168   Status AsGraphDefInternal(SerializationContext* ctx,
169                             DatasetGraphDefBuilder* b,
170                             Node** output) const override {
171     // Input: input_dataset
172     Node* input_graph_node = nullptr;
173     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
174 
175     // Input: other_arguments
176     std::vector<Node*> other_arguments;
177     DataTypeVector other_arguments_types;
178     TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
179                                                   &other_arguments_types));
180 
181     // Input: num_parallel_calls
182     Node* num_parallel_calls = nullptr;
183     if (op_version_ == 1) {
184       TF_RETURN_IF_ERROR(b->AddScalar(static_cast<int32>(num_parallel_calls_),
185                                       &num_parallel_calls));
186     } else {
187       TF_RETURN_IF_ERROR(
188           b->AddScalar(num_parallel_calls_, &num_parallel_calls));
189     }
190     std::vector<std::pair<StringPiece, AttrValue>> attrs;
191 
192     // Attr: f
193     AttrValue f_attr;
194     b->BuildAttrValue(captured_func_->func(), &f_attr);
195     attrs.emplace_back(kFunc, f_attr);
196 
197     // Attr: Targuments
198     AttrValue other_arguments_types_attr;
199     b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
200     attrs.emplace_back(kTarguments, other_arguments_types_attr);
201 
202     // Attr: use_inter_op_parallelism
203     AttrValue use_inter_op_parallelism_attr;
204     b->BuildAttrValue(captured_func_->use_inter_op_parallelism(),
205                       &use_inter_op_parallelism_attr);
206     attrs.emplace_back(kUseInterOpParallelism, use_inter_op_parallelism_attr);
207 
208     if (op_version_ == 1) {
209       // Attr: sloppy
210       AttrValue sloppy_attr;
211       b->BuildAttrValue(deterministic_.IsNondeterministic(), &sloppy_attr);
212       attrs.emplace_back(kSloppy, sloppy_attr);
213     }
214     if (op_version_ == 2) {
215       AttrValue deterministic_attr;
216       b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
217       attrs.emplace_back(kDeterministic, deterministic_attr);
218     }
219 
220     // Attr: preserve_cardinality
221     AttrValue preserve_cardinality_attr;
222     b->BuildAttrValue(preserve_cardinality_, &preserve_cardinality_attr);
223     attrs.emplace_back(kPreserveCardinality, preserve_cardinality_attr);
224 
225     TF_RETURN_IF_ERROR(b->AddDataset(
226         this,
227         {std::make_pair(0, input_graph_node),
228          std::make_pair(2, num_parallel_calls)},  // Single tensor inputs.
229         {std::make_pair(1, other_arguments)},     // Tensor list inputs.
230         attrs, output));
231     return OkStatus();
232   }
233 
234  private:
235   class Iterator : public DatasetIterator<Dataset> {
236    public:
Iterator(const Params & params)237     explicit Iterator(const Params& params)
238         : DatasetIterator<Dataset>(params),
239           mu_(std::make_shared<mutex>()),
240           cond_var_(std::make_shared<condition_variable>()),
241           num_parallel_calls_(std::make_shared<model::SharedState>(
242               params.dataset->num_parallel_calls_, mu_, cond_var_)),
243           deterministic_(params.dataset->deterministic_.IsDeterministic() ||
244                          params.dataset->deterministic_.IsDefault()),
245           preserve_cardinality_(params.dataset->preserve_cardinality_),
246           autotune_(params.dataset->num_parallel_calls_ == model::kAutotune) {}
247 
~Iterator()248     ~Iterator() override {
249       CancelThreads(/*wait=*/true);
250       input_impl_.reset();
251       if (deregister_fn_) deregister_fn_();
252     }
253 
Initialize(IteratorContext * ctx)254     Status Initialize(IteratorContext* ctx) override {
255       mutex_lock l(*mu_);
256       interleave_depth_ = ctx->interleave_depth();
257 
258       if (num_parallel_calls_->value == model::kAutotune) {
259         num_parallel_calls_->value = GetAutotuneDefaultParallelism(ctx);
260       }
261       cancellation_manager_ = std::make_unique<CancellationManager>();
262       TF_RETURN_IF_ERROR(RegisterCancellationCallback(
263           ctx->cancellation_manager(),
264           [this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_));
265       IteratorContext::Params params(ctx);
266       params.cancellation_manager = cancellation_manager_.get();
267       TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
268           IteratorContext(params), this, prefix(), &input_impl_));
269       return dataset()->captured_func_->Instantiate(
270           ctx, &instantiated_captured_func_);
271     }
272 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)273     Status GetNextInternal(IteratorContext* ctx,
274                            std::vector<Tensor>* out_tensors,
275                            bool* end_of_sequence) override {
276       std::shared_ptr<InvocationResult> result;
277       {
278         mutex_lock l(*mu_);
279         EnsureThreadsStarted(ctx);
280         while (ShouldWait(&result)) {
281           RecordStop(ctx);
282           cond_var_->wait(l);
283           RecordStart(ctx);
284         }
285         if (cancelled_) {
286           return errors::Cancelled("Iterator was cancelled");
287         }
288       }
289       RecordStop(ctx);
290       result->notification.WaitForNotification();
291       RecordStart(ctx);
292       profiler::TraceMe traceme([&] {
293         return profiler::TraceMeEncode("ParallelMapConsume",
294                                        {{"element_id", result->uid}});
295       });
296       return ProcessResult(ctx, result, out_tensors, end_of_sequence);
297     }
298 
299    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const300     std::shared_ptr<model::Node> CreateNode(
301         IteratorContext* ctx, model::Node::Args args) const override {
302       return model::MakeAsyncKnownRatioNode(
303           std::move(args),
304           /*ratio=*/1,
305           {model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
306                                 /*max=*/ctx->runner_threadpool_size())});
307     }
308 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)309     Status SaveInternal(SerializationContext* ctx,
310                         IteratorStateWriter* writer) override {
311       TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
312           dataset()->captured_func_->CheckExternalState()));
313       mutex_lock l(*mu_);
314       // Wait for all in-flight calls to complete.
315       while (num_calls_ > 0) {
316         cond_var_->wait(l);
317       }
318       if (num_calls_ != 0) {
319         return errors::FailedPrecondition(
320             "Unexpected outstanding calls encountered.");
321       }
322       TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
323       TF_RETURN_IF_ERROR(
324           writer->WriteScalar(absl::StrCat(prefix(), "::", kInvocationResults),
325                               kSize, invocation_results_.size()));
326       for (size_t i = 0; i < invocation_results_.size(); i++) {
327         const auto& result = *(invocation_results_[i]);
328         std::string element_prefix =
329             absl::StrCat(prefix(), "::", kInvocationResults, "::", i);
330         TF_RETURN_IF_ERROR(
331             WriteStatusLocked(writer, element_prefix, result.status));
332         TF_RETURN_IF_ERROR(writer->WriteScalar(element_prefix, kSize,
333                                                result.return_values.size()));
334         for (size_t j = 0; j < result.return_values.size(); j++) {
335           TF_RETURN_IF_ERROR(writer->WriteTensor(
336               element_prefix, absl::StrCat(kComponent, "[", j, "]"),
337               result.return_values[j]));
338         }
339         if (result.end_of_input) {
340           TF_RETURN_IF_ERROR(
341               writer->WriteScalar(element_prefix, kEndOfInput, ""));
342         }
343       }
344       return OkStatus();
345     }
346 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)347     Status RestoreInternal(IteratorContext* ctx,
348                            IteratorStateReader* reader) override {
349       mutex_lock l(*mu_);
350       TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
351       int64_t invocation_results_size;
352       TF_RETURN_IF_ERROR(
353           reader->ReadScalar(absl::StrCat(prefix(), "::", kInvocationResults),
354                              kSize, &invocation_results_size));
355       DCHECK(invocation_results_.empty());
356       for (size_t i = 0; i < invocation_results_size; i++) {
357         invocation_results_.push_back(std::make_shared<InvocationResult>());
358         auto& result = *invocation_results_.back();
359         std::string element_prefix =
360             absl::StrCat(prefix(), "::", kInvocationResults, "::", i);
361         TF_RETURN_IF_ERROR(
362             ReadStatusLocked(reader, element_prefix, &result.status));
363         size_t num_return_values;
364         {
365           int64_t size;
366           TF_RETURN_IF_ERROR(reader->ReadScalar(element_prefix, kSize, &size));
367           num_return_values = static_cast<size_t>(size);
368           if (num_return_values != size) {
369             return errors::InvalidArgument(
370                 element_prefix, ",", kSize, ": ", size,
371                 " is not a valid value of type size_t.");
372           }
373         }
374         result.return_values.reserve(num_return_values);
375         for (size_t j = 0; j < num_return_values; j++) {
376           result.return_values.emplace_back();
377           TF_RETURN_IF_ERROR(reader->ReadTensor(
378               ctx->flr(), element_prefix, absl::StrCat(kComponent, "[", j, "]"),
379               &result.return_values.back()));
380         }
381         result.end_of_input = reader->Contains(element_prefix, kEndOfInput);
382         RecordBufferEnqueue(ctx, result.return_values);
383         result.notification.Notify();
384       }
385       return OkStatus();
386     }
387 
GetTraceMeMetadata() const388     TraceMeMetadata GetTraceMeMetadata() const override {
389       int64_t parallelism = -1;
390       // NOTE: We only set the parallelism value if the lock can be acquired
391       // right away to avoid introducing tracing overhead.
392       if (mu_->try_lock()) {
393         parallelism = num_parallel_calls_->value;
394         mu_->unlock();
395       }
396       data::TraceMeMetadata result;
397       result.push_back(
398           std::make_pair("autotune", autotune_ ? "true" : "false"));
399       result.push_back(
400           std::make_pair("deterministic", deterministic_ ? "true" : "false"));
401       result.push_back(std::make_pair(
402           "parallelism",
403           parallelism == -1
404               ? kTraceInfoUnavailable
405               : strings::Printf("%lld", static_cast<long long>(parallelism))));
406       result.push_back(std::make_pair(
407           "interleave_depth",
408           strings::Printf("%lld", static_cast<long long>(interleave_depth_))));
409       return result;
410     }
411 
412    private:
413     struct InvocationResult {
InvocationResulttensorflow::data::ParallelMapDatasetOp::Dataset::Iterator::InvocationResult414       InvocationResult() : uid(tensorflow::EnvTime::NowNanos()) {}
415 
416       Notification notification;
417       Status status;
418       std::vector<Tensor> return_values;
419       bool end_of_input = false;
420       const int64_t uid;
421     };
422 
CancelThreads(bool wait)423     void CancelThreads(bool wait) TF_LOCKS_EXCLUDED(mu_) {
424       cancellation_manager_->StartCancel();
425       mutex_lock l(*mu_);
426       cancelled_ = true;
427       cond_var_->notify_all();
428       // Wait for all in-flight calls to complete.
429       while (wait && num_calls_ > 0) {
430         cond_var_->wait(l);
431       }
432     }
433 
EnsureThreadsStarted(IteratorContext * ctx)434     void EnsureThreadsStarted(IteratorContext* ctx)
435         TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
436       if (!runner_thread_) {
437         auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
438         runner_thread_ = ctx->StartThread(
439             "tf_data_parallel_map",
440             std::bind(&Iterator::RunnerThread, this, ctx_copy));
441         if (ctx->stats_aggregator()) {
442           stats_thread_ = ctx->StartThread(
443               "tf_data_parallel_map_stats",
444               std::bind(&Iterator::StatsThread, this, ctx_copy));
445         }
446       }
447     }
448 
CallCompleted(const std::shared_ptr<IteratorContext> & ctx,const std::shared_ptr<InvocationResult> & result)449     void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
450                        const std::shared_ptr<InvocationResult>& result)
451         TF_LOCKS_EXCLUDED(*mu_) {
452       mutex_lock l(*mu_);
453       num_calls_--;
454       result->notification.Notify();
455       cond_var_->notify_all();
456     }
457 
CallFunction(const std::shared_ptr<IteratorContext> & ctx,const std::shared_ptr<InvocationResult> & result)458     void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
459                       const std::shared_ptr<InvocationResult>& result)
460         TF_LOCKS_EXCLUDED(*mu_) {
461       profiler::TraceMe traceme([&] {
462         return profiler::TraceMeEncode("ParallelMapProduce",
463                                        {{"element_id", result->uid}});
464       });
465       // Get the next input element.
466       std::vector<Tensor> input_element;
467       result->status = input_impl_->GetNext(ctx.get(), &input_element,
468                                             &result->end_of_input);
469       if (result->end_of_input || !result->status.ok()) {
470         CallCompleted(ctx, result);
471         return;
472       }
473 
474       auto done = [this, ctx, result](Status status) {
475         result->status.Update(status);
476         RecordBufferEnqueue(ctx.get(), result->return_values);
477         CallCompleted(ctx, result);
478       };
479 
480       // Apply the map function on `input_element`, storing the result in
481       // `result->return_values`, and invoking `done` when finished.
482       if (dataset()->captured_func_->use_inter_op_parallelism()) {
483         instantiated_captured_func_->RunAsync(
484             ctx.get(), std::move(input_element), &result->return_values,
485             std::move(done), model_node());
486       } else {
487         // In this case, the function will be executed using single-threaded
488         // executor. We schedule it using `ctx->runner()` to enable concurrent
489         // application of the function over different input elements.
490         auto fn = std::bind(
491             [this, ctx, result](std::vector<Tensor> input_element) {
492               return instantiated_captured_func_->Run(
493                   ctx.get(), std::move(input_element), &result->return_values,
494                   model_node());
495             },
496             std::move(input_element));
497         (*ctx->runner())(
498             [this, ctx, fn = std::move(fn), done = std::move(done)]() {
499               Status s;
500               // Check whether we are already recording to prevent invalid
501               // nesting of `RecordStart` calls.
502               if (IsRecording(ctx.get())) {
503                 s = fn();
504               } else {
505                 RecordStart(ctx.get());
506                 s = fn();
507                 RecordStop(ctx.get());
508               }
509               done(s);
510             });
511       }
512     }
513 
ProcessResult(IteratorContext * ctx,const std::shared_ptr<InvocationResult> & result,std::vector<Tensor> * out_tensors,bool * end_of_sequence)514     Status ProcessResult(IteratorContext* ctx,
515                          const std::shared_ptr<InvocationResult>& result,
516                          std::vector<Tensor>* out_tensors,
517                          bool* end_of_sequence) TF_LOCKS_EXCLUDED(*mu_) {
518       if (!result->end_of_input && result->status.ok()) {
519         *out_tensors = std::move(result->return_values);
520         RecordBufferDequeue(ctx, *out_tensors);
521         *end_of_sequence = false;
522         return OkStatus();
523       }
524       if (errors::IsOutOfRange(result->status)) {
525         if (preserve_cardinality_) {
526           // To guarantee that the transformation preserves the cardinality of
527           // the dataset, we convert `OutOfRange` to `InvalidArgument` as the
528           // former may be interpreted by a caller as the end of sequence.
529           return errors::InvalidArgument(
530               "Function invocation produced OutOfRangeError: ",
531               result->status.error_message());
532         } else {
533           // `f` may deliberately raise `errors::OutOfRange` to indicate
534           // that we should terminate the iteration early.
535           *end_of_sequence = true;
536           return OkStatus();
537         }
538       }
539       *end_of_sequence = result->end_of_input;
540       return result->status;
541     }
542 
RunnerThread(const std::shared_ptr<IteratorContext> & ctx)543     void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
544         TF_LOCKS_EXCLUDED(*mu_) {
545       RecordStart(ctx.get());
546       auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
547       std::vector<std::shared_ptr<InvocationResult>> new_calls;
548       {
549         tf_shared_lock l(*mu_);  // mu_ == num_parallel_calls_->mu
550         new_calls.reserve(num_parallel_calls_->value);
551       }
552       auto busy = [this]() TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
553         int64_t num_parallel_calls = num_parallel_calls_->value;
554         return num_calls_ >= num_parallel_calls ||
555                invocation_results_.size() >= num_parallel_calls;
556       };
557       while (true) {
558         {
559           mutex_lock l(*mu_);
560           while (!cancelled_ && busy()) {
561             RecordStop(ctx.get());
562             cond_var_->wait(l);
563             RecordStart(ctx.get());
564           }
565           if (cancelled_) {
566             return;
567           }
568           while (!busy()) {
569             invocation_results_.push_back(std::make_shared<InvocationResult>());
570             new_calls.push_back(invocation_results_.back());
571             num_calls_++;
572           }
573           cond_var_->notify_all();
574         }
575         for (const auto& call : new_calls) {
576           CallFunction(ctx, call);
577         }
578         new_calls.clear();
579       }
580     }
581 
582     // Determines whether the caller needs to wait for a result. Upon returning
583     // false, `result` will point to the result.
ShouldWait(std::shared_ptr<InvocationResult> * result)584     bool ShouldWait(std::shared_ptr<InvocationResult>* result)
585         TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
586       if (cancelled_) {
587         return false;
588       }
589       if (!deterministic_) {
590         // Iterate through in-flight results and return the first one that is
591         // found to be available and not end-of-input. If the first result (in
592         // order) is end-of-input, we know that all earlier iterations have
593         // already been completed, so it is safe to return that result for the
594         // caller to process end of iteration.
595         for (auto it = invocation_results_.begin();
596              it != invocation_results_.end(); ++it) {
597           if ((*it)->notification.HasBeenNotified() &&
598               (it == invocation_results_.begin() || !(*it)->end_of_input)) {
599             std::swap(*result, *it);
600             invocation_results_.erase(it);
601             cond_var_->notify_all();
602             return false;
603           }
604         }
605       } else if (!invocation_results_.empty()) {
606         std::swap(*result, invocation_results_.front());
607         invocation_results_.pop_front();
608         cond_var_->notify_all();
609         return false;
610       }
611       return true;
612     }
613 
StatsThread(const std::shared_ptr<IteratorContext> & ctx)614     void StatsThread(const std::shared_ptr<IteratorContext>& ctx) {
615       for (int64_t step = 0;; ++step) {
616         int num_calls;
617         int num_parallel_calls;
618         {
619           mutex_lock l(*mu_);
620           if (step != 0 && !cancelled_) {
621             cond_var_->wait_for(
622                 l, std::chrono::milliseconds(kStatsReportingPeriodMillis));
623           }
624           if (cancelled_) {
625             return;
626           }
627           num_calls = num_calls_;
628           num_parallel_calls = num_parallel_calls_->value;
629         }
630         if (num_parallel_calls == 0) {
631           // Avoid division by zero.
632           num_parallel_calls = 1;
633         }
634         ctx->stats_aggregator()->AddScalar(
635             stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
636             static_cast<float>(num_calls) /
637                 static_cast<float>(num_parallel_calls),
638             step);
639       }
640     }
641 
WriteStatusLocked(IteratorStateWriter * writer,const std::string & key,const Status & status)642     Status WriteStatusLocked(IteratorStateWriter* writer,
643                              const std::string& key, const Status& status)
644         TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
645       TF_RETURN_IF_ERROR(writer->WriteScalar(
646           key, kErrorCode, static_cast<int64_t>(status.code())));
647       if (!status.ok()) {
648         TF_RETURN_IF_ERROR(
649             writer->WriteScalar(key, kErrorMessage, status.error_message()));
650       }
651       return OkStatus();
652     }
653 
ReadStatusLocked(IteratorStateReader * reader,const std::string & key,Status * status)654     Status ReadStatusLocked(IteratorStateReader* reader, const std::string& key,
655                             Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
656       int64_t code_int;
657       TF_RETURN_IF_ERROR(reader->ReadScalar(key, kErrorCode, &code_int));
658       error::Code code = static_cast<error::Code>(code_int);
659 
660       if (code != error::Code::OK) {
661         tstring error_message;
662         TF_RETURN_IF_ERROR(
663             reader->ReadScalar(key, kErrorMessage, &error_message));
664         *status = Status(code, error_message);
665       } else {
666         *status = OkStatus();
667       }
668       return OkStatus();
669     }
670 
671     // Used for coordination between the main thread and the runner thread.
672     const std::shared_ptr<mutex> mu_;
673     // Used for coordination between the main thread and the runner thread. In
674     // particular, the runner thread should only schedule new calls when the
675     // number of in-flight calls is less than the user specified level of
676     // parallelism and there are slots available in the `invocation_results_`
677     // buffer.
678     const std::shared_ptr<condition_variable> cond_var_;
679     // Identifies the maximum number of parallel calls.
680     const std::shared_ptr<model::SharedState> num_parallel_calls_;
681     const bool deterministic_;
682     const bool preserve_cardinality_;
683     const bool autotune_;
684     // Counts the number of outstanding calls.
685     int64_t num_calls_ TF_GUARDED_BY(*mu_) = 0;
686     // Controls cancellation of `input_impl_`. Must be ordered before
687     // `input_impl_` so that `input_impl_` is destroyed first.
688     std::unique_ptr<CancellationManager> cancellation_manager_;
689     std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
690     // Must be ordered after `cancellation_manager_` so that `input_impl_` is
691     // destroyed first.
692     std::unique_ptr<IteratorBase> input_impl_;
693     // Buffer for storing the invocation results.
694     std::deque<std::shared_ptr<InvocationResult>> invocation_results_
695         TF_GUARDED_BY(*mu_);
696     bool cancelled_ TF_GUARDED_BY(*mu_) = false;
697     std::unique_ptr<Thread> runner_thread_ TF_GUARDED_BY(*mu_);
698     std::unique_ptr<Thread> stats_thread_ TF_GUARDED_BY(*mu_);
699 
700     // Method for deregistering the cancellation callback.
701     std::function<void()> deregister_fn_;
702 
703     // Records the number of ParallelInterleave operations in the path from the
704     // root node to this node (not including this node) in the input pipeline
705     // tree. We record the interleave depth so that it can be included in the
706     // trace metadata.
707     int64 interleave_depth_ = -1;
708   };
709 
710   const DatasetBase* const input_;
711   const int64_t num_parallel_calls_;
712   const DataTypeVector output_types_;
713   const std::vector<PartialTensorShape> output_shapes_;
714   const DeterminismPolicy deterministic_;
715   const bool preserve_cardinality_;
716   const std::unique_ptr<CapturedFunction> captured_func_;
717   const int op_version_;
718   // This is used for random access provided by Get().
719   mutable std::unique_ptr<InstantiatedCapturedFunction>
720       instantiated_captured_func_;
721 };
722 
ParallelMapDatasetOp(OpKernelConstruction * ctx)723 ParallelMapDatasetOp::ParallelMapDatasetOp(OpKernelConstruction* ctx)
724     : UnaryDatasetOpKernel(ctx), op_version_(ctx->HasAttr(kSloppy) ? 1 : 2) {
725   FunctionMetadata::Params params;
726   OP_REQUIRES_OK(ctx, ctx->GetAttr(kUseInterOpParallelism,
727                                    &params.use_inter_op_parallelism));
728   OP_REQUIRES_OK(ctx,
729                  FunctionMetadata::Create(ctx, kFunc, params, &func_metadata_));
730   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
731   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
732   if (op_version_ == 1) {
733     bool sloppy;
734     OP_REQUIRES_OK(ctx, ctx->GetAttr(kSloppy, &sloppy));
735     if (sloppy) {
736       deterministic_ =
737           DeterminismPolicy(DeterminismPolicy::Type::kNondeterministic);
738     } else {
739       deterministic_ = DeterminismPolicy(DeterminismPolicy::Type::kDefault);
740     }
741   }
742   if (op_version_ == 2) {
743     std::string deterministic;
744     OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
745     OP_REQUIRES_OK(
746         ctx, DeterminismPolicy::FromString(deterministic, &deterministic_));
747   }
748   OP_REQUIRES_OK(ctx,
749                  ctx->GetAttr(kPreserveCardinality, &preserve_cardinality_));
750 }
751 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)752 void ParallelMapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
753                                        DatasetBase** output) {
754   int64_t num_parallel_calls;
755   if (op_version_ == 1) {
756     int32_t parallel_calls;
757     OP_REQUIRES_OK(
758         ctx, ParseScalarArgument(ctx, kNumParallelCalls, &parallel_calls));
759     num_parallel_calls = parallel_calls;
760   }
761   if (op_version_ == 2) {
762     OP_REQUIRES_OK(
763         ctx, ParseScalarArgument(ctx, kNumParallelCalls, &num_parallel_calls));
764   }
765   OP_REQUIRES(
766       ctx, num_parallel_calls > 0 || num_parallel_calls == model::kAutotune,
767       errors::InvalidArgument("num_parallel_calls must be greater than zero."));
768 
769   std::unique_ptr<CapturedFunction> captured_func;
770   OP_REQUIRES_OK(ctx,
771                  CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
772                                           &captured_func));
773 
774   if (num_parallel_calls == model::kAutotune) {
775     metrics::RecordTFDataAutotune(kDatasetType);
776   }
777 
778   *output =
779       new Dataset(ctx, input, num_parallel_calls, output_types_, output_shapes_,
780                   deterministic_, std::move(captured_func),
781                   preserve_cardinality_, op_version_);
782 }
783 
MakeDataServiceUncompressDataset(DatasetBase * input,std::unique_ptr<CapturedFunction> captured_function,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)784 std::unique_ptr<DatasetBase> MakeDataServiceUncompressDataset(
785     DatasetBase* input, std::unique_ptr<CapturedFunction> captured_function,
786     const DataTypeVector& output_types,
787     const std::vector<PartialTensorShape>& output_shapes) {
788   DatasetContext::Params param;
789   param.type_string = kParallelMapDatasetV2;
790   param.node_name = kParallelMapDatasetV2;
791   return std::make_unique<ParallelMapDatasetOp::Dataset>(
792       DatasetContext(std::move(param)), input,
793       /*num_parallel_calls=*/model::kAutotune, output_types, output_shapes,
794       DeterminismPolicy(DeterminismPolicy::Type::kDefault),
795       std::move(captured_function),
796       /*preserve_cardinality=*/true, /*op_version=*/2);
797 }
798 
799 namespace {
800 REGISTER_KERNEL_BUILDER(Name(kParallelMapDatasetV1).Device(DEVICE_CPU),
801                         ParallelMapDatasetOp);
802 REGISTER_KERNEL_BUILDER(Name(kParallelMapDatasetV2).Device(DEVICE_CPU),
803                         ParallelMapDatasetOp);
804 REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelMapDatasetV1);
805 REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelMapDatasetV2);
806 }  // namespace
807 }  // namespace data
808 }  // namespace tensorflow
809