xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/multi_platform_manager.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This is a registration-oriented interface for multiple platforms. It will
17 // replace the MachineManager singleton interface, as MachineManager does not
18 // currently support simultaneous use of multiple platforms.
19 //
20 // Usage:
21 //
22 // In your BUILD rule, add a dependency on a platform plugin that you'd like
23 // to use, such as:
24 //
25 //   //third_party/tensorflow/compiler/xla/stream_executor/cuda:cuda_platform
26 //   //third_party/tensorflow/compiler/xla/stream_executor/opencl:opencl_platform
27 //
28 // This will register platform plugins that can be discovered via this
29 // interface. Sample API usage:
30 //
31 //   port::StatusOr<Platform*> platform_status =
32 //      se::MultiPlatformManager::PlatformWithName("OpenCL");
33 //   if (!platform_status.ok()) { ... }
34 //   Platform* platform = platform_status.ValueOrDie();
35 //   LOG(INFO) << platform->VisibleDeviceCount() << " devices visible";
36 //   if (platform->VisibleDeviceCount() <= 0) { return; }
37 //
38 //   for (int i = 0; i < platform->VisibleDeviceCount(); ++i) {
39 //     port::StatusOr<StreamExecutor*> executor_status =
40 //        platform->ExecutorForDevice(i);
41 //     if (!executor_status.ok()) {
42 //       LOG(INFO) << "could not retrieve executor for device ordinal " << i
43 //                 << ": " << executor_status.status();
44 //       continue;
45 //     }
46 //     LOG(INFO) << "found usable executor: " << executor_status.ValueOrDie();
47 //   }
48 //
49 // A few things to note:
50 //  - There is no standard formatting/practice for identifying the name of a
51 //    platform. Ideally, a platform will list its registered name in its header
52 //    or in other associated documentation.
53 //  - Platform name lookup is case-insensitive. "OpenCL" or "opencl" (or even
54 //    ("OpEnCl") would work correctly in the above example.
55 //
56 // And similarly, for standard interfaces (BLAS, RNG, etc.) you can add
57 // dependencies on support libraries, e.g.:
58 //
59 //    //third_party/tensorflow/compiler/xla/stream_executor/cuda:pluton_blas_plugin
60 //    //third_party/tensorflow/compiler/xla/stream_executor/cuda:cudnn_plugin
61 //    //third_party/tensorflow/compiler/xla/stream_executor/cuda:cublas_plugin
62 //    //third_party/tensorflow/compiler/xla/stream_executor/cuda:curand_plugin
63 
64 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_MULTI_PLATFORM_MANAGER_H_
65 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_MULTI_PLATFORM_MANAGER_H_
66 
67 #include <functional>
68 #include <map>
69 #include <memory>
70 #include <vector>
71 
72 #include "absl/strings/string_view.h"
73 #include "tensorflow/compiler/xla/stream_executor/lib/initialize.h"
74 #include "tensorflow/compiler/xla/stream_executor/lib/status.h"
75 #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
76 #include "tensorflow/compiler/xla/stream_executor/platform.h"
77 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
78 
79 namespace stream_executor {
80 
81 // Manages multiple platforms that may be present on the current machine.
82 class MultiPlatformManager {
83  public:
84   // Registers a platform object, returns an error status if the platform is
85   // already registered. The associated listener, if not null, will be used to
86   // trace events for ALL executors for that platform.
87   // Takes ownership of platform.
88   static port::Status RegisterPlatform(std::unique_ptr<Platform> platform);
89 
90   // Retrieves the platform registered with the given platform name (e.g.
91   // "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the
92   // Platform's Id() method).
93   //
94   // If the platform has not already been initialized, it will be initialized
95   // with a default set of parameters.
96   //
97   // If the requested platform is not registered, an error status is returned.
98   // Ownership of the platform is NOT transferred to the caller --
99   // the MultiPlatformManager owns the platforms in a singleton-like fashion.
100   static port::StatusOr<Platform*> PlatformWithName(absl::string_view target);
101   static port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id);
102 
103   // Same functions as above, but allows platforms to be returned without
104   // initialization if initialize_platform == false.
105   static port::StatusOr<Platform*> PlatformWithName(absl::string_view target,
106                                                     bool initialize_platform);
107   static port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id,
108                                                   bool initialize_platform);
109 
110   // Retrieves the platform registered with the given platform name (e.g.
111   // "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the
112   // Platform's Id() method).
113   //
114   // The platform will be initialized with the given options. If the platform
115   // was already initialized, an error will be returned.
116   //
117   // If the requested platform is not registered, an error status is returned.
118   // Ownership of the platform is NOT transferred to the caller --
119   // the MultiPlatformManager owns the platforms in a singleton-like fashion.
120   static port::StatusOr<Platform*> InitializePlatformWithName(
121       absl::string_view target,
122       const std::map<std::string, std::string>& options);
123 
124   static port::StatusOr<Platform*> InitializePlatformWithId(
125       const Platform::Id& id,
126       const std::map<std::string, std::string>& options);
127 
128   // Retrieves the platforms satisfying the given filter, i.e. returns true.
129   // Returned Platforms are always initialized.
130   static port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
131       const std::function<bool(const Platform*)>& filter);
132 
133   static port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
134       const std::function<bool(const Platform*)>& filter,
135       bool initialize_platform);
136 
137   // Although the MultiPlatformManager "owns" its platforms, it holds them as
138   // undecorated pointers to prevent races during program exit (between this
139   // object's data and the underlying platforms (e.g., CUDA, OpenCL).
140   // Because certain platforms have unpredictable deinitialization
141   // times/sequences, it is not possible to strucure a safe deinitialization
142   // sequence. Thus, we intentionally "leak" allocated platforms to defer
143   // cleanup to the OS. This should be acceptable, as these are one-time
144   // allocations per program invocation.
145   // The MultiPlatformManager should be considered the owner
146   // of any platforms registered with it, and leak checking should be disabled
147   // during allocation of such Platforms, to avoid spurious reporting at program
148   // exit.
149 
150   // Interface for a listener that gets notified at certain events.
151   class Listener {
152    public:
153     virtual ~Listener() = default;
154     // Callback that is invoked when a Platform is registered.
155     virtual void PlatformRegistered(Platform* platform) = 0;
156   };
157   // Registers a listeners to receive notifications about certain events.
158   // Precondition: No Platform has been registered yet.
159   static port::Status RegisterListener(std::unique_ptr<Listener> listener);
160 };
161 
162 }  // namespace stream_executor
163 
164 // multi_platform_manager.cc will define these instances.
165 //
166 // Registering a platform:
167 // REGISTER_MODULE_INITIALIZER_SEQUENCE(my_platform, multi_platform_manager);
168 // REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener,
169 // my_platform);
170 //
171 // Registering a listener:
172 // REGISTER_MODULE_INITIALIZER_SEQUENCE(my_listener,
173 // multi_platform_manager_listener);
174 DECLARE_MODULE_INITIALIZER(multi_platform_manager);
175 DECLARE_MODULE_INITIALIZER(multi_platform_manager_listener);
176 
177 #endif  // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_MULTI_PLATFORM_MANAGER_H_
178