xref: /aosp_15_r20/external/armnn/include/armnn/QuantizedLstmParams.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "TensorFwd.hpp"
8 #include "Exceptions.hpp"
9 
10 namespace armnn
11 {
12 
13 struct QuantizedLstmInputParams
14 {
QuantizedLstmInputParamsarmnn::QuantizedLstmInputParams15     QuantizedLstmInputParams()
16         : m_InputToInputWeights(nullptr)
17         , m_InputToForgetWeights(nullptr)
18         , m_InputToCellWeights(nullptr)
19         , m_InputToOutputWeights(nullptr)
20 
21         , m_RecurrentToInputWeights(nullptr)
22         , m_RecurrentToForgetWeights(nullptr)
23         , m_RecurrentToCellWeights(nullptr)
24         , m_RecurrentToOutputWeights(nullptr)
25 
26         , m_InputGateBias(nullptr)
27         , m_ForgetGateBias(nullptr)
28         , m_CellBias(nullptr)
29         , m_OutputGateBias(nullptr)
30     {
31     }
32 
33     const ConstTensor* m_InputToInputWeights;
34     const ConstTensor* m_InputToForgetWeights;
35     const ConstTensor* m_InputToCellWeights;
36     const ConstTensor* m_InputToOutputWeights;
37 
38     const ConstTensor* m_RecurrentToInputWeights;
39     const ConstTensor* m_RecurrentToForgetWeights;
40     const ConstTensor* m_RecurrentToCellWeights;
41     const ConstTensor* m_RecurrentToOutputWeights;
42 
43     const ConstTensor* m_InputGateBias;
44     const ConstTensor* m_ForgetGateBias;
45     const ConstTensor* m_CellBias;
46     const ConstTensor* m_OutputGateBias;
47 
Derefarmnn::QuantizedLstmInputParams48     const ConstTensor& Deref(const ConstTensor* tensorPtr) const
49     {
50         if (tensorPtr != nullptr)
51         {
52             const ConstTensor &temp = *tensorPtr;
53             return temp;
54         }
55         throw InvalidArgumentException("QuantizedLstmInputParams: Can't dereference a null pointer");
56     }
57 
GetInputToInputWeightsarmnn::QuantizedLstmInputParams58     const ConstTensor& GetInputToInputWeights() const
59     {
60         return Deref(m_InputToInputWeights);
61     }
62 
GetInputToForgetWeightsarmnn::QuantizedLstmInputParams63     const ConstTensor& GetInputToForgetWeights() const
64     {
65         return Deref(m_InputToForgetWeights);
66     }
67 
GetInputToCellWeightsarmnn::QuantizedLstmInputParams68     const ConstTensor& GetInputToCellWeights() const
69     {
70         return Deref(m_InputToCellWeights);
71     }
72 
GetInputToOutputWeightsarmnn::QuantizedLstmInputParams73     const ConstTensor& GetInputToOutputWeights() const
74     {
75         return Deref(m_InputToOutputWeights);
76     }
77 
GetRecurrentToInputWeightsarmnn::QuantizedLstmInputParams78     const ConstTensor& GetRecurrentToInputWeights() const
79     {
80         return Deref(m_RecurrentToInputWeights);
81     }
82 
GetRecurrentToForgetWeightsarmnn::QuantizedLstmInputParams83     const ConstTensor& GetRecurrentToForgetWeights() const
84     {
85         return Deref(m_RecurrentToForgetWeights);
86     }
87 
GetRecurrentToCellWeightsarmnn::QuantizedLstmInputParams88     const ConstTensor& GetRecurrentToCellWeights() const
89     {
90         return Deref(m_RecurrentToCellWeights);
91     }
92 
GetRecurrentToOutputWeightsarmnn::QuantizedLstmInputParams93     const ConstTensor& GetRecurrentToOutputWeights() const
94     {
95         return Deref(m_RecurrentToOutputWeights);
96     }
97 
GetInputGateBiasarmnn::QuantizedLstmInputParams98     const ConstTensor& GetInputGateBias() const
99     {
100         return Deref(m_InputGateBias);
101     }
102 
GetForgetGateBiasarmnn::QuantizedLstmInputParams103     const ConstTensor& GetForgetGateBias() const
104     {
105         return Deref(m_ForgetGateBias);
106     }
107 
GetCellBiasarmnn::QuantizedLstmInputParams108     const ConstTensor& GetCellBias() const
109     {
110         return Deref(m_CellBias);
111     }
112 
GetOutputGateBiasarmnn::QuantizedLstmInputParams113     const ConstTensor& GetOutputGateBias() const
114     {
115         return Deref(m_OutputGateBias);
116     }
117 };
118 
119 struct QuantizedLstmInputParamsInfo
120 {
QuantizedLstmInputParamsInfoarmnn::QuantizedLstmInputParamsInfo121     QuantizedLstmInputParamsInfo()
122         : m_InputToInputWeights(nullptr)
123         , m_InputToForgetWeights(nullptr)
124         , m_InputToCellWeights(nullptr)
125         , m_InputToOutputWeights(nullptr)
126 
127         , m_RecurrentToInputWeights(nullptr)
128         , m_RecurrentToForgetWeights(nullptr)
129         , m_RecurrentToCellWeights(nullptr)
130         , m_RecurrentToOutputWeights(nullptr)
131 
132         , m_InputGateBias(nullptr)
133         , m_ForgetGateBias(nullptr)
134         , m_CellBias(nullptr)
135         , m_OutputGateBias(nullptr)
136     {
137     }
138 
139     const TensorInfo* m_InputToInputWeights;
140     const TensorInfo* m_InputToForgetWeights;
141     const TensorInfo* m_InputToCellWeights;
142     const TensorInfo* m_InputToOutputWeights;
143 
144     const TensorInfo* m_RecurrentToInputWeights;
145     const TensorInfo* m_RecurrentToForgetWeights;
146     const TensorInfo* m_RecurrentToCellWeights;
147     const TensorInfo* m_RecurrentToOutputWeights;
148 
149     const TensorInfo* m_InputGateBias;
150     const TensorInfo* m_ForgetGateBias;
151     const TensorInfo* m_CellBias;
152     const TensorInfo* m_OutputGateBias;
153 
154 
Derefarmnn::QuantizedLstmInputParamsInfo155     const TensorInfo& Deref(const TensorInfo* tensorInfo) const
156     {
157         if (tensorInfo != nullptr)
158         {
159             const TensorInfo &temp = *tensorInfo;
160             return temp;
161         }
162         throw InvalidArgumentException("Can't dereference a null pointer");
163     }
164 
GetInputToInputWeightsarmnn::QuantizedLstmInputParamsInfo165     const TensorInfo& GetInputToInputWeights() const
166     {
167         return Deref(m_InputToInputWeights);
168     }
GetInputToForgetWeightsarmnn::QuantizedLstmInputParamsInfo169     const TensorInfo& GetInputToForgetWeights() const
170     {
171         return Deref(m_InputToForgetWeights);
172     }
GetInputToCellWeightsarmnn::QuantizedLstmInputParamsInfo173     const TensorInfo& GetInputToCellWeights() const
174     {
175         return Deref(m_InputToCellWeights);
176     }
GetInputToOutputWeightsarmnn::QuantizedLstmInputParamsInfo177     const TensorInfo& GetInputToOutputWeights() const
178     {
179         return Deref(m_InputToOutputWeights);
180     }
181 
GetRecurrentToInputWeightsarmnn::QuantizedLstmInputParamsInfo182     const TensorInfo& GetRecurrentToInputWeights() const
183     {
184         return Deref(m_RecurrentToInputWeights);
185     }
GetRecurrentToForgetWeightsarmnn::QuantizedLstmInputParamsInfo186     const TensorInfo& GetRecurrentToForgetWeights() const
187     {
188         return Deref(m_RecurrentToForgetWeights);
189     }
GetRecurrentToCellWeightsarmnn::QuantizedLstmInputParamsInfo190     const TensorInfo& GetRecurrentToCellWeights() const
191     {
192         return Deref(m_RecurrentToCellWeights);
193     }
GetRecurrentToOutputWeightsarmnn::QuantizedLstmInputParamsInfo194     const TensorInfo& GetRecurrentToOutputWeights() const
195     {
196         return Deref(m_RecurrentToOutputWeights);
197     }
198 
GetInputGateBiasarmnn::QuantizedLstmInputParamsInfo199     const TensorInfo& GetInputGateBias() const
200     {
201         return Deref(m_InputGateBias);
202     }
GetForgetGateBiasarmnn::QuantizedLstmInputParamsInfo203     const TensorInfo& GetForgetGateBias() const
204     {
205         return Deref(m_ForgetGateBias);
206     }
GetCellBiasarmnn::QuantizedLstmInputParamsInfo207     const TensorInfo& GetCellBias() const
208     {
209         return Deref(m_CellBias);
210     }
GetOutputGateBiasarmnn::QuantizedLstmInputParamsInfo211     const TensorInfo& GetOutputGateBias() const
212     {
213         return Deref(m_OutputGateBias);
214     }
215 };
216 
217 } // namespace armnn
218 
219