xref: /aosp_15_r20/external/armnn/src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "UnidirectionalSequenceLstmLayer.hpp"
6 
7 #include "LayerCloneBase.hpp"
8 
9 #include <armnn/LstmParams.hpp>
10 #include <armnn/TypesUtils.hpp>
11 #include <armnn/backends/TensorHandle.hpp>
12 #include <armnn/backends/WorkloadFactory.hpp>
13 
14 namespace armnn
15 {
16 
UnidirectionalSequenceLstmLayer(const LstmDescriptor & param,const char * name)17 UnidirectionalSequenceLstmLayer::UnidirectionalSequenceLstmLayer(const LstmDescriptor& param, const char* name)
18         : LayerWithParameters(3, 3, LayerType::UnidirectionalSequenceLstm, param, name)
19 {
20 }
21 
CreateWorkload(const IWorkloadFactory & factory) const22 std::unique_ptr<IWorkload> UnidirectionalSequenceLstmLayer::CreateWorkload(const IWorkloadFactory& factory) const
23 {
24     UnidirectionalSequenceLstmQueueDescriptor descriptor;
25 
26     // Basic parameters
27     descriptor.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights.get();
28     descriptor.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights.get();
29     descriptor.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights.get();
30     descriptor.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights.get();
31     descriptor.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights.get();
32     descriptor.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights.get();
33     descriptor.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias.get();
34     descriptor.m_CellBias = m_BasicParameters.m_CellBias.get();
35     descriptor.m_OutputGateBias = m_BasicParameters.m_OutputGateBias.get();
36 
37     // Cifg parameters
38     if (!m_Param.m_CifgEnabled)
39     {
40         descriptor.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights.get();
41         descriptor.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights.get();
42         descriptor.m_InputGateBias = m_CifgParameters.m_InputGateBias.get();
43     }
44 
45     // Projection parameters
46     if (m_Param.m_ProjectionEnabled)
47     {
48         descriptor.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights.get();
49         descriptor.m_ProjectionBias    = m_ProjectionParameters.m_ProjectionBias.get();
50     }
51 
52     // Peephole parameters
53     if (m_Param.m_PeepholeEnabled)
54     {
55         if (!m_Param.m_CifgEnabled)
56         {
57             descriptor.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights.get();
58         }
59         descriptor.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights.get();
60         descriptor.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights.get();
61     }
62 
63     // Layer normalisation parameters
64     if(m_Param.m_LayerNormEnabled)
65     {
66         if (!m_Param.m_CifgEnabled)
67         {
68             descriptor.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights.get();
69         }
70         descriptor.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights.get();
71         descriptor.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights.get();
72         descriptor.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights.get();
73     }
74 
75     SetAdditionalInfo(descriptor);
76 
77     return factory.CreateWorkload(LayerType::UnidirectionalSequenceLstm, descriptor, PrepInfoAndDesc(descriptor));
78 }
79 
Clone(Graph & graph) const80 UnidirectionalSequenceLstmLayer* UnidirectionalSequenceLstmLayer::Clone(Graph& graph) const
81 {
82     auto layer = CloneBase<UnidirectionalSequenceLstmLayer>(graph, m_Param, GetName());
83 
84     layer->m_BasicParameters.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights ?
85             m_BasicParameters.m_InputToForgetWeights
86                 : nullptr;
87     layer->m_BasicParameters.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights ?
88             m_BasicParameters.m_InputToCellWeights : nullptr;
89     layer->m_BasicParameters.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights ?
90             m_BasicParameters.m_InputToOutputWeights : nullptr;
91     layer->m_BasicParameters.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights ?
92             m_BasicParameters.m_RecurrentToForgetWeights : nullptr;
93     layer->m_BasicParameters.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights ?
94             m_BasicParameters.m_RecurrentToCellWeights : nullptr;
95     layer->m_BasicParameters.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights ?
96             m_BasicParameters.m_RecurrentToOutputWeights : nullptr;
97     layer->m_BasicParameters.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias ?
98             m_BasicParameters.m_ForgetGateBias : nullptr;
99     layer->m_BasicParameters.m_CellBias = m_BasicParameters.m_CellBias ?
100             m_BasicParameters.m_CellBias : nullptr;
101     layer->m_BasicParameters.m_OutputGateBias = m_BasicParameters.m_OutputGateBias ?
102             m_BasicParameters.m_OutputGateBias : nullptr;
103 
104     if (!m_Param.m_CifgEnabled)
105     {
106         layer->m_CifgParameters.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights ?
107                 m_CifgParameters.m_InputToInputWeights : nullptr;
108         layer->m_CifgParameters.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights ?
109                 m_CifgParameters.m_RecurrentToInputWeights : nullptr;
110         layer->m_CifgParameters.m_InputGateBias = m_CifgParameters.m_InputGateBias ?
111                 m_CifgParameters.m_InputGateBias : nullptr;
112     }
113 
114     if (m_Param.m_ProjectionEnabled)
115     {
116         layer->m_ProjectionParameters.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights ?
117                m_ProjectionParameters.m_ProjectionWeights : nullptr;
118         layer->m_ProjectionParameters.m_ProjectionBias = m_ProjectionParameters.m_ProjectionBias ?
119                m_ProjectionParameters.m_ProjectionBias : nullptr;
120     }
121 
122     if (m_Param.m_PeepholeEnabled)
123     {
124         if (!m_Param.m_CifgEnabled)
125         {
126             layer->m_PeepholeParameters.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights ?
127                 m_PeepholeParameters.m_CellToInputWeights : nullptr;
128         }
129         layer->m_PeepholeParameters.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights ?
130                m_PeepholeParameters.m_CellToForgetWeights : nullptr;
131         layer->m_PeepholeParameters.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights ?
132                m_PeepholeParameters.m_CellToOutputWeights : nullptr;
133     }
134 
135     if (m_Param.m_LayerNormEnabled)
136     {
137         layer->m_LayerNormParameters.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights ?
138                m_LayerNormParameters.m_InputLayerNormWeights : nullptr;
139         layer->m_LayerNormParameters.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights ?
140                m_LayerNormParameters.m_ForgetLayerNormWeights : nullptr;
141         layer->m_LayerNormParameters.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights ?
142                m_LayerNormParameters.m_CellLayerNormWeights : nullptr;
143         layer->m_LayerNormParameters.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights ?
144                m_LayerNormParameters.m_OutputLayerNormWeights : nullptr;
145     }
146 
147     return std::move(layer);
148 }
149 
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const150 std::vector<TensorShape> UnidirectionalSequenceLstmLayer::InferOutputShapes(
151     const std::vector<TensorShape>& inputShapes) const
152 {
153     ARMNN_ASSERT(inputShapes.size() == 3);
154 
155     // Get input values for validation
156     unsigned int outputSize = inputShapes[1][1];
157 
158     std::vector<TensorShape> outShapes;
159     if (m_Param.m_TimeMajor)
160     {
161         outShapes.push_back(TensorShape({inputShapes[0][0], inputShapes[0][1], outputSize}));
162     }
163     else
164     {
165         outShapes.push_back(TensorShape({inputShapes[0][0], inputShapes[0][1], outputSize}));
166     }
167     return outShapes;
168 }
169 
ValidateTensorShapesFromInputs()170 void UnidirectionalSequenceLstmLayer::ValidateTensorShapesFromInputs()
171 {
172     VerifyLayerConnections(3, CHECK_LOCATION());
173 
174     const TensorShape& outputShape = GetOutputSlot(2).GetTensorInfo().GetShape();
175 
176     VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
177 
178     auto inferredShapes = InferOutputShapes( {
179         GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
180         GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(),
181         GetInputSlot(2).GetConnection()->GetTensorInfo().GetShape()
182     });
183 
184     ARMNN_ASSERT(inferredShapes.size() == 1);
185 
186     // Check if the weights are nullptr
187     ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToForgetWeights != nullptr,
188                      "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_InputToForgetWeights should not be null.");
189     ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToCellWeights != nullptr,
190                      "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_InputToCellWeights should not be null.");
191     ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToOutputWeights != nullptr,
192                      "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_InputToOutputWeights should not be null.");
193     ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToForgetWeights != nullptr,
194                      "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_RecurrentToForgetWeights "
195                      "should not be null.");
196     ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToCellWeights != nullptr,
197                      "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_RecurrentToCellWeights should not be null.");
198     ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToOutputWeights != nullptr,
199                      "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_RecurrentToOutputWeights "
200                      "should not be null.");
201     ARMNN_ASSERT_MSG(m_BasicParameters.m_ForgetGateBias != nullptr,
202                      "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_ForgetGateBias should not be null.");
203     ARMNN_ASSERT_MSG(m_BasicParameters.m_CellBias != nullptr,
204                      "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_CellBias should not be null.");
205     ARMNN_ASSERT_MSG(m_BasicParameters.m_OutputGateBias != nullptr,
206                      "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_OutputGateBias should not be null.");
207 
208     if (!m_Param.m_CifgEnabled)
209     {
210         ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights != nullptr,
211                          "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputToInputWeights should not be null.");
212         ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights != nullptr,
213                          "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_RecurrentToInputWeights "
214                          "should not be null.");
215         ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias != nullptr,
216                          "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputGateBias should not be null.");
217     }
218     else
219     {
220         ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights == nullptr,
221             "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputToInputWeights should not have a value "
222             "when CIFG is enabled.");
223         ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights == nullptr,
224             "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not have a value "
225             "when CIFG is enabled.");
226         ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias == nullptr,
227             "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputGateBias should not have a value "
228             "when CIFG is enabled.");
229     }
230 
231     if (m_Param.m_ProjectionEnabled)
232     {
233         ARMNN_ASSERT_MSG(m_ProjectionParameters.m_ProjectionWeights != nullptr,
234                          "UnidirectionalSequenceLstmLayer: m_ProjectionParameters.m_ProjectionWeights "
235                          "should not be null.");
236     }
237 
238     if (m_Param.m_PeepholeEnabled)
239     {
240         if (!m_Param.m_CifgEnabled)
241         {
242             ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToInputWeights != nullptr,
243                              "UnidirectionalSequenceLstmLayer: m_PeepholeParameters.m_CellToInputWeights "
244                              "should not be null "
245                              "when Peephole is enabled and CIFG is disabled.");
246         }
247         ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToForgetWeights != nullptr,
248                          "UnidirectionalSequenceLstmLayer: m_PeepholeParameters.m_CellToForgetWeights "
249                          "should not be null.");
250         ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToOutputWeights != nullptr,
251                          "UnidirectionalSequenceLstmLayer: m_PeepholeParameters.m_CellToOutputWeights "
252                          "should not be null.");
253     }
254 
255     if (m_Param.m_LayerNormEnabled)
256     {
257         if(!m_Param.m_CifgEnabled)
258         {
259             ARMNN_ASSERT_MSG(m_LayerNormParameters.m_InputLayerNormWeights != nullptr,
260                              "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_inputLayerNormWeights "
261                              "should not be null.");
262         }
263         ARMNN_ASSERT_MSG(m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr,
264                          "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_forgetLayerNormWeights "
265                          "should not be null.");
266         ARMNN_ASSERT_MSG(m_LayerNormParameters.m_CellLayerNormWeights != nullptr,
267                          "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_cellLayerNormWeights "
268                          "should not be null.");
269         ARMNN_ASSERT_MSG(m_LayerNormParameters.m_OutputLayerNormWeights != nullptr,
270                          "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_outputLayerNormWeights "
271                          "should not be null.");
272     }
273 
274     ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "UnidirectionalSequenceLstmLayer");
275 }
276 
GetConstantTensorsByRef() const277 Layer::ImmutableConstantTensors UnidirectionalSequenceLstmLayer::GetConstantTensorsByRef() const
278 {
279     // For API stability DO NOT ALTER order and add new members to the end of vector
280     return {m_BasicParameters.m_InputToForgetWeights,
281             m_BasicParameters.m_InputToCellWeights,
282             m_BasicParameters.m_InputToOutputWeights,
283             m_BasicParameters.m_RecurrentToForgetWeights,
284             m_BasicParameters.m_RecurrentToCellWeights,
285             m_BasicParameters.m_RecurrentToOutputWeights,
286             m_BasicParameters.m_ForgetGateBias,
287             m_BasicParameters.m_CellBias,
288             m_BasicParameters.m_OutputGateBias,
289 
290             // Cifg parameters
291             m_CifgParameters.m_InputToInputWeights,
292             m_CifgParameters.m_RecurrentToInputWeights,
293             m_CifgParameters.m_InputGateBias,
294 
295             // Projection parameters
296             m_ProjectionParameters.m_ProjectionWeights,
297             m_ProjectionParameters.m_ProjectionBias,
298 
299             // Peephole parameters
300             m_PeepholeParameters.m_CellToInputWeights,
301             m_PeepholeParameters.m_CellToForgetWeights,
302             m_PeepholeParameters.m_CellToOutputWeights,
303 
304             // Layer normalisation parameters
305             m_LayerNormParameters.m_InputLayerNormWeights,
306             m_LayerNormParameters.m_ForgetLayerNormWeights,
307             m_LayerNormParameters.m_CellLayerNormWeights,
308             m_LayerNormParameters.m_OutputLayerNormWeights};
309 }
310 
ExecuteStrategy(IStrategy & strategy) const311 void UnidirectionalSequenceLstmLayer::ExecuteStrategy(IStrategy& strategy) const
312 {
313     std::vector<ConstTensor> constTensors;
314 
315     LstmDescriptor descriptor = GetParameters();
316 
317     ManagedConstTensorHandle managedInputToForgetWeights(m_BasicParameters.m_InputToForgetWeights);
318     ManagedConstTensorHandle managedInputToCellWeights(m_BasicParameters.m_InputToCellWeights);
319     ManagedConstTensorHandle managedInputToOutputWeights(m_BasicParameters.m_InputToOutputWeights);
320     ManagedConstTensorHandle managedRecurrentToForgetWeights(m_BasicParameters.m_RecurrentToForgetWeights);
321     ManagedConstTensorHandle managedRecurrentToCellWeights(m_BasicParameters.m_RecurrentToCellWeights);
322     ManagedConstTensorHandle managedRecurrentToOutputWeights(m_BasicParameters.m_RecurrentToOutputWeights);
323     ManagedConstTensorHandle managedForgetGateBias(m_BasicParameters.m_ForgetGateBias);
324     ManagedConstTensorHandle managedCellBias(m_BasicParameters.m_CellBias);
325     ManagedConstTensorHandle managedOutputGateBias(m_BasicParameters.m_OutputGateBias);
326 
327     // Cifg parameters
328     ManagedConstTensorHandle managedInputToInputWeights(m_CifgParameters.m_InputToInputWeights);
329     ManagedConstTensorHandle managedRecurrentToInputWeights(m_CifgParameters.m_RecurrentToInputWeights);
330     ManagedConstTensorHandle managedInputGateBias(m_CifgParameters.m_InputGateBias);
331 
332     // Projection parameters
333     ManagedConstTensorHandle managedProjectionWeights(m_ProjectionParameters.m_ProjectionWeights);
334     ManagedConstTensorHandle managedProjectionBias(m_ProjectionParameters.m_ProjectionBias);
335 
336     // Peephole parameters
337     ManagedConstTensorHandle managedCellToInputWeights(m_PeepholeParameters.m_CellToInputWeights);
338     ManagedConstTensorHandle managedCellToForgetWeights(m_PeepholeParameters.m_CellToForgetWeights);
339     ManagedConstTensorHandle managedCellToOutputWeights(m_PeepholeParameters.m_CellToOutputWeights);
340 
341     // Layer normalisation parameters
342     ManagedConstTensorHandle managedInputLayerNormWeights(m_LayerNormParameters.m_InputLayerNormWeights);
343     ManagedConstTensorHandle managedForgetLayerNormWeights(m_LayerNormParameters.m_ForgetLayerNormWeights);
344     ManagedConstTensorHandle managedCellLayerNormWeights(m_LayerNormParameters.m_CellLayerNormWeights);
345     ManagedConstTensorHandle managedOutputLayerNormWeights(m_LayerNormParameters.m_OutputLayerNormWeights);
346 
347     // First add mandatory/basic parameters
348     if (m_BasicParameters.m_InputToForgetWeights != nullptr)
349     {
350         constTensors.emplace_back(ConstTensor(managedInputToForgetWeights.GetTensorInfo(),
351                                               managedInputToForgetWeights.Map()));
352     }
353     if (m_BasicParameters.m_InputToCellWeights != nullptr)
354     {
355         constTensors.emplace_back(ConstTensor(managedInputToCellWeights.GetTensorInfo(),
356                                               managedInputToCellWeights.Map()));
357     }
358     if (m_BasicParameters.m_InputToOutputWeights != nullptr)
359     {
360         constTensors.emplace_back(ConstTensor(managedInputToOutputWeights.GetTensorInfo(),
361                                               managedInputToOutputWeights.Map()));
362     }
363     if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr)
364     {
365         constTensors.emplace_back(ConstTensor(
366                 managedRecurrentToForgetWeights.GetTensorInfo(),
367                 managedRecurrentToForgetWeights.Map()));
368     }
369     if (m_BasicParameters.m_RecurrentToCellWeights != nullptr)
370     {
371         constTensors.emplace_back(ConstTensor(
372                 managedRecurrentToCellWeights.GetTensorInfo(),
373                 managedRecurrentToCellWeights.Map()));
374     }
375     if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr)
376     {
377         constTensors.emplace_back(ConstTensor(
378                 managedRecurrentToOutputWeights.GetTensorInfo(),
379                 managedRecurrentToOutputWeights.Map()));
380     }
381     if (m_BasicParameters.m_ForgetGateBias != nullptr)
382     {
383         constTensors.emplace_back(ConstTensor(managedForgetGateBias.GetTensorInfo(),
384                                               managedForgetGateBias.Map()));
385     }
386     if (m_BasicParameters.m_CellBias != nullptr)
387     {
388         constTensors.emplace_back(ConstTensor(managedCellBias.GetTensorInfo(),
389                                               managedCellBias.Map()));
390     }
391     if (m_BasicParameters.m_OutputGateBias != nullptr)
392     {
393         constTensors.emplace_back(ConstTensor(managedOutputGateBias.GetTensorInfo(),
394                                               managedOutputGateBias.Map()));
395     }
396 
397     // Add cifg parameters
398     if (!descriptor.m_CifgEnabled)
399     {
400         if (m_CifgParameters.m_InputToInputWeights != nullptr)
401         {
402             constTensors.emplace_back(ConstTensor(managedInputToInputWeights.GetTensorInfo(),
403                                                   managedInputToInputWeights.Map()));
404         }
405         if (m_CifgParameters.m_RecurrentToInputWeights != nullptr)
406         {
407             constTensors.emplace_back(ConstTensor(
408                     managedRecurrentToInputWeights.GetTensorInfo(),
409                     managedRecurrentToInputWeights.Map()));
410         }
411         if (m_CifgParameters.m_InputGateBias != nullptr)
412         {
413             constTensors.emplace_back(ConstTensor(managedInputGateBias.GetTensorInfo(),
414                                                   managedInputGateBias.Map()));
415         }
416     }
417 
418     // Add peephole parameters
419     if (descriptor.m_PeepholeEnabled)
420     {
421         if (!descriptor.m_CifgEnabled)
422         {
423             if (m_PeepholeParameters.m_CellToInputWeights != nullptr)
424             {
425                 constTensors.emplace_back(ConstTensor(managedCellToInputWeights.GetTensorInfo(),
426                                                       managedCellToInputWeights.Map()));
427             }
428         }
429         if (m_PeepholeParameters.m_CellToForgetWeights != nullptr)
430         {
431             constTensors.emplace_back(ConstTensor(managedCellToForgetWeights.GetTensorInfo(),
432                                                   managedCellToForgetWeights.Map()));
433         }
434         if (m_PeepholeParameters.m_CellToOutputWeights != nullptr)
435         {
436             constTensors.emplace_back(ConstTensor(managedCellToOutputWeights.GetTensorInfo(),
437                                                   managedCellToOutputWeights.Map()));
438         }
439     }
440 
441     // Add projection parameters
442     if (descriptor.m_ProjectionEnabled)
443     {
444         if (m_ProjectionParameters.m_ProjectionWeights != nullptr)
445         {
446             constTensors.emplace_back(ConstTensor(managedProjectionWeights.GetTensorInfo(),
447                                                   managedProjectionWeights.Map()));
448         }
449         if (m_ProjectionParameters.m_ProjectionBias != nullptr)
450         {
451             constTensors.emplace_back(ConstTensor(managedProjectionBias.GetTensorInfo(),
452                                                   managedProjectionBias.Map()));
453         }
454     }
455 
456     // Add norm parameters
457     if (descriptor.m_LayerNormEnabled)
458     {
459         if (!descriptor.m_CifgEnabled)
460         {
461             if (m_LayerNormParameters.m_InputLayerNormWeights != nullptr)
462             {
463                 constTensors.emplace_back(ConstTensor(managedInputLayerNormWeights.GetTensorInfo(),
464                                                       managedInputLayerNormWeights.Map()));
465             }
466         }
467         if (m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr)
468         {
469             constTensors.emplace_back(ConstTensor(managedForgetLayerNormWeights.GetTensorInfo(),
470                                                   managedForgetLayerNormWeights.Map()));
471         }
472         if (m_LayerNormParameters.m_CellLayerNormWeights != nullptr)
473         {
474             constTensors.emplace_back(ConstTensor(managedCellLayerNormWeights.GetTensorInfo(),
475                                                   managedCellLayerNormWeights.Map()));
476         }
477         if (m_LayerNormParameters.m_OutputLayerNormWeights != nullptr)
478         {
479             constTensors.emplace_back(ConstTensor(managedOutputLayerNormWeights.GetTensorInfo(),
480                                                   managedOutputLayerNormWeights.Map()));
481         }
482     }
483 
484     strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName());
485 }
486 
487 } // namespace armnn
488