1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ============================================================================== 15 */ 16 17 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_GPU_SOLVERS_H_ 18 #define TENSORFLOW_CORE_KERNELS_LINALG_GPU_SOLVERS_H_ 19 20 // This header declares the class GpuSolver, which contains wrappers of linear 21 // algebra solvers in the cuBlas/cuSolverDN or rocmSolver libraries for use in 22 // TensorFlow kernels. 23 24 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 25 26 #include <functional> 27 #include <vector> 28 29 #if GOOGLE_CUDA 30 #include "third_party/gpus/cuda/include/cublas_v2.h" 31 #include "third_party/gpus/cuda/include/cuda.h" 32 #include "third_party/gpus/cuda/include/cusolverDn.h" 33 #else 34 #include "rocm/include/hip/hip_complex.h" 35 #include "rocm/include/rocblas.h" 36 #include "tensorflow/stream_executor/blas.h" 37 #include "tensorflow/stream_executor/rocm/rocsolver_wrapper.h" 38 #endif 39 #include "tensorflow/core/framework/op_kernel.h" 40 #include "tensorflow/core/framework/tensor.h" 41 #include "tensorflow/core/framework/tensor_reference.h" 42 #include "tensorflow/core/lib/core/status.h" 43 #include "tensorflow/core/platform/stream_executor.h" 44 45 namespace tensorflow { 46 47 #if GOOGLE_CUDA 48 // Type traits to get CUDA complex types from std::complex<T>. 49 template <typename T> 50 struct CUDAComplexT { 51 typedef T type; 52 }; 53 template <> 54 struct CUDAComplexT<std::complex<float>> { 55 typedef cuComplex type; 56 }; 57 template <> 58 struct CUDAComplexT<std::complex<double>> { 59 typedef cuDoubleComplex type; 60 }; 61 // Converts pointers of std::complex<> to pointers of 62 // cuComplex/cuDoubleComplex. No type conversion for non-complex types. 63 template <typename T> 64 inline const typename CUDAComplexT<T>::type* CUDAComplex(const T* p) { 65 return reinterpret_cast<const typename CUDAComplexT<T>::type*>(p); 66 } 67 template <typename T> 68 inline typename CUDAComplexT<T>::type* CUDAComplex(T* p) { 69 return reinterpret_cast<typename CUDAComplexT<T>::type*>(p); 70 } 71 72 // Template to give the Cublas adjoint operation for real and complex types. 73 template <typename T> 74 cublasOperation_t CublasAdjointOp() { 75 return Eigen::NumTraits<T>::IsComplex ? CUBLAS_OP_C : CUBLAS_OP_T; 76 } 77 #else // TENSORFLOW_USE_ROCM 78 // Type traits to get ROCm complex types from std::complex<T>. 79 template <typename T> 80 struct ROCmComplexT { 81 typedef T type; 82 }; 83 template <> 84 struct ROCmComplexT<std::complex<float>> { 85 typedef rocblas_float_complex type; 86 }; 87 template <> 88 struct ROCmComplexT<std::complex<double>> { 89 typedef rocblas_double_complex type; 90 }; 91 // Converts pointers of std::complex<> to pointers of 92 // ROCmComplex/ROCmDoubleComplex. No type conversion for non-complex types. 93 template <typename T> 94 inline const typename ROCmComplexT<T>::type* ROCmComplex(const T* p) { 95 return reinterpret_cast<const typename ROCmComplexT<T>::type*>(p); 96 } 97 template <typename T> 98 inline typename ROCmComplexT<T>::type* ROCmComplex(T* p) { 99 return reinterpret_cast<typename ROCmComplexT<T>::type*>(p); 100 } 101 102 // Type traits to get HIP complex types from std::complex<> 103 104 template <typename T> 105 struct HipComplexT { 106 typedef T type; 107 }; 108 109 template <> 110 struct HipComplexT<std::complex<float>> { 111 typedef hipFloatComplex type; 112 }; 113 114 template <> 115 struct HipComplexT<std::complex<double>> { 116 typedef hipDoubleComplex type; 117 }; 118 119 // Convert pointers of std::complex<> to pointers of 120 // hipFloatComplex/hipDoubleComplex. No type conversion for non-complex types. 121 template <typename T> 122 inline const typename HipComplexT<T>::type* AsHipComplex(const T* p) { 123 return reinterpret_cast<const typename HipComplexT<T>::type*>(p); 124 } 125 126 template <typename T> 127 inline typename HipComplexT<T>::type* AsHipComplex(T* p) { 128 return reinterpret_cast<typename HipComplexT<T>::type*>(p); 129 } 130 // Template to give the Rocblas adjoint operation for real and complex types. 131 template <typename T> 132 rocblas_operation RocblasAdjointOp() { 133 return Eigen::NumTraits<T>::IsComplex ? rocblas_operation_conjugate_transpose 134 : rocblas_operation_transpose; 135 } 136 137 #if TF_ROCM_VERSION >= 40500 138 using gpuSolverOp_t = hipsolverOperation_t; 139 using gpuSolverFill_t = hipsolverFillMode_t; 140 using gpuSolverSide_t = hipsolverSideMode_t; 141 #else 142 using gpuSolverOp_t = rocblas_operation; 143 using gpuSolverFill_t = rocblas_fill; 144 using gpuSolverSide_t = rocblas_side; 145 #endif 146 #endif 147 148 // Container of LAPACK info data (an array of int) generated on-device by 149 // a GpuSolver call. One or more such objects can be passed to 150 // GpuSolver::CopyLapackInfoToHostAsync() along with a callback to 151 // check the LAPACK info data after the corresponding kernels 152 // finish and LAPACK info has been copied from the device to the host. 153 class DeviceLapackInfo; 154 155 // Host-side copy of LAPACK info. 156 class HostLapackInfo; 157 158 // The GpuSolver class provides a simplified templated API for the dense linear 159 // solvers implemented in cuSolverDN (http://docs.nvidia.com/cuda/cusolver) and 160 // cuBlas (http://docs.nvidia.com/cuda/cublas/#blas-like-extension/). 161 // An object of this class wraps static cuSolver and cuBlas instances, 162 // and will launch Cuda kernels on the stream wrapped by the GPU device 163 // in the OpKernelContext provided to the constructor. 164 // 165 // Notice: All the computational member functions are asynchronous and simply 166 // launch one or more Cuda kernels on the Cuda stream wrapped by the GpuSolver 167 // object. To check the final status of the kernels run, call 168 // CopyLapackInfoToHostAsync() on the GpuSolver object to set a callback that 169 // will be invoked with the status of the kernels launched thus far as 170 // arguments. 171 // 172 // Example of an asynchronous TensorFlow kernel using GpuSolver: 173 // 174 // template <typename Scalar> 175 // class SymmetricPositiveDefiniteSolveOpGpu : public AsyncOpKernel { 176 // public: 177 // explicit SymmetricPositiveDefiniteSolveOpGpu(OpKernelConstruction* context) 178 // : AsyncOpKernel(context) { } 179 // void ComputeAsync(OpKernelContext* context, DoneCallback done) final { 180 // // 1. Set up input and output device ptrs. See, e.g., 181 // // matrix_inverse_op.cc for a full example. 182 // ... 183 // 184 // // 2. Initialize the solver object. 185 // std::unique_ptr<GpuSolver> solver(new GpuSolver(context)); 186 // 187 // // 3. Launch the two compute kernels back to back on the stream without 188 // // synchronizing. 189 // std::vector<DeviceLapackInfo> dev_info; 190 // const int batch_size = 1; 191 // dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrf"); 192 // // Compute the Cholesky decomposition of the input matrix. 193 // OP_REQUIRES_OK_ASYNC(context, 194 // solver->Potrf(uplo, n, dev_matrix_ptrs, n, 195 // dev_info.back().mutable_data()), 196 // done); 197 // dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrs"); 198 // // Use the Cholesky decomposition of the input matrix to solve A X = RHS. 199 // OP_REQUIRES_OK_ASYNC(context, 200 // solver->Potrs(uplo, n, nrhs, dev_matrix_ptrs, n, 201 // dev_output_ptrs, ldrhs, 202 // dev_info.back().mutable_data()), 203 // done); 204 // 205 // // 4. Check the status after the computation finishes and call done. 206 // solver.CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info, 207 // std::move(done)); 208 // } 209 // }; 210 211 template <typename Scalar> 212 class ScratchSpace; 213 214 class GpuSolver { 215 public: 216 // This object stores a pointer to context, which must outlive it. 217 explicit GpuSolver(OpKernelContext* context); 218 virtual ~GpuSolver(); 219 220 // Launches a memcpy of solver status data specified by dev_lapack_info from 221 // device to the host, and asynchronously invokes the given callback when the 222 // copy is complete. The first Status argument to the callback will be 223 // Status::OK if all lapack infos retrieved are zero, otherwise an error 224 // status is given. The second argument contains a host-side copy of the 225 // entire set of infos retrieved, and can be used for generating detailed 226 // error messages. 227 // `info_checker_callback` must call the DoneCallback of any asynchronous 228 // OpKernel within which `solver` is used. 229 static void CheckLapackInfoAndDeleteSolverAsync( 230 std::unique_ptr<GpuSolver> solver, 231 const std::vector<DeviceLapackInfo>& dev_lapack_info, 232 std::function<void(const Status&, const std::vector<HostLapackInfo>&)> 233 info_checker_callback); 234 235 // Simpler version to use if no special error checking / messages are needed 236 // apart from checking that the Status of all calls was Status::OK. 237 // `done` may be nullptr. 238 static void CheckLapackInfoAndDeleteSolverAsync( 239 std::unique_ptr<GpuSolver> solver, 240 const std::vector<DeviceLapackInfo>& dev_lapack_info, 241 AsyncOpKernel::DoneCallback done); 242 243 // Returns a ScratchSpace. The GpuSolver object maintains a TensorReference 244 // to the underlying Tensor to prevent it from being deallocated prematurely. 245 template <typename Scalar> 246 ScratchSpace<Scalar> GetScratchSpace(const TensorShape& shape, 247 const std::string& debug_info, 248 bool on_host); 249 template <typename Scalar> 250 ScratchSpace<Scalar> GetScratchSpace(int64_t size, 251 const std::string& debug_info, 252 bool on_host); 253 // Returns a DeviceLapackInfo that will live for the duration of the 254 // GpuSolver object. 255 inline DeviceLapackInfo GetDeviceLapackInfo(int64_t size, 256 const std::string& debug_info); 257 258 // Allocates a temporary tensor that will live for the duration of the 259 // GpuSolver object. 260 Status allocate_scoped_tensor(DataType type, const TensorShape& shape, 261 Tensor* scoped_tensor); 262 Status forward_input_or_allocate_scoped_tensor( 263 gtl::ArraySlice<int> candidate_input_indices, DataType type, 264 const TensorShape& shape, Tensor* input_alias_or_new_scoped_tensor); 265 266 OpKernelContext* context() { return context_; } 267 268 #if TENSORFLOW_USE_ROCM 269 // ==================================================================== 270 // Wrappers for ROCSolver start here 271 // 272 // The method names below 273 // map to those in ROCSolver/Hipsolver, which follow the naming 274 // convention in LAPACK. See rocm_solvers.cc for a mapping of 275 // GpuSolverMethod to library API 276 277 // LU factorization. 278 // Computes LU factorization with partial pivoting P * A = L * U. 279 template <typename Scalar> 280 Status Getrf(int m, int n, Scalar* dev_A, int lda, int* dev_pivots, 281 int* info); 282 283 // Uses LU factorization to solve A * X = B. 284 template <typename Scalar> 285 Status Getrs(const gpuSolverOp_t trans, int n, int nrhs, Scalar* A, int lda, 286 const int* dev_pivots, Scalar* B, int ldb, int* dev_lapack_info); 287 288 template <typename Scalar> 289 Status GetrfBatched(int n, Scalar** dev_A, int lda, int* dev_pivots, 290 DeviceLapackInfo* info, const int batch_count); 291 292 // No GetrsBatched for HipSolver yet. 293 template <typename Scalar> 294 Status GetrsBatched(const rocblas_operation trans, int n, int nrhs, 295 Scalar** A, int lda, int* dev_pivots, Scalar** B, 296 const int ldb, int* lapack_info, const int batch_count); 297 298 // Computes the Cholesky factorization A = L * L^H for a single matrix. 299 template <typename Scalar> 300 Status Potrf(gpuSolverFill_t uplo, int n, Scalar* dev_A, int lda, 301 int* dev_lapack_info); 302 303 // Computes matrix inverses for a batch of small matrices. Uses the outputs 304 // from GetrfBatched. No HipSolver implementation yet 305 template <typename Scalar> 306 Status GetriBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda, 307 const int* dev_pivots, 308 const Scalar* const host_a_inverse_dev_ptrs[], int ldainv, 309 DeviceLapackInfo* dev_lapack_info, int batch_size); 310 311 // Computes matrix inverses for a batch of small matrices with size n < 32. 312 // Returns Status::OK() if the kernel was launched successfully. Uses 313 // GetrfBatched and GetriBatched 314 template <typename Scalar> 315 Status MatInvBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda, 316 const Scalar* const host_a_inverse_dev_ptrs[], 317 int ldainv, DeviceLapackInfo* dev_lapack_info, 318 int batch_size); 319 320 // Cholesky factorization 321 // Computes the Cholesky factorization A = L * L^H for a batch of small 322 // matrices. 323 template <typename Scalar> 324 Status PotrfBatched(gpuSolverFill_t uplo, int n, 325 const Scalar* const host_a_dev_ptrs[], int lda, 326 DeviceLapackInfo* dev_lapack_info, int batch_size); 327 328 template <typename Scalar> 329 Status Trsm(rocblas_side side, rocblas_fill uplo, rocblas_operation trans, 330 rocblas_diagonal diag, int m, int n, const Scalar* alpha, 331 const Scalar* A, int lda, Scalar* B, int ldb); 332 333 // QR factorization. 334 // Computes QR factorization A = Q * R. 335 template <typename Scalar> 336 Status Geqrf(int m, int n, Scalar* dev_A, int lda, Scalar* dev_tau, 337 int* dev_lapack_info); 338 339 // This function performs the matrix-matrix addition/transposition 340 // C = alpha * op(A) + beta * op(B). 341 template <typename Scalar> 342 Status Geam(rocblas_operation transa, rocblas_operation transb, int m, int n, 343 const Scalar* alpha, /* host or device pointer */ 344 const Scalar* A, int lda, 345 const Scalar* beta, /* host or device pointer */ 346 const Scalar* B, int ldb, Scalar* C, int ldc); 347 348 // Overwrite matrix C by product of C and the unitary Householder matrix Q. 349 // The Householder matrix Q is represented by the output from Geqrf in dev_a 350 // and dev_tau. 351 template <typename Scalar> 352 Status Unmqr(gpuSolverSide_t side, gpuSolverOp_t trans, int m, int n, int k, 353 const Scalar* dev_a, int lda, const Scalar* dev_tau, 354 Scalar* dev_c, int ldc, int* dev_lapack_info); 355 356 // Overwrites QR factorization produced by Geqrf by the unitary Householder 357 // matrix Q. On input, the Householder matrix Q is represented by the output 358 // from Geqrf in dev_a and dev_tau. On output, dev_a is overwritten with the 359 // first n columns of Q. Requires m >= n >= 0. 360 template <typename Scalar> 361 Status Ungqr(int m, int n, int k, Scalar* dev_a, int lda, 362 const Scalar* dev_tau, int* dev_lapack_info); 363 364 #if TF_ROCM_VERSION >= 40500 365 // Hermitian (Symmetric) Eigen decomposition. 366 template <typename Scalar> 367 Status Heevd(gpuSolverOp_t jobz, gpuSolverFill_t uplo, int n, Scalar* dev_A, 368 int lda, typename Eigen::NumTraits<Scalar>::Real* dev_W, 369 int* dev_lapack_info); 370 #endif 371 372 #else // GOOGLE_CUDA 373 // ==================================================================== 374 // Wrappers for cuSolverDN and cuBlas solvers start here. 375 // 376 // Apart from capitalization of the first letter, the method names below 377 // map to those in cuSolverDN and cuBlas, which follow the naming 378 // convention in LAPACK see, e.g., 379 // http://docs.nvidia.com/cuda/cusolver/#naming-convention 380 381 // This function performs the matrix-matrix addition/transposition 382 // C = alpha * op(A) + beta * op(B). 383 // Returns Status::OK() if the kernel was launched successfully. See: 384 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-geam 385 // NOTE(ebrevdo): Does not support in-place transpose of non-square 386 // matrices. 387 388 template <typename Scalar> 389 Status Geam(cublasOperation_t transa, cublasOperation_t transb, int m, int n, 390 const Scalar* alpha, /* host or device pointer */ 391 const Scalar* A, int lda, 392 const Scalar* beta, /* host or device pointer */ 393 const Scalar* B, int ldb, Scalar* C, 394 int ldc) const TF_MUST_USE_RESULT; 395 396 // Computes the Cholesky factorization A = L * L^H for a single matrix. 397 // Returns Status::OK() if the kernel was launched successfully. See: 398 // http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf 399 template <typename Scalar> 400 Status Potrf(cublasFillMode_t uplo, int n, Scalar* dev_A, int lda, 401 int* dev_lapack_info) TF_MUST_USE_RESULT; 402 403 #if CUDA_VERSION >= 9020 404 // Computes the Cholesky factorization A = L * L^H for a batch of small 405 // matrices. 406 // Returns Status::OK() if the kernel was launched successfully. See: 407 // http://docs.nvidia.com/cuda/cusolver/index.html#cuds-lt-t-gt-potrfBatched 408 template <typename Scalar> 409 Status PotrfBatched(cublasFillMode_t uplo, int n, 410 const Scalar* const host_a_dev_ptrs[], int lda, 411 DeviceLapackInfo* dev_lapack_info, 412 int batch_size) TF_MUST_USE_RESULT; 413 #endif // CUDA_VERSION >= 9020 414 // LU factorization. 415 // Computes LU factorization with partial pivoting P * A = L * U. 416 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrf 417 template <typename Scalar> 418 Status Getrf(int m, int n, Scalar* dev_A, int lda, int* dev_pivots, 419 int* dev_lapack_info) TF_MUST_USE_RESULT; 420 421 // Uses LU factorization to solve A * X = B. 422 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrs 423 template <typename Scalar> 424 Status Getrs(cublasOperation_t trans, int n, int nrhs, const Scalar* A, 425 int lda, const int* pivots, Scalar* B, int ldb, 426 int* dev_lapack_info) const TF_MUST_USE_RESULT; 427 428 // Computes partially pivoted LU factorizations for a batch of small matrices. 429 // Returns Status::OK() if the kernel was launched successfully. See: 430 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrfbatched 431 template <typename Scalar> 432 Status GetrfBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda, 433 int* dev_pivots, DeviceLapackInfo* dev_lapack_info, 434 int batch_size) TF_MUST_USE_RESULT; 435 436 // Batched linear solver using LU factorization from getrfBatched. 437 // Notice that lapack_info is returned on the host, as opposed to 438 // most of the other functions that return it on the device. See: 439 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched 440 template <typename Scalar> 441 Status GetrsBatched(cublasOperation_t trans, int n, int nrhs, 442 const Scalar* const dev_Aarray[], int lda, 443 const int* devIpiv, const Scalar* const dev_Barray[], 444 int ldb, int* host_lapack_info, 445 int batch_size) TF_MUST_USE_RESULT; 446 447 // Computes matrix inverses for a batch of small matrices. Uses the outputs 448 // from GetrfBatched. Returns Status::OK() if the kernel was launched 449 // successfully. See: 450 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getribatched 451 template <typename Scalar> 452 Status GetriBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda, 453 const int* dev_pivots, 454 const Scalar* const host_a_inverse_dev_ptrs[], int ldainv, 455 DeviceLapackInfo* dev_lapack_info, 456 int batch_size) TF_MUST_USE_RESULT; 457 458 // Computes matrix inverses for a batch of small matrices with size n < 32. 459 // Returns Status::OK() if the kernel was launched successfully. See: 460 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-matinvbatched 461 template <typename Scalar> 462 Status MatInvBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda, 463 const Scalar* const host_a_inverse_dev_ptrs[], 464 int ldainv, DeviceLapackInfo* dev_lapack_info, 465 int batch_size) TF_MUST_USE_RESULT; 466 467 // QR factorization. 468 // Computes QR factorization A = Q * R. 469 // Returns Status::OK() if the kernel was launched successfully. 470 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-geqrf 471 template <typename Scalar> 472 Status Geqrf(int m, int n, Scalar* dev_A, int lda, Scalar* dev_tau, 473 int* dev_lapack_info) TF_MUST_USE_RESULT; 474 475 // Overwrite matrix C by product of C and the unitary Householder matrix Q. 476 // The Householder matrix Q is represented by the output from Geqrf in dev_a 477 // and dev_tau. 478 // Notice: If Scalar is real, only trans=CUBLAS_OP_N or trans=CUBLAS_OP_T is 479 // supported. If Scalar is complex, trans=CUBLAS_OP_N or trans=CUBLAS_OP_C is 480 // supported. 481 // Returns Status::OK() if the kernel was launched successfully. 482 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-ormqr 483 template <typename Scalar> 484 Status Unmqr(cublasSideMode_t side, cublasOperation_t trans, int m, int n, 485 int k, const Scalar* dev_a, int lda, const Scalar* dev_tau, 486 Scalar* dev_c, int ldc, int* dev_lapack_info) TF_MUST_USE_RESULT; 487 488 // Overwrites QR factorization produced by Geqrf by the unitary Householder 489 // matrix Q. On input, the Householder matrix Q is represented by the output 490 // from Geqrf in dev_a and dev_tau. On output, dev_a is overwritten with the 491 // first n columns of Q. Requires m >= n >= 0. 492 // Returns Status::OK() if the kernel was launched successfully. 493 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-orgqr 494 template <typename Scalar> 495 Status Ungqr(int m, int n, int k, Scalar* dev_a, int lda, 496 const Scalar* dev_tau, int* dev_lapack_info) TF_MUST_USE_RESULT; 497 498 // Hermitian (Symmetric) Eigen decomposition. 499 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-syevd 500 template <typename Scalar> 501 Status Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, 502 Scalar* dev_A, int lda, 503 typename Eigen::NumTraits<Scalar>::Real* dev_W, 504 int* dev_lapack_info) TF_MUST_USE_RESULT; 505 506 // Singular value decomposition. 507 // Returns Status::OK() if the kernel was launched successfully. 508 // TODO(rmlarsen, volunteers): Add support for complex types. 509 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd 510 template <typename Scalar> 511 Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A, 512 int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT, 513 int ldvt, int* dev_lapack_info) TF_MUST_USE_RESULT; 514 template <typename Scalar> 515 Status GesvdjBatched(cusolverEigMode_t jobz, int m, int n, Scalar* dev_A, 516 int lda, Scalar* dev_S, Scalar* dev_U, int ldu, 517 Scalar* dev_V, int ldv, int* dev_lapack_info, 518 int batch_size); 519 520 // Triangular solve 521 // Returns Status::OK() if the kernel was launched successfully. 522 // See https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-trsm 523 template <typename Scalar> 524 Status Trsm(cublasSideMode_t side, cublasFillMode_t uplo, 525 cublasOperation_t trans, cublasDiagType_t diag, int m, int n, 526 const Scalar* alpha, const Scalar* A, int lda, Scalar* B, 527 int ldb); 528 529 template <typename Scalar> 530 Status Trsv(cublasFillMode_t uplo, cublasOperation_t trans, 531 cublasDiagType_t diag, int n, const Scalar* A, int lda, Scalar* x, 532 int intcx); 533 534 // See 535 // https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-trsmbatched 536 template <typename Scalar> 537 Status TrsmBatched(cublasSideMode_t side, cublasFillMode_t uplo, 538 cublasOperation_t trans, cublasDiagType_t diag, int m, 539 int n, const Scalar* alpha, 540 const Scalar* const dev_Aarray[], int lda, 541 Scalar* dev_Barray[], int ldb, int batch_size); 542 #endif 543 544 private: 545 OpKernelContext* context_; // not owned. 546 #if GOOGLE_CUDA 547 cudaStream_t cuda_stream_; 548 cusolverDnHandle_t cusolver_dn_handle_; 549 cublasHandle_t cublas_handle_; 550 #else // TENSORFLOW_USE_ROCM 551 hipStream_t hip_stream_; 552 rocblas_handle rocm_blas_handle_; 553 #endif 554 555 std::vector<TensorReference> scratch_tensor_refs_; 556 557 TF_DISALLOW_COPY_AND_ASSIGN(GpuSolver); 558 }; 559 560 // Helper class to allocate scratch memory and keep track of debug info. 561 // Mostly a thin wrapper around Tensor & allocate_temp. 562 template <typename Scalar> 563 class ScratchSpace { 564 public: 565 ScratchSpace(OpKernelContext* context, int64_t size, bool on_host) 566 : ScratchSpace(context, TensorShape({size}), "", on_host) {} 567 568 ScratchSpace(OpKernelContext* context, int64_t size, 569 const std::string& debug_info, bool on_host) 570 : ScratchSpace(context, TensorShape({size}), debug_info, on_host) {} 571 572 ScratchSpace(OpKernelContext* context, const TensorShape& shape, 573 const std::string& debug_info, bool on_host) 574 : context_(context), debug_info_(debug_info), on_host_(on_host) { 575 AllocatorAttributes alloc_attr; 576 if (on_host) { 577 // Allocate pinned memory on the host to avoid unnecessary 578 // synchronization. 579 alloc_attr.set_on_host(true); 580 alloc_attr.set_gpu_compatible(true); 581 } 582 TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<Scalar>::value, shape, 583 &scratch_tensor_, alloc_attr)); 584 } 585 586 virtual ~ScratchSpace() {} 587 588 Scalar* mutable_data() { 589 return scratch_tensor_.template flat<Scalar>().data(); 590 } 591 const Scalar* data() const { 592 return scratch_tensor_.template flat<Scalar>().data(); 593 } 594 Scalar& operator()(int64_t i) { 595 return scratch_tensor_.template flat<Scalar>()(i); 596 } 597 const Scalar& operator()(int64_t i) const { 598 return scratch_tensor_.template flat<Scalar>()(i); 599 } 600 int64_t bytes() const { return scratch_tensor_.TotalBytes(); } 601 int64_t size() const { return scratch_tensor_.NumElements(); } 602 const std::string& debug_info() const { return debug_info_; } 603 604 Tensor& tensor() { return scratch_tensor_; } 605 const Tensor& tensor() const { return scratch_tensor_; } 606 607 // Returns true if this ScratchSpace is in host memory. 608 bool on_host() const { return on_host_; } 609 610 protected: 611 OpKernelContext* context() const { return context_; } 612 613 private: 614 OpKernelContext* context_; // not owned 615 const std::string debug_info_; 616 const bool on_host_; 617 Tensor scratch_tensor_; 618 }; 619 620 class HostLapackInfo : public ScratchSpace<int> { 621 public: 622 HostLapackInfo(OpKernelContext* context, int64_t size, 623 const std::string& debug_info) 624 : ScratchSpace<int>(context, size, debug_info, /* on_host */ true) {} 625 }; 626 627 class DeviceLapackInfo : public ScratchSpace<int> { 628 public: 629 DeviceLapackInfo(OpKernelContext* context, int64_t size, 630 const std::string& debug_info) 631 : ScratchSpace<int>(context, size, debug_info, /* on_host */ false) {} 632 633 // Allocates a new scratch space on the host and launches a copy of the 634 // contents of *this to the new scratch space. Sets success to true if 635 // the copy kernel was launched successfully. 636 HostLapackInfo CopyToHost(bool* success) const { 637 CHECK(success != nullptr); 638 HostLapackInfo copy(context(), size(), debug_info()); 639 auto stream = context()->op_device_context()->stream(); 640 se::DeviceMemoryBase wrapped_src( 641 static_cast<void*>(const_cast<int*>(this->data()))); 642 *success = 643 stream->ThenMemcpy(copy.mutable_data(), wrapped_src, this->bytes()) 644 .ok(); 645 return copy; 646 } 647 }; 648 649 template <typename Scalar> 650 ScratchSpace<Scalar> GpuSolver::GetScratchSpace(const TensorShape& shape, 651 const std::string& debug_info, 652 bool on_host) { 653 ScratchSpace<Scalar> new_scratch_space(context_, shape, debug_info, on_host); 654 scratch_tensor_refs_.emplace_back(new_scratch_space.tensor()); 655 return std::move(new_scratch_space); 656 } 657 658 template <typename Scalar> 659 ScratchSpace<Scalar> GpuSolver::GetScratchSpace(int64_t size, 660 const std::string& debug_info, 661 bool on_host) { 662 return GetScratchSpace<Scalar>(TensorShape({size}), debug_info, on_host); 663 } 664 665 inline DeviceLapackInfo GpuSolver::GetDeviceLapackInfo( 666 int64_t size, const std::string& debug_info) { 667 DeviceLapackInfo new_dev_info(context_, size, debug_info); 668 scratch_tensor_refs_.emplace_back(new_dev_info.tensor()); 669 return new_dev_info; 670 } 671 672 } // namespace tensorflow 673 674 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 675 676 #endif // TENSORFLOW_CORE_KERNELS_LINALG_GPU_SOLVERS_H_ 677