xref: /aosp_15_r20/external/armnn/src/armnn/layers/LstmParameters.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "LayerWithParameters.hpp"
8 
9 namespace armnn
10 {
11 
12 class ScopedTensorHandle;
13 
14 struct LstmOptLayerNormParameters
15 {
16     /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
17     std::shared_ptr<ConstTensorHandle> m_InputLayerNormWeights;
18     /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
19     std::shared_ptr<ConstTensorHandle> m_ForgetLayerNormWeights;
20     /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
21     std::shared_ptr<ConstTensorHandle> m_CellLayerNormWeights;
22     /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
23     std::shared_ptr<ConstTensorHandle> m_OutputLayerNormWeights;
24 };
25 
26 struct LstmOptCifgParameters
27 {
28     /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
29     std::shared_ptr<ConstTensorHandle> m_InputToInputWeights;
30     /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
31     std::shared_ptr<ConstTensorHandle> m_RecurrentToInputWeights;
32     /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
33     std::shared_ptr<ConstTensorHandle> m_InputGateBias;
34 };
35 
36 struct LstmOptProjectionParameters
37 {
38     /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units].
39     std::shared_ptr<ConstTensorHandle> m_ProjectionWeights;
40     /// A unique pointer to represent 1D weights tensor with dimensions [output_size].
41     std::shared_ptr<ConstTensorHandle> m_ProjectionBias;
42 };
43 
44 struct LstmOptPeepholeParameters
45 {
46     /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
47     std::shared_ptr<ConstTensorHandle> m_CellToInputWeights;
48     /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
49     std::shared_ptr<ConstTensorHandle> m_CellToForgetWeights;
50     /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
51     std::shared_ptr<ConstTensorHandle> m_CellToOutputWeights;
52 };
53 
54 struct LstmBasicParameters
55 {
56     /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
57     std::shared_ptr<ConstTensorHandle> m_InputToForgetWeights;
58     /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
59     std::shared_ptr<ConstTensorHandle> m_InputToCellWeights;
60     /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
61     std::shared_ptr<ConstTensorHandle> m_InputToOutputWeights;
62     /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units].
63     std::shared_ptr<ConstTensorHandle> m_RecurrentToForgetWeights;
64     /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units].
65     std::shared_ptr<ConstTensorHandle> m_RecurrentToCellWeights;
66     /// A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units].
67     std::shared_ptr<ConstTensorHandle> m_RecurrentToOutputWeights;
68     /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
69     std::shared_ptr<ConstTensorHandle> m_ForgetGateBias;
70     /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
71     std::shared_ptr<ConstTensorHandle> m_CellBias;
72     /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
73     std::shared_ptr<ConstTensorHandle> m_OutputGateBias;
74 };
75 
76 } // namespace
77