xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // See Note [BatchLinearAlgebraLib split implementation files]
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 #include <ATen/Context.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/ExpandUtils.h>
7 #include <ATen/cuda/PinnedMemoryAllocator.h>
8 #include <ATen/cuda/CUDABlas.h>
9 #include <ATen/cuda/CUDAEvent.h>
10 #include <c10/cuda/CUDAStream.h>
11 #include <c10/util/irange.h>
12 
13 #include <ATen/native/LinearAlgebraUtils.h>
14 #include <ATen/native/TransposeType.h>
15 #include <ATen/native/cuda/MiscUtils.h>
16 #include <ATen/native/cuda/linalg/CUDASolver.h>
17 #include <ATen/native/cuda/linalg/BatchLinearAlgebraLib.h>
18 
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/Functions.h>
21 #else
22 #include <ATen/ops/arange.h>
23 #include <ATen/ops/empty.h>
24 #include <ATen/ops/nan_to_num.h>
25 #include <ATen/ops/ones.h>
26 #include <ATen/ops/scalar_tensor.h>
27 #include <ATen/ops/where.h>
28 #include <ATen/ops/zeros.h>
29 #endif
30 
31 namespace at::native {
32 
to_cublas(TransposeType trans)33 static cublasOperation_t to_cublas(TransposeType trans) {
34   switch (trans) {
35     case TransposeType::NoTranspose: return CUBLAS_OP_N;
36     case TransposeType::Transpose: return CUBLAS_OP_T;
37     case TransposeType::ConjTranspose: return CUBLAS_OP_C;
38   }
39   TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
40 }
41 
42 // Some cuBLAS and cuSOLVER batched routines require input to be a device array of pointers to device individual matrices
43 // 'input' must be a contiguous tensor
44 template <typename scalar_t>
get_device_pointers(const Tensor & input)45 static Tensor get_device_pointers(const Tensor& input) {
46   auto input_data = input.const_data_ptr<scalar_t>();
47   int64_t input_mat_stride = matrixStride(input);
48 
49   // cublas/cusolver interface requires 'int'
50   int batch_size = cuda_int_cast(batchCount(input), "batch_size");
51 
52   // if batch_size==0, then start=0 and end=0
53   // if input_mat_stride==0, then step=sizeof(scalar_t)
54   return at::arange(
55       /*start=*/reinterpret_cast<int64_t>(input_data),
56       /*end=*/reinterpret_cast<int64_t>(input_data + batch_size * input_mat_stride),
57       /*step=*/static_cast<int64_t>(std::max<int64_t>(input_mat_stride, 1) * sizeof(scalar_t)),
58       input.options().dtype(at::kLong));
59 }
60 
61 namespace {
62 
63 template <typename scalar_t>
apply_ldl_factor_cusolver(const Tensor & A,const Tensor & pivots,const Tensor & info,bool upper)64 void apply_ldl_factor_cusolver(
65     const Tensor& A,
66     const Tensor& pivots,
67     const Tensor& info,
68     bool upper) {
69 #if !defined(USE_LINALG_SOLVER)
70   TORCH_CHECK(
71       false,
72       "Calling torch.linalg.ldl_factor on a CUDA tensor requires compiling ",
73       "PyTorch with cuSOLVER. Please use PyTorch built with cuSOLVER support.");
74 #else
75   auto batch_size = batchCount(A);
76   auto n = cuda_int_cast(A.size(-2), "A.size(-2)");
77   auto lda = cuda_int_cast(A.stride(-1), "A.stride(-1)");
78   auto uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
79 
80   auto a_stride = A.dim() > 2 ? A.stride(-3) : 0;
81   auto pivots_stride = pivots.dim() > 1 ? pivots.stride(-2) : 0;
82 
83   auto a_data = A.data_ptr<scalar_t>();
84   auto pivots_data = pivots.data_ptr<int>();
85   auto info_data = info.data_ptr<int>();
86 
87   auto handle = at::cuda::getCurrentCUDASolverDnHandle();
88 
89   int lwork = 0;
90   at::cuda::solver::sytrf_bufferSize(handle, n, a_data, lda, &lwork);
91   auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
92   auto work = allocator.allocate(sizeof(scalar_t) * lwork);
93 
94   for (const auto i : c10::irange(batch_size)) {
95     auto* a_working_ptr = &a_data[i * a_stride];
96     auto* pivots_working_ptr = &pivots_data[i * pivots_stride];
97     auto* info_working_ptr = &info_data[i];
98     at::cuda::solver::sytrf(
99         handle,
100         uplo,
101         n,
102         a_working_ptr,
103         lda,
104         pivots_working_ptr,
105         reinterpret_cast<scalar_t*>(work.get()),
106         lwork,
107         info_working_ptr);
108   }
109 #endif
110 }
111 
112 template <typename scalar_t>
apply_ldl_solve_cusolver(const Tensor & A,const Tensor & pivots,const Tensor & B,bool upper)113 void apply_ldl_solve_cusolver(
114     const Tensor& A,
115     const Tensor& pivots,
116     const Tensor& B,
117     bool upper) {
118 #if !(defined(CUDART_VERSION) && defined(CUSOLVER_VERSION) && \
119     CUSOLVER_VERSION >= 11102)
120   TORCH_CHECK(
121       false,
122       "Calling torch.linalg.ldl_solve on a CUDA tensor requires compiling ",
123       "PyTorch with cuSOLVER. Please use PyTorch built with cuSOLVER 11.1.2+ (CUDA 11.3.1+) support.");
124 #else
125   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(batchCount(A) > 0);
126   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(batchCount(pivots.unsqueeze(-1)) > 0);
127   auto batch_size = batchCount(B);
128   auto n = A.size(-2);
129   auto nrhs = B.size(-1);
130   auto lda = A.stride(-1);
131   auto ldb = B.stride(-1);
132   auto uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
133 
134   auto a_stride = A.dim() > 2 ? A.stride(-3) : 0;
135   auto b_stride = B.dim() > 2 ? B.stride(-3) : 0;
136   auto pivots_stride = pivots.dim() > 1 ? pivots.stride(-2) : 0;
137 
138   auto a_data = A.const_data_ptr<scalar_t>();
139   auto b_data = B.data_ptr<scalar_t>();
140 
141   auto pivots_ = pivots.to(kLong);
142   auto pivots_data = pivots_.const_data_ptr<int64_t>();
143 
144   // needed to run ldl_solve tests in parallel
145   // see https://github.com/pytorch/pytorch/issues/82894 for examples of failures
146   c10::cuda::device_synchronize();
147   auto handle = at::cuda::getCurrentCUDASolverDnHandle();
148   auto datatype = at::cuda::solver::get_cusolver_datatype<scalar_t>();
149   size_t worksize_device = 0;
150   size_t worksize_host = 0;
151 
152   TORCH_CUSOLVER_CHECK(cusolverDnXsytrs_bufferSize(
153       handle,
154       uplo,
155       n,
156       nrhs,
157       datatype,
158       a_data,
159       lda,
160       pivots_data,
161       datatype,
162       b_data,
163       ldb,
164       &worksize_device,
165       &worksize_host));
166 
167   // allocate workspace storage
168   auto& device_allocator = *at::cuda::getCUDADeviceAllocator();
169   auto workdata_device = device_allocator.allocate(worksize_device);
170   void* workdata_device_ptr = workdata_device.get();
171 
172   auto& host_allocator = *at::getCPUAllocator();
173   auto workdata_host = host_allocator.allocate(worksize_host);
174   void* workdata_host_ptr = workdata_host.get();
175 
176   Tensor info = at::zeros({}, A.options().dtype(at::kInt));
177   for (const auto i : c10::irange(batch_size)) {
178     const auto* a_working_ptr = &a_data[i * a_stride];
179     auto* b_working_ptr = &b_data[i * b_stride];
180     const auto* pivots_working_ptr = &pivots_data[i * pivots_stride];
181     TORCH_CUSOLVER_CHECK(cusolverDnXsytrs(
182         handle,
183         uplo,
184         n,
185         nrhs,
186         datatype,
187         a_working_ptr,
188         lda,
189         pivots_working_ptr,
190         datatype,
191         b_working_ptr,
192         ldb,
193         workdata_device_ptr,
194         worksize_device,
195         workdata_host_ptr,
196         worksize_host,
197         info.data_ptr<int>()));
198   }
199 
200   // info from sytrs only reports if the i-th parameter is wrong
201   // so we don't need to check it all the time
202   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info.item().toInt() == 0);
203 #endif
204 }
205 
206 } // anonymous namespace
207 
ldl_factor_cusolver(const Tensor & LD,const Tensor & pivots,const Tensor & info,bool upper,bool hermitian)208 void ldl_factor_cusolver(
209     const Tensor& LD,
210     const Tensor& pivots,
211     const Tensor& info,
212     bool upper,
213     bool hermitian) {
214   if (LD.is_complex()) {
215     TORCH_CHECK(
216         !hermitian,
217         "torch.linalg.ldl_factor: complex tensors with hermitian=True flag are not supported with cuSOLVER backend. ",
218         "Currently preferred backend is ",
219         at::globalContext().linalgPreferredBackend(),
220         ", please set 'default' or 'magma' backend with torch.backends.cuda.preferred_linalg_library");
221   }
222   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
223       LD.scalar_type(), "ldl_factor_looped_cusolver", [&] {
224         apply_ldl_factor_cusolver<scalar_t>(LD, pivots, info, upper);
225       });
226 }
227 
ldl_solve_cusolver(const Tensor & LD,const Tensor & pivots,const Tensor & B,bool upper)228 void ldl_solve_cusolver(
229     const Tensor& LD,
230     const Tensor& pivots,
231     const Tensor& B,
232     bool upper) {
233   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
234       LD.scalar_type(), "ldl_solve_looped_cusolver", [&] {
235         apply_ldl_solve_cusolver<scalar_t>(LD, pivots, B, upper);
236       });
237 }
238 
239 #if defined(USE_LINALG_SOLVER)
240 
column_major_identity_matrix_like(const Tensor & self)241 inline static Tensor column_major_identity_matrix_like(const Tensor& self) {
242   auto size = self.sizes();
243   auto size_slice = IntArrayRef(size.data(), size.size()-1);
244   return at::ones(size_slice, self.options()).diag_embed().mT();
245 }
246 
247 
248 // call cusolver gesvd function to calculate svd
249 template<typename scalar_t>
apply_svd_cusolver_gesvd(const Tensor & A,const Tensor & U,const Tensor & S,const Tensor & V,const Tensor & infos,bool full_matrices,bool compute_uv,const bool calculate_all_batches,const std::vector<int64_t> & batches)250 inline static void apply_svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
251   const Tensor& infos, bool full_matrices, bool compute_uv,
252   const bool calculate_all_batches,
253   const std::vector<int64_t>& batches
254 ) {
255   using value_t = typename c10::scalar_value_type<scalar_t>::type;
256   auto A_data = A.data_ptr<scalar_t>();
257   auto S_data = S.data_ptr<value_t>();
258   auto A_stride = matrixStride(A);
259   auto S_stride = S.size(-1);
260 
261   int m = cuda_int_cast(A.size(-2), "m");
262   int n = cuda_int_cast(A.size(-1), "n");
263   auto k = std::min(m, n);
264   int lda = std::max<int>(1, m);
265   int ldvh = std::max<int>(1, n);
266 
267   TORCH_INTERNAL_ASSERT(m >= n, "cusolver gesvd only supports matrix with sizes m >= n");
268 
269   char job = compute_uv ? (full_matrices ? 'A' : 'S') : 'N';
270   auto handle = at::cuda::getCurrentCUDASolverDnHandle();
271 
272   int lwork = -1;
273   at::cuda::solver::gesvd_buffersize<scalar_t>(handle, m, n, &lwork);
274   TORCH_INTERNAL_ASSERT(lwork >= 0, "gesvd_buffersize failed to get needed buffer size, got lwork = ", lwork);
275 
276   auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
277   const auto dataPtr_work = allocator.allocate(sizeof(scalar_t)*lwork);
278   const auto dataPtr_rwork = allocator.allocate(sizeof(value_t)*std::min(m, n));
279 
280   // nb. We can do this .view() because V is a batch of F-contig matrices
281   const auto V_view = compute_uv ? V.view({-1, n, V.size(-1)})
282                                  : Tensor{};
283   // V is F-contig. Since this function computes Vh, we need an auxiliary F-conj-transposed matrix to hold Vh
284   const auto Vh_workspace = compute_uv ?  at::empty({n, full_matrices ? n : k},
285                                               A.options().memory_format(at::MemoryFormat::Contiguous)).conj()
286                                        : Tensor{};
287   const auto Vh_ptr = compute_uv ? Vh_workspace.data_ptr<scalar_t>()
288                                  : nullptr;
289 
290   const auto U_stride = compute_uv ? matrixStride(U) : 0;
291   const auto U_ptr = compute_uv ? U.data_ptr<scalar_t>() : nullptr;
292 
293   int batchsize = calculate_all_batches ? cuda_int_cast(batchCount(A), "batch size")
294                                         : batches.size();
295 
296 
297   for(int _i = 0; _i < batchsize; _i++){
298     int i = calculate_all_batches ? _i : batches[_i];
299 
300     at::cuda::solver::gesvd<scalar_t>(
301       handle, job, job, m, n,
302       A_data + i * A_stride,
303       lda,
304       S_data + i * S_stride,
305       compute_uv ? U_ptr + i * U_stride : nullptr,
306       lda,
307       compute_uv ? Vh_ptr : nullptr,
308       ldvh,
309       reinterpret_cast<scalar_t*>(dataPtr_work.get()),
310       lwork,
311       reinterpret_cast<value_t*>(dataPtr_rwork.get()),
312       infos.data_ptr<int>() + i
313     );
314 
315     if (compute_uv) {
316       V_view[i].copy_(Vh_workspace);
317     }
318   }
319 }
320 
321 // We'll copy A inside svd_cusolver_gesvd
svd_cusolver_gesvd(const Tensor & A,const Tensor & U,const Tensor & S,const Tensor & V,const Tensor & infos,bool full_matrices,bool compute_uv,const bool calculate_all_batches=true,const std::vector<int64_t> & batches={} )322 inline static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
323   const Tensor& infos, bool full_matrices, bool compute_uv,
324   const bool calculate_all_batches = true,
325   const std::vector<int64_t>& batches = {}
326 ) {
327   // We need to pass a copy of A, as it will be overwritten
328   // gesvd just knows how to handle m >= n, so in the other case we need to transpose A
329   const auto not_A_H = A.size(-2) >= A.size(-1);
330   Tensor Vcopy = V; // Shallow copy
331 #ifdef USE_ROCM
332   // Similar to the case in svd_magma(), experiments have shown Vh tensor is
333   // not guaranteed to be column major on ROCM, we have to create a copy to
334   // deal with this
335   if (!not_A_H) {
336     Vcopy = at::empty_like(V.mT(),
337                            V.options()
338                            .device(V.device())
339                            .memory_format(at::MemoryFormat::Contiguous)).mT();
340   }
341 #endif
__anon128ac59f0402null342   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "svd_cuda_gesvd", [&] {
343     apply_svd_cusolver_gesvd<scalar_t>(cloneBatchedColumnMajor(not_A_H ? A : A.mH()),
344                                        not_A_H ? U : Vcopy,
345                                        S,
346                                        not_A_H ? Vcopy : U,
347                                        infos,
348                                        full_matrices, compute_uv, calculate_all_batches, batches);
349   });
350 #ifdef USE_ROCM
351   if (!not_A_H) {
352     V.copy_(Vcopy);
353   }
354 #endif
355 }
356 
357 // call cusolver gesvdj function to calculate svd
358 template<typename scalar_t>
apply_svd_cusolver_gesvdj(const Tensor & A,const Tensor & U,const Tensor & S,const Tensor & V,const Tensor & infos,bool full_matrices,bool compute_uv)359 inline static void apply_svd_cusolver_gesvdj(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
360   const Tensor& infos, bool full_matrices, bool compute_uv) {
361   using value_t = typename c10::scalar_value_type<scalar_t>::type;
362   int m = cuda_int_cast(A.size(-2), "m");
363   int n = cuda_int_cast(A.size(-1), "n");
364   int k = std::min(m, n);
365 
366   // Need to pass allocated memory to the function, otherwise it fails
367   auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
368   auto dataPtr_U = !compute_uv ? allocator.allocate(sizeof(scalar_t)* m * k) : c10::DataPtr{};
369   auto dataPtr_V = !compute_uv ? allocator.allocate(sizeof(scalar_t)* n * k) : c10::DataPtr{};
370 
371   auto A_data = A.data_ptr<scalar_t>();
372   auto U_data = compute_uv ? U.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_U.get());
373   auto S_data = S.data_ptr<value_t>();
374   auto V_data = compute_uv ? V.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_V.get());
375   auto A_stride = matrixStride(A);
376   auto U_stride = compute_uv ? matrixStride(U) : 0;
377   auto S_stride = S.size(-1);
378   auto V_stride = compute_uv ? matrixStride(V) : 0;
379 
380   int batchsize = cuda_int_cast(batchCount(A), "batch size");
381   int lda = A.stride(-1);
382   int ldu = compute_uv ? U.stride(-1) : m;
383   int ldv = compute_uv ? V.stride(-1) : n;
384 
385   auto handle = at::cuda::getCurrentCUDASolverDnHandle();
386   auto jobz = compute_uv ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
387   int econ = full_matrices ? 0 : 1;
388 
389   // gesvdj_params controls the numerical accuracy of cusolver gesvdj iterations on GPU
390   gesvdjInfo_t gesvdj_params;
391   TORCH_CUSOLVER_CHECK(cusolverDnCreateGesvdjInfo(&gesvdj_params));
392 
393   // Todo: expose the following two parameters to users
394   TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetTolerance(gesvdj_params, std::numeric_limits<scalar_t>::epsilon()));
395   TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetMaxSweeps(gesvdj_params, cusolver_gesvdj_max_sweeps));
396 
397   int lwork = -1;
398   at::cuda::solver::gesvdj_buffersize<scalar_t>(
399     handle, jobz, econ, m, n, A_data, lda, S_data, U_data, ldu, V_data, ldv, &lwork, gesvdj_params);
400   TORCH_INTERNAL_ASSERT(lwork >= 0, "gesvdj_buffersize failed to get needed buffer size, got lwork = ", lwork);
401 
402   auto dataPtr = allocator.allocate(sizeof(scalar_t)*lwork);
403 
404   for(int i = 0; i < batchsize; i++){
405     at::cuda::solver::gesvdj<scalar_t>(
406       handle, jobz, econ, m, n,
407       A_data + i * A_stride,
408       lda,
409       S_data + i * S_stride,
410       U_data + i * U_stride,
411       ldu,
412       V_data + i * V_stride,
413       ldv,
414       reinterpret_cast<scalar_t*>(dataPtr.get()),
415       lwork,
416       infos.data_ptr<int>() + i,
417       gesvdj_params
418     );
419 
420     // // The following code can be used to check or report the gesvdj residual.
421     // // Note: this will introduce a device-host sync and may negatively affect the performance
422     // double residual = 0;
423     // TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjGetResidual(handle, gesvdj_params, &residual));
424     // printf("gesvdj residual = %.6e\n", residual);
425   }
426 
427   TORCH_CUSOLVER_CHECK(cusolverDnDestroyGesvdjInfo(gesvdj_params));
428 }
429 
430 // wrapper around apply_svd_cusolver_gesvdj that handles dtype dispatch
431 // note that gesvdj returns V, which is what we want
432 // Need to pass a copy of A, since A will be rewritten inside the function call
svd_cusolver_gesvdj(const Tensor & A,const Tensor & U,const Tensor & S,const Tensor & V,const Tensor & infos,bool full_matrices,bool compute_uv)433 inline static void svd_cusolver_gesvdj(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V, const Tensor& infos, bool full_matrices, bool compute_uv) {
434   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "svd_cuda_gesvdj", [&] {
435     apply_svd_cusolver_gesvdj<scalar_t>(A, U, S, V, infos, full_matrices, compute_uv);
436   });
437 }
438 
439 // call cusolver gesvdj batched function to calculate svd
440 template<typename scalar_t>
apply_svd_cusolver_gesvdjBatched(const Tensor & A,const Tensor & U,const Tensor & S,const Tensor & V,const Tensor & infos,bool compute_uv)441 inline static void apply_svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
442   const Tensor& infos, bool compute_uv
443 ) {
444   using value_t = typename c10::scalar_value_type<scalar_t>::type;
445   int m = cuda_int_cast(A.size(-2), "m");
446   int n = cuda_int_cast(A.size(-1), "n");
447   int batchsize = cuda_int_cast(batchCount(A), "batch size");
448   int lda = A.stride(-1);
449   int ldu = compute_uv ? U.stride(-1) : m;
450   int ldv = compute_uv ? V.stride(-1) : n;
451 
452   // Need to pass allocated memory to the function, otherwise it fails
453   auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
454   auto dataPtr_U = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * m * ldu) : c10::DataPtr{};
455   auto dataPtr_V = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * n * ldv) : c10::DataPtr{};
456 
457   auto A_data = A.data_ptr<scalar_t>();
458   auto U_data = compute_uv ? U.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_U.get());
459   auto S_data = S.data_ptr<value_t>();
460   auto V_data = compute_uv ? V.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_V.get());
461 
462   TORCH_INTERNAL_ASSERT(m <= 32 && n <= 32, "gesvdjBatched requires both matrix dimensions not greater than 32, but got "
463                         "m = ", m, " n = ", n);
464 
465   // gesvdj_params controls the numerical accuracy of cusolver gesvdj iterations on GPU
466   gesvdjInfo_t gesvdj_params;
467   TORCH_CUSOLVER_CHECK(cusolverDnCreateGesvdjInfo(&gesvdj_params));
468 
469   // Todo: expose the following two parameters to users
470   TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetTolerance(gesvdj_params, std::numeric_limits<scalar_t>::epsilon()));
471   TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetMaxSweeps(gesvdj_params, cusolver_gesvdj_max_sweeps));
472   TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetSortEig(gesvdj_params, 1));
473 
474   auto handle = at::cuda::getCurrentCUDASolverDnHandle();
475   auto jobz = compute_uv ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
476   at::cuda::solver::gesvdjBatched<scalar_t>(
477     handle, jobz, m, n, A_data, lda, S_data, U_data, ldu, V_data, ldv,
478     infos.data_ptr<int>(), gesvdj_params, batchsize
479   );
480 
481   TORCH_CUSOLVER_CHECK(cusolverDnDestroyGesvdjInfo(gesvdj_params));
482 }
483 
svd_cusolver_gesvdjBatched(const Tensor & A,const Tensor & U,const Tensor & S,const Tensor & V,const Tensor & infos,bool full_matrices,bool compute_uv)484 inline static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V, const Tensor& infos, bool full_matrices, bool compute_uv) {
485   auto m = A.size(-2);
486   auto n = A.size(-1);
487   auto k = std::min(m, n);
488   // The kernel assumes full_matrices == true
489   // If full_matrices == false and m != n, we create auxiliary tensors of the right size and copy the results back
490   auto U_ = U;
491   auto V_ = V;
492   if (compute_uv && !full_matrices) {
493     auto sizes = A.sizes().vec();
494     if (m > n) {
495       // Size of U with full_matrices == True
496       sizes.end()[-1] = m;
497       // U, V should be a batch of Fortran contiguous arrays
498       U_ = U.new_empty(sizes).mT();
499     } else if (m < n) {
500       // Size of V with full_matrices == True
501       sizes.end()[-2] = n;
502       V_ = V.new_empty(sizes).mT();
503     }
504   }
505   // Here U_ and V_ are batches of F-contig square matrices
506 
507   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "svd_cuda_gesvdjBatched", [&] {
508     apply_svd_cusolver_gesvdjBatched<scalar_t>(A, U_, S, V_, infos, compute_uv);
509   });
510 
511   // Copy the result back if we created any new matrix
512   if (compute_uv && !full_matrices) {
513     if (!U_.is_alias_of(U)) {
514       U.copy_(U_.narrow(-1, 0, k));
515     }
516     if (!V_.is_alias_of(V)) {
517       V.copy_(V_.narrow(-1, 0, k));
518     }
519   }
520 }
521 
522 template<typename scalar_t>
apply_svd_cusolver_gesvdaStridedBatched(const Tensor & A,const Tensor & U,const Tensor & S,const Tensor & V,const Tensor & infos,bool full_matrices,bool compute_uv)523 inline static void apply_svd_cusolver_gesvdaStridedBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
524     const Tensor& infos, bool full_matrices, bool compute_uv) {
525 #ifndef CUDART_VERSION
526   TORCH_CHECK(false, "gesvda: Batched version is supported only with cuBLAS backend.")
527 #else
528   using value_t = typename c10::scalar_value_type<scalar_t>::type;
529   int m = cuda_int_cast(A.size(-2), "m");
530   int n = cuda_int_cast(A.size(-1), "n");
531   TORCH_INTERNAL_ASSERT(m >= n, "cusolver gesvdaStridedBatched requires m >= n");
532   int batchsize = cuda_int_cast(batchCount(A), "batch size");
533 
534   int lda = A.stride(-1);
535   int ldu = compute_uv ? U.stride(-1) : m;
536   int ldv = compute_uv ? V.stride(-1) : n;
537 
538   auto A_stride = matrixStride(A);
539   auto S_stride = S.size(-1);
540   auto rank = S_stride; // number of singular values
541   auto U_stride = compute_uv ? matrixStride(U) : ldu * rank;  // The strides for "empty matrices" are needed to satisfy cusolver.
542   auto V_stride = compute_uv ? matrixStride(V) : ldv * rank;
543 
544   // Need to pass allocated memory to the function, otherwise it fails
545   auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
546   auto dataPtr_U = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * m * n) : c10::DataPtr{};
547   auto dataPtr_V = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * n * n) : c10::DataPtr{};
548 
549   auto A_data = A.data_ptr<scalar_t>();
550   auto U_data = compute_uv ? U.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_U.get());
551   auto S_data = S.data_ptr<value_t>();
552   auto V_data = compute_uv ? V.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_V.get());
553 
554   auto handle = at::cuda::getCurrentCUDASolverDnHandle();
555   auto jobz = compute_uv ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
556 
557   int lwork = -1;
558   at::cuda::solver::gesvdaStridedBatched_buffersize<scalar_t>(
559     handle, jobz, rank, m, n, A_data, lda, A_stride, S_data, S_stride, U_data, ldu, U_stride, V_data, ldv, V_stride,
560     &lwork, batchsize);
561   TORCH_INTERNAL_ASSERT(lwork >= 0, "gesvdaStridedBatched_buffersize failed to get needed buffer size, got lwork = ", lwork);
562   auto workspace = allocator.allocate(sizeof(scalar_t)*lwork);
563 
564   // The residual Frobenius norm is always returned in double.
565   // cuSOLVER remark: if the user is confident on the accuracy of singular values and singular vectors,
566   //   for example, certain conditions hold (required singular value is far from zero),
567   //   then the performance can be improved by passing a null pointer to h_RnrmF, i.e. no computation of residual norm.
568   // Comment: calculation of Frobenius norm is expensive and doesn't affect accuracy of the result
569 
570   at::cuda::solver::gesvdaStridedBatched<scalar_t>(
571     handle, jobz, rank, m, n, A_data, lda, A_stride, S_data, S_stride, U_data, ldu, U_stride, V_data, ldv, V_stride,
572     reinterpret_cast<scalar_t*>(workspace.get()),
573     lwork, infos.data_ptr<int>(),
574     nullptr,  // cuSOLVER h_RnrmF is not calculated: reinterpret_cast<double*>(residual_frobenius_norm.get()),
575     batchsize);
576 #endif
577 }
578 
579 // We'll copy A inside svd_cusolver_gesvdaStridedBatched
svd_cusolver_gesvdaStridedBatched(const Tensor & A,const Tensor & U,const Tensor & S,const Tensor & V,const Tensor & infos,bool full_matrices,bool compute_uv)580 inline static void svd_cusolver_gesvdaStridedBatched(
581     const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
582     const Tensor& infos, bool full_matrices, bool compute_uv) {
583   // We need to pass a copy of A, as it will be overwritten
584   // gesvdaStridedBatched just knows how to handle m >= n, so in the other case we need to transpose A
585   const auto not_A_H = A.size(-2) >= A.size(-1);
586   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "svd_cuda_gesvdaStridedBatched", [&] {
587     apply_svd_cusolver_gesvdaStridedBatched<scalar_t>(
588       cloneBatchedColumnMajor(not_A_H ? A : A.mH()),
589       not_A_H ? U : V,
590       S,
591       not_A_H ? V : U,
592       infos, full_matrices, compute_uv);
593   });
594 }
595 
596 // Check convergence of gesvdj/gesvdjBatched/gesvdaStridedBatched results.
597 // If not converged, return a vector that contains indices of the non-converging batches.
598 // If the returned vector is empty, all the matrices are converged.
599 // This function will cause a device-host sync.
_check_gesvdj_convergence(const Tensor & infos,int64_t non_converging_info)600 std::vector<int64_t> _check_gesvdj_convergence(const Tensor& infos, int64_t non_converging_info) {
601   at::Tensor infos_cpu = infos.cpu();
602   auto infos_cpu_data = infos_cpu.data_ptr<int>();
603 
604   std::vector<int64_t> res;
605 
606   for(int64_t i = 0; i < infos.numel(); i++) {
607     int info_for_batch_i = infos_cpu_data[i];
608 
609     // From cusolver doc, if info < 0, the i-th function call parameter is wrong,
610     // which means pytorch implementation of cusolver is wrong.
611     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info_for_batch_i >= 0);
612 
613     // In our use case, gesvdj, gesvdjBatched, and gesvdaStridedBatched have the same notations for `info`.
614     if (info_for_batch_i == non_converging_info) res.push_back(i);
615 
616     // However, it is not the same for gesvd, though we don't use this function to check gesvd convergence either.
617     // If it's implemented some day in the future, this needs to be handled carefully.
618   }
619 
620   return res;
621 }
622 
623 // Depending on the number of non-converging batches,
624 // format the non-converging batches string as either (no leading or trailing whitespaces)
625 // batches 2, 3, 5  // or
626 // batches 2, 3, 5, 7, 11 and other 65535 batches
_format_non_converging_batches(const std::vector<int64_t> & batches)627 std::string _format_non_converging_batches(const std::vector<int64_t>& batches) {
628   std::stringstream ss;
629   const int too_long = 5;
630 
631   ss << "batches ";
632   if (batches.size() <= too_long) {
633     for (const auto i : c10::irange(batches.size() - 1)) {
634       ss << batches[i] << ", ";
635     }
636     ss << batches.back();
637   } else {
638     for (const auto i : c10::irange(too_long)) {
639       ss << batches[i] << ", ";
640     }
641     ss << "and other " << batches.size() - too_long << " batches";
642   }
643 
644   return ss.str();
645 }
646 
647 // This function returns V, not V^H.
svd_cusolver(const Tensor & A,const bool full_matrices,const bool compute_uv,const std::optional<c10::string_view> & driver,const Tensor & U,const Tensor & S,const Tensor & V,const Tensor & info)648 void svd_cusolver(const Tensor& A,
649                   const bool full_matrices,
650                   const bool compute_uv,
651                   const std::optional<c10::string_view>& driver,
652                   const Tensor& U,
653                   const Tensor& S,
654                   const Tensor& V,
655                   const Tensor& info) {
656   // Here U and V are F-contig whenever they are defined (i.e. whenever compute_uv=true)
657   const auto m = A.size(-2);
658   const auto n = A.size(-1);
659   const auto k = std::min(m, n);
660 
661   static const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html";
662 
663   // The default heuristic is to use gesvdj driver
664 #ifdef USE_ROCM
665   const auto driver_v = c10::string_view("gesvdj");
666 #else
667   const auto driver_v = driver.value_or("gesvdj");
668 #endif
669 
670   if (driver_v == "gesvd") {
671     svd_cusolver_gesvd(A, U, S, V, info, full_matrices, compute_uv);
672   } else if (driver_v == "gesvdj") {
673     // See the benchmarks in
674     // https://github.com/pytorch/pytorch/pull/88502#issuecomment-1303860789
675     // The m <= 32 && n <= 32 restrictions come from the limitations of the cusolver backend. See the cusolver docs
676     if (m <= 32 && n <= 32) {
677       svd_cusolver_gesvdjBatched(cloneBatchedColumnMajor(A), U, S, V, info, full_matrices, compute_uv);
678     } else {
679       // gesvdj driver may be numerically unstable for large sized matrix
680       svd_cusolver_gesvdj(cloneBatchedColumnMajor(A), U, S, V, info, full_matrices, compute_uv);
681     }
682   } else if (driver_v == "gesvda") {
683     // cuSOLVER: gesvdaStridedBatched is preferred for "tall skinny" (m > n) matrices
684     // We do a transpose here to make it also work for (m < n) matrices.
685     svd_cusolver_gesvdaStridedBatched(A, U, S, V, info, full_matrices, compute_uv);
686   } else {
687     TORCH_CHECK(false, "torch.linalg.svd: unknown svd driver ", driver_v, " in svd_cusolver computation. ", check_svd_doc);
688   }
689 
690   // Need convergence check
691   if (driver_v != "gesvd") {
692     // A device-host sync will be performed.
693     // Todo: implement the svd_ex variant to not check result convergence, thus removing the device-host sync
694     const auto svd_non_converging_batches = _check_gesvdj_convergence(info, k + 1);
695 
696     if (!svd_non_converging_batches.empty()) {
697       TORCH_WARN_ONCE("torch.linalg.svd: During SVD computation with the selected cusolver driver, ",
698                       _format_non_converging_batches(svd_non_converging_batches),
699                       " failed to converge. ",
700                       (driver.has_value()
701                         ?  "It is recommended to redo this SVD with another driver. "
702                         : "A more accurate method will be used to compute the SVD as a fallback. "),
703                       check_svd_doc);
704 
705       // We'll do the fallback if user doesn't specify a driver and the default heuristic doesn't converge well.
706       // However, if user manually chooses a driver, should we just do a warning or a hard crash?
707       if (!driver.has_value()) {
708         svd_cusolver_gesvd(A, U, S, V, info, full_matrices, compute_uv, false, svd_non_converging_batches);
709       }
710     }
711   }
712 
713   // `info` will be checked later at `TORCH_IMPL_FUNC(_linalg_svd_out)` function.
714 }
715 
716 
717 // Implementation of Cholesky decomposition using looped cusolverDn<T>potrf or cusolverDnXpotrf (64-bit)
718 template<typename scalar_t>
apply_cholesky_cusolver_potrf_looped(const Tensor & self_working_copy,bool upper,const Tensor & infos)719 inline static void apply_cholesky_cusolver_potrf_looped(const Tensor& self_working_copy, bool upper, const Tensor& infos) {
720   auto handle = at::cuda::getCurrentCUDASolverDnHandle();
721   const auto uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
722   const int64_t n = self_working_copy.size(-1);
723   const int64_t lda = std::max<int64_t>(1, n);
724   const int64_t batch_size = batchCount(self_working_copy);
725   const int64_t matrix_stride = matrixStride(self_working_copy);
726 
727   scalar_t* self_working_copy_ptr = self_working_copy.data_ptr<scalar_t>();
728   int* infos_ptr = infos.data_ptr<int>();
729 
730 #ifdef USE_CUSOLVER_64_BIT
731   size_t worksize_device;
732   size_t worksize_host;
733   cusolverDnParams_t params;
734   cudaDataType datatype = at::cuda::solver::get_cusolver_datatype<scalar_t>();
735   TORCH_CUSOLVER_CHECK(cusolverDnCreateParams(&params));
736   at::cuda::solver::xpotrf_buffersize(handle, params, uplo, n, datatype, nullptr, lda, datatype, &worksize_device, &worksize_host);
737 
738   // allocate workspace storage
739   auto& device_allocator = *at::cuda::getCUDADeviceAllocator();
740   auto workdata_device = device_allocator.allocate(worksize_device * batch_size);
741   void* workdata_device_ptr = workdata_device.get();
742 
743   auto& host_allocator = *at::getCPUAllocator();
744   auto workdata_host = host_allocator.allocate(worksize_host * batch_size);
745   void* workdata_host_ptr = workdata_host.get();
746 
747   for (int64_t i = 0; i < batch_size; i++) {
748     at::cuda::solver::xpotrf(
749       handle, params, uplo, n, datatype,
750       self_working_copy_ptr + i * matrix_stride,
751       lda, datatype,
752       (char*)workdata_device_ptr + i * worksize_device, worksize_device,
753       (char*)workdata_host_ptr + i * worksize_host, worksize_host,
754       infos_ptr + i
755     );
756   }
757 
758   TORCH_CUSOLVER_CHECK(cusolverDnDestroyParams(params));
759 #else // USE_CUSOLVER_64_BIT
760   int n_32 = cuda_int_cast(n, "n");
761   int lda_32 = cuda_int_cast(lda, "lda");
762   int lwork;
763   at::cuda::solver::potrf_buffersize<scalar_t>(
764     handle, uplo, n_32, nullptr, lda_32, &lwork);
765 
766    // allocate workspace storage
767   auto& allocator = *at::cuda::getCUDADeviceAllocator();
768   auto work_data = allocator.allocate(sizeof(scalar_t)*lwork * batch_size);
769   scalar_t* work_data_ptr = static_cast<scalar_t*>(work_data.get());
770 
771   for (int64_t i = 0; i < batch_size; i++) {
772     at::cuda::solver::potrf<scalar_t>(
773       handle, uplo, n_32,
774       self_working_copy_ptr + i * matrix_stride,
775       lda_32,
776       work_data_ptr + i * lwork,
777       lwork,
778       infos_ptr + i
779     );
780   }
781 #endif // USE_CUSOLVER_64_BIT
782 }
783 
784 // Implementation of Cholesky decomposition using batched cusolverDn<T>potrfBatched
785 // Warning: cusolverDn<T>potrfBatched doesn't work quite well when matrix size or batch size is zero.
786 // If you write your own C++ extension and use this function, make sure you do a zero numel check for the input.
787 template<typename scalar_t>
apply_cholesky_cusolver_potrfBatched(const Tensor & self_working_copy,bool upper,const Tensor & infos)788 inline static void apply_cholesky_cusolver_potrfBatched(const Tensor& self_working_copy, bool upper, const Tensor& infos) {
789   auto handle = at::cuda::getCurrentCUDASolverDnHandle();
790   const auto uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
791   const int n = cuda_int_cast(self_working_copy.size(-1), "n");
792   const int lda = std::max<int>(1, n);
793 
794   const int batch_size = cuda_int_cast(batchCount(self_working_copy), "batch_size");
795 
796   // cusolver batched kernels require input be "device array of device pointers"
797   Tensor self_working_copy_array = get_device_pointers<scalar_t>(self_working_copy);
798 
799   at::cuda::solver::potrfBatched<scalar_t>(
800     handle, uplo, n,
801     reinterpret_cast<scalar_t**>(self_working_copy_array.data_ptr()),
802     lda, infos.data_ptr<int>(), batch_size);
803 }
804 
cholesky_helper_cusolver(const Tensor & input,bool upper,const Tensor & info)805 void cholesky_helper_cusolver(const Tensor& input, bool upper, const Tensor& info) {
806   if (input.numel() == 0) {
807     return;
808   }
809 
810   if (use_cusolver_potrf_batched_ && batchCount(input) > 1) {
811     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "cholesky_cusolver", [&] {
812       apply_cholesky_cusolver_potrfBatched<scalar_t>(input, upper, info);
813     });
814   } else {
815     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "cholesky_cusolver", [&] {
816       apply_cholesky_cusolver_potrf_looped<scalar_t>(input, upper, info);
817     });
818   }
819 }
820 
821 
822 template<typename scalar_t>
apply_cholesky_cusolver_potrs(Tensor & self_working_copy,const Tensor & A_column_major_copy,bool upper,Tensor & infos)823 inline static void apply_cholesky_cusolver_potrs(Tensor& self_working_copy, const Tensor& A_column_major_copy, bool upper, Tensor& infos) {
824   auto handle = at::cuda::getCurrentCUDASolverDnHandle();
825   const auto uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
826   const int64_t n = self_working_copy.size(-2);
827   const int64_t nrhs = self_working_copy.size(-1);
828   const int64_t lda = std::max<int64_t>(1, n);
829   const int64_t batch_size = batchCount(self_working_copy);
830   const int64_t self_matrix_stride = matrixStride(self_working_copy);
831   scalar_t* self_working_copy_ptr = self_working_copy.data_ptr<scalar_t>();
832 
833   scalar_t* A_ptr = A_column_major_copy.data_ptr<scalar_t>();
834   const int64_t A_matrix_stride = matrixStride(A_column_major_copy);
835   const int64_t ldb = std::max<int64_t>(1, A_column_major_copy.size(-1));
836 
837   int* infos_ptr = infos.data_ptr<int>();
838 
839 #ifdef USE_CUSOLVER_64_BIT
840   cusolverDnParams_t params;
841   cudaDataType datatype = at::cuda::solver::get_cusolver_datatype<scalar_t>();
842   TORCH_CUSOLVER_CHECK(cusolverDnCreateParams(&params));
843 
844   for (int64_t i = 0; i < batch_size; i++) {
845     at::cuda::solver::xpotrs(
846       handle, params, uplo, n, nrhs, datatype,
847       A_ptr + i * A_matrix_stride,
848       lda, datatype,
849       self_working_copy_ptr + i * self_matrix_stride,
850       ldb,
851       infos_ptr
852     );
853   }
854 
855   TORCH_CUSOLVER_CHECK(cusolverDnDestroyParams(params));
856 #else // USE_CUSOLVER_64_BIT
857   int n_32 = cuda_int_cast(n, "n");
858   int nrhs_32 = cuda_int_cast(nrhs, "nrhs");
859   int lda_32 = cuda_int_cast(lda, "lda");
860   int ldb_32 = cuda_int_cast(ldb, "ldb");
861 
862   for (int64_t i = 0; i < batch_size; i++) {
863     at::cuda::solver::potrs<scalar_t>(
864       handle, uplo, n_32, nrhs_32,
865       A_ptr + i * A_matrix_stride,
866       lda_32,
867       self_working_copy_ptr + i * self_matrix_stride,
868       ldb_32,
869       infos_ptr
870     );
871   }
872 #endif // USE_CUSOLVER_64_BIT
873 }
874 
875 
876 // This code path is only dispatched to if MAGMA is not linked in the pytorch build.
877 // cusolverDn<t>potrsBatched only supports nrhs == 1
878 template<typename scalar_t>
apply_cholesky_cusolver_potrsBatched(Tensor & self_working_copy,const Tensor & A_column_major_copy,bool upper,Tensor & infos)879 inline static void apply_cholesky_cusolver_potrsBatched(Tensor& self_working_copy, const Tensor& A_column_major_copy, bool upper, Tensor& infos) {
880   auto handle = at::cuda::getCurrentCUDASolverDnHandle();
881   const auto uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
882   const int64_t n = self_working_copy.size(-2);
883   const int64_t nrhs = self_working_copy.size(-1);
884   const int64_t lda = std::max<int64_t>(1, n);
885   const int64_t batch_size = batchCount(self_working_copy);
886 
887   const int64_t ldb = std::max<int64_t>(1, A_column_major_copy.size(-1));
888 
889   int* infos_ptr = infos.data_ptr<int>();
890 
891   auto self_ptr_array = get_device_pointers<scalar_t>(self_working_copy);
892   auto A_ptr_array = get_device_pointers<scalar_t>(A_column_major_copy);
893 
894   at::cuda::solver::potrsBatched(
895     handle, uplo,
896     cuda_int_cast(n, "n"),
897     cuda_int_cast(nrhs, "nrhs"),
898     reinterpret_cast<scalar_t**>(A_ptr_array.data_ptr()),
899     cuda_int_cast(lda, "lda"),
900     reinterpret_cast<scalar_t**>(self_ptr_array.data_ptr()),
901     cuda_int_cast(ldb, "ldb"),
902     infos_ptr,
903     cuda_int_cast(batch_size, "batch_size")
904   );
905 }
906 
_cholesky_solve_helper_cuda_cusolver(const Tensor & self,const Tensor & A,bool upper)907 Tensor _cholesky_solve_helper_cuda_cusolver(const Tensor& self, const Tensor& A, bool upper) {
908   const int64_t batch_size = batchCount(self);
909   at::Tensor infos = at::zeros({1}, self.options().dtype(at::kInt));
910   at::Tensor self_working_copy = cloneBatchedColumnMajor(self);
911   at::Tensor A_column_major_copy = cloneBatchedColumnMajor(A);
912 
913   const int64_t nrhs = self_working_copy.size(-1);
914 
915   // cusolverDn<t>potrsBatched only supports nrhs == 1
916   if (batch_size > 1 && nrhs == 1) {
917     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "cholesky_cuda_potrs_batched", [&] {
918       apply_cholesky_cusolver_potrsBatched<scalar_t>(self_working_copy, A_column_major_copy, upper, infos);
919     });
920   } else {
921     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "cholesky_cuda_potrs", [&] {
922       apply_cholesky_cusolver_potrs<scalar_t>(self_working_copy, A_column_major_copy, upper, infos);
923     });
924   }
925 
926   // info from potrs and potrsBatched only report if the i-th parameter is wrong, not about the matrix singularity, etc.
927   // So we don't need to check it all the time.
928   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.item().toInt() == 0);
929 
930   return self_working_copy;
931 }
932 
933 
_cholesky_inverse_cusolver_potrs_based(Tensor & result,Tensor & infos,bool upper)934 void _cholesky_inverse_cusolver_potrs_based(Tensor& result, Tensor& infos, bool upper) {
935   at::Tensor input_working_copy = cloneBatchedColumnMajor(result);
936   at::Tensor infos_gpu = at::zeros({1}, result.options().dtype(at::kInt));
937   result.fill_(0);
938   result.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(1);
939   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "cholesky_cuda_potri", [&] {
940     apply_cholesky_cusolver_potrs<scalar_t>(result, input_working_copy, upper, infos_gpu);
941   });
942 
943   // Debug only: info of cusolver potrs only check if the i-th parameter is wrong
944   // Function argument `infos` is a CPU tensor, the following copy will cause a device-host sync.
945   // infos.copy_(infos_gpu);
946 }
947 
cholesky_inverse_kernel_impl_cusolver(Tensor & result,Tensor & infos,bool upper)948 Tensor& cholesky_inverse_kernel_impl_cusolver(Tensor &result, Tensor& infos, bool upper) {
949   _cholesky_inverse_cusolver_potrs_based(result, infos, upper);
950   return result;
951 }
952 
953 
954 /*
955   The geqrf function computes the QR decomposition of a m x n matrix A.
956 
957   Args:
958   * `A` - [in] Tensor with matrices for QR decomposition,
959           [out] Tensor containing R in the upper triangle of A
960           and elementary reflectors below the main diagonal of A
961   * `tau` - Tensor containing the magnitudes of the elementary reflectors
962   * `m` - The number of rows of `input` to consider
963   * `n` - The number of columns of `input` to consider (actual sizes of `input` could be larger)
964 
965   For further details, please see the cuSOLVER documentation for GEQRF.
966 */
967 template <typename scalar_t>
apply_geqrf(const Tensor & A,const Tensor & tau)968 static void apply_geqrf(const Tensor& A, const Tensor& tau) {
969   int64_t m = A.size(-2);
970   int64_t n = A.size(-1);
971   int64_t lda = std::max<int64_t>(1, m);
972   int64_t batch_size = batchCount(A);
973 
974   auto A_stride = matrixStride(A);
975   auto tau_stride = tau.size(-1);
976 
977   auto A_data = A.data_ptr<scalar_t>();
978   auto tau_data = tau.data_ptr<scalar_t>();
979 
980   auto infos = at::zeros({1}, A.options().dtype(at::kInt));
981   auto infos_data = infos.data_ptr<int>();
982 
983   // get the optimal work size and allocate workspace tensor
984 #ifdef USE_CUSOLVER_64_BIT
985   size_t worksize_device; // workspaceInBytesOnDevice
986   size_t worksize_host; // workspaceInBytesOnHost
987   cusolverDnParams_t params = NULL; // use default algorithm (currently it's the only option)
988   at::cuda::solver::xgeqrf_bufferSize<scalar_t>(
989       at::cuda::getCurrentCUDASolverDnHandle(),
990       params,
991       m,
992       n,
993       A_data,
994       lda,
995       tau_data,
996       &worksize_device,
997       &worksize_host);
998 #else
999   int lwork;
1000   int m_32 = cuda_int_cast(m, "m");
1001   int n_32 = cuda_int_cast(n, "n");
1002   int lda_32 = cuda_int_cast(lda, "lda");
1003   at::cuda::solver::geqrf_bufferSize<scalar_t>(
1004       at::cuda::getCurrentCUDASolverDnHandle(), m_32, n_32, A_data, lda_32, &lwork);
1005 #endif // USE_CUSOLVER_64_BIT
1006 
1007   for (decltype(batch_size) i = 0; i < batch_size; i++) {
1008     scalar_t* A_working_ptr = &A_data[i * A_stride];
1009     scalar_t* tau_working_ptr = &tau_data[i * tau_stride];
1010     auto handle = at::cuda::getCurrentCUDASolverDnHandle();
1011 
1012 #ifdef USE_CUSOLVER_64_BIT
1013     // allocate workspace storage on device and host
1014     auto& device_allocator = *at::cuda::getCUDADeviceAllocator();
1015     auto work_device_data = device_allocator.allocate(worksize_device);
1016     auto& host_allocator = *at::getCPUAllocator();
1017     auto work_host_data = host_allocator.allocate(worksize_host);
1018     at::cuda::solver::xgeqrf<scalar_t>(
1019         handle,
1020         params,
1021         m,
1022         n,
1023         A_working_ptr,
1024         lda,
1025         tau_working_ptr,
1026         static_cast<scalar_t*>(work_device_data.get()),
1027         worksize_device,
1028         static_cast<scalar_t*>(work_host_data.get()),
1029         worksize_host,
1030         infos_data);
1031 #else
1032     // allocate workspace storage on device
1033     auto& allocator = *at::cuda::getCUDADeviceAllocator();
1034     auto work_data = allocator.allocate(sizeof(scalar_t) * std::max<int>(1, lwork));
1035     at::cuda::solver::geqrf<scalar_t>(
1036         handle,
1037         m_32,
1038         n_32,
1039         A_working_ptr,
1040         lda_32,
1041         tau_working_ptr,
1042         static_cast<scalar_t*>(work_data.get()),
1043         lwork,
1044         infos_data);
1045 #endif // USE_CUSOLVER_64_BIT
1046   }
1047 
1048   // info from geqrf only reports if the i-th parameter is wrong, not about the matrix singularity
1049   // so we don't need to check it all the time
1050   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.item().toInt() == 0);
1051 }
1052 
1053 // This is a type dispatching helper function for 'apply_geqrf'
geqrf_cusolver(const Tensor & input,const Tensor & tau)1054 void geqrf_cusolver(const Tensor& input, const Tensor& tau) {
1055   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "geqrf_cuda", [&]{
1056     apply_geqrf<scalar_t>(input, tau);
1057   });
1058 }
1059 
1060 /*
1061   The ormqr function multiplies Q with another matrix from a sequence of
1062   elementary reflectors, such as is produced by the geqrf function.
1063 
1064   Args:
1065   * `input`     - Tensor with elementary reflectors below the diagonal,
1066                   encoding the matrix Q.
1067   * `tau`       - Tensor containing the magnitudes of the elementary
1068                   reflectors.
1069   * `other`     - [in] Tensor containing the matrix to be multiplied.
1070                   [out] result of the matrix multiplication with Q.
1071   * `left`      - bool, determining whether `other` is left- or right-multiplied with Q.
1072   * `transpose` - bool, determining whether to transpose (or conjugate transpose) Q before multiplying.
1073 
1074   For further details, please see the cuSOLVER documentation for ORMQR and UNMQR.
1075 */
1076 template <typename scalar_t>
apply_ormqr(const Tensor & input,const Tensor & tau,const Tensor & other,bool left,bool transpose)1077 static void apply_ormqr(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) {
1078   auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
1079   auto trans = transpose ? (input.is_complex() ? CUBLAS_OP_C : CUBLAS_OP_T) : CUBLAS_OP_N;
1080 
1081   auto input_data = input.const_data_ptr<scalar_t>();
1082   auto tau_data = tau.const_data_ptr<scalar_t>();
1083   auto other_data = other.data_ptr<scalar_t>();
1084 
1085   auto input_matrix_stride = matrixStride(input);
1086   auto other_matrix_stride = matrixStride(other);
1087   auto tau_stride = tau.size(-1);
1088   auto batch_size = batchCount(input);
1089   auto m = cuda_int_cast(other.size(-2), "m");
1090   auto n = cuda_int_cast(other.size(-1), "n");
1091   auto k = cuda_int_cast(tau.size(-1), "k");
1092   auto lda = std::max<int>(1, left ? m : n);
1093   auto ldc = std::max<int>(1, m);
1094 
1095   // get the optimal work size and allocate workspace tensor
1096   int lwork;
1097   at::cuda::solver::ormqr_bufferSize<scalar_t>(
1098     at::cuda::getCurrentCUDASolverDnHandle(), side, trans, m, n, k, input_data, lda, tau_data, other_data, ldc, &lwork);
1099 
1100   auto info = at::zeros({1}, input.options().dtype(at::kInt));
1101   auto info_data = info.data_ptr<int>();
1102 
1103   for (auto i = decltype(batch_size){0}; i < batch_size; i++) {
1104     const scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
1105     scalar_t* other_working_ptr = &other_data[i * other_matrix_stride];
1106     const scalar_t* tau_working_ptr = &tau_data[i * tau_stride];
1107     auto handle = at::cuda::getCurrentCUDASolverDnHandle();
1108 
1109     // allocate workspace storage
1110     auto& allocator = *at::cuda::getCUDADeviceAllocator();
1111     auto work_data = allocator.allocate(sizeof(scalar_t)*lwork);
1112 
1113     at::cuda::solver::ormqr<scalar_t>(
1114       handle, side, trans, m, n, k,
1115       input_working_ptr,
1116       lda,
1117       tau_working_ptr,
1118       other_working_ptr,
1119       ldc,
1120       static_cast<scalar_t*>(work_data.get()),
1121       lwork,
1122       info_data
1123     );
1124 
1125     // info from ormqr only reports if the i-th parameter is wrong
1126     // so we don't need to check it all the time
1127     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info.item().toInt() == 0);
1128   }
1129 }
1130 
1131 // This is a type dispatching helper function for 'apply_ormqr'
ormqr_cusolver(const Tensor & input,const Tensor & tau,const Tensor & other,bool left,bool transpose)1132 void ormqr_cusolver(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) {
1133   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "orgmr_cuda", [&]{
1134     apply_ormqr<scalar_t>(input, tau, other, left, transpose);
1135   });
1136 }
1137 
1138 /*
1139   The orgqr function allows reconstruction of an orthogonal (or unitary) matrix Q,
1140   from a sequence of elementary reflectors, such as produced by the geqrf function.
1141 
1142   Args:
1143   * `self` - Tensor with the directions of the elementary reflectors below the diagonal,
1144               it will be overwritten with the result
1145   * `tau` - Tensor containing the magnitudes of the elementary reflectors
1146 
1147   For further details, please see the cuSOLVER documentation for ORGQR and UNGQR.
1148 */
1149 template <typename scalar_t>
apply_orgqr(Tensor & self,const Tensor & tau)1150 inline static void apply_orgqr(Tensor& self, const Tensor& tau) {
1151   auto self_data = self.data_ptr<scalar_t>();
1152   auto tau_data = tau.const_data_ptr<scalar_t>();
1153   auto self_matrix_stride = matrixStride(self);
1154   auto batchsize = cuda_int_cast(batchCount(self), "batch size");
1155   auto m = cuda_int_cast(self.size(-2), "m");
1156   auto n = cuda_int_cast(self.size(-1), "n");
1157   auto k = cuda_int_cast(tau.size(-1), "k");
1158   auto tau_stride = std::max<int>(1, k);
1159   auto lda = std::max<int>(1, m);
1160 
1161   // LAPACK's requirement
1162   TORCH_INTERNAL_ASSERT(m >= n);
1163   TORCH_INTERNAL_ASSERT(n >= k);
1164 
1165   // cuSOLVER doesn't compute anything for this case, which is wrong
1166   // the result should be a matrix with 1 on the diagonal
1167   if (k == 0) {
1168     self.fill_(0);
1169     self.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(1);
1170     return;
1171   }
1172 
1173   // get the optimal work size and allocate workspace tensor
1174   int lwork;
1175   at::cuda::solver::orgqr_buffersize<scalar_t>(
1176     at::cuda::getCurrentCUDASolverDnHandle(), m, n, k, self_data, lda, tau_data, &lwork);
1177 
1178   auto info = at::zeros({1}, self.options().dtype(at::kInt));
1179   auto info_data = info.data_ptr<int>();
1180 
1181   for (auto i = decltype(batchsize){0}; i < batchsize; i++) {
1182     scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
1183     const scalar_t* tau_working_ptr = &tau_data[i * tau_stride];
1184     auto handle = at::cuda::getCurrentCUDASolverDnHandle();
1185 
1186     // allocate workspace storage
1187     auto& allocator = *at::cuda::getCUDADeviceAllocator();
1188     auto work_data = allocator.allocate(sizeof(scalar_t)*lwork);
1189 
1190     at::cuda::solver::orgqr<scalar_t>(
1191       handle, m, n, k,
1192       self_working_ptr,
1193       lda,
1194       tau_working_ptr,
1195       static_cast<scalar_t*>(work_data.get()),
1196       lwork,
1197       info_data
1198     );
1199 
1200     // info from orgqr only reports if the i-th parameter is wrong
1201     // so we don't need to check it all the time
1202     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info.item().toInt() == 0);
1203   }
1204 }
1205 
1206 // This is a type dispatching helper function for 'apply_orgqr'
orgqr_helper_cusolver(Tensor & result,const Tensor & tau)1207 Tensor& orgqr_helper_cusolver(Tensor& result, const Tensor& tau) {
1208   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "orgqr_cuda", [&]{
1209     apply_orgqr<scalar_t>(result, tau);
1210   });
1211   return result;
1212 }
1213 
1214 template <typename scalar_t>
apply_syevd(const Tensor & values,const Tensor & vectors,const Tensor & infos,bool upper,bool compute_eigenvectors)1215 static void apply_syevd(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
1216   using value_t = typename c10::scalar_value_type<scalar_t>::type;
1217 
1218   cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
1219   cusolverEigMode_t jobz = compute_eigenvectors ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
1220 
1221   int64_t n = vectors.size(-1);
1222   int64_t lda = std::max<int64_t>(1, n);
1223   int64_t batch_size = batchCount(vectors);
1224 
1225   auto vectors_stride = matrixStride(vectors);
1226   auto values_stride = values.size(-1);
1227 
1228   auto vectors_data = vectors.data_ptr<scalar_t>();
1229   auto values_data = values.data_ptr<value_t>();
1230   auto infos_data = infos.data_ptr<int>();
1231 
1232   // get the optimal work size and allocate workspace tensor
1233 #ifdef USE_CUSOLVER_64_BIT
1234   size_t worksize_device; // workspaceInBytesOnDevice
1235   size_t worksize_host; // workspaceInBytesOnHost
1236   cusolverDnParams_t params = NULL; // use default algorithm (currently it's the only option)
1237   at::cuda::solver::xsyevd_bufferSize<scalar_t>(
1238       at::cuda::getCurrentCUDASolverDnHandle(),
1239       params,
1240       jobz,
1241       uplo,
1242       n,
1243       vectors_data,
1244       lda,
1245       values_data,
1246       &worksize_device,
1247       &worksize_host);
1248 #else
1249   int lwork;
1250   int n_32 = cuda_int_cast(n, "n");
1251   int lda_32 = cuda_int_cast(lda, "lda");
1252   at::cuda::solver::syevd_bufferSize<scalar_t>(
1253       at::cuda::getCurrentCUDASolverDnHandle(), jobz, uplo, n_32, vectors_data, lda_32, values_data, &lwork);
1254 #endif // USE_CUSOLVER_64_BIT
1255 
1256   for (decltype(batch_size) i = 0; i < batch_size; i++) {
1257     scalar_t* vectors_working_ptr = &vectors_data[i * vectors_stride];
1258     value_t* values_working_ptr = &values_data[i * values_stride];
1259     int* info_working_ptr = &infos_data[i];
1260     auto handle = at::cuda::getCurrentCUDASolverDnHandle();
1261 
1262 #ifdef USE_CUSOLVER_64_BIT
1263     // allocate workspace storage on device and host
1264     auto& device_allocator = *at::cuda::getCUDADeviceAllocator();
1265     auto work_device_data = device_allocator.allocate(worksize_device);
1266     auto& host_allocator = *at::getCPUAllocator();
1267     auto work_host_data = host_allocator.allocate(worksize_host);
1268     at::cuda::solver::xsyevd<scalar_t>(
1269         handle,
1270         params,
1271         jobz,
1272         uplo,
1273         n,
1274         vectors_working_ptr,
1275         lda,
1276         values_working_ptr,
1277         static_cast<scalar_t*>(work_device_data.get()),
1278         worksize_device,
1279         static_cast<scalar_t*>(work_host_data.get()),
1280         worksize_host,
1281         info_working_ptr);
1282 #else
1283     // allocate workspace storage on device
1284     auto& allocator = *at::cuda::getCUDADeviceAllocator();
1285     auto work_data = allocator.allocate(sizeof(scalar_t) * lwork);
1286     at::cuda::solver::syevd<scalar_t>(
1287         handle,
1288         jobz,
1289         uplo,
1290         n_32,
1291         vectors_working_ptr,
1292         lda_32,
1293         values_working_ptr,
1294         static_cast<scalar_t*>(work_data.get()),
1295         lwork,
1296         info_working_ptr);
1297 #endif // USE_CUSOLVER_64_BIT
1298   }
1299 }
1300 
1301 template <typename scalar_t>
apply_syevj(const Tensor & values,const Tensor & vectors,const Tensor & infos,bool upper,bool compute_eigenvectors)1302 static void apply_syevj(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
1303   using value_t = typename c10::scalar_value_type<scalar_t>::type;
1304 
1305   cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
1306   cusolverEigMode_t jobz = compute_eigenvectors ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
1307 
1308   int n = cuda_int_cast(vectors.size(-1), "n");
1309   int lda = std::max<int>(1, n);
1310   auto batch_size = batchCount(vectors);
1311 
1312   auto vectors_stride = matrixStride(vectors);
1313   auto values_stride = values.size(-1);
1314 
1315   auto vectors_data = vectors.data_ptr<scalar_t>();
1316   auto values_data = values.data_ptr<value_t>();
1317   auto infos_data = infos.data_ptr<int>();
1318 
1319   // syevj_params controls the numerical accuracy of syevj
1320   // by default the tolerance is set to machine accuracy
1321   // the maximum number of iteration of Jacobi method by default is 100
1322   // cuSOLVER documentations says: "15 sweeps are good enough to converge to machine accuracy"
1323   // LAPACK has SVD routine based on similar Jacobi algorithm (gesvj) and there a maximum of 30 iterations is set
1324   // Let's use the default values for now
1325   syevjInfo_t syevj_params;
1326   TORCH_CUSOLVER_CHECK(cusolverDnCreateSyevjInfo(&syevj_params));
1327 
1328   // get the optimal work size and allocate workspace tensor
1329   int lwork;
1330   at::cuda::solver::syevj_bufferSize<scalar_t>(
1331       at::cuda::getCurrentCUDASolverDnHandle(), jobz, uplo, n, vectors_data, lda, values_data, &lwork, syevj_params);
1332 
1333   for (decltype(batch_size) i = 0; i < batch_size; i++) {
1334     scalar_t* vectors_working_ptr = &vectors_data[i * vectors_stride];
1335     value_t* values_working_ptr = &values_data[i * values_stride];
1336     int* info_working_ptr = &infos_data[i];
1337     auto handle = at::cuda::getCurrentCUDASolverDnHandle();
1338 
1339     // allocate workspace storage on device
1340     auto& allocator = *at::cuda::getCUDADeviceAllocator();
1341     auto work_data = allocator.allocate(sizeof(scalar_t) * lwork);
1342     at::cuda::solver::syevj<scalar_t>(
1343         handle,
1344         jobz,
1345         uplo,
1346         n,
1347         vectors_working_ptr,
1348         lda,
1349         values_working_ptr,
1350         static_cast<scalar_t*>(work_data.get()),
1351         lwork,
1352         info_working_ptr,
1353         syevj_params);
1354   }
1355   TORCH_CUSOLVER_CHECK(cusolverDnDestroySyevjInfo(syevj_params));
1356 }
1357 
1358 template <typename scalar_t>
apply_syevj_batched(const Tensor & values,const Tensor & vectors,const Tensor & infos,bool upper,bool compute_eigenvectors)1359 static void apply_syevj_batched(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
1360   using value_t = typename c10::scalar_value_type<scalar_t>::type;
1361 
1362   cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
1363   cusolverEigMode_t jobz = compute_eigenvectors ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
1364 
1365   int n = cuda_int_cast(vectors.size(-1), "n");
1366   int lda = std::max<int>(1, n);
1367   int batch_size = cuda_int_cast(batchCount(vectors), "batch_size");
1368 
1369   auto vectors_data = vectors.data_ptr<scalar_t>();
1370   auto values_data = values.data_ptr<value_t>();
1371   auto infos_data = infos.data_ptr<int>();
1372 
1373   // syevj_params controls the numerical accuracy of syevj
1374   // by default the tolerance is set to machine accuracy
1375   // the maximum number of iteration of Jacobi method by default is 100
1376   // cuSOLVER documentations says: "15 sweeps are good enough to converge to machine accuracy"
1377   // LAPACK has SVD routine based on similar Jacobi algorithm (gesvj) and there a maximum of 30 iterations is set
1378   // Let's use the default values for now
1379   syevjInfo_t syevj_params;
1380   TORCH_CUSOLVER_CHECK(cusolverDnCreateSyevjInfo(&syevj_params));
1381   TORCH_CUSOLVER_CHECK(cusolverDnXsyevjSetSortEig(syevj_params, 1));
1382 
1383   auto handle = at::cuda::getCurrentCUDASolverDnHandle();
1384 
1385   // get the optimal work size and allocate workspace tensor
1386   int lwork;
1387   at::cuda::solver::syevjBatched_bufferSize<scalar_t>(
1388       handle,
1389       jobz,
1390       uplo,
1391       n,
1392       vectors_data,
1393       lda,
1394       values_data,
1395       &lwork,
1396       syevj_params,
1397       batch_size);
1398 
1399   // allocate workspace storage on device
1400   auto& allocator = *at::cuda::getCUDADeviceAllocator();
1401   auto work_data = allocator.allocate(sizeof(scalar_t) * lwork);
1402   at::cuda::solver::syevjBatched<scalar_t>(
1403       handle,
1404       jobz,
1405       uplo,
1406       n,
1407       vectors_data,
1408       lda,
1409       values_data,
1410       static_cast<scalar_t*>(work_data.get()),
1411       lwork,
1412       infos_data,
1413       syevj_params,
1414       batch_size);
1415   TORCH_CUSOLVER_CHECK(cusolverDnDestroySyevjInfo(syevj_params));
1416 }
1417 
linalg_eigh_cusolver_syevd(const Tensor & eigenvalues,const Tensor & eigenvectors,const Tensor & infos,bool upper,bool compute_eigenvectors)1418 static void linalg_eigh_cusolver_syevd(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
1419   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(eigenvectors.scalar_type(), "linalg_eigh_cuda", [&] {
1420     apply_syevd<scalar_t>(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
1421   });
1422 }
1423 
linalg_eigh_cusolver_syevj(const Tensor & eigenvalues,const Tensor & eigenvectors,const Tensor & infos,bool upper,bool compute_eigenvectors)1424 static void linalg_eigh_cusolver_syevj(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
1425   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(eigenvectors.scalar_type(), "linalg_eigh_cuda", [&] {
1426     apply_syevj<scalar_t>(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
1427   });
1428 }
1429 
linalg_eigh_cusolver_syevj_batched(const Tensor & eigenvalues,const Tensor & eigenvectors,const Tensor & infos,bool upper,bool compute_eigenvectors)1430 static void linalg_eigh_cusolver_syevj_batched(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
1431   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(eigenvectors.scalar_type(), "linalg_eigh_cuda", [&] {
1432     apply_syevj_batched<scalar_t>(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
1433   });
1434 }
1435 
linalg_eigh_cusolver(const Tensor & eigenvalues,const Tensor & eigenvectors,const Tensor & infos,bool upper,bool compute_eigenvectors)1436 void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
1437   // for ROCm's hipSolver, syevj is fastest.
1438 #ifdef USE_ROCM
1439   linalg_eigh_cusolver_syevj(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
1440 #else
1441   if (use_cusolver_syevj_batched_ && batchCount(eigenvectors) > 1 && eigenvectors.size(-1) <= 32) {
1442     // Use syevjBatched for batched matrix operation when matrix size <= 32
1443     // See https://github.com/pytorch/pytorch/pull/53040#issuecomment-788264724
1444     linalg_eigh_cusolver_syevj_batched(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
1445   } else if (eigenvectors.scalar_type() == at::kFloat && eigenvectors.size(-1) >= 32 && eigenvectors.size(-1) <= 512) {
1446     // syevj is better than syevd for float32 dtype and matrix sizes 32x32 - 512x512
1447     // See https://github.com/pytorch/pytorch/pull/53040#issuecomment-788264724
1448     linalg_eigh_cusolver_syevj(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
1449   } else {
1450     linalg_eigh_cusolver_syevd(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
1451   }
1452 #endif
1453 }
1454 
1455 // The 'apply_' word is used for templated by dtype functions that call an API routine
1456 // underneath. Since the cusolver API has a slightly different structure we do not prepend
1457 // apply_ to this function.
lu_factor_looped_cusolver(const Tensor & self,const Tensor & pivots,const Tensor & infos,bool get_pivots)1458 void lu_factor_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots) {
1459   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
1460     self.scalar_type(),
1461     "lu_factor_cusolver",
1462     [&self,
1463      &pivots,
1464      &infos,
1465      get_pivots]() {
1466     const auto m = cuda_int_cast(self.size(-2), "m");
1467     const auto n = cuda_int_cast(self.size(-1), "n");
1468     const auto lda = std::max<int>(1, m);
1469     const auto self_stride = matrixStride(self);
1470     const auto batch_size = batchCount(self);
1471     const auto self_data = self.data_ptr<scalar_t>();
1472     const auto infos_data = infos.data_ptr<int>();
1473 
1474     const auto pivots_data = get_pivots ? pivots.data_ptr<int>() : nullptr;
1475     const auto pivots_stride = get_pivots ? pivots.size(-1) : 0;
1476 
1477     const auto handle = at::cuda::getCurrentCUDASolverDnHandle();
1478     for (auto batch = decltype(batch_size){0}; batch < batch_size; ++batch) {
1479       at::cuda::solver::getrf<scalar_t>(
1480         handle, m, n,
1481         self_data + batch * self_stride,
1482         lda,
1483         get_pivots ? pivots_data + batch * pivots_stride : nullptr,
1484         infos_data + batch
1485       );
1486     }
1487   });
1488 
1489   // Necessary because cuSOLVER uses nan for outputs that correspond to 0 in MAGMA for non-pivoted LU.
1490   // https://github.com/pytorch/pytorch/issues/53879#issuecomment-830633572
1491   if (!get_pivots) {
1492     // nan_to_num does not work for complex inputs
1493     // https://github.com/pytorch/pytorch/issues/59247
1494     if (self.is_complex()) {
1495       self.copy_(at::where(self.eq(self), self,  at::scalar_tensor(0., self.options())));
1496     } else {
1497       at::nan_to_num_(const_cast<Tensor&>(self), 0, std::numeric_limits<double>::infinity(),
1498         -std::numeric_limits<double>::infinity());
1499     }
1500   }
1501 }
1502 
lu_solve_looped_cusolver(const Tensor & LU,const Tensor & pivots,const Tensor & B,TransposeType transpose)1503 void lu_solve_looped_cusolver(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) {
1504   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(LU.scalar_type(), "lu_solve_cusolver", [&] {
1505     const auto trans = to_cublas(transpose);
1506     int n = cuda_int_cast(LU.size(-2), "n");
1507     int nrhs = cuda_int_cast(B.size(-1), "nrhs");
1508     auto batch_size = batchCount(B);
1509     auto info = at::zeros({1}, LU.options().dtype(kInt));
1510     auto info_data = info.data_ptr<int>();
1511     auto b_data = B.data_ptr<scalar_t>();
1512     auto lu_data = LU.data_ptr<scalar_t>();
1513     auto pivots_data = pivots.data_ptr<int>();
1514     auto pivots_stride = pivots.dim() > 1 ? pivots.stride(-2) : 0;
1515     auto lu_stride = LU.dim() > 2 ? LU.stride(-3) : 0;
1516     auto b_stride = matrixStride(B);
1517     int leading_dimension = cuda_int_cast(std::max<int>(1, n), "leading_dimension");
1518 
1519     // lu and pivots tensors can be broadcast to b
1520     // here we construct a helper indexing tensor to linearly index into lu and pivots
1521     IntArrayRef lu_batch_shape(LU.sizes().data(), LU.dim() - 2);
1522     IntArrayRef b_batch_shape(B.sizes().data(), B.dim() - 2);
1523     BroadcastLinearIndices lu_index(
1524         batchCount(LU), lu_batch_shape, b_batch_shape);
1525 
1526     auto handle = at::cuda::getCurrentCUDASolverDnHandle();
1527     for (auto batch = decltype(batch_size){0}; batch < batch_size; ++batch) {
1528       int64_t lu_index_i = lu_index(batch);
1529       at::cuda::solver::getrs<scalar_t>(
1530         handle,
1531         n,
1532         nrhs,
1533         lu_data + lu_index_i * lu_stride,
1534         leading_dimension,
1535         pivots_data + lu_index_i * pivots_stride,
1536         b_data + batch * b_stride,
1537         leading_dimension,
1538         info_data,
1539         trans);
1540 
1541         TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info.item().toInt() == 0);
1542     }
1543   });
1544 }
1545 
1546 #endif  // USE_LINALG_SOLVER
1547 
1548 } // namespace at::native
1549