1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/service/dispatcher_impl.h"
17
18 #include <algorithm>
19 #include <array>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #ifdef PLATFORM_GOOGLE
27 #include "file/logging/log_lines.h"
28 #endif
29 #include "grpcpp/create_channel.h"
30 #include "grpcpp/impl/codegen/server_context.h"
31 #include "grpcpp/security/credentials.h"
32 #include "absl/container/flat_hash_map.h"
33 #include "absl/container/flat_hash_set.h"
34 #include "absl/memory/memory.h"
35 #include "absl/time/time.h"
36 #include "tensorflow/core/data/dataset_utils.h"
37 #include "tensorflow/core/data/hash_utils.h"
38 #include "tensorflow/core/data/service/common.h"
39 #include "tensorflow/core/data/service/common.pb.h"
40 #include "tensorflow/core/data/service/credentials_factory.h"
41 #include "tensorflow/core/data/service/dataset_store.h"
42 #include "tensorflow/core/data/service/dispatcher.pb.h"
43 #include "tensorflow/core/data/service/dispatcher_state.h"
44 #include "tensorflow/core/data/service/export.pb.h"
45 #include "tensorflow/core/data/service/grpc_util.h"
46 #include "tensorflow/core/data/service/journal.h"
47 #include "tensorflow/core/data/service/validate_utils.h"
48 #include "tensorflow/core/data/service/worker.grpc.pb.h"
49 #include "tensorflow/core/data/standalone.h"
50 #include "tensorflow/core/framework/dataset.h"
51 #include "tensorflow/core/framework/graph.pb.h"
52 #include "tensorflow/core/framework/metrics.h"
53 #include "tensorflow/core/framework/node_def.pb.h"
54 #include "tensorflow/core/framework/tensor.h"
55 #include "tensorflow/core/platform/env.h"
56 #include "tensorflow/core/platform/errors.h"
57 #include "tensorflow/core/platform/mutex.h"
58 #include "tensorflow/core/platform/path.h"
59 #include "tensorflow/core/platform/protobuf.h"
60 #include "tensorflow/core/platform/random.h"
61 #include "tensorflow/core/platform/status.h"
62 #include "tensorflow/core/platform/statusor.h"
63 #include "tensorflow/core/platform/strcat.h"
64 #include "tensorflow/core/platform/thread_annotations.h"
65 #include "tensorflow/core/protobuf/data_service.pb.h"
66 #include "tensorflow/core/protobuf/service_config.pb.h"
67 #include "tensorflow/core/public/session_options.h"
68
69 namespace tensorflow {
70 namespace data {
71 namespace {
72
73 using ::tensorflow::protobuf::util::MessageDifferencer;
74
75 // The name of the journal directory inside the dispatcher's working directory.
76 // This name is load-bearing; do not change.
77 constexpr char kJournalDir[] = "tf_data_dispatcher_journal";
78 // The name of the datasets directory inside the dispatcher's working directory.
79 constexpr char kDatasetsDir[] = "datasets";
80 constexpr int64_t kDefaultIterationGcCheckIntervalMs =
81 10 * 60 * 1000; // 10 minutes.
82 constexpr int64_t kDefaultIterationGcTimeoutMs = 5 * 60 * 1000; // 5 minutes.
83 constexpr int64_t kDefaultClientTimeoutMs = 2 * 60 * 1000; // 2 minutes.
84
85 constexpr std::array<const char*, 8> kNodeNameSharingOps = {
86 "HashTable",
87 "HashTableV2",
88 "MutableHashTable",
89 "MutableHashTableV2",
90 "MutableDenseHashTable",
91 "MutableDenseHashTableV2",
92 "MutableHashTableOfTensors",
93 "MutableHashTableOfTensorsV2",
94 };
95
96 using DispatcherConfig = experimental::DispatcherConfig;
97 using Dataset = DispatcherState::Dataset;
98 using Worker = DispatcherState::Worker;
99 using Job = DispatcherState::Job;
100 using IterationKey = DispatcherState::IterationKey;
101 using Iteration = DispatcherState::Iteration;
102 using Task = DispatcherState::Task;
103
JournalDir(const std::string & work_dir)104 std::string JournalDir(const std::string& work_dir) {
105 return io::JoinPath(work_dir, kJournalDir);
106 }
107
DatasetsDir(const std::string & work_dir)108 std::string DatasetsDir(const std::string& work_dir) {
109 return io::JoinPath(work_dir, kDatasetsDir);
110 }
111
DatasetKey(const std::string & dataset_id,uint64 fingerprint)112 std::string DatasetKey(const std::string& dataset_id, uint64 fingerprint) {
113 return absl::StrCat("id_", dataset_id, "_fp_", fingerprint);
114 }
115
CreateWorkerStub(const std::string & address,const std::string & protocol,std::unique_ptr<WorkerService::Stub> & stub)116 Status CreateWorkerStub(const std::string& address, const std::string& protocol,
117 std::unique_ptr<WorkerService::Stub>& stub) {
118 ::grpc::ChannelArguments args;
119 args.SetMaxReceiveMessageSize(-1);
120 std::shared_ptr<::grpc::ChannelCredentials> credentials;
121 TF_RETURN_IF_ERROR(
122 CredentialsFactory::CreateClientCredentials(protocol, &credentials));
123 auto channel = ::grpc::CreateCustomChannel(address, credentials, args);
124 stub = WorkerService::NewStub(channel);
125 return OkStatus();
126 }
127
PrepareGraph(GraphDef * graph)128 void PrepareGraph(GraphDef* graph) {
129 for (NodeDef& node : *graph->mutable_node()) {
130 for (const auto& op : kNodeNameSharingOps) {
131 // Set `use_node_name_sharing` to `true` so that resources aren't deleted
132 // prematurely. Otherwise, resources may be deleted when their ops are
133 // deleted at the end of the GraphRunner::Run used by standalone::Dataset.
134 if (node.op() == op) {
135 (*node.mutable_attr())["use_node_name_sharing"].set_b(true);
136 }
137 if (!node.device().empty()) {
138 *node.mutable_device() = "";
139 }
140 }
141 }
142 StripDevicePlacement(graph->mutable_library());
143 }
144
ApplyConfigDefaults(const DispatcherConfig & config)145 DispatcherConfig ApplyConfigDefaults(const DispatcherConfig& config) {
146 DispatcherConfig new_config(config);
147 if (new_config.job_gc_check_interval_ms() == 0) {
148 new_config.set_job_gc_check_interval_ms(kDefaultIterationGcCheckIntervalMs);
149 }
150 if (new_config.job_gc_timeout_ms() == 0) {
151 new_config.set_job_gc_timeout_ms(kDefaultIterationGcTimeoutMs);
152 }
153 if (new_config.client_timeout_ms() == 0) {
154 new_config.set_client_timeout_ms(kDefaultClientTimeoutMs);
155 }
156 return new_config;
157 }
158
VLogLines(const int log_level,const std::string & message)159 void VLogLines(const int log_level, const std::string& message) {
160 #if defined(PLATFORM_GOOGLE)
161 VLOG_LINES(log_level, message);
162 #else
163 VLOG(log_level) << message;
164 #endif
165 }
166 } // namespace
167
DataServiceDispatcherImpl(const DispatcherConfig & config)168 DataServiceDispatcherImpl::DataServiceDispatcherImpl(
169 const DispatcherConfig& config)
170 : config_(ApplyConfigDefaults(config)),
171 env_(Env::Default()),
172 state_(config_) {
173 if (config_.work_dir().empty()) {
174 dataset_store_ = std::make_unique<MemoryDatasetStore>();
175 } else {
176 dataset_store_ = std::make_unique<FileSystemDatasetStore>(
177 DatasetsDir(config_.work_dir()));
178 }
179 }
180
~DataServiceDispatcherImpl()181 DataServiceDispatcherImpl::~DataServiceDispatcherImpl() {
182 {
183 mutex_lock l(mu_);
184 cancelled_ = true;
185 iteration_gc_thread_cv_.notify_all();
186 }
187 iteration_gc_thread_.reset();
188 }
189
Start()190 Status DataServiceDispatcherImpl::Start() {
191 mutex_lock l(mu_);
192 if (config_.job_gc_timeout_ms() >= 0) {
193 iteration_gc_thread_ = absl::WrapUnique(env_->StartThread(
194 {}, "iteration-gc-thread", [&] { IterationGcThread(); }));
195 }
196 if (config_.work_dir().empty()) {
197 if (config_.fault_tolerant_mode()) {
198 return errors::InvalidArgument(
199 "fault_tolerant_mode is True, but no work_dir is configured.");
200 }
201 } else {
202 TF_RETURN_IF_ERROR(
203 env_->RecursivelyCreateDir(DatasetsDir(config_.work_dir())));
204 }
205 if (!config_.fault_tolerant_mode()) {
206 LOG(INFO) << "Running with fault_tolerant_mode=False. The dispatcher will "
207 "not be able to recover its state on restart.";
208 started_ = true;
209 return OkStatus();
210 }
211 journal_writer_ =
212 std::make_unique<FileJournalWriter>(env_, JournalDir(config_.work_dir()));
213 LOG(INFO) << "Attempting to restore dispatcher state from journal in "
214 << JournalDir(config_.work_dir());
215 Update update;
216 bool end_of_journal = false;
217 FileJournalReader reader(env_, JournalDir(config_.work_dir()));
218 Status s = reader.Read(update, end_of_journal);
219 if (errors::IsNotFound(s)) {
220 LOG(INFO) << "No journal found. Starting dispatcher from new state.";
221 } else if (!s.ok()) {
222 return s;
223 } else {
224 while (!end_of_journal) {
225 TF_RETURN_IF_ERROR(ApplyWithoutJournaling(update));
226 TF_RETURN_IF_ERROR(reader.Read(update, end_of_journal));
227 }
228 }
229 for (const auto& iteration : state_.ListIterations()) {
230 if (IsDynamicShard(iteration->job->processing_mode)) {
231 TF_RETURN_IF_ERROR(RestoreSplitProviders(
232 *iteration, split_providers_[iteration->iteration_id]));
233 }
234 }
235 for (const auto& client_id : state_.ListActiveClientIds()) {
236 // Conservatively pretend we just received a heartbeat from all clients, so
237 // that we don't garbage collect iterations too early.
238 latest_client_heartbeats_time_[client_id] =
239 absl::FromUnixMicros(env_->NowMicros());
240 }
241 // Initialize the journal writer in `Start` so that we fail fast in case it
242 // can't be initialized.
243 TF_RETURN_IF_ERROR(journal_writer_.value()->EnsureInitialized());
244 started_ = true;
245 return OkStatus();
246 }
247
NumActiveIterations()248 size_t DataServiceDispatcherImpl::NumActiveIterations() TF_LOCKS_EXCLUDED(mu_) {
249 mutex_lock l(mu_);
250 size_t count = 0;
251 for (const auto& iteration : state_.ListIterations()) {
252 if (!iteration->finished) {
253 count++;
254 }
255 }
256 return count;
257 }
258
RestoreSplitProviders(const Iteration & iteration,std::vector<std::unique_ptr<SplitProvider>> & restored)259 Status DataServiceDispatcherImpl::RestoreSplitProviders(
260 const Iteration& iteration,
261 std::vector<std::unique_ptr<SplitProvider>>& restored)
262 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
263 const std::vector<int64_t>& indices =
264 iteration.distributed_epoch_state.value().indices;
265 std::vector<std::unique_ptr<SplitProvider>> split_providers;
266 TF_RETURN_IF_ERROR(
267 MakeSplitProviders(iteration.job->dataset_id, split_providers));
268 for (int provider_index = 0; provider_index < indices.size();
269 ++provider_index) {
270 int index = indices[provider_index];
271 VLOG(1) << "Restoring split provider " << provider_index
272 << " for iteration " << iteration.iteration_id << " to index "
273 << index;
274 Tensor unused_tensor;
275 bool unused_end_of_splits;
276 for (int i = 0; i < index; ++i) {
277 TF_RETURN_IF_ERROR(split_providers[provider_index]->GetNext(
278 &unused_tensor, &unused_end_of_splits));
279 }
280 }
281 restored = std::move(split_providers);
282 return OkStatus();
283 }
284
FindTasksToDelete(const absl::flat_hash_set<int64_t> & current_tasks,const std::vector<std::shared_ptr<const Task>> assigned_tasks,WorkerHeartbeatResponse * response)285 Status DataServiceDispatcherImpl::FindTasksToDelete(
286 const absl::flat_hash_set<int64_t>& current_tasks,
287 const std::vector<std::shared_ptr<const Task>> assigned_tasks,
288 WorkerHeartbeatResponse* response) {
289 absl::flat_hash_set<int64_t> assigned_ids;
290 for (const auto& assigned : assigned_tasks) {
291 assigned_ids.insert(assigned->task_id);
292 }
293 for (int64_t current_task : current_tasks) {
294 if (!assigned_ids.contains(current_task)) {
295 response->add_tasks_to_delete(current_task);
296 }
297 }
298 return OkStatus();
299 }
300
FindNewTasks(const std::string & worker_address,const absl::flat_hash_set<int64_t> & current_tasks,std::vector<std::shared_ptr<const Task>> & assigned_tasks,WorkerHeartbeatResponse * response)301 Status DataServiceDispatcherImpl::FindNewTasks(
302 const std::string& worker_address,
303 const absl::flat_hash_set<int64_t>& current_tasks,
304 std::vector<std::shared_ptr<const Task>>& assigned_tasks,
305 WorkerHeartbeatResponse* response) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
306 // Check for round-robin iterations that had tasks on the worker removed. Now
307 // that the worker is back, we create a new pending task for the worker.
308 absl::flat_hash_set<int64_t> assigned_iteration_ids;
309 for (const auto& task : assigned_tasks) {
310 assigned_iteration_ids.insert(task->iteration->iteration_id);
311 }
312 for (const auto& iteration : state_.ListIterations()) {
313 if (!assigned_iteration_ids.contains(iteration->iteration_id) &&
314 iteration->IsRoundRobin() && !iteration->finished) {
315 VLOG(1) << "Creating pending task for reconnected worker "
316 << worker_address;
317 TF_RETURN_IF_ERROR(CreatePendingTask(iteration, worker_address));
318 }
319 }
320 // Refresh assigned_tasks to include newly added pending tasks.
321 TF_RETURN_IF_ERROR(state_.TasksForWorker(worker_address, assigned_tasks));
322 for (const auto& task : assigned_tasks) {
323 if (current_tasks.contains(task->task_id)) {
324 continue;
325 }
326 TaskDef* task_def = response->add_new_tasks();
327 TF_RETURN_IF_ERROR(PopulateTaskDef(task, task_def));
328 }
329 return OkStatus();
330 }
331
WorkerHeartbeat(const WorkerHeartbeatRequest * request,WorkerHeartbeatResponse * response)332 Status DataServiceDispatcherImpl::WorkerHeartbeat(
333 const WorkerHeartbeatRequest* request, WorkerHeartbeatResponse* response) {
334 TF_RETURN_IF_ERROR(CheckStarted());
335 VLOG(4) << "Received worker heartbeat request from worker "
336 << request->worker_address();
337 mutex_lock l(mu_);
338 const std::string& worker_address = request->worker_address();
339 // Assigned tasks from the perspective of the dispatcher.
340 std::vector<std::shared_ptr<const Task>> assigned_tasks;
341 Status s = state_.TasksForWorker(worker_address, assigned_tasks);
342 if (!s.ok()) {
343 if (!errors::IsNotFound(s)) {
344 return s;
345 }
346 VLOG(1) << "Registering new worker at address " << worker_address;
347 TF_RETURN_IF_ERROR(state_.ValidateWorker(worker_address));
348 Update update;
349 update.mutable_register_worker()->set_worker_address(worker_address);
350 update.mutable_register_worker()->set_transfer_address(
351 request->transfer_address());
352 *update.mutable_register_worker()->mutable_worker_tags() =
353 request->worker_tags();
354 update.mutable_register_worker()->set_worker_uid(request->worker_uid());
355 TF_RETURN_IF_ERROR(Apply(update));
356 TF_RETURN_IF_ERROR(CreateTasksForWorker(worker_address));
357 TF_RETURN_IF_ERROR(state_.TasksForWorker(worker_address, assigned_tasks));
358 }
359 absl::flat_hash_set<int64_t> current_tasks;
360 current_tasks.insert(request->current_tasks().cbegin(),
361 request->current_tasks().cend());
362 TF_RETURN_IF_ERROR(
363 FindTasksToDelete(current_tasks, assigned_tasks, response));
364 TF_RETURN_IF_ERROR(
365 FindNewTasks(worker_address, current_tasks, assigned_tasks, response));
366
367 VLOG(4) << "Finished worker heartbeat for worker at address "
368 << request->worker_address();
369 return OkStatus();
370 }
371
WorkerUpdate(const WorkerUpdateRequest * request,WorkerUpdateResponse * response)372 Status DataServiceDispatcherImpl::WorkerUpdate(
373 const WorkerUpdateRequest* request, WorkerUpdateResponse* response) {
374 TF_RETURN_IF_ERROR(CheckStarted());
375 mutex_lock l(mu_);
376 for (auto& update : request->updates()) {
377 int64_t task_id = update.task_id();
378 std::shared_ptr<const Task> task;
379 TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task));
380 if (update.completed()) {
381 if (task->finished) {
382 VLOG(1) << "Received completion update for already-finished task "
383 << task->task_id << " on worker " << task->worker_address;
384 continue;
385 }
386 Update update;
387 update.mutable_finish_task()->set_task_id(task_id);
388 TF_RETURN_IF_ERROR(Apply(update));
389 VLOG(3) << "Task " << task_id << " from iteration "
390 << task->iteration->iteration_id << " completed";
391 }
392 }
393 return OkStatus();
394 }
395
GetDatasetDef(const GetDatasetDefRequest * request,GetDatasetDefResponse * response)396 Status DataServiceDispatcherImpl::GetDatasetDef(
397 const GetDatasetDefRequest* request, GetDatasetDefResponse* response) {
398 TF_RETURN_IF_ERROR(CheckStarted());
399 mutex_lock l(mu_);
400 std::shared_ptr<const Dataset> dataset;
401 TF_RETURN_IF_ERROR(state_.DatasetFromId(request->dataset_id(), dataset));
402 std::shared_ptr<const DatasetDef> dataset_def;
403 TF_RETURN_IF_ERROR(GetDatasetDef(*dataset, dataset_def));
404 *response->mutable_dataset_def() = *dataset_def;
405 return OkStatus();
406 }
407
GetSplit(const GetSplitRequest * request,GetSplitResponse * response)408 Status DataServiceDispatcherImpl::GetSplit(const GetSplitRequest* request,
409 GetSplitResponse* response) {
410 TF_RETURN_IF_ERROR(CheckStarted());
411 mutex_lock l(mu_);
412 int64_t iteration_id = request->iteration_id();
413 int64_t repetition = request->repetition();
414 int64_t provider_index = request->split_provider_index();
415 VLOG(3) << "Received GetSplit request for iteration " << iteration_id
416 << ", repetition " << repetition << ", split provider index "
417 << provider_index;
418 std::shared_ptr<const Iteration> iteration;
419 TF_RETURN_IF_ERROR(state_.IterationFromId(iteration_id, iteration));
420 if (!iteration->distributed_epoch_state.has_value()) {
421 return errors::FailedPrecondition(
422 "Cannot get split for iteration ", iteration_id,
423 ", since it is not a distributed_epoch iteration.");
424 }
425 int64_t current_repetition =
426 iteration->distributed_epoch_state.value().repetitions[provider_index];
427 if (repetition < current_repetition) {
428 response->set_end_of_splits(true);
429 VLOG(3) << "Returning end_of_splits since current repetition "
430 << current_repetition
431 << " is greater than the requested repetition " << repetition;
432 return OkStatus();
433 }
434 SplitProvider* split_provider =
435 split_providers_[iteration_id][provider_index].get();
436 DCHECK(split_provider != nullptr);
437 Tensor split;
438 bool end_of_splits = false;
439 TF_RETURN_IF_ERROR(split_provider->GetNext(&split, &end_of_splits));
440 TF_RETURN_IF_ERROR(RecordSplitProduced(iteration_id, repetition,
441 request->split_provider_index(),
442 end_of_splits));
443 response->set_end_of_splits(end_of_splits);
444 if (end_of_splits) {
445 // Reset the split provider to prepare for the next iteration.
446 TF_RETURN_IF_ERROR(split_providers_[iteration_id][provider_index]->Reset());
447 } else {
448 split.AsProtoTensorContent(response->mutable_split());
449 }
450 VLOG(3) << "Returning from GetSplit, end_of_splits=" << end_of_splits;
451 return OkStatus();
452 }
453
MakeSplitProviders(const std::string & dataset_id,std::vector<std::unique_ptr<SplitProvider>> & split_providers)454 Status DataServiceDispatcherImpl::MakeSplitProviders(
455 const std::string& dataset_id,
456 std::vector<std::unique_ptr<SplitProvider>>& split_providers)
457 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
458 std::shared_ptr<const Dataset> dataset;
459 TF_RETURN_IF_ERROR(state_.DatasetFromId(dataset_id, dataset));
460 std::shared_ptr<const DatasetDef> dataset_def;
461 TF_RETURN_IF_ERROR(GetDatasetDef(*dataset, dataset_def));
462 standalone::Dataset::Params params;
463 std::unique_ptr<standalone::Dataset> standalone_dataset;
464 TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
465 params, dataset_def->graph(), &standalone_dataset));
466 TF_RETURN_IF_ERROR(standalone_dataset->MakeSplitProviders(&split_providers));
467 return OkStatus();
468 }
469
GetVersion(const GetVersionRequest * request,GetVersionResponse * response)470 Status DataServiceDispatcherImpl::GetVersion(const GetVersionRequest* request,
471 GetVersionResponse* response) {
472 response->set_version(kDataServiceVersion);
473 return OkStatus();
474 }
475
GetOrRegisterDataset(const GetOrRegisterDatasetRequest * request,GetOrRegisterDatasetResponse * response)476 Status DataServiceDispatcherImpl::GetOrRegisterDataset(
477 const GetOrRegisterDatasetRequest* request,
478 GetOrRegisterDatasetResponse* response) {
479 TF_RETURN_IF_ERROR(CheckStarted());
480 uint64 fingerprint;
481 DatasetDef dataset_def = request->dataset();
482 GraphDef* graph = dataset_def.mutable_graph();
483 PrepareGraph(graph);
484 TF_RETURN_IF_ERROR(HashGraph(*graph, &fingerprint));
485 VLogLines(/*log_level=*/4,
486 absl::StrCat("Registering dataset graph: ", graph->DebugString()));
487
488 mutex_lock l(mu_);
489 TF_ASSIGN_OR_RETURN(std::optional<std::string> dataset_id,
490 FindDataset(*request, fingerprint));
491 if (dataset_id.has_value()) {
492 VLOG(3) << "RegisterDataset returns an existing dataset with ID = "
493 << *dataset_id << ", fingerprint = " << fingerprint << ".";
494 response->set_dataset_id(*dataset_id);
495 return Status::OK();
496 }
497
498 std::string new_dataset_id;
499 TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, dataset_def,
500 request->metadata(), request->dataset_id(),
501 new_dataset_id));
502 response->set_dataset_id(new_dataset_id);
503 VLOG(3) << "Registered new dataset with id " << new_dataset_id;
504 return OkStatus();
505 }
506
FindDataset(const GetOrRegisterDatasetRequest & request,uint64 fingerprint)507 StatusOr<std::optional<std::string>> DataServiceDispatcherImpl::FindDataset(
508 const GetOrRegisterDatasetRequest& request, uint64 fingerprint)
509 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
510 std::shared_ptr<const Dataset> existing_dataset;
511 Status status;
512 // TODO(b/236725000): Stop supporting fingerprint-based deduping. This becomes
513 // unreliable due to nondeterminism in the dataset graphdef generation. The
514 // users should provide a `dataset_id` to dedupe the dataset instead.
515 if (request.dataset_id().empty()) {
516 status = state_.DatasetFromFingerprint(fingerprint, existing_dataset);
517 } else {
518 status = state_.DatasetFromId(request.dataset_id(), existing_dataset);
519 }
520
521 if (errors::IsNotFound(status)) {
522 return std::optional<std::string>();
523 }
524 TF_RETURN_IF_ERROR(status);
525 if (!request.dataset_id().empty()) {
526 TF_RETURN_IF_ERROR(ValidateMatchingDataset(
527 request.dataset_id(), request.metadata(), existing_dataset->metadata));
528 }
529 return std::optional<std::string>(existing_dataset->dataset_id);
530 }
531
RegisterDataset(uint64 fingerprint,const DatasetDef & dataset,const DataServiceMetadata & metadata,const std::string & requested_dataset_id,std::string & dataset_id)532 Status DataServiceDispatcherImpl::RegisterDataset(
533 uint64 fingerprint, const DatasetDef& dataset,
534 const DataServiceMetadata& metadata,
535 const std::string& requested_dataset_id, std::string& dataset_id)
536 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
537 dataset_id = requested_dataset_id;
538 if (dataset_id.empty()) {
539 dataset_id = state_.NextAvailableDatasetId();
540 }
541 Update update;
542 RegisterDatasetUpdate* register_dataset = update.mutable_register_dataset();
543 register_dataset->set_dataset_id(dataset_id);
544 register_dataset->set_fingerprint(fingerprint);
545 *register_dataset->mutable_metadata() = metadata;
546 register_dataset->set_dedupe_by_dataset_id(!requested_dataset_id.empty());
547 TF_RETURN_IF_ERROR(
548 dataset_store_->Put(DatasetKey(dataset_id, fingerprint), dataset));
549 return Apply(update);
550 }
551
GetDataServiceMetadata(const GetDataServiceMetadataRequest * request,GetDataServiceMetadataResponse * response)552 Status DataServiceDispatcherImpl::GetDataServiceMetadata(
553 const GetDataServiceMetadataRequest* request,
554 GetDataServiceMetadataResponse* response) {
555 TF_RETURN_IF_ERROR(CheckStarted());
556 std::string dataset_id = request->dataset_id();
557 std::shared_ptr<const Dataset> dataset;
558
559 mutex_lock l(mu_);
560 TF_RETURN_IF_ERROR(state_.DatasetFromId(dataset_id, dataset));
561 VLOG(3) << "Get the data service metadata for dataset id: " << dataset_id
562 << ".";
563 *response->mutable_metadata() = dataset->metadata;
564 return OkStatus();
565 }
566
GetDataServiceConfig(const GetDataServiceConfigRequest * request,GetDataServiceConfigResponse * response)567 Status DataServiceDispatcherImpl::GetDataServiceConfig(
568 const GetDataServiceConfigRequest* request,
569 GetDataServiceConfigResponse* response) {
570 TF_RETURN_IF_ERROR(CheckStarted());
571 response->mutable_config()->set_deployment_mode(config_.deployment_mode());
572 return OkStatus();
573 }
574
GetOrCreateJob(const GetOrCreateJobRequest * request,GetOrCreateJobResponse * response)575 Status DataServiceDispatcherImpl::GetOrCreateJob(
576 const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) {
577 TF_RETURN_IF_ERROR(CheckStarted());
578 VLOG(3) << "GetOrCreateJob(" << request->DebugString() << ")";
579 std::shared_ptr<const Job> job;
580 {
581 mutex_lock l(mu_);
582 std::string job_name;
583 if (request->optional_job_name_case() == GetOrCreateJobRequest::kJobName) {
584 job_name = request->job_name();
585 } else {
586 job_name = absl::StrCat("anonymous_job_", state_.NextAvailableJobId(),
587 "_", random::New64());
588 }
589 Status s = state_.JobByName(job_name, job);
590 if (s.ok()) {
591 TF_RETURN_IF_ERROR(ValidateMatchingJob(job, *request));
592 } else if (errors::IsNotFound(s)) {
593 TF_RETURN_IF_ERROR(CreateJob(job_name, *request, job));
594 } else {
595 return s;
596 }
597 response->set_job_id(job->id);
598 }
599 VLOG(3) << "Received job id " << job->id << " for CreateJob("
600 << request->DebugString() << ")";
601 return Status::OK();
602 }
603
GetOrCreateIteration(const GetOrCreateIterationRequest * request,GetOrCreateIterationResponse * response)604 Status DataServiceDispatcherImpl::GetOrCreateIteration(
605 const GetOrCreateIterationRequest* request,
606 GetOrCreateIterationResponse* response) {
607 TF_RETURN_IF_ERROR(CheckStarted());
608 VLOG(3) << "GetOrCreateIteration(" << request->DebugString() << ")";
609 std::shared_ptr<const Iteration> iteration;
610 std::vector<std::shared_ptr<const Task>> tasks;
611 {
612 mutex_lock l(mu_);
613 std::shared_ptr<const Job> job;
614 TF_RETURN_IF_ERROR(state_.JobFromId(request->job_id(), job));
615 IterationKey key(job->job_name, request->repetition());
616 Status s = state_.IterationByKey(key, iteration);
617 if (!s.ok() && !errors::IsNotFound(s)) {
618 return s;
619 }
620 if (errors::IsNotFound(s) || iteration->garbage_collected) {
621 TF_RETURN_IF_ERROR(CreateIteration(*request, iteration));
622 TF_RETURN_IF_ERROR(CreateTasksForIteration(iteration, tasks));
623 }
624 int64_t iteration_client_id;
625 TF_RETURN_IF_ERROR(
626 AcquireIterationClientId(iteration, iteration_client_id));
627 response->set_iteration_client_id(iteration_client_id);
628 }
629 TF_RETURN_IF_ERROR(AssignTasks(tasks));
630 VLOG(3) << "Created iteration " << iteration->iteration_id
631 << " for CreateIteration(" << request->DebugString() << ")";
632 return OkStatus();
633 }
634
MaybeRemoveTask(const MaybeRemoveTaskRequest * request,MaybeRemoveTaskResponse * response)635 Status DataServiceDispatcherImpl::MaybeRemoveTask(
636 const MaybeRemoveTaskRequest* request, MaybeRemoveTaskResponse* response) {
637 VLOG(1) << "Attempting to remove task. Request: " << request->DebugString();
638 std::shared_ptr<TaskRemover> remover;
639 std::shared_ptr<const Task> task;
640 {
641 mutex_lock l(mu_);
642 Status s = state_.TaskFromId(request->task_id(), task);
643 if (errors::IsNotFound(s)) {
644 // Task is already removed.
645 response->set_removed(true);
646 return OkStatus();
647 }
648 TF_RETURN_IF_ERROR(s);
649 auto& remover_ref = remove_task_requests_[task->task_id];
650 if (remover_ref == nullptr) {
651 if (!task->iteration->IsRoundRobin()) {
652 return errors::FailedPrecondition(
653 "MaybeRemoveTask called on a non-round-robin task.");
654 }
655 remover_ref = std::make_shared<TaskRemover>(
656 task->iteration->job->num_consumers.value());
657 }
658 remover = remover_ref;
659 }
660 bool removed =
661 remover->RequestRemoval(request->consumer_index(), request->round());
662 response->set_removed(removed);
663 if (!removed) {
664 VLOG(1) << "Failed to remove task " << task->task_id;
665 return OkStatus();
666 }
667 mutex_lock l(mu_);
668 if (!task->removed) {
669 Update update;
670 RemoveTaskUpdate* remove_task = update.mutable_remove_task();
671 remove_task->set_task_id(request->task_id());
672 TF_RETURN_IF_ERROR(Apply(update));
673 }
674 VLOG(1) << "Task " << task->task_id << " successfully removed";
675 return OkStatus();
676 }
677
ReleaseIterationClient(const ReleaseIterationClientRequest * request,ReleaseIterationClientResponse * response)678 Status DataServiceDispatcherImpl::ReleaseIterationClient(
679 const ReleaseIterationClientRequest* request,
680 ReleaseIterationClientResponse* response) {
681 TF_RETURN_IF_ERROR(CheckStarted());
682 mutex_lock l(mu_);
683 int64_t iteration_client_id = request->iteration_client_id();
684 std::shared_ptr<const Iteration> iteration;
685 TF_RETURN_IF_ERROR(
686 state_.IterationForIterationClientId(iteration_client_id, iteration));
687 Update update;
688 ReleaseIterationClientUpdate* release_iteration_client =
689 update.mutable_release_iteration_client();
690 release_iteration_client->set_iteration_client_id(iteration_client_id);
691 release_iteration_client->set_time_micros(env_->NowMicros());
692 TF_RETURN_IF_ERROR(Apply(update));
693 return OkStatus();
694 }
695
696 // Validates that the job matches the requested processing mode.
ValidateMatchingJob(std::shared_ptr<const Job> job,const GetOrCreateJobRequest & request)697 Status DataServiceDispatcherImpl::ValidateMatchingJob(
698 std::shared_ptr<const Job> job, const GetOrCreateJobRequest& request)
699 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
700 std::string diff;
701 if (!MessageDifferencer::Equals(job->processing_mode,
702 request.processing_mode_def())) {
703 strings::StrAppend(&diff, "Existing processing mode: <",
704 job->processing_mode.ShortDebugString(), ">; got <",
705 request.processing_mode_def().ShortDebugString(), ">. ");
706 }
707
708 if (job->use_cross_trainer_cache != request.use_cross_trainer_cache()) {
709 strings::StrAppend(
710 &diff, "Existing cross-trainer cache: <",
711 (job->use_cross_trainer_cache ? "enabled" : "disabled"), ">; got <",
712 (request.use_cross_trainer_cache() ? "enabled" : "disabled"), ">. ");
713 }
714
715 if (job->target_workers != request.target_workers()) {
716 strings::StrAppend(&diff, "Existing target workers: <",
717 TargetWorkersToString(job->target_workers), ">; got <",
718 TargetWorkersToString(request.target_workers()), ">. ");
719 }
720
721 if (!diff.empty()) {
722 return errors::InvalidArgument(
723 "Tried to create job with name ", job->job_name,
724 ", but found an existing job with different parameters: ", diff);
725 }
726 return OkStatus();
727 }
728
CreateJob(const std::string & job_name,const GetOrCreateJobRequest & request,std::shared_ptr<const Job> & job)729 Status DataServiceDispatcherImpl::CreateJob(
730 const std::string& job_name, const GetOrCreateJobRequest& request,
731 std::shared_ptr<const Job>& job) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
732 TF_RETURN_IF_ERROR(ValidateProcessingMode(request.processing_mode_def()));
733 int64_t job_id = state_.NextAvailableJobId();
734 Update update;
735 CreateJobUpdate* create_job = update.mutable_create_job();
736 create_job->set_job_id(job_id);
737 create_job->set_job_name(job_name);
738 create_job->set_dataset_id(request.dataset_id());
739 *create_job->mutable_processing_mode_def() = request.processing_mode_def();
740 const bool is_coordinated_read = (request.optional_num_consumers_case() ==
741 GetOrCreateJobRequest::kNumConsumers);
742 if (is_coordinated_read) {
743 create_job->set_num_consumers(request.num_consumers());
744 }
745 create_job->set_target_workers(request.target_workers());
746 create_job->set_use_cross_trainer_cache(request.use_cross_trainer_cache());
747 TF_RETURN_IF_ERROR(Apply(update));
748 TF_RETURN_IF_ERROR(state_.JobFromId(job_id, job));
749 tensorflow::metrics::RecordTFDataServiceJobsCreated(
750 request.processing_mode_def(), is_coordinated_read);
751 return Status::OK();
752 }
753
CreateIteration(const GetOrCreateIterationRequest & request,std::shared_ptr<const Iteration> & iteration)754 Status DataServiceDispatcherImpl::CreateIteration(
755 const GetOrCreateIterationRequest& request,
756 std::shared_ptr<const Iteration>& iteration)
757 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
758 int64_t iteration_id = state_.NextAvailableIterationId();
759 int64_t num_split_providers = 0;
760 std::shared_ptr<const Job> job;
761 TF_RETURN_IF_ERROR(state_.JobFromId(request.job_id(), job));
762 if (IsDynamicShard(job->processing_mode)) {
763 TF_RETURN_IF_ERROR(
764 MakeSplitProviders(job->dataset_id, split_providers_[iteration_id]));
765 num_split_providers = split_providers_[iteration_id].size();
766 }
767 Update update;
768 CreateIterationUpdate* create_iteration = update.mutable_create_iteration();
769 create_iteration->set_iteration_id(iteration_id);
770 create_iteration->set_repetition(request.repetition());
771 create_iteration->set_job_id(request.job_id());
772 create_iteration->set_num_split_providers(num_split_providers);
773 TF_RETURN_IF_ERROR(Apply(update));
774 TF_RETURN_IF_ERROR(state_.IterationFromId(iteration_id, iteration));
775 return OkStatus();
776 }
777
CreateTasksForWorker(const std::string & worker_address)778 Status DataServiceDispatcherImpl::CreateTasksForWorker(
779 const std::string& worker_address) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
780 std::vector<std::shared_ptr<const Iteration>> iterations =
781 state_.ListIterations();
782 for (const auto& iteration : iterations) {
783 if (iteration->finished) {
784 continue;
785 }
786 if (iteration->job->num_consumers.has_value()) {
787 TF_RETURN_IF_ERROR(CreatePendingTask(iteration, worker_address));
788 continue;
789 }
790 std::shared_ptr<const Task> task;
791 TF_RETURN_IF_ERROR(CreateTask(iteration, worker_address, task));
792 }
793 return OkStatus();
794 }
795
AcquireIterationClientId(const std::shared_ptr<const Iteration> & iteration,int64_t & iteration_client_id)796 Status DataServiceDispatcherImpl::AcquireIterationClientId(
797 const std::shared_ptr<const Iteration>& iteration,
798 int64_t& iteration_client_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
799 iteration_client_id = state_.NextAvailableIterationClientId();
800 Update update;
801 AcquireIterationClientUpdate* acquire_iteration_client =
802 update.mutable_acquire_iteration_client();
803 acquire_iteration_client->set_iteration_client_id(iteration_client_id);
804 acquire_iteration_client->set_iteration_id(iteration->iteration_id);
805 TF_RETURN_IF_ERROR(Apply(update));
806 // Does not release clients before they start to read from the dataset.
807 latest_client_heartbeats_time_[iteration_client_id] = absl::InfiniteFuture();
808 return OkStatus();
809 }
810
CreateTasksForIteration(std::shared_ptr<const Iteration> iteration,std::vector<std::shared_ptr<const Task>> & tasks)811 Status DataServiceDispatcherImpl::CreateTasksForIteration(
812 std::shared_ptr<const Iteration> iteration,
813 std::vector<std::shared_ptr<const Task>>& tasks)
814 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
815 std::vector<std::shared_ptr<const Worker>> workers = state_.ListWorkers();
816 tasks.clear();
817 tasks.reserve(workers.size());
818 for (const auto& worker : workers) {
819 std::shared_ptr<const Task> task;
820 TF_RETURN_IF_ERROR(CreateTask(iteration, worker->address, task));
821 tasks.push_back(task);
822 }
823 return OkStatus();
824 }
825
CreatePendingTask(std::shared_ptr<const Iteration> iteration,const std::string & worker_address)826 Status DataServiceDispatcherImpl::CreatePendingTask(
827 std::shared_ptr<const Iteration> iteration,
828 const std::string& worker_address) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
829 int64_t task_id = state_.NextAvailableTaskId();
830 Update update;
831 CreatePendingTaskUpdate* create_task = update.mutable_create_pending_task();
832 create_task->set_task_id(task_id);
833 create_task->set_iteration_id(iteration->iteration_id);
834 create_task->set_worker_address(worker_address);
835 create_task->set_starting_round(round_robin_rounds_[iteration->iteration_id] +
836 1);
837 std::shared_ptr<const Worker> worker;
838 TF_RETURN_IF_ERROR(state_.WorkerFromAddress(worker_address, worker));
839 create_task->set_transfer_address(worker->transfer_address);
840 *create_task->mutable_worker_tags() = {worker->tags.begin(),
841 worker->tags.end()};
842 create_task->set_worker_uid(worker->uid);
843 TF_RETURN_IF_ERROR(Apply(update));
844 return OkStatus();
845 }
846
CreateTask(std::shared_ptr<const Iteration> iteration,const std::string & worker_address,std::shared_ptr<const Task> & task)847 Status DataServiceDispatcherImpl::CreateTask(
848 std::shared_ptr<const Iteration> iteration,
849 const std::string& worker_address, std::shared_ptr<const Task>& task)
850 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
851 int64_t task_id = state_.NextAvailableTaskId();
852 Update update;
853 CreateTaskUpdate* create_task = update.mutable_create_task();
854 create_task->set_task_id(task_id);
855 create_task->set_iteration_id(iteration->iteration_id);
856 create_task->set_worker_address(worker_address);
857 std::shared_ptr<const Worker> worker;
858 TF_RETURN_IF_ERROR(state_.WorkerFromAddress(worker_address, worker));
859 create_task->set_transfer_address(worker->transfer_address);
860 *create_task->mutable_worker_tags() = {worker->tags.begin(),
861 worker->tags.end()};
862 create_task->set_worker_uid(worker->uid);
863 TF_RETURN_IF_ERROR(Apply(update));
864 TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task));
865 return OkStatus();
866 }
867
AssignTasks(std::vector<std::shared_ptr<const Task>> tasks)868 Status DataServiceDispatcherImpl::AssignTasks(
869 std::vector<std::shared_ptr<const Task>> tasks) TF_LOCKS_EXCLUDED(mu_) {
870 for (const auto& task : tasks) {
871 TF_RETURN_IF_ERROR(AssignTask(task));
872 }
873 return OkStatus();
874 }
875
GetOrCreateWorkerStub(const std::string & worker_address,WorkerService::Stub * & out_stub)876 Status DataServiceDispatcherImpl::GetOrCreateWorkerStub(
877 const std::string& worker_address, WorkerService::Stub*& out_stub)
878 TF_LOCKS_EXCLUDED(mu_) {
879 {
880 mutex_lock l(mu_);
881 auto it = worker_stubs_.find(worker_address);
882 if (it != worker_stubs_.end()) {
883 out_stub = it->second.get();
884 return OkStatus();
885 }
886 }
887 std::unique_ptr<WorkerService::Stub> stub;
888 TF_RETURN_IF_ERROR(
889 CreateWorkerStub(worker_address, config_.protocol(), stub));
890 {
891 mutex_lock l(mu_);
892 // A concurrent call could have already created the stub.
893 auto& worker = worker_stubs_[worker_address];
894 if (worker == nullptr) {
895 worker = std::move(stub);
896 }
897 out_stub = worker.get();
898 }
899 return OkStatus();
900 }
901
AssignTask(std::shared_ptr<const Task> task)902 Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr<const Task> task)
903 TF_LOCKS_EXCLUDED(mu_) {
904 VLOG(2) << "Started assigning task " << task->task_id << " to worker "
905 << task->worker_address;
906 grpc::ClientContext client_ctx;
907 ProcessTaskRequest req;
908 TaskDef* task_def = req.mutable_task();
909 {
910 mutex_lock l(mu_);
911 TF_RETURN_IF_ERROR(PopulateTaskDef(task, task_def));
912 }
913 ProcessTaskResponse resp;
914 WorkerService::Stub* stub;
915 TF_RETURN_IF_ERROR(GetOrCreateWorkerStub(task->worker_address, stub));
916 grpc::Status s = stub->ProcessTask(&client_ctx, req, &resp);
917 if (!s.ok()) {
918 if (s.error_code() == grpc::StatusCode::UNAVAILABLE ||
919 s.error_code() == grpc::StatusCode::ABORTED ||
920 s.error_code() == grpc::StatusCode::CANCELLED) {
921 // Worker is presumably preempted. We will assign the task to the worker
922 // when it reconnects.
923 return OkStatus();
924 }
925 return grpc_util::WrapError(
926 absl::StrCat("Failed to submit task to worker ", task->worker_address),
927 s);
928 }
929 VLOG(2) << "Finished assigning task " << task->task_id << " to worker "
930 << task->worker_address;
931 return OkStatus();
932 }
933
ClientHeartbeat(const ClientHeartbeatRequest * request,ClientHeartbeatResponse * response)934 Status DataServiceDispatcherImpl::ClientHeartbeat(
935 const ClientHeartbeatRequest* request, ClientHeartbeatResponse* response) {
936 TF_RETURN_IF_ERROR(CheckStarted());
937 mutex_lock l(mu_);
938 VLOG(4) << "Received heartbeat from client id "
939 << request->iteration_client_id();
940 latest_client_heartbeats_time_[request->iteration_client_id()] =
941 absl::FromUnixMicros(env_->NowMicros());
942 std::shared_ptr<const Iteration> iteration;
943 Status s = state_.IterationForIterationClientId(
944 request->iteration_client_id(), iteration);
945 if (errors::IsNotFound(s) && !config_.fault_tolerant_mode()) {
946 return errors::NotFound(
947 "Unknown iteration client id ", request->iteration_client_id(),
948 ". The dispatcher is not configured to be fault tolerant, so this "
949 "could be caused by a dispatcher restart.");
950 }
951 TF_RETURN_IF_ERROR(s);
952 if (iteration->garbage_collected) {
953 return errors::FailedPrecondition(
954 "The requested iteration has been garbage collected due to inactivity. "
955 "Consider configuring the dispatcher with a higher "
956 "`iteration_gc_timeout_ms`.");
957 }
958 if (request->optional_current_round_case() ==
959 ClientHeartbeatRequest::kCurrentRound) {
960 round_robin_rounds_[request->iteration_client_id()] =
961 std::max(round_robin_rounds_[request->iteration_client_id()],
962 request->current_round());
963 }
964 if (!iteration->pending_tasks.empty()) {
965 const auto& task = iteration->pending_tasks.front();
966 Update update;
967 ClientHeartbeatUpdate* client_heartbeat = update.mutable_client_heartbeat();
968 bool apply_update = false;
969 client_heartbeat->set_iteration_client_id(request->iteration_client_id());
970 std::optional<int64_t> blocked_round;
971 if (request->optional_blocked_round_case() ==
972 ClientHeartbeatRequest::kBlockedRound) {
973 blocked_round = request->blocked_round();
974 }
975 VLOG(1) << "Handling pending task in iteration client heartbeat. "
976 "iteration_client_id: "
977 << request->iteration_client_id()
978 << ". current_round: " << request->current_round()
979 << ". blocked_round: " << blocked_round.value_or(-1)
980 << ". target_round: " << task.target_round;
981 if (request->current_round() >= task.target_round) {
982 TaskRejected* rejected = client_heartbeat->mutable_task_rejected();
983 // Exponentially try later and later rounds until consumers all agree.
984 int64_t round_offset = 2;
985 for (int i = 0; i < task.failures; ++i) {
986 round_offset *= 2;
987 }
988 rejected->set_new_target_round(
989 round_robin_rounds_[request->iteration_client_id()] + round_offset);
990 apply_update = true;
991 }
992 if (blocked_round.has_value() &&
993 blocked_round.value() <= task.target_round &&
994 !task.ready_consumers.contains(request->iteration_client_id())) {
995 client_heartbeat->set_task_accepted(true);
996 apply_update = true;
997 }
998 if (apply_update) {
999 TF_RETURN_IF_ERROR(Apply(update));
1000 }
1001 }
1002 if (!iteration->pending_tasks.empty()) {
1003 response->set_block_round(iteration->pending_tasks.front().target_round);
1004 }
1005
1006 std::vector<std::shared_ptr<const Task>> tasks;
1007 TF_RETURN_IF_ERROR(state_.TasksForIteration(iteration->iteration_id, tasks));
1008 for (const auto& task : tasks) {
1009 TaskInfo* task_info = response->mutable_task_info()->Add();
1010 task_info->set_worker_address(task->worker_address);
1011 task_info->set_transfer_address(task->transfer_address);
1012 *task_info->mutable_worker_tags() = {task->worker_tags.begin(),
1013 task->worker_tags.end()};
1014 task_info->set_task_id(task->task_id);
1015 task_info->set_iteration_id(iteration->iteration_id);
1016 task_info->set_worker_uid(task->worker_uid);
1017 task_info->set_starting_round(task->starting_round);
1018 }
1019 response->set_iteration_finished(iteration->finished);
1020 response->set_deployment_mode(config_.deployment_mode());
1021 VLOG(4) << "Found " << response->task_info_size()
1022 << " tasks for iteration client id "
1023 << request->iteration_client_id();
1024 return OkStatus();
1025 }
1026
GetWorkers(const GetWorkersRequest * request,GetWorkersResponse * response)1027 Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request,
1028 GetWorkersResponse* response) {
1029 TF_RETURN_IF_ERROR(CheckStarted());
1030 mutex_lock l(mu_);
1031 VLOG(3) << "Enter GetWorkers";
1032 std::vector<std::shared_ptr<const Worker>> workers = state_.ListWorkers();
1033 for (const auto& worker : workers) {
1034 WorkerInfo* info = response->add_workers();
1035 info->set_address(worker->address);
1036 }
1037 VLOG(3) << "Returning list of " << response->workers_size()
1038 << " workers from GetWorkers";
1039 return OkStatus();
1040 }
1041
PopulateTaskDef(std::shared_ptr<const Task> task,TaskDef * task_def) const1042 Status DataServiceDispatcherImpl::PopulateTaskDef(
1043 std::shared_ptr<const Task> task, TaskDef* task_def) const
1044 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1045 task_def->set_dataset_id(task->iteration->job->dataset_id);
1046 task_def->set_iteration_id(task->iteration->iteration_id);
1047 task_def->set_worker_address(task->worker_address);
1048 task_def->set_task_id(task->task_id);
1049 *task_def->mutable_processing_mode_def() =
1050 task->iteration->job->processing_mode;
1051 if (IsStaticShard(task->iteration->job->processing_mode)) {
1052 task_def->set_num_workers(config_.worker_addresses_size());
1053 TF_ASSIGN_OR_RETURN(int64_t worker_index,
1054 state_.GetWorkerIndex(task->worker_address));
1055 task_def->set_worker_index(worker_index);
1056 }
1057 if (task->iteration->distributed_epoch_state.has_value()) {
1058 task_def->set_num_split_providers(
1059 task->iteration->distributed_epoch_state.value().indices.size());
1060 }
1061 if (task->iteration->job->num_consumers.has_value()) {
1062 task_def->set_num_consumers(task->iteration->job->num_consumers.value());
1063 }
1064 task_def->set_use_cross_trainer_cache(
1065 task->iteration->job->use_cross_trainer_cache);
1066 std::shared_ptr<const Dataset> dataset;
1067 TF_RETURN_IF_ERROR(
1068 state_.DatasetFromId(task->iteration->job->dataset_id, dataset));
1069 std::string dataset_key =
1070 DatasetKey(dataset->dataset_id, dataset->fingerprint);
1071 if (config_.work_dir().empty()) {
1072 std::shared_ptr<const DatasetDef> dataset_def;
1073 TF_RETURN_IF_ERROR(dataset_store_->Get(dataset_key, dataset_def));
1074 *task_def->mutable_dataset_def() = *dataset_def;
1075 } else {
1076 std::string path =
1077 io::JoinPath(DatasetsDir(config_.work_dir()), dataset_key);
1078 task_def->set_path(path);
1079 }
1080 return OkStatus();
1081 }
1082
CheckStarted()1083 Status DataServiceDispatcherImpl::CheckStarted() TF_LOCKS_EXCLUDED(mu_) {
1084 mutex_lock l(mu_);
1085 if (!started_) {
1086 return errors::Unavailable("Dispatcher has not started yet.");
1087 }
1088 return OkStatus();
1089 }
1090
RecordSplitProduced(int64_t iteration_id,int64_t repetition,int64_t split_provider_index,bool finished)1091 Status DataServiceDispatcherImpl::RecordSplitProduced(
1092 int64_t iteration_id, int64_t repetition, int64_t split_provider_index,
1093 bool finished) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1094 Update update;
1095 ProduceSplitUpdate* produce_split = update.mutable_produce_split();
1096 produce_split->set_iteration_id(iteration_id);
1097 produce_split->set_repetition(repetition);
1098 produce_split->set_split_provider_index(split_provider_index);
1099 produce_split->set_finished(finished);
1100 return Apply(update);
1101 }
1102
ApplyWithoutJournaling(const Update & update)1103 Status DataServiceDispatcherImpl::ApplyWithoutJournaling(const Update& update)
1104 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1105 return state_.Apply(update);
1106 }
1107
Apply(const Update & update)1108 Status DataServiceDispatcherImpl::Apply(const Update& update)
1109 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1110 if (journal_writer_.has_value()) {
1111 TF_RETURN_IF_ERROR(journal_writer_.value()->Write(update));
1112 }
1113 return state_.Apply(update);
1114 }
1115
IterationGcThread()1116 void DataServiceDispatcherImpl::IterationGcThread() {
1117 int64_t next_check_micros = 0;
1118 while (true) {
1119 mutex_lock l(mu_);
1120 while (!cancelled_ && env_->NowMicros() < next_check_micros) {
1121 int64_t remaining_micros = next_check_micros - env_->NowMicros();
1122 iteration_gc_thread_cv_.wait_for(
1123 l, std::chrono::microseconds(remaining_micros));
1124 }
1125 if (cancelled_) {
1126 return;
1127 }
1128 {
1129 Status s = ReleaseMissingClients();
1130 if (!s.ok()) {
1131 LOG(WARNING) << "Error releasing missing clients: " << s;
1132 }
1133 }
1134
1135 {
1136 Status s = GcOldIterations();
1137 if (!s.ok()) {
1138 LOG(WARNING) << "Error garbage collecting old iterations: " << s;
1139 }
1140 }
1141 next_check_micros =
1142 env_->NowMicros() + (config_.job_gc_check_interval_ms() * 1000);
1143 }
1144 }
1145
ReleaseMissingClients()1146 Status DataServiceDispatcherImpl::ReleaseMissingClients()
1147 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1148 int64_t now = env_->NowMicros();
1149 for (const auto& client_id : state_.ListActiveClientIds()) {
1150 if (absl::FromUnixMicros(now) >
1151 latest_client_heartbeats_time_[client_id] +
1152 absl::Milliseconds(config_.client_timeout_ms())) {
1153 LOG(INFO) << "Releasing timed-out client with id " << client_id;
1154 Update update;
1155 ReleaseIterationClientUpdate* release_client =
1156 update.mutable_release_iteration_client();
1157 release_client->set_iteration_client_id(client_id);
1158 release_client->set_time_micros(now);
1159 TF_RETURN_IF_ERROR(Apply(update));
1160 }
1161 }
1162 return OkStatus();
1163 }
1164
GcOldIterations()1165 Status DataServiceDispatcherImpl::GcOldIterations()
1166 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1167 std::vector<std::shared_ptr<const Iteration>> iterations =
1168 state_.ListIterations();
1169 int64_t now = env_->NowMicros();
1170 for (const auto& iteration : iterations) {
1171 if (iteration->finished || iteration->num_clients > 0 ||
1172 iteration->last_client_released_micros < 0 ||
1173 now < iteration->last_client_released_micros +
1174 (config_.job_gc_timeout_ms() * 1000)) {
1175 continue;
1176 }
1177 Update update;
1178 update.mutable_garbage_collect_iteration()->set_iteration_id(
1179 iteration->iteration_id);
1180 TF_RETURN_IF_ERROR(state_.Apply(update));
1181 LOG(INFO) << "Garbage collected iteration " << iteration->DebugString();
1182 }
1183 return OkStatus();
1184 }
1185
GetDatasetDef(const std::string & dataset_id,std::shared_ptr<const DatasetDef> & dataset_def)1186 Status DataServiceDispatcherImpl::GetDatasetDef(
1187 const std::string& dataset_id,
1188 std::shared_ptr<const DatasetDef>& dataset_def)
1189 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1190 std::shared_ptr<const Dataset> dataset;
1191 TF_RETURN_IF_ERROR(state_.DatasetFromId(dataset_id, dataset));
1192 return GetDatasetDef(*dataset, dataset_def);
1193 }
1194
GetDatasetDef(const Dataset & dataset,std::shared_ptr<const DatasetDef> & dataset_def)1195 Status DataServiceDispatcherImpl::GetDatasetDef(
1196 const Dataset& dataset, std::shared_ptr<const DatasetDef>& dataset_def)
1197 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1198 std::string key = DatasetKey(dataset.dataset_id, dataset.fingerprint);
1199 return dataset_store_->Get(key, dataset_def);
1200 }
1201
ExportState() const1202 DispatcherStateExport DataServiceDispatcherImpl::ExportState() const
1203 TF_LOCKS_EXCLUDED(mu_) {
1204 DispatcherStateExport dispatcher_state_export;
1205 *dispatcher_state_export.mutable_dispatcher_config() = config_;
1206 mutex_lock l(mu_);
1207 if (!started_) {
1208 return dispatcher_state_export;
1209 }
1210
1211 std::vector<std::shared_ptr<const Worker>> workers = state_.ListWorkers();
1212 for (const auto& worker : workers) {
1213 dispatcher_state_export.add_worker_addresses(worker->address);
1214 }
1215
1216 std::vector<std::shared_ptr<const Iteration>> iterations =
1217 state_.ListIterations();
1218 for (const auto& iteration : iterations) {
1219 DispatcherStateExport::Iteration* iteration_export =
1220 dispatcher_state_export.add_iterations();
1221 iteration_export->set_dataset_id(iteration->job->dataset_id);
1222 iteration_export->set_iteration_id(iteration->iteration_id);
1223 iteration_export->mutable_iteration_key()->set_name(
1224 iteration->iteration_key.name);
1225 iteration_export->mutable_iteration_key()->set_iteration(
1226 iteration->iteration_key.repetition);
1227 *iteration_export->mutable_processing_mode() =
1228 iteration->job->processing_mode;
1229 if (iteration->job->num_consumers) {
1230 iteration_export->set_num_consumers(*iteration->job->num_consumers);
1231 }
1232 iteration_export->set_num_clients(iteration->num_clients);
1233 iteration_export->set_finished(iteration->finished);
1234 iteration_export->set_garbage_collected(iteration->garbage_collected);
1235 }
1236 return dispatcher_state_export;
1237 }
1238
1239 } // namespace data
1240 } // namespace tensorflow
1241