1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "UnidirectionalSequenceLstmLayer.hpp"
6
7 #include "LayerCloneBase.hpp"
8
9 #include <armnn/LstmParams.hpp>
10 #include <armnn/TypesUtils.hpp>
11 #include <armnn/backends/TensorHandle.hpp>
12 #include <armnn/backends/WorkloadFactory.hpp>
13
14 namespace armnn
15 {
16
UnidirectionalSequenceLstmLayer(const LstmDescriptor & param,const char * name)17 UnidirectionalSequenceLstmLayer::UnidirectionalSequenceLstmLayer(const LstmDescriptor& param, const char* name)
18 : LayerWithParameters(3, 3, LayerType::UnidirectionalSequenceLstm, param, name)
19 {
20 }
21
CreateWorkload(const IWorkloadFactory & factory) const22 std::unique_ptr<IWorkload> UnidirectionalSequenceLstmLayer::CreateWorkload(const IWorkloadFactory& factory) const
23 {
24 UnidirectionalSequenceLstmQueueDescriptor descriptor;
25
26 // Basic parameters
27 descriptor.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights.get();
28 descriptor.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights.get();
29 descriptor.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights.get();
30 descriptor.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights.get();
31 descriptor.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights.get();
32 descriptor.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights.get();
33 descriptor.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias.get();
34 descriptor.m_CellBias = m_BasicParameters.m_CellBias.get();
35 descriptor.m_OutputGateBias = m_BasicParameters.m_OutputGateBias.get();
36
37 // Cifg parameters
38 if (!m_Param.m_CifgEnabled)
39 {
40 descriptor.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights.get();
41 descriptor.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights.get();
42 descriptor.m_InputGateBias = m_CifgParameters.m_InputGateBias.get();
43 }
44
45 // Projection parameters
46 if (m_Param.m_ProjectionEnabled)
47 {
48 descriptor.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights.get();
49 descriptor.m_ProjectionBias = m_ProjectionParameters.m_ProjectionBias.get();
50 }
51
52 // Peephole parameters
53 if (m_Param.m_PeepholeEnabled)
54 {
55 if (!m_Param.m_CifgEnabled)
56 {
57 descriptor.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights.get();
58 }
59 descriptor.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights.get();
60 descriptor.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights.get();
61 }
62
63 // Layer normalisation parameters
64 if(m_Param.m_LayerNormEnabled)
65 {
66 if (!m_Param.m_CifgEnabled)
67 {
68 descriptor.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights.get();
69 }
70 descriptor.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights.get();
71 descriptor.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights.get();
72 descriptor.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights.get();
73 }
74
75 SetAdditionalInfo(descriptor);
76
77 return factory.CreateWorkload(LayerType::UnidirectionalSequenceLstm, descriptor, PrepInfoAndDesc(descriptor));
78 }
79
Clone(Graph & graph) const80 UnidirectionalSequenceLstmLayer* UnidirectionalSequenceLstmLayer::Clone(Graph& graph) const
81 {
82 auto layer = CloneBase<UnidirectionalSequenceLstmLayer>(graph, m_Param, GetName());
83
84 layer->m_BasicParameters.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights ?
85 m_BasicParameters.m_InputToForgetWeights
86 : nullptr;
87 layer->m_BasicParameters.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights ?
88 m_BasicParameters.m_InputToCellWeights : nullptr;
89 layer->m_BasicParameters.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights ?
90 m_BasicParameters.m_InputToOutputWeights : nullptr;
91 layer->m_BasicParameters.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights ?
92 m_BasicParameters.m_RecurrentToForgetWeights : nullptr;
93 layer->m_BasicParameters.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights ?
94 m_BasicParameters.m_RecurrentToCellWeights : nullptr;
95 layer->m_BasicParameters.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights ?
96 m_BasicParameters.m_RecurrentToOutputWeights : nullptr;
97 layer->m_BasicParameters.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias ?
98 m_BasicParameters.m_ForgetGateBias : nullptr;
99 layer->m_BasicParameters.m_CellBias = m_BasicParameters.m_CellBias ?
100 m_BasicParameters.m_CellBias : nullptr;
101 layer->m_BasicParameters.m_OutputGateBias = m_BasicParameters.m_OutputGateBias ?
102 m_BasicParameters.m_OutputGateBias : nullptr;
103
104 if (!m_Param.m_CifgEnabled)
105 {
106 layer->m_CifgParameters.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights ?
107 m_CifgParameters.m_InputToInputWeights : nullptr;
108 layer->m_CifgParameters.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights ?
109 m_CifgParameters.m_RecurrentToInputWeights : nullptr;
110 layer->m_CifgParameters.m_InputGateBias = m_CifgParameters.m_InputGateBias ?
111 m_CifgParameters.m_InputGateBias : nullptr;
112 }
113
114 if (m_Param.m_ProjectionEnabled)
115 {
116 layer->m_ProjectionParameters.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights ?
117 m_ProjectionParameters.m_ProjectionWeights : nullptr;
118 layer->m_ProjectionParameters.m_ProjectionBias = m_ProjectionParameters.m_ProjectionBias ?
119 m_ProjectionParameters.m_ProjectionBias : nullptr;
120 }
121
122 if (m_Param.m_PeepholeEnabled)
123 {
124 if (!m_Param.m_CifgEnabled)
125 {
126 layer->m_PeepholeParameters.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights ?
127 m_PeepholeParameters.m_CellToInputWeights : nullptr;
128 }
129 layer->m_PeepholeParameters.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights ?
130 m_PeepholeParameters.m_CellToForgetWeights : nullptr;
131 layer->m_PeepholeParameters.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights ?
132 m_PeepholeParameters.m_CellToOutputWeights : nullptr;
133 }
134
135 if (m_Param.m_LayerNormEnabled)
136 {
137 layer->m_LayerNormParameters.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights ?
138 m_LayerNormParameters.m_InputLayerNormWeights : nullptr;
139 layer->m_LayerNormParameters.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights ?
140 m_LayerNormParameters.m_ForgetLayerNormWeights : nullptr;
141 layer->m_LayerNormParameters.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights ?
142 m_LayerNormParameters.m_CellLayerNormWeights : nullptr;
143 layer->m_LayerNormParameters.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights ?
144 m_LayerNormParameters.m_OutputLayerNormWeights : nullptr;
145 }
146
147 return std::move(layer);
148 }
149
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const150 std::vector<TensorShape> UnidirectionalSequenceLstmLayer::InferOutputShapes(
151 const std::vector<TensorShape>& inputShapes) const
152 {
153 ARMNN_ASSERT(inputShapes.size() == 3);
154
155 // Get input values for validation
156 unsigned int outputSize = inputShapes[1][1];
157
158 std::vector<TensorShape> outShapes;
159 if (m_Param.m_TimeMajor)
160 {
161 outShapes.push_back(TensorShape({inputShapes[0][0], inputShapes[0][1], outputSize}));
162 }
163 else
164 {
165 outShapes.push_back(TensorShape({inputShapes[0][0], inputShapes[0][1], outputSize}));
166 }
167 return outShapes;
168 }
169
ValidateTensorShapesFromInputs()170 void UnidirectionalSequenceLstmLayer::ValidateTensorShapesFromInputs()
171 {
172 VerifyLayerConnections(3, CHECK_LOCATION());
173
174 const TensorShape& outputShape = GetOutputSlot(2).GetTensorInfo().GetShape();
175
176 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
177
178 auto inferredShapes = InferOutputShapes( {
179 GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
180 GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(),
181 GetInputSlot(2).GetConnection()->GetTensorInfo().GetShape()
182 });
183
184 ARMNN_ASSERT(inferredShapes.size() == 1);
185
186 // Check if the weights are nullptr
187 ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToForgetWeights != nullptr,
188 "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_InputToForgetWeights should not be null.");
189 ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToCellWeights != nullptr,
190 "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_InputToCellWeights should not be null.");
191 ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToOutputWeights != nullptr,
192 "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_InputToOutputWeights should not be null.");
193 ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToForgetWeights != nullptr,
194 "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_RecurrentToForgetWeights "
195 "should not be null.");
196 ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToCellWeights != nullptr,
197 "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_RecurrentToCellWeights should not be null.");
198 ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToOutputWeights != nullptr,
199 "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_RecurrentToOutputWeights "
200 "should not be null.");
201 ARMNN_ASSERT_MSG(m_BasicParameters.m_ForgetGateBias != nullptr,
202 "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_ForgetGateBias should not be null.");
203 ARMNN_ASSERT_MSG(m_BasicParameters.m_CellBias != nullptr,
204 "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_CellBias should not be null.");
205 ARMNN_ASSERT_MSG(m_BasicParameters.m_OutputGateBias != nullptr,
206 "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_OutputGateBias should not be null.");
207
208 if (!m_Param.m_CifgEnabled)
209 {
210 ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights != nullptr,
211 "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputToInputWeights should not be null.");
212 ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights != nullptr,
213 "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_RecurrentToInputWeights "
214 "should not be null.");
215 ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias != nullptr,
216 "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputGateBias should not be null.");
217 }
218 else
219 {
220 ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights == nullptr,
221 "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputToInputWeights should not have a value "
222 "when CIFG is enabled.");
223 ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights == nullptr,
224 "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not have a value "
225 "when CIFG is enabled.");
226 ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias == nullptr,
227 "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputGateBias should not have a value "
228 "when CIFG is enabled.");
229 }
230
231 if (m_Param.m_ProjectionEnabled)
232 {
233 ARMNN_ASSERT_MSG(m_ProjectionParameters.m_ProjectionWeights != nullptr,
234 "UnidirectionalSequenceLstmLayer: m_ProjectionParameters.m_ProjectionWeights "
235 "should not be null.");
236 }
237
238 if (m_Param.m_PeepholeEnabled)
239 {
240 if (!m_Param.m_CifgEnabled)
241 {
242 ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToInputWeights != nullptr,
243 "UnidirectionalSequenceLstmLayer: m_PeepholeParameters.m_CellToInputWeights "
244 "should not be null "
245 "when Peephole is enabled and CIFG is disabled.");
246 }
247 ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToForgetWeights != nullptr,
248 "UnidirectionalSequenceLstmLayer: m_PeepholeParameters.m_CellToForgetWeights "
249 "should not be null.");
250 ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToOutputWeights != nullptr,
251 "UnidirectionalSequenceLstmLayer: m_PeepholeParameters.m_CellToOutputWeights "
252 "should not be null.");
253 }
254
255 if (m_Param.m_LayerNormEnabled)
256 {
257 if(!m_Param.m_CifgEnabled)
258 {
259 ARMNN_ASSERT_MSG(m_LayerNormParameters.m_InputLayerNormWeights != nullptr,
260 "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_inputLayerNormWeights "
261 "should not be null.");
262 }
263 ARMNN_ASSERT_MSG(m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr,
264 "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_forgetLayerNormWeights "
265 "should not be null.");
266 ARMNN_ASSERT_MSG(m_LayerNormParameters.m_CellLayerNormWeights != nullptr,
267 "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_cellLayerNormWeights "
268 "should not be null.");
269 ARMNN_ASSERT_MSG(m_LayerNormParameters.m_OutputLayerNormWeights != nullptr,
270 "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_outputLayerNormWeights "
271 "should not be null.");
272 }
273
274 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "UnidirectionalSequenceLstmLayer");
275 }
276
GetConstantTensorsByRef() const277 Layer::ImmutableConstantTensors UnidirectionalSequenceLstmLayer::GetConstantTensorsByRef() const
278 {
279 // For API stability DO NOT ALTER order and add new members to the end of vector
280 return {m_BasicParameters.m_InputToForgetWeights,
281 m_BasicParameters.m_InputToCellWeights,
282 m_BasicParameters.m_InputToOutputWeights,
283 m_BasicParameters.m_RecurrentToForgetWeights,
284 m_BasicParameters.m_RecurrentToCellWeights,
285 m_BasicParameters.m_RecurrentToOutputWeights,
286 m_BasicParameters.m_ForgetGateBias,
287 m_BasicParameters.m_CellBias,
288 m_BasicParameters.m_OutputGateBias,
289
290 // Cifg parameters
291 m_CifgParameters.m_InputToInputWeights,
292 m_CifgParameters.m_RecurrentToInputWeights,
293 m_CifgParameters.m_InputGateBias,
294
295 // Projection parameters
296 m_ProjectionParameters.m_ProjectionWeights,
297 m_ProjectionParameters.m_ProjectionBias,
298
299 // Peephole parameters
300 m_PeepholeParameters.m_CellToInputWeights,
301 m_PeepholeParameters.m_CellToForgetWeights,
302 m_PeepholeParameters.m_CellToOutputWeights,
303
304 // Layer normalisation parameters
305 m_LayerNormParameters.m_InputLayerNormWeights,
306 m_LayerNormParameters.m_ForgetLayerNormWeights,
307 m_LayerNormParameters.m_CellLayerNormWeights,
308 m_LayerNormParameters.m_OutputLayerNormWeights};
309 }
310
ExecuteStrategy(IStrategy & strategy) const311 void UnidirectionalSequenceLstmLayer::ExecuteStrategy(IStrategy& strategy) const
312 {
313 std::vector<ConstTensor> constTensors;
314
315 LstmDescriptor descriptor = GetParameters();
316
317 ManagedConstTensorHandle managedInputToForgetWeights(m_BasicParameters.m_InputToForgetWeights);
318 ManagedConstTensorHandle managedInputToCellWeights(m_BasicParameters.m_InputToCellWeights);
319 ManagedConstTensorHandle managedInputToOutputWeights(m_BasicParameters.m_InputToOutputWeights);
320 ManagedConstTensorHandle managedRecurrentToForgetWeights(m_BasicParameters.m_RecurrentToForgetWeights);
321 ManagedConstTensorHandle managedRecurrentToCellWeights(m_BasicParameters.m_RecurrentToCellWeights);
322 ManagedConstTensorHandle managedRecurrentToOutputWeights(m_BasicParameters.m_RecurrentToOutputWeights);
323 ManagedConstTensorHandle managedForgetGateBias(m_BasicParameters.m_ForgetGateBias);
324 ManagedConstTensorHandle managedCellBias(m_BasicParameters.m_CellBias);
325 ManagedConstTensorHandle managedOutputGateBias(m_BasicParameters.m_OutputGateBias);
326
327 // Cifg parameters
328 ManagedConstTensorHandle managedInputToInputWeights(m_CifgParameters.m_InputToInputWeights);
329 ManagedConstTensorHandle managedRecurrentToInputWeights(m_CifgParameters.m_RecurrentToInputWeights);
330 ManagedConstTensorHandle managedInputGateBias(m_CifgParameters.m_InputGateBias);
331
332 // Projection parameters
333 ManagedConstTensorHandle managedProjectionWeights(m_ProjectionParameters.m_ProjectionWeights);
334 ManagedConstTensorHandle managedProjectionBias(m_ProjectionParameters.m_ProjectionBias);
335
336 // Peephole parameters
337 ManagedConstTensorHandle managedCellToInputWeights(m_PeepholeParameters.m_CellToInputWeights);
338 ManagedConstTensorHandle managedCellToForgetWeights(m_PeepholeParameters.m_CellToForgetWeights);
339 ManagedConstTensorHandle managedCellToOutputWeights(m_PeepholeParameters.m_CellToOutputWeights);
340
341 // Layer normalisation parameters
342 ManagedConstTensorHandle managedInputLayerNormWeights(m_LayerNormParameters.m_InputLayerNormWeights);
343 ManagedConstTensorHandle managedForgetLayerNormWeights(m_LayerNormParameters.m_ForgetLayerNormWeights);
344 ManagedConstTensorHandle managedCellLayerNormWeights(m_LayerNormParameters.m_CellLayerNormWeights);
345 ManagedConstTensorHandle managedOutputLayerNormWeights(m_LayerNormParameters.m_OutputLayerNormWeights);
346
347 // First add mandatory/basic parameters
348 if (m_BasicParameters.m_InputToForgetWeights != nullptr)
349 {
350 constTensors.emplace_back(ConstTensor(managedInputToForgetWeights.GetTensorInfo(),
351 managedInputToForgetWeights.Map()));
352 }
353 if (m_BasicParameters.m_InputToCellWeights != nullptr)
354 {
355 constTensors.emplace_back(ConstTensor(managedInputToCellWeights.GetTensorInfo(),
356 managedInputToCellWeights.Map()));
357 }
358 if (m_BasicParameters.m_InputToOutputWeights != nullptr)
359 {
360 constTensors.emplace_back(ConstTensor(managedInputToOutputWeights.GetTensorInfo(),
361 managedInputToOutputWeights.Map()));
362 }
363 if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr)
364 {
365 constTensors.emplace_back(ConstTensor(
366 managedRecurrentToForgetWeights.GetTensorInfo(),
367 managedRecurrentToForgetWeights.Map()));
368 }
369 if (m_BasicParameters.m_RecurrentToCellWeights != nullptr)
370 {
371 constTensors.emplace_back(ConstTensor(
372 managedRecurrentToCellWeights.GetTensorInfo(),
373 managedRecurrentToCellWeights.Map()));
374 }
375 if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr)
376 {
377 constTensors.emplace_back(ConstTensor(
378 managedRecurrentToOutputWeights.GetTensorInfo(),
379 managedRecurrentToOutputWeights.Map()));
380 }
381 if (m_BasicParameters.m_ForgetGateBias != nullptr)
382 {
383 constTensors.emplace_back(ConstTensor(managedForgetGateBias.GetTensorInfo(),
384 managedForgetGateBias.Map()));
385 }
386 if (m_BasicParameters.m_CellBias != nullptr)
387 {
388 constTensors.emplace_back(ConstTensor(managedCellBias.GetTensorInfo(),
389 managedCellBias.Map()));
390 }
391 if (m_BasicParameters.m_OutputGateBias != nullptr)
392 {
393 constTensors.emplace_back(ConstTensor(managedOutputGateBias.GetTensorInfo(),
394 managedOutputGateBias.Map()));
395 }
396
397 // Add cifg parameters
398 if (!descriptor.m_CifgEnabled)
399 {
400 if (m_CifgParameters.m_InputToInputWeights != nullptr)
401 {
402 constTensors.emplace_back(ConstTensor(managedInputToInputWeights.GetTensorInfo(),
403 managedInputToInputWeights.Map()));
404 }
405 if (m_CifgParameters.m_RecurrentToInputWeights != nullptr)
406 {
407 constTensors.emplace_back(ConstTensor(
408 managedRecurrentToInputWeights.GetTensorInfo(),
409 managedRecurrentToInputWeights.Map()));
410 }
411 if (m_CifgParameters.m_InputGateBias != nullptr)
412 {
413 constTensors.emplace_back(ConstTensor(managedInputGateBias.GetTensorInfo(),
414 managedInputGateBias.Map()));
415 }
416 }
417
418 // Add peephole parameters
419 if (descriptor.m_PeepholeEnabled)
420 {
421 if (!descriptor.m_CifgEnabled)
422 {
423 if (m_PeepholeParameters.m_CellToInputWeights != nullptr)
424 {
425 constTensors.emplace_back(ConstTensor(managedCellToInputWeights.GetTensorInfo(),
426 managedCellToInputWeights.Map()));
427 }
428 }
429 if (m_PeepholeParameters.m_CellToForgetWeights != nullptr)
430 {
431 constTensors.emplace_back(ConstTensor(managedCellToForgetWeights.GetTensorInfo(),
432 managedCellToForgetWeights.Map()));
433 }
434 if (m_PeepholeParameters.m_CellToOutputWeights != nullptr)
435 {
436 constTensors.emplace_back(ConstTensor(managedCellToOutputWeights.GetTensorInfo(),
437 managedCellToOutputWeights.Map()));
438 }
439 }
440
441 // Add projection parameters
442 if (descriptor.m_ProjectionEnabled)
443 {
444 if (m_ProjectionParameters.m_ProjectionWeights != nullptr)
445 {
446 constTensors.emplace_back(ConstTensor(managedProjectionWeights.GetTensorInfo(),
447 managedProjectionWeights.Map()));
448 }
449 if (m_ProjectionParameters.m_ProjectionBias != nullptr)
450 {
451 constTensors.emplace_back(ConstTensor(managedProjectionBias.GetTensorInfo(),
452 managedProjectionBias.Map()));
453 }
454 }
455
456 // Add norm parameters
457 if (descriptor.m_LayerNormEnabled)
458 {
459 if (!descriptor.m_CifgEnabled)
460 {
461 if (m_LayerNormParameters.m_InputLayerNormWeights != nullptr)
462 {
463 constTensors.emplace_back(ConstTensor(managedInputLayerNormWeights.GetTensorInfo(),
464 managedInputLayerNormWeights.Map()));
465 }
466 }
467 if (m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr)
468 {
469 constTensors.emplace_back(ConstTensor(managedForgetLayerNormWeights.GetTensorInfo(),
470 managedForgetLayerNormWeights.Map()));
471 }
472 if (m_LayerNormParameters.m_CellLayerNormWeights != nullptr)
473 {
474 constTensors.emplace_back(ConstTensor(managedCellLayerNormWeights.GetTensorInfo(),
475 managedCellLayerNormWeights.Map()));
476 }
477 if (m_LayerNormParameters.m_OutputLayerNormWeights != nullptr)
478 {
479 constTensors.emplace_back(ConstTensor(managedOutputLayerNormWeights.GetTensorInfo(),
480 managedOutputLayerNormWeights.Map()));
481 }
482 }
483
484 strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName());
485 }
486
487 } // namespace armnn
488