xref: /aosp_15_r20/external/armnn/src/backends/cl/workloads/ClQLstmWorkload.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ClQLstmWorkload.hpp"
7 #include "ClWorkloadUtils.hpp"
8 
9 #include "aclCommon/ArmComputeTensorUtils.hpp"
10 
11 #include "cl/ClTensorHandle.hpp"
12 
13 namespace armnn
14 {
15 using namespace armcomputetensorutils;
16 
ClQLstmWorkload(const QLstmQueueDescriptor & descriptor,const WorkloadInfo & info,const arm_compute::CLCompileContext & clCompileContext)17 ClQLstmWorkload::ClQLstmWorkload(const QLstmQueueDescriptor& descriptor,
18                                  const WorkloadInfo& info,
19                                  const arm_compute::CLCompileContext& clCompileContext)
20     : ClBaseWorkload<QLstmQueueDescriptor>(descriptor, info)
21 {
22     // Report Profiling Details
23     ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClQLstmWorkload_Construct",
24                                          descriptor.m_Parameters,
25                                          info,
26                                          this->GetGuid());
27 
28     arm_compute::LSTMParams<arm_compute::ICLTensor> qLstmParams;
29 
30     // Mandatory params
31     m_InputToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
32     BuildArmComputeTensor(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights->GetTensorInfo());
33 
34     m_InputToCellWeightsTensor = std::make_unique<arm_compute::CLTensor>();
35     BuildArmComputeTensor(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights->GetTensorInfo());
36 
37     m_InputToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
38     BuildArmComputeTensor(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights->GetTensorInfo());
39 
40     m_RecurrentToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
41     BuildArmComputeTensor(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights->GetTensorInfo());
42 
43     m_RecurrentToCellWeightsTensor = std::make_unique<arm_compute::CLTensor>();
44     BuildArmComputeTensor(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights->GetTensorInfo());
45 
46     m_RecurrentToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
47     BuildArmComputeTensor(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights->GetTensorInfo());
48 
49     m_ForgetGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
50     BuildArmComputeTensor(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias->GetTensorInfo());
51 
52     m_CellBiasTensor = std::make_unique<arm_compute::CLTensor>();
53     BuildArmComputeTensor(*m_CellBiasTensor, m_Data.m_CellBias->GetTensorInfo());
54 
55     m_OutputGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
56     BuildArmComputeTensor(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias->GetTensorInfo());
57 
58     // Create tensors for optional params if they are enabled
59     if (m_Data.m_Parameters.m_PeepholeEnabled)
60     {
61         m_CellToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
62 
63         if (!m_Data.m_Parameters.m_CifgEnabled)
64         {
65             // In ACL this is categorised as a CIFG param and not a Peephole param
66             BuildArmComputeTensor(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights->GetTensorInfo());
67         }
68 
69         m_CellToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
70         BuildArmComputeTensor(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights->GetTensorInfo());
71 
72         m_CellToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
73         BuildArmComputeTensor(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights->GetTensorInfo());
74 
75         // Set Peephole params
76         qLstmParams.set_peephole_params(m_CellToForgetWeightsTensor.get(),
77                                         m_CellToOutputWeightsTensor.get());
78     }
79 
80     if (m_Data.m_Parameters.m_ProjectionEnabled)
81     {
82         m_ProjectionWeightsTensor = std::make_unique<arm_compute::CLTensor>();
83         BuildArmComputeTensor(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights->GetTensorInfo());
84 
85         m_ProjectionBiasTensor = std::make_unique<arm_compute::CLTensor>();
86         if (m_Data.m_ProjectionBias != nullptr)
87         {
88             BuildArmComputeTensor(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias->GetTensorInfo());
89         }
90 
91         // Set projection params
92         qLstmParams.set_projection_params(
93             m_ProjectionWeightsTensor.get(),
94             m_Data.m_ProjectionBias != nullptr ? m_ProjectionBiasTensor.get() : nullptr);
95     }
96 
97     if (m_Data.m_Parameters.m_LayerNormEnabled)
98     {
99         m_InputLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
100 
101         if (!m_Data.m_Parameters.m_CifgEnabled)
102         {
103             BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo());
104         }
105 
106         m_ForgetLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
107         BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo());
108 
109         m_CellLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
110         BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo());
111 
112         m_OutputLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
113         BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo());
114 
115         // Set layer norm params
116         qLstmParams.set_layer_normalization_params(
117             m_Data.m_InputLayerNormWeights != nullptr ? m_InputLayerNormWeightsTensor.get() : nullptr,
118             m_ForgetLayerNormWeightsTensor.get(),
119             m_CellLayerNormWeightsTensor.get(),
120             m_OutputLayerNormWeightsTensor.get());
121     }
122 
123     if (!m_Data.m_Parameters.m_CifgEnabled)
124     {
125         m_InputToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
126         BuildArmComputeTensor(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights->GetTensorInfo());
127 
128         m_RecurrentToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
129         BuildArmComputeTensor(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights->GetTensorInfo());
130 
131         m_InputGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
132         BuildArmComputeTensor(*m_InputGateBiasTensor, m_Data.m_InputGateBias->GetTensorInfo());
133 
134         // Set CIFG params
135         qLstmParams.set_cifg_params(
136             m_InputToInputWeightsTensor.get(),
137             m_RecurrentToInputWeightsTensor.get(),
138             m_Data.m_CellToInputWeights != nullptr ? m_CellToInputWeightsTensor.get() : nullptr,
139             m_InputGateBiasTensor.get());
140     }
141 
142     // Input/Output tensors
143     const arm_compute::ICLTensor& input         = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
144     arm_compute::ICLTensor&       outputStateIn = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
145     arm_compute::ICLTensor&       cellStateIn   = static_cast<IClTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
146 
147     arm_compute::ICLTensor& outputStateOut = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
148     arm_compute::ICLTensor& cellStateOut   = static_cast<IClTensorHandle*>(m_Data.m_Outputs[1])->GetTensor();
149     arm_compute::ICLTensor& output         = static_cast<IClTensorHandle*>(m_Data.m_Outputs[2])->GetTensor();
150 
151     // Set scalar descriptor params
152     qLstmParams.set_cell_clip_params(m_Data.m_Parameters.m_CellClip);
153     qLstmParams.set_projection_clip_params(m_Data.m_Parameters.m_ProjectionClip);
154     qLstmParams.set_hidden_state_params(m_Data.m_Parameters.m_HiddenStateZeroPoint,
155                                         m_Data.m_Parameters.m_HiddenStateScale);
156     qLstmParams.set_matmul_scale_params(m_Data.m_Parameters.m_InputIntermediateScale,
157                                         m_Data.m_Parameters.m_ForgetIntermediateScale,
158                                         m_Data.m_Parameters.m_CellIntermediateScale,
159                                         m_Data.m_Parameters.m_OutputIntermediateScale);
160 
161     {
162         ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "ClQLstmWorkload_configure");
163         // QLSTM CL configure
164         m_QLstmLayer.configure(clCompileContext,
165                                &input,
166                                m_InputToForgetWeightsTensor.get(),
167                                m_InputToCellWeightsTensor.get(),
168                                m_InputToOutputWeightsTensor.get(),
169                                m_RecurrentToForgetWeightsTensor.get(),
170                                m_RecurrentToCellWeightsTensor.get(),
171                                m_RecurrentToOutputWeightsTensor.get(),
172                                m_ForgetGateBiasTensor.get(),
173                                m_CellBiasTensor.get(),
174                                m_OutputGateBiasTensor.get(),
175                                &cellStateIn,
176                                &outputStateIn,
177                                &cellStateOut,
178                                &outputStateOut,
179                                &output,
180                                qLstmParams);
181     }
182 
183     // Initialise ACL tensor data for mandatory params
184     InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights);
185     InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor,   m_Data.m_InputToCellWeights);
186     InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights);
187 
188     InitializeArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights);
189     InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor,   m_Data.m_RecurrentToCellWeights);
190     InitializeArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights);
191 
192     InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias);
193     InitializeArmComputeClTensorData(*m_CellBiasTensor,       m_Data.m_CellBias);
194     InitializeArmComputeClTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias);
195 
196     // Initialise ACL tensor data for optional params
197     if (!m_Data.m_Parameters.m_CifgEnabled)
198     {
199         InitializeArmComputeClTensorData(*m_InputToInputWeightsTensor,     m_Data.m_InputToInputWeights);
200         InitializeArmComputeClTensorData(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights);
201         InitializeArmComputeClTensorData(*m_InputGateBiasTensor,           m_Data.m_InputGateBias);
202     }
203 
204     if (m_Data.m_Parameters.m_ProjectionEnabled)
205     {
206         InitializeArmComputeClTensorData(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights);
207 
208         if (m_Data.m_ProjectionBias != nullptr)
209         {
210             InitializeArmComputeClTensorData(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias);
211         }
212     }
213 
214     if (m_Data.m_Parameters.m_PeepholeEnabled)
215     {
216         if (!m_Data.m_Parameters.m_CifgEnabled)
217         {
218             InitializeArmComputeClTensorData(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights);
219         }
220 
221         InitializeArmComputeClTensorData(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights);
222         InitializeArmComputeClTensorData(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights);
223     }
224 
225     if (m_Data.m_Parameters.m_LayerNormEnabled)
226     {
227         if (!m_Data.m_Parameters.m_CifgEnabled)
228         {
229             InitializeArmComputeClTensorData(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights);
230         }
231         InitializeArmComputeClTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights);
232         InitializeArmComputeClTensorData(*m_CellLayerNormWeightsTensor,   m_Data.m_CellLayerNormWeights);
233         InitializeArmComputeClTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights);
234     }
235 
236     m_QLstmLayer.prepare();
237 
238     FreeUnusedTensors();
239 }
240 
Execute() const241 void ClQLstmWorkload::Execute() const
242 {
243     ARMNN_SCOPED_PROFILING_EVENT_CL_GUID("ClQuantizedLstmWorkload_Execute", this->GetGuid());
244     m_QLstmLayer.run();
245 }
246 
ClQLstmWorkloadValidate(const TensorInfo & input,const TensorInfo & cellStateIn,const TensorInfo & outputStateIn,const TensorInfo & cellStateOut,const TensorInfo & outputStateOut,const TensorInfo & output,const QLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo)247 arm_compute::Status ClQLstmWorkloadValidate(const TensorInfo& input,
248                                               const TensorInfo& cellStateIn,
249                                               const TensorInfo& outputStateIn,
250                                               const TensorInfo& cellStateOut,
251                                               const TensorInfo& outputStateOut,
252                                               const TensorInfo& output,
253                                               const QLstmDescriptor& descriptor,
254                                               const LstmInputParamsInfo& paramsInfo)
255 {
256     arm_compute::LSTMParams<arm_compute::ITensorInfo> aclParamsInfo;
257 
258     // Input/Output tensor info
259     const arm_compute::TensorInfo aclInputInfo         = BuildArmComputeTensorInfo(input);
260     const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);
261     const arm_compute::TensorInfo aclCellStateInInfo   = BuildArmComputeTensorInfo(cellStateIn);
262 
263     const arm_compute::TensorInfo aclOutputStateOutInfo = BuildArmComputeTensorInfo(outputStateOut);
264     const arm_compute::TensorInfo aclCellStateOutInfo   = BuildArmComputeTensorInfo(cellStateOut);
265     const arm_compute::TensorInfo aclOutputInfo         = BuildArmComputeTensorInfo(output);
266 
267     // Mandatory tensor info
268     const arm_compute::TensorInfo aclInputToForgetWeightsInfo
269         = BuildArmComputeTensorInfo(paramsInfo.GetInputToForgetWeights());
270     const arm_compute::TensorInfo aclInputToCellWeightsInfo
271         = BuildArmComputeTensorInfo(paramsInfo.GetInputToCellWeights());
272     const arm_compute::TensorInfo aclInputToOutputWeightsInfo
273         = BuildArmComputeTensorInfo(paramsInfo.GetInputToOutputWeights());
274     const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
275         = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToForgetWeights());
276     const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
277         = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToCellWeights());
278     const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
279         = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToOutputWeights());
280     const arm_compute::TensorInfo aclForgetGateBiasInfo
281         = BuildArmComputeTensorInfo(paramsInfo.GetForgetGateBias());
282     const arm_compute::TensorInfo aclCellBiasInfo
283         = BuildArmComputeTensorInfo(paramsInfo.GetCellBias());
284     const arm_compute::TensorInfo aclOutputGateBiasInfo
285         = BuildArmComputeTensorInfo(paramsInfo.GetOutputGateBias());
286 
287     // Optional tensor info
288     arm_compute::TensorInfo aclInputToInputWeightsInfo;
289     arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
290 
291     arm_compute::TensorInfo aclCellToInputWeightsInfo;
292     arm_compute::TensorInfo aclCellToForgetWeightsInfo;
293     arm_compute::TensorInfo aclCellToOutputWeightsInfo;
294 
295     arm_compute::TensorInfo aclInputGateBiasInfo;
296 
297     arm_compute::TensorInfo aclProjectionWeightsInfo;
298     arm_compute::TensorInfo aclProjectionBiasInfo;
299 
300     arm_compute::TensorInfo aclInputLayerNormWeightsInfo;
301     arm_compute::TensorInfo aclForgetLayerNormWeightsInfo;
302     arm_compute::TensorInfo aclCellLayerNormWeightsInfo;
303     arm_compute::TensorInfo aclOutputLayerNormWeightsInfo;
304 
305     // Create tensor info for optional params if they are enabled
306     if (descriptor.m_PeepholeEnabled)
307     {
308         if (!descriptor.m_CifgEnabled)
309         {
310             aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToInputWeights());
311         }
312 
313         aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToForgetWeights());
314         aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToOutputWeights());
315 
316         // Set peephole params info
317         aclParamsInfo.set_peephole_params(&aclCellToForgetWeightsInfo,
318                                           &aclCellToOutputWeightsInfo);
319     }
320 
321     if (descriptor.m_ProjectionEnabled)
322     {
323         aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionWeights());
324 
325         if (paramsInfo.m_ProjectionBias != nullptr)
326         {
327             aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionBias());
328         }
329 
330         // Set projection params info
331         aclParamsInfo.set_projection_params(
332             &aclProjectionWeightsInfo,
333             paramsInfo.m_ProjectionBias != nullptr ? &aclProjectionBiasInfo : nullptr);
334     }
335 
336     if (descriptor.m_LayerNormEnabled)
337     {
338         if (!descriptor.m_CifgEnabled)
339         {
340             aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputLayerNormWeights());
341         }
342 
343         aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetForgetLayerNormWeights());
344         aclCellLayerNormWeightsInfo   = BuildArmComputeTensorInfo(paramsInfo.GetCellLayerNormWeights());
345         aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetOutputLayerNormWeights());
346 
347         // Set layer norm params info
348         aclParamsInfo.set_layer_normalization_params(
349             paramsInfo.m_InputLayerNormWeights != nullptr ? &aclInputLayerNormWeightsInfo : nullptr,
350             &aclForgetLayerNormWeightsInfo,
351             &aclCellLayerNormWeightsInfo,
352             &aclOutputLayerNormWeightsInfo);
353     }
354 
355     if (!descriptor.m_CifgEnabled)
356     {
357         aclInputToInputWeightsInfo     = BuildArmComputeTensorInfo(paramsInfo.GetInputToInputWeights());
358         aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToInputWeights());
359         aclInputGateBiasInfo           = BuildArmComputeTensorInfo(paramsInfo.GetInputGateBias());
360 
361         // Set CIFG params info
362         aclParamsInfo.set_cifg_params(
363             &aclInputToInputWeightsInfo,
364             &aclRecurrentToInputWeightsInfo,
365             paramsInfo.m_CellToInputWeights != nullptr ? &aclCellToInputWeightsInfo : nullptr,
366             &aclInputGateBiasInfo);
367     }
368 
369     // Set scalar descriptor params
370     aclParamsInfo.set_cell_clip_params(descriptor.m_CellClip);
371     aclParamsInfo.set_projection_clip_params(descriptor.m_ProjectionClip);
372     aclParamsInfo.set_hidden_state_params(descriptor.m_HiddenStateZeroPoint, descriptor.m_HiddenStateScale);
373     aclParamsInfo.set_matmul_scale_params(descriptor.m_InputIntermediateScale,
374                                           descriptor.m_ForgetIntermediateScale,
375                                           descriptor.m_CellIntermediateScale,
376                                           descriptor.m_OutputIntermediateScale);
377 
378     // QLSTM CL validate
379     return arm_compute::CLQLSTMLayer::validate(&aclInputInfo,
380                                                &aclInputToForgetWeightsInfo,
381                                                &aclInputToCellWeightsInfo,
382                                                &aclInputToOutputWeightsInfo,
383                                                &aclRecurrentToForgetWeightsInfo,
384                                                &aclRecurrentToCellWeightsInfo,
385                                                &aclRecurrentToOutputWeightsInfo,
386                                                &aclForgetGateBiasInfo,
387                                                &aclCellBiasInfo,
388                                                &aclOutputGateBiasInfo,
389                                                &aclCellStateInInfo,
390                                                &aclOutputStateInInfo,
391                                                &aclCellStateOutInfo,
392                                                &aclOutputStateOutInfo,
393                                                &aclOutputInfo,
394                                                aclParamsInfo);
395 }
396 
FreeUnusedTensors()397 void ClQLstmWorkload::FreeUnusedTensors()
398 {
399     FreeTensorIfUnused(m_InputToInputWeightsTensor);
400     FreeTensorIfUnused(m_InputToForgetWeightsTensor);
401     FreeTensorIfUnused(m_InputToCellWeightsTensor);
402     FreeTensorIfUnused(m_InputToOutputWeightsTensor);
403 
404     FreeTensorIfUnused(m_RecurrentToInputWeightsTensor);
405     FreeTensorIfUnused(m_RecurrentToForgetWeightsTensor);
406     FreeTensorIfUnused(m_RecurrentToCellWeightsTensor);
407     FreeTensorIfUnused(m_RecurrentToOutputWeightsTensor);
408 
409     FreeTensorIfUnused(m_CellToInputWeightsTensor);
410     FreeTensorIfUnused(m_CellToForgetWeightsTensor);
411     FreeTensorIfUnused(m_CellToOutputWeightsTensor);
412 
413     FreeTensorIfUnused(m_InputGateBiasTensor);
414     FreeTensorIfUnused(m_ForgetGateBiasTensor);
415     FreeTensorIfUnused(m_CellBiasTensor);
416     FreeTensorIfUnused(m_OutputGateBiasTensor);
417 
418     FreeTensorIfUnused(m_ProjectionWeightsTensor);
419     FreeTensorIfUnused(m_ProjectionBiasTensor);
420 
421     FreeTensorIfUnused(m_InputLayerNormWeightsTensor);
422     FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor);
423     FreeTensorIfUnused(m_CellLayerNormWeightsTensor);
424     FreeTensorIfUnused(m_OutputLayerNormWeightsTensor);
425 }
426 
427 } //namespace armnn
428