1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_ 17 #define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_ 18 19 #include <memory> 20 #include <optional> 21 #include <string> 22 #include <vector> 23 24 #include "absl/container/flat_hash_map.h" 25 #include "absl/container/flat_hash_set.h" 26 #include "absl/time/time.h" 27 #include "tensorflow/core/data/service/common.h" 28 #include "tensorflow/core/data/service/common.pb.h" 29 #include "tensorflow/core/data/service/dataset_store.h" 30 #include "tensorflow/core/data/service/dispatcher.pb.h" 31 #include "tensorflow/core/data/service/dispatcher_state.h" 32 #include "tensorflow/core/data/service/export.pb.h" 33 #include "tensorflow/core/data/service/task_remover.h" 34 #include "tensorflow/core/data/service/worker.grpc.pb.h" 35 #include "tensorflow/core/framework/dataset.h" 36 #include "tensorflow/core/platform/env.h" 37 #include "tensorflow/core/platform/macros.h" 38 #include "tensorflow/core/platform/mutex.h" 39 #include "tensorflow/core/platform/status.h" 40 #include "tensorflow/core/platform/statusor.h" 41 #include "tensorflow/core/platform/thread_annotations.h" 42 #include "tensorflow/core/protobuf/data_service.pb.h" 43 #include "tensorflow/core/protobuf/service_config.pb.h" 44 #include "tensorflow/core/public/session.h" 45 46 namespace tensorflow { 47 namespace data { 48 49 // A service which coordinates a pool of workers to serve dataset elements over 50 // RPC. 51 // 52 // Glossary: 53 // * Dataset: A definition of how to generate a potentially large collection of 54 // elements. 55 // * Iteration: A coordinated phase of reading from the tf.data service. An 56 // iteration produces some amount of data, and (potentially multiple) 57 // consumers consume the data from the iteration until there is no data left. 58 // Each iteration has a ProcessingModeDef which determines what data it 59 // produces. 60 // * Task: An iteration is broken into multiple tasks, which each represent 61 // iterating over all of or part of the dataset. Workers process tasks. 62 // * Consumer: A process reading from the tf.data service. 63 // 64 // **Adding workers** 65 // 66 // tf.data service supports adding workers mid-iteration. When a new worker 67 // connects to the dispatcher, the dispatcher creates a new task for the worker, 68 // one task for each outstanding iteration. Consumers periodically heartbeat to 69 // the dispatcher to learn about new tasks. 70 // 71 // For non-round-robin-reads, there is no coordination among consumers. Each 72 // consumer will start reading from the new task as soon as it learns about the 73 // task from its heartbeat. Round robin reads, on the other hand, require 74 // consumers to read from the same task at each step. This requires coordination 75 // to ensure that all consumers start reading from the new task in the same 76 // round. 77 // 78 // The protocol for adding round robin tasks works as follows: 79 // 80 // - The dispatcher keeps track of which round each round-robin iteration is on. 81 // This 82 // information is reported by consumers in their heartbeats. 83 // - When a new worker joins and there is an outstanding round-robin iteration, 84 // we create a new task for the iteration and assign it to the worker. 85 // However, we don't yet report the task in consumer heartbeats. 86 // We call the task a "pending task" and add it to its iteration's "pending 87 // tasks" queue. 88 // - When we create a pending task, we choose a "target round" to try adding 89 // the task to. The target round is chosen by adding a "target round delta" to 90 // the latest reported round for the iteration. 91 // - When a consumer heartbeats for an iteration and there is a pending task for 92 // that iteration, the dispatcher sends a heartbeat response telling the 93 // consumer to block before reading from the target round. 94 // - When a consumer receives a heartbeat response telling it to block 95 // (before reading) a round, the consumer try to block the round. If the 96 // consumer has already started the round, it will too late to block the 97 // round. 98 // - When consumers heartbeat, they tell the dispatcher their current round and 99 // whether they have blocked themselves from reading past a certain round. If 100 // a consumer reports a current round exceeding the target round, the target 101 // round has failed and needs to be increased. We choose a new target round by 102 // doubling the previous target round delta. If the consumer reports that it 103 // has blocked before the target round, we record that the consumer is ready 104 // to add the new task. Once all consumers are ready to add the new task, we 105 // remove the task from the pending tasks list and begin reporting the task to 106 // consumers. We set the "starting_round" field of the task to indicate the 107 // target round where all consumers should start reading from the task. 108 // - If a new worker joins while there are already pending tasks, a pending 109 // task for the new worker is created and queued behind the existing tasks. 110 // The new task won't be considered until all previous pending tasks have been 111 // successfully added. 112 // 113 // An example of executing this protocol with two consumers could go as follows: 114 // 1. Consumers read up to round 50 and heartbeat that they are on round 50. 115 // 2. A new worker joins. Dispatcher chooses round 51 as the target round. 116 // 3. Consumer 1 heartbeats that its current round is 50. Dispatcher tells it to 117 // block round 51. 118 // 4. Consumer 2 heartbeats that its current round is 51. Dispatcher realizes 119 // that it is too late to block round 51 and chooses round 53 as the new 120 // target round. Dispatcher tells consumer 2 to block round 53. 121 // 5. Consumer 1 heartbeats that its current round is 50 and that it has blocked 122 // round 51. Dispatcher tells it to block round 53 instead. Dispatcher 123 // records that consumer 1 is ready to add a task in round 53. 124 // 6. Consumer 2 heartbeats that its current round is 52 and it has blocked 125 // round 53. Dispatcher realizes that all consumers are blocked on round 53 126 // or earlier and promotes the task from pending to regular. Dispatcher sends 127 // consumer 2 a task list containing the new task, and tells consumer 2 that 128 // it no longer needs to block. 129 // 7. Consumer 1 heartbeats. Dispatcher sends consumer 1 the task list 130 // containing the new task, and tells it that it no longer needs to block. 131 // 132 class DataServiceDispatcherImpl { 133 public: 134 explicit DataServiceDispatcherImpl( 135 const experimental::DispatcherConfig& config); 136 137 ~DataServiceDispatcherImpl(); 138 139 // Starts the dispatcher. If there is a journal, this will read from the 140 // journal to restore the dispatcher's state. 141 Status Start(); 142 143 // Returns the number of active iterations. 144 size_t NumActiveIterations() TF_LOCKS_EXCLUDED(mu_); 145 146 // See dispatcher.proto for API documentation. 147 148 /// Worker-facing API. 149 Status WorkerHeartbeat(const WorkerHeartbeatRequest* request, 150 WorkerHeartbeatResponse* response); 151 Status WorkerUpdate(const WorkerUpdateRequest* request, 152 WorkerUpdateResponse* response); 153 Status GetDatasetDef(const GetDatasetDefRequest* request, 154 GetDatasetDefResponse* response); 155 Status GetSplit(const GetSplitRequest* request, GetSplitResponse* response); 156 157 /// Client-facing API. 158 Status GetVersion(const GetVersionRequest* request, 159 GetVersionResponse* response); 160 Status GetOrRegisterDataset(const GetOrRegisterDatasetRequest* request, 161 GetOrRegisterDatasetResponse* response); 162 Status GetDataServiceMetadata(const GetDataServiceMetadataRequest* request, 163 GetDataServiceMetadataResponse* response); 164 Status GetDataServiceConfig(const GetDataServiceConfigRequest* request, 165 GetDataServiceConfigResponse* response); 166 Status GetOrCreateJob(const GetOrCreateJobRequest* request, 167 GetOrCreateJobResponse* response); 168 Status GetOrCreateIteration(const GetOrCreateIterationRequest* request, 169 GetOrCreateIterationResponse* response); 170 Status ReleaseIterationClient(const ReleaseIterationClientRequest* request, 171 ReleaseIterationClientResponse* response); 172 Status MaybeRemoveTask(const MaybeRemoveTaskRequest* request, 173 MaybeRemoveTaskResponse* response); 174 Status ClientHeartbeat(const ClientHeartbeatRequest* request, 175 ClientHeartbeatResponse* response); 176 Status GetWorkers(const GetWorkersRequest* request, 177 GetWorkersResponse* response); 178 179 // Exports the dispatcher state for debugging. 180 DispatcherStateExport ExportState() const; 181 182 private: 183 // Restores split providers from the state in `iteration` and stores them in 184 // `restored`. 185 Status RestoreSplitProviders( 186 const DispatcherState::Iteration& iteration, 187 std::vector<std::unique_ptr<SplitProvider>>& restored) 188 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 189 // Makes split providers for the specified `dataset_id`, and stores them in 190 // `split_providers`. 191 Status MakeSplitProviders( 192 const std::string& dataset_id, 193 std::vector<std::unique_ptr<SplitProvider>>& split_providers) 194 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 195 // Registers a dataset with the given fingerprint, storing the new dataset's 196 // id in `dataset_id`. 197 Status RegisterDataset(uint64 fingerprint, const DatasetDef& dataset, 198 const DataServiceMetadata& metadata, 199 const std::string& requested_dataset_id, 200 std::string& dataset_id) 201 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 202 // Finds the dataset ID with the requested dataset ID, or with the matching 203 // fingerprint if the ID does not exist. Returns nullptr if no such dataset 204 // exists. 205 StatusOr<std::optional<std::string>> FindDataset( 206 const GetOrRegisterDatasetRequest& request, uint64 fingerprint); 207 // Gets a worker's stub from `worker_stubs_`, or if none exists, creates a 208 // stub and stores it in `worker_stubs_`. A borrowed pointer to the stub is 209 // stored in `out_stub`. 210 Status GetOrCreateWorkerStub(const std::string& worker_address, 211 WorkerService::Stub*& out_stub) 212 TF_LOCKS_EXCLUDED(mu_); 213 // Creates a job and stores it in `job`. 214 Status CreateJob(const std::string& job_name, 215 const GetOrCreateJobRequest& request, 216 std::shared_ptr<const DispatcherState::Job>& job) 217 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 218 // Creates an iteration and stores it in `iteration`. This method updates the 219 // dispatcher state with the new iteration, but does not assign tasks to 220 // workers. 221 Status CreateIteration( 222 const GetOrCreateIterationRequest& request, 223 std::shared_ptr<const DispatcherState::Iteration>& iteration) 224 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 225 // Creates tasks for the specified worker, one task for every unfinished 226 // iteration. 227 Status CreateTasksForWorker(const std::string& worker_address); 228 // Finds tasks that should be deleted from a worker, updating the heartbeat 229 // response. 230 Status FindTasksToDelete( 231 const absl::flat_hash_set<int64_t>& current_tasks, 232 const std::vector<std::shared_ptr<const DispatcherState::Task>> 233 assigned_tasks, 234 WorkerHeartbeatResponse* response); 235 // Finds new tasks that should be assigned to a worker and adds them to 236 // the heartbeat response. 237 Status FindNewTasks( 238 const std::string& worker_address, 239 const absl::flat_hash_set<int64_t>& current_tasks, 240 std::vector<std::shared_ptr<const DispatcherState::Task>>& assigned_tasks, 241 WorkerHeartbeatResponse* response); 242 // Acquires an iteration client id to read from the given iteration and sets 243 // `iteration_client_id`. 244 Status AcquireIterationClientId( 245 const std::shared_ptr<const DispatcherState::Iteration>& iteration, 246 int64_t& iteration_client_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 247 // Creates one task for each worker, for the given iteration. The created 248 // tasks are stored in `tasks`. This method only updates dispatcher metadata 249 // with the new tasks, but doesn't assign the tasks to the workers. 250 Status CreateTasksForIteration( 251 std::shared_ptr<const DispatcherState::Iteration> iteration, 252 std::vector<std::shared_ptr<const DispatcherState::Task>>& tasks) 253 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 254 255 // Creates a new task for an iteration. The created task may be either pending 256 // or active. 257 Status CreateTask(std::shared_ptr<const DispatcherState::Iteration> iteration, 258 const std::string& worker_address, 259 std::shared_ptr<const DispatcherState::Task>& task) 260 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 261 // Creates a pending task for a round robin iteration. All consumers need to 262 // agree on which round to add the task in before the pending task can be 263 // promoted to a regular task. 264 Status CreatePendingTask( 265 std::shared_ptr<const DispatcherState::Iteration> iteration, 266 const std::string& worker_address) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 267 // Creates a new active task for an iteration, storing the created task in 268 // `task`. 269 Status CreateActiveTask( 270 std::shared_ptr<const DispatcherState::Iteration> iteration, 271 const std::string& worker_address, 272 std::shared_ptr<const DispatcherState::Task>& task); 273 // Assigns the list of tasks to the workers indicated by their 274 // `worker_address` fields. 275 Status AssignTasks( 276 std::vector<std::shared_ptr<const DispatcherState::Task>> tasks) 277 TF_LOCKS_EXCLUDED(mu_); 278 // Assigns a task to the worker indicated by its `worker_address` field. 279 Status AssignTask(std::shared_ptr<const DispatcherState::Task> task) 280 TF_LOCKS_EXCLUDED(mu_); 281 // Validates that an existing job matches a given request. 282 // Returns an error status describing any difference. 283 Status ValidateMatchingJob(std::shared_ptr<const DispatcherState::Job> job, 284 const GetOrCreateJobRequest& request) 285 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 286 // Fills out a TaskDef with information about a task. 287 Status PopulateTaskDef(std::shared_ptr<const DispatcherState::Task> task, 288 TaskDef* task_def) const 289 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 290 // Checks that the dispatcher has started, returning UNAVAILABLE if it hasn't. 291 Status CheckStarted() TF_LOCKS_EXCLUDED(mu_); 292 // Records that a split was produced by a call to `GetSplit`. 293 Status RecordSplitProduced(int64_t iteration_id, int64_t repetition, 294 int64_t split_provider_index, bool finished) 295 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 296 // Applies a state update, updating both the journal and the in-memory state. 297 Status Apply(const Update& update) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 298 // Applies a state update, but doesn't update the journal. Only meant to be 299 // used when recovering state when the dispatcher starts. 300 Status ApplyWithoutJournaling(const Update& update) 301 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 302 // A thread which periodically checks for iterations to clean up. 303 void IterationGcThread(); 304 // Releases iteration clients that haven't heartbeated recently. 305 Status ReleaseMissingClients() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 306 // Scans for old iterations and marks them as finished. 307 Status GcOldIterations() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 308 // Gets a `DatasetDef` from `dataset_store_` for the given dataset id, and 309 // stores it in `dataset_def`. 310 Status GetDatasetDef(const std::string& dataset_id, 311 std::shared_ptr<const DatasetDef>& dataset_def) 312 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 313 // Gets a `DatasetDef` from `dataset_store_` for the given dataset, and 314 // stores it in `dataset_def`. 315 Status GetDatasetDef(const DispatcherState::Dataset& dataset, 316 std::shared_ptr<const DatasetDef>& dataset_def) 317 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 318 319 const experimental::DispatcherConfig config_; 320 Env* env_; 321 322 mutable mutex mu_; 323 bool started_ TF_GUARDED_BY(mu_) = false; 324 bool cancelled_ TF_GUARDED_BY(mu_) = false; 325 326 // Cached worker stubs for communicating with workers. 327 absl::flat_hash_map<std::string, std::unique_ptr<WorkerService::Stub>> 328 worker_stubs_ TF_GUARDED_BY(mu_); 329 // Store of dataset definitions. 330 std::unique_ptr<DatasetStore> dataset_store_ TF_GUARDED_BY(mu_); 331 // Mapping from iteration id to the split providers for the iteration. 332 absl::flat_hash_map<int64_t, std::vector<std::unique_ptr<SplitProvider>>> 333 split_providers_ TF_GUARDED_BY(mu_); 334 // Mapping from round robin iteration id to the round the iteration is 335 // currently on. This is based on the data provided by client heartbeats, and 336 // may be stale. 337 absl::flat_hash_map<int64_t, int64_t> round_robin_rounds_ TF_GUARDED_BY(mu_); 338 // Map from task id to a TaskRemover which determines when to remove the task. 339 absl::flat_hash_map<int64_t, std::shared_ptr<TaskRemover>> 340 remove_task_requests_ TF_GUARDED_BY(mu_); 341 // Map from client id to the time of the client's last heartbeat. 342 absl::flat_hash_map<int64_t, absl::Time> latest_client_heartbeats_time_ 343 TF_GUARDED_BY(mu_); 344 345 std::optional<std::unique_ptr<JournalWriter>> journal_writer_ 346 TF_GUARDED_BY(mu_); 347 DispatcherState state_ TF_GUARDED_BY(mu_); 348 // Condition variable for waking up the iteration gc thread. 349 condition_variable iteration_gc_thread_cv_; 350 std::unique_ptr<Thread> iteration_gc_thread_; 351 352 TF_DISALLOW_COPY_AND_ASSIGN(DataServiceDispatcherImpl); 353 }; 354 355 } // namespace data 356 } // namespace tensorflow 357 358 #endif // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_ 359