xref: /aosp_15_r20/external/pytorch/torch/nn/attention/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2""" This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """
3import contextlib
4from typing import List, Union
5from warnings import warn
6
7from torch._C import _SDPBackend as SDPBackend
8from torch.backends.cuda import (
9    can_use_efficient_attention,
10    can_use_flash_attention,
11    cudnn_sdp_enabled,
12    enable_cudnn_sdp,
13    enable_flash_sdp,
14    enable_math_sdp,
15    enable_mem_efficient_sdp,
16    flash_sdp_enabled,
17    math_sdp_enabled,
18    mem_efficient_sdp_enabled,
19    SDPAParams,
20)
21
22
23__all__: List[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"]
24
25# Note: [SDPA warnings]
26# TODO: Consider using this for sdpa regardless of subclasses
27# This only effects users of bias subclasses
28# If this is set to True, we will warn the user if they are not using the fused kernels
29# As well, it will raise warnings for all the reasons why the fused kernels can't be run.
30# To set this to True, run
31# torch.nn.attention.WARN_FOR_UNFUSED_KERNELS = True
32WARN_FOR_UNFUSED_KERNELS = False
33
34
35# Hacks for Sphinx documentation:
36# https://stackoverflow.com/questions/38765577/overriding-sphinx-autodoc-alias-of-for-import-of-private-class
37SDPBackend = SDPBackend
38r"""An enum-like class that contains the different backends for scaled dot product attention.
39    This backend class is designed to be used with the sdpa_kernel context manager.
40
41    The following Enums are available:
42        - ERROR: An error occurred when trying to determine the backend.
43        - MATH: The math backend for scaled dot product attention.
44        - FLASH_ATTENTION: The flash attention backend for scaled dot product attention.
45        - EFFICIENT_ATTENTION: The efficient attention backend for scaled dot product attention.
46        - CUDNN_ATTENTION: The cuDNN backend for scaled dot product attention.
47
48    See :func:`torch.nn.attention.sdpa_kernel` for more details.
49
50    .. warning:: This class is in beta and subject to change.
51"""
52SDPBackend.__module__ = __name__
53SDPBackend.__name__ = "SDPBackend"
54
55
56def _raise_kernel_warnings(params: SDPAParams) -> None:
57    """
58    If WARN_FOR_UNFUSED_KERNELS is set to True, this will raise warnings
59    for all the reasons why the fused kernels can't be run. If using subclasses
60    """
61    if WARN_FOR_UNFUSED_KERNELS:
62        if not can_use_efficient_attention(params):
63            warn("Efficient attention can't be used because:")
64            can_use_efficient_attention(params, True)
65        if not can_use_flash_attention(params):
66            warn("Flash attention can't be used because:")
67            can_use_flash_attention(params, True)
68
69
70@contextlib.contextmanager
71def sdpa_kernel(backends: Union[List[SDPBackend], SDPBackend]):
72    r"""
73    Context manager to select which backend to use for scaled dot product attention.
74
75    .. warning:: This function is beta and subject to change.
76
77    Args:
78        backend (Union[List[SDPBackend], SDPBackend]): A backend or list of backends for scaled dot product attention.
79
80    Example:
81
82    .. code-block:: python
83
84        from torch.nn.functional import scaled_dot_product_attention
85        from torch.nn.attention import SDPBackend, sdpa_kernel
86        # Only enable flash attention backend
87        with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
88            scaled_dot_product_attention(...)
89
90        # Enable the Math or Efficient attention backends
91        with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
92            scaled_dot_product_attention(...)
93
94    This context manager can be used to select which backend to use for scaled dot product attention.
95    Upon exiting the context manager, the previous state of the flags will be restored, enabling all backends.
96    """
97    assert isinstance(
98        backends, (list, SDPBackend)
99    ), "Backend must be an instance of SDPBackend or a list of SDPBackend instances"
100
101    if isinstance(backends, SDPBackend):
102        backends = [backends]
103
104    backends = set(backends)
105    previous_cudnn: bool = cudnn_sdp_enabled()
106    previous_flash: bool = flash_sdp_enabled()
107    previous_mem_efficient: bool = mem_efficient_sdp_enabled()
108    previous_math: bool = math_sdp_enabled()
109    try:
110        enable_cudnn = SDPBackend.CUDNN_ATTENTION in backends
111        enable_flash = SDPBackend.FLASH_ATTENTION in backends
112        enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION in backends
113        enable_math = SDPBackend.MATH in backends
114
115        enable_cudnn_sdp(enable_cudnn)
116        enable_flash_sdp(enable_flash)
117        enable_mem_efficient_sdp(enable_mem_efficient)
118        enable_math_sdp(enable_math)
119        yield {}
120    finally:
121        enable_cudnn_sdp(previous_cudnn)
122        enable_flash_sdp(previous_flash)
123        enable_mem_efficient_sdp(previous_mem_efficient)
124        enable_math_sdp(previous_math)
125
126
127def _get_flash_version() -> str:
128    """This returns the closest matching tag for the flash attention backend"""
129    return "2.5.7"
130