1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/Descriptors.hpp> 9 #include <armnn/LstmParams.hpp> 10 #include <armnn/backends/Workload.hpp> 11 #include <armnn/backends/WorkloadData.hpp> 12 13 #include "arm_compute/graph/Tensor.h" 14 #include "arm_compute/runtime/NEON/functions/NELSTMLayer.h" 15 16 namespace armnn 17 { 18 19 class NeonLstmFloatWorkload : public FloatWorkload<LstmQueueDescriptor> 20 { 21 public: 22 NeonLstmFloatWorkload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info); 23 virtual void Execute() const override; 24 // Replace input tensor handle with the given TensorHandle 25 void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override; 26 27 // Replace output tensor handle with the given TensorHandle 28 void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override; 29 private: 30 mutable arm_compute::NELSTMLayer m_LstmLayer; 31 32 std::unique_ptr<arm_compute::Tensor> m_InputToInputWeightsTensor; 33 std::unique_ptr<arm_compute::Tensor> m_InputToForgetWeightsTensor; 34 std::unique_ptr<arm_compute::Tensor> m_InputToCellWeightsTensor; 35 std::unique_ptr<arm_compute::Tensor> m_InputToOutputWeightsTensor; 36 std::unique_ptr<arm_compute::Tensor> m_RecurrentToInputWeightsTensor; 37 std::unique_ptr<arm_compute::Tensor> m_RecurrentToForgetWeightsTensor; 38 std::unique_ptr<arm_compute::Tensor> m_RecurrentToCellWeightsTensor; 39 std::unique_ptr<arm_compute::Tensor> m_RecurrentToOutputWeightsTensor; 40 std::unique_ptr<arm_compute::Tensor> m_CellToInputWeightsTensor; 41 std::unique_ptr<arm_compute::Tensor> m_CellToForgetWeightsTensor; 42 std::unique_ptr<arm_compute::Tensor> m_CellToOutputWeightsTensor; 43 std::unique_ptr<arm_compute::Tensor> m_InputGateBiasTensor; 44 std::unique_ptr<arm_compute::Tensor> m_ForgetGateBiasTensor; 45 std::unique_ptr<arm_compute::Tensor> m_CellBiasTensor; 46 std::unique_ptr<arm_compute::Tensor> m_OutputGateBiasTensor; 47 std::unique_ptr<arm_compute::Tensor> m_ProjectionWeightsTensor; 48 std::unique_ptr<arm_compute::Tensor> m_ProjectionBiasTensor; 49 50 std::unique_ptr<arm_compute::Tensor> m_ScratchBuffer; 51 52 std::unique_ptr<arm_compute::Tensor> m_InputLayerNormWeightsTensor; 53 std::unique_ptr<arm_compute::Tensor> m_ForgetLayerNormWeightsTensor; 54 std::unique_ptr<arm_compute::Tensor> m_CellLayerNormWeightsTensor; 55 std::unique_ptr<arm_compute::Tensor> m_OutputLayerNormWeightsTensor; 56 57 void FreeUnusedTensors(); 58 virtual void Reconfigure(); 59 }; 60 61 arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input, const TensorInfo& outputStateIn, 62 const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer, 63 const TensorInfo& outputStateOut, const TensorInfo& cellStateOut, 64 const TensorInfo& output, const LstmDescriptor &descriptor, 65 const LstmInputParamsInfo& paramsInfo); 66 67 } //namespace armnn 68