xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/reference/lstm_cell.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LSTM_CELL_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LSTM_CELL_H_
17 
18 #include <algorithm>
19 #include <cmath>
20 #include <cstdint>
21 
22 #include "tensorflow/lite/kernels/internal/common.h"
23 #include "tensorflow/lite/kernels/internal/reference/concatenation.h"
24 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
25 #include "tensorflow/lite/kernels/internal/types.h"
26 
27 namespace tflite {
28 namespace reference_ops {
29 
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & unextended_prev_activ_shape,const float * prev_activ_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & unextended_bias_shape,const float * bias_data,const RuntimeShape & unextended_prev_state_shape,const float * prev_state_data,const RuntimeShape & unextended_output_state_shape,float * output_state_data,const RuntimeShape & unextended_output_activ_shape,float * output_activ_data,const RuntimeShape & unextended_concat_temp_shape,float * concat_temp_data,const RuntimeShape & unextended_activ_temp_shape,float * activ_temp_data)30 inline void LstmCell(
31     const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
32     const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
33     const float* prev_activ_data, const RuntimeShape& weights_shape,
34     const float* weights_data, const RuntimeShape& unextended_bias_shape,
35     const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
36     const float* prev_state_data,
37     const RuntimeShape& unextended_output_state_shape, float* output_state_data,
38     const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
39     const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
40     const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
41   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
42   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
43   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
44   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
45   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
46   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
47   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
48   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
49   const RuntimeShape input_shape =
50       RuntimeShape::ExtendedShape(4, unextended_input_shape);
51   const RuntimeShape prev_activ_shape =
52       RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
53   const RuntimeShape bias_shape =
54       RuntimeShape::ExtendedShape(4, unextended_bias_shape);
55   const RuntimeShape prev_state_shape =
56       RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
57   const RuntimeShape output_state_shape =
58       RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
59   const RuntimeShape output_activ_shape =
60       RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
61   const RuntimeShape concat_temp_shape =
62       RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
63   const RuntimeShape activ_temp_shape =
64       RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
65   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
66 
67   const int weights_dim_count = weights_shape.DimensionsCount();
68   const int batches =
69       MatchingDim(input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
70                   output_state_shape, 0, output_activ_shape, 0);
71   const int height =
72       MatchingDim(input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
73                   output_state_shape, 1, output_activ_shape, 1);
74   const int width =
75       MatchingDim(input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
76                   output_state_shape, 2, output_activ_shape, 2);
77   const int input_depth = input_shape.Dims(3);
78   const int prev_activ_depth = prev_activ_shape.Dims(3);
79   const int total_input_depth = prev_activ_depth + input_depth;
80   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
81                    total_input_depth);
82   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
83   const int intern_activ_depth =
84       MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
85   TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
86                    intern_activ_depth * total_input_depth);
87   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
88   const int output_depth =
89       MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
90                   3, output_activ_shape, 3);
91   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
92 
93   // Concatenate prev_activ and input data together
94   float const* concat_input_arrays_data[2] = {input_data, prev_activ_data};
95   const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
96                                                        &prev_activ_shape};
97   tflite::ConcatenationParams concat_params;
98   concat_params.axis = 3;
99   concat_params.inputs_count = 2;
100   Concatenation(concat_params, concat_input_arrays_shapes,
101                 concat_input_arrays_data, concat_temp_shape, concat_temp_data);
102 
103   // Fully connected
104   tflite::FullyConnectedParams fc_params;
105   fc_params.float_activation_min = std::numeric_limits<float>::lowest();
106   fc_params.float_activation_max = std::numeric_limits<float>::max();
107   FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
108                  weights_data, bias_shape, bias_data, activ_temp_shape,
109                  activ_temp_data);
110 
111   // Memory state update (the LSTM "guts")
112   for (int b = 0; b < batches; ++b) {
113     for (int w = 0; w < width; ++w) {
114       for (int h = 0; h < height; ++h) {
115         for (int c = 0; c < output_depth; ++c) {
116           const float input_gate =
117               1.f /
118               (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
119                                                       0 * output_depth + c)]));
120           const float new_input = std::tanh(activ_temp_data[Offset(
121               activ_temp_shape, b, h, w, 1 * output_depth + c)]);
122           const float forget_gate =
123               1.f /
124               (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
125                                                       2 * output_depth + c)]));
126           const float output_gate =
127               1.f /
128               (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
129                                                       3 * output_depth + c)]));
130           const float new_state =
131               input_gate * new_input +
132               forget_gate *
133                   prev_state_data[Offset(prev_state_shape, b, h, w, c)];
134           output_state_data[Offset(output_state_shape, b, h, w, c)] = new_state;
135           output_activ_data[Offset(output_activ_shape, b, h, w, c)] =
136               output_gate * std::tanh(new_state);
137         }
138       }
139     }
140   }
141 }
142 
143 // Quantized LSTM cell implementation.
144 // The quantization of the input, output arrays is as follows:
145 //  - The input activations are quantized as uint8 on the interval
146 //    [-1, 127/128].
147 //    The rationale for that is that is the natural interval for output
148 //    activations (see next point) and these need to be concatenated together.
149 //    We could accommodate different ranges by re-scaling, but we empirically
150 //    found that setting the input activations range to be [-1, 127/128] in the
151 //    first place, removing the need for re-scaling, greatly improves accuracy.
152 //  - The output activations are quantized as uint8 on the interval
153 //    [-1, 127/128].
154 //    The rationale for that is that the definition of a LSTM cell makes them
155 //    intrinsically constrained in [-1, 1]; tweaking that to [-1, 127/128]
156 //    makes for simpler, more accurate fixed-point arithmetic.
157 //  - The output-at-previous-timestep state array is obviously quantized as
158 //    the output activations.
159 //  - The internal LSTM memory (not the output-at-previous-timestep, the other
160 //    internal state array) is int16-quantized and may use any power-of-two,
161 //    symmetric range i.e. [-2^N, 2^N * 32767/32768] for any N, which we call
162 //    StateIntegerBits below, see the below discussion of that template
163 //    parameter ("The StateIntegerBits template parameter").
164 //  - The output of the internal fully-connected node is int16-quantized
165 //    on the interval [-8, 8 * 32767/32768], the rationale for which is
166 //    explained just below ("Why [-8, 8] for fully-connected output?").
167 //
168 //
169 // === The StateIntegerBits template parameter ===
170 //
171 // The StateIntegerBits template parameter controls the fixed-point format used
172 // to represent the internal memory of the LSTM cell (not the
173 // output-at-previous-timestep, the other internal state array). It's currently
174 // a template parameter so that the model can control that. The most typical
175 // value for StateIntegerBits is 4. Other plausible values are anywhere between
176 // 3 and 5. We might eventually standardize on a single supported value, e.g. 4,
177 // and drop that template parameter. The reason why it can't be a runtime
178 // parameter is that this controls the fixed-point format used, i.e. we need to
179 // generate actually different code based on it. In particular, we generate code
180 // for a fixed-point tanh() implementation for that format, which internally
181 // uses a fixed-point exp() implementation, which internally uses a
182 // barrel-shifter with a number of steps that depends on StateIntegerBits.
183 // Another consequence of that is that a higher value of StateIntegerBits
184 // results in a more expensive implementation (more barrel shifter steps
185 // needed).
186 //
187 //
188 // === Why [-8, 8] for fully-connected output? ===
189 //
190 // This array is only fed to Logistic and Tanh functions, for which
191 // the quantized implementation will want to use fixed-point arithmetic,
192 // requiring a power-of-two representation interval. Thus, we should right
193 // away quantize this array to a power-of-two interval; otherwise,
194 // implementation will need to rescale that, losing any benefit that a tighter
195 // representation interval might otherwise yield, while introducing some
196 // numerical error and computational overhead.
197 //
198 // Now, Logistic and Tanh
199 // are nearly constant (nearly equal to their horizontal asymptotes)
200 // outside of a small bounded interval around 0:
201 //
202 //   Logistic(4) = 1 - 1.8e-2     Tanh(4) = 1 - 6.7e-4
203 //   Logistic(8) = 1 - 3.4e-4     Tanh(8) = 1 - 2.3e-7
204 //   Logistic(16) = 1 - 1.1e-7    Tanh(16) = 1 - 2.5e-14
205 //
206 // From this, we see that clamping to [-4, 4] would be too inaccurate
207 // (the error of 1.8e-2 on Logistic would be felt even in 8bit precision)
208 // while clamping to [-16, 16] would make no difference even in float32.
209 // However, for a fixed-point implementation in 16-bit integers, using 5
210 // integer bits to represent the [-16, 16] range would leave only 11
211 // fractional bits, giving an increment of 2^-11 = 4.9e-4 between consecutive
212 // representable values. Notice that is higher than the
213 // worst-case clamping error with clamping to [-8, 8]: 3.4e-4 for Logistic.
214 // Using [-8, 8] thus seems like the better compromise overall, enjoying
215 // an increment of 2.4e-4 between representable values and a worst-case
216 // clamping error of 3.4e-4, both better than the increment of 4.9e-4 with
217 // [-16, 16].
218 //
219 // Moreover, all other things being equal, it is nice to choose the narrower
220 // representation range, as that makes the implementation of fixed-point
221 // math functions a little cheaper (each integer bit requires an additional
222 // barrel-shifter atep in the implementation of exp(-x)). That is further
223 // reason to prefer [-8, 8] over [-16, 16]. The choice of [-16, 16] would make
224 // sense for 32-bit float or 32-bit fixed-point quantization, but we are
225 // aiming for 16-bit fixed-point quantization of these internal nodes here.
226 //
227 template <int StateIntegerBits>
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const uint8_t * input_data_uint8,const RuntimeShape & unextended_prev_activ_shape,const uint8_t * prev_activ_data_uint8,const RuntimeShape & weights_shape,const uint8_t * weights_data_uint8,const RuntimeShape & unextended_bias_shape,const int32_t * bias_data_int32,const RuntimeShape & unextended_prev_state_shape,const int16_t * prev_state_data_int16,const RuntimeShape & unextended_output_state_shape,int16_t * output_state_data_int16,const RuntimeShape & unextended_output_activ_shape,uint8_t * output_activ_data_uint8,const RuntimeShape & unextended_concat_temp_shape,uint8_t * concat_temp_data_uint8,const RuntimeShape & unextended_activ_temp_shape,int16_t * activ_temp_data_int16,void * gemmlowp_context)228 inline void LstmCell(const LstmCellParams& params,
229                      const RuntimeShape& unextended_input_shape,
230                      const uint8_t* input_data_uint8,
231                      const RuntimeShape& unextended_prev_activ_shape,
232                      const uint8_t* prev_activ_data_uint8,
233                      const RuntimeShape& weights_shape,
234                      const uint8_t* weights_data_uint8,
235                      const RuntimeShape& unextended_bias_shape,
236                      const int32_t* bias_data_int32,
237                      const RuntimeShape& unextended_prev_state_shape,
238                      const int16_t* prev_state_data_int16,
239                      const RuntimeShape& unextended_output_state_shape,
240                      int16_t* output_state_data_int16,
241                      const RuntimeShape& unextended_output_activ_shape,
242                      uint8_t* output_activ_data_uint8,
243                      const RuntimeShape& unextended_concat_temp_shape,
244                      uint8_t* concat_temp_data_uint8,
245                      const RuntimeShape& unextended_activ_temp_shape,
246                      int16_t* activ_temp_data_int16, void* gemmlowp_context) {
247   (void)gemmlowp_context;  // only used in optimized code.
248   int32_t weights_zero_point = params.weights_zero_point;
249   int32_t accum_multiplier = params.accum_multiplier;
250   int accum_shift = params.accum_shift;
251   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
252   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
253   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
254   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
255   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
256   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
257   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
258   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
259   const RuntimeShape input_shape =
260       RuntimeShape::ExtendedShape(4, unextended_input_shape);
261   const RuntimeShape prev_activ_shape =
262       RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
263   const RuntimeShape bias_shape =
264       RuntimeShape::ExtendedShape(4, unextended_bias_shape);
265   const RuntimeShape prev_state_shape =
266       RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
267   const RuntimeShape output_state_shape =
268       RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
269   const RuntimeShape output_activ_shape =
270       RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
271   const RuntimeShape concat_temp_shape =
272       RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
273   const RuntimeShape activ_temp_shape =
274       RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
275   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
276 
277   // Gather dimensions information, and perform consistency checks.
278   const int weights_dim_count = weights_shape.DimensionsCount();
279   const int outer_size = MatchingFlatSizeSkipDim(
280       input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
281       output_activ_shape);
282   const int input_depth = input_shape.Dims(3);
283   const int prev_activ_depth = prev_activ_shape.Dims(3);
284   const int total_input_depth = prev_activ_depth + input_depth;
285   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
286                    total_input_depth);
287   const int intern_activ_depth =
288       MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
289   TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
290                    intern_activ_depth * total_input_depth);
291   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
292   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
293   const int output_depth =
294       MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
295                   3, output_activ_shape, 3);
296   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
297   const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
298   const int fc_output_depth =
299       MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
300   const int fc_accum_depth = total_input_depth;
301   TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
302 
303   // Depth-concatenate prev_activ and input data together.
304   uint8_t const* concat_input_arrays_data[2] = {input_data_uint8,
305                                                 prev_activ_data_uint8};
306   const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
307                                                        &prev_activ_shape};
308   tflite::ConcatenationParams concat_params;
309   concat_params.axis = 3;
310   concat_params.inputs_count = 2;
311   Concatenation(concat_params, concat_input_arrays_shapes,
312                 concat_input_arrays_data, concat_temp_shape,
313                 concat_temp_data_uint8);
314 
315   // Implementation of the fully connected node inside the LSTM cell.
316   // The operands are 8-bit integers, the accumulators are internally 32bit
317   // integers, and the output is 16-bit fixed-point with 3 integer bits so
318   // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
319   // is explained in the function comment above.
320   for (int b = 0; b < fc_batches; ++b) {
321     for (int out_c = 0; out_c < fc_output_depth; ++out_c) {
322       // Internal accumulation.
323       // Initialize accumulator with the bias-value.
324       int32_t accum = bias_data_int32[out_c];
325       // Accumulation loop.
326       for (int d = 0; d < fc_accum_depth; ++d) {
327         int16_t input_val =
328             concat_temp_data_uint8[b * fc_accum_depth + d] - 128;
329         int16_t weights_val =
330             weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point;
331         accum += input_val * weights_val;
332       }
333       // Down-scale the final int32 accumulator to the scale used by our
334       // (16-bit, using 3 integer bits) fixed-point format. The quantized
335       // multiplier and shift here have been pre-computed offline
336       // (e.g. by toco).
337       accum =
338           MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift);
339       // Saturate, cast to int16, and store to the temporary activations array.
340       accum = std::max(-32768, std::min(32767, accum));
341       activ_temp_data_int16[out_c + fc_output_depth * b] = accum;
342     }
343   }
344 
345   // Rest of the LSTM cell: tanh and logistic math functions, and some adds
346   // and muls, all done in 16-bit fixed-point.
347   for (int b = 0; b < outer_size; ++b) {
348     for (int c = 0; c < output_depth; ++c) {
349       // Define the fixed-point data types that we will use here. All use
350       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
351       // They only differ by the number of integral vs. fractional bits,
352       // determining the range of values that they can represent.
353       //
354       // F0 uses 0 integer bits, range [-1, 1].
355       // This is the return type of math functions such as tanh, logistic,
356       // whose range is in [-1, 1].
357       using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
358       // F3 uses 3 integer bits, range [-8, 8].
359       // This is the range of the previous fully-connected node's output,
360       // which is our input here.
361       using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
362       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
363       // 2^StateIntegerBits]. It's used to represent the internal state, whose
364       // number of integer bits is currently dictated by the model. See comment
365       // on the StateIntegerBits template parameter above.
366       using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
367       // Implementation of input gate, using fixed-point logistic function.
368       F3 input_gate_input = F3::FromRaw(
369           activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]);
370       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
371       // Implementation of input modulation gate, using fixed-point tanh
372       // function.
373       F3 input_modulation_gate_input = F3::FromRaw(
374           activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]);
375       F0 input_modulation_gate_output =
376           gemmlowp::tanh(input_modulation_gate_input);
377       // Implementation of forget gate, using fixed-point logistic function.
378       F3 forget_gate_input = F3::FromRaw(
379           activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]);
380       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
381       // Implementation of output gate, using fixed-point logistic function.
382       F3 output_gate_input = F3::FromRaw(
383           activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]);
384       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
385       // Implementation of internal multiplication nodes, still in fixed-point.
386       F0 input_times_input_modulation =
387           input_gate_output * input_modulation_gate_output;
388       FS prev_state = FS::FromRaw(prev_state_data_int16[b * output_depth + c]);
389       FS prev_state_times_forget_state = forget_gate_output * prev_state;
390       // Implementation of internal addition node, saturating.
391       FS new_state = gemmlowp::SaturatingAdd(
392           gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
393           prev_state_times_forget_state);
394       // Implementation of last internal Tanh node, still in fixed-point.
395       // Since a Tanh fixed-point implementation is specialized for a given
396       // number or integer bits, and each specialization can have a substantial
397       // code size, and we already used above a Tanh on an input with 3 integer
398       // bits, and per the table in the above function comment there is no
399       // significant accuracy to be lost by clamping to [-8, +8] for a
400       // 3-integer-bits representation, let us just do that. This helps people
401       // porting this to targets where code footprint must be minimized.
402       F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
403       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
404       // Store the new internal state back to memory, as 16-bit integers.
405       // Note: here we store the original value with StateIntegerBits, not
406       // the rescaled 3-integer-bits value fed to tanh.
407       output_state_data_int16[b * output_depth + c] = new_state.raw();
408       // Down-scale the output activations to 8-bit integers, saturating,
409       // and store back to memory.
410       int16_t rescaled_output_activ =
411           gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
412       int16_t clamped_output_activ = std::max<int16_t>(
413           -128, std::min<int16_t>(127, rescaled_output_activ));
414       output_activ_data_uint8[b * output_depth + c] =
415           128 + clamped_output_activ;
416     }
417   }
418 }
419 
420 }  // namespace reference_ops
421 }  // namespace tflite
422 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LSTM_CELL_H_
423