xref: /aosp_15_r20/external/tensorflow/tensorflow/stream_executor/rocm/rocm_blas.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/rocm/rocm_blas.h"
17 
18 #include "tensorflow/stream_executor/rocm/rocblas_wrapper.h"
19 
20 #define EIGEN_USE_GPU
21 #include <assert.h>
22 
23 #include <complex>
24 
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_format.h"
27 #include "absl/types/span.h"
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29 #include "tensorflow/stream_executor/device_memory.h"
30 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
31 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
32 #include "tensorflow/stream_executor/gpu/gpu_helpers.h"
33 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
34 #include "tensorflow/stream_executor/gpu/gpu_timer.h"
35 #include "tensorflow/stream_executor/lib/env.h"
36 #include "tensorflow/stream_executor/lib/initialize.h"
37 #include "tensorflow/stream_executor/lib/status.h"
38 #include "tensorflow/stream_executor/platform/dso_loader.h"
39 #include "tensorflow/stream_executor/platform/logging.h"
40 #include "tensorflow/stream_executor/platform/port.h"
41 #include "tensorflow/stream_executor/plugin_registry.h"
42 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
43 #include "tensorflow/stream_executor/scratch_allocator.h"
44 #include "tensorflow/stream_executor/stream_executor.h"
45 
46 namespace stream_executor {
47 namespace gpu {
48 
49 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocBlasPlugin);
50 
51 namespace wrap = tensorflow::wrap;
52 
53 template <class T>
complex_cast(const DeviceMemory<T> & a)54 const typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast(
55     const DeviceMemory<T> &a) {
56   return reinterpret_cast<
57       const typename RocBlasTypeConversionHelper<T>::mapped_type *>(
58       GpuMemory(a));
59 }
60 
61 template <class T>
complex_cast(const T & a)62 const typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast(
63     const T &a) {
64   return reinterpret_cast<
65       const typename RocBlasTypeConversionHelper<T>::mapped_type *>(&a);
66 }
67 template <class T>
complex_cast(DeviceMemory<T> * a)68 typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast(
69     DeviceMemory<T> *a) {
70   return reinterpret_cast<
71       typename RocBlasTypeConversionHelper<T>::mapped_type *>(
72       GpuMemoryMutable(a));
73 }
74 
blas_log(const char * c)75 static void blas_log(const char *c) {}
76 
ToString(rocblas_status status)77 static string ToString(rocblas_status status) {
78   switch (status) {
79     case rocblas_status_success:
80       return "rocblas_status_success";
81     case rocblas_status_invalid_handle:
82       return "rocblas_status_invalid_handle";
83     case rocblas_status_not_implemented:
84       return "rocblas_status_not_implemented";
85     case rocblas_status_invalid_pointer:
86       return "rocblas_status_invalid_pointer";
87     case rocblas_status_invalid_size:
88       return "rocblas_status_invalid_size";
89     case rocblas_status_memory_error:
90       return "rocblas_status_memory_error";
91     case rocblas_status_internal_error:
92       return "rocblas_status_internal_error";
93     default:
94       return absl::StrCat("<invalid rocBLAS status: ", status, ">");
95   }
96 }
97 
Init()98 bool ROCMBlas::Init() {
99   gpu::ScopedActivateExecutorContext sac{parent_};
100   rocblas_status ret = wrap::rocblas_create_handle(&blas_);
101   if (ret != rocblas_status_success) {
102     LOG(ERROR) << "failed to create rocBLAS handle: " << ToString(ret);
103     return false;
104   }
105 
106   return true;
107 }
108 
ROCMBlas(gpu::GpuExecutor * parent)109 ROCMBlas::ROCMBlas(gpu::GpuExecutor *parent)
110     : parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {}
111 
~ROCMBlas()112 ROCMBlas::~ROCMBlas() {
113   if (blas_ != nullptr) {
114     gpu::ScopedActivateExecutorContext sac{parent_};
115     wrap::rocblas_destroy_handle(blas_);
116   }
117 }
118 
SetStream(Stream * stream)119 bool ROCMBlas::SetStream(Stream *stream) {
120   CHECK(stream != nullptr);
121   CHECK(AsGpuStreamValue(stream) != nullptr);
122   CHECK(blas_ != nullptr);
123   gpu::ScopedActivateExecutorContext sac{parent_};
124   rocblas_status ret =
125       wrap::rocblas_set_stream(blas_, AsGpuStreamValue(stream));
126   if (ret != rocblas_status_success) {
127     LOG(ERROR) << "failed to set stream for rocBLAS calls: " << ToString(ret);
128     return false;
129   }
130 
131   return true;
132 }
133 
134 namespace {
135 
136 // Helper functions transforming blas arguments into rocBLAS arguments.
137 
ROCMBlasTranspose(blas::Transpose trans)138 rocblas_operation ROCMBlasTranspose(blas::Transpose trans) {
139   switch (trans) {
140     case blas::Transpose::kNoTranspose:
141       return rocblas_operation_none;
142     case blas::Transpose::kTranspose:
143       return rocblas_operation_transpose;
144     case blas::Transpose::kConjugateTranspose:
145       return rocblas_operation_conjugate_transpose;
146     default:
147       LOG(FATAL) << "Invalid value of blas::Transpose.";
148   }
149 }
150 
ROCMBlasUpperLower(blas::UpperLower uplo)151 rocblas_fill ROCMBlasUpperLower(blas::UpperLower uplo) {
152   switch (uplo) {
153     case blas::UpperLower::kUpper:
154       return rocblas_fill_upper;
155     case blas::UpperLower::kLower:
156       return rocblas_fill_lower;
157     default:
158       LOG(FATAL) << "Invalid value of blas::UpperLower.";
159   }
160 }
161 
ROCMBlasDiagonal(blas::Diagonal diag)162 rocblas_diagonal ROCMBlasDiagonal(blas::Diagonal diag) {
163   switch (diag) {
164     case blas::Diagonal::kUnit:
165       return rocblas_diagonal_unit;
166     case blas::Diagonal::kNonUnit:
167       return rocblas_diagonal_non_unit;
168     default:
169       LOG(FATAL) << "Invalid value of blas::Diagonal.";
170   }
171 }
172 
ROCMBlasSide(blas::Side side)173 rocblas_side ROCMBlasSide(blas::Side side) {
174   switch (side) {
175     case blas::Side::kLeft:
176       return rocblas_side_left;
177     case blas::Side::kRight:
178       return rocblas_side_right;
179     default:
180       LOG(FATAL) << "Invalid value of blas::Side.";
181   }
182 }
183 
184 }  // namespace
185 
186 template <typename FuncT, typename... Args>
DoBlasInternalImpl(FuncT rocblas_func,Stream * stream,bool pointer_mode_host,bool err_on_failure,Args...args)187 bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
188                                   bool pointer_mode_host, bool err_on_failure,
189                                   Args... args) {
190   absl::MutexLock lock{&mu_};
191 
192   CHECK(blas_ != nullptr);
193   if (!SetStream(stream)) {
194     return false;
195   }
196 
197   gpu::ScopedActivateExecutorContext sac{parent_};
198   rocblas_status ret = rocblas_func(blas_, args...);
199   if (err_on_failure && ret != rocblas_status_success) {
200     LOG(ERROR) << "failed to run ROCBLAS routine " << rocblas_func.kName << ": "
201                << ToString(ret);
202   }
203   return ret == rocblas_status_success;
204 }
205 
DoBlasAxpy(Stream * stream,uint64_t elem_count,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)206 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64_t elem_count, float alpha,
207                           const DeviceMemory<float> &x, int incx,
208                           DeviceMemory<float> *y, int incy) {
209   blas_log("DoBlasAxpy");
210   return DoBlasInternal(wrap::rocblas_saxpy, stream,
211                         /* pointer_mode_host = */ true, elem_count, &alpha,
212                         GpuMemory(x), incx, GpuMemoryMutable(y), incy);
213 }
214 
DoBlasAxpy(Stream * stream,uint64_t elem_count,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)215 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64_t elem_count, double alpha,
216                           const DeviceMemory<double> &x, int incx,
217                           DeviceMemory<double> *y, int incy) {
218   blas_log("DoBlasAxpy");
219   return DoBlasInternal(wrap::rocblas_daxpy, stream,
220                         /* pointer_mode_host = */ true, elem_count, &alpha,
221                         GpuMemory(x), incx, GpuMemoryMutable(y), incy);
222 }
223 
DoBlasAxpy(Stream * stream,uint64_t elem_count,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)224 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64_t elem_count,
225                           std::complex<float> alpha,
226                           const DeviceMemory<std::complex<float>> &x, int incx,
227                           DeviceMemory<std::complex<float>> *y, int incy) {
228   return DoBlasInternal(
229       wrap::rocblas_caxpy, stream, /* pointer_mode_host = */ true, elem_count,
230       complex_cast(alpha), complex_cast(x), incx, complex_cast(y), incy);
231 }
232 
DoBlasAxpy(Stream * stream,uint64_t elem_count,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)233 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64_t elem_count,
234                           std::complex<double> alpha,
235                           const DeviceMemory<std::complex<double>> &x, int incx,
236                           DeviceMemory<std::complex<double>> *y, int incy) {
237   return DoBlasInternal(
238       wrap::rocblas_zaxpy, stream, /* pointer_mode_host = */ true, elem_count,
239       complex_cast(alpha), complex_cast(x), incx, complex_cast(y), incy);
240 }
241 
DoBlasCopy(Stream * stream,uint64_t elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)242 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64_t elem_count,
243                           const DeviceMemory<float> &x, int incx,
244                           DeviceMemory<float> *y, int incy) {
245   return DoBlasInternal(wrap::rocblas_scopy, stream,
246                         /* pointer_mode_host = */ true, elem_count,
247                         GpuMemory(x), incx, GpuMemoryMutable(y), incy);
248 }
249 
DoBlasCopy(Stream * stream,uint64_t elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)250 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64_t elem_count,
251                           const DeviceMemory<double> &x, int incx,
252                           DeviceMemory<double> *y, int incy) {
253   return DoBlasInternal(wrap::rocblas_dcopy, stream,
254                         /* pointer_mode_host = */ true, elem_count,
255                         GpuMemory(x), incx, GpuMemoryMutable(y), incy);
256 }
257 
DoBlasCopy(Stream * stream,uint64_t elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)258 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64_t elem_count,
259                           const DeviceMemory<std::complex<float>> &x, int incx,
260                           DeviceMemory<std::complex<float>> *y, int incy) {
261   return DoBlasInternal(wrap::rocblas_ccopy, stream,
262                         /* pointer_mode_host = */ true, elem_count,
263                         complex_cast(x), incx, complex_cast(y), incy);
264 }
265 
DoBlasCopy(Stream * stream,uint64_t elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)266 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64_t elem_count,
267                           const DeviceMemory<std::complex<double>> &x, int incx,
268                           DeviceMemory<std::complex<double>> *y, int incy) {
269   return DoBlasInternal(wrap::rocblas_zcopy, stream,
270                         /* pointer_mode_host = */ true, elem_count,
271                         complex_cast(x), incx, complex_cast(y), incy);
272 }
273 
DoBlasScal(Stream * stream,uint64_t elem_count,float alpha,DeviceMemory<float> * x,int incx)274 bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, float alpha,
275                           DeviceMemory<float> *x, int incx) {
276   blas_log("DoBlasScal<float>");
277   return DoBlasInternal(wrap::rocblas_sscal, stream,
278                         /* pointer_mode_host = */ true, elem_count, &alpha,
279                         GpuMemoryMutable(x), incx);
280 }
281 
DoBlasScal(Stream * stream,uint64_t elem_count,double alpha,DeviceMemory<double> * x,int incx)282 bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, double alpha,
283                           DeviceMemory<double> *x, int incx) {
284   return DoBlasInternal(wrap::rocblas_dscal, stream,
285                         /* pointer_mode_host = */ true, elem_count, &alpha,
286                         GpuMemoryMutable(x), incx);
287 }
288 
DoBlasScal(Stream * stream,uint64_t elem_count,float alpha,DeviceMemory<std::complex<float>> * x,int incx)289 bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, float alpha,
290                           DeviceMemory<std::complex<float>> *x, int incx) {
291   return DoBlasInternal(wrap::rocblas_csscal, stream,
292                         /* pointer_mode_host = */ true, elem_count, &alpha,
293                         complex_cast(x), incx);
294 }
295 
DoBlasScal(Stream * stream,uint64_t elem_count,double alpha,DeviceMemory<std::complex<double>> * x,int incx)296 bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count, double alpha,
297                           DeviceMemory<std::complex<double>> *x, int incx) {
298   return DoBlasInternal(wrap::rocblas_zdscal, stream,
299                         /* pointer_mode_host = */ true, elem_count, &alpha,
300                         complex_cast(x), incx);
301 }
302 
DoBlasScal(Stream * stream,uint64_t elem_count,std::complex<float> alpha,DeviceMemory<std::complex<float>> * x,int incx)303 bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count,
304                           std::complex<float> alpha,
305                           DeviceMemory<std::complex<float>> *x, int incx) {
306   return DoBlasInternal(wrap::rocblas_cscal, stream,
307                         /* pointer_mode_host = */ true, elem_count,
308                         complex_cast(alpha), complex_cast(x), incx);
309 }
310 
DoBlasScal(Stream * stream,uint64_t elem_count,std::complex<double> alpha,DeviceMemory<std::complex<double>> * x,int incx)311 bool ROCMBlas::DoBlasScal(Stream *stream, uint64_t elem_count,
312                           std::complex<double> alpha,
313                           DeviceMemory<std::complex<double>> *x, int incx) {
314   return DoBlasInternal(wrap::rocblas_zscal, stream,
315                         /* pointer_mode_host = */ true, elem_count,
316                         complex_cast(alpha), complex_cast(x), incx);
317 }
318 
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64_t m,uint64_t n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)319 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
320                           uint64_t n, float alpha, const DeviceMemory<float> &a,
321                           int lda, const DeviceMemory<float> &x, int incx,
322                           float beta, DeviceMemory<float> *y, int incy) {
323   blas_log("DoBlasGemv");
324   return DoBlasInternal(
325       wrap::rocblas_sgemv, stream, /* pointer_mode_host = */ true,
326       ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
327       incx, &beta, GpuMemoryMutable(y), incy);
328 }
329 
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64_t m,uint64_t n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)330 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
331                           uint64_t n, double alpha,
332                           const DeviceMemory<double> &a, int lda,
333                           const DeviceMemory<double> &x, int incx, double beta,
334                           DeviceMemory<double> *y, int incy) {
335   blas_log("DoBlasGemv");
336   return DoBlasInternal(
337       wrap::rocblas_dgemv, stream, /* pointer_mode_host = */ true,
338       ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
339       incx, &beta, GpuMemoryMutable(y), incy);
340 }
341 
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64_t m,uint64_t n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)342 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
343                           uint64_t n, std::complex<float> alpha,
344                           const DeviceMemory<std::complex<float>> &a, int lda,
345                           const DeviceMemory<std::complex<float>> &x, int incx,
346                           std::complex<float> beta,
347                           DeviceMemory<std::complex<float>> *y, int incy) {
348   blas_log("DoBlasGemv");
349   return DoBlasInternal(
350       wrap::rocblas_cgemv, stream, /* pointer_mode_host = */ true,
351       ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda,
352       complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
353 }
354 
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64_t m,uint64_t n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)355 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
356                           uint64_t n, std::complex<double> alpha,
357                           const DeviceMemory<std::complex<double>> &a, int lda,
358                           const DeviceMemory<std::complex<double>> &x, int incx,
359                           std::complex<double> beta,
360                           DeviceMemory<std::complex<double>> *y, int incy) {
361   blas_log("DoBlasGemv\n");
362   return DoBlasInternal(
363       wrap::rocblas_zgemv, stream, /* pointer_mode_host = */ true,
364       ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda,
365       complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
366 }
367 
DoBlasSbmv(Stream * stream,blas::UpperLower uplo,uint64_t n,uint64_t k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)368 bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n,
369                           uint64_t k, float alpha, const DeviceMemory<float> &a,
370                           int lda, const DeviceMemory<float> &x, int incx,
371                           float beta, DeviceMemory<float> *y, int incy) {
372   return DoBlasInternal(
373       wrap::rocblas_ssbmv, stream, /* pointer_mode_host = */ true,
374       ROCMBlasUpperLower(uplo), n, k, &alpha, GpuMemory(a), lda, GpuMemory(x),
375       incx, &beta, GpuMemoryMutable(y), incy);
376 }
377 
DoBlasSbmv(Stream * stream,blas::UpperLower uplo,uint64_t n,uint64_t k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)378 bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n,
379                           uint64_t k, double alpha,
380                           const DeviceMemory<double> &a, int lda,
381                           const DeviceMemory<double> &x, int incx, double beta,
382                           DeviceMemory<double> *y, int incy) {
383   return DoBlasInternal(
384       wrap::rocblas_dsbmv, stream, /* pointer_mode_host = */ true,
385       ROCMBlasUpperLower(uplo), n, k, &alpha, GpuMemory(a), lda, GpuMemory(x),
386       incx, &beta, GpuMemoryMutable(y), incy);
387 }
388 
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,blas::DataType dtype,const void * alpha,const DeviceMemoryBase & a,int lda,const DeviceMemoryBase & b,int ldb,const void * beta,DeviceMemoryBase * c,int ldc,blas::ComputePrecision precision)389 port::Status ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
390                                   blas::Transpose transb, uint64_t m, uint64 n,
391                                   uint64_t k, blas::DataType dtype,
392                                   const void *alpha, const DeviceMemoryBase &a,
393                                   int lda, const DeviceMemoryBase &b, int ldb,
394                                   const void *beta, DeviceMemoryBase *c,
395                                   int ldc, blas::ComputePrecision precision) {
396   blas_log("DoBlasGemm");
397   VLOG(1) << absl::StreamFormat(
398       "doing rocBLAS GEMM: at=%d bt=%d m=%u n=%u "
399       "k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p "
400       "c=%p ldc=%d",
401       static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
402       a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
403   if (dtype == blas::DataType::kHalf || dtype == blas::DataType::kFloat) {
404     if (transa == blas::Transpose::kNoTranspose) {
405       if (lda < static_cast<int64_t>(m)) {
406         LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
407                         "precondition violation";
408       }
409     } else {
410       if (lda < static_cast<int64_t>(k)) {
411         LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
412                      << ") (transpose case); precondition violation";
413       }
414     }
415     if (transb == blas::Transpose::kNoTranspose) {
416       if (ldb < static_cast<int64_t>(k)) {
417         LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
418                      << ") (no transpose case); precondition violation";
419       }
420     } else {
421       if (ldb < static_cast<int64_t>(n)) {
422         LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
423                         "precondition violation";
424       }
425     }
426   }
427 
428   switch (dtype) {
429     case blas::DataType::kHalf: {
430       port::StatusOr<bool> maybe_hasXDLOPS = GpuDriver::GetMFMASupport();
431       if (maybe_hasXDLOPS.ok() && maybe_hasXDLOPS.ValueOrDie()) {
432         VLOG(1) << "Using rocblas_gemm_ex";
433         return DoBlasInternalStatus(
434             wrap::rocblas_gemm_ex, stream, /* pointer_mode_host = */ true,
435             ROCMBlasTranspose(transa), ROCMBlasTranspose(transb),
436             (rocblas_int)m, (rocblas_int)n, (rocblas_int)k, alpha, a.opaque(),
437             rocblas_datatype_f16_r, lda, b.opaque(), rocblas_datatype_f16_r,
438             ldb, beta, c->opaque(), rocblas_datatype_f16_r, ldc, c->opaque(),
439             rocblas_datatype_f16_r, ldc, rocblas_datatype_f32_r,
440             rocblas_gemm_algo_standard, 0, 0);
441       } else {
442         VLOG(1) << "Using rocblas_hgemm";
443         const Eigen::half alpha_half(*static_cast<const float *>(alpha));
444         const Eigen::half beta_half(*static_cast<const float *>(beta));
445         return DoBlasInternalStatus(
446             wrap::rocblas_hgemm, stream, /* pointer_mode_host = */ true,
447             ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
448             reinterpret_cast<const rocblas_half *>(&alpha_half),
449             reinterpret_cast<const rocblas_half *>(a.opaque()), lda,
450             reinterpret_cast<const rocblas_half *>(b.opaque()), ldb,
451             reinterpret_cast<const rocblas_half *>(&beta_half),
452             reinterpret_cast<rocblas_half *>(c->opaque()), ldc);
453       }
454     }
455     case blas::DataType::kBF16:
456       return DoBlasInternalStatus(
457           wrap::rocblas_gemm_ex, stream, /* pointer_mode_host = */ true,
458           ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), (rocblas_int)m,
459           (rocblas_int)n, (rocblas_int)k, alpha, a.opaque(),
460           rocblas_datatype_bf16_r, lda, b.opaque(), rocblas_datatype_bf16_r,
461           ldb, beta, c->opaque(), rocblas_datatype_bf16_r, ldc, c->opaque(),
462           rocblas_datatype_bf16_r, ldc, rocblas_datatype_f32_r,
463           rocblas_gemm_algo_standard, 0, 0);
464     case blas::DataType::kFloat:
465       return DoBlasInternalStatus(
466           wrap::rocblas_sgemm, stream, /* pointer_mode_host = */ true,
467           ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
468           static_cast<const float *>(alpha),
469           static_cast<const float *>(a.opaque()), lda,
470           static_cast<const float *>(b.opaque()), ldb,
471           static_cast<const float *>(beta), static_cast<float *>(c->opaque()),
472           ldc);
473     case blas::DataType::kDouble:
474       return DoBlasInternalStatus(
475           wrap::rocblas_dgemm, stream, /* pointer_mode_host = */ true,
476           ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
477           static_cast<const double *>(alpha),
478           static_cast<const double *>(a.opaque()), lda,
479           static_cast<const double *>(b.opaque()), ldb,
480           static_cast<const double *>(beta), static_cast<double *>(c->opaque()),
481           ldc);
482     case blas::DataType::kComplexFloat: {
483       auto cb_alpha =
484           complex_cast(*static_cast<const std::complex<float> *>(alpha));
485       auto cb_beta =
486           complex_cast(*static_cast<const std::complex<float> *>(beta));
487       return DoBlasInternalStatus(
488           wrap::rocblas_cgemm, stream, /* pointer_mode_host = */ true,
489           ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
490           cb_alpha, static_cast<const rocblas_float_complex *>(a.opaque()), lda,
491           static_cast<const rocblas_float_complex *>(b.opaque()), ldb, cb_beta,
492           static_cast<rocblas_float_complex *>(c->opaque()), ldc);
493     }
494     case blas::DataType::kComplexDouble: {
495       auto cb_alpha =
496           complex_cast(*static_cast<const std::complex<double> *>(alpha));
497       auto cb_beta =
498           complex_cast(*static_cast<const std::complex<double> *>(beta));
499       return DoBlasInternalStatus(
500           wrap::rocblas_zgemm, stream, /* pointer_mode_host = */ true,
501           ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
502           cb_alpha, static_cast<const rocblas_double_complex *>(a.opaque()),
503           lda, static_cast<const rocblas_double_complex *>(b.opaque()), ldb,
504           cb_beta, static_cast<rocblas_double_complex *>(c->opaque()), ldc);
505     }
506     default:
507       return port::InternalError(absl::StrCat("Unsupported datatype for GEMM: ",
508                                               blas::DataTypeString(dtype)));
509   }
510 }
511 
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64_t m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy,blas::ProfileResult * output_profile_result)512 bool ROCMBlas::DoBlasGemvWithProfiling(
513     Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, float alpha,
514     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
515     int incx, float beta, DeviceMemory<float> *y, int incy,
516     blas::ProfileResult *output_profile_result) {
517   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
518                                      incx, beta, y, incy,
519                                      output_profile_result);
520 }
521 
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64_t m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy,blas::ProfileResult * output_profile_result)522 bool ROCMBlas::DoBlasGemvWithProfiling(
523     Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, double alpha,
524     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
525     int incx, double beta, DeviceMemory<double> *y, int incy,
526     blas::ProfileResult *output_profile_result) {
527   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
528                                      incx, beta, y, incy,
529                                      output_profile_result);
530 }
531 
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64_t m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy,blas::ProfileResult * output_profile_result)532 bool ROCMBlas::DoBlasGemvWithProfiling(
533     Stream *stream, blas::Transpose trans, uint64_t m, uint64 n,
534     std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
535     int lda, const DeviceMemory<std::complex<float>> &x, int incx,
536     std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
537     blas::ProfileResult *output_profile_result) {
538   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
539                                      incx, beta, y, incy,
540                                      output_profile_result);
541 }
542 
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64_t m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy,blas::ProfileResult * output_profile_result)543 bool ROCMBlas::DoBlasGemvWithProfiling(
544     Stream *stream, blas::Transpose trans, uint64_t m, uint64 n,
545     std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
546     int lda, const DeviceMemory<std::complex<double>> &x, int incx,
547     std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
548     blas::ProfileResult *output_profile_result) {
549   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
550                                      incx, beta, y, incy,
551                                      output_profile_result);
552 }
553 
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ProfileResult * output_profile_result)554 bool ROCMBlas::DoBlasGemmWithProfiling(
555     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
556     uint64_t n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
557     int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
558     DeviceMemory<Eigen::half> *c, int ldc,
559     blas::ProfileResult *output_profile_result) {
560   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
561                                      lda, b, ldb, beta, c, ldc,
562                                      output_profile_result);
563 }
564 
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ProfileResult * output_profile_result)565 bool ROCMBlas::DoBlasGemmWithProfiling(
566     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
567     uint64_t n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
568     const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
569     int ldc, blas::ProfileResult *output_profile_result) {
570   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
571                                      lda, b, ldb, beta, c, ldc,
572                                      output_profile_result);
573 }
574 
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ProfileResult * output_profile_result)575 bool ROCMBlas::DoBlasGemmWithProfiling(
576     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
577     uint64_t n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
578     const DeviceMemory<double> &b, int ldb, double beta,
579     DeviceMemory<double> *c, int ldc,
580     blas::ProfileResult *output_profile_result) {
581   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
582                                      lda, b, ldb, beta, c, ldc,
583                                      output_profile_result);
584 }
585 
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ProfileResult * output_profile_result)586 bool ROCMBlas::DoBlasGemmWithProfiling(
587     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
588     uint64_t n, uint64 k, std::complex<float> alpha,
589     const DeviceMemory<std::complex<float>> &a, int lda,
590     const DeviceMemory<std::complex<float>> &b, int ldb,
591     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
592     blas::ProfileResult *output_profile_result) {
593   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
594                                      lda, b, ldb, beta, c, ldc,
595                                      output_profile_result);
596 }
597 
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ProfileResult * output_profile_result)598 bool ROCMBlas::DoBlasGemmWithProfiling(
599     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
600     uint64_t n, uint64 k, std::complex<double> alpha,
601     const DeviceMemory<std::complex<double>> &a, int lda,
602     const DeviceMemory<std::complex<double>> &b, int ldb,
603     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
604     blas::ProfileResult *output_profile_result) {
605   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
606                                      lda, b, ldb, beta, c, ldc,
607                                      output_profile_result);
608 }
609 
610 template <typename T>
DoBlasGemvWithProfilingImpl(Stream * stream,blas::Transpose trans,uint64_t m,uint64 n,const T & alpha,const DeviceMemory<T> & a,int lda,const DeviceMemory<T> & x,int incx,const T & beta,DeviceMemory<T> * y,int incy,blas::ProfileResult * output_profile_result)611 bool ROCMBlas::DoBlasGemvWithProfilingImpl(
612     Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, const T &alpha,
613     const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx,
614     const T &beta, DeviceMemory<T> *y, int incy,
615     blas::ProfileResult *output_profile_result) {
616   // ROCM TODO: properly implement the interface
617   return false;
618 }
619 
620 template <typename T, typename ParamType>
DoBlasGemmWithProfilingImpl(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,const ParamType & alpha,const DeviceMemory<T> & a,int lda,const DeviceMemory<T> & b,int ldb,const ParamType & beta,DeviceMemory<T> * c,int ldc,blas::ProfileResult * output_profile_result)621 bool ROCMBlas::DoBlasGemmWithProfilingImpl(
622     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
623     uint64_t n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
624     int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
625     DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) {
626   // ROCM TODO: properly implement the interface
627   return false;
628 }
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,const void * alpha,const DeviceMemoryBase & a,blas::DataType type_a,int lda,const DeviceMemoryBase & b,blas::DataType type_b,int ldb,const void * beta,DeviceMemoryBase * c,blas::DataType type_c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)629 port::Status ROCMBlas::DoBlasGemmWithAlgorithm(
630     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
631     uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
632     blas::DataType type_a, int lda, const DeviceMemoryBase &b,
633     blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c,
634     blas::DataType type_c, int ldc, blas::ComputationType computation_type,
635     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
636   // ROCM TODO: properly implement the interface
637   return port::InternalError("Not implemented on ROCm");
638 }
639 
DoBlasGemmStridedBatchedWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,const void * alpha,const DeviceMemoryBase & a,blas::DataType type_a,int lda,int64_t stride_a,const DeviceMemoryBase & b,blas::DataType type_b,int ldb,int64_t stride_b,const void * beta,DeviceMemoryBase * c,blas::DataType type_c,int ldc,int64_t stride_c,int batch_count,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)640 port::Status ROCMBlas::DoBlasGemmStridedBatchedWithAlgorithm(
641     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
642     uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
643     blas::DataType type_a, int lda, int64_t stride_a, const DeviceMemoryBase &b,
644     blas::DataType type_b, int ldb, int64_t stride_b, const void *beta,
645     DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64_t stride_c,
646     int batch_count, blas::ComputationType computation_type,
647     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
648   // ROCM TODO: properly implement the interface
649   return port::InternalError("Not implemented on ROCm");
650 }
651 
GetBlasGemmAlgorithms(Stream * stream,std::vector<blas::AlgorithmType> * out_algorithms)652 bool ROCMBlas::GetBlasGemmAlgorithms(
653     Stream *stream, std::vector<blas::AlgorithmType> *out_algorithms) {
654   // ROCM TODO: properly implement the interface
655   return true;
656 }
657 
658 // This copies from source memory: raw_ptrs[i] to target memory:
659 // device_memory_ptr at the interval of matrix_byte_size, or vice versa.
660 // The below algorithm tries to minimize the number of memcpy by consolidating
661 // neighboring memcpy into a single request
662 template <typename MAPPED_T>
ReorganizeMemory(Stream * stream,DeviceMemory<MAPPED_T> * device_memory,const std::vector<MAPPED_T * > & raw_ptrs,int batch_count,uint64_t batch_stride,bool gather)663 port::Status ReorganizeMemory(Stream *stream,
664                               DeviceMemory<MAPPED_T> *device_memory,
665                               const std::vector<MAPPED_T *> &raw_ptrs,
666                               int batch_count, uint64_t batch_stride,
667                               bool gather) {
668   assert(batch_count > 0);
669   char *device_memory_ptr = static_cast<char *>(device_memory->opaque());
670   char *src_ptr = reinterpret_cast<char *>(raw_ptrs[0]);
671   char *dst_ptr = device_memory_ptr;
672   size_t matrix_byte_size = batch_stride * sizeof(MAPPED_T);
673   uint64_t cur_stride_size = matrix_byte_size;
674 
675   for (int i = 1; i < batch_count; ++i) {
676     if (reinterpret_cast<char *>(raw_ptrs[i]) == src_ptr + cur_stride_size) {
677       cur_stride_size += matrix_byte_size;
678     } else {
679       DeviceMemoryBase src_mem = DeviceMemoryBase(src_ptr, cur_stride_size);
680       DeviceMemoryBase target_mem = DeviceMemoryBase(dst_ptr, cur_stride_size);
681       bool a_status =
682           gather
683               ? stream->ThenMemcpy(&target_mem, src_mem, cur_stride_size).ok()
684               : stream->ThenMemcpy(&src_mem, target_mem, cur_stride_size).ok();
685       if (!a_status) {
686         return port::Status(
687             port::error::INTERNAL,
688             "failed to copy device memory in ROCMBlas::DoBlasGemmBatched");
689       }
690       src_ptr = reinterpret_cast<char *>(raw_ptrs[i]);
691       dst_ptr = device_memory_ptr + i * matrix_byte_size;
692       cur_stride_size = matrix_byte_size;
693     }
694   }
695 
696   DeviceMemoryBase src_mem = DeviceMemoryBase(src_ptr, cur_stride_size);
697   DeviceMemoryBase target_mem = DeviceMemoryBase(dst_ptr, cur_stride_size);
698   bool a_status =
699       gather ? stream->ThenMemcpy(&target_mem, src_mem, cur_stride_size).ok()
700              : stream->ThenMemcpy(&src_mem, target_mem, cur_stride_size).ok();
701   if (!a_status)
702     return port::Status(
703         port::error::INTERNAL,
704         "failed to copy device memory in ROCMBlas::DoBlasGemmBatched");
705   return port::Status::OK();
706 }
707 
708 template <typename T>
AllocateStridedBuffer(const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type * > & raw_ptrs,int batch_count,uint64_t batch_stride,ScratchAllocator * scratch_allocator,Stream * stream,std::unique_ptr<TemporaryDeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>> * temp_memory,DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type> * device_memory,bool copy_data,bool & reallocated)709 port::Status ROCMBlas::AllocateStridedBuffer(
710     const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *>
711         &raw_ptrs,
712     int batch_count, uint64_t batch_stride, ScratchAllocator *scratch_allocator,
713     Stream *stream,
714     std::unique_ptr<TemporaryDeviceMemory<
715         typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
716     DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
717         *device_memory,
718     bool copy_data, bool &reallocated) {
719   assert(device_memory != nullptr);
720 
721   using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type;
722 
723   bool needs_allocate_strided = false;
724   for (int i = 1; i < batch_count; ++i) {
725     uint64_t tmp_batch_stride = raw_ptrs[i] - raw_ptrs[i - 1];
726     if (tmp_batch_stride != batch_stride) {
727       needs_allocate_strided = true;
728       break;
729     }
730   }
731 
732   size_t matrix_byte_size = batch_stride * sizeof(MAPPED_T);
733   size_t matrix_batch_byte_size = matrix_byte_size * batch_count;
734 
735   // No need to do re-allocation, take the short cut and return
736   if (!needs_allocate_strided) {
737     *device_memory = DeviceMemory<MAPPED_T>(
738         DeviceMemoryBase(raw_ptrs[0], matrix_batch_byte_size));
739     reallocated = false;
740     return port::Status::OK();
741   }
742 
743   if (scratch_allocator != nullptr) {
744     TF_ASSIGN_OR_RETURN(
745         DeviceMemory<uint8> batch_matrix_bytes,
746         scratch_allocator->AllocateBytes(matrix_batch_byte_size));
747     *device_memory = DeviceMemory<MAPPED_T>(batch_matrix_bytes);
748   } else {
749     assert(temp_memory != nullptr);
750     TF_ASSIGN_OR_RETURN(*temp_memory, stream->AllocateTemporaryArray<MAPPED_T>(
751                                           matrix_batch_byte_size));
752     *device_memory =
753         DeviceMemory<MAPPED_T>(*(*temp_memory)->mutable_device_memory());
754   }
755 
756   reallocated = true;
757 
758   if (copy_data)
759     return ReorganizeMemory(stream, device_memory, raw_ptrs, batch_count,
760                             batch_stride, true);
761   return port::Status::OK();
762 }
763 
764 template <typename T, typename FuncT>
DoBlasGemmBatchedInternal(FuncT rocblas_func,Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64 k,T alpha,const absl::Span<DeviceMemory<T> * const> & a_ptrs_to_wrappers,int lda,const absl::Span<DeviceMemory<T> * const> & b_ptrs_to_wrappers,int ldb,T beta,const absl::Span<DeviceMemory<T> * const> & c_ptrs_to_wrappers,int ldc,int batch_count,ScratchAllocator * scratch_allocator)765 port::Status ROCMBlas::DoBlasGemmBatchedInternal(
766     FuncT rocblas_func, Stream *stream, blas::Transpose transa,
767     blas::Transpose transb, uint64_t m, uint64 n, uint64 k, T alpha,
768     const absl::Span<DeviceMemory<T> *const> &a_ptrs_to_wrappers, int lda,
769     const absl::Span<DeviceMemory<T> *const> &b_ptrs_to_wrappers, int ldb,
770     T beta, const absl::Span<DeviceMemory<T> *const> &c_ptrs_to_wrappers,
771     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
772   using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type;
773 
774   // Sanity checks before making any further progress
775   uint64_t batch_stride_a = 0;
776   uint64_t batch_stride_b = 0;
777   uint64_t batch_stride_c = 0;
778 
779   assert(ldc >= m);
780   batch_stride_c = ldc * n;
781 
782   if (ROCMBlasTranspose(transa) == rocblas_operation_none) {
783     assert(lda >= m);
784     batch_stride_a = lda * k;
785   } else {
786     assert(lda >= k);
787     batch_stride_a = lda * m;
788   }
789 
790   if (ROCMBlasTranspose(transb) == rocblas_operation_none) {
791     assert(ldb >= k);
792     batch_stride_b = ldb * n;
793   } else {
794     assert(ldb >= n);
795     batch_stride_b = ldb * k;
796   }
797 
798   // Allocate local vectors to hold device pointers to matrices
799   std::vector<MAPPED_T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
800   for (int i = 0; i < batch_count; ++i) {
801     // static_cast does work when converting Eigen::half* to rocblas_half*,
802     // hence the use of reinterpret_cast
803     a_raw_ptrs.push_back(
804         reinterpret_cast<MAPPED_T *>(a_ptrs_to_wrappers[i]->opaque()));
805     b_raw_ptrs.push_back(
806         reinterpret_cast<MAPPED_T *>(b_ptrs_to_wrappers[i]->opaque()));
807     c_raw_ptrs.push_back(
808         reinterpret_cast<MAPPED_T *>(c_ptrs_to_wrappers[i]->opaque()));
809   }
810 
811   DeviceMemory<MAPPED_T> a;
812   // Make sure the temporary memory are in-scope before the function returns
813   std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> a_temp;
814   bool reallocated_a, reallocated_b, reallocated_c;
815   port::Status a_allocation_status = AllocateStridedBuffer<T>(
816       a_raw_ptrs, batch_count, batch_stride_a, scratch_allocator, stream,
817       &a_temp, &a, true, reallocated_a);
818   if (a_allocation_status != port::Status::OK()) {
819     return a_allocation_status;
820   }
821 
822   DeviceMemory<MAPPED_T> b;
823   std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> b_temp;
824   port::Status b_allocation_status = AllocateStridedBuffer<T>(
825       b_raw_ptrs, batch_count, batch_stride_b, scratch_allocator, stream,
826       &b_temp, &b, true, reallocated_b);
827   if (b_allocation_status != port::Status::OK()) {
828     return b_allocation_status;
829   }
830 
831   DeviceMemory<MAPPED_T> c;
832   std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> c_temp;
833   port::Status c_allocation_status = AllocateStridedBuffer<T>(
834       c_raw_ptrs, batch_count, batch_stride_c, scratch_allocator, stream,
835       &c_temp, &c, true, reallocated_c);  // can disable copy if beta=0
836   if (c_allocation_status != port::Status::OK()) {
837     return c_allocation_status;
838   }
839 
840   MAPPED_T *alpha_ptr = reinterpret_cast<MAPPED_T *>(&alpha);
841   MAPPED_T *beta_ptr = reinterpret_cast<MAPPED_T *>(&beta);
842 
843   bool ok;
844   ok = DoBlasInternal(rocblas_func, stream, /* pointer_mode_host = */ true,
845                       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
846                       n, k, GpuComplex(alpha_ptr), GpuMemory(a), lda,
847                       batch_stride_a, GpuMemory(b), ldb, batch_stride_b,
848                       GpuComplex(beta_ptr), GpuMemoryMutable(&c), ldc,
849                       batch_stride_c, batch_count);
850   if (!ok)
851     return port::Status(port::error::INTERNAL,
852                         "failed BLAS call, see log for details");
853   if (reallocated_c)
854     return ReorganizeMemory(stream, &c, c_raw_ptrs, batch_count, batch_stride_c,
855                             false);
856   return port::Status::OK();
857 }
858 
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,float alpha,const absl::Span<DeviceMemory<Eigen::half> * const> & a,int lda,const absl::Span<DeviceMemory<Eigen::half> * const> & b,int ldb,float beta,const absl::Span<DeviceMemory<Eigen::half> * const> & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)859 bool ROCMBlas::DoBlasGemmBatched(
860     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
861     uint64_t n, uint64 k, float alpha,
862     const absl::Span<DeviceMemory<Eigen::half> *const> &a, int lda,
863     const absl::Span<DeviceMemory<Eigen::half> *const> &b, int ldb, float beta,
864     const absl::Span<DeviceMemory<Eigen::half> *const> &c, int ldc,
865     int batch_count, ScratchAllocator *scratch_allocator) {
866   blas_log("DoBlasGemmBatched");
867   const Eigen::half alpha_half(alpha);
868   const Eigen::half beta_half(beta);
869 
870   port::Status status = DoBlasGemmBatchedInternal(
871       wrap::rocblas_hgemm_strided_batched, stream, transa, transb, m, n, k,
872       alpha_half, a, lda, b, ldb, beta_half, c, ldc, batch_count,
873       scratch_allocator);
874   if (!status.ok()) {
875     LOG(ERROR) << status;
876   }
877 
878   return status.ok();
879 }
880 
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,float alpha,const absl::Span<DeviceMemory<float> * const> & a_array,int lda,const absl::Span<DeviceMemory<float> * const> & b_array,int ldb,float beta,const absl::Span<DeviceMemory<float> * const> & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)881 bool ROCMBlas::DoBlasGemmBatched(
882     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
883     uint64_t n, uint64 k, float alpha,
884     const absl::Span<DeviceMemory<float> *const> &a_array, int lda,
885     const absl::Span<DeviceMemory<float> *const> &b_array, int ldb, float beta,
886     const absl::Span<DeviceMemory<float> *const> &c_array, int ldc,
887     int batch_count, ScratchAllocator *scratch_allocator) {
888   blas_log("DoBlasGemmBatched");
889   port::Status status = DoBlasGemmBatchedInternal(
890       wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k,
891       alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
892       scratch_allocator);
893   if (!status.ok()) {
894     LOG(ERROR) << status;
895   }
896   return status.ok();
897 }
898 
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,double alpha,const absl::Span<DeviceMemory<double> * const> & a_array,int lda,const absl::Span<DeviceMemory<double> * const> & b_array,int ldb,double beta,const absl::Span<DeviceMemory<double> * const> & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)899 bool ROCMBlas::DoBlasGemmBatched(
900     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
901     uint64_t n, uint64 k, double alpha,
902     const absl::Span<DeviceMemory<double> *const> &a_array, int lda,
903     const absl::Span<DeviceMemory<double> *const> &b_array, int ldb,
904     double beta, const absl::Span<DeviceMemory<double> *const> &c_array,
905     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
906   blas_log("DoBlasGemmBatched");
907   port::Status status = DoBlasGemmBatchedInternal(
908       wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k,
909       alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
910       scratch_allocator);
911   if (!status.ok()) {
912     LOG(ERROR) << status;
913   }
914   return status.ok();
915 }
916 
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,std::complex<float> alpha,const absl::Span<DeviceMemory<std::complex<float>> * const> & a_array,int lda,const absl::Span<DeviceMemory<std::complex<float>> * const> & b_array,int ldb,std::complex<float> beta,const absl::Span<DeviceMemory<std::complex<float>> * const> & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)917 bool ROCMBlas::DoBlasGemmBatched(
918     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
919     uint64_t n, uint64 k, std::complex<float> alpha,
920     const absl::Span<DeviceMemory<std::complex<float>> *const> &a_array,
921     int lda,
922     const absl::Span<DeviceMemory<std::complex<float>> *const> &b_array,
923     int ldb, std::complex<float> beta,
924     const absl::Span<DeviceMemory<std::complex<float>> *const> &c_array,
925     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
926   blas_log("DoBlasGemmBatched");
927   port::Status status = DoBlasGemmBatchedInternal(
928       wrap::rocblas_cgemm_strided_batched, stream, transa, transb, m, n, k,
929       alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
930       scratch_allocator);
931   if (!status.ok()) {
932     LOG(ERROR) << status;
933   }
934   return status.ok();
935 }
936 
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,std::complex<double> alpha,const absl::Span<DeviceMemory<std::complex<double>> * const> & a_array,int lda,const absl::Span<DeviceMemory<std::complex<double>> * const> & b_array,int ldb,std::complex<double> beta,const absl::Span<DeviceMemory<std::complex<double>> * const> & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)937 bool ROCMBlas::DoBlasGemmBatched(
938     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
939     uint64_t n, uint64 k, std::complex<double> alpha,
940     const absl::Span<DeviceMemory<std::complex<double>> *const> &a_array,
941     int lda,
942     const absl::Span<DeviceMemory<std::complex<double>> *const> &b_array,
943     int ldb, std::complex<double> beta,
944     const absl::Span<DeviceMemory<std::complex<double>> *const> &c_array,
945     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
946   blas_log("DoBlasGemmBatched");
947   port::Status status = DoBlasGemmBatchedInternal(
948       wrap::rocblas_zgemm_strided_batched, stream, transa, transb, m, n, k,
949       alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
950       scratch_allocator);
951   if (!status.ok()) {
952     LOG(ERROR) << status;
953   }
954   return status.ok();
955 }
956 
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)957 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
958                           blas::UpperLower uplo, blas::Transpose transa,
959                           blas::Diagonal diag, uint64_t m, uint64 n,
960                           float alpha, const DeviceMemory<float> &a, int lda,
961                           DeviceMemory<float> *b, int ldb) {
962   blas_log("DoBlasTrsm");
963   return DoBlasInternal(wrap::rocblas_strsm, stream,
964                         /* pointer_mode_host = */ true, ROCMBlasSide(side),
965                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
966                         ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda,
967                         GpuMemoryMutable(b), ldb);
968 }
969 
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)970 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
971                           blas::UpperLower uplo, blas::Transpose transa,
972                           blas::Diagonal diag, uint64_t m, uint64 n,
973                           double alpha, const DeviceMemory<double> &a, int lda,
974                           DeviceMemory<double> *b, int ldb) {
975   blas_log("DoBlasTrsm");
976   return DoBlasInternal(wrap::rocblas_dtrsm, stream,
977                         /* pointer_mode_host = */ true, ROCMBlasSide(side),
978                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
979                         ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(a), lda,
980                         GpuMemoryMutable(b), ldb);
981 }
982 
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)983 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
984                           blas::UpperLower uplo, blas::Transpose transa,
985                           blas::Diagonal diag, uint64_t m, uint64 n,
986                           std::complex<float> alpha,
987                           const DeviceMemory<std::complex<float>> &a, int lda,
988                           DeviceMemory<std::complex<float>> *b, int ldb) {
989   return DoBlasInternal(wrap::rocblas_ctrsm, stream,
990                         /* pointer_mode_host = */ true, ROCMBlasSide(side),
991                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
992                         ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
993                         complex_cast(a), lda, complex_cast(b), ldb);
994 }
995 
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)996 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
997                           blas::UpperLower uplo, blas::Transpose transa,
998                           blas::Diagonal diag, uint64_t m, uint64 n,
999                           std::complex<double> alpha,
1000                           const DeviceMemory<std::complex<double>> &a, int lda,
1001                           DeviceMemory<std::complex<double>> *b, int ldb) {
1002   return DoBlasInternal(wrap::rocblas_ztrsm, stream,
1003                         /* pointer_mode_host = */ true, ROCMBlasSide(side),
1004                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
1005                         ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
1006                         complex_cast(a), lda, complex_cast(b), ldb);
1007 }
1008 
DoBlasTrsmBatched(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,float alpha,const DeviceMemory<float * > & as,int lda,DeviceMemory<float * > * bs,int ldb,int batch_count)1009 bool ROCMBlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
1010                                  blas::UpperLower uplo, blas::Transpose transa,
1011                                  blas::Diagonal diag, uint64_t m, uint64 n,
1012                                  float alpha, const DeviceMemory<float *> &as,
1013                                  int lda, DeviceMemory<float *> *bs, int ldb,
1014                                  int batch_count) {
1015   return DoBlasInternal(wrap::rocblas_strsm_batched, stream,
1016                         true /* = pointer_mode_host */, ROCMBlasSide(side),
1017                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
1018                         ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(as),
1019                         lda, GpuMemoryMutable(bs), ldb, batch_count);
1020 }
1021 
DoBlasTrsmBatched(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,double alpha,const DeviceMemory<double * > & as,int lda,DeviceMemory<double * > * bs,int ldb,int batch_count)1022 bool ROCMBlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
1023                                  blas::UpperLower uplo, blas::Transpose transa,
1024                                  blas::Diagonal diag, uint64_t m, uint64 n,
1025                                  double alpha, const DeviceMemory<double *> &as,
1026                                  int lda, DeviceMemory<double *> *bs, int ldb,
1027                                  int batch_count) {
1028   return DoBlasInternal(wrap::rocblas_dtrsm_batched, stream,
1029                         true /* = pointer_mode_host */, ROCMBlasSide(side),
1030                         ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
1031                         ROCMBlasDiagonal(diag), m, n, &alpha, GpuMemory(as),
1032                         lda, GpuMemoryMutable(bs), ldb, batch_count);
1033 }
1034 
DoBlasTrsmBatched(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float> * > & as,int lda,DeviceMemory<std::complex<float> * > * bs,int ldb,int batch_count)1035 bool ROCMBlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
1036                                  blas::UpperLower uplo, blas::Transpose transa,
1037                                  blas::Diagonal diag, uint64_t m, uint64 n,
1038                                  std::complex<float> alpha,
1039                                  const DeviceMemory<std::complex<float> *> &as,
1040                                  int lda,
1041                                  DeviceMemory<std::complex<float> *> *bs,
1042                                  int ldb, int batch_count) {
1043   return DoBlasInternal(
1044       wrap::rocblas_ctrsm_batched, stream, true /* = pointer_mode_host */,
1045       ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
1046       ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
1047       static_cast<const rocblas_float_complex *const *>(as.opaque()), lda,
1048       static_cast<rocblas_float_complex *const *>(bs->opaque()), ldb,
1049       batch_count);
1050 }
1051 
DoBlasTrsmBatched(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double> * > & as,int lda,DeviceMemory<std::complex<double> * > * bs,int ldb,int batch_count)1052 bool ROCMBlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
1053                                  blas::UpperLower uplo, blas::Transpose transa,
1054                                  blas::Diagonal diag, uint64_t m, uint64 n,
1055                                  std::complex<double> alpha,
1056                                  const DeviceMemory<std::complex<double> *> &as,
1057                                  int lda,
1058                                  DeviceMemory<std::complex<double> *> *bs,
1059                                  int ldb, int batch_count) {
1060   return DoBlasInternal(
1061       wrap::rocblas_ztrsm_batched, stream, true /* = pointer_mode_host */,
1062       ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
1063       ROCMBlasDiagonal(diag), m, n, complex_cast(alpha),
1064       static_cast<const rocblas_double_complex *const *>(as.opaque()), lda,
1065       static_cast<rocblas_double_complex *const *>(bs->opaque()), ldb,
1066       batch_count);
1067 }
1068 
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64_t n,uint64 k,blas::DataType dtype,const void * alpha,const DeviceMemoryBase & a,int lda,int64_t stride_a,const DeviceMemoryBase & b,int ldb,int64_t stride_b,const void * beta,DeviceMemoryBase * c,int ldc,int64_t stride_c,int batch_count)1069 port::Status ROCMBlas::DoBlasGemmStridedBatched(
1070     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
1071     uint64_t n, uint64 k, blas::DataType dtype, const void *alpha,
1072     const DeviceMemoryBase &a, int lda, int64_t stride_a,
1073     const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta,
1074     DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count) {
1075   VLOG(1) << absl::StreamFormat(
1076       "doing rocBLAS SGEMM Strided Batched<float>: at=%d bt=%d m=%u n=%u "
1077       "k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p "
1078       "c=%p ldc=%d",
1079       static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
1080       a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
1081 
1082   switch (dtype) {
1083     case blas::DataType::kHalf: {
1084       const Eigen::half alpha_half(*static_cast<const float *>(alpha));
1085       const Eigen::half beta_half(*static_cast<const float *>(beta));
1086       return DoBlasInternalStatus(
1087           wrap::rocblas_hgemm_strided_batched, stream,
1088           false, /* pointer_mode_host */
1089           ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
1090           reinterpret_cast<const rocblas_half *>(&alpha_half),
1091           reinterpret_cast<const rocblas_half *>(a.opaque()), lda, stride_a,
1092           reinterpret_cast<const rocblas_half *>(b.opaque()), ldb, stride_b,
1093           reinterpret_cast<const rocblas_half *>(&beta_half),
1094           reinterpret_cast<rocblas_half *>(c->opaque()), ldc, stride_c,
1095           batch_count);
1096     }
1097     case blas::DataType::kBF16:
1098       return DoBlasInternalStatus(
1099           wrap::rocblas_gemm_strided_batched_ex, stream,
1100           false, /* pointer_mode_host */
1101           ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, alpha,
1102           a.opaque(), rocblas_datatype_bf16_r, lda, stride_a, b.opaque(),
1103           rocblas_datatype_bf16_r, ldb, stride_b, beta, c->opaque(),
1104           rocblas_datatype_bf16_r, ldc, stride_c, c->opaque(),
1105           rocblas_datatype_bf16_r, ldc, stride_c, batch_count,
1106           rocblas_datatype_f32_r, rocblas_gemm_algo_standard, 0, 0);
1107     case blas::DataType::kFloat:
1108       return DoBlasInternalStatus(
1109           wrap::rocblas_sgemm_strided_batched, stream,
1110           false, /* pointer_mode_host */
1111           ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
1112           reinterpret_cast<const float *>(alpha),
1113           reinterpret_cast<const float *>(a.opaque()), lda, stride_a,
1114           reinterpret_cast<const float *>(b.opaque()), ldb, stride_b,
1115           reinterpret_cast<const float *>(beta),
1116           reinterpret_cast<float *>(c->opaque()), ldc, stride_c, batch_count);
1117     case blas::DataType::kDouble:
1118       return DoBlasInternalStatus(
1119           wrap::rocblas_dgemm_strided_batched, stream,
1120           false, /* pointer_mode_host */
1121           ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
1122           reinterpret_cast<const double *>(alpha),
1123           reinterpret_cast<const double *>(a.opaque()), lda, stride_a,
1124           reinterpret_cast<const double *>(b.opaque()), ldb, stride_b,
1125           reinterpret_cast<const double *>(beta),
1126           reinterpret_cast<double *>(c->opaque()), ldc, stride_c, batch_count);
1127     case blas::DataType::kComplexFloat: {
1128       auto cb_alpha =
1129           complex_cast(*static_cast<const std::complex<float> *>(alpha));
1130       auto cb_beta =
1131           complex_cast(*static_cast<const std::complex<float> *>(beta));
1132       return DoBlasInternalStatus(
1133           wrap::rocblas_cgemm_strided_batched, stream,
1134           /* pointer_mode_host = */ true, ROCMBlasTranspose(transa),
1135           ROCMBlasTranspose(transb), m, n, k, cb_alpha,
1136           static_cast<const rocblas_float_complex *>(a.opaque()), lda, stride_a,
1137           static_cast<const rocblas_float_complex *>(b.opaque()), ldb, stride_b,
1138           cb_beta, static_cast<rocblas_float_complex *>(c->opaque()), ldc,
1139           stride_c, batch_count);
1140     }
1141     case blas::DataType::kComplexDouble: {
1142       auto cb_alpha =
1143           complex_cast(*static_cast<const std::complex<double> *>(alpha));
1144       auto cb_beta =
1145           complex_cast(*static_cast<const std::complex<double> *>(beta));
1146       return DoBlasInternalStatus(
1147           wrap::rocblas_zgemm_strided_batched, stream,
1148           /* pointer_mode_host = */ true, ROCMBlasTranspose(transa),
1149           ROCMBlasTranspose(transb), m, n, k, cb_alpha,
1150           static_cast<const rocblas_double_complex *>(a.opaque()), lda,
1151           stride_a, static_cast<const rocblas_double_complex *>(b.opaque()),
1152           ldb, stride_b, cb_beta,
1153           static_cast<rocblas_double_complex *>(c->opaque()), ldc, stride_c,
1154           batch_count);
1155     }
1156     default:
1157       return port::InternalError(absl::StrCat("Unsupported datatype for GEMM: ",
1158                                               blas::DataTypeString(dtype)));
1159   }
1160 }
1161 
GetVersion(string * version)1162 port::Status ROCMBlas::GetVersion(string *version) {
1163   return port::UnimplementedError("");
1164 }
1165 
1166 }  // namespace gpu
1167 
initialize_rocblas()1168 void initialize_rocblas() {
1169   auto rocBlasAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
1170       rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
1171 
1172   if (!rocBlasAlreadyRegistered) {
1173     port::Status status =
1174         PluginRegistry::Instance()
1175             ->RegisterFactory<PluginRegistry::BlasFactory>(
1176                 rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS",
1177                 [](internal::StreamExecutorInterface *parent)
1178                     -> blas::BlasSupport * {
1179                   gpu::GpuExecutor *rocm_executor =
1180                       dynamic_cast<gpu::GpuExecutor *>(parent);
1181                   if (rocm_executor == nullptr) {
1182                     LOG(ERROR)
1183                         << "Attempting to initialize an instance of the "
1184                            "rocBLAS "
1185                         << "support library with a non-ROCM StreamExecutor";
1186                     return nullptr;
1187                   }
1188 
1189                   gpu::ROCMBlas *blas = new gpu::ROCMBlas(rocm_executor);
1190                   if (!blas->Init()) {
1191                     // Note: Init() will log a more specific error.
1192                     delete blas;
1193                     return nullptr;
1194                   }
1195                   return blas;
1196                 });
1197 
1198     if (!status.ok()) {
1199       LOG(ERROR) << "Unable to register rocBLAS factory: "
1200                  << status.error_message();
1201     }
1202 
1203     PluginRegistry::Instance()->SetDefaultFactory(
1204         rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
1205   }
1206 }
1207 
1208 }  // namespace stream_executor
1209 
1210 REGISTER_MODULE_INITIALIZER(register_rocblas,
1211                             { stream_executor::initialize_rocblas(); });
1212