xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ops/audio_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 #include "tensorflow/core/lib/core/bits.h"
20 
21 namespace tensorflow {
22 
23 namespace {
24 
25 using shape_inference::DimensionHandle;
26 using shape_inference::InferenceContext;
27 using shape_inference::ShapeHandle;
28 
DecodeWavShapeFn(InferenceContext * c)29 Status DecodeWavShapeFn(InferenceContext* c) {
30   ShapeHandle unused;
31   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
32 
33   DimensionHandle channels_dim;
34   int32_t desired_channels;
35   TF_RETURN_IF_ERROR(c->GetAttr("desired_channels", &desired_channels));
36   if (desired_channels == -1) {
37     channels_dim = c->UnknownDim();
38   } else {
39     if (desired_channels < 0) {
40       return errors::InvalidArgument("channels must be non-negative, got ",
41                                      desired_channels);
42     }
43     channels_dim = c->MakeDim(desired_channels);
44   }
45   DimensionHandle samples_dim;
46   int32_t desired_samples;
47   TF_RETURN_IF_ERROR(c->GetAttr("desired_samples", &desired_samples));
48   if (desired_samples == -1) {
49     samples_dim = c->UnknownDim();
50   } else {
51     if (desired_samples < 0) {
52       return errors::InvalidArgument("samples must be non-negative, got ",
53                                      desired_samples);
54     }
55     samples_dim = c->MakeDim(desired_samples);
56   }
57   c->set_output(0, c->MakeShape({samples_dim, channels_dim}));
58   c->set_output(1, c->Scalar());
59   return OkStatus();
60 }
61 
EncodeWavShapeFn(InferenceContext * c)62 Status EncodeWavShapeFn(InferenceContext* c) {
63   ShapeHandle unused;
64   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
65   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
66   c->set_output(0, c->Scalar());
67   return OkStatus();
68 }
69 
SpectrogramShapeFn(InferenceContext * c)70 Status SpectrogramShapeFn(InferenceContext* c) {
71   ShapeHandle input;
72   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
73   int32_t window_size;
74   TF_RETURN_IF_ERROR(c->GetAttr("window_size", &window_size));
75   int32_t stride;
76   TF_RETURN_IF_ERROR(c->GetAttr("stride", &stride));
77 
78   DimensionHandle input_length = c->Dim(input, 0);
79   DimensionHandle input_channels = c->Dim(input, 1);
80 
81   DimensionHandle output_length;
82   if (!c->ValueKnown(input_length)) {
83     output_length = c->UnknownDim();
84   } else {
85     const int64_t input_length_value = c->Value(input_length);
86     const int64_t length_minus_window = (input_length_value - window_size);
87     int64_t output_length_value;
88     if (length_minus_window < 0) {
89       output_length_value = 0;
90     } else {
91       output_length_value = 1 + (length_minus_window / stride);
92     }
93     output_length = c->MakeDim(output_length_value);
94   }
95 
96   DimensionHandle output_channels =
97       c->MakeDim(1 + NextPowerOfTwo(window_size) / 2);
98   c->set_output(0,
99                 c->MakeShape({input_channels, output_length, output_channels}));
100   return OkStatus();
101 }
102 
MfccShapeFn(InferenceContext * c)103 Status MfccShapeFn(InferenceContext* c) {
104   ShapeHandle spectrogram;
105   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &spectrogram));
106   ShapeHandle unused;
107   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
108 
109   int32_t dct_coefficient_count;
110   TF_RETURN_IF_ERROR(
111       c->GetAttr("dct_coefficient_count", &dct_coefficient_count));
112 
113   DimensionHandle spectrogram_channels = c->Dim(spectrogram, 0);
114   DimensionHandle spectrogram_length = c->Dim(spectrogram, 1);
115 
116   DimensionHandle output_channels = c->MakeDim(dct_coefficient_count);
117 
118   c->set_output(0, c->MakeShape({spectrogram_channels, spectrogram_length,
119                                  output_channels}));
120   return OkStatus();
121 }
122 
123 }  // namespace
124 
125 REGISTER_OP("DecodeWav")
126     .Input("contents: string")
127     .Attr("desired_channels: int = -1")
128     .Attr("desired_samples: int = -1")
129     .Output("audio: float")
130     .Output("sample_rate: int32")
131     .SetShapeFn(DecodeWavShapeFn);
132 
133 REGISTER_OP("EncodeWav")
134     .Input("audio: float")
135     .Input("sample_rate: int32")
136     .Output("contents: string")
137     .SetShapeFn(EncodeWavShapeFn);
138 
139 REGISTER_OP("AudioSpectrogram")
140     .Input("input: float")
141     .Attr("window_size: int")
142     .Attr("stride: int")
143     .Attr("magnitude_squared: bool = false")
144     .Output("spectrogram: float")
145     .SetShapeFn(SpectrogramShapeFn);
146 
147 REGISTER_OP("Mfcc")
148     .Input("spectrogram: float")
149     .Input("sample_rate: int32")
150     .Attr("upper_frequency_limit: float = 4000")
151     .Attr("lower_frequency_limit: float = 20")
152     .Attr("filterbank_channel_count: int = 40")
153     .Attr("dct_coefficient_count: int = 13")
154     .Output("output: float")
155     .SetShapeFn(MfccShapeFn);
156 
157 }  // namespace tensorflow
158