xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/signal/mel_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""mel conversion ops."""
16
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import ops
19from tensorflow.python.framework import tensor_util
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops import math_ops
22from tensorflow.python.ops.signal import shape_ops
23from tensorflow.python.util import dispatch
24from tensorflow.python.util.tf_export import tf_export
25
26
27# mel spectrum constants.
28_MEL_BREAK_FREQUENCY_HERTZ = 700.0
29_MEL_HIGH_FREQUENCY_Q = 1127.0
30
31
32def _mel_to_hertz(mel_values, name=None):
33  """Converts frequencies in `mel_values` from the mel scale to linear scale.
34
35  Args:
36    mel_values: A `Tensor` of frequencies in the mel scale.
37    name: An optional name for the operation.
38
39  Returns:
40    A `Tensor` of the same shape and type as `mel_values` containing linear
41    scale frequencies in Hertz.
42  """
43  with ops.name_scope(name, 'mel_to_hertz', [mel_values]):
44    mel_values = ops.convert_to_tensor(mel_values)
45    return _MEL_BREAK_FREQUENCY_HERTZ * (
46        math_ops.exp(mel_values / _MEL_HIGH_FREQUENCY_Q) - 1.0
47    )
48
49
50def _hertz_to_mel(frequencies_hertz, name=None):
51  """Converts frequencies in `frequencies_hertz` in Hertz to the mel scale.
52
53  Args:
54    frequencies_hertz: A `Tensor` of frequencies in Hertz.
55    name: An optional name for the operation.
56
57  Returns:
58    A `Tensor` of the same shape and type of `frequencies_hertz` containing
59    frequencies in the mel scale.
60  """
61  with ops.name_scope(name, 'hertz_to_mel', [frequencies_hertz]):
62    frequencies_hertz = ops.convert_to_tensor(frequencies_hertz)
63    return _MEL_HIGH_FREQUENCY_Q * math_ops.log(
64        1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
65
66
67def _validate_arguments(num_mel_bins, sample_rate,
68                        lower_edge_hertz, upper_edge_hertz, dtype):
69  """Checks the inputs to linear_to_mel_weight_matrix."""
70  if num_mel_bins <= 0:
71    raise ValueError('num_mel_bins must be positive. Got: %s' % num_mel_bins)
72  if lower_edge_hertz < 0.0:
73    raise ValueError('lower_edge_hertz must be non-negative. Got: %s' %
74                     lower_edge_hertz)
75  if lower_edge_hertz >= upper_edge_hertz:
76    raise ValueError('lower_edge_hertz %.1f >= upper_edge_hertz %.1f' %
77                     (lower_edge_hertz, upper_edge_hertz))
78  if not isinstance(sample_rate, ops.Tensor):
79    if sample_rate <= 0.0:
80      raise ValueError('sample_rate must be positive. Got: %s' % sample_rate)
81    if upper_edge_hertz > sample_rate / 2:
82      raise ValueError('upper_edge_hertz must not be larger than the Nyquist '
83                       'frequency (sample_rate / 2). Got %s for sample_rate: %s'
84                       % (upper_edge_hertz, sample_rate))
85  if not dtype.is_floating:
86    raise ValueError('dtype must be a floating point type. Got: %s' % dtype)
87
88
89@tf_export('signal.linear_to_mel_weight_matrix')
90@dispatch.add_dispatch_support
91def linear_to_mel_weight_matrix(num_mel_bins=20,
92                                num_spectrogram_bins=129,
93                                sample_rate=8000,
94                                lower_edge_hertz=125.0,
95                                upper_edge_hertz=3800.0,
96                                dtype=dtypes.float32,
97                                name=None):
98  r"""Returns a matrix to warp linear scale spectrograms to the [mel scale][mel].
99
100  Returns a weight matrix that can be used to re-weight a `Tensor` containing
101  `num_spectrogram_bins` linearly sampled frequency information from
102  `[0, sample_rate / 2]` into `num_mel_bins` frequency information from
103  `[lower_edge_hertz, upper_edge_hertz]` on the [mel scale][mel].
104
105  This function follows the [Hidden Markov Model Toolkit
106  (HTK)](http://htk.eng.cam.ac.uk/) convention, defining the mel scale in
107  terms of a frequency in hertz according to the following formula:
108
109      $$\textrm{mel}(f) = 2595 * \textrm{log}_{10}(1 + \frac{f}{700})$$
110
111  In the returned matrix, all the triangles (filterbanks) have a peak value
112  of 1.0.
113
114  For example, the returned matrix `A` can be used to right-multiply a
115  spectrogram `S` of shape `[frames, num_spectrogram_bins]` of linear
116  scale spectrum values (e.g. STFT magnitudes) to generate a "mel spectrogram"
117  `M` of shape `[frames, num_mel_bins]`.
118
119      # `S` has shape [frames, num_spectrogram_bins]
120      # `M` has shape [frames, num_mel_bins]
121      M = tf.matmul(S, A)
122
123  The matrix can be used with `tf.tensordot` to convert an arbitrary rank
124  `Tensor` of linear-scale spectral bins into the mel scale.
125
126      # S has shape [..., num_spectrogram_bins].
127      # M has shape [..., num_mel_bins].
128      M = tf.tensordot(S, A, 1)
129
130  Args:
131    num_mel_bins: Python int. How many bands in the resulting mel spectrum.
132    num_spectrogram_bins: An integer `Tensor`. How many bins there are in the
133      source spectrogram data, which is understood to be `fft_size // 2 + 1`,
134      i.e. the spectrogram only contains the nonredundant FFT bins.
135    sample_rate: An integer or float `Tensor`. Samples per second of the input
136      signal used to create the spectrogram. Used to figure out the frequencies
137      corresponding to each spectrogram bin, which dictates how they are mapped
138      into the mel scale.
139    lower_edge_hertz: Python float. Lower bound on the frequencies to be
140      included in the mel spectrum. This corresponds to the lower edge of the
141      lowest triangular band.
142    upper_edge_hertz: Python float. The desired top edge of the highest
143      frequency band.
144    dtype: The `DType` of the result matrix. Must be a floating point type.
145    name: An optional name for the operation.
146
147  Returns:
148    A `Tensor` of shape `[num_spectrogram_bins, num_mel_bins]`.
149
150  Raises:
151    ValueError: If `num_mel_bins`/`num_spectrogram_bins`/`sample_rate` are not
152      positive, `lower_edge_hertz` is negative, frequency edges are incorrectly
153      ordered, `upper_edge_hertz` is larger than the Nyquist frequency.
154
155  [mel]: https://en.wikipedia.org/wiki/Mel_scale
156  """
157  with ops.name_scope(name, 'linear_to_mel_weight_matrix') as name:
158    # Convert Tensor `sample_rate` to float, if possible.
159    if isinstance(sample_rate, ops.Tensor):
160      maybe_const_val = tensor_util.constant_value(sample_rate)
161      if maybe_const_val is not None:
162        sample_rate = maybe_const_val
163
164    # Note: As num_spectrogram_bins is passed to `math_ops.linspace`
165    # and the validation is already done in linspace (both in shape function
166    # and in kernel), there is no need to validate num_spectrogram_bins here.
167    _validate_arguments(num_mel_bins, sample_rate,
168                        lower_edge_hertz, upper_edge_hertz, dtype)
169
170    # This function can be constant folded by graph optimization since there are
171    # no Tensor inputs.
172    sample_rate = math_ops.cast(
173        sample_rate, dtype, name='sample_rate')
174    lower_edge_hertz = ops.convert_to_tensor(
175        lower_edge_hertz, dtype, name='lower_edge_hertz')
176    upper_edge_hertz = ops.convert_to_tensor(
177        upper_edge_hertz, dtype, name='upper_edge_hertz')
178    zero = ops.convert_to_tensor(0.0, dtype)
179
180    # HTK excludes the spectrogram DC bin.
181    bands_to_zero = 1
182    nyquist_hertz = sample_rate / 2.0
183    linear_frequencies = math_ops.linspace(
184        zero, nyquist_hertz, num_spectrogram_bins)[bands_to_zero:]
185    spectrogram_bins_mel = array_ops.expand_dims(
186        _hertz_to_mel(linear_frequencies), 1)
187
188    # Compute num_mel_bins triples of (lower_edge, center, upper_edge). The
189    # center of each band is the lower and upper edge of the adjacent bands.
190    # Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into
191    # num_mel_bins + 2 pieces.
192    band_edges_mel = shape_ops.frame(
193        math_ops.linspace(_hertz_to_mel(lower_edge_hertz),
194                          _hertz_to_mel(upper_edge_hertz),
195                          num_mel_bins + 2), frame_length=3, frame_step=1)
196
197    # Split the triples up and reshape them into [1, num_mel_bins] tensors.
198    lower_edge_mel, center_mel, upper_edge_mel = tuple(array_ops.reshape(
199        t, [1, num_mel_bins]) for t in array_ops.split(
200            band_edges_mel, 3, axis=1))
201
202    # Calculate lower and upper slopes for every spectrogram bin.
203    # Line segments are linear in the mel domain, not Hertz.
204    lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / (
205        center_mel - lower_edge_mel)
206    upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / (
207        upper_edge_mel - center_mel)
208
209    # Intersect the line segments with each other and zero.
210    mel_weights_matrix = math_ops.maximum(
211        zero, math_ops.minimum(lower_slopes, upper_slopes))
212
213    # Re-add the zeroed lower bins we sliced out above.
214    return array_ops.pad(
215        mel_weights_matrix, [[bands_to_zero, 0], [0, 0]], name=name)
216