xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLibBlas.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Note [BatchLinearAlgebraLib split implementation files]
2 //
3 // There are two files that implement the interfaces found in
4 // BatchLinearAlgebraLib.h
5 // - BatchLinearAlgebraLib.cpp
6 // - BatchLinearAlgebraLibBlas.cpp (this file)
7 //
8 // In order to support the ROCm build target, the use of cublas and
9 // cusolver APIs needed to be split into separate source files to
10 // accommodate the hipify step of the ROCm build process.
11 //
12 // To create this current file, the original file
13 // BatchLinearAlgebraLib.cpp was copied to
14 // BatchLinearAlgebraLibBlas.cpp, then any functions that used cusolver
15 // APIs were removed. Similarly, in the original file
16 // BatchLinearAlgebraLib.cpp, any use of cublas APIs was removed.
17 // The net result is a split of the BatchLinearAlgebraLib
18 // implementation files. The original file BatchLinearAlgebraLib.cpp
19 // contains the full, original git history for both files.
20 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
21 #include <ATen/Context.h>
22 #include <ATen/cuda/CUDAContext.h>
23 #include <ATen/Dispatch.h>
24 #include <ATen/ExpandUtils.h>
25 #include <ATen/cuda/PinnedMemoryAllocator.h>
26 #include <ATen/cuda/CUDABlas.h>
27 #include <ATen/cuda/CUDAEvent.h>
28 #include <c10/cuda/CUDAStream.h>
29 #include <c10/util/irange.h>
30 
31 #include <ATen/native/LinearAlgebraUtils.h>
32 #include <ATen/native/TransposeType.h>
33 #include <ATen/native/cuda/MiscUtils.h>
34 #include <ATen/native/cuda/linalg/CUDASolver.h>
35 #include <ATen/native/cuda/linalg/BatchLinearAlgebraLib.h>
36 
37 #ifndef AT_PER_OPERATOR_HEADERS
38 #include <ATen/Functions.h>
39 #else
40 #include <ATen/ops/arange.h>
41 #include <ATen/ops/empty.h>
42 #include <ATen/ops/nan_to_num.h>
43 #include <ATen/ops/ones.h>
44 #include <ATen/ops/scalar_tensor.h>
45 #include <ATen/ops/where.h>
46 #include <ATen/ops/zeros.h>
47 #endif
48 
49 namespace at::native {
50 
to_cublas(TransposeType trans)51 static cublasOperation_t to_cublas(TransposeType trans) {
52   switch (trans) {
53     case TransposeType::NoTranspose: return CUBLAS_OP_N;
54     case TransposeType::Transpose: return CUBLAS_OP_T;
55     case TransposeType::ConjTranspose: return CUBLAS_OP_C;
56   }
57   TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
58 }
59 
60 // Some cuBLAS and cuSOLVER batched routines require input to be a device array of pointers to device individual matrices
61 // 'input' must be a contiguous tensor
62 template <typename scalar_t>
get_device_pointers(const Tensor & input)63 static Tensor get_device_pointers(const Tensor& input) {
64   auto input_data = input.const_data_ptr<scalar_t>();
65   int64_t input_mat_stride = matrixStride(input);
66 
67   // cublas/cusolver interface requires 'int'
68   int batch_size = cuda_int_cast(batchCount(input), "batch_size");
69 
70   // if batch_size==0, then start=0 and end=0
71   // if input_mat_stride==0, then step=sizeof(scalar_t)
72   return at::arange(
73       /*start=*/reinterpret_cast<int64_t>(input_data),
74       /*end=*/reinterpret_cast<int64_t>(input_data + batch_size * input_mat_stride),
75       /*step=*/static_cast<int64_t>(std::max<int64_t>(input_mat_stride, 1) * sizeof(scalar_t)),
76       input.options().dtype(at::kLong));
77 }
78 
79 template <typename scalar_t>
apply_geqrf_batched(const Tensor & input,const Tensor & tau)80 void apply_geqrf_batched(const Tensor& input, const Tensor& tau) {
81   auto batch_size = cuda_int_cast(batchCount(input), "batch_size");
82   auto m = cuda_int_cast(input.size(-2), "m");
83   auto n = cuda_int_cast(input.size(-1), "n");
84   auto lda = std::max<int>(1, m);
85 
86   // cuBLAS batched geqrf requires input to be the device array of pointers to device single matrices
87   Tensor input_ptr_array = get_device_pointers<scalar_t>(input);
88   Tensor tau_ptr_array = get_device_pointers<scalar_t>(tau.unsqueeze(-1));
89   auto input_ptr_array_data = reinterpret_cast<scalar_t**>(input_ptr_array.data_ptr());
90   auto tau_ptr_array_data = reinterpret_cast<scalar_t**>(tau_ptr_array.data_ptr());
91 
92   int info;
93   auto handle = at::cuda::getCurrentCUDABlasHandle();
94   at::cuda::blas::geqrfBatched(handle, m, n, input_ptr_array_data, lda, tau_ptr_array_data, &info, batch_size);
95 
96   // info only indicates wrong arguments to geqrfBatched call
97   // info is a host variable, we can check it without device synchronization
98   TORCH_INTERNAL_ASSERT(info == 0);
99 }
100 
geqrf_batched_cublas(const Tensor & input,const Tensor & tau)101 void geqrf_batched_cublas(const Tensor& input, const Tensor& tau) {
102   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "geqrf_batched_cuda", [&]{
103     apply_geqrf_batched<scalar_t>(input, tau);
104   });
105 }
106 
107 template <typename scalar_t>
apply_lu_factor_batched_cublas(const Tensor & A,const Tensor & pivots,const Tensor & infos,bool get_pivots)108 static void apply_lu_factor_batched_cublas(const Tensor& A, const Tensor& pivots, const Tensor& infos, bool get_pivots) {
109   // This function just works with square matrices
110   TORCH_INTERNAL_ASSERT(A.size(-2) == A.size(-1));
111 
112   auto batch_size = cuda_int_cast(batchCount(A), "batch_size");;
113   auto n = cuda_int_cast(A.size(-2), "n");
114   auto lda = cuda_int_cast(std::max<int>(1, n), "lda");
115 
116   auto pivots_data = get_pivots ? pivots.data_ptr<int>() : nullptr;
117   auto infos_data = infos.data_ptr<int>();
118   Tensor a_ptr_array = get_device_pointers<scalar_t>(A);
119   auto a_ptr_array_data = reinterpret_cast<scalar_t**>(a_ptr_array.data_ptr());
120 
121   at::cuda::blas::getrfBatched(n, a_ptr_array_data, lda, pivots_data, infos_data, batch_size);
122 }
123 
lu_factor_batched_cublas(const Tensor & A,const Tensor & pivots,const Tensor & infos,bool get_pivots)124 void lu_factor_batched_cublas(const Tensor& A, const Tensor& pivots, const Tensor& infos, bool get_pivots) {
125   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "lu_factor_cublas", [&]{
126     apply_lu_factor_batched_cublas<scalar_t>(A, pivots, infos, get_pivots);
127   });
128 }
129 
130 template <typename scalar_t>
apply_lu_solve_batched_cublas(const Tensor & LU,const Tensor & pivots,const Tensor & B,TransposeType transpose)131 static void apply_lu_solve_batched_cublas(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) {
132   TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(B), "batch_size of LU and B must be the same");
133   TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(pivots.unsqueeze(-1)), "batch_size of LU and pivots must be the same");
134   const auto trans = to_cublas(transpose);
135 
136   auto pivots_data = pivots.const_data_ptr<int>();
137   auto batch_size = cuda_int_cast(batchCount(LU), "batch_size");;
138   auto m = cuda_int_cast(LU.size(-2), "m");
139   auto nrhs = cuda_int_cast(B.size(-1), "nrhs");
140   auto lda = cuda_int_cast(std::max<int>(1, m), "lda");
141   int info = 0;
142 
143   Tensor lu_ptr_array = get_device_pointers<scalar_t>(LU);
144   Tensor b_ptr_array = get_device_pointers<scalar_t>(B);
145   auto lu_ptr_array_data = reinterpret_cast<const scalar_t* const*>(lu_ptr_array.const_data_ptr());
146   auto b_ptr_array_data = reinterpret_cast<scalar_t**>(b_ptr_array.data_ptr());
147 
148   auto handle = at::cuda::getCurrentCUDABlasHandle();
149   at::cuda::blas::getrsBatched(handle, trans, m, nrhs, const_cast<scalar_t**>(lu_ptr_array_data),
150     lda, const_cast<int*>(pivots_data), b_ptr_array_data, lda, &info, batch_size);
151   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
152 }
153 
lu_solve_batched_cublas(const Tensor & LU,const Tensor & pivots,const Tensor & B,TransposeType trans)154 void lu_solve_batched_cublas(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
155   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(LU.scalar_type(), "lu_solve_cublas", [&]{
156     apply_lu_solve_batched_cublas<scalar_t>(LU, pivots, B, trans);
157   });
158 }
159 
160 template <typename scalar_t>
apply_triangular_solve(const Tensor & A,const Tensor & B,bool left,bool upper,TransposeType transpose,bool unitriangular)161 static void apply_triangular_solve(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
162   cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
163   const auto trans = to_cublas(transpose);
164   cublasSideMode_t side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
165   cublasDiagType_t diag = unitriangular ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
166 
167   auto A_data = A.data_ptr<scalar_t>();
168   auto B_data = B.data_ptr<scalar_t>();
169   auto A_mat_stride = matrixStride(A);
170   auto B_mat_stride = matrixStride(B);
171   auto batch_size = batchCount(A);
172   // This allows to pass rectangular A and B when left = True
173   auto m = cuda_int_cast(left ? A.size(-1) : B.size(-2), "m");
174   auto n = cuda_int_cast(B.size(-1), "n");
175   auto lda = std::max<int>(1, cuda_int_cast(A.size(-2), "lda"));
176   auto ldb = std::max<int>(1, cuda_int_cast(B.size(-2), "ldb"));
177 
178   auto alpha = scalar_t{1};
179 
180   for (decltype(batch_size) i = 0; i < batch_size; i++) {
181     scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
182     scalar_t* B_working_ptr = &B_data[i * B_mat_stride];
183     auto handle = at::cuda::getCurrentCUDABlasHandle();
184     at::cuda::blas::trsm(handle, side, uplo, trans, diag, m, n, &alpha, A_working_ptr, lda, B_working_ptr, ldb);
185   }
186 }
187 
triangular_solve_cublas(const Tensor & A,const Tensor & B,bool left,bool upper,TransposeType transpose,bool unitriangular)188 void triangular_solve_cublas(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
189   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{
190     apply_triangular_solve<scalar_t>(A, B, left, upper, transpose, unitriangular);
191   });
192 }
193 
194 template <typename scalar_t>
apply_triangular_solve_batched(const Tensor & A,const Tensor & B,bool left,bool upper,TransposeType transpose,bool unitriangular)195 static void apply_triangular_solve_batched(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
196   cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
197   const auto trans = to_cublas(transpose);
198   cublasSideMode_t side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
199   cublasDiagType_t diag = unitriangular ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
200 
201   auto batch_size = cuda_int_cast(batchCount(A), "batch_size");
202   // This allows to pass rectangular A and B when left = True
203   auto m = cuda_int_cast(left ? A.size(-1) : B.size(-2), "m");
204   auto n = cuda_int_cast(B.size(-1), "n");
205   auto lda = std::max<int>(1, cuda_int_cast(A.size(-2), "lda"));
206   auto ldb = std::max<int>(1, cuda_int_cast(B.size(-2), "ldb"));
207 
208   auto alpha = scalar_t{1};
209 
210   // cuBLAS batched trsm requires input to be the device array of pointers to device single matrices
211   Tensor A_ptr_array = get_device_pointers<scalar_t>(A);
212   Tensor B_ptr_array = get_device_pointers<scalar_t>(B);
213   auto A_ptr_array_data = reinterpret_cast<scalar_t**>(A_ptr_array.data_ptr());
214   auto B_ptr_array_data = reinterpret_cast<scalar_t**>(B_ptr_array.data_ptr());
215 
216   auto handle = at::cuda::getCurrentCUDABlasHandle();
217   at::cuda::blas::trsmBatched(handle, side, uplo, trans, diag, m, n, &alpha, A_ptr_array_data, lda, B_ptr_array_data, ldb, batch_size);
218 }
219 
triangular_solve_batched_cublas(const Tensor & A,const Tensor & B,bool left,bool upper,TransposeType transpose,bool unitriangular)220 void triangular_solve_batched_cublas(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
221   // Workaround the following a bug on CUDA < 12.1
222   // RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasStrsmBatched
223   // See https://github.com/pytorch/pytorch/issues/79191#issuecomment-1154222580
224 #if defined(CUSOLVER_VERSION) && CUSOLVER_VERSION < 12100
225   constexpr auto max_batch_size = 524280;
226   if (B.size(-1) > max_batch_size) {
227     auto n_chunks = (B.size(-1) + max_batch_size - 1) / max_batch_size; // ceildiv
228     auto splits = B.split(n_chunks, /*dim=*/-1);
229     for (const Tensor& b : splits) {
230       triangular_solve_batched_cublas(A, b, left, upper, transpose, unitriangular);
231     }
232     return;
233   }
234 #endif
235   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{
236     apply_triangular_solve_batched<scalar_t>(A, B, left, upper, transpose, unitriangular);
237   });
238 }
239 
240 template <typename scalar_t>
apply_gels_batched(const Tensor & A,Tensor & B,Tensor & infos)241 inline void apply_gels_batched(const Tensor& A, Tensor& B, Tensor& infos) {
242   auto trans = CUBLAS_OP_N;
243   auto m = cuda_int_cast(A.size(-2), "m");
244   auto n = cuda_int_cast(A.size(-1), "n");
245 
246   auto nrhs = cuda_int_cast(B.size(-1), "nrhs");
247   // cuBLAS from cuda10 and older doesn't work with nrhs == 0 (cuda11 works)
248   // so we need to put this early return
249   if (nrhs == 0) {
250     return;
251   }
252 
253   auto batch_size = cuda_int_cast(batchCount(B), "batch_size");
254   auto lda = std::max<int>(1, m);
255   auto ldb = std::max<int>(1, m);
256 
257   // cuBLAS's requirement
258   TORCH_CHECK(
259     m >= n,
260     "torch.linalg.lstsq: only overdetermined systems (input.size(-2) >= input.size(-1)) are allowed on CUDA with cuBLAS backend.");
261 
262   // cuBLAS documentation says:
263   // Matrices Aarray[i] should not overlap; otherwise, undefined behavior is expected.
264   // explicitly broadcast the batch dimensions of A
265   IntArrayRef A_batch_sizes(A.sizes().data(), A.dim() - 2);
266   IntArrayRef B_batch_sizes(B.sizes().data(), B.dim() - 2);
267   std::vector<int64_t> expand_batch_portion = at::infer_size(A_batch_sizes, B_batch_sizes);
268   expand_batch_portion.insert(expand_batch_portion.end(), {A.size(-2), A.size(-1)});
269   Tensor A_expanded = A.expand({expand_batch_portion});
270   Tensor A_broadcasted = cloneBatchedColumnMajor(A_expanded);
271 
272   // cuBLAS batched gels requires input to be the device array of pointers to device single matrices
273   Tensor A_ptr_array = get_device_pointers<scalar_t>(A_broadcasted);
274   Tensor B_ptr_array = get_device_pointers<scalar_t>(B);
275   auto A_ptr_array_data = reinterpret_cast<scalar_t**>(A_ptr_array.data_ptr());
276   auto B_ptr_array_data = reinterpret_cast<scalar_t**>(B_ptr_array.data_ptr());
277 
278   auto infos_data = infos.data_ptr<int>();
279   auto handle = at::cuda::getCurrentCUDABlasHandle();
280   int info;
281 
282   at::cuda::blas::gelsBatched<scalar_t>(
283     handle, trans, m, n, nrhs,
284     A_ptr_array_data, lda,
285     B_ptr_array_data, ldb,
286     &info,
287     infos_data,
288     batch_size);
289 
290   // negative info indicates that an argument to gelsBatched call is invalid
291   TORCH_INTERNAL_ASSERT(info == 0);
292 }
293 
294 // This is a type dispatching helper function for 'apply_gels_batched'
gels_batched_cublas(const Tensor & a,Tensor & b,Tensor & infos)295 void gels_batched_cublas(const Tensor& a, Tensor& b, Tensor& infos) {
296   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(a.scalar_type(), "gels_batched_cublas", [&]{
297     apply_gels_batched<scalar_t>(a, b, infos);
298   });
299 }
300 
301 } // namespace at::native
302