xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/RowwiseScaledMM.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/TensorBase.h>
3 #include <optional>
4 
5 
6 namespace at::cuda::detail {
7 TORCH_API void f8f8bf16_rowwise(
8     at::Tensor XQ, // FP8
9     at::Tensor WQ, // FP8
10     at::Tensor x_scale, // FP32
11     at::Tensor w_scale, // FP32
12     std::optional<at::Tensor> bias, // BF16
13     bool use_fast_accum,
14     at::Tensor& out);
15 }  // at::cuda::detail
16