xref: /aosp_15_r20/external/armnn/include/armnn/QuantizedLstmParams.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include "TensorFwd.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "Exceptions.hpp"
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker namespace armnn
11*89c4ff92SAndroid Build Coastguard Worker {
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker struct QuantizedLstmInputParams
14*89c4ff92SAndroid Build Coastguard Worker {
QuantizedLstmInputParamsarmnn::QuantizedLstmInputParams15*89c4ff92SAndroid Build Coastguard Worker     QuantizedLstmInputParams()
16*89c4ff92SAndroid Build Coastguard Worker         : m_InputToInputWeights(nullptr)
17*89c4ff92SAndroid Build Coastguard Worker         , m_InputToForgetWeights(nullptr)
18*89c4ff92SAndroid Build Coastguard Worker         , m_InputToCellWeights(nullptr)
19*89c4ff92SAndroid Build Coastguard Worker         , m_InputToOutputWeights(nullptr)
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker         , m_RecurrentToInputWeights(nullptr)
22*89c4ff92SAndroid Build Coastguard Worker         , m_RecurrentToForgetWeights(nullptr)
23*89c4ff92SAndroid Build Coastguard Worker         , m_RecurrentToCellWeights(nullptr)
24*89c4ff92SAndroid Build Coastguard Worker         , m_RecurrentToOutputWeights(nullptr)
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker         , m_InputGateBias(nullptr)
27*89c4ff92SAndroid Build Coastguard Worker         , m_ForgetGateBias(nullptr)
28*89c4ff92SAndroid Build Coastguard Worker         , m_CellBias(nullptr)
29*89c4ff92SAndroid Build Coastguard Worker         , m_OutputGateBias(nullptr)
30*89c4ff92SAndroid Build Coastguard Worker     {
31*89c4ff92SAndroid Build Coastguard Worker     }
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor* m_InputToInputWeights;
34*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor* m_InputToForgetWeights;
35*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor* m_InputToCellWeights;
36*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor* m_InputToOutputWeights;
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor* m_RecurrentToInputWeights;
39*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor* m_RecurrentToForgetWeights;
40*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor* m_RecurrentToCellWeights;
41*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor* m_RecurrentToOutputWeights;
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor* m_InputGateBias;
44*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor* m_ForgetGateBias;
45*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor* m_CellBias;
46*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor* m_OutputGateBias;
47*89c4ff92SAndroid Build Coastguard Worker 
Derefarmnn::QuantizedLstmInputParams48*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor& Deref(const ConstTensor* tensorPtr) const
49*89c4ff92SAndroid Build Coastguard Worker     {
50*89c4ff92SAndroid Build Coastguard Worker         if (tensorPtr != nullptr)
51*89c4ff92SAndroid Build Coastguard Worker         {
52*89c4ff92SAndroid Build Coastguard Worker             const ConstTensor &temp = *tensorPtr;
53*89c4ff92SAndroid Build Coastguard Worker             return temp;
54*89c4ff92SAndroid Build Coastguard Worker         }
55*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException("QuantizedLstmInputParams: Can't dereference a null pointer");
56*89c4ff92SAndroid Build Coastguard Worker     }
57*89c4ff92SAndroid Build Coastguard Worker 
GetInputToInputWeightsarmnn::QuantizedLstmInputParams58*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor& GetInputToInputWeights() const
59*89c4ff92SAndroid Build Coastguard Worker     {
60*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_InputToInputWeights);
61*89c4ff92SAndroid Build Coastguard Worker     }
62*89c4ff92SAndroid Build Coastguard Worker 
GetInputToForgetWeightsarmnn::QuantizedLstmInputParams63*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor& GetInputToForgetWeights() const
64*89c4ff92SAndroid Build Coastguard Worker     {
65*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_InputToForgetWeights);
66*89c4ff92SAndroid Build Coastguard Worker     }
67*89c4ff92SAndroid Build Coastguard Worker 
GetInputToCellWeightsarmnn::QuantizedLstmInputParams68*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor& GetInputToCellWeights() const
69*89c4ff92SAndroid Build Coastguard Worker     {
70*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_InputToCellWeights);
71*89c4ff92SAndroid Build Coastguard Worker     }
72*89c4ff92SAndroid Build Coastguard Worker 
GetInputToOutputWeightsarmnn::QuantizedLstmInputParams73*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor& GetInputToOutputWeights() const
74*89c4ff92SAndroid Build Coastguard Worker     {
75*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_InputToOutputWeights);
76*89c4ff92SAndroid Build Coastguard Worker     }
77*89c4ff92SAndroid Build Coastguard Worker 
GetRecurrentToInputWeightsarmnn::QuantizedLstmInputParams78*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor& GetRecurrentToInputWeights() const
79*89c4ff92SAndroid Build Coastguard Worker     {
80*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_RecurrentToInputWeights);
81*89c4ff92SAndroid Build Coastguard Worker     }
82*89c4ff92SAndroid Build Coastguard Worker 
GetRecurrentToForgetWeightsarmnn::QuantizedLstmInputParams83*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor& GetRecurrentToForgetWeights() const
84*89c4ff92SAndroid Build Coastguard Worker     {
85*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_RecurrentToForgetWeights);
86*89c4ff92SAndroid Build Coastguard Worker     }
87*89c4ff92SAndroid Build Coastguard Worker 
GetRecurrentToCellWeightsarmnn::QuantizedLstmInputParams88*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor& GetRecurrentToCellWeights() const
89*89c4ff92SAndroid Build Coastguard Worker     {
90*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_RecurrentToCellWeights);
91*89c4ff92SAndroid Build Coastguard Worker     }
92*89c4ff92SAndroid Build Coastguard Worker 
GetRecurrentToOutputWeightsarmnn::QuantizedLstmInputParams93*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor& GetRecurrentToOutputWeights() const
94*89c4ff92SAndroid Build Coastguard Worker     {
95*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_RecurrentToOutputWeights);
96*89c4ff92SAndroid Build Coastguard Worker     }
97*89c4ff92SAndroid Build Coastguard Worker 
GetInputGateBiasarmnn::QuantizedLstmInputParams98*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor& GetInputGateBias() const
99*89c4ff92SAndroid Build Coastguard Worker     {
100*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_InputGateBias);
101*89c4ff92SAndroid Build Coastguard Worker     }
102*89c4ff92SAndroid Build Coastguard Worker 
GetForgetGateBiasarmnn::QuantizedLstmInputParams103*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor& GetForgetGateBias() const
104*89c4ff92SAndroid Build Coastguard Worker     {
105*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_ForgetGateBias);
106*89c4ff92SAndroid Build Coastguard Worker     }
107*89c4ff92SAndroid Build Coastguard Worker 
GetCellBiasarmnn::QuantizedLstmInputParams108*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor& GetCellBias() const
109*89c4ff92SAndroid Build Coastguard Worker     {
110*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_CellBias);
111*89c4ff92SAndroid Build Coastguard Worker     }
112*89c4ff92SAndroid Build Coastguard Worker 
GetOutputGateBiasarmnn::QuantizedLstmInputParams113*89c4ff92SAndroid Build Coastguard Worker     const ConstTensor& GetOutputGateBias() const
114*89c4ff92SAndroid Build Coastguard Worker     {
115*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_OutputGateBias);
116*89c4ff92SAndroid Build Coastguard Worker     }
117*89c4ff92SAndroid Build Coastguard Worker };
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker struct QuantizedLstmInputParamsInfo
120*89c4ff92SAndroid Build Coastguard Worker {
QuantizedLstmInputParamsInfoarmnn::QuantizedLstmInputParamsInfo121*89c4ff92SAndroid Build Coastguard Worker     QuantizedLstmInputParamsInfo()
122*89c4ff92SAndroid Build Coastguard Worker         : m_InputToInputWeights(nullptr)
123*89c4ff92SAndroid Build Coastguard Worker         , m_InputToForgetWeights(nullptr)
124*89c4ff92SAndroid Build Coastguard Worker         , m_InputToCellWeights(nullptr)
125*89c4ff92SAndroid Build Coastguard Worker         , m_InputToOutputWeights(nullptr)
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker         , m_RecurrentToInputWeights(nullptr)
128*89c4ff92SAndroid Build Coastguard Worker         , m_RecurrentToForgetWeights(nullptr)
129*89c4ff92SAndroid Build Coastguard Worker         , m_RecurrentToCellWeights(nullptr)
130*89c4ff92SAndroid Build Coastguard Worker         , m_RecurrentToOutputWeights(nullptr)
131*89c4ff92SAndroid Build Coastguard Worker 
132*89c4ff92SAndroid Build Coastguard Worker         , m_InputGateBias(nullptr)
133*89c4ff92SAndroid Build Coastguard Worker         , m_ForgetGateBias(nullptr)
134*89c4ff92SAndroid Build Coastguard Worker         , m_CellBias(nullptr)
135*89c4ff92SAndroid Build Coastguard Worker         , m_OutputGateBias(nullptr)
136*89c4ff92SAndroid Build Coastguard Worker     {
137*89c4ff92SAndroid Build Coastguard Worker     }
138*89c4ff92SAndroid Build Coastguard Worker 
139*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo* m_InputToInputWeights;
140*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo* m_InputToForgetWeights;
141*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo* m_InputToCellWeights;
142*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo* m_InputToOutputWeights;
143*89c4ff92SAndroid Build Coastguard Worker 
144*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo* m_RecurrentToInputWeights;
145*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo* m_RecurrentToForgetWeights;
146*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo* m_RecurrentToCellWeights;
147*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo* m_RecurrentToOutputWeights;
148*89c4ff92SAndroid Build Coastguard Worker 
149*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo* m_InputGateBias;
150*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo* m_ForgetGateBias;
151*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo* m_CellBias;
152*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo* m_OutputGateBias;
153*89c4ff92SAndroid Build Coastguard Worker 
154*89c4ff92SAndroid Build Coastguard Worker 
Derefarmnn::QuantizedLstmInputParamsInfo155*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo& Deref(const TensorInfo* tensorInfo) const
156*89c4ff92SAndroid Build Coastguard Worker     {
157*89c4ff92SAndroid Build Coastguard Worker         if (tensorInfo != nullptr)
158*89c4ff92SAndroid Build Coastguard Worker         {
159*89c4ff92SAndroid Build Coastguard Worker             const TensorInfo &temp = *tensorInfo;
160*89c4ff92SAndroid Build Coastguard Worker             return temp;
161*89c4ff92SAndroid Build Coastguard Worker         }
162*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException("Can't dereference a null pointer");
163*89c4ff92SAndroid Build Coastguard Worker     }
164*89c4ff92SAndroid Build Coastguard Worker 
GetInputToInputWeightsarmnn::QuantizedLstmInputParamsInfo165*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo& GetInputToInputWeights() const
166*89c4ff92SAndroid Build Coastguard Worker     {
167*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_InputToInputWeights);
168*89c4ff92SAndroid Build Coastguard Worker     }
GetInputToForgetWeightsarmnn::QuantizedLstmInputParamsInfo169*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo& GetInputToForgetWeights() const
170*89c4ff92SAndroid Build Coastguard Worker     {
171*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_InputToForgetWeights);
172*89c4ff92SAndroid Build Coastguard Worker     }
GetInputToCellWeightsarmnn::QuantizedLstmInputParamsInfo173*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo& GetInputToCellWeights() const
174*89c4ff92SAndroid Build Coastguard Worker     {
175*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_InputToCellWeights);
176*89c4ff92SAndroid Build Coastguard Worker     }
GetInputToOutputWeightsarmnn::QuantizedLstmInputParamsInfo177*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo& GetInputToOutputWeights() const
178*89c4ff92SAndroid Build Coastguard Worker     {
179*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_InputToOutputWeights);
180*89c4ff92SAndroid Build Coastguard Worker     }
181*89c4ff92SAndroid Build Coastguard Worker 
GetRecurrentToInputWeightsarmnn::QuantizedLstmInputParamsInfo182*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo& GetRecurrentToInputWeights() const
183*89c4ff92SAndroid Build Coastguard Worker     {
184*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_RecurrentToInputWeights);
185*89c4ff92SAndroid Build Coastguard Worker     }
GetRecurrentToForgetWeightsarmnn::QuantizedLstmInputParamsInfo186*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo& GetRecurrentToForgetWeights() const
187*89c4ff92SAndroid Build Coastguard Worker     {
188*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_RecurrentToForgetWeights);
189*89c4ff92SAndroid Build Coastguard Worker     }
GetRecurrentToCellWeightsarmnn::QuantizedLstmInputParamsInfo190*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo& GetRecurrentToCellWeights() const
191*89c4ff92SAndroid Build Coastguard Worker     {
192*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_RecurrentToCellWeights);
193*89c4ff92SAndroid Build Coastguard Worker     }
GetRecurrentToOutputWeightsarmnn::QuantizedLstmInputParamsInfo194*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo& GetRecurrentToOutputWeights() const
195*89c4ff92SAndroid Build Coastguard Worker     {
196*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_RecurrentToOutputWeights);
197*89c4ff92SAndroid Build Coastguard Worker     }
198*89c4ff92SAndroid Build Coastguard Worker 
GetInputGateBiasarmnn::QuantizedLstmInputParamsInfo199*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo& GetInputGateBias() const
200*89c4ff92SAndroid Build Coastguard Worker     {
201*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_InputGateBias);
202*89c4ff92SAndroid Build Coastguard Worker     }
GetForgetGateBiasarmnn::QuantizedLstmInputParamsInfo203*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo& GetForgetGateBias() const
204*89c4ff92SAndroid Build Coastguard Worker     {
205*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_ForgetGateBias);
206*89c4ff92SAndroid Build Coastguard Worker     }
GetCellBiasarmnn::QuantizedLstmInputParamsInfo207*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo& GetCellBias() const
208*89c4ff92SAndroid Build Coastguard Worker     {
209*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_CellBias);
210*89c4ff92SAndroid Build Coastguard Worker     }
GetOutputGateBiasarmnn::QuantizedLstmInputParamsInfo211*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo& GetOutputGateBias() const
212*89c4ff92SAndroid Build Coastguard Worker     {
213*89c4ff92SAndroid Build Coastguard Worker         return Deref(m_OutputGateBias);
214*89c4ff92SAndroid Build Coastguard Worker     }
215*89c4ff92SAndroid Build Coastguard Worker };
216*89c4ff92SAndroid Build Coastguard Worker 
217*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
218*89c4ff92SAndroid Build Coastguard Worker 
219