xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/experimental/data_service_dataset_op.h"
16 
17 #include <algorithm>
18 #include <functional>
19 #include <limits>
20 #include <memory>
21 #include <optional>
22 #include <queue>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/algorithm/container.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/strings/ascii.h"
31 #include "absl/strings/numbers.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/string_view.h"
34 #include "absl/strings/substitute.h"
35 #include "absl/time/time.h"
36 #include "tensorflow/core/data/captured_function.h"
37 #include "tensorflow/core/data/dataset.pb.h"
38 #include "tensorflow/core/data/dataset_utils.h"
39 #include "tensorflow/core/data/name_utils.h"
40 #include "tensorflow/core/data/serialization_utils.h"
41 #include "tensorflow/core/data/service/client/common.h"
42 #include "tensorflow/core/data/service/client/validate_utils.h"
43 #include "tensorflow/core/data/service/common.h"
44 #include "tensorflow/core/data/service/common.pb.h"
45 #include "tensorflow/core/data/service/dispatcher.pb.h"
46 #include "tensorflow/core/data/service/dispatcher_client.h"
47 #include "tensorflow/core/data/service/grpc_util.h"
48 #include "tensorflow/core/data/service/worker.pb.h"
49 #include "tensorflow/core/data/service/worker_client.h"
50 #include "tensorflow/core/data/service/worker_impl.h"
51 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
52 #include "tensorflow/core/framework/dataset.h"
53 #include "tensorflow/core/framework/model.h"
54 #include "tensorflow/core/framework/partial_tensor_shape.h"
55 #include "tensorflow/core/framework/tensor.h"
56 #include "tensorflow/core/framework/types.pb.h"
57 #include "tensorflow/core/kernels/data/parallel_map_dataset_op.h"
58 #include "tensorflow/core/lib/core/errors.h"
59 #include "tensorflow/core/lib/gtl/cleanup.h"
60 #include "tensorflow/core/platform/env_time.h"
61 #include "tensorflow/core/platform/errors.h"
62 #include "tensorflow/core/platform/mutex.h"
63 #include "tensorflow/core/platform/refcount.h"
64 #include "tensorflow/core/platform/status.h"
65 #include "tensorflow/core/platform/statusor.h"
66 #include "tensorflow/core/platform/thread_annotations.h"
67 #include "tensorflow/core/platform/tstring.h"
68 #include "tensorflow/core/platform/types.h"
69 #include "tensorflow/core/profiler/lib/traceme.h"
70 #include "tensorflow/core/profiler/lib/traceme_encode.h"
71 #include "tensorflow/core/protobuf/data_service.pb.h"
72 #include "tensorflow/core/protobuf/error_codes.pb.h"
73 
74 namespace tensorflow {
75 namespace data {
76 
77 /* static */ constexpr const char* const DataServiceDatasetOp::kDatasetType;
78 /* static */ constexpr const char* const DataServiceDatasetOp::kDatasetId;
79 /* static */ constexpr const char* const DataServiceDatasetOp::kProcessingMode;
80 /* static */ constexpr const char* const DataServiceDatasetOp::kAddress;
81 /* static */ constexpr const char* const DataServiceDatasetOp::kProtocol;
82 /* static */ constexpr const char* const
83     DataServiceDatasetOp::kDataTransferProtocol;
84 /* static */ constexpr const char* const DataServiceDatasetOp::kJobName;
85 /* static */ constexpr const char* const DataServiceDatasetOp::kConsumerIndex;
86 /* static */ constexpr const char* const DataServiceDatasetOp::kNumConsumers;
87 /* static */ constexpr const char* const
88     DataServiceDatasetOp::kMaxOutstandingRequests;
89 /* static */ constexpr const char* const
90     DataServiceDatasetOp::kTaskRefreshIntervalHintMs;
91 /* static */ constexpr const char* const DataServiceDatasetOp::kTargetWorkers;
92 /* static */ constexpr const char* const
93     DataServiceDatasetOp::kIterationCounter;
94 /* static */ constexpr const char* const DataServiceDatasetOp::kOutputTypes;
95 /* static */ constexpr const char* const DataServiceDatasetOp::kOutputShapes;
96 /* static */ constexpr const char* const DataServiceDatasetOp::kUncompress;
97 /* static */ constexpr const char* const DataServiceDatasetOp::kUncompressFn;
98 /* static */ constexpr const char* const
99     DataServiceDatasetOp::kCrossTrainerCacheOptions;
100 
101 namespace {
102 // Default interval between task list refreshes.
103 const int64_t kDefaultTaskRefreshIntervalMs = 1000;  // 1 second.
104 
105 constexpr char kDataServiceDatasetV1[] = "DataServiceDataset";
106 constexpr char kDataServiceDatasetV2[] = "DataServiceDatasetV2";
107 constexpr char kDataServiceDatasetV3[] = "DataServiceDatasetV3";
108 constexpr char kDataServiceDatasetV4[] = "DataServiceDatasetV4";
109 
110 constexpr const char kParallelEpochs[] = "parallel_epochs";
111 constexpr const char kDistributedEpoch[] = "distributed_epoch";
112 
113 // Same timeout used by the RegisterDatasetOp.
114 constexpr absl::Duration kGetMetadataRetryTimeout = absl::Hours(1);
115 
IsColocatedTask(const TaskInfo & task)116 bool IsColocatedTask(const TaskInfo& task) {
117   return absl::c_any_of(task.worker_tags(), [](absl::string_view worker_tag) {
118     return absl::AsciiStrToUpper(worker_tag) == kColocatedWorkerTag;
119   });
120 }
121 
GetDataServiceMetadata(const std::string & dataset_id,const tstring & address,const tstring & protocol)122 StatusOr<DataServiceMetadata> GetDataServiceMetadata(
123     const std::string& dataset_id, const tstring& address,
124     const tstring& protocol) {
125   DataServiceDispatcherClient client(address, protocol);
126   DataServiceMetadata metadata;
127   absl::Time deadline =
128       absl::FromUnixMicros(EnvTime::NowMicros()) + kGetMetadataRetryTimeout;
129 
130   Status status = grpc_util::Retry(
131       [&]() { return client.GetDataServiceMetadata(dataset_id, metadata); },
132       absl::Substitute("Get data service metadata for dataset $0, "
133                        "with dispatcher at $1.",
134                        dataset_id, std::string(address)),
135       absl::ToUnixMicros(deadline));
136   if (errors::IsNotFound(status)) {
137     return errors::NotFound(
138         "Dataset id ", dataset_id,
139         " not found. It must be registered with `register_dataset` before "
140         "calling `from_dataset_id`.");
141   }
142   TF_RETURN_IF_ERROR(status);
143   return metadata;
144 }
145 
GetValidatedCompression(const std::string & dataset_id,const DataServiceMetadata & metadata)146 StatusOr<DataServiceMetadata::Compression> GetValidatedCompression(
147     const std::string& dataset_id, const DataServiceMetadata& metadata) {
148   if (metadata.compression() == DataServiceMetadata::COMPRESSION_UNSPECIFIED) {
149     return errors::Internal(absl::Substitute(
150         "Got invalid compression $0 for dataset $1. A proper compression "
151         "should be registered in `register_dataset`.",
152         DataServiceMetadata::Compression_Name(metadata.compression()),
153         dataset_id));
154   }
155   return metadata.compression();
156 }
157 
GetDataServiceConfig(const tstring & address,const tstring & protocol)158 StatusOr<DataServiceConfig> GetDataServiceConfig(const tstring& address,
159                                                  const tstring& protocol) {
160   DataServiceDispatcherClient client(address, protocol);
161   DataServiceConfig config;
162   absl::Time deadline =
163       absl::FromUnixMicros(EnvTime::NowMicros()) + kGetMetadataRetryTimeout;
164 
165   TF_RETURN_IF_ERROR(grpc_util::Retry(
166       [&]() { return client.GetDataServiceConfig(config); },
167       absl::Substitute("Get data service config with dispatcher at $0.",
168                        std::string(address)),
169       absl::ToUnixMicros(deadline)));
170   return config;
171 }
172 }  // namespace
173 
174 // Dataset for reading data from the tf.data service non-deterministically.
175 //
176 // This dataset interleaves dataset elements produced by multiple tf.data
177 // workers. We periodically query the dispatcher to determine which workers
178 // to read from (in case workers are added or removed).
179 class DataServiceDatasetOp::Dataset : public DatasetBase {
180  public:
Dataset(OpKernelContext * ctx,int op_version,const std::string & dataset_id,const ProcessingModeDef & processing_mode,const std::string & address,const std::string & protocol,const std::string & data_transfer_protocol,const std::string & job_name,std::optional<int64_t> consumer_index,std::optional<int64_t> num_consumers,int64_t max_outstanding_requests,int64_t task_refresh_interval_ms,const TargetWorkers target_workers,const DataServiceMetadata & metadata,IterationCounter * iteration_counter,bool owns_resource,ResourceHandle iteration_counter_handle,std::unique_ptr<CapturedFunction> captured_uncompress_func,const std::optional<CrossTrainerCacheOptions> & cross_trainer_cache_options,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)181   Dataset(OpKernelContext* ctx, int op_version, const std::string& dataset_id,
182           const ProcessingModeDef& processing_mode, const std::string& address,
183           const std::string& protocol,
184           const std::string& data_transfer_protocol,
185           const std::string& job_name, std::optional<int64_t> consumer_index,
186           std::optional<int64_t> num_consumers,
187           int64_t max_outstanding_requests, int64_t task_refresh_interval_ms,
188           const TargetWorkers target_workers,
189           const DataServiceMetadata& metadata,
190           IterationCounter* iteration_counter, bool owns_resource,
191           ResourceHandle iteration_counter_handle,
192           std::unique_ptr<CapturedFunction> captured_uncompress_func,
193           const std::optional<CrossTrainerCacheOptions>&
194               cross_trainer_cache_options,
195           const DataTypeVector& output_types,
196           const std::vector<PartialTensorShape>& output_shapes)
197       : DatasetBase(DatasetContext(ctx)),
198         op_version_(op_version),
199         dataset_id_(dataset_id),
200         processing_mode_(processing_mode),
201         address_(address),
202         protocol_(protocol),
203         data_transfer_protocol_(data_transfer_protocol),
204         job_name_(job_name),
205         is_coordinated_read_(consumer_index.has_value()),
206         consumer_index_(consumer_index),
207         num_consumers_(num_consumers),
208         max_outstanding_requests_(max_outstanding_requests),
209         task_refresh_interval_ms_(task_refresh_interval_ms),
210         target_workers_(target_workers),
211         metadata_(metadata),
212         iteration_counter_(iteration_counter),
213         owns_resource_(owns_resource),
214         iteration_counter_handle_(iteration_counter_handle),
215         resource_mgr_(ctx->resource_manager()),
216         captured_uncompress_func_(std::move(captured_uncompress_func)),
217         cross_trainer_cache_options_(cross_trainer_cache_options),
218         output_types_(output_types),
219         output_shapes_(output_shapes) {}
220 
~Dataset()221   ~Dataset() override {
222     iteration_counter_->Unref();
223     if (owns_resource_) {
224       Status s = resource_mgr_->Delete<IterationCounter>(
225           iteration_counter_handle_.container(),
226           iteration_counter_handle_.name());
227       if (!s.ok()) {
228         LOG(WARNING) << "Failed to delete iteration counter resource: " << s;
229       }
230     }
231   }
232 
MakeIteratorInternal(const string & prefix) const233   std::unique_ptr<IteratorBase> MakeIteratorInternal(
234       const string& prefix) const override {
235     return std::make_unique<Iterator>(
236         Iterator::Params{this,
237                          name_utils::IteratorPrefix(kDatasetType, prefix)},
238         DataServiceParams{dataset_id_, processing_mode_, address_, protocol_,
239                           data_transfer_protocol_, job_name_,
240                           /*repetition=*/iteration_counter_->GetAndIncrement(),
241                           num_consumers_, consumer_index_, target_workers_,
242                           metadata_, cross_trainer_cache_options_});
243   }
244 
output_dtypes() const245   const DataTypeVector& output_dtypes() const override { return output_types_; }
246 
output_shapes() const247   const std::vector<PartialTensorShape>& output_shapes() const override {
248     return output_shapes_;
249   }
250 
DebugString() const251   string DebugString() const override {
252     return name_utils::DatasetDebugString(kDatasetType);
253   }
254 
CardinalityInternal() const255   int64_t CardinalityInternal() const override {
256     if (is_coordinated_read_) {
257       // Coordinated reads require the dataset to be infinite.
258       return kInfiniteCardinality;
259     }
260 
261     if (metadata_.cardinality() == 0) {
262       return 0;
263     }
264 
265     if (metadata_.cardinality() == kInfiniteCardinality) {
266       // Sharding may cause an infinite dataset to be empty. For example, in
267       // `range(10).batch(10, drop_remainder=True).repeat()`, inserting `shard`
268       // before `batch` will cause the dataset to be empty.
269       // This case is rare, and there is significant performance hit for dynamic
270       // sharding if it reports unknown cardinality, so it is reasonable to
271       // report infinite cardinality. For DATA sharding, it is ok to report
272       // infinite cardinality since it inserts `shard` after `repeat`.
273       if (processing_mode_.sharding_policy() == ProcessingModeDef::OFF ||
274           processing_mode_.sharding_policy() == ProcessingModeDef::DYNAMIC ||
275           processing_mode_.sharding_policy() == ProcessingModeDef::DATA) {
276         return kInfiniteCardinality;
277       }
278     }
279     return kUnknownCardinality;
280   }
281 
CheckExternalState() const282   Status CheckExternalState() const override {
283     return Status(
284         error::FAILED_PRECONDITION,
285         strings::StrCat(DebugString(), " does not yet support serialization."));
286   }
287 
InputDatasets(std::vector<const DatasetBase * > * inputs) const288   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
289     inputs->clear();
290     return OkStatus();
291   }
292 
293  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const294   Status AsGraphDefInternal(SerializationContext* ctx,
295                             DatasetGraphDefBuilder* b,
296                             Node** output) const override {
297     // Inputs
298     std::vector<Node*> inputs;
299 
300     if (op_version_ >= 4) {
301       Node* dataset_id;
302       TF_RETURN_IF_ERROR(b->AddScalar(dataset_id_, &dataset_id));
303       inputs.push_back(dataset_id);
304     } else {
305       int64_t dataset_id_int;
306       if (!absl::SimpleAtoi(dataset_id_, &dataset_id_int)) {
307         return errors::Internal("Failed to parse dataset ID: ", dataset_id_,
308                                 ". Expect integers.");
309       }
310       Node* dataset_id;
311       TF_RETURN_IF_ERROR(b->AddScalar(dataset_id_int, &dataset_id));
312       inputs.push_back(dataset_id);
313     }
314 
315     Node* processing_mode;
316     tstring processing_mode_str = processing_mode_.SerializeAsString();
317     TF_RETURN_IF_ERROR(b->AddScalar(processing_mode_str, &processing_mode));
318     inputs.push_back(processing_mode);
319 
320     Node* address;
321     TF_RETURN_IF_ERROR(b->AddScalar(address_, &address));
322     inputs.push_back(address);
323 
324     Node* protocol;
325     TF_RETURN_IF_ERROR(b->AddScalar(protocol_, &protocol));
326     inputs.push_back(protocol);
327 
328     Node* job_name;
329     TF_RETURN_IF_ERROR(b->AddScalar(job_name_, &job_name));
330     inputs.push_back(job_name);
331 
332     if (op_version_ >= 2) {
333       Node* consumer_index;
334       TF_RETURN_IF_ERROR(
335           b->AddScalar(consumer_index_.value_or(-1), &consumer_index));
336       inputs.push_back(consumer_index);
337 
338       Node* num_consumers;
339       TF_RETURN_IF_ERROR(
340           b->AddScalar(num_consumers_.value_or(-1), &num_consumers));
341       inputs.push_back(num_consumers);
342     }
343 
344     Node* max_outstanding_requests;
345     TF_RETURN_IF_ERROR(
346         b->AddScalar(max_outstanding_requests_, &max_outstanding_requests));
347     inputs.push_back(max_outstanding_requests);
348 
349     Node* iteration_counter_handle = nullptr;
350     Tensor handle(DT_RESOURCE, TensorShape({}));
351     handle.scalar<ResourceHandle>()() = iteration_counter_handle_;
352     TF_RETURN_IF_ERROR(b->AddTensor(handle, &iteration_counter_handle));
353     inputs.push_back(iteration_counter_handle);
354 
355     // Attributes
356     std::vector<std::pair<StringPiece, AttrValue>> attrs;
357     AttrValue task_refresh_interval_hint_ms;
358     b->BuildAttrValue(task_refresh_interval_ms_,
359                       &task_refresh_interval_hint_ms);
360     attrs.push_back(
361         {kTaskRefreshIntervalHintMs, task_refresh_interval_hint_ms});
362 
363     AttrValue data_transfer_protocol;
364     b->BuildAttrValue(data_transfer_protocol_, &data_transfer_protocol);
365     attrs.push_back({kDataTransferProtocol, data_transfer_protocol});
366 
367     AttrValue target_workers;
368     b->BuildAttrValue(TargetWorkersToString(target_workers_), &target_workers);
369     attrs.push_back({kTargetWorkers, target_workers});
370 
371     if (op_version_ >= 3) {
372       // Attr: uncompress is true for the first time the graph is built, when a
373       // ParallelMap dataset is inserted for uncompression. Subsequent
374       // serialization always sets it to false to avoid inserting repeated map
375       // datasets for uncompression.
376       AttrValue uncompress_attr;
377       b->BuildAttrValue(false, &uncompress_attr);
378       attrs.push_back({kUncompress, uncompress_attr});
379 
380       // Attr: uncompress_fn
381       AttrValue uncompress_fn_attr;
382       b->BuildAttrValue(captured_uncompress_func_->func(), &uncompress_fn_attr);
383       attrs.push_back({kUncompressFn, uncompress_fn_attr});
384 
385       std::vector<Node*> uncompress_arguments;
386       DataTypeVector uncompress_arguments_types;
387       TF_RETURN_IF_ERROR(captured_uncompress_func_->AddToGraph(
388           ctx, b, &uncompress_arguments, &uncompress_arguments_types));
389     }
390 
391     // Attr: cross_trainer_cache_options
392     AttrValue cross_trainer_cache_options_attr;
393     std::string serialized_cross_trainer_cache_options;
394     if (cross_trainer_cache_options_.has_value()) {
395       serialized_cross_trainer_cache_options =
396           cross_trainer_cache_options_->SerializeAsString();
397     }
398     b->BuildAttrValue(serialized_cross_trainer_cache_options,
399                       &cross_trainer_cache_options_attr);
400     attrs.push_back(
401         {kCrossTrainerCacheOptions, cross_trainer_cache_options_attr});
402     return b->AddDataset(this, inputs, attrs, output);
403   }
404 
405  private:
406   class Iterator : public DatasetIterator<Dataset> {
407    public:
Iterator(const Params & params,const DataServiceParams & data_service_params)408     explicit Iterator(const Params& params,
409                       const DataServiceParams& data_service_params)
410         : DatasetIterator<Dataset>(params),
411           data_service_params_(data_service_params),
412           max_outstanding_requests_(params.dataset->max_outstanding_requests_) {
413     }
414 
~Iterator()415     ~Iterator() override {
416       VLOG(1) << "Destroying data service dataset iterator for iteration id "
417               << iteration_client_id_;
418       CancelThreads();
419       if (deregister_fn_) deregister_fn_();
420       task_thread_manager_.reset();
421       if (initialized_) {
422         Status s = dispatcher_->ReleaseIterationClient(iteration_client_id_);
423         if (!s.ok()) {
424           LOG(WARNING) << "Failed to release iteration client id: " << s;
425         }
426       }
427       for (auto& worker_thread : worker_threads_) {
428         worker_thread.reset();
429       }
430       DeleteLocalWorkerTasks();
431       VLOG(1) << "Destroyed data service dataset iterator for iteration id "
432               << iteration_client_id_;
433     }
434 
Initialize(IteratorContext * ctx)435     Status Initialize(IteratorContext* ctx) override {
436       TF_RETURN_IF_ERROR(ValidateDataServiceParams(data_service_params_));
437       VLOG(3) << "Connecting to " << dataset()->address_
438               << " in data service dataset op";
439       TF_RETURN_IF_ERROR(RegisterCancellationCallback(
440           ctx->cancellation_manager(), [this]() { CancelThreads(); },
441           &deregister_fn_));
442       dispatcher_ = std::make_unique<DataServiceDispatcherClient>(
443           dataset()->address_, dataset()->protocol_);
444       int64_t deadline_micros = kint64max;
445       std::optional<std::string> job_name;
446       if (!dataset()->job_name_.empty()) {
447         job_name = dataset()->job_name_;
448       }
449       TF_RETURN_IF_ERROR(grpc_util::Retry(
450           [&]() {
451             return dispatcher_->GetOrCreateJob(
452                 dataset()->dataset_id_, dataset()->processing_mode_, job_name,
453                 dataset()->num_consumers_,
454                 dataset()->cross_trainer_cache_options_.has_value(),
455                 dataset()->target_workers_, job_id_);
456           },
457           /*description=*/
458           strings::StrCat("get or create job with dispatcher at ",
459                           dataset()->address_),
460           deadline_micros));
461       TF_RETURN_IF_ERROR(grpc_util::Retry(
462           [&]() {
463             return dispatcher_->GetOrCreateIteration(
464                 job_id_, data_service_params_.repetition, iteration_client_id_);
465           },
466           /*description=*/
467           strings::StrCat("get or create iteration with dispatcher at ",
468                           dataset()->address_),
469           deadline_micros));
470       initialized_ = true;
471       return OkStatus();
472     }
473 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)474     Status GetNextInternal(IteratorContext* ctx,
475                            std::vector<Tensor>* out_tensors,
476                            bool* end_of_sequence) override {
477       VLOG(3) << "Calling GetNext in data service dataset's iterator.";
478       mutex_lock l(mu_);
479       EnsureThreadsStarted(ctx);
480       std::shared_ptr<Result> result;
481       do {
482         while (!ResultReady() && !Finished() && !cancelled_ && status_.ok()) {
483           VLOG(3) << "Blocking in GetNext: " << DebugString();
484           get_next_cv_.wait(l);
485         }
486         if (cancelled_) {
487           VLOG(3) << "Returning from GetNext due to cancellation";
488           return errors::Cancelled("Data service iterator was cancelled");
489         }
490         if (!status_.ok()) {
491           VLOG(3) << "Returning from GetNext with error " << status_;
492           return status_;
493         }
494         if (results_.empty()) {
495           *end_of_sequence = true;
496           VLOG(3) << "Returning from GetNext with end_of_sequence";
497           return OkStatus();
498         }
499         if (!ResultReady()) {
500           return errors::Internal(
501               "Expected a result to be ready, but none were.");
502         }
503         result = PopNextResult();
504         worker_thread_cv_.notify_one();
505       } while (result->skip);
506 
507       *end_of_sequence = result->end_of_sequence;
508       if (!*end_of_sequence) {
509         VLOG(1) << "Returning the next element from data service dataset's "
510                 << "Iterator: task " << result->task_id << ", element "
511                 << result->element_index;
512         if (StrictRoundRobin()) {
513           VLOG(1) << "Consumer " << dataset()->consumer_index_.value()
514                   << ": Result " << get_next_index_++;
515         }
516         out_tensors->swap(result->element);
517       }
518       return OkStatus();
519     }
520 
521    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const522     std::shared_ptr<model::Node> CreateNode(
523         IteratorContext* ctx, model::Node::Args args) const override {
524       return model::MakeKnownRatioNode(std::move(args),
525                                        /*ratio=*/1);
526     }
527 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)528     Status SaveInternal(SerializationContext* ctx,
529                         IteratorStateWriter* writer) override {
530       return errors::Unimplemented("SaveInternal is not yet supported");
531     }
532 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)533     Status RestoreInternal(IteratorContext* ctx,
534                            IteratorStateReader* reader) override {
535       return errors::Unimplemented("RestoreInternal is not yet supported");
536     }
537 
GetTraceMeMetadata() const538     data::TraceMeMetadata GetTraceMeMetadata() const override {
539       data::TraceMeMetadata result;
540       int64_t num_tasks = -1;
541       if (mu_.try_lock()) {
542         num_tasks = tasks_.size() - finished_tasks_;
543         mu_.unlock();
544       }
545       result.push_back(std::make_pair(
546           "num_tasks",
547           num_tasks == -1
548               ? kTraceInfoUnavailable
549               : strings::Printf("%lld", static_cast<long long>(num_tasks))));
550       result.push_back(std::make_pair("job_name", dataset()->job_name_));
551       result.push_back(std::make_pair(
552           "max_outstanding_requests",
553           strings::Printf("%lld", static_cast<long long>(
554                                       dataset()->max_outstanding_requests_))));
555       return result;
556     }
557 
558    private:
559     struct Task {
Tasktensorflow::data::DataServiceDatasetOp::Dataset::Iterator::Task560       Task(const TaskInfo& info,
561            std::unique_ptr<DataServiceWorkerClient> worker)
562           : info(info), worker(std::move(worker)) {}
563 
564       const TaskInfo info;
565       // Client for fetching task elements from the tf.data service worker.
566       const std::unique_ptr<DataServiceWorkerClient> worker;
567       // The next round to read from the task.
568       int64_t round = 0;
569       // Whether the task has been removed. The task will eventually be
570       // deleted from `tasks_` on the next dispatcher heartbeat.
571       bool removed = false;
572       bool skipped_previous_round = false;
573       // Indicates whether a worker thread is currently processing the task.
574       bool in_use TF_GUARDED_BY(&Iterator::mu_) = false;
575       // Indicates whether the worker has returned end_of_sequence for the task.
576       bool end_of_sequence TF_GUARDED_BY(&Iterator::mu_) = false;
577     };
578 
579     struct Result {
580       Result() = default;
581       Result(Result&&) = default;
582       Result& operator=(Result&&) = default;
583       Result(const Result&) = delete;
584       Result& operator=(const Result&) = delete;
585 
586       // Whether the result has been computed yet. GetNext needs to block
587       // until the next result is ready.
588       bool ready TF_GUARDED_BY(&Iterator::mu_) = false;
589       std::vector<Tensor> element TF_GUARDED_BY(&Iterator::mu_);
590       // The element's index within the tf.data worker it came from. Used for
591       // debugging.
592       int64_t element_index TF_GUARDED_BY(&Iterator::mu_) = -1;
593       // The id of the task that generated the result.
594       int64_t task_id TF_GUARDED_BY(&Iterator::mu_) = -1;
595       bool end_of_sequence TF_GUARDED_BY(&Iterator::mu_) = false;
596       bool skip TF_GUARDED_BY(&Iterator::mu_) = false;
597     };
598 
599     // Returns whether the iterator has finished and should return.
Finished() const600     bool Finished() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
601       return num_running_worker_threads_ == 0 && !ShouldWaitForNext();
602     }
603 
604     // Returns whether the iterator has more data.
ShouldWaitForNext() const605     bool ShouldWaitForNext() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
606       if (should_finish_iteration_) {
607         return !iteration_finished_;
608       }
609       return tasks_.empty() || finished_tasks_ < tasks_.size();
610     }
611 
EnsureThreadsStarted(IteratorContext * ctx)612     void EnsureThreadsStarted(IteratorContext* ctx)
613         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
614       if (!task_thread_manager_ && !cancelled_) {
615         auto new_ctx = std::make_shared<IteratorContext>(*ctx);
616         task_thread_manager_ =
617             ctx->StartThread("task-thread-manager",
618                              [this, new_ctx]() { TaskThreadManager(new_ctx); });
619       }
620     }
621 
CancelThreads()622     void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
623       mutex_lock l(mu_);
624       for (const auto& task : tasks_) {
625         task->worker->TryCancel();
626       }
627       cancelled_ = true;
628       worker_thread_cv_.notify_all();
629       manager_thread_cv_.notify_all();
630       get_next_cv_.notify_all();
631     }
632 
DeleteLocalWorkerTasks()633     void DeleteLocalWorkerTasks() {
634       std::vector<std::shared_ptr<Task>> tasks;
635       {
636         mutex_lock l(mu_);
637         tasks = tasks_;
638       }
639 
640       for (const std::shared_ptr<Task>& task : tasks) {
641         std::shared_ptr<DataServiceWorkerImpl> worker =
642             LocalWorkers::Get(task->info.worker_address());
643         if (worker && ShouldDeleteLocalTask(task->info)) {
644           worker->DeleteLocalTask(task->info);
645         }
646       }
647     }
648 
649     // Deletes the task if it is only read by the local client.
ShouldDeleteLocalTask(const TaskInfo & task) const650     bool ShouldDeleteLocalTask(const TaskInfo& task) const {
651       if (StrictRoundRobin()) {
652         return false;
653       }
654 
655       if (dataset()->target_workers_ == TARGET_WORKERS_LOCAL) {
656         return true;
657       }
658 
659       return dataset()->target_workers_ == TARGET_WORKERS_AUTO &&
660              IsColocatedTask(task);
661     }
662 
663     // Periodically refresh the task list.
664     // Maintain one thread fetching elements for each task.
665     // TODO(aaudibert): Instead of polling, have dispatcher send updates when
666     // the list of tasks changes.
TaskThreadManager(std::shared_ptr<IteratorContext> ctx)667     void TaskThreadManager(std::shared_ptr<IteratorContext> ctx) {
668       auto cleanup =
669           gtl::MakeCleanup([] { VLOG(1) << "Task thread manager exiting"; });
670       VLOG(1) << "Starting task thread manager";
671       uint64 next_check = Env::Default()->NowMicros();
672       while (true) {
673         {
674           mutex_lock l(mu_);
675           // All units are microseconds.
676           while (!cancelled_ && Env::Default()->NowMicros() < next_check) {
677             int64_t remaining_time = next_check - Env::Default()->NowMicros();
678             VLOG(4) << "Task thread manager waiting for " << remaining_time
679                     << "us";
680             manager_thread_cv_.wait_for(
681                 l, std::chrono::microseconds(remaining_time));
682           }
683           if (cancelled_) {
684             VLOG(3) << "Task thread manager finished";
685             return;
686           }
687         }
688         Heartbeat();
689         UpdateBufferSize();
690         UpdateWorkerThreads(ctx.get());
691         next_check = Env::Default()->NowMicros() +
692                      dataset()->task_refresh_interval_ms_ * 1000;
693       }
694     }
695 
TryBlockRound(int64_t round)696     void TryBlockRound(int64_t round) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
697       if (round_robin_round_limit_.has_value() &&
698           round_robin_round_limit_.value() == round) {
699         return;
700       }
701       if (current_round_ >= round) {
702         // In the next heartbeat, notify the dispatcher that we failed to add
703         // the task.
704         VLOG(1) << "Rejecting request to block round " << round
705                 << ", because processing has already begun for round "
706                 << current_round_;
707         return;
708       }
709       VLOG(1) << "Accepting request to block round " << round;
710       round_robin_round_limit_ = round;
711     }
712 
UpdateIterationFinished(bool iteration_finished)713     void UpdateIterationFinished(bool iteration_finished)
714         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
715       if (!iteration_finished) {
716         return;
717       }
718       iteration_finished_ = true;
719       get_next_cv_.notify_all();
720       worker_thread_cv_.notify_all();
721     }
722 
AddTask(const TaskInfo & task_info)723     Status AddTask(const TaskInfo& task_info) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
724       TF_ASSIGN_OR_RETURN(
725           std::unique_ptr<DataServiceWorkerClient> worker,
726           CreateDataServiceWorkerClient(task_info.transfer_address(),
727                                         dataset()->protocol_,
728                                         dataset()->data_transfer_protocol_));
729       tasks_.push_back(std::make_shared<Task>(task_info, std::move(worker)));
730       worker_thread_cv_.notify_one();
731       if (StrictRoundRobin()) {
732         VLOG(1) << "Consumer " << dataset()->consumer_index_.value()
733                 << " adding task " << task_info.task_id()
734                 << " to read from worker " << task_info.worker_address()
735                 << ". Task starting round: " << task_info.starting_round();
736         DCHECK_LE(current_round_, task_info.starting_round());
737         if (current_round_ == task_info.starting_round()) {
738           DCHECK_EQ(next_task_index_, 0);
739         }
740       }
741       return OkStatus();
742     }
743 
Heartbeat()744     void Heartbeat() TF_LOCKS_EXCLUDED(mu_) {
745       ClientHeartbeatRequest req;
746       req.set_iteration_client_id(iteration_client_id_);
747       if (StrictRoundRobin()) {
748         mutex_lock l(mu_);
749         req.set_current_round(current_round_);
750         if (round_robin_round_limit_.has_value()) {
751           req.set_blocked_round(round_robin_round_limit_.value());
752         }
753       }
754       ClientHeartbeatResponse resp;
755       Status s = dispatcher_->ClientHeartbeat(req, resp);
756       if (!s.ok()) {
757         if (IsPreemptedError(s)) {
758           LOG(WARNING)
759               << "Failed to heartbeat to dispatcher from iteration client id "
760               << iteration_client_id_
761               << ". Dispatcher address: " << dataset()->address_
762               << ". Error: " << s;
763           return;
764         }
765         mutex_lock l(mu_);
766         status_ = s;
767         get_next_cv_.notify_all();
768       }
769       mutex_lock l(mu_);
770       UpdateIterationFinished(resp.iteration_finished());
771       if (resp.optional_block_round_case() ==
772           ClientHeartbeatResponse::kBlockRound) {
773         TryBlockRound(resp.block_round());
774       } else {
775         round_robin_round_limit_ = std::nullopt;
776         worker_thread_cv_.notify_all();
777       }
778       UpdateTasks(resp);
779       RecordTFMetrics(resp);
780     }
781 
UpdateTasks(const ClientHeartbeatResponse & resp)782     void UpdateTasks(const ClientHeartbeatResponse& resp)
783         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
784       absl::flat_hash_map<int64_t, TaskInfo> task_id_to_task;
785       for (auto& task : resp.task_info()) {
786         task_id_to_task[task.task_id()] = task;
787       }
788       if (iteration_finished_) {
789         return;
790       }
791 
792       int index = 0;
793       while (index < tasks_.size()) {
794         std::shared_ptr<Task> task = tasks_[index];
795         if (task_id_to_task.contains(task->info.task_id())) {
796           // Remove already-known tasks from `task_id_to_task`, so that at the
797           // end of the loop, only new tasks remain.
798           task_id_to_task.erase(task->info.task_id());
799           ++index;
800         } else {
801           // Task has been removed.
802           if (task->end_of_sequence) {
803             finished_tasks_--;
804           }
805           tasks_.erase(tasks_.begin() + index);
806           if (index < next_task_index_) {
807             next_task_index_--;
808           }
809           if (!tasks_.empty() && next_task_index_ >= tasks_.size()) {
810             AdvanceTaskIndex();
811           }
812         }
813       }
814       for (auto& task : resp.task_info()) {
815         auto it = task_id_to_task.find(task.task_id());
816         if (it == task_id_to_task.end()) {
817           continue;
818         }
819         if (!ShouldReadFromTask(task)) {
820           VLOG(3) << "Skipping untargeted worker task " << task.task_id();
821           should_finish_iteration_ = false;
822           continue;
823         }
824         Status s = AddTask(it->second);
825         if (!s.ok()) {
826           status_ = s;
827           get_next_cv_.notify_all();
828           break;
829         }
830       }
831     }
832 
ShouldReadFromTask(const TaskInfo & task) const833     bool ShouldReadFromTask(const TaskInfo& task) const
834         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
835       if (StrictRoundRobin()) {
836         return true;
837       }
838 
839       const bool is_local_task =
840           (LocalWorkers::Get(task.worker_address()) != nullptr);
841       if (dataset()->target_workers_ == TARGET_WORKERS_LOCAL &&
842           !is_local_task) {
843         return false;
844       }
845 
846       // Cross-TF/TPU host reads may cause resource contention on the TF/TPU
847       // hosts. tf.data service avoids reading from non-local TF-hosted workers.
848       const bool is_cross_tf_host_read =
849           !is_local_task && IsColocatedTask(task);
850       if (dataset()->target_workers_ == TARGET_WORKERS_AUTO &&
851           is_cross_tf_host_read) {
852         return false;
853       }
854       return true;
855     }
856 
RecordTFMetrics(const ClientHeartbeatResponse & resp)857     void RecordTFMetrics(const ClientHeartbeatResponse& resp)
858         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
859       for (const auto& task : resp.task_info()) {
860         if (worker_uids_.contains(task.worker_uid())) {
861           continue;
862         }
863         metrics::RecordTFDataServiceClientIterators(
864             task.worker_uid(), resp.deployment_mode(),
865             dataset()->processing_mode_, dataset()->is_coordinated_read_);
866         worker_uids_.insert(task.worker_uid());
867       }
868     }
869 
UpdateBufferSize()870     void UpdateBufferSize() TF_LOCKS_EXCLUDED(mu_) {
871       if (dataset()->max_outstanding_requests_ == model::kAutotune) {
872         // Adjust `max_outstanding_requests_` to account for newly added tasks.
873         // `tasks_` includes the local tasks, so we subtract one from the
874         // configured local task buffer size.
875         mutex_lock l(mu_);
876         int64_t max_outstanding_requests = tasks_.size();
877         if (max_outstanding_requests > max_outstanding_requests_) {
878           worker_thread_cv_.notify_all();
879         }
880         max_outstanding_requests_ = max_outstanding_requests;
881       }
882     }
883 
UpdateWorkerThreads(IteratorContext * ctx)884     void UpdateWorkerThreads(IteratorContext* ctx) TF_LOCKS_EXCLUDED(mu_) {
885       mutex_lock l(mu_);
886       const int64_t max_num_threads =
887           std::min<int64_t>(tasks_.size(), max_outstanding_requests_);
888       while (num_running_worker_threads_ < max_num_threads && !cancelled_ &&
889              status_.ok()) {
890         num_running_worker_threads_++;
891         auto done = [this]() {
892           mutex_lock l(mu_);
893           num_running_worker_threads_--;
894           get_next_cv_.notify_all();
895         };
896         worker_threads_.push_back(ctx->StartThread(
897             "tf-data-service-task_thread", [this, done = std::move(done)]() {
898               RunWorkerThread(std::move(done));
899             }));
900       }
901     }
902 
RunWorkerThread(std::function<void ()> done)903     void RunWorkerThread(std::function<void()> done) {
904       auto cleanup = gtl::MakeCleanup([done = std::move(done)]() {
905         done();
906         VLOG(1) << "Worker thread exiting";
907       });
908       VLOG(1) << "Starting worker thread";
909       std::shared_ptr<Task> task_to_process;
910       while (true) {
911         std::shared_ptr<Result> result;
912         {
913           mutex_lock l(mu_);
914           if (task_to_process) {
915             task_to_process->in_use = false;
916             --outstanding_requests_;
917             task_to_process = nullptr;
918             worker_thread_cv_.notify_one();
919           }
920           while (true) {
921             if (cancelled_ || !ShouldWaitForNext()) {
922               return;
923             }
924             task_to_process = GetTaskToProcess();
925             if (task_to_process) {
926               VLOG(3) << "Selected a task to process: "
927                       << task_to_process->info.ShortDebugString();
928               break;
929             }
930             worker_thread_cv_.wait(l);
931           }
932           DCHECK(task_to_process != nullptr);
933           task_to_process->in_use = true;
934           ++outstanding_requests_;
935           if (StrictRoundRobin()) {
936             // Reserve a spot in the results_ queue.
937             results_.push(std::make_shared<Result>());
938             result = results_.back();
939           } else {
940             // The result will be added to results_ when it's ready.
941             result = std::make_shared<Result>();
942           }
943           VLOG(3) << "Processing task " << task_to_process->info.task_id();
944         }
945         int64_t deadline_micros = kint64max;
946         Status s =
947             GetElementTraced(task_to_process.get(), deadline_micros,
948                              /*enqueue_result=*/!StrictRoundRobin(), result);
949         if (!s.ok()) {
950           mutex_lock l(mu_);
951           VLOG(1) << "Failed to get element from worker "
952                   << task_to_process->info.worker_address() << ": " << s;
953           task_to_process->in_use = false;
954           --outstanding_requests_;
955           status_ = errors::CreateWithUpdatedMessage(
956               s, absl::StrCat("Failed to get element from worker ",
957                               task_to_process->info.worker_address(), ": ",
958                               s.error_message()));
959           get_next_cv_.notify_all();
960           return;
961         }
962       }
963     }
964 
965     // Reports whether we can request another element without violating
966     // `max_outstanding_requests_`.
ShouldProcessTask()967     bool ShouldProcessTask() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
968       // When doing round-robin reads, outstanding requests pre-allocate a
969       // result in `results_`, so we only need to check the size of `results_`.
970       if (StrictRoundRobin()) {
971         return results_.size() < max_outstanding_requests_;
972       }
973       // Otherwise, results aren't added to `results_` until the data has been
974       // successfully retrieved. We need to count requests already added to
975       // `results_` as well as in-progress requests.
976       return results_.size() + outstanding_requests_ <
977              max_outstanding_requests_;
978     }
979 
980     // Searches for a task to process, visiting tasks in-order and giving every
981     // task a chance to proceed.
GetTaskToProcess()982     std::shared_ptr<Task> GetTaskToProcess() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
983       if (!ShouldProcessTask()) {
984         return nullptr;
985       }
986 
987       for (int i = 0; i < tasks_.size(); ++i) {
988         std::shared_ptr<Task>& task = tasks_[next_task_index_];
989         if (StrictRoundRobin() &&
990             (task->in_use ||
991              current_round_ >= round_robin_round_limit_.value_or(
992                                    std::numeric_limits<int64_t>::max()))) {
993           VLOG(4) << "No round robin task found. in_use: " << task->in_use
994                   << ". current_round: " << current_round_
995                   << ". round_robin_round_limit: "
996                   << round_robin_round_limit_.value_or(-1);
997           return nullptr;
998         }
999         if (current_round_ < task->info.starting_round() || task->in_use ||
1000             task->end_of_sequence || task->removed) {
1001           VLOG(3) << "Skipping task " << next_task_index_
1002                   << ". starting round: " << task->info.starting_round()
1003                   << ". current round: " << current_round_
1004                   << ". task->in_use: " << task->in_use
1005                   << ". end_of_sequence: " << task->end_of_sequence
1006                   << ". task->removed: " << task->removed;
1007           AdvanceTaskIndex();
1008           continue;
1009         }
1010         task->round = current_round_;
1011         AdvanceTaskIndex();
1012         return task;
1013       }
1014       return nullptr;
1015     }
1016 
1017     // Increments the next task index, starting over if all tasks have been
1018     // processed.
AdvanceTaskIndex()1019     void AdvanceTaskIndex() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1020       next_task_index_++;
1021       if (next_task_index_ >= tasks_.size()) {
1022         current_round_++;
1023         next_task_index_ = 0;
1024       }
1025     }
1026 
TryGetElement(const Task & task,GetElementResult & result)1027     Status TryGetElement(const Task& task, GetElementResult& result) {
1028       GetElementRequest req;
1029       req.set_task_id(task.info.task_id());
1030       req.set_skipped_previous_round(task.skipped_previous_round);
1031       if (StrictRoundRobin()) {
1032         req.set_consumer_index(dataset()->consumer_index_.value());
1033         req.set_round_index(task.round);
1034         req.set_allow_skip(true);
1035       }
1036       if (dataset()->cross_trainer_cache_options_) {
1037         req.set_trainer_id(
1038             dataset()->cross_trainer_cache_options_->trainer_id());
1039       }
1040       return task.worker->GetElement(req, result);
1041     }
1042 
ProcessGetElementResponse(bool enqueue_result,GetElementResult & get_element_result,std::shared_ptr<Result> result,Task & task)1043     void ProcessGetElementResponse(bool enqueue_result,
1044                                    GetElementResult& get_element_result,
1045                                    std::shared_ptr<Result> result, Task& task) {
1046       mutex_lock l(mu_);
1047       result->ready = true;
1048       result->end_of_sequence = get_element_result.end_of_sequence;
1049       result->skip = get_element_result.skip;
1050       if (!get_element_result.end_of_sequence && !get_element_result.skip) {
1051         task.skipped_previous_round = false;
1052         result->element = std::move(get_element_result.components);
1053         result->element_index = get_element_result.element_index;
1054         result->task_id = task.info.task_id();
1055       } else if (get_element_result.skip) {
1056         task.skipped_previous_round = true;
1057       } else {
1058         task.end_of_sequence = true;
1059         finished_tasks_++;
1060       }
1061       if (enqueue_result && !result->end_of_sequence) {
1062         results_.push(std::move(result));
1063       }
1064       get_next_cv_.notify_all();
1065     }
1066 
GetElementTraced(Task * task,int64_t deadline_micros,bool enqueue_result,std::shared_ptr<Result> result)1067     Status GetElementTraced(Task* task, int64_t deadline_micros,
1068                             bool enqueue_result, std::shared_ptr<Result> result)
1069         TF_LOCKS_EXCLUDED(mu_) {
1070       VLOG(3) << "Getting an element for task id " << task->info.task_id();
1071       tensorflow::profiler::TraceMe activity(
1072           "GetDataServiceElement", tensorflow::profiler::TraceMeLevel::kInfo);
1073       activity.AppendMetadata([&]() {
1074         return profiler::TraceMeEncode(
1075             {{"address", task->info.worker_address()}});
1076       });
1077       if (StrictRoundRobin()) {
1078         VLOG(3) << "Requesting element from consumer index "
1079                 << dataset()->consumer_index_.value() << ", round "
1080                 << task->round;
1081         activity.AppendMetadata([&]() {
1082           return profiler::TraceMeEncode(
1083               {{"consumer_index", dataset()->consumer_index_.value()},
1084                {"round_index", task->round}});
1085         });
1086       }
1087       Status s = GetElement(task, deadline_micros, enqueue_result, result);
1088       mutex_lock l(mu_);
1089       VLOG(3) << "Got an element for task id " << task->info.task_id();
1090       return s;
1091     }
1092 
MaybeRemoveTask(Task & task,int64_t deadline_micros,Result & result)1093     Status MaybeRemoveTask(Task& task, int64_t deadline_micros,
1094                            Result& result) {
1095       bool removed;
1096       VLOG(1) << "Requesting task removal for worker "
1097               << task.info.worker_address() << " in round " << task.round;
1098       TF_RETURN_IF_ERROR(grpc_util::Retry(
1099           [&] {
1100             return dispatcher_->MaybeRemoveTask(
1101                 task.info.task_id(), dataset()->consumer_index_.value(),
1102                 task.round, removed);
1103           },
1104           /*should_retry=*/
1105           [&] {
1106             mutex_lock l(mu_);
1107             return !cancelled_;
1108           },
1109           /*description=*/"request task removal ", deadline_micros));
1110       if (removed) {
1111         mutex_lock l(mu_);
1112         task.removed = true;
1113         result.ready = true;
1114         result.skip = true;
1115         get_next_cv_.notify_all();
1116         return OkStatus();
1117       }
1118       VLOG(1) << "Failed to remove task for worker "
1119               << task.info.worker_address();
1120       return OkStatus();
1121     }
1122 
GetElement(Task * task,int64_t deadline_micros,bool enqueue_result,std::shared_ptr<Result> result)1123     Status GetElement(Task* task, int64_t deadline_micros, bool enqueue_result,
1124                       std::shared_ptr<Result> result) TF_LOCKS_EXCLUDED(mu_) {
1125       GetElementResult get_element_result;
1126       for (int num_retries = 0;; ++num_retries) {
1127         Status s = TryGetElement(*task, get_element_result);
1128         if (s.ok()) break;
1129         // Retry all errors that could indicate preemption.
1130         if (!IsPreemptedError(s)) {
1131           return s;
1132         }
1133         {
1134           mutex_lock l(mu_);
1135           if (cancelled_) {
1136             return errors::Cancelled("DataServiceDataset iterator cancelled");
1137           }
1138         }
1139         int64_t now_micros = Env::Default()->NowMicros();
1140         if (now_micros > deadline_micros) {
1141           return s;
1142         }
1143         if (StrictRoundRobin() && num_retries > 0) {
1144           TF_RETURN_IF_ERROR(MaybeRemoveTask(*task, deadline_micros, *result));
1145           mutex_lock l(mu_);
1146           if (result->skip) {
1147             return OkStatus();
1148           }
1149         }
1150         int64_t backoff_until = std::min(
1151             deadline_micros,
1152             now_micros + ::tensorflow::ComputeBackoffMicroseconds(num_retries));
1153         VLOG(1) << "Failed to get an element from worker "
1154                 << task->info.worker_address() << ": " << s
1155                 << ". Will retry in " << (backoff_until - now_micros)
1156                 << " microseconds";
1157         Env::Default()->SleepForMicroseconds(backoff_until - now_micros);
1158       }
1159       ProcessGetElementResponse(enqueue_result, get_element_result, result,
1160                                 *task);
1161       return OkStatus();
1162     }
1163 
ResultReady() const1164     bool ResultReady() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1165       return !results_.empty() && results_.front()->ready;
1166     }
1167 
PopNextResult()1168     std::shared_ptr<Result> PopNextResult() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1169       std::shared_ptr<Result> result = results_.front();
1170       results_.pop();
1171       return result;
1172     }
1173 
StrictRoundRobin() const1174     bool StrictRoundRobin() const {
1175       return dataset()->num_consumers_.has_value();
1176     }
1177 
DebugString() const1178     std::string DebugString() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1179       return absl::Substitute(
1180           "results_ { size: $0 front.ready: $1 } iteration_finished_: $2 "
1181           "tasks { size: $3 } finished_tasks_: $4 "
1182           "num_running_worker_threads_: $5",
1183           results_.size(), !results_.empty() && results_.front()->ready,
1184           iteration_finished_, tasks_.size(), finished_tasks_,
1185           num_running_worker_threads_);
1186     }
1187 
1188     const DataServiceParams data_service_params_;
1189 
1190     mutable mutex mu_;
1191     condition_variable get_next_cv_ TF_GUARDED_BY(mu_);
1192     condition_variable worker_thread_cv_ TF_GUARDED_BY(mu_);
1193     condition_variable manager_thread_cv_ TF_GUARDED_BY(mu_);
1194     bool cancelled_ TF_GUARDED_BY(mu_) = false;
1195     // Method for deregistering the cancellation callback.
1196     std::function<void()> deregister_fn_;
1197 
1198     // Number of outstanding requests.
1199     int64_t outstanding_requests_ TF_GUARDED_BY(mu_) = 0;
1200 
1201     // max_outstanding_requests controls how many elements may be held in memory
1202     // at the same time. This count includes both in-progress requests for
1203     // elements as well as completed requests which haven't yet been produced.
1204     int64_t max_outstanding_requests_ TF_GUARDED_BY(mu_);
1205 
1206     // The number of threads in `worker_threads_` which are still running.
1207     int64_t num_running_worker_threads_ TF_GUARDED_BY(mu_) = 0;
1208 
1209     // The index of the next task in `tasks_` to read from.
1210     int64_t next_task_index_ TF_GUARDED_BY(mu_) = 0;
1211 
1212     // The number tasks in the `tasks_` list that have reached end_of_sequence.
1213     int64_t finished_tasks_ TF_GUARDED_BY(mu_) = 0;
1214 
1215     // List of tasks to read from.
1216     std::vector<std::shared_ptr<Task>> tasks_ TF_GUARDED_BY(mu_);
1217 
1218     // The current round robin round we are engaged in. A round involves reading
1219     // from each task once.
1220     int64_t current_round_ TF_GUARDED_BY(mu_) = 0;
1221 
1222     // Maximum round robin round to read up to before blocking, not inclusive.
1223     // INVARIANT: current_round_ <= round_robin_round_limit_.
1224     //            If current_round_ == round_robin_round_limit_,
1225     //            next_task_index_ must be 0.
1226     std::optional<int64_t> round_robin_round_limit_ TF_GUARDED_BY(mu_);
1227 
1228     // A status to be returned from the next call to `GetNext`. This is set by
1229     // asynchronous threads when they encounter errors.
1230     Status status_ TF_GUARDED_BY(mu_) = OkStatus();
1231     // A queue of results for `GetElement` requests to read from. When doing
1232     // strict round robin reads, the queue will contain placeholder results with
1233     // their `Result::ready` field false until their data has been retrieved
1234     // from a worker. When not doing round-robin reads, results are only added
1235     // to the queue after they are ready, to avoid head-of-line blocking.
1236     std::queue<std::shared_ptr<Result>> results_ TF_GUARDED_BY(mu_);
1237 
1238     bool initialized_ = false;
1239     // Set once in Initialize().
1240     int64_t job_id_;
1241     int64_t iteration_client_id_;
1242     std::unique_ptr<DataServiceDispatcherClient> dispatcher_;
1243     int64_t get_next_index_ TF_GUARDED_BY(mu_) = 0;
1244 
1245     bool iteration_finished_ = false;
1246     bool should_finish_iteration_ TF_GUARDED_BY(mu_) = true;
1247 
1248     // The set of worker UIDs that we have already recorded metrics for.
1249     absl::flat_hash_set<int64_t> worker_uids_ TF_GUARDED_BY(mu_);
1250 
1251     std::vector<std::unique_ptr<Thread>> worker_threads_ TF_GUARDED_BY(mu_);
1252     std::unique_ptr<Thread> task_thread_manager_ TF_GUARDED_BY(mu_);
1253   };
1254 
1255   const int op_version_;
1256   const tstring dataset_id_;
1257   const ProcessingModeDef processing_mode_;
1258   const tstring address_;
1259   const tstring protocol_;
1260   const tstring data_transfer_protocol_;
1261   const tstring job_name_;
1262   const bool is_coordinated_read_;
1263   const std::optional<int64_t> consumer_index_;
1264   const std::optional<int64_t> num_consumers_;
1265   const int64_t max_outstanding_requests_;
1266   const int64_t task_refresh_interval_ms_;
1267   const TargetWorkers target_workers_;
1268   const DataServiceMetadata metadata_;
1269   IterationCounter* const iteration_counter_;  // Owned
1270   const bool owns_resource_;
1271   const ResourceHandle iteration_counter_handle_;
1272   ResourceMgr* const resource_mgr_;  // Not owned
1273   const std::unique_ptr<CapturedFunction> captured_uncompress_func_;
1274   const std::optional<CrossTrainerCacheOptions> cross_trainer_cache_options_;
1275   const DataTypeVector output_types_;
1276   const std::vector<PartialTensorShape> output_shapes_;
1277 };
1278 
DataServiceDatasetOp(OpKernelConstruction * ctx)1279 DataServiceDatasetOp::DataServiceDatasetOp(OpKernelConstruction* ctx)
1280     : DatasetOpKernel(ctx) {
1281   const auto& op_name = ctx->def().op();
1282   if (op_name == kDataServiceDatasetV1) {
1283     op_version_ = 1;
1284   } else if (op_name == kDataServiceDatasetV2) {
1285     op_version_ = 2;
1286   } else if (op_name == kDataServiceDatasetV3) {
1287     op_version_ = 3;
1288   } else if (op_name == kDataServiceDatasetV4) {
1289     op_version_ = 4;
1290   } else {
1291     ctx->CtxFailure(errors::FailedPrecondition(
1292         "Unrecognized data service dataset op name: ", op_name));
1293     return;
1294   }
1295 
1296   OP_REQUIRES_OK(ctx, ctx->GetAttr(kTaskRefreshIntervalHintMs,
1297                                    &task_refresh_interval_hint_ms_));
1298   if (task_refresh_interval_hint_ms_ == model::kAutotune) {
1299     task_refresh_interval_hint_ms_ = kDefaultTaskRefreshIntervalMs;
1300   }
1301   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
1302   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
1303   if (ctx->HasAttr(kDataTransferProtocol)) {
1304     OP_REQUIRES_OK(
1305         ctx, ctx->GetAttr(kDataTransferProtocol, &data_transfer_protocol_));
1306   }
1307   if (data_transfer_protocol_.empty()) {
1308     data_transfer_protocol_ = kGrpcTransferProtocol;
1309   }
1310 
1311   std::string target_workers_str = "AUTO";
1312   if (ctx->HasAttr(kTargetWorkers)) {
1313     OP_REQUIRES_OK(ctx, ctx->GetAttr(kTargetWorkers, &target_workers_str));
1314   }
1315   StatusOr<TargetWorkers> status_or_target_workers =
1316       ParseTargetWorkers(target_workers_str);
1317   OP_REQUIRES_OK(ctx, status_or_target_workers.status());
1318   target_workers_ = *status_or_target_workers;
1319 
1320   if (op_version_ >= 3) {
1321     OP_REQUIRES_OK(ctx, ctx->GetAttr(kUncompress, &uncompress_));
1322     FunctionMetadata::Params params;
1323     params.use_inter_op_parallelism = true;
1324     OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kUncompressFn, params,
1325                                                  &uncompress_fn_));
1326   }
1327 
1328   if (ctx->HasAttr(kCrossTrainerCacheOptions)) {
1329     OP_REQUIRES_OK(ctx, ctx->GetAttr(kCrossTrainerCacheOptions,
1330                                      &seriazlied_cross_trainer_cache_options_));
1331   }
1332 }
1333 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)1334 void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx,
1335                                        DatasetBase** output) {
1336   tstring dataset_id;
1337   if (op_version_ >= 4) {
1338     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kDatasetId, &dataset_id));
1339   } else {
1340     int64_t dataset_id_int = 0;
1341     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kDatasetId, &dataset_id_int));
1342     dataset_id = absl::StrCat(dataset_id_int);
1343   }
1344 
1345   tstring processing_mode_str;
1346   OP_REQUIRES_OK(
1347       ctx, ParseScalarArgument(ctx, kProcessingMode, &processing_mode_str));
1348   ProcessingModeDef processing_mode;
1349   if (processing_mode_str == kParallelEpochs) {
1350     processing_mode.set_sharding_policy(ProcessingModeDef::OFF);
1351   } else if (processing_mode_str == kDistributedEpoch) {
1352     processing_mode.set_sharding_policy(ProcessingModeDef::DYNAMIC);
1353   } else {
1354     OP_REQUIRES(ctx, processing_mode.ParseFromString(processing_mode_str),
1355                 errors::InvalidArgument(absl::Substitute(
1356                     "Failed to parse ProcessingModeDef from string: $0",
1357                     std::string(processing_mode_str))));
1358   }
1359 
1360   tstring address;
1361   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kAddress, &address));
1362   OP_REQUIRES(ctx, !address.empty(),
1363               errors::InvalidArgument(kAddress, " must be non-empty."));
1364 
1365   tstring protocol;
1366   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kProtocol, &protocol));
1367   OP_REQUIRES(ctx, !protocol.empty(),
1368               errors::InvalidArgument(kProtocol, " must be non-empty."));
1369 
1370   tstring job_name;
1371   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kJobName, &job_name));
1372 
1373   StatusOr<DataServiceConfig> config = GetDataServiceConfig(address, protocol);
1374   OP_REQUIRES_OK(ctx, config.status());
1375 
1376   if (IsStaticShard(processing_mode) &&
1377       config->deployment_mode() == DEPLOYMENT_MODE_COLOCATED &&
1378       target_workers_ == TARGET_WORKERS_AUTO) {
1379     VLOG(1) << "Using LOCAL target workers for static sharding mode: "
1380             << processing_mode.ShortDebugString();
1381     target_workers_ = TARGET_WORKERS_LOCAL;
1382   }
1383   if (target_workers_ == TARGET_WORKERS_LOCAL) {
1384     data_transfer_protocol_ = kLocalTransferProtocol;
1385   }
1386 
1387   std::optional<int64_t> consumer_index;
1388   std::optional<int64_t> num_consumers;
1389   if (op_version_ >= 2) {
1390     int64_t consumer_index_int;
1391     OP_REQUIRES_OK(
1392         ctx, ParseScalarArgument(ctx, kConsumerIndex, &consumer_index_int));
1393     if (consumer_index_int >= 0) {
1394       consumer_index = consumer_index_int;
1395     }
1396 
1397     int64_t num_consumers_int;
1398     OP_REQUIRES_OK(ctx,
1399                    ParseScalarArgument(ctx, kNumConsumers, &num_consumers_int));
1400     if (num_consumers_int >= 0) {
1401       num_consumers = num_consumers_int;
1402     }
1403   }
1404 
1405   int64_t max_outstanding_requests;
1406   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kMaxOutstandingRequests,
1407                                           &max_outstanding_requests));
1408 
1409   ResourceHandle iteration_counter_handle;
1410   OP_REQUIRES_OK(
1411       ctx, HandleFromInput(ctx, kIterationCounter, &iteration_counter_handle));
1412   IterationCounter* iteration_counter = nullptr;
1413   Status s = ctx->resource_manager()->Lookup<IterationCounter>(
1414       iteration_counter_handle.container(), iteration_counter_handle.name(),
1415       &iteration_counter);
1416   bool owns_resource = false;
1417   if (errors::IsNotFound(s)) {
1418     owns_resource = true;
1419     static std::atomic<int64_t> resource_id_counter(0);
1420     const std::string& container = ctx->resource_manager()->default_container();
1421     std::string name =
1422         strings::StrCat(ctx->op_kernel().name(), "/", kIterationCounter, "_",
1423                         resource_id_counter.fetch_add(1));
1424     OP_REQUIRES_OK(ctx,
1425                    ctx->resource_manager()->LookupOrCreate<IterationCounter>(
1426                        container, name, &iteration_counter,
1427                        [](IterationCounter** counter) {
1428                          *counter = new IterationCounter();
1429                          return OkStatus();
1430                        }));
1431     iteration_counter_handle =
1432         MakeResourceHandle<IterationCounter>(ctx, container, name);
1433   } else {
1434     OP_REQUIRES_OK(ctx, s);
1435   }
1436 
1437   OP_REQUIRES(
1438       ctx,
1439       max_outstanding_requests == model::kAutotune ||
1440           max_outstanding_requests > 0,
1441       errors::InvalidArgument(kMaxOutstandingRequests, " must be positive or ",
1442                               model::kAutotune));
1443 
1444   StatusOr<DataServiceMetadata> metadata =
1445       GetDataServiceMetadata(dataset_id, address, protocol);
1446   OP_REQUIRES_OK(ctx, metadata.status());
1447 
1448   bool should_uncompress = op_version_ >= 3 && uncompress_;
1449   if (should_uncompress) {
1450     StatusOr<DataServiceMetadata::Compression> compression =
1451         GetValidatedCompression(dataset_id, *metadata);
1452     OP_REQUIRES_OK(ctx, compression.status());
1453     should_uncompress =
1454         should_uncompress &&
1455         (*compression == DataServiceMetadata::COMPRESSION_SNAPPY);
1456   }
1457   DataTypeVector data_service_output_types = output_types_;
1458   std::vector<PartialTensorShape> data_service_output_shapes = output_shapes_;
1459   if (should_uncompress) {
1460     data_service_output_types = {DT_VARIANT};
1461     data_service_output_shapes = {TensorShape({})};
1462   }
1463 
1464   std::unique_ptr<CapturedFunction> captured_uncompress_func;
1465   if (op_version_ >= 3) {
1466     OP_REQUIRES_OK(
1467         ctx, CapturedFunction::Create(ctx, uncompress_fn_,
1468                                       /*captured_inputs=*/std::vector<Tensor>{},
1469                                       &captured_uncompress_func));
1470   }
1471 
1472   std::optional<CrossTrainerCacheOptions> cross_trainer_cache_options;
1473   if (!seriazlied_cross_trainer_cache_options_.empty()) {
1474     cross_trainer_cache_options.emplace();
1475     cross_trainer_cache_options->ParseFromString(
1476         seriazlied_cross_trainer_cache_options_);
1477   }
1478   DatasetBase* dataset = new Dataset(
1479       ctx, op_version_, dataset_id, processing_mode, address, protocol,
1480       data_transfer_protocol_, job_name, consumer_index, num_consumers,
1481       max_outstanding_requests, task_refresh_interval_hint_ms_, target_workers_,
1482       *metadata, iteration_counter, owns_resource, iteration_counter_handle,
1483       std::move(captured_uncompress_func), cross_trainer_cache_options,
1484       data_service_output_types, data_service_output_shapes);
1485   if (should_uncompress) {
1486     VLOG(2) << "Inserting a ParallelMap dataset to uncompress tf.data service "
1487             << "dataset " << dataset_id << ".";
1488     dataset->Initialize(/*metadata=*/{});
1489     captured_uncompress_func.reset();
1490     OP_REQUIRES_OK(
1491         ctx, CapturedFunction::Create(ctx, uncompress_fn_,
1492                                       /*captured_inputs=*/std::vector<Tensor>{},
1493                                       &captured_uncompress_func));
1494 
1495     // Release the ownership of `dataset` and transfer it to the ParallelMap
1496     // dataset for uncompression.
1497     core::ScopedUnref unref(dataset);
1498     dataset = MakeDataServiceUncompressDataset(
1499                   /*input=*/dataset, std::move(captured_uncompress_func),
1500                   output_types_, output_shapes_)
1501                   .release();
1502   }
1503   *output = dataset;
1504 }
1505 
1506 REGISTER_KERNEL_BUILDER(Name(kDataServiceDatasetV1).Device(DEVICE_CPU),
1507                         DataServiceDatasetOp);
1508 REGISTER_KERNEL_BUILDER(Name(kDataServiceDatasetV2).Device(DEVICE_CPU),
1509                         DataServiceDatasetOp);
1510 REGISTER_KERNEL_BUILDER(Name(kDataServiceDatasetV3).Device(DEVICE_CPU),
1511                         DataServiceDatasetOp);
1512 REGISTER_KERNEL_BUILDER(Name(kDataServiceDatasetV4).Device(DEVICE_CPU),
1513                         DataServiceDatasetOp);
1514 REGISTER_KERNEL_BUILDER(Name("DummyIterationCounter").Device(DEVICE_CPU),
1515                         DummyResourceOp<IterationCounter>);
1516 
1517 }  // namespace data
1518 }  // namespace tensorflow
1519