1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "NeonUnidirectionalSequenceLstmFloatWorkload.hpp"
7 #include "NeonWorkloadUtils.hpp"
8 
9 #include <aclCommon/ArmComputeUtils.hpp>
10 #include <aclCommon/ArmComputeTensorUtils.hpp>
11 
12 #include <armnn/utility/NumericCast.hpp>
13 #include <armnnUtils/Permute.hpp>
14 #include <neon/test/NeonWorkloadFactoryHelper.hpp>
15 #include <backendsCommon/WorkloadUtils.hpp>
16 
17 #include "neon/NeonTensorHandle.hpp"
18 
19 namespace
20 {
CalcAclAxis(unsigned int numDimensions,unsigned int axis)21 unsigned int CalcAclAxis(unsigned int numDimensions, unsigned int axis)
22 {
23     return (numDimensions - axis) - 1;
24 }
25 } //namespace
26 
27 namespace armnn
28 {
29 using namespace armcomputetensorutils;
30 
NeonUnidirectionalSequenceLstmFloatWorkload(const UnidirectionalSequenceLstmQueueDescriptor & descriptor,const WorkloadInfo & info)31 NeonUnidirectionalSequenceLstmFloatWorkload::NeonUnidirectionalSequenceLstmFloatWorkload
32     (const UnidirectionalSequenceLstmQueueDescriptor& descriptor, const WorkloadInfo& info)
33     : FloatWorkload<UnidirectionalSequenceLstmQueueDescriptor>(descriptor, info)
34 {
35     // Report Profiling Details
36     ARMNN_REPORT_PROFILING_WORKLOAD_DESC("NeonUnidirectionalSequenceLstmFloatWorkload_Construct",
37                                          descriptor.m_Parameters,
38                                          info,
39                                          GetGuid());
40 
41     const arm_compute::ITensor& input = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
42     arm_compute::ITensor& output = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[2])->GetTensor();
43 
44     TensorInfo inputInfo = info.m_InputTensorInfos[0];
45     TensorInfo outputInfo = info.m_OutputTensorInfos[0];
46 
47     arm_compute::DataType armComputeDataType = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetDataType();
48     armnn::DataType armnnDataType = GetArmNNDataType(armComputeDataType);
49 
50     TensorShape inputLayerShape = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetShape();
51     TensorShape cellStateLayerShape = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[2])->GetShape();
52     TensorShape outputLayerShape = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[2])->GetShape();
53 
54     unsigned int maxTime = m_Data.m_Parameters.m_TimeMajor ? inputLayerShape[0] : inputLayerShape[1];
55     unsigned int batchSize = m_Data.m_Parameters.m_TimeMajor ? inputLayerShape[1] : inputLayerShape[0];
56     unsigned int inputSize = inputLayerShape[2];
57     unsigned int outputSize = outputLayerShape[2];
58     unsigned int numUnits = cellStateLayerShape[1];
59 
60     const TensorShape timeMajorShapeInput({maxTime, batchSize, inputSize});
61     const TensorShape timeMajorShapeOutput({maxTime, batchSize, outputSize});
62 
63     //
64     // Permute: performed if Unidirectional Sequence Layer inputs/outputs are in batch major format.
65     //
66     if (!m_Data.m_Parameters.m_TimeMajor)
67     {
68         std::unique_ptr<arm_compute::NEPermute> layer(new arm_compute::NEPermute());
69 
70         TensorInfo permuteOutInfo = inputInfo;
71         permuteOutInfo.SetShape(timeMajorShapeInput);
72         BuildArmComputeTensor(m_PermuteFirstOut, permuteOutInfo);
73         armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermuteFirstOut);
74 
75         // Permute to time major format.
76         layer->configure(&input, &m_PermuteFirstOut, arm_compute::PermutationVector(0U,2U,1U));
77         m_Permute1.reset(layer.release());
78     }
79 
80     //
81     // Split and Concat Tensors
82     //
83     for (unsigned int i = 0; i < maxTime; ++i)
84     {
85         arm_compute::Tensor splitter_out;
86         arm_compute::Tensor concat_in;
87 
88         auto splitterTensorInfo = inputInfo;
89         auto concatTensorInfo = outputInfo;
90         splitterTensorInfo.SetShape({batchSize, inputSize});
91         concatTensorInfo.SetShape({batchSize, outputSize});
92         BuildArmComputeTensor(splitter_out, splitterTensorInfo);
93         BuildArmComputeTensor(concat_in, concatTensorInfo);
94 
95         armcomputetensorutils::InitialiseArmComputeTensorEmpty(splitter_out);
96         armcomputetensorutils::InitialiseArmComputeTensorEmpty(concat_in);
97 
98         // append to std::vector<arm_compute::Tensor>
99         m_SplitterOutputsTensors.push_back(std::move(splitter_out));
100         m_ConcatInputsTensors.push_back(std::move(concat_in));
101     }
102 
103     for (unsigned int i = 0; i < maxTime; ++i)
104     {
105         // append to std::vector<arm_compute::ITensor*>
106         m_SplitterOutputs.push_back(&m_SplitterOutputsTensors[i]);
107         m_ConcatInputs.push_back(&m_ConcatInputsTensors[i]);
108     }
109 
110     //
111     // Split
112     //
113     unsigned int numberDimensions = 3;
114     unsigned int dimension = 0; // splitting on 0-dimension (i.e. maxTime dimension)
115 
116     if (maxTime != 1) // ACL split does not work with only one element to split.
117     {
118         ViewsDescriptor splitterDesc(maxTime, numberDimensions);
119         unsigned int splitterDimSizes[3] = {1, batchSize, inputSize};
120         for (unsigned int outputIdx = 0u; outputIdx < maxTime; ++outputIdx)
121         {
122             splitterDesc.SetViewOriginCoord(outputIdx, dimension, splitterDimSizes[dimension] * outputIdx);
123             for (unsigned int dimIdx = 0u; dimIdx < numberDimensions; ++dimIdx)
124             {
125                 splitterDesc.SetViewSize(outputIdx, dimIdx, splitterDimSizes[dimIdx]);
126             }
127         }
128 
129         std::set<unsigned int> splitAxis = ComputeSplitAxis(splitterDesc, timeMajorShapeInput);
130 
131         std::unique_ptr<arm_compute::NESplit> split_layer(new arm_compute::NESplit());
132         unsigned int                          aclAxisSplit = CalcAclAxis(splitterDesc.GetNumDimensions(),
133                                                                          *splitAxis.begin());
134         if (!m_Data.m_Parameters.m_TimeMajor)
135         {
136             split_layer->configure(&m_PermuteFirstOut, m_SplitterOutputs, aclAxisSplit);
137         } else
138         {
139             split_layer->configure(&input, m_SplitterOutputs, aclAxisSplit);
140         }
141 
142         split_layer->prepare();
143         m_Splitter.reset(split_layer.release());
144     }
145 
146     //
147     // Lstm
148     //
149     arm_compute::LSTMParams<arm_compute::ITensor> lstm_param;
150 
151     m_InputToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
152     BuildArmComputeTensor(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights->GetTensorInfo());
153 
154     m_InputToCellWeightsTensor = std::make_unique<arm_compute::Tensor>();
155     BuildArmComputeTensor(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights->GetTensorInfo());
156 
157     m_InputToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
158     BuildArmComputeTensor(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights->GetTensorInfo());
159 
160     m_RecurrentToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
161     BuildArmComputeTensor(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights->GetTensorInfo());
162 
163     m_RecurrentToCellWeightsTensor = std::make_unique<arm_compute::Tensor>();
164     BuildArmComputeTensor(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights->GetTensorInfo());
165 
166     m_RecurrentToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
167     BuildArmComputeTensor(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights->GetTensorInfo());
168 
169     m_ForgetGateBiasTensor = std::make_unique<arm_compute::Tensor>();
170     BuildArmComputeTensor(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias->GetTensorInfo());
171 
172     m_CellBiasTensor = std::make_unique<arm_compute::Tensor>();
173     BuildArmComputeTensor(*m_CellBiasTensor, m_Data.m_CellBias->GetTensorInfo());
174 
175     m_OutputGateBiasTensor = std::make_unique<arm_compute::Tensor>();
176     BuildArmComputeTensor(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias->GetTensorInfo());
177 
178     // for future reference: check the AndroidNN API for the logic here
179     if (!m_Data.m_Parameters.m_CifgEnabled)
180     {
181         m_InputToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
182         BuildArmComputeTensor(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights->GetTensorInfo());
183 
184         m_RecurrentToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
185         BuildArmComputeTensor(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights->GetTensorInfo());
186 
187         m_CellToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
188         if (m_Data.m_CellToInputWeights != nullptr)
189         {
190             BuildArmComputeTensor(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights->GetTensorInfo());
191         }
192 
193         m_InputGateBiasTensor = std::make_unique<arm_compute::Tensor>();
194         BuildArmComputeTensor(*m_InputGateBiasTensor, m_Data.m_InputGateBias->GetTensorInfo());
195 
196         lstm_param.set_cifg_params(m_InputToInputWeightsTensor.get(),
197                                    m_RecurrentToInputWeightsTensor.get(),
198                                    m_Data.m_CellToInputWeights ? m_CellToInputWeightsTensor.get() : nullptr,
199                                    m_InputGateBiasTensor.get());
200     }
201 
202     if (m_Data.m_Parameters.m_ProjectionEnabled)
203     {
204         m_ProjectionWeightsTensor = std::make_unique<arm_compute::Tensor>();
205         BuildArmComputeTensor(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights->GetTensorInfo());
206 
207         m_ProjectionBiasTensor = std::make_unique<arm_compute::Tensor>();
208         if (m_Data.m_ProjectionBias != nullptr)
209         {
210             BuildArmComputeTensor(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias->GetTensorInfo());
211         }
212 
213         lstm_param.set_projection_params(m_ProjectionWeightsTensor.get(),
214                                          m_Data.m_ProjectionBias ? m_ProjectionBiasTensor.get() : nullptr);
215     }
216 
217     if (m_Data.m_Parameters.m_PeepholeEnabled)
218     {
219         m_CellToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
220         BuildArmComputeTensor(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights->GetTensorInfo());
221 
222         m_CellToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
223         BuildArmComputeTensor(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights->GetTensorInfo());
224 
225         lstm_param.set_peephole_params(m_CellToForgetWeightsTensor.get(), m_CellToOutputWeightsTensor.get());
226     }
227 
228     if (m_Data.m_Parameters.m_LayerNormEnabled)
229     {
230         m_InputLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
231         if (!m_Data.m_Parameters.m_CifgEnabled)
232         {
233             BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo());
234         }
235 
236         m_ForgetLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
237         BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo());
238 
239         m_CellLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
240         BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo());
241 
242         m_OutputLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
243         BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo());
244 
245         auto inputNormWeightTensor = m_Data.m_Parameters.m_CifgEnabled ? nullptr : m_InputLayerNormWeightsTensor.get();
246         lstm_param.set_layer_normalization_params(inputNormWeightTensor,
247                                                   m_ForgetLayerNormWeightsTensor.get(),
248                                                   m_CellLayerNormWeightsTensor.get(),
249                                                   m_OutputLayerNormWeightsTensor.get());
250     }
251 
252     arm_compute::ITensor& output_state_in = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
253     arm_compute::ITensor& cell_state_in   = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
254 
255     arm_compute::ITensor& output_state_out = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
256     arm_compute::ITensor& cell_state_out = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
257 
258     m_ScratchBuffer = std::make_unique<arm_compute::Tensor>();
259     if (m_Data.m_Parameters.m_CifgEnabled)
260     {
261         // scratch_buffer [num_units * 3, batch_size] with CIFG
262         BuildArmComputeTensor(*m_ScratchBuffer, TensorInfo({batchSize, numUnits * 3}, armnnDataType));
263     }
264     else
265     {
266         // scratch_buffer [num_units * 4, batch_size] without CIFG
267         BuildArmComputeTensor(*m_ScratchBuffer, TensorInfo({batchSize, numUnits * 4}, armnnDataType));
268     }
269 
270     // Need to be set at negative threshold to be compatible for ACL
271     float cell_threshold       = m_Data.m_Parameters.m_ClippingThresCell;
272     float projection_threshold = m_Data.m_Parameters.m_ClippingThresProj;
273 
274     // For preparing the object for the class ActivationLayerInfo, consider 5 situations
275     arm_compute::ActivationLayerInfo activationLayerInfo =
276         ConvertLstmActivationFuncToAclLayerInfo(m_Data.m_Parameters.m_ActivationFunc);
277 
278     for (unsigned int i = 0; i != maxTime; ++i)
279     {
280         // Set LSTM input and output ITensors depending on:
281         // input format (timeMajor) & number of LSTM batches (maxTime).
282         arm_compute::ITensor* outputLSTM;
283         arm_compute::ITensor* inputLSTM;
284 
285         // If there is only one LSTM time major batch, we will not concat OR permute.
286         // Set input of LSTM to be first input ITensor.
287         // Set output of LSTM to be final output ITensor.
288         // LSTM input/output cannot be > 2 dimensions so need to resize its TensorInfo.
289         if (maxTime == 1 && m_Data.m_Parameters.m_TimeMajor)
290         {
291             TensorShape inputShape = GetTensorShape(input.info()->tensor_shape(), 1U);
292             TensorShape outputShape = GetTensorShape((&output)->info()->tensor_shape(), 1U);
293 
294             TensorShape inputShapeShrink({inputShape[1], inputShape[2]});
295             TensorShape outputShapeShrink({outputShape[1], outputShape[2]});
296 
297             auto acl_input_shape_shrink = BuildArmComputeTensorShape(inputShapeShrink);
298             auto acl_output_shape_shrink = BuildArmComputeTensorShape(outputShapeShrink);
299 
300             input.info()->set_tensor_shape(acl_input_shape_shrink);
301             inputLSTM = const_cast<arm_compute::ITensor*>(&input);
302 
303             output.info()->set_tensor_shape(acl_output_shape_shrink);
304             outputLSTM = &output;
305         }
306         // If there is only one LSTM batch major batch, we will not concat, only permute.
307         // Set input of LSTM to be output of initial permute.
308         // Set output of LSTM to be first element of m_ConcatInputs & use that value later in permute.
309         // LSTM output cannot be > 2 dimensions so need to resize its TensorInfo.
310         else if (maxTime == 1 && !m_Data.m_Parameters.m_TimeMajor)
311         {
312             TensorShape inputShape = GetTensorShape(m_PermuteFirstOut.info()->tensor_shape(), 1U);
313             TensorShape inputShapeShrink({inputShape[1], inputShape[2]});
314             auto acl_input_shape_shrink = BuildArmComputeTensorShape(inputShapeShrink);
315             m_PermuteFirstOut.info()->set_tensor_shape(acl_input_shape_shrink);
316             inputLSTM = &m_PermuteFirstOut;
317 
318             outputLSTM = const_cast<arm_compute::ITensor*>(m_ConcatInputs[i]);
319         }
320         // Batch major AND/OR 2+ LSTM batches so will use concat AND/OR permute later on.
321         else
322         {
323             inputLSTM = m_SplitterOutputs[i];
324             outputLSTM = const_cast<arm_compute::ITensor*>(m_ConcatInputs[i]);
325         }
326 
327         std::unique_ptr<arm_compute::NELSTMLayer> lstm_layer(new arm_compute::NELSTMLayer());
328         lstm_layer->configure(inputLSTM,
329                               m_InputToForgetWeightsTensor.get(),
330                               m_InputToCellWeightsTensor.get(),
331                               m_InputToOutputWeightsTensor.get(),
332                               m_RecurrentToForgetWeightsTensor.get(),
333                               m_RecurrentToCellWeightsTensor.get(),
334                               m_RecurrentToOutputWeightsTensor.get(),
335                               m_ForgetGateBiasTensor.get(),
336                               m_CellBiasTensor.get(),
337                               m_OutputGateBiasTensor.get(),
338                               &output_state_in,
339                               &cell_state_in,
340                               m_ScratchBuffer.get(),
341                               &output_state_out,
342                               &cell_state_out,
343                               outputLSTM,
344                               lstm_param,
345                               activationLayerInfo,
346                               cell_threshold,
347                               projection_threshold);
348 
349         m_Layers.emplace_back(std::move(lstm_layer));
350     }
351 
352     armcomputetensorutils::InitialiseArmComputeTensorEmpty(*m_ScratchBuffer);
353 
354     InitializeArmComputeTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights);
355     InitializeArmComputeTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights);
356     InitializeArmComputeTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights);
357     InitializeArmComputeTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights);
358     InitializeArmComputeTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights);
359     InitializeArmComputeTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights);
360     InitializeArmComputeTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias);
361     InitializeArmComputeTensorData(*m_CellBiasTensor, m_Data.m_CellBias);
362     InitializeArmComputeTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias);
363 
364     if (!m_Data.m_Parameters.m_CifgEnabled)
365     {
366         InitializeArmComputeTensorData(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights);
367         InitializeArmComputeTensorData(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights);
368         if (m_Data.m_CellToInputWeights != nullptr)
369         {
370             InitializeArmComputeTensorData(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights);
371         }
372         InitializeArmComputeTensorData(*m_InputGateBiasTensor, m_Data.m_InputGateBias);
373     }
374 
375     if (m_Data.m_Parameters.m_ProjectionEnabled)
376     {
377         InitializeArmComputeTensorData(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights);
378         if (m_Data.m_ProjectionBias != nullptr)
379         {
380             InitializeArmComputeTensorData(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias);
381         }
382     }
383 
384     if (m_Data.m_Parameters.m_PeepholeEnabled)
385     {
386         InitializeArmComputeTensorData(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights);
387         InitializeArmComputeTensorData(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights);
388     }
389 
390     if (m_Data.m_Parameters.m_LayerNormEnabled)
391     {
392         if (!m_Data.m_Parameters.m_CifgEnabled)
393         {
394             InitializeArmComputeTensorData(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights);
395         }
396         InitializeArmComputeTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights);
397         InitializeArmComputeTensorData(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights);
398         InitializeArmComputeTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights);
399     }
400 
401     // Force Compute Library to perform the necessary copying and reshaping.
402     // After which delete all the input tensors that will no longer be needed.
403     for (uint32_t i = 0; i < m_Layers.size(); ++i)
404     {
405         m_Layers[i]->prepare();
406     }
407 
408     //
409     // Concat
410     //
411 
412     // Expand dimensions of LSTM outputs adding one empty dimension to fit concatenate inputs.
413     TensorShape shape = GetTensorShape(m_ConcatInputs[0]->info()->tensor_shape(), 1U);
414     TensorShape shapeExpandTimeMajor({1, shape[0], shape[1]});
415     TensorShape shapeExpandBatchMajor({shape[0], 1, shape[1]});
416 
417     if (maxTime != 1) // ACL concat does not work with only one element to concatenate.
418     {
419         for (unsigned int i = 0; i < maxTime; ++i)
420         {
421             m_ConcatInputs[i]->info()->set_tensor_shape(BuildArmComputeTensorShape(shapeExpandTimeMajor));
422         }
423 
424         ConcatDescriptor  concatDescriptor(maxTime, numberDimensions);  // maxTime = num inputs (aka. number of views).
425         for (unsigned int inputIdx = 0u; inputIdx < maxTime; ++inputIdx)
426         {
427             concatDescriptor.SetViewOriginCoord(inputIdx, dimension, inputIdx);
428             concatDescriptor.SetConcatAxis(dimension);
429         }
430 
431         m_Concat.reset(new arm_compute::NEConcatenateLayer());
432         unsigned int aclAxisConcat = CalcAclAxis(concatDescriptor.GetNumDimensions(), concatDescriptor.GetConcatAxis());
433         if (!m_Data.m_Parameters.m_TimeMajor)
434         {
435             TensorInfo concatOutputTensorInfo = outputInfo;
436             concatOutputTensorInfo.SetShape(timeMajorShapeOutput);
437             BuildArmComputeTensor(concat_out, concatOutputTensorInfo);
438             armcomputetensorutils::InitialiseArmComputeTensorEmpty(concat_out);
439 
440             m_Concat->configure(m_ConcatInputs, &concat_out, aclAxisConcat);
441         }
442         else
443         {
444             m_Concat->configure(m_ConcatInputs, &output, aclAxisConcat);
445         }
446 
447         m_Concat->prepare();
448     }
449     // If only one LSTM batch, we do not concat and/or permute.
450     // Must ensure final output info is expanded to correct batch major dimensions.
451     else
452     {
453         if (!m_Data.m_Parameters.m_TimeMajor)
454         {
455             output.info()->set_tensor_shape(BuildArmComputeTensorShape(shapeExpandBatchMajor));
456         }
457         else
458         {
459             output.info()->set_tensor_shape(BuildArmComputeTensorShape(shapeExpandTimeMajor));
460         }
461     }
462 
463     //
464     // Permute: only done if input/output are in batch major format.
465     //
466     if (!m_Data.m_Parameters.m_TimeMajor)
467     {
468         // Output now time major. Permute output back to batch major.
469         std::unique_ptr<arm_compute::NEPermute> layer(new arm_compute::NEPermute());
470         if (maxTime != 1)
471         {
472             layer->configure(&concat_out, &output, arm_compute::PermutationVector(0U, 2U, 1U));
473         }
474         else
475         {
476             layer->configure(m_ConcatInputs[0], &output, arm_compute::PermutationVector(0U, 2U, 1U));
477         }
478         m_Permute2.reset(layer.release());
479     }
480 
481     FreeUnusedTensors();
482 }
483 
Execute() const484 void NeonUnidirectionalSequenceLstmFloatWorkload::Execute() const
485 {
486     ARMNN_SCOPED_PROFILING_EVENT_NEON_GUID("NeonUnidirectionalSequenceLstmFloatWorkload_Execute", GetGuid());
487     if (m_Permute1)
488     {
489         m_Permute1->run();
490     }
491     if (m_Splitter)
492     {
493         m_Splitter->run();
494     }
495     for (uint32_t i = 0; i < m_Layers.size(); ++i)
496     {
497         m_Layers[i]->run();
498     }
499     if (m_Concat)
500     {
501         m_Concat->run();
502     }
503     if (m_Permute2)
504     {
505         m_Permute2->run();
506     }
507 }
508 
509 arm_compute::Status
NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo & input,const TensorInfo & outputStateIn,const TensorInfo & cellStateIn,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const UnidirectionalSequenceLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo)510 NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
511                                                     const TensorInfo& outputStateIn,
512                                                     const TensorInfo& cellStateIn,
513                                                     const TensorInfo& outputStateOut,
514                                                     const TensorInfo& cellStateOut,
515                                                     const TensorInfo& output,
516                                                     const UnidirectionalSequenceLstmDescriptor& descriptor,
517                                                     const LstmInputParamsInfo& paramsInfo)
518 {
519     TensorShape inputLayerShape = input.GetShape();
520     TensorShape outputLayerShape = outputStateIn.GetShape();
521 
522     unsigned int maxTime = descriptor.m_TimeMajor ? inputLayerShape[0] : inputLayerShape[1];
523     unsigned int batchSize = descriptor.m_TimeMajor ? inputLayerShape[1] : inputLayerShape[0];
524     unsigned int inputSize = inputLayerShape[2];
525     unsigned int outputSize = outputLayerShape[2];
526 
527     const TensorShape timeMajorShapeInput({maxTime, batchSize, inputSize});
528     const TensorShape timeMajorShapeOutput({maxTime, batchSize, outputSize});
529 
530     arm_compute::Status statusPermute1 = arm_compute::Status(arm_compute::ErrorCode::OK,
531                                                              "Permute1 status");
532     arm_compute::Status statusSplit = arm_compute::Status(arm_compute::ErrorCode::OK,
533                                                           "Split status");
534     arm_compute::Status statusLSTM = arm_compute::Status(arm_compute::ErrorCode::OK,
535                                                          "LSTM status");
536     arm_compute::Status statusConcat = arm_compute::Status(arm_compute::ErrorCode::OK,
537                                                            "Concat status");
538     arm_compute::Status statusPermute2 = arm_compute::Status(arm_compute::ErrorCode::OK,
539                                                              "Permute2 status");
540 
541     const arm_compute::TensorInfo aclInputInfo  = armcomputetensorutils::BuildArmComputeTensorInfo(input);
542     const arm_compute::TensorInfo aclOutputInfo  = armcomputetensorutils::BuildArmComputeTensorInfo(output);
543 
544     //
545     // Permute validate
546     //
547     TensorInfo permuteOutInfo = TensorInfo(input);
548     arm_compute::TensorInfo aclPermuteOutInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permuteOutInfo);
549     if (!descriptor.m_TimeMajor)
550     {
551         statusPermute1 =  arm_compute::NEPermute::validate(&aclInputInfo,
552                                                            &aclPermuteOutInfo,
553                                                            arm_compute::PermutationVector(0U, 2U, 1U));
554     }
555 
556     //
557     // Split and Concat Tensors validate
558     //
559     std::vector<arm_compute::TensorInfo> splitterOutputsTensorInfos;
560     std::vector<arm_compute::TensorInfo> concatInputsTensorInfos;
561     std::vector<arm_compute::ITensorInfo*> splitterOutputsTensorInfosPtr;
562     std::vector<const arm_compute::ITensorInfo*> concatInputsTensorInfosPtr;
563     splitterOutputsTensorInfos.reserve(maxTime);
564     concatInputsTensorInfos.reserve(maxTime);
565     for (unsigned int i = 0; i < maxTime; ++i)
566     {
567         arm_compute::TensorInfo splitter_out;
568         arm_compute::TensorInfo concat_in;
569 
570         auto splitterTensorInfo = TensorInfo(input);
571         auto concatTensorInfo   = TensorInfo(output);
572         splitterTensorInfo.SetShape({batchSize, inputSize});
573         concatTensorInfo.SetShape({batchSize, outputSize});
574 
575         arm_compute::TensorInfo aclSplitterTensorInfo
576             = armcomputetensorutils::BuildArmComputeTensorInfo(splitterTensorInfo);
577         arm_compute::TensorInfo aclConcatTensorInfo
578             = armcomputetensorutils::BuildArmComputeTensorInfo(concatTensorInfo);
579 
580         splitterOutputsTensorInfos.emplace_back(aclSplitterTensorInfo);
581         concatInputsTensorInfos.emplace_back(aclConcatTensorInfo);
582         splitterOutputsTensorInfosPtr.emplace_back(&splitterOutputsTensorInfos[i]);
583         concatInputsTensorInfosPtr.emplace_back(&concatInputsTensorInfos[i]);
584     }
585 
586     //
587     // Split validate
588     //
589     unsigned int numberDimensions = 3;
590     unsigned int dimension = 0; // splitting on 0-dimension (i.e. maxTime dimension)
591     unsigned int aclAxisSplit = CalcAclAxis(numberDimensions, dimension);
592 
593     if (maxTime != 1) // ACL split does not work with only one element to split.
594     {
595         if (!descriptor.m_TimeMajor)
596         {
597             statusSplit = arm_compute::NESplit::validate(&aclPermuteOutInfo,
598                                                          splitterOutputsTensorInfosPtr,
599                                                          aclAxisSplit);
600         } else
601         {
602             statusSplit = arm_compute::NESplit::validate(&aclInputInfo, splitterOutputsTensorInfosPtr, aclAxisSplit);
603         }
604     }
605 
606     //
607     // LSTM validate
608     //
609 
610     arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info;
611 
612     const TensorInfo& scratchBuffer = TensorInfo(cellStateIn.GetShape(), input.GetDataType());
613 
614     // The inputs and outputs
615     const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);
616     const arm_compute::TensorInfo aclCellStateInInfo = BuildArmComputeTensorInfo(cellStateIn);
617     const arm_compute::TensorInfo aclScratchBufferInfo = BuildArmComputeTensorInfo(scratchBuffer);
618     const arm_compute::TensorInfo aclOutputStateOutInfo = BuildArmComputeTensorInfo(outputStateOut);
619     const arm_compute::TensorInfo aclCellStateOutInfo = BuildArmComputeTensorInfo(cellStateOut);
620 
621     // Basic parameters
622     const arm_compute::TensorInfo aclInputToForgetWeightsInfo
623                                       = BuildArmComputeTensorInfo(paramsInfo.GetInputToForgetWeights());
624     const arm_compute::TensorInfo aclInputToCellWeightsInfo
625                                       = BuildArmComputeTensorInfo(paramsInfo.GetInputToCellWeights());
626     const arm_compute::TensorInfo aclInputToOutputWeightsInfo
627                                       = BuildArmComputeTensorInfo(paramsInfo.GetInputToOutputWeights());
628     const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
629                                       = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToForgetWeights());
630     const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
631                                       = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToCellWeights());
632     const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
633                                       = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToOutputWeights());
634     const arm_compute::TensorInfo aclForgetGateBiasInfo
635                                       = BuildArmComputeTensorInfo(paramsInfo.GetForgetGateBias());
636     const arm_compute::TensorInfo aclCellBiasInfo
637                                       = BuildArmComputeTensorInfo(paramsInfo.GetCellBias());
638     const arm_compute::TensorInfo aclOutputGateBiasInfo
639                                       = BuildArmComputeTensorInfo(paramsInfo.GetOutputGateBias());
640 
641     arm_compute::TensorInfo aclInputToInputWeightsInfo;
642     arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
643     arm_compute::TensorInfo aclCellToInputWeightsInfo;
644     arm_compute::TensorInfo aclInputGateBiasInfo;
645     arm_compute::TensorInfo aclProjectionWeightsInfo;
646     arm_compute::TensorInfo aclProjectionBiasInfo;
647     arm_compute::TensorInfo aclCellToForgetWeightsInfo;
648     arm_compute::TensorInfo aclCellToOutputWeightsInfo;
649 
650     arm_compute::TensorInfo aclInputLayerNormWeightsInfo;
651     arm_compute::TensorInfo aclForgetLayerNormWeightsInfo;
652     arm_compute::TensorInfo aclCellLayerNormWeightsInfo;
653     arm_compute::TensorInfo aclOutputLayerNormWeightsInfo;
654 
655 
656     if (!descriptor.m_CifgEnabled)
657     {
658         if (descriptor.m_PeepholeEnabled)
659         {
660             aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToInputWeights());
661         }
662         aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputToInputWeights());
663         aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToInputWeights());
664         aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputGateBias());
665 
666         lstm_params_info.set_cifg_params(&aclInputToInputWeightsInfo,
667                                          &aclRecurrentToInputWeightsInfo,
668                                          descriptor.m_PeepholeEnabled ? &aclCellToInputWeightsInfo : nullptr,
669                                          &aclInputGateBiasInfo);
670     }
671 
672     if (descriptor.m_ProjectionEnabled)
673     {
674         if (paramsInfo.m_ProjectionBias != nullptr)
675         {
676             aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionBias());
677         }
678         aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionWeights());
679 
680         lstm_params_info.set_projection_params(&aclProjectionWeightsInfo,
681                                                paramsInfo.m_ProjectionBias ? &aclProjectionBiasInfo : nullptr);
682     }
683 
684     if (descriptor.m_PeepholeEnabled)
685     {
686         aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToForgetWeights());
687         aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToOutputWeights());
688 
689         lstm_params_info.set_peephole_params(&aclCellToForgetWeightsInfo, &aclCellToOutputWeightsInfo);
690     }
691 
692     if (descriptor.m_LayerNormEnabled)
693     {
694         if (!descriptor.m_CifgEnabled)
695         {
696             aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputLayerNormWeights());
697         }
698         aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetForgetLayerNormWeights());
699         aclCellLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellLayerNormWeights());
700         aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetOutputLayerNormWeights());
701 
702         lstm_params_info.set_layer_normalization_params(descriptor.m_CifgEnabled ? nullptr :
703                                                         &aclInputLayerNormWeightsInfo,
704                                                         &aclForgetLayerNormWeightsInfo,
705                                                         &aclCellLayerNormWeightsInfo,
706                                                         &aclOutputLayerNormWeightsInfo);
707     }
708 
709     // Need to be set at negative threshold to be compatible for ACL
710     float cell_threshold = descriptor.m_ClippingThresCell;
711     float projection_threshold = descriptor.m_ClippingThresProj;
712 
713     arm_compute::ActivationLayerInfo activationLayerInfo =
714         ConvertLstmActivationFuncToAclLayerInfo(descriptor.m_ActivationFunc);
715 
716     for (unsigned int i = 0; i != maxTime; ++i)
717     {
718 
719         // Set LSTM input and output ITensors depending on:
720         // input format (timeMajor) & number of LSTM batches (maxTime).
721         arm_compute::ITensorInfo* outputLSTM;
722         arm_compute::ITensorInfo* inputLSTM;
723 
724         // If there is only one LSTM time major batch, we will not concat OR permute.
725         // Set input of LSTM to be first input ITensor.
726         // Set output of LSTM to be final output ITensor.
727         // LSTM input/output cannot be > 2 dimensions so need to resize its TensorInfo.
728         if (maxTime == 1 && !descriptor.m_TimeMajor)
729         {
730             TensorShape inputShape = GetTensorShape(aclInputInfo.tensor_shape(), 1U);
731             TensorShape outputShape = GetTensorShape(aclOutputInfo.tensor_shape(), 1U);
732 
733             TensorShape inputShapeShrink({inputShape[1], inputShape[2]});
734             TensorShape outputShapeShrink({outputShape[1], outputShape[2]});
735 
736             auto acl_input_shape_shrink = BuildArmComputeTensorShape(inputShapeShrink);
737             auto acl_output_shape_shrink = BuildArmComputeTensorShape(outputShapeShrink);
738 
739             const_cast<arm_compute::TensorInfo*>(&aclInputInfo)->set_tensor_shape(acl_input_shape_shrink);
740             inputLSTM = const_cast<arm_compute::TensorInfo*>(&aclInputInfo);
741 
742             const_cast<arm_compute::TensorInfo*>(&aclOutputInfo)->set_tensor_shape(acl_output_shape_shrink);
743             outputLSTM = const_cast<arm_compute::TensorInfo*>(&aclOutputInfo);
744         }
745         // If there is only one LSTM batch major batch, we will not concat, only permute.
746         // Set input of LSTM to be output of initial permute.
747         // Set output of LSTM to be first element of m_ConcatInputs & use that value later in permute.
748         // LSTM output cannot be > 2 dimensions so need to resize its TensorInfo.
749         else if (maxTime == 1 && !descriptor.m_TimeMajor)
750         {
751             TensorShape inputShape = GetTensorShape(aclPermuteOutInfo.tensor_shape(), 1U);
752             TensorShape inputShapeShrink({inputShape[1], inputShape[2]});
753             auto acl_input_shape_shrink = BuildArmComputeTensorShape(inputShapeShrink);
754             aclPermuteOutInfo.set_tensor_shape(acl_input_shape_shrink);
755             inputLSTM = &aclPermuteOutInfo;
756 
757             outputLSTM = const_cast<arm_compute::ITensorInfo*>(concatInputsTensorInfosPtr[i]);
758         }
759         // Batch major AND/OR 2+ LSTM batches so will use concat AND/OR permute later on.
760         else
761         {
762             inputLSTM = splitterOutputsTensorInfosPtr[i];
763             outputLSTM = const_cast<arm_compute::ITensorInfo*>(concatInputsTensorInfosPtr[i]);
764         }
765 
766         statusLSTM = arm_compute::NELSTMLayer::validate(inputLSTM,
767                                                         &aclInputToForgetWeightsInfo,
768                                                         &aclInputToCellWeightsInfo,
769                                                         &aclInputToOutputWeightsInfo,
770                                                         &aclRecurrentToForgetWeightsInfo,
771                                                         &aclRecurrentToCellWeightsInfo,
772                                                         &aclRecurrentToOutputWeightsInfo,
773                                                         &aclForgetGateBiasInfo,
774                                                         &aclCellBiasInfo,
775                                                         &aclOutputGateBiasInfo,
776                                                         &aclOutputStateInInfo,
777                                                         &aclCellStateInInfo,
778                                                         &aclScratchBufferInfo,
779                                                         &aclOutputStateOutInfo,
780                                                         &aclCellStateOutInfo,
781                                                         outputLSTM,
782                                                         lstm_params_info,
783                                                         activationLayerInfo,
784                                                         cell_threshold,
785                                                         projection_threshold);
786 
787         if (statusLSTM.error_code() != arm_compute::ErrorCode::OK)
788         {
789             break;
790         }
791     }
792 
793     //
794     // Concat validate
795     //
796 
797     // Expand dimensions of LSTM outputs adding one empty dimension to fit concatenate inputs.
798     TensorShape shape = GetTensorShape(concatInputsTensorInfosPtr[0]->tensor_shape(), 1U);
799     TensorShape shapeExpandTimeMajor({1, shape[0], shape[1]});
800     TensorShape shapeExpandBatchMajor({shape[0], 1, shape[1]});
801 
802     TensorInfo concatOutputTensorInfo = TensorInfo(output);
803     concatOutputTensorInfo.SetShape(timeMajorShapeOutput);
804     arm_compute::TensorInfo aclConcatOutputTensorInfo= BuildArmComputeTensorInfo(concatOutputTensorInfo);
805 
806     if (maxTime != 1) // ACL concat does not work with only one element to concatenate.
807     {
808         for (unsigned int i = 0; i < maxTime; ++i)
809         {
810             auto acl_shape_expand = BuildArmComputeTensorShape(shapeExpandTimeMajor);
811             concatInputsTensorInfos[i].set_tensor_shape(acl_shape_expand);
812         }
813 
814         unsigned int aclAxisConcat = CalcAclAxis(numberDimensions, dimension);
815         if (!descriptor.m_TimeMajor)
816         {
817             statusConcat = arm_compute::NEConcatenateLayer::validate(concatInputsTensorInfosPtr,
818                                                                      &aclConcatOutputTensorInfo,
819                                                                      aclAxisConcat);
820         }
821         else
822         {
823             statusConcat = arm_compute::NEConcatenateLayer::validate(concatInputsTensorInfosPtr,
824                                                                      &aclOutputInfo,
825                                                                      aclAxisConcat);
826         }
827     }
828     // If only one LSTM batch, we do not concat and/or permute.
829     // Must ensure final output info is expanded to correct batch major dimensions.
830     else
831     {
832         if (!descriptor.m_TimeMajor)
833         {
834             const_cast<arm_compute::TensorInfo*>(&aclInputInfo)->set_tensor_shape(
835                 BuildArmComputeTensorShape(shapeExpandBatchMajor));
836         }
837         else
838         {
839             const_cast<arm_compute::TensorInfo*>(&aclInputInfo)->set_tensor_shape(
840                 BuildArmComputeTensorShape(shapeExpandTimeMajor));
841         }
842     }
843 
844     //
845     // Permute validate
846     //
847     if (!descriptor.m_TimeMajor)
848     {
849         // Output now time major. Permute output back to batch major.
850         if (maxTime != 1)
851         {
852             statusPermute2 = arm_compute::NEPermute::validate(&aclConcatOutputTensorInfo,
853                                                               &aclOutputInfo,
854                                                               arm_compute::PermutationVector(0U, 2U, 1U));
855         }
856         else
857         {
858             statusPermute2 = arm_compute::NEPermute::validate(concatInputsTensorInfosPtr[0],
859                                                               &aclOutputInfo,
860                                                               arm_compute::PermutationVector(0U, 2U, 1U));
861         }
862     }
863 
864     auto okCode = arm_compute::ErrorCode::OK;
865     if (statusPermute1.error_code() == okCode &&
866         statusSplit.error_code()    == okCode &&
867         statusLSTM .error_code()    == okCode &&
868         statusConcat.error_code()   == okCode &&
869         statusPermute2.error_code() == okCode)
870     {
871         return arm_compute::Status(arm_compute::ErrorCode::OK,
872                                    "All Unidirectional Sequence LSTM layer validate status OK.");
873     }
874     else
875     {
876         return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
877                                    "Unidirectional Sequence LSTM layer validate status failed.");
878     }
879 }
880 
FreeUnusedTensors()881 void NeonUnidirectionalSequenceLstmFloatWorkload::FreeUnusedTensors()
882 {
883     FreeTensorIfUnused(m_InputToInputWeightsTensor);
884     FreeTensorIfUnused(m_InputToForgetWeightsTensor);
885     FreeTensorIfUnused(m_InputToCellWeightsTensor);
886     FreeTensorIfUnused(m_InputToOutputWeightsTensor);
887     FreeTensorIfUnused(m_RecurrentToInputWeightsTensor);
888     FreeTensorIfUnused(m_RecurrentToForgetWeightsTensor);
889     FreeTensorIfUnused(m_RecurrentToCellWeightsTensor);
890     FreeTensorIfUnused(m_RecurrentToOutputWeightsTensor);
891     FreeTensorIfUnused(m_CellToInputWeightsTensor);
892     FreeTensorIfUnused(m_CellToForgetWeightsTensor);
893     FreeTensorIfUnused(m_CellToOutputWeightsTensor);
894     FreeTensorIfUnused(m_InputGateBiasTensor);
895     FreeTensorIfUnused(m_ForgetGateBiasTensor);
896     FreeTensorIfUnused(m_CellBiasTensor);
897     FreeTensorIfUnused(m_OutputGateBiasTensor);
898     FreeTensorIfUnused(m_ProjectionWeightsTensor);
899     FreeTensorIfUnused(m_ProjectionBiasTensor);
900     FreeTensorIfUnused(m_InputLayerNormWeightsTensor);
901     FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor);
902     FreeTensorIfUnused(m_CellLayerNormWeightsTensor);
903     FreeTensorIfUnused(m_OutputLayerNormWeightsTensor);
904     FreeTensorIfUnused(m_ScratchBuffer);
905 }
906 
907 } //namespace armnn
908