xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/worker_impl.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_
16 #define TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/core/data/service/common.pb.h"
26 #include "tensorflow/core/data/service/data_transfer.h"
27 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
28 #include "tensorflow/core/data/service/dispatcher_client.h"
29 #include "tensorflow/core/data/service/export.pb.h"
30 #include "tensorflow/core/data/service/task_runner.h"
31 #include "tensorflow/core/data/service/worker.pb.h"
32 #include "tensorflow/core/data/standalone.h"
33 #include "tensorflow/core/framework/cancellation.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/platform/env.h"
36 #include "tensorflow/core/platform/mutex.h"
37 #include "tensorflow/core/platform/status.h"
38 #include "tensorflow/core/platform/statusor.h"
39 #include "tensorflow/core/platform/thread_annotations.h"
40 #include "tensorflow/core/protobuf/service_config.pb.h"
41 #include "tensorflow/core/public/session.h"
42 
43 namespace tensorflow {
44 namespace data {
45 
46 // A TensorFlow DataService serves dataset elements over RPC.
47 class DataServiceWorkerImpl {
48  public:
49   explicit DataServiceWorkerImpl(const experimental::WorkerConfig& config);
50   ~DataServiceWorkerImpl();
51 
52   // Starts the worker. The worker needs to know its own address so that it can
53   // register with the dispatcher. This is set in `Start` instead of in the
54   // constructor because the worker may be binding to port `0`, in which case
55   // the address isn't known until the worker has started and decided which port
56   // to bind to.
57   Status Start(const std::string& worker_address,
58                const std::string& transfer_address);
59   // Stops the worker, attempting a clean shutdown by rejecting new requests
60   // and waiting for outstanding requests to complete.
61   void Stop();
62 
63   // Serves a GetElement request, storing the result in `*result`. See
64   // worker.proto for GetElement API documentation.
65   Status GetElementResult(const GetElementRequest* request,
66                           GetElementResult* result);
67 
68   // Deletes the local task and iterator. Only called by local clients to delete
69   // unused task iterators assuming the task is not read by remote clients. This
70   // method is not visible to gRPC clients.
71   void DeleteLocalTask(const TaskInfo& task_info);
72 
73   // See worker.proto for API documentation.
74 
75   /// Dispatcher-facing API.
76   Status ProcessTask(const ProcessTaskRequest* request,
77                      ProcessTaskResponse* response);
78 
79   /// Client-facing API.
80   Status GetElement(const GetElementRequest* request,
81                     GetElementResponse* response);
82   Status GetWorkerTasks(const GetWorkerTasksRequest* request,
83                         GetWorkerTasksResponse* response);
84 
85   // Exports the worker state for debugging.
86   WorkerStateExport ExportState() const;
87 
88  private:
89   struct Task {
TaskTask90     explicit Task(TaskDef task_def) : task_def(std::move(task_def)) {}
91 
92     TaskDef task_def;
93     mutex mu;
94     bool initialized TF_GUARDED_BY(mu) = false;
95     int64_t outstanding_requests TF_GUARDED_BY(&DataServiceWorkerImpl::mu_) = 0;
96     std::unique_ptr<TaskRunner> task_runner;
97   };
98 
99   // Validates the worker config.
100   Status ValidateWorkerConfig() const;
101   // Sends task status to the dispatcher and checks for dispatcher commands.
102   Status SendTaskUpdates() TF_LOCKS_EXCLUDED(mu_);
103   // Creates an iterator to process a task.
104   Status ProcessTaskInternal(const TaskDef& task)
105       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
106   Status EnsureTaskInitialized(Task& task);
107   // Stops a task, cancelling the task's outstanding requests and waiting for
108   // them to finish.
109   void StopTask(Task& task) TF_LOCKS_EXCLUDED(mu_);
110   // A thread for notifying the dispatcher when tasks complete.
111   void TaskCompletionThread() TF_LOCKS_EXCLUDED(mu_);
112   // A thread for doing periodic heartbeats to the dispatcher.
113   void HeartbeatThread() TF_LOCKS_EXCLUDED(mu_);
114   // Performs a heartbeat to the dispatcher.
115   Status Heartbeat() TF_LOCKS_EXCLUDED(mu_);
116   // Gets the DatasetDef for `task_def`.
117   StatusOr<DatasetDef> GetDatasetDef(const TaskDef& task_def) const;
118   // Creates a dataset from `dataset_def`.
119   StatusOr<std::unique_ptr<standalone::Dataset>> MakeDataset(
120       const DatasetDef& dataset_def, const TaskDef& task_def) const;
121   // Creates an iterator for `dataset`.
122   StatusOr<std::unique_ptr<standalone::Iterator>> MakeDatasetIterator(
123       standalone::Dataset& dataset, const TaskDef& task_def) const;
124 
125   const experimental::WorkerConfig config_;
126   // Worker Borg job UID for telemetry. -1 if not supported.
127   const int64_t worker_uid_;
128 
129   // The worker's own address.
130   std::string worker_address_;
131   std::string transfer_address_;
132   std::unique_ptr<DataServiceDispatcherClient> dispatcher_;
133 
134   mutable mutex mu_;
135   condition_variable cv_;
136   // Information about tasks, keyed by task ids. The tasks are updated based on
137   // the heartbeat responses from the dispatcher.
138   absl::flat_hash_map<int64_t, std::shared_ptr<Task>> tasks_ TF_GUARDED_BY(mu_);
139   // Ids of tasks that have finished.
140   absl::flat_hash_set<int64_t> finished_tasks_ TF_GUARDED_BY(mu_);
141   // Completed tasks which haven't yet been communicated to the dispatcher.
142   absl::flat_hash_set<int64_t> pending_completed_tasks_ TF_GUARDED_BY(mu_);
143   // Tasks deleted by the local client. If the client tries to read from them
144   // again, the worker will return a non-retriable FailedPrecondition error.
145   absl::flat_hash_set<int64_t> deleted_tasks_ TF_GUARDED_BY(mu_);
146   bool cancelled_ TF_GUARDED_BY(mu_) = false;
147   // Whether the worker has registered with the dispatcher yet.
148   bool registered_ TF_GUARDED_BY(mu_) = false;
149   condition_variable task_completion_cv_ TF_GUARDED_BY(mu_);
150   condition_variable heartbeat_cv_ TF_GUARDED_BY(mu_);
151   CancellationManager cancellation_manager_;
152 
153   // A thread for notifying the dispatcher when tasks complete.
154   std::unique_ptr<Thread> task_completion_thread_;
155   // A thread for performing regular heartbeats to the dispatcher.
156   std::unique_ptr<Thread> heartbeat_thread_;
157 
158   TF_DISALLOW_COPY_AND_ASSIGN(DataServiceWorkerImpl);
159 };
160 
161 // Local in-process workers shared among clients and servers. If clients and
162 // workers colocate in the same process, clients can read from local workers to
163 // reduce RPC calls and data copy.
164 class LocalWorkers {
165  public:
166   // Adds a `worker` at `worker_address`. If a worker already exists at the
167   // address, it will be updated to the new `worker`.
168   // REQUIRES: worker != nullptr.
169   static void Add(absl::string_view worker_address,
170                   std::shared_ptr<DataServiceWorkerImpl> worker);
171 
172   // Gets a local worker at `worker_address`. Returns nullptr if a worker is not
173   // found.
174   static std::shared_ptr<DataServiceWorkerImpl> Get(
175       absl::string_view worker_address);
176 
177   // Returns if there are any local workers in the process.
178   static bool Empty();
179 
180   // Removes a worker at `worker_address`. It is no-op if a worker is not found
181   // at the address.
182   static void Remove(absl::string_view worker_address);
183 
184  private:
185   using AddressToWorkerMap =
186       absl::flat_hash_map<std::string, std::shared_ptr<DataServiceWorkerImpl>>;
187   static mutex mu_;
188   static AddressToWorkerMap* local_workers_ TF_GUARDED_BY(mu_);
189 };
190 
191 }  // namespace data
192 }  // namespace tensorflow
193 
194 #endif  // TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_
195