xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/dispatcher_impl.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
17 #define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
18 
19 #include <memory>
20 #include <optional>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/time/time.h"
27 #include "tensorflow/core/data/service/common.h"
28 #include "tensorflow/core/data/service/common.pb.h"
29 #include "tensorflow/core/data/service/dataset_store.h"
30 #include "tensorflow/core/data/service/dispatcher.pb.h"
31 #include "tensorflow/core/data/service/dispatcher_state.h"
32 #include "tensorflow/core/data/service/export.pb.h"
33 #include "tensorflow/core/data/service/task_remover.h"
34 #include "tensorflow/core/data/service/worker.grpc.pb.h"
35 #include "tensorflow/core/framework/dataset.h"
36 #include "tensorflow/core/platform/env.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/mutex.h"
39 #include "tensorflow/core/platform/status.h"
40 #include "tensorflow/core/platform/statusor.h"
41 #include "tensorflow/core/platform/thread_annotations.h"
42 #include "tensorflow/core/protobuf/data_service.pb.h"
43 #include "tensorflow/core/protobuf/service_config.pb.h"
44 #include "tensorflow/core/public/session.h"
45 
46 namespace tensorflow {
47 namespace data {
48 
49 // A service which coordinates a pool of workers to serve dataset elements over
50 // RPC.
51 //
52 // Glossary:
53 // * Dataset: A definition of how to generate a potentially large collection of
54 //   elements.
55 // * Iteration: A coordinated phase of reading from the tf.data service. An
56 //   iteration produces some amount of data, and (potentially multiple)
57 //   consumers consume the data from the iteration until there is no data left.
58 //   Each iteration has a ProcessingModeDef which determines what data it
59 //   produces.
60 // * Task: An iteration is broken into multiple tasks, which each represent
61 //   iterating over all of or part of the dataset. Workers process tasks.
62 // * Consumer: A process reading from the tf.data service.
63 //
64 // **Adding workers**
65 //
66 // tf.data service supports adding workers mid-iteration. When a new worker
67 // connects to the dispatcher, the dispatcher creates a new task for the worker,
68 // one task for each outstanding iteration. Consumers periodically heartbeat to
69 // the dispatcher to learn about new tasks.
70 //
71 // For non-round-robin-reads, there is no coordination among consumers. Each
72 // consumer will start reading from the new task as soon as it learns about the
73 // task from its heartbeat. Round robin reads, on the other hand, require
74 // consumers to read from the same task at each step. This requires coordination
75 // to ensure that all consumers start reading from the new task in the same
76 // round.
77 //
78 // The protocol for adding round robin tasks works as follows:
79 //
80 // - The dispatcher keeps track of which round each round-robin iteration is on.
81 // This
82 //   information is reported by consumers in their heartbeats.
83 // - When a new worker joins and there is an outstanding round-robin iteration,
84 //   we create a new task for the iteration and assign it to the worker.
85 //   However, we don't yet report the task in consumer heartbeats.
86 //   We call the task a "pending task" and add it to its iteration's "pending
87 //   tasks" queue.
88 // - When we create a pending task, we choose a "target round" to try adding
89 //   the task to. The target round is chosen by adding a "target round delta" to
90 //   the latest reported round for the iteration.
91 // - When a consumer heartbeats for an iteration and there is a pending task for
92 //   that iteration, the dispatcher sends a heartbeat response telling the
93 //   consumer to block before reading from the target round.
94 // - When a consumer receives a heartbeat response telling it to block
95 //   (before reading) a round, the consumer try to block the round. If the
96 //   consumer has already started the round, it will too late to block the
97 //   round.
98 // - When consumers heartbeat, they tell the dispatcher their current round and
99 //   whether they have blocked themselves from reading past a certain round. If
100 //   a consumer reports a current round exceeding the target round, the target
101 //   round has failed and needs to be increased. We choose a new target round by
102 //   doubling the previous target round delta. If the consumer reports that it
103 //   has blocked before the target round, we record that the consumer is ready
104 //   to add the new task. Once all consumers are ready to add the new task, we
105 //   remove the task from the pending tasks list and begin reporting the task to
106 //   consumers. We set the "starting_round" field of the task to indicate the
107 //   target round where all consumers should start reading from the task.
108 // - If a new worker joins while there are already pending tasks, a pending
109 //   task for the new worker is created and queued behind the existing tasks.
110 //   The new task won't be considered until all previous pending tasks have been
111 //   successfully added.
112 //
113 // An example of executing this protocol with two consumers could go as follows:
114 // 1. Consumers read up to round 50 and heartbeat that they are on round 50.
115 // 2. A new worker joins. Dispatcher chooses round 51 as the target round.
116 // 3. Consumer 1 heartbeats that its current round is 50. Dispatcher tells it to
117 //    block round 51.
118 // 4. Consumer 2 heartbeats that its current round is 51. Dispatcher realizes
119 //    that it is too late to block round 51 and chooses round 53 as the new
120 //    target round. Dispatcher tells consumer 2 to block round 53.
121 // 5. Consumer 1 heartbeats that its current round is 50 and that it has blocked
122 //    round 51. Dispatcher tells it to block round 53 instead. Dispatcher
123 //    records that consumer 1 is ready to add a task in round 53.
124 // 6. Consumer 2 heartbeats that its current round is 52 and it has blocked
125 //    round 53. Dispatcher realizes that all consumers are blocked on round 53
126 //    or earlier and promotes the task from pending to regular. Dispatcher sends
127 //    consumer 2 a task list containing the new task, and tells consumer 2 that
128 //    it no longer needs to block.
129 // 7. Consumer 1 heartbeats. Dispatcher sends consumer 1 the task list
130 //    containing the new task, and tells it that it no longer needs to block.
131 //
132 class DataServiceDispatcherImpl {
133  public:
134   explicit DataServiceDispatcherImpl(
135       const experimental::DispatcherConfig& config);
136 
137   ~DataServiceDispatcherImpl();
138 
139   // Starts the dispatcher. If there is a journal, this will read from the
140   // journal to restore the dispatcher's state.
141   Status Start();
142 
143   // Returns the number of active iterations.
144   size_t NumActiveIterations() TF_LOCKS_EXCLUDED(mu_);
145 
146   // See dispatcher.proto for API documentation.
147 
148   /// Worker-facing API.
149   Status WorkerHeartbeat(const WorkerHeartbeatRequest* request,
150                          WorkerHeartbeatResponse* response);
151   Status WorkerUpdate(const WorkerUpdateRequest* request,
152                       WorkerUpdateResponse* response);
153   Status GetDatasetDef(const GetDatasetDefRequest* request,
154                        GetDatasetDefResponse* response);
155   Status GetSplit(const GetSplitRequest* request, GetSplitResponse* response);
156 
157   /// Client-facing API.
158   Status GetVersion(const GetVersionRequest* request,
159                     GetVersionResponse* response);
160   Status GetOrRegisterDataset(const GetOrRegisterDatasetRequest* request,
161                               GetOrRegisterDatasetResponse* response);
162   Status GetDataServiceMetadata(const GetDataServiceMetadataRequest* request,
163                                 GetDataServiceMetadataResponse* response);
164   Status GetDataServiceConfig(const GetDataServiceConfigRequest* request,
165                               GetDataServiceConfigResponse* response);
166   Status GetOrCreateJob(const GetOrCreateJobRequest* request,
167                         GetOrCreateJobResponse* response);
168   Status GetOrCreateIteration(const GetOrCreateIterationRequest* request,
169                               GetOrCreateIterationResponse* response);
170   Status ReleaseIterationClient(const ReleaseIterationClientRequest* request,
171                                 ReleaseIterationClientResponse* response);
172   Status MaybeRemoveTask(const MaybeRemoveTaskRequest* request,
173                          MaybeRemoveTaskResponse* response);
174   Status ClientHeartbeat(const ClientHeartbeatRequest* request,
175                          ClientHeartbeatResponse* response);
176   Status GetWorkers(const GetWorkersRequest* request,
177                     GetWorkersResponse* response);
178 
179   // Exports the dispatcher state for debugging.
180   DispatcherStateExport ExportState() const;
181 
182  private:
183   // Restores split providers from the state in `iteration` and stores them in
184   // `restored`.
185   Status RestoreSplitProviders(
186       const DispatcherState::Iteration& iteration,
187       std::vector<std::unique_ptr<SplitProvider>>& restored)
188       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
189   // Makes split providers for the specified `dataset_id`, and stores them in
190   // `split_providers`.
191   Status MakeSplitProviders(
192       const std::string& dataset_id,
193       std::vector<std::unique_ptr<SplitProvider>>& split_providers)
194       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
195   // Registers a dataset with the given fingerprint, storing the new dataset's
196   // id in `dataset_id`.
197   Status RegisterDataset(uint64 fingerprint, const DatasetDef& dataset,
198                          const DataServiceMetadata& metadata,
199                          const std::string& requested_dataset_id,
200                          std::string& dataset_id)
201       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
202   // Finds the dataset ID with the requested dataset ID, or with the matching
203   // fingerprint if the ID does not exist. Returns nullptr if no such dataset
204   // exists.
205   StatusOr<std::optional<std::string>> FindDataset(
206       const GetOrRegisterDatasetRequest& request, uint64 fingerprint);
207   // Gets a worker's stub from `worker_stubs_`, or if none exists, creates a
208   // stub and stores it in `worker_stubs_`. A borrowed pointer to the stub is
209   // stored in `out_stub`.
210   Status GetOrCreateWorkerStub(const std::string& worker_address,
211                                WorkerService::Stub*& out_stub)
212       TF_LOCKS_EXCLUDED(mu_);
213   // Creates a job and stores it in `job`.
214   Status CreateJob(const std::string& job_name,
215                    const GetOrCreateJobRequest& request,
216                    std::shared_ptr<const DispatcherState::Job>& job)
217       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
218   // Creates an iteration and stores it in `iteration`. This method updates the
219   // dispatcher state with the new iteration, but does not assign tasks to
220   // workers.
221   Status CreateIteration(
222       const GetOrCreateIterationRequest& request,
223       std::shared_ptr<const DispatcherState::Iteration>& iteration)
224       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
225   // Creates tasks for the specified worker, one task for every unfinished
226   // iteration.
227   Status CreateTasksForWorker(const std::string& worker_address);
228   // Finds tasks that should be deleted from a worker, updating the heartbeat
229   // response.
230   Status FindTasksToDelete(
231       const absl::flat_hash_set<int64_t>& current_tasks,
232       const std::vector<std::shared_ptr<const DispatcherState::Task>>
233           assigned_tasks,
234       WorkerHeartbeatResponse* response);
235   // Finds new tasks that should be assigned to a worker and adds them to
236   // the heartbeat response.
237   Status FindNewTasks(
238       const std::string& worker_address,
239       const absl::flat_hash_set<int64_t>& current_tasks,
240       std::vector<std::shared_ptr<const DispatcherState::Task>>& assigned_tasks,
241       WorkerHeartbeatResponse* response);
242   // Acquires an iteration client id to read from the given iteration and sets
243   // `iteration_client_id`.
244   Status AcquireIterationClientId(
245       const std::shared_ptr<const DispatcherState::Iteration>& iteration,
246       int64_t& iteration_client_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
247   // Creates one task for each worker, for the given iteration. The created
248   // tasks are stored in `tasks`. This method only updates dispatcher metadata
249   // with the new tasks, but doesn't assign the tasks to the workers.
250   Status CreateTasksForIteration(
251       std::shared_ptr<const DispatcherState::Iteration> iteration,
252       std::vector<std::shared_ptr<const DispatcherState::Task>>& tasks)
253       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
254 
255   // Creates a new task for an iteration. The created task may be either pending
256   // or active.
257   Status CreateTask(std::shared_ptr<const DispatcherState::Iteration> iteration,
258                     const std::string& worker_address,
259                     std::shared_ptr<const DispatcherState::Task>& task)
260       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
261   // Creates a pending task for a round robin iteration. All consumers need to
262   // agree on which round to add the task in before the pending task can be
263   // promoted to a regular task.
264   Status CreatePendingTask(
265       std::shared_ptr<const DispatcherState::Iteration> iteration,
266       const std::string& worker_address) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
267   // Creates a new active task for an iteration, storing the created task in
268   // `task`.
269   Status CreateActiveTask(
270       std::shared_ptr<const DispatcherState::Iteration> iteration,
271       const std::string& worker_address,
272       std::shared_ptr<const DispatcherState::Task>& task);
273   // Assigns the list of tasks to the workers indicated by their
274   // `worker_address` fields.
275   Status AssignTasks(
276       std::vector<std::shared_ptr<const DispatcherState::Task>> tasks)
277       TF_LOCKS_EXCLUDED(mu_);
278   // Assigns a task to the worker indicated by its `worker_address` field.
279   Status AssignTask(std::shared_ptr<const DispatcherState::Task> task)
280       TF_LOCKS_EXCLUDED(mu_);
281   // Validates that an existing job matches a given request.
282   // Returns an error status describing any difference.
283   Status ValidateMatchingJob(std::shared_ptr<const DispatcherState::Job> job,
284                              const GetOrCreateJobRequest& request)
285       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
286   // Fills out a TaskDef with information about a task.
287   Status PopulateTaskDef(std::shared_ptr<const DispatcherState::Task> task,
288                          TaskDef* task_def) const
289       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
290   // Checks that the dispatcher has started, returning UNAVAILABLE if it hasn't.
291   Status CheckStarted() TF_LOCKS_EXCLUDED(mu_);
292   // Records that a split was produced by a call to `GetSplit`.
293   Status RecordSplitProduced(int64_t iteration_id, int64_t repetition,
294                              int64_t split_provider_index, bool finished)
295       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
296   // Applies a state update, updating both the journal and the in-memory state.
297   Status Apply(const Update& update) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
298   // Applies a state update, but doesn't update the journal. Only meant to be
299   // used when recovering state when the dispatcher starts.
300   Status ApplyWithoutJournaling(const Update& update)
301       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
302   // A thread which periodically checks for iterations to clean up.
303   void IterationGcThread();
304   // Releases iteration clients that haven't heartbeated recently.
305   Status ReleaseMissingClients() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
306   // Scans for old iterations and marks them as finished.
307   Status GcOldIterations() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
308   // Gets a `DatasetDef` from `dataset_store_` for the given dataset id, and
309   // stores it in `dataset_def`.
310   Status GetDatasetDef(const std::string& dataset_id,
311                        std::shared_ptr<const DatasetDef>& dataset_def)
312       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
313   // Gets a `DatasetDef` from `dataset_store_` for the given dataset, and
314   // stores it in `dataset_def`.
315   Status GetDatasetDef(const DispatcherState::Dataset& dataset,
316                        std::shared_ptr<const DatasetDef>& dataset_def)
317       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
318 
319   const experimental::DispatcherConfig config_;
320   Env* env_;
321 
322   mutable mutex mu_;
323   bool started_ TF_GUARDED_BY(mu_) = false;
324   bool cancelled_ TF_GUARDED_BY(mu_) = false;
325 
326   // Cached worker stubs for communicating with workers.
327   absl::flat_hash_map<std::string, std::unique_ptr<WorkerService::Stub>>
328       worker_stubs_ TF_GUARDED_BY(mu_);
329   // Store of dataset definitions.
330   std::unique_ptr<DatasetStore> dataset_store_ TF_GUARDED_BY(mu_);
331   // Mapping from iteration id to the split providers for the iteration.
332   absl::flat_hash_map<int64_t, std::vector<std::unique_ptr<SplitProvider>>>
333       split_providers_ TF_GUARDED_BY(mu_);
334   // Mapping from round robin iteration id to the round the iteration is
335   // currently on. This is based on the data provided by client heartbeats, and
336   // may be stale.
337   absl::flat_hash_map<int64_t, int64_t> round_robin_rounds_ TF_GUARDED_BY(mu_);
338   // Map from task id to a TaskRemover which determines when to remove the task.
339   absl::flat_hash_map<int64_t, std::shared_ptr<TaskRemover>>
340       remove_task_requests_ TF_GUARDED_BY(mu_);
341   // Map from client id to the time of the client's last heartbeat.
342   absl::flat_hash_map<int64_t, absl::Time> latest_client_heartbeats_time_
343       TF_GUARDED_BY(mu_);
344 
345   std::optional<std::unique_ptr<JournalWriter>> journal_writer_
346       TF_GUARDED_BY(mu_);
347   DispatcherState state_ TF_GUARDED_BY(mu_);
348   // Condition variable for waking up the iteration gc thread.
349   condition_variable iteration_gc_thread_cv_;
350   std::unique_ptr<Thread> iteration_gc_thread_;
351 
352   TF_DISALLOW_COPY_AND_ASSIGN(DataServiceDispatcherImpl);
353 };
354 
355 }  // namespace data
356 }  // namespace tensorflow
357 
358 #endif  // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
359