1 // 2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/TypesUtils.hpp> 9 #include <armnn/backends/WorkloadData.hpp> 10 11 #include "Encoders.hpp" 12 #include "Decoders.hpp" 13 14 namespace armnn 15 { 16 17 void LstmImpl(const LstmDescriptor& descriptor, 18 const TensorInfo& inputInfo, 19 const TensorInfo& outputInfo, 20 const TensorShape& inputToOutputWeightsShape, 21 const TensorShape& recurrentToOutputWeightsShape, 22 std::unique_ptr<Decoder<float>>& inputData, 23 std::unique_ptr<Decoder<float>>& outputStateIn, 24 std::unique_ptr<Decoder<float>>& cellStateIn, 25 std::unique_ptr<Encoder<float>>& outputStateOut, 26 std::unique_ptr<Encoder<float>>& cellStateOut, 27 std::unique_ptr<Encoder<float>>& output, 28 std::unique_ptr<Decoder<float>>& cellStateOutDecoder, 29 std::unique_ptr<Decoder<float>>& outputDecoder, 30 std::unique_ptr<Decoder<float>>& inputToInputWeightsTensor, 31 std::unique_ptr<Decoder<float>>& inputToForgetWeightsTensor, 32 std::unique_ptr<Decoder<float>>& inputToCellWeightsTensor, 33 std::unique_ptr<Decoder<float>>& inputToOutputWeightsTensor, 34 std::unique_ptr<Decoder<float>>& recurrentToInputWeightsTensor, 35 std::unique_ptr<Decoder<float>>& recurrentToForgetWeightsTensor, 36 std::unique_ptr<Decoder<float>>& recurrentToCellWeightsTensor, 37 std::unique_ptr<Decoder<float>>& recurrentToOutputWeightsTensor, 38 std::unique_ptr<Decoder<float>>& cellToInputWeightsTensor, 39 std::unique_ptr<Decoder<float>>& cellToForgetWeightsTensor, 40 std::unique_ptr<Decoder<float>>& cellToOutputWeightsTensor, 41 std::unique_ptr<Decoder<float>>& inputGateBiasTensor, 42 std::unique_ptr<Decoder<float>>& forgetGateBiasTensor, 43 std::unique_ptr<Decoder<float>>& cellBiasTensor, 44 std::unique_ptr<Decoder<float>>& outputGateBiasTensor, 45 std::unique_ptr<Decoder<float>>& projectionWeightsTensor, 46 std::unique_ptr<Decoder<float>>& projectionBiasTensor, 47 std::unique_ptr<Decoder<float>>& inputLayerNormWeights, 48 std::unique_ptr<Decoder<float>>& forgetLayerNormWeights, 49 std::unique_ptr<Decoder<float>>& cellLayerNormWeights, 50 std::unique_ptr<Decoder<float>>& outputLayerNormWeights, 51 std::unique_ptr<Encoder<float>>& inputGateScratch, 52 std::unique_ptr<Encoder<float>>& cellScratch, 53 std::unique_ptr<Encoder<float>>& forgetGateScratch, 54 std::unique_ptr<Encoder<float>>& outputGateScratch, 55 std::unique_ptr<Decoder<float>>& inputGateScratchDecoder, 56 std::unique_ptr<Decoder<float>>& cellScratchDecoder, 57 std::unique_ptr<Decoder<float>>& forgetGateScratchDecoder, 58 std::unique_ptr<Decoder<float>>& outputGateScratchDecoder, 59 float layerNormEpsilon); 60 61 } //namespace armnn 62