xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/stream_executor_internal.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Interfaces for platform-dependent implementations to satisfy. This are
17 // delegated to from the StreamExecutor in pointer-to-implementation style; i.e.
18 // the StreamExecutor is just a husk that delegates calls to the
19 // platform-specific objects which implement the interfaces defined here.
20 
21 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
22 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
23 
24 #include <functional>
25 #include <map>
26 #include <memory>
27 #include <utility>
28 #include <vector>
29 
30 #include "absl/types/optional.h"
31 #include "tensorflow/compiler/xla/stream_executor/allocator_stats.h"
32 #include "tensorflow/compiler/xla/stream_executor/device_description.h"
33 #include "tensorflow/compiler/xla/stream_executor/device_memory.h"
34 #include "tensorflow/compiler/xla/stream_executor/device_options.h"
35 #include "tensorflow/compiler/xla/stream_executor/dnn.h"
36 #include "tensorflow/compiler/xla/stream_executor/event.h"
37 #include "tensorflow/compiler/xla/stream_executor/kernel.h"
38 #include "tensorflow/compiler/xla/stream_executor/kernel_cache_config.h"
39 #include "tensorflow/compiler/xla/stream_executor/kernel_spec.h"
40 #include "tensorflow/compiler/xla/stream_executor/launch_dim.h"
41 #include "tensorflow/compiler/xla/stream_executor/lib/status.h"
42 #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
43 #include "tensorflow/compiler/xla/stream_executor/module_spec.h"
44 #include "tensorflow/compiler/xla/stream_executor/platform.h"
45 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
46 #include "tensorflow/compiler/xla/stream_executor/plugin_registry.h"
47 #include "tensorflow/compiler/xla/stream_executor/trace_listener.h"
48 
49 namespace stream_executor {
50 
51 class Stream;
52 class Timer;
53 
54 // An opaque handle to a loaded module.
55 //
56 // An instance of this is returned from StreamExecutor::GetModule.
57 class ModuleHandle {
58  public:
id_(id)59   /*implicit*/ ModuleHandle(void* id = nullptr) : id_(id) {}
60 
61   // A ModuleHandle with id() == nullptr is an invalid module handle, akin to a
62   // null pointer.
id()63   void* id() const { return id_; }
64 
65   explicit operator bool() const { return id() != nullptr; }
66 
67  private:
68   void* id_;
69 };
70 
71 namespace internal {
72 
73 // Platform-dependent interface class for the generic Events interface, in
74 // the PIMPL style.
75 class EventInterface {
76  public:
EventInterface()77   EventInterface() {}
~EventInterface()78   virtual ~EventInterface() {}
79 
80  private:
81   SE_DISALLOW_COPY_AND_ASSIGN(EventInterface);
82 };
83 
84 // Pointer-to-implementation object type (i.e. the KernelBase class delegates to
85 // this interface) with virtual destruction. This class exists for the
86 // platform-dependent code to hang any kernel data/resource info/functionality
87 // off of.
88 class KernelInterface {
89  public:
90   // Default constructor for the abstract interface.
KernelInterface()91   KernelInterface() {}
92 
93   // Default destructor for the abstract interface.
~KernelInterface()94   virtual ~KernelInterface() {}
95 
96   // Returns the number of formal parameters that this kernel accepts.
97   virtual unsigned Arity() const = 0;
98 
99   // Sets the preferred cache configuration.
100   virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0;
101 
102   // Gets the preferred cache configuration.
103   virtual KernelCacheConfig GetPreferredCacheConfig() const = 0;
104 
105  private:
106   SE_DISALLOW_COPY_AND_ASSIGN(KernelInterface);
107 };
108 
109 // Pointer-to-implementation object type (i.e. the Stream class delegates to
110 // this interface) with virtual destruction. This class exists for the
111 // platform-dependent code to hang any kernel data/resource info/functionality
112 // off of.
113 class StreamInterface {
114  public:
115   // Default constructor for the abstract interface.
StreamInterface()116   StreamInterface() {}
117 
118   // Default destructor for the abstract interface.
~StreamInterface()119   virtual ~StreamInterface() {}
120 
121   // Returns the GPU stream associated with this platform's stream
122   // implementation, or nullptr otherwise.
GpuStreamHack()123   virtual void* GpuStreamHack() { return nullptr; }
124 
125   // Returns a pointer to a GPU stream associated with this platform's stream,
126   // or a nullptr.
GpuStreamMemberHack()127   virtual void** GpuStreamMemberHack() { return nullptr; }
128 
129  private:
130   SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface);
131 };
132 
133 // Pointer-to-implementation object type (i.e. the Timer class delegates to
134 // this interface) with virtual destruction. This class exists for the
135 // platform-dependent code to hang any timer data/resource info/functionality
136 // off of.
137 class TimerInterface {
138  public:
139   // Default constructor for the abstract interface.
TimerInterface()140   TimerInterface() {}
141 
142   // Default destructor for the abstract interface.
~TimerInterface()143   virtual ~TimerInterface() {}
144 
145   // Returns the number of microseconds elapsed in a completed timer.
146   virtual uint64_t Microseconds() const = 0;
147 
148   // Returns the number of nanoseconds elapsed in a completed timer.
149   virtual uint64_t Nanoseconds() const = 0;
150 
151  private:
152   SE_DISALLOW_COPY_AND_ASSIGN(TimerInterface);
153 };
154 
155 // Interface for the different StreamExecutor platforms (i.e. CUDA, OpenCL).
156 //
157 // Various platforms will provide an implementation that satisfy this interface.
158 class StreamExecutorInterface {
159  public:
160   // Default constructor for the abstract interface.
StreamExecutorInterface()161   StreamExecutorInterface() {}
162 
163   // Default destructor for the abstract interface.
~StreamExecutorInterface()164   virtual ~StreamExecutorInterface() {}
165 
166   // Returns the (transitively) wrapped executor if this executor is
167   // wrapping another executor; otherwise, returns this.
GetUnderlyingExecutor()168   virtual StreamExecutorInterface* GetUnderlyingExecutor() { return this; }
169 
170   // See the StreamExecutor interface for comments on the same-named methods.
171   virtual port::Status Init(int device_ordinal,
172                             DeviceOptions device_options) = 0;
173 
GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)174   virtual port::Status GetKernel(const MultiKernelLoaderSpec& spec,
175                                  KernelBase* kernel) {
176     return port::UnimplementedError("Not Implemented");
177   }
UnloadModule(ModuleHandle module_handle)178   virtual bool UnloadModule(ModuleHandle module_handle) { return false; }
LoadModule(const MultiModuleLoaderSpec & spec,ModuleHandle * module_handle)179   virtual port::Status LoadModule(const MultiModuleLoaderSpec& spec,
180                                   ModuleHandle* module_handle) {
181     return port::UnimplementedError("Not Implemented");
182   }
183   virtual port::StatusOr<std::shared_ptr<DeviceMemoryBase>>
CreateOrShareConstant(Stream * stream,const std::vector<uint8_t> & content)184   CreateOrShareConstant(Stream* stream, const std::vector<uint8_t>& content) {
185     return port::UnimplementedError("Not Implemented");
186   }
Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & k,const KernelArgsArrayBase & args)187   virtual port::Status Launch(Stream* stream, const ThreadDim& thread_dims,
188                               const BlockDim& block_dims, const KernelBase& k,
189                               const KernelArgsArrayBase& args) {
190     return port::UnimplementedError("Not Implemented");
191   }
192 
193   // Releases any state associated with the kernel.
UnloadKernel(const KernelBase * kernel)194   virtual void UnloadKernel(const KernelBase* kernel) {}
195   virtual DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) = 0;
Allocate(uint64_t size)196   DeviceMemoryBase Allocate(uint64_t size) {
197     return Allocate(size, /*memory_space=*/0);
198   }
199   virtual void* GetSubBuffer(DeviceMemoryBase* parent, uint64_t offset,
200                              uint64_t size) = 0;
201   virtual void Deallocate(DeviceMemoryBase* mem) = 0;
202   // Allocates unified memory space of the given size, if supported.
203   // See
204   // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd
205   // for more details on unified memory.
UnifiedMemoryAllocate(uint64_t size)206   virtual void* UnifiedMemoryAllocate(uint64_t size) { return nullptr; }
207 
208   // Deallocates unified memory space previously allocated with
209   // UnifiedMemoryAllocate.
UnifiedMemoryDeallocate(void * mem)210   virtual void UnifiedMemoryDeallocate(void* mem) {}
211   virtual void* HostMemoryAllocate(uint64_t size) = 0;
212   virtual void HostMemoryDeallocate(void* mem) = 0;
213   virtual bool HostMemoryRegister(void* mem, uint64_t size) = 0;
214   virtual bool HostMemoryUnregister(void* mem) = 0;
215   virtual bool SynchronizeAllActivity() = 0;
216   virtual port::Status SynchronousMemZero(DeviceMemoryBase* location,
217                                           uint64_t size) = 0;
218   virtual port::Status SynchronousMemSet(DeviceMemoryBase* location, int value,
219                                          uint64_t size) = 0;
220   virtual port::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst,
221                                          const void* host_src,
222                                          uint64_t size) = 0;
223   virtual port::Status SynchronousMemcpy(void* host_dst,
224                                          const DeviceMemoryBase& gpu_src,
225                                          uint64_t size) = 0;
226   virtual port::Status SynchronousMemcpyDeviceToDevice(
227       DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src,
228       uint64_t size) = 0;
229   virtual port::Status MemZero(Stream* stream, DeviceMemoryBase* location,
230                                uint64_t size) = 0;
Memset(Stream * stream,DeviceMemoryBase * location,uint8 pattern,uint64_t size)231   virtual port::Status Memset(Stream* stream, DeviceMemoryBase* location,
232                               uint8 pattern, uint64_t size) {
233     return port::InternalError("Not implemented");
234   }
235   virtual port::Status Memset32(Stream* stream, DeviceMemoryBase* location,
236                                 uint32 pattern, uint64_t size) = 0;
237   virtual bool Memcpy(Stream* stream, void* host_dst,
238                       const DeviceMemoryBase& gpu_src, uint64_t size) = 0;
239   virtual bool Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst,
240                       const void* host_src, uint64_t size) = 0;
241   virtual bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst,
242                                     const DeviceMemoryBase& gpu_src,
243                                     uint64_t size) = 0;
244   virtual bool HostCallback(Stream* stream, std::function<void()> callback);
245   virtual bool HostCallback(Stream* stream,
246                             std::function<port::Status()> callback) = 0;
247   virtual port::Status AllocateEvent(Event* event) = 0;
248   virtual port::Status DeallocateEvent(Event* event) = 0;
249   virtual port::Status RecordEvent(Stream* stream, Event* event) = 0;
250   virtual port::Status WaitForEvent(Stream* stream, Event* event) = 0;
251   virtual Event::Status PollForEventStatus(Event* event) = 0;
252   virtual bool AllocateStream(Stream* stream) = 0;
253   virtual void DeallocateStream(Stream* stream) = 0;
254   virtual bool CreateStreamDependency(Stream* dependent, Stream* other) = 0;
255   virtual bool AllocateTimer(Timer* timer) = 0;
256   virtual void DeallocateTimer(Timer* timer) = 0;
257   virtual bool StartTimer(Stream* stream, Timer* timer) = 0;
258   virtual bool StopTimer(Stream* stream, Timer* timer) = 0;
259   virtual port::Status BlockHostUntilDone(Stream* stream) = 0;
GetStatus(Stream * stream)260   virtual port::Status GetStatus(Stream* stream) {
261     return port::Status(port::error::UNIMPLEMENTED,
262                         "GetStatus is not supported on this executor.");
263   }
264   virtual int PlatformDeviceCount() = 0;
265   virtual port::Status EnablePeerAccessTo(StreamExecutorInterface* other) = 0;
266   virtual bool CanEnablePeerAccessTo(StreamExecutorInterface* other) = 0;
267 
GetDeviceLoad()268   virtual int64_t GetDeviceLoad() { return -1; }
269 
DeviceMemoryUsage(int64_t * free,int64_t * total)270   virtual bool DeviceMemoryUsage(int64_t* free, int64_t* total) const {
271     return false;
272   }
273 
274   // Retrieves device pointer and size for a symbol. The device pointer is
275   // stored at mem, and the size is stored at size. Either mem or bytes can be
276   // null, however, both of them cannot be null at the same time. To use
277   // constant memory in CUDA, GetSymbol has to be used. Returns true if symbol
278   // is found.
279   //
280   // If ModuleHandle is set then we search for `symbol_name` only within the
281   // module corresponding to `module_handle`.  Otherwise all loaded modules are
282   // searched.
GetSymbol(const std::string & symbol_name,ModuleHandle module_handle,void ** mem,size_t * bytes)283   virtual bool GetSymbol(const std::string& symbol_name,
284                          ModuleHandle module_handle, void** mem,
285                          size_t* bytes) {
286     return false;
287   }
288 
289   // Creates a new DeviceDescription object. Ownership is transferred to the
290   // caller.
291   virtual port::StatusOr<std::unique_ptr<DeviceDescription>>
292   CreateDeviceDescription() const = 0;
293 
294   // Attempts to register the provided TraceListener with the device-specific
295   // Executor implementation. When this is called, the PIMPL interface has
296   // already taken ownership of the object and is managing the generic tracing
297   // events. The device-specific implementation must determine if the passed
298   // listener is of a type appropriate for it to trace during registration (and
299   // before dispatching events to it).
300   // Returns true if the listener was successfully registered, false otherwise.
301   // Does not take ownership of listener.
RegisterTraceListener(TraceListener * listener)302   virtual bool RegisterTraceListener(TraceListener* listener) { return false; }
303 
304   // Unregisters the specified listener from the device-specific Executor.
305   // Returns true if the listener was successfully registered, false otherwise.
UnregisterTraceListener(TraceListener * listener)306   virtual bool UnregisterTraceListener(TraceListener* listener) {
307     return false;
308   }
309 
310   // Returns whether this StreamExecutor has BLAS support for its underlying
311   // platform.
SupportsBlas()312   virtual bool SupportsBlas() const { return false; }
313 
314   // Creates a new BlasSupport object, ownership is transferred to the caller.
315   // If SupportsBlas() is false, this will always return null.
316   //
317   // If SupportsBlas() is true, this may return null, for example, if the BLAS
318   // initialization fails.
CreateBlas()319   virtual blas::BlasSupport* CreateBlas() { return nullptr; }
320 
321   // Returns whether this StreamExecutor has FFT support for its underlying
322   // platform.
SupportsFft()323   virtual bool SupportsFft() const { return false; }
324 
325   // Creates a new fft::FftSupport object, ownership is transferred to the
326   // caller.
327   // If SupportsFft() is false, this will always return null.
328   //
329   // If SupportsFft() is true, this may return null, for example, if the FFT
330   // initialization fails.
CreateFft()331   virtual fft::FftSupport* CreateFft() { return nullptr; }
332 
333   // Returns whether this StreamExecutor has Random Number Generation support
334   // for
335   // its underlying platform.
SupportsRng()336   virtual bool SupportsRng() const { return false; }
337 
338   // Returns whether this StreamExecutor has neural net support for its
339   // underlying
340   // platform.
SupportsDnn()341   virtual bool SupportsDnn() const { return false; }
342 
343   // Creates a new RngSupport object, ownership is transferred to the caller.
344   // If SupportsRng() is false, this will always return null.
345   //
346   // If SupportsRng() is true, this may return null, for example, if the RNG
347   // initialization fails.
CreateRng()348   virtual rng::RngSupport* CreateRng() { return nullptr; }
349 
350   // Creates a new DnnSupport object, ownership is transferred to the caller.
351   // If SupportsDnn() is false, this will always return null.
352   //
353   // If SupportsDnn() is true, this may return null, for example, if the DNN
354   // initialization fails.
CreateDnn()355   virtual dnn::DnnSupport* CreateDnn() { return nullptr; }
356 
357   // Each call creates a new instance of the platform-specific implementation of
358   // the corresponding interface type.
359   virtual std::unique_ptr<EventInterface> CreateEventImplementation() = 0;
360   virtual std::unique_ptr<KernelInterface> CreateKernelImplementation() = 0;
361   virtual std::unique_ptr<StreamInterface> GetStreamImplementation() = 0;
362   virtual std::unique_ptr<TimerInterface> GetTimerImplementation() = 0;
363 
364   // Returns the CUDA or ROCm context associated with this StreamExecutor
365   // platform implementation.
366   //
367   // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm,
368   // causing a fatal error if it is not. This hack is made available solely for
369   // use from distbelief code, which temporarily has strong ties to CUDA or ROCm
370   // as a platform.
GpuContextHack()371   virtual void* GpuContextHack() { return nullptr; }
372 
373   // Return allocator statistics.
GetAllocatorStats()374   virtual std::optional<AllocatorStats> GetAllocatorStats() {
375     return std::nullopt;
376   }
377 
378   // If implemented, clears the internal stats except for the `in_use` fields
379   // and sets the `peak_bytes_in_use` to be equal to the `bytes_in_use`. Returns
380   // true if implemented.
381   //
382   // REQUIRES: GetAllocatorStats is overridden.
ClearAllocatorStats()383   virtual bool ClearAllocatorStats() { return false; }
384 
385   // Clears the compilation cache from volatile memory. Returns OK if no
386   // compilation cache exists or if clearing the compilation cache is
387   // unsupported. Caches in non-volatile storage are unaffected.
FlushCompilationCache()388   virtual port::Status FlushCompilationCache() {
389     return ::tensorflow::OkStatus();
390   }
391 
392   // Returns a stream allocated by this executor, or nullptr if not found.
393   // Performs linear search over alive GPU streams.
FindAllocatedStream(void *)394   virtual Stream* FindAllocatedStream(void* /*gpu_stream*/) { return nullptr; }
395 
396  private:
397   SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface);
398 };
399 
400 }  // namespace internal
401 }  // namespace stream_executor
402 
403 #endif  // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
404