1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "QLstmLayer.hpp"
6
7 #include "LayerCloneBase.hpp"
8
9 #include <armnn/LstmParams.hpp>
10 #include <armnn/TypesUtils.hpp>
11 #include <armnn/backends/TensorHandle.hpp>
12 #include <armnn/backends/WorkloadFactory.hpp>
13
14 namespace armnn
15 {
16
QLstmLayer(const QLstmDescriptor & param,const char * name)17 QLstmLayer::QLstmLayer(const QLstmDescriptor& param, const char* name)
18 : LayerWithParameters(3, 3, LayerType::QLstm, param, name)
19 {
20 }
21
CreateWorkload(const IWorkloadFactory & factory) const22 std::unique_ptr<IWorkload> QLstmLayer::CreateWorkload(const IWorkloadFactory& factory) const
23 {
24 QLstmQueueDescriptor descriptor;
25
26 // Basic parameters
27 descriptor.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights.get();
28 descriptor.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights.get();
29 descriptor.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights.get();
30 descriptor.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights.get();
31 descriptor.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights.get();
32 descriptor.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights.get();
33 descriptor.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias.get();
34 descriptor.m_CellBias = m_BasicParameters.m_CellBias.get();
35 descriptor.m_OutputGateBias = m_BasicParameters.m_OutputGateBias.get();
36
37 // CIFG parameters
38 if (!m_Param.m_CifgEnabled)
39 {
40 descriptor.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights.get();
41 descriptor.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights.get();
42 descriptor.m_InputGateBias = m_CifgParameters.m_InputGateBias.get();
43 }
44
45 // Projection parameters
46 if (m_Param.m_ProjectionEnabled)
47 {
48 descriptor.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights.get();
49 descriptor.m_ProjectionBias = m_ProjectionParameters.m_ProjectionBias.get();
50 }
51
52 // Peephole parameters
53 if (m_Param.m_PeepholeEnabled)
54 {
55 if (!m_Param.m_CifgEnabled)
56 {
57 descriptor.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights.get();
58 }
59
60 descriptor.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights.get();
61 descriptor.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights.get();
62 }
63
64 // Layer normalisation parameters
65 if(m_Param.m_LayerNormEnabled)
66 {
67 if (!m_Param.m_CifgEnabled)
68 {
69 descriptor.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights.get();
70 }
71 descriptor.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights.get();
72 descriptor.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights.get();
73 descriptor.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights.get();
74 }
75
76 SetAdditionalInfo(descriptor);
77
78 return factory.CreateWorkload(LayerType::QLstm, descriptor, PrepInfoAndDesc(descriptor));
79 }
80
Clone(Graph & graph) const81 QLstmLayer* QLstmLayer::Clone(Graph& graph) const
82 {
83 auto layer = CloneBase<QLstmLayer>(graph, m_Param, GetName());
84
85 layer->m_BasicParameters.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights ?
86 m_BasicParameters.m_InputToForgetWeights : nullptr;
87 layer->m_BasicParameters.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights ?
88 m_BasicParameters.m_InputToCellWeights : nullptr;
89 layer->m_BasicParameters.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights ?
90 m_BasicParameters.m_InputToOutputWeights : nullptr;
91 layer->m_BasicParameters.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights ?
92 m_BasicParameters.m_RecurrentToForgetWeights : nullptr;
93 layer->m_BasicParameters.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights ?
94 m_BasicParameters.m_RecurrentToCellWeights : nullptr;
95 layer->m_BasicParameters.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights ?
96 m_BasicParameters.m_RecurrentToOutputWeights : nullptr;
97 layer->m_BasicParameters.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias ?
98 m_BasicParameters.m_ForgetGateBias : nullptr;
99 layer->m_BasicParameters.m_CellBias = m_BasicParameters.m_CellBias ?
100 m_BasicParameters.m_CellBias : nullptr;
101 layer->m_BasicParameters.m_OutputGateBias = m_BasicParameters.m_OutputGateBias ?
102 m_BasicParameters.m_OutputGateBias : nullptr;
103
104 if (!m_Param.m_CifgEnabled)
105 {
106 layer->m_CifgParameters.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights ?
107 m_CifgParameters.m_InputToInputWeights : nullptr;
108 layer->m_CifgParameters.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights ?
109 m_CifgParameters.m_RecurrentToInputWeights : nullptr;
110 layer->m_CifgParameters.m_InputGateBias = m_CifgParameters.m_InputGateBias ?
111 m_CifgParameters.m_InputGateBias : nullptr;
112 }
113
114 if (m_Param.m_ProjectionEnabled)
115 {
116 layer->m_ProjectionParameters.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights ?
117 m_ProjectionParameters.m_ProjectionWeights : nullptr;
118 layer->m_ProjectionParameters.m_ProjectionBias = m_ProjectionParameters.m_ProjectionBias ?
119 m_ProjectionParameters.m_ProjectionBias : nullptr;
120 }
121
122 if (m_Param.m_PeepholeEnabled)
123 {
124 if (!m_Param.m_CifgEnabled) {
125 layer->m_PeepholeParameters.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights ?
126 m_PeepholeParameters.m_CellToInputWeights : nullptr;
127 }
128
129 layer->m_PeepholeParameters.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights ?
130 m_PeepholeParameters.m_CellToForgetWeights : nullptr;
131 layer->m_PeepholeParameters.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights ?
132 m_PeepholeParameters.m_CellToOutputWeights : nullptr;
133 }
134
135 if (m_Param.m_LayerNormEnabled)
136 {
137 if (!m_Param.m_CifgEnabled) {
138 layer->m_LayerNormParameters.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights ?
139 m_LayerNormParameters.m_InputLayerNormWeights : nullptr;
140 }
141
142 layer->m_LayerNormParameters.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights ?
143 m_LayerNormParameters.m_ForgetLayerNormWeights : nullptr;
144 layer->m_LayerNormParameters.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights ?
145 m_LayerNormParameters.m_CellLayerNormWeights : nullptr;
146 layer->m_LayerNormParameters.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights ?
147 m_LayerNormParameters.m_OutputLayerNormWeights : nullptr;
148 }
149
150 return std::move(layer);
151 }
152
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const153 std::vector<TensorShape> QLstmLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
154 {
155 ARMNN_ASSERT(inputShapes.size() == 3);
156
157 // Get input values for validation
158 unsigned int batchSize = inputShapes[0][0];
159 unsigned int outputSize = inputShapes[1][1];
160 unsigned int numUnits = inputShapes[2][1];
161
162 std::vector<TensorShape> outShapes;
163 outShapes.push_back(TensorShape({ batchSize, outputSize })); // outputStateOut
164 outShapes.push_back(TensorShape({ batchSize, numUnits })); // cellStateOut
165 outShapes.push_back(TensorShape({ batchSize, outputSize })); // output
166
167 return outShapes;
168 }
169
ValidateTensorShapesFromInputs()170 void QLstmLayer::ValidateTensorShapesFromInputs()
171 {
172 VerifyLayerConnections(3, CHECK_LOCATION());
173
174 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
175
176 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
177
178 auto inferredShapes = InferOutputShapes(
179 {
180 GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), // input
181 GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(), // previousOutputIn
182 GetInputSlot(2).GetConnection()->GetTensorInfo().GetShape() // previousCellStateIn
183 });
184
185 ARMNN_ASSERT(inferredShapes.size() == 3);
186
187 // Check if the weights are nullptr for basic params
188 ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToForgetWeights != nullptr,
189 "QLstmLayer: m_BasicParameters.m_InputToForgetWeights should not be null.");
190 ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToCellWeights != nullptr,
191 "QLstmLayer: m_BasicParameters.m_InputToCellWeights should not be null.");
192 ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToOutputWeights != nullptr,
193 "QLstmLayer: m_BasicParameters.m_InputToOutputWeights should not be null.");
194 ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToForgetWeights != nullptr,
195 "QLstmLayer: m_BasicParameters.m_RecurrentToForgetWeights should not be null.");
196 ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToCellWeights != nullptr,
197 "QLstmLayer: m_BasicParameters.m_RecurrentToCellWeights should not be null.");
198 ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToOutputWeights != nullptr,
199 "QLstmLayer: m_BasicParameters.m_RecurrentToOutputWeights should not be null.");
200 ARMNN_ASSERT_MSG(m_BasicParameters.m_ForgetGateBias != nullptr,
201 "QLstmLayer: m_BasicParameters.m_ForgetGateBias should not be null.");
202 ARMNN_ASSERT_MSG(m_BasicParameters.m_CellBias != nullptr,
203 "QLstmLayer: m_BasicParameters.m_CellBias should not be null.");
204 ARMNN_ASSERT_MSG(m_BasicParameters.m_OutputGateBias != nullptr,
205 "QLstmLayer: m_BasicParameters.m_OutputGateBias should not be null.");
206
207 if (!m_Param.m_CifgEnabled)
208 {
209 ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights != nullptr,
210 "QLstmLayer: m_CifgParameters.m_InputToInputWeights should not be null.");
211 ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights != nullptr,
212 "QLstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not be null.");
213 ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias != nullptr,
214 "QLstmLayer: m_CifgParameters.m_InputGateBias should not be null.");
215
216 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "QLstmLayer");
217 }
218 else
219 {
220 ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights == nullptr,
221 "QLstmLayer: m_CifgParameters.m_InputToInputWeights should not have a value when CIFG is enabled.");
222 ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights == nullptr,
223 "QLstmLayer: m_CifgParameters.m_RecurrentToInputWeights should "
224 "not have a value when CIFG is enabled.");
225 ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias == nullptr,
226 "QLstmLayer: m_CifgParameters.m_InputGateBias should not have a value when CIFG is enabled.");
227
228 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "QLstmLayer");
229 }
230
231 if (m_Param.m_ProjectionEnabled)
232 {
233 ARMNN_ASSERT_MSG(m_ProjectionParameters.m_ProjectionWeights != nullptr,
234 "QLstmLayer: m_ProjectionParameters.m_ProjectionWeights should not be null.");
235 }
236
237 if (m_Param.m_PeepholeEnabled)
238 {
239 if (!m_Param.m_CifgEnabled) {
240 ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToInputWeights != nullptr,
241 "QLstmLayer: m_PeepholeParameters.m_CellToInputWeights should not be null "
242 "when Peephole is enabled and CIFG is disabled.");
243 }
244
245 ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToForgetWeights != nullptr,
246 "QLstmLayer: m_PeepholeParameters.m_CellToForgetWeights should not be null.");
247 ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToOutputWeights != nullptr,
248 "QLstmLayer: m_PeepholeParameters.m_CellToOutputWeights should not be null.");
249 }
250
251 ValidateAndCopyShape(
252 GetOutputSlot(1).GetTensorInfo().GetShape(), inferredShapes[1], m_ShapeInferenceMethod, "QLstmLayer", 1);
253 ValidateAndCopyShape(
254 GetOutputSlot(2).GetTensorInfo().GetShape(), inferredShapes[2], m_ShapeInferenceMethod, "QLstmLayer", 2);
255
256 if (m_Param.m_LayerNormEnabled)
257 {
258 if(!m_Param.m_CifgEnabled)
259 {
260 ARMNN_ASSERT_MSG(m_LayerNormParameters.m_InputLayerNormWeights != nullptr,
261 "QLstmLayer: m_LayerNormParameters.m_InputLayerNormWeights should not be null.");
262 }
263 ARMNN_ASSERT_MSG(m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr,
264 "QLstmLayer: m_LayerNormParameters.m_ForgetLayerNormWeights should not be null.");
265 ARMNN_ASSERT_MSG(m_LayerNormParameters.m_CellLayerNormWeights != nullptr,
266 "QLstmLayer: m_LayerNormParameters.m_CellLayerNormWeights should not be null.");
267 ARMNN_ASSERT_MSG(m_LayerNormParameters.m_OutputLayerNormWeights != nullptr,
268 "QLstmLayer: m_LayerNormParameters.m_UutputLayerNormWeights should not be null.");
269 }
270 }
271
GetConstantTensorsByRef() const272 Layer::ImmutableConstantTensors QLstmLayer::GetConstantTensorsByRef() const
273 {
274 // For API stability DO NOT ALTER order and add new members to the end of vector
275 return {m_BasicParameters.m_InputToForgetWeights,
276 m_BasicParameters.m_InputToCellWeights,
277 m_BasicParameters.m_InputToOutputWeights,
278 m_BasicParameters.m_RecurrentToForgetWeights,
279 m_BasicParameters.m_RecurrentToCellWeights,
280 m_BasicParameters.m_RecurrentToOutputWeights,
281 m_BasicParameters.m_ForgetGateBias,
282 m_BasicParameters.m_CellBias,
283 m_BasicParameters.m_OutputGateBias,
284
285 // Cifg parameters
286 m_CifgParameters.m_InputToInputWeights,
287 m_CifgParameters.m_RecurrentToInputWeights,
288 m_CifgParameters.m_InputGateBias,
289
290 // Projection parameters
291 m_ProjectionParameters.m_ProjectionWeights,
292 m_ProjectionParameters.m_ProjectionBias,
293
294 // Peephole parameters
295 m_PeepholeParameters.m_CellToInputWeights,
296 m_PeepholeParameters.m_CellToForgetWeights,
297 m_PeepholeParameters.m_CellToOutputWeights,
298
299 // Layer normalisation parameters
300 m_LayerNormParameters.m_InputLayerNormWeights,
301 m_LayerNormParameters.m_ForgetLayerNormWeights,
302 m_LayerNormParameters.m_CellLayerNormWeights,
303 m_LayerNormParameters.m_OutputLayerNormWeights};
304 }
305
306
ExecuteStrategy(IStrategy & strategy) const307 void QLstmLayer::ExecuteStrategy(IStrategy& strategy) const
308 {
309 std::vector<ConstTensor> constTensors;
310 ManagedConstTensorHandle managedInputToForgetWeights(m_BasicParameters.m_InputToForgetWeights);
311 ManagedConstTensorHandle managedInputToCellWeights(m_BasicParameters.m_InputToCellWeights);
312 ManagedConstTensorHandle managedInputToOutputWeights(m_BasicParameters.m_InputToOutputWeights);
313 ManagedConstTensorHandle managedRecurrentToForgetWeights(m_BasicParameters.m_RecurrentToForgetWeights);
314 ManagedConstTensorHandle managedRecurrentToCellWeights(m_BasicParameters.m_RecurrentToCellWeights);
315 ManagedConstTensorHandle managedRecurrentToOutputWeights(m_BasicParameters.m_RecurrentToOutputWeights);
316 ManagedConstTensorHandle managedForgetGateBias(m_BasicParameters.m_ForgetGateBias);
317 ManagedConstTensorHandle managedCellBias(m_BasicParameters.m_CellBias);
318 ManagedConstTensorHandle managedOutputGateBias(m_BasicParameters.m_OutputGateBias);
319
320 // Cifg parameters
321 ManagedConstTensorHandle managedInputToInputWeights(m_CifgParameters.m_InputToInputWeights);
322 ManagedConstTensorHandle managedRecurrentToInputWeights(m_CifgParameters.m_RecurrentToInputWeights);
323 ManagedConstTensorHandle managedInputGateBias(m_CifgParameters.m_InputGateBias);
324
325 // Projection parameters
326 ManagedConstTensorHandle managedProjectionWeights(m_ProjectionParameters.m_ProjectionWeights);
327 ManagedConstTensorHandle managedProjectionBias(m_ProjectionParameters.m_ProjectionBias);
328
329 // Peephole parameters
330 ManagedConstTensorHandle managedCellToInputWeights(m_PeepholeParameters.m_CellToInputWeights);
331 ManagedConstTensorHandle managedCellToForgetWeights(m_PeepholeParameters.m_CellToForgetWeights);
332 ManagedConstTensorHandle managedCellToOutputWeights(m_PeepholeParameters.m_CellToOutputWeights);
333
334 // Layer normalisation parameters
335 ManagedConstTensorHandle managedInputLayerNormWeights(m_LayerNormParameters.m_InputLayerNormWeights);
336 ManagedConstTensorHandle managedForgetLayerNormWeights(m_LayerNormParameters.m_ForgetLayerNormWeights);
337 ManagedConstTensorHandle managedCellLayerNormWeights(m_LayerNormParameters.m_CellLayerNormWeights);
338 ManagedConstTensorHandle managedOutputLayerNormWeights(m_LayerNormParameters.m_OutputLayerNormWeights);
339
340 // First add mandatory/basic parameters
341 if (m_BasicParameters.m_InputToForgetWeights != nullptr)
342 {
343 constTensors.emplace_back(ConstTensor(managedInputToForgetWeights.GetTensorInfo(),
344 managedInputToForgetWeights.Map()));
345 }
346 if (m_BasicParameters.m_InputToCellWeights != nullptr)
347 {
348 constTensors.emplace_back(ConstTensor(managedInputToCellWeights.GetTensorInfo(),
349 managedInputToCellWeights.Map()));
350 }
351 if (m_BasicParameters.m_InputToOutputWeights != nullptr)
352 {
353 constTensors.emplace_back(ConstTensor(managedInputToOutputWeights.GetTensorInfo(),
354 managedInputToOutputWeights.Map()));
355 }
356 if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr)
357 {
358 constTensors.emplace_back(ConstTensor(
359 managedRecurrentToForgetWeights.GetTensorInfo(),
360 managedRecurrentToForgetWeights.Map()));
361 }
362 if (m_BasicParameters.m_RecurrentToCellWeights != nullptr)
363 {
364 constTensors.emplace_back(ConstTensor(
365 managedRecurrentToCellWeights.GetTensorInfo(),
366 managedRecurrentToCellWeights.Map()));
367 }
368 if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr)
369 {
370 constTensors.emplace_back(ConstTensor(
371 managedRecurrentToOutputWeights.GetTensorInfo(),
372 managedRecurrentToOutputWeights.Map()));
373 }
374 if (m_BasicParameters.m_ForgetGateBias != nullptr)
375 {
376 constTensors.emplace_back(ConstTensor(managedForgetGateBias.GetTensorInfo(),
377 managedForgetGateBias.Map()));
378 }
379 if (m_BasicParameters.m_CellBias != nullptr)
380 {
381 constTensors.emplace_back(ConstTensor(managedCellBias.GetTensorInfo(),
382 managedCellBias.Map()));
383 }
384 if (m_BasicParameters.m_OutputGateBias != nullptr)
385 {
386 constTensors.emplace_back(ConstTensor(managedOutputGateBias.GetTensorInfo(),
387 managedOutputGateBias.Map()));
388 }
389
390 // Add cifig parameters
391 if (m_CifgParameters.m_InputToInputWeights != nullptr)
392 {
393 constTensors.emplace_back(ConstTensor(managedInputToInputWeights.GetTensorInfo(),
394 managedInputToInputWeights.Map()));
395 }
396 if (m_CifgParameters.m_RecurrentToInputWeights != nullptr)
397 {
398 constTensors.emplace_back(ConstTensor(
399 managedRecurrentToInputWeights.GetTensorInfo(),
400 managedRecurrentToInputWeights.Map()));
401 }
402 if (m_CifgParameters.m_InputGateBias != nullptr)
403 {
404 constTensors.emplace_back(ConstTensor(managedInputGateBias.GetTensorInfo(),
405 managedInputGateBias.Map()));
406 }
407
408 // Add peephole parameters
409 if (m_PeepholeParameters.m_CellToInputWeights != nullptr)
410 {
411 constTensors.emplace_back(ConstTensor(managedCellToInputWeights.GetTensorInfo(),
412 managedCellToInputWeights.Map()));
413 }
414 if (m_PeepholeParameters.m_CellToForgetWeights != nullptr)
415 {
416 constTensors.emplace_back(ConstTensor(managedCellToForgetWeights.GetTensorInfo(),
417 managedCellToForgetWeights.Map()));
418 }
419 if (m_PeepholeParameters.m_CellToOutputWeights != nullptr)
420 {
421 constTensors.emplace_back(ConstTensor(managedCellToOutputWeights.GetTensorInfo(),
422 managedCellToOutputWeights.Map()));
423 }
424
425 // Add projection parameters
426 if (m_ProjectionParameters.m_ProjectionWeights != nullptr)
427 {
428 constTensors.emplace_back(ConstTensor(managedProjectionWeights.GetTensorInfo(),
429 managedProjectionWeights.Map()));
430 }
431 if (m_ProjectionParameters.m_ProjectionBias != nullptr)
432 {
433 constTensors.emplace_back(ConstTensor(managedProjectionBias.GetTensorInfo(),
434 managedProjectionBias.Map()));
435 }
436
437 // Add norm parameters
438 if (m_LayerNormParameters.m_InputLayerNormWeights != nullptr)
439 {
440 constTensors.emplace_back(ConstTensor(managedInputLayerNormWeights.GetTensorInfo(),
441 managedInputLayerNormWeights.Map()));
442 }
443 if (m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr)
444 {
445 constTensors.emplace_back(ConstTensor(managedForgetLayerNormWeights.GetTensorInfo(),
446 managedForgetLayerNormWeights.Map()));
447 }
448 if (m_LayerNormParameters.m_CellLayerNormWeights != nullptr)
449 {
450 constTensors.emplace_back(ConstTensor(managedCellLayerNormWeights.GetTensorInfo(),
451 managedCellLayerNormWeights.Map()));
452 }
453 if (m_LayerNormParameters.m_OutputLayerNormWeights != nullptr)
454 {
455 constTensors.emplace_back(ConstTensor(managedOutputLayerNormWeights.GetTensorInfo(),
456 managedOutputLayerNormWeights.Map()));
457 }
458 strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName());
459 }
460
461 } // namespace armnn
462