1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_ 18 19 #include <memory> 20 #include <unordered_map> 21 #include <utility> 22 23 #include "tensorflow/core/common_runtime/eager/context.h" 24 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h" 25 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h" 26 #include "tensorflow/core/distributed_runtime/worker_env.h" 27 28 namespace tensorflow { 29 namespace eager { 30 31 // A TensorFlow Eager Worker runs ops and supports worker to worker 32 // Tensor transfer. 33 // 34 // See eager_service.proto for more details about each method. 35 // This class can be wrapped by specific classes that implement rpc transports 36 // over this (e.g. gRPC). 37 class EagerServiceImpl { 38 public: EagerServiceImpl(const WorkerEnv * env)39 explicit EagerServiceImpl(const WorkerEnv* env) : env_(env) { 40 gc_thread_.reset( 41 env_->env->StartThread({}, "EagerServiceContextGC", [this]() { 42 while (true) { 43 { 44 mutex_lock l(gc_thread_shutdown_mu_); 45 gc_thread_cv_.wait_for(l, std::chrono::seconds(1)); 46 47 if (shutting_down_) { 48 return; 49 } 50 } 51 { 52 mutex_lock l(contexts_mu_); 53 for (auto it = contexts_.begin(); it != contexts_.end();) { 54 if (it->second->IsStale()) { 55 it->second->Unref(); 56 it = contexts_.erase(it); 57 } else { 58 it++; 59 } 60 } 61 } 62 } 63 })); 64 } ~EagerServiceImpl()65 virtual ~EagerServiceImpl() { 66 { 67 mutex_lock l(gc_thread_shutdown_mu_); 68 shutting_down_ = true; 69 gc_thread_cv_.notify_all(); 70 } 71 gc_thread_.reset(); 72 73 mutex_lock l(contexts_mu_); 74 for (auto& entry : contexts_) { 75 entry.second->Unref(); 76 } 77 } 78 79 Status CreateContext(const CreateContextRequest* request, 80 CreateContextResponse* response); 81 82 Status UpdateContext(const UpdateContextRequest* request, 83 UpdateContextResponse* response); 84 85 // Create a ServerContext for master eager context. 86 Status CreateMasterContext(const tensorflow::uint64 context_id, 87 EagerContext* context); 88 89 static constexpr uint64 kInvalidStreamId = 0; 90 91 // Used by both Enqueue and StreamingEnqueue RPCs. 92 Status Enqueue(CallOptions* call_opts, const EnqueueRequest* request, 93 EnqueueResponse* response, 94 uint64 stream_id = kInvalidStreamId); 95 96 Status WaitQueueDone(const WaitQueueDoneRequest* request, 97 WaitQueueDoneResponse* response); 98 99 void RunComponentFunction(CallOptions* call_opts, 100 const RunComponentFunctionRequest* request, 101 RunComponentFunctionResponse* response, 102 StatusCallback done); 103 104 Status KeepAlive(const KeepAliveRequest* request, 105 KeepAliveResponse* response); 106 107 Status CloseContext(const CloseContextRequest* request, 108 CloseContextResponse* response); 109 110 protected: 111 // This is the server-side execution context. All state regarding execution of 112 // a client's ops is held in this server-side context (all generated tensors, 113 // and the EagerContext). 114 class ServerContext : public core::RefCounted { 115 public: 116 // Create a ServerContext for local master. CreateMasterContext(tensorflow::EagerContext * ctx,const WorkerEnv * env)117 static ServerContext* CreateMasterContext(tensorflow::EagerContext* ctx, 118 const WorkerEnv* env) { 119 return new ServerContext(ctx, -1, env, /* is_master= */ true); 120 } 121 122 explicit ServerContext(tensorflow::EagerContext* ctx, 123 int64_t destroy_after_secs, const WorkerEnv* env, 124 const bool is_master = false) ctx_(ctx)125 : ctx_(ctx), env_(env), is_master_(is_master) { 126 ctx->Ref(); 127 destroy_after_micros_ = 128 destroy_after_secs * tensorflow::EnvTime::kSecondsToMicros; 129 RecordAccess(); 130 } 131 ~ServerContext()132 ~ServerContext() override { 133 // TFE_Context is responsible for shutting down master eager context. 134 if (!is_master_) { 135 ctx_->WaitForAndCloseRemoteContexts(); 136 } 137 // ctx_->RefCountIsOne() should be true here when is_master_ = false. 138 // TODO(iga): Remove EagerContext refcounting. 139 ctx_->Unref(); 140 } 141 Context()142 tensorflow::EagerContext* Context() const { return ctx_; } 143 RecordAccess()144 void RecordAccess() { 145 mutex_lock l(last_accessed_mu_); 146 last_accessed_micros_ = env_->env->NowMicros(); 147 } 148 IsStale()149 bool IsStale() { 150 mutex_lock l(last_accessed_mu_); 151 const int64_t time_passed = 152 env_->env->NowMicros() - last_accessed_micros_; 153 return (destroy_after_micros_ > 0 && time_passed > destroy_after_micros_); 154 } 155 156 private: 157 // The context for this execution. 158 tensorflow::EagerContext* ctx_; 159 160 const WorkerEnv* const env_; // Not owned. 161 162 mutex last_accessed_mu_; 163 int64_t last_accessed_micros_ TF_GUARDED_BY(last_accessed_mu_); 164 int64_t destroy_after_micros_; 165 166 const bool is_master_; 167 }; 168 // The returned ServerContext will need to be Unrefed. 169 tensorflow::Status GetServerContext(uint64, ServerContext**); 170 171 class ClientTensorHandleDeleteNode : public EagerNode { 172 public: ClientTensorHandleDeleteNode(ServerContext * context,std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete)173 ClientTensorHandleDeleteNode( 174 ServerContext* context, 175 std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete) 176 : tensorflow::EagerNode(), 177 context_(context), 178 handle_to_delete_(std::move(handle_to_delete)) { 179 context_->Ref(); 180 } 181 ~ClientTensorHandleDeleteNode()182 ~ClientTensorHandleDeleteNode() override { context_->Unref(); } 183 Run()184 Status Run() override { 185 VLOG(3) << "ServerContext: Deleting tensor handle " 186 << handle_to_delete_->op_id << ":" 187 << handle_to_delete_->output_num; 188 return context_->Context()->RemoteMgr()->DeleteTensorHandle( 189 *handle_to_delete_); 190 } 191 Abort(Status status)192 void Abort(Status status) override {} 193 194 // Remote node deletions are best effort Fatal()195 bool Fatal() const override { return false; } 196 DebugString()197 string DebugString() const override { 198 string out = "[ClientTensorHandleDeleteNode]"; 199 strings::StrAppend(&out, " op_id: ", handle_to_delete_->op_id); 200 strings::StrAppend(&out, ", output_num: ", handle_to_delete_->output_num); 201 return out; 202 } 203 204 private: 205 // Owns one reference. 206 ServerContext* const context_; 207 const std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete_; 208 }; 209 210 private: 211 Status ExecuteOp(CallOptions* call_opts, const Operation& operation, 212 EagerContext* eager_context, EagerExecutor* eager_executor, 213 QueueResponse* queue_response); 214 Status SendTensor(const SendTensorOp& send_tensor, 215 EagerContext* eager_context); 216 Status SendPackedHandle(const SendPackedHandleOp& send_packed_handle, 217 EagerContext* eager_context); 218 Status RegisterFunction(const RegisterFunctionOp& register_function, 219 EagerContext* eager_context); 220 Status CleanupFunction(const CleanupFunctionOp& cleanup_function); 221 const WorkerEnv* const env_; // Not owned. 222 223 mutex contexts_mu_; 224 std::unordered_map<uint64, ServerContext*> contexts_ 225 TF_GUARDED_BY(contexts_mu_); 226 227 std::unique_ptr<Thread> gc_thread_; 228 mutex gc_thread_shutdown_mu_; 229 condition_variable gc_thread_cv_; 230 bool shutting_down_ TF_GUARDED_BY(gc_thread_shutdown_mu_) = false; 231 232 TF_DISALLOW_COPY_AND_ASSIGN(EagerServiceImpl); 233 }; 234 235 } // namespace eager 236 } // namespace tensorflow 237 238 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_ 239