1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Interfaces for platform-dependent implementations to satisfy. This are 17 // delegated to from the StreamExecutor in pointer-to-implementation style; i.e. 18 // the StreamExecutor is just a husk that delegates calls to the 19 // platform-specific objects which implement the interfaces defined here. 20 21 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 22 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 23 24 #include <functional> 25 #include <map> 26 #include <memory> 27 #include <utility> 28 #include <vector> 29 30 #include "absl/types/optional.h" 31 #include "tensorflow/compiler/xla/stream_executor/allocator_stats.h" 32 #include "tensorflow/compiler/xla/stream_executor/device_description.h" 33 #include "tensorflow/compiler/xla/stream_executor/device_memory.h" 34 #include "tensorflow/compiler/xla/stream_executor/device_options.h" 35 #include "tensorflow/compiler/xla/stream_executor/dnn.h" 36 #include "tensorflow/compiler/xla/stream_executor/event.h" 37 #include "tensorflow/compiler/xla/stream_executor/kernel.h" 38 #include "tensorflow/compiler/xla/stream_executor/kernel_cache_config.h" 39 #include "tensorflow/compiler/xla/stream_executor/kernel_spec.h" 40 #include "tensorflow/compiler/xla/stream_executor/launch_dim.h" 41 #include "tensorflow/compiler/xla/stream_executor/lib/status.h" 42 #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" 43 #include "tensorflow/compiler/xla/stream_executor/module_spec.h" 44 #include "tensorflow/compiler/xla/stream_executor/platform.h" 45 #include "tensorflow/compiler/xla/stream_executor/platform/port.h" 46 #include "tensorflow/compiler/xla/stream_executor/plugin_registry.h" 47 #include "tensorflow/compiler/xla/stream_executor/trace_listener.h" 48 49 namespace stream_executor { 50 51 class Stream; 52 class Timer; 53 54 // An opaque handle to a loaded module. 55 // 56 // An instance of this is returned from StreamExecutor::GetModule. 57 class ModuleHandle { 58 public: id_(id)59 /*implicit*/ ModuleHandle(void* id = nullptr) : id_(id) {} 60 61 // A ModuleHandle with id() == nullptr is an invalid module handle, akin to a 62 // null pointer. id()63 void* id() const { return id_; } 64 65 explicit operator bool() const { return id() != nullptr; } 66 67 private: 68 void* id_; 69 }; 70 71 namespace internal { 72 73 // Platform-dependent interface class for the generic Events interface, in 74 // the PIMPL style. 75 class EventInterface { 76 public: EventInterface()77 EventInterface() {} ~EventInterface()78 virtual ~EventInterface() {} 79 80 private: 81 SE_DISALLOW_COPY_AND_ASSIGN(EventInterface); 82 }; 83 84 // Pointer-to-implementation object type (i.e. the KernelBase class delegates to 85 // this interface) with virtual destruction. This class exists for the 86 // platform-dependent code to hang any kernel data/resource info/functionality 87 // off of. 88 class KernelInterface { 89 public: 90 // Default constructor for the abstract interface. KernelInterface()91 KernelInterface() {} 92 93 // Default destructor for the abstract interface. ~KernelInterface()94 virtual ~KernelInterface() {} 95 96 // Returns the number of formal parameters that this kernel accepts. 97 virtual unsigned Arity() const = 0; 98 99 // Sets the preferred cache configuration. 100 virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0; 101 102 // Gets the preferred cache configuration. 103 virtual KernelCacheConfig GetPreferredCacheConfig() const = 0; 104 105 private: 106 SE_DISALLOW_COPY_AND_ASSIGN(KernelInterface); 107 }; 108 109 // Pointer-to-implementation object type (i.e. the Stream class delegates to 110 // this interface) with virtual destruction. This class exists for the 111 // platform-dependent code to hang any kernel data/resource info/functionality 112 // off of. 113 class StreamInterface { 114 public: 115 // Default constructor for the abstract interface. StreamInterface()116 StreamInterface() {} 117 118 // Default destructor for the abstract interface. ~StreamInterface()119 virtual ~StreamInterface() {} 120 121 // Returns the GPU stream associated with this platform's stream 122 // implementation, or nullptr otherwise. GpuStreamHack()123 virtual void* GpuStreamHack() { return nullptr; } 124 125 // Returns a pointer to a GPU stream associated with this platform's stream, 126 // or a nullptr. GpuStreamMemberHack()127 virtual void** GpuStreamMemberHack() { return nullptr; } 128 129 private: 130 SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface); 131 }; 132 133 // Pointer-to-implementation object type (i.e. the Timer class delegates to 134 // this interface) with virtual destruction. This class exists for the 135 // platform-dependent code to hang any timer data/resource info/functionality 136 // off of. 137 class TimerInterface { 138 public: 139 // Default constructor for the abstract interface. TimerInterface()140 TimerInterface() {} 141 142 // Default destructor for the abstract interface. ~TimerInterface()143 virtual ~TimerInterface() {} 144 145 // Returns the number of microseconds elapsed in a completed timer. 146 virtual uint64_t Microseconds() const = 0; 147 148 // Returns the number of nanoseconds elapsed in a completed timer. 149 virtual uint64_t Nanoseconds() const = 0; 150 151 private: 152 SE_DISALLOW_COPY_AND_ASSIGN(TimerInterface); 153 }; 154 155 // Interface for the different StreamExecutor platforms (i.e. CUDA, OpenCL). 156 // 157 // Various platforms will provide an implementation that satisfy this interface. 158 class StreamExecutorInterface { 159 public: 160 // Default constructor for the abstract interface. StreamExecutorInterface()161 StreamExecutorInterface() {} 162 163 // Default destructor for the abstract interface. ~StreamExecutorInterface()164 virtual ~StreamExecutorInterface() {} 165 166 // Returns the (transitively) wrapped executor if this executor is 167 // wrapping another executor; otherwise, returns this. GetUnderlyingExecutor()168 virtual StreamExecutorInterface* GetUnderlyingExecutor() { return this; } 169 170 // See the StreamExecutor interface for comments on the same-named methods. 171 virtual port::Status Init(int device_ordinal, 172 DeviceOptions device_options) = 0; 173 GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)174 virtual port::Status GetKernel(const MultiKernelLoaderSpec& spec, 175 KernelBase* kernel) { 176 return port::UnimplementedError("Not Implemented"); 177 } UnloadModule(ModuleHandle module_handle)178 virtual bool UnloadModule(ModuleHandle module_handle) { return false; } LoadModule(const MultiModuleLoaderSpec & spec,ModuleHandle * module_handle)179 virtual port::Status LoadModule(const MultiModuleLoaderSpec& spec, 180 ModuleHandle* module_handle) { 181 return port::UnimplementedError("Not Implemented"); 182 } 183 virtual port::StatusOr<std::shared_ptr<DeviceMemoryBase>> CreateOrShareConstant(Stream * stream,const std::vector<uint8_t> & content)184 CreateOrShareConstant(Stream* stream, const std::vector<uint8_t>& content) { 185 return port::UnimplementedError("Not Implemented"); 186 } Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & k,const KernelArgsArrayBase & args)187 virtual port::Status Launch(Stream* stream, const ThreadDim& thread_dims, 188 const BlockDim& block_dims, const KernelBase& k, 189 const KernelArgsArrayBase& args) { 190 return port::UnimplementedError("Not Implemented"); 191 } 192 193 // Releases any state associated with the kernel. UnloadKernel(const KernelBase * kernel)194 virtual void UnloadKernel(const KernelBase* kernel) {} 195 virtual DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) = 0; Allocate(uint64_t size)196 DeviceMemoryBase Allocate(uint64_t size) { 197 return Allocate(size, /*memory_space=*/0); 198 } 199 virtual void* GetSubBuffer(DeviceMemoryBase* parent, uint64_t offset, 200 uint64_t size) = 0; 201 virtual void Deallocate(DeviceMemoryBase* mem) = 0; 202 // Allocates unified memory space of the given size, if supported. 203 // See 204 // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd 205 // for more details on unified memory. UnifiedMemoryAllocate(uint64_t size)206 virtual void* UnifiedMemoryAllocate(uint64_t size) { return nullptr; } 207 208 // Deallocates unified memory space previously allocated with 209 // UnifiedMemoryAllocate. UnifiedMemoryDeallocate(void * mem)210 virtual void UnifiedMemoryDeallocate(void* mem) {} 211 virtual void* HostMemoryAllocate(uint64_t size) = 0; 212 virtual void HostMemoryDeallocate(void* mem) = 0; 213 virtual bool HostMemoryRegister(void* mem, uint64_t size) = 0; 214 virtual bool HostMemoryUnregister(void* mem) = 0; 215 virtual bool SynchronizeAllActivity() = 0; 216 virtual port::Status SynchronousMemZero(DeviceMemoryBase* location, 217 uint64_t size) = 0; 218 virtual port::Status SynchronousMemSet(DeviceMemoryBase* location, int value, 219 uint64_t size) = 0; 220 virtual port::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst, 221 const void* host_src, 222 uint64_t size) = 0; 223 virtual port::Status SynchronousMemcpy(void* host_dst, 224 const DeviceMemoryBase& gpu_src, 225 uint64_t size) = 0; 226 virtual port::Status SynchronousMemcpyDeviceToDevice( 227 DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, 228 uint64_t size) = 0; 229 virtual port::Status MemZero(Stream* stream, DeviceMemoryBase* location, 230 uint64_t size) = 0; Memset(Stream * stream,DeviceMemoryBase * location,uint8 pattern,uint64_t size)231 virtual port::Status Memset(Stream* stream, DeviceMemoryBase* location, 232 uint8 pattern, uint64_t size) { 233 return port::InternalError("Not implemented"); 234 } 235 virtual port::Status Memset32(Stream* stream, DeviceMemoryBase* location, 236 uint32 pattern, uint64_t size) = 0; 237 virtual bool Memcpy(Stream* stream, void* host_dst, 238 const DeviceMemoryBase& gpu_src, uint64_t size) = 0; 239 virtual bool Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, 240 const void* host_src, uint64_t size) = 0; 241 virtual bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, 242 const DeviceMemoryBase& gpu_src, 243 uint64_t size) = 0; 244 virtual bool HostCallback(Stream* stream, std::function<void()> callback); 245 virtual bool HostCallback(Stream* stream, 246 std::function<port::Status()> callback) = 0; 247 virtual port::Status AllocateEvent(Event* event) = 0; 248 virtual port::Status DeallocateEvent(Event* event) = 0; 249 virtual port::Status RecordEvent(Stream* stream, Event* event) = 0; 250 virtual port::Status WaitForEvent(Stream* stream, Event* event) = 0; 251 virtual Event::Status PollForEventStatus(Event* event) = 0; 252 virtual bool AllocateStream(Stream* stream) = 0; 253 virtual void DeallocateStream(Stream* stream) = 0; 254 virtual bool CreateStreamDependency(Stream* dependent, Stream* other) = 0; 255 virtual bool AllocateTimer(Timer* timer) = 0; 256 virtual void DeallocateTimer(Timer* timer) = 0; 257 virtual bool StartTimer(Stream* stream, Timer* timer) = 0; 258 virtual bool StopTimer(Stream* stream, Timer* timer) = 0; 259 virtual port::Status BlockHostUntilDone(Stream* stream) = 0; GetStatus(Stream * stream)260 virtual port::Status GetStatus(Stream* stream) { 261 return port::Status(port::error::UNIMPLEMENTED, 262 "GetStatus is not supported on this executor."); 263 } 264 virtual int PlatformDeviceCount() = 0; 265 virtual port::Status EnablePeerAccessTo(StreamExecutorInterface* other) = 0; 266 virtual bool CanEnablePeerAccessTo(StreamExecutorInterface* other) = 0; 267 GetDeviceLoad()268 virtual int64_t GetDeviceLoad() { return -1; } 269 DeviceMemoryUsage(int64_t * free,int64_t * total)270 virtual bool DeviceMemoryUsage(int64_t* free, int64_t* total) const { 271 return false; 272 } 273 274 // Retrieves device pointer and size for a symbol. The device pointer is 275 // stored at mem, and the size is stored at size. Either mem or bytes can be 276 // null, however, both of them cannot be null at the same time. To use 277 // constant memory in CUDA, GetSymbol has to be used. Returns true if symbol 278 // is found. 279 // 280 // If ModuleHandle is set then we search for `symbol_name` only within the 281 // module corresponding to `module_handle`. Otherwise all loaded modules are 282 // searched. GetSymbol(const std::string & symbol_name,ModuleHandle module_handle,void ** mem,size_t * bytes)283 virtual bool GetSymbol(const std::string& symbol_name, 284 ModuleHandle module_handle, void** mem, 285 size_t* bytes) { 286 return false; 287 } 288 289 // Creates a new DeviceDescription object. Ownership is transferred to the 290 // caller. 291 virtual port::StatusOr<std::unique_ptr<DeviceDescription>> 292 CreateDeviceDescription() const = 0; 293 294 // Attempts to register the provided TraceListener with the device-specific 295 // Executor implementation. When this is called, the PIMPL interface has 296 // already taken ownership of the object and is managing the generic tracing 297 // events. The device-specific implementation must determine if the passed 298 // listener is of a type appropriate for it to trace during registration (and 299 // before dispatching events to it). 300 // Returns true if the listener was successfully registered, false otherwise. 301 // Does not take ownership of listener. RegisterTraceListener(TraceListener * listener)302 virtual bool RegisterTraceListener(TraceListener* listener) { return false; } 303 304 // Unregisters the specified listener from the device-specific Executor. 305 // Returns true if the listener was successfully registered, false otherwise. UnregisterTraceListener(TraceListener * listener)306 virtual bool UnregisterTraceListener(TraceListener* listener) { 307 return false; 308 } 309 310 // Returns whether this StreamExecutor has BLAS support for its underlying 311 // platform. SupportsBlas()312 virtual bool SupportsBlas() const { return false; } 313 314 // Creates a new BlasSupport object, ownership is transferred to the caller. 315 // If SupportsBlas() is false, this will always return null. 316 // 317 // If SupportsBlas() is true, this may return null, for example, if the BLAS 318 // initialization fails. CreateBlas()319 virtual blas::BlasSupport* CreateBlas() { return nullptr; } 320 321 // Returns whether this StreamExecutor has FFT support for its underlying 322 // platform. SupportsFft()323 virtual bool SupportsFft() const { return false; } 324 325 // Creates a new fft::FftSupport object, ownership is transferred to the 326 // caller. 327 // If SupportsFft() is false, this will always return null. 328 // 329 // If SupportsFft() is true, this may return null, for example, if the FFT 330 // initialization fails. CreateFft()331 virtual fft::FftSupport* CreateFft() { return nullptr; } 332 333 // Returns whether this StreamExecutor has Random Number Generation support 334 // for 335 // its underlying platform. SupportsRng()336 virtual bool SupportsRng() const { return false; } 337 338 // Returns whether this StreamExecutor has neural net support for its 339 // underlying 340 // platform. SupportsDnn()341 virtual bool SupportsDnn() const { return false; } 342 343 // Creates a new RngSupport object, ownership is transferred to the caller. 344 // If SupportsRng() is false, this will always return null. 345 // 346 // If SupportsRng() is true, this may return null, for example, if the RNG 347 // initialization fails. CreateRng()348 virtual rng::RngSupport* CreateRng() { return nullptr; } 349 350 // Creates a new DnnSupport object, ownership is transferred to the caller. 351 // If SupportsDnn() is false, this will always return null. 352 // 353 // If SupportsDnn() is true, this may return null, for example, if the DNN 354 // initialization fails. CreateDnn()355 virtual dnn::DnnSupport* CreateDnn() { return nullptr; } 356 357 // Each call creates a new instance of the platform-specific implementation of 358 // the corresponding interface type. 359 virtual std::unique_ptr<EventInterface> CreateEventImplementation() = 0; 360 virtual std::unique_ptr<KernelInterface> CreateKernelImplementation() = 0; 361 virtual std::unique_ptr<StreamInterface> GetStreamImplementation() = 0; 362 virtual std::unique_ptr<TimerInterface> GetTimerImplementation() = 0; 363 364 // Returns the CUDA or ROCm context associated with this StreamExecutor 365 // platform implementation. 366 // 367 // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm, 368 // causing a fatal error if it is not. This hack is made available solely for 369 // use from distbelief code, which temporarily has strong ties to CUDA or ROCm 370 // as a platform. GpuContextHack()371 virtual void* GpuContextHack() { return nullptr; } 372 373 // Return allocator statistics. GetAllocatorStats()374 virtual std::optional<AllocatorStats> GetAllocatorStats() { 375 return std::nullopt; 376 } 377 378 // If implemented, clears the internal stats except for the `in_use` fields 379 // and sets the `peak_bytes_in_use` to be equal to the `bytes_in_use`. Returns 380 // true if implemented. 381 // 382 // REQUIRES: GetAllocatorStats is overridden. ClearAllocatorStats()383 virtual bool ClearAllocatorStats() { return false; } 384 385 // Clears the compilation cache from volatile memory. Returns OK if no 386 // compilation cache exists or if clearing the compilation cache is 387 // unsupported. Caches in non-volatile storage are unaffected. FlushCompilationCache()388 virtual port::Status FlushCompilationCache() { 389 return ::tensorflow::OkStatus(); 390 } 391 392 // Returns a stream allocated by this executor, or nullptr if not found. 393 // Performs linear search over alive GPU streams. FindAllocatedStream(void *)394 virtual Stream* FindAllocatedStream(void* /*gpu_stream*/) { return nullptr; } 395 396 private: 397 SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface); 398 }; 399 400 } // namespace internal 401 } // namespace stream_executor 402 403 #endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ 404