1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/experimental/microfrontend/lib/frontend.h"
16 #include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/platform/macros.h"
25 
26 using tensorflow::errors::Internal;
27 using tensorflow::errors::InvalidArgument;
28 using tensorflow::shape_inference::DimensionHandle;
29 using tensorflow::shape_inference::InferenceContext;
30 using tensorflow::shape_inference::ShapeHandle;
31 
32 namespace tensorflow {
33 REGISTER_OP("AudioMicrofrontend")
34     .Input("audio: int16")
35     .Output("filterbanks: out_type")
36     .Attr("sample_rate: int = 16000")
37     .Attr("window_size: int = 25")
38     .Attr("window_step: int = 10")
39     .Attr("num_channels: int = 32")
40     .Attr("upper_band_limit: float = 7500.0")
41     .Attr("lower_band_limit: float = 125.0")
42     .Attr("smoothing_bits: int = 10")
43     .Attr("even_smoothing: float = 0.025")
44     .Attr("odd_smoothing: float = 0.06")
45     .Attr("min_signal_remaining: float = 0.05")
46     .Attr("enable_pcan: bool = false")
47     .Attr("pcan_strength: float = 0.95")
48     .Attr("pcan_offset: float = 80.0")
49     .Attr("gain_bits: int = 21")
50     .Attr("enable_log: bool = true")
51     .Attr("scale_shift: int = 6")
52     .Attr("left_context: int = 0")
53     .Attr("right_context: int = 0")
54     .Attr("frame_stride: int = 1")
55     .Attr("zero_padding: bool = false")
56     .Attr("out_scale: int = 1")
57     .Attr("out_type: {uint16, float} = DT_UINT16")
__anon8824a42c0102(InferenceContext* ctx) 58     .SetShapeFn([](InferenceContext* ctx) {
59       ShapeHandle input;
60       TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(0), 1, &input));
61 
62       int sample_rate;
63       TF_RETURN_IF_ERROR(ctx->GetAttr("sample_rate", &sample_rate));
64       int window_size;
65       TF_RETURN_IF_ERROR(ctx->GetAttr("window_size", &window_size));
66       window_size *= sample_rate / 1000;
67       int window_step;
68       TF_RETURN_IF_ERROR(ctx->GetAttr("window_step", &window_step));
69       window_step *= sample_rate / 1000;
70 
71       int num_channels;
72       TF_RETURN_IF_ERROR(ctx->GetAttr("num_channels", &num_channels));
73       int left_context;
74       TF_RETURN_IF_ERROR(ctx->GetAttr("left_context", &left_context));
75       int right_context;
76       TF_RETURN_IF_ERROR(ctx->GetAttr("right_context", &right_context));
77       int frame_stride;
78       TF_RETURN_IF_ERROR(ctx->GetAttr("frame_stride", &frame_stride));
79 
80       DimensionHandle num_frames = ctx->Dim(input, 0);
81       if (ctx->Value(num_frames) < window_size) {
82         num_frames = ctx->MakeDim(0);
83       } else {
84         TF_RETURN_IF_ERROR(ctx->Subtract(num_frames, window_size, &num_frames));
85         TF_RETURN_IF_ERROR(
86             ctx->Divide(num_frames, window_step, false, &num_frames));
87         TF_RETURN_IF_ERROR(
88             ctx->Divide(num_frames, frame_stride, false, &num_frames));
89         TF_RETURN_IF_ERROR(ctx->Add(num_frames, 1, &num_frames));
90       }
91 
92       int stack_size = 1 + left_context + right_context;
93       DimensionHandle num_features = ctx->MakeDim(num_channels);
94       TF_RETURN_IF_ERROR(
95           ctx->Multiply(num_features, stack_size, &num_features));
96 
97       ShapeHandle output = ctx->MakeShape({num_frames, num_features});
98       ctx->set_output(0, output);
99       return OkStatus();
100     })
101     .Doc(R"doc(
102 Audio Microfrontend Op.
103 
104 This Op converts a sequence of audio data into one or more
105 feature vectors containing filterbanks of the input. The
106 conversion process uses a lightweight library to perform:
107 
108 1. A slicing window function
109 2. Short-time FFTs
110 3. Filterbank calculations
111 4. Noise reduction
112 5. PCAN Auto Gain Control
113 6. Logarithmic scaling
114 
115 Arguments
116   audio: 1D Tensor, int16 audio data in temporal ordering.
117   sample_rate: Integer, the sample rate of the audio in Hz.
118   window_size: Integer, length of desired time frames in ms.
119   window_step: Integer, length of step size for the next frame in ms.
120   num_channels: Integer, the number of filterbank channels to use.
121   upper_band_limit: Float, the highest frequency included in the filterbanks.
122   lower_band_limit: Float, the lowest frequency included in the filterbanks.
123   smoothing_bits: Int, scale up signal by 2^(smoothing_bits) before reduction.
124   even_smoothing: Float, smoothing coefficient for even-numbered channels.
125   odd_smoothing: Float, smoothing coefficient for odd-numbered channels.
126   min_signal_remaining: Float, fraction of signal to preserve in smoothing.
127   enable_pcan: Bool, enable PCAN auto gain control.
128   pcan_strength: Float, gain normalization exponent.
129   pcan_offset: Float, positive value added in the normalization denominator.
130   gain_bits: Int, number of fractional bits in the gain.
131   enable_log: Bool, enable logarithmic scaling of filterbanks.
132   scale_shift: Integer, scale filterbanks by 2^(scale_shift).
133   left_context: Integer, number of preceding frames to attach to each frame.
134   right_context: Integer, number of preceding frames to attach to each frame.
135   frame_stride: Integer, M frames to skip over, where output[n] = frame[n*M].
136   zero_padding: Bool, if left/right context is out-of-bounds, attach frame of
137                 zeroes. Otherwise, frame[0] or frame[size-1] will be copied.
138   out_scale: Integer, divide all filterbanks by this number.
139   out_type: DType, type of the output Tensor, defaults to UINT16.
140 
141 Returns
142   filterbanks: 2D Tensor, each row is a time frame, each column is a channel.
143 )doc");
144 
145 template <typename T>
146 class AudioMicrofrontendOp : public OpKernel {
147  public:
AudioMicrofrontendOp(OpKernelConstruction * ctx)148   explicit AudioMicrofrontendOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
149     OP_REQUIRES_OK(ctx, ctx->GetAttr("sample_rate", &sample_rate_));
150 
151     int window_size;
152     OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size", &window_size));
153     config_.window.size_ms = window_size;
154 
155     int window_step;
156     OP_REQUIRES_OK(ctx, ctx->GetAttr("window_step", &window_step));
157     config_.window.step_size_ms = window_step;
158 
159     OP_REQUIRES_OK(
160         ctx, ctx->GetAttr("num_channels", &config_.filterbank.num_channels));
161     OP_REQUIRES_OK(ctx, ctx->GetAttr("upper_band_limit",
162                                      &config_.filterbank.upper_band_limit));
163     OP_REQUIRES_OK(ctx, ctx->GetAttr("lower_band_limit",
164                                      &config_.filterbank.lower_band_limit));
165     OP_REQUIRES_OK(ctx, ctx->GetAttr("smoothing_bits",
166                                      &config_.noise_reduction.smoothing_bits));
167     OP_REQUIRES_OK(ctx, ctx->GetAttr("even_smoothing",
168                                      &config_.noise_reduction.even_smoothing));
169     OP_REQUIRES_OK(ctx, ctx->GetAttr("odd_smoothing",
170                                      &config_.noise_reduction.odd_smoothing));
171     OP_REQUIRES_OK(ctx,
172                    ctx->GetAttr("min_signal_remaining",
173                                 &config_.noise_reduction.min_signal_remaining));
174 
175     bool enable_pcan;
176     OP_REQUIRES_OK(ctx, ctx->GetAttr("enable_pcan", &enable_pcan));
177     config_.pcan_gain_control.enable_pcan = enable_pcan;
178 
179     OP_REQUIRES_OK(ctx, ctx->GetAttr("pcan_strength",
180                                      &config_.pcan_gain_control.strength));
181     OP_REQUIRES_OK(
182         ctx, ctx->GetAttr("pcan_offset", &config_.pcan_gain_control.offset));
183     OP_REQUIRES_OK(
184         ctx, ctx->GetAttr("gain_bits", &config_.pcan_gain_control.gain_bits));
185 
186     bool enable_log;
187     OP_REQUIRES_OK(ctx, ctx->GetAttr("enable_log", &enable_log));
188     config_.log_scale.enable_log = enable_log;
189 
190     OP_REQUIRES_OK(ctx,
191                    ctx->GetAttr("scale_shift", &config_.log_scale.scale_shift));
192 
193     OP_REQUIRES_OK(ctx, ctx->GetAttr("left_context", &left_context_));
194     OP_REQUIRES_OK(ctx, ctx->GetAttr("right_context", &right_context_));
195     OP_REQUIRES_OK(ctx, ctx->GetAttr("frame_stride", &frame_stride_));
196     OP_REQUIRES_OK(ctx, ctx->GetAttr("zero_padding", &zero_padding_));
197     OP_REQUIRES_OK(ctx, ctx->GetAttr("out_scale", &out_scale_));
198   }
199 
Compute(OpKernelContext * ctx)200   void Compute(OpKernelContext* ctx) override {
201     const Tensor* audio;
202     OP_REQUIRES_OK(ctx, ctx->input("audio", &audio));
203     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(audio->shape()),
204                 InvalidArgument("audio is not a vector"));
205 
206     auto audio_data =
207         reinterpret_cast<const int16_t*>(audio->tensor_data().data());
208     int audio_size = audio->NumElements();
209 
210     Tensor* filterbanks = nullptr;
211     int window_size = config_.window.size_ms * sample_rate_ / 1000;
212     int window_step = config_.window.step_size_ms * sample_rate_ / 1000;
213     int num_frames = 0;
214     int sampled_frames = 0;
215     if (audio_size >= window_size) {
216       num_frames = (audio_size - window_size) / window_step + 1;
217       sampled_frames = (num_frames - 1) / frame_stride_ + 1;
218     }
219     TensorShape filterbanks_shape{
220         sampled_frames,
221         config_.filterbank.num_channels * (1 + left_context_ + right_context_)};
222     OP_REQUIRES_OK(ctx,
223                    ctx->allocate_output(0, filterbanks_shape, &filterbanks));
224     auto filterbanks_flat = filterbanks->flat<T>();
225 
226     struct FrontendState state;
227     if (!TF_PREDICT_TRUE(
228             FrontendPopulateState(&config_, &state, sample_rate_))) {
229       ctx->CtxFailure(__FILE__, __LINE__,
230                       Internal("failed to populate frontend state"));
231       FrontendFreeStateContents(&state);
232       return;
233     }
234 
235     std::vector<std::vector<T>> frame_buffer(num_frames);
236     int frame_index = 0;
237     while (audio_size > 0) {
238       size_t num_samples_read;
239       struct FrontendOutput output = FrontendProcessSamples(
240           &state, audio_data, audio_size, &num_samples_read);
241       audio_data += num_samples_read;
242       audio_size -= num_samples_read;
243 
244       if (output.values != nullptr) {
245         frame_buffer[frame_index].reserve(output.size);
246         int i;
247         for (i = 0; i < output.size; ++i) {
248           frame_buffer[frame_index].push_back(static_cast<T>(output.values[i]) /
249                                               out_scale_);
250         }
251         ++frame_index;
252       }
253     }
254     FrontendFreeStateContents(&state);
255 
256     int index = 0;
257     std::vector<T> pad(config_.filterbank.num_channels, 0);
258     int anchor;
259     for (anchor = 0; anchor < frame_buffer.size(); anchor += frame_stride_) {
260       int frame;
261       for (frame = anchor - left_context_; frame <= anchor + right_context_;
262            ++frame) {
263         std::vector<T>* feature;
264         if (zero_padding_ && (frame < 0 || frame >= frame_buffer.size())) {
265           feature = &pad;
266         } else if (frame < 0) {
267           feature = &frame_buffer[0];
268         } else if (frame >= frame_buffer.size()) {
269           feature = &frame_buffer[frame_buffer.size() - 1];
270         } else {
271           feature = &frame_buffer[frame];
272         }
273         for (auto f : *feature) {
274           filterbanks_flat(index++) = f;
275         }
276       }
277     }
278   }
279 
280  protected:
281   int sample_rate_;
282   struct FrontendConfig config_;
283   int left_context_;
284   int right_context_;
285   int frame_stride_;
286   bool zero_padding_;
287   int out_scale_;
288 
289   TF_DISALLOW_COPY_AND_ASSIGN(AudioMicrofrontendOp);
290 };
291 
292 REGISTER_KERNEL_BUILDER(Name("AudioMicrofrontend")
293                             .Device(tensorflow::DEVICE_CPU)
294                             .TypeConstraint<uint16>("out_type"),
295                         AudioMicrofrontendOp<uint16>);
296 REGISTER_KERNEL_BUILDER(Name("AudioMicrofrontend")
297                             .Device(tensorflow::DEVICE_CPU)
298                             .TypeConstraint<float>("out_type"),
299                         AudioMicrofrontendOp<float>);
300 }  // namespace tensorflow
301