xref: /aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/signal/fft_ops_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for fft operations."""
16
17import itertools
18import unittest
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.core.protobuf import config_pb2
24from tensorflow.python.eager import context
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import errors
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import gen_spectral_ops
30from tensorflow.python.ops import gradient_checker_v2
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops.signal import fft_ops
33from tensorflow.python.platform import test
34
35VALID_FFT_RANKS = (1, 2, 3)
36
37
38# TODO(rjryan): Investigate precision issues. We should be able to achieve
39# better tolerances, at least for the complex128 tests.
40class BaseFFTOpsTest(test.TestCase):
41
42  def _compare(self, x, rank, fft_length=None, use_placeholder=False,
43               rtol=1e-4, atol=1e-4):
44    self._compare_forward(x, rank, fft_length, use_placeholder, rtol, atol)
45    self._compare_backward(x, rank, fft_length, use_placeholder, rtol, atol)
46
47  def _compare_forward(self, x, rank, fft_length=None, use_placeholder=False,
48                       rtol=1e-4, atol=1e-4):
49    x_np = self._np_fft(x, rank, fft_length)
50    if use_placeholder:
51      x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype))
52      x_tf = self._tf_fft(x_ph, rank, fft_length, feed_dict={x_ph: x})
53    else:
54      x_tf = self._tf_fft(x, rank, fft_length)
55
56    self.assertAllClose(x_np, x_tf, rtol=rtol, atol=atol)
57
58  def _compare_backward(self, x, rank, fft_length=None, use_placeholder=False,
59                        rtol=1e-4, atol=1e-4):
60    x_np = self._np_ifft(x, rank, fft_length)
61    if use_placeholder:
62      x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype))
63      x_tf = self._tf_ifft(x_ph, rank, fft_length, feed_dict={x_ph: x})
64    else:
65      x_tf = self._tf_ifft(x, rank, fft_length)
66
67    self.assertAllClose(x_np, x_tf, rtol=rtol, atol=atol)
68
69  def _check_memory_fail(self, x, rank):
70    config = config_pb2.ConfigProto()
71    config.gpu_options.per_process_gpu_memory_fraction = 1e-2
72    with self.cached_session(config=config, force_gpu=True):
73      self._tf_fft(x, rank, fft_length=None)
74
75  def _check_grad_complex(self, func, x, y, result_is_complex=True,
76                          rtol=1e-2, atol=1e-2):
77    with self.cached_session():
78
79      def f(inx, iny):
80        inx.set_shape(x.shape)
81        iny.set_shape(y.shape)
82        # func is a forward or inverse, real or complex, batched or unbatched
83        # FFT function with a complex input.
84        z = func(math_ops.complex(inx, iny))
85        # loss = sum(|z|^2)
86        loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
87        return loss
88
89      ((x_jacob_t, y_jacob_t), (x_jacob_n, y_jacob_n)) = (
90          gradient_checker_v2.compute_gradient(f, [x, y], delta=1e-2))
91
92    self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol)
93    self.assertAllClose(y_jacob_t, y_jacob_n, rtol=rtol, atol=atol)
94
95  def _check_grad_real(self, func, x, rtol=1e-2, atol=1e-2):
96    def f(inx):
97      inx.set_shape(x.shape)
98      # func is a forward RFFT function (batched or unbatched).
99      z = func(inx)
100      # loss = sum(|z|^2)
101      loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
102      return loss
103
104    (x_jacob_t,), (x_jacob_n,) = gradient_checker_v2.compute_gradient(
105        f, [x], delta=1e-2)
106    self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol)
107
108
109@test_util.run_all_in_graph_and_eager_modes
110class FFTOpsTest(BaseFFTOpsTest, parameterized.TestCase):
111
112  def _tf_fft(self, x, rank, fft_length=None, feed_dict=None):
113    # fft_length unused for complex FFTs.
114    with self.cached_session() as sess:
115      return sess.run(self._tf_fft_for_rank(rank)(x), feed_dict=feed_dict)
116
117  def _tf_ifft(self, x, rank, fft_length=None, feed_dict=None):
118    # fft_length unused for complex FFTs.
119    with self.cached_session() as sess:
120      return sess.run(self._tf_ifft_for_rank(rank)(x), feed_dict=feed_dict)
121
122  def _np_fft(self, x, rank, fft_length=None):
123    if rank == 1:
124      return np.fft.fft2(x, s=fft_length, axes=(-1,))
125    elif rank == 2:
126      return np.fft.fft2(x, s=fft_length, axes=(-2, -1))
127    elif rank == 3:
128      return np.fft.fft2(x, s=fft_length, axes=(-3, -2, -1))
129    else:
130      raise ValueError("invalid rank")
131
132  def _np_ifft(self, x, rank, fft_length=None):
133    if rank == 1:
134      return np.fft.ifft2(x, s=fft_length, axes=(-1,))
135    elif rank == 2:
136      return np.fft.ifft2(x, s=fft_length, axes=(-2, -1))
137    elif rank == 3:
138      return np.fft.ifft2(x, s=fft_length, axes=(-3, -2, -1))
139    else:
140      raise ValueError("invalid rank")
141
142  def _tf_fft_for_rank(self, rank):
143    if rank == 1:
144      return fft_ops.fft
145    elif rank == 2:
146      return fft_ops.fft2d
147    elif rank == 3:
148      return fft_ops.fft3d
149    else:
150      raise ValueError("invalid rank")
151
152  def _tf_ifft_for_rank(self, rank):
153    if rank == 1:
154      return fft_ops.ifft
155    elif rank == 2:
156      return fft_ops.ifft2d
157    elif rank == 3:
158      return fft_ops.ifft3d
159    else:
160      raise ValueError("invalid rank")
161
162  @parameterized.parameters(itertools.product(
163      VALID_FFT_RANKS, range(3), (np.complex64, np.complex128)))
164  def test_empty(self, rank, extra_dims, np_type):
165    dims = rank + extra_dims
166    x = np.zeros((0,) * dims).astype(np_type)
167    self.assertEqual(x.shape, self._tf_fft(x, rank).shape)
168    self.assertEqual(x.shape, self._tf_ifft(x, rank).shape)
169
170  @parameterized.parameters(
171      itertools.product(VALID_FFT_RANKS, range(3),
172                        (np.complex64, np.complex128)))
173  def test_basic(self, rank, extra_dims, np_type):
174    dims = rank + extra_dims
175    tol = 1e-4 if np_type == np.complex64 else 1e-8
176    self._compare(
177        np.mod(np.arange(np.power(4, dims)), 10).reshape(
178            (4,) * dims).astype(np_type), rank, rtol=tol, atol=tol)
179
180  @parameterized.parameters(itertools.product(
181      (1,), range(3), (np.complex64, np.complex128)))
182  def test_large_batch(self, rank, extra_dims, np_type):
183    dims = rank + extra_dims
184    tol = 1e-4 if np_type == np.complex64 else 5e-5
185    self._compare(
186        np.mod(np.arange(np.power(128, dims)), 10).reshape(
187            (128,) * dims).astype(np_type), rank, rtol=tol, atol=tol)
188
189  # TODO(yangzihao): Disable before we can figure out a way to
190  # properly test memory fail for large batch fft.
191  # def test_large_batch_memory_fail(self):
192  #   if test.is_gpu_available(cuda_only=True):
193  #     rank = 1
194  #     for dims in range(rank, rank + 3):
195  #       self._check_memory_fail(
196  #           np.mod(np.arange(np.power(128, dims)), 64).reshape(
197  #               (128,) * dims).astype(np.complex64), rank)
198
199  @parameterized.parameters(itertools.product(
200      VALID_FFT_RANKS, range(3), (np.complex64, np.complex128)))
201  def test_placeholder(self, rank, extra_dims, np_type):
202    if context.executing_eagerly():
203      return
204    tol = 1e-4 if np_type == np.complex64 else 1e-8
205    dims = rank + extra_dims
206    self._compare(
207        np.mod(np.arange(np.power(4, dims)), 10).reshape(
208            (4,) * dims).astype(np_type),
209        rank, use_placeholder=True, rtol=tol, atol=tol)
210
211  @parameterized.parameters(itertools.product(
212      VALID_FFT_RANKS, range(3), (np.complex64, np.complex128)))
213  def test_random(self, rank, extra_dims, np_type):
214    tol = 1e-4 if np_type == np.complex64 else 5e-6
215    dims = rank + extra_dims
216    def gen(shape):
217      n = np.prod(shape)
218      re = np.random.uniform(size=n)
219      im = np.random.uniform(size=n)
220      return (re + im * 1j).reshape(shape)
221
222    self._compare(gen((4,) * dims).astype(np_type), rank,
223                  rtol=tol, atol=tol)
224
225  @parameterized.parameters(itertools.product(
226      VALID_FFT_RANKS,
227      # Check a variety of sizes (power-of-2, odd, etc.)
228      [128, 256, 512, 1024, 127, 255, 511, 1023],
229      (np.complex64, np.complex128)))
230  def test_random_1d(self, rank, dim, np_type):
231    has_gpu = test.is_gpu_available(cuda_only=True)
232    tol = {(np.complex64, True): 1e-4,
233           (np.complex64, False): 1e-2,
234           (np.complex128, True): 1e-4,
235           (np.complex128, False): 1e-2}[(np_type, has_gpu)]
236    def gen(shape):
237      n = np.prod(shape)
238      re = np.random.uniform(size=n)
239      im = np.random.uniform(size=n)
240      return (re + im * 1j).reshape(shape)
241
242    self._compare(gen((dim,)).astype(np_type), 1, rtol=tol, atol=tol)
243
244  def test_error(self):
245    # TODO(rjryan): Fix this test under Eager.
246    if context.executing_eagerly():
247      return
248    for rank in VALID_FFT_RANKS:
249      for dims in range(0, rank):
250        x = np.zeros((1,) * dims).astype(np.complex64)
251        with self.assertRaisesWithPredicateMatch(
252            ValueError, "Shape must be .*rank {}.*".format(rank)):
253          self._tf_fft(x, rank)
254        with self.assertRaisesWithPredicateMatch(
255            ValueError, "Shape must be .*rank {}.*".format(rank)):
256          self._tf_ifft(x, rank)
257
258  @parameterized.parameters(itertools.product(
259      VALID_FFT_RANKS, range(2), (np.float32, np.float64)))
260  def test_grad_simple(self, rank, extra_dims, np_type):
261    tol = 1e-4 if np_type == np.float32 else 1e-10
262    dims = rank + extra_dims
263    re = np.ones(shape=(4,) * dims, dtype=np_type) / 10.0
264    im = np.zeros(shape=(4,) * dims, dtype=np_type)
265    self._check_grad_complex(self._tf_fft_for_rank(rank), re, im,
266                             rtol=tol, atol=tol)
267    self._check_grad_complex(self._tf_ifft_for_rank(rank), re, im,
268                             rtol=tol, atol=tol)
269
270  @unittest.skip("16.86% flaky")
271  @parameterized.parameters(itertools.product(
272      VALID_FFT_RANKS, range(2), (np.float32, np.float64)))
273  def test_grad_random(self, rank, extra_dims, np_type):
274    dims = rank + extra_dims
275    tol = 1e-2 if np_type == np.float32 else 1e-10
276    re = np.random.rand(*((3,) * dims)).astype(np_type) * 2 - 1
277    im = np.random.rand(*((3,) * dims)).astype(np_type) * 2 - 1
278    self._check_grad_complex(self._tf_fft_for_rank(rank), re, im,
279                             rtol=tol, atol=tol)
280    self._check_grad_complex(self._tf_ifft_for_rank(rank), re, im,
281                             rtol=tol, atol=tol)
282
283
284@test_util.run_all_in_graph_and_eager_modes
285@test_util.disable_xla("b/155276727")
286class RFFTOpsTest(BaseFFTOpsTest, parameterized.TestCase):
287
288  def _tf_fft(self, x, rank, fft_length=None, feed_dict=None):
289    with self.cached_session() as sess:
290      return sess.run(
291          self._tf_fft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)
292
293  def _tf_ifft(self, x, rank, fft_length=None, feed_dict=None):
294    with self.cached_session() as sess:
295      return sess.run(
296          self._tf_ifft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)
297
298  def _np_fft(self, x, rank, fft_length=None):
299    if rank == 1:
300      return np.fft.rfft2(x, s=fft_length, axes=(-1,))
301    elif rank == 2:
302      return np.fft.rfft2(x, s=fft_length, axes=(-2, -1))
303    elif rank == 3:
304      return np.fft.rfft2(x, s=fft_length, axes=(-3, -2, -1))
305    else:
306      raise ValueError("invalid rank")
307
308  def _np_ifft(self, x, rank, fft_length=None):
309    if rank == 1:
310      return np.fft.irfft2(x, s=fft_length, axes=(-1,))
311    elif rank == 2:
312      return np.fft.irfft2(x, s=fft_length, axes=(-2, -1))
313    elif rank == 3:
314      return np.fft.irfft2(x, s=fft_length, axes=(-3, -2, -1))
315    else:
316      raise ValueError("invalid rank")
317
318  def _tf_fft_for_rank(self, rank):
319    if rank == 1:
320      return fft_ops.rfft
321    elif rank == 2:
322      return fft_ops.rfft2d
323    elif rank == 3:
324      return fft_ops.rfft3d
325    else:
326      raise ValueError("invalid rank")
327
328  def _tf_ifft_for_rank(self, rank):
329    if rank == 1:
330      return fft_ops.irfft
331    elif rank == 2:
332      return fft_ops.irfft2d
333    elif rank == 3:
334      return fft_ops.irfft3d
335    else:
336      raise ValueError("invalid rank")
337
338  # rocFFT requires/assumes that the input to the irfft transform
339  # is of the form that is a valid output from the rfft transform
340  # (i.e. it cannot be a set of random numbers)
341  # So for ROCm, call rfft and use its output as the input for testing irfft
342  def _generate_valid_irfft_input(self, c2r, np_ctype, r2c, np_rtype, rank,
343                                  fft_length):
344    if test.is_built_with_rocm():
345      return self._np_fft(r2c.astype(np_rtype), rank, fft_length)
346    else:
347      return c2r.astype(np_ctype)
348
349  @parameterized.parameters(itertools.product(
350      VALID_FFT_RANKS, range(3), (np.float32, np.float64)))
351
352  def test_empty(self, rank, extra_dims, np_rtype):
353    np_ctype = np.complex64 if np_rtype == np.float32 else np.complex128
354    dims = rank + extra_dims
355    x = np.zeros((0,) * dims).astype(np_rtype)
356    self.assertEqual(x.shape, self._tf_fft(x, rank).shape)
357    x = np.zeros((0,) * dims).astype(np_ctype)
358    self.assertEqual(x.shape, self._tf_ifft(x, rank).shape)
359
360  @parameterized.parameters(itertools.product(
361      VALID_FFT_RANKS, range(3), (5, 6), (np.float32, np.float64)))
362  def test_basic(self, rank, extra_dims, size, np_rtype):
363    np_ctype = np.complex64 if np_rtype == np.float32 else np.complex128
364    tol = 1e-4 if np_rtype == np.float32 else 5e-5
365    dims = rank + extra_dims
366    inner_dim = size // 2 + 1
367    r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
368        (size,) * dims)
369    fft_length = (size,) * rank
370    self._compare_forward(
371        r2c.astype(np_rtype), rank, fft_length, rtol=tol, atol=tol)
372    c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
373                 10).reshape((size,) * (dims - 1) + (inner_dim,))
374    c2r = self._generate_valid_irfft_input(c2r, np_ctype, r2c, np_rtype, rank,
375                                           fft_length)
376    self._compare_backward(c2r, rank, fft_length, rtol=tol, atol=tol)
377
378  @parameterized.parameters(itertools.product(
379      (1,), range(3), (64, 128), (np.float32, np.float64)))
380  def test_large_batch(self, rank, extra_dims, size, np_rtype):
381    np_ctype = np.complex64 if np_rtype == np.float32 else np.complex128
382    tol = 1e-4 if np_rtype == np.float32 else 1e-5
383    dims = rank + extra_dims
384    inner_dim = size // 2 + 1
385    r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
386        (size,) * dims)
387    fft_length = (size,) * rank
388    self._compare_forward(
389        r2c.astype(np_rtype), rank, fft_length, rtol=tol, atol=tol)
390    c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
391                 10).reshape((size,) * (dims - 1) + (inner_dim,))
392    c2r = self._generate_valid_irfft_input(c2r, np_ctype, r2c, np_rtype, rank,
393                                           fft_length)
394    self._compare_backward(c2r, rank, fft_length, rtol=tol, atol=tol)
395
396  @parameterized.parameters(itertools.product(
397      VALID_FFT_RANKS, range(3), (5, 6), (np.float32, np.float64)))
398  def test_placeholder(self, rank, extra_dims, size, np_rtype):
399    if context.executing_eagerly():
400      return
401    np_ctype = np.complex64 if np_rtype == np.float32 else np.complex128
402    tol = 1e-4 if np_rtype == np.float32 else 1e-8
403    dims = rank + extra_dims
404    inner_dim = size // 2 + 1
405    r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
406        (size,) * dims)
407    fft_length = (size,) * rank
408    self._compare_forward(
409        r2c.astype(np_rtype),
410        rank,
411        fft_length,
412        use_placeholder=True,
413        rtol=tol,
414        atol=tol)
415    c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
416                 10).reshape((size,) * (dims - 1) + (inner_dim,))
417    c2r = self._generate_valid_irfft_input(c2r, np_ctype, r2c, np_rtype, rank,
418                                           fft_length)
419    self._compare_backward(
420        c2r, rank, fft_length, use_placeholder=True, rtol=tol, atol=tol)
421
422  @parameterized.parameters(itertools.product(
423      VALID_FFT_RANKS, range(3), (5, 6), (np.float32, np.float64)))
424  def test_fft_lenth_truncate(self, rank, extra_dims, size, np_rtype):
425    """Test truncation (FFT size < dimensions)."""
426    if test.is_built_with_rocm() and (rank == 3):
427      # TODO(rocm): fix me
428      # rfft fails for rank == 3 on ROCm
429      self.skipTest("Test fails on ROCm...fix me")
430    np_ctype = np.complex64 if np_rtype == np.float32 else np.complex128
431    tol = 1e-4 if np_rtype == np.float32 else 8e-5
432    dims = rank + extra_dims
433    inner_dim = size // 2 + 1
434    r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
435        (size,) * dims)
436    c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
437                 10).reshape((size,) * (dims - 1) + (inner_dim,))
438    fft_length = (size - 2,) * rank
439    self._compare_forward(r2c.astype(np_rtype), rank, fft_length,
440                          rtol=tol, atol=tol)
441    c2r = self._generate_valid_irfft_input(c2r, np_ctype, r2c, np_rtype, rank,
442                                           fft_length)
443    self._compare_backward(c2r, rank, fft_length, rtol=tol, atol=tol)
444    # Confirm it works with unknown shapes as well.
445    if not context.executing_eagerly():
446      self._compare_forward(
447          r2c.astype(np_rtype),
448          rank,
449          fft_length,
450          use_placeholder=True,
451          rtol=tol, atol=tol)
452      self._compare_backward(
453          c2r, rank, fft_length, use_placeholder=True, rtol=tol, atol=tol)
454
455  @parameterized.parameters(itertools.product(
456      VALID_FFT_RANKS, range(3), (5, 6), (np.float32, np.float64)))
457  def test_fft_lenth_pad(self, rank, extra_dims, size, np_rtype):
458    """Test padding (FFT size > dimensions)."""
459    np_ctype = np.complex64 if np_rtype == np.float32 else np.complex128
460    tol = 1e-4 if np_rtype == np.float32 else 8e-5
461    dims = rank + extra_dims
462    inner_dim = size // 2 + 1
463    r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
464        (size,) * dims)
465    c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
466                 10).reshape((size,) * (dims - 1) + (inner_dim,))
467    fft_length = (size + 2,) * rank
468    self._compare_forward(r2c.astype(np_rtype), rank, fft_length,
469                          rtol=tol, atol=tol)
470    c2r = self._generate_valid_irfft_input(c2r, np_ctype, r2c, np_rtype, rank,
471                                           fft_length)
472    self._compare_backward(c2r.astype(np_ctype), rank, fft_length,
473                           rtol=tol, atol=tol)
474    # Confirm it works with unknown shapes as well.
475    if not context.executing_eagerly():
476      self._compare_forward(
477          r2c.astype(np_rtype),
478          rank,
479          fft_length,
480          use_placeholder=True,
481          rtol=tol, atol=tol)
482      self._compare_backward(
483          c2r.astype(np_ctype),
484          rank,
485          fft_length,
486          use_placeholder=True,
487          rtol=tol, atol=tol)
488
489  @parameterized.parameters(itertools.product(
490      VALID_FFT_RANKS, range(3), (5, 6), (np.float32, np.float64)))
491  def test_random(self, rank, extra_dims, size, np_rtype):
492    def gen_real(shape):
493      n = np.prod(shape)
494      re = np.random.uniform(size=n)
495      ret = re.reshape(shape)
496      return ret
497
498    def gen_complex(shape):
499      n = np.prod(shape)
500      re = np.random.uniform(size=n)
501      im = np.random.uniform(size=n)
502      ret = (re + im * 1j).reshape(shape)
503      return ret
504    np_ctype = np.complex64 if np_rtype == np.float32 else np.complex128
505    tol = 1e-4 if np_rtype == np.float32 else 1e-5
506    dims = rank + extra_dims
507    r2c = gen_real((size,) * dims)
508    inner_dim = size // 2 + 1
509    fft_length = (size,) * rank
510    self._compare_forward(
511        r2c.astype(np_rtype), rank, fft_length, rtol=tol, atol=tol)
512    complex_dims = (size,) * (dims - 1) + (inner_dim,)
513    c2r = gen_complex(complex_dims)
514    c2r = self._generate_valid_irfft_input(c2r, np_ctype, r2c, np_rtype, rank,
515                                           fft_length)
516    self._compare_backward(c2r, rank, fft_length, rtol=tol, atol=tol)
517
518  def test_error(self):
519    # TODO(rjryan): Fix this test under Eager.
520    if context.executing_eagerly():
521      return
522    for rank in VALID_FFT_RANKS:
523      for dims in range(0, rank):
524        x = np.zeros((1,) * dims).astype(np.complex64)
525        with self.assertRaisesWithPredicateMatch(
526            ValueError, "Shape .* must have rank at least {}".format(rank)):
527          self._tf_fft(x, rank)
528        with self.assertRaisesWithPredicateMatch(
529            ValueError, "Shape .* must have rank at least {}".format(rank)):
530          self._tf_ifft(x, rank)
531      for dims in range(rank, rank + 2):
532        x = np.zeros((1,) * rank)
533
534        # Test non-rank-1 fft_length produces an error.
535        fft_length = np.zeros((1, 1)).astype(np.int32)
536        with self.assertRaisesWithPredicateMatch(ValueError,
537                                                 "Shape .* must have rank 1"):
538          self._tf_fft(x, rank, fft_length)
539        with self.assertRaisesWithPredicateMatch(ValueError,
540                                                 "Shape .* must have rank 1"):
541          self._tf_ifft(x, rank, fft_length)
542
543        # Test wrong fft_length length.
544        fft_length = np.zeros((rank + 1,)).astype(np.int32)
545        with self.assertRaisesWithPredicateMatch(
546            ValueError, "Dimension must be .*but is {}.*".format(rank + 1)):
547          self._tf_fft(x, rank, fft_length)
548        with self.assertRaisesWithPredicateMatch(
549            ValueError, "Dimension must be .*but is {}.*".format(rank + 1)):
550          self._tf_ifft(x, rank, fft_length)
551
552      # Test that calling the kernel directly without padding to fft_length
553      # produces an error.
554      rffts_for_rank = {
555          1: [gen_spectral_ops.rfft, gen_spectral_ops.irfft],
556          2: [gen_spectral_ops.rfft2d, gen_spectral_ops.irfft2d],
557          3: [gen_spectral_ops.rfft3d, gen_spectral_ops.irfft3d]
558      }
559      rfft_fn, irfft_fn = rffts_for_rank[rank]
560      with self.assertRaisesWithPredicateMatch(
561          errors.InvalidArgumentError,
562          "Input dimension .* must have length of at least 6 but got: 5"):
563        x = np.zeros((5,) * rank).astype(np.float32)
564        fft_length = [6] * rank
565        with self.cached_session():
566          self.evaluate(rfft_fn(x, fft_length))
567
568      with self.assertRaisesWithPredicateMatch(
569          errors.InvalidArgumentError,
570          "Input dimension .* must have length of at least .* but got: 3"):
571        x = np.zeros((3,) * rank).astype(np.complex64)
572        fft_length = [6] * rank
573        with self.cached_session():
574          self.evaluate(irfft_fn(x, fft_length))
575
576  @parameterized.parameters(itertools.product(
577      VALID_FFT_RANKS, range(2), (5, 6), (np.float32, np.float64)))
578  def test_grad_simple(self, rank, extra_dims, size, np_rtype):
579    # rfft3d/irfft3d do not have gradients yet.
580    if rank == 3:
581      return
582    dims = rank + extra_dims
583    tol = 1e-3 if np_rtype == np.float32 else 1e-10
584    re = np.ones(shape=(size,) * dims, dtype=np_rtype)
585    im = -np.ones(shape=(size,) * dims, dtype=np_rtype)
586    self._check_grad_real(self._tf_fft_for_rank(rank), re,
587                          rtol=tol, atol=tol)
588    if test.is_built_with_rocm():
589      # Fails on ROCm because of irfft peculairity
590      return
591    self._check_grad_complex(
592        self._tf_ifft_for_rank(rank), re, im, result_is_complex=False,
593        rtol=tol, atol=tol)
594
595  @parameterized.parameters(itertools.product(
596      VALID_FFT_RANKS, range(2), (5, 6), (np.float32, np.float64)))
597  def test_grad_random(self, rank, extra_dims, size, np_rtype):
598    # rfft3d/irfft3d do not have gradients yet.
599    if rank == 3:
600      return
601    dims = rank + extra_dims
602    tol = 1e-2 if np_rtype == np.float32 else 1e-10
603    re = np.random.rand(*((size,) * dims)).astype(np_rtype) * 2 - 1
604    im = np.random.rand(*((size,) * dims)).astype(np_rtype) * 2 - 1
605    self._check_grad_real(self._tf_fft_for_rank(rank), re,
606                          rtol=tol, atol=tol)
607    if test.is_built_with_rocm():
608      # Fails on ROCm because of irfft peculairity
609      return
610    self._check_grad_complex(
611        self._tf_ifft_for_rank(rank), re, im, result_is_complex=False,
612        rtol=tol, atol=tol)
613
614  def test_invalid_args(self):
615    # Test case for GitHub issue 55263
616    a = np.empty([6, 0])
617    b = np.array([1, -1])
618    with self.assertRaisesRegex(errors.InvalidArgumentError, "must >= 0"):
619      with self.session():
620        v = fft_ops.rfft2d(input_tensor=a, fft_length=b)
621        self.evaluate(v)
622
623
624@test_util.run_all_in_graph_and_eager_modes
625class FFTShiftTest(test.TestCase, parameterized.TestCase):
626
627  def test_definition(self):
628    with self.session():
629      x = [0, 1, 2, 3, 4, -4, -3, -2, -1]
630      y = [-4, -3, -2, -1, 0, 1, 2, 3, 4]
631      self.assertAllEqual(fft_ops.fftshift(x), y)
632      self.assertAllEqual(fft_ops.ifftshift(y), x)
633      x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1]
634      y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
635      self.assertAllEqual(fft_ops.fftshift(x), y)
636      self.assertAllEqual(fft_ops.ifftshift(y), x)
637
638  def test_axes_keyword(self):
639    with self.session():
640      freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]]
641      shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]]
642      self.assertAllEqual(fft_ops.fftshift(freqs, axes=(0, 1)), shifted)
643      self.assertAllEqual(
644          fft_ops.fftshift(freqs, axes=0),
645          fft_ops.fftshift(freqs, axes=(0,)))
646      self.assertAllEqual(fft_ops.ifftshift(shifted, axes=(0, 1)), freqs)
647      self.assertAllEqual(
648          fft_ops.ifftshift(shifted, axes=0),
649          fft_ops.ifftshift(shifted, axes=(0,)))
650      self.assertAllEqual(fft_ops.fftshift(freqs), shifted)
651      self.assertAllEqual(fft_ops.ifftshift(shifted), freqs)
652
653  def test_numpy_compatibility(self):
654    with self.session():
655      x = [0, 1, 2, 3, 4, -4, -3, -2, -1]
656      y = [-4, -3, -2, -1, 0, 1, 2, 3, 4]
657      self.assertAllEqual(fft_ops.fftshift(x), np.fft.fftshift(x))
658      self.assertAllEqual(fft_ops.ifftshift(y), np.fft.ifftshift(y))
659      x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1]
660      y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
661      self.assertAllEqual(fft_ops.fftshift(x), np.fft.fftshift(x))
662      self.assertAllEqual(fft_ops.ifftshift(y), np.fft.ifftshift(y))
663      freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]]
664      shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]]
665      self.assertAllEqual(
666          fft_ops.fftshift(freqs, axes=(0, 1)),
667          np.fft.fftshift(freqs, axes=(0, 1)))
668      self.assertAllEqual(
669          fft_ops.ifftshift(shifted, axes=(0, 1)),
670          np.fft.ifftshift(shifted, axes=(0, 1)))
671
672  @parameterized.parameters(None, 1, ([1, 2],))
673  def test_placeholder(self, axes):
674    if context.executing_eagerly():
675      return
676    x = array_ops.placeholder(shape=[None, None, None], dtype="float32")
677    y_fftshift = fft_ops.fftshift(x, axes=axes)
678    y_ifftshift = fft_ops.ifftshift(x, axes=axes)
679    x_np = np.random.rand(16, 256, 256)
680    with self.session() as sess:
681      y_fftshift_res, y_ifftshift_res = sess.run(
682          [y_fftshift, y_ifftshift],
683          feed_dict={x: x_np})
684    self.assertAllClose(y_fftshift_res, np.fft.fftshift(x_np, axes=axes))
685    self.assertAllClose(y_ifftshift_res, np.fft.ifftshift(x_np, axes=axes))
686
687  def test_negative_axes(self):
688    with self.session():
689      freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]]
690      shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]]
691      self.assertAllEqual(fft_ops.fftshift(freqs, axes=(0, -1)), shifted)
692      self.assertAllEqual(fft_ops.ifftshift(shifted, axes=(0, -1)), freqs)
693      self.assertAllEqual(
694          fft_ops.fftshift(freqs, axes=-1), fft_ops.fftshift(freqs, axes=(1,)))
695      self.assertAllEqual(
696          fft_ops.ifftshift(shifted, axes=-1),
697          fft_ops.ifftshift(shifted, axes=(1,)))
698
699
700if __name__ == "__main__":
701  test.main()
702