1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 #include "tensorflow/core/lib/core/bits.h"
20
21 namespace tensorflow {
22
23 namespace {
24
25 using shape_inference::DimensionHandle;
26 using shape_inference::InferenceContext;
27 using shape_inference::ShapeHandle;
28
DecodeWavShapeFn(InferenceContext * c)29 Status DecodeWavShapeFn(InferenceContext* c) {
30 ShapeHandle unused;
31 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
32
33 DimensionHandle channels_dim;
34 int32_t desired_channels;
35 TF_RETURN_IF_ERROR(c->GetAttr("desired_channels", &desired_channels));
36 if (desired_channels == -1) {
37 channels_dim = c->UnknownDim();
38 } else {
39 if (desired_channels < 0) {
40 return errors::InvalidArgument("channels must be non-negative, got ",
41 desired_channels);
42 }
43 channels_dim = c->MakeDim(desired_channels);
44 }
45 DimensionHandle samples_dim;
46 int32_t desired_samples;
47 TF_RETURN_IF_ERROR(c->GetAttr("desired_samples", &desired_samples));
48 if (desired_samples == -1) {
49 samples_dim = c->UnknownDim();
50 } else {
51 if (desired_samples < 0) {
52 return errors::InvalidArgument("samples must be non-negative, got ",
53 desired_samples);
54 }
55 samples_dim = c->MakeDim(desired_samples);
56 }
57 c->set_output(0, c->MakeShape({samples_dim, channels_dim}));
58 c->set_output(1, c->Scalar());
59 return OkStatus();
60 }
61
EncodeWavShapeFn(InferenceContext * c)62 Status EncodeWavShapeFn(InferenceContext* c) {
63 ShapeHandle unused;
64 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
65 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
66 c->set_output(0, c->Scalar());
67 return OkStatus();
68 }
69
SpectrogramShapeFn(InferenceContext * c)70 Status SpectrogramShapeFn(InferenceContext* c) {
71 ShapeHandle input;
72 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
73 int32_t window_size;
74 TF_RETURN_IF_ERROR(c->GetAttr("window_size", &window_size));
75 int32_t stride;
76 TF_RETURN_IF_ERROR(c->GetAttr("stride", &stride));
77
78 DimensionHandle input_length = c->Dim(input, 0);
79 DimensionHandle input_channels = c->Dim(input, 1);
80
81 DimensionHandle output_length;
82 if (!c->ValueKnown(input_length)) {
83 output_length = c->UnknownDim();
84 } else {
85 const int64_t input_length_value = c->Value(input_length);
86 const int64_t length_minus_window = (input_length_value - window_size);
87 int64_t output_length_value;
88 if (length_minus_window < 0) {
89 output_length_value = 0;
90 } else {
91 output_length_value = 1 + (length_minus_window / stride);
92 }
93 output_length = c->MakeDim(output_length_value);
94 }
95
96 DimensionHandle output_channels =
97 c->MakeDim(1 + NextPowerOfTwo(window_size) / 2);
98 c->set_output(0,
99 c->MakeShape({input_channels, output_length, output_channels}));
100 return OkStatus();
101 }
102
MfccShapeFn(InferenceContext * c)103 Status MfccShapeFn(InferenceContext* c) {
104 ShapeHandle spectrogram;
105 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &spectrogram));
106 ShapeHandle unused;
107 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
108
109 int32_t dct_coefficient_count;
110 TF_RETURN_IF_ERROR(
111 c->GetAttr("dct_coefficient_count", &dct_coefficient_count));
112
113 DimensionHandle spectrogram_channels = c->Dim(spectrogram, 0);
114 DimensionHandle spectrogram_length = c->Dim(spectrogram, 1);
115
116 DimensionHandle output_channels = c->MakeDim(dct_coefficient_count);
117
118 c->set_output(0, c->MakeShape({spectrogram_channels, spectrogram_length,
119 output_channels}));
120 return OkStatus();
121 }
122
123 } // namespace
124
125 REGISTER_OP("DecodeWav")
126 .Input("contents: string")
127 .Attr("desired_channels: int = -1")
128 .Attr("desired_samples: int = -1")
129 .Output("audio: float")
130 .Output("sample_rate: int32")
131 .SetShapeFn(DecodeWavShapeFn);
132
133 REGISTER_OP("EncodeWav")
134 .Input("audio: float")
135 .Input("sample_rate: int32")
136 .Output("contents: string")
137 .SetShapeFn(EncodeWavShapeFn);
138
139 REGISTER_OP("AudioSpectrogram")
140 .Input("input: float")
141 .Attr("window_size: int")
142 .Attr("stride: int")
143 .Attr("magnitude_squared: bool = false")
144 .Output("spectrogram: float")
145 .SetShapeFn(SpectrogramShapeFn);
146
147 REGISTER_OP("Mfcc")
148 .Input("spectrogram: float")
149 .Input("sample_rate: int32")
150 .Attr("upper_frequency_limit: float = 4000")
151 .Attr("lower_frequency_limit: float = 20")
152 .Attr("filterbank_channel_count: int = 40")
153 .Attr("dct_coefficient_count: int = 13")
154 .Output("output: float")
155 .SetShapeFn(MfccShapeFn);
156
157 } // namespace tensorflow
158