1 // Note [BatchLinearAlgebraLib split implementation files]
2 //
3 // There are two files that implement the interfaces found in
4 // BatchLinearAlgebraLib.h
5 // - BatchLinearAlgebraLib.cpp
6 // - BatchLinearAlgebraLibBlas.cpp (this file)
7 //
8 // In order to support the ROCm build target, the use of cublas and
9 // cusolver APIs needed to be split into separate source files to
10 // accommodate the hipify step of the ROCm build process.
11 //
12 // To create this current file, the original file
13 // BatchLinearAlgebraLib.cpp was copied to
14 // BatchLinearAlgebraLibBlas.cpp, then any functions that used cusolver
15 // APIs were removed. Similarly, in the original file
16 // BatchLinearAlgebraLib.cpp, any use of cublas APIs was removed.
17 // The net result is a split of the BatchLinearAlgebraLib
18 // implementation files. The original file BatchLinearAlgebraLib.cpp
19 // contains the full, original git history for both files.
20 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
21 #include <ATen/Context.h>
22 #include <ATen/cuda/CUDAContext.h>
23 #include <ATen/Dispatch.h>
24 #include <ATen/ExpandUtils.h>
25 #include <ATen/cuda/PinnedMemoryAllocator.h>
26 #include <ATen/cuda/CUDABlas.h>
27 #include <ATen/cuda/CUDAEvent.h>
28 #include <c10/cuda/CUDAStream.h>
29 #include <c10/util/irange.h>
30
31 #include <ATen/native/LinearAlgebraUtils.h>
32 #include <ATen/native/TransposeType.h>
33 #include <ATen/native/cuda/MiscUtils.h>
34 #include <ATen/native/cuda/linalg/CUDASolver.h>
35 #include <ATen/native/cuda/linalg/BatchLinearAlgebraLib.h>
36
37 #ifndef AT_PER_OPERATOR_HEADERS
38 #include <ATen/Functions.h>
39 #else
40 #include <ATen/ops/arange.h>
41 #include <ATen/ops/empty.h>
42 #include <ATen/ops/nan_to_num.h>
43 #include <ATen/ops/ones.h>
44 #include <ATen/ops/scalar_tensor.h>
45 #include <ATen/ops/where.h>
46 #include <ATen/ops/zeros.h>
47 #endif
48
49 namespace at::native {
50
to_cublas(TransposeType trans)51 static cublasOperation_t to_cublas(TransposeType trans) {
52 switch (trans) {
53 case TransposeType::NoTranspose: return CUBLAS_OP_N;
54 case TransposeType::Transpose: return CUBLAS_OP_T;
55 case TransposeType::ConjTranspose: return CUBLAS_OP_C;
56 }
57 TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
58 }
59
60 // Some cuBLAS and cuSOLVER batched routines require input to be a device array of pointers to device individual matrices
61 // 'input' must be a contiguous tensor
62 template <typename scalar_t>
get_device_pointers(const Tensor & input)63 static Tensor get_device_pointers(const Tensor& input) {
64 auto input_data = input.const_data_ptr<scalar_t>();
65 int64_t input_mat_stride = matrixStride(input);
66
67 // cublas/cusolver interface requires 'int'
68 int batch_size = cuda_int_cast(batchCount(input), "batch_size");
69
70 // if batch_size==0, then start=0 and end=0
71 // if input_mat_stride==0, then step=sizeof(scalar_t)
72 return at::arange(
73 /*start=*/reinterpret_cast<int64_t>(input_data),
74 /*end=*/reinterpret_cast<int64_t>(input_data + batch_size * input_mat_stride),
75 /*step=*/static_cast<int64_t>(std::max<int64_t>(input_mat_stride, 1) * sizeof(scalar_t)),
76 input.options().dtype(at::kLong));
77 }
78
79 template <typename scalar_t>
apply_geqrf_batched(const Tensor & input,const Tensor & tau)80 void apply_geqrf_batched(const Tensor& input, const Tensor& tau) {
81 auto batch_size = cuda_int_cast(batchCount(input), "batch_size");
82 auto m = cuda_int_cast(input.size(-2), "m");
83 auto n = cuda_int_cast(input.size(-1), "n");
84 auto lda = std::max<int>(1, m);
85
86 // cuBLAS batched geqrf requires input to be the device array of pointers to device single matrices
87 Tensor input_ptr_array = get_device_pointers<scalar_t>(input);
88 Tensor tau_ptr_array = get_device_pointers<scalar_t>(tau.unsqueeze(-1));
89 auto input_ptr_array_data = reinterpret_cast<scalar_t**>(input_ptr_array.data_ptr());
90 auto tau_ptr_array_data = reinterpret_cast<scalar_t**>(tau_ptr_array.data_ptr());
91
92 int info;
93 auto handle = at::cuda::getCurrentCUDABlasHandle();
94 at::cuda::blas::geqrfBatched(handle, m, n, input_ptr_array_data, lda, tau_ptr_array_data, &info, batch_size);
95
96 // info only indicates wrong arguments to geqrfBatched call
97 // info is a host variable, we can check it without device synchronization
98 TORCH_INTERNAL_ASSERT(info == 0);
99 }
100
geqrf_batched_cublas(const Tensor & input,const Tensor & tau)101 void geqrf_batched_cublas(const Tensor& input, const Tensor& tau) {
102 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "geqrf_batched_cuda", [&]{
103 apply_geqrf_batched<scalar_t>(input, tau);
104 });
105 }
106
107 template <typename scalar_t>
apply_lu_factor_batched_cublas(const Tensor & A,const Tensor & pivots,const Tensor & infos,bool get_pivots)108 static void apply_lu_factor_batched_cublas(const Tensor& A, const Tensor& pivots, const Tensor& infos, bool get_pivots) {
109 // This function just works with square matrices
110 TORCH_INTERNAL_ASSERT(A.size(-2) == A.size(-1));
111
112 auto batch_size = cuda_int_cast(batchCount(A), "batch_size");;
113 auto n = cuda_int_cast(A.size(-2), "n");
114 auto lda = cuda_int_cast(std::max<int>(1, n), "lda");
115
116 auto pivots_data = get_pivots ? pivots.data_ptr<int>() : nullptr;
117 auto infos_data = infos.data_ptr<int>();
118 Tensor a_ptr_array = get_device_pointers<scalar_t>(A);
119 auto a_ptr_array_data = reinterpret_cast<scalar_t**>(a_ptr_array.data_ptr());
120
121 at::cuda::blas::getrfBatched(n, a_ptr_array_data, lda, pivots_data, infos_data, batch_size);
122 }
123
lu_factor_batched_cublas(const Tensor & A,const Tensor & pivots,const Tensor & infos,bool get_pivots)124 void lu_factor_batched_cublas(const Tensor& A, const Tensor& pivots, const Tensor& infos, bool get_pivots) {
125 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "lu_factor_cublas", [&]{
126 apply_lu_factor_batched_cublas<scalar_t>(A, pivots, infos, get_pivots);
127 });
128 }
129
130 template <typename scalar_t>
apply_lu_solve_batched_cublas(const Tensor & LU,const Tensor & pivots,const Tensor & B,TransposeType transpose)131 static void apply_lu_solve_batched_cublas(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) {
132 TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(B), "batch_size of LU and B must be the same");
133 TORCH_INTERNAL_ASSERT(batchCount(LU) == batchCount(pivots.unsqueeze(-1)), "batch_size of LU and pivots must be the same");
134 const auto trans = to_cublas(transpose);
135
136 auto pivots_data = pivots.const_data_ptr<int>();
137 auto batch_size = cuda_int_cast(batchCount(LU), "batch_size");;
138 auto m = cuda_int_cast(LU.size(-2), "m");
139 auto nrhs = cuda_int_cast(B.size(-1), "nrhs");
140 auto lda = cuda_int_cast(std::max<int>(1, m), "lda");
141 int info = 0;
142
143 Tensor lu_ptr_array = get_device_pointers<scalar_t>(LU);
144 Tensor b_ptr_array = get_device_pointers<scalar_t>(B);
145 auto lu_ptr_array_data = reinterpret_cast<const scalar_t* const*>(lu_ptr_array.const_data_ptr());
146 auto b_ptr_array_data = reinterpret_cast<scalar_t**>(b_ptr_array.data_ptr());
147
148 auto handle = at::cuda::getCurrentCUDABlasHandle();
149 at::cuda::blas::getrsBatched(handle, trans, m, nrhs, const_cast<scalar_t**>(lu_ptr_array_data),
150 lda, const_cast<int*>(pivots_data), b_ptr_array_data, lda, &info, batch_size);
151 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
152 }
153
lu_solve_batched_cublas(const Tensor & LU,const Tensor & pivots,const Tensor & B,TransposeType trans)154 void lu_solve_batched_cublas(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
155 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(LU.scalar_type(), "lu_solve_cublas", [&]{
156 apply_lu_solve_batched_cublas<scalar_t>(LU, pivots, B, trans);
157 });
158 }
159
160 template <typename scalar_t>
apply_triangular_solve(const Tensor & A,const Tensor & B,bool left,bool upper,TransposeType transpose,bool unitriangular)161 static void apply_triangular_solve(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
162 cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
163 const auto trans = to_cublas(transpose);
164 cublasSideMode_t side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
165 cublasDiagType_t diag = unitriangular ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
166
167 auto A_data = A.data_ptr<scalar_t>();
168 auto B_data = B.data_ptr<scalar_t>();
169 auto A_mat_stride = matrixStride(A);
170 auto B_mat_stride = matrixStride(B);
171 auto batch_size = batchCount(A);
172 // This allows to pass rectangular A and B when left = True
173 auto m = cuda_int_cast(left ? A.size(-1) : B.size(-2), "m");
174 auto n = cuda_int_cast(B.size(-1), "n");
175 auto lda = std::max<int>(1, cuda_int_cast(A.size(-2), "lda"));
176 auto ldb = std::max<int>(1, cuda_int_cast(B.size(-2), "ldb"));
177
178 auto alpha = scalar_t{1};
179
180 for (decltype(batch_size) i = 0; i < batch_size; i++) {
181 scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
182 scalar_t* B_working_ptr = &B_data[i * B_mat_stride];
183 auto handle = at::cuda::getCurrentCUDABlasHandle();
184 at::cuda::blas::trsm(handle, side, uplo, trans, diag, m, n, &alpha, A_working_ptr, lda, B_working_ptr, ldb);
185 }
186 }
187
triangular_solve_cublas(const Tensor & A,const Tensor & B,bool left,bool upper,TransposeType transpose,bool unitriangular)188 void triangular_solve_cublas(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
189 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{
190 apply_triangular_solve<scalar_t>(A, B, left, upper, transpose, unitriangular);
191 });
192 }
193
194 template <typename scalar_t>
apply_triangular_solve_batched(const Tensor & A,const Tensor & B,bool left,bool upper,TransposeType transpose,bool unitriangular)195 static void apply_triangular_solve_batched(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
196 cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
197 const auto trans = to_cublas(transpose);
198 cublasSideMode_t side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
199 cublasDiagType_t diag = unitriangular ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
200
201 auto batch_size = cuda_int_cast(batchCount(A), "batch_size");
202 // This allows to pass rectangular A and B when left = True
203 auto m = cuda_int_cast(left ? A.size(-1) : B.size(-2), "m");
204 auto n = cuda_int_cast(B.size(-1), "n");
205 auto lda = std::max<int>(1, cuda_int_cast(A.size(-2), "lda"));
206 auto ldb = std::max<int>(1, cuda_int_cast(B.size(-2), "ldb"));
207
208 auto alpha = scalar_t{1};
209
210 // cuBLAS batched trsm requires input to be the device array of pointers to device single matrices
211 Tensor A_ptr_array = get_device_pointers<scalar_t>(A);
212 Tensor B_ptr_array = get_device_pointers<scalar_t>(B);
213 auto A_ptr_array_data = reinterpret_cast<scalar_t**>(A_ptr_array.data_ptr());
214 auto B_ptr_array_data = reinterpret_cast<scalar_t**>(B_ptr_array.data_ptr());
215
216 auto handle = at::cuda::getCurrentCUDABlasHandle();
217 at::cuda::blas::trsmBatched(handle, side, uplo, trans, diag, m, n, &alpha, A_ptr_array_data, lda, B_ptr_array_data, ldb, batch_size);
218 }
219
triangular_solve_batched_cublas(const Tensor & A,const Tensor & B,bool left,bool upper,TransposeType transpose,bool unitriangular)220 void triangular_solve_batched_cublas(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
221 // Workaround the following a bug on CUDA < 12.1
222 // RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasStrsmBatched
223 // See https://github.com/pytorch/pytorch/issues/79191#issuecomment-1154222580
224 #if defined(CUSOLVER_VERSION) && CUSOLVER_VERSION < 12100
225 constexpr auto max_batch_size = 524280;
226 if (B.size(-1) > max_batch_size) {
227 auto n_chunks = (B.size(-1) + max_batch_size - 1) / max_batch_size; // ceildiv
228 auto splits = B.split(n_chunks, /*dim=*/-1);
229 for (const Tensor& b : splits) {
230 triangular_solve_batched_cublas(A, b, left, upper, transpose, unitriangular);
231 }
232 return;
233 }
234 #endif
235 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{
236 apply_triangular_solve_batched<scalar_t>(A, B, left, upper, transpose, unitriangular);
237 });
238 }
239
240 template <typename scalar_t>
apply_gels_batched(const Tensor & A,Tensor & B,Tensor & infos)241 inline void apply_gels_batched(const Tensor& A, Tensor& B, Tensor& infos) {
242 auto trans = CUBLAS_OP_N;
243 auto m = cuda_int_cast(A.size(-2), "m");
244 auto n = cuda_int_cast(A.size(-1), "n");
245
246 auto nrhs = cuda_int_cast(B.size(-1), "nrhs");
247 // cuBLAS from cuda10 and older doesn't work with nrhs == 0 (cuda11 works)
248 // so we need to put this early return
249 if (nrhs == 0) {
250 return;
251 }
252
253 auto batch_size = cuda_int_cast(batchCount(B), "batch_size");
254 auto lda = std::max<int>(1, m);
255 auto ldb = std::max<int>(1, m);
256
257 // cuBLAS's requirement
258 TORCH_CHECK(
259 m >= n,
260 "torch.linalg.lstsq: only overdetermined systems (input.size(-2) >= input.size(-1)) are allowed on CUDA with cuBLAS backend.");
261
262 // cuBLAS documentation says:
263 // Matrices Aarray[i] should not overlap; otherwise, undefined behavior is expected.
264 // explicitly broadcast the batch dimensions of A
265 IntArrayRef A_batch_sizes(A.sizes().data(), A.dim() - 2);
266 IntArrayRef B_batch_sizes(B.sizes().data(), B.dim() - 2);
267 std::vector<int64_t> expand_batch_portion = at::infer_size(A_batch_sizes, B_batch_sizes);
268 expand_batch_portion.insert(expand_batch_portion.end(), {A.size(-2), A.size(-1)});
269 Tensor A_expanded = A.expand({expand_batch_portion});
270 Tensor A_broadcasted = cloneBatchedColumnMajor(A_expanded);
271
272 // cuBLAS batched gels requires input to be the device array of pointers to device single matrices
273 Tensor A_ptr_array = get_device_pointers<scalar_t>(A_broadcasted);
274 Tensor B_ptr_array = get_device_pointers<scalar_t>(B);
275 auto A_ptr_array_data = reinterpret_cast<scalar_t**>(A_ptr_array.data_ptr());
276 auto B_ptr_array_data = reinterpret_cast<scalar_t**>(B_ptr_array.data_ptr());
277
278 auto infos_data = infos.data_ptr<int>();
279 auto handle = at::cuda::getCurrentCUDABlasHandle();
280 int info;
281
282 at::cuda::blas::gelsBatched<scalar_t>(
283 handle, trans, m, n, nrhs,
284 A_ptr_array_data, lda,
285 B_ptr_array_data, ldb,
286 &info,
287 infos_data,
288 batch_size);
289
290 // negative info indicates that an argument to gelsBatched call is invalid
291 TORCH_INTERNAL_ASSERT(info == 0);
292 }
293
294 // This is a type dispatching helper function for 'apply_gels_batched'
gels_batched_cublas(const Tensor & a,Tensor & b,Tensor & infos)295 void gels_batched_cublas(const Tensor& a, Tensor& b, Tensor& infos) {
296 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(a.scalar_type(), "gels_batched_cublas", [&]{
297 apply_gels_batched<scalar_t>(a, b, infos);
298 });
299 }
300
301 } // namespace at::native
302