xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/Conv3dImpl.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "Conv3dImpl.hpp"
7 
8 namespace armnn
9 {
10 
Convolve3d(const TensorShape & rInputShape,Decoder<float> & rInputDecoder,const TensorShape & rOutputShape,Encoder<float> & rOutputEncoder,const TensorShape & rFilterShape,Decoder<float> & rFilterDecoder,bool biasEnabled,Decoder<float> * pBiasDecoder,DataLayout dataLayout,unsigned int paddingTop,unsigned int paddingLeft,unsigned int paddingFront,unsigned int xStride,unsigned int yStride,unsigned int zStride,unsigned int xDilation,unsigned int yDilation,unsigned int zDilation)11 void Convolve3d(const TensorShape& rInputShape,
12                 Decoder<float>& rInputDecoder,
13                 const TensorShape& rOutputShape,
14                 Encoder<float>& rOutputEncoder,
15                 const TensorShape& rFilterShape,
16                 Decoder<float>& rFilterDecoder,
17                 bool biasEnabled,
18                 Decoder<float>* pBiasDecoder,
19                 DataLayout dataLayout,
20                 unsigned int paddingTop,
21                 unsigned int paddingLeft,
22                 unsigned int paddingFront,
23                 unsigned int xStride,
24                 unsigned int yStride,
25                 unsigned int zStride,
26                 unsigned int xDilation,
27                 unsigned int yDilation,
28                 unsigned int zDilation)
29 {
30     if (biasEnabled && !pBiasDecoder)
31     {
32         throw InvalidArgumentException("Bias is enabled but the bias data is invalid");
33     }
34     const armnnUtils::DataLayoutIndexed dataLayoutIndexed(dataLayout);
35 
36     const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex();
37     const unsigned int heightIndex   = dataLayoutIndexed.GetHeightIndex();
38     const unsigned int widthIndex    = dataLayoutIndexed.GetWidthIndex();
39     const unsigned int depthIndex    = dataLayoutIndexed.GetDepthIndex();
40 
41     const unsigned int inChannels   = rInputShape[channelsIndex];
42     const unsigned int outChannels  = rOutputShape[channelsIndex];
43 
44     const unsigned int batchSize    = rOutputShape[0];
45     const unsigned int outputHeight = rOutputShape[heightIndex];
46     const unsigned int outputWidth  = rOutputShape[widthIndex];
47     const unsigned int outputDepth  = rOutputShape[depthIndex];
48     const unsigned int inputHeight  = rInputShape[heightIndex];
49     const unsigned int inputWidth   = rInputShape[widthIndex];
50     const unsigned int inputDepth   = rInputShape[depthIndex];
51 
52     // Conv3d weights layout: [D,H,W,I,O]
53     const unsigned int filterDepth  = rFilterShape[0];
54     const unsigned int filterHeight = rFilterShape[1];
55     const unsigned int filterWidth  = rFilterShape[2];
56 
57     const std::vector<float> inputVec = rInputDecoder.DecodeTensor(rInputShape);
58     const std::vector<float> filterVec = rFilterDecoder.DecodeTensor(rFilterShape);
59 
60     const TensorShape biasShape{outChannels};
61     const std::vector<float> biasVec = biasEnabled ? pBiasDecoder->DecodeTensor(biasShape) : std::vector<float>();
62 
63     for (unsigned int batchIdx = 0; batchIdx < batchSize; batchIdx++)
64     {
65         for (unsigned int zOutput = 0; zOutput < outputDepth; zOutput++)
66         {
67             for (unsigned int xOutput = 0; xOutput < outputWidth; xOutput++)
68             {
69                 for (unsigned int yOutput = 0; yOutput < outputHeight; yOutput++)
70                 {
71                     for (unsigned int cOutput = 0; cOutput < outChannels; cOutput++)
72                     {
73                         // This loop goes over each output element.
74                         float sum = 0.0f;
75 
76                         // Loop over each input channel.
77                         for (unsigned int zFilter = 0; zFilter < filterDepth; zFilter++)
78                         {
79                             for (unsigned int yFilter = 0; yFilter < filterHeight; yFilter++)
80                             {
81                                 for (unsigned int xFilter = 0; xFilter < filterWidth; xFilter++)
82                                 {
83                                     for (unsigned int cInput = 0; cInput < inChannels; cInput++)
84                                     {
85                                         // This loop goes over each input element for each output element.
86                                         unsigned int filterIndex = 0;
87 
88                                         // Conv3d weights layout: [D,H,W,I,O]
89                                         // Keep this implementation, as using DataLayoutIndexed::GetIndex
90                                         // causes large performance regression.
91                                         filterIndex = zFilter * filterHeight * filterWidth * inChannels * outChannels +
92                                                       yFilter * filterWidth * inChannels * outChannels +
93                                                       xFilter * inChannels * outChannels +
94                                                       cInput * outChannels +
95                                                       cOutput;
96 
97                                         unsigned int yInput = yOutput * yStride + yFilter * yDilation;
98                                         unsigned int xInput = xOutput * xStride + xFilter * xDilation;
99                                         unsigned int zInput = zOutput * zStride + zFilter * zDilation;
100 
101                                         float inputValue;
102 
103                                         // Check if we're in the padding.
104                                         if (yInput < paddingTop || yInput >= inputHeight + paddingTop ||
105                                             xInput < paddingLeft || xInput >= inputWidth + paddingLeft ||
106                                             zInput < paddingFront || zInput >= inputDepth + paddingFront)
107                                         {
108                                             inputValue = 0.0f;
109                                         }
110                                         else
111                                         {
112                                             unsigned int inputIndex = 0;
113 
114                                             // Keep this implementation, as using DataLayoutIndexed::GetIndex
115                                             // causes large performance regression.
116                                             if (dataLayoutIndexed.GetDataLayout() == DataLayout::NDHWC)
117                                             {
118                                                 inputIndex =
119                                                         batchIdx * inputDepth * inputHeight * inputWidth * inChannels +
120                                                         (zInput-paddingFront) * inputHeight * inputWidth * inChannels +
121                                                         (yInput-paddingTop) * inputWidth * inChannels +
122                                                         (xInput-paddingLeft) * inChannels +
123                                                         cInput;
124                                             }
125                                             else
126                                             {
127                                                 // NCDHW DataLayout
128                                                 inputIndex =
129                                                         batchIdx * inputDepth * inputHeight * inputWidth * inChannels +
130                                                         inputDepth * inputHeight * inputWidth * cInput +
131                                                         (zInput-paddingFront) * inputHeight * inputWidth +
132                                                         (yInput-paddingTop) * inputWidth +
133                                                         xInput-paddingLeft;
134                                             }
135 
136                                             inputValue = inputVec[inputIndex];
137                                         }
138 
139                                         sum += filterVec[filterIndex] * inputValue;
140                                     }
141                                 }
142                             }
143                         }
144 
145                         if (biasEnabled)
146                         {
147                             sum += biasVec[cOutput];
148                         }
149 
150                         unsigned int outIdx;
151                         if (dataLayoutIndexed.GetDataLayout() == DataLayout::NDHWC)
152                         {
153                             outIdx = batchIdx * outputDepth * outputHeight * outputWidth * outChannels +
154                                      zOutput * outputHeight * outputWidth * outChannels +
155                                      yOutput * outputWidth * outChannels +
156                                      xOutput * outChannels +
157                                      cOutput;
158                         }
159                         else
160                         {
161                             // NCDHW DataLayout
162                             outIdx = batchIdx * outputDepth * outputHeight * outputWidth * outChannels +
163                                      cOutput * outputDepth * outputHeight * outputWidth +
164                                      zOutput * outputHeight * outputWidth +
165                                      yOutput * outputWidth +
166                                      xOutput;
167                         }
168 
169                         rOutputEncoder[outIdx];
170                         rOutputEncoder.Set(sum);
171                     }
172                 }
173             }
174         }
175     }
176 }
177 
178 } // namespace armnn
179