1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/experimental/data_service_dataset_op.h"
16
17 #include <algorithm>
18 #include <functional>
19 #include <limits>
20 #include <memory>
21 #include <optional>
22 #include <queue>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/algorithm/container.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/strings/ascii.h"
31 #include "absl/strings/numbers.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/string_view.h"
34 #include "absl/strings/substitute.h"
35 #include "absl/time/time.h"
36 #include "tensorflow/core/data/captured_function.h"
37 #include "tensorflow/core/data/dataset.pb.h"
38 #include "tensorflow/core/data/dataset_utils.h"
39 #include "tensorflow/core/data/name_utils.h"
40 #include "tensorflow/core/data/serialization_utils.h"
41 #include "tensorflow/core/data/service/client/common.h"
42 #include "tensorflow/core/data/service/client/validate_utils.h"
43 #include "tensorflow/core/data/service/common.h"
44 #include "tensorflow/core/data/service/common.pb.h"
45 #include "tensorflow/core/data/service/dispatcher.pb.h"
46 #include "tensorflow/core/data/service/dispatcher_client.h"
47 #include "tensorflow/core/data/service/grpc_util.h"
48 #include "tensorflow/core/data/service/worker.pb.h"
49 #include "tensorflow/core/data/service/worker_client.h"
50 #include "tensorflow/core/data/service/worker_impl.h"
51 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
52 #include "tensorflow/core/framework/dataset.h"
53 #include "tensorflow/core/framework/model.h"
54 #include "tensorflow/core/framework/partial_tensor_shape.h"
55 #include "tensorflow/core/framework/tensor.h"
56 #include "tensorflow/core/framework/types.pb.h"
57 #include "tensorflow/core/kernels/data/parallel_map_dataset_op.h"
58 #include "tensorflow/core/lib/core/errors.h"
59 #include "tensorflow/core/lib/gtl/cleanup.h"
60 #include "tensorflow/core/platform/env_time.h"
61 #include "tensorflow/core/platform/errors.h"
62 #include "tensorflow/core/platform/mutex.h"
63 #include "tensorflow/core/platform/refcount.h"
64 #include "tensorflow/core/platform/status.h"
65 #include "tensorflow/core/platform/statusor.h"
66 #include "tensorflow/core/platform/thread_annotations.h"
67 #include "tensorflow/core/platform/tstring.h"
68 #include "tensorflow/core/platform/types.h"
69 #include "tensorflow/core/profiler/lib/traceme.h"
70 #include "tensorflow/core/profiler/lib/traceme_encode.h"
71 #include "tensorflow/core/protobuf/data_service.pb.h"
72 #include "tensorflow/core/protobuf/error_codes.pb.h"
73
74 namespace tensorflow {
75 namespace data {
76
77 /* static */ constexpr const char* const DataServiceDatasetOp::kDatasetType;
78 /* static */ constexpr const char* const DataServiceDatasetOp::kDatasetId;
79 /* static */ constexpr const char* const DataServiceDatasetOp::kProcessingMode;
80 /* static */ constexpr const char* const DataServiceDatasetOp::kAddress;
81 /* static */ constexpr const char* const DataServiceDatasetOp::kProtocol;
82 /* static */ constexpr const char* const
83 DataServiceDatasetOp::kDataTransferProtocol;
84 /* static */ constexpr const char* const DataServiceDatasetOp::kJobName;
85 /* static */ constexpr const char* const DataServiceDatasetOp::kConsumerIndex;
86 /* static */ constexpr const char* const DataServiceDatasetOp::kNumConsumers;
87 /* static */ constexpr const char* const
88 DataServiceDatasetOp::kMaxOutstandingRequests;
89 /* static */ constexpr const char* const
90 DataServiceDatasetOp::kTaskRefreshIntervalHintMs;
91 /* static */ constexpr const char* const DataServiceDatasetOp::kTargetWorkers;
92 /* static */ constexpr const char* const
93 DataServiceDatasetOp::kIterationCounter;
94 /* static */ constexpr const char* const DataServiceDatasetOp::kOutputTypes;
95 /* static */ constexpr const char* const DataServiceDatasetOp::kOutputShapes;
96 /* static */ constexpr const char* const DataServiceDatasetOp::kUncompress;
97 /* static */ constexpr const char* const DataServiceDatasetOp::kUncompressFn;
98 /* static */ constexpr const char* const
99 DataServiceDatasetOp::kCrossTrainerCacheOptions;
100
101 namespace {
102 // Default interval between task list refreshes.
103 const int64_t kDefaultTaskRefreshIntervalMs = 1000; // 1 second.
104
105 constexpr char kDataServiceDatasetV1[] = "DataServiceDataset";
106 constexpr char kDataServiceDatasetV2[] = "DataServiceDatasetV2";
107 constexpr char kDataServiceDatasetV3[] = "DataServiceDatasetV3";
108 constexpr char kDataServiceDatasetV4[] = "DataServiceDatasetV4";
109
110 constexpr const char kParallelEpochs[] = "parallel_epochs";
111 constexpr const char kDistributedEpoch[] = "distributed_epoch";
112
113 // Same timeout used by the RegisterDatasetOp.
114 constexpr absl::Duration kGetMetadataRetryTimeout = absl::Hours(1);
115
IsColocatedTask(const TaskInfo & task)116 bool IsColocatedTask(const TaskInfo& task) {
117 return absl::c_any_of(task.worker_tags(), [](absl::string_view worker_tag) {
118 return absl::AsciiStrToUpper(worker_tag) == kColocatedWorkerTag;
119 });
120 }
121
GetDataServiceMetadata(const std::string & dataset_id,const tstring & address,const tstring & protocol)122 StatusOr<DataServiceMetadata> GetDataServiceMetadata(
123 const std::string& dataset_id, const tstring& address,
124 const tstring& protocol) {
125 DataServiceDispatcherClient client(address, protocol);
126 DataServiceMetadata metadata;
127 absl::Time deadline =
128 absl::FromUnixMicros(EnvTime::NowMicros()) + kGetMetadataRetryTimeout;
129
130 Status status = grpc_util::Retry(
131 [&]() { return client.GetDataServiceMetadata(dataset_id, metadata); },
132 absl::Substitute("Get data service metadata for dataset $0, "
133 "with dispatcher at $1.",
134 dataset_id, std::string(address)),
135 absl::ToUnixMicros(deadline));
136 if (errors::IsNotFound(status)) {
137 return errors::NotFound(
138 "Dataset id ", dataset_id,
139 " not found. It must be registered with `register_dataset` before "
140 "calling `from_dataset_id`.");
141 }
142 TF_RETURN_IF_ERROR(status);
143 return metadata;
144 }
145
GetValidatedCompression(const std::string & dataset_id,const DataServiceMetadata & metadata)146 StatusOr<DataServiceMetadata::Compression> GetValidatedCompression(
147 const std::string& dataset_id, const DataServiceMetadata& metadata) {
148 if (metadata.compression() == DataServiceMetadata::COMPRESSION_UNSPECIFIED) {
149 return errors::Internal(absl::Substitute(
150 "Got invalid compression $0 for dataset $1. A proper compression "
151 "should be registered in `register_dataset`.",
152 DataServiceMetadata::Compression_Name(metadata.compression()),
153 dataset_id));
154 }
155 return metadata.compression();
156 }
157
GetDataServiceConfig(const tstring & address,const tstring & protocol)158 StatusOr<DataServiceConfig> GetDataServiceConfig(const tstring& address,
159 const tstring& protocol) {
160 DataServiceDispatcherClient client(address, protocol);
161 DataServiceConfig config;
162 absl::Time deadline =
163 absl::FromUnixMicros(EnvTime::NowMicros()) + kGetMetadataRetryTimeout;
164
165 TF_RETURN_IF_ERROR(grpc_util::Retry(
166 [&]() { return client.GetDataServiceConfig(config); },
167 absl::Substitute("Get data service config with dispatcher at $0.",
168 std::string(address)),
169 absl::ToUnixMicros(deadline)));
170 return config;
171 }
172 } // namespace
173
174 // Dataset for reading data from the tf.data service non-deterministically.
175 //
176 // This dataset interleaves dataset elements produced by multiple tf.data
177 // workers. We periodically query the dispatcher to determine which workers
178 // to read from (in case workers are added or removed).
179 class DataServiceDatasetOp::Dataset : public DatasetBase {
180 public:
Dataset(OpKernelContext * ctx,int op_version,const std::string & dataset_id,const ProcessingModeDef & processing_mode,const std::string & address,const std::string & protocol,const std::string & data_transfer_protocol,const std::string & job_name,std::optional<int64_t> consumer_index,std::optional<int64_t> num_consumers,int64_t max_outstanding_requests,int64_t task_refresh_interval_ms,const TargetWorkers target_workers,const DataServiceMetadata & metadata,IterationCounter * iteration_counter,bool owns_resource,ResourceHandle iteration_counter_handle,std::unique_ptr<CapturedFunction> captured_uncompress_func,const std::optional<CrossTrainerCacheOptions> & cross_trainer_cache_options,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)181 Dataset(OpKernelContext* ctx, int op_version, const std::string& dataset_id,
182 const ProcessingModeDef& processing_mode, const std::string& address,
183 const std::string& protocol,
184 const std::string& data_transfer_protocol,
185 const std::string& job_name, std::optional<int64_t> consumer_index,
186 std::optional<int64_t> num_consumers,
187 int64_t max_outstanding_requests, int64_t task_refresh_interval_ms,
188 const TargetWorkers target_workers,
189 const DataServiceMetadata& metadata,
190 IterationCounter* iteration_counter, bool owns_resource,
191 ResourceHandle iteration_counter_handle,
192 std::unique_ptr<CapturedFunction> captured_uncompress_func,
193 const std::optional<CrossTrainerCacheOptions>&
194 cross_trainer_cache_options,
195 const DataTypeVector& output_types,
196 const std::vector<PartialTensorShape>& output_shapes)
197 : DatasetBase(DatasetContext(ctx)),
198 op_version_(op_version),
199 dataset_id_(dataset_id),
200 processing_mode_(processing_mode),
201 address_(address),
202 protocol_(protocol),
203 data_transfer_protocol_(data_transfer_protocol),
204 job_name_(job_name),
205 is_coordinated_read_(consumer_index.has_value()),
206 consumer_index_(consumer_index),
207 num_consumers_(num_consumers),
208 max_outstanding_requests_(max_outstanding_requests),
209 task_refresh_interval_ms_(task_refresh_interval_ms),
210 target_workers_(target_workers),
211 metadata_(metadata),
212 iteration_counter_(iteration_counter),
213 owns_resource_(owns_resource),
214 iteration_counter_handle_(iteration_counter_handle),
215 resource_mgr_(ctx->resource_manager()),
216 captured_uncompress_func_(std::move(captured_uncompress_func)),
217 cross_trainer_cache_options_(cross_trainer_cache_options),
218 output_types_(output_types),
219 output_shapes_(output_shapes) {}
220
~Dataset()221 ~Dataset() override {
222 iteration_counter_->Unref();
223 if (owns_resource_) {
224 Status s = resource_mgr_->Delete<IterationCounter>(
225 iteration_counter_handle_.container(),
226 iteration_counter_handle_.name());
227 if (!s.ok()) {
228 LOG(WARNING) << "Failed to delete iteration counter resource: " << s;
229 }
230 }
231 }
232
MakeIteratorInternal(const string & prefix) const233 std::unique_ptr<IteratorBase> MakeIteratorInternal(
234 const string& prefix) const override {
235 return std::make_unique<Iterator>(
236 Iterator::Params{this,
237 name_utils::IteratorPrefix(kDatasetType, prefix)},
238 DataServiceParams{dataset_id_, processing_mode_, address_, protocol_,
239 data_transfer_protocol_, job_name_,
240 /*repetition=*/iteration_counter_->GetAndIncrement(),
241 num_consumers_, consumer_index_, target_workers_,
242 metadata_, cross_trainer_cache_options_});
243 }
244
output_dtypes() const245 const DataTypeVector& output_dtypes() const override { return output_types_; }
246
output_shapes() const247 const std::vector<PartialTensorShape>& output_shapes() const override {
248 return output_shapes_;
249 }
250
DebugString() const251 string DebugString() const override {
252 return name_utils::DatasetDebugString(kDatasetType);
253 }
254
CardinalityInternal() const255 int64_t CardinalityInternal() const override {
256 if (is_coordinated_read_) {
257 // Coordinated reads require the dataset to be infinite.
258 return kInfiniteCardinality;
259 }
260
261 if (metadata_.cardinality() == 0) {
262 return 0;
263 }
264
265 if (metadata_.cardinality() == kInfiniteCardinality) {
266 // Sharding may cause an infinite dataset to be empty. For example, in
267 // `range(10).batch(10, drop_remainder=True).repeat()`, inserting `shard`
268 // before `batch` will cause the dataset to be empty.
269 // This case is rare, and there is significant performance hit for dynamic
270 // sharding if it reports unknown cardinality, so it is reasonable to
271 // report infinite cardinality. For DATA sharding, it is ok to report
272 // infinite cardinality since it inserts `shard` after `repeat`.
273 if (processing_mode_.sharding_policy() == ProcessingModeDef::OFF ||
274 processing_mode_.sharding_policy() == ProcessingModeDef::DYNAMIC ||
275 processing_mode_.sharding_policy() == ProcessingModeDef::DATA) {
276 return kInfiniteCardinality;
277 }
278 }
279 return kUnknownCardinality;
280 }
281
CheckExternalState() const282 Status CheckExternalState() const override {
283 return Status(
284 error::FAILED_PRECONDITION,
285 strings::StrCat(DebugString(), " does not yet support serialization."));
286 }
287
InputDatasets(std::vector<const DatasetBase * > * inputs) const288 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
289 inputs->clear();
290 return OkStatus();
291 }
292
293 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const294 Status AsGraphDefInternal(SerializationContext* ctx,
295 DatasetGraphDefBuilder* b,
296 Node** output) const override {
297 // Inputs
298 std::vector<Node*> inputs;
299
300 if (op_version_ >= 4) {
301 Node* dataset_id;
302 TF_RETURN_IF_ERROR(b->AddScalar(dataset_id_, &dataset_id));
303 inputs.push_back(dataset_id);
304 } else {
305 int64_t dataset_id_int;
306 if (!absl::SimpleAtoi(dataset_id_, &dataset_id_int)) {
307 return errors::Internal("Failed to parse dataset ID: ", dataset_id_,
308 ". Expect integers.");
309 }
310 Node* dataset_id;
311 TF_RETURN_IF_ERROR(b->AddScalar(dataset_id_int, &dataset_id));
312 inputs.push_back(dataset_id);
313 }
314
315 Node* processing_mode;
316 tstring processing_mode_str = processing_mode_.SerializeAsString();
317 TF_RETURN_IF_ERROR(b->AddScalar(processing_mode_str, &processing_mode));
318 inputs.push_back(processing_mode);
319
320 Node* address;
321 TF_RETURN_IF_ERROR(b->AddScalar(address_, &address));
322 inputs.push_back(address);
323
324 Node* protocol;
325 TF_RETURN_IF_ERROR(b->AddScalar(protocol_, &protocol));
326 inputs.push_back(protocol);
327
328 Node* job_name;
329 TF_RETURN_IF_ERROR(b->AddScalar(job_name_, &job_name));
330 inputs.push_back(job_name);
331
332 if (op_version_ >= 2) {
333 Node* consumer_index;
334 TF_RETURN_IF_ERROR(
335 b->AddScalar(consumer_index_.value_or(-1), &consumer_index));
336 inputs.push_back(consumer_index);
337
338 Node* num_consumers;
339 TF_RETURN_IF_ERROR(
340 b->AddScalar(num_consumers_.value_or(-1), &num_consumers));
341 inputs.push_back(num_consumers);
342 }
343
344 Node* max_outstanding_requests;
345 TF_RETURN_IF_ERROR(
346 b->AddScalar(max_outstanding_requests_, &max_outstanding_requests));
347 inputs.push_back(max_outstanding_requests);
348
349 Node* iteration_counter_handle = nullptr;
350 Tensor handle(DT_RESOURCE, TensorShape({}));
351 handle.scalar<ResourceHandle>()() = iteration_counter_handle_;
352 TF_RETURN_IF_ERROR(b->AddTensor(handle, &iteration_counter_handle));
353 inputs.push_back(iteration_counter_handle);
354
355 // Attributes
356 std::vector<std::pair<StringPiece, AttrValue>> attrs;
357 AttrValue task_refresh_interval_hint_ms;
358 b->BuildAttrValue(task_refresh_interval_ms_,
359 &task_refresh_interval_hint_ms);
360 attrs.push_back(
361 {kTaskRefreshIntervalHintMs, task_refresh_interval_hint_ms});
362
363 AttrValue data_transfer_protocol;
364 b->BuildAttrValue(data_transfer_protocol_, &data_transfer_protocol);
365 attrs.push_back({kDataTransferProtocol, data_transfer_protocol});
366
367 AttrValue target_workers;
368 b->BuildAttrValue(TargetWorkersToString(target_workers_), &target_workers);
369 attrs.push_back({kTargetWorkers, target_workers});
370
371 if (op_version_ >= 3) {
372 // Attr: uncompress is true for the first time the graph is built, when a
373 // ParallelMap dataset is inserted for uncompression. Subsequent
374 // serialization always sets it to false to avoid inserting repeated map
375 // datasets for uncompression.
376 AttrValue uncompress_attr;
377 b->BuildAttrValue(false, &uncompress_attr);
378 attrs.push_back({kUncompress, uncompress_attr});
379
380 // Attr: uncompress_fn
381 AttrValue uncompress_fn_attr;
382 b->BuildAttrValue(captured_uncompress_func_->func(), &uncompress_fn_attr);
383 attrs.push_back({kUncompressFn, uncompress_fn_attr});
384
385 std::vector<Node*> uncompress_arguments;
386 DataTypeVector uncompress_arguments_types;
387 TF_RETURN_IF_ERROR(captured_uncompress_func_->AddToGraph(
388 ctx, b, &uncompress_arguments, &uncompress_arguments_types));
389 }
390
391 // Attr: cross_trainer_cache_options
392 AttrValue cross_trainer_cache_options_attr;
393 std::string serialized_cross_trainer_cache_options;
394 if (cross_trainer_cache_options_.has_value()) {
395 serialized_cross_trainer_cache_options =
396 cross_trainer_cache_options_->SerializeAsString();
397 }
398 b->BuildAttrValue(serialized_cross_trainer_cache_options,
399 &cross_trainer_cache_options_attr);
400 attrs.push_back(
401 {kCrossTrainerCacheOptions, cross_trainer_cache_options_attr});
402 return b->AddDataset(this, inputs, attrs, output);
403 }
404
405 private:
406 class Iterator : public DatasetIterator<Dataset> {
407 public:
Iterator(const Params & params,const DataServiceParams & data_service_params)408 explicit Iterator(const Params& params,
409 const DataServiceParams& data_service_params)
410 : DatasetIterator<Dataset>(params),
411 data_service_params_(data_service_params),
412 max_outstanding_requests_(params.dataset->max_outstanding_requests_) {
413 }
414
~Iterator()415 ~Iterator() override {
416 VLOG(1) << "Destroying data service dataset iterator for iteration id "
417 << iteration_client_id_;
418 CancelThreads();
419 if (deregister_fn_) deregister_fn_();
420 task_thread_manager_.reset();
421 if (initialized_) {
422 Status s = dispatcher_->ReleaseIterationClient(iteration_client_id_);
423 if (!s.ok()) {
424 LOG(WARNING) << "Failed to release iteration client id: " << s;
425 }
426 }
427 for (auto& worker_thread : worker_threads_) {
428 worker_thread.reset();
429 }
430 DeleteLocalWorkerTasks();
431 VLOG(1) << "Destroyed data service dataset iterator for iteration id "
432 << iteration_client_id_;
433 }
434
Initialize(IteratorContext * ctx)435 Status Initialize(IteratorContext* ctx) override {
436 TF_RETURN_IF_ERROR(ValidateDataServiceParams(data_service_params_));
437 VLOG(3) << "Connecting to " << dataset()->address_
438 << " in data service dataset op";
439 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
440 ctx->cancellation_manager(), [this]() { CancelThreads(); },
441 &deregister_fn_));
442 dispatcher_ = std::make_unique<DataServiceDispatcherClient>(
443 dataset()->address_, dataset()->protocol_);
444 int64_t deadline_micros = kint64max;
445 std::optional<std::string> job_name;
446 if (!dataset()->job_name_.empty()) {
447 job_name = dataset()->job_name_;
448 }
449 TF_RETURN_IF_ERROR(grpc_util::Retry(
450 [&]() {
451 return dispatcher_->GetOrCreateJob(
452 dataset()->dataset_id_, dataset()->processing_mode_, job_name,
453 dataset()->num_consumers_,
454 dataset()->cross_trainer_cache_options_.has_value(),
455 dataset()->target_workers_, job_id_);
456 },
457 /*description=*/
458 strings::StrCat("get or create job with dispatcher at ",
459 dataset()->address_),
460 deadline_micros));
461 TF_RETURN_IF_ERROR(grpc_util::Retry(
462 [&]() {
463 return dispatcher_->GetOrCreateIteration(
464 job_id_, data_service_params_.repetition, iteration_client_id_);
465 },
466 /*description=*/
467 strings::StrCat("get or create iteration with dispatcher at ",
468 dataset()->address_),
469 deadline_micros));
470 initialized_ = true;
471 return OkStatus();
472 }
473
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)474 Status GetNextInternal(IteratorContext* ctx,
475 std::vector<Tensor>* out_tensors,
476 bool* end_of_sequence) override {
477 VLOG(3) << "Calling GetNext in data service dataset's iterator.";
478 mutex_lock l(mu_);
479 EnsureThreadsStarted(ctx);
480 std::shared_ptr<Result> result;
481 do {
482 while (!ResultReady() && !Finished() && !cancelled_ && status_.ok()) {
483 VLOG(3) << "Blocking in GetNext: " << DebugString();
484 get_next_cv_.wait(l);
485 }
486 if (cancelled_) {
487 VLOG(3) << "Returning from GetNext due to cancellation";
488 return errors::Cancelled("Data service iterator was cancelled");
489 }
490 if (!status_.ok()) {
491 VLOG(3) << "Returning from GetNext with error " << status_;
492 return status_;
493 }
494 if (results_.empty()) {
495 *end_of_sequence = true;
496 VLOG(3) << "Returning from GetNext with end_of_sequence";
497 return OkStatus();
498 }
499 if (!ResultReady()) {
500 return errors::Internal(
501 "Expected a result to be ready, but none were.");
502 }
503 result = PopNextResult();
504 worker_thread_cv_.notify_one();
505 } while (result->skip);
506
507 *end_of_sequence = result->end_of_sequence;
508 if (!*end_of_sequence) {
509 VLOG(1) << "Returning the next element from data service dataset's "
510 << "Iterator: task " << result->task_id << ", element "
511 << result->element_index;
512 if (StrictRoundRobin()) {
513 VLOG(1) << "Consumer " << dataset()->consumer_index_.value()
514 << ": Result " << get_next_index_++;
515 }
516 out_tensors->swap(result->element);
517 }
518 return OkStatus();
519 }
520
521 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const522 std::shared_ptr<model::Node> CreateNode(
523 IteratorContext* ctx, model::Node::Args args) const override {
524 return model::MakeKnownRatioNode(std::move(args),
525 /*ratio=*/1);
526 }
527
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)528 Status SaveInternal(SerializationContext* ctx,
529 IteratorStateWriter* writer) override {
530 return errors::Unimplemented("SaveInternal is not yet supported");
531 }
532
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)533 Status RestoreInternal(IteratorContext* ctx,
534 IteratorStateReader* reader) override {
535 return errors::Unimplemented("RestoreInternal is not yet supported");
536 }
537
GetTraceMeMetadata() const538 data::TraceMeMetadata GetTraceMeMetadata() const override {
539 data::TraceMeMetadata result;
540 int64_t num_tasks = -1;
541 if (mu_.try_lock()) {
542 num_tasks = tasks_.size() - finished_tasks_;
543 mu_.unlock();
544 }
545 result.push_back(std::make_pair(
546 "num_tasks",
547 num_tasks == -1
548 ? kTraceInfoUnavailable
549 : strings::Printf("%lld", static_cast<long long>(num_tasks))));
550 result.push_back(std::make_pair("job_name", dataset()->job_name_));
551 result.push_back(std::make_pair(
552 "max_outstanding_requests",
553 strings::Printf("%lld", static_cast<long long>(
554 dataset()->max_outstanding_requests_))));
555 return result;
556 }
557
558 private:
559 struct Task {
Tasktensorflow::data::DataServiceDatasetOp::Dataset::Iterator::Task560 Task(const TaskInfo& info,
561 std::unique_ptr<DataServiceWorkerClient> worker)
562 : info(info), worker(std::move(worker)) {}
563
564 const TaskInfo info;
565 // Client for fetching task elements from the tf.data service worker.
566 const std::unique_ptr<DataServiceWorkerClient> worker;
567 // The next round to read from the task.
568 int64_t round = 0;
569 // Whether the task has been removed. The task will eventually be
570 // deleted from `tasks_` on the next dispatcher heartbeat.
571 bool removed = false;
572 bool skipped_previous_round = false;
573 // Indicates whether a worker thread is currently processing the task.
574 bool in_use TF_GUARDED_BY(&Iterator::mu_) = false;
575 // Indicates whether the worker has returned end_of_sequence for the task.
576 bool end_of_sequence TF_GUARDED_BY(&Iterator::mu_) = false;
577 };
578
579 struct Result {
580 Result() = default;
581 Result(Result&&) = default;
582 Result& operator=(Result&&) = default;
583 Result(const Result&) = delete;
584 Result& operator=(const Result&) = delete;
585
586 // Whether the result has been computed yet. GetNext needs to block
587 // until the next result is ready.
588 bool ready TF_GUARDED_BY(&Iterator::mu_) = false;
589 std::vector<Tensor> element TF_GUARDED_BY(&Iterator::mu_);
590 // The element's index within the tf.data worker it came from. Used for
591 // debugging.
592 int64_t element_index TF_GUARDED_BY(&Iterator::mu_) = -1;
593 // The id of the task that generated the result.
594 int64_t task_id TF_GUARDED_BY(&Iterator::mu_) = -1;
595 bool end_of_sequence TF_GUARDED_BY(&Iterator::mu_) = false;
596 bool skip TF_GUARDED_BY(&Iterator::mu_) = false;
597 };
598
599 // Returns whether the iterator has finished and should return.
Finished() const600 bool Finished() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
601 return num_running_worker_threads_ == 0 && !ShouldWaitForNext();
602 }
603
604 // Returns whether the iterator has more data.
ShouldWaitForNext() const605 bool ShouldWaitForNext() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
606 if (should_finish_iteration_) {
607 return !iteration_finished_;
608 }
609 return tasks_.empty() || finished_tasks_ < tasks_.size();
610 }
611
EnsureThreadsStarted(IteratorContext * ctx)612 void EnsureThreadsStarted(IteratorContext* ctx)
613 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
614 if (!task_thread_manager_ && !cancelled_) {
615 auto new_ctx = std::make_shared<IteratorContext>(*ctx);
616 task_thread_manager_ =
617 ctx->StartThread("task-thread-manager",
618 [this, new_ctx]() { TaskThreadManager(new_ctx); });
619 }
620 }
621
CancelThreads()622 void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
623 mutex_lock l(mu_);
624 for (const auto& task : tasks_) {
625 task->worker->TryCancel();
626 }
627 cancelled_ = true;
628 worker_thread_cv_.notify_all();
629 manager_thread_cv_.notify_all();
630 get_next_cv_.notify_all();
631 }
632
DeleteLocalWorkerTasks()633 void DeleteLocalWorkerTasks() {
634 std::vector<std::shared_ptr<Task>> tasks;
635 {
636 mutex_lock l(mu_);
637 tasks = tasks_;
638 }
639
640 for (const std::shared_ptr<Task>& task : tasks) {
641 std::shared_ptr<DataServiceWorkerImpl> worker =
642 LocalWorkers::Get(task->info.worker_address());
643 if (worker && ShouldDeleteLocalTask(task->info)) {
644 worker->DeleteLocalTask(task->info);
645 }
646 }
647 }
648
649 // Deletes the task if it is only read by the local client.
ShouldDeleteLocalTask(const TaskInfo & task) const650 bool ShouldDeleteLocalTask(const TaskInfo& task) const {
651 if (StrictRoundRobin()) {
652 return false;
653 }
654
655 if (dataset()->target_workers_ == TARGET_WORKERS_LOCAL) {
656 return true;
657 }
658
659 return dataset()->target_workers_ == TARGET_WORKERS_AUTO &&
660 IsColocatedTask(task);
661 }
662
663 // Periodically refresh the task list.
664 // Maintain one thread fetching elements for each task.
665 // TODO(aaudibert): Instead of polling, have dispatcher send updates when
666 // the list of tasks changes.
TaskThreadManager(std::shared_ptr<IteratorContext> ctx)667 void TaskThreadManager(std::shared_ptr<IteratorContext> ctx) {
668 auto cleanup =
669 gtl::MakeCleanup([] { VLOG(1) << "Task thread manager exiting"; });
670 VLOG(1) << "Starting task thread manager";
671 uint64 next_check = Env::Default()->NowMicros();
672 while (true) {
673 {
674 mutex_lock l(mu_);
675 // All units are microseconds.
676 while (!cancelled_ && Env::Default()->NowMicros() < next_check) {
677 int64_t remaining_time = next_check - Env::Default()->NowMicros();
678 VLOG(4) << "Task thread manager waiting for " << remaining_time
679 << "us";
680 manager_thread_cv_.wait_for(
681 l, std::chrono::microseconds(remaining_time));
682 }
683 if (cancelled_) {
684 VLOG(3) << "Task thread manager finished";
685 return;
686 }
687 }
688 Heartbeat();
689 UpdateBufferSize();
690 UpdateWorkerThreads(ctx.get());
691 next_check = Env::Default()->NowMicros() +
692 dataset()->task_refresh_interval_ms_ * 1000;
693 }
694 }
695
TryBlockRound(int64_t round)696 void TryBlockRound(int64_t round) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
697 if (round_robin_round_limit_.has_value() &&
698 round_robin_round_limit_.value() == round) {
699 return;
700 }
701 if (current_round_ >= round) {
702 // In the next heartbeat, notify the dispatcher that we failed to add
703 // the task.
704 VLOG(1) << "Rejecting request to block round " << round
705 << ", because processing has already begun for round "
706 << current_round_;
707 return;
708 }
709 VLOG(1) << "Accepting request to block round " << round;
710 round_robin_round_limit_ = round;
711 }
712
UpdateIterationFinished(bool iteration_finished)713 void UpdateIterationFinished(bool iteration_finished)
714 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
715 if (!iteration_finished) {
716 return;
717 }
718 iteration_finished_ = true;
719 get_next_cv_.notify_all();
720 worker_thread_cv_.notify_all();
721 }
722
AddTask(const TaskInfo & task_info)723 Status AddTask(const TaskInfo& task_info) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
724 TF_ASSIGN_OR_RETURN(
725 std::unique_ptr<DataServiceWorkerClient> worker,
726 CreateDataServiceWorkerClient(task_info.transfer_address(),
727 dataset()->protocol_,
728 dataset()->data_transfer_protocol_));
729 tasks_.push_back(std::make_shared<Task>(task_info, std::move(worker)));
730 worker_thread_cv_.notify_one();
731 if (StrictRoundRobin()) {
732 VLOG(1) << "Consumer " << dataset()->consumer_index_.value()
733 << " adding task " << task_info.task_id()
734 << " to read from worker " << task_info.worker_address()
735 << ". Task starting round: " << task_info.starting_round();
736 DCHECK_LE(current_round_, task_info.starting_round());
737 if (current_round_ == task_info.starting_round()) {
738 DCHECK_EQ(next_task_index_, 0);
739 }
740 }
741 return OkStatus();
742 }
743
Heartbeat()744 void Heartbeat() TF_LOCKS_EXCLUDED(mu_) {
745 ClientHeartbeatRequest req;
746 req.set_iteration_client_id(iteration_client_id_);
747 if (StrictRoundRobin()) {
748 mutex_lock l(mu_);
749 req.set_current_round(current_round_);
750 if (round_robin_round_limit_.has_value()) {
751 req.set_blocked_round(round_robin_round_limit_.value());
752 }
753 }
754 ClientHeartbeatResponse resp;
755 Status s = dispatcher_->ClientHeartbeat(req, resp);
756 if (!s.ok()) {
757 if (IsPreemptedError(s)) {
758 LOG(WARNING)
759 << "Failed to heartbeat to dispatcher from iteration client id "
760 << iteration_client_id_
761 << ". Dispatcher address: " << dataset()->address_
762 << ". Error: " << s;
763 return;
764 }
765 mutex_lock l(mu_);
766 status_ = s;
767 get_next_cv_.notify_all();
768 }
769 mutex_lock l(mu_);
770 UpdateIterationFinished(resp.iteration_finished());
771 if (resp.optional_block_round_case() ==
772 ClientHeartbeatResponse::kBlockRound) {
773 TryBlockRound(resp.block_round());
774 } else {
775 round_robin_round_limit_ = std::nullopt;
776 worker_thread_cv_.notify_all();
777 }
778 UpdateTasks(resp);
779 RecordTFMetrics(resp);
780 }
781
UpdateTasks(const ClientHeartbeatResponse & resp)782 void UpdateTasks(const ClientHeartbeatResponse& resp)
783 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
784 absl::flat_hash_map<int64_t, TaskInfo> task_id_to_task;
785 for (auto& task : resp.task_info()) {
786 task_id_to_task[task.task_id()] = task;
787 }
788 if (iteration_finished_) {
789 return;
790 }
791
792 int index = 0;
793 while (index < tasks_.size()) {
794 std::shared_ptr<Task> task = tasks_[index];
795 if (task_id_to_task.contains(task->info.task_id())) {
796 // Remove already-known tasks from `task_id_to_task`, so that at the
797 // end of the loop, only new tasks remain.
798 task_id_to_task.erase(task->info.task_id());
799 ++index;
800 } else {
801 // Task has been removed.
802 if (task->end_of_sequence) {
803 finished_tasks_--;
804 }
805 tasks_.erase(tasks_.begin() + index);
806 if (index < next_task_index_) {
807 next_task_index_--;
808 }
809 if (!tasks_.empty() && next_task_index_ >= tasks_.size()) {
810 AdvanceTaskIndex();
811 }
812 }
813 }
814 for (auto& task : resp.task_info()) {
815 auto it = task_id_to_task.find(task.task_id());
816 if (it == task_id_to_task.end()) {
817 continue;
818 }
819 if (!ShouldReadFromTask(task)) {
820 VLOG(3) << "Skipping untargeted worker task " << task.task_id();
821 should_finish_iteration_ = false;
822 continue;
823 }
824 Status s = AddTask(it->second);
825 if (!s.ok()) {
826 status_ = s;
827 get_next_cv_.notify_all();
828 break;
829 }
830 }
831 }
832
ShouldReadFromTask(const TaskInfo & task) const833 bool ShouldReadFromTask(const TaskInfo& task) const
834 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
835 if (StrictRoundRobin()) {
836 return true;
837 }
838
839 const bool is_local_task =
840 (LocalWorkers::Get(task.worker_address()) != nullptr);
841 if (dataset()->target_workers_ == TARGET_WORKERS_LOCAL &&
842 !is_local_task) {
843 return false;
844 }
845
846 // Cross-TF/TPU host reads may cause resource contention on the TF/TPU
847 // hosts. tf.data service avoids reading from non-local TF-hosted workers.
848 const bool is_cross_tf_host_read =
849 !is_local_task && IsColocatedTask(task);
850 if (dataset()->target_workers_ == TARGET_WORKERS_AUTO &&
851 is_cross_tf_host_read) {
852 return false;
853 }
854 return true;
855 }
856
RecordTFMetrics(const ClientHeartbeatResponse & resp)857 void RecordTFMetrics(const ClientHeartbeatResponse& resp)
858 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
859 for (const auto& task : resp.task_info()) {
860 if (worker_uids_.contains(task.worker_uid())) {
861 continue;
862 }
863 metrics::RecordTFDataServiceClientIterators(
864 task.worker_uid(), resp.deployment_mode(),
865 dataset()->processing_mode_, dataset()->is_coordinated_read_);
866 worker_uids_.insert(task.worker_uid());
867 }
868 }
869
UpdateBufferSize()870 void UpdateBufferSize() TF_LOCKS_EXCLUDED(mu_) {
871 if (dataset()->max_outstanding_requests_ == model::kAutotune) {
872 // Adjust `max_outstanding_requests_` to account for newly added tasks.
873 // `tasks_` includes the local tasks, so we subtract one from the
874 // configured local task buffer size.
875 mutex_lock l(mu_);
876 int64_t max_outstanding_requests = tasks_.size();
877 if (max_outstanding_requests > max_outstanding_requests_) {
878 worker_thread_cv_.notify_all();
879 }
880 max_outstanding_requests_ = max_outstanding_requests;
881 }
882 }
883
UpdateWorkerThreads(IteratorContext * ctx)884 void UpdateWorkerThreads(IteratorContext* ctx) TF_LOCKS_EXCLUDED(mu_) {
885 mutex_lock l(mu_);
886 const int64_t max_num_threads =
887 std::min<int64_t>(tasks_.size(), max_outstanding_requests_);
888 while (num_running_worker_threads_ < max_num_threads && !cancelled_ &&
889 status_.ok()) {
890 num_running_worker_threads_++;
891 auto done = [this]() {
892 mutex_lock l(mu_);
893 num_running_worker_threads_--;
894 get_next_cv_.notify_all();
895 };
896 worker_threads_.push_back(ctx->StartThread(
897 "tf-data-service-task_thread", [this, done = std::move(done)]() {
898 RunWorkerThread(std::move(done));
899 }));
900 }
901 }
902
RunWorkerThread(std::function<void ()> done)903 void RunWorkerThread(std::function<void()> done) {
904 auto cleanup = gtl::MakeCleanup([done = std::move(done)]() {
905 done();
906 VLOG(1) << "Worker thread exiting";
907 });
908 VLOG(1) << "Starting worker thread";
909 std::shared_ptr<Task> task_to_process;
910 while (true) {
911 std::shared_ptr<Result> result;
912 {
913 mutex_lock l(mu_);
914 if (task_to_process) {
915 task_to_process->in_use = false;
916 --outstanding_requests_;
917 task_to_process = nullptr;
918 worker_thread_cv_.notify_one();
919 }
920 while (true) {
921 if (cancelled_ || !ShouldWaitForNext()) {
922 return;
923 }
924 task_to_process = GetTaskToProcess();
925 if (task_to_process) {
926 VLOG(3) << "Selected a task to process: "
927 << task_to_process->info.ShortDebugString();
928 break;
929 }
930 worker_thread_cv_.wait(l);
931 }
932 DCHECK(task_to_process != nullptr);
933 task_to_process->in_use = true;
934 ++outstanding_requests_;
935 if (StrictRoundRobin()) {
936 // Reserve a spot in the results_ queue.
937 results_.push(std::make_shared<Result>());
938 result = results_.back();
939 } else {
940 // The result will be added to results_ when it's ready.
941 result = std::make_shared<Result>();
942 }
943 VLOG(3) << "Processing task " << task_to_process->info.task_id();
944 }
945 int64_t deadline_micros = kint64max;
946 Status s =
947 GetElementTraced(task_to_process.get(), deadline_micros,
948 /*enqueue_result=*/!StrictRoundRobin(), result);
949 if (!s.ok()) {
950 mutex_lock l(mu_);
951 VLOG(1) << "Failed to get element from worker "
952 << task_to_process->info.worker_address() << ": " << s;
953 task_to_process->in_use = false;
954 --outstanding_requests_;
955 status_ = errors::CreateWithUpdatedMessage(
956 s, absl::StrCat("Failed to get element from worker ",
957 task_to_process->info.worker_address(), ": ",
958 s.error_message()));
959 get_next_cv_.notify_all();
960 return;
961 }
962 }
963 }
964
965 // Reports whether we can request another element without violating
966 // `max_outstanding_requests_`.
ShouldProcessTask()967 bool ShouldProcessTask() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
968 // When doing round-robin reads, outstanding requests pre-allocate a
969 // result in `results_`, so we only need to check the size of `results_`.
970 if (StrictRoundRobin()) {
971 return results_.size() < max_outstanding_requests_;
972 }
973 // Otherwise, results aren't added to `results_` until the data has been
974 // successfully retrieved. We need to count requests already added to
975 // `results_` as well as in-progress requests.
976 return results_.size() + outstanding_requests_ <
977 max_outstanding_requests_;
978 }
979
980 // Searches for a task to process, visiting tasks in-order and giving every
981 // task a chance to proceed.
GetTaskToProcess()982 std::shared_ptr<Task> GetTaskToProcess() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
983 if (!ShouldProcessTask()) {
984 return nullptr;
985 }
986
987 for (int i = 0; i < tasks_.size(); ++i) {
988 std::shared_ptr<Task>& task = tasks_[next_task_index_];
989 if (StrictRoundRobin() &&
990 (task->in_use ||
991 current_round_ >= round_robin_round_limit_.value_or(
992 std::numeric_limits<int64_t>::max()))) {
993 VLOG(4) << "No round robin task found. in_use: " << task->in_use
994 << ". current_round: " << current_round_
995 << ". round_robin_round_limit: "
996 << round_robin_round_limit_.value_or(-1);
997 return nullptr;
998 }
999 if (current_round_ < task->info.starting_round() || task->in_use ||
1000 task->end_of_sequence || task->removed) {
1001 VLOG(3) << "Skipping task " << next_task_index_
1002 << ". starting round: " << task->info.starting_round()
1003 << ". current round: " << current_round_
1004 << ". task->in_use: " << task->in_use
1005 << ". end_of_sequence: " << task->end_of_sequence
1006 << ". task->removed: " << task->removed;
1007 AdvanceTaskIndex();
1008 continue;
1009 }
1010 task->round = current_round_;
1011 AdvanceTaskIndex();
1012 return task;
1013 }
1014 return nullptr;
1015 }
1016
1017 // Increments the next task index, starting over if all tasks have been
1018 // processed.
AdvanceTaskIndex()1019 void AdvanceTaskIndex() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1020 next_task_index_++;
1021 if (next_task_index_ >= tasks_.size()) {
1022 current_round_++;
1023 next_task_index_ = 0;
1024 }
1025 }
1026
TryGetElement(const Task & task,GetElementResult & result)1027 Status TryGetElement(const Task& task, GetElementResult& result) {
1028 GetElementRequest req;
1029 req.set_task_id(task.info.task_id());
1030 req.set_skipped_previous_round(task.skipped_previous_round);
1031 if (StrictRoundRobin()) {
1032 req.set_consumer_index(dataset()->consumer_index_.value());
1033 req.set_round_index(task.round);
1034 req.set_allow_skip(true);
1035 }
1036 if (dataset()->cross_trainer_cache_options_) {
1037 req.set_trainer_id(
1038 dataset()->cross_trainer_cache_options_->trainer_id());
1039 }
1040 return task.worker->GetElement(req, result);
1041 }
1042
ProcessGetElementResponse(bool enqueue_result,GetElementResult & get_element_result,std::shared_ptr<Result> result,Task & task)1043 void ProcessGetElementResponse(bool enqueue_result,
1044 GetElementResult& get_element_result,
1045 std::shared_ptr<Result> result, Task& task) {
1046 mutex_lock l(mu_);
1047 result->ready = true;
1048 result->end_of_sequence = get_element_result.end_of_sequence;
1049 result->skip = get_element_result.skip;
1050 if (!get_element_result.end_of_sequence && !get_element_result.skip) {
1051 task.skipped_previous_round = false;
1052 result->element = std::move(get_element_result.components);
1053 result->element_index = get_element_result.element_index;
1054 result->task_id = task.info.task_id();
1055 } else if (get_element_result.skip) {
1056 task.skipped_previous_round = true;
1057 } else {
1058 task.end_of_sequence = true;
1059 finished_tasks_++;
1060 }
1061 if (enqueue_result && !result->end_of_sequence) {
1062 results_.push(std::move(result));
1063 }
1064 get_next_cv_.notify_all();
1065 }
1066
GetElementTraced(Task * task,int64_t deadline_micros,bool enqueue_result,std::shared_ptr<Result> result)1067 Status GetElementTraced(Task* task, int64_t deadline_micros,
1068 bool enqueue_result, std::shared_ptr<Result> result)
1069 TF_LOCKS_EXCLUDED(mu_) {
1070 VLOG(3) << "Getting an element for task id " << task->info.task_id();
1071 tensorflow::profiler::TraceMe activity(
1072 "GetDataServiceElement", tensorflow::profiler::TraceMeLevel::kInfo);
1073 activity.AppendMetadata([&]() {
1074 return profiler::TraceMeEncode(
1075 {{"address", task->info.worker_address()}});
1076 });
1077 if (StrictRoundRobin()) {
1078 VLOG(3) << "Requesting element from consumer index "
1079 << dataset()->consumer_index_.value() << ", round "
1080 << task->round;
1081 activity.AppendMetadata([&]() {
1082 return profiler::TraceMeEncode(
1083 {{"consumer_index", dataset()->consumer_index_.value()},
1084 {"round_index", task->round}});
1085 });
1086 }
1087 Status s = GetElement(task, deadline_micros, enqueue_result, result);
1088 mutex_lock l(mu_);
1089 VLOG(3) << "Got an element for task id " << task->info.task_id();
1090 return s;
1091 }
1092
MaybeRemoveTask(Task & task,int64_t deadline_micros,Result & result)1093 Status MaybeRemoveTask(Task& task, int64_t deadline_micros,
1094 Result& result) {
1095 bool removed;
1096 VLOG(1) << "Requesting task removal for worker "
1097 << task.info.worker_address() << " in round " << task.round;
1098 TF_RETURN_IF_ERROR(grpc_util::Retry(
1099 [&] {
1100 return dispatcher_->MaybeRemoveTask(
1101 task.info.task_id(), dataset()->consumer_index_.value(),
1102 task.round, removed);
1103 },
1104 /*should_retry=*/
1105 [&] {
1106 mutex_lock l(mu_);
1107 return !cancelled_;
1108 },
1109 /*description=*/"request task removal ", deadline_micros));
1110 if (removed) {
1111 mutex_lock l(mu_);
1112 task.removed = true;
1113 result.ready = true;
1114 result.skip = true;
1115 get_next_cv_.notify_all();
1116 return OkStatus();
1117 }
1118 VLOG(1) << "Failed to remove task for worker "
1119 << task.info.worker_address();
1120 return OkStatus();
1121 }
1122
GetElement(Task * task,int64_t deadline_micros,bool enqueue_result,std::shared_ptr<Result> result)1123 Status GetElement(Task* task, int64_t deadline_micros, bool enqueue_result,
1124 std::shared_ptr<Result> result) TF_LOCKS_EXCLUDED(mu_) {
1125 GetElementResult get_element_result;
1126 for (int num_retries = 0;; ++num_retries) {
1127 Status s = TryGetElement(*task, get_element_result);
1128 if (s.ok()) break;
1129 // Retry all errors that could indicate preemption.
1130 if (!IsPreemptedError(s)) {
1131 return s;
1132 }
1133 {
1134 mutex_lock l(mu_);
1135 if (cancelled_) {
1136 return errors::Cancelled("DataServiceDataset iterator cancelled");
1137 }
1138 }
1139 int64_t now_micros = Env::Default()->NowMicros();
1140 if (now_micros > deadline_micros) {
1141 return s;
1142 }
1143 if (StrictRoundRobin() && num_retries > 0) {
1144 TF_RETURN_IF_ERROR(MaybeRemoveTask(*task, deadline_micros, *result));
1145 mutex_lock l(mu_);
1146 if (result->skip) {
1147 return OkStatus();
1148 }
1149 }
1150 int64_t backoff_until = std::min(
1151 deadline_micros,
1152 now_micros + ::tensorflow::ComputeBackoffMicroseconds(num_retries));
1153 VLOG(1) << "Failed to get an element from worker "
1154 << task->info.worker_address() << ": " << s
1155 << ". Will retry in " << (backoff_until - now_micros)
1156 << " microseconds";
1157 Env::Default()->SleepForMicroseconds(backoff_until - now_micros);
1158 }
1159 ProcessGetElementResponse(enqueue_result, get_element_result, result,
1160 *task);
1161 return OkStatus();
1162 }
1163
ResultReady() const1164 bool ResultReady() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1165 return !results_.empty() && results_.front()->ready;
1166 }
1167
PopNextResult()1168 std::shared_ptr<Result> PopNextResult() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1169 std::shared_ptr<Result> result = results_.front();
1170 results_.pop();
1171 return result;
1172 }
1173
StrictRoundRobin() const1174 bool StrictRoundRobin() const {
1175 return dataset()->num_consumers_.has_value();
1176 }
1177
DebugString() const1178 std::string DebugString() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1179 return absl::Substitute(
1180 "results_ { size: $0 front.ready: $1 } iteration_finished_: $2 "
1181 "tasks { size: $3 } finished_tasks_: $4 "
1182 "num_running_worker_threads_: $5",
1183 results_.size(), !results_.empty() && results_.front()->ready,
1184 iteration_finished_, tasks_.size(), finished_tasks_,
1185 num_running_worker_threads_);
1186 }
1187
1188 const DataServiceParams data_service_params_;
1189
1190 mutable mutex mu_;
1191 condition_variable get_next_cv_ TF_GUARDED_BY(mu_);
1192 condition_variable worker_thread_cv_ TF_GUARDED_BY(mu_);
1193 condition_variable manager_thread_cv_ TF_GUARDED_BY(mu_);
1194 bool cancelled_ TF_GUARDED_BY(mu_) = false;
1195 // Method for deregistering the cancellation callback.
1196 std::function<void()> deregister_fn_;
1197
1198 // Number of outstanding requests.
1199 int64_t outstanding_requests_ TF_GUARDED_BY(mu_) = 0;
1200
1201 // max_outstanding_requests controls how many elements may be held in memory
1202 // at the same time. This count includes both in-progress requests for
1203 // elements as well as completed requests which haven't yet been produced.
1204 int64_t max_outstanding_requests_ TF_GUARDED_BY(mu_);
1205
1206 // The number of threads in `worker_threads_` which are still running.
1207 int64_t num_running_worker_threads_ TF_GUARDED_BY(mu_) = 0;
1208
1209 // The index of the next task in `tasks_` to read from.
1210 int64_t next_task_index_ TF_GUARDED_BY(mu_) = 0;
1211
1212 // The number tasks in the `tasks_` list that have reached end_of_sequence.
1213 int64_t finished_tasks_ TF_GUARDED_BY(mu_) = 0;
1214
1215 // List of tasks to read from.
1216 std::vector<std::shared_ptr<Task>> tasks_ TF_GUARDED_BY(mu_);
1217
1218 // The current round robin round we are engaged in. A round involves reading
1219 // from each task once.
1220 int64_t current_round_ TF_GUARDED_BY(mu_) = 0;
1221
1222 // Maximum round robin round to read up to before blocking, not inclusive.
1223 // INVARIANT: current_round_ <= round_robin_round_limit_.
1224 // If current_round_ == round_robin_round_limit_,
1225 // next_task_index_ must be 0.
1226 std::optional<int64_t> round_robin_round_limit_ TF_GUARDED_BY(mu_);
1227
1228 // A status to be returned from the next call to `GetNext`. This is set by
1229 // asynchronous threads when they encounter errors.
1230 Status status_ TF_GUARDED_BY(mu_) = OkStatus();
1231 // A queue of results for `GetElement` requests to read from. When doing
1232 // strict round robin reads, the queue will contain placeholder results with
1233 // their `Result::ready` field false until their data has been retrieved
1234 // from a worker. When not doing round-robin reads, results are only added
1235 // to the queue after they are ready, to avoid head-of-line blocking.
1236 std::queue<std::shared_ptr<Result>> results_ TF_GUARDED_BY(mu_);
1237
1238 bool initialized_ = false;
1239 // Set once in Initialize().
1240 int64_t job_id_;
1241 int64_t iteration_client_id_;
1242 std::unique_ptr<DataServiceDispatcherClient> dispatcher_;
1243 int64_t get_next_index_ TF_GUARDED_BY(mu_) = 0;
1244
1245 bool iteration_finished_ = false;
1246 bool should_finish_iteration_ TF_GUARDED_BY(mu_) = true;
1247
1248 // The set of worker UIDs that we have already recorded metrics for.
1249 absl::flat_hash_set<int64_t> worker_uids_ TF_GUARDED_BY(mu_);
1250
1251 std::vector<std::unique_ptr<Thread>> worker_threads_ TF_GUARDED_BY(mu_);
1252 std::unique_ptr<Thread> task_thread_manager_ TF_GUARDED_BY(mu_);
1253 };
1254
1255 const int op_version_;
1256 const tstring dataset_id_;
1257 const ProcessingModeDef processing_mode_;
1258 const tstring address_;
1259 const tstring protocol_;
1260 const tstring data_transfer_protocol_;
1261 const tstring job_name_;
1262 const bool is_coordinated_read_;
1263 const std::optional<int64_t> consumer_index_;
1264 const std::optional<int64_t> num_consumers_;
1265 const int64_t max_outstanding_requests_;
1266 const int64_t task_refresh_interval_ms_;
1267 const TargetWorkers target_workers_;
1268 const DataServiceMetadata metadata_;
1269 IterationCounter* const iteration_counter_; // Owned
1270 const bool owns_resource_;
1271 const ResourceHandle iteration_counter_handle_;
1272 ResourceMgr* const resource_mgr_; // Not owned
1273 const std::unique_ptr<CapturedFunction> captured_uncompress_func_;
1274 const std::optional<CrossTrainerCacheOptions> cross_trainer_cache_options_;
1275 const DataTypeVector output_types_;
1276 const std::vector<PartialTensorShape> output_shapes_;
1277 };
1278
DataServiceDatasetOp(OpKernelConstruction * ctx)1279 DataServiceDatasetOp::DataServiceDatasetOp(OpKernelConstruction* ctx)
1280 : DatasetOpKernel(ctx) {
1281 const auto& op_name = ctx->def().op();
1282 if (op_name == kDataServiceDatasetV1) {
1283 op_version_ = 1;
1284 } else if (op_name == kDataServiceDatasetV2) {
1285 op_version_ = 2;
1286 } else if (op_name == kDataServiceDatasetV3) {
1287 op_version_ = 3;
1288 } else if (op_name == kDataServiceDatasetV4) {
1289 op_version_ = 4;
1290 } else {
1291 ctx->CtxFailure(errors::FailedPrecondition(
1292 "Unrecognized data service dataset op name: ", op_name));
1293 return;
1294 }
1295
1296 OP_REQUIRES_OK(ctx, ctx->GetAttr(kTaskRefreshIntervalHintMs,
1297 &task_refresh_interval_hint_ms_));
1298 if (task_refresh_interval_hint_ms_ == model::kAutotune) {
1299 task_refresh_interval_hint_ms_ = kDefaultTaskRefreshIntervalMs;
1300 }
1301 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
1302 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
1303 if (ctx->HasAttr(kDataTransferProtocol)) {
1304 OP_REQUIRES_OK(
1305 ctx, ctx->GetAttr(kDataTransferProtocol, &data_transfer_protocol_));
1306 }
1307 if (data_transfer_protocol_.empty()) {
1308 data_transfer_protocol_ = kGrpcTransferProtocol;
1309 }
1310
1311 std::string target_workers_str = "AUTO";
1312 if (ctx->HasAttr(kTargetWorkers)) {
1313 OP_REQUIRES_OK(ctx, ctx->GetAttr(kTargetWorkers, &target_workers_str));
1314 }
1315 StatusOr<TargetWorkers> status_or_target_workers =
1316 ParseTargetWorkers(target_workers_str);
1317 OP_REQUIRES_OK(ctx, status_or_target_workers.status());
1318 target_workers_ = *status_or_target_workers;
1319
1320 if (op_version_ >= 3) {
1321 OP_REQUIRES_OK(ctx, ctx->GetAttr(kUncompress, &uncompress_));
1322 FunctionMetadata::Params params;
1323 params.use_inter_op_parallelism = true;
1324 OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kUncompressFn, params,
1325 &uncompress_fn_));
1326 }
1327
1328 if (ctx->HasAttr(kCrossTrainerCacheOptions)) {
1329 OP_REQUIRES_OK(ctx, ctx->GetAttr(kCrossTrainerCacheOptions,
1330 &seriazlied_cross_trainer_cache_options_));
1331 }
1332 }
1333
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)1334 void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx,
1335 DatasetBase** output) {
1336 tstring dataset_id;
1337 if (op_version_ >= 4) {
1338 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kDatasetId, &dataset_id));
1339 } else {
1340 int64_t dataset_id_int = 0;
1341 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kDatasetId, &dataset_id_int));
1342 dataset_id = absl::StrCat(dataset_id_int);
1343 }
1344
1345 tstring processing_mode_str;
1346 OP_REQUIRES_OK(
1347 ctx, ParseScalarArgument(ctx, kProcessingMode, &processing_mode_str));
1348 ProcessingModeDef processing_mode;
1349 if (processing_mode_str == kParallelEpochs) {
1350 processing_mode.set_sharding_policy(ProcessingModeDef::OFF);
1351 } else if (processing_mode_str == kDistributedEpoch) {
1352 processing_mode.set_sharding_policy(ProcessingModeDef::DYNAMIC);
1353 } else {
1354 OP_REQUIRES(ctx, processing_mode.ParseFromString(processing_mode_str),
1355 errors::InvalidArgument(absl::Substitute(
1356 "Failed to parse ProcessingModeDef from string: $0",
1357 std::string(processing_mode_str))));
1358 }
1359
1360 tstring address;
1361 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kAddress, &address));
1362 OP_REQUIRES(ctx, !address.empty(),
1363 errors::InvalidArgument(kAddress, " must be non-empty."));
1364
1365 tstring protocol;
1366 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kProtocol, &protocol));
1367 OP_REQUIRES(ctx, !protocol.empty(),
1368 errors::InvalidArgument(kProtocol, " must be non-empty."));
1369
1370 tstring job_name;
1371 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kJobName, &job_name));
1372
1373 StatusOr<DataServiceConfig> config = GetDataServiceConfig(address, protocol);
1374 OP_REQUIRES_OK(ctx, config.status());
1375
1376 if (IsStaticShard(processing_mode) &&
1377 config->deployment_mode() == DEPLOYMENT_MODE_COLOCATED &&
1378 target_workers_ == TARGET_WORKERS_AUTO) {
1379 VLOG(1) << "Using LOCAL target workers for static sharding mode: "
1380 << processing_mode.ShortDebugString();
1381 target_workers_ = TARGET_WORKERS_LOCAL;
1382 }
1383 if (target_workers_ == TARGET_WORKERS_LOCAL) {
1384 data_transfer_protocol_ = kLocalTransferProtocol;
1385 }
1386
1387 std::optional<int64_t> consumer_index;
1388 std::optional<int64_t> num_consumers;
1389 if (op_version_ >= 2) {
1390 int64_t consumer_index_int;
1391 OP_REQUIRES_OK(
1392 ctx, ParseScalarArgument(ctx, kConsumerIndex, &consumer_index_int));
1393 if (consumer_index_int >= 0) {
1394 consumer_index = consumer_index_int;
1395 }
1396
1397 int64_t num_consumers_int;
1398 OP_REQUIRES_OK(ctx,
1399 ParseScalarArgument(ctx, kNumConsumers, &num_consumers_int));
1400 if (num_consumers_int >= 0) {
1401 num_consumers = num_consumers_int;
1402 }
1403 }
1404
1405 int64_t max_outstanding_requests;
1406 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kMaxOutstandingRequests,
1407 &max_outstanding_requests));
1408
1409 ResourceHandle iteration_counter_handle;
1410 OP_REQUIRES_OK(
1411 ctx, HandleFromInput(ctx, kIterationCounter, &iteration_counter_handle));
1412 IterationCounter* iteration_counter = nullptr;
1413 Status s = ctx->resource_manager()->Lookup<IterationCounter>(
1414 iteration_counter_handle.container(), iteration_counter_handle.name(),
1415 &iteration_counter);
1416 bool owns_resource = false;
1417 if (errors::IsNotFound(s)) {
1418 owns_resource = true;
1419 static std::atomic<int64_t> resource_id_counter(0);
1420 const std::string& container = ctx->resource_manager()->default_container();
1421 std::string name =
1422 strings::StrCat(ctx->op_kernel().name(), "/", kIterationCounter, "_",
1423 resource_id_counter.fetch_add(1));
1424 OP_REQUIRES_OK(ctx,
1425 ctx->resource_manager()->LookupOrCreate<IterationCounter>(
1426 container, name, &iteration_counter,
1427 [](IterationCounter** counter) {
1428 *counter = new IterationCounter();
1429 return OkStatus();
1430 }));
1431 iteration_counter_handle =
1432 MakeResourceHandle<IterationCounter>(ctx, container, name);
1433 } else {
1434 OP_REQUIRES_OK(ctx, s);
1435 }
1436
1437 OP_REQUIRES(
1438 ctx,
1439 max_outstanding_requests == model::kAutotune ||
1440 max_outstanding_requests > 0,
1441 errors::InvalidArgument(kMaxOutstandingRequests, " must be positive or ",
1442 model::kAutotune));
1443
1444 StatusOr<DataServiceMetadata> metadata =
1445 GetDataServiceMetadata(dataset_id, address, protocol);
1446 OP_REQUIRES_OK(ctx, metadata.status());
1447
1448 bool should_uncompress = op_version_ >= 3 && uncompress_;
1449 if (should_uncompress) {
1450 StatusOr<DataServiceMetadata::Compression> compression =
1451 GetValidatedCompression(dataset_id, *metadata);
1452 OP_REQUIRES_OK(ctx, compression.status());
1453 should_uncompress =
1454 should_uncompress &&
1455 (*compression == DataServiceMetadata::COMPRESSION_SNAPPY);
1456 }
1457 DataTypeVector data_service_output_types = output_types_;
1458 std::vector<PartialTensorShape> data_service_output_shapes = output_shapes_;
1459 if (should_uncompress) {
1460 data_service_output_types = {DT_VARIANT};
1461 data_service_output_shapes = {TensorShape({})};
1462 }
1463
1464 std::unique_ptr<CapturedFunction> captured_uncompress_func;
1465 if (op_version_ >= 3) {
1466 OP_REQUIRES_OK(
1467 ctx, CapturedFunction::Create(ctx, uncompress_fn_,
1468 /*captured_inputs=*/std::vector<Tensor>{},
1469 &captured_uncompress_func));
1470 }
1471
1472 std::optional<CrossTrainerCacheOptions> cross_trainer_cache_options;
1473 if (!seriazlied_cross_trainer_cache_options_.empty()) {
1474 cross_trainer_cache_options.emplace();
1475 cross_trainer_cache_options->ParseFromString(
1476 seriazlied_cross_trainer_cache_options_);
1477 }
1478 DatasetBase* dataset = new Dataset(
1479 ctx, op_version_, dataset_id, processing_mode, address, protocol,
1480 data_transfer_protocol_, job_name, consumer_index, num_consumers,
1481 max_outstanding_requests, task_refresh_interval_hint_ms_, target_workers_,
1482 *metadata, iteration_counter, owns_resource, iteration_counter_handle,
1483 std::move(captured_uncompress_func), cross_trainer_cache_options,
1484 data_service_output_types, data_service_output_shapes);
1485 if (should_uncompress) {
1486 VLOG(2) << "Inserting a ParallelMap dataset to uncompress tf.data service "
1487 << "dataset " << dataset_id << ".";
1488 dataset->Initialize(/*metadata=*/{});
1489 captured_uncompress_func.reset();
1490 OP_REQUIRES_OK(
1491 ctx, CapturedFunction::Create(ctx, uncompress_fn_,
1492 /*captured_inputs=*/std::vector<Tensor>{},
1493 &captured_uncompress_func));
1494
1495 // Release the ownership of `dataset` and transfer it to the ParallelMap
1496 // dataset for uncompression.
1497 core::ScopedUnref unref(dataset);
1498 dataset = MakeDataServiceUncompressDataset(
1499 /*input=*/dataset, std::move(captured_uncompress_func),
1500 output_types_, output_shapes_)
1501 .release();
1502 }
1503 *output = dataset;
1504 }
1505
1506 REGISTER_KERNEL_BUILDER(Name(kDataServiceDatasetV1).Device(DEVICE_CPU),
1507 DataServiceDatasetOp);
1508 REGISTER_KERNEL_BUILDER(Name(kDataServiceDatasetV2).Device(DEVICE_CPU),
1509 DataServiceDatasetOp);
1510 REGISTER_KERNEL_BUILDER(Name(kDataServiceDatasetV3).Device(DEVICE_CPU),
1511 DataServiceDatasetOp);
1512 REGISTER_KERNEL_BUILDER(Name(kDataServiceDatasetV4).Device(DEVICE_CPU),
1513 DataServiceDatasetOp);
1514 REGISTER_KERNEL_BUILDER(Name("DummyIterationCounter").Device(DEVICE_CPU),
1515 DummyResourceOp<IterationCounter>);
1516
1517 } // namespace data
1518 } // namespace tensorflow
1519