1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Discrete Cosine Transform ops.""" 16import math as _math 17 18from tensorflow.python.framework import dtypes as _dtypes 19from tensorflow.python.framework import ops as _ops 20from tensorflow.python.framework import tensor_shape 21from tensorflow.python.ops import array_ops as _array_ops 22from tensorflow.python.ops import math_ops as _math_ops 23from tensorflow.python.ops.signal import fft_ops 24from tensorflow.python.util import dispatch 25from tensorflow.python.util.tf_export import tf_export 26 27 28def _validate_dct_arguments(input_tensor, dct_type, n, axis, norm): 29 """Checks that DCT/IDCT arguments are compatible and well formed.""" 30 if axis != -1: 31 raise NotImplementedError("axis must be -1. Got: %s" % axis) 32 if n is not None and n < 1: 33 raise ValueError("n should be a positive integer or None") 34 if dct_type not in (1, 2, 3, 4): 35 raise ValueError("Types I, II, III and IV (I)DCT are supported.") 36 if dct_type == 1: 37 if norm == "ortho": 38 raise ValueError("Normalization is not supported for the Type-I DCT.") 39 if input_tensor.shape[-1] is not None and input_tensor.shape[-1] < 2: 40 raise ValueError( 41 "Type-I DCT requires the dimension to be greater than one.") 42 43 if norm not in (None, "ortho"): 44 raise ValueError( 45 "Unknown normalization. Expected None or 'ortho', got: %s" % norm) 46 47 48# TODO(rjryan): Implement `axis` parameter. 49@tf_export("signal.dct", v1=["signal.dct", "spectral.dct"]) 50@dispatch.add_dispatch_support 51def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin 52 """Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`. 53 54 Types I, II, III and IV are supported. 55 Type I is implemented using a length `2N` padded `tf.signal.rfft`. 56 Type II is implemented using a length `2N` padded `tf.signal.rfft`, as 57 described here: [Type 2 DCT using 2N FFT padded (Makhoul)] 58 (https://dsp.stackexchange.com/a/10606). 59 Type III is a fairly straightforward inverse of Type II 60 (i.e. using a length `2N` padded `tf.signal.irfft`). 61 Type IV is calculated through 2N length DCT2 of padded signal and 62 picking the odd indices. 63 64 @compatibility(scipy) 65 Equivalent to [scipy.fftpack.dct] 66 (https://docs.scipy.org/doc/scipy-1.4.0/reference/generated/scipy.fftpack.dct.html) 67 for Type-I, Type-II, Type-III and Type-IV DCT. 68 @end_compatibility 69 70 Args: 71 input: A `[..., samples]` `float32`/`float64` `Tensor` containing the 72 signals to take the DCT of. 73 type: The DCT type to perform. Must be 1, 2, 3 or 4. 74 n: The length of the transform. If length is less than sequence length, 75 only the first n elements of the sequence are considered for the DCT. 76 If n is greater than the sequence length, zeros are padded and then 77 the DCT is computed as usual. 78 axis: For future expansion. The axis to compute the DCT along. Must be `-1`. 79 norm: The normalization to apply. `None` for no normalization or `'ortho'` 80 for orthonormal normalization. 81 name: An optional name for the operation. 82 83 Returns: 84 A `[..., samples]` `float32`/`float64` `Tensor` containing the DCT of 85 `input`. 86 87 Raises: 88 ValueError: If `type` is not `1`, `2`, `3` or `4`, `axis` is 89 not `-1`, `n` is not `None` or greater than 0, 90 or `norm` is not `None` or `'ortho'`. 91 ValueError: If `type` is `1` and `norm` is `ortho`. 92 93 [dct]: https://en.wikipedia.org/wiki/Discrete_cosine_transform 94 """ 95 _validate_dct_arguments(input, type, n, axis, norm) 96 with _ops.name_scope(name, "dct", [input]): 97 input = _ops.convert_to_tensor(input) 98 zero = _ops.convert_to_tensor(0.0, dtype=input.dtype) 99 100 seq_len = ( 101 tensor_shape.dimension_value(input.shape[-1]) or 102 _array_ops.shape(input)[-1]) 103 if n is not None: 104 if n <= seq_len: 105 input = input[..., 0:n] 106 else: 107 rank = len(input.shape) 108 padding = [[0, 0] for _ in range(rank)] 109 padding[rank - 1][1] = n - seq_len 110 padding = _ops.convert_to_tensor(padding, dtype=_dtypes.int32) 111 input = _array_ops.pad(input, paddings=padding) 112 113 axis_dim = (tensor_shape.dimension_value(input.shape[-1]) 114 or _array_ops.shape(input)[-1]) 115 axis_dim_float = _math_ops.cast(axis_dim, input.dtype) 116 117 if type == 1: 118 dct1_input = _array_ops.concat([input, input[..., -2:0:-1]], axis=-1) 119 dct1 = _math_ops.real(fft_ops.rfft(dct1_input)) 120 return dct1 121 122 if type == 2: 123 scale = 2.0 * _math_ops.exp( 124 _math_ops.complex( 125 zero, -_math_ops.range(axis_dim_float) * _math.pi * 0.5 / 126 axis_dim_float)) 127 128 # TODO(rjryan): Benchmark performance and memory usage of the various 129 # approaches to computing a DCT via the RFFT. 130 dct2 = _math_ops.real( 131 fft_ops.rfft( 132 input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale) 133 134 if norm == "ortho": 135 n1 = 0.5 * _math_ops.rsqrt(axis_dim_float) 136 n2 = n1 * _math.sqrt(2.0) 137 # Use tf.pad to make a vector of [n1, n2, n2, n2, ...]. 138 weights = _array_ops.pad( 139 _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]], 140 constant_values=n2) 141 dct2 *= weights 142 143 return dct2 144 145 elif type == 3: 146 if norm == "ortho": 147 n1 = _math_ops.sqrt(axis_dim_float) 148 n2 = n1 * _math.sqrt(0.5) 149 # Use tf.pad to make a vector of [n1, n2, n2, n2, ...]. 150 weights = _array_ops.pad( 151 _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]], 152 constant_values=n2) 153 input *= weights 154 else: 155 input *= axis_dim_float 156 scale = 2.0 * _math_ops.exp( 157 _math_ops.complex( 158 zero, 159 _math_ops.range(axis_dim_float) * _math.pi * 0.5 / 160 axis_dim_float)) 161 dct3 = _math_ops.real( 162 fft_ops.irfft( 163 scale * _math_ops.complex(input, zero), 164 fft_length=[2 * axis_dim]))[..., :axis_dim] 165 166 return dct3 167 168 elif type == 4: 169 # DCT-2 of 2N length zero-padded signal, unnormalized. 170 dct2 = dct(input, type=2, n=2*axis_dim, axis=axis, norm=None) 171 # Get odd indices of DCT-2 of zero padded 2N signal to obtain 172 # DCT-4 of the original N length signal. 173 dct4 = dct2[..., 1::2] 174 if norm == "ortho": 175 dct4 *= _math.sqrt(0.5) * _math_ops.rsqrt(axis_dim_float) 176 177 return dct4 178 179 180# TODO(rjryan): Implement `n` and `axis` parameters. 181@tf_export("signal.idct", v1=["signal.idct", "spectral.idct"]) 182@dispatch.add_dispatch_support 183def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin 184 """Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`. 185 186 Currently Types I, II, III, IV are supported. Type III is the inverse of 187 Type II, and vice versa. 188 189 Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is 190 not `'ortho'`. That is: 191 `signal == idct(dct(signal)) * 0.5 / signal.shape[-1]`. 192 When `norm='ortho'`, we have: 193 `signal == idct(dct(signal, norm='ortho'), norm='ortho')`. 194 195 @compatibility(scipy) 196 Equivalent to [scipy.fftpack.idct] 197 (https://docs.scipy.org/doc/scipy-1.4.0/reference/generated/scipy.fftpack.idct.html) 198 for Type-I, Type-II, Type-III and Type-IV DCT. 199 @end_compatibility 200 201 Args: 202 input: A `[..., samples]` `float32`/`float64` `Tensor` containing the 203 signals to take the DCT of. 204 type: The IDCT type to perform. Must be 1, 2, 3 or 4. 205 n: For future expansion. The length of the transform. Must be `None`. 206 axis: For future expansion. The axis to compute the DCT along. Must be `-1`. 207 norm: The normalization to apply. `None` for no normalization or `'ortho'` 208 for orthonormal normalization. 209 name: An optional name for the operation. 210 211 Returns: 212 A `[..., samples]` `float32`/`float64` `Tensor` containing the IDCT of 213 `input`. 214 215 Raises: 216 ValueError: If `type` is not `1`, `2` or `3`, `n` is not `None, `axis` is 217 not `-1`, or `norm` is not `None` or `'ortho'`. 218 219 [idct]: 220 https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms 221 """ 222 _validate_dct_arguments(input, type, n, axis, norm) 223 inverse_type = {1: 1, 2: 3, 3: 2, 4: 4}[type] 224 return dct(input, type=inverse_type, n=n, axis=axis, norm=norm, name=name) 225