xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/worker_impl.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/service/worker_impl.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "grpcpp/create_channel.h"
24 #include "absl/algorithm/container.h"
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/strings/string_view.h"
28 #include "absl/strings/substitute.h"
29 #include "tensorflow/c/c_api_internal.h"
30 #include "tensorflow/c/tf_status_helper.h"
31 #include "tensorflow/core/data/dataset.pb.h"
32 #include "tensorflow/core/data/service/auto_shard_rewriter.h"
33 #include "tensorflow/core/data/service/common.h"
34 #include "tensorflow/core/data/service/common.pb.h"
35 #include "tensorflow/core/data/service/credentials_factory.h"
36 #include "tensorflow/core/data/service/data_transfer.h"
37 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
38 #include "tensorflow/core/data/service/dispatcher.pb.h"
39 #include "tensorflow/core/data/service/dispatcher_client.h"
40 #include "tensorflow/core/data/service/export.pb.h"
41 #include "tensorflow/core/data/service/grpc_util.h"
42 #include "tensorflow/core/data/service/split_provider.h"
43 #include "tensorflow/core/data/service/task_runner.h"
44 #include "tensorflow/core/data/service/utils.h"
45 #include "tensorflow/core/data/service/worker.pb.h"
46 #include "tensorflow/core/data/standalone.h"
47 #include "tensorflow/core/framework/dataset_options.pb.h"
48 #include "tensorflow/core/framework/metrics.h"
49 #include "tensorflow/core/framework/tensor.h"
50 #include "tensorflow/core/framework/tensor.pb.h"
51 #include "tensorflow/core/lib/core/errors.h"
52 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
53 #include "tensorflow/core/lib/monitoring/gauge.h"
54 #include "tensorflow/core/platform/env.h"
55 #include "tensorflow/core/platform/errors.h"
56 #include "tensorflow/core/platform/host_info.h"
57 #include "tensorflow/core/platform/refcount.h"
58 #include "tensorflow/core/platform/snappy.h"
59 #include "tensorflow/core/platform/status.h"
60 #include "tensorflow/core/platform/statusor.h"
61 #include "tensorflow/core/platform/thread_annotations.h"
62 #include "tensorflow/core/protobuf/service_config.pb.h"
63 #include "tensorflow/core/public/session_options.h"
64 
65 namespace tensorflow {
66 namespace data {
67 namespace {
68 
69 constexpr int64_t kRetryIntervalMicros = 5 * 1000 * 1000;        // 5 seconds.
70 constexpr int64_t kDefaultHeartBeatIntervalMs = 30 * 1000;       // 30 seconds.
71 constexpr int64_t kDefaultDispatcherTimeoutMs = 60 * 60 * 1000;  // 1 hour.
72 
73 using WorkerConfig = experimental::WorkerConfig;
74 
75 // Moves the element into the response. If the tensor contains a single
76 // CompressedElement variant, the move will be zero-copy. Otherwise, the tensor
77 // data will be serialized as TensorProtos.
MoveElementToResponse(std::vector<Tensor> && element,GetElementResponse & resp)78 Status MoveElementToResponse(std::vector<Tensor>&& element,
79                              GetElementResponse& resp) {
80   if (element.size() != 1 || element[0].dtype() != DT_VARIANT ||
81       !TensorShapeUtils::IsScalar(element[0].shape())) {
82     for (const auto& component : element) {
83       UncompressedElement* uncompressed = resp.mutable_uncompressed();
84       component.AsProtoTensorContent(uncompressed->add_components());
85     }
86     return OkStatus();
87   }
88   Variant& variant = element[0].scalar<Variant>()();
89   CompressedElement* compressed = variant.get<CompressedElement>();
90   if (compressed == nullptr) {
91     return errors::FailedPrecondition(
92         "Expected dataset to produce a CompressedElement variant tensor, but "
93         "it produced ",
94         variant.TypeName());
95   }
96   *resp.mutable_compressed() = *compressed;
97   return OkStatus();
98 }
99 
ApplyWorkerDefaults(const WorkerConfig & config)100 WorkerConfig ApplyWorkerDefaults(const WorkerConfig& config) {
101   WorkerConfig new_config(config);
102   if (new_config.heartbeat_interval_ms() == 0) {
103     new_config.set_heartbeat_interval_ms(kDefaultHeartBeatIntervalMs);
104   }
105   if (new_config.dispatcher_timeout_ms() == 0) {
106     new_config.set_dispatcher_timeout_ms(kDefaultDispatcherTimeoutMs);
107   }
108   return new_config;
109 }
110 
Export(const TaskDef & task)111 TaskDef Export(const TaskDef& task) {
112   TaskDef result;
113   switch (task.dataset_case()) {
114     case TaskDef::kDatasetDef:
115       result.set_path(
116           "In-memory dataset graphs are omitted for brevity. To view datasets "
117           "stored on the dispatcher, configure a `work_dir`.");
118       break;
119     case TaskDef::kPath:
120       result.set_path(task.path());
121       break;
122     default:
123       break;
124   }
125   result.set_dataset_id(task.dataset_id());
126   result.set_task_id(task.task_id());
127   result.set_iteration_id(task.iteration_id());
128   result.set_num_split_providers(task.num_split_providers());
129   result.set_worker_address(task.worker_address());
130   *result.mutable_processing_mode_def() = task.processing_mode_def();
131   switch (task.optional_num_consumers_case()) {
132     case TaskDef::kNumConsumers:
133       result.set_num_consumers(task.num_consumers());
134       break;
135     default:
136       break;
137   }
138   result.set_num_workers(task.num_workers());
139   result.set_worker_index(task.worker_index());
140   return result;
141 }
142 }  // namespace
143 
144 mutex LocalWorkers::mu_(LINKER_INITIALIZED);
145 LocalWorkers::AddressToWorkerMap* LocalWorkers::local_workers_ =
146     new AddressToWorkerMap();
147 
DataServiceWorkerImpl(const WorkerConfig & config)148 DataServiceWorkerImpl::DataServiceWorkerImpl(const WorkerConfig& config)
149     : config_(ApplyWorkerDefaults(config)), worker_uid_(port::JobUid()) {
150   metrics::RecordTFDataServiceWorkerCreated();
151 }
152 
~DataServiceWorkerImpl()153 DataServiceWorkerImpl::~DataServiceWorkerImpl() {
154   mutex_lock l(mu_);
155   cancelled_ = true;
156   task_completion_cv_.notify_one();
157   heartbeat_cv_.notify_one();
158 }
159 
Start(const std::string & worker_address,const std::string & transfer_address)160 Status DataServiceWorkerImpl::Start(const std::string& worker_address,
161                                     const std::string& transfer_address) {
162   VLOG(3) << "Starting tf.data service worker at address " << worker_address;
163   TF_RETURN_IF_ERROR(ValidateWorkerConfig());
164   worker_address_ = worker_address;
165   transfer_address_ = transfer_address;
166 
167   dispatcher_ = std::make_unique<DataServiceDispatcherClient>(
168       config_.dispatcher_address(), config_.protocol());
169   TF_RETURN_IF_ERROR(dispatcher_->Initialize());
170 
171   Status s = Heartbeat();
172   while (!s.ok()) {
173     if (!IsPreemptedError(s)) {
174       return s;
175     }
176     LOG(WARNING) << "Failed to register with dispatcher at "
177                  << config_.dispatcher_address() << ": " << s;
178     Env::Default()->SleepForMicroseconds(kRetryIntervalMicros);
179     s = Heartbeat();
180   }
181   LOG(INFO) << "Worker registered with dispatcher running at "
182             << config_.dispatcher_address();
183   task_completion_thread_ = absl::WrapUnique(
184       Env::Default()->StartThread({}, "data-service-worker-task-completion",
185                                   [this]() { TaskCompletionThread(); }));
186   heartbeat_thread_ = absl::WrapUnique(Env::Default()->StartThread(
187       {}, "data-service-worker-heartbeat", [this]() { HeartbeatThread(); }));
188   mutex_lock l(mu_);
189   registered_ = true;
190   return OkStatus();
191 }
192 
Stop()193 void DataServiceWorkerImpl::Stop() {
194   std::vector<std::shared_ptr<Task>> tasks;
195   {
196     mutex_lock l(mu_);
197     cancelled_ = true;
198     for (const auto& entry : tasks_) {
199       tasks.push_back(entry.second);
200     }
201   }
202   for (auto& task : tasks) {
203     StopTask(*task);
204   }
205   // At this point there are no outstanding requests in this RPC handler.
206   // However, requests successfully returned from this RPC handler may still be
207   // in progress within the gRPC server. If we shut down the gRPC server
208   // immediately, it could cause these requests to fail, e.g. with broken pipe.
209   // To mitigate this, we sleep for some time to give the gRPC server time to
210   // complete requests.
211   Env::Default()->SleepForMicroseconds(config_.shutdown_quiet_period_ms() *
212                                        1000);
213 }
214 
ValidateWorkerConfig() const215 Status DataServiceWorkerImpl::ValidateWorkerConfig() const {
216   const bool any_tag_is_empty = absl::c_any_of(
217       config_.worker_tags(),
218       [](const std::string& worker_tag) { return worker_tag.empty(); });
219   if (any_tag_is_empty) {
220     return errors::FailedPrecondition(
221         "Worker tags cannot be empty. Got tags {",
222         absl::StrJoin(config_.worker_tags().begin(),
223                       config_.worker_tags().end(), ", "),
224         "}");
225   }
226   return OkStatus();
227 }
228 
GetElementResult(const GetElementRequest * request,struct GetElementResult * result)229 Status DataServiceWorkerImpl::GetElementResult(
230     const GetElementRequest* request, struct GetElementResult* result) {
231   Task* task = nullptr;
232   {
233     mutex_lock l(mu_);
234     if (cancelled_) {
235       return errors::Cancelled("Worker is shutting down");
236     }
237     if (!registered_) {
238       // We need to reject requests until the worker has registered with the
239       // dispatcher, so that we don't return NOT_FOUND for tasks that the worker
240       // had before preemption.
241       return errors::Unavailable(
242           "Worker has not yet registered with dispatcher.");
243     }
244     auto it = tasks_.find(request->task_id());
245     if (it == tasks_.end()) {
246       if (deleted_tasks_.contains(request->task_id())) {
247         return errors::FailedPrecondition(
248             "Got request for local task ", request->task_id(), " of worker ",
249             worker_address_, ", which has been deleted. You may be creating ",
250             "a duplicate iteration which has already finished. To fix this, "
251             "make sure to create your dataset only once, as opposed to "
252             "re-creating it repeatedly inside a loop.");
253       }
254       if (finished_tasks_.contains(request->task_id())) {
255         VLOG(3) << "Task is already finished";
256         result->end_of_sequence = true;
257         result->skip = false;
258         return OkStatus();
259       }
260       // Perhaps the worker hasn't gotten the task from the dispatcher yet.
261       // Return Unavailable so that the client knows to continue retrying.
262       return errors::Unavailable("Task ", request->task_id(), " not found");
263     }
264     task = it->second.get();
265     task->outstanding_requests++;
266   }
267   auto cleanup = gtl::MakeCleanup([&] {
268     mutex_lock l(mu_);
269     task->outstanding_requests--;
270     cv_.notify_all();
271   });
272   TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
273   TF_RETURN_IF_ERROR(task->task_runner->GetNext(*request, *result));
274 
275   if (result->end_of_sequence) {
276     mutex_lock l(mu_);
277     VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
278     pending_completed_tasks_.insert(request->task_id());
279     task_completion_cv_.notify_one();
280   }
281   return OkStatus();
282 }
283 
ProcessTask(const ProcessTaskRequest * request,ProcessTaskResponse * response)284 Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request,
285                                           ProcessTaskResponse* response) {
286   mutex_lock l(mu_);
287   const TaskDef& task = request->task();
288   VLOG(3) << "Received request to process task " << task.task_id();
289   return ProcessTaskInternal(task);
290 }
291 
ProcessTaskInternal(const TaskDef & task_def)292 Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def)
293     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
294   std::shared_ptr<Task>& task = tasks_[task_def.task_id()];
295   if (task) {
296     VLOG(1) << "Received request to process already-processed task "
297             << task->task_def.task_id();
298     return OkStatus();
299   }
300   task = std::make_unique<Task>(task_def);
301   VLOG(3) << "Began processing for task " << task_def.task_id()
302           << " with processing mode "
303           << task_def.processing_mode_def().DebugString();
304   return OkStatus();
305 }
306 
EnsureTaskInitialized(DataServiceWorkerImpl::Task & task)307 Status DataServiceWorkerImpl::EnsureTaskInitialized(
308     DataServiceWorkerImpl::Task& task) {
309   if (task.task_def.worker_address() != worker_address_) {
310     return errors::Internal(absl::Substitute(
311         "Dispatcher's worker address $0 does not match worker's address $1.",
312         task.task_def.worker_address(), worker_address_));
313   }
314 
315   mutex_lock l(task.mu);
316   if (task.initialized) {
317     return OkStatus();
318   }
319   TF_ASSIGN_OR_RETURN(DatasetDef dataset_def, GetDatasetDef(task.task_def));
320   TF_ASSIGN_OR_RETURN(std::unique_ptr<standalone::Dataset> dataset,
321                       MakeDataset(dataset_def, task.task_def));
322   TF_ASSIGN_OR_RETURN(std::unique_ptr<standalone::Iterator> iterator,
323                       MakeDatasetIterator(*dataset, task.task_def));
324   auto task_iterator = std::make_unique<StandaloneTaskIterator>(
325       std::move(dataset), std::move(iterator));
326   TF_RETURN_IF_ERROR(TaskRunner::Create(
327       config_, task.task_def, std::move(task_iterator), task.task_runner));
328 
329   task.initialized = true;
330   VLOG(3) << "Created iterator for task " << task.task_def.task_id();
331   return OkStatus();
332 }
333 
GetDatasetDef(const TaskDef & task_def) const334 StatusOr<DatasetDef> DataServiceWorkerImpl::GetDatasetDef(
335     const TaskDef& task_def) const {
336   switch (task_def.dataset_case()) {
337     case TaskDef::kDatasetDef:
338       return task_def.dataset_def();
339     case TaskDef::kPath: {
340       DatasetDef def;
341       Status s = ReadDatasetDef(task_def.path(), def);
342       if (!s.ok()) {
343         LOG(INFO) << "Failed to read dataset from " << task_def.path() << ": "
344                   << s << ". Falling back to reading from dispatcher.";
345         TF_RETURN_IF_ERROR(
346             dispatcher_->GetDatasetDef(task_def.dataset_id(), def));
347       }
348       return def;
349     }
350     case TaskDef::DATASET_NOT_SET:
351       return errors::Internal("Unrecognized dataset case: ",
352                               task_def.dataset_case());
353   }
354 }
355 
356 StatusOr<std::unique_ptr<standalone::Dataset>>
MakeDataset(const DatasetDef & dataset_def,const TaskDef & task_def) const357 DataServiceWorkerImpl::MakeDataset(const DatasetDef& dataset_def,
358                                    const TaskDef& task_def) const {
359   TF_ASSIGN_OR_RETURN(AutoShardRewriter auto_shard_rewriter,
360                       AutoShardRewriter::Create(task_def));
361   // `ApplyAutoShardRewrite` does nothing if auto-sharding is disabled.
362   TF_ASSIGN_OR_RETURN(
363       GraphDef rewritten_graph,
364       auto_shard_rewriter.ApplyAutoShardRewrite(dataset_def.graph()));
365   std::unique_ptr<standalone::Dataset> dataset;
366   TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
367       standalone::Dataset::Params(), rewritten_graph, &dataset));
368   return dataset;
369 }
370 
371 StatusOr<std::unique_ptr<standalone::Iterator>>
MakeDatasetIterator(standalone::Dataset & dataset,const TaskDef & task_def) const372 DataServiceWorkerImpl::MakeDatasetIterator(standalone::Dataset& dataset,
373                                            const TaskDef& task_def) const {
374   std::unique_ptr<standalone::Iterator> iterator;
375   if (IsNoShard(task_def.processing_mode_def()) ||
376       IsStaticShard(task_def.processing_mode_def())) {
377     TF_RETURN_IF_ERROR(dataset.MakeIterator(&iterator));
378     return iterator;
379   }
380 
381   if (IsDynamicShard(task_def.processing_mode_def())) {
382     std::vector<std::unique_ptr<SplitProvider>> split_providers;
383     split_providers.reserve(task_def.num_split_providers());
384     for (int i = 0; i < task_def.num_split_providers(); ++i) {
385       split_providers.push_back(std::make_unique<DataServiceSplitProvider>(
386           config_.dispatcher_address(), config_.protocol(),
387           task_def.iteration_id(), i, config_.dispatcher_timeout_ms()));
388     }
389     TF_RETURN_IF_ERROR(
390         dataset.MakeIterator(std::move(split_providers), &iterator));
391     return iterator;
392   }
393 
394   return errors::InvalidArgument("Unrecognized processing mode: ",
395                                  task_def.processing_mode_def().DebugString());
396 }
397 
StopTask(Task & task)398 void DataServiceWorkerImpl::StopTask(Task& task) TF_LOCKS_EXCLUDED(mu_) {
399   {
400     mutex_lock l(task.mu);
401     task.initialized = true;
402   }
403   if (task.task_runner) {
404     task.task_runner->Cancel();
405   }
406   mutex_lock l(mu_);
407   while (task.outstanding_requests > 0) {
408     cv_.wait(l);
409   }
410 }
411 
GetElement(const GetElementRequest * request,GetElementResponse * response)412 Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
413                                          GetElementResponse* response) {
414   VLOG(3) << "Received GetElement request for task " << request->task_id();
415   struct GetElementResult result;
416   TF_RETURN_IF_ERROR(GetElementResult(request, &result));
417   response->set_end_of_sequence(result.end_of_sequence);
418   response->set_skip_task(result.skip);
419   if (!response->end_of_sequence() && !response->skip_task()) {
420     TF_RETURN_IF_ERROR(
421         MoveElementToResponse(std::move(result.components), *response));
422     VLOG(3) << "Producing an element for task " << request->task_id();
423   }
424   return OkStatus();
425 }
426 
GetWorkerTasks(const GetWorkerTasksRequest * request,GetWorkerTasksResponse * response)427 Status DataServiceWorkerImpl::GetWorkerTasks(
428     const GetWorkerTasksRequest* request, GetWorkerTasksResponse* response) {
429   mutex_lock l(mu_);
430   for (const auto& it : tasks_) {
431     Task* task = it.second.get();
432     TaskInfo* task_info = response->add_tasks();
433     task_info->set_worker_address(worker_address_);
434     task_info->set_task_id(task->task_def.task_id());
435     task_info->set_iteration_id(task->task_def.iteration_id());
436   }
437   return OkStatus();
438 }
439 
TaskCompletionThread()440 void DataServiceWorkerImpl::TaskCompletionThread() TF_LOCKS_EXCLUDED(mu_) {
441   while (true) {
442     {
443       mutex_lock l(mu_);
444       while (!cancelled_ && pending_completed_tasks_.empty()) {
445         task_completion_cv_.wait(l);
446       }
447       if (cancelled_) {
448         VLOG(3) << "Task completion thread shutting down";
449         return;
450       }
451     }
452     Status s = SendTaskUpdates();
453     if (!s.ok()) {
454       LOG(WARNING) << "Failed to send task updates to dispatcher: " << s;
455       mutex_lock l(mu_);
456       if (!cancelled_) {
457         task_completion_cv_.wait_for(
458             l, std::chrono::microseconds(kRetryIntervalMicros));
459       }
460     }
461   }
462 }
463 
SendTaskUpdates()464 Status DataServiceWorkerImpl::SendTaskUpdates() TF_LOCKS_EXCLUDED(mu_) {
465   std::vector<TaskProgress> task_progress;
466   {
467     mutex_lock l(mu_);
468     VLOG(3) << "Sending " << pending_completed_tasks_.size()
469             << " task updates to dispatcher";
470     task_progress.reserve(pending_completed_tasks_.size());
471     for (int task_id : pending_completed_tasks_) {
472       task_progress.emplace_back();
473       task_progress.back().set_task_id(task_id);
474       task_progress.back().set_completed(true);
475     }
476   }
477 
478   TF_RETURN_IF_ERROR(dispatcher_->WorkerUpdate(worker_address_, task_progress));
479   mutex_lock l(mu_);
480   for (const auto& update : task_progress) {
481     pending_completed_tasks_.erase(update.task_id());
482   }
483   VLOG(3) << "Sent " << task_progress.size() << " task updates ";
484   return OkStatus();
485 }
486 
HeartbeatThread()487 void DataServiceWorkerImpl::HeartbeatThread() TF_LOCKS_EXCLUDED(mu_) {
488   while (true) {
489     int64_t next_heartbeat_micros =
490         Env::Default()->NowMicros() + (config_.heartbeat_interval_ms() * 1000);
491     {
492       mutex_lock l(mu_);
493       while (!cancelled_ &&
494              Env::Default()->NowMicros() < next_heartbeat_micros) {
495         int64_t time_to_wait_micros =
496             next_heartbeat_micros - Env::Default()->NowMicros();
497         heartbeat_cv_.wait_for(l,
498                                std::chrono::microseconds(time_to_wait_micros));
499       }
500       if (cancelled_) {
501         VLOG(3) << "Heartbeat thread shutting down";
502         return;
503       }
504       if (!registered_) {
505         VLOG(1) << "Not performing heartbeat; worker is not yet registered";
506         continue;
507       }
508     }
509     Status s = Heartbeat();
510     if (!s.ok()) {
511       LOG(WARNING) << "Failed to send heartbeat to dispatcher: " << s;
512     }
513   }
514 }
515 
Heartbeat()516 Status DataServiceWorkerImpl::Heartbeat() TF_LOCKS_EXCLUDED(mu_) {
517   std::vector<int64_t> current_tasks;
518   {
519     mutex_lock l(mu_);
520     for (const auto& task : tasks_) {
521       current_tasks.push_back(task.first);
522     }
523   }
524   WorkerHeartbeatRequest request;
525   request.set_worker_address(worker_address_);
526   request.set_transfer_address(transfer_address_);
527   *request.mutable_worker_tags() = config_.worker_tags();
528   request.set_worker_uid(worker_uid_);
529   *request.mutable_current_tasks() = {current_tasks.begin(),
530                                       current_tasks.end()};
531   TF_ASSIGN_OR_RETURN(WorkerHeartbeatResponse response,
532                       dispatcher_->WorkerHeartbeat(request));
533 
534   std::vector<std::shared_ptr<Task>> tasks_to_delete;
535   {
536     mutex_lock l(mu_);
537     for (const auto& task : response.new_tasks()) {
538       VLOG(1) << "Received new task from dispatcher with id " << task.task_id();
539       if (deleted_tasks_.contains(task.task_id())) {
540         continue;
541       }
542       Status s = ProcessTaskInternal(task);
543       if (!s.ok() && !errors::IsAlreadyExists(s)) {
544         LOG(WARNING) << "Failed to start processing task " << task.task_id()
545                      << ": " << s;
546       }
547     }
548     tasks_to_delete.reserve(response.tasks_to_delete_size());
549     for (int64_t task_id : response.tasks_to_delete()) {
550       VLOG(3) << "Deleting task " << task_id
551               << " at the request of the dispatcher";
552       if (!tasks_.contains(task_id)) {
553         continue;
554       }
555       tasks_to_delete.push_back(std::move(tasks_[task_id]));
556       tasks_.erase(task_id);
557       finished_tasks_.insert(task_id);
558     }
559   }
560   for (const auto& task : tasks_to_delete) {
561     StopTask(*task);
562   }
563   return OkStatus();
564 }
565 
DeleteLocalTask(const TaskInfo & task_info)566 void DataServiceWorkerImpl::DeleteLocalTask(const TaskInfo& task_info)
567     TF_LOCKS_EXCLUDED(mu_) {
568   std::shared_ptr<Task> task;
569   {
570     mutex_lock l(mu_);
571     auto it = tasks_.find(task_info.task_id());
572     if (it == tasks_.end() || !it->second) {
573       return;
574     }
575     task = std::move(it->second);
576     tasks_.erase(task_info.task_id());
577     pending_completed_tasks_.insert(task_info.task_id());
578     deleted_tasks_.insert(task_info.task_id());
579   }
580 
581   VLOG(2) << "Delete local task " << task_info.task_id() << " from worker "
582           << worker_address_ << " at the request of the client.";
583   StopTask(*task);
584 }
585 
ExportState() const586 WorkerStateExport DataServiceWorkerImpl::ExportState() const {
587   WorkerStateExport worker_state_export;
588   *worker_state_export.mutable_worker_config() = config_;
589   mutex_lock l(mu_);
590   if (!registered_) {
591     return worker_state_export;
592   }
593   for (const auto& task : tasks_) {
594     *worker_state_export.add_tasks() = Export(task.second->task_def);
595   }
596   for (int64_t finished_task : finished_tasks_) {
597     worker_state_export.add_finished_task_ids(finished_task);
598   }
599   for (int64_t deleted_task : deleted_tasks_) {
600     worker_state_export.add_deleted_task_ids(deleted_task);
601   }
602   return worker_state_export;
603 }
604 
Add(absl::string_view worker_address,std::shared_ptr<DataServiceWorkerImpl> worker)605 void LocalWorkers::Add(absl::string_view worker_address,
606                        std::shared_ptr<DataServiceWorkerImpl> worker) {
607   DCHECK(worker != nullptr) << "Adding a nullptr local worker is disallowed.";
608   VLOG(1) << "Register local worker at address " << worker_address;
609   mutex_lock l(mu_);
610   (*local_workers_)[worker_address] = worker;
611 }
612 
Get(absl::string_view worker_address)613 std::shared_ptr<DataServiceWorkerImpl> LocalWorkers::Get(
614     absl::string_view worker_address) {
615   tf_shared_lock l(mu_);
616   AddressToWorkerMap::const_iterator it = local_workers_->find(worker_address);
617   if (it == local_workers_->end()) {
618     return nullptr;
619   }
620   return it->second;
621 }
622 
Empty()623 bool LocalWorkers::Empty() {
624   tf_shared_lock l(mu_);
625   return local_workers_->empty();
626 }
627 
Remove(absl::string_view worker_address)628 void LocalWorkers::Remove(absl::string_view worker_address) {
629   VLOG(1) << "Remove local worker at address " << worker_address;
630   mutex_lock l(mu_);
631   local_workers_->erase(worker_address);
632 }
633 
634 }  // namespace data
635 }  // namespace tensorflow
636