1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/service/worker_impl.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "grpcpp/create_channel.h"
24 #include "absl/algorithm/container.h"
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/strings/string_view.h"
28 #include "absl/strings/substitute.h"
29 #include "tensorflow/c/c_api_internal.h"
30 #include "tensorflow/c/tf_status_helper.h"
31 #include "tensorflow/core/data/dataset.pb.h"
32 #include "tensorflow/core/data/service/auto_shard_rewriter.h"
33 #include "tensorflow/core/data/service/common.h"
34 #include "tensorflow/core/data/service/common.pb.h"
35 #include "tensorflow/core/data/service/credentials_factory.h"
36 #include "tensorflow/core/data/service/data_transfer.h"
37 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
38 #include "tensorflow/core/data/service/dispatcher.pb.h"
39 #include "tensorflow/core/data/service/dispatcher_client.h"
40 #include "tensorflow/core/data/service/export.pb.h"
41 #include "tensorflow/core/data/service/grpc_util.h"
42 #include "tensorflow/core/data/service/split_provider.h"
43 #include "tensorflow/core/data/service/task_runner.h"
44 #include "tensorflow/core/data/service/utils.h"
45 #include "tensorflow/core/data/service/worker.pb.h"
46 #include "tensorflow/core/data/standalone.h"
47 #include "tensorflow/core/framework/dataset_options.pb.h"
48 #include "tensorflow/core/framework/metrics.h"
49 #include "tensorflow/core/framework/tensor.h"
50 #include "tensorflow/core/framework/tensor.pb.h"
51 #include "tensorflow/core/lib/core/errors.h"
52 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
53 #include "tensorflow/core/lib/monitoring/gauge.h"
54 #include "tensorflow/core/platform/env.h"
55 #include "tensorflow/core/platform/errors.h"
56 #include "tensorflow/core/platform/host_info.h"
57 #include "tensorflow/core/platform/refcount.h"
58 #include "tensorflow/core/platform/snappy.h"
59 #include "tensorflow/core/platform/status.h"
60 #include "tensorflow/core/platform/statusor.h"
61 #include "tensorflow/core/platform/thread_annotations.h"
62 #include "tensorflow/core/protobuf/service_config.pb.h"
63 #include "tensorflow/core/public/session_options.h"
64
65 namespace tensorflow {
66 namespace data {
67 namespace {
68
69 constexpr int64_t kRetryIntervalMicros = 5 * 1000 * 1000; // 5 seconds.
70 constexpr int64_t kDefaultHeartBeatIntervalMs = 30 * 1000; // 30 seconds.
71 constexpr int64_t kDefaultDispatcherTimeoutMs = 60 * 60 * 1000; // 1 hour.
72
73 using WorkerConfig = experimental::WorkerConfig;
74
75 // Moves the element into the response. If the tensor contains a single
76 // CompressedElement variant, the move will be zero-copy. Otherwise, the tensor
77 // data will be serialized as TensorProtos.
MoveElementToResponse(std::vector<Tensor> && element,GetElementResponse & resp)78 Status MoveElementToResponse(std::vector<Tensor>&& element,
79 GetElementResponse& resp) {
80 if (element.size() != 1 || element[0].dtype() != DT_VARIANT ||
81 !TensorShapeUtils::IsScalar(element[0].shape())) {
82 for (const auto& component : element) {
83 UncompressedElement* uncompressed = resp.mutable_uncompressed();
84 component.AsProtoTensorContent(uncompressed->add_components());
85 }
86 return OkStatus();
87 }
88 Variant& variant = element[0].scalar<Variant>()();
89 CompressedElement* compressed = variant.get<CompressedElement>();
90 if (compressed == nullptr) {
91 return errors::FailedPrecondition(
92 "Expected dataset to produce a CompressedElement variant tensor, but "
93 "it produced ",
94 variant.TypeName());
95 }
96 *resp.mutable_compressed() = *compressed;
97 return OkStatus();
98 }
99
ApplyWorkerDefaults(const WorkerConfig & config)100 WorkerConfig ApplyWorkerDefaults(const WorkerConfig& config) {
101 WorkerConfig new_config(config);
102 if (new_config.heartbeat_interval_ms() == 0) {
103 new_config.set_heartbeat_interval_ms(kDefaultHeartBeatIntervalMs);
104 }
105 if (new_config.dispatcher_timeout_ms() == 0) {
106 new_config.set_dispatcher_timeout_ms(kDefaultDispatcherTimeoutMs);
107 }
108 return new_config;
109 }
110
Export(const TaskDef & task)111 TaskDef Export(const TaskDef& task) {
112 TaskDef result;
113 switch (task.dataset_case()) {
114 case TaskDef::kDatasetDef:
115 result.set_path(
116 "In-memory dataset graphs are omitted for brevity. To view datasets "
117 "stored on the dispatcher, configure a `work_dir`.");
118 break;
119 case TaskDef::kPath:
120 result.set_path(task.path());
121 break;
122 default:
123 break;
124 }
125 result.set_dataset_id(task.dataset_id());
126 result.set_task_id(task.task_id());
127 result.set_iteration_id(task.iteration_id());
128 result.set_num_split_providers(task.num_split_providers());
129 result.set_worker_address(task.worker_address());
130 *result.mutable_processing_mode_def() = task.processing_mode_def();
131 switch (task.optional_num_consumers_case()) {
132 case TaskDef::kNumConsumers:
133 result.set_num_consumers(task.num_consumers());
134 break;
135 default:
136 break;
137 }
138 result.set_num_workers(task.num_workers());
139 result.set_worker_index(task.worker_index());
140 return result;
141 }
142 } // namespace
143
144 mutex LocalWorkers::mu_(LINKER_INITIALIZED);
145 LocalWorkers::AddressToWorkerMap* LocalWorkers::local_workers_ =
146 new AddressToWorkerMap();
147
DataServiceWorkerImpl(const WorkerConfig & config)148 DataServiceWorkerImpl::DataServiceWorkerImpl(const WorkerConfig& config)
149 : config_(ApplyWorkerDefaults(config)), worker_uid_(port::JobUid()) {
150 metrics::RecordTFDataServiceWorkerCreated();
151 }
152
~DataServiceWorkerImpl()153 DataServiceWorkerImpl::~DataServiceWorkerImpl() {
154 mutex_lock l(mu_);
155 cancelled_ = true;
156 task_completion_cv_.notify_one();
157 heartbeat_cv_.notify_one();
158 }
159
Start(const std::string & worker_address,const std::string & transfer_address)160 Status DataServiceWorkerImpl::Start(const std::string& worker_address,
161 const std::string& transfer_address) {
162 VLOG(3) << "Starting tf.data service worker at address " << worker_address;
163 TF_RETURN_IF_ERROR(ValidateWorkerConfig());
164 worker_address_ = worker_address;
165 transfer_address_ = transfer_address;
166
167 dispatcher_ = std::make_unique<DataServiceDispatcherClient>(
168 config_.dispatcher_address(), config_.protocol());
169 TF_RETURN_IF_ERROR(dispatcher_->Initialize());
170
171 Status s = Heartbeat();
172 while (!s.ok()) {
173 if (!IsPreemptedError(s)) {
174 return s;
175 }
176 LOG(WARNING) << "Failed to register with dispatcher at "
177 << config_.dispatcher_address() << ": " << s;
178 Env::Default()->SleepForMicroseconds(kRetryIntervalMicros);
179 s = Heartbeat();
180 }
181 LOG(INFO) << "Worker registered with dispatcher running at "
182 << config_.dispatcher_address();
183 task_completion_thread_ = absl::WrapUnique(
184 Env::Default()->StartThread({}, "data-service-worker-task-completion",
185 [this]() { TaskCompletionThread(); }));
186 heartbeat_thread_ = absl::WrapUnique(Env::Default()->StartThread(
187 {}, "data-service-worker-heartbeat", [this]() { HeartbeatThread(); }));
188 mutex_lock l(mu_);
189 registered_ = true;
190 return OkStatus();
191 }
192
Stop()193 void DataServiceWorkerImpl::Stop() {
194 std::vector<std::shared_ptr<Task>> tasks;
195 {
196 mutex_lock l(mu_);
197 cancelled_ = true;
198 for (const auto& entry : tasks_) {
199 tasks.push_back(entry.second);
200 }
201 }
202 for (auto& task : tasks) {
203 StopTask(*task);
204 }
205 // At this point there are no outstanding requests in this RPC handler.
206 // However, requests successfully returned from this RPC handler may still be
207 // in progress within the gRPC server. If we shut down the gRPC server
208 // immediately, it could cause these requests to fail, e.g. with broken pipe.
209 // To mitigate this, we sleep for some time to give the gRPC server time to
210 // complete requests.
211 Env::Default()->SleepForMicroseconds(config_.shutdown_quiet_period_ms() *
212 1000);
213 }
214
ValidateWorkerConfig() const215 Status DataServiceWorkerImpl::ValidateWorkerConfig() const {
216 const bool any_tag_is_empty = absl::c_any_of(
217 config_.worker_tags(),
218 [](const std::string& worker_tag) { return worker_tag.empty(); });
219 if (any_tag_is_empty) {
220 return errors::FailedPrecondition(
221 "Worker tags cannot be empty. Got tags {",
222 absl::StrJoin(config_.worker_tags().begin(),
223 config_.worker_tags().end(), ", "),
224 "}");
225 }
226 return OkStatus();
227 }
228
GetElementResult(const GetElementRequest * request,struct GetElementResult * result)229 Status DataServiceWorkerImpl::GetElementResult(
230 const GetElementRequest* request, struct GetElementResult* result) {
231 Task* task = nullptr;
232 {
233 mutex_lock l(mu_);
234 if (cancelled_) {
235 return errors::Cancelled("Worker is shutting down");
236 }
237 if (!registered_) {
238 // We need to reject requests until the worker has registered with the
239 // dispatcher, so that we don't return NOT_FOUND for tasks that the worker
240 // had before preemption.
241 return errors::Unavailable(
242 "Worker has not yet registered with dispatcher.");
243 }
244 auto it = tasks_.find(request->task_id());
245 if (it == tasks_.end()) {
246 if (deleted_tasks_.contains(request->task_id())) {
247 return errors::FailedPrecondition(
248 "Got request for local task ", request->task_id(), " of worker ",
249 worker_address_, ", which has been deleted. You may be creating ",
250 "a duplicate iteration which has already finished. To fix this, "
251 "make sure to create your dataset only once, as opposed to "
252 "re-creating it repeatedly inside a loop.");
253 }
254 if (finished_tasks_.contains(request->task_id())) {
255 VLOG(3) << "Task is already finished";
256 result->end_of_sequence = true;
257 result->skip = false;
258 return OkStatus();
259 }
260 // Perhaps the worker hasn't gotten the task from the dispatcher yet.
261 // Return Unavailable so that the client knows to continue retrying.
262 return errors::Unavailable("Task ", request->task_id(), " not found");
263 }
264 task = it->second.get();
265 task->outstanding_requests++;
266 }
267 auto cleanup = gtl::MakeCleanup([&] {
268 mutex_lock l(mu_);
269 task->outstanding_requests--;
270 cv_.notify_all();
271 });
272 TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
273 TF_RETURN_IF_ERROR(task->task_runner->GetNext(*request, *result));
274
275 if (result->end_of_sequence) {
276 mutex_lock l(mu_);
277 VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
278 pending_completed_tasks_.insert(request->task_id());
279 task_completion_cv_.notify_one();
280 }
281 return OkStatus();
282 }
283
ProcessTask(const ProcessTaskRequest * request,ProcessTaskResponse * response)284 Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request,
285 ProcessTaskResponse* response) {
286 mutex_lock l(mu_);
287 const TaskDef& task = request->task();
288 VLOG(3) << "Received request to process task " << task.task_id();
289 return ProcessTaskInternal(task);
290 }
291
ProcessTaskInternal(const TaskDef & task_def)292 Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def)
293 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
294 std::shared_ptr<Task>& task = tasks_[task_def.task_id()];
295 if (task) {
296 VLOG(1) << "Received request to process already-processed task "
297 << task->task_def.task_id();
298 return OkStatus();
299 }
300 task = std::make_unique<Task>(task_def);
301 VLOG(3) << "Began processing for task " << task_def.task_id()
302 << " with processing mode "
303 << task_def.processing_mode_def().DebugString();
304 return OkStatus();
305 }
306
EnsureTaskInitialized(DataServiceWorkerImpl::Task & task)307 Status DataServiceWorkerImpl::EnsureTaskInitialized(
308 DataServiceWorkerImpl::Task& task) {
309 if (task.task_def.worker_address() != worker_address_) {
310 return errors::Internal(absl::Substitute(
311 "Dispatcher's worker address $0 does not match worker's address $1.",
312 task.task_def.worker_address(), worker_address_));
313 }
314
315 mutex_lock l(task.mu);
316 if (task.initialized) {
317 return OkStatus();
318 }
319 TF_ASSIGN_OR_RETURN(DatasetDef dataset_def, GetDatasetDef(task.task_def));
320 TF_ASSIGN_OR_RETURN(std::unique_ptr<standalone::Dataset> dataset,
321 MakeDataset(dataset_def, task.task_def));
322 TF_ASSIGN_OR_RETURN(std::unique_ptr<standalone::Iterator> iterator,
323 MakeDatasetIterator(*dataset, task.task_def));
324 auto task_iterator = std::make_unique<StandaloneTaskIterator>(
325 std::move(dataset), std::move(iterator));
326 TF_RETURN_IF_ERROR(TaskRunner::Create(
327 config_, task.task_def, std::move(task_iterator), task.task_runner));
328
329 task.initialized = true;
330 VLOG(3) << "Created iterator for task " << task.task_def.task_id();
331 return OkStatus();
332 }
333
GetDatasetDef(const TaskDef & task_def) const334 StatusOr<DatasetDef> DataServiceWorkerImpl::GetDatasetDef(
335 const TaskDef& task_def) const {
336 switch (task_def.dataset_case()) {
337 case TaskDef::kDatasetDef:
338 return task_def.dataset_def();
339 case TaskDef::kPath: {
340 DatasetDef def;
341 Status s = ReadDatasetDef(task_def.path(), def);
342 if (!s.ok()) {
343 LOG(INFO) << "Failed to read dataset from " << task_def.path() << ": "
344 << s << ". Falling back to reading from dispatcher.";
345 TF_RETURN_IF_ERROR(
346 dispatcher_->GetDatasetDef(task_def.dataset_id(), def));
347 }
348 return def;
349 }
350 case TaskDef::DATASET_NOT_SET:
351 return errors::Internal("Unrecognized dataset case: ",
352 task_def.dataset_case());
353 }
354 }
355
356 StatusOr<std::unique_ptr<standalone::Dataset>>
MakeDataset(const DatasetDef & dataset_def,const TaskDef & task_def) const357 DataServiceWorkerImpl::MakeDataset(const DatasetDef& dataset_def,
358 const TaskDef& task_def) const {
359 TF_ASSIGN_OR_RETURN(AutoShardRewriter auto_shard_rewriter,
360 AutoShardRewriter::Create(task_def));
361 // `ApplyAutoShardRewrite` does nothing if auto-sharding is disabled.
362 TF_ASSIGN_OR_RETURN(
363 GraphDef rewritten_graph,
364 auto_shard_rewriter.ApplyAutoShardRewrite(dataset_def.graph()));
365 std::unique_ptr<standalone::Dataset> dataset;
366 TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
367 standalone::Dataset::Params(), rewritten_graph, &dataset));
368 return dataset;
369 }
370
371 StatusOr<std::unique_ptr<standalone::Iterator>>
MakeDatasetIterator(standalone::Dataset & dataset,const TaskDef & task_def) const372 DataServiceWorkerImpl::MakeDatasetIterator(standalone::Dataset& dataset,
373 const TaskDef& task_def) const {
374 std::unique_ptr<standalone::Iterator> iterator;
375 if (IsNoShard(task_def.processing_mode_def()) ||
376 IsStaticShard(task_def.processing_mode_def())) {
377 TF_RETURN_IF_ERROR(dataset.MakeIterator(&iterator));
378 return iterator;
379 }
380
381 if (IsDynamicShard(task_def.processing_mode_def())) {
382 std::vector<std::unique_ptr<SplitProvider>> split_providers;
383 split_providers.reserve(task_def.num_split_providers());
384 for (int i = 0; i < task_def.num_split_providers(); ++i) {
385 split_providers.push_back(std::make_unique<DataServiceSplitProvider>(
386 config_.dispatcher_address(), config_.protocol(),
387 task_def.iteration_id(), i, config_.dispatcher_timeout_ms()));
388 }
389 TF_RETURN_IF_ERROR(
390 dataset.MakeIterator(std::move(split_providers), &iterator));
391 return iterator;
392 }
393
394 return errors::InvalidArgument("Unrecognized processing mode: ",
395 task_def.processing_mode_def().DebugString());
396 }
397
StopTask(Task & task)398 void DataServiceWorkerImpl::StopTask(Task& task) TF_LOCKS_EXCLUDED(mu_) {
399 {
400 mutex_lock l(task.mu);
401 task.initialized = true;
402 }
403 if (task.task_runner) {
404 task.task_runner->Cancel();
405 }
406 mutex_lock l(mu_);
407 while (task.outstanding_requests > 0) {
408 cv_.wait(l);
409 }
410 }
411
GetElement(const GetElementRequest * request,GetElementResponse * response)412 Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
413 GetElementResponse* response) {
414 VLOG(3) << "Received GetElement request for task " << request->task_id();
415 struct GetElementResult result;
416 TF_RETURN_IF_ERROR(GetElementResult(request, &result));
417 response->set_end_of_sequence(result.end_of_sequence);
418 response->set_skip_task(result.skip);
419 if (!response->end_of_sequence() && !response->skip_task()) {
420 TF_RETURN_IF_ERROR(
421 MoveElementToResponse(std::move(result.components), *response));
422 VLOG(3) << "Producing an element for task " << request->task_id();
423 }
424 return OkStatus();
425 }
426
GetWorkerTasks(const GetWorkerTasksRequest * request,GetWorkerTasksResponse * response)427 Status DataServiceWorkerImpl::GetWorkerTasks(
428 const GetWorkerTasksRequest* request, GetWorkerTasksResponse* response) {
429 mutex_lock l(mu_);
430 for (const auto& it : tasks_) {
431 Task* task = it.second.get();
432 TaskInfo* task_info = response->add_tasks();
433 task_info->set_worker_address(worker_address_);
434 task_info->set_task_id(task->task_def.task_id());
435 task_info->set_iteration_id(task->task_def.iteration_id());
436 }
437 return OkStatus();
438 }
439
TaskCompletionThread()440 void DataServiceWorkerImpl::TaskCompletionThread() TF_LOCKS_EXCLUDED(mu_) {
441 while (true) {
442 {
443 mutex_lock l(mu_);
444 while (!cancelled_ && pending_completed_tasks_.empty()) {
445 task_completion_cv_.wait(l);
446 }
447 if (cancelled_) {
448 VLOG(3) << "Task completion thread shutting down";
449 return;
450 }
451 }
452 Status s = SendTaskUpdates();
453 if (!s.ok()) {
454 LOG(WARNING) << "Failed to send task updates to dispatcher: " << s;
455 mutex_lock l(mu_);
456 if (!cancelled_) {
457 task_completion_cv_.wait_for(
458 l, std::chrono::microseconds(kRetryIntervalMicros));
459 }
460 }
461 }
462 }
463
SendTaskUpdates()464 Status DataServiceWorkerImpl::SendTaskUpdates() TF_LOCKS_EXCLUDED(mu_) {
465 std::vector<TaskProgress> task_progress;
466 {
467 mutex_lock l(mu_);
468 VLOG(3) << "Sending " << pending_completed_tasks_.size()
469 << " task updates to dispatcher";
470 task_progress.reserve(pending_completed_tasks_.size());
471 for (int task_id : pending_completed_tasks_) {
472 task_progress.emplace_back();
473 task_progress.back().set_task_id(task_id);
474 task_progress.back().set_completed(true);
475 }
476 }
477
478 TF_RETURN_IF_ERROR(dispatcher_->WorkerUpdate(worker_address_, task_progress));
479 mutex_lock l(mu_);
480 for (const auto& update : task_progress) {
481 pending_completed_tasks_.erase(update.task_id());
482 }
483 VLOG(3) << "Sent " << task_progress.size() << " task updates ";
484 return OkStatus();
485 }
486
HeartbeatThread()487 void DataServiceWorkerImpl::HeartbeatThread() TF_LOCKS_EXCLUDED(mu_) {
488 while (true) {
489 int64_t next_heartbeat_micros =
490 Env::Default()->NowMicros() + (config_.heartbeat_interval_ms() * 1000);
491 {
492 mutex_lock l(mu_);
493 while (!cancelled_ &&
494 Env::Default()->NowMicros() < next_heartbeat_micros) {
495 int64_t time_to_wait_micros =
496 next_heartbeat_micros - Env::Default()->NowMicros();
497 heartbeat_cv_.wait_for(l,
498 std::chrono::microseconds(time_to_wait_micros));
499 }
500 if (cancelled_) {
501 VLOG(3) << "Heartbeat thread shutting down";
502 return;
503 }
504 if (!registered_) {
505 VLOG(1) << "Not performing heartbeat; worker is not yet registered";
506 continue;
507 }
508 }
509 Status s = Heartbeat();
510 if (!s.ok()) {
511 LOG(WARNING) << "Failed to send heartbeat to dispatcher: " << s;
512 }
513 }
514 }
515
Heartbeat()516 Status DataServiceWorkerImpl::Heartbeat() TF_LOCKS_EXCLUDED(mu_) {
517 std::vector<int64_t> current_tasks;
518 {
519 mutex_lock l(mu_);
520 for (const auto& task : tasks_) {
521 current_tasks.push_back(task.first);
522 }
523 }
524 WorkerHeartbeatRequest request;
525 request.set_worker_address(worker_address_);
526 request.set_transfer_address(transfer_address_);
527 *request.mutable_worker_tags() = config_.worker_tags();
528 request.set_worker_uid(worker_uid_);
529 *request.mutable_current_tasks() = {current_tasks.begin(),
530 current_tasks.end()};
531 TF_ASSIGN_OR_RETURN(WorkerHeartbeatResponse response,
532 dispatcher_->WorkerHeartbeat(request));
533
534 std::vector<std::shared_ptr<Task>> tasks_to_delete;
535 {
536 mutex_lock l(mu_);
537 for (const auto& task : response.new_tasks()) {
538 VLOG(1) << "Received new task from dispatcher with id " << task.task_id();
539 if (deleted_tasks_.contains(task.task_id())) {
540 continue;
541 }
542 Status s = ProcessTaskInternal(task);
543 if (!s.ok() && !errors::IsAlreadyExists(s)) {
544 LOG(WARNING) << "Failed to start processing task " << task.task_id()
545 << ": " << s;
546 }
547 }
548 tasks_to_delete.reserve(response.tasks_to_delete_size());
549 for (int64_t task_id : response.tasks_to_delete()) {
550 VLOG(3) << "Deleting task " << task_id
551 << " at the request of the dispatcher";
552 if (!tasks_.contains(task_id)) {
553 continue;
554 }
555 tasks_to_delete.push_back(std::move(tasks_[task_id]));
556 tasks_.erase(task_id);
557 finished_tasks_.insert(task_id);
558 }
559 }
560 for (const auto& task : tasks_to_delete) {
561 StopTask(*task);
562 }
563 return OkStatus();
564 }
565
DeleteLocalTask(const TaskInfo & task_info)566 void DataServiceWorkerImpl::DeleteLocalTask(const TaskInfo& task_info)
567 TF_LOCKS_EXCLUDED(mu_) {
568 std::shared_ptr<Task> task;
569 {
570 mutex_lock l(mu_);
571 auto it = tasks_.find(task_info.task_id());
572 if (it == tasks_.end() || !it->second) {
573 return;
574 }
575 task = std::move(it->second);
576 tasks_.erase(task_info.task_id());
577 pending_completed_tasks_.insert(task_info.task_id());
578 deleted_tasks_.insert(task_info.task_id());
579 }
580
581 VLOG(2) << "Delete local task " << task_info.task_id() << " from worker "
582 << worker_address_ << " at the request of the client.";
583 StopTask(*task);
584 }
585
ExportState() const586 WorkerStateExport DataServiceWorkerImpl::ExportState() const {
587 WorkerStateExport worker_state_export;
588 *worker_state_export.mutable_worker_config() = config_;
589 mutex_lock l(mu_);
590 if (!registered_) {
591 return worker_state_export;
592 }
593 for (const auto& task : tasks_) {
594 *worker_state_export.add_tasks() = Export(task.second->task_def);
595 }
596 for (int64_t finished_task : finished_tasks_) {
597 worker_state_export.add_finished_task_ids(finished_task);
598 }
599 for (int64_t deleted_task : deleted_tasks_) {
600 worker_state_export.add_deleted_task_ids(deleted_task);
601 }
602 return worker_state_export;
603 }
604
Add(absl::string_view worker_address,std::shared_ptr<DataServiceWorkerImpl> worker)605 void LocalWorkers::Add(absl::string_view worker_address,
606 std::shared_ptr<DataServiceWorkerImpl> worker) {
607 DCHECK(worker != nullptr) << "Adding a nullptr local worker is disallowed.";
608 VLOG(1) << "Register local worker at address " << worker_address;
609 mutex_lock l(mu_);
610 (*local_workers_)[worker_address] = worker;
611 }
612
Get(absl::string_view worker_address)613 std::shared_ptr<DataServiceWorkerImpl> LocalWorkers::Get(
614 absl::string_view worker_address) {
615 tf_shared_lock l(mu_);
616 AddressToWorkerMap::const_iterator it = local_workers_->find(worker_address);
617 if (it == local_workers_->end()) {
618 return nullptr;
619 }
620 return it->second;
621 }
622
Empty()623 bool LocalWorkers::Empty() {
624 tf_shared_lock l(mu_);
625 return local_workers_->empty();
626 }
627
Remove(absl::string_view worker_address)628 void LocalWorkers::Remove(absl::string_view worker_address) {
629 VLOG(1) << "Remove local worker at address " << worker_address;
630 mutex_lock l(mu_);
631 local_workers_->erase(worker_address);
632 }
633
634 } // namespace data
635 } // namespace tensorflow
636