xref: /aosp_15_r20/external/armnn/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ClBatchMatMulWorkload.hpp"
7 
8 #include "ClWorkloadUtils.hpp"
9 
10 #include <aclCommon/ArmComputeTensorUtils.hpp>
11 #include <aclCommon/ArmComputeUtils.hpp>
12 
13 #include <armnn/utility/PolymorphicDowncast.hpp>
14 
15 #include <armnnUtils/Permute.hpp>
16 #include <armnnUtils/TensorUtils.hpp>
17 
18 #include <backendsCommon/WorkloadUtils.hpp>
19 
20 #include <cl/ClTensorHandle.hpp>
21 
22 #include <arm_compute/runtime/CL/functions/CLGEMM.h>
23 #include <arm_compute/runtime/CL/functions/CLPermute.h>
24 
25 
26 namespace armnn
27 {
28 
ClBatchMatMulValidate(const TensorInfo & inputX,const TensorInfo & inputY,const TensorInfo & output,const BatchMatMulDescriptor & descriptor)29 arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX,
30                                           const TensorInfo& inputY,
31                                           const TensorInfo& output,
32                                           const BatchMatMulDescriptor& descriptor)
33 {
34     if (descriptor.m_AdjointX || descriptor.m_AdjointY )
35     {
36         throw Exception("Support for adjoint not implemented.");
37     }
38     if (descriptor.m_DataLayoutX != armnn::DataLayout::NCHW || descriptor.m_DataLayoutY != armnn::DataLayout::NCHW )
39     {
40         throw Exception("Only supported the MatMul in the last 2 dimensions");
41     }
42 
43     arm_compute::Status statusGEMM = arm_compute::Status(arm_compute::ErrorCode::OK);
44     arm_compute::Status statusPermuteX = arm_compute::Status(arm_compute::ErrorCode::OK);
45     arm_compute::Status statusPermuteY = arm_compute::Status(arm_compute::ErrorCode::OK);
46 
47     // ClGemmMatrixMultiplyNativeKernel used by CLGEMM can only support 3 dimensional
48     // tensors so try to reduce the dimensions to 3
49     const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.m_DataLayoutX, 3);
50     const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.m_DataLayoutY, 3);
51     const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output, descriptor.m_DataLayoutY, 3);
52 
53     arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
54     arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
55 
56     if (descriptor.m_TransposeX == true)
57     {
58         armnn::TensorInfo inputXStripped = armnnUtils::ReduceDims(inputX, 3);
59 
60         auto permutationXVector = GeneratePermutationVectorOnLastTwoDimensions(inputXStripped.GetNumDimensions());
61         const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
62         const TensorInfo permutedXInfo = armnnUtils::Permuted(inputXStripped, permutationXVector);
63         aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo, 3);
64 
65         statusPermuteX =  arm_compute::CLPermute::validate(&aclInputXInfo,
66                                                            &aclPermutedXInfo,
67                                                            aclPermutationXVector);
68     }
69 
70     if (descriptor.m_TransposeY == true)
71     {
72         armnn::TensorInfo inputYStripped = armnnUtils::ReduceDims(inputY, 3);
73 
74         auto permutationYVector = GeneratePermutationVectorOnLastTwoDimensions(inputYStripped.GetNumDimensions());
75         const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
76         const TensorInfo permutedYInfo = armnnUtils::Permuted(inputYStripped, permutationYVector);
77         aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo, 3);
78 
79         statusPermuteY =  arm_compute::CLPermute::validate(&aclInputYInfo,
80                                                            &aclPermutedYInfo,
81                                                            aclPermutationYVector);
82     }
83 
84     const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false,  // is inputX reshaped
85                                                                    false,  // is inputY reshaped
86                                                                    false); // is inputY reshaped only 1st run
87 
88 
89     statusGEMM = arm_compute::CLGEMM::validate(descriptor.m_TransposeX ? &aclPermutedXInfo : &aclInputXInfo,
90                                                descriptor.m_TransposeY ? &aclPermutedYInfo : &aclInputYInfo,
91                                                nullptr,
92                                                &aclOutputInfo,
93                                                1.0,
94                                                0,
95                                                gemm_info);
96 
97     if (statusPermuteX.error_code() == arm_compute::ErrorCode::OK &&
98         statusPermuteY.error_code() == arm_compute::ErrorCode::OK &&
99         statusGEMM.error_code()     == arm_compute::ErrorCode::OK)
100     {
101         return arm_compute::Status(arm_compute::ErrorCode::OK,
102                                    "All Batch Mat Mul layers validate status OK.");
103     }
104     else
105     {
106         return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
107                                    "BatchMatMul layer validate status failed."
108                                    + statusGEMM.error_description()
109                                    + statusPermuteX.error_description()
110                                    + statusPermuteY.error_description());
111     }
112 
113 }
114 
ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor & descriptor,const WorkloadInfo & info,const arm_compute::CLCompileContext & clCompileContext)115 ClBatchMatMulWorkload::ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
116                                              const WorkloadInfo& info,
117                                              const arm_compute::CLCompileContext& clCompileContext)
118     : ClBaseWorkload<BatchMatMulQueueDescriptor>(descriptor, info)
119 {
120     // Report Profiling Details
121     ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClBatchMatMulWorkload_Construct",
122                                          descriptor.m_Parameters,
123                                          info,
124                                          this->GetGuid());
125 
126     if (descriptor.m_Parameters.m_AdjointX || descriptor.m_Parameters.m_AdjointY )
127     {
128         throw Exception("Support for adjoint not implemented.");
129     }
130     if (descriptor.m_Parameters.m_DataLayoutX != armnn::DataLayout::NCHW ||
131         descriptor.m_Parameters.m_DataLayoutY != armnn::DataLayout::NCHW )
132     {
133         throw Exception("Only supported the MatMul in the last 2 dimensions");
134     }
135 
136     m_Data.ValidateInputsOutputs("ClBatchMatMulWorkload", 2, 1);
137 
138     const arm_compute::ICLTensor& inputX = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
139     const arm_compute::ICLTensor& inputY = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
140     arm_compute::ICLTensor& output = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
141 
142     inputX.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutX));
143     arm_compute::TensorShape inputXTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape(
144             info.m_InputTensorInfos[0].GetShape(), 3);
145     inputX.info()->set_tensor_shape(inputXTensorInfo);
146     inputY.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutY));
147     arm_compute::TensorShape inputYTensorInfo = armcomputetensorutils::BuildArmComputeTensorShape(
148             info.m_InputTensorInfos[1].GetShape(), 3);
149     inputY.info()->set_tensor_shape(inputYTensorInfo);
150 
151     arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
152     arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
153 
154     if (descriptor.m_Parameters.m_TransposeX == true)
155     {
156         armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[0], 3);
157 
158         armnn::PermutationVector permutationXVector
159                 = GeneratePermutationVectorOnLastTwoDimensions(strippedInfo.GetNumDimensions());
160         const TensorInfo permutedXInfo = armnnUtils::Permuted(strippedInfo, permutationXVector);
161         const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
162         armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorX, permutedXInfo);
163         armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorX);
164 
165         auto permuteLayerX = std::make_unique<arm_compute::CLPermute>();
166         permuteLayerX->configure(clCompileContext,
167                                  &inputX,
168                                  &m_PermutedTensorX,
169                                  aclPermutationXVector);
170         m_PermuteLayerX.reset(permuteLayerX.release());
171     }
172 
173     if (descriptor.m_Parameters.m_TransposeY == true)
174     {
175         armnn::TensorInfo strippedInfo = armnnUtils::ReduceDims(info.m_InputTensorInfos[1], 3);
176 
177         armnn::PermutationVector permutationYVector
178                 = GeneratePermutationVectorOnLastTwoDimensions(strippedInfo.GetNumDimensions());
179         const TensorInfo permutedYInfo = armnnUtils::Permuted(strippedInfo, permutationYVector);
180         const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
181         armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorY, permutedYInfo);
182         armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorY);
183 
184         auto permuteLayerY = std::make_unique<arm_compute::CLPermute>();
185         permuteLayerY->configure(clCompileContext,
186                                  &inputY,
187                                  &m_PermutedTensorY,
188                                  aclPermutationYVector);
189         m_PermuteLayerY.reset(permuteLayerY.release());
190     }
191 
192     const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false,  // is inputX reshaped
193                                                                    false,  // is inputY reshaped
194                                                                    false); // is inputY reshaped only 1st run
195     auto gemmLayer = std::make_unique<arm_compute::CLGEMM>();
196     gemmLayer->configure(clCompileContext,
197                          descriptor.m_Parameters.m_TransposeX ? &m_PermutedTensorX : &inputX,
198                          descriptor.m_Parameters.m_TransposeY ? &m_PermutedTensorY : &inputY,
199                          nullptr,
200                          &output,
201                          1.0,
202                          0,
203                          gemm_info);
204     m_GEMMLayer.reset(gemmLayer.release());
205 }
206 
Execute() const207 void ClBatchMatMulWorkload::Execute() const
208 {
209     ARMNN_SCOPED_PROFILING_EVENT_CL_GUID("ClBatchMatMulWorkload_Execute", this->GetGuid());
210     if (m_PermuteLayerX)
211     {
212         m_PermuteLayerX->run();
213     }
214     if (m_PermuteLayerY)
215     {
216         m_PermuteLayerY->run();
217     }
218     m_GEMMLayer->run();
219 }
220 } //namespace armnn
221