1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================
15 */
16 #ifdef GOOGLE_CUDA
17 #include <chrono>
18 #include <complex>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "third_party/gpus/cuda/include/cublas_v2.h"
23 #include "third_party/gpus/cuda/include/cusolverDn.h"
24 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/blocking_counter.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/core/stringpiece.h"
30 #include "tensorflow/core/lib/gtl/inlined_vector.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/stream_executor.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/util/gpu_solvers.h"
35 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
36
37 // The CUDA cublas_api.h API contains const-correctness errors. Instead of
38 // casting away constness on our data, we instead reinterpret the CuBLAS
39 // functions as what they were clearly meant to be, and thus we can call
40 // the functions naturally.
41 //
42 // (The error is that input-only arrays are bound to parameter types
43 // "const T**" instead of the correct "const T* const*".)
44 extern "C" {
45 using getrs_S = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
46 const float* const*, int, const int*, float**,
47 int, int*, int);
48 using getrs_D = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
49 const double* const*, int, const int*, double**,
50 int, int*, int);
51 using getrs_C = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
52 const float2* const*, int, const int*, float2**,
53 int, int*, int);
54 using getrs_Z = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
55 const double2* const*, int, const int*,
56 double2**, int, int*, int);
57
58 using getri_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
59 const int*, float**, int, int*, int);
60 using getri_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
61 const int*, double**, int, int*, int);
62 using getri_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
63 const int*, float2**, int, int*, int);
64 using getri_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
65 const int*, double2**, int, int*, int);
66
67 using matinv_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
68 float**, int, int*, int);
69 using matinv_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
70 double**, int, int*, int);
71 using matinv_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
72 float2**, int, int*, int);
73 using matinv_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
74 double2**, int, int*, int);
75
76 using trsm_S = cublasStatus_t(cublasContext*, cublasSideMode_t,
77 cublasFillMode_t, cublasOperation_t,
78 cublasDiagType_t, int, int, const float*,
79 const float* const*, int, float* const*, int,
80 int);
81 using trsm_D = cublasStatus_t(cublasContext*, cublasSideMode_t,
82 cublasFillMode_t, cublasOperation_t,
83 cublasDiagType_t, int, int, const double*,
84 const double* const*, int, double* const*, int,
85 int);
86 using trsm_C = cublasStatus_t(cublasContext*, cublasSideMode_t,
87 cublasFillMode_t, cublasOperation_t,
88 cublasDiagType_t, int, int, const float2*,
89 const float2* const*, int, float2* const*, int,
90 int);
91 using trsm_Z = cublasStatus_t(cublasContext*, cublasSideMode_t,
92 cublasFillMode_t, cublasOperation_t,
93 cublasDiagType_t, int, int, const double2*,
94 const double2* const*, int, double2* const*, int,
95 int);
96 }
97
98 namespace tensorflow {
99 namespace {
100
101 using se::cuda::ScopedActivateExecutorContext;
102
CopyHostToDevice(OpKernelContext * context,void * dst,const void * src,uint64 bytes)103 inline bool CopyHostToDevice(OpKernelContext* context, void* dst,
104 const void* src, uint64 bytes) {
105 auto stream = context->op_device_context()->stream();
106 se::DeviceMemoryBase wrapped_dst(dst);
107 return stream->ThenMemcpy(&wrapped_dst, src, bytes).ok();
108 }
109
110 // A set of initialized handles to the underlying Cuda libraries used by
111 // GpuSolver. We maintain one such set of handles per unique stream.
112 struct GpuSolverHandles {
GpuSolverHandlestensorflow::__anonc494060e0111::GpuSolverHandles113 explicit GpuSolverHandles(cudaStream_t stream) {
114 CHECK(cusolverDnCreate(&cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
115 << "Failed to create cuSolverDN instance.";
116 CHECK(cusolverDnSetStream(cusolver_dn_handle, stream) ==
117 CUSOLVER_STATUS_SUCCESS)
118 << "Failed to set cuSolverDN stream.";
119 CHECK(cublasCreate(&cublas_handle) == CUBLAS_STATUS_SUCCESS)
120 << "Failed to create cuBlas instance.";
121 CHECK(cublasSetStream(cublas_handle, stream) == CUBLAS_STATUS_SUCCESS)
122 << "Failed to set cuBlas stream.";
123 }
124
~GpuSolverHandlestensorflow::__anonc494060e0111::GpuSolverHandles125 ~GpuSolverHandles() {
126 CHECK(cublasDestroy(cublas_handle) == CUBLAS_STATUS_SUCCESS)
127 << "Failed to destroy cuBlas instance.";
128 CHECK(cusolverDnDestroy(cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
129 << "Failed to destroy cuSolverDN instance.";
130 }
131 cublasHandle_t cublas_handle;
132 cusolverDnHandle_t cusolver_dn_handle;
133 };
134
135 static mutex handle_map_mutex(LINKER_INITIALIZED);
136
137 using HandleMap =
138 std::unordered_map<cudaStream_t, std::unique_ptr<GpuSolverHandles>>;
139
140 // Returns a singleton map used for storing initialized handles for each unique
141 // cuda stream.
GetHandleMapSingleton()142 HandleMap* GetHandleMapSingleton() {
143 static HandleMap* cm = new HandleMap;
144 return cm;
145 }
146
147 } // namespace
148
149 #define TF_RETURN_IF_CUSOLVER_ERROR(expr) \
150 do { \
151 auto status = (expr); \
152 if (TF_PREDICT_FALSE(status != CUSOLVER_STATUS_SUCCESS)) { \
153 return errors::Internal( \
154 __FILE__, ":", __LINE__, \
155 ": cuSolverDN call failed with status =", status); \
156 } \
157 } while (0)
158
159 #define TF_RETURN_IF_CUBLAS_ERROR(expr) \
160 do { \
161 auto status = (expr); \
162 if (TF_PREDICT_FALSE(status != CUBLAS_STATUS_SUCCESS)) { \
163 return errors::Internal(__FILE__, ":", __LINE__, \
164 ": cuBlas call failed status = ", status); \
165 } \
166 } while (0)
167
GpuSolver(OpKernelContext * context)168 GpuSolver::GpuSolver(OpKernelContext* context) : context_(context) {
169 mutex_lock lock(handle_map_mutex);
170 const cudaStream_t* cu_stream_ptr = CHECK_NOTNULL(
171 reinterpret_cast<const cudaStream_t*>(context->op_device_context()
172 ->stream()
173 ->implementation()
174 ->GpuStreamMemberHack()));
175 cuda_stream_ = *cu_stream_ptr;
176 HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton());
177 auto it = handle_map->find(cuda_stream_);
178 if (it == handle_map->end()) {
179 LOG(INFO) << "Creating GpuSolver handles for stream " << cuda_stream_;
180 // Previously unseen Cuda stream. Initialize a set of Cuda solver library
181 // handles for it.
182 std::unique_ptr<GpuSolverHandles> new_handles(
183 new GpuSolverHandles(cuda_stream_));
184 it =
185 handle_map->insert(std::make_pair(cuda_stream_, std::move(new_handles)))
186 .first;
187 }
188 cusolver_dn_handle_ = it->second->cusolver_dn_handle;
189 cublas_handle_ = it->second->cublas_handle;
190 }
191
~GpuSolver()192 GpuSolver::~GpuSolver() {
193 for (const auto& tensor_ref : scratch_tensor_refs_) {
194 tensor_ref.Unref();
195 }
196 }
197
198 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<GpuSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_infos,std::function<void (const Status &,const std::vector<HostLapackInfo> &)> info_checker_callback)199 void GpuSolver::CheckLapackInfoAndDeleteSolverAsync(
200 std::unique_ptr<GpuSolver> solver,
201 const std::vector<DeviceLapackInfo>& dev_lapack_infos,
202 std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
203 info_checker_callback) {
204 CHECK(info_checker_callback != nullptr);
205 std::vector<HostLapackInfo> host_lapack_infos;
206 if (dev_lapack_infos.empty()) {
207 info_checker_callback(OkStatus(), host_lapack_infos);
208 return;
209 }
210
211 // Launch memcpys to copy info back from the device to the host.
212 for (const auto& dev_lapack_info : dev_lapack_infos) {
213 bool success = true;
214 auto host_copy = dev_lapack_info.CopyToHost(&success);
215 OP_REQUIRES(
216 solver->context(), success,
217 errors::Internal(
218 "Failed to launch copy of dev_lapack_info to host, debug_info = ",
219 dev_lapack_info.debug_info()));
220 host_lapack_infos.push_back(std::move(host_copy));
221 }
222
223 // This callback checks that all batch items in all calls were processed
224 // successfully and passes status to the info_checker_callback accordingly.
225 auto* stream = solver->context()->op_device_context()->stream();
226 auto wrapped_info_checker_callback =
227 [stream](
228 GpuSolver* solver,
229 std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
230 info_checker_callback,
231 std::vector<HostLapackInfo> host_lapack_infos) {
232 ScopedActivateExecutorContext scoped_activation{stream->parent()};
233 Status status;
234 for (const auto& host_lapack_info : host_lapack_infos) {
235 for (int i = 0; i < host_lapack_info.size() && status.ok(); ++i) {
236 const int info_value = host_lapack_info(i);
237 if (info_value != 0) {
238 status = errors::InvalidArgument(
239 "Got info = ", info_value, " for batch index ", i,
240 ", expected info = 0. Debug_info = ",
241 host_lapack_info.debug_info());
242 }
243 }
244 if (!status.ok()) {
245 break;
246 }
247 }
248 // Delete solver to release temp tensor refs.
249 delete solver;
250
251 // Delegate further error checking to provided functor.
252 info_checker_callback(status, host_lapack_infos);
253 };
254 // Note: An std::function cannot have unique_ptr arguments (it must be copy
255 // constructible and therefore so must its arguments). Therefore, we release
256 // solver into a raw pointer to be deleted at the end of
257 // wrapped_info_checker_callback.
258 // Release ownership of solver. It will be deleted in the cb callback.
259 auto solver_raw_ptr = solver.release();
260 auto cb =
261 std::bind(wrapped_info_checker_callback, solver_raw_ptr,
262 std::move(info_checker_callback), std::move(host_lapack_infos));
263
264 solver_raw_ptr->context()
265 ->device()
266 ->tensorflow_accelerator_device_info()
267 ->event_mgr->ThenExecute(stream, std::move(cb));
268 }
269
270 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<GpuSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_info,AsyncOpKernel::DoneCallback done)271 void GpuSolver::CheckLapackInfoAndDeleteSolverAsync(
272 std::unique_ptr<GpuSolver> solver,
273 const std::vector<DeviceLapackInfo>& dev_lapack_info,
274 AsyncOpKernel::DoneCallback done) {
275 OpKernelContext* context = solver->context();
276 auto wrapped_done = [context, done](
277 const Status& status,
278 const std::vector<HostLapackInfo>& /* unused */) {
279 if (done != nullptr) {
280 OP_REQUIRES_OK_ASYNC(context, status, done);
281 done();
282 } else {
283 OP_REQUIRES_OK(context, status);
284 }
285 };
286 CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_lapack_info,
287 wrapped_done);
288 }
289
290 // Allocates a temporary tensor. The GpuSolver object maintains a
291 // TensorReference to the underlying Tensor to prevent it from being deallocated
292 // prematurely.
allocate_scoped_tensor(DataType type,const TensorShape & shape,Tensor * out_temp)293 Status GpuSolver::allocate_scoped_tensor(DataType type,
294 const TensorShape& shape,
295 Tensor* out_temp) {
296 const Status status = context_->allocate_temp(type, shape, out_temp);
297 if (status.ok()) {
298 scratch_tensor_refs_.emplace_back(*out_temp);
299 }
300 return status;
301 }
302
forward_input_or_allocate_scoped_tensor(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)303 Status GpuSolver::forward_input_or_allocate_scoped_tensor(
304 gtl::ArraySlice<int> candidate_input_indices, DataType type,
305 const TensorShape& shape, Tensor* out_temp) {
306 const Status status = context_->forward_input_or_allocate_temp(
307 candidate_input_indices, type, shape, out_temp);
308 if (status.ok()) {
309 scratch_tensor_refs_.emplace_back(*out_temp);
310 }
311 return status;
312 }
313
314 // Macro that specializes a solver method for all 4 standard
315 // numeric types.
316 #define TF_CALL_LAPACK_TYPES(m) \
317 m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
318 #define TF_CALL_LAPACK_TYPES_NO_COMPLEX(m) m(float, S) m(double, D)
319
320 // Macros to construct cusolverDn method names.
321 #define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method
322 #define DN_SOLVER_NAME(method, type_prefix) "cusolverDn" #type_prefix #method
323 #define DN_BUFSIZE_FN(method, type_prefix) \
324 cusolverDn##type_prefix##method##_bufferSize
325
326 // Macros to construct cublas method names.
327 #define BLAS_SOLVER_FN(method, type_prefix) cublas##type_prefix##method
328 #define BLAS_SOLVER_NAME(method, type_prefix) "cublas" #type_prefix #method
329
330 //=============================================================================
331 // Wrappers of cuSolverDN computational methods begin here.
332 //
333 // WARNING to implementers: The function signatures listed in the online docs
334 // are sometimes inaccurate, e.g., are missing 'const' on pointers
335 // to immutable arguments, while the actual headers have them as expected.
336 // Check the actual declarations in the cusolver_api.h header file.
337 //
338 // NOTE: The cuSolver functions called below appear not to be threadsafe.
339 // so we put a global lock around the calls. Since these functions only put a
340 // kernel on the shared stream, it is not a big performance hit.
341 // TODO(rmlarsen): Investigate if the locking is still needed in Cuda 9.
342 //=============================================================================
343
344 template <typename Scalar, typename SolverFnT>
GeamImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasOperation_t transa,cublasOperation_t transb,int m,int n,const Scalar * alpha,const Scalar * A,int lda,const Scalar * beta,const Scalar * B,int ldb,Scalar * C,int ldc)345 static inline Status GeamImpl(SolverFnT solver, cublasHandle_t cublas_handle,
346 cublasOperation_t transa,
347 cublasOperation_t transb, int m, int n,
348 const Scalar* alpha, /* host or device pointer */
349 const Scalar* A, int lda,
350 const Scalar* beta, /* host or device pointer */
351 const Scalar* B, int ldb, Scalar* C, int ldc) {
352 mutex_lock lock(handle_map_mutex);
353 using CudaScalar = typename CUDAComplexT<Scalar>::type;
354 TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, transa, transb, m, n,
355 reinterpret_cast<const CudaScalar*>(alpha),
356 reinterpret_cast<const CudaScalar*>(A), lda,
357 reinterpret_cast<const CudaScalar*>(beta),
358 reinterpret_cast<const CudaScalar*>(B), ldb,
359 reinterpret_cast<CudaScalar*>(C), ldc));
360 return OkStatus();
361 }
362
363 #define GEAM_INSTANCE(Scalar, type_prefix) \
364 template <> \
365 Status GpuSolver::Geam<Scalar>( \
366 cublasOperation_t transa, cublasOperation_t transb, int m, int n, \
367 const Scalar* alpha, /* host or device pointer */ \
368 const Scalar* A, int lda, \
369 const Scalar* beta, /* host or device pointer */ \
370 const Scalar* B, int ldb, Scalar* C, int ldc) const { \
371 return GeamImpl(BLAS_SOLVER_FN(geam, type_prefix), cublas_handle_, transa, \
372 transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); \
373 }
374
375 TF_CALL_LAPACK_TYPES(GEAM_INSTANCE);
376
377 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
PotrfImpl(BufSizeFnT bufsize,SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasFillMode_t uplo,int n,Scalar * A,int lda,int * dev_lapack_info)378 static inline Status PotrfImpl(BufSizeFnT bufsize, SolverFnT solver,
379 GpuSolver* cuda_solver, OpKernelContext* context,
380 cusolverDnHandle_t cusolver_dn_handle,
381 cublasFillMode_t uplo, int n, Scalar* A, int lda,
382 int* dev_lapack_info) {
383 mutex_lock lock(handle_map_mutex);
384 /* Get amount of workspace memory required. */
385 int lwork;
386 TF_RETURN_IF_CUSOLVER_ERROR(
387 bufsize(cusolver_dn_handle, uplo, n, CUDAComplex(A), lda, &lwork));
388 /* Allocate device memory for workspace. */
389 auto dev_workspace =
390 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
391 /* Launch the solver kernel. */
392 TF_RETURN_IF_CUSOLVER_ERROR(solver(
393 cusolver_dn_handle, uplo, n, CUDAComplex(A), lda,
394 CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
395 return OkStatus();
396 }
397
398 #define POTRF_INSTANCE(Scalar, type_prefix) \
399 template <> \
400 Status GpuSolver::Potrf<Scalar>(cublasFillMode_t uplo, int n, Scalar* A, \
401 int lda, int* dev_lapack_info) { \
402 return PotrfImpl(DN_BUFSIZE_FN(potrf, type_prefix), \
403 DN_SOLVER_FN(potrf, type_prefix), this, context_, \
404 cusolver_dn_handle_, uplo, n, A, lda, dev_lapack_info); \
405 }
406
407 TF_CALL_LAPACK_TYPES(POTRF_INSTANCE);
408
409 #if CUDA_VERSION >= 9020
410 template <typename Scalar, typename SolverFnT>
PotrfBatchedImpl(SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasFillMode_t uplo,int n,const Scalar * const host_a_dev_ptrs[],int lda,DeviceLapackInfo * dev_lapack_info,int batch_size)411 static inline Status PotrfBatchedImpl(
412 SolverFnT solver, GpuSolver* cuda_solver, OpKernelContext* context,
413 cusolverDnHandle_t cusolver_dn_handle, cublasFillMode_t uplo, int n,
414 const Scalar* const host_a_dev_ptrs[], int lda,
415 DeviceLapackInfo* dev_lapack_info, int batch_size) {
416 mutex_lock lock(handle_map_mutex);
417 using CudaScalar = typename CUDAComplexT<Scalar>::type;
418 ScratchSpace<uint8> dev_a_dev_ptrs =
419 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
420 /* on_host */ false);
421 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
422 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
423 return errors::Internal("PotrfBatched: failed to copy pointers to device");
424 }
425 TF_RETURN_IF_CUSOLVER_ERROR(
426 solver(cusolver_dn_handle, uplo, n,
427 reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda,
428 dev_lapack_info->mutable_data(), batch_size));
429 return OkStatus();
430 }
431
432 #define POTRF_BATCHED_INSTANCE(Scalar, type_prefix) \
433 template <> \
434 Status GpuSolver::PotrfBatched( \
435 cublasFillMode_t uplo, int n, const Scalar* const host_a_dev_ptrs[], \
436 int lda, DeviceLapackInfo* dev_lapack_info, int batch_size) { \
437 return PotrfBatchedImpl(DN_SOLVER_FN(potrfBatched, type_prefix), this, \
438 context_, cusolver_dn_handle_, uplo, n, \
439 host_a_dev_ptrs, lda, dev_lapack_info, \
440 batch_size); \
441 }
442
443 TF_CALL_LAPACK_TYPES(POTRF_BATCHED_INSTANCE);
444 #endif // CUDA_VERSION >= 9020
445
446 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GetrfImpl(BufSizeFnT bufsize,SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,Scalar * A,int lda,int * dev_pivots,int * dev_lapack_info)447 static inline Status GetrfImpl(BufSizeFnT bufsize, SolverFnT solver,
448 GpuSolver* cuda_solver, OpKernelContext* context,
449 cusolverDnHandle_t cusolver_dn_handle, int m,
450 int n, Scalar* A, int lda, int* dev_pivots,
451 int* dev_lapack_info) {
452 mutex_lock lock(handle_map_mutex);
453 /* Get amount of workspace memory required. */
454 int lwork;
455 TF_RETURN_IF_CUSOLVER_ERROR(
456 bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
457 /* Allocate device memory for workspace. */
458 auto dev_workspace =
459 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
460 /* Launch the solver kernel. */
461 TF_RETURN_IF_CUSOLVER_ERROR(solver(
462 cusolver_dn_handle, m, n, CUDAComplex(A), lda,
463 CUDAComplex(dev_workspace.mutable_data()), dev_pivots, dev_lapack_info));
464 return OkStatus();
465 }
466
467 #define GETRF_INSTANCE(Scalar, type_prefix) \
468 template <> \
469 Status GpuSolver::Getrf<Scalar>(int m, int n, Scalar* A, int lda, \
470 int* dev_pivots, int* dev_lapack_info) { \
471 return GetrfImpl(DN_BUFSIZE_FN(getrf, type_prefix), \
472 DN_SOLVER_FN(getrf, type_prefix), this, context_, \
473 cusolver_dn_handle_, m, n, A, lda, dev_pivots, \
474 dev_lapack_info); \
475 }
476
477 TF_CALL_LAPACK_TYPES(GETRF_INSTANCE);
478
479 template <typename Scalar, typename SolverFnT>
GetrsImpl(SolverFnT solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasOperation_t trans,int n,int nrhs,const Scalar * A,int lda,const int * pivots,Scalar * B,int ldb,int * dev_lapack_info)480 static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context,
481 cusolverDnHandle_t cusolver_dn_handle,
482 cublasOperation_t trans, int n, int nrhs,
483 const Scalar* A, int lda, const int* pivots,
484 Scalar* B, int ldb, int* dev_lapack_info) {
485 mutex_lock lock(handle_map_mutex);
486 /* Launch the solver kernel. */
487 TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, trans, n, nrhs,
488 CUDAComplex(A), lda, pivots,
489 CUDAComplex(B), ldb, dev_lapack_info));
490 return OkStatus();
491 }
492
493 #define GETRS_INSTANCE(Scalar, type_prefix) \
494 template <> \
495 Status GpuSolver::Getrs<Scalar>( \
496 cublasOperation_t trans, int n, int nrhs, const Scalar* A, int lda, \
497 const int* pivots, Scalar* B, int ldb, int* dev_lapack_info) const { \
498 return GetrsImpl(DN_SOLVER_FN(getrs, type_prefix), context_, \
499 cusolver_dn_handle_, trans, n, nrhs, A, lda, pivots, B, \
500 ldb, dev_lapack_info); \
501 }
502
503 TF_CALL_LAPACK_TYPES(GETRS_INSTANCE);
504
505 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GeqrfImpl(BufSizeFnT bufsize,SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,Scalar * A,int lda,Scalar * tau,int * dev_lapack_info)506 static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver,
507 GpuSolver* cuda_solver, OpKernelContext* context,
508 cusolverDnHandle_t cusolver_dn_handle, int m,
509 int n, Scalar* A, int lda, Scalar* tau,
510 int* dev_lapack_info) {
511 mutex_lock lock(handle_map_mutex);
512 /* Get amount of workspace memory required. */
513 int lwork;
514 TF_RETURN_IF_CUSOLVER_ERROR(
515 bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
516 /* Allocate device memory for workspace. */
517 auto dev_workspace =
518 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
519 /* Launch the solver kernel. */
520 TF_RETURN_IF_CUSOLVER_ERROR(solver(
521 cusolver_dn_handle, m, n, CUDAComplex(A), lda, CUDAComplex(tau),
522 CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
523 return OkStatus();
524 }
525
526 #define GEQRF_INSTANCE(Scalar, type_prefix) \
527 template <> \
528 Status GpuSolver::Geqrf<Scalar>(int m, int n, Scalar* A, int lda, \
529 Scalar* tau, int* dev_lapack_info) { \
530 return GeqrfImpl(DN_BUFSIZE_FN(geqrf, type_prefix), \
531 DN_SOLVER_FN(geqrf, type_prefix), this, context_, \
532 cusolver_dn_handle_, m, n, A, lda, tau, dev_lapack_info); \
533 }
534
535 TF_CALL_LAPACK_TYPES(GEQRF_INSTANCE);
536
537 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
UnmqrImpl(BufSizeFnT bufsize,SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasSideMode_t side,cublasOperation_t trans,int m,int n,int k,const Scalar * dev_a,int lda,const Scalar * dev_tau,Scalar * dev_c,int ldc,int * dev_lapack_info)538 static inline Status UnmqrImpl(BufSizeFnT bufsize, SolverFnT solver,
539 GpuSolver* cuda_solver, OpKernelContext* context,
540 cusolverDnHandle_t cusolver_dn_handle,
541 cublasSideMode_t side, cublasOperation_t trans,
542 int m, int n, int k, const Scalar* dev_a,
543 int lda, const Scalar* dev_tau, Scalar* dev_c,
544 int ldc, int* dev_lapack_info) {
545 mutex_lock lock(handle_map_mutex);
546 /* Get amount of workspace memory required. */
547 int lwork;
548 TF_RETURN_IF_CUSOLVER_ERROR(
549 bufsize(cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
550 CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc, &lwork));
551 /* Allocate device memory for workspace. */
552 auto dev_workspace =
553 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
554 /* Launch the solver kernel. */
555 TF_RETURN_IF_CUSOLVER_ERROR(solver(
556 cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
557 CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc,
558 CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
559 return OkStatus();
560 }
561
562 // Unfortunately the LAPACK function name differs for the real and complex case
563 // (complex ones are prefixed with "UN" for "unitary"), so we instantiate each
564 // one separately.
565 #define UNMQR_INSTANCE(Scalar, function_prefix, type_prefix) \
566 template <> \
567 Status GpuSolver::Unmqr(cublasSideMode_t side, cublasOperation_t trans, \
568 int m, int n, int k, const Scalar* dev_a, int lda, \
569 const Scalar* dev_tau, Scalar* dev_c, int ldc, \
570 int* dev_lapack_info) { \
571 return UnmqrImpl(DN_BUFSIZE_FN(function_prefix##mqr, type_prefix), \
572 DN_SOLVER_FN(function_prefix##mqr, type_prefix), this, \
573 context_, cusolver_dn_handle_, side, trans, m, n, k, \
574 dev_a, lda, dev_tau, dev_c, ldc, dev_lapack_info); \
575 }
576
577 UNMQR_INSTANCE(float, or, S);
578 UNMQR_INSTANCE(double, or, D);
579 UNMQR_INSTANCE(complex64, un, C);
580 UNMQR_INSTANCE(complex128, un, Z);
581
582 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
UngqrImpl(BufSizeFnT bufsize,SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,int k,Scalar * dev_a,int lda,const Scalar * dev_tau,int * dev_lapack_info)583 static inline Status UngqrImpl(BufSizeFnT bufsize, SolverFnT solver,
584 GpuSolver* cuda_solver, OpKernelContext* context,
585 cusolverDnHandle_t cusolver_dn_handle, int m,
586 int n, int k, Scalar* dev_a, int lda,
587 const Scalar* dev_tau, int* dev_lapack_info) {
588 mutex_lock lock(handle_map_mutex);
589 /* Get amount of workspace memory required. */
590 int lwork;
591 TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k,
592 CUDAComplex(dev_a), lda,
593 CUDAComplex(dev_tau), &lwork));
594 /* Allocate device memory for workspace. */
595 auto dev_workspace =
596 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
597 /* Launch the solver kernel. */
598 TF_RETURN_IF_CUSOLVER_ERROR(
599 solver(cusolver_dn_handle, m, n, k, CUDAComplex(dev_a), lda,
600 CUDAComplex(dev_tau), CUDAComplex(dev_workspace.mutable_data()),
601 lwork, dev_lapack_info));
602 return OkStatus();
603 }
604
605 #define UNGQR_INSTANCE(Scalar, function_prefix, type_prefix) \
606 template <> \
607 Status GpuSolver::Ungqr(int m, int n, int k, Scalar* dev_a, int lda, \
608 const Scalar* dev_tau, int* dev_lapack_info) { \
609 return UngqrImpl(DN_BUFSIZE_FN(function_prefix##gqr, type_prefix), \
610 DN_SOLVER_FN(function_prefix##gqr, type_prefix), this, \
611 context_, cusolver_dn_handle_, m, n, k, dev_a, lda, \
612 dev_tau, dev_lapack_info); \
613 }
614
615 UNGQR_INSTANCE(float, or, S);
616 UNGQR_INSTANCE(double, or, D);
617 UNGQR_INSTANCE(complex64, un, C);
618 UNGQR_INSTANCE(complex128, un, Z);
619
620 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
HeevdImpl(BufSizeFnT bufsize,SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,Scalar * dev_A,int lda,typename Eigen::NumTraits<Scalar>::Real * dev_W,int * dev_lapack_info)621 static inline Status HeevdImpl(BufSizeFnT bufsize, SolverFnT solver,
622 GpuSolver* cuda_solver, OpKernelContext* context,
623 cusolverDnHandle_t cusolver_dn_handle,
624 cusolverEigMode_t jobz, cublasFillMode_t uplo,
625 int n, Scalar* dev_A, int lda,
626 typename Eigen::NumTraits<Scalar>::Real* dev_W,
627 int* dev_lapack_info) {
628 mutex_lock lock(handle_map_mutex);
629 /* Get amount of workspace memory required. */
630 int lwork;
631 TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, jobz, uplo, n,
632 CUDAComplex(dev_A), lda,
633 CUDAComplex(dev_W), &lwork));
634 /* Allocate device memory for workspace. */
635 auto dev_workspace =
636 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
637 #if CUDA_VERSION >= 11070
638 // TODO(b/223856016): CUDA 11.7 sometimes gives invalid outputs if the scratch
639 // space is not initialized to zero.
640 se::Stream* stream = context->op_device_context()->stream();
641 if (!stream) {
642 return errors::Internal("No GPU stream available");
643 }
644 uint64_t work_size_in_bytes = static_cast<uint64_t>(lwork) * sizeof(Scalar);
645 se::DeviceMemoryBase dev_workspace_ptr(dev_workspace.mutable_data(),
646 work_size_in_bytes);
647 stream->ThenMemZero(&dev_workspace_ptr, work_size_in_bytes);
648 #endif
649 /* Launch the solver kernel. */
650 TF_RETURN_IF_CUSOLVER_ERROR(
651 solver(cusolver_dn_handle, jobz, uplo, n, CUDAComplex(dev_A), lda,
652 CUDAComplex(dev_W), CUDAComplex(dev_workspace.mutable_data()),
653 lwork, dev_lapack_info));
654 return OkStatus();
655 }
656
657 #define HEEVD_INSTANCE(Scalar, function_prefix, type_prefix) \
658 template <> \
659 Status GpuSolver::Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, \
660 int n, Scalar* dev_A, int lda, \
661 typename Eigen::NumTraits<Scalar>::Real* dev_W, \
662 int* dev_lapack_info) { \
663 return HeevdImpl(DN_BUFSIZE_FN(function_prefix##evd, type_prefix), \
664 DN_SOLVER_FN(function_prefix##evd, type_prefix), this, \
665 context_, cusolver_dn_handle_, jobz, uplo, n, dev_A, lda, \
666 dev_W, dev_lapack_info); \
667 }
668
669 HEEVD_INSTANCE(float, sy, S);
670 HEEVD_INSTANCE(double, sy, D);
671 HEEVD_INSTANCE(complex64, he, C);
672 HEEVD_INSTANCE(complex128, he, Z);
673
674 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GesvdImpl(BufSizeFnT bufsize,SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,signed char jobu,signed char jobvt,int m,int n,Scalar * A,int lda,Scalar * S,Scalar * U,int ldu,Scalar * VT,int ldvt,int * dev_lapack_info)675 static inline Status GesvdImpl(BufSizeFnT bufsize, SolverFnT solver,
676 GpuSolver* cuda_solver, OpKernelContext* context,
677 cusolverDnHandle_t cusolver_dn_handle,
678 signed char jobu, signed char jobvt, int m,
679 int n, Scalar* A, int lda, Scalar* S, Scalar* U,
680 int ldu, Scalar* VT, int ldvt,
681 int* dev_lapack_info) {
682 mutex_lock lock(handle_map_mutex);
683 /* Get amount of workspace memory required. */
684 int lwork;
685 TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork));
686 /* Allocate device memory for workspace. */
687 auto dev_workspace =
688 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
689 TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, jobu, jobvt, m, n,
690 CUDAComplex(A), lda, S, CUDAComplex(U),
691 ldu, CUDAComplex(VT), ldvt,
692 CUDAComplex(dev_workspace.mutable_data()),
693 lwork, nullptr, dev_lapack_info));
694 return OkStatus();
695 }
696
697 #define GESVD_INSTANCE(Scalar, type_prefix) \
698 template <> \
699 Status GpuSolver::Gesvd<Scalar>( \
700 signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A, \
701 int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT, \
702 int ldvt, int* dev_lapack_info) { \
703 return GesvdImpl(DN_BUFSIZE_FN(gesvd, type_prefix), \
704 DN_SOLVER_FN(gesvd, type_prefix), this, context_, \
705 cusolver_dn_handle_, jobu, jobvt, m, n, dev_A, lda, \
706 dev_S, dev_U, ldu, dev_VT, ldvt, dev_lapack_info); \
707 }
708
709 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVD_INSTANCE);
710
711 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GesvdjBatchedImpl(BufSizeFnT bufsize,SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cusolverEigMode_t jobz,int m,int n,Scalar * A,int lda,Scalar * S,Scalar * U,int ldu,Scalar * V,int ldv,int * dev_lapack_info,int batch_size)712 static inline Status GesvdjBatchedImpl(BufSizeFnT bufsize, SolverFnT solver,
713 GpuSolver* cuda_solver,
714 OpKernelContext* context,
715 cusolverDnHandle_t cusolver_dn_handle,
716 cusolverEigMode_t jobz, int m, int n,
717 Scalar* A, int lda, Scalar* S, Scalar* U,
718 int ldu, Scalar* V, int ldv,
719 int* dev_lapack_info, int batch_size) {
720 mutex_lock lock(handle_map_mutex);
721 /* Get amount of workspace memory required. */
722 int lwork;
723 /* Default parameters for gesvdj and gesvdjBatched. */
724 gesvdjInfo_t svdj_info;
725 TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnCreateGesvdjInfo(&svdj_info));
726 TF_RETURN_IF_CUSOLVER_ERROR(bufsize(
727 cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U),
728 ldu, CUDAComplex(V), ldv, &lwork, svdj_info, batch_size));
729 /* Allocate device memory for workspace. */
730 auto dev_workspace =
731 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
732 TF_RETURN_IF_CUSOLVER_ERROR(solver(
733 cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U),
734 ldu, CUDAComplex(V), ldv, CUDAComplex(dev_workspace.mutable_data()),
735 lwork, dev_lapack_info, svdj_info, batch_size));
736 TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnDestroyGesvdjInfo(svdj_info));
737 return OkStatus();
738 }
739
740 #define GESVDJBATCHED_INSTANCE(Scalar, type_prefix) \
741 template <> \
742 Status GpuSolver::GesvdjBatched<Scalar>( \
743 cusolverEigMode_t jobz, int m, int n, Scalar* dev_A, int lda, \
744 Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_V, int ldv, \
745 int* dev_lapack_info, int batch_size) { \
746 return GesvdjBatchedImpl(DN_BUFSIZE_FN(gesvdjBatched, type_prefix), \
747 DN_SOLVER_FN(gesvdjBatched, type_prefix), this, \
748 context_, cusolver_dn_handle_, jobz, m, n, dev_A, \
749 lda, dev_S, dev_U, ldu, dev_V, ldv, \
750 dev_lapack_info, batch_size); \
751 }
752
753 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVDJBATCHED_INSTANCE);
754
755 //=============================================================================
756 // Wrappers of cuBlas computational methods begin here.
757 //
758 // WARNING to implementers: The function signatures listed in the online docs
759 // are sometimes inaccurate, e.g., are missing 'const' on pointers
760 // to immutable arguments, while the actual headers have them as expected.
761 // Check the actual declarations in the cublas_api.h header file.
762 //=============================================================================
763 template <typename Scalar, typename SolverFnT>
GetrfBatchedImpl(SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,int * dev_pivots,DeviceLapackInfo * dev_lapack_info,int batch_size)764 static inline Status GetrfBatchedImpl(SolverFnT solver, GpuSolver* cuda_solver,
765 OpKernelContext* context,
766 cublasHandle_t cublas_handle, int n,
767 const Scalar* const host_a_dev_ptrs[],
768 int lda, int* dev_pivots,
769 DeviceLapackInfo* dev_lapack_info,
770 int batch_size) {
771 mutex_lock lock(handle_map_mutex);
772 using CudaScalar = typename CUDAComplexT<Scalar>::type;
773 ScratchSpace<uint8> dev_a_dev_ptrs =
774 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
775 /* on_host */ false);
776 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
777 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
778 return errors::Internal("GetrfBatched: failed to copy pointers to device");
779 }
780 TF_RETURN_IF_CUBLAS_ERROR(
781 solver(cublas_handle, n,
782 reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda,
783 dev_pivots, dev_lapack_info->mutable_data(), batch_size));
784 return OkStatus();
785 }
786
787 #define GETRF_BATCHED_INSTANCE(Scalar, type_prefix) \
788 template <> \
789 Status GpuSolver::GetrfBatched( \
790 int n, const Scalar* const host_a_dev_ptrs[], int lda, int* dev_pivots, \
791 DeviceLapackInfo* dev_lapack_info, int batch_size) { \
792 return GetrfBatchedImpl(BLAS_SOLVER_FN(getrfBatched, type_prefix), this, \
793 context_, cublas_handle_, n, host_a_dev_ptrs, lda, \
794 dev_pivots, dev_lapack_info, batch_size); \
795 }
796
797 TF_CALL_LAPACK_TYPES(GETRF_BATCHED_INSTANCE);
798
799 template <typename Scalar, typename SolverFnT>
GetrsBatchedImpl(SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,cublasOperation_t trans,int n,int nrhs,const Scalar * const host_a_dev_ptrs[],int lda,const int * dev_pivots,const Scalar * const host_b_dev_ptrs[],int ldb,int * host_lapack_info,int batch_size)800 static inline Status GetrsBatchedImpl(
801 SolverFnT solver, GpuSolver* cuda_solver, OpKernelContext* context,
802 cublasHandle_t cublas_handle, cublasOperation_t trans, int n, int nrhs,
803 const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,
804 const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info,
805 int batch_size) {
806 mutex_lock lock(handle_map_mutex);
807 using CudaScalar = typename CUDAComplexT<Scalar>::type;
808 ScratchSpace<uint8> dev_a_dev_ptrs =
809 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
810 /* on_host */ false);
811 ScratchSpace<uint8> dev_b_dev_ptrs =
812 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
813 /* on_host */ false);
814 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
815 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
816 return errors::Internal("GetrsBatched: failed to copy pointers to device");
817 }
818 if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
819 host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
820 return errors::Internal("GetrsBatched: failed to copy pointers to device");
821 }
822 TF_RETURN_IF_CUBLAS_ERROR(solver(
823 cublas_handle, trans, n, nrhs,
824 reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
825 dev_pivots, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
826 ldb, host_lapack_info, batch_size));
827 return OkStatus();
828 }
829
830 #define GETRS_BATCHED_INSTANCE(Scalar, type_prefix) \
831 template <> \
832 Status GpuSolver::GetrsBatched( \
833 cublasOperation_t trans, int n, int nrhs, \
834 const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, \
835 const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info, \
836 int batch_size) { \
837 return GetrsBatchedImpl(reinterpret_cast<getrs_##type_prefix*>( \
838 BLAS_SOLVER_FN(getrsBatched, type_prefix)), \
839 this, context_, cublas_handle_, trans, n, nrhs, \
840 host_a_dev_ptrs, lda, dev_pivots, host_b_dev_ptrs, \
841 ldb, host_lapack_info, batch_size); \
842 }
843
844 TF_CALL_LAPACK_TYPES(GETRS_BATCHED_INSTANCE);
845
846 template <typename Scalar, typename SolverFnT>
GetriBatchedImpl(SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,const int * dev_pivots,const Scalar * const host_a_inv_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)847 static inline Status GetriBatchedImpl(
848 SolverFnT solver, GpuSolver* cuda_solver, OpKernelContext* context,
849 cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
850 int lda, const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],
851 int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {
852 mutex_lock lock(handle_map_mutex);
853 using CudaScalar = typename CUDAComplexT<Scalar>::type;
854 ScratchSpace<uint8> dev_a_dev_ptrs =
855 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
856 /* on_host */ false);
857 ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
858 sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
859 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
860 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
861 !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
862 host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
863 return errors::Internal("GetriBatched: failed to copy pointers to device");
864 }
865 TF_RETURN_IF_CUBLAS_ERROR(
866 solver(cublas_handle, n,
867 reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
868 lda, dev_pivots,
869 reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()),
870 ldainv, dev_lapack_info->mutable_data(), batch_size));
871 return OkStatus();
872 }
873
874 #define GETRI_BATCHED_INSTANCE(Scalar, type_prefix) \
875 template <> \
876 Status GpuSolver::GetriBatched( \
877 int n, const Scalar* const host_a_dev_ptrs[], int lda, \
878 const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[], \
879 int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) { \
880 return GetriBatchedImpl( \
881 reinterpret_cast<getri_##type_prefix*>( \
882 BLAS_SOLVER_FN(getriBatched, type_prefix)), \
883 this, context_, cublas_handle_, n, host_a_dev_ptrs, lda, dev_pivots, \
884 host_a_inv_dev_ptrs, ldainv, dev_lapack_info, batch_size); \
885 }
886
887 TF_CALL_LAPACK_TYPES(GETRI_BATCHED_INSTANCE);
888
889 template <typename Scalar, typename SolverFnT>
MatInvBatchedImpl(SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,const Scalar * const host_a_inv_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)890 static inline Status MatInvBatchedImpl(
891 SolverFnT solver, GpuSolver* cuda_solver, OpKernelContext* context,
892 cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
893 int lda, const Scalar* const host_a_inv_dev_ptrs[], int ldainv,
894 DeviceLapackInfo* dev_lapack_info, int batch_size) {
895 mutex_lock lock(handle_map_mutex);
896 using CudaScalar = typename CUDAComplexT<Scalar>::type;
897 ScratchSpace<uint8> dev_a_dev_ptrs =
898 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
899 /* on_host */ false);
900 ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
901 sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
902 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
903 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
904 !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
905 host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
906 return errors::Internal("MatInvBatched: failed to copy pointers to device");
907 }
908 TF_RETURN_IF_CUBLAS_ERROR(solver(
909 cublas_handle, n,
910 reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
911 reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()), ldainv,
912 dev_lapack_info->mutable_data(), batch_size));
913 return OkStatus();
914 }
915
916 #define MATINV_BATCHED_INSTANCE(Scalar, type_prefix) \
917 template <> \
918 Status GpuSolver::MatInvBatched( \
919 int n, const Scalar* const host_a_dev_ptrs[], int lda, \
920 const Scalar* const host_a_inv_dev_ptrs[], int ldainv, \
921 DeviceLapackInfo* dev_lapack_info, int batch_size) { \
922 return MatInvBatchedImpl(reinterpret_cast<matinv_##type_prefix*>( \
923 BLAS_SOLVER_FN(matinvBatched, type_prefix)), \
924 this, context_, cublas_handle_, n, \
925 host_a_dev_ptrs, lda, host_a_inv_dev_ptrs, \
926 ldainv, dev_lapack_info, batch_size); \
927 }
928
929 TF_CALL_LAPACK_TYPES(MATINV_BATCHED_INSTANCE);
930
931 template <typename Scalar, typename SolverFnT>
TrsmImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasSideMode_t side,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int m,int n,const Scalar * alpha,const Scalar * A,int lda,Scalar * B,int ldb)932 static inline Status TrsmImpl(SolverFnT solver, cublasHandle_t cublas_handle,
933 cublasSideMode_t side, cublasFillMode_t uplo,
934 cublasOperation_t trans, cublasDiagType_t diag,
935 int m, int n,
936 const Scalar* alpha, /* host or device pointer */
937 const Scalar* A, int lda, Scalar* B, int ldb) {
938 mutex_lock lock(handle_map_mutex);
939 using CudaScalar = typename CUDAComplexT<Scalar>::type;
940 TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, side, uplo, trans, diag, m, n,
941 reinterpret_cast<const CudaScalar*>(alpha),
942 reinterpret_cast<const CudaScalar*>(A), lda,
943 reinterpret_cast<CudaScalar*>(B), ldb));
944 return OkStatus();
945 }
946
947 #define TRSM_INSTANCE(Scalar, type_prefix) \
948 template <> \
949 Status GpuSolver::Trsm<Scalar>( \
950 cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \
951 cublasDiagType_t diag, int m, int n, \
952 const Scalar* alpha, /* host or device pointer */ \
953 const Scalar* A, int lda, Scalar* B, int ldb) { \
954 return TrsmImpl(BLAS_SOLVER_FN(trsm, type_prefix), cublas_handle_, side, \
955 uplo, trans, diag, m, n, alpha, A, lda, B, ldb); \
956 }
957
958 TF_CALL_LAPACK_TYPES(TRSM_INSTANCE);
959
960 template <typename Scalar, typename SolverFnT>
TrsvImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int n,const Scalar * A,int lda,Scalar * x,int incx)961 static inline Status TrsvImpl(SolverFnT solver, cublasHandle_t cublas_handle,
962 cublasFillMode_t uplo, cublasOperation_t trans,
963 cublasDiagType_t diag, int n, const Scalar* A,
964 int lda, Scalar* x, int incx) {
965 mutex_lock lock(handle_map_mutex);
966 using CudaScalar = typename CUDAComplexT<Scalar>::type;
967 TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, uplo, trans, diag, n,
968 reinterpret_cast<const CudaScalar*>(A), lda,
969 reinterpret_cast<CudaScalar*>(x), incx));
970 return OkStatus();
971 }
972
973 #define TRSV_INSTANCE(Scalar, type_prefix) \
974 template <> \
975 Status GpuSolver::Trsv<Scalar>( \
976 cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, \
977 int n, const Scalar* A, int lda, Scalar* x, int incx) { \
978 return TrsvImpl(BLAS_SOLVER_FN(trsv, type_prefix), cublas_handle_, uplo, \
979 trans, diag, n, A, lda, x, incx); \
980 }
981
982 TF_CALL_LAPACK_TYPES(TRSV_INSTANCE);
983
984 template <typename Scalar, typename SolverFnT>
TrsmBatchedImpl(SolverFnT solver,GpuSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,cublasSideMode_t side,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int m,int n,const Scalar * alpha,const Scalar * const host_a_dev_ptrs[],int lda,Scalar * host_b_dev_ptrs[],int ldb,int batch_size)985 static inline Status TrsmBatchedImpl(
986 SolverFnT solver, GpuSolver* cuda_solver, OpKernelContext* context,
987 cublasHandle_t cublas_handle, cublasSideMode_t side, cublasFillMode_t uplo,
988 cublasOperation_t trans, cublasDiagType_t diag, int m, int n,
989 const Scalar* alpha, const Scalar* const host_a_dev_ptrs[], int lda,
990 Scalar* host_b_dev_ptrs[], int ldb, int batch_size) {
991 mutex_lock lock(handle_map_mutex);
992 using CudaScalar = typename CUDAComplexT<Scalar>::type;
993 ScratchSpace<uint8> dev_a_dev_ptrs =
994 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
995 /* on_host */ false);
996 ScratchSpace<uint8> dev_b_dev_ptrs =
997 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
998 /* on_host */ false);
999 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
1000 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
1001 return errors::Internal("TrsmBatched: failed to copy pointers to device");
1002 }
1003 if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
1004 host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
1005 return errors::Internal("TrsmBatched: failed to copy pointers to device");
1006 }
1007 TF_RETURN_IF_CUBLAS_ERROR(
1008 solver(cublas_handle, side, uplo, trans, diag, m, n,
1009 reinterpret_cast<const CudaScalar*>(alpha),
1010 reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
1011 lda, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
1012 ldb, batch_size));
1013 return OkStatus();
1014 }
1015
1016 #define TRSM_BATCHED_INSTANCE(Scalar, type_prefix) \
1017 template <> \
1018 Status GpuSolver::TrsmBatched( \
1019 cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \
1020 cublasDiagType_t diag, int m, int n, const Scalar* alpha, \
1021 const Scalar* const dev_Aarray[], int lda, Scalar* dev_Barray[], \
1022 int ldb, int batch_size) { \
1023 return TrsmBatchedImpl(reinterpret_cast<trsm_##type_prefix*>( \
1024 BLAS_SOLVER_FN(trsmBatched, type_prefix)), \
1025 this, context_, cublas_handle_, side, uplo, trans, \
1026 diag, m, n, alpha, dev_Aarray, lda, dev_Barray, \
1027 ldb, batch_size); \
1028 }
1029
1030 TF_CALL_LAPACK_TYPES(TRSM_BATCHED_INSTANCE);
1031
1032 } // namespace tensorflow
1033
1034 #endif // GOOGLE_CUDA
1035