1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_ 16 #define TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_ 17 18 #include <memory> 19 #include <string> 20 #include <utility> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "absl/container/flat_hash_set.h" 24 #include "absl/strings/string_view.h" 25 #include "tensorflow/core/data/service/common.pb.h" 26 #include "tensorflow/core/data/service/data_transfer.h" 27 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h" 28 #include "tensorflow/core/data/service/dispatcher_client.h" 29 #include "tensorflow/core/data/service/export.pb.h" 30 #include "tensorflow/core/data/service/task_runner.h" 31 #include "tensorflow/core/data/service/worker.pb.h" 32 #include "tensorflow/core/data/standalone.h" 33 #include "tensorflow/core/framework/cancellation.h" 34 #include "tensorflow/core/lib/core/errors.h" 35 #include "tensorflow/core/platform/env.h" 36 #include "tensorflow/core/platform/mutex.h" 37 #include "tensorflow/core/platform/status.h" 38 #include "tensorflow/core/platform/statusor.h" 39 #include "tensorflow/core/platform/thread_annotations.h" 40 #include "tensorflow/core/protobuf/service_config.pb.h" 41 #include "tensorflow/core/public/session.h" 42 43 namespace tensorflow { 44 namespace data { 45 46 // A TensorFlow DataService serves dataset elements over RPC. 47 class DataServiceWorkerImpl { 48 public: 49 explicit DataServiceWorkerImpl(const experimental::WorkerConfig& config); 50 ~DataServiceWorkerImpl(); 51 52 // Starts the worker. The worker needs to know its own address so that it can 53 // register with the dispatcher. This is set in `Start` instead of in the 54 // constructor because the worker may be binding to port `0`, in which case 55 // the address isn't known until the worker has started and decided which port 56 // to bind to. 57 Status Start(const std::string& worker_address, 58 const std::string& transfer_address); 59 // Stops the worker, attempting a clean shutdown by rejecting new requests 60 // and waiting for outstanding requests to complete. 61 void Stop(); 62 63 // Serves a GetElement request, storing the result in `*result`. See 64 // worker.proto for GetElement API documentation. 65 Status GetElementResult(const GetElementRequest* request, 66 GetElementResult* result); 67 68 // Deletes the local task and iterator. Only called by local clients to delete 69 // unused task iterators assuming the task is not read by remote clients. This 70 // method is not visible to gRPC clients. 71 void DeleteLocalTask(const TaskInfo& task_info); 72 73 // See worker.proto for API documentation. 74 75 /// Dispatcher-facing API. 76 Status ProcessTask(const ProcessTaskRequest* request, 77 ProcessTaskResponse* response); 78 79 /// Client-facing API. 80 Status GetElement(const GetElementRequest* request, 81 GetElementResponse* response); 82 Status GetWorkerTasks(const GetWorkerTasksRequest* request, 83 GetWorkerTasksResponse* response); 84 85 // Exports the worker state for debugging. 86 WorkerStateExport ExportState() const; 87 88 private: 89 struct Task { TaskTask90 explicit Task(TaskDef task_def) : task_def(std::move(task_def)) {} 91 92 TaskDef task_def; 93 mutex mu; 94 bool initialized TF_GUARDED_BY(mu) = false; 95 int64_t outstanding_requests TF_GUARDED_BY(&DataServiceWorkerImpl::mu_) = 0; 96 std::unique_ptr<TaskRunner> task_runner; 97 }; 98 99 // Validates the worker config. 100 Status ValidateWorkerConfig() const; 101 // Sends task status to the dispatcher and checks for dispatcher commands. 102 Status SendTaskUpdates() TF_LOCKS_EXCLUDED(mu_); 103 // Creates an iterator to process a task. 104 Status ProcessTaskInternal(const TaskDef& task) 105 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 106 Status EnsureTaskInitialized(Task& task); 107 // Stops a task, cancelling the task's outstanding requests and waiting for 108 // them to finish. 109 void StopTask(Task& task) TF_LOCKS_EXCLUDED(mu_); 110 // A thread for notifying the dispatcher when tasks complete. 111 void TaskCompletionThread() TF_LOCKS_EXCLUDED(mu_); 112 // A thread for doing periodic heartbeats to the dispatcher. 113 void HeartbeatThread() TF_LOCKS_EXCLUDED(mu_); 114 // Performs a heartbeat to the dispatcher. 115 Status Heartbeat() TF_LOCKS_EXCLUDED(mu_); 116 // Gets the DatasetDef for `task_def`. 117 StatusOr<DatasetDef> GetDatasetDef(const TaskDef& task_def) const; 118 // Creates a dataset from `dataset_def`. 119 StatusOr<std::unique_ptr<standalone::Dataset>> MakeDataset( 120 const DatasetDef& dataset_def, const TaskDef& task_def) const; 121 // Creates an iterator for `dataset`. 122 StatusOr<std::unique_ptr<standalone::Iterator>> MakeDatasetIterator( 123 standalone::Dataset& dataset, const TaskDef& task_def) const; 124 125 const experimental::WorkerConfig config_; 126 // Worker Borg job UID for telemetry. -1 if not supported. 127 const int64_t worker_uid_; 128 129 // The worker's own address. 130 std::string worker_address_; 131 std::string transfer_address_; 132 std::unique_ptr<DataServiceDispatcherClient> dispatcher_; 133 134 mutable mutex mu_; 135 condition_variable cv_; 136 // Information about tasks, keyed by task ids. The tasks are updated based on 137 // the heartbeat responses from the dispatcher. 138 absl::flat_hash_map<int64_t, std::shared_ptr<Task>> tasks_ TF_GUARDED_BY(mu_); 139 // Ids of tasks that have finished. 140 absl::flat_hash_set<int64_t> finished_tasks_ TF_GUARDED_BY(mu_); 141 // Completed tasks which haven't yet been communicated to the dispatcher. 142 absl::flat_hash_set<int64_t> pending_completed_tasks_ TF_GUARDED_BY(mu_); 143 // Tasks deleted by the local client. If the client tries to read from them 144 // again, the worker will return a non-retriable FailedPrecondition error. 145 absl::flat_hash_set<int64_t> deleted_tasks_ TF_GUARDED_BY(mu_); 146 bool cancelled_ TF_GUARDED_BY(mu_) = false; 147 // Whether the worker has registered with the dispatcher yet. 148 bool registered_ TF_GUARDED_BY(mu_) = false; 149 condition_variable task_completion_cv_ TF_GUARDED_BY(mu_); 150 condition_variable heartbeat_cv_ TF_GUARDED_BY(mu_); 151 CancellationManager cancellation_manager_; 152 153 // A thread for notifying the dispatcher when tasks complete. 154 std::unique_ptr<Thread> task_completion_thread_; 155 // A thread for performing regular heartbeats to the dispatcher. 156 std::unique_ptr<Thread> heartbeat_thread_; 157 158 TF_DISALLOW_COPY_AND_ASSIGN(DataServiceWorkerImpl); 159 }; 160 161 // Local in-process workers shared among clients and servers. If clients and 162 // workers colocate in the same process, clients can read from local workers to 163 // reduce RPC calls and data copy. 164 class LocalWorkers { 165 public: 166 // Adds a `worker` at `worker_address`. If a worker already exists at the 167 // address, it will be updated to the new `worker`. 168 // REQUIRES: worker != nullptr. 169 static void Add(absl::string_view worker_address, 170 std::shared_ptr<DataServiceWorkerImpl> worker); 171 172 // Gets a local worker at `worker_address`. Returns nullptr if a worker is not 173 // found. 174 static std::shared_ptr<DataServiceWorkerImpl> Get( 175 absl::string_view worker_address); 176 177 // Returns if there are any local workers in the process. 178 static bool Empty(); 179 180 // Removes a worker at `worker_address`. It is no-op if a worker is not found 181 // at the address. 182 static void Remove(absl::string_view worker_address); 183 184 private: 185 using AddressToWorkerMap = 186 absl::flat_hash_map<std::string, std::shared_ptr<DataServiceWorkerImpl>>; 187 static mutex mu_; 188 static AddressToWorkerMap* local_workers_ TF_GUARDED_BY(mu_); 189 }; 190 191 } // namespace data 192 } // namespace tensorflow 193 194 #endif // TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_ 195