xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/BatchMatMulImpl.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "BatchMatMulImpl.hpp"
7 
8 #include <armnn/backends/WorkloadData.hpp>
9 #include <armnn/Logging.hpp>
10 #include <armnnUtils/Permute.hpp>
11 
12 namespace armnn
13 {
14 
BatchMatMul(const BatchMatMulDescriptor & params,const TensorInfo & inputXInfo,const TensorInfo & inputYInfo,const TensorInfo & outputInfo,Decoder<float> & inputXDecoder,Decoder<float> & inputYDecoder,Encoder<float> & outputEncoder)15 BatchMatMul::BatchMatMul(const BatchMatMulDescriptor& params,
16                          const TensorInfo& inputXInfo,
17                          const TensorInfo& inputYInfo,
18                          const TensorInfo& outputInfo,
19                          Decoder<float>& inputXDecoder,
20                          Decoder<float>& inputYDecoder,
21                          Encoder<float>& outputEncoder)
22     : params(params),
23       inputXInfo(inputXInfo),
24       inputYInfo(inputYInfo),
25       outputInfo(outputInfo),
26       inputXDecoder(inputXDecoder),
27       inputYDecoder(inputYDecoder),
28       outputEncoder(outputEncoder)
29 {
30     inputXData = this->inputXDecoder.DecodeTensor(inputXInfo.GetShape());
31     inputYData = this->inputYDecoder.DecodeTensor(inputYInfo.GetShape());
32     // At this point, we don't touch the input decoders - just the resultant vectors
33 
34     ApplyParams();
35 
36     ApplyBatchMatMul();
37 }
38 
ApplyBatchMatMul()39 void BatchMatMul::ApplyBatchMatMul()
40 {
41     auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutX,
42                                                           inputXInfo.GetShape());
43     auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutY,
44                                                           inputYInfo.GetShape());
45     AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul);
46 
47     unsigned int inputXColDim = axesXToMul.second;
48     unsigned int inputYRowDim = axesYToMul.first;
49 
50     unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim];
51 
52     auto batchMatMulOperation = [&](const std::vector<unsigned int>& curIdx)
53     {
54         float sum = 0.0f;
55 
56         // InputYRowSize is synonymous with inputXColSize
57         for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
58             auto xIdx = curIdx;
59             xIdx[inputXColDim] = inputYRowIdx;
60 
61             auto yIdx = curIdx;
62             yIdx[inputYRowDim] = inputYRowIdx;
63 
64             sum += (GetValueAt(DataSlot::InputX, xIdx) * GetValueAt(DataSlot::InputY, yIdx));
65         }
66 
67         SetValueAt(sum, DataSlot::Output, curIdx);
68     };
69 
70     auto startIdx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
71     RecurseTensor(outputInfo,
72                   batchMatMulOperation,
73                   startIdx,
74                   0);
75 }
76 
ApplyParams()77 void BatchMatMul::ApplyParams()
78 {
79     if(params.m_TransposeX)
80     {
81         Transpose(DataSlot::InputX);
82     }
83     else if(params.m_AdjointX)
84     {
85         Adjoint(DataSlot::InputX);
86     }
87     if(params.m_TransposeY)
88     {
89         Transpose(DataSlot::InputY);
90     }
91     else if(params.m_AdjointY)
92     {
93         Adjoint(DataSlot::InputY);
94     }
95 }
96 
Transpose(DataSlot type)97 void BatchMatMul::Transpose(DataSlot type)
98 {
99     // AKA the permute of the tensor
100     // This modifies the tensor's info.
101 
102     switch(type)
103     {
104         case DataSlot::InputX:
105         {
106             auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutX,
107                                                                    inputXInfo.GetShape());
108             inputXInfo = armnnUtils::Permuted(inputXInfo, permuteVec);
109             std::vector<float> temp(inputXData.size());
110             armnnUtils::Permute(inputXInfo.GetShape(),
111                                 permuteVec,
112                                 inputXData.data(),
113                                 temp.data(),
114                                 sizeof(float));
115             inputXData = temp;
116             break;
117         }
118         case DataSlot::InputY:
119         {
120             auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutY,
121                                                                    inputYInfo.GetShape());
122             inputYInfo = armnnUtils::Permuted(inputYInfo, permuteVec);
123             std::vector<float> temp(inputYData.size());
124             armnnUtils::Permute(inputYInfo.GetShape(),
125                                 permuteVec,
126                                 inputYData.data(),
127                                 temp.data(),
128                                 sizeof(float));
129             inputYData = temp;
130             break;
131         }
132         case DataSlot::Output: // We needn't transpose the output tensor
133         default:
134             break;
135     }
136 }
137 
Adjoint(DataSlot type)138 void BatchMatMul::Adjoint(DataSlot type)
139 {
140     // Finding the adjoint of a square matrix:
141     // Calculate the cofactor of each element (using Gauss elimination here)
142     // Apply a transpose to it (this also modifies the tensor's info)
143 
144     TensorInfo& inputInfo = (type == DataSlot::InputX) ? inputXInfo : inputYInfo;
145     const auto& dataLayout = (type == DataSlot::InputX) ? params.m_DataLayoutX : params.m_DataLayoutY;
146     const auto axesToAdjoint = BatchMatMulDescriptor::GetAxesToMul(dataLayout,inputInfo.GetShape());
147 
148     ARMNN_ASSERT(inputInfo.GetShape()[axesToAdjoint.first] == inputInfo.GetShape()[axesToAdjoint.second]);
149     // We grab a copy of the tensor data to prevent overwriting
150     std::vector<float> inputDataClone = (type == DataSlot::InputX) ? inputXData : inputYData;
151 
152     // The sub-matrix is the resultant matrix when the row and column of the current index is removed
153     unsigned int subMatAxisSize = inputInfo.GetShape()[axesToAdjoint.first] - 1;
154     std::vector<std::vector<float>> subMat(subMatAxisSize,
155                                            std::vector<float>(subMatAxisSize));
156 
157     // Lambdas for each sub-step of the cofactor operation
158     auto almostEquals = [&](const float& a, const float& b, float unitsInLastPlace = 2.0f)
159     {
160         float diff = std::fabs(a-b);
161         float bound = diff * std::numeric_limits<float>::epsilon() * unitsInLastPlace;
162         return (diff <= bound) || (diff < std::numeric_limits<float>::min());
163     };
164 
165     float swapMultiplier = std::numeric_limits<float>::max();
166     auto swapRows = [&](unsigned int rowIdxA, unsigned int rowIdxB)
167     {
168         // Every row swap flips this around by the negative (set to 1 at the beginning of each cofactor op run)
169         for(unsigned int colIdx = 0; colIdx < subMatAxisSize; colIdx++)
170         {
171             float tmp = subMat[rowIdxA][colIdx];
172             subMat[rowIdxA][colIdx] = subMat[rowIdxB][colIdx];
173             subMat[rowIdxB][colIdx] = tmp;
174         }
175         swapMultiplier *= -1.0f;
176     };
177 
178     auto findNextValidPivotRowIdx = [&](unsigned int colIdx)
179     {
180         unsigned int result = std::numeric_limits<unsigned int>::max();
181 
182         // The original diagonal has been checked and is invalid
183         for(unsigned int rowIdx = colIdx+1; rowIdx < subMatAxisSize; rowIdx++)
184         {
185             if(!almostEquals(subMat[rowIdx][colIdx], 0.0f))
186             {
187                 result = rowIdx;
188                 break;
189             }
190         }
191         return result;
192     };
193 
194     auto eliminate = [&](const float& pivot, unsigned int pivotPos)
195     {
196         for(unsigned int rowIdx = pivotPos+1; rowIdx < subMatAxisSize; rowIdx++)
197         {
198             float multiplierNumerator = subMat[rowIdx][pivotPos];
199             if(almostEquals(multiplierNumerator, 0.0f))
200             {
201                 continue;
202             }
203             float multiplier = multiplierNumerator / pivot; // Susceptible to floating point inaccuracies
204                                                             // Hence the almostEquals usage to counteract this
205             for(unsigned int colIdx = pivotPos; colIdx < subMatAxisSize; colIdx++)
206             {
207                 // We start at col=pivotPos as we have assumed that all elements
208                 // to our left have been eliminated to zero already
209 
210                 // We subtract based on the element directly above us in our pivot row
211                 subMat[rowIdx][colIdx] -= multiplier * subMat[pivotPos][colIdx];
212             }
213         }
214     };
215 
216     auto cofactorOperation = [&](const std::vector<unsigned int>& curIdx)
217     {
218         auto row = curIdx[axesToAdjoint.first];
219         auto col = curIdx[axesToAdjoint.second];
220 
221         float minorMultiplier = static_cast<float>(std::pow(-1, (row + 1 + col + 1)));
222 
223         for(unsigned int subRow = 0; subRow < subMatAxisSize; subRow++)
224         {
225             for(unsigned int subCol = 0; subCol < subMatAxisSize; subCol++)
226             {
227                 unsigned int outerRow = (subRow >= row)?subRow + 1:subRow;
228                 unsigned int outerCol = (subCol >= col)?subCol + 1:subCol;
229                 auto cloneIdx = curIdx;
230                 cloneIdx[axesToAdjoint.first] = outerRow;
231                 cloneIdx[axesToAdjoint.second] = outerCol;
232                 subMat[subRow][subCol] = GetValueAt(type,cloneIdx,inputDataClone);
233             }
234         }
235 
236         float determinant = 1.0f;
237 
238         // Cover the edge cases and simple base cases before resorting to Gauss elimination for larger matrices
239         switch(subMatAxisSize)
240         {
241             case 0:
242             {
243                 determinant = GetValueAt(type, curIdx, inputDataClone);
244                 break;
245             }
246             case 1:
247             {
248                 // If the resultant sub-matrix is just one element - that's the determinant
249                 determinant = subMat[0][0];
250                 break;
251             }
252             case 2:
253             {
254                 // For a 2x2 sub-matrix, the determinant is just a*d-b*c
255                 determinant = subMat[0][0] * subMat[1][1] -
256                               subMat[0][1] * subMat[1][0];
257                 break;
258             }
259             default:
260             {
261                 // Gaussian elimination to find the determinant of this sub-matrix
262                 swapMultiplier = 1.0f;
263                 // March diagonally down the pivots and if it's invalid (a zero), swap the row with the
264                 // nearest non-zero down within the column
265                 for(unsigned int pivotRow = 0, pivotCol = 0;
266                     pivotRow < subMatAxisSize;
267                     pivotRow++, pivotCol++)
268                 {
269                     float& pivot = subMat[pivotRow][pivotCol];
270 
271                     if(almostEquals(pivot, 0.0f))
272                     {
273                         unsigned int nextValidPivotRowIdx = findNextValidPivotRowIdx(pivotCol);
274                         if(nextValidPivotRowIdx == std::numeric_limits<unsigned int>::max())
275                         {
276                             // No valid pivot down this column, which means that this pivot remains a zero.
277                             // This results in the determinant for this entire sub-matrix to just be zero.
278                             determinant = 0.0f;
279                             break;
280                         }
281                         swapRows(pivotRow, nextValidPivotRowIdx);
282                     }
283                     determinant *= pivot;
284                     // The actual elimination bit (which will update/propagate to the pivots down the line)
285                     eliminate(pivot, pivotRow); // Synonymous with pivotCol
286                 }
287 
288                 determinant *= swapMultiplier;
289                 break;
290             }
291         }
292         float cofactor = minorMultiplier * determinant;
293         SetValueAt(cofactor, type, curIdx);
294     };
295 
296     auto startIdx = std::vector<unsigned int>(inputInfo.GetNumDimensions(), 0);
297     RecurseTensor(inputInfo,
298                   cofactorOperation,
299                   startIdx,
300                   0);
301 
302     Transpose(type);
303 }
304 
RecurseTensor(const TensorInfo & tensorInfo,const std::function<void (const std::vector<unsigned int> &)> & operation,std::vector<unsigned int> & curIdx,unsigned int curDim)305 void BatchMatMul::RecurseTensor(const TensorInfo& tensorInfo,
306                                 const std::function<void(const std::vector<unsigned int>&)>& operation,
307                                 std::vector<unsigned int>& curIdx,
308                                 unsigned int curDim)
309 {
310     if(!(curDim < tensorInfo.GetNumDimensions()))
311     {
312         // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
313         operation(curIdx);
314         return;
315     }
316 
317     for(unsigned int i = 0; i < tensorInfo.GetShape()[curDim]; i++)
318     {
319         curIdx[curDim] = i;
320         RecurseTensor(tensorInfo,
321                       operation,
322                       curIdx,
323                       curDim + 1);
324     }
325 }
326 
AdjustAxesToMulForUnequalRanks(std::pair<unsigned int,unsigned int> & axesXToMul,std::pair<unsigned int,unsigned int> & axesYToMul)327 void BatchMatMul::AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
328                                                  std::pair<unsigned int, unsigned int>& axesYToMul)
329 {
330     int rankDiff = static_cast<int>(inputXInfo.GetNumDimensions()) -
331                    static_cast<int>(inputYInfo.GetNumDimensions());
332     if(rankDiff == 0)
333     {
334         return;
335     }
336     else if(rankDiff < 0)
337     {
338         // Y is the larger one
339         axesXToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
340         axesXToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
341     }
342     else if(rankDiff > 0)
343     {
344         // X is the larger one
345         axesYToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
346         axesYToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
347     }
348 }
349 
GetValueAt(DataSlot type,std::vector<unsigned int> idx,const std::vector<float> & customData)350 float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData)
351 {
352     // This gets the data from the input vector that we have, Not the decoder
353     // But for the output, it is operating on the encoder itself
354 
355     AdjustToSafeIdx(type, idx);
356     unsigned int flatIdx = CalcFlatIdx(type, idx);
357     float value = 0.0f;
358     switch(type)
359     {
360         case DataSlot::InputX:
361             value = customData.empty() ? inputXData[flatIdx] : customData[flatIdx];
362             break;
363         case DataSlot::InputY:
364             value = customData.empty() ? inputYData[flatIdx] : customData[flatIdx];
365             break;
366         case DataSlot::Output:
367             outputEncoder[flatIdx];
368             value = outputEncoder.Get();
369             break;
370         default:
371             break;
372     }
373 
374     return value;
375 }
376 
SetValueAt(float value,DataSlot type,std::vector<unsigned int> idx)377 void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx)
378 {
379     AdjustToSafeIdx(type, idx);
380     unsigned int flatIdx = CalcFlatIdx(type, idx);
381     switch(type)
382     {
383         case DataSlot::InputX:
384             inputXData[flatIdx] = value;
385             break;
386         case DataSlot::InputY:
387             inputYData[flatIdx] = value;
388             break;
389         case DataSlot::Output:
390             outputEncoder[flatIdx];
391             outputEncoder.Set(value);
392             break;
393         default:
394             break;
395     }
396 }
397 
AdjustToSafeIdx(DataSlot type,std::vector<unsigned int> & idx)398 void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
399 {
400     for(unsigned int dim = 0; dim < idx.size(); dim++)
401     {
402         switch(type)
403         {
404             case DataSlot::InputX:
405             {
406                 auto xRank = inputXInfo.GetNumDimensions();
407                 auto xDiff = outputInfo.GetNumDimensions() - xRank;
408                 if (dim < xDiff ||
409                     idx[dim] > inputXInfo.GetShape()[dim-xDiff]-1)
410                 {
411                     idx[dim] = 0; // Broadcasting
412                 }
413                 break;
414             }
415             case DataSlot::InputY:
416             {
417                 auto yRank = inputYInfo.GetNumDimensions();
418                 auto yDiff = outputInfo.GetNumDimensions() - yRank;
419                 if (dim < yDiff ||
420                     idx[dim] > inputYInfo.GetShape()[dim-yDiff]-1)
421                 {
422                     idx[dim] = 0;
423                 }
424                 break;
425             }
426             case DataSlot::Output:
427             {
428                 // Our indices are based off the output
429                 break;
430             }
431             default:
432                 break;
433         }
434     }
435 }
436 
CalcFlatIdx(DataSlot type,const std::vector<unsigned int> & idx)437 unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx)
438 {
439     unsigned int result = idx[idx.size()-1];
440     unsigned int dimMultiplier = 1;
441     unsigned int offset;
442 
443     // -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x)
444     for(unsigned int i = static_cast<unsigned int>(idx.size()-2); static_cast<int>(i) >= 0; i--)
445     {
446         switch(type)
447         {
448             case DataSlot::InputX:
449                 offset = outputInfo.GetNumDimensions() - inputXInfo.GetNumDimensions();
450                 dimMultiplier *= inputXInfo.GetShape()[i + 1 - offset];
451                 break;
452             case DataSlot::InputY:
453                 offset = outputInfo.GetNumDimensions() - inputYInfo.GetNumDimensions();
454                 dimMultiplier *= inputYInfo.GetShape()[i + 1 - offset];
455                 break;
456             case DataSlot::Output:
457                 dimMultiplier *= outputInfo.GetShape()[i+1];
458                 break;
459             default:
460                 break;
461         }
462         result += (idx[i] * dimMultiplier);
463     }
464     return result;
465 }
466 
467 } // namespace armnn