1 // 2 // Copyright © 2019 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "BaseIterator.hpp" 7 #include <armnn/Tensor.hpp> 8 9 #include <functional> 10 11 namespace armnn 12 { 13 14 struct BroadcastLoop 15 { 16 BroadcastLoop(const TensorShape& inShape0, const TensorShape& inShape1, const TensorShape& outShape); 17 18 BroadcastLoop(const TensorShape& inShape, const TensorShape& outShape); 19 GetNumDimensionsarmnn::BroadcastLoop20 unsigned int GetNumDimensions() 21 { 22 return static_cast<unsigned int>(m_DimData.size()); 23 } 24 25 template <typename Func, typename DecoderOp, typename EncoderOp> Unrollarmnn::BroadcastLoop26 void Unroll(Func operationFunc, 27 unsigned int dimension, 28 DecoderOp& inData0, 29 DecoderOp& inData1, 30 EncoderOp& outData) 31 { 32 if (dimension >= GetNumDimensions()) 33 { 34 outData.Set(operationFunc(inData0.Get(), inData1.Get())); 35 return; 36 } 37 38 unsigned int inData0Movement = 0; 39 unsigned int inData1Movement = 0; 40 unsigned int outDataMovement = 0; 41 42 for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++) 43 { 44 Unroll(operationFunc, dimension + 1, inData0, inData1, outData); 45 46 inData0 += m_DimData[dimension].m_Stride1; 47 inData1 += m_DimData[dimension].m_Stride2; 48 outData += m_DimData[dimension].m_StrideOut; 49 50 inData0Movement += m_DimData[dimension].m_Stride1; 51 inData1Movement += m_DimData[dimension].m_Stride2; 52 outDataMovement += m_DimData[dimension].m_StrideOut; 53 } 54 55 // move iterator back to the start 56 inData0 -= inData0Movement; 57 inData1 -= inData1Movement; 58 outData -= outDataMovement; 59 } 60 61 template <typename Func, typename DecoderOp, typename EncoderOp> Unrollarmnn::BroadcastLoop62 void Unroll(Func operationFunc, 63 unsigned int dimension, 64 DecoderOp& inData, 65 EncoderOp& outData) 66 { 67 if (dimension >= GetNumDimensions()) 68 { 69 outData.Set(operationFunc(inData.Get())); 70 return; 71 } 72 73 unsigned int inDataMovement = 0; 74 unsigned int outDataMovement = 0; 75 76 for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++) 77 { 78 Unroll(operationFunc, dimension + 1, inData, outData); 79 80 inData += m_DimData[dimension].m_Stride1; 81 outData += m_DimData[dimension].m_StrideOut; 82 83 inDataMovement += m_DimData[dimension].m_Stride1; 84 outDataMovement += m_DimData[dimension].m_StrideOut; 85 } 86 87 // move iterator back to the start 88 inData -= inDataMovement; 89 outData -= outDataMovement; 90 } 91 92 private: 93 // Struct to hold the dimension data. 94 struct BroadcastDimensionData 95 { 96 unsigned int m_DimSize; 97 unsigned int m_StrideOut; 98 unsigned int m_Stride1; 99 unsigned int m_Stride2; 100 }; 101 102 std::vector<BroadcastDimensionData> m_DimData; 103 }; 104 105 } //namespace armnn