1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for fft operations.""" 16 17import itertools 18import unittest 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.core.protobuf import config_pb2 24from tensorflow.python.eager import context 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import errors 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import gen_spectral_ops 30from tensorflow.python.ops import gradient_checker_v2 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops.signal import fft_ops 33from tensorflow.python.platform import test 34 35VALID_FFT_RANKS = (1, 2, 3) 36 37 38# TODO(rjryan): Investigate precision issues. We should be able to achieve 39# better tolerances, at least for the complex128 tests. 40class BaseFFTOpsTest(test.TestCase): 41 42 def _compare(self, x, rank, fft_length=None, use_placeholder=False, 43 rtol=1e-4, atol=1e-4): 44 self._compare_forward(x, rank, fft_length, use_placeholder, rtol, atol) 45 self._compare_backward(x, rank, fft_length, use_placeholder, rtol, atol) 46 47 def _compare_forward(self, x, rank, fft_length=None, use_placeholder=False, 48 rtol=1e-4, atol=1e-4): 49 x_np = self._np_fft(x, rank, fft_length) 50 if use_placeholder: 51 x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype)) 52 x_tf = self._tf_fft(x_ph, rank, fft_length, feed_dict={x_ph: x}) 53 else: 54 x_tf = self._tf_fft(x, rank, fft_length) 55 56 self.assertAllClose(x_np, x_tf, rtol=rtol, atol=atol) 57 58 def _compare_backward(self, x, rank, fft_length=None, use_placeholder=False, 59 rtol=1e-4, atol=1e-4): 60 x_np = self._np_ifft(x, rank, fft_length) 61 if use_placeholder: 62 x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype)) 63 x_tf = self._tf_ifft(x_ph, rank, fft_length, feed_dict={x_ph: x}) 64 else: 65 x_tf = self._tf_ifft(x, rank, fft_length) 66 67 self.assertAllClose(x_np, x_tf, rtol=rtol, atol=atol) 68 69 def _check_memory_fail(self, x, rank): 70 config = config_pb2.ConfigProto() 71 config.gpu_options.per_process_gpu_memory_fraction = 1e-2 72 with self.cached_session(config=config, force_gpu=True): 73 self._tf_fft(x, rank, fft_length=None) 74 75 def _check_grad_complex(self, func, x, y, result_is_complex=True, 76 rtol=1e-2, atol=1e-2): 77 with self.cached_session(): 78 79 def f(inx, iny): 80 inx.set_shape(x.shape) 81 iny.set_shape(y.shape) 82 # func is a forward or inverse, real or complex, batched or unbatched 83 # FFT function with a complex input. 84 z = func(math_ops.complex(inx, iny)) 85 # loss = sum(|z|^2) 86 loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z))) 87 return loss 88 89 ((x_jacob_t, y_jacob_t), (x_jacob_n, y_jacob_n)) = ( 90 gradient_checker_v2.compute_gradient(f, [x, y], delta=1e-2)) 91 92 self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol) 93 self.assertAllClose(y_jacob_t, y_jacob_n, rtol=rtol, atol=atol) 94 95 def _check_grad_real(self, func, x, rtol=1e-2, atol=1e-2): 96 def f(inx): 97 inx.set_shape(x.shape) 98 # func is a forward RFFT function (batched or unbatched). 99 z = func(inx) 100 # loss = sum(|z|^2) 101 loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z))) 102 return loss 103 104 (x_jacob_t,), (x_jacob_n,) = gradient_checker_v2.compute_gradient( 105 f, [x], delta=1e-2) 106 self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol) 107 108 109@test_util.run_all_in_graph_and_eager_modes 110class FFTOpsTest(BaseFFTOpsTest, parameterized.TestCase): 111 112 def _tf_fft(self, x, rank, fft_length=None, feed_dict=None): 113 # fft_length unused for complex FFTs. 114 with self.cached_session() as sess: 115 return sess.run(self._tf_fft_for_rank(rank)(x), feed_dict=feed_dict) 116 117 def _tf_ifft(self, x, rank, fft_length=None, feed_dict=None): 118 # fft_length unused for complex FFTs. 119 with self.cached_session() as sess: 120 return sess.run(self._tf_ifft_for_rank(rank)(x), feed_dict=feed_dict) 121 122 def _np_fft(self, x, rank, fft_length=None): 123 if rank == 1: 124 return np.fft.fft2(x, s=fft_length, axes=(-1,)) 125 elif rank == 2: 126 return np.fft.fft2(x, s=fft_length, axes=(-2, -1)) 127 elif rank == 3: 128 return np.fft.fft2(x, s=fft_length, axes=(-3, -2, -1)) 129 else: 130 raise ValueError("invalid rank") 131 132 def _np_ifft(self, x, rank, fft_length=None): 133 if rank == 1: 134 return np.fft.ifft2(x, s=fft_length, axes=(-1,)) 135 elif rank == 2: 136 return np.fft.ifft2(x, s=fft_length, axes=(-2, -1)) 137 elif rank == 3: 138 return np.fft.ifft2(x, s=fft_length, axes=(-3, -2, -1)) 139 else: 140 raise ValueError("invalid rank") 141 142 def _tf_fft_for_rank(self, rank): 143 if rank == 1: 144 return fft_ops.fft 145 elif rank == 2: 146 return fft_ops.fft2d 147 elif rank == 3: 148 return fft_ops.fft3d 149 else: 150 raise ValueError("invalid rank") 151 152 def _tf_ifft_for_rank(self, rank): 153 if rank == 1: 154 return fft_ops.ifft 155 elif rank == 2: 156 return fft_ops.ifft2d 157 elif rank == 3: 158 return fft_ops.ifft3d 159 else: 160 raise ValueError("invalid rank") 161 162 @parameterized.parameters(itertools.product( 163 VALID_FFT_RANKS, range(3), (np.complex64, np.complex128))) 164 def test_empty(self, rank, extra_dims, np_type): 165 dims = rank + extra_dims 166 x = np.zeros((0,) * dims).astype(np_type) 167 self.assertEqual(x.shape, self._tf_fft(x, rank).shape) 168 self.assertEqual(x.shape, self._tf_ifft(x, rank).shape) 169 170 @parameterized.parameters( 171 itertools.product(VALID_FFT_RANKS, range(3), 172 (np.complex64, np.complex128))) 173 def test_basic(self, rank, extra_dims, np_type): 174 dims = rank + extra_dims 175 tol = 1e-4 if np_type == np.complex64 else 1e-8 176 self._compare( 177 np.mod(np.arange(np.power(4, dims)), 10).reshape( 178 (4,) * dims).astype(np_type), rank, rtol=tol, atol=tol) 179 180 @parameterized.parameters(itertools.product( 181 (1,), range(3), (np.complex64, np.complex128))) 182 def test_large_batch(self, rank, extra_dims, np_type): 183 dims = rank + extra_dims 184 tol = 1e-4 if np_type == np.complex64 else 5e-5 185 self._compare( 186 np.mod(np.arange(np.power(128, dims)), 10).reshape( 187 (128,) * dims).astype(np_type), rank, rtol=tol, atol=tol) 188 189 # TODO(yangzihao): Disable before we can figure out a way to 190 # properly test memory fail for large batch fft. 191 # def test_large_batch_memory_fail(self): 192 # if test.is_gpu_available(cuda_only=True): 193 # rank = 1 194 # for dims in range(rank, rank + 3): 195 # self._check_memory_fail( 196 # np.mod(np.arange(np.power(128, dims)), 64).reshape( 197 # (128,) * dims).astype(np.complex64), rank) 198 199 @parameterized.parameters(itertools.product( 200 VALID_FFT_RANKS, range(3), (np.complex64, np.complex128))) 201 def test_placeholder(self, rank, extra_dims, np_type): 202 if context.executing_eagerly(): 203 return 204 tol = 1e-4 if np_type == np.complex64 else 1e-8 205 dims = rank + extra_dims 206 self._compare( 207 np.mod(np.arange(np.power(4, dims)), 10).reshape( 208 (4,) * dims).astype(np_type), 209 rank, use_placeholder=True, rtol=tol, atol=tol) 210 211 @parameterized.parameters(itertools.product( 212 VALID_FFT_RANKS, range(3), (np.complex64, np.complex128))) 213 def test_random(self, rank, extra_dims, np_type): 214 tol = 1e-4 if np_type == np.complex64 else 5e-6 215 dims = rank + extra_dims 216 def gen(shape): 217 n = np.prod(shape) 218 re = np.random.uniform(size=n) 219 im = np.random.uniform(size=n) 220 return (re + im * 1j).reshape(shape) 221 222 self._compare(gen((4,) * dims).astype(np_type), rank, 223 rtol=tol, atol=tol) 224 225 @parameterized.parameters(itertools.product( 226 VALID_FFT_RANKS, 227 # Check a variety of sizes (power-of-2, odd, etc.) 228 [128, 256, 512, 1024, 127, 255, 511, 1023], 229 (np.complex64, np.complex128))) 230 def test_random_1d(self, rank, dim, np_type): 231 has_gpu = test.is_gpu_available(cuda_only=True) 232 tol = {(np.complex64, True): 1e-4, 233 (np.complex64, False): 1e-2, 234 (np.complex128, True): 1e-4, 235 (np.complex128, False): 1e-2}[(np_type, has_gpu)] 236 def gen(shape): 237 n = np.prod(shape) 238 re = np.random.uniform(size=n) 239 im = np.random.uniform(size=n) 240 return (re + im * 1j).reshape(shape) 241 242 self._compare(gen((dim,)).astype(np_type), 1, rtol=tol, atol=tol) 243 244 def test_error(self): 245 # TODO(rjryan): Fix this test under Eager. 246 if context.executing_eagerly(): 247 return 248 for rank in VALID_FFT_RANKS: 249 for dims in range(0, rank): 250 x = np.zeros((1,) * dims).astype(np.complex64) 251 with self.assertRaisesWithPredicateMatch( 252 ValueError, "Shape must be .*rank {}.*".format(rank)): 253 self._tf_fft(x, rank) 254 with self.assertRaisesWithPredicateMatch( 255 ValueError, "Shape must be .*rank {}.*".format(rank)): 256 self._tf_ifft(x, rank) 257 258 @parameterized.parameters(itertools.product( 259 VALID_FFT_RANKS, range(2), (np.float32, np.float64))) 260 def test_grad_simple(self, rank, extra_dims, np_type): 261 tol = 1e-4 if np_type == np.float32 else 1e-10 262 dims = rank + extra_dims 263 re = np.ones(shape=(4,) * dims, dtype=np_type) / 10.0 264 im = np.zeros(shape=(4,) * dims, dtype=np_type) 265 self._check_grad_complex(self._tf_fft_for_rank(rank), re, im, 266 rtol=tol, atol=tol) 267 self._check_grad_complex(self._tf_ifft_for_rank(rank), re, im, 268 rtol=tol, atol=tol) 269 270 @unittest.skip("16.86% flaky") 271 @parameterized.parameters(itertools.product( 272 VALID_FFT_RANKS, range(2), (np.float32, np.float64))) 273 def test_grad_random(self, rank, extra_dims, np_type): 274 dims = rank + extra_dims 275 tol = 1e-2 if np_type == np.float32 else 1e-10 276 re = np.random.rand(*((3,) * dims)).astype(np_type) * 2 - 1 277 im = np.random.rand(*((3,) * dims)).astype(np_type) * 2 - 1 278 self._check_grad_complex(self._tf_fft_for_rank(rank), re, im, 279 rtol=tol, atol=tol) 280 self._check_grad_complex(self._tf_ifft_for_rank(rank), re, im, 281 rtol=tol, atol=tol) 282 283 284@test_util.run_all_in_graph_and_eager_modes 285@test_util.disable_xla("b/155276727") 286class RFFTOpsTest(BaseFFTOpsTest, parameterized.TestCase): 287 288 def _tf_fft(self, x, rank, fft_length=None, feed_dict=None): 289 with self.cached_session() as sess: 290 return sess.run( 291 self._tf_fft_for_rank(rank)(x, fft_length), feed_dict=feed_dict) 292 293 def _tf_ifft(self, x, rank, fft_length=None, feed_dict=None): 294 with self.cached_session() as sess: 295 return sess.run( 296 self._tf_ifft_for_rank(rank)(x, fft_length), feed_dict=feed_dict) 297 298 def _np_fft(self, x, rank, fft_length=None): 299 if rank == 1: 300 return np.fft.rfft2(x, s=fft_length, axes=(-1,)) 301 elif rank == 2: 302 return np.fft.rfft2(x, s=fft_length, axes=(-2, -1)) 303 elif rank == 3: 304 return np.fft.rfft2(x, s=fft_length, axes=(-3, -2, -1)) 305 else: 306 raise ValueError("invalid rank") 307 308 def _np_ifft(self, x, rank, fft_length=None): 309 if rank == 1: 310 return np.fft.irfft2(x, s=fft_length, axes=(-1,)) 311 elif rank == 2: 312 return np.fft.irfft2(x, s=fft_length, axes=(-2, -1)) 313 elif rank == 3: 314 return np.fft.irfft2(x, s=fft_length, axes=(-3, -2, -1)) 315 else: 316 raise ValueError("invalid rank") 317 318 def _tf_fft_for_rank(self, rank): 319 if rank == 1: 320 return fft_ops.rfft 321 elif rank == 2: 322 return fft_ops.rfft2d 323 elif rank == 3: 324 return fft_ops.rfft3d 325 else: 326 raise ValueError("invalid rank") 327 328 def _tf_ifft_for_rank(self, rank): 329 if rank == 1: 330 return fft_ops.irfft 331 elif rank == 2: 332 return fft_ops.irfft2d 333 elif rank == 3: 334 return fft_ops.irfft3d 335 else: 336 raise ValueError("invalid rank") 337 338 # rocFFT requires/assumes that the input to the irfft transform 339 # is of the form that is a valid output from the rfft transform 340 # (i.e. it cannot be a set of random numbers) 341 # So for ROCm, call rfft and use its output as the input for testing irfft 342 def _generate_valid_irfft_input(self, c2r, np_ctype, r2c, np_rtype, rank, 343 fft_length): 344 if test.is_built_with_rocm(): 345 return self._np_fft(r2c.astype(np_rtype), rank, fft_length) 346 else: 347 return c2r.astype(np_ctype) 348 349 @parameterized.parameters(itertools.product( 350 VALID_FFT_RANKS, range(3), (np.float32, np.float64))) 351 352 def test_empty(self, rank, extra_dims, np_rtype): 353 np_ctype = np.complex64 if np_rtype == np.float32 else np.complex128 354 dims = rank + extra_dims 355 x = np.zeros((0,) * dims).astype(np_rtype) 356 self.assertEqual(x.shape, self._tf_fft(x, rank).shape) 357 x = np.zeros((0,) * dims).astype(np_ctype) 358 self.assertEqual(x.shape, self._tf_ifft(x, rank).shape) 359 360 @parameterized.parameters(itertools.product( 361 VALID_FFT_RANKS, range(3), (5, 6), (np.float32, np.float64))) 362 def test_basic(self, rank, extra_dims, size, np_rtype): 363 np_ctype = np.complex64 if np_rtype == np.float32 else np.complex128 364 tol = 1e-4 if np_rtype == np.float32 else 5e-5 365 dims = rank + extra_dims 366 inner_dim = size // 2 + 1 367 r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape( 368 (size,) * dims) 369 fft_length = (size,) * rank 370 self._compare_forward( 371 r2c.astype(np_rtype), rank, fft_length, rtol=tol, atol=tol) 372 c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim), 373 10).reshape((size,) * (dims - 1) + (inner_dim,)) 374 c2r = self._generate_valid_irfft_input(c2r, np_ctype, r2c, np_rtype, rank, 375 fft_length) 376 self._compare_backward(c2r, rank, fft_length, rtol=tol, atol=tol) 377 378 @parameterized.parameters(itertools.product( 379 (1,), range(3), (64, 128), (np.float32, np.float64))) 380 def test_large_batch(self, rank, extra_dims, size, np_rtype): 381 np_ctype = np.complex64 if np_rtype == np.float32 else np.complex128 382 tol = 1e-4 if np_rtype == np.float32 else 1e-5 383 dims = rank + extra_dims 384 inner_dim = size // 2 + 1 385 r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape( 386 (size,) * dims) 387 fft_length = (size,) * rank 388 self._compare_forward( 389 r2c.astype(np_rtype), rank, fft_length, rtol=tol, atol=tol) 390 c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim), 391 10).reshape((size,) * (dims - 1) + (inner_dim,)) 392 c2r = self._generate_valid_irfft_input(c2r, np_ctype, r2c, np_rtype, rank, 393 fft_length) 394 self._compare_backward(c2r, rank, fft_length, rtol=tol, atol=tol) 395 396 @parameterized.parameters(itertools.product( 397 VALID_FFT_RANKS, range(3), (5, 6), (np.float32, np.float64))) 398 def test_placeholder(self, rank, extra_dims, size, np_rtype): 399 if context.executing_eagerly(): 400 return 401 np_ctype = np.complex64 if np_rtype == np.float32 else np.complex128 402 tol = 1e-4 if np_rtype == np.float32 else 1e-8 403 dims = rank + extra_dims 404 inner_dim = size // 2 + 1 405 r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape( 406 (size,) * dims) 407 fft_length = (size,) * rank 408 self._compare_forward( 409 r2c.astype(np_rtype), 410 rank, 411 fft_length, 412 use_placeholder=True, 413 rtol=tol, 414 atol=tol) 415 c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim), 416 10).reshape((size,) * (dims - 1) + (inner_dim,)) 417 c2r = self._generate_valid_irfft_input(c2r, np_ctype, r2c, np_rtype, rank, 418 fft_length) 419 self._compare_backward( 420 c2r, rank, fft_length, use_placeholder=True, rtol=tol, atol=tol) 421 422 @parameterized.parameters(itertools.product( 423 VALID_FFT_RANKS, range(3), (5, 6), (np.float32, np.float64))) 424 def test_fft_lenth_truncate(self, rank, extra_dims, size, np_rtype): 425 """Test truncation (FFT size < dimensions).""" 426 if test.is_built_with_rocm() and (rank == 3): 427 # TODO(rocm): fix me 428 # rfft fails for rank == 3 on ROCm 429 self.skipTest("Test fails on ROCm...fix me") 430 np_ctype = np.complex64 if np_rtype == np.float32 else np.complex128 431 tol = 1e-4 if np_rtype == np.float32 else 8e-5 432 dims = rank + extra_dims 433 inner_dim = size // 2 + 1 434 r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape( 435 (size,) * dims) 436 c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim), 437 10).reshape((size,) * (dims - 1) + (inner_dim,)) 438 fft_length = (size - 2,) * rank 439 self._compare_forward(r2c.astype(np_rtype), rank, fft_length, 440 rtol=tol, atol=tol) 441 c2r = self._generate_valid_irfft_input(c2r, np_ctype, r2c, np_rtype, rank, 442 fft_length) 443 self._compare_backward(c2r, rank, fft_length, rtol=tol, atol=tol) 444 # Confirm it works with unknown shapes as well. 445 if not context.executing_eagerly(): 446 self._compare_forward( 447 r2c.astype(np_rtype), 448 rank, 449 fft_length, 450 use_placeholder=True, 451 rtol=tol, atol=tol) 452 self._compare_backward( 453 c2r, rank, fft_length, use_placeholder=True, rtol=tol, atol=tol) 454 455 @parameterized.parameters(itertools.product( 456 VALID_FFT_RANKS, range(3), (5, 6), (np.float32, np.float64))) 457 def test_fft_lenth_pad(self, rank, extra_dims, size, np_rtype): 458 """Test padding (FFT size > dimensions).""" 459 np_ctype = np.complex64 if np_rtype == np.float32 else np.complex128 460 tol = 1e-4 if np_rtype == np.float32 else 8e-5 461 dims = rank + extra_dims 462 inner_dim = size // 2 + 1 463 r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape( 464 (size,) * dims) 465 c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim), 466 10).reshape((size,) * (dims - 1) + (inner_dim,)) 467 fft_length = (size + 2,) * rank 468 self._compare_forward(r2c.astype(np_rtype), rank, fft_length, 469 rtol=tol, atol=tol) 470 c2r = self._generate_valid_irfft_input(c2r, np_ctype, r2c, np_rtype, rank, 471 fft_length) 472 self._compare_backward(c2r.astype(np_ctype), rank, fft_length, 473 rtol=tol, atol=tol) 474 # Confirm it works with unknown shapes as well. 475 if not context.executing_eagerly(): 476 self._compare_forward( 477 r2c.astype(np_rtype), 478 rank, 479 fft_length, 480 use_placeholder=True, 481 rtol=tol, atol=tol) 482 self._compare_backward( 483 c2r.astype(np_ctype), 484 rank, 485 fft_length, 486 use_placeholder=True, 487 rtol=tol, atol=tol) 488 489 @parameterized.parameters(itertools.product( 490 VALID_FFT_RANKS, range(3), (5, 6), (np.float32, np.float64))) 491 def test_random(self, rank, extra_dims, size, np_rtype): 492 def gen_real(shape): 493 n = np.prod(shape) 494 re = np.random.uniform(size=n) 495 ret = re.reshape(shape) 496 return ret 497 498 def gen_complex(shape): 499 n = np.prod(shape) 500 re = np.random.uniform(size=n) 501 im = np.random.uniform(size=n) 502 ret = (re + im * 1j).reshape(shape) 503 return ret 504 np_ctype = np.complex64 if np_rtype == np.float32 else np.complex128 505 tol = 1e-4 if np_rtype == np.float32 else 1e-5 506 dims = rank + extra_dims 507 r2c = gen_real((size,) * dims) 508 inner_dim = size // 2 + 1 509 fft_length = (size,) * rank 510 self._compare_forward( 511 r2c.astype(np_rtype), rank, fft_length, rtol=tol, atol=tol) 512 complex_dims = (size,) * (dims - 1) + (inner_dim,) 513 c2r = gen_complex(complex_dims) 514 c2r = self._generate_valid_irfft_input(c2r, np_ctype, r2c, np_rtype, rank, 515 fft_length) 516 self._compare_backward(c2r, rank, fft_length, rtol=tol, atol=tol) 517 518 def test_error(self): 519 # TODO(rjryan): Fix this test under Eager. 520 if context.executing_eagerly(): 521 return 522 for rank in VALID_FFT_RANKS: 523 for dims in range(0, rank): 524 x = np.zeros((1,) * dims).astype(np.complex64) 525 with self.assertRaisesWithPredicateMatch( 526 ValueError, "Shape .* must have rank at least {}".format(rank)): 527 self._tf_fft(x, rank) 528 with self.assertRaisesWithPredicateMatch( 529 ValueError, "Shape .* must have rank at least {}".format(rank)): 530 self._tf_ifft(x, rank) 531 for dims in range(rank, rank + 2): 532 x = np.zeros((1,) * rank) 533 534 # Test non-rank-1 fft_length produces an error. 535 fft_length = np.zeros((1, 1)).astype(np.int32) 536 with self.assertRaisesWithPredicateMatch(ValueError, 537 "Shape .* must have rank 1"): 538 self._tf_fft(x, rank, fft_length) 539 with self.assertRaisesWithPredicateMatch(ValueError, 540 "Shape .* must have rank 1"): 541 self._tf_ifft(x, rank, fft_length) 542 543 # Test wrong fft_length length. 544 fft_length = np.zeros((rank + 1,)).astype(np.int32) 545 with self.assertRaisesWithPredicateMatch( 546 ValueError, "Dimension must be .*but is {}.*".format(rank + 1)): 547 self._tf_fft(x, rank, fft_length) 548 with self.assertRaisesWithPredicateMatch( 549 ValueError, "Dimension must be .*but is {}.*".format(rank + 1)): 550 self._tf_ifft(x, rank, fft_length) 551 552 # Test that calling the kernel directly without padding to fft_length 553 # produces an error. 554 rffts_for_rank = { 555 1: [gen_spectral_ops.rfft, gen_spectral_ops.irfft], 556 2: [gen_spectral_ops.rfft2d, gen_spectral_ops.irfft2d], 557 3: [gen_spectral_ops.rfft3d, gen_spectral_ops.irfft3d] 558 } 559 rfft_fn, irfft_fn = rffts_for_rank[rank] 560 with self.assertRaisesWithPredicateMatch( 561 errors.InvalidArgumentError, 562 "Input dimension .* must have length of at least 6 but got: 5"): 563 x = np.zeros((5,) * rank).astype(np.float32) 564 fft_length = [6] * rank 565 with self.cached_session(): 566 self.evaluate(rfft_fn(x, fft_length)) 567 568 with self.assertRaisesWithPredicateMatch( 569 errors.InvalidArgumentError, 570 "Input dimension .* must have length of at least .* but got: 3"): 571 x = np.zeros((3,) * rank).astype(np.complex64) 572 fft_length = [6] * rank 573 with self.cached_session(): 574 self.evaluate(irfft_fn(x, fft_length)) 575 576 @parameterized.parameters(itertools.product( 577 VALID_FFT_RANKS, range(2), (5, 6), (np.float32, np.float64))) 578 def test_grad_simple(self, rank, extra_dims, size, np_rtype): 579 # rfft3d/irfft3d do not have gradients yet. 580 if rank == 3: 581 return 582 dims = rank + extra_dims 583 tol = 1e-3 if np_rtype == np.float32 else 1e-10 584 re = np.ones(shape=(size,) * dims, dtype=np_rtype) 585 im = -np.ones(shape=(size,) * dims, dtype=np_rtype) 586 self._check_grad_real(self._tf_fft_for_rank(rank), re, 587 rtol=tol, atol=tol) 588 if test.is_built_with_rocm(): 589 # Fails on ROCm because of irfft peculairity 590 return 591 self._check_grad_complex( 592 self._tf_ifft_for_rank(rank), re, im, result_is_complex=False, 593 rtol=tol, atol=tol) 594 595 @parameterized.parameters(itertools.product( 596 VALID_FFT_RANKS, range(2), (5, 6), (np.float32, np.float64))) 597 def test_grad_random(self, rank, extra_dims, size, np_rtype): 598 # rfft3d/irfft3d do not have gradients yet. 599 if rank == 3: 600 return 601 dims = rank + extra_dims 602 tol = 1e-2 if np_rtype == np.float32 else 1e-10 603 re = np.random.rand(*((size,) * dims)).astype(np_rtype) * 2 - 1 604 im = np.random.rand(*((size,) * dims)).astype(np_rtype) * 2 - 1 605 self._check_grad_real(self._tf_fft_for_rank(rank), re, 606 rtol=tol, atol=tol) 607 if test.is_built_with_rocm(): 608 # Fails on ROCm because of irfft peculairity 609 return 610 self._check_grad_complex( 611 self._tf_ifft_for_rank(rank), re, im, result_is_complex=False, 612 rtol=tol, atol=tol) 613 614 def test_invalid_args(self): 615 # Test case for GitHub issue 55263 616 a = np.empty([6, 0]) 617 b = np.array([1, -1]) 618 with self.assertRaisesRegex(errors.InvalidArgumentError, "must >= 0"): 619 with self.session(): 620 v = fft_ops.rfft2d(input_tensor=a, fft_length=b) 621 self.evaluate(v) 622 623 624@test_util.run_all_in_graph_and_eager_modes 625class FFTShiftTest(test.TestCase, parameterized.TestCase): 626 627 def test_definition(self): 628 with self.session(): 629 x = [0, 1, 2, 3, 4, -4, -3, -2, -1] 630 y = [-4, -3, -2, -1, 0, 1, 2, 3, 4] 631 self.assertAllEqual(fft_ops.fftshift(x), y) 632 self.assertAllEqual(fft_ops.ifftshift(y), x) 633 x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1] 634 y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] 635 self.assertAllEqual(fft_ops.fftshift(x), y) 636 self.assertAllEqual(fft_ops.ifftshift(y), x) 637 638 def test_axes_keyword(self): 639 with self.session(): 640 freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]] 641 shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]] 642 self.assertAllEqual(fft_ops.fftshift(freqs, axes=(0, 1)), shifted) 643 self.assertAllEqual( 644 fft_ops.fftshift(freqs, axes=0), 645 fft_ops.fftshift(freqs, axes=(0,))) 646 self.assertAllEqual(fft_ops.ifftshift(shifted, axes=(0, 1)), freqs) 647 self.assertAllEqual( 648 fft_ops.ifftshift(shifted, axes=0), 649 fft_ops.ifftshift(shifted, axes=(0,))) 650 self.assertAllEqual(fft_ops.fftshift(freqs), shifted) 651 self.assertAllEqual(fft_ops.ifftshift(shifted), freqs) 652 653 def test_numpy_compatibility(self): 654 with self.session(): 655 x = [0, 1, 2, 3, 4, -4, -3, -2, -1] 656 y = [-4, -3, -2, -1, 0, 1, 2, 3, 4] 657 self.assertAllEqual(fft_ops.fftshift(x), np.fft.fftshift(x)) 658 self.assertAllEqual(fft_ops.ifftshift(y), np.fft.ifftshift(y)) 659 x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1] 660 y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] 661 self.assertAllEqual(fft_ops.fftshift(x), np.fft.fftshift(x)) 662 self.assertAllEqual(fft_ops.ifftshift(y), np.fft.ifftshift(y)) 663 freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]] 664 shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]] 665 self.assertAllEqual( 666 fft_ops.fftshift(freqs, axes=(0, 1)), 667 np.fft.fftshift(freqs, axes=(0, 1))) 668 self.assertAllEqual( 669 fft_ops.ifftshift(shifted, axes=(0, 1)), 670 np.fft.ifftshift(shifted, axes=(0, 1))) 671 672 @parameterized.parameters(None, 1, ([1, 2],)) 673 def test_placeholder(self, axes): 674 if context.executing_eagerly(): 675 return 676 x = array_ops.placeholder(shape=[None, None, None], dtype="float32") 677 y_fftshift = fft_ops.fftshift(x, axes=axes) 678 y_ifftshift = fft_ops.ifftshift(x, axes=axes) 679 x_np = np.random.rand(16, 256, 256) 680 with self.session() as sess: 681 y_fftshift_res, y_ifftshift_res = sess.run( 682 [y_fftshift, y_ifftshift], 683 feed_dict={x: x_np}) 684 self.assertAllClose(y_fftshift_res, np.fft.fftshift(x_np, axes=axes)) 685 self.assertAllClose(y_ifftshift_res, np.fft.ifftshift(x_np, axes=axes)) 686 687 def test_negative_axes(self): 688 with self.session(): 689 freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]] 690 shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]] 691 self.assertAllEqual(fft_ops.fftshift(freqs, axes=(0, -1)), shifted) 692 self.assertAllEqual(fft_ops.ifftshift(shifted, axes=(0, -1)), freqs) 693 self.assertAllEqual( 694 fft_ops.fftshift(freqs, axes=-1), fft_ops.fftshift(freqs, axes=(1,))) 695 self.assertAllEqual( 696 fft_ops.ifftshift(shifted, axes=-1), 697 fft_ops.ifftshift(shifted, axes=(1,))) 698 699 700if __name__ == "__main__": 701 test.main() 702