1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/TypesUtils.hpp> 9 10 #include "RefBaseWorkload.hpp" 11 #include <armnn/backends/WorkloadData.hpp> 12 13 namespace armnn 14 { 15 16 class RefQLstmWorkload : public RefBaseWorkload<QLstmQueueDescriptor> 17 { 18 public: 19 explicit RefQLstmWorkload(const QLstmQueueDescriptor& descriptor, const WorkloadInfo& info); 20 21 void Execute() const override; 22 void ExecuteAsync(ExecutionData& executionData) override; 23 24 private: 25 void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const; 26 std::unique_ptr<ScopedTensorHandle> m_InputToInputWeightsTensor; 27 std::unique_ptr<ScopedTensorHandle> m_InputToForgetWeightsTensor; 28 std::unique_ptr<ScopedTensorHandle> m_InputToCellWeightsTensor; 29 std::unique_ptr<ScopedTensorHandle> m_InputToOutputWeightsTensor; 30 31 std::unique_ptr<ScopedTensorHandle> m_RecurrentToInputWeightsTensor; 32 std::unique_ptr<ScopedTensorHandle> m_RecurrentToForgetWeightsTensor; 33 std::unique_ptr<ScopedTensorHandle> m_RecurrentToCellWeightsTensor; 34 std::unique_ptr<ScopedTensorHandle> m_RecurrentToOutputWeightsTensor; 35 36 std::unique_ptr<ScopedTensorHandle> m_CellToInputWeightsTensor; 37 std::unique_ptr<ScopedTensorHandle> m_CellToForgetWeightsTensor; 38 std::unique_ptr<ScopedTensorHandle> m_CellToOutputWeightsTensor; 39 40 std::unique_ptr<ScopedTensorHandle> m_InputGateBiasTensor; 41 std::unique_ptr<ScopedTensorHandle> m_ForgetGateBiasTensor; 42 std::unique_ptr<ScopedTensorHandle> m_CellBiasTensor; 43 std::unique_ptr<ScopedTensorHandle> m_OutputGateBiasTensor; 44 45 std::unique_ptr<ScopedTensorHandle> m_ProjectionWeightsTensor; 46 std::unique_ptr<ScopedTensorHandle> m_ProjectionBiasTensor; 47 48 std::unique_ptr<ScopedTensorHandle> m_InputLayerNormWeightsTensor; 49 std::unique_ptr<ScopedTensorHandle> m_ForgetLayerNormWeightsTensor; 50 std::unique_ptr<ScopedTensorHandle> m_CellLayerNormWeightsTensor; 51 std::unique_ptr<ScopedTensorHandle> m_OutputLayerNormWeightsTensor; 52 53 float m_LayerNormEpsilon = static_cast<float>(1e-8); 54 }; 55 56 } //namespace armnn 57