1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops for computing common window functions.""" 16 17import numpy as np 18 19from tensorflow.python.framework import constant_op 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import tensor_util 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import control_flow_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops import special_math_ops 27from tensorflow.python.util import dispatch 28from tensorflow.python.util.tf_export import tf_export 29 30 31def _check_params(window_length, dtype): 32 """Check window_length and dtype params. 33 34 Args: 35 window_length: A scalar value or `Tensor`. 36 dtype: The data type to produce. Must be a floating point type. 37 38 Returns: 39 window_length converted to a tensor of type int32. 40 41 Raises: 42 ValueError: If `dtype` is not a floating point type or window_length is not 43 a scalar. 44 """ 45 if not dtype.is_floating: 46 raise ValueError('dtype must be a floating point type. Found %s' % dtype) 47 window_length = ops.convert_to_tensor(window_length, dtype=dtypes.int32) 48 window_length.shape.assert_has_rank(0) 49 return window_length 50 51 52@tf_export('signal.kaiser_window') 53@dispatch.add_dispatch_support 54def kaiser_window(window_length, beta=12., dtype=dtypes.float32, name=None): 55 """Generate a [Kaiser window][kaiser]. 56 57 Args: 58 window_length: A scalar `Tensor` indicating the window length to generate. 59 beta: Beta parameter for Kaiser window, see reference below. 60 dtype: The data type to produce. Must be a floating point type. 61 name: An optional name for the operation. 62 63 Returns: 64 A `Tensor` of shape `[window_length]` of type `dtype`. 65 66 [kaiser]: 67 https://docs.scipy.org/doc/numpy/reference/generated/numpy.kaiser.html 68 """ 69 with ops.name_scope(name, 'kaiser_window'): 70 window_length = _check_params(window_length, dtype) 71 window_length_const = tensor_util.constant_value(window_length) 72 if window_length_const == 1: 73 return array_ops.ones([1], dtype=dtype) 74 # tf.range does not support float16 so we work with float32 initially. 75 halflen_float = ( 76 math_ops.cast(window_length, dtype=dtypes.float32) - 1.0) / 2.0 77 arg = math_ops.range(-halflen_float, halflen_float + 0.1, 78 dtype=dtypes.float32) 79 # Convert everything into given dtype which can be float16. 80 arg = math_ops.cast(arg, dtype=dtype) 81 beta = math_ops.cast(beta, dtype=dtype) 82 one = math_ops.cast(1.0, dtype=dtype) 83 two = math_ops.cast(2.0, dtype=dtype) 84 halflen_float = math_ops.cast(halflen_float, dtype=dtype) 85 num = beta * math_ops.sqrt( 86 one - math_ops.pow(arg, two) / math_ops.pow(halflen_float, two)) 87 window = math_ops.exp(num - beta) * ( 88 special_math_ops.bessel_i0e(num) / special_math_ops.bessel_i0e(beta)) 89 return window 90 91 92@tf_export('signal.kaiser_bessel_derived_window') 93@dispatch.add_dispatch_support 94def kaiser_bessel_derived_window(window_length, beta=12., 95 dtype=dtypes.float32, name=None): 96 """Generate a [Kaiser Bessel derived window][kbd]. 97 98 Args: 99 window_length: A scalar `Tensor` indicating the window length to generate. 100 beta: Beta parameter for Kaiser window. 101 dtype: The data type to produce. Must be a floating point type. 102 name: An optional name for the operation. 103 104 Returns: 105 A `Tensor` of shape `[window_length]` of type `dtype`. 106 107 [kbd]: 108 https://en.wikipedia.org/wiki/Kaiser_window#Kaiser%E2%80%93Bessel-derived_(KBD)_window 109 """ 110 with ops.name_scope(name, 'kaiser_bessel_derived_window'): 111 window_length = _check_params(window_length, dtype) 112 halflen = window_length // 2 113 kaiserw = kaiser_window(halflen + 1, beta, dtype=dtype) 114 kaiserw_csum = math_ops.cumsum(kaiserw) 115 halfw = math_ops.sqrt(kaiserw_csum[:-1] / kaiserw_csum[-1]) 116 window = array_ops.concat((halfw, halfw[::-1]), axis=0) 117 return window 118 119 120@tf_export('signal.vorbis_window') 121@dispatch.add_dispatch_support 122def vorbis_window(window_length, dtype=dtypes.float32, name=None): 123 """Generate a [Vorbis power complementary window][vorbis]. 124 125 Args: 126 window_length: A scalar `Tensor` indicating the window length to generate. 127 dtype: The data type to produce. Must be a floating point type. 128 name: An optional name for the operation. 129 130 Returns: 131 A `Tensor` of shape `[window_length]` of type `dtype`. 132 133 [vorbis]: 134 https://en.wikipedia.org/wiki/Modified_discrete_cosine_transform#Window_functions 135 """ 136 with ops.name_scope(name, 'vorbis_window'): 137 window_length = _check_params(window_length, dtype) 138 arg = math_ops.cast(math_ops.range(window_length), dtype=dtype) 139 window = math_ops.sin(np.pi / 2.0 * math_ops.pow(math_ops.sin( 140 np.pi / math_ops.cast(window_length, dtype=dtype) * 141 (arg + 0.5)), 2.0)) 142 return window 143 144 145@tf_export('signal.hann_window') 146@dispatch.add_dispatch_support 147def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None): 148 """Generate a [Hann window][hann]. 149 150 Args: 151 window_length: A scalar `Tensor` indicating the window length to generate. 152 periodic: A bool `Tensor` indicating whether to generate a periodic or 153 symmetric window. Periodic windows are typically used for spectral 154 analysis while symmetric windows are typically used for digital 155 filter design. 156 dtype: The data type to produce. Must be a floating point type. 157 name: An optional name for the operation. 158 159 Returns: 160 A `Tensor` of shape `[window_length]` of type `dtype`. 161 162 Raises: 163 ValueError: If `dtype` is not a floating point type. 164 165 [hann]: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows 166 """ 167 return _raised_cosine_window(name, 'hann_window', window_length, periodic, 168 dtype, 0.5, 0.5) 169 170 171@tf_export('signal.hamming_window') 172@dispatch.add_dispatch_support 173def hamming_window(window_length, periodic=True, dtype=dtypes.float32, 174 name=None): 175 """Generate a [Hamming][hamming] window. 176 177 Args: 178 window_length: A scalar `Tensor` indicating the window length to generate. 179 periodic: A bool `Tensor` indicating whether to generate a periodic or 180 symmetric window. Periodic windows are typically used for spectral 181 analysis while symmetric windows are typically used for digital 182 filter design. 183 dtype: The data type to produce. Must be a floating point type. 184 name: An optional name for the operation. 185 186 Returns: 187 A `Tensor` of shape `[window_length]` of type `dtype`. 188 189 Raises: 190 ValueError: If `dtype` is not a floating point type. 191 192 [hamming]: 193 https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows 194 """ 195 return _raised_cosine_window(name, 'hamming_window', window_length, periodic, 196 dtype, 0.54, 0.46) 197 198 199def _raised_cosine_window(name, default_name, window_length, periodic, 200 dtype, a, b): 201 """Helper function for computing a raised cosine window. 202 203 Args: 204 name: Name to use for the scope. 205 default_name: Default name to use for the scope. 206 window_length: A scalar `Tensor` or integer indicating the window length. 207 periodic: A bool `Tensor` indicating whether to generate a periodic or 208 symmetric window. 209 dtype: A floating point `DType`. 210 a: The alpha parameter to the raised cosine window. 211 b: The beta parameter to the raised cosine window. 212 213 Returns: 214 A `Tensor` of shape `[window_length]` of type `dtype`. 215 216 Raises: 217 ValueError: If `dtype` is not a floating point type or `window_length` is 218 not scalar or `periodic` is not scalar. 219 """ 220 if not dtype.is_floating: 221 raise ValueError('dtype must be a floating point type. Found %s' % dtype) 222 223 with ops.name_scope(name, default_name, [window_length, periodic]): 224 window_length = ops.convert_to_tensor(window_length, dtype=dtypes.int32, 225 name='window_length') 226 window_length.shape.assert_has_rank(0) 227 window_length_const = tensor_util.constant_value(window_length) 228 if window_length_const == 1: 229 return array_ops.ones([1], dtype=dtype) 230 periodic = math_ops.cast( 231 ops.convert_to_tensor(periodic, dtype=dtypes.bool, name='periodic'), 232 dtypes.int32) 233 periodic.shape.assert_has_rank(0) 234 even = 1 - math_ops.mod(window_length, 2) 235 236 n = math_ops.cast(window_length + periodic * even - 1, dtype=dtype) 237 count = math_ops.cast(math_ops.range(window_length), dtype) 238 cos_arg = constant_op.constant(2 * np.pi, dtype=dtype) * count / n 239 240 if window_length_const is not None: 241 return math_ops.cast(a - b * math_ops.cos(cos_arg), dtype=dtype) 242 return control_flow_ops.cond( 243 math_ops.equal(window_length, 1), 244 lambda: array_ops.ones([window_length], dtype=dtype), 245 lambda: math_ops.cast(a - b * math_ops.cos(cos_arg), dtype=dtype)) 246