1import math 2from typing import Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union 3 4import torch 5import torch._prims as prims 6import torch._prims_common as utils 7from torch._decomp import register_decomposition 8from torch._prims_common import DimsType, ShapeType, TensorLikeType 9from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper 10 11 12__all__ = [ 13 # Transforms 14 "fft", 15 "fft2", 16 "fftn", 17 "hfft", 18 "hfft2", 19 "hfftn", 20 "rfft", 21 "rfft2", 22 "rfftn", 23 "ifft", 24 "ifft2", 25 "ifftn", 26 "ihfft", 27 "ihfft2", 28 "ihfftn", 29 "irfft", 30 "irfft2", 31 "irfftn", 32 # Helpers 33 "fftshift", 34 "ifftshift", 35] 36 37NormType = Union[None, Literal["forward", "backward", "ortho"]] 38_NORM_VALUES = {None, "forward", "backward", "ortho"} 39aten = torch._ops.ops.aten 40 41 42def _apply_norm( 43 x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool 44) -> TensorLikeType: 45 """Apply normalization to the un-normalized FFT result""" 46 torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}") 47 48 if norm == "ortho": 49 return x * (1 / math.sqrt(signal_numel)) 50 51 normalize = (not forward and (norm is None or norm == "backward")) or ( 52 forward and norm == "forward" 53 ) 54 return x * (1 / signal_numel) if normalize else x 55 56 57def _promote_type_fft( 58 dtype: torch.dtype, require_complex: bool, device: torch.device 59) -> torch.dtype: 60 """Helper to promote a dtype to one supported by the FFT primitives""" 61 if dtype.is_complex: 62 return dtype 63 64 # Promote integral to default float type 65 if not dtype.is_floating_point: 66 dtype = torch.get_default_dtype() 67 68 allowed_types = [torch.float32, torch.float64] 69 maybe_support_half = device.type in ["cuda", "meta"] 70 71 if maybe_support_half: 72 allowed_types.append(torch.float16) 73 torch._check(dtype in allowed_types, lambda: f"Unsupported dtype {dtype}") 74 75 if require_complex: 76 dtype = utils.corresponding_complex_dtype(dtype) 77 78 return dtype 79 80 81def _maybe_promote_tensor_fft( 82 t: TensorLikeType, require_complex: bool = False 83) -> TensorLikeType: 84 """Helper to promote a tensor to a dtype supported by the FFT primitives""" 85 cur_type = t.dtype 86 new_type = _promote_type_fft(cur_type, require_complex, t.device) 87 return _maybe_convert_to_dtype(t, new_type) # type: ignore[return-value] 88 89 90def _resize_fft_input( 91 x: TensorLikeType, dims: Tuple[int, ...], sizes: Tuple[int, ...] 92) -> TensorLikeType: 93 """ 94 Fixes the shape of x such that x.size(dims[i]) == sizes[i], 95 either by zero-padding, or by slicing x starting from 0. 96 """ 97 assert len(dims) == len(sizes) 98 must_copy = False 99 x_sizes = x.shape 100 pad_amount = [0] * len(x_sizes) * 2 101 for i in range(len(dims)): 102 if sizes[i] == -1: 103 continue 104 105 if x_sizes[dims[i]] < sizes[i]: 106 must_copy = True 107 pad_idx = len(pad_amount) - 2 * dims[i] - 1 108 pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]] 109 110 if x_sizes[dims[i]] > sizes[i]: 111 x = x.narrow(dims[i], 0, sizes[i]) 112 113 return torch.constant_pad_nd(x, pad_amount) if must_copy else x 114 115 116def _fft_c2r( 117 func_name: str, 118 input: TensorLikeType, 119 n: Optional[int], 120 dim: int, 121 norm: NormType, 122 forward: bool, 123) -> TensorLikeType: 124 """Common code for performing any complex to real FFT (irfft or hfft)""" 125 input = _maybe_promote_tensor_fft(input, require_complex=True) 126 dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) 127 last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1) 128 torch._check( 129 last_dim_size >= 1, 130 lambda: f"Invalid number of data points ({last_dim_size}) specified", 131 ) 132 133 if n is not None: 134 input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,)) 135 136 if forward: 137 input = torch.conj(input) 138 139 output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size) 140 return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward) 141 142 143def _fft_r2c( 144 func_name: str, 145 input: TensorLikeType, 146 n: Optional[int], 147 dim: int, 148 norm: NormType, 149 forward: bool, 150 onesided: bool, 151) -> TensorLikeType: 152 """Common code for performing any real to complex FFT (rfft or ihfft)""" 153 torch._check( 154 not input.dtype.is_complex, 155 lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}", 156 ) 157 input = _maybe_promote_tensor_fft(input) 158 dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) 159 dim_size = n if n is not None else input.shape[dim] 160 torch._check( 161 dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified" 162 ) 163 164 if n is not None: 165 input = _resize_fft_input(input, dims, (n,)) 166 167 ret = prims.fft_r2c(input, dim=dims, onesided=onesided) 168 ret = _apply_norm(ret, norm, dim_size, forward) 169 return ret if forward else torch.conj(ret) 170 171 172def _fft_c2c( 173 func_name: str, 174 input: TensorLikeType, 175 n: Optional[int], 176 dim: int, 177 norm: NormType, 178 forward: bool, 179) -> TensorLikeType: 180 """Common code for performing any complex to complex FFT (fft or ifft)""" 181 torch._check( 182 input.dtype.is_complex, 183 lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}", 184 ) 185 dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) 186 dim_size = n if n is not None else input.shape[dim] 187 torch._check( 188 dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified" 189 ) 190 191 if n is not None: 192 input = _resize_fft_input(input, dims, (n,)) 193 194 ret = prims.fft_c2c(input, dim=dims, forward=forward) 195 return _apply_norm(ret, norm, dim_size, forward) 196 197 198@register_decomposition(aten.fft_fft) 199@out_wrapper() 200def fft( 201 input: TensorLikeType, 202 n: Optional[int] = None, 203 dim: int = -1, 204 norm: NormType = None, 205) -> TensorLikeType: 206 if input.dtype.is_complex: 207 return _fft_c2c("fft", input, n, dim, norm, forward=True) 208 else: 209 return _fft_r2c("fft", input, n, dim, norm, forward=True, onesided=False) 210 211 212@register_decomposition(aten.fft_ifft) 213@out_wrapper() 214def ifft( 215 input: TensorLikeType, 216 n: Optional[int] = None, 217 dim: int = -1, 218 norm: NormType = None, 219) -> TensorLikeType: 220 if input.dtype.is_complex: 221 return _fft_c2c("ifft", input, n, dim, norm, forward=False) 222 else: 223 return _fft_r2c("ifft", input, n, dim, norm, forward=False, onesided=False) 224 225 226@register_decomposition(aten.fft_rfft) 227@out_wrapper() 228def rfft( 229 input: TensorLikeType, 230 n: Optional[int] = None, 231 dim: int = -1, 232 norm: NormType = None, 233) -> TensorLikeType: 234 return _fft_r2c("rfft", input, n, dim, norm, forward=True, onesided=True) 235 236 237@register_decomposition(aten.fft_irfft) 238@out_wrapper() 239def irfft( 240 input: TensorLikeType, 241 n: Optional[int] = None, 242 dim: int = -1, 243 norm: NormType = None, 244) -> TensorLikeType: 245 return _fft_c2r("irfft", input, n, dim, norm, forward=False) 246 247 248@register_decomposition(aten.fft_hfft) 249@out_wrapper() 250def hfft( 251 input: TensorLikeType, 252 n: Optional[int] = None, 253 dim: int = -1, 254 norm: NormType = None, 255) -> TensorLikeType: 256 return _fft_c2r("hfft", input, n, dim, norm, forward=True) 257 258 259@register_decomposition(aten.fft_ihfft) 260@out_wrapper() 261def ihfft( 262 input: TensorLikeType, 263 n: Optional[int] = None, 264 dim: int = -1, 265 norm: NormType = None, 266) -> TensorLikeType: 267 return _fft_r2c("ihfft", input, n, dim, norm, forward=False, onesided=True) 268 269 270class _ShapeAndDims(NamedTuple): 271 shape: Tuple[int, ...] 272 dims: Tuple[int, ...] 273 274 275def _canonicalize_fft_shape_and_dim_args( 276 input: TensorLikeType, shape: Optional[ShapeType], dim: Optional[DimsType] 277) -> _ShapeAndDims: 278 """Convert the shape and dim arguments into a canonical form where neither are optional""" 279 input_dim = input.ndim 280 input_sizes = input.shape 281 282 if dim is not None: 283 if not isinstance(dim, Sequence): 284 dim = (dim,) 285 ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False) 286 287 # Check dims are unique 288 torch._check( 289 len(set(ret_dims)) == len(ret_dims), lambda: "FFT dims must be unique" 290 ) 291 292 if shape is not None: 293 if not isinstance(shape, Sequence): 294 shape = (shape,) 295 296 # Has shape, might have dim 297 torch._check( 298 dim is None or len(dim) == len(shape), 299 lambda: "When given, dim and shape arguments must have the same length", 300 ) 301 transform_ndim = len(shape) 302 303 torch._check( 304 transform_ndim <= input_dim, 305 lambda: f"Got shape with {transform_ndim} values but input tensor " 306 f"only has {input_dim} dimensions.", 307 ) 308 309 # If shape is given, dims defaults to the last len(shape) dimensions 310 if dim is None: 311 ret_dims = tuple(range(input_dim - transform_ndim, input_dim)) 312 313 # Translate any -1 values in shape to the default length 314 ret_shape = tuple( 315 s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined] 316 ) 317 elif dim is None: 318 # No shape, no dim 319 ret_dims = tuple(range(input_dim)) 320 ret_shape = tuple(input_sizes) 321 else: 322 # No shape, has dim 323 ret_shape = tuple(input_sizes[d] for d in ret_dims) # type: ignore[possibly-undefined] 324 325 for n in ret_shape: 326 torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified") 327 328 return _ShapeAndDims(shape=ret_shape, dims=ret_dims) # type: ignore[possibly-undefined] 329 330 331def _prod(xs: Iterable[int]) -> int: 332 """Compute product of a list""" 333 prod = 1 334 for x in xs: 335 prod *= x 336 return prod 337 338 339def _fftn_c2c( 340 function_name: str, 341 input: TensorLikeType, 342 shape: Tuple[int, ...], 343 dim: Tuple[int, ...], 344 norm: NormType, 345 forward: bool, 346) -> TensorLikeType: 347 """Common code for n-dimensional complex to complex FFTs (fftn or ifftn)""" 348 torch._check( 349 input.dtype.is_complex, 350 lambda: f"{function_name} expects a complex input tensor, " 351 f"but got {input.dtype}", 352 ) 353 x = _resize_fft_input(input, dim, shape) 354 output = prims.fft_c2c(x, dim=dim, forward=forward) 355 return _apply_norm(output, norm=norm, signal_numel=_prod(shape), forward=forward) 356 357 358@register_decomposition(aten.fft_fftn) 359@out_wrapper() 360def fftn( 361 input: TensorLikeType, 362 s: Optional[ShapeType] = None, 363 dim: Optional[DimsType] = None, 364 norm: NormType = None, 365) -> TensorLikeType: 366 (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) 367 x = _maybe_promote_tensor_fft(input, require_complex=True) 368 return _fftn_c2c("fftn", x, shape, dim, norm, forward=True) 369 370 371@register_decomposition(aten.fft_ifftn) 372@out_wrapper() 373def ifftn( 374 input: TensorLikeType, 375 s: Optional[ShapeType] = None, 376 dim: Optional[DimsType] = None, 377 norm: NormType = None, 378) -> TensorLikeType: 379 (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) 380 x = _maybe_promote_tensor_fft(input, require_complex=True) 381 return _fftn_c2c("ifftn", x, shape, dim, norm, forward=False) 382 383 384@register_decomposition(aten.fft_rfftn) 385@out_wrapper() 386def rfftn( 387 input: TensorLikeType, 388 s: Optional[ShapeType] = None, 389 dim: Optional[DimsType] = None, 390 norm: NormType = None, 391) -> TensorLikeType: 392 torch._check( 393 not input.dtype.is_complex, 394 lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}", 395 ) 396 shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim) 397 input = _maybe_promote_tensor_fft(input, require_complex=False) 398 input = _resize_fft_input(input, dim, shape) 399 out = prims.fft_r2c(input, dim=dim, onesided=True) 400 return _apply_norm(out, norm=norm, signal_numel=_prod(shape), forward=True) 401 402 403@register_decomposition(aten.fft_ihfftn) 404@out_wrapper() 405def ihfftn( 406 input: TensorLikeType, 407 s: Optional[ShapeType] = None, 408 dim: Optional[DimsType] = None, 409 norm: NormType = None, 410) -> TensorLikeType: 411 torch._check( 412 not input.dtype.is_complex, 413 lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}", 414 ) 415 shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim) 416 torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis") 417 input = _maybe_promote_tensor_fft(input, require_complex=False) 418 input = _resize_fft_input(input, dim, shape) 419 420 tmp = prims.fft_r2c(input, dim=dim[-1:], onesided=True) 421 422 if len(dim) == 1: 423 tmp = _apply_norm(tmp, norm=norm, signal_numel=shape[0], forward=False) 424 return prims.conj(tmp) 425 426 tmp = prims.conj_physical(tmp) 427 tmp = prims.fft_c2c(tmp, dim=dim[:-1], forward=False) 428 return _apply_norm(tmp, norm=norm, signal_numel=_prod(shape), forward=False) 429 430 431class _CanonicalizeC2rReturn(NamedTuple): 432 shape: Tuple[int, ...] 433 dim: Tuple[int, ...] 434 last_dim_size: int 435 436 437def _canonicalize_fft_c2r_shape_and_dim_args( 438 fname: str, 439 input: TensorLikeType, 440 s: Optional[ShapeType], 441 dim: Optional[DimsType], 442) -> _CanonicalizeC2rReturn: 443 """Canonicalize shape and dim arguments for n-dimensional c2r transforms, 444 as well as calculating the last_dim_size which is shape[dim[-1]] for the output""" 445 (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) 446 torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis") 447 448 if s is None or s[-1] == -1: 449 last_dim_size = 2 * (input.shape[dim[-1]] - 1) 450 else: 451 last_dim_size = shape[-1] 452 453 torch._check( 454 last_dim_size >= 1, 455 lambda: f"Invalid number of data points ({last_dim_size}) specified", 456 ) 457 458 shape_list = list(shape) 459 shape_list[-1] = last_dim_size // 2 + 1 460 return _CanonicalizeC2rReturn( 461 shape=tuple(shape_list), dim=dim, last_dim_size=last_dim_size 462 ) 463 464 465@register_decomposition(aten.fft_irfftn) 466@out_wrapper() 467def irfftn( 468 input: TensorLikeType, 469 s: Optional[ShapeType] = None, 470 dim: Optional[DimsType] = None, 471 norm: NormType = None, 472) -> TensorLikeType: 473 shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args( 474 "irfftn", input, s, dim 475 ) 476 input = _maybe_promote_tensor_fft(input, require_complex=True) 477 input = _resize_fft_input(input, dim, shape) 478 out = prims.fft_c2r(input, dim=dim, last_dim_size=last_dim_size) 479 return _apply_norm(out, norm, _prod(out.shape[d] for d in dim), forward=False) 480 481 482@register_decomposition(aten.fft_hfftn) 483@out_wrapper() 484def hfftn( 485 input: TensorLikeType, 486 s: Optional[ShapeType] = None, 487 dim: Optional[DimsType] = None, 488 norm: NormType = None, 489) -> TensorLikeType: 490 shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args( 491 "hfftn", input, s, dim 492 ) 493 input = _maybe_promote_tensor_fft(input, require_complex=True) 494 input = _resize_fft_input(input, dim, shape) 495 496 tmp = prims.fft_c2c(input, dim=dim[:-1], forward=True) if len(dim) > 1 else input 497 tmp = _apply_norm(tmp, norm, _prod(shape[:-1]), forward=True) 498 tmp = prims.conj_physical(tmp) 499 out = prims.fft_c2r(tmp, dim=dim[-1:], last_dim_size=last_dim_size) 500 return _apply_norm(out, norm, last_dim_size, forward=True) 501 502 503@register_decomposition(aten.fft_fft2) 504@out_wrapper() 505def fft2( 506 input: TensorLikeType, 507 s: Optional[ShapeType] = None, 508 dim: Optional[DimsType] = (-2, -1), 509 norm: NormType = None, 510) -> TensorLikeType: 511 return torch.fft.fftn(input, s=s, dim=dim, norm=norm) 512 513 514@register_decomposition(aten.fft_ifft2) 515@out_wrapper() 516def ifft2( 517 input: TensorLikeType, 518 s: Optional[ShapeType] = None, 519 dim: Optional[DimsType] = (-2, -1), 520 norm: NormType = None, 521) -> TensorLikeType: 522 return torch.fft.ifftn(input, s=s, dim=dim, norm=norm) 523 524 525@register_decomposition(aten.fft_rfft2) 526@out_wrapper() 527def rfft2( 528 input: TensorLikeType, 529 s: Optional[ShapeType] = None, 530 dim: Optional[DimsType] = (-2, -1), 531 norm: NormType = None, 532) -> TensorLikeType: 533 return torch.fft.rfftn(input, s=s, dim=dim, norm=norm) 534 535 536@register_decomposition(aten.fft_irfft2) 537@out_wrapper() 538def irfft2( 539 input: TensorLikeType, 540 s: Optional[ShapeType] = None, 541 dim: Optional[DimsType] = (-2, -1), 542 norm: NormType = None, 543) -> TensorLikeType: 544 return torch.fft.irfftn(input, s=s, dim=dim, norm=norm) 545 546 547@register_decomposition(aten.fft_hfft2) 548@out_wrapper() 549def hfft2( 550 input: TensorLikeType, 551 s: Optional[ShapeType] = None, 552 dim: Optional[DimsType] = (-2, -1), 553 norm: NormType = None, 554) -> TensorLikeType: 555 return torch.fft.hfftn(input, s=s, dim=dim, norm=norm) 556 557 558@register_decomposition(aten.fft_ihfft2) 559@out_wrapper() 560def ihfft2( 561 input: TensorLikeType, 562 s: Optional[ShapeType] = None, 563 dim: Optional[DimsType] = (-2, -1), 564 norm: NormType = None, 565) -> TensorLikeType: 566 return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm) 567 568 569def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> List[int]: 570 """Convert Optional[DimsType] to a simple list, defaulting to all dimensions""" 571 if dim is None: 572 return list(range(x.ndim)) 573 elif not isinstance(dim, Sequence): 574 return [dim] 575 else: 576 return list(dim) 577 578 579@register_decomposition(aten.fft_fftshift) 580def fftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: 581 dims = _default_alldims(dim, input) 582 shift = [input.shape[d] // 2 for d in dims] 583 return torch.roll(input, shift, dims) 584 585 586@register_decomposition(aten.fft_ifftshift) 587def ifftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: 588 dims = _default_alldims(dim, input) 589 shift = [(input.shape[d] + 1) // 2 for d in dims] 590 return torch.roll(input, shift, dims) 591