xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/kernel_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
16 
17 #include <algorithm>
18 
19 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
20 
21 namespace tflite {
22 namespace kernel_utils {
23 
RnnBatchStep(const float * input_ptr_batch,const float * input_weights_ptr,const float * recurrent_weights_ptr,const float * bias_ptr,int input_size,int num_units,int batch_size,int output_batch_leading_dim,TfLiteFusedActivation activation,float * hidden_state_ptr_batch,float * output_ptr_batch)24 void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
25                   const float* recurrent_weights_ptr, const float* bias_ptr,
26                   int input_size, int num_units, int batch_size,
27                   int output_batch_leading_dim,
28                   TfLiteFusedActivation activation,
29                   float* hidden_state_ptr_batch, float* output_ptr_batch) {
30   RnnBatchStep(input_ptr_batch, input_weights_ptr,
31                /*aux_input_ptr_batch=*/nullptr,
32                /*aux_input_weights_ptr=*/nullptr, recurrent_weights_ptr,
33                bias_ptr, input_size, /*aux_input_size=*/0, num_units,
34                batch_size, output_batch_leading_dim, activation,
35                hidden_state_ptr_batch, output_ptr_batch);
36 }
37 
RnnBatchStep(const float * input_ptr_batch,const float * input_weights_ptr,const float * aux_input_ptr_batch,const float * aux_input_weights_ptr,const float * recurrent_weights_ptr,const float * bias_ptr,int input_size,int aux_input_size,int num_units,int batch_size,int output_batch_leading_dim,TfLiteFusedActivation activation,float * hidden_state_ptr_batch,float * output_ptr_batch)38 void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
39                   const float* aux_input_ptr_batch,
40                   const float* aux_input_weights_ptr,
41                   const float* recurrent_weights_ptr, const float* bias_ptr,
42                   int input_size, int aux_input_size, int num_units,
43                   int batch_size, int output_batch_leading_dim,
44                   TfLiteFusedActivation activation,
45                   float* hidden_state_ptr_batch, float* output_ptr_batch) {
46   // Since the output batch rows may not be contiguous (output_batch_leading_dim
47   // != n_output), we unroll the batched operations where this is the case.
48   if (output_batch_leading_dim == num_units) {
49     // Output = bias
50     tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
51                                           output_ptr_batch);
52 
53     // Output += input * input_weights
54     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
55         input_weights_ptr, num_units, input_size, input_ptr_batch, batch_size,
56         output_ptr_batch);
57 
58     // Output += aux_input * aux_input_weights (if they are not empty).
59     if (aux_input_size > 0) {
60       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
61           aux_input_weights_ptr, num_units, aux_input_size, aux_input_ptr_batch,
62           batch_size, output_ptr_batch);
63     }
64 
65     // Output += recurrent_weights * hidden_state
66     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
67         recurrent_weights_ptr, num_units, num_units, hidden_state_ptr_batch,
68         batch_size, output_ptr_batch);
69 
70     // Output = activation(Output) and update hidden_state
71     tensor_utils::ApplyActivationToVector(
72         output_ptr_batch, num_units * batch_size, activation, output_ptr_batch);
73     std::copy_n(output_ptr_batch, num_units * batch_size,
74                 hidden_state_ptr_batch);
75   } else {
76     // Output = bias
77     for (int k = 0; k < batch_size; k++) {
78       std::copy_n(bias_ptr, num_units,
79                   output_ptr_batch + k * output_batch_leading_dim);
80     }
81 
82     // Output += input * input_weights
83     for (int k = 0; k < batch_size; k++) {
84       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
85           input_weights_ptr, num_units, input_size,
86           input_ptr_batch + k * input_size, /*n_batch=*/1,
87           output_ptr_batch + k * output_batch_leading_dim);
88     }
89 
90     // Output += aux_input * aux_input_weights (if they are not empty).
91     if (aux_input_size > 0) {
92       for (int k = 0; k < batch_size; k++) {
93         tensor_utils::MatrixBatchVectorMultiplyAccumulate(
94             aux_input_weights_ptr, num_units, aux_input_size,
95             aux_input_ptr_batch + k * aux_input_size,
96             /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim);
97       }
98     }
99 
100     // Output += recurrent_weights * hidden_state
101     for (int k = 0; k < batch_size; k++) {
102       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
103           recurrent_weights_ptr, num_units, num_units,
104           hidden_state_ptr_batch + k * num_units,
105           /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim);
106     }
107 
108     // Output = activation(Output) and update hidden_state
109     for (int k = 0; k < batch_size; k++) {
110       tensor_utils::ApplyActivationToVector(
111           output_ptr_batch + k * output_batch_leading_dim, num_units,
112           activation, output_ptr_batch + k * output_batch_leading_dim);
113       std::copy_n(output_ptr_batch + k * output_batch_leading_dim, num_units,
114                   hidden_state_ptr_batch + k * num_units);
115     }
116   }
117 }
118 
RnnBatchStep(const float * input_ptr_batch,const int8_t * input_weights_ptr,float input_weights_scale,const int8_t * recurrent_weights_ptr,float recurrent_weights_scale,const float * bias_ptr,int input_size,int num_units,int batch_size,int output_batch_leading_dim,TfLiteFusedActivation activation,int8_t * quantized_input_ptr_batch,int8_t * quantized_hidden_state_ptr_batch,float * scaling_factors,float * hidden_state_ptr_batch,float * output_ptr_batch,bool asymmetric_quantize_inputs,int32_t * zero_points,int32_t * accum_scratch,int32_t * row_sums,bool * compute_row_sums)119 void RnnBatchStep(
120     const float* input_ptr_batch, const int8_t* input_weights_ptr,
121     float input_weights_scale, const int8_t* recurrent_weights_ptr,
122     float recurrent_weights_scale, const float* bias_ptr, int input_size,
123     int num_units, int batch_size, int output_batch_leading_dim,
124     TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
125     int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
126     float* hidden_state_ptr_batch, float* output_ptr_batch,
127     bool asymmetric_quantize_inputs, int32_t* zero_points,
128     int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums) {
129   RnnBatchStep(input_ptr_batch, input_weights_ptr, input_weights_scale,
130                /*aux_input_ptr_batch=*/nullptr,
131                /*aux_input_weights_ptr=*/nullptr,
132                /*aux_input_weights_scale=*/0.0f, recurrent_weights_ptr,
133                recurrent_weights_scale, bias_ptr, input_size,
134                /*aux_input_size=*/0, num_units, batch_size,
135                output_batch_leading_dim, activation, quantized_input_ptr_batch,
136                /*aux_quantized_input_ptr_batch=*/nullptr,
137                quantized_hidden_state_ptr_batch, scaling_factors,
138                hidden_state_ptr_batch, output_ptr_batch,
139                asymmetric_quantize_inputs, zero_points, accum_scratch, row_sums,
140                compute_row_sums);
141 }
142 
RnnBatchStep(const float * input_ptr_batch,const int8_t * input_weights_ptr,float input_weights_scale,const float * aux_input_ptr_batch,const int8_t * aux_input_weights_ptr,float aux_input_weights_scale,const int8_t * recurrent_weights_ptr,float recurrent_weights_scale,const float * bias_ptr,int input_size,int aux_input_size,int num_units,int batch_size,int output_batch_leading_dim,TfLiteFusedActivation activation,int8_t * quantized_input_ptr_batch,int8_t * aux_quantized_input_ptr_batch,int8_t * quantized_hidden_state_ptr_batch,float * scaling_factors,float * hidden_state_ptr_batch,float * output_ptr_batch,bool asymmetric_quantize_inputs,int32_t * zero_points,int32_t * accum_scratch,int32_t * row_sums,bool * compute_row_sums)143 void RnnBatchStep(
144     const float* input_ptr_batch, const int8_t* input_weights_ptr,
145     float input_weights_scale, const float* aux_input_ptr_batch,
146     const int8_t* aux_input_weights_ptr, float aux_input_weights_scale,
147     const int8_t* recurrent_weights_ptr, float recurrent_weights_scale,
148     const float* bias_ptr, int input_size, int aux_input_size, int num_units,
149     int batch_size, int output_batch_leading_dim,
150     TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch,
151     int8_t* aux_quantized_input_ptr_batch,
152     int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
153     float* hidden_state_ptr_batch, float* output_ptr_batch,
154     bool asymmetric_quantize_inputs, int32_t* zero_points,
155     int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums) {
156   // Since the output batch rows may not be contiguous (output_batch_leading_dim
157   // != n_output), we unroll the batched operations where this is the case.
158 
159   int32_t* input_row_sums = nullptr;
160   int32_t* aux_input_row_sums = nullptr;
161   int32_t* recurrent_row_sums = nullptr;
162   if (asymmetric_quantize_inputs) {
163     input_row_sums = row_sums;
164     aux_input_row_sums = row_sums;
165     if (aux_input_ptr_batch) {
166       aux_input_row_sums += num_units;
167     }
168     recurrent_row_sums = aux_input_row_sums + num_units;
169     if (*compute_row_sums) {
170       tensor_utils::ReductionSumVector(input_weights_ptr, input_row_sums,
171                                        num_units, input_size);
172       if (aux_input_ptr_batch) {
173         tensor_utils::ReductionSumVector(aux_input_weights_ptr,
174                                          aux_input_row_sums, num_units,
175                                          aux_input_size);
176       }
177       tensor_utils::ReductionSumVector(
178           recurrent_weights_ptr, recurrent_row_sums, num_units, num_units);
179       *compute_row_sums = false;
180     }
181   }
182 
183   if (output_batch_leading_dim == num_units) {
184     // Output = bias
185     tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
186                                           output_ptr_batch);
187 
188     // Save quantization and matmul computation for all zero input.
189     if (!tensor_utils::IsZeroVector(input_ptr_batch, batch_size * input_size)) {
190       // Quantize input from float to uint8 + quantization params (scaling
191       // factor).
192       tensor_utils::BatchQuantizeFloats(
193           input_ptr_batch, batch_size, input_size, quantized_input_ptr_batch,
194           scaling_factors, zero_points, asymmetric_quantize_inputs);
195       for (int b = 0; b < batch_size; ++b) {
196         scaling_factors[b] *= input_weights_scale;
197       }
198       // Output += input * input_weights
199       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
200           input_weights_ptr, num_units, input_size, quantized_input_ptr_batch,
201           scaling_factors, batch_size, output_ptr_batch,
202           /*per_channel_scale=*/nullptr, zero_points, accum_scratch,
203           input_row_sums, compute_row_sums, /*context=*/nullptr);
204     }
205 
206     if (aux_input_ptr_batch &&
207         !tensor_utils::IsZeroVector(aux_input_ptr_batch,
208                                     batch_size * aux_input_size)) {
209       tensor_utils::BatchQuantizeFloats(
210           aux_input_ptr_batch, batch_size, aux_input_size,
211           aux_quantized_input_ptr_batch, scaling_factors, zero_points,
212           asymmetric_quantize_inputs);
213       for (int b = 0; b < batch_size; ++b) {
214         scaling_factors[b] *= aux_input_weights_scale;
215       }
216 
217       // Output += aux_input * aux_input_weights
218       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
219           aux_input_weights_ptr, num_units, aux_input_size,
220           aux_quantized_input_ptr_batch, scaling_factors, batch_size,
221           output_ptr_batch, /*per_channel_scale=*/nullptr, zero_points,
222           accum_scratch, aux_input_row_sums, compute_row_sums,
223           /*context=*/nullptr);
224     }
225 
226     // Save quantization and matmul computation for all zero input.
227     if (!tensor_utils::IsZeroVector(hidden_state_ptr_batch,
228                                     batch_size * num_units)) {
229       // Quantize hidden_state
230       tensor_utils::BatchQuantizeFloats(
231           hidden_state_ptr_batch, batch_size, num_units,
232           quantized_hidden_state_ptr_batch, scaling_factors, zero_points,
233           asymmetric_quantize_inputs);
234       for (int b = 0; b < batch_size; ++b) {
235         scaling_factors[b] *= recurrent_weights_scale;
236       }
237 
238       // Output += recurrent_weights * hidden_state
239       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
240           recurrent_weights_ptr, num_units, num_units,
241           quantized_hidden_state_ptr_batch, scaling_factors, batch_size,
242           output_ptr_batch, /*per_channel_scale=*/nullptr, zero_points,
243           accum_scratch, recurrent_row_sums, compute_row_sums,
244           /*context=*/nullptr);
245     }
246 
247     // Output = activation(Output) and update hidden_state
248     tensor_utils::ApplyActivationToVector(
249         output_ptr_batch, num_units * batch_size, activation, output_ptr_batch);
250     std::copy_n(output_ptr_batch, num_units * batch_size,
251                 hidden_state_ptr_batch);
252   } else {
253     // Output = bias
254     for (int k = 0; k < batch_size; k++) {
255       std::copy_n(bias_ptr, num_units,
256                   output_ptr_batch + k * output_batch_leading_dim);
257     }
258 
259     // Save quantization and matmul computation for all zero input.
260     if (!tensor_utils::IsZeroVector(input_ptr_batch, batch_size * input_size)) {
261       // Quantize input from float to uint8 + quantization params (scaling
262       // factor).
263       tensor_utils::BatchQuantizeFloats(
264           input_ptr_batch, batch_size, input_size, quantized_input_ptr_batch,
265           scaling_factors, zero_points, asymmetric_quantize_inputs);
266       for (int b = 0; b < batch_size; ++b) {
267         scaling_factors[b] *= input_weights_scale;
268       }
269 
270       // Output += input * input_weights
271       for (int k = 0; k < batch_size; k++) {
272         tensor_utils::MatrixBatchVectorMultiplyAccumulate(
273             input_weights_ptr, num_units, input_size,
274             quantized_input_ptr_batch + k * input_size, &scaling_factors[k],
275             /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim,
276             /*per_channel_scale=*/nullptr, zero_points + k, accum_scratch,
277             input_row_sums, compute_row_sums, /*context=*/nullptr);
278       }
279     }
280 
281     if (aux_input_ptr_batch &&
282         !tensor_utils::IsZeroVector(aux_input_ptr_batch,
283                                     batch_size * aux_input_size)) {
284       tensor_utils::BatchQuantizeFloats(
285           aux_input_ptr_batch, batch_size, aux_input_size,
286           aux_quantized_input_ptr_batch, scaling_factors, zero_points,
287           asymmetric_quantize_inputs);
288       for (int b = 0; b < batch_size; ++b) {
289         scaling_factors[b] *= aux_input_weights_scale;
290       }
291 
292       // Output += aux_input * aux_input_weights
293       for (int k = 0; k < batch_size; k++) {
294         tensor_utils::MatrixBatchVectorMultiplyAccumulate(
295             aux_input_weights_ptr, num_units, aux_input_size,
296             aux_quantized_input_ptr_batch + k * aux_input_size,
297             &scaling_factors[k],
298             /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim,
299             /*per_channel_scale=*/nullptr, zero_points + k, accum_scratch,
300             aux_input_row_sums, compute_row_sums, /*context=*/nullptr);
301       }
302     }
303 
304     // Save quantization and matmul computation for all zero input.
305     if (!tensor_utils::IsZeroVector(hidden_state_ptr_batch,
306                                     batch_size * num_units)) {
307       // Quantize hidden_state
308       tensor_utils::BatchQuantizeFloats(
309           hidden_state_ptr_batch, batch_size, num_units,
310           quantized_hidden_state_ptr_batch, scaling_factors, zero_points,
311           asymmetric_quantize_inputs);
312       for (int b = 0; b < batch_size; ++b) {
313         scaling_factors[b] *= recurrent_weights_scale;
314       }
315 
316       // Output += recurrent_weights * hidden_state
317       for (int k = 0; k < batch_size; k++) {
318         tensor_utils::MatrixBatchVectorMultiplyAccumulate(
319             recurrent_weights_ptr, num_units, num_units,
320             quantized_hidden_state_ptr_batch + k * num_units,
321             &scaling_factors[k], /*n_batch=*/1,
322             output_ptr_batch + k * output_batch_leading_dim,
323             /*per_channel_scale=*/nullptr, zero_points + k, accum_scratch,
324             recurrent_row_sums, compute_row_sums, /*context=*/nullptr);
325       }
326     }
327 
328     // Output = activation(Output) and update hidden_state
329     for (int k = 0; k < batch_size; k++) {
330       tensor_utils::ApplyActivationToVector(
331           output_ptr_batch + k * output_batch_leading_dim, num_units,
332           activation, output_ptr_batch + k * output_batch_leading_dim);
333       std::copy_n(output_ptr_batch + k * output_batch_leading_dim, num_units,
334                   hidden_state_ptr_batch + k * num_units);
335     }
336   }
337 }
338 
339 }  // namespace kernel_utils
340 }  // namespace tflite
341