xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/runtime_matmul_acl.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifdef XLA_CPU_USE_ACL
17 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_acl.h"
18 
19 #include "absl/base/call_once.h"
20 #include "tensorflow/compiler/xla/executable_run_options.h"
21 #include "tensorflow/compiler/xla/service/cpu/runtime_lightweight_check.h"
22 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/types.h"
25 
26 #define EIGEN_USE_THREADS
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/core/platform/dynamic_annotations.h"
29 
30 namespace {
31 // ACL GEMM API for 32-bit Matrix Multiplication.
32 
33 // MatMul function is defined as: c = alpha * op(a) * op(b) + beta * c.
34 // Since XLA MatMul does not use alpha, beta, we set them to 1.0 and 0.0.
35 // Matrix lhs, rhs and out are all column-major.
MatMulF32(const void * run_options_ptr,float * out,float * lhs,float * rhs,int64_t m,int64_t n,int64_t k,int64_t batch_size,int32_t transpose_lhs,int32_t transpose_rhs)36 int32_t MatMulF32(const void* run_options_ptr, float* out, float* lhs,
37                   float* rhs, int64_t m, int64_t n, int64_t k,
38                   int64_t batch_size, int32_t transpose_lhs,
39                   int32_t transpose_rhs) {
40   const float alpha = 1.0f, beta = 0.0f;
41 
42   /* TODO: optimize this object creation along with tensor init and
43    * gemm configuration by caching the shapes, similar to onednn
44    * primitive caching feature
45    */
46   struct acl_matmul_obj_t acl_obj;
47   struct acl_matmul_conf_t acl_conf;
48 
49   acl_conf.is_trans_lhs = (bool)transpose_lhs;
50   acl_conf.is_trans_rhs = (bool)transpose_rhs;
51 
52   if (acl_conf.is_trans_lhs) {
53     acl_conf.lhs_acc_info =
54         arm_compute::TensorInfo(arm_compute::TensorShape(k, m, batch_size), 1,
55                                 arm_compute::DataType::F32);
56   }
57   if (acl_conf.is_trans_rhs) {
58     acl_conf.rhs_acc_info =
59         arm_compute::TensorInfo(arm_compute::TensorShape(n, k, 1, batch_size),
60                                 1, arm_compute::DataType::F32);
61   }
62 
63   acl_conf.lhs_info =
64       arm_compute::TensorInfo(arm_compute::TensorShape(m, k, batch_size), 1,
65                               arm_compute::DataType::F32);
66   acl_conf.rhs_info =
67       arm_compute::TensorInfo(arm_compute::TensorShape(k, n, 1, batch_size), 1,
68                               arm_compute::DataType::F32);
69   acl_conf.out_info =
70       arm_compute::TensorInfo(arm_compute::TensorShape(m, n, 1, batch_size), 1,
71                               arm_compute::DataType::F32);
72 
73   /* TODO: add TF_XLA_* flag for runtime control of fast math mode*/
74   bool is_fastmath_enabled = true;
75   acl_conf.gemm_info.set_fast_math(is_fastmath_enabled);
76 
77   // Fused ReLU activation
78   acl_conf.gemm_info.set_activation_info(arm_compute::ActivationLayerInfo());
79 
80   // Set alpha (output scaling)
81   acl_conf.alpha = alpha;
82 
83   // Validate ACL transpose
84   if (acl_conf.is_trans_lhs) {
85     auto acl_trans_lhs_st = arm_compute::NETranspose::validate(
86         &acl_conf.lhs_acc_info, &acl_conf.lhs_info);
87     if (acl_trans_lhs_st.error_code() != arm_compute::ErrorCode::OK) {
88       VLOG(1) << "lhs transpose validation failed";
89       return -1;
90     }
91   }
92   if (acl_conf.is_trans_rhs) {
93     auto acl_trans_rhs_st = arm_compute::NETranspose::validate(
94         &acl_conf.rhs_acc_info, &acl_conf.rhs_info);
95     if (acl_trans_rhs_st.error_code() != arm_compute::ErrorCode::OK) {
96       VLOG(1) << "rhs transpose validation failed";
97       return -1;
98     }
99   }
100 
101   // Validate ACL GEMM
102   auto acl_st = arm_compute::NEGEMM::validate(
103       &acl_conf.rhs_info, &acl_conf.lhs_info, nullptr, &acl_conf.out_info,
104       acl_conf.alpha, 0.0f, acl_conf.gemm_info);
105   if (acl_st.error_code() != arm_compute::ErrorCode::OK) {
106     VLOG(1) << "validate acl GEMM FAILED";
107     return -1;
108   }
109 
110   static absl::once_flag flag_once;
111   const xla::ExecutableRunOptions* run_options =
112       static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
113   XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr);
114   const Eigen::ThreadPoolDevice* tpd =
115       (Eigen::ThreadPoolDevice*)(run_options->intra_op_thread_pool());
116   // The threads in Compute Library are bound for the cores 0..max_threads-1
117   const int max_threads = tpd->numThreads();
118 
119   // arm_compute::Scheduler does not support concurrent access thus a
120   // workaround here restricts it to only one call
121   absl::call_once(flag_once, [&]() {
122     arm_compute::Scheduler::get().set_num_threads(max_threads);
123   });
124 
125   // configure the acl obj with the config
126   acl_obj.lhs_tensor.allocator()->init(acl_conf.lhs_info);
127   acl_obj.rhs_tensor.allocator()->init(acl_conf.rhs_info);
128   acl_obj.out_tensor.allocator()->init(acl_conf.out_info);
129 
130   // Configure transpose kernel for src, wei or both
131   if (acl_conf.is_trans_lhs) {
132     acl_obj.lhs_acc_tensor.allocator()->init(acl_conf.lhs_acc_info);
133     acl_obj.trans_lhs.configure(&acl_obj.lhs_acc_tensor, &acl_obj.lhs_tensor);
134   }
135   if (acl_conf.is_trans_rhs) {
136     acl_obj.rhs_acc_tensor.allocator()->init(acl_conf.rhs_acc_info);
137     acl_obj.trans_rhs.configure(&acl_obj.rhs_acc_tensor, &acl_obj.rhs_tensor);
138   }
139   // Configure GEMM
140   acl_obj.gemm.configure(&acl_obj.rhs_tensor, &acl_obj.lhs_tensor, nullptr,
141                          &acl_obj.out_tensor, acl_conf.alpha, 0.0f,
142                          acl_conf.gemm_info);
143 
144   // Run transpose kernel
145   if (transpose_lhs && !transpose_rhs) {
146     acl_obj.lhs_tensor.allocator()->allocate();
147     acl_obj.lhs_acc_tensor.allocator()->import_memory(lhs);
148     acl_obj.trans_lhs.run();
149     acl_obj.rhs_tensor.allocator()->import_memory(rhs);
150   } else if (transpose_rhs && !transpose_lhs) {
151     acl_obj.rhs_tensor.allocator()->allocate();
152     acl_obj.rhs_acc_tensor.allocator()->import_memory(rhs);
153     acl_obj.trans_rhs.run();
154     acl_obj.lhs_tensor.allocator()->import_memory(lhs);
155   } else if (transpose_rhs && transpose_lhs) {
156     acl_obj.lhs_tensor.allocator()->allocate();
157     acl_obj.lhs_acc_tensor.allocator()->import_memory(lhs);
158     acl_obj.rhs_tensor.allocator()->allocate();
159     acl_obj.rhs_acc_tensor.allocator()->import_memory(rhs);
160     acl_obj.trans_lhs.run();
161     acl_obj.trans_rhs.run();
162   } else {
163     acl_obj.lhs_tensor.allocator()->import_memory(lhs);
164     acl_obj.rhs_tensor.allocator()->import_memory(rhs);
165   }
166 
167   acl_obj.out_tensor.allocator()->import_memory(out);
168 
169   // Execute the function
170   acl_obj.gemm.run();
171 
172   acl_obj.lhs_tensor.allocator()->free();
173   acl_obj.rhs_tensor.allocator()->free();
174   acl_obj.out_tensor.allocator()->free();
175   if (acl_conf.is_trans_lhs) acl_obj.lhs_acc_tensor.allocator()->free();
176   if (acl_conf.is_trans_rhs) acl_obj.rhs_acc_tensor.allocator()->free();
177 
178   return 0;
179 }
180 
181 }  // namespace
182 
__xla_cpu_runtime_ACLMatMulF32(const void * run_options_ptr,float * out,float * lhs,float * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)183 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ACLMatMulF32(
184     const void* run_options_ptr, float* out, float* lhs, float* rhs, int64_t m,
185     int64_t n, int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) {
186   if (MatMulF32(run_options_ptr, out, lhs, rhs, m, n, k, 1 /*batch_size*/,
187                 transpose_lhs, transpose_rhs) < 0) {
188     VLOG(1) << "ACL matmul failed, fallback to Eigen matmul";
189     __xla_cpu_runtime_EigenMatMulF32(run_options_ptr, out, lhs, rhs, m, n, k,
190                                      transpose_lhs, transpose_rhs);
191   }
192 }
193 
__xla_cpu_runtime_ACLBatchMatMulF32(const void * run_options_ptr,float * out,float * lhs,float * rhs,int64_t m,int64_t n,int64_t k,int64_t batch_size,int32_t transpose_lhs,int32_t transpose_rhs)194 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ACLBatchMatMulF32(
195     const void* run_options_ptr, float* out, float* lhs, float* rhs, int64_t m,
196     int64_t n, int64_t k, int64_t batch_size, int32_t transpose_lhs,
197     int32_t transpose_rhs) {
198   if (MatMulF32(run_options_ptr, out, lhs, rhs, m, n, k, batch_size,
199                 transpose_lhs, transpose_rhs) < 0) {
200     VLOG(1) << "ACL batch matmul failed, fallback to Eigen batch matmul";
201     __xla_cpu_runtime_EigenBatchMatMulF32(run_options_ptr, out, lhs, rhs, m, n,
202                                           k, batch_size, transpose_lhs,
203                                           transpose_rhs);
204   }
205 }
206 
207 #endif  // XLA_CPU_USE_ACL
208