1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "LstmLayer.hpp"
6
7 #include "LayerCloneBase.hpp"
8
9 #include <armnn/LstmParams.hpp>
10 #include <armnn/TypesUtils.hpp>
11 #include <armnn/backends/TensorHandle.hpp>
12 #include <armnn/backends/WorkloadFactory.hpp>
13
14 namespace armnn
15 {
16
LstmLayer(const LstmDescriptor & param,const char * name)17 LstmLayer::LstmLayer(const LstmDescriptor& param, const char* name)
18 : LayerWithParameters(3, 4, LayerType::Lstm, param, name)
19 {
20 }
21
CreateWorkload(const IWorkloadFactory & factory) const22 std::unique_ptr<IWorkload> LstmLayer::CreateWorkload(const IWorkloadFactory& factory) const
23 {
24 LstmQueueDescriptor descriptor;
25
26 // Basic parameters
27 descriptor.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights.get();
28 descriptor.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights.get();
29 descriptor.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights.get();
30 descriptor.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights.get();
31 descriptor.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights.get();
32 descriptor.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights.get();
33 descriptor.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias.get();
34 descriptor.m_CellBias = m_BasicParameters.m_CellBias.get();
35 descriptor.m_OutputGateBias = m_BasicParameters.m_OutputGateBias.get();
36
37 // Cifg parameters
38 if (!m_Param.m_CifgEnabled)
39 {
40 descriptor.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights.get();
41 descriptor.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights.get();
42 descriptor.m_InputGateBias = m_CifgParameters.m_InputGateBias.get();
43 }
44
45 // Projection parameters
46 if (m_Param.m_ProjectionEnabled)
47 {
48 descriptor.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights.get();
49 descriptor.m_ProjectionBias = m_ProjectionParameters.m_ProjectionBias.get();
50 }
51
52 // Peephole parameters
53 if (m_Param.m_PeepholeEnabled)
54 {
55 if (!m_Param.m_CifgEnabled)
56 {
57 descriptor.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights.get();
58 }
59 descriptor.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights.get();
60 descriptor.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights.get();
61 }
62
63 // Layer normalisation parameters
64 if(m_Param.m_LayerNormEnabled)
65 {
66 if (!m_Param.m_CifgEnabled)
67 {
68 descriptor.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights.get();
69 }
70 descriptor.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights.get();
71 descriptor.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights.get();
72 descriptor.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights.get();
73 }
74
75 SetAdditionalInfo(descriptor);
76
77 return factory.CreateWorkload(LayerType::Lstm, descriptor, PrepInfoAndDesc(descriptor));
78 }
79
Clone(Graph & graph) const80 LstmLayer* LstmLayer::Clone(Graph& graph) const
81 {
82 auto layer = CloneBase<LstmLayer>(graph, m_Param, GetName());
83
84 layer->m_BasicParameters.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights ?
85 m_BasicParameters.m_InputToForgetWeights
86 : nullptr;
87 layer->m_BasicParameters.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights ?
88 m_BasicParameters.m_InputToCellWeights : nullptr;
89 layer->m_BasicParameters.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights ?
90 m_BasicParameters.m_InputToOutputWeights : nullptr;
91 layer->m_BasicParameters.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights ?
92 m_BasicParameters.m_RecurrentToForgetWeights : nullptr;
93 layer->m_BasicParameters.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights ?
94 m_BasicParameters.m_RecurrentToCellWeights : nullptr;
95 layer->m_BasicParameters.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights ?
96 m_BasicParameters.m_RecurrentToOutputWeights : nullptr;
97 layer->m_BasicParameters.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias ?
98 m_BasicParameters.m_ForgetGateBias : nullptr;
99 layer->m_BasicParameters.m_CellBias = m_BasicParameters.m_CellBias ?
100 m_BasicParameters.m_CellBias : nullptr;
101 layer->m_BasicParameters.m_OutputGateBias = m_BasicParameters.m_OutputGateBias ?
102 m_BasicParameters.m_OutputGateBias : nullptr;
103
104 if (!m_Param.m_CifgEnabled)
105 {
106 layer->m_CifgParameters.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights ?
107 m_CifgParameters.m_InputToInputWeights : nullptr;
108 layer->m_CifgParameters.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights ?
109 m_CifgParameters.m_RecurrentToInputWeights : nullptr;
110 layer->m_CifgParameters.m_InputGateBias = m_CifgParameters.m_InputGateBias ?
111 m_CifgParameters.m_InputGateBias : nullptr;
112 }
113
114 if (m_Param.m_ProjectionEnabled)
115 {
116 layer->m_ProjectionParameters.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights ?
117 m_ProjectionParameters.m_ProjectionWeights : nullptr;
118 layer->m_ProjectionParameters.m_ProjectionBias = m_ProjectionParameters.m_ProjectionBias ?
119 m_ProjectionParameters.m_ProjectionBias : nullptr;
120 }
121
122 if (m_Param.m_PeepholeEnabled)
123 {
124 if (!m_Param.m_CifgEnabled)
125 {
126 layer->m_PeepholeParameters.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights ?
127 m_PeepholeParameters.m_CellToInputWeights : nullptr;
128 }
129 layer->m_PeepholeParameters.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights ?
130 m_PeepholeParameters.m_CellToForgetWeights : nullptr;
131 layer->m_PeepholeParameters.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights ?
132 m_PeepholeParameters.m_CellToOutputWeights : nullptr;
133 }
134
135 if (m_Param.m_LayerNormEnabled)
136 {
137 layer->m_LayerNormParameters.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights ?
138 m_LayerNormParameters.m_InputLayerNormWeights : nullptr;
139 layer->m_LayerNormParameters.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights ?
140 m_LayerNormParameters.m_ForgetLayerNormWeights : nullptr;
141 layer->m_LayerNormParameters.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights ?
142 m_LayerNormParameters.m_CellLayerNormWeights : nullptr;
143 layer->m_LayerNormParameters.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights ?
144 m_LayerNormParameters.m_OutputLayerNormWeights : nullptr;
145 }
146
147 return std::move(layer);
148 }
149
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const150 std::vector<TensorShape> LstmLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
151 {
152 ARMNN_ASSERT(inputShapes.size() == 3);
153
154 // Get input values for validation
155 unsigned int batchSize = inputShapes[0][0];
156 unsigned int outputSize = inputShapes[1][1];
157 unsigned int numUnits = inputShapes[2][1];
158
159 std::vector<TensorShape> outShapes;
160 outShapes.push_back(TensorShape({batchSize, numUnits * (m_Param.m_CifgEnabled ? 3 : 4)}));
161 outShapes.push_back(TensorShape({batchSize, outputSize}));
162 outShapes.push_back(TensorShape({batchSize, numUnits}));
163 outShapes.push_back(TensorShape({batchSize, outputSize}));
164
165 return outShapes;
166 }
167
ValidateTensorShapesFromInputs()168 void LstmLayer::ValidateTensorShapesFromInputs()
169 {
170 VerifyLayerConnections(3, CHECK_LOCATION());
171
172 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
173
174 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
175
176 auto inferredShapes = InferOutputShapes( {
177 GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
178 GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(),
179 GetInputSlot(2).GetConnection()->GetTensorInfo().GetShape()
180 });
181
182 ARMNN_ASSERT(inferredShapes.size() == 4);
183
184 // Check if the weights are nullptr
185 ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToForgetWeights != nullptr,
186 "LstmLayer: m_BasicParameters.m_InputToForgetWeights should not be null.");
187 ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToCellWeights != nullptr,
188 "LstmLayer: m_BasicParameters.m_InputToCellWeights should not be null.");
189 ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToOutputWeights != nullptr,
190 "LstmLayer: m_BasicParameters.m_InputToOutputWeights should not be null.");
191 ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToForgetWeights != nullptr,
192 "LstmLayer: m_BasicParameters.m_RecurrentToForgetWeights should not be null.");
193 ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToCellWeights != nullptr,
194 "LstmLayer: m_BasicParameters.m_RecurrentToCellWeights should not be null.");
195 ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToOutputWeights != nullptr,
196 "LstmLayer: m_BasicParameters.m_RecurrentToOutputWeights should not be null.");
197 ARMNN_ASSERT_MSG(m_BasicParameters.m_ForgetGateBias != nullptr,
198 "LstmLayer: m_BasicParameters.m_ForgetGateBias should not be null.");
199 ARMNN_ASSERT_MSG(m_BasicParameters.m_CellBias != nullptr,
200 "LstmLayer: m_BasicParameters.m_CellBias should not be null.");
201 ARMNN_ASSERT_MSG(m_BasicParameters.m_OutputGateBias != nullptr,
202 "LstmLayer: m_BasicParameters.m_OutputGateBias should not be null.");
203
204 if (!m_Param.m_CifgEnabled)
205 {
206 ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights != nullptr,
207 "LstmLayer: m_CifgParameters.m_InputToInputWeights should not be null.");
208 ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights != nullptr,
209 "LstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not be null.");
210 ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias != nullptr,
211 "LstmLayer: m_CifgParameters.m_InputGateBias should not be null.");
212
213 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "LstmLayer");
214 }
215 else
216 {
217 ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights == nullptr,
218 "LstmLayer: m_CifgParameters.m_InputToInputWeights should not have a value when CIFG is enabled.");
219 ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights == nullptr,
220 "LstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not have a value when CIFG is enabled.");
221 ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias == nullptr,
222 "LstmLayer: m_CifgParameters.m_InputGateBias should not have a value when CIFG is enabled.");
223
224 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "LstmLayer");
225 }
226
227 if (m_Param.m_ProjectionEnabled)
228 {
229 ARMNN_ASSERT_MSG(m_ProjectionParameters.m_ProjectionWeights != nullptr,
230 "LstmLayer: m_ProjectionParameters.m_ProjectionWeights should not be null.");
231 }
232
233 if (m_Param.m_PeepholeEnabled)
234 {
235 if (!m_Param.m_CifgEnabled)
236 {
237 ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToInputWeights != nullptr,
238 "LstmLayer: m_PeepholeParameters.m_CellToInputWeights should not be null "
239 "when Peephole is enabled and CIFG is disabled.");
240 }
241 ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToForgetWeights != nullptr,
242 "LstmLayer: m_PeepholeParameters.m_CellToForgetWeights should not be null.");
243 ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToOutputWeights != nullptr,
244 "LstmLayer: m_PeepholeParameters.m_CellToOutputWeights should not be null.");
245 }
246
247 ValidateAndCopyShape(
248 GetOutputSlot(1).GetTensorInfo().GetShape(), inferredShapes[1], m_ShapeInferenceMethod, "LstmLayer", 1);
249 ValidateAndCopyShape(
250 GetOutputSlot(2).GetTensorInfo().GetShape(), inferredShapes[2], m_ShapeInferenceMethod, "LstmLayer", 2);
251 ValidateAndCopyShape(
252 GetOutputSlot(3).GetTensorInfo().GetShape(), inferredShapes[3], m_ShapeInferenceMethod, "LstmLayer", 3);
253
254 if (m_Param.m_LayerNormEnabled)
255 {
256 if(!m_Param.m_CifgEnabled)
257 {
258 ARMNN_ASSERT_MSG(m_LayerNormParameters.m_InputLayerNormWeights != nullptr,
259 "LstmLayer: m_LayerNormParameters.m_inputLayerNormWeights should not be null.");
260 }
261 ARMNN_ASSERT_MSG(m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr,
262 "LstmLayer: m_LayerNormParameters.m_forgetLayerNormWeights should not be null.");
263 ARMNN_ASSERT_MSG(m_LayerNormParameters.m_CellLayerNormWeights != nullptr,
264 "LstmLayer: m_LayerNormParameters.m_cellLayerNormWeights should not be null.");
265 ARMNN_ASSERT_MSG(m_LayerNormParameters.m_OutputLayerNormWeights != nullptr,
266 "LstmLayer: m_LayerNormParameters.m_outputLayerNormWeights should not be null.");
267 }
268 }
269
GetConstantTensorsByRef() const270 Layer::ImmutableConstantTensors LstmLayer::GetConstantTensorsByRef() const
271 {
272 // For API stability DO NOT ALTER order and add new members to the end of vector
273 return {m_BasicParameters.m_InputToForgetWeights,
274 m_BasicParameters.m_InputToCellWeights,
275 m_BasicParameters.m_InputToOutputWeights,
276 m_BasicParameters.m_RecurrentToForgetWeights,
277 m_BasicParameters.m_RecurrentToCellWeights,
278 m_BasicParameters.m_RecurrentToOutputWeights,
279 m_BasicParameters.m_ForgetGateBias,
280 m_BasicParameters.m_CellBias,
281 m_BasicParameters.m_OutputGateBias,
282
283 // Cifg parameters
284 m_CifgParameters.m_InputToInputWeights,
285 m_CifgParameters.m_RecurrentToInputWeights,
286 m_CifgParameters.m_InputGateBias,
287
288 // Projection parameters
289 m_ProjectionParameters.m_ProjectionWeights,
290 m_ProjectionParameters.m_ProjectionBias,
291
292 // Peephole parameters
293 m_PeepholeParameters.m_CellToInputWeights,
294 m_PeepholeParameters.m_CellToForgetWeights,
295 m_PeepholeParameters.m_CellToOutputWeights,
296
297 // Layer normalisation parameters
298 m_LayerNormParameters.m_InputLayerNormWeights,
299 m_LayerNormParameters.m_ForgetLayerNormWeights,
300 m_LayerNormParameters.m_CellLayerNormWeights,
301 m_LayerNormParameters.m_OutputLayerNormWeights};
302 }
303
ExecuteStrategy(IStrategy & strategy) const304 void LstmLayer::ExecuteStrategy(IStrategy& strategy) const
305 {
306 std::vector<ConstTensor> constTensors;
307
308 LstmDescriptor descriptor = GetParameters();
309
310 ManagedConstTensorHandle managedInputToForgetWeights(m_BasicParameters.m_InputToForgetWeights);
311 ManagedConstTensorHandle managedInputToCellWeights(m_BasicParameters.m_InputToCellWeights);
312 ManagedConstTensorHandle managedInputToOutputWeights(m_BasicParameters.m_InputToOutputWeights);
313 ManagedConstTensorHandle managedRecurrentToForgetWeights(m_BasicParameters.m_RecurrentToForgetWeights);
314 ManagedConstTensorHandle managedRecurrentToCellWeights(m_BasicParameters.m_RecurrentToCellWeights);
315 ManagedConstTensorHandle managedRecurrentToOutputWeights(m_BasicParameters.m_RecurrentToOutputWeights);
316 ManagedConstTensorHandle managedForgetGateBias(m_BasicParameters.m_ForgetGateBias);
317 ManagedConstTensorHandle managedCellBias(m_BasicParameters.m_CellBias);
318 ManagedConstTensorHandle managedOutputGateBias(m_BasicParameters.m_OutputGateBias);
319
320 // Cifg parameters
321 ManagedConstTensorHandle managedInputToInputWeights(m_CifgParameters.m_InputToInputWeights);
322 ManagedConstTensorHandle managedRecurrentToInputWeights(m_CifgParameters.m_RecurrentToInputWeights);
323 ManagedConstTensorHandle managedInputGateBias(m_CifgParameters.m_InputGateBias);
324
325 // Projection parameters
326 ManagedConstTensorHandle managedProjectionWeights(m_ProjectionParameters.m_ProjectionWeights);
327 ManagedConstTensorHandle managedProjectionBias(m_ProjectionParameters.m_ProjectionBias);
328
329 // Peephole parameters
330 ManagedConstTensorHandle managedCellToInputWeights(m_PeepholeParameters.m_CellToInputWeights);
331 ManagedConstTensorHandle managedCellToForgetWeights(m_PeepholeParameters.m_CellToForgetWeights);
332 ManagedConstTensorHandle managedCellToOutputWeights(m_PeepholeParameters.m_CellToOutputWeights);
333
334 // Layer normalisation parameters
335 ManagedConstTensorHandle managedInputLayerNormWeights(m_LayerNormParameters.m_InputLayerNormWeights);
336 ManagedConstTensorHandle managedForgetLayerNormWeights(m_LayerNormParameters.m_ForgetLayerNormWeights);
337 ManagedConstTensorHandle managedCellLayerNormWeights(m_LayerNormParameters.m_CellLayerNormWeights);
338 ManagedConstTensorHandle managedOutputLayerNormWeights(m_LayerNormParameters.m_OutputLayerNormWeights);
339
340 // First add mandatory/basic parameters
341 if (m_BasicParameters.m_InputToForgetWeights != nullptr)
342 {
343 constTensors.emplace_back(ConstTensor(managedInputToForgetWeights.GetTensorInfo(),
344 managedInputToForgetWeights.Map()));
345 }
346 if (m_BasicParameters.m_InputToCellWeights != nullptr)
347 {
348 constTensors.emplace_back(ConstTensor(managedInputToCellWeights.GetTensorInfo(),
349 managedInputToCellWeights.Map()));
350 }
351 if (m_BasicParameters.m_InputToOutputWeights != nullptr)
352 {
353 constTensors.emplace_back(ConstTensor(managedInputToOutputWeights.GetTensorInfo(),
354 managedInputToOutputWeights.Map()));
355 }
356 if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr)
357 {
358 constTensors.emplace_back(ConstTensor(
359 managedRecurrentToForgetWeights.GetTensorInfo(),
360 managedRecurrentToForgetWeights.Map()));
361 }
362 if (m_BasicParameters.m_RecurrentToCellWeights != nullptr)
363 {
364 constTensors.emplace_back(ConstTensor(
365 managedRecurrentToCellWeights.GetTensorInfo(),
366 managedRecurrentToCellWeights.Map()));
367 }
368 if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr)
369 {
370 constTensors.emplace_back(ConstTensor(
371 managedRecurrentToOutputWeights.GetTensorInfo(),
372 managedRecurrentToOutputWeights.Map()));
373 }
374 if (m_BasicParameters.m_ForgetGateBias != nullptr)
375 {
376 constTensors.emplace_back(ConstTensor(managedForgetGateBias.GetTensorInfo(),
377 managedForgetGateBias.Map()));
378 }
379 if (m_BasicParameters.m_CellBias != nullptr)
380 {
381 constTensors.emplace_back(ConstTensor(managedCellBias.GetTensorInfo(),
382 managedCellBias.Map()));
383 }
384 if (m_BasicParameters.m_OutputGateBias != nullptr)
385 {
386 constTensors.emplace_back(ConstTensor(managedOutputGateBias.GetTensorInfo(),
387 managedOutputGateBias.Map()));
388 }
389
390 // Add cifg parameters
391 if (!descriptor.m_CifgEnabled)
392 {
393 if (m_CifgParameters.m_InputToInputWeights != nullptr)
394 {
395 constTensors.emplace_back(ConstTensor(managedInputToInputWeights.GetTensorInfo(),
396 managedInputToInputWeights.Map()));
397 }
398 if (m_CifgParameters.m_RecurrentToInputWeights != nullptr)
399 {
400 constTensors.emplace_back(ConstTensor(
401 managedRecurrentToInputWeights.GetTensorInfo(),
402 managedRecurrentToInputWeights.Map()));
403 }
404 if (m_CifgParameters.m_InputGateBias != nullptr)
405 {
406 constTensors.emplace_back(ConstTensor(managedInputGateBias.GetTensorInfo(),
407 managedInputGateBias.Map()));
408 }
409 }
410
411 // Add peephole parameters
412 if (descriptor.m_PeepholeEnabled)
413 {
414 if (!descriptor.m_CifgEnabled)
415 {
416 if (m_PeepholeParameters.m_CellToInputWeights != nullptr)
417 {
418 constTensors.emplace_back(ConstTensor(managedCellToInputWeights.GetTensorInfo(),
419 managedCellToInputWeights.Map()));
420 }
421 }
422 if (m_PeepholeParameters.m_CellToForgetWeights != nullptr)
423 {
424 constTensors.emplace_back(ConstTensor(managedCellToForgetWeights.GetTensorInfo(),
425 managedCellToForgetWeights.Map()));
426 }
427 if (m_PeepholeParameters.m_CellToOutputWeights != nullptr)
428 {
429 constTensors.emplace_back(ConstTensor(managedCellToOutputWeights.GetTensorInfo(),
430 managedCellToOutputWeights.Map()));
431 }
432 }
433
434 // Add projection parameters
435 if (descriptor.m_ProjectionEnabled)
436 {
437 if (m_ProjectionParameters.m_ProjectionWeights != nullptr)
438 {
439 constTensors.emplace_back(ConstTensor(managedProjectionWeights.GetTensorInfo(),
440 managedProjectionWeights.Map()));
441 }
442 if (m_ProjectionParameters.m_ProjectionBias != nullptr)
443 {
444 constTensors.emplace_back(ConstTensor(managedProjectionBias.GetTensorInfo(),
445 managedProjectionBias.Map()));
446 }
447 }
448
449 // Add norm parameters
450 if (descriptor.m_LayerNormEnabled)
451 {
452 if (!descriptor.m_CifgEnabled)
453 {
454 if (m_LayerNormParameters.m_InputLayerNormWeights != nullptr)
455 {
456 constTensors.emplace_back(ConstTensor(managedInputLayerNormWeights.GetTensorInfo(),
457 managedInputLayerNormWeights.Map()));
458 }
459 }
460 if (m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr)
461 {
462 constTensors.emplace_back(ConstTensor(managedForgetLayerNormWeights.GetTensorInfo(),
463 managedForgetLayerNormWeights.Map()));
464 }
465 if (m_LayerNormParameters.m_CellLayerNormWeights != nullptr)
466 {
467 constTensors.emplace_back(ConstTensor(managedCellLayerNormWeights.GetTensorInfo(),
468 managedCellLayerNormWeights.Map()));
469 }
470 if (m_LayerNormParameters.m_OutputLayerNormWeights != nullptr)
471 {
472 constTensors.emplace_back(ConstTensor(managedOutputLayerNormWeights.GetTensorInfo(),
473 managedOutputLayerNormWeights.Map()));
474 }
475 }
476
477 strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName());
478 }
479
480 } // namespace armnn
481