1 // 2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/Descriptors.hpp> 9 #include <armnn/LstmParams.hpp> 10 #include "ClBaseWorkload.hpp" 11 #include <armnn/backends/WorkloadData.hpp> 12 13 #include "arm_compute/graph/Tensor.h" 14 #include "arm_compute/runtime/CL/functions/CLQLSTMLayer.h" 15 16 namespace armnn 17 { 18 19 class ClQLstmWorkload : public ClBaseWorkload<QLstmQueueDescriptor> 20 { 21 public: 22 ClQLstmWorkload(const QLstmQueueDescriptor& descriptor, 23 const WorkloadInfo& info, 24 const arm_compute::CLCompileContext& clCompileContext); 25 virtual void Execute() const override; 26 27 private: 28 mutable arm_compute::CLQLSTMLayer m_QLstmLayer; 29 30 std::unique_ptr<arm_compute::CLTensor> m_InputToInputWeightsTensor; 31 std::unique_ptr<arm_compute::CLTensor> m_InputToForgetWeightsTensor; 32 std::unique_ptr<arm_compute::CLTensor> m_InputToCellWeightsTensor; 33 std::unique_ptr<arm_compute::CLTensor> m_InputToOutputWeightsTensor; 34 35 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToInputWeightsTensor; 36 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToForgetWeightsTensor; 37 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToCellWeightsTensor; 38 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToOutputWeightsTensor; 39 40 std::unique_ptr<arm_compute::CLTensor> m_CellToInputWeightsTensor; 41 std::unique_ptr<arm_compute::CLTensor> m_CellToForgetWeightsTensor; 42 std::unique_ptr<arm_compute::CLTensor> m_CellToOutputWeightsTensor; 43 44 std::unique_ptr<arm_compute::CLTensor> m_InputGateBiasTensor; 45 std::unique_ptr<arm_compute::CLTensor> m_ForgetGateBiasTensor; 46 std::unique_ptr<arm_compute::CLTensor> m_CellBiasTensor; 47 std::unique_ptr<arm_compute::CLTensor> m_OutputGateBiasTensor; 48 49 std::unique_ptr<arm_compute::CLTensor> m_ProjectionWeightsTensor; 50 std::unique_ptr<arm_compute::CLTensor> m_ProjectionBiasTensor; 51 52 std::unique_ptr<arm_compute::CLTensor> m_InputLayerNormWeightsTensor; 53 std::unique_ptr<arm_compute::CLTensor> m_ForgetLayerNormWeightsTensor; 54 std::unique_ptr<arm_compute::CLTensor> m_CellLayerNormWeightsTensor; 55 std::unique_ptr<arm_compute::CLTensor> m_OutputLayerNormWeightsTensor; 56 57 void FreeUnusedTensors(); 58 }; 59 60 arm_compute::Status ClQLstmWorkloadValidate(const TensorInfo& input, 61 const TensorInfo& cellStateIn, 62 const TensorInfo& outputStateIn, 63 const TensorInfo& cellStateOut, 64 const TensorInfo& outputStateOut, 65 const TensorInfo& output, 66 const QLstmDescriptor& descriptor, 67 const LstmInputParamsInfo& paramsInfo); 68 } //namespace armnn 69