1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/device_factory.h"
17
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <vector>
22
23 #include "tensorflow/core/framework/device.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/public/session_options.h"
30 #include "tensorflow/core/util/env_var.h"
31
32 namespace tensorflow {
33
34 namespace {
35
get_device_factory_lock()36 static mutex* get_device_factory_lock() {
37 static mutex device_factory_lock(LINKER_INITIALIZED);
38 return &device_factory_lock;
39 }
40
41 struct FactoryItem {
42 std::unique_ptr<DeviceFactory> factory;
43 int priority;
44 bool is_pluggable_device;
45 };
46
device_factories()47 std::unordered_map<string, FactoryItem>& device_factories() {
48 static std::unordered_map<string, FactoryItem>* factories =
49 new std::unordered_map<string, FactoryItem>;
50 return *factories;
51 }
52
IsDeviceFactoryEnabled(const string & device_type)53 bool IsDeviceFactoryEnabled(const string& device_type) {
54 std::vector<string> enabled_devices;
55 TF_CHECK_OK(tensorflow::ReadStringsFromEnvVar(
56 /*env_var_name=*/"TF_ENABLED_DEVICE_TYPES", /*default_val=*/"",
57 &enabled_devices));
58 if (enabled_devices.empty()) {
59 return true;
60 }
61 return std::find(enabled_devices.begin(), enabled_devices.end(),
62 device_type) != enabled_devices.end();
63 }
64 } // namespace
65
66 // static
DevicePriority(const string & device_type)67 int32 DeviceFactory::DevicePriority(const string& device_type) {
68 tf_shared_lock l(*get_device_factory_lock());
69 std::unordered_map<string, FactoryItem>& factories = device_factories();
70 auto iter = factories.find(device_type);
71 if (iter != factories.end()) {
72 return iter->second.priority;
73 }
74
75 return -1;
76 }
77
IsPluggableDevice(const string & device_type)78 bool DeviceFactory::IsPluggableDevice(const string& device_type) {
79 tf_shared_lock l(*get_device_factory_lock());
80 std::unordered_map<string, FactoryItem>& factories = device_factories();
81 auto iter = factories.find(device_type);
82 if (iter != factories.end()) {
83 return iter->second.is_pluggable_device;
84 }
85 return false;
86 }
87
88 // static
Register(const string & device_type,std::unique_ptr<DeviceFactory> factory,int priority,bool is_pluggable_device)89 void DeviceFactory::Register(const string& device_type,
90 std::unique_ptr<DeviceFactory> factory,
91 int priority, bool is_pluggable_device) {
92 if (!IsDeviceFactoryEnabled(device_type)) {
93 LOG(INFO) << "Device factory '" << device_type << "' disabled by "
94 << "TF_ENABLED_DEVICE_TYPES environment variable.";
95 return;
96 }
97 mutex_lock l(*get_device_factory_lock());
98 std::unordered_map<string, FactoryItem>& factories = device_factories();
99 auto iter = factories.find(device_type);
100 if (iter == factories.end()) {
101 factories[device_type] = {std::move(factory), priority,
102 is_pluggable_device};
103 } else {
104 if (iter->second.priority < priority) {
105 iter->second = {std::move(factory), priority, is_pluggable_device};
106 } else if (iter->second.priority == priority) {
107 LOG(FATAL) << "Duplicate registration of device factory for type "
108 << device_type << " with the same priority " << priority;
109 }
110 }
111 }
112
GetFactory(const string & device_type)113 DeviceFactory* DeviceFactory::GetFactory(const string& device_type) {
114 tf_shared_lock l(*get_device_factory_lock());
115 auto it = device_factories().find(device_type);
116 if (it == device_factories().end()) {
117 return nullptr;
118 } else if (!IsDeviceFactoryEnabled(device_type)) {
119 LOG(FATAL) << "Device type " << device_type // Crash OK
120 << " had factory registered but was explicitly disabled by "
121 << "`TF_ENABLED_DEVICE_TYPES`. This environment variable needs "
122 << "to be set at program startup.";
123 }
124 return it->second.factory.get();
125 }
126
ListAllPhysicalDevices(std::vector<string> * devices)127 Status DeviceFactory::ListAllPhysicalDevices(std::vector<string>* devices) {
128 // CPU first. A CPU device is required.
129 // TODO(b/183974121): Consider merge the logic into the loop below.
130 auto cpu_factory = GetFactory("CPU");
131 if (!cpu_factory) {
132 return errors::NotFound(
133 "CPU Factory not registered. Did you link in threadpool_device?");
134 }
135
136 size_t init_size = devices->size();
137 TF_RETURN_IF_ERROR(cpu_factory->ListPhysicalDevices(devices));
138 if (devices->size() == init_size) {
139 return errors::NotFound("No CPU devices are available in this process");
140 }
141
142 // Then the rest (including GPU).
143 tf_shared_lock l(*get_device_factory_lock());
144 for (auto& p : device_factories()) {
145 auto factory = p.second.factory.get();
146 if (factory != cpu_factory) {
147 TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(devices));
148 }
149 }
150
151 return OkStatus();
152 }
153
ListPluggablePhysicalDevices(std::vector<string> * devices)154 Status DeviceFactory::ListPluggablePhysicalDevices(
155 std::vector<string>* devices) {
156 tf_shared_lock l(*get_device_factory_lock());
157 for (auto& p : device_factories()) {
158 if (p.second.is_pluggable_device) {
159 auto factory = p.second.factory.get();
160 TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(devices));
161 }
162 }
163 return OkStatus();
164 }
165
GetAnyDeviceDetails(int device_index,std::unordered_map<string,string> * details)166 Status DeviceFactory::GetAnyDeviceDetails(
167 int device_index, std::unordered_map<string, string>* details) {
168 if (device_index < 0) {
169 return errors::InvalidArgument("Device index out of bounds: ",
170 device_index);
171 }
172 const int orig_device_index = device_index;
173
174 // Iterate over devices in the same way as in ListAllPhysicalDevices.
175 auto cpu_factory = GetFactory("CPU");
176 if (!cpu_factory) {
177 return errors::NotFound(
178 "CPU Factory not registered. Did you link in threadpool_device?");
179 }
180
181 // TODO(b/183974121): Consider merge the logic into the loop below.
182 std::vector<string> devices;
183 TF_RETURN_IF_ERROR(cpu_factory->ListPhysicalDevices(&devices));
184 if (device_index < devices.size()) {
185 return cpu_factory->GetDeviceDetails(device_index, details);
186 }
187 device_index -= devices.size();
188
189 // Then the rest (including GPU).
190 tf_shared_lock l(*get_device_factory_lock());
191 for (auto& p : device_factories()) {
192 auto factory = p.second.factory.get();
193 if (factory != cpu_factory) {
194 devices.clear();
195 // TODO(b/146009447): Find the factory size without having to allocate a
196 // vector with all the physical devices.
197 TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(&devices));
198 if (device_index < devices.size()) {
199 return factory->GetDeviceDetails(device_index, details);
200 }
201 device_index -= devices.size();
202 }
203 }
204
205 return errors::InvalidArgument("Device index out of bounds: ",
206 orig_device_index);
207 }
208
AddCpuDevices(const SessionOptions & options,const string & name_prefix,std::vector<std::unique_ptr<Device>> * devices)209 Status DeviceFactory::AddCpuDevices(
210 const SessionOptions& options, const string& name_prefix,
211 std::vector<std::unique_ptr<Device>>* devices) {
212 auto cpu_factory = GetFactory("CPU");
213 if (!cpu_factory) {
214 return errors::NotFound(
215 "CPU Factory not registered. Did you link in threadpool_device?");
216 }
217 size_t init_size = devices->size();
218 TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(options, name_prefix, devices));
219 if (devices->size() == init_size) {
220 return errors::NotFound("No CPU devices are available in this process");
221 }
222
223 return OkStatus();
224 }
225
AddDevices(const SessionOptions & options,const string & name_prefix,std::vector<std::unique_ptr<Device>> * devices)226 Status DeviceFactory::AddDevices(
227 const SessionOptions& options, const string& name_prefix,
228 std::vector<std::unique_ptr<Device>>* devices) {
229 // CPU first. A CPU device is required.
230 // TODO(b/183974121): Consider merge the logic into the loop below.
231 TF_RETURN_IF_ERROR(AddCpuDevices(options, name_prefix, devices));
232
233 auto cpu_factory = GetFactory("CPU");
234 // Then the rest (including GPU).
235 mutex_lock l(*get_device_factory_lock());
236 for (auto& p : device_factories()) {
237 auto factory = p.second.factory.get();
238 if (factory != cpu_factory) {
239 TF_RETURN_IF_ERROR(factory->CreateDevices(options, name_prefix, devices));
240 }
241 }
242
243 return OkStatus();
244 }
245
NewDevice(const string & type,const SessionOptions & options,const string & name_prefix)246 std::unique_ptr<Device> DeviceFactory::NewDevice(const string& type,
247 const SessionOptions& options,
248 const string& name_prefix) {
249 auto device_factory = GetFactory(type);
250 if (!device_factory) {
251 return nullptr;
252 }
253 SessionOptions opt = options;
254 (*opt.config.mutable_device_count())[type] = 1;
255 std::vector<std::unique_ptr<Device>> devices;
256 TF_CHECK_OK(device_factory->CreateDevices(opt, name_prefix, &devices));
257 int expected_num_devices = 1;
258 auto iter = options.config.device_count().find(type);
259 if (iter != options.config.device_count().end()) {
260 expected_num_devices = iter->second;
261 }
262 DCHECK_EQ(devices.size(), static_cast<size_t>(expected_num_devices));
263 return std::move(devices[0]);
264 }
265
266 } // namespace tensorflow
267