xref: /aosp_15_r20/external/pytorch/test/test_spectral_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: fft"]
2
3import torch
4import unittest
5import math
6from contextlib import contextmanager
7from itertools import product
8import itertools
9import doctest
10import inspect
11
12from torch.testing._internal.common_utils import \
13    (TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL, first_sample, TEST_WITH_ROCM,
14     make_tensor, skipIfTorchDynamo)
15from torch.testing._internal.common_device_type import \
16    (instantiate_device_type_tests, ops, dtypes, onlyNativeDeviceTypes,
17     skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf, toleranceOverride, tol)
18from torch.testing._internal.common_methods_invocations import (
19    spectral_funcs, SpectralFuncType)
20from torch.testing._internal.common_cuda import SM53OrLater
21from torch._prims_common import corresponding_complex_dtype
22
23from typing import Optional, List
24from packaging import version
25
26
27if TEST_NUMPY:
28    import numpy as np
29
30
31if TEST_LIBROSA:
32    import librosa
33
34has_scipy_fft = False
35try:
36    import scipy.fft
37    has_scipy_fft = True
38except ModuleNotFoundError:
39    pass
40
41REFERENCE_NORM_MODES = (
42    (None, "forward", "backward", "ortho")
43    if version.parse(np.__version__) >= version.parse('1.20.0') and (
44        not has_scipy_fft or version.parse(scipy.__version__) >= version.parse('1.6.0'))
45    else (None, "ortho"))
46
47
48def _complex_stft(x, *args, **kwargs):
49    # Transform real and imaginary components separably
50    stft_real = torch.stft(x.real, *args, **kwargs, return_complex=True, onesided=False)
51    stft_imag = torch.stft(x.imag, *args, **kwargs, return_complex=True, onesided=False)
52    return stft_real + 1j * stft_imag
53
54
55def _hermitian_conj(x, dim):
56    """Returns the hermitian conjugate along a single dimension
57
58    H(x)[i] = conj(x[-i])
59    """
60    out = torch.empty_like(x)
61    mid = (x.size(dim) - 1) // 2
62    idx = [slice(None)] * out.dim()
63    idx_center = list(idx)
64    idx_center[dim] = 0
65    out[idx] = x[idx]
66
67    idx_neg = list(idx)
68    idx_neg[dim] = slice(-mid, None)
69    idx_pos = idx
70    idx_pos[dim] = slice(1, mid + 1)
71
72    out[idx_pos] = x[idx_neg].flip(dim)
73    out[idx_neg] = x[idx_pos].flip(dim)
74    if (2 * mid + 1 < x.size(dim)):
75        idx[dim] = mid + 1
76        out[idx] = x[idx]
77    return out.conj()
78
79
80def _complex_istft(x, *args, **kwargs):
81    # Decompose into Hermitian (FFT of real) and anti-Hermitian (FFT of imaginary)
82    n_fft = x.size(-2)
83    slc = (Ellipsis, slice(None, n_fft // 2 + 1), slice(None))
84
85    hconj = _hermitian_conj(x, dim=-2)
86    x_hermitian = (x + hconj) / 2
87    x_antihermitian = (x - hconj) / 2
88    istft_real = torch.istft(x_hermitian[slc], *args, **kwargs, onesided=True)
89    istft_imag = torch.istft(-1j * x_antihermitian[slc], *args, **kwargs, onesided=True)
90    return torch.complex(istft_real, istft_imag)
91
92
93def _stft_reference(x, hop_length, window):
94    r"""Reference stft implementation
95
96    This doesn't implement all of torch.stft, only the STFT definition:
97
98    .. math:: X(m, \omega) = \sum_n x[n]w[n - m] e^{-jn\omega}
99
100    """
101    n_fft = window.numel()
102    X = torch.empty((n_fft, (x.numel() - n_fft + hop_length) // hop_length),
103                    device=x.device, dtype=torch.cdouble)
104    for m in range(X.size(1)):
105        start = m * hop_length
106        if start + n_fft > x.numel():
107            slc = torch.empty(n_fft, device=x.device, dtype=x.dtype)
108            tmp = x[start:]
109            slc[:tmp.numel()] = tmp
110        else:
111            slc = x[start: start + n_fft]
112        X[:, m] = torch.fft.fft(slc * window)
113    return X
114
115
116def skip_helper_for_fft(device, dtype):
117    device_type = torch.device(device).type
118    if dtype not in (torch.half, torch.complex32):
119        return
120
121    if device_type == 'cpu':
122        raise unittest.SkipTest("half and complex32 are not supported on CPU")
123    if not SM53OrLater:
124        raise unittest.SkipTest("half and complex32 are only supported on CUDA device with SM>53")
125
126
127# Tests of functions related to Fourier analysis in the torch.fft namespace
128class TestFFT(TestCase):
129    exact_dtype = True
130
131    @onlyNativeDeviceTypes
132    @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.OneD],
133         allowed_dtypes=(torch.float, torch.cfloat))
134    def test_reference_1d(self, device, dtype, op):
135        if op.ref is None:
136            raise unittest.SkipTest("No reference implementation")
137
138        norm_modes = REFERENCE_NORM_MODES
139        test_args = [
140            *product(
141                # input
142                (torch.randn(67, device=device, dtype=dtype),
143                 torch.randn(80, device=device, dtype=dtype),
144                 torch.randn(12, 14, device=device, dtype=dtype),
145                 torch.randn(9, 6, 3, device=device, dtype=dtype)),
146                # n
147                (None, 50, 6),
148                # dim
149                (-1, 0),
150                # norm
151                norm_modes
152            ),
153            # Test transforming middle dimensions of multi-dim tensor
154            *product(
155                (torch.randn(4, 5, 6, 7, device=device, dtype=dtype),),
156                (None,),
157                (1, 2, -2,),
158                norm_modes
159            )
160        ]
161
162        for iargs in test_args:
163            args = list(iargs)
164            input = args[0]
165            args = args[1:]
166
167            expected = op.ref(input.cpu().numpy(), *args)
168            exact_dtype = dtype in (torch.double, torch.complex128)
169            actual = op(input, *args)
170            self.assertEqual(actual, expected, exact_dtype=exact_dtype)
171
172    @skipCPUIfNoFFT
173    @onlyNativeDeviceTypes
174    @toleranceOverride({
175        torch.half : tol(1e-2, 1e-2),
176        torch.chalf : tol(1e-2, 1e-2),
177    })
178    @dtypes(torch.half, torch.float, torch.double, torch.complex32, torch.complex64, torch.complex128)
179    def test_fft_round_trip(self, device, dtype):
180        skip_helper_for_fft(device, dtype)
181        # Test that round trip through ifft(fft(x)) is the identity
182        if dtype not in (torch.half, torch.complex32):
183            test_args = list(product(
184                # input
185                (torch.randn(67, device=device, dtype=dtype),
186                 torch.randn(80, device=device, dtype=dtype),
187                 torch.randn(12, 14, device=device, dtype=dtype),
188                 torch.randn(9, 6, 3, device=device, dtype=dtype)),
189                # dim
190                (-1, 0),
191                # norm
192                (None, "forward", "backward", "ortho")
193            ))
194        else:
195            # cuFFT supports powers of 2 for half and complex half precision
196            test_args = list(product(
197                # input
198                (torch.randn(64, device=device, dtype=dtype),
199                 torch.randn(128, device=device, dtype=dtype),
200                 torch.randn(4, 16, device=device, dtype=dtype),
201                 torch.randn(8, 6, 2, device=device, dtype=dtype)),
202                # dim
203                (-1, 0),
204                # norm
205                (None, "forward", "backward", "ortho")
206            ))
207
208        fft_functions = [(torch.fft.fft, torch.fft.ifft)]
209        # Real-only functions
210        if not dtype.is_complex:
211            # NOTE: Using ihfft as "forward" transform to avoid needing to
212            # generate true half-complex input
213            fft_functions += [(torch.fft.rfft, torch.fft.irfft),
214                              (torch.fft.ihfft, torch.fft.hfft)]
215
216        for forward, backward in fft_functions:
217            for x, dim, norm in test_args:
218                kwargs = {
219                    'n': x.size(dim),
220                    'dim': dim,
221                    'norm': norm,
222                }
223
224                y = backward(forward(x, **kwargs), **kwargs)
225                if x.dtype is torch.half and y.dtype is torch.complex32:
226                    # Since type promotion currently doesn't work with complex32
227                    # manually promote `x` to complex32
228                    x = x.to(torch.complex32)
229                # For real input, ifft(fft(x)) will convert to complex
230                self.assertEqual(x, y, exact_dtype=(
231                    forward != torch.fft.fft or x.is_complex()))
232
233    # Note: NumPy will throw a ValueError for an empty input
234    @onlyNativeDeviceTypes
235    @ops(spectral_funcs, allowed_dtypes=(torch.half, torch.float, torch.complex32, torch.cfloat))
236    def test_empty_fft(self, device, dtype, op):
237        t = torch.empty(1, 0, device=device, dtype=dtype)
238        match = r"Invalid number of data points \([-\d]*\) specified"
239
240        with self.assertRaisesRegex(RuntimeError, match):
241            op(t)
242
243    @onlyNativeDeviceTypes
244    def test_empty_ifft(self, device):
245        t = torch.empty(2, 1, device=device, dtype=torch.complex64)
246        match = r"Invalid number of data points \([-\d]*\) specified"
247
248        for f in [torch.fft.irfft, torch.fft.irfft2, torch.fft.irfftn,
249                  torch.fft.hfft, torch.fft.hfft2, torch.fft.hfftn]:
250            with self.assertRaisesRegex(RuntimeError, match):
251                f(t)
252
253    @onlyNativeDeviceTypes
254    def test_fft_invalid_dtypes(self, device):
255        t = torch.randn(64, device=device, dtype=torch.complex128)
256
257        with self.assertRaisesRegex(RuntimeError, "rfft expects a real input tensor"):
258            torch.fft.rfft(t)
259
260        with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input tensor"):
261            torch.fft.rfftn(t)
262
263        with self.assertRaisesRegex(RuntimeError, "ihfft expects a real input tensor"):
264            torch.fft.ihfft(t)
265
266    @skipCPUIfNoFFT
267    @onlyNativeDeviceTypes
268    @dtypes(torch.int8, torch.half, torch.float, torch.double,
269            torch.complex32, torch.complex64, torch.complex128)
270    def test_fft_type_promotion(self, device, dtype):
271        skip_helper_for_fft(device, dtype)
272
273        if dtype.is_complex or dtype.is_floating_point:
274            t = torch.randn(64, device=device, dtype=dtype)
275        else:
276            t = torch.randint(-2, 2, (64,), device=device, dtype=dtype)
277
278        PROMOTION_MAP = {
279            torch.int8: torch.complex64,
280            torch.half: torch.complex32,
281            torch.float: torch.complex64,
282            torch.double: torch.complex128,
283            torch.complex32: torch.complex32,
284            torch.complex64: torch.complex64,
285            torch.complex128: torch.complex128,
286        }
287        T = torch.fft.fft(t)
288        self.assertEqual(T.dtype, PROMOTION_MAP[dtype])
289
290        PROMOTION_MAP_C2R = {
291            torch.int8: torch.float,
292            torch.half: torch.half,
293            torch.float: torch.float,
294            torch.double: torch.double,
295            torch.complex32: torch.half,
296            torch.complex64: torch.float,
297            torch.complex128: torch.double,
298        }
299        if dtype in (torch.half, torch.complex32):
300            # cuFFT supports powers of 2 for half and complex half precision
301            # NOTE: With hfft and default args where output_size n=2*(input_size - 1),
302            # we make sure that logical fft size is a power of two.
303            x = torch.randn(65, device=device, dtype=dtype)
304            R = torch.fft.hfft(x)
305        else:
306            R = torch.fft.hfft(t)
307        self.assertEqual(R.dtype, PROMOTION_MAP_C2R[dtype])
308
309        if not dtype.is_complex:
310            PROMOTION_MAP_R2C = {
311                torch.int8: torch.complex64,
312                torch.half: torch.complex32,
313                torch.float: torch.complex64,
314                torch.double: torch.complex128,
315            }
316            C = torch.fft.rfft(t)
317            self.assertEqual(C.dtype, PROMOTION_MAP_R2C[dtype])
318
319    @onlyNativeDeviceTypes
320    @ops(spectral_funcs, dtypes=OpDTypes.unsupported,
321         allowed_dtypes=[torch.half, torch.bfloat16])
322    def test_fft_half_and_bfloat16_errors(self, device, dtype, op):
323        # TODO: Remove torch.half error when complex32 is fully implemented
324        sample = first_sample(self, op.sample_inputs(device, dtype))
325        device_type = torch.device(device).type
326        default_msg = "Unsupported dtype"
327        if dtype is torch.half and device_type == 'cuda' and TEST_WITH_ROCM:
328            err_msg = default_msg
329        elif dtype is torch.half and device_type == 'cuda' and not SM53OrLater:
330            err_msg = "cuFFT doesn't support signals of half type with compute capability less than SM_53"
331        else:
332            err_msg = default_msg
333        with self.assertRaisesRegex(RuntimeError, err_msg):
334            op(sample.input, *sample.args, **sample.kwargs)
335
336    @onlyNativeDeviceTypes
337    @ops(spectral_funcs, allowed_dtypes=(torch.half, torch.chalf))
338    def test_fft_half_and_chalf_not_power_of_two_error(self, device, dtype, op):
339        t = make_tensor(13, 13, device=device, dtype=dtype)
340        err_msg = "cuFFT only supports dimensions whose sizes are powers of two"
341        with self.assertRaisesRegex(RuntimeError, err_msg):
342            op(t)
343
344        if op.ndimensional in (SpectralFuncType.ND, SpectralFuncType.TwoD):
345            kwargs = {'s': (12, 12)}
346        else:
347            kwargs = {'n': 12}
348
349        with self.assertRaisesRegex(RuntimeError, err_msg):
350            op(t, **kwargs)
351
352    # nd-fft tests
353    @onlyNativeDeviceTypes
354    @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
355    @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND],
356         allowed_dtypes=(torch.cfloat, torch.cdouble))
357    def test_reference_nd(self, device, dtype, op):
358        if op.ref is None:
359            raise unittest.SkipTest("No reference implementation")
360
361        norm_modes = REFERENCE_NORM_MODES
362
363        # input_ndim, s, dim
364        transform_desc = [
365            *product(range(2, 5), (None,), (None, (0,), (0, -1))),
366            *product(range(2, 5), (None, (4, 10)), (None,)),
367            (6, None, None),
368            (5, None, (1, 3, 4)),
369            (3, None, (1,)),
370            (1, None, (0,)),
371            (4, (10, 10), None),
372            (4, (10, 10), (0, 1))
373        ]
374
375        for input_ndim, s, dim in transform_desc:
376            shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
377            input = torch.randn(*shape, device=device, dtype=dtype)
378
379            for norm in norm_modes:
380                expected = op.ref(input.cpu().numpy(), s, dim, norm)
381                exact_dtype = dtype in (torch.double, torch.complex128)
382                actual = op(input, s, dim, norm)
383                self.assertEqual(actual, expected, exact_dtype=exact_dtype)
384
385    @skipCPUIfNoFFT
386    @onlyNativeDeviceTypes
387    @toleranceOverride({
388        torch.half : tol(1e-2, 1e-2),
389        torch.chalf : tol(1e-2, 1e-2),
390    })
391    @dtypes(torch.half, torch.float, torch.double,
392            torch.complex32, torch.complex64, torch.complex128)
393    def test_fftn_round_trip(self, device, dtype):
394        skip_helper_for_fft(device, dtype)
395
396        norm_modes = (None, "forward", "backward", "ortho")
397
398        # input_ndim, dim
399        transform_desc = [
400            *product(range(2, 5), (None, (0,), (0, -1))),
401            (7, None),
402            (5, (1, 3, 4)),
403            (3, (1,)),
404            (1, 0),
405        ]
406
407        fft_functions = [(torch.fft.fftn, torch.fft.ifftn)]
408
409        # Real-only functions
410        if not dtype.is_complex:
411            # NOTE: Using ihfftn as "forward" transform to avoid needing to
412            # generate true half-complex input
413            fft_functions += [(torch.fft.rfftn, torch.fft.irfftn),
414                              (torch.fft.ihfftn, torch.fft.hfftn)]
415
416        for input_ndim, dim in transform_desc:
417            if dtype in (torch.half, torch.complex32):
418                # cuFFT supports powers of 2 for half and complex half precision
419                shape = itertools.islice(itertools.cycle((2, 4, 8)), input_ndim)
420            else:
421                shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
422            x = torch.randn(*shape, device=device, dtype=dtype)
423
424            for (forward, backward), norm in product(fft_functions, norm_modes):
425                if isinstance(dim, tuple):
426                    s = [x.size(d) for d in dim]
427                else:
428                    s = x.size() if dim is None else x.size(dim)
429
430                kwargs = {'s': s, 'dim': dim, 'norm': norm}
431                y = backward(forward(x, **kwargs), **kwargs)
432                # For real input, ifftn(fftn(x)) will convert to complex
433                if x.dtype is torch.half and y.dtype is torch.chalf:
434                    # Since type promotion currently doesn't work with complex32
435                    # manually promote `x` to complex32
436                    self.assertEqual(x.to(torch.chalf), y)
437                else:
438                    self.assertEqual(x, y, exact_dtype=(
439                        forward != torch.fft.fftn or x.is_complex()))
440
441    @onlyNativeDeviceTypes
442    @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND],
443         allowed_dtypes=[torch.float, torch.cfloat])
444    def test_fftn_invalid(self, device, dtype, op):
445        a = torch.rand(10, 10, 10, device=device, dtype=dtype)
446        # FIXME: https://github.com/pytorch/pytorch/issues/108205
447        errMsg = "dims must be unique"
448        with self.assertRaisesRegex(RuntimeError, errMsg):
449            op(a, dim=(0, 1, 0))
450
451        with self.assertRaisesRegex(RuntimeError, errMsg):
452            op(a, dim=(2, -1))
453
454        with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"):
455            op(a, s=(1,), dim=(0, 1))
456
457        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
458            op(a, dim=(3,))
459
460        with self.assertRaisesRegex(RuntimeError, "tensor only has 3 dimensions"):
461            op(a, s=(10, 10, 10, 10))
462
463    @skipCPUIfNoFFT
464    @onlyNativeDeviceTypes
465    @dtypes(torch.half, torch.float, torch.double, torch.cfloat, torch.cdouble)
466    def test_fftn_noop_transform(self, device, dtype):
467        skip_helper_for_fft(device, dtype)
468        RESULT_TYPE = {
469            torch.half: torch.chalf,
470            torch.float: torch.cfloat,
471            torch.double: torch.cdouble,
472        }
473
474        for op in [
475            torch.fft.fftn,
476            torch.fft.ifftn,
477            torch.fft.fft2,
478            torch.fft.ifft2,
479        ]:
480            inp = make_tensor((10, 10), device=device, dtype=dtype)
481            out = torch.fft.fftn(inp, dim=[])
482
483            expect_dtype = RESULT_TYPE.get(inp.dtype, inp.dtype)
484            expect = inp.to(expect_dtype)
485            self.assertEqual(expect, out)
486
487
488    @skipCPUIfNoFFT
489    @onlyNativeDeviceTypes
490    @toleranceOverride({
491        torch.half : tol(1e-2, 1e-2),
492    })
493    @dtypes(torch.half, torch.float, torch.double)
494    def test_hfftn(self, device, dtype):
495        skip_helper_for_fft(device, dtype)
496
497        # input_ndim, dim
498        transform_desc = [
499            *product(range(2, 5), (None, (0,), (0, -1))),
500            (6, None),
501            (5, (1, 3, 4)),
502            (3, (1,)),
503            (1, (0,)),
504            (4, (0, 1))
505        ]
506
507        for input_ndim, dim in transform_desc:
508            actual_dims = list(range(input_ndim)) if dim is None else dim
509            if dtype is torch.half:
510                shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim))
511            else:
512                shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
513            expect = torch.randn(*shape, device=device, dtype=dtype)
514            input = torch.fft.ifftn(expect, dim=dim, norm="ortho")
515
516            lastdim = actual_dims[-1]
517            lastdim_size = input.size(lastdim) // 2 + 1
518            idx = [slice(None)] * input_ndim
519            idx[lastdim] = slice(0, lastdim_size)
520            input = input[idx]
521
522            s = [shape[dim] for dim in actual_dims]
523            actual = torch.fft.hfftn(input, s=s, dim=dim, norm="ortho")
524
525            self.assertEqual(expect, actual)
526
527    @skipCPUIfNoFFT
528    @onlyNativeDeviceTypes
529    @toleranceOverride({
530        torch.half : tol(1e-2, 1e-2),
531    })
532    @dtypes(torch.half, torch.float, torch.double)
533    def test_ihfftn(self, device, dtype):
534        skip_helper_for_fft(device, dtype)
535
536        # input_ndim, dim
537        transform_desc = [
538            *product(range(2, 5), (None, (0,), (0, -1))),
539            (6, None),
540            (5, (1, 3, 4)),
541            (3, (1,)),
542            (1, (0,)),
543            (4, (0, 1))
544        ]
545
546        for input_ndim, dim in transform_desc:
547            if dtype is torch.half:
548                shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim))
549            else:
550                shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
551
552            input = torch.randn(*shape, device=device, dtype=dtype)
553            expect = torch.fft.ifftn(input, dim=dim, norm="ortho")
554
555            # Slice off the half-symmetric component
556            lastdim = -1 if dim is None else dim[-1]
557            lastdim_size = expect.size(lastdim) // 2 + 1
558            idx = [slice(None)] * input_ndim
559            idx[lastdim] = slice(0, lastdim_size)
560            expect = expect[idx]
561
562            actual = torch.fft.ihfftn(input, dim=dim, norm="ortho")
563            self.assertEqual(expect, actual)
564
565
566    # 2d-fft tests
567
568    # NOTE: 2d transforms are only thin wrappers over n-dim transforms,
569    # so don't require exhaustive testing.
570
571
572    @skipCPUIfNoFFT
573    @onlyNativeDeviceTypes
574    @dtypes(torch.double, torch.complex128)
575    def test_fft2_numpy(self, device, dtype):
576        norm_modes = REFERENCE_NORM_MODES
577
578        # input_ndim, s
579        transform_desc = [
580            *product(range(2, 5), (None, (4, 10))),
581        ]
582
583        fft_functions = ['fft2', 'ifft2', 'irfft2', 'hfft2']
584        if dtype.is_floating_point:
585            fft_functions += ['rfft2', 'ihfft2']
586
587        for input_ndim, s in transform_desc:
588            shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
589            input = torch.randn(*shape, device=device, dtype=dtype)
590            for fname, norm in product(fft_functions, norm_modes):
591                torch_fn = getattr(torch.fft, fname)
592                if "hfft" in fname:
593                    if not has_scipy_fft:
594                        continue  # Requires scipy to compare against
595                    numpy_fn = getattr(scipy.fft, fname)
596                else:
597                    numpy_fn = getattr(np.fft, fname)
598
599                def fn(t: torch.Tensor, s: Optional[List[int]], dim: List[int] = (-2, -1), norm: Optional[str] = None):
600                    return torch_fn(t, s, dim, norm)
601
602                torch_fns = (torch_fn, torch.jit.script(fn))
603
604                # Once with dim defaulted
605                input_np = input.cpu().numpy()
606                expected = numpy_fn(input_np, s, norm=norm)
607                for fn in torch_fns:
608                    actual = fn(input, s, norm=norm)
609                    self.assertEqual(actual, expected)
610
611                # Once with explicit dims
612                dim = (1, 0)
613                expected = numpy_fn(input_np, s, dim, norm)
614                for fn in torch_fns:
615                    actual = fn(input, s, dim, norm)
616                    self.assertEqual(actual, expected)
617
618    @skipCPUIfNoFFT
619    @onlyNativeDeviceTypes
620    @dtypes(torch.float, torch.complex64)
621    def test_fft2_fftn_equivalence(self, device, dtype):
622        norm_modes = (None, "forward", "backward", "ortho")
623
624        # input_ndim, s, dim
625        transform_desc = [
626            *product(range(2, 5), (None, (4, 10)), (None, (1, 0))),
627            (3, None, (0, 2)),
628        ]
629
630        fft_functions = ['fft', 'ifft', 'irfft', 'hfft']
631        # Real-only functions
632        if dtype.is_floating_point:
633            fft_functions += ['rfft', 'ihfft']
634
635        for input_ndim, s, dim in transform_desc:
636            shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
637            x = torch.randn(*shape, device=device, dtype=dtype)
638
639            for func, norm in product(fft_functions, norm_modes):
640                f2d = getattr(torch.fft, func + '2')
641                fnd = getattr(torch.fft, func + 'n')
642
643                kwargs = {'s': s, 'norm': norm}
644
645                if dim is not None:
646                    kwargs['dim'] = dim
647                    expect = fnd(x, **kwargs)
648                else:
649                    expect = fnd(x, dim=(-2, -1), **kwargs)
650
651                actual = f2d(x, **kwargs)
652
653                self.assertEqual(actual, expect)
654
655    @skipCPUIfNoFFT
656    @onlyNativeDeviceTypes
657    def test_fft2_invalid(self, device):
658        a = torch.rand(10, 10, 10, device=device)
659        fft_funcs = (torch.fft.fft2, torch.fft.ifft2,
660                     torch.fft.rfft2, torch.fft.irfft2)
661
662        for func in fft_funcs:
663            with self.assertRaisesRegex(RuntimeError, "dims must be unique"):
664                func(a, dim=(0, 0))
665
666            with self.assertRaisesRegex(RuntimeError, "dims must be unique"):
667                func(a, dim=(2, -1))
668
669            with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"):
670                func(a, s=(1,))
671
672            with self.assertRaisesRegex(IndexError, "Dimension out of range"):
673                func(a, dim=(2, 3))
674
675        c = torch.complex(a, a)
676        with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input"):
677            torch.fft.rfft2(c)
678
679    # Helper functions
680
681    @skipCPUIfNoFFT
682    @onlyNativeDeviceTypes
683    @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
684    @dtypes(torch.float, torch.double)
685    def test_fftfreq_numpy(self, device, dtype):
686        test_args = [
687            *product(
688                # n
689                range(1, 20),
690                # d
691                (None, 10.0),
692            )
693        ]
694
695        functions = ['fftfreq', 'rfftfreq']
696
697        for fname in functions:
698            torch_fn = getattr(torch.fft, fname)
699            numpy_fn = getattr(np.fft, fname)
700
701            for n, d in test_args:
702                args = (n,) if d is None else (n, d)
703                expected = numpy_fn(*args)
704                actual = torch_fn(*args, device=device, dtype=dtype)
705                self.assertEqual(actual, expected, exact_dtype=False)
706
707    @skipCPUIfNoFFT
708    @onlyNativeDeviceTypes
709    @dtypes(torch.float, torch.double)
710    def test_fftfreq_out(self, device, dtype):
711        for func in (torch.fft.fftfreq, torch.fft.rfftfreq):
712            expect = func(n=100, d=.5, device=device, dtype=dtype)
713            actual = torch.empty((), device=device, dtype=dtype)
714            with self.assertWarnsRegex(UserWarning, "out tensor will be resized"):
715                func(n=100, d=.5, out=actual)
716            self.assertEqual(actual, expect)
717
718
719    @skipCPUIfNoFFT
720    @onlyNativeDeviceTypes
721    @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
722    @dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
723    def test_fftshift_numpy(self, device, dtype):
724        test_args = [
725            # shape, dim
726            *product(((11,), (12,)), (None, 0, -1)),
727            *product(((4, 5), (6, 6)), (None, 0, (-1,))),
728            *product(((1, 1, 4, 6, 7, 2),), (None, (3, 4))),
729        ]
730
731        functions = ['fftshift', 'ifftshift']
732
733        for shape, dim in test_args:
734            input = torch.rand(*shape, device=device, dtype=dtype)
735            input_np = input.cpu().numpy()
736
737            for fname in functions:
738                torch_fn = getattr(torch.fft, fname)
739                numpy_fn = getattr(np.fft, fname)
740
741                expected = numpy_fn(input_np, axes=dim)
742                actual = torch_fn(input, dim=dim)
743                self.assertEqual(actual, expected)
744
745    @skipCPUIfNoFFT
746    @onlyNativeDeviceTypes
747    @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
748    @dtypes(torch.float, torch.double)
749    def test_fftshift_frequencies(self, device, dtype):
750        for n in range(10, 15):
751            sorted_fft_freqs = torch.arange(-(n // 2), n - (n // 2),
752                                            device=device, dtype=dtype)
753            x = torch.fft.fftfreq(n, d=1 / n, device=device, dtype=dtype)
754
755            # Test fftshift sorts the fftfreq output
756            shifted = torch.fft.fftshift(x)
757            self.assertEqual(shifted, shifted.sort().values)
758            self.assertEqual(sorted_fft_freqs, shifted)
759
760            # And ifftshift is the inverse
761            self.assertEqual(x, torch.fft.ifftshift(shifted))
762
763    # Legacy fft tests
764    def _test_fft_ifft_rfft_irfft(self, device, dtype):
765        complex_dtype = corresponding_complex_dtype(dtype)
766
767        def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x):
768            x = prepro_fn(torch.randn(*sizes, dtype=complex_dtype, device=device))
769            dim = tuple(range(-signal_ndim, 0))
770            for norm in ('ortho', None):
771                res = torch.fft.fftn(x, dim=dim, norm=norm)
772                rec = torch.fft.ifftn(res, dim=dim, norm=norm)
773                self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='fft and ifft')
774                res = torch.fft.ifftn(x, dim=dim, norm=norm)
775                rec = torch.fft.fftn(res, dim=dim, norm=norm)
776                self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='ifft and fft')
777
778        def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x):
779            x = prepro_fn(torch.randn(*sizes, dtype=dtype, device=device))
780            signal_numel = 1
781            signal_sizes = x.size()[-signal_ndim:]
782            dim = tuple(range(-signal_ndim, 0))
783            for norm in (None, 'ortho'):
784                res = torch.fft.rfftn(x, dim=dim, norm=norm)
785                rec = torch.fft.irfftn(res, s=signal_sizes, dim=dim, norm=norm)
786                self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='rfft and irfft')
787                res = torch.fft.fftn(x, dim=dim, norm=norm)
788                rec = torch.fft.ifftn(res, dim=dim, norm=norm)
789                x_complex = torch.complex(x, torch.zeros_like(x))
790                self.assertEqual(x_complex, rec, atol=1e-8, rtol=0, msg='fft and ifft (from real)')
791
792        # contiguous case
793        _test_real((100,), 1)
794        _test_real((10, 1, 10, 100), 1)
795        _test_real((100, 100), 2)
796        _test_real((2, 2, 5, 80, 60), 2)
797        _test_real((50, 40, 70), 3)
798        _test_real((30, 1, 50, 25, 20), 3)
799
800        _test_complex((100,), 1)
801        _test_complex((100, 100), 1)
802        _test_complex((100, 100), 2)
803        _test_complex((1, 20, 80, 60), 2)
804        _test_complex((50, 40, 70), 3)
805        _test_complex((6, 5, 50, 25, 20), 3)
806
807        # non-contiguous case
808        _test_real((165,), 1, lambda x: x.narrow(0, 25, 100))  # input is not aligned to complex type
809        _test_real((100, 100, 3), 1, lambda x: x[:, :, 0])
810        _test_real((100, 100), 2, lambda x: x.t())
811        _test_real((20, 100, 10, 10), 2, lambda x: x.view(20, 100, 100)[:, :60])
812        _test_real((65, 80, 115), 3, lambda x: x[10:60, 13:53, 10:80])
813        _test_real((30, 20, 50, 25), 3, lambda x: x.transpose(1, 2).transpose(2, 3))
814
815        _test_complex((100,), 1, lambda x: x.expand(100, 100))
816        _test_complex((20, 90, 110), 2, lambda x: x[:, 5:85].narrow(2, 5, 100))
817        _test_complex((40, 60, 3, 80), 3, lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:])
818        _test_complex((30, 55, 50, 22), 3, lambda x: x[:, 3:53, 15:40, 1:21])
819
820    @skipCPUIfNoFFT
821    @onlyNativeDeviceTypes
822    @dtypes(torch.double)
823    def test_fft_ifft_rfft_irfft(self, device, dtype):
824        self._test_fft_ifft_rfft_irfft(device, dtype)
825
826    @deviceCountAtLeast(1)
827    @onlyCUDA
828    @dtypes(torch.double)
829    def test_cufft_plan_cache(self, devices, dtype):
830        @contextmanager
831        def plan_cache_max_size(device, n):
832            if device is None:
833                plan_cache = torch.backends.cuda.cufft_plan_cache
834            else:
835                plan_cache = torch.backends.cuda.cufft_plan_cache[device]
836            original = plan_cache.max_size
837            plan_cache.max_size = n
838            try:
839                yield
840            finally:
841                plan_cache.max_size = original
842
843        with plan_cache_max_size(devices[0], max(1, torch.backends.cuda.cufft_plan_cache.size - 10)):
844            self._test_fft_ifft_rfft_irfft(devices[0], dtype)
845
846        with plan_cache_max_size(devices[0], 0):
847            self._test_fft_ifft_rfft_irfft(devices[0], dtype)
848
849        torch.backends.cuda.cufft_plan_cache.clear()
850
851        # check that stll works after clearing cache
852        with plan_cache_max_size(devices[0], 10):
853            self._test_fft_ifft_rfft_irfft(devices[0], dtype)
854
855        with self.assertRaisesRegex(RuntimeError, r"must be non-negative"):
856            torch.backends.cuda.cufft_plan_cache.max_size = -1
857
858        with self.assertRaisesRegex(RuntimeError, r"read-only property"):
859            torch.backends.cuda.cufft_plan_cache.size = -1
860
861        with self.assertRaisesRegex(RuntimeError, r"but got device with index"):
862            torch.backends.cuda.cufft_plan_cache[torch.cuda.device_count() + 10]
863
864        # Multigpu tests
865        if len(devices) > 1:
866            # Test that different GPU has different cache
867            x0 = torch.randn(2, 3, 3, device=devices[0])
868            x1 = x0.to(devices[1])
869            self.assertEqual(torch.fft.rfftn(x0, dim=(-2, -1)), torch.fft.rfftn(x1, dim=(-2, -1)))
870            # If a plan is used across different devices, the following line (or
871            # the assert above) would trigger illegal memory access. Other ways
872            # to trigger the error include
873            #   (1) setting CUDA_LAUNCH_BLOCKING=1 (pytorch/pytorch#19224) and
874            #   (2) printing a device 1 tensor.
875            x0.copy_(x1)
876
877            # Test that un-indexed `torch.backends.cuda.cufft_plan_cache` uses current device
878            with plan_cache_max_size(devices[0], 10):
879                with plan_cache_max_size(devices[1], 11):
880                    self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10)
881                    self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11)
882
883                    self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10)  # default is cuda:0
884                    with torch.cuda.device(devices[1]):
885                        self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11)  # default is cuda:1
886                        with torch.cuda.device(devices[0]):
887                            self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10)  # default is cuda:0
888
889                self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10)
890                with torch.cuda.device(devices[1]):
891                    with plan_cache_max_size(None, 11):  # default is cuda:1
892                        self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10)
893                        self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11)
894
895                        self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11)  # default is cuda:1
896                        with torch.cuda.device(devices[0]):
897                            self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10)  # default is cuda:0
898                        self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11)  # default is cuda:1
899
900    @onlyCUDA
901    @dtypes(torch.cfloat, torch.cdouble)
902    def test_cufft_context(self, device, dtype):
903        # Regression test for https://github.com/pytorch/pytorch/issues/109448
904        x = torch.randn(32, dtype=dtype, device=device, requires_grad=True)
905        dout = torch.zeros(32, dtype=dtype, device=device)
906
907        # compute iFFT(FFT(x))
908        out = torch.fft.ifft(torch.fft.fft(x))
909        out.backward(dout, retain_graph=True)
910
911        dx = torch.fft.fft(torch.fft.ifft(dout))
912
913        self.assertTrue((x.grad - dx).abs().max() == 0)
914        self.assertFalse((x.grad - x).abs().max() == 0)
915
916    # passes on ROCm w/ python 2.7, fails w/ python 3.6
917    @skipIfTorchDynamo("cannot set WRITEABLE flag to True of this array")
918    @skipCPUIfNoFFT
919    @onlyNativeDeviceTypes
920    @dtypes(torch.double)
921    def test_stft(self, device, dtype):
922        if not TEST_LIBROSA:
923            raise unittest.SkipTest('librosa not found')
924
925        def librosa_stft(x, n_fft, hop_length, win_length, window, center):
926            if window is None:
927                window = np.ones(n_fft if win_length is None else win_length)
928            else:
929                window = window.cpu().numpy()
930            input_1d = x.dim() == 1
931            if input_1d:
932                x = x.view(1, -1)
933
934            # NOTE: librosa 0.9 changed default pad_mode to 'constant' (zero padding)
935            # however, we use the pre-0.9 default ('reflect')
936            pad_mode = 'reflect'
937
938            result = []
939            for xi in x:
940                ri = librosa.stft(xi.cpu().numpy(), n_fft=n_fft, hop_length=hop_length,
941                                  win_length=win_length, window=window, center=center,
942                                  pad_mode=pad_mode)
943                result.append(torch.from_numpy(np.stack([ri.real, ri.imag], -1)))
944            result = torch.stack(result, 0)
945            if input_1d:
946                result = result[0]
947            return result
948
949        def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None,
950                  center=True, expected_error=None):
951            x = torch.randn(*sizes, dtype=dtype, device=device)
952            if win_sizes is not None:
953                window = torch.randn(*win_sizes, dtype=dtype, device=device)
954            else:
955                window = None
956            if expected_error is None:
957                result = x.stft(n_fft, hop_length, win_length, window,
958                                center=center, return_complex=False)
959                # NB: librosa defaults to np.complex64 output, no matter what
960                # the input dtype
961                ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center)
962                self.assertEqual(result, ref_result, atol=7e-6, rtol=0, msg='stft comparison against librosa', exact_dtype=False)
963                # With return_complex=True, the result is the same but viewed as complex instead of real
964                result_complex = x.stft(n_fft, hop_length, win_length, window, center=center, return_complex=True)
965                self.assertEqual(result_complex, torch.view_as_complex(result))
966            else:
967                self.assertRaises(expected_error,
968                                  lambda: x.stft(n_fft, hop_length, win_length, window, center=center))
969
970        for center in [True, False]:
971            _test((10,), 7, center=center)
972            _test((10, 4000), 1024, center=center)
973
974            _test((10,), 7, 2, center=center)
975            _test((10, 4000), 1024, 512, center=center)
976
977            _test((10,), 7, 2, win_sizes=(7,), center=center)
978            _test((10, 4000), 1024, 512, win_sizes=(1024,), center=center)
979
980            # spectral oversample
981            _test((10,), 7, 2, win_length=5, center=center)
982            _test((10, 4000), 1024, 512, win_length=100, center=center)
983
984        _test((10, 4, 2), 1, 1, expected_error=RuntimeError)
985        _test((10,), 11, 1, center=False, expected_error=RuntimeError)
986        _test((10,), -1, 1, expected_error=RuntimeError)
987        _test((10,), 3, win_length=5, expected_error=RuntimeError)
988        _test((10,), 5, 4, win_sizes=(11,), expected_error=RuntimeError)
989        _test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError)
990
991    @skipIfTorchDynamo("double")
992    @skipCPUIfNoFFT
993    @onlyNativeDeviceTypes
994    @dtypes(torch.double)
995    def test_istft_against_librosa(self, device, dtype):
996        if not TEST_LIBROSA:
997            raise unittest.SkipTest('librosa not found')
998
999        def librosa_istft(x, n_fft, hop_length, win_length, window, length, center):
1000            if window is None:
1001                window = np.ones(n_fft if win_length is None else win_length)
1002            else:
1003                window = window.cpu().numpy()
1004
1005            return librosa.istft(x.cpu().numpy(), n_fft=n_fft, hop_length=hop_length,
1006                                 win_length=win_length, length=length, window=window, center=center)
1007
1008        def _test(size, n_fft, hop_length=None, win_length=None, win_sizes=None,
1009                  length=None, center=True):
1010            x = torch.randn(size, dtype=dtype, device=device)
1011            if win_sizes is not None:
1012                window = torch.randn(*win_sizes, dtype=dtype, device=device)
1013            else:
1014                window = None
1015
1016            x_stft = x.stft(n_fft, hop_length, win_length, window, center=center,
1017                            onesided=True, return_complex=True)
1018
1019            ref_result = librosa_istft(x_stft, n_fft, hop_length, win_length,
1020                                       window, length, center)
1021            result = x_stft.istft(n_fft, hop_length, win_length, window,
1022                                  length=length, center=center)
1023            self.assertEqual(result, ref_result)
1024
1025        for center in [True, False]:
1026            _test(10, 7, center=center)
1027            _test(4000, 1024, center=center)
1028            _test(4000, 1024, center=center, length=4000)
1029
1030            _test(10, 7, 2, center=center)
1031            _test(4000, 1024, 512, center=center)
1032            _test(4000, 1024, 512, center=center, length=4000)
1033
1034            _test(10, 7, 2, win_sizes=(7,), center=center)
1035            _test(4000, 1024, 512, win_sizes=(1024,), center=center)
1036            _test(4000, 1024, 512, win_sizes=(1024,), center=center, length=4000)
1037
1038    @onlyNativeDeviceTypes
1039    @skipCPUIfNoFFT
1040    @dtypes(torch.double, torch.cdouble)
1041    def test_complex_stft_roundtrip(self, device, dtype):
1042        test_args = list(product(
1043            # input
1044            (torch.randn(600, device=device, dtype=dtype),
1045             torch.randn(807, device=device, dtype=dtype),
1046             torch.randn(12, 60, device=device, dtype=dtype)),
1047            # n_fft
1048            (50, 27),
1049            # hop_length
1050            (None, 10),
1051            # center
1052            (True,),
1053            # pad_mode
1054            ("constant", "reflect", "circular"),
1055            # normalized
1056            (True, False),
1057            # onesided
1058            (True, False) if not dtype.is_complex else (False,),
1059        ))
1060
1061        for args in test_args:
1062            x, n_fft, hop_length, center, pad_mode, normalized, onesided = args
1063            common_kwargs = {
1064                'n_fft': n_fft, 'hop_length': hop_length, 'center': center,
1065                'normalized': normalized, 'onesided': onesided,
1066            }
1067
1068            # Functional interface
1069            x_stft = torch.stft(x, pad_mode=pad_mode, return_complex=True, **common_kwargs)
1070            x_roundtrip = torch.istft(x_stft, return_complex=dtype.is_complex,
1071                                      length=x.size(-1), **common_kwargs)
1072            self.assertEqual(x_roundtrip, x)
1073
1074            # Tensor method interface
1075            x_stft = x.stft(pad_mode=pad_mode, return_complex=True, **common_kwargs)
1076            x_roundtrip = torch.istft(x_stft, return_complex=dtype.is_complex,
1077                                      length=x.size(-1), **common_kwargs)
1078            self.assertEqual(x_roundtrip, x)
1079
1080    @onlyNativeDeviceTypes
1081    @skipCPUIfNoFFT
1082    @dtypes(torch.double, torch.cdouble)
1083    def test_stft_roundtrip_complex_window(self, device, dtype):
1084        test_args = list(product(
1085            # input
1086            (torch.randn(600, device=device, dtype=dtype),
1087             torch.randn(807, device=device, dtype=dtype),
1088             torch.randn(12, 60, device=device, dtype=dtype)),
1089            # n_fft
1090            (50, 27),
1091            # hop_length
1092            (None, 10),
1093            # pad_mode
1094            ("constant", "reflect", "replicate", "circular"),
1095            # normalized
1096            (True, False),
1097        ))
1098        for args in test_args:
1099            x, n_fft, hop_length, pad_mode, normalized = args
1100            window = torch.rand(n_fft, device=device, dtype=torch.cdouble)
1101            x_stft = torch.stft(
1102                x, n_fft=n_fft, hop_length=hop_length, window=window,
1103                center=True, pad_mode=pad_mode, normalized=normalized)
1104            self.assertEqual(x_stft.dtype, torch.cdouble)
1105            self.assertEqual(x_stft.size(-2), n_fft)  # Not onesided
1106
1107            x_roundtrip = torch.istft(
1108                x_stft, n_fft=n_fft, hop_length=hop_length, window=window,
1109                center=True, normalized=normalized, length=x.size(-1),
1110                return_complex=True)
1111            self.assertEqual(x_stft.dtype, torch.cdouble)
1112
1113            if not dtype.is_complex:
1114                self.assertEqual(x_roundtrip.imag, torch.zeros_like(x_roundtrip.imag),
1115                                 atol=1e-6, rtol=0)
1116                self.assertEqual(x_roundtrip.real, x)
1117            else:
1118                self.assertEqual(x_roundtrip, x)
1119
1120
1121    @skipCPUIfNoFFT
1122    @dtypes(torch.cdouble)
1123    def test_complex_stft_definition(self, device, dtype):
1124        test_args = list(product(
1125            # input
1126            (torch.randn(600, device=device, dtype=dtype),
1127             torch.randn(807, device=device, dtype=dtype)),
1128            # n_fft
1129            (50, 27),
1130            # hop_length
1131            (10, 15)
1132        ))
1133
1134        for args in test_args:
1135            window = torch.randn(args[1], device=device, dtype=dtype)
1136            expected = _stft_reference(args[0], args[2], window)
1137            actual = torch.stft(*args, window=window, center=False)
1138            self.assertEqual(actual, expected)
1139
1140    @onlyNativeDeviceTypes
1141    @skipCPUIfNoFFT
1142    @dtypes(torch.cdouble)
1143    def test_complex_stft_real_equiv(self, device, dtype):
1144        test_args = list(product(
1145            # input
1146            (torch.rand(600, device=device, dtype=dtype),
1147             torch.rand(807, device=device, dtype=dtype),
1148             torch.rand(14, 50, device=device, dtype=dtype),
1149             torch.rand(6, 51, device=device, dtype=dtype)),
1150            # n_fft
1151            (50, 27),
1152            # hop_length
1153            (None, 10),
1154            # win_length
1155            (None, 20),
1156            # center
1157            (False, True),
1158            # pad_mode
1159            ("constant", "reflect", "circular"),
1160            # normalized
1161            (True, False),
1162        ))
1163
1164        for args in test_args:
1165            x, n_fft, hop_length, win_length, center, pad_mode, normalized = args
1166            expected = _complex_stft(x, n_fft, hop_length=hop_length,
1167                                     win_length=win_length, pad_mode=pad_mode,
1168                                     center=center, normalized=normalized)
1169            actual = torch.stft(x, n_fft, hop_length=hop_length,
1170                                win_length=win_length, pad_mode=pad_mode,
1171                                center=center, normalized=normalized)
1172            self.assertEqual(expected, actual)
1173
1174    @skipCPUIfNoFFT
1175    @dtypes(torch.cdouble)
1176    def test_complex_istft_real_equiv(self, device, dtype):
1177        test_args = list(product(
1178            # input
1179            (torch.rand(40, 20, device=device, dtype=dtype),
1180             torch.rand(25, 1, device=device, dtype=dtype),
1181             torch.rand(4, 20, 10, device=device, dtype=dtype)),
1182            # hop_length
1183            (None, 10),
1184            # center
1185            (False, True),
1186            # normalized
1187            (True, False),
1188        ))
1189
1190        for args in test_args:
1191            x, hop_length, center, normalized = args
1192            n_fft = x.size(-2)
1193            expected = _complex_istft(x, n_fft, hop_length=hop_length,
1194                                      center=center, normalized=normalized)
1195            actual = torch.istft(x, n_fft, hop_length=hop_length,
1196                                 center=center, normalized=normalized,
1197                                 return_complex=True)
1198            self.assertEqual(expected, actual)
1199
1200    @skipCPUIfNoFFT
1201    def test_complex_stft_onesided(self, device):
1202        # stft of complex input cannot be onesided
1203        for x_dtype, window_dtype in product((torch.double, torch.cdouble), repeat=2):
1204            x = torch.rand(100, device=device, dtype=x_dtype)
1205            window = torch.rand(10, device=device, dtype=window_dtype)
1206
1207            if x_dtype.is_complex or window_dtype.is_complex:
1208                with self.assertRaisesRegex(RuntimeError, 'complex'):
1209                    x.stft(10, window=window, pad_mode='constant', onesided=True)
1210            else:
1211                y = x.stft(10, window=window, pad_mode='constant', onesided=True,
1212                           return_complex=True)
1213                self.assertEqual(y.dtype, torch.cdouble)
1214                self.assertEqual(y.size(), (6, 51))
1215
1216        x = torch.rand(100, device=device, dtype=torch.cdouble)
1217        with self.assertRaisesRegex(RuntimeError, 'complex'):
1218            x.stft(10, pad_mode='constant', onesided=True)
1219
1220    # stft is currently warning that it requires return-complex while an upgrader is written
1221    @onlyNativeDeviceTypes
1222    @skipCPUIfNoFFT
1223    def test_stft_requires_complex(self, device):
1224        x = torch.rand(100)
1225        with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
1226            y = x.stft(10, pad_mode='constant')
1227
1228    # stft and istft are currently warning if a window is not provided
1229    @onlyNativeDeviceTypes
1230    @skipCPUIfNoFFT
1231    def test_stft_requires_window(self, device):
1232        x = torch.rand(100)
1233        with self.assertWarnsOnceRegex(UserWarning, "A window was not provided"):
1234            y = x.stft(10, pad_mode='constant', return_complex=True)
1235
1236    @onlyNativeDeviceTypes
1237    @skipCPUIfNoFFT
1238    def test_istft_requires_window(self, device):
1239        stft = torch.rand((51, 5), dtype=torch.cdouble)
1240        # 51 = 2 * n_fft + 1, 5 = number of frames
1241        with self.assertWarnsOnceRegex(UserWarning, "A window was not provided"):
1242            x = torch.istft(stft, n_fft=100, length=100)
1243
1244    @skipCPUIfNoFFT
1245    def test_fft_input_modification(self, device):
1246        # FFT functions should not modify their input (gh-34551)
1247
1248        signal = torch.ones((2, 2, 2), device=device)
1249        signal_copy = signal.clone()
1250        spectrum = torch.fft.fftn(signal, dim=(-2, -1))
1251        self.assertEqual(signal, signal_copy)
1252
1253        spectrum_copy = spectrum.clone()
1254        _ = torch.fft.ifftn(spectrum, dim=(-2, -1))
1255        self.assertEqual(spectrum, spectrum_copy)
1256
1257        half_spectrum = torch.fft.rfftn(signal, dim=(-2, -1))
1258        self.assertEqual(signal, signal_copy)
1259
1260        half_spectrum_copy = half_spectrum.clone()
1261        _ = torch.fft.irfftn(half_spectrum_copy, s=(2, 2), dim=(-2, -1))
1262        self.assertEqual(half_spectrum, half_spectrum_copy)
1263
1264    @onlyNativeDeviceTypes
1265    @skipCPUIfNoFFT
1266    def test_fft_plan_repeatable(self, device):
1267        # Regression test for gh-58724 and gh-63152
1268        for n in [2048, 3199, 5999]:
1269            a = torch.randn(n, device=device, dtype=torch.complex64)
1270            res1 = torch.fft.fftn(a)
1271            res2 = torch.fft.fftn(a.clone())
1272            self.assertEqual(res1, res2)
1273
1274            a = torch.randn(n, device=device, dtype=torch.float64)
1275            res1 = torch.fft.rfft(a)
1276            res2 = torch.fft.rfft(a.clone())
1277            self.assertEqual(res1, res2)
1278
1279    @onlyNativeDeviceTypes
1280    @skipCPUIfNoFFT
1281    @dtypes(torch.double)
1282    def test_istft_round_trip_simple_cases(self, device, dtype):
1283        """stft -> istft should recover the original signale"""
1284        def _test(input, n_fft, length):
1285            stft = torch.stft(input, n_fft=n_fft, return_complex=True)
1286            inverse = torch.istft(stft, n_fft=n_fft, length=length)
1287            self.assertEqual(input, inverse, exact_dtype=True)
1288
1289        _test(torch.ones(4, dtype=dtype, device=device), 4, 4)
1290        _test(torch.zeros(4, dtype=dtype, device=device), 4, 4)
1291
1292    @onlyNativeDeviceTypes
1293    @skipCPUIfNoFFT
1294    @dtypes(torch.double)
1295    def test_istft_round_trip_various_params(self, device, dtype):
1296        """stft -> istft should recover the original signale"""
1297        def _test_istft_is_inverse_of_stft(stft_kwargs):
1298            # generates a random sound signal for each tril and then does the stft/istft
1299            # operation to check whether we can reconstruct signal
1300            data_sizes = [(2, 20), (3, 15), (4, 10)]
1301            num_trials = 100
1302            istft_kwargs = stft_kwargs.copy()
1303            del istft_kwargs['pad_mode']
1304            for sizes in data_sizes:
1305                for i in range(num_trials):
1306                    original = torch.randn(*sizes, dtype=dtype, device=device)
1307                    stft = torch.stft(original, return_complex=True, **stft_kwargs)
1308                    inversed = torch.istft(stft, length=original.size(1), **istft_kwargs)
1309                    self.assertEqual(
1310                        inversed, original, msg='istft comparison against original',
1311                        atol=7e-6, rtol=0, exact_dtype=True)
1312
1313        patterns = [
1314            # hann_window, centered, normalized, onesided
1315            {
1316                'n_fft': 12,
1317                'hop_length': 4,
1318                'win_length': 12,
1319                'window': torch.hann_window(12, dtype=dtype, device=device),
1320                'center': True,
1321                'pad_mode': 'reflect',
1322                'normalized': True,
1323                'onesided': True,
1324            },
1325            # hann_window, centered, not normalized, not onesided
1326            {
1327                'n_fft': 12,
1328                'hop_length': 2,
1329                'win_length': 8,
1330                'window': torch.hann_window(8, dtype=dtype, device=device),
1331                'center': True,
1332                'pad_mode': 'reflect',
1333                'normalized': False,
1334                'onesided': False,
1335            },
1336            # hamming_window, centered, normalized, not onesided
1337            {
1338                'n_fft': 15,
1339                'hop_length': 3,
1340                'win_length': 11,
1341                'window': torch.hamming_window(11, dtype=dtype, device=device),
1342                'center': True,
1343                'pad_mode': 'constant',
1344                'normalized': True,
1345                'onesided': False,
1346            },
1347            # hamming_window, centered, not normalized, onesided
1348            # window same size as n_fft
1349            {
1350                'n_fft': 5,
1351                'hop_length': 2,
1352                'win_length': 5,
1353                'window': torch.hamming_window(5, dtype=dtype, device=device),
1354                'center': True,
1355                'pad_mode': 'constant',
1356                'normalized': False,
1357                'onesided': True,
1358            },
1359        ]
1360        for i, pattern in enumerate(patterns):
1361            _test_istft_is_inverse_of_stft(pattern)
1362
1363    @onlyNativeDeviceTypes
1364    @skipCPUIfNoFFT
1365    @dtypes(torch.double)
1366    def test_istft_round_trip_with_padding(self, device, dtype):
1367        """long hop_length or not centered may cause length mismatch in the inversed signal"""
1368        def _test_istft_is_inverse_of_stft_with_padding(stft_kwargs):
1369            # generates a random sound signal for each tril and then does the stft/istft
1370            # operation to check whether we can reconstruct signal
1371            num_trials = 100
1372            sizes = stft_kwargs['size']
1373            del stft_kwargs['size']
1374            istft_kwargs = stft_kwargs.copy()
1375            del istft_kwargs['pad_mode']
1376            for i in range(num_trials):
1377                original = torch.randn(*sizes, dtype=dtype, device=device)
1378                stft = torch.stft(original, return_complex=True, **stft_kwargs)
1379                with self.assertWarnsOnceRegex(UserWarning, "The length of signal is shorter than the length parameter."):
1380                    inversed = torch.istft(stft, length=original.size(-1), **istft_kwargs)
1381                n_frames = stft.size(-1)
1382                if stft_kwargs["center"] is True:
1383                    len_expected = stft_kwargs["n_fft"] // 2 + stft_kwargs["hop_length"] * (n_frames - 1)
1384                else:
1385                    len_expected = stft_kwargs["n_fft"] + stft_kwargs["hop_length"] * (n_frames - 1)
1386                # trim the original for case when constructed signal is shorter than original
1387                padding = inversed[..., len_expected:]
1388                inversed = inversed[..., :len_expected]
1389                original = original[..., :len_expected]
1390                # test the padding points of the inversed signal are all zeros
1391                zeros = torch.zeros_like(padding, device=padding.device)
1392                self.assertEqual(
1393                    padding, zeros, msg='istft padding values against zeros',
1394                    atol=7e-6, rtol=0, exact_dtype=True)
1395                self.assertEqual(
1396                    inversed, original, msg='istft comparison against original',
1397                    atol=7e-6, rtol=0, exact_dtype=True)
1398
1399        patterns = [
1400            # hamming_window, not centered, not normalized, not onesided
1401            # window same size as n_fft
1402            {
1403                'size': [2, 20],
1404                'n_fft': 3,
1405                'hop_length': 2,
1406                'win_length': 3,
1407                'window': torch.hamming_window(3, dtype=dtype, device=device),
1408                'center': False,
1409                'pad_mode': 'reflect',
1410                'normalized': False,
1411                'onesided': False,
1412            },
1413            # hamming_window, centered, not normalized, onesided, long hop_length
1414            # window same size as n_fft
1415            {
1416                'size': [2, 500],
1417                'n_fft': 256,
1418                'hop_length': 254,
1419                'win_length': 256,
1420                'window': torch.hamming_window(256, dtype=dtype, device=device),
1421                'center': True,
1422                'pad_mode': 'constant',
1423                'normalized': False,
1424                'onesided': True,
1425            },
1426        ]
1427        for i, pattern in enumerate(patterns):
1428            _test_istft_is_inverse_of_stft_with_padding(pattern)
1429
1430    @onlyNativeDeviceTypes
1431    def test_istft_throws(self, device):
1432        """istft should throw exception for invalid parameters"""
1433        stft = torch.zeros((3, 5, 2), device=device)
1434        # the window is size 1 but it hops 20 so there is a gap which throw an error
1435        self.assertRaises(
1436            RuntimeError, torch.istft, stft, n_fft=4,
1437            hop_length=20, win_length=1, window=torch.ones(1))
1438        # A window of zeros does not meet NOLA
1439        invalid_window = torch.zeros(4, device=device)
1440        self.assertRaises(
1441            RuntimeError, torch.istft, stft, n_fft=4, win_length=4, window=invalid_window)
1442        # Input cannot be empty
1443        self.assertRaises(RuntimeError, torch.istft, torch.zeros((3, 0, 2)), 2)
1444        self.assertRaises(RuntimeError, torch.istft, torch.zeros((0, 3, 2)), 2)
1445
1446    @skipIfTorchDynamo("Failed running call_function")
1447    @onlyNativeDeviceTypes
1448    @skipCPUIfNoFFT
1449    @dtypes(torch.double)
1450    def test_istft_of_sine(self, device, dtype):
1451        complex_dtype = corresponding_complex_dtype(dtype)
1452
1453        def _test(amplitude, L, n):
1454            # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
1455            x = torch.arange(2 * L + 1, device=device, dtype=dtype)
1456            original = amplitude * torch.sin(2 * math.pi / L * x * n)
1457            # stft = torch.stft(original, L, hop_length=L, win_length=L,
1458            #                   window=torch.ones(L), center=False, normalized=False)
1459            stft = torch.zeros((L // 2 + 1, 2), device=device, dtype=complex_dtype)
1460            stft_largest_val = (amplitude * L) / 2.0
1461            if n < stft.size(0):
1462                stft[n].imag = torch.tensor(-stft_largest_val, dtype=dtype)
1463
1464            if 0 <= L - n < stft.size(0):
1465                # symmetric about L // 2
1466                stft[L - n].imag = torch.tensor(stft_largest_val, dtype=dtype)
1467
1468            inverse = torch.istft(
1469                stft, L, hop_length=L, win_length=L,
1470                window=torch.ones(L, device=device, dtype=dtype), center=False, normalized=False)
1471            # There is a larger error due to the scaling of amplitude
1472            original = original[..., :inverse.size(-1)]
1473            self.assertEqual(inverse, original, atol=1e-3, rtol=0)
1474
1475        _test(amplitude=123, L=5, n=1)
1476        _test(amplitude=150, L=5, n=2)
1477        _test(amplitude=111, L=5, n=3)
1478        _test(amplitude=160, L=7, n=4)
1479        _test(amplitude=145, L=8, n=5)
1480        _test(amplitude=80, L=9, n=6)
1481        _test(amplitude=99, L=10, n=7)
1482
1483    @onlyNativeDeviceTypes
1484    @skipCPUIfNoFFT
1485    @dtypes(torch.double)
1486    def test_istft_linearity(self, device, dtype):
1487        num_trials = 100
1488        complex_dtype = corresponding_complex_dtype(dtype)
1489
1490        def _test(data_size, kwargs):
1491            for i in range(num_trials):
1492                tensor1 = torch.randn(data_size, device=device, dtype=complex_dtype)
1493                tensor2 = torch.randn(data_size, device=device, dtype=complex_dtype)
1494                a, b = torch.rand(2, dtype=dtype, device=device)
1495                # Also compare method vs. functional call signature
1496                istft1 = tensor1.istft(**kwargs)
1497                istft2 = tensor2.istft(**kwargs)
1498                istft = a * istft1 + b * istft2
1499                estimate = torch.istft(a * tensor1 + b * tensor2, **kwargs)
1500                self.assertEqual(istft, estimate, atol=1e-5, rtol=0)
1501        patterns = [
1502            # hann_window, centered, normalized, onesided
1503            (
1504                (2, 7, 7),
1505                {
1506                    'n_fft': 12,
1507                    'window': torch.hann_window(12, device=device, dtype=dtype),
1508                    'center': True,
1509                    'normalized': True,
1510                    'onesided': True,
1511                },
1512            ),
1513            # hann_window, centered, not normalized, not onesided
1514            (
1515                (2, 12, 7),
1516                {
1517                    'n_fft': 12,
1518                    'window': torch.hann_window(12, device=device, dtype=dtype),
1519                    'center': True,
1520                    'normalized': False,
1521                    'onesided': False,
1522                },
1523            ),
1524            # hamming_window, centered, normalized, not onesided
1525            (
1526                (2, 12, 7),
1527                {
1528                    'n_fft': 12,
1529                    'window': torch.hamming_window(12, device=device, dtype=dtype),
1530                    'center': True,
1531                    'normalized': True,
1532                    'onesided': False,
1533                },
1534            ),
1535            # hamming_window, not centered, not normalized, onesided
1536            (
1537                (2, 7, 3),
1538                {
1539                    'n_fft': 12,
1540                    'window': torch.hamming_window(12, device=device, dtype=dtype),
1541                    'center': False,
1542                    'normalized': False,
1543                    'onesided': True,
1544                },
1545            )
1546        ]
1547        for data_size, kwargs in patterns:
1548            _test(data_size, kwargs)
1549
1550    @onlyNativeDeviceTypes
1551    @skipCPUIfNoFFT
1552    def test_batch_istft(self, device):
1553        original = torch.tensor([
1554            [4., 4., 4., 4., 4.],
1555            [0., 0., 0., 0., 0.],
1556            [0., 0., 0., 0., 0.]
1557        ], device=device, dtype=torch.complex64)
1558
1559        single = original.repeat(1, 1, 1)
1560        multi = original.repeat(4, 1, 1)
1561
1562        i_original = torch.istft(original, n_fft=4, length=4)
1563        i_single = torch.istft(single, n_fft=4, length=4)
1564        i_multi = torch.istft(multi, n_fft=4, length=4)
1565
1566        self.assertEqual(i_original.repeat(1, 1), i_single, atol=1e-6, rtol=0, exact_dtype=True)
1567        self.assertEqual(i_original.repeat(4, 1), i_multi, atol=1e-6, rtol=0, exact_dtype=True)
1568
1569    @onlyCUDA
1570    @skipIf(not TEST_MKL, "Test requires MKL")
1571    def test_stft_window_device(self, device):
1572        # Test the (i)stft window must be on the same device as the input
1573        x = torch.randn(1000, dtype=torch.complex64)
1574        window = torch.randn(100, dtype=torch.complex64)
1575
1576        with self.assertRaisesRegex(RuntimeError, "stft input and window must be on the same device"):
1577            torch.stft(x, n_fft=100, window=window.to(device))
1578
1579        with self.assertRaisesRegex(RuntimeError, "stft input and window must be on the same device"):
1580            torch.stft(x.to(device), n_fft=100, window=window)
1581
1582        X = torch.stft(x, n_fft=100, window=window)
1583
1584        with self.assertRaisesRegex(RuntimeError, "istft input and window must be on the same device"):
1585            torch.istft(X, n_fft=100, window=window.to(device))
1586
1587        with self.assertRaisesRegex(RuntimeError, "istft input and window must be on the same device"):
1588            torch.istft(x.to(device), n_fft=100, window=window)
1589
1590
1591class FFTDocTestFinder:
1592    '''The default doctest finder doesn't like that function.__module__ doesn't
1593    match torch.fft. It assumes the functions are leaked imports.
1594    '''
1595    def __init__(self) -> None:
1596        self.parser = doctest.DocTestParser()
1597
1598    def find(self, obj, name=None, module=None, globs=None, extraglobs=None):
1599        doctests = []
1600
1601        modname = name if name is not None else obj.__name__
1602        globs = {} if globs is None else globs
1603
1604        for fname in obj.__all__:
1605            func = getattr(obj, fname)
1606            if inspect.isroutine(func):
1607                qualname = modname + '.' + fname
1608                docstring = inspect.getdoc(func)
1609                if docstring is None:
1610                    continue
1611
1612                examples = self.parser.get_doctest(
1613                    docstring, globs=globs, name=fname, filename=None, lineno=None)
1614                doctests.append(examples)
1615
1616        return doctests
1617
1618
1619class TestFFTDocExamples(TestCase):
1620    pass
1621
1622def generate_doc_test(doc_test):
1623    def test(self, device):
1624        self.assertEqual(device, 'cpu')
1625        runner = doctest.DocTestRunner()
1626        runner.run(doc_test)
1627
1628        if runner.failures != 0:
1629            runner.summarize()
1630            self.fail('Doctest failed')
1631
1632    setattr(TestFFTDocExamples, 'test_' + doc_test.name, skipCPUIfNoFFT(test))
1633
1634for doc_test in FFTDocTestFinder().find(torch.fft, globs=dict(torch=torch)):
1635    generate_doc_test(doc_test)
1636
1637
1638instantiate_device_type_tests(TestFFT, globals())
1639instantiate_device_type_tests(TestFFTDocExamples, globals(), only_for='cpu')
1640
1641if __name__ == '__main__':
1642    run_tests()
1643