xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/gpu_solvers.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================
15 */
16 
17 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_GPU_SOLVERS_H_
18 #define TENSORFLOW_CORE_KERNELS_LINALG_GPU_SOLVERS_H_
19 
20 // This header declares the class GpuSolver, which contains wrappers of linear
21 // algebra solvers in the cuBlas/cuSolverDN or rocmSolver libraries for use in
22 // TensorFlow kernels.
23 
24 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
25 
26 #include <functional>
27 #include <vector>
28 
29 #if GOOGLE_CUDA
30 #include "third_party/gpus/cuda/include/cublas_v2.h"
31 #include "third_party/gpus/cuda/include/cuda.h"
32 #include "third_party/gpus/cuda/include/cusolverDn.h"
33 #else
34 #include "rocm/include/hip/hip_complex.h"
35 #include "rocm/include/rocblas.h"
36 #include "tensorflow/stream_executor/blas.h"
37 #include "tensorflow/stream_executor/rocm/rocsolver_wrapper.h"
38 #endif
39 #include "tensorflow/core/framework/op_kernel.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/framework/tensor_reference.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/platform/stream_executor.h"
44 
45 namespace tensorflow {
46 
47 #if GOOGLE_CUDA
48 // Type traits to get CUDA complex types from std::complex<T>.
49 template <typename T>
50 struct CUDAComplexT {
51   typedef T type;
52 };
53 template <>
54 struct CUDAComplexT<std::complex<float>> {
55   typedef cuComplex type;
56 };
57 template <>
58 struct CUDAComplexT<std::complex<double>> {
59   typedef cuDoubleComplex type;
60 };
61 // Converts pointers of std::complex<> to pointers of
62 // cuComplex/cuDoubleComplex. No type conversion for non-complex types.
63 template <typename T>
64 inline const typename CUDAComplexT<T>::type* CUDAComplex(const T* p) {
65   return reinterpret_cast<const typename CUDAComplexT<T>::type*>(p);
66 }
67 template <typename T>
68 inline typename CUDAComplexT<T>::type* CUDAComplex(T* p) {
69   return reinterpret_cast<typename CUDAComplexT<T>::type*>(p);
70 }
71 
72 // Template to give the Cublas adjoint operation for real and complex types.
73 template <typename T>
74 cublasOperation_t CublasAdjointOp() {
75   return Eigen::NumTraits<T>::IsComplex ? CUBLAS_OP_C : CUBLAS_OP_T;
76 }
77 #else  // TENSORFLOW_USE_ROCM
78 // Type traits to get ROCm complex types from std::complex<T>.
79 template <typename T>
80 struct ROCmComplexT {
81   typedef T type;
82 };
83 template <>
84 struct ROCmComplexT<std::complex<float>> {
85   typedef rocblas_float_complex type;
86 };
87 template <>
88 struct ROCmComplexT<std::complex<double>> {
89   typedef rocblas_double_complex type;
90 };
91 // Converts pointers of std::complex<> to pointers of
92 // ROCmComplex/ROCmDoubleComplex. No type conversion for non-complex types.
93 template <typename T>
94 inline const typename ROCmComplexT<T>::type* ROCmComplex(const T* p) {
95   return reinterpret_cast<const typename ROCmComplexT<T>::type*>(p);
96 }
97 template <typename T>
98 inline typename ROCmComplexT<T>::type* ROCmComplex(T* p) {
99   return reinterpret_cast<typename ROCmComplexT<T>::type*>(p);
100 }
101 
102 // Type traits to get HIP complex types from std::complex<>
103 
104 template <typename T>
105 struct HipComplexT {
106   typedef T type;
107 };
108 
109 template <>
110 struct HipComplexT<std::complex<float>> {
111   typedef hipFloatComplex type;
112 };
113 
114 template <>
115 struct HipComplexT<std::complex<double>> {
116   typedef hipDoubleComplex type;
117 };
118 
119 // Convert pointers of std::complex<> to pointers of
120 // hipFloatComplex/hipDoubleComplex. No type conversion for non-complex types.
121 template <typename T>
122 inline const typename HipComplexT<T>::type* AsHipComplex(const T* p) {
123   return reinterpret_cast<const typename HipComplexT<T>::type*>(p);
124 }
125 
126 template <typename T>
127 inline typename HipComplexT<T>::type* AsHipComplex(T* p) {
128   return reinterpret_cast<typename HipComplexT<T>::type*>(p);
129 }
130 // Template to give the Rocblas adjoint operation for real and complex types.
131 template <typename T>
132 rocblas_operation RocblasAdjointOp() {
133   return Eigen::NumTraits<T>::IsComplex ? rocblas_operation_conjugate_transpose
134                                         : rocblas_operation_transpose;
135 }
136 
137 #if TF_ROCM_VERSION >= 40500
138 using gpuSolverOp_t = hipsolverOperation_t;
139 using gpuSolverFill_t = hipsolverFillMode_t;
140 using gpuSolverSide_t = hipsolverSideMode_t;
141 #else
142 using gpuSolverOp_t = rocblas_operation;
143 using gpuSolverFill_t = rocblas_fill;
144 using gpuSolverSide_t = rocblas_side;
145 #endif
146 #endif
147 
148 // Container of LAPACK info data (an array of int) generated on-device by
149 // a GpuSolver call. One or more such objects can be passed to
150 // GpuSolver::CopyLapackInfoToHostAsync() along with a callback to
151 // check the LAPACK info data after the corresponding kernels
152 // finish and LAPACK info has been copied from the device to the host.
153 class DeviceLapackInfo;
154 
155 // Host-side copy of LAPACK info.
156 class HostLapackInfo;
157 
158 // The GpuSolver class provides a simplified templated API for the dense linear
159 // solvers implemented in cuSolverDN (http://docs.nvidia.com/cuda/cusolver) and
160 // cuBlas (http://docs.nvidia.com/cuda/cublas/#blas-like-extension/).
161 // An object of this class wraps static cuSolver and cuBlas instances,
162 // and will launch Cuda kernels on the stream wrapped by the GPU device
163 // in the OpKernelContext provided to the constructor.
164 //
165 // Notice: All the computational member functions are asynchronous and simply
166 // launch one or more Cuda kernels on the Cuda stream wrapped by the GpuSolver
167 // object. To check the final status of the kernels run, call
168 // CopyLapackInfoToHostAsync() on the GpuSolver object to set a callback that
169 // will be invoked with the status of the kernels launched thus far as
170 // arguments.
171 //
172 // Example of an asynchronous TensorFlow kernel using GpuSolver:
173 //
174 // template <typename Scalar>
175 // class SymmetricPositiveDefiniteSolveOpGpu : public AsyncOpKernel {
176 //  public:
177 //   explicit SymmetricPositiveDefiniteSolveOpGpu(OpKernelConstruction* context)
178 //       : AsyncOpKernel(context) { }
179 //   void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
180 //     // 1. Set up input and output device ptrs. See, e.g.,
181 //     // matrix_inverse_op.cc for a full example.
182 //     ...
183 //
184 //     // 2. Initialize the solver object.
185 //     std::unique_ptr<GpuSolver> solver(new GpuSolver(context));
186 //
187 //     // 3. Launch the two compute kernels back to back on the stream without
188 //     // synchronizing.
189 //     std::vector<DeviceLapackInfo> dev_info;
190 //     const int batch_size = 1;
191 //     dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrf");
192 //     // Compute the Cholesky decomposition of the input matrix.
193 //     OP_REQUIRES_OK_ASYNC(context,
194 //                          solver->Potrf(uplo, n, dev_matrix_ptrs, n,
195 //                                        dev_info.back().mutable_data()),
196 //                          done);
197 //     dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrs");
198 //     // Use the Cholesky decomposition of the input matrix to solve A X = RHS.
199 //     OP_REQUIRES_OK_ASYNC(context,
200 //                          solver->Potrs(uplo, n, nrhs, dev_matrix_ptrs, n,
201 //                                        dev_output_ptrs, ldrhs,
202 //                                        dev_info.back().mutable_data()),
203 //                          done);
204 //
205 //     // 4. Check the status after the computation finishes and call done.
206 //     solver.CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
207 //                                                std::move(done));
208 //   }
209 // };
210 
211 template <typename Scalar>
212 class ScratchSpace;
213 
214 class GpuSolver {
215  public:
216   // This object stores a pointer to context, which must outlive it.
217   explicit GpuSolver(OpKernelContext* context);
218   virtual ~GpuSolver();
219 
220   // Launches a memcpy of solver status data specified by dev_lapack_info from
221   // device to the host, and asynchronously invokes the given callback when the
222   // copy is complete. The first Status argument to the callback will be
223   // Status::OK if all lapack infos retrieved are zero, otherwise an error
224   // status is given. The second argument contains a host-side copy of the
225   // entire set of infos retrieved, and can be used for generating detailed
226   // error messages.
227   // `info_checker_callback` must call the DoneCallback of any asynchronous
228   // OpKernel within which `solver` is used.
229   static void CheckLapackInfoAndDeleteSolverAsync(
230       std::unique_ptr<GpuSolver> solver,
231       const std::vector<DeviceLapackInfo>& dev_lapack_info,
232       std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
233           info_checker_callback);
234 
235   // Simpler version to use if no special error checking / messages are needed
236   // apart from checking that the Status of all calls was Status::OK.
237   // `done` may be nullptr.
238   static void CheckLapackInfoAndDeleteSolverAsync(
239       std::unique_ptr<GpuSolver> solver,
240       const std::vector<DeviceLapackInfo>& dev_lapack_info,
241       AsyncOpKernel::DoneCallback done);
242 
243   // Returns a ScratchSpace. The GpuSolver object maintains a TensorReference
244   // to the underlying Tensor to prevent it from being deallocated prematurely.
245   template <typename Scalar>
246   ScratchSpace<Scalar> GetScratchSpace(const TensorShape& shape,
247                                        const std::string& debug_info,
248                                        bool on_host);
249   template <typename Scalar>
250   ScratchSpace<Scalar> GetScratchSpace(int64_t size,
251                                        const std::string& debug_info,
252                                        bool on_host);
253   // Returns a DeviceLapackInfo that will live for the duration of the
254   // GpuSolver object.
255   inline DeviceLapackInfo GetDeviceLapackInfo(int64_t size,
256                                               const std::string& debug_info);
257 
258   // Allocates a temporary tensor that will live for the duration of the
259   // GpuSolver object.
260   Status allocate_scoped_tensor(DataType type, const TensorShape& shape,
261                                 Tensor* scoped_tensor);
262   Status forward_input_or_allocate_scoped_tensor(
263       gtl::ArraySlice<int> candidate_input_indices, DataType type,
264       const TensorShape& shape, Tensor* input_alias_or_new_scoped_tensor);
265 
266   OpKernelContext* context() { return context_; }
267 
268 #if TENSORFLOW_USE_ROCM
269   // ====================================================================
270   // Wrappers for ROCSolver start here
271   //
272   // The method names below
273   // map to those in ROCSolver/Hipsolver, which follow the naming
274   // convention in LAPACK. See rocm_solvers.cc for a mapping of
275   // GpuSolverMethod to library API
276 
277   // LU factorization.
278   // Computes LU factorization with partial pivoting P * A = L * U.
279   template <typename Scalar>
280   Status Getrf(int m, int n, Scalar* dev_A, int lda, int* dev_pivots,
281                int* info);
282 
283   // Uses LU factorization to solve A * X = B.
284   template <typename Scalar>
285   Status Getrs(const gpuSolverOp_t trans, int n, int nrhs, Scalar* A, int lda,
286                const int* dev_pivots, Scalar* B, int ldb, int* dev_lapack_info);
287 
288   template <typename Scalar>
289   Status GetrfBatched(int n, Scalar** dev_A, int lda, int* dev_pivots,
290                       DeviceLapackInfo* info, const int batch_count);
291 
292   // No GetrsBatched for HipSolver yet.
293   template <typename Scalar>
294   Status GetrsBatched(const rocblas_operation trans, int n, int nrhs,
295                       Scalar** A, int lda, int* dev_pivots, Scalar** B,
296                       const int ldb, int* lapack_info, const int batch_count);
297 
298   // Computes the Cholesky factorization A = L * L^H for a single matrix.
299   template <typename Scalar>
300   Status Potrf(gpuSolverFill_t uplo, int n, Scalar* dev_A, int lda,
301                int* dev_lapack_info);
302 
303   // Computes matrix inverses for a batch of small matrices. Uses the outputs
304   // from GetrfBatched. No HipSolver implementation yet
305   template <typename Scalar>
306   Status GetriBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda,
307                       const int* dev_pivots,
308                       const Scalar* const host_a_inverse_dev_ptrs[], int ldainv,
309                       DeviceLapackInfo* dev_lapack_info, int batch_size);
310 
311   // Computes matrix inverses for a batch of small matrices with size n < 32.
312   // Returns Status::OK() if the kernel was launched successfully. Uses
313   // GetrfBatched and GetriBatched
314   template <typename Scalar>
315   Status MatInvBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda,
316                        const Scalar* const host_a_inverse_dev_ptrs[],
317                        int ldainv, DeviceLapackInfo* dev_lapack_info,
318                        int batch_size);
319 
320   // Cholesky factorization
321   // Computes the Cholesky factorization A = L * L^H for a batch of small
322   // matrices.
323   template <typename Scalar>
324   Status PotrfBatched(gpuSolverFill_t uplo, int n,
325                       const Scalar* const host_a_dev_ptrs[], int lda,
326                       DeviceLapackInfo* dev_lapack_info, int batch_size);
327 
328   template <typename Scalar>
329   Status Trsm(rocblas_side side, rocblas_fill uplo, rocblas_operation trans,
330               rocblas_diagonal diag, int m, int n, const Scalar* alpha,
331               const Scalar* A, int lda, Scalar* B, int ldb);
332 
333   // QR factorization.
334   // Computes QR factorization A = Q * R.
335   template <typename Scalar>
336   Status Geqrf(int m, int n, Scalar* dev_A, int lda, Scalar* dev_tau,
337                int* dev_lapack_info);
338 
339   // This function performs the matrix-matrix addition/transposition
340   //   C = alpha * op(A) + beta * op(B).
341   template <typename Scalar>
342   Status Geam(rocblas_operation transa, rocblas_operation transb, int m, int n,
343               const Scalar* alpha, /* host or device pointer */
344               const Scalar* A, int lda,
345               const Scalar* beta, /* host or device pointer */
346               const Scalar* B, int ldb, Scalar* C, int ldc);
347 
348   // Overwrite matrix C by product of C and the unitary Householder matrix Q.
349   // The Householder matrix Q is represented by the output from Geqrf in dev_a
350   // and dev_tau.
351   template <typename Scalar>
352   Status Unmqr(gpuSolverSide_t side, gpuSolverOp_t trans, int m, int n, int k,
353                const Scalar* dev_a, int lda, const Scalar* dev_tau,
354                Scalar* dev_c, int ldc, int* dev_lapack_info);
355 
356   // Overwrites QR factorization produced by Geqrf by the unitary Householder
357   // matrix Q. On input, the Householder matrix Q is represented by the output
358   // from Geqrf in dev_a and dev_tau. On output, dev_a is overwritten with the
359   // first n columns of Q. Requires m >= n >= 0.
360   template <typename Scalar>
361   Status Ungqr(int m, int n, int k, Scalar* dev_a, int lda,
362                const Scalar* dev_tau, int* dev_lapack_info);
363 
364 #if TF_ROCM_VERSION >= 40500
365   // Hermitian (Symmetric) Eigen decomposition.
366   template <typename Scalar>
367   Status Heevd(gpuSolverOp_t jobz, gpuSolverFill_t uplo, int n, Scalar* dev_A,
368                int lda, typename Eigen::NumTraits<Scalar>::Real* dev_W,
369                int* dev_lapack_info);
370 #endif
371 
372 #else  // GOOGLE_CUDA
373   // ====================================================================
374   // Wrappers for cuSolverDN and cuBlas solvers start here.
375   //
376   // Apart from capitalization of the first letter, the method names below
377   // map to those in cuSolverDN and cuBlas, which follow the naming
378   // convention in LAPACK see, e.g.,
379   // http://docs.nvidia.com/cuda/cusolver/#naming-convention
380 
381   // This function performs the matrix-matrix addition/transposition
382   //   C = alpha * op(A) + beta * op(B).
383   // Returns Status::OK() if the kernel was launched successfully.  See:
384   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-geam
385   // NOTE(ebrevdo): Does not support in-place transpose of non-square
386   // matrices.
387 
388   template <typename Scalar>
389   Status Geam(cublasOperation_t transa, cublasOperation_t transb, int m, int n,
390               const Scalar* alpha, /* host or device pointer */
391               const Scalar* A, int lda,
392               const Scalar* beta, /* host or device pointer */
393               const Scalar* B, int ldb, Scalar* C,
394               int ldc) const TF_MUST_USE_RESULT;
395 
396   // Computes the Cholesky factorization A = L * L^H for a single matrix.
397   // Returns Status::OK() if the kernel was launched successfully. See:
398   // http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf
399   template <typename Scalar>
400   Status Potrf(cublasFillMode_t uplo, int n, Scalar* dev_A, int lda,
401                int* dev_lapack_info) TF_MUST_USE_RESULT;
402 
403 #if CUDA_VERSION >= 9020
404   // Computes the Cholesky factorization A = L * L^H for a batch of small
405   // matrices.
406   // Returns Status::OK() if the kernel was launched successfully. See:
407   // http://docs.nvidia.com/cuda/cusolver/index.html#cuds-lt-t-gt-potrfBatched
408   template <typename Scalar>
409   Status PotrfBatched(cublasFillMode_t uplo, int n,
410                       const Scalar* const host_a_dev_ptrs[], int lda,
411                       DeviceLapackInfo* dev_lapack_info,
412                       int batch_size) TF_MUST_USE_RESULT;
413 #endif  // CUDA_VERSION >= 9020
414   // LU factorization.
415   // Computes LU factorization with partial pivoting P * A = L * U.
416   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrf
417   template <typename Scalar>
418   Status Getrf(int m, int n, Scalar* dev_A, int lda, int* dev_pivots,
419                int* dev_lapack_info) TF_MUST_USE_RESULT;
420 
421   // Uses LU factorization to solve A * X = B.
422   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrs
423   template <typename Scalar>
424   Status Getrs(cublasOperation_t trans, int n, int nrhs, const Scalar* A,
425                int lda, const int* pivots, Scalar* B, int ldb,
426                int* dev_lapack_info) const TF_MUST_USE_RESULT;
427 
428   // Computes partially pivoted LU factorizations for a batch of small matrices.
429   // Returns Status::OK() if the kernel was launched successfully. See:
430   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrfbatched
431   template <typename Scalar>
432   Status GetrfBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda,
433                       int* dev_pivots, DeviceLapackInfo* dev_lapack_info,
434                       int batch_size) TF_MUST_USE_RESULT;
435 
436   // Batched linear solver using LU factorization from getrfBatched.
437   // Notice that lapack_info is returned on the host, as opposed to
438   // most of the other functions that return it on the device. See:
439   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched
440   template <typename Scalar>
441   Status GetrsBatched(cublasOperation_t trans, int n, int nrhs,
442                       const Scalar* const dev_Aarray[], int lda,
443                       const int* devIpiv, const Scalar* const dev_Barray[],
444                       int ldb, int* host_lapack_info,
445                       int batch_size) TF_MUST_USE_RESULT;
446 
447   // Computes matrix inverses for a batch of small matrices. Uses the outputs
448   // from GetrfBatched. Returns Status::OK() if the kernel was launched
449   // successfully. See:
450   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getribatched
451   template <typename Scalar>
452   Status GetriBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda,
453                       const int* dev_pivots,
454                       const Scalar* const host_a_inverse_dev_ptrs[], int ldainv,
455                       DeviceLapackInfo* dev_lapack_info,
456                       int batch_size) TF_MUST_USE_RESULT;
457 
458   // Computes matrix inverses for a batch of small matrices with size n < 32.
459   // Returns Status::OK() if the kernel was launched successfully. See:
460   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-matinvbatched
461   template <typename Scalar>
462   Status MatInvBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda,
463                        const Scalar* const host_a_inverse_dev_ptrs[],
464                        int ldainv, DeviceLapackInfo* dev_lapack_info,
465                        int batch_size) TF_MUST_USE_RESULT;
466 
467   // QR factorization.
468   // Computes QR factorization A = Q * R.
469   // Returns Status::OK() if the kernel was launched successfully.
470   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-geqrf
471   template <typename Scalar>
472   Status Geqrf(int m, int n, Scalar* dev_A, int lda, Scalar* dev_tau,
473                int* dev_lapack_info) TF_MUST_USE_RESULT;
474 
475   // Overwrite matrix C by product of C and the unitary Householder matrix Q.
476   // The Householder matrix Q is represented by the output from Geqrf in dev_a
477   // and dev_tau.
478   // Notice: If Scalar is real, only trans=CUBLAS_OP_N or trans=CUBLAS_OP_T is
479   // supported. If Scalar is complex, trans=CUBLAS_OP_N or trans=CUBLAS_OP_C is
480   // supported.
481   // Returns Status::OK() if the kernel was launched successfully.
482   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-ormqr
483   template <typename Scalar>
484   Status Unmqr(cublasSideMode_t side, cublasOperation_t trans, int m, int n,
485                int k, const Scalar* dev_a, int lda, const Scalar* dev_tau,
486                Scalar* dev_c, int ldc, int* dev_lapack_info) TF_MUST_USE_RESULT;
487 
488   // Overwrites QR factorization produced by Geqrf by the unitary Householder
489   // matrix Q. On input, the Householder matrix Q is represented by the output
490   // from Geqrf in dev_a and dev_tau. On output, dev_a is overwritten with the
491   // first n columns of Q. Requires m >= n >= 0.
492   // Returns Status::OK() if the kernel was launched successfully.
493   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-orgqr
494   template <typename Scalar>
495   Status Ungqr(int m, int n, int k, Scalar* dev_a, int lda,
496                const Scalar* dev_tau, int* dev_lapack_info) TF_MUST_USE_RESULT;
497 
498   // Hermitian (Symmetric) Eigen decomposition.
499   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-syevd
500   template <typename Scalar>
501   Status Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, int n,
502                Scalar* dev_A, int lda,
503                typename Eigen::NumTraits<Scalar>::Real* dev_W,
504                int* dev_lapack_info) TF_MUST_USE_RESULT;
505 
506   // Singular value decomposition.
507   // Returns Status::OK() if the kernel was launched successfully.
508   // TODO(rmlarsen, volunteers): Add support for complex types.
509   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd
510   template <typename Scalar>
511   Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A,
512                int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,
513                int ldvt, int* dev_lapack_info) TF_MUST_USE_RESULT;
514   template <typename Scalar>
515   Status GesvdjBatched(cusolverEigMode_t jobz, int m, int n, Scalar* dev_A,
516                        int lda, Scalar* dev_S, Scalar* dev_U, int ldu,
517                        Scalar* dev_V, int ldv, int* dev_lapack_info,
518                        int batch_size);
519 
520   // Triangular solve
521   // Returns Status::OK() if the kernel was launched successfully.
522   // See https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-trsm
523   template <typename Scalar>
524   Status Trsm(cublasSideMode_t side, cublasFillMode_t uplo,
525               cublasOperation_t trans, cublasDiagType_t diag, int m, int n,
526               const Scalar* alpha, const Scalar* A, int lda, Scalar* B,
527               int ldb);
528 
529   template <typename Scalar>
530   Status Trsv(cublasFillMode_t uplo, cublasOperation_t trans,
531               cublasDiagType_t diag, int n, const Scalar* A, int lda, Scalar* x,
532               int intcx);
533 
534   // See
535   // https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-trsmbatched
536   template <typename Scalar>
537   Status TrsmBatched(cublasSideMode_t side, cublasFillMode_t uplo,
538                      cublasOperation_t trans, cublasDiagType_t diag, int m,
539                      int n, const Scalar* alpha,
540                      const Scalar* const dev_Aarray[], int lda,
541                      Scalar* dev_Barray[], int ldb, int batch_size);
542 #endif
543 
544  private:
545   OpKernelContext* context_;  // not owned.
546 #if GOOGLE_CUDA
547   cudaStream_t cuda_stream_;
548   cusolverDnHandle_t cusolver_dn_handle_;
549   cublasHandle_t cublas_handle_;
550 #else  // TENSORFLOW_USE_ROCM
551   hipStream_t hip_stream_;
552   rocblas_handle rocm_blas_handle_;
553 #endif
554 
555   std::vector<TensorReference> scratch_tensor_refs_;
556 
557   TF_DISALLOW_COPY_AND_ASSIGN(GpuSolver);
558 };
559 
560 // Helper class to allocate scratch memory and keep track of debug info.
561 // Mostly a thin wrapper around Tensor & allocate_temp.
562 template <typename Scalar>
563 class ScratchSpace {
564  public:
565   ScratchSpace(OpKernelContext* context, int64_t size, bool on_host)
566       : ScratchSpace(context, TensorShape({size}), "", on_host) {}
567 
568   ScratchSpace(OpKernelContext* context, int64_t size,
569                const std::string& debug_info, bool on_host)
570       : ScratchSpace(context, TensorShape({size}), debug_info, on_host) {}
571 
572   ScratchSpace(OpKernelContext* context, const TensorShape& shape,
573                const std::string& debug_info, bool on_host)
574       : context_(context), debug_info_(debug_info), on_host_(on_host) {
575     AllocatorAttributes alloc_attr;
576     if (on_host) {
577       // Allocate pinned memory on the host to avoid unnecessary
578       // synchronization.
579       alloc_attr.set_on_host(true);
580       alloc_attr.set_gpu_compatible(true);
581     }
582     TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<Scalar>::value, shape,
583                                        &scratch_tensor_, alloc_attr));
584   }
585 
586   virtual ~ScratchSpace() {}
587 
588   Scalar* mutable_data() {
589     return scratch_tensor_.template flat<Scalar>().data();
590   }
591   const Scalar* data() const {
592     return scratch_tensor_.template flat<Scalar>().data();
593   }
594   Scalar& operator()(int64_t i) {
595     return scratch_tensor_.template flat<Scalar>()(i);
596   }
597   const Scalar& operator()(int64_t i) const {
598     return scratch_tensor_.template flat<Scalar>()(i);
599   }
600   int64_t bytes() const { return scratch_tensor_.TotalBytes(); }
601   int64_t size() const { return scratch_tensor_.NumElements(); }
602   const std::string& debug_info() const { return debug_info_; }
603 
604   Tensor& tensor() { return scratch_tensor_; }
605   const Tensor& tensor() const { return scratch_tensor_; }
606 
607   // Returns true if this ScratchSpace is in host memory.
608   bool on_host() const { return on_host_; }
609 
610  protected:
611   OpKernelContext* context() const { return context_; }
612 
613  private:
614   OpKernelContext* context_;  // not owned
615   const std::string debug_info_;
616   const bool on_host_;
617   Tensor scratch_tensor_;
618 };
619 
620 class HostLapackInfo : public ScratchSpace<int> {
621  public:
622   HostLapackInfo(OpKernelContext* context, int64_t size,
623                  const std::string& debug_info)
624       : ScratchSpace<int>(context, size, debug_info, /* on_host */ true) {}
625 };
626 
627 class DeviceLapackInfo : public ScratchSpace<int> {
628  public:
629   DeviceLapackInfo(OpKernelContext* context, int64_t size,
630                    const std::string& debug_info)
631       : ScratchSpace<int>(context, size, debug_info, /* on_host */ false) {}
632 
633   // Allocates a new scratch space on the host and launches a copy of the
634   // contents of *this to the new scratch space. Sets success to true if
635   // the copy kernel was launched successfully.
636   HostLapackInfo CopyToHost(bool* success) const {
637     CHECK(success != nullptr);
638     HostLapackInfo copy(context(), size(), debug_info());
639     auto stream = context()->op_device_context()->stream();
640     se::DeviceMemoryBase wrapped_src(
641         static_cast<void*>(const_cast<int*>(this->data())));
642     *success =
643         stream->ThenMemcpy(copy.mutable_data(), wrapped_src, this->bytes())
644             .ok();
645     return copy;
646   }
647 };
648 
649 template <typename Scalar>
650 ScratchSpace<Scalar> GpuSolver::GetScratchSpace(const TensorShape& shape,
651                                                 const std::string& debug_info,
652                                                 bool on_host) {
653   ScratchSpace<Scalar> new_scratch_space(context_, shape, debug_info, on_host);
654   scratch_tensor_refs_.emplace_back(new_scratch_space.tensor());
655   return std::move(new_scratch_space);
656 }
657 
658 template <typename Scalar>
659 ScratchSpace<Scalar> GpuSolver::GetScratchSpace(int64_t size,
660                                                 const std::string& debug_info,
661                                                 bool on_host) {
662   return GetScratchSpace<Scalar>(TensorShape({size}), debug_info, on_host);
663 }
664 
665 inline DeviceLapackInfo GpuSolver::GetDeviceLapackInfo(
666     int64_t size, const std::string& debug_info) {
667   DeviceLapackInfo new_dev_info(context_, size, debug_info);
668   scratch_tensor_refs_.emplace_back(new_dev_info.tensor());
669   return new_dev_info;
670 }
671 
672 }  // namespace tensorflow
673 
674 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
675 
676 #endif  // TENSORFLOW_CORE_KERNELS_LINALG_GPU_SOLVERS_H_
677