1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include "tensorflow/core/framework/device_base.h"
19
20 #include <algorithm>
21 #include <vector>
22
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/synchronization/notification.h"
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/util/work_sharder.h"
27
28 namespace tensorflow {
29
~DeviceBase()30 DeviceBase::~DeviceBase() {
31 for (auto& temp : eigen_cpu_devices_) {
32 delete temp;
33 }
34 eigen_cpu_devices_.clear();
35 }
36
CopyDeviceTensorToCPUSync(const Tensor * device_tensor,StringPiece tensor_name,Device * device,Tensor * cpu_tensor)37 Status DeviceContext::CopyDeviceTensorToCPUSync(const Tensor* device_tensor,
38 StringPiece tensor_name,
39 Device* device,
40 Tensor* cpu_tensor) {
41 absl::Notification n;
42 Status status;
43 CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor,
44 [&](const Status& s) {
45 status = s;
46 n.Notify();
47 });
48 n.WaitForNotification();
49 return status;
50 }
51
CopyCPUTensorToDeviceSync(const Tensor * cpu_tensor,Device * device,Tensor * device_tensor) const52 Status DeviceContext::CopyCPUTensorToDeviceSync(const Tensor* cpu_tensor,
53 Device* device,
54 Tensor* device_tensor) const {
55 absl::Notification n;
56 Status status;
57 CopyCPUTensorToDevice(cpu_tensor, device, device_tensor,
58 [&](const Status& s) {
59 status = s;
60 n.Notify();
61 });
62 n.WaitForNotification();
63 return status;
64 }
65
attributes() const66 const DeviceAttributes& DeviceBase::attributes() const {
67 LOG(FATAL) << "DeviceBase does not implement attributes()"; // Crash OK
68 std::abort();
69 }
70
name() const71 const string& DeviceBase::name() const {
72 LOG(FATAL) << "DeviceBase does not implement name()"; // Crash OK
73 std::abort();
74 }
75
parsed_name() const76 const DeviceNameUtils::ParsedName& DeviceBase::parsed_name() const {
77 LOG(FATAL) << "DeviceBase does not implement parsed_name()"; // Crash OK
78 std::abort();
79 }
80
set_eigen_cpu_device(Eigen::ThreadPoolDevice * d)81 void DeviceBase::set_eigen_cpu_device(Eigen::ThreadPoolDevice* d) {
82 // Eigen::ThreadPoolDevice is a very cheap struct (two pointers and
83 // an int). Therefore, we can afford a pre-allocated array of
84 // Eigen::ThreadPoolDevice. Here, we ensure that
85 // Eigen::ThreadPoolDevices in eigen_cpu_devices_ has increasingly
86 // larger numThreads.
87 for (int i = 1; i <= d->numThreads(); ++i) {
88 eigen_cpu_devices_.push_back(new Eigen::ThreadPoolDevice(
89 d->getPool(), i /* numThreads() */, d->allocator()));
90 }
91 }
92
eigen_cpu_device()93 const Eigen::ThreadPoolDevice* DeviceBase::eigen_cpu_device() {
94 // Based on GetPerThreadMaxParallelism(), we return a different
95 // pre-allocated Eigen::ThreadPoolDevice. All these ThreadPoolDevice
96 // use the same underlying threadpool. But they use different
97 // nominal numThreads() hoping that the user of the returned
98 // Eigen::ThreadPoolDevice may not aggressively occupy all the
99 // threads in the underlying threadpool.
100 const int parallelism = std::max<int>(
101 1,
102 std::min<int>(GetPerThreadMaxParallelism(), eigen_cpu_devices_.size()));
103 return eigen_cpu_devices_[parallelism - 1];
104 }
105
106 namespace {
107
GetSymbolicDeviceList()108 absl::flat_hash_set<std::string>* GetSymbolicDeviceList() {
109 static absl::flat_hash_set<std::string>* symbolic_device_list =
110 new absl::flat_hash_set<std::string>();
111 return symbolic_device_list;
112 }
113
114 } // namespace
115
AddSymbolicExecutionDevice(const absl::string_view device_name)116 void AddSymbolicExecutionDevice(const absl::string_view device_name) {
117 GetSymbolicDeviceList()->insert(std::string(device_name));
118 }
119
IsSymbolicExecutionDevice(const absl::string_view device_name)120 bool IsSymbolicExecutionDevice(const absl::string_view device_name) {
121 absl::flat_hash_set<std::string>* symbolic_devices = GetSymbolicDeviceList();
122 if (symbolic_devices->contains(device_name)) {
123 return true;
124 } else {
125 return false;
126 }
127 }
128
129 } // namespace tensorflow
130