xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/device_factory.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/device_factory.h"
17 
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/device.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/public/session_options.h"
30 #include "tensorflow/core/util/env_var.h"
31 
32 namespace tensorflow {
33 
34 namespace {
35 
get_device_factory_lock()36 static mutex* get_device_factory_lock() {
37   static mutex device_factory_lock(LINKER_INITIALIZED);
38   return &device_factory_lock;
39 }
40 
41 struct FactoryItem {
42   std::unique_ptr<DeviceFactory> factory;
43   int priority;
44   bool is_pluggable_device;
45 };
46 
device_factories()47 std::unordered_map<string, FactoryItem>& device_factories() {
48   static std::unordered_map<string, FactoryItem>* factories =
49       new std::unordered_map<string, FactoryItem>;
50   return *factories;
51 }
52 
IsDeviceFactoryEnabled(const string & device_type)53 bool IsDeviceFactoryEnabled(const string& device_type) {
54   std::vector<string> enabled_devices;
55   TF_CHECK_OK(tensorflow::ReadStringsFromEnvVar(
56       /*env_var_name=*/"TF_ENABLED_DEVICE_TYPES", /*default_val=*/"",
57       &enabled_devices));
58   if (enabled_devices.empty()) {
59     return true;
60   }
61   return std::find(enabled_devices.begin(), enabled_devices.end(),
62                    device_type) != enabled_devices.end();
63 }
64 }  // namespace
65 
66 // static
DevicePriority(const string & device_type)67 int32 DeviceFactory::DevicePriority(const string& device_type) {
68   tf_shared_lock l(*get_device_factory_lock());
69   std::unordered_map<string, FactoryItem>& factories = device_factories();
70   auto iter = factories.find(device_type);
71   if (iter != factories.end()) {
72     return iter->second.priority;
73   }
74 
75   return -1;
76 }
77 
IsPluggableDevice(const string & device_type)78 bool DeviceFactory::IsPluggableDevice(const string& device_type) {
79   tf_shared_lock l(*get_device_factory_lock());
80   std::unordered_map<string, FactoryItem>& factories = device_factories();
81   auto iter = factories.find(device_type);
82   if (iter != factories.end()) {
83     return iter->second.is_pluggable_device;
84   }
85   return false;
86 }
87 
88 // static
Register(const string & device_type,std::unique_ptr<DeviceFactory> factory,int priority,bool is_pluggable_device)89 void DeviceFactory::Register(const string& device_type,
90                              std::unique_ptr<DeviceFactory> factory,
91                              int priority, bool is_pluggable_device) {
92   if (!IsDeviceFactoryEnabled(device_type)) {
93     LOG(INFO) << "Device factory '" << device_type << "' disabled by "
94               << "TF_ENABLED_DEVICE_TYPES environment variable.";
95     return;
96   }
97   mutex_lock l(*get_device_factory_lock());
98   std::unordered_map<string, FactoryItem>& factories = device_factories();
99   auto iter = factories.find(device_type);
100   if (iter == factories.end()) {
101     factories[device_type] = {std::move(factory), priority,
102                               is_pluggable_device};
103   } else {
104     if (iter->second.priority < priority) {
105       iter->second = {std::move(factory), priority, is_pluggable_device};
106     } else if (iter->second.priority == priority) {
107       LOG(FATAL) << "Duplicate registration of device factory for type "
108                  << device_type << " with the same priority " << priority;
109     }
110   }
111 }
112 
GetFactory(const string & device_type)113 DeviceFactory* DeviceFactory::GetFactory(const string& device_type) {
114   tf_shared_lock l(*get_device_factory_lock());
115   auto it = device_factories().find(device_type);
116   if (it == device_factories().end()) {
117     return nullptr;
118   } else if (!IsDeviceFactoryEnabled(device_type)) {
119     LOG(FATAL) << "Device type " << device_type  // Crash OK
120                << " had factory registered but was explicitly disabled by "
121                << "`TF_ENABLED_DEVICE_TYPES`. This environment variable needs "
122                << "to be set at program startup.";
123   }
124   return it->second.factory.get();
125 }
126 
ListAllPhysicalDevices(std::vector<string> * devices)127 Status DeviceFactory::ListAllPhysicalDevices(std::vector<string>* devices) {
128   // CPU first. A CPU device is required.
129   // TODO(b/183974121): Consider merge the logic into the loop below.
130   auto cpu_factory = GetFactory("CPU");
131   if (!cpu_factory) {
132     return errors::NotFound(
133         "CPU Factory not registered. Did you link in threadpool_device?");
134   }
135 
136   size_t init_size = devices->size();
137   TF_RETURN_IF_ERROR(cpu_factory->ListPhysicalDevices(devices));
138   if (devices->size() == init_size) {
139     return errors::NotFound("No CPU devices are available in this process");
140   }
141 
142   // Then the rest (including GPU).
143   tf_shared_lock l(*get_device_factory_lock());
144   for (auto& p : device_factories()) {
145     auto factory = p.second.factory.get();
146     if (factory != cpu_factory) {
147       TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(devices));
148     }
149   }
150 
151   return OkStatus();
152 }
153 
ListPluggablePhysicalDevices(std::vector<string> * devices)154 Status DeviceFactory::ListPluggablePhysicalDevices(
155     std::vector<string>* devices) {
156   tf_shared_lock l(*get_device_factory_lock());
157   for (auto& p : device_factories()) {
158     if (p.second.is_pluggable_device) {
159       auto factory = p.second.factory.get();
160       TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(devices));
161     }
162   }
163   return OkStatus();
164 }
165 
GetAnyDeviceDetails(int device_index,std::unordered_map<string,string> * details)166 Status DeviceFactory::GetAnyDeviceDetails(
167     int device_index, std::unordered_map<string, string>* details) {
168   if (device_index < 0) {
169     return errors::InvalidArgument("Device index out of bounds: ",
170                                    device_index);
171   }
172   const int orig_device_index = device_index;
173 
174   // Iterate over devices in the same way as in ListAllPhysicalDevices.
175   auto cpu_factory = GetFactory("CPU");
176   if (!cpu_factory) {
177     return errors::NotFound(
178         "CPU Factory not registered. Did you link in threadpool_device?");
179   }
180 
181   // TODO(b/183974121): Consider merge the logic into the loop below.
182   std::vector<string> devices;
183   TF_RETURN_IF_ERROR(cpu_factory->ListPhysicalDevices(&devices));
184   if (device_index < devices.size()) {
185     return cpu_factory->GetDeviceDetails(device_index, details);
186   }
187   device_index -= devices.size();
188 
189   // Then the rest (including GPU).
190   tf_shared_lock l(*get_device_factory_lock());
191   for (auto& p : device_factories()) {
192     auto factory = p.second.factory.get();
193     if (factory != cpu_factory) {
194       devices.clear();
195       // TODO(b/146009447): Find the factory size without having to allocate a
196       // vector with all the physical devices.
197       TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(&devices));
198       if (device_index < devices.size()) {
199         return factory->GetDeviceDetails(device_index, details);
200       }
201       device_index -= devices.size();
202     }
203   }
204 
205   return errors::InvalidArgument("Device index out of bounds: ",
206                                  orig_device_index);
207 }
208 
AddCpuDevices(const SessionOptions & options,const string & name_prefix,std::vector<std::unique_ptr<Device>> * devices)209 Status DeviceFactory::AddCpuDevices(
210     const SessionOptions& options, const string& name_prefix,
211     std::vector<std::unique_ptr<Device>>* devices) {
212   auto cpu_factory = GetFactory("CPU");
213   if (!cpu_factory) {
214     return errors::NotFound(
215         "CPU Factory not registered. Did you link in threadpool_device?");
216   }
217   size_t init_size = devices->size();
218   TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(options, name_prefix, devices));
219   if (devices->size() == init_size) {
220     return errors::NotFound("No CPU devices are available in this process");
221   }
222 
223   return OkStatus();
224 }
225 
AddDevices(const SessionOptions & options,const string & name_prefix,std::vector<std::unique_ptr<Device>> * devices)226 Status DeviceFactory::AddDevices(
227     const SessionOptions& options, const string& name_prefix,
228     std::vector<std::unique_ptr<Device>>* devices) {
229   // CPU first. A CPU device is required.
230   // TODO(b/183974121): Consider merge the logic into the loop below.
231   TF_RETURN_IF_ERROR(AddCpuDevices(options, name_prefix, devices));
232 
233   auto cpu_factory = GetFactory("CPU");
234   // Then the rest (including GPU).
235   mutex_lock l(*get_device_factory_lock());
236   for (auto& p : device_factories()) {
237     auto factory = p.second.factory.get();
238     if (factory != cpu_factory) {
239       TF_RETURN_IF_ERROR(factory->CreateDevices(options, name_prefix, devices));
240     }
241   }
242 
243   return OkStatus();
244 }
245 
NewDevice(const string & type,const SessionOptions & options,const string & name_prefix)246 std::unique_ptr<Device> DeviceFactory::NewDevice(const string& type,
247                                                  const SessionOptions& options,
248                                                  const string& name_prefix) {
249   auto device_factory = GetFactory(type);
250   if (!device_factory) {
251     return nullptr;
252   }
253   SessionOptions opt = options;
254   (*opt.config.mutable_device_count())[type] = 1;
255   std::vector<std::unique_ptr<Device>> devices;
256   TF_CHECK_OK(device_factory->CreateDevices(opt, name_prefix, &devices));
257   int expected_num_devices = 1;
258   auto iter = options.config.device_count().find(type);
259   if (iter != options.config.device_count().end()) {
260     expected_num_devices = iter->second;
261   }
262   DCHECK_EQ(devices.size(), static_cast<size_t>(expected_num_devices));
263   return std::move(devices[0]);
264 }
265 
266 }  // namespace tensorflow
267