xref: /aosp_15_r20/external/ComputeLibrary/arm_compute/runtime/common/LSTMParams.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2018-2021 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #ifndef ARM_COMPUTE_LSTMPARAMS_H
25*c217d954SCole Faust #define ARM_COMPUTE_LSTMPARAMS_H
26*c217d954SCole Faust 
27*c217d954SCole Faust #include "arm_compute/core/Types.h"
28*c217d954SCole Faust #include "arm_compute/runtime/Tensor.h"
29*c217d954SCole Faust 
30*c217d954SCole Faust #include <cstddef>
31*c217d954SCole Faust #include <memory>
32*c217d954SCole Faust 
33*c217d954SCole Faust namespace arm_compute
34*c217d954SCole Faust {
35*c217d954SCole Faust template <typename T>
36*c217d954SCole Faust class LSTMParams
37*c217d954SCole Faust {
38*c217d954SCole Faust public:
39*c217d954SCole Faust     /** Constructor */
LSTMParams()40*c217d954SCole Faust     LSTMParams()
41*c217d954SCole Faust         : _input_to_input_weights(nullptr),
42*c217d954SCole Faust           _recurrent_to_input_weights(nullptr),
43*c217d954SCole Faust           _cell_to_input_weights(nullptr),
44*c217d954SCole Faust           _input_gate_bias(nullptr),
45*c217d954SCole Faust           _cell_to_forget_weights(nullptr),
46*c217d954SCole Faust           _cell_to_output_weights(nullptr),
47*c217d954SCole Faust           _projection_weights(nullptr),
48*c217d954SCole Faust           _projection_bias(nullptr),
49*c217d954SCole Faust           _input_layer_norm_weights(nullptr),
50*c217d954SCole Faust           _forget_layer_norm_weights(nullptr),
51*c217d954SCole Faust           _cell_layer_norm_weights(nullptr),
52*c217d954SCole Faust           _output_layer_norm_weights(nullptr),
53*c217d954SCole Faust           _cell_clip(0.f),
54*c217d954SCole Faust           _projection_clip(0.0f),
55*c217d954SCole Faust           _input_intermediate_scale(0.0f),
56*c217d954SCole Faust           _forget_intermediate_scale(0.0f),
57*c217d954SCole Faust           _cell_intermediate_scale(0.0f),
58*c217d954SCole Faust           _output_intermediate_scale(0.0f),
59*c217d954SCole Faust           _hidden_state_zero(0),
60*c217d954SCole Faust           _hidden_state_scale(0.0f),
61*c217d954SCole Faust           _has_peephole_opt(false),
62*c217d954SCole Faust           _has_projection(false),
63*c217d954SCole Faust           _has_cifg_opt(true),
64*c217d954SCole Faust           _use_layer_norm(false)
65*c217d954SCole Faust     {
66*c217d954SCole Faust     }
67*c217d954SCole Faust     /** Prevent instances of this class from being copied (As this class contains pointers) */
68*c217d954SCole Faust     LSTMParams(const LSTMParams &) = delete;
69*c217d954SCole Faust     /** Prevent instances of this class from being copied (As this class contains pointers) */
70*c217d954SCole Faust     LSTMParams &operator=(const LSTMParams &) = delete;
71*c217d954SCole Faust     /** Default destructor */
72*c217d954SCole Faust     ~LSTMParams() = default;
73*c217d954SCole Faust     /** Set CIFG tensor parameters.
74*c217d954SCole Faust      *
75*c217d954SCole Faust      * @param[in] input_to_input_weights     2D weights tensor with dimensions [input_size, num_units]. Data types supported: QSYMM8/F16/F32.
76*c217d954SCole Faust      * @param[in] recurrent_to_input_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input_to_input_weights.
77*c217d954SCole Faust      * @param[in] cell_to_input_weights      1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: Same as @p input_to_input_weights.
78*c217d954SCole Faust      * @param[in] input_gate_bias            1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_to_input_weights, S32 when @p input_to_input_weights is QSYMM8
79*c217d954SCole Faust      *
80*c217d954SCole Faust      * @return Reference to this LSTMParams object
81*c217d954SCole Faust      */
set_cifg_params(const T * input_to_input_weights,const T * recurrent_to_input_weights,T * cell_to_input_weights,const T * input_gate_bias)82*c217d954SCole Faust     LSTMParams &set_cifg_params(const T *input_to_input_weights, const T *recurrent_to_input_weights, T *cell_to_input_weights, const T *input_gate_bias)
83*c217d954SCole Faust     {
84*c217d954SCole Faust         _input_to_input_weights     = input_to_input_weights;
85*c217d954SCole Faust         _recurrent_to_input_weights = recurrent_to_input_weights;
86*c217d954SCole Faust         _cell_to_input_weights      = cell_to_input_weights;
87*c217d954SCole Faust         _input_gate_bias            = input_gate_bias;
88*c217d954SCole Faust         _has_cifg_opt               = false;
89*c217d954SCole Faust         return *this;
90*c217d954SCole Faust     }
91*c217d954SCole Faust     /** Set projection tensor parameters.
92*c217d954SCole Faust      *
93*c217d954SCole Faust      * @param[in] projection_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Data types supported: QSYMM8/F16/F32.
94*c217d954SCole Faust      * @param[in] projection_bias    1D weights tensor with dimensions [output_size]. Data type supported: Same as @p projection_weights, S32 when @p input_to_input_weights is QSYMM8.
95*c217d954SCole Faust      *
96*c217d954SCole Faust      * @return Reference to this LSTMParams object
97*c217d954SCole Faust      */
set_projection_params(const T * projection_weights,const T * projection_bias)98*c217d954SCole Faust     LSTMParams &set_projection_params(const T *projection_weights, const T *projection_bias)
99*c217d954SCole Faust     {
100*c217d954SCole Faust         _projection_weights = projection_weights;
101*c217d954SCole Faust         _projection_bias    = projection_bias;
102*c217d954SCole Faust         _has_projection     = true;
103*c217d954SCole Faust         return *this;
104*c217d954SCole Faust     }
105*c217d954SCole Faust     /** Set peephole tensor parameters.
106*c217d954SCole Faust      *
107*c217d954SCole Faust      * @param[in] cell_to_forget_weights 1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32.
108*c217d954SCole Faust      * @param[in] cell_to_output_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p cell_to_forget_weights.
109*c217d954SCole Faust      *
110*c217d954SCole Faust      * @return Reference to this LSTMParams object
111*c217d954SCole Faust      */
set_peephole_params(T * cell_to_forget_weights,T * cell_to_output_weights)112*c217d954SCole Faust     LSTMParams &set_peephole_params(T *cell_to_forget_weights, T *cell_to_output_weights)
113*c217d954SCole Faust     {
114*c217d954SCole Faust         _cell_to_forget_weights = cell_to_forget_weights;
115*c217d954SCole Faust         _cell_to_output_weights = cell_to_output_weights;
116*c217d954SCole Faust         _has_peephole_opt       = true;
117*c217d954SCole Faust         return *this;
118*c217d954SCole Faust     }
119*c217d954SCole Faust     /** Set layer normalization tensor parameters.
120*c217d954SCole Faust      *
121*c217d954SCole Faust      * @param[in] input_layer_norm_weights  1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32.
122*c217d954SCole Faust      * @param[in] forget_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
123*c217d954SCole Faust      * @param[in] cell_layer_norm_weights   1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
124*c217d954SCole Faust      * @param[in] output_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
125*c217d954SCole Faust      *
126*c217d954SCole Faust      * @return Reference to this LSTMParams object
127*c217d954SCole Faust      */
set_layer_normalization_params(T * input_layer_norm_weights,T * forget_layer_norm_weights,T * cell_layer_norm_weights,T * output_layer_norm_weights)128*c217d954SCole Faust     LSTMParams &set_layer_normalization_params(T *input_layer_norm_weights, T *forget_layer_norm_weights,
129*c217d954SCole Faust                                                T *cell_layer_norm_weights, T *output_layer_norm_weights)
130*c217d954SCole Faust     {
131*c217d954SCole Faust         _input_layer_norm_weights  = input_layer_norm_weights;
132*c217d954SCole Faust         _forget_layer_norm_weights = forget_layer_norm_weights;
133*c217d954SCole Faust         _cell_layer_norm_weights   = cell_layer_norm_weights;
134*c217d954SCole Faust         _output_layer_norm_weights = output_layer_norm_weights;
135*c217d954SCole Faust         _use_layer_norm            = true;
136*c217d954SCole Faust         return *this;
137*c217d954SCole Faust     }
138*c217d954SCole Faust 
139*c217d954SCole Faust     /** Set cell clip value.
140*c217d954SCole Faust      *
141*c217d954SCole Faust      * @param[in] cell_clip Value to be used to clip the cell state prior to the cell output activation.
142*c217d954SCole Faust      *
143*c217d954SCole Faust      * @return Reference to this LSTMParams object
144*c217d954SCole Faust      */
set_cell_clip_params(float cell_clip)145*c217d954SCole Faust     LSTMParams &set_cell_clip_params(float cell_clip)
146*c217d954SCole Faust     {
147*c217d954SCole Faust         _cell_clip = cell_clip;
148*c217d954SCole Faust         return *this;
149*c217d954SCole Faust     }
150*c217d954SCole Faust 
151*c217d954SCole Faust     /** Set projection clip value.
152*c217d954SCole Faust      *
153*c217d954SCole Faust      * @param[in] projection_clip Value to be used to clip the projection, in case projection is enabled.
154*c217d954SCole Faust      *
155*c217d954SCole Faust      * @return Reference to this LSTMParams object
156*c217d954SCole Faust      */
set_projection_clip_params(float projection_clip)157*c217d954SCole Faust     LSTMParams &set_projection_clip_params(float projection_clip)
158*c217d954SCole Faust     {
159*c217d954SCole Faust         _projection_clip = projection_clip;
160*c217d954SCole Faust         return *this;
161*c217d954SCole Faust     }
162*c217d954SCole Faust 
163*c217d954SCole Faust     /** Set scale of the intermediate results of matmul of each layer parameters.
164*c217d954SCole Faust      *
165*c217d954SCole Faust      * @param[in] input_intermediate_scale  Scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate.
166*c217d954SCole Faust      * @param[in] forget_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate.
167*c217d954SCole Faust      * @param[in] cell_intermediate_scale   Scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate.
168*c217d954SCole Faust      * @param[in] output_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate.
169*c217d954SCole Faust      *
170*c217d954SCole Faust      * @return Reference to this LSTMParams object
171*c217d954SCole Faust      */
set_matmul_scale_params(float input_intermediate_scale,float forget_intermediate_scale,float cell_intermediate_scale,float output_intermediate_scale)172*c217d954SCole Faust     LSTMParams &set_matmul_scale_params(float input_intermediate_scale, float forget_intermediate_scale, float cell_intermediate_scale, float output_intermediate_scale)
173*c217d954SCole Faust     {
174*c217d954SCole Faust         _input_intermediate_scale  = input_intermediate_scale;
175*c217d954SCole Faust         _forget_intermediate_scale = forget_intermediate_scale;
176*c217d954SCole Faust         _cell_intermediate_scale   = cell_intermediate_scale;
177*c217d954SCole Faust         _output_intermediate_scale = output_intermediate_scale;
178*c217d954SCole Faust         return *this;
179*c217d954SCole Faust     }
180*c217d954SCole Faust 
181*c217d954SCole Faust     /** Set hidden state zero and scale parameters.
182*c217d954SCole Faust      *
183*c217d954SCole Faust      * @param[in] hidden_state_zero  The zero point of the hidden state.
184*c217d954SCole Faust      * @param[in] hidden_state_scale The scale of the hidden state.
185*c217d954SCole Faust      *
186*c217d954SCole Faust      * @return Reference to this LSTMParams object
187*c217d954SCole Faust      */
set_hidden_state_params(int32_t hidden_state_zero,float hidden_state_scale)188*c217d954SCole Faust     LSTMParams &set_hidden_state_params(int32_t hidden_state_zero, float hidden_state_scale)
189*c217d954SCole Faust     {
190*c217d954SCole Faust         _hidden_state_zero  = hidden_state_zero;
191*c217d954SCole Faust         _hidden_state_scale = hidden_state_scale;
192*c217d954SCole Faust         return *this;
193*c217d954SCole Faust     }
194*c217d954SCole Faust 
input_to_input_weights()195*c217d954SCole Faust     const T *input_to_input_weights() const
196*c217d954SCole Faust     {
197*c217d954SCole Faust         return _input_to_input_weights;
198*c217d954SCole Faust     }
199*c217d954SCole Faust 
recurrent_to_input_weights()200*c217d954SCole Faust     const T *recurrent_to_input_weights() const
201*c217d954SCole Faust     {
202*c217d954SCole Faust         return _recurrent_to_input_weights;
203*c217d954SCole Faust     }
204*c217d954SCole Faust 
cell_to_input_weights()205*c217d954SCole Faust     T *cell_to_input_weights() const
206*c217d954SCole Faust     {
207*c217d954SCole Faust         return _cell_to_input_weights;
208*c217d954SCole Faust     }
209*c217d954SCole Faust 
input_gate_bias()210*c217d954SCole Faust     const T *input_gate_bias() const
211*c217d954SCole Faust     {
212*c217d954SCole Faust         return _input_gate_bias;
213*c217d954SCole Faust     }
214*c217d954SCole Faust 
cell_to_forget_weights()215*c217d954SCole Faust     T *cell_to_forget_weights() const
216*c217d954SCole Faust     {
217*c217d954SCole Faust         return _cell_to_forget_weights;
218*c217d954SCole Faust     }
219*c217d954SCole Faust 
cell_to_output_weights()220*c217d954SCole Faust     T *cell_to_output_weights() const
221*c217d954SCole Faust     {
222*c217d954SCole Faust         return _cell_to_output_weights;
223*c217d954SCole Faust     }
224*c217d954SCole Faust 
projection_weights()225*c217d954SCole Faust     const T *projection_weights() const
226*c217d954SCole Faust     {
227*c217d954SCole Faust         return _projection_weights;
228*c217d954SCole Faust     }
229*c217d954SCole Faust 
projection_bias()230*c217d954SCole Faust     const T *projection_bias() const
231*c217d954SCole Faust     {
232*c217d954SCole Faust         return _projection_bias;
233*c217d954SCole Faust     }
234*c217d954SCole Faust 
input_layer_norm_weights()235*c217d954SCole Faust     T *input_layer_norm_weights() const
236*c217d954SCole Faust     {
237*c217d954SCole Faust         return _input_layer_norm_weights;
238*c217d954SCole Faust     }
239*c217d954SCole Faust 
forget_layer_norm_weights()240*c217d954SCole Faust     T *forget_layer_norm_weights() const
241*c217d954SCole Faust     {
242*c217d954SCole Faust         return _forget_layer_norm_weights;
243*c217d954SCole Faust     }
244*c217d954SCole Faust 
cell_layer_norm_weights()245*c217d954SCole Faust     T *cell_layer_norm_weights() const
246*c217d954SCole Faust     {
247*c217d954SCole Faust         return _cell_layer_norm_weights;
248*c217d954SCole Faust     }
249*c217d954SCole Faust 
output_layer_norm_weights()250*c217d954SCole Faust     T *output_layer_norm_weights() const
251*c217d954SCole Faust     {
252*c217d954SCole Faust         return _output_layer_norm_weights;
253*c217d954SCole Faust     }
254*c217d954SCole Faust 
cell_clip()255*c217d954SCole Faust     float cell_clip() const
256*c217d954SCole Faust     {
257*c217d954SCole Faust         return _cell_clip;
258*c217d954SCole Faust     }
259*c217d954SCole Faust 
projection_clip()260*c217d954SCole Faust     float projection_clip() const
261*c217d954SCole Faust     {
262*c217d954SCole Faust         return _projection_clip;
263*c217d954SCole Faust     }
264*c217d954SCole Faust 
input_intermediate_scale()265*c217d954SCole Faust     float input_intermediate_scale() const
266*c217d954SCole Faust     {
267*c217d954SCole Faust         return _input_intermediate_scale;
268*c217d954SCole Faust     }
269*c217d954SCole Faust 
forget_intermediate_scale()270*c217d954SCole Faust     float forget_intermediate_scale() const
271*c217d954SCole Faust     {
272*c217d954SCole Faust         return _forget_intermediate_scale;
273*c217d954SCole Faust     }
274*c217d954SCole Faust 
cell_intermediate_scale()275*c217d954SCole Faust     float cell_intermediate_scale() const
276*c217d954SCole Faust     {
277*c217d954SCole Faust         return _cell_intermediate_scale;
278*c217d954SCole Faust     }
279*c217d954SCole Faust 
output_intermediate_scale()280*c217d954SCole Faust     float output_intermediate_scale() const
281*c217d954SCole Faust     {
282*c217d954SCole Faust         return _output_intermediate_scale;
283*c217d954SCole Faust     }
284*c217d954SCole Faust 
hidden_state_zero()285*c217d954SCole Faust     int32_t hidden_state_zero() const
286*c217d954SCole Faust     {
287*c217d954SCole Faust         return _hidden_state_zero;
288*c217d954SCole Faust     }
289*c217d954SCole Faust 
hidden_state_scale()290*c217d954SCole Faust     float hidden_state_scale() const
291*c217d954SCole Faust     {
292*c217d954SCole Faust         return _hidden_state_scale;
293*c217d954SCole Faust     }
294*c217d954SCole Faust 
has_peephole_opt()295*c217d954SCole Faust     bool has_peephole_opt() const
296*c217d954SCole Faust     {
297*c217d954SCole Faust         return _has_peephole_opt;
298*c217d954SCole Faust     }
299*c217d954SCole Faust 
has_projection()300*c217d954SCole Faust     bool has_projection() const
301*c217d954SCole Faust     {
302*c217d954SCole Faust         return _has_projection;
303*c217d954SCole Faust     }
304*c217d954SCole Faust 
has_cifg_opt()305*c217d954SCole Faust     bool has_cifg_opt() const
306*c217d954SCole Faust     {
307*c217d954SCole Faust         return _has_cifg_opt;
308*c217d954SCole Faust     }
309*c217d954SCole Faust 
use_layer_norm()310*c217d954SCole Faust     bool use_layer_norm() const
311*c217d954SCole Faust     {
312*c217d954SCole Faust         return _use_layer_norm;
313*c217d954SCole Faust     }
314*c217d954SCole Faust 
315*c217d954SCole Faust private:
316*c217d954SCole Faust     const T *_input_to_input_weights;
317*c217d954SCole Faust     const T *_recurrent_to_input_weights;
318*c217d954SCole Faust     T       *_cell_to_input_weights;
319*c217d954SCole Faust     const T *_input_gate_bias;
320*c217d954SCole Faust     T       *_cell_to_forget_weights;
321*c217d954SCole Faust     T       *_cell_to_output_weights;
322*c217d954SCole Faust     const T *_projection_weights;
323*c217d954SCole Faust     const T *_projection_bias;
324*c217d954SCole Faust     T       *_input_layer_norm_weights;
325*c217d954SCole Faust     T       *_forget_layer_norm_weights;
326*c217d954SCole Faust     T       *_cell_layer_norm_weights;
327*c217d954SCole Faust     T       *_output_layer_norm_weights;
328*c217d954SCole Faust     float    _cell_clip;
329*c217d954SCole Faust     float    _projection_clip;
330*c217d954SCole Faust     float    _input_intermediate_scale;
331*c217d954SCole Faust     float    _forget_intermediate_scale;
332*c217d954SCole Faust     float    _cell_intermediate_scale;
333*c217d954SCole Faust     float    _output_intermediate_scale;
334*c217d954SCole Faust     int32_t  _hidden_state_zero;
335*c217d954SCole Faust     float    _hidden_state_scale;
336*c217d954SCole Faust     bool     _has_peephole_opt;
337*c217d954SCole Faust     bool     _has_projection;
338*c217d954SCole Faust     bool     _has_cifg_opt;
339*c217d954SCole Faust     bool     _use_layer_norm;
340*c217d954SCole Faust };
341*c217d954SCole Faust }
342*c217d954SCole Faust #endif /*ARM_COMPUTE_LSTMPARAMS_H */
343