xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/Lstm.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "Activation.hpp"
7 #include "Lstm.hpp"
8 #include "LstmUtils.hpp"
9 
10 namespace armnn
11 {
12 
LstmImpl(const LstmDescriptor & descriptor,const TensorInfo & inputInfo,const TensorInfo & outputInfo,const TensorShape & inputToOutputWeightsShape,const TensorShape & recurrentToOutputWeightsShape,std::unique_ptr<Decoder<float>> & inputData,std::unique_ptr<Decoder<float>> & outputStateIn,std::unique_ptr<Decoder<float>> & cellStateIn,std::unique_ptr<Encoder<float>> & outputStateOut,std::unique_ptr<Encoder<float>> & cellStateOut,std::unique_ptr<Encoder<float>> & output,std::unique_ptr<Decoder<float>> & cellStateOutDecoder,std::unique_ptr<Decoder<float>> & outputDecoder,std::unique_ptr<Decoder<float>> & inputToInputWeightsTensor,std::unique_ptr<Decoder<float>> & inputToForgetWeightsTensor,std::unique_ptr<Decoder<float>> & inputToCellWeightsTensor,std::unique_ptr<Decoder<float>> & inputToOutputWeightsTensor,std::unique_ptr<Decoder<float>> & recurrentToInputWeightsTensor,std::unique_ptr<Decoder<float>> & recurrentToForgetWeightsTensor,std::unique_ptr<Decoder<float>> & recurrentToCellWeightsTensor,std::unique_ptr<Decoder<float>> & recurrentToOutputWeightsTensor,std::unique_ptr<Decoder<float>> & cellToInputWeightsTensor,std::unique_ptr<Decoder<float>> & cellToForgetWeightsTensor,std::unique_ptr<Decoder<float>> & cellToOutputWeightsTensor,std::unique_ptr<Decoder<float>> & inputGateBiasTensor,std::unique_ptr<Decoder<float>> & forgetGateBiasTensor,std::unique_ptr<Decoder<float>> & cellBiasTensor,std::unique_ptr<Decoder<float>> & outputGateBiasTensor,std::unique_ptr<Decoder<float>> & projectionWeightsTensor,std::unique_ptr<Decoder<float>> & projectionBiasTensor,std::unique_ptr<Decoder<float>> & inputLayerNormWeights,std::unique_ptr<Decoder<float>> & forgetLayerNormWeights,std::unique_ptr<Decoder<float>> & cellLayerNormWeights,std::unique_ptr<Decoder<float>> & outputLayerNormWeights,std::unique_ptr<Encoder<float>> & inputGateScratch,std::unique_ptr<Encoder<float>> & cellScratch,std::unique_ptr<Encoder<float>> & forgetGateScratch,std::unique_ptr<Encoder<float>> & outputGateScratch,std::unique_ptr<Decoder<float>> & inputGateScratchDecoder,std::unique_ptr<Decoder<float>> & cellScratchDecoder,std::unique_ptr<Decoder<float>> & forgetGateScratchDecoder,std::unique_ptr<Decoder<float>> & outputGateScratchDecoder,float layerNormEpsilon)13 void LstmImpl(const LstmDescriptor& descriptor,
14               const TensorInfo& inputInfo,
15               const TensorInfo& outputInfo,
16               const TensorShape& inputToOutputWeightsShape,
17               const TensorShape& recurrentToOutputWeightsShape,
18               std::unique_ptr<Decoder<float>>& inputData,
19               std::unique_ptr<Decoder<float>>& outputStateIn,
20               std::unique_ptr<Decoder<float>>& cellStateIn,
21               std::unique_ptr<Encoder<float>>& outputStateOut,
22               std::unique_ptr<Encoder<float>>& cellStateOut,
23               std::unique_ptr<Encoder<float>>& output,
24               std::unique_ptr<Decoder<float>>& cellStateOutDecoder,
25               std::unique_ptr<Decoder<float>>& outputDecoder,
26               std::unique_ptr<Decoder<float>>& inputToInputWeightsTensor,
27               std::unique_ptr<Decoder<float>>& inputToForgetWeightsTensor,
28               std::unique_ptr<Decoder<float>>& inputToCellWeightsTensor,
29               std::unique_ptr<Decoder<float>>& inputToOutputWeightsTensor,
30               std::unique_ptr<Decoder<float>>& recurrentToInputWeightsTensor,
31               std::unique_ptr<Decoder<float>>& recurrentToForgetWeightsTensor,
32               std::unique_ptr<Decoder<float>>& recurrentToCellWeightsTensor,
33               std::unique_ptr<Decoder<float>>& recurrentToOutputWeightsTensor,
34               std::unique_ptr<Decoder<float>>& cellToInputWeightsTensor,
35               std::unique_ptr<Decoder<float>>& cellToForgetWeightsTensor,
36               std::unique_ptr<Decoder<float>>& cellToOutputWeightsTensor,
37               std::unique_ptr<Decoder<float>>& inputGateBiasTensor,
38               std::unique_ptr<Decoder<float>>& forgetGateBiasTensor,
39               std::unique_ptr<Decoder<float>>& cellBiasTensor,
40               std::unique_ptr<Decoder<float>>& outputGateBiasTensor,
41               std::unique_ptr<Decoder<float>>& projectionWeightsTensor,
42               std::unique_ptr<Decoder<float>>& projectionBiasTensor,
43               std::unique_ptr<Decoder<float>>& inputLayerNormWeights,
44               std::unique_ptr<Decoder<float>>& forgetLayerNormWeights,
45               std::unique_ptr<Decoder<float>>& cellLayerNormWeights,
46               std::unique_ptr<Decoder<float>>& outputLayerNormWeights,
47               std::unique_ptr<Encoder<float>>& inputGateScratch,
48               std::unique_ptr<Encoder<float>>& cellScratch,
49               std::unique_ptr<Encoder<float>>& forgetGateScratch,
50               std::unique_ptr<Encoder<float>>& outputGateScratch,
51               std::unique_ptr<Decoder<float>>& inputGateScratchDecoder,
52               std::unique_ptr<Decoder<float>>& cellScratchDecoder,
53               std::unique_ptr<Decoder<float>>& forgetGateScratchDecoder,
54               std::unique_ptr<Decoder<float>>& outputGateScratchDecoder,
55               float layerNormEpsilon)
56 {
57     // This is a porting of the LSTM::Eval() method in the Android code base
58     // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
59 
60     const TensorShape& inputShape = inputInfo.GetShape();
61     const DataType& outputType = outputInfo.GetDataType();
62 
63     const uint32_t nBatch = inputShape[0];
64     const uint32_t nInput = inputShape[1];
65 
66     const uint32_t nCell   = inputToOutputWeightsShape[0];
67     const uint32_t nOutput = recurrentToOutputWeightsShape[1];
68 
69     const bool useCifg      = descriptor.m_CifgEnabled;
70     const bool usePeephole  = descriptor.m_PeepholeEnabled;
71     const bool useLayerNorm = descriptor.m_LayerNormEnabled;
72 
73     if (!useLayerNorm)
74     {
75         // Initialize scratch buffers with bias.
76         if (!useCifg)
77         {
78             VectorBatchVectorAssign(*inputGateBiasTensor,
79                                     nCell, nBatch, *inputGateScratch);
80         }
81         VectorBatchVectorAssign(*forgetGateBiasTensor,
82                                 nCell, nBatch, *forgetGateScratch);
83         VectorBatchVectorAssign(*cellBiasTensor,
84                                 nCell, nBatch, *cellScratch);
85         VectorBatchVectorAssign(*outputGateBiasTensor,
86                                 nCell, nBatch, *outputGateScratch);
87     }
88     else
89     {
90         // Initialize scratch buffers with zeroes.
91         if (!useCifg)
92         {
93             ZeroVector(*inputGateScratch, nCell * nBatch);
94         }
95         ZeroVector(*forgetGateScratch, nCell * nBatch);
96         ZeroVector(*cellScratch      , nCell * nBatch);
97         ZeroVector(*outputGateScratch, nCell * nBatch);
98     }
99 
100     // For each batch and cell: compute input_weight * input.
101     if (!useCifg)
102     {
103         MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor,
104                                             nCell, nInput, *inputData, nBatch, *inputGateScratch);
105     }
106     MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor,
107                                         nCell, nInput, *inputData, nBatch, *forgetGateScratch);
108     MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor,
109                                         nCell, nInput, *inputData, nBatch, *cellScratch);
110     MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor,
111                                         nCell, nInput, *inputData, nBatch, *outputGateScratch);
112 
113     // For each batch and cell: compute recurrent_weight * output_state.
114     if (!useCifg)
115     {
116         MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor,
117                                             nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch);
118     }
119     MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor,
120                                         nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch);
121     MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor,
122                                         nCell, nOutput, *outputStateIn, nBatch, *cellScratch);
123     MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor,
124                                         nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch);
125 
126     // For each batch and cell: update input gate.
127     if (!useCifg)
128     {
129         if (usePeephole)
130         {
131             VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor,
132                                                     nCell, *cellStateIn, nBatch, *inputGateScratch);
133         }
134         if (useLayerNorm)
135         {
136             MeanStddevNormalization(*inputGateScratchDecoder,
137                                     *inputGateScratch, nCell, nBatch, layerNormEpsilon);
138             VectorBatchVectorCwiseProduct(*inputLayerNormWeights,
139                                           nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
140             VectorBatchVectorAdd(*inputGateBiasTensor,
141                                  nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
142         }
143         Activation(*inputGateScratchDecoder, *inputGateScratch,
144                    TensorInfo({nCell, nBatch}, outputType),
145                    ActivationFunction::Sigmoid, 0, 0);
146     }
147 
148     // For each batch and cell: update forget gate.
149     if (usePeephole)
150     {
151         VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell,
152                                                 *cellStateIn, nBatch, *forgetGateScratch);
153     }
154     if (useLayerNorm)
155     {
156         MeanStddevNormalization(*forgetGateScratchDecoder,
157                                 *forgetGateScratch, nCell, nBatch, layerNormEpsilon);
158         VectorBatchVectorCwiseProduct(*forgetLayerNormWeights,
159                                       nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
160         VectorBatchVectorAdd(*forgetGateBiasTensor,
161                              nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
162     }
163     Activation(*forgetGateScratchDecoder, *forgetGateScratch,
164                TensorInfo({nCell, nBatch}, outputType),
165                ActivationFunction::Sigmoid, 0, 0);
166 
167     // For each batch and cell: update the cell.
168     if (useLayerNorm)
169     {
170         MeanStddevNormalization(*cellScratchDecoder,
171                                 *cellScratch, nCell, nBatch, layerNormEpsilon);
172         VectorBatchVectorCwiseProduct(*cellLayerNormWeights,
173                                       nCell, *cellScratchDecoder, nBatch, *cellScratch);
174         VectorBatchVectorAdd(*cellBiasTensor,
175                              nCell, *cellScratchDecoder, nBatch, *cellScratch);
176     }
177 
178     VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut);
179 
180     ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid;
181     float a = 0;
182     float b = 0;
183     SetActivationParameters(descriptor.m_ActivationFunc, armnnActivationFunc, a, b);
184 
185     if (descriptor.m_ActivationFunc > 0)
186     {
187         Activation(*cellScratchDecoder, *cellScratch,
188                    TensorInfo({nCell, nBatch}, outputType),
189                    armnnActivationFunc, a, b);
190     }
191     if (useCifg)
192     {
193         Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch);
194         VectorVectorCwiseProductAccumulate(
195             *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut);
196     }
197     else
198     {
199         VectorVectorCwiseProductAccumulate(
200             *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut);
201     }
202     if (descriptor.m_ClippingThresCell > 0.0)
203     {
204         ClipVector(*cellStateOutDecoder, nBatch * nCell, descriptor.m_ClippingThresCell, *cellStateOut);
205     }
206 
207     // For each batch and cell: update the output gate.
208     if (usePeephole)
209     {
210         VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor,
211                                                 nCell, *cellStateOutDecoder, nBatch, *outputGateScratch);
212     }
213     if (useLayerNorm)
214     {
215         MeanStddevNormalization(*outputGateScratchDecoder,
216                                 *outputGateScratch, nCell, nBatch, layerNormEpsilon);
217         VectorBatchVectorCwiseProduct(*outputLayerNormWeights,
218                                       nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
219         VectorBatchVectorAdd(*outputGateBiasTensor,
220                              nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
221     }
222     Activation(*outputGateScratchDecoder, *outputGateScratch,
223                TensorInfo({nCell, nBatch}, outputType),
224                ActivationFunction::Sigmoid, 0, 0);
225 
226     if (descriptor.m_ActivationFunc > 0)
227     {
228         Activation(*cellStateOutDecoder, *cellScratch,
229                    TensorInfo({nCell, nBatch}, outputType),
230                    armnnActivationFunc, a, b);
231     }
232 
233     VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch);
234 
235     // For each batch: update the projection and output_state.
236     if (descriptor.m_ProjectionEnabled)
237     {
238         if (projectionBiasTensor)
239         {
240             VectorBatchVectorAssign(*projectionBiasTensor,
241                                     nOutput, nBatch, *output);
242         }
243         MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor,
244                                             nOutput, nCell, *outputGateScratchDecoder, nBatch, *output);
245 
246         if (descriptor.m_ClippingThresProj > 0.0)
247         {
248             ClipVector(*outputDecoder, nBatch * nOutput, descriptor.m_ClippingThresProj, *output);
249         }
250     }
251     else
252     {
253         CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output);
254     }
255 
256     CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut);
257 }
258 
259 } //namespace armnn
260