1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/QuantizedLstmParams.hpp> 9 #include "NeonBaseWorkload.hpp" 10 #include <armnn/backends/WorkloadData.hpp> 11 12 #include <arm_compute/graph/Tensor.h> 13 #include <arm_compute/runtime/NEON/functions/NELSTMLayerQuantized.h> 14 15 namespace armnn 16 { 17 18 class NeonQuantizedLstmWorkload : public NeonBaseWorkload<QuantizedLstmQueueDescriptor> 19 { 20 public: 21 using BaseWorkload<QuantizedLstmQueueDescriptor>::m_Data; 22 NeonQuantizedLstmWorkload(const QuantizedLstmQueueDescriptor& descriptor, const WorkloadInfo& info); 23 virtual void Execute() const override; 24 25 private: 26 mutable arm_compute::NELSTMLayerQuantized m_QuantizedLstmLayer; 27 28 std::unique_ptr<arm_compute::Tensor> m_InputToInputWeightsTensor; 29 std::unique_ptr<arm_compute::Tensor> m_InputToForgetWeightsTensor; 30 std::unique_ptr<arm_compute::Tensor> m_InputToCellWeightsTensor; 31 std::unique_ptr<arm_compute::Tensor> m_InputToOutputWeightsTensor; 32 std::unique_ptr<arm_compute::Tensor> m_RecurrentToInputWeightsTensor; 33 std::unique_ptr<arm_compute::Tensor> m_RecurrentToForgetWeightsTensor; 34 std::unique_ptr<arm_compute::Tensor> m_RecurrentToCellWeightsTensor; 35 std::unique_ptr<arm_compute::Tensor> m_RecurrentToOutputWeightsTensor; 36 std::unique_ptr<arm_compute::Tensor> m_InputGateBiasTensor; 37 std::unique_ptr<arm_compute::Tensor> m_ForgetGateBiasTensor; 38 std::unique_ptr<arm_compute::Tensor> m_CellBiasTensor; 39 std::unique_ptr<arm_compute::Tensor> m_OutputGateBiasTensor; 40 std::unique_ptr<arm_compute::Tensor> m_CellStateInTensor; 41 std::unique_ptr<arm_compute::Tensor> m_OutputStateInTensor; 42 std::unique_ptr<arm_compute::Tensor> m_CellStateOutTensor; 43 44 void FreeUnusedTensors(); 45 }; 46 47 arm_compute::Status NeonQuantizedLstmWorkloadValidate(const TensorInfo& input, 48 const TensorInfo& outputStateIn, 49 const TensorInfo& cellStateIn, 50 const TensorInfo& outputStateOut, 51 const TensorInfo& cellStateOut, 52 const QuantizedLstmInputParamsInfo& paramsInfo); 53 54 } //namespace armnn 55