xref: /aosp_15_r20/external/armnn/src/backends/neon/workloads/NeonBatchMatMulWorkload.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "NeonBatchMatMulWorkload.hpp"
7 
8 #include "NeonWorkloadUtils.hpp"
9 
10 #include <armnn/utility/PolymorphicDowncast.hpp>
11 
12 #include <armnnUtils/Permute.hpp>
13 
14 #include <backendsCommon/WorkloadUtils.hpp>
15 
16 #include <arm_compute/runtime/NEON/functions/NEGEMM.h>
17 
18 #include <arm_compute/runtime/NEON/functions/NEPermute.h>
19 
20 
21 namespace armnn
22 {
NeonBatchMatMulValidate(const TensorInfo & inputX,const TensorInfo & inputY,const TensorInfo & output,const BatchMatMulDescriptor & descriptor)23 arm_compute::Status NeonBatchMatMulValidate(const TensorInfo& inputX,
24                                             const TensorInfo& inputY,
25                                             const TensorInfo& output,
26                                             const BatchMatMulDescriptor& descriptor)
27 {
28     if (descriptor.m_AdjointX || descriptor.m_AdjointY )
29     {
30         throw Exception("Support for adjoint not implemented.");
31     }
32     if (descriptor.m_DataLayoutX != armnn::DataLayout::NCHW || descriptor.m_DataLayoutY != armnn::DataLayout::NCHW )
33     {
34         throw Exception("Only supported the MatMul in the last 2 dimensions");
35     }
36 
37     const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.m_DataLayoutX);
38     const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.m_DataLayoutY);
39     const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
40 
41     arm_compute::Status statusGEMM = arm_compute::Status(arm_compute::ErrorCode::OK);
42     arm_compute::Status statusPermuteX = arm_compute::Status(arm_compute::ErrorCode::OK);
43     arm_compute::Status statusPermuteY = arm_compute::Status(arm_compute::ErrorCode::OK);
44 
45     arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
46     arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
47 
48     if (descriptor.m_TransposeX == true)
49     {
50         auto permutationXVector = GeneratePermutationVectorOnLastTwoDimensions(inputX.GetNumDimensions());
51         const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
52         const TensorInfo permutedXInfo = armnnUtils::Permuted(inputX, permutationXVector);
53         aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo);
54 
55         statusPermuteX = arm_compute::NEPermute::validate(&aclInputXInfo,
56                                                           &aclPermutedXInfo,
57                                                           aclPermutationXVector);
58     }
59 
60     if (descriptor.m_TransposeY == true)
61     {
62         auto permutationYVector = GeneratePermutationVectorOnLastTwoDimensions(inputY.GetNumDimensions());
63         const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
64         const TensorInfo permutedYInfo = armnnUtils::Permuted(inputY, permutationYVector);
65         aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo);
66 
67         statusPermuteY = arm_compute::NEPermute::validate(&aclInputYInfo,
68                                                           &aclPermutedYInfo,
69                                                           aclPermutationYVector);
70     }
71 
72     const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false,  // is inputX reshaped
73                                                                    false,  // is inputY reshaped
74                                                                    false); // is inputY reshaped only 1st run
75 
76     statusGEMM = arm_compute::NEGEMM::validate(descriptor.m_TransposeX ? &aclPermutedXInfo : &aclInputXInfo,
77                                                descriptor.m_TransposeY ? &aclPermutedYInfo : &aclInputYInfo,
78                                                nullptr,
79                                                &aclOutputInfo,
80                                                1.0,
81                                                0,
82                                                gemm_info);
83 
84     if (statusPermuteX.error_code() == arm_compute::ErrorCode::OK &&
85         statusPermuteY.error_code() == arm_compute::ErrorCode::OK &&
86         statusGEMM.error_code()     == arm_compute::ErrorCode::OK)
87     {
88         return arm_compute::Status(arm_compute::ErrorCode::OK,
89                                    "All BatchMatMul layers validate status OK.");
90     }
91     else
92     {
93         return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
94                                    "BatchMatMul layer validate status failed."
95                                    + statusGEMM.error_description()
96                                    + statusPermuteX.error_description()
97                                    + statusPermuteY.error_description());
98     }
99 
100 }
101 
NeonBatchMatMulWorkload(const BatchMatMulQueueDescriptor & descriptor,const WorkloadInfo & info)102 NeonBatchMatMulWorkload::NeonBatchMatMulWorkload(
103     const BatchMatMulQueueDescriptor& descriptor, const WorkloadInfo& info)
104     : NeonBaseWorkload<BatchMatMulQueueDescriptor>(descriptor, info)
105 {
106     if (descriptor.m_Parameters.m_AdjointX || descriptor.m_Parameters.m_AdjointY )
107     {
108         throw Exception("Support for adjoint not implemented.");
109     }
110     if (descriptor.m_Parameters.m_DataLayoutX != armnn::DataLayout::NCHW ||
111         descriptor.m_Parameters.m_DataLayoutY != armnn::DataLayout::NCHW )
112     {
113         throw Exception("Only supported the MatMul in the last 2 dimensions");
114     }
115 
116     // Report Profiling Details
117     ARMNN_REPORT_PROFILING_WORKLOAD_DESC("NeonBatchMatMulWorkload_Construct",
118                                          descriptor.m_Parameters,
119                                          info,
120                                          this->GetGuid());
121 
122     m_Data.ValidateInputsOutputs("NeonBatchMatMulWorkload", 2, 1);
123 
124     arm_compute::ITensor& inputX = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
125     arm_compute::ITensor& inputY = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
126     auto outputHandle = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Outputs[0]);
127     arm_compute::ITensor& output = outputHandle->GetTensor();
128 
129     arm_compute::DataLayout aclDataLayoutX = ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutX);
130     arm_compute::DataLayout aclDataLayoutY = ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutY);
131 
132     inputX.info()->set_data_layout(aclDataLayoutX);
133     inputY.info()->set_data_layout(aclDataLayoutY);
134 
135     if (descriptor.m_Parameters.m_TransposeX == true)
136     {
137         armnn::PermutationVector permutationXVector
138                 = GeneratePermutationVectorOnLastTwoDimensions(info.m_InputTensorInfos[0].GetNumDimensions());
139         const TensorInfo permutedXInfo = armnnUtils::Permuted(info.m_InputTensorInfos[0], permutationXVector);
140         const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
141 
142         auto permuteLayerX = std::make_unique<arm_compute::NEPermute>();
143         BuildArmComputeTensor(m_PermutedTensorX, permutedXInfo);
144         InitialiseArmComputeTensorEmpty(m_PermutedTensorX);
145         permuteLayerX->configure(&inputX, &m_PermutedTensorX, aclPermutationXVector);
146         m_PermuteLayerX.reset(permuteLayerX.release());
147     }
148 
149     if (descriptor.m_Parameters.m_TransposeY == true)
150     {
151         armnn::PermutationVector permutationYVector
152                 = GeneratePermutationVectorOnLastTwoDimensions(info.m_InputTensorInfos[1].GetNumDimensions());
153         const TensorInfo permutedYInfo = armnnUtils::Permuted(info.m_InputTensorInfos[1], permutationYVector);
154         const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
155 
156         auto permuteLayerY = std::make_unique<arm_compute::NEPermute>();
157         BuildArmComputeTensor(m_PermutedTensorY, permutedYInfo);
158         InitialiseArmComputeTensorEmpty(m_PermutedTensorY);
159         permuteLayerY->configure(&inputY, &m_PermutedTensorY, aclPermutationYVector);
160         m_PermuteLayerY.reset(permuteLayerY.release());
161     }
162 
163     const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false,  // is inputX reshaped
164                                                                    false,  // is inputY reshaped
165                                                                    false); // is inputY reshaped only 1st run
166     auto gemmLayer = std::make_unique<arm_compute::NEGEMM>();
167     gemmLayer->configure(descriptor.m_Parameters.m_TransposeX ? &m_PermutedTensorX : &inputX,
168                          descriptor.m_Parameters.m_TransposeY ? &m_PermutedTensorY : &inputY,
169                          nullptr,
170                          &output,
171                          1.0,
172                          0,
173                          gemm_info);
174     m_GEMMLayer.reset(gemmLayer.release());
175 }
176 
Execute() const177 void NeonBatchMatMulWorkload::Execute() const
178 {
179     ARMNN_SCOPED_PROFILING_EVENT_NEON_GUID("NeonBatchMatMulWorkload_Execute", this->GetGuid());
180     if (m_PermuteLayerX)
181     {
182         m_PermuteLayerX->run();
183     }
184     if (m_PermuteLayerY)
185     {
186         m_PermuteLayerY->run();
187     }
188     m_GEMMLayer->run();
189 }
190 } //namespace armnn
191