1# Owner(s): ["module: fft"] 2 3import torch 4import unittest 5import math 6from contextlib import contextmanager 7from itertools import product 8import itertools 9import doctest 10import inspect 11 12from torch.testing._internal.common_utils import \ 13 (TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL, first_sample, TEST_WITH_ROCM, 14 make_tensor, skipIfTorchDynamo) 15from torch.testing._internal.common_device_type import \ 16 (instantiate_device_type_tests, ops, dtypes, onlyNativeDeviceTypes, 17 skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf, toleranceOverride, tol) 18from torch.testing._internal.common_methods_invocations import ( 19 spectral_funcs, SpectralFuncType) 20from torch.testing._internal.common_cuda import SM53OrLater 21from torch._prims_common import corresponding_complex_dtype 22 23from typing import Optional, List 24from packaging import version 25 26 27if TEST_NUMPY: 28 import numpy as np 29 30 31if TEST_LIBROSA: 32 import librosa 33 34has_scipy_fft = False 35try: 36 import scipy.fft 37 has_scipy_fft = True 38except ModuleNotFoundError: 39 pass 40 41REFERENCE_NORM_MODES = ( 42 (None, "forward", "backward", "ortho") 43 if version.parse(np.__version__) >= version.parse('1.20.0') and ( 44 not has_scipy_fft or version.parse(scipy.__version__) >= version.parse('1.6.0')) 45 else (None, "ortho")) 46 47 48def _complex_stft(x, *args, **kwargs): 49 # Transform real and imaginary components separably 50 stft_real = torch.stft(x.real, *args, **kwargs, return_complex=True, onesided=False) 51 stft_imag = torch.stft(x.imag, *args, **kwargs, return_complex=True, onesided=False) 52 return stft_real + 1j * stft_imag 53 54 55def _hermitian_conj(x, dim): 56 """Returns the hermitian conjugate along a single dimension 57 58 H(x)[i] = conj(x[-i]) 59 """ 60 out = torch.empty_like(x) 61 mid = (x.size(dim) - 1) // 2 62 idx = [slice(None)] * out.dim() 63 idx_center = list(idx) 64 idx_center[dim] = 0 65 out[idx] = x[idx] 66 67 idx_neg = list(idx) 68 idx_neg[dim] = slice(-mid, None) 69 idx_pos = idx 70 idx_pos[dim] = slice(1, mid + 1) 71 72 out[idx_pos] = x[idx_neg].flip(dim) 73 out[idx_neg] = x[idx_pos].flip(dim) 74 if (2 * mid + 1 < x.size(dim)): 75 idx[dim] = mid + 1 76 out[idx] = x[idx] 77 return out.conj() 78 79 80def _complex_istft(x, *args, **kwargs): 81 # Decompose into Hermitian (FFT of real) and anti-Hermitian (FFT of imaginary) 82 n_fft = x.size(-2) 83 slc = (Ellipsis, slice(None, n_fft // 2 + 1), slice(None)) 84 85 hconj = _hermitian_conj(x, dim=-2) 86 x_hermitian = (x + hconj) / 2 87 x_antihermitian = (x - hconj) / 2 88 istft_real = torch.istft(x_hermitian[slc], *args, **kwargs, onesided=True) 89 istft_imag = torch.istft(-1j * x_antihermitian[slc], *args, **kwargs, onesided=True) 90 return torch.complex(istft_real, istft_imag) 91 92 93def _stft_reference(x, hop_length, window): 94 r"""Reference stft implementation 95 96 This doesn't implement all of torch.stft, only the STFT definition: 97 98 .. math:: X(m, \omega) = \sum_n x[n]w[n - m] e^{-jn\omega} 99 100 """ 101 n_fft = window.numel() 102 X = torch.empty((n_fft, (x.numel() - n_fft + hop_length) // hop_length), 103 device=x.device, dtype=torch.cdouble) 104 for m in range(X.size(1)): 105 start = m * hop_length 106 if start + n_fft > x.numel(): 107 slc = torch.empty(n_fft, device=x.device, dtype=x.dtype) 108 tmp = x[start:] 109 slc[:tmp.numel()] = tmp 110 else: 111 slc = x[start: start + n_fft] 112 X[:, m] = torch.fft.fft(slc * window) 113 return X 114 115 116def skip_helper_for_fft(device, dtype): 117 device_type = torch.device(device).type 118 if dtype not in (torch.half, torch.complex32): 119 return 120 121 if device_type == 'cpu': 122 raise unittest.SkipTest("half and complex32 are not supported on CPU") 123 if not SM53OrLater: 124 raise unittest.SkipTest("half and complex32 are only supported on CUDA device with SM>53") 125 126 127# Tests of functions related to Fourier analysis in the torch.fft namespace 128class TestFFT(TestCase): 129 exact_dtype = True 130 131 @onlyNativeDeviceTypes 132 @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.OneD], 133 allowed_dtypes=(torch.float, torch.cfloat)) 134 def test_reference_1d(self, device, dtype, op): 135 if op.ref is None: 136 raise unittest.SkipTest("No reference implementation") 137 138 norm_modes = REFERENCE_NORM_MODES 139 test_args = [ 140 *product( 141 # input 142 (torch.randn(67, device=device, dtype=dtype), 143 torch.randn(80, device=device, dtype=dtype), 144 torch.randn(12, 14, device=device, dtype=dtype), 145 torch.randn(9, 6, 3, device=device, dtype=dtype)), 146 # n 147 (None, 50, 6), 148 # dim 149 (-1, 0), 150 # norm 151 norm_modes 152 ), 153 # Test transforming middle dimensions of multi-dim tensor 154 *product( 155 (torch.randn(4, 5, 6, 7, device=device, dtype=dtype),), 156 (None,), 157 (1, 2, -2,), 158 norm_modes 159 ) 160 ] 161 162 for iargs in test_args: 163 args = list(iargs) 164 input = args[0] 165 args = args[1:] 166 167 expected = op.ref(input.cpu().numpy(), *args) 168 exact_dtype = dtype in (torch.double, torch.complex128) 169 actual = op(input, *args) 170 self.assertEqual(actual, expected, exact_dtype=exact_dtype) 171 172 @skipCPUIfNoFFT 173 @onlyNativeDeviceTypes 174 @toleranceOverride({ 175 torch.half : tol(1e-2, 1e-2), 176 torch.chalf : tol(1e-2, 1e-2), 177 }) 178 @dtypes(torch.half, torch.float, torch.double, torch.complex32, torch.complex64, torch.complex128) 179 def test_fft_round_trip(self, device, dtype): 180 skip_helper_for_fft(device, dtype) 181 # Test that round trip through ifft(fft(x)) is the identity 182 if dtype not in (torch.half, torch.complex32): 183 test_args = list(product( 184 # input 185 (torch.randn(67, device=device, dtype=dtype), 186 torch.randn(80, device=device, dtype=dtype), 187 torch.randn(12, 14, device=device, dtype=dtype), 188 torch.randn(9, 6, 3, device=device, dtype=dtype)), 189 # dim 190 (-1, 0), 191 # norm 192 (None, "forward", "backward", "ortho") 193 )) 194 else: 195 # cuFFT supports powers of 2 for half and complex half precision 196 test_args = list(product( 197 # input 198 (torch.randn(64, device=device, dtype=dtype), 199 torch.randn(128, device=device, dtype=dtype), 200 torch.randn(4, 16, device=device, dtype=dtype), 201 torch.randn(8, 6, 2, device=device, dtype=dtype)), 202 # dim 203 (-1, 0), 204 # norm 205 (None, "forward", "backward", "ortho") 206 )) 207 208 fft_functions = [(torch.fft.fft, torch.fft.ifft)] 209 # Real-only functions 210 if not dtype.is_complex: 211 # NOTE: Using ihfft as "forward" transform to avoid needing to 212 # generate true half-complex input 213 fft_functions += [(torch.fft.rfft, torch.fft.irfft), 214 (torch.fft.ihfft, torch.fft.hfft)] 215 216 for forward, backward in fft_functions: 217 for x, dim, norm in test_args: 218 kwargs = { 219 'n': x.size(dim), 220 'dim': dim, 221 'norm': norm, 222 } 223 224 y = backward(forward(x, **kwargs), **kwargs) 225 if x.dtype is torch.half and y.dtype is torch.complex32: 226 # Since type promotion currently doesn't work with complex32 227 # manually promote `x` to complex32 228 x = x.to(torch.complex32) 229 # For real input, ifft(fft(x)) will convert to complex 230 self.assertEqual(x, y, exact_dtype=( 231 forward != torch.fft.fft or x.is_complex())) 232 233 # Note: NumPy will throw a ValueError for an empty input 234 @onlyNativeDeviceTypes 235 @ops(spectral_funcs, allowed_dtypes=(torch.half, torch.float, torch.complex32, torch.cfloat)) 236 def test_empty_fft(self, device, dtype, op): 237 t = torch.empty(1, 0, device=device, dtype=dtype) 238 match = r"Invalid number of data points \([-\d]*\) specified" 239 240 with self.assertRaisesRegex(RuntimeError, match): 241 op(t) 242 243 @onlyNativeDeviceTypes 244 def test_empty_ifft(self, device): 245 t = torch.empty(2, 1, device=device, dtype=torch.complex64) 246 match = r"Invalid number of data points \([-\d]*\) specified" 247 248 for f in [torch.fft.irfft, torch.fft.irfft2, torch.fft.irfftn, 249 torch.fft.hfft, torch.fft.hfft2, torch.fft.hfftn]: 250 with self.assertRaisesRegex(RuntimeError, match): 251 f(t) 252 253 @onlyNativeDeviceTypes 254 def test_fft_invalid_dtypes(self, device): 255 t = torch.randn(64, device=device, dtype=torch.complex128) 256 257 with self.assertRaisesRegex(RuntimeError, "rfft expects a real input tensor"): 258 torch.fft.rfft(t) 259 260 with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input tensor"): 261 torch.fft.rfftn(t) 262 263 with self.assertRaisesRegex(RuntimeError, "ihfft expects a real input tensor"): 264 torch.fft.ihfft(t) 265 266 @skipCPUIfNoFFT 267 @onlyNativeDeviceTypes 268 @dtypes(torch.int8, torch.half, torch.float, torch.double, 269 torch.complex32, torch.complex64, torch.complex128) 270 def test_fft_type_promotion(self, device, dtype): 271 skip_helper_for_fft(device, dtype) 272 273 if dtype.is_complex or dtype.is_floating_point: 274 t = torch.randn(64, device=device, dtype=dtype) 275 else: 276 t = torch.randint(-2, 2, (64,), device=device, dtype=dtype) 277 278 PROMOTION_MAP = { 279 torch.int8: torch.complex64, 280 torch.half: torch.complex32, 281 torch.float: torch.complex64, 282 torch.double: torch.complex128, 283 torch.complex32: torch.complex32, 284 torch.complex64: torch.complex64, 285 torch.complex128: torch.complex128, 286 } 287 T = torch.fft.fft(t) 288 self.assertEqual(T.dtype, PROMOTION_MAP[dtype]) 289 290 PROMOTION_MAP_C2R = { 291 torch.int8: torch.float, 292 torch.half: torch.half, 293 torch.float: torch.float, 294 torch.double: torch.double, 295 torch.complex32: torch.half, 296 torch.complex64: torch.float, 297 torch.complex128: torch.double, 298 } 299 if dtype in (torch.half, torch.complex32): 300 # cuFFT supports powers of 2 for half and complex half precision 301 # NOTE: With hfft and default args where output_size n=2*(input_size - 1), 302 # we make sure that logical fft size is a power of two. 303 x = torch.randn(65, device=device, dtype=dtype) 304 R = torch.fft.hfft(x) 305 else: 306 R = torch.fft.hfft(t) 307 self.assertEqual(R.dtype, PROMOTION_MAP_C2R[dtype]) 308 309 if not dtype.is_complex: 310 PROMOTION_MAP_R2C = { 311 torch.int8: torch.complex64, 312 torch.half: torch.complex32, 313 torch.float: torch.complex64, 314 torch.double: torch.complex128, 315 } 316 C = torch.fft.rfft(t) 317 self.assertEqual(C.dtype, PROMOTION_MAP_R2C[dtype]) 318 319 @onlyNativeDeviceTypes 320 @ops(spectral_funcs, dtypes=OpDTypes.unsupported, 321 allowed_dtypes=[torch.half, torch.bfloat16]) 322 def test_fft_half_and_bfloat16_errors(self, device, dtype, op): 323 # TODO: Remove torch.half error when complex32 is fully implemented 324 sample = first_sample(self, op.sample_inputs(device, dtype)) 325 device_type = torch.device(device).type 326 default_msg = "Unsupported dtype" 327 if dtype is torch.half and device_type == 'cuda' and TEST_WITH_ROCM: 328 err_msg = default_msg 329 elif dtype is torch.half and device_type == 'cuda' and not SM53OrLater: 330 err_msg = "cuFFT doesn't support signals of half type with compute capability less than SM_53" 331 else: 332 err_msg = default_msg 333 with self.assertRaisesRegex(RuntimeError, err_msg): 334 op(sample.input, *sample.args, **sample.kwargs) 335 336 @onlyNativeDeviceTypes 337 @ops(spectral_funcs, allowed_dtypes=(torch.half, torch.chalf)) 338 def test_fft_half_and_chalf_not_power_of_two_error(self, device, dtype, op): 339 t = make_tensor(13, 13, device=device, dtype=dtype) 340 err_msg = "cuFFT only supports dimensions whose sizes are powers of two" 341 with self.assertRaisesRegex(RuntimeError, err_msg): 342 op(t) 343 344 if op.ndimensional in (SpectralFuncType.ND, SpectralFuncType.TwoD): 345 kwargs = {'s': (12, 12)} 346 else: 347 kwargs = {'n': 12} 348 349 with self.assertRaisesRegex(RuntimeError, err_msg): 350 op(t, **kwargs) 351 352 # nd-fft tests 353 @onlyNativeDeviceTypes 354 @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') 355 @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND], 356 allowed_dtypes=(torch.cfloat, torch.cdouble)) 357 def test_reference_nd(self, device, dtype, op): 358 if op.ref is None: 359 raise unittest.SkipTest("No reference implementation") 360 361 norm_modes = REFERENCE_NORM_MODES 362 363 # input_ndim, s, dim 364 transform_desc = [ 365 *product(range(2, 5), (None,), (None, (0,), (0, -1))), 366 *product(range(2, 5), (None, (4, 10)), (None,)), 367 (6, None, None), 368 (5, None, (1, 3, 4)), 369 (3, None, (1,)), 370 (1, None, (0,)), 371 (4, (10, 10), None), 372 (4, (10, 10), (0, 1)) 373 ] 374 375 for input_ndim, s, dim in transform_desc: 376 shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) 377 input = torch.randn(*shape, device=device, dtype=dtype) 378 379 for norm in norm_modes: 380 expected = op.ref(input.cpu().numpy(), s, dim, norm) 381 exact_dtype = dtype in (torch.double, torch.complex128) 382 actual = op(input, s, dim, norm) 383 self.assertEqual(actual, expected, exact_dtype=exact_dtype) 384 385 @skipCPUIfNoFFT 386 @onlyNativeDeviceTypes 387 @toleranceOverride({ 388 torch.half : tol(1e-2, 1e-2), 389 torch.chalf : tol(1e-2, 1e-2), 390 }) 391 @dtypes(torch.half, torch.float, torch.double, 392 torch.complex32, torch.complex64, torch.complex128) 393 def test_fftn_round_trip(self, device, dtype): 394 skip_helper_for_fft(device, dtype) 395 396 norm_modes = (None, "forward", "backward", "ortho") 397 398 # input_ndim, dim 399 transform_desc = [ 400 *product(range(2, 5), (None, (0,), (0, -1))), 401 (7, None), 402 (5, (1, 3, 4)), 403 (3, (1,)), 404 (1, 0), 405 ] 406 407 fft_functions = [(torch.fft.fftn, torch.fft.ifftn)] 408 409 # Real-only functions 410 if not dtype.is_complex: 411 # NOTE: Using ihfftn as "forward" transform to avoid needing to 412 # generate true half-complex input 413 fft_functions += [(torch.fft.rfftn, torch.fft.irfftn), 414 (torch.fft.ihfftn, torch.fft.hfftn)] 415 416 for input_ndim, dim in transform_desc: 417 if dtype in (torch.half, torch.complex32): 418 # cuFFT supports powers of 2 for half and complex half precision 419 shape = itertools.islice(itertools.cycle((2, 4, 8)), input_ndim) 420 else: 421 shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) 422 x = torch.randn(*shape, device=device, dtype=dtype) 423 424 for (forward, backward), norm in product(fft_functions, norm_modes): 425 if isinstance(dim, tuple): 426 s = [x.size(d) for d in dim] 427 else: 428 s = x.size() if dim is None else x.size(dim) 429 430 kwargs = {'s': s, 'dim': dim, 'norm': norm} 431 y = backward(forward(x, **kwargs), **kwargs) 432 # For real input, ifftn(fftn(x)) will convert to complex 433 if x.dtype is torch.half and y.dtype is torch.chalf: 434 # Since type promotion currently doesn't work with complex32 435 # manually promote `x` to complex32 436 self.assertEqual(x.to(torch.chalf), y) 437 else: 438 self.assertEqual(x, y, exact_dtype=( 439 forward != torch.fft.fftn or x.is_complex())) 440 441 @onlyNativeDeviceTypes 442 @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND], 443 allowed_dtypes=[torch.float, torch.cfloat]) 444 def test_fftn_invalid(self, device, dtype, op): 445 a = torch.rand(10, 10, 10, device=device, dtype=dtype) 446 # FIXME: https://github.com/pytorch/pytorch/issues/108205 447 errMsg = "dims must be unique" 448 with self.assertRaisesRegex(RuntimeError, errMsg): 449 op(a, dim=(0, 1, 0)) 450 451 with self.assertRaisesRegex(RuntimeError, errMsg): 452 op(a, dim=(2, -1)) 453 454 with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"): 455 op(a, s=(1,), dim=(0, 1)) 456 457 with self.assertRaisesRegex(IndexError, "Dimension out of range"): 458 op(a, dim=(3,)) 459 460 with self.assertRaisesRegex(RuntimeError, "tensor only has 3 dimensions"): 461 op(a, s=(10, 10, 10, 10)) 462 463 @skipCPUIfNoFFT 464 @onlyNativeDeviceTypes 465 @dtypes(torch.half, torch.float, torch.double, torch.cfloat, torch.cdouble) 466 def test_fftn_noop_transform(self, device, dtype): 467 skip_helper_for_fft(device, dtype) 468 RESULT_TYPE = { 469 torch.half: torch.chalf, 470 torch.float: torch.cfloat, 471 torch.double: torch.cdouble, 472 } 473 474 for op in [ 475 torch.fft.fftn, 476 torch.fft.ifftn, 477 torch.fft.fft2, 478 torch.fft.ifft2, 479 ]: 480 inp = make_tensor((10, 10), device=device, dtype=dtype) 481 out = torch.fft.fftn(inp, dim=[]) 482 483 expect_dtype = RESULT_TYPE.get(inp.dtype, inp.dtype) 484 expect = inp.to(expect_dtype) 485 self.assertEqual(expect, out) 486 487 488 @skipCPUIfNoFFT 489 @onlyNativeDeviceTypes 490 @toleranceOverride({ 491 torch.half : tol(1e-2, 1e-2), 492 }) 493 @dtypes(torch.half, torch.float, torch.double) 494 def test_hfftn(self, device, dtype): 495 skip_helper_for_fft(device, dtype) 496 497 # input_ndim, dim 498 transform_desc = [ 499 *product(range(2, 5), (None, (0,), (0, -1))), 500 (6, None), 501 (5, (1, 3, 4)), 502 (3, (1,)), 503 (1, (0,)), 504 (4, (0, 1)) 505 ] 506 507 for input_ndim, dim in transform_desc: 508 actual_dims = list(range(input_ndim)) if dim is None else dim 509 if dtype is torch.half: 510 shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim)) 511 else: 512 shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim)) 513 expect = torch.randn(*shape, device=device, dtype=dtype) 514 input = torch.fft.ifftn(expect, dim=dim, norm="ortho") 515 516 lastdim = actual_dims[-1] 517 lastdim_size = input.size(lastdim) // 2 + 1 518 idx = [slice(None)] * input_ndim 519 idx[lastdim] = slice(0, lastdim_size) 520 input = input[idx] 521 522 s = [shape[dim] for dim in actual_dims] 523 actual = torch.fft.hfftn(input, s=s, dim=dim, norm="ortho") 524 525 self.assertEqual(expect, actual) 526 527 @skipCPUIfNoFFT 528 @onlyNativeDeviceTypes 529 @toleranceOverride({ 530 torch.half : tol(1e-2, 1e-2), 531 }) 532 @dtypes(torch.half, torch.float, torch.double) 533 def test_ihfftn(self, device, dtype): 534 skip_helper_for_fft(device, dtype) 535 536 # input_ndim, dim 537 transform_desc = [ 538 *product(range(2, 5), (None, (0,), (0, -1))), 539 (6, None), 540 (5, (1, 3, 4)), 541 (3, (1,)), 542 (1, (0,)), 543 (4, (0, 1)) 544 ] 545 546 for input_ndim, dim in transform_desc: 547 if dtype is torch.half: 548 shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim)) 549 else: 550 shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim)) 551 552 input = torch.randn(*shape, device=device, dtype=dtype) 553 expect = torch.fft.ifftn(input, dim=dim, norm="ortho") 554 555 # Slice off the half-symmetric component 556 lastdim = -1 if dim is None else dim[-1] 557 lastdim_size = expect.size(lastdim) // 2 + 1 558 idx = [slice(None)] * input_ndim 559 idx[lastdim] = slice(0, lastdim_size) 560 expect = expect[idx] 561 562 actual = torch.fft.ihfftn(input, dim=dim, norm="ortho") 563 self.assertEqual(expect, actual) 564 565 566 # 2d-fft tests 567 568 # NOTE: 2d transforms are only thin wrappers over n-dim transforms, 569 # so don't require exhaustive testing. 570 571 572 @skipCPUIfNoFFT 573 @onlyNativeDeviceTypes 574 @dtypes(torch.double, torch.complex128) 575 def test_fft2_numpy(self, device, dtype): 576 norm_modes = REFERENCE_NORM_MODES 577 578 # input_ndim, s 579 transform_desc = [ 580 *product(range(2, 5), (None, (4, 10))), 581 ] 582 583 fft_functions = ['fft2', 'ifft2', 'irfft2', 'hfft2'] 584 if dtype.is_floating_point: 585 fft_functions += ['rfft2', 'ihfft2'] 586 587 for input_ndim, s in transform_desc: 588 shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) 589 input = torch.randn(*shape, device=device, dtype=dtype) 590 for fname, norm in product(fft_functions, norm_modes): 591 torch_fn = getattr(torch.fft, fname) 592 if "hfft" in fname: 593 if not has_scipy_fft: 594 continue # Requires scipy to compare against 595 numpy_fn = getattr(scipy.fft, fname) 596 else: 597 numpy_fn = getattr(np.fft, fname) 598 599 def fn(t: torch.Tensor, s: Optional[List[int]], dim: List[int] = (-2, -1), norm: Optional[str] = None): 600 return torch_fn(t, s, dim, norm) 601 602 torch_fns = (torch_fn, torch.jit.script(fn)) 603 604 # Once with dim defaulted 605 input_np = input.cpu().numpy() 606 expected = numpy_fn(input_np, s, norm=norm) 607 for fn in torch_fns: 608 actual = fn(input, s, norm=norm) 609 self.assertEqual(actual, expected) 610 611 # Once with explicit dims 612 dim = (1, 0) 613 expected = numpy_fn(input_np, s, dim, norm) 614 for fn in torch_fns: 615 actual = fn(input, s, dim, norm) 616 self.assertEqual(actual, expected) 617 618 @skipCPUIfNoFFT 619 @onlyNativeDeviceTypes 620 @dtypes(torch.float, torch.complex64) 621 def test_fft2_fftn_equivalence(self, device, dtype): 622 norm_modes = (None, "forward", "backward", "ortho") 623 624 # input_ndim, s, dim 625 transform_desc = [ 626 *product(range(2, 5), (None, (4, 10)), (None, (1, 0))), 627 (3, None, (0, 2)), 628 ] 629 630 fft_functions = ['fft', 'ifft', 'irfft', 'hfft'] 631 # Real-only functions 632 if dtype.is_floating_point: 633 fft_functions += ['rfft', 'ihfft'] 634 635 for input_ndim, s, dim in transform_desc: 636 shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) 637 x = torch.randn(*shape, device=device, dtype=dtype) 638 639 for func, norm in product(fft_functions, norm_modes): 640 f2d = getattr(torch.fft, func + '2') 641 fnd = getattr(torch.fft, func + 'n') 642 643 kwargs = {'s': s, 'norm': norm} 644 645 if dim is not None: 646 kwargs['dim'] = dim 647 expect = fnd(x, **kwargs) 648 else: 649 expect = fnd(x, dim=(-2, -1), **kwargs) 650 651 actual = f2d(x, **kwargs) 652 653 self.assertEqual(actual, expect) 654 655 @skipCPUIfNoFFT 656 @onlyNativeDeviceTypes 657 def test_fft2_invalid(self, device): 658 a = torch.rand(10, 10, 10, device=device) 659 fft_funcs = (torch.fft.fft2, torch.fft.ifft2, 660 torch.fft.rfft2, torch.fft.irfft2) 661 662 for func in fft_funcs: 663 with self.assertRaisesRegex(RuntimeError, "dims must be unique"): 664 func(a, dim=(0, 0)) 665 666 with self.assertRaisesRegex(RuntimeError, "dims must be unique"): 667 func(a, dim=(2, -1)) 668 669 with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"): 670 func(a, s=(1,)) 671 672 with self.assertRaisesRegex(IndexError, "Dimension out of range"): 673 func(a, dim=(2, 3)) 674 675 c = torch.complex(a, a) 676 with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input"): 677 torch.fft.rfft2(c) 678 679 # Helper functions 680 681 @skipCPUIfNoFFT 682 @onlyNativeDeviceTypes 683 @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') 684 @dtypes(torch.float, torch.double) 685 def test_fftfreq_numpy(self, device, dtype): 686 test_args = [ 687 *product( 688 # n 689 range(1, 20), 690 # d 691 (None, 10.0), 692 ) 693 ] 694 695 functions = ['fftfreq', 'rfftfreq'] 696 697 for fname in functions: 698 torch_fn = getattr(torch.fft, fname) 699 numpy_fn = getattr(np.fft, fname) 700 701 for n, d in test_args: 702 args = (n,) if d is None else (n, d) 703 expected = numpy_fn(*args) 704 actual = torch_fn(*args, device=device, dtype=dtype) 705 self.assertEqual(actual, expected, exact_dtype=False) 706 707 @skipCPUIfNoFFT 708 @onlyNativeDeviceTypes 709 @dtypes(torch.float, torch.double) 710 def test_fftfreq_out(self, device, dtype): 711 for func in (torch.fft.fftfreq, torch.fft.rfftfreq): 712 expect = func(n=100, d=.5, device=device, dtype=dtype) 713 actual = torch.empty((), device=device, dtype=dtype) 714 with self.assertWarnsRegex(UserWarning, "out tensor will be resized"): 715 func(n=100, d=.5, out=actual) 716 self.assertEqual(actual, expect) 717 718 719 @skipCPUIfNoFFT 720 @onlyNativeDeviceTypes 721 @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') 722 @dtypes(torch.float, torch.double, torch.complex64, torch.complex128) 723 def test_fftshift_numpy(self, device, dtype): 724 test_args = [ 725 # shape, dim 726 *product(((11,), (12,)), (None, 0, -1)), 727 *product(((4, 5), (6, 6)), (None, 0, (-1,))), 728 *product(((1, 1, 4, 6, 7, 2),), (None, (3, 4))), 729 ] 730 731 functions = ['fftshift', 'ifftshift'] 732 733 for shape, dim in test_args: 734 input = torch.rand(*shape, device=device, dtype=dtype) 735 input_np = input.cpu().numpy() 736 737 for fname in functions: 738 torch_fn = getattr(torch.fft, fname) 739 numpy_fn = getattr(np.fft, fname) 740 741 expected = numpy_fn(input_np, axes=dim) 742 actual = torch_fn(input, dim=dim) 743 self.assertEqual(actual, expected) 744 745 @skipCPUIfNoFFT 746 @onlyNativeDeviceTypes 747 @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') 748 @dtypes(torch.float, torch.double) 749 def test_fftshift_frequencies(self, device, dtype): 750 for n in range(10, 15): 751 sorted_fft_freqs = torch.arange(-(n // 2), n - (n // 2), 752 device=device, dtype=dtype) 753 x = torch.fft.fftfreq(n, d=1 / n, device=device, dtype=dtype) 754 755 # Test fftshift sorts the fftfreq output 756 shifted = torch.fft.fftshift(x) 757 self.assertEqual(shifted, shifted.sort().values) 758 self.assertEqual(sorted_fft_freqs, shifted) 759 760 # And ifftshift is the inverse 761 self.assertEqual(x, torch.fft.ifftshift(shifted)) 762 763 # Legacy fft tests 764 def _test_fft_ifft_rfft_irfft(self, device, dtype): 765 complex_dtype = corresponding_complex_dtype(dtype) 766 767 def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x): 768 x = prepro_fn(torch.randn(*sizes, dtype=complex_dtype, device=device)) 769 dim = tuple(range(-signal_ndim, 0)) 770 for norm in ('ortho', None): 771 res = torch.fft.fftn(x, dim=dim, norm=norm) 772 rec = torch.fft.ifftn(res, dim=dim, norm=norm) 773 self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='fft and ifft') 774 res = torch.fft.ifftn(x, dim=dim, norm=norm) 775 rec = torch.fft.fftn(res, dim=dim, norm=norm) 776 self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='ifft and fft') 777 778 def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x): 779 x = prepro_fn(torch.randn(*sizes, dtype=dtype, device=device)) 780 signal_numel = 1 781 signal_sizes = x.size()[-signal_ndim:] 782 dim = tuple(range(-signal_ndim, 0)) 783 for norm in (None, 'ortho'): 784 res = torch.fft.rfftn(x, dim=dim, norm=norm) 785 rec = torch.fft.irfftn(res, s=signal_sizes, dim=dim, norm=norm) 786 self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='rfft and irfft') 787 res = torch.fft.fftn(x, dim=dim, norm=norm) 788 rec = torch.fft.ifftn(res, dim=dim, norm=norm) 789 x_complex = torch.complex(x, torch.zeros_like(x)) 790 self.assertEqual(x_complex, rec, atol=1e-8, rtol=0, msg='fft and ifft (from real)') 791 792 # contiguous case 793 _test_real((100,), 1) 794 _test_real((10, 1, 10, 100), 1) 795 _test_real((100, 100), 2) 796 _test_real((2, 2, 5, 80, 60), 2) 797 _test_real((50, 40, 70), 3) 798 _test_real((30, 1, 50, 25, 20), 3) 799 800 _test_complex((100,), 1) 801 _test_complex((100, 100), 1) 802 _test_complex((100, 100), 2) 803 _test_complex((1, 20, 80, 60), 2) 804 _test_complex((50, 40, 70), 3) 805 _test_complex((6, 5, 50, 25, 20), 3) 806 807 # non-contiguous case 808 _test_real((165,), 1, lambda x: x.narrow(0, 25, 100)) # input is not aligned to complex type 809 _test_real((100, 100, 3), 1, lambda x: x[:, :, 0]) 810 _test_real((100, 100), 2, lambda x: x.t()) 811 _test_real((20, 100, 10, 10), 2, lambda x: x.view(20, 100, 100)[:, :60]) 812 _test_real((65, 80, 115), 3, lambda x: x[10:60, 13:53, 10:80]) 813 _test_real((30, 20, 50, 25), 3, lambda x: x.transpose(1, 2).transpose(2, 3)) 814 815 _test_complex((100,), 1, lambda x: x.expand(100, 100)) 816 _test_complex((20, 90, 110), 2, lambda x: x[:, 5:85].narrow(2, 5, 100)) 817 _test_complex((40, 60, 3, 80), 3, lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:]) 818 _test_complex((30, 55, 50, 22), 3, lambda x: x[:, 3:53, 15:40, 1:21]) 819 820 @skipCPUIfNoFFT 821 @onlyNativeDeviceTypes 822 @dtypes(torch.double) 823 def test_fft_ifft_rfft_irfft(self, device, dtype): 824 self._test_fft_ifft_rfft_irfft(device, dtype) 825 826 @deviceCountAtLeast(1) 827 @onlyCUDA 828 @dtypes(torch.double) 829 def test_cufft_plan_cache(self, devices, dtype): 830 @contextmanager 831 def plan_cache_max_size(device, n): 832 if device is None: 833 plan_cache = torch.backends.cuda.cufft_plan_cache 834 else: 835 plan_cache = torch.backends.cuda.cufft_plan_cache[device] 836 original = plan_cache.max_size 837 plan_cache.max_size = n 838 try: 839 yield 840 finally: 841 plan_cache.max_size = original 842 843 with plan_cache_max_size(devices[0], max(1, torch.backends.cuda.cufft_plan_cache.size - 10)): 844 self._test_fft_ifft_rfft_irfft(devices[0], dtype) 845 846 with plan_cache_max_size(devices[0], 0): 847 self._test_fft_ifft_rfft_irfft(devices[0], dtype) 848 849 torch.backends.cuda.cufft_plan_cache.clear() 850 851 # check that stll works after clearing cache 852 with plan_cache_max_size(devices[0], 10): 853 self._test_fft_ifft_rfft_irfft(devices[0], dtype) 854 855 with self.assertRaisesRegex(RuntimeError, r"must be non-negative"): 856 torch.backends.cuda.cufft_plan_cache.max_size = -1 857 858 with self.assertRaisesRegex(RuntimeError, r"read-only property"): 859 torch.backends.cuda.cufft_plan_cache.size = -1 860 861 with self.assertRaisesRegex(RuntimeError, r"but got device with index"): 862 torch.backends.cuda.cufft_plan_cache[torch.cuda.device_count() + 10] 863 864 # Multigpu tests 865 if len(devices) > 1: 866 # Test that different GPU has different cache 867 x0 = torch.randn(2, 3, 3, device=devices[0]) 868 x1 = x0.to(devices[1]) 869 self.assertEqual(torch.fft.rfftn(x0, dim=(-2, -1)), torch.fft.rfftn(x1, dim=(-2, -1))) 870 # If a plan is used across different devices, the following line (or 871 # the assert above) would trigger illegal memory access. Other ways 872 # to trigger the error include 873 # (1) setting CUDA_LAUNCH_BLOCKING=1 (pytorch/pytorch#19224) and 874 # (2) printing a device 1 tensor. 875 x0.copy_(x1) 876 877 # Test that un-indexed `torch.backends.cuda.cufft_plan_cache` uses current device 878 with plan_cache_max_size(devices[0], 10): 879 with plan_cache_max_size(devices[1], 11): 880 self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10) 881 self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11) 882 883 self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0 884 with torch.cuda.device(devices[1]): 885 self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1 886 with torch.cuda.device(devices[0]): 887 self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0 888 889 self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10) 890 with torch.cuda.device(devices[1]): 891 with plan_cache_max_size(None, 11): # default is cuda:1 892 self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10) 893 self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11) 894 895 self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1 896 with torch.cuda.device(devices[0]): 897 self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0 898 self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1 899 900 @onlyCUDA 901 @dtypes(torch.cfloat, torch.cdouble) 902 def test_cufft_context(self, device, dtype): 903 # Regression test for https://github.com/pytorch/pytorch/issues/109448 904 x = torch.randn(32, dtype=dtype, device=device, requires_grad=True) 905 dout = torch.zeros(32, dtype=dtype, device=device) 906 907 # compute iFFT(FFT(x)) 908 out = torch.fft.ifft(torch.fft.fft(x)) 909 out.backward(dout, retain_graph=True) 910 911 dx = torch.fft.fft(torch.fft.ifft(dout)) 912 913 self.assertTrue((x.grad - dx).abs().max() == 0) 914 self.assertFalse((x.grad - x).abs().max() == 0) 915 916 # passes on ROCm w/ python 2.7, fails w/ python 3.6 917 @skipIfTorchDynamo("cannot set WRITEABLE flag to True of this array") 918 @skipCPUIfNoFFT 919 @onlyNativeDeviceTypes 920 @dtypes(torch.double) 921 def test_stft(self, device, dtype): 922 if not TEST_LIBROSA: 923 raise unittest.SkipTest('librosa not found') 924 925 def librosa_stft(x, n_fft, hop_length, win_length, window, center): 926 if window is None: 927 window = np.ones(n_fft if win_length is None else win_length) 928 else: 929 window = window.cpu().numpy() 930 input_1d = x.dim() == 1 931 if input_1d: 932 x = x.view(1, -1) 933 934 # NOTE: librosa 0.9 changed default pad_mode to 'constant' (zero padding) 935 # however, we use the pre-0.9 default ('reflect') 936 pad_mode = 'reflect' 937 938 result = [] 939 for xi in x: 940 ri = librosa.stft(xi.cpu().numpy(), n_fft=n_fft, hop_length=hop_length, 941 win_length=win_length, window=window, center=center, 942 pad_mode=pad_mode) 943 result.append(torch.from_numpy(np.stack([ri.real, ri.imag], -1))) 944 result = torch.stack(result, 0) 945 if input_1d: 946 result = result[0] 947 return result 948 949 def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None, 950 center=True, expected_error=None): 951 x = torch.randn(*sizes, dtype=dtype, device=device) 952 if win_sizes is not None: 953 window = torch.randn(*win_sizes, dtype=dtype, device=device) 954 else: 955 window = None 956 if expected_error is None: 957 result = x.stft(n_fft, hop_length, win_length, window, 958 center=center, return_complex=False) 959 # NB: librosa defaults to np.complex64 output, no matter what 960 # the input dtype 961 ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center) 962 self.assertEqual(result, ref_result, atol=7e-6, rtol=0, msg='stft comparison against librosa', exact_dtype=False) 963 # With return_complex=True, the result is the same but viewed as complex instead of real 964 result_complex = x.stft(n_fft, hop_length, win_length, window, center=center, return_complex=True) 965 self.assertEqual(result_complex, torch.view_as_complex(result)) 966 else: 967 self.assertRaises(expected_error, 968 lambda: x.stft(n_fft, hop_length, win_length, window, center=center)) 969 970 for center in [True, False]: 971 _test((10,), 7, center=center) 972 _test((10, 4000), 1024, center=center) 973 974 _test((10,), 7, 2, center=center) 975 _test((10, 4000), 1024, 512, center=center) 976 977 _test((10,), 7, 2, win_sizes=(7,), center=center) 978 _test((10, 4000), 1024, 512, win_sizes=(1024,), center=center) 979 980 # spectral oversample 981 _test((10,), 7, 2, win_length=5, center=center) 982 _test((10, 4000), 1024, 512, win_length=100, center=center) 983 984 _test((10, 4, 2), 1, 1, expected_error=RuntimeError) 985 _test((10,), 11, 1, center=False, expected_error=RuntimeError) 986 _test((10,), -1, 1, expected_error=RuntimeError) 987 _test((10,), 3, win_length=5, expected_error=RuntimeError) 988 _test((10,), 5, 4, win_sizes=(11,), expected_error=RuntimeError) 989 _test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError) 990 991 @skipIfTorchDynamo("double") 992 @skipCPUIfNoFFT 993 @onlyNativeDeviceTypes 994 @dtypes(torch.double) 995 def test_istft_against_librosa(self, device, dtype): 996 if not TEST_LIBROSA: 997 raise unittest.SkipTest('librosa not found') 998 999 def librosa_istft(x, n_fft, hop_length, win_length, window, length, center): 1000 if window is None: 1001 window = np.ones(n_fft if win_length is None else win_length) 1002 else: 1003 window = window.cpu().numpy() 1004 1005 return librosa.istft(x.cpu().numpy(), n_fft=n_fft, hop_length=hop_length, 1006 win_length=win_length, length=length, window=window, center=center) 1007 1008 def _test(size, n_fft, hop_length=None, win_length=None, win_sizes=None, 1009 length=None, center=True): 1010 x = torch.randn(size, dtype=dtype, device=device) 1011 if win_sizes is not None: 1012 window = torch.randn(*win_sizes, dtype=dtype, device=device) 1013 else: 1014 window = None 1015 1016 x_stft = x.stft(n_fft, hop_length, win_length, window, center=center, 1017 onesided=True, return_complex=True) 1018 1019 ref_result = librosa_istft(x_stft, n_fft, hop_length, win_length, 1020 window, length, center) 1021 result = x_stft.istft(n_fft, hop_length, win_length, window, 1022 length=length, center=center) 1023 self.assertEqual(result, ref_result) 1024 1025 for center in [True, False]: 1026 _test(10, 7, center=center) 1027 _test(4000, 1024, center=center) 1028 _test(4000, 1024, center=center, length=4000) 1029 1030 _test(10, 7, 2, center=center) 1031 _test(4000, 1024, 512, center=center) 1032 _test(4000, 1024, 512, center=center, length=4000) 1033 1034 _test(10, 7, 2, win_sizes=(7,), center=center) 1035 _test(4000, 1024, 512, win_sizes=(1024,), center=center) 1036 _test(4000, 1024, 512, win_sizes=(1024,), center=center, length=4000) 1037 1038 @onlyNativeDeviceTypes 1039 @skipCPUIfNoFFT 1040 @dtypes(torch.double, torch.cdouble) 1041 def test_complex_stft_roundtrip(self, device, dtype): 1042 test_args = list(product( 1043 # input 1044 (torch.randn(600, device=device, dtype=dtype), 1045 torch.randn(807, device=device, dtype=dtype), 1046 torch.randn(12, 60, device=device, dtype=dtype)), 1047 # n_fft 1048 (50, 27), 1049 # hop_length 1050 (None, 10), 1051 # center 1052 (True,), 1053 # pad_mode 1054 ("constant", "reflect", "circular"), 1055 # normalized 1056 (True, False), 1057 # onesided 1058 (True, False) if not dtype.is_complex else (False,), 1059 )) 1060 1061 for args in test_args: 1062 x, n_fft, hop_length, center, pad_mode, normalized, onesided = args 1063 common_kwargs = { 1064 'n_fft': n_fft, 'hop_length': hop_length, 'center': center, 1065 'normalized': normalized, 'onesided': onesided, 1066 } 1067 1068 # Functional interface 1069 x_stft = torch.stft(x, pad_mode=pad_mode, return_complex=True, **common_kwargs) 1070 x_roundtrip = torch.istft(x_stft, return_complex=dtype.is_complex, 1071 length=x.size(-1), **common_kwargs) 1072 self.assertEqual(x_roundtrip, x) 1073 1074 # Tensor method interface 1075 x_stft = x.stft(pad_mode=pad_mode, return_complex=True, **common_kwargs) 1076 x_roundtrip = torch.istft(x_stft, return_complex=dtype.is_complex, 1077 length=x.size(-1), **common_kwargs) 1078 self.assertEqual(x_roundtrip, x) 1079 1080 @onlyNativeDeviceTypes 1081 @skipCPUIfNoFFT 1082 @dtypes(torch.double, torch.cdouble) 1083 def test_stft_roundtrip_complex_window(self, device, dtype): 1084 test_args = list(product( 1085 # input 1086 (torch.randn(600, device=device, dtype=dtype), 1087 torch.randn(807, device=device, dtype=dtype), 1088 torch.randn(12, 60, device=device, dtype=dtype)), 1089 # n_fft 1090 (50, 27), 1091 # hop_length 1092 (None, 10), 1093 # pad_mode 1094 ("constant", "reflect", "replicate", "circular"), 1095 # normalized 1096 (True, False), 1097 )) 1098 for args in test_args: 1099 x, n_fft, hop_length, pad_mode, normalized = args 1100 window = torch.rand(n_fft, device=device, dtype=torch.cdouble) 1101 x_stft = torch.stft( 1102 x, n_fft=n_fft, hop_length=hop_length, window=window, 1103 center=True, pad_mode=pad_mode, normalized=normalized) 1104 self.assertEqual(x_stft.dtype, torch.cdouble) 1105 self.assertEqual(x_stft.size(-2), n_fft) # Not onesided 1106 1107 x_roundtrip = torch.istft( 1108 x_stft, n_fft=n_fft, hop_length=hop_length, window=window, 1109 center=True, normalized=normalized, length=x.size(-1), 1110 return_complex=True) 1111 self.assertEqual(x_stft.dtype, torch.cdouble) 1112 1113 if not dtype.is_complex: 1114 self.assertEqual(x_roundtrip.imag, torch.zeros_like(x_roundtrip.imag), 1115 atol=1e-6, rtol=0) 1116 self.assertEqual(x_roundtrip.real, x) 1117 else: 1118 self.assertEqual(x_roundtrip, x) 1119 1120 1121 @skipCPUIfNoFFT 1122 @dtypes(torch.cdouble) 1123 def test_complex_stft_definition(self, device, dtype): 1124 test_args = list(product( 1125 # input 1126 (torch.randn(600, device=device, dtype=dtype), 1127 torch.randn(807, device=device, dtype=dtype)), 1128 # n_fft 1129 (50, 27), 1130 # hop_length 1131 (10, 15) 1132 )) 1133 1134 for args in test_args: 1135 window = torch.randn(args[1], device=device, dtype=dtype) 1136 expected = _stft_reference(args[0], args[2], window) 1137 actual = torch.stft(*args, window=window, center=False) 1138 self.assertEqual(actual, expected) 1139 1140 @onlyNativeDeviceTypes 1141 @skipCPUIfNoFFT 1142 @dtypes(torch.cdouble) 1143 def test_complex_stft_real_equiv(self, device, dtype): 1144 test_args = list(product( 1145 # input 1146 (torch.rand(600, device=device, dtype=dtype), 1147 torch.rand(807, device=device, dtype=dtype), 1148 torch.rand(14, 50, device=device, dtype=dtype), 1149 torch.rand(6, 51, device=device, dtype=dtype)), 1150 # n_fft 1151 (50, 27), 1152 # hop_length 1153 (None, 10), 1154 # win_length 1155 (None, 20), 1156 # center 1157 (False, True), 1158 # pad_mode 1159 ("constant", "reflect", "circular"), 1160 # normalized 1161 (True, False), 1162 )) 1163 1164 for args in test_args: 1165 x, n_fft, hop_length, win_length, center, pad_mode, normalized = args 1166 expected = _complex_stft(x, n_fft, hop_length=hop_length, 1167 win_length=win_length, pad_mode=pad_mode, 1168 center=center, normalized=normalized) 1169 actual = torch.stft(x, n_fft, hop_length=hop_length, 1170 win_length=win_length, pad_mode=pad_mode, 1171 center=center, normalized=normalized) 1172 self.assertEqual(expected, actual) 1173 1174 @skipCPUIfNoFFT 1175 @dtypes(torch.cdouble) 1176 def test_complex_istft_real_equiv(self, device, dtype): 1177 test_args = list(product( 1178 # input 1179 (torch.rand(40, 20, device=device, dtype=dtype), 1180 torch.rand(25, 1, device=device, dtype=dtype), 1181 torch.rand(4, 20, 10, device=device, dtype=dtype)), 1182 # hop_length 1183 (None, 10), 1184 # center 1185 (False, True), 1186 # normalized 1187 (True, False), 1188 )) 1189 1190 for args in test_args: 1191 x, hop_length, center, normalized = args 1192 n_fft = x.size(-2) 1193 expected = _complex_istft(x, n_fft, hop_length=hop_length, 1194 center=center, normalized=normalized) 1195 actual = torch.istft(x, n_fft, hop_length=hop_length, 1196 center=center, normalized=normalized, 1197 return_complex=True) 1198 self.assertEqual(expected, actual) 1199 1200 @skipCPUIfNoFFT 1201 def test_complex_stft_onesided(self, device): 1202 # stft of complex input cannot be onesided 1203 for x_dtype, window_dtype in product((torch.double, torch.cdouble), repeat=2): 1204 x = torch.rand(100, device=device, dtype=x_dtype) 1205 window = torch.rand(10, device=device, dtype=window_dtype) 1206 1207 if x_dtype.is_complex or window_dtype.is_complex: 1208 with self.assertRaisesRegex(RuntimeError, 'complex'): 1209 x.stft(10, window=window, pad_mode='constant', onesided=True) 1210 else: 1211 y = x.stft(10, window=window, pad_mode='constant', onesided=True, 1212 return_complex=True) 1213 self.assertEqual(y.dtype, torch.cdouble) 1214 self.assertEqual(y.size(), (6, 51)) 1215 1216 x = torch.rand(100, device=device, dtype=torch.cdouble) 1217 with self.assertRaisesRegex(RuntimeError, 'complex'): 1218 x.stft(10, pad_mode='constant', onesided=True) 1219 1220 # stft is currently warning that it requires return-complex while an upgrader is written 1221 @onlyNativeDeviceTypes 1222 @skipCPUIfNoFFT 1223 def test_stft_requires_complex(self, device): 1224 x = torch.rand(100) 1225 with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'): 1226 y = x.stft(10, pad_mode='constant') 1227 1228 # stft and istft are currently warning if a window is not provided 1229 @onlyNativeDeviceTypes 1230 @skipCPUIfNoFFT 1231 def test_stft_requires_window(self, device): 1232 x = torch.rand(100) 1233 with self.assertWarnsOnceRegex(UserWarning, "A window was not provided"): 1234 y = x.stft(10, pad_mode='constant', return_complex=True) 1235 1236 @onlyNativeDeviceTypes 1237 @skipCPUIfNoFFT 1238 def test_istft_requires_window(self, device): 1239 stft = torch.rand((51, 5), dtype=torch.cdouble) 1240 # 51 = 2 * n_fft + 1, 5 = number of frames 1241 with self.assertWarnsOnceRegex(UserWarning, "A window was not provided"): 1242 x = torch.istft(stft, n_fft=100, length=100) 1243 1244 @skipCPUIfNoFFT 1245 def test_fft_input_modification(self, device): 1246 # FFT functions should not modify their input (gh-34551) 1247 1248 signal = torch.ones((2, 2, 2), device=device) 1249 signal_copy = signal.clone() 1250 spectrum = torch.fft.fftn(signal, dim=(-2, -1)) 1251 self.assertEqual(signal, signal_copy) 1252 1253 spectrum_copy = spectrum.clone() 1254 _ = torch.fft.ifftn(spectrum, dim=(-2, -1)) 1255 self.assertEqual(spectrum, spectrum_copy) 1256 1257 half_spectrum = torch.fft.rfftn(signal, dim=(-2, -1)) 1258 self.assertEqual(signal, signal_copy) 1259 1260 half_spectrum_copy = half_spectrum.clone() 1261 _ = torch.fft.irfftn(half_spectrum_copy, s=(2, 2), dim=(-2, -1)) 1262 self.assertEqual(half_spectrum, half_spectrum_copy) 1263 1264 @onlyNativeDeviceTypes 1265 @skipCPUIfNoFFT 1266 def test_fft_plan_repeatable(self, device): 1267 # Regression test for gh-58724 and gh-63152 1268 for n in [2048, 3199, 5999]: 1269 a = torch.randn(n, device=device, dtype=torch.complex64) 1270 res1 = torch.fft.fftn(a) 1271 res2 = torch.fft.fftn(a.clone()) 1272 self.assertEqual(res1, res2) 1273 1274 a = torch.randn(n, device=device, dtype=torch.float64) 1275 res1 = torch.fft.rfft(a) 1276 res2 = torch.fft.rfft(a.clone()) 1277 self.assertEqual(res1, res2) 1278 1279 @onlyNativeDeviceTypes 1280 @skipCPUIfNoFFT 1281 @dtypes(torch.double) 1282 def test_istft_round_trip_simple_cases(self, device, dtype): 1283 """stft -> istft should recover the original signale""" 1284 def _test(input, n_fft, length): 1285 stft = torch.stft(input, n_fft=n_fft, return_complex=True) 1286 inverse = torch.istft(stft, n_fft=n_fft, length=length) 1287 self.assertEqual(input, inverse, exact_dtype=True) 1288 1289 _test(torch.ones(4, dtype=dtype, device=device), 4, 4) 1290 _test(torch.zeros(4, dtype=dtype, device=device), 4, 4) 1291 1292 @onlyNativeDeviceTypes 1293 @skipCPUIfNoFFT 1294 @dtypes(torch.double) 1295 def test_istft_round_trip_various_params(self, device, dtype): 1296 """stft -> istft should recover the original signale""" 1297 def _test_istft_is_inverse_of_stft(stft_kwargs): 1298 # generates a random sound signal for each tril and then does the stft/istft 1299 # operation to check whether we can reconstruct signal 1300 data_sizes = [(2, 20), (3, 15), (4, 10)] 1301 num_trials = 100 1302 istft_kwargs = stft_kwargs.copy() 1303 del istft_kwargs['pad_mode'] 1304 for sizes in data_sizes: 1305 for i in range(num_trials): 1306 original = torch.randn(*sizes, dtype=dtype, device=device) 1307 stft = torch.stft(original, return_complex=True, **stft_kwargs) 1308 inversed = torch.istft(stft, length=original.size(1), **istft_kwargs) 1309 self.assertEqual( 1310 inversed, original, msg='istft comparison against original', 1311 atol=7e-6, rtol=0, exact_dtype=True) 1312 1313 patterns = [ 1314 # hann_window, centered, normalized, onesided 1315 { 1316 'n_fft': 12, 1317 'hop_length': 4, 1318 'win_length': 12, 1319 'window': torch.hann_window(12, dtype=dtype, device=device), 1320 'center': True, 1321 'pad_mode': 'reflect', 1322 'normalized': True, 1323 'onesided': True, 1324 }, 1325 # hann_window, centered, not normalized, not onesided 1326 { 1327 'n_fft': 12, 1328 'hop_length': 2, 1329 'win_length': 8, 1330 'window': torch.hann_window(8, dtype=dtype, device=device), 1331 'center': True, 1332 'pad_mode': 'reflect', 1333 'normalized': False, 1334 'onesided': False, 1335 }, 1336 # hamming_window, centered, normalized, not onesided 1337 { 1338 'n_fft': 15, 1339 'hop_length': 3, 1340 'win_length': 11, 1341 'window': torch.hamming_window(11, dtype=dtype, device=device), 1342 'center': True, 1343 'pad_mode': 'constant', 1344 'normalized': True, 1345 'onesided': False, 1346 }, 1347 # hamming_window, centered, not normalized, onesided 1348 # window same size as n_fft 1349 { 1350 'n_fft': 5, 1351 'hop_length': 2, 1352 'win_length': 5, 1353 'window': torch.hamming_window(5, dtype=dtype, device=device), 1354 'center': True, 1355 'pad_mode': 'constant', 1356 'normalized': False, 1357 'onesided': True, 1358 }, 1359 ] 1360 for i, pattern in enumerate(patterns): 1361 _test_istft_is_inverse_of_stft(pattern) 1362 1363 @onlyNativeDeviceTypes 1364 @skipCPUIfNoFFT 1365 @dtypes(torch.double) 1366 def test_istft_round_trip_with_padding(self, device, dtype): 1367 """long hop_length or not centered may cause length mismatch in the inversed signal""" 1368 def _test_istft_is_inverse_of_stft_with_padding(stft_kwargs): 1369 # generates a random sound signal for each tril and then does the stft/istft 1370 # operation to check whether we can reconstruct signal 1371 num_trials = 100 1372 sizes = stft_kwargs['size'] 1373 del stft_kwargs['size'] 1374 istft_kwargs = stft_kwargs.copy() 1375 del istft_kwargs['pad_mode'] 1376 for i in range(num_trials): 1377 original = torch.randn(*sizes, dtype=dtype, device=device) 1378 stft = torch.stft(original, return_complex=True, **stft_kwargs) 1379 with self.assertWarnsOnceRegex(UserWarning, "The length of signal is shorter than the length parameter."): 1380 inversed = torch.istft(stft, length=original.size(-1), **istft_kwargs) 1381 n_frames = stft.size(-1) 1382 if stft_kwargs["center"] is True: 1383 len_expected = stft_kwargs["n_fft"] // 2 + stft_kwargs["hop_length"] * (n_frames - 1) 1384 else: 1385 len_expected = stft_kwargs["n_fft"] + stft_kwargs["hop_length"] * (n_frames - 1) 1386 # trim the original for case when constructed signal is shorter than original 1387 padding = inversed[..., len_expected:] 1388 inversed = inversed[..., :len_expected] 1389 original = original[..., :len_expected] 1390 # test the padding points of the inversed signal are all zeros 1391 zeros = torch.zeros_like(padding, device=padding.device) 1392 self.assertEqual( 1393 padding, zeros, msg='istft padding values against zeros', 1394 atol=7e-6, rtol=0, exact_dtype=True) 1395 self.assertEqual( 1396 inversed, original, msg='istft comparison against original', 1397 atol=7e-6, rtol=0, exact_dtype=True) 1398 1399 patterns = [ 1400 # hamming_window, not centered, not normalized, not onesided 1401 # window same size as n_fft 1402 { 1403 'size': [2, 20], 1404 'n_fft': 3, 1405 'hop_length': 2, 1406 'win_length': 3, 1407 'window': torch.hamming_window(3, dtype=dtype, device=device), 1408 'center': False, 1409 'pad_mode': 'reflect', 1410 'normalized': False, 1411 'onesided': False, 1412 }, 1413 # hamming_window, centered, not normalized, onesided, long hop_length 1414 # window same size as n_fft 1415 { 1416 'size': [2, 500], 1417 'n_fft': 256, 1418 'hop_length': 254, 1419 'win_length': 256, 1420 'window': torch.hamming_window(256, dtype=dtype, device=device), 1421 'center': True, 1422 'pad_mode': 'constant', 1423 'normalized': False, 1424 'onesided': True, 1425 }, 1426 ] 1427 for i, pattern in enumerate(patterns): 1428 _test_istft_is_inverse_of_stft_with_padding(pattern) 1429 1430 @onlyNativeDeviceTypes 1431 def test_istft_throws(self, device): 1432 """istft should throw exception for invalid parameters""" 1433 stft = torch.zeros((3, 5, 2), device=device) 1434 # the window is size 1 but it hops 20 so there is a gap which throw an error 1435 self.assertRaises( 1436 RuntimeError, torch.istft, stft, n_fft=4, 1437 hop_length=20, win_length=1, window=torch.ones(1)) 1438 # A window of zeros does not meet NOLA 1439 invalid_window = torch.zeros(4, device=device) 1440 self.assertRaises( 1441 RuntimeError, torch.istft, stft, n_fft=4, win_length=4, window=invalid_window) 1442 # Input cannot be empty 1443 self.assertRaises(RuntimeError, torch.istft, torch.zeros((3, 0, 2)), 2) 1444 self.assertRaises(RuntimeError, torch.istft, torch.zeros((0, 3, 2)), 2) 1445 1446 @skipIfTorchDynamo("Failed running call_function") 1447 @onlyNativeDeviceTypes 1448 @skipCPUIfNoFFT 1449 @dtypes(torch.double) 1450 def test_istft_of_sine(self, device, dtype): 1451 complex_dtype = corresponding_complex_dtype(dtype) 1452 1453 def _test(amplitude, L, n): 1454 # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L 1455 x = torch.arange(2 * L + 1, device=device, dtype=dtype) 1456 original = amplitude * torch.sin(2 * math.pi / L * x * n) 1457 # stft = torch.stft(original, L, hop_length=L, win_length=L, 1458 # window=torch.ones(L), center=False, normalized=False) 1459 stft = torch.zeros((L // 2 + 1, 2), device=device, dtype=complex_dtype) 1460 stft_largest_val = (amplitude * L) / 2.0 1461 if n < stft.size(0): 1462 stft[n].imag = torch.tensor(-stft_largest_val, dtype=dtype) 1463 1464 if 0 <= L - n < stft.size(0): 1465 # symmetric about L // 2 1466 stft[L - n].imag = torch.tensor(stft_largest_val, dtype=dtype) 1467 1468 inverse = torch.istft( 1469 stft, L, hop_length=L, win_length=L, 1470 window=torch.ones(L, device=device, dtype=dtype), center=False, normalized=False) 1471 # There is a larger error due to the scaling of amplitude 1472 original = original[..., :inverse.size(-1)] 1473 self.assertEqual(inverse, original, atol=1e-3, rtol=0) 1474 1475 _test(amplitude=123, L=5, n=1) 1476 _test(amplitude=150, L=5, n=2) 1477 _test(amplitude=111, L=5, n=3) 1478 _test(amplitude=160, L=7, n=4) 1479 _test(amplitude=145, L=8, n=5) 1480 _test(amplitude=80, L=9, n=6) 1481 _test(amplitude=99, L=10, n=7) 1482 1483 @onlyNativeDeviceTypes 1484 @skipCPUIfNoFFT 1485 @dtypes(torch.double) 1486 def test_istft_linearity(self, device, dtype): 1487 num_trials = 100 1488 complex_dtype = corresponding_complex_dtype(dtype) 1489 1490 def _test(data_size, kwargs): 1491 for i in range(num_trials): 1492 tensor1 = torch.randn(data_size, device=device, dtype=complex_dtype) 1493 tensor2 = torch.randn(data_size, device=device, dtype=complex_dtype) 1494 a, b = torch.rand(2, dtype=dtype, device=device) 1495 # Also compare method vs. functional call signature 1496 istft1 = tensor1.istft(**kwargs) 1497 istft2 = tensor2.istft(**kwargs) 1498 istft = a * istft1 + b * istft2 1499 estimate = torch.istft(a * tensor1 + b * tensor2, **kwargs) 1500 self.assertEqual(istft, estimate, atol=1e-5, rtol=0) 1501 patterns = [ 1502 # hann_window, centered, normalized, onesided 1503 ( 1504 (2, 7, 7), 1505 { 1506 'n_fft': 12, 1507 'window': torch.hann_window(12, device=device, dtype=dtype), 1508 'center': True, 1509 'normalized': True, 1510 'onesided': True, 1511 }, 1512 ), 1513 # hann_window, centered, not normalized, not onesided 1514 ( 1515 (2, 12, 7), 1516 { 1517 'n_fft': 12, 1518 'window': torch.hann_window(12, device=device, dtype=dtype), 1519 'center': True, 1520 'normalized': False, 1521 'onesided': False, 1522 }, 1523 ), 1524 # hamming_window, centered, normalized, not onesided 1525 ( 1526 (2, 12, 7), 1527 { 1528 'n_fft': 12, 1529 'window': torch.hamming_window(12, device=device, dtype=dtype), 1530 'center': True, 1531 'normalized': True, 1532 'onesided': False, 1533 }, 1534 ), 1535 # hamming_window, not centered, not normalized, onesided 1536 ( 1537 (2, 7, 3), 1538 { 1539 'n_fft': 12, 1540 'window': torch.hamming_window(12, device=device, dtype=dtype), 1541 'center': False, 1542 'normalized': False, 1543 'onesided': True, 1544 }, 1545 ) 1546 ] 1547 for data_size, kwargs in patterns: 1548 _test(data_size, kwargs) 1549 1550 @onlyNativeDeviceTypes 1551 @skipCPUIfNoFFT 1552 def test_batch_istft(self, device): 1553 original = torch.tensor([ 1554 [4., 4., 4., 4., 4.], 1555 [0., 0., 0., 0., 0.], 1556 [0., 0., 0., 0., 0.] 1557 ], device=device, dtype=torch.complex64) 1558 1559 single = original.repeat(1, 1, 1) 1560 multi = original.repeat(4, 1, 1) 1561 1562 i_original = torch.istft(original, n_fft=4, length=4) 1563 i_single = torch.istft(single, n_fft=4, length=4) 1564 i_multi = torch.istft(multi, n_fft=4, length=4) 1565 1566 self.assertEqual(i_original.repeat(1, 1), i_single, atol=1e-6, rtol=0, exact_dtype=True) 1567 self.assertEqual(i_original.repeat(4, 1), i_multi, atol=1e-6, rtol=0, exact_dtype=True) 1568 1569 @onlyCUDA 1570 @skipIf(not TEST_MKL, "Test requires MKL") 1571 def test_stft_window_device(self, device): 1572 # Test the (i)stft window must be on the same device as the input 1573 x = torch.randn(1000, dtype=torch.complex64) 1574 window = torch.randn(100, dtype=torch.complex64) 1575 1576 with self.assertRaisesRegex(RuntimeError, "stft input and window must be on the same device"): 1577 torch.stft(x, n_fft=100, window=window.to(device)) 1578 1579 with self.assertRaisesRegex(RuntimeError, "stft input and window must be on the same device"): 1580 torch.stft(x.to(device), n_fft=100, window=window) 1581 1582 X = torch.stft(x, n_fft=100, window=window) 1583 1584 with self.assertRaisesRegex(RuntimeError, "istft input and window must be on the same device"): 1585 torch.istft(X, n_fft=100, window=window.to(device)) 1586 1587 with self.assertRaisesRegex(RuntimeError, "istft input and window must be on the same device"): 1588 torch.istft(x.to(device), n_fft=100, window=window) 1589 1590 1591class FFTDocTestFinder: 1592 '''The default doctest finder doesn't like that function.__module__ doesn't 1593 match torch.fft. It assumes the functions are leaked imports. 1594 ''' 1595 def __init__(self) -> None: 1596 self.parser = doctest.DocTestParser() 1597 1598 def find(self, obj, name=None, module=None, globs=None, extraglobs=None): 1599 doctests = [] 1600 1601 modname = name if name is not None else obj.__name__ 1602 globs = {} if globs is None else globs 1603 1604 for fname in obj.__all__: 1605 func = getattr(obj, fname) 1606 if inspect.isroutine(func): 1607 qualname = modname + '.' + fname 1608 docstring = inspect.getdoc(func) 1609 if docstring is None: 1610 continue 1611 1612 examples = self.parser.get_doctest( 1613 docstring, globs=globs, name=fname, filename=None, lineno=None) 1614 doctests.append(examples) 1615 1616 return doctests 1617 1618 1619class TestFFTDocExamples(TestCase): 1620 pass 1621 1622def generate_doc_test(doc_test): 1623 def test(self, device): 1624 self.assertEqual(device, 'cpu') 1625 runner = doctest.DocTestRunner() 1626 runner.run(doc_test) 1627 1628 if runner.failures != 0: 1629 runner.summarize() 1630 self.fail('Doctest failed') 1631 1632 setattr(TestFFTDocExamples, 'test_' + doc_test.name, skipCPUIfNoFFT(test)) 1633 1634for doc_test in FFTDocTestFinder().find(torch.fft, globs=dict(torch=torch)): 1635 generate_doc_test(doc_test) 1636 1637 1638instantiate_device_type_tests(TestFFT, globals()) 1639instantiate_device_type_tests(TestFFTDocExamples, globals(), only_for='cpu') 1640 1641if __name__ == '__main__': 1642 run_tests() 1643