1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for AudioMicrofrontend."""
16
17import tensorflow as tf
18
19from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op as frontend_op
20from tensorflow.python.framework import ops
21
22SAMPLE_RATE = 1000
23WINDOW_SIZE = 25
24WINDOW_STEP = 10
25NUM_CHANNELS = 2
26UPPER_BAND_LIMIT = 450.0
27LOWER_BAND_LIMIT = 8.0
28SMOOTHING_BITS = 10
29
30
31class AudioFeatureGenerationTest(tf.test.TestCase):
32
33  def setUp(self):
34    super(AudioFeatureGenerationTest, self).setUp()
35    ops.disable_eager_execution()
36
37  def testSimple(self):
38    with self.test_session():
39      audio = tf.constant(
40          [0, 32767, 0, -32768] * ((WINDOW_SIZE + 4 * WINDOW_STEP) // 4),
41          tf.int16)
42      filterbanks = frontend_op.audio_microfrontend(
43          audio,
44          sample_rate=SAMPLE_RATE,
45          window_size=WINDOW_SIZE,
46          window_step=WINDOW_STEP,
47          num_channels=NUM_CHANNELS,
48          upper_band_limit=UPPER_BAND_LIMIT,
49          lower_band_limit=LOWER_BAND_LIMIT,
50          smoothing_bits=SMOOTHING_BITS,
51          enable_pcan=True)
52      self.assertAllEqual(filterbanks.eval(),
53                          [[479, 425], [436, 378], [410, 350], [391, 325]])
54
55  def testSimpleFloatScaled(self):
56    with self.test_session():
57      audio = tf.constant(
58          [0, 32767, 0, -32768] * ((WINDOW_SIZE + 4 * WINDOW_STEP) // 4),
59          tf.int16)
60      filterbanks = frontend_op.audio_microfrontend(
61          audio,
62          sample_rate=SAMPLE_RATE,
63          window_size=WINDOW_SIZE,
64          window_step=WINDOW_STEP,
65          num_channels=NUM_CHANNELS,
66          upper_band_limit=UPPER_BAND_LIMIT,
67          lower_band_limit=LOWER_BAND_LIMIT,
68          smoothing_bits=SMOOTHING_BITS,
69          enable_pcan=True,
70          out_scale=64,
71          out_type=tf.float32)
72      self.assertAllEqual(filterbanks.eval(),
73                          [[7.484375, 6.640625], [6.8125, 5.90625],
74                           [6.40625, 5.46875], [6.109375, 5.078125]])
75
76  def testStacking(self):
77    with self.test_session():
78      audio = tf.constant(
79          [0, 32767, 0, -32768] * ((WINDOW_SIZE + 4 * WINDOW_STEP) // 4),
80          tf.int16)
81      filterbanks = frontend_op.audio_microfrontend(
82          audio,
83          sample_rate=SAMPLE_RATE,
84          window_size=WINDOW_SIZE,
85          window_step=WINDOW_STEP,
86          num_channels=NUM_CHANNELS,
87          upper_band_limit=UPPER_BAND_LIMIT,
88          lower_band_limit=LOWER_BAND_LIMIT,
89          smoothing_bits=SMOOTHING_BITS,
90          enable_pcan=True,
91          right_context=1,
92          frame_stride=2)
93      self.assertAllEqual(filterbanks.eval(),
94                          [[479, 425, 436, 378], [410, 350, 391, 325]])
95
96  def testStackingWithOverlap(self):
97    with self.test_session():
98      audio = tf.constant(
99          [0, 32767, 0, -32768] * ((WINDOW_SIZE + 4 * WINDOW_STEP) // 4),
100          tf.int16)
101      filterbanks = frontend_op.audio_microfrontend(
102          audio,
103          sample_rate=SAMPLE_RATE,
104          window_size=WINDOW_SIZE,
105          window_step=WINDOW_STEP,
106          num_channels=NUM_CHANNELS,
107          upper_band_limit=UPPER_BAND_LIMIT,
108          lower_band_limit=LOWER_BAND_LIMIT,
109          smoothing_bits=SMOOTHING_BITS,
110          enable_pcan=True,
111          left_context=1,
112          right_context=1)
113      self.assertAllEqual(
114          self.evaluate(filterbanks),
115          [[479, 425, 479, 425, 436, 378], [479, 425, 436, 378, 410, 350],
116           [436, 378, 410, 350, 391, 325], [410, 350, 391, 325, 391, 325]])
117
118  def testStackingDropFrame(self):
119    with self.test_session():
120      audio = tf.constant(
121          [0, 32767, 0, -32768] * ((WINDOW_SIZE + 4 * WINDOW_STEP) // 4),
122          tf.int16)
123      filterbanks = frontend_op.audio_microfrontend(
124          audio,
125          sample_rate=SAMPLE_RATE,
126          window_size=WINDOW_SIZE,
127          window_step=WINDOW_STEP,
128          num_channels=NUM_CHANNELS,
129          upper_band_limit=UPPER_BAND_LIMIT,
130          lower_band_limit=LOWER_BAND_LIMIT,
131          smoothing_bits=SMOOTHING_BITS,
132          enable_pcan=True,
133          left_context=1,
134          frame_stride=2)
135      self.assertAllEqual(filterbanks.eval(),
136                          [[479, 425, 479, 425], [436, 378, 410, 350]])
137
138  def testZeroPadding(self):
139    with self.test_session():
140      audio = tf.constant(
141          [0, 32767, 0, -32768] * ((WINDOW_SIZE + 7 * WINDOW_STEP) // 4),
142          tf.int16)
143      filterbanks = frontend_op.audio_microfrontend(
144          audio,
145          sample_rate=SAMPLE_RATE,
146          window_size=WINDOW_SIZE,
147          window_step=WINDOW_STEP,
148          num_channels=NUM_CHANNELS,
149          upper_band_limit=UPPER_BAND_LIMIT,
150          lower_band_limit=LOWER_BAND_LIMIT,
151          smoothing_bits=SMOOTHING_BITS,
152          enable_pcan=True,
153          left_context=2,
154          frame_stride=3,
155          zero_padding=True)
156      self.assertAllEqual(
157          self.evaluate(filterbanks),
158          [[0, 0, 0, 0, 479, 425], [436, 378, 410, 350, 391, 325],
159           [374, 308, 362, 292, 352, 275]])
160
161
162if __name__ == '__main__':
163  tf.test.main()
164