1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_CLIENT_H_ 16 #define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_CLIENT_H_ 17 18 #include <memory> 19 #include <optional> 20 #include <string> 21 #include <vector> 22 23 #include "tensorflow/core/data/service/common.h" 24 #include "tensorflow/core/data/service/common.pb.h" 25 #include "tensorflow/core/data/service/data_transfer.h" 26 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h" 27 #include "tensorflow/core/data/service/dispatcher.pb.h" 28 #include "tensorflow/core/framework/graph.pb.h" 29 #include "tensorflow/core/framework/tensor.h" 30 #include "tensorflow/core/platform/mutex.h" 31 #include "tensorflow/core/platform/status.h" 32 #include "tensorflow/core/platform/statusor.h" 33 #include "tensorflow/core/platform/types.h" 34 #include "tensorflow/core/protobuf/data_service.pb.h" 35 #include "tensorflow/core/protobuf/service_config.pb.h" 36 37 namespace tensorflow { 38 namespace data { 39 40 // Client for communicating with the tf.data service dispatcher. 41 class DataServiceDispatcherClient : public DataServiceClientBase { 42 public: DataServiceDispatcherClient(const std::string & address,const std::string & protocol)43 DataServiceDispatcherClient(const std::string& address, 44 const std::string& protocol) 45 : DataServiceClientBase(address, protocol) {} 46 47 // Sends a heartbeat to the dispatcher. If the worker wasn't already 48 // registered with the dispatcher, this will register the worker. The 49 // dispatcher will report which new tasks the worker should run, and which 50 // tasks it should delete. 51 StatusOr<WorkerHeartbeatResponse> WorkerHeartbeat( 52 const WorkerHeartbeatRequest& request); 53 54 // Updates the dispatcher with information about the worker's state. 55 Status WorkerUpdate(const std::string& worker_address, 56 std::vector<TaskProgress>& task_progress); 57 58 // Gets a dataset definition for the given dataset id, and stores the 59 // definition in `dataset_def`. 60 Status GetDatasetDef(const std::string& dataset_id, DatasetDef& dataset_def); 61 62 // Gets the next split for the specified iteration id, repetition, and split 63 // provider index. 64 Status GetSplit(int64_t iteration_id, int64_t repetition, 65 int64_t split_provider_index, Tensor& split, 66 bool& end_of_splits); 67 68 // Registers a dataset with the tf.data service, and stores the generated 69 // dataset id in `dataset_id`. 70 Status RegisterDataset(const DatasetDef& dataset, 71 const DataServiceMetadata& metadata, 72 const std::optional<std::string>& requested_dataset_id, 73 std::string& dataset_id); 74 75 // If `job_name` is set, looks up a job matching `job_name`. 76 // If `job_name` is absent or no matching job is found, creates a 77 // new job. The resulting job id is stored in `job_id`. 78 Status GetOrCreateJob(const std::string& dataset_id, 79 const ProcessingModeDef& processing_mode, 80 const std::optional<std::string>& job_name, 81 std::optional<int64_t> num_consumers, 82 bool use_cross_trainer_cache, 83 TargetWorkers target_workers, int64_t& job_id); 84 85 // Looks up an iteration of a job, creating an iteration if one doesn't 86 // already exist. The returned `iteration_client_id` can be used to query 87 // information about the iteration. The client should call 88 // `ReleaseIterationClient` when finished with the iteration, so that 89 // resources can be reclaimed. 90 Status GetOrCreateIteration(int64_t job_id, int64_t repetition, 91 int64_t& iteration_client_id); 92 93 // Releases a iteration client id, indicating that the id will no longer be 94 // used to read from the iteration. 95 Status ReleaseIterationClient(int64_t iteration_client_id); 96 97 // Attempts to remove a task. The task is removed if all consumers try to 98 // remove the task in the same round. 99 Status MaybeRemoveTask(int64_t task_id, int64_t consumer_index, int64_t round, 100 bool& removed); 101 102 // Heartbeats to the dispatcher, getting back the tasks that should be 103 // running, and whether the iteration is finished. 104 Status ClientHeartbeat(ClientHeartbeatRequest& req, 105 ClientHeartbeatResponse& resp); 106 107 // Queries the dispatcher for its registered workers. The worker info will be 108 // stored in `workers`. 109 Status GetWorkers(std::vector<WorkerInfo>& workers); 110 111 // Returns data service metadata for the registered dataset. 112 Status GetDataServiceMetadata(const std::string& dataset_id, 113 DataServiceMetadata& metadata); 114 115 // Returns data service config of the data service cluster. 116 Status GetDataServiceConfig(DataServiceConfig& config); 117 118 protected: 119 Status EnsureInitialized() override; 120 121 private: 122 mutex mu_; 123 // Initialization is guarded by `mu_`, but using the stub does not require 124 // holding `mu_` 125 std::unique_ptr<DispatcherService::Stub> stub_; 126 }; 127 128 } // namespace data 129 } // namespace tensorflow 130 131 #endif // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_CLIENT_H_ 132