xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/gpu/gpu_driver.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // CUDA userspace driver library wrapper functionality.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_GPU_GPU_DRIVER_H_
19 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_GPU_GPU_DRIVER_H_
20 
21 #include <stddef.h>
22 
23 #include "tensorflow/compiler/xla/stream_executor/device_options.h"
24 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_types.h"
25 #include "tensorflow/compiler/xla/stream_executor/lib/status.h"
26 #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
27 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
28 
29 namespace stream_executor {
30 namespace gpu {
31 
32 // Identifies the memory space where an allocation resides. See
33 // GpuDriver::GetPointerMemorySpace().
34 enum class MemorySpace { kHost, kDevice };
35 
36 // Returns a casual string, such as "host" for the provided memory space.
37 std::string MemorySpaceString(MemorySpace memory_space);
38 
39 class GpuContext;
40 
41 // GpuDriver contains wrappers for calls to the userspace library driver. It's
42 // useful to isolate these calls and put basic wrappers around them to separate
43 // userspace library driver behaviors from the rest of the program.
44 //
45 // At the moment it's simply used as a namespace.
46 //
47 // The calls log any specific errors internally and return whether the operation
48 // was successful to the caller.
49 //
50 // The order of parameters is generally kept symmetric with the underlying CUDA
51 // driver API.
52 //
53 // Links on functions are to specific documentation under
54 // http://docs.nvidia.com/cuda/cuda-driver-api/
55 //
56 // Thread safety: these functions should not be used from signal handlers.
57 class GpuDriver {
58  public:
59   // Wraps a call to cuInit with logging to help indicate what has gone wrong in
60   // the case of failure. Safe to call multiple times; will be fast on all calls
61   // after the first.
62   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__INITIALIZE.html#group__CUDA__INITIALIZE_1g0a2f1517e1bd8502c7194c3a8c134bc3
63   static port::Status Init();
64 
65   // Returns the device associated with the given context.
66   // device is an outparam owned by the caller, must not be null.
67   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g4e84b109eba36cdaaade167f34ae881e
68   static port::StatusOr<GpuDeviceHandle> DeviceFromContext(GpuContext* context);
69 
70   // Creates a new CUDA stream associated with the given context via
71   // cuStreamCreate.
72   // stream is an outparam owned by the caller, must not be null.
73   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1ga581f0c5833e21ded8b5a56594e243f4
74   static bool CreateStream(GpuContext* context, GpuStreamHandle* stream,
75                            int priority = 0);
76 
77   // Destroys a CUDA stream associated with the given context.
78   // stream is owned by the caller, must not be null, and *stream is set to null
79   // if the stream is successfully destroyed.
80   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g244c8833de4596bcd31a06cdf21ee758
81   static void DestroyStream(GpuContext* context, GpuStreamHandle* stream);
82 
83   // CUDA events can explicitly disable event TSC retrieval for some presumed
84   // performance improvement if timing is unnecessary.
85   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g450687e75f3ff992fe01662a43d9d3db
86   enum class EventFlags { kDefault, kDisableTiming };
87 
88   // Creates a new event associated with the given context.
89   // result is an outparam owned by the caller and must not be null.
90   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g450687e75f3ff992fe01662a43d9d3db
91   static port::Status InitEvent(GpuContext* context, GpuEventHandle* result,
92                                 EventFlags flags);
93 
94   // Destroys *event and turns it into a nullptr. event may not be null, but
95   // *event may be, via cuEventDestroy
96   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g593ec73a8ec5a5fc031311d3e4dca1ef
97   static port::Status DestroyEvent(GpuContext* context, GpuEventHandle* event);
98 
99   // Allocates a GPU memory space of size bytes associated with the given
100   // context via cuMemAlloc.
101   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gb82d2a09844a58dd9e744dc31e8aa467
102   static void* DeviceAllocate(GpuContext* context, uint64_t bytes);
103 
104   // Deallocates a GPU memory space of size bytes associated with the given
105   // context via cuMemFree.
106   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g89b3f154e17cc89b6eea277dbdf5c93a
107   static void DeviceDeallocate(GpuContext* context, void* location);
108 
109   // Allocates a unified memory space of size bytes associated with the given
110   // context via cuMemAllocManaged.
111   // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gb347ded34dc326af404aa02af5388a32
112   // (supported on CUDA only)
113   static void* UnifiedMemoryAllocate(GpuContext* context, uint64_t bytes);
114 
115   // Deallocates a unified memory space of size bytes associated with the given
116   // context via cuMemFree.
117   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g89b3f154e17cc89b6eea277dbdf5c93a
118   // (supported on CUDA only)
119   static void UnifiedMemoryDeallocate(GpuContext* context, void* location);
120 
121   // Allocates page-locked and CUDA-registered memory on the host via
122   // cuMemAllocHost.
123   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gdd8311286d2c2691605362c689bc64e0
124   static void* HostAllocate(GpuContext* context, uint64_t bytes);
125 
126   // Deallocates a location created by HostAllocate, via cuMemFreeHost.
127   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g62e0fdbe181dab6b1c90fa1a51c7b92c
128   static void HostDeallocate(GpuContext* context, void* location);
129 
130   // Registers a memory region at location of size bytes via cuMemHostRegister.
131   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gf0a9fe11544326dabd743b7aa6b54223
132   static bool HostRegister(GpuContext* context, void* location, uint64_t bytes);
133 
134   // Unregisters a memory region that was previously registered at location via
135   // cuMemHostUnregister.
136   //
137   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g63f450c8125359be87b7623b1c0b2a14
138   //
139   // TODO(leary) verify an error will be returned if the location wasn't
140   // previously registered.
141   static bool HostUnregister(GpuContext* context, void* location);
142 
143   // Virtual memory support was added to CUDA in 10.2
144 #if CUDA_VERSION >= 10020
145 
146   // Reserves a range of virtual device memory addresses via
147   // cuMemAddressReserve. bytes must be a multiple of the host page size.
148   // Returns nullptr base address in VmemSpan if the reservation fails.
149   // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1ge489256c107df2a07ddf96d80c86cd9b
150   struct VmemSpan {
151     GpuDevicePtr base;
152     // Size in bytes.
153     uint64_t size_bytes;
154   };
155   static port::StatusOr<VmemSpan> ReserveVirtualMemory(GpuContext* context,
156                                                        uint64_t bytes);
157 
158   // Frees a range of virtual addresses that were previously reserved through
159   // ReserveVirtualMemory via cuMemAddressFree.
160   // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1g6993ecea2ea03e1b802b8255edc2da5b
161   static void FreeVirtualMemory(GpuContext* context, VmemSpan reservation);
162 
163   // Calculates the minimum alignment for memory allocations done through
164   // cuMemCreate via cuMemGetAllocationGranularity.
165   // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1g30ee906c2cf66a0347b3dfec3d7eb31a
166   static port::StatusOr<uint64_t> GetMinAllocationGranularity(
167       GpuDeviceHandle device);
168 
169   // Allocates physical memory and returns a handle that can be mapped to
170   // virtual addresses via cuMemCreate. bytes must be a multiple of the
171   // granularity returned by GetMinAllocationGranularity.
172   // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1g899d69a862bba36449789c64b430dc7c
173   struct GenericMemoryHandle {
174     uint64_t handle;
175     uint64_t bytes;
176   };
177   static port::StatusOr<GenericMemoryHandle> CreateMemoryHandle(
178       GpuContext* context, uint64_t bytes);
179 
180   // Frees memory represented by the provided MemoryHandle via cuMemRelease.
181   // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1g3014f0759f43a8d82db951b8e4b91d68
182   static void ReleaseMemoryHandle(GpuContext* context,
183                                   GenericMemoryHandle handle);
184 
185   // Maps a memory allocation handle to a reserved virtual address range via
186   // cuMemMap and sets the appropriate access settings via cuMemSetAccess.
187   // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1gff1d395423af5c5c75375516959dae56
188   // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1g1b6b12b10e8324bf462ecab4e7ef30e1
189   static port::Status MapMemory(
190       GpuContext* context, GpuDevicePtr va, const GenericMemoryHandle& handle,
191       const std::vector<GpuDeviceHandle>& device_handles);
192 
193   // Unmaps the backing memory from the given virtual address range. This range
194   // must fully unmap a memory handle that was mapped using MapMemory; partial
195   // unmapping is not supported.
196   // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1gfb50aac00c848fd7087e858f59bf7e2a
197   static void UnmapMemory(GpuContext* context, GpuDevicePtr va, uint64_t bytes);
198 
199 #endif  // CUDA_VERSION >= 10200
200 
201   // Given a device ordinal, returns a device handle into the device outparam,
202   // which must not be null.
203   //
204   // N.B. these device handles do not have a corresponding destroy function in
205   // the CUDA driver API.
206   static port::Status GetDevice(int device_ordinal, GpuDeviceHandle* device);
207 
208   // Given a device handle, returns the name reported by the driver for the
209   // device.
210   static port::Status GetDeviceName(GpuDeviceHandle device,
211                                     std::string* device_name);
212 
213   // Given a device to create a context for, returns a context handle into the
214   // context outparam, which must not be null.
215   //
216   // N.B. CUDA contexts are weird. They are implicitly associated with the
217   // calling thread. Current documentation on contexts and their influence on
218   // userspace processes is given here:
219   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g65dc0012348bc84810e2103a40d8e2cf
220   static port::Status CreateContext(int device_ordinal, GpuDeviceHandle device,
221                                     const DeviceOptions& device_options,
222                                     GpuContext** context);
223 
224   // Destroys the provided context via cuCtxDestroy.
225   // Don't do this while clients could still be using the context, per the docs
226   // bad things will happen.
227   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g27a365aebb0eb548166309f58a1e8b8e
228   static void DestroyContext(GpuContext* context);
229 
230   // Returns the context handle (CUcontext for CUDA and hipCtx_t for ROCm) of a
231   // GpuContext.
232   static GpuContextHandle GetContextHandle(GpuContext* context);
233 
234   // Queries the runtime for the specified attribute of the specified function.
235   // cuFuncGetAttribute (the underlying CUDA driver API routine) only operates
236   // in terms of integer-sized values, so there's no potential for overrun (as
237   // of CUDA 5.5).
238   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1g5e92a1b0d8d1b82cb00dcfb2de15961b
239   static port::Status FuncGetAttribute(GpuFunctionAttribute attribute,
240                                        GpuFunctionHandle function,
241                                        int* attribute_value);
242 
243   // Sets the preferred cache configuration for the specified function.
244   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1g40f8c11e81def95dc0072a375f965681
245   static port::Status FuncSetCacheConfig(GpuFunctionHandle function,
246                                          GpuFuncCachePreference cache_config);
247 
248   // Gets the preferred shared memory bank configuration for the specified
249   // CONTEXT (not function!), either default or four- or eight-byte bank size.
250   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g17153a1b8b8c756f7ab8505686a4ad74
251   static port::StatusOr<GpuSharedMemConfig> ContextGetSharedMemConfig(
252       GpuContext* context);
253 
254   // Sets the preferred shared memory bank configuration for the specified
255   // CONTEXT (not function!), either default or four- or eight-byte bank size.
256   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g2574235fa643f8f251bf7bc28fac3692
257   static port::Status ContextSetSharedMemConfig(
258       GpuContext* context, GpuSharedMemConfig shared_mem_config);
259 
260   // Launches a CUDA kernel via cuLaunchKernel.
261   // TODO(leary) describe the structure of kernel_params and extra in a readable
262   // way.
263   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb8f3dc3031b40da29d5f9a7139e52e15
264   static port::Status LaunchKernel(
265       GpuContext* context, absl::string_view kernel_name,
266       GpuFunctionHandle function, unsigned int grid_dim_x,
267       unsigned int grid_dim_y, unsigned int grid_dim_z,
268       unsigned int block_dim_x, unsigned int block_dim_y,
269       unsigned int block_dim_z, unsigned int shared_mem_bytes,
270       GpuStreamHandle stream, void** kernel_params, void** extra);
271 
272   // Loads ptx_contents with the CUDA driver's PTX JIT and stores the resulting
273   // handle in "module". Any error logs that are produced are logged internally.
274   // (supported on CUDA only)
275   static port::Status LoadPtx(GpuContext* context, const char* ptx_contents,
276                               GpuModuleHandle* module);
277 
278   // Loads cubin_bytes with the CUDA driver's blob loading interface and stores
279   // the resulting handle in "module".
280   // (supported on CUDA only)
281   static port::Status LoadCubin(GpuContext* context, const char* cubin_bytes,
282                                 GpuModuleHandle* module);
283 
284   // Loads HSACO with the ROCM runtime and stores the resulting handle in
285   // "module". Any error logs that are produced are logged internally.
286   // (supported on ROCm only)
287   static port::Status LoadHsaco(GpuContext* context, const char* hsaco_contents,
288                                 GpuModuleHandle* module);
289 
290   // Retrieves a named kernel from a loaded module, and places the resulting
291   // handle into function (outparam) on success. Neither kernel_name nor
292   // function may be null. No ownership is taken of kernel_name.
293   static bool GetModuleFunction(GpuContext* context, GpuModuleHandle module,
294                                 const char* kernel_name,
295                                 GpuFunctionHandle* function);
296 
297   // Retrieves a named global/constant symbol from a loaded module, and returns
298   // a device pointer and size of the symbol on success. symbol_name may not be
299   // null. At least one of dptr or bytes should not be null. No ownership is
300   // taken of symbol_name.
301   static bool GetModuleSymbol(GpuContext* context, GpuModuleHandle module,
302                               const char* symbol_name, GpuDevicePtr* dptr,
303                               size_t* bytes);
304 
305   // Unloads module from the current context via cuModuleUnload.
306   // TODO(leary) the documentation doesn't say what kind of disasters happen
307   // if you try to unload a module while its GpuFunctionHandles are in use.
308   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MODULE.html#group__CUDA__MODULE_1g8ea3d716524369de3763104ced4ea57b
309   static void UnloadModule(GpuContext* context, GpuModuleHandle module);
310 
311   // Performs a synchronous memset of the device memory segment via cuMemsetD8.
312   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g6e582bf866e9e2fb014297bfaf354d7b
313   static port::Status SynchronousMemsetUint8(GpuContext* context,
314                                              GpuDevicePtr location, uint8 value,
315                                              size_t size);
316 
317   // Performs a synchronous memset of the device memory segment via cuMemsetD32.
318   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g983e8d8759acd1b64326317481fbf132
319   static port::Status SynchronousMemsetUint32(GpuContext* context,
320                                               GpuDevicePtr location,
321                                               uint32 value,
322                                               size_t uint32_count);
323 
324   // Performs an asynchronous memset of the device memory segment via
325   // cuMemsetD8Async.
326   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gaef08a7ccd61112f94e82f2b30d43627
327   static port::Status AsynchronousMemsetUint8(GpuContext* context,
328                                               GpuDevicePtr location,
329                                               uint8 value, size_t uint32_count,
330                                               GpuStreamHandle stream);
331 
332   // Performs an asynchronous memset of the device memory segment via
333   // cuMemsetD32Async.
334   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g58229da5d30f1c0cdf667b320ec2c0f5
335   static port::Status AsynchronousMemsetUint32(GpuContext* context,
336                                                GpuDevicePtr location,
337                                                uint32 value,
338                                                size_t uint32_count,
339                                                GpuStreamHandle stream);
340 
341   // -- Synchronous memcopies.
342   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g4d32266788c440b0220b1a9ba5795169
343 
344   static port::Status SynchronousMemcpyD2H(GpuContext* context, void* host_dst,
345                                            GpuDevicePtr gpu_src, uint64_t size);
346   static port::Status SynchronousMemcpyH2D(GpuContext* context,
347                                            GpuDevicePtr gpu_dst,
348                                            const void* host_src, uint64_t size);
349   static port::Status SynchronousMemcpyD2D(GpuContext* context,
350                                            GpuDevicePtr gpu_dst,
351                                            GpuDevicePtr gpu_src, uint64_t size);
352 
353   // -- Asynchronous memcopies.
354   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g56f30236c7c5247f8e061b59d3268362
355 
356   static bool AsynchronousMemcpyD2H(GpuContext* context, void* host_dst,
357                                     GpuDevicePtr gpu_src, uint64_t size,
358                                     GpuStreamHandle stream);
359   static bool AsynchronousMemcpyH2D(GpuContext* context, GpuDevicePtr gpu_dst,
360                                     const void* host_src, uint64_t size,
361                                     GpuStreamHandle stream);
362   static bool AsynchronousMemcpyD2D(GpuContext* context, GpuDevicePtr gpu_dst,
363                                     GpuDevicePtr gpu_src, uint64_t size,
364                                     GpuStreamHandle stream);
365 
366   // The CUDA stream callback type signature.
367   // The data passed to AddStreamCallback is subsequently passed to this
368   // callback when it fires.
369   //
370   // Some notable things:
371   // * Callbacks must not make any CUDA API calls.
372   // * Callbacks from independent streams execute in an undefined order and may
373   //   be serialized.
374   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g613d97a277d7640f4cb1c03bd51c2483
375   typedef void (*StreamCallback)(GpuStreamHandle stream, GpuStatus status,
376                                  void* data);
377 
378   // Enqueues a callback operation into stream.
379   // See StreamCallback above and the NVIDIA documentation for additional
380   // details.
381   static bool AddStreamCallback(GpuContext* context, GpuStreamHandle stream,
382                                 StreamCallback callback, void* data);
383 
384   // Causes stream to wait for event to trigger before proceeding via
385   // cuStreamWaitEvent.
386   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#axzz334nAXAhM
387   static bool WaitStreamOnEvent(GpuContext* context, GpuStreamHandle stream,
388                                 GpuEventHandle event);
389 
390   // Blocks the calling thread until the operations enqueued onto stream have
391   // been completed, via cuStreamSynchronize.
392   //
393   // TODO(leary) if a pathological thread enqueues operations onto the stream
394   // while another thread blocks like this, can you wind up waiting an unbounded
395   // amount of time?
396   //
397   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g15e49dd91ec15991eb7c0a741beb7dad
398   static port::Status SynchronizeStream(GpuContext* context,
399                                         GpuStreamHandle stream);
400 
401   // Blocks the calling thread until the operations associated with the context
402   // have been completed, via cuCtxSynchronize.
403   //
404   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g7a54725f28d34b8c6299f0c6ca579616
405   static bool SynchronizeContext(GpuContext* context);
406 
407   // Returns true if all stream tasks have completed at time of the call. Note
408   // the potential for races around this call (if another thread adds work to
409   // the stream immediately after this returns).
410   static bool IsStreamIdle(GpuContext* context, GpuStreamHandle stream);
411 
412   // Returns whether code in the from context can access memory in the to
413   // context via cuDeviceCanAccessPeer.
414   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PEER__ACCESS.html#group__CUDA__PEER__ACCESS_1g496bdaae1f632ebfb695b99d2c40f19e
415   static bool CanEnablePeerAccess(GpuContext* from, GpuContext* to);
416 
417   // Returns whether the from device can access memory in the to
418   // device via cuDeviceCanAccessPeer. Because of differences between ROCM and
419   // CUDA, this API is not supported in ROCM builds and will result in a link
420   // error if used.
421   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PEER__ACCESS.html#group__CUDA__PEER__ACCESS_1g496bdaae1f632ebfb695b99d2c40f19e
422   static bool CanEnablePeerAccess(GpuDeviceHandle from, GpuDeviceHandle to);
423 
424   // Enables peer access per CanEnablePeerAccess, via cuCtxEnablePeerAccess.
425   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PEER__ACCESS.html#group__CUDA__PEER__ACCESS_1g0889ec6728e61c05ed359551d67b3f5a
426   static port::Status EnablePeerAccess(GpuContext* from, GpuContext* to);
427 
428   // Returns the elapsed milliseconds between start and stop via
429   // cuEventElapsedTime.
430   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1gdfb1178807353bbcaa9e245da497cf97
431   static bool GetEventElapsedTime(GpuContext* context,
432                                   float* elapsed_milliseconds,
433                                   GpuEventHandle start, GpuEventHandle stop);
434 
435   // Records that an event occurred when execution reaches the current point in
436   // thestream via cuEventRecord.
437   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g95424d3be52c4eb95d83861b70fb89d1
438   static port::Status RecordEvent(GpuContext* context, GpuEventHandle event,
439                                   GpuStreamHandle stream);
440 
441   // Polls (without blocking) to determine the status of an event - pending or
442   // complete (or an error status).
443   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g6f0704d755066b0ee705749ae911deef
444   static port::StatusOr<GpuStatus> QueryEvent(GpuContext* context,
445                                               GpuEventHandle event);
446 
447   // -- Pointer-specific calls.
448 
449   // Returns the context in which pointer was allocated or registered.
450   static port::StatusOr<GpuContext*> GetPointerContext(GpuDevicePtr pointer);
451 
452   // Returns the device associated with the context from GetPointerContext().
453   static port::StatusOr<GpuDeviceHandle> GetPointerDevice(GpuDevicePtr pointer);
454 
455   // Returns the memory space addressed by pointer.
456   static port::StatusOr<MemorySpace> GetPointerMemorySpace(
457       GpuDevicePtr pointer);
458 
459   // Returns the base address and size of the device pointer dptr.
460   static port::Status GetPointerAddressRange(GpuDevicePtr dptr,
461                                              GpuDevicePtr* base, size_t* size);
462 
463   // -- Device-specific calls.
464 
465   // Returns the compute capability for the device; i.e (3, 5).
466   // This is currently done via the deprecated device API.
467   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html#group__CUDA__DEVICE__DEPRECATED_1ge2091bbac7e1fb18c2821612115607ea
468   // (supported on CUDA only)
469   static port::Status GetComputeCapability(int* cc_major, int* cc_minor,
470                                            GpuDeviceHandle device);
471 
472   // Returns Gpu ISA version for the device; i.e 803, 900.
473   // (supported on ROCm only)
474   static port::Status GetGpuISAVersion(int* version, GpuDeviceHandle device);
475 
476   // Return the full GCN Architecture Name for the device
477   // for eg: amdgcn-amd-amdhsa--gfx908:sramecc+:xnack-
478   // (supported on ROCm only)
479   static port::Status GetGpuGCNArchName(GpuDeviceHandle device,
480                                         std::string* gcnArchName);
481 
482 #if TENSORFLOW_USE_ROCM
483   // tests the current device for MFMA insn support (ROCm only)
484   static port::StatusOr<bool> GetMFMASupport();
485 #endif
486 
487   // Returns the number of multiprocessors on the device (note that the device
488   // may be multi-GPU-per-board).
489   static port::StatusOr<int> GetMultiprocessorCount(GpuDeviceHandle device);
490 
491   // Returns the limit on number of threads that can be resident in a single
492   // multiprocessor.
493   static port::StatusOr<int64_t> GetMaxThreadsPerMultiprocessor(
494       GpuDeviceHandle device);
495 
496   // Returns the limit on number of threads which may be resident for a single
497   // block (cooperative thread array).
498   static port::StatusOr<int64_t> GetMaxThreadsPerBlock(GpuDeviceHandle device);
499 
500   // Returns the amount of shared memory available on a single GPU core (i.e.
501   // SM on NVIDIA devices).
502   static port::StatusOr<int64_t> GetMaxSharedMemoryPerCore(
503       GpuDeviceHandle device);
504 
505   // Returns the amount of shared memory available for a single block
506   // (cooperative thread array).
507   static port::StatusOr<int64_t> GetMaxSharedMemoryPerBlock(
508       GpuDeviceHandle device);
509 
510   // Returns the maximum supported number of registers per block.
511   static port::StatusOr<int64_t> GetMaxRegistersPerBlock(
512       GpuDeviceHandle device);
513 
514   // Returns the number of threads per warp.
515   static port::StatusOr<int64_t> GetThreadsPerWarp(GpuDeviceHandle device);
516 
517   // Queries the grid limits for device with cuDeviceGetAttribute calls.
518   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g9c3e1414f0ad901d3278a4d6645fc266
519   static bool GetGridLimits(int* x, int* y, int* z, GpuDeviceHandle device);
520 
521   // Returns a grab-bag of device properties in a caller-owned device_properties
522   // structure for device_ordinal via cuDeviceGetProperties.
523   //
524   // This call is deprecated in the NVIDIA driver API; its replacement is
525   // GetDeviceAttribute
526   //
527   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html#group__CUDA__DEVICE__DEPRECATED_1g65a5b4e25186bd257df80b98c98cffe6
528   static bool GetDeviceProperties(GpuDeviceProperty* device_properties,
529                                   int device_ordinal);
530 
531   // Gets a specific integer-valued property about the given device.
532   //
533   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g9c3e1414f0ad901d3278a4d6645fc266
534   static port::StatusOr<int> GetDeviceAttribute(GpuDeviceAttribute attribute,
535                                                 GpuDeviceHandle device);
536 
537   // Returns whether ECC is enabled for the given GpuDeviceHandle via
538   // cuDeviceGetattribute with CU_DEVICE_ATTRIBUTE_ECC_ENABLED.
539   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g9c3e1414f0ad901d3278a4d6645fc266
540   static bool IsEccEnabled(GpuDeviceHandle device, bool* result);
541 
542   // Returns the total amount of memory available for allocation by the CUDA
543   // context, in bytes, via cuDeviceTotalMem.
544   static bool GetDeviceTotalMemory(GpuDeviceHandle device, uint64_t* result);
545 
546   // Returns the free amount of memory and total amount of memory, as reported
547   // by cuMemGetInfo.
548   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g808f555540d0143a331cc42aa98835c0
549   static bool GetDeviceMemoryInfo(GpuContext* context, int64_t* free,
550                                   int64_t* total);
551 
552   // Returns a PCI bus id string for the device.
553   // [domain]:[bus]:[device].[function]
554   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g85295e7d9745ab8f0aa80dd1e172acfc
555   static std::string GetPCIBusID(GpuDeviceHandle device);
556 
557   // -- Context- and device-independent calls.
558 
559   // Returns the number of visible CUDA device via cuDeviceGetCount.
560   // This should correspond to the set of device ordinals available.
561   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g52b5ce05cb8c5fb6831b2c0ff2887c74
562   static int GetDeviceCount();
563 
564   // Returns the driver version number via cuDriverGetVersion.
565   // This is, surprisingly, NOT the actual driver version (e.g. 331.79) but,
566   // instead, the CUDA toolkit release number that this driver is compatible
567   // with; e.g. 6000 (for a CUDA 6.0 compatible driver) or 6050 (for a CUDA 6.5
568   // compatible driver).
569   //
570   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VERSION.html#group__CUDA__VERSION_1g8b7a10395392e049006e61bcdc8ebe71
571   static bool GetDriverVersion(int* driver_version);
572 
573   // -- Other calls
574 
575   // Returns the maximum number of blocks (per multiprocessor) occupied by the
576   // specified kernel/GpuFunctionHandle when launched with the specified
577   // parameters.
578   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__OCCUPANCY.html#group__CUDA__OCCUPANCY_1gcc6e1094d05cba2cee17fe33ddd04a98
579   static port::StatusOr<int> GetMaxOccupiedBlocksPerCore(
580       GpuContext* context, GpuFunctionHandle kernel, int threads_per_block,
581       size_t dynamic_shared_memory_bytes);
582 
583   // Seam for injecting an error at CUDA initialization time for testing
584   // purposes.
585   static bool driver_inject_init_error_;
586 };
587 
588 // Ensures a context is activated within a scope.
589 class ScopedActivateContext {
590  public:
591   // Activates the context via cuCtxSetCurrent, if it is not the currently
592   // active context (a la cuCtxGetCurrent). Note the alternative push/pop
593   // mechanism is said by NVIDIA to be relatively slow and deprecated.
594   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1gbe562ee6258b4fcc272ca6478ca2a2f7
595   explicit ScopedActivateContext(GpuContext* context);
596 
597   // Checks that the context has remained activated for the duration of the
598   // scope.
599   ~ScopedActivateContext();
600 
601  private:
602   GpuContext* to_restore_ = nullptr;
603 };
604 
605 }  // namespace gpu
606 }  // namespace stream_executor
607 
608 #endif  // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_GPU_GPU_DRIVER_H_
609