xref: /aosp_15_r20/external/armnn/src/armnn/layers/LstmLayer.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "LstmLayer.hpp"
6 
7 #include "LayerCloneBase.hpp"
8 
9 #include <armnn/LstmParams.hpp>
10 #include <armnn/TypesUtils.hpp>
11 #include <armnn/backends/TensorHandle.hpp>
12 #include <armnn/backends/WorkloadFactory.hpp>
13 
14 namespace armnn
15 {
16 
LstmLayer(const LstmDescriptor & param,const char * name)17 LstmLayer::LstmLayer(const LstmDescriptor& param, const char* name)
18         : LayerWithParameters(3, 4, LayerType::Lstm, param, name)
19 {
20 }
21 
CreateWorkload(const IWorkloadFactory & factory) const22 std::unique_ptr<IWorkload> LstmLayer::CreateWorkload(const IWorkloadFactory& factory) const
23 {
24     LstmQueueDescriptor descriptor;
25 
26     // Basic parameters
27     descriptor.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights.get();
28     descriptor.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights.get();
29     descriptor.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights.get();
30     descriptor.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights.get();
31     descriptor.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights.get();
32     descriptor.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights.get();
33     descriptor.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias.get();
34     descriptor.m_CellBias = m_BasicParameters.m_CellBias.get();
35     descriptor.m_OutputGateBias = m_BasicParameters.m_OutputGateBias.get();
36 
37     // Cifg parameters
38     if (!m_Param.m_CifgEnabled)
39     {
40         descriptor.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights.get();
41         descriptor.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights.get();
42         descriptor.m_InputGateBias = m_CifgParameters.m_InputGateBias.get();
43     }
44 
45     // Projection parameters
46     if (m_Param.m_ProjectionEnabled)
47     {
48         descriptor.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights.get();
49         descriptor.m_ProjectionBias    = m_ProjectionParameters.m_ProjectionBias.get();
50     }
51 
52     // Peephole parameters
53     if (m_Param.m_PeepholeEnabled)
54     {
55         if (!m_Param.m_CifgEnabled)
56         {
57             descriptor.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights.get();
58         }
59         descriptor.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights.get();
60         descriptor.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights.get();
61     }
62 
63     // Layer normalisation parameters
64     if(m_Param.m_LayerNormEnabled)
65     {
66         if (!m_Param.m_CifgEnabled)
67         {
68             descriptor.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights.get();
69         }
70         descriptor.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights.get();
71         descriptor.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights.get();
72         descriptor.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights.get();
73     }
74 
75     SetAdditionalInfo(descriptor);
76 
77     return factory.CreateWorkload(LayerType::Lstm, descriptor, PrepInfoAndDesc(descriptor));
78 }
79 
Clone(Graph & graph) const80 LstmLayer* LstmLayer::Clone(Graph& graph) const
81 {
82     auto layer = CloneBase<LstmLayer>(graph, m_Param, GetName());
83 
84     layer->m_BasicParameters.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights ?
85             m_BasicParameters.m_InputToForgetWeights
86                 : nullptr;
87     layer->m_BasicParameters.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights ?
88             m_BasicParameters.m_InputToCellWeights : nullptr;
89     layer->m_BasicParameters.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights ?
90             m_BasicParameters.m_InputToOutputWeights : nullptr;
91     layer->m_BasicParameters.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights ?
92             m_BasicParameters.m_RecurrentToForgetWeights : nullptr;
93     layer->m_BasicParameters.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights ?
94             m_BasicParameters.m_RecurrentToCellWeights : nullptr;
95     layer->m_BasicParameters.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights ?
96             m_BasicParameters.m_RecurrentToOutputWeights : nullptr;
97     layer->m_BasicParameters.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias ?
98             m_BasicParameters.m_ForgetGateBias : nullptr;
99     layer->m_BasicParameters.m_CellBias = m_BasicParameters.m_CellBias ?
100             m_BasicParameters.m_CellBias : nullptr;
101     layer->m_BasicParameters.m_OutputGateBias = m_BasicParameters.m_OutputGateBias ?
102             m_BasicParameters.m_OutputGateBias : nullptr;
103 
104     if (!m_Param.m_CifgEnabled)
105     {
106         layer->m_CifgParameters.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights ?
107                 m_CifgParameters.m_InputToInputWeights : nullptr;
108         layer->m_CifgParameters.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights ?
109                 m_CifgParameters.m_RecurrentToInputWeights : nullptr;
110         layer->m_CifgParameters.m_InputGateBias = m_CifgParameters.m_InputGateBias ?
111                 m_CifgParameters.m_InputGateBias : nullptr;
112     }
113 
114     if (m_Param.m_ProjectionEnabled)
115     {
116         layer->m_ProjectionParameters.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights ?
117                m_ProjectionParameters.m_ProjectionWeights : nullptr;
118         layer->m_ProjectionParameters.m_ProjectionBias = m_ProjectionParameters.m_ProjectionBias ?
119                m_ProjectionParameters.m_ProjectionBias : nullptr;
120     }
121 
122     if (m_Param.m_PeepholeEnabled)
123     {
124         if (!m_Param.m_CifgEnabled)
125         {
126             layer->m_PeepholeParameters.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights ?
127                 m_PeepholeParameters.m_CellToInputWeights : nullptr;
128         }
129         layer->m_PeepholeParameters.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights ?
130                m_PeepholeParameters.m_CellToForgetWeights : nullptr;
131         layer->m_PeepholeParameters.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights ?
132                m_PeepholeParameters.m_CellToOutputWeights : nullptr;
133     }
134 
135     if (m_Param.m_LayerNormEnabled)
136     {
137         layer->m_LayerNormParameters.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights ?
138                m_LayerNormParameters.m_InputLayerNormWeights : nullptr;
139         layer->m_LayerNormParameters.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights ?
140                m_LayerNormParameters.m_ForgetLayerNormWeights : nullptr;
141         layer->m_LayerNormParameters.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights ?
142                m_LayerNormParameters.m_CellLayerNormWeights : nullptr;
143         layer->m_LayerNormParameters.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights ?
144                m_LayerNormParameters.m_OutputLayerNormWeights : nullptr;
145     }
146 
147     return std::move(layer);
148 }
149 
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const150 std::vector<TensorShape> LstmLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
151 {
152     ARMNN_ASSERT(inputShapes.size() == 3);
153 
154     // Get input values for validation
155     unsigned int batchSize = inputShapes[0][0];
156     unsigned int outputSize = inputShapes[1][1];
157     unsigned int numUnits = inputShapes[2][1];
158 
159     std::vector<TensorShape> outShapes;
160     outShapes.push_back(TensorShape({batchSize, numUnits * (m_Param.m_CifgEnabled ? 3 : 4)}));
161     outShapes.push_back(TensorShape({batchSize, outputSize}));
162     outShapes.push_back(TensorShape({batchSize, numUnits}));
163     outShapes.push_back(TensorShape({batchSize, outputSize}));
164 
165     return outShapes;
166 }
167 
ValidateTensorShapesFromInputs()168 void LstmLayer::ValidateTensorShapesFromInputs()
169 {
170     VerifyLayerConnections(3, CHECK_LOCATION());
171 
172     const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
173 
174     VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
175 
176     auto inferredShapes = InferOutputShapes( {
177         GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
178         GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(),
179         GetInputSlot(2).GetConnection()->GetTensorInfo().GetShape()
180     });
181 
182     ARMNN_ASSERT(inferredShapes.size() == 4);
183 
184     // Check if the weights are nullptr
185     ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToForgetWeights != nullptr,
186                      "LstmLayer: m_BasicParameters.m_InputToForgetWeights should not be null.");
187     ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToCellWeights != nullptr,
188                      "LstmLayer: m_BasicParameters.m_InputToCellWeights should not be null.");
189     ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToOutputWeights != nullptr,
190                      "LstmLayer: m_BasicParameters.m_InputToOutputWeights should not be null.");
191     ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToForgetWeights != nullptr,
192                      "LstmLayer: m_BasicParameters.m_RecurrentToForgetWeights should not be null.");
193     ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToCellWeights != nullptr,
194                      "LstmLayer: m_BasicParameters.m_RecurrentToCellWeights should not be null.");
195     ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToOutputWeights != nullptr,
196                      "LstmLayer: m_BasicParameters.m_RecurrentToOutputWeights should not be null.");
197     ARMNN_ASSERT_MSG(m_BasicParameters.m_ForgetGateBias != nullptr,
198                      "LstmLayer: m_BasicParameters.m_ForgetGateBias should not be null.");
199     ARMNN_ASSERT_MSG(m_BasicParameters.m_CellBias != nullptr,
200                      "LstmLayer: m_BasicParameters.m_CellBias should not be null.");
201     ARMNN_ASSERT_MSG(m_BasicParameters.m_OutputGateBias != nullptr,
202                      "LstmLayer: m_BasicParameters.m_OutputGateBias should not be null.");
203 
204     if (!m_Param.m_CifgEnabled)
205     {
206         ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights != nullptr,
207                          "LstmLayer: m_CifgParameters.m_InputToInputWeights should not be null.");
208         ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights != nullptr,
209                          "LstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not be null.");
210         ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias != nullptr,
211                          "LstmLayer: m_CifgParameters.m_InputGateBias should not be null.");
212 
213         ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "LstmLayer");
214     }
215     else
216     {
217         ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights == nullptr,
218             "LstmLayer: m_CifgParameters.m_InputToInputWeights should not have a value when CIFG is enabled.");
219         ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights == nullptr,
220             "LstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not have a value when CIFG is enabled.");
221         ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias == nullptr,
222             "LstmLayer: m_CifgParameters.m_InputGateBias should not have a value when CIFG is enabled.");
223 
224         ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "LstmLayer");
225     }
226 
227     if (m_Param.m_ProjectionEnabled)
228     {
229         ARMNN_ASSERT_MSG(m_ProjectionParameters.m_ProjectionWeights != nullptr,
230                          "LstmLayer: m_ProjectionParameters.m_ProjectionWeights should not be null.");
231     }
232 
233     if (m_Param.m_PeepholeEnabled)
234     {
235         if (!m_Param.m_CifgEnabled)
236         {
237             ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToInputWeights != nullptr,
238                              "LstmLayer: m_PeepholeParameters.m_CellToInputWeights should not be null "
239                              "when Peephole is enabled and CIFG is disabled.");
240         }
241         ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToForgetWeights != nullptr,
242                          "LstmLayer: m_PeepholeParameters.m_CellToForgetWeights should not be null.");
243         ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToOutputWeights != nullptr,
244                          "LstmLayer: m_PeepholeParameters.m_CellToOutputWeights should not be null.");
245     }
246 
247     ValidateAndCopyShape(
248             GetOutputSlot(1).GetTensorInfo().GetShape(), inferredShapes[1], m_ShapeInferenceMethod, "LstmLayer", 1);
249     ValidateAndCopyShape(
250             GetOutputSlot(2).GetTensorInfo().GetShape(), inferredShapes[2], m_ShapeInferenceMethod, "LstmLayer", 2);
251     ValidateAndCopyShape(
252             GetOutputSlot(3).GetTensorInfo().GetShape(), inferredShapes[3], m_ShapeInferenceMethod, "LstmLayer", 3);
253 
254     if (m_Param.m_LayerNormEnabled)
255     {
256         if(!m_Param.m_CifgEnabled)
257         {
258             ARMNN_ASSERT_MSG(m_LayerNormParameters.m_InputLayerNormWeights != nullptr,
259                              "LstmLayer: m_LayerNormParameters.m_inputLayerNormWeights should not be null.");
260         }
261         ARMNN_ASSERT_MSG(m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr,
262                          "LstmLayer: m_LayerNormParameters.m_forgetLayerNormWeights should not be null.");
263         ARMNN_ASSERT_MSG(m_LayerNormParameters.m_CellLayerNormWeights != nullptr,
264                          "LstmLayer: m_LayerNormParameters.m_cellLayerNormWeights should not be null.");
265         ARMNN_ASSERT_MSG(m_LayerNormParameters.m_OutputLayerNormWeights != nullptr,
266                          "LstmLayer: m_LayerNormParameters.m_outputLayerNormWeights should not be null.");
267     }
268 }
269 
GetConstantTensorsByRef() const270 Layer::ImmutableConstantTensors LstmLayer::GetConstantTensorsByRef() const
271 {
272     // For API stability DO NOT ALTER order and add new members to the end of vector
273     return {m_BasicParameters.m_InputToForgetWeights,
274             m_BasicParameters.m_InputToCellWeights,
275             m_BasicParameters.m_InputToOutputWeights,
276             m_BasicParameters.m_RecurrentToForgetWeights,
277             m_BasicParameters.m_RecurrentToCellWeights,
278             m_BasicParameters.m_RecurrentToOutputWeights,
279             m_BasicParameters.m_ForgetGateBias,
280             m_BasicParameters.m_CellBias,
281             m_BasicParameters.m_OutputGateBias,
282 
283             // Cifg parameters
284             m_CifgParameters.m_InputToInputWeights,
285             m_CifgParameters.m_RecurrentToInputWeights,
286             m_CifgParameters.m_InputGateBias,
287 
288             // Projection parameters
289             m_ProjectionParameters.m_ProjectionWeights,
290             m_ProjectionParameters.m_ProjectionBias,
291 
292             // Peephole parameters
293             m_PeepholeParameters.m_CellToInputWeights,
294             m_PeepholeParameters.m_CellToForgetWeights,
295             m_PeepholeParameters.m_CellToOutputWeights,
296 
297             // Layer normalisation parameters
298             m_LayerNormParameters.m_InputLayerNormWeights,
299             m_LayerNormParameters.m_ForgetLayerNormWeights,
300             m_LayerNormParameters.m_CellLayerNormWeights,
301             m_LayerNormParameters.m_OutputLayerNormWeights};
302 }
303 
ExecuteStrategy(IStrategy & strategy) const304 void LstmLayer::ExecuteStrategy(IStrategy& strategy) const
305 {
306     std::vector<ConstTensor> constTensors;
307 
308     LstmDescriptor descriptor = GetParameters();
309 
310     ManagedConstTensorHandle managedInputToForgetWeights(m_BasicParameters.m_InputToForgetWeights);
311     ManagedConstTensorHandle managedInputToCellWeights(m_BasicParameters.m_InputToCellWeights);
312     ManagedConstTensorHandle managedInputToOutputWeights(m_BasicParameters.m_InputToOutputWeights);
313     ManagedConstTensorHandle managedRecurrentToForgetWeights(m_BasicParameters.m_RecurrentToForgetWeights);
314     ManagedConstTensorHandle managedRecurrentToCellWeights(m_BasicParameters.m_RecurrentToCellWeights);
315     ManagedConstTensorHandle managedRecurrentToOutputWeights(m_BasicParameters.m_RecurrentToOutputWeights);
316     ManagedConstTensorHandle managedForgetGateBias(m_BasicParameters.m_ForgetGateBias);
317     ManagedConstTensorHandle managedCellBias(m_BasicParameters.m_CellBias);
318     ManagedConstTensorHandle managedOutputGateBias(m_BasicParameters.m_OutputGateBias);
319 
320     // Cifg parameters
321     ManagedConstTensorHandle managedInputToInputWeights(m_CifgParameters.m_InputToInputWeights);
322     ManagedConstTensorHandle managedRecurrentToInputWeights(m_CifgParameters.m_RecurrentToInputWeights);
323     ManagedConstTensorHandle managedInputGateBias(m_CifgParameters.m_InputGateBias);
324 
325     // Projection parameters
326     ManagedConstTensorHandle managedProjectionWeights(m_ProjectionParameters.m_ProjectionWeights);
327     ManagedConstTensorHandle managedProjectionBias(m_ProjectionParameters.m_ProjectionBias);
328 
329     // Peephole parameters
330     ManagedConstTensorHandle managedCellToInputWeights(m_PeepholeParameters.m_CellToInputWeights);
331     ManagedConstTensorHandle managedCellToForgetWeights(m_PeepholeParameters.m_CellToForgetWeights);
332     ManagedConstTensorHandle managedCellToOutputWeights(m_PeepholeParameters.m_CellToOutputWeights);
333 
334     // Layer normalisation parameters
335     ManagedConstTensorHandle managedInputLayerNormWeights(m_LayerNormParameters.m_InputLayerNormWeights);
336     ManagedConstTensorHandle managedForgetLayerNormWeights(m_LayerNormParameters.m_ForgetLayerNormWeights);
337     ManagedConstTensorHandle managedCellLayerNormWeights(m_LayerNormParameters.m_CellLayerNormWeights);
338     ManagedConstTensorHandle managedOutputLayerNormWeights(m_LayerNormParameters.m_OutputLayerNormWeights);
339 
340     // First add mandatory/basic parameters
341     if (m_BasicParameters.m_InputToForgetWeights != nullptr)
342     {
343         constTensors.emplace_back(ConstTensor(managedInputToForgetWeights.GetTensorInfo(),
344                                               managedInputToForgetWeights.Map()));
345     }
346     if (m_BasicParameters.m_InputToCellWeights != nullptr)
347     {
348         constTensors.emplace_back(ConstTensor(managedInputToCellWeights.GetTensorInfo(),
349                                               managedInputToCellWeights.Map()));
350     }
351     if (m_BasicParameters.m_InputToOutputWeights != nullptr)
352     {
353         constTensors.emplace_back(ConstTensor(managedInputToOutputWeights.GetTensorInfo(),
354                                               managedInputToOutputWeights.Map()));
355     }
356     if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr)
357     {
358         constTensors.emplace_back(ConstTensor(
359                 managedRecurrentToForgetWeights.GetTensorInfo(),
360                 managedRecurrentToForgetWeights.Map()));
361     }
362     if (m_BasicParameters.m_RecurrentToCellWeights != nullptr)
363     {
364         constTensors.emplace_back(ConstTensor(
365                 managedRecurrentToCellWeights.GetTensorInfo(),
366                 managedRecurrentToCellWeights.Map()));
367     }
368     if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr)
369     {
370         constTensors.emplace_back(ConstTensor(
371                 managedRecurrentToOutputWeights.GetTensorInfo(),
372                 managedRecurrentToOutputWeights.Map()));
373     }
374     if (m_BasicParameters.m_ForgetGateBias != nullptr)
375     {
376         constTensors.emplace_back(ConstTensor(managedForgetGateBias.GetTensorInfo(),
377                                               managedForgetGateBias.Map()));
378     }
379     if (m_BasicParameters.m_CellBias != nullptr)
380     {
381         constTensors.emplace_back(ConstTensor(managedCellBias.GetTensorInfo(),
382                                               managedCellBias.Map()));
383     }
384     if (m_BasicParameters.m_OutputGateBias != nullptr)
385     {
386         constTensors.emplace_back(ConstTensor(managedOutputGateBias.GetTensorInfo(),
387                                               managedOutputGateBias.Map()));
388     }
389 
390     // Add cifg parameters
391     if (!descriptor.m_CifgEnabled)
392     {
393         if (m_CifgParameters.m_InputToInputWeights != nullptr)
394         {
395             constTensors.emplace_back(ConstTensor(managedInputToInputWeights.GetTensorInfo(),
396                                                   managedInputToInputWeights.Map()));
397         }
398         if (m_CifgParameters.m_RecurrentToInputWeights != nullptr)
399         {
400             constTensors.emplace_back(ConstTensor(
401                     managedRecurrentToInputWeights.GetTensorInfo(),
402                     managedRecurrentToInputWeights.Map()));
403         }
404         if (m_CifgParameters.m_InputGateBias != nullptr)
405         {
406             constTensors.emplace_back(ConstTensor(managedInputGateBias.GetTensorInfo(),
407                                                   managedInputGateBias.Map()));
408         }
409     }
410 
411     // Add peephole parameters
412     if (descriptor.m_PeepholeEnabled)
413     {
414         if (!descriptor.m_CifgEnabled)
415         {
416             if (m_PeepholeParameters.m_CellToInputWeights != nullptr)
417             {
418                 constTensors.emplace_back(ConstTensor(managedCellToInputWeights.GetTensorInfo(),
419                                                       managedCellToInputWeights.Map()));
420             }
421         }
422         if (m_PeepholeParameters.m_CellToForgetWeights != nullptr)
423         {
424             constTensors.emplace_back(ConstTensor(managedCellToForgetWeights.GetTensorInfo(),
425                                                   managedCellToForgetWeights.Map()));
426         }
427         if (m_PeepholeParameters.m_CellToOutputWeights != nullptr)
428         {
429             constTensors.emplace_back(ConstTensor(managedCellToOutputWeights.GetTensorInfo(),
430                                                   managedCellToOutputWeights.Map()));
431         }
432     }
433 
434     // Add projection parameters
435     if (descriptor.m_ProjectionEnabled)
436     {
437         if (m_ProjectionParameters.m_ProjectionWeights != nullptr)
438         {
439             constTensors.emplace_back(ConstTensor(managedProjectionWeights.GetTensorInfo(),
440                                                   managedProjectionWeights.Map()));
441         }
442         if (m_ProjectionParameters.m_ProjectionBias != nullptr)
443         {
444             constTensors.emplace_back(ConstTensor(managedProjectionBias.GetTensorInfo(),
445                                                   managedProjectionBias.Map()));
446         }
447     }
448 
449     // Add norm parameters
450     if (descriptor.m_LayerNormEnabled)
451     {
452         if (!descriptor.m_CifgEnabled)
453         {
454             if (m_LayerNormParameters.m_InputLayerNormWeights != nullptr)
455             {
456                 constTensors.emplace_back(ConstTensor(managedInputLayerNormWeights.GetTensorInfo(),
457                                                       managedInputLayerNormWeights.Map()));
458             }
459         }
460         if (m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr)
461         {
462             constTensors.emplace_back(ConstTensor(managedForgetLayerNormWeights.GetTensorInfo(),
463                                                   managedForgetLayerNormWeights.Map()));
464         }
465         if (m_LayerNormParameters.m_CellLayerNormWeights != nullptr)
466         {
467             constTensors.emplace_back(ConstTensor(managedCellLayerNormWeights.GetTensorInfo(),
468                                                   managedCellLayerNormWeights.Map()));
469         }
470         if (m_LayerNormParameters.m_OutputLayerNormWeights != nullptr)
471         {
472             constTensors.emplace_back(ConstTensor(managedOutputLayerNormWeights.GetTensorInfo(),
473                                                   managedOutputLayerNormWeights.Map()));
474         }
475     }
476 
477     strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName());
478 }
479 
480 } // namespace armnn
481