xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/Lstm.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/TypesUtils.hpp>
9 #include <armnn/backends/WorkloadData.hpp>
10 
11 #include "Encoders.hpp"
12 #include "Decoders.hpp"
13 
14 namespace armnn
15 {
16 
17 void LstmImpl(const LstmDescriptor& descriptor,
18               const TensorInfo& inputInfo,
19               const TensorInfo& outputInfo,
20               const TensorShape& inputToOutputWeightsShape,
21               const TensorShape& recurrentToOutputWeightsShape,
22               std::unique_ptr<Decoder<float>>& inputData,
23               std::unique_ptr<Decoder<float>>& outputStateIn,
24               std::unique_ptr<Decoder<float>>& cellStateIn,
25               std::unique_ptr<Encoder<float>>& outputStateOut,
26               std::unique_ptr<Encoder<float>>& cellStateOut,
27               std::unique_ptr<Encoder<float>>& output,
28               std::unique_ptr<Decoder<float>>& cellStateOutDecoder,
29               std::unique_ptr<Decoder<float>>& outputDecoder,
30               std::unique_ptr<Decoder<float>>& inputToInputWeightsTensor,
31               std::unique_ptr<Decoder<float>>& inputToForgetWeightsTensor,
32               std::unique_ptr<Decoder<float>>& inputToCellWeightsTensor,
33               std::unique_ptr<Decoder<float>>& inputToOutputWeightsTensor,
34               std::unique_ptr<Decoder<float>>& recurrentToInputWeightsTensor,
35               std::unique_ptr<Decoder<float>>& recurrentToForgetWeightsTensor,
36               std::unique_ptr<Decoder<float>>& recurrentToCellWeightsTensor,
37               std::unique_ptr<Decoder<float>>& recurrentToOutputWeightsTensor,
38               std::unique_ptr<Decoder<float>>& cellToInputWeightsTensor,
39               std::unique_ptr<Decoder<float>>& cellToForgetWeightsTensor,
40               std::unique_ptr<Decoder<float>>& cellToOutputWeightsTensor,
41               std::unique_ptr<Decoder<float>>& inputGateBiasTensor,
42               std::unique_ptr<Decoder<float>>& forgetGateBiasTensor,
43               std::unique_ptr<Decoder<float>>& cellBiasTensor,
44               std::unique_ptr<Decoder<float>>& outputGateBiasTensor,
45               std::unique_ptr<Decoder<float>>& projectionWeightsTensor,
46               std::unique_ptr<Decoder<float>>& projectionBiasTensor,
47               std::unique_ptr<Decoder<float>>& inputLayerNormWeights,
48               std::unique_ptr<Decoder<float>>& forgetLayerNormWeights,
49               std::unique_ptr<Decoder<float>>& cellLayerNormWeights,
50               std::unique_ptr<Decoder<float>>& outputLayerNormWeights,
51               std::unique_ptr<Encoder<float>>& inputGateScratch,
52               std::unique_ptr<Encoder<float>>& cellScratch,
53               std::unique_ptr<Encoder<float>>& forgetGateScratch,
54               std::unique_ptr<Encoder<float>>& outputGateScratch,
55               std::unique_ptr<Decoder<float>>& inputGateScratchDecoder,
56               std::unique_ptr<Decoder<float>>& cellScratchDecoder,
57               std::unique_ptr<Decoder<float>>& forgetGateScratchDecoder,
58               std::unique_ptr<Decoder<float>>& outputGateScratchDecoder,
59               float layerNormEpsilon);
60 
61 } //namespace armnn
62