1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/rocm/rocm_blas.h"
17
18 #include "tensorflow/stream_executor/rocm/rocblas_wrapper.h"
19
20 #define EIGEN_USE_GPU
21 #include <assert.h>
22
23 #include <complex>
24
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_format.h"
27 #include "absl/types/span.h"
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29 #include "tensorflow/stream_executor/device_memory.h"
30 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
31 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
32 #include "tensorflow/stream_executor/gpu/gpu_helpers.h"
33 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
34 #include "tensorflow/stream_executor/gpu/gpu_timer.h"
35 #include "tensorflow/stream_executor/lib/env.h"
36 #include "tensorflow/stream_executor/lib/initialize.h"
37 #include "tensorflow/stream_executor/lib/status.h"
38 #include "tensorflow/stream_executor/platform/dso_loader.h"
39 #include "tensorflow/stream_executor/platform/logging.h"
40 #include "tensorflow/stream_executor/platform/port.h"
41 #include "tensorflow/stream_executor/plugin_registry.h"
42 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
43 #include "tensorflow/stream_executor/scratch_allocator.h"
44 #include "tensorflow/stream_executor/stream_executor.h"
45
46 namespace stream_executor {
47 namespace gpu {
48
49 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocBlasPlugin);
50
51 namespace wrap = tensorflow::wrap;
52
53 template <class T>
complex_cast(const DeviceMemory<T> & a)54 const typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast(
55 const DeviceMemory<T> &a) {
56 return reinterpret_cast<
57 const typename RocBlasTypeConversionHelper<T>::mapped_type *>(
58 GpuMemory(a));
59 }
60
61 template <class T>
complex_cast(const T & a)62 const typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast(
63 const T &a) {
64 return reinterpret_cast<
65 const typename RocBlasTypeConversionHelper<T>::mapped_type *>(&a);
66 }
67 template <class T>
complex_cast(DeviceMemory<T> * a)68 typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast(
69 DeviceMemory<T> *a) {
70 return reinterpret_cast<
71 typename RocBlasTypeConversionHelper<T>::mapped_type *>(
72 GpuMemoryMutable(a));
73 }
74
blas_log(const char * c)75 static void blas_log(const char *c) {}
76
ToString(rocblas_status status)77 static string ToString(rocblas_status status) {
78 switch (status) {
79 case rocblas_status_success:
80 return "rocblas_status_success";
81 case rocblas_status_invalid_handle:
82 return "rocblas_status_invalid_handle";
83 case rocblas_status_not_implemented:
84 return "rocblas_status_not_implemented";
85 case rocblas_status_invalid_pointer:
86 return "rocblas_status_invalid_pointer";
87 case rocblas_status_invalid_size:
88 return "rocblas_status_invalid_size";
89 case rocblas_status_memory_error:
90 return "rocblas_status_memory_error";
91 case rocblas_status_internal_error:
92 return "rocblas_status_internal_error";
93 default:
94 return absl::StrCat("<invalid rocBLAS status: ", status, ">");
95 }
96 }
97
Init()98 bool ROCMBlas::Init() {
99 gpu::ScopedActivateExecutorContext sac{parent_};
100 rocblas_status ret = wrap::rocblas_create_handle(&blas_);
101 if (ret != rocblas_status_success) {
102 LOG(ERROR) << "failed to create rocBLAS handle: " << ToString(ret);
103 return false;
104 }
105
106 return true;
107 }
108
ROCMBlas(gpu::GpuExecutor * parent)109 ROCMBlas::ROCMBlas(gpu::GpuExecutor *parent)
110 : parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {}
111
~ROCMBlas()112 ROCMBlas::~ROCMBlas() {
113 if (blas_ != nullptr) {
114 gpu::ScopedActivateExecutorContext sac{parent_};
115 wrap::rocblas_destroy_handle(blas_);
116 }
117 }
118
SetStream(Stream * stream)119 bool ROCMBlas::SetStream(Stream *stream) {
120 CHECK(stream != nullptr);
121 CHECK(AsGpuStreamValue(stream) != nullptr);
122 CHECK(blas_ != nullptr);
123 gpu::ScopedActivateExecutorContext sac{parent_};
124 rocblas_status ret =
125 wrap::rocblas_set_stream(blas_, AsGpuStreamValue(stream));
126 if (ret != rocblas_status_success) {
127 LOG(ERROR) << "failed to set stream for rocBLAS calls: " << ToString(ret);
128 return false;
129 }
130
131 return true;
132 }
133
134 namespace {
135
136 // Helper functions transforming blas arguments into rocBLAS arguments.
137
ROCMBlasTranspose(blas::Transpose trans)138 rocblas_operation ROCMBlasTranspose(blas::Transpose trans) {
139 switch (trans) {
140 case blas::Transpose::kNoTranspose:
141 return rocblas_operation_none;
142 case blas::Transpose::kTranspose:
143 return rocblas_operation_transpose;
144 case blas::Transpose::kConjugateTranspose:
145 return rocblas_operation_conjugate_transpose;
146 default:
147 LOG(FATAL) << "Invalid value of blas::Transpose.";
148 }
149 }
150
ROCMBlasUpperLower(blas::UpperLower uplo)151 rocblas_fill ROCMBlasUpperLower(blas::UpperLower uplo) {
152 switch (uplo) {
153 case blas::UpperLower::kUpper:
154 return rocblas_fill_upper;
155 case blas::UpperLower::kLower:
156 return rocblas_fill_lower;
157 default:
158 LOG(FATAL) << "Invalid value of blas::UpperLower.";
159 }
160 }
161
ROCMBlasDiagonal(blas::Diagonal diag)162 rocblas_diagonal ROCMBlasDiagonal(blas::Diagonal diag) {
163 switch (diag) {
164 case blas::Diagonal::kUnit:
165 return rocblas_diagonal_unit;
166 case blas::Diagonal::kNonUnit:
167 return rocblas_diagonal_non_unit;
168 default:
169 LOG(FATAL) << "Invalid value of blas::Diagonal.";
170 }
171 }
172
ROCMBlasSide(blas::Side side)173 rocblas_side ROCMBlasSide(blas::Side side) {
174 switch (side) {
175 case blas::Side::kLeft:
176 return rocblas_side_left;
177 case blas::Side::kRight:
178 return rocblas_side_right;
179 default:
180 LOG(FATAL) << "Invalid value of blas::Side.";
181 }
182 }
183
184 } // namespace
185
186 template <typename FuncT, typename... Args>
DoBlasInternalImpl(FuncT rocblas_func,Stream * stream,bool pointer_mode_host,bool err_on_failure,Args...args)187 bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
188 bool pointer_mode_host, bool err_on_failure,
189 Args... args) {
190 absl::MutexLock lock{&mu_};
191
192 CHECK(blas_ != nullptr);
193 if (!SetStream(stream)) {
194 return false;
195 }
196
197 gpu::ScopedActivateExecutorContext sac{parent_};
198 rocblas_status ret = rocblas_func(blas_, args...);
199 if (err_on_failure && ret != rocblas_status_success) {
200 LOG(ERROR) << "failed to run ROCBLAS routine " << rocblas_func.kName << ": "
201 << ToString(ret);
202 }
203 return ret == rocblas_status_success;
204 }
205
DoBlasAxpy(Stream * stream,uint64_t elem_count,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)206 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64_t elem_count, float alpha,
207 const DeviceMemory<float> &x, int incx,
208 DeviceMemory<float> *y, int incy) {
209 blas_log("DoBlasAxpy");
210 return DoBlasInternal(wrap::rocblas_saxpy, stream,
211 /* pointer_mode_host = */ true, elem_count, &alpha,
212 GpuMemory(x), incx, GpuMemoryMutable(y), incy);
213 }
214
DoBlasAxpy(Stream * stream,uint64_t elem_count,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)215 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64_t elem_count, double alpha,
216 const DeviceMemory<double> &x, int incx,
217 DeviceMemory<double> *y, int incy) {
218 blas_log("DoBlasAxpy");
219 return DoBlasInternal(wrap::rocblas_daxpy, stream,
220 /* pointer_mode_host = */ true, elem_count, &alpha,
221 GpuMemory(x), incx, GpuMemoryMutable(y), incy);
222 }
223
DoBlasAxpy(Stream * stream,uint64_t elem_count,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)224 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64_t elem_count,
225 std::complex<float> alpha,
226 const DeviceMemory<std::complex<float>> &x, int incx,
227 DeviceMemory<std::complex<float>> *y, int incy) {
228 return DoBlasInternal(
229 wrap::rocblas_caxpy, stream, /* pointer_mode_host = */ true, elem_count,
230 complex_cast(alpha), complex_cast(x), incx, complex_cast(y), incy);
231 }
232
DoBlasAxpy(Stream * stream,uint64_t elem_count,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)233 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64_t elem_count,
234 std::complex<double> alpha,
235 const DeviceMemory<std::complex<double>> &x, int incx,
236 DeviceMemory<std::complex<double>> *y, int incy) {
237 return DoBlasInternal(
238 wrap::rocblas_zaxpy, stream, /* pointer_mode_host = */ true, elem_count,
239 complex_cast(alpha), complex_cast(x), incx, complex_cast(y), incy);
240 }
241
DoBlasCopy(Stream * stream,uint64_t elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)242 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64_t elem_count,
243 const DeviceMemory<float> &x, int incx,
244 DeviceMemory<float> *y, int incy) {
245 return DoBlasInternal(wrap::rocblas_scopy, stream,
246 /* pointer_mode_host = */ true, elem_count,
247 GpuMemory(x), incx, GpuMemoryMutable(y), incy);
248 }
249
DoBlasCopy(Stream * stream,uint64_t elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)250 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64_t elem_count,
251 const DeviceMemory<double> &x, int incx,
252 DeviceMemory<double> *y, int incy) {
253 return DoBlasInternal(wrap::rocblas_dcopy, stream,
254 /* pointer_mode_host = */ true, elem_count,
255 GpuMemory(x), incx, GpuMemoryMutable(y), incy);
256 }
257
DoBlasCopy(Stream * stream,uint64_t elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)258 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64_t elem_count,
259 const DeviceMemory<std::complex<float>> &x, int incx,
260 DeviceMemory<std::complex<float>> *y, int incy) {
261 return DoBlasInternal(wrap::rocblas_ccopy, stream,
262 /* pointer_mode_host = */ true, elem_count,
263 complex_cast(x), incx, complex_cast(y), incy);
264 }
265
DoBlasCopy(Stream * stream,uint64_t elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)266 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64_t elem_count,
267 const DeviceMemory<std::complex<double>> &x, int incx,
268 DeviceMemory<std::complex<double>> *y, int incy) {
269 return DoBlasInternal(wrap::rocblas_zcopy, stream,
270 /* pointer_mode_host = */ true, elem_count,
271 complex_cast(x), incx, complex_cast(y), incy);
272 }
273
DoBlasScal(Stream * stream,uint64_t elem_count,float alpha,DeviceMemory<float> * x,int incx)274 bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, float alpha,
275 DeviceMemory<float> *x, int incx) {
276 blas_log("DoBlasScal<float>");
277 return DoBlasInternal(wrap::rocblas_sscal, stream,
278 /* pointer_mode_host = */ true, elem_count, &alpha,
279 GpuMemoryMutable(x), incx);
280 }
281
DoBlasScal(Stream * stream,uint64_t elem_count,double alpha,DeviceMemory<double> * x,int incx)282 bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, double alpha,
283 DeviceMemory<double> *x, int incx) {
284 return DoBlasInternal(wrap::rocblas_dscal, stream,
285 /* pointer_mode_host = */ true, elem_count, &alpha,
286 GpuMemoryMutable(x), incx);
287 }
288
DoBlasScal(Stream * stream,uint64_t elem_count,float alpha,DeviceMemory<std::complex<float>> * x,int incx)289 bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, float alpha,
290 DeviceMemory<std::complex<float>> *x, int incx) {
291 return DoBlasInternal(wrap::rocblas_csscal, stream,
292 /* pointer_mode_host = */ true, elem_count, &alpha,
293 complex_cast(x), incx);
294 }
295
DoBlasScal(Stream * stream,uint64_t elem_count,double alpha,DeviceMemory<std::complex<double>> * x,int incx)296 bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, double alpha,
297 DeviceMemory<std::complex<double>> *x, int incx) {
298 return DoBlasInternal(wrap::rocblas_zdscal, stream,
299 /* pointer_mode_host = */ true, elem_count, &alpha,
300 complex_cast(x), incx);
301 }
302
DoBlasScal(Stream * stream,uint64_t elem_count,std::complex<float> alpha,DeviceMemory<std::complex<float>> * x,int incx)303 bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count,
304 std::complex<float> alpha,
305 DeviceMemory<std::complex<float>> *x, int incx) {
306 return DoBlasInternal(wrap::rocblas_cscal, stream,
307 /* pointer_mode_host = */ true, elem_count,
308 complex_cast(alpha), complex_cast(x), incx);
309 }
310
DoBlasScal(Stream * stream,uint64_t elem_count,std::complex<double> alpha,DeviceMemory<std::complex<double>> * x,int incx)311 bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count,
312 std::complex<double> alpha,
313 DeviceMemory<std::complex<double>> *x, int incx) {
314 return DoBlasInternal(wrap::rocblas_zscal, stream,
315 /* pointer_mode_host = */ true, elem_count,
316 complex_cast(alpha), complex_cast(x), incx);
317 }
318
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64_t m,uint64_t n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)319 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
320 uint64_t n, float alpha, const DeviceMemory<float> &a,
321 int lda, const DeviceMemory<float> &x, int incx,
322 float beta, DeviceMemory<float> *y, int incy) {
323 blas_log("DoBlasGemv");
324 return DoBlasInternal(
325 wrap::rocblas_sgemv, stream, /* pointer_mode_host = */ true,
326 ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
327 incx, &beta, GpuMemoryMutable(y), incy);
328 }
329
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64_t m,uint64_t n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)330 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
331 uint64_t n, double alpha,
332 const DeviceMemory<double> &a, int lda,
333 const DeviceMemory<double> &x, int incx, double beta,
334 DeviceMemory<double> *y, int incy) {
335 blas_log("DoBlasGemv");
336 return DoBlasInternal(
337 wrap::rocblas_dgemv, stream, /* pointer_mode_host = */ true,
338 ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
339 incx, &beta, GpuMemoryMutable(y), incy);
340 }
341
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64_t m,uint64_t n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)342 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
343 uint64_t n, std::complex<float> alpha,
344 const DeviceMemory<std::complex<float>> &a, int lda,
345 const DeviceMemory<std::complex<float>> &x, int incx,
346 std::complex<float> beta,
347 DeviceMemory<std::complex<float>> *y, int incy) {
348 blas_log("DoBlasGemv");
349 return DoBlasInternal(
350 wrap::rocblas_cgemv, stream, /* pointer_mode_host = */ true,
351 ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda,
352 complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
353 }
354
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64_t m,uint64_t n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)355 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
356 uint64_t n, std::complex<double> alpha,
357 const DeviceMemory<std::complex<double>> &a, int lda,
358 const DeviceMemory<std::complex<double>> &x, int incx,
359 std::complex<double> beta,
360 DeviceMemory<std::complex<double>> *y, int incy) {
361 blas_log("DoBlasGemv\n");
362 return DoBlasInternal(
363 wrap::rocblas_zgemv, stream, /* pointer_mode_host = */ true,
364 ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda,
365 complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
366 }
367
DoBlasSbmv(Stream * stream,blas::UpperLower uplo,uint64_t n,uint64_t k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)368 bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n,
369 uint64_t k, float alpha, const DeviceMemory<float> &a,
370 int lda, const DeviceMemory<float> &x, int incx,
371 float beta, DeviceMemory<float> *y, int incy) {
372 return DoBlasInternal(
373 wrap::rocblas_ssbmv, stream, /* pointer_mode_host = */ true,
374 ROCMBlasUpperLower(uplo), n, k, &alpha, GpuMemory(a), lda, GpuMemory(x),
375 incx, &beta, GpuMemoryMutable(y), incy);
376 }
377
DoBlasSbmv(Stream * stream,blas::UpperLower uplo,uint64_t n,uint64_t k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)378 bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n,
379 uint64_t k, double alpha,
380 const DeviceMemory<double> &a, int lda,
381 const DeviceMemory<double> &x, int incx, double beta,
382 DeviceMemory<double> *y, int incy) {
383 return DoBlasInternal(
384 wrap::rocblas_dsbmv, stream, /* pointer_mode_host = */ true,
385 ROCMBlasUpperLower(uplo), n, k, &alpha, GpuMemory(a), lda, GpuMemory(x),
386 incx, &beta, GpuMemoryMutable(y), incy);
387 }
388
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,blas::DataType dtype,const void * alpha,const DeviceMemoryBase & a,int lda,const DeviceMemoryBase & b,int ldb,const void * beta,DeviceMemoryBase * c,int ldc,blas::ComputePrecision precision)389 port::Status ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
390 blas::Transpose transb, uint64_t m, uint64 n,
391 uint64_t k, blas::DataType dtype,
392 const void *alpha, const DeviceMemoryBase &a,
393 int lda, const DeviceMemoryBase &b, int ldb,
394 const void *beta, DeviceMemoryBase *c,
395 int ldc, blas::ComputePrecision precision) {
396 blas_log("DoBlasGemm");
397 VLOG(1) << absl::StreamFormat(
398 "doing rocBLAS GEMM: at=%d bt=%d m=%u n=%u "
399 "k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p "
400 "c=%p ldc=%d",
401 static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
402 a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
403 if (dtype == blas::DataType::kHalf || dtype == blas::DataType::kFloat) {
404 if (transa == blas::Transpose::kNoTranspose) {
405 if (lda < static_cast<int64_t>(m)) {
406 LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
407 "precondition violation";
408 }
409 } else {
410 if (lda < static_cast<int64_t>(k)) {
411 LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
412 << ") (transpose case); precondition violation";
413 }
414 }
415 if (transb == blas::Transpose::kNoTranspose) {
416 if (ldb < static_cast<int64_t>(k)) {
417 LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
418 << ") (no transpose case); precondition violation";
419 }
420 } else {
421 if (ldb < static_cast<int64_t>(n)) {
422 LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
423 "precondition violation";
424 }
425 }
426 }
427
428 switch (dtype) {
429 case blas::DataType::kHalf: {
430 port::StatusOr<bool> maybe_hasXDLOPS = GpuDriver::GetMFMASupport();
431 if (maybe_hasXDLOPS.ok() && maybe_hasXDLOPS.ValueOrDie()) {
432 VLOG(1) << "Using rocblas_gemm_ex";
433 return DoBlasInternalStatus(
434 wrap::rocblas_gemm_ex, stream, /* pointer_mode_host = */ true,
435 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb),
436 (rocblas_int)m, (rocblas_int)n, (rocblas_int)k, alpha, a.opaque(),
437 rocblas_datatype_f16_r, lda, b.opaque(), rocblas_datatype_f16_r,
438 ldb, beta, c->opaque(), rocblas_datatype_f16_r, ldc, c->opaque(),
439 rocblas_datatype_f16_r, ldc, rocblas_datatype_f32_r,
440 rocblas_gemm_algo_standard, 0, 0);
441 } else {
442 VLOG(1) << "Using rocblas_hgemm";
443 const Eigen::half alpha_half(*static_cast<const float *>(alpha));
444 const Eigen::half beta_half(*static_cast<const float *>(beta));
445 return DoBlasInternalStatus(
446 wrap::rocblas_hgemm, stream, /* pointer_mode_host = */ true,
447 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
448 reinterpret_cast<const rocblas_half *>(&alpha_half),
449 reinterpret_cast<const rocblas_half *>(a.opaque()), lda,
450 reinterpret_cast<const rocblas_half *>(b.opaque()), ldb,
451 reinterpret_cast<const rocblas_half *>(&beta_half),
452 reinterpret_cast<rocblas_half *>(c->opaque()), ldc);
453 }
454 }
455 case blas::DataType::kBF16:
456 return DoBlasInternalStatus(
457 wrap::rocblas_gemm_ex, stream, /* pointer_mode_host = */ true,
458 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), (rocblas_int)m,
459 (rocblas_int)n, (rocblas_int)k, alpha, a.opaque(),
460 rocblas_datatype_bf16_r, lda, b.opaque(), rocblas_datatype_bf16_r,
461 ldb, beta, c->opaque(), rocblas_datatype_bf16_r, ldc, c->opaque(),
462 rocblas_datatype_bf16_r, ldc, rocblas_datatype_f32_r,
463 rocblas_gemm_algo_standard, 0, 0);
464 case blas::DataType::kFloat:
465 return DoBlasInternalStatus(
466 wrap::rocblas_sgemm, stream, /* pointer_mode_host = */ true,
467 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
468 static_cast<const float *>(alpha),
469 static_cast<const float *>(a.opaque()), lda,
470 static_cast<const float *>(b.opaque()), ldb,
471 static_cast<const float *>(beta), static_cast<float *>(c->opaque()),
472 ldc);
473 case blas::DataType::kDouble:
474 return DoBlasInternalStatus(
475 wrap::rocblas_dgemm, stream, /* pointer_mode_host = */ true,
476 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
477 static_cast<const double *>(alpha),
478 static_cast<const double *>(a.opaque()), lda,
479 static_cast<const double *>(b.opaque()), ldb,
480 static_cast<const double *>(beta), static_cast<double *>(c->opaque()),
481 ldc);
482 case blas::DataType::kComplexFloat: {
483 auto cb_alpha =
484 complex_cast(*static_cast<const std::complex<float> *>(alpha));
485 auto cb_beta =
486 complex_cast(*static_cast<const std::complex<float> *>(beta));
487 return DoBlasInternalStatus(
488 wrap::rocblas_cgemm, stream, /* pointer_mode_host = */ true,
489 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
490 cb_alpha, static_cast<const rocblas_float_complex *>(a.opaque()), lda,
491 static_cast<const rocblas_float_complex *>(b.opaque()), ldb, cb_beta,
492 static_cast<rocblas_float_complex *>(c->opaque()), ldc);
493 }
494 case blas::DataType::kComplexDouble: {
495 auto cb_alpha =
496 complex_cast(*static_cast<const std::complex<double> *>(alpha));
497 auto cb_beta =
498 complex_cast(*static_cast<const std::complex<double> *>(beta));
499 return DoBlasInternalStatus(
500 wrap::rocblas_zgemm, stream, /* pointer_mode_host = */ true,
501 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
502 cb_alpha, static_cast<const rocblas_double_complex *>(a.opaque()),
503 lda, static_cast<const rocblas_double_complex *>(b.opaque()), ldb,
504 cb_beta, static_cast<rocblas_double_complex *>(c->opaque()), ldc);
505 }
506 default:
507 return port::InternalError(absl::StrCat("Unsupported datatype for GEMM: ",
508 blas::DataTypeString(dtype)));
509 }
510 }
511
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64_t m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy,blas::ProfileResult * output_profile_result)512 bool ROCMBlas::DoBlasGemvWithProfiling(
513 Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, float alpha,
514 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
515 int incx, float beta, DeviceMemory<float> *y, int incy,
516 blas::ProfileResult *output_profile_result) {
517 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
518 incx, beta, y, incy,
519 output_profile_result);
520 }
521
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64_t m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy,blas::ProfileResult * output_profile_result)522 bool ROCMBlas::DoBlasGemvWithProfiling(
523 Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, double alpha,
524 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
525 int incx, double beta, DeviceMemory<double> *y, int incy,
526 blas::ProfileResult *output_profile_result) {
527 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
528 incx, beta, y, incy,
529 output_profile_result);
530 }
531
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64_t m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy,blas::ProfileResult * output_profile_result)532 bool ROCMBlas::DoBlasGemvWithProfiling(
533 Stream *stream, blas::Transpose trans, uint64_t m, uint64 n,
534 std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
535 int lda, const DeviceMemory<std::complex<float>> &x, int incx,
536 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
537 blas::ProfileResult *output_profile_result) {
538 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
539 incx, beta, y, incy,
540 output_profile_result);
541 }
542
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64_t m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy,blas::ProfileResult * output_profile_result)543 bool ROCMBlas::DoBlasGemvWithProfiling(
544 Stream *stream, blas::Transpose trans, uint64_t m, uint64 n,
545 std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
546 int lda, const DeviceMemory<std::complex<double>> &x, int incx,
547 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
548 blas::ProfileResult *output_profile_result) {
549 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
550 incx, beta, y, incy,
551 output_profile_result);
552 }
553
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ProfileResult * output_profile_result)554 bool ROCMBlas::DoBlasGemmWithProfiling(
555 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
556 uint64_t n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
557 int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
558 DeviceMemory<Eigen::half> *c, int ldc,
559 blas::ProfileResult *output_profile_result) {
560 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
561 lda, b, ldb, beta, c, ldc,
562 output_profile_result);
563 }
564
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ProfileResult * output_profile_result)565 bool ROCMBlas::DoBlasGemmWithProfiling(
566 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
567 uint64_t n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
568 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
569 int ldc, blas::ProfileResult *output_profile_result) {
570 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
571 lda, b, ldb, beta, c, ldc,
572 output_profile_result);
573 }
574
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ProfileResult * output_profile_result)575 bool ROCMBlas::DoBlasGemmWithProfiling(
576 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
577 uint64_t n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
578 const DeviceMemory<double> &b, int ldb, double beta,
579 DeviceMemory<double> *c, int ldc,
580 blas::ProfileResult *output_profile_result) {
581 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
582 lda, b, ldb, beta, c, ldc,
583 output_profile_result);
584 }
585
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ProfileResult * output_profile_result)586 bool ROCMBlas::DoBlasGemmWithProfiling(
587 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
588 uint64_t n, uint64 k, std::complex<float> alpha,
589 const DeviceMemory<std::complex<float>> &a, int lda,
590 const DeviceMemory<std::complex<float>> &b, int ldb,
591 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
592 blas::ProfileResult *output_profile_result) {
593 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
594 lda, b, ldb, beta, c, ldc,
595 output_profile_result);
596 }
597
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ProfileResult * output_profile_result)598 bool ROCMBlas::DoBlasGemmWithProfiling(
599 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
600 uint64_t n, uint64 k, std::complex<double> alpha,
601 const DeviceMemory<std::complex<double>> &a, int lda,
602 const DeviceMemory<std::complex<double>> &b, int ldb,
603 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
604 blas::ProfileResult *output_profile_result) {
605 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
606 lda, b, ldb, beta, c, ldc,
607 output_profile_result);
608 }
609
610 template <typename T>
DoBlasGemvWithProfilingImpl(Stream * stream,blas::Transpose trans,uint64_t m,uint64 n,const T & alpha,const DeviceMemory<T> & a,int lda,const DeviceMemory<T> & x,int incx,const T & beta,DeviceMemory<T> * y,int incy,blas::ProfileResult * output_profile_result)611 bool ROCMBlas::DoBlasGemvWithProfilingImpl(
612 Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, const T &alpha,
613 const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx,
614 const T &beta, DeviceMemory<T> *y, int incy,
615 blas::ProfileResult *output_profile_result) {
616 // ROCM TODO: properly implement the interface
617 return false;
618 }
619
620 template <typename T, typename ParamType>
DoBlasGemmWithProfilingImpl(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,const ParamType & alpha,const DeviceMemory<T> & a,int lda,const DeviceMemory<T> & b,int ldb,const ParamType & beta,DeviceMemory<T> * c,int ldc,blas::ProfileResult * output_profile_result)621 bool ROCMBlas::DoBlasGemmWithProfilingImpl(
622 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
623 uint64_t n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
624 int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
625 DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) {
626 // ROCM TODO: properly implement the interface
627 return false;
628 }
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,const void * alpha,const DeviceMemoryBase & a,blas::DataType type_a,int lda,const DeviceMemoryBase & b,blas::DataType type_b,int ldb,const void * beta,DeviceMemoryBase * c,blas::DataType type_c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)629 port::Status ROCMBlas::DoBlasGemmWithAlgorithm(
630 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
631 uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
632 blas::DataType type_a, int lda, const DeviceMemoryBase &b,
633 blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c,
634 blas::DataType type_c, int ldc, blas::ComputationType computation_type,
635 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
636 // ROCM TODO: properly implement the interface
637 return port::InternalError("Not implemented on ROCm");
638 }
639
DoBlasGemmStridedBatchedWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,const void * alpha,const DeviceMemoryBase & a,blas::DataType type_a,int lda,int64_t stride_a,const DeviceMemoryBase & b,blas::DataType type_b,int ldb,int64_t stride_b,const void * beta,DeviceMemoryBase * c,blas::DataType type_c,int ldc,int64_t stride_c,int batch_count,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)640 port::Status ROCMBlas::DoBlasGemmStridedBatchedWithAlgorithm(
641 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
642 uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
643 blas::DataType type_a, int lda, int64_t stride_a, const DeviceMemoryBase &b,
644 blas::DataType type_b, int ldb, int64_t stride_b, const void *beta,
645 DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64_t stride_c,
646 int batch_count, blas::ComputationType computation_type,
647 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
648 // ROCM TODO: properly implement the interface
649 return port::InternalError("Not implemented on ROCm");
650 }
651
GetBlasGemmAlgorithms(Stream * stream,std::vector<blas::AlgorithmType> * out_algorithms)652 bool ROCMBlas::GetBlasGemmAlgorithms(
653 Stream *stream, std::vector<blas::AlgorithmType> *out_algorithms) {
654 // ROCM TODO: properly implement the interface
655 return true;
656 }
657
658 // This copies from source memory: raw_ptrs[i] to target memory:
659 // device_memory_ptr at the interval of matrix_byte_size, or vice versa.
660 // The below algorithm tries to minimize the number of memcpy by consolidating
661 // neighboring memcpy into a single request
662 template <typename MAPPED_T>
ReorganizeMemory(Stream * stream,DeviceMemory<MAPPED_T> * device_memory,const std::vector<MAPPED_T * > & raw_ptrs,int batch_count,uint64_t batch_stride,bool gather)663 port::Status ReorganizeMemory(Stream *stream,
664 DeviceMemory<MAPPED_T> *device_memory,
665 const std::vector<MAPPED_T *> &raw_ptrs,
666 int batch_count, uint64_t batch_stride,
667 bool gather) {
668 assert(batch_count > 0);
669 char *device_memory_ptr = static_cast<char *>(device_memory->opaque());
670 char *src_ptr = reinterpret_cast<char *>(raw_ptrs[0]);
671 char *dst_ptr = device_memory_ptr;
672 size_t matrix_byte_size = batch_stride * sizeof(MAPPED_T);
673 uint64_t cur_stride_size = matrix_byte_size;
674
675 for (int i = 1; i < batch_count; ++i) {
676 if (reinterpret_cast<char *>(raw_ptrs[i]) == src_ptr + cur_stride_size) {
677 cur_stride_size += matrix_byte_size;
678 } else {
679 DeviceMemoryBase src_mem = DeviceMemoryBase(src_ptr, cur_stride_size);
680 DeviceMemoryBase target_mem = DeviceMemoryBase(dst_ptr, cur_stride_size);
681 bool a_status =
682 gather
683 ? stream->ThenMemcpy(&target_mem, src_mem, cur_stride_size).ok()
684 : stream->ThenMemcpy(&src_mem, target_mem, cur_stride_size).ok();
685 if (!a_status) {
686 return port::Status(
687 port::error::INTERNAL,
688 "failed to copy device memory in ROCMBlas::DoBlasGemmBatched");
689 }
690 src_ptr = reinterpret_cast<char *>(raw_ptrs[i]);
691 dst_ptr = device_memory_ptr + i * matrix_byte_size;
692 cur_stride_size = matrix_byte_size;
693 }
694 }
695
696 DeviceMemoryBase src_mem = DeviceMemoryBase(src_ptr, cur_stride_size);
697 DeviceMemoryBase target_mem = DeviceMemoryBase(dst_ptr, cur_stride_size);
698 bool a_status =
699 gather ? stream->ThenMemcpy(&target_mem, src_mem, cur_stride_size).ok()
700 : stream->ThenMemcpy(&src_mem, target_mem, cur_stride_size).ok();
701 if (!a_status)
702 return port::Status(
703 port::error::INTERNAL,
704 "failed to copy device memory in ROCMBlas::DoBlasGemmBatched");
705 return port::Status::OK();
706 }
707
708 template <typename T>
AllocateStridedBuffer(const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type * > & raw_ptrs,int batch_count,uint64_t batch_stride,ScratchAllocator * scratch_allocator,Stream * stream,std::unique_ptr<TemporaryDeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>> * temp_memory,DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type> * device_memory,bool copy_data,bool & reallocated)709 port::Status ROCMBlas::AllocateStridedBuffer(
710 const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *>
711 &raw_ptrs,
712 int batch_count, uint64_t batch_stride, ScratchAllocator *scratch_allocator,
713 Stream *stream,
714 std::unique_ptr<TemporaryDeviceMemory<
715 typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
716 DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
717 *device_memory,
718 bool copy_data, bool &reallocated) {
719 assert(device_memory != nullptr);
720
721 using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type;
722
723 bool needs_allocate_strided = false;
724 for (int i = 1; i < batch_count; ++i) {
725 uint64_t tmp_batch_stride = raw_ptrs[i] - raw_ptrs[i - 1];
726 if (tmp_batch_stride != batch_stride) {
727 needs_allocate_strided = true;
728 break;
729 }
730 }
731
732 size_t matrix_byte_size = batch_stride * sizeof(MAPPED_T);
733 size_t matrix_batch_byte_size = matrix_byte_size * batch_count;
734
735 // No need to do re-allocation, take the short cut and return
736 if (!needs_allocate_strided) {
737 *device_memory = DeviceMemory<MAPPED_T>(
738 DeviceMemoryBase(raw_ptrs[0], matrix_batch_byte_size));
739 reallocated = false;
740 return port::Status::OK();
741 }
742
743 if (scratch_allocator != nullptr) {
744 TF_ASSIGN_OR_RETURN(
745 DeviceMemory<uint8> batch_matrix_bytes,
746 scratch_allocator->AllocateBytes(matrix_batch_byte_size));
747 *device_memory = DeviceMemory<MAPPED_T>(batch_matrix_bytes);
748 } else {
749 assert(temp_memory != nullptr);
750 TF_ASSIGN_OR_RETURN(*temp_memory, stream->AllocateTemporaryArray<MAPPED_T>(
751 matrix_batch_byte_size));
752 *device_memory =
753 DeviceMemory<MAPPED_T>(*(*temp_memory)->mutable_device_memory());
754 }
755
756 reallocated = true;
757
758 if (copy_data)
759 return ReorganizeMemory(stream, device_memory, raw_ptrs, batch_count,
760 batch_stride, true);
761 return port::Status::OK();
762 }
763
764 template <typename T, typename FuncT>
DoBlasGemmBatchedInternal(FuncT rocblas_func,Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64 k,T alpha,const absl::Span<DeviceMemory<T> * const> & a_ptrs_to_wrappers,int lda,const absl::Span<DeviceMemory<T> * const> & b_ptrs_to_wrappers,int ldb,T beta,const absl::Span<DeviceMemory<T> * const> & c_ptrs_to_wrappers,int ldc,int batch_count,ScratchAllocator * scratch_allocator)765 port::Status ROCMBlas::DoBlasGemmBatchedInternal(
766 FuncT rocblas_func, Stream *stream, blas::Transpose transa,
767 blas::Transpose transb, uint64_t m, uint64 n, uint64 k, T alpha,
768 const absl::Span<DeviceMemory<T> *const> &a_ptrs_to_wrappers, int lda,
769 const absl::Span<DeviceMemory<T> *const> &b_ptrs_to_wrappers, int ldb,
770 T beta, const absl::Span<DeviceMemory<T> *const> &c_ptrs_to_wrappers,
771 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
772 using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type;
773
774 // Sanity checks before making any further progress
775 uint64_t batch_stride_a = 0;
776 uint64_t batch_stride_b = 0;
777 uint64_t batch_stride_c = 0;
778
779 assert(ldc >= m);
780 batch_stride_c = ldc * n;
781
782 if (ROCMBlasTranspose(transa) == rocblas_operation_none) {
783 assert(lda >= m);
784 batch_stride_a = lda * k;
785 } else {
786 assert(lda >= k);
787 batch_stride_a = lda * m;
788 }
789
790 if (ROCMBlasTranspose(transb) == rocblas_operation_none) {
791 assert(ldb >= k);
792 batch_stride_b = ldb * n;
793 } else {
794 assert(ldb >= n);
795 batch_stride_b = ldb * k;
796 }
797
798 // Allocate local vectors to hold device pointers to matrices
799 std::vector<MAPPED_T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
800 for (int i = 0; i < batch_count; ++i) {
801 // static_cast does work when converting Eigen::half* to rocblas_half*,
802 // hence the use of reinterpret_cast
803 a_raw_ptrs.push_back(
804 reinterpret_cast<MAPPED_T *>(a_ptrs_to_wrappers[i]->opaque()));
805 b_raw_ptrs.push_back(
806 reinterpret_cast<MAPPED_T *>(b_ptrs_to_wrappers[i]->opaque()));
807 c_raw_ptrs.push_back(
808 reinterpret_cast<MAPPED_T *>(c_ptrs_to_wrappers[i]->opaque()));
809 }
810
811 DeviceMemory<MAPPED_T> a;
812 // Make sure the temporary memory are in-scope before the function returns
813 std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> a_temp;
814 bool reallocated_a, reallocated_b, reallocated_c;
815 port::Status a_allocation_status = AllocateStridedBuffer<T>(
816 a_raw_ptrs, batch_count, batch_stride_a, scratch_allocator, stream,
817 &a_temp, &a, true, reallocated_a);
818 if (a_allocation_status != port::Status::OK()) {
819 return a_allocation_status;
820 }
821
822 DeviceMemory<MAPPED_T> b;
823 std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> b_temp;
824 port::Status b_allocation_status = AllocateStridedBuffer<T>(
825 b_raw_ptrs, batch_count, batch_stride_b, scratch_allocator, stream,
826 &b_temp, &b, true, reallocated_b);
827 if (b_allocation_status != port::Status::OK()) {
828 return b_allocation_status;
829 }
830
831 DeviceMemory<MAPPED_T> c;
832 std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> c_temp;
833 port::Status c_allocation_status = AllocateStridedBuffer<T>(
834 c_raw_ptrs, batch_count, batch_stride_c, scratch_allocator, stream,
835 &c_temp, &c, true, reallocated_c); // can disable copy if beta=0
836 if (c_allocation_status != port::Status::OK()) {
837 return c_allocation_status;
838 }
839
840 MAPPED_T *alpha_ptr = reinterpret_cast<MAPPED_T *>(&alpha);
841 MAPPED_T *beta_ptr = reinterpret_cast<MAPPED_T *>(&beta);
842
843 bool ok;
844 ok = DoBlasInternal(rocblas_func, stream, /* pointer_mode_host = */ true,
845 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
846 n, k, GpuComplex(alpha_ptr), GpuMemory(a), lda,
847 batch_stride_a, GpuMemory(b), ldb, batch_stride_b,
848 GpuComplex(beta_ptr), GpuMemoryMutable(&c), ldc,
849 batch_stride_c, batch_count);
850 if (!ok)
851 return port::Status(port::error::INTERNAL,
852 "failed BLAS call, see log for details");
853 if (reallocated_c)
854 return ReorganizeMemory(stream, &c, c_raw_ptrs, batch_count, batch_stride_c,
855 false);
856 return port::Status::OK();
857 }
858
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,float alpha,const absl::Span<DeviceMemory<Eigen::half> * const> & a,int lda,const absl::Span<DeviceMemory<Eigen::half> * const> & b,int ldb,float beta,const absl::Span<DeviceMemory<Eigen::half> * const> & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)859 bool ROCMBlas::DoBlasGemmBatched(
860 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
861 uint64_t n, uint64 k, float alpha,
862 const absl::Span<DeviceMemory<Eigen::half> *const> &a, int lda,
863 const absl::Span<DeviceMemory<Eigen::half> *const> &b, int ldb, float beta,
864 const absl::Span<DeviceMemory<Eigen::half> *const> &c, int ldc,
865 int batch_count, ScratchAllocator *scratch_allocator) {
866 blas_log("DoBlasGemmBatched");
867 const Eigen::half alpha_half(alpha);
868 const Eigen::half beta_half(beta);
869
870 port::Status status = DoBlasGemmBatchedInternal(
871 wrap::rocblas_hgemm_strided_batched, stream, transa, transb, m, n, k,
872 alpha_half, a, lda, b, ldb, beta_half, c, ldc, batch_count,
873 scratch_allocator);
874 if (!status.ok()) {
875 LOG(ERROR) << status;
876 }
877
878 return status.ok();
879 }
880
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,float alpha,const absl::Span<DeviceMemory<float> * const> & a_array,int lda,const absl::Span<DeviceMemory<float> * const> & b_array,int ldb,float beta,const absl::Span<DeviceMemory<float> * const> & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)881 bool ROCMBlas::DoBlasGemmBatched(
882 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
883 uint64_t n, uint64 k, float alpha,
884 const absl::Span<DeviceMemory<float> *const> &a_array, int lda,
885 const absl::Span<DeviceMemory<float> *const> &b_array, int ldb, float beta,
886 const absl::Span<DeviceMemory<float> *const> &c_array, int ldc,
887 int batch_count, ScratchAllocator *scratch_allocator) {
888 blas_log("DoBlasGemmBatched");
889 port::Status status = DoBlasGemmBatchedInternal(
890 wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k,
891 alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
892 scratch_allocator);
893 if (!status.ok()) {
894 LOG(ERROR) << status;
895 }
896 return status.ok();
897 }
898
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,double alpha,const absl::Span<DeviceMemory<double> * const> & a_array,int lda,const absl::Span<DeviceMemory<double> * const> & b_array,int ldb,double beta,const absl::Span<DeviceMemory<double> * const> & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)899 bool ROCMBlas::DoBlasGemmBatched(
900 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
901 uint64_t n, uint64 k, double alpha,
902 const absl::Span<DeviceMemory<double> *const> &a_array, int lda,
903 const absl::Span<DeviceMemory<double> *const> &b_array, int ldb,
904 double beta, const absl::Span<DeviceMemory<double> *const> &c_array,
905 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
906 blas_log("DoBlasGemmBatched");
907 port::Status status = DoBlasGemmBatchedInternal(
908 wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k,
909 alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
910 scratch_allocator);
911 if (!status.ok()) {
912 LOG(ERROR) << status;
913 }
914 return status.ok();
915 }
916
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,std::complex<float> alpha,const absl::Span<DeviceMemory<std::complex<float>> * const> & a_array,int lda,const absl::Span<DeviceMemory<std::complex<float>> * const> & b_array,int ldb,std::complex<float> beta,const absl::Span<DeviceMemory<std::complex<float>> * const> & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)917 bool ROCMBlas::DoBlasGemmBatched(
918 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
919 uint64_t n, uint64 k, std::complex<float> alpha,
920 const absl::Span<DeviceMemory<std::complex<float>> *const> &a_array,
921 int lda,
922 const absl::Span<DeviceMemory<std::complex<float>> *const> &b_array,
923 int ldb, std::complex<float> beta,
924 const absl::Span<DeviceMemory<std::complex<float>> *const> &c_array,
925 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
926 blas_log("DoBlasGemmBatched");
927 port::Status status = DoBlasGemmBatchedInternal(
928 wrap::rocblas_cgemm_strided_batched, stream, transa, transb, m, n, k,
929 alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
930 scratch_allocator);
931 if (!status.ok()) {
932 LOG(ERROR) << status;
933 }
934 return status.ok();
935 }
936
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,std::complex<double> alpha,const absl::Span<DeviceMemory<std::complex<double>> * const> & a_array,int lda,const absl::Span<DeviceMemory<std::complex<double>> * const> & b_array,int ldb,std::complex<double> beta,const absl::Span<DeviceMemory<std::complex<double>> * const> & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)937 bool ROCMBlas::DoBlasGemmBatched(
938 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
939 uint64_t n, uint64 k, std::complex<double> alpha,
940 const absl::Span<DeviceMemory<std::complex<double>> *const> &a_array,
941 int lda,
942 const absl::Span<DeviceMemory<std::complex<double>> *const> &b_array,
943 int ldb, std::complex<double> beta,
944 const absl::Span<DeviceMemory<std::complex<double>> *const> &c_array,
945 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
946 blas_log("DoBlasGemmBatched");
947 port::Status status = DoBlasGemmBatchedInternal(
948 wrap::rocblas_zgemm_strided_batched, stream, transa, transb, m, n, k,
949 alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
950 scratch_allocator);
951 if (!status.ok()) {
952 LOG(ERROR) << status;
953 }
954 return status.ok();
955 }
956
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)957 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
958 blas::UpperLower uplo, blas::Transpose transa,
959 blas::Diagonal diag, uint64_t m, uint64 n,
960 float alpha, const DeviceMemory<float> &a, int lda,
961 DeviceMemory<float> *b, int ldb) {
962 blas_log("DoBlasTrsm");
963 return DoBlasInternal(wrap::rocblas_strsm, stream,
964 /* pointer_mode_host = */ true, ROCMBlasSide(side),
965 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
966 ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda,
967 GpuMemoryMutable(b), ldb);
968 }
969
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)970 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
971 blas::UpperLower uplo, blas::Transpose transa,
972 blas::Diagonal diag, uint64_t m, uint64 n,
973 double alpha, const DeviceMemory<double> &a, int lda,
974 DeviceMemory<double> *b, int ldb) {
975 blas_log("DoBlasTrsm");
976 return DoBlasInternal(wrap::rocblas_dtrsm, stream,
977 /* pointer_mode_host = */ true, ROCMBlasSide(side),
978 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
979 ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda,
980 GpuMemoryMutable(b), ldb);
981 }
982
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)983 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
984 blas::UpperLower uplo, blas::Transpose transa,
985 blas::Diagonal diag, uint64_t m, uint64 n,
986 std::complex<float> alpha,
987 const DeviceMemory<std::complex<float>> &a, int lda,
988 DeviceMemory<std::complex<float>> *b, int ldb) {
989 return DoBlasInternal(wrap::rocblas_ctrsm, stream,
990 /* pointer_mode_host = */ true, ROCMBlasSide(side),
991 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
992 ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
993 complex_cast(a), lda, complex_cast(b), ldb);
994 }
995
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)996 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
997 blas::UpperLower uplo, blas::Transpose transa,
998 blas::Diagonal diag, uint64_t m, uint64 n,
999 std::complex<double> alpha,
1000 const DeviceMemory<std::complex<double>> &a, int lda,
1001 DeviceMemory<std::complex<double>> *b, int ldb) {
1002 return DoBlasInternal(wrap::rocblas_ztrsm, stream,
1003 /* pointer_mode_host = */ true, ROCMBlasSide(side),
1004 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
1005 ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
1006 complex_cast(a), lda, complex_cast(b), ldb);
1007 }
1008
DoBlasTrsmBatched(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,float alpha,const DeviceMemory<float * > & as,int lda,DeviceMemory<float * > * bs,int ldb,int batch_count)1009 bool ROCMBlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
1010 blas::UpperLower uplo, blas::Transpose transa,
1011 blas::Diagonal diag, uint64_t m, uint64 n,
1012 float alpha, const DeviceMemory<float *> &as,
1013 int lda, DeviceMemory<float *> *bs, int ldb,
1014 int batch_count) {
1015 return DoBlasInternal(wrap::rocblas_strsm_batched, stream,
1016 true /* = pointer_mode_host */, ROCMBlasSide(side),
1017 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
1018 ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(as),
1019 lda, GpuMemoryMutable(bs), ldb, batch_count);
1020 }
1021
DoBlasTrsmBatched(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,double alpha,const DeviceMemory<double * > & as,int lda,DeviceMemory<double * > * bs,int ldb,int batch_count)1022 bool ROCMBlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
1023 blas::UpperLower uplo, blas::Transpose transa,
1024 blas::Diagonal diag, uint64_t m, uint64 n,
1025 double alpha, const DeviceMemory<double *> &as,
1026 int lda, DeviceMemory<double *> *bs, int ldb,
1027 int batch_count) {
1028 return DoBlasInternal(wrap::rocblas_dtrsm_batched, stream,
1029 true /* = pointer_mode_host */, ROCMBlasSide(side),
1030 ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
1031 ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(as),
1032 lda, GpuMemoryMutable(bs), ldb, batch_count);
1033 }
1034
DoBlasTrsmBatched(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float> * > & as,int lda,DeviceMemory<std::complex<float> * > * bs,int ldb,int batch_count)1035 bool ROCMBlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
1036 blas::UpperLower uplo, blas::Transpose transa,
1037 blas::Diagonal diag, uint64_t m, uint64 n,
1038 std::complex<float> alpha,
1039 const DeviceMemory<std::complex<float> *> &as,
1040 int lda,
1041 DeviceMemory<std::complex<float> *> *bs,
1042 int ldb, int batch_count) {
1043 return DoBlasInternal(
1044 wrap::rocblas_ctrsm_batched, stream, true /* = pointer_mode_host */,
1045 ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
1046 ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
1047 static_cast<const rocblas_float_complex *const *>(as.opaque()), lda,
1048 static_cast<rocblas_float_complex *const *>(bs->opaque()), ldb,
1049 batch_count);
1050 }
1051
DoBlasTrsmBatched(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double> * > & as,int lda,DeviceMemory<std::complex<double> * > * bs,int ldb,int batch_count)1052 bool ROCMBlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
1053 blas::UpperLower uplo, blas::Transpose transa,
1054 blas::Diagonal diag, uint64_t m, uint64 n,
1055 std::complex<double> alpha,
1056 const DeviceMemory<std::complex<double> *> &as,
1057 int lda,
1058 DeviceMemory<std::complex<double> *> *bs,
1059 int ldb, int batch_count) {
1060 return DoBlasInternal(
1061 wrap::rocblas_ztrsm_batched, stream, true /* = pointer_mode_host */,
1062 ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
1063 ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
1064 static_cast<const rocblas_double_complex *const *>(as.opaque()), lda,
1065 static_cast<rocblas_double_complex *const *>(bs->opaque()), ldb,
1066 batch_count);
1067 }
1068
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,blas::DataType dtype,const void * alpha,const DeviceMemoryBase & a,int lda,int64_t stride_a,const DeviceMemoryBase & b,int ldb,int64_t stride_b,const void * beta,DeviceMemoryBase * c,int ldc,int64_t stride_c,int batch_count)1069 port::Status ROCMBlas::DoBlasGemmStridedBatched(
1070 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
1071 uint64_t n, uint64 k, blas::DataType dtype, const void *alpha,
1072 const DeviceMemoryBase &a, int lda, int64_t stride_a,
1073 const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta,
1074 DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count) {
1075 VLOG(1) << absl::StreamFormat(
1076 "doing rocBLAS SGEMM Strided Batched<float>: at=%d bt=%d m=%u n=%u "
1077 "k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p "
1078 "c=%p ldc=%d",
1079 static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
1080 a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
1081
1082 switch (dtype) {
1083 case blas::DataType::kHalf: {
1084 const Eigen::half alpha_half(*static_cast<const float *>(alpha));
1085 const Eigen::half beta_half(*static_cast<const float *>(beta));
1086 return DoBlasInternalStatus(
1087 wrap::rocblas_hgemm_strided_batched, stream,
1088 false, /* pointer_mode_host */
1089 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
1090 reinterpret_cast<const rocblas_half *>(&alpha_half),
1091 reinterpret_cast<const rocblas_half *>(a.opaque()), lda, stride_a,
1092 reinterpret_cast<const rocblas_half *>(b.opaque()), ldb, stride_b,
1093 reinterpret_cast<const rocblas_half *>(&beta_half),
1094 reinterpret_cast<rocblas_half *>(c->opaque()), ldc, stride_c,
1095 batch_count);
1096 }
1097 case blas::DataType::kBF16:
1098 return DoBlasInternalStatus(
1099 wrap::rocblas_gemm_strided_batched_ex, stream,
1100 false, /* pointer_mode_host */
1101 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, alpha,
1102 a.opaque(), rocblas_datatype_bf16_r, lda, stride_a, b.opaque(),
1103 rocblas_datatype_bf16_r, ldb, stride_b, beta, c->opaque(),
1104 rocblas_datatype_bf16_r, ldc, stride_c, c->opaque(),
1105 rocblas_datatype_bf16_r, ldc, stride_c, batch_count,
1106 rocblas_datatype_f32_r, rocblas_gemm_algo_standard, 0, 0);
1107 case blas::DataType::kFloat:
1108 return DoBlasInternalStatus(
1109 wrap::rocblas_sgemm_strided_batched, stream,
1110 false, /* pointer_mode_host */
1111 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
1112 reinterpret_cast<const float *>(alpha),
1113 reinterpret_cast<const float *>(a.opaque()), lda, stride_a,
1114 reinterpret_cast<const float *>(b.opaque()), ldb, stride_b,
1115 reinterpret_cast<const float *>(beta),
1116 reinterpret_cast<float *>(c->opaque()), ldc, stride_c, batch_count);
1117 case blas::DataType::kDouble:
1118 return DoBlasInternalStatus(
1119 wrap::rocblas_dgemm_strided_batched, stream,
1120 false, /* pointer_mode_host */
1121 ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
1122 reinterpret_cast<const double *>(alpha),
1123 reinterpret_cast<const double *>(a.opaque()), lda, stride_a,
1124 reinterpret_cast<const double *>(b.opaque()), ldb, stride_b,
1125 reinterpret_cast<const double *>(beta),
1126 reinterpret_cast<double *>(c->opaque()), ldc, stride_c, batch_count);
1127 case blas::DataType::kComplexFloat: {
1128 auto cb_alpha =
1129 complex_cast(*static_cast<const std::complex<float> *>(alpha));
1130 auto cb_beta =
1131 complex_cast(*static_cast<const std::complex<float> *>(beta));
1132 return DoBlasInternalStatus(
1133 wrap::rocblas_cgemm_strided_batched, stream,
1134 /* pointer_mode_host = */ true, ROCMBlasTranspose(transa),
1135 ROCMBlasTranspose(transb), m, n, k, cb_alpha,
1136 static_cast<const rocblas_float_complex *>(a.opaque()), lda, stride_a,
1137 static_cast<const rocblas_float_complex *>(b.opaque()), ldb, stride_b,
1138 cb_beta, static_cast<rocblas_float_complex *>(c->opaque()), ldc,
1139 stride_c, batch_count);
1140 }
1141 case blas::DataType::kComplexDouble: {
1142 auto cb_alpha =
1143 complex_cast(*static_cast<const std::complex<double> *>(alpha));
1144 auto cb_beta =
1145 complex_cast(*static_cast<const std::complex<double> *>(beta));
1146 return DoBlasInternalStatus(
1147 wrap::rocblas_zgemm_strided_batched, stream,
1148 /* pointer_mode_host = */ true, ROCMBlasTranspose(transa),
1149 ROCMBlasTranspose(transb), m, n, k, cb_alpha,
1150 static_cast<const rocblas_double_complex *>(a.opaque()), lda,
1151 stride_a, static_cast<const rocblas_double_complex *>(b.opaque()),
1152 ldb, stride_b, cb_beta,
1153 static_cast<rocblas_double_complex *>(c->opaque()), ldc, stride_c,
1154 batch_count);
1155 }
1156 default:
1157 return port::InternalError(absl::StrCat("Unsupported datatype for GEMM: ",
1158 blas::DataTypeString(dtype)));
1159 }
1160 }
1161
GetVersion(string * version)1162 port::Status ROCMBlas::GetVersion(string *version) {
1163 return port::UnimplementedError("");
1164 }
1165
1166 } // namespace gpu
1167
initialize_rocblas()1168 void initialize_rocblas() {
1169 auto rocBlasAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
1170 rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
1171
1172 if (!rocBlasAlreadyRegistered) {
1173 port::Status status =
1174 PluginRegistry::Instance()
1175 ->RegisterFactory<PluginRegistry::BlasFactory>(
1176 rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS",
1177 [](internal::StreamExecutorInterface *parent)
1178 -> blas::BlasSupport * {
1179 gpu::GpuExecutor *rocm_executor =
1180 dynamic_cast<gpu::GpuExecutor *>(parent);
1181 if (rocm_executor == nullptr) {
1182 LOG(ERROR)
1183 << "Attempting to initialize an instance of the "
1184 "rocBLAS "
1185 << "support library with a non-ROCM StreamExecutor";
1186 return nullptr;
1187 }
1188
1189 gpu::ROCMBlas *blas = new gpu::ROCMBlas(rocm_executor);
1190 if (!blas->Init()) {
1191 // Note: Init() will log a more specific error.
1192 delete blas;
1193 return nullptr;
1194 }
1195 return blas;
1196 });
1197
1198 if (!status.ok()) {
1199 LOG(ERROR) << "Unable to register rocBLAS factory: "
1200 << status.error_message();
1201 }
1202
1203 PluginRegistry::Instance()->SetDefaultFactory(
1204 rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
1205 }
1206 }
1207
1208 } // namespace stream_executor
1209
1210 REGISTER_MODULE_INITIALIZER(register_rocblas,
1211 { stream_executor::initialize_rocblas(); });
1212