xref: /aosp_15_r20/external/armnn/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ClUnidirectionalSequenceLstmFloatWorkload.hpp"
7 #include "ClWorkloadUtils.hpp"
8 
9 #include <aclCommon/ArmComputeUtils.hpp>
10 #include <aclCommon/ArmComputeTensorUtils.hpp>
11 
12 #include <armnn/utility/NumericCast.hpp>
13 #include <armnnUtils/Permute.hpp>
14 #include <cl/test/ClWorkloadFactoryHelper.hpp>
15 #include <backendsCommon/WorkloadUtils.hpp>
16 
17 #include "cl/ClTensorHandle.hpp"
18 
19 namespace
20 {
CalcAclAxis(unsigned int numDimensions,unsigned int axis)21 unsigned int CalcAclAxis(unsigned int numDimensions, unsigned int axis)
22 {
23     return (numDimensions - axis) - 1;
24 }
25 } //namespace
26 
27 namespace armnn
28 {
29 using namespace armcomputetensorutils;
30 
ClUnidirectionalSequenceLstmFloatWorkload(const UnidirectionalSequenceLstmQueueDescriptor & descriptor,const WorkloadInfo & info,const arm_compute::CLCompileContext & clCompileContext)31 ClUnidirectionalSequenceLstmFloatWorkload::ClUnidirectionalSequenceLstmFloatWorkload
32     (const UnidirectionalSequenceLstmQueueDescriptor& descriptor,
33      const WorkloadInfo& info,
34      const arm_compute::CLCompileContext& clCompileContext)
35     : FloatWorkload<UnidirectionalSequenceLstmQueueDescriptor>(descriptor, info)
36 {
37     // Report Profiling Details
38     ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClUnidirectionalSequenceLstmFloatWorkload_Construct",
39                                          descriptor.m_Parameters,
40                                          info,
41                                          GetGuid());
42 
43     const arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
44     arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[2])->GetTensor();
45 
46     TensorInfo inputInfo = info.m_InputTensorInfos[0];
47     TensorInfo outputInfo = info.m_OutputTensorInfos[2];
48 
49     arm_compute::DataType armComputeDataType = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetDataType();
50     armnn::DataType armnnDataType = GetArmNNDataType(armComputeDataType);
51 
52     TensorShape inputLayerShape = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetShape();
53     TensorShape cellStateLayerShape = static_cast<IClTensorHandle*>(m_Data.m_Inputs[2])->GetShape();
54     TensorShape outputLayerShape = static_cast<IClTensorHandle*>(m_Data.m_Outputs[2])->GetShape();
55 
56     unsigned int maxTime = m_Data.m_Parameters.m_TimeMajor ? inputLayerShape[0] : inputLayerShape[1];
57     unsigned int batchSize = m_Data.m_Parameters.m_TimeMajor ? inputLayerShape[1] : inputLayerShape[0];
58     unsigned int inputSize = inputLayerShape[2];
59     unsigned int outputSize = outputLayerShape[2];
60     unsigned int numUnits = cellStateLayerShape[1];
61 
62     const TensorShape timeMajorShapeInput({maxTime, batchSize, inputSize});
63     const TensorShape timeMajorShapeOutput({maxTime, batchSize, outputSize});
64 
65     //
66     // Permute: performed if Unidirectional Sequence Layer inputs/outputs are in batch major format.
67     //
68     if (!m_Data.m_Parameters.m_TimeMajor)
69     {
70         std::unique_ptr<arm_compute::CLPermute> layer(new arm_compute::CLPermute());
71 
72         TensorInfo permuteOutInfo = inputInfo;
73         permuteOutInfo.SetShape(timeMajorShapeInput);
74         BuildArmComputeTensor(m_PermuteFirstOut, permuteOutInfo);
75         armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermuteFirstOut);
76 
77         // Permute to time major format.
78         layer->configure(clCompileContext, &input, &m_PermuteFirstOut, arm_compute::PermutationVector(0U,2U,1U));
79         m_Permute1.reset(layer.release());
80     }
81 
82     //
83     // Split and Concat Tensors
84     //
85     for (unsigned int i = 0; i < maxTime; ++i)
86     {
87         arm_compute::CLTensor splitter_out;
88         arm_compute::CLTensor concat_in;
89 
90         auto splitterTensorInfo = inputInfo;
91         auto concatTensorInfo = outputInfo;
92         splitterTensorInfo.SetShape({batchSize, inputSize});
93         concatTensorInfo.SetShape({batchSize, outputSize});
94         BuildArmComputeTensor(splitter_out, splitterTensorInfo);
95         BuildArmComputeTensor(concat_in, concatTensorInfo);
96 
97         armcomputetensorutils::InitialiseArmComputeTensorEmpty(splitter_out);
98         armcomputetensorutils::InitialiseArmComputeTensorEmpty(concat_in);
99 
100         // append to std::vector<arm_compute::CLTensor>
101         m_SplitterOutputsTensors.push_back(std::move(splitter_out));
102         m_ConcatInputsTensors.push_back(std::move(concat_in));
103     }
104 
105     for (unsigned int i = 0; i < maxTime; ++i)
106     {
107         // append to std::vector<arm_compute::ICLTensor*>
108         m_SplitterOutputs.push_back(&m_SplitterOutputsTensors[i]);
109         m_ConcatInputs.push_back(&m_ConcatInputsTensors[i]);
110     }
111 
112     //
113     // Split
114     //
115     unsigned int numberDimensions = 3;
116     unsigned int dimension = 0; // splitting on 0-dimension (i.e. maxTime dimension)
117 
118     if (maxTime != 1) // ACL split does not work with only one element to split.
119     {
120         ViewsDescriptor splitterDesc(maxTime, numberDimensions);
121         unsigned int splitterDimSizes[3] = {1, batchSize, inputSize};
122         for (unsigned int outputIdx = 0u; outputIdx < maxTime; ++outputIdx)
123         {
124             splitterDesc.SetViewOriginCoord(outputIdx, dimension, splitterDimSizes[dimension] * outputIdx);
125             for (unsigned int dimIdx = 0u; dimIdx < numberDimensions; ++dimIdx)
126             {
127                 splitterDesc.SetViewSize(outputIdx, dimIdx, splitterDimSizes[dimIdx]);
128             }
129         }
130 
131         std::set<unsigned int> splitAxis = ComputeSplitAxis(splitterDesc, timeMajorShapeInput);
132 
133         std::unique_ptr<arm_compute::CLSplit> split_layer(new arm_compute::CLSplit());
134         unsigned int aclAxisSplit = CalcAclAxis(splitterDesc.GetNumDimensions(), *splitAxis.begin());
135         if (!m_Data.m_Parameters.m_TimeMajor)
136         {
137             split_layer->configure(&m_PermuteFirstOut, m_SplitterOutputs, aclAxisSplit);
138         }
139         else
140         {
141             split_layer->configure(&input, m_SplitterOutputs, aclAxisSplit);
142         }
143 
144         split_layer->prepare();
145         m_Splitter.reset(split_layer.release());
146     }
147 
148     //
149     // Lstm
150     //
151     arm_compute::LSTMParams<arm_compute::ICLTensor> lstm_param;
152 
153     m_InputToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
154     BuildArmComputeTensor(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights->GetTensorInfo());
155 
156     m_InputToCellWeightsTensor = std::make_unique<arm_compute::CLTensor>();
157     BuildArmComputeTensor(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights->GetTensorInfo());
158 
159     m_InputToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
160     BuildArmComputeTensor(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights->GetTensorInfo());
161 
162     m_RecurrentToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
163     BuildArmComputeTensor(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights->GetTensorInfo());
164 
165     m_RecurrentToCellWeightsTensor = std::make_unique<arm_compute::CLTensor>();
166     BuildArmComputeTensor(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights->GetTensorInfo());
167 
168     m_RecurrentToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
169     BuildArmComputeTensor(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights->GetTensorInfo());
170 
171     m_ForgetGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
172     BuildArmComputeTensor(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias->GetTensorInfo());
173 
174     m_CellBiasTensor = std::make_unique<arm_compute::CLTensor>();
175     BuildArmComputeTensor(*m_CellBiasTensor, m_Data.m_CellBias->GetTensorInfo());
176 
177     m_OutputGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
178     BuildArmComputeTensor(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias->GetTensorInfo());
179 
180     // for future reference: check the AndroidNN API for the logic here
181     if (!m_Data.m_Parameters.m_CifgEnabled)
182     {
183         m_InputToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
184         BuildArmComputeTensor(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights->GetTensorInfo());
185 
186         m_RecurrentToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
187         BuildArmComputeTensor(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights->GetTensorInfo());
188 
189         m_CellToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
190         if (m_Data.m_CellToInputWeights != nullptr)
191         {
192             BuildArmComputeTensor(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights->GetTensorInfo());
193         }
194 
195         m_InputGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
196         BuildArmComputeTensor(*m_InputGateBiasTensor, m_Data.m_InputGateBias->GetTensorInfo());
197 
198         lstm_param.set_cifg_params(m_InputToInputWeightsTensor.get(),
199                                    m_RecurrentToInputWeightsTensor.get(),
200                                    m_Data.m_CellToInputWeights ? m_CellToInputWeightsTensor.get() : nullptr,
201                                    m_InputGateBiasTensor.get());
202     }
203 
204     if (m_Data.m_Parameters.m_ProjectionEnabled)
205     {
206         m_ProjectionWeightsTensor = std::make_unique<arm_compute::CLTensor>();
207         BuildArmComputeTensor(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights->GetTensorInfo());
208 
209         m_ProjectionBiasTensor = std::make_unique<arm_compute::CLTensor>();
210         if (m_Data.m_ProjectionBias != nullptr)
211         {
212             BuildArmComputeTensor(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias->GetTensorInfo());
213         }
214 
215         lstm_param.set_projection_params(m_ProjectionWeightsTensor.get(),
216                                          m_Data.m_ProjectionBias ? m_ProjectionBiasTensor.get() : nullptr);
217     }
218 
219     if (m_Data.m_Parameters.m_PeepholeEnabled)
220     {
221         m_CellToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
222         BuildArmComputeTensor(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights->GetTensorInfo());
223 
224         m_CellToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
225         BuildArmComputeTensor(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights->GetTensorInfo());
226 
227         lstm_param.set_peephole_params(m_CellToForgetWeightsTensor.get(), m_CellToOutputWeightsTensor.get());
228     }
229 
230     if (m_Data.m_Parameters.m_LayerNormEnabled)
231     {
232         m_InputLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
233         if (!m_Data.m_Parameters.m_CifgEnabled)
234         {
235             BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo());
236         }
237 
238         m_ForgetLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
239         BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo());
240 
241         m_CellLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
242         BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo());
243 
244         m_OutputLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
245         BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo());
246 
247         auto inputNormWeightTensor = m_Data.m_Parameters.m_CifgEnabled ? nullptr : m_InputLayerNormWeightsTensor.get();
248         lstm_param.set_layer_normalization_params(inputNormWeightTensor,
249                                                   m_ForgetLayerNormWeightsTensor.get(),
250                                                   m_CellLayerNormWeightsTensor.get(),
251                                                   m_OutputLayerNormWeightsTensor.get());
252     }
253 
254     arm_compute::ICLTensor& output_state_in = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
255     arm_compute::ICLTensor& cell_state_in   = static_cast<IClTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
256 
257     arm_compute::ICLTensor& output_state_out = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
258     arm_compute::ICLTensor& cell_state_out = static_cast<IClTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
259 
260     m_ScratchBuffer = std::make_unique<arm_compute::CLTensor>();
261     if (m_Data.m_Parameters.m_CifgEnabled)
262     {
263         // scratch_buffer [num_units * 3, batch_size] with CIFG
264         BuildArmComputeTensor(*m_ScratchBuffer, TensorInfo({batchSize, numUnits * 3}, armnnDataType));
265     }
266     else
267     {
268         // scratch_buffer [num_units * 4, batch_size] without CIFG
269         BuildArmComputeTensor(*m_ScratchBuffer, TensorInfo({batchSize, numUnits * 4}, armnnDataType));
270     }
271 
272     // Need to be set at negative threshold to be compatible for ACL
273     float cell_threshold       = m_Data.m_Parameters.m_ClippingThresCell;
274     float projection_threshold = m_Data.m_Parameters.m_ClippingThresProj;
275 
276     // For preparing the object for the class ActivationLayerInfo, consider 5 situations
277     arm_compute::ActivationLayerInfo activationLayerInfo =
278         ConvertLstmActivationFuncToAclLayerInfo(m_Data.m_Parameters.m_ActivationFunc);
279 
280     for (unsigned int i = 0; i != maxTime; ++i)
281     {
282         // Set LSTM input and output ITensors depending on:
283         // input format (timeMajor) & number of LSTM batches (maxTime).
284         arm_compute::ICLTensor* outputLSTM;
285         arm_compute::ICLTensor* inputLSTM;
286         // If there is only one LSTM time major batch, we will not concat OR permute.
287         // Set input of LSTM to be first input ITensor.
288         // Set output of LSTM to be final output ITensor.
289         // LSTM input/output cannot be > 2 dimensions so need to resize its TensorInfo.
290         if (maxTime == 1 && m_Data.m_Parameters.m_TimeMajor)
291         {
292             TensorShape inputShape = GetTensorShape((&input)->info()->tensor_shape(), 1U);
293             TensorShape outputShape = GetTensorShape((&output)->info()->tensor_shape(), 1U);
294             TensorShape inputShapeShrink({inputShape[1], inputShape[2]});
295             TensorShape outputShapeShrink({outputShape[1], outputShape[2]});
296             auto acl_input_shape_shrink = BuildArmComputeTensorShape(inputShapeShrink);
297             auto acl_output_shape_shrink = BuildArmComputeTensorShape(outputShapeShrink);
298             (&input)->info()->set_tensor_shape(acl_input_shape_shrink);
299             inputLSTM = const_cast<arm_compute::ICLTensor*>(&input);
300             (&output)->info()->set_tensor_shape(acl_output_shape_shrink);
301             outputLSTM = &output;
302         }
303             // If there is only one LSTM batch major batch, we will not concat, only permute.
304             // Set input of LSTM to be output of initial permute.
305             // Set output of LSTM to be first element of m_ConcatInputs & use that value later in permute.
306             // LSTM output cannot be > 2 dimensions so need to resize its TensorInfo.
307         else if (maxTime == 1 && !m_Data.m_Parameters.m_TimeMajor)
308         {
309             TensorShape inputShape = GetTensorShape(m_PermuteFirstOut.info()->tensor_shape(), 1U);
310             TensorShape inputShapeShrink({inputShape[1], inputShape[2]});
311             auto acl_input_shape_shrink = BuildArmComputeTensorShape(inputShapeShrink);
312             m_PermuteFirstOut.info()->set_tensor_shape(acl_input_shape_shrink);
313             inputLSTM = &m_PermuteFirstOut;
314             outputLSTM = const_cast<arm_compute::ICLTensor*>(m_ConcatInputs[i]);
315         }
316             // Batch major AND/OR 2+ LSTM batches so will use concat AND/OR permute later on.
317         else
318         {
319             inputLSTM = m_SplitterOutputs[i];
320             outputLSTM = const_cast<arm_compute::ICLTensor*>(m_ConcatInputs[i]);
321         }
322 
323         std::unique_ptr<arm_compute::CLLSTMLayer> lstm_layer(new arm_compute::CLLSTMLayer());
324         lstm_layer->configure(clCompileContext,
325                               inputLSTM,
326                               m_InputToForgetWeightsTensor.get(),
327                               m_InputToCellWeightsTensor.get(),
328                               m_InputToOutputWeightsTensor.get(),
329                               m_RecurrentToForgetWeightsTensor.get(),
330                               m_RecurrentToCellWeightsTensor.get(),
331                               m_RecurrentToOutputWeightsTensor.get(),
332                               m_ForgetGateBiasTensor.get(),
333                               m_CellBiasTensor.get(),
334                               m_OutputGateBiasTensor.get(),
335                               &output_state_in,
336                               &cell_state_in,
337                               m_ScratchBuffer.get(),
338                               &output_state_out,
339                               &cell_state_out,
340                               outputLSTM,
341                               lstm_param,
342                               activationLayerInfo,
343                               cell_threshold,
344                               projection_threshold);
345 
346         m_Layers.emplace_back(std::move(lstm_layer));
347     }
348 
349     armcomputetensorutils::InitialiseArmComputeTensorEmpty(*m_ScratchBuffer);
350 
351     InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights);
352     InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights);
353     InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights);
354     InitializeArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights);
355     InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights);
356     InitializeArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights);
357     InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias);
358     InitializeArmComputeClTensorData(*m_CellBiasTensor, m_Data.m_CellBias);
359     InitializeArmComputeClTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias);
360 
361     if (!m_Data.m_Parameters.m_CifgEnabled)
362     {
363         InitializeArmComputeClTensorData(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights);
364         InitializeArmComputeClTensorData(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights);
365         if (m_Data.m_CellToInputWeights != nullptr)
366         {
367             InitializeArmComputeClTensorData(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights);
368         }
369         InitializeArmComputeClTensorData(*m_InputGateBiasTensor, m_Data.m_InputGateBias);
370     }
371 
372     if (m_Data.m_Parameters.m_ProjectionEnabled)
373     {
374         InitializeArmComputeClTensorData(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights);
375         if (m_Data.m_ProjectionBias != nullptr)
376         {
377             InitializeArmComputeClTensorData(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias);
378         }
379     }
380 
381     if (m_Data.m_Parameters.m_PeepholeEnabled)
382     {
383         InitializeArmComputeClTensorData(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights);
384         InitializeArmComputeClTensorData(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights);
385     }
386 
387     if (m_Data.m_Parameters.m_LayerNormEnabled)
388     {
389         if (!m_Data.m_Parameters.m_CifgEnabled)
390         {
391             InitializeArmComputeClTensorData(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights);
392         }
393         InitializeArmComputeClTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights);
394         InitializeArmComputeClTensorData(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights);
395         InitializeArmComputeClTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights);
396     }
397 
398     // Force Compute Library to perform the necessary copying and reshaping.
399     // After which delete all the input tensors that will no longer be needed.
400     for (uint32_t i = 0; i < m_Layers.size(); ++i)
401     {
402         m_Layers[i]->prepare();
403     }
404 
405     //
406     // Concat
407     //
408 
409     // Expand dimensions of LSTM outputs adding one empty dimension to fit concatenate inputs.
410     TensorShape shape = GetTensorShape(m_ConcatInputs[0]->info()->tensor_shape(), 1U);
411     TensorShape shapeExpandTimeMajor({1, shape[0], shape[1]});
412     TensorShape shapeExpandBatchMajor({shape[0], 1, shape[1]});
413 
414     if (maxTime != 1) // ACL concat does not work with only one element to concatenate.
415     {
416         for (unsigned int i = 0; i < maxTime; ++i)
417         {
418             m_ConcatInputs[i]->info()->set_tensor_shape(BuildArmComputeTensorShape(shapeExpandTimeMajor));
419         }
420 
421         ConcatDescriptor  concatDescriptor(maxTime, numberDimensions);  // maxTime = num inputs (aka. number of views).
422         for (unsigned int inputIdx = 0u; inputIdx < maxTime; ++inputIdx)
423         {
424             concatDescriptor.SetViewOriginCoord(inputIdx, dimension, inputIdx);
425             concatDescriptor.SetConcatAxis(dimension);
426         }
427 
428         m_Concat.reset(new arm_compute::CLConcatenateLayer());
429         unsigned int aclAxisConcat = CalcAclAxis(concatDescriptor.GetNumDimensions(),
430                                                  concatDescriptor.GetConcatAxis());
431         if (!m_Data.m_Parameters.m_TimeMajor)
432         {
433             TensorInfo concatOuputTensorInfo = outputInfo;
434             concatOuputTensorInfo.SetShape(timeMajorShapeOutput);
435             BuildArmComputeTensor(concat_out, concatOuputTensorInfo);
436             armcomputetensorutils::InitialiseArmComputeTensorEmpty(concat_out);
437 
438             m_Concat->configure(m_ConcatInputs, &concat_out, aclAxisConcat);
439         }
440         else
441         {
442             m_Concat->configure(m_ConcatInputs, &output, aclAxisConcat);
443         }
444 
445         m_Concat->prepare();
446     }
447     // If only one LSTM batch, we do not concat and/or permute.
448     // Must ensure final output info is expanded to correct batch major dimensions.
449     else
450     {
451         if (!m_Data.m_Parameters.m_TimeMajor)
452         {
453             (&output)->info()->set_tensor_shape(BuildArmComputeTensorShape(shapeExpandBatchMajor));
454         }
455         else
456         {
457             (&output)->info()->set_tensor_shape(BuildArmComputeTensorShape(shapeExpandTimeMajor));
458         }
459     }
460 
461     //
462     // Permute: only done if input/output are in batch major format.
463     //
464     if (!m_Data.m_Parameters.m_TimeMajor)
465     {
466         // Output now time major. Permute output back to batch major.
467         std::unique_ptr<arm_compute::CLPermute> layer(new arm_compute::CLPermute());
468         if (maxTime != 1)
469         {
470             layer->configure(clCompileContext, &concat_out, &output, arm_compute::PermutationVector(0U, 2U, 1U));
471         }
472         else
473         {
474             layer->configure(clCompileContext, m_ConcatInputs[0], &output, arm_compute::PermutationVector(0U, 2U, 1U));
475         }
476         m_Permute2.reset(layer.release());
477     }
478 
479     FreeUnusedTensors();
480 }
481 
Execute() const482 void ClUnidirectionalSequenceLstmFloatWorkload::Execute() const
483 {
484     ARMNN_SCOPED_PROFILING_EVENT_CL_GUID("ClUnidirectionalSequenceLstmFloatWorkload_Execute", GetGuid());
485     if (m_Permute1)
486     {
487         m_Permute1->run();
488     }
489     if (m_Splitter)
490     {
491         m_Splitter->run();
492     }
493     for (uint32_t i = 0; i < m_Layers.size(); ++i)
494     {
495         m_Layers[i]->run();
496     }
497     if (m_Concat)
498     {
499         m_Concat->run();
500     }
501     if (m_Permute2)
502     {
503         m_Permute2->run();
504     }
505 }
506 
507 arm_compute::Status
ClUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo & input,const TensorInfo & outputStateIn,const TensorInfo & cellStateIn,const TensorInfo & output,const Optional<TensorInfo> & hiddenStateOutput,const Optional<TensorInfo> & cellStateOutput,const UnidirectionalSequenceLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo)508 ClUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
509                                                   const TensorInfo& outputStateIn,
510                                                   const TensorInfo& cellStateIn,
511                                                   const TensorInfo& output,
512                                                   const Optional<TensorInfo>& hiddenStateOutput,
513                                                   const Optional<TensorInfo>& cellStateOutput,
514                                                   const UnidirectionalSequenceLstmDescriptor& descriptor,
515                                                   const LstmInputParamsInfo& paramsInfo)
516 {
517     IgnoreUnused(hiddenStateOutput, cellStateOutput);
518 
519     TensorShape inputLayerShape  = input.GetShape();
520     TensorShape outputLayerShape = outputStateIn.GetShape();
521 
522     unsigned int maxTime    = descriptor.m_TimeMajor?inputLayerShape[0]:inputLayerShape[1];
523     unsigned int batchSize  = descriptor.m_TimeMajor?inputLayerShape[1]:inputLayerShape[0];
524     unsigned int inputSize  = inputLayerShape[2];
525     unsigned int outputSize = outputLayerShape[2];
526 
527     const TensorShape timeMajorShapeInput({maxTime, batchSize, inputSize});
528     const TensorShape timeMajorShapeOutput({maxTime, batchSize, outputSize});
529 
530     arm_compute::Status statusPermute1 = arm_compute::Status(arm_compute::ErrorCode::OK,
531                                                              "Permute1 status");
532     arm_compute::Status statusSplit    = arm_compute::Status(arm_compute::ErrorCode::OK,
533                                                              "Split status");
534     arm_compute::Status statusLSTM     = arm_compute::Status(arm_compute::ErrorCode::OK,
535                                                              "LSTM status");
536     arm_compute::Status statusConcat   = arm_compute::Status(arm_compute::ErrorCode::OK,
537                                                              "Concat status");
538     arm_compute::Status statusPermute2 = arm_compute::Status(arm_compute::ErrorCode::OK,
539                                                              "Permute2 status");
540 
541     const arm_compute::TensorInfo aclInputInfo  = armcomputetensorutils::BuildArmComputeTensorInfo(input);
542     const arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
543 
544     //
545     // Permute validate
546     //
547     TensorInfo              permuteOutInfo    = TensorInfo(input);
548     arm_compute::TensorInfo aclPermuteOutInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permuteOutInfo);
549     if (!descriptor.m_TimeMajor)
550     {
551         statusPermute1 = arm_compute::CLPermute::validate(&aclInputInfo,
552                                                           &aclPermuteOutInfo,
553                                                           arm_compute::PermutationVector(0U, 2U, 1U));
554     }
555 
556     //
557     // Split and Concat Tensors validate
558     //
559     std::vector<arm_compute::TensorInfo>         splitterOutputsTensorInfos;
560     std::vector<arm_compute::TensorInfo>         concatInputsTensorInfos;
561     std::vector<arm_compute::ITensorInfo*>       splitterOutputsTensorInfosPtr;
562     std::vector<const arm_compute::ITensorInfo*> concatInputsTensorInfosPtr;
563     splitterOutputsTensorInfos.reserve(maxTime);
564     concatInputsTensorInfos.reserve(maxTime);
565     for (unsigned int i = 0; i < maxTime; ++i)
566     {
567         arm_compute::TensorInfo splitter_out;
568         arm_compute::TensorInfo concat_in;
569 
570         auto splitterTensorInfo = TensorInfo(input);
571         auto concatTensorInfo   = TensorInfo(output);
572         splitterTensorInfo.SetShape({batchSize, inputSize});
573         concatTensorInfo.SetShape({batchSize, outputSize});
574 
575         arm_compute::TensorInfo aclSplitterTensorInfo
576                                     = armcomputetensorutils::BuildArmComputeTensorInfo(splitterTensorInfo);
577         arm_compute::TensorInfo aclConcatTensorInfo
578                                     = armcomputetensorutils::BuildArmComputeTensorInfo(concatTensorInfo);
579 
580         splitterOutputsTensorInfos.emplace_back(aclSplitterTensorInfo);
581         concatInputsTensorInfos.emplace_back(aclConcatTensorInfo);
582         splitterOutputsTensorInfosPtr.emplace_back(&splitterOutputsTensorInfos[i]);
583         concatInputsTensorInfosPtr.emplace_back(&concatInputsTensorInfos[i]);
584     }
585 
586     //
587     // Split validate
588     //
589     unsigned int numberDimensions = 3;
590     unsigned int dimension        = 0; // splitting on 0-dimension (i.e. maxTime dimension)
591     unsigned int aclAxisSplit     = CalcAclAxis(numberDimensions, dimension);
592 
593     if (maxTime != 1) // ACL split does not work with only one element to split.
594     {
595         if (!descriptor.m_TimeMajor)
596         {
597             statusSplit = arm_compute::CLSplit::validate(&aclPermuteOutInfo,
598                                                          splitterOutputsTensorInfosPtr,
599                                                          aclAxisSplit);
600         }
601         else
602         {
603             statusSplit = arm_compute::CLSplit::validate(&aclInputInfo, splitterOutputsTensorInfosPtr, aclAxisSplit);
604         }
605     }
606 
607     //
608     // LSTM validate
609     //
610 
611     arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info;
612 
613     const TensorInfo& scratchBuffer = TensorInfo(cellStateIn.GetShape(), input.GetDataType());
614     const TensorInfo& outputStateOut = TensorInfo(outputStateIn.GetShape(), input.GetDataType());
615     const TensorInfo& cellStateOut = TensorInfo(cellStateIn.GetShape(), input.GetDataType());
616 
617     // The inputs and outputs
618     const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);
619     const arm_compute::TensorInfo aclCellStateInInfo = BuildArmComputeTensorInfo(cellStateIn);
620     const arm_compute::TensorInfo aclScratchBufferInfo = BuildArmComputeTensorInfo(scratchBuffer);
621     const arm_compute::TensorInfo aclOutputStateOutInfo = BuildArmComputeTensorInfo(outputStateOut);
622     const arm_compute::TensorInfo aclCellStateOutInfo = BuildArmComputeTensorInfo(cellStateOut);
623 
624     // Basic parameters
625     const arm_compute::TensorInfo aclInputToForgetWeightsInfo
626                                       = BuildArmComputeTensorInfo(paramsInfo.GetInputToForgetWeights());
627     const arm_compute::TensorInfo aclInputToCellWeightsInfo
628                                       = BuildArmComputeTensorInfo(paramsInfo.GetInputToCellWeights());
629     const arm_compute::TensorInfo aclInputToOutputWeightsInfo
630                                       = BuildArmComputeTensorInfo(paramsInfo.GetInputToOutputWeights());
631     const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
632                                       = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToForgetWeights());
633     const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
634                                       = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToCellWeights());
635     const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
636                                       = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToOutputWeights());
637     const arm_compute::TensorInfo aclForgetGateBiasInfo
638                                       = BuildArmComputeTensorInfo(paramsInfo.GetForgetGateBias());
639     const arm_compute::TensorInfo aclCellBiasInfo
640                                       = BuildArmComputeTensorInfo(paramsInfo.GetCellBias());
641     const arm_compute::TensorInfo aclOutputGateBiasInfo
642                                       = BuildArmComputeTensorInfo(paramsInfo.GetOutputGateBias());
643 
644     arm_compute::TensorInfo aclInputToInputWeightsInfo;
645     arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
646     arm_compute::TensorInfo aclCellToInputWeightsInfo;
647     arm_compute::TensorInfo aclInputGateBiasInfo;
648     arm_compute::TensorInfo aclProjectionWeightsInfo;
649     arm_compute::TensorInfo aclProjectionBiasInfo;
650     arm_compute::TensorInfo aclCellToForgetWeightsInfo;
651     arm_compute::TensorInfo aclCellToOutputWeightsInfo;
652 
653     arm_compute::TensorInfo aclInputLayerNormWeightsInfo;
654     arm_compute::TensorInfo aclForgetLayerNormWeightsInfo;
655     arm_compute::TensorInfo aclCellLayerNormWeightsInfo;
656     arm_compute::TensorInfo aclOutputLayerNormWeightsInfo;
657 
658 
659     if (!descriptor.m_CifgEnabled)
660     {
661         if (descriptor.m_PeepholeEnabled)
662         {
663             aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToInputWeights());
664         }
665         aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputToInputWeights());
666         aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToInputWeights());
667         aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputGateBias());
668 
669         lstm_params_info.set_cifg_params(&aclInputToInputWeightsInfo,
670                                          &aclRecurrentToInputWeightsInfo,
671                                          descriptor.m_PeepholeEnabled ? &aclCellToInputWeightsInfo : nullptr,
672                                          &aclInputGateBiasInfo);
673     }
674 
675     if (descriptor.m_ProjectionEnabled)
676     {
677         if (paramsInfo.m_ProjectionBias != nullptr)
678         {
679             aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionBias());
680         }
681         aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionWeights());
682 
683         lstm_params_info.set_projection_params(&aclProjectionWeightsInfo,
684                                                paramsInfo.m_ProjectionBias ? &aclProjectionBiasInfo : nullptr);
685     }
686 
687     if (descriptor.m_PeepholeEnabled)
688     {
689         aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToForgetWeights());
690         aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToOutputWeights());
691 
692         lstm_params_info.set_peephole_params(&aclCellToForgetWeightsInfo, &aclCellToOutputWeightsInfo);
693     }
694 
695     if (descriptor.m_LayerNormEnabled)
696     {
697         if (!descriptor.m_CifgEnabled)
698         {
699             aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputLayerNormWeights());
700         }
701         aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetForgetLayerNormWeights());
702         aclCellLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellLayerNormWeights());
703         aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetOutputLayerNormWeights());
704 
705         lstm_params_info.set_layer_normalization_params(descriptor.m_CifgEnabled ? nullptr :
706                                                         &aclInputLayerNormWeightsInfo,
707                                                         &aclForgetLayerNormWeightsInfo,
708                                                         &aclCellLayerNormWeightsInfo,
709                                                         &aclOutputLayerNormWeightsInfo);
710     }
711 
712     // Need to be set at negative threshold to be compatible for ACL
713     float cell_threshold = descriptor.m_ClippingThresCell;
714     float projection_threshold = descriptor.m_ClippingThresProj;
715 
716     arm_compute::ActivationLayerInfo activationLayerInfo =
717         ConvertLstmActivationFuncToAclLayerInfo(descriptor.m_ActivationFunc);
718 
719     for (unsigned int i = 0; i != maxTime; ++i)
720     {
721 
722         // Set LSTM input and output ITensors depending on:
723         // input format (timeMajor) & number of LSTM batches (maxTime).
724         arm_compute::ITensorInfo* outputLSTM;
725         arm_compute::ITensorInfo* inputLSTM;
726         // If there is only one LSTM time major batch, we will not concat OR permute.
727         // Set input of LSTM to be first input ITensor.
728         // Set output of LSTM to be final output ITensor.
729         // LSTM input/output cannot be > 2 dimensions so need to resize its TensorInfo.
730         if (maxTime == 1 && !descriptor.m_TimeMajor)
731         {
732             TensorShape inputShape = GetTensorShape(aclInputInfo.tensor_shape(), 1U);
733             TensorShape outputShape = GetTensorShape(aclOutputInfo.tensor_shape(), 1U);
734             TensorShape inputShapeShrink({inputShape[1], inputShape[2]});
735             TensorShape outputShapeShrink({outputShape[1], outputShape[2]});
736             auto acl_input_shape_shrink = BuildArmComputeTensorShape(inputShapeShrink);
737             auto acl_output_shape_shrink = BuildArmComputeTensorShape(outputShapeShrink);
738             const_cast<arm_compute::TensorInfo*>(&aclInputInfo)->set_tensor_shape(acl_input_shape_shrink);
739             inputLSTM = const_cast<arm_compute::TensorInfo*>(&aclInputInfo);
740             const_cast<arm_compute::TensorInfo*>(&aclOutputInfo)->set_tensor_shape(acl_output_shape_shrink);
741             outputLSTM = const_cast<arm_compute::TensorInfo*>(&aclOutputInfo);
742         }
743             // If there is only one LSTM batch major batch, we will not concat, only permute.
744             // Set input of LSTM to be output of initial permute.
745             // Set output of LSTM to be first element of m_ConcatInputs & use that value later in permute.
746             // LSTM output cannot be > 2 dimensions so need to resize its TensorInfo.
747         else if (maxTime == 1 && !descriptor.m_TimeMajor)
748         {
749             TensorShape inputShape = GetTensorShape(aclPermuteOutInfo.tensor_shape(), 1U);
750             TensorShape inputShapeShrink({inputShape[1], inputShape[2]});
751             auto acl_input_shape_shrink = BuildArmComputeTensorShape(inputShapeShrink);
752             aclPermuteOutInfo.set_tensor_shape(acl_input_shape_shrink);
753             inputLSTM = &aclPermuteOutInfo;
754             outputLSTM = const_cast<arm_compute::ITensorInfo*>(concatInputsTensorInfosPtr[i]);
755         }
756             // Batch major AND/OR 2+ LSTM batches so will use concat AND/OR permute later on.
757         else
758         {
759             inputLSTM = splitterOutputsTensorInfosPtr[i];
760             outputLSTM = const_cast<arm_compute::ITensorInfo*>(concatInputsTensorInfosPtr[i]);
761         }
762 
763         statusLSTM = arm_compute::CLLSTMLayer::validate(inputLSTM,
764                                                         &aclInputToForgetWeightsInfo,
765                                                         &aclInputToCellWeightsInfo,
766                                                         &aclInputToOutputWeightsInfo,
767                                                         &aclRecurrentToForgetWeightsInfo,
768                                                         &aclRecurrentToCellWeightsInfo,
769                                                         &aclRecurrentToOutputWeightsInfo,
770                                                         &aclForgetGateBiasInfo,
771                                                         &aclCellBiasInfo,
772                                                         &aclOutputGateBiasInfo,
773                                                         &aclOutputStateInInfo,
774                                                         &aclCellStateInInfo,
775                                                         &aclScratchBufferInfo,
776                                                         &aclOutputStateOutInfo,
777                                                         &aclCellStateOutInfo,
778                                                         outputLSTM,
779                                                         lstm_params_info,
780                                                         activationLayerInfo,
781                                                         cell_threshold,
782                                                         projection_threshold);
783 
784         if (statusLSTM.error_code() != arm_compute::ErrorCode::OK)
785         {
786             break;
787         }
788     }
789 
790     //
791     // Concat validate
792     //
793 
794     // Expand dimensions of LSTM outputs adding one empty dimension to fit concatenate inputs.
795     TensorShape shape = GetTensorShape(concatInputsTensorInfosPtr[0]->tensor_shape(), 1U);
796     TensorShape shapeExpandTimeMajor({1, shape[0], shape[1]});
797     TensorShape shapeExpandBatchMajor({shape[0], 1, shape[1]});
798 
799     TensorInfo concatOuputTensorInfo = TensorInfo(output);
800     concatOuputTensorInfo.SetShape(timeMajorShapeOutput);
801     arm_compute::TensorInfo aclConcatOuputTensorInfo= BuildArmComputeTensorInfo(concatOuputTensorInfo);
802 
803     if (maxTime != 1) // ACL concat does not work with only one element to concatenate.
804     {
805         for (unsigned int i = 0; i < maxTime; ++i)
806         {
807             auto acl_shape_expand = BuildArmComputeTensorShape(shapeExpandTimeMajor);
808             concatInputsTensorInfos[i].set_tensor_shape(acl_shape_expand);
809         }
810 
811         unsigned int aclAxisConcat = CalcAclAxis(numberDimensions, dimension);
812         if (!descriptor.m_TimeMajor)
813         {
814             statusConcat = arm_compute::CLConcatenateLayer::validate(concatInputsTensorInfosPtr,
815                                                                      &aclConcatOuputTensorInfo,
816                                                                      aclAxisConcat);
817         }
818         else
819         {
820             statusConcat = arm_compute::CLConcatenateLayer::validate(concatInputsTensorInfosPtr,
821                                                                      &aclOutputInfo,
822                                                                      aclAxisConcat);
823         }
824     }
825     // If only one LSTM batch, we do not concat and/or permute.
826     // Must ensure final output info is expanded to correct batch major dimensions.
827     else
828     {
829         if (!descriptor.m_TimeMajor)
830         {
831             const_cast<arm_compute::TensorInfo*>(&aclInputInfo)->set_tensor_shape(
832                 BuildArmComputeTensorShape(shapeExpandBatchMajor));
833         }
834         else
835         {
836             const_cast<arm_compute::TensorInfo*>(&aclInputInfo)->set_tensor_shape(
837                 BuildArmComputeTensorShape(shapeExpandTimeMajor));
838         }
839     }
840     //
841     // Permute validate
842     //
843     if (!descriptor.m_TimeMajor)
844     {
845         // Output now time major. Permute output back to batch major.
846         if (maxTime != 1)
847         {
848             statusPermute2 = arm_compute::CLPermute::validate(&aclConcatOuputTensorInfo,
849                                                               &aclOutputInfo,
850                                                               arm_compute::PermutationVector(0U, 2U, 1U));
851         }
852         else
853         {
854             statusPermute2 = arm_compute::CLPermute::validate(concatInputsTensorInfosPtr[0],
855                                                               &aclOutputInfo,
856                                                               arm_compute::PermutationVector(0U, 2U, 1U));
857         }
858     }
859 
860     auto okCode = arm_compute::ErrorCode::OK;
861     if (statusPermute1.error_code() == okCode &&
862         statusSplit.error_code()    == okCode &&
863         statusLSTM .error_code()    == okCode &&
864         statusConcat.error_code()   == okCode &&
865         statusPermute2.error_code() == okCode)
866     {
867         return arm_compute::Status(arm_compute::ErrorCode::OK,
868                                    "All Unidirectional Sequence LSTM layer validate status OK.");
869     }
870     else
871     {
872         return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
873                                    "Unidirectional Sequence LSTM layer validate status failed.");
874     }
875 }
876 
FreeUnusedTensors()877 void ClUnidirectionalSequenceLstmFloatWorkload::FreeUnusedTensors()
878 {
879     FreeTensorIfUnused(m_InputToInputWeightsTensor);
880     FreeTensorIfUnused(m_InputToForgetWeightsTensor);
881     FreeTensorIfUnused(m_InputToCellWeightsTensor);
882     FreeTensorIfUnused(m_InputToOutputWeightsTensor);
883     FreeTensorIfUnused(m_RecurrentToInputWeightsTensor);
884     FreeTensorIfUnused(m_RecurrentToForgetWeightsTensor);
885     FreeTensorIfUnused(m_RecurrentToCellWeightsTensor);
886     FreeTensorIfUnused(m_RecurrentToOutputWeightsTensor);
887     FreeTensorIfUnused(m_CellToInputWeightsTensor);
888     FreeTensorIfUnused(m_CellToForgetWeightsTensor);
889     FreeTensorIfUnused(m_CellToOutputWeightsTensor);
890     FreeTensorIfUnused(m_InputGateBiasTensor);
891     FreeTensorIfUnused(m_ForgetGateBiasTensor);
892     FreeTensorIfUnused(m_CellBiasTensor);
893     FreeTensorIfUnused(m_OutputGateBiasTensor);
894     FreeTensorIfUnused(m_ProjectionWeightsTensor);
895     FreeTensorIfUnused(m_ProjectionBiasTensor);
896     FreeTensorIfUnused(m_InputLayerNormWeightsTensor);
897     FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor);
898     FreeTensorIfUnused(m_CellLayerNormWeightsTensor);
899     FreeTensorIfUnused(m_OutputLayerNormWeightsTensor);
900     FreeTensorIfUnused(m_ScratchBuffer);
901 }
902 
903 } //namespace armnn
904