1 // See Note [BatchLinearAlgebraLib split implementation files]
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 #include <ATen/Context.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/ExpandUtils.h>
7 #include <ATen/cuda/PinnedMemoryAllocator.h>
8 #include <ATen/cuda/CUDABlas.h>
9 #include <ATen/cuda/CUDAEvent.h>
10 #include <c10/cuda/CUDAStream.h>
11 #include <c10/util/irange.h>
12
13 #include <ATen/native/LinearAlgebraUtils.h>
14 #include <ATen/native/TransposeType.h>
15 #include <ATen/native/cuda/MiscUtils.h>
16 #include <ATen/native/cuda/linalg/CUDASolver.h>
17 #include <ATen/native/cuda/linalg/BatchLinearAlgebraLib.h>
18
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/Functions.h>
21 #else
22 #include <ATen/ops/arange.h>
23 #include <ATen/ops/empty.h>
24 #include <ATen/ops/nan_to_num.h>
25 #include <ATen/ops/ones.h>
26 #include <ATen/ops/scalar_tensor.h>
27 #include <ATen/ops/where.h>
28 #include <ATen/ops/zeros.h>
29 #endif
30
31 namespace at::native {
32
to_cublas(TransposeType trans)33 static cublasOperation_t to_cublas(TransposeType trans) {
34 switch (trans) {
35 case TransposeType::NoTranspose: return CUBLAS_OP_N;
36 case TransposeType::Transpose: return CUBLAS_OP_T;
37 case TransposeType::ConjTranspose: return CUBLAS_OP_C;
38 }
39 TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
40 }
41
42 // Some cuBLAS and cuSOLVER batched routines require input to be a device array of pointers to device individual matrices
43 // 'input' must be a contiguous tensor
44 template <typename scalar_t>
get_device_pointers(const Tensor & input)45 static Tensor get_device_pointers(const Tensor& input) {
46 auto input_data = input.const_data_ptr<scalar_t>();
47 int64_t input_mat_stride = matrixStride(input);
48
49 // cublas/cusolver interface requires 'int'
50 int batch_size = cuda_int_cast(batchCount(input), "batch_size");
51
52 // if batch_size==0, then start=0 and end=0
53 // if input_mat_stride==0, then step=sizeof(scalar_t)
54 return at::arange(
55 /*start=*/reinterpret_cast<int64_t>(input_data),
56 /*end=*/reinterpret_cast<int64_t>(input_data + batch_size * input_mat_stride),
57 /*step=*/static_cast<int64_t>(std::max<int64_t>(input_mat_stride, 1) * sizeof(scalar_t)),
58 input.options().dtype(at::kLong));
59 }
60
61 namespace {
62
63 template <typename scalar_t>
apply_ldl_factor_cusolver(const Tensor & A,const Tensor & pivots,const Tensor & info,bool upper)64 void apply_ldl_factor_cusolver(
65 const Tensor& A,
66 const Tensor& pivots,
67 const Tensor& info,
68 bool upper) {
69 #if !defined(USE_LINALG_SOLVER)
70 TORCH_CHECK(
71 false,
72 "Calling torch.linalg.ldl_factor on a CUDA tensor requires compiling ",
73 "PyTorch with cuSOLVER. Please use PyTorch built with cuSOLVER support.");
74 #else
75 auto batch_size = batchCount(A);
76 auto n = cuda_int_cast(A.size(-2), "A.size(-2)");
77 auto lda = cuda_int_cast(A.stride(-1), "A.stride(-1)");
78 auto uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
79
80 auto a_stride = A.dim() > 2 ? A.stride(-3) : 0;
81 auto pivots_stride = pivots.dim() > 1 ? pivots.stride(-2) : 0;
82
83 auto a_data = A.data_ptr<scalar_t>();
84 auto pivots_data = pivots.data_ptr<int>();
85 auto info_data = info.data_ptr<int>();
86
87 auto handle = at::cuda::getCurrentCUDASolverDnHandle();
88
89 int lwork = 0;
90 at::cuda::solver::sytrf_bufferSize(handle, n, a_data, lda, &lwork);
91 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
92 auto work = allocator.allocate(sizeof(scalar_t) * lwork);
93
94 for (const auto i : c10::irange(batch_size)) {
95 auto* a_working_ptr = &a_data[i * a_stride];
96 auto* pivots_working_ptr = &pivots_data[i * pivots_stride];
97 auto* info_working_ptr = &info_data[i];
98 at::cuda::solver::sytrf(
99 handle,
100 uplo,
101 n,
102 a_working_ptr,
103 lda,
104 pivots_working_ptr,
105 reinterpret_cast<scalar_t*>(work.get()),
106 lwork,
107 info_working_ptr);
108 }
109 #endif
110 }
111
112 template <typename scalar_t>
apply_ldl_solve_cusolver(const Tensor & A,const Tensor & pivots,const Tensor & B,bool upper)113 void apply_ldl_solve_cusolver(
114 const Tensor& A,
115 const Tensor& pivots,
116 const Tensor& B,
117 bool upper) {
118 #if !(defined(CUDART_VERSION) && defined(CUSOLVER_VERSION) && \
119 CUSOLVER_VERSION >= 11102)
120 TORCH_CHECK(
121 false,
122 "Calling torch.linalg.ldl_solve on a CUDA tensor requires compiling ",
123 "PyTorch with cuSOLVER. Please use PyTorch built with cuSOLVER 11.1.2+ (CUDA 11.3.1+) support.");
124 #else
125 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(batchCount(A) > 0);
126 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(batchCount(pivots.unsqueeze(-1)) > 0);
127 auto batch_size = batchCount(B);
128 auto n = A.size(-2);
129 auto nrhs = B.size(-1);
130 auto lda = A.stride(-1);
131 auto ldb = B.stride(-1);
132 auto uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
133
134 auto a_stride = A.dim() > 2 ? A.stride(-3) : 0;
135 auto b_stride = B.dim() > 2 ? B.stride(-3) : 0;
136 auto pivots_stride = pivots.dim() > 1 ? pivots.stride(-2) : 0;
137
138 auto a_data = A.const_data_ptr<scalar_t>();
139 auto b_data = B.data_ptr<scalar_t>();
140
141 auto pivots_ = pivots.to(kLong);
142 auto pivots_data = pivots_.const_data_ptr<int64_t>();
143
144 // needed to run ldl_solve tests in parallel
145 // see https://github.com/pytorch/pytorch/issues/82894 for examples of failures
146 c10::cuda::device_synchronize();
147 auto handle = at::cuda::getCurrentCUDASolverDnHandle();
148 auto datatype = at::cuda::solver::get_cusolver_datatype<scalar_t>();
149 size_t worksize_device = 0;
150 size_t worksize_host = 0;
151
152 TORCH_CUSOLVER_CHECK(cusolverDnXsytrs_bufferSize(
153 handle,
154 uplo,
155 n,
156 nrhs,
157 datatype,
158 a_data,
159 lda,
160 pivots_data,
161 datatype,
162 b_data,
163 ldb,
164 &worksize_device,
165 &worksize_host));
166
167 // allocate workspace storage
168 auto& device_allocator = *at::cuda::getCUDADeviceAllocator();
169 auto workdata_device = device_allocator.allocate(worksize_device);
170 void* workdata_device_ptr = workdata_device.get();
171
172 auto& host_allocator = *at::getCPUAllocator();
173 auto workdata_host = host_allocator.allocate(worksize_host);
174 void* workdata_host_ptr = workdata_host.get();
175
176 Tensor info = at::zeros({}, A.options().dtype(at::kInt));
177 for (const auto i : c10::irange(batch_size)) {
178 const auto* a_working_ptr = &a_data[i * a_stride];
179 auto* b_working_ptr = &b_data[i * b_stride];
180 const auto* pivots_working_ptr = &pivots_data[i * pivots_stride];
181 TORCH_CUSOLVER_CHECK(cusolverDnXsytrs(
182 handle,
183 uplo,
184 n,
185 nrhs,
186 datatype,
187 a_working_ptr,
188 lda,
189 pivots_working_ptr,
190 datatype,
191 b_working_ptr,
192 ldb,
193 workdata_device_ptr,
194 worksize_device,
195 workdata_host_ptr,
196 worksize_host,
197 info.data_ptr<int>()));
198 }
199
200 // info from sytrs only reports if the i-th parameter is wrong
201 // so we don't need to check it all the time
202 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info.item().toInt() == 0);
203 #endif
204 }
205
206 } // anonymous namespace
207
ldl_factor_cusolver(const Tensor & LD,const Tensor & pivots,const Tensor & info,bool upper,bool hermitian)208 void ldl_factor_cusolver(
209 const Tensor& LD,
210 const Tensor& pivots,
211 const Tensor& info,
212 bool upper,
213 bool hermitian) {
214 if (LD.is_complex()) {
215 TORCH_CHECK(
216 !hermitian,
217 "torch.linalg.ldl_factor: complex tensors with hermitian=True flag are not supported with cuSOLVER backend. ",
218 "Currently preferred backend is ",
219 at::globalContext().linalgPreferredBackend(),
220 ", please set 'default' or 'magma' backend with torch.backends.cuda.preferred_linalg_library");
221 }
222 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
223 LD.scalar_type(), "ldl_factor_looped_cusolver", [&] {
224 apply_ldl_factor_cusolver<scalar_t>(LD, pivots, info, upper);
225 });
226 }
227
ldl_solve_cusolver(const Tensor & LD,const Tensor & pivots,const Tensor & B,bool upper)228 void ldl_solve_cusolver(
229 const Tensor& LD,
230 const Tensor& pivots,
231 const Tensor& B,
232 bool upper) {
233 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
234 LD.scalar_type(), "ldl_solve_looped_cusolver", [&] {
235 apply_ldl_solve_cusolver<scalar_t>(LD, pivots, B, upper);
236 });
237 }
238
239 #if defined(USE_LINALG_SOLVER)
240
column_major_identity_matrix_like(const Tensor & self)241 inline static Tensor column_major_identity_matrix_like(const Tensor& self) {
242 auto size = self.sizes();
243 auto size_slice = IntArrayRef(size.data(), size.size()-1);
244 return at::ones(size_slice, self.options()).diag_embed().mT();
245 }
246
247
248 // call cusolver gesvd function to calculate svd
249 template<typename scalar_t>
apply_svd_cusolver_gesvd(const Tensor & A,const Tensor & U,const Tensor & S,const Tensor & V,const Tensor & infos,bool full_matrices,bool compute_uv,const bool calculate_all_batches,const std::vector<int64_t> & batches)250 inline static void apply_svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
251 const Tensor& infos, bool full_matrices, bool compute_uv,
252 const bool calculate_all_batches,
253 const std::vector<int64_t>& batches
254 ) {
255 using value_t = typename c10::scalar_value_type<scalar_t>::type;
256 auto A_data = A.data_ptr<scalar_t>();
257 auto S_data = S.data_ptr<value_t>();
258 auto A_stride = matrixStride(A);
259 auto S_stride = S.size(-1);
260
261 int m = cuda_int_cast(A.size(-2), "m");
262 int n = cuda_int_cast(A.size(-1), "n");
263 auto k = std::min(m, n);
264 int lda = std::max<int>(1, m);
265 int ldvh = std::max<int>(1, n);
266
267 TORCH_INTERNAL_ASSERT(m >= n, "cusolver gesvd only supports matrix with sizes m >= n");
268
269 char job = compute_uv ? (full_matrices ? 'A' : 'S') : 'N';
270 auto handle = at::cuda::getCurrentCUDASolverDnHandle();
271
272 int lwork = -1;
273 at::cuda::solver::gesvd_buffersize<scalar_t>(handle, m, n, &lwork);
274 TORCH_INTERNAL_ASSERT(lwork >= 0, "gesvd_buffersize failed to get needed buffer size, got lwork = ", lwork);
275
276 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
277 const auto dataPtr_work = allocator.allocate(sizeof(scalar_t)*lwork);
278 const auto dataPtr_rwork = allocator.allocate(sizeof(value_t)*std::min(m, n));
279
280 // nb. We can do this .view() because V is a batch of F-contig matrices
281 const auto V_view = compute_uv ? V.view({-1, n, V.size(-1)})
282 : Tensor{};
283 // V is F-contig. Since this function computes Vh, we need an auxiliary F-conj-transposed matrix to hold Vh
284 const auto Vh_workspace = compute_uv ? at::empty({n, full_matrices ? n : k},
285 A.options().memory_format(at::MemoryFormat::Contiguous)).conj()
286 : Tensor{};
287 const auto Vh_ptr = compute_uv ? Vh_workspace.data_ptr<scalar_t>()
288 : nullptr;
289
290 const auto U_stride = compute_uv ? matrixStride(U) : 0;
291 const auto U_ptr = compute_uv ? U.data_ptr<scalar_t>() : nullptr;
292
293 int batchsize = calculate_all_batches ? cuda_int_cast(batchCount(A), "batch size")
294 : batches.size();
295
296
297 for(int _i = 0; _i < batchsize; _i++){
298 int i = calculate_all_batches ? _i : batches[_i];
299
300 at::cuda::solver::gesvd<scalar_t>(
301 handle, job, job, m, n,
302 A_data + i * A_stride,
303 lda,
304 S_data + i * S_stride,
305 compute_uv ? U_ptr + i * U_stride : nullptr,
306 lda,
307 compute_uv ? Vh_ptr : nullptr,
308 ldvh,
309 reinterpret_cast<scalar_t*>(dataPtr_work.get()),
310 lwork,
311 reinterpret_cast<value_t*>(dataPtr_rwork.get()),
312 infos.data_ptr<int>() + i
313 );
314
315 if (compute_uv) {
316 V_view[i].copy_(Vh_workspace);
317 }
318 }
319 }
320
321 // We'll copy A inside svd_cusolver_gesvd
svd_cusolver_gesvd(const Tensor & A,const Tensor & U,const Tensor & S,const Tensor & V,const Tensor & infos,bool full_matrices,bool compute_uv,const bool calculate_all_batches=true,const std::vector<int64_t> & batches={} )322 inline static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
323 const Tensor& infos, bool full_matrices, bool compute_uv,
324 const bool calculate_all_batches = true,
325 const std::vector<int64_t>& batches = {}
326 ) {
327 // We need to pass a copy of A, as it will be overwritten
328 // gesvd just knows how to handle m >= n, so in the other case we need to transpose A
329 const auto not_A_H = A.size(-2) >= A.size(-1);
330 Tensor Vcopy = V; // Shallow copy
331 #ifdef USE_ROCM
332 // Similar to the case in svd_magma(), experiments have shown Vh tensor is
333 // not guaranteed to be column major on ROCM, we have to create a copy to
334 // deal with this
335 if (!not_A_H) {
336 Vcopy = at::empty_like(V.mT(),
337 V.options()
338 .device(V.device())
339 .memory_format(at::MemoryFormat::Contiguous)).mT();
340 }
341 #endif
__anon128ac59f0402null342 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "svd_cuda_gesvd", [&] {
343 apply_svd_cusolver_gesvd<scalar_t>(cloneBatchedColumnMajor(not_A_H ? A : A.mH()),
344 not_A_H ? U : Vcopy,
345 S,
346 not_A_H ? Vcopy : U,
347 infos,
348 full_matrices, compute_uv, calculate_all_batches, batches);
349 });
350 #ifdef USE_ROCM
351 if (!not_A_H) {
352 V.copy_(Vcopy);
353 }
354 #endif
355 }
356
357 // call cusolver gesvdj function to calculate svd
358 template<typename scalar_t>
apply_svd_cusolver_gesvdj(const Tensor & A,const Tensor & U,const Tensor & S,const Tensor & V,const Tensor & infos,bool full_matrices,bool compute_uv)359 inline static void apply_svd_cusolver_gesvdj(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
360 const Tensor& infos, bool full_matrices, bool compute_uv) {
361 using value_t = typename c10::scalar_value_type<scalar_t>::type;
362 int m = cuda_int_cast(A.size(-2), "m");
363 int n = cuda_int_cast(A.size(-1), "n");
364 int k = std::min(m, n);
365
366 // Need to pass allocated memory to the function, otherwise it fails
367 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
368 auto dataPtr_U = !compute_uv ? allocator.allocate(sizeof(scalar_t)* m * k) : c10::DataPtr{};
369 auto dataPtr_V = !compute_uv ? allocator.allocate(sizeof(scalar_t)* n * k) : c10::DataPtr{};
370
371 auto A_data = A.data_ptr<scalar_t>();
372 auto U_data = compute_uv ? U.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_U.get());
373 auto S_data = S.data_ptr<value_t>();
374 auto V_data = compute_uv ? V.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_V.get());
375 auto A_stride = matrixStride(A);
376 auto U_stride = compute_uv ? matrixStride(U) : 0;
377 auto S_stride = S.size(-1);
378 auto V_stride = compute_uv ? matrixStride(V) : 0;
379
380 int batchsize = cuda_int_cast(batchCount(A), "batch size");
381 int lda = A.stride(-1);
382 int ldu = compute_uv ? U.stride(-1) : m;
383 int ldv = compute_uv ? V.stride(-1) : n;
384
385 auto handle = at::cuda::getCurrentCUDASolverDnHandle();
386 auto jobz = compute_uv ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
387 int econ = full_matrices ? 0 : 1;
388
389 // gesvdj_params controls the numerical accuracy of cusolver gesvdj iterations on GPU
390 gesvdjInfo_t gesvdj_params;
391 TORCH_CUSOLVER_CHECK(cusolverDnCreateGesvdjInfo(&gesvdj_params));
392
393 // Todo: expose the following two parameters to users
394 TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetTolerance(gesvdj_params, std::numeric_limits<scalar_t>::epsilon()));
395 TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetMaxSweeps(gesvdj_params, cusolver_gesvdj_max_sweeps));
396
397 int lwork = -1;
398 at::cuda::solver::gesvdj_buffersize<scalar_t>(
399 handle, jobz, econ, m, n, A_data, lda, S_data, U_data, ldu, V_data, ldv, &lwork, gesvdj_params);
400 TORCH_INTERNAL_ASSERT(lwork >= 0, "gesvdj_buffersize failed to get needed buffer size, got lwork = ", lwork);
401
402 auto dataPtr = allocator.allocate(sizeof(scalar_t)*lwork);
403
404 for(int i = 0; i < batchsize; i++){
405 at::cuda::solver::gesvdj<scalar_t>(
406 handle, jobz, econ, m, n,
407 A_data + i * A_stride,
408 lda,
409 S_data + i * S_stride,
410 U_data + i * U_stride,
411 ldu,
412 V_data + i * V_stride,
413 ldv,
414 reinterpret_cast<scalar_t*>(dataPtr.get()),
415 lwork,
416 infos.data_ptr<int>() + i,
417 gesvdj_params
418 );
419
420 // // The following code can be used to check or report the gesvdj residual.
421 // // Note: this will introduce a device-host sync and may negatively affect the performance
422 // double residual = 0;
423 // TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjGetResidual(handle, gesvdj_params, &residual));
424 // printf("gesvdj residual = %.6e\n", residual);
425 }
426
427 TORCH_CUSOLVER_CHECK(cusolverDnDestroyGesvdjInfo(gesvdj_params));
428 }
429
430 // wrapper around apply_svd_cusolver_gesvdj that handles dtype dispatch
431 // note that gesvdj returns V, which is what we want
432 // Need to pass a copy of A, since A will be rewritten inside the function call
svd_cusolver_gesvdj(const Tensor & A,const Tensor & U,const Tensor & S,const Tensor & V,const Tensor & infos,bool full_matrices,bool compute_uv)433 inline static void svd_cusolver_gesvdj(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V, const Tensor& infos, bool full_matrices, bool compute_uv) {
434 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "svd_cuda_gesvdj", [&] {
435 apply_svd_cusolver_gesvdj<scalar_t>(A, U, S, V, infos, full_matrices, compute_uv);
436 });
437 }
438
439 // call cusolver gesvdj batched function to calculate svd
440 template<typename scalar_t>
apply_svd_cusolver_gesvdjBatched(const Tensor & A,const Tensor & U,const Tensor & S,const Tensor & V,const Tensor & infos,bool compute_uv)441 inline static void apply_svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
442 const Tensor& infos, bool compute_uv
443 ) {
444 using value_t = typename c10::scalar_value_type<scalar_t>::type;
445 int m = cuda_int_cast(A.size(-2), "m");
446 int n = cuda_int_cast(A.size(-1), "n");
447 int batchsize = cuda_int_cast(batchCount(A), "batch size");
448 int lda = A.stride(-1);
449 int ldu = compute_uv ? U.stride(-1) : m;
450 int ldv = compute_uv ? V.stride(-1) : n;
451
452 // Need to pass allocated memory to the function, otherwise it fails
453 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
454 auto dataPtr_U = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * m * ldu) : c10::DataPtr{};
455 auto dataPtr_V = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * n * ldv) : c10::DataPtr{};
456
457 auto A_data = A.data_ptr<scalar_t>();
458 auto U_data = compute_uv ? U.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_U.get());
459 auto S_data = S.data_ptr<value_t>();
460 auto V_data = compute_uv ? V.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_V.get());
461
462 TORCH_INTERNAL_ASSERT(m <= 32 && n <= 32, "gesvdjBatched requires both matrix dimensions not greater than 32, but got "
463 "m = ", m, " n = ", n);
464
465 // gesvdj_params controls the numerical accuracy of cusolver gesvdj iterations on GPU
466 gesvdjInfo_t gesvdj_params;
467 TORCH_CUSOLVER_CHECK(cusolverDnCreateGesvdjInfo(&gesvdj_params));
468
469 // Todo: expose the following two parameters to users
470 TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetTolerance(gesvdj_params, std::numeric_limits<scalar_t>::epsilon()));
471 TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetMaxSweeps(gesvdj_params, cusolver_gesvdj_max_sweeps));
472 TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetSortEig(gesvdj_params, 1));
473
474 auto handle = at::cuda::getCurrentCUDASolverDnHandle();
475 auto jobz = compute_uv ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
476 at::cuda::solver::gesvdjBatched<scalar_t>(
477 handle, jobz, m, n, A_data, lda, S_data, U_data, ldu, V_data, ldv,
478 infos.data_ptr<int>(), gesvdj_params, batchsize
479 );
480
481 TORCH_CUSOLVER_CHECK(cusolverDnDestroyGesvdjInfo(gesvdj_params));
482 }
483
svd_cusolver_gesvdjBatched(const Tensor & A,const Tensor & U,const Tensor & S,const Tensor & V,const Tensor & infos,bool full_matrices,bool compute_uv)484 inline static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V, const Tensor& infos, bool full_matrices, bool compute_uv) {
485 auto m = A.size(-2);
486 auto n = A.size(-1);
487 auto k = std::min(m, n);
488 // The kernel assumes full_matrices == true
489 // If full_matrices == false and m != n, we create auxiliary tensors of the right size and copy the results back
490 auto U_ = U;
491 auto V_ = V;
492 if (compute_uv && !full_matrices) {
493 auto sizes = A.sizes().vec();
494 if (m > n) {
495 // Size of U with full_matrices == True
496 sizes.end()[-1] = m;
497 // U, V should be a batch of Fortran contiguous arrays
498 U_ = U.new_empty(sizes).mT();
499 } else if (m < n) {
500 // Size of V with full_matrices == True
501 sizes.end()[-2] = n;
502 V_ = V.new_empty(sizes).mT();
503 }
504 }
505 // Here U_ and V_ are batches of F-contig square matrices
506
507 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "svd_cuda_gesvdjBatched", [&] {
508 apply_svd_cusolver_gesvdjBatched<scalar_t>(A, U_, S, V_, infos, compute_uv);
509 });
510
511 // Copy the result back if we created any new matrix
512 if (compute_uv && !full_matrices) {
513 if (!U_.is_alias_of(U)) {
514 U.copy_(U_.narrow(-1, 0, k));
515 }
516 if (!V_.is_alias_of(V)) {
517 V.copy_(V_.narrow(-1, 0, k));
518 }
519 }
520 }
521
522 template<typename scalar_t>
apply_svd_cusolver_gesvdaStridedBatched(const Tensor & A,const Tensor & U,const Tensor & S,const Tensor & V,const Tensor & infos,bool full_matrices,bool compute_uv)523 inline static void apply_svd_cusolver_gesvdaStridedBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
524 const Tensor& infos, bool full_matrices, bool compute_uv) {
525 #ifndef CUDART_VERSION
526 TORCH_CHECK(false, "gesvda: Batched version is supported only with cuBLAS backend.")
527 #else
528 using value_t = typename c10::scalar_value_type<scalar_t>::type;
529 int m = cuda_int_cast(A.size(-2), "m");
530 int n = cuda_int_cast(A.size(-1), "n");
531 TORCH_INTERNAL_ASSERT(m >= n, "cusolver gesvdaStridedBatched requires m >= n");
532 int batchsize = cuda_int_cast(batchCount(A), "batch size");
533
534 int lda = A.stride(-1);
535 int ldu = compute_uv ? U.stride(-1) : m;
536 int ldv = compute_uv ? V.stride(-1) : n;
537
538 auto A_stride = matrixStride(A);
539 auto S_stride = S.size(-1);
540 auto rank = S_stride; // number of singular values
541 auto U_stride = compute_uv ? matrixStride(U) : ldu * rank; // The strides for "empty matrices" are needed to satisfy cusolver.
542 auto V_stride = compute_uv ? matrixStride(V) : ldv * rank;
543
544 // Need to pass allocated memory to the function, otherwise it fails
545 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
546 auto dataPtr_U = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * m * n) : c10::DataPtr{};
547 auto dataPtr_V = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * n * n) : c10::DataPtr{};
548
549 auto A_data = A.data_ptr<scalar_t>();
550 auto U_data = compute_uv ? U.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_U.get());
551 auto S_data = S.data_ptr<value_t>();
552 auto V_data = compute_uv ? V.data_ptr<scalar_t>() : reinterpret_cast<scalar_t*>(dataPtr_V.get());
553
554 auto handle = at::cuda::getCurrentCUDASolverDnHandle();
555 auto jobz = compute_uv ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
556
557 int lwork = -1;
558 at::cuda::solver::gesvdaStridedBatched_buffersize<scalar_t>(
559 handle, jobz, rank, m, n, A_data, lda, A_stride, S_data, S_stride, U_data, ldu, U_stride, V_data, ldv, V_stride,
560 &lwork, batchsize);
561 TORCH_INTERNAL_ASSERT(lwork >= 0, "gesvdaStridedBatched_buffersize failed to get needed buffer size, got lwork = ", lwork);
562 auto workspace = allocator.allocate(sizeof(scalar_t)*lwork);
563
564 // The residual Frobenius norm is always returned in double.
565 // cuSOLVER remark: if the user is confident on the accuracy of singular values and singular vectors,
566 // for example, certain conditions hold (required singular value is far from zero),
567 // then the performance can be improved by passing a null pointer to h_RnrmF, i.e. no computation of residual norm.
568 // Comment: calculation of Frobenius norm is expensive and doesn't affect accuracy of the result
569
570 at::cuda::solver::gesvdaStridedBatched<scalar_t>(
571 handle, jobz, rank, m, n, A_data, lda, A_stride, S_data, S_stride, U_data, ldu, U_stride, V_data, ldv, V_stride,
572 reinterpret_cast<scalar_t*>(workspace.get()),
573 lwork, infos.data_ptr<int>(),
574 nullptr, // cuSOLVER h_RnrmF is not calculated: reinterpret_cast<double*>(residual_frobenius_norm.get()),
575 batchsize);
576 #endif
577 }
578
579 // We'll copy A inside svd_cusolver_gesvdaStridedBatched
svd_cusolver_gesvdaStridedBatched(const Tensor & A,const Tensor & U,const Tensor & S,const Tensor & V,const Tensor & infos,bool full_matrices,bool compute_uv)580 inline static void svd_cusolver_gesvdaStridedBatched(
581 const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
582 const Tensor& infos, bool full_matrices, bool compute_uv) {
583 // We need to pass a copy of A, as it will be overwritten
584 // gesvdaStridedBatched just knows how to handle m >= n, so in the other case we need to transpose A
585 const auto not_A_H = A.size(-2) >= A.size(-1);
586 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "svd_cuda_gesvdaStridedBatched", [&] {
587 apply_svd_cusolver_gesvdaStridedBatched<scalar_t>(
588 cloneBatchedColumnMajor(not_A_H ? A : A.mH()),
589 not_A_H ? U : V,
590 S,
591 not_A_H ? V : U,
592 infos, full_matrices, compute_uv);
593 });
594 }
595
596 // Check convergence of gesvdj/gesvdjBatched/gesvdaStridedBatched results.
597 // If not converged, return a vector that contains indices of the non-converging batches.
598 // If the returned vector is empty, all the matrices are converged.
599 // This function will cause a device-host sync.
_check_gesvdj_convergence(const Tensor & infos,int64_t non_converging_info)600 std::vector<int64_t> _check_gesvdj_convergence(const Tensor& infos, int64_t non_converging_info) {
601 at::Tensor infos_cpu = infos.cpu();
602 auto infos_cpu_data = infos_cpu.data_ptr<int>();
603
604 std::vector<int64_t> res;
605
606 for(int64_t i = 0; i < infos.numel(); i++) {
607 int info_for_batch_i = infos_cpu_data[i];
608
609 // From cusolver doc, if info < 0, the i-th function call parameter is wrong,
610 // which means pytorch implementation of cusolver is wrong.
611 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info_for_batch_i >= 0);
612
613 // In our use case, gesvdj, gesvdjBatched, and gesvdaStridedBatched have the same notations for `info`.
614 if (info_for_batch_i == non_converging_info) res.push_back(i);
615
616 // However, it is not the same for gesvd, though we don't use this function to check gesvd convergence either.
617 // If it's implemented some day in the future, this needs to be handled carefully.
618 }
619
620 return res;
621 }
622
623 // Depending on the number of non-converging batches,
624 // format the non-converging batches string as either (no leading or trailing whitespaces)
625 // batches 2, 3, 5 // or
626 // batches 2, 3, 5, 7, 11 and other 65535 batches
_format_non_converging_batches(const std::vector<int64_t> & batches)627 std::string _format_non_converging_batches(const std::vector<int64_t>& batches) {
628 std::stringstream ss;
629 const int too_long = 5;
630
631 ss << "batches ";
632 if (batches.size() <= too_long) {
633 for (const auto i : c10::irange(batches.size() - 1)) {
634 ss << batches[i] << ", ";
635 }
636 ss << batches.back();
637 } else {
638 for (const auto i : c10::irange(too_long)) {
639 ss << batches[i] << ", ";
640 }
641 ss << "and other " << batches.size() - too_long << " batches";
642 }
643
644 return ss.str();
645 }
646
647 // This function returns V, not V^H.
svd_cusolver(const Tensor & A,const bool full_matrices,const bool compute_uv,const std::optional<c10::string_view> & driver,const Tensor & U,const Tensor & S,const Tensor & V,const Tensor & info)648 void svd_cusolver(const Tensor& A,
649 const bool full_matrices,
650 const bool compute_uv,
651 const std::optional<c10::string_view>& driver,
652 const Tensor& U,
653 const Tensor& S,
654 const Tensor& V,
655 const Tensor& info) {
656 // Here U and V are F-contig whenever they are defined (i.e. whenever compute_uv=true)
657 const auto m = A.size(-2);
658 const auto n = A.size(-1);
659 const auto k = std::min(m, n);
660
661 static const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html";
662
663 // The default heuristic is to use gesvdj driver
664 #ifdef USE_ROCM
665 const auto driver_v = c10::string_view("gesvdj");
666 #else
667 const auto driver_v = driver.value_or("gesvdj");
668 #endif
669
670 if (driver_v == "gesvd") {
671 svd_cusolver_gesvd(A, U, S, V, info, full_matrices, compute_uv);
672 } else if (driver_v == "gesvdj") {
673 // See the benchmarks in
674 // https://github.com/pytorch/pytorch/pull/88502#issuecomment-1303860789
675 // The m <= 32 && n <= 32 restrictions come from the limitations of the cusolver backend. See the cusolver docs
676 if (m <= 32 && n <= 32) {
677 svd_cusolver_gesvdjBatched(cloneBatchedColumnMajor(A), U, S, V, info, full_matrices, compute_uv);
678 } else {
679 // gesvdj driver may be numerically unstable for large sized matrix
680 svd_cusolver_gesvdj(cloneBatchedColumnMajor(A), U, S, V, info, full_matrices, compute_uv);
681 }
682 } else if (driver_v == "gesvda") {
683 // cuSOLVER: gesvdaStridedBatched is preferred for "tall skinny" (m > n) matrices
684 // We do a transpose here to make it also work for (m < n) matrices.
685 svd_cusolver_gesvdaStridedBatched(A, U, S, V, info, full_matrices, compute_uv);
686 } else {
687 TORCH_CHECK(false, "torch.linalg.svd: unknown svd driver ", driver_v, " in svd_cusolver computation. ", check_svd_doc);
688 }
689
690 // Need convergence check
691 if (driver_v != "gesvd") {
692 // A device-host sync will be performed.
693 // Todo: implement the svd_ex variant to not check result convergence, thus removing the device-host sync
694 const auto svd_non_converging_batches = _check_gesvdj_convergence(info, k + 1);
695
696 if (!svd_non_converging_batches.empty()) {
697 TORCH_WARN_ONCE("torch.linalg.svd: During SVD computation with the selected cusolver driver, ",
698 _format_non_converging_batches(svd_non_converging_batches),
699 " failed to converge. ",
700 (driver.has_value()
701 ? "It is recommended to redo this SVD with another driver. "
702 : "A more accurate method will be used to compute the SVD as a fallback. "),
703 check_svd_doc);
704
705 // We'll do the fallback if user doesn't specify a driver and the default heuristic doesn't converge well.
706 // However, if user manually chooses a driver, should we just do a warning or a hard crash?
707 if (!driver.has_value()) {
708 svd_cusolver_gesvd(A, U, S, V, info, full_matrices, compute_uv, false, svd_non_converging_batches);
709 }
710 }
711 }
712
713 // `info` will be checked later at `TORCH_IMPL_FUNC(_linalg_svd_out)` function.
714 }
715
716
717 // Implementation of Cholesky decomposition using looped cusolverDn<T>potrf or cusolverDnXpotrf (64-bit)
718 template<typename scalar_t>
apply_cholesky_cusolver_potrf_looped(const Tensor & self_working_copy,bool upper,const Tensor & infos)719 inline static void apply_cholesky_cusolver_potrf_looped(const Tensor& self_working_copy, bool upper, const Tensor& infos) {
720 auto handle = at::cuda::getCurrentCUDASolverDnHandle();
721 const auto uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
722 const int64_t n = self_working_copy.size(-1);
723 const int64_t lda = std::max<int64_t>(1, n);
724 const int64_t batch_size = batchCount(self_working_copy);
725 const int64_t matrix_stride = matrixStride(self_working_copy);
726
727 scalar_t* self_working_copy_ptr = self_working_copy.data_ptr<scalar_t>();
728 int* infos_ptr = infos.data_ptr<int>();
729
730 #ifdef USE_CUSOLVER_64_BIT
731 size_t worksize_device;
732 size_t worksize_host;
733 cusolverDnParams_t params;
734 cudaDataType datatype = at::cuda::solver::get_cusolver_datatype<scalar_t>();
735 TORCH_CUSOLVER_CHECK(cusolverDnCreateParams(¶ms));
736 at::cuda::solver::xpotrf_buffersize(handle, params, uplo, n, datatype, nullptr, lda, datatype, &worksize_device, &worksize_host);
737
738 // allocate workspace storage
739 auto& device_allocator = *at::cuda::getCUDADeviceAllocator();
740 auto workdata_device = device_allocator.allocate(worksize_device * batch_size);
741 void* workdata_device_ptr = workdata_device.get();
742
743 auto& host_allocator = *at::getCPUAllocator();
744 auto workdata_host = host_allocator.allocate(worksize_host * batch_size);
745 void* workdata_host_ptr = workdata_host.get();
746
747 for (int64_t i = 0; i < batch_size; i++) {
748 at::cuda::solver::xpotrf(
749 handle, params, uplo, n, datatype,
750 self_working_copy_ptr + i * matrix_stride,
751 lda, datatype,
752 (char*)workdata_device_ptr + i * worksize_device, worksize_device,
753 (char*)workdata_host_ptr + i * worksize_host, worksize_host,
754 infos_ptr + i
755 );
756 }
757
758 TORCH_CUSOLVER_CHECK(cusolverDnDestroyParams(params));
759 #else // USE_CUSOLVER_64_BIT
760 int n_32 = cuda_int_cast(n, "n");
761 int lda_32 = cuda_int_cast(lda, "lda");
762 int lwork;
763 at::cuda::solver::potrf_buffersize<scalar_t>(
764 handle, uplo, n_32, nullptr, lda_32, &lwork);
765
766 // allocate workspace storage
767 auto& allocator = *at::cuda::getCUDADeviceAllocator();
768 auto work_data = allocator.allocate(sizeof(scalar_t)*lwork * batch_size);
769 scalar_t* work_data_ptr = static_cast<scalar_t*>(work_data.get());
770
771 for (int64_t i = 0; i < batch_size; i++) {
772 at::cuda::solver::potrf<scalar_t>(
773 handle, uplo, n_32,
774 self_working_copy_ptr + i * matrix_stride,
775 lda_32,
776 work_data_ptr + i * lwork,
777 lwork,
778 infos_ptr + i
779 );
780 }
781 #endif // USE_CUSOLVER_64_BIT
782 }
783
784 // Implementation of Cholesky decomposition using batched cusolverDn<T>potrfBatched
785 // Warning: cusolverDn<T>potrfBatched doesn't work quite well when matrix size or batch size is zero.
786 // If you write your own C++ extension and use this function, make sure you do a zero numel check for the input.
787 template<typename scalar_t>
apply_cholesky_cusolver_potrfBatched(const Tensor & self_working_copy,bool upper,const Tensor & infos)788 inline static void apply_cholesky_cusolver_potrfBatched(const Tensor& self_working_copy, bool upper, const Tensor& infos) {
789 auto handle = at::cuda::getCurrentCUDASolverDnHandle();
790 const auto uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
791 const int n = cuda_int_cast(self_working_copy.size(-1), "n");
792 const int lda = std::max<int>(1, n);
793
794 const int batch_size = cuda_int_cast(batchCount(self_working_copy), "batch_size");
795
796 // cusolver batched kernels require input be "device array of device pointers"
797 Tensor self_working_copy_array = get_device_pointers<scalar_t>(self_working_copy);
798
799 at::cuda::solver::potrfBatched<scalar_t>(
800 handle, uplo, n,
801 reinterpret_cast<scalar_t**>(self_working_copy_array.data_ptr()),
802 lda, infos.data_ptr<int>(), batch_size);
803 }
804
cholesky_helper_cusolver(const Tensor & input,bool upper,const Tensor & info)805 void cholesky_helper_cusolver(const Tensor& input, bool upper, const Tensor& info) {
806 if (input.numel() == 0) {
807 return;
808 }
809
810 if (use_cusolver_potrf_batched_ && batchCount(input) > 1) {
811 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "cholesky_cusolver", [&] {
812 apply_cholesky_cusolver_potrfBatched<scalar_t>(input, upper, info);
813 });
814 } else {
815 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "cholesky_cusolver", [&] {
816 apply_cholesky_cusolver_potrf_looped<scalar_t>(input, upper, info);
817 });
818 }
819 }
820
821
822 template<typename scalar_t>
apply_cholesky_cusolver_potrs(Tensor & self_working_copy,const Tensor & A_column_major_copy,bool upper,Tensor & infos)823 inline static void apply_cholesky_cusolver_potrs(Tensor& self_working_copy, const Tensor& A_column_major_copy, bool upper, Tensor& infos) {
824 auto handle = at::cuda::getCurrentCUDASolverDnHandle();
825 const auto uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
826 const int64_t n = self_working_copy.size(-2);
827 const int64_t nrhs = self_working_copy.size(-1);
828 const int64_t lda = std::max<int64_t>(1, n);
829 const int64_t batch_size = batchCount(self_working_copy);
830 const int64_t self_matrix_stride = matrixStride(self_working_copy);
831 scalar_t* self_working_copy_ptr = self_working_copy.data_ptr<scalar_t>();
832
833 scalar_t* A_ptr = A_column_major_copy.data_ptr<scalar_t>();
834 const int64_t A_matrix_stride = matrixStride(A_column_major_copy);
835 const int64_t ldb = std::max<int64_t>(1, A_column_major_copy.size(-1));
836
837 int* infos_ptr = infos.data_ptr<int>();
838
839 #ifdef USE_CUSOLVER_64_BIT
840 cusolverDnParams_t params;
841 cudaDataType datatype = at::cuda::solver::get_cusolver_datatype<scalar_t>();
842 TORCH_CUSOLVER_CHECK(cusolverDnCreateParams(¶ms));
843
844 for (int64_t i = 0; i < batch_size; i++) {
845 at::cuda::solver::xpotrs(
846 handle, params, uplo, n, nrhs, datatype,
847 A_ptr + i * A_matrix_stride,
848 lda, datatype,
849 self_working_copy_ptr + i * self_matrix_stride,
850 ldb,
851 infos_ptr
852 );
853 }
854
855 TORCH_CUSOLVER_CHECK(cusolverDnDestroyParams(params));
856 #else // USE_CUSOLVER_64_BIT
857 int n_32 = cuda_int_cast(n, "n");
858 int nrhs_32 = cuda_int_cast(nrhs, "nrhs");
859 int lda_32 = cuda_int_cast(lda, "lda");
860 int ldb_32 = cuda_int_cast(ldb, "ldb");
861
862 for (int64_t i = 0; i < batch_size; i++) {
863 at::cuda::solver::potrs<scalar_t>(
864 handle, uplo, n_32, nrhs_32,
865 A_ptr + i * A_matrix_stride,
866 lda_32,
867 self_working_copy_ptr + i * self_matrix_stride,
868 ldb_32,
869 infos_ptr
870 );
871 }
872 #endif // USE_CUSOLVER_64_BIT
873 }
874
875
876 // This code path is only dispatched to if MAGMA is not linked in the pytorch build.
877 // cusolverDn<t>potrsBatched only supports nrhs == 1
878 template<typename scalar_t>
apply_cholesky_cusolver_potrsBatched(Tensor & self_working_copy,const Tensor & A_column_major_copy,bool upper,Tensor & infos)879 inline static void apply_cholesky_cusolver_potrsBatched(Tensor& self_working_copy, const Tensor& A_column_major_copy, bool upper, Tensor& infos) {
880 auto handle = at::cuda::getCurrentCUDASolverDnHandle();
881 const auto uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
882 const int64_t n = self_working_copy.size(-2);
883 const int64_t nrhs = self_working_copy.size(-1);
884 const int64_t lda = std::max<int64_t>(1, n);
885 const int64_t batch_size = batchCount(self_working_copy);
886
887 const int64_t ldb = std::max<int64_t>(1, A_column_major_copy.size(-1));
888
889 int* infos_ptr = infos.data_ptr<int>();
890
891 auto self_ptr_array = get_device_pointers<scalar_t>(self_working_copy);
892 auto A_ptr_array = get_device_pointers<scalar_t>(A_column_major_copy);
893
894 at::cuda::solver::potrsBatched(
895 handle, uplo,
896 cuda_int_cast(n, "n"),
897 cuda_int_cast(nrhs, "nrhs"),
898 reinterpret_cast<scalar_t**>(A_ptr_array.data_ptr()),
899 cuda_int_cast(lda, "lda"),
900 reinterpret_cast<scalar_t**>(self_ptr_array.data_ptr()),
901 cuda_int_cast(ldb, "ldb"),
902 infos_ptr,
903 cuda_int_cast(batch_size, "batch_size")
904 );
905 }
906
_cholesky_solve_helper_cuda_cusolver(const Tensor & self,const Tensor & A,bool upper)907 Tensor _cholesky_solve_helper_cuda_cusolver(const Tensor& self, const Tensor& A, bool upper) {
908 const int64_t batch_size = batchCount(self);
909 at::Tensor infos = at::zeros({1}, self.options().dtype(at::kInt));
910 at::Tensor self_working_copy = cloneBatchedColumnMajor(self);
911 at::Tensor A_column_major_copy = cloneBatchedColumnMajor(A);
912
913 const int64_t nrhs = self_working_copy.size(-1);
914
915 // cusolverDn<t>potrsBatched only supports nrhs == 1
916 if (batch_size > 1 && nrhs == 1) {
917 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "cholesky_cuda_potrs_batched", [&] {
918 apply_cholesky_cusolver_potrsBatched<scalar_t>(self_working_copy, A_column_major_copy, upper, infos);
919 });
920 } else {
921 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "cholesky_cuda_potrs", [&] {
922 apply_cholesky_cusolver_potrs<scalar_t>(self_working_copy, A_column_major_copy, upper, infos);
923 });
924 }
925
926 // info from potrs and potrsBatched only report if the i-th parameter is wrong, not about the matrix singularity, etc.
927 // So we don't need to check it all the time.
928 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.item().toInt() == 0);
929
930 return self_working_copy;
931 }
932
933
_cholesky_inverse_cusolver_potrs_based(Tensor & result,Tensor & infos,bool upper)934 void _cholesky_inverse_cusolver_potrs_based(Tensor& result, Tensor& infos, bool upper) {
935 at::Tensor input_working_copy = cloneBatchedColumnMajor(result);
936 at::Tensor infos_gpu = at::zeros({1}, result.options().dtype(at::kInt));
937 result.fill_(0);
938 result.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(1);
939 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "cholesky_cuda_potri", [&] {
940 apply_cholesky_cusolver_potrs<scalar_t>(result, input_working_copy, upper, infos_gpu);
941 });
942
943 // Debug only: info of cusolver potrs only check if the i-th parameter is wrong
944 // Function argument `infos` is a CPU tensor, the following copy will cause a device-host sync.
945 // infos.copy_(infos_gpu);
946 }
947
cholesky_inverse_kernel_impl_cusolver(Tensor & result,Tensor & infos,bool upper)948 Tensor& cholesky_inverse_kernel_impl_cusolver(Tensor &result, Tensor& infos, bool upper) {
949 _cholesky_inverse_cusolver_potrs_based(result, infos, upper);
950 return result;
951 }
952
953
954 /*
955 The geqrf function computes the QR decomposition of a m x n matrix A.
956
957 Args:
958 * `A` - [in] Tensor with matrices for QR decomposition,
959 [out] Tensor containing R in the upper triangle of A
960 and elementary reflectors below the main diagonal of A
961 * `tau` - Tensor containing the magnitudes of the elementary reflectors
962 * `m` - The number of rows of `input` to consider
963 * `n` - The number of columns of `input` to consider (actual sizes of `input` could be larger)
964
965 For further details, please see the cuSOLVER documentation for GEQRF.
966 */
967 template <typename scalar_t>
apply_geqrf(const Tensor & A,const Tensor & tau)968 static void apply_geqrf(const Tensor& A, const Tensor& tau) {
969 int64_t m = A.size(-2);
970 int64_t n = A.size(-1);
971 int64_t lda = std::max<int64_t>(1, m);
972 int64_t batch_size = batchCount(A);
973
974 auto A_stride = matrixStride(A);
975 auto tau_stride = tau.size(-1);
976
977 auto A_data = A.data_ptr<scalar_t>();
978 auto tau_data = tau.data_ptr<scalar_t>();
979
980 auto infos = at::zeros({1}, A.options().dtype(at::kInt));
981 auto infos_data = infos.data_ptr<int>();
982
983 // get the optimal work size and allocate workspace tensor
984 #ifdef USE_CUSOLVER_64_BIT
985 size_t worksize_device; // workspaceInBytesOnDevice
986 size_t worksize_host; // workspaceInBytesOnHost
987 cusolverDnParams_t params = NULL; // use default algorithm (currently it's the only option)
988 at::cuda::solver::xgeqrf_bufferSize<scalar_t>(
989 at::cuda::getCurrentCUDASolverDnHandle(),
990 params,
991 m,
992 n,
993 A_data,
994 lda,
995 tau_data,
996 &worksize_device,
997 &worksize_host);
998 #else
999 int lwork;
1000 int m_32 = cuda_int_cast(m, "m");
1001 int n_32 = cuda_int_cast(n, "n");
1002 int lda_32 = cuda_int_cast(lda, "lda");
1003 at::cuda::solver::geqrf_bufferSize<scalar_t>(
1004 at::cuda::getCurrentCUDASolverDnHandle(), m_32, n_32, A_data, lda_32, &lwork);
1005 #endif // USE_CUSOLVER_64_BIT
1006
1007 for (decltype(batch_size) i = 0; i < batch_size; i++) {
1008 scalar_t* A_working_ptr = &A_data[i * A_stride];
1009 scalar_t* tau_working_ptr = &tau_data[i * tau_stride];
1010 auto handle = at::cuda::getCurrentCUDASolverDnHandle();
1011
1012 #ifdef USE_CUSOLVER_64_BIT
1013 // allocate workspace storage on device and host
1014 auto& device_allocator = *at::cuda::getCUDADeviceAllocator();
1015 auto work_device_data = device_allocator.allocate(worksize_device);
1016 auto& host_allocator = *at::getCPUAllocator();
1017 auto work_host_data = host_allocator.allocate(worksize_host);
1018 at::cuda::solver::xgeqrf<scalar_t>(
1019 handle,
1020 params,
1021 m,
1022 n,
1023 A_working_ptr,
1024 lda,
1025 tau_working_ptr,
1026 static_cast<scalar_t*>(work_device_data.get()),
1027 worksize_device,
1028 static_cast<scalar_t*>(work_host_data.get()),
1029 worksize_host,
1030 infos_data);
1031 #else
1032 // allocate workspace storage on device
1033 auto& allocator = *at::cuda::getCUDADeviceAllocator();
1034 auto work_data = allocator.allocate(sizeof(scalar_t) * std::max<int>(1, lwork));
1035 at::cuda::solver::geqrf<scalar_t>(
1036 handle,
1037 m_32,
1038 n_32,
1039 A_working_ptr,
1040 lda_32,
1041 tau_working_ptr,
1042 static_cast<scalar_t*>(work_data.get()),
1043 lwork,
1044 infos_data);
1045 #endif // USE_CUSOLVER_64_BIT
1046 }
1047
1048 // info from geqrf only reports if the i-th parameter is wrong, not about the matrix singularity
1049 // so we don't need to check it all the time
1050 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.item().toInt() == 0);
1051 }
1052
1053 // This is a type dispatching helper function for 'apply_geqrf'
geqrf_cusolver(const Tensor & input,const Tensor & tau)1054 void geqrf_cusolver(const Tensor& input, const Tensor& tau) {
1055 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "geqrf_cuda", [&]{
1056 apply_geqrf<scalar_t>(input, tau);
1057 });
1058 }
1059
1060 /*
1061 The ormqr function multiplies Q with another matrix from a sequence of
1062 elementary reflectors, such as is produced by the geqrf function.
1063
1064 Args:
1065 * `input` - Tensor with elementary reflectors below the diagonal,
1066 encoding the matrix Q.
1067 * `tau` - Tensor containing the magnitudes of the elementary
1068 reflectors.
1069 * `other` - [in] Tensor containing the matrix to be multiplied.
1070 [out] result of the matrix multiplication with Q.
1071 * `left` - bool, determining whether `other` is left- or right-multiplied with Q.
1072 * `transpose` - bool, determining whether to transpose (or conjugate transpose) Q before multiplying.
1073
1074 For further details, please see the cuSOLVER documentation for ORMQR and UNMQR.
1075 */
1076 template <typename scalar_t>
apply_ormqr(const Tensor & input,const Tensor & tau,const Tensor & other,bool left,bool transpose)1077 static void apply_ormqr(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) {
1078 auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
1079 auto trans = transpose ? (input.is_complex() ? CUBLAS_OP_C : CUBLAS_OP_T) : CUBLAS_OP_N;
1080
1081 auto input_data = input.const_data_ptr<scalar_t>();
1082 auto tau_data = tau.const_data_ptr<scalar_t>();
1083 auto other_data = other.data_ptr<scalar_t>();
1084
1085 auto input_matrix_stride = matrixStride(input);
1086 auto other_matrix_stride = matrixStride(other);
1087 auto tau_stride = tau.size(-1);
1088 auto batch_size = batchCount(input);
1089 auto m = cuda_int_cast(other.size(-2), "m");
1090 auto n = cuda_int_cast(other.size(-1), "n");
1091 auto k = cuda_int_cast(tau.size(-1), "k");
1092 auto lda = std::max<int>(1, left ? m : n);
1093 auto ldc = std::max<int>(1, m);
1094
1095 // get the optimal work size and allocate workspace tensor
1096 int lwork;
1097 at::cuda::solver::ormqr_bufferSize<scalar_t>(
1098 at::cuda::getCurrentCUDASolverDnHandle(), side, trans, m, n, k, input_data, lda, tau_data, other_data, ldc, &lwork);
1099
1100 auto info = at::zeros({1}, input.options().dtype(at::kInt));
1101 auto info_data = info.data_ptr<int>();
1102
1103 for (auto i = decltype(batch_size){0}; i < batch_size; i++) {
1104 const scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
1105 scalar_t* other_working_ptr = &other_data[i * other_matrix_stride];
1106 const scalar_t* tau_working_ptr = &tau_data[i * tau_stride];
1107 auto handle = at::cuda::getCurrentCUDASolverDnHandle();
1108
1109 // allocate workspace storage
1110 auto& allocator = *at::cuda::getCUDADeviceAllocator();
1111 auto work_data = allocator.allocate(sizeof(scalar_t)*lwork);
1112
1113 at::cuda::solver::ormqr<scalar_t>(
1114 handle, side, trans, m, n, k,
1115 input_working_ptr,
1116 lda,
1117 tau_working_ptr,
1118 other_working_ptr,
1119 ldc,
1120 static_cast<scalar_t*>(work_data.get()),
1121 lwork,
1122 info_data
1123 );
1124
1125 // info from ormqr only reports if the i-th parameter is wrong
1126 // so we don't need to check it all the time
1127 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info.item().toInt() == 0);
1128 }
1129 }
1130
1131 // This is a type dispatching helper function for 'apply_ormqr'
ormqr_cusolver(const Tensor & input,const Tensor & tau,const Tensor & other,bool left,bool transpose)1132 void ormqr_cusolver(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) {
1133 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "orgmr_cuda", [&]{
1134 apply_ormqr<scalar_t>(input, tau, other, left, transpose);
1135 });
1136 }
1137
1138 /*
1139 The orgqr function allows reconstruction of an orthogonal (or unitary) matrix Q,
1140 from a sequence of elementary reflectors, such as produced by the geqrf function.
1141
1142 Args:
1143 * `self` - Tensor with the directions of the elementary reflectors below the diagonal,
1144 it will be overwritten with the result
1145 * `tau` - Tensor containing the magnitudes of the elementary reflectors
1146
1147 For further details, please see the cuSOLVER documentation for ORGQR and UNGQR.
1148 */
1149 template <typename scalar_t>
apply_orgqr(Tensor & self,const Tensor & tau)1150 inline static void apply_orgqr(Tensor& self, const Tensor& tau) {
1151 auto self_data = self.data_ptr<scalar_t>();
1152 auto tau_data = tau.const_data_ptr<scalar_t>();
1153 auto self_matrix_stride = matrixStride(self);
1154 auto batchsize = cuda_int_cast(batchCount(self), "batch size");
1155 auto m = cuda_int_cast(self.size(-2), "m");
1156 auto n = cuda_int_cast(self.size(-1), "n");
1157 auto k = cuda_int_cast(tau.size(-1), "k");
1158 auto tau_stride = std::max<int>(1, k);
1159 auto lda = std::max<int>(1, m);
1160
1161 // LAPACK's requirement
1162 TORCH_INTERNAL_ASSERT(m >= n);
1163 TORCH_INTERNAL_ASSERT(n >= k);
1164
1165 // cuSOLVER doesn't compute anything for this case, which is wrong
1166 // the result should be a matrix with 1 on the diagonal
1167 if (k == 0) {
1168 self.fill_(0);
1169 self.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(1);
1170 return;
1171 }
1172
1173 // get the optimal work size and allocate workspace tensor
1174 int lwork;
1175 at::cuda::solver::orgqr_buffersize<scalar_t>(
1176 at::cuda::getCurrentCUDASolverDnHandle(), m, n, k, self_data, lda, tau_data, &lwork);
1177
1178 auto info = at::zeros({1}, self.options().dtype(at::kInt));
1179 auto info_data = info.data_ptr<int>();
1180
1181 for (auto i = decltype(batchsize){0}; i < batchsize; i++) {
1182 scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
1183 const scalar_t* tau_working_ptr = &tau_data[i * tau_stride];
1184 auto handle = at::cuda::getCurrentCUDASolverDnHandle();
1185
1186 // allocate workspace storage
1187 auto& allocator = *at::cuda::getCUDADeviceAllocator();
1188 auto work_data = allocator.allocate(sizeof(scalar_t)*lwork);
1189
1190 at::cuda::solver::orgqr<scalar_t>(
1191 handle, m, n, k,
1192 self_working_ptr,
1193 lda,
1194 tau_working_ptr,
1195 static_cast<scalar_t*>(work_data.get()),
1196 lwork,
1197 info_data
1198 );
1199
1200 // info from orgqr only reports if the i-th parameter is wrong
1201 // so we don't need to check it all the time
1202 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info.item().toInt() == 0);
1203 }
1204 }
1205
1206 // This is a type dispatching helper function for 'apply_orgqr'
orgqr_helper_cusolver(Tensor & result,const Tensor & tau)1207 Tensor& orgqr_helper_cusolver(Tensor& result, const Tensor& tau) {
1208 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "orgqr_cuda", [&]{
1209 apply_orgqr<scalar_t>(result, tau);
1210 });
1211 return result;
1212 }
1213
1214 template <typename scalar_t>
apply_syevd(const Tensor & values,const Tensor & vectors,const Tensor & infos,bool upper,bool compute_eigenvectors)1215 static void apply_syevd(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
1216 using value_t = typename c10::scalar_value_type<scalar_t>::type;
1217
1218 cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
1219 cusolverEigMode_t jobz = compute_eigenvectors ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
1220
1221 int64_t n = vectors.size(-1);
1222 int64_t lda = std::max<int64_t>(1, n);
1223 int64_t batch_size = batchCount(vectors);
1224
1225 auto vectors_stride = matrixStride(vectors);
1226 auto values_stride = values.size(-1);
1227
1228 auto vectors_data = vectors.data_ptr<scalar_t>();
1229 auto values_data = values.data_ptr<value_t>();
1230 auto infos_data = infos.data_ptr<int>();
1231
1232 // get the optimal work size and allocate workspace tensor
1233 #ifdef USE_CUSOLVER_64_BIT
1234 size_t worksize_device; // workspaceInBytesOnDevice
1235 size_t worksize_host; // workspaceInBytesOnHost
1236 cusolverDnParams_t params = NULL; // use default algorithm (currently it's the only option)
1237 at::cuda::solver::xsyevd_bufferSize<scalar_t>(
1238 at::cuda::getCurrentCUDASolverDnHandle(),
1239 params,
1240 jobz,
1241 uplo,
1242 n,
1243 vectors_data,
1244 lda,
1245 values_data,
1246 &worksize_device,
1247 &worksize_host);
1248 #else
1249 int lwork;
1250 int n_32 = cuda_int_cast(n, "n");
1251 int lda_32 = cuda_int_cast(lda, "lda");
1252 at::cuda::solver::syevd_bufferSize<scalar_t>(
1253 at::cuda::getCurrentCUDASolverDnHandle(), jobz, uplo, n_32, vectors_data, lda_32, values_data, &lwork);
1254 #endif // USE_CUSOLVER_64_BIT
1255
1256 for (decltype(batch_size) i = 0; i < batch_size; i++) {
1257 scalar_t* vectors_working_ptr = &vectors_data[i * vectors_stride];
1258 value_t* values_working_ptr = &values_data[i * values_stride];
1259 int* info_working_ptr = &infos_data[i];
1260 auto handle = at::cuda::getCurrentCUDASolverDnHandle();
1261
1262 #ifdef USE_CUSOLVER_64_BIT
1263 // allocate workspace storage on device and host
1264 auto& device_allocator = *at::cuda::getCUDADeviceAllocator();
1265 auto work_device_data = device_allocator.allocate(worksize_device);
1266 auto& host_allocator = *at::getCPUAllocator();
1267 auto work_host_data = host_allocator.allocate(worksize_host);
1268 at::cuda::solver::xsyevd<scalar_t>(
1269 handle,
1270 params,
1271 jobz,
1272 uplo,
1273 n,
1274 vectors_working_ptr,
1275 lda,
1276 values_working_ptr,
1277 static_cast<scalar_t*>(work_device_data.get()),
1278 worksize_device,
1279 static_cast<scalar_t*>(work_host_data.get()),
1280 worksize_host,
1281 info_working_ptr);
1282 #else
1283 // allocate workspace storage on device
1284 auto& allocator = *at::cuda::getCUDADeviceAllocator();
1285 auto work_data = allocator.allocate(sizeof(scalar_t) * lwork);
1286 at::cuda::solver::syevd<scalar_t>(
1287 handle,
1288 jobz,
1289 uplo,
1290 n_32,
1291 vectors_working_ptr,
1292 lda_32,
1293 values_working_ptr,
1294 static_cast<scalar_t*>(work_data.get()),
1295 lwork,
1296 info_working_ptr);
1297 #endif // USE_CUSOLVER_64_BIT
1298 }
1299 }
1300
1301 template <typename scalar_t>
apply_syevj(const Tensor & values,const Tensor & vectors,const Tensor & infos,bool upper,bool compute_eigenvectors)1302 static void apply_syevj(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
1303 using value_t = typename c10::scalar_value_type<scalar_t>::type;
1304
1305 cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
1306 cusolverEigMode_t jobz = compute_eigenvectors ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
1307
1308 int n = cuda_int_cast(vectors.size(-1), "n");
1309 int lda = std::max<int>(1, n);
1310 auto batch_size = batchCount(vectors);
1311
1312 auto vectors_stride = matrixStride(vectors);
1313 auto values_stride = values.size(-1);
1314
1315 auto vectors_data = vectors.data_ptr<scalar_t>();
1316 auto values_data = values.data_ptr<value_t>();
1317 auto infos_data = infos.data_ptr<int>();
1318
1319 // syevj_params controls the numerical accuracy of syevj
1320 // by default the tolerance is set to machine accuracy
1321 // the maximum number of iteration of Jacobi method by default is 100
1322 // cuSOLVER documentations says: "15 sweeps are good enough to converge to machine accuracy"
1323 // LAPACK has SVD routine based on similar Jacobi algorithm (gesvj) and there a maximum of 30 iterations is set
1324 // Let's use the default values for now
1325 syevjInfo_t syevj_params;
1326 TORCH_CUSOLVER_CHECK(cusolverDnCreateSyevjInfo(&syevj_params));
1327
1328 // get the optimal work size and allocate workspace tensor
1329 int lwork;
1330 at::cuda::solver::syevj_bufferSize<scalar_t>(
1331 at::cuda::getCurrentCUDASolverDnHandle(), jobz, uplo, n, vectors_data, lda, values_data, &lwork, syevj_params);
1332
1333 for (decltype(batch_size) i = 0; i < batch_size; i++) {
1334 scalar_t* vectors_working_ptr = &vectors_data[i * vectors_stride];
1335 value_t* values_working_ptr = &values_data[i * values_stride];
1336 int* info_working_ptr = &infos_data[i];
1337 auto handle = at::cuda::getCurrentCUDASolverDnHandle();
1338
1339 // allocate workspace storage on device
1340 auto& allocator = *at::cuda::getCUDADeviceAllocator();
1341 auto work_data = allocator.allocate(sizeof(scalar_t) * lwork);
1342 at::cuda::solver::syevj<scalar_t>(
1343 handle,
1344 jobz,
1345 uplo,
1346 n,
1347 vectors_working_ptr,
1348 lda,
1349 values_working_ptr,
1350 static_cast<scalar_t*>(work_data.get()),
1351 lwork,
1352 info_working_ptr,
1353 syevj_params);
1354 }
1355 TORCH_CUSOLVER_CHECK(cusolverDnDestroySyevjInfo(syevj_params));
1356 }
1357
1358 template <typename scalar_t>
apply_syevj_batched(const Tensor & values,const Tensor & vectors,const Tensor & infos,bool upper,bool compute_eigenvectors)1359 static void apply_syevj_batched(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
1360 using value_t = typename c10::scalar_value_type<scalar_t>::type;
1361
1362 cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
1363 cusolverEigMode_t jobz = compute_eigenvectors ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
1364
1365 int n = cuda_int_cast(vectors.size(-1), "n");
1366 int lda = std::max<int>(1, n);
1367 int batch_size = cuda_int_cast(batchCount(vectors), "batch_size");
1368
1369 auto vectors_data = vectors.data_ptr<scalar_t>();
1370 auto values_data = values.data_ptr<value_t>();
1371 auto infos_data = infos.data_ptr<int>();
1372
1373 // syevj_params controls the numerical accuracy of syevj
1374 // by default the tolerance is set to machine accuracy
1375 // the maximum number of iteration of Jacobi method by default is 100
1376 // cuSOLVER documentations says: "15 sweeps are good enough to converge to machine accuracy"
1377 // LAPACK has SVD routine based on similar Jacobi algorithm (gesvj) and there a maximum of 30 iterations is set
1378 // Let's use the default values for now
1379 syevjInfo_t syevj_params;
1380 TORCH_CUSOLVER_CHECK(cusolverDnCreateSyevjInfo(&syevj_params));
1381 TORCH_CUSOLVER_CHECK(cusolverDnXsyevjSetSortEig(syevj_params, 1));
1382
1383 auto handle = at::cuda::getCurrentCUDASolverDnHandle();
1384
1385 // get the optimal work size and allocate workspace tensor
1386 int lwork;
1387 at::cuda::solver::syevjBatched_bufferSize<scalar_t>(
1388 handle,
1389 jobz,
1390 uplo,
1391 n,
1392 vectors_data,
1393 lda,
1394 values_data,
1395 &lwork,
1396 syevj_params,
1397 batch_size);
1398
1399 // allocate workspace storage on device
1400 auto& allocator = *at::cuda::getCUDADeviceAllocator();
1401 auto work_data = allocator.allocate(sizeof(scalar_t) * lwork);
1402 at::cuda::solver::syevjBatched<scalar_t>(
1403 handle,
1404 jobz,
1405 uplo,
1406 n,
1407 vectors_data,
1408 lda,
1409 values_data,
1410 static_cast<scalar_t*>(work_data.get()),
1411 lwork,
1412 infos_data,
1413 syevj_params,
1414 batch_size);
1415 TORCH_CUSOLVER_CHECK(cusolverDnDestroySyevjInfo(syevj_params));
1416 }
1417
linalg_eigh_cusolver_syevd(const Tensor & eigenvalues,const Tensor & eigenvectors,const Tensor & infos,bool upper,bool compute_eigenvectors)1418 static void linalg_eigh_cusolver_syevd(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
1419 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(eigenvectors.scalar_type(), "linalg_eigh_cuda", [&] {
1420 apply_syevd<scalar_t>(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
1421 });
1422 }
1423
linalg_eigh_cusolver_syevj(const Tensor & eigenvalues,const Tensor & eigenvectors,const Tensor & infos,bool upper,bool compute_eigenvectors)1424 static void linalg_eigh_cusolver_syevj(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
1425 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(eigenvectors.scalar_type(), "linalg_eigh_cuda", [&] {
1426 apply_syevj<scalar_t>(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
1427 });
1428 }
1429
linalg_eigh_cusolver_syevj_batched(const Tensor & eigenvalues,const Tensor & eigenvectors,const Tensor & infos,bool upper,bool compute_eigenvectors)1430 static void linalg_eigh_cusolver_syevj_batched(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
1431 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(eigenvectors.scalar_type(), "linalg_eigh_cuda", [&] {
1432 apply_syevj_batched<scalar_t>(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
1433 });
1434 }
1435
linalg_eigh_cusolver(const Tensor & eigenvalues,const Tensor & eigenvectors,const Tensor & infos,bool upper,bool compute_eigenvectors)1436 void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
1437 // for ROCm's hipSolver, syevj is fastest.
1438 #ifdef USE_ROCM
1439 linalg_eigh_cusolver_syevj(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
1440 #else
1441 if (use_cusolver_syevj_batched_ && batchCount(eigenvectors) > 1 && eigenvectors.size(-1) <= 32) {
1442 // Use syevjBatched for batched matrix operation when matrix size <= 32
1443 // See https://github.com/pytorch/pytorch/pull/53040#issuecomment-788264724
1444 linalg_eigh_cusolver_syevj_batched(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
1445 } else if (eigenvectors.scalar_type() == at::kFloat && eigenvectors.size(-1) >= 32 && eigenvectors.size(-1) <= 512) {
1446 // syevj is better than syevd for float32 dtype and matrix sizes 32x32 - 512x512
1447 // See https://github.com/pytorch/pytorch/pull/53040#issuecomment-788264724
1448 linalg_eigh_cusolver_syevj(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
1449 } else {
1450 linalg_eigh_cusolver_syevd(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
1451 }
1452 #endif
1453 }
1454
1455 // The 'apply_' word is used for templated by dtype functions that call an API routine
1456 // underneath. Since the cusolver API has a slightly different structure we do not prepend
1457 // apply_ to this function.
lu_factor_looped_cusolver(const Tensor & self,const Tensor & pivots,const Tensor & infos,bool get_pivots)1458 void lu_factor_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots) {
1459 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
1460 self.scalar_type(),
1461 "lu_factor_cusolver",
1462 [&self,
1463 &pivots,
1464 &infos,
1465 get_pivots]() {
1466 const auto m = cuda_int_cast(self.size(-2), "m");
1467 const auto n = cuda_int_cast(self.size(-1), "n");
1468 const auto lda = std::max<int>(1, m);
1469 const auto self_stride = matrixStride(self);
1470 const auto batch_size = batchCount(self);
1471 const auto self_data = self.data_ptr<scalar_t>();
1472 const auto infos_data = infos.data_ptr<int>();
1473
1474 const auto pivots_data = get_pivots ? pivots.data_ptr<int>() : nullptr;
1475 const auto pivots_stride = get_pivots ? pivots.size(-1) : 0;
1476
1477 const auto handle = at::cuda::getCurrentCUDASolverDnHandle();
1478 for (auto batch = decltype(batch_size){0}; batch < batch_size; ++batch) {
1479 at::cuda::solver::getrf<scalar_t>(
1480 handle, m, n,
1481 self_data + batch * self_stride,
1482 lda,
1483 get_pivots ? pivots_data + batch * pivots_stride : nullptr,
1484 infos_data + batch
1485 );
1486 }
1487 });
1488
1489 // Necessary because cuSOLVER uses nan for outputs that correspond to 0 in MAGMA for non-pivoted LU.
1490 // https://github.com/pytorch/pytorch/issues/53879#issuecomment-830633572
1491 if (!get_pivots) {
1492 // nan_to_num does not work for complex inputs
1493 // https://github.com/pytorch/pytorch/issues/59247
1494 if (self.is_complex()) {
1495 self.copy_(at::where(self.eq(self), self, at::scalar_tensor(0., self.options())));
1496 } else {
1497 at::nan_to_num_(const_cast<Tensor&>(self), 0, std::numeric_limits<double>::infinity(),
1498 -std::numeric_limits<double>::infinity());
1499 }
1500 }
1501 }
1502
lu_solve_looped_cusolver(const Tensor & LU,const Tensor & pivots,const Tensor & B,TransposeType transpose)1503 void lu_solve_looped_cusolver(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) {
1504 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(LU.scalar_type(), "lu_solve_cusolver", [&] {
1505 const auto trans = to_cublas(transpose);
1506 int n = cuda_int_cast(LU.size(-2), "n");
1507 int nrhs = cuda_int_cast(B.size(-1), "nrhs");
1508 auto batch_size = batchCount(B);
1509 auto info = at::zeros({1}, LU.options().dtype(kInt));
1510 auto info_data = info.data_ptr<int>();
1511 auto b_data = B.data_ptr<scalar_t>();
1512 auto lu_data = LU.data_ptr<scalar_t>();
1513 auto pivots_data = pivots.data_ptr<int>();
1514 auto pivots_stride = pivots.dim() > 1 ? pivots.stride(-2) : 0;
1515 auto lu_stride = LU.dim() > 2 ? LU.stride(-3) : 0;
1516 auto b_stride = matrixStride(B);
1517 int leading_dimension = cuda_int_cast(std::max<int>(1, n), "leading_dimension");
1518
1519 // lu and pivots tensors can be broadcast to b
1520 // here we construct a helper indexing tensor to linearly index into lu and pivots
1521 IntArrayRef lu_batch_shape(LU.sizes().data(), LU.dim() - 2);
1522 IntArrayRef b_batch_shape(B.sizes().data(), B.dim() - 2);
1523 BroadcastLinearIndices lu_index(
1524 batchCount(LU), lu_batch_shape, b_batch_shape);
1525
1526 auto handle = at::cuda::getCurrentCUDASolverDnHandle();
1527 for (auto batch = decltype(batch_size){0}; batch < batch_size; ++batch) {
1528 int64_t lu_index_i = lu_index(batch);
1529 at::cuda::solver::getrs<scalar_t>(
1530 handle,
1531 n,
1532 nrhs,
1533 lu_data + lu_index_i * lu_stride,
1534 leading_dimension,
1535 pivots_data + lu_index_i * pivots_stride,
1536 b_data + batch * b_stride,
1537 leading_dimension,
1538 info_data,
1539 trans);
1540
1541 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info.item().toInt() == 0);
1542 }
1543 });
1544 }
1545
1546 #endif // USE_LINALG_SOLVER
1547
1548 } // namespace at::native
1549