1*c217d954SCole Faust /* 2*c217d954SCole Faust * Copyright (c) 2018-2021 Arm Limited. 3*c217d954SCole Faust * 4*c217d954SCole Faust * SPDX-License-Identifier: MIT 5*c217d954SCole Faust * 6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy 7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to 8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the 9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is 11*c217d954SCole Faust * furnished to do so, subject to the following conditions: 12*c217d954SCole Faust * 13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all 14*c217d954SCole Faust * copies or substantial portions of the Software. 15*c217d954SCole Faust * 16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22*c217d954SCole Faust * SOFTWARE. 23*c217d954SCole Faust */ 24*c217d954SCole Faust #ifndef ARM_COMPUTE_LSTMPARAMS_H 25*c217d954SCole Faust #define ARM_COMPUTE_LSTMPARAMS_H 26*c217d954SCole Faust 27*c217d954SCole Faust #include "arm_compute/core/Types.h" 28*c217d954SCole Faust #include "arm_compute/runtime/Tensor.h" 29*c217d954SCole Faust 30*c217d954SCole Faust #include <cstddef> 31*c217d954SCole Faust #include <memory> 32*c217d954SCole Faust 33*c217d954SCole Faust namespace arm_compute 34*c217d954SCole Faust { 35*c217d954SCole Faust template <typename T> 36*c217d954SCole Faust class LSTMParams 37*c217d954SCole Faust { 38*c217d954SCole Faust public: 39*c217d954SCole Faust /** Constructor */ LSTMParams()40*c217d954SCole Faust LSTMParams() 41*c217d954SCole Faust : _input_to_input_weights(nullptr), 42*c217d954SCole Faust _recurrent_to_input_weights(nullptr), 43*c217d954SCole Faust _cell_to_input_weights(nullptr), 44*c217d954SCole Faust _input_gate_bias(nullptr), 45*c217d954SCole Faust _cell_to_forget_weights(nullptr), 46*c217d954SCole Faust _cell_to_output_weights(nullptr), 47*c217d954SCole Faust _projection_weights(nullptr), 48*c217d954SCole Faust _projection_bias(nullptr), 49*c217d954SCole Faust _input_layer_norm_weights(nullptr), 50*c217d954SCole Faust _forget_layer_norm_weights(nullptr), 51*c217d954SCole Faust _cell_layer_norm_weights(nullptr), 52*c217d954SCole Faust _output_layer_norm_weights(nullptr), 53*c217d954SCole Faust _cell_clip(0.f), 54*c217d954SCole Faust _projection_clip(0.0f), 55*c217d954SCole Faust _input_intermediate_scale(0.0f), 56*c217d954SCole Faust _forget_intermediate_scale(0.0f), 57*c217d954SCole Faust _cell_intermediate_scale(0.0f), 58*c217d954SCole Faust _output_intermediate_scale(0.0f), 59*c217d954SCole Faust _hidden_state_zero(0), 60*c217d954SCole Faust _hidden_state_scale(0.0f), 61*c217d954SCole Faust _has_peephole_opt(false), 62*c217d954SCole Faust _has_projection(false), 63*c217d954SCole Faust _has_cifg_opt(true), 64*c217d954SCole Faust _use_layer_norm(false) 65*c217d954SCole Faust { 66*c217d954SCole Faust } 67*c217d954SCole Faust /** Prevent instances of this class from being copied (As this class contains pointers) */ 68*c217d954SCole Faust LSTMParams(const LSTMParams &) = delete; 69*c217d954SCole Faust /** Prevent instances of this class from being copied (As this class contains pointers) */ 70*c217d954SCole Faust LSTMParams &operator=(const LSTMParams &) = delete; 71*c217d954SCole Faust /** Default destructor */ 72*c217d954SCole Faust ~LSTMParams() = default; 73*c217d954SCole Faust /** Set CIFG tensor parameters. 74*c217d954SCole Faust * 75*c217d954SCole Faust * @param[in] input_to_input_weights 2D weights tensor with dimensions [input_size, num_units]. Data types supported: QSYMM8/F16/F32. 76*c217d954SCole Faust * @param[in] recurrent_to_input_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input_to_input_weights. 77*c217d954SCole Faust * @param[in] cell_to_input_weights 1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: Same as @p input_to_input_weights. 78*c217d954SCole Faust * @param[in] input_gate_bias 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_to_input_weights, S32 when @p input_to_input_weights is QSYMM8 79*c217d954SCole Faust * 80*c217d954SCole Faust * @return Reference to this LSTMParams object 81*c217d954SCole Faust */ set_cifg_params(const T * input_to_input_weights,const T * recurrent_to_input_weights,T * cell_to_input_weights,const T * input_gate_bias)82*c217d954SCole Faust LSTMParams &set_cifg_params(const T *input_to_input_weights, const T *recurrent_to_input_weights, T *cell_to_input_weights, const T *input_gate_bias) 83*c217d954SCole Faust { 84*c217d954SCole Faust _input_to_input_weights = input_to_input_weights; 85*c217d954SCole Faust _recurrent_to_input_weights = recurrent_to_input_weights; 86*c217d954SCole Faust _cell_to_input_weights = cell_to_input_weights; 87*c217d954SCole Faust _input_gate_bias = input_gate_bias; 88*c217d954SCole Faust _has_cifg_opt = false; 89*c217d954SCole Faust return *this; 90*c217d954SCole Faust } 91*c217d954SCole Faust /** Set projection tensor parameters. 92*c217d954SCole Faust * 93*c217d954SCole Faust * @param[in] projection_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Data types supported: QSYMM8/F16/F32. 94*c217d954SCole Faust * @param[in] projection_bias 1D weights tensor with dimensions [output_size]. Data type supported: Same as @p projection_weights, S32 when @p input_to_input_weights is QSYMM8. 95*c217d954SCole Faust * 96*c217d954SCole Faust * @return Reference to this LSTMParams object 97*c217d954SCole Faust */ set_projection_params(const T * projection_weights,const T * projection_bias)98*c217d954SCole Faust LSTMParams &set_projection_params(const T *projection_weights, const T *projection_bias) 99*c217d954SCole Faust { 100*c217d954SCole Faust _projection_weights = projection_weights; 101*c217d954SCole Faust _projection_bias = projection_bias; 102*c217d954SCole Faust _has_projection = true; 103*c217d954SCole Faust return *this; 104*c217d954SCole Faust } 105*c217d954SCole Faust /** Set peephole tensor parameters. 106*c217d954SCole Faust * 107*c217d954SCole Faust * @param[in] cell_to_forget_weights 1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32. 108*c217d954SCole Faust * @param[in] cell_to_output_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p cell_to_forget_weights. 109*c217d954SCole Faust * 110*c217d954SCole Faust * @return Reference to this LSTMParams object 111*c217d954SCole Faust */ set_peephole_params(T * cell_to_forget_weights,T * cell_to_output_weights)112*c217d954SCole Faust LSTMParams &set_peephole_params(T *cell_to_forget_weights, T *cell_to_output_weights) 113*c217d954SCole Faust { 114*c217d954SCole Faust _cell_to_forget_weights = cell_to_forget_weights; 115*c217d954SCole Faust _cell_to_output_weights = cell_to_output_weights; 116*c217d954SCole Faust _has_peephole_opt = true; 117*c217d954SCole Faust return *this; 118*c217d954SCole Faust } 119*c217d954SCole Faust /** Set layer normalization tensor parameters. 120*c217d954SCole Faust * 121*c217d954SCole Faust * @param[in] input_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32. 122*c217d954SCole Faust * @param[in] forget_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights. 123*c217d954SCole Faust * @param[in] cell_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights. 124*c217d954SCole Faust * @param[in] output_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights. 125*c217d954SCole Faust * 126*c217d954SCole Faust * @return Reference to this LSTMParams object 127*c217d954SCole Faust */ set_layer_normalization_params(T * input_layer_norm_weights,T * forget_layer_norm_weights,T * cell_layer_norm_weights,T * output_layer_norm_weights)128*c217d954SCole Faust LSTMParams &set_layer_normalization_params(T *input_layer_norm_weights, T *forget_layer_norm_weights, 129*c217d954SCole Faust T *cell_layer_norm_weights, T *output_layer_norm_weights) 130*c217d954SCole Faust { 131*c217d954SCole Faust _input_layer_norm_weights = input_layer_norm_weights; 132*c217d954SCole Faust _forget_layer_norm_weights = forget_layer_norm_weights; 133*c217d954SCole Faust _cell_layer_norm_weights = cell_layer_norm_weights; 134*c217d954SCole Faust _output_layer_norm_weights = output_layer_norm_weights; 135*c217d954SCole Faust _use_layer_norm = true; 136*c217d954SCole Faust return *this; 137*c217d954SCole Faust } 138*c217d954SCole Faust 139*c217d954SCole Faust /** Set cell clip value. 140*c217d954SCole Faust * 141*c217d954SCole Faust * @param[in] cell_clip Value to be used to clip the cell state prior to the cell output activation. 142*c217d954SCole Faust * 143*c217d954SCole Faust * @return Reference to this LSTMParams object 144*c217d954SCole Faust */ set_cell_clip_params(float cell_clip)145*c217d954SCole Faust LSTMParams &set_cell_clip_params(float cell_clip) 146*c217d954SCole Faust { 147*c217d954SCole Faust _cell_clip = cell_clip; 148*c217d954SCole Faust return *this; 149*c217d954SCole Faust } 150*c217d954SCole Faust 151*c217d954SCole Faust /** Set projection clip value. 152*c217d954SCole Faust * 153*c217d954SCole Faust * @param[in] projection_clip Value to be used to clip the projection, in case projection is enabled. 154*c217d954SCole Faust * 155*c217d954SCole Faust * @return Reference to this LSTMParams object 156*c217d954SCole Faust */ set_projection_clip_params(float projection_clip)157*c217d954SCole Faust LSTMParams &set_projection_clip_params(float projection_clip) 158*c217d954SCole Faust { 159*c217d954SCole Faust _projection_clip = projection_clip; 160*c217d954SCole Faust return *this; 161*c217d954SCole Faust } 162*c217d954SCole Faust 163*c217d954SCole Faust /** Set scale of the intermediate results of matmul of each layer parameters. 164*c217d954SCole Faust * 165*c217d954SCole Faust * @param[in] input_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate. 166*c217d954SCole Faust * @param[in] forget_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate. 167*c217d954SCole Faust * @param[in] cell_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate. 168*c217d954SCole Faust * @param[in] output_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate. 169*c217d954SCole Faust * 170*c217d954SCole Faust * @return Reference to this LSTMParams object 171*c217d954SCole Faust */ set_matmul_scale_params(float input_intermediate_scale,float forget_intermediate_scale,float cell_intermediate_scale,float output_intermediate_scale)172*c217d954SCole Faust LSTMParams &set_matmul_scale_params(float input_intermediate_scale, float forget_intermediate_scale, float cell_intermediate_scale, float output_intermediate_scale) 173*c217d954SCole Faust { 174*c217d954SCole Faust _input_intermediate_scale = input_intermediate_scale; 175*c217d954SCole Faust _forget_intermediate_scale = forget_intermediate_scale; 176*c217d954SCole Faust _cell_intermediate_scale = cell_intermediate_scale; 177*c217d954SCole Faust _output_intermediate_scale = output_intermediate_scale; 178*c217d954SCole Faust return *this; 179*c217d954SCole Faust } 180*c217d954SCole Faust 181*c217d954SCole Faust /** Set hidden state zero and scale parameters. 182*c217d954SCole Faust * 183*c217d954SCole Faust * @param[in] hidden_state_zero The zero point of the hidden state. 184*c217d954SCole Faust * @param[in] hidden_state_scale The scale of the hidden state. 185*c217d954SCole Faust * 186*c217d954SCole Faust * @return Reference to this LSTMParams object 187*c217d954SCole Faust */ set_hidden_state_params(int32_t hidden_state_zero,float hidden_state_scale)188*c217d954SCole Faust LSTMParams &set_hidden_state_params(int32_t hidden_state_zero, float hidden_state_scale) 189*c217d954SCole Faust { 190*c217d954SCole Faust _hidden_state_zero = hidden_state_zero; 191*c217d954SCole Faust _hidden_state_scale = hidden_state_scale; 192*c217d954SCole Faust return *this; 193*c217d954SCole Faust } 194*c217d954SCole Faust input_to_input_weights()195*c217d954SCole Faust const T *input_to_input_weights() const 196*c217d954SCole Faust { 197*c217d954SCole Faust return _input_to_input_weights; 198*c217d954SCole Faust } 199*c217d954SCole Faust recurrent_to_input_weights()200*c217d954SCole Faust const T *recurrent_to_input_weights() const 201*c217d954SCole Faust { 202*c217d954SCole Faust return _recurrent_to_input_weights; 203*c217d954SCole Faust } 204*c217d954SCole Faust cell_to_input_weights()205*c217d954SCole Faust T *cell_to_input_weights() const 206*c217d954SCole Faust { 207*c217d954SCole Faust return _cell_to_input_weights; 208*c217d954SCole Faust } 209*c217d954SCole Faust input_gate_bias()210*c217d954SCole Faust const T *input_gate_bias() const 211*c217d954SCole Faust { 212*c217d954SCole Faust return _input_gate_bias; 213*c217d954SCole Faust } 214*c217d954SCole Faust cell_to_forget_weights()215*c217d954SCole Faust T *cell_to_forget_weights() const 216*c217d954SCole Faust { 217*c217d954SCole Faust return _cell_to_forget_weights; 218*c217d954SCole Faust } 219*c217d954SCole Faust cell_to_output_weights()220*c217d954SCole Faust T *cell_to_output_weights() const 221*c217d954SCole Faust { 222*c217d954SCole Faust return _cell_to_output_weights; 223*c217d954SCole Faust } 224*c217d954SCole Faust projection_weights()225*c217d954SCole Faust const T *projection_weights() const 226*c217d954SCole Faust { 227*c217d954SCole Faust return _projection_weights; 228*c217d954SCole Faust } 229*c217d954SCole Faust projection_bias()230*c217d954SCole Faust const T *projection_bias() const 231*c217d954SCole Faust { 232*c217d954SCole Faust return _projection_bias; 233*c217d954SCole Faust } 234*c217d954SCole Faust input_layer_norm_weights()235*c217d954SCole Faust T *input_layer_norm_weights() const 236*c217d954SCole Faust { 237*c217d954SCole Faust return _input_layer_norm_weights; 238*c217d954SCole Faust } 239*c217d954SCole Faust forget_layer_norm_weights()240*c217d954SCole Faust T *forget_layer_norm_weights() const 241*c217d954SCole Faust { 242*c217d954SCole Faust return _forget_layer_norm_weights; 243*c217d954SCole Faust } 244*c217d954SCole Faust cell_layer_norm_weights()245*c217d954SCole Faust T *cell_layer_norm_weights() const 246*c217d954SCole Faust { 247*c217d954SCole Faust return _cell_layer_norm_weights; 248*c217d954SCole Faust } 249*c217d954SCole Faust output_layer_norm_weights()250*c217d954SCole Faust T *output_layer_norm_weights() const 251*c217d954SCole Faust { 252*c217d954SCole Faust return _output_layer_norm_weights; 253*c217d954SCole Faust } 254*c217d954SCole Faust cell_clip()255*c217d954SCole Faust float cell_clip() const 256*c217d954SCole Faust { 257*c217d954SCole Faust return _cell_clip; 258*c217d954SCole Faust } 259*c217d954SCole Faust projection_clip()260*c217d954SCole Faust float projection_clip() const 261*c217d954SCole Faust { 262*c217d954SCole Faust return _projection_clip; 263*c217d954SCole Faust } 264*c217d954SCole Faust input_intermediate_scale()265*c217d954SCole Faust float input_intermediate_scale() const 266*c217d954SCole Faust { 267*c217d954SCole Faust return _input_intermediate_scale; 268*c217d954SCole Faust } 269*c217d954SCole Faust forget_intermediate_scale()270*c217d954SCole Faust float forget_intermediate_scale() const 271*c217d954SCole Faust { 272*c217d954SCole Faust return _forget_intermediate_scale; 273*c217d954SCole Faust } 274*c217d954SCole Faust cell_intermediate_scale()275*c217d954SCole Faust float cell_intermediate_scale() const 276*c217d954SCole Faust { 277*c217d954SCole Faust return _cell_intermediate_scale; 278*c217d954SCole Faust } 279*c217d954SCole Faust output_intermediate_scale()280*c217d954SCole Faust float output_intermediate_scale() const 281*c217d954SCole Faust { 282*c217d954SCole Faust return _output_intermediate_scale; 283*c217d954SCole Faust } 284*c217d954SCole Faust hidden_state_zero()285*c217d954SCole Faust int32_t hidden_state_zero() const 286*c217d954SCole Faust { 287*c217d954SCole Faust return _hidden_state_zero; 288*c217d954SCole Faust } 289*c217d954SCole Faust hidden_state_scale()290*c217d954SCole Faust float hidden_state_scale() const 291*c217d954SCole Faust { 292*c217d954SCole Faust return _hidden_state_scale; 293*c217d954SCole Faust } 294*c217d954SCole Faust has_peephole_opt()295*c217d954SCole Faust bool has_peephole_opt() const 296*c217d954SCole Faust { 297*c217d954SCole Faust return _has_peephole_opt; 298*c217d954SCole Faust } 299*c217d954SCole Faust has_projection()300*c217d954SCole Faust bool has_projection() const 301*c217d954SCole Faust { 302*c217d954SCole Faust return _has_projection; 303*c217d954SCole Faust } 304*c217d954SCole Faust has_cifg_opt()305*c217d954SCole Faust bool has_cifg_opt() const 306*c217d954SCole Faust { 307*c217d954SCole Faust return _has_cifg_opt; 308*c217d954SCole Faust } 309*c217d954SCole Faust use_layer_norm()310*c217d954SCole Faust bool use_layer_norm() const 311*c217d954SCole Faust { 312*c217d954SCole Faust return _use_layer_norm; 313*c217d954SCole Faust } 314*c217d954SCole Faust 315*c217d954SCole Faust private: 316*c217d954SCole Faust const T *_input_to_input_weights; 317*c217d954SCole Faust const T *_recurrent_to_input_weights; 318*c217d954SCole Faust T *_cell_to_input_weights; 319*c217d954SCole Faust const T *_input_gate_bias; 320*c217d954SCole Faust T *_cell_to_forget_weights; 321*c217d954SCole Faust T *_cell_to_output_weights; 322*c217d954SCole Faust const T *_projection_weights; 323*c217d954SCole Faust const T *_projection_bias; 324*c217d954SCole Faust T *_input_layer_norm_weights; 325*c217d954SCole Faust T *_forget_layer_norm_weights; 326*c217d954SCole Faust T *_cell_layer_norm_weights; 327*c217d954SCole Faust T *_output_layer_norm_weights; 328*c217d954SCole Faust float _cell_clip; 329*c217d954SCole Faust float _projection_clip; 330*c217d954SCole Faust float _input_intermediate_scale; 331*c217d954SCole Faust float _forget_intermediate_scale; 332*c217d954SCole Faust float _cell_intermediate_scale; 333*c217d954SCole Faust float _output_intermediate_scale; 334*c217d954SCole Faust int32_t _hidden_state_zero; 335*c217d954SCole Faust float _hidden_state_scale; 336*c217d954SCole Faust bool _has_peephole_opt; 337*c217d954SCole Faust bool _has_projection; 338*c217d954SCole Faust bool _has_cifg_opt; 339*c217d954SCole Faust bool _use_layer_norm; 340*c217d954SCole Faust }; 341*c217d954SCole Faust } 342*c217d954SCole Faust #endif /*ARM_COMPUTE_LSTMPARAMS_H */ 343