xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/test/gpu_info_lib.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library for getting system information during TensorFlow tests."""
16
17import ctypes as ct
18import platform
19
20
21from tensorflow.core.util import test_log_pb2
22from tensorflow.python.framework import errors
23from tensorflow.python.platform import gfile
24
25
26def _gather_gpu_devices_proc():
27  """Try to gather NVidia GPU device information via /proc/driver."""
28  dev_info = []
29  for f in gfile.Glob("/proc/driver/nvidia/gpus/*/information"):
30    bus_id = f.split("/")[5]
31    key_values = dict(line.rstrip().replace("\t", "").split(":", 1)
32                      for line in gfile.GFile(f, "r"))
33    key_values = dict(
34        (k.lower(), v.strip(" ").rstrip(" ")) for (k, v) in key_values.items())
35    info = test_log_pb2.GPUInfo()
36    info.model = key_values.get("model", "Unknown")
37    info.uuid = key_values.get("gpu uuid", "Unknown")
38    info.bus_id = bus_id
39    dev_info.append(info)
40  return dev_info
41
42
43class CUDADeviceProperties(ct.Structure):
44  # See $CUDA_HOME/include/cuda_runtime_api.h for the definition of
45  # the cudaDeviceProp struct.
46  _fields_ = [
47      ("name", ct.c_char * 256),
48      ("totalGlobalMem", ct.c_size_t),
49      ("sharedMemPerBlock", ct.c_size_t),
50      ("regsPerBlock", ct.c_int),
51      ("warpSize", ct.c_int),
52      ("memPitch", ct.c_size_t),
53      ("maxThreadsPerBlock", ct.c_int),
54      ("maxThreadsDim", ct.c_int * 3),
55      ("maxGridSize", ct.c_int * 3),
56      ("clockRate", ct.c_int),
57      ("totalConstMem", ct.c_size_t),
58      ("major", ct.c_int),
59      ("minor", ct.c_int),
60      ("textureAlignment", ct.c_size_t),
61      ("texturePitchAlignment", ct.c_size_t),
62      ("deviceOverlap", ct.c_int),
63      ("multiProcessorCount", ct.c_int),
64      ("kernelExecTimeoutEnabled", ct.c_int),
65      ("integrated", ct.c_int),
66      ("canMapHostMemory", ct.c_int),
67      ("computeMode", ct.c_int),
68      ("maxTexture1D", ct.c_int),
69      ("maxTexture1DMipmap", ct.c_int),
70      ("maxTexture1DLinear", ct.c_int),
71      ("maxTexture2D", ct.c_int * 2),
72      ("maxTexture2DMipmap", ct.c_int * 2),
73      ("maxTexture2DLinear", ct.c_int * 3),
74      ("maxTexture2DGather", ct.c_int * 2),
75      ("maxTexture3D", ct.c_int * 3),
76      ("maxTexture3DAlt", ct.c_int * 3),
77      ("maxTextureCubemap", ct.c_int),
78      ("maxTexture1DLayered", ct.c_int * 2),
79      ("maxTexture2DLayered", ct.c_int * 3),
80      ("maxTextureCubemapLayered", ct.c_int * 2),
81      ("maxSurface1D", ct.c_int),
82      ("maxSurface2D", ct.c_int * 2),
83      ("maxSurface3D", ct.c_int * 3),
84      ("maxSurface1DLayered", ct.c_int * 2),
85      ("maxSurface2DLayered", ct.c_int * 3),
86      ("maxSurfaceCubemap", ct.c_int),
87      ("maxSurfaceCubemapLayered", ct.c_int * 2),
88      ("surfaceAlignment", ct.c_size_t),
89      ("concurrentKernels", ct.c_int),
90      ("ECCEnabled", ct.c_int),
91      ("pciBusID", ct.c_int),
92      ("pciDeviceID", ct.c_int),
93      ("pciDomainID", ct.c_int),
94      ("tccDriver", ct.c_int),
95      ("asyncEngineCount", ct.c_int),
96      ("unifiedAddressing", ct.c_int),
97      ("memoryClockRate", ct.c_int),
98      ("memoryBusWidth", ct.c_int),
99      ("l2CacheSize", ct.c_int),
100      ("maxThreadsPerMultiProcessor", ct.c_int),
101      ("streamPrioritiesSupported", ct.c_int),
102      ("globalL1CacheSupported", ct.c_int),
103      ("localL1CacheSupported", ct.c_int),
104      ("sharedMemPerMultiprocessor", ct.c_size_t),
105      ("regsPerMultiprocessor", ct.c_int),
106      ("managedMemSupported", ct.c_int),
107      ("isMultiGpuBoard", ct.c_int),
108      ("multiGpuBoardGroupID", ct.c_int),
109      # Pad with extra space to avoid dereference crashes if future
110      # versions of CUDA extend the size of this struct.
111      ("__future_buffer", ct.c_char * 4096)
112  ]
113
114
115def _gather_gpu_devices_cudart():
116  """Try to gather NVidia GPU device information via libcudart."""
117  dev_info = []
118
119  system = platform.system()
120  if system == "Linux":
121    libcudart = ct.cdll.LoadLibrary("libcudart.so")
122  elif system == "Darwin":
123    libcudart = ct.cdll.LoadLibrary("libcudart.dylib")
124  elif system == "Windows":
125    libcudart = ct.windll.LoadLibrary("libcudart.dll")
126  else:
127    raise NotImplementedError("Cannot identify system.")
128
129  version = ct.c_int()
130  rc = libcudart.cudaRuntimeGetVersion(ct.byref(version))
131  if rc != 0:
132    raise ValueError("Could not get version")
133  if version.value < 6050:
134    raise NotImplementedError("CUDA version must be between >= 6.5")
135
136  device_count = ct.c_int()
137  libcudart.cudaGetDeviceCount(ct.byref(device_count))
138
139  for i in range(device_count.value):
140    properties = CUDADeviceProperties()
141    rc = libcudart.cudaGetDeviceProperties(ct.byref(properties), i)
142    if rc != 0:
143      raise ValueError("Could not get device properties")
144    pci_bus_id = " " * 13
145    rc = libcudart.cudaDeviceGetPCIBusId(ct.c_char_p(pci_bus_id), 13, i)
146    if rc != 0:
147      raise ValueError("Could not get device PCI bus id")
148
149    info = test_log_pb2.GPUInfo()  # No UUID available
150    info.model = properties.name
151    info.bus_id = pci_bus_id
152    dev_info.append(info)
153
154    del properties
155
156  return dev_info
157
158
159def gather_gpu_devices():
160  """Gather gpu device info.
161
162  Returns:
163    A list of test_log_pb2.GPUInfo messages.
164  """
165  try:
166    # Prefer using /proc if possible, it provides the UUID.
167    dev_info = _gather_gpu_devices_proc()
168    if not dev_info:
169      raise ValueError("No devices found")
170    return dev_info
171  except (IOError, ValueError, errors.OpError):
172    pass
173
174  try:
175    # Fall back on using libcudart
176    return _gather_gpu_devices_cudart()
177  except (OSError, ValueError, NotImplementedError, errors.OpError):
178    return []
179