xref: /aosp_15_r20/external/ComputeLibrary/src/runtime/NEON/functions/NELSTMLayer.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2018-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/runtime/NEON/functions/NELSTMLayer.h"
25 
26 #include "arm_compute/core/Utils.h"
27 #include "arm_compute/core/Validate.h"
28 #include "arm_compute/core/utils/misc/InfoHelpers.h"
29 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
30 #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
31 #include "arm_compute/runtime/common/LSTMParams.h"
32 #include "src/common/utils/Log.h"
33 
34 namespace arm_compute
35 {
36 using namespace arm_compute::misc::shape_calculator;
37 using namespace arm_compute::utils::info_helpers;
38 
39 NELSTMLayer::~NELSTMLayer() = default;
40 
NELSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)41 NELSTMLayer::NELSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
42     : _memory_group(std::move(memory_manager)), _fully_connected_input_gate(), _accum_input_gate1(), _subtract_input_gate(), _pixelwise_mul_input_gate(), _activation_input_gate(),
43       _fully_connected_forget_gate(), _accum_forget_gate1(), _pixelwise_mul_forget_gate(), _activation_forget_gate(), _fully_connected_cell_state(), _gemm_cell_state1(), _transpose_cell_state(),
44       _accum_cell_state1(), _accum_cell_state2(), _pixelwise_mul_cell_state1(), _activation_cell_state(), _cell_clip(), _pixelwise_mul_cell_state2(), _fully_connected_output(),
45       _pixelwise_mul_output_state1(), _accum_output1(), _activation_output(), _activation_output_state(), _pixelwise_mul_output_state2(), _fully_connected_output_state(), _projection_clip(),
46       _copy_cell_state(), _copy_output(), _concat_scratch_buffer(), _concat_inputs_forget_gate(), _concat_weights_forget_gate(), _concat_weights_input_gate(), _concat_weights_output(),
47       _mean_std_norm_input_gate(), _pixelwise_mul_input_gate_coeff(), _accum_input_gate_bias(), _mean_std_norm_forget_gate(), _pixelwise_mul_forget_gate_coeff(), _accum_forget_gate_bias(),
48       _mean_std_norm_cell_gate(), _pixelwise_mul_cell_gate_coeff(), _accum_cell_gate_bias(), _mean_std_norm_output_gate(), _pixelwise_mul_output_gate_coeff(), _accum_output_gate_bias(), _input_gate_out1(),
49       _input_gate_out2(), _input_gate_out3(), _input_gate_out4(), _forget_gate_out1(), _forget_gate_out2(), _forget_gate_out3(), _forget_gate_out4(), _forget_gate_out5(), _forget_gate_out6(),
50       _cell_state_out1(), _cell_state_out2(), _cell_state_out3(), _cell_state_out4(), _cell_state_out5(), _output1(), _output2(), _output3(), _output4(), _cell_state_activation(), _output_state1(), _ones(),
51       _input_layer_norm_out1(), _input_layer_norm_out2(), _forget_layer_norm_out1(), _forget_layer_norm_out2(), _cell_layer_norm_out1(), _cell_layer_norm_out2(), _output_layer_norm_out1(),
52       _output_layer_norm_out2(), _run_peephole_opt(false), _run_cifg_opt(false), _perform_cell_clipping(false), _has_projection_weights(false), _perform_projection_clipping(false), _is_prepared(false),
53       _is_layer_norm_lstm(false)
54 {
55 }
56 
configure(const ITensor * input,const ITensor * input_to_forget_weights,const ITensor * input_to_cell_weights,const ITensor * input_to_output_weights,const ITensor * recurrent_to_forget_weights,const ITensor * recurrent_to_cell_weights,const ITensor * recurrent_to_output_weights,const ITensor * forget_gate_bias,const ITensor * cell_bias,const ITensor * output_gate_bias,const ITensor * output_state_in,const ITensor * cell_state_in,ITensor * scratch_buffer,ITensor * output_state_out,ITensor * cell_state_out,ITensor * output,const LSTMParams<ITensor> & lstm_params,const ActivationLayerInfo & activation_info,float cell_threshold,float projection_threshold)57 void NELSTMLayer::configure(const ITensor *input,
58                             const ITensor *input_to_forget_weights, const ITensor *input_to_cell_weights, const ITensor *input_to_output_weights,
59                             const ITensor *recurrent_to_forget_weights, const ITensor *recurrent_to_cell_weights, const ITensor *recurrent_to_output_weights,
60                             const ITensor *forget_gate_bias, const ITensor *cell_bias, const ITensor *output_gate_bias,
61                             const ITensor *output_state_in, const ITensor *cell_state_in,
62                             ITensor *scratch_buffer, ITensor *output_state_out, ITensor *cell_state_out, ITensor *output,
63                             const LSTMParams<ITensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
64 {
65     ARM_COMPUTE_ERROR_ON_NULLPTR(input,
66                                  input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
67                                  recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
68                                  forget_gate_bias, cell_bias, output_gate_bias,
69                                  output_state_in, cell_state_in,
70                                  scratch_buffer, output_state_out, cell_state_out, output);
71     ARM_COMPUTE_LOG_PARAMS(input,
72                            input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
73                            recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
74                            forget_gate_bias, cell_bias, output_gate_bias,
75                            output_state_in, cell_state_in,
76                            scratch_buffer, output_state_out, cell_state_out, output,
77                            lstm_params, activation_info, cell_threshold, projection_threshold);
78 
79     _is_layer_norm_lstm = lstm_params.use_layer_norm();
80 
81     // Set lstm parameters
82     LSTMParams<ITensorInfo> lstm_params_info{};
83     build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
84 
85     // Validate
86     ARM_COMPUTE_ERROR_THROW_ON(NELSTMLayer::validate(input->info(), input_to_forget_weights->info(),
87                                                      input_to_cell_weights->info(), input_to_output_weights->info(),
88                                                      recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
89                                                      forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
90                                                      output_state_in->info(), cell_state_in->info(),
91                                                      scratch_buffer->info(), output_state_out->info(), cell_state_out->info(), output->info(),
92                                                      lstm_params_info, activation_info, cell_threshold, projection_threshold));
93 
94     const TensorShape cell_state_shape = cell_state_in->info()->tensor_shape();
95 
96     // Configure block that calculates the forget gate
97     // forget_gate = Activation(input * input_to_forget_weights + output_state_in * recurrent_to_forget_weights + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias)
98     // We optimize this as follows:
99     // forget_gate = Activation( (input,output_state_in) * (input_to_forget_weights,recurrent_to_forget_weights) + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias)
100     _forget_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
101     _forget_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
102     _forget_gate_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
103 
104     std::vector<const ITensor *> inputs_vector;
105     inputs_vector.emplace_back(input);
106     inputs_vector.emplace_back(output_state_in);
107 
108     _memory_group.manage(&_forget_gate_out2);
109     _concat_inputs_forget_gate.configure(inputs_vector, &_forget_gate_out2, Window::DimX);
110 
111     std::vector<const ITensor *> weights_vector;
112 
113     weights_vector.emplace_back(input_to_forget_weights);
114     weights_vector.emplace_back(recurrent_to_forget_weights);
115 
116     _concat_weights_forget_gate.configure(weights_vector, &_forget_gate_out6, Window::DimX);
117 
118     _memory_group.manage(&_forget_gate_out5);
119     _fully_connected_forget_gate.configure(&_forget_gate_out2, &_forget_gate_out6, (_is_layer_norm_lstm) ? nullptr : forget_gate_bias, &_forget_gate_out5);
120     _memory_group.manage(&_forget_gate_out1);
121     _memory_group.manage(&_forget_gate_out3);
122     _forget_gate_out6.allocator()->allocate();
123 
124     Tensor *forget_gate_out = &_forget_gate_out5;
125     if(lstm_params.has_peephole_opt())
126     {
127         _forget_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
128 
129         _run_peephole_opt = true;
130         _memory_group.manage(&_forget_gate_out4);
131         _pixelwise_mul_forget_gate.configure(cell_state_in, lstm_params.cell_to_forget_weights(), &_forget_gate_out4, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
132         _accum_forget_gate1.configure(&_forget_gate_out5, &_forget_gate_out4, &_forget_gate_out3, ConvertPolicy::SATURATE);
133         _forget_gate_out4.allocator()->allocate();
134         _forget_gate_out5.allocator()->allocate();
135         forget_gate_out = &_forget_gate_out3;
136     }
137     else
138     {
139         _forget_gate_out3.allocator()->allocate();
140     }
141     if(_is_layer_norm_lstm)
142     {
143         _forget_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
144         _forget_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
145         _memory_group.manage(&_forget_layer_norm_out1);
146         _memory_group.manage(&_forget_layer_norm_out2);
147         _mean_std_norm_forget_gate.configure(forget_gate_out);
148         _pixelwise_mul_forget_gate_coeff.configure(forget_gate_out, lstm_params.forget_layer_norm_weights(), &_forget_layer_norm_out1, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
149         // forget_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
150         forget_gate_out->allocator()->allocate();
151         _accum_forget_gate_bias.configure(&_forget_layer_norm_out1, forget_gate_bias, &_forget_layer_norm_out2, ConvertPolicy::SATURATE);
152         _forget_layer_norm_out1.allocator()->allocate();
153         forget_gate_out = &_forget_layer_norm_out2;
154     }
155     _activation_forget_gate.configure(forget_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
156 
157     // Configure block that calculates the input gate
158     // input_gate = Activation(input * input_to_input_weights + output_state * recurrent_to_input_weights + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
159     // input_gate = 1 - forget_gate, with CIFG
160     // We optimize this as follows:
161     // input_gate = Activation((input,output_state) * (input_to_input_weights,recurrent_to_input_weights) + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
162     _input_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
163     Tensor *input_gate_out = &_input_gate_out1;
164     if(lstm_params.has_cifg_opt())
165     {
166         _memory_group.manage(&_input_gate_out1);
167         _ones.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
168         _subtract_input_gate.configure(&_ones, forget_gate_out, &_input_gate_out1, ConvertPolicy::SATURATE);
169         _ones.allocator()->allocate();
170         _run_cifg_opt = true;
171     }
172     else
173     {
174         _input_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
175         _input_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
176 
177         std::vector<const ITensor *> lstm_weights;
178         lstm_weights.emplace_back(lstm_params.input_to_input_weights());
179         lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
180 
181         _concat_weights_input_gate.configure(lstm_weights, &_input_gate_out2, Window::DimX);
182 
183         _memory_group.manage(&_input_gate_out1);
184         _memory_group.manage(&_input_gate_out4);
185 
186         _fully_connected_input_gate.configure(&_forget_gate_out2, &_input_gate_out2, (_is_layer_norm_lstm) ? nullptr : lstm_params.input_gate_bias(), &_input_gate_out3);
187         _input_gate_out2.allocator()->allocate();
188         input_gate_out = &_input_gate_out3;
189 
190         if(_run_peephole_opt)
191         {
192             _memory_group.manage(&_input_gate_out4);
193             _pixelwise_mul_input_gate.configure(cell_state_in, lstm_params.cell_to_input_weights(), &_input_gate_out4, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
194             _accum_input_gate1.configure(&_input_gate_out3, &_input_gate_out4, &_input_gate_out1, ConvertPolicy::SATURATE);
195             _input_gate_out3.allocator()->allocate();
196             _input_gate_out4.allocator()->allocate();
197             input_gate_out = &_input_gate_out1;
198         }
199         else
200         {
201             _input_gate_out1.allocator()->allocate();
202         }
203 
204         if(_is_layer_norm_lstm)
205         {
206             _input_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
207             _input_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
208             _memory_group.manage(&_input_layer_norm_out1);
209             _memory_group.manage(&_input_layer_norm_out2);
210             _mean_std_norm_input_gate.configure(input_gate_out);
211             _pixelwise_mul_input_gate_coeff.configure(input_gate_out, lstm_params.input_layer_norm_weights(), &_input_layer_norm_out1, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
212             // input_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
213             input_gate_out->allocator()->allocate();
214             _accum_input_gate_bias.configure(&_input_layer_norm_out1, lstm_params.input_gate_bias(), &_input_layer_norm_out2, ConvertPolicy::SATURATE);
215             _input_layer_norm_out1.allocator()->allocate();
216             input_gate_out = &_input_layer_norm_out2;
217         }
218         _activation_input_gate.configure(input_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
219     }
220 
221     // Configure block that calculates the cell state
222     // cell_state = Clip((PixelwiseMul(input_gate, Activation(input * input_to_cell_weights + output_state_in * recurrent_to_cell_weights + cell_bias)) + PixelwiseMul(forget_gate, cell_state)), cell_threshold)
223     TensorShape cell_state1_shape = compute_transposed_shape(*recurrent_to_output_weights->info());
224     _cell_state_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
225     _cell_state_out2.allocator()->init(TensorInfo(cell_state1_shape, 1, input->info()->data_type()));
226     _cell_state_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
227     _cell_state_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
228     _cell_state_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
229 
230     _memory_group.manage(&_cell_state_out1);
231     _fully_connected_cell_state.configure(input, input_to_cell_weights, (_is_layer_norm_lstm) ? nullptr : cell_bias, &_cell_state_out1);
232     _memory_group.manage(&_cell_state_out2);
233     _transpose_cell_state.configure(recurrent_to_cell_weights, &_cell_state_out2);
234     _memory_group.manage(&_cell_state_out3);
235     _gemm_cell_state1.configure(output_state_in, &_cell_state_out2, nullptr, &_cell_state_out3, 1.f, 0.f);
236     _cell_state_out2.allocator()->allocate();
237     _memory_group.manage(&_cell_state_out4);
238     _accum_cell_state1.configure(&_cell_state_out1, &_cell_state_out3, &_cell_state_out4, ConvertPolicy::SATURATE);
239     Tensor *cell_state_out_ptr = &_cell_state_out4;
240     if(_is_layer_norm_lstm)
241     {
242         _cell_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
243         _cell_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
244         _memory_group.manage(&_cell_layer_norm_out1);
245         _memory_group.manage(&_cell_layer_norm_out2);
246         _mean_std_norm_cell_gate.configure(cell_state_out_ptr);
247         _pixelwise_mul_cell_gate_coeff.configure(cell_state_out_ptr, lstm_params.cell_layer_norm_weights(), &_cell_layer_norm_out1, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
248         // cell_state_out_ptr is going to be reassigned, so allocate the tensor that it was assigned to before
249         cell_state_out_ptr->allocator()->allocate();
250         _accum_cell_gate_bias.configure(&_cell_layer_norm_out1, cell_bias, &_cell_layer_norm_out2, ConvertPolicy::SATURATE);
251         _cell_layer_norm_out1.allocator()->allocate();
252         cell_state_out_ptr = &_cell_layer_norm_out2;
253     }
254     _activation_cell_state.configure(cell_state_out_ptr, nullptr, activation_info);
255     _memory_group.manage(&_cell_state_out5);
256     _pixelwise_mul_cell_state1.configure(cell_state_out_ptr, input_gate_out, &_cell_state_out5, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
257     cell_state_out_ptr->allocator()->allocate();
258     _pixelwise_mul_cell_state2.configure(forget_gate_out, cell_state_in, &_cell_state_out3, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
259     _accum_cell_state2.configure(&_cell_state_out5, &_cell_state_out3, &_cell_state_out1, ConvertPolicy::SATURATE);
260     _cell_state_out3.allocator()->allocate();
261     _cell_state_out5.allocator()->allocate();
262     // Perform clipping
263     if(cell_threshold != 0.f)
264     {
265         _perform_cell_clipping = true;
266         _cell_clip.configure(&_cell_state_out1, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold, -cell_threshold));
267     }
268 
269     // Configure block that calculates the output
270     // output_state_out = Activation(input * input_to_output_weights + output_state_in * recurrent_to_output_weights + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
271     // We optimize this as follows:
272     // output_state_out = Activation( (input,output_state_in) * (input_to_output_weights, recurrent_to_output_weights) + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
273     _output1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
274     _output4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
275 
276     std::vector<const ITensor *> in_out_weights;
277     in_out_weights.emplace_back(input_to_output_weights);
278     in_out_weights.emplace_back(recurrent_to_output_weights);
279 
280     _concat_weights_output.configure(in_out_weights, &_output2, Window::DimX);
281     _memory_group.manage(&_output1);
282     _memory_group.manage(&_output4);
283 
284     _fully_connected_output.configure(&_forget_gate_out2, &_output2, (_is_layer_norm_lstm) ? nullptr : output_gate_bias, &_output4);
285 
286     _output2.allocator()->allocate();
287     _forget_gate_out2.allocator()->allocate();
288 
289     Tensor *output_gate_out = &_output4;
290     if(lstm_params.has_peephole_opt())
291     {
292         _output3.allocator()->init(TensorInfo(_cell_state_out1.info()->tensor_shape(), 1, input->info()->data_type()));
293 
294         _memory_group.manage(&_output3);
295         _pixelwise_mul_output_state1.configure(&_cell_state_out1, lstm_params.cell_to_output_weights(), &_output3, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
296         _accum_output1.configure(&_output4, &_output3, &_output1, ConvertPolicy::SATURATE);
297         _output4.allocator()->allocate();
298         output_gate_out = &_output1;
299 
300         // Allocate intermediate buffers
301         _output3.allocator()->allocate();
302     }
303     else
304     {
305         _output1.allocator()->allocate();
306     }
307     if(_is_layer_norm_lstm)
308     {
309         _output_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
310         _output_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
311         _memory_group.manage(&_output_layer_norm_out1);
312         _memory_group.manage(&_output_layer_norm_out2);
313         _mean_std_norm_output_gate.configure(output_gate_out);
314         _pixelwise_mul_output_gate_coeff.configure(output_gate_out, lstm_params.output_layer_norm_weights(), &_output_layer_norm_out1, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
315         // output_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
316         output_gate_out->allocator()->allocate();
317         _accum_output_gate_bias.configure(&_output_layer_norm_out1, output_gate_bias, &_output_layer_norm_out2, ConvertPolicy::SATURATE);
318         _output_layer_norm_out1.allocator()->allocate();
319         output_gate_out = &_output_layer_norm_out2;
320     }
321     _activation_output.configure(output_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
322 
323     // Configure block that calculates the output state
324     /** lstm_res = PixelwiseMul(output, Activation(cell_state))
325      *
326      *                      -- Clip(lstm_res * projection_weights + projection_bias, projection_threshold) , if there is a projection
327      *                     /
328      *  output_state =  --
329      *                     \
330      *                      -- lstm_res , otherwise
331      */
332     ITensor *output_state_out_tmp = lstm_params.has_projection() ? &_output_state1 : output_state_out;
333     _cell_state_activation.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
334     _output_state1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
335 
336     _memory_group.manage(&_cell_state_activation);
337     _activation_output_state.configure(&_cell_state_out1, &_cell_state_activation, activation_info);
338     _pixelwise_mul_output_state2.configure(&_cell_state_activation, output_gate_out, output_state_out_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
339     _cell_state_activation.allocator()->allocate();
340     output_gate_out->allocator()->allocate();
341 
342     if(lstm_params.has_projection())
343     {
344         _has_projection_weights = true;
345         _fully_connected_output_state.configure(output_state_out_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out);
346         _output_state1.allocator()->allocate();
347         // Perform clipping
348         if(projection_threshold != 0.f)
349         {
350             _perform_projection_clipping = true;
351             _projection_clip.configure(output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold));
352         }
353     }
354 
355     // Copy cell state and output
356     _copy_cell_state.configure(&_cell_state_out1, cell_state_out);
357     _copy_output.configure(output_state_out, output);
358 
359     // Vector for holding the tensors to store in scratch buffer
360     std::vector<const ITensor *> scratch_inputs;
361     if(!lstm_params.has_cifg_opt())
362     {
363         scratch_inputs.emplace_back(input_gate_out);
364     }
365     scratch_inputs.emplace_back(&_cell_state_out1);
366     scratch_inputs.emplace_back(forget_gate_out);
367     scratch_inputs.emplace_back(output_gate_out);
368     _concat_scratch_buffer.configure(scratch_inputs, scratch_buffer, Window::DimX);
369     input_gate_out->allocator()->allocate();
370     _cell_state_out1.allocator()->allocate();
371     forget_gate_out->allocator()->allocate();
372     output_gate_out->allocator()->allocate();
373 }
374 
validate(const ITensorInfo * input,const ITensorInfo * input_to_forget_weights,const ITensorInfo * input_to_cell_weights,const ITensorInfo * input_to_output_weights,const ITensorInfo * recurrent_to_forget_weights,const ITensorInfo * recurrent_to_cell_weights,const ITensorInfo * recurrent_to_output_weights,const ITensorInfo * forget_gate_bias,const ITensorInfo * cell_bias,const ITensorInfo * output_gate_bias,const ITensorInfo * output_state_in,const ITensorInfo * cell_state_in,const ITensorInfo * scratch_buffer,const ITensorInfo * output_state_out,const ITensorInfo * cell_state_out,const ITensorInfo * output,const LSTMParams<ITensorInfo> & lstm_params,const ActivationLayerInfo & activation_info,float cell_threshold,float projection_threshold)375 Status NELSTMLayer::validate(const ITensorInfo *input,
376                              const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
377                              const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
378                              const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
379                              const ITensorInfo *output_state_in, const ITensorInfo *cell_state_in,
380                              const ITensorInfo *scratch_buffer, const ITensorInfo *output_state_out, const ITensorInfo *cell_state_out, const ITensorInfo *output,
381                              const LSTMParams<ITensorInfo> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
382 {
383     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input,
384                                         input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
385                                         recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
386                                         forget_gate_bias, cell_bias, output_gate_bias,
387                                         output_state_in, cell_state_in,
388                                         scratch_buffer, output_state_out, cell_state_out, output);
389 
390     // Check data types
391     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
392     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input,
393                                                        input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
394                                                        recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
395                                                        forget_gate_bias, cell_bias, output_gate_bias,
396                                                        output_state_in, cell_state_in,
397                                                        scratch_buffer, output_state_out, cell_state_out, output);
398 
399     // Check dimensions
400     ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2);
401     ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() > 2);
402     ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() > 2);
403     ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() > 2);
404     ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() > 2);
405     ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() > 2);
406     ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() > 2);
407     ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() > 1);
408     ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() > 1);
409     ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() > 1);
410     ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() > 2);
411     ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() > 2);
412     ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() > 2);
413     ARM_COMPUTE_RETURN_ERROR_ON(output_state_out->num_dimensions() > 2);
414     ARM_COMPUTE_RETURN_ERROR_ON(cell_state_out->num_dimensions() > 2);
415     ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 2);
416     ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->dimension(0) * 4 != scratch_buffer->dimension(0)
417                                 && cell_bias->dimension(0) * 3 != scratch_buffer->dimension(0));
418 
419     const unsigned int num_batches = input->dimension(1);
420     const unsigned int num_cells   = input_to_output_weights->dimension(1);
421 
422     if(lstm_params.use_layer_norm())
423     {
424         // If CIFG is used, input layer normalization weights tensor is omitted
425         if(lstm_params.has_cifg_opt())
426         {
427             ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights() != nullptr);
428         }
429         else
430         {
431             ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_layer_norm_weights());
432             ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->num_dimensions() > 1);
433             ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->dimension(0) != num_cells);
434             ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, lstm_params.input_layer_norm_weights());
435         }
436 
437         ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.forget_layer_norm_weights(), lstm_params.cell_layer_norm_weights(), lstm_params.output_layer_norm_weights());
438         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, lstm_params.forget_layer_norm_weights(), lstm_params.cell_layer_norm_weights(), lstm_params.output_layer_norm_weights());
439         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->num_dimensions() > 1);
440         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->num_dimensions() > 1);
441         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->num_dimensions() > 1);
442         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->dimension(0) != num_cells);
443         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->dimension(0) != num_cells);
444         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->dimension(0) != num_cells);
445     }
446 
447     // Check peephole optimization
448     if(lstm_params.has_peephole_opt())
449     {
450         ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_output_weights(), lstm_params.cell_to_forget_weights());
451         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() > 1);
452         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() > 1);
453     }
454 
455     TensorShape      units_out_transposed_shape = compute_transposed_shape(*recurrent_to_output_weights);
456     TensorShape      num_units_transposed_shape = compute_transposed_shape(*forget_gate_bias);
457     const TensorInfo units_out_transposed_info  = TensorInfo(units_out_transposed_shape, 1, input->data_type());
458     const TensorInfo num_units_transposed_info  = TensorInfo(num_units_transposed_shape, 1, input->data_type());
459 
460     TensorInfo input_gate      = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
461     TensorInfo forget_gate     = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
462     TensorInfo output_gate_tmp = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
463     TensorInfo cell_state_tmp  = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
464 
465     std::vector<const ITensorInfo *> inputs_vector;
466     inputs_vector.emplace_back(input);
467     inputs_vector.emplace_back(output_state_in);
468     const TensorShape concat_shape       = arm_compute::misc::shape_calculator::calculate_concatenate_shape(inputs_vector, 0);
469     TensorInfo        forget_gate_concat = TensorInfo(concat_shape, 1, input->data_type());
470     ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(inputs_vector, &forget_gate_concat, Window::DimX));
471 
472     // Validate forget gate
473     ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayer::validate(input, input_to_forget_weights, (lstm_params.use_layer_norm()) ? nullptr : forget_gate_bias, &forget_gate));
474 
475     if(lstm_params.has_peephole_opt())
476     {
477         ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
478         ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
479     }
480     if(lstm_params.use_layer_norm())
481     {
482         ARM_COMPUTE_RETURN_ON_ERROR(NEMeanStdDevNormalizationLayer::validate(&forget_gate));
483         ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&forget_gate, lstm_params.forget_layer_norm_weights(), &forget_gate, 1, ConvertPolicy::SATURATE,
484                                                                         RoundingPolicy::TO_ZERO));
485         ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_gate, forget_gate_bias, &forget_gate, ConvertPolicy::SATURATE));
486     }
487     ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
488 
489     // Validate input gate
490     if(!lstm_params.has_cifg_opt())
491     {
492         ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(),
493                                             lstm_params.recurrent_to_input_weights(),
494                                             lstm_params.input_gate_bias());
495         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() > 2);
496         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() > 2);
497         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() > 1);
498 
499         std::vector<const ITensorInfo *> lstm_weights;
500         lstm_weights.emplace_back(lstm_params.input_to_input_weights());
501         lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
502         TensorShape lstm_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(lstm_weights, 0);
503         TensorInfo  lstm_gate_concat          = TensorInfo(lstm_weights_concat_shape, 1, input->data_type());
504         ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(lstm_weights, &lstm_gate_concat, Window::DimX));
505         ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), (lstm_params.use_layer_norm()) ? nullptr : lstm_params.input_gate_bias(), &input_gate));
506 
507         if(lstm_params.has_peephole_opt())
508         {
509             ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
510             ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1);
511             ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
512             ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_gate, &input_gate, &input_gate, ConvertPolicy::SATURATE));
513         }
514 
515         if(lstm_params.use_layer_norm())
516         {
517             ARM_COMPUTE_RETURN_ON_ERROR(NEMeanStdDevNormalizationLayer::validate(&input_gate));
518             ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&input_gate, lstm_params.input_layer_norm_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
519             ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_gate, lstm_params.input_gate_bias(), &input_gate, ConvertPolicy::SATURATE));
520         }
521         ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
522     }
523     else
524     {
525         ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticSubtraction::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
526     }
527 
528     // Validate cell state
529     ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayer::validate(input, input_to_cell_weights, (lstm_params.use_layer_norm()) ? nullptr : cell_bias, &cell_state_tmp));
530     ARM_COMPUTE_RETURN_ON_ERROR(NEGEMM::validate(output_state_in, &units_out_transposed_info, nullptr, &cell_state_tmp, 1.f, 0.f, GEMMInfo()));
531     ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
532     if(lstm_params.use_layer_norm())
533     {
534         ARM_COMPUTE_RETURN_ON_ERROR(NEMeanStdDevNormalizationLayer::validate(&cell_state_tmp));
535         ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_layer_norm_weights(), &cell_state_tmp, 1, ConvertPolicy::SATURATE,
536                                                                         RoundingPolicy::TO_ZERO));
537         ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&cell_state_tmp, cell_bias, &cell_state_tmp, ConvertPolicy::SATURATE));
538     }
539     ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_state_tmp, nullptr, activation_info));
540     ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
541     ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
542     ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
543     if(cell_threshold != 0.f)
544     {
545         ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold,
546                                                                                                               -cell_threshold)));
547     }
548 
549     // Validate output gate tmp
550     std::vector<const ITensorInfo *> in_out_weights;
551     in_out_weights.emplace_back(input_to_output_weights);
552     in_out_weights.emplace_back(recurrent_to_output_weights);
553     TensorShape in_out_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(in_out_weights, 0);
554     TensorInfo  in_out_gate_concat          = TensorInfo(in_out_weights_concat_shape, 1, input->data_type());
555     ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(in_out_weights, &in_out_gate_concat, Window::DimX));
556 
557     ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayer::validate(input, input_to_output_weights, (lstm_params.use_layer_norm()) ? nullptr : output_gate_bias, &output_gate_tmp));
558 
559     if(lstm_params.has_peephole_opt())
560     {
561         ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_to_output_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
562                                                                         RoundingPolicy::TO_ZERO));
563         ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_gate_tmp, &output_gate_tmp, &output_gate_tmp, ConvertPolicy::SATURATE));
564     }
565     if(lstm_params.use_layer_norm())
566     {
567         ARM_COMPUTE_RETURN_ON_ERROR(NEMeanStdDevNormalizationLayer::validate(&output_gate_tmp));
568         ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&output_gate_tmp, lstm_params.output_layer_norm_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
569                                                                         RoundingPolicy::TO_ZERO));
570         ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_gate_tmp, output_gate_bias, &output_gate_tmp, ConvertPolicy::SATURATE));
571     }
572     ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
573 
574     // Validate output state
575     ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_state_tmp, &cell_state_tmp, activation_info));
576     ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
577     if(lstm_params.has_projection())
578     {
579         ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayer::validate(&output_gate_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out));
580         if(projection_threshold != 0.f)
581         {
582             ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output_state_out, output_state_out,
583                                                                     ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold)));
584         }
585     }
586 
587     // Validate copy kernel
588     ARM_COMPUTE_RETURN_ON_ERROR(NECopy::validate(&cell_state_tmp, cell_state_out));
589     ARM_COMPUTE_RETURN_ON_ERROR(NECopy::validate(output_state_out, output));
590 
591     // Validate scratch concatenation
592     std::vector<const ITensorInfo *> inputs_vector_info_raw;
593     if(!lstm_params.has_cifg_opt())
594     {
595         inputs_vector_info_raw.push_back(&input_gate);
596     }
597     inputs_vector_info_raw.push_back(&cell_state_tmp);
598     inputs_vector_info_raw.push_back(&forget_gate);
599     inputs_vector_info_raw.push_back(&output_gate_tmp);
600 
601     ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(inputs_vector_info_raw, scratch_buffer, Window::DimX));
602     return Status{};
603 }
604 
run()605 void NELSTMLayer::run()
606 {
607     prepare();
608 
609     MemoryGroupResourceScope scope_mg(_memory_group);
610 
611     _concat_inputs_forget_gate.run();
612     _fully_connected_forget_gate.run();
613 
614     if(_run_peephole_opt)
615     {
616         _pixelwise_mul_forget_gate.run();
617         _accum_forget_gate1.run();
618     }
619     if(_is_layer_norm_lstm)
620     {
621         _mean_std_norm_forget_gate.run();
622         _pixelwise_mul_forget_gate_coeff.run();
623         _accum_forget_gate_bias.run();
624     }
625     _activation_forget_gate.run();
626 
627     if(_run_cifg_opt)
628     {
629         if(_ones.info()->data_type() == DataType::F16)
630         {
631             std::fill_n(reinterpret_cast<half *>(_ones.buffer()), _ones.info()->total_size() / _ones.info()->element_size(), 1);
632         }
633         else
634         {
635             std::fill_n(reinterpret_cast<float *>(_ones.buffer()), _ones.info()->total_size() / _ones.info()->element_size(), 1);
636         }
637         _subtract_input_gate.run();
638     }
639     else
640     {
641         _fully_connected_input_gate.run();
642 
643         if(_run_peephole_opt)
644         {
645             _pixelwise_mul_input_gate.run();
646             _accum_input_gate1.run();
647         }
648 
649         if(_is_layer_norm_lstm)
650         {
651             _mean_std_norm_input_gate.run();
652             _pixelwise_mul_input_gate_coeff.run();
653             _accum_input_gate_bias.run();
654         }
655         _activation_input_gate.run();
656     }
657 
658     _fully_connected_cell_state.run();
659     _transpose_cell_state.run();
660     _gemm_cell_state1.run();
661     _accum_cell_state1.run();
662     if(_is_layer_norm_lstm)
663     {
664         _mean_std_norm_cell_gate.run();
665         _pixelwise_mul_cell_gate_coeff.run();
666         _accum_cell_gate_bias.run();
667     }
668 
669     _activation_cell_state.run();
670     _pixelwise_mul_cell_state1.run();
671     _pixelwise_mul_cell_state2.run();
672     _accum_cell_state2.run();
673 
674     if(_perform_cell_clipping)
675     {
676         _cell_clip.run();
677     }
678 
679     _fully_connected_output.run();
680     if(_run_peephole_opt)
681     {
682         _pixelwise_mul_output_state1.run();
683         _accum_output1.run();
684     }
685     if(_is_layer_norm_lstm)
686     {
687         _mean_std_norm_output_gate.run();
688         _pixelwise_mul_output_gate_coeff.run();
689         _accum_output_gate_bias.run();
690     }
691     _activation_output.run();
692 
693     _activation_output_state.run();
694     _pixelwise_mul_output_state2.run();
695 
696     if(_has_projection_weights)
697     {
698         _fully_connected_output_state.run();
699         if(_perform_projection_clipping)
700         {
701             _projection_clip.run();
702         }
703     }
704 
705     _copy_cell_state.run();
706     _copy_output.run();
707 
708     _concat_scratch_buffer.run();
709 }
710 
prepare()711 void NELSTMLayer::prepare()
712 {
713     if(!_is_prepared)
714     {
715         _concat_weights_forget_gate.run();
716         if(!_run_cifg_opt)
717         {
718             _concat_weights_input_gate.run();
719         }
720         _concat_weights_output.run();
721         _is_prepared = true;
722     }
723 }
724 } // namespace arm_compute
725