1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifdef XLA_CPU_USE_ACL
17 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_acl.h"
18
19 #include "absl/base/call_once.h"
20 #include "tensorflow/compiler/xla/executable_run_options.h"
21 #include "tensorflow/compiler/xla/service/cpu/runtime_lightweight_check.h"
22 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/types.h"
25
26 #define EIGEN_USE_THREADS
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/core/platform/dynamic_annotations.h"
29
30 namespace {
31 // ACL GEMM API for 32-bit Matrix Multiplication.
32
33 // MatMul function is defined as: c = alpha * op(a) * op(b) + beta * c.
34 // Since XLA MatMul does not use alpha, beta, we set them to 1.0 and 0.0.
35 // Matrix lhs, rhs and out are all column-major.
MatMulF32(const void * run_options_ptr,float * out,float * lhs,float * rhs,int64_t m,int64_t n,int64_t k,int64_t batch_size,int32_t transpose_lhs,int32_t transpose_rhs)36 int32_t MatMulF32(const void* run_options_ptr, float* out, float* lhs,
37 float* rhs, int64_t m, int64_t n, int64_t k,
38 int64_t batch_size, int32_t transpose_lhs,
39 int32_t transpose_rhs) {
40 const float alpha = 1.0f, beta = 0.0f;
41
42 /* TODO: optimize this object creation along with tensor init and
43 * gemm configuration by caching the shapes, similar to onednn
44 * primitive caching feature
45 */
46 struct acl_matmul_obj_t acl_obj;
47 struct acl_matmul_conf_t acl_conf;
48
49 acl_conf.is_trans_lhs = (bool)transpose_lhs;
50 acl_conf.is_trans_rhs = (bool)transpose_rhs;
51
52 if (acl_conf.is_trans_lhs) {
53 acl_conf.lhs_acc_info =
54 arm_compute::TensorInfo(arm_compute::TensorShape(k, m, batch_size), 1,
55 arm_compute::DataType::F32);
56 }
57 if (acl_conf.is_trans_rhs) {
58 acl_conf.rhs_acc_info =
59 arm_compute::TensorInfo(arm_compute::TensorShape(n, k, 1, batch_size),
60 1, arm_compute::DataType::F32);
61 }
62
63 acl_conf.lhs_info =
64 arm_compute::TensorInfo(arm_compute::TensorShape(m, k, batch_size), 1,
65 arm_compute::DataType::F32);
66 acl_conf.rhs_info =
67 arm_compute::TensorInfo(arm_compute::TensorShape(k, n, 1, batch_size), 1,
68 arm_compute::DataType::F32);
69 acl_conf.out_info =
70 arm_compute::TensorInfo(arm_compute::TensorShape(m, n, 1, batch_size), 1,
71 arm_compute::DataType::F32);
72
73 /* TODO: add TF_XLA_* flag for runtime control of fast math mode*/
74 bool is_fastmath_enabled = true;
75 acl_conf.gemm_info.set_fast_math(is_fastmath_enabled);
76
77 // Fused ReLU activation
78 acl_conf.gemm_info.set_activation_info(arm_compute::ActivationLayerInfo());
79
80 // Set alpha (output scaling)
81 acl_conf.alpha = alpha;
82
83 // Validate ACL transpose
84 if (acl_conf.is_trans_lhs) {
85 auto acl_trans_lhs_st = arm_compute::NETranspose::validate(
86 &acl_conf.lhs_acc_info, &acl_conf.lhs_info);
87 if (acl_trans_lhs_st.error_code() != arm_compute::ErrorCode::OK) {
88 VLOG(1) << "lhs transpose validation failed";
89 return -1;
90 }
91 }
92 if (acl_conf.is_trans_rhs) {
93 auto acl_trans_rhs_st = arm_compute::NETranspose::validate(
94 &acl_conf.rhs_acc_info, &acl_conf.rhs_info);
95 if (acl_trans_rhs_st.error_code() != arm_compute::ErrorCode::OK) {
96 VLOG(1) << "rhs transpose validation failed";
97 return -1;
98 }
99 }
100
101 // Validate ACL GEMM
102 auto acl_st = arm_compute::NEGEMM::validate(
103 &acl_conf.rhs_info, &acl_conf.lhs_info, nullptr, &acl_conf.out_info,
104 acl_conf.alpha, 0.0f, acl_conf.gemm_info);
105 if (acl_st.error_code() != arm_compute::ErrorCode::OK) {
106 VLOG(1) << "validate acl GEMM FAILED";
107 return -1;
108 }
109
110 static absl::once_flag flag_once;
111 const xla::ExecutableRunOptions* run_options =
112 static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
113 XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr);
114 const Eigen::ThreadPoolDevice* tpd =
115 (Eigen::ThreadPoolDevice*)(run_options->intra_op_thread_pool());
116 // The threads in Compute Library are bound for the cores 0..max_threads-1
117 const int max_threads = tpd->numThreads();
118
119 // arm_compute::Scheduler does not support concurrent access thus a
120 // workaround here restricts it to only one call
121 absl::call_once(flag_once, [&]() {
122 arm_compute::Scheduler::get().set_num_threads(max_threads);
123 });
124
125 // configure the acl obj with the config
126 acl_obj.lhs_tensor.allocator()->init(acl_conf.lhs_info);
127 acl_obj.rhs_tensor.allocator()->init(acl_conf.rhs_info);
128 acl_obj.out_tensor.allocator()->init(acl_conf.out_info);
129
130 // Configure transpose kernel for src, wei or both
131 if (acl_conf.is_trans_lhs) {
132 acl_obj.lhs_acc_tensor.allocator()->init(acl_conf.lhs_acc_info);
133 acl_obj.trans_lhs.configure(&acl_obj.lhs_acc_tensor, &acl_obj.lhs_tensor);
134 }
135 if (acl_conf.is_trans_rhs) {
136 acl_obj.rhs_acc_tensor.allocator()->init(acl_conf.rhs_acc_info);
137 acl_obj.trans_rhs.configure(&acl_obj.rhs_acc_tensor, &acl_obj.rhs_tensor);
138 }
139 // Configure GEMM
140 acl_obj.gemm.configure(&acl_obj.rhs_tensor, &acl_obj.lhs_tensor, nullptr,
141 &acl_obj.out_tensor, acl_conf.alpha, 0.0f,
142 acl_conf.gemm_info);
143
144 // Run transpose kernel
145 if (transpose_lhs && !transpose_rhs) {
146 acl_obj.lhs_tensor.allocator()->allocate();
147 acl_obj.lhs_acc_tensor.allocator()->import_memory(lhs);
148 acl_obj.trans_lhs.run();
149 acl_obj.rhs_tensor.allocator()->import_memory(rhs);
150 } else if (transpose_rhs && !transpose_lhs) {
151 acl_obj.rhs_tensor.allocator()->allocate();
152 acl_obj.rhs_acc_tensor.allocator()->import_memory(rhs);
153 acl_obj.trans_rhs.run();
154 acl_obj.lhs_tensor.allocator()->import_memory(lhs);
155 } else if (transpose_rhs && transpose_lhs) {
156 acl_obj.lhs_tensor.allocator()->allocate();
157 acl_obj.lhs_acc_tensor.allocator()->import_memory(lhs);
158 acl_obj.rhs_tensor.allocator()->allocate();
159 acl_obj.rhs_acc_tensor.allocator()->import_memory(rhs);
160 acl_obj.trans_lhs.run();
161 acl_obj.trans_rhs.run();
162 } else {
163 acl_obj.lhs_tensor.allocator()->import_memory(lhs);
164 acl_obj.rhs_tensor.allocator()->import_memory(rhs);
165 }
166
167 acl_obj.out_tensor.allocator()->import_memory(out);
168
169 // Execute the function
170 acl_obj.gemm.run();
171
172 acl_obj.lhs_tensor.allocator()->free();
173 acl_obj.rhs_tensor.allocator()->free();
174 acl_obj.out_tensor.allocator()->free();
175 if (acl_conf.is_trans_lhs) acl_obj.lhs_acc_tensor.allocator()->free();
176 if (acl_conf.is_trans_rhs) acl_obj.rhs_acc_tensor.allocator()->free();
177
178 return 0;
179 }
180
181 } // namespace
182
__xla_cpu_runtime_ACLMatMulF32(const void * run_options_ptr,float * out,float * lhs,float * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)183 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ACLMatMulF32(
184 const void* run_options_ptr, float* out, float* lhs, float* rhs, int64_t m,
185 int64_t n, int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) {
186 if (MatMulF32(run_options_ptr, out, lhs, rhs, m, n, k, 1 /*batch_size*/,
187 transpose_lhs, transpose_rhs) < 0) {
188 VLOG(1) << "ACL matmul failed, fallback to Eigen matmul";
189 __xla_cpu_runtime_EigenMatMulF32(run_options_ptr, out, lhs, rhs, m, n, k,
190 transpose_lhs, transpose_rhs);
191 }
192 }
193
__xla_cpu_runtime_ACLBatchMatMulF32(const void * run_options_ptr,float * out,float * lhs,float * rhs,int64_t m,int64_t n,int64_t k,int64_t batch_size,int32_t transpose_lhs,int32_t transpose_rhs)194 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ACLBatchMatMulF32(
195 const void* run_options_ptr, float* out, float* lhs, float* rhs, int64_t m,
196 int64_t n, int64_t k, int64_t batch_size, int32_t transpose_lhs,
197 int32_t transpose_rhs) {
198 if (MatMulF32(run_options_ptr, out, lhs, rhs, m, n, k, batch_size,
199 transpose_lhs, transpose_rhs) < 0) {
200 VLOG(1) << "ACL batch matmul failed, fallback to Eigen batch matmul";
201 __xla_cpu_runtime_EigenBatchMatMulF32(run_options_ptr, out, lhs, rhs, m, n,
202 k, batch_size, transpose_lhs,
203 transpose_rhs);
204 }
205 }
206
207 #endif // XLA_CPU_USE_ACL
208