1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/Tensor.hpp> 9 10 #include "../ConversionUtils.hpp" 11 12 namespace armnn_driver 13 { 14 FlattenFullyConnectedInput(const armnn::TensorShape & inputShape,const armnn::TensorShape & weightsShape)15inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape& inputShape, 16 const armnn::TensorShape& weightsShape) 17 { 18 if (inputShape.GetNumDimensions() > 2U) 19 { 20 unsigned int totalInputElements = inputShape.GetNumElements(); 21 unsigned int inputSize = weightsShape[1]; 22 23 unsigned int batchSize = totalInputElements / inputSize; 24 25 if(totalInputElements % batchSize != 0) 26 { 27 throw std::runtime_error("Failed to deduce tensor shape"); 28 } 29 30 return armnn::TensorShape({batchSize, inputSize}); 31 } 32 else 33 { 34 return inputShape; 35 } 36 } 37 VerifyFullyConnectedShapes(const armnn::TensorShape & inputShape,const armnn::TensorShape & weightsShape,const armnn::TensorShape & outputShape,bool transposeWeightMatrix)38inline bool VerifyFullyConnectedShapes(const armnn::TensorShape& inputShape, 39 const armnn::TensorShape& weightsShape, 40 const armnn::TensorShape& outputShape, 41 bool transposeWeightMatrix) 42 { 43 unsigned int dimIdx = transposeWeightMatrix ? 0 : 1; 44 return (inputShape[0] == outputShape[0] && weightsShape[dimIdx] == outputShape[1]); 45 } 46 47 }