1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
16
17 #include <algorithm>
18
19 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
20
21 namespace tflite {
22 namespace kernel_utils {
23
RnnBatchStep(const float * input_ptr_batch,const float * input_weights_ptr,const float * recurrent_weights_ptr,const float * bias_ptr,int input_size,int num_units,int batch_size,int output_batch_leading_dim,TfLiteFusedActivation activation,float * hidden_state_ptr_batch,float * output_ptr_batch)24 void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
25 const float* recurrent_weights_ptr, const float* bias_ptr,
26 int input_size, int num_units, int batch_size,
27 int output_batch_leading_dim,
28 TfLiteFusedActivation activation,
29 float* hidden_state_ptr_batch, float* output_ptr_batch) {
30 RnnBatchStep(input_ptr_batch, input_weights_ptr,
31 /*aux_input_ptr_batch=*/nullptr,
32 /*aux_input_weights_ptr=*/nullptr, recurrent_weights_ptr,
33 bias_ptr, input_size, /*aux_input_size=*/0, num_units,
34 batch_size, output_batch_leading_dim, activation,
35 hidden_state_ptr_batch, output_ptr_batch);
36 }
37
RnnBatchStep(const float * input_ptr_batch,const float * input_weights_ptr,const float * aux_input_ptr_batch,const float * aux_input_weights_ptr,const float * recurrent_weights_ptr,const float * bias_ptr,int input_size,int aux_input_size,int num_units,int batch_size,int output_batch_leading_dim,TfLiteFusedActivation activation,float * hidden_state_ptr_batch,float * output_ptr_batch)38 void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
39 const float* aux_input_ptr_batch,
40 const float* aux_input_weights_ptr,
41 const float* recurrent_weights_ptr, const float* bias_ptr,
42 int input_size, int aux_input_size, int num_units,
43 int batch_size, int output_batch_leading_dim,
44 TfLiteFusedActivation activation,
45 float* hidden_state_ptr_batch, float* output_ptr_batch) {
46 // Since the output batch rows may not be contiguous (output_batch_leading_dim
47 // != n_output), we unroll the batched operations where this is the case.
48 if (output_batch_leading_dim == num_units) {
49 // Output = bias
50 tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
51 output_ptr_batch);
52
53 // Output += input * input_weights
54 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
55 input_weights_ptr, num_units, input_size, input_ptr_batch, batch_size,
56 output_ptr_batch);
57
58 // Output += aux_input * aux_input_weights (if they are not empty).
59 if (aux_input_size > 0) {
60 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
61 aux_input_weights_ptr, num_units, aux_input_size, aux_input_ptr_batch,
62 batch_size, output_ptr_batch);
63 }
64
65 // Output += recurrent_weights * hidden_state
66 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
67 recurrent_weights_ptr, num_units, num_units, hidden_state_ptr_batch,
68 batch_size, output_ptr_batch);
69
70 // Output = activation(Output) and update hidden_state
71 tensor_utils::ApplyActivationToVector(
72 output_ptr_batch, num_units * batch_size, activation, output_ptr_batch);
73 std::copy_n(output_ptr_batch, num_units * batch_size,
74 hidden_state_ptr_batch);
75 } else {
76 // Output = bias
77 for (int k = 0; k < batch_size; k++) {
78 std::copy_n(bias_ptr, num_units,
79 output_ptr_batch + k * output_batch_leading_dim);
80 }
81
82 // Output += input * input_weights
83 for (int k = 0; k < batch_size; k++) {
84 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
85 input_weights_ptr, num_units, input_size,
86 input_ptr_batch + k * input_size, /*n_batch=*/1,
87 output_ptr_batch + k * output_batch_leading_dim);
88 }
89
90 // Output += aux_input * aux_input_weights (if they are not empty).
91 if (aux_input_size > 0) {
92 for (int k = 0; k < batch_size; k++) {
93 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
94 aux_input_weights_ptr, num_units, aux_input_size,
95 aux_input_ptr_batch + k * aux_input_size,
96 /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim);
97 }
98 }
99
100 // Output += recurrent_weights * hidden_state
101 for (int k = 0; k < batch_size; k++) {
102 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
103 recurrent_weights_ptr, num_units, num_units,
104 hidden_state_ptr_batch + k * num_units,
105 /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim);
106 }
107
108 // Output = activation(Output) and update hidden_state
109 for (int k = 0; k < batch_size; k++) {
110 tensor_utils::ApplyActivationToVector(
111 output_ptr_batch + k * output_batch_leading_dim, num_units,
112 activation, output_ptr_batch + k * output_batch_leading_dim);
113 std::copy_n(output_ptr_batch + k * output_batch_leading_dim, num_units,
114 hidden_state_ptr_batch + k * num_units);
115 }
116 }
117 }
118
RnnBatchStep(const float * input_ptr_batch,const int8_t * input_weights_ptr,float input_weights_scale,const int8_t * recurrent_weights_ptr,float recurrent_weights_scale,const float * bias_ptr,int input_size,int num_units,int batch_size,int output_batch_leading_dim,TfLiteFusedActivation activation,int8_t * quantized_input_ptr_batch,int8_t * quantized_hidden_state_ptr_batch,float * scaling_factors,float * hidden_state_ptr_batch,float * output_ptr_batch,bool asymmetric_quantize_inputs,int32_t * zero_points,int32_t * accum_scratch,int32_t * row_sums,bool * compute_row_sums)119 void RnnBatchStep(
120 const float* input_ptr_batch, const int8_t* input_weights_ptr,
121 float input_weights_scale, const int8_t* recurrent_weights_ptr,
122 float recurrent_weights_scale, const float* bias_ptr, int input_size,
123 int num_units, int batch_size, int output_batch_leading_dim,
124 TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
125 int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
126 float* hidden_state_ptr_batch, float* output_ptr_batch,
127 bool asymmetric_quantize_inputs, int32_t* zero_points,
128 int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums) {
129 RnnBatchStep(input_ptr_batch, input_weights_ptr, input_weights_scale,
130 /*aux_input_ptr_batch=*/nullptr,
131 /*aux_input_weights_ptr=*/nullptr,
132 /*aux_input_weights_scale=*/0.0f, recurrent_weights_ptr,
133 recurrent_weights_scale, bias_ptr, input_size,
134 /*aux_input_size=*/0, num_units, batch_size,
135 output_batch_leading_dim, activation, quantized_input_ptr_batch,
136 /*aux_quantized_input_ptr_batch=*/nullptr,
137 quantized_hidden_state_ptr_batch, scaling_factors,
138 hidden_state_ptr_batch, output_ptr_batch,
139 asymmetric_quantize_inputs, zero_points, accum_scratch, row_sums,
140 compute_row_sums);
141 }
142
RnnBatchStep(const float * input_ptr_batch,const int8_t * input_weights_ptr,float input_weights_scale,const float * aux_input_ptr_batch,const int8_t * aux_input_weights_ptr,float aux_input_weights_scale,const int8_t * recurrent_weights_ptr,float recurrent_weights_scale,const float * bias_ptr,int input_size,int aux_input_size,int num_units,int batch_size,int output_batch_leading_dim,TfLiteFusedActivation activation,int8_t * quantized_input_ptr_batch,int8_t * aux_quantized_input_ptr_batch,int8_t * quantized_hidden_state_ptr_batch,float * scaling_factors,float * hidden_state_ptr_batch,float * output_ptr_batch,bool asymmetric_quantize_inputs,int32_t * zero_points,int32_t * accum_scratch,int32_t * row_sums,bool * compute_row_sums)143 void RnnBatchStep(
144 const float* input_ptr_batch, const int8_t* input_weights_ptr,
145 float input_weights_scale, const float* aux_input_ptr_batch,
146 const int8_t* aux_input_weights_ptr, float aux_input_weights_scale,
147 const int8_t* recurrent_weights_ptr, float recurrent_weights_scale,
148 const float* bias_ptr, int input_size, int aux_input_size, int num_units,
149 int batch_size, int output_batch_leading_dim,
150 TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
151 int8_t* aux_quantized_input_ptr_batch,
152 int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
153 float* hidden_state_ptr_batch, float* output_ptr_batch,
154 bool asymmetric_quantize_inputs, int32_t* zero_points,
155 int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums) {
156 // Since the output batch rows may not be contiguous (output_batch_leading_dim
157 // != n_output), we unroll the batched operations where this is the case.
158
159 int32_t* input_row_sums = nullptr;
160 int32_t* aux_input_row_sums = nullptr;
161 int32_t* recurrent_row_sums = nullptr;
162 if (asymmetric_quantize_inputs) {
163 input_row_sums = row_sums;
164 aux_input_row_sums = row_sums;
165 if (aux_input_ptr_batch) {
166 aux_input_row_sums += num_units;
167 }
168 recurrent_row_sums = aux_input_row_sums + num_units;
169 if (*compute_row_sums) {
170 tensor_utils::ReductionSumVector(input_weights_ptr, input_row_sums,
171 num_units, input_size);
172 if (aux_input_ptr_batch) {
173 tensor_utils::ReductionSumVector(aux_input_weights_ptr,
174 aux_input_row_sums, num_units,
175 aux_input_size);
176 }
177 tensor_utils::ReductionSumVector(
178 recurrent_weights_ptr, recurrent_row_sums, num_units, num_units);
179 *compute_row_sums = false;
180 }
181 }
182
183 if (output_batch_leading_dim == num_units) {
184 // Output = bias
185 tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
186 output_ptr_batch);
187
188 // Save quantization and matmul computation for all zero input.
189 if (!tensor_utils::IsZeroVector(input_ptr_batch, batch_size * input_size)) {
190 // Quantize input from float to uint8 + quantization params (scaling
191 // factor).
192 tensor_utils::BatchQuantizeFloats(
193 input_ptr_batch, batch_size, input_size, quantized_input_ptr_batch,
194 scaling_factors, zero_points, asymmetric_quantize_inputs);
195 for (int b = 0; b < batch_size; ++b) {
196 scaling_factors[b] *= input_weights_scale;
197 }
198 // Output += input * input_weights
199 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
200 input_weights_ptr, num_units, input_size, quantized_input_ptr_batch,
201 scaling_factors, batch_size, output_ptr_batch,
202 /*per_channel_scale=*/nullptr, zero_points, accum_scratch,
203 input_row_sums, compute_row_sums, /*context=*/nullptr);
204 }
205
206 if (aux_input_ptr_batch &&
207 !tensor_utils::IsZeroVector(aux_input_ptr_batch,
208 batch_size * aux_input_size)) {
209 tensor_utils::BatchQuantizeFloats(
210 aux_input_ptr_batch, batch_size, aux_input_size,
211 aux_quantized_input_ptr_batch, scaling_factors, zero_points,
212 asymmetric_quantize_inputs);
213 for (int b = 0; b < batch_size; ++b) {
214 scaling_factors[b] *= aux_input_weights_scale;
215 }
216
217 // Output += aux_input * aux_input_weights
218 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
219 aux_input_weights_ptr, num_units, aux_input_size,
220 aux_quantized_input_ptr_batch, scaling_factors, batch_size,
221 output_ptr_batch, /*per_channel_scale=*/nullptr, zero_points,
222 accum_scratch, aux_input_row_sums, compute_row_sums,
223 /*context=*/nullptr);
224 }
225
226 // Save quantization and matmul computation for all zero input.
227 if (!tensor_utils::IsZeroVector(hidden_state_ptr_batch,
228 batch_size * num_units)) {
229 // Quantize hidden_state
230 tensor_utils::BatchQuantizeFloats(
231 hidden_state_ptr_batch, batch_size, num_units,
232 quantized_hidden_state_ptr_batch, scaling_factors, zero_points,
233 asymmetric_quantize_inputs);
234 for (int b = 0; b < batch_size; ++b) {
235 scaling_factors[b] *= recurrent_weights_scale;
236 }
237
238 // Output += recurrent_weights * hidden_state
239 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
240 recurrent_weights_ptr, num_units, num_units,
241 quantized_hidden_state_ptr_batch, scaling_factors, batch_size,
242 output_ptr_batch, /*per_channel_scale=*/nullptr, zero_points,
243 accum_scratch, recurrent_row_sums, compute_row_sums,
244 /*context=*/nullptr);
245 }
246
247 // Output = activation(Output) and update hidden_state
248 tensor_utils::ApplyActivationToVector(
249 output_ptr_batch, num_units * batch_size, activation, output_ptr_batch);
250 std::copy_n(output_ptr_batch, num_units * batch_size,
251 hidden_state_ptr_batch);
252 } else {
253 // Output = bias
254 for (int k = 0; k < batch_size; k++) {
255 std::copy_n(bias_ptr, num_units,
256 output_ptr_batch + k * output_batch_leading_dim);
257 }
258
259 // Save quantization and matmul computation for all zero input.
260 if (!tensor_utils::IsZeroVector(input_ptr_batch, batch_size * input_size)) {
261 // Quantize input from float to uint8 + quantization params (scaling
262 // factor).
263 tensor_utils::BatchQuantizeFloats(
264 input_ptr_batch, batch_size, input_size, quantized_input_ptr_batch,
265 scaling_factors, zero_points, asymmetric_quantize_inputs);
266 for (int b = 0; b < batch_size; ++b) {
267 scaling_factors[b] *= input_weights_scale;
268 }
269
270 // Output += input * input_weights
271 for (int k = 0; k < batch_size; k++) {
272 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
273 input_weights_ptr, num_units, input_size,
274 quantized_input_ptr_batch + k * input_size, &scaling_factors[k],
275 /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim,
276 /*per_channel_scale=*/nullptr, zero_points + k, accum_scratch,
277 input_row_sums, compute_row_sums, /*context=*/nullptr);
278 }
279 }
280
281 if (aux_input_ptr_batch &&
282 !tensor_utils::IsZeroVector(aux_input_ptr_batch,
283 batch_size * aux_input_size)) {
284 tensor_utils::BatchQuantizeFloats(
285 aux_input_ptr_batch, batch_size, aux_input_size,
286 aux_quantized_input_ptr_batch, scaling_factors, zero_points,
287 asymmetric_quantize_inputs);
288 for (int b = 0; b < batch_size; ++b) {
289 scaling_factors[b] *= aux_input_weights_scale;
290 }
291
292 // Output += aux_input * aux_input_weights
293 for (int k = 0; k < batch_size; k++) {
294 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
295 aux_input_weights_ptr, num_units, aux_input_size,
296 aux_quantized_input_ptr_batch + k * aux_input_size,
297 &scaling_factors[k],
298 /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim,
299 /*per_channel_scale=*/nullptr, zero_points + k, accum_scratch,
300 aux_input_row_sums, compute_row_sums, /*context=*/nullptr);
301 }
302 }
303
304 // Save quantization and matmul computation for all zero input.
305 if (!tensor_utils::IsZeroVector(hidden_state_ptr_batch,
306 batch_size * num_units)) {
307 // Quantize hidden_state
308 tensor_utils::BatchQuantizeFloats(
309 hidden_state_ptr_batch, batch_size, num_units,
310 quantized_hidden_state_ptr_batch, scaling_factors, zero_points,
311 asymmetric_quantize_inputs);
312 for (int b = 0; b < batch_size; ++b) {
313 scaling_factors[b] *= recurrent_weights_scale;
314 }
315
316 // Output += recurrent_weights * hidden_state
317 for (int k = 0; k < batch_size; k++) {
318 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
319 recurrent_weights_ptr, num_units, num_units,
320 quantized_hidden_state_ptr_batch + k * num_units,
321 &scaling_factors[k], /*n_batch=*/1,
322 output_ptr_batch + k * output_batch_leading_dim,
323 /*per_channel_scale=*/nullptr, zero_points + k, accum_scratch,
324 recurrent_row_sums, compute_row_sums, /*context=*/nullptr);
325 }
326 }
327
328 // Output = activation(Output) and update hidden_state
329 for (int k = 0; k < batch_size; k++) {
330 tensor_utils::ApplyActivationToVector(
331 output_ptr_batch + k * output_batch_leading_dim, num_units,
332 activation, output_ptr_batch + k * output_batch_leading_dim);
333 std::copy_n(output_ptr_batch + k * output_batch_leading_dim, num_units,
334 hidden_state_ptr_batch + k * num_units);
335 }
336 }
337 }
338
339 } // namespace kernel_utils
340 } // namespace tflite
341