xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/multi_platform_manager.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/stream_executor/multi_platform_manager.h"
17 
18 #include <string>
19 
20 #include "absl/base/thread_annotations.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/ascii.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/synchronization/mutex.h"
27 #include "tensorflow/compiler/xla/stream_executor/lib/error.h"
28 #include "tensorflow/compiler/xla/stream_executor/lib/initialize.h"
29 #include "tensorflow/core/platform/errors.h"
30 
31 namespace stream_executor {
32 namespace {
33 
34 class MultiPlatformManagerImpl {
35  public:
36   port::Status RegisterPlatform(std::unique_ptr<Platform> platform)
37       ABSL_LOCKS_EXCLUDED(mu_);
38 
39   port::StatusOr<Platform*> PlatformWithName(absl::string_view target)
40       ABSL_LOCKS_EXCLUDED(mu_);
41 
42   port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id)
43       ABSL_LOCKS_EXCLUDED(mu_);
44 
45   port::StatusOr<Platform*> PlatformWithName(absl::string_view target,
46                                              bool initialize_platform)
47       ABSL_LOCKS_EXCLUDED(mu_);
48 
49   port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id,
50                                            bool initialize_platform)
51       ABSL_LOCKS_EXCLUDED(mu_);
52 
53   port::StatusOr<Platform*> InitializePlatformWithName(
54       absl::string_view target,
55       const std::map<std::string, std::string>& options)
56       ABSL_LOCKS_EXCLUDED(mu_);
57   port::StatusOr<Platform*> InitializePlatformWithId(
58       const Platform::Id& id, const std::map<std::string, std::string>& options)
59       ABSL_LOCKS_EXCLUDED(mu_);
60 
61   port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
62       const std::function<bool(const Platform*)>& filter,
63       bool initialize_platform) ABSL_LOCKS_EXCLUDED(mu_);
64 
65   using Listener = MultiPlatformManager::Listener;
66   port::Status RegisterListener(std::unique_ptr<Listener> listener)
67       ABSL_LOCKS_EXCLUDED(mu_);
68 
69  private:
70   // Looks up the platform object with the given name.  Assumes the Platforms
71   // mutex is held.
72   port::StatusOr<Platform*> LookupByNameLocked(absl::string_view target)
73       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
74 
75   // Looks up the platform object with the given id.  Assumes the Platforms
76   // mutex is held.
77   port::StatusOr<Platform*> LookupByIdLocked(const Platform::Id& id)
78       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
79 
80   // Returns the names of the initialied platforms satisfying the given filter.
81   // By default, it will return all initialized platform names.
82   std::vector<std::string> InitializedPlatformNamesWithFilter(
__anon587030ef0202(const Platform*) 83       const std::function<bool(const Platform*)>& filter = [](const Platform*) {
84         return true;
85       }) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
86 
87   absl::Mutex mu_;
88   std::vector<std::unique_ptr<Listener>> listeners_ ABSL_GUARDED_BY(mu_);
89   absl::flat_hash_map<Platform::Id, Platform*> id_map_ ABSL_GUARDED_BY(mu_);
90   absl::flat_hash_map<std::string, Platform*> name_map_ ABSL_GUARDED_BY(mu_);
91 };
92 
RegisterPlatform(std::unique_ptr<Platform> platform)93 port::Status MultiPlatformManagerImpl::RegisterPlatform(
94     std::unique_ptr<Platform> platform) {
95   CHECK(platform != nullptr);
96   std::string key = absl::AsciiStrToLower(platform->Name());
97   absl::MutexLock lock(&mu_);
98   if (name_map_.find(key) != name_map_.end()) {
99     return port::Status(port::error::INTERNAL,
100                         "platform is already registered with name: \"" +
101                             platform->Name() + "\"");
102   }
103   Platform* platform_ptr = platform.get();
104   CHECK(id_map_.emplace(platform->id(), platform_ptr).second);
105   // Release ownership/uniqueness to prevent destruction on program exit.
106   // This avoids Platforms "cleaning up" on program exit, because otherwise,
107   // there are _very_ tricky races between StreamExecutor and underlying
108   // platforms (CUDA, OpenCL) during exit. Since these are fixed-size and 1x per
109   // program, these are deemed acceptable.
110   name_map_[key] = platform.release();
111   for (const auto& listener : listeners_) {
112     listener->PlatformRegistered(platform_ptr);
113   }
114   return ::tensorflow::OkStatus();
115 }
116 
PlatformWithName(absl::string_view target)117 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithName(
118     absl::string_view target) {
119   return PlatformWithName(target, /*initialize_platform=*/true);
120 }
121 
PlatformWithId(const Platform::Id & id)122 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithId(
123     const Platform::Id& id) {
124   return PlatformWithId(id, /*initialize_platform=*/true);
125 }
126 
PlatformWithName(absl::string_view target,bool initialize_platform)127 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithName(
128     absl::string_view target, bool initialize_platform) {
129   absl::MutexLock lock(&mu_);
130 
131   TF_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
132   if (initialize_platform && !platform->Initialized()) {
133     TF_RETURN_IF_ERROR(platform->Initialize({}));
134   }
135 
136   return platform;
137 }
138 
PlatformWithId(const Platform::Id & id,bool initialize_platform)139 port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithId(
140     const Platform::Id& id, bool initialize_platform) {
141   absl::MutexLock lock(&mu_);
142 
143   TF_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
144   if (initialize_platform && !platform->Initialized()) {
145     TF_RETURN_IF_ERROR(platform->Initialize({}));
146   }
147 
148   return platform;
149 }
150 
InitializePlatformWithName(absl::string_view target,const std::map<std::string,std::string> & options)151 port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithName(
152     absl::string_view target,
153     const std::map<std::string, std::string>& options) {
154   absl::MutexLock lock(&mu_);
155 
156   TF_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
157   if (platform->Initialized()) {
158     return port::Status(
159         port::error::FAILED_PRECONDITION,
160         absl::StrCat("platform \"", target, "\" is already initialized"));
161   }
162 
163   TF_RETURN_IF_ERROR(platform->Initialize(options));
164 
165   return platform;
166 }
167 
InitializePlatformWithId(const Platform::Id & id,const std::map<std::string,std::string> & options)168 port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithId(
169     const Platform::Id& id, const std::map<std::string, std::string>& options) {
170   absl::MutexLock lock(&mu_);
171 
172   TF_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
173   if (platform->Initialized()) {
174     return port::Status(
175         port::error::FAILED_PRECONDITION,
176         absl::StrFormat("platform with id %p is already initialized", id));
177   }
178 
179   TF_RETURN_IF_ERROR(platform->Initialize(options));
180 
181   return platform;
182 }
183 
RegisterListener(std::unique_ptr<Listener> listener)184 port::Status MultiPlatformManagerImpl::RegisterListener(
185     std::unique_ptr<Listener> listener) {
186   absl::MutexLock lock(&mu_);
187   CHECK(id_map_.empty());
188   CHECK(name_map_.empty());
189   listeners_.push_back(std::move(listener));
190   return ::tensorflow::OkStatus();
191 }
192 
193 port::StatusOr<std::vector<Platform*>>
PlatformsWithFilter(const std::function<bool (const Platform *)> & filter,bool initialize_platform)194 MultiPlatformManagerImpl::PlatformsWithFilter(
195     const std::function<bool(const Platform*)>& filter,
196     bool initialize_platform) {
197   absl::MutexLock lock(&mu_);
198   CHECK_EQ(id_map_.size(), name_map_.size());
199   std::vector<Platform*> platforms;
200   platforms.reserve(id_map_.size());
201   for (const auto& entry : id_map_) {
202     Platform* platform = entry.second;
203     if (filter(platform)) {
204       if (initialize_platform && !platform->Initialized()) {
205         TF_RETURN_IF_ERROR(platform->Initialize({}));
206       }
207       platforms.push_back(platform);
208     }
209   }
210   return platforms;
211 }
212 
213 std::vector<std::string>
InitializedPlatformNamesWithFilter(const std::function<bool (const Platform *)> & filter)214 MultiPlatformManagerImpl::InitializedPlatformNamesWithFilter(
215     const std::function<bool(const Platform*)>& filter) {
216   CHECK_EQ(id_map_.size(), name_map_.size());
217   std::vector<std::string> initialized_platforms_names;
218   initialized_platforms_names.reserve(id_map_.size());
219   for (const auto& entry : id_map_) {
220     Platform* platform = entry.second;
221     if (filter(platform)) {
222       if (platform->Initialized()) {
223         initialized_platforms_names.push_back(platform->Name());
224       }
225     }
226   }
227   return initialized_platforms_names;
228 }
229 
LookupByNameLocked(absl::string_view target)230 port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByNameLocked(
231     absl::string_view target) {
232   auto it = name_map_.find(absl::AsciiStrToLower(target));
233   if (it == name_map_.end()) {
234     return port::Status(
235         port::error::NOT_FOUND,
236         absl::StrCat("Could not find registered platform with name: \"", target,
237                      "\". Available platform names are: ",
238                      absl::StrJoin(InitializedPlatformNamesWithFilter(), " ")));
239   }
240   return it->second;
241 }
242 
LookupByIdLocked(const Platform::Id & id)243 port::StatusOr<Platform*> MultiPlatformManagerImpl::LookupByIdLocked(
244     const Platform::Id& id) {
245   auto it = id_map_.find(id);
246   if (it == id_map_.end()) {
247     return port::Status(
248         port::error::NOT_FOUND,
249         absl::StrFormat("could not find registered platform with id: %p", id));
250   }
251   return it->second;
252 }
253 
Impl()254 MultiPlatformManagerImpl& Impl() {
255   static MultiPlatformManagerImpl* impl = new MultiPlatformManagerImpl;
256   return *impl;
257 }
258 
259 }  // namespace
260 
RegisterPlatform(std::unique_ptr<Platform> platform)261 /*static*/ port::Status MultiPlatformManager::RegisterPlatform(
262     std::unique_ptr<Platform> platform) {
263   return Impl().RegisterPlatform(std::move(platform));
264 }
265 
PlatformWithName(absl::string_view target)266 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithName(
267     absl::string_view target) {
268   return Impl().PlatformWithName(target);
269 }
270 
PlatformWithId(const Platform::Id & id)271 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithId(
272     const Platform::Id& id) {
273   return Impl().PlatformWithId(id);
274 }
275 
PlatformWithId(const Platform::Id & id,bool initialize_platform)276 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithId(
277     const Platform::Id& id, bool initialize_platform) {
278   return Impl().PlatformWithId(id, initialize_platform);
279 }
280 
PlatformWithName(absl::string_view target,bool initialize_platform)281 /*static*/ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithName(
282     absl::string_view target, bool initialize_platform) {
283   return Impl().PlatformWithName(target, initialize_platform);
284 }
285 
286 /*static*/ port::StatusOr<Platform*>
InitializePlatformWithName(absl::string_view target,const std::map<std::string,std::string> & options)287 MultiPlatformManager::InitializePlatformWithName(
288     absl::string_view target,
289     const std::map<std::string, std::string>& options) {
290   return Impl().InitializePlatformWithName(target, options);
291 }
292 
293 /*static*/ port::StatusOr<Platform*>
InitializePlatformWithId(const Platform::Id & id,const std::map<std::string,std::string> & options)294 MultiPlatformManager::InitializePlatformWithId(
295     const Platform::Id& id, const std::map<std::string, std::string>& options) {
296   return Impl().InitializePlatformWithId(id, options);
297 }
298 
RegisterListener(std::unique_ptr<Listener> listener)299 /*static*/ port::Status MultiPlatformManager::RegisterListener(
300     std::unique_ptr<Listener> listener) {
301   return Impl().RegisterListener(std::move(listener));
302 }
303 
304 /*static*/ port::StatusOr<std::vector<Platform*>>
PlatformsWithFilter(const std::function<bool (const Platform *)> & filter)305 MultiPlatformManager::PlatformsWithFilter(
306     const std::function<bool(const Platform*)>& filter) {
307   return PlatformsWithFilter(filter, /*initialize_platform=*/true);
308 }
309 
310 /*static*/ port::StatusOr<std::vector<Platform*>>
PlatformsWithFilter(const std::function<bool (const Platform *)> & filter,bool initialize_platform)311 MultiPlatformManager::PlatformsWithFilter(
312     const std::function<bool(const Platform*)>& filter,
313     bool initialize_platform) {
314   return Impl().PlatformsWithFilter(filter, initialize_platform);
315 }
316 
317 }  // namespace stream_executor
318 
319 REGISTER_MODULE_INITIALIZER(
320     multi_platform_manager,
321     {
322         // Nothing -- this is just a module initializer
323         // definition to reference for sequencing
324         // purposes from Platform subclasses that register
325         // themselves with the MultiPlatformManager.
326     });
327 
328 REGISTER_MODULE_INITIALIZER(
329     multi_platform_manager_listener,
330     {
331         // Nothing -- this is just a module initializer definition to reference
332         // for sequencing registration of listeners with the
333         // MultiPlatformManager.
334     });
335 
336 // Listener registration should happen before platform registration.
337 REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener,
338                                      multi_platform_manager);
339