1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/Descriptors.hpp> 9 #include <armnn/LstmParams.hpp> 10 #include <armnn/backends/Workload.hpp> 11 #include <armnn/backends/WorkloadData.hpp> 12 #include "NeonBaseWorkload.hpp" 13 14 #include "arm_compute/runtime/NEON/functions/NEQLSTMLayer.h" 15 #include "arm_compute/runtime/NEON/functions/NEPermute.h" 16 #include "arm_compute/runtime/NEON/functions/NESplit.h" 17 #include "arm_compute/runtime/NEON/functions/NEConcatenateLayer.h" 18 19 namespace armnn 20 { 21 22 class NeonUnidirectionalSequenceLstmWorkload : public NeonBaseWorkload<UnidirectionalSequenceLstmQueueDescriptor> 23 { 24 public: 25 NeonUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor& descriptor, 26 const WorkloadInfo& info); 27 virtual void Execute() const override; 28 29 private: 30 31 // 32 // ACL layers required to fully form a Unidirectional Sequence LSTM layer. 33 // 34 mutable std::unique_ptr<arm_compute::NEPermute> m_Permute1; 35 mutable std::unique_ptr<arm_compute::IFunction> m_Splitter; 36 mutable std::vector<std::unique_ptr<arm_compute::NEQLSTMLayer>> m_Layers; 37 mutable std::unique_ptr<arm_compute::NEConcatenateLayer> m_Concat; 38 mutable std::unique_ptr<arm_compute::NEPermute> m_Permute2; 39 40 // 41 // ACL LSTM arm_compute::Tensors. 42 // 43 std::unique_ptr<arm_compute::Tensor> m_InputToInputWeightsTensor; 44 std::unique_ptr<arm_compute::Tensor> m_InputToForgetWeightsTensor; 45 std::unique_ptr<arm_compute::Tensor> m_InputToCellWeightsTensor; 46 std::unique_ptr<arm_compute::Tensor> m_InputToOutputWeightsTensor; 47 std::unique_ptr<arm_compute::Tensor> m_RecurrentToInputWeightsTensor; 48 std::unique_ptr<arm_compute::Tensor> m_RecurrentToForgetWeightsTensor; 49 std::unique_ptr<arm_compute::Tensor> m_RecurrentToCellWeightsTensor; 50 std::unique_ptr<arm_compute::Tensor> m_RecurrentToOutputWeightsTensor; 51 std::unique_ptr<arm_compute::Tensor> m_CellToInputWeightsTensor; 52 std::unique_ptr<arm_compute::Tensor> m_CellToForgetWeightsTensor; 53 std::unique_ptr<arm_compute::Tensor> m_CellToOutputWeightsTensor; 54 std::unique_ptr<arm_compute::Tensor> m_InputGateBiasTensor; 55 std::unique_ptr<arm_compute::Tensor> m_ForgetGateBiasTensor; 56 std::unique_ptr<arm_compute::Tensor> m_CellBiasTensor; 57 std::unique_ptr<arm_compute::Tensor> m_OutputGateBiasTensor; 58 std::unique_ptr<arm_compute::Tensor> m_ProjectionWeightsTensor; 59 std::unique_ptr<arm_compute::Tensor> m_ProjectionBiasTensor; 60 61 std::unique_ptr<arm_compute::Tensor> m_InputLayerNormWeightsTensor; 62 std::unique_ptr<arm_compute::Tensor> m_ForgetLayerNormWeightsTensor; 63 std::unique_ptr<arm_compute::Tensor> m_CellLayerNormWeightsTensor; 64 std::unique_ptr<arm_compute::Tensor> m_OutputLayerNormWeightsTensor; 65 66 // 67 // Additional ACL arm_compute::Tensors and std::vector<arm_compute::Tensor>. 68 // Required to perform splitting, concatenation and permutations. 69 // 70 arm_compute::Tensor m_PermuteFirstOut; 71 std::vector<arm_compute::Tensor> m_SplitterOutputsTensors; 72 std::vector<arm_compute::Tensor> m_ConcatInputsTensors; 73 std::vector<arm_compute::ITensor*> m_SplitterOutputs; 74 std::vector<const arm_compute::ITensor*> m_ConcatInputs; 75 arm_compute::Tensor concat_out; 76 77 void FreeUnusedTensors(); 78 }; 79 80 arm_compute::Status 81 NeonUnidirectionalSequenceLstmWorkloadValidate(const TensorInfo& input, 82 const TensorInfo& outputStateIn, 83 const TensorInfo& cellStateIn, 84 const TensorInfo& outputStateOut, 85 const TensorInfo& cellStateOut, 86 const TensorInfo& output, 87 const UnidirectionalSequenceLstmDescriptor& descriptor, 88 const LstmInputParamsInfo& paramsInfo); 89 90 } //namespace armnn 91