1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/stream_executor/multi_platform_manager.h"
17
18 #include <string>
19
20 #include "absl/base/thread_annotations.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/ascii.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/synchronization/mutex.h"
27 #include "tensorflow/compiler/xla/stream_executor/lib/error.h"
28 #include "tensorflow/compiler/xla/stream_executor/lib/initialize.h"
29 #include "tensorflow/core/platform/errors.h"
30
31 namespace stream_executor {
32 namespace {
33
34 class MultiPlatformManagerImpl {
35 public:
36 port::Status RegisterPlatform(std::unique_ptr<Platform> platform)
37 ABSL_LOCKS_EXCLUDED(mu_);
38
39 port::StatusOr<Platform*> PlatformWithName(absl::string_view target)
40 ABSL_LOCKS_EXCLUDED(mu_);
41
42 port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id)
43 ABSL_LOCKS_EXCLUDED(mu_);
44
45 port::StatusOr<Platform*> PlatformWithName(absl::string_view target,
46 bool initialize_platform)
47 ABSL_LOCKS_EXCLUDED(mu_);
48
49 port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id,
50 bool initialize_platform)
51 ABSL_LOCKS_EXCLUDED(mu_);
52
53 port::StatusOr<Platform*> InitializePlatformWithName(
54 absl::string_view target,
55 const std::map<std::string, std::string>& options)
56 ABSL_LOCKS_EXCLUDED(mu_);
57 port::StatusOr<Platform*> InitializePlatformWithId(
58 const Platform::Id& id, const std::map<std::string, std::string>& options)
59 ABSL_LOCKS_EXCLUDED(mu_);
60
61 port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
62 const std::function<bool(const Platform*)>& filter,
63 bool initialize_platform) ABSL_LOCKS_EXCLUDED(mu_);
64
65 using Listener = MultiPlatformManager::Listener;
66 port::Status RegisterListener(std::unique_ptr<Listener> listener)
67 ABSL_LOCKS_EXCLUDED(mu_);
68
69 private:
70 // Looks up the platform object with the given name. Assumes the Platforms
71 // mutex is held.
72 port::StatusOr<Platform*> LookupByNameLocked(absl::string_view target)
73 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
74
75 // Looks up the platform object with the given id. Assumes the Platforms
76 // mutex is held.
77 port::StatusOr<Platform*> LookupByIdLocked(const Platform::Id& id)
78 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
79
80 // Returns the names of the initialied platforms satisfying the given filter.
81 // By default, it will return all initialized platform names.
82 std::vector<std::string> InitializedPlatformNamesWithFilter(
__anon587030ef0202(const Platform*) 83 const std::function<bool(const Platform*)>& filter = [](const Platform*) {
84 return true;
85 }) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
86
87 absl::Mutex mu_;
88 std::vector<std::unique_ptr<Listener>> listeners_ ABSL_GUARDED_BY(mu_);
89 absl::flat_hash_map<Platform::Id, Platform*> id_map_ ABSL_GUARDED_BY(mu_);
90 absl::flat_hash_map<std::string, Platform*> name_map_ ABSL_GUARDED_BY(mu_);
91 };
92
RegisterPlatform(std::unique_ptr<Platform> platform)93 port::Status MultiPlatformManagerImpl::RegisterPlatform(
94 std::unique_ptr<Platform> platform) {
95 CHECK(platform != nullptr);
96 std::string key = absl::AsciiStrToLower(platform->Name());
97 absl::MutexLock lock(&mu_);
98 if (name_map_.find(key) != name_map_.end()) {
99 return port::Status(port::error::INTERNAL,
100 "platform is already registered with name: \"" +
101 platform->Name() + "\"");
102 }
103 Platform* platform_ptr = platform.get();
104 CHECK(id_map_.emplace(platform->id(), platform_ptr).second);
105 // Release ownership/uniqueness to prevent destruction on program exit.
106 // This avoids Platforms "cleaning up" on program exit, because otherwise,
107 // there are _very_ tricky races between StreamExecutor and underlying
108 // platforms (CUDA, OpenCL) during exit. Since these are fixed-size and 1x per
109 // program, these are deemed acceptable.
110 name_map_[key] = platform.release();
111 for (const auto& listener : listeners_) {
112 listener->PlatformRegistered(platform_ptr);
113 }
114 return ::tensorflow::OkStatus();
115 }
116
PlatformWithName(absl::string_view target)117 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithName(
118 absl::string_view target) {
119 return PlatformWithName(target, /*initialize_platform=*/true);
120 }
121
PlatformWithId(const Platform::Id & id)122 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithId(
123 const Platform::Id& id) {
124 return PlatformWithId(id, /*initialize_platform=*/true);
125 }
126
PlatformWithName(absl::string_view target,bool initialize_platform)127 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithName(
128 absl::string_view target, bool initialize_platform) {
129 absl::MutexLock lock(&mu_);
130
131 TF_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
132 if (initialize_platform && !platform->Initialized()) {
133 TF_RETURN_IF_ERROR(platform->Initialize({}));
134 }
135
136 return platform;
137 }
138
PlatformWithId(const Platform::Id & id,bool initialize_platform)139 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithId(
140 const Platform::Id& id, bool initialize_platform) {
141 absl::MutexLock lock(&mu_);
142
143 TF_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
144 if (initialize_platform && !platform->Initialized()) {
145 TF_RETURN_IF_ERROR(platform->Initialize({}));
146 }
147
148 return platform;
149 }
150
InitializePlatformWithName(absl::string_view target,const std::map<std::string,std::string> & options)151 port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithName(
152 absl::string_view target,
153 const std::map<std::string, std::string>& options) {
154 absl::MutexLock lock(&mu_);
155
156 TF_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
157 if (platform->Initialized()) {
158 return port::Status(
159 port::error::FAILED_PRECONDITION,
160 absl::StrCat("platform \"", target, "\" is already initialized"));
161 }
162
163 TF_RETURN_IF_ERROR(platform->Initialize(options));
164
165 return platform;
166 }
167
InitializePlatformWithId(const Platform::Id & id,const std::map<std::string,std::string> & options)168 port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithId(
169 const Platform::Id& id, const std::map<std::string, std::string>& options) {
170 absl::MutexLock lock(&mu_);
171
172 TF_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
173 if (platform->Initialized()) {
174 return port::Status(
175 port::error::FAILED_PRECONDITION,
176 absl::StrFormat("platform with id %p is already initialized", id));
177 }
178
179 TF_RETURN_IF_ERROR(platform->Initialize(options));
180
181 return platform;
182 }
183
RegisterListener(std::unique_ptr<Listener> listener)184 port::Status MultiPlatformManagerImpl::RegisterListener(
185 std::unique_ptr<Listener> listener) {
186 absl::MutexLock lock(&mu_);
187 CHECK(id_map_.empty());
188 CHECK(name_map_.empty());
189 listeners_.push_back(std::move(listener));
190 return ::tensorflow::OkStatus();
191 }
192
193 port::StatusOr<std::vector<Platform*>>
PlatformsWithFilter(const std::function<bool (const Platform *)> & filter,bool initialize_platform)194 MultiPlatformManagerImpl::PlatformsWithFilter(
195 const std::function<bool(const Platform*)>& filter,
196 bool initialize_platform) {
197 absl::MutexLock lock(&mu_);
198 CHECK_EQ(id_map_.size(), name_map_.size());
199 std::vector<Platform*> platforms;
200 platforms.reserve(id_map_.size());
201 for (const auto& entry : id_map_) {
202 Platform* platform = entry.second;
203 if (filter(platform)) {
204 if (initialize_platform && !platform->Initialized()) {
205 TF_RETURN_IF_ERROR(platform->Initialize({}));
206 }
207 platforms.push_back(platform);
208 }
209 }
210 return platforms;
211 }
212
213 std::vector<std::string>
InitializedPlatformNamesWithFilter(const std::function<bool (const Platform *)> & filter)214 MultiPlatformManagerImpl::InitializedPlatformNamesWithFilter(
215 const std::function<bool(const Platform*)>& filter) {
216 CHECK_EQ(id_map_.size(), name_map_.size());
217 std::vector<std::string> initialized_platforms_names;
218 initialized_platforms_names.reserve(id_map_.size());
219 for (const auto& entry : id_map_) {
220 Platform* platform = entry.second;
221 if (filter(platform)) {
222 if (platform->Initialized()) {
223 initialized_platforms_names.push_back(platform->Name());
224 }
225 }
226 }
227 return initialized_platforms_names;
228 }
229
LookupByNameLocked(absl::string_view target)230 port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByNameLocked(
231 absl::string_view target) {
232 auto it = name_map_.find(absl::AsciiStrToLower(target));
233 if (it == name_map_.end()) {
234 return port::Status(
235 port::error::NOT_FOUND,
236 absl::StrCat("Could not find registered platform with name: \"", target,
237 "\". Available platform names are: ",
238 absl::StrJoin(InitializedPlatformNamesWithFilter(), " ")));
239 }
240 return it->second;
241 }
242
LookupByIdLocked(const Platform::Id & id)243 port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByIdLocked(
244 const Platform::Id& id) {
245 auto it = id_map_.find(id);
246 if (it == id_map_.end()) {
247 return port::Status(
248 port::error::NOT_FOUND,
249 absl::StrFormat("could not find registered platform with id: %p", id));
250 }
251 return it->second;
252 }
253
Impl()254 MultiPlatformManagerImpl& Impl() {
255 static MultiPlatformManagerImpl* impl = new MultiPlatformManagerImpl;
256 return *impl;
257 }
258
259 } // namespace
260
RegisterPlatform(std::unique_ptr<Platform> platform)261 /*static*/ port::Status MultiPlatformManager::RegisterPlatform(
262 std::unique_ptr<Platform> platform) {
263 return Impl().RegisterPlatform(std::move(platform));
264 }
265
PlatformWithName(absl::string_view target)266 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithName(
267 absl::string_view target) {
268 return Impl().PlatformWithName(target);
269 }
270
PlatformWithId(const Platform::Id & id)271 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithId(
272 const Platform::Id& id) {
273 return Impl().PlatformWithId(id);
274 }
275
PlatformWithId(const Platform::Id & id,bool initialize_platform)276 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithId(
277 const Platform::Id& id, bool initialize_platform) {
278 return Impl().PlatformWithId(id, initialize_platform);
279 }
280
PlatformWithName(absl::string_view target,bool initialize_platform)281 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithName(
282 absl::string_view target, bool initialize_platform) {
283 return Impl().PlatformWithName(target, initialize_platform);
284 }
285
286 /*static*/ port::StatusOr<Platform*>
InitializePlatformWithName(absl::string_view target,const std::map<std::string,std::string> & options)287 MultiPlatformManager::InitializePlatformWithName(
288 absl::string_view target,
289 const std::map<std::string, std::string>& options) {
290 return Impl().InitializePlatformWithName(target, options);
291 }
292
293 /*static*/ port::StatusOr<Platform*>
InitializePlatformWithId(const Platform::Id & id,const std::map<std::string,std::string> & options)294 MultiPlatformManager::InitializePlatformWithId(
295 const Platform::Id& id, const std::map<std::string, std::string>& options) {
296 return Impl().InitializePlatformWithId(id, options);
297 }
298
RegisterListener(std::unique_ptr<Listener> listener)299 /*static*/ port::Status MultiPlatformManager::RegisterListener(
300 std::unique_ptr<Listener> listener) {
301 return Impl().RegisterListener(std::move(listener));
302 }
303
304 /*static*/ port::StatusOr<std::vector<Platform*>>
PlatformsWithFilter(const std::function<bool (const Platform *)> & filter)305 MultiPlatformManager::PlatformsWithFilter(
306 const std::function<bool(const Platform*)>& filter) {
307 return PlatformsWithFilter(filter, /*initialize_platform=*/true);
308 }
309
310 /*static*/ port::StatusOr<std::vector<Platform*>>
PlatformsWithFilter(const std::function<bool (const Platform *)> & filter,bool initialize_platform)311 MultiPlatformManager::PlatformsWithFilter(
312 const std::function<bool(const Platform*)>& filter,
313 bool initialize_platform) {
314 return Impl().PlatformsWithFilter(filter, initialize_platform);
315 }
316
317 } // namespace stream_executor
318
319 REGISTER_MODULE_INITIALIZER(
320 multi_platform_manager,
321 {
322 // Nothing -- this is just a module initializer
323 // definition to reference for sequencing
324 // purposes from Platform subclasses that register
325 // themselves with the MultiPlatformManager.
326 });
327
328 REGISTER_MODULE_INITIALIZER(
329 multi_platform_manager_listener,
330 {
331 // Nothing -- this is just a module initializer definition to reference
332 // for sequencing registration of listeners with the
333 // MultiPlatformManager.
334 });
335
336 // Listener registration should happen before platform registration.
337 REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener,
338 multi_platform_manager);
339