1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/Descriptors.hpp> 9 #include <armnn/LstmParams.hpp> 10 #include <armnn/backends/Workload.hpp> 11 #include <armnn/backends/WorkloadData.hpp> 12 13 #include <arm_compute/graph/Tensor.h> 14 #include <arm_compute/runtime/CL/functions/CLLSTMLayer.h> 15 #include <arm_compute/runtime/CL/functions/CLPermute.h> 16 #include <arm_compute/runtime/CL/functions/CLSplit.h> 17 #include <arm_compute/runtime/CL/functions/CLConcatenateLayer.h> 18 19 namespace armnn 20 { 21 22 class ClUnidirectionalSequenceLstmFloatWorkload : public FloatWorkload<UnidirectionalSequenceLstmQueueDescriptor> 23 { 24 public: 25 ClUnidirectionalSequenceLstmFloatWorkload(const UnidirectionalSequenceLstmQueueDescriptor& descriptor, 26 const WorkloadInfo& info, 27 const arm_compute::CLCompileContext& clCompileContext); 28 virtual void Execute() const override; 29 30 private: 31 32 // 33 // ACL layers required to fully form a Unidirectional Sequence LSTM layer. 34 // 35 36 // permutation for input (only used when input is batch major) 37 mutable std::unique_ptr<arm_compute::CLPermute> m_Permute1; 38 mutable std::unique_ptr<arm_compute::IFunction> m_Splitter; 39 mutable std::vector<std::unique_ptr<arm_compute::CLLSTMLayer>> m_Layers; 40 mutable std::unique_ptr<arm_compute::CLConcatenateLayer> m_Concat; 41 // permutation for output (only used when input is batch major) 42 mutable std::unique_ptr<arm_compute::CLPermute> m_Permute2; 43 44 // 45 // ACL LSTM arm_compute::CLTensors. 46 // 47 std::unique_ptr<arm_compute::CLTensor> m_InputToInputWeightsTensor; 48 std::unique_ptr<arm_compute::CLTensor> m_InputToForgetWeightsTensor; 49 std::unique_ptr<arm_compute::CLTensor> m_InputToCellWeightsTensor; 50 std::unique_ptr<arm_compute::CLTensor> m_InputToOutputWeightsTensor; 51 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToInputWeightsTensor; 52 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToForgetWeightsTensor; 53 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToCellWeightsTensor; 54 std::unique_ptr<arm_compute::CLTensor> m_RecurrentToOutputWeightsTensor; 55 std::unique_ptr<arm_compute::CLTensor> m_CellToInputWeightsTensor; 56 std::unique_ptr<arm_compute::CLTensor> m_CellToForgetWeightsTensor; 57 std::unique_ptr<arm_compute::CLTensor> m_CellToOutputWeightsTensor; 58 std::unique_ptr<arm_compute::CLTensor> m_InputGateBiasTensor; 59 std::unique_ptr<arm_compute::CLTensor> m_ForgetGateBiasTensor; 60 std::unique_ptr<arm_compute::CLTensor> m_CellBiasTensor; 61 std::unique_ptr<arm_compute::CLTensor> m_OutputGateBiasTensor; 62 std::unique_ptr<arm_compute::CLTensor> m_ProjectionWeightsTensor; 63 std::unique_ptr<arm_compute::CLTensor> m_ProjectionBiasTensor; 64 65 std::unique_ptr<arm_compute::CLTensor> m_ScratchBuffer; 66 67 std::unique_ptr<arm_compute::CLTensor> m_InputLayerNormWeightsTensor; 68 std::unique_ptr<arm_compute::CLTensor> m_ForgetLayerNormWeightsTensor; 69 std::unique_ptr<arm_compute::CLTensor> m_CellLayerNormWeightsTensor; 70 std::unique_ptr<arm_compute::CLTensor> m_OutputLayerNormWeightsTensor; 71 72 // 73 // Additional ACL arm_compute::CLTensors and std::vector<arm_compute::CLTensor>. 74 // Required to perform splitting, concatenation and permutations. 75 // 76 arm_compute::CLTensor m_PermuteFirstOut; 77 std::vector<arm_compute::CLTensor> m_SplitterOutputsTensors; 78 std::vector<arm_compute::CLTensor> m_ConcatInputsTensors; 79 std::vector<arm_compute::ICLTensor*> m_SplitterOutputs; 80 std::vector<const arm_compute::ICLTensor*> m_ConcatInputs; 81 arm_compute::CLTensor concat_out; 82 83 void FreeUnusedTensors(); 84 }; 85 86 arm_compute::Status 87 ClUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input, 88 const TensorInfo& outputStateIn, 89 const TensorInfo& cellStateIn, 90 const TensorInfo& output, 91 const Optional<TensorInfo>& hiddenStateOutput, 92 const Optional<TensorInfo>& cellStateOutput, 93 const UnidirectionalSequenceLstmDescriptor& descriptor, 94 const LstmInputParamsInfo& paramsInfo); 95 96 } //namespace armnn 97