xref: /aosp_15_r20/external/armnn/src/armnn/layers/BatchMatMulLayer.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "BatchMatMulLayer.hpp"
6 
7 #include <armnn/backends/WorkloadFactory.hpp>
8 #include <armnnUtils/Permute.hpp>
9 #include "layers/LayerCloneBase.hpp"
10 
11 namespace armnn
12 {
13 
BatchMatMulLayer(const BatchMatMulDescriptor & param,const char * name)14 BatchMatMulLayer::BatchMatMulLayer(const BatchMatMulDescriptor& param, const char* name)
15     : LayerWithParameters(2, 1, LayerType::BatchMatMul, param, name)
16 {}
17 
CreateWorkload(const IWorkloadFactory & factory) const18 std::unique_ptr<IWorkload> BatchMatMulLayer::CreateWorkload(const IWorkloadFactory& factory) const
19 {
20     BatchMatMulQueueDescriptor descriptor;
21     SetAdditionalInfo(descriptor);
22 
23     return factory.CreateWorkload(LayerType::BatchMatMul, descriptor, PrepInfoAndDesc(descriptor));
24 }
25 
Clone(Graph & graph) const26 BatchMatMulLayer* BatchMatMulLayer::Clone(Graph& graph) const
27 {
28     auto layer = CloneBase<BatchMatMulLayer>(graph, m_Param, GetName());
29 
30     return std::move(layer);
31 }
32 
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const33 std::vector<TensorShape> BatchMatMulLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
34 {
35     ARMNN_ASSERT(inputShapes.size() == 2);
36 
37     TensorShape inputXShape = inputShapes[0];
38     TensorShape inputYShape = inputShapes[1];
39 
40     // Adjoint is assumed to be square, but we will apply the permute anyway
41     if(m_Param.m_TransposeX || m_Param.m_AdjointX)
42     {
43         auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutX,
44                                                                inputXShape);
45         inputXShape = armnnUtils::Permuted(inputXShape, permuteVec);
46     }
47     if(m_Param.m_TransposeY || m_Param.m_AdjointY)
48     {
49         auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutY,
50                                                                inputYShape);
51         inputYShape = armnnUtils::Permuted(inputYShape, permuteVec);
52     }
53 
54     TensorShape& longerInput = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
55                                inputXShape : inputYShape;
56     TensorShape& shorterInput = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
57                                 inputYShape : inputXShape;
58 
59     unsigned int inputNumDimsOffset = longerInput.GetNumDimensions() - shorterInput.GetNumDimensions();
60 
61     unsigned int outputNumDimensions = longerInput.GetNumDimensions();
62 
63     std::vector<unsigned int> tensorDimensions(outputNumDimensions, 0);
64 
65     const auto& longerInputDataLayout = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
66                                         m_Param.m_DataLayoutX : m_Param.m_DataLayoutY;
67     auto longerAxesToMul = BatchMatMulDescriptor::GetAxesToMul(longerInputDataLayout,
68                                                                longerInput);
69 
70     for (unsigned int i = 0; i < outputNumDimensions; ++i)
71     {
72         if (i == longerAxesToMul.first)
73         {
74             tensorDimensions[i] = &shorterInput == &inputXShape ? inputXShape[i - inputNumDimsOffset] : inputXShape[i];
75         }
76         else if(i == longerAxesToMul.second)
77         {
78             tensorDimensions[i] = &shorterInput == &inputYShape ? inputYShape[i - inputNumDimsOffset] : inputYShape[i];
79         }
80         else // The other dimensions not to be multiplied (but may be broadcasted)
81         {
82             // Does NOT validate whether it's a valid broadcast - that's done in the validate func in WorkloadData.cpp
83             tensorDimensions[i] = static_cast<int>(i) - static_cast<int>(inputNumDimsOffset) < 0 ?
84                 longerInput[i] :
85                 std::max(longerInput[i], shorterInput[i - inputNumDimsOffset]);
86         }
87     }
88 
89     auto outputShape = TensorShape(outputNumDimensions, tensorDimensions.data());
90     return std::vector<TensorShape>({ outputShape });
91 }
92 
ValidateTensorShapesFromInputs()93 void BatchMatMulLayer::ValidateTensorShapesFromInputs()
94 {
95     VerifyLayerConnections(2, CHECK_LOCATION());
96 
97     const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
98 
99     VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
100 
101     auto inferredShapes = InferOutputShapes({
102         GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
103         GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape() });
104 
105     ARMNN_ASSERT(inferredShapes.size() == 1);
106 
107     ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "BatchMatMulLayer");
108 }
109 
110 } // namespace armnn