xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/dispatcher_impl.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/service/dispatcher_impl.h"
17 
18 #include <algorithm>
19 #include <array>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #ifdef PLATFORM_GOOGLE
27 #include "file/logging/log_lines.h"
28 #endif
29 #include "grpcpp/create_channel.h"
30 #include "grpcpp/impl/codegen/server_context.h"
31 #include "grpcpp/security/credentials.h"
32 #include "absl/container/flat_hash_map.h"
33 #include "absl/container/flat_hash_set.h"
34 #include "absl/memory/memory.h"
35 #include "absl/time/time.h"
36 #include "tensorflow/core/data/dataset_utils.h"
37 #include "tensorflow/core/data/hash_utils.h"
38 #include "tensorflow/core/data/service/common.h"
39 #include "tensorflow/core/data/service/common.pb.h"
40 #include "tensorflow/core/data/service/credentials_factory.h"
41 #include "tensorflow/core/data/service/dataset_store.h"
42 #include "tensorflow/core/data/service/dispatcher.pb.h"
43 #include "tensorflow/core/data/service/dispatcher_state.h"
44 #include "tensorflow/core/data/service/export.pb.h"
45 #include "tensorflow/core/data/service/grpc_util.h"
46 #include "tensorflow/core/data/service/journal.h"
47 #include "tensorflow/core/data/service/validate_utils.h"
48 #include "tensorflow/core/data/service/worker.grpc.pb.h"
49 #include "tensorflow/core/data/standalone.h"
50 #include "tensorflow/core/framework/dataset.h"
51 #include "tensorflow/core/framework/graph.pb.h"
52 #include "tensorflow/core/framework/metrics.h"
53 #include "tensorflow/core/framework/node_def.pb.h"
54 #include "tensorflow/core/framework/tensor.h"
55 #include "tensorflow/core/platform/env.h"
56 #include "tensorflow/core/platform/errors.h"
57 #include "tensorflow/core/platform/mutex.h"
58 #include "tensorflow/core/platform/path.h"
59 #include "tensorflow/core/platform/protobuf.h"
60 #include "tensorflow/core/platform/random.h"
61 #include "tensorflow/core/platform/status.h"
62 #include "tensorflow/core/platform/statusor.h"
63 #include "tensorflow/core/platform/strcat.h"
64 #include "tensorflow/core/platform/thread_annotations.h"
65 #include "tensorflow/core/protobuf/data_service.pb.h"
66 #include "tensorflow/core/protobuf/service_config.pb.h"
67 #include "tensorflow/core/public/session_options.h"
68 
69 namespace tensorflow {
70 namespace data {
71 namespace {
72 
73 using ::tensorflow::protobuf::util::MessageDifferencer;
74 
75 // The name of the journal directory inside the dispatcher's working directory.
76 // This name is load-bearing; do not change.
77 constexpr char kJournalDir[] = "tf_data_dispatcher_journal";
78 // The name of the datasets directory inside the dispatcher's working directory.
79 constexpr char kDatasetsDir[] = "datasets";
80 constexpr int64_t kDefaultIterationGcCheckIntervalMs =
81     10 * 60 * 1000;                                              // 10 minutes.
82 constexpr int64_t kDefaultIterationGcTimeoutMs = 5 * 60 * 1000;  // 5 minutes.
83 constexpr int64_t kDefaultClientTimeoutMs = 2 * 60 * 1000;       // 2 minutes.
84 
85 constexpr std::array<const char*, 8> kNodeNameSharingOps = {
86     "HashTable",
87     "HashTableV2",
88     "MutableHashTable",
89     "MutableHashTableV2",
90     "MutableDenseHashTable",
91     "MutableDenseHashTableV2",
92     "MutableHashTableOfTensors",
93     "MutableHashTableOfTensorsV2",
94 };
95 
96 using DispatcherConfig = experimental::DispatcherConfig;
97 using Dataset = DispatcherState::Dataset;
98 using Worker = DispatcherState::Worker;
99 using Job = DispatcherState::Job;
100 using IterationKey = DispatcherState::IterationKey;
101 using Iteration = DispatcherState::Iteration;
102 using Task = DispatcherState::Task;
103 
JournalDir(const std::string & work_dir)104 std::string JournalDir(const std::string& work_dir) {
105   return io::JoinPath(work_dir, kJournalDir);
106 }
107 
DatasetsDir(const std::string & work_dir)108 std::string DatasetsDir(const std::string& work_dir) {
109   return io::JoinPath(work_dir, kDatasetsDir);
110 }
111 
DatasetKey(const std::string & dataset_id,uint64 fingerprint)112 std::string DatasetKey(const std::string& dataset_id, uint64 fingerprint) {
113   return absl::StrCat("id_", dataset_id, "_fp_", fingerprint);
114 }
115 
CreateWorkerStub(const std::string & address,const std::string & protocol,std::unique_ptr<WorkerService::Stub> & stub)116 Status CreateWorkerStub(const std::string& address, const std::string& protocol,
117                         std::unique_ptr<WorkerService::Stub>& stub) {
118   ::grpc::ChannelArguments args;
119   args.SetMaxReceiveMessageSize(-1);
120   std::shared_ptr<::grpc::ChannelCredentials> credentials;
121   TF_RETURN_IF_ERROR(
122       CredentialsFactory::CreateClientCredentials(protocol, &credentials));
123   auto channel = ::grpc::CreateCustomChannel(address, credentials, args);
124   stub = WorkerService::NewStub(channel);
125   return OkStatus();
126 }
127 
PrepareGraph(GraphDef * graph)128 void PrepareGraph(GraphDef* graph) {
129   for (NodeDef& node : *graph->mutable_node()) {
130     for (const auto& op : kNodeNameSharingOps) {
131       // Set `use_node_name_sharing` to `true` so that resources aren't deleted
132       // prematurely. Otherwise, resources may be deleted when their ops are
133       // deleted at the end of the GraphRunner::Run used by standalone::Dataset.
134       if (node.op() == op) {
135         (*node.mutable_attr())["use_node_name_sharing"].set_b(true);
136       }
137       if (!node.device().empty()) {
138         *node.mutable_device() = "";
139       }
140     }
141   }
142   StripDevicePlacement(graph->mutable_library());
143 }
144 
ApplyConfigDefaults(const DispatcherConfig & config)145 DispatcherConfig ApplyConfigDefaults(const DispatcherConfig& config) {
146   DispatcherConfig new_config(config);
147   if (new_config.job_gc_check_interval_ms() == 0) {
148     new_config.set_job_gc_check_interval_ms(kDefaultIterationGcCheckIntervalMs);
149   }
150   if (new_config.job_gc_timeout_ms() == 0) {
151     new_config.set_job_gc_timeout_ms(kDefaultIterationGcTimeoutMs);
152   }
153   if (new_config.client_timeout_ms() == 0) {
154     new_config.set_client_timeout_ms(kDefaultClientTimeoutMs);
155   }
156   return new_config;
157 }
158 
VLogLines(const int log_level,const std::string & message)159 void VLogLines(const int log_level, const std::string& message) {
160 #if defined(PLATFORM_GOOGLE)
161   VLOG_LINES(log_level, message);
162 #else
163   VLOG(log_level) << message;
164 #endif
165 }
166 }  // namespace
167 
DataServiceDispatcherImpl(const DispatcherConfig & config)168 DataServiceDispatcherImpl::DataServiceDispatcherImpl(
169     const DispatcherConfig& config)
170     : config_(ApplyConfigDefaults(config)),
171       env_(Env::Default()),
172       state_(config_) {
173   if (config_.work_dir().empty()) {
174     dataset_store_ = std::make_unique<MemoryDatasetStore>();
175   } else {
176     dataset_store_ = std::make_unique<FileSystemDatasetStore>(
177         DatasetsDir(config_.work_dir()));
178   }
179 }
180 
~DataServiceDispatcherImpl()181 DataServiceDispatcherImpl::~DataServiceDispatcherImpl() {
182   {
183     mutex_lock l(mu_);
184     cancelled_ = true;
185     iteration_gc_thread_cv_.notify_all();
186   }
187   iteration_gc_thread_.reset();
188 }
189 
Start()190 Status DataServiceDispatcherImpl::Start() {
191   mutex_lock l(mu_);
192   if (config_.job_gc_timeout_ms() >= 0) {
193     iteration_gc_thread_ = absl::WrapUnique(env_->StartThread(
194         {}, "iteration-gc-thread", [&] { IterationGcThread(); }));
195   }
196   if (config_.work_dir().empty()) {
197     if (config_.fault_tolerant_mode()) {
198       return errors::InvalidArgument(
199           "fault_tolerant_mode is True, but no work_dir is configured.");
200     }
201   } else {
202     TF_RETURN_IF_ERROR(
203         env_->RecursivelyCreateDir(DatasetsDir(config_.work_dir())));
204   }
205   if (!config_.fault_tolerant_mode()) {
206     LOG(INFO) << "Running with fault_tolerant_mode=False. The dispatcher will "
207                  "not be able to recover its state on restart.";
208     started_ = true;
209     return OkStatus();
210   }
211   journal_writer_ =
212       std::make_unique<FileJournalWriter>(env_, JournalDir(config_.work_dir()));
213   LOG(INFO) << "Attempting to restore dispatcher state from journal in "
214             << JournalDir(config_.work_dir());
215   Update update;
216   bool end_of_journal = false;
217   FileJournalReader reader(env_, JournalDir(config_.work_dir()));
218   Status s = reader.Read(update, end_of_journal);
219   if (errors::IsNotFound(s)) {
220     LOG(INFO) << "No journal found. Starting dispatcher from new state.";
221   } else if (!s.ok()) {
222     return s;
223   } else {
224     while (!end_of_journal) {
225       TF_RETURN_IF_ERROR(ApplyWithoutJournaling(update));
226       TF_RETURN_IF_ERROR(reader.Read(update, end_of_journal));
227     }
228   }
229   for (const auto& iteration : state_.ListIterations()) {
230     if (IsDynamicShard(iteration->job->processing_mode)) {
231       TF_RETURN_IF_ERROR(RestoreSplitProviders(
232           *iteration, split_providers_[iteration->iteration_id]));
233     }
234   }
235   for (const auto& client_id : state_.ListActiveClientIds()) {
236     // Conservatively pretend we just received a heartbeat from all clients, so
237     // that we don't garbage collect iterations too early.
238     latest_client_heartbeats_time_[client_id] =
239         absl::FromUnixMicros(env_->NowMicros());
240   }
241   // Initialize the journal writer in `Start` so that we fail fast in case it
242   // can't be initialized.
243   TF_RETURN_IF_ERROR(journal_writer_.value()->EnsureInitialized());
244   started_ = true;
245   return OkStatus();
246 }
247 
NumActiveIterations()248 size_t DataServiceDispatcherImpl::NumActiveIterations() TF_LOCKS_EXCLUDED(mu_) {
249   mutex_lock l(mu_);
250   size_t count = 0;
251   for (const auto& iteration : state_.ListIterations()) {
252     if (!iteration->finished) {
253       count++;
254     }
255   }
256   return count;
257 }
258 
RestoreSplitProviders(const Iteration & iteration,std::vector<std::unique_ptr<SplitProvider>> & restored)259 Status DataServiceDispatcherImpl::RestoreSplitProviders(
260     const Iteration& iteration,
261     std::vector<std::unique_ptr<SplitProvider>>& restored)
262     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
263   const std::vector<int64_t>& indices =
264       iteration.distributed_epoch_state.value().indices;
265   std::vector<std::unique_ptr<SplitProvider>> split_providers;
266   TF_RETURN_IF_ERROR(
267       MakeSplitProviders(iteration.job->dataset_id, split_providers));
268   for (int provider_index = 0; provider_index < indices.size();
269        ++provider_index) {
270     int index = indices[provider_index];
271     VLOG(1) << "Restoring split provider " << provider_index
272             << " for iteration " << iteration.iteration_id << " to index "
273             << index;
274     Tensor unused_tensor;
275     bool unused_end_of_splits;
276     for (int i = 0; i < index; ++i) {
277       TF_RETURN_IF_ERROR(split_providers[provider_index]->GetNext(
278           &unused_tensor, &unused_end_of_splits));
279     }
280   }
281   restored = std::move(split_providers);
282   return OkStatus();
283 }
284 
FindTasksToDelete(const absl::flat_hash_set<int64_t> & current_tasks,const std::vector<std::shared_ptr<const Task>> assigned_tasks,WorkerHeartbeatResponse * response)285 Status DataServiceDispatcherImpl::FindTasksToDelete(
286     const absl::flat_hash_set<int64_t>& current_tasks,
287     const std::vector<std::shared_ptr<const Task>> assigned_tasks,
288     WorkerHeartbeatResponse* response) {
289   absl::flat_hash_set<int64_t> assigned_ids;
290   for (const auto& assigned : assigned_tasks) {
291     assigned_ids.insert(assigned->task_id);
292   }
293   for (int64_t current_task : current_tasks) {
294     if (!assigned_ids.contains(current_task)) {
295       response->add_tasks_to_delete(current_task);
296     }
297   }
298   return OkStatus();
299 }
300 
FindNewTasks(const std::string & worker_address,const absl::flat_hash_set<int64_t> & current_tasks,std::vector<std::shared_ptr<const Task>> & assigned_tasks,WorkerHeartbeatResponse * response)301 Status DataServiceDispatcherImpl::FindNewTasks(
302     const std::string& worker_address,
303     const absl::flat_hash_set<int64_t>& current_tasks,
304     std::vector<std::shared_ptr<const Task>>& assigned_tasks,
305     WorkerHeartbeatResponse* response) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
306   // Check for round-robin iterations that had tasks on the worker removed. Now
307   // that the worker is back, we create a new pending task for the worker.
308   absl::flat_hash_set<int64_t> assigned_iteration_ids;
309   for (const auto& task : assigned_tasks) {
310     assigned_iteration_ids.insert(task->iteration->iteration_id);
311   }
312   for (const auto& iteration : state_.ListIterations()) {
313     if (!assigned_iteration_ids.contains(iteration->iteration_id) &&
314         iteration->IsRoundRobin() && !iteration->finished) {
315       VLOG(1) << "Creating pending task for reconnected worker "
316               << worker_address;
317       TF_RETURN_IF_ERROR(CreatePendingTask(iteration, worker_address));
318     }
319   }
320   // Refresh assigned_tasks to include newly added pending tasks.
321   TF_RETURN_IF_ERROR(state_.TasksForWorker(worker_address, assigned_tasks));
322   for (const auto& task : assigned_tasks) {
323     if (current_tasks.contains(task->task_id)) {
324       continue;
325     }
326     TaskDef* task_def = response->add_new_tasks();
327     TF_RETURN_IF_ERROR(PopulateTaskDef(task, task_def));
328   }
329   return OkStatus();
330 }
331 
WorkerHeartbeat(const WorkerHeartbeatRequest * request,WorkerHeartbeatResponse * response)332 Status DataServiceDispatcherImpl::WorkerHeartbeat(
333     const WorkerHeartbeatRequest* request, WorkerHeartbeatResponse* response) {
334   TF_RETURN_IF_ERROR(CheckStarted());
335   VLOG(4) << "Received worker heartbeat request from worker "
336           << request->worker_address();
337   mutex_lock l(mu_);
338   const std::string& worker_address = request->worker_address();
339   // Assigned tasks from the perspective of the dispatcher.
340   std::vector<std::shared_ptr<const Task>> assigned_tasks;
341   Status s = state_.TasksForWorker(worker_address, assigned_tasks);
342   if (!s.ok()) {
343     if (!errors::IsNotFound(s)) {
344       return s;
345     }
346     VLOG(1) << "Registering new worker at address " << worker_address;
347     TF_RETURN_IF_ERROR(state_.ValidateWorker(worker_address));
348     Update update;
349     update.mutable_register_worker()->set_worker_address(worker_address);
350     update.mutable_register_worker()->set_transfer_address(
351         request->transfer_address());
352     *update.mutable_register_worker()->mutable_worker_tags() =
353         request->worker_tags();
354     update.mutable_register_worker()->set_worker_uid(request->worker_uid());
355     TF_RETURN_IF_ERROR(Apply(update));
356     TF_RETURN_IF_ERROR(CreateTasksForWorker(worker_address));
357     TF_RETURN_IF_ERROR(state_.TasksForWorker(worker_address, assigned_tasks));
358   }
359   absl::flat_hash_set<int64_t> current_tasks;
360   current_tasks.insert(request->current_tasks().cbegin(),
361                        request->current_tasks().cend());
362   TF_RETURN_IF_ERROR(
363       FindTasksToDelete(current_tasks, assigned_tasks, response));
364   TF_RETURN_IF_ERROR(
365       FindNewTasks(worker_address, current_tasks, assigned_tasks, response));
366 
367   VLOG(4) << "Finished worker heartbeat for worker at address "
368           << request->worker_address();
369   return OkStatus();
370 }
371 
WorkerUpdate(const WorkerUpdateRequest * request,WorkerUpdateResponse * response)372 Status DataServiceDispatcherImpl::WorkerUpdate(
373     const WorkerUpdateRequest* request, WorkerUpdateResponse* response) {
374   TF_RETURN_IF_ERROR(CheckStarted());
375   mutex_lock l(mu_);
376   for (auto& update : request->updates()) {
377     int64_t task_id = update.task_id();
378     std::shared_ptr<const Task> task;
379     TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task));
380     if (update.completed()) {
381       if (task->finished) {
382         VLOG(1) << "Received completion update for already-finished task "
383                 << task->task_id << " on worker " << task->worker_address;
384         continue;
385       }
386       Update update;
387       update.mutable_finish_task()->set_task_id(task_id);
388       TF_RETURN_IF_ERROR(Apply(update));
389       VLOG(3) << "Task " << task_id << " from iteration "
390               << task->iteration->iteration_id << " completed";
391     }
392   }
393   return OkStatus();
394 }
395 
GetDatasetDef(const GetDatasetDefRequest * request,GetDatasetDefResponse * response)396 Status DataServiceDispatcherImpl::GetDatasetDef(
397     const GetDatasetDefRequest* request, GetDatasetDefResponse* response) {
398   TF_RETURN_IF_ERROR(CheckStarted());
399   mutex_lock l(mu_);
400   std::shared_ptr<const Dataset> dataset;
401   TF_RETURN_IF_ERROR(state_.DatasetFromId(request->dataset_id(), dataset));
402   std::shared_ptr<const DatasetDef> dataset_def;
403   TF_RETURN_IF_ERROR(GetDatasetDef(*dataset, dataset_def));
404   *response->mutable_dataset_def() = *dataset_def;
405   return OkStatus();
406 }
407 
GetSplit(const GetSplitRequest * request,GetSplitResponse * response)408 Status DataServiceDispatcherImpl::GetSplit(const GetSplitRequest* request,
409                                            GetSplitResponse* response) {
410   TF_RETURN_IF_ERROR(CheckStarted());
411   mutex_lock l(mu_);
412   int64_t iteration_id = request->iteration_id();
413   int64_t repetition = request->repetition();
414   int64_t provider_index = request->split_provider_index();
415   VLOG(3) << "Received GetSplit request for iteration " << iteration_id
416           << ", repetition " << repetition << ", split provider index "
417           << provider_index;
418   std::shared_ptr<const Iteration> iteration;
419   TF_RETURN_IF_ERROR(state_.IterationFromId(iteration_id, iteration));
420   if (!iteration->distributed_epoch_state.has_value()) {
421     return errors::FailedPrecondition(
422         "Cannot get split for iteration ", iteration_id,
423         ", since it is not a distributed_epoch iteration.");
424   }
425   int64_t current_repetition =
426       iteration->distributed_epoch_state.value().repetitions[provider_index];
427   if (repetition < current_repetition) {
428     response->set_end_of_splits(true);
429     VLOG(3) << "Returning end_of_splits since current repetition "
430             << current_repetition
431             << " is greater than the requested repetition " << repetition;
432     return OkStatus();
433   }
434   SplitProvider* split_provider =
435       split_providers_[iteration_id][provider_index].get();
436   DCHECK(split_provider != nullptr);
437   Tensor split;
438   bool end_of_splits = false;
439   TF_RETURN_IF_ERROR(split_provider->GetNext(&split, &end_of_splits));
440   TF_RETURN_IF_ERROR(RecordSplitProduced(iteration_id, repetition,
441                                          request->split_provider_index(),
442                                          end_of_splits));
443   response->set_end_of_splits(end_of_splits);
444   if (end_of_splits) {
445     // Reset the split provider to prepare for the next iteration.
446     TF_RETURN_IF_ERROR(split_providers_[iteration_id][provider_index]->Reset());
447   } else {
448     split.AsProtoTensorContent(response->mutable_split());
449   }
450   VLOG(3) << "Returning from GetSplit, end_of_splits=" << end_of_splits;
451   return OkStatus();
452 }
453 
MakeSplitProviders(const std::string & dataset_id,std::vector<std::unique_ptr<SplitProvider>> & split_providers)454 Status DataServiceDispatcherImpl::MakeSplitProviders(
455     const std::string& dataset_id,
456     std::vector<std::unique_ptr<SplitProvider>>& split_providers)
457     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
458   std::shared_ptr<const Dataset> dataset;
459   TF_RETURN_IF_ERROR(state_.DatasetFromId(dataset_id, dataset));
460   std::shared_ptr<const DatasetDef> dataset_def;
461   TF_RETURN_IF_ERROR(GetDatasetDef(*dataset, dataset_def));
462   standalone::Dataset::Params params;
463   std::unique_ptr<standalone::Dataset> standalone_dataset;
464   TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
465       params, dataset_def->graph(), &standalone_dataset));
466   TF_RETURN_IF_ERROR(standalone_dataset->MakeSplitProviders(&split_providers));
467   return OkStatus();
468 }
469 
GetVersion(const GetVersionRequest * request,GetVersionResponse * response)470 Status DataServiceDispatcherImpl::GetVersion(const GetVersionRequest* request,
471                                              GetVersionResponse* response) {
472   response->set_version(kDataServiceVersion);
473   return OkStatus();
474 }
475 
GetOrRegisterDataset(const GetOrRegisterDatasetRequest * request,GetOrRegisterDatasetResponse * response)476 Status DataServiceDispatcherImpl::GetOrRegisterDataset(
477     const GetOrRegisterDatasetRequest* request,
478     GetOrRegisterDatasetResponse* response) {
479   TF_RETURN_IF_ERROR(CheckStarted());
480   uint64 fingerprint;
481   DatasetDef dataset_def = request->dataset();
482   GraphDef* graph = dataset_def.mutable_graph();
483   PrepareGraph(graph);
484   TF_RETURN_IF_ERROR(HashGraph(*graph, &fingerprint));
485   VLogLines(/*log_level=*/4,
486             absl::StrCat("Registering dataset graph: ", graph->DebugString()));
487 
488   mutex_lock l(mu_);
489   TF_ASSIGN_OR_RETURN(std::optional<std::string> dataset_id,
490                       FindDataset(*request, fingerprint));
491   if (dataset_id.has_value()) {
492     VLOG(3) << "RegisterDataset returns an existing dataset with ID = "
493             << *dataset_id << ", fingerprint = " << fingerprint << ".";
494     response->set_dataset_id(*dataset_id);
495     return Status::OK();
496   }
497 
498   std::string new_dataset_id;
499   TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, dataset_def,
500                                      request->metadata(), request->dataset_id(),
501                                      new_dataset_id));
502   response->set_dataset_id(new_dataset_id);
503   VLOG(3) << "Registered new dataset with id " << new_dataset_id;
504   return OkStatus();
505 }
506 
FindDataset(const GetOrRegisterDatasetRequest & request,uint64 fingerprint)507 StatusOr<std::optional<std::string>> DataServiceDispatcherImpl::FindDataset(
508     const GetOrRegisterDatasetRequest& request, uint64 fingerprint)
509     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
510   std::shared_ptr<const Dataset> existing_dataset;
511   Status status;
512   // TODO(b/236725000): Stop supporting fingerprint-based deduping. This becomes
513   // unreliable due to nondeterminism in the dataset graphdef generation. The
514   // users should provide a `dataset_id` to dedupe the dataset instead.
515   if (request.dataset_id().empty()) {
516     status = state_.DatasetFromFingerprint(fingerprint, existing_dataset);
517   } else {
518     status = state_.DatasetFromId(request.dataset_id(), existing_dataset);
519   }
520 
521   if (errors::IsNotFound(status)) {
522     return std::optional<std::string>();
523   }
524   TF_RETURN_IF_ERROR(status);
525   if (!request.dataset_id().empty()) {
526     TF_RETURN_IF_ERROR(ValidateMatchingDataset(
527         request.dataset_id(), request.metadata(), existing_dataset->metadata));
528   }
529   return std::optional<std::string>(existing_dataset->dataset_id);
530 }
531 
RegisterDataset(uint64 fingerprint,const DatasetDef & dataset,const DataServiceMetadata & metadata,const std::string & requested_dataset_id,std::string & dataset_id)532 Status DataServiceDispatcherImpl::RegisterDataset(
533     uint64 fingerprint, const DatasetDef& dataset,
534     const DataServiceMetadata& metadata,
535     const std::string& requested_dataset_id, std::string& dataset_id)
536     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
537   dataset_id = requested_dataset_id;
538   if (dataset_id.empty()) {
539     dataset_id = state_.NextAvailableDatasetId();
540   }
541   Update update;
542   RegisterDatasetUpdate* register_dataset = update.mutable_register_dataset();
543   register_dataset->set_dataset_id(dataset_id);
544   register_dataset->set_fingerprint(fingerprint);
545   *register_dataset->mutable_metadata() = metadata;
546   register_dataset->set_dedupe_by_dataset_id(!requested_dataset_id.empty());
547   TF_RETURN_IF_ERROR(
548       dataset_store_->Put(DatasetKey(dataset_id, fingerprint), dataset));
549   return Apply(update);
550 }
551 
GetDataServiceMetadata(const GetDataServiceMetadataRequest * request,GetDataServiceMetadataResponse * response)552 Status DataServiceDispatcherImpl::GetDataServiceMetadata(
553     const GetDataServiceMetadataRequest* request,
554     GetDataServiceMetadataResponse* response) {
555   TF_RETURN_IF_ERROR(CheckStarted());
556   std::string dataset_id = request->dataset_id();
557   std::shared_ptr<const Dataset> dataset;
558 
559   mutex_lock l(mu_);
560   TF_RETURN_IF_ERROR(state_.DatasetFromId(dataset_id, dataset));
561   VLOG(3) << "Get the data service metadata for dataset id: " << dataset_id
562           << ".";
563   *response->mutable_metadata() = dataset->metadata;
564   return OkStatus();
565 }
566 
GetDataServiceConfig(const GetDataServiceConfigRequest * request,GetDataServiceConfigResponse * response)567 Status DataServiceDispatcherImpl::GetDataServiceConfig(
568     const GetDataServiceConfigRequest* request,
569     GetDataServiceConfigResponse* response) {
570   TF_RETURN_IF_ERROR(CheckStarted());
571   response->mutable_config()->set_deployment_mode(config_.deployment_mode());
572   return OkStatus();
573 }
574 
GetOrCreateJob(const GetOrCreateJobRequest * request,GetOrCreateJobResponse * response)575 Status DataServiceDispatcherImpl::GetOrCreateJob(
576     const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) {
577   TF_RETURN_IF_ERROR(CheckStarted());
578   VLOG(3) << "GetOrCreateJob(" << request->DebugString() << ")";
579   std::shared_ptr<const Job> job;
580   {
581     mutex_lock l(mu_);
582     std::string job_name;
583     if (request->optional_job_name_case() == GetOrCreateJobRequest::kJobName) {
584       job_name = request->job_name();
585     } else {
586       job_name = absl::StrCat("anonymous_job_", state_.NextAvailableJobId(),
587                               "_", random::New64());
588     }
589     Status s = state_.JobByName(job_name, job);
590     if (s.ok()) {
591       TF_RETURN_IF_ERROR(ValidateMatchingJob(job, *request));
592     } else if (errors::IsNotFound(s)) {
593       TF_RETURN_IF_ERROR(CreateJob(job_name, *request, job));
594     } else {
595       return s;
596     }
597     response->set_job_id(job->id);
598   }
599   VLOG(3) << "Received job id " << job->id << " for CreateJob("
600           << request->DebugString() << ")";
601   return Status::OK();
602 }
603 
GetOrCreateIteration(const GetOrCreateIterationRequest * request,GetOrCreateIterationResponse * response)604 Status DataServiceDispatcherImpl::GetOrCreateIteration(
605     const GetOrCreateIterationRequest* request,
606     GetOrCreateIterationResponse* response) {
607   TF_RETURN_IF_ERROR(CheckStarted());
608   VLOG(3) << "GetOrCreateIteration(" << request->DebugString() << ")";
609   std::shared_ptr<const Iteration> iteration;
610   std::vector<std::shared_ptr<const Task>> tasks;
611   {
612     mutex_lock l(mu_);
613     std::shared_ptr<const Job> job;
614     TF_RETURN_IF_ERROR(state_.JobFromId(request->job_id(), job));
615     IterationKey key(job->job_name, request->repetition());
616     Status s = state_.IterationByKey(key, iteration);
617     if (!s.ok() && !errors::IsNotFound(s)) {
618       return s;
619     }
620     if (errors::IsNotFound(s) || iteration->garbage_collected) {
621       TF_RETURN_IF_ERROR(CreateIteration(*request, iteration));
622       TF_RETURN_IF_ERROR(CreateTasksForIteration(iteration, tasks));
623     }
624     int64_t iteration_client_id;
625     TF_RETURN_IF_ERROR(
626         AcquireIterationClientId(iteration, iteration_client_id));
627     response->set_iteration_client_id(iteration_client_id);
628   }
629   TF_RETURN_IF_ERROR(AssignTasks(tasks));
630   VLOG(3) << "Created iteration " << iteration->iteration_id
631           << " for CreateIteration(" << request->DebugString() << ")";
632   return OkStatus();
633 }
634 
MaybeRemoveTask(const MaybeRemoveTaskRequest * request,MaybeRemoveTaskResponse * response)635 Status DataServiceDispatcherImpl::MaybeRemoveTask(
636     const MaybeRemoveTaskRequest* request, MaybeRemoveTaskResponse* response) {
637   VLOG(1) << "Attempting to remove task. Request: " << request->DebugString();
638   std::shared_ptr<TaskRemover> remover;
639   std::shared_ptr<const Task> task;
640   {
641     mutex_lock l(mu_);
642     Status s = state_.TaskFromId(request->task_id(), task);
643     if (errors::IsNotFound(s)) {
644       // Task is already removed.
645       response->set_removed(true);
646       return OkStatus();
647     }
648     TF_RETURN_IF_ERROR(s);
649     auto& remover_ref = remove_task_requests_[task->task_id];
650     if (remover_ref == nullptr) {
651       if (!task->iteration->IsRoundRobin()) {
652         return errors::FailedPrecondition(
653             "MaybeRemoveTask called on a non-round-robin task.");
654       }
655       remover_ref = std::make_shared<TaskRemover>(
656           task->iteration->job->num_consumers.value());
657     }
658     remover = remover_ref;
659   }
660   bool removed =
661       remover->RequestRemoval(request->consumer_index(), request->round());
662   response->set_removed(removed);
663   if (!removed) {
664     VLOG(1) << "Failed to remove task " << task->task_id;
665     return OkStatus();
666   }
667   mutex_lock l(mu_);
668   if (!task->removed) {
669     Update update;
670     RemoveTaskUpdate* remove_task = update.mutable_remove_task();
671     remove_task->set_task_id(request->task_id());
672     TF_RETURN_IF_ERROR(Apply(update));
673   }
674   VLOG(1) << "Task " << task->task_id << " successfully removed";
675   return OkStatus();
676 }
677 
ReleaseIterationClient(const ReleaseIterationClientRequest * request,ReleaseIterationClientResponse * response)678 Status DataServiceDispatcherImpl::ReleaseIterationClient(
679     const ReleaseIterationClientRequest* request,
680     ReleaseIterationClientResponse* response) {
681   TF_RETURN_IF_ERROR(CheckStarted());
682   mutex_lock l(mu_);
683   int64_t iteration_client_id = request->iteration_client_id();
684   std::shared_ptr<const Iteration> iteration;
685   TF_RETURN_IF_ERROR(
686       state_.IterationForIterationClientId(iteration_client_id, iteration));
687   Update update;
688   ReleaseIterationClientUpdate* release_iteration_client =
689       update.mutable_release_iteration_client();
690   release_iteration_client->set_iteration_client_id(iteration_client_id);
691   release_iteration_client->set_time_micros(env_->NowMicros());
692   TF_RETURN_IF_ERROR(Apply(update));
693   return OkStatus();
694 }
695 
696 // Validates that the job matches the requested processing mode.
ValidateMatchingJob(std::shared_ptr<const Job> job,const GetOrCreateJobRequest & request)697 Status DataServiceDispatcherImpl::ValidateMatchingJob(
698     std::shared_ptr<const Job> job, const GetOrCreateJobRequest& request)
699     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
700   std::string diff;
701   if (!MessageDifferencer::Equals(job->processing_mode,
702                                   request.processing_mode_def())) {
703     strings::StrAppend(&diff, "Existing processing mode: <",
704                        job->processing_mode.ShortDebugString(), ">; got <",
705                        request.processing_mode_def().ShortDebugString(), ">. ");
706   }
707 
708   if (job->use_cross_trainer_cache != request.use_cross_trainer_cache()) {
709     strings::StrAppend(
710         &diff, "Existing cross-trainer cache: <",
711         (job->use_cross_trainer_cache ? "enabled" : "disabled"), ">; got <",
712         (request.use_cross_trainer_cache() ? "enabled" : "disabled"), ">. ");
713   }
714 
715   if (job->target_workers != request.target_workers()) {
716     strings::StrAppend(&diff, "Existing target workers: <",
717                        TargetWorkersToString(job->target_workers), ">; got <",
718                        TargetWorkersToString(request.target_workers()), ">. ");
719   }
720 
721   if (!diff.empty()) {
722     return errors::InvalidArgument(
723         "Tried to create job with name ", job->job_name,
724         ", but found an existing job with different parameters: ", diff);
725   }
726   return OkStatus();
727 }
728 
CreateJob(const std::string & job_name,const GetOrCreateJobRequest & request,std::shared_ptr<const Job> & job)729 Status DataServiceDispatcherImpl::CreateJob(
730     const std::string& job_name, const GetOrCreateJobRequest& request,
731     std::shared_ptr<const Job>& job) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
732   TF_RETURN_IF_ERROR(ValidateProcessingMode(request.processing_mode_def()));
733   int64_t job_id = state_.NextAvailableJobId();
734   Update update;
735   CreateJobUpdate* create_job = update.mutable_create_job();
736   create_job->set_job_id(job_id);
737   create_job->set_job_name(job_name);
738   create_job->set_dataset_id(request.dataset_id());
739   *create_job->mutable_processing_mode_def() = request.processing_mode_def();
740   const bool is_coordinated_read = (request.optional_num_consumers_case() ==
741                                     GetOrCreateJobRequest::kNumConsumers);
742   if (is_coordinated_read) {
743     create_job->set_num_consumers(request.num_consumers());
744   }
745   create_job->set_target_workers(request.target_workers());
746   create_job->set_use_cross_trainer_cache(request.use_cross_trainer_cache());
747   TF_RETURN_IF_ERROR(Apply(update));
748   TF_RETURN_IF_ERROR(state_.JobFromId(job_id, job));
749   tensorflow::metrics::RecordTFDataServiceJobsCreated(
750       request.processing_mode_def(), is_coordinated_read);
751   return Status::OK();
752 }
753 
CreateIteration(const GetOrCreateIterationRequest & request,std::shared_ptr<const Iteration> & iteration)754 Status DataServiceDispatcherImpl::CreateIteration(
755     const GetOrCreateIterationRequest& request,
756     std::shared_ptr<const Iteration>& iteration)
757     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
758   int64_t iteration_id = state_.NextAvailableIterationId();
759   int64_t num_split_providers = 0;
760   std::shared_ptr<const Job> job;
761   TF_RETURN_IF_ERROR(state_.JobFromId(request.job_id(), job));
762   if (IsDynamicShard(job->processing_mode)) {
763     TF_RETURN_IF_ERROR(
764         MakeSplitProviders(job->dataset_id, split_providers_[iteration_id]));
765     num_split_providers = split_providers_[iteration_id].size();
766   }
767   Update update;
768   CreateIterationUpdate* create_iteration = update.mutable_create_iteration();
769   create_iteration->set_iteration_id(iteration_id);
770   create_iteration->set_repetition(request.repetition());
771   create_iteration->set_job_id(request.job_id());
772   create_iteration->set_num_split_providers(num_split_providers);
773   TF_RETURN_IF_ERROR(Apply(update));
774   TF_RETURN_IF_ERROR(state_.IterationFromId(iteration_id, iteration));
775   return OkStatus();
776 }
777 
CreateTasksForWorker(const std::string & worker_address)778 Status DataServiceDispatcherImpl::CreateTasksForWorker(
779     const std::string& worker_address) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
780   std::vector<std::shared_ptr<const Iteration>> iterations =
781       state_.ListIterations();
782   for (const auto& iteration : iterations) {
783     if (iteration->finished) {
784       continue;
785     }
786     if (iteration->job->num_consumers.has_value()) {
787       TF_RETURN_IF_ERROR(CreatePendingTask(iteration, worker_address));
788       continue;
789     }
790     std::shared_ptr<const Task> task;
791     TF_RETURN_IF_ERROR(CreateTask(iteration, worker_address, task));
792   }
793   return OkStatus();
794 }
795 
AcquireIterationClientId(const std::shared_ptr<const Iteration> & iteration,int64_t & iteration_client_id)796 Status DataServiceDispatcherImpl::AcquireIterationClientId(
797     const std::shared_ptr<const Iteration>& iteration,
798     int64_t& iteration_client_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
799   iteration_client_id = state_.NextAvailableIterationClientId();
800   Update update;
801   AcquireIterationClientUpdate* acquire_iteration_client =
802       update.mutable_acquire_iteration_client();
803   acquire_iteration_client->set_iteration_client_id(iteration_client_id);
804   acquire_iteration_client->set_iteration_id(iteration->iteration_id);
805   TF_RETURN_IF_ERROR(Apply(update));
806   // Does not release clients before they start to read from the dataset.
807   latest_client_heartbeats_time_[iteration_client_id] = absl::InfiniteFuture();
808   return OkStatus();
809 }
810 
CreateTasksForIteration(std::shared_ptr<const Iteration> iteration,std::vector<std::shared_ptr<const Task>> & tasks)811 Status DataServiceDispatcherImpl::CreateTasksForIteration(
812     std::shared_ptr<const Iteration> iteration,
813     std::vector<std::shared_ptr<const Task>>& tasks)
814     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
815   std::vector<std::shared_ptr<const Worker>> workers = state_.ListWorkers();
816   tasks.clear();
817   tasks.reserve(workers.size());
818   for (const auto& worker : workers) {
819     std::shared_ptr<const Task> task;
820     TF_RETURN_IF_ERROR(CreateTask(iteration, worker->address, task));
821     tasks.push_back(task);
822   }
823   return OkStatus();
824 }
825 
CreatePendingTask(std::shared_ptr<const Iteration> iteration,const std::string & worker_address)826 Status DataServiceDispatcherImpl::CreatePendingTask(
827     std::shared_ptr<const Iteration> iteration,
828     const std::string& worker_address) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
829   int64_t task_id = state_.NextAvailableTaskId();
830   Update update;
831   CreatePendingTaskUpdate* create_task = update.mutable_create_pending_task();
832   create_task->set_task_id(task_id);
833   create_task->set_iteration_id(iteration->iteration_id);
834   create_task->set_worker_address(worker_address);
835   create_task->set_starting_round(round_robin_rounds_[iteration->iteration_id] +
836                                   1);
837   std::shared_ptr<const Worker> worker;
838   TF_RETURN_IF_ERROR(state_.WorkerFromAddress(worker_address, worker));
839   create_task->set_transfer_address(worker->transfer_address);
840   *create_task->mutable_worker_tags() = {worker->tags.begin(),
841                                          worker->tags.end()};
842   create_task->set_worker_uid(worker->uid);
843   TF_RETURN_IF_ERROR(Apply(update));
844   return OkStatus();
845 }
846 
CreateTask(std::shared_ptr<const Iteration> iteration,const std::string & worker_address,std::shared_ptr<const Task> & task)847 Status DataServiceDispatcherImpl::CreateTask(
848     std::shared_ptr<const Iteration> iteration,
849     const std::string& worker_address, std::shared_ptr<const Task>& task)
850     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
851   int64_t task_id = state_.NextAvailableTaskId();
852   Update update;
853   CreateTaskUpdate* create_task = update.mutable_create_task();
854   create_task->set_task_id(task_id);
855   create_task->set_iteration_id(iteration->iteration_id);
856   create_task->set_worker_address(worker_address);
857   std::shared_ptr<const Worker> worker;
858   TF_RETURN_IF_ERROR(state_.WorkerFromAddress(worker_address, worker));
859   create_task->set_transfer_address(worker->transfer_address);
860   *create_task->mutable_worker_tags() = {worker->tags.begin(),
861                                          worker->tags.end()};
862   create_task->set_worker_uid(worker->uid);
863   TF_RETURN_IF_ERROR(Apply(update));
864   TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task));
865   return OkStatus();
866 }
867 
AssignTasks(std::vector<std::shared_ptr<const Task>> tasks)868 Status DataServiceDispatcherImpl::AssignTasks(
869     std::vector<std::shared_ptr<const Task>> tasks) TF_LOCKS_EXCLUDED(mu_) {
870   for (const auto& task : tasks) {
871     TF_RETURN_IF_ERROR(AssignTask(task));
872   }
873   return OkStatus();
874 }
875 
GetOrCreateWorkerStub(const std::string & worker_address,WorkerService::Stub * & out_stub)876 Status DataServiceDispatcherImpl::GetOrCreateWorkerStub(
877     const std::string& worker_address, WorkerService::Stub*& out_stub)
878     TF_LOCKS_EXCLUDED(mu_) {
879   {
880     mutex_lock l(mu_);
881     auto it = worker_stubs_.find(worker_address);
882     if (it != worker_stubs_.end()) {
883       out_stub = it->second.get();
884       return OkStatus();
885     }
886   }
887   std::unique_ptr<WorkerService::Stub> stub;
888   TF_RETURN_IF_ERROR(
889       CreateWorkerStub(worker_address, config_.protocol(), stub));
890   {
891     mutex_lock l(mu_);
892     // A concurrent call could have already created the stub.
893     auto& worker = worker_stubs_[worker_address];
894     if (worker == nullptr) {
895       worker = std::move(stub);
896     }
897     out_stub = worker.get();
898   }
899   return OkStatus();
900 }
901 
AssignTask(std::shared_ptr<const Task> task)902 Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr<const Task> task)
903     TF_LOCKS_EXCLUDED(mu_) {
904   VLOG(2) << "Started assigning task " << task->task_id << " to worker "
905           << task->worker_address;
906   grpc::ClientContext client_ctx;
907   ProcessTaskRequest req;
908   TaskDef* task_def = req.mutable_task();
909   {
910     mutex_lock l(mu_);
911     TF_RETURN_IF_ERROR(PopulateTaskDef(task, task_def));
912   }
913   ProcessTaskResponse resp;
914   WorkerService::Stub* stub;
915   TF_RETURN_IF_ERROR(GetOrCreateWorkerStub(task->worker_address, stub));
916   grpc::Status s = stub->ProcessTask(&client_ctx, req, &resp);
917   if (!s.ok()) {
918     if (s.error_code() == grpc::StatusCode::UNAVAILABLE ||
919         s.error_code() == grpc::StatusCode::ABORTED ||
920         s.error_code() == grpc::StatusCode::CANCELLED) {
921       // Worker is presumably preempted. We will assign the task to the worker
922       // when it reconnects.
923       return OkStatus();
924     }
925     return grpc_util::WrapError(
926         absl::StrCat("Failed to submit task to worker ", task->worker_address),
927         s);
928   }
929   VLOG(2) << "Finished assigning task " << task->task_id << " to worker "
930           << task->worker_address;
931   return OkStatus();
932 }
933 
ClientHeartbeat(const ClientHeartbeatRequest * request,ClientHeartbeatResponse * response)934 Status DataServiceDispatcherImpl::ClientHeartbeat(
935     const ClientHeartbeatRequest* request, ClientHeartbeatResponse* response) {
936   TF_RETURN_IF_ERROR(CheckStarted());
937   mutex_lock l(mu_);
938   VLOG(4) << "Received heartbeat from client id "
939           << request->iteration_client_id();
940   latest_client_heartbeats_time_[request->iteration_client_id()] =
941       absl::FromUnixMicros(env_->NowMicros());
942   std::shared_ptr<const Iteration> iteration;
943   Status s = state_.IterationForIterationClientId(
944       request->iteration_client_id(), iteration);
945   if (errors::IsNotFound(s) && !config_.fault_tolerant_mode()) {
946     return errors::NotFound(
947         "Unknown iteration client id ", request->iteration_client_id(),
948         ". The dispatcher is not configured to be fault tolerant, so this "
949         "could be caused by a dispatcher restart.");
950   }
951   TF_RETURN_IF_ERROR(s);
952   if (iteration->garbage_collected) {
953     return errors::FailedPrecondition(
954         "The requested iteration has been garbage collected due to inactivity. "
955         "Consider configuring the dispatcher with a higher "
956         "`iteration_gc_timeout_ms`.");
957   }
958   if (request->optional_current_round_case() ==
959       ClientHeartbeatRequest::kCurrentRound) {
960     round_robin_rounds_[request->iteration_client_id()] =
961         std::max(round_robin_rounds_[request->iteration_client_id()],
962                  request->current_round());
963   }
964   if (!iteration->pending_tasks.empty()) {
965     const auto& task = iteration->pending_tasks.front();
966     Update update;
967     ClientHeartbeatUpdate* client_heartbeat = update.mutable_client_heartbeat();
968     bool apply_update = false;
969     client_heartbeat->set_iteration_client_id(request->iteration_client_id());
970     std::optional<int64_t> blocked_round;
971     if (request->optional_blocked_round_case() ==
972         ClientHeartbeatRequest::kBlockedRound) {
973       blocked_round = request->blocked_round();
974     }
975     VLOG(1) << "Handling pending task in iteration client heartbeat. "
976                "iteration_client_id: "
977             << request->iteration_client_id()
978             << ". current_round: " << request->current_round()
979             << ". blocked_round: " << blocked_round.value_or(-1)
980             << ". target_round: " << task.target_round;
981     if (request->current_round() >= task.target_round) {
982       TaskRejected* rejected = client_heartbeat->mutable_task_rejected();
983       // Exponentially try later and later rounds until consumers all agree.
984       int64_t round_offset = 2;
985       for (int i = 0; i < task.failures; ++i) {
986         round_offset *= 2;
987       }
988       rejected->set_new_target_round(
989           round_robin_rounds_[request->iteration_client_id()] + round_offset);
990       apply_update = true;
991     }
992     if (blocked_round.has_value() &&
993         blocked_round.value() <= task.target_round &&
994         !task.ready_consumers.contains(request->iteration_client_id())) {
995       client_heartbeat->set_task_accepted(true);
996       apply_update = true;
997     }
998     if (apply_update) {
999       TF_RETURN_IF_ERROR(Apply(update));
1000     }
1001   }
1002   if (!iteration->pending_tasks.empty()) {
1003     response->set_block_round(iteration->pending_tasks.front().target_round);
1004   }
1005 
1006   std::vector<std::shared_ptr<const Task>> tasks;
1007   TF_RETURN_IF_ERROR(state_.TasksForIteration(iteration->iteration_id, tasks));
1008   for (const auto& task : tasks) {
1009     TaskInfo* task_info = response->mutable_task_info()->Add();
1010     task_info->set_worker_address(task->worker_address);
1011     task_info->set_transfer_address(task->transfer_address);
1012     *task_info->mutable_worker_tags() = {task->worker_tags.begin(),
1013                                          task->worker_tags.end()};
1014     task_info->set_task_id(task->task_id);
1015     task_info->set_iteration_id(iteration->iteration_id);
1016     task_info->set_worker_uid(task->worker_uid);
1017     task_info->set_starting_round(task->starting_round);
1018   }
1019   response->set_iteration_finished(iteration->finished);
1020   response->set_deployment_mode(config_.deployment_mode());
1021   VLOG(4) << "Found " << response->task_info_size()
1022           << " tasks for iteration client id "
1023           << request->iteration_client_id();
1024   return OkStatus();
1025 }
1026 
GetWorkers(const GetWorkersRequest * request,GetWorkersResponse * response)1027 Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request,
1028                                              GetWorkersResponse* response) {
1029   TF_RETURN_IF_ERROR(CheckStarted());
1030   mutex_lock l(mu_);
1031   VLOG(3) << "Enter GetWorkers";
1032   std::vector<std::shared_ptr<const Worker>> workers = state_.ListWorkers();
1033   for (const auto& worker : workers) {
1034     WorkerInfo* info = response->add_workers();
1035     info->set_address(worker->address);
1036   }
1037   VLOG(3) << "Returning list of " << response->workers_size()
1038           << " workers from GetWorkers";
1039   return OkStatus();
1040 }
1041 
PopulateTaskDef(std::shared_ptr<const Task> task,TaskDef * task_def) const1042 Status DataServiceDispatcherImpl::PopulateTaskDef(
1043     std::shared_ptr<const Task> task, TaskDef* task_def) const
1044     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1045   task_def->set_dataset_id(task->iteration->job->dataset_id);
1046   task_def->set_iteration_id(task->iteration->iteration_id);
1047   task_def->set_worker_address(task->worker_address);
1048   task_def->set_task_id(task->task_id);
1049   *task_def->mutable_processing_mode_def() =
1050       task->iteration->job->processing_mode;
1051   if (IsStaticShard(task->iteration->job->processing_mode)) {
1052     task_def->set_num_workers(config_.worker_addresses_size());
1053     TF_ASSIGN_OR_RETURN(int64_t worker_index,
1054                         state_.GetWorkerIndex(task->worker_address));
1055     task_def->set_worker_index(worker_index);
1056   }
1057   if (task->iteration->distributed_epoch_state.has_value()) {
1058     task_def->set_num_split_providers(
1059         task->iteration->distributed_epoch_state.value().indices.size());
1060   }
1061   if (task->iteration->job->num_consumers.has_value()) {
1062     task_def->set_num_consumers(task->iteration->job->num_consumers.value());
1063   }
1064   task_def->set_use_cross_trainer_cache(
1065       task->iteration->job->use_cross_trainer_cache);
1066   std::shared_ptr<const Dataset> dataset;
1067   TF_RETURN_IF_ERROR(
1068       state_.DatasetFromId(task->iteration->job->dataset_id, dataset));
1069   std::string dataset_key =
1070       DatasetKey(dataset->dataset_id, dataset->fingerprint);
1071   if (config_.work_dir().empty()) {
1072     std::shared_ptr<const DatasetDef> dataset_def;
1073     TF_RETURN_IF_ERROR(dataset_store_->Get(dataset_key, dataset_def));
1074     *task_def->mutable_dataset_def() = *dataset_def;
1075   } else {
1076     std::string path =
1077         io::JoinPath(DatasetsDir(config_.work_dir()), dataset_key);
1078     task_def->set_path(path);
1079   }
1080   return OkStatus();
1081 }
1082 
CheckStarted()1083 Status DataServiceDispatcherImpl::CheckStarted() TF_LOCKS_EXCLUDED(mu_) {
1084   mutex_lock l(mu_);
1085   if (!started_) {
1086     return errors::Unavailable("Dispatcher has not started yet.");
1087   }
1088   return OkStatus();
1089 }
1090 
RecordSplitProduced(int64_t iteration_id,int64_t repetition,int64_t split_provider_index,bool finished)1091 Status DataServiceDispatcherImpl::RecordSplitProduced(
1092     int64_t iteration_id, int64_t repetition, int64_t split_provider_index,
1093     bool finished) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1094   Update update;
1095   ProduceSplitUpdate* produce_split = update.mutable_produce_split();
1096   produce_split->set_iteration_id(iteration_id);
1097   produce_split->set_repetition(repetition);
1098   produce_split->set_split_provider_index(split_provider_index);
1099   produce_split->set_finished(finished);
1100   return Apply(update);
1101 }
1102 
ApplyWithoutJournaling(const Update & update)1103 Status DataServiceDispatcherImpl::ApplyWithoutJournaling(const Update& update)
1104     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1105   return state_.Apply(update);
1106 }
1107 
Apply(const Update & update)1108 Status DataServiceDispatcherImpl::Apply(const Update& update)
1109     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1110   if (journal_writer_.has_value()) {
1111     TF_RETURN_IF_ERROR(journal_writer_.value()->Write(update));
1112   }
1113   return state_.Apply(update);
1114 }
1115 
IterationGcThread()1116 void DataServiceDispatcherImpl::IterationGcThread() {
1117   int64_t next_check_micros = 0;
1118   while (true) {
1119     mutex_lock l(mu_);
1120     while (!cancelled_ && env_->NowMicros() < next_check_micros) {
1121       int64_t remaining_micros = next_check_micros - env_->NowMicros();
1122       iteration_gc_thread_cv_.wait_for(
1123           l, std::chrono::microseconds(remaining_micros));
1124     }
1125     if (cancelled_) {
1126       return;
1127     }
1128     {
1129       Status s = ReleaseMissingClients();
1130       if (!s.ok()) {
1131         LOG(WARNING) << "Error releasing missing clients: " << s;
1132       }
1133     }
1134 
1135     {
1136       Status s = GcOldIterations();
1137       if (!s.ok()) {
1138         LOG(WARNING) << "Error garbage collecting old iterations: " << s;
1139       }
1140     }
1141     next_check_micros =
1142         env_->NowMicros() + (config_.job_gc_check_interval_ms() * 1000);
1143   }
1144 }
1145 
ReleaseMissingClients()1146 Status DataServiceDispatcherImpl::ReleaseMissingClients()
1147     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1148   int64_t now = env_->NowMicros();
1149   for (const auto& client_id : state_.ListActiveClientIds()) {
1150     if (absl::FromUnixMicros(now) >
1151         latest_client_heartbeats_time_[client_id] +
1152             absl::Milliseconds(config_.client_timeout_ms())) {
1153       LOG(INFO) << "Releasing timed-out client with id " << client_id;
1154       Update update;
1155       ReleaseIterationClientUpdate* release_client =
1156           update.mutable_release_iteration_client();
1157       release_client->set_iteration_client_id(client_id);
1158       release_client->set_time_micros(now);
1159       TF_RETURN_IF_ERROR(Apply(update));
1160     }
1161   }
1162   return OkStatus();
1163 }
1164 
GcOldIterations()1165 Status DataServiceDispatcherImpl::GcOldIterations()
1166     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1167   std::vector<std::shared_ptr<const Iteration>> iterations =
1168       state_.ListIterations();
1169   int64_t now = env_->NowMicros();
1170   for (const auto& iteration : iterations) {
1171     if (iteration->finished || iteration->num_clients > 0 ||
1172         iteration->last_client_released_micros < 0 ||
1173         now < iteration->last_client_released_micros +
1174                   (config_.job_gc_timeout_ms() * 1000)) {
1175       continue;
1176     }
1177     Update update;
1178     update.mutable_garbage_collect_iteration()->set_iteration_id(
1179         iteration->iteration_id);
1180     TF_RETURN_IF_ERROR(state_.Apply(update));
1181     LOG(INFO) << "Garbage collected iteration " << iteration->DebugString();
1182   }
1183   return OkStatus();
1184 }
1185 
GetDatasetDef(const std::string & dataset_id,std::shared_ptr<const DatasetDef> & dataset_def)1186 Status DataServiceDispatcherImpl::GetDatasetDef(
1187     const std::string& dataset_id,
1188     std::shared_ptr<const DatasetDef>& dataset_def)
1189     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1190   std::shared_ptr<const Dataset> dataset;
1191   TF_RETURN_IF_ERROR(state_.DatasetFromId(dataset_id, dataset));
1192   return GetDatasetDef(*dataset, dataset_def);
1193 }
1194 
GetDatasetDef(const Dataset & dataset,std::shared_ptr<const DatasetDef> & dataset_def)1195 Status DataServiceDispatcherImpl::GetDatasetDef(
1196     const Dataset& dataset, std::shared_ptr<const DatasetDef>& dataset_def)
1197     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1198   std::string key = DatasetKey(dataset.dataset_id, dataset.fingerprint);
1199   return dataset_store_->Get(key, dataset_def);
1200 }
1201 
ExportState() const1202 DispatcherStateExport DataServiceDispatcherImpl::ExportState() const
1203     TF_LOCKS_EXCLUDED(mu_) {
1204   DispatcherStateExport dispatcher_state_export;
1205   *dispatcher_state_export.mutable_dispatcher_config() = config_;
1206   mutex_lock l(mu_);
1207   if (!started_) {
1208     return dispatcher_state_export;
1209   }
1210 
1211   std::vector<std::shared_ptr<const Worker>> workers = state_.ListWorkers();
1212   for (const auto& worker : workers) {
1213     dispatcher_state_export.add_worker_addresses(worker->address);
1214   }
1215 
1216   std::vector<std::shared_ptr<const Iteration>> iterations =
1217       state_.ListIterations();
1218   for (const auto& iteration : iterations) {
1219     DispatcherStateExport::Iteration* iteration_export =
1220         dispatcher_state_export.add_iterations();
1221     iteration_export->set_dataset_id(iteration->job->dataset_id);
1222     iteration_export->set_iteration_id(iteration->iteration_id);
1223     iteration_export->mutable_iteration_key()->set_name(
1224         iteration->iteration_key.name);
1225     iteration_export->mutable_iteration_key()->set_iteration(
1226         iteration->iteration_key.repetition);
1227     *iteration_export->mutable_processing_mode() =
1228         iteration->job->processing_mode;
1229     if (iteration->job->num_consumers) {
1230       iteration_export->set_num_consumers(*iteration->job->num_consumers);
1231     }
1232     iteration_export->set_num_clients(iteration->num_clients);
1233     iteration_export->set_finished(iteration->finished);
1234     iteration_export->set_garbage_collected(iteration->garbage_collected);
1235   }
1236   return dispatcher_state_export;
1237 }
1238 
1239 }  // namespace data
1240 }  // namespace tensorflow
1241