1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/device_util.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "tensorflow/compiler/xla/status_macros.h"
21
22 namespace tensorflow {
23 namespace jit {
24
Insert(DeviceId device_id)25 void DeviceSet::Insert(DeviceId device_id) {
26 int word_index = device_id.id() / kWordSize;
27 int bit_index = device_id.id() % kWordSize;
28 const int storage_size = storage_.size();
29 if (word_index >= storage_size) {
30 storage_.resize(word_index + 1, 0);
31 }
32
33 storage_[word_index] |= (1ull << bit_index);
34 }
35
UnionWith(const DeviceSet & other)36 void DeviceSet::UnionWith(const DeviceSet& other) {
37 if (other.storage_.size() > storage_.size()) {
38 storage_.resize(other.storage_.size(), 0);
39 }
40
41 for (int i = 0, end = other.storage_.size(); i < end; i++) {
42 storage_[i] |= other.storage_[i];
43 }
44 }
45
IsEmpty() const46 bool DeviceSet::IsEmpty() const {
47 return absl::c_all_of(storage_, [&](uint64 val) { return val == 0; });
48 }
49
GetIdFor(absl::string_view name)50 StatusOr<DeviceId> DeviceInfoCache::GetIdFor(absl::string_view name) {
51 TF_RET_CHECK(!name.empty());
52
53 auto it = name_to_id_.find(name);
54 if (it != name_to_id_.end()) {
55 return it->second;
56 }
57
58 int new_id = names_.size();
59 names_.push_back(string(name));
60 id_to_device_type_.push_back(std::make_unique<DeviceType>(""));
61 DeviceType* device_type = id_to_device_type_.back().get();
62 TF_RETURN_IF_ERROR(DeviceNameToDeviceType(names_.back(), device_type));
63
64 is_cpu_.push_back(device_type->type_string() == DEVICE_CPU);
65 is_gpu_.push_back(device_type->type_string() == DEVICE_GPU);
66
67 name_to_id_.emplace(string(name), DeviceId(new_id));
68
69 const XlaOpRegistry::DeviceRegistration* compilation_device;
70 if (!XlaOpRegistry::GetCompilationDevice(device_type->type(),
71 &compilation_device)) {
72 compilation_device = nullptr;
73 }
74 id_to_compilation_device_.push_back(compilation_device);
75
76 return DeviceId(new_id);
77 }
78
DebugString(const DeviceSet & device_set) const79 string DeviceInfoCache::DebugString(const DeviceSet& device_set) const {
80 std::vector<string> names;
81 device_set.ForEach([&](DeviceId device_id) {
82 names.push_back(string(GetNameFor(device_id)));
83 return true;
84 });
85
86 return absl::StrCat("[", absl::StrJoin(names, ","), "]");
87 }
88 } // namespace jit
89
DeviceNameToDeviceType(const string & device,DeviceType * device_type)90 Status DeviceNameToDeviceType(const string& device, DeviceType* device_type) {
91 DeviceNameUtils::ParsedName parsed;
92 if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
93 return errors::Internal("Malformed assigned device '", device, "'");
94 }
95 *device_type = DeviceType(parsed.type);
96 return OkStatus();
97 }
98
PickDeviceForXlaImpl(const jit::DeviceInfoCache & device_info_cache,const jit::DeviceSet & devices,bool allow_mixing_unknown_and_cpu,bool failure_to_pick_is_error)99 StatusOr<std::optional<jit::DeviceId>> PickDeviceForXlaImpl(
100 const jit::DeviceInfoCache& device_info_cache,
101 const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu,
102 bool failure_to_pick_is_error) {
103 #define FAILED_TO_PICK_DEVICE(failing_status) \
104 do { \
105 if (failure_to_pick_is_error) { \
106 return failing_status; \
107 } else { \
108 return {std::nullopt}; \
109 } \
110 } while (false)
111
112 std::optional<jit::DeviceId> maybe_gpu_device;
113 std::optional<jit::DeviceId> maybe_cpu_device;
114 std::optional<jit::DeviceId> maybe_unknown_device;
115
116 bool multiple_cpu_devices = false;
117 bool multiple_gpu_devices = false;
118 bool multiple_unknown_devices = false;
119
120 // Returns 'true' if d0 and d1 are conflicting devices. If they are
121 // compatible, update d1 with a more specific one.
122 // TODO(sanjoy): Cache DeviceNameUtils::ParsedName inside device_info_cache.
123 const auto is_multiple_devices =
124 [&](const jit::DeviceId& d0, std::optional<jit::DeviceId>* d1) -> bool {
125 const absl::string_view name0 = device_info_cache.GetNameFor(d0);
126 const absl::string_view name1 = device_info_cache.GetNameFor(d1->value());
127
128 DeviceNameUtils::ParsedName parsed0, parsed1;
129 if (!DeviceNameUtils::ParseFullName(name0, &parsed0) ||
130 !DeviceNameUtils::ParseFullName(name1, &parsed1) ||
131 !DeviceNameUtils::AreCompatibleDevNames(parsed0, parsed1)) {
132 return true;
133 }
134
135 if (DeviceNameUtils::IsSpecification(parsed0, parsed1)) {
136 return false;
137 }
138
139 if (DeviceNameUtils::IsSpecification(parsed1, parsed0)) {
140 *d1 = d0;
141 return false;
142 }
143
144 return true;
145 };
146
147 devices.ForEach([&](jit::DeviceId device) {
148 if (device_info_cache.IsGpu(device)) {
149 if (maybe_gpu_device) {
150 multiple_gpu_devices = is_multiple_devices(device, &maybe_gpu_device);
151 if (multiple_gpu_devices) return false;
152 } else {
153 maybe_gpu_device = device;
154 }
155 } else if (device_info_cache.IsCpu(device)) {
156 if (maybe_cpu_device) {
157 multiple_cpu_devices = is_multiple_devices(device, &maybe_cpu_device);
158 if (multiple_cpu_devices) return false;
159 } else {
160 maybe_cpu_device = device;
161 }
162 } else {
163 if (maybe_unknown_device) {
164 multiple_unknown_devices = true;
165 return false;
166 }
167 maybe_unknown_device = device;
168 }
169
170 return true;
171 });
172
173 if (multiple_cpu_devices) {
174 FAILED_TO_PICK_DEVICE(errors::Internal(
175 "Multiple CPU devices ", device_info_cache.DebugString(devices)));
176 }
177
178 if (multiple_gpu_devices) {
179 FAILED_TO_PICK_DEVICE(errors::Internal(
180 "Multiple GPU devices ", device_info_cache.DebugString(devices)));
181 }
182
183 if (multiple_unknown_devices) {
184 FAILED_TO_PICK_DEVICE(errors::Internal(
185 "Multiple unknown devices ", device_info_cache.DebugString(devices)));
186 }
187
188 if (maybe_unknown_device && maybe_gpu_device) {
189 FAILED_TO_PICK_DEVICE(errors::Internal(
190 "Found both unknown and GPU devices: ",
191 device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
192 device_info_cache.GetNameFor(*maybe_gpu_device)));
193 }
194
195 if (!allow_mixing_unknown_and_cpu) {
196 if (maybe_unknown_device && maybe_cpu_device) {
197 FAILED_TO_PICK_DEVICE(errors::Internal(
198 "Found both unknown and CPU devices: ",
199 device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
200 device_info_cache.GetNameFor(*maybe_cpu_device)));
201 }
202 }
203
204 if (maybe_gpu_device) {
205 return {*maybe_gpu_device};
206 } else if (maybe_unknown_device) {
207 return {*maybe_unknown_device};
208 } else if (maybe_cpu_device) {
209 return {*maybe_cpu_device};
210 }
211
212 FAILED_TO_PICK_DEVICE(errors::Internal("Empty device set!"));
213
214 #undef FAILED_TO_PICK_DEVICE
215 }
216
PickDeviceForXla(const jit::DeviceInfoCache & device_info_cache,const jit::DeviceSet & devices,bool allow_mixing_unknown_and_cpu)217 StatusOr<jit::DeviceId> PickDeviceForXla(
218 const jit::DeviceInfoCache& device_info_cache,
219 const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
220 TF_ASSIGN_OR_RETURN(std::optional<jit::DeviceId> device_id,
221 PickDeviceForXlaImpl(device_info_cache, devices,
222 allow_mixing_unknown_and_cpu,
223 /*failure_to_pick_is_error=*/true));
224 return *device_id;
225 }
226
MaybePickDeviceForXla(const jit::DeviceInfoCache & device_info_cache,const jit::DeviceSet & devices,bool allow_mixing_unknown_and_cpu)227 StatusOr<std::optional<jit::DeviceId>> MaybePickDeviceForXla(
228 const jit::DeviceInfoCache& device_info_cache,
229 const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
230 return PickDeviceForXlaImpl(device_info_cache, devices,
231 allow_mixing_unknown_and_cpu,
232 /*failure_to_pick_is_error=*/false);
233 }
234 } // namespace tensorflow
235