xref: /aosp_15_r20/external/armnn/src/armnn/layers/QLstmLayer.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "QLstmLayer.hpp"
6 
7 #include "LayerCloneBase.hpp"
8 
9 #include <armnn/LstmParams.hpp>
10 #include <armnn/TypesUtils.hpp>
11 #include <armnn/backends/TensorHandle.hpp>
12 #include <armnn/backends/WorkloadFactory.hpp>
13 
14 namespace armnn
15 {
16 
QLstmLayer(const QLstmDescriptor & param,const char * name)17 QLstmLayer::QLstmLayer(const QLstmDescriptor& param, const char* name)
18         : LayerWithParameters(3, 3, LayerType::QLstm, param, name)
19 {
20 }
21 
CreateWorkload(const IWorkloadFactory & factory) const22 std::unique_ptr<IWorkload> QLstmLayer::CreateWorkload(const IWorkloadFactory& factory) const
23 {
24     QLstmQueueDescriptor descriptor;
25 
26     // Basic parameters
27     descriptor.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights.get();
28     descriptor.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights.get();
29     descriptor.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights.get();
30     descriptor.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights.get();
31     descriptor.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights.get();
32     descriptor.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights.get();
33     descriptor.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias.get();
34     descriptor.m_CellBias = m_BasicParameters.m_CellBias.get();
35     descriptor.m_OutputGateBias = m_BasicParameters.m_OutputGateBias.get();
36 
37     // CIFG parameters
38     if (!m_Param.m_CifgEnabled)
39     {
40         descriptor.m_InputToInputWeights     = m_CifgParameters.m_InputToInputWeights.get();
41         descriptor.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights.get();
42         descriptor.m_InputGateBias           = m_CifgParameters.m_InputGateBias.get();
43     }
44 
45     // Projection parameters
46     if (m_Param.m_ProjectionEnabled)
47     {
48         descriptor.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights.get();
49         descriptor.m_ProjectionBias    = m_ProjectionParameters.m_ProjectionBias.get();
50     }
51 
52     // Peephole parameters
53     if (m_Param.m_PeepholeEnabled)
54     {
55         if (!m_Param.m_CifgEnabled)
56         {
57             descriptor.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights.get();
58         }
59 
60         descriptor.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights.get();
61         descriptor.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights.get();
62     }
63 
64     // Layer normalisation parameters
65     if(m_Param.m_LayerNormEnabled)
66     {
67         if (!m_Param.m_CifgEnabled)
68         {
69             descriptor.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights.get();
70         }
71         descriptor.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights.get();
72         descriptor.m_CellLayerNormWeights   = m_LayerNormParameters.m_CellLayerNormWeights.get();
73         descriptor.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights.get();
74     }
75 
76     SetAdditionalInfo(descriptor);
77 
78     return factory.CreateWorkload(LayerType::QLstm, descriptor, PrepInfoAndDesc(descriptor));
79 }
80 
Clone(Graph & graph) const81 QLstmLayer* QLstmLayer::Clone(Graph& graph) const
82 {
83     auto layer = CloneBase<QLstmLayer>(graph, m_Param, GetName());
84 
85     layer->m_BasicParameters.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights ?
86             m_BasicParameters.m_InputToForgetWeights : nullptr;
87     layer->m_BasicParameters.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights ?
88             m_BasicParameters.m_InputToCellWeights : nullptr;
89     layer->m_BasicParameters.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights ?
90             m_BasicParameters.m_InputToOutputWeights : nullptr;
91     layer->m_BasicParameters.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights ?
92             m_BasicParameters.m_RecurrentToForgetWeights : nullptr;
93     layer->m_BasicParameters.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights ?
94             m_BasicParameters.m_RecurrentToCellWeights : nullptr;
95     layer->m_BasicParameters.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights ?
96             m_BasicParameters.m_RecurrentToOutputWeights : nullptr;
97     layer->m_BasicParameters.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias ?
98             m_BasicParameters.m_ForgetGateBias : nullptr;
99     layer->m_BasicParameters.m_CellBias = m_BasicParameters.m_CellBias ?
100             m_BasicParameters.m_CellBias : nullptr;
101     layer->m_BasicParameters.m_OutputGateBias = m_BasicParameters.m_OutputGateBias ?
102             m_BasicParameters.m_OutputGateBias : nullptr;
103 
104     if (!m_Param.m_CifgEnabled)
105     {
106         layer->m_CifgParameters.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights ?
107                 m_CifgParameters.m_InputToInputWeights : nullptr;
108         layer->m_CifgParameters.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights ?
109                 m_CifgParameters.m_RecurrentToInputWeights : nullptr;
110         layer->m_CifgParameters.m_InputGateBias = m_CifgParameters.m_InputGateBias ?
111                 m_CifgParameters.m_InputGateBias : nullptr;
112     }
113 
114     if (m_Param.m_ProjectionEnabled)
115     {
116         layer->m_ProjectionParameters.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights ?
117                 m_ProjectionParameters.m_ProjectionWeights : nullptr;
118         layer->m_ProjectionParameters.m_ProjectionBias = m_ProjectionParameters.m_ProjectionBias ?
119                 m_ProjectionParameters.m_ProjectionBias : nullptr;
120     }
121 
122     if (m_Param.m_PeepholeEnabled)
123     {
124         if (!m_Param.m_CifgEnabled) {
125             layer->m_PeepholeParameters.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights ?
126                     m_PeepholeParameters.m_CellToInputWeights : nullptr;
127         }
128 
129         layer->m_PeepholeParameters.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights ?
130                 m_PeepholeParameters.m_CellToForgetWeights : nullptr;
131         layer->m_PeepholeParameters.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights ?
132                 m_PeepholeParameters.m_CellToOutputWeights : nullptr;
133     }
134 
135     if (m_Param.m_LayerNormEnabled)
136     {
137         if (!m_Param.m_CifgEnabled) {
138             layer->m_LayerNormParameters.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights ?
139                     m_LayerNormParameters.m_InputLayerNormWeights : nullptr;
140         }
141 
142         layer->m_LayerNormParameters.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights ?
143                 m_LayerNormParameters.m_ForgetLayerNormWeights : nullptr;
144         layer->m_LayerNormParameters.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights ?
145                 m_LayerNormParameters.m_CellLayerNormWeights : nullptr;
146         layer->m_LayerNormParameters.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights ?
147                 m_LayerNormParameters.m_OutputLayerNormWeights : nullptr;
148     }
149 
150     return std::move(layer);
151 }
152 
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const153 std::vector<TensorShape> QLstmLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
154 {
155     ARMNN_ASSERT(inputShapes.size() == 3);
156 
157     // Get input values for validation
158     unsigned int batchSize = inputShapes[0][0];
159     unsigned int outputSize = inputShapes[1][1];
160     unsigned int numUnits = inputShapes[2][1];
161 
162     std::vector<TensorShape> outShapes;
163     outShapes.push_back(TensorShape({ batchSize, outputSize })); // outputStateOut
164     outShapes.push_back(TensorShape({ batchSize, numUnits })); // cellStateOut
165     outShapes.push_back(TensorShape({ batchSize, outputSize })); // output
166 
167     return outShapes;
168 }
169 
ValidateTensorShapesFromInputs()170 void QLstmLayer::ValidateTensorShapesFromInputs()
171 {
172     VerifyLayerConnections(3, CHECK_LOCATION());
173 
174     const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
175 
176     VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
177 
178     auto inferredShapes = InferOutputShapes(
179     {
180         GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), // input
181         GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(), // previousOutputIn
182         GetInputSlot(2).GetConnection()->GetTensorInfo().GetShape()  // previousCellStateIn
183     });
184 
185     ARMNN_ASSERT(inferredShapes.size() == 3);
186 
187     // Check if the weights are nullptr for basic params
188     ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToForgetWeights != nullptr,
189             "QLstmLayer: m_BasicParameters.m_InputToForgetWeights should not be null.");
190     ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToCellWeights != nullptr,
191             "QLstmLayer: m_BasicParameters.m_InputToCellWeights should not be null.");
192     ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToOutputWeights != nullptr,
193             "QLstmLayer: m_BasicParameters.m_InputToOutputWeights should not be null.");
194     ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToForgetWeights != nullptr,
195             "QLstmLayer: m_BasicParameters.m_RecurrentToForgetWeights should not be null.");
196     ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToCellWeights != nullptr,
197             "QLstmLayer: m_BasicParameters.m_RecurrentToCellWeights should not be null.");
198     ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToOutputWeights != nullptr,
199             "QLstmLayer: m_BasicParameters.m_RecurrentToOutputWeights should not be null.");
200     ARMNN_ASSERT_MSG(m_BasicParameters.m_ForgetGateBias != nullptr,
201             "QLstmLayer: m_BasicParameters.m_ForgetGateBias should not be null.");
202     ARMNN_ASSERT_MSG(m_BasicParameters.m_CellBias != nullptr,
203             "QLstmLayer: m_BasicParameters.m_CellBias should not be null.");
204     ARMNN_ASSERT_MSG(m_BasicParameters.m_OutputGateBias != nullptr,
205             "QLstmLayer: m_BasicParameters.m_OutputGateBias should not be null.");
206 
207     if (!m_Param.m_CifgEnabled)
208     {
209         ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights != nullptr,
210                 "QLstmLayer: m_CifgParameters.m_InputToInputWeights should not be null.");
211         ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights != nullptr,
212                 "QLstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not be null.");
213         ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias != nullptr,
214                 "QLstmLayer: m_CifgParameters.m_InputGateBias should not be null.");
215 
216         ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "QLstmLayer");
217     }
218     else
219     {
220         ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights == nullptr,
221                 "QLstmLayer: m_CifgParameters.m_InputToInputWeights should not have a value when CIFG is enabled.");
222         ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights == nullptr,
223                 "QLstmLayer: m_CifgParameters.m_RecurrentToInputWeights should "
224                              "not have a value when CIFG is enabled.");
225         ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias == nullptr,
226                 "QLstmLayer: m_CifgParameters.m_InputGateBias should not have a value when CIFG is enabled.");
227 
228         ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "QLstmLayer");
229     }
230 
231     if (m_Param.m_ProjectionEnabled)
232     {
233         ARMNN_ASSERT_MSG(m_ProjectionParameters.m_ProjectionWeights != nullptr,
234                          "QLstmLayer: m_ProjectionParameters.m_ProjectionWeights should not be null.");
235     }
236 
237     if (m_Param.m_PeepholeEnabled)
238     {
239         if (!m_Param.m_CifgEnabled) {
240             ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToInputWeights != nullptr,
241                     "QLstmLayer: m_PeepholeParameters.m_CellToInputWeights should not be null "
242                     "when Peephole is enabled and CIFG is disabled.");
243         }
244 
245         ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToForgetWeights != nullptr,
246                          "QLstmLayer: m_PeepholeParameters.m_CellToForgetWeights should not be null.");
247         ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToOutputWeights != nullptr,
248                          "QLstmLayer: m_PeepholeParameters.m_CellToOutputWeights should not be null.");
249     }
250 
251     ValidateAndCopyShape(
252             GetOutputSlot(1).GetTensorInfo().GetShape(), inferredShapes[1], m_ShapeInferenceMethod, "QLstmLayer", 1);
253     ValidateAndCopyShape(
254             GetOutputSlot(2).GetTensorInfo().GetShape(), inferredShapes[2], m_ShapeInferenceMethod, "QLstmLayer", 2);
255 
256     if (m_Param.m_LayerNormEnabled)
257     {
258         if(!m_Param.m_CifgEnabled)
259         {
260             ARMNN_ASSERT_MSG(m_LayerNormParameters.m_InputLayerNormWeights != nullptr,
261                              "QLstmLayer: m_LayerNormParameters.m_InputLayerNormWeights should not be null.");
262         }
263         ARMNN_ASSERT_MSG(m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr,
264                          "QLstmLayer: m_LayerNormParameters.m_ForgetLayerNormWeights should not be null.");
265         ARMNN_ASSERT_MSG(m_LayerNormParameters.m_CellLayerNormWeights != nullptr,
266                          "QLstmLayer: m_LayerNormParameters.m_CellLayerNormWeights should not be null.");
267         ARMNN_ASSERT_MSG(m_LayerNormParameters.m_OutputLayerNormWeights != nullptr,
268                          "QLstmLayer: m_LayerNormParameters.m_UutputLayerNormWeights should not be null.");
269     }
270 }
271 
GetConstantTensorsByRef() const272 Layer::ImmutableConstantTensors QLstmLayer::GetConstantTensorsByRef() const
273 {
274     // For API stability DO NOT ALTER order and add new members to the end of vector
275     return {m_BasicParameters.m_InputToForgetWeights,
276             m_BasicParameters.m_InputToCellWeights,
277             m_BasicParameters.m_InputToOutputWeights,
278             m_BasicParameters.m_RecurrentToForgetWeights,
279             m_BasicParameters.m_RecurrentToCellWeights,
280             m_BasicParameters.m_RecurrentToOutputWeights,
281             m_BasicParameters.m_ForgetGateBias,
282             m_BasicParameters.m_CellBias,
283             m_BasicParameters.m_OutputGateBias,
284 
285             // Cifg parameters
286             m_CifgParameters.m_InputToInputWeights,
287             m_CifgParameters.m_RecurrentToInputWeights,
288             m_CifgParameters.m_InputGateBias,
289 
290             // Projection parameters
291             m_ProjectionParameters.m_ProjectionWeights,
292             m_ProjectionParameters.m_ProjectionBias,
293 
294             // Peephole parameters
295             m_PeepholeParameters.m_CellToInputWeights,
296             m_PeepholeParameters.m_CellToForgetWeights,
297             m_PeepholeParameters.m_CellToOutputWeights,
298 
299             // Layer normalisation parameters
300             m_LayerNormParameters.m_InputLayerNormWeights,
301             m_LayerNormParameters.m_ForgetLayerNormWeights,
302             m_LayerNormParameters.m_CellLayerNormWeights,
303             m_LayerNormParameters.m_OutputLayerNormWeights};
304 }
305 
306 
ExecuteStrategy(IStrategy & strategy) const307 void QLstmLayer::ExecuteStrategy(IStrategy& strategy) const
308 {
309     std::vector<ConstTensor> constTensors;
310     ManagedConstTensorHandle managedInputToForgetWeights(m_BasicParameters.m_InputToForgetWeights);
311     ManagedConstTensorHandle managedInputToCellWeights(m_BasicParameters.m_InputToCellWeights);
312     ManagedConstTensorHandle managedInputToOutputWeights(m_BasicParameters.m_InputToOutputWeights);
313     ManagedConstTensorHandle managedRecurrentToForgetWeights(m_BasicParameters.m_RecurrentToForgetWeights);
314     ManagedConstTensorHandle managedRecurrentToCellWeights(m_BasicParameters.m_RecurrentToCellWeights);
315     ManagedConstTensorHandle managedRecurrentToOutputWeights(m_BasicParameters.m_RecurrentToOutputWeights);
316     ManagedConstTensorHandle managedForgetGateBias(m_BasicParameters.m_ForgetGateBias);
317     ManagedConstTensorHandle managedCellBias(m_BasicParameters.m_CellBias);
318     ManagedConstTensorHandle managedOutputGateBias(m_BasicParameters.m_OutputGateBias);
319 
320     // Cifg parameters
321     ManagedConstTensorHandle managedInputToInputWeights(m_CifgParameters.m_InputToInputWeights);
322     ManagedConstTensorHandle managedRecurrentToInputWeights(m_CifgParameters.m_RecurrentToInputWeights);
323     ManagedConstTensorHandle managedInputGateBias(m_CifgParameters.m_InputGateBias);
324 
325     // Projection parameters
326     ManagedConstTensorHandle managedProjectionWeights(m_ProjectionParameters.m_ProjectionWeights);
327     ManagedConstTensorHandle managedProjectionBias(m_ProjectionParameters.m_ProjectionBias);
328 
329     // Peephole parameters
330     ManagedConstTensorHandle managedCellToInputWeights(m_PeepholeParameters.m_CellToInputWeights);
331     ManagedConstTensorHandle managedCellToForgetWeights(m_PeepholeParameters.m_CellToForgetWeights);
332     ManagedConstTensorHandle managedCellToOutputWeights(m_PeepholeParameters.m_CellToOutputWeights);
333 
334     // Layer normalisation parameters
335     ManagedConstTensorHandle managedInputLayerNormWeights(m_LayerNormParameters.m_InputLayerNormWeights);
336     ManagedConstTensorHandle managedForgetLayerNormWeights(m_LayerNormParameters.m_ForgetLayerNormWeights);
337     ManagedConstTensorHandle managedCellLayerNormWeights(m_LayerNormParameters.m_CellLayerNormWeights);
338     ManagedConstTensorHandle managedOutputLayerNormWeights(m_LayerNormParameters.m_OutputLayerNormWeights);
339 
340     // First add mandatory/basic parameters
341     if (m_BasicParameters.m_InputToForgetWeights != nullptr)
342     {
343         constTensors.emplace_back(ConstTensor(managedInputToForgetWeights.GetTensorInfo(),
344                                               managedInputToForgetWeights.Map()));
345     }
346     if (m_BasicParameters.m_InputToCellWeights != nullptr)
347     {
348         constTensors.emplace_back(ConstTensor(managedInputToCellWeights.GetTensorInfo(),
349                                               managedInputToCellWeights.Map()));
350     }
351     if (m_BasicParameters.m_InputToOutputWeights != nullptr)
352     {
353         constTensors.emplace_back(ConstTensor(managedInputToOutputWeights.GetTensorInfo(),
354                                               managedInputToOutputWeights.Map()));
355     }
356     if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr)
357     {
358         constTensors.emplace_back(ConstTensor(
359                 managedRecurrentToForgetWeights.GetTensorInfo(),
360                 managedRecurrentToForgetWeights.Map()));
361     }
362     if (m_BasicParameters.m_RecurrentToCellWeights != nullptr)
363     {
364         constTensors.emplace_back(ConstTensor(
365                 managedRecurrentToCellWeights.GetTensorInfo(),
366                 managedRecurrentToCellWeights.Map()));
367     }
368     if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr)
369     {
370         constTensors.emplace_back(ConstTensor(
371                 managedRecurrentToOutputWeights.GetTensorInfo(),
372                 managedRecurrentToOutputWeights.Map()));
373     }
374     if (m_BasicParameters.m_ForgetGateBias != nullptr)
375     {
376         constTensors.emplace_back(ConstTensor(managedForgetGateBias.GetTensorInfo(),
377                                               managedForgetGateBias.Map()));
378     }
379     if (m_BasicParameters.m_CellBias != nullptr)
380     {
381         constTensors.emplace_back(ConstTensor(managedCellBias.GetTensorInfo(),
382                                               managedCellBias.Map()));
383     }
384     if (m_BasicParameters.m_OutputGateBias != nullptr)
385     {
386         constTensors.emplace_back(ConstTensor(managedOutputGateBias.GetTensorInfo(),
387                                               managedOutputGateBias.Map()));
388     }
389 
390     // Add cifig parameters
391     if (m_CifgParameters.m_InputToInputWeights != nullptr)
392     {
393         constTensors.emplace_back(ConstTensor(managedInputToInputWeights.GetTensorInfo(),
394                                               managedInputToInputWeights.Map()));
395     }
396     if (m_CifgParameters.m_RecurrentToInputWeights != nullptr)
397     {
398         constTensors.emplace_back(ConstTensor(
399                 managedRecurrentToInputWeights.GetTensorInfo(),
400                 managedRecurrentToInputWeights.Map()));
401     }
402     if (m_CifgParameters.m_InputGateBias != nullptr)
403     {
404         constTensors.emplace_back(ConstTensor(managedInputGateBias.GetTensorInfo(),
405                                               managedInputGateBias.Map()));
406     }
407 
408     // Add peephole parameters
409     if (m_PeepholeParameters.m_CellToInputWeights != nullptr)
410     {
411         constTensors.emplace_back(ConstTensor(managedCellToInputWeights.GetTensorInfo(),
412                                               managedCellToInputWeights.Map()));
413     }
414     if (m_PeepholeParameters.m_CellToForgetWeights != nullptr)
415     {
416         constTensors.emplace_back(ConstTensor(managedCellToForgetWeights.GetTensorInfo(),
417                                               managedCellToForgetWeights.Map()));
418     }
419     if (m_PeepholeParameters.m_CellToOutputWeights != nullptr)
420     {
421         constTensors.emplace_back(ConstTensor(managedCellToOutputWeights.GetTensorInfo(),
422                                               managedCellToOutputWeights.Map()));
423     }
424 
425     // Add projection parameters
426     if (m_ProjectionParameters.m_ProjectionWeights != nullptr)
427     {
428         constTensors.emplace_back(ConstTensor(managedProjectionWeights.GetTensorInfo(),
429                                               managedProjectionWeights.Map()));
430     }
431     if (m_ProjectionParameters.m_ProjectionBias != nullptr)
432     {
433         constTensors.emplace_back(ConstTensor(managedProjectionBias.GetTensorInfo(),
434                                               managedProjectionBias.Map()));
435     }
436 
437     // Add norm parameters
438     if (m_LayerNormParameters.m_InputLayerNormWeights != nullptr)
439     {
440         constTensors.emplace_back(ConstTensor(managedInputLayerNormWeights.GetTensorInfo(),
441                                               managedInputLayerNormWeights.Map()));
442     }
443     if (m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr)
444     {
445         constTensors.emplace_back(ConstTensor(managedForgetLayerNormWeights.GetTensorInfo(),
446                                               managedForgetLayerNormWeights.Map()));
447     }
448     if (m_LayerNormParameters.m_CellLayerNormWeights != nullptr)
449     {
450         constTensors.emplace_back(ConstTensor(managedCellLayerNormWeights.GetTensorInfo(),
451                                               managedCellLayerNormWeights.Map()));
452     }
453     if (m_LayerNormParameters.m_OutputLayerNormWeights != nullptr)
454     {
455         constTensors.emplace_back(ConstTensor(managedOutputLayerNormWeights.GetTensorInfo(),
456                                               managedOutputLayerNormWeights.Map()));
457     }
458     strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName());
459 }
460 
461 } // namespace armnn
462