xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/worker_session.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/worker_session.h"
16 
17 #include "tensorflow/core/lib/monitoring/collection_registry.h"
18 #include "tensorflow/core/lib/monitoring/gauge.h"
19 
20 namespace tensorflow {
21 
22 namespace {
23 
24 auto* worker_session_created =
25     monitoring::Gauge<bool, 0>::New("/tensorflow/core/worker_session_created",
26                                     "True if a worker session was created.");
27 
28 // A private cache that wraps worker_cache and allows reuse of
29 // WorkerInterface objects.
30 class WorkerFreeListCache : public WorkerCacheInterface {
31  public:
WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)32   explicit WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)
33       : wrapped_(std::move(w)) {}
34 
~WorkerFreeListCache()35   ~WorkerFreeListCache() final {
36     for (auto& p : workers_) {
37       wrapped_->ReleaseWorker(p.first, p.second.worker);
38     }
39   }
40 
ListWorkers(std::vector<string> * workers) const41   void ListWorkers(std::vector<string>* workers) const override {
42     wrapped_->ListWorkers(workers);
43   }
44 
ListWorkersInJob(const string & job_name,std::vector<string> * workers) const45   void ListWorkersInJob(const string& job_name,
46                         std::vector<string>* workers) const override {
47     wrapped_->ListWorkersInJob(job_name, workers);
48   }
49 
GetOrCreateWorker(const string & target)50   WorkerInterface* GetOrCreateWorker(const string& target) override {
51     {
52       // Fast path if worker has been created.
53       tf_shared_lock l(mu_);
54       auto p = workers_.find(target);
55       if (p != workers_.end()) {
56         return p->second.worker;
57       }
58     }
59     {
60       // Slow path if worker hasn't been created.
61       mutex_lock l(mu_);
62       auto p = workers_.find(target);
63       if (p != workers_.end()) {
64         return p->second.worker;
65       }
66       WorkerState state;
67       state.worker = wrapped_->GetOrCreateWorker(target);
68       if (state.worker != nullptr) {
69         workers_.insert(std::make_pair(target, state));
70       }
71       return state.worker;
72     }
73   }
74 
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)75   Status GetEagerClientCache(
76       std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
77     return wrapped_->GetEagerClientCache(eager_client_cache);
78   }
79 
GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache> * coordination_client_cache)80   Status GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache>*
81                                         coordination_client_cache) override {
82     return wrapped_->GetCoordinationClientCache(coordination_client_cache);
83   }
84 
ReleaseWorker(const string & target,WorkerInterface * worker)85   void ReleaseWorker(const string& target, WorkerInterface* worker) override {
86     // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
87   }
88 
GetDeviceLocalityNonBlocking(const string & device,DeviceLocality * locality)89   bool GetDeviceLocalityNonBlocking(const string& device,
90                                     DeviceLocality* locality) override {
91     return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
92   }
93 
GetDeviceLocalityAsync(const string & device,DeviceLocality * locality,StatusCallback done)94   void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
95                               StatusCallback done) override {
96     wrapped_->GetDeviceLocalityAsync(device, locality, done);
97   }
98 
SetLogging(bool active)99   void SetLogging(bool active) override { wrapped_->SetLogging(active); }
100 
ClearLogs()101   void ClearLogs() override { wrapped_->ClearLogs(); }
102 
RetrieveLogs(int64_t step_id,StepStats * ss)103   bool RetrieveLogs(int64_t step_id, StepStats* ss) override {
104     return wrapped_->RetrieveLogs(step_id, ss);
105   }
106 
107  private:
108   std::unique_ptr<WorkerCacheInterface> wrapped_;
109 
110   // Information kept per created WorkerInterface.
111   struct WorkerState {
112     WorkerInterface* worker;
113     // TODO(jeff,sanjay): Add reference count if we support eviction.
114   };
115 
116   // TODO(jeff,sanjay): Eviction when the map becomes too big.
117   mutex mu_;
118   std::unordered_map<string, WorkerState> workers_ TF_GUARDED_BY(mu_);
119 };
120 
121 }  // namespace
122 
WorkerSession(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,std::unique_ptr<DeviceMgr> device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)123 WorkerSession::WorkerSession(
124     const string& session_name, const string& worker_name,
125     std::unique_ptr<WorkerCacheInterface> worker_cache,
126     std::unique_ptr<DeviceMgr> device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
127     std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
128     : session_name_(session_name),
129       worker_name_(worker_name),
130       worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
131       graph_mgr_(std::move(graph_mgr)),
132       cluster_flr_(new ClusterFunctionLibraryRuntime(
133           this, !session_name.empty(),
134           remote_device_mgr ? remote_device_mgr.get() : nullptr)),
135       device_mgr_(std::move(device_mgr)),
136       borrowed_device_mgr_(nullptr),
137       remote_device_mgr_(std::move(remote_device_mgr)) {
138   // Starts exporting metrics through a platform-specific monitoring API (if
139   // provided). For builds using "tensorflow/tsl/platform/default", this is
140   // currently a no-op.
141   worker_session_created->GetCell()->Set(true);
142 }
143 
UpdateWorkerCacheAndDevices(std::unique_ptr<WorkerCacheInterface> new_worker_cache,std::vector<std::unique_ptr<Device>> added_remote_devices,const std::vector<Device * > & removed_remote_devices)144 Status WorkerSession::UpdateWorkerCacheAndDevices(
145     std::unique_ptr<WorkerCacheInterface> new_worker_cache,
146     std::vector<std::unique_ptr<Device>> added_remote_devices,
147     const std::vector<Device*>& removed_remote_devices) {
148   {
149     mutex_lock l(worker_session_state_mu_);
150     worker_cache_ = std::shared_ptr<WorkerCacheInterface>(
151         new WorkerFreeListCache(std::move(new_worker_cache)));
152   }
153   TF_RETURN_IF_ERROR(remote_device_mgr_->RemoveDevices(removed_remote_devices));
154   TF_RETURN_IF_ERROR(
155       remote_device_mgr_->AddDevices(std::move(added_remote_devices)));
156   return OkStatus();
157 }
158 
159 /* static */
CreateWithBorrowedDeviceMgr(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,DeviceMgr * borrowed_device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)160 std::shared_ptr<WorkerSession> WorkerSession::CreateWithBorrowedDeviceMgr(
161     const string& session_name, const string& worker_name,
162     std::unique_ptr<WorkerCacheInterface> worker_cache,
163     DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
164     std::unique_ptr<DynamicDeviceMgr> remote_device_mgr) {
165   return std::shared_ptr<WorkerSession>(new WorkerSession(
166       session_name, worker_name, std::move(worker_cache), borrowed_device_mgr,
167       std::move(graph_mgr), std::move(remote_device_mgr)));
168 }
169 
WorkerSession(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,DeviceMgr * borrowed_device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)170 WorkerSession::WorkerSession(
171     const string& session_name, const string& worker_name,
172     std::unique_ptr<WorkerCacheInterface> worker_cache,
173     DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
174     std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
175     : session_name_(session_name),
176       worker_name_(worker_name),
177       worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
178       graph_mgr_(std::move(graph_mgr)),
179       cluster_flr_(new ClusterFunctionLibraryRuntime(
180           this, !session_name.empty(), remote_device_mgr.get())),
181       device_mgr_(nullptr),
182       borrowed_device_mgr_(borrowed_device_mgr),
183       remote_device_mgr_(std::move(remote_device_mgr)) {
184   // Starts exporting metrics through a platform-specific monitoring API (if
185   // provided). For builds using "tensorflow/tsl/platform/default", this is
186   // currently a no-op.
187   worker_session_created->GetCell()->Set(true);
188 }
189 
~WorkerSession()190 WorkerSession::~WorkerSession() {
191   if (graph_mgr_) {
192     Status s = graph_mgr_->DeregisterAll();
193     if (!s.ok()) {
194       LOG(WARNING) << "Error during worker session deletion: " << s;
195     }
196   }
197 }
198 
199 }  // namespace tensorflow
200