xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/Broadcast.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "BaseIterator.hpp"
7 #include <armnn/Tensor.hpp>
8 
9 #include <functional>
10 
11 namespace armnn
12 {
13 
14 struct BroadcastLoop
15 {
16     BroadcastLoop(const TensorShape& inShape0, const TensorShape& inShape1, const TensorShape& outShape);
17 
18     BroadcastLoop(const TensorShape& inShape, const TensorShape& outShape);
19 
GetNumDimensionsarmnn::BroadcastLoop20     unsigned int GetNumDimensions()
21     {
22         return static_cast<unsigned int>(m_DimData.size());
23     }
24 
25     template <typename Func, typename DecoderOp, typename EncoderOp>
Unrollarmnn::BroadcastLoop26     void Unroll(Func operationFunc,
27                 unsigned int dimension,
28                 DecoderOp& inData0,
29                 DecoderOp& inData1,
30                 EncoderOp& outData)
31     {
32         if (dimension >= GetNumDimensions())
33         {
34             outData.Set(operationFunc(inData0.Get(), inData1.Get()));
35             return;
36         }
37 
38         unsigned int inData0Movement = 0;
39         unsigned int inData1Movement = 0;
40         unsigned int outDataMovement = 0;
41 
42         for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++)
43         {
44             Unroll(operationFunc, dimension + 1, inData0, inData1, outData);
45 
46             inData0 += m_DimData[dimension].m_Stride1;
47             inData1 += m_DimData[dimension].m_Stride2;
48             outData += m_DimData[dimension].m_StrideOut;
49 
50             inData0Movement += m_DimData[dimension].m_Stride1;
51             inData1Movement += m_DimData[dimension].m_Stride2;
52             outDataMovement += m_DimData[dimension].m_StrideOut;
53         }
54 
55         // move iterator back to the start
56         inData0 -= inData0Movement;
57         inData1 -= inData1Movement;
58         outData -= outDataMovement;
59     }
60 
61     template <typename Func, typename DecoderOp, typename EncoderOp>
Unrollarmnn::BroadcastLoop62     void Unroll(Func operationFunc,
63                 unsigned int dimension,
64                 DecoderOp& inData,
65                 EncoderOp& outData)
66     {
67         if (dimension >= GetNumDimensions())
68         {
69             outData.Set(operationFunc(inData.Get()));
70             return;
71         }
72 
73         unsigned int inDataMovement = 0;
74         unsigned int outDataMovement = 0;
75 
76         for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++)
77         {
78             Unroll(operationFunc, dimension + 1, inData, outData);
79 
80             inData += m_DimData[dimension].m_Stride1;
81             outData += m_DimData[dimension].m_StrideOut;
82 
83             inDataMovement += m_DimData[dimension].m_Stride1;
84             outDataMovement += m_DimData[dimension].m_StrideOut;
85         }
86 
87         // move iterator back to the start
88         inData -= inDataMovement;
89         outData -= outDataMovement;
90     }
91 
92 private:
93     // Struct to hold the dimension data.
94     struct BroadcastDimensionData
95     {
96         unsigned int m_DimSize;
97         unsigned int m_StrideOut;
98         unsigned int m_Stride1;
99         unsigned int m_Stride2;
100     };
101 
102     std::vector<BroadcastDimensionData> m_DimData;
103 };
104 
105 } //namespace armnn