xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/eager/eager_service_impl.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_
18 
19 #include <memory>
20 #include <unordered_map>
21 #include <utility>
22 
23 #include "tensorflow/core/common_runtime/eager/context.h"
24 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
25 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
26 #include "tensorflow/core/distributed_runtime/worker_env.h"
27 
28 namespace tensorflow {
29 namespace eager {
30 
31 // A TensorFlow Eager Worker runs ops and supports worker to worker
32 // Tensor transfer.
33 //
34 // See eager_service.proto for more details about each method.
35 // This class can be wrapped by specific classes that implement rpc transports
36 // over this (e.g. gRPC).
37 class EagerServiceImpl {
38  public:
EagerServiceImpl(const WorkerEnv * env)39   explicit EagerServiceImpl(const WorkerEnv* env) : env_(env) {
40     gc_thread_.reset(
41         env_->env->StartThread({}, "EagerServiceContextGC", [this]() {
42           while (true) {
43             {
44               mutex_lock l(gc_thread_shutdown_mu_);
45               gc_thread_cv_.wait_for(l, std::chrono::seconds(1));
46 
47               if (shutting_down_) {
48                 return;
49               }
50             }
51             {
52               mutex_lock l(contexts_mu_);
53               for (auto it = contexts_.begin(); it != contexts_.end();) {
54                 if (it->second->IsStale()) {
55                   it->second->Unref();
56                   it = contexts_.erase(it);
57                 } else {
58                   it++;
59                 }
60               }
61             }
62           }
63         }));
64   }
~EagerServiceImpl()65   virtual ~EagerServiceImpl() {
66     {
67       mutex_lock l(gc_thread_shutdown_mu_);
68       shutting_down_ = true;
69       gc_thread_cv_.notify_all();
70     }
71     gc_thread_.reset();
72 
73     mutex_lock l(contexts_mu_);
74     for (auto& entry : contexts_) {
75       entry.second->Unref();
76     }
77   }
78 
79   Status CreateContext(const CreateContextRequest* request,
80                        CreateContextResponse* response);
81 
82   Status UpdateContext(const UpdateContextRequest* request,
83                        UpdateContextResponse* response);
84 
85   // Create a ServerContext for master eager context.
86   Status CreateMasterContext(const tensorflow::uint64 context_id,
87                              EagerContext* context);
88 
89   static constexpr uint64 kInvalidStreamId = 0;
90 
91   // Used by both Enqueue and StreamingEnqueue RPCs.
92   Status Enqueue(CallOptions* call_opts, const EnqueueRequest* request,
93                  EnqueueResponse* response,
94                  uint64 stream_id = kInvalidStreamId);
95 
96   Status WaitQueueDone(const WaitQueueDoneRequest* request,
97                        WaitQueueDoneResponse* response);
98 
99   void RunComponentFunction(CallOptions* call_opts,
100                             const RunComponentFunctionRequest* request,
101                             RunComponentFunctionResponse* response,
102                             StatusCallback done);
103 
104   Status KeepAlive(const KeepAliveRequest* request,
105                    KeepAliveResponse* response);
106 
107   Status CloseContext(const CloseContextRequest* request,
108                       CloseContextResponse* response);
109 
110  protected:
111   // This is the server-side execution context. All state regarding execution of
112   // a client's ops is held in this server-side context (all generated tensors,
113   // and the EagerContext).
114   class ServerContext : public core::RefCounted {
115    public:
116     // Create a ServerContext for local master.
CreateMasterContext(tensorflow::EagerContext * ctx,const WorkerEnv * env)117     static ServerContext* CreateMasterContext(tensorflow::EagerContext* ctx,
118                                               const WorkerEnv* env) {
119       return new ServerContext(ctx, -1, env, /* is_master= */ true);
120     }
121 
122     explicit ServerContext(tensorflow::EagerContext* ctx,
123                            int64_t destroy_after_secs, const WorkerEnv* env,
124                            const bool is_master = false)
ctx_(ctx)125         : ctx_(ctx), env_(env), is_master_(is_master) {
126       ctx->Ref();
127       destroy_after_micros_ =
128           destroy_after_secs * tensorflow::EnvTime::kSecondsToMicros;
129       RecordAccess();
130     }
131 
~ServerContext()132     ~ServerContext() override {
133       // TFE_Context is responsible for shutting down master eager context.
134       if (!is_master_) {
135         ctx_->WaitForAndCloseRemoteContexts();
136       }
137       // ctx_->RefCountIsOne() should be true here when is_master_ = false.
138       // TODO(iga): Remove EagerContext refcounting.
139       ctx_->Unref();
140     }
141 
Context()142     tensorflow::EagerContext* Context() const { return ctx_; }
143 
RecordAccess()144     void RecordAccess() {
145       mutex_lock l(last_accessed_mu_);
146       last_accessed_micros_ = env_->env->NowMicros();
147     }
148 
IsStale()149     bool IsStale() {
150       mutex_lock l(last_accessed_mu_);
151       const int64_t time_passed =
152           env_->env->NowMicros() - last_accessed_micros_;
153       return (destroy_after_micros_ > 0 && time_passed > destroy_after_micros_);
154     }
155 
156    private:
157     // The context for this execution.
158     tensorflow::EagerContext* ctx_;
159 
160     const WorkerEnv* const env_;  // Not owned.
161 
162     mutex last_accessed_mu_;
163     int64_t last_accessed_micros_ TF_GUARDED_BY(last_accessed_mu_);
164     int64_t destroy_after_micros_;
165 
166     const bool is_master_;
167   };
168   // The returned ServerContext will need to be Unrefed.
169   tensorflow::Status GetServerContext(uint64, ServerContext**);
170 
171   class ClientTensorHandleDeleteNode : public EagerNode {
172    public:
ClientTensorHandleDeleteNode(ServerContext * context,std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete)173     ClientTensorHandleDeleteNode(
174         ServerContext* context,
175         std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete)
176         : tensorflow::EagerNode(),
177           context_(context),
178           handle_to_delete_(std::move(handle_to_delete)) {
179       context_->Ref();
180     }
181 
~ClientTensorHandleDeleteNode()182     ~ClientTensorHandleDeleteNode() override { context_->Unref(); }
183 
Run()184     Status Run() override {
185       VLOG(3) << "ServerContext: Deleting tensor handle "
186               << handle_to_delete_->op_id << ":"
187               << handle_to_delete_->output_num;
188       return context_->Context()->RemoteMgr()->DeleteTensorHandle(
189           *handle_to_delete_);
190     }
191 
Abort(Status status)192     void Abort(Status status) override {}
193 
194     // Remote node deletions are best effort
Fatal()195     bool Fatal() const override { return false; }
196 
DebugString()197     string DebugString() const override {
198       string out = "[ClientTensorHandleDeleteNode]";
199       strings::StrAppend(&out, " op_id: ", handle_to_delete_->op_id);
200       strings::StrAppend(&out, ", output_num: ", handle_to_delete_->output_num);
201       return out;
202     }
203 
204    private:
205     // Owns one reference.
206     ServerContext* const context_;
207     const std::unique_ptr<RemoteTensorHandleInternal> handle_to_delete_;
208   };
209 
210  private:
211   Status ExecuteOp(CallOptions* call_opts, const Operation& operation,
212                    EagerContext* eager_context, EagerExecutor* eager_executor,
213                    QueueResponse* queue_response);
214   Status SendTensor(const SendTensorOp& send_tensor,
215                     EagerContext* eager_context);
216   Status SendPackedHandle(const SendPackedHandleOp& send_packed_handle,
217                           EagerContext* eager_context);
218   Status RegisterFunction(const RegisterFunctionOp& register_function,
219                           EagerContext* eager_context);
220   Status CleanupFunction(const CleanupFunctionOp& cleanup_function);
221   const WorkerEnv* const env_;  // Not owned.
222 
223   mutex contexts_mu_;
224   std::unordered_map<uint64, ServerContext*> contexts_
225       TF_GUARDED_BY(contexts_mu_);
226 
227   std::unique_ptr<Thread> gc_thread_;
228   mutex gc_thread_shutdown_mu_;
229   condition_variable gc_thread_cv_;
230   bool shutting_down_ TF_GUARDED_BY(gc_thread_shutdown_mu_) = false;
231 
232   TF_DISALLOW_COPY_AND_ASSIGN(EagerServiceImpl);
233 };
234 
235 }  // namespace eager
236 }  // namespace tensorflow
237 
238 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_
239