xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/rocm_solvers.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3    Licensed under the Apache License, Version 2.0 (the "License");
4    you may not use this file except in compliance with the License.
5    You may obtain a copy of the License at
6 
7    http://www.apache.org/licenses/LICENSE-2.0
8 
9    Unless required by applicable law or agreed to in writing, software
10    distributed under the License is distributed on an "AS IS" BASIS,
11    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12    See the License for the specific language governing permissions and
13    limitations under the License.
14    ==============================================================================
15 */
16 /*
17 Mapping of GpuSolver Methods to respective ROCm Library APIs.
18 /-----------------------------------------------------------=-----------------------------/
19 / GpuSolverMethod  //    rocblasAPI    //   rocsolverAPI    //   hipsolverAPI
20 (ROCM>4.5)  / / Geam             //  rocblas_Xgeam   //    ----           //
21 ----                   / / Getrf            //      ----        //
22 rocsolver_Xgetrf  //   hipsolverXgetrf          / / GetrfBatched     // ---- //
23 ""_Xgetrf_batched //     ----                   / / GetriBatched     // ---- //
24 ""_Xgetri_batched //     ----                   / / Getrs            // ---- //
25 rocsolver_Xgetrs  //   hipsolverXgetrs          / / GetrsBatched     // ---- //
26 ""_Xgetrs_batched //     ----                   / / Geqrf            // ---- //
27 rocsolver_Xgeqrf  //   hipsolverXgeqrf          / / Heevd            // ---- //
28 ----           //   hipsolverXheevd          / / Potrf            //      ----
29 // rocsolver_Xpotrf  //   hipsolverXpotrf          / / PotrfBatched     // ----
30 // ""_Xpotrf_batched //   ""XpotrfBatched          / / Trsm             //
31 rocblas_Xtrsm   //    ----           //     ----                   / / Ungqr //
32 ----        // rocsolver_Xungqr  //   hipsolverXungqr          / / Unmqr // ----
33 // rocsolver_Xunmqr  //   hipsolverXunmqr          /
34 /-----------------------------------------------------------------------------------------/
35 */
36 #if TENSORFLOW_USE_ROCM
37 #include <complex>
38 #include <unordered_map>
39 #include <vector>
40 
41 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
42 #include "tensorflow/core/framework/op_kernel.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/lib/core/blocking_counter.h"
45 #include "tensorflow/core/lib/core/status.h"
46 #include "tensorflow/core/lib/core/stringpiece.h"
47 #include "tensorflow/core/lib/gtl/inlined_vector.h"
48 #include "tensorflow/core/platform/mutex.h"
49 #include "tensorflow/core/platform/stream_executor.h"
50 #include "tensorflow/core/platform/types.h"
51 #include "tensorflow/core/util/gpu_solvers.h"
52 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
53 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
54 #include "tensorflow/stream_executor/lib/env.h"
55 #include "tensorflow/stream_executor/platform/default/dso_loader.h"
56 #include "tensorflow/stream_executor/platform/port.h"
57 #include "tensorflow/stream_executor/rocm/rocblas_wrapper.h"
58 
59 namespace tensorflow {
60 namespace {
61 
62 using stream_executor::gpu::GpuExecutor;
63 using stream_executor::gpu::ScopedActivateExecutorContext;
64 
CopyHostToDevice(OpKernelContext * context,void * dst,const void * src,uint64 bytes)65 inline bool CopyHostToDevice(OpKernelContext* context, void* dst,
66                              const void* src, uint64 bytes) {
67   auto stream = context->op_device_context()->stream();
68   se::DeviceMemoryBase wrapped_dst(dst);
69   return stream->ThenMemcpy(&wrapped_dst, src, bytes).ok();
70 }
71 
72 struct GpuSolverHandles {
GpuSolverHandlestensorflow::__anonee4b71a20111::GpuSolverHandles73   explicit GpuSolverHandles(GpuExecutor* parent, hipStream_t stream) {
74     parent_ = parent;
75     ScopedActivateExecutorContext sac{parent_};
76 #if TF_ROCM_VERSION >= 40500
77     CHECK(wrap::hipsolverCreate(&hipsolver_handle) == rocblas_status_success)
78         << "Failed to create hipsolver instance";
79 #endif
80     CHECK(wrap::rocblas_create_handle(&rocm_blas_handle) ==
81           rocblas_status_success)
82         << "Failed to create rocBlas instance.";
83     CHECK(wrap::rocblas_set_stream(rocm_blas_handle, stream) ==
84           rocblas_status_success)
85         << "Failed to set rocBlas stream.";
86   }
87 
~GpuSolverHandlestensorflow::__anonee4b71a20111::GpuSolverHandles88   ~GpuSolverHandles() {
89     ScopedActivateExecutorContext sac{parent_};
90     CHECK(wrap::rocblas_destroy_handle(rocm_blas_handle) ==
91           rocblas_status_success)
92         << "Failed to destroy rocBlas instance.";
93 #if TF_ROCM_VERSION >= 40500
94     CHECK(wrap::hipsolverDestroy(hipsolver_handle) == rocblas_status_success)
95         << "Failed to destroy hipsolver instance.";
96 #endif
97   }
98   GpuExecutor* parent_;
99   rocblas_handle rocm_blas_handle;
100 #if TF_ROCM_VERSION >= 40500
101   hipsolverHandle_t hipsolver_handle;
102 #endif
103 };
104 
105 using HandleMap =
106     std::unordered_map<hipStream_t, std::unique_ptr<GpuSolverHandles>>;
107 
108 // Returns a singleton map used for storing initialized handles for each unique
109 // gpu stream.
GetHandleMapSingleton()110 HandleMap* GetHandleMapSingleton() {
111   static HandleMap* cm = new HandleMap;
112   return cm;
113 }
114 
115 static mutex handle_map_mutex(LINKER_INITIALIZED);
116 
117 }  // namespace
118 
GpuSolver(OpKernelContext * context)119 GpuSolver::GpuSolver(OpKernelContext* context) : context_(context) {
120   mutex_lock lock(handle_map_mutex);
121   GpuExecutor* gpu_executor = static_cast<GpuExecutor*>(
122       context->op_device_context()->stream()->parent()->implementation());
123   const hipStream_t* hip_stream_ptr = CHECK_NOTNULL(
124       reinterpret_cast<const hipStream_t*>(context->op_device_context()
125                                                ->stream()
126                                                ->implementation()
127                                                ->GpuStreamMemberHack()));
128 
129   hip_stream_ = *hip_stream_ptr;
130   HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton());
131   auto it = handle_map->find(hip_stream_);
132   if (it == handle_map->end()) {
133     LOG(INFO) << "Creating GpuSolver handles for stream " << hip_stream_;
134     // Previously unseen Gpu stream. Initialize a set of Gpu solver library
135     // handles for it.
136     std::unique_ptr<GpuSolverHandles> new_handles(
137         new GpuSolverHandles(gpu_executor, hip_stream_));
138     it = handle_map->insert(std::make_pair(hip_stream_, std::move(new_handles)))
139              .first;
140   }
141   rocm_blas_handle_ = it->second->rocm_blas_handle;
142 #if TF_ROCM_VERSION >= 40500
143   hipsolver_handle_ = it->second->hipsolver_handle;
144 #endif
145 }
146 
~GpuSolver()147 GpuSolver::~GpuSolver() {
148   for (auto tensor_ref : scratch_tensor_refs_) {
149     tensor_ref.Unref();
150   }
151 }
152 
153 // Static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<GpuSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_infos,std::function<void (const Status &,const std::vector<HostLapackInfo> &)> info_checker_callback)154 void GpuSolver::CheckLapackInfoAndDeleteSolverAsync(
155     std::unique_ptr<GpuSolver> solver,
156     const std::vector<DeviceLapackInfo>& dev_lapack_infos,
157     std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
158         info_checker_callback) {
159   CHECK(info_checker_callback != nullptr);
160   std::vector<HostLapackInfo> host_lapack_infos;
161   if (dev_lapack_infos.empty()) {
162     info_checker_callback(Status::OK(), host_lapack_infos);
163     return;
164   }
165 
166   // Launch memcpys to copy info back from device to host
167   for (const auto& dev_lapack_info : dev_lapack_infos) {
168     bool success = true;
169     auto host_copy = dev_lapack_info.CopyToHost(&success);
170     OP_REQUIRES(
171         solver->context(), success,
172         errors::Internal(
173             "Failed to launch copy of dev_lapack_info to host, debug_info = ",
174             dev_lapack_info.debug_info()));
175     host_lapack_infos.push_back(std::move(host_copy));
176   }
177 
178   // This callback checks that all batch items in all calls were processed
179   // successfully and passes status to the info_checker_callback accordingly.
180   auto* stream = solver->context()->op_device_context()->stream();
181   auto wrapped_info_checker_callback =
182       [stream](
183           GpuSolver* solver,
184           std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
185               info_checker_callback,
186           std::vector<HostLapackInfo> host_lapack_infos) {
187         ScopedActivateExecutorContext scoped_activation{stream->parent()};
188         Status status;
189         for (const auto& host_lapack_info : host_lapack_infos) {
190           for (int i = 0; i < host_lapack_info.size() && status.ok(); ++i) {
191             const int info_value = host_lapack_info(i);
192             if (info_value != 0) {
193               status = errors::InvalidArgument(
194                   "Got info = ", info_value, " for batch index ", i,
195                   ", expected info = 0. Debug_info = ",
196                   host_lapack_info.debug_info());
197             }
198           }
199           if (!status.ok()) {
200             break;
201           }
202         }
203         // Delete solver to release temp tensor refs.
204         delete solver;
205 
206         // Delegate further error checking to provided functor.
207         info_checker_callback(status, host_lapack_infos);
208       };
209   // Note: An std::function cannot have unique_ptr arguments (it must be copy
210   // constructible and therefore so must its arguments). Therefore, we release
211   // solver into a raw pointer to be deleted at the end of
212   // wrapped_info_checker_callback.
213   // Release ownership of solver. It will be deleted in the cb callback.
214   auto solver_raw_ptr = solver.release();
215   auto cb =
216       std::bind(wrapped_info_checker_callback, solver_raw_ptr,
217                 std::move(info_checker_callback), std::move(host_lapack_infos));
218 
219   solver_raw_ptr->context()
220       ->device()
221       ->tensorflow_accelerator_device_info()
222       ->event_mgr->ThenExecute(stream, std::move(cb));
223 }
224 
225 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<GpuSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_info,AsyncOpKernel::DoneCallback done)226 void GpuSolver::CheckLapackInfoAndDeleteSolverAsync(
227     std::unique_ptr<GpuSolver> solver,
228     const std::vector<DeviceLapackInfo>& dev_lapack_info,
229     AsyncOpKernel::DoneCallback done) {
230   OpKernelContext* context = solver->context();
231   auto wrapped_done = [context, done](
232                           const Status& status,
233                           const std::vector<HostLapackInfo>& /* unused */) {
234     if (done != nullptr) {
235       OP_REQUIRES_OK_ASYNC(context, status, done);
236       done();
237     } else {
238       OP_REQUIRES_OK(context, status);
239     }
240   };
241   CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_lapack_info,
242                                       wrapped_done);
243 }
244 
245 #define TF_RETURN_IF_ROCBLAS_ERROR(expr)                                  \
246   do {                                                                    \
247     auto status = (expr);                                                 \
248     if (TF_PREDICT_FALSE(status != rocblas_status_success)) {             \
249       return errors::Internal(__FILE__, ":", __LINE__,                    \
250                               ": rocBlas call failed status = ", status); \
251     }                                                                     \
252   } while (0)
253 
254 // Macro that specializes a solver method for all 4 standard
255 // numeric types.
256 #define TF_CALL_ROCSOLV_TYPES(m) \
257   m(float, s) m(double, d) m(std::complex<float>, c) m(std::complex<double>, z)
258 #define TF_CALL_LAPACK_TYPES_NO_COMPLEX(m) m(float, s) m(double, d)
259 #define BLAS_SOLVER_FN(method, type_prefix) \
260   wrap::rocblas##_##type_prefix##method
261 
262 #if TF_ROCM_VERSION >= 40500
263 #define TF_CALL_LAPACK_TYPES(m) \
264   m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
265 #define TF_CALL_LAPACK_TYPES_NO_REAL(m) \
266   m(std::complex<float>, C) m(std::complex<double>, Z)
267 #define SOLVER_FN(method, hip_prefix) wrap::hipsolver##hip_prefix##method
268 #else
269 #define TF_CALL_LAPACK_TYPES(m) \
270   m(float, s) m(double, d) m(std::complex<float>, c) m(std::complex<double>, z)
271 #define TF_CALL_LAPACK_TYPES_NO_REAL(m) \
272   m(std::complex<float>, c) m(std::complex<double>, z)
273 #define SOLVER_FN(method, type_prefix) wrap::rocsolver##_##type_prefix##method
274 #endif
275 
276 // Macros to construct rocsolver/hipsolver method names.
277 #define ROCSOLVER_FN(method, type_prefix) \
278   wrap::rocsolver##_##type_prefix##method
279 #define BUFSIZE_FN(method, hip_prefix) \
280   wrap::hipsolver##hip_prefix##method##_bufferSize
281 
282 #if TF_ROCM_VERSION >= 40500
283 
284 #define GETRF_INSTANCE(Scalar, type_prefix)                                \
285   template <>                                                              \
286   Status GpuSolver::Getrf<Scalar>(int m, int n, Scalar* A, int lda,        \
287                                   int* dev_pivots, int* dev_lapack_info) { \
288     mutex_lock lock(handle_map_mutex);                                     \
289     int lwork;                                                             \
290     TF_RETURN_IF_ROCBLAS_ERROR(BUFSIZE_FN(getrf, type_prefix)(             \
291         hipsolver_handle_, m, n, AsHipComplex(A), lda, &lwork));           \
292     auto dev_work =                                                        \
293         this->GetScratchSpace<Scalar>(lwork, "", /*on_host*/ false);       \
294     TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getrf, type_prefix)(              \
295         hipsolver_handle_, m, n, AsHipComplex(A), lda,                     \
296         AsHipComplex(dev_work.mutable_data()), lwork, dev_pivots,          \
297         dev_lapack_info));                                                 \
298     return Status::OK();                                                   \
299   }
300 
301 TF_CALL_LAPACK_TYPES(GETRF_INSTANCE);
302 
303 #define GEQRF_INSTANCE(Scalar, type_prefix)                                  \
304   template <>                                                                \
305   Status GpuSolver::Geqrf(int m, int n, Scalar* dev_A, int lda,              \
306                           Scalar* dev_tau, int* dev_lapack_info) {           \
307     mutex_lock lock(handle_map_mutex);                                       \
308     int lwork;                                                               \
309     TF_RETURN_IF_ROCBLAS_ERROR(BUFSIZE_FN(geqrf, type_prefix)(               \
310         hipsolver_handle_, m, n, AsHipComplex(dev_A), lda, &lwork));         \
311     auto dev_work =                                                          \
312         this->GetScratchSpace<Scalar>(lwork, "", /*on_host*/ false);         \
313     TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(geqrf, type_prefix)(                \
314         hipsolver_handle_, m, n, AsHipComplex(dev_A), lda,                   \
315         AsHipComplex(dev_tau), AsHipComplex(dev_work.mutable_data()), lwork, \
316         dev_lapack_info));                                                   \
317     return Status::OK();                                                     \
318   }
319 
320 TF_CALL_LAPACK_TYPES(GEQRF_INSTANCE);
321 
322 #define UNMQR_INSTANCE(Scalar, type_prefix)                                    \
323   template <>                                                                  \
324   Status GpuSolver::Unmqr(hipsolverSideMode_t side,                            \
325                           hipsolverOperation_t trans, int m, int n, int k,     \
326                           const Scalar* dev_a, int lda, const Scalar* dev_tau, \
327                           Scalar* dev_c, int ldc, int* dev_lapack_info) {      \
328     mutex_lock lock(handle_map_mutex);                                         \
329     using HipScalar = typename HipComplexT<Scalar>::type;                      \
330     ScratchSpace<uint8> dev_a_copy = this->GetScratchSpace<uint8>(             \
331         sizeof(Scalar*) * m * k, "", /*on host */ false);                      \
332     if (!CopyHostToDevice(context_, dev_a_copy.mutable_data(), dev_a,          \
333                           dev_a_copy.bytes())) {                               \
334       return errors::Internal("Unmqr: Failed to copy ptrs to device");         \
335     }                                                                          \
336     ScratchSpace<uint8> dev_tau_copy = this->GetScratchSpace<uint8>(           \
337         sizeof(Scalar*) * k * n, "", /*on host */ false);                      \
338     if (!CopyHostToDevice(context_, dev_tau_copy.mutable_data(), dev_tau,      \
339                           dev_tau_copy.bytes())) {                             \
340       return errors::Internal("Unmqr: Failed to copy ptrs to device");         \
341     }                                                                          \
342     int lwork;                                                                 \
343     TF_RETURN_IF_ROCBLAS_ERROR(BUFSIZE_FN(unmqr, type_prefix)(                 \
344         hipsolver_handle_, side, trans, m, n, k,                               \
345         reinterpret_cast<HipScalar*>(dev_a_copy.mutable_data()), lda,          \
346         reinterpret_cast<HipScalar*>(dev_tau_copy.mutable_data()),             \
347         AsHipComplex(dev_c), ldc, &lwork));                                    \
348     auto dev_work =                                                            \
349         this->GetScratchSpace<Scalar>(lwork, "", /*on_host*/ false);           \
350     TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(unmqr, type_prefix)(                  \
351         hipsolver_handle_, side, trans, m, n, k,                               \
352         reinterpret_cast<HipScalar*>(dev_a_copy.mutable_data()), lda,          \
353         reinterpret_cast<HipScalar*>(dev_tau_copy.mutable_data()),             \
354         AsHipComplex(dev_c), ldc, AsHipComplex(dev_work.mutable_data()),       \
355         lwork, dev_lapack_info));                                              \
356     return Status::OK();                                                       \
357   }
358 
359 TF_CALL_LAPACK_TYPES_NO_REAL(UNMQR_INSTANCE);
360 
361 #define UNGQR_INSTANCE(Scalar, type_prefix)                                  \
362   template <>                                                                \
363   Status GpuSolver::Ungqr(int m, int n, int k, Scalar* dev_a, int lda,       \
364                           const Scalar* dev_tau, int* dev_lapack_info) {     \
365     mutex_lock lock(handle_map_mutex);                                       \
366     using HipScalar = typename HipComplexT<Scalar>::type;                    \
367     ScratchSpace<uint8> dev_tau_copy = this->GetScratchSpace<uint8>(         \
368         sizeof(HipScalar*) * k * n, "", /*on host */ false);                 \
369     if (!CopyHostToDevice(context_, dev_tau_copy.mutable_data(), dev_tau,    \
370                           dev_tau_copy.bytes())) {                           \
371       return errors::Internal("Ungqr: Failed to copy ptrs to device");       \
372     }                                                                        \
373     int lwork;                                                               \
374     TF_RETURN_IF_ROCBLAS_ERROR(BUFSIZE_FN(ungqr, type_prefix)(               \
375         hipsolver_handle_, m, n, k, AsHipComplex(dev_a), lda,                \
376         reinterpret_cast<HipScalar*>(dev_tau_copy.mutable_data()), &lwork)); \
377     auto dev_work =                                                          \
378         this->GetScratchSpace<Scalar>(lwork, "", /*on_host*/ false);         \
379     TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(ungqr, type_prefix)(                \
380         hipsolver_handle_, m, n, k, AsHipComplex(dev_a), lda,                \
381         reinterpret_cast<HipScalar*>(dev_tau_copy.mutable_data()),           \
382         AsHipComplex(dev_work.mutable_data()), lwork, dev_lapack_info));     \
383     return Status::OK();                                                     \
384   }
385 
386 TF_CALL_LAPACK_TYPES_NO_REAL(UNGQR_INSTANCE);
387 
388 #define POTRF_INSTANCE(Scalar, type_prefix)                              \
389   template <>                                                            \
390   Status GpuSolver::Potrf<Scalar>(hipsolverFillMode_t uplo, int n,       \
391                                   Scalar* dev_A, int lda,                \
392                                   int* dev_lapack_info) {                \
393     mutex_lock lock(handle_map_mutex);                                   \
394     using ROCmScalar = typename ROCmComplexT<Scalar>::type;              \
395     int lwork;                                                           \
396     TF_RETURN_IF_ROCBLAS_ERROR(BUFSIZE_FN(potrf, type_prefix)(           \
397         hipsolver_handle_, uplo, n, AsHipComplex(dev_A), lda, &lwork));  \
398     auto dev_work =                                                      \
399         this->GetScratchSpace<Scalar>(lwork, "", /*on_host*/ false);     \
400     TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(potrf, type_prefix)(            \
401         hipsolver_handle_, uplo, n, AsHipComplex(dev_A), lda,            \
402         AsHipComplex(dev_work.mutable_data()), lwork, dev_lapack_info)); \
403     return Status::OK();                                                 \
404   }
405 
406 TF_CALL_LAPACK_TYPES(POTRF_INSTANCE);
407 
408 #define GETRS_INSTANCE(Scalar, type_prefix)                                    \
409   template <>                                                                  \
410   Status GpuSolver::Getrs<Scalar>(hipsolverOperation_t trans, int n, int nrhs, \
411                                   Scalar* A, int lda, int* dev_pivots,         \
412                                   Scalar* B, int ldb, int* dev_lapack_info) {  \
413     mutex_lock lock(handle_map_mutex);                                         \
414     int lwork;                                                                 \
415     TF_RETURN_IF_ROCBLAS_ERROR(BUFSIZE_FN(getrs, type_prefix)(                 \
416         hipsolver_handle_, trans, n, nrhs, AsHipComplex(A), lda, dev_pivots,   \
417         AsHipComplex(B), ldb, &lwork));                                        \
418     auto dev_work =                                                            \
419         this->GetScratchSpace<Scalar>(lwork, "", /*on_host*/ false);           \
420     TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getrs, type_prefix)(                  \
421         hipsolver_handle_, trans, n, nrhs, AsHipComplex(A), lda, dev_pivots,   \
422         AsHipComplex(B), ldb, AsHipComplex(dev_work.mutable_data()), lwork,    \
423         dev_lapack_info));                                                     \
424     return Status::OK();                                                       \
425   }
426 
427 TF_CALL_LAPACK_TYPES(GETRS_INSTANCE);
428 
429 #define POTRF_BATCHED_INSTANCE(Scalar, type_prefix)                           \
430   template <>                                                                 \
431   Status GpuSolver::PotrfBatched<Scalar>(                                     \
432       hipsolverFillMode_t uplo, int n, const Scalar* const host_a_dev_ptrs[], \
433       int lda, DeviceLapackInfo* dev_lapack_info, int batch_size) {           \
434     rocblas_stride stride = n;                                                \
435     mutex_lock lock(handle_map_mutex);                                        \
436     using HipScalar = typename HipComplexT<Scalar>::type;                     \
437     ScratchSpace<uint8> dev_a = this->GetScratchSpace<uint8>(                 \
438         sizeof(HipScalar*) * batch_size, "", /*on host */ false);             \
439     if (!CopyHostToDevice(context_, dev_a.mutable_data(), host_a_dev_ptrs,    \
440                           dev_a.bytes())) {                                   \
441       return errors::Internal("PotrfBatched: Failed to copy ptrs to device"); \
442     }                                                                         \
443     int lwork;                                                                \
444     TF_RETURN_IF_ROCBLAS_ERROR(BUFSIZE_FN(potrfBatched, type_prefix)(         \
445         hipsolver_handle_, uplo, n,                                           \
446         reinterpret_cast<HipScalar**>(dev_a.mutable_data()), lda, &lwork,     \
447         batch_size));                                                         \
448     auto dev_work =                                                           \
449         this->GetScratchSpace<Scalar>(lwork, "", /*on_host*/ false);          \
450     TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(potrfBatched, type_prefix)(          \
451         hipsolver_handle_, uplo, n,                                           \
452         reinterpret_cast<HipScalar**>(dev_a.mutable_data()), lda,             \
453         AsHipComplex(dev_work.mutable_data()), lwork,                         \
454         dev_lapack_info->mutable_data(), batch_size));                        \
455     return Status::OK();                                                      \
456   }
457 
458 TF_CALL_LAPACK_TYPES(POTRF_BATCHED_INSTANCE);
459 
460 #define HEEVD_INSTANCE(Scalar, type_prefix)                                    \
461   template <>                                                                  \
462   Status GpuSolver::Heevd<Scalar>(                                             \
463       hipsolverEigMode_t jobz, hipsolverFillMode_t uplo, int n, Scalar* dev_A, \
464       int lda, typename Eigen::NumTraits<Scalar>::Real* dev_W,                 \
465       int* dev_lapack_info) {                                                  \
466     mutex_lock lock(handle_map_mutex);                                         \
467     using EigenScalar = typename Eigen::NumTraits<Scalar>::Real;               \
468     int lwork;                                                                 \
469     TF_RETURN_IF_ROCBLAS_ERROR(BUFSIZE_FN(heevd, type_prefix)(                 \
470         hipsolver_handle_, jobz, uplo, n, AsHipComplex(dev_A), lda, dev_W,     \
471         &lwork));                                                              \
472     auto dev_workspace =                                                       \
473         this->GetScratchSpace<Scalar>(lwork, "", /*on host */ false);          \
474     TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(heevd, type_prefix)(                  \
475         hipsolver_handle_, jobz, uplo, n, AsHipComplex(dev_A), lda, dev_W,     \
476         AsHipComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));  \
477     return Status::OK();                                                       \
478   }
479 
480 TF_CALL_LAPACK_TYPES_NO_REAL(HEEVD_INSTANCE);
481 
482 #else
483 // Macro that specializes a solver method for all 4 standard
484 // numeric types.
485 // Macro to construct rocsolver method names.
486 
487 #define GETRF_INSTANCE(Scalar, type_prefix)                                \
488   template <>                                                              \
489   Status GpuSolver::Getrf<Scalar>(int m, int n, Scalar* A, int lda,        \
490                                   int* dev_pivots, int* dev_lapack_info) { \
491     mutex_lock lock(handle_map_mutex);                                     \
492     using ROCmScalar = typename ROCmComplexT<Scalar>::type;                \
493     TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getrf, type_prefix)(              \
494         rocm_blas_handle_, m, n, reinterpret_cast<ROCmScalar*>(A), lda,    \
495         dev_pivots, dev_lapack_info));                                     \
496     return Status::OK();                                                   \
497   }
498 
499 TF_CALL_LAPACK_TYPES(GETRF_INSTANCE);
500 
501 #define GEQRF_INSTANCE(Scalar, type_prefix)                                 \
502   template <>                                                               \
503   Status GpuSolver::Geqrf(int m, int n, Scalar* dev_A, int lda,             \
504                           Scalar* dev_tau, int* dev_lapack_info) {          \
505     mutex_lock lock(handle_map_mutex);                                      \
506     using ROCmScalar = typename ROCmComplexT<Scalar>::type;                 \
507     TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(geqrf, type_prefix)(               \
508         rocm_blas_handle_, m, n, reinterpret_cast<ROCmScalar*>(dev_A), lda, \
509         reinterpret_cast<ROCmScalar*>(dev_tau)));                           \
510     return Status::OK();                                                    \
511   }
512 
513 TF_CALL_LAPACK_TYPES(GEQRF_INSTANCE);
514 
515 #define UMMQR_INSTANCE(Scalar, type_prefix)                                  \
516   template <>                                                                \
517   Status GpuSolver::Unmqr(rocblas_side side, rocblas_operation trans, int m, \
518                           int n, int k, const Scalar* dev_a, int lda,        \
519                           const Scalar* dev_tau, Scalar* dev_c, int ldc,     \
520                           int* dev_lapack_info) {                            \
521     mutex_lock lock(handle_map_mutex);                                       \
522     using ROCmScalar = typename ROCmComplexT<Scalar>::type;                  \
523     ScratchSpace<uint8> dev_a_copy = this->GetScratchSpace<uint8>(           \
524         sizeof(ROCmScalar*) * m * k, "", /*on host */ false);                \
525     if (!CopyHostToDevice(context_, dev_a_copy.mutable_data(), dev_a,        \
526                           dev_a_copy.bytes())) {                             \
527       return errors::Internal("Unmqr: Failed to copy ptrs to device");       \
528     }                                                                        \
529     ScratchSpace<uint8> dev_tau_copy = this->GetScratchSpace<uint8>(         \
530         sizeof(ROCmScalar*) * k * n, "", /*on host */ false);                \
531     if (!CopyHostToDevice(context_, dev_tau_copy.mutable_data(), dev_tau,    \
532                           dev_tau_copy.bytes())) {                           \
533       return errors::Internal("Unmqr: Failed to copy ptrs to device");       \
534     }                                                                        \
535     TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(unmqr, type_prefix)(                \
536         rocm_blas_handle_, side, trans, m, n, k,                             \
537         reinterpret_cast<ROCmScalar*>(dev_a_copy.mutable_data()), lda,       \
538         reinterpret_cast<ROCmScalar*>(dev_tau_copy.mutable_data()),          \
539         reinterpret_cast<ROCmScalar*>(dev_c), ldc));                         \
540     return Status::OK();                                                     \
541   }
542 
543 TF_CALL_LAPACK_TYPES_NO_REAL(UMMQR_INSTANCE);
544 
545 #define UNGQR_INSTANCE(Scalar, type_prefix)                                    \
546   template <>                                                                  \
547   Status GpuSolver::Ungqr(int m, int n, int k, Scalar* dev_a, int lda,         \
548                           const Scalar* dev_tau, int* dev_lapack_info) {       \
549     mutex_lock lock(handle_map_mutex);                                         \
550     using ROCmScalar = typename ROCmComplexT<Scalar>::type;                    \
551     ScratchSpace<uint8> dev_tau_copy = this->GetScratchSpace<uint8>(           \
552         sizeof(ROCmScalar*) * k * n, "", /*on host */ false);                  \
553     if (!CopyHostToDevice(context_, dev_tau_copy.mutable_data(), dev_tau,      \
554                           dev_tau_copy.bytes())) {                             \
555       return errors::Internal("Ungqr: Failed to copy ptrs to device");         \
556     }                                                                          \
557     TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(ungqr, type_prefix)(                  \
558         rocm_blas_handle_, m, n, k, reinterpret_cast<ROCmScalar*>(dev_a), lda, \
559         reinterpret_cast<ROCmScalar*>(dev_tau_copy.mutable_data())));          \
560     return Status::OK();                                                       \
561   }
562 
563 TF_CALL_LAPACK_TYPES_NO_REAL(UNGQR_INSTANCE);
564 
565 #define POTRF_INSTANCE(Scalar, type_prefix)                                    \
566   template <>                                                                  \
567   Status GpuSolver::Potrf<Scalar>(rocblas_fill uplo, int n, Scalar* dev_A,     \
568                                   int lda, int* dev_lapack_info) {             \
569     mutex_lock lock(handle_map_mutex);                                         \
570     using ROCmScalar = typename ROCmComplexT<Scalar>::type;                    \
571     TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(potrf, type_prefix)(                  \
572         rocm_blas_handle_, uplo, n, reinterpret_cast<ROCmScalar*>(dev_A), lda, \
573         dev_lapack_info));                                                     \
574     return Status::OK();                                                       \
575   }
576 
577 TF_CALL_LAPACK_TYPES(POTRF_INSTANCE);
578 
579 #define GETRS_INSTANCE(Scalar, type_prefix)                                   \
580   template <>                                                                 \
581   Status GpuSolver::Getrs<Scalar>(rocblas_operation trans, int n, int nrhs,   \
582                                   Scalar* A, int lda, const int* dev_pivots,  \
583                                   Scalar* B, int ldb, int* dev_lapack_info) { \
584     mutex_lock lock(handle_map_mutex);                                        \
585     using ROCmScalar = typename ROCmComplexT<Scalar>::type;                   \
586     TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getrs, type_prefix)(                 \
587         rocm_blas_handle_, trans, n, nrhs, reinterpret_cast<ROCmScalar*>(A),  \
588         lda, dev_pivots, reinterpret_cast<ROCmScalar*>(B), ldb));             \
589     return Status::OK();                                                      \
590   }
591 
592 TF_CALL_LAPACK_TYPES(GETRS_INSTANCE);
593 
594 #define POTRF_BATCHED_INSTANCE(Scalar, type_prefix)                           \
595   template <>                                                                 \
596   Status GpuSolver::PotrfBatched<Scalar>(                                     \
597       rocblas_fill uplo, int n, const Scalar* const host_a_dev_ptrs[],        \
598       int lda, DeviceLapackInfo* dev_lapack_info, int batch_size) {           \
599     rocblas_stride stride = n;                                                \
600     mutex_lock lock(handle_map_mutex);                                        \
601     using ROCmScalar = typename ROCmComplexT<Scalar>::type;                   \
602     ScratchSpace<uint8> dev_a = this->GetScratchSpace<uint8>(                 \
603         sizeof(ROCmScalar*) * batch_size, "", /*on host */ false);            \
604     if (!CopyHostToDevice(context_, dev_a.mutable_data(), host_a_dev_ptrs,    \
605                           dev_a.bytes())) {                                   \
606       return errors::Internal("PotrfBatched: Failed to copy ptrs to device"); \
607     }                                                                         \
608     TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(potrf_batched, type_prefix)(         \
609         rocm_blas_handle_, uplo, n,                                           \
610         reinterpret_cast<ROCmScalar**>(dev_a.mutable_data()), lda,            \
611         dev_lapack_info->mutable_data(), batch_size));                        \
612     return Status::OK();                                                      \
613   }
614 
615 TF_CALL_LAPACK_TYPES(POTRF_BATCHED_INSTANCE);
616 
617 #endif
618 
619 #define GETRI_BATCHED_INSTANCE(Scalar, type_prefix)                           \
620   template <>                                                                 \
621   Status GpuSolver::GetriBatched<Scalar>(                                     \
622       int n, const Scalar* const host_a_dev_ptrs[], int lda,                  \
623       const int* dev_pivots, const Scalar* const host_a_inverse_dev_ptrs[],   \
624       int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {        \
625     mutex_lock lock(handle_map_mutex);                                        \
626     rocblas_stride stride = n;                                                \
627     using ROCmScalar = typename ROCmComplexT<Scalar>::type;                   \
628     ScratchSpace<uint8> dev_a = this->GetScratchSpace<uint8>(                 \
629         sizeof(ROCmScalar*) * batch_size, "", /*on host */ false);            \
630     if (!CopyHostToDevice(context_, dev_a.mutable_data(), host_a_dev_ptrs,    \
631                           dev_a.bytes())) {                                   \
632       return errors::Internal("GetriBatched: Failed to copy ptrs to device"); \
633     }                                                                         \
634     ScratchSpace<uint8> dev_a_inverse = this->GetScratchSpace<uint8>(         \
635         sizeof(ROCmScalar*) * batch_size, "", /*on host */ false);            \
636     if (!CopyHostToDevice(context_, dev_a_inverse.mutable_data(),             \
637                           host_a_inverse_dev_ptrs, dev_a_inverse.bytes())) {  \
638       return errors::Internal("GetriBatched: Failed to copy ptrs to device"); \
639     }                                                                         \
640     ScratchSpace<uint8> pivots = this->GetScratchSpace<uint8>(                \
641         sizeof(ROCmScalar*) * batch_size, "", /*on host */ false);            \
642     if (!CopyHostToDevice(context_, pivots.mutable_data(), dev_pivots,        \
643                           pivots.bytes())) {                                  \
644       return errors::Internal("GetriBatched: Failed to copy ptrs to device"); \
645     }                                                                         \
646     TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getri_batched, type_prefix)(         \
647         rocm_blas_handle_, n,                                                 \
648         reinterpret_cast<ROCmScalar**>(dev_a.mutable_data()), lda,            \
649         reinterpret_cast<int*>(pivots.mutable_data()), stride,                \
650         dev_lapack_info->mutable_data(), batch_size));                        \
651     return Status::OK();                                                      \
652   }
653 
654 TF_CALL_ROCSOLV_TYPES(GETRI_BATCHED_INSTANCE);
655 
656 #define GETRF_BATCHED_INSTANCE(Scalar, type_prefix)                            \
657   template <>                                                                  \
658   Status GpuSolver::GetrfBatched<Scalar>(                                      \
659       int n, Scalar** A, int lda, int* dev_pivots, DeviceLapackInfo* dev_info, \
660       const int batch_size) {                                                  \
661     mutex_lock lock(handle_map_mutex);                                         \
662     rocblas_stride stride = n;                                                 \
663     using ROCmScalar = typename ROCmComplexT<Scalar>::type;                    \
664     ScratchSpace<uint8> dev_a = this->GetScratchSpace<uint8>(                  \
665         sizeof(ROCmScalar*) * batch_size, "", /*on host */ false);             \
666     if (!CopyHostToDevice(context_, dev_a.mutable_data(), A, dev_a.bytes())) { \
667       return errors::Internal("GetrfBatched: Failed to copy ptrs to device");  \
668     }                                                                          \
669     TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getrf_batched, type_prefix)(          \
670         rocm_blas_handle_, n, n,                                               \
671         reinterpret_cast<ROCmScalar**>(dev_a.mutable_data()), lda, dev_pivots, \
672         stride, dev_info->mutable_data(), batch_size));                        \
673     return Status::OK();                                                       \
674   }
675 
676 TF_CALL_ROCSOLV_TYPES(GETRF_BATCHED_INSTANCE);
677 
678 #define GETRS_BATCHED_INSTANCE(Scalar, type_prefix)                            \
679   template <>                                                                  \
680   Status GpuSolver::GetrsBatched<Scalar>(                                      \
681       const rocblas_operation trans, int n, int nrhs, Scalar** A, int lda,     \
682       int* dev_pivots, Scalar** B, const int ldb, int* host_lapack_info,       \
683       const int batch_size) {                                                  \
684     rocblas_stride stride = n;                                                 \
685     mutex_lock lock(handle_map_mutex);                                         \
686     using ROCmScalar = typename ROCmComplexT<Scalar>::type;                    \
687     ScratchSpace<uint8> dev_a = this->GetScratchSpace<uint8>(                  \
688         sizeof(ROCmScalar*) * batch_size, "", /*on host */ false);             \
689     if (!CopyHostToDevice(context_, dev_a.mutable_data(), A, dev_a.bytes())) { \
690       return errors::Internal("GetrfBatched: Failed to copy ptrs to device");  \
691     }                                                                          \
692     ScratchSpace<uint8> dev_b = this->GetScratchSpace<uint8>(                  \
693         sizeof(ROCmScalar*) * batch_size, "", /*on host */ false);             \
694     if (!CopyHostToDevice(context_, dev_b.mutable_data(), B, dev_b.bytes())) { \
695       return errors::Internal("GetrfBatched: Failed to copy ptrs to device");  \
696     }                                                                          \
697     TF_RETURN_IF_ROCBLAS_ERROR(SOLVER_FN(getrs_batched, type_prefix)(          \
698         rocm_blas_handle_, trans, n, nrhs,                                     \
699         reinterpret_cast<ROCmScalar**>(dev_a.mutable_data()), lda, dev_pivots, \
700         stride, reinterpret_cast<ROCmScalar**>(dev_b.mutable_data()), ldb,     \
701         batch_size));                                                          \
702     return Status::OK();                                                       \
703   }
704 
705 TF_CALL_ROCSOLV_TYPES(GETRS_BATCHED_INSTANCE);
706 
707 // Allocates a temporary tensor. The GpuSolver object maintains a
708 // TensorReference to the underlying Tensor to prevent it from being deallocated
709 // prematurely.
allocate_scoped_tensor(DataType type,const TensorShape & shape,Tensor * out_temp)710 Status GpuSolver::allocate_scoped_tensor(DataType type,
711                                          const TensorShape& shape,
712                                          Tensor* out_temp) {
713   const Status status = context_->allocate_temp(type, shape, out_temp);
714   if (status.ok()) {
715     scratch_tensor_refs_.emplace_back(*out_temp);
716   }
717   return status;
718 }
719 
forward_input_or_allocate_scoped_tensor(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)720 Status GpuSolver::forward_input_or_allocate_scoped_tensor(
721     gtl::ArraySlice<int> candidate_input_indices, DataType type,
722     const TensorShape& shape, Tensor* out_temp) {
723   const Status status = context_->forward_input_or_allocate_temp(
724       candidate_input_indices, type, shape, out_temp);
725   if (status.ok()) {
726     scratch_tensor_refs_.emplace_back(*out_temp);
727   }
728   return status;
729 }
730 
731 template <typename Scalar, typename SolverFnT>
TrsmImpl(GpuExecutor * gpu_executor,SolverFnT solver,rocblas_handle rocm_blas_handle,rocblas_side side,rocblas_fill uplo,rocblas_operation trans,rocblas_diagonal diag,int m,int n,const Scalar * alpha,const Scalar * A,int lda,Scalar * B,int ldb)732 static inline Status TrsmImpl(GpuExecutor* gpu_executor, SolverFnT solver,
733                               rocblas_handle rocm_blas_handle,
734                               rocblas_side side, rocblas_fill uplo,
735                               rocblas_operation trans, rocblas_diagonal diag,
736                               int m, int n,
737                               const Scalar* alpha, /* host or device pointer */
738                               const Scalar* A, int lda, Scalar* B, int ldb) {
739   mutex_lock lock(handle_map_mutex);
740   using ROCmScalar = typename ROCmComplexT<Scalar>::type;
741 
742   ScopedActivateExecutorContext sac{gpu_executor};
743   TF_RETURN_IF_ROCBLAS_ERROR(solver(rocm_blas_handle, side, uplo, trans, diag,
744                                     m, n,
745                                     reinterpret_cast<const ROCmScalar*>(alpha),
746                                     reinterpret_cast<const ROCmScalar*>(A), lda,
747                                     reinterpret_cast<ROCmScalar*>(B), ldb));
748 
749   return Status::OK();
750 }
751 
752 #define TRSM_INSTANCE(Scalar, type_prefix)                                    \
753   template <>                                                                 \
754   Status GpuSolver::Trsm<Scalar>(                                             \
755       rocblas_side side, rocblas_fill uplo, rocblas_operation trans,          \
756       rocblas_diagonal diag, int m, int n,                                    \
757       const Scalar* alpha, /* host or device pointer */                       \
758       const Scalar* A, int lda, Scalar* B, int ldb) {                         \
759     GpuExecutor* gpu_executor = static_cast<GpuExecutor*>(                    \
760         context_->op_device_context()->stream()->parent()->implementation()); \
761     return TrsmImpl(gpu_executor, BLAS_SOLVER_FN(trsm, type_prefix),          \
762                     rocm_blas_handle_, side, uplo, trans, diag, m, n, alpha,  \
763                     A, lda, B, ldb);                                          \
764   }
765 
766 TF_CALL_LAPACK_TYPES_NO_COMPLEX(TRSM_INSTANCE);
767 
768 template <typename Scalar, typename SolverFnT>
MatInvBatchedImpl(GpuExecutor * gpu_executor,SolverFnT solver,rocblas_handle rocm_blas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,int * dev_pivots,const Scalar * const host_a_inverse_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)769 Status MatInvBatchedImpl(GpuExecutor* gpu_executor, SolverFnT solver,
770                          rocblas_handle rocm_blas_handle, int n,
771                          const Scalar* const host_a_dev_ptrs[], int lda,
772                          int* dev_pivots,
773                          const Scalar* const host_a_inverse_dev_ptrs[],
774                          int ldainv, DeviceLapackInfo* dev_lapack_info,
775                          int batch_size) {
776   mutex_lock lock(handle_map_mutex);
777   using ROCmScalar = typename ROCmComplexT<Scalar>::type;
778   ScopedActivateExecutorContext sac{gpu_executor};
779 
780   GetrfBatched(n, host_a_dev_ptrs, lda, dev_pivots, dev_lapack_info,
781                batch_size);
782 
783   GetriBatched(n, host_a_dev_ptrs, lda, dev_pivots, host_a_inverse_dev_ptrs,
784                ldainv, dev_lapack_info, batch_size);
785 
786   return Status::OK();
787 }
788 
789 #define MATINVBATCHED_INSTANCE(Scalar, type_prefix)                           \
790   template <>                                                                 \
791   Status GpuSolver::MatInvBatched<Scalar>(                                    \
792       int n, const Scalar* const host_a_dev_ptrs[], int lda,                  \
793       const Scalar* const host_a_inverse_dev_ptrs[], int ldainv,              \
794       DeviceLapackInfo* dev_lapack_info, int batch_size) {                    \
795     GpuExecutor* gpu_executor = static_cast<GpuExecutor*>(                    \
796         context_->op_device_context()->stream()->parent()->implementation()); \
797     Tensor pivots;                                                            \
798     context_->allocate_scoped_tensor(DataTypeToEnum<int>::value,              \
799                                      TensorShape{batch_size, n}, &pivots);    \
800     auto pivots_mat = pivots.template matrix<int>();                          \
801     int* dev_pivots = pivots_mat.data();                                      \
802     return MatInvBatchedImpl(                                                 \
803         gpu_executor, BLAS_SOLVER_FN(matinvbatched, type_prefix),             \
804         rocm_blas_handle_, n, host_a_dev_ptrs, lda, dev_pivots,               \
805         host_a_inverse_dev_ptrs, ldainv, dev_lapack_info, batch_size);        \
806   }
807 
808 template <typename Scalar, typename SolverFnT>
GeamImpl(GpuExecutor * gpu_executor,SolverFnT solver,rocblas_handle rocm_blas_handle,rocblas_operation transa,rocblas_operation transb,int m,int n,const Scalar * alpha,const Scalar * A,int lda,const Scalar * beta,const Scalar * B,int ldb,Scalar * C,int ldc)809 Status GeamImpl(GpuExecutor* gpu_executor, SolverFnT solver,
810                 rocblas_handle rocm_blas_handle, rocblas_operation transa,
811                 rocblas_operation transb, int m, int n, const Scalar* alpha,
812                 /* host or device pointer */ const Scalar* A, int lda,
813                 const Scalar* beta,
814                 /* host or device pointer */ const Scalar* B, int ldb,
815                 Scalar* C, int ldc) {
816   mutex_lock lock(handle_map_mutex);
817   using ROCmScalar = typename ROCmComplexT<Scalar>::type;
818 
819   ScopedActivateExecutorContext sac{gpu_executor};
820   TF_RETURN_IF_ROCBLAS_ERROR(solver(rocm_blas_handle, transa, transb, m, n,
821                                     reinterpret_cast<const ROCmScalar*>(alpha),
822                                     reinterpret_cast<const ROCmScalar*>(A), lda,
823                                     reinterpret_cast<const ROCmScalar*>(beta),
824                                     reinterpret_cast<const ROCmScalar*>(B), ldb,
825                                     reinterpret_cast<ROCmScalar*>(C), ldc));
826   return Status::OK();
827 }
828 
829 #define GEAM_INSTANCE(Scalar, type_prefix)                                    \
830   template <>                                                                 \
831   Status GpuSolver::Geam<Scalar>(                                             \
832       rocblas_operation transa, rocblas_operation transb, int m, int n,       \
833       const Scalar* alpha, const Scalar* A, int lda, const Scalar* beta,      \
834       const Scalar* B, int ldb, Scalar* C, int ldc) {                         \
835     GpuExecutor* gpu_executor = static_cast<GpuExecutor*>(                    \
836         context_->op_device_context()->stream()->parent()->implementation()); \
837     return GeamImpl(gpu_executor, BLAS_SOLVER_FN(geam, type_prefix),          \
838                     rocm_blas_handle_, transa, transb, m, n, alpha, A, lda,   \
839                     beta, B, ldb, C, ldc);                                    \
840   }
841 
842 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GEAM_INSTANCE);
843 }  // namespace tensorflow
844 
845 #endif  // TENSORFLOW_USE_ROCM
846