1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/lite/experimental/microfrontend/lib/frontend.h" 16 #include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h" 17 #include "tensorflow/core/framework/op.h" 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/framework/shape_inference.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/framework/tensor_shape.h" 22 #include "tensorflow/core/lib/core/errors.h" 23 #include "tensorflow/core/lib/core/status.h" 24 #include "tensorflow/core/platform/macros.h" 25 26 using tensorflow::errors::Internal; 27 using tensorflow::errors::InvalidArgument; 28 using tensorflow::shape_inference::DimensionHandle; 29 using tensorflow::shape_inference::InferenceContext; 30 using tensorflow::shape_inference::ShapeHandle; 31 32 namespace tensorflow { 33 REGISTER_OP("AudioMicrofrontend") 34 .Input("audio: int16") 35 .Output("filterbanks: out_type") 36 .Attr("sample_rate: int = 16000") 37 .Attr("window_size: int = 25") 38 .Attr("window_step: int = 10") 39 .Attr("num_channels: int = 32") 40 .Attr("upper_band_limit: float = 7500.0") 41 .Attr("lower_band_limit: float = 125.0") 42 .Attr("smoothing_bits: int = 10") 43 .Attr("even_smoothing: float = 0.025") 44 .Attr("odd_smoothing: float = 0.06") 45 .Attr("min_signal_remaining: float = 0.05") 46 .Attr("enable_pcan: bool = false") 47 .Attr("pcan_strength: float = 0.95") 48 .Attr("pcan_offset: float = 80.0") 49 .Attr("gain_bits: int = 21") 50 .Attr("enable_log: bool = true") 51 .Attr("scale_shift: int = 6") 52 .Attr("left_context: int = 0") 53 .Attr("right_context: int = 0") 54 .Attr("frame_stride: int = 1") 55 .Attr("zero_padding: bool = false") 56 .Attr("out_scale: int = 1") 57 .Attr("out_type: {uint16, float} = DT_UINT16") __anon8824a42c0102(InferenceContext* ctx) 58 .SetShapeFn([](InferenceContext* ctx) { 59 ShapeHandle input; 60 TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(0), 1, &input)); 61 62 int sample_rate; 63 TF_RETURN_IF_ERROR(ctx->GetAttr("sample_rate", &sample_rate)); 64 int window_size; 65 TF_RETURN_IF_ERROR(ctx->GetAttr("window_size", &window_size)); 66 window_size *= sample_rate / 1000; 67 int window_step; 68 TF_RETURN_IF_ERROR(ctx->GetAttr("window_step", &window_step)); 69 window_step *= sample_rate / 1000; 70 71 int num_channels; 72 TF_RETURN_IF_ERROR(ctx->GetAttr("num_channels", &num_channels)); 73 int left_context; 74 TF_RETURN_IF_ERROR(ctx->GetAttr("left_context", &left_context)); 75 int right_context; 76 TF_RETURN_IF_ERROR(ctx->GetAttr("right_context", &right_context)); 77 int frame_stride; 78 TF_RETURN_IF_ERROR(ctx->GetAttr("frame_stride", &frame_stride)); 79 80 DimensionHandle num_frames = ctx->Dim(input, 0); 81 if (ctx->Value(num_frames) < window_size) { 82 num_frames = ctx->MakeDim(0); 83 } else { 84 TF_RETURN_IF_ERROR(ctx->Subtract(num_frames, window_size, &num_frames)); 85 TF_RETURN_IF_ERROR( 86 ctx->Divide(num_frames, window_step, false, &num_frames)); 87 TF_RETURN_IF_ERROR( 88 ctx->Divide(num_frames, frame_stride, false, &num_frames)); 89 TF_RETURN_IF_ERROR(ctx->Add(num_frames, 1, &num_frames)); 90 } 91 92 int stack_size = 1 + left_context + right_context; 93 DimensionHandle num_features = ctx->MakeDim(num_channels); 94 TF_RETURN_IF_ERROR( 95 ctx->Multiply(num_features, stack_size, &num_features)); 96 97 ShapeHandle output = ctx->MakeShape({num_frames, num_features}); 98 ctx->set_output(0, output); 99 return OkStatus(); 100 }) 101 .Doc(R"doc( 102 Audio Microfrontend Op. 103 104 This Op converts a sequence of audio data into one or more 105 feature vectors containing filterbanks of the input. The 106 conversion process uses a lightweight library to perform: 107 108 1. A slicing window function 109 2. Short-time FFTs 110 3. Filterbank calculations 111 4. Noise reduction 112 5. PCAN Auto Gain Control 113 6. Logarithmic scaling 114 115 Arguments 116 audio: 1D Tensor, int16 audio data in temporal ordering. 117 sample_rate: Integer, the sample rate of the audio in Hz. 118 window_size: Integer, length of desired time frames in ms. 119 window_step: Integer, length of step size for the next frame in ms. 120 num_channels: Integer, the number of filterbank channels to use. 121 upper_band_limit: Float, the highest frequency included in the filterbanks. 122 lower_band_limit: Float, the lowest frequency included in the filterbanks. 123 smoothing_bits: Int, scale up signal by 2^(smoothing_bits) before reduction. 124 even_smoothing: Float, smoothing coefficient for even-numbered channels. 125 odd_smoothing: Float, smoothing coefficient for odd-numbered channels. 126 min_signal_remaining: Float, fraction of signal to preserve in smoothing. 127 enable_pcan: Bool, enable PCAN auto gain control. 128 pcan_strength: Float, gain normalization exponent. 129 pcan_offset: Float, positive value added in the normalization denominator. 130 gain_bits: Int, number of fractional bits in the gain. 131 enable_log: Bool, enable logarithmic scaling of filterbanks. 132 scale_shift: Integer, scale filterbanks by 2^(scale_shift). 133 left_context: Integer, number of preceding frames to attach to each frame. 134 right_context: Integer, number of preceding frames to attach to each frame. 135 frame_stride: Integer, M frames to skip over, where output[n] = frame[n*M]. 136 zero_padding: Bool, if left/right context is out-of-bounds, attach frame of 137 zeroes. Otherwise, frame[0] or frame[size-1] will be copied. 138 out_scale: Integer, divide all filterbanks by this number. 139 out_type: DType, type of the output Tensor, defaults to UINT16. 140 141 Returns 142 filterbanks: 2D Tensor, each row is a time frame, each column is a channel. 143 )doc"); 144 145 template <typename T> 146 class AudioMicrofrontendOp : public OpKernel { 147 public: AudioMicrofrontendOp(OpKernelConstruction * ctx)148 explicit AudioMicrofrontendOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 149 OP_REQUIRES_OK(ctx, ctx->GetAttr("sample_rate", &sample_rate_)); 150 151 int window_size; 152 OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size", &window_size)); 153 config_.window.size_ms = window_size; 154 155 int window_step; 156 OP_REQUIRES_OK(ctx, ctx->GetAttr("window_step", &window_step)); 157 config_.window.step_size_ms = window_step; 158 159 OP_REQUIRES_OK( 160 ctx, ctx->GetAttr("num_channels", &config_.filterbank.num_channels)); 161 OP_REQUIRES_OK(ctx, ctx->GetAttr("upper_band_limit", 162 &config_.filterbank.upper_band_limit)); 163 OP_REQUIRES_OK(ctx, ctx->GetAttr("lower_band_limit", 164 &config_.filterbank.lower_band_limit)); 165 OP_REQUIRES_OK(ctx, ctx->GetAttr("smoothing_bits", 166 &config_.noise_reduction.smoothing_bits)); 167 OP_REQUIRES_OK(ctx, ctx->GetAttr("even_smoothing", 168 &config_.noise_reduction.even_smoothing)); 169 OP_REQUIRES_OK(ctx, ctx->GetAttr("odd_smoothing", 170 &config_.noise_reduction.odd_smoothing)); 171 OP_REQUIRES_OK(ctx, 172 ctx->GetAttr("min_signal_remaining", 173 &config_.noise_reduction.min_signal_remaining)); 174 175 bool enable_pcan; 176 OP_REQUIRES_OK(ctx, ctx->GetAttr("enable_pcan", &enable_pcan)); 177 config_.pcan_gain_control.enable_pcan = enable_pcan; 178 179 OP_REQUIRES_OK(ctx, ctx->GetAttr("pcan_strength", 180 &config_.pcan_gain_control.strength)); 181 OP_REQUIRES_OK( 182 ctx, ctx->GetAttr("pcan_offset", &config_.pcan_gain_control.offset)); 183 OP_REQUIRES_OK( 184 ctx, ctx->GetAttr("gain_bits", &config_.pcan_gain_control.gain_bits)); 185 186 bool enable_log; 187 OP_REQUIRES_OK(ctx, ctx->GetAttr("enable_log", &enable_log)); 188 config_.log_scale.enable_log = enable_log; 189 190 OP_REQUIRES_OK(ctx, 191 ctx->GetAttr("scale_shift", &config_.log_scale.scale_shift)); 192 193 OP_REQUIRES_OK(ctx, ctx->GetAttr("left_context", &left_context_)); 194 OP_REQUIRES_OK(ctx, ctx->GetAttr("right_context", &right_context_)); 195 OP_REQUIRES_OK(ctx, ctx->GetAttr("frame_stride", &frame_stride_)); 196 OP_REQUIRES_OK(ctx, ctx->GetAttr("zero_padding", &zero_padding_)); 197 OP_REQUIRES_OK(ctx, ctx->GetAttr("out_scale", &out_scale_)); 198 } 199 Compute(OpKernelContext * ctx)200 void Compute(OpKernelContext* ctx) override { 201 const Tensor* audio; 202 OP_REQUIRES_OK(ctx, ctx->input("audio", &audio)); 203 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(audio->shape()), 204 InvalidArgument("audio is not a vector")); 205 206 auto audio_data = 207 reinterpret_cast<const int16_t*>(audio->tensor_data().data()); 208 int audio_size = audio->NumElements(); 209 210 Tensor* filterbanks = nullptr; 211 int window_size = config_.window.size_ms * sample_rate_ / 1000; 212 int window_step = config_.window.step_size_ms * sample_rate_ / 1000; 213 int num_frames = 0; 214 int sampled_frames = 0; 215 if (audio_size >= window_size) { 216 num_frames = (audio_size - window_size) / window_step + 1; 217 sampled_frames = (num_frames - 1) / frame_stride_ + 1; 218 } 219 TensorShape filterbanks_shape{ 220 sampled_frames, 221 config_.filterbank.num_channels * (1 + left_context_ + right_context_)}; 222 OP_REQUIRES_OK(ctx, 223 ctx->allocate_output(0, filterbanks_shape, &filterbanks)); 224 auto filterbanks_flat = filterbanks->flat<T>(); 225 226 struct FrontendState state; 227 if (!TF_PREDICT_TRUE( 228 FrontendPopulateState(&config_, &state, sample_rate_))) { 229 ctx->CtxFailure(__FILE__, __LINE__, 230 Internal("failed to populate frontend state")); 231 FrontendFreeStateContents(&state); 232 return; 233 } 234 235 std::vector<std::vector<T>> frame_buffer(num_frames); 236 int frame_index = 0; 237 while (audio_size > 0) { 238 size_t num_samples_read; 239 struct FrontendOutput output = FrontendProcessSamples( 240 &state, audio_data, audio_size, &num_samples_read); 241 audio_data += num_samples_read; 242 audio_size -= num_samples_read; 243 244 if (output.values != nullptr) { 245 frame_buffer[frame_index].reserve(output.size); 246 int i; 247 for (i = 0; i < output.size; ++i) { 248 frame_buffer[frame_index].push_back(static_cast<T>(output.values[i]) / 249 out_scale_); 250 } 251 ++frame_index; 252 } 253 } 254 FrontendFreeStateContents(&state); 255 256 int index = 0; 257 std::vector<T> pad(config_.filterbank.num_channels, 0); 258 int anchor; 259 for (anchor = 0; anchor < frame_buffer.size(); anchor += frame_stride_) { 260 int frame; 261 for (frame = anchor - left_context_; frame <= anchor + right_context_; 262 ++frame) { 263 std::vector<T>* feature; 264 if (zero_padding_ && (frame < 0 || frame >= frame_buffer.size())) { 265 feature = &pad; 266 } else if (frame < 0) { 267 feature = &frame_buffer[0]; 268 } else if (frame >= frame_buffer.size()) { 269 feature = &frame_buffer[frame_buffer.size() - 1]; 270 } else { 271 feature = &frame_buffer[frame]; 272 } 273 for (auto f : *feature) { 274 filterbanks_flat(index++) = f; 275 } 276 } 277 } 278 } 279 280 protected: 281 int sample_rate_; 282 struct FrontendConfig config_; 283 int left_context_; 284 int right_context_; 285 int frame_stride_; 286 bool zero_padding_; 287 int out_scale_; 288 289 TF_DISALLOW_COPY_AND_ASSIGN(AudioMicrofrontendOp); 290 }; 291 292 REGISTER_KERNEL_BUILDER(Name("AudioMicrofrontend") 293 .Device(tensorflow::DEVICE_CPU) 294 .TypeConstraint<uint16>("out_type"), 295 AudioMicrofrontendOp<uint16>); 296 REGISTER_KERNEL_BUILDER(Name("AudioMicrofrontend") 297 .Device(tensorflow::DEVICE_CPU) 298 .TypeConstraint<float>("out_type"), 299 AudioMicrofrontendOp<float>); 300 } // namespace tensorflow 301