xref: /aosp_15_r20/external/armnn/src/backends/neon/workloads/NeonQLstmWorkload.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "NeonQLstmWorkload.hpp"
7 #include "NeonWorkloadUtils.hpp"
8 
9 #include "aclCommon/ArmComputeTensorUtils.hpp"
10 
11 #include "neon/NeonTensorHandle.hpp"
12 
13 namespace armnn
14 {
15 using namespace armcomputetensorutils;
16 
NeonQLstmWorkload(const QLstmQueueDescriptor & descriptor,const WorkloadInfo & info)17 NeonQLstmWorkload::NeonQLstmWorkload(const QLstmQueueDescriptor& descriptor, const WorkloadInfo& info)
18         : NeonBaseWorkload<QLstmQueueDescriptor>(descriptor, info)
19 {
20     // Report Profiling Details
21     ARMNN_REPORT_PROFILING_WORKLOAD_DESC("NeonQLstmWorkload_Construct",
22                                          descriptor.m_Parameters,
23                                          info,
24                                          this->GetGuid());
25 
26     arm_compute::LSTMParams<arm_compute::ITensor> qLstmParams;
27 
28     // Mandatory params
29     m_InputToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
30     BuildArmComputeTensor(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights->GetTensorInfo());
31 
32     m_InputToCellWeightsTensor = std::make_unique<arm_compute::Tensor>();
33     BuildArmComputeTensor(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights->GetTensorInfo());
34 
35     m_InputToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
36     BuildArmComputeTensor(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights->GetTensorInfo());
37 
38     m_RecurrentToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
39     BuildArmComputeTensor(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights->GetTensorInfo());
40 
41     m_RecurrentToCellWeightsTensor = std::make_unique<arm_compute::Tensor>();
42     BuildArmComputeTensor(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights->GetTensorInfo());
43 
44     m_RecurrentToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
45     BuildArmComputeTensor(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights->GetTensorInfo());
46 
47     m_ForgetGateBiasTensor = std::make_unique<arm_compute::Tensor>();
48     BuildArmComputeTensor(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias->GetTensorInfo());
49 
50     m_CellBiasTensor = std::make_unique<arm_compute::Tensor>();
51     BuildArmComputeTensor(*m_CellBiasTensor, m_Data.m_CellBias->GetTensorInfo());
52 
53     m_OutputGateBiasTensor = std::make_unique<arm_compute::Tensor>();
54     BuildArmComputeTensor(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias->GetTensorInfo());
55 
56     // Create tensors for optional params if they are enabled
57     if (m_Data.m_Parameters.m_PeepholeEnabled)
58     {
59         m_CellToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
60 
61         if (!m_Data.m_Parameters.m_CifgEnabled)
62         {
63             // In ACL this is categorised as a CIFG param and not a Peephole param
64             BuildArmComputeTensor(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights->GetTensorInfo());
65         }
66 
67         m_CellToForgetWeightsTensor = std::make_unique<arm_compute::Tensor>();
68         BuildArmComputeTensor(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights->GetTensorInfo());
69 
70         m_CellToOutputWeightsTensor = std::make_unique<arm_compute::Tensor>();
71         BuildArmComputeTensor(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights->GetTensorInfo());
72 
73         // Set Peephole params
74         qLstmParams.set_peephole_params(m_CellToForgetWeightsTensor.get(),
75                                         m_CellToOutputWeightsTensor.get());
76     }
77 
78     if (m_Data.m_Parameters.m_ProjectionEnabled)
79     {
80         m_ProjectionWeightsTensor = std::make_unique<arm_compute::Tensor>();
81         BuildArmComputeTensor(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights->GetTensorInfo());
82 
83         m_ProjectionBiasTensor = std::make_unique<arm_compute::Tensor>();
84         if (m_Data.m_ProjectionBias != nullptr)
85         {
86             BuildArmComputeTensor(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias->GetTensorInfo());
87         }
88 
89         // Set projection params
90         qLstmParams.set_projection_params(
91             m_ProjectionWeightsTensor.get(),
92             m_Data.m_ProjectionBias != nullptr ? m_ProjectionBiasTensor.get() : nullptr);
93     }
94 
95     if (m_Data.m_Parameters.m_LayerNormEnabled)
96     {
97         m_InputLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
98 
99         if (!m_Data.m_Parameters.m_CifgEnabled)
100         {
101             BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo());
102         }
103 
104         m_ForgetLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
105         BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo());
106 
107         m_CellLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
108         BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo());
109 
110         m_OutputLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
111         BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo());
112 
113         // Set layer norm params
114         qLstmParams.set_layer_normalization_params(
115             m_Data.m_InputLayerNormWeights != nullptr ? m_InputLayerNormWeightsTensor.get() : nullptr,
116             m_ForgetLayerNormWeightsTensor.get(),
117             m_CellLayerNormWeightsTensor.get(),
118             m_OutputLayerNormWeightsTensor.get());
119     }
120 
121     if (!m_Data.m_Parameters.m_CifgEnabled)
122     {
123         m_InputToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
124         BuildArmComputeTensor(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights->GetTensorInfo());
125 
126         m_RecurrentToInputWeightsTensor = std::make_unique<arm_compute::Tensor>();
127         BuildArmComputeTensor(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights->GetTensorInfo());
128 
129         m_InputGateBiasTensor = std::make_unique<arm_compute::Tensor>();
130         BuildArmComputeTensor(*m_InputGateBiasTensor, m_Data.m_InputGateBias->GetTensorInfo());
131 
132         // Set CIFG params
133         qLstmParams.set_cifg_params(
134             m_InputToInputWeightsTensor.get(),
135             m_RecurrentToInputWeightsTensor.get(),
136             m_Data.m_CellToInputWeights != nullptr ? m_CellToInputWeightsTensor.get() : nullptr,
137             m_InputGateBiasTensor.get());
138     }
139 
140     // Input/Output tensors
141     const arm_compute::ITensor& input         = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
142     arm_compute::ITensor& outputStateIn = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
143     const arm_compute::ITensor& cellStateIn   = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
144 
145     arm_compute::ITensor& outputStateOut = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
146     arm_compute::ITensor& cellStateOut   = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[1])->GetTensor();
147     arm_compute::ITensor& output         = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[2])->GetTensor();
148 
149     // Set scalar descriptor params
150     qLstmParams.set_cell_clip_params(m_Data.m_Parameters.m_CellClip);
151     qLstmParams.set_projection_clip_params(m_Data.m_Parameters.m_ProjectionClip);
152     qLstmParams.set_hidden_state_params(m_Data.m_Parameters.m_HiddenStateZeroPoint,
153                                         m_Data.m_Parameters.m_HiddenStateScale);
154     qLstmParams.set_matmul_scale_params(m_Data.m_Parameters.m_InputIntermediateScale,
155                                         m_Data.m_Parameters.m_ForgetIntermediateScale,
156                                         m_Data.m_Parameters.m_CellIntermediateScale,
157                                         m_Data.m_Parameters.m_OutputIntermediateScale);
158 
159     // QLSTM NEON configure
160     m_QLstmLayer.configure(&input,
161                            m_InputToForgetWeightsTensor.get(),
162                            m_InputToCellWeightsTensor.get(),
163                            m_InputToOutputWeightsTensor.get(),
164                            m_RecurrentToForgetWeightsTensor.get(),
165                            m_RecurrentToCellWeightsTensor.get(),
166                            m_RecurrentToOutputWeightsTensor.get(),
167                            m_ForgetGateBiasTensor.get(),
168                            m_CellBiasTensor.get(),
169                            m_OutputGateBiasTensor.get(),
170                            &cellStateIn,
171                            &outputStateIn,
172                            &cellStateOut,
173                            &outputStateOut,
174                            &output,
175                            qLstmParams);
176 
177     // Initialise ACL tensor data for mandatory params
178     InitializeArmComputeTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights);
179     InitializeArmComputeTensorData(*m_InputToCellWeightsTensor,   m_Data.m_InputToCellWeights);
180     InitializeArmComputeTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights);
181 
182     InitializeArmComputeTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights);
183     InitializeArmComputeTensorData(*m_RecurrentToCellWeightsTensor,   m_Data.m_RecurrentToCellWeights);
184     InitializeArmComputeTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights);
185 
186     InitializeArmComputeTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias);
187     InitializeArmComputeTensorData(*m_CellBiasTensor,       m_Data.m_CellBias);
188     InitializeArmComputeTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias);
189 
190     // Initialise ACL tensor data for optional params
191     if (!m_Data.m_Parameters.m_CifgEnabled)
192     {
193         InitializeArmComputeTensorData(*m_InputToInputWeightsTensor,     m_Data.m_InputToInputWeights);
194         InitializeArmComputeTensorData(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights);
195         InitializeArmComputeTensorData(*m_InputGateBiasTensor,           m_Data.m_InputGateBias);
196     }
197 
198     if (m_Data.m_Parameters.m_ProjectionEnabled)
199     {
200         InitializeArmComputeTensorData(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights);
201 
202         if (m_Data.m_ProjectionBias != nullptr)
203         {
204             InitializeArmComputeTensorData(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias);
205         }
206     }
207 
208     if (m_Data.m_Parameters.m_PeepholeEnabled)
209     {
210         if (!m_Data.m_Parameters.m_CifgEnabled)
211         {
212             InitializeArmComputeTensorData(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights);
213         }
214 
215         InitializeArmComputeTensorData(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights);
216         InitializeArmComputeTensorData(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights);
217     }
218 
219     if (m_Data.m_Parameters.m_LayerNormEnabled)
220     {
221         if (!m_Data.m_Parameters.m_CifgEnabled)
222         {
223             InitializeArmComputeTensorData(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights);
224         }
225 
226         InitializeArmComputeTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights);
227         InitializeArmComputeTensorData(*m_CellLayerNormWeightsTensor,   m_Data.m_CellLayerNormWeights);
228         InitializeArmComputeTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights);
229     }
230 
231     // QLSTM NEON prepare
232     m_QLstmLayer.prepare();
233 
234     FreeUnusedTensors();
235 }
236 
Execute() const237 void NeonQLstmWorkload::Execute() const
238 {
239     ARMNN_SCOPED_PROFILING_EVENT_NEON_GUID("NeonQuantizedLstmWorkload_Execute", this->GetGuid());
240     m_QLstmLayer.run();
241 }
242 
NeonQLstmWorkloadValidate(const TensorInfo & input,const TensorInfo & cellStateIn,const TensorInfo & outputStateIn,const TensorInfo & cellStateOut,const TensorInfo & outputStateOut,const TensorInfo & output,const QLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo)243 arm_compute::Status NeonQLstmWorkloadValidate(const TensorInfo& input,
244                                               const TensorInfo& cellStateIn,
245                                               const TensorInfo& outputStateIn,
246                                               const TensorInfo& cellStateOut,
247                                               const TensorInfo& outputStateOut,
248                                               const TensorInfo& output,
249                                               const QLstmDescriptor& descriptor,
250                                               const LstmInputParamsInfo& paramsInfo)
251 {
252     arm_compute::LSTMParams<arm_compute::ITensorInfo> aclParamsInfo;
253 
254     // Input/Output tensor info
255     const arm_compute::TensorInfo aclInputInfo         = BuildArmComputeTensorInfo(input);
256     const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);
257     const arm_compute::TensorInfo aclCellStateInInfo   = BuildArmComputeTensorInfo(cellStateIn);
258 
259     const arm_compute::TensorInfo aclOutputStateOutInfo = BuildArmComputeTensorInfo(outputStateOut);
260     const arm_compute::TensorInfo aclCellStateOutInfo   = BuildArmComputeTensorInfo(cellStateOut);
261     const arm_compute::TensorInfo aclOutputInfo         = BuildArmComputeTensorInfo(output);
262 
263     // Mandatory tensor info
264     const arm_compute::TensorInfo aclInputToForgetWeightsInfo
265             = BuildArmComputeTensorInfo(paramsInfo.GetInputToForgetWeights());
266     const arm_compute::TensorInfo aclInputToCellWeightsInfo
267             = BuildArmComputeTensorInfo(paramsInfo.GetInputToCellWeights());
268     const arm_compute::TensorInfo aclInputToOutputWeightsInfo
269             = BuildArmComputeTensorInfo(paramsInfo.GetInputToOutputWeights());
270     const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
271             = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToForgetWeights());
272     const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
273             = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToCellWeights());
274     const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
275             = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToOutputWeights());
276     const arm_compute::TensorInfo aclForgetGateBiasInfo
277             = BuildArmComputeTensorInfo(paramsInfo.GetForgetGateBias());
278     const arm_compute::TensorInfo aclCellBiasInfo
279             = BuildArmComputeTensorInfo(paramsInfo.GetCellBias());
280     const arm_compute::TensorInfo aclOutputGateBiasInfo
281             = BuildArmComputeTensorInfo(paramsInfo.GetOutputGateBias());
282 
283     // Optional tensor info
284     arm_compute::TensorInfo aclInputToInputWeightsInfo;
285     arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
286 
287     arm_compute::TensorInfo aclCellToInputWeightsInfo;
288     arm_compute::TensorInfo aclCellToForgetWeightsInfo;
289     arm_compute::TensorInfo aclCellToOutputWeightsInfo;
290 
291     arm_compute::TensorInfo aclInputGateBiasInfo;
292 
293     arm_compute::TensorInfo aclProjectionWeightsInfo;
294     arm_compute::TensorInfo aclProjectionBiasInfo;
295 
296     arm_compute::TensorInfo aclInputLayerNormWeightsInfo;
297     arm_compute::TensorInfo aclForgetLayerNormWeightsInfo;
298     arm_compute::TensorInfo aclCellLayerNormWeightsInfo;
299     arm_compute::TensorInfo aclOutputLayerNormWeightsInfo;
300 
301     // Create tensor info for optional params if they are enabled
302     if (descriptor.m_PeepholeEnabled)
303     {
304         if (!descriptor.m_CifgEnabled)
305         {
306             aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToInputWeights());
307         }
308 
309         aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToForgetWeights());
310         aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToOutputWeights());
311 
312         // Set peephole params info
313         aclParamsInfo.set_peephole_params(&aclCellToForgetWeightsInfo,
314                                           &aclCellToOutputWeightsInfo);
315     }
316 
317     if (descriptor.m_ProjectionEnabled)
318     {
319         aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionWeights());
320 
321         if (paramsInfo.m_ProjectionBias != nullptr)
322         {
323             aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionBias());
324         }
325 
326         // Set projection params info
327         aclParamsInfo.set_projection_params(
328             &aclProjectionWeightsInfo,
329             paramsInfo.m_ProjectionBias != nullptr ? &aclProjectionBiasInfo : nullptr);
330     }
331 
332     if (descriptor.m_LayerNormEnabled)
333     {
334         if (!descriptor.m_CifgEnabled)
335         {
336             aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputLayerNormWeights());
337         }
338 
339         aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetForgetLayerNormWeights());
340         aclCellLayerNormWeightsInfo   = BuildArmComputeTensorInfo(paramsInfo.GetCellLayerNormWeights());
341         aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetOutputLayerNormWeights());
342 
343         // Set layer norm params info
344         aclParamsInfo.set_layer_normalization_params(
345             paramsInfo.m_InputLayerNormWeights != nullptr ? &aclInputLayerNormWeightsInfo : nullptr,
346             &aclForgetLayerNormWeightsInfo,
347             &aclCellLayerNormWeightsInfo,
348             &aclOutputLayerNormWeightsInfo);
349     }
350 
351     if (!descriptor.m_CifgEnabled)
352     {
353         aclInputToInputWeightsInfo     = BuildArmComputeTensorInfo(paramsInfo.GetInputToInputWeights());
354         aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToInputWeights());
355         aclInputGateBiasInfo           = BuildArmComputeTensorInfo(paramsInfo.GetInputGateBias());
356 
357         // Set CIFG params info
358         aclParamsInfo.set_cifg_params(
359             &aclInputToInputWeightsInfo,
360             &aclRecurrentToInputWeightsInfo,
361             paramsInfo.m_CellToInputWeights != nullptr ? &aclCellToInputWeightsInfo : nullptr,
362             &aclInputGateBiasInfo);
363     }
364 
365     // Set scalar descriptor params
366     aclParamsInfo.set_cell_clip_params(descriptor.m_CellClip);
367     aclParamsInfo.set_projection_clip_params(descriptor.m_ProjectionClip);
368     aclParamsInfo.set_hidden_state_params(descriptor.m_HiddenStateZeroPoint, descriptor.m_HiddenStateScale);
369     aclParamsInfo.set_matmul_scale_params(descriptor.m_InputIntermediateScale,
370                                           descriptor.m_ForgetIntermediateScale,
371                                           descriptor.m_CellIntermediateScale,
372                                           descriptor.m_OutputIntermediateScale);
373 
374     // QLSTM NEON validate
375     return arm_compute::NEQLSTMLayer::validate(&aclInputInfo,
376                                                &aclInputToForgetWeightsInfo,
377                                                &aclInputToCellWeightsInfo,
378                                                &aclInputToOutputWeightsInfo,
379                                                &aclRecurrentToForgetWeightsInfo,
380                                                &aclRecurrentToCellWeightsInfo,
381                                                &aclRecurrentToOutputWeightsInfo,
382                                                &aclForgetGateBiasInfo,
383                                                &aclCellBiasInfo,
384                                                &aclOutputGateBiasInfo,
385                                                &aclCellStateInInfo,
386                                                &aclOutputStateInInfo,
387                                                &aclCellStateOutInfo,
388                                                &aclOutputStateOutInfo,
389                                                &aclOutputInfo,
390                                                aclParamsInfo);
391 }
392 
FreeUnusedTensors()393 void NeonQLstmWorkload::FreeUnusedTensors()
394 {
395     FreeTensorIfUnused(m_InputToInputWeightsTensor);
396     FreeTensorIfUnused(m_InputToForgetWeightsTensor);
397     FreeTensorIfUnused(m_InputToCellWeightsTensor);
398     FreeTensorIfUnused(m_InputToOutputWeightsTensor);
399 
400     FreeTensorIfUnused(m_RecurrentToInputWeightsTensor);
401     FreeTensorIfUnused(m_RecurrentToForgetWeightsTensor);
402     FreeTensorIfUnused(m_RecurrentToCellWeightsTensor);
403     FreeTensorIfUnused(m_RecurrentToOutputWeightsTensor);
404 
405     FreeTensorIfUnused(m_CellToInputWeightsTensor);
406     FreeTensorIfUnused(m_CellToForgetWeightsTensor);
407     FreeTensorIfUnused(m_CellToOutputWeightsTensor);
408 
409     FreeTensorIfUnused(m_InputGateBiasTensor);
410     FreeTensorIfUnused(m_ForgetGateBiasTensor);
411     FreeTensorIfUnused(m_CellBiasTensor);
412     FreeTensorIfUnused(m_OutputGateBiasTensor);
413 
414     FreeTensorIfUnused(m_ProjectionWeightsTensor);
415     FreeTensorIfUnused(m_ProjectionBiasTensor);
416 
417     FreeTensorIfUnused(m_InputLayerNormWeightsTensor);
418     FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor);
419     FreeTensorIfUnused(m_CellLayerNormWeightsTensor);
420     FreeTensorIfUnused(m_OutputLayerNormWeightsTensor);
421 }
422 
423 } //namespace armnn