1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker #pragma once 6*89c4ff92SAndroid Build Coastguard Worker 7*89c4ff92SAndroid Build Coastguard Worker #include "TensorFwd.hpp" 8*89c4ff92SAndroid Build Coastguard Worker #include "Exceptions.hpp" 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker namespace armnn 11*89c4ff92SAndroid Build Coastguard Worker { 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Worker struct QuantizedLstmInputParams 14*89c4ff92SAndroid Build Coastguard Worker { QuantizedLstmInputParamsarmnn::QuantizedLstmInputParams15*89c4ff92SAndroid Build Coastguard Worker QuantizedLstmInputParams() 16*89c4ff92SAndroid Build Coastguard Worker : m_InputToInputWeights(nullptr) 17*89c4ff92SAndroid Build Coastguard Worker , m_InputToForgetWeights(nullptr) 18*89c4ff92SAndroid Build Coastguard Worker , m_InputToCellWeights(nullptr) 19*89c4ff92SAndroid Build Coastguard Worker , m_InputToOutputWeights(nullptr) 20*89c4ff92SAndroid Build Coastguard Worker 21*89c4ff92SAndroid Build Coastguard Worker , m_RecurrentToInputWeights(nullptr) 22*89c4ff92SAndroid Build Coastguard Worker , m_RecurrentToForgetWeights(nullptr) 23*89c4ff92SAndroid Build Coastguard Worker , m_RecurrentToCellWeights(nullptr) 24*89c4ff92SAndroid Build Coastguard Worker , m_RecurrentToOutputWeights(nullptr) 25*89c4ff92SAndroid Build Coastguard Worker 26*89c4ff92SAndroid Build Coastguard Worker , m_InputGateBias(nullptr) 27*89c4ff92SAndroid Build Coastguard Worker , m_ForgetGateBias(nullptr) 28*89c4ff92SAndroid Build Coastguard Worker , m_CellBias(nullptr) 29*89c4ff92SAndroid Build Coastguard Worker , m_OutputGateBias(nullptr) 30*89c4ff92SAndroid Build Coastguard Worker { 31*89c4ff92SAndroid Build Coastguard Worker } 32*89c4ff92SAndroid Build Coastguard Worker 33*89c4ff92SAndroid Build Coastguard Worker const ConstTensor* m_InputToInputWeights; 34*89c4ff92SAndroid Build Coastguard Worker const ConstTensor* m_InputToForgetWeights; 35*89c4ff92SAndroid Build Coastguard Worker const ConstTensor* m_InputToCellWeights; 36*89c4ff92SAndroid Build Coastguard Worker const ConstTensor* m_InputToOutputWeights; 37*89c4ff92SAndroid Build Coastguard Worker 38*89c4ff92SAndroid Build Coastguard Worker const ConstTensor* m_RecurrentToInputWeights; 39*89c4ff92SAndroid Build Coastguard Worker const ConstTensor* m_RecurrentToForgetWeights; 40*89c4ff92SAndroid Build Coastguard Worker const ConstTensor* m_RecurrentToCellWeights; 41*89c4ff92SAndroid Build Coastguard Worker const ConstTensor* m_RecurrentToOutputWeights; 42*89c4ff92SAndroid Build Coastguard Worker 43*89c4ff92SAndroid Build Coastguard Worker const ConstTensor* m_InputGateBias; 44*89c4ff92SAndroid Build Coastguard Worker const ConstTensor* m_ForgetGateBias; 45*89c4ff92SAndroid Build Coastguard Worker const ConstTensor* m_CellBias; 46*89c4ff92SAndroid Build Coastguard Worker const ConstTensor* m_OutputGateBias; 47*89c4ff92SAndroid Build Coastguard Worker Derefarmnn::QuantizedLstmInputParams48*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& Deref(const ConstTensor* tensorPtr) const 49*89c4ff92SAndroid Build Coastguard Worker { 50*89c4ff92SAndroid Build Coastguard Worker if (tensorPtr != nullptr) 51*89c4ff92SAndroid Build Coastguard Worker { 52*89c4ff92SAndroid Build Coastguard Worker const ConstTensor &temp = *tensorPtr; 53*89c4ff92SAndroid Build Coastguard Worker return temp; 54*89c4ff92SAndroid Build Coastguard Worker } 55*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("QuantizedLstmInputParams: Can't dereference a null pointer"); 56*89c4ff92SAndroid Build Coastguard Worker } 57*89c4ff92SAndroid Build Coastguard Worker GetInputToInputWeightsarmnn::QuantizedLstmInputParams58*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& GetInputToInputWeights() const 59*89c4ff92SAndroid Build Coastguard Worker { 60*89c4ff92SAndroid Build Coastguard Worker return Deref(m_InputToInputWeights); 61*89c4ff92SAndroid Build Coastguard Worker } 62*89c4ff92SAndroid Build Coastguard Worker GetInputToForgetWeightsarmnn::QuantizedLstmInputParams63*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& GetInputToForgetWeights() const 64*89c4ff92SAndroid Build Coastguard Worker { 65*89c4ff92SAndroid Build Coastguard Worker return Deref(m_InputToForgetWeights); 66*89c4ff92SAndroid Build Coastguard Worker } 67*89c4ff92SAndroid Build Coastguard Worker GetInputToCellWeightsarmnn::QuantizedLstmInputParams68*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& GetInputToCellWeights() const 69*89c4ff92SAndroid Build Coastguard Worker { 70*89c4ff92SAndroid Build Coastguard Worker return Deref(m_InputToCellWeights); 71*89c4ff92SAndroid Build Coastguard Worker } 72*89c4ff92SAndroid Build Coastguard Worker GetInputToOutputWeightsarmnn::QuantizedLstmInputParams73*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& GetInputToOutputWeights() const 74*89c4ff92SAndroid Build Coastguard Worker { 75*89c4ff92SAndroid Build Coastguard Worker return Deref(m_InputToOutputWeights); 76*89c4ff92SAndroid Build Coastguard Worker } 77*89c4ff92SAndroid Build Coastguard Worker GetRecurrentToInputWeightsarmnn::QuantizedLstmInputParams78*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& GetRecurrentToInputWeights() const 79*89c4ff92SAndroid Build Coastguard Worker { 80*89c4ff92SAndroid Build Coastguard Worker return Deref(m_RecurrentToInputWeights); 81*89c4ff92SAndroid Build Coastguard Worker } 82*89c4ff92SAndroid Build Coastguard Worker GetRecurrentToForgetWeightsarmnn::QuantizedLstmInputParams83*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& GetRecurrentToForgetWeights() const 84*89c4ff92SAndroid Build Coastguard Worker { 85*89c4ff92SAndroid Build Coastguard Worker return Deref(m_RecurrentToForgetWeights); 86*89c4ff92SAndroid Build Coastguard Worker } 87*89c4ff92SAndroid Build Coastguard Worker GetRecurrentToCellWeightsarmnn::QuantizedLstmInputParams88*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& GetRecurrentToCellWeights() const 89*89c4ff92SAndroid Build Coastguard Worker { 90*89c4ff92SAndroid Build Coastguard Worker return Deref(m_RecurrentToCellWeights); 91*89c4ff92SAndroid Build Coastguard Worker } 92*89c4ff92SAndroid Build Coastguard Worker GetRecurrentToOutputWeightsarmnn::QuantizedLstmInputParams93*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& GetRecurrentToOutputWeights() const 94*89c4ff92SAndroid Build Coastguard Worker { 95*89c4ff92SAndroid Build Coastguard Worker return Deref(m_RecurrentToOutputWeights); 96*89c4ff92SAndroid Build Coastguard Worker } 97*89c4ff92SAndroid Build Coastguard Worker GetInputGateBiasarmnn::QuantizedLstmInputParams98*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& GetInputGateBias() const 99*89c4ff92SAndroid Build Coastguard Worker { 100*89c4ff92SAndroid Build Coastguard Worker return Deref(m_InputGateBias); 101*89c4ff92SAndroid Build Coastguard Worker } 102*89c4ff92SAndroid Build Coastguard Worker GetForgetGateBiasarmnn::QuantizedLstmInputParams103*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& GetForgetGateBias() const 104*89c4ff92SAndroid Build Coastguard Worker { 105*89c4ff92SAndroid Build Coastguard Worker return Deref(m_ForgetGateBias); 106*89c4ff92SAndroid Build Coastguard Worker } 107*89c4ff92SAndroid Build Coastguard Worker GetCellBiasarmnn::QuantizedLstmInputParams108*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& GetCellBias() const 109*89c4ff92SAndroid Build Coastguard Worker { 110*89c4ff92SAndroid Build Coastguard Worker return Deref(m_CellBias); 111*89c4ff92SAndroid Build Coastguard Worker } 112*89c4ff92SAndroid Build Coastguard Worker GetOutputGateBiasarmnn::QuantizedLstmInputParams113*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& GetOutputGateBias() const 114*89c4ff92SAndroid Build Coastguard Worker { 115*89c4ff92SAndroid Build Coastguard Worker return Deref(m_OutputGateBias); 116*89c4ff92SAndroid Build Coastguard Worker } 117*89c4ff92SAndroid Build Coastguard Worker }; 118*89c4ff92SAndroid Build Coastguard Worker 119*89c4ff92SAndroid Build Coastguard Worker struct QuantizedLstmInputParamsInfo 120*89c4ff92SAndroid Build Coastguard Worker { QuantizedLstmInputParamsInfoarmnn::QuantizedLstmInputParamsInfo121*89c4ff92SAndroid Build Coastguard Worker QuantizedLstmInputParamsInfo() 122*89c4ff92SAndroid Build Coastguard Worker : m_InputToInputWeights(nullptr) 123*89c4ff92SAndroid Build Coastguard Worker , m_InputToForgetWeights(nullptr) 124*89c4ff92SAndroid Build Coastguard Worker , m_InputToCellWeights(nullptr) 125*89c4ff92SAndroid Build Coastguard Worker , m_InputToOutputWeights(nullptr) 126*89c4ff92SAndroid Build Coastguard Worker 127*89c4ff92SAndroid Build Coastguard Worker , m_RecurrentToInputWeights(nullptr) 128*89c4ff92SAndroid Build Coastguard Worker , m_RecurrentToForgetWeights(nullptr) 129*89c4ff92SAndroid Build Coastguard Worker , m_RecurrentToCellWeights(nullptr) 130*89c4ff92SAndroid Build Coastguard Worker , m_RecurrentToOutputWeights(nullptr) 131*89c4ff92SAndroid Build Coastguard Worker 132*89c4ff92SAndroid Build Coastguard Worker , m_InputGateBias(nullptr) 133*89c4ff92SAndroid Build Coastguard Worker , m_ForgetGateBias(nullptr) 134*89c4ff92SAndroid Build Coastguard Worker , m_CellBias(nullptr) 135*89c4ff92SAndroid Build Coastguard Worker , m_OutputGateBias(nullptr) 136*89c4ff92SAndroid Build Coastguard Worker { 137*89c4ff92SAndroid Build Coastguard Worker } 138*89c4ff92SAndroid Build Coastguard Worker 139*89c4ff92SAndroid Build Coastguard Worker const TensorInfo* m_InputToInputWeights; 140*89c4ff92SAndroid Build Coastguard Worker const TensorInfo* m_InputToForgetWeights; 141*89c4ff92SAndroid Build Coastguard Worker const TensorInfo* m_InputToCellWeights; 142*89c4ff92SAndroid Build Coastguard Worker const TensorInfo* m_InputToOutputWeights; 143*89c4ff92SAndroid Build Coastguard Worker 144*89c4ff92SAndroid Build Coastguard Worker const TensorInfo* m_RecurrentToInputWeights; 145*89c4ff92SAndroid Build Coastguard Worker const TensorInfo* m_RecurrentToForgetWeights; 146*89c4ff92SAndroid Build Coastguard Worker const TensorInfo* m_RecurrentToCellWeights; 147*89c4ff92SAndroid Build Coastguard Worker const TensorInfo* m_RecurrentToOutputWeights; 148*89c4ff92SAndroid Build Coastguard Worker 149*89c4ff92SAndroid Build Coastguard Worker const TensorInfo* m_InputGateBias; 150*89c4ff92SAndroid Build Coastguard Worker const TensorInfo* m_ForgetGateBias; 151*89c4ff92SAndroid Build Coastguard Worker const TensorInfo* m_CellBias; 152*89c4ff92SAndroid Build Coastguard Worker const TensorInfo* m_OutputGateBias; 153*89c4ff92SAndroid Build Coastguard Worker 154*89c4ff92SAndroid Build Coastguard Worker Derefarmnn::QuantizedLstmInputParamsInfo155*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& Deref(const TensorInfo* tensorInfo) const 156*89c4ff92SAndroid Build Coastguard Worker { 157*89c4ff92SAndroid Build Coastguard Worker if (tensorInfo != nullptr) 158*89c4ff92SAndroid Build Coastguard Worker { 159*89c4ff92SAndroid Build Coastguard Worker const TensorInfo &temp = *tensorInfo; 160*89c4ff92SAndroid Build Coastguard Worker return temp; 161*89c4ff92SAndroid Build Coastguard Worker } 162*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Can't dereference a null pointer"); 163*89c4ff92SAndroid Build Coastguard Worker } 164*89c4ff92SAndroid Build Coastguard Worker GetInputToInputWeightsarmnn::QuantizedLstmInputParamsInfo165*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& GetInputToInputWeights() const 166*89c4ff92SAndroid Build Coastguard Worker { 167*89c4ff92SAndroid Build Coastguard Worker return Deref(m_InputToInputWeights); 168*89c4ff92SAndroid Build Coastguard Worker } GetInputToForgetWeightsarmnn::QuantizedLstmInputParamsInfo169*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& GetInputToForgetWeights() const 170*89c4ff92SAndroid Build Coastguard Worker { 171*89c4ff92SAndroid Build Coastguard Worker return Deref(m_InputToForgetWeights); 172*89c4ff92SAndroid Build Coastguard Worker } GetInputToCellWeightsarmnn::QuantizedLstmInputParamsInfo173*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& GetInputToCellWeights() const 174*89c4ff92SAndroid Build Coastguard Worker { 175*89c4ff92SAndroid Build Coastguard Worker return Deref(m_InputToCellWeights); 176*89c4ff92SAndroid Build Coastguard Worker } GetInputToOutputWeightsarmnn::QuantizedLstmInputParamsInfo177*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& GetInputToOutputWeights() const 178*89c4ff92SAndroid Build Coastguard Worker { 179*89c4ff92SAndroid Build Coastguard Worker return Deref(m_InputToOutputWeights); 180*89c4ff92SAndroid Build Coastguard Worker } 181*89c4ff92SAndroid Build Coastguard Worker GetRecurrentToInputWeightsarmnn::QuantizedLstmInputParamsInfo182*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& GetRecurrentToInputWeights() const 183*89c4ff92SAndroid Build Coastguard Worker { 184*89c4ff92SAndroid Build Coastguard Worker return Deref(m_RecurrentToInputWeights); 185*89c4ff92SAndroid Build Coastguard Worker } GetRecurrentToForgetWeightsarmnn::QuantizedLstmInputParamsInfo186*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& GetRecurrentToForgetWeights() const 187*89c4ff92SAndroid Build Coastguard Worker { 188*89c4ff92SAndroid Build Coastguard Worker return Deref(m_RecurrentToForgetWeights); 189*89c4ff92SAndroid Build Coastguard Worker } GetRecurrentToCellWeightsarmnn::QuantizedLstmInputParamsInfo190*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& GetRecurrentToCellWeights() const 191*89c4ff92SAndroid Build Coastguard Worker { 192*89c4ff92SAndroid Build Coastguard Worker return Deref(m_RecurrentToCellWeights); 193*89c4ff92SAndroid Build Coastguard Worker } GetRecurrentToOutputWeightsarmnn::QuantizedLstmInputParamsInfo194*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& GetRecurrentToOutputWeights() const 195*89c4ff92SAndroid Build Coastguard Worker { 196*89c4ff92SAndroid Build Coastguard Worker return Deref(m_RecurrentToOutputWeights); 197*89c4ff92SAndroid Build Coastguard Worker } 198*89c4ff92SAndroid Build Coastguard Worker GetInputGateBiasarmnn::QuantizedLstmInputParamsInfo199*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& GetInputGateBias() const 200*89c4ff92SAndroid Build Coastguard Worker { 201*89c4ff92SAndroid Build Coastguard Worker return Deref(m_InputGateBias); 202*89c4ff92SAndroid Build Coastguard Worker } GetForgetGateBiasarmnn::QuantizedLstmInputParamsInfo203*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& GetForgetGateBias() const 204*89c4ff92SAndroid Build Coastguard Worker { 205*89c4ff92SAndroid Build Coastguard Worker return Deref(m_ForgetGateBias); 206*89c4ff92SAndroid Build Coastguard Worker } GetCellBiasarmnn::QuantizedLstmInputParamsInfo207*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& GetCellBias() const 208*89c4ff92SAndroid Build Coastguard Worker { 209*89c4ff92SAndroid Build Coastguard Worker return Deref(m_CellBias); 210*89c4ff92SAndroid Build Coastguard Worker } GetOutputGateBiasarmnn::QuantizedLstmInputParamsInfo211*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& GetOutputGateBias() const 212*89c4ff92SAndroid Build Coastguard Worker { 213*89c4ff92SAndroid Build Coastguard Worker return Deref(m_OutputGateBias); 214*89c4ff92SAndroid Build Coastguard Worker } 215*89c4ff92SAndroid Build Coastguard Worker }; 216*89c4ff92SAndroid Build Coastguard Worker 217*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn 218*89c4ff92SAndroid Build Coastguard Worker 219