xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/device_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/device_util.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "tensorflow/compiler/xla/status_macros.h"
21 
22 namespace tensorflow {
23 namespace jit {
24 
Insert(DeviceId device_id)25 void DeviceSet::Insert(DeviceId device_id) {
26   int word_index = device_id.id() / kWordSize;
27   int bit_index = device_id.id() % kWordSize;
28   const int storage_size = storage_.size();
29   if (word_index >= storage_size) {
30     storage_.resize(word_index + 1, 0);
31   }
32 
33   storage_[word_index] |= (1ull << bit_index);
34 }
35 
UnionWith(const DeviceSet & other)36 void DeviceSet::UnionWith(const DeviceSet& other) {
37   if (other.storage_.size() > storage_.size()) {
38     storage_.resize(other.storage_.size(), 0);
39   }
40 
41   for (int i = 0, end = other.storage_.size(); i < end; i++) {
42     storage_[i] |= other.storage_[i];
43   }
44 }
45 
IsEmpty() const46 bool DeviceSet::IsEmpty() const {
47   return absl::c_all_of(storage_, [&](uint64 val) { return val == 0; });
48 }
49 
GetIdFor(absl::string_view name)50 StatusOr<DeviceId> DeviceInfoCache::GetIdFor(absl::string_view name) {
51   TF_RET_CHECK(!name.empty());
52 
53   auto it = name_to_id_.find(name);
54   if (it != name_to_id_.end()) {
55     return it->second;
56   }
57 
58   int new_id = names_.size();
59   names_.push_back(string(name));
60   id_to_device_type_.push_back(std::make_unique<DeviceType>(""));
61   DeviceType* device_type = id_to_device_type_.back().get();
62   TF_RETURN_IF_ERROR(DeviceNameToDeviceType(names_.back(), device_type));
63 
64   is_cpu_.push_back(device_type->type_string() == DEVICE_CPU);
65   is_gpu_.push_back(device_type->type_string() == DEVICE_GPU);
66 
67   name_to_id_.emplace(string(name), DeviceId(new_id));
68 
69   const XlaOpRegistry::DeviceRegistration* compilation_device;
70   if (!XlaOpRegistry::GetCompilationDevice(device_type->type(),
71                                            &compilation_device)) {
72     compilation_device = nullptr;
73   }
74   id_to_compilation_device_.push_back(compilation_device);
75 
76   return DeviceId(new_id);
77 }
78 
DebugString(const DeviceSet & device_set) const79 string DeviceInfoCache::DebugString(const DeviceSet& device_set) const {
80   std::vector<string> names;
81   device_set.ForEach([&](DeviceId device_id) {
82     names.push_back(string(GetNameFor(device_id)));
83     return true;
84   });
85 
86   return absl::StrCat("[", absl::StrJoin(names, ","), "]");
87 }
88 }  // namespace jit
89 
DeviceNameToDeviceType(const string & device,DeviceType * device_type)90 Status DeviceNameToDeviceType(const string& device, DeviceType* device_type) {
91   DeviceNameUtils::ParsedName parsed;
92   if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
93     return errors::Internal("Malformed assigned device '", device, "'");
94   }
95   *device_type = DeviceType(parsed.type);
96   return OkStatus();
97 }
98 
PickDeviceForXlaImpl(const jit::DeviceInfoCache & device_info_cache,const jit::DeviceSet & devices,bool allow_mixing_unknown_and_cpu,bool failure_to_pick_is_error)99 StatusOr<std::optional<jit::DeviceId>> PickDeviceForXlaImpl(
100     const jit::DeviceInfoCache& device_info_cache,
101     const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu,
102     bool failure_to_pick_is_error) {
103 #define FAILED_TO_PICK_DEVICE(failing_status) \
104   do {                                        \
105     if (failure_to_pick_is_error) {           \
106       return failing_status;                  \
107     } else {                                  \
108       return {std::nullopt};                  \
109     }                                         \
110   } while (false)
111 
112   std::optional<jit::DeviceId> maybe_gpu_device;
113   std::optional<jit::DeviceId> maybe_cpu_device;
114   std::optional<jit::DeviceId> maybe_unknown_device;
115 
116   bool multiple_cpu_devices = false;
117   bool multiple_gpu_devices = false;
118   bool multiple_unknown_devices = false;
119 
120   // Returns 'true' if d0 and d1 are conflicting devices. If they are
121   // compatible, update d1 with a more specific one.
122   // TODO(sanjoy): Cache DeviceNameUtils::ParsedName inside device_info_cache.
123   const auto is_multiple_devices =
124       [&](const jit::DeviceId& d0, std::optional<jit::DeviceId>* d1) -> bool {
125     const absl::string_view name0 = device_info_cache.GetNameFor(d0);
126     const absl::string_view name1 = device_info_cache.GetNameFor(d1->value());
127 
128     DeviceNameUtils::ParsedName parsed0, parsed1;
129     if (!DeviceNameUtils::ParseFullName(name0, &parsed0) ||
130         !DeviceNameUtils::ParseFullName(name1, &parsed1) ||
131         !DeviceNameUtils::AreCompatibleDevNames(parsed0, parsed1)) {
132       return true;
133     }
134 
135     if (DeviceNameUtils::IsSpecification(parsed0, parsed1)) {
136       return false;
137     }
138 
139     if (DeviceNameUtils::IsSpecification(parsed1, parsed0)) {
140       *d1 = d0;
141       return false;
142     }
143 
144     return true;
145   };
146 
147   devices.ForEach([&](jit::DeviceId device) {
148     if (device_info_cache.IsGpu(device)) {
149       if (maybe_gpu_device) {
150         multiple_gpu_devices = is_multiple_devices(device, &maybe_gpu_device);
151         if (multiple_gpu_devices) return false;
152       } else {
153         maybe_gpu_device = device;
154       }
155     } else if (device_info_cache.IsCpu(device)) {
156       if (maybe_cpu_device) {
157         multiple_cpu_devices = is_multiple_devices(device, &maybe_cpu_device);
158         if (multiple_cpu_devices) return false;
159       } else {
160         maybe_cpu_device = device;
161       }
162     } else {
163       if (maybe_unknown_device) {
164         multiple_unknown_devices = true;
165         return false;
166       }
167       maybe_unknown_device = device;
168     }
169 
170     return true;
171   });
172 
173   if (multiple_cpu_devices) {
174     FAILED_TO_PICK_DEVICE(errors::Internal(
175         "Multiple CPU devices ", device_info_cache.DebugString(devices)));
176   }
177 
178   if (multiple_gpu_devices) {
179     FAILED_TO_PICK_DEVICE(errors::Internal(
180         "Multiple GPU devices ", device_info_cache.DebugString(devices)));
181   }
182 
183   if (multiple_unknown_devices) {
184     FAILED_TO_PICK_DEVICE(errors::Internal(
185         "Multiple unknown devices ", device_info_cache.DebugString(devices)));
186   }
187 
188   if (maybe_unknown_device && maybe_gpu_device) {
189     FAILED_TO_PICK_DEVICE(errors::Internal(
190         "Found both unknown and GPU devices: ",
191         device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
192         device_info_cache.GetNameFor(*maybe_gpu_device)));
193   }
194 
195   if (!allow_mixing_unknown_and_cpu) {
196     if (maybe_unknown_device && maybe_cpu_device) {
197       FAILED_TO_PICK_DEVICE(errors::Internal(
198           "Found both unknown and CPU devices: ",
199           device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
200           device_info_cache.GetNameFor(*maybe_cpu_device)));
201     }
202   }
203 
204   if (maybe_gpu_device) {
205     return {*maybe_gpu_device};
206   } else if (maybe_unknown_device) {
207     return {*maybe_unknown_device};
208   } else if (maybe_cpu_device) {
209     return {*maybe_cpu_device};
210   }
211 
212   FAILED_TO_PICK_DEVICE(errors::Internal("Empty device set!"));
213 
214 #undef FAILED_TO_PICK_DEVICE
215 }
216 
PickDeviceForXla(const jit::DeviceInfoCache & device_info_cache,const jit::DeviceSet & devices,bool allow_mixing_unknown_and_cpu)217 StatusOr<jit::DeviceId> PickDeviceForXla(
218     const jit::DeviceInfoCache& device_info_cache,
219     const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
220   TF_ASSIGN_OR_RETURN(std::optional<jit::DeviceId> device_id,
221                       PickDeviceForXlaImpl(device_info_cache, devices,
222                                            allow_mixing_unknown_and_cpu,
223                                            /*failure_to_pick_is_error=*/true));
224   return *device_id;
225 }
226 
MaybePickDeviceForXla(const jit::DeviceInfoCache & device_info_cache,const jit::DeviceSet & devices,bool allow_mixing_unknown_and_cpu)227 StatusOr<std::optional<jit::DeviceId>> MaybePickDeviceForXla(
228     const jit::DeviceInfoCache& device_info_cache,
229     const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
230   return PickDeviceForXlaImpl(device_info_cache, devices,
231                               allow_mixing_unknown_and_cpu,
232                               /*failure_to_pick_is_error=*/false);
233 }
234 }  // namespace tensorflow
235