xref: /aosp_15_r20/external/pytorch/torch/_refs/fft.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import math
2from typing import Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union
3
4import torch
5import torch._prims as prims
6import torch._prims_common as utils
7from torch._decomp import register_decomposition
8from torch._prims_common import DimsType, ShapeType, TensorLikeType
9from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper
10
11
12__all__ = [
13    # Transforms
14    "fft",
15    "fft2",
16    "fftn",
17    "hfft",
18    "hfft2",
19    "hfftn",
20    "rfft",
21    "rfft2",
22    "rfftn",
23    "ifft",
24    "ifft2",
25    "ifftn",
26    "ihfft",
27    "ihfft2",
28    "ihfftn",
29    "irfft",
30    "irfft2",
31    "irfftn",
32    # Helpers
33    "fftshift",
34    "ifftshift",
35]
36
37NormType = Union[None, Literal["forward", "backward", "ortho"]]
38_NORM_VALUES = {None, "forward", "backward", "ortho"}
39aten = torch._ops.ops.aten
40
41
42def _apply_norm(
43    x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool
44) -> TensorLikeType:
45    """Apply normalization to the un-normalized FFT result"""
46    torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")
47
48    if norm == "ortho":
49        return x * (1 / math.sqrt(signal_numel))
50
51    normalize = (not forward and (norm is None or norm == "backward")) or (
52        forward and norm == "forward"
53    )
54    return x * (1 / signal_numel) if normalize else x
55
56
57def _promote_type_fft(
58    dtype: torch.dtype, require_complex: bool, device: torch.device
59) -> torch.dtype:
60    """Helper to promote a dtype to one supported by the FFT primitives"""
61    if dtype.is_complex:
62        return dtype
63
64    # Promote integral to default float type
65    if not dtype.is_floating_point:
66        dtype = torch.get_default_dtype()
67
68    allowed_types = [torch.float32, torch.float64]
69    maybe_support_half = device.type in ["cuda", "meta"]
70
71    if maybe_support_half:
72        allowed_types.append(torch.float16)
73    torch._check(dtype in allowed_types, lambda: f"Unsupported dtype {dtype}")
74
75    if require_complex:
76        dtype = utils.corresponding_complex_dtype(dtype)
77
78    return dtype
79
80
81def _maybe_promote_tensor_fft(
82    t: TensorLikeType, require_complex: bool = False
83) -> TensorLikeType:
84    """Helper to promote a tensor to a dtype supported by the FFT primitives"""
85    cur_type = t.dtype
86    new_type = _promote_type_fft(cur_type, require_complex, t.device)
87    return _maybe_convert_to_dtype(t, new_type)  # type: ignore[return-value]
88
89
90def _resize_fft_input(
91    x: TensorLikeType, dims: Tuple[int, ...], sizes: Tuple[int, ...]
92) -> TensorLikeType:
93    """
94    Fixes the shape of x such that x.size(dims[i]) == sizes[i],
95    either by zero-padding, or by slicing x starting from 0.
96    """
97    assert len(dims) == len(sizes)
98    must_copy = False
99    x_sizes = x.shape
100    pad_amount = [0] * len(x_sizes) * 2
101    for i in range(len(dims)):
102        if sizes[i] == -1:
103            continue
104
105        if x_sizes[dims[i]] < sizes[i]:
106            must_copy = True
107            pad_idx = len(pad_amount) - 2 * dims[i] - 1
108            pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]]
109
110        if x_sizes[dims[i]] > sizes[i]:
111            x = x.narrow(dims[i], 0, sizes[i])
112
113    return torch.constant_pad_nd(x, pad_amount) if must_copy else x
114
115
116def _fft_c2r(
117    func_name: str,
118    input: TensorLikeType,
119    n: Optional[int],
120    dim: int,
121    norm: NormType,
122    forward: bool,
123) -> TensorLikeType:
124    """Common code for performing any complex to real FFT (irfft or hfft)"""
125    input = _maybe_promote_tensor_fft(input, require_complex=True)
126    dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
127    last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
128    torch._check(
129        last_dim_size >= 1,
130        lambda: f"Invalid number of data points ({last_dim_size}) specified",
131    )
132
133    if n is not None:
134        input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,))
135
136    if forward:
137        input = torch.conj(input)
138
139    output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size)
140    return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward)
141
142
143def _fft_r2c(
144    func_name: str,
145    input: TensorLikeType,
146    n: Optional[int],
147    dim: int,
148    norm: NormType,
149    forward: bool,
150    onesided: bool,
151) -> TensorLikeType:
152    """Common code for performing any real to complex FFT (rfft or ihfft)"""
153    torch._check(
154        not input.dtype.is_complex,
155        lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}",
156    )
157    input = _maybe_promote_tensor_fft(input)
158    dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
159    dim_size = n if n is not None else input.shape[dim]
160    torch._check(
161        dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified"
162    )
163
164    if n is not None:
165        input = _resize_fft_input(input, dims, (n,))
166
167    ret = prims.fft_r2c(input, dim=dims, onesided=onesided)
168    ret = _apply_norm(ret, norm, dim_size, forward)
169    return ret if forward else torch.conj(ret)
170
171
172def _fft_c2c(
173    func_name: str,
174    input: TensorLikeType,
175    n: Optional[int],
176    dim: int,
177    norm: NormType,
178    forward: bool,
179) -> TensorLikeType:
180    """Common code for performing any complex to complex FFT (fft or ifft)"""
181    torch._check(
182        input.dtype.is_complex,
183        lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}",
184    )
185    dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
186    dim_size = n if n is not None else input.shape[dim]
187    torch._check(
188        dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified"
189    )
190
191    if n is not None:
192        input = _resize_fft_input(input, dims, (n,))
193
194    ret = prims.fft_c2c(input, dim=dims, forward=forward)
195    return _apply_norm(ret, norm, dim_size, forward)
196
197
198@register_decomposition(aten.fft_fft)
199@out_wrapper()
200def fft(
201    input: TensorLikeType,
202    n: Optional[int] = None,
203    dim: int = -1,
204    norm: NormType = None,
205) -> TensorLikeType:
206    if input.dtype.is_complex:
207        return _fft_c2c("fft", input, n, dim, norm, forward=True)
208    else:
209        return _fft_r2c("fft", input, n, dim, norm, forward=True, onesided=False)
210
211
212@register_decomposition(aten.fft_ifft)
213@out_wrapper()
214def ifft(
215    input: TensorLikeType,
216    n: Optional[int] = None,
217    dim: int = -1,
218    norm: NormType = None,
219) -> TensorLikeType:
220    if input.dtype.is_complex:
221        return _fft_c2c("ifft", input, n, dim, norm, forward=False)
222    else:
223        return _fft_r2c("ifft", input, n, dim, norm, forward=False, onesided=False)
224
225
226@register_decomposition(aten.fft_rfft)
227@out_wrapper()
228def rfft(
229    input: TensorLikeType,
230    n: Optional[int] = None,
231    dim: int = -1,
232    norm: NormType = None,
233) -> TensorLikeType:
234    return _fft_r2c("rfft", input, n, dim, norm, forward=True, onesided=True)
235
236
237@register_decomposition(aten.fft_irfft)
238@out_wrapper()
239def irfft(
240    input: TensorLikeType,
241    n: Optional[int] = None,
242    dim: int = -1,
243    norm: NormType = None,
244) -> TensorLikeType:
245    return _fft_c2r("irfft", input, n, dim, norm, forward=False)
246
247
248@register_decomposition(aten.fft_hfft)
249@out_wrapper()
250def hfft(
251    input: TensorLikeType,
252    n: Optional[int] = None,
253    dim: int = -1,
254    norm: NormType = None,
255) -> TensorLikeType:
256    return _fft_c2r("hfft", input, n, dim, norm, forward=True)
257
258
259@register_decomposition(aten.fft_ihfft)
260@out_wrapper()
261def ihfft(
262    input: TensorLikeType,
263    n: Optional[int] = None,
264    dim: int = -1,
265    norm: NormType = None,
266) -> TensorLikeType:
267    return _fft_r2c("ihfft", input, n, dim, norm, forward=False, onesided=True)
268
269
270class _ShapeAndDims(NamedTuple):
271    shape: Tuple[int, ...]
272    dims: Tuple[int, ...]
273
274
275def _canonicalize_fft_shape_and_dim_args(
276    input: TensorLikeType, shape: Optional[ShapeType], dim: Optional[DimsType]
277) -> _ShapeAndDims:
278    """Convert the shape and dim arguments into a canonical form where neither are optional"""
279    input_dim = input.ndim
280    input_sizes = input.shape
281
282    if dim is not None:
283        if not isinstance(dim, Sequence):
284            dim = (dim,)
285        ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False)
286
287        # Check dims are unique
288        torch._check(
289            len(set(ret_dims)) == len(ret_dims), lambda: "FFT dims must be unique"
290        )
291
292    if shape is not None:
293        if not isinstance(shape, Sequence):
294            shape = (shape,)
295
296        # Has shape, might have dim
297        torch._check(
298            dim is None or len(dim) == len(shape),
299            lambda: "When given, dim and shape arguments must have the same length",
300        )
301        transform_ndim = len(shape)
302
303        torch._check(
304            transform_ndim <= input_dim,
305            lambda: f"Got shape with {transform_ndim} values but input tensor "
306            f"only has {input_dim} dimensions.",
307        )
308
309        # If shape is given, dims defaults to the last len(shape) dimensions
310        if dim is None:
311            ret_dims = tuple(range(input_dim - transform_ndim, input_dim))
312
313        # Translate any -1 values in shape to the default length
314        ret_shape = tuple(
315            s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims)  # type: ignore[possibly-undefined]
316        )
317    elif dim is None:
318        # No shape, no dim
319        ret_dims = tuple(range(input_dim))
320        ret_shape = tuple(input_sizes)
321    else:
322        # No shape, has dim
323        ret_shape = tuple(input_sizes[d] for d in ret_dims)  # type: ignore[possibly-undefined]
324
325    for n in ret_shape:
326        torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
327
328    return _ShapeAndDims(shape=ret_shape, dims=ret_dims)  # type: ignore[possibly-undefined]
329
330
331def _prod(xs: Iterable[int]) -> int:
332    """Compute product of a list"""
333    prod = 1
334    for x in xs:
335        prod *= x
336    return prod
337
338
339def _fftn_c2c(
340    function_name: str,
341    input: TensorLikeType,
342    shape: Tuple[int, ...],
343    dim: Tuple[int, ...],
344    norm: NormType,
345    forward: bool,
346) -> TensorLikeType:
347    """Common code for n-dimensional complex to complex FFTs (fftn or ifftn)"""
348    torch._check(
349        input.dtype.is_complex,
350        lambda: f"{function_name} expects a complex input tensor, "
351        f"but got {input.dtype}",
352    )
353    x = _resize_fft_input(input, dim, shape)
354    output = prims.fft_c2c(x, dim=dim, forward=forward)
355    return _apply_norm(output, norm=norm, signal_numel=_prod(shape), forward=forward)
356
357
358@register_decomposition(aten.fft_fftn)
359@out_wrapper()
360def fftn(
361    input: TensorLikeType,
362    s: Optional[ShapeType] = None,
363    dim: Optional[DimsType] = None,
364    norm: NormType = None,
365) -> TensorLikeType:
366    (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
367    x = _maybe_promote_tensor_fft(input, require_complex=True)
368    return _fftn_c2c("fftn", x, shape, dim, norm, forward=True)
369
370
371@register_decomposition(aten.fft_ifftn)
372@out_wrapper()
373def ifftn(
374    input: TensorLikeType,
375    s: Optional[ShapeType] = None,
376    dim: Optional[DimsType] = None,
377    norm: NormType = None,
378) -> TensorLikeType:
379    (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
380    x = _maybe_promote_tensor_fft(input, require_complex=True)
381    return _fftn_c2c("ifftn", x, shape, dim, norm, forward=False)
382
383
384@register_decomposition(aten.fft_rfftn)
385@out_wrapper()
386def rfftn(
387    input: TensorLikeType,
388    s: Optional[ShapeType] = None,
389    dim: Optional[DimsType] = None,
390    norm: NormType = None,
391) -> TensorLikeType:
392    torch._check(
393        not input.dtype.is_complex,
394        lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}",
395    )
396    shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
397    input = _maybe_promote_tensor_fft(input, require_complex=False)
398    input = _resize_fft_input(input, dim, shape)
399    out = prims.fft_r2c(input, dim=dim, onesided=True)
400    return _apply_norm(out, norm=norm, signal_numel=_prod(shape), forward=True)
401
402
403@register_decomposition(aten.fft_ihfftn)
404@out_wrapper()
405def ihfftn(
406    input: TensorLikeType,
407    s: Optional[ShapeType] = None,
408    dim: Optional[DimsType] = None,
409    norm: NormType = None,
410) -> TensorLikeType:
411    torch._check(
412        not input.dtype.is_complex,
413        lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}",
414    )
415    shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
416    torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
417    input = _maybe_promote_tensor_fft(input, require_complex=False)
418    input = _resize_fft_input(input, dim, shape)
419
420    tmp = prims.fft_r2c(input, dim=dim[-1:], onesided=True)
421
422    if len(dim) == 1:
423        tmp = _apply_norm(tmp, norm=norm, signal_numel=shape[0], forward=False)
424        return prims.conj(tmp)
425
426    tmp = prims.conj_physical(tmp)
427    tmp = prims.fft_c2c(tmp, dim=dim[:-1], forward=False)
428    return _apply_norm(tmp, norm=norm, signal_numel=_prod(shape), forward=False)
429
430
431class _CanonicalizeC2rReturn(NamedTuple):
432    shape: Tuple[int, ...]
433    dim: Tuple[int, ...]
434    last_dim_size: int
435
436
437def _canonicalize_fft_c2r_shape_and_dim_args(
438    fname: str,
439    input: TensorLikeType,
440    s: Optional[ShapeType],
441    dim: Optional[DimsType],
442) -> _CanonicalizeC2rReturn:
443    """Canonicalize shape and dim arguments for n-dimensional c2r transforms,
444    as well as calculating the last_dim_size which is shape[dim[-1]] for the output"""
445    (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
446    torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")
447
448    if s is None or s[-1] == -1:
449        last_dim_size = 2 * (input.shape[dim[-1]] - 1)
450    else:
451        last_dim_size = shape[-1]
452
453    torch._check(
454        last_dim_size >= 1,
455        lambda: f"Invalid number of data points ({last_dim_size}) specified",
456    )
457
458    shape_list = list(shape)
459    shape_list[-1] = last_dim_size // 2 + 1
460    return _CanonicalizeC2rReturn(
461        shape=tuple(shape_list), dim=dim, last_dim_size=last_dim_size
462    )
463
464
465@register_decomposition(aten.fft_irfftn)
466@out_wrapper()
467def irfftn(
468    input: TensorLikeType,
469    s: Optional[ShapeType] = None,
470    dim: Optional[DimsType] = None,
471    norm: NormType = None,
472) -> TensorLikeType:
473    shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
474        "irfftn", input, s, dim
475    )
476    input = _maybe_promote_tensor_fft(input, require_complex=True)
477    input = _resize_fft_input(input, dim, shape)
478    out = prims.fft_c2r(input, dim=dim, last_dim_size=last_dim_size)
479    return _apply_norm(out, norm, _prod(out.shape[d] for d in dim), forward=False)
480
481
482@register_decomposition(aten.fft_hfftn)
483@out_wrapper()
484def hfftn(
485    input: TensorLikeType,
486    s: Optional[ShapeType] = None,
487    dim: Optional[DimsType] = None,
488    norm: NormType = None,
489) -> TensorLikeType:
490    shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
491        "hfftn", input, s, dim
492    )
493    input = _maybe_promote_tensor_fft(input, require_complex=True)
494    input = _resize_fft_input(input, dim, shape)
495
496    tmp = prims.fft_c2c(input, dim=dim[:-1], forward=True) if len(dim) > 1 else input
497    tmp = _apply_norm(tmp, norm, _prod(shape[:-1]), forward=True)
498    tmp = prims.conj_physical(tmp)
499    out = prims.fft_c2r(tmp, dim=dim[-1:], last_dim_size=last_dim_size)
500    return _apply_norm(out, norm, last_dim_size, forward=True)
501
502
503@register_decomposition(aten.fft_fft2)
504@out_wrapper()
505def fft2(
506    input: TensorLikeType,
507    s: Optional[ShapeType] = None,
508    dim: Optional[DimsType] = (-2, -1),
509    norm: NormType = None,
510) -> TensorLikeType:
511    return torch.fft.fftn(input, s=s, dim=dim, norm=norm)
512
513
514@register_decomposition(aten.fft_ifft2)
515@out_wrapper()
516def ifft2(
517    input: TensorLikeType,
518    s: Optional[ShapeType] = None,
519    dim: Optional[DimsType] = (-2, -1),
520    norm: NormType = None,
521) -> TensorLikeType:
522    return torch.fft.ifftn(input, s=s, dim=dim, norm=norm)
523
524
525@register_decomposition(aten.fft_rfft2)
526@out_wrapper()
527def rfft2(
528    input: TensorLikeType,
529    s: Optional[ShapeType] = None,
530    dim: Optional[DimsType] = (-2, -1),
531    norm: NormType = None,
532) -> TensorLikeType:
533    return torch.fft.rfftn(input, s=s, dim=dim, norm=norm)
534
535
536@register_decomposition(aten.fft_irfft2)
537@out_wrapper()
538def irfft2(
539    input: TensorLikeType,
540    s: Optional[ShapeType] = None,
541    dim: Optional[DimsType] = (-2, -1),
542    norm: NormType = None,
543) -> TensorLikeType:
544    return torch.fft.irfftn(input, s=s, dim=dim, norm=norm)
545
546
547@register_decomposition(aten.fft_hfft2)
548@out_wrapper()
549def hfft2(
550    input: TensorLikeType,
551    s: Optional[ShapeType] = None,
552    dim: Optional[DimsType] = (-2, -1),
553    norm: NormType = None,
554) -> TensorLikeType:
555    return torch.fft.hfftn(input, s=s, dim=dim, norm=norm)
556
557
558@register_decomposition(aten.fft_ihfft2)
559@out_wrapper()
560def ihfft2(
561    input: TensorLikeType,
562    s: Optional[ShapeType] = None,
563    dim: Optional[DimsType] = (-2, -1),
564    norm: NormType = None,
565) -> TensorLikeType:
566    return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm)
567
568
569def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> List[int]:
570    """Convert Optional[DimsType] to a simple list, defaulting to all dimensions"""
571    if dim is None:
572        return list(range(x.ndim))
573    elif not isinstance(dim, Sequence):
574        return [dim]
575    else:
576        return list(dim)
577
578
579@register_decomposition(aten.fft_fftshift)
580def fftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
581    dims = _default_alldims(dim, input)
582    shift = [input.shape[d] // 2 for d in dims]
583    return torch.roll(input, shift, dims)
584
585
586@register_decomposition(aten.fft_ifftshift)
587def ifftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
588    dims = _default_alldims(dim, input)
589    shift = [(input.shape[d] + 1) // 2 for d in dims]
590    return torch.roll(input, shift, dims)
591