1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_CONTEXT_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_CONTEXT_H_ 18 19 #include <complex> 20 21 #define TENSORFLOW_USE_HIPSOLVER \ 22 (TENSORFLOW_USE_ROCM && (TF_ROCM_VERSION >= 40500)) 23 #define TENSORFLOW_USE_ROCSOLVER \ 24 (TENSORFLOW_USE_ROCM && (TF_ROCM_VERSION < 40500)) 25 #define TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER \ 26 (!TENSORFLOW_USE_ROCM || TENSORFLOW_USE_HIPSOLVER) 27 28 #if !TENSORFLOW_USE_ROCM 29 #include "third_party/gpus/cuda/include/cusolverDn.h" 30 using gpusolverHandle_t = cusolverDnHandle_t; 31 #else 32 #include "rocm/rocm_config.h" 33 // Macros to ease the transition from rocsolver to hipsolver. 34 #if TENSORFLOW_USE_HIPSOLVER 35 #include "tensorflow/stream_executor/rocm/hipsolver_wrapper.h" 36 using gpusolverHandle_t = hipsolverHandle_t; 37 #else // TENSORFLOW_USE_ROCSOLVER 38 #include "tensorflow/stream_executor/rocm/rocblas_wrapper.h" 39 #include "tensorflow/stream_executor/rocm/rocsolver_wrapper.h" 40 using gpusolverHandle_t = rocblas_handle; 41 #endif // TF_ROCM_VERSION >= 40500 42 #endif // TENSORFLOW_USE_ROCM 43 44 #include "tensorflow/compiler/xla/statusor.h" 45 #include "tensorflow/compiler/xla/types.h" 46 #include "tensorflow/compiler/xla/util.h" 47 #include "tensorflow/core/lib/core/status.h" 48 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 49 #include "tensorflow/stream_executor/blas.h" 50 51 namespace xla { 52 namespace gpu { 53 54 class GpuSolverContext { 55 public: 56 // stream may be nullptr, in which case the context can only be used for 57 // buffer size queries. 58 static StatusOr<GpuSolverContext> Create(se::Stream* stream); 59 GpuSolverContext() = default; 60 ~GpuSolverContext(); 61 62 GpuSolverContext(const GpuSolverContext&) = delete; 63 GpuSolverContext(GpuSolverContext&&); 64 GpuSolverContext& operator=(const GpuSolverContext&) = delete; 65 GpuSolverContext& operator=(GpuSolverContext&&); 66 SupportsPotrfBatched()67 bool SupportsPotrfBatched() const { 68 return true; 69 } 70 71 // Computes the Cholesky factorization A = L * L^T for a single matrix. 72 // Returns Status::OK() if the kernel was launched successfully. See: 73 // http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf 74 Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory<float> a, 75 int lda, se::DeviceMemory<int> lapack_info, 76 se::DeviceMemoryBase workspace); 77 Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory<double> a, 78 int lda, se::DeviceMemory<int> lapack_info, 79 se::DeviceMemoryBase workspace); 80 Status Potrf(se::blas::UpperLower uplo, int n, 81 se::DeviceMemory<std::complex<float>> a, int lda, 82 se::DeviceMemory<int> lapack_info, 83 se::DeviceMemoryBase workspace); 84 Status Potrf(se::blas::UpperLower uplo, int n, 85 se::DeviceMemory<std::complex<double>> a, int lda, 86 se::DeviceMemory<int> lapack_info, 87 se::DeviceMemoryBase workspace); 88 89 // Computes the Cholesky factorization of multiple matrices. See 90 // https://docs.nvidia.com/cuda/cusolver/index.html#cuSolverDN-lt-t-gt-batchpotrf 91 // 92 // `as` is a list of pointers to the batch_size individual n x n matricies 93 // that make up the input array. 94 Status PotrfBatched(se::blas::UpperLower uplo, int n, 95 se::DeviceMemory<float*> as, int lda, 96 se::DeviceMemory<int> lapack_info, int batch_size); 97 Status PotrfBatched(se::blas::UpperLower uplo, int n, 98 se::DeviceMemory<double*> as, int lda, 99 se::DeviceMemory<int> lapack_info, int batch_size); 100 Status PotrfBatched(se::blas::UpperLower uplo, int n, 101 se::DeviceMemory<std::complex<float>*> as, int lda, 102 se::DeviceMemory<int> lapack_info, int batch_size); 103 Status PotrfBatched(se::blas::UpperLower uplo, int n, 104 se::DeviceMemory<std::complex<double>*> as, int lda, 105 se::DeviceMemory<int> lapack_info, int batch_size); 106 107 // Returns the max size of the `workspace` required by Potrf and PotrfBatched, 108 // in number of elements of `type`. 109 // 110 // (cusolver's PotrfBatched doesn't require a workspace per se -- it uses the 111 // input array as scratch. But we do need to materialize the `as` input, and 112 // we do this in the workspace.) 113 // 114 // This is a bit of a hack; we could instead split it up into two functions. 115 // But at the moment, it's an implementation detail of CholeskyThunk whether 116 // it calls Potrf or PotrfBatched, so we need to allocate enough scratch space 117 // for either. 118 // 119 // In practice, this does not result in a notable increase in scratch space 120 // needed, because both cases require a relatively small amount of scratch. 121 StatusOr<int64_t> PotrfBufferSize(PrimitiveType type, 122 se::blas::UpperLower uplo, int n, int lda, 123 int batch_size); 124 125 private: 126 GpuSolverContext(se::Stream* stream, gpusolverHandle_t handle); 127 handle()128 gpusolverHandle_t handle() const { return handle_; } 129 130 se::Stream* stream_ = nullptr; 131 gpusolverHandle_t handle_ = nullptr; 132 }; 133 134 } // namespace gpu 135 } // namespace xla 136 137 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_CONTEXT_H_ 138