1# mypy: allow-untyped-defs 2import contextlib 3from typing import Union 4from typing_extensions import deprecated 5 6import torch 7 8 9__all__ = [ 10 "is_built", 11 "cuFFTPlanCacheAttrContextProp", 12 "cuFFTPlanCache", 13 "cuFFTPlanCacheManager", 14 "cuBLASModule", 15 "preferred_linalg_library", 16 "preferred_blas_library", 17 "cufft_plan_cache", 18 "matmul", 19 "SDPAParams", 20 "enable_cudnn_sdp", 21 "cudnn_sdp_enabled", 22 "enable_flash_sdp", 23 "flash_sdp_enabled", 24 "enable_mem_efficient_sdp", 25 "mem_efficient_sdp_enabled", 26 "math_sdp_enabled", 27 "enable_math_sdp", 28 "allow_fp16_bf16_reduction_math_sdp", 29 "fp16_bf16_reduction_math_sdp_allowed", 30 "is_flash_attention_available", 31 "can_use_flash_attention", 32 "can_use_efficient_attention", 33 "sdp_kernel", 34] 35 36 37def is_built(): 38 r""" 39 Return whether PyTorch is built with CUDA support. 40 41 Note that this doesn't necessarily mean CUDA is available; just that if this PyTorch 42 binary were run on a machine with working CUDA drivers and devices, we would be able to use it. 43 """ 44 return torch._C._has_cuda 45 46 47class cuFFTPlanCacheAttrContextProp: 48 # Like regular ContextProp, but uses the `.device_index` attribute from the 49 # calling object as the first argument to the getter and setter. 50 def __init__(self, getter, setter): 51 self.getter = getter 52 self.setter = setter 53 54 def __get__(self, obj, objtype): 55 return self.getter(obj.device_index) 56 57 def __set__(self, obj, val): 58 if isinstance(self.setter, str): 59 raise RuntimeError(self.setter) 60 self.setter(obj.device_index, val) 61 62 63class cuFFTPlanCache: 64 r""" 65 Represent a specific plan cache for a specific `device_index`. 66 67 The attributes `size` and `max_size`, and method `clear`, can fetch and/ or 68 change properties of the C++ cuFFT plan cache. 69 """ 70 71 def __init__(self, device_index): 72 self.device_index = device_index 73 74 size = cuFFTPlanCacheAttrContextProp( 75 torch._cufft_get_plan_cache_size, 76 ".size is a read-only property showing the number of plans currently in the " 77 "cache. To change the cache capacity, set cufft_plan_cache.max_size.", 78 ) 79 80 max_size = cuFFTPlanCacheAttrContextProp( 81 torch._cufft_get_plan_cache_max_size, torch._cufft_set_plan_cache_max_size 82 ) 83 84 def clear(self): 85 return torch._cufft_clear_plan_cache(self.device_index) 86 87 88class cuFFTPlanCacheManager: 89 r""" 90 Represent all cuFFT plan caches, return the cuFFTPlanCache for a given device when indexed. 91 92 Finally, this object, when used directly as a `cuFFTPlanCache` object (e.g., 93 setting the `.max_size`) attribute, the current device's cuFFT plan cache is 94 used. 95 """ 96 97 __initialized = False 98 99 def __init__(self): 100 self.caches = [] 101 self.__initialized = True 102 103 def __getitem__(self, device): 104 index = torch.cuda._utils._get_device_index(device) 105 if index < 0 or index >= torch.cuda.device_count(): 106 raise RuntimeError( 107 f"cufft_plan_cache: expected 0 <= device index < {torch.cuda.device_count()}, but got " 108 f"device with index {index}" 109 ) 110 if len(self.caches) == 0: 111 self.caches.extend( 112 cuFFTPlanCache(index) for index in range(torch.cuda.device_count()) 113 ) 114 return self.caches[index] 115 116 def __getattr__(self, name): 117 return getattr(self[torch.cuda.current_device()], name) 118 119 def __setattr__(self, name, value): 120 if self.__initialized: 121 return setattr(self[torch.cuda.current_device()], name, value) 122 else: 123 return super().__setattr__(name, value) 124 125 126class cuBLASModule: 127 def __getattr__(self, name): 128 if name == "allow_tf32": 129 return torch._C._get_cublas_allow_tf32() 130 elif name == "allow_fp16_reduced_precision_reduction": 131 return torch._C._get_cublas_allow_fp16_reduced_precision_reduction() 132 elif name == "allow_bf16_reduced_precision_reduction": 133 return torch._C._get_cublas_allow_bf16_reduced_precision_reduction() 134 raise AttributeError("Unknown attribute " + name) 135 136 def __setattr__(self, name, value): 137 if name == "allow_tf32": 138 return torch._C._set_cublas_allow_tf32(value) 139 elif name == "allow_fp16_reduced_precision_reduction": 140 return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(value) 141 elif name == "allow_bf16_reduced_precision_reduction": 142 return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value) 143 raise AttributeError("Unknown attribute " + name) 144 145 146_LinalgBackends = { 147 "default": torch._C._LinalgBackend.Default, 148 "cusolver": torch._C._LinalgBackend.Cusolver, 149 "magma": torch._C._LinalgBackend.Magma, 150} 151_LinalgBackends_str = ", ".join(_LinalgBackends.keys()) 152 153 154def preferred_linalg_library( 155 backend: Union[None, str, torch._C._LinalgBackend] = None 156) -> torch._C._LinalgBackend: 157 r""" 158 Override the heuristic PyTorch uses to choose between cuSOLVER and MAGMA for CUDA linear algebra operations. 159 160 .. warning:: This flag is experimental and subject to change. 161 162 When PyTorch runs a CUDA linear algebra operation it often uses the cuSOLVER or MAGMA libraries, 163 and if both are available it decides which to use with a heuristic. 164 This flag (a :class:`str`) allows overriding those heuristics. 165 166 * If `"cusolver"` is set then cuSOLVER will be used wherever possible. 167 * If `"magma"` is set then MAGMA will be used wherever possible. 168 * If `"default"` (the default) is set then heuristics will be used to pick between 169 cuSOLVER and MAGMA if both are available. 170 * When no input is given, this function returns the currently preferred library. 171 * User may use the environment variable TORCH_LINALG_PREFER_CUSOLVER=1 to set the preferred library to cuSOLVER 172 globally. 173 This flag only sets the initial value of the preferred library and the preferred library 174 may still be overridden by this function call later in your script. 175 176 Note: When a library is preferred other libraries may still be used if the preferred library 177 doesn't implement the operation(s) called. 178 This flag may achieve better performance if PyTorch's heuristic library selection is incorrect 179 for your application's inputs. 180 181 Currently supported linalg operators: 182 183 * :func:`torch.linalg.inv` 184 * :func:`torch.linalg.inv_ex` 185 * :func:`torch.linalg.cholesky` 186 * :func:`torch.linalg.cholesky_ex` 187 * :func:`torch.cholesky_solve` 188 * :func:`torch.cholesky_inverse` 189 * :func:`torch.linalg.lu_factor` 190 * :func:`torch.linalg.lu` 191 * :func:`torch.linalg.lu_solve` 192 * :func:`torch.linalg.qr` 193 * :func:`torch.linalg.eigh` 194 * :func:`torch.linalg.eighvals` 195 * :func:`torch.linalg.svd` 196 * :func:`torch.linalg.svdvals` 197 """ 198 if backend is None: 199 pass 200 elif isinstance(backend, str): 201 if backend not in _LinalgBackends: 202 raise RuntimeError( 203 "Unknown input value. " f"Choose from: {_LinalgBackends_str}." 204 ) 205 torch._C._set_linalg_preferred_backend(_LinalgBackends[backend]) 206 elif isinstance(backend, torch._C._LinalgBackend): 207 torch._C._set_linalg_preferred_backend(backend) 208 else: 209 raise RuntimeError("Unknown input value type.") 210 211 return torch._C._get_linalg_preferred_backend() 212 213 214_BlasBackends = { 215 "cublas": torch._C._BlasBackend.Cublas, 216 "cublaslt": torch._C._BlasBackend.Cublaslt, 217 "hipblaslt": torch._C._BlasBackend.Cublaslt, # alias 218} 219_BlasBackends_str = ", ".join(_BlasBackends.keys()) 220 221 222def preferred_blas_library( 223 backend: Union[None, str, torch._C._BlasBackend] = None 224) -> torch._C._BlasBackend: 225 r""" 226 Override the library PyTorch uses for BLAS operations. Choose between cuBLAS and cuBLASLt. 227 228 .. warning:: This flag is experimental and subject to change. 229 230 When PyTorch runs a CUDA BLAS operation it defaults to cuBLAS even if both cuBLAS and cuBLASLt are available. 231 For PyTorch built for ROCm, hipBLAS and hipBLASLt may offer different performance. 232 This flag (a :class:`str`) allows overriding which BLAS library to use. 233 234 * If `"cublas"` is set then cuBLAS will be used wherever possible. 235 * If `"cublaslt"` is set then cuBLASLt will be used wherever possible. 236 * When no input is given, this function returns the currently preferred library. 237 * User may use the environment variable TORCH_BLAS_PREFER_CUBLASLT=1 to set the preferred library to cuBLASLt 238 globally. 239 This flag only sets the initial value of the preferred library and the preferred library 240 may still be overridden by this function call later in your script. 241 242 Note: When a library is preferred other libraries may still be used if the preferred library 243 doesn't implement the operation(s) called. 244 This flag may achieve better performance if PyTorch's library selection is incorrect 245 for your application's inputs. 246 247 """ 248 if backend is None: 249 pass 250 elif isinstance(backend, str): 251 if backend not in _BlasBackends: 252 raise RuntimeError( 253 "Unknown input value. " f"Choose from: {_BlasBackends_str}." 254 ) 255 torch._C._set_blas_preferred_backend(_BlasBackends[backend]) 256 elif isinstance(backend, torch._C._BlasBackend): 257 torch._C._set_blas_preferred_backend(backend) 258 else: 259 raise RuntimeError("Unknown input value type.") 260 261 return torch._C._get_blas_preferred_backend() 262 263 264from torch._C import _SDPAParams as SDPAParams, _SDPBackend as SDPBackend 265 266 267# Set the __module__ attribute 268SDPAParams.__module__ = "torch.backends.cuda" 269SDPAParams.__name__ = "SDPAParams" 270 271 272def flash_sdp_enabled(): 273 r""" 274 .. warning:: This flag is beta and subject to change. 275 276 Returns whether flash scaled dot product attention is enabled or not. 277 """ 278 return torch._C._get_flash_sdp_enabled() 279 280 281def enable_flash_sdp(enabled: bool): 282 r""" 283 .. warning:: This flag is beta and subject to change. 284 285 Enables or disables flash scaled dot product attention. 286 """ 287 torch._C._set_sdp_use_flash(enabled) 288 289 290def mem_efficient_sdp_enabled(): 291 r""" 292 .. warning:: This flag is beta and subject to change. 293 294 Returns whether memory efficient scaled dot product attention is enabled or not. 295 """ 296 return torch._C._get_mem_efficient_sdp_enabled() 297 298 299def enable_mem_efficient_sdp(enabled: bool): 300 r""" 301 .. warning:: This flag is beta and subject to change. 302 303 Enables or disables memory efficient scaled dot product attention. 304 """ 305 torch._C._set_sdp_use_mem_efficient(enabled) 306 307 308def math_sdp_enabled(): 309 r""" 310 .. warning:: This flag is beta and subject to change. 311 312 Returns whether math scaled dot product attention is enabled or not. 313 """ 314 return torch._C._get_math_sdp_enabled() 315 316 317def enable_math_sdp(enabled: bool): 318 r""" 319 .. warning:: This flag is beta and subject to change. 320 321 Enables or disables math scaled dot product attention. 322 """ 323 torch._C._set_sdp_use_math(enabled) 324 325 326def allow_fp16_bf16_reduction_math_sdp(enabled: bool): 327 r""" 328 .. warning:: This flag is beta and subject to change. 329 330 Enables or disables fp16/bf16 reduction in math scaled dot product attention. 331 """ 332 torch._C._set_math_sdp_allow_fp16_bf16_reduction(enabled) 333 334 335def fp16_bf16_reduction_math_sdp_allowed(): 336 r""" 337 .. warning:: This flag is beta and subject to change. 338 339 Returns whether fp16/bf16 reduction in math scaled dot product attention is enabled or not. 340 """ 341 return torch._C._get_math_sdp_allow_fp16_bf16_reduction() 342 343 344def is_flash_attention_available() -> bool: 345 r"""Check if PyTorch was built with FlashAttention for scaled_dot_product_attention. 346 347 Returns: 348 True if FlashAttention is built and available; otherwise, False. 349 350 Note: 351 This function is dependent on a CUDA-enabled build of PyTorch. It will return False 352 in non-CUDA environments. 353 """ 354 return torch._C._is_flash_attention_available() 355 356 357def can_use_flash_attention(params: SDPAParams, debug: bool = False) -> bool: 358 r"""Check if FlashAttention can be utilized in scaled_dot_product_attention. 359 360 Args: 361 params: An instance of SDPAParams containing the tensors for query, 362 key, value, an optional attention mask, dropout rate, and 363 a flag indicating if the attention is causal. 364 debug: Whether to logging.warn debug information as to why FlashAttention could not be run. 365 Defaults to False. 366 367 Returns: 368 True if FlashAttention can be used with the given parameters; otherwise, False. 369 370 Note: 371 This function is dependent on a CUDA-enabled build of PyTorch. It will return False 372 in non-CUDA environments. 373 """ 374 return torch._C._can_use_flash_attention(params, debug) 375 376 377def can_use_efficient_attention(params: SDPAParams, debug: bool = False) -> bool: 378 r"""Check if efficient_attention can be utilized in scaled_dot_product_attention. 379 380 Args: 381 params: An instance of SDPAParams containing the tensors for query, 382 key, value, an optional attention mask, dropout rate, and 383 a flag indicating if the attention is causal. 384 debug: Whether to logging.warn with information as to why efficient_attention could not be run. 385 Defaults to False. 386 387 Returns: 388 True if efficient_attention can be used with the given parameters; otherwise, False. 389 390 Note: 391 This function is dependent on a CUDA-enabled build of PyTorch. It will return False 392 in non-CUDA environments. 393 """ 394 return torch._C._can_use_mem_efficient_attention(params, debug) 395 396 397def cudnn_sdp_enabled(): 398 r""" 399 .. warning:: This flag is beta and subject to change. 400 401 Returns whether cuDNN scaled dot product attention is enabled or not. 402 """ 403 return torch._C._get_cudnn_sdp_enabled() 404 405 406def enable_cudnn_sdp(enabled: bool): 407 r""" 408 .. warning:: This flag is beta and subject to change. 409 410 Enables or disables cuDNN scaled dot product attention. 411 """ 412 torch._C._set_sdp_use_cudnn(enabled) 413 414 415@contextlib.contextmanager 416@deprecated( 417 ( 418 "`torch.backends.cuda.sdp_kernel()` is deprecated. " 419 "In the future, this context manager will be removed. " 420 "Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, " 421 "with updated signature." 422 ), 423 category=FutureWarning, 424) 425def sdp_kernel( 426 enable_flash: bool = True, 427 enable_math: bool = True, 428 enable_mem_efficient: bool = True, 429 enable_cudnn: bool = True, 430): 431 r""" 432 .. warning:: This flag is beta and subject to change. 433 434 This context manager can be used to temporarily enable or disable any of the three backends for scaled dot product attention. 435 Upon exiting the context manager, the previous state of the flags will be restored. 436 """ 437 from torch.nn.attention import sdpa_kernel 438 439 backend_list = [] 440 if enable_flash: 441 backend_list.append(SDPBackend.FLASH_ATTENTION) 442 if enable_mem_efficient: 443 backend_list.append(SDPBackend.EFFICIENT_ATTENTION) 444 if enable_math: 445 backend_list.append(SDPBackend.MATH) 446 if enable_cudnn: 447 backend_list.append(SDPBackend.CUDNN_ATTENTION) 448 449 with sdpa_kernel(backend_list) as context: 450 try: 451 yield context 452 finally: 453 pass 454 455 456cufft_plan_cache = cuFFTPlanCacheManager() 457matmul = cuBLASModule() 458