xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/signal/dct_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Discrete Cosine Transform ops."""
16import math as _math
17
18from tensorflow.python.framework import dtypes as _dtypes
19from tensorflow.python.framework import ops as _ops
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.ops import array_ops as _array_ops
22from tensorflow.python.ops import math_ops as _math_ops
23from tensorflow.python.ops.signal import fft_ops
24from tensorflow.python.util import dispatch
25from tensorflow.python.util.tf_export import tf_export
26
27
28def _validate_dct_arguments(input_tensor, dct_type, n, axis, norm):
29  """Checks that DCT/IDCT arguments are compatible and well formed."""
30  if axis != -1:
31    raise NotImplementedError("axis must be -1. Got: %s" % axis)
32  if n is not None and n < 1:
33    raise ValueError("n should be a positive integer or None")
34  if dct_type not in (1, 2, 3, 4):
35    raise ValueError("Types I, II, III and IV (I)DCT are supported.")
36  if dct_type == 1:
37    if norm == "ortho":
38      raise ValueError("Normalization is not supported for the Type-I DCT.")
39    if input_tensor.shape[-1] is not None and input_tensor.shape[-1] < 2:
40      raise ValueError(
41          "Type-I DCT requires the dimension to be greater than one.")
42
43  if norm not in (None, "ortho"):
44    raise ValueError(
45        "Unknown normalization. Expected None or 'ortho', got: %s" % norm)
46
47
48# TODO(rjryan): Implement `axis` parameter.
49@tf_export("signal.dct", v1=["signal.dct", "spectral.dct"])
50@dispatch.add_dispatch_support
51def dct(input, type=2, n=None, axis=-1, norm=None, name=None):  # pylint: disable=redefined-builtin
52  """Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`.
53
54  Types I, II, III and IV are supported.
55  Type I is implemented using a length `2N` padded `tf.signal.rfft`.
56  Type II is implemented using a length `2N` padded `tf.signal.rfft`, as
57   described here: [Type 2 DCT using 2N FFT padded (Makhoul)]
58   (https://dsp.stackexchange.com/a/10606).
59  Type III is a fairly straightforward inverse of Type II
60   (i.e. using a length `2N` padded `tf.signal.irfft`).
61   Type IV is calculated through 2N length DCT2 of padded signal and
62  picking the odd indices.
63
64  @compatibility(scipy)
65  Equivalent to [scipy.fftpack.dct]
66   (https://docs.scipy.org/doc/scipy-1.4.0/reference/generated/scipy.fftpack.dct.html)
67   for Type-I, Type-II, Type-III and Type-IV DCT.
68  @end_compatibility
69
70  Args:
71    input: A `[..., samples]` `float32`/`float64` `Tensor` containing the
72      signals to take the DCT of.
73    type: The DCT type to perform. Must be 1, 2, 3 or 4.
74    n: The length of the transform. If length is less than sequence length,
75      only the first n elements of the sequence are considered for the DCT.
76      If n is greater than the sequence length, zeros are padded and then
77      the DCT is computed as usual.
78    axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
79    norm: The normalization to apply. `None` for no normalization or `'ortho'`
80      for orthonormal normalization.
81    name: An optional name for the operation.
82
83  Returns:
84    A `[..., samples]` `float32`/`float64` `Tensor` containing the DCT of
85    `input`.
86
87  Raises:
88    ValueError: If `type` is not `1`, `2`, `3` or `4`, `axis` is
89      not `-1`, `n` is not `None` or greater than 0,
90      or `norm` is not `None` or `'ortho'`.
91    ValueError: If `type` is `1` and `norm` is `ortho`.
92
93  [dct]: https://en.wikipedia.org/wiki/Discrete_cosine_transform
94  """
95  _validate_dct_arguments(input, type, n, axis, norm)
96  with _ops.name_scope(name, "dct", [input]):
97    input = _ops.convert_to_tensor(input)
98    zero = _ops.convert_to_tensor(0.0, dtype=input.dtype)
99
100    seq_len = (
101        tensor_shape.dimension_value(input.shape[-1]) or
102        _array_ops.shape(input)[-1])
103    if n is not None:
104      if n <= seq_len:
105        input = input[..., 0:n]
106      else:
107        rank = len(input.shape)
108        padding = [[0, 0] for _ in range(rank)]
109        padding[rank - 1][1] = n - seq_len
110        padding = _ops.convert_to_tensor(padding, dtype=_dtypes.int32)
111        input = _array_ops.pad(input, paddings=padding)
112
113    axis_dim = (tensor_shape.dimension_value(input.shape[-1])
114                or _array_ops.shape(input)[-1])
115    axis_dim_float = _math_ops.cast(axis_dim, input.dtype)
116
117    if type == 1:
118      dct1_input = _array_ops.concat([input, input[..., -2:0:-1]], axis=-1)
119      dct1 = _math_ops.real(fft_ops.rfft(dct1_input))
120      return dct1
121
122    if type == 2:
123      scale = 2.0 * _math_ops.exp(
124          _math_ops.complex(
125              zero, -_math_ops.range(axis_dim_float) * _math.pi * 0.5 /
126              axis_dim_float))
127
128      # TODO(rjryan): Benchmark performance and memory usage of the various
129      # approaches to computing a DCT via the RFFT.
130      dct2 = _math_ops.real(
131          fft_ops.rfft(
132              input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale)
133
134      if norm == "ortho":
135        n1 = 0.5 * _math_ops.rsqrt(axis_dim_float)
136        n2 = n1 * _math.sqrt(2.0)
137        # Use tf.pad to make a vector of [n1, n2, n2, n2, ...].
138        weights = _array_ops.pad(
139            _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]],
140            constant_values=n2)
141        dct2 *= weights
142
143      return dct2
144
145    elif type == 3:
146      if norm == "ortho":
147        n1 = _math_ops.sqrt(axis_dim_float)
148        n2 = n1 * _math.sqrt(0.5)
149        # Use tf.pad to make a vector of [n1, n2, n2, n2, ...].
150        weights = _array_ops.pad(
151            _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]],
152            constant_values=n2)
153        input *= weights
154      else:
155        input *= axis_dim_float
156      scale = 2.0 * _math_ops.exp(
157          _math_ops.complex(
158              zero,
159              _math_ops.range(axis_dim_float) * _math.pi * 0.5 /
160              axis_dim_float))
161      dct3 = _math_ops.real(
162          fft_ops.irfft(
163              scale * _math_ops.complex(input, zero),
164              fft_length=[2 * axis_dim]))[..., :axis_dim]
165
166      return dct3
167
168    elif type == 4:
169      # DCT-2 of 2N length zero-padded signal, unnormalized.
170      dct2 = dct(input, type=2, n=2*axis_dim, axis=axis, norm=None)
171      # Get odd indices of DCT-2 of zero padded 2N signal to obtain
172      # DCT-4 of the original N length signal.
173      dct4 = dct2[..., 1::2]
174      if norm == "ortho":
175        dct4 *= _math.sqrt(0.5) * _math_ops.rsqrt(axis_dim_float)
176
177      return dct4
178
179
180# TODO(rjryan): Implement `n` and `axis` parameters.
181@tf_export("signal.idct", v1=["signal.idct", "spectral.idct"])
182@dispatch.add_dispatch_support
183def idct(input, type=2, n=None, axis=-1, norm=None, name=None):  # pylint: disable=redefined-builtin
184  """Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`.
185
186  Currently Types I, II, III, IV are supported. Type III is the inverse of
187  Type II, and vice versa.
188
189  Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is
190  not `'ortho'`. That is:
191  `signal == idct(dct(signal)) * 0.5 / signal.shape[-1]`.
192  When `norm='ortho'`, we have:
193  `signal == idct(dct(signal, norm='ortho'), norm='ortho')`.
194
195  @compatibility(scipy)
196  Equivalent to [scipy.fftpack.idct]
197   (https://docs.scipy.org/doc/scipy-1.4.0/reference/generated/scipy.fftpack.idct.html)
198   for Type-I, Type-II, Type-III and Type-IV DCT.
199  @end_compatibility
200
201  Args:
202    input: A `[..., samples]` `float32`/`float64` `Tensor` containing the
203      signals to take the DCT of.
204    type: The IDCT type to perform. Must be 1, 2, 3 or 4.
205    n: For future expansion. The length of the transform. Must be `None`.
206    axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
207    norm: The normalization to apply. `None` for no normalization or `'ortho'`
208      for orthonormal normalization.
209    name: An optional name for the operation.
210
211  Returns:
212    A `[..., samples]` `float32`/`float64` `Tensor` containing the IDCT of
213    `input`.
214
215  Raises:
216    ValueError: If `type` is not `1`, `2` or `3`, `n` is not `None, `axis` is
217      not `-1`, or `norm` is not `None` or `'ortho'`.
218
219  [idct]:
220  https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms
221  """
222  _validate_dct_arguments(input, type, n, axis, norm)
223  inverse_type = {1: 1, 2: 3, 3: 2, 4: 4}[type]
224  return dct(input, type=inverse_type, n=n, axis=axis, norm=norm, name=name)
225