1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "TensorFwd.hpp" 8 #include "Exceptions.hpp" 9 10 namespace armnn 11 { 12 13 struct QuantizedLstmInputParams 14 { QuantizedLstmInputParamsarmnn::QuantizedLstmInputParams15 QuantizedLstmInputParams() 16 : m_InputToInputWeights(nullptr) 17 , m_InputToForgetWeights(nullptr) 18 , m_InputToCellWeights(nullptr) 19 , m_InputToOutputWeights(nullptr) 20 21 , m_RecurrentToInputWeights(nullptr) 22 , m_RecurrentToForgetWeights(nullptr) 23 , m_RecurrentToCellWeights(nullptr) 24 , m_RecurrentToOutputWeights(nullptr) 25 26 , m_InputGateBias(nullptr) 27 , m_ForgetGateBias(nullptr) 28 , m_CellBias(nullptr) 29 , m_OutputGateBias(nullptr) 30 { 31 } 32 33 const ConstTensor* m_InputToInputWeights; 34 const ConstTensor* m_InputToForgetWeights; 35 const ConstTensor* m_InputToCellWeights; 36 const ConstTensor* m_InputToOutputWeights; 37 38 const ConstTensor* m_RecurrentToInputWeights; 39 const ConstTensor* m_RecurrentToForgetWeights; 40 const ConstTensor* m_RecurrentToCellWeights; 41 const ConstTensor* m_RecurrentToOutputWeights; 42 43 const ConstTensor* m_InputGateBias; 44 const ConstTensor* m_ForgetGateBias; 45 const ConstTensor* m_CellBias; 46 const ConstTensor* m_OutputGateBias; 47 Derefarmnn::QuantizedLstmInputParams48 const ConstTensor& Deref(const ConstTensor* tensorPtr) const 49 { 50 if (tensorPtr != nullptr) 51 { 52 const ConstTensor &temp = *tensorPtr; 53 return temp; 54 } 55 throw InvalidArgumentException("QuantizedLstmInputParams: Can't dereference a null pointer"); 56 } 57 GetInputToInputWeightsarmnn::QuantizedLstmInputParams58 const ConstTensor& GetInputToInputWeights() const 59 { 60 return Deref(m_InputToInputWeights); 61 } 62 GetInputToForgetWeightsarmnn::QuantizedLstmInputParams63 const ConstTensor& GetInputToForgetWeights() const 64 { 65 return Deref(m_InputToForgetWeights); 66 } 67 GetInputToCellWeightsarmnn::QuantizedLstmInputParams68 const ConstTensor& GetInputToCellWeights() const 69 { 70 return Deref(m_InputToCellWeights); 71 } 72 GetInputToOutputWeightsarmnn::QuantizedLstmInputParams73 const ConstTensor& GetInputToOutputWeights() const 74 { 75 return Deref(m_InputToOutputWeights); 76 } 77 GetRecurrentToInputWeightsarmnn::QuantizedLstmInputParams78 const ConstTensor& GetRecurrentToInputWeights() const 79 { 80 return Deref(m_RecurrentToInputWeights); 81 } 82 GetRecurrentToForgetWeightsarmnn::QuantizedLstmInputParams83 const ConstTensor& GetRecurrentToForgetWeights() const 84 { 85 return Deref(m_RecurrentToForgetWeights); 86 } 87 GetRecurrentToCellWeightsarmnn::QuantizedLstmInputParams88 const ConstTensor& GetRecurrentToCellWeights() const 89 { 90 return Deref(m_RecurrentToCellWeights); 91 } 92 GetRecurrentToOutputWeightsarmnn::QuantizedLstmInputParams93 const ConstTensor& GetRecurrentToOutputWeights() const 94 { 95 return Deref(m_RecurrentToOutputWeights); 96 } 97 GetInputGateBiasarmnn::QuantizedLstmInputParams98 const ConstTensor& GetInputGateBias() const 99 { 100 return Deref(m_InputGateBias); 101 } 102 GetForgetGateBiasarmnn::QuantizedLstmInputParams103 const ConstTensor& GetForgetGateBias() const 104 { 105 return Deref(m_ForgetGateBias); 106 } 107 GetCellBiasarmnn::QuantizedLstmInputParams108 const ConstTensor& GetCellBias() const 109 { 110 return Deref(m_CellBias); 111 } 112 GetOutputGateBiasarmnn::QuantizedLstmInputParams113 const ConstTensor& GetOutputGateBias() const 114 { 115 return Deref(m_OutputGateBias); 116 } 117 }; 118 119 struct QuantizedLstmInputParamsInfo 120 { QuantizedLstmInputParamsInfoarmnn::QuantizedLstmInputParamsInfo121 QuantizedLstmInputParamsInfo() 122 : m_InputToInputWeights(nullptr) 123 , m_InputToForgetWeights(nullptr) 124 , m_InputToCellWeights(nullptr) 125 , m_InputToOutputWeights(nullptr) 126 127 , m_RecurrentToInputWeights(nullptr) 128 , m_RecurrentToForgetWeights(nullptr) 129 , m_RecurrentToCellWeights(nullptr) 130 , m_RecurrentToOutputWeights(nullptr) 131 132 , m_InputGateBias(nullptr) 133 , m_ForgetGateBias(nullptr) 134 , m_CellBias(nullptr) 135 , m_OutputGateBias(nullptr) 136 { 137 } 138 139 const TensorInfo* m_InputToInputWeights; 140 const TensorInfo* m_InputToForgetWeights; 141 const TensorInfo* m_InputToCellWeights; 142 const TensorInfo* m_InputToOutputWeights; 143 144 const TensorInfo* m_RecurrentToInputWeights; 145 const TensorInfo* m_RecurrentToForgetWeights; 146 const TensorInfo* m_RecurrentToCellWeights; 147 const TensorInfo* m_RecurrentToOutputWeights; 148 149 const TensorInfo* m_InputGateBias; 150 const TensorInfo* m_ForgetGateBias; 151 const TensorInfo* m_CellBias; 152 const TensorInfo* m_OutputGateBias; 153 154 Derefarmnn::QuantizedLstmInputParamsInfo155 const TensorInfo& Deref(const TensorInfo* tensorInfo) const 156 { 157 if (tensorInfo != nullptr) 158 { 159 const TensorInfo &temp = *tensorInfo; 160 return temp; 161 } 162 throw InvalidArgumentException("Can't dereference a null pointer"); 163 } 164 GetInputToInputWeightsarmnn::QuantizedLstmInputParamsInfo165 const TensorInfo& GetInputToInputWeights() const 166 { 167 return Deref(m_InputToInputWeights); 168 } GetInputToForgetWeightsarmnn::QuantizedLstmInputParamsInfo169 const TensorInfo& GetInputToForgetWeights() const 170 { 171 return Deref(m_InputToForgetWeights); 172 } GetInputToCellWeightsarmnn::QuantizedLstmInputParamsInfo173 const TensorInfo& GetInputToCellWeights() const 174 { 175 return Deref(m_InputToCellWeights); 176 } GetInputToOutputWeightsarmnn::QuantizedLstmInputParamsInfo177 const TensorInfo& GetInputToOutputWeights() const 178 { 179 return Deref(m_InputToOutputWeights); 180 } 181 GetRecurrentToInputWeightsarmnn::QuantizedLstmInputParamsInfo182 const TensorInfo& GetRecurrentToInputWeights() const 183 { 184 return Deref(m_RecurrentToInputWeights); 185 } GetRecurrentToForgetWeightsarmnn::QuantizedLstmInputParamsInfo186 const TensorInfo& GetRecurrentToForgetWeights() const 187 { 188 return Deref(m_RecurrentToForgetWeights); 189 } GetRecurrentToCellWeightsarmnn::QuantizedLstmInputParamsInfo190 const TensorInfo& GetRecurrentToCellWeights() const 191 { 192 return Deref(m_RecurrentToCellWeights); 193 } GetRecurrentToOutputWeightsarmnn::QuantizedLstmInputParamsInfo194 const TensorInfo& GetRecurrentToOutputWeights() const 195 { 196 return Deref(m_RecurrentToOutputWeights); 197 } 198 GetInputGateBiasarmnn::QuantizedLstmInputParamsInfo199 const TensorInfo& GetInputGateBias() const 200 { 201 return Deref(m_InputGateBias); 202 } GetForgetGateBiasarmnn::QuantizedLstmInputParamsInfo203 const TensorInfo& GetForgetGateBias() const 204 { 205 return Deref(m_ForgetGateBias); 206 } GetCellBiasarmnn::QuantizedLstmInputParamsInfo207 const TensorInfo& GetCellBias() const 208 { 209 return Deref(m_CellBias); 210 } GetOutputGateBiasarmnn::QuantizedLstmInputParamsInfo211 const TensorInfo& GetOutputGateBias() const 212 { 213 return Deref(m_OutputGateBias); 214 } 215 }; 216 217 } // namespace armnn 218 219