xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/cusolver_context.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_CONTEXT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_CONTEXT_H_
18 
19 #include <complex>
20 
21 #define TENSORFLOW_USE_HIPSOLVER \
22   (TENSORFLOW_USE_ROCM && (TF_ROCM_VERSION >= 40500))
23 #define TENSORFLOW_USE_ROCSOLVER \
24   (TENSORFLOW_USE_ROCM && (TF_ROCM_VERSION < 40500))
25 #define TENSORFLOW_USE_CUSOLVER_OR_HIPSOLVER \
26   (!TENSORFLOW_USE_ROCM || TENSORFLOW_USE_HIPSOLVER)
27 
28 #if !TENSORFLOW_USE_ROCM
29 #include "third_party/gpus/cuda/include/cusolverDn.h"
30 using gpusolverHandle_t = cusolverDnHandle_t;
31 #else
32 #include "rocm/rocm_config.h"
33 // Macros to ease the transition from rocsolver to hipsolver.
34 #if TENSORFLOW_USE_HIPSOLVER
35 #include "tensorflow/stream_executor/rocm/hipsolver_wrapper.h"
36 using gpusolverHandle_t = hipsolverHandle_t;
37 #else  // TENSORFLOW_USE_ROCSOLVER
38 #include "tensorflow/stream_executor/rocm/rocblas_wrapper.h"
39 #include "tensorflow/stream_executor/rocm/rocsolver_wrapper.h"
40 using gpusolverHandle_t = rocblas_handle;
41 #endif  // TF_ROCM_VERSION >= 40500
42 #endif  // TENSORFLOW_USE_ROCM
43 
44 #include "tensorflow/compiler/xla/statusor.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/core/lib/core/status.h"
48 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
49 #include "tensorflow/stream_executor/blas.h"
50 
51 namespace xla {
52 namespace gpu {
53 
54 class GpuSolverContext {
55  public:
56   // stream may be nullptr, in which case the context can only be used for
57   // buffer size queries.
58   static StatusOr<GpuSolverContext> Create(se::Stream* stream);
59   GpuSolverContext() = default;
60   ~GpuSolverContext();
61 
62   GpuSolverContext(const GpuSolverContext&) = delete;
63   GpuSolverContext(GpuSolverContext&&);
64   GpuSolverContext& operator=(const GpuSolverContext&) = delete;
65   GpuSolverContext& operator=(GpuSolverContext&&);
66 
SupportsPotrfBatched()67   bool SupportsPotrfBatched() const {
68     return true;
69   }
70 
71   // Computes the Cholesky factorization A = L * L^T for a single matrix.
72   // Returns Status::OK() if the kernel was launched successfully. See:
73   // http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf
74   Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory<float> a,
75                int lda, se::DeviceMemory<int> lapack_info,
76                se::DeviceMemoryBase workspace);
77   Status Potrf(se::blas::UpperLower uplo, int n, se::DeviceMemory<double> a,
78                int lda, se::DeviceMemory<int> lapack_info,
79                se::DeviceMemoryBase workspace);
80   Status Potrf(se::blas::UpperLower uplo, int n,
81                se::DeviceMemory<std::complex<float>> a, int lda,
82                se::DeviceMemory<int> lapack_info,
83                se::DeviceMemoryBase workspace);
84   Status Potrf(se::blas::UpperLower uplo, int n,
85                se::DeviceMemory<std::complex<double>> a, int lda,
86                se::DeviceMemory<int> lapack_info,
87                se::DeviceMemoryBase workspace);
88 
89   // Computes the Cholesky factorization of multiple matrices.  See
90   // https://docs.nvidia.com/cuda/cusolver/index.html#cuSolverDN-lt-t-gt-batchpotrf
91   //
92   // `as` is a list of pointers to the batch_size individual n x n matricies
93   // that make up the input array.
94   Status PotrfBatched(se::blas::UpperLower uplo, int n,
95                       se::DeviceMemory<float*> as, int lda,
96                       se::DeviceMemory<int> lapack_info, int batch_size);
97   Status PotrfBatched(se::blas::UpperLower uplo, int n,
98                       se::DeviceMemory<double*> as, int lda,
99                       se::DeviceMemory<int> lapack_info, int batch_size);
100   Status PotrfBatched(se::blas::UpperLower uplo, int n,
101                       se::DeviceMemory<std::complex<float>*> as, int lda,
102                       se::DeviceMemory<int> lapack_info, int batch_size);
103   Status PotrfBatched(se::blas::UpperLower uplo, int n,
104                       se::DeviceMemory<std::complex<double>*> as, int lda,
105                       se::DeviceMemory<int> lapack_info, int batch_size);
106 
107   // Returns the max size of the `workspace` required by Potrf and PotrfBatched,
108   // in number of elements of `type`.
109   //
110   // (cusolver's PotrfBatched doesn't require a workspace per se -- it uses the
111   // input array as scratch.  But we do need to materialize the `as` input, and
112   // we do this in the workspace.)
113   //
114   // This is a bit of a hack; we could instead split it up into two functions.
115   // But at the moment, it's an implementation detail of CholeskyThunk whether
116   // it calls Potrf or PotrfBatched, so we need to allocate enough scratch space
117   // for either.
118   //
119   // In practice, this does not result in a notable increase in scratch space
120   // needed, because both cases require a relatively small amount of scratch.
121   StatusOr<int64_t> PotrfBufferSize(PrimitiveType type,
122                                     se::blas::UpperLower uplo, int n, int lda,
123                                     int batch_size);
124 
125  private:
126   GpuSolverContext(se::Stream* stream, gpusolverHandle_t handle);
127 
handle()128   gpusolverHandle_t handle() const { return handle_; }
129 
130   se::Stream* stream_ = nullptr;
131   gpusolverHandle_t handle_ = nullptr;
132 };
133 
134 }  // namespace gpu
135 }  // namespace xla
136 
137 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_CONTEXT_H_
138