xref: /aosp_15_r20/external/pytorch/torch/backends/cuda/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3from typing import Union
4from typing_extensions import deprecated
5
6import torch
7
8
9__all__ = [
10    "is_built",
11    "cuFFTPlanCacheAttrContextProp",
12    "cuFFTPlanCache",
13    "cuFFTPlanCacheManager",
14    "cuBLASModule",
15    "preferred_linalg_library",
16    "preferred_blas_library",
17    "cufft_plan_cache",
18    "matmul",
19    "SDPAParams",
20    "enable_cudnn_sdp",
21    "cudnn_sdp_enabled",
22    "enable_flash_sdp",
23    "flash_sdp_enabled",
24    "enable_mem_efficient_sdp",
25    "mem_efficient_sdp_enabled",
26    "math_sdp_enabled",
27    "enable_math_sdp",
28    "allow_fp16_bf16_reduction_math_sdp",
29    "fp16_bf16_reduction_math_sdp_allowed",
30    "is_flash_attention_available",
31    "can_use_flash_attention",
32    "can_use_efficient_attention",
33    "sdp_kernel",
34]
35
36
37def is_built():
38    r"""
39    Return whether PyTorch is built with CUDA support.
40
41    Note that this doesn't necessarily mean CUDA is available; just that if this PyTorch
42    binary were run on a machine with working CUDA drivers and devices, we would be able to use it.
43    """
44    return torch._C._has_cuda
45
46
47class cuFFTPlanCacheAttrContextProp:
48    # Like regular ContextProp, but uses the `.device_index` attribute from the
49    # calling object as the first argument to the getter and setter.
50    def __init__(self, getter, setter):
51        self.getter = getter
52        self.setter = setter
53
54    def __get__(self, obj, objtype):
55        return self.getter(obj.device_index)
56
57    def __set__(self, obj, val):
58        if isinstance(self.setter, str):
59            raise RuntimeError(self.setter)
60        self.setter(obj.device_index, val)
61
62
63class cuFFTPlanCache:
64    r"""
65    Represent a specific plan cache for a specific `device_index`.
66
67    The attributes `size` and `max_size`, and method `clear`, can fetch and/ or
68    change properties of the C++ cuFFT plan cache.
69    """
70
71    def __init__(self, device_index):
72        self.device_index = device_index
73
74    size = cuFFTPlanCacheAttrContextProp(
75        torch._cufft_get_plan_cache_size,
76        ".size is a read-only property showing the number of plans currently in the "
77        "cache. To change the cache capacity, set cufft_plan_cache.max_size.",
78    )
79
80    max_size = cuFFTPlanCacheAttrContextProp(
81        torch._cufft_get_plan_cache_max_size, torch._cufft_set_plan_cache_max_size
82    )
83
84    def clear(self):
85        return torch._cufft_clear_plan_cache(self.device_index)
86
87
88class cuFFTPlanCacheManager:
89    r"""
90    Represent all cuFFT plan caches, return the cuFFTPlanCache for a given device when indexed.
91
92    Finally, this object, when used directly as a `cuFFTPlanCache` object (e.g.,
93    setting the `.max_size`) attribute, the current device's cuFFT plan cache is
94    used.
95    """
96
97    __initialized = False
98
99    def __init__(self):
100        self.caches = []
101        self.__initialized = True
102
103    def __getitem__(self, device):
104        index = torch.cuda._utils._get_device_index(device)
105        if index < 0 or index >= torch.cuda.device_count():
106            raise RuntimeError(
107                f"cufft_plan_cache: expected 0 <= device index < {torch.cuda.device_count()}, but got "
108                f"device with index {index}"
109            )
110        if len(self.caches) == 0:
111            self.caches.extend(
112                cuFFTPlanCache(index) for index in range(torch.cuda.device_count())
113            )
114        return self.caches[index]
115
116    def __getattr__(self, name):
117        return getattr(self[torch.cuda.current_device()], name)
118
119    def __setattr__(self, name, value):
120        if self.__initialized:
121            return setattr(self[torch.cuda.current_device()], name, value)
122        else:
123            return super().__setattr__(name, value)
124
125
126class cuBLASModule:
127    def __getattr__(self, name):
128        if name == "allow_tf32":
129            return torch._C._get_cublas_allow_tf32()
130        elif name == "allow_fp16_reduced_precision_reduction":
131            return torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
132        elif name == "allow_bf16_reduced_precision_reduction":
133            return torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
134        raise AttributeError("Unknown attribute " + name)
135
136    def __setattr__(self, name, value):
137        if name == "allow_tf32":
138            return torch._C._set_cublas_allow_tf32(value)
139        elif name == "allow_fp16_reduced_precision_reduction":
140            return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(value)
141        elif name == "allow_bf16_reduced_precision_reduction":
142            return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value)
143        raise AttributeError("Unknown attribute " + name)
144
145
146_LinalgBackends = {
147    "default": torch._C._LinalgBackend.Default,
148    "cusolver": torch._C._LinalgBackend.Cusolver,
149    "magma": torch._C._LinalgBackend.Magma,
150}
151_LinalgBackends_str = ", ".join(_LinalgBackends.keys())
152
153
154def preferred_linalg_library(
155    backend: Union[None, str, torch._C._LinalgBackend] = None
156) -> torch._C._LinalgBackend:
157    r"""
158    Override the heuristic PyTorch uses to choose between cuSOLVER and MAGMA for CUDA linear algebra operations.
159
160    .. warning:: This flag is experimental and subject to change.
161
162    When PyTorch runs a CUDA linear algebra operation it often uses the cuSOLVER or MAGMA libraries,
163    and if both are available it decides which to use with a heuristic.
164    This flag (a :class:`str`) allows overriding those heuristics.
165
166    * If `"cusolver"` is set then cuSOLVER will be used wherever possible.
167    * If `"magma"` is set then MAGMA will be used wherever possible.
168    * If `"default"` (the default) is set then heuristics will be used to pick between
169      cuSOLVER and MAGMA if both are available.
170    * When no input is given, this function returns the currently preferred library.
171    * User may use the environment variable TORCH_LINALG_PREFER_CUSOLVER=1 to set the preferred library to cuSOLVER
172      globally.
173      This flag only sets the initial value of the preferred library and the preferred library
174      may still be overridden by this function call later in your script.
175
176    Note: When a library is preferred other libraries may still be used if the preferred library
177    doesn't implement the operation(s) called.
178    This flag may achieve better performance if PyTorch's heuristic library selection is incorrect
179    for your application's inputs.
180
181    Currently supported linalg operators:
182
183    * :func:`torch.linalg.inv`
184    * :func:`torch.linalg.inv_ex`
185    * :func:`torch.linalg.cholesky`
186    * :func:`torch.linalg.cholesky_ex`
187    * :func:`torch.cholesky_solve`
188    * :func:`torch.cholesky_inverse`
189    * :func:`torch.linalg.lu_factor`
190    * :func:`torch.linalg.lu`
191    * :func:`torch.linalg.lu_solve`
192    * :func:`torch.linalg.qr`
193    * :func:`torch.linalg.eigh`
194    * :func:`torch.linalg.eighvals`
195    * :func:`torch.linalg.svd`
196    * :func:`torch.linalg.svdvals`
197    """
198    if backend is None:
199        pass
200    elif isinstance(backend, str):
201        if backend not in _LinalgBackends:
202            raise RuntimeError(
203                "Unknown input value. " f"Choose from: {_LinalgBackends_str}."
204            )
205        torch._C._set_linalg_preferred_backend(_LinalgBackends[backend])
206    elif isinstance(backend, torch._C._LinalgBackend):
207        torch._C._set_linalg_preferred_backend(backend)
208    else:
209        raise RuntimeError("Unknown input value type.")
210
211    return torch._C._get_linalg_preferred_backend()
212
213
214_BlasBackends = {
215    "cublas": torch._C._BlasBackend.Cublas,
216    "cublaslt": torch._C._BlasBackend.Cublaslt,
217    "hipblaslt": torch._C._BlasBackend.Cublaslt,  # alias
218}
219_BlasBackends_str = ", ".join(_BlasBackends.keys())
220
221
222def preferred_blas_library(
223    backend: Union[None, str, torch._C._BlasBackend] = None
224) -> torch._C._BlasBackend:
225    r"""
226    Override the library PyTorch uses for BLAS operations. Choose between cuBLAS and cuBLASLt.
227
228    .. warning:: This flag is experimental and subject to change.
229
230    When PyTorch runs a CUDA BLAS operation it defaults to cuBLAS even if both cuBLAS and cuBLASLt are available.
231    For PyTorch built for ROCm, hipBLAS and hipBLASLt may offer different performance.
232    This flag (a :class:`str`) allows overriding which BLAS library to use.
233
234    * If `"cublas"` is set then cuBLAS will be used wherever possible.
235    * If `"cublaslt"` is set then cuBLASLt will be used wherever possible.
236    * When no input is given, this function returns the currently preferred library.
237    * User may use the environment variable TORCH_BLAS_PREFER_CUBLASLT=1 to set the preferred library to cuBLASLt
238      globally.
239      This flag only sets the initial value of the preferred library and the preferred library
240      may still be overridden by this function call later in your script.
241
242    Note: When a library is preferred other libraries may still be used if the preferred library
243    doesn't implement the operation(s) called.
244    This flag may achieve better performance if PyTorch's library selection is incorrect
245    for your application's inputs.
246
247    """
248    if backend is None:
249        pass
250    elif isinstance(backend, str):
251        if backend not in _BlasBackends:
252            raise RuntimeError(
253                "Unknown input value. " f"Choose from: {_BlasBackends_str}."
254            )
255        torch._C._set_blas_preferred_backend(_BlasBackends[backend])
256    elif isinstance(backend, torch._C._BlasBackend):
257        torch._C._set_blas_preferred_backend(backend)
258    else:
259        raise RuntimeError("Unknown input value type.")
260
261    return torch._C._get_blas_preferred_backend()
262
263
264from torch._C import _SDPAParams as SDPAParams, _SDPBackend as SDPBackend
265
266
267# Set the __module__ attribute
268SDPAParams.__module__ = "torch.backends.cuda"
269SDPAParams.__name__ = "SDPAParams"
270
271
272def flash_sdp_enabled():
273    r"""
274    .. warning:: This flag is beta and subject to change.
275
276    Returns whether flash scaled dot product attention is enabled or not.
277    """
278    return torch._C._get_flash_sdp_enabled()
279
280
281def enable_flash_sdp(enabled: bool):
282    r"""
283    .. warning:: This flag is beta and subject to change.
284
285    Enables or disables flash scaled dot product attention.
286    """
287    torch._C._set_sdp_use_flash(enabled)
288
289
290def mem_efficient_sdp_enabled():
291    r"""
292    .. warning:: This flag is beta and subject to change.
293
294    Returns whether memory efficient scaled dot product attention is enabled or not.
295    """
296    return torch._C._get_mem_efficient_sdp_enabled()
297
298
299def enable_mem_efficient_sdp(enabled: bool):
300    r"""
301    .. warning:: This flag is beta and subject to change.
302
303    Enables or disables memory efficient scaled dot product attention.
304    """
305    torch._C._set_sdp_use_mem_efficient(enabled)
306
307
308def math_sdp_enabled():
309    r"""
310    .. warning:: This flag is beta and subject to change.
311
312    Returns whether math scaled dot product attention is enabled or not.
313    """
314    return torch._C._get_math_sdp_enabled()
315
316
317def enable_math_sdp(enabled: bool):
318    r"""
319    .. warning:: This flag is beta and subject to change.
320
321    Enables or disables math scaled dot product attention.
322    """
323    torch._C._set_sdp_use_math(enabled)
324
325
326def allow_fp16_bf16_reduction_math_sdp(enabled: bool):
327    r"""
328    .. warning:: This flag is beta and subject to change.
329
330    Enables or disables fp16/bf16 reduction in math scaled dot product attention.
331    """
332    torch._C._set_math_sdp_allow_fp16_bf16_reduction(enabled)
333
334
335def fp16_bf16_reduction_math_sdp_allowed():
336    r"""
337    .. warning:: This flag is beta and subject to change.
338
339    Returns whether fp16/bf16 reduction in math scaled dot product attention is enabled or not.
340    """
341    return torch._C._get_math_sdp_allow_fp16_bf16_reduction()
342
343
344def is_flash_attention_available() -> bool:
345    r"""Check if PyTorch was built with FlashAttention for scaled_dot_product_attention.
346
347    Returns:
348        True if FlashAttention is built and available; otherwise, False.
349
350    Note:
351        This function is dependent on a CUDA-enabled build of PyTorch. It will return False
352        in non-CUDA environments.
353    """
354    return torch._C._is_flash_attention_available()
355
356
357def can_use_flash_attention(params: SDPAParams, debug: bool = False) -> bool:
358    r"""Check if FlashAttention can be utilized in scaled_dot_product_attention.
359
360    Args:
361        params: An instance of SDPAParams containing the tensors for query,
362                key, value, an optional attention mask, dropout rate, and
363                a flag indicating if the attention is causal.
364        debug: Whether to logging.warn debug information as to why FlashAttention could not be run.
365            Defaults to False.
366
367    Returns:
368        True if FlashAttention can be used with the given parameters; otherwise, False.
369
370    Note:
371        This function is dependent on a CUDA-enabled build of PyTorch. It will return False
372        in non-CUDA environments.
373    """
374    return torch._C._can_use_flash_attention(params, debug)
375
376
377def can_use_efficient_attention(params: SDPAParams, debug: bool = False) -> bool:
378    r"""Check if efficient_attention can be utilized in scaled_dot_product_attention.
379
380    Args:
381        params: An instance of SDPAParams containing the tensors for query,
382                key, value, an optional attention mask, dropout rate, and
383                a flag indicating if the attention is causal.
384        debug: Whether to logging.warn with information as to why efficient_attention could not be run.
385            Defaults to False.
386
387    Returns:
388        True if efficient_attention can be used with the given parameters; otherwise, False.
389
390    Note:
391        This function is dependent on a CUDA-enabled build of PyTorch. It will return False
392        in non-CUDA environments.
393    """
394    return torch._C._can_use_mem_efficient_attention(params, debug)
395
396
397def cudnn_sdp_enabled():
398    r"""
399    .. warning:: This flag is beta and subject to change.
400
401    Returns whether cuDNN scaled dot product attention is enabled or not.
402    """
403    return torch._C._get_cudnn_sdp_enabled()
404
405
406def enable_cudnn_sdp(enabled: bool):
407    r"""
408    .. warning:: This flag is beta and subject to change.
409
410    Enables or disables cuDNN scaled dot product attention.
411    """
412    torch._C._set_sdp_use_cudnn(enabled)
413
414
415@contextlib.contextmanager
416@deprecated(
417    (
418        "`torch.backends.cuda.sdp_kernel()` is deprecated. "
419        "In the future, this context manager will be removed. "
420        "Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, "
421        "with updated signature."
422    ),
423    category=FutureWarning,
424)
425def sdp_kernel(
426    enable_flash: bool = True,
427    enable_math: bool = True,
428    enable_mem_efficient: bool = True,
429    enable_cudnn: bool = True,
430):
431    r"""
432    .. warning:: This flag is beta and subject to change.
433
434    This context manager can be used to temporarily enable or disable any of the three backends for scaled dot product attention.
435    Upon exiting the context manager, the previous state of the flags will be restored.
436    """
437    from torch.nn.attention import sdpa_kernel
438
439    backend_list = []
440    if enable_flash:
441        backend_list.append(SDPBackend.FLASH_ATTENTION)
442    if enable_mem_efficient:
443        backend_list.append(SDPBackend.EFFICIENT_ATTENTION)
444    if enable_math:
445        backend_list.append(SDPBackend.MATH)
446    if enable_cudnn:
447        backend_list.append(SDPBackend.CUDNN_ATTENTION)
448
449    with sdpa_kernel(backend_list) as context:
450        try:
451            yield context
452        finally:
453            pass
454
455
456cufft_plan_cache = cuFFTPlanCacheManager()
457matmul = cuBLASModule()
458