xref: /aosp_15_r20/external/armnn/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmWorkload.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Descriptors.hpp>
9 #include <armnn/LstmParams.hpp>
10 #include <armnn/backends/Workload.hpp>
11 #include <armnn/backends/WorkloadData.hpp>
12 #include "NeonBaseWorkload.hpp"
13 
14 #include "arm_compute/runtime/NEON/functions/NEQLSTMLayer.h"
15 #include "arm_compute/runtime/NEON/functions/NEPermute.h"
16 #include "arm_compute/runtime/NEON/functions/NESplit.h"
17 #include "arm_compute/runtime/NEON/functions/NEConcatenateLayer.h"
18 
19 namespace armnn
20 {
21 
22 class NeonUnidirectionalSequenceLstmWorkload : public NeonBaseWorkload<UnidirectionalSequenceLstmQueueDescriptor>
23 {
24 public:
25     NeonUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor& descriptor,
26                                            const WorkloadInfo& info);
27     virtual void Execute() const override;
28 
29 private:
30 
31     //
32     // ACL layers required to fully form a Unidirectional Sequence LSTM layer.
33     //
34     mutable std::unique_ptr<arm_compute::NEPermute> m_Permute1;
35     mutable std::unique_ptr<arm_compute::IFunction> m_Splitter;
36     mutable std::vector<std::unique_ptr<arm_compute::NEQLSTMLayer>> m_Layers;
37     mutable std::unique_ptr<arm_compute::NEConcatenateLayer> m_Concat;
38     mutable std::unique_ptr<arm_compute::NEPermute> m_Permute2;
39 
40     //
41     // ACL LSTM arm_compute::Tensors.
42     //
43     std::unique_ptr<arm_compute::Tensor> m_InputToInputWeightsTensor;
44     std::unique_ptr<arm_compute::Tensor> m_InputToForgetWeightsTensor;
45     std::unique_ptr<arm_compute::Tensor> m_InputToCellWeightsTensor;
46     std::unique_ptr<arm_compute::Tensor> m_InputToOutputWeightsTensor;
47     std::unique_ptr<arm_compute::Tensor> m_RecurrentToInputWeightsTensor;
48     std::unique_ptr<arm_compute::Tensor> m_RecurrentToForgetWeightsTensor;
49     std::unique_ptr<arm_compute::Tensor> m_RecurrentToCellWeightsTensor;
50     std::unique_ptr<arm_compute::Tensor> m_RecurrentToOutputWeightsTensor;
51     std::unique_ptr<arm_compute::Tensor> m_CellToInputWeightsTensor;
52     std::unique_ptr<arm_compute::Tensor> m_CellToForgetWeightsTensor;
53     std::unique_ptr<arm_compute::Tensor> m_CellToOutputWeightsTensor;
54     std::unique_ptr<arm_compute::Tensor> m_InputGateBiasTensor;
55     std::unique_ptr<arm_compute::Tensor> m_ForgetGateBiasTensor;
56     std::unique_ptr<arm_compute::Tensor> m_CellBiasTensor;
57     std::unique_ptr<arm_compute::Tensor> m_OutputGateBiasTensor;
58     std::unique_ptr<arm_compute::Tensor> m_ProjectionWeightsTensor;
59     std::unique_ptr<arm_compute::Tensor> m_ProjectionBiasTensor;
60 
61     std::unique_ptr<arm_compute::Tensor> m_InputLayerNormWeightsTensor;
62     std::unique_ptr<arm_compute::Tensor> m_ForgetLayerNormWeightsTensor;
63     std::unique_ptr<arm_compute::Tensor> m_CellLayerNormWeightsTensor;
64     std::unique_ptr<arm_compute::Tensor> m_OutputLayerNormWeightsTensor;
65 
66     //
67     // Additional ACL arm_compute::Tensors and std::vector<arm_compute::Tensor>.
68     // Required to perform splitting, concatenation and permutations.
69     //
70     arm_compute::Tensor m_PermuteFirstOut;
71     std::vector<arm_compute::Tensor> m_SplitterOutputsTensors;
72     std::vector<arm_compute::Tensor> m_ConcatInputsTensors;
73     std::vector<arm_compute::ITensor*> m_SplitterOutputs;
74     std::vector<const arm_compute::ITensor*> m_ConcatInputs;
75     arm_compute::Tensor concat_out;
76 
77     void FreeUnusedTensors();
78 };
79 
80 arm_compute::Status
81 NeonUnidirectionalSequenceLstmWorkloadValidate(const TensorInfo& input,
82                                                const TensorInfo& outputStateIn,
83                                                const TensorInfo& cellStateIn,
84                                                const TensorInfo& outputStateOut,
85                                                const TensorInfo& cellStateOut,
86                                                const TensorInfo& output,
87                                                const UnidirectionalSequenceLstmDescriptor& descriptor,
88                                                const LstmInputParamsInfo& paramsInfo);
89 
90 } //namespace armnn
91