# Owner(s): ["module: fft"] import torch import unittest import math from contextlib import contextmanager from itertools import product import itertools import doctest import inspect from torch.testing._internal.common_utils import \ (TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL, first_sample, TEST_WITH_ROCM, make_tensor, skipIfTorchDynamo) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, ops, dtypes, onlyNativeDeviceTypes, skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf, toleranceOverride, tol) from torch.testing._internal.common_methods_invocations import ( spectral_funcs, SpectralFuncType) from torch.testing._internal.common_cuda import SM53OrLater from torch._prims_common import corresponding_complex_dtype from typing import Optional, List from packaging import version if TEST_NUMPY: import numpy as np if TEST_LIBROSA: import librosa has_scipy_fft = False try: import scipy.fft has_scipy_fft = True except ModuleNotFoundError: pass REFERENCE_NORM_MODES = ( (None, "forward", "backward", "ortho") if version.parse(np.__version__) >= version.parse('1.20.0') and ( not has_scipy_fft or version.parse(scipy.__version__) >= version.parse('1.6.0')) else (None, "ortho")) def _complex_stft(x, *args, **kwargs): # Transform real and imaginary components separably stft_real = torch.stft(x.real, *args, **kwargs, return_complex=True, onesided=False) stft_imag = torch.stft(x.imag, *args, **kwargs, return_complex=True, onesided=False) return stft_real + 1j * stft_imag def _hermitian_conj(x, dim): """Returns the hermitian conjugate along a single dimension H(x)[i] = conj(x[-i]) """ out = torch.empty_like(x) mid = (x.size(dim) - 1) // 2 idx = [slice(None)] * out.dim() idx_center = list(idx) idx_center[dim] = 0 out[idx] = x[idx] idx_neg = list(idx) idx_neg[dim] = slice(-mid, None) idx_pos = idx idx_pos[dim] = slice(1, mid + 1) out[idx_pos] = x[idx_neg].flip(dim) out[idx_neg] = x[idx_pos].flip(dim) if (2 * mid + 1 < x.size(dim)): idx[dim] = mid + 1 out[idx] = x[idx] return out.conj() def _complex_istft(x, *args, **kwargs): # Decompose into Hermitian (FFT of real) and anti-Hermitian (FFT of imaginary) n_fft = x.size(-2) slc = (Ellipsis, slice(None, n_fft // 2 + 1), slice(None)) hconj = _hermitian_conj(x, dim=-2) x_hermitian = (x + hconj) / 2 x_antihermitian = (x - hconj) / 2 istft_real = torch.istft(x_hermitian[slc], *args, **kwargs, onesided=True) istft_imag = torch.istft(-1j * x_antihermitian[slc], *args, **kwargs, onesided=True) return torch.complex(istft_real, istft_imag) def _stft_reference(x, hop_length, window): r"""Reference stft implementation This doesn't implement all of torch.stft, only the STFT definition: .. math:: X(m, \omega) = \sum_n x[n]w[n - m] e^{-jn\omega} """ n_fft = window.numel() X = torch.empty((n_fft, (x.numel() - n_fft + hop_length) // hop_length), device=x.device, dtype=torch.cdouble) for m in range(X.size(1)): start = m * hop_length if start + n_fft > x.numel(): slc = torch.empty(n_fft, device=x.device, dtype=x.dtype) tmp = x[start:] slc[:tmp.numel()] = tmp else: slc = x[start: start + n_fft] X[:, m] = torch.fft.fft(slc * window) return X def skip_helper_for_fft(device, dtype): device_type = torch.device(device).type if dtype not in (torch.half, torch.complex32): return if device_type == 'cpu': raise unittest.SkipTest("half and complex32 are not supported on CPU") if not SM53OrLater: raise unittest.SkipTest("half and complex32 are only supported on CUDA device with SM>53") # Tests of functions related to Fourier analysis in the torch.fft namespace class TestFFT(TestCase): exact_dtype = True @onlyNativeDeviceTypes @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.OneD], allowed_dtypes=(torch.float, torch.cfloat)) def test_reference_1d(self, device, dtype, op): if op.ref is None: raise unittest.SkipTest("No reference implementation") norm_modes = REFERENCE_NORM_MODES test_args = [ *product( # input (torch.randn(67, device=device, dtype=dtype), torch.randn(80, device=device, dtype=dtype), torch.randn(12, 14, device=device, dtype=dtype), torch.randn(9, 6, 3, device=device, dtype=dtype)), # n (None, 50, 6), # dim (-1, 0), # norm norm_modes ), # Test transforming middle dimensions of multi-dim tensor *product( (torch.randn(4, 5, 6, 7, device=device, dtype=dtype),), (None,), (1, 2, -2,), norm_modes ) ] for iargs in test_args: args = list(iargs) input = args[0] args = args[1:] expected = op.ref(input.cpu().numpy(), *args) exact_dtype = dtype in (torch.double, torch.complex128) actual = op(input, *args) self.assertEqual(actual, expected, exact_dtype=exact_dtype) @skipCPUIfNoFFT @onlyNativeDeviceTypes @toleranceOverride({ torch.half : tol(1e-2, 1e-2), torch.chalf : tol(1e-2, 1e-2), }) @dtypes(torch.half, torch.float, torch.double, torch.complex32, torch.complex64, torch.complex128) def test_fft_round_trip(self, device, dtype): skip_helper_for_fft(device, dtype) # Test that round trip through ifft(fft(x)) is the identity if dtype not in (torch.half, torch.complex32): test_args = list(product( # input (torch.randn(67, device=device, dtype=dtype), torch.randn(80, device=device, dtype=dtype), torch.randn(12, 14, device=device, dtype=dtype), torch.randn(9, 6, 3, device=device, dtype=dtype)), # dim (-1, 0), # norm (None, "forward", "backward", "ortho") )) else: # cuFFT supports powers of 2 for half and complex half precision test_args = list(product( # input (torch.randn(64, device=device, dtype=dtype), torch.randn(128, device=device, dtype=dtype), torch.randn(4, 16, device=device, dtype=dtype), torch.randn(8, 6, 2, device=device, dtype=dtype)), # dim (-1, 0), # norm (None, "forward", "backward", "ortho") )) fft_functions = [(torch.fft.fft, torch.fft.ifft)] # Real-only functions if not dtype.is_complex: # NOTE: Using ihfft as "forward" transform to avoid needing to # generate true half-complex input fft_functions += [(torch.fft.rfft, torch.fft.irfft), (torch.fft.ihfft, torch.fft.hfft)] for forward, backward in fft_functions: for x, dim, norm in test_args: kwargs = { 'n': x.size(dim), 'dim': dim, 'norm': norm, } y = backward(forward(x, **kwargs), **kwargs) if x.dtype is torch.half and y.dtype is torch.complex32: # Since type promotion currently doesn't work with complex32 # manually promote `x` to complex32 x = x.to(torch.complex32) # For real input, ifft(fft(x)) will convert to complex self.assertEqual(x, y, exact_dtype=( forward != torch.fft.fft or x.is_complex())) # Note: NumPy will throw a ValueError for an empty input @onlyNativeDeviceTypes @ops(spectral_funcs, allowed_dtypes=(torch.half, torch.float, torch.complex32, torch.cfloat)) def test_empty_fft(self, device, dtype, op): t = torch.empty(1, 0, device=device, dtype=dtype) match = r"Invalid number of data points \([-\d]*\) specified" with self.assertRaisesRegex(RuntimeError, match): op(t) @onlyNativeDeviceTypes def test_empty_ifft(self, device): t = torch.empty(2, 1, device=device, dtype=torch.complex64) match = r"Invalid number of data points \([-\d]*\) specified" for f in [torch.fft.irfft, torch.fft.irfft2, torch.fft.irfftn, torch.fft.hfft, torch.fft.hfft2, torch.fft.hfftn]: with self.assertRaisesRegex(RuntimeError, match): f(t) @onlyNativeDeviceTypes def test_fft_invalid_dtypes(self, device): t = torch.randn(64, device=device, dtype=torch.complex128) with self.assertRaisesRegex(RuntimeError, "rfft expects a real input tensor"): torch.fft.rfft(t) with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input tensor"): torch.fft.rfftn(t) with self.assertRaisesRegex(RuntimeError, "ihfft expects a real input tensor"): torch.fft.ihfft(t) @skipCPUIfNoFFT @onlyNativeDeviceTypes @dtypes(torch.int8, torch.half, torch.float, torch.double, torch.complex32, torch.complex64, torch.complex128) def test_fft_type_promotion(self, device, dtype): skip_helper_for_fft(device, dtype) if dtype.is_complex or dtype.is_floating_point: t = torch.randn(64, device=device, dtype=dtype) else: t = torch.randint(-2, 2, (64,), device=device, dtype=dtype) PROMOTION_MAP = { torch.int8: torch.complex64, torch.half: torch.complex32, torch.float: torch.complex64, torch.double: torch.complex128, torch.complex32: torch.complex32, torch.complex64: torch.complex64, torch.complex128: torch.complex128, } T = torch.fft.fft(t) self.assertEqual(T.dtype, PROMOTION_MAP[dtype]) PROMOTION_MAP_C2R = { torch.int8: torch.float, torch.half: torch.half, torch.float: torch.float, torch.double: torch.double, torch.complex32: torch.half, torch.complex64: torch.float, torch.complex128: torch.double, } if dtype in (torch.half, torch.complex32): # cuFFT supports powers of 2 for half and complex half precision # NOTE: With hfft and default args where output_size n=2*(input_size - 1), # we make sure that logical fft size is a power of two. x = torch.randn(65, device=device, dtype=dtype) R = torch.fft.hfft(x) else: R = torch.fft.hfft(t) self.assertEqual(R.dtype, PROMOTION_MAP_C2R[dtype]) if not dtype.is_complex: PROMOTION_MAP_R2C = { torch.int8: torch.complex64, torch.half: torch.complex32, torch.float: torch.complex64, torch.double: torch.complex128, } C = torch.fft.rfft(t) self.assertEqual(C.dtype, PROMOTION_MAP_R2C[dtype]) @onlyNativeDeviceTypes @ops(spectral_funcs, dtypes=OpDTypes.unsupported, allowed_dtypes=[torch.half, torch.bfloat16]) def test_fft_half_and_bfloat16_errors(self, device, dtype, op): # TODO: Remove torch.half error when complex32 is fully implemented sample = first_sample(self, op.sample_inputs(device, dtype)) device_type = torch.device(device).type default_msg = "Unsupported dtype" if dtype is torch.half and device_type == 'cuda' and TEST_WITH_ROCM: err_msg = default_msg elif dtype is torch.half and device_type == 'cuda' and not SM53OrLater: err_msg = "cuFFT doesn't support signals of half type with compute capability less than SM_53" else: err_msg = default_msg with self.assertRaisesRegex(RuntimeError, err_msg): op(sample.input, *sample.args, **sample.kwargs) @onlyNativeDeviceTypes @ops(spectral_funcs, allowed_dtypes=(torch.half, torch.chalf)) def test_fft_half_and_chalf_not_power_of_two_error(self, device, dtype, op): t = make_tensor(13, 13, device=device, dtype=dtype) err_msg = "cuFFT only supports dimensions whose sizes are powers of two" with self.assertRaisesRegex(RuntimeError, err_msg): op(t) if op.ndimensional in (SpectralFuncType.ND, SpectralFuncType.TwoD): kwargs = {'s': (12, 12)} else: kwargs = {'n': 12} with self.assertRaisesRegex(RuntimeError, err_msg): op(t, **kwargs) # nd-fft tests @onlyNativeDeviceTypes @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND], allowed_dtypes=(torch.cfloat, torch.cdouble)) def test_reference_nd(self, device, dtype, op): if op.ref is None: raise unittest.SkipTest("No reference implementation") norm_modes = REFERENCE_NORM_MODES # input_ndim, s, dim transform_desc = [ *product(range(2, 5), (None,), (None, (0,), (0, -1))), *product(range(2, 5), (None, (4, 10)), (None,)), (6, None, None), (5, None, (1, 3, 4)), (3, None, (1,)), (1, None, (0,)), (4, (10, 10), None), (4, (10, 10), (0, 1)) ] for input_ndim, s, dim in transform_desc: shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) input = torch.randn(*shape, device=device, dtype=dtype) for norm in norm_modes: expected = op.ref(input.cpu().numpy(), s, dim, norm) exact_dtype = dtype in (torch.double, torch.complex128) actual = op(input, s, dim, norm) self.assertEqual(actual, expected, exact_dtype=exact_dtype) @skipCPUIfNoFFT @onlyNativeDeviceTypes @toleranceOverride({ torch.half : tol(1e-2, 1e-2), torch.chalf : tol(1e-2, 1e-2), }) @dtypes(torch.half, torch.float, torch.double, torch.complex32, torch.complex64, torch.complex128) def test_fftn_round_trip(self, device, dtype): skip_helper_for_fft(device, dtype) norm_modes = (None, "forward", "backward", "ortho") # input_ndim, dim transform_desc = [ *product(range(2, 5), (None, (0,), (0, -1))), (7, None), (5, (1, 3, 4)), (3, (1,)), (1, 0), ] fft_functions = [(torch.fft.fftn, torch.fft.ifftn)] # Real-only functions if not dtype.is_complex: # NOTE: Using ihfftn as "forward" transform to avoid needing to # generate true half-complex input fft_functions += [(torch.fft.rfftn, torch.fft.irfftn), (torch.fft.ihfftn, torch.fft.hfftn)] for input_ndim, dim in transform_desc: if dtype in (torch.half, torch.complex32): # cuFFT supports powers of 2 for half and complex half precision shape = itertools.islice(itertools.cycle((2, 4, 8)), input_ndim) else: shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) x = torch.randn(*shape, device=device, dtype=dtype) for (forward, backward), norm in product(fft_functions, norm_modes): if isinstance(dim, tuple): s = [x.size(d) for d in dim] else: s = x.size() if dim is None else x.size(dim) kwargs = {'s': s, 'dim': dim, 'norm': norm} y = backward(forward(x, **kwargs), **kwargs) # For real input, ifftn(fftn(x)) will convert to complex if x.dtype is torch.half and y.dtype is torch.chalf: # Since type promotion currently doesn't work with complex32 # manually promote `x` to complex32 self.assertEqual(x.to(torch.chalf), y) else: self.assertEqual(x, y, exact_dtype=( forward != torch.fft.fftn or x.is_complex())) @onlyNativeDeviceTypes @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND], allowed_dtypes=[torch.float, torch.cfloat]) def test_fftn_invalid(self, device, dtype, op): a = torch.rand(10, 10, 10, device=device, dtype=dtype) # FIXME: https://github.com/pytorch/pytorch/issues/108205 errMsg = "dims must be unique" with self.assertRaisesRegex(RuntimeError, errMsg): op(a, dim=(0, 1, 0)) with self.assertRaisesRegex(RuntimeError, errMsg): op(a, dim=(2, -1)) with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"): op(a, s=(1,), dim=(0, 1)) with self.assertRaisesRegex(IndexError, "Dimension out of range"): op(a, dim=(3,)) with self.assertRaisesRegex(RuntimeError, "tensor only has 3 dimensions"): op(a, s=(10, 10, 10, 10)) @skipCPUIfNoFFT @onlyNativeDeviceTypes @dtypes(torch.half, torch.float, torch.double, torch.cfloat, torch.cdouble) def test_fftn_noop_transform(self, device, dtype): skip_helper_for_fft(device, dtype) RESULT_TYPE = { torch.half: torch.chalf, torch.float: torch.cfloat, torch.double: torch.cdouble, } for op in [ torch.fft.fftn, torch.fft.ifftn, torch.fft.fft2, torch.fft.ifft2, ]: inp = make_tensor((10, 10), device=device, dtype=dtype) out = torch.fft.fftn(inp, dim=[]) expect_dtype = RESULT_TYPE.get(inp.dtype, inp.dtype) expect = inp.to(expect_dtype) self.assertEqual(expect, out) @skipCPUIfNoFFT @onlyNativeDeviceTypes @toleranceOverride({ torch.half : tol(1e-2, 1e-2), }) @dtypes(torch.half, torch.float, torch.double) def test_hfftn(self, device, dtype): skip_helper_for_fft(device, dtype) # input_ndim, dim transform_desc = [ *product(range(2, 5), (None, (0,), (0, -1))), (6, None), (5, (1, 3, 4)), (3, (1,)), (1, (0,)), (4, (0, 1)) ] for input_ndim, dim in transform_desc: actual_dims = list(range(input_ndim)) if dim is None else dim if dtype is torch.half: shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim)) else: shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim)) expect = torch.randn(*shape, device=device, dtype=dtype) input = torch.fft.ifftn(expect, dim=dim, norm="ortho") lastdim = actual_dims[-1] lastdim_size = input.size(lastdim) // 2 + 1 idx = [slice(None)] * input_ndim idx[lastdim] = slice(0, lastdim_size) input = input[idx] s = [shape[dim] for dim in actual_dims] actual = torch.fft.hfftn(input, s=s, dim=dim, norm="ortho") self.assertEqual(expect, actual) @skipCPUIfNoFFT @onlyNativeDeviceTypes @toleranceOverride({ torch.half : tol(1e-2, 1e-2), }) @dtypes(torch.half, torch.float, torch.double) def test_ihfftn(self, device, dtype): skip_helper_for_fft(device, dtype) # input_ndim, dim transform_desc = [ *product(range(2, 5), (None, (0,), (0, -1))), (6, None), (5, (1, 3, 4)), (3, (1,)), (1, (0,)), (4, (0, 1)) ] for input_ndim, dim in transform_desc: if dtype is torch.half: shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim)) else: shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim)) input = torch.randn(*shape, device=device, dtype=dtype) expect = torch.fft.ifftn(input, dim=dim, norm="ortho") # Slice off the half-symmetric component lastdim = -1 if dim is None else dim[-1] lastdim_size = expect.size(lastdim) // 2 + 1 idx = [slice(None)] * input_ndim idx[lastdim] = slice(0, lastdim_size) expect = expect[idx] actual = torch.fft.ihfftn(input, dim=dim, norm="ortho") self.assertEqual(expect, actual) # 2d-fft tests # NOTE: 2d transforms are only thin wrappers over n-dim transforms, # so don't require exhaustive testing. @skipCPUIfNoFFT @onlyNativeDeviceTypes @dtypes(torch.double, torch.complex128) def test_fft2_numpy(self, device, dtype): norm_modes = REFERENCE_NORM_MODES # input_ndim, s transform_desc = [ *product(range(2, 5), (None, (4, 10))), ] fft_functions = ['fft2', 'ifft2', 'irfft2', 'hfft2'] if dtype.is_floating_point: fft_functions += ['rfft2', 'ihfft2'] for input_ndim, s in transform_desc: shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) input = torch.randn(*shape, device=device, dtype=dtype) for fname, norm in product(fft_functions, norm_modes): torch_fn = getattr(torch.fft, fname) if "hfft" in fname: if not has_scipy_fft: continue # Requires scipy to compare against numpy_fn = getattr(scipy.fft, fname) else: numpy_fn = getattr(np.fft, fname) def fn(t: torch.Tensor, s: Optional[List[int]], dim: List[int] = (-2, -1), norm: Optional[str] = None): return torch_fn(t, s, dim, norm) torch_fns = (torch_fn, torch.jit.script(fn)) # Once with dim defaulted input_np = input.cpu().numpy() expected = numpy_fn(input_np, s, norm=norm) for fn in torch_fns: actual = fn(input, s, norm=norm) self.assertEqual(actual, expected) # Once with explicit dims dim = (1, 0) expected = numpy_fn(input_np, s, dim, norm) for fn in torch_fns: actual = fn(input, s, dim, norm) self.assertEqual(actual, expected) @skipCPUIfNoFFT @onlyNativeDeviceTypes @dtypes(torch.float, torch.complex64) def test_fft2_fftn_equivalence(self, device, dtype): norm_modes = (None, "forward", "backward", "ortho") # input_ndim, s, dim transform_desc = [ *product(range(2, 5), (None, (4, 10)), (None, (1, 0))), (3, None, (0, 2)), ] fft_functions = ['fft', 'ifft', 'irfft', 'hfft'] # Real-only functions if dtype.is_floating_point: fft_functions += ['rfft', 'ihfft'] for input_ndim, s, dim in transform_desc: shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) x = torch.randn(*shape, device=device, dtype=dtype) for func, norm in product(fft_functions, norm_modes): f2d = getattr(torch.fft, func + '2') fnd = getattr(torch.fft, func + 'n') kwargs = {'s': s, 'norm': norm} if dim is not None: kwargs['dim'] = dim expect = fnd(x, **kwargs) else: expect = fnd(x, dim=(-2, -1), **kwargs) actual = f2d(x, **kwargs) self.assertEqual(actual, expect) @skipCPUIfNoFFT @onlyNativeDeviceTypes def test_fft2_invalid(self, device): a = torch.rand(10, 10, 10, device=device) fft_funcs = (torch.fft.fft2, torch.fft.ifft2, torch.fft.rfft2, torch.fft.irfft2) for func in fft_funcs: with self.assertRaisesRegex(RuntimeError, "dims must be unique"): func(a, dim=(0, 0)) with self.assertRaisesRegex(RuntimeError, "dims must be unique"): func(a, dim=(2, -1)) with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"): func(a, s=(1,)) with self.assertRaisesRegex(IndexError, "Dimension out of range"): func(a, dim=(2, 3)) c = torch.complex(a, a) with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input"): torch.fft.rfft2(c) # Helper functions @skipCPUIfNoFFT @onlyNativeDeviceTypes @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') @dtypes(torch.float, torch.double) def test_fftfreq_numpy(self, device, dtype): test_args = [ *product( # n range(1, 20), # d (None, 10.0), ) ] functions = ['fftfreq', 'rfftfreq'] for fname in functions: torch_fn = getattr(torch.fft, fname) numpy_fn = getattr(np.fft, fname) for n, d in test_args: args = (n,) if d is None else (n, d) expected = numpy_fn(*args) actual = torch_fn(*args, device=device, dtype=dtype) self.assertEqual(actual, expected, exact_dtype=False) @skipCPUIfNoFFT @onlyNativeDeviceTypes @dtypes(torch.float, torch.double) def test_fftfreq_out(self, device, dtype): for func in (torch.fft.fftfreq, torch.fft.rfftfreq): expect = func(n=100, d=.5, device=device, dtype=dtype) actual = torch.empty((), device=device, dtype=dtype) with self.assertWarnsRegex(UserWarning, "out tensor will be resized"): func(n=100, d=.5, out=actual) self.assertEqual(actual, expect) @skipCPUIfNoFFT @onlyNativeDeviceTypes @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') @dtypes(torch.float, torch.double, torch.complex64, torch.complex128) def test_fftshift_numpy(self, device, dtype): test_args = [ # shape, dim *product(((11,), (12,)), (None, 0, -1)), *product(((4, 5), (6, 6)), (None, 0, (-1,))), *product(((1, 1, 4, 6, 7, 2),), (None, (3, 4))), ] functions = ['fftshift', 'ifftshift'] for shape, dim in test_args: input = torch.rand(*shape, device=device, dtype=dtype) input_np = input.cpu().numpy() for fname in functions: torch_fn = getattr(torch.fft, fname) numpy_fn = getattr(np.fft, fname) expected = numpy_fn(input_np, axes=dim) actual = torch_fn(input, dim=dim) self.assertEqual(actual, expected) @skipCPUIfNoFFT @onlyNativeDeviceTypes @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') @dtypes(torch.float, torch.double) def test_fftshift_frequencies(self, device, dtype): for n in range(10, 15): sorted_fft_freqs = torch.arange(-(n // 2), n - (n // 2), device=device, dtype=dtype) x = torch.fft.fftfreq(n, d=1 / n, device=device, dtype=dtype) # Test fftshift sorts the fftfreq output shifted = torch.fft.fftshift(x) self.assertEqual(shifted, shifted.sort().values) self.assertEqual(sorted_fft_freqs, shifted) # And ifftshift is the inverse self.assertEqual(x, torch.fft.ifftshift(shifted)) # Legacy fft tests def _test_fft_ifft_rfft_irfft(self, device, dtype): complex_dtype = corresponding_complex_dtype(dtype) def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x): x = prepro_fn(torch.randn(*sizes, dtype=complex_dtype, device=device)) dim = tuple(range(-signal_ndim, 0)) for norm in ('ortho', None): res = torch.fft.fftn(x, dim=dim, norm=norm) rec = torch.fft.ifftn(res, dim=dim, norm=norm) self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='fft and ifft') res = torch.fft.ifftn(x, dim=dim, norm=norm) rec = torch.fft.fftn(res, dim=dim, norm=norm) self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='ifft and fft') def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x): x = prepro_fn(torch.randn(*sizes, dtype=dtype, device=device)) signal_numel = 1 signal_sizes = x.size()[-signal_ndim:] dim = tuple(range(-signal_ndim, 0)) for norm in (None, 'ortho'): res = torch.fft.rfftn(x, dim=dim, norm=norm) rec = torch.fft.irfftn(res, s=signal_sizes, dim=dim, norm=norm) self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='rfft and irfft') res = torch.fft.fftn(x, dim=dim, norm=norm) rec = torch.fft.ifftn(res, dim=dim, norm=norm) x_complex = torch.complex(x, torch.zeros_like(x)) self.assertEqual(x_complex, rec, atol=1e-8, rtol=0, msg='fft and ifft (from real)') # contiguous case _test_real((100,), 1) _test_real((10, 1, 10, 100), 1) _test_real((100, 100), 2) _test_real((2, 2, 5, 80, 60), 2) _test_real((50, 40, 70), 3) _test_real((30, 1, 50, 25, 20), 3) _test_complex((100,), 1) _test_complex((100, 100), 1) _test_complex((100, 100), 2) _test_complex((1, 20, 80, 60), 2) _test_complex((50, 40, 70), 3) _test_complex((6, 5, 50, 25, 20), 3) # non-contiguous case _test_real((165,), 1, lambda x: x.narrow(0, 25, 100)) # input is not aligned to complex type _test_real((100, 100, 3), 1, lambda x: x[:, :, 0]) _test_real((100, 100), 2, lambda x: x.t()) _test_real((20, 100, 10, 10), 2, lambda x: x.view(20, 100, 100)[:, :60]) _test_real((65, 80, 115), 3, lambda x: x[10:60, 13:53, 10:80]) _test_real((30, 20, 50, 25), 3, lambda x: x.transpose(1, 2).transpose(2, 3)) _test_complex((100,), 1, lambda x: x.expand(100, 100)) _test_complex((20, 90, 110), 2, lambda x: x[:, 5:85].narrow(2, 5, 100)) _test_complex((40, 60, 3, 80), 3, lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:]) _test_complex((30, 55, 50, 22), 3, lambda x: x[:, 3:53, 15:40, 1:21]) @skipCPUIfNoFFT @onlyNativeDeviceTypes @dtypes(torch.double) def test_fft_ifft_rfft_irfft(self, device, dtype): self._test_fft_ifft_rfft_irfft(device, dtype) @deviceCountAtLeast(1) @onlyCUDA @dtypes(torch.double) def test_cufft_plan_cache(self, devices, dtype): @contextmanager def plan_cache_max_size(device, n): if device is None: plan_cache = torch.backends.cuda.cufft_plan_cache else: plan_cache = torch.backends.cuda.cufft_plan_cache[device] original = plan_cache.max_size plan_cache.max_size = n try: yield finally: plan_cache.max_size = original with plan_cache_max_size(devices[0], max(1, torch.backends.cuda.cufft_plan_cache.size - 10)): self._test_fft_ifft_rfft_irfft(devices[0], dtype) with plan_cache_max_size(devices[0], 0): self._test_fft_ifft_rfft_irfft(devices[0], dtype) torch.backends.cuda.cufft_plan_cache.clear() # check that stll works after clearing cache with plan_cache_max_size(devices[0], 10): self._test_fft_ifft_rfft_irfft(devices[0], dtype) with self.assertRaisesRegex(RuntimeError, r"must be non-negative"): torch.backends.cuda.cufft_plan_cache.max_size = -1 with self.assertRaisesRegex(RuntimeError, r"read-only property"): torch.backends.cuda.cufft_plan_cache.size = -1 with self.assertRaisesRegex(RuntimeError, r"but got device with index"): torch.backends.cuda.cufft_plan_cache[torch.cuda.device_count() + 10] # Multigpu tests if len(devices) > 1: # Test that different GPU has different cache x0 = torch.randn(2, 3, 3, device=devices[0]) x1 = x0.to(devices[1]) self.assertEqual(torch.fft.rfftn(x0, dim=(-2, -1)), torch.fft.rfftn(x1, dim=(-2, -1))) # If a plan is used across different devices, the following line (or # the assert above) would trigger illegal memory access. Other ways # to trigger the error include # (1) setting CUDA_LAUNCH_BLOCKING=1 (pytorch/pytorch#19224) and # (2) printing a device 1 tensor. x0.copy_(x1) # Test that un-indexed `torch.backends.cuda.cufft_plan_cache` uses current device with plan_cache_max_size(devices[0], 10): with plan_cache_max_size(devices[1], 11): self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10) self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11) self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0 with torch.cuda.device(devices[1]): self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1 with torch.cuda.device(devices[0]): self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0 self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10) with torch.cuda.device(devices[1]): with plan_cache_max_size(None, 11): # default is cuda:1 self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10) self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11) self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1 with torch.cuda.device(devices[0]): self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0 self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1 @onlyCUDA @dtypes(torch.cfloat, torch.cdouble) def test_cufft_context(self, device, dtype): # Regression test for https://github.com/pytorch/pytorch/issues/109448 x = torch.randn(32, dtype=dtype, device=device, requires_grad=True) dout = torch.zeros(32, dtype=dtype, device=device) # compute iFFT(FFT(x)) out = torch.fft.ifft(torch.fft.fft(x)) out.backward(dout, retain_graph=True) dx = torch.fft.fft(torch.fft.ifft(dout)) self.assertTrue((x.grad - dx).abs().max() == 0) self.assertFalse((x.grad - x).abs().max() == 0) # passes on ROCm w/ python 2.7, fails w/ python 3.6 @skipIfTorchDynamo("cannot set WRITEABLE flag to True of this array") @skipCPUIfNoFFT @onlyNativeDeviceTypes @dtypes(torch.double) def test_stft(self, device, dtype): if not TEST_LIBROSA: raise unittest.SkipTest('librosa not found') def librosa_stft(x, n_fft, hop_length, win_length, window, center): if window is None: window = np.ones(n_fft if win_length is None else win_length) else: window = window.cpu().numpy() input_1d = x.dim() == 1 if input_1d: x = x.view(1, -1) # NOTE: librosa 0.9 changed default pad_mode to 'constant' (zero padding) # however, we use the pre-0.9 default ('reflect') pad_mode = 'reflect' result = [] for xi in x: ri = librosa.stft(xi.cpu().numpy(), n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, center=center, pad_mode=pad_mode) result.append(torch.from_numpy(np.stack([ri.real, ri.imag], -1))) result = torch.stack(result, 0) if input_1d: result = result[0] return result def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None, center=True, expected_error=None): x = torch.randn(*sizes, dtype=dtype, device=device) if win_sizes is not None: window = torch.randn(*win_sizes, dtype=dtype, device=device) else: window = None if expected_error is None: result = x.stft(n_fft, hop_length, win_length, window, center=center, return_complex=False) # NB: librosa defaults to np.complex64 output, no matter what # the input dtype ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center) self.assertEqual(result, ref_result, atol=7e-6, rtol=0, msg='stft comparison against librosa', exact_dtype=False) # With return_complex=True, the result is the same but viewed as complex instead of real result_complex = x.stft(n_fft, hop_length, win_length, window, center=center, return_complex=True) self.assertEqual(result_complex, torch.view_as_complex(result)) else: self.assertRaises(expected_error, lambda: x.stft(n_fft, hop_length, win_length, window, center=center)) for center in [True, False]: _test((10,), 7, center=center) _test((10, 4000), 1024, center=center) _test((10,), 7, 2, center=center) _test((10, 4000), 1024, 512, center=center) _test((10,), 7, 2, win_sizes=(7,), center=center) _test((10, 4000), 1024, 512, win_sizes=(1024,), center=center) # spectral oversample _test((10,), 7, 2, win_length=5, center=center) _test((10, 4000), 1024, 512, win_length=100, center=center) _test((10, 4, 2), 1, 1, expected_error=RuntimeError) _test((10,), 11, 1, center=False, expected_error=RuntimeError) _test((10,), -1, 1, expected_error=RuntimeError) _test((10,), 3, win_length=5, expected_error=RuntimeError) _test((10,), 5, 4, win_sizes=(11,), expected_error=RuntimeError) _test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError) @skipIfTorchDynamo("double") @skipCPUIfNoFFT @onlyNativeDeviceTypes @dtypes(torch.double) def test_istft_against_librosa(self, device, dtype): if not TEST_LIBROSA: raise unittest.SkipTest('librosa not found') def librosa_istft(x, n_fft, hop_length, win_length, window, length, center): if window is None: window = np.ones(n_fft if win_length is None else win_length) else: window = window.cpu().numpy() return librosa.istft(x.cpu().numpy(), n_fft=n_fft, hop_length=hop_length, win_length=win_length, length=length, window=window, center=center) def _test(size, n_fft, hop_length=None, win_length=None, win_sizes=None, length=None, center=True): x = torch.randn(size, dtype=dtype, device=device) if win_sizes is not None: window = torch.randn(*win_sizes, dtype=dtype, device=device) else: window = None x_stft = x.stft(n_fft, hop_length, win_length, window, center=center, onesided=True, return_complex=True) ref_result = librosa_istft(x_stft, n_fft, hop_length, win_length, window, length, center) result = x_stft.istft(n_fft, hop_length, win_length, window, length=length, center=center) self.assertEqual(result, ref_result) for center in [True, False]: _test(10, 7, center=center) _test(4000, 1024, center=center) _test(4000, 1024, center=center, length=4000) _test(10, 7, 2, center=center) _test(4000, 1024, 512, center=center) _test(4000, 1024, 512, center=center, length=4000) _test(10, 7, 2, win_sizes=(7,), center=center) _test(4000, 1024, 512, win_sizes=(1024,), center=center) _test(4000, 1024, 512, win_sizes=(1024,), center=center, length=4000) @onlyNativeDeviceTypes @skipCPUIfNoFFT @dtypes(torch.double, torch.cdouble) def test_complex_stft_roundtrip(self, device, dtype): test_args = list(product( # input (torch.randn(600, device=device, dtype=dtype), torch.randn(807, device=device, dtype=dtype), torch.randn(12, 60, device=device, dtype=dtype)), # n_fft (50, 27), # hop_length (None, 10), # center (True,), # pad_mode ("constant", "reflect", "circular"), # normalized (True, False), # onesided (True, False) if not dtype.is_complex else (False,), )) for args in test_args: x, n_fft, hop_length, center, pad_mode, normalized, onesided = args common_kwargs = { 'n_fft': n_fft, 'hop_length': hop_length, 'center': center, 'normalized': normalized, 'onesided': onesided, } # Functional interface x_stft = torch.stft(x, pad_mode=pad_mode, return_complex=True, **common_kwargs) x_roundtrip = torch.istft(x_stft, return_complex=dtype.is_complex, length=x.size(-1), **common_kwargs) self.assertEqual(x_roundtrip, x) # Tensor method interface x_stft = x.stft(pad_mode=pad_mode, return_complex=True, **common_kwargs) x_roundtrip = torch.istft(x_stft, return_complex=dtype.is_complex, length=x.size(-1), **common_kwargs) self.assertEqual(x_roundtrip, x) @onlyNativeDeviceTypes @skipCPUIfNoFFT @dtypes(torch.double, torch.cdouble) def test_stft_roundtrip_complex_window(self, device, dtype): test_args = list(product( # input (torch.randn(600, device=device, dtype=dtype), torch.randn(807, device=device, dtype=dtype), torch.randn(12, 60, device=device, dtype=dtype)), # n_fft (50, 27), # hop_length (None, 10), # pad_mode ("constant", "reflect", "replicate", "circular"), # normalized (True, False), )) for args in test_args: x, n_fft, hop_length, pad_mode, normalized = args window = torch.rand(n_fft, device=device, dtype=torch.cdouble) x_stft = torch.stft( x, n_fft=n_fft, hop_length=hop_length, window=window, center=True, pad_mode=pad_mode, normalized=normalized) self.assertEqual(x_stft.dtype, torch.cdouble) self.assertEqual(x_stft.size(-2), n_fft) # Not onesided x_roundtrip = torch.istft( x_stft, n_fft=n_fft, hop_length=hop_length, window=window, center=True, normalized=normalized, length=x.size(-1), return_complex=True) self.assertEqual(x_stft.dtype, torch.cdouble) if not dtype.is_complex: self.assertEqual(x_roundtrip.imag, torch.zeros_like(x_roundtrip.imag), atol=1e-6, rtol=0) self.assertEqual(x_roundtrip.real, x) else: self.assertEqual(x_roundtrip, x) @skipCPUIfNoFFT @dtypes(torch.cdouble) def test_complex_stft_definition(self, device, dtype): test_args = list(product( # input (torch.randn(600, device=device, dtype=dtype), torch.randn(807, device=device, dtype=dtype)), # n_fft (50, 27), # hop_length (10, 15) )) for args in test_args: window = torch.randn(args[1], device=device, dtype=dtype) expected = _stft_reference(args[0], args[2], window) actual = torch.stft(*args, window=window, center=False) self.assertEqual(actual, expected) @onlyNativeDeviceTypes @skipCPUIfNoFFT @dtypes(torch.cdouble) def test_complex_stft_real_equiv(self, device, dtype): test_args = list(product( # input (torch.rand(600, device=device, dtype=dtype), torch.rand(807, device=device, dtype=dtype), torch.rand(14, 50, device=device, dtype=dtype), torch.rand(6, 51, device=device, dtype=dtype)), # n_fft (50, 27), # hop_length (None, 10), # win_length (None, 20), # center (False, True), # pad_mode ("constant", "reflect", "circular"), # normalized (True, False), )) for args in test_args: x, n_fft, hop_length, win_length, center, pad_mode, normalized = args expected = _complex_stft(x, n_fft, hop_length=hop_length, win_length=win_length, pad_mode=pad_mode, center=center, normalized=normalized) actual = torch.stft(x, n_fft, hop_length=hop_length, win_length=win_length, pad_mode=pad_mode, center=center, normalized=normalized) self.assertEqual(expected, actual) @skipCPUIfNoFFT @dtypes(torch.cdouble) def test_complex_istft_real_equiv(self, device, dtype): test_args = list(product( # input (torch.rand(40, 20, device=device, dtype=dtype), torch.rand(25, 1, device=device, dtype=dtype), torch.rand(4, 20, 10, device=device, dtype=dtype)), # hop_length (None, 10), # center (False, True), # normalized (True, False), )) for args in test_args: x, hop_length, center, normalized = args n_fft = x.size(-2) expected = _complex_istft(x, n_fft, hop_length=hop_length, center=center, normalized=normalized) actual = torch.istft(x, n_fft, hop_length=hop_length, center=center, normalized=normalized, return_complex=True) self.assertEqual(expected, actual) @skipCPUIfNoFFT def test_complex_stft_onesided(self, device): # stft of complex input cannot be onesided for x_dtype, window_dtype in product((torch.double, torch.cdouble), repeat=2): x = torch.rand(100, device=device, dtype=x_dtype) window = torch.rand(10, device=device, dtype=window_dtype) if x_dtype.is_complex or window_dtype.is_complex: with self.assertRaisesRegex(RuntimeError, 'complex'): x.stft(10, window=window, pad_mode='constant', onesided=True) else: y = x.stft(10, window=window, pad_mode='constant', onesided=True, return_complex=True) self.assertEqual(y.dtype, torch.cdouble) self.assertEqual(y.size(), (6, 51)) x = torch.rand(100, device=device, dtype=torch.cdouble) with self.assertRaisesRegex(RuntimeError, 'complex'): x.stft(10, pad_mode='constant', onesided=True) # stft is currently warning that it requires return-complex while an upgrader is written @onlyNativeDeviceTypes @skipCPUIfNoFFT def test_stft_requires_complex(self, device): x = torch.rand(100) with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'): y = x.stft(10, pad_mode='constant') # stft and istft are currently warning if a window is not provided @onlyNativeDeviceTypes @skipCPUIfNoFFT def test_stft_requires_window(self, device): x = torch.rand(100) with self.assertWarnsOnceRegex(UserWarning, "A window was not provided"): y = x.stft(10, pad_mode='constant', return_complex=True) @onlyNativeDeviceTypes @skipCPUIfNoFFT def test_istft_requires_window(self, device): stft = torch.rand((51, 5), dtype=torch.cdouble) # 51 = 2 * n_fft + 1, 5 = number of frames with self.assertWarnsOnceRegex(UserWarning, "A window was not provided"): x = torch.istft(stft, n_fft=100, length=100) @skipCPUIfNoFFT def test_fft_input_modification(self, device): # FFT functions should not modify their input (gh-34551) signal = torch.ones((2, 2, 2), device=device) signal_copy = signal.clone() spectrum = torch.fft.fftn(signal, dim=(-2, -1)) self.assertEqual(signal, signal_copy) spectrum_copy = spectrum.clone() _ = torch.fft.ifftn(spectrum, dim=(-2, -1)) self.assertEqual(spectrum, spectrum_copy) half_spectrum = torch.fft.rfftn(signal, dim=(-2, -1)) self.assertEqual(signal, signal_copy) half_spectrum_copy = half_spectrum.clone() _ = torch.fft.irfftn(half_spectrum_copy, s=(2, 2), dim=(-2, -1)) self.assertEqual(half_spectrum, half_spectrum_copy) @onlyNativeDeviceTypes @skipCPUIfNoFFT def test_fft_plan_repeatable(self, device): # Regression test for gh-58724 and gh-63152 for n in [2048, 3199, 5999]: a = torch.randn(n, device=device, dtype=torch.complex64) res1 = torch.fft.fftn(a) res2 = torch.fft.fftn(a.clone()) self.assertEqual(res1, res2) a = torch.randn(n, device=device, dtype=torch.float64) res1 = torch.fft.rfft(a) res2 = torch.fft.rfft(a.clone()) self.assertEqual(res1, res2) @onlyNativeDeviceTypes @skipCPUIfNoFFT @dtypes(torch.double) def test_istft_round_trip_simple_cases(self, device, dtype): """stft -> istft should recover the original signale""" def _test(input, n_fft, length): stft = torch.stft(input, n_fft=n_fft, return_complex=True) inverse = torch.istft(stft, n_fft=n_fft, length=length) self.assertEqual(input, inverse, exact_dtype=True) _test(torch.ones(4, dtype=dtype, device=device), 4, 4) _test(torch.zeros(4, dtype=dtype, device=device), 4, 4) @onlyNativeDeviceTypes @skipCPUIfNoFFT @dtypes(torch.double) def test_istft_round_trip_various_params(self, device, dtype): """stft -> istft should recover the original signale""" def _test_istft_is_inverse_of_stft(stft_kwargs): # generates a random sound signal for each tril and then does the stft/istft # operation to check whether we can reconstruct signal data_sizes = [(2, 20), (3, 15), (4, 10)] num_trials = 100 istft_kwargs = stft_kwargs.copy() del istft_kwargs['pad_mode'] for sizes in data_sizes: for i in range(num_trials): original = torch.randn(*sizes, dtype=dtype, device=device) stft = torch.stft(original, return_complex=True, **stft_kwargs) inversed = torch.istft(stft, length=original.size(1), **istft_kwargs) self.assertEqual( inversed, original, msg='istft comparison against original', atol=7e-6, rtol=0, exact_dtype=True) patterns = [ # hann_window, centered, normalized, onesided { 'n_fft': 12, 'hop_length': 4, 'win_length': 12, 'window': torch.hann_window(12, dtype=dtype, device=device), 'center': True, 'pad_mode': 'reflect', 'normalized': True, 'onesided': True, }, # hann_window, centered, not normalized, not onesided { 'n_fft': 12, 'hop_length': 2, 'win_length': 8, 'window': torch.hann_window(8, dtype=dtype, device=device), 'center': True, 'pad_mode': 'reflect', 'normalized': False, 'onesided': False, }, # hamming_window, centered, normalized, not onesided { 'n_fft': 15, 'hop_length': 3, 'win_length': 11, 'window': torch.hamming_window(11, dtype=dtype, device=device), 'center': True, 'pad_mode': 'constant', 'normalized': True, 'onesided': False, }, # hamming_window, centered, not normalized, onesided # window same size as n_fft { 'n_fft': 5, 'hop_length': 2, 'win_length': 5, 'window': torch.hamming_window(5, dtype=dtype, device=device), 'center': True, 'pad_mode': 'constant', 'normalized': False, 'onesided': True, }, ] for i, pattern in enumerate(patterns): _test_istft_is_inverse_of_stft(pattern) @onlyNativeDeviceTypes @skipCPUIfNoFFT @dtypes(torch.double) def test_istft_round_trip_with_padding(self, device, dtype): """long hop_length or not centered may cause length mismatch in the inversed signal""" def _test_istft_is_inverse_of_stft_with_padding(stft_kwargs): # generates a random sound signal for each tril and then does the stft/istft # operation to check whether we can reconstruct signal num_trials = 100 sizes = stft_kwargs['size'] del stft_kwargs['size'] istft_kwargs = stft_kwargs.copy() del istft_kwargs['pad_mode'] for i in range(num_trials): original = torch.randn(*sizes, dtype=dtype, device=device) stft = torch.stft(original, return_complex=True, **stft_kwargs) with self.assertWarnsOnceRegex(UserWarning, "The length of signal is shorter than the length parameter."): inversed = torch.istft(stft, length=original.size(-1), **istft_kwargs) n_frames = stft.size(-1) if stft_kwargs["center"] is True: len_expected = stft_kwargs["n_fft"] // 2 + stft_kwargs["hop_length"] * (n_frames - 1) else: len_expected = stft_kwargs["n_fft"] + stft_kwargs["hop_length"] * (n_frames - 1) # trim the original for case when constructed signal is shorter than original padding = inversed[..., len_expected:] inversed = inversed[..., :len_expected] original = original[..., :len_expected] # test the padding points of the inversed signal are all zeros zeros = torch.zeros_like(padding, device=padding.device) self.assertEqual( padding, zeros, msg='istft padding values against zeros', atol=7e-6, rtol=0, exact_dtype=True) self.assertEqual( inversed, original, msg='istft comparison against original', atol=7e-6, rtol=0, exact_dtype=True) patterns = [ # hamming_window, not centered, not normalized, not onesided # window same size as n_fft { 'size': [2, 20], 'n_fft': 3, 'hop_length': 2, 'win_length': 3, 'window': torch.hamming_window(3, dtype=dtype, device=device), 'center': False, 'pad_mode': 'reflect', 'normalized': False, 'onesided': False, }, # hamming_window, centered, not normalized, onesided, long hop_length # window same size as n_fft { 'size': [2, 500], 'n_fft': 256, 'hop_length': 254, 'win_length': 256, 'window': torch.hamming_window(256, dtype=dtype, device=device), 'center': True, 'pad_mode': 'constant', 'normalized': False, 'onesided': True, }, ] for i, pattern in enumerate(patterns): _test_istft_is_inverse_of_stft_with_padding(pattern) @onlyNativeDeviceTypes def test_istft_throws(self, device): """istft should throw exception for invalid parameters""" stft = torch.zeros((3, 5, 2), device=device) # the window is size 1 but it hops 20 so there is a gap which throw an error self.assertRaises( RuntimeError, torch.istft, stft, n_fft=4, hop_length=20, win_length=1, window=torch.ones(1)) # A window of zeros does not meet NOLA invalid_window = torch.zeros(4, device=device) self.assertRaises( RuntimeError, torch.istft, stft, n_fft=4, win_length=4, window=invalid_window) # Input cannot be empty self.assertRaises(RuntimeError, torch.istft, torch.zeros((3, 0, 2)), 2) self.assertRaises(RuntimeError, torch.istft, torch.zeros((0, 3, 2)), 2) @skipIfTorchDynamo("Failed running call_function") @onlyNativeDeviceTypes @skipCPUIfNoFFT @dtypes(torch.double) def test_istft_of_sine(self, device, dtype): complex_dtype = corresponding_complex_dtype(dtype) def _test(amplitude, L, n): # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L x = torch.arange(2 * L + 1, device=device, dtype=dtype) original = amplitude * torch.sin(2 * math.pi / L * x * n) # stft = torch.stft(original, L, hop_length=L, win_length=L, # window=torch.ones(L), center=False, normalized=False) stft = torch.zeros((L // 2 + 1, 2), device=device, dtype=complex_dtype) stft_largest_val = (amplitude * L) / 2.0 if n < stft.size(0): stft[n].imag = torch.tensor(-stft_largest_val, dtype=dtype) if 0 <= L - n < stft.size(0): # symmetric about L // 2 stft[L - n].imag = torch.tensor(stft_largest_val, dtype=dtype) inverse = torch.istft( stft, L, hop_length=L, win_length=L, window=torch.ones(L, device=device, dtype=dtype), center=False, normalized=False) # There is a larger error due to the scaling of amplitude original = original[..., :inverse.size(-1)] self.assertEqual(inverse, original, atol=1e-3, rtol=0) _test(amplitude=123, L=5, n=1) _test(amplitude=150, L=5, n=2) _test(amplitude=111, L=5, n=3) _test(amplitude=160, L=7, n=4) _test(amplitude=145, L=8, n=5) _test(amplitude=80, L=9, n=6) _test(amplitude=99, L=10, n=7) @onlyNativeDeviceTypes @skipCPUIfNoFFT @dtypes(torch.double) def test_istft_linearity(self, device, dtype): num_trials = 100 complex_dtype = corresponding_complex_dtype(dtype) def _test(data_size, kwargs): for i in range(num_trials): tensor1 = torch.randn(data_size, device=device, dtype=complex_dtype) tensor2 = torch.randn(data_size, device=device, dtype=complex_dtype) a, b = torch.rand(2, dtype=dtype, device=device) # Also compare method vs. functional call signature istft1 = tensor1.istft(**kwargs) istft2 = tensor2.istft(**kwargs) istft = a * istft1 + b * istft2 estimate = torch.istft(a * tensor1 + b * tensor2, **kwargs) self.assertEqual(istft, estimate, atol=1e-5, rtol=0) patterns = [ # hann_window, centered, normalized, onesided ( (2, 7, 7), { 'n_fft': 12, 'window': torch.hann_window(12, device=device, dtype=dtype), 'center': True, 'normalized': True, 'onesided': True, }, ), # hann_window, centered, not normalized, not onesided ( (2, 12, 7), { 'n_fft': 12, 'window': torch.hann_window(12, device=device, dtype=dtype), 'center': True, 'normalized': False, 'onesided': False, }, ), # hamming_window, centered, normalized, not onesided ( (2, 12, 7), { 'n_fft': 12, 'window': torch.hamming_window(12, device=device, dtype=dtype), 'center': True, 'normalized': True, 'onesided': False, }, ), # hamming_window, not centered, not normalized, onesided ( (2, 7, 3), { 'n_fft': 12, 'window': torch.hamming_window(12, device=device, dtype=dtype), 'center': False, 'normalized': False, 'onesided': True, }, ) ] for data_size, kwargs in patterns: _test(data_size, kwargs) @onlyNativeDeviceTypes @skipCPUIfNoFFT def test_batch_istft(self, device): original = torch.tensor([ [4., 4., 4., 4., 4.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.] ], device=device, dtype=torch.complex64) single = original.repeat(1, 1, 1) multi = original.repeat(4, 1, 1) i_original = torch.istft(original, n_fft=4, length=4) i_single = torch.istft(single, n_fft=4, length=4) i_multi = torch.istft(multi, n_fft=4, length=4) self.assertEqual(i_original.repeat(1, 1), i_single, atol=1e-6, rtol=0, exact_dtype=True) self.assertEqual(i_original.repeat(4, 1), i_multi, atol=1e-6, rtol=0, exact_dtype=True) @onlyCUDA @skipIf(not TEST_MKL, "Test requires MKL") def test_stft_window_device(self, device): # Test the (i)stft window must be on the same device as the input x = torch.randn(1000, dtype=torch.complex64) window = torch.randn(100, dtype=torch.complex64) with self.assertRaisesRegex(RuntimeError, "stft input and window must be on the same device"): torch.stft(x, n_fft=100, window=window.to(device)) with self.assertRaisesRegex(RuntimeError, "stft input and window must be on the same device"): torch.stft(x.to(device), n_fft=100, window=window) X = torch.stft(x, n_fft=100, window=window) with self.assertRaisesRegex(RuntimeError, "istft input and window must be on the same device"): torch.istft(X, n_fft=100, window=window.to(device)) with self.assertRaisesRegex(RuntimeError, "istft input and window must be on the same device"): torch.istft(x.to(device), n_fft=100, window=window) class FFTDocTestFinder: '''The default doctest finder doesn't like that function.__module__ doesn't match torch.fft. It assumes the functions are leaked imports. ''' def __init__(self) -> None: self.parser = doctest.DocTestParser() def find(self, obj, name=None, module=None, globs=None, extraglobs=None): doctests = [] modname = name if name is not None else obj.__name__ globs = {} if globs is None else globs for fname in obj.__all__: func = getattr(obj, fname) if inspect.isroutine(func): qualname = modname + '.' + fname docstring = inspect.getdoc(func) if docstring is None: continue examples = self.parser.get_doctest( docstring, globs=globs, name=fname, filename=None, lineno=None) doctests.append(examples) return doctests class TestFFTDocExamples(TestCase): pass def generate_doc_test(doc_test): def test(self, device): self.assertEqual(device, 'cpu') runner = doctest.DocTestRunner() runner.run(doc_test) if runner.failures != 0: runner.summarize() self.fail('Doctest failed') setattr(TestFFTDocExamples, 'test_' + doc_test.name, skipCPUIfNoFFT(test)) for doc_test in FFTDocTestFinder().find(torch.fft, globs=dict(torch=torch)): generate_doc_test(doc_test) instantiate_device_type_tests(TestFFT, globals()) instantiate_device_type_tests(TestFFTDocExamples, globals(), only_for='cpu') if __name__ == '__main__': run_tests()