1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "ClBaseWorkload.hpp" 9 10 #include <armnn/QuantizedLstmParams.hpp> 11 #include <armnn/backends/Workload.hpp> 12 #include <armnn/backends/WorkloadData.hpp> 13 14 #include <arm_compute/runtime/CL/functions/CLLSTMLayerQuantized.h> 15 16 namespace armnn 17 { 18 19 arm_compute::Status ClQuantizedLstmWorkloadValidate(const TensorInfo& input, const TensorInfo& previousCellStateIn, 20 const TensorInfo& previousOutputIn, const TensorInfo& cellStateOut, 21 const TensorInfo& output, 22 const QuantizedLstmInputParamsInfo& paramsInfo); 23 24 class ClQuantizedLstmWorkload : public ClBaseWorkload<QuantizedLstmQueueDescriptor> 25 { 26 public: 27 ClQuantizedLstmWorkload(const QuantizedLstmQueueDescriptor& descriptor, 28 const WorkloadInfo& info, 29 const arm_compute::CLCompileContext& clCompileContext); 30 void Execute() const override; 31 32 private: 33 mutable arm_compute::CLLSTMLayerQuantized m_QuantizedLstmLayer; 34 35 std::unique_ptr<arm_compute::CLTensor> m_InputToInputWeightsTensor; 36 std::unique_ptr<arm_compute::CLTensor> m_InputToForgetWeightsTensor; 37 std::unique_ptr<arm_compute::CLTensor> m_InputToCellWeightsTensor; 38 std::unique_ptr<arm_compute::CLTensor> m_InputToOutputWeightsTensor; 39 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToInputWeightsTensor; 40 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToForgetWeightsTensor; 41 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToCellWeightsTensor; 42 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToOutputWeightsTensor; 43 std::unique_ptr<arm_compute::CLTensor> m_InputGateBiasTensor; 44 std::unique_ptr<arm_compute::CLTensor> m_ForgetGateBiasTensor; 45 std::unique_ptr<arm_compute::CLTensor> m_CellBiasTensor; 46 std::unique_ptr<arm_compute::CLTensor> m_OutputGateBiasTensor; 47 48 void FreeUnusedTensors(); 49 }; 50 51 } //namespace armnn 52 53 54