1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "RefUnidirectionalSequenceLstmWorkload.hpp"
7 #include "Activation.hpp"
8 #include "Encoders.hpp"
9 #include "Decoders.hpp"
10 #include "Lstm.hpp"
11 #include "LstmUtils.hpp"
12 #include "RefWorkloadUtils.hpp"
13 
14 #include <armnnUtils/Permute.hpp>
15 
16 namespace armnn
17 {
18 
RefUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor & descriptor,const WorkloadInfo & info)19 RefUnidirectionalSequenceLstmWorkload::RefUnidirectionalSequenceLstmWorkload(
20     const UnidirectionalSequenceLstmQueueDescriptor& descriptor,
21     const WorkloadInfo& info)
22     : RefBaseWorkload<UnidirectionalSequenceLstmQueueDescriptor>(descriptor, info)
23     , m_InputToInputWeightsTensor     (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
24     , m_InputToForgetWeightsTensor    (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
25     , m_InputToCellWeightsTensor      (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
26     , m_InputToOutputWeightsTensor    (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
27     , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
28     , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
29     , m_RecurrentToCellWeightsTensor  (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
30     , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
31     , m_CellToInputWeightsTensor      (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
32     , m_CellToForgetWeightsTensor     (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
33     , m_CellToOutputWeightsTensor     (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
34     , m_InputGateBiasTensor           (AssignScopedTensorHandle(descriptor.m_InputGateBias))
35     , m_ForgetGateBiasTensor          (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
36     , m_CellBiasTensor                (AssignScopedTensorHandle(descriptor.m_CellBias))
37     , m_OutputGateBiasTensor          (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
38     , m_ProjectionWeightsTensor       (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
39     , m_ProjectionBiasTensor          (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
40     , m_InputLayerNormWeights         (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
41     , m_ForgetLayerNormWeights        (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
42     , m_CellLayerNormWeights          (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
43     , m_OutputLayerNormWeights        (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
44 {}
45 
Execute() const46 void RefUnidirectionalSequenceLstmWorkload::Execute() const
47 {
48     Execute(m_Data.m_Inputs, m_Data.m_Outputs);
49 }
50 
ExecuteAsync(ExecutionData & executionData)51 void RefUnidirectionalSequenceLstmWorkload::ExecuteAsync(ExecutionData& executionData)
52 {
53     WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
54     Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
55 }
56 
Execute(std::vector<ITensorHandle * > inputs,std::vector<ITensorHandle * > outputs) const57 void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*> inputs,
58                                                     std::vector<ITensorHandle*> outputs) const
59 {
60     TensorInfo inputInfo = GetTensorInfo(inputs[0]);
61     const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]);
62     const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]);
63     TensorInfo outputStateOutInfo = GetTensorInfo(outputs[0]);
64     TensorInfo cellStateOutInfo = GetTensorInfo(outputs[1]);
65     TensorInfo outputInfo = GetTensorInfo(outputs[2]);
66     TensorShape& inputShape = inputInfo.GetShape();
67     TensorShape& outputShape= outputInfo.GetShape();
68     auto inputTensor = reinterpret_cast<float*>(inputs[0]->Map());
69 
70     if (!m_Data.m_Parameters.m_TimeMajor)
71     {
72         // Permute to time major
73         const PermutationVector& mappings = {1U, 0U, 2U};
74         std::vector<float> inputValue(inputTensor, inputTensor + inputInfo.GetNumElements());
75         inputShape = armnnUtils::Permuted(inputInfo.GetShape(), mappings);
76         inputInfo.SetShape(inputShape);
77         armnnUtils::Permute(inputShape, mappings,  inputValue.data(), inputTensor, sizeof(float));
78 
79         outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
80         outputInfo.SetShape(outputShape);
81     }
82     unsigned int maxTime = inputShape[0];
83     unsigned int batchSize = inputShape[1];
84     unsigned int outputSize = outputShape[2];
85     unsigned int inputSize = inputShape[2];
86 
87     TensorInfo scratchInfo = outputInfo;
88     scratchInfo.SetShape({batchSize, cellStateInfo.GetShape()[1]});
89 
90     std::vector<float> inputGateScratchBuffer;
91     std::vector<float> cellScratchBuffer(scratchInfo.GetNumElements(), 0.);
92     std::vector<float> forgetGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
93     std::vector<float> outputGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
94 
95     std::vector<float> outputStateOutBuffer(outputStateInfo.GetNumElements(), 0.);
96     std::vector<float> cellStateOutBuffer(cellStateInfo.GetNumElements(), 0.);
97 
98     void* outputStateOutData = outputStateOutBuffer.data();
99     void* cellStateOutData = cellStateOutBuffer.data();
100 
101     std::unique_ptr<Encoder<float>> inputGateScratch;
102     std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(scratchInfo, cellScratchBuffer.data());
103     std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(scratchInfo, forgetGateScratchBuffer.data());
104     std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(scratchInfo, outputGateScratchBuffer.data());
105 
106     std::unique_ptr<Decoder<float>> inputGateScratchDecoder;
107     std::unique_ptr<Decoder<float>> cellScratchDecoder = MakeDecoder<float>(scratchInfo, cellScratchBuffer.data());
108     std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = MakeDecoder<float>(scratchInfo,
109                                                                                   forgetGateScratchBuffer.data());
110     std::unique_ptr<Decoder<float>> outputGateScratchDecoder = MakeDecoder<float>(scratchInfo,
111                                                                                   outputGateScratchBuffer.data());
112 
113     const bool useCifg      = m_Data.m_Parameters.m_CifgEnabled;
114     const bool usePeephole  = m_Data.m_Parameters.m_PeepholeEnabled;
115     const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
116 
117     if (!useCifg)
118     {
119         inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.);
120         inputGateScratch = MakeEncoder<float>(scratchInfo, inputGateScratchBuffer.data());
121         inputGateScratchDecoder = MakeDecoder<float>(scratchInfo, inputGateScratchBuffer.data());
122     }
123 
124     std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputStateInfo, outputStateOutData);
125     std::unique_ptr<Encoder<float>> cellStateOut   = MakeEncoder<float>(cellStateInfo, cellStateOutData);
126     std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(cellStateInfo, cellStateOutData);
127 
128     TensorInfo lstmInputInfo = inputInfo;
129     TensorShape batchInputShape = TensorShape({batchSize, inputSize});
130     lstmInputInfo.SetShape(batchInputShape);
131 
132     TensorInfo lstmOutputInfo = outputInfo;
133     lstmOutputInfo.SetShape({batchSize, outputSize});
134 
135     const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
136     const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
137     unsigned int nOutput = recurrentToOutputWeightsShape[1];
138     auto outputStateInData = inputs[1]->Map();
139     std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateInData);
140 
141     auto cellStateInData = inputs[2]->Map();
142     std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateInData);
143 
144     auto currentInputData = reinterpret_cast<float*>(inputs[0]->Map());
145     std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
146     auto currentOutputData = reinterpret_cast<float*>(outputs[2]->Map());
147     std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
148     std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
149 
150     std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
151     std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
152         m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
153     std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
154         m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
155     std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
156         m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
157 
158     std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
159     std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
160         m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
161     std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
162         m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
163     std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
164         m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
165 
166     std::unique_ptr<Decoder<float>> inputGateBiasTensor;
167     std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
168         m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
169     std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
170         m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
171     std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
172         m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
173 
174     std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
175     std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
176     std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
177 
178     std::unique_ptr<Decoder<float>> projectionWeightsTensor;
179     std::unique_ptr<Decoder<float>> projectionBiasTensor;
180 
181     std::unique_ptr<Decoder<float>> inputLayerNormWeights;
182     std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
183     std::unique_ptr<Decoder<float>> cellLayerNormWeights;
184     std::unique_ptr<Decoder<float>> outputLayerNormWeights;
185 
186     if (useLayerNorm)
187     {
188         if (!useCifg)
189         {
190             inputLayerNormWeights = MakeDecoder<float>(
191                     m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
192         }
193         forgetLayerNormWeights = MakeDecoder<float>(
194                 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
195         cellLayerNormWeights = MakeDecoder<float>(
196                 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
197         outputLayerNormWeights = MakeDecoder<float>(
198                 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
199     }
200 
201     if (!useCifg)
202     {
203         inputToInputWeightsTensor = MakeDecoder<float>(
204             m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
205         inputGateBiasTensor = MakeDecoder<float>(
206             m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
207         recurrentToInputWeightsTensor = MakeDecoder<float>(
208             m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
209     }
210 
211     if (usePeephole)
212     {
213         cellToForgetWeightsTensor = MakeDecoder<float>(
214             m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
215         cellToOutputWeightsTensor = MakeDecoder<float>(
216             m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
217     }
218 
219     if (!useCifg && usePeephole)
220     {
221         cellToInputWeightsTensor = MakeDecoder<float>(
222             m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
223     }
224 
225     if (m_Data.m_Parameters.m_ProjectionEnabled)
226     {
227         projectionWeightsTensor = MakeDecoder<float>(
228             m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
229         if (m_ProjectionBiasTensor)
230         {
231             projectionBiasTensor = MakeDecoder<float>(
232                 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
233         }
234     }
235 
236     unsigned int batchInputSize = batchSize * inputSize;
237     unsigned int batchOutputSize = batchSize * nOutput;
238 
239     for (unsigned int t = 0; t < maxTime; ++t)
240     {
241         LstmImpl(m_Data.m_Parameters,
242                  lstmInputInfo,
243                  lstmOutputInfo,
244                  inputToOutputWeightsShape,
245                  recurrentToOutputWeightsShape,
246                  inputData,
247                  outputStateIn,
248                  cellStateIn,
249                  outputStateOut,
250                  cellStateOut,
251                  output,
252                  cellStateOutDecoder,
253                  outputDecoder,
254                  inputToInputWeightsTensor,
255                  inputToForgetWeightsTensor,
256                  inputToCellWeightsTensor,
257                  inputToOutputWeightsTensor,
258                  recurrentToInputWeightsTensor,
259                  recurrentToForgetWeightsTensor,
260                  recurrentToCellWeightsTensor,
261                  recurrentToOutputWeightsTensor,
262                  cellToInputWeightsTensor,
263                  cellToForgetWeightsTensor,
264                  cellToOutputWeightsTensor,
265                  inputGateBiasTensor,
266                  forgetGateBiasTensor,
267                  cellBiasTensor,
268                  outputGateBiasTensor,
269                  projectionWeightsTensor,
270                  projectionBiasTensor,
271                  inputLayerNormWeights,
272                  forgetLayerNormWeights,
273                  cellLayerNormWeights,
274                  outputLayerNormWeights,
275                  inputGateScratch,
276                  cellScratch,
277                  forgetGateScratch,
278                  outputGateScratch,
279                  inputGateScratchDecoder,
280                  cellScratchDecoder,
281                  forgetGateScratchDecoder,
282                  outputGateScratchDecoder,
283                  m_LayerNormEpsilon);
284 
285         currentInputData += batchInputSize;
286         inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
287         currentOutputData += batchOutputSize;
288         output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
289         outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
290 
291         // Assign output state out to the next output state in
292         outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateOutData);
293 
294         // Assign cell state out to the next cell state in
295         cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateOutData);
296     }
297 
298     if (!m_Data.m_Parameters.m_TimeMajor)
299     {
300         // Permute Output back to batch major
301         const PermutationVector& mappings = {1U, 0U, 2U};
302         auto outputData = reinterpret_cast<float*>(outputs[2]->Map());
303         std::vector<float> outputValue(outputData, outputData + outputInfo.GetNumElements());
304         outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
305         outputInfo.SetShape(outputShape);
306         armnnUtils::Permute(outputShape, mappings, outputValue.data(), outputData, sizeof(float));
307     }
308 }
309 
310 } //namespace armnn
311