xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/dispatcher_client.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_CLIENT_H_
16 #define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_CLIENT_H_
17 
18 #include <memory>
19 #include <optional>
20 #include <string>
21 #include <vector>
22 
23 #include "tensorflow/core/data/service/common.h"
24 #include "tensorflow/core/data/service/common.pb.h"
25 #include "tensorflow/core/data/service/data_transfer.h"
26 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
27 #include "tensorflow/core/data/service/dispatcher.pb.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/status.h"
32 #include "tensorflow/core/platform/statusor.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/protobuf/data_service.pb.h"
35 #include "tensorflow/core/protobuf/service_config.pb.h"
36 
37 namespace tensorflow {
38 namespace data {
39 
40 // Client for communicating with the tf.data service dispatcher.
41 class DataServiceDispatcherClient : public DataServiceClientBase {
42  public:
DataServiceDispatcherClient(const std::string & address,const std::string & protocol)43   DataServiceDispatcherClient(const std::string& address,
44                               const std::string& protocol)
45       : DataServiceClientBase(address, protocol) {}
46 
47   // Sends a heartbeat to the dispatcher. If the worker wasn't already
48   // registered with the dispatcher, this will register the worker. The
49   // dispatcher will report which new tasks the worker should run, and which
50   // tasks it should delete.
51   StatusOr<WorkerHeartbeatResponse> WorkerHeartbeat(
52       const WorkerHeartbeatRequest& request);
53 
54   // Updates the dispatcher with information about the worker's state.
55   Status WorkerUpdate(const std::string& worker_address,
56                       std::vector<TaskProgress>& task_progress);
57 
58   // Gets a dataset definition for the given dataset id, and stores the
59   // definition in `dataset_def`.
60   Status GetDatasetDef(const std::string& dataset_id, DatasetDef& dataset_def);
61 
62   // Gets the next split for the specified iteration id, repetition, and split
63   // provider index.
64   Status GetSplit(int64_t iteration_id, int64_t repetition,
65                   int64_t split_provider_index, Tensor& split,
66                   bool& end_of_splits);
67 
68   // Registers a dataset with the tf.data service, and stores the generated
69   // dataset id in `dataset_id`.
70   Status RegisterDataset(const DatasetDef& dataset,
71                          const DataServiceMetadata& metadata,
72                          const std::optional<std::string>& requested_dataset_id,
73                          std::string& dataset_id);
74 
75   // If `job_name` is set, looks up a job matching `job_name`.
76   // If `job_name` is absent or no matching job is found, creates a
77   // new job. The resulting job id is stored in `job_id`.
78   Status GetOrCreateJob(const std::string& dataset_id,
79                         const ProcessingModeDef& processing_mode,
80                         const std::optional<std::string>& job_name,
81                         std::optional<int64_t> num_consumers,
82                         bool use_cross_trainer_cache,
83                         TargetWorkers target_workers, int64_t& job_id);
84 
85   // Looks up an iteration of a job, creating an iteration if one doesn't
86   // already exist. The returned `iteration_client_id` can be used to query
87   // information about the iteration. The client should call
88   // `ReleaseIterationClient` when finished with the iteration, so that
89   // resources can be reclaimed.
90   Status GetOrCreateIteration(int64_t job_id, int64_t repetition,
91                               int64_t& iteration_client_id);
92 
93   // Releases a iteration client id, indicating that the id will no longer be
94   // used to read from the iteration.
95   Status ReleaseIterationClient(int64_t iteration_client_id);
96 
97   // Attempts to remove a task. The task is removed if all consumers try to
98   // remove the task in the same round.
99   Status MaybeRemoveTask(int64_t task_id, int64_t consumer_index, int64_t round,
100                          bool& removed);
101 
102   // Heartbeats to the dispatcher, getting back the tasks that should be
103   // running, and whether the iteration is finished.
104   Status ClientHeartbeat(ClientHeartbeatRequest& req,
105                          ClientHeartbeatResponse& resp);
106 
107   // Queries the dispatcher for its registered workers. The worker info will be
108   // stored in `workers`.
109   Status GetWorkers(std::vector<WorkerInfo>& workers);
110 
111   // Returns data service metadata for the registered dataset.
112   Status GetDataServiceMetadata(const std::string& dataset_id,
113                                 DataServiceMetadata& metadata);
114 
115   // Returns data service config of the data service cluster.
116   Status GetDataServiceConfig(DataServiceConfig& config);
117 
118  protected:
119   Status EnsureInitialized() override;
120 
121  private:
122   mutex mu_;
123   // Initialization is guarded by `mu_`, but using the stub does not require
124   // holding `mu_`
125   std::unique_ptr<DispatcherService::Stub> stub_;
126 };
127 
128 }  // namespace data
129 }  // namespace tensorflow
130 
131 #endif  // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_CLIENT_H_
132