1# Owner(s): ["module: dynamo"] 2 3import functools 4import queue 5import threading 6from unittest import skipIf as skipif, SkipTest 7 8import pytest 9from pytest import raises as assert_raises 10 11from torch.testing._internal.common_utils import ( 12 instantiate_parametrized_tests, 13 parametrize, 14 run_tests, 15 TEST_WITH_TORCHDYNAMO, 16 TestCase, 17) 18 19 20if TEST_WITH_TORCHDYNAMO: 21 import numpy as np 22 from numpy.random import random 23 from numpy.testing import assert_allclose # , IS_WASM 24else: 25 import torch._numpy as np 26 from torch._numpy.random import random 27 from torch._numpy.testing import assert_allclose # , IS_WASM 28 29 30skip = functools.partial(skipif, True) 31 32 33IS_WASM = False 34 35 36def fft1(x): 37 L = len(x) 38 phase = -2j * np.pi * (np.arange(L) / L) 39 phase = np.arange(L).reshape(-1, 1) * phase 40 return np.sum(x * np.exp(phase), axis=1) 41 42 43class TestFFTShift(TestCase): 44 def test_fft_n(self): 45 assert_raises((ValueError, RuntimeError), np.fft.fft, [1, 2, 3], 0) 46 47 48@instantiate_parametrized_tests 49class TestFFT1D(TestCase): 50 def setUp(self): 51 super().setUp() 52 np.random.seed(123456) 53 54 def test_identity(self): 55 maxlen = 512 56 x = random(maxlen) + 1j * random(maxlen) 57 xr = random(maxlen) 58 for i in range(1, maxlen): 59 assert_allclose(np.fft.ifft(np.fft.fft(x[0:i])), x[0:i], atol=1e-12) 60 assert_allclose(np.fft.irfft(np.fft.rfft(xr[0:i]), i), xr[0:i], atol=1e-12) 61 62 def test_fft(self): 63 np.random.seed(1234) 64 x = random(30) + 1j * random(30) 65 assert_allclose(fft1(x), np.fft.fft(x), atol=3e-5) 66 assert_allclose(fft1(x), np.fft.fft(x, norm="backward"), atol=3e-5) 67 assert_allclose(fft1(x) / np.sqrt(30), np.fft.fft(x, norm="ortho"), atol=5e-6) 68 assert_allclose(fft1(x) / 30.0, np.fft.fft(x, norm="forward"), atol=5e-6) 69 70 @parametrize("norm", (None, "backward", "ortho", "forward")) 71 def test_ifft(self, norm): 72 x = random(30) + 1j * random(30) 73 assert_allclose(x, np.fft.ifft(np.fft.fft(x, norm=norm), norm=norm), atol=1e-6) 74 75 # Ensure we get the correct error message 76 # NB: Exact wording differs slightly under Dynamo and in eager. 77 with pytest.raises((ValueError, RuntimeError), match="Invalid number of"): 78 np.fft.ifft([], norm=norm) 79 80 def test_fft2(self): 81 x = random((30, 20)) + 1j * random((30, 20)) 82 assert_allclose( 83 np.fft.fft(np.fft.fft(x, axis=1), axis=0), np.fft.fft2(x), atol=1e-6 84 ) 85 assert_allclose(np.fft.fft2(x), np.fft.fft2(x, norm="backward"), atol=1e-6) 86 assert_allclose( 87 np.fft.fft2(x) / np.sqrt(30 * 20), np.fft.fft2(x, norm="ortho"), atol=1e-6 88 ) 89 assert_allclose( 90 np.fft.fft2(x) / (30.0 * 20.0), np.fft.fft2(x, norm="forward"), atol=1e-6 91 ) 92 93 def test_ifft2(self): 94 x = random((30, 20)) + 1j * random((30, 20)) 95 assert_allclose( 96 np.fft.ifft(np.fft.ifft(x, axis=1), axis=0), np.fft.ifft2(x), atol=1e-6 97 ) 98 assert_allclose(np.fft.ifft2(x), np.fft.ifft2(x, norm="backward"), atol=1e-6) 99 assert_allclose( 100 np.fft.ifft2(x) * np.sqrt(30 * 20), np.fft.ifft2(x, norm="ortho"), atol=1e-6 101 ) 102 assert_allclose( 103 np.fft.ifft2(x) * (30.0 * 20.0), np.fft.ifft2(x, norm="forward"), atol=1e-6 104 ) 105 106 def test_fftn(self): 107 x = random((30, 20, 10)) + 1j * random((30, 20, 10)) 108 assert_allclose( 109 np.fft.fft(np.fft.fft(np.fft.fft(x, axis=2), axis=1), axis=0), 110 np.fft.fftn(x), 111 atol=1e-6, 112 ) 113 assert_allclose(np.fft.fftn(x), np.fft.fftn(x, norm="backward"), atol=1e-6) 114 assert_allclose( 115 np.fft.fftn(x) / np.sqrt(30 * 20 * 10), 116 np.fft.fftn(x, norm="ortho"), 117 atol=1e-6, 118 ) 119 assert_allclose( 120 np.fft.fftn(x) / (30.0 * 20.0 * 10.0), 121 np.fft.fftn(x, norm="forward"), 122 atol=1e-6, 123 ) 124 125 def test_ifftn(self): 126 x = random((30, 20, 10)) + 1j * random((30, 20, 10)) 127 assert_allclose( 128 np.fft.ifft(np.fft.ifft(np.fft.ifft(x, axis=2), axis=1), axis=0), 129 np.fft.ifftn(x), 130 atol=1e-6, 131 ) 132 assert_allclose(np.fft.ifftn(x), np.fft.ifftn(x, norm="backward"), atol=1e-6) 133 assert_allclose( 134 np.fft.ifftn(x) * np.sqrt(30 * 20 * 10), 135 np.fft.ifftn(x, norm="ortho"), 136 atol=1e-6, 137 ) 138 assert_allclose( 139 np.fft.ifftn(x) * (30.0 * 20.0 * 10.0), 140 np.fft.ifftn(x, norm="forward"), 141 atol=1e-6, 142 ) 143 144 def test_rfft(self): 145 x = random(30) 146 for n in [x.size, 2 * x.size]: 147 for norm in [None, "backward", "ortho", "forward"]: 148 assert_allclose( 149 np.fft.fft(x, n=n, norm=norm)[: (n // 2 + 1)], 150 np.fft.rfft(x, n=n, norm=norm), 151 atol=1e-6, 152 ) 153 assert_allclose( 154 np.fft.rfft(x, n=n), np.fft.rfft(x, n=n, norm="backward"), atol=1e-6 155 ) 156 assert_allclose( 157 np.fft.rfft(x, n=n) / np.sqrt(n), 158 np.fft.rfft(x, n=n, norm="ortho"), 159 atol=1e-6, 160 ) 161 assert_allclose( 162 np.fft.rfft(x, n=n) / n, np.fft.rfft(x, n=n, norm="forward"), atol=1e-6 163 ) 164 165 def test_irfft(self): 166 x = random(30) 167 assert_allclose(x, np.fft.irfft(np.fft.rfft(x)), atol=1e-6) 168 assert_allclose( 169 x, np.fft.irfft(np.fft.rfft(x, norm="backward"), norm="backward"), atol=1e-6 170 ) 171 assert_allclose( 172 x, np.fft.irfft(np.fft.rfft(x, norm="ortho"), norm="ortho"), atol=1e-6 173 ) 174 assert_allclose( 175 x, np.fft.irfft(np.fft.rfft(x, norm="forward"), norm="forward"), atol=1e-6 176 ) 177 178 def test_rfft2(self): 179 x = random((30, 20)) 180 assert_allclose(np.fft.fft2(x)[:, :11], np.fft.rfft2(x), atol=1e-6) 181 assert_allclose(np.fft.rfft2(x), np.fft.rfft2(x, norm="backward"), atol=1e-6) 182 assert_allclose( 183 np.fft.rfft2(x) / np.sqrt(30 * 20), np.fft.rfft2(x, norm="ortho"), atol=1e-6 184 ) 185 assert_allclose( 186 np.fft.rfft2(x) / (30.0 * 20.0), np.fft.rfft2(x, norm="forward"), atol=1e-6 187 ) 188 189 def test_irfft2(self): 190 x = random((30, 20)) 191 assert_allclose(x, np.fft.irfft2(np.fft.rfft2(x)), atol=1e-6) 192 assert_allclose( 193 x, 194 np.fft.irfft2(np.fft.rfft2(x, norm="backward"), norm="backward"), 195 atol=1e-6, 196 ) 197 assert_allclose( 198 x, np.fft.irfft2(np.fft.rfft2(x, norm="ortho"), norm="ortho"), atol=1e-6 199 ) 200 assert_allclose( 201 x, np.fft.irfft2(np.fft.rfft2(x, norm="forward"), norm="forward"), atol=1e-6 202 ) 203 204 def test_rfftn(self): 205 x = random((30, 20, 10)) 206 assert_allclose(np.fft.fftn(x)[:, :, :6], np.fft.rfftn(x), atol=1e-6) 207 assert_allclose(np.fft.rfftn(x), np.fft.rfftn(x, norm="backward"), atol=1e-6) 208 assert_allclose( 209 np.fft.rfftn(x) / np.sqrt(30 * 20 * 10), 210 np.fft.rfftn(x, norm="ortho"), 211 atol=1e-6, 212 ) 213 assert_allclose( 214 np.fft.rfftn(x) / (30.0 * 20.0 * 10.0), 215 np.fft.rfftn(x, norm="forward"), 216 atol=1e-6, 217 ) 218 219 def test_irfftn(self): 220 x = random((30, 20, 10)) 221 assert_allclose(x, np.fft.irfftn(np.fft.rfftn(x)), atol=1e-6) 222 assert_allclose( 223 x, 224 np.fft.irfftn(np.fft.rfftn(x, norm="backward"), norm="backward"), 225 atol=1e-6, 226 ) 227 assert_allclose( 228 x, np.fft.irfftn(np.fft.rfftn(x, norm="ortho"), norm="ortho"), atol=1e-6 229 ) 230 assert_allclose( 231 x, np.fft.irfftn(np.fft.rfftn(x, norm="forward"), norm="forward"), atol=1e-6 232 ) 233 234 def test_hfft(self): 235 x = random(14) + 1j * random(14) 236 x_herm = np.concatenate((random(1), x, random(1))) 237 x = np.concatenate((x_herm, np.flip(x).conj())) 238 assert_allclose(np.fft.fft(x), np.fft.hfft(x_herm), atol=1e-6) 239 assert_allclose( 240 np.fft.hfft(x_herm), np.fft.hfft(x_herm, norm="backward"), atol=1e-6 241 ) 242 assert_allclose( 243 np.fft.hfft(x_herm) / np.sqrt(30), 244 np.fft.hfft(x_herm, norm="ortho"), 245 atol=1e-6, 246 ) 247 assert_allclose( 248 np.fft.hfft(x_herm) / 30.0, np.fft.hfft(x_herm, norm="forward"), atol=1e-6 249 ) 250 251 def test_ihfft(self): 252 x = random(14) + 1j * random(14) 253 x_herm = np.concatenate((random(1), x, random(1))) 254 x = np.concatenate((x_herm, np.flip(x).conj())) 255 assert_allclose(x_herm, np.fft.ihfft(np.fft.hfft(x_herm)), atol=1e-6) 256 assert_allclose( 257 x_herm, 258 np.fft.ihfft(np.fft.hfft(x_herm, norm="backward"), norm="backward"), 259 atol=1e-6, 260 ) 261 assert_allclose( 262 x_herm, 263 np.fft.ihfft(np.fft.hfft(x_herm, norm="ortho"), norm="ortho"), 264 atol=1e-6, 265 ) 266 assert_allclose( 267 x_herm, 268 np.fft.ihfft(np.fft.hfft(x_herm, norm="forward"), norm="forward"), 269 atol=1e-6, 270 ) 271 272 @parametrize("op", [np.fft.fftn, np.fft.ifftn, np.fft.rfftn, np.fft.irfftn]) 273 def test_axes(self, op): 274 x = random((30, 20, 10)) 275 axes = [(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0)] 276 for a in axes: 277 op_tr = op(np.transpose(x, a)) 278 tr_op = np.transpose(op(x, axes=a), a) 279 assert_allclose(op_tr, tr_op, atol=1e-6) 280 281 def test_all_1d_norm_preserving(self): 282 # verify that round-trip transforms are norm-preserving 283 x = random(30) 284 x_norm = np.linalg.norm(x) 285 n = x.size * 2 286 func_pairs = [ 287 (np.fft.fft, np.fft.ifft), 288 (np.fft.rfft, np.fft.irfft), 289 # hfft: order so the first function takes x.size samples 290 # (necessary for comparison to x_norm above) 291 (np.fft.ihfft, np.fft.hfft), 292 ] 293 for forw, back in func_pairs: 294 for n in [x.size, 2 * x.size]: 295 for norm in [None, "backward", "ortho", "forward"]: 296 tmp = forw(x, n=n, norm=norm) 297 tmp = back(tmp, n=n, norm=norm) 298 assert_allclose(x_norm, np.linalg.norm(tmp), atol=1e-6) 299 300 @parametrize("dtype", [np.half, np.single, np.double]) 301 def test_dtypes(self, dtype): 302 # make sure that all input precisions are accepted and internally 303 # converted to 64bit 304 x = random(30).astype(dtype) 305 assert_allclose(np.fft.ifft(np.fft.fft(x)), x, atol=1e-6) 306 assert_allclose(np.fft.irfft(np.fft.rfft(x)), x, atol=1e-6) 307 308 @parametrize("dtype", [np.float32, np.float64, np.complex64, np.complex128]) 309 @parametrize("order", ["F", "non-contiguous"]) 310 @parametrize( 311 "fft", 312 [np.fft.fft, np.fft.fft2, np.fft.fftn, np.fft.ifft, np.fft.ifft2, np.fft.ifftn], 313 ) 314 def test_fft_with_order(self, dtype, order, fft): 315 # Check that FFT/IFFT produces identical results for C, Fortran and 316 # non contiguous arrays 317 # rng = np.random.RandomState(42) 318 rng = np.random 319 X = rng.rand(8, 7, 13).astype(dtype) # , copy=False) 320 # See discussion in pull/14178 321 _tol = float(8.0 * np.sqrt(np.log2(X.size)) * np.finfo(X.dtype).eps) 322 if order == "F": 323 raise SkipTest("Fortran order arrays") 324 Y = np.asfortranarray(X) 325 else: 326 # Make a non contiguous array 327 Z = np.empty((16, 7, 13), dtype=X.dtype) 328 Z[::2] = X 329 Y = Z[::2] 330 X = Y.copy() 331 332 if fft.__name__.endswith("fft"): 333 for axis in range(3): 334 X_res = fft(X, axis=axis) 335 Y_res = fft(Y, axis=axis) 336 assert_allclose(X_res, Y_res, atol=_tol, rtol=_tol) 337 elif fft.__name__.endswith(("fft2", "fftn")): 338 axes = [(0, 1), (1, 2), (0, 2)] 339 if fft.__name__.endswith("fftn"): 340 axes.extend([(0,), (1,), (2,), None]) 341 for ax in axes: 342 X_res = fft(X, axes=ax) 343 Y_res = fft(Y, axes=ax) 344 assert_allclose(X_res, Y_res, atol=_tol, rtol=_tol) 345 else: 346 raise ValueError 347 348 349@skipif(IS_WASM, reason="Cannot start thread") 350class TestFFTThreadSafe(TestCase): 351 threads = 16 352 input_shape = (800, 200) 353 354 def _test_mtsame(self, func, *args): 355 def worker(args, q): 356 q.put(func(*args)) 357 358 q = queue.Queue() 359 expected = func(*args) 360 361 # Spin off a bunch of threads to call the same function simultaneously 362 t = [ 363 threading.Thread(target=worker, args=(args, q)) for i in range(self.threads) 364 ] 365 [x.start() for x in t] 366 367 [x.join() for x in t] 368 # Make sure all threads returned the correct value 369 for i in range(self.threads): 370 # under torch.dynamo `assert_array_equal` fails with relative errors of 371 # about 1.5e-14. Hence replace it with `assert_allclose(..., rtol=2e-14)` 372 assert_allclose( 373 q.get(timeout=5), 374 expected, 375 atol=2e-14 376 # msg="Function returned wrong value in multithreaded context", 377 ) 378 379 def test_fft(self): 380 a = np.ones(self.input_shape) * 1 + 0j 381 self._test_mtsame(np.fft.fft, a) 382 383 def test_ifft(self): 384 a = np.ones(self.input_shape) * 1 + 0j 385 self._test_mtsame(np.fft.ifft, a) 386 387 def test_rfft(self): 388 a = np.ones(self.input_shape) 389 self._test_mtsame(np.fft.rfft, a) 390 391 def test_irfft(self): 392 a = np.ones(self.input_shape) * 1 + 0j 393 self._test_mtsame(np.fft.irfft, a) 394 395 396if __name__ == "__main__": 397 run_tests() 398