xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/signal/window_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops for computing common window functions."""
16
17import numpy as np
18
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import control_flow_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import special_math_ops
27from tensorflow.python.util import dispatch
28from tensorflow.python.util.tf_export import tf_export
29
30
31def _check_params(window_length, dtype):
32  """Check window_length and dtype params.
33
34  Args:
35    window_length: A scalar value or `Tensor`.
36    dtype: The data type to produce. Must be a floating point type.
37
38  Returns:
39    window_length converted to a tensor of type int32.
40
41  Raises:
42    ValueError: If `dtype` is not a floating point type or window_length is not
43      a scalar.
44  """
45  if not dtype.is_floating:
46    raise ValueError('dtype must be a floating point type. Found %s' % dtype)
47  window_length = ops.convert_to_tensor(window_length, dtype=dtypes.int32)
48  window_length.shape.assert_has_rank(0)
49  return window_length
50
51
52@tf_export('signal.kaiser_window')
53@dispatch.add_dispatch_support
54def kaiser_window(window_length, beta=12., dtype=dtypes.float32, name=None):
55  """Generate a [Kaiser window][kaiser].
56
57  Args:
58    window_length: A scalar `Tensor` indicating the window length to generate.
59    beta: Beta parameter for Kaiser window, see reference below.
60    dtype: The data type to produce. Must be a floating point type.
61    name: An optional name for the operation.
62
63  Returns:
64    A `Tensor` of shape `[window_length]` of type `dtype`.
65
66  [kaiser]:
67    https://docs.scipy.org/doc/numpy/reference/generated/numpy.kaiser.html
68  """
69  with ops.name_scope(name, 'kaiser_window'):
70    window_length = _check_params(window_length, dtype)
71    window_length_const = tensor_util.constant_value(window_length)
72    if window_length_const == 1:
73      return array_ops.ones([1], dtype=dtype)
74    # tf.range does not support float16 so we work with float32 initially.
75    halflen_float = (
76        math_ops.cast(window_length, dtype=dtypes.float32) - 1.0) / 2.0
77    arg = math_ops.range(-halflen_float, halflen_float + 0.1,
78                         dtype=dtypes.float32)
79    # Convert everything into given dtype which can be float16.
80    arg = math_ops.cast(arg, dtype=dtype)
81    beta = math_ops.cast(beta, dtype=dtype)
82    one = math_ops.cast(1.0, dtype=dtype)
83    two = math_ops.cast(2.0, dtype=dtype)
84    halflen_float = math_ops.cast(halflen_float, dtype=dtype)
85    num = beta * math_ops.sqrt(
86        one - math_ops.pow(arg, two) / math_ops.pow(halflen_float, two))
87    window = math_ops.exp(num - beta) * (
88        special_math_ops.bessel_i0e(num) / special_math_ops.bessel_i0e(beta))
89  return window
90
91
92@tf_export('signal.kaiser_bessel_derived_window')
93@dispatch.add_dispatch_support
94def kaiser_bessel_derived_window(window_length, beta=12.,
95                                 dtype=dtypes.float32, name=None):
96  """Generate a [Kaiser Bessel derived window][kbd].
97
98  Args:
99    window_length: A scalar `Tensor` indicating the window length to generate.
100    beta: Beta parameter for Kaiser window.
101    dtype: The data type to produce. Must be a floating point type.
102    name: An optional name for the operation.
103
104  Returns:
105    A `Tensor` of shape `[window_length]` of type `dtype`.
106
107  [kbd]:
108    https://en.wikipedia.org/wiki/Kaiser_window#Kaiser%E2%80%93Bessel-derived_(KBD)_window
109  """
110  with ops.name_scope(name, 'kaiser_bessel_derived_window'):
111    window_length = _check_params(window_length, dtype)
112    halflen = window_length // 2
113    kaiserw = kaiser_window(halflen + 1, beta, dtype=dtype)
114    kaiserw_csum = math_ops.cumsum(kaiserw)
115    halfw = math_ops.sqrt(kaiserw_csum[:-1] / kaiserw_csum[-1])
116    window = array_ops.concat((halfw, halfw[::-1]), axis=0)
117  return window
118
119
120@tf_export('signal.vorbis_window')
121@dispatch.add_dispatch_support
122def vorbis_window(window_length, dtype=dtypes.float32, name=None):
123  """Generate a [Vorbis power complementary window][vorbis].
124
125  Args:
126    window_length: A scalar `Tensor` indicating the window length to generate.
127    dtype: The data type to produce. Must be a floating point type.
128    name: An optional name for the operation.
129
130  Returns:
131    A `Tensor` of shape `[window_length]` of type `dtype`.
132
133  [vorbis]:
134    https://en.wikipedia.org/wiki/Modified_discrete_cosine_transform#Window_functions
135  """
136  with ops.name_scope(name, 'vorbis_window'):
137    window_length = _check_params(window_length, dtype)
138    arg = math_ops.cast(math_ops.range(window_length), dtype=dtype)
139    window = math_ops.sin(np.pi / 2.0 * math_ops.pow(math_ops.sin(
140        np.pi / math_ops.cast(window_length, dtype=dtype) *
141        (arg + 0.5)), 2.0))
142  return window
143
144
145@tf_export('signal.hann_window')
146@dispatch.add_dispatch_support
147def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None):
148  """Generate a [Hann window][hann].
149
150  Args:
151    window_length: A scalar `Tensor` indicating the window length to generate.
152    periodic: A bool `Tensor` indicating whether to generate a periodic or
153      symmetric window. Periodic windows are typically used for spectral
154      analysis while symmetric windows are typically used for digital
155      filter design.
156    dtype: The data type to produce. Must be a floating point type.
157    name: An optional name for the operation.
158
159  Returns:
160    A `Tensor` of shape `[window_length]` of type `dtype`.
161
162  Raises:
163    ValueError: If `dtype` is not a floating point type.
164
165  [hann]: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
166  """
167  return _raised_cosine_window(name, 'hann_window', window_length, periodic,
168                               dtype, 0.5, 0.5)
169
170
171@tf_export('signal.hamming_window')
172@dispatch.add_dispatch_support
173def hamming_window(window_length, periodic=True, dtype=dtypes.float32,
174                   name=None):
175  """Generate a [Hamming][hamming] window.
176
177  Args:
178    window_length: A scalar `Tensor` indicating the window length to generate.
179    periodic: A bool `Tensor` indicating whether to generate a periodic or
180      symmetric window. Periodic windows are typically used for spectral
181      analysis while symmetric windows are typically used for digital
182      filter design.
183    dtype: The data type to produce. Must be a floating point type.
184    name: An optional name for the operation.
185
186  Returns:
187    A `Tensor` of shape `[window_length]` of type `dtype`.
188
189  Raises:
190    ValueError: If `dtype` is not a floating point type.
191
192  [hamming]:
193    https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
194  """
195  return _raised_cosine_window(name, 'hamming_window', window_length, periodic,
196                               dtype, 0.54, 0.46)
197
198
199def _raised_cosine_window(name, default_name, window_length, periodic,
200                          dtype, a, b):
201  """Helper function for computing a raised cosine window.
202
203  Args:
204    name: Name to use for the scope.
205    default_name: Default name to use for the scope.
206    window_length: A scalar `Tensor` or integer indicating the window length.
207    periodic: A bool `Tensor` indicating whether to generate a periodic or
208      symmetric window.
209    dtype: A floating point `DType`.
210    a: The alpha parameter to the raised cosine window.
211    b: The beta parameter to the raised cosine window.
212
213  Returns:
214    A `Tensor` of shape `[window_length]` of type `dtype`.
215
216  Raises:
217    ValueError: If `dtype` is not a floating point type or `window_length` is
218      not scalar or `periodic` is not scalar.
219  """
220  if not dtype.is_floating:
221    raise ValueError('dtype must be a floating point type. Found %s' % dtype)
222
223  with ops.name_scope(name, default_name, [window_length, periodic]):
224    window_length = ops.convert_to_tensor(window_length, dtype=dtypes.int32,
225                                          name='window_length')
226    window_length.shape.assert_has_rank(0)
227    window_length_const = tensor_util.constant_value(window_length)
228    if window_length_const == 1:
229      return array_ops.ones([1], dtype=dtype)
230    periodic = math_ops.cast(
231        ops.convert_to_tensor(periodic, dtype=dtypes.bool, name='periodic'),
232        dtypes.int32)
233    periodic.shape.assert_has_rank(0)
234    even = 1 - math_ops.mod(window_length, 2)
235
236    n = math_ops.cast(window_length + periodic * even - 1, dtype=dtype)
237    count = math_ops.cast(math_ops.range(window_length), dtype)
238    cos_arg = constant_op.constant(2 * np.pi, dtype=dtype) * count / n
239
240    if window_length_const is not None:
241      return math_ops.cast(a - b * math_ops.cos(cos_arg), dtype=dtype)
242    return control_flow_ops.cond(
243        math_ops.equal(window_length, 1),
244        lambda: array_ops.ones([window_length], dtype=dtype),
245        lambda: math_ops.cast(a - b * math_ops.cos(cos_arg), dtype=dtype))
246