xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/cuda_solvers.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3    Licensed under the Apache License, Version 2.0 (the "License");
4    you may not use this file except in compliance with the License.
5    You may obtain a copy of the License at
6 
7    http://www.apache.org/licenses/LICENSE-2.0
8 
9    Unless required by applicable law or agreed to in writing, software
10    distributed under the License is distributed on an "AS IS" BASIS,
11    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12    See the License for the specific language governing permissions and
13    limitations under the License.
14    ==============================================================================
15 */
16 #ifdef GOOGLE_CUDA
17 #include <chrono>
18 #include <complex>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "third_party/gpus/cuda/include/cublas_v2.h"
23 #include "third_party/gpus/cuda/include/cusolverDn.h"
24 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/blocking_counter.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/core/stringpiece.h"
30 #include "tensorflow/core/lib/gtl/inlined_vector.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/stream_executor.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/util/gpu_solvers.h"
35 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
36 
37 // The CUDA cublas_api.h API contains const-correctness errors. Instead of
38 // casting away constness on our data, we instead reinterpret the CuBLAS
39 // functions as what they were clearly meant to be, and thus we can call
40 // the functions naturally.
41 //
42 // (The error is that input-only arrays are bound to parameter types
43 // "const T**" instead of the correct "const T* const*".)
44 extern "C" {
45 using getrs_S = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
46                                const float* const*, int, const int*, float**,
47                                int, int*, int);
48 using getrs_D = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
49                                const double* const*, int, const int*, double**,
50                                int, int*, int);
51 using getrs_C = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
52                                const float2* const*, int, const int*, float2**,
53                                int, int*, int);
54 using getrs_Z = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
55                                const double2* const*, int, const int*,
56                                double2**, int, int*, int);
57 
58 using getri_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
59                                const int*, float**, int, int*, int);
60 using getri_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
61                                const int*, double**, int, int*, int);
62 using getri_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
63                                const int*, float2**, int, int*, int);
64 using getri_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
65                                const int*, double2**, int, int*, int);
66 
67 using matinv_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
68                                 float**, int, int*, int);
69 using matinv_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
70                                 double**, int, int*, int);
71 using matinv_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
72                                 float2**, int, int*, int);
73 using matinv_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
74                                 double2**, int, int*, int);
75 
76 using trsm_S = cublasStatus_t(cublasContext*, cublasSideMode_t,
77                               cublasFillMode_t, cublasOperation_t,
78                               cublasDiagType_t, int, int, const float*,
79                               const float* const*, int, float* const*, int,
80                               int);
81 using trsm_D = cublasStatus_t(cublasContext*, cublasSideMode_t,
82                               cublasFillMode_t, cublasOperation_t,
83                               cublasDiagType_t, int, int, const double*,
84                               const double* const*, int, double* const*, int,
85                               int);
86 using trsm_C = cublasStatus_t(cublasContext*, cublasSideMode_t,
87                               cublasFillMode_t, cublasOperation_t,
88                               cublasDiagType_t, int, int, const float2*,
89                               const float2* const*, int, float2* const*, int,
90                               int);
91 using trsm_Z = cublasStatus_t(cublasContext*, cublasSideMode_t,
92                               cublasFillMode_t, cublasOperation_t,
93                               cublasDiagType_t, int, int, const double2*,
94                               const double2* const*, int, double2* const*, int,
95                               int);
96 }
97 
98 namespace tensorflow {
99 namespace {
100 
101 using se::cuda::ScopedActivateExecutorContext;
102 
CopyHostToDevice(OpKernelContext * context,void * dst,const void * src,uint64 bytes)103 inline bool CopyHostToDevice(OpKernelContext* context, void* dst,
104                              const void* src, uint64 bytes) {
105   auto stream = context->op_device_context()->stream();
106   se::DeviceMemoryBase wrapped_dst(dst);
107   return stream->ThenMemcpy(&wrapped_dst, src, bytes).ok();
108 }
109 
110 // A set of initialized handles to the underlying Cuda libraries used by
111 // GpuSolver. We maintain one such set of handles per unique stream.
112 struct GpuSolverHandles {
GpuSolverHandlestensorflow::__anonc494060e0111::GpuSolverHandles113   explicit GpuSolverHandles(cudaStream_t stream) {
114     CHECK(cusolverDnCreate(&cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
115         << "Failed to create cuSolverDN instance.";
116     CHECK(cusolverDnSetStream(cusolver_dn_handle, stream) ==
117           CUSOLVER_STATUS_SUCCESS)
118         << "Failed to set cuSolverDN stream.";
119     CHECK(cublasCreate(&cublas_handle) == CUBLAS_STATUS_SUCCESS)
120         << "Failed to create cuBlas instance.";
121     CHECK(cublasSetStream(cublas_handle, stream) == CUBLAS_STATUS_SUCCESS)
122         << "Failed to set cuBlas stream.";
123   }
124 
~GpuSolverHandlestensorflow::__anonc494060e0111::GpuSolverHandles125   ~GpuSolverHandles() {
126     CHECK(cublasDestroy(cublas_handle) == CUBLAS_STATUS_SUCCESS)
127         << "Failed to destroy cuBlas instance.";
128     CHECK(cusolverDnDestroy(cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
129         << "Failed to destroy cuSolverDN instance.";
130   }
131   cublasHandle_t cublas_handle;
132   cusolverDnHandle_t cusolver_dn_handle;
133 };
134 
135 static mutex handle_map_mutex(LINKER_INITIALIZED);
136 
137 using HandleMap =
138     std::unordered_map<cudaStream_t, std::unique_ptr<GpuSolverHandles>>;
139 
140 // Returns a singleton map used for storing initialized handles for each unique
141 // cuda stream.
GetHandleMapSingleton()142 HandleMap* GetHandleMapSingleton() {
143   static HandleMap* cm = new HandleMap;
144   return cm;
145 }
146 
147 }  // namespace
148 
149 #define TF_RETURN_IF_CUSOLVER_ERROR(expr)                      \
150   do {                                                         \
151     auto status = (expr);                                      \
152     if (TF_PREDICT_FALSE(status != CUSOLVER_STATUS_SUCCESS)) { \
153       return errors::Internal(                                 \
154           __FILE__, ":", __LINE__,                             \
155           ": cuSolverDN call failed with status =", status);   \
156     }                                                          \
157   } while (0)
158 
159 #define TF_RETURN_IF_CUBLAS_ERROR(expr)                                  \
160   do {                                                                   \
161     auto status = (expr);                                                \
162     if (TF_PREDICT_FALSE(status != CUBLAS_STATUS_SUCCESS)) {             \
163       return errors::Internal(__FILE__, ":", __LINE__,                   \
164                               ": cuBlas call failed status = ", status); \
165     }                                                                    \
166   } while (0)
167 
GpuSolver(OpKernelContext * context)168 GpuSolver::GpuSolver(OpKernelContext* context) : context_(context) {
169   mutex_lock lock(handle_map_mutex);
170   const cudaStream_t* cu_stream_ptr = CHECK_NOTNULL(
171       reinterpret_cast<const cudaStream_t*>(context->op_device_context()
172                                                 ->stream()
173                                                 ->implementation()
174                                                 ->GpuStreamMemberHack()));
175   cuda_stream_ = *cu_stream_ptr;
176   HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton());
177   auto it = handle_map->find(cuda_stream_);
178   if (it == handle_map->end()) {
179     LOG(INFO) << "Creating GpuSolver handles for stream " << cuda_stream_;
180     // Previously unseen Cuda stream. Initialize a set of Cuda solver library
181     // handles for it.
182     std::unique_ptr<GpuSolverHandles> new_handles(
183         new GpuSolverHandles(cuda_stream_));
184     it =
185         handle_map->insert(std::make_pair(cuda_stream_, std::move(new_handles)))
186             .first;
187   }
188   cusolver_dn_handle_ = it->second->cusolver_dn_handle;
189   cublas_handle_ = it->second->cublas_handle;
190 }
191 
~GpuSolver()192 GpuSolver::~GpuSolver() {
193   for (const auto& tensor_ref : scratch_tensor_refs_) {
194     tensor_ref.Unref();
195   }
196 }
197 
198 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<GpuSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_infos,std::function<void (const Status &,const std::vector<HostLapackInfo> &)> info_checker_callback)199 void GpuSolver::CheckLapackInfoAndDeleteSolverAsync(
200     std::unique_ptr<GpuSolver> solver,
201     const std::vector<DeviceLapackInfo>& dev_lapack_infos,
202     std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
203         info_checker_callback) {
204   CHECK(info_checker_callback != nullptr);
205   std::vector<HostLapackInfo> host_lapack_infos;
206   if (dev_lapack_infos.empty()) {
207     info_checker_callback(OkStatus(), host_lapack_infos);
208     return;
209   }
210 
211   // Launch memcpys to copy info back from the device to the host.
212   for (const auto& dev_lapack_info : dev_lapack_infos) {
213     bool success = true;
214     auto host_copy = dev_lapack_info.CopyToHost(&success);
215     OP_REQUIRES(
216         solver->context(), success,
217         errors::Internal(
218             "Failed to launch copy of dev_lapack_info to host, debug_info = ",
219             dev_lapack_info.debug_info()));
220     host_lapack_infos.push_back(std::move(host_copy));
221   }
222 
223   // This callback checks that all batch items in all calls were processed
224   // successfully and passes status to the info_checker_callback accordingly.
225   auto* stream = solver->context()->op_device_context()->stream();
226   auto wrapped_info_checker_callback =
227       [stream](
228           GpuSolver* solver,
229           std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
230               info_checker_callback,
231           std::vector<HostLapackInfo> host_lapack_infos) {
232         ScopedActivateExecutorContext scoped_activation{stream->parent()};
233         Status status;
234         for (const auto& host_lapack_info : host_lapack_infos) {
235           for (int i = 0; i < host_lapack_info.size() && status.ok(); ++i) {
236             const int info_value = host_lapack_info(i);
237             if (info_value != 0) {
238               status = errors::InvalidArgument(
239                   "Got info = ", info_value, " for batch index ", i,
240                   ", expected info = 0. Debug_info = ",
241                   host_lapack_info.debug_info());
242             }
243           }
244           if (!status.ok()) {
245             break;
246           }
247         }
248         // Delete solver to release temp tensor refs.
249         delete solver;
250 
251         // Delegate further error checking to provided functor.
252         info_checker_callback(status, host_lapack_infos);
253       };
254   // Note: An std::function cannot have unique_ptr arguments (it must be copy
255   // constructible and therefore so must its arguments). Therefore, we release
256   // solver into a raw pointer to be deleted at the end of
257   // wrapped_info_checker_callback.
258   // Release ownership of solver. It will be deleted in the cb callback.
259   auto solver_raw_ptr = solver.release();
260   auto cb =
261       std::bind(wrapped_info_checker_callback, solver_raw_ptr,
262                 std::move(info_checker_callback), std::move(host_lapack_infos));
263 
264   solver_raw_ptr->context()
265       ->device()
266       ->tensorflow_accelerator_device_info()
267       ->event_mgr->ThenExecute(stream, std::move(cb));
268 }
269 
270 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<GpuSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_info,AsyncOpKernel::DoneCallback done)271 void GpuSolver::CheckLapackInfoAndDeleteSolverAsync(
272     std::unique_ptr<GpuSolver> solver,
273     const std::vector<DeviceLapackInfo>& dev_lapack_info,
274     AsyncOpKernel::DoneCallback done) {
275   OpKernelContext* context = solver->context();
276   auto wrapped_done = [context, done](
277                           const Status& status,
278                           const std::vector<HostLapackInfo>& /* unused */) {
279     if (done != nullptr) {
280       OP_REQUIRES_OK_ASYNC(context, status, done);
281       done();
282     } else {
283       OP_REQUIRES_OK(context, status);
284     }
285   };
286   CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_lapack_info,
287                                       wrapped_done);
288 }
289 
290 // Allocates a temporary tensor. The GpuSolver object maintains a
291 // TensorReference to the underlying Tensor to prevent it from being deallocated
292 // prematurely.
allocate_scoped_tensor(DataType type,const TensorShape & shape,Tensor * out_temp)293 Status GpuSolver::allocate_scoped_tensor(DataType type,
294                                          const TensorShape& shape,
295                                          Tensor* out_temp) {
296   const Status status = context_->allocate_temp(type, shape, out_temp);
297   if (status.ok()) {
298     scratch_tensor_refs_.emplace_back(*out_temp);
299   }
300   return status;
301 }
302 
forward_input_or_allocate_scoped_tensor(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)303 Status GpuSolver::forward_input_or_allocate_scoped_tensor(
304     gtl::ArraySlice<int> candidate_input_indices, DataType type,
305     const TensorShape& shape, Tensor* out_temp) {
306   const Status status = context_->forward_input_or_allocate_temp(
307       candidate_input_indices, type, shape, out_temp);
308   if (status.ok()) {
309     scratch_tensor_refs_.emplace_back(*out_temp);
310   }
311   return status;
312 }
313 
314 // Macro that specializes a solver method for all 4 standard
315 // numeric types.
316 #define TF_CALL_LAPACK_TYPES(m) \
317   m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
318 #define TF_CALL_LAPACK_TYPES_NO_COMPLEX(m) m(float, S) m(double, D)
319 
320 // Macros to construct cusolverDn method names.
321 #define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method
322 #define DN_SOLVER_NAME(method, type_prefix) "cusolverDn" #type_prefix #method
323 #define DN_BUFSIZE_FN(method, type_prefix) \
324   cusolverDn##type_prefix##method##_bufferSize
325 
326 // Macros to construct cublas method names.
327 #define BLAS_SOLVER_FN(method, type_prefix) cublas##type_prefix##method
328 #define BLAS_SOLVER_NAME(method, type_prefix) "cublas" #type_prefix #method
329 
330 //=============================================================================
331 // Wrappers of cuSolverDN computational methods begin here.
332 //
333 // WARNING to implementers: The function signatures listed in the online docs
334 // are sometimes inaccurate, e.g., are missing 'const' on pointers
335 // to immutable arguments, while the actual headers have them as expected.
336 // Check the actual declarations in the cusolver_api.h header file.
337 //
338 // NOTE: The cuSolver functions called below appear not to be threadsafe.
339 // so we put a global lock around the calls. Since these functions only put a
340 // kernel on the shared stream, it is not a big performance hit.
341 // TODO(rmlarsen): Investigate if the locking is still needed in Cuda 9.
342 //=============================================================================
343 
344 template <typename Scalar, typename SolverFnT>
GeamImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasOperation_t transa,cublasOperation_t transb,int m,int n,const Scalar * alpha,const Scalar * A,int lda,const Scalar * beta,const Scalar * B,int ldb,Scalar * C,int ldc)345 static inline Status GeamImpl(SolverFnT solver, cublasHandle_t cublas_handle,
346                               cublasOperation_t transa,
347                               cublasOperation_t transb, int m, int n,
348                               const Scalar* alpha, /* host or device pointer */
349                               const Scalar* A, int lda,
350                               const Scalar* beta, /* host or device pointer */
351                               const Scalar* B, int ldb, Scalar* C, int ldc) {
352   mutex_lock lock(handle_map_mutex);
353   using CudaScalar = typename CUDAComplexT<Scalar>::type;
354   TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, transa, transb, m, n,
355                                    reinterpret_cast<const CudaScalar*>(alpha),
356                                    reinterpret_cast<const CudaScalar*>(A), lda,
357                                    reinterpret_cast<const CudaScalar*>(beta),
358                                    reinterpret_cast<const CudaScalar*>(B), ldb,
359                                    reinterpret_cast<CudaScalar*>(C), ldc));
360   return OkStatus();
361 }
362 
363 #define GEAM_INSTANCE(Scalar, type_prefix)                                     \
364   template <>                                                                  \
365   Status GpuSolver::Geam<Scalar>(                                              \
366       cublasOperation_t transa, cublasOperation_t transb, int m, int n,        \
367       const Scalar* alpha, /* host or device pointer */                        \
368       const Scalar* A, int lda,                                                \
369       const Scalar* beta, /* host or device pointer */                         \
370       const Scalar* B, int ldb, Scalar* C, int ldc) const {                    \
371     return GeamImpl(BLAS_SOLVER_FN(geam, type_prefix), cublas_handle_, transa, \
372                     transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);        \
373   }
374 
375 TF_CALL_LAPACK_TYPES(GEAM_INSTANCE);
376 
377 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
PotrfImpl(BufSizeFnT bufsize,SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasFillMode_t uplo,int n,Scalar * A,int lda,int * dev_lapack_info)378 static inline Status PotrfImpl(BufSizeFnT bufsize, SolverFnT solver,
379                                GpuSolver* cuda_solver, OpKernelContext* context,
380                                cusolverDnHandle_t cusolver_dn_handle,
381                                cublasFillMode_t uplo, int n, Scalar* A, int lda,
382                                int* dev_lapack_info) {
383   mutex_lock lock(handle_map_mutex);
384   /* Get amount of workspace memory required. */
385   int lwork;
386   TF_RETURN_IF_CUSOLVER_ERROR(
387       bufsize(cusolver_dn_handle, uplo, n, CUDAComplex(A), lda, &lwork));
388   /* Allocate device memory for workspace. */
389   auto dev_workspace =
390       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
391   /* Launch the solver kernel. */
392   TF_RETURN_IF_CUSOLVER_ERROR(solver(
393       cusolver_dn_handle, uplo, n, CUDAComplex(A), lda,
394       CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
395   return OkStatus();
396 }
397 
398 #define POTRF_INSTANCE(Scalar, type_prefix)                                  \
399   template <>                                                                \
400   Status GpuSolver::Potrf<Scalar>(cublasFillMode_t uplo, int n, Scalar* A,   \
401                                   int lda, int* dev_lapack_info) {           \
402     return PotrfImpl(DN_BUFSIZE_FN(potrf, type_prefix),                      \
403                      DN_SOLVER_FN(potrf, type_prefix), this, context_,       \
404                      cusolver_dn_handle_, uplo, n, A, lda, dev_lapack_info); \
405   }
406 
407 TF_CALL_LAPACK_TYPES(POTRF_INSTANCE);
408 
409 #if CUDA_VERSION >= 9020
410 template <typename Scalar, typename SolverFnT>
PotrfBatchedImpl(SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasFillMode_t uplo,int n,const Scalar * const host_a_dev_ptrs[],int lda,DeviceLapackInfo * dev_lapack_info,int batch_size)411 static inline Status PotrfBatchedImpl(
412     SolverFnT solver, GpuSolver* cuda_solver, OpKernelContext* context,
413     cusolverDnHandle_t cusolver_dn_handle, cublasFillMode_t uplo, int n,
414     const Scalar* const host_a_dev_ptrs[], int lda,
415     DeviceLapackInfo* dev_lapack_info, int batch_size) {
416   mutex_lock lock(handle_map_mutex);
417   using CudaScalar = typename CUDAComplexT<Scalar>::type;
418   ScratchSpace<uint8> dev_a_dev_ptrs =
419       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
420                                           /* on_host */ false);
421   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
422                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
423     return errors::Internal("PotrfBatched: failed to copy pointers to device");
424   }
425   TF_RETURN_IF_CUSOLVER_ERROR(
426       solver(cusolver_dn_handle, uplo, n,
427              reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda,
428              dev_lapack_info->mutable_data(), batch_size));
429   return OkStatus();
430 }
431 
432 #define POTRF_BATCHED_INSTANCE(Scalar, type_prefix)                        \
433   template <>                                                              \
434   Status GpuSolver::PotrfBatched(                                          \
435       cublasFillMode_t uplo, int n, const Scalar* const host_a_dev_ptrs[], \
436       int lda, DeviceLapackInfo* dev_lapack_info, int batch_size) {        \
437     return PotrfBatchedImpl(DN_SOLVER_FN(potrfBatched, type_prefix), this, \
438                             context_, cusolver_dn_handle_, uplo, n,        \
439                             host_a_dev_ptrs, lda, dev_lapack_info,         \
440                             batch_size);                                   \
441   }
442 
443 TF_CALL_LAPACK_TYPES(POTRF_BATCHED_INSTANCE);
444 #endif  // CUDA_VERSION >= 9020
445 
446 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GetrfImpl(BufSizeFnT bufsize,SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,Scalar * A,int lda,int * dev_pivots,int * dev_lapack_info)447 static inline Status GetrfImpl(BufSizeFnT bufsize, SolverFnT solver,
448                                GpuSolver* cuda_solver, OpKernelContext* context,
449                                cusolverDnHandle_t cusolver_dn_handle, int m,
450                                int n, Scalar* A, int lda, int* dev_pivots,
451                                int* dev_lapack_info) {
452   mutex_lock lock(handle_map_mutex);
453   /* Get amount of workspace memory required. */
454   int lwork;
455   TF_RETURN_IF_CUSOLVER_ERROR(
456       bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
457   /* Allocate device memory for workspace. */
458   auto dev_workspace =
459       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
460   /* Launch the solver kernel. */
461   TF_RETURN_IF_CUSOLVER_ERROR(solver(
462       cusolver_dn_handle, m, n, CUDAComplex(A), lda,
463       CUDAComplex(dev_workspace.mutable_data()), dev_pivots, dev_lapack_info));
464   return OkStatus();
465 }
466 
467 #define GETRF_INSTANCE(Scalar, type_prefix)                                \
468   template <>                                                              \
469   Status GpuSolver::Getrf<Scalar>(int m, int n, Scalar* A, int lda,        \
470                                   int* dev_pivots, int* dev_lapack_info) { \
471     return GetrfImpl(DN_BUFSIZE_FN(getrf, type_prefix),                    \
472                      DN_SOLVER_FN(getrf, type_prefix), this, context_,     \
473                      cusolver_dn_handle_, m, n, A, lda, dev_pivots,        \
474                      dev_lapack_info);                                     \
475   }
476 
477 TF_CALL_LAPACK_TYPES(GETRF_INSTANCE);
478 
479 template <typename Scalar, typename SolverFnT>
GetrsImpl(SolverFnT solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasOperation_t trans,int n,int nrhs,const Scalar * A,int lda,const int * pivots,Scalar * B,int ldb,int * dev_lapack_info)480 static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context,
481                                cusolverDnHandle_t cusolver_dn_handle,
482                                cublasOperation_t trans, int n, int nrhs,
483                                const Scalar* A, int lda, const int* pivots,
484                                Scalar* B, int ldb, int* dev_lapack_info) {
485   mutex_lock lock(handle_map_mutex);
486   /* Launch the solver kernel. */
487   TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, trans, n, nrhs,
488                                      CUDAComplex(A), lda, pivots,
489                                      CUDAComplex(B), ldb, dev_lapack_info));
490   return OkStatus();
491 }
492 
493 #define GETRS_INSTANCE(Scalar, type_prefix)                                  \
494   template <>                                                                \
495   Status GpuSolver::Getrs<Scalar>(                                           \
496       cublasOperation_t trans, int n, int nrhs, const Scalar* A, int lda,    \
497       const int* pivots, Scalar* B, int ldb, int* dev_lapack_info) const {   \
498     return GetrsImpl(DN_SOLVER_FN(getrs, type_prefix), context_,             \
499                      cusolver_dn_handle_, trans, n, nrhs, A, lda, pivots, B, \
500                      ldb, dev_lapack_info);                                  \
501   }
502 
503 TF_CALL_LAPACK_TYPES(GETRS_INSTANCE);
504 
505 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GeqrfImpl(BufSizeFnT bufsize,SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,Scalar * A,int lda,Scalar * tau,int * dev_lapack_info)506 static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver,
507                                GpuSolver* cuda_solver, OpKernelContext* context,
508                                cusolverDnHandle_t cusolver_dn_handle, int m,
509                                int n, Scalar* A, int lda, Scalar* tau,
510                                int* dev_lapack_info) {
511   mutex_lock lock(handle_map_mutex);
512   /* Get amount of workspace memory required. */
513   int lwork;
514   TF_RETURN_IF_CUSOLVER_ERROR(
515       bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
516   /* Allocate device memory for workspace. */
517   auto dev_workspace =
518       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
519   /* Launch the solver kernel. */
520   TF_RETURN_IF_CUSOLVER_ERROR(solver(
521       cusolver_dn_handle, m, n, CUDAComplex(A), lda, CUDAComplex(tau),
522       CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
523   return OkStatus();
524 }
525 
526 #define GEQRF_INSTANCE(Scalar, type_prefix)                                    \
527   template <>                                                                  \
528   Status GpuSolver::Geqrf<Scalar>(int m, int n, Scalar* A, int lda,            \
529                                   Scalar* tau, int* dev_lapack_info) {         \
530     return GeqrfImpl(DN_BUFSIZE_FN(geqrf, type_prefix),                        \
531                      DN_SOLVER_FN(geqrf, type_prefix), this, context_,         \
532                      cusolver_dn_handle_, m, n, A, lda, tau, dev_lapack_info); \
533   }
534 
535 TF_CALL_LAPACK_TYPES(GEQRF_INSTANCE);
536 
537 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
UnmqrImpl(BufSizeFnT bufsize,SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasSideMode_t side,cublasOperation_t trans,int m,int n,int k,const Scalar * dev_a,int lda,const Scalar * dev_tau,Scalar * dev_c,int ldc,int * dev_lapack_info)538 static inline Status UnmqrImpl(BufSizeFnT bufsize, SolverFnT solver,
539                                GpuSolver* cuda_solver, OpKernelContext* context,
540                                cusolverDnHandle_t cusolver_dn_handle,
541                                cublasSideMode_t side, cublasOperation_t trans,
542                                int m, int n, int k, const Scalar* dev_a,
543                                int lda, const Scalar* dev_tau, Scalar* dev_c,
544                                int ldc, int* dev_lapack_info) {
545   mutex_lock lock(handle_map_mutex);
546   /* Get amount of workspace memory required. */
547   int lwork;
548   TF_RETURN_IF_CUSOLVER_ERROR(
549       bufsize(cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
550               CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc, &lwork));
551   /* Allocate device memory for workspace. */
552   auto dev_workspace =
553       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
554   /* Launch the solver kernel. */
555   TF_RETURN_IF_CUSOLVER_ERROR(solver(
556       cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
557       CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc,
558       CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
559   return OkStatus();
560 }
561 
562 // Unfortunately the LAPACK function name differs for the real and complex case
563 // (complex ones are prefixed with "UN" for "unitary"), so we instantiate each
564 // one separately.
565 #define UNMQR_INSTANCE(Scalar, function_prefix, type_prefix)                 \
566   template <>                                                                \
567   Status GpuSolver::Unmqr(cublasSideMode_t side, cublasOperation_t trans,    \
568                           int m, int n, int k, const Scalar* dev_a, int lda, \
569                           const Scalar* dev_tau, Scalar* dev_c, int ldc,     \
570                           int* dev_lapack_info) {                            \
571     return UnmqrImpl(DN_BUFSIZE_FN(function_prefix##mqr, type_prefix),       \
572                      DN_SOLVER_FN(function_prefix##mqr, type_prefix), this,  \
573                      context_, cusolver_dn_handle_, side, trans, m, n, k,    \
574                      dev_a, lda, dev_tau, dev_c, ldc, dev_lapack_info);      \
575   }
576 
577 UNMQR_INSTANCE(float, or, S);
578 UNMQR_INSTANCE(double, or, D);
579 UNMQR_INSTANCE(complex64, un, C);
580 UNMQR_INSTANCE(complex128, un, Z);
581 
582 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
UngqrImpl(BufSizeFnT bufsize,SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,int k,Scalar * dev_a,int lda,const Scalar * dev_tau,int * dev_lapack_info)583 static inline Status UngqrImpl(BufSizeFnT bufsize, SolverFnT solver,
584                                GpuSolver* cuda_solver, OpKernelContext* context,
585                                cusolverDnHandle_t cusolver_dn_handle, int m,
586                                int n, int k, Scalar* dev_a, int lda,
587                                const Scalar* dev_tau, int* dev_lapack_info) {
588   mutex_lock lock(handle_map_mutex);
589   /* Get amount of workspace memory required. */
590   int lwork;
591   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k,
592                                       CUDAComplex(dev_a), lda,
593                                       CUDAComplex(dev_tau), &lwork));
594   /* Allocate device memory for workspace. */
595   auto dev_workspace =
596       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
597   /* Launch the solver kernel. */
598   TF_RETURN_IF_CUSOLVER_ERROR(
599       solver(cusolver_dn_handle, m, n, k, CUDAComplex(dev_a), lda,
600              CUDAComplex(dev_tau), CUDAComplex(dev_workspace.mutable_data()),
601              lwork, dev_lapack_info));
602   return OkStatus();
603 }
604 
605 #define UNGQR_INSTANCE(Scalar, function_prefix, type_prefix)                \
606   template <>                                                               \
607   Status GpuSolver::Ungqr(int m, int n, int k, Scalar* dev_a, int lda,      \
608                           const Scalar* dev_tau, int* dev_lapack_info) {    \
609     return UngqrImpl(DN_BUFSIZE_FN(function_prefix##gqr, type_prefix),      \
610                      DN_SOLVER_FN(function_prefix##gqr, type_prefix), this, \
611                      context_, cusolver_dn_handle_, m, n, k, dev_a, lda,    \
612                      dev_tau, dev_lapack_info);                             \
613   }
614 
615 UNGQR_INSTANCE(float, or, S);
616 UNGQR_INSTANCE(double, or, D);
617 UNGQR_INSTANCE(complex64, un, C);
618 UNGQR_INSTANCE(complex128, un, Z);
619 
620 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
HeevdImpl(BufSizeFnT bufsize,SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,Scalar * dev_A,int lda,typename Eigen::NumTraits<Scalar>::Real * dev_W,int * dev_lapack_info)621 static inline Status HeevdImpl(BufSizeFnT bufsize, SolverFnT solver,
622                                GpuSolver* cuda_solver, OpKernelContext* context,
623                                cusolverDnHandle_t cusolver_dn_handle,
624                                cusolverEigMode_t jobz, cublasFillMode_t uplo,
625                                int n, Scalar* dev_A, int lda,
626                                typename Eigen::NumTraits<Scalar>::Real* dev_W,
627                                int* dev_lapack_info) {
628   mutex_lock lock(handle_map_mutex);
629   /* Get amount of workspace memory required. */
630   int lwork;
631   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, jobz, uplo, n,
632                                       CUDAComplex(dev_A), lda,
633                                       CUDAComplex(dev_W), &lwork));
634   /* Allocate device memory for workspace. */
635   auto dev_workspace =
636       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
637 #if CUDA_VERSION >= 11070
638   // TODO(b/223856016): CUDA 11.7 sometimes gives invalid outputs if the scratch
639   // space is not initialized to zero.
640   se::Stream* stream = context->op_device_context()->stream();
641   if (!stream) {
642     return errors::Internal("No GPU stream available");
643   }
644   uint64_t work_size_in_bytes = static_cast<uint64_t>(lwork) * sizeof(Scalar);
645   se::DeviceMemoryBase dev_workspace_ptr(dev_workspace.mutable_data(),
646                                          work_size_in_bytes);
647   stream->ThenMemZero(&dev_workspace_ptr, work_size_in_bytes);
648 #endif
649   /* Launch the solver kernel. */
650   TF_RETURN_IF_CUSOLVER_ERROR(
651       solver(cusolver_dn_handle, jobz, uplo, n, CUDAComplex(dev_A), lda,
652              CUDAComplex(dev_W), CUDAComplex(dev_workspace.mutable_data()),
653              lwork, dev_lapack_info));
654   return OkStatus();
655 }
656 
657 #define HEEVD_INSTANCE(Scalar, function_prefix, type_prefix)                   \
658   template <>                                                                  \
659   Status GpuSolver::Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo,       \
660                           int n, Scalar* dev_A, int lda,                       \
661                           typename Eigen::NumTraits<Scalar>::Real* dev_W,      \
662                           int* dev_lapack_info) {                              \
663     return HeevdImpl(DN_BUFSIZE_FN(function_prefix##evd, type_prefix),         \
664                      DN_SOLVER_FN(function_prefix##evd, type_prefix), this,    \
665                      context_, cusolver_dn_handle_, jobz, uplo, n, dev_A, lda, \
666                      dev_W, dev_lapack_info);                                  \
667   }
668 
669 HEEVD_INSTANCE(float, sy, S);
670 HEEVD_INSTANCE(double, sy, D);
671 HEEVD_INSTANCE(complex64, he, C);
672 HEEVD_INSTANCE(complex128, he, Z);
673 
674 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GesvdImpl(BufSizeFnT bufsize,SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,signed char jobu,signed char jobvt,int m,int n,Scalar * A,int lda,Scalar * S,Scalar * U,int ldu,Scalar * VT,int ldvt,int * dev_lapack_info)675 static inline Status GesvdImpl(BufSizeFnT bufsize, SolverFnT solver,
676                                GpuSolver* cuda_solver, OpKernelContext* context,
677                                cusolverDnHandle_t cusolver_dn_handle,
678                                signed char jobu, signed char jobvt, int m,
679                                int n, Scalar* A, int lda, Scalar* S, Scalar* U,
680                                int ldu, Scalar* VT, int ldvt,
681                                int* dev_lapack_info) {
682   mutex_lock lock(handle_map_mutex);
683   /* Get amount of workspace memory required. */
684   int lwork;
685   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork));
686   /* Allocate device memory for workspace. */
687   auto dev_workspace =
688       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
689   TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, jobu, jobvt, m, n,
690                                      CUDAComplex(A), lda, S, CUDAComplex(U),
691                                      ldu, CUDAComplex(VT), ldvt,
692                                      CUDAComplex(dev_workspace.mutable_data()),
693                                      lwork, nullptr, dev_lapack_info));
694   return OkStatus();
695 }
696 
697 #define GESVD_INSTANCE(Scalar, type_prefix)                              \
698   template <>                                                            \
699   Status GpuSolver::Gesvd<Scalar>(                                       \
700       signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A,  \
701       int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,    \
702       int ldvt, int* dev_lapack_info) {                                  \
703     return GesvdImpl(DN_BUFSIZE_FN(gesvd, type_prefix),                  \
704                      DN_SOLVER_FN(gesvd, type_prefix), this, context_,   \
705                      cusolver_dn_handle_, jobu, jobvt, m, n, dev_A, lda, \
706                      dev_S, dev_U, ldu, dev_VT, ldvt, dev_lapack_info);  \
707   }
708 
709 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVD_INSTANCE);
710 
711 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GesvdjBatchedImpl(BufSizeFnT bufsize,SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cusolverEigMode_t jobz,int m,int n,Scalar * A,int lda,Scalar * S,Scalar * U,int ldu,Scalar * V,int ldv,int * dev_lapack_info,int batch_size)712 static inline Status GesvdjBatchedImpl(BufSizeFnT bufsize, SolverFnT solver,
713                                        GpuSolver* cuda_solver,
714                                        OpKernelContext* context,
715                                        cusolverDnHandle_t cusolver_dn_handle,
716                                        cusolverEigMode_t jobz, int m, int n,
717                                        Scalar* A, int lda, Scalar* S, Scalar* U,
718                                        int ldu, Scalar* V, int ldv,
719                                        int* dev_lapack_info, int batch_size) {
720   mutex_lock lock(handle_map_mutex);
721   /* Get amount of workspace memory required. */
722   int lwork;
723   /* Default parameters for gesvdj and gesvdjBatched. */
724   gesvdjInfo_t svdj_info;
725   TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnCreateGesvdjInfo(&svdj_info));
726   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(
727       cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U),
728       ldu, CUDAComplex(V), ldv, &lwork, svdj_info, batch_size));
729   /* Allocate device memory for workspace. */
730   auto dev_workspace =
731       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
732   TF_RETURN_IF_CUSOLVER_ERROR(solver(
733       cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U),
734       ldu, CUDAComplex(V), ldv, CUDAComplex(dev_workspace.mutable_data()),
735       lwork, dev_lapack_info, svdj_info, batch_size));
736   TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnDestroyGesvdjInfo(svdj_info));
737   return OkStatus();
738 }
739 
740 #define GESVDJBATCHED_INSTANCE(Scalar, type_prefix)                            \
741   template <>                                                                  \
742   Status GpuSolver::GesvdjBatched<Scalar>(                                     \
743       cusolverEigMode_t jobz, int m, int n, Scalar* dev_A, int lda,            \
744       Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_V, int ldv,           \
745       int* dev_lapack_info, int batch_size) {                                  \
746     return GesvdjBatchedImpl(DN_BUFSIZE_FN(gesvdjBatched, type_prefix),        \
747                              DN_SOLVER_FN(gesvdjBatched, type_prefix), this,   \
748                              context_, cusolver_dn_handle_, jobz, m, n, dev_A, \
749                              lda, dev_S, dev_U, ldu, dev_V, ldv,               \
750                              dev_lapack_info, batch_size);                     \
751   }
752 
753 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVDJBATCHED_INSTANCE);
754 
755 //=============================================================================
756 // Wrappers of cuBlas computational methods begin here.
757 //
758 // WARNING to implementers: The function signatures listed in the online docs
759 // are sometimes inaccurate, e.g., are missing 'const' on pointers
760 // to immutable arguments, while the actual headers have them as expected.
761 // Check the actual declarations in the cublas_api.h header file.
762 //=============================================================================
763 template <typename Scalar, typename SolverFnT>
GetrfBatchedImpl(SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,int * dev_pivots,DeviceLapackInfo * dev_lapack_info,int batch_size)764 static inline Status GetrfBatchedImpl(SolverFnT solver, GpuSolver* cuda_solver,
765                                       OpKernelContext* context,
766                                       cublasHandle_t cublas_handle, int n,
767                                       const Scalar* const host_a_dev_ptrs[],
768                                       int lda, int* dev_pivots,
769                                       DeviceLapackInfo* dev_lapack_info,
770                                       int batch_size) {
771   mutex_lock lock(handle_map_mutex);
772   using CudaScalar = typename CUDAComplexT<Scalar>::type;
773   ScratchSpace<uint8> dev_a_dev_ptrs =
774       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
775                                           /* on_host */ false);
776   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
777                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
778     return errors::Internal("GetrfBatched: failed to copy pointers to device");
779   }
780   TF_RETURN_IF_CUBLAS_ERROR(
781       solver(cublas_handle, n,
782              reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda,
783              dev_pivots, dev_lapack_info->mutable_data(), batch_size));
784   return OkStatus();
785 }
786 
787 #define GETRF_BATCHED_INSTANCE(Scalar, type_prefix)                            \
788   template <>                                                                  \
789   Status GpuSolver::GetrfBatched(                                              \
790       int n, const Scalar* const host_a_dev_ptrs[], int lda, int* dev_pivots,  \
791       DeviceLapackInfo* dev_lapack_info, int batch_size) {                     \
792     return GetrfBatchedImpl(BLAS_SOLVER_FN(getrfBatched, type_prefix), this,   \
793                             context_, cublas_handle_, n, host_a_dev_ptrs, lda, \
794                             dev_pivots, dev_lapack_info, batch_size);          \
795   }
796 
797 TF_CALL_LAPACK_TYPES(GETRF_BATCHED_INSTANCE);
798 
799 template <typename Scalar, typename SolverFnT>
GetrsBatchedImpl(SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,cublasOperation_t trans,int n,int nrhs,const Scalar * const host_a_dev_ptrs[],int lda,const int * dev_pivots,const Scalar * const host_b_dev_ptrs[],int ldb,int * host_lapack_info,int batch_size)800 static inline Status GetrsBatchedImpl(
801     SolverFnT solver, GpuSolver* cuda_solver, OpKernelContext* context,
802     cublasHandle_t cublas_handle, cublasOperation_t trans, int n, int nrhs,
803     const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,
804     const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info,
805     int batch_size) {
806   mutex_lock lock(handle_map_mutex);
807   using CudaScalar = typename CUDAComplexT<Scalar>::type;
808   ScratchSpace<uint8> dev_a_dev_ptrs =
809       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
810                                           /* on_host */ false);
811   ScratchSpace<uint8> dev_b_dev_ptrs =
812       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
813                                           /* on_host */ false);
814   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
815                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
816     return errors::Internal("GetrsBatched: failed to copy pointers to device");
817   }
818   if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
819                         host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
820     return errors::Internal("GetrsBatched: failed to copy pointers to device");
821   }
822   TF_RETURN_IF_CUBLAS_ERROR(solver(
823       cublas_handle, trans, n, nrhs,
824       reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
825       dev_pivots, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
826       ldb, host_lapack_info, batch_size));
827   return OkStatus();
828 }
829 
830 #define GETRS_BATCHED_INSTANCE(Scalar, type_prefix)                            \
831   template <>                                                                  \
832   Status GpuSolver::GetrsBatched(                                              \
833       cublasOperation_t trans, int n, int nrhs,                                \
834       const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,   \
835       const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info,   \
836       int batch_size) {                                                        \
837     return GetrsBatchedImpl(reinterpret_cast<getrs_##type_prefix*>(            \
838                                 BLAS_SOLVER_FN(getrsBatched, type_prefix)),    \
839                             this, context_, cublas_handle_, trans, n, nrhs,    \
840                             host_a_dev_ptrs, lda, dev_pivots, host_b_dev_ptrs, \
841                             ldb, host_lapack_info, batch_size);                \
842   }
843 
844 TF_CALL_LAPACK_TYPES(GETRS_BATCHED_INSTANCE);
845 
846 template <typename Scalar, typename SolverFnT>
GetriBatchedImpl(SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,const int * dev_pivots,const Scalar * const host_a_inv_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)847 static inline Status GetriBatchedImpl(
848     SolverFnT solver, GpuSolver* cuda_solver, OpKernelContext* context,
849     cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
850     int lda, const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],
851     int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {
852   mutex_lock lock(handle_map_mutex);
853   using CudaScalar = typename CUDAComplexT<Scalar>::type;
854   ScratchSpace<uint8> dev_a_dev_ptrs =
855       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
856                                           /* on_host */ false);
857   ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
858       sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
859   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
860                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
861       !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
862                         host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
863     return errors::Internal("GetriBatched: failed to copy pointers to device");
864   }
865   TF_RETURN_IF_CUBLAS_ERROR(
866       solver(cublas_handle, n,
867              reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
868              lda, dev_pivots,
869              reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()),
870              ldainv, dev_lapack_info->mutable_data(), batch_size));
871   return OkStatus();
872 }
873 
874 #define GETRI_BATCHED_INSTANCE(Scalar, type_prefix)                          \
875   template <>                                                                \
876   Status GpuSolver::GetriBatched(                                            \
877       int n, const Scalar* const host_a_dev_ptrs[], int lda,                 \
878       const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],      \
879       int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {       \
880     return GetriBatchedImpl(                                                 \
881         reinterpret_cast<getri_##type_prefix*>(                              \
882             BLAS_SOLVER_FN(getriBatched, type_prefix)),                      \
883         this, context_, cublas_handle_, n, host_a_dev_ptrs, lda, dev_pivots, \
884         host_a_inv_dev_ptrs, ldainv, dev_lapack_info, batch_size);           \
885   }
886 
887 TF_CALL_LAPACK_TYPES(GETRI_BATCHED_INSTANCE);
888 
889 template <typename Scalar, typename SolverFnT>
MatInvBatchedImpl(SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,const Scalar * const host_a_inv_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)890 static inline Status MatInvBatchedImpl(
891     SolverFnT solver, GpuSolver* cuda_solver, OpKernelContext* context,
892     cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
893     int lda, const Scalar* const host_a_inv_dev_ptrs[], int ldainv,
894     DeviceLapackInfo* dev_lapack_info, int batch_size) {
895   mutex_lock lock(handle_map_mutex);
896   using CudaScalar = typename CUDAComplexT<Scalar>::type;
897   ScratchSpace<uint8> dev_a_dev_ptrs =
898       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
899                                           /* on_host */ false);
900   ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
901       sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
902   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
903                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
904       !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
905                         host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
906     return errors::Internal("MatInvBatched: failed to copy pointers to device");
907   }
908   TF_RETURN_IF_CUBLAS_ERROR(solver(
909       cublas_handle, n,
910       reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
911       reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()), ldainv,
912       dev_lapack_info->mutable_data(), batch_size));
913   return OkStatus();
914 }
915 
916 #define MATINV_BATCHED_INSTANCE(Scalar, type_prefix)                          \
917   template <>                                                                 \
918   Status GpuSolver::MatInvBatched(                                            \
919       int n, const Scalar* const host_a_dev_ptrs[], int lda,                  \
920       const Scalar* const host_a_inv_dev_ptrs[], int ldainv,                  \
921       DeviceLapackInfo* dev_lapack_info, int batch_size) {                    \
922     return MatInvBatchedImpl(reinterpret_cast<matinv_##type_prefix*>(         \
923                                  BLAS_SOLVER_FN(matinvBatched, type_prefix)), \
924                              this, context_, cublas_handle_, n,               \
925                              host_a_dev_ptrs, lda, host_a_inv_dev_ptrs,       \
926                              ldainv, dev_lapack_info, batch_size);            \
927   }
928 
929 TF_CALL_LAPACK_TYPES(MATINV_BATCHED_INSTANCE);
930 
931 template <typename Scalar, typename SolverFnT>
TrsmImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasSideMode_t side,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int m,int n,const Scalar * alpha,const Scalar * A,int lda,Scalar * B,int ldb)932 static inline Status TrsmImpl(SolverFnT solver, cublasHandle_t cublas_handle,
933                               cublasSideMode_t side, cublasFillMode_t uplo,
934                               cublasOperation_t trans, cublasDiagType_t diag,
935                               int m, int n,
936                               const Scalar* alpha, /* host or device pointer */
937                               const Scalar* A, int lda, Scalar* B, int ldb) {
938   mutex_lock lock(handle_map_mutex);
939   using CudaScalar = typename CUDAComplexT<Scalar>::type;
940   TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, side, uplo, trans, diag, m, n,
941                                    reinterpret_cast<const CudaScalar*>(alpha),
942                                    reinterpret_cast<const CudaScalar*>(A), lda,
943                                    reinterpret_cast<CudaScalar*>(B), ldb));
944   return OkStatus();
945 }
946 
947 #define TRSM_INSTANCE(Scalar, type_prefix)                                   \
948   template <>                                                                \
949   Status GpuSolver::Trsm<Scalar>(                                            \
950       cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \
951       cublasDiagType_t diag, int m, int n,                                   \
952       const Scalar* alpha, /* host or device pointer */                      \
953       const Scalar* A, int lda, Scalar* B, int ldb) {                        \
954     return TrsmImpl(BLAS_SOLVER_FN(trsm, type_prefix), cublas_handle_, side, \
955                     uplo, trans, diag, m, n, alpha, A, lda, B, ldb);         \
956   }
957 
958 TF_CALL_LAPACK_TYPES(TRSM_INSTANCE);
959 
960 template <typename Scalar, typename SolverFnT>
TrsvImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int n,const Scalar * A,int lda,Scalar * x,int incx)961 static inline Status TrsvImpl(SolverFnT solver, cublasHandle_t cublas_handle,
962                               cublasFillMode_t uplo, cublasOperation_t trans,
963                               cublasDiagType_t diag, int n, const Scalar* A,
964                               int lda, Scalar* x, int incx) {
965   mutex_lock lock(handle_map_mutex);
966   using CudaScalar = typename CUDAComplexT<Scalar>::type;
967   TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, uplo, trans, diag, n,
968                                    reinterpret_cast<const CudaScalar*>(A), lda,
969                                    reinterpret_cast<CudaScalar*>(x), incx));
970   return OkStatus();
971 }
972 
973 #define TRSV_INSTANCE(Scalar, type_prefix)                                   \
974   template <>                                                                \
975   Status GpuSolver::Trsv<Scalar>(                                            \
976       cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, \
977       int n, const Scalar* A, int lda, Scalar* x, int incx) {                \
978     return TrsvImpl(BLAS_SOLVER_FN(trsv, type_prefix), cublas_handle_, uplo, \
979                     trans, diag, n, A, lda, x, incx);                        \
980   }
981 
982 TF_CALL_LAPACK_TYPES(TRSV_INSTANCE);
983 
984 template <typename Scalar, typename SolverFnT>
TrsmBatchedImpl(SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,cublasSideMode_t side,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int m,int n,const Scalar * alpha,const Scalar * const host_a_dev_ptrs[],int lda,Scalar * host_b_dev_ptrs[],int ldb,int batch_size)985 static inline Status TrsmBatchedImpl(
986     SolverFnT solver, GpuSolver* cuda_solver, OpKernelContext* context,
987     cublasHandle_t cublas_handle, cublasSideMode_t side, cublasFillMode_t uplo,
988     cublasOperation_t trans, cublasDiagType_t diag, int m, int n,
989     const Scalar* alpha, const Scalar* const host_a_dev_ptrs[], int lda,
990     Scalar* host_b_dev_ptrs[], int ldb, int batch_size) {
991   mutex_lock lock(handle_map_mutex);
992   using CudaScalar = typename CUDAComplexT<Scalar>::type;
993   ScratchSpace<uint8> dev_a_dev_ptrs =
994       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
995                                           /* on_host */ false);
996   ScratchSpace<uint8> dev_b_dev_ptrs =
997       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
998                                           /* on_host */ false);
999   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
1000                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
1001     return errors::Internal("TrsmBatched: failed to copy pointers to device");
1002   }
1003   if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
1004                         host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
1005     return errors::Internal("TrsmBatched: failed to copy pointers to device");
1006   }
1007   TF_RETURN_IF_CUBLAS_ERROR(
1008       solver(cublas_handle, side, uplo, trans, diag, m, n,
1009              reinterpret_cast<const CudaScalar*>(alpha),
1010              reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
1011              lda, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
1012              ldb, batch_size));
1013   return OkStatus();
1014 }
1015 
1016 #define TRSM_BATCHED_INSTANCE(Scalar, type_prefix)                            \
1017   template <>                                                                 \
1018   Status GpuSolver::TrsmBatched(                                              \
1019       cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans,  \
1020       cublasDiagType_t diag, int m, int n, const Scalar* alpha,               \
1021       const Scalar* const dev_Aarray[], int lda, Scalar* dev_Barray[],        \
1022       int ldb, int batch_size) {                                              \
1023     return TrsmBatchedImpl(reinterpret_cast<trsm_##type_prefix*>(             \
1024                                BLAS_SOLVER_FN(trsmBatched, type_prefix)),     \
1025                            this, context_, cublas_handle_, side, uplo, trans, \
1026                            diag, m, n, alpha, dev_Aarray, lda, dev_Barray,    \
1027                            ldb, batch_size);                                  \
1028   }
1029 
1030 TF_CALL_LAPACK_TYPES(TRSM_BATCHED_INSTANCE);
1031 
1032 }  // namespace tensorflow
1033 
1034 #endif  // GOOGLE_CUDA
1035