xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Blas.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <cstdint>
2 #include <c10/util/Exception.h>
3 #include <c10/core/Scalar.h>
4 #include <c10/core/ScalarType.h>
5 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
6 #include <ATen/core/Tensor.h>
7 #include <ATen/core/NamedTensor.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/ExpandUtils.h>
10 #include <ATen/OpMathType.h>
11 #include <ATen/TensorUtils.h>
12 #include <ATen/cuda/CUDABlas.h>
13 #include <ATen/cuda/tunable/Tunable.h>
14 #include <ATen/cuda/tunable/TunableGemm.h>
15 #include <ATen/native/Resize.h>
16 #include <c10/util/MaybeOwned.h>
17 #include <ATen/native/cuda/RowwiseScaledMM.h>
18 
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/Functions.h>
21 #include <ATen/NativeFunctions.h>
22 #else
23 #include <ATen/ops/_addmm_activation_native.h>
24 #include <ATen/ops/_efficientzerotensor.h>
25 #include <ATen/ops/_scaled_mm_native.h>
26 #include <ATen/ops/_unsafe_view_native.h>
27 #include <ATen/ops/abs.h>
28 #include <ATen/ops/addmm_native.h>
29 #include <ATen/ops/addmv_native.h>
30 #include <ATen/ops/baddbmm_native.h>
31 #include <ATen/ops/bmm_native.h>
32 #include <ATen/ops/copy_native.h>
33 #include <ATen/ops/dot_native.h>
34 #include <ATen/ops/empty.h>
35 #include <ATen/ops/gelu.h>
36 #include <ATen/ops/max.h>
37 #include <ATen/ops/mm_native.h>
38 #include <ATen/ops/mul.h>
39 #include <ATen/ops/relu.h>
40 #include <ATen/ops/ones.h>
41 #include <ATen/ops/scalar_tensor_native.h>
42 #include <ATen/ops/vdot_native.h>
43 #endif
44 
45 namespace at::native {
46 
47 namespace {
48 
49 // TODO: https://github.com/pytorch/pytorch/pull/59380#pullrequestreview-725310492
resolve_conj_if_indicated(const Tensor & tensor,bool resolve_conj)50 c10::MaybeOwned<Tensor> inline resolve_conj_if_indicated(const Tensor& tensor, bool resolve_conj) {
51   if (resolve_conj && tensor.is_conj()) {
52     return c10::MaybeOwned<Tensor>::owned(tensor.resolve_conj());
53   } else {
54     return c10::MaybeOwned<Tensor>::borrowed(tensor);
55   }
56 }
57 
prepare_matrix_for_cublas(const Tensor & tensor,bool & transpose_tensor,bool transpose_result)58 c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, bool transpose_result) {
59   if (tensor.is_non_overlapping_and_dense()) { // common case
60       transpose_tensor = tensor.is_contiguous();
61       return resolve_conj_if_indicated(tensor, transpose_result ? transpose_tensor : !transpose_tensor);
62   }
63   IntArrayRef tensor_strides = tensor.strides();
64   IntArrayRef tensor_sizes = tensor.sizes();
65   if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {
66     transpose_tensor = false;
67     return resolve_conj_if_indicated(tensor, !transpose_result);
68   } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) {
69     transpose_tensor = true;
70     return resolve_conj_if_indicated(tensor, transpose_result);
71   } else {
72     transpose_tensor = true;
73     return c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous));
74   }
75 }
76 
prepare_matrix_for_cublas(const Tensor & tensor,bool & transpose_tensor)77 c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor) {
78   if (tensor.is_non_overlapping_and_dense()) { // common case
79       transpose_tensor = tensor.is_contiguous();
80       return resolve_conj_if_indicated(tensor, true);
81   }
82   IntArrayRef tensor_strides = tensor.strides();
83   IntArrayRef tensor_sizes = tensor.sizes();
84   if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {
85     transpose_tensor = false;
86     return resolve_conj_if_indicated(tensor, true);
87   } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) {
88     transpose_tensor = true;
89     return resolve_conj_if_indicated(tensor, true);
90   } else {
91     transpose_tensor = true;
92     return c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous));
93   }
94 }
95 
96 struct cublasCommonArgs {
cublasCommonArgsat::native::__anon030780d90111::cublasCommonArgs97   cublasCommonArgs(const Tensor& mat1, const Tensor& mat2, Tensor& c) {
98     bool transpose_result, transpose_mat1, transpose_mat2;
99     result = prepare_matrix_for_cublas(c, transpose_result);
100     mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_mat1, transpose_result);
101     matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_mat2, transpose_result);
102     auto mat1_sizes = mat1.sizes();
103     auto mat2_sizes = mat2.sizes();
104     if (transpose_result) {
105       transpose_mat1 = !transpose_mat1;
106       transpose_mat2 = !transpose_mat2;
107       mat1_sizes = mata->sizes();
108       mat2_sizes = matb->sizes();
109     }
110 
111     m = mat1_sizes[transpose_result ? 1 : 0];
112     k = mat1_sizes[transpose_result ? 0 : 1];
113     n = mat2_sizes[transpose_result ? 0 : 1];
114     lda = mata->stride((transpose_mat1 == transpose_result) ? 1 : 0);
115     ldb = matb->stride((transpose_mat2 == transpose_result) ? 1 : 0);
116     result_ld = result->stride(transpose_result ? 0 : 1);
117     transa = transpose_mat1 ?  mata->is_conj() ? 'c' : 't' : 'n';
118     transb = transpose_mat2 ?  matb->is_conj() ? 'c' : 't' : 'n';
119   }
120   char transa, transb;
121   int64_t m, n, k;
122   int64_t lda, ldb, result_ld;
123   c10::MaybeOwned<Tensor> mata, matb, result;
124 };
125 } // namespace
126 
prepare_batch_matrix_for_cublas(const Tensor & tensor,bool & transpose_tensor,int64_t & ld_tensor,bool transpose_result,int64_t m,int64_t n)127 c10::MaybeOwned<Tensor> prepare_batch_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, int64_t& ld_tensor, bool transpose_result, int64_t m, int64_t n) {
128   IntArrayRef tensor_strides = tensor.strides();
129   c10::MaybeOwned<Tensor> tensor_;
130   int fast_dim = transpose_result ? 2 : 1;
131   int leading_dim = transpose_result ? 1 : 2;
132 
133   if (tensor_strides[fast_dim] == 1 &&
134     (tensor_strides[leading_dim] >= std::max<int64_t>(1, m))) {
135     transpose_tensor = false;
136     tensor_ = resolve_conj_if_indicated(tensor, true);
137     ld_tensor = tensor_->strides()[leading_dim];
138   } else if ((tensor_strides[leading_dim] == 1) &&
139     (tensor_strides[fast_dim] >= std::max<int64_t>(1, n))) {
140     transpose_tensor = true;
141     tensor_ = resolve_conj_if_indicated(tensor, false);
142     ld_tensor = tensor_->strides()[fast_dim];
143   } else {
144     transpose_tensor = !transpose_result;
145     // gemm call requires leading dimension and stride parameters to be non-zero
146     bool is_stride_non_zero = tensor.strides()[1] != 0 && tensor.strides()[2] != 0;
147     if (tensor.is_contiguous() && is_stride_non_zero) {
148       tensor_ = resolve_conj_if_indicated(tensor, transpose_result);
149     } else {
150       tensor_ = c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous));
151     }
152     ld_tensor = tensor_->strides()[1];
153   }
154 
155   return tensor_;
156 }
157 
158 namespace {
159 
160 enum class Activation {
161   None,
162   RELU,
163   GELU,
164 };
165 
activation_to_gemm_and_blas_arg(Activation a)166 cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
167   switch (a) {
168     case Activation::None:
169       return cuda::blas::GEMMAndBiasActivationEpilogue::None;
170     case Activation::RELU:
171       return cuda::blas::GEMMAndBiasActivationEpilogue::RELU;
172     case Activation::GELU:
173       return cuda::blas::GEMMAndBiasActivationEpilogue::GELU;
174     default:
175       TORCH_CHECK(false);
176       return cuda::blas::GEMMAndBiasActivationEpilogue::None;
177   }
178 }
179 
getDisableAddmmCudaLt()180 static bool getDisableAddmmCudaLt() {
181     static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT");
182 #ifdef USE_ROCM
183     // allow both CUDA and HIP env var names for ROCm builds
184     // also, current default for ROCm builds is disable by default
185     if (env_value == nullptr) {
186         env_value = std::getenv("DISABLE_ADDMM_HIP_LT");
187     }
188     if (env_value != nullptr && strcmp(env_value, "0") == 0) {
189       return false;
190     }
191     return true;
192 #else
193     if (env_value != nullptr && strcmp(env_value, "1") == 0) {
194       return true;
195     }
196     return false;
197 #endif
198 }
199 
200 #ifdef USE_ROCM
isSupportedHipLtROCmArch(int index)201 static bool isSupportedHipLtROCmArch(int index) {
202     hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
203     std::string device_arch = prop->gcnArchName;
204     static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
205     for (std::string arch : archs) {
206         size_t substring = device_arch.find(arch);
207         if (substring != std::string::npos) {
208             return true;
209         }
210     }
211     TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
212     return false;
213 }
214 #endif
215 
216 template <typename scalar_t>
launchTunableGemmAndBias(cublasCommonArgs & args,const Scalar & alpha,const scalar_t * bias,cuda::blas::GEMMAndBiasActivationEpilogue activation)217 static void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) {
218   bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
219   bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
220   at::cuda::tunable::GemmAndBiasParams<scalar_t> params;
221   params.transa = args.transa;
222   params.transb = args.transb;
223   params.m = args.m;
224   params.n = args.n;
225   params.k = args.k;
226   params.alpha = alpha.to<at::opmath_type<scalar_t>>();
227   params.a = args.mata->const_data_ptr<scalar_t>();
228   params.lda = args.lda;
229   params.b = args.matb->const_data_ptr<scalar_t>();
230   params.ldb = args.ldb;
231   params.c = args.result->data_ptr<scalar_t>();
232   params.ldc = args.result_ld;
233   params.bias = bias;
234   params.activation = activation;
235   if (transa_ && transb_) {
236     static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T> gemm{};
237     gemm(&params);
238   }
239   else if (transa_ && !transb_) {
240     static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N> gemm{};
241     gemm(&params);
242   }
243   else if (!transa_ && transb_) {
244     static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T> gemm{};
245     gemm(&params);
246   }
247   else if (!transa_ && !transb_) {
248     static at::cuda::tunable::GemmAndBiasTunableOp<scalar_t, at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N> gemm{};
249     gemm(&params);
250   }
251   else {
252     TORCH_CHECK(false, "unreachable");
253   }
254 }
255 
addmm_out_cuda_impl(Tensor & result,const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,Activation activation=Activation::None)256 Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None) {
257   // Make sure to keep addmm_cuda below in sync with this code; it
258   // preflights a check to try to avoid actually needing to call
259   // expand().
260   TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D");
261   TORCH_CHECK(
262     mat1.dtype() == mat2.dtype(),
263     "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()
264   )
265 
266   TensorArg targs[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
267   checkAllSameGPU(__func__, targs);
268 
269   IntArrayRef mat1_sizes = mat1.sizes();
270   IntArrayRef mat2_sizes = mat2.sizes();
271   IntArrayRef self__sizes;
272   bool useLtInterface = false;
273   static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt();
274   at::ScalarType scalar_type = self.scalar_type();
275   c10::MaybeOwned<Tensor> self_;
276   if (&result != &self) {
277 #if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11040)) || defined(USE_ROCM)
278     // Strangely, if mat2 has only 1 row or column, we get
279     // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
280     // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
281     // is to use lt interface only when self is bias.
282     // for cuda 11.4, cublasLtMatmul is activated
283     // the last two conditions is to skip 16b transA and non-trans-B having
284     // leading dim >> rows when they are sliced from a large tensor
285     // see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
286     if (!disable_addmm_cuda_lt) {
287       useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
288           result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
289           self.is_contiguous() && result.is_contiguous() &&
290 #ifdef USE_ROCM
291           isSupportedHipLtROCmArch(self.device().index()) &&
292           (scalar_type == at::ScalarType::Float ||
293            scalar_type == at::ScalarType::Half ||
294            scalar_type == at::ScalarType::BFloat16) &&
295 #else
296           (scalar_type == at::ScalarType::Double ||
297            scalar_type == at::ScalarType::Float ||
298            scalar_type == at::ScalarType::Half ||
299            scalar_type == at::ScalarType::BFloat16) &&
300 #endif
301 #if (defined(CUDA_VERSION) && CUDA_VERSION >= 12010 && !defined(USE_ROCM))
302           mat2_sizes[0] > 1 && mat2_sizes[1] > 1;
303 #else
304           mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
305           mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
306           mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
307           // avoid leading dim >> rows bugs
308           ((mat1.strides()[0] == 1 && mat1.strides()[1] == mat1_sizes[0]) ||
309            (mat1.strides()[1] == 1 && mat1.strides()[0] == mat1_sizes[1]) ||
310            (scalar_type != at::ScalarType::Half &&
311             scalar_type != at::ScalarType::BFloat16)) &&
312           ((mat2.strides()[0] == 1 && mat2.strides()[1] == mat2_sizes[0]) ||
313            (mat2.strides()[1] == 1 && mat2.strides()[0] == mat2_sizes[1]) ||
314            (scalar_type != at::ScalarType::Half &&
315             scalar_type != at::ScalarType::BFloat16));
316 #endif
317     }
318 #endif
319     if (!useLtInterface) {
320       self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
321     }
322     self__sizes = self_->sizes();
323   } else {
324 #if defined(USE_ROCM)
325     useLtInterface = !disable_addmm_cuda_lt &&
326         result.dim() == 2 && result.is_contiguous() &&
327         isSupportedHipLtROCmArch(self.device().index()) &&
328         (scalar_type == at::ScalarType::Float ||
329           scalar_type == at::ScalarType::Half ||
330           scalar_type == at::ScalarType::BFloat16);
331 #endif
332     self_ = c10::MaybeOwned<Tensor>::borrowed(self);
333     self__sizes = self_->sizes();
334     TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
335     TORCH_CHECK(self__sizes[0] == mat1_sizes[0], "self_ dim 0 must match mat1 dim 0");
336     TORCH_CHECK(self__sizes[1] == mat2_sizes[1], "self_ dim 1 must match mat2 dim 1");
337   }
338 
339   if (&result != &self) {
340     at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]});
341     if (beta.toComplexDouble() != 0.0 && !useLtInterface) {
342       at::native::copy_(result, *self_);
343     }
344   }
345 
346 
347   IntArrayRef result_sizes = result.sizes();
348   if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
349     return result;
350   }
351 
352   cublasCommonArgs args(mat1, mat2, result);
353 
354   if (mat1.numel() == 0) {
355     // By definition, when beta==0, values in self should be ignored. nans and infs
356     // should not propagate
357     if (beta.toComplexDouble() == 0.) {
358       return result.zero_();
359     }
360     // TODO: We could squeeze some perf by calling at::cuda::mul_out here instead, to bypass the dispatcher.
361     // That requires some fixing some internal build dependencies though.
362     return at::mul_out(
363         result,
364         self.expand(result.sizes()),
365         at::native::scalar_tensor(
366             beta,
367             self.scalar_type(),
368             std::nullopt /* layout */,
369             at::kCPU,
370             std::nullopt /* pin_memory */));
371   }
372 
373   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
374 
375   if (useLtInterface) {
376 #if defined(USE_ROCM)
377     AT_DISPATCH_FLOATING_TYPES_AND2(
378         at::ScalarType::Half,
379         at::ScalarType::BFloat16,
380         scalar_type,
381         "addmm_cuda_lt",
382         [&] {
383         auto tuning_ctx = at::cuda::tunable::getTuningContext();
384         if (tuning_ctx->IsTunableOpEnabled()) {
385           launchTunableGemmAndBias<scalar_t>(
386               args,
387               alpha,
388               (&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
389               activation_to_gemm_and_blas_arg(activation));
390         }
391         else {
392           at::cuda::blas::gemm_and_bias<scalar_t>(
393               args.transa == 't',
394               args.transb == 't',
395               args.m,
396               args.n,
397               args.k,
398               alpha.to<at::opmath_type<scalar_t>>(),
399               args.mata->const_data_ptr<scalar_t>(),
400               args.lda,
401               args.matb->const_data_ptr<scalar_t>(),
402               args.ldb,
403               // This condition is needed for mm case on ROCm for hipblasLt path.
404               // Passing the bias ptr as null to avoid accuracy issues for mm case.
405               (&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
406               args.result->data_ptr<scalar_t>(),
407               args.result_ld,
408               activation_to_gemm_and_blas_arg(activation)
409           );
410         }});
411 #else
412     auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
413 #if (defined(CUDA_VERSION) && (CUDA_VERSION < 11080))
414     // GELU is not supported (and does not compile!) prior
415     // to CUDA 11.4. Have observed accuracy issues with
416     // GELU epilogue in 11.4; disabling the GELU epilogue
417     // path for CUDA version < 11.8.
418     if (activation == Activation::GELU)
419       activation_epilogue = cuda::blas::GEMMAndBiasActivationEpilogue::None;
420 #endif
421 
422     AT_DISPATCH_FLOATING_TYPES_AND2(
423         at::ScalarType::Half,
424         at::ScalarType::BFloat16,
425         scalar_type,
426         "addmm_cuda_lt",
427         [&] {
428         auto tuning_ctx = at::cuda::tunable::getTuningContext();
429         if (tuning_ctx->IsTunableOpEnabled()) {
430           launchTunableGemmAndBias<scalar_t>(
431               args,
432               alpha,
433               self.const_data_ptr<scalar_t>(),
434               activation_epilogue);
435         }
436         else {
437           at::cuda::blas::gemm_and_bias<scalar_t>(
438               args.transa == 't',
439               args.transb == 't',
440               args.m,
441               args.n,
442               args.k,
443               alpha.to<at::opmath_type<scalar_t>>(),
444               args.mata->const_data_ptr<scalar_t>(),
445               args.lda,
446               args.matb->const_data_ptr<scalar_t>(),
447               args.ldb,
448               self.const_data_ptr<scalar_t>(),
449               args.result->data_ptr<scalar_t>(),
450               args.result_ld,
451               activation_epilogue
452           );
453         }});
454 #endif
455   } else
456   {
457     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
458         at::ScalarType::Half,
459         at::ScalarType::BFloat16,
460         scalar_type,
461         "addmm_cuda",
462         [&] {
463           using opmath_t = at::opmath_type<scalar_t>;
464           opmath_t alpha_val = alpha.to<opmath_t>();
465           opmath_t beta_val = beta.to<opmath_t>();
466           const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
467           const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
468           scalar_t* result_ptr = args.result->mutable_data_ptr<scalar_t>();
469           at::cuda::blas::gemm<scalar_t>(
470               args.transa,
471               args.transb,
472               args.m,
473               args.n,
474               args.k,
475               alpha_val,
476               mat1_ptr,
477               args.lda,
478               mat2_ptr,
479               args.ldb,
480               beta_val,
481               result_ptr,
482               args.result_ld);
483         });
484     switch (activation) {
485       case Activation::RELU:
486         at::relu_(const_cast<Tensor&>(*args.result));
487         break;
488       case Activation::GELU:
489         at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
490         break;
491       default: break;
492     }
493   }
494 
495 // Preprocessor gate here needs to match the inverse of the check
496 // gating activation_to_gemm_and_blas_arg above; here we are manually
497 // performing a post-GELU because we weren't able to use the GELU
498 // epilogue above.
499 #if !(defined(CUDA_VERSION) && CUDA_VERSION >= 11080) && !defined(USE_ROCM)
500   if (useLtInterface && activation == Activation::GELU) {
501     at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
502   }
503 #endif
504 
505   if (!result.is_same(*args.result)) {
506     result.copy_(*args.result);
507   }
508   return result;
509 }
510 
baddbmm_out_cuda_impl(const Tensor & result,const Tensor & self,const Tensor & batch1,const Tensor & batch2,const Scalar & beta,const Scalar & alpha)511 const Tensor& baddbmm_out_cuda_impl(const Tensor& result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
512   // handle pathological cases that blas may not like
513   if (result.numel() == 0) {
514     return result;
515   } else if (batch1.size(2) == 0) {
516     if (beta.to<c10::complex<double>>() == 0.0) {
517       return result.zero_();
518     } else {
519       return result.mul_(beta);
520     }
521   }
522 
523   bool transpose_result = false;
524   c10::MaybeOwned<Tensor> result_;
525   IntArrayRef result_strides = result.strides();
526   IntArrayRef result_sizes = result.sizes();
527 
528   if ((result_strides[1] == 1) &&
529       ((result_sizes[2] == 1) || (result_strides[2] >= std::max<int64_t>(1, result_sizes[1])))) {
530     result_ = resolve_conj_if_indicated(result, true);
531   } else if ((result_strides[2] == 1) &&
532     (result_sizes[1] == 1 || (result_strides[1] >= std::max<int64_t>(1, result_sizes[2])))) {
533     transpose_result = true;
534     result_ = resolve_conj_if_indicated(result, true);
535   } else {
536     result_ = c10::MaybeOwned<Tensor>::owned(result.transpose(1, 2).clone(at::MemoryFormat::Contiguous).transpose(1, 2));
537   }
538 
539   int leading_dim = transpose_result ? 1 : 2;
540 
541   int64_t m = result_sizes[transpose_result ? 2 : 1];
542   int64_t n = result_sizes[leading_dim];
543   int64_t k = (transpose_result ? batch2 : batch1).sizes()[leading_dim];
544 
545   int64_t lda, ldb, ldc;
546   bool transpose_batch1, transpose_batch2;
547   auto batch1_ = prepare_batch_matrix_for_cublas(transpose_result ? batch2 : batch1, transpose_batch1, lda, transpose_result, m, k);
548   auto batch2_ = prepare_batch_matrix_for_cublas(transpose_result ? batch1 : batch2, transpose_batch2, ldb, transpose_result, k, n);
549 
550   ldc = result_->strides()[leading_dim];
551   int64_t num_batches = result_->sizes()[0];
552 
553   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result_->is_conj());
554 
555   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "baddbmm_cuda", [&] {
556     using opmath_t = at::opmath_type<scalar_t>;
557     opmath_t alpha_val = alpha.to<opmath_t>();
558     opmath_t beta_val = beta.to<opmath_t>();
559     const scalar_t* batch1_ptr = batch1_->const_data_ptr<scalar_t>();
560     const scalar_t* batch2_ptr = batch2_->const_data_ptr<scalar_t>();
561     scalar_t* result_ptr = result_->mutable_data_ptr<scalar_t>();
562     const auto transa = transpose_batch1 ? batch1_->is_conj() ? 'c' : 't' : 'n';
563     const auto transb = transpose_batch2 ? batch2_->is_conj() ? 'c' : 't' : 'n';
564     // If batch is 1 call gemm rather than bgemm
565     if (num_batches == 1) {
566       at::cuda::blas::gemm<scalar_t>(
567           transa, transb,
568           m, n, k,
569           alpha_val,
570           batch1_ptr, lda,
571           batch2_ptr, ldb,
572           beta_val,
573           result_ptr, ldc);
574     } else {
575       at::cuda::blas::bgemm<scalar_t>(
576         transa, transb,
577         m, n, k,
578         alpha_val,
579         batch1_ptr, lda, batch1_->strides()[0],
580         batch2_ptr, ldb, batch2_->strides()[0],
581         beta_val,
582         result_ptr, ldc, result_->strides()[0],
583         num_batches
584       );
585    }
586   });
587   if (!result.is_same(*result_)) {
588     result.copy_(*result_);
589   }
590   return result;
591 }
592 
593 } // anonymous namespace
594 
TORCH_IMPL_FUNC(addmm_out_cuda)595 TORCH_IMPL_FUNC(addmm_out_cuda)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
596   addmm_out_cuda_impl(const_cast<Tensor&>(result), self, mat1, mat2, beta, alpha);
597 }
598 
TORCH_IMPL_FUNC(addmm_activation_out_cuda)599 TORCH_IMPL_FUNC(addmm_activation_out_cuda)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, bool use_gelu, const Tensor& result) {
600   addmm_out_cuda_impl(const_cast<Tensor&>(result), self, mat1, mat2, beta, alpha, use_gelu ? Activation::GELU : Activation::RELU);
601 }
602 
TORCH_IMPL_FUNC(mm_out_cuda)603 TORCH_IMPL_FUNC(mm_out_cuda)(const Tensor& self, const Tensor& mat2, const Tensor& result) {
604   addmm_out_cuda_impl(const_cast<Tensor&>(result), result, self, mat2, 0, 1);
605 }
606 
TORCH_IMPL_FUNC(baddbmm_out_cuda)607 TORCH_IMPL_FUNC(baddbmm_out_cuda)(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
608   {
609     at::NoNamesGuard guard;
610     baddbmm_out_cuda_impl(result, self, batch1, batch2, beta, alpha);
611   }
612 }
613 
TORCH_IMPL_FUNC(bmm_out_cuda)614 TORCH_IMPL_FUNC(bmm_out_cuda)(const Tensor& batch1, const Tensor& batch2, const Tensor &result) {
615   Scalar beta(0.0);
616   Scalar alpha(1.0);
617   {
618     NoNamesGuard guard;
619     baddbmm_out_cuda_impl(result, result, batch1, batch2, beta, alpha);
620   }
621 }
622 
623 namespace {
624 
dot_check(const Tensor & self,const Tensor & other)625 inline void dot_check(const Tensor& self, const Tensor& other) {
626   TORCH_CHECK(
627       self.dim() == 1 && other.dim() == 1,
628       "1D tensors expected, but got ",
629       self.dim(),
630       "D and ",
631       other.dim(),
632       "D tensors");
633   TORCH_CHECK(
634       self.scalar_type() == other.scalar_type(),
635       "dot : expected both vectors to have same dtype, but found ",
636       self.scalar_type(),
637       " and ",
638       other.scalar_type());
639   TORCH_CHECK(
640       self.numel() == other.numel(),
641       "inconsistent tensor size, expected tensor [",
642       self.numel(),
643       "] and src [",
644       other.numel(),
645       "] to have the same number of elements, but got ",
646       self.numel(),
647       " and ",
648       other.numel(),
649       " elements respectively");
650   TORCH_CHECK(
651       (self.numel() <= INT_MAX) && (self.stride(0) <= INT_MAX) &&
652           (other.stride(0) <= INT_MAX),
653       "dot only supports n, incx, incy with the bound [val] <= %d",
654       INT_MAX);
655 }
656 
657 } // anonymous namespace
658 
dot_cuda(const Tensor & self,const Tensor & other)659 Tensor dot_cuda(const Tensor& self, const Tensor& other) {
660   if (self.is_complex()) {
661     if (self.is_conj()) {
662       if (other.is_conj()) {
663         return (dot_cuda(self.conj(), other.conj())).conj();
664        } else {
665          return vdot_cuda(self.conj(), other);
666        }
667     } else if (other.is_conj()) {
668       return vdot_cuda(other.conj(), self);
669     }
670   }
671 
672   at::NoNamesGuard guard;
673   dot_check(self, other);
674 
675   const int n = static_cast<int>(self.numel());
676   int incx = static_cast<int>(self.stride(0));
677   int incy = static_cast<int>(other.stride(0));
678   if (n == 1) {
679     incx = 1;
680     incy = 1;
681   }
682 
683 if (self._is_zerotensor() || other._is_zerotensor()) {
684   return at::_efficientzerotensor({}, self.options());
685 }
686 
687 return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
688       ScalarType::Half, ScalarType::BFloat16,
689       self.scalar_type(), "dot",
690       [&] {
691         Tensor result = at::empty({}, self.options());
692 
693         auto handle = at::cuda::getCurrentCUDABlasHandle();
694         at::cuda::blas::PointerModeGuard pointerModeGuard(handle, CUBLAS_POINTER_MODE_DEVICE);
695         at::cuda::blas::dot<scalar_t>(
696             handle,
697             n,
698             self.const_data_ptr<scalar_t>(),
699             incx,
700             other.const_data_ptr<scalar_t>(),
701             incy,
702             result.mutable_data_ptr<scalar_t>());
703 
704         return result;
705       });
706 }
707 
vdot_cuda(const Tensor & self,const Tensor & other)708 Tensor vdot_cuda(const Tensor& self, const Tensor& other) {
709   if (!self.is_complex()) {
710     return dot_cuda(self, other);
711   }
712 
713   if (self.is_conj()) {
714     if (other.is_conj()) {
715       return vdot_cuda(other.conj(), self.conj());
716     } else {
717       return dot_cuda(self.conj(), other);
718     }
719   } else if (other.is_conj()) {
720     return (dot_cuda(self, other.conj())).conj();
721   }
722 
723   at::NoNamesGuard guard;
724   dot_check(self, other);
725 
726   if (self._is_zerotensor() || other._is_zerotensor()) {
727     return at::_efficientzerotensor({}, self.options());
728   }
729 
730   const int n = static_cast<int>(self.numel());
731   int incx = static_cast<int>(self.stride(0));
732   int incy = static_cast<int>(other.stride(0));
733   if (n == 1) {
734     incx = 1;
735     incy = 1;
736   }
737 
738   return AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "vdot", [&] {
739     Tensor result = at::empty({}, self.options());
740 
741     auto handle = at::cuda::getCurrentCUDABlasHandle();
742     at::cuda::blas::PointerModeGuard pointerModeGuard(
743         handle, CUBLAS_POINTER_MODE_DEVICE);
744     at::cuda::blas::vdot<scalar_t>(
745         handle,
746         n,
747         self.const_data_ptr<scalar_t>(),
748         incx,
749         other.const_data_ptr<scalar_t>(),
750         incy,
751         result.mutable_data_ptr<scalar_t>());
752 
753     return result;
754   });
755 }
756 
TORCH_IMPL_FUNC(addmv_out_cuda)757 TORCH_IMPL_FUNC(addmv_out_cuda)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta_, const Scalar& alpha_, const Tensor& result) {
758   c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)});
759   auto betaval = beta_.toComplexDouble();
760   if (mat.numel() == 0) {
761     // shortcut for an empty matrix
762     // By definition, when beta==0, values in self should be ignored. nans and infs
763     // should not propagate
764     if (betaval == 0.0) {
765       result.zero_();
766     } else {
767       at::mul_out(
768           const_cast<Tensor&>(result),
769           self,
770           at::native::scalar_tensor(
771               beta_, self.scalar_type(), std::nullopt /* layout */, at::kCPU, std::nullopt /* pin_memory */));
772     }
773   } else {
774     if (!result.is_same(*self_) && betaval != 0.0) { //if beta is 0, result contents will be zeroed later
775       at::native::copy_(const_cast<Tensor&>(result), *self_);
776     }
777     if (result.numel() != 0) {
778       auto r_stride = result.stride(0);
779       auto vec_stride = vec.stride(0);
780 
781       // Check for contiguity of `vec` and update `vec_stride` accordingly
782       const auto vec_contiguous = vec_stride == 0 ? vec.contiguous() : vec;
783       // A vector can be contiguous and have a stride of zero if it has it is of length 1
784       vec_stride = std::max<int64_t>(vec_contiguous.stride(0), 1LL);
785 
786       AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, mat.scalar_type(), "addmv_impl_cuda", [&] {
787         auto beta = beta_.to<scalar_t>();
788         auto alpha = alpha_.to<scalar_t>();
789         if (mat.stride(0) == 1 && mat.stride(1) >= std::max<int64_t>(1, mat.size(0))) {
790           at::cuda::blas::gemv<scalar_t>('n',
791             mat.size(0), mat.size(1), alpha, mat.const_data_ptr<scalar_t>(), mat.stride(1), vec_contiguous.const_data_ptr<scalar_t>(),
792             vec_stride, beta, result.mutable_data_ptr<scalar_t>(), r_stride);
793         }
794         else if (mat.stride(1) == 1 && mat.stride(0) >= std::max<int64_t>(1, mat.size(1))) {
795           at::cuda::blas::gemv<scalar_t>('t',
796             mat.size(1), mat.size(0), alpha, mat.const_data_ptr<scalar_t>(), mat.stride(0),
797             vec_contiguous.const_data_ptr<scalar_t>(), vec_stride, beta, result.mutable_data_ptr<scalar_t>(), r_stride);
798         }
799         else {
800           Tensor cmat = mat.contiguous();
801           at::cuda::blas::gemv<scalar_t>('t',
802               mat.size(1), mat.size(0), alpha, cmat.const_data_ptr<scalar_t>(), cmat.stride(0),
803               vec_contiguous.const_data_ptr<scalar_t>(), vec_stride, beta, result.mutable_data_ptr<scalar_t>(), r_stride);
804         }
805       });
806     }
807   }
808 }
809 
_int_mm_out_cuda(const Tensor & self,const Tensor & mat2,Tensor & result)810 Tensor& _int_mm_out_cuda(const Tensor& self, const Tensor& mat2, Tensor& result) {
811   // NOTE: cuBLAS is currently broken for some combination of transposed inputs.
812   TORCH_CHECK(self.dim() == 2, "Expected self to be of dimension 2 but got ", self.dim());
813   TORCH_CHECK(mat2.dim() == 2, "Expected mat2 to be of dimension 2 but got ", mat2.dim());
814   TORCH_CHECK(self.size(0) > 16, "self.size(0) needs to be greater than 16, but got ", self.size(0));
815   TORCH_CHECK(self.size(1) > 0 && self.size(1) % 8 == 0, "self.size(1) needs to be greater than 0 and a multiple of 8, but got ", self.size(1));
816   TORCH_CHECK(self.size(1) == mat2.size(0), "self.size(1) needs to match mat2.size(0) but got ", self.size(1), " and ", mat2.size(0));
817   TORCH_CHECK(mat2.size(1) > 0 && mat2.size(1) % 8 == 0, "mat2.size(1) needs to be greater than 0 and a multiple of 8, but got ", mat2.size(1));
818 
819   TORCH_CHECK(result.dtype() == at::kInt, "Expected result dtype to be of type kInt but got ", result.dtype());
820   TORCH_CHECK(result.size(0) == self.size(0), "Expected result.size(0) to be ", self.size(0), " but got ", result.size(0));
821   TORCH_CHECK(result.size(1) == mat2.size(1), "Expected result.size(1) to be ", mat2.size(1), " but got ", result.size(1));
822 
823   TORCH_CHECK(result.dim() == 2, "Expected result to be of dimension 2 but got ", result.dim());
824 
825   TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous.");
826 
827 #if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11070)) || defined(USE_ROCM)
828   cublasCommonArgs args(self, mat2, result);
829 
830   at::cuda::blas::int8_gemm(
831       args.transa == 't',
832       args.transb == 't',
833       args.m,
834       args.n,
835       args.k,
836       args.mata->data_ptr<int8_t>(),
837       args.lda,
838       args.matb->data_ptr<int8_t>(),
839       args.ldb,
840       args.result->data_ptr<int32_t>(),
841       args.result_ld);
842 
843   if (!result.is_same(*args.result)) {
844     result.copy_(*args.result);
845   }
846 #else
847 #if !defined(USE_ROCM) && defined(CUDA_VERSION)
848   TORCH_CHECK(false, "_int_mm_out_cuda not compiled for CUDA ", CUDA_VERSION);
849 #else
850   TORCH_CHECK(false, "_int_mm_out_cuda not compiled for this platform.");
851 #endif
852 #endif
853 
854   return result;
855 }
856 
_int_mm_cuda(const Tensor & self,const Tensor & mat2)857 Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
858   Tensor result = at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt));
859   return _int_mm_out_cuda(self, mat2, result);
860 }
861 
_scaled_mm_allowed_device()862 static bool _scaled_mm_allowed_device() {
863     auto dprops = at::cuda::getCurrentDeviceProperties();
864 #ifdef USE_ROCM
865     std::string device_arch = dprops->gcnArchName;
866     static const std::vector<std::string> archs = {"gfx940", "gfx941", "gfx942"};
867     for (std::string arch : archs) {
868         size_t substring = device_arch.find(arch);
869         if (substring != std::string::npos) {
870             return true;
871         }
872     }
873     return false;
874 #else
875     return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
876 #endif
877 }
878 
879 namespace{
880 
881 enum class ScalingType {
882   TensorWise,
883   RowWise,
884   Error
885 };
886 /*
887  * Scaling Type Determination:
888  * ---------------------------
889  * Conditions and corresponding Scaling Types:
890  *
891  * - If scale_a.numel() == 1 && scale_b.numel() == 1:
892  *   - Returns TensorWise.
893  *
894  * - Else if scale_a.dim() == 1 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n:
895  *   - Returns RowWise.
896  *
897  * - Otherwise:
898  *   - Returns Error.
899  */
900 
901 // Validates the scale tensors to scaled_mm
902 // And returns the type of scaling/which kernel to use
get_scaling_type(const at::Tensor & scale_a,const at::Tensor & scale_b,int64_t dim_m,int64_t dim_n)903 ScalingType get_scaling_type(
904     const at::Tensor& scale_a,
905     const at::Tensor& scale_b,
906     int64_t dim_m,
907     int64_t dim_n) {
908   // Both Per-Tensor and Row-wise scaling expect fp32 tensors
909   TORCH_CHECK(
910       scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
911       "Both scale_a and scale_b must be float (fp32) tensors.");
912 
913   // Check the singluar scale case for per-tensor scaling
914   if (scale_a.numel() == 1 && scale_b.numel() == 1) {
915     return ScalingType::TensorWise;
916   }
917 
918   // For non-TensorWise scaling, enforce 2D input tensors
919   TORCH_CHECK(
920       scale_a.dim() == 2 && scale_b.dim() == 2,
921       "For non-TensorWise scaling, scale tensors must be 2-dimensional, "
922       "but got scale_a.dim()=",
923       scale_a.dim(),
924       " and scale_b.dim()=",
925       scale_b.dim());
926 
927   // Check for RowWise scaling
928   if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 &&
929       scale_b.size(0) == 1 && scale_b.size(1) == dim_n) {
930 #if !defined(USE_ROCM) && !defined(_MSC_VER) || \
931     (defined(USE_ROCM) && ROCM_VERSION >= 60000)
932     TORCH_CHECK(
933         scale_a.is_contiguous() && scale_b.is_contiguous(),
934         "Both scale_a and scale_b must be contiguous for RowWise scaling.");
935     return ScalingType::RowWise;
936 #else
937     TORCH_CHECK(false, "Per-row scaling is not supported for this platform!");
938     return ScalingType::Error;
939 #endif
940   }
941 
942   // If we reach here, the input doesn't match any valid scaling type
943   TORCH_CHECK(
944       false,
945       "Invalid scaling configuration. For TensorWise scaling, both scales should be scalar. "
946       "For RowWise scaling, scale_a should be (",
947       dim_m,
948       ", 1) and scale_b should be (1, ",
949       dim_n,
950       "). "
951       "Got scale_a.size()=(",
952       scale_a.size(0),
953       ", ",
954       scale_a.size(1),
955       ") and ",
956       "scale_b.size()=(",
957       scale_b.size(0),
958       ", ",
959       scale_b.size(1),
960       ")");
961 
962   return ScalingType::Error;
963 }
964 
965 } // namespace
966 
967 // Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax
968 // Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default.
969 // If output matrix type is 16 or 32-bit type, neither scale_result is applied nor amax is computed.
970 // Known limitations:
971 //  - Only works if mat1 is row-major and mat2 is column-major
972 //  - Only works if matrices sizes are divisible by 32
973 //  - If 1-dimensional tensors are used then scale_a should be size = mat1.size(0)
974 //    and scale_b should have size = to mat2.size(1)
975 //  Arguments:
976 //    - `mat1`: the first operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2`
977 //    - `mat2`: the second operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2`
978 //    - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
979 //    - `out_dtype`: the output dtype, can either be a float8 or a higher precision floating point type
980 //    - `scale_a`: a scalar or 1-dimensional tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type
981 //    - `scale_b`: a scalar or 1-dimensional tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type
982 //    - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type
983 //    - `use_fast_accum`: if true, enables fast float8 accumulation
984 //    - `out`: a reference to the output tensor
985 
986 Tensor&
_scaled_mm_out_cuda(const Tensor & mat1,const Tensor & mat2,const Tensor & scale_a,const Tensor & scale_b,const std::optional<at::Tensor> & bias,const std::optional<at::Tensor> & scale_result,std::optional<c10::ScalarType> out_dtype,bool use_fast_accum,Tensor & out)987 _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
988           const Tensor& scale_a,
989           const Tensor& scale_b,
990           const std::optional<at::Tensor>& bias,
991           const std::optional<at::Tensor>& scale_result,
992           std::optional<c10::ScalarType> out_dtype,
993           bool use_fast_accum,
994           Tensor& out) {
995   // Check sizes
996   bool allowed_device = _scaled_mm_allowed_device();
997   TORCH_CHECK(allowed_device, "torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+");
998   TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
999   TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
1000   TORCH_CHECK(
1001       mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
1002       mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
1003 
1004   // Check what type of scaling we are doing based on inputs
1005   ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat2.size(1));
1006   TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported");
1007 
1008   TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
1009        "scale_result must be a float scalar");
1010   TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
1011        " but got ", bias->numel());
1012   TORCH_CHECK(
1013       mat1.sizes()[1] % 16 == 0,
1014       "Expected trailing dimension of mat1 to be divisible by 16 ",
1015       "but got mat1 shape: (",
1016       mat1.sizes()[0],
1017       "x",
1018       mat1.sizes()[1],
1019       ".");
1020   TORCH_CHECK(mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0, "mat2 shape (", mat2.sizes()[0], "x",
1021        mat2.sizes()[1], " must be divisible by 16");
1022   // Check types
1023   TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
1024   TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
1025   TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
1026   // Type restrictions imposed by CuBLASLt as of CUDA-12.1
1027   TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
1028         "Multiplication of two Float8_e5m2 matrices is not supported");
1029   if (bias) {
1030     TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32");
1031     TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half,
1032          "Bias must be either Half or BFloat16, but got ", bias->scalar_type());
1033     TORCH_CHECK((out.scalar_type() != kFloat && out.scalar_type() != ScalarType::BFloat16) ||
1034           bias->scalar_type() == ScalarType::BFloat16,
1035           "Bias must be BFloat16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type());
1036     TORCH_CHECK(out.scalar_type() != ScalarType::Half || bias->scalar_type() == ScalarType::Half,
1037           "Bias must be Float16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type());
1038   }
1039   {
1040     auto bias_ = bias.value_or(Tensor());
1041     auto scale_result_ = scale_result.value_or(Tensor());
1042 
1043     TensorArg targs[]{{out, "out", 0}, {mat1, "mat1", 1}, {mat2, "mat2", 2},
1044                       {bias_, "bias", 3}, {scale_a, "scale_a", 4}, {scale_b, "scale_b", 5},
1045                       {scale_result_, "scale_result", 6}};
1046     checkAllSameGPU(__func__, targs);
1047   }
1048   // Validation checks have passed lets resize the output to actual size
1049   IntArrayRef mat1_sizes = mat1.sizes();
1050   IntArrayRef mat2_sizes = mat2.sizes();
1051   at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
1052 
1053   // We are doing row-wise scaling
1054   if (scaling_choice == ScalingType::RowWise) {
1055     TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precsion output types are supported for row-wise scaling.");
1056     at::cuda::detail::f8f8bf16_rowwise(
1057         mat1,
1058         mat2,
1059         scale_a,
1060         scale_b,
1061         bias,
1062         use_fast_accum,
1063         out);
1064     return out;
1065   }
1066 
1067   cublasCommonArgs args(mat1, mat2, out);
1068   const auto out_dtype_ = args.result->scalar_type();
1069   TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
1070 
1071   // Some scaled_gemms require an amax to populate lets create one here
1072   Tensor amax = at::empty({0}, mat1.options().dtype(ScalarType::Float));
1073 
1074 #ifdef USE_ROCM
1075   auto tuning_ctx = at::cuda::tunable::getTuningContext();
1076   if (tuning_ctx->IsTunableOpEnabled()) {
1077 #define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B)                            \
1078         if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) {        \
1079           if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) {      \
1080             static at::cuda::tunable::ScaledGemmTunableOp<              \
1081                 at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t,     \
1082                 BLASOP_A, BLASOP_B> scaledgemm{};                       \
1083             scaledgemm(&params);                                        \
1084           }                                                             \
1085           else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
1086             static at::cuda::tunable::ScaledGemmTunableOp<              \
1087                 at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t,     \
1088                 BLASOP_A, BLASOP_B> scaledgemm{};                       \
1089             scaledgemm(&params);                                        \
1090           }                                                             \
1091         }                                                               \
1092         else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) {   \
1093           if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) {      \
1094             static at::cuda::tunable::ScaledGemmTunableOp<              \
1095                 at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t,     \
1096                 BLASOP_A, BLASOP_B> scaledgemm{};                       \
1097             scaledgemm(&params);                                        \
1098           }                                                             \
1099           else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
1100             static at::cuda::tunable::ScaledGemmTunableOp<              \
1101                 at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t,     \
1102                 BLASOP_A, BLASOP_B> scaledgemm{};                       \
1103             scaledgemm(&params);                                        \
1104           }                                                             \
1105         }
1106     AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] {
1107       bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
1108       bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
1109       at::cuda::tunable::ScaledGemmParams<scalar_t> params;
1110       params.transa = args.transa;
1111       params.transb = args.transb;
1112       params.m = args.m;
1113       params.n = args.n;
1114       params.k = args.k;
1115       params.a = args.mata->data_ptr();
1116       params.a_scale_ptr = scale_a.data_ptr();
1117       params.lda = args.lda;
1118       params.a_dtype = args.mata->scalar_type();
1119       params.b = args.matb->data_ptr();
1120       params.b_scale_ptr = scale_b.data_ptr();
1121       params.ldb = args.ldb;
1122       params.b_dtype = args.matb->scalar_type();
1123       params.bias_ptr = bias ? bias->data_ptr(): nullptr;
1124       params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
1125       params.c = args.result->data_ptr();
1126       params.c_scale_ptr = scale_result ? scale_result->data_ptr() : nullptr;
1127       params.ldc = args.result_ld;
1128       params.c_dtype = out_dtype_;
1129       params.amax_ptr = amax.data_ptr();
1130       params.use_fast_accum = use_fast_accum;
1131       if (transa_ && transb_) {
1132         TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
1133       }
1134       else if (transa_ && !transb_) {
1135         TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
1136       }
1137       else if (!transa_ && transb_) {
1138         TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
1139       }
1140       else if (!transa_ && !transb_) {
1141         TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
1142       }
1143       else {
1144         TORCH_CHECK(false, "unreachable");
1145       }
1146     }),
1147     kHalf, kBFloat16, kFloat8_e4m3fnuz, kFloat8_e5m2fnuz, AT_EXPAND(AT_FLOATING_TYPES));
1148 #undef TUNABLE_DISPATCH
1149   }
1150   else
1151 #endif
1152   {
1153 #if defined(USE_ROCM) && ROCM_VERSION >= 60200
1154   // hipBlasLT requires scaleD to be set to something in order to use AMAX
1155     auto dummy_options = TensorOptions().dtype(kFloat).device(kCUDA);
1156     auto dummy_scale = at::ones(1, dummy_options);
1157 #endif
1158     at::cuda::blas::scaled_gemm(
1159         args.transa,
1160         args.transb,
1161         args.m,
1162         args.n,
1163         args.k,
1164         args.mata->data_ptr(),
1165         scale_a.data_ptr(),
1166         args.lda,
1167         args.mata->scalar_type(),
1168         args.matb->data_ptr(),
1169         scale_b.data_ptr(),
1170         args.ldb,
1171         args.matb->scalar_type(),
1172         bias ? bias->data_ptr(): nullptr,
1173         bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
1174         args.result->data_ptr(),
1175 #if defined(USE_ROCM) && ROCM_VERSION >= 60200
1176         scale_result ? scale_result->data_ptr() : dummy_scale.data_ptr(),
1177 #else
1178         scale_result ? scale_result->data_ptr() : nullptr,
1179 #endif
1180         args.result_ld,
1181         out_dtype_,
1182         amax.data_ptr(),
1183         use_fast_accum);
1184   }
1185 
1186   return out;
1187 }
1188 
1189 Tensor
_scaled_mm_cuda(const Tensor & mat_a,const Tensor & mat_b,const Tensor & scale_a,const Tensor & scale_b,const std::optional<at::Tensor> & bias,const std::optional<at::Tensor> & scale_result,std::optional<c10::ScalarType> out_dtype,bool use_fast_accum)1190 _scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
1191           const Tensor& scale_a,
1192           const Tensor& scale_b,
1193           const std::optional<at::Tensor>& bias,
1194           const std::optional<at::Tensor>& scale_result,
1195           std::optional<c10::ScalarType> out_dtype,
1196           bool use_fast_accum) {
1197   const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
1198   Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
1199   return _scaled_mm_out_cuda(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
1200 }
1201 
1202 } // namespace at::native
1203