1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "BatchMatMulLayer.hpp"
6
7 #include <armnn/backends/WorkloadFactory.hpp>
8 #include <armnnUtils/Permute.hpp>
9 #include "layers/LayerCloneBase.hpp"
10
11 namespace armnn
12 {
13
BatchMatMulLayer(const BatchMatMulDescriptor & param,const char * name)14 BatchMatMulLayer::BatchMatMulLayer(const BatchMatMulDescriptor& param, const char* name)
15 : LayerWithParameters(2, 1, LayerType::BatchMatMul, param, name)
16 {}
17
CreateWorkload(const IWorkloadFactory & factory) const18 std::unique_ptr<IWorkload> BatchMatMulLayer::CreateWorkload(const IWorkloadFactory& factory) const
19 {
20 BatchMatMulQueueDescriptor descriptor;
21 SetAdditionalInfo(descriptor);
22
23 return factory.CreateWorkload(LayerType::BatchMatMul, descriptor, PrepInfoAndDesc(descriptor));
24 }
25
Clone(Graph & graph) const26 BatchMatMulLayer* BatchMatMulLayer::Clone(Graph& graph) const
27 {
28 auto layer = CloneBase<BatchMatMulLayer>(graph, m_Param, GetName());
29
30 return std::move(layer);
31 }
32
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const33 std::vector<TensorShape> BatchMatMulLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
34 {
35 ARMNN_ASSERT(inputShapes.size() == 2);
36
37 TensorShape inputXShape = inputShapes[0];
38 TensorShape inputYShape = inputShapes[1];
39
40 // Adjoint is assumed to be square, but we will apply the permute anyway
41 if(m_Param.m_TransposeX || m_Param.m_AdjointX)
42 {
43 auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutX,
44 inputXShape);
45 inputXShape = armnnUtils::Permuted(inputXShape, permuteVec);
46 }
47 if(m_Param.m_TransposeY || m_Param.m_AdjointY)
48 {
49 auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutY,
50 inputYShape);
51 inputYShape = armnnUtils::Permuted(inputYShape, permuteVec);
52 }
53
54 TensorShape& longerInput = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
55 inputXShape : inputYShape;
56 TensorShape& shorterInput = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
57 inputYShape : inputXShape;
58
59 unsigned int inputNumDimsOffset = longerInput.GetNumDimensions() - shorterInput.GetNumDimensions();
60
61 unsigned int outputNumDimensions = longerInput.GetNumDimensions();
62
63 std::vector<unsigned int> tensorDimensions(outputNumDimensions, 0);
64
65 const auto& longerInputDataLayout = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
66 m_Param.m_DataLayoutX : m_Param.m_DataLayoutY;
67 auto longerAxesToMul = BatchMatMulDescriptor::GetAxesToMul(longerInputDataLayout,
68 longerInput);
69
70 for (unsigned int i = 0; i < outputNumDimensions; ++i)
71 {
72 if (i == longerAxesToMul.first)
73 {
74 tensorDimensions[i] = &shorterInput == &inputXShape ? inputXShape[i - inputNumDimsOffset] : inputXShape[i];
75 }
76 else if(i == longerAxesToMul.second)
77 {
78 tensorDimensions[i] = &shorterInput == &inputYShape ? inputYShape[i - inputNumDimsOffset] : inputYShape[i];
79 }
80 else // The other dimensions not to be multiplied (but may be broadcasted)
81 {
82 // Does NOT validate whether it's a valid broadcast - that's done in the validate func in WorkloadData.cpp
83 tensorDimensions[i] = static_cast<int>(i) - static_cast<int>(inputNumDimsOffset) < 0 ?
84 longerInput[i] :
85 std::max(longerInput[i], shorterInput[i - inputNumDimsOffset]);
86 }
87 }
88
89 auto outputShape = TensorShape(outputNumDimensions, tensorDimensions.data());
90 return std::vector<TensorShape>({ outputShape });
91 }
92
ValidateTensorShapesFromInputs()93 void BatchMatMulLayer::ValidateTensorShapesFromInputs()
94 {
95 VerifyLayerConnections(2, CHECK_LOCATION());
96
97 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
98
99 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
100
101 auto inferredShapes = InferOutputShapes({
102 GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
103 GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape() });
104
105 ARMNN_ASSERT(inferredShapes.size() == 1);
106
107 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "BatchMatMulLayer");
108 }
109
110 } // namespace armnn