1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_RPC_HANDLER_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_RPC_HANDLER_H_ 18 19 #include "tensorflow/core/platform/mutex.h" 20 #include "tensorflow/core/platform/status.h" 21 #include "tensorflow/core/platform/thread_annotations.h" 22 #include "tensorflow/core/protobuf/coordination_service.pb.h" 23 24 namespace tensorflow { 25 class CoordinationServiceAgent; 26 27 class CoordinationServiceRpcHandler { 28 public: CoordinationServiceRpcHandler()29 explicit CoordinationServiceRpcHandler() {} 30 31 void SetAgentInstance(CoordinationServiceAgent* agent); 32 33 void RegisterTaskAsync(const RegisterTaskRequest* request, 34 RegisterTaskResponse* response, StatusCallback done); 35 36 void HeartbeatAsync(const HeartbeatRequest* request, 37 HeartbeatResponse* response, StatusCallback done); 38 39 void WaitForAllTasksAsync(const WaitForAllTasksRequest* request, 40 WaitForAllTasksResponse* response, 41 StatusCallback done); 42 43 void ShutdownTaskAsync(const ShutdownTaskRequest* request, 44 ShutdownTaskResponse* response, StatusCallback done); 45 46 void ResetTaskAsync(const ResetTaskRequest* request, 47 ResetTaskResponse* response, StatusCallback done); 48 49 void ReportErrorToTaskAsync(const ReportErrorToTaskRequest* request, 50 ReportErrorToTaskResponse* response, 51 StatusCallback done); 52 53 void ReportErrorToServiceAsync(const ReportErrorToServiceRequest* request, 54 ReportErrorToServiceResponse* response, 55 StatusCallback done); 56 57 void InsertKeyValueAsync(const InsertKeyValueRequest* request, 58 InsertKeyValueResponse* response, 59 StatusCallback done); 60 61 void GetKeyValueAsync(const GetKeyValueRequest* request, 62 GetKeyValueResponse* response, StatusCallback done); 63 64 void TryGetKeyValueAsync(const TryGetKeyValueRequest* request, 65 TryGetKeyValueResponse* response, 66 StatusCallback done); 67 68 void GetKeyValueDirAsync(const GetKeyValueDirRequest* request, 69 GetKeyValueDirResponse* response, 70 StatusCallback done); 71 72 void DeleteKeyValueAsync(const DeleteKeyValueRequest* request, 73 DeleteKeyValueResponse* response, 74 StatusCallback done); 75 76 void BarrierAsync(const BarrierRequest* request, BarrierResponse* response, 77 StatusCallback done); 78 79 void CancelBarrierAsync(const CancelBarrierRequest* request, 80 CancelBarrierResponse* response, StatusCallback done); 81 82 private: 83 mutex agent_mu_; 84 CoordinationServiceAgent* agent_ TF_GUARDED_BY(agent_mu_); 85 }; 86 87 } // namespace tensorflow 88 89 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_RPC_HANDLER_H_ 90