xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/clusters/single_machine.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/clusters/single_machine.h"
17 
18 #include <atomic>
19 #include <memory>
20 
21 #include "tensorflow/cc/training/queue_runner.h"
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
25 #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
26 #include "tensorflow/core/grappler/clusters/utils.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/kernels/ops_util.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/notification.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/public/session.h"
35 
36 namespace tensorflow {
37 namespace grappler {
38 
39 static std::atomic<bool> already_provisioned(false);
40 
SingleMachine(int timeout_s,int num_cpu_cores,int num_gpus)41 SingleMachine::SingleMachine(int timeout_s, int num_cpu_cores, int num_gpus)
42     : Cluster(timeout_s), expected_init_time_s_(0), closing_(false) {
43   VLOG(1) << "Number of CPU cores: " << num_cpu_cores
44           << " Number of GPUs: " << num_gpus;
45   thread_pool_.reset(new thread::ThreadPool(
46       Env::Default(), SanitizeThreadSuffix("single_machine"), 2));
47 
48   (*options_.config.mutable_device_count())["CPU"] = 1;
49   if (num_gpus > 0) {
50     (*options_.config.mutable_device_count())["GPU"] = num_gpus;
51   }
52   CHECK_GE(num_cpu_cores, 1);
53   options_.config.set_intra_op_parallelism_threads(num_cpu_cores);
54   // Create a session specific thread pool to ensure the threads are reset when
55   // the session is reset.
56   options_.config.add_session_inter_op_thread_pool()->set_num_threads(
57       num_cpu_cores);
58   if (timeout_s > 0) {
59     options_.config.set_operation_timeout_in_ms(timeout_s * 1000);
60   }
61 }
62 
~SingleMachine()63 SingleMachine::~SingleMachine() {
64   CloseSession(false /*use_timeout*/).IgnoreError();
65 
66   // Reset the thread-pool so that there are no outstanding Session::Run(...)s
67   // when we delete the session.
68   thread_pool_.reset();
69 }
70 
Provision()71 Status SingleMachine::Provision() {
72   // This is really ugly: to avoid leaking variables, we need to reset the tf
73   // session every time we're done processing a grappler item. However,
74   // variables are global, and therefore we can't have more than 1 session alive
75   // at a time. This check detects when more that one cluster is provisioned.
76   if (already_provisioned) {
77     return errors::Unavailable(
78         "Can't provision more than one single cluster at a time");
79   }
80 
81   TF_RETURN_IF_ERROR(ResetSession());
82 
83   std::vector<DeviceAttributes> devices;
84   TF_RETURN_IF_ERROR(session_->ListDevices(&devices));
85   for (const auto& dev : devices) {
86     DeviceProperties attr;
87     if (dev.device_type() == "CPU") {
88       attr = GetLocalCPUInfo();
89     } else if (dev.device_type() == "GPU") {
90       DeviceNameUtils::ParsedName parsed;
91       if (!DeviceNameUtils::ParseFullName(dev.name(), &parsed)) {
92         return errors::InvalidArgument(
93             strings::StrCat("Not able to parse GPU device name: ", dev.name()));
94       }
95       TfDeviceId tf_device_id(parsed.id);
96       PlatformDeviceId platform_device_id;
97       Status s =
98           GpuIdManager::TfToPlatformDeviceId(tf_device_id, &platform_device_id);
99       if (!s.ok()) {
100         return errors::Unavailable("Unknown TF GPU device with id ",
101                                    tf_device_id.value(), ": ",
102                                    s.error_message());
103       }
104       attr = GetLocalGPUInfo(platform_device_id);
105     } else if (dev.device_type().find("XLA") == string::npos) {
106       // Filter out the fake XLA devices to avoid double counting the actual
107       // hardware resources that are available.
108       attr.set_type(dev.device_type());
109     }
110     // Overwrite the memory size since users might have requested to use only a
111     // fraction of the available device memory.
112     attr.set_memory_size(dev.memory_limit());
113     devices_[dev.name()] = attr;
114   }
115   already_provisioned = true;
116 
117   // Clear highmark stats of all local allocators.
118   if (cpu_allocator_stats_enabled_) {
119     TF_RETURN_IF_ERROR(ClearAllocatorStats());
120   }
121   return OkStatus();
122 }
123 
Initialize(const GrapplerItem & item)124 Status SingleMachine::Initialize(const GrapplerItem& item) {
125   mutex_lock l(this->last_graph_mu_);
126   if (last_graph_ != &item.graph || last_graph_id_ != item.id) {
127     init_ops_ = item.init_ops;
128     expected_init_time_s_ = item.expected_init_time;
129     last_graph_ = nullptr;
130     queue_runner_defs_ = item.queue_runners;
131     last_graph_id_ = item.id;
132   }
133   return OkStatus();
134 }
135 
Shutdown()136 Status SingleMachine::Shutdown() {
137   TF_RETURN_IF_ERROR(ShutdownSession());
138 
139   mutex_lock l(this->last_graph_mu_);
140   last_graph_ = nullptr;
141   already_provisioned = false;
142 
143   return OkStatus();
144 }
145 
Run(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & feed,const std::vector<string> & fetch,RunMetadata * metadata)146 Status SingleMachine::Run(const GraphDef& graph_def,
147                           const std::vector<std::pair<string, Tensor>>& feed,
148                           const std::vector<string>& fetch,
149                           RunMetadata* metadata) {
150   mutex_lock l(this->last_graph_mu_);
151   if (last_graph_ != &graph_def) {
152     TF_RETURN_IF_ERROR(ResetSession());
153     TF_RETURN_IF_ERROR(session_->Create(graph_def));
154     if (!init_ops_.empty()) {
155       init_metadata_ = RunMetadata();
156       int64_t timeout_s = timeout_s_ + expected_init_time_s_;
157       TF_RETURN_IF_ERROR(
158           RunWithTimeout({}, init_ops_, &init_metadata_, timeout_s));
159       // The compute cost for init ops is likely to be pessimistic since init
160       // ops are run only once before warmup. Therefore we only keep their
161       // memory costs.
162       for (auto node : *init_metadata_.mutable_cost_graph()->mutable_node()) {
163         node.clear_compute_cost();
164       }
165       // Also clear the timeline to save memory
166       init_metadata_.clear_step_stats();
167     }
168     // We can have at most one hardware trace. Use it for the main graph, and
169     // downgrade tracing of the queue runners to a software trace.
170     RunOptions queue_options = run_options_;
171     if (queue_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
172       queue_options.set_trace_level(RunOptions::SOFTWARE_TRACE);
173     }
174     for (size_t i = 0; i < queue_runner_defs_.size(); ++i) {
175       std::unique_ptr<QueueRunner> queue_runner;
176       TF_RETURN_IF_ERROR(QueueRunner::New(queue_runner_defs_[i],
177                                           coordinator_.get(), &queue_runner));
178 
179       TF_RETURN_IF_ERROR(queue_runner->StartAndCollectCostGraph(session_.get(),
180                                                                 queue_options));
181       TF_RETURN_IF_ERROR(coordinator_->RegisterRunner(std::move(queue_runner)));
182       TF_RETURN_IF_ERROR(coordinator_->GetStatus());
183     }
184 
185     // Warmup TensorFlow if needed
186     for (int i = 0; i < NumWarmupSteps(); ++i) {
187       TF_RETURN_IF_ERROR(RunWithTimeout(feed, fetch, nullptr));
188     }
189   }
190 
191   if (metadata) {
192     TF_RETURN_IF_ERROR(RunWithTimeout(feed, fetch, metadata));
193     // Merge the costs of the initialization and the queue runners.
194     CostGraphDef queue_costs;
195     TF_RETURN_IF_ERROR(coordinator_->ExportCostGraph(&queue_costs));
196     MergeCosts(metadata->mutable_cost_graph(), init_metadata_.cost_graph(),
197                queue_costs);
198   } else {
199     TF_RETURN_IF_ERROR(RunWithTimeout(feed, fetch, nullptr));
200   }
201 
202   last_graph_ = &graph_def;
203 
204   return OkStatus();
205 }
206 
EnablePeakMemoryStats()207 Status SingleMachine::EnablePeakMemoryStats() {
208   EnableCPUAllocatorStats();
209   cpu_allocator_stats_enabled_ = true;
210   // No need to enable GPU allocator stats since its stats are always collected.
211   return OkStatus();
212 }
213 
GetPeakMemoryUsage(std::unordered_map<string,uint64> * device_peak_memory) const214 Status SingleMachine::GetPeakMemoryUsage(
215     std::unordered_map<string, uint64>* device_peak_memory) const {
216   // Cpu_allocator->TracksAllocationSizes() returns true doesn't always mean the
217   // the AllocatorStats would be collected.
218   if (!cpu_allocator_stats_enabled_) {
219     return Status(error::INVALID_ARGUMENT,
220                   "Tracking allocation for CPU is not enabled.");
221   }
222 
223   const DeviceMgr* device_mgr;
224   TF_RETURN_IF_ERROR(session_->LocalDeviceManager(&device_mgr));
225   std::vector<Device*> devices = device_mgr->ListDevices();
226 
227   device_peak_memory->clear();
228   for (Device* device : devices) {
229     auto* allocator = device->GetAllocator(AllocatorAttributes());
230     if (!allocator->TracksAllocationSizes()) {
231       return Status(error::INVALID_ARGUMENT,
232                     "Tracking allocation is not enabled.");
233     }
234     absl::optional<AllocatorStats> stats = allocator->GetStats();
235     (*device_peak_memory)[device->name()] =
236         (stats ? stats->peak_bytes_in_use : 0);
237   }
238 
239   return OkStatus();
240 }
241 
RunWithTimeout(const std::vector<std::pair<string,Tensor>> & feed,const std::vector<string> & fetch,RunMetadata * run_metadata)242 Status SingleMachine::RunWithTimeout(
243     const std::vector<std::pair<string, Tensor>>& feed,
244     const std::vector<string>& fetch, RunMetadata* run_metadata) {
245   return RunWithTimeout(feed, fetch, run_metadata, timeout_s_);
246 }
247 
RunWithTimeout(const std::vector<std::pair<string,Tensor>> & feed,const std::vector<string> & fetch,RunMetadata * run_metadata,int64_t timeout_s)248 Status SingleMachine::RunWithTimeout(
249     const std::vector<std::pair<string, Tensor>>& feed,
250     const std::vector<string>& fetch, RunMetadata* run_metadata,
251     int64_t timeout_s) {
252   // We shouldn't be running or closing the session at this point.
253   {
254     mutex_lock l(close_mu_);
255     CHECK(!closing_);
256   }
257 
258   auto status = std::make_shared<Status>();
259   auto local_metadata = std::make_shared<RunMetadata>();
260   const bool executed_in_time = ExecuteWithTimeout(
261       [this, status, local_metadata, feed, fetch]() {
262         *status = session_->Run(run_options_, feed, {}, fetch, nullptr,
263                                 local_metadata.get());
264       },
265       timeout_s * 1000, thread_pool_.get());
266   if (!executed_in_time) {
267     return errors::DeadlineExceeded("Failed to run the graph after ", timeout_s,
268                                     " seconds, aborting");
269   } else if (run_metadata && status->ok()) {
270     *run_metadata = *local_metadata;
271   }
272   return *status;
273 }
274 
CloseSession(bool use_timeout)275 Status SingleMachine::CloseSession(bool use_timeout) {
276   if (!session_ || !thread_pool_) {
277     return OkStatus();
278   }
279 
280   {
281     mutex_lock l(close_mu_);
282 
283     if (!closing_) {
284       closing_ = true;
285     }
286   }
287 
288   const bool executed_in_time = ExecuteWithTimeout(
289       [&]() {
290         if (this->coordinator_) {
291           this->coordinator_->RequestStop().IgnoreError();
292           // Wait for all the runners to have closed their queues.
293           while (!this->coordinator_->AllRunnersStopped()) {
294             Env::Default()->SleepForMicroseconds(1000000);
295           }
296           // Now we can close the session. This should cancel any pending I/O
297           // operation.
298           this->session_->Close().IgnoreError();
299           // Last but not least, we can delete the coordinator.
300           this->coordinator_.reset();
301         } else {
302           this->session_->Close().IgnoreError();
303         }
304 
305         mutex_lock l2(close_mu_);
306         closing_ = false;
307       },
308       use_timeout ? timeout_s_ * 1000 : -1, thread_pool_.get());
309 
310   if (!executed_in_time) {
311     // Let the caller know that we can't shutdown the session, and therefore
312     // can't process any further.
313     return errors::Unavailable("Failed to close the previous session after ",
314                                timeout_s_, " seconds, aborting");
315   }
316 
317   return OkStatus();
318 }
319 
ShutdownSession()320 Status SingleMachine::ShutdownSession() {
321   TF_RETURN_IF_ERROR(CloseSession(true /*use_timeout*/));
322 
323   // Delete the threadpool: this ensures that all the pending closures complete
324   // before we return. Note that if TF deadlocked on us, the closures will
325   // never complete, and the call to thread_pool_.reset() will never return:
326   // therefore we need to delete the threadpool with the background thread.
327   // That thread itself will also never complete, so the user should
328   // abort the process to avoid leaking too many resources.
329   auto n = std::make_shared<Notification>();
330   Env::Default()->SchedClosure([this, n]() {
331     thread_pool_.reset();
332     n->Notify();
333   });
334   int64_t timeout_us = 1000000ll * timeout_s_;
335   const bool notified = WaitForNotificationWithTimeout(n.get(), timeout_us);
336   if (!notified) {
337     // Let the caller know that we can't shutdown the session properly since
338     // there are calls to Session::Run() still running.
339     return errors::Unavailable("The session is still running graphs after ",
340                                timeout_s_, " seconds");
341   }
342 
343   return OkStatus();
344 }
345 
ResetSession()346 Status SingleMachine::ResetSession() {
347   if (session_) {
348     LOG(INFO) << "Cleaning up previous session";
349 
350     // Make sure the session is properly closed
351     TF_RETURN_IF_ERROR(ShutdownSession());
352 
353     // Destroying the object deletes all its variables as well. This is only
354     // true for DirectSession.
355     session_.reset();
356   }
357 
358   LOG(INFO) << "Starting new session";
359 
360   // Create a new threadpool
361   thread_pool_.reset(new thread::ThreadPool(
362       Env::Default(), SanitizeThreadSuffix("single_machine"), 2));
363 
364   session_.reset(NewSession(options_));
365   if (!session_) {
366     return errors::Unknown("Failed to create session");
367   }
368   coordinator_.reset(new Coordinator());
369 
370   // Build the DeviceSet.
371   device_set_.reset(new DeviceSet);
372   const DeviceMgr* device_mgr;
373   TF_RETURN_IF_ERROR(session_->LocalDeviceManager(&device_mgr));
374   for (auto d : device_mgr->ListDevices()) {
375     device_set_->AddDevice(d);
376     // We currently don't care about the client device.
377   }
378 
379   return OkStatus();
380 }
381 
MergeCosts(CostGraphDef * graph_costs,const CostGraphDef & init_costs,const CostGraphDef & queue_costs)382 void SingleMachine::MergeCosts(CostGraphDef* graph_costs,
383                                const CostGraphDef& init_costs,
384                                const CostGraphDef& queue_costs) {
385   graph_costs->mutable_node()->Reserve(graph_costs->node_size() +
386                                        init_costs.node_size() +
387                                        queue_costs.node_size());
388   std::unordered_set<string> nodes_seen;
389   int queue_costs_id_offset = graph_costs->node_size();
390   for (const auto& node : graph_costs->node()) {
391     nodes_seen.insert(node.name());
392     if (node.id() >= queue_costs_id_offset) {
393       queue_costs_id_offset = node.id() + 1;
394     }
395   }
396 
397   int init_costs_id_offset = queue_costs_id_offset + queue_costs.node_size();
398   // The costs obtained by running the main graph could be more stable than
399   // the one we get from the queue runners since the queue runners run
400   // asynchronously.
401   for (const auto& node : queue_costs.node()) {
402     if (nodes_seen.find(node.name()) != nodes_seen.end()) {
403       continue;
404     }
405 
406     auto* new_node = graph_costs->add_node();
407     new_node->MergeFrom(node);
408 
409     new_node->set_id(node.id() + queue_costs_id_offset);
410     if (new_node->id() >= init_costs_id_offset) {
411       init_costs_id_offset = new_node->id() + 1;
412     }
413 
414     for (auto& input_info : *new_node->mutable_input_info()) {
415       input_info.set_preceding_node(input_info.preceding_node() +
416                                     queue_costs_id_offset);
417     }
418     for (auto& control_input : *new_node->mutable_control_input()) {
419       control_input += queue_costs_id_offset;
420     }
421   }
422 
423   // Don't overwrite the costs with that generated during initialization since
424   // these are possibly outdated.
425   for (const auto& node : init_costs.node()) {
426     if (nodes_seen.find(node.name()) != nodes_seen.end()) {
427       continue;
428     }
429 
430     auto* new_node = graph_costs->add_node();
431     new_node->MergeFrom(node);
432 
433     new_node->set_id(node.id() + init_costs_id_offset);
434     for (auto& input_info : *new_node->mutable_input_info()) {
435       input_info.set_preceding_node(input_info.preceding_node() +
436                                     init_costs_id_offset);
437     }
438     for (auto& control_input : *new_node->mutable_control_input()) {
439       control_input += init_costs_id_offset;
440     }
441   }
442 }
443 
ClearAllocatorStats() const444 Status SingleMachine::ClearAllocatorStats() const {
445   // Cpu_allocator->TracksAllocationSizes() returns true doesn't always mean the
446   // the AllocatorStats would be collected.
447   if (!cpu_allocator_stats_enabled_) {
448     return Status(error::INVALID_ARGUMENT,
449                   "Tracking allocation for CPU is not enabled.");
450   }
451 
452   const DeviceMgr* device_mgr;
453   TF_RETURN_IF_ERROR(session_->LocalDeviceManager(&device_mgr));
454   std::vector<Device*> devices = device_mgr->ListDevices();
455 
456   for (Device* device : devices) {
457     auto* allocator = device->GetAllocator(AllocatorAttributes());
458     if (!allocator->TracksAllocationSizes()) {
459       return Status(error::INVALID_ARGUMENT,
460                     "Tracking allocation is not enabled.");
461     }
462     if (!allocator->ClearStats()) {
463       return Status(
464           error::INVALID_ARGUMENT,
465           absl::StrCat("Clearing allocation stats is not supported for ",
466                        device->name()));
467     }
468   }
469   return OkStatus();
470 }
471 
472 }  // namespace grappler
473 }  // namespace tensorflow
474