1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LSTM_CELL_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LSTM_CELL_H_
17
18 #include <algorithm>
19 #include <cmath>
20 #include <cstdint>
21
22 #include "tensorflow/lite/kernels/internal/common.h"
23 #include "tensorflow/lite/kernels/internal/reference/concatenation.h"
24 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
25 #include "tensorflow/lite/kernels/internal/types.h"
26
27 namespace tflite {
28 namespace reference_ops {
29
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & unextended_prev_activ_shape,const float * prev_activ_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & unextended_bias_shape,const float * bias_data,const RuntimeShape & unextended_prev_state_shape,const float * prev_state_data,const RuntimeShape & unextended_output_state_shape,float * output_state_data,const RuntimeShape & unextended_output_activ_shape,float * output_activ_data,const RuntimeShape & unextended_concat_temp_shape,float * concat_temp_data,const RuntimeShape & unextended_activ_temp_shape,float * activ_temp_data)30 inline void LstmCell(
31 const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
32 const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
33 const float* prev_activ_data, const RuntimeShape& weights_shape,
34 const float* weights_data, const RuntimeShape& unextended_bias_shape,
35 const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
36 const float* prev_state_data,
37 const RuntimeShape& unextended_output_state_shape, float* output_state_data,
38 const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
39 const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
40 const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
41 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
42 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
43 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
44 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
45 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
46 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
47 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
48 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
49 const RuntimeShape input_shape =
50 RuntimeShape::ExtendedShape(4, unextended_input_shape);
51 const RuntimeShape prev_activ_shape =
52 RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
53 const RuntimeShape bias_shape =
54 RuntimeShape::ExtendedShape(4, unextended_bias_shape);
55 const RuntimeShape prev_state_shape =
56 RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
57 const RuntimeShape output_state_shape =
58 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
59 const RuntimeShape output_activ_shape =
60 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
61 const RuntimeShape concat_temp_shape =
62 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
63 const RuntimeShape activ_temp_shape =
64 RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
65 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
66
67 const int weights_dim_count = weights_shape.DimensionsCount();
68 const int batches =
69 MatchingDim(input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
70 output_state_shape, 0, output_activ_shape, 0);
71 const int height =
72 MatchingDim(input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
73 output_state_shape, 1, output_activ_shape, 1);
74 const int width =
75 MatchingDim(input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
76 output_state_shape, 2, output_activ_shape, 2);
77 const int input_depth = input_shape.Dims(3);
78 const int prev_activ_depth = prev_activ_shape.Dims(3);
79 const int total_input_depth = prev_activ_depth + input_depth;
80 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
81 total_input_depth);
82 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
83 const int intern_activ_depth =
84 MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
85 TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
86 intern_activ_depth * total_input_depth);
87 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
88 const int output_depth =
89 MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
90 3, output_activ_shape, 3);
91 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
92
93 // Concatenate prev_activ and input data together
94 float const* concat_input_arrays_data[2] = {input_data, prev_activ_data};
95 const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
96 &prev_activ_shape};
97 tflite::ConcatenationParams concat_params;
98 concat_params.axis = 3;
99 concat_params.inputs_count = 2;
100 Concatenation(concat_params, concat_input_arrays_shapes,
101 concat_input_arrays_data, concat_temp_shape, concat_temp_data);
102
103 // Fully connected
104 tflite::FullyConnectedParams fc_params;
105 fc_params.float_activation_min = std::numeric_limits<float>::lowest();
106 fc_params.float_activation_max = std::numeric_limits<float>::max();
107 FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
108 weights_data, bias_shape, bias_data, activ_temp_shape,
109 activ_temp_data);
110
111 // Memory state update (the LSTM "guts")
112 for (int b = 0; b < batches; ++b) {
113 for (int w = 0; w < width; ++w) {
114 for (int h = 0; h < height; ++h) {
115 for (int c = 0; c < output_depth; ++c) {
116 const float input_gate =
117 1.f /
118 (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
119 0 * output_depth + c)]));
120 const float new_input = std::tanh(activ_temp_data[Offset(
121 activ_temp_shape, b, h, w, 1 * output_depth + c)]);
122 const float forget_gate =
123 1.f /
124 (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
125 2 * output_depth + c)]));
126 const float output_gate =
127 1.f /
128 (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
129 3 * output_depth + c)]));
130 const float new_state =
131 input_gate * new_input +
132 forget_gate *
133 prev_state_data[Offset(prev_state_shape, b, h, w, c)];
134 output_state_data[Offset(output_state_shape, b, h, w, c)] = new_state;
135 output_activ_data[Offset(output_activ_shape, b, h, w, c)] =
136 output_gate * std::tanh(new_state);
137 }
138 }
139 }
140 }
141 }
142
143 // Quantized LSTM cell implementation.
144 // The quantization of the input, output arrays is as follows:
145 // - The input activations are quantized as uint8 on the interval
146 // [-1, 127/128].
147 // The rationale for that is that is the natural interval for output
148 // activations (see next point) and these need to be concatenated together.
149 // We could accommodate different ranges by re-scaling, but we empirically
150 // found that setting the input activations range to be [-1, 127/128] in the
151 // first place, removing the need for re-scaling, greatly improves accuracy.
152 // - The output activations are quantized as uint8 on the interval
153 // [-1, 127/128].
154 // The rationale for that is that the definition of a LSTM cell makes them
155 // intrinsically constrained in [-1, 1]; tweaking that to [-1, 127/128]
156 // makes for simpler, more accurate fixed-point arithmetic.
157 // - The output-at-previous-timestep state array is obviously quantized as
158 // the output activations.
159 // - The internal LSTM memory (not the output-at-previous-timestep, the other
160 // internal state array) is int16-quantized and may use any power-of-two,
161 // symmetric range i.e. [-2^N, 2^N * 32767/32768] for any N, which we call
162 // StateIntegerBits below, see the below discussion of that template
163 // parameter ("The StateIntegerBits template parameter").
164 // - The output of the internal fully-connected node is int16-quantized
165 // on the interval [-8, 8 * 32767/32768], the rationale for which is
166 // explained just below ("Why [-8, 8] for fully-connected output?").
167 //
168 //
169 // === The StateIntegerBits template parameter ===
170 //
171 // The StateIntegerBits template parameter controls the fixed-point format used
172 // to represent the internal memory of the LSTM cell (not the
173 // output-at-previous-timestep, the other internal state array). It's currently
174 // a template parameter so that the model can control that. The most typical
175 // value for StateIntegerBits is 4. Other plausible values are anywhere between
176 // 3 and 5. We might eventually standardize on a single supported value, e.g. 4,
177 // and drop that template parameter. The reason why it can't be a runtime
178 // parameter is that this controls the fixed-point format used, i.e. we need to
179 // generate actually different code based on it. In particular, we generate code
180 // for a fixed-point tanh() implementation for that format, which internally
181 // uses a fixed-point exp() implementation, which internally uses a
182 // barrel-shifter with a number of steps that depends on StateIntegerBits.
183 // Another consequence of that is that a higher value of StateIntegerBits
184 // results in a more expensive implementation (more barrel shifter steps
185 // needed).
186 //
187 //
188 // === Why [-8, 8] for fully-connected output? ===
189 //
190 // This array is only fed to Logistic and Tanh functions, for which
191 // the quantized implementation will want to use fixed-point arithmetic,
192 // requiring a power-of-two representation interval. Thus, we should right
193 // away quantize this array to a power-of-two interval; otherwise,
194 // implementation will need to rescale that, losing any benefit that a tighter
195 // representation interval might otherwise yield, while introducing some
196 // numerical error and computational overhead.
197 //
198 // Now, Logistic and Tanh
199 // are nearly constant (nearly equal to their horizontal asymptotes)
200 // outside of a small bounded interval around 0:
201 //
202 // Logistic(4) = 1 - 1.8e-2 Tanh(4) = 1 - 6.7e-4
203 // Logistic(8) = 1 - 3.4e-4 Tanh(8) = 1 - 2.3e-7
204 // Logistic(16) = 1 - 1.1e-7 Tanh(16) = 1 - 2.5e-14
205 //
206 // From this, we see that clamping to [-4, 4] would be too inaccurate
207 // (the error of 1.8e-2 on Logistic would be felt even in 8bit precision)
208 // while clamping to [-16, 16] would make no difference even in float32.
209 // However, for a fixed-point implementation in 16-bit integers, using 5
210 // integer bits to represent the [-16, 16] range would leave only 11
211 // fractional bits, giving an increment of 2^-11 = 4.9e-4 between consecutive
212 // representable values. Notice that is higher than the
213 // worst-case clamping error with clamping to [-8, 8]: 3.4e-4 for Logistic.
214 // Using [-8, 8] thus seems like the better compromise overall, enjoying
215 // an increment of 2.4e-4 between representable values and a worst-case
216 // clamping error of 3.4e-4, both better than the increment of 4.9e-4 with
217 // [-16, 16].
218 //
219 // Moreover, all other things being equal, it is nice to choose the narrower
220 // representation range, as that makes the implementation of fixed-point
221 // math functions a little cheaper (each integer bit requires an additional
222 // barrel-shifter atep in the implementation of exp(-x)). That is further
223 // reason to prefer [-8, 8] over [-16, 16]. The choice of [-16, 16] would make
224 // sense for 32-bit float or 32-bit fixed-point quantization, but we are
225 // aiming for 16-bit fixed-point quantization of these internal nodes here.
226 //
227 template <int StateIntegerBits>
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const uint8_t * input_data_uint8,const RuntimeShape & unextended_prev_activ_shape,const uint8_t * prev_activ_data_uint8,const RuntimeShape & weights_shape,const uint8_t * weights_data_uint8,const RuntimeShape & unextended_bias_shape,const int32_t * bias_data_int32,const RuntimeShape & unextended_prev_state_shape,const int16_t * prev_state_data_int16,const RuntimeShape & unextended_output_state_shape,int16_t * output_state_data_int16,const RuntimeShape & unextended_output_activ_shape,uint8_t * output_activ_data_uint8,const RuntimeShape & unextended_concat_temp_shape,uint8_t * concat_temp_data_uint8,const RuntimeShape & unextended_activ_temp_shape,int16_t * activ_temp_data_int16,void * gemmlowp_context)228 inline void LstmCell(const LstmCellParams& params,
229 const RuntimeShape& unextended_input_shape,
230 const uint8_t* input_data_uint8,
231 const RuntimeShape& unextended_prev_activ_shape,
232 const uint8_t* prev_activ_data_uint8,
233 const RuntimeShape& weights_shape,
234 const uint8_t* weights_data_uint8,
235 const RuntimeShape& unextended_bias_shape,
236 const int32_t* bias_data_int32,
237 const RuntimeShape& unextended_prev_state_shape,
238 const int16_t* prev_state_data_int16,
239 const RuntimeShape& unextended_output_state_shape,
240 int16_t* output_state_data_int16,
241 const RuntimeShape& unextended_output_activ_shape,
242 uint8_t* output_activ_data_uint8,
243 const RuntimeShape& unextended_concat_temp_shape,
244 uint8_t* concat_temp_data_uint8,
245 const RuntimeShape& unextended_activ_temp_shape,
246 int16_t* activ_temp_data_int16, void* gemmlowp_context) {
247 (void)gemmlowp_context; // only used in optimized code.
248 int32_t weights_zero_point = params.weights_zero_point;
249 int32_t accum_multiplier = params.accum_multiplier;
250 int accum_shift = params.accum_shift;
251 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
252 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
253 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
254 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
255 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
256 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
257 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
258 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
259 const RuntimeShape input_shape =
260 RuntimeShape::ExtendedShape(4, unextended_input_shape);
261 const RuntimeShape prev_activ_shape =
262 RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
263 const RuntimeShape bias_shape =
264 RuntimeShape::ExtendedShape(4, unextended_bias_shape);
265 const RuntimeShape prev_state_shape =
266 RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
267 const RuntimeShape output_state_shape =
268 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
269 const RuntimeShape output_activ_shape =
270 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
271 const RuntimeShape concat_temp_shape =
272 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
273 const RuntimeShape activ_temp_shape =
274 RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
275 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
276
277 // Gather dimensions information, and perform consistency checks.
278 const int weights_dim_count = weights_shape.DimensionsCount();
279 const int outer_size = MatchingFlatSizeSkipDim(
280 input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
281 output_activ_shape);
282 const int input_depth = input_shape.Dims(3);
283 const int prev_activ_depth = prev_activ_shape.Dims(3);
284 const int total_input_depth = prev_activ_depth + input_depth;
285 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
286 total_input_depth);
287 const int intern_activ_depth =
288 MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
289 TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
290 intern_activ_depth * total_input_depth);
291 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
292 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
293 const int output_depth =
294 MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
295 3, output_activ_shape, 3);
296 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
297 const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
298 const int fc_output_depth =
299 MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
300 const int fc_accum_depth = total_input_depth;
301 TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
302
303 // Depth-concatenate prev_activ and input data together.
304 uint8_t const* concat_input_arrays_data[2] = {input_data_uint8,
305 prev_activ_data_uint8};
306 const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
307 &prev_activ_shape};
308 tflite::ConcatenationParams concat_params;
309 concat_params.axis = 3;
310 concat_params.inputs_count = 2;
311 Concatenation(concat_params, concat_input_arrays_shapes,
312 concat_input_arrays_data, concat_temp_shape,
313 concat_temp_data_uint8);
314
315 // Implementation of the fully connected node inside the LSTM cell.
316 // The operands are 8-bit integers, the accumulators are internally 32bit
317 // integers, and the output is 16-bit fixed-point with 3 integer bits so
318 // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
319 // is explained in the function comment above.
320 for (int b = 0; b < fc_batches; ++b) {
321 for (int out_c = 0; out_c < fc_output_depth; ++out_c) {
322 // Internal accumulation.
323 // Initialize accumulator with the bias-value.
324 int32_t accum = bias_data_int32[out_c];
325 // Accumulation loop.
326 for (int d = 0; d < fc_accum_depth; ++d) {
327 int16_t input_val =
328 concat_temp_data_uint8[b * fc_accum_depth + d] - 128;
329 int16_t weights_val =
330 weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point;
331 accum += input_val * weights_val;
332 }
333 // Down-scale the final int32 accumulator to the scale used by our
334 // (16-bit, using 3 integer bits) fixed-point format. The quantized
335 // multiplier and shift here have been pre-computed offline
336 // (e.g. by toco).
337 accum =
338 MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift);
339 // Saturate, cast to int16, and store to the temporary activations array.
340 accum = std::max(-32768, std::min(32767, accum));
341 activ_temp_data_int16[out_c + fc_output_depth * b] = accum;
342 }
343 }
344
345 // Rest of the LSTM cell: tanh and logistic math functions, and some adds
346 // and muls, all done in 16-bit fixed-point.
347 for (int b = 0; b < outer_size; ++b) {
348 for (int c = 0; c < output_depth; ++c) {
349 // Define the fixed-point data types that we will use here. All use
350 // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
351 // They only differ by the number of integral vs. fractional bits,
352 // determining the range of values that they can represent.
353 //
354 // F0 uses 0 integer bits, range [-1, 1].
355 // This is the return type of math functions such as tanh, logistic,
356 // whose range is in [-1, 1].
357 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
358 // F3 uses 3 integer bits, range [-8, 8].
359 // This is the range of the previous fully-connected node's output,
360 // which is our input here.
361 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
362 // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
363 // 2^StateIntegerBits]. It's used to represent the internal state, whose
364 // number of integer bits is currently dictated by the model. See comment
365 // on the StateIntegerBits template parameter above.
366 using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
367 // Implementation of input gate, using fixed-point logistic function.
368 F3 input_gate_input = F3::FromRaw(
369 activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]);
370 F0 input_gate_output = gemmlowp::logistic(input_gate_input);
371 // Implementation of input modulation gate, using fixed-point tanh
372 // function.
373 F3 input_modulation_gate_input = F3::FromRaw(
374 activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]);
375 F0 input_modulation_gate_output =
376 gemmlowp::tanh(input_modulation_gate_input);
377 // Implementation of forget gate, using fixed-point logistic function.
378 F3 forget_gate_input = F3::FromRaw(
379 activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]);
380 F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
381 // Implementation of output gate, using fixed-point logistic function.
382 F3 output_gate_input = F3::FromRaw(
383 activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]);
384 F0 output_gate_output = gemmlowp::logistic(output_gate_input);
385 // Implementation of internal multiplication nodes, still in fixed-point.
386 F0 input_times_input_modulation =
387 input_gate_output * input_modulation_gate_output;
388 FS prev_state = FS::FromRaw(prev_state_data_int16[b * output_depth + c]);
389 FS prev_state_times_forget_state = forget_gate_output * prev_state;
390 // Implementation of internal addition node, saturating.
391 FS new_state = gemmlowp::SaturatingAdd(
392 gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
393 prev_state_times_forget_state);
394 // Implementation of last internal Tanh node, still in fixed-point.
395 // Since a Tanh fixed-point implementation is specialized for a given
396 // number or integer bits, and each specialization can have a substantial
397 // code size, and we already used above a Tanh on an input with 3 integer
398 // bits, and per the table in the above function comment there is no
399 // significant accuracy to be lost by clamping to [-8, +8] for a
400 // 3-integer-bits representation, let us just do that. This helps people
401 // porting this to targets where code footprint must be minimized.
402 F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
403 F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
404 // Store the new internal state back to memory, as 16-bit integers.
405 // Note: here we store the original value with StateIntegerBits, not
406 // the rescaled 3-integer-bits value fed to tanh.
407 output_state_data_int16[b * output_depth + c] = new_state.raw();
408 // Down-scale the output activations to 8-bit integers, saturating,
409 // and store back to memory.
410 int16_t rescaled_output_activ =
411 gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
412 int16_t clamped_output_activ = std::max<int16_t>(
413 -128, std::min<int16_t>(127, rescaled_output_activ));
414 output_activ_data_uint8[b * output_depth + c] =
415 128 + clamped_output_activ;
416 }
417 }
418 }
419
420 } // namespace reference_ops
421 } // namespace tflite
422 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LSTM_CELL_H_
423