1 #include <cstdint>
2 #include <c10/util/Exception.h>
3 #include <c10/core/Scalar.h>
4 #include <c10/core/ScalarType.h>
5 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
6 #include <ATen/core/Tensor.h>
7 #include <ATen/core/NamedTensor.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/ExpandUtils.h>
10 #include <ATen/OpMathType.h>
11 #include <ATen/TensorUtils.h>
12 #include <ATen/cuda/CUDABlas.h>
13 #include <ATen/cuda/tunable/Tunable.h>
14 #include <ATen/cuda/tunable/TunableGemm.h>
15 #include <ATen/native/Resize.h>
16 #include <c10/util/MaybeOwned.h>
17 #include <ATen/native/cuda/RowwiseScaledMM.h>
18
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/Functions.h>
21 #include <ATen/NativeFunctions.h>
22 #else
23 #include <ATen/ops/_addmm_activation_native.h>
24 #include <ATen/ops/_efficientzerotensor.h>
25 #include <ATen/ops/_scaled_mm_native.h>
26 #include <ATen/ops/_unsafe_view_native.h>
27 #include <ATen/ops/abs.h>
28 #include <ATen/ops/addmm_native.h>
29 #include <ATen/ops/addmv_native.h>
30 #include <ATen/ops/baddbmm_native.h>
31 #include <ATen/ops/bmm_native.h>
32 #include <ATen/ops/copy_native.h>
33 #include <ATen/ops/dot_native.h>
34 #include <ATen/ops/empty.h>
35 #include <ATen/ops/gelu.h>
36 #include <ATen/ops/max.h>
37 #include <ATen/ops/mm_native.h>
38 #include <ATen/ops/mul.h>
39 #include <ATen/ops/relu.h>
40 #include <ATen/ops/ones.h>
41 #include <ATen/ops/scalar_tensor_native.h>
42 #include <ATen/ops/vdot_native.h>
43 #endif
44
45 namespace at::native {
46
47 namespace {
48
49 // TODO: https://github.com/pytorch/pytorch/pull/59380#pullrequestreview-725310492
resolve_conj_if_indicated(const Tensor & tensor,bool resolve_conj)50 c10::MaybeOwned<Tensor> inline resolve_conj_if_indicated(const Tensor& tensor, bool resolve_conj) {
51 if (resolve_conj && tensor.is_conj()) {
52 return c10::MaybeOwned<Tensor>::owned(tensor.resolve_conj());
53 } else {
54 return c10::MaybeOwned<Tensor>::borrowed(tensor);
55 }
56 }
57
prepare_matrix_for_cublas(const Tensor & tensor,bool & transpose_tensor,bool transpose_result)58 c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, bool transpose_result) {
59 if (tensor.is_non_overlapping_and_dense()) { // common case
60 transpose_tensor = tensor.is_contiguous();
61 return resolve_conj_if_indicated(tensor, transpose_result ? transpose_tensor : !transpose_tensor);
62 }
63 IntArrayRef tensor_strides = tensor.strides();
64 IntArrayRef tensor_sizes = tensor.sizes();
65 if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {
66 transpose_tensor = false;
67 return resolve_conj_if_indicated(tensor, !transpose_result);
68 } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) {
69 transpose_tensor = true;
70 return resolve_conj_if_indicated(tensor, transpose_result);
71 } else {
72 transpose_tensor = true;
73 return c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous));
74 }
75 }
76
prepare_matrix_for_cublas(const Tensor & tensor,bool & transpose_tensor)77 c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor) {
78 if (tensor.is_non_overlapping_and_dense()) { // common case
79 transpose_tensor = tensor.is_contiguous();
80 return resolve_conj_if_indicated(tensor, true);
81 }
82 IntArrayRef tensor_strides = tensor.strides();
83 IntArrayRef tensor_sizes = tensor.sizes();
84 if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {
85 transpose_tensor = false;
86 return resolve_conj_if_indicated(tensor, true);
87 } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) {
88 transpose_tensor = true;
89 return resolve_conj_if_indicated(tensor, true);
90 } else {
91 transpose_tensor = true;
92 return c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous));
93 }
94 }
95
96 struct cublasCommonArgs {
cublasCommonArgsat::native::__anon030780d90111::cublasCommonArgs97 cublasCommonArgs(const Tensor& mat1, const Tensor& mat2, Tensor& c) {
98 bool transpose_result, transpose_mat1, transpose_mat2;
99 result = prepare_matrix_for_cublas(c, transpose_result);
100 mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_mat1, transpose_result);
101 matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_mat2, transpose_result);
102 auto mat1_sizes = mat1.sizes();
103 auto mat2_sizes = mat2.sizes();
104 if (transpose_result) {
105 transpose_mat1 = !transpose_mat1;
106 transpose_mat2 = !transpose_mat2;
107 mat1_sizes = mata->sizes();
108 mat2_sizes = matb->sizes();
109 }
110
111 m = mat1_sizes[transpose_result ? 1 : 0];
112 k = mat1_sizes[transpose_result ? 0 : 1];
113 n = mat2_sizes[transpose_result ? 0 : 1];
114 lda = mata->stride((transpose_mat1 == transpose_result) ? 1 : 0);
115 ldb = matb->stride((transpose_mat2 == transpose_result) ? 1 : 0);
116 result_ld = result->stride(transpose_result ? 0 : 1);
117 transa = transpose_mat1 ? mata->is_conj() ? 'c' : 't' : 'n';
118 transb = transpose_mat2 ? matb->is_conj() ? 'c' : 't' : 'n';
119 }
120 char transa, transb;
121 int64_t m, n, k;
122 int64_t lda, ldb, result_ld;
123 c10::MaybeOwned<Tensor> mata, matb, result;
124 };
125 } // namespace
126
prepare_batch_matrix_for_cublas(const Tensor & tensor,bool & transpose_tensor,int64_t & ld_tensor,bool transpose_result,int64_t m,int64_t n)127 c10::MaybeOwned<Tensor> prepare_batch_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, int64_t& ld_tensor, bool transpose_result, int64_t m, int64_t n) {
128 IntArrayRef tensor_strides = tensor.strides();
129 c10::MaybeOwned<Tensor> tensor_;
130 int fast_dim = transpose_result ? 2 : 1;
131 int leading_dim = transpose_result ? 1 : 2;
132
133 if (tensor_strides[fast_dim] == 1 &&
134 (tensor_strides[leading_dim] >= std::max<int64_t>(1, m))) {
135 transpose_tensor = false;
136 tensor_ = resolve_conj_if_indicated(tensor, true);
137 ld_tensor = tensor_->strides()[leading_dim];
138 } else if ((tensor_strides[leading_dim] == 1) &&
139 (tensor_strides[fast_dim] >= std::max<int64_t>(1, n))) {
140 transpose_tensor = true;
141 tensor_ = resolve_conj_if_indicated(tensor, false);
142 ld_tensor = tensor_->strides()[fast_dim];
143 } else {
144 transpose_tensor = !transpose_result;
145 // gemm call requires leading dimension and stride parameters to be non-zero
146 bool is_stride_non_zero = tensor.strides()[1] != 0 && tensor.strides()[2] != 0;
147 if (tensor.is_contiguous() && is_stride_non_zero) {
148 tensor_ = resolve_conj_if_indicated(tensor, transpose_result);
149 } else {
150 tensor_ = c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous));
151 }
152 ld_tensor = tensor_->strides()[1];
153 }
154
155 return tensor_;
156 }
157
158 namespace {
159
160 enum class Activation {
161 None,
162 RELU,
163 GELU,
164 };
165
activation_to_gemm_and_blas_arg(Activation a)166 cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
167 switch (a) {
168 case Activation::None:
169 return cuda::blas::GEMMAndBiasActivationEpilogue::None;
170 case Activation::RELU:
171 return cuda::blas::GEMMAndBiasActivationEpilogue::RELU;
172 case Activation::GELU:
173 return cuda::blas::GEMMAndBiasActivationEpilogue::GELU;
174 default:
175 TORCH_CHECK(false);
176 return cuda::blas::GEMMAndBiasActivationEpilogue::None;
177 }
178 }
179
getDisableAddmmCudaLt()180 static bool getDisableAddmmCudaLt() {
181 static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT");
182 #ifdef USE_ROCM
183 // allow both CUDA and HIP env var names for ROCm builds
184 // also, current default for ROCm builds is disable by default
185 if (env_value == nullptr) {
186 env_value = std::getenv("DISABLE_ADDMM_HIP_LT");
187 }
188 if (env_value != nullptr && strcmp(env_value, "0") == 0) {
189 return false;
190 }
191 return true;
192 #else
193 if (env_value != nullptr && strcmp(env_value, "1") == 0) {
194 return true;
195 }
196 return false;
197 #endif
198 }
199
200 #ifdef USE_ROCM
isSupportedHipLtROCmArch(int index)201 static bool isSupportedHipLtROCmArch(int index) {
202 hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
203 std::string device_arch = prop->gcnArchName;
204 static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
205 for (std::string arch : archs) {
206 size_t substring = device_arch.find(arch);
207 if (substring != std::string::npos) {
208 return true;
209 }
210 }
211 TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
212 return false;
213 }
214 #endif
215
216 template <typename scalar_t>
launchTunableGemmAndBias(cublasCommonArgs & args,const Scalar & alpha,const scalar_t * bias,cuda::blas::GEMMAndBiasActivationEpilogue activation)217 static void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) {
218 bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
219 bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
220 at::cuda::tunable::GemmAndBiasParams<scalar_t> params;
221 params.transa = args.transa;
222 params.transb = args.transb;
223 params.m = args.m;
224 params.n = args.n;
225 params.k = args.k;
226 params.alpha = alpha.to<at::opmath_type<scalar_t>>();
227 params.a = args.mata->const_data_ptr<scalar_t>();
228 params.lda = args.lda;
229 params.b = args.matb->const_data_ptr<scalar_t>();
230 params.ldb = args.ldb;
231 params.c = args.result->data_ptr<scalar_t>();
232 params.ldc = args.result_ld;
233 params.bias = bias;
234 params.activation = activation;
235 if (transa_ && transb_) {
236 static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T> gemm{};
237 gemm(¶ms);
238 }
239 else if (transa_ && !transb_) {
240 static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N> gemm{};
241 gemm(¶ms);
242 }
243 else if (!transa_ && transb_) {
244 static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T> gemm{};
245 gemm(¶ms);
246 }
247 else if (!transa_ && !transb_) {
248 static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N> gemm{};
249 gemm(¶ms);
250 }
251 else {
252 TORCH_CHECK(false, "unreachable");
253 }
254 }
255
addmm_out_cuda_impl(Tensor & result,const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,Activation activation=Activation::None)256 Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None) {
257 // Make sure to keep addmm_cuda below in sync with this code; it
258 // preflights a check to try to avoid actually needing to call
259 // expand().
260 TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D");
261 TORCH_CHECK(
262 mat1.dtype() == mat2.dtype(),
263 "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()
264 )
265
266 TensorArg targs[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
267 checkAllSameGPU(__func__, targs);
268
269 IntArrayRef mat1_sizes = mat1.sizes();
270 IntArrayRef mat2_sizes = mat2.sizes();
271 IntArrayRef self__sizes;
272 bool useLtInterface = false;
273 static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt();
274 at::ScalarType scalar_type = self.scalar_type();
275 c10::MaybeOwned<Tensor> self_;
276 if (&result != &self) {
277 #if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11040)) || defined(USE_ROCM)
278 // Strangely, if mat2 has only 1 row or column, we get
279 // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
280 // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
281 // is to use lt interface only when self is bias.
282 // for cuda 11.4, cublasLtMatmul is activated
283 // the last two conditions is to skip 16b transA and non-trans-B having
284 // leading dim >> rows when they are sliced from a large tensor
285 // see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
286 if (!disable_addmm_cuda_lt) {
287 useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
288 result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
289 self.is_contiguous() && result.is_contiguous() &&
290 #ifdef USE_ROCM
291 isSupportedHipLtROCmArch(self.device().index()) &&
292 (scalar_type == at::ScalarType::Float ||
293 scalar_type == at::ScalarType::Half ||
294 scalar_type == at::ScalarType::BFloat16) &&
295 #else
296 (scalar_type == at::ScalarType::Double ||
297 scalar_type == at::ScalarType::Float ||
298 scalar_type == at::ScalarType::Half ||
299 scalar_type == at::ScalarType::BFloat16) &&
300 #endif
301 #if (defined(CUDA_VERSION) && CUDA_VERSION >= 12010 && !defined(USE_ROCM))
302 mat2_sizes[0] > 1 && mat2_sizes[1] > 1;
303 #else
304 mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
305 mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
306 mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
307 // avoid leading dim >> rows bugs
308 ((mat1.strides()[0] == 1 && mat1.strides()[1] == mat1_sizes[0]) ||
309 (mat1.strides()[1] == 1 && mat1.strides()[0] == mat1_sizes[1]) ||
310 (scalar_type != at::ScalarType::Half &&
311 scalar_type != at::ScalarType::BFloat16)) &&
312 ((mat2.strides()[0] == 1 && mat2.strides()[1] == mat2_sizes[0]) ||
313 (mat2.strides()[1] == 1 && mat2.strides()[0] == mat2_sizes[1]) ||
314 (scalar_type != at::ScalarType::Half &&
315 scalar_type != at::ScalarType::BFloat16));
316 #endif
317 }
318 #endif
319 if (!useLtInterface) {
320 self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
321 }
322 self__sizes = self_->sizes();
323 } else {
324 #if defined(USE_ROCM)
325 useLtInterface = !disable_addmm_cuda_lt &&
326 result.dim() == 2 && result.is_contiguous() &&
327 isSupportedHipLtROCmArch(self.device().index()) &&
328 (scalar_type == at::ScalarType::Float ||
329 scalar_type == at::ScalarType::Half ||
330 scalar_type == at::ScalarType::BFloat16);
331 #endif
332 self_ = c10::MaybeOwned<Tensor>::borrowed(self);
333 self__sizes = self_->sizes();
334 TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
335 TORCH_CHECK(self__sizes[0] == mat1_sizes[0], "self_ dim 0 must match mat1 dim 0");
336 TORCH_CHECK(self__sizes[1] == mat2_sizes[1], "self_ dim 1 must match mat2 dim 1");
337 }
338
339 if (&result != &self) {
340 at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]});
341 if (beta.toComplexDouble() != 0.0 && !useLtInterface) {
342 at::native::copy_(result, *self_);
343 }
344 }
345
346
347 IntArrayRef result_sizes = result.sizes();
348 if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
349 return result;
350 }
351
352 cublasCommonArgs args(mat1, mat2, result);
353
354 if (mat1.numel() == 0) {
355 // By definition, when beta==0, values in self should be ignored. nans and infs
356 // should not propagate
357 if (beta.toComplexDouble() == 0.) {
358 return result.zero_();
359 }
360 // TODO: We could squeeze some perf by calling at::cuda::mul_out here instead, to bypass the dispatcher.
361 // That requires some fixing some internal build dependencies though.
362 return at::mul_out(
363 result,
364 self.expand(result.sizes()),
365 at::native::scalar_tensor(
366 beta,
367 self.scalar_type(),
368 std::nullopt /* layout */,
369 at::kCPU,
370 std::nullopt /* pin_memory */));
371 }
372
373 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
374
375 if (useLtInterface) {
376 #if defined(USE_ROCM)
377 AT_DISPATCH_FLOATING_TYPES_AND2(
378 at::ScalarType::Half,
379 at::ScalarType::BFloat16,
380 scalar_type,
381 "addmm_cuda_lt",
382 [&] {
383 auto tuning_ctx = at::cuda::tunable::getTuningContext();
384 if (tuning_ctx->IsTunableOpEnabled()) {
385 launchTunableGemmAndBias<scalar_t>(
386 args,
387 alpha,
388 (&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
389 activation_to_gemm_and_blas_arg(activation));
390 }
391 else {
392 at::cuda::blas::gemm_and_bias<scalar_t>(
393 args.transa == 't',
394 args.transb == 't',
395 args.m,
396 args.n,
397 args.k,
398 alpha.to<at::opmath_type<scalar_t>>(),
399 args.mata->const_data_ptr<scalar_t>(),
400 args.lda,
401 args.matb->const_data_ptr<scalar_t>(),
402 args.ldb,
403 // This condition is needed for mm case on ROCm for hipblasLt path.
404 // Passing the bias ptr as null to avoid accuracy issues for mm case.
405 (&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
406 args.result->data_ptr<scalar_t>(),
407 args.result_ld,
408 activation_to_gemm_and_blas_arg(activation)
409 );
410 }});
411 #else
412 auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
413 #if (defined(CUDA_VERSION) && (CUDA_VERSION < 11080))
414 // GELU is not supported (and does not compile!) prior
415 // to CUDA 11.4. Have observed accuracy issues with
416 // GELU epilogue in 11.4; disabling the GELU epilogue
417 // path for CUDA version < 11.8.
418 if (activation == Activation::GELU)
419 activation_epilogue = cuda::blas::GEMMAndBiasActivationEpilogue::None;
420 #endif
421
422 AT_DISPATCH_FLOATING_TYPES_AND2(
423 at::ScalarType::Half,
424 at::ScalarType::BFloat16,
425 scalar_type,
426 "addmm_cuda_lt",
427 [&] {
428 auto tuning_ctx = at::cuda::tunable::getTuningContext();
429 if (tuning_ctx->IsTunableOpEnabled()) {
430 launchTunableGemmAndBias<scalar_t>(
431 args,
432 alpha,
433 self.const_data_ptr<scalar_t>(),
434 activation_epilogue);
435 }
436 else {
437 at::cuda::blas::gemm_and_bias<scalar_t>(
438 args.transa == 't',
439 args.transb == 't',
440 args.m,
441 args.n,
442 args.k,
443 alpha.to<at::opmath_type<scalar_t>>(),
444 args.mata->const_data_ptr<scalar_t>(),
445 args.lda,
446 args.matb->const_data_ptr<scalar_t>(),
447 args.ldb,
448 self.const_data_ptr<scalar_t>(),
449 args.result->data_ptr<scalar_t>(),
450 args.result_ld,
451 activation_epilogue
452 );
453 }});
454 #endif
455 } else
456 {
457 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
458 at::ScalarType::Half,
459 at::ScalarType::BFloat16,
460 scalar_type,
461 "addmm_cuda",
462 [&] {
463 using opmath_t = at::opmath_type<scalar_t>;
464 opmath_t alpha_val = alpha.to<opmath_t>();
465 opmath_t beta_val = beta.to<opmath_t>();
466 const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
467 const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
468 scalar_t* result_ptr = args.result->mutable_data_ptr<scalar_t>();
469 at::cuda::blas::gemm<scalar_t>(
470 args.transa,
471 args.transb,
472 args.m,
473 args.n,
474 args.k,
475 alpha_val,
476 mat1_ptr,
477 args.lda,
478 mat2_ptr,
479 args.ldb,
480 beta_val,
481 result_ptr,
482 args.result_ld);
483 });
484 switch (activation) {
485 case Activation::RELU:
486 at::relu_(const_cast<Tensor&>(*args.result));
487 break;
488 case Activation::GELU:
489 at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
490 break;
491 default: break;
492 }
493 }
494
495 // Preprocessor gate here needs to match the inverse of the check
496 // gating activation_to_gemm_and_blas_arg above; here we are manually
497 // performing a post-GELU because we weren't able to use the GELU
498 // epilogue above.
499 #if !(defined(CUDA_VERSION) && CUDA_VERSION >= 11080) && !defined(USE_ROCM)
500 if (useLtInterface && activation == Activation::GELU) {
501 at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
502 }
503 #endif
504
505 if (!result.is_same(*args.result)) {
506 result.copy_(*args.result);
507 }
508 return result;
509 }
510
baddbmm_out_cuda_impl(const Tensor & result,const Tensor & self,const Tensor & batch1,const Tensor & batch2,const Scalar & beta,const Scalar & alpha)511 const Tensor& baddbmm_out_cuda_impl(const Tensor& result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
512 // handle pathological cases that blas may not like
513 if (result.numel() == 0) {
514 return result;
515 } else if (batch1.size(2) == 0) {
516 if (beta.to<c10::complex<double>>() == 0.0) {
517 return result.zero_();
518 } else {
519 return result.mul_(beta);
520 }
521 }
522
523 bool transpose_result = false;
524 c10::MaybeOwned<Tensor> result_;
525 IntArrayRef result_strides = result.strides();
526 IntArrayRef result_sizes = result.sizes();
527
528 if ((result_strides[1] == 1) &&
529 ((result_sizes[2] == 1) || (result_strides[2] >= std::max<int64_t>(1, result_sizes[1])))) {
530 result_ = resolve_conj_if_indicated(result, true);
531 } else if ((result_strides[2] == 1) &&
532 (result_sizes[1] == 1 || (result_strides[1] >= std::max<int64_t>(1, result_sizes[2])))) {
533 transpose_result = true;
534 result_ = resolve_conj_if_indicated(result, true);
535 } else {
536 result_ = c10::MaybeOwned<Tensor>::owned(result.transpose(1, 2).clone(at::MemoryFormat::Contiguous).transpose(1, 2));
537 }
538
539 int leading_dim = transpose_result ? 1 : 2;
540
541 int64_t m = result_sizes[transpose_result ? 2 : 1];
542 int64_t n = result_sizes[leading_dim];
543 int64_t k = (transpose_result ? batch2 : batch1).sizes()[leading_dim];
544
545 int64_t lda, ldb, ldc;
546 bool transpose_batch1, transpose_batch2;
547 auto batch1_ = prepare_batch_matrix_for_cublas(transpose_result ? batch2 : batch1, transpose_batch1, lda, transpose_result, m, k);
548 auto batch2_ = prepare_batch_matrix_for_cublas(transpose_result ? batch1 : batch2, transpose_batch2, ldb, transpose_result, k, n);
549
550 ldc = result_->strides()[leading_dim];
551 int64_t num_batches = result_->sizes()[0];
552
553 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result_->is_conj());
554
555 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "baddbmm_cuda", [&] {
556 using opmath_t = at::opmath_type<scalar_t>;
557 opmath_t alpha_val = alpha.to<opmath_t>();
558 opmath_t beta_val = beta.to<opmath_t>();
559 const scalar_t* batch1_ptr = batch1_->const_data_ptr<scalar_t>();
560 const scalar_t* batch2_ptr = batch2_->const_data_ptr<scalar_t>();
561 scalar_t* result_ptr = result_->mutable_data_ptr<scalar_t>();
562 const auto transa = transpose_batch1 ? batch1_->is_conj() ? 'c' : 't' : 'n';
563 const auto transb = transpose_batch2 ? batch2_->is_conj() ? 'c' : 't' : 'n';
564 // If batch is 1 call gemm rather than bgemm
565 if (num_batches == 1) {
566 at::cuda::blas::gemm<scalar_t>(
567 transa, transb,
568 m, n, k,
569 alpha_val,
570 batch1_ptr, lda,
571 batch2_ptr, ldb,
572 beta_val,
573 result_ptr, ldc);
574 } else {
575 at::cuda::blas::bgemm<scalar_t>(
576 transa, transb,
577 m, n, k,
578 alpha_val,
579 batch1_ptr, lda, batch1_->strides()[0],
580 batch2_ptr, ldb, batch2_->strides()[0],
581 beta_val,
582 result_ptr, ldc, result_->strides()[0],
583 num_batches
584 );
585 }
586 });
587 if (!result.is_same(*result_)) {
588 result.copy_(*result_);
589 }
590 return result;
591 }
592
593 } // anonymous namespace
594
TORCH_IMPL_FUNC(addmm_out_cuda)595 TORCH_IMPL_FUNC(addmm_out_cuda)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
596 addmm_out_cuda_impl(const_cast<Tensor&>(result), self, mat1, mat2, beta, alpha);
597 }
598
TORCH_IMPL_FUNC(addmm_activation_out_cuda)599 TORCH_IMPL_FUNC(addmm_activation_out_cuda)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, bool use_gelu, const Tensor& result) {
600 addmm_out_cuda_impl(const_cast<Tensor&>(result), self, mat1, mat2, beta, alpha, use_gelu ? Activation::GELU : Activation::RELU);
601 }
602
TORCH_IMPL_FUNC(mm_out_cuda)603 TORCH_IMPL_FUNC(mm_out_cuda)(const Tensor& self, const Tensor& mat2, const Tensor& result) {
604 addmm_out_cuda_impl(const_cast<Tensor&>(result), result, self, mat2, 0, 1);
605 }
606
TORCH_IMPL_FUNC(baddbmm_out_cuda)607 TORCH_IMPL_FUNC(baddbmm_out_cuda)(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
608 {
609 at::NoNamesGuard guard;
610 baddbmm_out_cuda_impl(result, self, batch1, batch2, beta, alpha);
611 }
612 }
613
TORCH_IMPL_FUNC(bmm_out_cuda)614 TORCH_IMPL_FUNC(bmm_out_cuda)(const Tensor& batch1, const Tensor& batch2, const Tensor &result) {
615 Scalar beta(0.0);
616 Scalar alpha(1.0);
617 {
618 NoNamesGuard guard;
619 baddbmm_out_cuda_impl(result, result, batch1, batch2, beta, alpha);
620 }
621 }
622
623 namespace {
624
dot_check(const Tensor & self,const Tensor & other)625 inline void dot_check(const Tensor& self, const Tensor& other) {
626 TORCH_CHECK(
627 self.dim() == 1 && other.dim() == 1,
628 "1D tensors expected, but got ",
629 self.dim(),
630 "D and ",
631 other.dim(),
632 "D tensors");
633 TORCH_CHECK(
634 self.scalar_type() == other.scalar_type(),
635 "dot : expected both vectors to have same dtype, but found ",
636 self.scalar_type(),
637 " and ",
638 other.scalar_type());
639 TORCH_CHECK(
640 self.numel() == other.numel(),
641 "inconsistent tensor size, expected tensor [",
642 self.numel(),
643 "] and src [",
644 other.numel(),
645 "] to have the same number of elements, but got ",
646 self.numel(),
647 " and ",
648 other.numel(),
649 " elements respectively");
650 TORCH_CHECK(
651 (self.numel() <= INT_MAX) && (self.stride(0) <= INT_MAX) &&
652 (other.stride(0) <= INT_MAX),
653 "dot only supports n, incx, incy with the bound [val] <= %d",
654 INT_MAX);
655 }
656
657 } // anonymous namespace
658
dot_cuda(const Tensor & self,const Tensor & other)659 Tensor dot_cuda(const Tensor& self, const Tensor& other) {
660 if (self.is_complex()) {
661 if (self.is_conj()) {
662 if (other.is_conj()) {
663 return (dot_cuda(self.conj(), other.conj())).conj();
664 } else {
665 return vdot_cuda(self.conj(), other);
666 }
667 } else if (other.is_conj()) {
668 return vdot_cuda(other.conj(), self);
669 }
670 }
671
672 at::NoNamesGuard guard;
673 dot_check(self, other);
674
675 const int n = static_cast<int>(self.numel());
676 int incx = static_cast<int>(self.stride(0));
677 int incy = static_cast<int>(other.stride(0));
678 if (n == 1) {
679 incx = 1;
680 incy = 1;
681 }
682
683 if (self._is_zerotensor() || other._is_zerotensor()) {
684 return at::_efficientzerotensor({}, self.options());
685 }
686
687 return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
688 ScalarType::Half, ScalarType::BFloat16,
689 self.scalar_type(), "dot",
690 [&] {
691 Tensor result = at::empty({}, self.options());
692
693 auto handle = at::cuda::getCurrentCUDABlasHandle();
694 at::cuda::blas::PointerModeGuard pointerModeGuard(handle, CUBLAS_POINTER_MODE_DEVICE);
695 at::cuda::blas::dot<scalar_t>(
696 handle,
697 n,
698 self.const_data_ptr<scalar_t>(),
699 incx,
700 other.const_data_ptr<scalar_t>(),
701 incy,
702 result.mutable_data_ptr<scalar_t>());
703
704 return result;
705 });
706 }
707
vdot_cuda(const Tensor & self,const Tensor & other)708 Tensor vdot_cuda(const Tensor& self, const Tensor& other) {
709 if (!self.is_complex()) {
710 return dot_cuda(self, other);
711 }
712
713 if (self.is_conj()) {
714 if (other.is_conj()) {
715 return vdot_cuda(other.conj(), self.conj());
716 } else {
717 return dot_cuda(self.conj(), other);
718 }
719 } else if (other.is_conj()) {
720 return (dot_cuda(self, other.conj())).conj();
721 }
722
723 at::NoNamesGuard guard;
724 dot_check(self, other);
725
726 if (self._is_zerotensor() || other._is_zerotensor()) {
727 return at::_efficientzerotensor({}, self.options());
728 }
729
730 const int n = static_cast<int>(self.numel());
731 int incx = static_cast<int>(self.stride(0));
732 int incy = static_cast<int>(other.stride(0));
733 if (n == 1) {
734 incx = 1;
735 incy = 1;
736 }
737
738 return AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "vdot", [&] {
739 Tensor result = at::empty({}, self.options());
740
741 auto handle = at::cuda::getCurrentCUDABlasHandle();
742 at::cuda::blas::PointerModeGuard pointerModeGuard(
743 handle, CUBLAS_POINTER_MODE_DEVICE);
744 at::cuda::blas::vdot<scalar_t>(
745 handle,
746 n,
747 self.const_data_ptr<scalar_t>(),
748 incx,
749 other.const_data_ptr<scalar_t>(),
750 incy,
751 result.mutable_data_ptr<scalar_t>());
752
753 return result;
754 });
755 }
756
TORCH_IMPL_FUNC(addmv_out_cuda)757 TORCH_IMPL_FUNC(addmv_out_cuda)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta_, const Scalar& alpha_, const Tensor& result) {
758 c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)});
759 auto betaval = beta_.toComplexDouble();
760 if (mat.numel() == 0) {
761 // shortcut for an empty matrix
762 // By definition, when beta==0, values in self should be ignored. nans and infs
763 // should not propagate
764 if (betaval == 0.0) {
765 result.zero_();
766 } else {
767 at::mul_out(
768 const_cast<Tensor&>(result),
769 self,
770 at::native::scalar_tensor(
771 beta_, self.scalar_type(), std::nullopt /* layout */, at::kCPU, std::nullopt /* pin_memory */));
772 }
773 } else {
774 if (!result.is_same(*self_) && betaval != 0.0) { //if beta is 0, result contents will be zeroed later
775 at::native::copy_(const_cast<Tensor&>(result), *self_);
776 }
777 if (result.numel() != 0) {
778 auto r_stride = result.stride(0);
779 auto vec_stride = vec.stride(0);
780
781 // Check for contiguity of `vec` and update `vec_stride` accordingly
782 const auto vec_contiguous = vec_stride == 0 ? vec.contiguous() : vec;
783 // A vector can be contiguous and have a stride of zero if it has it is of length 1
784 vec_stride = std::max<int64_t>(vec_contiguous.stride(0), 1LL);
785
786 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, mat.scalar_type(), "addmv_impl_cuda", [&] {
787 auto beta = beta_.to<scalar_t>();
788 auto alpha = alpha_.to<scalar_t>();
789 if (mat.stride(0) == 1 && mat.stride(1) >= std::max<int64_t>(1, mat.size(0))) {
790 at::cuda::blas::gemv<scalar_t>('n',
791 mat.size(0), mat.size(1), alpha, mat.const_data_ptr<scalar_t>(), mat.stride(1), vec_contiguous.const_data_ptr<scalar_t>(),
792 vec_stride, beta, result.mutable_data_ptr<scalar_t>(), r_stride);
793 }
794 else if (mat.stride(1) == 1 && mat.stride(0) >= std::max<int64_t>(1, mat.size(1))) {
795 at::cuda::blas::gemv<scalar_t>('t',
796 mat.size(1), mat.size(0), alpha, mat.const_data_ptr<scalar_t>(), mat.stride(0),
797 vec_contiguous.const_data_ptr<scalar_t>(), vec_stride, beta, result.mutable_data_ptr<scalar_t>(), r_stride);
798 }
799 else {
800 Tensor cmat = mat.contiguous();
801 at::cuda::blas::gemv<scalar_t>('t',
802 mat.size(1), mat.size(0), alpha, cmat.const_data_ptr<scalar_t>(), cmat.stride(0),
803 vec_contiguous.const_data_ptr<scalar_t>(), vec_stride, beta, result.mutable_data_ptr<scalar_t>(), r_stride);
804 }
805 });
806 }
807 }
808 }
809
_int_mm_out_cuda(const Tensor & self,const Tensor & mat2,Tensor & result)810 Tensor& _int_mm_out_cuda(const Tensor& self, const Tensor& mat2, Tensor& result) {
811 // NOTE: cuBLAS is currently broken for some combination of transposed inputs.
812 TORCH_CHECK(self.dim() == 2, "Expected self to be of dimension 2 but got ", self.dim());
813 TORCH_CHECK(mat2.dim() == 2, "Expected mat2 to be of dimension 2 but got ", mat2.dim());
814 TORCH_CHECK(self.size(0) > 16, "self.size(0) needs to be greater than 16, but got ", self.size(0));
815 TORCH_CHECK(self.size(1) > 0 && self.size(1) % 8 == 0, "self.size(1) needs to be greater than 0 and a multiple of 8, but got ", self.size(1));
816 TORCH_CHECK(self.size(1) == mat2.size(0), "self.size(1) needs to match mat2.size(0) but got ", self.size(1), " and ", mat2.size(0));
817 TORCH_CHECK(mat2.size(1) > 0 && mat2.size(1) % 8 == 0, "mat2.size(1) needs to be greater than 0 and a multiple of 8, but got ", mat2.size(1));
818
819 TORCH_CHECK(result.dtype() == at::kInt, "Expected result dtype to be of type kInt but got ", result.dtype());
820 TORCH_CHECK(result.size(0) == self.size(0), "Expected result.size(0) to be ", self.size(0), " but got ", result.size(0));
821 TORCH_CHECK(result.size(1) == mat2.size(1), "Expected result.size(1) to be ", mat2.size(1), " but got ", result.size(1));
822
823 TORCH_CHECK(result.dim() == 2, "Expected result to be of dimension 2 but got ", result.dim());
824
825 TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous.");
826
827 #if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11070)) || defined(USE_ROCM)
828 cublasCommonArgs args(self, mat2, result);
829
830 at::cuda::blas::int8_gemm(
831 args.transa == 't',
832 args.transb == 't',
833 args.m,
834 args.n,
835 args.k,
836 args.mata->data_ptr<int8_t>(),
837 args.lda,
838 args.matb->data_ptr<int8_t>(),
839 args.ldb,
840 args.result->data_ptr<int32_t>(),
841 args.result_ld);
842
843 if (!result.is_same(*args.result)) {
844 result.copy_(*args.result);
845 }
846 #else
847 #if !defined(USE_ROCM) && defined(CUDA_VERSION)
848 TORCH_CHECK(false, "_int_mm_out_cuda not compiled for CUDA ", CUDA_VERSION);
849 #else
850 TORCH_CHECK(false, "_int_mm_out_cuda not compiled for this platform.");
851 #endif
852 #endif
853
854 return result;
855 }
856
_int_mm_cuda(const Tensor & self,const Tensor & mat2)857 Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
858 Tensor result = at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt));
859 return _int_mm_out_cuda(self, mat2, result);
860 }
861
_scaled_mm_allowed_device()862 static bool _scaled_mm_allowed_device() {
863 auto dprops = at::cuda::getCurrentDeviceProperties();
864 #ifdef USE_ROCM
865 std::string device_arch = dprops->gcnArchName;
866 static const std::vector<std::string> archs = {"gfx940", "gfx941", "gfx942"};
867 for (std::string arch : archs) {
868 size_t substring = device_arch.find(arch);
869 if (substring != std::string::npos) {
870 return true;
871 }
872 }
873 return false;
874 #else
875 return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
876 #endif
877 }
878
879 namespace{
880
881 enum class ScalingType {
882 TensorWise,
883 RowWise,
884 Error
885 };
886 /*
887 * Scaling Type Determination:
888 * ---------------------------
889 * Conditions and corresponding Scaling Types:
890 *
891 * - If scale_a.numel() == 1 && scale_b.numel() == 1:
892 * - Returns TensorWise.
893 *
894 * - Else if scale_a.dim() == 1 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n:
895 * - Returns RowWise.
896 *
897 * - Otherwise:
898 * - Returns Error.
899 */
900
901 // Validates the scale tensors to scaled_mm
902 // And returns the type of scaling/which kernel to use
get_scaling_type(const at::Tensor & scale_a,const at::Tensor & scale_b,int64_t dim_m,int64_t dim_n)903 ScalingType get_scaling_type(
904 const at::Tensor& scale_a,
905 const at::Tensor& scale_b,
906 int64_t dim_m,
907 int64_t dim_n) {
908 // Both Per-Tensor and Row-wise scaling expect fp32 tensors
909 TORCH_CHECK(
910 scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
911 "Both scale_a and scale_b must be float (fp32) tensors.");
912
913 // Check the singluar scale case for per-tensor scaling
914 if (scale_a.numel() == 1 && scale_b.numel() == 1) {
915 return ScalingType::TensorWise;
916 }
917
918 // For non-TensorWise scaling, enforce 2D input tensors
919 TORCH_CHECK(
920 scale_a.dim() == 2 && scale_b.dim() == 2,
921 "For non-TensorWise scaling, scale tensors must be 2-dimensional, "
922 "but got scale_a.dim()=",
923 scale_a.dim(),
924 " and scale_b.dim()=",
925 scale_b.dim());
926
927 // Check for RowWise scaling
928 if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 &&
929 scale_b.size(0) == 1 && scale_b.size(1) == dim_n) {
930 #if !defined(USE_ROCM) && !defined(_MSC_VER) || \
931 (defined(USE_ROCM) && ROCM_VERSION >= 60000)
932 TORCH_CHECK(
933 scale_a.is_contiguous() && scale_b.is_contiguous(),
934 "Both scale_a and scale_b must be contiguous for RowWise scaling.");
935 return ScalingType::RowWise;
936 #else
937 TORCH_CHECK(false, "Per-row scaling is not supported for this platform!");
938 return ScalingType::Error;
939 #endif
940 }
941
942 // If we reach here, the input doesn't match any valid scaling type
943 TORCH_CHECK(
944 false,
945 "Invalid scaling configuration. For TensorWise scaling, both scales should be scalar. "
946 "For RowWise scaling, scale_a should be (",
947 dim_m,
948 ", 1) and scale_b should be (1, ",
949 dim_n,
950 "). "
951 "Got scale_a.size()=(",
952 scale_a.size(0),
953 ", ",
954 scale_a.size(1),
955 ") and ",
956 "scale_b.size()=(",
957 scale_b.size(0),
958 ", ",
959 scale_b.size(1),
960 ")");
961
962 return ScalingType::Error;
963 }
964
965 } // namespace
966
967 // Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax
968 // Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default.
969 // If output matrix type is 16 or 32-bit type, neither scale_result is applied nor amax is computed.
970 // Known limitations:
971 // - Only works if mat1 is row-major and mat2 is column-major
972 // - Only works if matrices sizes are divisible by 32
973 // - If 1-dimensional tensors are used then scale_a should be size = mat1.size(0)
974 // and scale_b should have size = to mat2.size(1)
975 // Arguments:
976 // - `mat1`: the first operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2`
977 // - `mat2`: the second operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2`
978 // - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
979 // - `out_dtype`: the output dtype, can either be a float8 or a higher precision floating point type
980 // - `scale_a`: a scalar or 1-dimensional tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type
981 // - `scale_b`: a scalar or 1-dimensional tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type
982 // - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type
983 // - `use_fast_accum`: if true, enables fast float8 accumulation
984 // - `out`: a reference to the output tensor
985
986 Tensor&
_scaled_mm_out_cuda(const Tensor & mat1,const Tensor & mat2,const Tensor & scale_a,const Tensor & scale_b,const std::optional<at::Tensor> & bias,const std::optional<at::Tensor> & scale_result,std::optional<c10::ScalarType> out_dtype,bool use_fast_accum,Tensor & out)987 _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
988 const Tensor& scale_a,
989 const Tensor& scale_b,
990 const std::optional<at::Tensor>& bias,
991 const std::optional<at::Tensor>& scale_result,
992 std::optional<c10::ScalarType> out_dtype,
993 bool use_fast_accum,
994 Tensor& out) {
995 // Check sizes
996 bool allowed_device = _scaled_mm_allowed_device();
997 TORCH_CHECK(allowed_device, "torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+");
998 TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
999 TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
1000 TORCH_CHECK(
1001 mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
1002 mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
1003
1004 // Check what type of scaling we are doing based on inputs
1005 ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat2.size(1));
1006 TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported");
1007
1008 TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
1009 "scale_result must be a float scalar");
1010 TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
1011 " but got ", bias->numel());
1012 TORCH_CHECK(
1013 mat1.sizes()[1] % 16 == 0,
1014 "Expected trailing dimension of mat1 to be divisible by 16 ",
1015 "but got mat1 shape: (",
1016 mat1.sizes()[0],
1017 "x",
1018 mat1.sizes()[1],
1019 ".");
1020 TORCH_CHECK(mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0, "mat2 shape (", mat2.sizes()[0], "x",
1021 mat2.sizes()[1], " must be divisible by 16");
1022 // Check types
1023 TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
1024 TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
1025 TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
1026 // Type restrictions imposed by CuBLASLt as of CUDA-12.1
1027 TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
1028 "Multiplication of two Float8_e5m2 matrices is not supported");
1029 if (bias) {
1030 TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32");
1031 TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half,
1032 "Bias must be either Half or BFloat16, but got ", bias->scalar_type());
1033 TORCH_CHECK((out.scalar_type() != kFloat && out.scalar_type() != ScalarType::BFloat16) ||
1034 bias->scalar_type() == ScalarType::BFloat16,
1035 "Bias must be BFloat16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type());
1036 TORCH_CHECK(out.scalar_type() != ScalarType::Half || bias->scalar_type() == ScalarType::Half,
1037 "Bias must be Float16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type());
1038 }
1039 {
1040 auto bias_ = bias.value_or(Tensor());
1041 auto scale_result_ = scale_result.value_or(Tensor());
1042
1043 TensorArg targs[]{{out, "out", 0}, {mat1, "mat1", 1}, {mat2, "mat2", 2},
1044 {bias_, "bias", 3}, {scale_a, "scale_a", 4}, {scale_b, "scale_b", 5},
1045 {scale_result_, "scale_result", 6}};
1046 checkAllSameGPU(__func__, targs);
1047 }
1048 // Validation checks have passed lets resize the output to actual size
1049 IntArrayRef mat1_sizes = mat1.sizes();
1050 IntArrayRef mat2_sizes = mat2.sizes();
1051 at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
1052
1053 // We are doing row-wise scaling
1054 if (scaling_choice == ScalingType::RowWise) {
1055 TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precsion output types are supported for row-wise scaling.");
1056 at::cuda::detail::f8f8bf16_rowwise(
1057 mat1,
1058 mat2,
1059 scale_a,
1060 scale_b,
1061 bias,
1062 use_fast_accum,
1063 out);
1064 return out;
1065 }
1066
1067 cublasCommonArgs args(mat1, mat2, out);
1068 const auto out_dtype_ = args.result->scalar_type();
1069 TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
1070
1071 // Some scaled_gemms require an amax to populate lets create one here
1072 Tensor amax = at::empty({0}, mat1.options().dtype(ScalarType::Float));
1073
1074 #ifdef USE_ROCM
1075 auto tuning_ctx = at::cuda::tunable::getTuningContext();
1076 if (tuning_ctx->IsTunableOpEnabled()) {
1077 #define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \
1078 if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
1079 if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
1080 static at::cuda::tunable::ScaledGemmTunableOp< \
1081 at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \
1082 BLASOP_A, BLASOP_B> scaledgemm{}; \
1083 scaledgemm(¶ms); \
1084 } \
1085 else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
1086 static at::cuda::tunable::ScaledGemmTunableOp< \
1087 at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \
1088 BLASOP_A, BLASOP_B> scaledgemm{}; \
1089 scaledgemm(¶ms); \
1090 } \
1091 } \
1092 else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
1093 if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
1094 static at::cuda::tunable::ScaledGemmTunableOp< \
1095 at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \
1096 BLASOP_A, BLASOP_B> scaledgemm{}; \
1097 scaledgemm(¶ms); \
1098 } \
1099 else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
1100 static at::cuda::tunable::ScaledGemmTunableOp< \
1101 at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \
1102 BLASOP_A, BLASOP_B> scaledgemm{}; \
1103 scaledgemm(¶ms); \
1104 } \
1105 }
1106 AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] {
1107 bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
1108 bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
1109 at::cuda::tunable::ScaledGemmParams<scalar_t> params;
1110 params.transa = args.transa;
1111 params.transb = args.transb;
1112 params.m = args.m;
1113 params.n = args.n;
1114 params.k = args.k;
1115 params.a = args.mata->data_ptr();
1116 params.a_scale_ptr = scale_a.data_ptr();
1117 params.lda = args.lda;
1118 params.a_dtype = args.mata->scalar_type();
1119 params.b = args.matb->data_ptr();
1120 params.b_scale_ptr = scale_b.data_ptr();
1121 params.ldb = args.ldb;
1122 params.b_dtype = args.matb->scalar_type();
1123 params.bias_ptr = bias ? bias->data_ptr(): nullptr;
1124 params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
1125 params.c = args.result->data_ptr();
1126 params.c_scale_ptr = scale_result ? scale_result->data_ptr() : nullptr;
1127 params.ldc = args.result_ld;
1128 params.c_dtype = out_dtype_;
1129 params.amax_ptr = amax.data_ptr();
1130 params.use_fast_accum = use_fast_accum;
1131 if (transa_ && transb_) {
1132 TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
1133 }
1134 else if (transa_ && !transb_) {
1135 TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
1136 }
1137 else if (!transa_ && transb_) {
1138 TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
1139 }
1140 else if (!transa_ && !transb_) {
1141 TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
1142 }
1143 else {
1144 TORCH_CHECK(false, "unreachable");
1145 }
1146 }),
1147 kHalf, kBFloat16, kFloat8_e4m3fnuz, kFloat8_e5m2fnuz, AT_EXPAND(AT_FLOATING_TYPES));
1148 #undef TUNABLE_DISPATCH
1149 }
1150 else
1151 #endif
1152 {
1153 #if defined(USE_ROCM) && ROCM_VERSION >= 60200
1154 // hipBlasLT requires scaleD to be set to something in order to use AMAX
1155 auto dummy_options = TensorOptions().dtype(kFloat).device(kCUDA);
1156 auto dummy_scale = at::ones(1, dummy_options);
1157 #endif
1158 at::cuda::blas::scaled_gemm(
1159 args.transa,
1160 args.transb,
1161 args.m,
1162 args.n,
1163 args.k,
1164 args.mata->data_ptr(),
1165 scale_a.data_ptr(),
1166 args.lda,
1167 args.mata->scalar_type(),
1168 args.matb->data_ptr(),
1169 scale_b.data_ptr(),
1170 args.ldb,
1171 args.matb->scalar_type(),
1172 bias ? bias->data_ptr(): nullptr,
1173 bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
1174 args.result->data_ptr(),
1175 #if defined(USE_ROCM) && ROCM_VERSION >= 60200
1176 scale_result ? scale_result->data_ptr() : dummy_scale.data_ptr(),
1177 #else
1178 scale_result ? scale_result->data_ptr() : nullptr,
1179 #endif
1180 args.result_ld,
1181 out_dtype_,
1182 amax.data_ptr(),
1183 use_fast_accum);
1184 }
1185
1186 return out;
1187 }
1188
1189 Tensor
_scaled_mm_cuda(const Tensor & mat_a,const Tensor & mat_b,const Tensor & scale_a,const Tensor & scale_b,const std::optional<at::Tensor> & bias,const std::optional<at::Tensor> & scale_result,std::optional<c10::ScalarType> out_dtype,bool use_fast_accum)1190 _scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
1191 const Tensor& scale_a,
1192 const Tensor& scale_b,
1193 const std::optional<at::Tensor>& bias,
1194 const std::optional<at::Tensor>& scale_result,
1195 std::optional<c10::ScalarType> out_dtype,
1196 bool use_fast_accum) {
1197 const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
1198 Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
1199 return _scaled_mm_out_cuda(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
1200 }
1201
1202 } // namespace at::native
1203